├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── examples
├── dual_ema_on_apple.py
└── smart_beta.py
├── requirements.txt
├── setup.py
├── spectre
├── __init__.py
├── config.py
├── data
│ ├── __init__.py
│ ├── arrow.py
│ ├── csv.py
│ ├── dataloader.py
│ ├── iex.py
│ ├── memory.py
│ ├── quandl.py
│ └── yahoo.py
├── factors
│ ├── __init__.py
│ ├── basic.py
│ ├── datafactor.py
│ ├── engine.py
│ ├── factor.py
│ ├── feature.py
│ ├── filter.py
│ ├── label.py
│ ├── multiprocessing.py
│ ├── statistical.py
│ └── technical.py
├── parallel
│ ├── __init__.py
│ ├── algorithmic.py
│ └── constants.py
├── plotting
│ ├── __init__.py
│ ├── chart.py
│ ├── factor_diagram.py
│ └── returns_chart.py
└── trading
│ ├── __init__.py
│ ├── algorithm.py
│ ├── blotter.py
│ ├── calendar.py
│ ├── event.py
│ ├── metric.py
│ ├── portfolio.py
│ ├── position.py
│ └── stopmodel.py
└── tests
├── __init__.py
├── benchmarks_spectre.ipynb
├── benchmarks_zipline.ipynb
├── data
├── 5mins
│ ├── AAPL_2018.csv
│ ├── AAPL_2019.csv
│ ├── MSFT_2018.csv
│ └── MSFT_2019.csv
├── daily
│ ├── AAPL.csv
│ └── MSFT.csv
├── dividends
│ ├── AAPL.csv
│ └── MSFT.csv
└── splits
│ ├── AAPL.csv
│ └── MSFT.csv
├── test_blotter.py
├── test_custom_factor.py
├── test_data_factor.py
├── test_data_loader.py
├── test_event.py
├── test_factor.py
├── test_metric.py
├── test_parallel_algo.py
└── test_trading_algorithm.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 | /.vscode
106 | /.idea
107 | /.note.md
108 | /tests/data/yahoo
109 | /private_factors
110 | /tests/data/orders_2020-06-01.csv
111 | /tests/data/orders_2020-06-02.csv
112 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "spectre/data/iex_fetcher"]
2 | path = spectre/data/iex_fetcher
3 | url = https://github.com/Heerozh/iex_fetcher.git
4 | branch = lib
5 |
--------------------------------------------------------------------------------
/examples/dual_ema_on_apple.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Heerozh (Zhang Jianhao)
3 | @copyright: Copyright 2020, Heerozh. All rights reserved.
4 | @license: Apache 2.0
5 | @email: heeroz@gmail.com
6 | """
7 |
8 | from spectre import factors, trading
9 | from spectre.data import YahooDownloader, ArrowLoader
10 | import pandas as pd
11 |
12 |
13 | class AppleDualEma(trading.CustomAlgorithm):
14 | invested = False
15 | asset = 'AAPL'
16 |
17 | def initialize(self):
18 | # setup engine
19 | engine = self.get_factor_engine()
20 | engine.to_cuda()
21 |
22 | universe = factors.StaticAssets({self.asset})
23 | engine.set_filter(universe)
24 |
25 | # add your factors
26 | fast_ema = factors.EMA(20)
27 | slow_ema = factors.EMA(40)
28 | engine.add(fast_ema, 'fast_ema')
29 | engine.add(slow_ema, 'slow_ema')
30 | engine.add(fast_ema > slow_ema, 'buy_signal')
31 | engine.add(fast_ema < slow_ema, 'sell_signal')
32 | engine.add(factors.OHLCV.close, 'price')
33 |
34 | # schedule rebalance before market close
35 | self.schedule_rebalance(trading.event.MarketClose(self.rebalance, offset_ns=-10000))
36 |
37 | # simulation parameters
38 | self.blotter.capital_base = 10000
39 | self.blotter.set_commission(percentage=0, per_share=0.005, minimum=1)
40 |
41 | def rebalance(self, data: pd.DataFrame, history: pd.DataFrame):
42 | asset_data = data.loc[self.asset]
43 | buy, sell = False, False
44 | if asset_data.buy_signal and not self.invested:
45 | self.blotter.order(self.asset, 100)
46 | self.invested = True
47 | buy = True
48 | elif asset_data.sell_signal and self.invested:
49 | self.blotter.order(self.asset, -100)
50 | self.invested = False
51 | sell = True
52 |
53 | self.record(AAPL=asset_data.price,
54 | short_ema=asset_data.fast_ema,
55 | long_ema=asset_data.slow_ema,
56 | buy=buy,
57 | sell=sell)
58 |
59 | def terminate(self, records: 'pd.DataFrame'):
60 | # plotting results
61 | self.plot(benchmark='SPY')
62 |
63 | import matplotlib.pyplot as plt
64 | fig = plt.figure()
65 | ax1 = fig.add_subplot(211)
66 | ax1.set_ylabel('Price (USD)')
67 |
68 | records[['AAPL', 'short_ema', 'long_ema']].plot(ax=ax1)
69 |
70 | ax1.plot(
71 | records.index[records.buy],
72 | records.loc[records.buy, 'long_ema'],
73 | '^',
74 | markersize=10,
75 | color='m',
76 | )
77 | ax1.plot(
78 | records.index[records.sell],
79 | records.loc[records.sell, 'short_ema'],
80 | 'v',
81 | markersize=10,
82 | color='k',
83 | )
84 | plt.legend(loc=0)
85 | plt.gcf().set_size_inches(18, 8)
86 |
87 | plt.show()
88 |
89 |
90 | if __name__ == '__main__':
91 | import plotly.io as pio
92 | pio.renderers.default = "browser"
93 |
94 | import argparse
95 | parser = argparse.ArgumentParser()
96 | parser.add_argument("--download", help="download yahoo data")
97 | args = parser.parse_args()
98 |
99 | if args.download:
100 | YahooDownloader.ingest(
101 | start_date="2001", save_to="./yahoo",
102 | symbols=None, skip_exists=True)
103 |
104 | loader = ArrowLoader('./yahoo/yahoo.feather')
105 | results = trading.run_backtest(loader, AppleDualEma, '2013-01-01', '2018-01-01')
106 | print(results.transactions)
107 |
--------------------------------------------------------------------------------
/examples/smart_beta.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Heerozh (Zhang Jianhao)
3 | @copyright: Copyright 2020, Heerozh. All rights reserved.
4 | @license: Apache 2.0
5 | @email: heeroz@gmail.com
6 | """
7 |
8 | from spectre import factors, trading
9 | from spectre.data import YahooDownloader, ArrowLoader
10 | import pandas as pd
11 |
12 |
13 | class SmartBeta(trading.CustomAlgorithm):
14 | def initialize(self):
15 | # setup engine
16 | engine = self.get_factor_engine()
17 | engine.to_cuda()
18 |
19 | universe = factors.AverageDollarVolume(win=120).top(500)
20 | engine.set_filter(universe)
21 |
22 | # SP500 factor
23 | sp500 = factors.AverageDollarVolume(win=63)
24 | # our alpha is put more weight on NVDA! StaticAssets return True(1) on NVDA
25 | # and False(0) on others
26 | alpha = sp500 * (factors.StaticAssets({'NVDA'})*5 + 1)
27 | engine.add(alpha.to_weight(demean=False), 'weight')
28 |
29 | # schedule rebalance before market close
30 | self.schedule_rebalance(trading.event.MarketClose(self.rebalance, offset_ns=-10000))
31 |
32 | # simulation parameters
33 | self.blotter.capital_base = 1000000
34 | self.blotter.set_commission(percentage=0, per_share=0.005, minimum=1)
35 |
36 | def rebalance(self, data: pd.DataFrame, history: pd.DataFrame):
37 | self.blotter.batch_order_target_percent(data.index, data.weight)
38 | # closing asset position that are no longer in our universe.
39 | removes = self.blotter.portfolio.positions.keys() - set(data.index)
40 | self.blotter.batch_order_target_percent(removes, [0] * len(removes))
41 |
42 | def terminate(self, records: 'pd.DataFrame'):
43 | # plotting results
44 | self.plot(benchmark='SPY')
45 |
46 |
47 | if __name__ == '__main__':
48 | import plotly.io as pio
49 | pio.renderers.default = "browser"
50 |
51 | import argparse
52 | parser = argparse.ArgumentParser()
53 | parser.add_argument("--download", help="download yahoo data")
54 | args = parser.parse_args()
55 |
56 | if args.download:
57 | YahooDownloader.ingest(
58 | start_date="2001", save_to="./yahoo",
59 | symbols=None, skip_exists=True)
60 |
61 | loader = ArrowLoader('./yahoo/yahoo.feather')
62 | results = trading.run_backtest(loader, SmartBeta, '2013-01-01', '2018-01-01')
63 | print(results.transactions)
64 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | requests
2 | python>=3.5
3 | pyarrow
4 | numpy
5 | pandas>=0.22
6 | bs4
7 | lxml
8 | torch>=1.3
9 | plotly
10 | tqdm
11 |
12 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(
4 | name="spectre",
5 | author='Zhang Jianhao',
6 | author_email='heeroz@gmail.com',
7 | description='GPU-accelerated Parallel quantitative trading library',
8 | long_description=open('README.md', encoding='utf-8').read(),
9 | license='Apache 2.0',
10 | keywords='quantitative analysis backtesting parallel algorithmic trading',
11 | url='https://github.com/Heerozh/spectre',
12 | classifiers=[
13 | "Programming Language :: Python :: 3",
14 | "Operating System :: OS Independent",
15 | ],
16 | python_requires='>=3.5',
17 |
18 | version="0.5",
19 | packages=['spectre', 'spectre.data', 'spectre.factors', 'spectre.parallel', 'spectre.trading',
20 | 'spectre.plotting'],
21 | )
22 |
--------------------------------------------------------------------------------
/spectre/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from . import data
3 | from . import factors
4 | from . import parallel
5 | from . import trading
6 | from . import plotting
7 |
8 |
9 | __all__ = [
10 | "data",
11 | "factors",
12 | "parallel",
13 | "trading",
14 | "plotting",
15 | ]
16 |
--------------------------------------------------------------------------------
/spectre/config.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Global:
5 | # default float type for engine (not affect dataloader)
6 | float_type = torch.float32
7 |
8 |
--------------------------------------------------------------------------------
/spectre/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataloader import (
2 | DataLoader,
3 | DataLoaderFastGetter,
4 | )
5 | from .arrow import (
6 | ArrowLoader,
7 | )
8 | from .csv import (
9 | CsvDirLoader,
10 | )
11 | from .memory import (
12 | MemoryLoader,
13 | )
14 | from .quandl import (
15 | QuandlLoader,
16 | )
17 | from .yahoo import (
18 | YahooDownloader,
19 | )
20 |
--------------------------------------------------------------------------------
/spectre/data/arrow.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Heerozh (Zhang Jianhao)
3 | @copyright: Copyright 2019-2020, Heerozh. All rights reserved.
4 | @license: Apache 2.0
5 | @email: heeroz@gmail.com
6 | """
7 | import pandas as pd
8 | import os
9 | import warnings
10 | from .dataloader import DataLoader
11 |
12 |
13 | class ArrowLoader(DataLoader):
14 | """ Read from persistent data. """
15 |
16 | def __init__(self, path: str = None, keep_in_memory: bool = True) -> None:
17 | if not os.path.exists(path + '.meta'):
18 | raise FileNotFoundError(os.path.abspath(path + '.meta'))
19 |
20 | # pandas 0.22 has the fastest MultiIndex
21 | if pd.__version__.startswith('0.22'):
22 | import feather
23 | cols = feather.read_dataframe(path + '.meta')
24 | else:
25 | cols = pd.read_feather(path + '.meta')
26 |
27 | ohlcv = cols.ohlcv.values
28 | adjustments = cols.adjustments.values[:2]
29 | if adjustments[0] is None:
30 | adjustments = None
31 | super().__init__(path, ohlcv, adjustments)
32 | self.keep_in_memory = keep_in_memory
33 | self._cache = None
34 | self._filter = None
35 |
36 | @classmethod
37 | def _last_modified(cls, file_path) -> float:
38 | if not os.path.isfile(file_path):
39 | return 0
40 | else:
41 | return os.path.getmtime(file_path)
42 |
43 | @property
44 | def last_modified(self) -> float:
45 | return self._last_modified(self._path)
46 |
47 | @classmethod
48 | def ingest(cls, source: DataLoader, save_to, force: bool = False) -> None:
49 | if not force and (source.last_modified <= cls._last_modified(save_to)):
50 | warnings.warn("You called `ingest()`, but `source` seems unchanged, "
51 | "no ingestion required. Set `force=True` to re-ingest.",
52 | RuntimeWarning)
53 | return
54 |
55 | df = source.test_load()
56 | df.reset_index(inplace=True)
57 | df.to_feather(save_to)
58 |
59 | meta = pd.DataFrame(columns=['ohlcv', 'adjustments'])
60 | meta.ohlcv = source.ohlcv
61 | meta.adjustments[:2] = source.adjustments
62 | # meta.loc[:2, "adjustments"] = source.adjustments
63 | meta.to_feather(save_to + '.meta')
64 |
65 | def filter(self, func):
66 | self._filter = func
67 | self._cache = None
68 |
69 | def _load(self) -> pd.DataFrame:
70 | if self._cache is not None:
71 | return self._cache
72 |
73 | if pd.__version__.startswith('0.22'):
74 | import feather
75 | df = feather.read_dataframe(self._path)
76 | else:
77 | df = pd.read_feather(self._path)
78 | df.set_index(['date', 'asset'], inplace=True)
79 |
80 | if self._filter is not None:
81 | df = self._filter(df)
82 |
83 | if self.keep_in_memory:
84 | self._cache = df
85 | return df
86 |
--------------------------------------------------------------------------------
/spectre/data/csv.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Heerozh (Zhang Jianhao)
3 | @copyright: Copyright 2019-2020, Heerozh. All rights reserved.
4 | @license: Apache 2.0
5 | @email: heeroz@gmail.com
6 | """
7 | import numpy as np
8 | import pandas as pd
9 | import os
10 | import glob
11 | import warnings
12 | from .dataloader import DataLoader
13 |
14 |
15 | class CsvDirLoader(DataLoader):
16 | def __init__(self, prices_path: str, prices_by_year=False, earliest_date: pd.Timestamp = None,
17 | dividends_path=None, splits_path=None, file_pattern='*.csv',
18 | calender_asset: str = None, align_by_time=False,
19 | ohlcv=('open', 'high', 'low', 'close', 'volume'), adjustments=None,
20 | split_ratio_is_inverse=False, split_ratio_is_fraction=False,
21 | prices_index='date', dividends_index='exDate', splits_index='exDate', **read_csv):
22 | """
23 | Load data from csv dir
24 | :param prices_path: prices csv folder, structured as one csv per stock.
25 | When encountering duplicate indexes data in `prices_index`, Loader will keep the last,
26 | drop others.
27 | :param prices_by_year: If price file name like 'spy_2017.csv', set this to True
28 | :param earliest_date: Data before this date will not be saved to memory. Note: Use the
29 | same time zone as in the csv files.
30 | :param dividends_path: dividends csv folder, structured as one csv per stock.
31 | For duplicate data, loader will first drop the exact same rows, and then for the same
32 | `dividends_index` but different 'dividend amount' rows, loader will sum them up.
33 | If `dividends_path` not set, the `adjustments[0]` column is considered to be included
34 | in the prices csv.
35 | :param splits_path: splits csv folder, structured as one csv per stock,
36 | When encountering duplicate indexes data in `splits_index`, Loader will use the last
37 | non-NaN 'split ratio', drop others.
38 | If `splits_path` not set, the `adjustments[1]` column is considered to be included
39 | in the prices csv.
40 | :param file_pattern: csv file name pattern, default is '*.csv'.
41 | :param calender_asset: asset name as trading calendar, like 'SPY', for clean up non-trading
42 | time data.
43 | :param align_by_time: if True and `calender_asset` not None, df index will be the product of
44 | 'date' and 'asset'.
45 | :param ohlcv: Required, OHLCV column names. When you don't need to use `adjustments` and
46 | `factors.OHLCV`, you can set this to None.
47 | :param adjustments: Optional, `dividend amount` and `splits ratio` column names.
48 | :param split_ratio_is_inverse: If split ratio calculated by to/from, set to True.
49 | For example, 2-for-1 split, to/form = 2, 1-for-15 Reverse Split, to/form = 0.6666...
50 | :param split_ratio_is_fraction: If split ratio in csv is fraction string, set to True.
51 | :param prices_index: `index_col`for csv in prices_path
52 | :param dividends_index: `index_col`for csv in dividends_path.
53 | :param splits_index: `index_col`for csv in splits_path.
54 | :param read_csv: Parameters for all csv when calling pd.read_csv.
55 | """
56 | if adjustments is None:
57 | super().__init__(prices_path, ohlcv, None)
58 | else:
59 | super().__init__(prices_path, ohlcv, ('ex-dividend', 'split_ratio'))
60 |
61 | assert 'index_col' not in read_csv, \
62 | "`index_col` cannot be used here. Use `prices_index` and `dividends_index` and " \
63 | "`splits_index` instead."
64 | if 'dtype' not in read_csv:
65 | warnings.warn("It is recommended to set the `dtype` parameter and use float32 whenever "
66 | "possible. Example: dtype = {'Open': np.float32, 'High': np.float32, "
67 | "'Low': np.float32, 'Close': np.float32, 'Volume': np.float64}",
68 | RuntimeWarning)
69 | self._adjustment_cols = adjustments
70 | self._split_ratio_is_inverse = split_ratio_is_inverse
71 | self._split_ratio_is_fraction = split_ratio_is_fraction
72 | self._prices_by_year = prices_by_year
73 | self._earliest_date = earliest_date
74 | self._dividends_path = dividends_path
75 | self._splits_path = splits_path
76 | self._file_pattern = file_pattern
77 | self._calender = calender_asset
78 | self._prices_index = prices_index
79 | self._dividends_index = dividends_index
80 | self._splits_index = splits_index
81 | self._read_csv = read_csv
82 | self._align_by_time = align_by_time
83 |
84 | @property
85 | def last_modified(self) -> float:
86 | pattern = os.path.join(self._path, self._file_pattern)
87 | files = glob.glob(pattern)
88 | if len(files) == 0:
89 | raise ValueError("Dir '{}' does not contains any csv files.".format(self._path))
90 | return max([os.path.getmtime(fn) for fn in files])
91 |
92 | def _walk_split_by_year_dir(self, csv_path, index_col):
93 | years = set(pd.date_range(self._earliest_date or 0, pd.Timestamp.now()).year)
94 | pattern = os.path.join(csv_path, self._file_pattern)
95 | files = glob.glob(pattern)
96 | assets = {}
97 | for fn in files:
98 | base = os.path.basename(fn)
99 | symbol, year = base[:-9].upper(), int(base[-8:-4]) # like 'spy_2011.csv'
100 | if year in years:
101 | if symbol in assets:
102 | assets[symbol].append(fn)
103 | else:
104 | assets[symbol] = [fn, ]
105 |
106 | def multi_read_csv(file_list):
107 | df = pd.concat([pd.read_csv(_fn, index_col=index_col, **self._read_csv)
108 | for _fn in file_list])
109 | if not isinstance(df.index, pd.DatetimeIndex):
110 | raise ValueError(
111 | "df must index by datetime, set correct `read_csv`, "
112 | "for example index_col='date', parse_dates=True. "
113 | "For mixed-timezone like daylight saving time, "
114 | "set date_parser=lambda col: pd.to_datetime(col, utc=True)")
115 |
116 | return df[~df.index.duplicated(keep='last')]
117 |
118 | dfs = {symbol: multi_read_csv(file_list) for symbol, file_list in assets.items()}
119 | return dfs
120 |
121 | def _walk_dir(self, csv_path, index_col):
122 | pattern = os.path.join(csv_path, self._file_pattern)
123 | files = glob.glob(pattern)
124 | if len(files) == 0:
125 | raise ValueError("There are no files is {}".format(csv_path))
126 |
127 | def symbol(file):
128 | return os.path.basename(file)[:-4].upper()
129 |
130 | def read_csv(file):
131 | df = pd.read_csv(file, index_col=index_col, **self._read_csv)
132 | if len(df.index.dropna()) == 0:
133 | return None
134 | if not isinstance(df.index, pd.DatetimeIndex):
135 | raise ValueError(
136 | "df must index by datetime, set correct `read_csv`, "
137 | "for example parse_dates=True. "
138 | "For mixed-timezone like daylight saving time, "
139 | "set date_parser=lambda col: pd.to_datetime(col, utc=True)")
140 | return df[self._earliest_date:]
141 |
142 | dfs = {symbol(fn): read_csv(fn) for fn in files}
143 | return dfs
144 |
145 | def _load(self):
146 | if self._prices_by_year:
147 | dfs = self._walk_split_by_year_dir(self._path, self._prices_index)
148 | else:
149 | dfs = self._walk_dir(self._path, self._prices_index)
150 | dfs = {k: v[~v.index.duplicated(keep='last')] for k, v in dfs.items() if v is not None}
151 | df = pd.concat(dfs, sort=False)
152 | df = df.rename_axis(['asset', 'date'])
153 |
154 | if self.ohlcv is not None:
155 | # 这里把0当成nan进行ffill,如果未来取消了,要先把0变成nan,然后df.ffill
156 | df[list(self.ohlcv)] = df[list(self.ohlcv)].replace(to_replace=0, method='ffill')
157 |
158 | if self._dividends_path is not None:
159 | dfs = self._walk_dir(self._dividends_path, self._dividends_index)
160 | ex_div_col = self._adjustment_cols[0]
161 | div_index = self._dividends_index
162 |
163 | def _agg_duplicated(_df):
164 | if _df is None or ex_div_col not in _df:
165 | return None
166 | _df = _df.reset_index().drop_duplicates()
167 | _df = _df.dropna(subset=[ex_div_col])
168 | _df = _df.set_index(div_index)[ex_div_col]
169 | return _df.groupby(level=0).agg('sum')
170 |
171 | dfs = {k: _agg_duplicated(v) for k, v in dfs.items()}
172 | div = pd.concat(dfs, sort=False)
173 | div = div.reindex(df.index)
174 | div = div.fillna(0)
175 | div.name = self._adjustments[0]
176 | # div = df.rename_axis(['asset', 'date'])
177 | df = pd.concat([df, div], axis=1)
178 |
179 | if self._splits_path is not None:
180 | dfs = self._walk_dir(self._splits_path, self._splits_index)
181 | sp_rto_col = self._adjustment_cols[1]
182 |
183 | def _drop_na_and_duplicated(_df):
184 | if _df is None or sp_rto_col not in _df:
185 | return None
186 | _df = _df.dropna(subset=[sp_rto_col])[sp_rto_col]
187 | return _df[~_df.index.duplicated(keep='last')]
188 |
189 | dfs = {k: _drop_na_and_duplicated(v) for k, v in dfs.items()}
190 | splits = pd.concat(dfs, sort=False)
191 | if self._split_ratio_is_fraction:
192 | from fractions import Fraction
193 |
194 | def fraction_2_float(x):
195 | try:
196 | return float(Fraction(x))
197 | except (ValueError, ZeroDivisionError):
198 | return np.nan
199 |
200 | splits = splits.apply(fraction_2_float)
201 | splits = splits.reindex(df.index)
202 | splits = splits.fillna(1)
203 | splits.name = self._adjustments[1]
204 | df = pd.concat([df, splits], axis=1)
205 |
206 | df = df.swaplevel(0, 1).sort_index(level=0)
207 |
208 | if self._calender:
209 | # drop the data of the non-trading day by calender,
210 | # because there may be some one-line junk data in non-trading day,
211 | # causing extra row of nan to all others assets.
212 | df = self._align_to(df, self._calender, self._align_by_time)
213 | df.sort_index(level=[0, 1], inplace=True)
214 | df = self._format(df, self._split_ratio_is_inverse)
215 | return df
216 |
--------------------------------------------------------------------------------
/spectre/data/dataloader.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Heerozh (Zhang Jianhao)
3 | @copyright: Copyright 2019-2020, Heerozh. All rights reserved.
4 | @license: Apache 2.0
5 | @email: heeroz@gmail.com
6 | """
7 | from typing import Optional
8 | import pandas as pd
9 | import numpy as np
10 |
11 |
12 | class DataLoader:
13 | def __init__(self, path: str, ohlcv=('open', 'high', 'low', 'close', 'volume'),
14 | adjustments=('ex-dividend', 'split_ratio')) -> None:
15 | self._path = path
16 | self._ohlcv = ohlcv
17 | self._adjustments = adjustments
18 |
19 | @property
20 | def ohlcv(self):
21 | return self._ohlcv
22 |
23 | @property
24 | def adjustments(self):
25 | return self._adjustments
26 |
27 | @property
28 | def adjustment_multipliers(self):
29 | return ['price_multi', 'vol_multi']
30 |
31 | @property
32 | def time_category(self):
33 | return '_time_cat_id'
34 |
35 | @property
36 | def last_modified(self) -> float:
37 | """ data source last modification time """
38 | raise NotImplementedError("abstractmethod")
39 |
40 | @property
41 | def min_timedelta(self) -> pd.Timedelta:
42 | """ Minimum time delta of date index """
43 | date_idx = self.load().index.levels[0]
44 | return min(date_idx[1:] - date_idx[:-1])
45 |
46 | @classmethod
47 | def _align_to(cls, df, calender_asset, align_by_time=False):
48 | """ helper method for align index """
49 | index = df.loc[(slice(None), calender_asset), :].index.get_level_values(0)
50 | df = df[df.index.get_level_values(0).isin(index)]
51 | df.index = df.index.remove_unused_levels()
52 | if align_by_time:
53 | df = df.reindex(pd.MultiIndex.from_product(df.index.levels))
54 | # df = df.unstack(level=1).stack(dropna=False) 这个速度快,但新版本取消掉了
55 |
56 | def trim_nans(x):
57 | dts = x.index.get_level_values(0)
58 | first = x.first_valid_index()
59 | last = x.last_valid_index()
60 | if first is None:
61 | return None
62 | mask = (dts >= first[0]) & (dts <= last[0])
63 | return x[mask]
64 |
65 | df = df.groupby(level=1, group_keys=False).apply(trim_nans)
66 | return df
67 |
68 | def _format(self, df, split_ratio_is_inverse=False) -> pd.DataFrame:
69 | """
70 | Format the data as we want it. df index must be in order [datetime, asset_name]
71 | * change index name to ['date', 'asset']
72 | * change asset column type to category
73 | * covert date index to utc timezone
74 | * create time_cat column
75 | * create adjustment multipliers columns
76 | """
77 | # print(pd.Timestamp.now(), 'Formatting index...')
78 | df = df.rename_axis(['date', 'asset'])
79 | # speed up asset index search time
80 | df = df.reset_index()
81 | asset_type = pd.api.types.CategoricalDtype(categories=pd.unique(df.asset).sort(),
82 | ordered=True)
83 | df.asset = df.asset.astype(asset_type)
84 | # format index and convert to utc timezone-aware
85 | df.set_index(['date', 'asset'], inplace=True)
86 | if df.index.levels[0].tzinfo is None:
87 | df = df.tz_localize('UTC', level=0, copy=False)
88 | else:
89 | df = df.tz_convert('UTC', level=0, copy=False)
90 | df.sort_index(level=[0, 1], inplace=True)
91 | # generate time key for parallel
92 | # print(pd.Timestamp.now(), 'Formatting time key for gpu sorting...')
93 | date_index = df.index.get_level_values(0)
94 | unique_date = date_index.unique()
95 | time_cat = dict(zip(unique_date, range(len(unique_date))))
96 | # cat = np.fromiter(map(lambda x: time_cat[x], date_index), dtype=int)
97 | df[self.time_category] = date_index.map(time_cat)
98 | # print(pd.Timestamp.now(), 'Done.')
99 |
100 | # Process dividends and split
101 | if self.adjustments is not None:
102 | div_col = self.adjustments[0]
103 | spr_col = self.adjustments[1]
104 | close_col = self.ohlcv[3]
105 | price_multi_col = self.adjustment_multipliers[0]
106 | vol_multi_col = self.adjustment_multipliers[1]
107 | if split_ratio_is_inverse:
108 | df[spr_col] = 1 / df[spr_col]
109 |
110 | # move ex-div up 1 row
111 | groupby = df.groupby(level=1, observed=False)
112 | last = pd.DataFrame.last_valid_index
113 | ex_div = groupby[div_col].shift(-1)
114 | ex_div.loc[groupby.apply(last)] = 0
115 | sp_rto = groupby[spr_col].shift(-1)
116 | sp_rto.loc[groupby.apply(last)] = 1
117 |
118 | df[div_col] = ex_div
119 | df[spr_col] = sp_rto
120 |
121 | # generate dividend multipliers
122 | price_multi = (1 - ex_div / df[close_col]) * sp_rto
123 | price_multi = price_multi[::-1].groupby(level=1, observed=False).cumprod()[::-1]
124 | df[price_multi_col] = price_multi.astype(np.float32)
125 | vol_multi = (1 / sp_rto)[::-1].groupby(level=1, observed=False).cumprod()[::-1]
126 | df[vol_multi_col] = vol_multi.astype(np.float32)
127 |
128 | return df
129 |
130 | def _load(self) -> pd.DataFrame:
131 | """
132 | Return dataframe with multi-index ['date', 'asset']
133 |
134 | You need to call `self.test_load()` in your test case to check if the format
135 | you returned is correct.
136 | """
137 | raise NotImplementedError("abstractmethod")
138 |
139 | def test_load(self):
140 | """
141 | Basic test for the format returned by _load(),
142 | If you write your own Loader, call this method at your test case.
143 | """
144 | df = self._load()
145 |
146 | assert df.index.names == ['date', 'asset'], \
147 | "df.index.names should be ['date', 'asset'] "
148 | assert not any(df.index.duplicated()), \
149 | "There are duplicate indexes in df, you need handle them up."
150 | assert df.index.is_monotonic_increasing, \
151 | "df.index must be sorted, try using df.sort_index(level=0, inplace=True)"
152 | assert str(df.index.levels[0].tzinfo) == 'UTC', \
153 | "df.index.date must be UTC timezone."
154 | assert df.index.levels[-1].ordered, \
155 | "df.index.asset must ordered categorical."
156 | assert self.time_category in df, \
157 | "You must create a time_category column, convert time to category id"
158 |
159 | if self.adjustments:
160 | assert all(x in df for x in self.adjustments), \
161 | "Adjustments columns `{}` not found.".format(self.adjustments)
162 | assert all(x in df for x in self.adjustment_multipliers), \
163 | "Adjustment multipliers columns `{}` not found.".format(self.adjustment_multipliers)
164 | assert not any(df[self.adjustments[0]].isna()), \
165 | "There is nan value in ex-dividend column, should be filled with 0."
166 | assert not any(df[self.adjustments[1]].isna()), \
167 | "There is nan value in split_ratio column, should be filled with 1."
168 | assert not any(df[self.time_category].isna()), \
169 | "There is nan value in time_category column, should be filled with time id."
170 | return df
171 |
172 | def load(self, start: Optional[pd.Timestamp] = None, end: Optional[pd.Timestamp] = None,
173 | backwards: int = 0) -> pd.DataFrame:
174 | df = self._load()
175 |
176 | index = df.index.levels[0]
177 |
178 | if start is None:
179 | start = index[0]
180 | if end is None:
181 | end = index[-1]
182 |
183 | if index[0] > start:
184 | raise ValueError(
185 | f"`start` time ({start}) cannot be less " \
186 | f"than earliest time of data: {index[0]}."
187 | )
188 |
189 | if index[-1] < end:
190 | raise ValueError(
191 | f"`end` time ({end}) cannot be greater " \
192 | f"than latest time of data: {index[-1]}."
193 | )
194 |
195 | start_loc = index.get_indexer([start], method='bfill')[0]
196 | backward_loc = max(start_loc - backwards, 0)
197 | end_loc = index.get_indexer([end], method='ffill')[0]
198 | assert end_loc >= start_loc, 'There is no data between `start` and `end`.'
199 |
200 | backward_start = index[backward_loc]
201 | return df.loc[backward_start:end]
202 |
203 |
204 | class DataLoaderFastGetter:
205 | """Fast get method for dataloader's DataFrame"""
206 | class DictLikeCursor:
207 | def __init__(self, parent, row_slice, column_id):
208 | self.parent = parent
209 | self.row_slice = row_slice
210 | self.data = parent.raw_data[row_slice, column_id]
211 | self.index = parent.asset_index[row_slice]
212 | self.length = len(self.index)
213 |
214 | def get_datetime_index(self):
215 | return self.parent.indexes[0][self.row_slice]
216 |
217 | def __getitem__(self, asset):
218 | asset_id = self.parent.asset_to_code[asset]
219 | cursor_index = self.index
220 | i = cursor_index.searchsorted(asset_id)
221 | if i >= self.length:
222 | raise KeyError('{} not found'.format(asset))
223 | if cursor_index[i] != asset_id:
224 | raise KeyError('{} not found'.format(asset))
225 | return self.data[i]
226 |
227 | def items(self):
228 | idx = self.index
229 | code_to_asset = self.parent.code_to_asset
230 | for i in range(self.length):
231 | code = idx[i]
232 | name = code_to_asset[code]
233 | yield name, self.data[i]
234 |
235 | def get(self, asset, default=None):
236 | try:
237 | return self[asset]
238 | except KeyError:
239 | return default
240 |
241 | def __init__(self, df):
242 | cat = df.index.get_level_values(1)
243 |
244 | self.source = df
245 | self.raw_data = df.values
246 | self.columns = df.columns
247 | self.indexes = [df.index.get_level_values(0), cat]
248 | self.asset_index = cat.codes
249 | self.asset_to_code = {v: k for k, v in enumerate(cat.categories)}
250 | self.code_to_asset = dict(enumerate(cat.categories))
251 | self.last_row_slice = None
252 |
253 | def get_slice(self, start, stop):
254 | if isinstance(start, slice):
255 | return start
256 | idx = self.indexes[0]
257 | stop = stop or start
258 | row_slice = slice(idx.searchsorted(start), idx.searchsorted(stop, side='right'))
259 | return row_slice
260 |
261 | def get_as_dict(self, start, stop=None, column_id=slice(None)):
262 | row_slice = self.get_slice(start, stop)
263 | cur = self.DictLikeCursor(self, row_slice, column_id)
264 | self.last_row_slice = row_slice
265 | return cur
266 |
267 | def get_as_df(self, start, stop=None):
268 | """550x faster than .loc[], 3x faster than .iloc[]"""
269 | row_slice = self.get_slice(start, stop)
270 | data = self.raw_data[row_slice]
271 | index = self.indexes[1][row_slice]
272 | self.last_row_slice = row_slice
273 | return pd.DataFrame(data, index=index, columns=self.columns)
274 |
--------------------------------------------------------------------------------
/spectre/data/iex.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Heerozh (Zhang Jianhao)
3 | @copyright: Copyright 2020, Heerozh. All rights reserved.
4 | @license: Apache 2.0
5 | @email: heeroz@gmail.com
6 | """
7 | import os
8 | import numpy as np
9 | from .dataloader import ArrowLoader, CsvDirLoader
10 | from .iex_fetcher import iex
11 |
12 |
13 | class IexDownloader: # todo IexDownloader unfinished
14 | @classmethod
15 | def _concat(cls):
16 | pass
17 |
18 | @classmethod
19 | def ingest(cls, iex_key, save_to, range_='5y', symbols: list = None, skip_exists=True):
20 | """
21 | Download data from IEX. Please note that downloading all the data will cost around $60.
22 | :param iex_key: your private api key of IEX account.
23 | :param save_to: path to folder
24 | :param range_: historical range, supports 5y, 2y, 1y, ytd, 6m, 3m, 1m.
25 | :param symbols: list of symbol to download. If is None, download All Stocks
26 | (not including delisted).
27 | :param skip_exists: skip if file exists, useful for resume from interruption.
28 | """
29 | from tqdm.auto import tqdm
30 | print("Download prices from IEX...")
31 | iex.init(iex_key, api='cloud')
32 |
33 | calender_asset = None
34 | if symbols is None:
35 | symbols = iex.Reference.symbols()
36 | types = (symbols.type == 'ad') | (symbols.type == 'cs') & (symbols.exchange != 'OTC')
37 | symbols = symbols[types].symbol.values
38 | symbols.extend(['SPY', 'QQQ'])
39 | calender_asset = 'SPY'
40 |
41 | def download(event, folder):
42 | for symbol in tqdm(symbols):
43 | csv_path = os.path.join(folder, '{}.csv'.format(symbol))
44 | if os.path.exists(csv_path) and skip_exists:
45 | continue
46 | if event == 'chart':
47 | iex.Stock(symbol).chart(range_).to_csv(csv_path)
48 | elif event == 'dividends':
49 | iex.Stock(symbol).dividends(range_).to_csv(csv_path)
50 | elif event == 'splits':
51 | iex.Stock(symbol).splits(range_).to_csv(csv_path)
52 |
53 | print('Ingest prices...')
54 | prices_dir = os.path.join(save_to, 'daily')
55 | if not os.path.exists(prices_dir):
56 | os.makedirs(prices_dir)
57 | download('chart', prices_dir)
58 |
59 | print('Ingest dividends...')
60 | div_dir = os.path.join(save_to, 'dividends')
61 | if not os.path.exists(div_dir):
62 | os.makedirs(div_dir)
63 | download('dividends', div_dir)
64 |
65 | print('Ingest splits...')
66 | sp_dir = os.path.join(save_to, 'splits')
67 | if not os.path.exists(sp_dir):
68 | os.makedirs(sp_dir)
69 | download('splits', sp_dir)
70 |
71 | print('Converting...')
72 | use_cols = {'date', 'uOpen', 'uHigh', 'uLow', 'uClose', 'uVolume', 'exDate', 'amount',
73 | 'ratio'}
74 | loader = CsvDirLoader(
75 | prices_dir, calender_asset=calender_asset,
76 | dividends_path=div_dir, splits_path=sp_dir,
77 | ohlcv=('uOpen', 'uHigh', 'uLow', 'uClose', 'uVolume'), adjustments=('amount', 'ratio'),
78 | prices_index='date', dividends_index='exDate', splits_index='exDate',
79 | parse_dates=True, usecols=lambda x: x in use_cols,
80 | dtype={'uOpen': np.float32, 'uHigh': np.float32, 'uLow': np.float32,
81 | 'uClose': np.float32,
82 | 'uVolume': np.float64, 'amount': np.float64, 'ratio': np.float64})
83 |
84 | arrow_file = os.path.join(save_to, 'yahoo.feather')
85 | ArrowLoader.ingest(source=loader, save_to=arrow_file, force=True)
86 |
87 | print('Ingest completed! Use `loader = spectre.data.ArrowLoader(r"{}")` '
88 | 'to load your data.'.format(arrow_file))
89 |
90 | @classmethod
91 | def update(cls, range_, temp_path, save_to):
92 | # todo download and concat
93 | pass
94 |
--------------------------------------------------------------------------------
/spectre/data/memory.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Heerozh (Zhang Jianhao)
3 | @copyright: Copyright 2019-2020, Heerozh. All rights reserved.
4 | @license: Apache 2.0
5 | @email: heeroz@gmail.com
6 | """
7 | from .dataloader import DataLoader
8 |
9 |
10 | class MemoryLoader(DataLoader):
11 | """ Convert pd.Dataframe to spectre.data.DataLoader """
12 | def __init__(self, df, ohlcv=None, adjustments=None) -> None:
13 | super().__init__("", ohlcv, adjustments)
14 | self.df = self._format(df)
15 | self.test_load()
16 |
17 | @property
18 | def last_modified(self) -> float:
19 | return 1
20 |
21 | def _load(self):
22 | return self.df
23 |
--------------------------------------------------------------------------------
/spectre/data/quandl.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Heerozh (Zhang Jianhao)
3 | @copyright: Copyright 2019-2020, Heerozh. All rights reserved.
4 | @license: Apache 2.0
5 | @email: heeroz@gmail.com
6 | """
7 | from zipfile import ZipFile
8 | import numpy as np
9 | import pandas as pd
10 | from .dataloader import DataLoader
11 |
12 |
13 | class QuandlLoader(DataLoader):
14 | @property
15 | def last_modified(self) -> float:
16 | """ the quandl data is no longer updated, so return a fixed value """
17 | return 1
18 |
19 | def __init__(self, file: str, calender_asset='AAPL') -> None:
20 | """
21 | Usage:
22 | download data first:
23 | https://www.quandl.com/api/v3/datatables/WIKI/PRICES.csv?qopts.export=true&api_key=[yourapi_key]
24 | then:
25 | loader = data.QuandlLoader('./quandl/WIKI_PRICES.zip')
26 | """
27 | super().__init__(file,
28 | ohlcv=('open', 'high', 'low', 'close', 'volume'),
29 | adjustments=('ex-dividend', 'split_ratio'))
30 | self._calender = calender_asset
31 |
32 | def _load(self) -> pd.DataFrame:
33 | with ZipFile(self._path) as pkg:
34 | with pkg.open(pkg.namelist()[0]) as csv:
35 | df = pd.read_csv(csv, parse_dates=['date'],
36 | usecols=['ticker', 'date', 'open', 'high', 'low', 'close',
37 | 'volume', 'ex-dividend', 'split_ratio', ],
38 | dtype={
39 | 'open': np.float32, 'high': np.float32, 'low': np.float32,
40 | 'close': np.float32, 'volume': np.float64,
41 | 'ex-dividend': np.float64, 'split_ratio': np.float64
42 | })
43 |
44 | df.set_index(['date', 'ticker'], inplace=True)
45 | df.split_ratio.loc[("2001-09-12", 'GMT')] = 1 # fix nan
46 | if self._calender:
47 | df = self._align_to(df, self._calender)
48 | df.sort_index(level=[0, 1], inplace=True)
49 | df = self._format(df, split_ratio_is_inverse=True)
50 |
51 | return df
52 |
--------------------------------------------------------------------------------
/spectre/data/yahoo.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Heerozh (Zhang Jianhao)
3 | @copyright: Copyright 2019, Heerozh. All rights reserved.
4 | @license: Apache 2.0
5 | @email: heeroz@gmail.com
6 | """
7 | import datetime
8 | import os
9 | import time
10 | import pandas as pd
11 | import numpy as np
12 |
13 | from .arrow import ArrowLoader
14 | from .csv import CsvDirLoader
15 |
16 |
17 | class YahooDownloader:
18 |
19 | @classmethod
20 | def ingest(cls, start_date: str, save_to: str, symbols: list = None, skip_exists=True) -> None:
21 | """
22 | Download data from yahoo.
23 | :param start_date:
24 | :param save_to: path to folder
25 | :param symbols: list of symbol to download. If is None, download SP500 components.
26 | :param skip_exists: skip if file exists, useful for resume from interruption.
27 | """
28 | import requests
29 | import re
30 | from tqdm.auto import tqdm
31 |
32 | print("Download prices from yahoo...")
33 |
34 | start_date = pd.to_datetime(start_date, utc=True)
35 |
36 | calender_asset = None
37 | if symbols is None:
38 | etf = pd.read_html(requests.get(
39 | 'https://etfdailynews.com/etf/spy/', headers={'User-agent': 'Mozilla/5.0'}
40 | ).text, attrs={'id': 'etfs-that-own'})
41 | symbols = [x for x in etf[0].Symbol.values.tolist() if isinstance(x, str)]
42 | symbols.extend(['SPY', 'QQQ'])
43 | calender_asset = 'SPY'
44 |
45 | session = requests.Session()
46 | page = session.get('https://finance.yahoo.com/quote/IBM/history?p=IBM')
47 | # CrumbStore
48 | m = re.search('"CrumbStore":{"crumb":"(.*?)"}', page.text)
49 | crumb = m.group(1)
50 | crumb = crumb.encode('ascii').decode('unicode-escape')
51 |
52 | def download(event, folder):
53 | start = int(start_date.timestamp())
54 | now = int(datetime.datetime.now().timestamp())
55 | for symbol in tqdm(symbols):
56 | symbol = symbol.replace('.', '-')
57 | csv_path = os.path.join(folder, '{}.csv'.format(symbol))
58 | if os.path.exists(csv_path) and skip_exists:
59 | continue
60 | url = "https://query1.finance.yahoo.com/v7/finance/download/" \
61 | "{}?period1={}&period2={}&interval=1d&events={}&crumb={}".format(
62 | symbol, start, now, event, crumb)
63 |
64 | retry = 0.25
65 | while True:
66 | req = session.get(url)
67 | if req.status_code != requests.codes.ok:
68 | if 'No data found' in req.text:
69 | print('Symbol invalid, skipped: {}.'.format(symbol))
70 | break
71 | retry *= 2
72 | if retry >= 5:
73 | print('Get {} failed, Over 4 retries, skipped, reason: {}'.format(
74 | symbol, req.text))
75 | break
76 | else:
77 | time.sleep(retry)
78 | continue
79 | with open(csv_path, 'wb') as f:
80 | f.write(req.content)
81 | break
82 |
83 | print('Ingest prices...')
84 | prices_dir = os.path.join(save_to, 'daily')
85 | if not os.path.exists(prices_dir):
86 | os.makedirs(prices_dir)
87 | download('history', prices_dir)
88 |
89 | # yahoo prices data already split adjusted
90 | # print('Ingest dividends...')
91 | # div_dir = os.path.join(save_to, 'dividends')
92 | # if not os.path.exists(div_dir):
93 | # os.makedirs(div_dir)
94 | # download('div', div_dir)
95 |
96 | # print('Ingest splits...')
97 | # sp_dir = os.path.join(save_to, 'splits')
98 | # if not os.path.exists(sp_dir):
99 | # os.makedirs(sp_dir)
100 | # download('split', sp_dir)
101 |
102 | session.close()
103 |
104 | print('Converting...')
105 | loader = CsvDirLoader(
106 | prices_dir, calender_asset=calender_asset,
107 | # dividends_path=div_dir,
108 | # splits_path=sp_dir,
109 | ohlcv=('Open', 'High', 'Low', 'Close', 'Volume'),
110 | # adjustments=('Dividends', 'Stock Splits'),
111 | prices_index='Date',
112 | # dividends_index='Date', splits_index='Date', split_ratio_is_fraction=True,
113 | parse_dates=True,
114 | dtype={'Open': np.float32, 'High': np.float32, 'Low': np.float32,
115 | 'Close': np.float32,
116 | 'Volume': np.float64, 'Dividends': np.float64})
117 |
118 | arrow_file = os.path.join(save_to, 'yahoo.feather')
119 | ArrowLoader.ingest(source=loader, save_to=arrow_file, force=True)
120 |
121 | print('Ingest completed! Use `loader = spectre.data.ArrowLoader(r"{}")` '
122 | 'to load your data.'.format(arrow_file))
123 |
--------------------------------------------------------------------------------
/spectre/factors/__init__.py:
--------------------------------------------------------------------------------
1 | from .engine import (
2 | FactorEngine,
3 | OHLCV,
4 | )
5 |
6 | from .factor import (
7 | BaseFactor, PlaceHolderFactor,
8 | CustomFactor,
9 | CrossSectionFactor,
10 | RankFactor, RollingRankFactor,
11 | ZScoreFactor, RollingZScoreFactor,
12 | XSMax, XSMin,
13 | QuantileClassifier,
14 | SumFactor, ProdFactor, UniqueTSSumFactor,
15 | MADClampFactor,
16 | WinsorizingFactor,
17 | IQRNormalityFactor,
18 | )
19 |
20 | from .datafactor import (
21 | ColumnDataFactor,
22 | AdjustedColumnDataFactor,
23 | AssetClassifierDataFactor,
24 | SeriesDataFactor,
25 | DatetimeDataFactor,
26 | )
27 |
28 | from .filter import (
29 | FilterFactor,
30 | StaticAssets,
31 | OneHotEncoder,
32 | AndFactor,
33 | AnyFilter,
34 | AllFilter,
35 | PlaceHolderFilter,
36 | )
37 |
38 | from .multiprocessing import (
39 | CPUParallelFactor
40 | )
41 |
42 | from .basic import (
43 | Returns,
44 | LogReturns,
45 | SimpleMovingAverage, MA, SMA,
46 | WeightedAverageValue,
47 | LinearWeightedAverage,
48 | VWAP,
49 | ExponentialWeightedMovingAverage, EMA,
50 | AverageDollarVolume,
51 | AnnualizedVolatility,
52 | ElementWiseMax, ElementWiseMin,
53 | RollingArgMax, RollingArgMin,
54 | ConstantsFactor,
55 | )
56 |
57 | from .technical import (
58 | BollingerBands, BBANDS,
59 | MovingAverageConvergenceDivergenceSignal, MACD,
60 | TrueRange, TRANGE,
61 | RSI,
62 | FastStochasticOscillator, STOCHF,
63 | )
64 |
65 | from .statistical import (
66 | StandardDeviation, STDDEV,
67 | RollingHigh, MAX,
68 | RollingLow, MIN,
69 | RollingLinearRegression,
70 | RollingMomentum,
71 | RollingQuantile,
72 | HalfLifeMeanReversion,
73 | RollingCorrelation,
74 | RollingCovariance,
75 | XSMaxCorrCoef,
76 | InformationCoefficient, RankWeightedInformationCoefficient,
77 | RollingInformationCoefficient,
78 | TTest1Samp, StudentCDF,
79 | CrossSectionR2,
80 | FactorWiseKthValue, FactorWiseZScore,
81 | )
82 |
83 | from .feature import (
84 | MarketDispersion,
85 | MarketReturn,
86 | MarketVolatility,
87 | AdvanceDeclineRatio,
88 | AssetData,
89 | MONTH, WEEKDAY, QUARTER, TIME,
90 | IS_JANUARY, IS_DECEMBER, IS_MONTH_END, IS_MONTH_START, IS_QUARTER_END, IS_QUARTER_START,
91 | )
92 |
93 | from .label import (
94 | RollingFirst,
95 | ForwardSignalData,
96 | )
97 |
--------------------------------------------------------------------------------
/spectre/factors/basic.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Heerozh (Zhang Jianhao)
3 | @copyright: Copyright 2019, Heerozh. All rights reserved.
4 | @license: Apache 2.0
5 | @email: heeroz@gmail.com
6 | """
7 | from typing import Sequence
8 | import numpy as np
9 | import torch
10 | import math
11 | from .factor import BaseFactor, CustomFactor
12 | from ..parallel import nansum, nanmean
13 | from .engine import OHLCV
14 | from ..config import Global
15 |
16 |
17 | class Returns(CustomFactor):
18 | """ Returns by tick (not by time) """
19 | inputs = [OHLCV.close]
20 | win = 2
21 | _min_win = 2
22 |
23 | def compute(self, closes):
24 | # missing data considered as delisted, calculated on the last day's data.
25 | return closes.last_nonnan(offset=1) / closes.first() - 1
26 |
27 |
28 | class LogReturns(CustomFactor):
29 | inputs = [OHLCV.close]
30 | win = 2
31 | _min_win = 2
32 |
33 | def compute(self, closes):
34 | return (closes.last_nonnan() / closes.first()).log()
35 |
36 |
37 | class SimpleMovingAverage(CustomFactor):
38 | inputs = [OHLCV.close]
39 | _min_win = 2
40 |
41 | def compute(self, data):
42 | return data.nanmean()
43 |
44 |
45 | class WeightedAverageValue(CustomFactor):
46 | _min_win = 2
47 |
48 | def compute(self, base, weight):
49 | def _weight_mean(_base, _weight):
50 | return nansum(_base * _weight, dim=2) / nansum(_weight, dim=2)
51 |
52 | return base.agg(_weight_mean, weight)
53 |
54 |
55 | class LinearWeightedAverage(CustomFactor):
56 | _min_win = 2
57 |
58 | def __init__(self, win=None, inputs=None):
59 | super().__init__(win, inputs)
60 | self.weight = torch.arange(1, self.win + 1, dtype=Global.float_type)
61 | self.weight = self.weight / self.weight.sum()
62 |
63 | def pre_compute_(self, engine, start, end) -> None:
64 | super().pre_compute_(engine, start, end)
65 | self.weight = self.weight.to(device=engine.device, copy=False)
66 |
67 | def compute(self, base):
68 | def _weight_mean(_base):
69 | return nansum(_base * self.weight, dim=2)
70 |
71 | return base.agg(_weight_mean)
72 |
73 |
74 | class VWAP(WeightedAverageValue):
75 | inputs = (OHLCV.close, OHLCV.volume)
76 |
77 |
78 | class ExponentialWeightedMovingAverage(CustomFactor):
79 | inputs = [OHLCV.close]
80 | win = 2
81 | _min_win = 2
82 |
83 | def __init__(self, span: int = None, inputs: Sequence[BaseFactor] = None,
84 | adjust=False, half_life: float = None):
85 | if span is not None:
86 | self.alpha = (2.0 / (1.0 + span))
87 | # Length required to achieve 99.97% accuracy, np.log(1-99.97/100) / np.log(alpha)
88 | # simplification to 4 * (span+1). 3.45 achieve 99.90%, 2.26 99.00%
89 | self.win = int(4.5 * (span + 1))
90 | else:
91 | self.alpha = 1 - math.exp(math.log(0.5) / half_life)
92 | self.win = 15 * half_life
93 |
94 | super().__init__(None, inputs)
95 | self.adjust = adjust
96 | self.weight = np.full(self.win, 1 - self.alpha) ** np.arange(self.win - 1, -1, -1)
97 | if self.adjust:
98 | self.weight = self.weight / sum(self.weight) # to sum one
99 |
100 | def pre_compute_(self, engine, start, end) -> None:
101 | super().pre_compute_(engine, start, end)
102 | if not isinstance(self.weight, torch.Tensor):
103 | self.weight = torch.tensor(self.weight, dtype=Global.float_type, device=engine.device)
104 |
105 | def compute(self, data):
106 | self.weight = self.weight.to(device=data.device)
107 | weighted_mean = data.agg(lambda x: nansum(x * self.weight, dim=2))
108 | if self.adjust:
109 | return weighted_mean
110 | else:
111 | shifted = data.last().roll(self.win - 1, dims=1)
112 | shifted[:, 0:self.win - 1] = 0
113 | alpha = self.alpha
114 | return alpha * weighted_mean + (shifted * (1 - alpha) ** self.win)
115 |
116 |
117 | class AverageDollarVolume(CustomFactor):
118 | inputs = [OHLCV.close, OHLCV.volume]
119 |
120 | def compute(self, closes, volumes):
121 | if self.win == 1:
122 | return closes * volumes
123 | else:
124 | return closes.agg(lambda c, v: nanmean(c * v, dim=2), volumes)
125 |
126 |
127 | class AnnualizedVolatility(CustomFactor):
128 | inputs = [Returns(win=2), 252]
129 | window_length = 20
130 | _min_win = 2
131 |
132 | def compute(self, returns, annualization_factor):
133 | return returns.nanstd() * (annualization_factor ** .5)
134 |
135 |
136 | class ElementWiseMax(CustomFactor):
137 | _min_win = 1
138 |
139 | def __init__(self, win=None, inputs=None):
140 | super().__init__(win, inputs)
141 | assert self.win == 1
142 |
143 | @classmethod
144 | def binary_fill_na(cls, a, b, value):
145 | a = a.clone()
146 | b = b.clone()
147 | if a.dtype != b.dtype or a.dtype not in (torch.float32, torch.float64, torch.float16):
148 | a = a.to(Global.float_type)
149 | b = b.to(Global.float_type)
150 |
151 | a.masked_fill_(torch.isnan(a), value)
152 | b.masked_fill_(torch.isnan(b), value)
153 | return a, b
154 |
155 | def compute(self, a, b):
156 | ret = torch.max(*ElementWiseMax.binary_fill_na(a, b, -np.inf))
157 | ret.masked_fill_(torch.isinf(ret), np.nan)
158 | return ret
159 |
160 |
161 | class ElementWiseMin(CustomFactor):
162 | _min_win = 1
163 |
164 | def __init__(self, win=None, inputs=None):
165 | super().__init__(win, inputs)
166 | assert self.win == 1
167 |
168 | def compute(self, a, b):
169 | ret = torch.min(*ElementWiseMax.binary_fill_na(a, b, np.inf))
170 | ret.masked_fill_(torch.isinf(ret), np.nan)
171 | return ret
172 |
173 |
174 | class RollingArgMax(CustomFactor):
175 | _min_win = 2
176 |
177 | def compute(self, data):
178 | def _argmax(_data):
179 | ret = (_data.argmax(dim=2) + 1.) / self.win
180 | return ret.to(Global.float_type)
181 |
182 | return data.agg(_argmax)
183 |
184 |
185 | class RollingArgMin(CustomFactor):
186 | _min_win = 2
187 |
188 | def compute(self, data):
189 | def _argmin(_data):
190 | ret = (_data.argmin(dim=2) + 1.) / self.win
191 | return ret.to(Global.float_type)
192 |
193 | return data.agg(_argmin)
194 |
195 |
196 | class ConstantsFactor(CustomFactor):
197 | def __init__(self, value, like=OHLCV.open):
198 | self.value = value
199 | super().__init__(1, inputs=[like])
200 |
201 | def compute(self, x):
202 | return torch.full(x.shape, self.value, device=x.device, dtype=x.dtype)
203 |
204 |
205 | MA = SimpleMovingAverage
206 | SMA = SimpleMovingAverage
207 | EMA = ExponentialWeightedMovingAverage
208 |
--------------------------------------------------------------------------------
/spectre/factors/datafactor.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Heerozh (Zhang Jianhao)
3 | @copyright: Copyright 2019-2020, Heerozh. All rights reserved.
4 | @license: Apache 2.0
5 | @email: heeroz@gmail.com
6 | """
7 | import numpy as np
8 | import torch
9 | import pandas as pd
10 |
11 | from typing import Optional, Sequence, Union
12 | from ..parallel import nanlast
13 | from .factor import BaseFactor, CustomFactor
14 | from ..config import Global
15 |
16 |
17 | class ColumnDataFactor(BaseFactor):
18 | def __init__(self, inputs: Optional[Sequence[str]] = None,
19 | should_delay=True, dtype=None) -> None:
20 | super().__init__()
21 | if inputs:
22 | self.inputs = inputs
23 | assert (3 > len(self.inputs) > 0), \
24 | "ColumnDataFactor's `inputs` can only contains one data column and corresponding " \
25 | "adjustments column"
26 | self._data = None
27 | self._multiplier = None
28 | self._should_delay = should_delay
29 | self.dtype = dtype
30 |
31 | @property
32 | def adjustments(self):
33 | return self._multiplier
34 |
35 | def get_total_backwards_(self) -> int:
36 | return 0
37 |
38 | def should_delay(self) -> bool:
39 | return self._should_delay
40 |
41 | def pre_compute_(self, engine, start, end) -> None:
42 | super().pre_compute_(engine, start, end)
43 | if self._data is None:
44 | self._data = engine.column_to_tensor_(self.inputs[0])
45 | if self.dtype is not None:
46 | self._data = self._data.to(self.dtype)
47 | self._data = engine.group_by_(self._data, self.groupby)
48 | if len(self.inputs) > 1 and self.inputs[1] in engine.dataframe_:
49 | self._multiplier = engine.column_to_tensor_(self.inputs[1])
50 | self._multiplier = engine.group_by_(self._multiplier, self.groupby)
51 | if self.dtype is not None:
52 | self._multiplier = self._multiplier.to(self.dtype)
53 | else:
54 | self._multiplier = None
55 | self._clean_required = True
56 |
57 | def clean_up_(self, force=False) -> None:
58 | super().clean_up_(force)
59 | self._data = None
60 | self._multiplier = None
61 | self._clean_required = False
62 |
63 | def compute_(self, stream: Union[torch.cuda.Stream, None]) -> torch.Tensor:
64 | return self._data
65 |
66 | def compute(self, *inputs: Sequence[torch.Tensor]) -> torch.Tensor:
67 | pass
68 |
69 | # def adjusted_shift(self, periods=1):
70 | # factor = AdjustedShiftFactor(win=periods, inputs=(self,))
71 | # return factor
72 | #
73 | #
74 | # class AdjustedShiftFactor(CustomFactor):
75 | # """ Shift the root datafactor """
76 | #
77 | # def compute(self, data) -> torch.Tensor:
78 | # return data.first()
79 |
80 |
81 | class AdjustedColumnDataFactor(CustomFactor):
82 | def __init__(self, data: ColumnDataFactor):
83 | super().__init__(1, (data,))
84 | self.parent = data
85 |
86 | def compute(self, data) -> torch.Tensor:
87 | multi = self.parent.adjustments
88 | if multi is None:
89 | return data
90 | return data * multi / nanlast(multi, dim=1)[:, None]
91 |
92 |
93 | class AssetClassifierDataFactor(BaseFactor):
94 | """ Dict to categorical output for asset, slow """
95 | def __init__(self, sector: dict, default: int):
96 | super().__init__()
97 | self.sector = sector
98 | self.default = default
99 | self._data = None
100 |
101 | def get_total_backwards_(self) -> int:
102 | return 0
103 |
104 | def should_delay(self) -> bool:
105 | return False
106 |
107 | def pre_compute_(self, engine, start, end) -> None:
108 | super().pre_compute_(engine, start, end)
109 | assets = engine.dataframe_index[1]
110 | sector = self.sector
111 | default = self.default
112 | data = [sector.get(asset, default) for asset in assets] # slow
113 | data = torch.tensor(data, device=engine.device, dtype=Global.float_type)
114 | self._data = engine.group_by_(data, self.groupby)
115 |
116 | def clean_up_(self, force=False) -> None:
117 | super().clean_up_(force)
118 | self._data = None
119 |
120 | def compute_(self, stream: Union[torch.cuda.Stream, None]) -> torch.Tensor:
121 | return self._data
122 |
123 | def compute(self, *inputs: Sequence[torch.Tensor]) -> torch.Tensor:
124 | pass
125 |
126 |
127 | class SeriesDataFactor(ColumnDataFactor):
128 | """ Add series to engine, slow """
129 | def __init__(self, series: pd.Series, fill_na=None, should_delay=True):
130 | self.series = series
131 | self.fill_na = fill_na
132 | assert series.index.names == ['date', 'asset'], \
133 | "df.index.names should be ['date', 'asset'] "
134 | super().__init__(inputs=[str(series.name)], should_delay=should_delay)
135 |
136 | def pre_compute_(self, engine, start, end) -> None:
137 | if self.series.name not in engine.dataframe_.columns:
138 | engine._dataframe = engine.dataframe_.join(self.series)
139 | if self.fill_na is not None:
140 | engine._dataframe[self.series.name] = engine._dataframe[self.series.name].\
141 | groupby(level=1).fillna(method=self.fill_na)
142 | super().pre_compute_(engine, start, end)
143 |
144 |
145 | class DatetimeDataFactor(BaseFactor):
146 | """ Datetime's attr to DataFactor """
147 | _instance = {}
148 |
149 | def __new__(cls, attr):
150 | if attr not in cls._instance:
151 | cls._instance[attr] = super().__new__(cls)
152 | return cls._instance[attr]
153 |
154 | def __init__(self, attr) -> None:
155 | super().__init__()
156 | self._data = None
157 | self.attr = attr
158 |
159 | def get_total_backwards_(self) -> int:
160 | return 0
161 |
162 | def should_delay(self) -> bool:
163 | return False
164 |
165 | def pre_compute_(self, engine, start, end) -> None:
166 | super().pre_compute_(engine, start, end)
167 | if self._data is None:
168 | data = getattr(engine.dataframe_index[0], self.attr) # slow
169 | if not isinstance(data, np.ndarray):
170 | data = data.values
171 | data = torch.from_numpy(data).to(
172 | device=engine.device, dtype=Global.float_type, non_blocking=True)
173 | self._data = engine.group_by_(data, self.groupby)
174 | self._clean_required = True
175 |
176 | def clean_up_(self, force=False) -> None:
177 | super().clean_up_(force)
178 | self._data = None
179 | self._clean_required = False
180 |
181 | def compute_(self, stream: Union[torch.cuda.Stream, None]) -> torch.Tensor:
182 | return self._data
183 |
184 | def compute(self, *inputs: Sequence[torch.Tensor]) -> torch.Tensor:
185 | pass
186 |
--------------------------------------------------------------------------------
/spectre/factors/feature.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Heerozh (Zhang Jianhao)
3 | @copyright: Copyright 2019-2020, Heerozh. All rights reserved.
4 | @license: Apache 2.0
5 | @email: heeroz@gmail.com
6 | """
7 | import warnings
8 | from .datafactor import DatetimeDataFactor
9 | from .factor import CrossSectionFactor, CustomFactor
10 | from .basic import Returns
11 | from ..parallel import nanstd, nanmean, nansum
12 | from ..config import Global
13 |
14 |
15 | # ----------- Common Market Features -----------
16 |
17 |
18 | class MarketDispersion(CrossSectionFactor):
19 | """Cross-section standard deviation of universe stocks returns."""
20 | inputs = (Returns(), )
21 | win = 1
22 |
23 | def compute(self, returns):
24 | ret = nanstd(returns, dim=1).unsqueeze(-1)
25 | return ret.expand(ret.shape[0], returns.shape[1])
26 |
27 |
28 | class MarketReturn(CrossSectionFactor):
29 | """Cross-section mean returns of universe stocks."""
30 | inputs = (Returns(), )
31 | win = 1
32 |
33 | def compute(self, returns):
34 | ret = nanmean(returns, dim=1).unsqueeze(-1)
35 | return ret.expand(ret.shape[0], returns.shape[1])
36 |
37 |
38 | class MarketVolatility(CustomFactor):
39 | """MarketReturn Rolling standard deviation."""
40 | inputs = (MarketReturn(), 252)
41 | win = 252
42 | _min_win = 2
43 |
44 | def compute(self, returns, annualization_factor):
45 | return (returns.nanvar() * annualization_factor) ** 0.5
46 |
47 |
48 | class AdvanceDeclineRatio(CrossSectionFactor):
49 | """Need to work with MA, and could be applied to volume too"""
50 | inputs = (Returns(), )
51 | win = 1
52 |
53 | def compute(self, returns):
54 | advancing = nansum(returns > 0, dim=1).to(Global.float_type)
55 | declining = nansum(returns < 0, dim=1).to(Global.float_type)
56 | ratio = (advancing / declining).unsqueeze(-1)
57 | return ratio.expand(ratio.shape[0], returns.shape[1])
58 |
59 |
60 | # ----------- Asset-specific data -----------
61 |
62 |
63 | class AssetData(CustomFactor):
64 | def __init__(self, asset, factor):
65 | self.asset = asset
66 | self.asset_ind = None
67 | super().__init__(win=1, inputs=[factor])
68 |
69 | def pre_compute_(self, engine, start, end):
70 | super().pre_compute_(engine, start, end)
71 | if not engine.align_by_time:
72 | warnings.warn("Make sure your data is aligned by time, otherwise will cause data "
73 | "disorder. Or set engine.align_by_time = True.",
74 | RuntimeWarning)
75 | self.asset_ind = engine.dataframe_index[1].unique().categories.get_loc(self.asset)
76 |
77 | def compute(self, data):
78 | ret = data[self.asset_ind]
79 | return ret.expand(data.shape[0], ret.shape[0])
80 |
81 |
82 | # ----------- Common Calendar Features -----------
83 |
84 |
85 | MONTH = DatetimeDataFactor('month')
86 | WEEKDAY = DatetimeDataFactor('weekday')
87 | QUARTER = DatetimeDataFactor('quarter')
88 | TIME = DatetimeDataFactor('hour') + DatetimeDataFactor('minute') / 60.0
89 |
90 | IS_JANUARY = MONTH == 1
91 | IS_DECEMBER = MONTH == 12
92 | # Note: shift(-1) may fail the engine.test_lookahead_bias(),
93 | # but this method is the fastest, so be it.
94 | IS_MONTH_END = MONTH.shift(-1) != MONTH
95 | IS_MONTH_START = MONTH.shift(1) != MONTH
96 | IS_QUARTER_END = QUARTER.shift(-1) != QUARTER
97 | IS_QUARTER_START = QUARTER.shift(1) != QUARTER
98 |
99 |
--------------------------------------------------------------------------------
/spectre/factors/filter.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Heerozh (Zhang Jianhao)
3 | @copyright: Copyright 2019-2020, Heerozh. All rights reserved.
4 | @license: Apache 2.0
5 | @email: heeroz@gmail.com
6 | """
7 | from abc import ABC
8 | from typing import Set
9 | from .factor import BaseFactor, CustomFactor, MultiRetSelector
10 | from ..parallel import Rolling
11 | import torch
12 | import warnings
13 |
14 |
15 | class FilterFactor(CustomFactor, ABC):
16 | def __getitem__(self, key):
17 | return FilterMultiRetSelector(inputs=(self, key))
18 |
19 | def shift(self, periods=1):
20 | factor = FilterRawShiftFactor(inputs=(self,))
21 | factor.periods = periods
22 | return factor
23 |
24 | def sum(self, win):
25 | raise ValueError("FilterFactor does not support `.sum()` method, "
26 | "please convert to float by using `filter_factor.float()`")
27 |
28 | def filter(self, mask):
29 | raise ValueError("FilterFactor does not support local filtering `.filter()` method, "
30 | "please convert to float by using `filter_factor.float()`")
31 |
32 | def ts_any(self, win):
33 | """ Return True if Rolling window contains any True """
34 | return AnyFilter(win, inputs=(self,))
35 |
36 | def ts_all(self, win):
37 | """ Return True if Rolling window all are True """
38 | return AllFilter(win, inputs=(self,))
39 |
40 |
41 | class FilterMultiRetSelector(MultiRetSelector, FilterFactor):
42 | """MultiRetSelector returns CustomFactor, we're override here as FilterFactor"""
43 | pass
44 |
45 |
46 | class PlaceHolderFilter(FilterFactor):
47 | def compute(self, data: torch.Tensor) -> torch.Tensor:
48 | return data
49 |
50 |
51 | class FilterRawShiftFactor(FilterFactor):
52 | """For "roll_cuda" not implemented for 'Bool' """
53 | periods = 1
54 |
55 | def compute(self, data: torch.Tensor) -> torch.Tensor:
56 | shift = data.char().roll(self.periods, dims=1)
57 | if self.periods > 0:
58 | shift[:, 0:self.periods] = 0
59 | else:
60 | shift[:, self.periods:] = 0
61 |
62 | return shift.bool()
63 |
64 |
65 | class StaticAssets(FilterFactor):
66 | """Useful for remove specific outliers or debug some assets"""
67 | def __init__(self, assets: Set[str]):
68 | from .engine import OHLCV
69 | super().__init__(win=1, inputs=[OHLCV.open])
70 | self.assets = assets
71 |
72 | def compute(self, data: torch.Tensor) -> torch.Tensor:
73 | s = self._revert_to_series(data)
74 | ret = s.index.isin(self.assets, level=1)
75 | return self._regroup(ret)
76 |
77 |
78 | class OneHotEncoder(FilterFactor):
79 | def __init__(self, input_: BaseFactor):
80 | super().__init__(1, [input_])
81 |
82 | def compute(self, data: torch.Tensor) -> torch.Tensor:
83 | classes = torch.unique(data, sorted=False)
84 | classes = classes[~torch.isnan(classes)]
85 | one_hot = []
86 | if classes.shape[0] > 1000:
87 | warnings.warn("One hot encoding with too many features: ({}). "
88 | .format(classes.shape[0]),
89 | RuntimeWarning)
90 | for i in range(classes.shape[0]):
91 | one_hot.append((data == classes[i]).unsqueeze(-1))
92 | return torch.cat(one_hot, dim=-1)
93 |
94 |
95 | class AnyFilter(FilterFactor):
96 | _min_win = 2
97 |
98 | def compute(self, data: Rolling) -> torch.Tensor:
99 | return data.values.any(dim=2)
100 |
101 |
102 | class AllFilter(FilterFactor):
103 | _min_win = 2
104 |
105 | def compute(self, data: Rolling) -> torch.Tensor:
106 | return data.values.all(dim=2)
107 |
108 |
109 | class AnyNonNaNFactor(FilterFactor):
110 | _min_win = 2
111 |
112 | def compute(self, data: Rolling) -> torch.Tensor:
113 | return (~torch.isnan(data.values)).any(dim=2)
114 |
115 |
116 | class AllNonNaNFactor(FilterFactor):
117 | _min_win = 2
118 |
119 | def compute(self, data: Rolling) -> torch.Tensor:
120 | return (~torch.isnan(data.values)).all(dim=2)
121 |
122 |
123 | class InvertFactor(FilterFactor):
124 | def compute(self, left) -> torch.Tensor:
125 | return ~left
126 |
127 |
128 | class OrFactor(FilterFactor):
129 | def compute(self, left, right) -> torch.Tensor:
130 | return left | right
131 |
132 |
133 | class XorFactor(FilterFactor):
134 | def compute(self, left, right) -> torch.Tensor:
135 | return left ^ right
136 |
137 |
138 | class AndFactor(FilterFactor):
139 | def compute(self, left, right) -> torch.Tensor:
140 | return left & right
141 |
142 |
143 | class LtFactor(FilterFactor):
144 | def compute(self, left, right) -> torch.Tensor:
145 | return torch.lt(left, right)
146 |
147 |
148 | class LeFactor(FilterFactor):
149 | def compute(self, left, right) -> torch.Tensor:
150 | return torch.le(left, right)
151 |
152 |
153 | class GtFactor(FilterFactor):
154 | def compute(self, left, right) -> torch.Tensor:
155 | return torch.gt(left, right)
156 |
157 |
158 | class GeFactor(FilterFactor):
159 | def compute(self, left, right) -> torch.Tensor:
160 | return torch.ge(left, right)
161 |
162 |
163 | class EqFactor(FilterFactor):
164 | def compute(self, left, right) -> torch.Tensor:
165 | return torch.eq(left, right)
166 |
167 |
168 | class NeFactor(FilterFactor):
169 | def compute(self, left, right) -> torch.Tensor:
170 | return torch.ne(left, right)
171 |
--------------------------------------------------------------------------------
/spectre/factors/label.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Heerozh (Zhang Jianhao)
3 | @copyright: Copyright 2019-2020, Heerozh. All rights reserved.
4 | @license: Apache 2.0
5 | @email: heeroz@gmail.com
6 | """
7 | from .factor import CustomFactor
8 | from ..parallel import masked_first
9 |
10 |
11 | class RollingFirst(CustomFactor):
12 | win = 2
13 | _min_win = 2
14 |
15 | def __init__(self, win, data, mask):
16 | super().__init__(win, inputs=(data, mask))
17 |
18 | def compute(self, data, mask):
19 | def _first_filter(_data, _mask):
20 | first_signal_price = masked_first(_data, _mask, dim=2)
21 | return first_signal_price
22 |
23 | return data.agg(_first_filter, mask)
24 |
25 |
26 | class ForwardSignalData(RollingFirst):
27 | """Data in future window periods where signal = True. Lookahead biased."""
28 | def __init__(self, win, data, signal):
29 | super().__init__(win, data.shift(-win+1), signal.shift(-win+1))
30 |
--------------------------------------------------------------------------------
/spectre/factors/multiprocessing.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Heerozh (Zhang Jianhao)
3 | @copyright: Copyright 2019, Heerozh. All rights reserved.
4 | @license: Apache 2.0
5 | @email: heeroz@gmail.com
6 | """
7 | from typing import Optional, Sequence
8 | from .factor import BaseFactor, CustomFactor
9 | from .datafactor import ColumnDataFactor
10 | from ..parallel import Rolling
11 | import pandas as pd
12 | import numpy as np
13 | from multiprocessing import Pool, cpu_count
14 | from multiprocessing.pool import ThreadPool
15 |
16 |
17 | class CPFCaller:
18 | inputs = None
19 | win = None
20 | callback = None
21 |
22 | def split_call(self, splits):
23 | split_inputs = [[]] * len(self.inputs)
24 | for i, data in enumerate(self.inputs):
25 | if isinstance(data, pd.DataFrame):
26 | split_inputs[i] = [data.iloc[beg:end] for beg, end in splits]
27 | else:
28 | split_inputs[i] = [data] * len(splits)
29 | return np.array([self.callback(*params) for params in zip(*split_inputs)])
30 |
31 |
32 | class CPUParallelFactor(CustomFactor):
33 | """
34 | Use CPU multi-process/thread instead of GPU to process each window of data.
35 | Useful when your calculations can only be done in the CPU.
36 |
37 | The performance of this method is not so ideal, definitely not as fast as
38 | using the vectorization library directly.
39 | """
40 |
41 | def __init__(self, win: Optional[int] = None, inputs: Optional[Sequence[BaseFactor]] = None,
42 | multiprocess=False, core=None):
43 | """
44 | `multiprocess=True` may not working on windows If your code is written in a notebook cell.
45 | So it is recommended that you write the CPUParallelFactor code in a file.
46 | """
47 | super().__init__(win, inputs)
48 |
49 | for data in inputs:
50 | if isinstance(data, ColumnDataFactor):
51 | raise ValueError('Cannot use ColumnDataFactor in CPUParallelFactor, '
52 | 'please use AdjustedColumnDataFactor instead.')
53 | if multiprocess:
54 | self.pool = Pool
55 | else:
56 | self.pool = ThreadPool
57 | if core is None:
58 | self.core = cpu_count()
59 | else:
60 | self.core = core
61 |
62 | def compute(self, *inputs):
63 | n_cores = self.core
64 | origin_input = None
65 | date_count = 0
66 |
67 | converted_inputs = []
68 | for data in inputs:
69 | if isinstance(data, Rolling):
70 | s = self._revert_to_series(data.last())
71 | unstacked = s.unstack(level=1)
72 | converted_inputs.append(unstacked)
73 | if origin_input is None:
74 | origin_input = s
75 | date_count = len(unstacked)
76 | else:
77 | converted_inputs.append(data)
78 |
79 | backwards = self.get_total_backwards_()
80 | first_win_beg = backwards - self.win + 1
81 | first_win_end = backwards + 1
82 | windows = date_count - backwards
83 | ranges = list(zip(range(first_win_beg, first_win_beg + windows),
84 | range(first_win_end, date_count + 1)))
85 | caller = CPFCaller()
86 | caller.inputs = converted_inputs
87 | caller.callback = type(self).mp_compute
88 |
89 | if len(ranges) < n_cores:
90 | n_cores = len(ranges)
91 | split_range = np.array_split(ranges, n_cores)
92 |
93 | with self.pool(n_cores) as p:
94 | pool_ret = p.map(caller.split_call, split_range)
95 |
96 | pool_ret = np.concatenate(pool_ret)
97 | ret = pd.Series(index=origin_input.index, dtype='float64').unstack(level=1)
98 | if pool_ret.shape != ret.iloc[backwards:].shape:
99 | raise ValueError('return value shape {} != original {}'.format(
100 | pool_ret.shape, ret.iloc[backwards:].shape))
101 |
102 | ret.iloc[backwards:] = pool_ret
103 | ret = ret.stack(dropna=False)[origin_input.index]
104 |
105 | return self._regroup(ret)
106 |
107 | @staticmethod
108 | def mp_compute(*inputs) -> np.array:
109 | """
110 | You will receive a window of the input data, type is pd.DataFrame.
111 | The table below is how it looks when win=3
112 | | date | A | AA | ... | ZET |
113 | |------------|------|--------|-----|--------|
114 | | 2020-01-01 | 11.1 | xxx.xx | ... | 123.45 |
115 | | 2020-01-02 | ... | ... | ... | ... |
116 | | 2020-01-03 | 22.2 | xxx.xx | ... | 234.56 |
117 |
118 | You should return an np.array of length `input.shape[1]`
119 | """
120 | raise NotImplementedError("abstractmethod")
121 |
--------------------------------------------------------------------------------
/spectre/factors/statistical.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Heerozh (Zhang Jianhao)
3 | @copyright: Copyright 2019, Heerozh. All rights reserved.
4 | @license: Apache 2.0
5 | @email: heeroz@gmail.com
6 | """
7 | import torch
8 | import math
9 | import numpy as np
10 | from .factor import CustomFactor, CrossSectionFactor
11 | from .engine import OHLCV
12 | from ..parallel import (linear_regression_1d, quantile, pearsonr, unmasked_mean, unmasked_sum,
13 | nanmean, nanstd, covariance, nanvar)
14 | from ..parallel import DeviceConstant
15 | from ..config import Global
16 |
17 |
18 | class StandardDeviation(CustomFactor):
19 | inputs = [OHLCV.close]
20 | _min_win = 2
21 | ddof = 0
22 |
23 | def compute(self, data):
24 | return data.nanstd(ddof=self.ddof)
25 |
26 |
27 | class XSStandardDeviation(CrossSectionFactor):
28 | inputs = [OHLCV.close]
29 | ddof = 0
30 |
31 | def compute(self, data):
32 | return nanstd(data, ddof=self.ddof).unsqueeze(-1).expand(data.shape[0], data.shape[1])
33 |
34 |
35 | class RollingHigh(CustomFactor):
36 | inputs = (OHLCV.close,)
37 | win = 5
38 | _min_win = 2
39 |
40 | def compute(self, data):
41 | return data.nanmax()
42 |
43 |
44 | class RollingLow(CustomFactor):
45 | inputs = (OHLCV.close,)
46 | win = 5
47 | _min_win = 2
48 |
49 | def compute(self, data):
50 | return data.nanmin()
51 |
52 |
53 | class RollingLinearRegression(CustomFactor):
54 | _min_win = 2
55 |
56 | def __init__(self, win, x, y):
57 | super().__init__(win=win, inputs=[x, y])
58 |
59 | def compute(self, x, y):
60 | def lin_reg(_y, _x=None):
61 | if _x is None:
62 | _x = DeviceConstant.get(_y.device).arange(_y.shape[2], dtype=_y.dtype)
63 | _x = _x.expand(_y.shape[0], _y.shape[1], _x.shape[0])
64 | m, b = linear_regression_1d(_x, _y, dim=2)
65 | return torch.cat([m.unsqueeze(-1), b.unsqueeze(-1)], dim=-1)
66 | if x is None:
67 | return y.agg(lin_reg)
68 | else:
69 | return y.agg(lin_reg, x)
70 |
71 | @property
72 | def coef(self):
73 | return self[0]
74 |
75 | @property
76 | def intercept(self):
77 | return self[1]
78 |
79 |
80 | class RollingMomentum(CustomFactor):
81 | inputs = (OHLCV.close,)
82 | win = 20
83 | _min_win = 2
84 |
85 | def compute(self, prices):
86 | def polynomial_reg(_y):
87 | x = DeviceConstant.get(_y.device).arange(_y.shape[2], dtype=_y.dtype)
88 | ones = torch.ones(x.shape[0], device=_y.device, dtype=_y.dtype)
89 | x = torch.stack([ones, x, x ** 2]).T
90 | x = x.expand(_y.shape[0], _y.shape[1], x.shape[0], x.shape[1])
91 |
92 | xt = x.transpose(-2, -1)
93 | ret = (xt @ x).inverse() @ xt @ _y.unsqueeze(-1)
94 | return ret.squeeze(-1)
95 |
96 | return prices.agg(polynomial_reg)
97 |
98 | @property
99 | def gain(self):
100 | """gain>0 means stock gaining, otherwise is losing."""
101 | return self[1]
102 |
103 | @property
104 | def accelerate(self):
105 | """accelerate>0 means stock accelerating, otherwise is decelerating."""
106 | return self[2]
107 |
108 | @property
109 | def intercept(self):
110 | return self[0]
111 |
112 |
113 | class RollingQuantile(CustomFactor):
114 | inputs = (OHLCV.close, 5)
115 | _min_win = 2
116 |
117 | def compute(self, data, bins):
118 | def _quantile(_data):
119 | return quantile(_data, bins, dim=2)[:, :, -1]
120 | return data.agg(_quantile)
121 |
122 |
123 | class HalfLifeMeanReversion(CustomFactor):
124 | _min_win = 2
125 |
126 | def __init__(self, win, data, mean, mask=None):
127 | lag = data.shift(1) - mean
128 | diff = data - data.shift(1)
129 | lag.set_mask(mask)
130 | diff.set_mask(mask)
131 | super().__init__(win=win, inputs=[lag, diff, math.log(2)])
132 |
133 | def compute(self, lag, diff, ln2):
134 | def calc_h(_x, _y):
135 | _lambda, _ = linear_regression_1d(_x, _y, dim=2)
136 | return -ln2 / _lambda
137 | return lag.agg(calc_h, diff)
138 |
139 |
140 | class RollingCorrelation(CustomFactor):
141 | _min_win = 2
142 |
143 | def compute(self, x, y):
144 | def _corr(_x, _y):
145 | return pearsonr(_x, _y, dim=2, ddof=1)
146 | return x.agg(_corr, y)
147 |
148 |
149 | class RollingCovariance(CustomFactor):
150 | _min_win = 2
151 |
152 | def compute(self, x, y):
153 | def _cov(_x, _y):
154 | return covariance(_x, _y, dim=2, ddof=1)
155 | return x.agg(_cov, y)
156 |
157 |
158 | class XSMaxCorrCoef(CrossSectionFactor):
159 | """
160 | Returns the maximum correlation coefficient for each x to others
161 | """
162 |
163 | def compute(self, *xs):
164 | x = torch.stack(xs, dim=1)
165 | x_bar = nanmean(x, dim=2).unsqueeze(-1)
166 | demean = x - x_bar
167 | demean.masked_fill_(torch.isnan(demean), 0)
168 | cov = demean @ demean.transpose(-2, -1)
169 | cov = cov / (x.shape[-1] - 1)
170 | diag = cov[:, range(len(xs)), range(len(xs)), None]
171 | std = diag ** 0.5
172 | corr = cov / std / std.transpose(-2, -1)
173 | # set auto corr to zero
174 | corr[:, range(len(xs)), range(len(xs))] = 0
175 | max_corr = corr.max(dim=2).values.unsqueeze(-2)
176 | return max_corr.expand(x.shape[0], x.shape[2], x.shape[1])
177 |
178 |
179 | class InformationCoefficient(CrossSectionFactor):
180 | """
181 | Cross-Section IC, the ic value of all assets is the same.
182 | """
183 | def __init__(self, x, y, mask=None, weight=None):
184 | super().__init__(win=1, inputs=[x, y, weight], mask=mask)
185 |
186 | def compute(self, x, y, weight):
187 | if weight is None:
188 | ic = pearsonr(x, y, dim=1, ddof=1)
189 | else:
190 | xy = x * y
191 | mask = torch.isnan(x * y)
192 | w = weight / unmasked_sum(weight, mask=mask, dim=1).unsqueeze(-1)
193 | x_bar = unmasked_sum(w * x, mask=mask, dim=1)
194 | y_bar = unmasked_sum(w * y, mask=mask, dim=1)
195 | cov_xy = unmasked_sum(w * xy, mask=mask, dim=1) - x_bar * y_bar
196 | var_x = unmasked_sum(w * x ** 2, mask=mask, dim=1) - x_bar ** 2
197 | var_y = unmasked_sum(w * y ** 2, mask=mask, dim=1) - y_bar ** 2
198 | ic = cov_xy / (var_x * var_y) ** 0.5
199 | return ic.unsqueeze(-1).expand(ic.shape[0], y.shape[1])
200 |
201 | def to_ir(self, win, ddof=1):
202 | # Use CrossSectionFactor and unfold by self, because if use CustomFactor, the ir value
203 | # will inconsistent when some assets have no data (like newly listed), the ir value should
204 | # not be related to assets.
205 | class RollingIC2IR(CrossSectionFactor):
206 | def __init__(self, win_, inputs):
207 | super().__init__(1, inputs)
208 | self.rolling_win = win_
209 |
210 | def compute(self, ic):
211 | x = ic[:, 0]
212 | nan_stack = x.new_full((self.rolling_win - 1,), np.nan)
213 | new_x = torch.cat((nan_stack, x), dim=0)
214 | rolling_ic = new_x.unfold(0, self.rolling_win, 1)
215 |
216 | # Fundamental Law of Active Management: ir = ic * sqrt(b), 1/sqrt(b) = std(ic)
217 | ir = nanmean(rolling_ic, dim=1) / nanstd(rolling_ic, dim=1, ddof=ddof)
218 | return ir.unsqueeze(-1).expand(ic.shape)
219 | return RollingIC2IR(win_=win, inputs=[self])
220 |
221 |
222 | class RollingInformationCoefficient(RollingCorrelation):
223 | """
224 | Rolling IC, Calculate IC between 2 historical data for each asset.
225 | """
226 | def to_ir(self, win):
227 | std = StandardDeviation(win=win, inputs=(self,))
228 | std.ddof = 1
229 | mean = self.sum(win) / win
230 |
231 | return mean / std
232 |
233 |
234 | class RankWeightedInformationCoefficient(InformationCoefficient):
235 | def __init__(self, x, y, half_life, mask=None):
236 | alpha = np.exp((np.log(0.5) / half_life))
237 | y_rank = y.rank(ascending=False, mask=mask) - 1
238 | weight = alpha ** y_rank
239 | super().__init__(x, y, mask=mask, weight=weight)
240 |
241 |
242 | class TTest1Samp(CustomFactor):
243 | _min_win = 2
244 |
245 | def compute(self, a, pop_mean):
246 | def _ttest(_x):
247 | d = nanmean(_x, dim=2) - pop_mean
248 | v = nanvar(_x, dim=2, ddof=1)
249 | denom = torch.sqrt(v / self._min_win)
250 | t = d / denom
251 | return t
252 | return a.agg(_ttest)
253 |
254 |
255 | class StudentCDF(CrossSectionFactor):
256 | """
257 | Note!! For performance, This factor uses Cross-section mean of t-value.
258 | """
259 | DefaultPrecision = 0.001
260 |
261 | def compute(self, t, dof, precision):
262 | reduced_t = nanmean(t, dim=1)
263 | p = torch.zeros_like(reduced_t)
264 | dof = torch.tensor(dof, dtype=torch.float64, device=t.device)
265 | for i, v in enumerate(reduced_t.cpu()):
266 | if np.isnan(v):
267 | p[i] = torch.nan
268 | elif np.isinf(v):
269 | p[i] = 1
270 | elif v < -9:
271 | p[i] = 0
272 | else:
273 | x = torch.arange(-9, v, precision, device=t.device)
274 | p[i] = torch.e ** torch.lgamma((dof + 1) / 2) / (
275 | torch.sqrt(dof * torch.pi) * torch.e ** torch.lgamma(dof / 2)) * (
276 | torch.trapezoid((1 + x ** 2 / dof) ** (-dof / 2 - 1 / 2), x)
277 | )
278 | return p.unsqueeze(-1).expand(t.shape)
279 |
280 |
281 | class CrossSectionR2(CrossSectionFactor):
282 | def __init__(self, y, y_pred, mask, total_r2=False):
283 | super().__init__(win=1, inputs=[y, y_pred], mask=mask)
284 | self.total_r2 = total_r2
285 |
286 | def compute(self, y, y_pred):
287 | mask = torch.isnan(y_pred) | torch.isnan(y)
288 | ss_err = unmasked_sum((y - y_pred) ** 2, mask, dim=1)
289 | if self.total_r2:
290 | # 按市场总回报来算r2的话,用这个。不然就是相对回报。
291 | ss_tot = unmasked_sum(y ** 2, mask, dim=1)
292 | else:
293 | y_bar = unmasked_mean(y, mask, dim=1).unsqueeze(-1)
294 | ss_tot = unmasked_sum((y - y_bar) ** 2, mask, dim=1)
295 | r2 = -ss_err / ss_tot + 1
296 | r2[(~mask).to(Global.float_type).sum(dim=1) < 2] = np.nan
297 | return r2.unsqueeze(-1).expand(r2.shape[0], y.shape[1])
298 |
299 |
300 | class FactorWiseKthValue(CrossSectionFactor):
301 | """ The kth value of all factors sorted in ascending order, grouped by each datetime """
302 | def __init__(self, kth, inputs=None):
303 | super().__init__(1, inputs)
304 | self.kth = kth
305 |
306 | def compute(self, *data):
307 | mx = torch.stack([nanmean(x, dim=1) for x in data], dim=-1)
308 | nans = torch.isnan(mx)
309 | mx.masked_fill_(nans, -np.inf)
310 | ret = torch.kthvalue(mx, self.kth, dim=1, keepdim=True).values
311 | return ret.expand(ret.shape[0], data[0].shape[1])
312 |
313 |
314 | class FactorWiseZScore(CrossSectionFactor):
315 | def compute(self, *data):
316 | mx = torch.stack([nanmean(x, dim=1) for x in data], dim=-1)
317 | ret = (mx - nanmean(mx, dim=1).unsqueeze(-1)) / nanstd(mx, dim=1).unsqueeze(-1)
318 | return ret.unsqueeze(-2).repeat(1, data[0].shape[1], 1)
319 |
320 |
321 | STDDEV = StandardDeviation
322 | MAX = RollingHigh
323 | MIN = RollingLow
324 |
--------------------------------------------------------------------------------
/spectre/factors/technical.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Heerozh (Zhang Jianhao)
3 | @copyright: Copyright 2019, Heerozh. All rights reserved.
4 | @license: Apache 2.0
5 | @email: heeroz@gmail.com
6 | """
7 | from typing import Optional, Sequence
8 | from .factor import BaseFactor, CustomFactor
9 | from .basic import MA, EMA
10 | from .statistical import STDDEV
11 | from .engine import OHLCV
12 | from ..parallel import nanmean
13 | import numpy as np
14 | import torch
15 |
16 |
17 | class BollingerBands(CustomFactor):
18 | """ usage: BBANDS(win, inputs=[OHLCV.close, k]), k is constant normally 2 """
19 | inputs = (OHLCV.close, 2)
20 | win = 20
21 | _min_win = 2
22 |
23 | def __init__(self, win: Optional[int] = None, inputs: Optional[Sequence[BaseFactor]] = None):
24 | super().__init__(win, inputs)
25 | if len(self.inputs) < 2:
26 | raise ValueError("BollingerBands's inputs needs 2 inputs, "
27 | "inputs=[OHLCV.close, k]), k is constant normally 2.")
28 | comm_inputs = (self.inputs[0],)
29 | k = self.inputs[1]
30 | self.inputs = (self.inputs[0],
31 | MA(win=self.win, inputs=comm_inputs),
32 | STDDEV(win=self.win, inputs=comm_inputs),
33 | k)
34 | self.win = 1
35 |
36 | def compute(self, closes, ma, std, k):
37 | d = k * std
38 | up = ma + d
39 | down = ma - d
40 | return torch.cat([up.unsqueeze(-1), ma.unsqueeze(-1), down.unsqueeze(-1)], dim=-1)
41 |
42 | def normalized(self):
43 | return NormalizedBollingerBands(self.win, self.inputs)
44 |
45 |
46 | class NormalizedBollingerBands(CustomFactor):
47 | def compute(self, closes, ma, std, k):
48 | return (closes - ma) / (k * std)
49 |
50 |
51 | class MovingAverageConvergenceDivergenceSignal(EMA):
52 | """
53 | engine.add( MACD(fast, slow, sign, inputs=[OHLCV.close]) )
54 | or
55 | engine.add( MACD().normalized() )
56 | """
57 | inputs = (OHLCV.close,)
58 | win = 9
59 | _min_win = 2
60 |
61 | def __init__(self, fast=12, slow=26, sign=9, inputs: Optional[Sequence[BaseFactor]] = None,
62 | adjust=False):
63 | super().__init__(sign, inputs, adjust)
64 | self.inputs = (EMA(inputs=self.inputs, span=fast) - EMA(inputs=self.inputs, span=slow),)
65 |
66 | def normalized(self):
67 | # In order not to double the calculation, reuse `inputs` factor here
68 | macd = self.inputs[0]
69 | sign = self
70 | return macd - sign
71 |
72 |
73 | class TrueRange(CustomFactor):
74 | """ATR = MA(14, inputs=(TrueRange(),))"""
75 | inputs = (OHLCV.high, OHLCV.low, OHLCV.close)
76 | win = 2
77 | _min_win = 2
78 |
79 | def compute(self, highs, lows, closes):
80 | high_to_low = highs.last() - lows.last()
81 | high_to_prev_close = (highs.last() - closes.first()).abs()
82 | low_to_prev_close = (lows.last() - closes.first()).abs()
83 | max1 = high_to_low.where(high_to_low > high_to_prev_close, high_to_prev_close)
84 | return max1.where(max1 > low_to_prev_close, low_to_prev_close)
85 |
86 |
87 | class RSI(CustomFactor):
88 | """ usage: RSI(win, inputs=[OHLCV.close]) """
89 | inputs = (OHLCV.close,)
90 | win = 14
91 | _min_win = 2
92 |
93 | def __init__(self, win: Optional[int] = None, inputs: Optional[Sequence[BaseFactor]] = None):
94 | super().__init__(win, inputs)
95 | self.win = self.win + 1 # +1 for 1 day diff
96 |
97 | def compute(self, closes):
98 | def _rsi(_closes):
99 | shift = _closes.roll(1, dims=2)
100 | shift = shift.contiguous()
101 | shift[:, :, 0] = np.nan
102 | diff = _closes - shift
103 | del shift, _closes
104 | up = diff.clamp(min=0)
105 | down = diff.clamp(max=0)
106 | del diff
107 | # Cutler's RSI, more stable, independent to data length
108 | up = nanmean(up[:, :, 1:], dim=2)
109 | down = nanmean(down[:, :, 1:], dim=2).abs()
110 | return 100 - (100 / (1 + up / down))
111 | # Wilder's RSI
112 | # up = up.ewm(com=14-1, adjust=False).mean()
113 | # down = down.ewm(com=14-1, adjust=False).mean().abs()
114 | return closes.agg(_rsi)
115 |
116 | def normalized(self):
117 | return self / 50 - 1
118 |
119 |
120 | class FastStochasticOscillator(CustomFactor):
121 | """ usage: STOCHF(win, inputs=[OHLCV.high, OHLCV.low, OHLCV.close]) """
122 | inputs = (OHLCV.high, OHLCV.low, OHLCV.close)
123 | win = 14
124 | _min_win = 2
125 |
126 | def compute(self, highs, lows, closes):
127 | highest_highs = highs.nanmax()
128 | lowest_lows = lows.nanmin()
129 | k = (closes.last() - lowest_lows) / (highest_highs - lowest_lows)
130 |
131 | return k * 100
132 |
133 | def normalized(self):
134 | return self / 100 - 0.5
135 |
136 |
137 | BBANDS = BollingerBands
138 | MACD = MovingAverageConvergenceDivergenceSignal
139 | TRANGE = TrueRange
140 | STOCHF = FastStochasticOscillator
141 |
--------------------------------------------------------------------------------
/spectre/parallel/__init__.py:
--------------------------------------------------------------------------------
1 | from .algorithmic import (
2 | ParallelGroupBy, DummyParallelGroupBy,
3 | Rolling,
4 | nansum, unmasked_sum,
5 | nanmean, unmasked_mean,
6 | nanvar,
7 | nanstd,
8 | masked_last,
9 | masked_first,
10 | nanlast,
11 | nanmax,
12 | nanmin,
13 | pad_2d,
14 | rankdata,
15 | covariance,
16 | pearsonr,
17 | spearman,
18 | linear_regression_1d,
19 | quantile,
20 | masked_kth_value_1d,
21 | clamp_1d_,
22 | )
23 |
24 | from .constants import DeviceConstant
25 |
--------------------------------------------------------------------------------
/spectre/parallel/constants.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class TensorConstant:
5 | def __init__(self, device):
6 | self.device = device
7 | self.linspace_cache = {}
8 | self.r_linspace_cache = {}
9 | self.arange_cache = {}
10 |
11 | def linspace(self, size, dtype):
12 | if dtype in self.linspace_cache:
13 | w = self.linspace_cache[dtype]
14 | if size <= len(w):
15 | return w[:size]
16 |
17 | self.linspace_cache[dtype] = new = torch.linspace(
18 | 0.0, 0.9, size, dtype=dtype, device=self.device)
19 | return new
20 |
21 | def r_linspace(self, size, dtype):
22 | if dtype in self.r_linspace_cache:
23 | w = self.r_linspace_cache[dtype]
24 | if size <= len(w):
25 | return w[:size]
26 |
27 | self.r_linspace_cache[dtype] = new = torch.linspace(
28 | 0.9, 0.0, size, dtype=dtype, device=self.device)
29 | return new
30 |
31 | def arange(self, size, dtype):
32 | if dtype in self.arange_cache:
33 | w = self.arange_cache[dtype]
34 | if size <= len(w):
35 | return w[:size]
36 |
37 | self.arange_cache[dtype] = new = torch.arange(size, dtype=dtype, device=self.device)
38 | return new
39 |
40 |
41 | class DeviceConstant:
42 | constants = {}
43 |
44 | @classmethod
45 | def clean(cls):
46 | cls.constants = {}
47 |
48 | @classmethod
49 | def get(cls, device):
50 | if device in cls.constants:
51 | return cls.constants[device]
52 |
53 | cls.constants[device] = new = TensorConstant(device)
54 | return new
55 |
56 |
57 |
58 |
59 |
60 |
--------------------------------------------------------------------------------
/spectre/plotting/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .returns_chart import (
3 | plot_quantile_and_cumulative_returns,
4 | cumulative_returns_fig,
5 | )
6 |
7 | from .factor_diagram import (
8 | plot_factor_diagram,
9 | )
10 |
11 | from .chart import (
12 | plot_chart,
13 | )
14 |
--------------------------------------------------------------------------------
/spectre/plotting/chart.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Heerozh (Zhang Jianhao)
3 | @copyright: Copyright 2019-2020, Heerozh. All rights reserved.
4 | @license: Apache 2.0
5 | @email: heeroz@gmail.com
6 | """
7 | import warnings
8 |
9 |
10 | def plot_chart(df_prices, ohlcv, df_factor, trace_types=None, styles=None, inline=True):
11 | import plotly.graph_objects as go
12 |
13 | # group df by asset
14 | asset_group = df_factor.index.get_level_values(1).remove_unused_categories()
15 | dfs = list(df_factor.groupby(asset_group))
16 | if len(dfs) > 5:
17 | warnings.warn("Warning!! Too many assets {}, only plotting top 5.".format(len(dfs)),
18 | RuntimeWarning)
19 | dfs = dfs[:5]
20 | dfs = [(asset, factors) for (asset, factors) in dfs if factors.shape[0] > 0]
21 | trace_types = trace_types and trace_types or {}
22 |
23 | # init styles
24 | styles = styles and styles or {}
25 | styles['price'] = styles.get('price', {})
26 | styles['volume'] = styles.get('volume', {})
27 |
28 | # default styles
29 | styles['height'] = styles.get('height', 500)
30 | styles['price']['line'] = styles['price'].get('line', dict(width=1))
31 | styles['price']['name'] = styles['price'].get('name', 'price')
32 | styles['volume']['opacity'] = styles['volume'].get('opacity', 0.2)
33 | styles['volume']['yaxis'] = styles['volume'].get('yaxis', 'y2')
34 | styles['volume']['name'] = styles['volume'].get('name', 'volume')
35 |
36 | # get y_axes
37 | y_axes = set()
38 | for k, v in styles.items():
39 | if not isinstance(v, dict):
40 | continue
41 | if 'yaxis' in v:
42 | y_axes.add('yaxis' + v['yaxis'][1:])
43 | if 'yref' in v:
44 | y_axes.add('yaxis' + v['yref'][1:])
45 |
46 | figs = {}
47 | # plotting
48 | for i, (asset, factors) in enumerate(dfs):
49 | fig = go.Figure()
50 | figs[asset] = fig
51 |
52 | factors = factors.droplevel(level=1)
53 | start, end = factors.index[0], factors.index[-1]
54 |
55 | prices = df_prices.loc[(slice(start, end), asset), :].droplevel(level=1)
56 | index = prices.index.strftime("%y-%m-%d %H%M%S")
57 | if ohlcv[0] is not None:
58 | # add candlestick
59 | fig.add_trace(
60 | go.Candlestick(x=index, open=prices[ohlcv[0]], high=prices[ohlcv[1]],
61 | low=prices[ohlcv[2]], close=prices[ohlcv[3]], **styles['price']))
62 | fig.add_trace(
63 | go.Bar(x=index, y=prices[ohlcv[4]], **styles['volume']))
64 | else:
65 | # add line plot
66 | if ohlcv[3] is not None:
67 | fig.add_trace(
68 | go.Scatter(x=index, y=prices[ohlcv[3]], **styles['price']))
69 | if ohlcv[4] is not None:
70 | fig.add_trace(
71 | go.Scatter(x=index, y=prices[ohlcv[4]], **styles['volume']))
72 |
73 | # add factors
74 | for col in factors.columns:
75 | trace_type = trace_types.get(col, 'Scatter')
76 | if trace_type is None:
77 | continue
78 | style = styles.get(col, {})
79 | style['name'] = style.get('name', col)
80 | fig.add_trace(getattr(go, trace_type)(x=index, y=factors[col], **style))
81 |
82 | new_axis = dict(anchor="free", overlaying="y", side="right", position=1)
83 | alpha_ordered_axises = list(y_axes)
84 | alpha_ordered_axises.sort()
85 | for y_axis in alpha_ordered_axises:
86 | fig.update_layout(**{y_axis: new_axis})
87 | new_axis['position'] -= 0.03
88 | x_right = new_axis['position'] + 0.03
89 |
90 | fig.update_layout(xaxis=dict(domain=[0, x_right]))
91 | fig.update_xaxes(rangeslider=dict(visible=False))
92 | fig.update_yaxes(showgrid=False, scaleanchor="x", scaleratio=1)
93 | fig.update_layout(legend=dict(xanchor='right', x=x_right, y=1, bgcolor='rgba(0,0,0,0)'))
94 | fig.update_layout(height=styles['height'], barmode='group', bargap=0.5, margin={'t': 50},
95 | title=asset)
96 |
97 | if inline:
98 | fig.show()
99 |
100 | return figs
101 |
--------------------------------------------------------------------------------
/spectre/plotting/factor_diagram.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Heerozh (Zhang Jianhao)
3 | @copyright: Copyright 2019-2020, Heerozh. All rights reserved.
4 | @license: Apache 2.0
5 | @email: heeroz@gmail.com
6 | """
7 | from itertools import cycle, islice
8 |
9 |
10 | def plot_factor_diagram(factor):
11 | import plotly.graph_objects as go
12 | from ..factors import BaseFactor, CustomFactor
13 | from ..factors import ColumnDataFactor, DatetimeDataFactor
14 |
15 | color = [
16 | "rgba(31, 119, 180, 0.8)", "rgba(255, 127, 14, 0.8)", "rgba(44, 160, 44, 0.8)",
17 | "rgba(214, 39, 40, 0.8)", "rgba(148, 103, 189, 0.8)", "rgba(140, 86, 75, 0.8)",
18 | "rgba(227, 119, 194, 0.8)", "rgba(127, 127, 127, 0.8)", "rgba(188, 189, 34, 0.8)",
19 | "rgba(23, 190, 207, 0.8)", "rgba(31, 119, 180, 0.8)", "rgba(255, 127, 14, 0.8)",
20 | "rgba(44, 160, 44, 0.8)", "rgba(214, 39, 40, 0.8)", "rgba(148, 103, 189, 0.8)",
21 | "rgba(140, 86, 75, 0.8)", "rgba(227, 119, 194, 0.8)", "rgba(127, 127, 127, 0.8)",
22 | "rgba(188, 189, 34, 0.8)", "rgba(23, 190, 207, 0.8)", "rgba(31, 119, 180, 0.8)",
23 | "rgba(255, 127, 14, 0.8)", "rgba(44, 160, 44, 0.8)", "rgba(214, 39, 40, 0.8)",
24 | "rgba(148, 103, 189, 0.8)", "rgba(140, 86, 75, 0.8)", "rgba(227, 119, 194, 0.8)",
25 | "rgba(127, 127, 127, 0.8)", "rgba(188, 189, 34, 0.8)", "rgba(23, 190, 207, 0.8)",
26 | "rgba(31, 119, 180, 0.8)", "rgba(255, 127, 14, 0.8)", "rgba(44, 160, 44, 0.8)",
27 | "rgba(214, 39, 40, 0.8)", "rgba(148, 103, 189, 0.8)", "magenta",
28 | "rgba(227, 119, 194, 0.8)", "rgba(127, 127, 127, 0.8)", "rgba(188, 189, 34, 0.8)",
29 | "rgba(23, 190, 207, 0.8)", "rgba(31, 119, 180, 0.8)", "rgba(255, 127, 14, 0.8)",
30 | "rgba(44, 160, 44, 0.8)", "rgba(214, 39, 40, 0.8)", "rgba(148, 103, 189, 0.8)",
31 | "rgba(140, 86, 75, 0.8)", "rgba(227, 119, 194, 0.8)", "rgba(127, 127, 127, 0.8)"
32 | ]
33 |
34 | factor_id = dict()
35 | label = []
36 | source = []
37 | target = []
38 | value = []
39 | line_label = []
40 |
41 | def add_node(this, parent_label_id, parent_label, parent_win):
42 | class_id = id(this)
43 |
44 | if class_id in factor_id:
45 | this_label_id = factor_id[class_id]
46 | else:
47 | this_label_id = len(label)
48 | if isinstance(this, ColumnDataFactor):
49 | label.append(this.inputs[0])
50 | if isinstance(this, DatetimeDataFactor):
51 | label.append(this.attr)
52 | else:
53 | label.append(type(this).__name__)
54 |
55 | if parent_label_id is not None:
56 | source.append(parent_label_id)
57 | target.append(this_label_id)
58 | value.append(parent_win)
59 | line_label.append(parent_label)
60 |
61 | if class_id in factor_id:
62 | return
63 |
64 | if isinstance(this, CustomFactor):
65 | this_win = this.win
66 | else:
67 | this_win = 1
68 |
69 | factor_id[class_id] = this_label_id
70 | if isinstance(this, CustomFactor):
71 | if this.inputs:
72 | for upstream in this.inputs:
73 | if isinstance(upstream, BaseFactor):
74 | add_node(upstream, this_label_id, 'inputs', this_win)
75 |
76 | if this._mask is not None:
77 | add_node(this._mask, this_label_id, 'mask', this_win)
78 |
79 | add_node(factor, None, None, None)
80 |
81 | fig = go.Figure(data=[go.Sankey(
82 | valueformat=".0f",
83 | valuesuffix="win",
84 | node=dict(
85 | pad=15,
86 | thickness=15,
87 | line=dict(color="black", width=0.5),
88 | label=label,
89 | color=list(islice(cycle(color), len(label)))
90 | ),
91 | # Add links
92 | link=dict(
93 | source=source,
94 | target=target,
95 | value=value,
96 | label=line_label
97 | ))])
98 |
99 | fig.update_layout(title_text="Factor Diagram")
100 | fig.show()
101 |
--------------------------------------------------------------------------------
/spectre/plotting/returns_chart.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Heerozh (Zhang Jianhao)
3 | @copyright: Copyright 2019-2020, Heerozh. All rights reserved.
4 | @license: Apache 2.0
5 | @email: heeroz@gmail.com
6 | """
7 | import math
8 | from itertools import cycle
9 | import sys
10 |
11 |
12 | DEFAULT_COLORS = [
13 | 'rgb(99, 110, 250)', 'rgb(239, 85, 59)', 'rgb(0, 204, 150)',
14 | 'rgb(171, 99, 250)', 'rgb(255, 161, 90)', 'rgb(25, 211, 243)'
15 | ]
16 |
17 |
18 | def plot_quantile_and_cumulative_returns(factor_data, mean_ret):
19 | """
20 | Plotly Installation:
21 | https://github.com/plotly/plotly.py#jupyterlab-support-python-35
22 | """
23 | quantiles = mean_ret.index
24 |
25 | if 'plotly.graph_objects' not in sys.modules:
26 | print('Importing plotly, it may take a while...')
27 | import plotly.graph_objects as go
28 | import plotly.subplots as subplots
29 |
30 | x = quantiles
31 | factors = mean_ret.columns.levels[0]
32 | periods = list(mean_ret.columns.levels[1])
33 | periods.sort(key=lambda cn: int(cn[:-1]))
34 | rows = math.ceil(len(factors))
35 |
36 | colors = dict(zip(periods + ['to'], cycle(DEFAULT_COLORS)))
37 | quantile_styles = {
38 | period: {'name': period, 'legendgroup': period,
39 | 'hovertemplate': 'Quantile:%{x}
'
40 | 'Return: %{y:.2f}bps ±%{error_y.array:.2f}bps',
41 | 'marker': {'color': colors[period]}}
42 | for period in periods
43 | }
44 | cumulative_styles = {
45 | period: {'name': period, 'mode': 'lines', 'legendgroup': period, 'showlegend': False,
46 | 'hovertemplate': 'Date:%{x}
'
47 | 'Return: %{y:.3f}%',
48 | 'marker': {'color': colors[period]}}
49 | for period in periods
50 | }
51 | turnover_styles = {'opacity': 0.2, 'name': 'turnover', 'legendgroup': 'turnover',
52 | 'marker': {'color': colors['to']}}
53 |
54 | specs = [[{}, {"secondary_y": True}]] * rows
55 | fig = subplots.make_subplots(
56 | rows=rows, cols=2,
57 | vertical_spacing=0.06 / rows,
58 | horizontal_spacing=0.06,
59 | specs=specs,
60 | subplot_titles=['Quantile Return', 'Portfolio cumulative returns'],
61 | )
62 |
63 | mean_ret = mean_ret * 10000
64 | for i, factor in enumerate(factors):
65 | row = i + 1
66 | weight_col = (factor, 'factor_weight')
67 | weighted = factor_data['Returns'].multiply(factor_data[weight_col], axis=0)
68 | factor_return = weighted.groupby(level='date').sum()
69 | for j, period in enumerate(periods):
70 | y = mean_ret.loc[:, (factor, period, 'mean')]
71 | err_y = mean_ret.loc[:, (factor, period, 'sem')]
72 | fig.add_trace(go.Bar(
73 | x=x, y=y, error_y=dict(type='data', array=err_y, thickness=0.2),
74 | yaxis='y1', **quantile_styles[period]
75 | ), row=row, col=1)
76 | quantile_styles[period]['showlegend'] = False
77 |
78 | open_period = period
79 | if period.endswith('D'):
80 | open_period = 'b' + open_period
81 | try:
82 | cum_ret = factor_return[period].resample(open_period).mean().dropna()
83 | except ValueError as e:
84 | print("pandas re-sampling failed, try set "
85 | "`engine.timezone = 'your local timezone'`")
86 | cum_ret = factor_return[period].resample(period).mean().dropna()
87 | cum_ret = (cum_ret + 1).cumprod() * 100 - 100
88 | fig.add_trace(go.Scatter(
89 | x=cum_ret.index, y=cum_ret.values, yaxis='y2', **cumulative_styles[period]
90 | ), row=row, col=2)
91 |
92 | fig.update_xaxes(type="category", row=row, col=1)
93 |
94 | fig.add_shape(go.layout.Shape(
95 | type="line", line=dict(width=1),
96 | y0=0, y1=0, x0=factor_return.index[0], x1=factor_return.index[-1],
97 | ), row=row, col=2)
98 |
99 | weight_diff = factor_data[weight_col].unstack(level=[1]).diff()
100 | to = weight_diff.abs().sum(axis=1) * 100
101 | resample = int(len(to) / 64)
102 | if resample > 0:
103 | to = to.fillna(0).rolling(resample).mean()[::resample]
104 | fig.add_trace(go.Bar(x=to.index, y=to.values, **turnover_styles),
105 | secondary_y=True, row=row, col=2)
106 | turnover_styles['showlegend'] = False
107 |
108 | fig.update_yaxes(title_text=factor, row=row, col=1, matches='y1')
109 | fig.update_yaxes(row=row, col=2, ticksuffix='%')
110 | fig.update_yaxes(row=row, col=2, secondary_y=False, matches='y2')
111 |
112 | fig.update_layout(height=300 * rows, barmode='group', bargap=0.5, margin={'t': 50})
113 | return fig
114 |
115 |
116 | def cumulative_returns_fig(returns, positions, transactions, benchmark, annual_risk_free, start=0):
117 | from ..trading import turnover, sharpe_ratio, drawdown, annual_volatility
118 |
119 | import plotly.graph_objects as go
120 | import plotly.subplots as subplots
121 |
122 | fig = subplots.make_subplots(specs=[[{"secondary_y": True}]])
123 |
124 | cum_ret = returns.iloc[start:]
125 | cum_ret = (cum_ret + 1).cumprod()
126 | fig.add_trace(go.Scatter(x=cum_ret.index, y=cum_ret.values * 100 - 100, name='portfolio',
127 | hovertemplate='Date:%{x}
Return: %{y:.3f}%'))
128 | fig.add_shape(go.layout.Shape(y0=0, y1=0, x0=cum_ret.index[0], x1=cum_ret.index[-1],
129 | type="line", line=dict(width=1)))
130 |
131 | if benchmark is not None:
132 | cum_bench = benchmark.iloc[start:]
133 | cum_bench = (cum_bench + 1).cumprod()
134 | fig.add_trace(go.Scatter(x=cum_bench.index, y=cum_bench.values * 100 - 100,
135 | name='benchmark', line=dict(width=0.5)))
136 |
137 | fig.add_shape(go.layout.Shape(
138 | type="rect", xref="x", yref="paper", opacity=0.5, line_width=0,
139 | fillcolor="LightGoldenrodYellow", layer="below",
140 | y0=0, y1=1, x0=cum_ret.idxmax(), x1=cum_ret[cum_ret.idxmax():].idxmin(),
141 | ))
142 |
143 | to = turnover(positions, transactions).iloc[start:] * 100
144 | resample = int(len(to) / 126)
145 | if resample > 0:
146 | to = to.fillna(0).rolling(resample).mean()[::resample]
147 | fig.add_trace(go.Bar(x=to.index, y=to.values, opacity=0.2, name='turnover'),
148 | secondary_y=True)
149 |
150 | sr = sharpe_ratio(returns, annual_risk_free)
151 | dd, ddd = drawdown(cum_ret)
152 | mdd = abs(dd.min())
153 | mdd_dur = ddd.max()
154 | vol = annual_volatility(returns) * 100
155 |
156 | if benchmark is not None:
157 | bench_sr = sharpe_ratio(benchmark, annual_risk_free)
158 | bench_vol = annual_volatility(benchmark) * 100
159 | else:
160 | bench_sr = 0
161 | bench_vol = 0
162 |
163 | ann = go.layout.Annotation(
164 | x=0.01, y=0.98, xref="paper", yref="paper",
165 | showarrow=False, borderwidth=1, bordercolor='black', align='left',
166 | text="Overall (portfolio/benchmark)
"
167 | "SharpeRatio: {:.3f}/{:.3f}
"
168 | "MaxDrawDown: {:.2f}%, {} Days
"
169 | "AnnualVolatility: {:.2f}%/{:.2f}%"
170 | .format(sr, bench_sr, mdd * 100, mdd_dur, vol, bench_vol),
171 | )
172 |
173 | fig.update_layout(height=400, annotations=[ann], margin={'t': 50})
174 | fig.update_xaxes(tickformat='%Y-%m-%d')
175 | fig.update_yaxes(title_text='cumulative return', ticksuffix='%', secondary_y=False)
176 | fig.update_yaxes(title_text='turnover', ticksuffix='%', secondary_y=True)
177 | return fig
178 |
--------------------------------------------------------------------------------
/spectre/trading/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Trading algorithm architecture
3 | ------------------------------
4 |
5 | +----------+
6 | | +<---Network
7 | |Broker API|
8 | Live | +<--+
9 | +----+-----+ |
10 | | |
11 | +------+-------+ |
12 | +----fire-EveryBarData-event------+LiveDataLoader| |
13 | | +------+-------+ |
14 | | | |
15 | | +------------------+ +---------+ +-----+------+ |
16 | | |MarketEventManager| |Algorithm| |FactorEngine| |
17 | | +---------+--------+ +----+----+ +-----+------+ |
18 | | | | | |
19 | | +----+-----+ +-----+-----+ +-+-+ |
20 | +----->+fire_event+-->+_run_engine+--->+run+---+ |
21 | +----+-----+ +-----+----++ +---+ | |
22 | | | ^ | |
23 | | | +--save-data--+ |
24 | | | |
25 | time +----+-----+ +----+----+ --------+ |
26 | trigger->+fire close+--->+rebalance+--->+Blotter+----+
27 | +----------+ +---------+ +-------+
28 |
29 |
30 | +-------------+
31 | Back-test |CsvDataLoader|
32 | +-----+-------+
33 | |
34 | +----------------------+ +---------+ +-----+------+
35 | |SimulationEventManager| |Algorithm| |FactorEngine|
36 | +------------+---------+ +----+----+ +-----+------+
37 | | | |
38 | +------+------+ +----+-----+ +-+-+
39 | |loop data row+--->+run_engine+--->+run+---+
40 | +------+---+--+ ++---+----++ +---+ |
41 | | ^ | | ^ |
42 | | +-return-+ +---return----+
43 | | |
44 | | +----+----+ +-----------------+
45 | +---------->+rebalance+-->+SimulationBlotter|
46 | +---------+ +-----------------+
47 |
48 |
49 | Pseudo-code for back-test and live
50 | ----------------------------------
51 |
52 | class MyAlg(trading.CustomAlgorithm):
53 | def initialize(self):
54 | engine = self.get_factor_engine()
55 | factor = ....
56 | factor_engine.add(factor, 'your_factor')
57 |
58 | # 10000 ns before market close
59 | self.schedule_rebalance(trading.event.MarketClose(self.rebalance, -10000))
60 |
61 | self.blotter.set_commission() # only works on back-test
62 | self.blotter.set_slippage() # only works on back-test
63 |
64 | def rebalance(self, data, history):
65 | weight = data.your_factor / data.your_factor.sum()
66 | self.order_to_percent(data.index, weight)
67 |
68 | record(...)
69 |
70 | def terminate(self):
71 | plot()
72 |
73 | # Back-test
74 | -----------------
75 | loader = spectre.data.CsvDirLoader(...)
76 | blotter = spectre.trading.SimulationBlotter(loader)
77 | evt_mgr = spectre.trading.SimulationEventManager()
78 | alg = MyAlg(blotter, man=loader)
79 | evt_mgr.subscribe(alg)
80 | evt_mgr.subscribe(blotter)
81 | evt_mgr.run('2018-01-01', '2019-01-01')
82 |
83 | ## Or the helper function:
84 | spectre.trading.run_backtest(loader, MyAlg, '2018-01-01', '2019-01-01')
85 |
86 | # Live
87 | ----------------
88 | class YourBrokerAPI:
89 | class LiveDataLoader(EventReceiver, DataLoader):
90 | def on_run():
91 | self.schedule(event.Always(read_data))
92 | def read_data(self):
93 | api.asio.read()
94 | agg_data(_cache)
95 | _cache.resample(self.rule)
96 | if new_bar:
97 | self.fire_event(event.EveryBarData)
98 | def load(...):
99 | return self._cache[xx:xx]
100 | ...
101 |
102 | broker_api = YourBrokerAPI()
103 | loader = broker_api.LiveDataLoader(rule='5mins')
104 | blotter = broker_api.LiveBlotter()
105 |
106 | evt_mgr = spectre.trading.MarketEventManager(calendar_2020)
107 | evt_mgr.subscribe(loader)
108 |
109 | alg = MyAlg(blotter, main=loader)
110 | evt_mgr.subscribe(alg)
111 | evt_mgr.run()
112 |
113 | """
114 | from .event import (
115 | Event,
116 | EveryBarData,
117 | Always,
118 | MarketOpen,
119 | MarketClose,
120 | EventReceiver,
121 | EventManager,
122 | MarketEventManager,
123 | )
124 | from .calendar import (
125 | Calendar,
126 | CNCalendar,
127 | JPCalendar,
128 | )
129 | from .algorithm import (
130 | CustomAlgorithm,
131 | SimulationEventManager
132 | )
133 | from .stopmodel import (
134 | StopModel,
135 | TrailingStopModel,
136 | PnLDecayTrailingStopModel,
137 | TimeDecayTrailingStopModel
138 | )
139 | from .position import (
140 | Position,
141 | )
142 | from .portfolio import (
143 | Portfolio,
144 | )
145 | from .blotter import (
146 | BaseBlotter,
147 | SimulationBlotter,
148 | ManualBlotter,
149 | CommissionModel,
150 | DailyCurbModel,
151 | )
152 | from .metric import (
153 | drawdown,
154 | sharpe_ratio,
155 | turnover,
156 | annual_volatility,
157 | )
158 |
159 |
160 | def run_backtest(loader: 'DataLoader', alg_type: 'Type[CustomAlgorithm]', start, end,
161 | delay_factor=True, ohlcv=None):
162 | # force python to free memory, else may be encountering cuda out of memory
163 | import gc
164 | import pandas as pd
165 | gc.collect()
166 |
167 | _blotter = SimulationBlotter(loader, start=pd.to_datetime(start, utc=True), ohlcv=ohlcv)
168 | evt_mgr = SimulationEventManager()
169 | alg = alg_type(_blotter, main=loader)
170 | evt_mgr.subscribe(_blotter)
171 | evt_mgr.subscribe(alg)
172 | evt_mgr.run(start, end, delay_factor)
173 |
174 | return alg.results
175 |
176 |
177 | def get_algorithm_data(loader: 'DataLoader', alg_type: 'Type[CustomAlgorithm]',
178 | start, end, delay_factor=True):
179 | import pandas as pd
180 | start, end = pd.to_datetime(start, utc=True), pd.to_datetime(end, utc=True)
181 |
182 | _blotter = SimulationBlotter(loader, start=start)
183 | evt_mgr = SimulationEventManager()
184 | alg = alg_type(_blotter, main=loader)
185 | evt_mgr.subscribe(_blotter)
186 | evt_mgr.subscribe(alg)
187 | alg.on_run()
188 |
189 | return alg.run_engine(start, end, delay_factor)
190 |
--------------------------------------------------------------------------------
/spectre/trading/calendar.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Heerozh (Zhang Jianhao)
3 | @copyright: Copyright 2019, Heerozh. All rights reserved.
4 | @license: Apache 2.0
5 | @email: heeroz@gmail.com
6 | """
7 | from collections import defaultdict
8 | from typing import Dict
9 | import pandas as pd
10 |
11 |
12 | class Calendar:
13 | """
14 | Usage:
15 | call build() first, get business day calendar.
16 | and manually add holiday by calling set_as_holiday().
17 | if open half-day, use remove_events() remove all events that day, and add_event() manually.
18 | US holiday calendar can found at https://iextrading.com/trading/
19 | """
20 |
21 | def __init__(self) -> None:
22 | self.events = defaultdict(list)
23 | self.timezone = None
24 |
25 | def build(self, start: str, end: str, daily_events: Dict[str, str], tz='UTC', freq='B',
26 | pop_passed=True):
27 | """ build("2020", {'Open': '9:00:00', 'Close': '15:00:00'}) """
28 | self.timezone = tz
29 | days = pd.date_range(pd.Timestamp(start, tz=tz).normalize(),
30 | pd.Timestamp(end, tz=tz).normalize(),
31 | tz=tz, freq=freq)
32 | if len(days) == 0:
33 | raise ValueError("Empty date range between now({}) to end({})".format(
34 | pd.Timestamp.now(tz=tz).normalize(), end))
35 |
36 | self.events = {name: [day + pd.Timedelta(time) for day in days]
37 | for name, time in daily_events.items()}
38 |
39 | if pop_passed:
40 | for k, _ in self.events.items():
41 | self.pop_passed(k)
42 |
43 | def add_event(self, event: str, datetime: pd.Timestamp):
44 | self.events[event].append(datetime)
45 | self.events[event].sort()
46 |
47 | def remove_events(self, date: pd.Timestamp):
48 | self.events = {
49 | event: [dt for dt in dts if dt.normalize() != date]
50 | for event, dts in self.events.items()
51 | }
52 |
53 | def set_as_holiday(self, date: pd.Timestamp):
54 | return self.remove_events(date)
55 |
56 | def hr_now(self):
57 | """ Return now time """
58 | # todo high res
59 | return pd.Timestamp.now(self.timezone)
60 |
61 | def pop_passed(self, event_name):
62 | """ Remove passed events """
63 | now = self.hr_now()
64 | # every event is daily, so will not be overkilled
65 | dts = self.events[event_name]
66 | while True:
67 | if dts[0] <= now:
68 | del dts[0]
69 | else:
70 | break
71 | return self
72 |
73 | def today_next(self):
74 | """ Return today next events """
75 | now = self.hr_now()
76 | return {
77 | event: dts[0]
78 | for event, dts in self.events.items()
79 | if dts[0].normalize() == now.normalize()
80 | }
81 |
82 |
83 | class CNCalendar(Calendar):
84 | """
85 | CN holiday calendar: http://www.sse.com.cn/disclosure/dealinstruc/closed/
86 | """
87 | # yearly manually update
88 | closed = [
89 | *pd.date_range('2020-06-25', '2020-06-28', freq='D'),
90 | *pd.date_range('2020-10-01', '2020-10-08', freq='D'),
91 |
92 | *pd.date_range('2021-01-01', '2021-01-03', freq='D'),
93 | *pd.date_range('2021-02-11', '2021-02-17', freq='D'),
94 | *pd.date_range('2021-04-03', '2021-04-05', freq='D'),
95 | *pd.date_range('2021-05-01', '2021-05-05', freq='D'),
96 | *pd.date_range('2021-06-12', '2021-06-14', freq='D'),
97 | *pd.date_range('2021-09-19', '2021-09-21', freq='D'),
98 | *pd.date_range('2021-10-01', '2021-10-07', freq='D'),
99 |
100 | *pd.date_range('2022-01-01', '2022-01-03', freq='D'),
101 | *pd.date_range('2022-01-31', '2022-02-06', freq='D'),
102 | *pd.date_range('2022-04-03', '2022-04-05', freq='D'),
103 | *pd.date_range('2022-04-30', '2022-05-04', freq='D'),
104 | *pd.date_range('2022-06-03', '2022-06-05', freq='D'),
105 | *pd.date_range('2022-09-10', '2022-09-12', freq='D'),
106 | *pd.date_range('2022-10-01', '2022-10-07', freq='D'),
107 |
108 | *pd.date_range('2023-01-01', '2023-01-02', freq='D'),
109 | *pd.date_range('2023-01-21', '2023-01-27', freq='D'),
110 | *pd.date_range('2023-04-05', '2023-04-05', freq='D'),
111 | *pd.date_range('2023-04-29', '2023-05-03', freq='D'),
112 | *pd.date_range('2023-06-22', '2023-06-24', freq='D'),
113 | *pd.date_range('2023-09-29', '2023-10-06', freq='D'),
114 |
115 | *pd.date_range('2024-01-01', '2024-01-01', freq='D'),
116 | *pd.date_range('2024-02-09', '2024-02-17', freq='D'),
117 | *pd.date_range('2024-04-04', '2024-04-06', freq='D'),
118 | *pd.date_range('2024-05-01', '2024-05-05', freq='D'),
119 | *pd.date_range('2024-06-10', '2024-06-10', freq='D'),
120 | *pd.date_range('2024-09-15', '2024-09-17', freq='D'),
121 | *pd.date_range('2024-10-01', '2024-10-07', freq='D'),
122 | ]
123 |
124 | daily_events = {
125 | 'DayStart': '00:00:00',
126 | 'PreOpen': '9:15:00',
127 | 'Open': '9:30:00',
128 | 'Lunch': '11:30:00',
129 | 'LunchEnd': '13:00:00',
130 | 'Close': '15:00:00',
131 | 'DayEnd': '23:59:59'
132 | }
133 |
134 | def __init__(self, start=None, pop_passed=True):
135 | super().__init__()
136 | timezone = 'Asia/Shanghai'
137 | if start is None:
138 | start = pd.Timestamp.now(self.timezone).normalize()
139 | assert start.year <= CNCalendar.closed[-1].year
140 | self.build(
141 | start=str(start),
142 | end=str(CNCalendar.closed[-1].year + 1),
143 | daily_events=self.daily_events,
144 | tz=timezone, pop_passed=pop_passed)
145 | for d in CNCalendar.closed:
146 | self.set_as_holiday(d.tz_localize(timezone))
147 |
148 |
149 | class JPCalendar(Calendar):
150 | """
151 | JP holiday calendar: https://www.jpx.co.jp/corporate/about-jpx/calendar/index.html
152 | """
153 | closed = [
154 | *pd.date_range(f'{pd.Timestamp.now().year}-01-01',
155 | f'{pd.Timestamp.now().year}-01-03', freq='D'),
156 | *pd.date_range(f'{pd.Timestamp.now().year+1}-01-01',
157 | f'{pd.Timestamp.now().year+1}-01-03', freq='D'),
158 | *pd.date_range(f'{pd.Timestamp.now().year}-12-31',
159 | f'{pd.Timestamp.now().year}-12-31', freq='D'),
160 | *pd.date_range(f'{pd.Timestamp.now().year+1}-12-31',
161 | f'{pd.Timestamp.now().year+1}-12-31', freq='D'),
162 |
163 | # yearly manually updated
164 | *pd.date_range('2023-01-09', '2023-01-09', freq='D'),
165 | *pd.date_range('2023-02-11', '2023-02-11', freq='D'),
166 | *pd.date_range('2023-02-23', '2023-02-23', freq='D'),
167 | *pd.date_range('2023-03-21', '2023-03-21', freq='D'),
168 | *pd.date_range('2023-04-29', '2023-04-29', freq='D'),
169 | *pd.date_range('2023-05-03', '2023-05-05', freq='D'),
170 | *pd.date_range('2023-07-17', '2023-07-17', freq='D'),
171 | *pd.date_range('2023-08-11', '2023-08-11', freq='D'),
172 | *pd.date_range('2023-09-18', '2023-09-18', freq='D'),
173 | *pd.date_range('2023-09-23', '2023-09-23', freq='D'),
174 | *pd.date_range('2023-10-09', '2023-10-09', freq='D'),
175 | *pd.date_range('2023-11-03', '2023-11-03', freq='D'),
176 | *pd.date_range('2023-11-23', '2023-11-23', freq='D'),
177 |
178 | *pd.date_range('2024-01-08', '2024-01-08', freq='D'),
179 | *pd.date_range('2024-02-11', '2024-02-12', freq='D'),
180 | *pd.date_range('2024-02-23', '2024-02-23', freq='D'),
181 | *pd.date_range('2024-03-20', '2024-03-20', freq='D'),
182 | *pd.date_range('2024-04-29', '2024-04-29', freq='D'),
183 | *pd.date_range('2024-05-03', '2024-05-06', freq='D'),
184 | *pd.date_range('2024-07-15', '2024-07-15', freq='D'),
185 | *pd.date_range('2024-08-11', '2024-08-12', freq='D'),
186 | *pd.date_range('2024-09-16', '2024-09-16', freq='D'),
187 | *pd.date_range('2024-09-22', '2024-09-23', freq='D'),
188 | *pd.date_range('2024-10-14', '2024-10-14', freq='D'),
189 | *pd.date_range('2024-11-03', '2024-11-04', freq='D'),
190 | *pd.date_range('2024-11-23', '2024-11-23', freq='D'),
191 | ]
192 |
193 | daily_events = {
194 | 'DayStart': '00:00:00',
195 | 'PreOpen': '8:00:00',
196 | 'Open': '9:00:00',
197 | 'Lunch': '11:30:00',
198 | 'LunchEnd': '12:30:00',
199 | 'Close': '15:00:00',
200 | 'DayEnd': '23:59:59'
201 | }
202 |
203 | def __init__(self, start=None, pop_passed=True):
204 | super().__init__()
205 | timezone = 'Asia/Tokyo'
206 | if start is None:
207 | start = pd.Timestamp.now(self.timezone).normalize()
208 | assert start.year <= self.closed[-1].year
209 | self.build(
210 | start=str(start),
211 | end=str(self.closed[-1].year + 1),
212 | daily_events=self.daily_events,
213 | tz=timezone, pop_passed=pop_passed)
214 | for d in self.closed:
215 | self.set_as_holiday(d.tz_localize(timezone))
216 |
--------------------------------------------------------------------------------
/spectre/trading/event.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Heerozh (Zhang Jianhao)
3 | @copyright: Copyright 2019, Heerozh. All rights reserved.
4 | @license: Apache 2.0
5 | @email: heeroz@gmail.com
6 | """
7 | import time
8 | from typing import Type
9 | import pandas as pd
10 | from .calendar import Calendar
11 |
12 |
13 | class Event:
14 | def __init__(self, callback) -> None:
15 | self.callback = callback # Callable[[Object], None]
16 |
17 | def on_schedule(self, evt_mgr):
18 | pass
19 |
20 | def should_trigger(self) -> bool:
21 | raise NotImplementedError("abstractmethod")
22 |
23 |
24 | class EveryBarData(Event):
25 | """This event is triggered passively"""
26 | def should_trigger(self) -> bool:
27 | return False
28 |
29 |
30 | class Always(Event):
31 | """Always event is useful for live date IO function like asio.read_until_complete()"""
32 | def should_trigger(self) -> bool:
33 | return True
34 |
35 |
36 | class CalendarEvent(Event):
37 | """ The following code is for live trading, will not work on back-testing """
38 | def __init__(self, calendar_event_name, callback, offset_ns=0) -> None:
39 | super().__init__(callback)
40 | self.offset = offset_ns
41 | self.calendar = None
42 | self.event_name = calendar_event_name
43 | self.trigger_time = None
44 |
45 | def on_schedule(self, evt_mgr):
46 | try:
47 | self.calendar = evt_mgr.calendar
48 | self.calculate_range()
49 | except AttributeError:
50 | pass
51 |
52 | def calculate_range(self):
53 | self.trigger_time = self.calendar.events[self.event_name][0] + \
54 | pd.Timedelta(self.offset, unit='ns')
55 |
56 | def should_trigger(self) -> bool:
57 | if self.calendar.hr_now() >= self.trigger_time:
58 | self.calendar.pop_passed(self.event_name)
59 | self.calculate_range()
60 | return True
61 | return False
62 |
63 |
64 | class MarketOpen(CalendarEvent):
65 | """ Works on both live and backtest """
66 | def __init__(self, callback, offset_ns=0) -> None:
67 | super().__init__('Open', callback, offset_ns)
68 |
69 |
70 | class MarketClose(CalendarEvent):
71 | """ Works on both live and backtest """
72 | def __init__(self, callback, offset_ns=0) -> None:
73 | super().__init__('Close', callback, offset_ns)
74 |
75 |
76 | # ----------------------------------------------------------------
77 |
78 |
79 | class EventReceiver:
80 | def __init__(self) -> None:
81 | self._event_manager = None
82 |
83 | def unsubscribe(self):
84 | if self._event_manager is not None:
85 | self._event_manager.unsubscribe(self)
86 |
87 | def schedule(self, evt: Event):
88 | self._event_manager.schedule(self, evt)
89 |
90 | def stop_event_manager(self):
91 | self._event_manager.stop()
92 |
93 | def fire_event(self, evt_type: Type[Event]):
94 | self._event_manager.fire_event(self, evt_type)
95 |
96 | def on_run(self):
97 | pass
98 |
99 | def on_end_of_run(self):
100 | pass
101 |
102 |
103 | class EventManager:
104 | def __init__(self) -> None:
105 | self._subscribers = dict()
106 | self._stop = False
107 |
108 | def subscribe(self, receiver: EventReceiver):
109 | assert receiver not in self._subscribers, 'Duplicate subscribe'
110 | self._subscribers[receiver] = []
111 | receiver._event_manager = self
112 |
113 | def unsubscribe(self, receiver: EventReceiver):
114 | assert receiver in self._subscribers, 'Subscriber not exists'
115 | del self._subscribers[receiver]
116 | receiver._event_manager = None
117 |
118 | def schedule(self, receiver: EventReceiver, event: Event):
119 | self._subscribers[receiver].append(event)
120 | event.on_schedule(self)
121 |
122 | def fire_event(self, source, evt_type: Type[Event]):
123 | for r, events in self._subscribers.items():
124 | for evt in events:
125 | if isinstance(evt, evt_type):
126 | evt.callback(source)
127 |
128 | def stop(self):
129 | self._stop = True
130 |
131 | def _beg_run(self):
132 | if not self._subscribers:
133 | raise ValueError("At least one subscriber.")
134 |
135 | for r, events in self._subscribers.items():
136 | # clear scheduled events
137 | events.clear()
138 | r.on_run()
139 |
140 | def _end_run(self):
141 | for r in self._subscribers.keys():
142 | r.on_end_of_run()
143 |
144 | def _run_once(self):
145 | for r, events in self._subscribers.items():
146 | for event in events:
147 | if event.should_trigger():
148 | event.callback(self)
149 |
150 | def run(self, *params):
151 | self._beg_run()
152 |
153 | while not self._stop:
154 | time.sleep(0.001)
155 | self._run_once()
156 |
157 | self._end_run()
158 |
159 |
160 | # ----------------------------------------------------------------
161 |
162 |
163 | class MarketEventManager(EventManager):
164 | def __init__(self, calendar: Calendar) -> None:
165 | self.calendar = calendar
166 | super().__init__()
167 |
168 |
--------------------------------------------------------------------------------
/spectre/trading/metric.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Heerozh (Zhang Jianhao)
3 | @copyright: Copyright 2019-2020, Heerozh. All rights reserved.
4 | @license: Apache 2.0
5 | @email: heeroz@gmail.com
6 | """
7 | import numpy as np
8 | import pandas as pd
9 |
10 |
11 | def drawdown(cumulative_returns):
12 | max_ret = cumulative_returns.cummax()
13 | dd = cumulative_returns / max_ret - 1
14 | dd_group = 0
15 |
16 | def drawdown_split(x):
17 | nonlocal dd_group
18 | if dd[x] == 0:
19 | dd_group += 1
20 | return dd_group
21 |
22 | dd_duration = dd.groupby(drawdown_split).cumcount()
23 | return dd, dd_duration
24 |
25 |
26 | def sharpe_ratio(daily_returns: pd.Series, annual_risk_free_rate):
27 | risk_adj_ret = daily_returns.sub(annual_risk_free_rate/252)
28 | annual_factor = np.sqrt(252)
29 | return annual_factor * risk_adj_ret.mean() / risk_adj_ret.std(ddof=1)
30 |
31 |
32 | def turnover(positions, transactions):
33 | if transactions.shape[0] == 0:
34 | return transactions.amount
35 | value_trades = (transactions.amount * transactions.fill_price).abs()
36 | value_trades = value_trades.groupby(value_trades.index.normalize()).sum()
37 | return value_trades / positions.value.sum(axis=1)
38 |
39 |
40 | def annual_volatility(daily_returns: pd.Series):
41 | volatility = daily_returns.std(ddof=1)
42 | annual_factor = np.sqrt(252)
43 | return annual_factor * volatility
44 |
--------------------------------------------------------------------------------
/spectre/trading/portfolio.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Heerozh (Zhang Jianhao)
3 | @copyright: Copyright 2019-2020, Heerozh. All rights reserved.
4 | @license: Apache 2.0
5 | @email: heeroz@gmail.com
6 | """
7 | from typing import Union, Callable
8 | import pandas as pd
9 | import numpy as np
10 | from .position import Position
11 | from .stopmodel import StopModel
12 |
13 |
14 | class Portfolio:
15 | def __init__(self, stop_model: StopModel = None):
16 | self._history = []
17 | self._positions = dict()
18 | self._cash = 0
19 | self._funds_change = []
20 | self._current_dt = None
21 | self.stop_model = stop_model
22 |
23 | def set_stop_model(self, stop_model: StopModel):
24 | """
25 | Set default portfolio stop model.
26 | Stop model can make more strategic triggers than stop orders.
27 | """
28 | self.stop_model = stop_model
29 |
30 | @property
31 | def history(self):
32 | ret = pd.DataFrame(self._history + [self._get_today_record()])
33 | ret.columns = pd.MultiIndex.from_tuples(ret.columns)
34 | ret = ret.set_index('index').sort_index(axis=0).sort_index(axis=1)
35 | return ret
36 |
37 | @property
38 | def fund_history(self):
39 | return pd.DataFrame(self._funds_change).set_index('index')
40 |
41 | @property
42 | def returns(self):
43 | if self._funds_change:
44 | value = self.history.value.sum(axis=1)
45 | funds = pd.DataFrame(self._funds_change)
46 | funds['index'] = funds['index'].fillna(value.index[0])
47 | funds = funds.set_index('index').amount
48 | return value.sub(funds, fill_value=0) / value.shift(1) - 1
49 | else:
50 | return self.history.value.sum(axis=1).pct_change()
51 |
52 | @property
53 | def positions(self):
54 | return self._positions
55 |
56 | @property
57 | def cash(self):
58 | return self._cash
59 |
60 | @property
61 | def value(self):
62 | # for asset, shares in self.positions.items():
63 | # if self._last_price[asset] != self._last_price[asset]:
64 | # raise ValueError('{}({}) is nan in {}'.format(asset, shares, self._current_dt))
65 | values = [pos.value for asset, pos in self.positions.items() if pos.shares != 0]
66 | return sum(values) + self._cash
67 |
68 | @property
69 | def leverage(self):
70 | values = [pos.value for asset, pos in self.positions.items() if pos.shares != 0]
71 | return sum(np.abs(values)) / (sum(values) + self._cash)
72 |
73 | @property
74 | def current_dt(self):
75 | return self._current_dt
76 |
77 | def __repr__(self):
78 | return "" + str(self.history)[11:]
79 |
80 | def clear(self):
81 | self.__init__(self.stop_model)
82 |
83 | def shares(self, asset):
84 | try:
85 | return self._positions[asset].shares
86 | except KeyError:
87 | return 0
88 |
89 | def _get_today_record(self):
90 | current_date = self._current_dt.normalize()
91 | record = {('index', ''): current_date, ('value', 'cash'): self._cash}
92 | for asset, pos in self._positions.items():
93 | record[('avg_px', asset)] = pos.average_price
94 | record[('shares', asset)] = pos.shares
95 | record[('value', asset)] = pos.value
96 | return record
97 |
98 | def set_datetime(self, dt):
99 | if isinstance(dt, str):
100 | dt = pd.Timestamp(dt)
101 | date = dt.normalize()
102 | if self._current_dt is not None:
103 | current_date = self._current_dt.normalize()
104 | if dt < self._current_dt:
105 | raise ValueError('Cannot set a date less than the current date')
106 | elif date > current_date:
107 | # today add to history
108 | self._history.append(self._get_today_record())
109 |
110 | self._current_dt = dt
111 | for pos in self._positions.values():
112 | pos.current_dt = dt
113 |
114 | def update(self, asset, amount, fill_price, commission) -> float:
115 | """asset position + amount, also calculation average_price and realized P&L"""
116 | assert self._current_dt is not None
117 | if amount == 0:
118 | return 0
119 | if asset in self._positions:
120 | empty, realized = self._positions[asset].update(
121 | amount, fill_price, commission, self._current_dt)
122 | if empty:
123 | del self._positions[asset]
124 | return realized
125 | else:
126 | self._positions[asset] = Position(
127 | amount, fill_price, commission, self._current_dt, self.stop_model)
128 | return 0
129 |
130 | def update_cash(self, amount, is_funds=False):
131 | """ is_funds: Is this cash update related to funds transfer, (deposits/withdraw) """
132 | assert amount == amount
133 | self._cash += amount
134 | if is_funds:
135 | if self._current_dt is None:
136 | current_date = None
137 | else:
138 | current_date = self._current_dt.normalize()
139 | self._funds_change.append({'index': current_date, 'amount': amount})
140 |
141 | def process_split(self, asset, inverse_ratio: float, last_price):
142 | if asset not in self._positions:
143 | return
144 | pos = self._positions[asset]
145 | cash = pos.process_split(inverse_ratio, last_price)
146 | self.update_cash(cash)
147 |
148 | def process_dividend(self, asset, amount, tax):
149 | if asset not in self._positions:
150 | return
151 | pos = self._positions[asset]
152 | cash = pos.process_dividend(amount, tax)
153 | self.update_cash(cash)
154 |
155 | def process_borrow_interest(self, day_passed, money_interest_rate, stock_interest_rate):
156 | interest = 0
157 | for asset, pos in self._positions.items():
158 | # 有的时候运行到这pos.last_price没更新会导致cash变成nan,还没找到哪里没更新
159 | if pos.shares < 0 and pos.value == pos.value:
160 | interest += pos.value * (stock_interest_rate / 365) * day_passed
161 | if self._cash < 0:
162 | interest += self._cash * (money_interest_rate / 365) * day_passed
163 | self.update_cash(interest)
164 |
165 | def _update_value_func(self, func):
166 | for asset, pos in self._positions.items():
167 | price = func(asset)
168 | if price and price == price:
169 | pos.last_price = price
170 |
171 | def _update_value_dict(self, prices):
172 | for asset, pos in self._positions.items():
173 | price = prices.get(asset, np.nan)
174 | if price == price:
175 | pos.last_price = price
176 |
177 | def update_value(self, prices: Union[Callable, dict]):
178 | if callable(prices):
179 | self._update_value_func(prices)
180 | elif isinstance(prices, dict):
181 | self._update_value_dict(prices)
182 | else:
183 | raise ValueError('prices either callable or dict')
184 |
185 | def check_stop_trigger(self, *args):
186 | ret = []
187 | for asset in list(self._positions.keys()):
188 | pos = self._positions[asset]
189 | ret.append(pos.check_stop_trigger(asset, -pos.shares, *args))
190 | return ret
191 |
--------------------------------------------------------------------------------
/spectre/trading/position.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Heerozh (Zhang Jianhao)
3 | @copyright: Copyright 2019-2020, Heerozh. All rights reserved.
4 | @license: Apache 2.0
5 | @email: heeroz@gmail.com
6 | """
7 | import math
8 | from typing import Tuple
9 |
10 | from .stopmodel import StopModel
11 |
12 |
13 | def sign(x):
14 | return math.copysign(1, x)
15 |
16 |
17 | class Position:
18 | def __init__(self, shares: int, fill_price: float, commission: float,
19 | dt, stop_model: StopModel = None):
20 | self._shares = shares
21 | self._average_price = fill_price + commission / shares
22 | self._last_price = fill_price
23 | self._realized = 0
24 | self._open_dt = dt
25 | self.current_dt = dt
26 | self.stop_model = stop_model
27 | self.stop_tracker = None
28 | if stop_model is not None:
29 | self.stop_tracker = stop_model.new_tracker(
30 | fill_price, True if self._shares < 0 else False)
31 | self.stop_tracker.tracking_position = self
32 |
33 | @property
34 | def open_dt(self):
35 | return self._open_dt
36 |
37 | @property
38 | def period(self):
39 | return self.current_dt - self._open_dt
40 |
41 | @property
42 | def value(self):
43 | return self._shares * self._last_price
44 |
45 | @property
46 | def shares(self):
47 | return self._shares
48 |
49 | @property
50 | def average_price(self):
51 | return self._average_price
52 |
53 | @property
54 | def last_price(self):
55 | return self._last_price
56 |
57 | @last_price.setter
58 | def last_price(self, last_price: float):
59 | self._last_price = last_price
60 | if self.stop_tracker:
61 | self.stop_tracker.update_price(last_price)
62 |
63 | @property
64 | def realized(self):
65 | return self._realized
66 |
67 | @property
68 | def unrealized(self):
69 | return (self._last_price - self._average_price) * self._shares
70 |
71 | @property
72 | def unrealized_percent(self):
73 | return (self._last_price / self._average_price - 1) * sign(self._shares)
74 |
75 | def update(self, amount: int, fill_price: float, commission: float, dt) -> Tuple[bool, float]:
76 | """
77 | position + amount, fill_price and commission is for calculation average_price and P&L
78 | return (True, realized) when position is empty.
79 | """
80 | before_shares = self._shares
81 | before_avg_px = self._average_price
82 | after_shares = before_shares + amount
83 |
84 | # If the position is reversed, it will be filled in 2 steps
85 | if after_shares != 0 and sign(after_shares) != sign(before_shares):
86 | fill_1 = amount - after_shares
87 | fill_2 = amount - fill_1
88 | per_comm = commission / amount
89 | # close position
90 | _, realized = self.update(fill_1, fill_price, per_comm * fill_1, dt)
91 | # open a new position
92 | self.__init__(fill_2, fill_price, per_comm * fill_2, dt, stop_model=self.stop_model)
93 | return False, realized
94 | else:
95 | cum_cost = self._average_price * before_shares + amount * fill_price + commission
96 | self._shares = after_shares
97 | if after_shares == 0:
98 | self._average_price = 0
99 | realized = -cum_cost - self._realized
100 | self._realized = -cum_cost
101 | self.last_price = fill_price
102 | return True, realized
103 | else:
104 | self._average_price = cum_cost / after_shares
105 | if after_shares < before_shares:
106 | realized = (before_avg_px - self._average_price) * abs(after_shares)
107 | else:
108 | realized = 0
109 | self._realized += realized
110 | self.last_price = fill_price
111 | return False, realized
112 |
113 | def process_split(self, inverse_ratio: float, last_price: float) -> float:
114 | if inverse_ratio != inverse_ratio or inverse_ratio == 1:
115 | return 0
116 | sp = self._shares * inverse_ratio
117 | cash = 0
118 | if inverse_ratio < 1: # reverse split remaining to cash
119 | remaining = int(self._shares - int(sp) / inverse_ratio) # for more precise
120 | if remaining != 0:
121 | cash = remaining * last_price
122 | self._shares = int(round(sp, 5))
123 | self._average_price = self._average_price / inverse_ratio
124 | self.last_price = last_price / inverse_ratio
125 |
126 | if self.stop_tracker:
127 | self.stop_tracker.process_split(last_price)
128 | return cash
129 |
130 | def process_dividend(self, amount: float, tax: float) -> float:
131 | if amount != amount or amount == 0:
132 | return 0
133 | self._average_price -= amount
134 | self.last_price -= amount + tax
135 | cash = self._shares * amount
136 | return cash
137 |
138 | def check_stop_trigger(self, *args):
139 | if self.stop_tracker:
140 | return self.stop_tracker.check_trigger(*args)
141 |
--------------------------------------------------------------------------------
/spectre/trading/stopmodel.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Heerozh (Zhang Jianhao)
3 | @copyright: Copyright 2019-2020, Heerozh. All rights reserved.
4 | @license: Apache 2.0
5 | @email: heeroz@gmail.com
6 | """
7 | import math
8 |
9 |
10 | def sign(x):
11 | return math.copysign(1, x)
12 |
13 |
14 | class PriceTracker:
15 | def __init__(self, current_price, recorder=max):
16 | self.last_price = current_price
17 | self.recorder = recorder
18 | self.recorded_price = current_price
19 | self.tracking_position = None
20 |
21 | def update_price(self, last_price):
22 | self.recorded_price = self.recorder(self.recorded_price, last_price)
23 | self.last_price = last_price
24 |
25 | def process_split(self, inverse_ratio: float):
26 | self.recorded_price /= inverse_ratio
27 |
28 |
29 | # -----------------------------------------------------------------------------
30 |
31 |
32 | class StopTracker(PriceTracker):
33 | def __init__(self, current_price, stop_price, callback):
34 | super().__init__(current_price, lambda _, x: x)
35 | self._stop_price = stop_price
36 | self.stop_loss = stop_price < current_price
37 | self.callback = callback
38 |
39 | @property
40 | def stop_price(self):
41 | return self._stop_price
42 |
43 | def fire(self, *args):
44 | if callable(self.callback):
45 | return self.callback(*args)
46 | else:
47 | return self.callback
48 |
49 | def check_trigger(self, *args):
50 | if self.stop_loss:
51 | if self.last_price <= self.stop_price:
52 | return self.fire(*args)
53 | else:
54 | if self.last_price >= self.stop_price:
55 | return self.fire(*args)
56 | return False
57 |
58 |
59 | class StopModel:
60 | def __init__(self, ratio: float, callback=None):
61 | self.ratio = ratio
62 | self.callback = callback
63 |
64 | def new_tracker(self, current_price, inverse):
65 | if inverse:
66 | stop_price = current_price * (1 - self.ratio)
67 | else:
68 | stop_price = current_price * (1 + self.ratio)
69 | return StopTracker(current_price, stop_price, self.callback)
70 |
71 |
72 | # -----------------------------------------------------------------------------
73 |
74 |
75 | class TrailingStopTracker(StopTracker):
76 | def __init__(self, current_price, ratio, callback):
77 | self.ratio = ratio
78 | stop_price = current_price * (1 + self.ratio)
79 | StopTracker.__init__(self, current_price, stop_price, callback=callback)
80 | PriceTracker.__init__(self, current_price, recorder=max if ratio < 0 else min)
81 |
82 | @property
83 | def stop_price(self):
84 | return self.recorded_price * (1 + self.ratio)
85 |
86 |
87 | class TrailingStopModel(StopModel):
88 | """
89 | Unlike trailing stop order, the ratio in this model is relative to the highest / lowest price,
90 | so -0.1 means stop price is 90% of the highest price from now to the future; 0.1 means stop
91 | price is 110% of the lowest price from now to the future.
92 | """
93 | def new_tracker(self, current_price, inverse):
94 | ratio = -self.ratio if inverse else self.ratio
95 | return TrailingStopTracker(current_price, ratio, self.callback)
96 |
97 |
98 | # -----------------------------------------------------------------------------
99 |
100 |
101 | class DecayTrailingStopTracker(TrailingStopTracker):
102 | def __init__(self, current_price, ratio, target, decay_rate, max_decay, callback):
103 | self.initial_ratio = ratio
104 | self.max_decay = max_decay
105 | self.decay_rate = decay_rate
106 | self.target = target
107 | super().__init__(current_price, ratio, callback)
108 |
109 | @property
110 | def current(self):
111 | raise NotImplementedError("abstractmethod")
112 |
113 | @property
114 | def stop_price(self):
115 | decay = max(self.decay_rate ** (self.current / self.target), self.max_decay)
116 | self.ratio = self.initial_ratio * decay
117 | return self.recorded_price * (1 + self.ratio)
118 |
119 |
120 | class PnLDecayTrailingStopTracker(DecayTrailingStopTracker):
121 | @property
122 | def current(self):
123 | pos = self.tracking_position
124 | pnl = (self.recorded_price / pos.average_price - 1) * sign(pos.shares)
125 | pnl = max(pnl, 0) if self.target > 0 else min(pnl, 0)
126 | return pnl
127 |
128 |
129 | class PnLDecayTrailingStopModel(StopModel):
130 | """
131 | Exponential decay to the stop ratio: `ratio * decay_rate ^ (PnL% / PnL_target%)`.
132 | If it's stop gain model, `PnL_target` should be Loss Target (negative).
133 |
134 | So, the lower the `ratio` when PnL% approaches the target, and if PnL% exceeds PnL_target%,
135 | any small opposite changes will trigger stop.
136 | """
137 |
138 | def __init__(self, ratio: float, pnl_target: float, callback=None,
139 | decay_rate=0.05, max_decay=0):
140 | super().__init__(ratio, callback)
141 | self.decay_rate = decay_rate
142 | self.pnl_target = pnl_target
143 | self.max_decay = max_decay
144 |
145 | def new_tracker(self, current_price, inverse):
146 | ratio = -self.ratio if inverse else self.ratio
147 | return PnLDecayTrailingStopTracker(
148 | current_price, ratio, self.pnl_target, self.decay_rate, self.max_decay, self.callback)
149 |
150 |
151 | class TimeDecayTrailingStopTracker(DecayTrailingStopTracker):
152 | @property
153 | def current(self):
154 | pos = self.tracking_position
155 | return pos.period
156 |
157 |
158 | class TimeDecayTrailingStopModel(StopModel):
159 | def __init__(self, ratio: float, period_target: 'pd.Timedelta', callback=None,
160 | decay_rate=0.05, max_decay=0):
161 | super().__init__(ratio, callback)
162 | self.decay_rate = decay_rate
163 | self.period_target = period_target
164 | self.max_decay = max_decay
165 |
166 | def new_tracker(self, current_price, inverse):
167 | ratio = -self.ratio if inverse else self.ratio
168 | return TimeDecayTrailingStopTracker(
169 | current_price, ratio, self.period_target, self.decay_rate, self.max_decay,
170 | self.callback)
171 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Heerozh/spectre/c0e7fd974227b376a5790cc66d49f65d58929af8/tests/__init__.py
--------------------------------------------------------------------------------
/tests/data/5mins/AAPL_2018.csv:
--------------------------------------------------------------------------------
1 | Date,Open,High,Low,Close,Volume,BarCount,Wap
2 | 2018-12-28 09:30:00-05:00,157.48,157.65,155.76,156.03,14530,4350,156.849
3 | 2018-12-28 09:35:00-05:00,156.01,156.05,155.19,155.39,8493,4002,155.539
4 | 2018-12-28 09:40:00-05:00,155.42,156.09,155.19,156.05,6825,3040,155.681
5 | 2018-12-28 09:45:00-05:00,156.08,156.84,156.02,156.39,9933,4339,156.506
6 | 2018-12-28 09:50:00-05:00,156.4,156.59,155.84,156.42,7377,3527,156.208
7 | 2018-12-28 09:55:00-05:00,156.45,156.88,156.11,156.73,5646,2717,156.479
8 | 2018-12-28 10:00:00-05:00,156.75,156.84,155.8,156.0,5533,2701,156.255
9 | 2018-12-28 10:05:00-05:00,156.02,156.75,156.02,156.62,7276,3691,156.534
10 | 2018-12-28 10:10:00-05:00,156.63,156.77,155.81,156.06,6216,2970,156.293
11 | 2018-12-28 10:15:00-05:00,156.08,156.45,155.9,156.33,4541,2229,156.208
12 | 2018-12-28 10:20:00-05:00,156.35,156.5,155.59,155.73,5287,2537,156.113
13 | 2018-12-28 10:25:00-05:00,155.73,155.85,155.16,155.27,5340,2661,155.449
14 | 2018-12-28 10:30:00-05:00,155.26,155.78,155.12,155.29,4231,2074,155.413
15 | 2018-12-28 10:35:00-05:00,155.28,155.31,154.55,155.12,6388,3004,154.854
16 | 2018-12-28 10:40:00-05:00,155.1,155.44,154.9,154.9,4302,2056,155.183
17 | 2018-12-28 10:45:00-05:00,154.9,155.72,154.9,155.56,4212,2067,155.377
18 | 2018-12-28 10:50:00-05:00,155.59,155.99,155.49,155.59,5054,2425,155.797
19 | 2018-12-28 10:55:00-05:00,155.59,155.73,155.22,155.31,2992,1379,155.451
20 | 2018-12-28 11:00:00-05:00,155.32,155.5,155.07,155.16,2550,1394,155.303
21 | 2018-12-28 11:05:00-05:00,155.16,155.74,155.0,155.66,3333,1718,155.311
22 | 2018-12-28 11:10:00-05:00,155.64,155.9,155.59,155.78,4292,1816,155.775
23 | 2018-12-28 11:15:00-05:00,155.78,156.22,155.55,155.68,6413,2925,155.854
24 | 2018-12-28 11:20:00-05:00,155.66,155.84,155.44,155.64,2844,1457,155.617
25 | 2018-12-28 11:25:00-05:00,155.64,156.2,155.59,156.08,3821,1654,155.913
26 | 2018-12-28 11:30:00-05:00,156.07,156.25,155.87,156.22,3717,1776,156.109
27 | 2018-12-28 11:35:00-05:00,156.24,156.93,156.21,156.86,5446,2236,156.597
28 | 2018-12-28 11:40:00-05:00,156.86,157.11,156.73,156.74,4495,1927,156.897
29 | 2018-12-28 11:45:00-05:00,156.75,156.82,156.47,156.82,3335,1365,156.661
30 | 2018-12-28 11:50:00-05:00,156.81,157.2,156.63,157.02,3219,1377,156.944
31 | 2018-12-28 11:55:00-05:00,157.02,157.09,156.8,156.92,2776,1195,156.973
32 | 2018-12-28 12:00:00-05:00,156.93,157.23,156.68,157.22,3919,1594,156.939
33 | 2018-12-28 12:05:00-05:00,157.2,157.21,156.65,156.7,3105,1272,156.987
34 | 2018-12-28 12:10:00-05:00,156.72,156.81,156.4,156.49,3073,1286,156.564
35 | 2018-12-28 12:15:00-05:00,156.49,156.56,156.2,156.48,3369,1459,156.393
36 | 2018-12-28 12:20:00-05:00,156.48,156.96,156.48,156.83,3758,1671,156.719
37 | 2018-12-28 12:25:00-05:00,156.83,156.88,156.51,156.68,2637,1244,156.664
38 | 2018-12-28 12:30:00-05:00,156.67,156.92,156.61,156.77,2327,1135,156.753
39 | 2018-12-28 12:35:00-05:00,156.76,156.76,156.39,156.7,2507,1443,156.583
40 | 2018-12-28 12:40:00-05:00,156.71,156.81,156.54,156.57,2913,1232,156.679
41 | 2018-12-28 12:45:00-05:00,156.57,156.63,156.33,156.43,2341,1174,156.457
42 | 2018-12-28 12:50:00-05:00,156.43,156.64,156.22,156.51,2540,1175,156.426
43 | 2018-12-28 12:55:00-05:00,156.53,156.58,156.03,156.21,2431,1100,156.253
44 | 2018-12-28 13:00:00-05:00,156.22,156.36,156.14,156.17,1835,1017,156.251
45 | 2018-12-28 13:05:00-05:00,156.17,156.37,156.07,156.29,2076,1050,156.216
46 | 2018-12-28 13:10:00-05:00,156.29,156.38,156.07,156.29,2189,1182,156.255
47 | 2018-12-28 13:15:00-05:00,156.29,156.42,156.06,156.07,2135,855,156.229
48 | 2018-12-28 13:20:00-05:00,156.08,156.13,155.53,155.55,3625,1417,155.892
49 | 2018-12-28 13:25:00-05:00,155.54,155.77,155.41,155.71,3012,1368,155.581
50 | 2018-12-28 13:30:00-05:00,155.68,155.76,155.27,155.4,2850,1304,155.468
51 | 2018-12-28 13:35:00-05:00,155.39,155.6,155.29,155.47,1733,704,155.429
52 | 2018-12-28 13:40:00-05:00,155.46,155.67,155.38,155.46,1712,710,155.528
53 | 2018-12-28 13:45:00-05:00,155.47,156.1,155.47,155.99,2716,1133,155.854
54 | 2018-12-28 13:50:00-05:00,155.98,156.24,155.82,155.9,2310,1072,156.029
55 | 2018-12-28 13:55:00-05:00,155.9,156.21,155.9,156.03,1814,770,156.031
56 | 2018-12-28 14:00:00-05:00,156.05,156.66,156.05,156.61,3788,1495,156.444
57 | 2018-12-28 14:05:00-05:00,156.63,156.98,156.62,156.91,3227,1420,156.81
58 | 2018-12-28 14:10:00-05:00,156.89,157.35,156.76,157.22,4298,1810,157.114
59 | 2018-12-28 14:15:00-05:00,157.23,157.58,157.22,157.26,4489,2017,157.43
60 | 2018-12-28 14:20:00-05:00,157.25,157.45,157.04,157.37,3218,1411,157.265
61 | 2018-12-28 14:25:00-05:00,157.37,157.58,157.22,157.29,2641,1237,157.412
62 | 2018-12-28 14:30:00-05:00,157.29,157.45,157.04,157.45,3288,1398,157.228
63 | 2018-12-28 14:35:00-05:00,157.43,157.65,157.32,157.63,3294,1656,157.468
64 | 2018-12-28 14:40:00-05:00,157.65,158.1,157.49,157.99,5939,2373,157.848
65 | 2018-12-28 14:45:00-05:00,158.0,158.05,157.81,157.87,3579,1737,157.946
66 | 2018-12-28 14:50:00-05:00,157.88,158.52,157.82,158.2,5207,2384,158.115
67 | 2018-12-28 14:55:00-05:00,158.19,158.47,158.11,158.41,4225,1955,158.303
68 | 2018-12-28 15:00:00-05:00,158.4,158.51,157.95,158.02,3988,2069,158.24
69 | 2018-12-28 15:05:00-05:00,158.02,158.04,157.12,157.39,5006,2637,157.536
70 | 2018-12-28 15:10:00-05:00,157.38,157.67,156.8,157.58,5464,2690,157.17
71 | 2018-12-28 15:15:00-05:00,157.58,157.59,156.84,156.95,3124,1621,157.235
72 | 2018-12-28 15:20:00-05:00,156.93,157.04,156.67,157.01,4691,2531,156.862
73 | 2018-12-28 15:25:00-05:00,157.01,157.13,156.23,156.54,5140,2827,156.53
74 | 2018-12-28 15:30:00-05:00,156.54,156.63,156.16,156.47,4636,2461,156.407
75 | 2018-12-28 15:35:00-05:00,156.47,156.71,156.14,156.2,4395,2461,156.387
76 | 2018-12-28 15:40:00-05:00,156.19,156.46,155.9,155.94,5763,3510,156.215
77 | 2018-12-28 15:45:00-05:00,155.92,156.24,155.71,156.1,5309,3007,155.956
78 | 2018-12-28 15:50:00-05:00,156.1,156.33,155.67,155.75,5943,3780,156.01
79 | 2018-12-28 15:55:00-05:00,155.75,156.5,155.68,156.23,12132,7388,156.05
80 | 2018-12-31 09:30:00-05:00,158.5,158.85,157.96,158.71,11487,3277,158.395
81 | 2018-12-31 09:35:00-05:00,158.69,159.36,158.68,159.04,8422,3572,159.054
82 | 2018-12-31 09:40:00-05:00,159.04,159.17,158.7,158.77,4954,2159,158.883
83 | 2018-12-31 09:45:00-05:00,158.78,158.88,158.29,158.31,4922,2010,158.581
84 | 2018-12-31 09:50:00-05:00,158.32,158.33,157.67,157.75,5488,2321,157.919
85 | 2018-12-31 09:55:00-05:00,157.74,157.94,157.48,157.93,6145,1986,157.715
86 | 2018-12-31 10:00:00-05:00,157.92,158.23,157.91,158.01,4460,1626,158.084
87 | 2018-12-31 10:05:00-05:00,157.99,158.24,157.45,157.47,5103,2032,157.748
88 | 2018-12-31 10:10:00-05:00,157.46,157.69,157.0,157.03,5865,2401,157.407
89 | 2018-12-31 10:15:00-05:00,157.03,157.17,156.88,157.09,4178,1808,157.021
90 | 2018-12-31 10:20:00-05:00,157.07,157.23,156.93,157.06,3010,1582,157.094
91 | 2018-12-31 10:25:00-05:00,157.05,157.48,157.05,157.46,3333,1669,157.276
92 | 2018-12-31 10:30:00-05:00,157.47,157.6,156.81,156.9,4413,1990,157.136
93 | 2018-12-31 10:35:00-05:00,156.9,156.9,156.51,156.7,3970,1727,156.698
94 | 2018-12-31 10:40:00-05:00,156.68,157.0,156.63,156.88,4122,1707,156.811
95 | 2018-12-31 10:45:00-05:00,156.86,157.04,156.68,156.72,2654,1403,156.896
96 | 2018-12-31 10:50:00-05:00,156.71,156.81,156.54,156.55,2362,1134,156.675
97 | 2018-12-31 10:55:00-05:00,156.54,157.07,156.51,157.06,4184,1529,156.75
98 | 2018-12-31 11:00:00-05:00,157.06,157.39,156.91,157.36,3507,1648,157.141
99 | 2018-12-31 11:05:00-05:00,157.35,157.58,157.24,157.39,3958,1789,157.391
100 | 2018-12-31 11:10:00-05:00,157.4,157.43,157.11,157.33,3269,1602,157.257
101 | 2018-12-31 11:15:00-05:00,157.32,157.5,157.14,157.31,3153,1466,157.349
102 | 2018-12-31 11:20:00-05:00,157.32,157.35,156.94,156.98,3473,1694,157.074
103 | 2018-12-31 11:25:00-05:00,156.96,157.7,156.93,157.56,3962,1998,157.363
104 | 2018-12-31 11:30:00-05:00,157.52,157.62,157.07,157.17,3163,1530,157.399
105 | 2018-12-31 11:35:00-05:00,157.18,157.68,156.99,157.55,3753,1770,157.424
106 | 2018-12-31 11:40:00-05:00,157.54,157.61,157.33,157.6,3133,1718,157.455
107 | 2018-12-31 11:45:00-05:00,157.58,157.92,157.51,157.89,3430,1571,157.713
108 | 2018-12-31 11:50:00-05:00,157.89,158.17,157.69,158.08,4052,1510,157.994
109 | 2018-12-31 11:55:00-05:00,158.07,158.1,157.77,157.96,2303,1086,157.935
110 | 2018-12-31 12:00:00-05:00,157.99,158.02,157.74,157.91,2049,918,157.902
111 | 2018-12-31 12:05:00-05:00,157.94,158.1,157.8,158.05,2311,1026,157.975
112 | 2018-12-31 12:10:00-05:00,158.02,158.09,157.77,158.05,3076,1433,157.929
113 | 2018-12-31 12:15:00-05:00,158.05,158.24,157.83,158.14,3164,1309,158.073
114 | 2018-12-31 12:20:00-05:00,158.13,158.48,158.11,158.31,3742,1573,158.33
115 | 2018-12-31 12:25:00-05:00,158.3,158.58,158.26,158.53,3390,1408,158.442
116 | 2018-12-31 12:30:00-05:00,158.52,158.72,158.51,158.7,2436,1161,158.637
117 | 2018-12-31 12:35:00-05:00,158.69,158.85,158.63,158.83,2101,1090,158.749
118 | 2018-12-31 12:40:00-05:00,158.84,158.94,158.68,158.93,2507,1316,158.81
119 | 2018-12-31 12:45:00-05:00,158.92,158.93,158.58,158.6,1787,777,158.741
120 | 2018-12-31 12:50:00-05:00,158.6,158.61,158.3,158.38,2516,1178,158.407
121 | 2018-12-31 12:55:00-05:00,158.37,158.44,158.17,158.37,2206,1008,158.305
122 | 2018-12-31 13:00:00-05:00,158.38,158.48,158.13,158.24,2175,950,158.343
123 | 2018-12-31 13:05:00-05:00,158.22,158.3,158.09,158.23,2073,1001,158.201
124 | 2018-12-31 13:10:00-05:00,158.25,158.43,158.18,158.36,2329,1056,158.332
125 | 2018-12-31 13:15:00-05:00,158.36,158.48,158.28,158.3,2355,1013,158.406
126 | 2018-12-31 13:20:00-05:00,158.31,158.36,158.25,158.33,1367,589,158.298
127 | 2018-12-31 13:25:00-05:00,158.33,158.43,158.25,158.4,1178,533,158.35
128 | 2018-12-31 13:30:00-05:00,158.4,158.44,158.15,158.19,1647,828,158.28
129 | 2018-12-31 13:35:00-05:00,158.17,158.2,157.85,157.91,2175,999,158.002
130 | 2018-12-31 13:40:00-05:00,157.9,158.02,157.7,157.74,2708,1183,157.902
131 | 2018-12-31 13:45:00-05:00,157.75,157.8,157.63,157.76,1756,803,157.705
132 | 2018-12-31 13:50:00-05:00,157.76,157.85,157.43,157.55,2624,1153,157.616
133 | 2018-12-31 13:55:00-05:00,157.54,157.67,157.5,157.57,1422,762,157.589
134 | 2018-12-31 14:00:00-05:00,157.58,157.6,157.33,157.39,2295,1189,157.472
135 | 2018-12-31 14:05:00-05:00,157.4,157.82,157.36,157.68,2740,1145,157.673
136 | 2018-12-31 14:10:00-05:00,157.68,157.72,157.41,157.49,2317,906,157.563
137 | 2018-12-31 14:15:00-05:00,157.48,157.79,157.29,157.77,2965,1441,157.569
138 | 2018-12-31 14:20:00-05:00,157.77,157.79,157.2,157.35,3435,1518,157.469
139 | 2018-12-31 14:25:00-05:00,157.34,157.38,157.05,157.09,2088,905,157.244
140 | 2018-12-31 14:30:00-05:00,157.07,157.17,156.92,156.97,2976,1275,157.023
141 | 2018-12-31 14:35:00-05:00,156.96,157.37,156.85,157.32,2535,1311,157.081
142 | 2018-12-31 14:40:00-05:00,157.32,157.46,157.22,157.45,2421,876,157.333
143 | 2018-12-31 14:45:00-05:00,157.42,157.61,157.34,157.44,1626,610,157.511
144 | 2018-12-31 14:50:00-05:00,157.45,157.74,157.44,157.56,2233,1126,157.636
145 | 2018-12-31 14:55:00-05:00,157.55,157.67,157.3,157.37,1931,903,157.517
146 | 2018-12-31 15:00:00-05:00,157.36,157.69,157.25,157.29,2686,1399,157.485
147 | 2018-12-31 15:05:00-05:00,157.31,157.47,157.26,157.29,1841,1014,157.384
148 | 2018-12-31 15:10:00-05:00,157.28,157.37,157.07,157.09,1909,989,157.217
149 | 2018-12-31 15:15:00-05:00,157.1,157.26,156.9,157.26,3223,1568,157.066
150 | 2018-12-31 15:20:00-05:00,157.25,157.36,156.99,157.07,2318,1119,157.187
151 | 2018-12-31 15:25:00-05:00,157.06,157.33,156.97,157.0,3231,1581,157.142
152 | 2018-12-31 15:30:00-05:00,157.0,157.44,157.0,157.3,2822,1376,157.304
153 | 2018-12-31 15:35:00-05:00,157.29,157.34,157.08,157.25,3303,1454,157.205
154 | 2018-12-31 15:40:00-05:00,157.26,157.48,157.14,157.34,4036,2078,157.337
155 | 2018-12-31 15:45:00-05:00,157.35,157.49,157.06,157.15,4345,2222,157.262
156 | 2018-12-31 15:50:00-05:00,157.14,157.27,157.02,157.12,5169,2549,157.148
157 | 2018-12-31 15:55:00-05:00,157.14,157.95,156.48,157.94,14342,8274,157.104
158 |
--------------------------------------------------------------------------------
/tests/data/dividends/AAPL.csv:
--------------------------------------------------------------------------------
1 | exDate,paymentDate,recordDate,declaredDate,amount,flag,currency,description,frequency,date
2 | 2015-02-05,2015-02-12,2015-02-09,2015-01-27,0.47,Cash,USD,Ordinary Shares,quarterly,2019-12-07
3 | 2015-05-07,2015-05-14,2015-05-11,2015-04-27,0.52,Cash,USD,Ordinary Shares,quarterly,2019-12-07
4 | 2015-08-06,2015-08-13,2015-08-10,2015-07-21,0.52,Cash,USD,Ordinary Shares,quarterly,2019-12-07
5 | 2015-11-05,2015-11-12,2015-11-09,2015-10-27,0.52,Cash,USD,Ordinary Shares,quarterly,2019-12-07
6 | 2016-02-04,2016-02-11,2016-02-08,2016-01-26,0.52,Cash,USD,Ordinary Shares,quarterly,2019-12-07
7 | 2016-05-05,2016-05-12,2016-05-09,2016-04-26,0.57,Cash,USD,Ordinary Shares,quarterly,2019-12-07
8 | 2016-08-04,2016-08-11,2016-08-08,2016-07-26,0.57,Cash,USD,Ordinary Shares,quarterly,2019-12-07
9 | 2016-11-03,2016-11-10,2016-11-07,2016-10-25,0.57,Cash,USD,Ordinary Shares,quarterly,2019-12-07
10 | 2017-02-09,2017-02-16,2017-02-13,2017-01-31,0.57,Cash,USD,Ordinary Shares,quarterly,2019-12-07
11 | 2017-05-11,2017-05-18,2017-05-15,2017-05-02,0.63,Cash,USD,Ordinary Shares,quarterly,2019-12-07
12 | 2017-08-10,2017-08-17,2017-08-14,2017-08-01,0.63,Cash,USD,Ordinary Shares,quarterly,2019-12-07
13 | 2017-11-10,2017-11-16,2017-11-13,2017-11-02,0.63,Cash,USD,Ordinary Shares,quarterly,2019-12-07
14 | 2018-02-09,2018-02-15,2018-02-12,2018-02-01,0.63,Cash,USD,Ordinary Shares,quarterly,2019-12-07
15 | 2018-05-11,2018-05-17,2018-05-14,2018-05-01,0.73,Cash,USD,Ordinary Shares,quarterly,2019-12-07
16 | 2018-08-10,2018-08-16,2018-08-13,2018-07-31,0.73,Cash,USD,Ordinary Shares,quarterly,2019-12-07
17 | 2018-11-08,2018-11-15,2018-11-12,2018-11-01,0.73,Cash,USD,Ordinary Shares,quarterly,2019-12-07
18 | 2019-02-08,2019-02-14,2019-02-11,2019-01-29,0.73,Cash,USD,Ordinary Shares,quarterly,2019-12-07
19 | 2019-05-10,2019-05-16,2019-05-13,2019-04-30,0.77,Cash,USD,Ordinary Shares,quarterly,2019-12-07
20 | 2019-08-09,2019-08-15,2019-08-12,2019-07-30,0.77,Cash,USD,Ordinary Shares,quarterly,2019-12-07
21 | 2019-11-07,2019-11-14,2019-11-11,2019-10-30,0.77,Cash,USD,Ordinary Shares,quarterly,2019-12-07
22 |
--------------------------------------------------------------------------------
/tests/data/dividends/MSFT.csv:
--------------------------------------------------------------------------------
1 | exDate,paymentDate,recordDate,declaredDate,amount,flag,currency,description,frequency,date
2 | 2015-02-17,2015-03-12,2015-02-19,2014-12-03,0.31,Cash,USD,Ordinary Shares,quarterly,2019-12-07
3 | 2015-05-19,2015-06-11,2015-05-21,2015-03-10,0.31,Cash,USD,Ordinary Shares,quarterly,2019-12-07
4 | 2015-08-18,2015-09-10,2015-08-20,2015-06-09,0.31,Cash,USD,Ordinary Shares,quarterly,2019-12-07
5 | 2015-11-17,2015-12-10,2015-11-19,2015-09-15,0.36,Cash,USD,Ordinary Shares,quarterly,2019-12-07
6 | 2016-02-16,2016-03-10,2016-02-18,2015-12-02,0.36,Cash,USD,Ordinary Shares,quarterly,2019-12-07
7 | 2016-05-17,2016-06-09,2016-05-19,2016-03-15,0.36,Cash,USD,Ordinary Shares,quarterly,2019-12-07
8 | 2016-08-16,2016-09-08,2016-08-18,2016-06-14,0.36,Cash,USD,Ordinary Shares,quarterly,2019-12-07
9 | 2016-11-15,2016-12-08,2016-11-17,2016-09-20,0.39,Cash,USD,Ordinary Shares,quarterly,2019-12-07
10 | 2017-02-14,2017-03-09,2017-02-16,2016-11-30,0.39,Cash,USD,Ordinary Shares,quarterly,2019-12-07
11 | 2017-05-16,2017-06-08,2017-05-18,2017-03-14,0.39,Cash,USD,Ordinary Shares,quarterly,2019-12-07
12 | 2017-08-15,2017-09-14,2017-08-17,2017-06-13,0.39,Cash,USD,Ordinary Shares,quarterly,2019-12-07
13 | 2017-11-15,2017-12-14,2017-11-16,2017-09-19,0.42,Cash,USD,Ordinary Shares,quarterly,2019-12-07
14 | 2018-02-14,2018-03-08,2018-02-15,2017-11-29,0.42,Cash,USD,Ordinary Shares,quarterly,2019-12-07
15 | 2018-05-16,2018-06-14,2018-05-17,2018-03-12,0.42,Cash,USD,Ordinary Shares,quarterly,2019-12-07
16 | 2018-08-15,2018-09-13,2018-08-16,2018-06-13,0.42,Cash,USD,Ordinary Shares,quarterly,2019-12-07
17 | 2018-11-14,2018-12-13,2018-11-15,2018-09-18,0.46,Cash,USD,Ordinary Shares,quarterly,2019-12-07
18 | 2019-01-11,2019-03-14,2019-02-21,2018-11-28,0.46,Cash,USD,Ordinary Shares,quarterly,2019-12-07
19 | 2019-01-11,2019-03-14,2019-02-21,2018-11-28,nan,Cash,USD,Ordinary Shares,quarterly,2019-12-07
20 | 2019-01-11,2019-03-14,2019-02-21,2018-11-28,0.11,Cash,USD,Ordinary Shares,quarterly,2019-12-07
21 | 2019-05-15,2019-06-13,2019-05-16,2019-03-11,0.46,Cash,USD,Ordinary Shares,quarterly,2019-12-07
22 | 2019-08-14,2019-09-12,2019-08-15,2019-06-12,0.46,Cash,USD,Ordinary Shares,quarterly,2019-12-07
23 | 2019-11-20,2019-12-12,2019-11-21,2019-09-18,0.51,Cash,USD,Ordinary Shares,quarterly,2019-12-07
24 |
--------------------------------------------------------------------------------
/tests/data/splits/AAPL.csv:
--------------------------------------------------------------------------------
1 | exDate,date
2 | ,2019-12-07 23:30:17.719069
3 |
--------------------------------------------------------------------------------
/tests/data/splits/MSFT.csv:
--------------------------------------------------------------------------------
1 | exDate,declaredDate,ratio,toFactor,fromFactor,description,date
2 | 2019-01-14,2017-07-10,5,1,15,1-for-15 Reverse Split,2019-12-07
3 | 2019-01-14,2017-07-10,15,1,15,1-for-15 Reverse Split,2019-12-07
4 | 2019-01-14,2017-05-12,nan,1,15,1-for-15 Reverse Split,2019-12-07
5 |
--------------------------------------------------------------------------------
/tests/test_custom_factor.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import spectre
3 | import numpy as np
4 | from numpy.testing import assert_array_equal
5 | import torch
6 | import warnings
7 | from os.path import dirname
8 |
9 | data_dir = dirname(__file__) + '/data/'
10 |
11 |
12 | class TestCustomFactorLib(unittest.TestCase):
13 |
14 | def test_custom_factor(self):
15 | warnings.filterwarnings("ignore", module='spectre')
16 | # test backward tree
17 | a = spectre.factors.CustomFactor(win=2)
18 | b = spectre.factors.CustomFactor(win=3, inputs=(a,))
19 | c = spectre.factors.CustomFactor(win=3, inputs=(b,))
20 | self.assertEqual(5, c.get_total_backwards_())
21 | m = spectre.factors.CustomFactor(win=10)
22 | c.set_mask(m)
23 | self.assertEqual(9, c.get_total_backwards_())
24 |
25 | a1 = spectre.factors.CustomFactor(win=10)
26 | a2 = spectre.factors.CustomFactor(win=5)
27 | b1 = spectre.factors.CustomFactor(win=20, inputs=(a1, a2))
28 | b2 = spectre.factors.CustomFactor(win=100, inputs=(a2,))
29 | c1 = spectre.factors.CustomFactor(win=100, inputs=(b1,))
30 | self.assertEqual(9, a1.get_total_backwards_())
31 | self.assertEqual(4, a2.get_total_backwards_())
32 | self.assertEqual(28, b1.get_total_backwards_())
33 | self.assertEqual(103, b2.get_total_backwards_())
34 | self.assertEqual(127, c1.get_total_backwards_())
35 |
36 | # test inheritance
37 | loader = spectre.data.CsvDirLoader(
38 | data_dir + '/daily/', ohlcv=('uOpen', 'uHigh', 'uLow', 'uClose', 'uVolume'),
39 | prices_index='date', parse_dates=True,
40 | )
41 | engine = spectre.factors.FactorEngine(loader)
42 |
43 | class TestFactor(spectre.factors.CustomFactor):
44 | inputs = [spectre.factors.OHLCV.open]
45 |
46 | def compute(self, close):
47 | return torch.tensor(np.arange(close.nelement()).reshape(close.shape))
48 |
49 | class TestFactor2(spectre.factors.CustomFactor):
50 | inputs = []
51 |
52 | def compute(self):
53 | return torch.tensor([1])
54 |
55 | engine.add(TestFactor2(), 'test2')
56 | self.assertRaisesRegex(ValueError, "The return data shape.*test2.*",
57 | engine.run, '2019-01-11', '2019-01-15', False)
58 | engine.remove_all_factors()
59 | test_f1 = TestFactor()
60 |
61 | class TestFactor2(spectre.factors.CustomFactor):
62 | inputs = [test_f1]
63 |
64 | def compute(self, test_input):
65 | return torch.tensor(np.cumsum(test_input.numpy(), axis=1))
66 |
67 | engine.add(test_f1, 'test1')
68 | self.assertRaisesRegex(KeyError, ".*exists.*",
69 | engine.add, TestFactor(), 'test1')
70 |
71 | engine.add(TestFactor2(), 'test2')
72 |
73 | for f in engine._factors.values():
74 | f.pre_compute_(engine, '2019-01-11', '2019-01-15')
75 | self.assertEqual(2, test_f1._ref_count)
76 | for f in engine._factors.values():
77 | f._ref_count = 0
78 |
79 | df = engine.run('2019-01-11', '2019-01-15', delay_factor=False)
80 | self.assertEqual(0, test_f1._ref_count)
81 | assert_array_equal([0, 3, 1, 4, 2, 5], df['test1'].values)
82 | assert_array_equal([0, 3, 1, 7, 3, 12], df['test2'].values)
83 |
--------------------------------------------------------------------------------
/tests/test_data_factor.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import spectre
3 | from numpy.testing import assert_array_equal
4 | from os.path import dirname
5 | import warnings
6 |
7 | data_dir = dirname(__file__) + '/data/'
8 |
9 |
10 | class TestDataFactorLib(unittest.TestCase):
11 | def test_datafactor_value(self):
12 | warnings.filterwarnings("ignore", module='spectre')
13 | loader = spectre.data.CsvDirLoader(
14 | data_dir + '/daily/',
15 | ohlcv=('uOpen', 'uHigh', 'uLow', 'uClose', 'uVolume'),
16 | prices_index='date', parse_dates=True,
17 | )
18 | engine = spectre.factors.FactorEngine(loader)
19 | engine.add(spectre.factors.OHLCV.volume, 'CpVol')
20 | df = engine.run('2019-01-11', '2019-01-15')
21 | assert_array_equal(df.loc[(slice(None), 'AAPL'), 'CpVol'].values,
22 | (28065422, 33834032))
23 | assert_array_equal(df.loc[(slice(None), 'MSFT'), 'CpVol'].values,
24 | (28627674, 28720936))
25 |
26 | engine.add(spectre.factors.ColumnDataFactor(inputs=('changePercent',)), 'Chg')
27 | df = engine.run('2019-01-11', '2019-01-15')
28 | assert_array_equal(df.loc[(slice(None), 'AAPL'), 'Chg'].values,
29 | (-0.9835, -1.5724))
30 | assert_array_equal(df.loc[(slice(None), 'MSFT'), 'Chg'].values,
31 | (-0.8025, -0.7489))
32 |
33 | engine.remove_all_factors()
34 | engine.add(spectre.factors.OHLCV.open, 'open')
35 | df = engine.run('2019-01-11', '2019-01-15', delay_factor=False)
36 | assert_array_equal(df.loc[(slice(None), 'AAPL'), 'open'].values,
37 | (155.72, 155.19, 150.81))
38 | assert_array_equal(df.loc[(slice(None), 'MSFT'), 'open'].values,
39 | (104.65, 104.9, 103.19))
40 |
41 | df = engine.run('2019-01-11', '2019-01-15')
42 | assert_array_equal(df.loc[(slice(None), 'AAPL'), 'open'].values,
43 | (155.72, 155.19, 150.81))
44 | assert_array_equal(df.loc[(slice(None), 'MSFT'), 'open'].values,
45 | (104.65, 104.9, 103.19))
--------------------------------------------------------------------------------
/tests/test_data_loader.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import spectre
3 | import os
4 | import pandas as pd
5 | import numpy as np
6 | from numpy.testing import assert_almost_equal
7 | from os.path import dirname
8 | import warnings
9 |
10 | data_dir = dirname(__file__) + '/data/'
11 |
12 |
13 | class TestDataLoaderLib(unittest.TestCase):
14 | def _assertDFFirstLastEqual(self, tdf, col, expected_first, expected_last):
15 | self.assertAlmostEqual(tdf.loc[tdf.index[0], col], expected_first)
16 | self.assertAlmostEqual(tdf.loc[tdf.index[-1], col], expected_last)
17 |
18 | def test_required_parameters(self):
19 | loader = spectre.data.CsvDirLoader(data_dir + '/daily/')
20 | self.assertRaisesRegex(ValueError, "df must index by datetime.*",
21 | loader.load, '2019-01-01', '2019-01-15', 0)
22 | loader = spectre.data.CsvDirLoader(data_dir + '/daily/', prices_index='date', )
23 | self.assertRaisesRegex(ValueError, "df must index by datetime.*",
24 | loader.load, '2019-01-01', '2019-01-15', 0)
25 |
26 | def test_csv_loader_value(self):
27 | loader = spectre.data.CsvDirLoader(
28 | data_dir + '/daily/', calender_asset='AAPL', prices_index='date', parse_dates=True, )
29 | start, end = pd.Timestamp('2019-01-01', tz='UTC'), pd.Timestamp('2019-01-15', tz='UTC')
30 |
31 | # test backward
32 | df = loader.load(start, end, 11)
33 | self._assertDFFirstLastEqual(df.loc[(slice(None), 'AAPL'), :], 'close', 173.43, 158.09)
34 | self._assertDFFirstLastEqual(df.loc[(slice(None), 'MSFT'), :], 'close', 106.57, 105.36)
35 |
36 | # test value
37 | df = loader.load(start, end, 0)
38 | self._assertDFFirstLastEqual(df.loc[(slice(None), 'AAPL'), :], 'close', 160.35, 158.09)
39 | self._assertDFFirstLastEqual(df.loc[(slice(None), 'MSFT'), :], 'close', 100.1, 105.36)
40 | self._assertDFFirstLastEqual(df.loc[(slice('2019-01-11', '2019-01-12'), 'MSFT'), :],
41 | 'close', 104.5, 104.5)
42 | start, end = pd.Timestamp('2019-01-11', tz='UTC'), pd.Timestamp('2019-01-12', tz='UTC')
43 | df = loader.load(start, end, 0)
44 | self._assertDFFirstLastEqual(df.loc[(slice(None), 'MSFT'), :], 'close', 104.5, 104.5)
45 |
46 | loader.test_load()
47 |
48 | def test_csv_split_loader_value(self):
49 | loader = spectre.data.CsvDirLoader(
50 | data_dir + '/5mins/', prices_by_year=True, prices_index='Date', parse_dates=True,
51 | ohlcv=None)
52 | start = pd.Timestamp('2019-01-02 14:30:00', tz='UTC')
53 | end = pd.Timestamp('2019-01-15', tz='UTC')
54 | loader.load(start, end, 0)
55 |
56 | start = pd.Timestamp('2018-12-31 14:50:00', tz='America/New_York').tz_convert('UTC')
57 | end = pd.Timestamp('2019-01-02 10:00:00', tz='America/New_York').tz_convert('UTC')
58 | df = loader.load(start, end, 0)
59 | self._assertDFFirstLastEqual(df.loc[(slice(None), 'AAPL'), :], 'Open', 157.45, 155.17)
60 | self._assertDFFirstLastEqual(df.loc[(slice(None), 'MSFT'), :], 'Open', 101.44, 99.55)
61 |
62 | loader.test_load()
63 |
64 | def test_csv_div_split(self):
65 | warnings.filterwarnings("ignore", module='spectre')
66 | start, end = pd.Timestamp('2019-01-02', tz='UTC'), pd.Timestamp('2019-01-15', tz='UTC')
67 | loader = spectre.data.CsvDirLoader(
68 | prices_path=data_dir + '/daily/', earliest_date=start.tz_convert(None),
69 | calender_asset='AAPL',
70 | dividends_path=data_dir + '/dividends/', splits_path=data_dir + '/splits/',
71 | ohlcv=('uOpen', 'uHigh', 'uLow', 'uClose', 'uVolume'), adjustments=('amount', 'ratio'),
72 | prices_index='date', dividends_index='exDate', splits_index='exDate',
73 | parse_dates=True, )
74 | loader.test_load()
75 |
76 | df = loader.load(start, end, 0)
77 |
78 | # test value
79 | self.assertAlmostEqual(df.loc[('2019-01-09', 'MSFT'), 'ex-dividend'], 0.57)
80 |
81 | # test adjustments in engine
82 | engine = spectre.factors.FactorEngine(loader)
83 | engine.add(spectre.factors.AdjustedColumnDataFactor(spectre.factors.OHLCV.volume), 'vol')
84 | engine.add(spectre.factors.AdjustedColumnDataFactor(spectre.factors.OHLCV.open), 'open')
85 | df = engine.run(start, end, delay_factor=False)
86 |
87 | expected_msft_open = [1526.24849, 1548.329113, 1536.244448, 1541.16783, 1563.696033,
88 | 1585.47827, 1569.750105, 104.9, 103.19]
89 | expected_msft_vol = [2947962.0000, 3067160.6000, 2443784.2667, 2176777.6000,
90 | 2190846.8000, 2018093.5333, 1908511.6000, 28720936.0000, 32882983.0000]
91 | expected_aapl_open = [155.9200, 147.6300, 148.8400, 148.9000, 150.0000, 157.4400, 154.1000,
92 | 155.7200, 155.1900, 150.8100]
93 | expected_aapl_vol = [37932561, 92707401, 59457561, 56974905, 42839940, 45105063,
94 | 35793075, 28065422, 33834032, 29426699]
95 |
96 | assert_almost_equal(df.loc[(slice(None), 'MSFT'), 'open'], expected_msft_open, decimal=4)
97 | assert_almost_equal(df.loc[(slice(None), 'AAPL'), 'open'], expected_aapl_open, decimal=4)
98 | assert_almost_equal(df.loc[(slice(None), 'MSFT'), 'vol'], expected_msft_vol, decimal=0)
99 | assert_almost_equal(df.loc[(slice(None), 'AAPL'), 'vol'], expected_aapl_vol, decimal=4)
100 |
101 | # rolling adj test
102 | result = []
103 |
104 | class RollingAdjTest(spectre.factors.CustomFactor):
105 | win = 10
106 |
107 | def compute(self, data):
108 | result.append(data.agg(lambda x: x[:, -1]))
109 | return data.last()
110 |
111 | engine = spectre.factors.FactorEngine(loader)
112 | engine.add(RollingAdjTest(inputs=[spectre.factors.OHLCV.volume]), 'vol')
113 | engine.add(RollingAdjTest(inputs=[spectre.factors.OHLCV.open]), 'open')
114 | engine.run(end, end, delay_factor=False)
115 |
116 | assert_almost_equal(result[0][0], expected_aapl_vol, decimal=4)
117 | assert_almost_equal(result[0][1], expected_msft_vol+[np.nan], decimal=0)
118 | assert_almost_equal(result[1][0], expected_aapl_open, decimal=4)
119 | assert_almost_equal(result[1][1], expected_msft_open+[np.nan], decimal=4)
120 |
121 | def test_no_ohlcv(self):
122 | warnings.filterwarnings("ignore", module='spectre')
123 | start, end = pd.Timestamp('2019-01-02'), pd.Timestamp('2019-01-15')
124 | loader = spectre.data.CsvDirLoader(
125 | prices_path=data_dir + '/daily/', earliest_date=start, calender_asset='AAPL',
126 | ohlcv=None, adjustments=None,
127 | prices_index='date',
128 | parse_dates=True, )
129 | engine = spectre.factors.FactorEngine(loader)
130 | engine.add(spectre.factors.ColumnDataFactor(inputs=['uOpen']), 'open')
131 | engine.run(start, end, delay_factor=False)
132 |
133 | @unittest.skipUnless(os.getenv('COVERAGE_RUNNING'), "too slow, run manually")
134 | def test_yahoo(self):
135 | yahoo_path = data_dir + '/yahoo/'
136 | try:
137 | os.remove(yahoo_path + 'yahoo.feather')
138 | os.remove(yahoo_path + 'yahoo.feather.meta')
139 | except FileNotFoundError:
140 | pass
141 |
142 | spectre.data.YahooDownloader.ingest("2011", yahoo_path, ['IBM', 'AAPL'], skip_exists=False)
143 | loader = spectre.data.ArrowLoader(yahoo_path + 'yahoo.feather')
144 | df = loader._load()
145 | self.assertEqual(['AAPL', 'IBM'], list(df.index.levels[1]))
146 |
147 | @unittest.skipUnless(os.getenv('COVERAGE_RUNNING'), "too slow, run manually")
148 | def test_QuandlLoader(self):
149 | quandl_path = data_dir + '../../../historical_data/us/prices/quandl/'
150 | try:
151 | os.remove(quandl_path + 'wiki_prices.feather')
152 | os.remove(quandl_path + 'wiki_prices.feather.meta')
153 | except FileNotFoundError:
154 | pass
155 |
156 | spectre.data.ArrowLoader.ingest(
157 | spectre.data.QuandlLoader(quandl_path + 'WIKI_PRICES.zip'),
158 | quandl_path + 'wiki_prices.feather'
159 | )
160 |
161 | loader = spectre.data.ArrowLoader(quandl_path + 'wiki_prices.feather')
162 |
163 | spectre.parallel.Rolling._split_multi = 80
164 | engine = spectre.factors.FactorEngine(loader)
165 | engine.add(spectre.factors.MA(100), 'ma')
166 | engine.to_cuda()
167 | df = engine.run("2014-01-02", "2014-01-02", delay_factor=False)
168 | # expected result comes from zipline
169 | assert_almost_equal(df.head().values.T,
170 | [[51.388700, 49.194407, 599.280580, 28.336585, 12.7058]], decimal=4)
171 | assert_almost_equal(df.tail().values.T,
172 | [[86.087988, 3.602880, 7.364000, 31.428209, 27.605950]], decimal=4)
173 |
174 | # test last line bug
175 | engine.run("2016-12-15", "2017-01-02")
176 | df = engine._dataframe.loc[(slice('2016-12-15', '2017-12-15'), 'STJ'), :]
177 | assert df.price_multi.values[-1] == 1
178 |
179 | def test_fast_get(self):
180 | loader = spectre.data.CsvDirLoader(
181 | data_dir + '/daily/', prices_index='date', parse_dates=True, )
182 | df = loader.load()[list(loader.ohlcv)]
183 | getter = spectre.data.DataLoaderFastGetter(df)
184 |
185 | table = getter.get_as_dict(pd.Timestamp('2018-01-02', tz='UTC'), column_id=3)
186 | self.assertAlmostEqual(df.loc[("2018-01-02", 'MSFT')].close, table['MSFT'])
187 | self.assertAlmostEqual(df.loc[("2018-01-02", 'AAPL')].close, table['AAPL'])
188 | self.assertRaises(KeyError, table.__getitem__, 'A')
189 | table = dict(table.items())
190 | self.assertAlmostEqual(df.loc[("2018-01-02", 'MSFT')].close, table['MSFT'])
191 | self.assertAlmostEqual(df.loc[("2018-01-02", 'AAPL')].close, table['AAPL'])
192 |
193 | table = getter.get_as_dict(pd.Timestamp('2018-01-02', tz='UTC'))
194 | np.testing.assert_array_almost_equal(df.loc[("2018-01-02", 'MSFT')].values, table['MSFT'])
195 | np.testing.assert_array_almost_equal(df.loc[("2018-01-02", 'AAPL')].values, table['AAPL'])
196 |
197 | result_df = getter.get_as_df(pd.Timestamp('2018-01-02', tz='UTC'))
198 | expected = df.xs("2018-01-02")
199 | pd.testing.assert_frame_equal(expected, result_df)
200 |
201 | table = getter.get_as_dict(pd.Timestamp('2019-01-05', tz='UTC'), column_id=3)
202 | self.assertTrue(np.isnan(table['MSFT']))
203 | self.assertRaises(KeyError, table.__getitem__, 'AAPL')
204 |
205 | table = getter.get_as_dict(pd.Timestamp('2019-01-10', tz='UTC'), column_id=3)
206 | self.assertRaises(KeyError, table.__getitem__, 'MSFT')
207 |
208 | # test 5mins
209 | loader = spectre.data.CsvDirLoader(
210 | data_dir + '/5mins/', prices_by_year=True, prices_index='Date', parse_dates=True,
211 | ohlcv=None)
212 | df = loader.load()
213 | getter = spectre.data.DataLoaderFastGetter(df)
214 | table = getter.get_as_dict(
215 | pd.Timestamp('2018-12-20 00:00:00+00:00', tz='UTC'),
216 | pd.Timestamp('2018-12-20 23:59:59+00:00', tz='UTC'),
217 | column_id=0)
218 | self.assertTrue(len(table.get_datetime_index().normalize().unique()) == 1)
219 |
--------------------------------------------------------------------------------
/tests/test_event.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import spectre
3 | from os.path import dirname
4 | import pandas as pd
5 |
6 | data_dir = dirname(__file__) + '/data/'
7 |
8 |
9 | class TestTradingEvent(unittest.TestCase):
10 |
11 | def test_event_mgr(self):
12 | class TestEventReceiver(spectre.trading.EventReceiver):
13 | fired = 0
14 |
15 | def on_run(self):
16 | self.schedule(spectre.trading.event.Always(self.test_always))
17 | self.schedule(spectre.trading.event.EveryBarData(self.test_every_bar))
18 |
19 | def test_always(self, _):
20 | self.fired += 1
21 |
22 | def test_every_bar(self, _):
23 | self.fired += 1
24 | if self.fired == 2:
25 | self.stop_event_manager()
26 |
27 | class StopFirer(spectre.trading.EventReceiver):
28 | def on_run(self):
29 | self.schedule(spectre.trading.event.Always(self.test))
30 |
31 | def test(self, _):
32 | self.fire_event(spectre.trading.event.EveryBarData)
33 |
34 | rcv = TestEventReceiver()
35 |
36 | evt_mgr = spectre.trading.EventManager()
37 | evt_mgr.subscribe(rcv)
38 | rcv.unsubscribe()
39 | self.assertEqual(0, len(evt_mgr._subscribers))
40 | self.assertRaisesRegex(ValueError, 'At least one subscriber.', evt_mgr.run)
41 | self.assertRaisesRegex(AssertionError, 'Subscriber not exists', evt_mgr.unsubscribe, rcv)
42 |
43 | evt_mgr.subscribe(rcv)
44 | evt_mgr.subscribe(StopFirer())
45 | self.assertRaisesRegex(AssertionError, 'Duplicate subscribe', evt_mgr.subscribe, rcv)
46 |
47 | evt_mgr.run()
48 | self.assertEqual(2, rcv.fired)
49 |
50 | def test_calendar(self):
51 | tz = 'America/New_York'
52 | end = pd.Timestamp.now(tz=tz) + pd.DateOffset(days=10)
53 | first = pd.date_range(pd.Timestamp.now(tz=tz).normalize(), end, freq='B')[0]
54 | if pd.Timestamp.now(tz=tz) > (first + pd.Timedelta("9:00:00")):
55 | first = first + pd.offsets.BDay(1)
56 | holiday = first + pd.offsets.BDay(2)
57 | test_now = first + pd.offsets.BDay(1) + pd.Timedelta("10:00:00")
58 |
59 | calendar = spectre.trading.Calendar()
60 | calendar.build(start=str(pd.Timestamp.now(tz=tz).normalize()), end=str(end.date()),
61 | daily_events={'Open': '9:00:00', 'Close': '15:00:00'},
62 | tz=tz)
63 | calendar.set_as_holiday(holiday)
64 |
65 | self.assertEqual(first + pd.Timedelta("9:00:00"),
66 | calendar.events['Open'][0])
67 |
68 | calendar.hr_now = lambda: test_now
69 |
70 | calendar.pop_passed('Open')
71 |
72 | self.assertEqual(test_now.normalize() + pd.offsets.BDay(2) + pd.Timedelta("9:00:00"),
73 | calendar.events['Open'][0])
74 |
75 | # test assert
76 | self.assertRaises(ValueError, calendar.build,
77 | start=str(pd.Timestamp.now(tz=tz).normalize()), end='2019',
78 | daily_events={'Open': '9:00:00', 'Close': '15:00:00'},
79 | tz=tz)
80 |
--------------------------------------------------------------------------------
/tests/test_metric.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import spectre
3 | import pandas as pd
4 | from numpy import nan
5 | from numpy.testing import assert_almost_equal
6 |
7 |
8 | class TestMetric(unittest.TestCase):
9 |
10 | def test_metrics(self):
11 | ret = pd.Series([0.0022, 0.0090, -0.0067, 0.0052, 0.0030, -0.0012, -0.0091, 0.0082,
12 | -0.0071, 0.0093],
13 | index=pd.date_range('2040-01-01', periods=10))
14 |
15 | self.assertAlmostEqual(2.9101144, spectre.trading.sharpe_ratio(ret, 0.00))
16 | self.assertAlmostEqual(2.5492371, spectre.trading.sharpe_ratio(ret, 0.04))
17 |
18 | dd, ddu = spectre.trading.drawdown((ret+1).cumprod())
19 | vol = spectre.trading.annual_volatility(ret)
20 |
21 | self.assertAlmostEqual(0.0102891, dd.abs().max())
22 | self.assertAlmostEqual(5, ddu.max())
23 | self.assertAlmostEqual(0.110841, vol)
24 |
25 | txn = pd.DataFrame([['AAPL', 384, 155.92, 157.09960, 1.92],
26 | ['AAPL', -384, 158.61, 157.41695, 1.92]],
27 | columns=['symbol', 'amount', 'price',
28 | 'fill_price', 'commission'],
29 | index=['2040-01-01', '2040-01-02'])
30 | txn.index = pd.to_datetime(txn.index)
31 |
32 | pos = pd.DataFrame([[384, 384*155.92, 10000],
33 | [nan, nan, 10000+384*158.61]],
34 | columns=pd.MultiIndex.from_tuples(
35 | [('shares', 'AAPL'), ('value', 'AAPL'), ('value', 'cash')]),
36 | index=['2040-01-01', '2040-01-02'])
37 | pos.index = pd.to_datetime(pos.index)
38 | assert_almost_equal([0.8633665, 0.8525076], spectre.trading.turnover(pos, txn))
39 |
--------------------------------------------------------------------------------
/tests/test_parallel_algo.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torch
3 | import spectre
4 | from numpy.testing import assert_array_equal, assert_almost_equal
5 | import numpy as np
6 | import pandas as pd
7 |
8 |
9 | class TestParallelAlgorithm(unittest.TestCase):
10 | def test_groupby(self):
11 | test_x = torch.tensor([1, 2, 10, 3, 11, 20, 4, 21, 5, 12, 13, 14, 15], dtype=torch.float32)
12 | test_k = torch.tensor([1, 1, 2, 1, 2, 3, 1, 3, 1, 2, 2, 2, 2])
13 |
14 | groupby = spectre.parallel.ParallelGroupBy(test_k)
15 | groups = groupby.split(test_x)
16 | assert_array_equal([1., 2., 3., 4., 5., np.nan], groups[0].tolist())
17 | assert_array_equal([10., 11., 12., 13., 14., 15.], groups[1].tolist())
18 | assert_array_equal([20., 21., np.nan, np.nan, np.nan, np.nan], groups[2].tolist())
19 |
20 | revert_x = groupby.revert(groups)
21 | assert_array_equal(revert_x.tolist(), test_x.tolist())
22 |
23 | def test_rolling(self):
24 | x = torch.tensor([[164.0000, 163.7100, 158.6100, 145.230],
25 | [104.6100, 104.4200, 101.3000, 102.280]])
26 | expected = torch.tensor(
27 | [[np.nan, np.nan, 486.3200, 467.5500],
28 | [np.nan, np.nan, 310.3300, 308.0000]])
29 |
30 | self.assertRegex(str(spectre.parallel.Rolling(x, 3)),
31 | "spectre.parallel.Rolling object(.|\n)*tensor(.|\n)*")
32 | s = spectre.parallel.Rolling(x, 3).sum()
33 | assert_almost_equal(expected.numpy(), s.numpy(), decimal=4)
34 |
35 | # test adjustment
36 | y = torch.tensor([[0.25, 0.25, 0.5, 1],
37 | [0.6, 0.75, 0.75, 1]])
38 | s = spectre.parallel.Rolling(x, 3, y).sum()
39 | expected = torch.tensor([
40 | [
41 | np.nan, np.nan,
42 | sum([164.0000 / 2, 163.7100 / 2, 158.6100]),
43 | sum([163.7100 / 4, 158.6100 / 2, 145.230]),
44 | ],
45 | [
46 | np.nan, np.nan,
47 | sum([104.6100 * (0.6 / 0.75), 104.4200, 101.3000]),
48 | sum([104.4200 * 0.75, 101.3000 * 0.75, 102.280]),
49 | ]
50 | ])
51 | assert_almost_equal(expected.numpy(), s.numpy(), decimal=4)
52 |
53 | x = torch.zeros([1024, 102400], dtype=torch.float64)
54 | spectre.parallel.Rolling(x, 252).sum()
55 |
56 | def test_nan(self):
57 | # dim=1
58 | data = [[1, 2, 1], [4, np.nan, 2], [7, 8, 1]]
59 | result = spectre.parallel.nanmean(torch.tensor(data, dtype=torch.float))
60 | expected = np.nanmean(data, axis=1)
61 | assert_almost_equal(expected, result, decimal=6)
62 |
63 | result = spectre.parallel.nanstd(torch.tensor(data, dtype=torch.float))
64 | expected = np.nanstd(data, axis=1)
65 | assert_almost_equal(expected, result, decimal=6)
66 |
67 | result = spectre.parallel.nanstd(torch.tensor(data, dtype=torch.float), ddof=1)
68 | expected = np.nanstd(data, axis=1, ddof=1)
69 | assert_almost_equal(expected, result, decimal=6)
70 |
71 | # dim=2
72 | data = [[[np.nan, 1, 2], [1, 2, 1]], [[np.nan, 4, np.nan], [4, np.nan, 2]],
73 | [[np.nan, 7, 8], [7, 8, 1]]]
74 | result = spectre.parallel.nanmean(torch.tensor(data, dtype=torch.float), dim=2)
75 | expected = np.nanmean(data, axis=2)
76 | assert_almost_equal(expected, result, decimal=6)
77 |
78 | result = spectre.parallel.nanstd(torch.tensor(data, dtype=torch.float), dim=2)
79 | expected = np.nanstd(data, axis=2)
80 | assert_almost_equal(expected, result, decimal=6)
81 |
82 | # last
83 | data = [[1, 2, np.nan], [4, np.nan, 2], [7, 8, 1]]
84 | result = spectre.parallel.nanlast(torch.tensor(data, dtype=torch.float).cuda())
85 | expected = [2., 2., 1.]
86 | assert_almost_equal(expected, result.cpu(), decimal=6)
87 |
88 | data = [[[1, 2, np.nan], [4, np.nan, 2], [7, 8, 1]]]
89 | result = spectre.parallel.nanlast(torch.tensor(data, dtype=torch.float).cuda(), dim=2)
90 | expected = [[2., 2., 1.]]
91 | assert_almost_equal(expected, result.cpu(), decimal=6)
92 |
93 | data = [1, 2, np.nan, 4, np.nan, 2, 7, 8, 1]
94 | result = spectre.parallel.nanlast(torch.tensor(data, dtype=torch.float), dim=0)
95 | expected = [1.]
96 | assert_almost_equal(expected, result, decimal=6)
97 |
98 | data = [1, 2, np.nan]
99 | mask = [False, True, True]
100 | result = spectre.parallel.masked_first(
101 | torch.tensor(data, dtype=torch.float), torch.tensor(mask, dtype=torch.bool), dim=0)
102 | expected = [2.]
103 | assert_almost_equal(expected, result, decimal=6)
104 |
105 | mask = [False, False, True]
106 | result = spectre.parallel.masked_first(
107 | torch.tensor(data, dtype=torch.float), torch.tensor(mask, dtype=torch.bool), dim=0)
108 | expected = [np.nan]
109 | assert_almost_equal(expected, result, decimal=6)
110 |
111 | # nanmin/max
112 | data = [[1, 2, -14, np.nan, 2], [99999, 8, 1, np.nan, 2]]
113 | result = spectre.parallel.nanmax(torch.tensor(data, dtype=torch.float))
114 | expected = np.nanmax(data, axis=1)
115 | assert_almost_equal(expected, result, decimal=6)
116 |
117 | result = spectre.parallel.nanmin(torch.tensor(data, dtype=torch.float))
118 | expected = np.nanmin(data, axis=1)
119 | assert_almost_equal(expected, result, decimal=6)
120 |
121 | def test_stat(self):
122 | x = torch.tensor([[1., 2, 3, 4, 5], [10, 12, 13, 14, 16], [2, 2, 2, 2, 2, ]])
123 | y = torch.tensor([[-1., 2, 3, 4, -5], [11, 12, -13, 14, 15], [2, 2, 2, 2, 2, ]])
124 | result = spectre.parallel.covariance(x, y, ddof=1)
125 | expected = np.cov(x, y, ddof=1)
126 | expected = expected[:x.shape[0], x.shape[0]:]
127 | assert_almost_equal(np.diag(expected), result, decimal=6)
128 |
129 | coef, intcp = spectre.parallel.linear_regression_1d(x, y)
130 | from sklearn.linear_model import LinearRegression
131 | for i in range(3):
132 | reg = LinearRegression().fit(x[i, :, None], y[i, :, None])
133 | assert_almost_equal(reg.coef_, coef[i], decimal=6)
134 |
135 | # test pearsonr
136 | result = spectre.parallel.pearsonr(x, y)
137 | from scipy import stats
138 | for i in range(3):
139 | expected, _ = stats.pearsonr(x[i].tolist(), y[i].tolist())
140 | assert_almost_equal(expected, result[i], decimal=6)
141 |
142 | # test pearsonr
143 | rank_x = spectre.parallel.rankdata(x)
144 | rank_y = spectre.parallel.rankdata(y)
145 | result = spectre.parallel.spearman(rank_x, rank_y)
146 | print(result)
147 | from scipy import stats
148 | for i in range(3):
149 | expected, _ = stats.spearmanr(x[i].tolist(), y[i].tolist())
150 | if expected != expected:
151 | expected = 1
152 | assert_almost_equal(expected, result[i], decimal=6)
153 |
154 | # test quantile
155 | x = torch.tensor([[1, 2, np.nan, 3, 4, 5, 6], [3, 4, 5, 1.01, np.nan, 1.02, 1.03]])
156 | result = spectre.parallel.quantile(x, 5, dim=1)
157 | expected = pd.qcut(x[0].tolist(), 5, labels=False)
158 | assert_array_equal(expected, result[0])
159 | expected = pd.qcut(x[1].tolist(), 5, labels=False)
160 | assert_array_equal(expected, result[1])
161 |
162 | x = torch.tensor(
163 | [[[1, 2, np.nan, 3, 4, 5, 6],
164 | [3, 4, 5, 1.01, np.nan, 1.02, 1.03]],
165 | [[1, 2, 2.1, 3, 4, 5, 6],
166 | [3, 4, 5, np.nan, np.nan, 1.02, 1.03]]])
167 | result = spectre.parallel.quantile(x, 5, dim=2)
168 | expected = pd.qcut(x[0, 0].tolist(), 5, labels=False)
169 | assert_array_equal(expected, result[0, 0])
170 | expected = pd.qcut(x[0, 1].tolist(), 5, labels=False)
171 | assert_array_equal(expected, result[0, 1])
172 | expected = pd.qcut(x[1, 0].tolist(), 5, labels=False)
173 | assert_array_equal(expected, result[1, 0])
174 | expected = pd.qcut(x[1, 1].tolist(), 5, labels=False)
175 | assert_array_equal(expected, result[1, 1])
176 |
177 | # test squeeze bug
178 | x = torch.tensor([[1., 2, 3, 4, 5]])
179 | y = torch.tensor([[-1., 2, 3, 4, -5]])
180 | coef, intcp = spectre.parallel.linear_regression_1d(x, y)
181 | reg = LinearRegression().fit(x[0, :, None], y[0, :, None])
182 | assert_almost_equal(reg.coef_, coef[0], decimal=6)
183 |
184 | # test median
185 | x = torch.tensor([[1, 2, np.nan, 3, 4, 5, 6], [3, 4, 5, 1.01, np.nan, 1.02, 1.03]])
186 | u = torch.tensor([[True, False, True, True, True, True, True, ],
187 | [True, True, True, True, True, True, True]])
188 | np_x = torch.masked_fill(x, ~u, np.nan)
189 | median = np.nanmedian(np_x.cpu(), axis=1)
190 | median = np.expand_dims(median, axis=1)
191 | expected, _ = spectre.parallel.masked_kth_value_1d(x, u, [0.5], dim=1)
192 | assert_almost_equal(median, expected[0], decimal=6)
193 |
194 | def test_pad2d(self):
195 | x = torch.tensor([[np.nan, 1, 1, np.nan, 1, np.nan, np.nan, 0, np.nan, 0, np.nan, np.nan, 0,
196 | np.nan, -1, np.nan, - 1, np.nan, np.nan, np.nan, 1],
197 | [np.nan, 1, 0, np.nan, 1, np.nan, np.nan, 1, np.nan, -1, np.nan, -1, 0,
198 | np.nan, -1, np.nan, - 1, np.nan, np.nan, np.nan, 1]])
199 | result = spectre.parallel.pad_2d(x)
200 |
201 | expected = [pd.Series(x[0].numpy()).ffill(), pd.Series(x[1].numpy()).ffill()]
202 | assert_almost_equal(expected, result)
203 |
--------------------------------------------------------------------------------
/tests/test_trading_algorithm.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import spectre
3 | import pandas as pd
4 | from os.path import dirname
5 | from numpy import nan
6 | from numpy.testing import assert_array_equal, assert_almost_equal
7 |
8 | data_dir = dirname(__file__) + '/data/'
9 |
10 |
11 | class TestTradingAlgorithm(unittest.TestCase):
12 |
13 | def test_simulation_event_manager(self):
14 | loader = spectre.data.CsvDirLoader(
15 | data_dir + '/daily/', ohlcv=('uOpen', 'uHigh', 'uLow', 'uClose', 'uVolume'),
16 | prices_index='date', parse_dates=True,
17 | )
18 | parent = self
19 |
20 | class MockTestAlg(spectre.trading.CustomAlgorithm):
21 | _data = None
22 | # test the elapsed time is correct
23 | _bar_dates = []
24 | # test if the sequence of events per day is correct
25 | _seq = 0
26 |
27 | def __init__(self):
28 | self.blotter = spectre.trading.SimulationBlotter(loader)
29 |
30 | def clear(self):
31 | pass
32 |
33 | def run_engine(self, start, end, _=False):
34 | engine = spectre.factors.FactorEngine(loader)
35 | f = spectre.factors.MA(5)
36 | engine.add(f, 'f')
37 | # self._engines = {'main': engine}
38 | df = engine.run(start, end)
39 | return df, df.loc[df.index.get_level_values(0)[-1]]
40 |
41 | def initialize(self):
42 | self.schedule(spectre.trading.event.EveryBarData(
43 | lambda x: self.test_every_bar(self._data)
44 | ))
45 | self.schedule(spectre.trading.event.MarketOpen(self.test_before_open, -1000))
46 | self.schedule(spectre.trading.event.MarketOpen(self.test_open, 0))
47 | self.schedule(spectre.trading.event.MarketClose(self.test_before_close, -1000))
48 | self.schedule(spectre.trading.event.MarketClose(self.test_close, 1000))
49 |
50 | def _run_engine(self, source):
51 | self._data, _ = self.run_engine(None, None)
52 |
53 | def on_run(self):
54 | self.schedule(spectre.trading.event.EveryBarData(
55 | self._run_engine
56 | ))
57 | self.initialize()
58 |
59 | def on_end_of_run(self):
60 | pass
61 |
62 | def test_every_bar(self, data):
63 | self._seq += 1
64 | parent.assertEqual(1, self._seq)
65 |
66 | today = data.index.get_level_values(0)[-1]
67 | self._bar_dates.append(data.index.get_level_values(0)[-1])
68 | if today > pd.Timestamp("2019-01-10", tz='UTC'):
69 | self.stop_event_manager()
70 |
71 | def test_before_open(self, source):
72 | self._seq += 1
73 | parent.assertEqual(2, self._seq)
74 |
75 | def test_open(self, source):
76 | self._seq += 1
77 | parent.assertEqual(3, self._seq)
78 |
79 | def test_before_close(self, source):
80 | self._seq += 1
81 | parent.assertEqual(4, self._seq)
82 |
83 | def test_close(self, source):
84 | self._seq += 1
85 | parent.assertEqual(5, self._seq)
86 | self._seq = 0
87 |
88 | rcv = MockTestAlg()
89 |
90 | evt_mgr = spectre.trading.SimulationEventManager()
91 | evt_mgr.subscribe(rcv)
92 | evt_mgr.run("2019-01-01", "2019-01-15")
93 |
94 | self.assertEqual(rcv._bar_dates[0], pd.Timestamp("2019-01-03", tz='UTC'))
95 | self.assertEqual(rcv._bar_dates[1], pd.Timestamp("2019-01-04", tz='UTC'))
96 | # test stop event is correct
97 | self.assertEqual(rcv._bar_dates[-1], pd.Timestamp("2019-01-11", tz='UTC'))
98 |
99 | def test_one_engine_algorithm(self):
100 | self_test = self
101 |
102 | class OneEngineAlg(spectre.trading.CustomAlgorithm):
103 | def initialize(self):
104 | engine = self.get_factor_engine()
105 | ma5 = spectre.factors.MA(5)
106 | engine.add(ma5, 'ma5')
107 | engine.set_filter(ma5.top(5))
108 |
109 | self.schedule_rebalance(spectre.trading.event.MarketOpen(self.rebalance))
110 |
111 | self.blotter.long_ony = True
112 | self.blotter.set_commission(0, 0, 0)
113 | self.blotter.set_slippage(0, 0)
114 |
115 | def rebalance(self, data, history):
116 | if 103.98 <= data.loc['MSFT', 'ma5'] <= 103.99:
117 | data = data.drop('MSFT', axis=0)
118 | if 'MSFT' in data.index:
119 | self_test.assertAlmostEqual(1268.4657576628665, data.loc['MSFT', 'ma5'])
120 | weights = data.ma5 / data.ma5.sum()
121 | assets = data.index
122 | self.blotter.batch_order_target_percent(assets, weights)
123 |
124 | def terminate(self, _):
125 | pass
126 |
127 | loader = spectre.data.CsvDirLoader(
128 | data_dir + '/daily/', calender_asset='AAPL',
129 | ohlcv=('uOpen', 'uHigh', 'uLow', 'uClose', 'uVolume'),
130 | dividends_path=data_dir + '/dividends/', splits_path=data_dir + '/splits/',
131 | adjustments=('amount', 'ratio'),
132 | prices_index='date', dividends_index='exDate', splits_index='exDate', parse_dates=True,
133 | )
134 | results = spectre.trading.run_backtest(
135 | loader, OneEngineAlg, "2019-01-11", "2019-01-15")
136 |
137 | # test factor delay, order correct
138 | # --- day1 ---
139 | aapl_shares1 = int(1e5/155.19)
140 | aapl_cost1 = aapl_shares1*155.19
141 | cash1 = 1e5-aapl_cost1
142 | aapl_value_eod1 = aapl_shares1 * 157
143 | # --- day2 ---
144 | aapl_weight2 = 155.854 / (155.854+1268.466)
145 | msft_weight2 = 1268.466 / (155.854+1268.466)
146 | value_bod2 = aapl_shares1 * 150.81 + cash1
147 | aapl_shares_change = aapl_weight2 * value_bod2 / 150.81
148 | aapl_shares_change = int(round(aapl_shares_change)) - aapl_shares1
149 | aapl_shares2 = aapl_shares1 + aapl_shares_change
150 | aapl_basis = (155.19 * aapl_shares1 + aapl_shares_change * 150.81) / aapl_shares2
151 | aapl_value2 = aapl_shares2 * 156.94
152 | msft_shares2 = int(round(msft_weight2 * value_bod2 / 103.19))
153 | msft_value2 = msft_shares2 * 108.85
154 | cash2 = 1e5-aapl_cost1 + (aapl_shares1-aapl_shares2) * 150.81 - msft_shares2 * 103.19
155 | expected = pd.DataFrame([[nan, nan, nan, nan, nan, nan, 100000.00],
156 | [155.19, nan, aapl_shares1, nan, aapl_value_eod1, nan, cash1],
157 | [aapl_basis, 103.19, aapl_shares2, msft_shares2, aapl_value2,
158 | msft_value2, cash2]],
159 | columns=pd.MultiIndex.from_tuples(
160 | [('avg_px', 'AAPL'), ('avg_px', 'MSFT'),
161 | ('shares', 'AAPL'), ('shares', 'MSFT'),
162 | ('value', 'AAPL'), ('value', 'MSFT'),
163 | ('value', 'cash')]),
164 | index=[pd.Timestamp("2019-01-13", tz='UTC'),
165 | pd.Timestamp("2019-01-14", tz='UTC'),
166 | pd.Timestamp("2019-01-15", tz='UTC')])
167 | expected.index.name = 'index'
168 | pd.testing.assert_frame_equal(expected, results.positions)
169 |
170 | def test_two_engine_algorithm(self):
171 | class TwoEngineAlg(spectre.trading.CustomAlgorithm):
172 | def initialize(self):
173 | engine_main = self.get_factor_engine('main')
174 | engine_test = self.get_factor_engine('test')
175 |
176 | ma5 = spectre.factors.MA(5)
177 | ma4 = spectre.factors.MA(4)
178 | engine_main.add(ma5, 'ma5')
179 | engine_test.add(ma4, 'ma4')
180 |
181 | self.schedule_rebalance(spectre.trading.event.MarketClose(
182 | self.rebalance, offset_ns=-10000))
183 |
184 | self.blotter.set_commission(0, 0.005, 1)
185 | self.blotter.set_slippage(0, 0.4)
186 |
187 | def rebalance(self, data, history):
188 | mask = data['test'].ma4 > data['main'].ma5
189 | masked_test = data['test'][mask]
190 | assets = masked_test.index
191 | weights = masked_test.ma4 / masked_test.ma4.sum()
192 | for asset, weight in zip(assets, weights):
193 | self.blotter.order_target_percent(asset, weight)
194 |
195 | def terminate(self, _):
196 | pass
197 |
198 | loader = spectre.data.CsvDirLoader(
199 | data_dir + '/daily/', calender_asset='AAPL',
200 | ohlcv=('uOpen', 'uHigh', 'uLow', 'uClose', 'uVolume'),
201 | dividends_path=data_dir + '/dividends/', splits_path=data_dir + '/splits/',
202 | adjustments=('amount', 'ratio'),
203 | prices_index='date', dividends_index='exDate', splits_index='exDate', parse_dates=True,
204 | )
205 | blotter = spectre.trading.SimulationBlotter(loader)
206 | evt_mgr = spectre.trading.SimulationEventManager()
207 | alg = TwoEngineAlg(blotter, main=loader, test=loader)
208 | evt_mgr.subscribe(alg)
209 | evt_mgr.subscribe(blotter)
210 |
211 | evt_mgr.run("2019-01-10", "2019-01-15")
212 | first_run = str(blotter)
213 | evt_mgr.run("2019-01-10", "2019-01-15")
214 |
215 | # test two result should be the same.
216 | self.assertEqual(first_run, str(blotter))
217 | assert_array_equal(['AAPL', 'MSFT', 'AAPL', 'MSFT', 'AAPL'],
218 | blotter.get_transactions().symbol.values)
219 |
220 | def test_record(self):
221 | recorder = spectre.trading.CustomAlgorithm(None, main=None)._recorder
222 | recorder.record("2019-01-10", dict(a=1, b=2))
223 | recorder.record("2019-01-11", dict(a=2, b=3, c=4))
224 | df = recorder.to_df()
225 | expected = pd.DataFrame([[1, 2, nan],
226 | [2, 3, 4]],
227 | columns=['a', 'b', 'c'],
228 | index=["2019-01-10",
229 | "2019-01-11"])
230 | expected.index.name = 'date'
231 | pd.testing.assert_frame_equal(expected, df)
232 |
233 | def test_intraday_algorithm(self):
234 | class IntradayAlg(spectre.trading.CustomAlgorithm):
235 | order_shares = 0.3
236 |
237 | def initialize(self):
238 | engine_main = self.get_factor_engine()
239 | ma5 = spectre.factors.MA(5)
240 | engine_main.add(ma5, 'ma5')
241 |
242 | self.schedule_rebalance(spectre.trading.event.MarketClose(
243 | self.rebalance, offset_ns=-10000))
244 |
245 | self.blotter.set_commission(0, 0.005, 1)
246 | self.blotter.set_slippage(0, 0.4)
247 |
248 | def rebalance(self, data, history):
249 | self.blotter.order_target_percent('AAPL', self.order_shares)
250 | self.order_shares = -self.order_shares
251 |
252 | def terminate(self, _):
253 | pass
254 |
255 | loader = spectre.data.CsvDirLoader(
256 | data_dir + '/5mins/', prices_by_year=True, prices_index='Date',
257 | ohlcv=('Open', 'High', 'Low', 'Close', 'Volume'), parse_dates=True, )
258 | results = spectre.trading.run_backtest(loader, IntradayAlg, "2019-01-01", "2019-01-05")
259 |
260 | self.assertAlmostEqual(157.92, results.transactions.loc['2019-01-02 20:55:00+00:00'].price)
261 | self.assertAlmostEqual(142.09, results.transactions.loc['2019-01-03 20:55:00+00:00'].price)
262 | self.assertAlmostEqual(148.26, results.transactions.loc['2019-01-04 20:55:00+00:00'].price)
263 |
264 | class IntradayAlgOpen(spectre.trading.CustomAlgorithm):
265 | order_shares = 0.3
266 |
267 | def initialize(self):
268 | engine_main = self.get_factor_engine()
269 | ma5 = spectre.factors.MA(5)
270 | engine_main.add(ma5, 'ma5')
271 |
272 | self.schedule_rebalance(spectre.trading.event.MarketOpen(
273 | self.rebalance))
274 |
275 | self.blotter.set_commission(0, 0.005, 1)
276 | self.blotter.set_slippage(0, 0.4)
277 |
278 | def rebalance(self, data, history):
279 | self.blotter.order_target_percent('AAPL', self.order_shares)
280 | self.order_shares = -self.order_shares
281 |
282 | def terminate(self, _):
283 | pass
284 |
285 | loader = spectre.data.CsvDirLoader(
286 | data_dir + '/5mins/', prices_by_year=True, prices_index='Date',
287 | ohlcv=('Open', 'High', 'Low', 'Close', 'Volume'), parse_dates=True, )
288 | results = spectre.trading.run_backtest(loader, IntradayAlgOpen, "2019-01-01", "2019-01-05")
289 | assert_almost_equal([143.95, 144.58], results.transactions.price)
290 |
--------------------------------------------------------------------------------