├── .gitignore ├── LICENSE ├── README.md ├── atpy ├── __init__.py ├── backtesting │ ├── __init__.py │ ├── data_replay.py │ ├── environments.py │ ├── mock_exchange.py │ └── random_strategy.py ├── data │ ├── __init__.py │ ├── cache │ │ ├── __init__.py │ │ ├── influxdb_cache.py │ │ ├── influxdb_cache_requests.py │ │ ├── lmdb_cache.py │ │ └── postgres_cache.py │ ├── intrinio │ │ ├── __init__.py │ │ ├── api.py │ │ └── influxdb_cache.py │ ├── iqfeed │ │ ├── __init__.py │ │ ├── bar_util.py │ │ ├── filters.py │ │ ├── iqfeed_bar_data_provider.py │ │ ├── iqfeed_history_provider.py │ │ ├── iqfeed_influxdb_cache.py │ │ ├── iqfeed_influxdb_cache_requests.py │ │ ├── iqfeed_level_1_provider.py │ │ ├── iqfeed_news_provider.py │ │ ├── iqfeed_postgres_cache.py │ │ └── util.py │ ├── latest_data_snapshot.py │ ├── quandl │ │ ├── __init__.py │ │ ├── api.py │ │ ├── influxdb_cache.py │ │ └── postgres_cache.py │ ├── splits_dividends.py │ ├── tradingcalendar.py │ ├── ts_util.py │ └── util.py ├── ibapi │ ├── __init__.py │ └── ib_events.py ├── ml │ ├── __init__.py │ ├── cross_validation.py │ ├── frac_diff_features.py │ ├── labeling.py │ └── util.py └── portfolio │ ├── __init__.py │ ├── order.py │ └── portfolio_manager.py ├── scripts ├── iqfeed_to_postgres_bars_1d.py ├── iqfeed_to_postgres_bars_1m.py ├── iqfeed_to_postgres_bars_5m.py ├── iqfeed_to_postgres_bars_60m.py ├── postgres_to_lmdb_bars_1d.py ├── postgres_to_lmdb_bars_1m.py ├── postgres_to_lmdb_bars_5m.py ├── postgres_to_lmdb_bars_60m.py ├── quandl_sf0_to_postgres.py ├── update_influxdb_cache.py ├── update_influxdb_fundamentals_cache.py ├── update_postgres_adjustments_cache.py └── update_postgres_cache.py ├── setup.py └── tests ├── __init__.py ├── backtesting ├── __init__.py ├── test_data_replay.py ├── test_environments.py └── test_mock_exchange.py ├── data ├── __init__.py ├── test_splits_dividends.py ├── test_talib.py └── test_ts_utils.py ├── ibapi ├── __init__.py └── test_ibapi.py ├── intrinio ├── __init__.py └── test_api.py ├── iqfeed ├── __init__.py ├── test_bar_data_provider.py ├── test_history_provider.py ├── test_iqfeed_influxdb_cache.py ├── test_iqfeed_influxdb_cache_requests.py ├── test_iqfeed_postgres_cache.py ├── test_news_provider.py └── test_streaming_level_1.py ├── ml ├── __init__.py ├── test_cross_validation.py ├── test_data_pipeline.py ├── test_data_util.py └── test_labeling.py ├── portfolio ├── __init__.py └── test_portfolio_manager.py └── quandl ├── __init__.py └── test_api.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | .idea/ 81 | 82 | # virtualenv 83 | venv/ 84 | ENV/ 85 | 86 | # Spyder project settings 87 | .spyderproject 88 | 89 | # Rope project settings 90 | .ropeproject 91 | 92 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 Ivan Vasilev 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Event-based Algorithmic Trading For Python 2 | 3 | Event-based Algorithmic trading library. The events implementation is [pyevents](https://github.com/ivan-vasilev/pyevents) 4 | library. The features are: 5 | 6 | * Real-time and historical bar and tick data from [IQFeed](http://www.iqfeed.net/) via [@pyiqfeed](https://github.com/akapur/pyiqfeed). The data is provided as pandas multiindex dataframes. For this to work, you need IQFeed subscription. 7 | * API integration with [Quandl](https://www.quandl.com/) and [INTRINIO](https://intrinio.com/). 8 | * Storing and retrieving historical data and other datasets with [PostgreSQL](https://www.postgresql.org) and [InfluxDB](https://www.influxdata.com/). Again, the data is provided via pandas dataframes. 9 | * Placing orders via the [Interactive Brokers Python API](https://github.com/InteractiveBrokers/tws-api-public). For this to work, you need to have IB account. 10 | 11 | For more information on how to use the library please check the unit tests. 12 | -------------------------------------------------------------------------------- /atpy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivan-vasilev/atpy/abe72832ae8cec818b0e67989892c25456e9e5f5/atpy/__init__.py -------------------------------------------------------------------------------- /atpy/backtesting/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivan-vasilev/atpy/abe72832ae8cec818b0e67989892c25456e9e5f5/atpy/backtesting/__init__.py -------------------------------------------------------------------------------- /atpy/backtesting/data_replay.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import typing 4 | 5 | import pandas as pd 6 | 7 | from atpy.data.ts_util import overlap_by_symbol 8 | from pyevents.events import EventFilter 9 | 10 | 11 | class DataReplay(object): 12 | """Replay data from multiple sources, sorted by time. Each source provides a dataframe.""" 13 | 14 | def __init__(self): 15 | self._sources_defs = list() 16 | self._is_running = False 17 | 18 | def __iter__(self): 19 | if self._is_running: 20 | raise Exception("Cannot start iteration while the generator is working") 21 | 22 | self._is_running = True 23 | 24 | self._data = dict() 25 | 26 | self._timeline = None 27 | 28 | self._current_time = None 29 | 30 | sources = dict() 31 | 32 | for (iterator, name, historical_depth, listeners) in self._sources_defs: 33 | sources[name] = (iter(iterator), historical_depth, listeners) 34 | 35 | self._sources = sources 36 | 37 | return self 38 | 39 | def __next__(self): 40 | # delete "expired" dataframes and obtain new data from the providers 41 | for e, (dp, historical_depth, listeners) in dict(self._sources).items(): 42 | if e not in self._data or \ 43 | (self._current_time is not None and self._get_datetime_level(self._data[e].index)[-1] <= self._current_time): 44 | self._timeline = None 45 | 46 | now = datetime.datetime.now() 47 | try: 48 | df = next(dp) 49 | while df is not None and df.empty: 50 | df = next(dp) 51 | except StopIteration: 52 | df = None 53 | 54 | if df is not None: 55 | logging.getLogger(__name__).debug('Obtained data ' + str(e) + ' in ' + str(datetime.datetime.now() - now)) 56 | 57 | # prepend old data if exists 58 | self._data[e] = overlap_by_symbol(self._data[e], df, historical_depth) if e in self._data and historical_depth > 0 else df 59 | 60 | if listeners is not None: 61 | listeners({'type': 'pre_data', e + '_full': self._data[e]}) 62 | else: 63 | if e in self._data: 64 | del self._data[e] 65 | 66 | del self._sources[e] 67 | 68 | # build timeline 69 | if self._timeline is None and self._data: 70 | now = datetime.datetime.now() 71 | 72 | indices = [self._get_datetime_level(df.index) for df in self._data.values()] 73 | tzs = {ind.tz for ind in indices} 74 | 75 | if len(tzs) > 1: 76 | raise Exception("Multiple timezones detected") 77 | 78 | ind = indices[0].union_many(indices[1:]).unique().sort_values() 79 | 80 | self._timeline = pd.DataFrame(index=ind) 81 | 82 | for e, df in self._data.items(): 83 | ind = self._get_datetime_level(df.index) 84 | self._timeline[e] = False 85 | self._timeline.loc[ind, e] = True 86 | 87 | logging.getLogger(__name__).debug('Built timeline in ' + str(datetime.datetime.now() - now)) 88 | 89 | # produce results 90 | if self._timeline is not None: 91 | result = dict() 92 | 93 | if self._current_time is None or self._current_time < self._timeline.index[0]: 94 | self._current_time, current_index = self._timeline.index[0], 0 95 | elif self._current_time in self._timeline.index: 96 | current_index = self._timeline.index.get_loc(self._current_time) + 1 97 | self._current_time = self._timeline.index[current_index] 98 | else: 99 | self._current_time = self._timeline.loc[self._timeline.index > self._current_time].iloc[0].name 100 | current_index = self._timeline.index.get_loc(self._current_time) 101 | 102 | result['timestamp'] = self._current_time.to_pydatetime() 103 | 104 | row = self._timeline.iloc[current_index] 105 | 106 | for e in [e for e in row.index if row[e]]: 107 | df = self._data[e] 108 | _, historical_depth, _ = self._sources[e] 109 | ind = self._get_datetime_level(df) 110 | result[e] = df.loc[ind[max(0, ind.get_loc(self._current_time) - historical_depth)]:self._current_time] 111 | 112 | return result 113 | else: 114 | raise StopIteration() 115 | 116 | @staticmethod 117 | def _get_datetime_level(index): 118 | if isinstance(index, pd.DataFrame) or isinstance(index, pd.Series): 119 | index = index.index 120 | 121 | if isinstance(index, pd.DatetimeIndex): 122 | return index 123 | elif isinstance(index, pd.MultiIndex): 124 | return [l for l in index.levels if isinstance(l, pd.DatetimeIndex)][0] 125 | 126 | def add_source(self, data_provider: typing.Union[typing.Iterator, typing.Callable], name: str, historical_depth: int = 0, listeners: typing.Callable = None): 127 | """ 128 | Add source for data generation 129 | :param data_provider: return pd.DataFrame with either DateTimeIndex or MultiIndex, where one of the levels is of datetime type 130 | :param name: data set name for each of the data sources 131 | :param historical_depth: whether to return only the current element or with historical depth 132 | :param listeners: Fire event after each data provider request. 133 | This is necessary, because the data replay functionality is combining the new/old dataframes for continuity. 134 | Process data, once obtained from the data provider (applied once for the whole chunk). 135 | :return: self 136 | """ 137 | if self._is_running: 138 | raise Exception("Cannot add sources while the generator is working") 139 | 140 | self._sources_defs.append((data_provider, name, historical_depth, listeners)) 141 | 142 | return self 143 | 144 | 145 | class DataReplayEvents(object): 146 | """Add source for data generation""" 147 | 148 | def __init__(self, listeners, data_replay: DataReplay, event_name: str): 149 | self.listeners = listeners 150 | self.data_replay = data_replay 151 | self.event_name = event_name 152 | 153 | def start(self): 154 | for d in self.data_replay: 155 | d['type'] = self.event_name 156 | self.listeners(d) 157 | 158 | def event_filter(self) -> EventFilter: 159 | """ 160 | Return event filter, which only calls the listener for the main data replay event 161 | """ 162 | 163 | return EventFilter(listeners=self.listeners, 164 | event_filter=lambda e: True if e['type'] == self.event_name else False) 165 | 166 | def event_filter_by_source(self, source_name: str) -> EventFilter: 167 | """ 168 | Return event filter, which only calls the listener for the main data replay event 169 | if source_name exists in the event 170 | :param source_name: transform the event to the dataframe of source_name only 171 | """ 172 | 173 | return EventFilter(listeners=self.listeners, 174 | event_filter= 175 | lambda e: True if 'timestamp' in e and e['type'] == self.event_name and source_name in e else False, 176 | event_transformer=lambda e: (e[source_name],)) 177 | 178 | def event_filter_function(self, source_name: str = None) -> typing.Callable: 179 | """ 180 | Return event filter function, which returns True for data replay events only 181 | If source_name is specified, it also filters for source name 182 | :return the event itself if this is a data replay event and source_name exists in the dict (if source_name is specified) 183 | """ 184 | 185 | if source_name is not None: 186 | return lambda e: e[source_name] if 'timestamp' in e and e['type'] == self.event_name and source_name in e else None 187 | else: 188 | return lambda e: e if e['type'] == self.event_name else None 189 | -------------------------------------------------------------------------------- /atpy/backtesting/mock_exchange.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import threading 3 | import typing 4 | 5 | import pandas as pd 6 | 7 | import atpy.portfolio.order as orders 8 | from atpy.data.iqfeed.util import get_last_value 9 | from pyevents.events import EventFilter 10 | 11 | 12 | class MockExchange(object): 13 | """ 14 | Mock exchange for executing trades based on the current streaming prices. Works with realtime and historical data. 15 | """ 16 | 17 | def __init__(self, 18 | listeners, 19 | order_requests_event_stream=None, 20 | bar_event_stream=None, 21 | tick_event_stream=None, 22 | order_processor: typing.Callable = None, 23 | commission_loss: typing.Callable = None): 24 | """ 25 | :param order_requests_event_stream: event stream for order events 26 | :param bar_event_stream: event stream for bar data events 27 | :param tick_event_stream: event stream for tick data events 28 | :param order_processor: a function which takes the current bar/tick volume and price and the current order. 29 | It applies some logic to return allowed volume and price for the order, given the current conditions. 30 | This function might apply slippage and so on 31 | :param commission_loss: apply commission loss to the price 32 | """ 33 | 34 | order_requests_event_stream += self.process_order_request 35 | 36 | if bar_event_stream is not None: 37 | bar_event_stream += self.process_bar_data 38 | 39 | if tick_event_stream is not None: 40 | tick_event_stream += self.process_tick_data 41 | 42 | self.listeners = listeners 43 | 44 | self.order_processor = order_processor if order_processor is not None else lambda order, price, volume: (price, volume) 45 | self.commission_loss = commission_loss if commission_loss is not None else lambda o: 0 46 | 47 | self._pending_orders = list() 48 | self._lock = threading.RLock() 49 | 50 | def process_order_request(self, order): 51 | with self._lock: 52 | self._pending_orders.append(order) 53 | 54 | def process_tick_data(self, data): 55 | with self._lock: 56 | matching_orders = [o for o in self._pending_orders if o.symbol == data['symbol']] 57 | for o in matching_orders: 58 | data = get_last_value(data) 59 | if o.order_type == orders.Type.BUY: 60 | if 'tick_id' in data: 61 | price, volume = self.order_processor(o, data['ask'], data['last_size']) 62 | 63 | o.add_position(volume, price) 64 | else: 65 | price, volume = self.order_processor(order=o, 66 | price=data['ask'] if data['ask_size'] > 0 else data['most_recent_trade'], 67 | volume=data['ask_size'] if data['ask_size'] > 0 else data['most_recent_trade_size']) 68 | 69 | o.add_position(volume, price) 70 | 71 | o.commission = self.commission_loss(o) 72 | elif o.order_type == orders.Type.SELL: 73 | if 'tick_id' in data: 74 | price, volume = self.order_processor(o, data['bid'], data['last_size']) 75 | o.add_position(price, volume) 76 | else: 77 | price, volume = self.order_processor(order=o, 78 | price=data['bid'] if data['bid_size'] > 0 else data['most_recent_trade'], 79 | volume=data['bid_size'] if data['bid_size'] > 0 else data['most_recent_trade_size']) 80 | 81 | o.add_position(price, volume) 82 | 83 | o.commission = self.commission_loss(o) 84 | if o.fulfill_time is not None: 85 | self._pending_orders.remove(o) 86 | 87 | logging.getLogger(__name__).info("Order fulfilled: " + str(o)) 88 | 89 | self.listeners({'type': 'order_fulfilled', 'data': o}) 90 | 91 | def process_bar_data(self, data): 92 | with self._lock: 93 | symbols = data.index.get_level_values(level='symbol') 94 | 95 | symbol_ind = data.index.names.index('symbol') 96 | 97 | for o in [o for o in self._pending_orders if o.symbol in symbols]: 98 | ix = pd.IndexSlice[:, o.symbol] if symbol_ind == 1 else pd.IndexSlice[o.symbol, :] 99 | slc = data.loc[ix, :] 100 | 101 | if not slc.empty: 102 | price, volume = self.order_processor(o, slc.iloc[-1]['close'], slc.iloc[-1]['volume']) 103 | 104 | o.add_position(min(o.quantity - o.obtained_quantity, volume), price) 105 | 106 | o.commission = self.commission_loss(o) 107 | 108 | if o.fulfill_time is not None: 109 | self._pending_orders.remove(o) 110 | logging.getLogger(__name__).info("Order fulfilled: " + str(o)) 111 | 112 | self.listeners({'type': 'order_fulfilled', 'data': o}) 113 | 114 | def fulfilled_orders_stream(self): 115 | return EventFilter(listeners=self.listeners, 116 | event_filter=lambda e: True if 'type' in e and e['type'] == 'order_fulfilled' else False, 117 | event_transformer=lambda e: (e['data'],)) 118 | 119 | 120 | class StaticSlippageLoss: 121 | """Apply static loss value to account for slippage per each order""" 122 | 123 | def __init__(self, loss_rate: float, max_order_volume: float = 1.0): 124 | """ 125 | :param loss_rate: slippage loss rate [0:1] coefficient for each order 126 | :param max_order_volume: [0:1] coefficient, which says how much of the available volume can be assigned to this order 127 | """ 128 | self.loss_rate = loss_rate 129 | self.max_order_volume = max_order_volume 130 | 131 | def __call__(self, order: orders.BaseOrder, price: float, volume: int): 132 | if order.order_type == orders.Type.BUY: 133 | return price + self.loss_rate * price, int(volume * self.max_order_volume) 134 | elif order.order_type == orders.Type.SELL: 135 | return price - self.loss_rate * price, int(volume * self.max_order_volume) 136 | 137 | 138 | class PerShareCommissionLoss: 139 | """Apply commission loss for each share""" 140 | 141 | def __init__(self, value): 142 | self.value = value 143 | 144 | def __call__(self, o: orders.BaseOrder): 145 | return o.obtained_quantity * self.value 146 | -------------------------------------------------------------------------------- /atpy/backtesting/random_strategy.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | 4 | import pandas as pd 5 | 6 | from atpy.portfolio.portfolio_manager import PortfolioManager, MarketOrder, Type 7 | from pyevents.events import EventFilter 8 | 9 | 10 | class RandomStrategy: 11 | """Random buy/sell on each step""" 12 | 13 | def __init__(self, listeners, bar_event_stream, portfolio_manager: PortfolioManager, max_buys_per_step=1, max_sells_per_step=1): 14 | """ 15 | :param listeners: listeners environment 16 | :param bar_event_stream: bar events 17 | :param portfolio_manager: Portfolio manager 18 | :param max_buys_per_step: maximum buy orders per time step (one bar) 19 | :param max_sells_per_step: maximum sell orders per time step (one bar) 20 | """ 21 | self.listeners = listeners 22 | bar_event_stream += self.on_bar_event 23 | 24 | self.portfolio_manager = portfolio_manager 25 | self.max_buys_per_step = max_buys_per_step 26 | self.max_sells_per_step = max_sells_per_step 27 | 28 | def on_bar_event(self, data): 29 | buys = random.randint(0, min(len(data.index.get_level_values('symbol')), self.max_buys_per_step)) 30 | 31 | for _ in range(buys): 32 | symbol = data.sample().index.get_level_values('symbol')[0] 33 | volume = random.randint(1, data.loc[pd.IndexSlice[:, symbol], :].iloc[-1]['volume']) 34 | 35 | o = MarketOrder(Type.BUY, symbol, volume) 36 | 37 | logging.getLogger(__name__).debug('Placing new order ' + str(o)) 38 | 39 | self.listeners({'type': 'order_request', 'data': o}) 40 | 41 | quantities = self.portfolio_manager.quantity() 42 | sells = random.randint(0, min(len(quantities), self.max_sells_per_step)) 43 | 44 | selected_symbols = set() 45 | orders = list() 46 | for _ in range(sells): 47 | symbol, volume = random.choice(list(quantities.items())) 48 | while symbol in selected_symbols: 49 | symbol, volume = random.choice(list(quantities.items())) 50 | 51 | selected_symbols.add(symbol) 52 | orders.append(MarketOrder(Type.SELL, symbol, random.randint(1, min(self.portfolio_manager.quantity(symbol), volume)))) 53 | 54 | for o in orders: 55 | logging.getLogger(__name__).debug('Placing new order ' + str(o)) 56 | self.listeners({'type': 'order_request', 'data': o}) 57 | 58 | def order_requests_stream(self): 59 | return EventFilter(listeners=self.listeners, 60 | event_filter=lambda e: True if ('type' in e and e['type'] == 'order_request') else False, 61 | event_transformer=lambda e: (e['data'],)) 62 | -------------------------------------------------------------------------------- /atpy/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivan-vasilev/atpy/abe72832ae8cec818b0e67989892c25456e9e5f5/atpy/data/__init__.py -------------------------------------------------------------------------------- /atpy/data/cache/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivan-vasilev/atpy/abe72832ae8cec818b0e67989892c25456e9e5f5/atpy/data/cache/__init__.py -------------------------------------------------------------------------------- /atpy/data/cache/influxdb_cache.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import queue 4 | import threading 5 | import typing 6 | from functools import partial 7 | 8 | from dateutil import tz 9 | from dateutil.parser import parse 10 | from dateutil.relativedelta import relativedelta 11 | from influxdb import InfluxDBClient, DataFrameClient 12 | 13 | 14 | class BarsFilter(typing.NamedTuple): 15 | ticker: typing.Union[list, str] 16 | interval_len: int 17 | interval_type: str 18 | bgn_prd: datetime.datetime 19 | 20 | 21 | def ranges(client: InfluxDBClient): 22 | """ 23 | :return: list of latest times for each entry grouped by symbol and interval 24 | """ 25 | parse_time = lambda t: parse(t).replace(tzinfo=tz.gettz('UTC')) 26 | 27 | points = InfluxDBClient.query(client, "select FIRST(close), symbol, interval, time from bars group by symbol, interval").get_points() 28 | firsts = {(entry['symbol'], int(entry['interval'].split('_')[0]), entry['interval'].split('_')[1]): parse_time(entry['time']) for entry in points} 29 | 30 | points = InfluxDBClient.query(client, "select LAST(close), symbol, interval, time from bars group by symbol, interval").get_points() 31 | lasts = {(entry['symbol'], int(entry['interval'].split('_')[0]), entry['interval'].split('_')[1]): parse_time(entry['time']) for entry in points} 32 | 33 | result = {k: (firsts[k], lasts[k]) for k in firsts.keys() & lasts.keys()} 34 | 35 | return result 36 | 37 | 38 | def update_to_latest(client: DataFrameClient, noncache_provider: typing.Callable, new_symbols: set = None, time_delta_back: relativedelta = relativedelta(years=5), skip_if_older_than: relativedelta = None): 39 | """ 40 | Update existing entries in the database to the most current values 41 | :param client: DataFrameClient client 42 | :param noncache_provider: Non cache data provider 43 | :param new_symbols: additional symbols to add {(symbol, interval_len, interval_type), ...}} 44 | :param time_delta_back: start 45 | :param skip_if_older_than: skip symbol update if the symbol is older than... 46 | :return: 47 | """ 48 | filters = dict() 49 | 50 | new_symbols = set() if new_symbols is None else new_symbols 51 | 52 | if skip_if_older_than is not None: 53 | skip_if_older_than = (datetime.datetime.utcnow().replace(tzinfo=tz.gettz('UTC')) - skip_if_older_than).astimezone(tz.gettz('US/Eastern')) 54 | 55 | for key, time in [(e[0], e[1][1]) for e in ranges(client).items()]: 56 | if key in new_symbols: 57 | new_symbols.remove(key) 58 | 59 | if skip_if_older_than is None or time > skip_if_older_than: 60 | bgn_prd = datetime.datetime.combine(time.date(), datetime.datetime.min.time()).replace(tzinfo=tz.gettz('US/Eastern')) 61 | filters[BarsFilter(ticker=key[0], bgn_prd=bgn_prd, interval_len=key[1], interval_type=key[2])] = None 62 | 63 | bgn_prd = datetime.datetime.combine(datetime.datetime.utcnow().date() - time_delta_back, datetime.datetime.min.time()).replace(tzinfo=tz.gettz('US/Eastern')) 64 | for (symbol, interval_len, interval_type) in new_symbols: 65 | filters[BarsFilter(ticker=symbol, bgn_prd=bgn_prd, interval_len=interval_len, interval_type=interval_type)] = None 66 | 67 | logging.getLogger(__name__).info("Updating " + str(len(filters)) + " total symbols and intervals; New symbols and intervals: " + str(len(new_symbols))) 68 | 69 | q = queue.Queue(maxsize=100) 70 | 71 | threading.Thread(target=partial(noncache_provider, filters=filters, q=q), daemon=True).start() 72 | 73 | try: 74 | for i, tupl in enumerate(iter(q.get, None)): 75 | ft, to_cache = filters[tupl[0]], tupl[1] 76 | 77 | if to_cache is not None and not to_cache.empty: 78 | # Prepare data 79 | for c in [c for c in to_cache.columns if c not in ['symbol', 'open', 'high', 'low', 'close', 'volume']]: 80 | to_cache.drop(c, axis=1, inplace=True) 81 | 82 | to_cache['interval'] = str(ft.interval_len) + '_' + ft.interval_type 83 | 84 | if to_cache.iloc[0].name == ft.bgn_prd: 85 | to_cache = to_cache.iloc[1:] 86 | 87 | try: 88 | client.write_points(to_cache, 'bars', protocol='line', tag_columns=['symbol', 'interval'], time_precision='s') 89 | except Exception as err: 90 | logging.getLogger(__name__).exception(err) 91 | 92 | if i > 0 and (i % 20 == 0 or i == len(filters)): 93 | logging.getLogger(__name__).info("Cached " + str(i) + " queries") 94 | finally: 95 | client.close() 96 | 97 | 98 | def add_adjustments(client: InfluxDBClient, adjustments: list, provider: str): 99 | """ 100 | add a list of splits/dividends to the database 101 | :param client: influxdb client 102 | :param adjustments: list of adjustments of the type [(timestamp: datetime.date, symbol: str, typ: str, value), ...] 103 | :param provider: data provider 104 | """ 105 | points = [_get_adjustment_json_query(*a, provider=provider) for a in adjustments] 106 | return InfluxDBClient.write_points(client, points, protocol='json', time_precision='s') 107 | 108 | 109 | def add_adjustment(client: InfluxDBClient, timestamp: datetime.date, symbol: str, typ: str, value: float, provider: str): 110 | """ 111 | add splits/dividends to the database 112 | :param client: influxdb client 113 | :param timestamp: date of the adjustment 114 | :param symbol: symbol 115 | :param typ: 'split' or 'dividend' 116 | :param value: split_factor/dividend_rate 117 | :param provider: data provider 118 | """ 119 | json_body = _get_adjustment_json_query(timestamp=timestamp, symbol=symbol, typ=typ, value=value, provider=provider) 120 | return InfluxDBClient.write_points(client, [json_body], protocol='json', time_precision='s') 121 | 122 | 123 | def _get_adjustment_json_query(timestamp: datetime.date, symbol: str, typ: str, value: float, provider: str): 124 | return { 125 | "measurement": "splits_dividends", 126 | "tags": { 127 | "symbol": symbol, 128 | "provider": provider, 129 | }, 130 | 131 | "time": datetime.datetime.combine(timestamp, datetime.datetime.min.time()), 132 | "fields": {'value': value, 'type': typ} 133 | } 134 | -------------------------------------------------------------------------------- /atpy/data/cache/lmdb_cache.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import zlib 3 | 4 | import lmdb 5 | 6 | 7 | def write(key: str, value, lmdb_path: str, compress=True): 8 | with lmdb.open(lmdb_path, map_size=int(1e12)) as lmdb_env, lmdb_env.begin(write=True) as lmdb_txn: 9 | lmdb_txn.put(key.encode(), zlib.compress(pickle.dumps(value)) if compress else pickle.dumps(value)) 10 | 11 | 12 | def read(key: str, lmdb_path: str, decompress=True): 13 | with lmdb.open(lmdb_path) as lmdb_env, lmdb_env.begin() as lmdb_txn: 14 | result = lmdb_txn.get(key.encode()) 15 | if result is not None: 16 | if decompress: 17 | result = zlib.decompress(result) 18 | 19 | result = pickle.loads(result) 20 | 21 | return result 22 | 23 | 24 | def read_pickle(key: str, lmdb_path: str, decompress=True): 25 | return read(key=key, lmdb_path=lmdb_path, decompress=decompress) 26 | -------------------------------------------------------------------------------- /atpy/data/intrinio/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivan-vasilev/atpy/abe72832ae8cec818b0e67989892c25456e9e5f5/atpy/data/intrinio/__init__.py -------------------------------------------------------------------------------- /atpy/data/intrinio/api.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import json 3 | import logging 4 | import os 5 | import queue 6 | import threading 7 | import typing 8 | from io import StringIO 9 | from multiprocessing.pool import ThreadPool 10 | 11 | import pandas as pd 12 | import requests.sessions as sessions 13 | 14 | 15 | def to_dataframe(csv_str: str): 16 | """ 17 | Convert csv result to DataFrame 18 | :param csv_str: csv string 19 | :return: pd.DataFrame 20 | """ 21 | dates = [d for d in csv_str.split('\n', 1)[0].split(',') if 'date' in d or 'period' in d] 22 | return pd.read_csv(StringIO(csv_str), parse_dates=dates) 23 | 24 | 25 | def get_csv(sess: sessions.Session, endpoint: str, **parameters): 26 | """ 27 | get csv data from the Intrinio API 28 | :param sess: session 29 | :param endpoint: endpoint 30 | :param parameters: query parameters 31 | :return: csv result 32 | """ 33 | auth = os.getenv('INTRINIO_USERNAME'), os.getenv('INTRINIO_PASSWORD') 34 | 35 | url = '{}/{}'.format('https://api.intrinio.com', endpoint + ('' if endpoint.endswith('.csv') else '.csv')) 36 | 37 | if 'page_size' not in parameters: 38 | parameters['page_size'] = 10000 39 | 40 | pages = list() 41 | 42 | for page_number in itertools.count(): 43 | parameters['page_number'] = page_number + 1 44 | 45 | response = sess.request('GET', url, params=parameters, auth=auth, verify=True) 46 | if not response.ok: 47 | try: 48 | response.raise_for_status() 49 | except Exception as err: 50 | logging.getLogger(__name__).error(err) 51 | 52 | new_lines = response.content.decode('utf-8').count('\n') 53 | 54 | if new_lines == 1: 55 | break 56 | 57 | info, columns, page = response.content.decode('utf-8').split('\n', 2) 58 | 59 | if page_number == 0: 60 | info = {s.split(':')[0]: s.split(':')[1] for s in info.split(',')} 61 | total_pages = int(info['TOTAL_PAGES']) 62 | pages.append(columns.lower() + '\n') 63 | 64 | pages.append(page) 65 | 66 | if len(page) == 0 or page_number + 1 == total_pages: 67 | break 68 | 69 | return ''.join(pages) if len(pages) > 0 else None 70 | 71 | 72 | def get_data(filters: typing.List[dict], threads=1, async=False, processor: typing.Callable = None): 73 | """ 74 | Get async data for a list of filters. Works only for the historical API 75 | :param filters: a list of filters 76 | :param threads: number of threads for data retrieval 77 | :param async: if True, return queue. Otherwise, wait for the results 78 | :param processor: process the results 79 | 80 | For full list of available parameters check http://docs.intrinio.com/#historical-data 81 | 82 | :return Queue or pd.DataFrame with identifier, date set as multi index 83 | """ 84 | q = queue.Queue(100) 85 | pool = ThreadPool(threads) 86 | global_counter = {'c': 0} 87 | lock = threading.Lock() 88 | no_data = set() 89 | 90 | with sessions.Session() as sess: 91 | def mp_worker(f): 92 | try: 93 | data = get_csv(sess, **f) 94 | except Exception as err: 95 | data = None 96 | logging.getLogger(__name__).exception(err) 97 | 98 | if data is not None: 99 | q.put(processor(data, **f) if processor is not None else (json.dumps(f), data)) 100 | else: 101 | no_data.add(json.dumps(f)) 102 | 103 | with lock: 104 | global_counter['c'] += 1 105 | cnt = global_counter['c'] 106 | if cnt == len(filters): 107 | q.put(None) 108 | 109 | if cnt % 20 == 0 or cnt == len(filters): 110 | logging.getLogger(__name__).info("Loaded " + str(cnt) + " queries") 111 | if len(no_data) > 0: 112 | no_data_list = list(no_data) 113 | no_data_list.sort() 114 | logging.getLogger(__name__).info("No data found for " + str(len(no_data_list)) + " queries: " + str(no_data_list)) 115 | no_data.clear() 116 | 117 | if threads > 1 and len(filters) > 1: 118 | pool.map(mp_worker, (f for f in filters)) 119 | pool.close() 120 | else: 121 | for f in filters: 122 | mp_worker(f) 123 | 124 | if not async: 125 | result = dict() 126 | for job in iter(q.get, None): 127 | if job is None: 128 | break 129 | 130 | result[job[0]] = job[1] 131 | 132 | return result 133 | else: 134 | return q 135 | 136 | 137 | def get_historical_data(filters: typing.List[dict], threads=1, async=False): 138 | for f in filters: 139 | if 'endpoint' not in f: 140 | f['endpoint'] = 'historical_data' 141 | elif f['endpoint'] != 'historical_data': 142 | raise Exception("Only historical data is allowed with this request") 143 | 144 | result = get_data(filters, 145 | threads=threads, 146 | async=async, 147 | processor=_historical_data_processor) 148 | 149 | if not async and isinstance(result, dict): 150 | result = pd.concat(result) 151 | result.index.set_names('symbol', level=0, inplace=True) 152 | result = result.tz_localize('UTC', level=1, copy=False) 153 | 154 | return result 155 | 156 | 157 | def _historical_data_processor(csv_str: str, **parameters): 158 | """ 159 | Get historical data for given item and identifier 160 | :param csv_str: csv string 161 | :return pd.DataFrame with date set as index 162 | """ 163 | 164 | result = to_dataframe(csv_str) 165 | tag = result.columns[1] 166 | result['tag'] = tag 167 | result.rename(columns={tag: 'value'}, inplace=True) 168 | result.set_index(['date', 'tag'], drop=True, inplace=True, append=True) 169 | result.reset_index(level=0, inplace=True, drop=True) 170 | 171 | return parameters['identifier'], result 172 | 173 | 174 | class IntrinioEvents(object): 175 | """ 176 | Intrinio requests via events 177 | """ 178 | 179 | def __init__(self, listeners): 180 | self.listeners = listeners 181 | self.listeners += self.listener 182 | 183 | def listener(self, event): 184 | if event['type'] == 'intrinio_request': 185 | with sessions.Session() as sess: 186 | endpoint = event['endpoint'] if event['endpoint'].endswith('.csv') else event['endpoint'] + '.csv' 187 | if 'parameters' in event: 188 | result = get_csv(sess, endpoint=endpoint, **event['parameters']) 189 | else: 190 | result = get_csv(sess, endpoint=endpoint) 191 | 192 | if 'dataframe' in event: 193 | result = to_dataframe(result) 194 | 195 | self.listeners({'type': 'intrinio_request_result', 'data': result}) 196 | elif event['type'] == 'intrinio_historical_data': 197 | data = event['data'] if isinstance(event['data'], list) else event['data'] 198 | result = get_historical_data(data, 199 | threads=event['threads'] if 'threads' in event else 1, 200 | async=event['async'] if 'async' in event else False) 201 | 202 | self.listeners({'type': 'intrinio_historical_data_result', 'data': result}) 203 | -------------------------------------------------------------------------------- /atpy/data/intrinio/influxdb_cache.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import threading 4 | import typing 5 | 6 | import numpy as np 7 | from dateutil import tz 8 | from dateutil.parser import parse 9 | from dateutil.relativedelta import relativedelta 10 | from influxdb import InfluxDBClient, DataFrameClient 11 | 12 | from atpy.data.intrinio.api import get_historical_data 13 | 14 | 15 | class ClientFactory(object): 16 | def __init__(self, **kwargs): 17 | self.kwargs = kwargs 18 | 19 | def new_client(self): 20 | return InfluxDBClient(**self.kwargs) 21 | 22 | def new_df_client(self): 23 | return DataFrameClient(**self.kwargs) 24 | 25 | 26 | class InfluxDBCache(object): 27 | """ 28 | InfluxDB Intrinio cache using abstract data provider 29 | """ 30 | 31 | def __init__(self, client_factory: ClientFactory, listeners=None, time_delta_back: relativedelta = relativedelta(years=5)): 32 | self.client_factory = client_factory 33 | self.listeners = listeners 34 | self._time_delta_back = time_delta_back 35 | self._synchronized_symbols = set() 36 | self._lock = threading.RLock() 37 | 38 | def __enter__(self): 39 | self.client = self.client_factory.new_df_client() 40 | 41 | return self 42 | 43 | def __exit__(self, exception_type, exception_value, traceback): 44 | self.client.close() 45 | 46 | @property 47 | def ranges(self): 48 | """ 49 | :return: list of latest times for each entry grouped by symbol and tag 50 | """ 51 | parse_time = lambda t: parse(t).replace(tzinfo=tz.gettz('UTC')) 52 | 53 | points = InfluxDBClient.query(self.client, "select FIRST(value), symbol, itag, time from intrinio_tags group by symbol, itag").get_points() 54 | firsts = {(entry['symbol'], entry['itag']): parse_time(entry['time']) for entry in points} 55 | 56 | points = InfluxDBClient.query(self.client, "select LAST(value), symbol, itag, time from intrinio_tags group by symbol, itag").get_points() 57 | lasts = {(entry['symbol'], entry['itag']): parse_time(entry['time']) for entry in points} 58 | 59 | result = {k: (firsts[k], lasts[k]) for k in firsts.keys() & lasts.keys()} 60 | 61 | return result 62 | 63 | def _request_noncache_data(self, filters: typing.List[dict], async=False): 64 | """ 65 | request filter data 66 | :param filters: list of dicts for data request 67 | :return: 68 | """ 69 | return get_historical_data(filters=filters, async=async) 70 | 71 | def update_to_latest(self, new_symbols: typing.Set[typing.Tuple] = None, skip_if_older_than: relativedelta = None): 72 | """ 73 | Update existing entries in the database to the most current values 74 | :param new_symbols: additional symbols to add {(symbol, interval_len, interval_type), ...}} 75 | :param skip_if_older_than: skip symbol update if the symbol is older than... 76 | :return: 77 | """ 78 | filters = list() 79 | 80 | new_symbols = set() if new_symbols is None else new_symbols 81 | 82 | if skip_if_older_than is not None: 83 | skip_if_older_than = (datetime.datetime.utcnow().replace(tzinfo=tz.gettz('UTC')) - skip_if_older_than).astimezone(tz.gettz('US/Eastern')) 84 | 85 | ranges = self.ranges 86 | for key, time in [(e[0], e[1][1]) for e in ranges.items()]: 87 | if key in new_symbols: 88 | new_symbols.remove(key) 89 | 90 | if skip_if_older_than is None or time > skip_if_older_than: 91 | filters.append({'identifier': key[0], 'item': key[1], 'start_date': time.date()}) 92 | 93 | start_date = datetime.datetime.utcnow().date() - self._time_delta_back 94 | for (symbol, tag) in new_symbols: 95 | filters.append({'identifier': symbol, 'item': tag, 'start_date': start_date.strftime('%Y-%m-%d')}) 96 | 97 | logging.getLogger(__name__).info("Updating " + str(len(filters)) + " total symbols and intervals; New symbols and intervals: " + str(len(new_symbols))) 98 | 99 | q = self._request_noncache_data(filters, async=True) 100 | 101 | def worker(): 102 | client = self.client_factory.new_df_client() 103 | 104 | try: 105 | for i, tupl in enumerate(iter(q.get, None)): 106 | if tupl is None: 107 | return 108 | 109 | s, to_cache = tupl 110 | 111 | if to_cache is not None and not to_cache.empty: 112 | to_cache['symbol'] = s 113 | to_cache.reset_index(level=['tag'], inplace=True) 114 | to_cache.rename(columns={'tag': 'itag'}, inplace=True) 115 | try: 116 | client.write_points(to_cache, 'intrinio_tags', protocol='line', tag_columns=['symbol', 'itag'], time_precision='s') 117 | except Exception as err: 118 | logging.getLogger(__name__).exception(err) 119 | 120 | if i > 0 and (i % 20 == 0 or i == len(filters)): 121 | logging.getLogger(__name__).info("Cached " + str(i) + " queries") 122 | finally: 123 | client.close() 124 | 125 | t = threading.Thread(target=worker) 126 | t.start() 127 | t.join() 128 | 129 | def request_data(self, symbols: typing.Union[set, str] = None, tags: typing.Union[set, str] = None, start_date: datetime.date = None, end_date: datetime.date = None): 130 | query = "SELECT * FROM intrinio_tags" 131 | 132 | where = list() 133 | if symbols is not None: 134 | if isinstance(symbols, set) and len(symbols) > 0: 135 | where.append("symbol =~ /{}/".format("|".join(['^' + s + '$' for s in symbols]))) 136 | elif isinstance(symbols, str) and len(symbols) > 0: 137 | where.append("symbol = '{}'".format(symbols)) 138 | 139 | if tags is not None: 140 | if isinstance(tags, set) and len(tags) > 0: 141 | where.append("itag =~ /{}/".format("|".join(['^' + s + '$' for s in tags]))) 142 | elif isinstance(symbols, str) and len(tags) > 0: 143 | where.append("itag = '{}'".format(tags)) 144 | 145 | if start_date is not None: 146 | start_date = datetime.datetime.combine(start_date, datetime.datetime.min.time()) 147 | where.append("time >= '{}'".format(start_date)) 148 | 149 | if end_date is not None: 150 | end_date = datetime.datetime.combine(end_date, datetime.datetime.min.time()) 151 | where.append("time <= '{}'".format(end_date)) 152 | 153 | if len(where) > 0: 154 | query += " WHERE " + " AND ".join(where) 155 | 156 | result = self.client.query(query, chunked=True) 157 | 158 | if len(result) > 0: 159 | result = result['intrinio_tags'] 160 | 161 | result.rename(columns={'itag': 'tag'}, inplace=True) 162 | result.set_index(['tag', 'symbol'], drop=True, inplace=True, append=True) 163 | result.index.rename('date', level=0, inplace=True) 164 | result = result.reorder_levels(['symbol', 'date', 'tag']) 165 | result.sort_index(inplace=True) 166 | 167 | if result['value'].dtype != np.float: 168 | result['value'] = result['value'].astype(np.float) 169 | else: 170 | result = None 171 | 172 | return result 173 | -------------------------------------------------------------------------------- /atpy/data/iqfeed/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivan-vasilev/atpy/abe72832ae8cec818b0e67989892c25456e9e5f5/atpy/data/iqfeed/__init__.py -------------------------------------------------------------------------------- /atpy/data/iqfeed/bar_util.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | 4 | def merge_snapshots(s1: pd.DataFrame, s2: pd.DataFrame) -> pd.DataFrame: 5 | """ 6 | merge two bar dataframes (with multiindex [symbol, timestamp] 7 | :param s1: DataFrame 8 | :param s2: DataFrame 9 | :return: DataFrame 10 | """ 11 | if s1.empty and not s2.empty: 12 | return s2 13 | elif s2.empty and not s1.empty: 14 | return s1 15 | elif s1.empty and s2.empty: 16 | return 17 | 18 | return pd.concat([s1, s2]).sort_index() 19 | 20 | 21 | def reindex_and_fill(df: pd.DataFrame, index) -> pd.DataFrame: 22 | """ 23 | reindex DataFrame using new index and fill the missing values 24 | :param df: DataFrame 25 | :param index: new index to use 26 | :return: 27 | """ 28 | 29 | df = df.reindex(index) 30 | df.drop(['symbol', 'timestamp'], axis=1, inplace=True) 31 | df.reset_index(inplace=True) 32 | df.set_index(index, inplace=True) 33 | 34 | for c in [c for c in ['volume', 'number_of_trades'] if c in df.columns]: 35 | df[c].fillna(0, inplace=True) 36 | 37 | if 'close' in df.columns: 38 | df['close'] = df.groupby(level=0)['close'].fillna(method='ffill') 39 | df['close'] = df.groupby(level=0)['close'].fillna(method='backfill') 40 | op = df['close'] 41 | 42 | for c in [c for c in ['open', 'high', 'low'] if c in df.columns]: 43 | df[c].fillna(op, inplace=True) 44 | 45 | df = df.groupby(level=0).fillna(method='ffill') 46 | 47 | df = df.groupby(level=0).fillna(method='backfill') 48 | 49 | return df 50 | 51 | 52 | def expand(df: pd.DataFrame, steps: int, max_length: int=None) -> pd.DataFrame: 53 | """ 54 | expand DataFrame with steps at the end 55 | :param df: DataFrame 56 | :param steps: number of steps to expand at the end 57 | :param max_length: if max_length reached, truncate from the beginning 58 | :return: 59 | """ 60 | 61 | if len(df.index.levels[1]) < 2: 62 | return df 63 | 64 | diff = df.index.levels[1][1] - df.index.levels[1][0] 65 | new_index = df.index.levels[1].append(pd.date_range(df.index.levels[1][len(df.index.levels[1]) - 1] + diff, periods=steps, freq=diff)) 66 | 67 | multi_index = pd.MultiIndex.from_product([df.index.levels[0], new_index], names=['symbol', 'timestamp']).sort_values() 68 | 69 | result = df.reindex(multi_index) 70 | 71 | if max_length is not None: 72 | result = df.groupby(level=0).tail(max_length) 73 | 74 | return result 75 | 76 | 77 | def synchronize_timestamps(df) -> pd.DataFrame: 78 | """ 79 | synchronize the timestamps for all symbols in a DataFrame. Fill the values automatically 80 | :param df: DataFrame 81 | :return: df 82 | """ 83 | multi_index = pd.MultiIndex.from_product([df.index.levels[0].unique(), df.index.levels[1].unique()], names=['symbol', 'timestamp']).sort_values() 84 | return reindex_and_fill(df, multi_index) 85 | -------------------------------------------------------------------------------- /atpy/data/iqfeed/filters.py: -------------------------------------------------------------------------------- 1 | from abc import * 2 | from typing import NamedTuple 3 | 4 | 5 | class FilterProvider(metaclass=ABCMeta): 6 | """Base namedtuple filter provider generator/iterator interface""" 7 | 8 | @abstractmethod 9 | def __iter__(self): 10 | return 11 | 12 | @abstractmethod 13 | def __next__(self) -> NamedTuple: 14 | return 15 | 16 | 17 | class DefaultFilterProvider(FilterProvider): 18 | """Default filter provider, which contains a list of filters""" 19 | 20 | def __init__(self, repeat=True): 21 | self._filters = list() 22 | self._repeat = repeat 23 | 24 | def __iadd__(self, fn): 25 | self._filters.append(fn) 26 | return self 27 | 28 | def __isub__(self, fn): 29 | self._filters.remove(fn) 30 | return self 31 | 32 | def __iter__(self): 33 | self.__counter = 0 34 | return self 35 | 36 | def __next__(self) -> FilterProvider: 37 | if not self._repeat and self.__counter >= len(self._filters): 38 | raise StopIteration 39 | else: 40 | self.__counter += 1 41 | return self._filters[self.__counter % len(self._filters)] 42 | -------------------------------------------------------------------------------- /atpy/data/iqfeed/iqfeed_influxdb_cache.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import functools 3 | import json 4 | import logging 5 | 6 | from dateutil import tz 7 | from influxdb import InfluxDBClient 8 | 9 | from atpy.data.cache.influxdb_cache import add_adjustments 10 | from atpy.data.iqfeed.iqfeed_history_provider import IQFeedHistoryProvider, BarsInPeriodFilter, BarsDailyForDatesFilter 11 | 12 | 13 | def noncache_provider(history: IQFeedHistoryProvider): 14 | def _request_noncache_data(filters, q, h: IQFeedHistoryProvider): 15 | """ 16 | :return: request data from data provider (has to be UTC localized) 17 | """ 18 | new_filters = list() 19 | filters_copy = filters.copy() 20 | 21 | for f in filters_copy: 22 | if f.interval_type == 's': 23 | new_f = BarsInPeriodFilter(ticker=f.ticker, bgn_prd=f.bgn_prd.astimezone(tz.gettz('US/Eastern')) if f.bgn_prd is not None else None, end_prd=None, interval_len=f.interval_len, interval_type=f.interval_type) 24 | elif f.interval_type == 'd': 25 | new_f = BarsDailyForDatesFilter(ticker=f.ticker, bgn_dt=f.bgn_prd.date() if f.bgn_prd is not None else None, end_dt=None) 26 | 27 | filters[new_f] = f 28 | 29 | new_filters.append(new_f) 30 | 31 | h.request_data_by_filters(new_filters, q) 32 | 33 | return functools.partial(_request_noncache_data, h=history) 34 | 35 | 36 | def update_fundamentals(client: InfluxDBClient, fundamentals: list): 37 | points = list() 38 | for f in fundamentals: 39 | points.append( 40 | { 41 | "measurement": "iqfeed_fundamentals", 42 | "tags": { 43 | "symbol": f['symbol'], 44 | }, 45 | "time": datetime.datetime.combine(datetime.datetime.utcnow().date(), datetime.datetime.min.time()), 46 | "fields": { 47 | "data": json.dumps(f, default=lambda x: x.isoformat() if isinstance(x, datetime.datetime) else str(x)), 48 | } 49 | } 50 | ) 51 | 52 | try: 53 | InfluxDBClient.write_points(client, points, protocol='json', time_precision='s') 54 | except Exception as err: 55 | logging.getLogger(__name__).error(err) 56 | 57 | 58 | def update_splits_dividends(client: InfluxDBClient, fundamentals: list): 59 | points = list() 60 | for f in fundamentals: 61 | if f['split_factor_1_date'] is not None and f['split_factor_1'] is not None: 62 | points.append((f['split_factor_1_date'], f['symbol'], 'split', f['split_factor_1'])) 63 | 64 | if f['split_factor_2_date'] is not None and f['split_factor_2'] is not None: 65 | points.append((f['split_factor_2_date'], f['symbol'], 'split', f['split_factor_2'])) 66 | 67 | if f['ex-dividend_date'] is not None and f['dividend_amount'] is not None: 68 | points.append((f['ex-dividend_date'], f['symbol'], 'dividend', f['dividend_amount'])) 69 | 70 | add_adjustments(client=client, adjustments=points, provider='iqfeed') 71 | -------------------------------------------------------------------------------- /atpy/data/iqfeed/iqfeed_influxdb_cache_requests.py: -------------------------------------------------------------------------------- 1 | import json 2 | import typing 3 | 4 | from influxdb import InfluxDBClient 5 | 6 | 7 | def get_cache_fundamentals(client: InfluxDBClient, symbol: typing.Union[list, str] = None): 8 | query = "SELECT * FROM iqfeed_fundamentals" 9 | if isinstance(symbol, list) and len(symbol) > 0: 10 | query += " WHERE symbol =~ /{}/".format("|".join(['^' + s + '$' for s in symbol])) 11 | elif isinstance(symbol, str) and len(symbol) > 0: 12 | query += " WHERE symbol = '{}'".format(symbol) 13 | 14 | result = {f['symbol']: {**json.loads(f['data']), **{'last_update': f['time']}} for f in list(InfluxDBClient.query(client, query, chunked=True).get_points())} 15 | 16 | return result[symbol] if isinstance(symbol, str) else result 17 | -------------------------------------------------------------------------------- /atpy/data/iqfeed/iqfeed_news_provider.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import threading 3 | from datetime import timedelta 4 | from typing import List 5 | 6 | import pandas as pd 7 | from dateutil import tz 8 | 9 | import atpy.data.iqfeed.util as iqfeedutil 10 | import pyiqfeed as iq 11 | from atpy.data.iqfeed.filters import * 12 | 13 | 14 | class NewsFilter(NamedTuple): 15 | """ 16 | News filter parameters 17 | """ 18 | sources: List[str] 19 | symbols: List[str] 20 | date: datetime.date 21 | timeout: int 22 | limit: int 23 | 24 | 25 | NewsFilter.__new__.__defaults__ = (None, None, None, None, 100000) 26 | 27 | 28 | class DefaultNewsFilterProvider(DefaultFilterProvider): 29 | """Default news filter provider, which contains a list of filters""" 30 | 31 | def _default_filter(self): 32 | return NewsFilter() 33 | 34 | 35 | class IQFeedNewsProvider(object): 36 | """ 37 | IQFeed news provider (not streaming). See the unit test on how to use 38 | """ 39 | 40 | def __init__(self, attach_text=False, key_suffix=''): 41 | """ 42 | :param attach_text: attach news text (separate request for each news item) 43 | :param key_suffix: suffix in the output dictionary 44 | """ 45 | self.attach_text = attach_text 46 | self.conn = None 47 | self.key_suffix = key_suffix 48 | 49 | def __enter__(self): 50 | iqfeedutil.launch_service() 51 | 52 | self.conn = iq.NewsConn() 53 | self.conn.connect() 54 | self.cfg = self.conn.request_news_config() 55 | 56 | return self 57 | 58 | def __exit__(self, exception_type, exception_value, traceback): 59 | """Disconnect connection etc""" 60 | self.conn.disconnect() 61 | self.quote_conn = None 62 | 63 | def __del__(self): 64 | if self.conn is not None: 65 | self.conn.disconnect() 66 | self.cfg = None 67 | 68 | def __getattr__(self, name): 69 | if self.conn is not None: 70 | return getattr(self.conn, name) 71 | else: 72 | raise AttributeError 73 | 74 | def request_news(self, f: NewsFilter): 75 | _headlines = self.conn.request_news_headlines(sources=f.sources, symbols=f.symbols, date=f.date, limit=f.limit, timeout=f.timeout) 76 | headlines = [h._asdict() for h in _headlines] 77 | 78 | processed_data = None 79 | 80 | for h in headlines: 81 | if self.attach_text: 82 | h['text'] = self.conn.request_news_story(h['story_id']).story 83 | 84 | h['timestamp' + self.key_suffix] = (datetime.datetime.combine(h.pop('story_date').astype(datetime.date), datetime.datetime.min.time()) + timedelta(microseconds=h.pop('story_time'))).replace( 85 | tzinfo=tz.gettz('US/Eastern')).astimezone(tz.gettz('UTC')) 86 | 87 | if processed_data is None: 88 | processed_data = {f + self.key_suffix: list() for f in h.keys()} 89 | 90 | for key, value in h.items(): 91 | processed_data[key + self.key_suffix].append(value) 92 | 93 | result = pd.DataFrame(processed_data) 94 | 95 | result = result.set_index(['timestamp', 'story_id'], drop=False, append=False).iloc[::-1] 96 | 97 | return result 98 | 99 | 100 | class IQFeedNewsListener(IQFeedNewsProvider): 101 | """ 102 | IQFeed news listener (not streaming). See the unit test on how to use 103 | """ 104 | 105 | def __init__(self, listeners, attach_text=False, key_suffix='', filter_provider=DefaultNewsFilterProvider()): 106 | """ 107 | :param listeners: event listeners 108 | :param attach_text: attach news text (separate request for each news item) 109 | :param key_suffix: suffix in the output dictionary 110 | :param filter_provider: iterator for filters 111 | """ 112 | super().__init__(attach_text=attach_text, key_suffix=key_suffix) 113 | self.listeners = listeners 114 | self.conn = None 115 | self.filter_provider = filter_provider 116 | 117 | def __enter__(self): 118 | super().__enter__() 119 | iqfeedutil.launch_service() 120 | 121 | self.conn = iq.NewsConn() 122 | self.conn.connect() 123 | self.cfg = self.conn.request_news_config() 124 | 125 | self.is_running = True 126 | self.producer_thread = threading.Thread(target=self.produce, daemon=True) 127 | self.producer_thread.start() 128 | 129 | return self 130 | 131 | def __exit__(self, exception_type, exception_value, traceback): 132 | """Disconnect connection etc""" 133 | self.conn.disconnect() 134 | self.quote_conn = None 135 | self.is_running = False 136 | 137 | def __del__(self): 138 | if self.conn is not None: 139 | self.conn.disconnect() 140 | self.cfg = None 141 | 142 | def __getattr__(self, name): 143 | if self.conn is not None: 144 | return getattr(self.conn, name) 145 | else: 146 | raise AttributeError 147 | 148 | def produce(self): 149 | for f in self.filter_provider: 150 | result = super().request_news(f) 151 | 152 | self.listeners({'type': 'news_batch', 'data': result}) 153 | 154 | if not self.is_running: 155 | break 156 | 157 | def batch_provider(self): 158 | return iqfeedutil.IQFeedDataProvider(self.listeners, accept_event=lambda e: True if e['type'] == 'news_batch' else False) 159 | -------------------------------------------------------------------------------- /atpy/data/iqfeed/iqfeed_postgres_cache.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import functools 3 | import json 4 | import typing 5 | 6 | import pandas as pd 7 | from dateutil import tz 8 | 9 | from atpy.data.cache.postgres_cache import insert_json 10 | from atpy.data.iqfeed.iqfeed_history_provider import IQFeedHistoryProvider, BarsInPeriodFilter, BarsDailyForDatesFilter 11 | 12 | 13 | def noncache_provider(history: IQFeedHistoryProvider): 14 | def _request_noncache_data(filters: dict, q, h: IQFeedHistoryProvider): 15 | """ 16 | :return: request data from data provider (has to be UTC localized) 17 | """ 18 | new_filters = list() 19 | filters_copy = filters.copy() 20 | 21 | for f in filters_copy: 22 | if f.interval_type == 's': 23 | new_f = BarsInPeriodFilter(ticker=f.ticker, bgn_prd=f.bgn_prd.astimezone(tz.gettz('US/Eastern')) if f.bgn_prd is not None else None, end_prd=None, interval_len=f.interval_len, interval_type=f.interval_type) 24 | elif f.interval_type == 'd': 25 | new_f = BarsDailyForDatesFilter(ticker=f.ticker, bgn_dt=f.bgn_prd.date() if f.bgn_prd is not None else None, end_dt=None) 26 | 27 | filters[new_f] = f 28 | 29 | new_filters.append(new_f) 30 | 31 | h.request_data_by_filters(new_filters, q) 32 | 33 | return functools.partial(_request_noncache_data, h=history) 34 | 35 | 36 | def update_fundamentals(conn, fundamentals: dict, table_name: str = 'json_data'): 37 | to_store = list() 38 | for v in fundamentals.values(): 39 | v['provider'] = 'iqfeed' 40 | v['type'] = 'fundamentals' 41 | 42 | to_store.append(json.dumps(v, default=lambda x: x.isoformat() if isinstance(x, datetime.datetime) else str(x))) 43 | 44 | insert_json(conn=conn, table_name=table_name, data='\n'.join(to_store)) 45 | 46 | 47 | def request_fundamentals(conn, symbol: typing.Union[list, str], table_name: str = 'json_data'): 48 | where = " WHERE json_data ->> 'type' = 'fundamentals' AND json_data ->> 'provider' = 'iqfeed'" 49 | params = list() 50 | 51 | if isinstance(symbol, list): 52 | where += " AND json_data ->> 'symbol' IN (%s)" % ','.join(['%s'] * len(symbol)) 53 | params += symbol 54 | elif isinstance(symbol, str): 55 | where += " AND json_data ->> 'symbol' = %s" 56 | params.append(symbol) 57 | 58 | cursor = conn.cursor() 59 | cursor.execute("select * from {0} {1}".format(table_name, where), params) 60 | records = cursor.fetchall() 61 | 62 | if len(records) > 0: 63 | df = pd.DataFrame([x[0] for x in records]).drop(['type', 'provider'], axis=1) 64 | 65 | return df 66 | -------------------------------------------------------------------------------- /atpy/data/iqfeed/util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import queue 4 | import tempfile 5 | import typing 6 | import zipfile 7 | from collections import OrderedDict, deque 8 | 9 | import numpy as np 10 | import pandas as pd 11 | import requests 12 | 13 | import pyiqfeed as iq 14 | 15 | 16 | def dtn_credentials(): 17 | return os.environ['DTN_PRODUCT_ID'], os.environ['DTN_LOGIN'], os.environ['DTN_PASSWORD'], 'Debugging' 18 | 19 | 20 | def launch_service(): 21 | """Check if IQFeed.exe is running and start if not""" 22 | dtn_product_id, dtn_login, dtn_password, version = dtn_credentials() 23 | 24 | svc = iq.FeedService(product=dtn_product_id, 25 | version=version, 26 | login=dtn_login, 27 | password=dtn_password) 28 | 29 | headless = bool(os.environ["DTN_HEADLESS"]) if "DTN_HEADLESS" in os.environ else "DISPLAY" not in os.environ 30 | logging.getLogger(__name__).info("Launching IQFeed service in " + ("headless mode" if headless else "non headless mode")) 31 | 32 | svc.launch(headless=headless) 33 | 34 | 35 | def iqfeed_to_df(data: typing.Collection): 36 | """ 37 | Create minibatch-type data frame based on the pyiqfeed data format 38 | :param data: data list 39 | :return: 40 | """ 41 | result = None 42 | 43 | for i, datum in enumerate(data): 44 | datum = datum[0] if len(datum) == 1 else datum 45 | 46 | if result is None: 47 | result = OrderedDict( 48 | [(n.replace(" ", "_").lower(), 49 | np.empty((len(data),), d.dtype if str(d.dtype) not in ('|S4', '|S2', '|S3') else object)) 50 | for n, d in zip(datum.dtype.names, datum)]) 51 | 52 | for j, f in enumerate(datum.dtype.names): 53 | d = datum[j] 54 | if isinstance(d, bytes): 55 | d = d.decode('ascii') 56 | 57 | result[f.replace(" ", "_").lower()][i] = d 58 | 59 | return pd.DataFrame(result) 60 | 61 | 62 | def iqfeed_to_deque(data: typing.Iterable, maxlen: int = None): 63 | """ 64 | Create minibatch-type dict of deques based on the pyiqfeed data format 65 | :param data: data list 66 | :param maxlen: maximum deque length 67 | :return: 68 | """ 69 | result = None 70 | 71 | for i, datum in enumerate(data): 72 | datum = datum[0] if len(datum) == 1 else datum 73 | 74 | if result is None: 75 | result = OrderedDict( 76 | [(n.replace(" ", "_").lower(), 77 | deque(maxlen=maxlen)) 78 | for n, d in zip(datum.dtype.names, datum)]) 79 | 80 | for j, f in enumerate(datum.dtype.names): 81 | d = datum[j] 82 | if isinstance(datum[j], bytes): 83 | d = datum[j].decode('ascii') 84 | 85 | result[f.replace(" ", "_").lower()].append(d) 86 | 87 | return result 88 | 89 | 90 | def get_last_value(data: dict) -> dict: 91 | """ 92 | If the data is a result is a time-serires (dict of collections), return the last one 93 | :param data: data list 94 | :return: 95 | """ 96 | return OrderedDict([(k, v[-1] if isinstance(v, typing.Collection) else v) for k, v in data.items()]) 97 | 98 | 99 | def iqfeed_to_dict(data): 100 | """ 101 | Turn one iqfeed data item to dict 102 | :param data: data list 103 | :return: 104 | """ 105 | data = data[0] if len(data) == 1 else data 106 | 107 | result = OrderedDict([(n.replace(" ", "_").lower(), d) for n, d in zip(data.dtype.names, data)]) 108 | 109 | for k, v in result.items(): 110 | if isinstance(v, bytes): 111 | result[k] = v.decode('ascii') 112 | elif pd.isnull(v): 113 | result[k] = None 114 | 115 | return result 116 | 117 | 118 | def get_symbols(symbols_file: str = None, flt: dict = None): 119 | """ 120 | Get available symbols and information about them 121 | 122 | :param symbols_file: location of the symbols file (if None, the file is downloaded) 123 | :param flt: filter for the symbols 124 | """ 125 | 126 | with tempfile.TemporaryDirectory() as td: 127 | if symbols_file is not None: 128 | logging.getLogger(__name__).info("Symbols: " + symbols_file) 129 | zipfile.ZipFile(symbols_file).extractall(td) 130 | else: 131 | with tempfile.TemporaryFile() as tf: 132 | logging.getLogger(__name__).info("Downloading symbol list... ") 133 | tf.write(requests.get('http://www.dtniq.com/product/mktsymbols_v2.zip', allow_redirects=True).content) 134 | zipfile.ZipFile(tf).extractall(td) 135 | 136 | with open(os.path.join(td, 'mktsymbols_v2.txt')) as f: 137 | content = f.readlines() 138 | 139 | logging.getLogger(__name__).debug("Filtering companies...") 140 | 141 | flt = {'SECURITY TYPE': 'EQUITY', 'EXCHANGE': {'NYSE', 'NASDAQ'}} if flt is None else flt 142 | 143 | cols = content[0].split('\t') 144 | positions = {cols.index(k): v if isinstance(v, set) else {v} for k, v in flt.items()} 145 | 146 | result = dict() 147 | for c in content[1:]: 148 | split = c.split('\t') 149 | if all([split[col] in positions[col] for col in positions]): 150 | result[split[0]] = {cols[i]: split[i] for i in range(1, len(cols))} 151 | 152 | logging.getLogger(__name__).debug("Done") 153 | 154 | return result 155 | 156 | 157 | class IQFeedDataProvider(object): 158 | """Streaming data provider generator/iterator interface""" 159 | 160 | def __init__(self, listeners, accept_event): 161 | self._queue = queue.Queue() 162 | self.listeners = listeners 163 | self.accept_event = accept_event 164 | 165 | def _populate_queue(self, event): 166 | if self.accept_event(event): 167 | self._queue.put(event['data']) 168 | 169 | def __iter__(self): 170 | return self 171 | 172 | def __next__(self) -> map: 173 | return self._queue.get() 174 | 175 | def __enter__(self): 176 | self.listeners += self._populate_queue 177 | 178 | return self 179 | 180 | def __exit__(self, exception_type, exception_value, traceback): 181 | self.listeners -= self._populate_queue 182 | -------------------------------------------------------------------------------- /atpy/data/latest_data_snapshot.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import typing 3 | from collections import OrderedDict 4 | 5 | from atpy.data.iqfeed.util import * 6 | 7 | 8 | class LatestDataSnapshot(object): 9 | """Listen and maintain a dataframe of the latest data events""" 10 | 11 | def __init__(self, listeners, event: typing.Union[str, typing.Set[str]], depth=0, fire_update: bool = False): 12 | """ 13 | :param listeners: listeners 14 | :param event: event or list of events to accept 15 | :param depth: keep depth of the snapshot 16 | :param fire_update: whether to fire an event in case of snapshot update 17 | """ 18 | 19 | self.listeners = listeners 20 | self.listeners += self.on_event 21 | 22 | self.depth = depth 23 | self.event = {event} if isinstance(event, str) else event 24 | self._fire_update = fire_update 25 | self._snapshot = OrderedDict() 26 | 27 | self._rlock = threading.RLock() 28 | 29 | pd.set_option('mode.chained_assignment', 'warn') 30 | 31 | def update_snapshot(self, data): 32 | with self._rlock: 33 | timestamp_cols = [c for c in data if pd.core.dtypes.common.is_datetimelike(data[c])] 34 | if timestamp_cols: 35 | ind = data[timestamp_cols[0]].unique() 36 | else: 37 | if isinstance(data.index, pd.DatetimeIndex): 38 | ind = data.index 39 | elif isinstance(data.index, pd.MultiIndex): 40 | ind = data.index.levels[0] 41 | else: 42 | ind = None 43 | 44 | if not isinstance(ind, pd.DatetimeIndex): 45 | raise Exception("Only first level DateTimeIndex is supported") 46 | 47 | for i in ind[-self.depth:]: 48 | self._snapshot[i] = pd.concat([self._snapshot[i], data.loc[i]]) if i in self._snapshot else data.loc[i] 49 | 50 | while len(self._snapshot) > self.depth: 51 | self._snapshot.popitem() 52 | 53 | def on_event(self, event): 54 | if event['type'] in self.event: 55 | self.update_snapshot(event['data']) 56 | if self._fire_update: 57 | self.listeners({'type': event['type'] + '_snapshot', 'data': pd.concat(self._snapshot), 'new_data': pd.concat(self._snapshot)}) 58 | elif event['type'] == 'request_latest': 59 | self.listeners({'type': 'snapshot', 'data': pd.concat(self._snapshot)}) 60 | -------------------------------------------------------------------------------- /atpy/data/quandl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivan-vasilev/atpy/abe72832ae8cec818b0e67989892c25456e9e5f5/atpy/data/quandl/__init__.py -------------------------------------------------------------------------------- /atpy/data/quandl/api.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import logging 3 | import os 4 | import queue 5 | import tempfile 6 | import threading 7 | import typing 8 | import zipfile 9 | from enum import Enum 10 | from multiprocessing.pool import ThreadPool 11 | 12 | import pandas as pd 13 | import quandl 14 | 15 | 16 | def get_time_series(filters: typing.List[dict], threads=1, async=False, processor: typing.Callable = None): 17 | """ 18 | Get async data for a list of filters. Works only for the historical API 19 | :param filters: a list of filters 20 | :param threads: number of threads for data retrieval 21 | :param async: if True, return queue. Otherwise, wait for the results 22 | :param processor: process each result 23 | :return Queue or pd.DataFrame with identifier, date set as multi index 24 | """ 25 | 26 | return __get_data(filters=filters, api_type=__APIType.TIME_SERIES, threads=threads, async=async, processor=processor) 27 | 28 | 29 | def get_table(filters: typing.List[dict], threads=1, async=False, processor: typing.Callable = None): 30 | """ 31 | Get async data for a list of filters. Works only for the historical API 32 | :param filters: a list of filters 33 | :param threads: number of threads for data retrieval 34 | :param async: if True, return queue. Otherwise, wait for the results 35 | :param processor: process each result 36 | :return Queue or pd.DataFrame with identifier, date set as multi index 37 | """ 38 | 39 | return __get_data(filters=filters, api_type=__APIType.TABLES, threads=threads, async=async, processor=processor) 40 | 41 | 42 | class __APIType(Enum): 43 | TIME_SERIES = 1 44 | TABLES = 2 45 | 46 | 47 | def __get_data(filters: typing.List[dict], api_type: __APIType, threads=1, async=False, processor: typing.Callable = None): 48 | """ 49 | Get async data for a list of filters using the tables or time series api 50 | :param filters: a list of filters 51 | :param api_type: whether to use time series or tables 52 | :param threads: number of threads for data retrieval 53 | :param async: if True, return queue. Otherwise, wait for the results 54 | :param processor: process each result 55 | :return Queue or pd.DataFrame with identifier, date set as multi index 56 | """ 57 | api_k = os.environ['QUANDL_API_KEY'] if 'QUANDL_API_KEY' in os.environ else None 58 | q = queue.Queue(100) 59 | global_counter = {'c': 0} 60 | lock = threading.Lock() 61 | no_data = set() 62 | 63 | def mp_worker(f): 64 | try: 65 | data = None 66 | if api_type == __APIType.TIME_SERIES: 67 | data = quandl.get(**f, paginate=True, api_key=api_k) 68 | if data is not None: 69 | data = data.tz_localize('UTC', copy=False) 70 | q.put((f['dataset'], processor(data, **f) if processor is not None else data)) 71 | elif api_type == __APIType.TABLES: 72 | data = quandl.get_table(**f, paginate=True, api_key=api_k) 73 | if data is not None: 74 | q.put((f['datatable_code'], processor(data, **f) if processor is not None else data)) 75 | 76 | except Exception as err: 77 | data = None 78 | logging.getLogger(__name__).exception(err) 79 | 80 | if data is None: 81 | no_data.add(f) 82 | 83 | with lock: 84 | global_counter['c'] += 1 85 | cnt = global_counter['c'] 86 | if cnt == len(filters): 87 | q.put(None) 88 | 89 | if cnt % 20 == 0 or cnt == len(filters): 90 | logging.getLogger(__name__).info("Loaded " + str(cnt) + " queries") 91 | if len(no_data) > 0: 92 | no_data_list = list(no_data) 93 | no_data_list.sort() 94 | logging.getLogger(__name__).info("No data found for " + str(len(no_data_list)) + " datasets: " + str(no_data_list)) 95 | no_data.clear() 96 | 97 | if threads > 1 and len(filters) > 1: 98 | pool = ThreadPool(threads) 99 | pool.map(mp_worker, (f for f in filters)) 100 | pool.close() 101 | else: 102 | for f in filters: 103 | mp_worker(f) 104 | 105 | if not async: 106 | result = dict() 107 | while True: 108 | job = q.get() 109 | if job is None: 110 | break 111 | 112 | if job[0] in result: 113 | current = result[job[0]] 114 | if isinstance(current, list): 115 | current.append(job[1]) 116 | else: 117 | result[job[0]] = [result[job[0]], job[1]] 118 | else: 119 | result[job[0]] = job[1] 120 | 121 | return result 122 | else: 123 | return q 124 | 125 | 126 | def bulkdownload(dataset: str, chunksize=None): 127 | with tempfile.TemporaryDirectory() as td: 128 | filename = os.path.join(td, dataset + '.zip') 129 | logging.getLogger(__name__).info("Downloading dataset " + dataset + " to " + filename) 130 | quandl.bulkdownload(dataset, filename=filename, api_key=os.environ['QUANDL_API_KEY'] if 'QUANDL_API_KEY' in os.environ else None) 131 | zipfile.ZipFile(filename).extractall(td) 132 | 133 | logging.getLogger(__name__).info("Done... Start yielding dataframes") 134 | 135 | return pd.read_csv(glob.glob(os.path.join(td, '*.csv'))[0], header=None, chunksize=chunksize, parse_dates=[1]) 136 | 137 | 138 | def get_sf1(filters: typing.List[dict], threads=1, async=False): 139 | """ 140 | return core us fundamental data 141 | :param filters: list of filters 142 | :param threads: number of request threads 143 | :param async: wait for the result or return a queue 144 | :return: 145 | """ 146 | 147 | def _sf1_processor(df, dataset): 148 | df.rename(columns={'Value': 'value'}, inplace=True) 149 | df.index.rename('date', inplace=True) 150 | df = df.tz_localize('UTC', copy=False) 151 | df['symbol'], df['indicator'], df['dimension'] = dataset.split('/')[1].split('_') 152 | df.set_index(['symbol', 'indicator', 'dimension'], drop=True, inplace=True, append=True) 153 | 154 | return df 155 | 156 | result = get_time_series(filters, 157 | threads=threads, 158 | async=async, 159 | processor=_sf1_processor) 160 | 161 | if not async and isinstance(result, list): 162 | result = pd.concat(result) 163 | result.sort_index(inplace=True, ascending=True) 164 | 165 | return result 166 | 167 | 168 | def bulkdownload_sf0(): 169 | df = bulkdownload(dataset='SF0', chunksize=None) 170 | sid = df[0] 171 | df.drop(0, axis=1, inplace=True) 172 | df = pd.concat([df, sid.str.split('_', expand=True)], axis=1, copy=False) 173 | df.columns = ['date', 'value', 'symbol', 'indicator', 'dimension'] 174 | df.set_index(['date', 'symbol', 'indicator', 'dimension'], drop=True, inplace=True, append=False) 175 | 176 | return df 177 | 178 | 179 | class QuandlEvents(object): 180 | """ 181 | Quandl requests via events 182 | """ 183 | 184 | def __init__(self, listeners): 185 | self.listeners = listeners 186 | self.listeners += self.listener 187 | 188 | def listener(self, event): 189 | if event['type'] == 'quandl_timeseries_request': 190 | result = get_time_series(event['data'] if isinstance(event['data'], list) else event['data'], 191 | threads=event['threads'] if 'threads' in event else 1, 192 | async=event['async'] if 'async' in event else False) 193 | 194 | self.listeners({'type': 'quandl_timeseries_result', 'data': result}) 195 | elif event['type'] == 'quandl_table_request': 196 | result = get_table(event['data'] if isinstance(event['data'], list) else event['data'], 197 | threads=event['threads'] if 'threads' in event else 1, 198 | async=event['async'] if 'async' in event else False) 199 | 200 | self.listeners({'type': 'quandl_table_result', 'data': result}) 201 | -------------------------------------------------------------------------------- /atpy/data/quandl/influxdb_cache.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import typing 4 | 5 | import pandas as pd 6 | from influxdb import DataFrameClient 7 | 8 | from atpy.data.quandl.api import bulkdownload 9 | 10 | 11 | class InfluxDBCache(object): 12 | """ 13 | InfluxDB quandl cache using abstract data provider 14 | """ 15 | 16 | def __init__(self, client: DataFrameClient, listeners=None): 17 | self.client = client 18 | self.listeners = listeners 19 | 20 | def __enter__(self): 21 | logging.basicConfig(level=logging.INFO) 22 | return self 23 | 24 | def __exit__(self, exception_type, exception_value, traceback): 25 | self.client.close() 26 | 27 | def add_dataset_to_cache(self, dataset: str): 28 | self.add_to_cache(measurement='quandl_' + dataset, dfs=bulkdownload(dataset=dataset, chunksize=100000)) 29 | 30 | def add_to_cache(self, measurement: str, dfs: typing.Iterator[pd.DataFrame], tag_columns: list=None): 31 | for i, df in enumerate(dfs): 32 | self.client.write_points(df, 'quandl_' + measurement, tag_columns=tag_columns, protocol='line', time_precision='s') 33 | 34 | if i > 0 and i % 5 == 0: 35 | logging.getLogger(__name__).info("Cached " + str(i) + " queries") 36 | 37 | if i > 0 and i % 5 != 0: 38 | logging.getLogger(__name__).info("Cached " + str(i) + " queries") 39 | 40 | def request_data(self, dataset, tags: dict = None, start_date: datetime.date = None, end_date: datetime.date = None): 41 | query = "SELECT * FROM quandl_" + dataset 42 | 43 | where = list() 44 | if tags is not None: 45 | for t, v in tags.items(): 46 | if isinstance(v, set) and len(v) > 0: 47 | where.append(t + " =~ /{}/".format("|".join(['^' + s + '$' for s in v]))) 48 | elif isinstance(v, str) and len(v) > 0: 49 | where.append(t + " = '{}'".format(v)) 50 | 51 | if start_date is not None: 52 | start_date = datetime.datetime.combine(start_date, datetime.datetime.min.time()) 53 | where.append("time >= '{}'".format(start_date)) 54 | 55 | if end_date is not None: 56 | end_date = datetime.datetime.combine(end_date, datetime.datetime.min.time()) 57 | where.append("time <= '{}'".format(end_date)) 58 | 59 | if len(where) > 0: 60 | query += " WHERE " + " AND ".join(where) 61 | 62 | result = self.client.query(query, chunked=True) 63 | 64 | if len(result) > 0: 65 | result = result["quandl_" + dataset] 66 | result.index.rename('date', inplace=True) 67 | else: 68 | result = None 69 | 70 | return result 71 | -------------------------------------------------------------------------------- /atpy/data/quandl/postgres_cache.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import typing 3 | 4 | import pandas as pd 5 | import psycopg2 6 | from dateutil.relativedelta import relativedelta 7 | 8 | from atpy.data.cache.postgres_cache import insert_df 9 | from atpy.data.quandl.api import bulkdownload_sf0 10 | from atpy.data.ts_util import slice_periods 11 | 12 | 13 | create_sf = \ 14 | """ 15 | -- Table: public.{0} 16 | 17 | -- DROP TABLE public.{0}; 18 | 19 | CREATE TABLE public.{0} 20 | ( 21 | date timestamp(6) without time zone NOT NULL, 22 | symbol character varying COLLATE pg_catalog."default" NOT NULL, 23 | indicator character varying COLLATE pg_catalog."default" NOT NULL, 24 | dimension character varying COLLATE pg_catalog."default" NOT NULL, 25 | value double precision 26 | ) 27 | WITH ( 28 | OIDS = FALSE 29 | ) 30 | TABLESPACE pg_default; 31 | 32 | ALTER TABLE public.{0} 33 | OWNER to postgres; 34 | 35 | """ 36 | 37 | create_sf_indices = \ 38 | """ 39 | -- Index: ix_{0}_date 40 | 41 | -- DROP INDEX public.ix_{0}_date; 42 | 43 | CREATE INDEX ix_{0}_date 44 | ON public.{0} USING btree 45 | (date) 46 | TABLESPACE pg_default; 47 | 48 | ALTER TABLE public.{0} 49 | CLUSTER ON ix_{0}_date; 50 | 51 | -- Index: ix_{0}_dimension 52 | 53 | -- DROP INDEX public.ix_{0}_dimension; 54 | 55 | CREATE INDEX ix_{0}_dimension 56 | ON public.{0} USING btree 57 | (dimension COLLATE pg_catalog."default") 58 | TABLESPACE pg_default; 59 | 60 | -- Index: ix_{0}_indicator 61 | 62 | -- DROP INDEX public.ix_{0}_indicator; 63 | 64 | CREATE INDEX ix_{0}_indicator 65 | ON public.{0} USING btree 66 | (indicator COLLATE pg_catalog."default") 67 | TABLESPACE pg_default; 68 | 69 | -- Index: ix_{0}_symbol 70 | 71 | -- DROP INDEX public.ix_{0}_symbol; 72 | 73 | CREATE INDEX ix_{0}_symbol 74 | ON public.{0} USING btree 75 | (symbol COLLATE pg_catalog."default") 76 | TABLESPACE pg_default; 77 | """ 78 | 79 | 80 | def bulkinsert_SF0(url: str, table_name: str = 'quandl_sf0'): 81 | con = psycopg2.connect(url) 82 | con.autocommit = True 83 | cur = con.cursor() 84 | 85 | cur.execute("DROP TABLE IF EXISTS {0};".format(table_name)) 86 | 87 | cur.execute(create_sf.format(table_name)) 88 | 89 | data = bulkdownload_sf0() 90 | con = psycopg2.connect(url) 91 | con.autocommit = True 92 | insert_df(con, table_name, data) 93 | 94 | cur.execute(create_sf_indices.format(table_name)) 95 | 96 | 97 | def request_sf(conn, symbol: typing.Union[list, str] = None, bgn_prd: datetime.datetime = None, end_prd: datetime.datetime = None, table_name: str = 'quandl_SF0', selection='*'): 98 | """ 99 | Request bar data 100 | :param conn: connection 101 | :param table_name: table name 102 | :param symbol: symbol or a list of symbols 103 | :param bgn_prd: start period (including) 104 | :param end_prd: end period (excluding) 105 | :param selection: what to select 106 | :return: dataframe 107 | """ 108 | where = " WHERE 1=1" 109 | params = list() 110 | 111 | if isinstance(symbol, list): 112 | where += " AND symbol IN (%s)" % ','.join(['%s'] * len(symbol)) 113 | params += symbol 114 | elif isinstance(symbol, str): 115 | where += " AND symbol = %s" 116 | params.append(symbol) 117 | 118 | if bgn_prd is not None: 119 | where += " AND date >= %s" 120 | params.append(str(bgn_prd)) 121 | 122 | if end_prd is not None: 123 | where += " AND date <= %s" 124 | params.append(str(end_prd)) 125 | 126 | df = pd.read_sql("SELECT " + selection + " FROM " + table_name + where + " ORDER BY date, symbol", con=conn, index_col=['date', 'symbol', 'indicator', 'dimension'], params=params) 127 | 128 | if not df.empty: 129 | df = df.tz_localize('UTC', level='date', copy=False) 130 | 131 | return df 132 | 133 | 134 | class SFInPeriodProvider(object): 135 | """ 136 | SF (0 or 1) dataset in period provider 137 | """ 138 | 139 | def __init__(self, conn, bgn_prd: datetime.datetime, delta: relativedelta, symbol: typing.Union[list, str] = None, ascend: bool = True, table_name: str = 'quandl_SF0', overlap: relativedelta = None): 140 | self._periods = slice_periods(bgn_prd=bgn_prd, delta=delta, ascend=ascend, overlap=overlap) 141 | 142 | self.conn = conn 143 | self.symbol = symbol 144 | self.ascending = ascend 145 | self.table_name = table_name 146 | 147 | def __iter__(self): 148 | self._deltas = -1 149 | return self 150 | 151 | def __next__(self): 152 | self._deltas += 1 153 | 154 | if self._deltas < len(self._periods): 155 | result = request_sf(conn=self.conn, symbol=self.symbol, bgn_prd=self._periods[self._deltas][0], end_prd=self._periods[self._deltas][1], table_name=self.table_name) 156 | if result.empty: 157 | raise StopIteration 158 | 159 | return result 160 | else: 161 | raise StopIteration 162 | -------------------------------------------------------------------------------- /atpy/data/splits_dividends.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | 7 | def adjust_df(data: pd.DataFrame, adjustments: pd.DataFrame): 8 | """ 9 | IMPORTANT !!! This method supports MultiIndex dataframes 10 | :param data: dataframe with data. 11 | :param adjustments: list of adjustments in the form of [(date, split_factor/dividend_amount, 'split'/'dividend'), ...] 12 | :return adjusted data 13 | """ 14 | if not data.empty and not adjustments.empty: 15 | idx = pd.IndexSlice 16 | start = data.iloc[0].name[0] if isinstance(data.iloc[0].name, tuple) else data.iloc[0].name 17 | end = data.iloc[-1].name[0] if isinstance(data.iloc[0].name, tuple) else data.iloc[-1].name 18 | 19 | adjustments = adjustments.loc[idx[start:end, list(adjustments.index.levels[1]), :, :], :].sort_index(ascending=False) 20 | 21 | if isinstance(data.index, pd.MultiIndex): 22 | for (_, row) in adjustments.iterrows(): 23 | if row.name[2] == 'split': 24 | adjust_split_multiindex(data=data, symbol=row.name[1], split_date=row.name[0], split_factor=row[0]) 25 | elif row.name[2] == 'dividend': 26 | adjust_dividend_multiindex(data=data, symbol=row.name[1], dividend_date=row.name[0], dividend_amount=row[0]) 27 | else: 28 | for (_, row) in adjustments.iterrows(): 29 | if row.name[2] == 'split': 30 | adjust_split(data=data, split_date=row.name[0], split_factor=row[0]) 31 | elif row.name[2] == 'dividend': 32 | adjust_dividend(data=data, dividend_date=row.name[0], dividend_amount=row[0]) 33 | 34 | return data 35 | 36 | 37 | def adjust(data, adjustments: pd.DataFrame): 38 | """ 39 | IMPORTANT !!! This method supports single index df 40 | :param data: dataframe with data. 41 | :param adjustments: list of adjustments in the form of [(date, split_factor/dividend_amount, 'split'/'dividend'), ...] 42 | :return adjusted data 43 | """ 44 | adjustments.sort(key=lambda x: x[0], reverse=True) 45 | 46 | for (_, row) in adjustments.iterrows(): 47 | if row.name[2] == 'split': 48 | adjust_split(data=data, split_date=row.name[0], split_factor=row[0]) 49 | elif row.name[2] == 'dividend': 50 | adjust_dividend(data=data, dividend_date=row.name[0], dividend_amount=row[0]) 51 | 52 | return data 53 | 54 | 55 | def adjust_dividend(data, dividend_amount: float, dividend_date: datetime.date): 56 | if isinstance(data, pd.DataFrame): 57 | if len(data) > 0: 58 | dividend_date = datetime.datetime.combine(dividend_date, datetime.datetime.min.time()).replace(tzinfo=data.iloc[0]['timestamp'].tz) 59 | 60 | if dividend_date > data.iloc[0]['timestamp']: 61 | for c in [c for c in ['close', 'high', 'open', 'low', 'ask', 'bid', 'last'] if c in data.columns]: 62 | data[c] -= dividend_amount 63 | elif dividend_date > data.iloc[-1]['timestamp']: 64 | for c in [c for c in ['close', 'high', 'open', 'low', 'ask', 'bid', 'last'] if c in data.columns]: 65 | data.loc[data['timestamp'] < dividend_date, c] -= dividend_amount 66 | elif dividend_date > data['timestamp']: 67 | for c in [c for c in {'close', 'high', 'open', 'low', 'ask', 'bid', 'last'} if c in data.keys()]: 68 | data[c] -= dividend_amount 69 | 70 | 71 | def adjust_split(data, split_factor: float, split_date: datetime.date): 72 | if split_factor > 0: 73 | if isinstance(data, pd.DataFrame): 74 | if len(data) > 0: 75 | split_date = datetime.datetime.combine(split_date, datetime.datetime.min.time()).replace(tzinfo=data.iloc[0]['timestamp'].tz) 76 | 77 | if split_date > data.iloc[-1]['timestamp']: 78 | for c in [c for c in ['volume', 'total_volume', 'last_size'] if c in data.columns]: 79 | data[c] = (data[c] * (1 / split_factor)).astype(np.uint64) 80 | 81 | for c in [c for c in ['close', 'high', 'open', 'low', 'ask', 'bid', 'last'] if c in data.columns]: 82 | data[c] *= split_factor 83 | elif split_date > data.iloc[0]['timestamp']: 84 | for c in [c for c in ['volume', 'total_volume', 'last_size'] if c in data.columns]: 85 | data.loc[data['timestamp'] < split_date, c] *= (1 / split_factor) 86 | data[c] = data[c].astype(np.uint64) 87 | 88 | for c in [c for c in ['close', 'high', 'open', 'low', 'ask', 'bid', 'last'] if c in data.columns]: 89 | data.loc[data['timestamp'] < split_date, c] *= split_factor 90 | elif split_date > data['timestamp']: 91 | for c in [c for c in {'close', 'high', 'open', 'low', 'volume', 'total_volume', 'ask', 'bid', 'last', 'last_size'} if c in data.keys()]: 92 | if c in ('volume', 'total_volume', 'last_size'): 93 | data[c] = int(data[c] * (1 / split_factor)) 94 | else: 95 | data[c] *= split_factor 96 | 97 | 98 | def adjust_dividend_multiindex(data: pd.DataFrame, symbol: str, dividend_amount: float, dividend_date: datetime.date): 99 | if len(data) > 0 and symbol in data.index.levels[1]: 100 | dividend_date = datetime.datetime.combine(dividend_date, datetime.datetime.min.time()).replace(tzinfo=data.iloc[0].name[0].tz) 101 | 102 | idx = pd.IndexSlice 103 | 104 | if data.loc[idx[:, symbol], :].iloc[0].name[0] <= dividend_date <= data.loc[idx[:, symbol], :].iloc[-1].name[0]: 105 | for c in [c for c in ['close', 'high', 'open', 'low', 'ask', 'bid', 'last'] if c in data.columns]: 106 | data.loc[idx[:dividend_date, symbol], c] -= dividend_amount 107 | 108 | 109 | def adjust_split_multiindex(data: pd.DataFrame, symbol: str, split_factor: float, split_date: datetime.date): 110 | if split_factor > 0 and len(data) > 0 and symbol in data.index.levels[1]: 111 | split_date = datetime.datetime.combine(split_date, datetime.datetime.min.time()).replace(tzinfo=data.iloc[0].name[0].tz) 112 | 113 | idx = pd.IndexSlice 114 | 115 | if data.loc[idx[:, symbol], :].iloc[0].name[0] <= split_date <= data.loc[idx[:, symbol], :].iloc[-1].name[0]: 116 | for c in [c for c in ['volume', 'total_volume', 'last_size'] if c in data.columns]: 117 | data.loc[idx[:split_date, symbol], c] *= (1 / split_factor) 118 | data.loc[idx[:split_date, symbol], c] = data[c].astype(np.uint64) 119 | for c in [c for c in ['close', 'high', 'open', 'low', 'ask', 'bid', 'last'] if c in data.columns]: 120 | data.loc[idx[:split_date, symbol], c] *= split_factor 121 | 122 | 123 | def exclude_splits(data: pd.Series, splits: pd.Series, quarantine_length: int): 124 | """ 125 | exclude data based on proximity to split event 126 | :param data: single/multiindex series with boolean values for each timestamp - True - include, False - exclude. 127 | :param splits: splits dataframe 128 | :param quarantine_length: number of moments to exclude 129 | :return adjusted data 130 | """ 131 | 132 | if isinstance(data.index, pd.MultiIndex): 133 | def tmp(datum): 134 | symbol = datum.index[0][datum.index.names.index('symbol')] 135 | datum_index = datum.xs(symbol, level='symbol').index 136 | 137 | try: 138 | for _s in splits.xs(symbol, level='symbol').index: 139 | _i = datum_index.searchsorted(_s[0]) 140 | if 0 < _i < datum_index.size: 141 | datum.iloc[_i: _i + quarantine_length] = False 142 | except KeyError: 143 | pass 144 | 145 | return datum 146 | 147 | data = data.groupby(level='symbol').apply(tmp) 148 | else: 149 | for s in splits.index: 150 | i = data.index.searchsorted(s[0]) 151 | if 0 < i < data.size: 152 | data.iloc[i: i + quarantine_length] = False 153 | 154 | return data 155 | -------------------------------------------------------------------------------- /atpy/data/ts_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Time series utils. 3 | """ 4 | import datetime 5 | import queue 6 | import threading 7 | import typing 8 | 9 | import pandas as pd 10 | from dateutil.relativedelta import relativedelta 11 | 12 | import atpy.data.tradingcalendar as tcal 13 | 14 | 15 | def set_periods(df: pd.DataFrame): 16 | """ 17 | Split the dataset into trading/after hours segments using columns (e.g. trading hours or after hours) 18 | :param df: dataframe (first index have to be datetime) 19 | :return sliced period 20 | """ 21 | 22 | df['period'] = 'after-hours' 23 | 24 | lc = tcal.open_and_closes.loc[df.iloc[0].name[0].date():df.iloc[-1].name[0].date()] 25 | 26 | xs = pd.IndexSlice 27 | 28 | def a(x): 29 | df.loc[xs[x['market_open']:x['market_close'], :] if isinstance(df.index, pd.MultiIndex) else xs[x['market_open']:x['market_close']], 'period'] = 'trading-hours' 30 | 31 | lc.apply(a, axis=1) 32 | 33 | 34 | __open_and_closes_series = pd.concat([tcal.open_and_closes['market_open'], tcal.open_and_closes['market_close']]).sort_values() 35 | __closes_series = tcal.open_and_closes['market_close'] 36 | 37 | 38 | def current_period(df: pd.DataFrame): 39 | """ 40 | Slice only the current period (e.g. trading hours or after hours) 41 | :param df: dataframe (first index have to be datettime) 42 | :return sliced period 43 | """ 44 | 45 | most_recent = df.iloc[-1].name[0] 46 | 47 | try: 48 | current_hours = __open_and_closes_series.loc[most_recent.date()] 49 | except KeyError: 50 | current_hours = None 51 | 52 | if current_hours is not None: 53 | if most_recent > current_hours[1]: 54 | result, period = df.loc[current_hours[1]:], 'after-hours' 55 | elif most_recent < current_hours[0]: 56 | lc = __closes_series.loc[df.iloc[0].name[0].date(): most_recent.date()] 57 | if len(lc) > 1: 58 | result, period = df.loc[lc[-2]:], 'after-hours' 59 | else: 60 | result, period = df, 'after-hours' 61 | else: 62 | result, period = df.loc[current_hours[0]:], 'trading-hours' 63 | else: 64 | lc = __closes_series.loc[df.iloc[0].name[0].date(): most_recent.date()] 65 | if len(lc) > 0: 66 | result, period = df.loc[lc.iloc[-1]:], 'after-hours' 67 | else: 68 | result, period = df, 'after-hours' 69 | 70 | return result, period 71 | 72 | 73 | def current_phase(dttme): 74 | """ 75 | Get current phase (trading/after-hours) 76 | :param dttme: datetime 77 | :return phase 78 | """ 79 | 80 | if dttme.date() in tcal.trading_days: 81 | current_hours = tcal.open_and_closes.loc[dttme.date()] 82 | return 'trading-hours' if current_hours['market_open'] <= dttme <= current_hours['market_close'] else 'after-hours' 83 | else: 84 | return 'after-hours' 85 | 86 | 87 | def current_day(df: pd.DataFrame, tz=None): 88 | """ 89 | Slice only the current day data 90 | :param df: dataframe (first index have to be datettime) 91 | :param tz: timezone 92 | :return sliced period 93 | """ 94 | d = df.iloc[-1].name[0].normalize() 95 | if tz is not None: 96 | d = d.tz_convert(tz).tz_localize(None).tz_localize(d.tzinfo) 97 | 98 | xs = pd.IndexSlice 99 | 100 | return df.loc[xs[d:, :] if isinstance(df.index, pd.MultiIndex) else xs[d:]] 101 | 102 | 103 | def slice_periods(bgn_prd: datetime.datetime, delta: relativedelta, ascend: bool = True, overlap: relativedelta = None): 104 | """ 105 | Split time interval in delta-sized intervals 106 | :param bgn_prd: begin period 107 | :param delta: delta 108 | :param ascend: ascending/descending 109 | :param overlap: whether to provide overlap within the intervals 110 | :return sliced period 111 | """ 112 | 113 | overlap = overlap if overlap is not None else relativedelta(days=0) 114 | 115 | result = list() 116 | if ascend: 117 | now = datetime.datetime.now(tz=bgn_prd.tzinfo) 118 | 119 | while bgn_prd < now: 120 | end_prd = min(bgn_prd + delta + overlap, now) 121 | result.append((bgn_prd, end_prd)) 122 | bgn_prd = bgn_prd + delta 123 | else: 124 | end_prd = datetime.datetime.now(tz=bgn_prd.tzinfo) 125 | 126 | while end_prd > bgn_prd: 127 | result.append((max(end_prd - delta - overlap, bgn_prd), end_prd)) 128 | end_prd = end_prd - delta 129 | 130 | return result 131 | 132 | 133 | def gaps(df: pd.DataFrame): 134 | """ 135 | Compute percent changes in the price 136 | :param df: pandas OHLC DataFrame 137 | :return DataFrame with changes 138 | """ 139 | 140 | result = df.groupby('symbol', level='symbol').agg({'low': 'min', 'high': 'max'}) 141 | 142 | low = result['low'] 143 | result = (result['high'] - low) / low 144 | 145 | return result 146 | 147 | 148 | def rolling_mean(df: pd.DataFrame, window: int, column: typing.Union[typing.List, str] = 'close'): 149 | """ 150 | Compute the rolling mean over a column 151 | :param df: pandas OHLC DataFrame 152 | :param window: window size OHLC DataFrame 153 | :param column: a column (or list of columns, where to apply the rolling mean) 154 | :return DataFrame with changes 155 | """ 156 | return df[column].groupby(level='symbol', group_keys=False).rolling(window).mean() 157 | 158 | 159 | def ohlc_mean(df: pd.DataFrame): 160 | """ 161 | Compute the mean value of o/h/l/c 162 | :param df: pandas OHLC DataFrame 163 | :return DataFrame with changes 164 | """ 165 | df[['open', 'open', 'high', 'low']].mean(axis=1) 166 | 167 | 168 | def overlap_by_symbol(old_df: pd.DataFrame, new_df: pd.DataFrame, overlap: int): 169 | """ 170 | Overlap dataframes for timestamp continuity. Prepend the end of old_df to the beginning of new_df, grouped by symbol. 171 | If no symbol exists, just overlap the dataframes 172 | :param old_df: old dataframe 173 | :param new_df: new dataframe 174 | :param overlap: number of time steps to overlap 175 | :return DataFrame with changes 176 | """ 177 | if isinstance(old_df.index, pd.MultiIndex) and isinstance(new_df.index, pd.MultiIndex): 178 | old_df_tail = old_df.groupby(level='symbol').tail(overlap) 179 | 180 | old_df_tail = old_df_tail.drop(set(old_df_tail.index.get_level_values('symbol')) - set(new_df.index.get_level_values('symbol')), level='symbol') 181 | 182 | return pd.concat([old_df_tail, new_df], sort=True) 183 | else: 184 | return pd.concat([old_df.tail(overlap), new_df], sort=True) 185 | 186 | 187 | class AsyncInPeriodProvider(object): 188 | """ 189 | Run InPeriodProvider in async mode 190 | """ 191 | 192 | def __init__(self, in_period_provider: typing.Iterable): 193 | """ 194 | :param in_period_provider: provider 195 | """ 196 | 197 | self.in_period_provider = in_period_provider 198 | 199 | def __iter__(self): 200 | self._q = queue.Queue() 201 | 202 | self.in_period_provider.__iter__() 203 | 204 | def it(): 205 | for i in self.in_period_provider: 206 | self._q.put(i) 207 | 208 | self._q.put(None) 209 | 210 | threading.Thread(target=it, daemon=True).start() 211 | 212 | return self 213 | 214 | def __next__(self): 215 | result = self._q.get() 216 | if result is None: 217 | raise StopIteration() 218 | 219 | return result 220 | -------------------------------------------------------------------------------- /atpy/data/util.py: -------------------------------------------------------------------------------- 1 | from ftplib import FTP 2 | from io import StringIO 3 | 4 | import pandas as pd 5 | 6 | 7 | def _get_nasdaq_symbol_file(filename): 8 | ftp = FTP('ftp.nasdaqtrader.com') 9 | ftp.login() 10 | ftp.cwd('symboldirectory') 11 | 12 | class Reader: 13 | def __init__(self): 14 | self.data = "" 15 | 16 | def __call__(self, s): 17 | self.data += s.decode('ascii') 18 | 19 | r = Reader() 20 | 21 | ftp.retrbinary('RETR ' + filename, r) 22 | return pd.read_csv(StringIO(r.data), sep="|")[:-1] 23 | 24 | 25 | def get_nasdaq_listed_companies(): 26 | result = _get_nasdaq_symbol_file('nasdaqlisted.txt') 27 | result = result.loc[(result['Financial Status'] == 'N') & (result['Test Issue'] == 'N')] 28 | 29 | include_only = set() 30 | include_only_index = list() 31 | for i in range(result.shape[0]): 32 | s = result.iloc[i] 33 | if len(s['Symbol']) < 5 or s['Symbol'][:4] not in include_only: 34 | include_only_index.append(True) 35 | include_only.add(s['Symbol']) 36 | else: 37 | include_only_index.append(False) 38 | 39 | return result[include_only_index] 40 | 41 | 42 | def get_non_nasdaq_listed_companies(): 43 | result = _get_nasdaq_symbol_file('otherlisted.txt') 44 | result = result[result['Test Issue'] == 'N'] 45 | 46 | return result 47 | 48 | 49 | def get_us_listed_companies(): 50 | nd = get_nasdaq_listed_companies() 51 | non_nd = get_non_nasdaq_listed_companies() 52 | symbols = list(set(list(non_nd[0]) + list(nd[0]))) 53 | symbols.sort() 54 | 55 | return pd.DataFrame(symbols) 56 | 57 | 58 | def get_s_and_p_500(): 59 | return pd.read_csv('https://raw.githubusercontent.com/datasets/s-and-p-500-companies/master/data/constituents.csv').set_index('Symbol', drop=True) 60 | 61 | 62 | def resample_bars(df: pd.DataFrame, rule: str, period_id: str = 'right') -> pd.DataFrame: 63 | """ 64 | Resample bars in higher periods 65 | :param df: data frame 66 | :param rule: conversion target period (for reference see pandas.DataFrame.resample) 67 | :param period_id: whether to associate the bar with the beginning or the end of the interval 68 | (the inclusion is also closed to the left or right respectively) 69 | """ 70 | if isinstance(df.index, pd.MultiIndex): 71 | result = df.groupby(level='symbol', group_keys=False, sort=False) \ 72 | .resample(rule, closed=period_id, label=period_id, level='timestamp') \ 73 | .agg({'open': 'first', 'high': 'max', 'low': 'min', 'close': 'last', 'volume': 'sum'}) \ 74 | .dropna() 75 | else: 76 | result = df.resample(rule, closed=period_id, label=period_id) \ 77 | .agg({'open': 'first', 'high': 'max', 'low': 'min', 'close': 'last', 'volume': 'sum'}) \ 78 | .dropna() 79 | 80 | return result 81 | -------------------------------------------------------------------------------- /atpy/ibapi/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivan-vasilev/atpy/abe72832ae8cec818b0e67989892c25456e9e5f5/atpy/ibapi/__init__.py -------------------------------------------------------------------------------- /atpy/ibapi/ib_events.py: -------------------------------------------------------------------------------- 1 | import threading 2 | 3 | import pandas as pd 4 | from ibapi.client import EClient 5 | from ibapi.common import OrderId 6 | from ibapi.contract import Contract 7 | from ibapi.order import Order 8 | from ibapi.wrapper import EWrapper 9 | 10 | import atpy.portfolio.order as orders 11 | from atpy.portfolio.order import * 12 | 13 | 14 | class DefaultWrapper(EWrapper): 15 | def __init__(self, listeners): 16 | EWrapper.__init__(self) 17 | 18 | self.listeners = listeners 19 | self.next_valid_order_id = -1 20 | self._pending_orders = dict() 21 | self._has_valid_id = threading.Event() 22 | self._lock = threading.RLock() 23 | self._positions = None 24 | 25 | def nextValidId(self, orderId: int): 26 | self.next_valid_order_id = orderId 27 | self._has_valid_id.set() 28 | 29 | def orderStatus(self, orderId: OrderId, status: str, filled: float, 30 | remaining: float, avgFillPrice: float, permId: int, 31 | parentId: int, lastFillPrice: float, clientId: int, 32 | whyHeld: str, mktCapPrice: float): 33 | 34 | if status == 'Filled': 35 | if orderId in self._pending_orders: 36 | order = self._pending_orders[orderId] 37 | del self._pending_orders[orderId] 38 | order.uid = orderId 39 | 40 | self.listeners({'type': 'order_fulfilled', 'data': order}) 41 | else: 42 | self.listeners({'type': 'order_fulfilled', 'data': orderId}) 43 | elif status in ('Inactive', 'ApiCanceled', 'Cancelled') and orderId in self._pending_orders: 44 | del self._pending_orders[orderId] 45 | 46 | def position(self, account: str, contract: Contract, position: float, avgCost: float): 47 | """This event returns real-time positions for all accounts in 48 | response to the reqPositions() method.""" 49 | 50 | with self._lock: 51 | if self._positions is None: 52 | self._positions = {k: list() for k in list(contract.__dict__.keys()) + ['position', 'avgCost']} 53 | 54 | for k, v in {**contract.__dict__, **{'position': position, 'avgCost': avgCost}}.items(): 55 | self._positions[k].append(v) 56 | 57 | def positionEnd(self): 58 | """This is called once all position data for a given request are 59 | received and functions as an end marker for the position() data. """ 60 | 61 | with self._lock: 62 | data = None if self._positions is None else pd.DataFrame.from_dict(self._positions) 63 | self._positions = None 64 | 65 | if data is not None: 66 | self.listeners({'type': 'ibapi_positions', 'data': data}) 67 | 68 | def error(self, reqId: int, errorCode: int, errorString: str): 69 | super().error(reqId=reqId, errorCode=errorCode, errorString=errorString) 70 | self.listeners({'type': 'ibapi_error', 'data': {'reqId': reqId, 'errorCode': errorCode, 'errorString': errorString}}) 71 | 72 | 73 | class DefaultClient(EClient): 74 | def __init__(self, wrapper): 75 | EClient.__init__(self, wrapper) 76 | 77 | 78 | class IBEvents(DefaultWrapper, DefaultClient): 79 | def __init__(self, listeners, ipaddress, portid, clientid): 80 | DefaultWrapper.__init__(self, listeners) 81 | DefaultClient.__init__(self, wrapper=self) 82 | 83 | self.listeners = listeners 84 | self.listeners += self.on_event 85 | 86 | self.ipaddress = ipaddress 87 | self.portid = portid 88 | self.clientid = clientid 89 | self.lock = threading.RLock() 90 | 91 | def __enter__(self): 92 | self.connect(self.ipaddress, self.portid, self.clientid) 93 | 94 | if self.isConnected(): 95 | thread = threading.Thread(target=self.run) 96 | thread.start() 97 | 98 | setattr(self, "_thread", thread) 99 | 100 | self.reqIds(-1) 101 | else: 102 | raise Exception("Not connected. First connect via IB Gateway or TWS") 103 | 104 | def __exit__(self, exception_type, exception_value, traceback): 105 | """Disconnect connection etc""" 106 | self.done = True 107 | 108 | def on_event(self, event): 109 | if event['type'] == 'order_request': 110 | self.process_order_request(event['data']) 111 | elif event['type'] == 'positions_request': 112 | self.reqPositions() 113 | 114 | def process_order_request(self, order): 115 | with self.lock: 116 | ibcontract = Contract() 117 | ibcontract.symbol = order.symbol 118 | ibcontract.secType = "STK" 119 | ibcontract.currency = "USD" 120 | ibcontract.exchange = "SMART" 121 | 122 | iborder = Order() 123 | 124 | iborder.action = 'BUY' if order.order_type == Type.BUY else 'SELL' if order.order_type == Type.SELL else None 125 | 126 | if isinstance(order, orders.MarketOrder): 127 | iborder.orderType = "MKT" 128 | elif isinstance(order, orders.LimitOrder): 129 | iborder.orderType = "LMT" 130 | order.lmtPrice = order.price 131 | elif isinstance(order, orders.StopMarketOrder): 132 | iborder.orderType = "STP" 133 | order.auxPrice = order.price 134 | elif isinstance(order, orders.StopLimitOrder): 135 | iborder.orderType = "STP LMT" 136 | order.lmtPrice = order.limit_price 137 | order.auxPrice = order.stop_price 138 | 139 | iborder.totalQuantity = order.quantity 140 | 141 | self._has_valid_id.wait() 142 | 143 | self._pending_orders[self.next_valid_order_id] = order 144 | 145 | self.placeOrder(self.next_valid_order_id, ibcontract, iborder) 146 | 147 | self.next_valid_order_id += 1 148 | 149 | def reqPositions(self): 150 | with self.lock: 151 | if self._positions is None: 152 | super().reqPositions() 153 | -------------------------------------------------------------------------------- /atpy/ml/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivan-vasilev/atpy/abe72832ae8cec818b0e67989892c25456e9e5f5/atpy/ml/__init__.py -------------------------------------------------------------------------------- /atpy/ml/cross_validation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | """ 5 | Chapter 6 of Advances in Financial Machine Learning book by Marcos Lopez de Prado 6 | """ 7 | 8 | 9 | def cv_split(data: pd.DataFrame, splits: int, current_split: int, embargo: float = 0.01): 10 | """ 11 | cross validation data split, as implemented in chapter 6 of the book 12 | :param data: single/multiindex dataframe with boolean values for each timestamp - True - include, False - exclude. 13 | :param splits: number of splits 14 | :param current_split: current split (0 based) 15 | :param embargo: current split (0 based) 16 | :return timestamps that exclude the purged areas and all other splits except the specified 17 | """ 18 | 19 | if isinstance(data.index, pd.MultiIndex): 20 | data = data.reset_index(level=data.index.names.index('symbol'), drop=True).sort_index() 21 | 22 | result = pd.Series(True, index=data.index, dtype=np.bool) 23 | 24 | purge = cv_purge(data['interval_end'], splits=splits, current_split=current_split, embargo=embargo) 25 | 26 | if purge.size > 0: 27 | if current_split == 0: 28 | result.loc[purge[-1]:] = False 29 | elif current_split == splits - 1: 30 | result.loc[:purge[0]] = False 31 | else: 32 | result.loc[:purge[0]] = False 33 | result.loc[purge[-1]:] = False 34 | 35 | result.loc[purge] = False 36 | 37 | return result[result].index 38 | 39 | 40 | def cv_split_reverse(data: pd.DataFrame, splits: int, current_split: int, embargo: float = 0.01): 41 | """ 42 | cross validation data split, as implemented in chapter 6 of the book 43 | :param data: single/multiindex dataframe with boolean values for each timestamp - True - include, False - exclude. 44 | :param splits: number of splits 45 | :param current_split: current split (0 based) 46 | :param embargo: current split (0 based) 47 | :return timestamps that exclude the purged areas and the specified split, but including all other splits 48 | """ 49 | 50 | if isinstance(data.index, pd.MultiIndex): 51 | data = data.reset_index(level=data.index.names.index('symbol'), drop=True).sort_index() 52 | 53 | result = pd.Series(True, index=data.index, dtype=np.bool) 54 | 55 | purge = cv_purge(data['interval_end'], splits=splits, current_split=current_split, embargo=embargo) 56 | 57 | if purge.size > 0: 58 | if current_split == 0: 59 | result.loc[:purge[-1]] = False 60 | elif current_split == splits - 1: 61 | result.loc[purge[0]:] = False 62 | else: 63 | result.loc[purge[0]:purge[-1]] = False 64 | 65 | result.loc[purge] = False 66 | 67 | return result[result].index 68 | 69 | 70 | def cv_purge(data: pd.Series, splits: int, current_split: int, embargo: float = 0.01): 71 | """ 72 | cross validation data split purge areas, as implemented in chapter 6 of the book 73 | :param data: single index series with pd.DateTimeIndex as index (interval start) and pd.DateTimeIndex as value (interval end). 74 | This series is produced by the labeling method 75 | :param splits: number of splits 76 | :param current_split: current split (0 based) 77 | :param embargo: current split (0 based) 78 | :return boolean series indicating the purged data marked as false 79 | """ 80 | result = pd.Series(True, index=data.index, dtype=np.bool) 81 | 82 | timedelta = data.index.max() - data.index.min() 83 | split_length = timedelta / splits 84 | current_split_start, current_split_end = data.index.min() + split_length * current_split, data.index.min() + split_length * (current_split + 1) 85 | 86 | result.loc[((data.index < current_split_start) & (data >= current_split_start)) 87 | | ((data.index <= current_split_end) & (data > current_split_end))] = False 88 | 89 | if current_split < splits - 1 and embargo > 0: 90 | embargo_start = result[~result] 91 | embargo_start = embargo_start.index[-1] if embargo_start.size > 0 else current_split_end 92 | result[embargo_start:embargo_start + timedelta * embargo] = False 93 | 94 | return result[~result].index 95 | -------------------------------------------------------------------------------- /atpy/ml/frac_diff_features.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import cpu_count 2 | from multiprocessing.pool import Pool 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | """ 8 | Chapter 5 of Advances in Financial Machine Learning book by Marcos Lopez de Prado 9 | """ 10 | 11 | 12 | def get_weights_ffd(d, threshold=1e-5): 13 | """ 14 | Obtain the weights for the binomial representation of time series 15 | :param d: coefficient 16 | :param threshold: threshold 17 | """ 18 | 19 | w, k = [1.], 1 20 | while abs(w[-1]) > threshold: 21 | w_ = -w[-1] / k * (d - k + 1) 22 | w.append(w_) 23 | k += 1 24 | 25 | w = np.array(w[::-1]) 26 | 27 | return w 28 | 29 | 30 | def _frac_diff_ffd(data: pd.Series, d: float, threshold=1e-5): 31 | """ 32 | Fractionally Differentiated Features Fixed Window 33 | :param data: data 34 | :param d: difference coefficient 35 | :param threshold: threshold 36 | """ 37 | # 1) Compute weights for the longest series 38 | w = get_weights_ffd(d, threshold) 39 | # 2) Apply weights to values 40 | return data.rolling(w.size, min_periods=w.size).apply(lambda x: np.dot(x, w), raw=True).dropna() 41 | 42 | 43 | def frac_diff_ffd(data: pd.Series, d: float, threshold=1e-5, parallel=False): 44 | """ 45 | Fractionally Differentiated Features Fixed Window 46 | :param data: data 47 | :param d: difference coefficient 48 | :param threshold: threshold 49 | :param parallel: run in parallel 50 | """ 51 | if isinstance(data.index, pd.MultiIndex): 52 | if parallel: 53 | with Pool(cpu_count()) as p: 54 | ret_list = p.starmap(_frac_diff_ffd, [(group, d, threshold) for name, group in data.groupby(level='symbol', group_keys=False, sort=False)]) 55 | return pd.concat(ret_list) 56 | else: 57 | return data.groupby(level='symbol', group_keys=False, sort=False).apply(_frac_diff_ffd, d=d, threshold=threshold) 58 | else: 59 | return _frac_diff_ffd(data=data, d=d, threshold=threshold) 60 | 61 | # def plot_weights(d_range, n_plots): 62 | # import matplotlib.pyplot as plt 63 | # 64 | # w = pd.DataFrame() 65 | # for d in np.linspace(d_range[0], d_range[1], n_plots): 66 | # w_ = get_weights_ffd(d) 67 | # w_ = pd.DataFrame(w_, index=range(w_.shape[0])[::-1], columns=[d]) 68 | # w = w.join(w_, how='outer') 69 | # ax = w.plot() 70 | # ax.legend(loc='upper left'); 71 | # plt.show() 72 | # return 73 | # 74 | # 75 | # if __name__ == '__main__': 76 | # plot_weights(d_range=[0, 1], n_plots=11) 77 | # plot_weights(d_range=[1, 2], n_plots=11) 78 | -------------------------------------------------------------------------------- /atpy/portfolio/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivan-vasilev/atpy/abe72832ae8cec818b0e67989892c25456e9e5f5/atpy/portfolio/__init__.py -------------------------------------------------------------------------------- /atpy/portfolio/order.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import uuid 3 | from abc import ABCMeta 4 | from enum import Enum 5 | 6 | from dateutil import tz 7 | 8 | 9 | class Type(Enum): 10 | BUY = 1 11 | SELL = 2 12 | 13 | 14 | class BaseOrder(object, metaclass=ABCMeta): 15 | def __init__(self, order_type: Type, symbol: str, quantity: int, uid=None): 16 | self.uid = uid if uid is not None else uuid.uuid4() 17 | self.order_type = order_type 18 | self.symbol = symbol 19 | 20 | if quantity <= 0: 21 | raise ValueError("quantity > 0") 22 | 23 | self.quantity = quantity 24 | 25 | self.__obtained_positions = list() 26 | self.request_time = datetime.datetime.utcnow().replace(tzinfo=tz.gettz('UTC')) 27 | self.commission = 0 28 | self.__fulfill_time = None 29 | 30 | @property 31 | def fulfill_time(self): 32 | return self.__fulfill_time 33 | 34 | @fulfill_time.setter 35 | def fulfill_time(self, fulfill_time): 36 | if self.obtained_quantity != self.quantity: 37 | raise Exception("Order is not fulfilled. Obtained %d of %d" % (self.obtained_quantity, + str(self.quantity))) 38 | 39 | self.__fulfill_time = fulfill_time 40 | 41 | @property 42 | def obtained_quantity(self): 43 | return sum([op[0] for op in self.__obtained_positions]) 44 | 45 | def add_position(self, quantity, price): 46 | if self.obtained_quantity >= self.quantity: 47 | raise Exception("Order already fulfilled") 48 | 49 | self.__obtained_positions.append((quantity if self.quantity - self.obtained_quantity >= quantity else self.quantity - self.obtained_quantity, price)) 50 | 51 | if self.obtained_quantity >= self.quantity: 52 | self.__fulfill_time = datetime.datetime.utcnow().replace(tzinfo=tz.gettz('UTC')) 53 | 54 | return True 55 | 56 | @property 57 | def cost(self): 58 | return sum([p[0] * p[1] for p in self.__obtained_positions]) 59 | 60 | @property 61 | def last_cost_per_share(self): 62 | return self.__obtained_positions[-1][1] 63 | 64 | def __str__(self): 65 | result = str(self.order_type).split('.')[1] + " " + self.symbol + " " + str(self.quantity) 66 | if self.obtained_quantity > 0: 67 | result += "; fulfilled: %d for %.3f" % (self.obtained_quantity, self.cost) 68 | if self.__fulfill_time is not None: 69 | result += " in %ss" % str(self.__fulfill_time - self.request_time) 70 | 71 | if self.commission is not None: 72 | result += "; commission: %.3f;" % self.commission 73 | 74 | return result 75 | 76 | 77 | class MarketOrder(BaseOrder): 78 | pass 79 | 80 | 81 | class LimitOrder(BaseOrder): 82 | def __init__(self, order_type: Type, symbol: str, quantity: int, price: float, uid=None): 83 | super().__init__(order_type, symbol, quantity, uid=uid) 84 | self.price = price 85 | 86 | def add_position(self, quantity, price): 87 | if (self.order_type == Type.BUY and self.price < price) or (self.order_type == Type.SELL and self.price > price): 88 | return False 89 | 90 | return super().add_position(quantity, price) 91 | 92 | 93 | class StopMarketOrder(BaseOrder): 94 | def __init__(self, order_type: Type, symbol: str, quantity: int, price: float, uid=None): 95 | super().__init__(order_type, symbol, quantity, uid=uid) 96 | 97 | self.price = price 98 | self._is_market = False 99 | 100 | def add_position(self, quantity, price): 101 | if (self.order_type == Type.BUY and self.price >= price) or (self.order_type == Type.SELL and self.price <= price): 102 | self._is_market = True 103 | 104 | return super().add_position(quantity, price) if self._is_market else False 105 | 106 | 107 | class StopLimitOrder(BaseOrder): 108 | def __init__(self, order_type: Type, symbol: str, quantity: int, stop_price: float, limit_price: float, uid=None): 109 | super().__init__(order_type, symbol, quantity, uid=uid) 110 | 111 | self.stop_price = stop_price 112 | self.limit_price = limit_price 113 | self._is_limit = False 114 | 115 | def add_position(self, quantity, price): 116 | if (self.order_type == Type.BUY and self.stop_price >= price) or (self.order_type == Type.SELL and self.stop_price <= price): 117 | self._is_limit = True 118 | 119 | if self._is_limit and (self.order_type == Type.BUY and self.limit_price < price) or (self.order_type == Type.SELL and self.limit_price > price): 120 | return super().add_position(quantity, price) 121 | 122 | return False 123 | -------------------------------------------------------------------------------- /atpy/portfolio/portfolio_manager.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import threading 3 | from collections import Collection 4 | 5 | import pandas as pd 6 | 7 | from atpy.portfolio.order import * 8 | from pyevents.events import EventFilter 9 | 10 | 11 | class PortfolioManager(object): 12 | """Orders portfolio manager""" 13 | 14 | def __init__(self, listeners, initial_capital: float, fulfilled_orders_event_stream, bar_event_stream=None, tick_event_stream=None, uid=None, orders=None): 15 | """ 16 | :param fulfilled_orders_event_stream: event stream for fulfilled order events 17 | :param bar_event_stream: event stream for bar data events 18 | :param tick_event_stream: event stream for tick data events 19 | :param uid: unique id for this portfolio manager 20 | :param orders: a list of pre-existing orders 21 | """ 22 | 23 | self.listeners = listeners 24 | 25 | fulfilled_orders_event_stream += self.add_order 26 | 27 | if bar_event_stream is not None: 28 | bar_event_stream += self.process_bar_data 29 | 30 | if tick_event_stream is not None: 31 | tick_event_stream += self.process_tick_data 32 | 33 | self.initial_capital = initial_capital 34 | self._id = uid if uid is not None else uuid.uuid4() 35 | self.orders = orders if orders is not None else list() 36 | self._lock = threading.RLock() 37 | self._values = dict() 38 | 39 | def add_order(self, order: BaseOrder): 40 | with self._lock: 41 | if order.fulfill_time is None: 42 | raise Exception("Order has no fulfill_time set") 43 | 44 | if len([o for o in self.orders if o.uid == order.uid]) > 0: 45 | raise Exception("Attempt to fulfill existing order") 46 | 47 | if order.order_type == Type.SELL and self._quantity(order.symbol) < order.quantity: 48 | raise Exception("Attempt to sell more shares than available") 49 | 50 | if order.order_type == Type.BUY and self._capital < order.cost: 51 | raise Exception("Not enough capital to fulfill order") 52 | 53 | self.orders.append(order) 54 | 55 | self.listeners({'type': 'watch_ticks', 'data': order.symbol}) 56 | self.listeners({'type': 'portfolio_update', 'data': self}) 57 | 58 | def portfolio_updates_stream(self): 59 | return EventFilter(listeners=self.listeners, 60 | event_filter=lambda e: True if ('type' in e and e['type'] == 'portfolio_update') else False, 61 | event_transformer=lambda e: (e['data'],)) 62 | 63 | @property 64 | def symbols(self): 65 | """Get list of all orders/symbols""" 66 | 67 | return set([o.symbol for o in self.orders]) 68 | 69 | @property 70 | def capital(self): 71 | """Get available capital (including orders)""" 72 | 73 | with self._lock: 74 | return self._capital 75 | 76 | @property 77 | def _capital(self): 78 | turnover = 0 79 | commissions = 0 80 | for o in self.orders: 81 | cost = o.cost 82 | if o.order_type == Type.SELL: 83 | turnover += cost 84 | elif o.order_type == Type.BUY: 85 | turnover -= cost 86 | 87 | commissions += o.commission 88 | 89 | return self.initial_capital + turnover - commissions 90 | 91 | def quantity(self, symbol=None): 92 | with self._lock: 93 | return self._quantity(symbol=symbol) 94 | 95 | def _quantity(self, symbol=None): 96 | if symbol is not None: 97 | quantity = 0 98 | 99 | for o in [o for o in self.orders if o.symbol == symbol]: 100 | if o.order_type == Type.BUY: 101 | quantity += o.quantity 102 | elif o.order_type == Type.SELL: 103 | quantity -= o.quantity 104 | 105 | return quantity 106 | else: 107 | result = dict() 108 | for s in set([o.symbol for o in self.orders]): 109 | qty = self._quantity(s) 110 | if qty > 0: 111 | result[s] = qty 112 | 113 | return result 114 | 115 | def value(self, symbol=None, multiply_by_quantity=False): 116 | with self._lock: 117 | return self._value(symbol=symbol, multiply_by_quantity=multiply_by_quantity) 118 | 119 | def _value(self, symbol=None, multiply_by_quantity=False): 120 | if symbol is not None: 121 | if symbol not in self._values: 122 | logging.getLogger(__name__).debug("No current information available for %s. Falling back to last traded price" % symbol) 123 | symbol_orders = [o for o in self.orders if o.symbol == symbol] 124 | order = sorted(symbol_orders, key=lambda o: o.fulfill_time, reverse=True)[0] 125 | return order.last_cost_per_share * (self._quantity(symbol=symbol) if multiply_by_quantity else 1) 126 | else: 127 | return self._values[symbol] * (self._quantity(symbol=symbol) if multiply_by_quantity else 1) 128 | else: 129 | result = dict() 130 | for s in set([o.symbol for o in self.orders]): 131 | result[s] = self._value(symbol=s, multiply_by_quantity=multiply_by_quantity) 132 | 133 | return result 134 | 135 | def process_tick_data(self, data): 136 | with self._lock: 137 | symbol = data['symbol'] 138 | if symbol in [o.symbol for o in self.orders]: 139 | self._values[symbol] = data['bid'][-1] if isinstance(data['bid'], Collection) else data['bid'] 140 | self.listeners({'type': 'portfolio_value_update', 'data': self}) 141 | 142 | def process_bar_data(self, data): 143 | with self._lock: 144 | symbols = data.index.get_level_values(level='symbol') 145 | 146 | for o in [o for o in self.orders if o.symbol in symbols]: 147 | slc = data.loc[pd.IndexSlice[:, o.symbol], 'close'] 148 | if not slc.empty: 149 | self._values[o.symbol] = slc[-1] 150 | self.listeners({'type': 'portfolio_value_update', 'data': self}) 151 | 152 | def __getstate__(self): 153 | # Copy the object's state from self.__dict__ which contains 154 | # all our instance attributes. Always use the dict.copy() 155 | # method to avoid modifying the original state. 156 | state = self.__dict__.copy() 157 | # Remove the unpicklable entries. 158 | del state['_lock'] 159 | del state['listeners'] 160 | 161 | return state 162 | 163 | def __setstate__(self, state): 164 | # Restore instance attributes (i.e., _lock). 165 | self.__dict__.update(state) 166 | self._lock = threading.RLock() 167 | -------------------------------------------------------------------------------- /scripts/iqfeed_to_postgres_bars_1d.py: -------------------------------------------------------------------------------- 1 | #!/bin/python3 2 | import argparse 3 | import logging 4 | import os 5 | 6 | if __name__ == "__main__": 7 | logging.basicConfig(level=logging.INFO) 8 | 9 | parser = argparse.ArgumentParser(description="PostgreSQL and IQFeed configuration") 10 | parser.add_argument('-drop', action='store_true', help="Drop the database") 11 | parser.add_argument('-cluster', action='store_true', help="Cluster the table after inserts") 12 | args = parser.parse_args() 13 | 14 | query = "python3 update_postgres_cache.py " + \ 15 | ("-drop" if args.drop else "") + \ 16 | (" -cluster" if args.cluster else "") + \ 17 | " -url='" + os.environ['POSTGRESQL_CACHE'] + "'" + \ 18 | " -table_name='bars_1d'" + \ 19 | " -interval_len=1" + \ 20 | " -interval_type='d'" + \ 21 | " -skip_if_older=30" 22 | 23 | os.system(query) 24 | -------------------------------------------------------------------------------- /scripts/iqfeed_to_postgres_bars_1m.py: -------------------------------------------------------------------------------- 1 | #!/bin/python3 2 | import argparse 3 | import logging 4 | import os 5 | 6 | if __name__ == "__main__": 7 | logging.basicConfig(level=logging.INFO) 8 | 9 | parser = argparse.ArgumentParser(description="PostgreSQL and IQFeed configuration") 10 | parser.add_argument('-drop', action='store_true', help="Drop the database") 11 | parser.add_argument('-cluster', action='store_true', help="Cluster the table after inserts") 12 | args = parser.parse_args() 13 | 14 | query = "python3 update_postgres_cache.py " + \ 15 | ("-drop" if args.drop else "") + \ 16 | (" -cluster" if args.cluster else "") + \ 17 | " -url='" + os.environ['POSTGRESQL_CACHE'] + "'" + \ 18 | " -table_name='bars_1m'" + \ 19 | " -interval_len=60" + \ 20 | " -interval_type='s'" + \ 21 | " -skip_if_older=30" 22 | 23 | os.system(query) 24 | -------------------------------------------------------------------------------- /scripts/iqfeed_to_postgres_bars_5m.py: -------------------------------------------------------------------------------- 1 | #!/bin/python3 2 | import argparse 3 | import logging 4 | import os 5 | 6 | if __name__ == "__main__": 7 | logging.basicConfig(level=logging.INFO) 8 | 9 | parser = argparse.ArgumentParser(description="PostgreSQL and IQFeed configuration") 10 | parser.add_argument('-drop', action='store_true', help="Drop the database") 11 | parser.add_argument('-cluster', action='store_true', help="Cluster the table after inserts") 12 | args = parser.parse_args() 13 | 14 | query = "python3 update_postgres_cache.py " + \ 15 | ("-drop" if args.drop else "") + \ 16 | (" -cluster" if args.cluster else "") + \ 17 | " -url='" + os.environ['POSTGRESQL_CACHE'] + "'" + \ 18 | " -table_name='bars_5m'" + \ 19 | " -interval_len=300" + \ 20 | " -interval_type='s'" + \ 21 | " -skip_if_older=30" 22 | 23 | os.system(query) 24 | -------------------------------------------------------------------------------- /scripts/iqfeed_to_postgres_bars_60m.py: -------------------------------------------------------------------------------- 1 | #!/bin/python3 2 | import argparse 3 | import logging 4 | import os 5 | 6 | if __name__ == "__main__": 7 | logging.basicConfig(level=logging.INFO) 8 | 9 | parser = argparse.ArgumentParser(description="PostgreSQL and IQFeed configuration") 10 | parser.add_argument('-drop', action='store_true', help="Drop the database") 11 | parser.add_argument('-cluster', action='store_true', help="Cluster the table after inserts") 12 | args = parser.parse_args() 13 | 14 | query = "python3 update_postgres_cache.py " + \ 15 | ("-drop" if args.drop else "") + \ 16 | (" -cluster" if args.cluster else "") + \ 17 | " -url='" + os.environ['POSTGRESQL_CACHE'] + "'" + \ 18 | " -table_name='bars_60m'" + \ 19 | " -interval_len=3600" + \ 20 | " -interval_type='s'" 21 | 22 | os.system(query) 23 | -------------------------------------------------------------------------------- /scripts/postgres_to_lmdb_bars_1d.py: -------------------------------------------------------------------------------- 1 | #!/bin/python3 2 | 3 | import argparse 4 | import datetime 5 | import logging 6 | import os 7 | 8 | import psycopg2 9 | from dateutil.relativedelta import relativedelta 10 | 11 | from atpy.data.cache.lmdb_cache import * 12 | from atpy.data.cache.postgres_cache import BarsInPeriodProvider 13 | 14 | if __name__ == "__main__": 15 | logging.basicConfig(level=logging.INFO) 16 | 17 | parser = argparse.ArgumentParser(description="PostgreSQL to LMDB configuration") 18 | parser.add_argument('-lmdb_path', type=str, default=None, help="LMDB Path") 19 | parser.add_argument('-delta_back', type=int, default=8, help="Default number of years to look back") 20 | args = parser.parse_args() 21 | 22 | lmdb_path = args.lmdb_path if args.lmdb_path is not None else os.environ['ATPY_LMDB_PATH'] 23 | 24 | con = psycopg2.connect(os.environ['POSTGRESQL_CACHE']) 25 | 26 | now = datetime.datetime.now() 27 | bgn_prd = datetime.datetime(now.year - args.delta_back, 1, 1) 28 | 29 | bars_in_period = BarsInPeriodProvider(conn=con, interval_len=1, interval_type='d', bars_table='bars_1d', bgn_prd=bgn_prd, delta=relativedelta(years=1), 30 | overlap=relativedelta(microseconds=-1)) 31 | 32 | for i, df in enumerate(bars_in_period): 33 | logging.info('Saving ' + bars_in_period.current_cache_key()) 34 | write(bars_in_period.current_cache_key(), df, lmdb_path) 35 | -------------------------------------------------------------------------------- /scripts/postgres_to_lmdb_bars_1m.py: -------------------------------------------------------------------------------- 1 | #!/bin/python3 2 | 3 | import argparse 4 | import datetime 5 | import functools 6 | import logging 7 | import os 8 | 9 | import psycopg2 10 | from dateutil.relativedelta import relativedelta 11 | 12 | from atpy.data.cache.lmdb_cache import * 13 | from atpy.data.cache.postgres_cache import BarsInPeriodProvider, request_adjustments 14 | from atpy.data.splits_dividends import adjust_df 15 | 16 | if __name__ == "__main__": 17 | logging.basicConfig(level=logging.INFO) 18 | 19 | parser = argparse.ArgumentParser(description="PostgreSQL to LMDB configuration") 20 | parser.add_argument('-lmdb_path', type=str, default=None, help="LMDB Path") 21 | parser.add_argument('-delta_back', type=int, default=8, help="Default number of years to look back") 22 | parser.add_argument('-adjust_splits', action='store_true', default=True, help="Adjust splits before saving") 23 | parser.add_argument('-adjust_dividends', action='store_true', default=False, help="Adjust dividends before saving") 24 | 25 | args = parser.parse_args() 26 | 27 | lmdb_path = args.lmdb_path if args.lmdb_path is not None else os.environ['ATPY_LMDB_PATH'] 28 | 29 | con = psycopg2.connect(os.environ['POSTGRESQL_CACHE']) 30 | 31 | adjustments = None 32 | if args.adjust_splits and args.adjust_dividends: 33 | adjustments = request_adjustments(conn=con, table_name='splits_dividends') 34 | elif args.adjust_splits: 35 | adjustments = request_adjustments(conn=con, table_name='splits_dividends', adj_type='split') 36 | elif args.adjust_dividends: 37 | adjustments = request_adjustments(conn=con, table_name='splits_dividends', adj_type='dividend') 38 | 39 | now = datetime.datetime.now() 40 | bgn_prd = datetime.datetime(now.year - args.delta_back, 1, 1) 41 | bgn_prd = bgn_prd + relativedelta(days=7 - bgn_prd.weekday()) 42 | 43 | cache_read = functools.partial(read_pickle, lmdb_path=lmdb_path) 44 | bars_in_period = BarsInPeriodProvider(conn=con, interval_len=60, interval_type='s', bars_table='bars_1m', bgn_prd=bgn_prd, delta=relativedelta(days=7), 45 | overlap=relativedelta(microseconds=-1), cache=cache_read) 46 | 47 | for i, df in enumerate(bars_in_period): 48 | if cache_read(bars_in_period.current_cache_key()) is None: 49 | if adjustments is not None: 50 | adjust_df(df, adjustments) 51 | 52 | write(bars_in_period.current_cache_key(), df, lmdb_path) 53 | logging.info('Saving ' + bars_in_period.current_cache_key()) 54 | else: 55 | logging.info('Cache hit on ' + bars_in_period.current_cache_key()) 56 | -------------------------------------------------------------------------------- /scripts/postgres_to_lmdb_bars_5m.py: -------------------------------------------------------------------------------- 1 | #!/bin/python3 2 | 3 | import argparse 4 | import datetime 5 | import functools 6 | import logging 7 | import os 8 | 9 | import psycopg2 10 | from dateutil.relativedelta import relativedelta 11 | 12 | from atpy.data.cache.lmdb_cache import * 13 | from atpy.data.cache.postgres_cache import BarsInPeriodProvider, request_adjustments 14 | from atpy.data.splits_dividends import adjust_df 15 | 16 | if __name__ == "__main__": 17 | logging.basicConfig(level=logging.INFO) 18 | 19 | parser = argparse.ArgumentParser(description="PostgreSQL to LMDB configuration") 20 | parser.add_argument('-lmdb_path', type=str, default=None, help="LMDB Path") 21 | parser.add_argument('-delta_back', type=int, default=8, help="Default number of years to look back") 22 | parser.add_argument('-adjust_splits', action='store_true', default=True, help="Adjust splits before saving") 23 | parser.add_argument('-adjust_dividends', action='store_true', default=False, help="Adjust dividends before saving") 24 | 25 | args = parser.parse_args() 26 | 27 | lmdb_path = args.lmdb_path if args.lmdb_path is not None else os.environ['ATPY_LMDB_PATH'] 28 | 29 | con = psycopg2.connect(os.environ['POSTGRESQL_CACHE']) 30 | 31 | adjustments = None 32 | if args.adjust_splits and args.adjust_dividends: 33 | adjustments = request_adjustments(conn=con, table_name='splits_dividends') 34 | elif args.adjust_splits: 35 | adjustments = request_adjustments(conn=con, table_name='splits_dividends', adj_type='split') 36 | elif args.adjust_dividends: 37 | adjustments = request_adjustments(conn=con, table_name='splits_dividends', adj_type='dividend') 38 | 39 | now = datetime.datetime.now() 40 | bgn_prd = datetime.datetime(now.year - args.delta_back, 1, 1) 41 | bgn_prd = bgn_prd + relativedelta(days=7 - bgn_prd.weekday()) 42 | 43 | cache_read = functools.partial(read_pickle, lmdb_path=lmdb_path) 44 | bars_in_period = BarsInPeriodProvider(conn=con, interval_len=300, interval_type='s', bars_table='bars_5m', bgn_prd=bgn_prd, delta=relativedelta(days=7), 45 | overlap=relativedelta(microseconds=-1), cache=cache_read) 46 | 47 | for i, df in enumerate(bars_in_period): 48 | if cache_read(bars_in_period.current_cache_key()) is None: 49 | if adjustments is not None: 50 | adjust_df(df, adjustments) 51 | 52 | write(bars_in_period.current_cache_key(), df, lmdb_path) 53 | logging.info('Saving ' + bars_in_period.current_cache_key()) 54 | else: 55 | logging.info('Cache hit on ' + bars_in_period.current_cache_key()) 56 | -------------------------------------------------------------------------------- /scripts/postgres_to_lmdb_bars_60m.py: -------------------------------------------------------------------------------- 1 | #!/bin/python3 2 | 3 | import argparse 4 | import datetime 5 | import functools 6 | import logging 7 | import os 8 | 9 | import psycopg2 10 | from dateutil.relativedelta import relativedelta 11 | 12 | from atpy.data.cache.lmdb_cache import * 13 | from atpy.data.cache.postgres_cache import BarsInPeriodProvider 14 | from atpy.data.cache.postgres_cache import request_adjustments 15 | from atpy.data.splits_dividends import adjust_df 16 | 17 | if __name__ == "__main__": 18 | logging.basicConfig(level=logging.INFO) 19 | 20 | parser = argparse.ArgumentParser(description="PostgreSQL to LMDB configuration") 21 | parser.add_argument('-lmdb_path', type=str, default=None, help="LMDB Path") 22 | parser.add_argument('-delta_back', type=int, default=8, help="Default number of years to look back") 23 | parser.add_argument('-adjust_splits', action='store_true', default=True, help="Adjust splits before saving") 24 | parser.add_argument('-adjust_dividends', action='store_true', default=False, help="Adjust dividends before saving") 25 | 26 | args = parser.parse_args() 27 | 28 | lmdb_path = args.lmdb_path if args.lmdb_path is not None else os.environ['ATPY_LMDB_PATH'] 29 | 30 | con = psycopg2.connect(os.environ['POSTGRESQL_CACHE']) 31 | 32 | adjustments = None 33 | if args.adjust_splits and args.adjust_dividends: 34 | adjustments = request_adjustments(conn=con, table_name='splits_dividends') 35 | elif args.adjust_splits: 36 | adjustments = request_adjustments(conn=con, table_name='splits_dividends', adj_type='split') 37 | elif args.adjust_dividends: 38 | adjustments = request_adjustments(conn=con, table_name='splits_dividends', adj_type='dividend') 39 | 40 | now = datetime.datetime.now() 41 | bgn_prd = datetime.datetime(now.year - args.delta_back, 1, 1) 42 | bgn_prd = bgn_prd + relativedelta(days=7 - bgn_prd.weekday()) 43 | 44 | cache_read = functools.partial(read_pickle, lmdb_path=lmdb_path) 45 | bars_in_period = BarsInPeriodProvider(conn=con, interval_len=3600, interval_type='s', bars_table='bars_60m', bgn_prd=bgn_prd, delta=relativedelta(days=7), 46 | overlap=relativedelta(microseconds=-1), cache=cache_read) 47 | 48 | for i, df in enumerate(bars_in_period): 49 | if cache_read(bars_in_period.current_cache_key()) is None: 50 | if adjustments is not None: 51 | adjust_df(df, adjustments) 52 | 53 | write(bars_in_period.current_cache_key(), df, lmdb_path) 54 | logging.info('Saving ' + bars_in_period.current_cache_key()) 55 | else: 56 | logging.info('Cache hit on ' + bars_in_period.current_cache_key()) 57 | -------------------------------------------------------------------------------- /scripts/quandl_sf0_to_postgres.py: -------------------------------------------------------------------------------- 1 | #!/bin/python3 2 | 3 | import os 4 | 5 | import psycopg2 6 | from sqlalchemy import create_engine 7 | 8 | from atpy.data.quandl.postgres_cache import bulkinsert_SF0 9 | 10 | if __name__ == "__main__": 11 | table_name = 'quandl_sf0' 12 | url = os.environ['POSTGRESQL_CACHE'] 13 | con = psycopg2.connect(url) 14 | con.autocommit = True 15 | 16 | engine = create_engine(url) 17 | 18 | bulkinsert_SF0(url, table_name=table_name) 19 | -------------------------------------------------------------------------------- /scripts/update_influxdb_cache.py: -------------------------------------------------------------------------------- 1 | #!/bin/python3 2 | 3 | """ 4 | Script that populates the InfluxDB cache initially and the updates it incrementally 5 | """ 6 | 7 | import argparse 8 | import logging 9 | 10 | from dateutil.relativedelta import relativedelta 11 | from influxdb import DataFrameClient 12 | 13 | import atpy.data.iqfeed.util as iqutil 14 | from atpy.data.cache.influxdb_cache import update_to_latest 15 | from atpy.data.iqfeed.iqfeed_history_provider import IQFeedHistoryProvider 16 | from atpy.data.iqfeed.iqfeed_influxdb_cache import noncache_provider 17 | 18 | if __name__ == "__main__": 19 | logging.basicConfig(level=logging.INFO) 20 | 21 | parser = argparse.ArgumentParser(description="InfluxDB and IQFeed configuration") 22 | 23 | parser.add_argument('-host', type=str, default='localhost', help="InfluxDB location host") 24 | parser.add_argument('-port', type=int, default=8086, help="InfluxDB host port") 25 | parser.add_argument('-user', type=str, default='root', help="InfluxDB username") 26 | parser.add_argument('-password', type=str, default='root', help="InfluxDB password") 27 | parser.add_argument('-database', type=str, default='cache', help="InfluxDB database name") 28 | parser.add_argument('-drop', action='store_true', help="Drop the database") 29 | parser.add_argument('-skip_if_older', type=int, default=None, help="Skip symbols, which are in the database, but have no activity for more than N previous days") 30 | parser.add_argument('-interval_len', type=int, default=None, required=True, help="Interval length") 31 | parser.add_argument('-interval_type', type=str, default='s', help="Interval type (seconds, days, etc)") 32 | parser.add_argument('-iqfeed_conn', type=int, default=10, help="Number of historical connections to IQFeed") 33 | parser.add_argument('-delta_back', type=int, default=10, help="Default number of years to look back") 34 | parser.add_argument('-symbols_file', type=str, default=None, help="location to locally saved symbols file (to prevent downloading it every time)") 35 | args = parser.parse_args() 36 | 37 | client = DataFrameClient(host=args.host, port=args.port, username=args.user, password=args.password, database=args.database, pool_size=1) 38 | 39 | logging.getLogger(__name__).info("Updating database with arguments: " + str(args)) 40 | 41 | if args.drop: 42 | client.drop_database(args.database) 43 | 44 | if args.database not in [d['name'] for d in client.get_list_database()]: 45 | client.create_database(args.database) 46 | client.query("ALTER RETENTION POLICY autogen ON cache DURATION INF REPLICATION 1 SHARD DURATION 2600w DEFAULT") 47 | 48 | client.switch_database(args.database) 49 | 50 | with IQFeedHistoryProvider(num_connections=args.iqfeed_conn) as history: 51 | all_symbols = {(s, args.interval_len, args.interval_type) for s in set(iqutil.get_symbols(symbols_file=args.symbols_file).keys())} 52 | update_to_latest(client=client, noncache_provider=noncache_provider(history), new_symbols=all_symbols, time_delta_back=relativedelta(years=args.delta_back), 53 | skip_if_older_than=relativedelta(days=args.skip_if_older) if args.skip_if_older is not None else None) 54 | 55 | client.close() 56 | -------------------------------------------------------------------------------- /scripts/update_influxdb_fundamentals_cache.py: -------------------------------------------------------------------------------- 1 | #!/bin/python3 2 | 3 | """ 4 | Script that updates the splits/dividends/fundamentals cache 5 | """ 6 | 7 | import argparse 8 | import logging 9 | 10 | from influxdb import InfluxDBClient 11 | 12 | import atpy.data.iqfeed.util as iqutil 13 | from atpy.data.iqfeed.iqfeed_influxdb_cache import update_fundamentals, update_splits_dividends 14 | from atpy.data.iqfeed.iqfeed_level_1_provider import IQFeedLevel1Listener 15 | from atpy.data.iqfeed.iqfeed_level_1_provider import get_fundamentals 16 | 17 | if __name__ == "__main__": 18 | logging.basicConfig(level=logging.INFO) 19 | 20 | parser = argparse.ArgumentParser(description="InfluxDB and IQFeed configuration") 21 | 22 | parser.add_argument('-host', type=str, default='localhost', help="InfluxDB location host") 23 | parser.add_argument('-port', type=int, default=8086, help="InfluxDB host port") 24 | parser.add_argument('-user', type=str, default='root', help="InfluxDB username") 25 | parser.add_argument('-password', type=str, default='root', help="InfluxDB password") 26 | parser.add_argument('-drop', action='store_true', help="Drop the measurements") 27 | parser.add_argument('-database', type=str, default='cache', help="InfluxDB database name") 28 | parser.add_argument('-update_fundamentals', default=True, help="Update Fundamental data") 29 | parser.add_argument('-update_splits_dividends', default=True, help="Update Splits and dividends") 30 | parser.add_argument('-symbols_file', type=str, default=None, help="location to locally saved symbols file (to prevent downloading it every time)") 31 | args = parser.parse_args() 32 | 33 | client = InfluxDBClient(host=args.host, port=args.port, username=args.user, password=args.password, database=args.database, pool_size=1) 34 | 35 | logging.getLogger(__name__).info("Updating database with arguments: " + str(args)) 36 | 37 | if args.drop: 38 | client.drop_measurement('iqfeed_fundamentals') 39 | client.query('DELETE FROM splits_dividends WHERE provider="iqfeed"') 40 | 41 | if args.database not in [d['name'] for d in client.get_list_database()]: 42 | client.create_database(args.database) 43 | client.query("ALTER RETENTION POLICY autogen ON cache DURATION INF REPLICATION 1 SHARD DURATION 2600w DEFAULT") 44 | 45 | client.switch_database(args.database) 46 | 47 | with IQFeedLevel1Listener(fire_ticks=False) as listener: 48 | all_symbols = set(iqutil.get_symbols(symbols_file=args.symbols_file).keys()) 49 | 50 | fundamentals = get_fundamentals(all_symbols) 51 | 52 | if args.update_fundamentals: 53 | update_fundamentals(client=client, fundamentals=fundamentals.values()) 54 | 55 | if args.update_splits_dividends: 56 | update_splits_dividends(client=client, fundamentals=fundamentals.values()) 57 | 58 | client.close() 59 | -------------------------------------------------------------------------------- /scripts/update_postgres_adjustments_cache.py: -------------------------------------------------------------------------------- 1 | #!/bin/python3 2 | """ 3 | Script that updates the bars/splits/dividends/fundamentals cache 4 | """ 5 | 6 | import argparse 7 | import logging 8 | import os 9 | 10 | import psycopg2 11 | 12 | import atpy.data.iqfeed.util as iqutil 13 | from atpy.data.cache.postgres_cache import insert_df_json, create_json_data 14 | from atpy.data.iqfeed.iqfeed_level_1_provider import get_splits_dividends, IQFeedLevel1Listener 15 | from pyevents.events import SyncListeners 16 | 17 | if __name__ == "__main__": 18 | logging.basicConfig(level=logging.INFO) 19 | 20 | parser = argparse.ArgumentParser(description="PostgreSQL and IQFeed configuration") 21 | 22 | parser.add_argument('-url', type=str, default=os.environ['POSTGRESQL_CACHE'], help="PostgreSQL connection string") 23 | parser.add_argument('-symbols_file', type=str, default=None, help="location to locally saved symbols file (to prevent downloading it every time)") 24 | 25 | args = parser.parse_args() 26 | 27 | con = psycopg2.connect(args.url) 28 | con.autocommit = True 29 | 30 | all_symbols = set(iqutil.get_symbols(symbols_file=args.symbols_file).keys()) 31 | 32 | with IQFeedLevel1Listener(listeners=SyncListeners(), fire_ticks=False) as listener: 33 | adjustments = get_splits_dividends(all_symbols, listener.conn) 34 | 35 | table_name = 'json_data' 36 | cur = con.cursor() 37 | cur.execute("DROP TABLE IF EXISTS {0};".format(table_name)) 38 | cur.execute(create_json_data.format(table_name)) 39 | 40 | insert_df_json(con, table_name, adjustments) 41 | -------------------------------------------------------------------------------- /scripts/update_postgres_cache.py: -------------------------------------------------------------------------------- 1 | #!/bin/python3 2 | """ 3 | Script that updates the bars/splits/dividends/fundamentals cache 4 | """ 5 | 6 | import argparse 7 | import logging 8 | 9 | import psycopg2 10 | from dateutil.relativedelta import relativedelta 11 | 12 | import atpy.data.iqfeed.util as iqutil 13 | from atpy.data.cache.postgres_cache import update_to_latest 14 | from atpy.data.iqfeed.iqfeed_history_provider import IQFeedHistoryProvider 15 | from atpy.data.iqfeed.iqfeed_postgres_cache import noncache_provider 16 | 17 | if __name__ == "__main__": 18 | logging.basicConfig(level=logging.INFO) 19 | 20 | parser = argparse.ArgumentParser(description="PostgreSQL and IQFeed configuration") 21 | 22 | parser.add_argument('-url', type=str, default=None, help="PostgreSQL connection string") 23 | parser.add_argument('-drop', action='store_true', help="Drop the table") 24 | parser.add_argument('-table_name', type=str, default=None, required=True, help="PostgreSQL database name") 25 | parser.add_argument('-cluster', action='store_true', help="Cluster the table after the opertion") 26 | 27 | parser.add_argument('-interval_len', type=int, default=None, help="Interval length") 28 | parser.add_argument('-interval_type', type=str, default='s', help="Interval type (seconds, days, etc)") 29 | parser.add_argument('-skip_if_older', type=int, default=None, help="Skip symbols, which are in the database, but have no activity for more than N previous days") 30 | parser.add_argument('-delta_back', type=int, default=10, help="Default number of years to look back") 31 | parser.add_argument('-iqfeed_conn', type=int, default=10, help="Number of historical connections to IQFeed") 32 | 33 | parser.add_argument('-symbols_file', type=str, default=None, help="location to locally saved symbols file (to prevent downloading it every time)") 34 | 35 | args = parser.parse_args() 36 | 37 | con = psycopg2.connect(args.url) 38 | con.autocommit = True 39 | 40 | if args.drop: 41 | cur = con.cursor() 42 | cur.execute("DROP TABLE IF EXISTS {0};".format(args.table_name)) 43 | 44 | if args.interval_len is None or args.interval_type is None: 45 | parser.error('-interval_len and -interval_type are required') 46 | 47 | with IQFeedHistoryProvider(num_connections=args.iqfeed_conn) as history: 48 | all_symbols = set((s, args.interval_len, args.interval_type) for s in set(iqutil.get_symbols(symbols_file=args.symbols_file).keys())) 49 | update_to_latest(url=args.url, bars_table=args.table_name, noncache_provider=noncache_provider(history), symbols=all_symbols, time_delta_back=relativedelta(years=args.delta_back), 50 | skip_if_older_than=relativedelta(days=args.skip_if_older) if args.skip_if_older is not None else None, cluster=args.cluster) 51 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """A setuptools based setup module. 2 | 3 | See: 4 | https://packaging.python.org/en/latest/distributing.html 5 | https://github.com/pypa/sampleproject 6 | """ 7 | 8 | # Always prefer setuptools over distutils 9 | from setuptools import setup, find_packages 10 | # To use a consistent encoding 11 | from codecs import open 12 | from os import path 13 | 14 | here = path.abspath(path.dirname(__file__)) 15 | 16 | # Get the long description from the README file 17 | with open(path.join(here, 'README.md'), encoding='utf-8') as f: 18 | long_description = f.read() 19 | 20 | setup( 21 | name='atpy', 22 | 23 | # Versions should comply with PEP440. For a discussion on single-sourcing 24 | # the version across setup.py and the project code, see 25 | # https://packaging.python.org/en/latest/single_source_version.html 26 | version='0.0.1', 27 | 28 | description='Algo trading configuration', 29 | long_description=long_description, 30 | 31 | # The project's main homepage. 32 | url='https://github.com/ivan-vasilev/atpy', 33 | 34 | # Author details 35 | author='Ivan Vasilev', 36 | author_email='ivanvasilev@gmail.com', 37 | 38 | # Choose your license 39 | license='MIT', 40 | 41 | # See https://pypi.python.org/pypi?%3Aaction=list_classifiers 42 | classifiers=[ 43 | # How mature is this project? Common values are 44 | # 3 - Alpha 45 | # 4 - Beta 46 | # 5 - Production/Stable 47 | 'Development Status :: 3 - Alpha', 48 | 49 | # Indicate who your project is intended for 50 | 'Intended Audience :: Developers', 51 | 'Topic :: Software Development :: Algo trading', 52 | 53 | # Pick your license as you wish (should match "license" above) 54 | 'License :: OSI Approved :: MIT License', 55 | 56 | # Specify the Python versions you support here. In particular, ensure 57 | # that you indicate whether you support Python 2, Python 3 or both. 58 | 'Programming Language :: Python :: 3.6', 59 | ], 60 | 61 | # What does your project relate to? 62 | keywords='algorithmic trading alpha', 63 | 64 | # You can just specify the packages manually here if your project is 65 | # simple. Or you can use find_packages(). 66 | packages=find_packages(exclude=['contrib', 'docs', 'tests']), 67 | 68 | # Alternatively, if you want to distribute just a my_module.py, uncomment 69 | # this: 70 | # py_modules=["my_module"], 71 | 72 | # List run-time dependencies here. These will be installed by pip when 73 | # your project is installed. For an analysis of "install_requires" vs pip's 74 | # requirements files see: 75 | # https://packaging.python.org/en/latest/requirements.html 76 | install_requires=['pyiqfeed', 'pandas', 'numpy', 'pyevents==0.0.1', 'numba'], 77 | 78 | extras_require={ 79 | 'pyevents_util': ['pyevents_util'], 80 | 'influxdb': ['influxdb'], 81 | 'TA-Lib': ['TA-Lib'], 82 | 'quandl': ['quandl'], 83 | 'postgres': ['psycopg2-binary', 'lmdb'], 84 | 'sqlalchemy': ['sqlalchemy'], 85 | }, 86 | 87 | dependency_links=[ 88 | "git+https://github.com/ivan-vasilev/pyevents#egg=pyevents-0.0.1", 89 | "git+https://github.com/ivan-vasilev/pyevents_util#egg=pyevents_util", 90 | "git+https://github.com/akapur/pyiqfeed#egg=pyiqfeed" 91 | ], 92 | 93 | # List additional groups of dependencies here (e.g. development 94 | # dependencies). You can install these using the following syntax, 95 | # for example: 96 | # $ pip install -e .[dev,test] 97 | # extras_require={ 98 | # 'dev': ['check-manifest'], 99 | # 'test': ['coverage'], 100 | # }, 101 | 102 | # If there are data files included in your packages that need to be 103 | # installed, specify them here. If using Python 2.6 or less, then these 104 | # have to be included in MANIFEST.in as well. 105 | # package_data={ 106 | # 'sample': ['package_data.dat'], 107 | # }, 108 | 109 | # Although 'package_data' is the preferred approach, in some case you may 110 | # need to place data files outside of your packages. See: 111 | # http://docs.python.org/3.4/distutils/setupscript.html#installing-additional-files # noqa 112 | # In this case, 'data_file' will be installed into '/my_data' 113 | # data_files=[('my_data', ['data/data_file'])], 114 | 115 | # To provide executable scripts, use entry points in preference to the 116 | # "scripts" keyword. Entry points provide cross-platform support and allow 117 | # pip to create the appropriate form of executable for the target platform. 118 | # entry_points={ 119 | # 'console_scripts': [ 120 | # 'sample=sample:main', 121 | # ], 122 | # }, 123 | ) 124 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # the inclusion of the tests module is not meant to offer best practices for 2 | # testing in general, but rather to support the `find_packages` example in 3 | # setup.py that excludes installing the "tests" package 4 | -------------------------------------------------------------------------------- /tests/backtesting/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivan-vasilev/atpy/abe72832ae8cec818b0e67989892c25456e9e5f5/tests/backtesting/__init__.py -------------------------------------------------------------------------------- /tests/backtesting/test_environments.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from atpy.backtesting.environments import * 4 | from atpy.data.iqfeed.iqfeed_postgres_cache import * 5 | from pyevents.events import SyncListeners 6 | 7 | 8 | class TestEnvironments(unittest.TestCase): 9 | 10 | def test_postgre_ohlc(self): 11 | logging.basicConfig(level=logging.INFO) 12 | 13 | listeners = SyncListeners() 14 | 15 | dre = data_replay_events(listeners) 16 | data_event_stream = dre.event_filter() 17 | event_stream_1d, filter_1d = add_postgres_ohlc_1d(dre, bgn_prd=datetime.datetime.now() - relativedelta(months=2)) 18 | event_stream_5m, filter_5m = add_postgres_ohlc_5m(dre, bgn_prd=datetime.datetime.now() - relativedelta(months=2)) 19 | add_current_period(listeners, filter_5m) 20 | add_current_phase(data_event_stream) 21 | add_daily_log(data_event_stream) 22 | add_rolling_mean(event_stream_1d, window=5) 23 | add_gaps(listeners, filter_1d) 24 | 25 | dct = {'bars_5m': 0, 'bars_1d': 0, 'latest_5m': None, 'latest_1d': None, 'phases': set(), 'periods': set(), 'phase_start': False} 26 | 27 | def asserts(e): 28 | if e['type'] == 'data': 29 | self.assertTrue(isinstance(e, dict)) 30 | 31 | if 'bars_5m' in e: 32 | self.assertTrue(isinstance(e['bars_5m'], pd.DataFrame)) 33 | self.assertFalse(e['bars_5m'].empty) 34 | dct['bars_5m'] += 1 35 | 36 | if dct['latest_5m'] is not None: 37 | self.assertGreater(e['bars_5m'].iloc[-1].name[0], dct['latest_5m']) 38 | 39 | dct['latest_5m'] = e['bars_5m'].iloc[-1].name[0] 40 | self.assertTrue('bars_5m_current_period' in e) 41 | self.assertTrue('period_name' in e) 42 | dct['periods'].add(e['period_name']) 43 | 44 | if e['period_start'] is True: 45 | dct['period_start'] = True 46 | 47 | if 'bars_1d' in e: 48 | self.assertTrue(isinstance(e['bars_1d'], pd.DataFrame)) 49 | self.assertFalse(e['bars_1d'].empty) 50 | self.assertTrue('close_rm_5' in e['bars_1d'].columns) 51 | dct['bars_1d'] += 1 52 | 53 | if dct['latest_1d'] is not None: 54 | self.assertGreater(e['bars_1d'].iloc[-1].name[0], dct['latest_1d']) 55 | 56 | dct['latest_1d'] = e['bars_1d'].iloc[-1].name[0] 57 | 58 | self.assertTrue('bars_1d_gaps' in e) 59 | self.assertTrue('current_phase' in e) 60 | 61 | dct['phases'].add(e['current_phase']) 62 | 63 | if e['phase_start'] is True: 64 | dct['phase_start'] = True 65 | 66 | listeners += asserts 67 | dre.start() 68 | 69 | self.assertGreater(dct['bars_5m'], 0) 70 | self.assertGreater(dct['bars_1d'], 0) 71 | self.assertIsNotNone(dct['latest_5m']) 72 | self.assertIsNotNone(dct['latest_1d']) 73 | self.assertEqual(dct['periods'], {'trading-hours', 'after-hours'}) 74 | self.assertTrue(dct['period_start']) 75 | 76 | # TODO 77 | def test_postgre_backtest(self): 78 | logging.basicConfig(level=logging.INFO) 79 | 80 | listeners = SyncListeners() 81 | 82 | dre = data_replay_events(listeners) 83 | 84 | event_stream_1m, filter_1m = add_postgres_ohlc_1m(dre, bgn_prd=datetime.datetime.now() - relativedelta(years=10)) 85 | 86 | strategy = add_random_strategy(listeners, 87 | portfolio_manager=None, 88 | bar_event_stream=event_stream_1m) 89 | 90 | me = add_mock_exchange(listeners, 91 | order_requests_stream=strategy.order_requests_stream(), 92 | bar_event_stream=event_stream_1m, 93 | slippage_loss_ratio=0.1, 94 | commission_per_share=0.05) 95 | 96 | pm = add_portfolio_manager(listeners=listeners, 97 | fulfilled_orders_stream=me.fulfilled_orders_stream(), 98 | bar_event_stream=event_stream_1m, 99 | initial_capital=10000000) 100 | 101 | strategy.portfolio_manager = pm 102 | 103 | add_daily_log(dre.event_filter()) 104 | 105 | dct = {'bars_1m': 0, 'latest_1m': None} 106 | 107 | # def asserts(e): 108 | # if e['type'] == 'data': 109 | # self.assertTrue(isinstance(e, dict)) 110 | # 111 | # if 'bars_1m' in e: 112 | # self.assertTrue(isinstance(e['bars_1m'], pd.DataFrame)) 113 | # self.assertFalse(e['bars_1m'].empty) 114 | # dct['bars_1m'] += 1 115 | # 116 | # if dct['latest_1m'] is not None: 117 | # self.assertGreater(e['bars_1m'].iloc[-1].name[0], dct['latest_1m']) 118 | # 119 | # dct['latest_1m'] = e['bars_1m'].iloc[-1].name[0] 120 | # 121 | # listeners += asserts 122 | dre.start() 123 | 124 | self.assertGreater(dct['bars_1m'], 0) 125 | 126 | def test_postgre_ohlc_quandl_sf0(self): 127 | logging.basicConfig(level=logging.INFO) 128 | 129 | listeners = SyncListeners() 130 | 131 | dre = data_replay_events(listeners) 132 | data_event_stream = dre.event_filter() 133 | 134 | event_stream_1d, filter_1d = add_postgres_ohlc_1d(dre, bgn_prd=datetime.datetime.now() - relativedelta(months=2)) 135 | add_daily_log(data_event_stream) 136 | add_current_period(listeners, filter_1d) 137 | add_quandl_sf(dre, bgn_prd=datetime.datetime.now() - relativedelta(years=2)) 138 | 139 | dct = {'bars_1d': 0, 'quandl_sf0': 0, 'latest_1d': None, 'latest_quandl_sf0': None} 140 | 141 | def asserts(e): 142 | if e['type'] == 'data': 143 | self.assertTrue(isinstance(e, dict)) 144 | 145 | if 'bars_1d' in e: 146 | self.assertTrue(isinstance(e['bars_1d'], pd.DataFrame)) 147 | self.assertFalse(e['bars_1d'].empty) 148 | dct['bars_1d'] += 1 149 | 150 | if dct['latest_1d'] is not None: 151 | self.assertGreater(e['bars_1d'].iloc[-1].name[0], dct['latest_1d']) 152 | 153 | dct['latest_1d'] = e['bars_1d'].iloc[-1].name[0] 154 | 155 | if 'quandl_sf0' in e: 156 | self.assertTrue(isinstance(e['quandl_sf0'], pd.DataFrame)) 157 | self.assertFalse(e['quandl_sf0'].empty) 158 | dct['quandl_sf0'] += 1 159 | 160 | if dct['latest_quandl_sf0'] is not None: 161 | self.assertGreater(e['quandl_sf0'].iloc[-1].name[0], dct['latest_quandl_sf0']) 162 | 163 | dct['latest_quandl_sf0'] = e['quandl_sf0'].iloc[-1].name[0] 164 | 165 | listeners += asserts 166 | dre.start() 167 | 168 | self.assertGreater(dct['bars_1d'], 0) 169 | self.assertGreater(dct['quandl_sf0'], 0) 170 | self.assertIsNotNone(dct['latest_1d']) 171 | self.assertIsNotNone(dct['latest_quandl_sf0']) 172 | 173 | 174 | if __name__ == '__main__': 175 | unittest.main() 176 | -------------------------------------------------------------------------------- /tests/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivan-vasilev/atpy/abe72832ae8cec818b0e67989892c25456e9e5f5/tests/data/__init__.py -------------------------------------------------------------------------------- /tests/data/test_splits_dividends.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import random 4 | import threading 5 | import unittest 6 | 7 | import pandas as pd 8 | 9 | import pyiqfeed as iq 10 | from atpy.data.iqfeed.filters import DefaultFilterProvider 11 | from atpy.data.iqfeed.iqfeed_history_provider import BarsInPeriodFilter, IQFeedHistoryEvents, IQFeedHistoryProvider, BarsFilter 12 | from atpy.data.iqfeed.iqfeed_level_1_provider import get_splits_dividends 13 | from atpy.data.splits_dividends import exclude_splits 14 | from pyevents.events import AsyncListeners 15 | 16 | 17 | class TestSplitsDividends(unittest.TestCase): 18 | """ 19 | Test splits/dividends functionality 20 | """ 21 | 22 | def test_bar_split_adjust_1(self): 23 | filter_provider = DefaultFilterProvider() 24 | filter_provider += BarsInPeriodFilter(ticker="PLUS", bgn_prd=datetime.datetime(2017, 3, 31), end_prd=datetime.datetime(2017, 4, 5), interval_len=3600, ascend=True, interval_type='s', max_ticks=100) 25 | 26 | listeners = AsyncListeners() 27 | 28 | with IQFeedHistoryEvents(listeners=listeners, fire_batches=True, filter_provider=filter_provider, timestamp_first=True, num_connections=2) as listener, listener.batch_provider() as provider: 29 | e1 = threading.Event() 30 | 31 | def process_bar(event): 32 | if event['type'] == 'bar_batch': 33 | d = event['data'] 34 | try: 35 | self.assertLess(d['open'].max(), 68) 36 | self.assertGreater(d['open'].min(), 65) 37 | finally: 38 | e1.set() 39 | 40 | listeners += process_bar 41 | 42 | listener.start() 43 | 44 | e1.wait() 45 | 46 | for i, d in enumerate(provider): 47 | self.assertLess(d['open'].max(), 68) 48 | self.assertGreater(d['open'].min(), 65) 49 | 50 | if i == 1: 51 | break 52 | 53 | def test_bar_split_adjust_2(self): 54 | filter_provider = DefaultFilterProvider() 55 | filter_provider += BarsInPeriodFilter(ticker=["PLUS", "AAPL"], bgn_prd=datetime.datetime(2017, 3, 31), end_prd=datetime.datetime(2017, 4, 5), interval_len=3600, ascend=True, interval_type='s') 56 | 57 | listeners = AsyncListeners() 58 | 59 | with IQFeedHistoryEvents(listeners=listeners, fire_batches=True, filter_provider=filter_provider, sync_timestamps=False, timestamp_first=True, num_connections=2) as listener, listener.batch_provider() as provider: 60 | listener.start() 61 | 62 | for i, d in enumerate(provider): 63 | idx = pd.IndexSlice 64 | 65 | self.assertLess(d.loc[idx[:, 'PLUS'], 'open'].max(), 68) 66 | self.assertGreater(d.loc[idx[:, 'PLUS'], 'open'].min(), 65) 67 | self.assertGreater(d.loc[idx[:, 'AAPL'], 'open'].min(), 142) 68 | 69 | if i == 1: 70 | break 71 | 72 | def test_exclude_splits(self): 73 | with IQFeedHistoryProvider() as provider: 74 | # single index 75 | f = BarsInPeriodFilter(ticker="PLUS", bgn_prd=datetime.datetime(2017, 3, 31), end_prd=datetime.datetime(2017, 4, 5), interval_len=3600, ascend=True, interval_type='s', max_ticks=100) 76 | 77 | data = provider.request_data(f, sync_timestamps=False) 78 | data['include'] = True 79 | data = data['include'].copy() 80 | 81 | conn = iq.QuoteConn() 82 | conn.connect() 83 | try: 84 | sd = get_splits_dividends(f.ticker, conn=conn) 85 | finally: 86 | conn.disconnect() 87 | 88 | result = exclude_splits(data, sd['value'].xs('split', level='type'), 10) 89 | 90 | self.assertTrue(result[~result].size == 10) 91 | 92 | # multiindex 93 | f = BarsInPeriodFilter(ticker=["PLUS", "IBM"], bgn_prd=datetime.datetime(2017, 3, 31), end_prd=datetime.datetime(2017, 4, 5), interval_len=3600, ascend=True, interval_type='s', max_ticks=100) 94 | 95 | data = provider.request_data(f, sync_timestamps=False) 96 | data['include'] = True 97 | data = data['include'].copy() 98 | 99 | conn = iq.QuoteConn() 100 | conn.connect() 101 | try: 102 | sd = get_splits_dividends(f.ticker, conn=conn) 103 | finally: 104 | conn.disconnect() 105 | 106 | result = exclude_splits(data, sd['value'].xs('split', level='type'), 10) 107 | 108 | self.assertTrue(result[~result].size == 10) 109 | 110 | def test_exclude_splits_performance(self): 111 | logging.basicConfig(level=logging.DEBUG) 112 | 113 | batch_len = 15000 114 | batch_width = 4000 115 | 116 | now = datetime.datetime.now() 117 | with IQFeedHistoryProvider() as provider: 118 | df1 = provider.request_data(BarsFilter(ticker="PLUS", interval_len=3600, interval_type='s', max_bars=batch_len), sync_timestamps=False) 119 | 120 | df = {'PLUS': df1} 121 | for i in range(batch_width): 122 | df['PLUS_' + str(i)] = df1.sample(random.randint(int(len(df1) / 3), len(df1) - 1)) 123 | 124 | df = pd.concat(df, sort=True) 125 | df.index.set_names(['symbol', 'timestamp'], inplace=True) 126 | df['include'] = True 127 | data = df['include'] 128 | 129 | conn = iq.QuoteConn() 130 | conn.connect() 131 | try: 132 | sd = get_splits_dividends("PLUS", conn=conn).xs('split', level='type') 133 | finally: 134 | conn.disconnect() 135 | 136 | splits = list() 137 | for l in df.index.levels[0]: 138 | ind_cp = sd.index.set_levels([l], level=1) 139 | for i, v in enumerate(sd): 140 | ind_cp.values[i] = (sd.index.values[i][0], l, sd.index.values[i][2]) 141 | 142 | cp = pd.DataFrame(data=sd.values, index=ind_cp) 143 | 144 | splits.append(cp) 145 | 146 | splits = pd.concat(splits, sort=True) 147 | 148 | logging.getLogger(__name__).debug('Random data generated in ' + str(datetime.datetime.now() - now) + ' with shapes ' + str(df.shape)) 149 | 150 | now = datetime.datetime.now() 151 | 152 | result = exclude_splits(data, splits, 10) 153 | 154 | logging.getLogger(__name__).debug('Task done in ' + str(datetime.datetime.now() - now) + ' with shapes ' + str(result.shape)) 155 | 156 | self.assertTrue(result[~result].size > 10) 157 | self.assertTrue(result[result].size > 0) 158 | -------------------------------------------------------------------------------- /tests/data/test_talib.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import pandas as pd 4 | import talib 5 | from talib import abstract 6 | from pandas.util.testing import assert_index_equal 7 | 8 | from atpy.data.iqfeed.iqfeed_history_provider import IQFeedHistoryProvider, BarsFilter 9 | 10 | 11 | class TestTALib(unittest.TestCase): 12 | """Demonstrate how to use TA-lib""" 13 | 14 | def test_ta_lib_function_api(self): 15 | """Test the functional interface of TA-Lib""" 16 | 17 | with IQFeedHistoryProvider() as provider: 18 | df = provider.request_data(BarsFilter(ticker="AAPL", interval_len=300, interval_type='s', max_bars=1000), sync_timestamps=False) 19 | close = df['close'] 20 | 21 | output = talib.SMA(close) 22 | self.assertTrue(isinstance(output, pd.Series)) 23 | self.assertFalse(output.empty) 24 | self.assertTrue(pd.isna(output[0])) 25 | self.assertFalse(pd.isna(output[-1])) 26 | self.assertEqual(close.shape, output.shape) 27 | self.assertEqual(close.dtype, output.dtype) 28 | assert_index_equal(close.index, output.index) 29 | 30 | bbands = talib.BBANDS(close, matype=talib.MA_Type.T3) 31 | for bband in bbands: 32 | self.assertTrue(isinstance(bband, pd.Series)) 33 | self.assertFalse(bband.empty) 34 | self.assertTrue(pd.isna(bband[0])) 35 | self.assertFalse(pd.isna(bband[-1])) 36 | self.assertEqual(close.shape, bband.shape) 37 | self.assertEqual(close.dtype, bband.dtype) 38 | assert_index_equal(close.index, bband.index) 39 | 40 | def test_ta_lib_abstract_api(self): 41 | """Test the abstract API of TA-Lib""" 42 | 43 | with IQFeedHistoryProvider() as provider: 44 | df = provider.request_data(BarsFilter(ticker="AAPL", interval_len=300, interval_type='s', max_bars=1000), sync_timestamps=False) 45 | close = df['close'] 46 | 47 | output = abstract.SMA(df) 48 | self.assertTrue(isinstance(output, pd.Series)) 49 | self.assertFalse(output.empty) 50 | self.assertTrue(pd.isna(output[0])) 51 | self.assertFalse(pd.isna(output[-1])) 52 | self.assertEqual(close.shape, output.shape) 53 | self.assertEqual(close.dtype, output.dtype) 54 | assert_index_equal(close.index, output.index) 55 | 56 | bbands = abstract.BBANDS(df, matype=talib.MA_Type.T3) 57 | self.assertTrue(isinstance(bbands, pd.DataFrame)) 58 | assert_index_equal(close.index, bbands.index) 59 | 60 | for _, bband in bbands.iteritems(): 61 | self.assertTrue(isinstance(bband, pd.Series)) 62 | self.assertFalse(bband.empty) 63 | self.assertEqual(close.shape, bband.shape) 64 | self.assertEqual(close.dtype, bband.dtype) 65 | self.assertTrue(pd.isna(bband[0])) 66 | self.assertFalse(pd.isna(bband[-1])) 67 | 68 | stoch = abstract.STOCH(df, 5, 3, 0, 3, 0) 69 | self.assertTrue(isinstance(stoch, pd.DataFrame)) 70 | assert_index_equal(close.index, stoch.index) 71 | 72 | for _, s in stoch.iteritems(): 73 | self.assertTrue(isinstance(s, pd.Series)) 74 | self.assertFalse(s.empty) 75 | self.assertEqual(close.shape, s.shape) 76 | self.assertEqual(close.dtype, s.dtype) 77 | self.assertTrue(pd.isna(s[0])) 78 | self.assertFalse(pd.isna(s[-1])) 79 | -------------------------------------------------------------------------------- /tests/data/test_ts_utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import random 4 | import unittest 5 | 6 | import pandas as pd 7 | 8 | import atpy.data.tradingcalendar as tcal 9 | from atpy.backtesting.data_replay import DataReplay 10 | from atpy.data.iqfeed.iqfeed_history_provider import IQFeedHistoryProvider, BarsFilter 11 | from atpy.data.ts_util import current_period, set_periods, current_day 12 | 13 | 14 | class TestTSUtils(unittest.TestCase): 15 | 16 | def test_set_periods(self): 17 | batch_len = 1000 18 | 19 | with IQFeedHistoryProvider() as provider: 20 | # One symbol, all periods 21 | df = provider.request_data(BarsFilter(ticker="AAPL", interval_len=300, interval_type='s', max_bars=batch_len), sync_timestamps=False) 22 | 23 | set_periods(df) 24 | self.assertTrue('period' in df.columns) 25 | self.assertEqual(len(pd.unique(df['period'].dropna())), 2) 26 | self.assertEqual(len(df['period'].dropna()), len(df['period'])) 27 | 28 | # Multiple symbols, all periods 29 | df = provider.request_data(BarsFilter(ticker=["AAPL", "IBM"], interval_len=300, interval_type='s', max_bars=batch_len), sync_timestamps=False).swaplevel(0, 1).sort_index() 30 | 31 | set_periods(df) 32 | self.assertTrue('period' in df.columns) 33 | self.assertEqual(len(pd.unique(df['period'].dropna())), 2) 34 | self.assertEqual(len(df['period'].dropna()), len(df['period'])) 35 | 36 | # Multiple symbols, N periods 37 | df = provider.request_data(BarsFilter(ticker=["AAPL", "IBM"], interval_len=300, interval_type='s', max_bars=batch_len), sync_timestamps=False).swaplevel(0, 1).sort_index() 38 | lc = tcal.open_and_closes.loc[min(df['timestamp']): max(df['timestamp'])].iloc[::-1] 39 | xs = pd.IndexSlice 40 | df = df.loc[xs[:lc.iloc[0]['market_close'], :]].iloc[:-3] 41 | set_periods(df) 42 | self.assertTrue('period' in df.columns) 43 | self.assertEqual(len(pd.unique(df['period'].dropna())), 2) 44 | self.assertEqual(len(df['period'].dropna()), len(df['period'])) 45 | 46 | def test_set_periods_performance(self): 47 | logging.basicConfig(level=logging.DEBUG) 48 | 49 | batch_len = 10000 50 | batch_width = 1000 51 | 52 | with IQFeedHistoryProvider() as provider: 53 | df = provider.request_data(BarsFilter(ticker="AAPL", interval_len=60, interval_type='s', max_bars=batch_len), sync_timestamps=False) 54 | 55 | dfs = {'AAPL': df} 56 | for i in range(batch_width): 57 | dfs['AAPL_' + str(i)] = df.sample(random.randint(int(len(df) / 3), len(df) - 1)) 58 | 59 | dfs = pd.concat(dfs).swaplevel(0, 1).sort_index() 60 | 61 | now = datetime.datetime.now() 62 | set_periods(dfs) 63 | logging.getLogger(__name__).debug('Time elapsed ' + str(datetime.datetime.now() - now) + ' for ' + str(batch_len) + ' steps; ' + str(batch_width) + ' width') 64 | 65 | def test_current_period(self): 66 | batch_len = 1000 67 | 68 | with IQFeedHistoryProvider() as provider: 69 | # One symbol, all periods 70 | df = provider.request_data(BarsFilter(ticker="AAPL", interval_len=300, interval_type='s', max_bars=batch_len), sync_timestamps=False) 71 | 72 | slc, period = current_period(df) 73 | self.assertTrue(period in ('trading-hours', 'after-hours')) 74 | self.assertGreater(len(df), len(slc)) 75 | 76 | # Multiple symbols, all periods 77 | df = provider.request_data(BarsFilter(ticker=["AAPL", "IBM"], interval_len=300, interval_type='s', max_bars=batch_len), sync_timestamps=False).swaplevel(0, 1).sort_index() 78 | 79 | slc, period = current_period(df) 80 | self.assertTrue(period in ('trading-hours', 'after-hours')) 81 | self.assertGreater(len(df), len(slc)) 82 | 83 | df = provider.request_data(BarsFilter(ticker=["AAPL", "IBM"], interval_len=300, interval_type='s', max_bars=batch_len), sync_timestamps=False).swaplevel(0, 1).sort_index() 84 | lc = tcal.open_and_closes.loc[min(df['timestamp']): max(df['timestamp'])].iloc[::-1] 85 | xs = pd.IndexSlice 86 | df = df.loc[xs[:lc.iloc[0]['market_close'], :]].iloc[:-3] 87 | 88 | slc, period = current_period(df) 89 | self.assertTrue(period in ('trading-hours', 'after-hours')) 90 | self.assertGreater(len(df), len(slc)) 91 | 92 | def test_current_period_2(self): 93 | logging.basicConfig(level=logging.DEBUG) 94 | 95 | batch_len = 10000 96 | batch_width = 2000 97 | 98 | with IQFeedHistoryProvider() as provider: 99 | l1, l2 = list(), list() 100 | 101 | dr = DataReplay().add_source(l1, 'e1', historical_depth=1000) 102 | 103 | now = datetime.datetime.now() 104 | df = provider.request_data(BarsFilter(ticker="AAPL", interval_len=60, interval_type='s', max_bars=batch_len), sync_timestamps=False) 105 | 106 | dfs1 = {'AAPL': df} 107 | for i in range(batch_width): 108 | dfs1['AAPL_' + str(i)] = df.sample(random.randint(int(len(df) / 3), len(df) - 1)) 109 | 110 | df = pd.concat(dfs1).swaplevel(0, 1) 111 | df.reset_index(level='symbol', inplace=True) 112 | df.sort_index(inplace=True) 113 | df.set_index('level_1', drop=False, append=True, inplace=True) 114 | l1.append(df) 115 | 116 | logging.getLogger(__name__).debug('Random data generated in ' + str(datetime.datetime.now() - now) + ' with shapes ' + str(df.shape)) 117 | 118 | now = datetime.datetime.now() 119 | 120 | for i, r in enumerate(dr): 121 | if i % 1000 == 0 and i > 0: 122 | new_now = datetime.datetime.now() 123 | elapsed = new_now - now 124 | logging.getLogger(__name__).debug('Time elapsed ' + str(elapsed) + ' for ' + str(i) + ' iterations; ' + str(elapsed / 1000) + ' per iteration') 125 | self.assertGreater(10000, (elapsed / 1000).microseconds) 126 | now = new_now 127 | 128 | for e in r: 129 | period, phase = current_period(r[e]) 130 | self.assertTrue(not period.empty) 131 | self.assertTrue(phase in ('trading-hours', 'after-hours')) 132 | 133 | elapsed = datetime.datetime.now() - now 134 | logging.getLogger(__name__).debug('Time elapsed ' + str(elapsed) + ' for ' + str(i + 1) + ' iterations; ' + str(elapsed / (i % 1000)) + ' per iteration') 135 | 136 | def test_current_day(self): 137 | logging.basicConfig(level=logging.DEBUG) 138 | 139 | batch_len = 10000 140 | batch_width = 5000 141 | 142 | with IQFeedHistoryProvider() as provider: 143 | l1, l2 = list(), list() 144 | 145 | dr = DataReplay().add_source(l1, 'e1', historical_depth=100) 146 | 147 | now = datetime.datetime.now() 148 | df = provider.request_data(BarsFilter(ticker="AAPL", interval_len=3600, interval_type='s', max_bars=batch_len), sync_timestamps=False) 149 | 150 | dfs1 = {'AAPL': df} 151 | for i in range(batch_width): 152 | dfs1['AAPL_' + str(i)] = df.sample(random.randint(int(len(df) / 3), len(df) - 1)) 153 | 154 | df = pd.concat(dfs1).swaplevel(0, 1) 155 | df.reset_index(level='symbol', inplace=True) 156 | df.sort_index(inplace=True) 157 | df.set_index('level_1', drop=False, append=True, inplace=True) 158 | l1.append(df) 159 | 160 | logging.getLogger(__name__).debug('Random data generated in ' + str(datetime.datetime.now() - now) + ' with shapes ' + str(df.shape)) 161 | 162 | now = datetime.datetime.now() 163 | 164 | for i, r in enumerate(dr): 165 | if i % 1000 == 0 and i > 0: 166 | new_now = datetime.datetime.now() 167 | elapsed = new_now - now 168 | logging.getLogger(__name__).debug('Time elapsed ' + str(elapsed) + ' for ' + str(i) + ' iterations; ' + str(elapsed / 1000) + ' per iteration') 169 | self.assertGreater(10000, (elapsed / 1000).microseconds) 170 | now = new_now 171 | 172 | for e in r: 173 | current_day(r[e], 'US/Eastern') 174 | period = current_day(r[e]) 175 | self.assertTrue(not period.empty) 176 | self.assertEqual(period.iloc[0].name[0].date(), period.iloc[1].name[0].date()) 177 | 178 | elapsed = datetime.datetime.now() - now 179 | logging.getLogger(__name__).debug('Time elapsed ' + str(elapsed) + ' for ' + str(i + 1) + ' iterations; ' + str(elapsed / (i % 1000)) + ' per iteration') 180 | -------------------------------------------------------------------------------- /tests/ibapi/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivan-vasilev/atpy/abe72832ae8cec818b0e67989892c25456e9e5f5/tests/ibapi/__init__.py -------------------------------------------------------------------------------- /tests/ibapi/test_ibapi.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import unittest 3 | 4 | import pandas as pd 5 | 6 | from atpy.ibapi.ib_events import IBEvents 7 | from atpy.portfolio.order import * 8 | from pyevents.events import AsyncListeners 9 | 10 | 11 | class TestIBApi(unittest.TestCase): 12 | """ 13 | Test IB API Orders 14 | """ 15 | 16 | def test_1(self): 17 | e_orders = {'GOOG': threading.Event(), 'AAPL': threading.Event()} 18 | e_cancel = threading.Event() 19 | e_positions = threading.Event() 20 | 21 | listeners = AsyncListeners() 22 | 23 | class CustomIBEvents(IBEvents): 24 | def cancel_all_orders(self): 25 | self.reqOpenOrders() 26 | 27 | def openOrder(self, orderId, contract, order, orderState): 28 | super().openOrder(orderId, contract, order, orderState) 29 | if orderState.status == 'PreSubmitted': 30 | self.cancelOrder(orderId) 31 | e_orders[contract.symbol].set() 32 | 33 | def openOrderEnd(self): 34 | super().openOrderEnd() 35 | e_cancel.set() 36 | 37 | ibe = CustomIBEvents(listeners=listeners, ipaddress="127.0.0.1", portid=4002, clientid=0) 38 | 39 | with ibe: 40 | listeners += lambda x: e_orders['GOOG'].set() if isinstance(x['data'], BaseOrder) and x['type'] == 'order_fulfilled' and x['data'].symbol == 'GOOG' else None 41 | listeners({'type': 'order_request', 'data': MarketOrder(Type.BUY, 'GOOG', 1)}) 42 | 43 | listeners += lambda x: e_orders['AAPL'].set() if isinstance(x['data'], BaseOrder) and x['type'] == 'order_fulfilled' and x['data'].symbol == 'AAPL' else None 44 | listeners({'type': 'order_request', 'data': MarketOrder(Type.BUY, 'AAPL', 1)}) 45 | 46 | listeners += lambda x: e_positions.set() if isinstance(x['data'], pd.DataFrame) and x['type'] == 'ibapi_positions' else None 47 | listeners({'type': 'positions_request', 'data': None}) 48 | 49 | for e in e_orders.values(): 50 | e.wait() 51 | 52 | e_positions.wait() 53 | 54 | ibe.cancel_all_orders() 55 | 56 | e_cancel.wait() 57 | -------------------------------------------------------------------------------- /tests/intrinio/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivan-vasilev/atpy/abe72832ae8cec818b0e67989892c25456e9e5f5/tests/intrinio/__init__.py -------------------------------------------------------------------------------- /tests/intrinio/test_api.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import unittest 3 | 4 | import pandas as pd 5 | from dateutil.relativedelta import relativedelta 6 | from pandas.util.testing import assert_frame_equal 7 | 8 | from atpy.data.intrinio.api import IntrinioEvents 9 | from atpy.data.intrinio.influxdb_cache import InfluxDBCache, ClientFactory 10 | from pyevents.events import SyncListeners 11 | 12 | 13 | class TestIntrinioAPI(unittest.TestCase): 14 | 15 | def test_1(self): 16 | listeners = SyncListeners() 17 | IntrinioEvents(listeners) 18 | 19 | results = list() 20 | 21 | def listener(event): 22 | if event['type'] == 'intrinio_request_result': 23 | results.append(event['data']) 24 | 25 | listeners += listener 26 | 27 | listeners({'type': 'intrinio_request', 'endpoint': 'companies', 'dataframe': True, 'parameters': {'query': 'Computer'}}) 28 | 29 | data = results[0] 30 | 31 | self.assertTrue(isinstance(data, pd.DataFrame)) 32 | self.assertGreater(len(data), 0) 33 | 34 | def test_2(self): 35 | listeners = SyncListeners() 36 | IntrinioEvents(listeners) 37 | 38 | results = list() 39 | 40 | def listener(event): 41 | if event['type'] == 'intrinio_historical_data_result': 42 | results.append(event['data']) 43 | 44 | listeners += listener 45 | 46 | listeners({'type': 'intrinio_historical_data', 47 | 'data': [{'endpoint': 'historical_data', 'identifier': 'GOOG', 'item': 'totalrevenue'}, {'endpoint': 'historical_data', 'identifier': 'YHOO', 'item': 'totalrevenue'}], 48 | 'threads': 1, 49 | 'async': False}) 50 | 51 | data = results[0] 52 | 53 | self.assertTrue(isinstance(data, pd.DataFrame)) 54 | self.assertGreater(len(data), 0) 55 | self.assertTrue(isinstance(data.index, pd.MultiIndex)) 56 | 57 | def test_3(self): 58 | listeners = SyncListeners() 59 | IntrinioEvents(listeners) 60 | 61 | client_factory = ClientFactory(host='localhost', port=8086, username='root', password='root', database='test_cache') 62 | client = client_factory.new_client() 63 | 64 | try: 65 | client.create_database('test_cache') 66 | client.switch_database('test_cache') 67 | 68 | with InfluxDBCache(client_factory=client_factory, listeners=listeners, time_delta_back=relativedelta(years=20)) as cache: 69 | cache.update_to_latest({('GOOG', 'operatingrevenue'), ('FB', 'operatingrevenue'), ('YHOO', 'operatingrevenue')}) 70 | now = datetime.datetime.now() 71 | cached = cache.request_data(symbols={'GOOG', 'MSFT', 'YHOO', 'FB'}, tags={'operatingrevenue'}, start_date=datetime.date(year=now.year - 4, month=now.month, day=now.day), 72 | end_date=datetime.date(year=now.year - 2, month=now.month, day=now.day)) 73 | 74 | self.assertIsNotNone(cached) 75 | self.assertGreater(len(cached), 0) 76 | self.assertGreaterEqual(now.year - 2, cached.index.levels[1].max().year) 77 | self.assertGreaterEqual(cached.index.levels[1].min().year, now.year - 4) 78 | 79 | cached = cache.request_data(symbols={'GOOG', 'FB'}, tags={'operatingrevenue'}) 80 | 81 | listeners = SyncListeners() 82 | IntrinioEvents(listeners) 83 | 84 | non_cached = list() 85 | 86 | def listener(event): 87 | if event['type'] == 'intrinio_historical_data_result': 88 | non_cached.append(event['data']) 89 | 90 | listeners += listener 91 | 92 | listeners({'type': 'intrinio_historical_data', 93 | 'data': [{'endpoint': 'historical_data', 'identifier': 'GOOG', 'item': 'operatingrevenue', 'sort_order': 'asc'}, 94 | {'endpoint': 'historical_data', 'identifier': 'FB', 'item': 'operatingrevenue', 'sort_order': 'asc'}], 95 | 'async': False}) 96 | 97 | non_cached = non_cached[0] 98 | 99 | assert_frame_equal(cached, non_cached) 100 | finally: 101 | client.drop_database('test_cache') 102 | client.close() 103 | 104 | 105 | if __name__ == '__main__': 106 | unittest.main() 107 | -------------------------------------------------------------------------------- /tests/iqfeed/__init__.py: -------------------------------------------------------------------------------- 1 | # the inclusion of the tests module is not meant to offer best practices for 2 | # testing in general, but rather to support the `find_packages` example in 3 | # setup.py that excludes installing the "tests" package 4 | -------------------------------------------------------------------------------- /tests/iqfeed/test_bar_data_provider.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import unittest 3 | 4 | from pandas.util.testing import assert_frame_equal 5 | 6 | from atpy.data.iqfeed.iqfeed_bar_data_provider import * 7 | from atpy.data.util import get_nasdaq_listed_companies, resample_bars 8 | from pyevents.events import AsyncListeners, SyncListeners 9 | 10 | 11 | class TestIQFeedBarData(unittest.TestCase): 12 | """ 13 | IQFeed bar data test, which checks whether the class works in basic terms 14 | """ 15 | 16 | def test_provider(self): 17 | listeners = AsyncListeners() 18 | 19 | with IQFeedBarDataListener(listeners=listeners, mkt_snapshot_depth=10, interval_len=60) as listener: 20 | # test bars 21 | e1 = {'GOOG': threading.Event(), 'IBM': threading.Event()} 22 | 23 | def bar_listener(df, symbol): 24 | self.assertTrue(symbol in ['IBM', 'GOOG']) 25 | self.assertEqual(len(df), listener.mkt_snapshot_depth) 26 | self.assertEqual(df.shape[1], 8), 27 | e1[symbol].set() 28 | 29 | full_bars = listener.all_full_bars_event_stream() 30 | full_bars += bar_listener 31 | 32 | # test market snapshot 33 | e3 = threading.Event() 34 | 35 | bar_updates = listener.bar_updates_event_stream() 36 | bar_updates += lambda data, symbol: [self.assertEqual(symbol.shape[1], 8), e3.set()] 37 | 38 | listeners({'type': 'watch_bars', 'data': {'symbol': ['GOOG', 'IBM'], 'update': 1}}) 39 | 40 | for e in e1.values(): 41 | e.wait() 42 | 43 | e3.wait() 44 | 45 | def test_resample(self): 46 | listeners = SyncListeners() 47 | 48 | with IQFeedBarDataListener(listeners=listeners, mkt_snapshot_depth=100, interval_len=60) as listener: 49 | # test bars 50 | e1 = {'GOOG': threading.Event(), 'IBM': threading.Event()} 51 | 52 | def bar_listener(df, symbol): 53 | resampled_df = resample_bars(df, '5min') 54 | self.assertLess(len(resampled_df), len(df)) 55 | self.assertEqual(df['volume'].sum(), resampled_df['volume'].sum()) 56 | e1[symbol].set() 57 | 58 | full_bars = listener.all_full_bars_event_stream() 59 | full_bars += bar_listener 60 | 61 | listeners({'type': 'watch_bars', 'data': {'symbol': ['GOOG', 'IBM'], 'update': 1}}) 62 | 63 | for e in e1.values(): 64 | e.wait() 65 | 66 | def test_listener(self): 67 | listeners = AsyncListeners() 68 | 69 | with IQFeedBarDataListener(listeners=listeners, interval_len=300, interval_type='s', mkt_snapshot_depth=10) as listener: 70 | e1 = threading.Event() 71 | full_bars_filter = listener.all_full_bars_event_stream() 72 | full_bars_filter += lambda data: [self.assertEqual(data.index[0][1], 'SPY'), e1.set()] 73 | 74 | e2 = threading.Event() 75 | updates_filter = listener.all_full_bars_event_stream() 76 | updates_filter += lambda data: [self.assertEqual(data.index[0][1], 'SPY'), e2.set()] 77 | 78 | listener.watch_bars(symbol='SPY') 79 | 80 | e1.wait() 81 | e2.wait() 82 | 83 | def test_correctness_small(self): 84 | self._test_correctness('IBM') 85 | 86 | def test_nasdaq_correctness_large(self): 87 | nasdaq = get_nasdaq_listed_companies() 88 | nasdaq = nasdaq.loc[nasdaq['Market Category'] == 'Q'] 89 | nasdaq = nasdaq.sample(400) 90 | 91 | self._test_correctness(nasdaq['Symbol'].to_list()) 92 | 93 | def _test_correctness(self, symbols): 94 | logging.basicConfig(level=logging.DEBUG) 95 | 96 | listeners = SyncListeners() 97 | depth = 5 98 | with IQFeedBarDataListener(listeners=listeners, mkt_snapshot_depth=depth, interval_len=60, interval_type='s', adjust_history=False, update_interval=1) as listener: 99 | dfs = dict() 100 | 101 | te = threading.Event() 102 | 103 | def full_bar_listener(df, symbol): 104 | self.assertEqual(df.shape[0], depth) 105 | dfs[symbol] = df.copy(deep=True) 106 | 107 | full_bars = listener.all_full_bars_event_stream() 108 | full_bars += full_bar_listener 109 | 110 | conditions = {'ind_equal': False, 'ind_not_equal': False} 111 | 112 | def bar_update_listener(df, symbol): 113 | self.assertEqual(df.shape[0], depth) 114 | old_df = dfs[symbol] 115 | 116 | if old_df.index.equals(df.index): 117 | assert_frame_equal(old_df.iloc[:-1], df.iloc[:-1], check_index_type=False) 118 | conditions['ind_equal'] = True 119 | else: 120 | assert_frame_equal(old_df.iloc[1:], df.iloc[:-1], check_index_type=False) 121 | conditions['ind_not_equal'] = True 122 | 123 | try: 124 | assert_frame_equal(old_df.iloc[-1:], df.iloc[-1:], check_index_type=False) 125 | except AssertionError: 126 | pass 127 | else: 128 | raise AssertionError 129 | 130 | if conditions['ind_equal'] is True and conditions['ind_not_equal'] is True: 131 | te.set() 132 | 133 | bar_updates = listener.bar_updates_event_stream() 134 | bar_updates += bar_update_listener 135 | 136 | listener.watch_bars(symbols) 137 | 138 | te.wait() 139 | 140 | @unittest.skip('Run manually') 141 | def test_nasdaq_performance(self): 142 | listeners = AsyncListeners() 143 | import time 144 | nasdaq = get_nasdaq_listed_companies() 145 | nasdaq = nasdaq.loc[nasdaq['Market Category'] == 'Q'] 146 | nasdaq = nasdaq.sample(480) 147 | with IQFeedBarDataListener(listeners=listeners, mkt_snapshot_depth=200, interval_len=1, interval_type='s', adjust_history=False) as listener: 148 | listener.watch_bars(nasdaq['Symbol'].to_list()) 149 | time.sleep(1000) 150 | 151 | 152 | if __name__ == '__main__': 153 | unittest.main() 154 | -------------------------------------------------------------------------------- /tests/iqfeed/test_iqfeed_influxdb_cache_requests.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import unittest 3 | 4 | from pandas.util.testing import assert_frame_equal 5 | 6 | from atpy.data.cache.influxdb_cache_requests import * 7 | from atpy.data.iqfeed.iqfeed_influxdb_cache import * 8 | from atpy.data.iqfeed.iqfeed_level_1_provider import get_fundamentals 9 | from pyevents.events import AsyncListeners 10 | from atpy.data.splits_dividends import adjust_df 11 | import pyiqfeed as iq 12 | 13 | 14 | class TestInfluxDBCacheRequests(unittest.TestCase): 15 | """ 16 | Test InfluxDBCache 17 | """ 18 | 19 | def setUp(self): 20 | self._client = DataFrameClient(host='localhost', port=8086, username='root', password='root', database='test_cache') 21 | 22 | self._client.drop_database('test_cache') 23 | self._client.create_database('test_cache') 24 | self._client.switch_database('test_cache') 25 | 26 | def tearDown(self): 27 | self._client.drop_database('test_cache') 28 | self._client.close() 29 | 30 | def test_request_ohlc(self): 31 | listeners = AsyncListeners() 32 | 33 | with IQFeedHistoryProvider(num_connections=2) as history: 34 | streaming_conn = iq.QuoteConn() 35 | streaming_conn.connect() 36 | 37 | end_prd = datetime.datetime(2017, 5, 1) 38 | 39 | # test single symbol request 40 | filters = (BarsInPeriodFilter(ticker="IBM", bgn_prd=datetime.datetime(2017, 4, 1), end_prd=end_prd, interval_len=3600, ascend=True, interval_type='s'), 41 | BarsInPeriodFilter(ticker="AAPL", bgn_prd=datetime.datetime(2017, 4, 1), end_prd=end_prd, interval_len=3600, ascend=True, interval_type='s'), 42 | BarsInPeriodFilter(ticker="AAPL", bgn_prd=datetime.datetime(2017, 4, 1), end_prd=end_prd, interval_len=600, ascend=True, interval_type='s')) 43 | 44 | update_splits_dividends(client=self._client, fundamentals=get_fundamentals({'IBM', 'AAPL'}, streaming_conn).values()) 45 | adjusted = list() 46 | 47 | for f in filters: 48 | datum = history.request_data(f, sync_timestamps=False) 49 | datum.drop('timestamp', axis=1, inplace=True) 50 | datum['interval'] = str(f.interval_len) + '_' + f.interval_type 51 | self._client.write_points(datum, 'bars', protocol='line', tag_columns=['symbol', 'interval'], time_precision='s') 52 | datum.drop('interval', axis=1, inplace=True) 53 | 54 | datum = history.request_data(f, sync_timestamps=False) 55 | 56 | adjust_df(datum, get_adjustments(client=self._client, symbol=f.ticker)) 57 | adjusted.append(datum) 58 | 59 | cache_requests = InfluxDBOHLCRequest(client=self._client, interval_len=f.interval_len, interval_type=f.interval_type) 60 | _, test_data = cache_requests.request(symbol=f.ticker) 61 | adjust_df(test_data, get_adjustments(client=self._client, symbol=f.ticker)) 62 | del datum['total_volume'] 63 | del datum['number_of_trades'] 64 | assert_frame_equal(datum, test_data) 65 | 66 | for datum, f in zip(adjusted, filters): 67 | cache_requests = InfluxDBOHLCRequest(client=self._client, interval_len=f.interval_len, interval_type=f.interval_type) 68 | _, test_data = cache_requests.request(symbol=f.ticker) 69 | _, test_data_limit = cache_requests.request(symbol=f.ticker, bgn_prd=f.bgn_prd + relativedelta(days=7), end_prd=f.end_prd - relativedelta(days=7)) 70 | 71 | self.assertGreater(len(test_data_limit), 0) 72 | self.assertLess(len(test_data_limit), len(test_data)) 73 | 74 | # test multisymbol request 75 | requested_data = history.request_data(BarsInPeriodFilter(ticker=["AAPL", "IBM"], bgn_prd=datetime.datetime(2017, 4, 1), end_prd=end_prd, interval_len=3600, ascend=True, interval_type='s'), sync_timestamps=False) 76 | requested_data = requested_data.swaplevel(0, 1).sort_index() 77 | del requested_data['total_volume'] 78 | del requested_data['number_of_trades'] 79 | 80 | cache_requests = InfluxDBOHLCRequest(client=self._client, interval_len=3600, listeners=listeners) 81 | _, test_data = cache_requests.request(symbol=['IBM', 'AAPL', 'TSG'], bgn_prd=datetime.datetime(2017, 4, 1), end_prd=end_prd) 82 | assert_frame_equal(requested_data, test_data) 83 | 84 | # test any symbol request 85 | requested_data = history.request_data(BarsInPeriodFilter(ticker=["AAPL", "IBM"], bgn_prd=datetime.datetime(2017, 4, 1), end_prd=end_prd, interval_len=3600, ascend=True, interval_type='s'), sync_timestamps=False) 86 | requested_data = requested_data.swaplevel(0, 1).sort_index() 87 | 88 | del requested_data['total_volume'] 89 | del requested_data['number_of_trades'] 90 | 91 | e = threading.Event() 92 | 93 | def listen(event): 94 | if event['type'] == 'cache_result': 95 | assert_frame_equal(requested_data, event['data'][0]) 96 | e.set() 97 | 98 | listeners += listen 99 | 100 | listeners({'type': 'request_ohlc', 'data': {'bgn_prd': datetime.datetime(2017, 4, 1), 'end_prd': end_prd}}) 101 | 102 | e.wait() 103 | 104 | streaming_conn.disconnect() 105 | 106 | 107 | if __name__ == '__main__': 108 | unittest.main() 109 | -------------------------------------------------------------------------------- /tests/iqfeed/test_news_provider.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from atpy.data.iqfeed.iqfeed_news_provider import * 4 | from pyevents.events import AsyncListeners 5 | 6 | 7 | class TestIQFeedNews(unittest.TestCase): 8 | """ 9 | IQFeed news test, which checks whether the class works in basic terms 10 | """ 11 | 12 | def test_provider(self): 13 | filter_provider = DefaultNewsFilterProvider() 14 | filter_provider += NewsFilter(symbols=['AAPL'], limit=10) 15 | 16 | listeners = AsyncListeners() 17 | 18 | with IQFeedNewsListener(listeners=listeners, attach_text=True, filter_provider=filter_provider) as listener, listener.batch_provider() as provider: 19 | e1 = threading.Event() 20 | 21 | def process_batch_listener(event): 22 | if event['type'] == 'news_batch': 23 | batch = event['data'] 24 | self.assertEqual(len(batch), 10) 25 | self.assertEqual(len(batch.columns), 6) 26 | self.assertTrue('text' in batch.columns) 27 | self.assertTrue('AAPL' in batch['symbol_list'][0] or 'IBM' in batch['symbol_list'][0]) 28 | 29 | e1.set() 30 | 31 | listeners += process_batch_listener 32 | 33 | e1.wait() 34 | 35 | for i, d in enumerate(provider): 36 | self.assertEqual(len(d), 10) 37 | self.assertEqual(len(d.columns), 6) 38 | self.assertTrue('text' in d.columns) 39 | self.assertTrue('AAPL' in d['symbol_list'][0] or 'IBM' in d['symbol_list'][0]) 40 | 41 | if i == 1: 42 | break 43 | 44 | 45 | if __name__ == '__main__': 46 | unittest.main() 47 | -------------------------------------------------------------------------------- /tests/iqfeed/test_streaming_level_1.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from collections import OrderedDict 3 | 4 | from atpy.data.iqfeed.iqfeed_level_1_provider import * 5 | from atpy.data.util import get_nasdaq_listed_companies 6 | from pyevents.events import * 7 | 8 | 9 | class TestIQFeedLevel1(unittest.TestCase): 10 | """ 11 | IQFeed streaming news test, which checks whether the class works in basic terms 12 | """ 13 | 14 | def test_fundamentals(self): 15 | listeners = AsyncListeners() 16 | with IQFeedLevel1Listener(listeners=listeners) as listener: 17 | ffilter = listener.fundamentals_filter() 18 | 19 | listener.watch('IBM') 20 | listener.watch('AAPL') 21 | listener.watch('GOOG') 22 | listener.watch('MSFT') 23 | listener.watch('SPY') 24 | listener.request_watches() 25 | e1 = threading.Event() 26 | 27 | def on_fund_item(fund): 28 | try: 29 | self.assertTrue(fund['symbol'] in {'SPY', 'AAPL', 'IBM', 'GOOG', 'MSFT'}) 30 | self.assertEqual(len(fund), 50) 31 | finally: 32 | e1.set() 33 | 34 | ffilter += on_fund_item 35 | 36 | e1.wait() 37 | 38 | def test_get_fundamentals(self): 39 | funds = get_fundamentals({'TRC', 'IBM', 'AAPL', 'GOOG', 'MSFT'}) 40 | self.assertTrue('AAPL' in funds and 'IBM' in funds and 'GOOG' in funds and 'MSFT' in funds and 'TRC' in funds) 41 | for _, v in funds.items(): 42 | self.assertGreater(len(v), 0) 43 | 44 | def test_update_summary(self): 45 | listeners = AsyncListeners() 46 | 47 | with IQFeedLevel1Listener(listeners=listeners) as listener: 48 | e1 = threading.Event() 49 | 50 | def on_summary_item(data): 51 | try: 52 | self.assertEqual(len(data), 16) 53 | finally: 54 | e1.set() 55 | 56 | summary_filter = listener.level_1_summary_filter() 57 | summary_filter += on_summary_item 58 | 59 | e2 = threading.Event() 60 | 61 | def on_update_item(data): 62 | try: 63 | self.assertEqual(len(data), 16) 64 | finally: 65 | e2.set() 66 | 67 | update_filter = listener.level_1_update_filter() 68 | update_filter += on_update_item 69 | 70 | listener.watch('IBM') 71 | listener.watch('AAPL') 72 | listener.watch('GOOG') 73 | listener.watch('MSFT') 74 | listener.watch('SPY') 75 | 76 | e1.wait() 77 | e2.wait() 78 | 79 | def test_update_summary_deque(self): 80 | listeners = SyncListeners() 81 | 82 | mkt_snapshot_depth = 100 83 | with IQFeedLevel1Listener(listeners=listeners, mkt_snapshot_depth=mkt_snapshot_depth) as listener: 84 | e1 = threading.Event() 85 | 86 | def on_summary_item(data): 87 | try: 88 | self.assertEqual(len(data), 16) 89 | self.assertTrue(isinstance(data, OrderedDict)) 90 | self.assertEqual(next(iter(data.keys())), 'symbol') 91 | self.assertTrue(isinstance(data['symbol'], str)) 92 | finally: 93 | e1.set() 94 | 95 | summary_filter = listener.level_1_summary_filter() 96 | summary_filter += on_summary_item 97 | 98 | e2 = threading.Event() 99 | 100 | def on_update_item(data): 101 | try: 102 | self.assertEqual(len(data), 16) 103 | self.assertTrue(isinstance(data, OrderedDict)) 104 | self.assertGreater(len(next(iter(data.values()))), 1) 105 | self.assertEqual(next(iter(data.keys())), 'symbol') 106 | self.assertTrue(isinstance(data['symbol'], str)) 107 | finally: 108 | e2.set() 109 | 110 | update_filter = listener.level_1_update_filter() 111 | update_filter += on_update_item 112 | 113 | listener.watch('IBM') 114 | listener.watch('AAPL') 115 | listener.watch('GOOG') 116 | listener.watch('MSFT') 117 | listener.watch('SPY') 118 | 119 | e1.wait() 120 | e2.wait() 121 | 122 | def test_update_summary_deque_trades_only(self): 123 | listeners = SyncListeners() 124 | 125 | mkt_snapshot_depth = 2 126 | with IQFeedLevel1Listener(listeners=listeners, mkt_snapshot_depth=mkt_snapshot_depth) as listener: 127 | e1 = threading.Event() 128 | 129 | def on_summary_item(data): 130 | try: 131 | self.assertEqual(len(data), 16) 132 | self.assertTrue(isinstance(data, OrderedDict)) 133 | finally: 134 | e1.set() 135 | 136 | summary_filter = listener.level_1_summary_filter() 137 | summary_filter += on_summary_item 138 | 139 | e2 = threading.Event() 140 | 141 | def on_update_item(data): 142 | try: 143 | self.assertEqual(len(data), 16) 144 | self.assertTrue(isinstance(data, OrderedDict)) 145 | self.assertGreater(len(next(iter(data.values()))), 1) 146 | trade_times = data['most_recent_trade_time'] 147 | self.assertEqual(len(trade_times), len(set(trade_times))) 148 | finally: 149 | if len(trade_times) == mkt_snapshot_depth: 150 | e2.set() 151 | 152 | update_filter = listener.level_1_update_filter() 153 | update_filter += on_update_item 154 | 155 | listener.watch_trades('IBM') 156 | listener.watch_trades('AAPL') 157 | listener.watch_trades('GOOG') 158 | listener.watch_trades('MSFT') 159 | listener.watch_trades('SPY') 160 | 161 | e1.wait() 162 | e2.wait() 163 | 164 | def test_news(self): 165 | listeners = AsyncListeners() 166 | 167 | with IQFeedLevel1Listener(listeners=listeners) as listener: 168 | e1 = threading.Event() 169 | 170 | def on_news_item(news_item): 171 | try: 172 | self.assertEqual(len(news_item), 6) 173 | self.assertGreater(len(news_item['headline']), 0) 174 | finally: 175 | e1.set() 176 | 177 | news_filter = listener.news_filter() 178 | news_filter += on_news_item 179 | 180 | listener.news_on() 181 | listener.watch('IBM') 182 | listener.watch('AAPL') 183 | listener.watch('GOOG') 184 | listener.watch('SPY') 185 | 186 | e1.wait() 187 | 188 | @unittest.skip('Run manually') 189 | def test_nasdaq_quotes_and_trades_performance(self): 190 | logging.basicConfig(level=logging.DEBUG) 191 | 192 | listeners = AsyncListeners() 193 | import time 194 | nasdaq = get_nasdaq_listed_companies() 195 | nasdaq = nasdaq.loc[nasdaq['Market Category'] == 'Q'] 196 | nasdaq = nasdaq.sample(480) 197 | with IQFeedLevel1Listener(listeners=listeners, mkt_snapshot_depth=200) as listener: 198 | listener.watch(nasdaq['Symbol'].to_list()) 199 | time.sleep(1000) 200 | 201 | @unittest.skip('Run manually') 202 | def test_nasdaq_trades_performance(self): 203 | logging.basicConfig(level=logging.DEBUG) 204 | 205 | listeners = AsyncListeners() 206 | import time 207 | nasdaq = get_nasdaq_listed_companies() 208 | nasdaq = nasdaq.loc[nasdaq['Market Category'] == 'Q'] 209 | nasdaq = nasdaq.sample(480) 210 | with IQFeedLevel1Listener(listeners=listeners, mkt_snapshot_depth=200) as listener: 211 | listener.watch_trades(nasdaq['Symbol'].to_list()) 212 | time.sleep(1000) 213 | 214 | 215 | if __name__ == '__main__': 216 | unittest.main() 217 | -------------------------------------------------------------------------------- /tests/ml/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivan-vasilev/atpy/abe72832ae8cec818b0e67989892c25456e9e5f5/tests/ml/__init__.py -------------------------------------------------------------------------------- /tests/ml/test_data_pipeline.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import psycopg2 4 | from sqlalchemy import create_engine 5 | 6 | from atpy.data.cache.postgres_cache import BarsBySymbolProvider, create_bars, bars_indices, create_json_data, insert_df_json, request_adjustments 7 | from atpy.data.iqfeed.iqfeed_history_provider import * 8 | from atpy.data.splits_dividends import exclude_splits 9 | from atpy.ml.frac_diff_features import frac_diff_ffd 10 | from atpy.ml.labeling import triple_barriers 11 | from atpy.ml.util import * 12 | 13 | 14 | class TestDataPipeline(unittest.TestCase): 15 | 16 | def setUp(self): 17 | self.interval = 14400 18 | 19 | def __generate_temp_pipeline(self, url): 20 | with IQFeedHistoryProvider(num_connections=2) as history: 21 | # historical data 22 | engine = create_engine(url) 23 | con = psycopg2.connect(url) 24 | con.autocommit = True 25 | 26 | cur = con.cursor() 27 | 28 | cur.execute(create_bars.format('bars_test')) 29 | cur.execute(bars_indices.format('bars_test')) 30 | 31 | filters = (BarsFilter(ticker="IBM", interval_len=self.interval, interval_type='s', max_bars=10000), 32 | BarsFilter(ticker="AAPL", interval_len=self.interval, interval_type='s', max_bars=10000), 33 | BarsFilter(ticker="MSFT", interval_len=self.interval, interval_type='s', max_bars=10000), 34 | BarsFilter(ticker="FB", interval_len=self.interval, interval_type='s', max_bars=10000), 35 | BarsFilter(ticker="GOOG", interval_len=self.interval, interval_type='s', max_bars=10000)) 36 | 37 | data = [history.request_data(f, sync_timestamps=False) for f in filters] 38 | 39 | for datum, f in zip(data, filters): 40 | del datum['timestamp'] 41 | del datum['total_volume'] 42 | del datum['number_of_trades'] 43 | datum['symbol'] = f.ticker 44 | datum['interval'] = str(self.interval) + '_s' 45 | 46 | datum = datum.tz_localize(None) 47 | datum.to_sql('bars_test', con=engine, if_exists='append') 48 | 49 | # adjustments 50 | adjustments = get_splits_dividends({'IBM', 'AAPL', 'MSFT', 'FB', 'GOOG'}) 51 | 52 | cur = con.cursor() 53 | 54 | cur.execute(create_json_data.format('json_data_test')) 55 | 56 | insert_df_json(con, 'json_data_test', adjustments) 57 | 58 | def test_data_generation_pipeline(self): 59 | url = 'postgresql://postgres:postgres@localhost:5432/test' 60 | con = psycopg2.connect(url) 61 | con.autocommit = True 62 | 63 | try: 64 | self.__generate_temp_pipeline(url) 65 | bars_per_symbol = BarsBySymbolProvider(conn=con, records_per_query=50000, interval_len=self.interval, interval_type='s', table_name='bars_test') 66 | 67 | for df in bars_per_symbol: 68 | orig = df 69 | df['pt'] = 0.001 70 | df['sl'] = 0.001 71 | 72 | adj = request_adjustments(con, 'json_data_test', symbol=list(df.index.get_level_values('symbol').unique()), adj_type='split') 73 | self.assertTrue(adj.size > 0) 74 | 75 | df = triple_barriers(df['close'], df['pt'], sl=df['sl'], vb=pd.Timedelta(str(self.interval * 10) + 's'), parallel=False) 76 | self.assertTrue(df.size > 0) 77 | 78 | df['include'] = True 79 | df['include'] = exclude_splits(df['include'], adj['value'].xs('split', level='type'), 10) 80 | self.assertTrue(df['include'].max()) 81 | self.assertFalse(df['include'].min()) 82 | 83 | tmp = orig['close'].to_frame() 84 | tmp['threshold'] = 0.05 85 | to_include = cumsum_filter(tmp['close'], tmp['threshold'], parallel=True) 86 | df.loc[~df.index.isin(to_include), 'include'] = False 87 | df.loc[df['interval_end'].isnull(), 'include'] = False 88 | 89 | self.assertTrue(df['include'].max()) 90 | self.assertFalse(df['include'].min()) 91 | 92 | df['frac_diff'] = frac_diff_ffd(orig['close'], 0.4) 93 | finally: 94 | con.cursor().execute("DROP TABLE IF EXISTS bars_test;") 95 | con.cursor().execute("DROP TABLE IF EXISTS json_data_test;") 96 | -------------------------------------------------------------------------------- /tests/portfolio/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivan-vasilev/atpy/abe72832ae8cec818b0e67989892c25456e9e5f5/tests/portfolio/__init__.py -------------------------------------------------------------------------------- /tests/quandl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivan-vasilev/atpy/abe72832ae8cec818b0e67989892c25456e9e5f5/tests/quandl/__init__.py --------------------------------------------------------------------------------