├── .gitattributes ├── .gitignore ├── .travis.yml ├── LICENSE.TXT ├── algotrader ├── __init__.py ├── analyzer │ ├── __init__.py │ ├── drawdown.py │ ├── performance.py │ └── pnl.py ├── app │ ├── __init__.py │ ├── backtest_runner.py │ ├── live_ats.py │ └── mkt_data_importer.py ├── chart │ ├── __init__.py │ └── plotter.py ├── model │ ├── __init__.py │ ├── market_data_pb2.py │ ├── model_factory.py │ ├── ref_data_pb2.py │ ├── time_series2_pb2.py │ ├── time_series_pb2.py │ ├── timeseries_runner.py │ └── trade_data_pb2.py ├── model2 │ └── __init__.py ├── provider │ ├── __init__.py │ ├── broker │ │ ├── __init__.py │ │ ├── ib │ │ │ ├── __init__.py │ │ │ ├── ib_broker.py │ │ │ ├── ib_model_factory.py │ │ │ └── ib_socket.py │ │ └── sim │ │ │ ├── __init__.py │ │ │ ├── commission.py │ │ │ ├── data_processor.py │ │ │ ├── fill_strategy.py │ │ │ ├── order_handler.py │ │ │ ├── sim_config.py │ │ │ ├── simulator.py │ │ │ └── slippage.py │ ├── datastore │ │ ├── __init__.py │ │ ├── cass.py │ │ ├── inmemory.py │ │ └── mongodb.py │ └── feed │ │ ├── __init__.py │ │ ├── csv.py │ │ ├── pandas_h5.py │ │ ├── pandas_memory.py │ │ └── pandas_web.py ├── strategy │ ├── __init__.py │ ├── alpha_formula.py │ ├── cross_sectional_mean_reverting.py │ ├── down_2pct_strategy.py │ ├── ema_strategy.py │ ├── merton_optimal.py │ ├── pair_trading.py │ ├── sma_strategy.py │ ├── vix_future.py │ └── volatility_made_simple.py ├── technical │ ├── __init__.py │ ├── atr.py │ ├── bb.py │ ├── historical_volatility.py │ ├── kfpairregression.py │ ├── ma.py │ ├── pipeline │ │ ├── __init__.py │ │ ├── corr.py │ │ ├── cross_sessional_apply.py │ │ ├── make_vector.py │ │ ├── pairwise.py │ │ └── rank.py │ ├── roc.py │ ├── rolling_apply.py │ ├── rsi.py │ ├── stats.py │ ├── talib_wrapper.py │ └── talib_wrapper_gen.py ├── trading │ ├── __init__.py │ ├── account.py │ ├── bar_aggregator.py │ ├── clock.py │ ├── config.py │ ├── context.py │ ├── data_series.py │ ├── event.py │ ├── instrument_data.py │ ├── order.py │ ├── portfolio.py │ ├── position.py │ ├── ref_data.py │ ├── sequence.py │ └── subscription.py └── utils │ ├── __init__.py │ ├── data_series.py │ ├── date.py │ ├── indicator.py │ ├── logging.py │ ├── market_data.py │ ├── model.py │ ├── protobuf_to_dict.py │ ├── py2to3.py │ ├── ref_data.py │ ├── sde_sim.py │ └── trade_data.py ├── config ├── backtest.json ├── backtest.yaml ├── config.txt ├── data_import.yaml ├── down2%.yaml └── live_ib.yaml ├── data ├── refdata │ ├── ccy.csv │ ├── exch.csv │ └── instrument.csv └── tradedata │ ├── 1.csv │ ├── fb.csv │ ├── goog.csv │ ├── msft.csv │ └── spy.csv ├── poc ├── __init__.py ├── backtest_in_memory.py ├── backtest_pairtrading.py ├── cassandra_sample.py ├── copy_test.py ├── ib_demo.py ├── ib_tester.py ├── msgpack_numpy.py ├── oms_client.py ├── oms_server.py ├── pandas.ipynb ├── pyfolio_playground.py ├── rpy2.py ├── rxpy_scheduling.py ├── ser_deser.py ├── tensor_flow.ipynb ├── theano_check1.py ├── time_test.py └── zerorpc_patch.py ├── proto └── algotrader │ └── model │ ├── market_data.proto │ ├── ref_data.proto │ ├── time_series.proto │ ├── time_series2.proto │ └── trade_data.proto ├── readme.md ├── requirements.txt ├── scripts ├── __init__.py ├── cassandra │ ├── algotrader.cql │ └── algotrader.old.cql ├── eoddata_symbol_importer.py ├── event_calendar_downloader.py ├── gen_proto.sh ├── ib_inst_utils.py ├── kdb │ ├── bar.q │ ├── quote.q │ └── trade.q ├── netfronds_tickdata.py ├── start_mongo.sh └── vix_inst_importer.py ├── tests ├── __init__.py ├── integration_tests │ ├── __init__.py │ ├── test_data_store.py │ └── test_persistence_mongo.py ├── sample_factory.py ├── test_bar.py ├── test_bar_aggregator.py ├── test_broker.py ├── test_broker_mgr.py ├── test_clock.py ├── test_cmp_functional_backtest.py ├── test_config.py ├── test_data_series.py ├── test_data_series_utils.py ├── test_feed.py ├── test_in_memory_db.py ├── test_indicator.py ├── test_instrument_data.py ├── test_ma.py ├── test_market_data_processor.py ├── test_model_factory.py ├── test_model_utils.py ├── test_order.py ├── test_order_handler.py ├── test_persistence_indicator.py ├── test_persistence_strategy.py ├── test_pipeline.py ├── test_pipeline_pairwise.py ├── test_plot.py ├── test_portfolio.py ├── test_position.py ├── test_ref_data.py ├── test_rolling.py ├── test_ser_deser.py ├── test_suite.py └── test_talib_wrapper.py └── todo.txt /.gitattributes: -------------------------------------------------------------------------------- 1 | * text=auto 2 | 3 | eol=lf -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .ipynb_checkpoints 3 | *.iml 4 | *.ipr 5 | *.iws 6 | *.pyc 7 | .RData 8 | .Rhistory 9 | out 10 | classes 11 | build 12 | data/refdata/eoddata -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | 3 | sudo: required 4 | 5 | python: 6 | - "2.7" 7 | # command to install dependencies 8 | 9 | before_install: 10 | - wget http://prdownloads.sourceforge.net/ta-lib/ta-lib-0.4.0-src.tar.gz 11 | - tar xvfz ta-lib-0.4.0-src.tar.gz 12 | - cd ta-lib 13 | - ./configure --prefix=/usr 14 | - make 15 | - sudo make install 16 | - cd .. 17 | 18 | install: "pip install -r requirements.txt" 19 | # command to run tests 20 | script: nosetests tests/test_suite.py -------------------------------------------------------------------------------- /LICENSE.TXT: -------------------------------------------------------------------------------- 1 | Python-Trading 2 | 3 | Copyright 2016 Alex Yu 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. -------------------------------------------------------------------------------- /algotrader/__init__.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from algotrader.utils.model import get_model_id 3 | 4 | class HasId(object): 5 | __metaclass__ = abc.ABCMeta 6 | 7 | @abc.abstractmethod 8 | def id(self) -> str: 9 | raise NotImplementedError() 10 | 11 | 12 | 13 | class Startable(object): 14 | __metaclass__ = abc.ABCMeta 15 | 16 | def __init__(self): 17 | self.started = False 18 | 19 | def start(self, app_context = None) -> None: 20 | self.app_context = app_context 21 | if not hasattr(self, "started") or not self.started: 22 | self.started = True 23 | self._start(app_context=app_context) 24 | 25 | def stop(self) -> None: 26 | if hasattr(self, "started") and self.started: 27 | self._stop() 28 | self.started = False 29 | 30 | def reset(self) -> None: 31 | pass 32 | 33 | def _start(self, app_context = None) -> None: 34 | pass 35 | 36 | def _stop(self) -> None: 37 | pass 38 | 39 | 40 | class Context(object): 41 | __metaclass__ = abc.ABCMeta 42 | 43 | def __init__(self): 44 | self.startables = [] 45 | 46 | def add_startable(self, startable: Startable) -> Startable: 47 | self.startables.append(startable) 48 | return startable 49 | 50 | def start(self) -> None: 51 | for startable in self.startables: 52 | startable.start(self) 53 | 54 | def stop(self) -> None: 55 | for startable in reversed(self.startables): 56 | startable.stop() 57 | 58 | @abc.abstractmethod 59 | def get_data_store(self): 60 | raise NotImplementedError() 61 | 62 | @abc.abstractmethod 63 | def get_broker(self): 64 | raise NotImplementedError() 65 | 66 | @abc.abstractmethod 67 | def get_feed(self): 68 | raise NotImplementedError() 69 | 70 | @abc.abstractmethod 71 | def get_portfolio(self): 72 | raise NotImplementedError() 73 | 74 | 75 | class Manager(Startable, HasId): 76 | __metaclass__ = abc.ABCMeta 77 | 78 | def __init__(self): 79 | super(Manager, self).__init__() 80 | 81 | 82 | class SimpleManager(Manager): 83 | __metaclass__ = abc.ABCMeta 84 | 85 | __slots__ = ( 86 | 'item_dict' 87 | ) 88 | 89 | def __init__(self): 90 | super(SimpleManager, self).__init__() 91 | self.item_dict = {} 92 | 93 | def get(self, id): 94 | return self.item_dict.get(id, None) 95 | 96 | def add(self, item): 97 | self.item_dict[get_model_id(item)] = item 98 | 99 | def all_items(self): 100 | return [item for item in self.item_dict.values()] 101 | 102 | def has_item(self, id) -> bool: 103 | return id in self.item_dict 104 | 105 | def load_all(self): 106 | pass 107 | 108 | def save_all(self): 109 | pass 110 | 111 | def _start(self, app_context: Context) -> None: 112 | self.load_all() 113 | 114 | def _stop(self) -> None: 115 | self.save_all() 116 | 117 | for item in self.item_dict.values(): 118 | if isinstance(item, Startable): 119 | item.stop() 120 | 121 | self.reset() 122 | 123 | def reset(self) -> None: 124 | self.item_dict.clear() 125 | -------------------------------------------------------------------------------- /algotrader/analyzer/__init__.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class Analyzer(object): 5 | __metaclass__ = abc.ABCMeta 6 | 7 | @abc.abstractmethod 8 | def update(self, timestamp: int, total_equity: float): 9 | return 10 | -------------------------------------------------------------------------------- /algotrader/analyzer/drawdown.py: -------------------------------------------------------------------------------- 1 | from algotrader.analyzer import Analyzer 2 | from algotrader.trading.data_series import DataSeries 3 | 4 | 5 | class DrawDownAnalyzer(Analyzer): 6 | DrawDown = "DrawDown" 7 | DrawDownPct = "DrawDown%" 8 | HighEquity = "HighEquity" 9 | LowEquity = "LowEquity" 10 | CurrentRunUp = "CurrentRunUp" 11 | CurrentDrawDown = "CurrentDrawDown" 12 | 13 | def __init__(self, portfolio, state): 14 | self.portfolio = portfolio 15 | self.state = state 16 | self.series = DataSeries(time_series=self.state.drawdown.series) 17 | 18 | def update(self, timestamp: int, total_equity: float): 19 | if self.portfolio.performance.series.size() == 1: 20 | self.state.drawdown.low_equity = total_equity 21 | self.state.drawdown.high_equity = total_equity 22 | else: 23 | if total_equity > self.state.drawdown.high_equity: 24 | self.state.drawdown.high_equity = total_equity 25 | self.state.drawdown.low_equity = total_equity 26 | self.state.drawdown.current_drawdown = 0 27 | elif total_equity < self.state.drawdown.low_equity: 28 | self.state.drawdown.low_equity = total_equity 29 | self.state.drawdown.current_run_up = 0 30 | elif total_equity > self.state.drawdown.low_equity and total_equity < self.state.drawdown.high_equity: 31 | self.state.drawdown.current_drawdown = 1 - total_equity / self.state.drawdown.high_equity 32 | self.state.drawdown.current_run_up = total_equity / self.state.drawdown.low_equity - 1 33 | 34 | if self.portfolio.performance.series.size() >= 2: 35 | self.state.drawdown.last_drawdown = total_equity - self.state.drawdown.high_equity 36 | 37 | if self.state.drawdown.high_equity != 0: 38 | self.state.drawdown.last_drawdown_pct = abs( 39 | self.state.drawdown.last_drawdown / self.state.drawdown.high_equity) 40 | self.series.add(timestamp=timestamp, data={self.DrawDown: self.state.drawdown.last_drawdown, 41 | self.DrawDownPct: self.state.drawdown.last_drawdown_pct}) 42 | 43 | def get_result(self): 44 | return {self.DrawDown: self.state.drawdown.last_drawdown, 45 | self.DrawDownPct: self.state.drawdown.last_drawdown_pct, 46 | self.HighEquity: self.state.drawdown.high_equity, 47 | self.LowEquity: self.state.drawdown.low_equity, 48 | self.CurrentRunUp: self.state.drawdown.current_run_up, 49 | self.CurrentDrawDown: self.state.drawdown.current_drawdown} 50 | 51 | def get_series(self, keys=None): 52 | keys = keys if keys else [self.DrawDown, self.DrawDownPct] 53 | return self.series.get_series(keys) 54 | 55 | def last_drawdown(self) -> float: 56 | return self.state.drawdown.last_drawdown 57 | 58 | def last_drawdown_pct(self) -> float: 59 | return self.state.drawdown.last_drawdown_pct 60 | 61 | def high_equity(self) -> float: 62 | return self.state.drawdown.high_equity 63 | 64 | def low_equity(self) -> float: 65 | return self.state.drawdown.low_equity 66 | 67 | def current_run_up(self) -> float: 68 | return self.state.drawdown.current_run_up 69 | 70 | def current_drawdown(self) -> float: 71 | return self.state.drawdown.current_drawdown 72 | -------------------------------------------------------------------------------- /algotrader/analyzer/performance.py: -------------------------------------------------------------------------------- 1 | from algotrader.analyzer import Analyzer 2 | from algotrader.trading.data_series import DataSeries 3 | 4 | 5 | class PerformanceAnalyzer(Analyzer): 6 | Performance = "Performance" 7 | StockValue = "stock_value" 8 | Cash = "cash" 9 | TotalEquity = "total_equity" 10 | 11 | def __init__(self, portfolio, state): 12 | self.portfolio = portfolio 13 | self.state = state 14 | self.series = DataSeries(time_series=self.state.performance.series) 15 | 16 | def update(self, timestamp: int, total_equity: float): 17 | self.state.performance.total_equity = total_equity 18 | self.series.add(timestamp=timestamp, 19 | data={self.StockValue: self.state.stock_value, 20 | self.Cash: self.state.cash, 21 | self.TotalEquity: total_equity}) 22 | 23 | def get_result(self): 24 | return {self.StockValue: self.state.stock_value, 25 | self.Cash: self.state.cash, 26 | self.TotalEquity: self.state.performance.total_equity} 27 | 28 | def get_series(self, keys=None): 29 | keys = keys if keys else [self.StockValue, self.Cash, self.TotalEquity] 30 | return self.series.get_series(keys) 31 | 32 | def now(self, key): 33 | return self.series.now(key) 34 | 35 | def total_equity(self) -> float: 36 | return self.state.performance.total_equity 37 | -------------------------------------------------------------------------------- /algotrader/analyzer/pnl.py: -------------------------------------------------------------------------------- 1 | from algotrader.analyzer import Analyzer 2 | 3 | from algotrader.analyzer.performance import PerformanceAnalyzer 4 | from algotrader.trading.data_series import DataSeries 5 | 6 | 7 | class PnlAnalyzer(Analyzer): 8 | Pnl = "Pnl" 9 | 10 | def __init__(self, portfolio, state): 11 | self.portfolio = portfolio 12 | self.state = state 13 | self.series = DataSeries(time_series=self.state.pnl.series) 14 | 15 | def update(self, timestamp: int, total_equity: float): 16 | performance_series = self.portfolio.performance.series 17 | 18 | if self.series.size() >= 2: 19 | self.state.pnl.last_pnl = performance_series.get_by_idx(-1, PerformanceAnalyzer.TotalEquity) - \ 20 | performance_series.get_by_idx(-2, PerformanceAnalyzer.TotalEquity) 21 | 22 | self.series.add(timestamp=timestamp, data={self.Pnl: self.state.pnl.last_pnl}) 23 | else: 24 | self.state.pnl.last_pnl = 0 25 | self.series.add(timestamp=timestamp, data={self.Pnl: self.state.pnl.last_pnl}) 26 | 27 | def get_result(self): 28 | return {self.Pnl: self.state.pnl.last_pnl} 29 | 30 | def get_series(self, keys=None): 31 | keys = keys if keys else self.Pnl 32 | return {self.Pnl: self.series.get_series(self.Pnl)} 33 | 34 | def last_pnl(self) -> float: 35 | return self.state.pnl.last_pnl 36 | -------------------------------------------------------------------------------- /algotrader/app/__init__.py: -------------------------------------------------------------------------------- 1 | from algotrader import Startable, Context 2 | 3 | 4 | class Application(Startable): 5 | DataImport = "DataImport" 6 | LiveTrading = "LiveTrading" 7 | BackTesting = "BackTesting" 8 | 9 | def _start(self, app_context: Context) -> None: 10 | try: 11 | self.init() 12 | self.run() 13 | finally: 14 | self.stop() 15 | 16 | def init(self) -> None: 17 | pass 18 | 19 | def run(self) -> None: 20 | pass 21 | 22 | def _stop(self) -> None: 23 | self.app_context.stop() 24 | -------------------------------------------------------------------------------- /algotrader/app/backtest_runner.py: -------------------------------------------------------------------------------- 1 | from algotrader.app import Application 2 | 3 | from algotrader.chart.plotter import StrategyPlotter 4 | from algotrader.trading.config import Config, load_from_yaml 5 | from algotrader.trading.context import ApplicationContext 6 | from algotrader.utils.logging import logger 7 | 8 | 9 | class BacktestRunner(Application): 10 | def init(self): 11 | self.config = self.app_context.config 12 | 13 | self.is_plot = self.config.get_app_config("plot", default=True) 14 | 15 | def run(self): 16 | logger.info("starting BackTest") 17 | 18 | self.app_context.start() 19 | self.portfolio = self.app_context.portf_mgr.get_or_new_portfolio(self.config.get_app_config("portfolioId"), 20 | self.config.get_app_config( 21 | "portfolioInitialcash")) 22 | self.strategy = self.app_context.stg_mgr.get_or_new_stg(self.config.get_app_config("stgId"), 23 | self.config.get_app_config("stgCls")) 24 | 25 | self.initial_result = self.portfolio.get_result() 26 | self.app_context.add_startable(self.portfolio) 27 | self.portfolio.start(self.app_context) 28 | self.strategy.start(self.app_context) 29 | 30 | result = self.portfolio.get_result() 31 | print("Initial:", self.initial_result) 32 | print("Final:", result) 33 | if self.is_plot: 34 | self.plot() 35 | 36 | def plot(self): 37 | # pyfolio 38 | ret = self.portfolio.get_return() 39 | # import pyfolio as pf 40 | # pf.create_returns_tear_sheet(ret) 41 | # pf.create_full_tear_sheet(ret) 42 | 43 | # build in plot 44 | 45 | plotter = StrategyPlotter(self.strategy) 46 | plotter.plot(instrument=self.app_context.config.get_app_config("instrumentIds")[0]) 47 | 48 | 49 | def main(): 50 | config = Config( 51 | load_from_yaml("../../config/backtest.yaml"), 52 | load_from_yaml("../../config/down2%.yaml")) 53 | 54 | app_context = ApplicationContext(config=config) 55 | 56 | BacktestRunner().start(app_context) 57 | 58 | 59 | if __name__ == "__main__": 60 | main() 61 | -------------------------------------------------------------------------------- /algotrader/app/live_ats.py: -------------------------------------------------------------------------------- 1 | from algotrader.app import Application 2 | from algotrader.trading.config import Config, load_from_yaml 3 | from algotrader.trading.context import ApplicationContext 4 | from algotrader.utils.logging import logger 5 | 6 | 7 | class ATSRunner(Application): 8 | def init(self): 9 | self.config = self.config 10 | 11 | self.portfolio = self.app_context.portf_mgr.get_or_new_portfolio(self.config.get_app_config("portfolioId"), 12 | self.config.get_app_config( 13 | "portfolioInitialcash")) 14 | self.app_context.add_startable(self.portfolio) 15 | 16 | self.strategy = self.app_context.stg_mgr.get_or_new_stg(self.config.get_app_config("stgId"), 17 | self.config.get_app_config("stgCls")) 18 | self.app_context.add_startable(self.strategy) 19 | 20 | def run(self): 21 | logger.info("starting ATS") 22 | 23 | self.app_context.start() 24 | self.strategy.start(self.app_context) 25 | 26 | logger.info("ATS started, presss Ctrl-C to stop") 27 | 28 | 29 | def main(): 30 | config = Config( 31 | load_from_yaml("../../config/live_ib.yaml"), 32 | load_from_yaml("../../config/down2%.yaml")) 33 | 34 | app_context = ApplicationContext(config=config) 35 | 36 | ATSRunner().start(app_context) 37 | 38 | 39 | if __name__ == "__main__": 40 | main() 41 | -------------------------------------------------------------------------------- /algotrader/app/mkt_data_importer.py: -------------------------------------------------------------------------------- 1 | from gevent import monkey 2 | 3 | monkey.patch_all() 4 | from algotrader.trading.subscription import MarketDataSubscriber 5 | from algotrader.app import Application 6 | from algotrader.utils.logging import logger 7 | import time 8 | from algotrader.trading.config import Config, load_from_yaml 9 | from algotrader.trading.context import ApplicationContext 10 | from algotrader.utils.market_data import build_subscription_requests 11 | 12 | 13 | class MktDataImporter(Application, MarketDataSubscriber): 14 | def run(self): 15 | logger.info("importing data") 16 | self.app_context.start() 17 | config = self.app_context.config 18 | 19 | feed = self.app_context.provider_mgr.get(config.get_app_config("feedId")) 20 | feed.start(self.app_context) 21 | instruments = self.app_context.ref_data_mgr.get_insts_by_ids(config.get_app_config("instrumentIds")) 22 | 23 | for sub_req in build_subscription_requests(feed.id(), instruments, 24 | config.get_app_config("subscriptionTypes"), 25 | config.get_app_config("fromDate"), 26 | config.get_app_config("toDate")): 27 | feed.subscribe_mktdata(sub_req) 28 | 29 | logger.info("ATS started, presss Ctrl-C to stop") 30 | for i in range(1, 1000): 31 | time.sleep(1) 32 | logger.info(".") 33 | 34 | 35 | def main(): 36 | config = Config(load_from_yaml("../../config/data_import.yaml"), 37 | {"Application": {"feedId": "Yahoo"}}) 38 | app_context = ApplicationContext(config=config) 39 | MktDataImporter().start(app_context) 40 | 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /algotrader/chart/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexcwyu/python-trading/a494f602411a3ebfdecae002a16a5ea93fc7a046/algotrader/chart/__init__.py -------------------------------------------------------------------------------- /algotrader/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexcwyu/python-trading/a494f602411a3ebfdecae002a16a5ea93fc7a046/algotrader/model/__init__.py -------------------------------------------------------------------------------- /algotrader/model2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexcwyu/python-trading/a494f602411a3ebfdecae002a16a5ea93fc7a046/algotrader/model2/__init__.py -------------------------------------------------------------------------------- /algotrader/provider/__init__.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from algotrader import Startable, HasId 4 | 5 | 6 | class Provider(Startable, HasId): 7 | __metaclass__ = abc.ABCMeta 8 | 9 | def __init__(self): 10 | super(Provider, self).__init__() 11 | 12 | 13 | from algotrader import SimpleManager 14 | 15 | from algotrader.provider.broker.ib.ib_broker import IBBroker 16 | from algotrader.provider.broker.sim.simulator import Simulator 17 | from algotrader.provider.feed.csv import CSVDataFeed 18 | from algotrader.provider.feed.pandas_web import PandasWebDataFeed 19 | from algotrader.provider.feed.pandas_memory import PandasMemoryDataFeed 20 | from algotrader.provider.datastore.inmemory import InMemoryDataStore 21 | from algotrader.provider.datastore.mongodb import MongoDBDataStore 22 | 23 | 24 | class ProviderManager(SimpleManager): 25 | def __init__(self): 26 | super(ProviderManager, self).__init__() 27 | 28 | self.add(Simulator()) 29 | self.add(IBBroker()) 30 | 31 | self.add(MongoDBDataStore()) 32 | self.add(InMemoryDataStore()) 33 | 34 | self.add(CSVDataFeed()) 35 | self.add(PandasWebDataFeed()) 36 | self.add(PandasMemoryDataFeed()) 37 | 38 | def id(self): 39 | return "ProviderManager" 40 | -------------------------------------------------------------------------------- /algotrader/provider/broker/__init__.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from algotrader.provider import Provider 4 | from algotrader.trading.event import OrderEventHandler 5 | 6 | 7 | class Broker(Provider, OrderEventHandler): 8 | Simulator = "Simulator" 9 | IB = "IB" 10 | 11 | __metaclass__ = abc.ABCMeta 12 | 13 | def __init__(self): 14 | super(Provider, self).__init__() 15 | 16 | def _get_broker_config(self, path: str, default=None): 17 | return self.app_context.config.get_broker_config(self.id(), path, default=default) 18 | -------------------------------------------------------------------------------- /algotrader/provider/broker/ib/__init__.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | OrderAttribute = namedtuple('TimeInForce', 4 | ['DAY', 'GTC', 'OPG', 'IOC', 'GTD', 'FOK', 'DTC']) 5 | 6 | time_in_force = OrderAttribute('DAY', 'GTC', 'OPG', 'IOC', 'GTD', 'FOK', 'DTC') 7 | -------------------------------------------------------------------------------- /algotrader/provider/broker/sim/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexcwyu/python-trading/a494f602411a3ebfdecae002a16a5ea93fc7a046/algotrader/provider/broker/sim/__init__.py -------------------------------------------------------------------------------- /algotrader/provider/broker/sim/commission.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class Commission(object): 5 | Default = 0 6 | 7 | __metaclass__ = abc.ABCMeta 8 | 9 | @abc.abstractmethod 10 | def calc(self, new_ord_req, price, qty): 11 | raise NotImplementedError() 12 | 13 | 14 | class NoCommission(Commission): 15 | def calc(self, new_ord_req, price, qty): 16 | return 0 17 | 18 | 19 | class FixedPerTrade(Commission): 20 | def __init__(self, amount): 21 | self.amount = amount 22 | 23 | def calc(self, new_ord_req, price, qty): 24 | return self.amount 25 | 26 | 27 | class TradePercentage(Commission): 28 | def __init__(self, percentage): 29 | self.percentage = percentage 30 | 31 | def calc(self, new_ord_req, price, qty): 32 | return price * qty * self.percentage 33 | -------------------------------------------------------------------------------- /algotrader/provider/broker/sim/data_processor.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import math 3 | 4 | from algotrader.model.market_data_pb2 import Trade 5 | from algotrader.provider.broker.sim.sim_config import SimConfig 6 | from algotrader.utils.trade_data import is_buy, is_sell 7 | 8 | 9 | class MarketDataProcessor(object): 10 | __metaclass__ = abc.ABCMeta 11 | 12 | @abc.abstractmethod 13 | def get_price(self, new_ord_req, market_data, config, new_order=False): 14 | raise NotImplementedError() 15 | 16 | @abc.abstractmethod 17 | def get_qty(self, new_ord_req, market_data, config): 18 | raise NotImplementedError() 19 | 20 | 21 | class BarProcessor(MarketDataProcessor): 22 | def get_price(self, new_ord_req, market_data, config, new_order=False): 23 | if config.fill_on_bar_mode == SimConfig.FillMode.LAST or config.fill_on_bar_mode == SimConfig.FillMode.NEXT_CLOSE: 24 | return market_data.close 25 | elif not new_order and config.fill_on_bar_mode == SimConfig.FillMode.NEXT_OPEN: 26 | return market_data.open 27 | return 0.0 28 | 29 | def get_qty(self, new_ord_req, market_data, config): 30 | if config.partial_fill: 31 | bar_vol = math.trunc( 32 | market_data.vol if not config.bar_vol_ratio else market_data.vol * config.bar_vol_ratio) 33 | return min(new_ord_req.qty, bar_vol) 34 | return new_ord_req.qty 35 | 36 | 37 | class QuoteProcessor(MarketDataProcessor): 38 | def get_price(self, new_ord_req, market_data, config, new_order=False): 39 | if is_buy(new_ord_req) and market_data.ask > 0: 40 | return market_data.ask 41 | elif is_sell(new_ord_req) and market_data.bid > 0: 42 | return market_data.bid 43 | return 0.0 44 | 45 | def get_qty(self, new_ord_req, market_data, config): 46 | if config.partial_fill: 47 | if is_buy(new_ord_req): 48 | return min(market_data.ask_size, new_ord_req.qty) 49 | elif is_sell(new_ord_req): 50 | return min(market_data.bid_size, new_ord_req.qty) 51 | return new_ord_req.qty 52 | 53 | 54 | class TradeProcessor(MarketDataProcessor): 55 | def get_price(self, new_ord_req, market_data, config, new_order=False): 56 | if market_data and isinstance(market_data, Trade): 57 | if market_data.price > 0: 58 | return market_data.price 59 | return 0.0 60 | 61 | def get_qty(self, new_ord_req, market_data, config): 62 | if config.partial_fill: 63 | return min(new_ord_req.qty, market_data.size) 64 | return new_ord_req.qty 65 | -------------------------------------------------------------------------------- /algotrader/provider/broker/sim/fill_strategy.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from algotrader.model.market_data_pb2 import Bar, Quote, Trade 4 | from algotrader.model.trade_data_pb2 import * 5 | from algotrader.provider.broker.sim.order_handler import MarketOrderHandler, LimitOrderHandler, StopLimitOrderHandler, \ 6 | StopOrderHandler, TrailingStopOrderHandler 7 | from algotrader.provider.broker.sim.sim_config import SimConfig 8 | from algotrader.provider.broker.sim.slippage import NoSlippage 9 | 10 | 11 | class FillStrategy(object): 12 | Default = 0 13 | 14 | __metaclass__ = abc.ABCMeta 15 | 16 | @abc.abstractmethod 17 | def process_new_order(self, new_ord_req): 18 | raise NotImplementedError() 19 | 20 | @abc.abstractmethod 21 | def process_w_market_data(self, market_data): 22 | raise NotImplementedError() 23 | 24 | @abc.abstractmethod 25 | def process_w_price_qty(self, new_ord_req, price, qty): 26 | raise NotImplementedError() 27 | 28 | 29 | class DefaultFillStrategy(FillStrategy): 30 | def __init__(self, app_context=None, sim_config=None, slippage=None): 31 | self.app_context = app_context 32 | self.__sim_config = sim_config if sim_config else SimConfig() 33 | self.__slippage = slippage if slippage else NoSlippage() 34 | self.__market_ord_handler = MarketOrderHandler(self.__sim_config, self.__slippage) 35 | self.__limit_ord_handler = LimitOrderHandler(self.__sim_config) 36 | self.__stop_limit_ord_handler = StopLimitOrderHandler(self.__sim_config) 37 | self.__stop_ord_handler = StopOrderHandler(self.__sim_config, self.__slippage) 38 | self.__trailing_stop_ord_handler = TrailingStopOrderHandler(self.__sim_config, self.__slippage) 39 | 40 | def process_new_order(self, new_ord_req): 41 | fill_info = None 42 | config = self.__sim_config 43 | 44 | quote = self.app_context.inst_data_mgr.get_quote(new_ord_req.inst_id) 45 | trade = self.app_context.inst_data_mgr.get_trade(new_ord_req.inst_id) 46 | bar = self.app_context.inst_data_mgr.get_bar(new_ord_req.inst_id) 47 | 48 | if not fill_info and config.fill_on_quote and config.fill_on_bar_mode == SimConfig.FillMode.LAST and quote: 49 | fill_info = self.process_w_market_data(new_ord_req, quote, True) 50 | elif not fill_info and config.fill_on_trade and config.fill_on_trade_mode == SimConfig.FillMode.LAST and trade: 51 | fill_info = self.process_w_market_data(new_ord_req, trade, True) 52 | elif not fill_info and config.fill_on_bar and config.fill_on_bar_mode == SimConfig.FillMode.LAST and bar: 53 | fill_info = self.process_w_market_data(new_ord_req, bar, True) 54 | 55 | return fill_info 56 | 57 | def process_w_market_data(self, new_ord_req, event, new_order=False): 58 | 59 | config = self.__sim_config 60 | 61 | if not event \ 62 | or (isinstance(event, Bar) and not config.fill_on_bar) \ 63 | or (isinstance(event, Trade) and not config.fill_on_trade) \ 64 | or (isinstance(event, Quote) and not config.fill_on_quote): 65 | return None 66 | 67 | if new_ord_req.type == Market: 68 | return self.__market_ord_handler.process(new_ord_req, event, new_order) 69 | elif new_ord_req.type == Limit: 70 | return self.__limit_ord_handler.process(new_ord_req, event, new_order) 71 | elif new_ord_req.type == StopLimit: 72 | return self.__stop_limit_ord_handler.process(new_ord_req, event, new_order) 73 | elif new_ord_req.type == Stop: 74 | return self.__stop_ord_handler.process(new_ord_req, event, new_order) 75 | elif new_ord_req.type == TrailingStop: 76 | return self.__trailing_stop_ord_handler.process(new_ord_req, event, new_order) 77 | assert False 78 | 79 | def process_w_price_qty(self, new_ord_req, price, qty): 80 | if new_ord_req.type == Market: 81 | return self.__market_ord_handler.process_w_price_qty(new_ord_req, price, qty) 82 | elif new_ord_req.type == Limit: 83 | return self.__limit_ord_handler.process_w_price_qty(new_ord_req, price, qty) 84 | elif new_ord_req.type == StopLimit: 85 | return self.__stop_limit_ord_handler.process_w_price_qty(new_ord_req, price, qty) 86 | elif new_ord_req.type == Stop: 87 | return self.__stop_ord_handler.process_w_price_qty(new_ord_req, price, qty) 88 | elif new_ord_req.type == TrailingStop: 89 | return self.__trailing_stop_ord_handler.process_w_price_qty(new_ord_req, price, qty) 90 | return None 91 | -------------------------------------------------------------------------------- /algotrader/provider/broker/sim/sim_config.py: -------------------------------------------------------------------------------- 1 | class SimConfig: 2 | class FillMode: 3 | LAST = 0 4 | NEXT_OPEN = 1 5 | NEXT_CLOSE = 2 6 | 7 | def __init__(self, partial_fill=True, 8 | fill_on_quote=True, 9 | fill_on_trade=True, 10 | fill_on_bar=True, 11 | fill_on_quote_mode=FillMode.LAST, 12 | fill_on_trade_mode=FillMode.LAST, 13 | fill_on_bar_mode=FillMode.LAST, 14 | bar_vol_ratio=1): 15 | self.partial_fill = partial_fill 16 | self.fill_on_quote = fill_on_quote 17 | self.fill_on_trade = fill_on_trade 18 | self.fill_on_bar = fill_on_bar 19 | self.fill_on_quote_mode = fill_on_quote_mode 20 | self.fill_on_trade_mode = fill_on_trade_mode 21 | self.fill_on_bar_mode = fill_on_bar_mode 22 | self.bar_vol_ratio = bar_vol_ratio if 0 < bar_vol_ratio <= 1 else 1 23 | -------------------------------------------------------------------------------- /algotrader/provider/broker/sim/slippage.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from algotrader.utils.trade_data import is_buy 4 | 5 | 6 | class Slippage(object): 7 | __metaclass__ = abc.ABCMeta 8 | 9 | def calc_price_w_bar(self, new_ord_req, price, qty, bar): 10 | return self.calc_price(new_ord_req, price, qty, bar.vol) 11 | 12 | def calc_price_w_quote(self, new_ord_req, price, qty, quote): 13 | if is_buy(new_ord_req): 14 | return self.calc_price(new_ord_req, price, qty, quote.bid_size) 15 | else: 16 | return self.calc_price(new_ord_req, price, qty, quote.ask_size) 17 | 18 | def calc_price_w_trade(self, new_ord_req, price, qty, trade): 19 | return self.calc_price(new_ord_req, price, qty, trade.size) 20 | 21 | @abc.abstractmethod 22 | def calc_price(self, new_ord_req, price, qty, avail_qty): 23 | raise NotImplementedError() 24 | 25 | 26 | class NoSlippage(Slippage): 27 | def calc_price(self, new_ord_req, price, qty, avail_qty): 28 | return price 29 | 30 | 31 | class VolumeShareSlippage(Slippage): 32 | def __init__(self, price_impact=0.1): 33 | self.price_impact = price_impact 34 | 35 | def calc_price(self, new_ord_req, price, qty, avail_qty): 36 | 37 | vol_share = float(qty) / float(avail_qty) 38 | impacted_price = vol_share ** 2 * self.price_impact 39 | if is_buy(new_ord_req): 40 | return price * (1 + impacted_price) 41 | else: 42 | return price * (1 - impacted_price) 43 | -------------------------------------------------------------------------------- /algotrader/provider/feed/__init__.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import pandas as pd 4 | 5 | from algotrader.model.market_data_pb2 import * 6 | from algotrader.model.model_factory import ModelFactory 7 | from algotrader.provider import Provider 8 | from algotrader.utils.date import datestr_to_unixtimemillis, datetime_to_unixtimemillis 9 | from algotrader.utils.market_data import D1 10 | 11 | 12 | class Feed(Provider): 13 | CSV = "CSV" 14 | PandasMemory = "PandasMemory" 15 | PandasH5 = "PandasH5" 16 | PandasWeb = "PandasWeb" 17 | Yahoo = "Yahoo" 18 | Google = "Google" 19 | Quandl = "Quandl" 20 | 21 | __metaclass__ = abc.ABCMeta 22 | 23 | def __init__(self): 24 | super(Provider, self).__init__() 25 | 26 | @abc.abstractmethod 27 | def subscribe_mktdata(self, *sub_reqs): 28 | raise NotImplementedError() 29 | 30 | @abc.abstractmethod 31 | def unsubscribe_mktdata(self, *sub_reqs): 32 | raise NotImplementedError() 33 | 34 | def _get_feed_config(self, path: str, default=None): 35 | return self.app_context.config.get_feed_config(self.id(), path, default=default) 36 | 37 | 38 | class PandasDataFeed(Feed): 39 | __metaclass__ = abc.ABCMeta 40 | 41 | def subscribe_mktdata(self, *sub_reqs): 42 | self._verify_subscription(*sub_reqs); 43 | sub_req_ranges = {} 44 | insts = {} 45 | for sub_req in sub_reqs: 46 | insts[sub_req.inst_id] = self.app_context.ref_data_mgr.get_inst(inst_id=sub_req.inst_id) 47 | sub_req_ranges[sub_req.inst_id] = ( 48 | datestr_to_unixtimemillis(str(sub_req.from_date)), datestr_to_unixtimemillis(str(sub_req.to_date))) 49 | 50 | dfs = self._load_dataframes(insts, *sub_reqs) 51 | self._publish(dfs, sub_req_ranges, insts) 52 | 53 | def _verify_subscription(self, *sub_reqs): 54 | for sub_req in sub_reqs: 55 | if not sub_req.from_date or sub_req.type != MarketDataSubscriptionRequest.Bar or sub_req.bar_type != Bar.Time or sub_req.bar_size != D1: 56 | raise RuntimeError("only HistDataSubscriptionKey is supported!") 57 | 58 | def _within_range(self, inst_id, timestamp, sub_req_ranges): 59 | sub_req_range = sub_req_ranges[inst_id] 60 | return timestamp >= sub_req_range[0] and (not sub_req_range[1] or timestamp < sub_req_range[1]) 61 | 62 | def _publish(self, dfs, sub_req_ranges, insts): 63 | df = pd.concat(dfs).sort_index(0, ascending=True) 64 | 65 | for index, row in df.iterrows(): 66 | inst = insts[row['InstId']] 67 | timestamp = datetime_to_unixtimemillis(index) 68 | if self._within_range(row['InstId'], timestamp, sub_req_ranges): 69 | bar = self._build_bar(row, timestamp) 70 | self.app_context.event_bus.data_subject.on_next(bar) 71 | 72 | def _build_bar(self, row, timestamp) -> Bar: 73 | return ModelFactory.build_bar( 74 | inst_id=row['InstId'], 75 | type=Bar.Time, 76 | provider_id=row['ProviderId'], 77 | timestamp=timestamp, 78 | open=row['Open'], 79 | high=row['High'], 80 | low=row['Low'], 81 | close=row['Close'], 82 | vol=row['Volume'], 83 | adj_close=row['Adj Close'] if 'Adj Close' in row else None, 84 | size=row['BarSize']) 85 | 86 | @abc.abstractmethod 87 | def _load_dataframes(self, insts, *sub_reqs): 88 | raise NotImplementedError() 89 | 90 | def unsubscribe_mktdata(self, *sub_reqs): 91 | pass 92 | 93 | def _stop(self): 94 | pass 95 | -------------------------------------------------------------------------------- /algotrader/provider/feed/csv.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from algotrader.provider.feed import Feed, PandasDataFeed 4 | from algotrader import Context 5 | 6 | 7 | class CSVDataFeed(PandasDataFeed): 8 | dateparse = lambda x: pd.datetime.strptime(x, '%Y-%m-%d') 9 | 10 | def __init__(self): 11 | super(CSVDataFeed, self).__init__() 12 | 13 | def _start(self, app_context : Context) -> None: 14 | self.path = self._get_feed_config("path") 15 | 16 | def id(self): 17 | return Feed.CSV 18 | 19 | def _load_dataframes(self, insts, *sub_reqs): 20 | dfs = [] 21 | for sub_req in sub_reqs: 22 | inst = insts[sub_req.inst_id] 23 | df = pd.read_csv('%s/%s.csv' % (self.path, inst.symbol.lower()), index_col='Date', parse_dates=['Date'], 24 | date_parser=CSVDataFeed.dateparse) 25 | df['InstId'] = sub_req.inst_id 26 | df['ProviderId'] = sub_req.md_provider_id 27 | df['BarSize'] = sub_req.bar_size 28 | dfs.append(df) 29 | return dfs 30 | -------------------------------------------------------------------------------- /algotrader/provider/feed/pandas_h5.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from algotrader import Context 4 | from algotrader.provider.feed import Feed, PandasDataFeed 5 | 6 | 7 | class PandaH5DataFeed(PandasDataFeed): 8 | """ 9 | This is a class to make a data feed from dataframe we already have in memory 10 | """ 11 | 12 | def __init__(self): 13 | super(PandaH5DataFeed, self).__init__() 14 | 15 | def _start(self, app_context: Context) -> None: 16 | self.h5file = self._get_feed_config("path") 17 | 18 | def id(self): 19 | return Feed.PandasH5 20 | 21 | def _load_dataframes(self, insts, *sub_reqs): 22 | dfs = [] 23 | with pd.HDFStore(self.h5file) as store: 24 | for sub_req in sub_reqs: 25 | df = store[sub_req.inst_id] 26 | df['InstId'] = sub_req.inst_id 27 | df['ProviderId'] = sub_req.md_provider_id 28 | df['BarSize'] = sub_req.bar_size 29 | dfs.append(df) 30 | return dfs 31 | -------------------------------------------------------------------------------- /algotrader/provider/feed/pandas_memory.py: -------------------------------------------------------------------------------- 1 | from algotrader import Context 2 | from algotrader.provider.feed import PandasDataFeed, Feed 3 | 4 | 5 | class PandasMemoryDataFeed(PandasDataFeed): 6 | """ 7 | This is a class to make a data feed from dataframe we already have in memory 8 | """ 9 | 10 | def __init__(self): 11 | super(PandasMemoryDataFeed, self).__init__() 12 | 13 | def _start(self, app_context: Context) -> None: 14 | pass 15 | 16 | def set_data_frame(self, dict_of_df): 17 | self.dict_of_df = dict_of_df 18 | 19 | def id(self): 20 | return Feed.PandasMemory 21 | 22 | def _load_dataframes(self, insts, *sub_reqs): 23 | dfs = [] 24 | for sub_req in sub_reqs: 25 | df = self.dict_of_df[sub_req.inst_id] 26 | df['InstId'] = sub_req.inst_id 27 | df['ProviderId'] = sub_req.md_provider_id 28 | df['BarSize'] = sub_req.bar_size 29 | 30 | dfs.append(df) 31 | return dfs 32 | -------------------------------------------------------------------------------- /algotrader/provider/feed/pandas_web.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from pandas_datareader import data 4 | 5 | from algotrader import Context 6 | from algotrader.provider.feed import Feed, PandasDataFeed 7 | from algotrader.utils.date import * 8 | 9 | 10 | class PandasWebDataFeed(PandasDataFeed): 11 | Supported = set(['Yahoo', 'Google']) 12 | 13 | def __init__(self): 14 | super(PandasWebDataFeed, self).__init__() 15 | 16 | def _start(self, app_context: Context) -> None: 17 | pass 18 | 19 | def id(self): 20 | return Feed.PandasWeb 21 | 22 | @abc.abstractmethod 23 | def process_row(self, row): 24 | raise NotImplementedError 25 | 26 | def _verify_subscription(self, *sub_reqs): 27 | super(PandasWebDataFeed, self)._verify_subscription(*sub_reqs) 28 | for sub_req in sub_reqs: 29 | if not sub_req.md_provider_id or sub_req.md_provider_id not in PandasWebDataFeed.Supported: 30 | raise RuntimeError("only yahoo and goolge is supported!") 31 | 32 | def _load_dataframes(self, insts, *sub_reqs): 33 | dfs = [] 34 | for sub_req in sub_reqs: 35 | inst = insts[sub_req.inst_id] 36 | df = data.DataReader(inst.symbol.lower(), sub_req.md_provider_id.lower(), 37 | datestr_to_date(str(sub_req.from_date)), 38 | datestr_to_date(str(sub_req.to_date))) 39 | df['InstId'] = sub_req.inst_id 40 | df['ProviderId'] = sub_req.md_provider_id 41 | df['BarSize'] = sub_req.bar_size 42 | dfs.append(df) 43 | return dfs 44 | -------------------------------------------------------------------------------- /algotrader/strategy/alpha_formula.py: -------------------------------------------------------------------------------- 1 | from algotrader.event.order import OrdAction 2 | from algotrader.strategy.strategy import Strategy 3 | from algotrader.technical.roc import ROC 4 | from algotrader.technical.pipeline.pairwise import PairCorrelation 5 | from algotrader.technical.pipeline.make_vector import MakeVector 6 | from algotrader.technical.pipeline.rank import Rank 7 | from algotrader.utils import logger 8 | import numpy as np 9 | 10 | 11 | class AlphaFormula3(Strategy): 12 | def __init__(self, stg_id=None, stg_configs=None): 13 | super(AlphaFormula3, self).__init__(stg_id=stg_id, stg_configs=stg_configs) 14 | self.day_count = 0 15 | self.order = None 16 | 17 | def _start(self, app_context, **kwargs): 18 | self.length = self.get_stg_config_value("length", 10) 19 | 20 | self.bars = [self.app_context.inst_data_mgr.get_series( 21 | "Bar.%s.Time.300" % i) for i in self.app_context.app_config.instrument_ids] 22 | 23 | for bar in self.bars: 24 | bar.start(app_context) 25 | 26 | self.opens = MakeVector(self.bars, input_key='Open') 27 | self.volumes = MakeVector(self.bars, input_key="Volume") 28 | self.rank_opens = Rank(self.bars, input_key='open') 29 | self.rank_opens.start(app_context) 30 | 31 | self.rank_volumes = Rank(self.bars, input_key='Volume') 32 | self.rank_volumes.start(app_context) 33 | # 34 | self.pair_correlation = PairCorrelation(self.rank_opens, self.rank_volumes, length=self.length) 35 | self.pair_correlation.start(app_context) 36 | 37 | super(AlphaFormula3, self)._start(app_context, **kwargs) 38 | 39 | def _stop(self): 40 | super(AlphaFormula3, self)._stop() 41 | 42 | def on_bar(self, bar): 43 | # rank = self.rank_opens.now('value') 44 | # logger.info("[%s] %s" % (self.__class__.__name__, rank)) 45 | # if np.all(np.isnan(rank)): 46 | # return 47 | corr = self.pair_correlation.now('value') 48 | if np.any(np.isnan(corr)): 49 | return 50 | 51 | 52 | weight = [corr[i, i+2] for i in range(len(self.bars))] 53 | # weight = rank 54 | weight = -1*weight[0] 55 | 56 | portfolio = self.get_portfolio() 57 | allocation = portfolio.total_equity * weight 58 | delta = allocation - portfolio.stock_value 59 | 60 | index = self.app_context.app_config.instrument_ids.index(bar.inst_id) 61 | qty = delta[index] 62 | # logger.info("%s,B,%.2f" % (bar.timestamp, bar.close)) 63 | self.order = self.market_order(inst_id=bar.inst_id, action=OrdAction.BUY, qty=qty) if qty > 0 else \ 64 | self.market_order(inst_id=bar.inst_id, action=OrdAction.SELL, qty=-qty) 65 | 66 | 67 | -------------------------------------------------------------------------------- /algotrader/strategy/cross_sectional_mean_reverting.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on 10/10/16 3 | Author = jchan 4 | """ 5 | __author__ = 'jchan' 6 | -------------------------------------------------------------------------------- /algotrader/strategy/down_2pct_strategy.py: -------------------------------------------------------------------------------- 1 | from algotrader import Context 2 | from algotrader.model.trade_data_pb2 import * 3 | from algotrader.strategy import Strategy 4 | from algotrader.technical.roc import ROC 5 | 6 | 7 | class Down2PctStrategy(Strategy): 8 | def __init__(self, stg_id: str, stg_cls: str, state: StrategyState = None): 9 | super(Down2PctStrategy, self).__init__(stg_id=stg_id, stg_cls=stg_cls, state=state) 10 | self.day_count = 0 11 | self.order = None 12 | 13 | def _start(self, app_context: Context) -> None: 14 | self.qty = self._get_stg_config("qty", default=1) 15 | 16 | self.close = self.app_context.inst_data_mgr.get_series( 17 | "Bar.%s.Time.86400" % app_context.config.get_app_config("instrumentIds")[0]) 18 | self.close.start(app_context) 19 | 20 | self.roc = ROC(inputs=self.close, input_keys='close', length=1) 21 | self.roc.start(app_context) 22 | 23 | super(Down2PctStrategy, self)._start(app_context) 24 | 25 | def _stop(self): 26 | super(Down2PctStrategy, self)._stop() 27 | 28 | def on_bar(self, bar): 29 | if self.order is None: 30 | if self.roc.now('value') < -0.02: 31 | # logger.info("%s,B,%.2f" % (bar.timestamp, bar.close)) 32 | self.order = self.market_order(inst_id=bar.inst_id, action=Buy, qty=self.qty) 33 | self.day_count = 0 34 | else: 35 | self.day_count += 1 36 | if self.day_count >= 5: 37 | # logger.info("%s,S,%.2f" % (bar.timestamp, bar.close)) 38 | self.market_order(inst_id=bar.inst_id, action=Sell, qty=self.qty) 39 | self.order = None 40 | -------------------------------------------------------------------------------- /algotrader/strategy/ema_strategy.py: -------------------------------------------------------------------------------- 1 | from algotrader import Context 2 | from algotrader.model.trade_data_pb2 import * 3 | from algotrader.strategy import Strategy 4 | from algotrader.technical.talib_wrapper import EMA 5 | from algotrader.utils.logging import logger 6 | 7 | 8 | class EMAStrategy(Strategy): 9 | def __init__(self, stg_id: str, stg_cls: str, state: StrategyState = None): 10 | super(EMAStrategy, self).__init__(stg_id=stg_id, stg_cls=stg_cls, state=state) 11 | self.buy_order = None 12 | 13 | def _start(self, app_context: Context) -> None: 14 | 15 | self.instruments = app_context.config.get_app_config("instrumentIds") 16 | self.qty = self._get_stg_config("qty", default=1) 17 | 18 | self.bar = self.app_context.inst_data_mgr.get_series( 19 | "Bar.%s.Time.86400" % self.instruments[0]) 20 | self.bar.start(app_context) 21 | 22 | self.ema_fast = EMA(self.bar, 'close', 10) 23 | self.ema_fast.start(app_context) 24 | 25 | self.ema_slow = EMA(self.bar, 'close', 25) 26 | self.ema_slow.start(app_context) 27 | 28 | super(EMAStrategy, self)._start(app_context) 29 | 30 | def _stop(self): 31 | super(EMAStrategy, self)._stop() 32 | 33 | def on_bar(self, bar): 34 | if self.buy_order is None and self.ema_fast.now('value') > self.ema_slow.now('value'): 35 | self.buy_order = self.market_order(inst_id=bar.inst_id, action=Buy, qty=self.qty) 36 | logger.info("%s,B,%s,%s,%.2f,%.2f,%.2f" % ( 37 | bar.timestamp, self.buy_order.cl_id, self.buy_order.cl_ord_id, bar.close, self.ema_fast.now('value'), 38 | self.ema_slow.now('value'))) 39 | elif self.buy_order is not None and self.ema_fast.now('value') < self.ema_slow.now('value'): 40 | sell_order = self.market_order(inst_id=bar.inst_id, action=Sell, qty=self.qty) 41 | logger.info("%s,S,%s,%s,%.2f,%.2f,%.2f" % ( 42 | bar.timestamp, sell_order.cl_id, sell_order.cl_ord_id, bar.close, self.ema_fast.now('value'), 43 | self.ema_slow.now('value'))) 44 | -------------------------------------------------------------------------------- /algotrader/strategy/merton_optimal.py: -------------------------------------------------------------------------------- 1 | from algotrader import Context 2 | from algotrader.model.trade_data_pb2 import * 3 | from algotrader.strategy import Strategy 4 | 5 | 6 | class MertonOptimalBaby(Strategy): 7 | """ 8 | This is the baby version that assume appreciation rate and the volatility of the underlying is known 9 | in advance before constructing the strategy 10 | in reality this is not true 11 | So for more advanced version the strategy itself should able to call statistical inference logic to 12 | get the appreciation rate and volatility of the asset 13 | 14 | So now this class is used as testing purpose 15 | """ 16 | 17 | def __init__(self, stg_id: str, stg_cls: str, state: StrategyState = None): 18 | super(MertonOptimalBaby, self).__init__(stg_id=stg_id, stg_cls=stg_cls, state=state) 19 | self.buy_order = None 20 | 21 | def _start(self, app_context: Context) -> None: 22 | self.arate = self._get_stg_config("arate", default=1) 23 | self.vol = self._get_stg_config("vol", default=1) 24 | 25 | self.bar = self.app_context.inst_data_mgr.get_series( 26 | "Bar.%s.Time.86400" % app_context.config.get_app_config("instrumentIds")[0]) 27 | 28 | self.bar.start(app_context) 29 | 30 | self.optimal_weight = self.arate / self.vol ** 2 # assume risk free rate is zero 31 | 32 | super(MertonOptimalBaby, self)._start(app_context) 33 | 34 | def _stop(self): 35 | super(MertonOptimalBaby, self)._stop() 36 | 37 | def on_bar(self, bar): 38 | # we have to rebalance on each bar 39 | # print bar 40 | portfolio = self.get_portfolio() 41 | allocation = portfolio.total_equity * self.optimal_weight 42 | delta = allocation - portfolio.stock_value 43 | if delta > 0: 44 | qty = delta / bar.close # assume no lot size here 45 | self.market_order(inst_id=bar.inst_id, action=Buy, qty=qty) 46 | else: 47 | qty = -delta / bar.close # assume no lot size here 48 | self.market_order(inst_id=bar.inst_id, action=Sell, qty=qty) 49 | -------------------------------------------------------------------------------- /algotrader/strategy/pair_trading.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import rx 4 | from rx.subjects import BehaviorSubject 5 | 6 | from algotrader import Context 7 | from algotrader.model.trade_data_pb2 import * 8 | from algotrader.strategy import Strategy 9 | 10 | 11 | class PairTradingWithOUSpread(Strategy): 12 | """ 13 | This is the baby version that assume the asset we are trading paris 14 | that the spread follows Ornstein-Uhlenbeck mean reverting process with known parameters in advance 15 | in reality this is not true 16 | So for more advanced version the strategy itself should able to call statistical inference logic to 17 | get the appreciation rate and volatility of the asset 18 | 19 | So now this class is used as testing purpose 20 | """ 21 | 22 | def __init__(self, stg_id: str, stg_cls: str, state: StrategyState = None): 23 | super(PairTradingWithOUSpread, self).__init__(stg_id=stg_id, stg_cls=stg_cls, state=state) 24 | self.buy_order = None 25 | 26 | def _start(self, app_context: Context) -> None: 27 | self.ou_params = self._get_stg_config("ou_params", default=1) 28 | self.gamma = self._get_stg_config("gamma", default=1) 29 | 30 | self.instruments = app_context.config.get_app_config("instrumentIds") 31 | self.bar_0 = self.app_context.inst_data_mgr.get_series( 32 | "Bar.%s.Time.86400" % self.instruments[0]) 33 | self.bar_1 = self.app_context.inst_data_mgr.get_series( 34 | "Bar.%s.Time.86400" % self.instruments[1]) 35 | 36 | self.bar_0.start(app_context) 37 | self.bar_1.start(app_context) 38 | 39 | self.log_spot_0 = BehaviorSubject(0) 40 | self.log_spot_1 = BehaviorSubject(0) 41 | self.spread_stream = rx.Observable \ 42 | .zip(self.log_spot_0, self.log_spot_1, lambda x, y: [x, y, x - y]) \ 43 | .subscribe(self.rebalance) 44 | 45 | super(PairTradingWithOUSpread, self)._start(app_context) 46 | 47 | def _stop(self): 48 | super(PairTradingWithOUSpread, self)._stop() 49 | 50 | def on_bar(self, bar): 51 | # logger.info("%s,%s,%.2f" % (bar.inst_id, bar.timestamp, bar.close)) 52 | if bar.inst_id == self.instruments[0]: 53 | self.log_spot_0.on_next(math.log(bar.close)) 54 | elif bar.inst_id == self.instruments[1]: 55 | self.log_spot_1.on_next(math.log(bar.close)) 56 | 57 | def rebalance(self, spread_triple): 58 | if spread_triple[0] == 0: 59 | return 60 | # we have to rebalance on each bar 61 | k = self.ou_params['k'] 62 | eta = self.ou_params['eta'] 63 | theta = self.ou_params['theta'] 64 | spread = spread_triple[2] 65 | 66 | weight = k * (spread - theta) / eta ** 2 67 | portfolio = self.get_portfolio() 68 | allocation_0 = -portfolio.total_equity * weight 69 | allocation_1 = portfolio.total_equity * weight 70 | # TODO: need to check if the portoflio.positions is empty 71 | delta_0 = allocation_0 72 | delta_1 = allocation_1 73 | if self.instruments[0] in portfolio.positions.keys(): 74 | delta_0 = allocation_0 - portfolio.positions[self.instruments[0]].current_value() 75 | if self.instruments[1] in portfolio.positions.keys(): 76 | delta_1 = allocation_1 - portfolio.positions[self.instruments[1]].current_value() 77 | 78 | qty = abs(delta_0) / spread_triple[0] # assume no lot size here 79 | if delta_0 > 0: 80 | self.market_order(inst_id=self.instruments[0], action=Buy, qty=qty) 81 | else: 82 | self.market_order(inst_id=self.instruments[0], action=Sell, qty=qty) 83 | 84 | qty = abs(delta_1) / spread_triple[1] # assume no lot size here 85 | if delta_1 > 0: 86 | self.market_order(inst_id=self.instruments[1], action=Buy, qty=qty) 87 | else: 88 | self.market_order(inst_id=self.instruments[1], action=Sell, qty=qty) 89 | -------------------------------------------------------------------------------- /algotrader/strategy/sma_strategy.py: -------------------------------------------------------------------------------- 1 | from algotrader import Context 2 | from algotrader.model.trade_data_pb2 import * 3 | from algotrader.strategy import Strategy 4 | from algotrader.technical.ma import SMA 5 | from algotrader.utils.logging import logger 6 | 7 | 8 | class SMAStrategy(Strategy): 9 | def __init__(self, stg_id: str, stg_cls: str, state: StrategyState = None): 10 | super(SMAStrategy, self).__init__(stg_id=stg_id, stg_cls=stg_cls, state=state) 11 | self.buy_order = None 12 | 13 | def _start(self, app_context: Context) -> None: 14 | self.instruments = app_context.config.get_app_config("instrumentIds") 15 | self.qty = self._get_stg_config("qty", default=1) 16 | self.bar = self.app_context.inst_data_mgr.get_series( 17 | "Bar.%s.Time.86400" % self.instruments[0]) 18 | 19 | self.sma_fast = SMA(self.bar, 'close', 10) 20 | self.sma_fast.start(app_context) 21 | 22 | self.sma_slow = SMA(self.bar, 'close', 25) 23 | self.sma_slow.start(app_context) 24 | 25 | super(SMAStrategy, self)._start(app_context) 26 | 27 | def _stop(self): 28 | super(SMAStrategy, self)._stop() 29 | 30 | def on_bar(self, bar): 31 | if self.buy_order is None and self.sma_fast.now('value') > self.sma_slow.now('value'): 32 | self.buy_order = self.market_order(inst_id=bar.inst_id, action=Buy, qty=self.qty) 33 | logger.info("%s,B,%s,%s,%.2f,%.2f,%.2f" % ( 34 | bar.timestamp, self.buy_order.cl_id, self.buy_order.cl_ord_id, bar.close, self.sma_fast.now('value'), 35 | self.sma_slow.now('value'))) 36 | elif self.buy_order is not None and self.sma_fast.now('value') < self.sma_slow.now('value'): 37 | sell_order = self.market_order(inst_id=bar.inst_id, action=Sell, qty=self.qty) 38 | logger.info("%s,S,%s,%s,%.2f,%.2f,%.2f" % ( 39 | bar.timestamp, sell_order.cl_id, sell_order.cl_ord_id, bar.close, self.sma_fast.now('value'), 40 | self.sma_slow.now('value'))) 41 | self.buy_order = None 42 | -------------------------------------------------------------------------------- /algotrader/strategy/vix_future.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on 11/5/16 3 | Author = jchan 4 | """ 5 | __author__ = 'jchan' 6 | 7 | from algotrader.event.order import OrdAction 8 | from algotrader.strategy.strategy import Strategy 9 | from algotrader.technical.roc import ROC 10 | from algotrader.technical.pipeline import PipeLine 11 | from algotrader.technical.pipeline.pairwise import PairCorrelation 12 | from algotrader.technical.pipeline.make_vector import MakeVector 13 | from algotrader.technical.pipeline.rank import Rank 14 | from algotrader.technical.pipeline.cross_sessional_apply import Delta, Log 15 | from algotrader.technical.pipeline.pairwise import Minus, Divides 16 | from algotrader.utils import logger 17 | import numpy as np 18 | 19 | class VIXFuture(Strategy): 20 | def __init__(self, stg_id=None, stg_configs=None): 21 | super(VIXFuture, self).__init__(stg_id=stg_id, stg_configs=stg_configs) 22 | self.day_count = 0 23 | self.order = None 24 | 25 | def _start(self, app_context, **kwargs): 26 | self.length = self.get_stg_config_value("length", 10) 27 | 28 | self.bars = [self.app_context.inst_data_mgr.get_series( 29 | "Bar.%s.Time.300" % i) for i in self.app_context.app_config.instrument_ids] 30 | 31 | for bar in self.bars: 32 | bar.start(app_context) 33 | 34 | self.opens = MakeVector(self.bars, input_key='Open') 35 | self.volumes = MakeVector(self.bars, input_key="Volume") 36 | self.rank_opens = Rank(self.bars, input_key='open') 37 | self.rank_opens.start(app_context) 38 | 39 | self.rank_volumes = Rank(self.bars, input_key='Volume') 40 | self.rank_volumes.start(app_context) 41 | # 42 | self.pair_correlation = PairCorrelation(self.rank_opens, self.rank_volumes, length=self.length) 43 | self.pair_correlation.start(app_context) 44 | 45 | super(AlphaFormula3, self)._start(app_context, **kwargs) 46 | 47 | def _stop(self): 48 | super(AlphaFormula3, self)._stop() 49 | 50 | def on_bar(self, bar): 51 | # rank = self.rank_opens.now('value') 52 | # logger.info("[%s] %s" % (self.__class__.__name__, rank)) 53 | # if np.all(np.isnan(rank)): 54 | # return 55 | corr = self.pair_correlation.now('value') 56 | if np.any(np.isnan(corr)): 57 | return 58 | 59 | weight = [corr[i, i+2] for i in range(len(self.bars))] 60 | # weight = rank 61 | weight = -1*weight[0] 62 | 63 | portfolio = self.get_portfolio() 64 | allocation = portfolio.total_equity * weight 65 | delta = allocation - portfolio.stock_value 66 | 67 | index = self.app_context.app_config.instrument_ids.index(bar.inst_id) 68 | qty = delta[index] 69 | # logger.info("%s,B,%.2f" % (bar.timestamp, bar.close)) 70 | self.order = self.market_order(inst_id=bar.inst_id, action=OrdAction.BUY, qty=qty) if qty > 0 else \ 71 | self.market_order(inst_id=bar.inst_id, action=OrdAction.SELL, qty=-qty) 72 | 73 | -------------------------------------------------------------------------------- /algotrader/technical/atr.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from algotrader.technical import Indicator 4 | from algotrader.technical.ma import SMA 5 | 6 | 7 | class ATR(Indicator): 8 | __slots__ = ( 9 | 'length', 10 | '__prev_close', 11 | '__value', 12 | '__average', 13 | ) 14 | 15 | def __init__(self, time_series=None, inputs=None, input_keys=['high', 'low', 'close'], desc="Average True Range", 16 | length=14): 17 | super(ATR, self).__init__(time_series=time_series, inputs=inputs, input_keys=input_keys, desc=desc, 18 | length=length) 19 | self.length = self.get_int_config("length", 14) 20 | self.__prev_close = None 21 | self.__value = None 22 | self.__average = SMA(inputs=inputs, length=self.length) 23 | 24 | def _process_update(self, source: str, timestamp: int, data: Dict[str, float]): 25 | sma_input = {} 26 | high = data['high'] 27 | low = data['low'] 28 | close = data['close'] 29 | 30 | if self.__prev_close is None: 31 | tr = high - low 32 | else: 33 | tr1 = high - low 34 | tr2 = abs(high - self.__prev_close) 35 | tr3 = abs(low - self.__prev_close) 36 | tr = max(max(tr1, tr2), tr3) 37 | 38 | self.__prev_close = close 39 | 40 | sma_input[Indicator.VALUE] = tr 41 | self.__average.add(timestamp=timestamp, data=sma_input) 42 | 43 | result = {} 44 | # result['timestamp'] = data['timestamp'] 45 | result[Indicator.VALUE] = self.__average.now(Indicator.VALUE) 46 | self.add(timestamp=timestamp, data=result) 47 | -------------------------------------------------------------------------------- /algotrader/technical/bb.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Dict 3 | 4 | from algotrader.technical import Indicator 5 | from algotrader.technical.ma import SMA 6 | from algotrader.technical.stats import STD 7 | 8 | 9 | class BB(Indicator): 10 | UPPER = 'uppper' 11 | LOWER = 'lower' 12 | 13 | __slots__ = ( 14 | 'length', 15 | 'num_std' 16 | '__sma', 17 | '__std_dev', 18 | ) 19 | 20 | def __init__(self, time_series=None, inputs=None, input_keys=None, desc="Bollinger Bands", length=14, num_std=3): 21 | super(SMA, self).__init__(time_series=time_series, inputs=inputs, input_keys=input_keys, desc=desc, 22 | length=length, num_std=num_std) 23 | self.length = self.get_int_config("length", 14) 24 | self.num_std = self.get_int_config("num_std", 3) 25 | self.__sma = SMA(inputs=inputs, length=self.length) 26 | self.__std_dev = STD(inputs=inputs, length=self.length) 27 | 28 | def _process_update(self, source: str, timestamp: int, data: Dict[str, float]): 29 | result = {} 30 | sma = self.__sma.now(self.input_keys[0]) 31 | std = self.__std_dev.now(self.input_keys[0]) 32 | if not np.isnan(sma): 33 | upper = sma + std * self.num_std 34 | lower = sma - std * self.num_std 35 | 36 | result[BB.UPPER] = upper 37 | result[BB.LOWER] = lower 38 | result[Indicator.VALUE] = sma 39 | else: 40 | result[BB.UPPER] = np.nan 41 | result[BB.LOWER] = np.nan 42 | result[Indicator.VALUE] = np.nan 43 | 44 | self.add(timestamp=timestamp, data=result) 45 | -------------------------------------------------------------------------------- /algotrader/technical/historical_volatility.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | from typing import Dict 5 | 6 | from algotrader.technical import Indicator 7 | 8 | 9 | class HistoricalVolatility(Indicator): 10 | __slots__ = ( 11 | 'length' 12 | 'ann_factor' 13 | ) 14 | 15 | def __init__(self, time_series=None, inputs=None, input_keys=None, desc="Historical Volatility", length=0, ann_factor=252): 16 | super(HistoricalVolatility, self).__init__(time_series=time_series, inputs=inputs, input_keys=input_keys, desc=desc, 17 | length=length,ann_factor=ann_factor) 18 | self.length = self.get_int_config("length", 0) 19 | self.ann_factor = self.get_int_config("ann_factor", 252) 20 | 21 | def _process_update(self, source: str, timestamp: int, data: Dict[str, float]): 22 | result = {} 23 | if self.first_input.size() >= self.length: 24 | sum_ret_sq = 0.0 25 | for idx in range(self.first_input.size() - self.length + 1, self.first_input.size()): 26 | x_t = self.first_input.get_by_idx(idx, self.first_input_keys[0]) 27 | x_t_1 = self.first_input.get_by_idx(idx - 1, self.first_input_keys[0]) 28 | ret = math.log(x_t / x_t_1) 29 | sum_ret_sq += ret ** 2 30 | result[Indicator.VALUE] = math.sqrt(self.ann_factor * sum_ret_sq / self.length) 31 | else: 32 | result[Indicator.VALUE] = np.nan 33 | 34 | self.add(timestamp=timestamp, data=result) 35 | -------------------------------------------------------------------------------- /algotrader/technical/kfpairregression.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on 7/7/16 3 | Author = jchan 4 | """ 5 | __author__ = 'jchan' 6 | 7 | import numpy as np 8 | from pykalman import KalmanFilter 9 | from typing import Dict 10 | 11 | from algotrader.technical import Indicator 12 | 13 | 14 | class KalmanFilteringPairRegression(Indicator): 15 | SLOPE = 'slope' 16 | INTERCEPT = 'intercept' 17 | 18 | __slots__ = ( 19 | 'length' 20 | ) 21 | 22 | def __init__(self, time_series=None, inputs=None, input_keys=None, desc="Kalman Filter Regression", length=10): 23 | super(KalmanFilteringPairRegression, self).__init__(time_series=time_series, inputs=inputs, 24 | input_keys=input_keys, desc=desc, 25 | keys=['slope', 'intercept'], 26 | default_key='slope', 27 | length=length) 28 | self.length = self.get_int_config("length", 10) 29 | delta = 1e-5 30 | self.trans_cov = delta / (1 - delta) * np.eye(2) 31 | # super(KalmanFilteringPairRegression, self)._update_from_inputs() 32 | 33 | def _process_update(self, source: str, timestamp: int, data: Dict[str, float]): 34 | result = {} 35 | 36 | if input.size() >= self.length: 37 | 38 | independent_var = self.first_input.get_by_idx_range(key=None, start_idx=0, end_idx=-1) 39 | symbol_set = set(self.first_input.keys) 40 | depend_symbol = symbol_set.difference(self.first_input.default_key) 41 | depend_var = self.first_input.get_by_idx_range(key=depend_symbol, start_idx=0, end_idx=-1) 42 | 43 | obs_mat = np.vstack([independent_var.values, np.ones(independent_var.values.shape)]).T[:, np.newaxis] 44 | model = KalmanFilter(n_dim_obs=1, n_dim_state=2, 45 | initial_state_mean=np.zeros(2), 46 | initial_state_covariance=np.ones((2, 2)), 47 | transition_matrices=np.eye(2), 48 | observation_matrices=obs_mat, 49 | observation_covariance=1.0, 50 | transition_covariance=self.trans_cov) 51 | 52 | state_means, state_covs = model.filter(depend_var.values) 53 | slope = state_means[:, 0][-1] 54 | result[Indicator.VALUE] = slope 55 | result[KalmanFilteringPairRegression.SLOPE] = slope 56 | result[KalmanFilteringPairRegression.SLOPE] = state_means[:, 1][-1] 57 | self.add(timestamp=timestamp, data=result) 58 | 59 | else: 60 | result[Indicator.VALUE] = np.nan 61 | result[KalmanFilteringPairRegression.SLOPE] = np.nan 62 | result[KalmanFilteringPairRegression.SLOPE] = np.nan 63 | self.add(timestamp=timestamp, data=result) 64 | -------------------------------------------------------------------------------- /algotrader/technical/ma.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Dict 3 | 4 | from algotrader.technical import Indicator 5 | 6 | 7 | class SMA(Indicator): 8 | Length = 10 9 | __slots__ = ( 10 | 'length' 11 | ) 12 | 13 | def __init__(self, time_series=None, inputs=None, input_keys=None, desc="Simple Moving Average", length=0): 14 | super(SMA, self).__init__(time_series=time_series, inputs=inputs, input_keys=input_keys, desc=desc, 15 | length=length) 16 | self.length = self.get_int_config("length", SMA.Length) 17 | 18 | def _process_update(self, source: str, timestamp: int, data: Dict[str, float]): 19 | result = {} 20 | if self.first_input.size() >= self.length: 21 | value = 0.0 22 | for idx in range(self.first_input.size() - self.length, self.first_input.size()): 23 | value += self.first_input.get_by_idx(idx, self.first_input_keys) 24 | value = round(value / float(self.length), 8) 25 | result[Indicator.VALUE] = value 26 | else: 27 | result[Indicator.VALUE] = np.nan 28 | 29 | self.add(timestamp=timestamp, data=result) 30 | -------------------------------------------------------------------------------- /algotrader/technical/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import OrderedDict 3 | from typing import Dict, List 4 | 5 | from algotrader import Context 6 | from algotrader.technical import Indicator 7 | 8 | 9 | class PipeLine(Indicator): 10 | def __init__(self, time_series=None, inputs=None, input_keys=None, desc=None, 11 | keys: List[str] = None, default_output_key: str = 'value', **kwargs): 12 | 13 | super(PipeLine, self).__init__(time_series=time_series, inputs=inputs, input_keys=input_keys, desc=desc, 14 | keys=keys, default_output_key=default_output_key, **kwargs) 15 | 16 | self.length = self.get_int_config("length", 1) 17 | self.__curr_timestamp = None 18 | 19 | def _start(self, app_context: Context) -> None: 20 | super(PipeLine, self)._start(self.app_context) 21 | self.numPipes = len(self.input_series) 22 | self._flush_and_create() 23 | 24 | def _stop(self): 25 | pass 26 | 27 | def _flush_and_create(self): 28 | self.cache = OrderedDict(zip(list(self.input_names_pos.keys()), [None for _ in range(len(self.input_series))])) 29 | 30 | def all_filled(self): 31 | """ 32 | PipeLine specify function, check in all input in self.inputs have been updated 33 | :return: 34 | """ 35 | has_none = np.sum(np.array([v is None for v in self.cache.values()])) 36 | return False if has_none > 0 else True 37 | 38 | def _process_update(self, source: str, timestamp: int, data: Dict[str, float]): 39 | if timestamp != self.__curr_timestamp: 40 | self.__curr_timestamp = timestamp 41 | self._flush_and_create() 42 | 43 | if source in self.input_names_pos: 44 | idx = self.input_names_pos[source] 45 | self.cache[source] = self.get_input(idx).get_by_idx( 46 | keys=self.get_input_keys(idx=idx), 47 | idx=slice(-self.length, None, None)) 48 | 49 | def numPipes(self): 50 | return self.numPipes 51 | 52 | def shape(self): 53 | raise NotImplementedError() 54 | -------------------------------------------------------------------------------- /algotrader/technical/pipeline/corr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Dict 3 | 4 | from algotrader.technical.pipeline import PipeLine 5 | 6 | 7 | class Corr(PipeLine): 8 | def __init__(self, time_series=None, inputs=None, input_keys='close', 9 | desc="Correlation", length=30): 10 | super(Corr, self).__init__(time_series=time_series, inputs=inputs, input_keys=input_keys, desc=desc, 11 | length=length) 12 | 13 | def _process_update(self, source: str, timestamp: int, data: Dict[str, float]): 14 | super(Corr, self)._process_update(source=source, timestamp=timestamp, data=data) 15 | result = {} 16 | if self.inputs[0].size() > self.length: 17 | if self.all_filled(): 18 | result[PipeLine.VALUE] = self.df.corr() 19 | else: 20 | result[PipeLine.VALUE] = self._default_output() 21 | else: 22 | result[PipeLine.VALUE] = self._default_output() 23 | self.add(timestamp=timestamp, data=result) 24 | 25 | def _default_output(self): 26 | na_array = np.empty(shape=self.shape()) 27 | na_array[:] = np.nan 28 | return na_array 29 | 30 | def shape(self): 31 | return np.array([self.numPipes, self.numPipes]) 32 | -------------------------------------------------------------------------------- /algotrader/technical/pipeline/make_vector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Dict 3 | 4 | from algotrader.technical.pipeline import PipeLine 5 | 6 | 7 | class MakeVector(PipeLine): 8 | def __init__(self, time_series=None, inputs=None, input_keys='close', 9 | desc="Bundle and Sync DataSeries to Vector"): 10 | super(MakeVector, self).__init__(time_series=time_series, inputs=inputs, input_keys=input_keys, desc=desc, 11 | length=1) 12 | 13 | def _process_update(self, source: str, timestamp: int, data: Dict[str, float]): 14 | super(MakeVector, self)._process_update(source=source, timestamp=timestamp, data=data) 15 | result = {} 16 | if self.get_input(0).size() >= self.length: 17 | if self.all_filled(): 18 | packed_matrix = np.transpose(np.array(self.cache.values())) 19 | result[PipeLine.VALUE] = packed_matrix 20 | else: 21 | result[PipeLine.VALUE] = self._default_output() 22 | else: 23 | result[PipeLine.VALUE] = self._default_output() 24 | 25 | self.add(timestamp=timestamp, data=result) 26 | 27 | def _default_output(self): 28 | na_array = np.empty(shape=self.shape()) 29 | na_array[:] = np.nan 30 | return na_array 31 | 32 | def shape(self): 33 | return np.array([1, self.numPipes]) 34 | -------------------------------------------------------------------------------- /algotrader/technical/pipeline/rank.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from typing import Dict 4 | 5 | from algotrader.technical.pipeline import PipeLine 6 | 7 | 8 | class Rank(PipeLine): 9 | def __init__(self, time_series=None, inputs=None, input_keys='close', desc="Rank", ascending=True): 10 | super(Rank, self).__init__(time_series=time_series, inputs=inputs, input_keys=input_keys, desc=desc, 11 | ascending=ascending) 12 | 13 | self.ascending = self.get_bool_config("ascending", True) 14 | 15 | def _process_update(self, source: str, timestamp: int, data: Dict[str, float]): 16 | super(Rank, self)._process_update(source=source, timestamp=timestamp, data=data) 17 | result = {} 18 | if self.all_filled(): 19 | df = pd.DataFrame(self.cache) 20 | result[PipeLine.VALUE] = ((df.rank(axis=1, ascending=self.ascending) - 1) / (df.shape[1] - 1)).tail( 21 | 1).values 22 | else: 23 | result[PipeLine.VALUE] = self._default_output() 24 | 25 | self.add(timestamp=timestamp, data=result) 26 | 27 | def _default_output(self): 28 | na_array = np.empty(shape=self.shape()) 29 | na_array[:] = np.nan 30 | return na_array 31 | 32 | def shape(self): 33 | return np.array([1, self.numPipes]) 34 | -------------------------------------------------------------------------------- /algotrader/technical/roc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Dict 3 | 4 | from algotrader.technical import Indicator 5 | 6 | 7 | def roc(prev_value, curr_value): 8 | if prev_value != 0.0: 9 | return (curr_value - prev_value) / prev_value 10 | return np.nan 11 | 12 | 13 | class ROC(Indicator): 14 | __slots__ = ( 15 | 'length' 16 | ) 17 | 18 | def __init__(self, time_series=None, inputs=None, input_keys=None, desc="Rate Of Change", length=1): 19 | super(ROC, self).__init__(time_series=time_series, inputs=inputs, input_keys=input_keys, desc=desc, 20 | length=length) 21 | self.length = self.get_int_config("length", 1) 22 | 23 | def _process_update(self, source: str, timestamp: int, data: Dict[str, float]): 24 | result = {} 25 | if self.first_input.size() > self.length: 26 | prev_value = self.first_input.ago(self.length, self.first_input_keys) 27 | curr_value = self.first_input.now(self.first_input_keys) 28 | result[Indicator.VALUE] = roc(prev_value, curr_value) 29 | else: 30 | result[Indicator.VALUE] = np.nan 31 | 32 | self.add(timestamp=timestamp, data=result) 33 | -------------------------------------------------------------------------------- /algotrader/technical/rolling_apply.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from typing import Dict 4 | 5 | from algotrader.technical import Indicator 6 | 7 | 8 | class RollingApply(Indicator): 9 | _slots__ = ( 10 | 'length', 11 | 'func' 12 | ) 13 | 14 | def __init__(self, time_series=None, inputs=None, input_keys=None, desc="Rolling Apply", length=0, func=np.std): 15 | super(RollingApply, self).__init__(time_series=time_series, inputs=inputs, input_keys=input_keys, desc=desc, 16 | length=length) 17 | self.func = func 18 | self.length = self.get_int_config("length", 0) 19 | 20 | def _process_update(self, source: str, timestamp: int, data: Dict[str, float]): 21 | result = {} 22 | if self.first_input.size() >= self.length: 23 | sliced = self.first_input.get_by_idx(keys=self.first_input_keys, idx=slice(-self.length, None, None)) 24 | result[Indicator.VALUE] = self.func(sliced) 25 | else: 26 | result[Indicator.VALUE] = np.nan 27 | 28 | self.add(timestamp=timestamp, data=result) 29 | 30 | 31 | class StdDev(RollingApply): 32 | def __init__(self, time_series=None, inputs=None, input_keys='close', desc="Rolling Standard Deviation", length=30): 33 | super(StdDev, self).__init__(time_series=time_series, inputs=inputs, input_keys=input_keys, desc=desc, 34 | length=length, func=lambda x: np.std(x, axis=0)) 35 | 36 | 37 | def pd_skew_wrapper(x): 38 | ts = pd.Series(x) 39 | return ts.skew() 40 | 41 | 42 | def pd_kurtosis_wrapper(x): 43 | ts = pd.Series(x) 44 | return ts.kurtosis() 45 | 46 | 47 | def pd_kurt_wrapper(x): 48 | ts = pd.Series(x) 49 | return ts.kurt() 50 | 51 | 52 | class Skew(RollingApply): 53 | def __init__(self, time_series=None, inputs=None, input_keys='close', desc="Rolling Skew", length=30): 54 | super(Skew, self).__init__(time_series=time_series, inputs=inputs, input_keys=input_keys, desc=desc, 55 | length=length, func=lambda x: pd_skew_wrapper(x)) 56 | 57 | 58 | class Kurtosis(RollingApply): 59 | def __init__(self, time_series=None, inputs=None, input_keys='close', desc="Rolling Kurtosis", length=30): 60 | super(Kurtosis, self).__init__(time_series=time_series, inputs=inputs, input_keys=input_keys, desc=desc, 61 | length=length, func=lambda x: pd_kurtosis_wrapper(x)) 62 | 63 | 64 | class Kurt(RollingApply): 65 | def __init__(self, time_series=None, inputs=None, input_keys='close', desc="Rolling Kurt", length=30): 66 | super(Kurt, self).__init__(time_series=time_series, inputs=inputs, input_keys=input_keys, desc=desc, 67 | length=length, func=lambda x: pd_kurt_wrapper(x)) 68 | 69 | # 70 | # 71 | # from jinja2 import Template 72 | # rollingTmp = Template(""" 73 | # class {{className}}(RollingApply): 74 | # def __init__(self, input, input_key='close', length=30, desc="Rolling {{className}}"): 75 | # super({{className}}, self).__init__(input, 76 | # func={{func}}, 77 | # name=Indicator.get_name({{className}}.__name__, input, input_key, length), 78 | # length=length, 79 | # input_key=input_key, desc=desc) 80 | # """) 81 | # 82 | # print rollingTmp.render({"className": "Skew", 83 | # "func" : "lambda x: pd_skew_wrapper(x)"}) 84 | # 85 | # print rollingTmp.render({"className": "Kurtosis", 86 | # "func" : "lambda x: pd_kurtosis_wrapper(x)"}) 87 | # 88 | # print rollingTmp.render({"className": "Kurt", 89 | # "func" : "lambda x: pd_kurt_wrapper(x)"}) 90 | -------------------------------------------------------------------------------- /algotrader/technical/rsi.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Dict 3 | 4 | from algotrader.technical import Indicator 5 | 6 | 7 | def gain_loss(prev_value, next_value): 8 | change = next_value - prev_value 9 | if change < 0: 10 | gain = 0 11 | loss = abs(change) 12 | else: 13 | gain = change 14 | loss = 0 15 | return gain, loss 16 | 17 | 18 | # [begin, end) 19 | def avg_gain_loss(series, input_key, begin, end): 20 | range_len = end - begin 21 | if range_len < 2: 22 | return 0, 0 23 | 24 | gain = 0 25 | loss = 0 26 | for i in range(begin + 1, end): 27 | curr_gain, curr_loss = gain_loss(series[i - 1, input_key], series[i, input_key]) 28 | gain += curr_gain 29 | loss += curr_loss 30 | return gain / float(range_len - 1), loss / float(range_len - 1) 31 | 32 | 33 | def rsi(values, input_key, length): 34 | assert (length > 1) 35 | if len(values) > length: 36 | 37 | avg_gain, avg_loss = avg_gain_loss(values, input_key, 0, length) 38 | for i in range(length, len(values)): 39 | gain, loss = gain_loss(values[i - 1, input_key], values[i, input_key]) 40 | avg_gain = (avg_gain * (length - 1) + gain) / float(length) 41 | avg_loss = (avg_loss * (length - 1) + loss) / float(length) 42 | 43 | if avg_loss == 0: 44 | return 100 45 | rs = avg_gain / avg_loss 46 | return 100 - 100 / (1 + rs) 47 | else: 48 | return np.nan 49 | 50 | 51 | class RSI(Indicator): 52 | __slots__ = ( 53 | 'length', 54 | '__prev_gain', 55 | '__prev_loss' 56 | ) 57 | 58 | def __init__(self, time_series=None, inputs=None, input_keys=None, desc="Relative Strength Indicator", length=14): 59 | super(RSI, self).__init__(time_series=time_series, inputs=inputs, input_keys=input_keys, desc=desc, 60 | length=length) 61 | self.length = self.get_int_config("length", 14) 62 | self.__prev_gain = None 63 | self.__prev_loss = None 64 | #super(RSI, self)._update_from_inputs() 65 | 66 | def _process_update(self, source: str, timestamp: int, data: Dict[str, float]): 67 | result = {} 68 | if self.first_input.size() > self.length: 69 | if self.__prev_gain is None: 70 | avg_gain, avg_loss = avg_gain_loss(self.first_input, self.first_input_keys[0], 0, self.first_input.size()) 71 | else: 72 | prev_value = self.first_input.ago(1, self.first_input_keys[0]) 73 | curr_value = self.first_input.now(self.first_input_keys[0]) 74 | curr_gain, curr_loss = gain_loss(prev_value, curr_value) 75 | avg_gain = (self.__prev_gain * (self.length - 1) + curr_gain) / float(self.length) 76 | avg_loss = (self.__prev_loss * (self.length - 1) + curr_loss) / float(self.length) 77 | 78 | if avg_loss == 0: 79 | rsi_value = 100 80 | else: 81 | rs = avg_gain / avg_loss 82 | rsi_value = 100 - 100 / (1 + rs) 83 | self.__prev_gain = avg_gain 84 | self.__prev_loss = avg_loss 85 | 86 | result[Indicator.VALUE] = rsi_value 87 | else: 88 | result[Indicator.VALUE] = np.nan 89 | 90 | self.add(timestamp=timestamp, data=result) 91 | 92 | # if __name__ == "__main__": 93 | # import datetime 94 | # from algotrader.trading.data_series import DataSeries 95 | # 96 | # close = DataSeries("close") 97 | # rsi = RSI(close, input_key='close', length=14) 98 | # print(rsi.name) 99 | # t = datetime.datetime.now() 100 | # 101 | # values = [44.34, 44.09, 44.15, 43.61, 44.33, 44.83, 45.10, 45.42, 102 | # 45.84, 46.08, 45.89, 46.03, 45.61, 46.28, 46.28, 46.00] 103 | # 104 | # for idx, value in enumerate(values): 105 | # close.add(timestamp=t, data={'close': value}) 106 | # t = t + datetime.timedelta(0, 3) 107 | # 108 | # print(idx, rsi.now()) 109 | -------------------------------------------------------------------------------- /algotrader/technical/stats.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Dict 3 | 4 | from algotrader.technical import Indicator 5 | 6 | 7 | class MAX(Indicator): 8 | __slots__ = ( 9 | 'length' 10 | ) 11 | 12 | def __init__(self, time_series=None, inputs=None, input_keys=None, desc="Maximum", length=0): 13 | super(MAX, self).__init__(time_series=time_series, inputs=inputs, input_keys=input_keys, desc=desc, 14 | length=length) 15 | self.length = self.get_int_config("length", 0) 16 | 17 | def _process_update(self, source: str, timestamp: int, data: Dict[str, float]): 18 | result = {} 19 | if self.self.first_input.size() >= self.length: 20 | result[Indicator.VALUE] = self.self.first_input.max(-self.length, self.first_input_keys[0]) 21 | else: 22 | result[Indicator.VALUE] = np.nan 23 | 24 | self.add(timestamp=timestamp, data=result) 25 | 26 | 27 | class MIN(Indicator): 28 | __slots__ = ( 29 | 'length' 30 | ) 31 | 32 | def __init__(self, time_series=None, inputs=None, input_keys=None, desc="Minimum", length=0): 33 | super(MIN, self).__init__(time_series=time_series, inputs=inputs, input_keys=input_keys, desc=desc, 34 | length=length) 35 | self.length = self.get_int_config("length", 0) 36 | 37 | def _process_update(self, source: str, timestamp: int, data: Dict[str, float]): 38 | result = {} 39 | if self.first_input.size() >= self.length: 40 | result[Indicator.VALUE] = self.first_input.min(-self.length, self.first_input_keys[0]) 41 | else: 42 | result[Indicator.VALUE] = np.nan 43 | 44 | self.add(timestamp=timestamp, data=result) 45 | 46 | 47 | class STD(Indicator): 48 | __slots__ = ( 49 | 'length' 50 | ) 51 | 52 | def __init__(self, time_series=None, inputs=None, input_keys=None, desc="Standard Deviation", length=0): 53 | super(STD, self).__init__(time_series=time_series, inputs=inputs, input_keys=input_keys, desc=desc, 54 | length=length) 55 | self.length = self.get_int_config("length", 0) 56 | 57 | def _process_update(self, source: str, timestamp: int, data: Dict[str, float]): 58 | result = {} 59 | if self.first_input.size() >= self.length: 60 | result[Indicator.VALUE] = self.first_input.std(-self.length, self.first_input_keys[0]) 61 | else: 62 | result[Indicator.VALUE] = np.nan 63 | 64 | self.add(timestamp=timestamp, data=result) 65 | 66 | 67 | class VAR(Indicator): 68 | __slots__ = ( 69 | 'length' 70 | ) 71 | 72 | def __init__(self, time_series=None, inputs=None, input_keys=None, desc="Variance", length=0): 73 | super(VAR, self).__init__(time_series=time_series, inputs=inputs, input_keys=input_keys, desc=desc, 74 | length=length) 75 | self.length = self.get_int_config("length", 0) 76 | 77 | def _process_update(self, source: str, timestamp: int, data: Dict[str, float]): 78 | result = {} 79 | if self.first_input.size() >= self.length: 80 | result[Indicator.VALUE] = self.first_input.std(-self.length, self.first_input_keys[0]) 81 | else: 82 | result[Indicator.VALUE] = np.nan 83 | 84 | self.add(timestamp=timestamp, data=result) 85 | -------------------------------------------------------------------------------- /algotrader/technical/talib_wrapper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import talib 3 | from typing import Dict 4 | 5 | from algotrader.technical import Indicator 6 | 7 | 8 | def ds_to_high_numpy(ds, idx): 9 | for k in ds.keys: 10 | if k.lower() == 'high': 11 | return np.array(ds.get_by_idx(keys=k, idx=idx)) 12 | return None 13 | 14 | 15 | def ds_to_low_numpy(ds, idx): 16 | for k in ds.keys: 17 | if k.lower() == 'low': 18 | return np.array(ds.get_by_idx(keys=k, idx=idx)) 19 | return None 20 | 21 | 22 | def ds_to_open_list(ds, idx): 23 | for k in ds.keys: 24 | if k.lower() == 'open': 25 | return np.array(ds.get_by_idx(keys=k, idx=idx)) 26 | return None 27 | 28 | 29 | def ds_to_close_numpy(ds, idx): 30 | for k in ds.keys: 31 | if k.lower() == 'close': 32 | return np.array(ds.get_by_idx(keys=k, idx=idx)) 33 | return None 34 | 35 | 36 | def ds_to_volume_numpy(ds, idx): 37 | for k in ds.keys: 38 | if k.lower() == 'volume': 39 | return np.array(ds.get_by_idx(keys=k, idx=idx)) 40 | return None 41 | 42 | 43 | def call_talib_with_hlcv(ds, count, talib_func, *args, **kwargs): 44 | idx = slice(-count, None, None) 45 | high = ds_to_high_numpy(ds, idx) 46 | low = ds_to_low_numpy(ds, idx) 47 | close = ds_to_close_numpy(ds, idx) 48 | volume = ds_to_volume_numpy(ds, idx) 49 | 50 | if high is None or low is None or low is None or volume is None: 51 | return None 52 | 53 | return talib_func(high, low, close, volume, *args, **kwargs) 54 | 55 | 56 | class SMA(Indicator): 57 | __slots__ = ( 58 | 'length' 59 | ) 60 | 61 | def __init__(self, time_series=None, inputs=None, input_keys=None, desc="TALib Simple Moving Average", length=0): 62 | if time_series: 63 | super(SMA, self).__init__(time_series=time_series) 64 | else: 65 | super(SMA, self).__init__(inputs=inputs, input_keys=input_keys, desc=desc, length=length) 66 | self.length = self.get_int_config("length", 0) 67 | 68 | def _process_update(self, source: str, timestamp: int, data: Dict[str, float]): 69 | result = {} 70 | if self.first_input.size() >= self.length: 71 | value = talib.SMA( 72 | np.array( 73 | self.first_input.get_by_idx(keys=self.first_input_keys, 74 | idx=slice(-self.length, None, None))), timeperiod=self.length) 75 | 76 | result[Indicator.VALUE] = value[-1] 77 | else: 78 | result[Indicator.VALUE] = np.nan 79 | 80 | self.add(timestamp=timestamp, data=result) 81 | 82 | 83 | class EMA(Indicator): 84 | __slots__ = ( 85 | 'length' 86 | ) 87 | 88 | def __init__(self, time_series=None, inputs=None, input_keys=None, desc="TALib Exponential Moving Average", 89 | length=0): 90 | if time_series: 91 | super(EMA, self).__init__(time_series=time_series) 92 | else: 93 | super(EMA, self).__init__(inputs=inputs, input_keys=input_keys, desc=desc, length=length) 94 | self.length = self.get_int_config("length", 0) 95 | 96 | def _process_update(self, source: str, timestamp: int, data: Dict[str, float]): 97 | result = {} 98 | if self.first_input.size() >= self.length: 99 | value = talib.EMA( 100 | np.array( 101 | self.first_input.get_by_idx(keys=self.first_input_keys, 102 | idx=slice(-self.length, None, None))), timeperiod=self.length) 103 | 104 | result[Indicator.VALUE] = value[-1] 105 | else: 106 | result[Indicator.VALUE] = np.nan 107 | 108 | self.add(timestamp=timestamp, data=result) 109 | 110 | 111 | single_ds_list = ["APO", "BBANDS", "CMO", "DEMA", "EMA", "HT_DCPERIOD", "HT_DCPHASE", "HT_PHASOR", "HT_SINE", 112 | "HT_TRENDLINE", "HT_TRENDMODE", "KAMA", "LINEARREG", "LINEARREG_ANGLE", "LINEARREG_INTERCEPT"] 113 | -------------------------------------------------------------------------------- /algotrader/technical/talib_wrapper_gen.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Generator of TALib Wrapper Class code 4 | 5 | """ 6 | from jinja2 import Template 7 | 8 | indicatorTmp = Template(""" 9 | class {{IndicatorClass}}(Indicator): 10 | __slots__ = ( 11 | {% for i in params %}'{{i}}'{% if not loop.last %},{% endif %}{% endfor %} 12 | ) 13 | 14 | def __init__(self, input, input_key=None, {% for i in params %}{{i}}, {% endfor %} desc="{{description}}"): 15 | super({{IndicatorClass}}, self).__init__(Indicator.get_name({{IndicatorClass}}.__name__, input, input_key, length), input, input_key, desc) 16 | self.length = int(length) 17 | {% for p in params %}self.{{p}} = {{p}} 18 | {% endfor %} 19 | super({{IndicatorClass}}, self).update_all() 20 | 21 | def on_update(self, data): 22 | 23 | result = {} 24 | result['timestamp'] = data['timestamp'] 25 | if self.input.size() >= self.length: 26 | value = talib.{{IndicatorClass}}( 27 | np.array( 28 | self.input.get_by_idx(keys=self.input_keys, 29 | idx=slice(-self.length, None, None))), {% for p in params %} {{p}}=self.{{p}}{% if not loop.last %},{% endif %}{% endfor %}) 30 | 31 | result[Indicator.VALUE] = value[-1] 32 | else: 33 | result[Indicator.VALUE] = np.nan 34 | 35 | self.add(result) 36 | """) 37 | 38 | print(indicatorTmp.render({"IndicatorClass": "APO", 39 | "description": "apo test", 40 | "params": ["fastperiod", "slowperiod", "matype"]})) 41 | 42 | single_ds_list = ["APO", "BBANDS", "CMO", "DEMA", "EMA", "HT_DCPERIOD", "HT_DCPHASE", "HT_PHASOR", "HT_SINE", 43 | "HT_TRENDLINE", "HT_TRENDMODE", "KAMA", "LINEARREG", "LINEARREG_ANGLE", "LINEARREG_INTERCEPT"] 44 | -------------------------------------------------------------------------------- /algotrader/trading/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexcwyu/python-trading/a494f602411a3ebfdecae002a16a5ea93fc7a046/algotrader/trading/__init__.py -------------------------------------------------------------------------------- /algotrader/trading/account.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from algotrader import SimpleManager, Context, Startable, HasId 4 | from algotrader.model.model_factory import ModelFactory 5 | from algotrader.model.trade_data_pb2 import * 6 | from algotrader.provider.datastore import PersistenceMode 7 | 8 | 9 | class Account(Startable, HasId): 10 | def __init__(self, acct_id: str, values: Dict[str, AccountValue] = None, state=None): 11 | # TODO load from DB 12 | self.state = state if state else ModelFactory.build_account_state(acct_id=acct_id, values=values) 13 | 14 | def on_acc_upd(self, account_update: AccountUpdate) -> None: 15 | for update_value in account_update.values.values(): 16 | ModelFactory.update_account_value(self.state.values[update_value.key], update_value.key, 17 | update_value.ccy_values) 18 | 19 | def id(self) -> str: 20 | return self.state.acct_id 21 | 22 | def _start(self, app_context: Context) -> None: 23 | # TODO 24 | pass 25 | 26 | 27 | class AccountManager(SimpleManager): 28 | def __init__(self): 29 | super(AccountManager, self).__init__() 30 | self.store = None 31 | 32 | def _start(self, app_context: Context) -> None: 33 | self.store = self.app_context.get_data_store() 34 | self.persist_mode = self.app_context.config.get_app_config("persistenceMode") 35 | self.load_all() 36 | 37 | def load_all(self): 38 | if self.store: 39 | self.store.start(self.app_context) 40 | account_states = self.store.load_all('accounts') 41 | for account_state in account_states: 42 | self.add(self.new_account(account_state.acct_id, state=account_state)) 43 | 44 | def save_all(self): 45 | if self.store and self.persist_mode != PersistenceMode.Disable: 46 | for account in self.all_items(): 47 | self.store.save_account(account) 48 | 49 | def add(self, account): 50 | super(AccountManager, self).add(account) 51 | if self.store and self.persist_mode == PersistenceMode.RealTime: 52 | self.store.save_account(account) 53 | 54 | def id(self): 55 | return "AccountManager" 56 | 57 | def new_account(self, acct_id, values=None, state=None): 58 | account = Account(acct_id, values=values, state=state) 59 | self.add(account) 60 | return account 61 | -------------------------------------------------------------------------------- /algotrader/trading/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import yaml 4 | from typing import Dict 5 | 6 | try: 7 | from yaml import CLoader as Loader, CDumper as Dumper 8 | except ImportError: 9 | from yaml import Loader, Dumper 10 | 11 | 12 | class Config(object): 13 | def __init__(self, *configs: Dict): 14 | self.config = {} 15 | for config in configs: 16 | merge(config, self.config) 17 | 18 | def get(self, paths, default=None): 19 | result = self.config 20 | for path in paths.split("."): 21 | if path in result: 22 | result = result[path] 23 | else: 24 | return default 25 | 26 | return result 27 | 28 | def set(self, paths, value): 29 | result = self.config 30 | arr = paths.split(".") 31 | if len(arr) > 2: 32 | for path in arr[:-2]: 33 | if path in result: 34 | result = result[path] 35 | else: 36 | result[path] = {} 37 | result = result[path] 38 | result[arr[-1]] = value 39 | else: 40 | result[paths] = value 41 | 42 | def get_app_config(self, path: str, default=None): 43 | return self.get("Application.%s" % path, default=default) 44 | 45 | def get_strategy_config(self, id: str, path: str, default=None): 46 | return self.get("Strategy.%s.%s" % (id, path), default=default) 47 | 48 | def get_feed_config(self, id: str, path: str, default=None): 49 | return self.get("Feed.%s.%s" % (id, path), default=default) 50 | 51 | def get_broker_config(self, id: str, path: str, default=None): 52 | return self.get("Broker.%s.%s" % (id, path), default=default) 53 | 54 | def get_datastore_config(self, id: str, path: str, default=None): 55 | return self.get("DataStore.%s.%s" % (id, path), default=default) 56 | 57 | 58 | def merge(source: Dict, destination: Dict): 59 | for key, value in source.items(): 60 | if isinstance(value, dict): 61 | node = destination.setdefault(key, {}) 62 | merge(value, node) 63 | else: 64 | destination[key] = value 65 | 66 | 67 | def save_to_json(path: str, config: Dict) -> None: 68 | with open(path, 'w') as f: 69 | json.dumps(config, f) 70 | 71 | 72 | def save_to_yaml(path: str, config: Dict) -> None: 73 | with open(path, 'w') as f: 74 | yaml.dump(config, f, default_flow_style=False) 75 | 76 | 77 | def load_from_json(path: str = '../../config/backtest.json') -> Dict: 78 | with open(path, 'r') as f: 79 | return json.load(f) 80 | 81 | 82 | def load_from_yaml(path: str = '../../config/backtest.yaml') -> Dict: 83 | with open(path, 'r') as f: 84 | read_data = f.read() 85 | return yaml.load(read_data) 86 | -------------------------------------------------------------------------------- /algotrader/trading/context.py: -------------------------------------------------------------------------------- 1 | from algotrader import Context 2 | from algotrader.model.model_factory import ModelFactory 3 | from algotrader.provider import ProviderManager 4 | from algotrader.provider.broker import Broker 5 | from algotrader.provider.datastore import DataStore 6 | from algotrader.provider.feed import Feed 7 | from algotrader.strategy import StrategyManager 8 | from algotrader.trading.account import AccountManager 9 | from algotrader.trading.clock import Clock, RealTimeClock, SimulationClock 10 | from algotrader.trading.config import Config 11 | from algotrader.trading.event import EventBus 12 | from algotrader.trading.instrument_data import InstrumentDataManager 13 | from algotrader.trading.order import OrderManager 14 | from algotrader.trading.portfolio import Portfolio, PortfolioManager 15 | from algotrader.trading.ref_data import RefDataManager 16 | from algotrader.trading.sequence import SequenceManager 17 | 18 | 19 | class ApplicationContext(Context): 20 | def __init__(self, config: Config = None): 21 | super(ApplicationContext, self).__init__() 22 | 23 | self.config = config if config else Config() 24 | 25 | self.clock = self.add_startable(self.__get_clock()) 26 | self.provider_mgr = self.add_startable(ProviderManager()) 27 | 28 | self.seq_mgr = self.add_startable(SequenceManager()) 29 | 30 | self.inst_data_mgr = self.add_startable(InstrumentDataManager()) 31 | self.ref_data_mgr = self.add_startable(RefDataManager()) 32 | 33 | self.order_mgr = self.add_startable(OrderManager()) 34 | self.acct_mgr = self.add_startable(AccountManager()) 35 | self.portf_mgr = self.add_startable(PortfolioManager()) 36 | self.stg_mgr = self.add_startable(StrategyManager()) 37 | 38 | self.event_bus = EventBus() 39 | self.model_factory = ModelFactory 40 | 41 | def __get_clock(self) -> Clock: 42 | if self.config.get_app_config("clockId", Clock.Simulation) == Clock.RealTime: 43 | return RealTimeClock() 44 | return SimulationClock() 45 | 46 | def get_data_store(self) -> DataStore: 47 | return self.provider_mgr.get(self.config.get_app_config("dataStoreId")) 48 | 49 | def get_broker(self) -> Broker: 50 | return self.provider_mgr.get(self.config.get_app_config("brokerId")) 51 | 52 | def get_feed(self) -> Feed: 53 | return self.provider_mgr.get(self.config.get_app_config("feedId")) 54 | 55 | def get_portfolio(self) -> Portfolio: 56 | return self.portf_mgr.get_or_new_portfolio(self.config.get_app_config("dataStoreId"), 57 | self.config.get_app_config("portfolioInitialcash")) 58 | -------------------------------------------------------------------------------- /algotrader/trading/position.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from algotrader.model.market_data_pb2 import * 4 | from algotrader.model.model_factory import ModelFactory 5 | from algotrader.model.trade_data_pb2 import * 6 | from algotrader.trading.event import MarketDataEventHandler 7 | from algotrader.utils.market_data import get_quote_mid 8 | from algotrader.utils.model import add_to_list 9 | 10 | 11 | class HasPositions(MarketDataEventHandler): 12 | __metaclass__ = abc.ABCMeta 13 | 14 | def __init__(self, state): 15 | self.state = state 16 | 17 | def positions(self): 18 | return self.state.positions 19 | 20 | def has_position(self, inst_id: str) -> bool: 21 | return inst_id in self.state.positions 22 | 23 | def get_position(self, inst_id: str) -> Position: 24 | if inst_id not in self.state.positions: 25 | ModelFactory.add_position(self.state, inst_id=inst_id) 26 | position = self.state.positions[inst_id] 27 | return position 28 | 29 | def update_price(self, timestamp: int, inst_id: str, price: float) -> None: 30 | if inst_id in self.state.positions: 31 | position = self.state.positions[inst_id] 32 | position.last_price = price 33 | 34 | def add_position(self, inst_id: str, cl_id: str, cl_ord_id: str, qty: float) -> None: 35 | position = self.get_position(inst_id) 36 | order_position = self.__get_or_add_order(position=position, cl_id=cl_id, 37 | cl_ord_id=cl_ord_id) 38 | order_position.filled_qty = qty 39 | position.filled_qty += qty 40 | 41 | def add_order(self, inst_id: str, cl_id: str, cl_ord_id: str, ordered_qty: float) -> None: 42 | add_to_list(self.state.cl_ord_ids, [cl_ord_id]) 43 | 44 | position = self.get_position(inst_id) 45 | order_position = self.__get_or_add_order(position=position, cl_id=cl_id, 46 | cl_ord_id=cl_ord_id) 47 | order_position.ordered_qty = ordered_qty 48 | position.ordered_qty += ordered_qty 49 | 50 | def __get_or_add_order(self, position: Position, cl_id: str, cl_ord_id: str) -> OrderPosition: 51 | id = ModelFactory.build_cl_ord_id(cl_id, cl_ord_id) 52 | if id not in position.orders: 53 | ModelFactory.add_order_position(position, cl_id=cl_id, 54 | cl_ord_id=cl_ord_id, 55 | ordered_qty=0, filled_qty=0) 56 | return position.orders[id] 57 | 58 | def on_bar(self, bar: Bar) -> None: 59 | self.update_price(bar.timestamp, bar.inst_id, bar.close) 60 | 61 | def on_quote(self, quote: Quote) -> None: 62 | self.update_price(quote.timestamp, quote.inst_id, get_quote_mid(quote)) 63 | 64 | def on_trade(self, trade: Trade) -> None: 65 | self.update_price(trade.timestamp, trade.inst_id, trade.price) 66 | 67 | def position_filled_qty(self, inst_id: str) -> float: 68 | position = self.get_position(inst_id) 69 | return position.filled_qty 70 | 71 | def position_ordered_qty(self, inst_id: str) -> float: 72 | position = self.get_position(inst_id) 73 | return position.ordered_qty 74 | 75 | def position_value(self, inst_id: str) -> float: 76 | position = self.get_position(inst_id) 77 | return position.last_price * position.filled_qty 78 | 79 | def total_position_value(self) -> float: 80 | total = 0 81 | for inst_id, position in self.state.positions.items(): 82 | total += self.position_value(inst_id) 83 | return total 84 | 85 | def position_order_ids(self, inst_id: str): 86 | if not self.has_position(inst_id): 87 | return [] 88 | 89 | position = self.get_position(inst_id) 90 | return [ord_id for ord_id in position.orders.keys()] 91 | -------------------------------------------------------------------------------- /algotrader/trading/ref_data.py: -------------------------------------------------------------------------------- 1 | from algotrader import Manager, Context 2 | from algotrader.provider.datastore import PersistenceMode 3 | 4 | 5 | class RefDataManager(Manager): 6 | def __init__(self): 7 | super(RefDataManager, self).__init__() 8 | 9 | self._inst_dict = {} 10 | self._ccy_dict = {} 11 | self._exch_dict = {} 12 | self.store = None 13 | 14 | def _start(self, app_context: Context) -> None: 15 | self.store = self.app_context.get_data_store() 16 | self.persist_mode = self.app_context.config.get_app_config("persistenceMode") 17 | self.load_all() 18 | 19 | def _stop(self): 20 | self.save_all() 21 | self.reset() 22 | 23 | def load_all(self): 24 | if self.store: 25 | self.store.start(self.app_context) 26 | for inst in self.store.load_all('instruments'): 27 | self._inst_dict[inst.inst_id] = inst 28 | for ccy in self.store.load_all('currencies'): 29 | self._ccy_dict[ccy.ccy_id] = ccy 30 | for exch in self.store.load_all('exchanges'): 31 | self._exch_dict[exch.exch_id] = exch 32 | 33 | def save_all(self): 34 | if self.store and self.persist_mode != PersistenceMode.Disable: 35 | for inst in self._inst_dict.values(): 36 | self.store.save_instrument(inst) 37 | for ccy in self._ccy_dict.values(): 38 | self.store.save_currency(ccy) 39 | for exch in self._exch_dict.values(): 40 | self.store.save_exchange(exch) 41 | 42 | def reset(self): 43 | self._inst_dict = {} 44 | self._ccy_dict = {} 45 | self._exch_dict = {} 46 | 47 | # get all 48 | def get_all_insts(self): 49 | return self._inst_dict.values() 50 | 51 | def get_all_ccys(self): 52 | return self._ccy_dict.values() 53 | 54 | def get_all_exchs(self): 55 | return self._exch_dict.values() 56 | 57 | def get_inst(self, inst_id): 58 | return self._inst_dict.get(inst_id, None) 59 | 60 | def get_ccy(self, ccy_id): 61 | return self._ccy_dict.get(ccy_id, None) 62 | 63 | def get_exch(self, exch_id): 64 | return self._exch_dict.get(exch_id, None) 65 | 66 | def get_insts_by_ids(self, ids): 67 | ids = set(ids) 68 | return [self._inst_dict[id] for id in ids if id in self._inst_dict] 69 | 70 | def get_insts_by_symbols(self, symbols): 71 | symbols = set(symbols) 72 | return [inst for inst in self._inst_dict.values() if inst.symbol in symbols] 73 | 74 | def add_inst(self, inst): 75 | self._inst_dict[inst.inst_id] = inst 76 | if self.store and self.persist_mode == PersistenceMode.RealTime: 77 | self.store.save_instrument(inst) 78 | 79 | def add_ccy(self, ccy): 80 | self._ccy_dict[ccy.ccy_id] = ccy 81 | if self.store and self.persist_mode == PersistenceMode.RealTime: 82 | self.store.save_currency(ccy) 83 | 84 | def add_exch(self, exch): 85 | self._exch_dict[exch.exch_id] = exch 86 | if self.store and self.persist_mode == PersistenceMode.RealTime: 87 | self.store.save_exchange(exch) 88 | 89 | def id(self): 90 | raise "RefDataManager" 91 | -------------------------------------------------------------------------------- /algotrader/trading/sequence.py: -------------------------------------------------------------------------------- 1 | from algotrader import SimpleManager, Context 2 | from algotrader.provider.datastore import PersistenceMode 3 | 4 | 5 | class SequenceManager(SimpleManager): 6 | ID = "SequenceManager" 7 | 8 | def __init__(self): 9 | super(SequenceManager, self).__init__() 10 | self.store = None 11 | 12 | def _start(self, app_context: Context) -> None: 13 | self.store = self.app_context.get_data_store() 14 | self.persist_mode = self.app_context.config.get_app_config("persistenceMode") 15 | self.load_all() 16 | 17 | def load_all(self): 18 | if self.store: 19 | self.store.start(self.app_context) 20 | items = self.store.load_all('sequences') 21 | self.item_dict.update(items) 22 | 23 | def save_all(self): 24 | if self.store and self.persist_mode != PersistenceMode.Disable: 25 | for key, value in self.item_dict.items(): 26 | self.store.save_sequence(key, value) 27 | 28 | def get(self, id): 29 | return self.item_dict.get(id, None) 30 | 31 | def get_next_sequence(self, id): 32 | if id not in self.item_dict: 33 | self.add(id) 34 | current = self.item_dict[id] 35 | self.item_dict[id] += 1 36 | return current 37 | 38 | def add(self, id, initial=1): 39 | self.item_dict[id] = initial 40 | 41 | if self.store and self.persist_mode == PersistenceMode.RealTime: 42 | self.store.save_sequence(id, initial) 43 | 44 | def all_items(self): 45 | return self.item_dict 46 | 47 | def has_item(self, id): 48 | return id in self.item_dict 49 | 50 | def id(self): 51 | return self.ID 52 | -------------------------------------------------------------------------------- /algotrader/trading/subscription.py: -------------------------------------------------------------------------------- 1 | from algotrader.model.market_data_pb2 import MarketDataSubscriptionRequest 2 | from algotrader.model.model_factory import ModelFactory 3 | from algotrader.utils.market_data import get_subscription_type, get_bar_size, get_bar_type 4 | 5 | 6 | class MarketDataSubscriber(object): 7 | def subscript_market_data(self, feed, instruments, subscription_types, from_date=None, to_date=None): 8 | for sub_req in self.build_subscription_requests(feed.id(), instruments, subscription_types, from_date, to_date): 9 | feed.subscribe_mktdata(sub_req) 10 | 11 | def build_subscription_requests(self, feed_id, instruments, subscription_types, from_date=None, to_date=None): 12 | reqs = [] 13 | for instrument in instruments: 14 | for subscription_type in subscription_types: 15 | attrs = subscription_type.split(".") 16 | md_type = get_subscription_type(attrs[0]) 17 | md_provider_id = attrs[1] 18 | bar_type = get_bar_type(attrs[2]) if md_type == MarketDataSubscriptionRequest.Bar else None 19 | bar_size = get_bar_size(attrs[3]) if md_type == MarketDataSubscriptionRequest.Bar else None 20 | 21 | reqs.append(ModelFactory.build_market_data_subscription_request(type=md_type, 22 | inst_id=instrument.inst_id, 23 | feed_id=feed_id, 24 | md_provider_id=md_provider_id, 25 | bar_type=bar_type, 26 | bar_size=bar_size, 27 | from_date=from_date, 28 | to_date=to_date)) 29 | return reqs 30 | -------------------------------------------------------------------------------- /algotrader/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexcwyu/python-trading/a494f602411a3ebfdecae002a16a5ea93fc7a046/algotrader/utils/__init__.py -------------------------------------------------------------------------------- /algotrader/utils/data_series.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from typing import Dict, List 3 | from tzlocal import get_localzone 4 | 5 | 6 | def convert_series_idx_to_datetime(series: pd.Series) -> pd.Series: 7 | return pd.Series(series.values, 8 | index=pd.to_datetime(series.index, unit='ms').tz_localize('UTC') 9 | .tz_convert(get_localzone().zone)) 10 | 11 | 12 | def get_input_name(input): 13 | if hasattr(input, 'time_series') and input.time_series: 14 | return "%s" % input.time_series.series_id 15 | elif isinstance(input, str): 16 | return "%s" % input # str 17 | raise Exception("only str or DataSeries is supported") 18 | 19 | 20 | def convert_input(inputs=None) -> List[str]: 21 | result = [] 22 | if inputs: 23 | if not isinstance(inputs, list): 24 | inputs = [inputs] 25 | 26 | result = [get_input_name(input) for input in inputs] 27 | 28 | return result 29 | 30 | 31 | def convert_input_keys(inputs=None, input_keys=None) -> Dict[str, List[str]]: 32 | result = {} 33 | if inputs and input_keys: 34 | inputs = convert_input(inputs) 35 | input_keys = input_keys if input_keys else {} 36 | if isinstance(input_keys, dict): 37 | for k, v in input_keys.items(): 38 | if isinstance(v, list): 39 | result[k] = v 40 | elif isinstance(v, str): 41 | result[k] = [v] 42 | 43 | if isinstance(input_keys, list): 44 | for input in inputs: 45 | input_name = get_input_name(input) 46 | result[input_name] = input_keys 47 | 48 | if isinstance(input_keys, str): 49 | for input in inputs: 50 | input_name = get_input_name(input) 51 | result[input_name] = [input_keys] 52 | 53 | return result 54 | 55 | 56 | def build_series_id(name: str, inputs=None, input_keys=None, **kwargs): 57 | parts = [] 58 | inputs = convert_input(inputs) 59 | input_keys = convert_input_keys(inputs, input_keys) 60 | for input in inputs: 61 | input_name = get_input_name(input) 62 | if input_name in input_keys: 63 | keys = input_keys[input_name] 64 | parts.append('%s[%s]' % (input_name, ','.join(keys))) 65 | else: 66 | parts.append(input_name) 67 | 68 | if kwargs: 69 | for key, value in kwargs.items(): 70 | parts.append('%s=%s' % (key, value)) 71 | 72 | if parts: 73 | return "%s(%s)" % (name, ','.join(parts)) 74 | else: 75 | 76 | return "%s()" % name # # def build_indicator(cls, inputs=None, input_keys=None, desc=None, time_series=None, **kwargs): 77 | 78 | 79 | # if isinstance(cls, str): 80 | # if cls in cls_cache: 81 | # cls = cls_cache[cls] 82 | # else: 83 | # cls_name = cls 84 | # cls = dynamic_import(cls_name) 85 | # cls_cache[cls_name] = cls 86 | # 87 | # if not time_series: 88 | # if inputs and not isinstance(inputs, list): 89 | # inputs = list(inputs) 90 | # series_id = build_series_id(cls.__name__, inputs, input_keys, **kwargs) 91 | # time_series = ModelFactory.build_time_series(series_id=series_id, 92 | # series_cls=get_full_cls_name(cls), 93 | # desc=desc, 94 | # inputs=inputs, input_keys=input_keys, **kwargs) 95 | # 96 | # return cls(time_series=time_series) 97 | 98 | 99 | def convert_to_list(items=None): 100 | if items and type(items) != set and type(items) != list: 101 | items = [items] 102 | return items 103 | -------------------------------------------------------------------------------- /algotrader/utils/date.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | epoch = datetime.datetime.fromtimestamp(0) 4 | 5 | 6 | def datetime_to_unixtimemillis(dt: datetime.datetime) -> int: 7 | return int((dt - epoch).total_seconds() * 1000) 8 | 9 | 10 | def unixtimemillis_to_datetime(timestamp: int) -> datetime.datetime: 11 | return datetime.datetime.fromtimestamp(timestamp / 1000.0) 12 | 13 | 14 | def datetime_to_timestamp(dt: datetime.datetime) -> int: 15 | return (dt - epoch).total_seconds() 16 | 17 | 18 | def timestamp_to_datetime(timestamp: int) -> datetime.datetime: 19 | return datetime.datetime.fromtimestamp(timestamp) 20 | 21 | 22 | def datestr_to_unixtimemillis(datestr: str) -> int: 23 | if not datestr: 24 | return None 25 | return date_to_unixtimemillis(datestr_to_date(datestr)) 26 | 27 | 28 | def datestr_to_date(datestr: str) -> datetime.date: 29 | if not datestr: 30 | return None 31 | datestr = str(datestr) 32 | return datetime.date(int(datestr[0:4]), int(datestr[4:6]), int(datestr[6:8])) 33 | 34 | 35 | def date_to_unixtimemillis(d: datetime.date) -> int: 36 | return int( 37 | (datetime.datetime.combine(d, datetime.datetime.min.time()) - epoch).total_seconds() * 1000) 38 | 39 | 40 | def unixtimemillis_to_date(timestamp: int) -> datetime.date: 41 | return datetime.datetime.fromtimestamp(timestamp / 1000).date() 42 | 43 | 44 | def date_to_timestamp(d: datetime.date) -> int: 45 | return (datetime.datetime.combine(d, datetime.datetime.min.time()) - epoch).total_seconds() 46 | 47 | 48 | def timestamp_to_date(timestamp: int) -> datetime.date: 49 | return datetime.datetime.fromtimestamp(timestamp).date() 50 | -------------------------------------------------------------------------------- /algotrader/utils/indicator.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from algotrader.utils.data_series import build_series_id 4 | from algotrader.utils.model import get_cls 5 | 6 | 7 | def parse_inner(inner_str, level): 8 | #print(level, inner_str) 9 | inputs = [] 10 | input_keys = {} 11 | kwargs = {} 12 | while "(" in inner_str: 13 | lidx = inner_str.find("(") 14 | ridx = inner_str.rfind(")") 15 | assert lidx > -1 16 | assert ridx > -1 17 | nested = inner_str[0:ridx] 18 | 19 | inner_str = inner_str[ridx+1:] 20 | 21 | 22 | for inner in inner_str.split(','): 23 | if '=' in inner: 24 | idx = inner.find("=") 25 | k = inner[0:idx] 26 | v = inner[idx + 1:] 27 | kwargs[k] = v 28 | elif '[' in inner: 29 | assert inner.endswith("]") 30 | idx = inner.find("[") 31 | input = inner[0:idx] 32 | keys = re.split('; |, ', inner[idx + 1: -1]) 33 | inputs.append(input) 34 | input_keys[input] = keys 35 | else: 36 | inputs.append(inner) 37 | 38 | return inputs, input_keys, kwargs 39 | 40 | 41 | def parse_series(inst_data_mgr, name): 42 | if not inst_data_mgr.has_series(name): 43 | lidx = name.find("(") 44 | assert name.endswith(")"), "invalid syntax, cannot parse %s" % name 45 | assert lidx > -1, "invalid syntax, cannot parse %s" % name 46 | cls_str = name[0:lidx] 47 | inner_str = name[lidx + 1:-1] 48 | 49 | inputs, input_keys, kwargs = parse_inner(inner_str) 50 | cls = globals()[cls_str] 51 | return cls(inputs=inputs, input_keys=input_keys, **kwargs) 52 | 53 | return inst_data_mgr.get_series(name) 54 | 55 | 56 | def get_or_create_indicator(inst_data_mgr, cls, inputs=None, input_keys=None, **kwargs): 57 | cls = get_cls(cls) 58 | name = build_series_id(cls.__name__, inputs=inputs, input_keys=input_keys, **kwargs) 59 | if not inst_data_mgr.has_series(name): 60 | obj = cls(inputs=inputs, input_keys=input_keys, **kwargs) 61 | inst_data_mgr.add_series(obj) 62 | return obj 63 | return inst_data_mgr.get_series(name, create_if_missing=False) 64 | -------------------------------------------------------------------------------- /algotrader/utils/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logger = logging.getLogger(__name__) 4 | logger.addHandler(logging.StreamHandler()) 5 | logger.setLevel(logging.INFO) 6 | -------------------------------------------------------------------------------- /algotrader/utils/py2to3.py: -------------------------------------------------------------------------------- 1 | ## 2 | # import to make code compatibility with both python2 and python3 3 | ## 4 | 5 | 6 | try: 7 | range = xrange # Python 2 8 | except NameError: 9 | pass # Python 3 10 | -------------------------------------------------------------------------------- /algotrader/utils/ref_data.py: -------------------------------------------------------------------------------- 1 | import csv 2 | 3 | import pandas as pd 4 | 5 | from algotrader.model.model_factory import ModelFactory 6 | 7 | 8 | def get_inst_symbol(self, inst, provider_id): 9 | if inst: 10 | return inst.alt_symbols[provider_id] if provider_id in inst.alt_symbols else inst.symbol 11 | return None 12 | 13 | 14 | def get_exch_id(self, exch, provider_id): 15 | if exch: 16 | return exch.alt_ids[provider_id] if provider_id in exch.alt_ids else exch.exch_id 17 | return None 18 | 19 | 20 | def load_inst_from_csv(data_store, inst_file=None): 21 | if inst_file: 22 | with open(inst_file) as csvfile: 23 | reader = csv.DictReader(csvfile) 24 | for row in reader: 25 | load_inst_from_row(data_store, row) 26 | 27 | 28 | def load_ccy_from_csv(data_store, ccy_file=None): 29 | if ccy_file: 30 | with open(ccy_file) as csvfile: 31 | reader = csv.DictReader(csvfile) 32 | for row in reader: 33 | load_ccy_from_row(data_store, row) 34 | 35 | 36 | def load_exch_from_csv(data_store, exch_file=None): 37 | if exch_file: 38 | with open(exch_file) as csvfile: 39 | reader = csv.DictReader(csvfile) 40 | for row in reader: 41 | load_exch_from_row(data_store, row) 42 | 43 | 44 | def load_inst_from_df(data_store, inst_df): 45 | for index, row in inst_df.iterrows(): 46 | load_inst_from_row(data_store, row) 47 | 48 | 49 | def load_ccy_from_df(data_store, ccy_df): 50 | for index, row in ccy_df.iterrows(): 51 | load_ccy_from_row(data_store, row) 52 | 53 | 54 | def load_exch_from_df(data_store, exch_df): 55 | for index, row in exch_df.iterrows(): 56 | load_exch_from_row(data_store, row) 57 | 58 | 59 | def load_inst_from_row(data_store, row): 60 | alt_symbols = {} 61 | if 'alt_symbols' in row and row['alt_symbols']: 62 | for item in row['alt_symbols'].split(";"): 63 | kv = item.split("=") 64 | alt_symbols[kv[0]] = kv[1] 65 | inst = ModelFactory.build_instrument(symbol=row['symbol'], 66 | type=row['type'], 67 | primary_exch_id=row['exch_id'], 68 | ccy_id=row['ccy_id'], 69 | name=row['name'], 70 | sector=row['sector'], 71 | industry=row['industry'], 72 | margin=row['margin'], 73 | alt_symbols=alt_symbols, 74 | underlying_ids=row['und_inst_id'], 75 | option_type=row['put_call'], 76 | strike=row['strike'], 77 | exp_date=row['expiry_date'], 78 | multiplier=row['factor']) 79 | data_store.save_instrument(inst) 80 | 81 | 82 | def load_ccy_from_row(data_store, row): 83 | ccy = ModelFactory.build_currency(ccy_id=row['ccy_id'], name=row['name']) 84 | data_store.save_currency(ccy) 85 | 86 | 87 | def load_exch_from_row(data_store, row): 88 | exch = ModelFactory.build_exchange(exch_id=row['exch_id'], name=row['name']) 89 | data_store.save_exchange(exch) 90 | 91 | 92 | def build_inst_dataframe_from_list(symbols, type='ETF', exch_id='NYSE', ccy_id='USD'): 93 | inst_df = pd.DataFrame({'name': symbols}) 94 | inst_df['type'] = type 95 | inst_df['symbol'] = inst_df['name'] 96 | inst_df['exch_id'] = exch_id 97 | inst_df['ccy_id'] = ccy_id 98 | inst_df['alt_symbol'] = '' 99 | inst_df['alt_exch_id'] = '' 100 | inst_df['sector'] = '' 101 | inst_df['industry'] = '' 102 | inst_df['put_call'] = '' 103 | inst_df['expiry_date'] = '' 104 | inst_df['und_inst_id'] = '' 105 | inst_df['factor'] = '' 106 | inst_df['strike'] = '' 107 | inst_df['margin'] = '' 108 | inst_df['inst_id'] = inst_df.index 109 | return inst_df 110 | -------------------------------------------------------------------------------- /algotrader/utils/sde_sim.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | 5 | 6 | def euler(drift, diffusion, x0, T, Tstep, Nsim): 7 | """ 8 | Simulate solution of Stochastic Differential Equation by discretization using 9 | Euler's method 10 | :param drift: lambda function with signature x, t 11 | :param diffusion: lambda function with signature x, t 12 | :param x0: initial position 13 | :param T: terminal time in double 14 | :param Tstep: number of time steps 15 | :param Nsim: number of path simulated 16 | :return: np array of simulated stochastic process 17 | """ 18 | dt = T / (Tstep - 1) 19 | dW = np.random.normal(0, math.sqrt(dt), Nsim * Tstep).reshape(Nsim, Tstep) 20 | 21 | x = np.zeros([Nsim, Tstep]) 22 | t = np.linspace(0, T, Tstep) 23 | x[:, 0] = x0 24 | 25 | for i in range(1, Tstep): 26 | x[:, i] = x[:, i - 1] + drift(x[:, i - 1], t[i - 1]) * dt + diffusion(x[:, i - 1], t[i - 1]) * dW[:, i - 1] 27 | return x 28 | 29 | 30 | def euler2d(drift0, drift1, diffusion0, diffusion1, rho, x0, y0, T, Tstep, Nsim): 31 | """ 32 | :param drift0: 33 | :param drift1: 34 | :param diffusion0: 35 | :param diffusion1: 36 | :param rho: 37 | :param x0: 38 | :param y0: 39 | :param T: 40 | :param Tstep: 41 | :param Nsim: 42 | :return: 43 | """ 44 | dt = T / (Tstep - 1) 45 | dW0 = np.random.normal(0, dt, Nsim * Tstep).reshape(Nsim, Tstep) 46 | dW1 = np.random.normal(0, dt, Nsim * Tstep).reshape(Nsim, Tstep) 47 | 48 | x = np.zeros([Nsim, Tstep]) 49 | y = np.zeros([Nsim, Tstep]) 50 | t = np.linspace(0, T, Tstep) 51 | 52 | x[:, 0] = x0 53 | y[:, 0] = y0 54 | 55 | for i in range(1, Tstep): 56 | dWy = rho * dW0[:, i - 1] + np.sqrt(1.0 - rho ** 2) * dW1[:, i - 1] 57 | x[:, i] = x[:, i - 1] + drift0(x[:, i - 1], y[:, i - 1], t[i - 1]) * dt \ 58 | + diffusion0(x[:, i - 1], y[:, i - 1], t[i - 1]) * dW0[:, i - 1] 59 | y[:, i] = y[:, i - 1] + drift1(x[:, i - 1], y[:, i - 1], t[i - 1]) * dt \ 60 | + diffusion1(x[:, i - 1], y[:, i - 1], t[i - 1]) * dWy[:, i - 1] 61 | 62 | return x, y 63 | -------------------------------------------------------------------------------- /algotrader/utils/trade_data.py: -------------------------------------------------------------------------------- 1 | from algotrader.model.trade_data_pb2 import * 2 | 3 | 4 | def is_buy(new_order_req: NewOrderRequest): 5 | return new_order_req.action == Buy 6 | 7 | 8 | def is_sell(new_order_req: NewOrderRequest): 9 | return new_order_req.action == Sell 10 | -------------------------------------------------------------------------------- /config/backtest.json: -------------------------------------------------------------------------------- 1 | { 2 | "Broker": { 3 | "IBBroker": { 4 | "clientId": 0, 5 | "host": "localhost", 6 | "nextOrderId": 1, 7 | "account": 1, 8 | "port": 4001, 9 | "daemon": true, 10 | "nextRequestId": 1, 11 | "useGevent": false 12 | }, 13 | "Simulator": { 14 | "fillStrategy": "Default", 15 | "nextOrderId": 1, 16 | "commission": "Default", 17 | "nextExecId": 1 18 | } 19 | }, 20 | "Application": { 21 | "clock": "Simulation", 22 | "refDataMgr": "InMemory" 23 | }, 24 | "Feed": { 25 | "CSVFeed": { 26 | "path": "../data/tradedata" 27 | } 28 | }, 29 | "StgConfig": { 30 | "down2%Stg": { 31 | "qty": null 32 | } 33 | }, 34 | "Persistence": { 35 | "dataStore": "InMemoryDB", 36 | "persistenceMode": "Disable" 37 | }, 38 | "DataStore": { 39 | "MongoDB": { 40 | "username": null, 41 | "dbname": "algotrader", 42 | "host": "localhost", 43 | "password": null, 44 | "port": 27107 45 | }, 46 | "InMemoryDB": { 47 | "file": "algotrader_db.p" 48 | }, 49 | "CassandraDB": { 50 | "keyspace": "algotrader", 51 | "scriptPath": "../../../scripts/cassandra/algotrader.cql", 52 | "username": null, 53 | "contactPoints": [ 54 | "127.0.0.1" 55 | ], 56 | "password": null, 57 | "port": null 58 | } 59 | }, 60 | "Trading": { 61 | "stgId": null, 62 | "portfolioId": null, 63 | "feedId": null, 64 | "instrumentIds": null, 65 | "stgCls": null, 66 | "brokerId": null 67 | } 68 | } -------------------------------------------------------------------------------- /config/backtest.yaml: -------------------------------------------------------------------------------- 1 | Application: 2 | 3 | type: "BackTesting" 4 | 5 | clockId: "Simulation" 6 | 7 | dataStoreId: "InMemory" 8 | persistenceMode: "Disable" 9 | createDBAtStart : false 10 | deleteDBAtStop : false 11 | 12 | feedId: "CSV" 13 | brokerId: "Simulator" 14 | portfolioId: "test" 15 | 16 | fromDate : 20100101 17 | toDate : 20170101 18 | portfolioInitialcash : 100000 19 | plot : true 20 | 21 | 22 | DataStore: 23 | 24 | Cassandra: 25 | contactPoints: 26 | - "127.0.0.1" 27 | port: 28 | username: 29 | password: 30 | keyspace: "algotrader" 31 | scriptPath: "../../../scripts/cassandra/algotrader.cql" 32 | 33 | Mongo: 34 | host: "localhost" 35 | port: 27107 36 | username: 37 | password: 38 | dbname: "algotrader" 39 | 40 | InMemory: 41 | file: "../../data/algotrader_db.p" 42 | instCSV: "../../data/refdata/instrument.csv" 43 | ccyCSV: "../../data/refdata/ccy.csv" 44 | exchCSV: "../../data/refdata/exch.csv" 45 | 46 | Feed: 47 | CSV: 48 | path: "/mnt/data/dev/workspaces/python-trading/data/tradedata" 49 | 50 | Broker: 51 | Simulator: 52 | commission: "Default" 53 | fillStrategy: "Default" 54 | nextOrderId: 1 55 | nextExecId: 1 56 | -------------------------------------------------------------------------------- /config/config.txt: -------------------------------------------------------------------------------- 1 | ApplicationConfig 2 | ref_data_mgr_type 3 | RefDataManager.InMemory 4 | RefDataManager.DB 5 | RefDataManager.Mock 6 | 7 | clock_type 8 | Clock.Simulation 9 | Clock.RealTime 10 | 11 | persistence_config 12 | ref_ds_id 13 | ref_persist_mode 14 | PersistenceMode.Disable 15 | PersistenceMode.Batch 16 | PersistenceMode.RealTime 17 | 18 | trade_ds_id 19 | trade_persist_mode 20 | 21 | ts_ds_id 22 | ts_persist_mode 23 | 24 | seq_ds_id 25 | seq_persist_mode 26 | 27 | 28 | provider_configs 29 | {provider_config.__class__ : provider_config} 30 | 31 | 32 | RealtimeMarketDataImporterConfig 33 | feed_id 34 | instrument_ids 35 | subscription_types 36 | 37 | 38 | HistoricalMarketDataImporterConfig 39 | feed_id 40 | instrument_ids 41 | subscription_types 42 | from_date 43 | to_date 44 | 45 | #### 46 | TradingConfig 47 | stg_id 48 | stg_cls 49 | stg_configs 50 | 51 | feed_id 52 | instrument_ids 53 | subscription_types 54 | 55 | portfolio_id : str 56 | broker_id : str 57 | 58 | 59 | LiveTradingConfig : TradingConfig 60 | 61 | 62 | 63 | BacktestingConfig : TradingConfig 64 | from_date : int 65 | to_date : int 66 | portfolio_initial_cash : float 67 | 68 | #### data store 69 | DataStoreConfig 70 | create_at_start : bool 71 | delete_at_stop : bool 72 | 73 | CassandraConfig 74 | contact_points = List[str] 75 | port : int 76 | username : str 77 | password : str 78 | keyspace : str = 'algotrader' 79 | cql_script_path : str = '../../../scripts/cassandra/algotrader.cql' 80 | 81 | 82 | MongoDBConfig 83 | host : str = 'localhost' 84 | port : int = 27107 85 | username : str 86 | password : str 87 | dbname : str = 'algotrader' 88 | 89 | 90 | InMemoryStoreConfig 91 | file : str = 'algotrader_db.p' 92 | 93 | ## Feed 94 | CSVFeedConfig 95 | path : str = '../data/tradedata' 96 | 97 | 98 | PandasMemoryDataFeedConfig 99 | dict_df : dataframe 100 | 101 | 102 | ## Broker 103 | IBConfig 104 | host : str = 'localhost' 105 | port : int = 4001 106 | client_id : int = 0 107 | account 108 | daemon 109 | use_gevent 110 | next_request_id : int = 1 111 | next_order_id 112 | 113 | 114 | SimulatorConfig 115 | commission_id = Commission.Default 116 | fill_strategy_id = FillStrategy.Default 117 | next_ord_id : int = 0 118 | next_exec_id : int = 0 -------------------------------------------------------------------------------- /config/data_import.yaml: -------------------------------------------------------------------------------- 1 | Application: 2 | 3 | type: "DataImport" 4 | 5 | clockId: "RealTime" 6 | 7 | dataStoreId: "Mongo" 8 | persistenceMode: "RealTime" 9 | createDBAtStart : false 10 | deleteDBAtStop : false 11 | 12 | fromDate : 20100101 13 | toDate : 20170101 14 | 15 | 16 | DataStore: 17 | 18 | Cassandra: 19 | contactPoints: 20 | - "127.0.0.1" 21 | port: 22 | username: 23 | password: 24 | keyspace: "algotrader" 25 | scriptPath: "../../../scripts/cassandra/algotrader.cql" 26 | 27 | Mongo: 28 | host: "localhost" 29 | port: 27107 30 | username: 31 | password: 32 | dbname: "algotrader" 33 | 34 | InMemory: 35 | file: "algotrader_db.p" 36 | 37 | Feed: 38 | CSV: 39 | path: "../data/tradedata" 40 | -------------------------------------------------------------------------------- /config/down2%.yaml: -------------------------------------------------------------------------------- 1 | Application: 2 | stgId: "down2%" 3 | stgCls: "algotrader.strategy.down_2pct_strategy.Down2PctStrategy" 4 | instrumentIds: 5 | - "SPY@NYSEARCA" 6 | subscriptionTypes: 7 | - "Bar.Yahoo.Time.D1" 8 | subscriptions: 9 | 10 | Strategy: 11 | down2%: 12 | qty: 1 13 | -------------------------------------------------------------------------------- /config/live_ib.yaml: -------------------------------------------------------------------------------- 1 | Application: 2 | 3 | type: "LiveTrading" 4 | 5 | clockId: "RealTime" 6 | 7 | dataStoreId: "Mongo" 8 | persistenceMode: "RealTime" 9 | createDBAtStart : false 10 | deleteDBAtStop : false 11 | 12 | feedId: "IB" 13 | brokerId: "IB" 14 | 15 | 16 | DataStore: 17 | 18 | Cassandra: 19 | contactPoints: 20 | - "127.0.0.1" 21 | port: 22 | username: 23 | password: 24 | keyspace: "algotrader" 25 | scriptPath: "../../../scripts/cassandra/algotrader.cql" 26 | 27 | Mongo: 28 | host: "localhost" 29 | port: 27107 30 | username: 31 | password: 32 | dbname: "algotrader" 33 | 34 | InMemory: 35 | file: "algotrader_db.p" 36 | 37 | Feed: 38 | CSV: 39 | path: "../data/tradedata" 40 | 41 | Broker: 42 | IB: 43 | host: "localhost" 44 | port: 4001 45 | clientId: 0 46 | account: 1 47 | daemon: true 48 | useGevent: false 49 | nextRequestId: 1 50 | nextOrderId: 1 -------------------------------------------------------------------------------- /data/refdata/ccy.csv: -------------------------------------------------------------------------------- 1 | ccy_id,name 2 | USD,US Dollar 3 | HKD,HK Dollar 4 | CNY,Chinese Yuan Renminbi 5 | RUR,Russian Ruble 6 | AUD,Australian Dollar 7 | NZD,New Zealand Dollar 8 | CAD,Canadian Dollar 9 | GBP,British Pound 10 | EUR,Euro 11 | JPY,Japanese Yen 12 | CHF,Swiss Franc 13 | SGD,Singapore Dollar 14 | KRW,Korean (South) Won -------------------------------------------------------------------------------- /data/refdata/exch.csv: -------------------------------------------------------------------------------- 1 | exch_id,name 2 | SEHK,Hong Kong Stock Exchange 3 | HKFE,Hong Kong Futures Exchange 4 | SEHKNTL,Shanghai-Hong Kong Stock Connect 5 | NSE,National Stock Exchange of India 6 | CHIXJ,CHI-X Japan 7 | OSE,Osaka Securities Exchange 8 | TSEJ,Tokyo Stock Exchange 9 | SGX,Singapore Exchange 10 | KSE,Korea Stock Exchange 11 | ASX,Australian Stock Exchange 12 | IDEAL,IDEAL FX 13 | IDEALPRO,IDEALPRO Metals 14 | LSE,London Stock Exchange 15 | SWX,Swiss Exchange 16 | FWB,Frankfurt Stock Exchange 17 | CFE,CBOE Futures Exchange 18 | ECBOT,CBOT 19 | CBOE,Chicago Board Options Exchange 20 | CHX,Chicago Stock Exchange 21 | GLOBEX,CME 22 | NYBOT,ICE Futures U.S. 23 | ICEUS,ICE Futures US 24 | ISE,ISE Options Exchange 25 | NASDAQ,NASDAQ 26 | AMEX,NYSE Amex 27 | ARCA,NYSE Arca 28 | PSE,NYSE Arca 29 | NYMEX,New York Mercantile Exchange 30 | NYSE,New York Stock Exchange 31 | TSE,Toronto Stock Exchange 32 | SMART,Smart Order Routing 33 | -------------------------------------------------------------------------------- /data/refdata/instrument.csv: -------------------------------------------------------------------------------- 1 | symbol,type,name,exch_id,ccy_id,alt_symbols,sector,industry,und_inst_id,expiry_date,factor,strike,put_call,margin 2 | EURUSD,CASH,EUR/USD,IDEALPRO,USD,IB=EUR 3 | USDJPY,CASH,USD/JPY,IDEALPRO,JPY,IB=USD 4 | GOOG,STK,Alphabet Inc.,NASDAQ,USD 5 | MSFT,STK,Microsoft Corporation,NASDAQ,USD 6 | AAPL,STK,Apple Inc.,NASDAQ,USD 7 | FB,STK,Facebook Inc.,NASDAQ,USD 8 | AMZN,STK,Amazon.com Inc.,NASDAQ,USD 9 | BABA,STK,ALIBABA GROUP HOLDING LTD,NYSE,USD 10 | VXX,STK,iPath S&P 500 VIX ST Futures ETN,NYSEARCA,USD 11 | SPY,ETF,SPDR S&P 500 ETF,NYSEARCA,USD 12 | XIV,ETF,VelocityShares Daily Inverse VIX ST ETN,NYSEARCA,USD 13 | HSI,IDX,HANG SENG INDEX,CBOE,HKD,Yahoo=^HSI 14 | SPX,IDX,S&P 500 Index,CBOE,USD,Yahoo=^GSPC:Goolge=.INX 15 | DJI,IDX,Dow Jones Industrial Average,CBOE,USD,Yahoo=^DJI:Google=.DJI 16 | IXIC,IDX,NASDAQ Composite,NASDAQ,USD,Yahoo=^IXIC:Google=.IXIC 17 | VIX,IDX,VOLATILITY S&P 500,CBOE,USD,Yahoo=^VIX -------------------------------------------------------------------------------- /poc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexcwyu/python-trading/a494f602411a3ebfdecae002a16a5ea93fc7a046/poc/__init__.py -------------------------------------------------------------------------------- /poc/cassandra_sample.py: -------------------------------------------------------------------------------- 1 | from cassandra.cluster import Cluster 2 | import os 3 | 4 | cluster = Cluster(contact_points=['127.0.0.1']) 5 | session = cluster.connect() 6 | session.set_keyspace("demo") 7 | 8 | 9 | cql_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../scripts/cassandra/algotrader.cql')) 10 | 11 | with open(cql_path) as cql_file: 12 | for stmt in cql_file.read().split(";"): 13 | if len(stmt.strip())>0: 14 | print stmt 15 | session.execute(stmt) 16 | # 17 | # session.execute(""" 18 | # 19 | # insert into users (lastname, age, city, email, firstname) values ('Jones', 35, 'Austin', 'bob@example.com', 'Bob') 20 | # 21 | # """) 22 | 23 | result = session.execute("select * from users where lastname='Jones' ")[0] 24 | print result.firstname, result.age -------------------------------------------------------------------------------- /poc/copy_test.py: -------------------------------------------------------------------------------- 1 | class C1(object): 2 | __slots__ = "s1"; 3 | 4 | 5 | class C2(C1): 6 | __slots__ = "s2"; 7 | 8 | 9 | class C3(C2): 10 | pass 11 | 12 | 13 | o1 = C1() 14 | o2 = C2() 15 | o3 = C3() 16 | 17 | print o1.__slots__ # prints s1 18 | print o2.__slots__ # prints s2 19 | print o3.__slots__ # prints s2 20 | 21 | o1.s1 = 11 22 | o2.s1 = 21 23 | o2.s2 = 22 24 | o3.s1 = 31 25 | o3.s2 = 32 26 | o3.a = 5 27 | 28 | import copy 29 | 30 | p3 = copy.copy(o3) 31 | 32 | print "p3=", p3.s1 33 | print "p3=", p3.s2 34 | print "p3=", p3.a 35 | -------------------------------------------------------------------------------- /poc/oms_client.py: -------------------------------------------------------------------------------- 1 | import zerorpc 2 | 3 | from algotrader.event.market_data import Bar, Quote, Trade 4 | from algotrader.event.order import NewOrderRequest, OrdAction, OrdType 5 | 6 | c = zerorpc.Client() 7 | c.connect("tcp://127.0.0.1:14242") 8 | 9 | bar = Bar(open=18, high=19, low=17, close=17.5, vol=1000) 10 | quote = Quote(bid=18, ask=19, bid_size=200, ask_size=500) 11 | trade = Trade(price=20, size=200) 12 | order = NewOrderRequest(ord_id=1, inst_id=1, action=OrdAction.BUY, type=OrdType.LIMIT, qty=1000, limit_price=18.5) 13 | 14 | print c.on_order(order) 15 | -------------------------------------------------------------------------------- /poc/oms_server.py: -------------------------------------------------------------------------------- 1 | import zerorpc 2 | 3 | from algotrader import Context 4 | from algotrader.trading.order import OrderManager 5 | from algotrader.utils.logging import logger 6 | 7 | 8 | class RemoteOrderManager(OrderManager): 9 | def __init__(self, address="tcp://0.0.0.0:14242"): 10 | super(RemoteOrderManager, self).__init__() 11 | self.__address = address 12 | 13 | def _start(self, app_context: Context) -> None: 14 | self.__server = zerorpc.Server(self) 15 | self.__server.bind(self.__address) 16 | logger.info("starting OMS") 17 | self.__server.run() 18 | 19 | def on_new_ord_req(self, order): 20 | logger.info("[%s] %s" % (self.__class__.__name__, order)) 21 | return order 22 | 23 | def id(self): 24 | return "RemoteOrderManager" 25 | -------------------------------------------------------------------------------- /poc/pandas.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import pandas as pd\n", 12 | "import numpy as np" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "" 22 | ] 23 | } 24 | ], 25 | "metadata": { 26 | "kernelspec": { 27 | "display_name": "Python 2", 28 | "language": "python", 29 | "name": "python2" 30 | }, 31 | "language_info": { 32 | "codemirror_mode": { 33 | "name": "ipython", 34 | "version": 2.0 35 | }, 36 | "file_extension": ".py", 37 | "mimetype": "text/x-python", 38 | "name": "python", 39 | "nbconvert_exporter": "python", 40 | "pygments_lexer": "ipython2", 41 | "version": "2.7.6" 42 | } 43 | }, 44 | "nbformat": 4, 45 | "nbformat_minor": 0 46 | } -------------------------------------------------------------------------------- /poc/pyfolio_playground.py: -------------------------------------------------------------------------------- 1 | import pyfolio as pf 2 | 3 | # stock_rets = pf.utils.get_symbol_rets('SPY') 4 | stock_rets = pf.utils.get_symbol_rets('FB') 5 | print type(stock_rets) 6 | 7 | from pandas_datareader import data as web 8 | 9 | px = web.get_data_yahoo('FB', start=None, end=None) 10 | rets = px[['Adj Close']].pct_change().dropna() 11 | rets.index = rets.index.tz_localize("UTC") 12 | rets.columns = ['FB'] 13 | print type(rets) 14 | 15 | rets = rets['FB'] 16 | print type(rets) 17 | print stock_rets 18 | 19 | print "###" 20 | print rets 21 | 22 | pf.create_returns_tear_sheet(stock_rets) 23 | 24 | import matplotlib.pyplot as plt 25 | 26 | plt.show() 27 | -------------------------------------------------------------------------------- /poc/rpy2.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexcwyu/python-trading/a494f602411a3ebfdecae002a16a5ea93fc7a046/poc/rpy2.py -------------------------------------------------------------------------------- /poc/rxpy_scheduling.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | 3 | import gevent 4 | from algotrader.event.market_data import Bar, BarSize 5 | from rx.concurrency.historicalscheduler import HistoricalScheduler 6 | 7 | from algotrader.trading.clock import RealTimeClock 8 | from algotrader.utils.date import unixtimemillis_to_datetime, datetime_to_unixtimemillis 9 | 10 | realtime_clock = RealTimeClock() 11 | 12 | 13 | class HistoricalScheduler2(HistoricalScheduler): 14 | def __init__(self, initial_clock=None, comparer=None): 15 | def compare_datetimes(a, b): 16 | return (a > b) - (a < b) 17 | 18 | clock = initial_clock or datetime.fromtimestamp(0) 19 | comparer = comparer or compare_datetimes 20 | super(HistoricalScheduler2, self).__init__(clock, comparer) 21 | 22 | def now(self): 23 | return self.clock 24 | 25 | @staticmethod 26 | def add(absolute, relative): 27 | 28 | if isinstance(relative, int): 29 | return absolute + timedelta(milliseconds=relative) 30 | elif isinstance(relative, float): 31 | return absolute + timedelta(seconds=relative) 32 | 33 | return absolute + relative 34 | 35 | def to_datetime_offset(self, absolute): 36 | return absolute 37 | 38 | def to_relative(self, timespan): 39 | return timespan 40 | 41 | 42 | starttime = datetime.now() 43 | scheduler1 = HistoricalScheduler2(initial_clock=starttime) 44 | from algotrader.trading.clock import RealTimeScheduler 45 | 46 | scheduler2 = RealTimeScheduler() 47 | endtime = [None] 48 | 49 | 50 | def action(*arg): 51 | print(unixtimemillis_to_datetime(realtime_clock.now())) 52 | 53 | 54 | from rx import Observable 55 | import time 56 | 57 | 58 | from gevent.greenlet import Greenlet 59 | 60 | 61 | class MyNoopGreenlet(Greenlet): 62 | def __init__(self, seconds): 63 | Greenlet.__init__(self) 64 | self.seconds = seconds 65 | 66 | def _run(self): 67 | gevent.sleep(self.seconds) 68 | 69 | def __str__(self): 70 | return 'MyNoopGreenlet(%s)' % self.seconds 71 | 72 | 73 | current_ts = datetime_to_unixtimemillis(starttime) 74 | next_ts = Bar.get_next_bar_start_time(current_ts, BarSize.S5) 75 | diff = next_ts - current_ts 76 | # Observable.timer(int(diff), BarSize.S5 * 1000, scheduler2).subscribe(action) 77 | # scheduler1.advance_to(starttime) 78 | # scheduler2.schedule_absolute(datetime.utcnow() + timedelta(seconds=3), action, scheduler2.now) 79 | # print "1", scheduler1.now() 80 | # scheduler1.advance_to(starttime + timedelta(seconds=1)) 81 | # print "2", scheduler1.now() 82 | # scheduler1.advance_to(starttime + timedelta(seconds=2)) 83 | # print "3", scheduler1.now() 84 | # scheduler1.advance_to(starttime + timedelta(seconds=3)) 85 | # print "4", scheduler1.now() 86 | # scheduler1.advance_by(2000) 87 | # print "5", scheduler1.now() 88 | 89 | 90 | current_ts = datetime_to_unixtimemillis(starttime) 91 | next_ts = Bar.get_next_bar_start_time(current_ts, BarSize.S5) 92 | diff = next_ts - current_ts 93 | 94 | Observable.timer(int(diff), 1000, realtime_clock.scheduler).subscribe(on_next=action) 95 | 96 | time.sleep(10000) 97 | -------------------------------------------------------------------------------- /poc/tensor_flow.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import tensorflow as tf" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "metadata": { 18 | "collapsed": true 19 | }, 20 | "outputs": [], 21 | "source": [ 22 | "hello = tf.constant('Hello, TensorFlow!')" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 3, 28 | "metadata": { 29 | "collapsed": true 30 | }, 31 | "outputs": [], 32 | "source": [ 33 | "sess = tf.Session()" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 4, 39 | "metadata": { 40 | "collapsed": false 41 | }, 42 | "outputs": [ 43 | { 44 | "name": "stdout", 45 | "output_type": "stream", 46 | "text": [ 47 | "Hello, TensorFlow!\n" 48 | ] 49 | } 50 | ], 51 | "source": [ 52 | "print(sess.run(hello))" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 5, 58 | "metadata": { 59 | "collapsed": true 60 | }, 61 | "outputs": [], 62 | "source": [ 63 | "a = tf.constant(10)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 6, 69 | "metadata": { 70 | "collapsed": true 71 | }, 72 | "outputs": [], 73 | "source": [ 74 | "b = tf.constant(28)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 7, 80 | "metadata": { 81 | "collapsed": false 82 | }, 83 | "outputs": [ 84 | { 85 | "name": "stdout", 86 | "output_type": "stream", 87 | "text": [ 88 | "38\n" 89 | ] 90 | } 91 | ], 92 | "source": [ 93 | "print(sess.run(a+b))" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 8, 99 | "metadata": { 100 | "collapsed": true 101 | }, 102 | "outputs": [], 103 | "source": [ 104 | "from theano import function, config, shared, sandbox\n", 105 | "import theano.tensor as T\n", 106 | "import numpy\n", 107 | "import time" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 9, 113 | "metadata": { 114 | "collapsed": false 115 | }, 116 | "outputs": [ 117 | { 118 | "name": "stdout", 119 | "output_type": "stream", 120 | "text": [ 121 | "[Elemwise{exp,no_inplace}()]\n", 122 | "Looping 1000 times took 3.615365 seconds\n", 123 | "Result is [ 1.23178032 1.61879341 1.52278065 ..., 2.20771815 2.29967753\n", 124 | " 1.62323285]\n", 125 | "Used the cpu\n" 126 | ] 127 | } 128 | ], 129 | "source": [ 130 | "vlen = 10 * 30 * 768 # 10 x #cores x # threads per core\n", 131 | "iters = 1000\n", 132 | "\n", 133 | "rng = numpy.random.RandomState(22)\n", 134 | "x = shared(numpy.asarray(rng.rand(vlen), config.floatX))\n", 135 | "f = function([], T.exp(x))\n", 136 | "print(f.maker.fgraph.toposort())\n", 137 | "t0 = time.time()\n", 138 | "for i in range(iters):\n", 139 | " r = f()\n", 140 | "t1 = time.time()\n", 141 | "print(\"Looping %d times took %f seconds\" % (iters, t1 - t0))\n", 142 | "print(\"Result is %s\" % (r,))\n", 143 | "if numpy.any([isinstance(x.op, T.Elemwise) for x in f.maker.fgraph.toposort()]):\n", 144 | " print('Used the cpu')\n", 145 | "else:\n", 146 | " print('Used the gpu')" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": { 153 | "collapsed": true 154 | }, 155 | "outputs": [], 156 | "source": [ 157 | "" 158 | ] 159 | } 160 | ], 161 | "metadata": { 162 | "kernelspec": { 163 | "display_name": "Python 2", 164 | "language": "python", 165 | "name": "python2" 166 | }, 167 | "language_info": { 168 | "codemirror_mode": { 169 | "name": "ipython", 170 | "version": 2.0 171 | }, 172 | "file_extension": ".py", 173 | "mimetype": "text/x-python", 174 | "name": "python", 175 | "nbconvert_exporter": "python", 176 | "pygments_lexer": "ipython2", 177 | "version": "2.7.12" 178 | } 179 | }, 180 | "nbformat": 4, 181 | "nbformat_minor": 0 182 | } -------------------------------------------------------------------------------- /poc/theano_check1.py: -------------------------------------------------------------------------------- 1 | from theano import function, config, shared, sandbox 2 | import theano.tensor as T 3 | import numpy 4 | import time 5 | 6 | vlen = 10 * 30 * 768 # 10 x #cores x # threads per core 7 | iters = 1000 8 | 9 | rng = numpy.random.RandomState(22) 10 | x = shared(numpy.asarray(rng.rand(vlen), config.floatX)) 11 | f = function([], T.exp(x)) 12 | print(f.maker.fgraph.toposort()) 13 | t0 = time.time() 14 | for i in range(iters): 15 | r = f() 16 | t1 = time.time() 17 | print("Looping %d times took %f seconds" % (iters, t1 - t0)) 18 | print("Result is %s" % (r,)) 19 | if numpy.any([isinstance(x.op, T.Elemwise) for x in f.maker.fgraph.toposort()]): 20 | print('Used the cpu') 21 | else: 22 | print('Used the gpu') 23 | -------------------------------------------------------------------------------- /poc/time_test.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | from algotrader.event.market_data import BarSize 4 | 5 | epoch = datetime.datetime.fromtimestamp(0) 6 | 7 | 8 | def get_bar_end_time(timestamp, bar_size): 9 | return get_bar_start_time(timestamp, bar_size) + bar_size * 1000 - 1 10 | 11 | 12 | def get_bar_start_time(timestamp, bar_size): 13 | if bar_size < BarSize.D1: 14 | return (int(timestamp / (bar_size * 1000)) * bar_size * 1000) 15 | else: 16 | dt = datetime.datetime.fromtimestamp(timestamp / 1000) 17 | dt = datetime.datetime(year=dt.year, month=dt.month, day=dt.day) 18 | next_ts = unix_time_millis(dt) 19 | return next_ts 20 | 21 | 22 | def unix_time_millis(dt): 23 | return int((dt - epoch).total_seconds() * 1000) 24 | 25 | 26 | def from_unix_time_millis(timestamp): 27 | return datetime.datetime.fromtimestamp(timestamp / 1000.0) 28 | 29 | 30 | dt = datetime.datetime.now() 31 | ts = unix_time_millis(dt) 32 | print ts, datetime.datetime.fromtimestamp(ts / 1000.0) 33 | 34 | bar_sizes = [ 35 | ("S1 ", BarSize.S1), 36 | ("S5 ", BarSize.S5), 37 | ("S15", BarSize.S15), 38 | ("S30", BarSize.S30), 39 | ("M1 ", BarSize.M1), 40 | ("M5 ", BarSize.M5), 41 | ("M15", BarSize.M15), 42 | ("M30", BarSize.M30), 43 | ("H1 ", BarSize.H1), 44 | ("D1 ", BarSize.D1) 45 | ] 46 | 47 | for key, value in bar_sizes: 48 | ts2_start = get_bar_start_time(ts, value) 49 | ts2_end = get_bar_end_time(ts, value) 50 | print key, ts2_start, from_unix_time_millis(ts2_start), ts2_end, from_unix_time_millis(ts2_end) 51 | -------------------------------------------------------------------------------- /poc/zerorpc_patch.py: -------------------------------------------------------------------------------- 1 | import msgpack 2 | from zerorpc.events import Event 3 | 4 | from poc.ser_deser import encode, decode 5 | 6 | 7 | def pack(self): 8 | return msgpack.Packer(default=encode).pack((self._header, self._name, self._args)) 9 | 10 | 11 | @staticmethod 12 | def unpack(blob): 13 | unpacker = msgpack.Unpacker(object_hook=decode) 14 | unpacker.feed(blob) 15 | unpacked_msg = unpacker.unpack() 16 | 17 | try: 18 | (header, name, args) = unpacked_msg 19 | except Exception as e: 20 | raise Exception('invalid msg format "{0}": {1}'.format( 21 | unpacked_msg, e)) 22 | 23 | # Backward compatibility 24 | if not isinstance(header, dict): 25 | header = {} 26 | 27 | return Event(name, args, None, header) 28 | 29 | 30 | Event.pack = pack 31 | Event.unpack = unpack 32 | -------------------------------------------------------------------------------- /proto/algotrader/model/market_data.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package algotrader.model; 4 | 5 | message Bar { 6 | enum Type { 7 | Time = 0; 8 | Tick = 1; 9 | Volume = 2; 10 | Dynamic = 3; 11 | } 12 | 13 | string inst_id = 1; 14 | string provider_id = 2; 15 | Type type = 3; 16 | int32 size = 4; 17 | int64 timestamp = 5; 18 | 19 | int64 utc_time = 6; 20 | int64 begin_time = 7; 21 | 22 | double open = 9; 23 | double high = 10; 24 | double low = 11; 25 | double close = 12; 26 | double vol = 13; 27 | double adj_close = 14; 28 | double open_interest = 15; 29 | } 30 | 31 | message Quote { 32 | 33 | string inst_id = 1; 34 | string provider_id = 2; 35 | int64 timestamp = 3; 36 | 37 | int64 utc_time = 4; 38 | 39 | double bid = 5; 40 | double bid_size = 6; 41 | double ask = 7; 42 | double ask_size = 8; 43 | } 44 | 45 | message Trade { 46 | 47 | string inst_id = 1; 48 | string provider_id = 2; 49 | int64 timestamp = 3; 50 | 51 | int64 utc_time = 4; 52 | 53 | double price = 5; 54 | double size = 6; 55 | } 56 | 57 | message MarketDepth { 58 | enum Side { 59 | Ask = 0; 60 | Bid = 1; 61 | } 62 | 63 | enum Operation { 64 | Insert = 0; 65 | Update = 1; 66 | Delete = 2; 67 | } 68 | 69 | string inst_id = 1; 70 | string provider_id = 2; 71 | int64 timestamp = 3; 72 | 73 | int64 utc_time = 4; 74 | 75 | string md_provider = 5; 76 | int64 position = 6; 77 | Operation operation = 7; 78 | Side side = 8; 79 | double price = 9; 80 | double size = 10; 81 | } 82 | 83 | 84 | message MarketDataSubscriptionRequest { 85 | enum MDType { 86 | Bar = 0; 87 | Trade = 1; 88 | Quote = 2; 89 | MarketDepth = 3; 90 | } 91 | MDType type = 1; 92 | string inst_id = 2; 93 | string feed_id = 3; 94 | string md_provider_id = 4; 95 | Bar.Type bar_type = 5; 96 | int32 bar_size = 6; 97 | 98 | int64 from_date = 7; 99 | int64 to_date = 8; 100 | } 101 | 102 | message BarAggregationRequest { 103 | enum InputType { 104 | Bar = 0; 105 | Trade = 1; 106 | Bid = 2; 107 | Ask = 3; 108 | BidAsk = 4; 109 | Middle = 5; 110 | Spread = 6; 111 | } 112 | 113 | string inst_id = 1; 114 | string provider_id = 2; 115 | InputType input_type = 3; 116 | int32 input_bar_size = 4; 117 | Bar.Type output_type = 5; 118 | int32 output_size = 6; 119 | 120 | } -------------------------------------------------------------------------------- /proto/algotrader/model/ref_data.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package algotrader.model; 4 | 5 | 6 | message Underlying { 7 | 8 | message Asset { 9 | string inst_id = 1; 10 | double weight = 2; 11 | } 12 | 13 | enum UnderlyingType { 14 | Single = 0; 15 | FixedWeightBasket = 1; 16 | WorstOfBasket = 2; 17 | BestOfBasket = 3; 18 | } 19 | 20 | UnderlyingType type = 1; 21 | repeated Asset assets = 2; 22 | 23 | } 24 | 25 | message Instrument { 26 | 27 | enum InstType { 28 | STK = 0; 29 | FUT = 1; 30 | OPT = 2; 31 | FOT = 3; 32 | IDX = 4; 33 | CASH = 5; 34 | ETF = 6; 35 | CBO = 7; 36 | } 37 | 38 | enum OptionType { 39 | Call = 0; 40 | Put = 1; 41 | } 42 | 43 | enum OptionStyle { 44 | European = 0; 45 | American = 1; 46 | } 47 | 48 | //int64 inst_id = 1; // unique, generated by system 49 | string inst_id = 2; // uniqie symbol + '@' + primary_exch_id 50 | string symbol = 3; 51 | string name = 4; 52 | InstType type = 5; 53 | string primary_exch_id = 6; 54 | repeated string exch_ids = 7; 55 | string ccy_id = 8; 56 | string sector = 9; 57 | string industry = 10; 58 | double margin = 11; 59 | double tick_size = 12; 60 | 61 | //alt sym / ids 62 | map alt_symbols = 31; //map, e.g. IB -> 5 63 | map alt_ids = 32; //e.g. RIC -> 0005.HK 64 | map alt_sectors = 33; 65 | map alt_industries = 34; 66 | 67 | //derivatives 68 | Underlying underlying = 101; 69 | OptionType option_type = 102; 70 | OptionStyle option_style = 103; 71 | double strike = 104; 72 | int64 exp_date = 105; 73 | double multiplier = 106; 74 | 75 | } 76 | 77 | message Exchange { 78 | string exch_id = 1; //uniqie 79 | string name = 2; 80 | string country_id = 3; 81 | string trading_hours_id = 4; 82 | string holidays_id = 5; 83 | 84 | map alt_ids = 6; 85 | } 86 | 87 | message Country { 88 | string country_id = 1; //uniqie 89 | string name = 2; 90 | 91 | string holidays_id = 3; 92 | } 93 | 94 | message Currency { 95 | string ccy_id = 1; //uniqie 96 | string name = 2; 97 | } 98 | 99 | message HolidaySeries { 100 | message Holiday { 101 | int64 trading_date = 1; 102 | int64 start_date = 2; 103 | int64 start_time = 3; 104 | int64 end_date = 4; 105 | int64 end_time = 5; 106 | Type type = 6; 107 | string desc = 7; 108 | 109 | enum Type { 110 | FullDay = 0; 111 | LateOpen = 1; 112 | EarlyClose = 2; 113 | Replace = 3; 114 | Modify = 4; 115 | } 116 | } 117 | 118 | string holidays_id = 1; 119 | repeated Holiday holidays = 2; 120 | } 121 | 122 | message TradingHours { 123 | message Session { 124 | WeekDay start_weekdate = 1; 125 | int64 start_time = 2; 126 | WeekDay end_weekdate = 3; 127 | int64 end_time = 4; 128 | bool eod = 5; 129 | 130 | enum WeekDay { 131 | Sunday = 0; 132 | Monday = 1; 133 | Tuesday = 2; 134 | Wednesday = 3; 135 | Thursday = 4; 136 | Friday = 5; 137 | Saturday = 6; 138 | } 139 | } 140 | 141 | string trading_hours_id = 1; 142 | string timezone_id = 2; 143 | repeated Session sessions = 3; 144 | } 145 | 146 | message TimeZone { 147 | string timezone_id = 1; 148 | } -------------------------------------------------------------------------------- /proto/algotrader/model/time_series.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package algotrader.model; 4 | 5 | 6 | 7 | message TimeSeriesItem{ 8 | int64 timestamp = 1; 9 | map data = 2; 10 | } 11 | 12 | message TimeSeriesUpdateEvent{ 13 | string source = 1; 14 | TimeSeriesItem item= 2; 15 | } 16 | 17 | message TimeSeries{ 18 | message Input{ 19 | string source = 1; 20 | repeated string keys = 2; 21 | } 22 | 23 | string series_id = 1; 24 | string series_cls = 2; 25 | repeated string keys = 3; 26 | string desc = 4; 27 | repeated Input inputs = 5; 28 | string default_output_key = 6; 29 | double missing_value_replace = 7; 30 | int64 start_time = 8; 31 | int64 end_time = 9; 32 | repeated TimeSeriesItem items = 10; 33 | map configs = 11; 34 | } 35 | 36 | -------------------------------------------------------------------------------- /proto/algotrader/model/time_series2.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package algotrader.model2; 4 | 5 | enum DataType { 6 | DTFloat = 0; 7 | DTDouble = 1; 8 | DTInt32 = 2; 9 | DTInt64 = 3; 10 | DTBool = 4; 11 | DTString = 5; 12 | DTByteArray = 6; 13 | } 14 | 15 | message Series { 16 | 17 | 18 | string series_id = 1; 19 | string df_id = 2; 20 | string col_id = 3; 21 | string inst_id = 4; 22 | DataType dtype = 5; 23 | 24 | // int64 start_time = 5; 25 | // int64 end_time = 6; 26 | //int64 length = 7; 27 | 28 | repeated int64 index = 11; 29 | 30 | repeated float float_data = 20; 31 | repeated double double_data = 21; 32 | repeated int32 int32_data = 22; 33 | repeated int64 int64_data = 23; 34 | repeated bool bool_data = 24; 35 | repeated string string_data = 25; 36 | repeated bytes bytes_data = 26; 37 | 38 | } 39 | 40 | 41 | message TimeSeriesItem { 42 | int64 timestamp = 1; 43 | map data = 2; 44 | } 45 | 46 | message TimeSeriesUpdateEvent { 47 | string source = 1; 48 | TimeSeriesItem item = 2; 49 | } 50 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | AlgoTrader 2 | =========================== 3 | 4 | AlgoTrader is an **event driven algorithmic trading system**. 5 | 6 | There is a python version and java version. They both share same API. 7 | 8 | The Java version is designed to support ultra low latency trading and FIX protocol. 9 | It can process 10M+ event per seconds in a commodity pc hardware with I7 CPU. 10 | 11 | The Python version support quick modelling / testing. 12 | 13 | 14 | Main Features (Python version) 15 | ------------------------------ 16 | 17 | * Event driven. 18 | * Supports Market, Limit, Stop and StopLimit orders. 19 | * Supports CSV Feed (yahoo format) 20 | * Supports backtesting using Simulator, which supports differnet fill-strategy, slippage config, commission setting 21 | * Supports live-trading using Interactive Brokers 22 | * Technical indicators and filters like SMA, RSI, Bollinger Bands.. 23 | * Performance metrics (use pyfolio) 24 | 25 | 26 | TODO 27 | ---- 28 | 29 | * Persistence: Save the portfolio, account, result into DB and can load them back when system startup. 30 | * Bar Factory: Aggregate bar from lower time frame bar or from quote / trade 31 | * Supports more CSV format, e.g. Google Finance, Quandl and NinjaTrader. 32 | * Supports more data feed e.g. Cassandra, InfluxDB, KDB 33 | * Save real time data, persist into various data store (e.g. CSV, Cassandra, Influx, KDB) 34 | * Aggregated TimeSeries (multiple Key - value) 35 | * Event profiler. 36 | * TA-Lib integration, support more TA indicator 37 | * HTML5 UI, to view the account, portfolio, control strategy and view performance (real time and historical performance) 38 | * Multiple currencies, timezone, and trading session. 39 | * Trading context, which include strategy config, data subscription info 40 | * Supports Machine Learning Library, e.g. Theano 41 | * Supports Spark 42 | * Supports Parallel processing, for optimization and backtest. Results should be persisted into DB and can be viewed by HTML5 UI. 43 | * Supports FIX workflow, refactor order into order and orderEvent (Order contains state and with different order events: NewOrderRequest , OrderCancelRequest, OrderCancelReplaceRequest, OrderStatusRequest and execution events: ExecutionReport, OrderStatus) 44 | * Remote OMS (strategies can be run separately and send order request to remote OrderServer) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | rx>=1.5.0 2 | gevent 3 | msgpack-python 4 | 5 | pandas>=0.18.1 6 | numpy>=1.11.1 7 | scipy>=0.18.0 8 | scikit-learn>=0.17.1 9 | 10 | swigibpy 11 | 12 | cassandra-driver 13 | cassandra-driver-dse 14 | influxdb 15 | qpython 16 | pymongo 17 | 18 | pykalman 19 | ta-lib 20 | jinja2 21 | pyfolio 22 | 23 | nose 24 | nose_parameterized -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexcwyu/python-trading/a494f602411a3ebfdecae002a16a5ea93fc7a046/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/cassandra/algotrader.cql: -------------------------------------------------------------------------------- 1 | CREATE TABLE IF NOT EXISTS bars ( 2 | inst_id varchar, 3 | type int, 4 | size int, 5 | begin_time bigint, 6 | timestamp bigint, 7 | open double, 8 | high double, 9 | low double, 10 | close double, 11 | vol bigint, 12 | adj_close double, 13 | PRIMARY KEY ((inst_id, type, size), timestamp) 14 | ); 15 | 16 | CREATE TABLE IF NOT EXISTS quotes( 17 | inst_id varchar, 18 | timestamp bigint, 19 | bid double, 20 | ask double, 21 | bid_size int, 22 | ask_size int, 23 | PRIMARY KEY (inst_id, timestamp) 24 | ); 25 | 26 | CREATE TABLE IF NOT EXISTS trades( 27 | inst_id varchar, 28 | timestamp bigint, 29 | price double, 30 | size int, 31 | PRIMARY KEY (inst_id, timestamp) 32 | ); 33 | 34 | CREATE TABLE IF NOT EXISTS market_depths( 35 | inst_id varchar, 36 | provider_id varchar, 37 | timestamp bigint, 38 | position bigint, 39 | operation int, 40 | side int, 41 | price double, 42 | size int, 43 | PRIMARY KEY ((inst_id, provider_id), timestamp) 44 | ); 45 | 46 | CREATE TABLE IF NOT EXISTS time_series( 47 | id varchar, 48 | data blob, 49 | PRIMARY KEY (id) 50 | ); 51 | 52 | CREATE TABLE IF NOT EXISTS instruments( 53 | inst_id varchar, 54 | name varchar, 55 | type varchar, 56 | symbol varchar, 57 | exch_id varchar, 58 | ccy_id varchar, 59 | alt_symbols map, 60 | alt_exch_id map, 61 | sector varchar, 62 | industry varchar, 63 | und_inst_id varchar, 64 | expiry_date timestamp, 65 | factor double, 66 | strike double, 67 | put_call varchar, 68 | margin double, 69 | PRIMARY KEY (inst_id) 70 | ); 71 | 72 | CREATE TABLE IF NOT EXISTS exchanges( 73 | exch_id varchar, 74 | code varchar, 75 | name varchar, 76 | PRIMARY KEY (exch_id) 77 | ); 78 | 79 | CREATE TABLE IF NOT EXISTS currencies( 80 | ccy_id varchar, 81 | name varchar, 82 | PRIMARY KEY (ccy_id) 83 | ); 84 | 85 | CREATE TABLE IF NOT EXISTS accounts( 86 | id varchar, 87 | data blob, 88 | PRIMARY KEY (id) 89 | ); 90 | 91 | 92 | CREATE TABLE IF NOT EXISTS portfolios( 93 | id varchar, 94 | data blob, 95 | PRIMARY KEY (id) 96 | ); 97 | 98 | 99 | CREATE TABLE IF NOT EXISTS orders( 100 | id varchar, 101 | data blob, 102 | PRIMARY KEY (id) 103 | ); 104 | 105 | 106 | CREATE TABLE IF NOT EXISTS configs( 107 | id varchar, 108 | data blob, 109 | PRIMARY KEY (id) 110 | ); 111 | 112 | CREATE TABLE IF NOT EXISTS strategies( 113 | id varchar, 114 | data blob, 115 | PRIMARY KEY (id) 116 | ); 117 | 118 | CREATE TABLE IF NOT EXISTS account_updates( 119 | id varchar, 120 | data blob, 121 | PRIMARY KEY (id) 122 | ); 123 | 124 | 125 | CREATE TABLE IF NOT EXISTS portfolio_updates( 126 | id varchar, 127 | data blob, 128 | PRIMARY KEY (id) 129 | ); 130 | 131 | 132 | CREATE TABLE IF NOT EXISTS new_order_reqs( 133 | id varchar, 134 | data blob, 135 | PRIMARY KEY (id) 136 | ); 137 | 138 | 139 | CREATE TABLE IF NOT EXISTS ord_cancel_reqs( 140 | id varchar, 141 | data blob, 142 | PRIMARY KEY (id) 143 | ); 144 | 145 | 146 | CREATE TABLE IF NOT EXISTS ord_replace_reqs( 147 | id varchar, 148 | data blob, 149 | PRIMARY KEY (id) 150 | ); 151 | 152 | 153 | CREATE TABLE IF NOT EXISTS exec_reports( 154 | id varchar, 155 | data blob, 156 | PRIMARY KEY (id) 157 | ); 158 | 159 | 160 | CREATE TABLE IF NOT EXISTS ord_status_upds( 161 | id varchar, 162 | data blob, 163 | PRIMARY KEY (id) 164 | ); 165 | CREATE TABLE IF NOT EXISTS sequences( 166 | id varchar, 167 | seq bigint, 168 | PRIMARY KEY (id) 169 | ); 170 | 171 | -------------------------------------------------------------------------------- /scripts/eoddata_symbol_importer.py: -------------------------------------------------------------------------------- 1 | from scripts.ib_inst_utils import init_ib, import_inst_from_ib, app_context 2 | import time 3 | 4 | file_name = '../data/refdata/eoddata/HKEX.txt' 5 | 6 | 7 | app_context = app_context() 8 | broker = init_ib(app_context) 9 | 10 | f = open(file_name) 11 | count = 0 12 | for line in f: 13 | count += 1 14 | if count == 1: 15 | continue 16 | idx = line.find('\t') 17 | if idx >0: 18 | symbol = int(line[0:idx]) 19 | desc = line[idx+1:-1] 20 | else: 21 | symbol = line 22 | desc =None 23 | 24 | #print "symbol=%s, desc=%s" % (symbol, desc) 25 | 26 | import_inst_from_ib(broker=broker, symbol=str(symbol), exchange='SEHK') 27 | 28 | 29 | 30 | for inst in app_context.ref_data_mgr.get_all_insts(): 31 | print inst -------------------------------------------------------------------------------- /scripts/event_calendar_downloader.py: -------------------------------------------------------------------------------- 1 | # This will later become a provider of event 2 | # TODO: move to provider 3 | import urllib2 4 | import datetime 5 | 6 | # orig_link = 'https://www.dailyfx.com/files/Calendar-10-16-2016.xls' 7 | template = 'https://www.dailyfx.com/files/Calendar-%s-%s-%s.xls' 8 | 9 | 10 | last_date = datetime.date(2016,10,16) 11 | missing_dates = [] 12 | 13 | 14 | curr_date = last_date 15 | 16 | 17 | 18 | 19 | def download(date): 20 | date_str = date.isoformat() 21 | url = template % (date_str[5:7], date_str[8:10], date_str[:4]) 22 | try: 23 | xls_file = urllib2.urlopen(url) 24 | print "downloaded %s" % date 25 | to_file = '/Volumes/Transcend/data/RawCalendar/%s.xls' % date 26 | with open(to_file, 'wb') as outfile: 27 | outfile.write(xls_file.read()) 28 | except: 29 | print "404, can't download for %s" % date 30 | missing_dates.append(date) 31 | 32 | 33 | while curr_date > datetime.date(2000,1,1): 34 | download(curr_date) 35 | curr_date = curr_date + datetime.timedelta(days=-7) 36 | -------------------------------------------------------------------------------- /scripts/gen_proto.sh: -------------------------------------------------------------------------------- 1 | protoc --python_out=../ --proto_path=../proto/ ../proto/algotrader/model/*.proto 2 | -------------------------------------------------------------------------------- /scripts/ib_inst_utils.py: -------------------------------------------------------------------------------- 1 | from gevent import monkey 2 | from gevent.event import AsyncResult 3 | 4 | monkey.patch_all() 5 | 6 | from algotrader.trading.config import Config, load_from_yaml 7 | from algotrader.trading.context import ApplicationContext 8 | from algotrader.utils.logging import logger 9 | # from algotrader.config.app import ApplicationConfig 10 | # from algotrader.config.broker import IBConfig 11 | # from algotrader.config.persistence import MongoDBConfig 12 | # from algotrader.config.persistence import PersistenceConfig 13 | from algotrader.provider.broker import Broker 14 | from algotrader.provider.datastore import PersistenceMode 15 | from algotrader.provider.datastore import DataStore 16 | from algotrader.trading.context import ApplicationContext 17 | from algotrader.trading.ref_data import RefDataManager 18 | from algotrader.trading.clock import Clock 19 | # from algotrader.utils import logger 20 | 21 | 22 | # def app_context(): 23 | # persistence_config = PersistenceConfig(None, 24 | # DataStore.Mongo, PersistenceMode.RealTime, 25 | # DataStore.Mongo, PersistenceMode.RealTime, 26 | # DataStore.Mongo, PersistenceMode.RealTime, 27 | # DataStore.Mongo, PersistenceMode.RealTime) 28 | # app_config = ApplicationConfig(id=None, ref_data_mgr_type=RefDataManager.DB, clock_type=Clock.RealTime, 29 | # persistence_config=persistence_config, 30 | # provider_configs=[MongoDBConfig(), IBConfig(client_id=2, use_gevent=True)]) 31 | # app_context = ApplicationContext(app_config=app_config) 32 | # 33 | # return app_context 34 | 35 | config = Config( 36 | load_from_yaml("../config/data_import.yaml")) 37 | 38 | app_context = ApplicationContext(config=config) 39 | 40 | def init_ib(app_context): 41 | 42 | app_context.start() 43 | broker = app_context.provider_mgr.get(Broker.IB) 44 | broker.start(app_context) 45 | return broker 46 | 47 | 48 | 49 | def import_inst_from_ib(broker, symbol, sec_type='STK', exchange=None, currency=None): 50 | try: 51 | result = AsyncResult() 52 | logger.info("importing symbol %s" % symbol) 53 | broker.reqContractDetails(symbol=symbol, sec_type=sec_type, exchange=exchange, currency=currency, callback=result) 54 | # broker.reqScannerSubscription(inst_type='STK', location_code='STK.US', scan_code='TOP_PERC_GAIN', above_vol=1000000, callback=callback) 55 | 56 | logger.info("done %s %s" % (symbol, result.get(timeout=3))) 57 | except Exception as e: 58 | logger.error("faile to import %s", symbol, e) 59 | 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /scripts/kdb/bar.q: -------------------------------------------------------------------------------- 1 | bar:([] 2 | / date:`date$(); 3 | / time:`time$(); 4 | instId:`int$(); 5 | size:`int$(); 6 | datetime:`long$(); 7 | open:`float$(); 8 | high:`float$(); 9 | low:`float$(); 10 | close:`float$(); 11 | vol:`int$(); 12 | openInt:`int$()) 13 | 14 | -------------------------------------------------------------------------------- /scripts/kdb/quote.q: -------------------------------------------------------------------------------- 1 | quote:([] 2 | / date:`date$(); 3 | / time:`time$(); 4 | instId:`int$(); 5 | datetime:`long$(); 6 | bid:`float$(); 7 | ask:`float$(); 8 | bsize:`int$(); 9 | asize:`int$()) 10 | 11 | -------------------------------------------------------------------------------- /scripts/kdb/trade.q: -------------------------------------------------------------------------------- 1 | trade:([] 2 | / date:`date$(); 3 | / time:`time$(); 4 | instId:`int$(); 5 | datetime:`long$(); 6 | price:`float$(); 7 | size:`int$()) -------------------------------------------------------------------------------- /scripts/netfronds_tickdata.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on 10/21/16 3 | Author = jchan 4 | """ 5 | __author__ = 'jchan' 6 | -------------------------------------------------------------------------------- /scripts/start_mongo.sh: -------------------------------------------------------------------------------- 1 | mongod --dbpath /mnt/data/dev/data/mongo -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | from algotrader.trading.config import Config 2 | 3 | test_override = { 4 | "Application": { 5 | "dataStoreId": "InMemory", 6 | "createDBAtStart": True, 7 | "deleteDBAtStop": False, 8 | "plot": False 9 | }, 10 | "DataStore": {"InMemory": 11 | { 12 | "file": "../data/algotrader_backtest_db.p", 13 | "instCSV": "../data/refdata/instrument.csv", 14 | "ccyCSV": "../data/refdata/ccy.csv", 15 | "exchCSV": "../data/refdata/exch.csv" 16 | } 17 | }, 18 | "Feed": {"CSV": 19 | {"path": "/mnt/data/dev/workspaces/python-trading/data/tradedata"} 20 | } 21 | } 22 | 23 | config = Config(test_override) 24 | 25 | empty_config = Config({ 26 | "Application": { 27 | "dataStoreId": "InMemory", 28 | "createDBAtStart": True, 29 | "deleteDBAtStop": False 30 | }, 31 | "DataStore": {"InMemory": 32 | { 33 | "file": "../data/algotrader_backtest_db.p", 34 | } 35 | }, 36 | "Feed": {"CSV": 37 | {"path": "/mnt/data/dev/workspaces/python-trading/data/tradedata"} 38 | } 39 | }) 40 | -------------------------------------------------------------------------------- /tests/integration_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexcwyu/python-trading/a494f602411a3ebfdecae002a16a5ea93fc7a046/tests/integration_tests/__init__.py -------------------------------------------------------------------------------- /tests/integration_tests/test_persistence_mongo.py: -------------------------------------------------------------------------------- 1 | from pymongo import MongoClient 2 | from unittest import TestCase 3 | 4 | from algotrader.utils.protobuf_to_dict import * 5 | from tests.sample_factory import * 6 | 7 | 8 | class MongoPersistenceTest(TestCase): 9 | host = "localhost" 10 | port = 27017 11 | dbname = "test" 12 | client = None 13 | db = None 14 | 15 | @classmethod 16 | def setUpClass(cls): 17 | cls.client = MongoClient(host=cls.host, port=cls.port) 18 | cls.db = cls.client[cls.dbname] 19 | 20 | cls.tests = cls.db['tests'] 21 | cls.factory = SampleFactory() 22 | 23 | @classmethod 24 | def tearDownClass(cls): 25 | cls.client.drop_database(cls.dbname) 26 | 27 | def setUp(self): 28 | pass 29 | 30 | def tearDown(self): 31 | MongoPersistenceTest.tests.remove() 32 | 33 | def test_instrument(self): 34 | inst = MongoPersistenceTest.factory.sample_instrument() 35 | self.__test_persistence(Instrument, inst) 36 | 37 | def test_exchange(self): 38 | exchange = self.factory.sample_exchange() 39 | self.__test_persistence(Exchange, exchange) 40 | 41 | def test_currency(self): 42 | currency = self.factory.sample_currency() 43 | self.__test_persistence(Currency, currency) 44 | 45 | def test_country(self): 46 | country = self.factory.sample_country() 47 | self.__test_persistence(Country, country) 48 | 49 | def test_trading_holidays(self): 50 | trading_holiday = self.factory.sample_trading_holidays() 51 | self.__test_persistence(HolidaySeries, trading_holiday) 52 | 53 | def test_trading_hours(self): 54 | trading_hours = self.factory.sample_trading_hours() 55 | self.__test_persistence(TradingHours, trading_hours) 56 | 57 | def test_timezone(self): 58 | timezone = self.factory.sample_timezone() 59 | self.__test_persistence(TimeZone, timezone) 60 | 61 | def test_time_series(self): 62 | ds = self.factory.sample_time_series() 63 | self.__test_persistence(TimeSeries, ds) 64 | 65 | def test_bar(self): 66 | self.__test_persistence(Bar, self.factory.sample_bar()) 67 | 68 | def test_quote(self): 69 | self.__test_persistence(Quote, self.factory.sample_quote()) 70 | 71 | def test_trade(self): 72 | self.__test_persistence(Trade, self.factory.sample_trade()) 73 | 74 | def test_market_depth(self): 75 | self.__test_persistence(MarketDepth, self.factory.sample_market_depth()) 76 | 77 | def test_new_order_request(self): 78 | self.__test_persistence(NewOrderRequest, self.factory.sample_new_order_request()) 79 | 80 | def test_order_replace_request(self): 81 | self.__test_persistence(OrderReplaceRequest, self.factory.sample_order_replace_request()) 82 | 83 | def test_order_cancel_request(self): 84 | self.__test_persistence(OrderCancelRequest, self.factory.sample_order_cancel_request()) 85 | 86 | def test_order_status_update(self): 87 | self.__test_persistence(OrderStatusUpdate, self.factory.sample_order_status_update()) 88 | 89 | def test_execution_report(self): 90 | self.__test_persistence(ExecutionReport, self.factory.sample_execution_report()) 91 | 92 | def test_account_update(self): 93 | self.__test_persistence(AccountUpdate, self.factory.sample_account_update()) 94 | 95 | def test_portfolio_update(self): 96 | self.__test_persistence(PortfolioUpdate, self.factory.sample_portfolio_update()) 97 | 98 | def test_account_state(self): 99 | self.__test_persistence(AccountState, self.factory.sample_account_state()) 100 | 101 | def test_portfolio_state(self): 102 | self.__test_persistence(PortfolioState, self.factory.sample_portfolio_state()) 103 | 104 | def test_strategy_state(self): 105 | self.__test_persistence(StrategyState, self.factory.sample_strategy_state()) 106 | 107 | def test_order_state(self): 108 | self.__test_persistence(OrderState, self.factory.sample_order_state()) 109 | 110 | def test_sequence(self): 111 | self.__test_persistence(Sequence, self.factory.sample_sequence()) 112 | 113 | def __test_persistence(self, cls, obj): 114 | data = protobuf_to_dict(obj) 115 | MongoPersistenceTest.tests.update({'_id': 1}, data, upsert=True) 116 | result = MongoPersistenceTest.tests.find_one({"_id": 1}) 117 | del result['_id'] 118 | new_obj = dict_to_protobuf(cls, result) 119 | self.assertEqual(obj, new_obj) 120 | -------------------------------------------------------------------------------- /tests/test_broker_mgr.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from algotrader.provider.broker import Broker 4 | from algotrader.provider import ProviderManager 5 | 6 | 7 | class BrokerManagerTest(TestCase): 8 | def test_reg(self): 9 | bm = ProviderManager() 10 | self.assertIsNotNone(bm.get(Broker.Simulator)) 11 | -------------------------------------------------------------------------------- /tests/test_config.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from algotrader.trading.config import Config, load_from_yaml 4 | 5 | 6 | class ConfigTest(TestCase): 7 | def test_multiple(self): 8 | config = Config( 9 | load_from_yaml("../config/backtest.yaml"), 10 | load_from_yaml("../config/down2%.yaml")) 11 | self.assertEquals(1, config.get_strategy_config("down2%", "qty")) 12 | -------------------------------------------------------------------------------- /tests/test_data_series_utils.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from algotrader.utils.data_series import * 4 | 5 | 6 | class DataSeriesUtilsTest(TestCase): 7 | def test_build_series_name(self): 8 | name = build_series_id("SMA") 9 | self.assertEqual("SMA()", name) 10 | 11 | name = build_series_id("SMA", inputs='Bar.HSI') 12 | self.assertEqual("SMA(Bar.HSI)", name) 13 | 14 | name = build_series_id("SMA", input_keys='Close') 15 | self.assertEqual("SMA()", name) 16 | 17 | name = build_series_id("SMA", length=10, vol=2) 18 | self.assertTrue(name in set(["SMA(length=10,vol=2)", "SMA(vol=2,length=10)"])) 19 | 20 | name = build_series_id("SMA", inputs='Bar.HSI', input_keys='Close', length=10) 21 | self.assertEqual("SMA(Bar.HSI[Close],length=10)", name) 22 | 23 | name = build_series_id("SMA", inputs='Bar.HSI', input_keys=['Close', 'Open'], length=10) 24 | self.assertEqual("SMA(Bar.HSI[Close,Open],length=10)", name) 25 | 26 | name = build_series_id("SMA", inputs=['Bar.HSI', 'Quote.SPX'], input_keys=['Close', 'Open'], length=1) 27 | 28 | self.assertEqual("SMA(Bar.HSI[Close,Open],Quote.SPX[Close,Open],length=1)", name) 29 | 30 | name = build_series_id("SMA", inputs='SMA(Bar.HSI[Close],length=10)', length=5) 31 | self.assertEqual("SMA(SMA(Bar.HSI[Close],length=10),length=5)", name) 32 | -------------------------------------------------------------------------------- /tests/test_feed.py: -------------------------------------------------------------------------------- 1 | from nose_parameterized import parameterized, param 2 | from unittest import TestCase 3 | 4 | from algotrader.trading.context import ApplicationContext 5 | from algotrader.trading.event import EventLogger 6 | from algotrader.utils.market_data import * 7 | from tests import config 8 | 9 | params = [ 10 | param('CSV', ['Bar.Yahoo.Time.D1']), 11 | param('PandasWeb', ['Bar.Google.Time.D1']), 12 | param('PandasWeb', ['Bar.Yahoo.Time.D1']) 13 | ] 14 | 15 | 16 | class FeedTest(TestCase): 17 | @parameterized.expand(params) 18 | def test_loaded_bar(self, feed_id, subscription_types): 19 | app_context = ApplicationContext(config=config) 20 | app_context.start() 21 | 22 | feed = app_context.provider_mgr.get(feed_id) 23 | feed.start(app_context) 24 | 25 | # logger.setLevel(logging.DEBUG) 26 | eventLogger = EventLogger() 27 | eventLogger.start(app_context) 28 | 29 | instruments = app_context.ref_data_mgr.get_insts_by_ids(["SPY@NYSEARCA"]) 30 | for sub_req in build_subscription_requests(feed_id, instruments, 31 | subscription_types, 32 | 20100101, 33 | 20170101): 34 | feed.subscribe_mktdata(sub_req) 35 | 36 | self.assertTrue(eventLogger.count[Bar] > 0) 37 | self.assertTrue(eventLogger.count[Trade] == 0) 38 | self.assertTrue(eventLogger.count[Quote] == 0) 39 | -------------------------------------------------------------------------------- /tests/test_in_memory_db.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from unittest import TestCase 4 | 5 | from algotrader.model.model_factory import ModelFactory 6 | from algotrader.provider.datastore.inmemory import InMemoryDataStore 7 | from algotrader.trading.context import ApplicationContext 8 | from tests import empty_config 9 | 10 | 11 | class InMemoryDBTest(TestCase): 12 | def setUp(self): 13 | 14 | self.app_context = ApplicationContext(config=empty_config) 15 | self.app_context.start() 16 | 17 | self.db = InMemoryDataStore() 18 | self.db.start(self.app_context) 19 | 20 | def tearDown(self): 21 | self.db.remove_database() 22 | 23 | def test_save_and_load(self): 24 | inputs = [] 25 | for x in range(0, 10): 26 | data = sorted([random.randint(0, 100) for i in range(0, 4)]) 27 | bar = ModelFactory.build_bar(timestamp=x, inst_id="3", open=data[1], high=data[3], low=data[0], 28 | close=data[2], 29 | vol=random.randint(100, 1000)) 30 | inputs.append(bar) 31 | self.db.save_bar(bar) 32 | 33 | self.db.stop() 34 | 35 | self.db = InMemoryDataStore() 36 | self.db.start(self.app_context) 37 | 38 | bars = self.db.load_all('bars') 39 | bars = sorted(bars, key=lambda x: x.timestamp, reverse=False) 40 | self.assertEquals(10, len(bars)) 41 | 42 | for x in range(0, 10): 43 | self.assertEquals(inputs[x], bars[x]) 44 | -------------------------------------------------------------------------------- /tests/test_indicator.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from algotrader.trading.context import ApplicationContext 4 | from algotrader.utils.indicator import parse_series, get_or_create_indicator 5 | from algotrader.technical.ma import SMA 6 | 7 | class IndicatorTest(TestCase): 8 | def setUp(self): 9 | self.app_context = ApplicationContext() 10 | 11 | def test_reuse(self): 12 | close = self.app_context.inst_data_mgr.get_series("bar") 13 | close.start(self.app_context) 14 | 15 | 16 | sma1 = get_or_create_indicator(self.app_context.inst_data_mgr, cls=SMA, inputs='bar', input_keys='close', 17 | length=3) 18 | sma1.start(self.app_context) 19 | 20 | sma2 = get_or_create_indicator(self.app_context.inst_data_mgr, cls=SMA, inputs='bar', input_keys='close', 21 | length=3) 22 | sma2.start(self.app_context) 23 | 24 | sma3 = get_or_create_indicator(self.app_context.inst_data_mgr, cls=SMA, inputs='bar', input_keys='close', 25 | length=10) 26 | sma3.start(self.app_context) 27 | 28 | self.assertEquals(sma1, sma2) 29 | self.assertNotEquals(sma2, sma3) 30 | self.assertNotEquals(sma1, sma3) 31 | 32 | sma4 = get_or_create_indicator(self.app_context.inst_data_mgr, cls=SMA, inputs=sma3, length=10) 33 | sma4.start(self.app_context) 34 | 35 | self.assertEquals(sma4.input_series[0], sma3) 36 | 37 | # def test_parse(self): 38 | # bar = parse_series(self.app_context.inst_data_mgr, "bar") 39 | # bar.start(self.app_context) 40 | # 41 | # sma1 = parse_series(self.app_context.inst_data_mgr, "SMA(bar[close],length=3)") 42 | # sma1.start(self.app_context) 43 | # 44 | # sma2 = parse_series(self.app_context.inst_data_mgr, "SMA(SMA(bar[close],length=3)[value],length=10)") 45 | # sma2.start(self.app_context) 46 | # 47 | # rsi = parse_series(self.app_context.inst_data_mgr, "RSI(SMA(SMA('bar',close,3),value,10),value,14, 9)") 48 | # rsi.start(self.app_context) 49 | # 50 | # self.assertEquals(sma1.input, bar) 51 | # self.assertEquals(3, sma1.length) 52 | # 53 | # self.assertEquals(sma2.input, sma1) 54 | # self.assertEquals(10, sma2.length) 55 | # 56 | # self.assertEquals(rsi.input, sma2) 57 | # self.assertEquals(14, rsi.length) 58 | # 59 | # def test_fail_parse(self): 60 | # with self.assertRaises(AssertionError): 61 | # parse_series(self.app_context.inst_data_mgr, "SMA('Bar.Close',3") 62 | # 63 | # with self.assertRaises(AssertionError): 64 | # parse_series(self.app_context.inst_data_mgr, "RSI(SMA(SMA('Bar.Close',3,10),14)") 65 | -------------------------------------------------------------------------------- /tests/test_instrument_data.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from algotrader.model.model_factory import * 4 | from algotrader.trading.instrument_data import InstrumentDataManager 5 | 6 | 7 | class InstrumentDataTest(TestCase): 8 | def setUp(self): 9 | self.inst_data_mgr = InstrumentDataManager() 10 | 11 | def test_get_bar(self): 12 | bar = self.inst_data_mgr.get_bar(1) 13 | self.assertIsNone(bar) 14 | 15 | bar1 = ModelFactory.build_bar(timestamp=0, inst_id="1", open=20, high=21, low=19, close=20.5) 16 | self.inst_data_mgr.on_bar(bar1) 17 | bar = self.inst_data_mgr.get_bar("1") 18 | self.assertEqual(bar1, bar) 19 | 20 | def test_get_quote(self): 21 | quote = self.inst_data_mgr.get_quote("1") 22 | self.assertIsNone(quote) 23 | 24 | quote1 = ModelFactory.build_quote(timestamp=0, inst_id="1", bid=18, ask=19, bid_size=200, ask_size=500) 25 | self.inst_data_mgr.on_quote(quote1) 26 | quote = self.inst_data_mgr.get_quote("1") 27 | self.assertEqual(quote1, quote) 28 | 29 | def test_get_trade(self): 30 | trade = self.inst_data_mgr.get_trade("1") 31 | self.assertIsNone(trade) 32 | 33 | trade1 = ModelFactory.build_trade(timestamp=0, inst_id="1", price=20, size=200) 34 | self.inst_data_mgr.on_trade(trade1) 35 | trade = self.inst_data_mgr.get_trade("1") 36 | self.assertEqual(trade1, trade) 37 | 38 | def get_latest_price(self): 39 | price = self.inst_data_mgr.get_latest_price(1) 40 | self.assertIsNone(price) 41 | 42 | bar1 = ModelFactory.build_bar(timestamp=0, inst_id="1", open=20, high=21, low=19, close=20.5) 43 | self.inst_data_mgr.on_bar(bar1) 44 | price = self.inst_data_mgr.get_latest_price(1) 45 | self.assertEqual(20.5, price) 46 | 47 | bar1 = ModelFactory.build_bar(timestamp=0, inst_id="1", open=20, high=21, low=19, close=20.5, adj_close=22) 48 | self.inst_data_mgr.on_bar(bar1) 49 | price = self.inst_data_mgr.get_latest_price(1) 50 | self.assertEqual(22, price) 51 | 52 | quote1 = ModelFactory.build_quote(timestamp=0, inst_id="1", bid=18, ask=19, bid_size=200, ask_size=500) 53 | self.inst_data_mgr.on_quote(quote1) 54 | price = self.inst_data_mgr.get_latest_price(1) 55 | self.assertEqual(18.5, price) 56 | 57 | trade1 = ModelFactory.build_trade(timestamp=0, inst_id="1", price=20, size=200) 58 | self.inst_data_mgr.on_bar(trade1) 59 | price = self.inst_data_mgr.get_latest_price(1) 60 | self.assertEqual(20, price) 61 | -------------------------------------------------------------------------------- /tests/test_ma.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | from unittest import TestCase 5 | 6 | from algotrader.technical.ma import SMA 7 | from algotrader.trading.context import ApplicationContext 8 | 9 | 10 | class MovingAverageTest(TestCase): 11 | def setUp(self): 12 | self.app_context = ApplicationContext() 13 | 14 | def test_name(self): 15 | bar = self.app_context.inst_data_mgr.get_series("bar") 16 | sma = SMA(inputs=bar, input_keys='close', length=3) 17 | self.assertEquals("SMA(bar[close],length=3)", sma.name) 18 | 19 | sma2 = SMA(inputs=sma, input_keys='value', length=10) 20 | self.assertEquals("SMA(SMA(bar[close],length=3)[value],length=10)", sma2.name) 21 | 22 | def test_empty_at_initialize(self): 23 | close = self.app_context.inst_data_mgr.get_series("bar") 24 | sma = SMA(inputs=close, input_keys='close', length=3) 25 | self.assertEquals(0, len(sma.get_data())) 26 | 27 | def test_nan_before_size(self): 28 | bar = self.app_context.inst_data_mgr.get_series("bar") 29 | bar.start(self.app_context) 30 | 31 | sma = SMA(inputs=bar, input_keys='close', length=3) 32 | sma.start(self.app_context) 33 | 34 | t1 = 1 35 | t2 = t1 + 3 36 | t3 = t2 + 3 37 | 38 | bar.add(timestamp=t1, data={"close": 2.0, "open": 0}) 39 | self.assertEquals([{'value': np.nan}], 40 | sma.get_data()) 41 | 42 | bar.add(timestamp=t2, data={"close": 2.4, "open": 1.4}) 43 | self.assertEquals([{'value': np.nan}, 44 | {'value': np.nan}], 45 | sma.get_data()) 46 | 47 | bar.add(timestamp=t3, data={"close": 2.8, "open": 1.8}) 48 | self.assertEquals([{'value': np.nan}, 49 | {'value': np.nan}, 50 | {'value': 2.4}], 51 | sma.get_data()) 52 | 53 | def test_moving_average_calculation(self): 54 | bar = self.app_context.inst_data_mgr.get_series("bar") 55 | bar.start(self.app_context) 56 | 57 | sma = SMA(inputs=bar, input_keys='close', length=3) 58 | sma.start(self.app_context) 59 | 60 | t1 = 1 61 | t2 = t1 + 3 62 | t3 = t2 + 3 63 | t4 = t3 + 3 64 | t5 = t4 + 3 65 | 66 | bar.add(data={"timestamp": t1, "close": 2.0, "open": 0}) 67 | self.assertTrue(math.isnan(sma.now('value'))) 68 | 69 | bar.add(data={"timestamp": t2, "close": 2.4, "open": 1.4}) 70 | self.assertTrue(math.isnan(sma.now('value'))) 71 | 72 | bar.add(data={"timestamp": t3, "close": 2.8, "open": 1.8}) 73 | self.assertEquals(2.4, sma.now('value')) 74 | 75 | bar.add(data={"timestamp": t4, "close": 3.2, "open": 2.2}) 76 | self.assertEquals(2.8, sma.now('value')) 77 | 78 | bar.add(data={"timestamp": t5, "close": 3.6, "open": 2.6}) 79 | self.assertEquals(3.2, sma.now('value')) 80 | 81 | self.assertTrue(math.isnan(sma.get_by_idx(0, 'value'))) 82 | self.assertTrue(math.isnan(sma.get_by_idx(1, 'value'))) 83 | self.assertEquals(2.4, sma.get_by_idx(2, 'value')) 84 | self.assertEquals(2.8, sma.get_by_idx(3, 'value')) 85 | self.assertEquals(3.2, sma.get_by_idx(4, 'value')) 86 | 87 | self.assertTrue(math.isnan(sma.get_by_time(t1, 'value'))) 88 | self.assertTrue(math.isnan(sma.get_by_time(t2, 'value'))) 89 | self.assertEquals(2.4, sma.get_by_time(t3, 'value')) 90 | self.assertEquals(2.8, sma.get_by_time(t4, 'value')) 91 | self.assertEquals(3.2, sma.get_by_time(t5, 'value')) 92 | -------------------------------------------------------------------------------- /tests/test_market_data_processor.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from algotrader.model.model_factory import ModelFactory 4 | from algotrader.model.trade_data_pb2 import * 5 | from algotrader.provider.broker.sim.data_processor import BarProcessor, TradeProcessor, QuoteProcessor 6 | from algotrader.provider.broker.sim.sim_config import SimConfig 7 | 8 | 9 | class MarketDataProcessorTest(TestCase): 10 | def test_bar_processor(self): 11 | config = SimConfig() 12 | processor = BarProcessor() 13 | 14 | order = ModelFactory.build_new_order_request(timestamp=0, cl_id='test', cl_ord_id="1", inst_id="1", action=Buy, 15 | type=Limit, 16 | qty=1000, limit_price=18.5) 17 | bar = ModelFactory.build_bar(timestamp=0, inst_id="1", open=18, high=19, low=17, close=17.5, vol=1000) 18 | 19 | self.assertEqual(17.5, processor.get_price(order, bar, config)) 20 | self.assertEqual(1000, processor.get_qty(order, bar, config)) 21 | 22 | config2 = SimConfig(fill_on_bar_mode=SimConfig.FillMode.NEXT_OPEN) 23 | self.assertEqual(18, processor.get_price(order, bar, config2)) 24 | self.assertEqual(1000, processor.get_qty(order, bar, config2)) 25 | 26 | def test_trader_processor(self): 27 | config = SimConfig() 28 | processor = TradeProcessor() 29 | 30 | order = ModelFactory.build_new_order_request(timestamp=0, cl_id='test', cl_ord_id="1", inst_id="1", action=Buy, 31 | type=Limit, 32 | qty=1000, limit_price=18.5) 33 | trade = ModelFactory.build_trade(timestamp=0, inst_id="1", price=20, size=200) 34 | 35 | self.assertEqual(20, processor.get_price(order, trade, config)) 36 | self.assertEqual(200, processor.get_qty(order, trade, config)) 37 | 38 | def test_quote_processor(self): 39 | config = SimConfig() 40 | processor = QuoteProcessor() 41 | 42 | order = ModelFactory.build_new_order_request(timestamp=0, cl_id='test', cl_ord_id="1", inst_id="1", action=Buy, 43 | type=Limit, 44 | qty=1000, limit_price=18.5) 45 | quote = ModelFactory.build_quote(timestamp=0, inst_id="1", bid=18, ask=19, bid_size=200, ask_size=500) 46 | 47 | self.assertEqual(19, processor.get_price(order, quote, config)) 48 | self.assertEqual(500, processor.get_qty(order, quote, config)) 49 | 50 | order2 = ModelFactory.build_new_order_request(timestamp=0, cl_id='test', cl_ord_id="2", inst_id="1", 51 | action=Sell, type=Limit, 52 | qty=1000, 53 | limit_price=18.5) 54 | self.assertEqual(18, processor.get_price(order2, quote, config)) 55 | self.assertEqual(200, processor.get_qty(order2, quote, config)) 56 | -------------------------------------------------------------------------------- /tests/test_model_factory.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from algotrader.utils.model import get_model_id 4 | from tests.sample_factory import * 5 | 6 | 7 | class ModelFactoryTest(TestCase): 8 | factory = SampleFactory() 9 | 10 | def setUp(self): 11 | self.factory = SampleFactory() 12 | 13 | def test_instrument(self): 14 | inst = ModelFactoryTest.factory.sample_instrument() 15 | self.assertEqual("2800.HK@SEHK", get_model_id(inst)) 16 | 17 | def test_exchange(self): 18 | exchange = self.factory.sample_exchange() 19 | self.assertEqual("SEHK", get_model_id(exchange)) 20 | 21 | def test_currency(self): 22 | currency = self.factory.sample_currency() 23 | self.assertEqual("HKD", get_model_id(currency)) 24 | 25 | def test_country(self): 26 | country = self.factory.sample_country() 27 | self.assertEqual("US", get_model_id(country)) 28 | 29 | def test_trading_holidays(self): 30 | self.assertEqual("HK holiday", get_model_id(self.factory.sample_trading_holidays())) 31 | 32 | def test_trading_hours(self): 33 | self.assertEqual("SEHK_trdinghr", get_model_id(self.factory.sample_trading_hours())) 34 | 35 | def test_timezone(self): 36 | self.assertEqual("Venezuela Standard Time", get_model_id(self.factory.sample_timezone())) 37 | 38 | def test_time_series(self): 39 | ts = self.factory.sample_time_series() 40 | self.assertEqual("HSI.BAR.86400", get_model_id(ts)) 41 | self.assertTrue(len(list(ts.inputs)) == 1) 42 | input = ts.inputs[0] 43 | self.assertEqual("HSI.BAR.1", input.source) 44 | self.assertEqual(['close', 'open'], list(input.keys)) 45 | 46 | 47 | def test_bar(self): 48 | self.assertEqual("Bar.HSI@SEHK.0.86400.IB.12312", get_model_id(self.factory.sample_bar())) 49 | 50 | def test_quote(self): 51 | self.assertEqual("Quote.HSI@SEHK.IB.12312", get_model_id(self.factory.sample_quote())) 52 | 53 | def test_trade(self): 54 | self.assertEqual("Trade.HSI@SEHK.IB.12312", get_model_id(self.factory.sample_trade())) 55 | 56 | def test_market_depth(self): 57 | self.assertEqual("MarketDepth.HSI@SEHK.IB.12312", get_model_id(self.factory.sample_market_depth())) 58 | 59 | def test_new_order_request(self): 60 | self.assertEqual("BuyLowSellHigh.1", get_model_id(self.factory.sample_new_order_request())) 61 | 62 | def test_order_replace_request(self): 63 | self.assertEqual("BuyLowSellHigh.1", get_model_id(self.factory.sample_order_replace_request())) 64 | 65 | def test_order_cancel_request(self): 66 | self.assertEqual("BuyLowSellHigh.1", get_model_id(self.factory.sample_order_cancel_request())) 67 | 68 | def test_order_status_update(self): 69 | self.assertEqual("IB.event_123", get_model_id(self.factory.sample_order_status_update())) 70 | 71 | def test_execution_report(self): 72 | self.assertEqual("IB.event_123", get_model_id(self.factory.sample_execution_report())) 73 | 74 | def test_account_update(self): 75 | self.assertEqual("IB.e_123", get_model_id(self.factory.sample_account_update())) 76 | 77 | def test_portfolio_update(self): 78 | self.assertEqual("IB.e_456", get_model_id(self.factory.sample_portfolio_update())) 79 | 80 | def test_account_state(self): 81 | self.assertEqual("test_acct", get_model_id(self.factory.sample_account_state())) 82 | 83 | def test_portfolio_state(self): 84 | self.assertEqual("test_portf", get_model_id(self.factory.sample_portfolio_state())) 85 | 86 | def test_strategy_state(self): 87 | self.assertEqual("BLSH", get_model_id(self.factory.sample_strategy_state())) 88 | 89 | def test_order_state(self): 90 | self.assertEqual("BuyLowSellHigh.1", get_model_id(self.factory.sample_order_state())) 91 | 92 | def test_sequence(self): 93 | self.assertEqual("test_seq", get_model_id(self.factory.sample_sequence())) 94 | -------------------------------------------------------------------------------- /tests/test_model_utils.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from algotrader.utils.model import * 4 | 5 | 6 | class ModelUtilsTest(TestCase): 7 | 8 | def test_get_full_cls_name(self): 9 | from algotrader.technical.ma import SMA 10 | ma = SMA(inputs="null") 11 | self.assertEqual("algotrader.technical.ma.SMA", get_full_cls_name(ma)) 12 | 13 | self.assertEqual("algotrader.technical.ma.SMA", get_full_cls_name(SMA)) 14 | 15 | 16 | def test_dynamic_import(self): 17 | bb = get_cls("algotrader.technical.ma.SMA")(inputs="null") 18 | 19 | self.assertEqual("algotrader.technical.ma.SMA", get_full_cls_name(bb)) 20 | -------------------------------------------------------------------------------- /tests/test_persistence_strategy.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from algotrader.app.backtest_runner import BacktestRunner 4 | from algotrader.trading.config import Config, load_from_yaml 5 | from algotrader.trading.context import ApplicationContext 6 | from tests import test_override 7 | 8 | 9 | class StrategyPersistenceTest(TestCase): 10 | start_date = 19930101 11 | intrim_date = 20080101 12 | end_date = 20170101 13 | 14 | stg_override = { 15 | "Strategy": { 16 | "down2%": { 17 | "qty": 1000 18 | } 19 | } 20 | } 21 | 22 | def create_app_context(self, override): 23 | return ApplicationContext(config=Config( 24 | load_from_yaml("../config/backtest.yaml"), 25 | load_from_yaml("../config/down2%.yaml"), 26 | test_override, 27 | StrategyPersistenceTest.stg_override, 28 | override)) 29 | 30 | def execute(self, conf): 31 | context = self.create_app_context(conf) 32 | runner = BacktestRunner() 33 | 34 | runner.start(context) 35 | 36 | begin_result = runner.initial_result['total_equity'] 37 | end_result = runner.portfolio.get_result()['total_equity'] 38 | return begin_result, end_result 39 | 40 | def test_result(self): 41 | total_begin_result, total_end_result = self.execute( 42 | conf={ 43 | "Application": { 44 | "portfolioId": "test", 45 | "fromDate": StrategyPersistenceTest.start_date, 46 | "toDate": StrategyPersistenceTest.end_date, 47 | "deleteDBAtStop": True, 48 | "persistenceMode": "Disable" 49 | } 50 | }) 51 | 52 | part1_begin_result, part1_end_result = self.execute( 53 | conf={ 54 | "Application": { 55 | "portfolioId": "test1", 56 | "fromDate": StrategyPersistenceTest.start_date, 57 | "toDate": StrategyPersistenceTest.intrim_date, 58 | "createDBAtStart": True, 59 | "deleteDBAtStop": False, 60 | "persistenceMode": "Batch" 61 | } 62 | }) 63 | 64 | part2_begin_result, part2_end_result = self.execute( 65 | conf={ 66 | "Application": { 67 | "portfolioId": "test1", 68 | "fromDate": StrategyPersistenceTest.intrim_date, 69 | "toDate": StrategyPersistenceTest.end_date, 70 | "deleteDBAtStop": True, 71 | "persistenceMode": "Disable" 72 | } 73 | }) 74 | 75 | print("total begin = %s" % total_begin_result) 76 | print("total end = %s" % total_end_result) 77 | print("part1 begin = %s" % part1_begin_result) 78 | print("part1 end = %s" % part1_end_result) 79 | print("part2 begin = %s" % part2_begin_result) 80 | print("part2 end = %s" % part2_end_result) 81 | 82 | self.assertEqual(total_begin_result, part1_begin_result) 83 | self.assertEqual(part1_end_result, part2_begin_result) 84 | self.assertEqual(total_end_result, part2_end_result) 85 | -------------------------------------------------------------------------------- /tests/test_plot.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from algotrader.chart.plotter import TimeSeriesPlot 4 | from algotrader.model.model_factory import ModelFactory 5 | from algotrader.model.time_series_pb2 import * 6 | from algotrader.trading.data_series import DataSeries 7 | from algotrader.utils.date import datestr_to_unixtimemillis 8 | 9 | 10 | class PlotTest(TestCase): 11 | values = [44.34, 44.09, 44.15, 43.61, 44.33, 44.83, 45.10, 45.42, 12 | 45.84, 46.08, 45.89, 46.03, 45.61, 46.28, 46.28, 46.00] 13 | 14 | factory = ModelFactory() 15 | 16 | def __create_plot(self): 17 | series = DataSeries(time_series=TimeSeries()) 18 | t = 20170101 19 | for idx, value in enumerate(self.values): 20 | ts = datestr_to_unixtimemillis(t) 21 | series.add(timestamp=ts, data={"value": value}) 22 | t = t + 1 23 | values = series.get_series(["value"]) 24 | plot = TimeSeriesPlot(values) 25 | 26 | return plot 27 | 28 | def test_plot(self): 29 | plot = self.__create_plot() 30 | self.assertIsNotNone(plot) 31 | -------------------------------------------------------------------------------- /tests/test_rolling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from unittest import TestCase 3 | 4 | from algotrader.technical.rolling_apply import StdDev 5 | from algotrader.trading.context import ApplicationContext 6 | 7 | 8 | # from algotrader.trading.instrument_data import inst_data_mgr 9 | 10 | 11 | class RollingApplyTest(TestCase): 12 | def setUp(self): 13 | self.app_context = ApplicationContext() 14 | 15 | def test_name(self): 16 | bar = self.app_context.inst_data_mgr.get_series("bar") 17 | stddev = StdDev(inputs=bar, input_keys='close', length=3) 18 | self.assertEquals("StdDev(bar[close],length=3)", stddev.name) 19 | 20 | def test_empty_at_initialize(self): 21 | close = self.app_context.inst_data_mgr.get_series("bar") 22 | stddev = StdDev(inputs=close, input_keys='close', length=3) 23 | self.assertEquals(0, len(stddev.get_data())) 24 | 25 | def test_nan_before_size(self): 26 | bar = self.app_context.inst_data_mgr.get_series("bar") 27 | bar.start(self.app_context) 28 | 29 | stddev = StdDev(inputs=bar, input_keys='close', length=3) 30 | stddev.start(self.app_context) 31 | 32 | t1 = 1 33 | 34 | nextTime = lambda t: t + 3 35 | 36 | x = np.random.normal(0, 2.0, 3) 37 | ts = np.cumsum(x) + 100 38 | 39 | i = 0 40 | 41 | bar.add(timestamp=t1, data={"close": ts[i], "open": 0}) 42 | self.assertEquals([{'value': np.nan}], 43 | stddev.get_data()) 44 | 45 | t2 = nextTime(t1) 46 | i = i + 1 47 | bar.add(timestamp=t2, data={"close": ts[i], "open": 1.4}) 48 | self.assertEquals([{'value': np.nan}, 49 | {'value': np.nan}], 50 | stddev.get_data()) 51 | 52 | t3 = nextTime(t2) 53 | i = i + 1 54 | bar.add(timestamp=t3, data={"close": ts[i], "open": 1.8}) 55 | self.assertEquals([{'value': np.nan}, 56 | {'value': np.nan}, 57 | {'value': np.std(ts)}], 58 | stddev.get_data()) 59 | 60 | # def test_moving_average_calculation(self): 61 | # inst_data_mgr.clear() 62 | # bar = inst_data_mgr.get_series("bar") 63 | # sma = SMA(bar, input_key='close', length=3) 64 | # 65 | # t1 = datetime.datetime.now() 66 | # t2 = t1 + datetime.timedelta(0, 3) 67 | # t3 = t2 + datetime.timedelta(0, 3) 68 | # t4 = t3 + datetime.timedelta(0, 3) 69 | # t5 = t4 + datetime.timedelta(0, 3) 70 | # 71 | # bar.add({"timestamp": t1, "close": 2.0, "open": 0}) 72 | # self.assertTrue(math.isnan(sma.now('value'))) 73 | # 74 | # bar.add({"timestamp": t2, "close": 2.4, "open": 1.4}) 75 | # self.assertTrue(math.isnan(sma.now('value'))) 76 | # 77 | # bar.add({"timestamp": t3, "close": 2.8, "open": 1.8}) 78 | # self.assertEquals(2.4, sma.now('value')) 79 | # 80 | # bar.add({"timestamp": t4, "close": 3.2, "open": 2.2}) 81 | # self.assertEquals(2.8, sma.now('value')) 82 | # 83 | # bar.add({"timestamp": t5, "close": 3.6, "open": 2.6}) 84 | # self.assertEquals(3.2, sma.now('value')) 85 | # 86 | # self.assertTrue(math.isnan(sma.get_by_idx(0, 'value'))) 87 | # self.assertTrue(math.isnan(sma.get_by_idx(1, 'value'))) 88 | # self.assertEquals(2.4, sma.get_by_idx(2, 'value')) 89 | # self.assertEquals(2.8, sma.get_by_idx(3, 'value')) 90 | # self.assertEquals(3.2, sma.get_by_idx(4, 'value')) 91 | # 92 | # self.assertTrue(math.isnan(sma.get_by_time(t1, 'value'))) 93 | # self.assertTrue(math.isnan(sma.get_by_time(t2, 'value'))) 94 | # self.assertEquals(2.4, sma.get_by_time(t3, 'value')) 95 | # self.assertEquals(2.8, sma.get_by_time(t4, 'value')) 96 | # self.assertEquals(3.2, sma.get_by_time(t5, 'value')) 97 | -------------------------------------------------------------------------------- /tests/test_ser_deser.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from algotrader.model.market_data_pb2 import * 4 | from algotrader.model.ref_data_pb2 import * 5 | from algotrader.model.time_series_pb2 import * 6 | from algotrader.model.trade_data_pb2 import * 7 | from algotrader.utils.protobuf_to_dict import * 8 | from tests.sample_factory import SampleFactory 9 | 10 | 11 | class SerializationTest(TestCase): 12 | def setUp(self): 13 | self.factory = SampleFactory() 14 | 15 | def test_instrument(self): 16 | inst = self.factory.sample_instrument() 17 | self.__test_serializaion(Instrument, inst) 18 | 19 | def test_exchange(self): 20 | exchange = self.factory.sample_exchange() 21 | self.__test_serializaion(Exchange, exchange) 22 | 23 | def test_currency(self): 24 | currency = self.factory.sample_currency() 25 | self.__test_serializaion(Currency, currency) 26 | 27 | def test_country(self): 28 | country = self.factory.sample_country() 29 | self.__test_serializaion(Country, country) 30 | 31 | def test_trading_holidays(self): 32 | trading_holiday = self.factory.sample_trading_holidays() 33 | self.__test_serializaion(HolidaySeries, trading_holiday) 34 | 35 | def test_trading_hours(self): 36 | trading_hours = self.factory.sample_trading_hours() 37 | self.__test_serializaion(TradingHours, trading_hours) 38 | 39 | def test_timezone(self): 40 | timezone = self.factory.sample_timezone() 41 | self.__test_serializaion(TimeZone, timezone) 42 | 43 | def test_time_series(self): 44 | ds = self.factory.sample_time_series() 45 | self.__test_serializaion(TimeSeries, ds) 46 | 47 | def test_bar(self): 48 | self.__test_serializaion(Bar, self.factory.sample_bar()) 49 | 50 | def test_quote(self): 51 | self.__test_serializaion(Quote, self.factory.sample_quote()) 52 | 53 | def test_trade(self): 54 | self.__test_serializaion(Trade, self.factory.sample_trade()) 55 | 56 | def test_market_depth(self): 57 | self.__test_serializaion(MarketDepth, self.factory.sample_market_depth()) 58 | 59 | def test_new_order_request(self): 60 | self.__test_serializaion(NewOrderRequest, self.factory.sample_new_order_request()) 61 | 62 | def test_order_replace_request(self): 63 | self.__test_serializaion(OrderReplaceRequest, self.factory.sample_order_replace_request()) 64 | 65 | def test_order_cancel_request(self): 66 | self.__test_serializaion(OrderCancelRequest, self.factory.sample_order_cancel_request()) 67 | 68 | def test_order_status_update(self): 69 | self.__test_serializaion(OrderStatusUpdate, self.factory.sample_order_status_update()) 70 | 71 | def test_execution_report(self): 72 | self.__test_serializaion(ExecutionReport, self.factory.sample_execution_report()) 73 | 74 | def test_account_update(self): 75 | self.__test_serializaion(AccountUpdate, self.factory.sample_account_update()) 76 | 77 | def test_portfolio_update(self): 78 | self.__test_serializaion(PortfolioUpdate, self.factory.sample_portfolio_update()) 79 | 80 | def test_account_state(self): 81 | self.__test_serializaion(AccountState, self.factory.sample_account_state()) 82 | 83 | def test_portfolio_state(self): 84 | self.__test_serializaion(PortfolioState, self.factory.sample_portfolio_state()) 85 | 86 | def test_strategy_state(self): 87 | self.__test_serializaion(StrategyState, self.factory.sample_strategy_state()) 88 | 89 | def test_order_state(self): 90 | self.__test_serializaion(OrderState, self.factory.sample_order_state()) 91 | 92 | def test_sequence(self): 93 | self.__test_serializaion(Sequence, self.factory.sample_sequence()) 94 | 95 | def __test_serializaion(self, cls, obj): 96 | #print(obj) 97 | 98 | obj2 = cls() 99 | obj2.ParseFromString(obj.SerializeToString()) 100 | self.assertEqual(obj, obj2) 101 | 102 | obj3 = dict_to_protobuf(cls, protobuf_to_dict(obj)) 103 | self.assertEqual(obj, obj3) 104 | -------------------------------------------------------------------------------- /tests/test_suite.py: -------------------------------------------------------------------------------- 1 | # add comment here 2 | import unittest 3 | 4 | from tests.test_bar import BarTest 5 | from tests.test_bar_aggregator import BarAggregatorTest 6 | from tests.test_broker import SimulatorTest 7 | from tests.test_broker_mgr import BrokerManagerTest 8 | from tests.test_clock import ClockTest 9 | #from tests.test_cmp_functional_backtest import TestCompareWithFunctionalBacktest 10 | from tests.test_data_series import DataSeriesTest 11 | from tests.test_in_memory_db import InMemoryDBTest 12 | from tests.test_indicator import IndicatorTest 13 | from tests.test_instrument_data import InstrumentDataTest 14 | from tests.test_ma import MovingAverageTest 15 | from tests.test_market_data_processor import MarketDataProcessorTest 16 | from tests.test_model_factory import ModelFactoryTest 17 | from tests.test_order import OrderTest 18 | from tests.test_order_handler import OrderHandlerTest 19 | #from tests.test_pipeline import PipelineTest 20 | #from tests.test_pipeline_pairwise import PairwiseTest 21 | from tests.test_portfolio import PortfolioTest 22 | from tests.test_position import PositionTest 23 | from tests.test_ref_data import RefDataTest 24 | from tests.test_rolling import RollingApplyTest 25 | from tests.test_ser_deser import SerializationTest 26 | from tests.test_persistence_strategy import StrategyPersistenceTest 27 | from tests.test_persistence_indicator import IndicatorPersistenceTest 28 | from tests.test_talib_wrapper import TALibSMATest 29 | from tests.test_feed import FeedTest 30 | from tests.test_plot import PlotTest 31 | 32 | def suite(): 33 | test_suite = unittest.TestSuite() 34 | test_suite.addTest(unittest.makeSuite(BarTest)) 35 | test_suite.addTest(unittest.makeSuite(BarAggregatorTest)) 36 | test_suite.addTest(unittest.makeSuite(SimulatorTest)) 37 | test_suite.addTest(unittest.makeSuite(BrokerManagerTest)) 38 | test_suite.addTest(unittest.makeSuite(ClockTest)) 39 | test_suite.addTest(unittest.makeSuite(DataSeriesTest)) 40 | test_suite.addTest(unittest.makeSuite(FeedTest)) 41 | test_suite.addTest(unittest.makeSuite(IndicatorTest)) 42 | test_suite.addTest(unittest.makeSuite(InstrumentDataTest)) 43 | test_suite.addTest(unittest.makeSuite(MovingAverageTest)) 44 | test_suite.addTest(unittest.makeSuite(MarketDataProcessorTest)) 45 | test_suite.addTest(unittest.makeSuite(ModelFactoryTest)) 46 | test_suite.addTest(unittest.makeSuite(OrderTest)) 47 | test_suite.addTest(unittest.makeSuite(OrderHandlerTest)) 48 | #test_suite.addTest(unittest.makeSuite(TestCompareWithFunctionalBacktest)) 49 | test_suite.addTest(unittest.makeSuite(InMemoryDBTest)) 50 | #test_suite.addTest(unittest.makeSuite(PersistenceTest)) 51 | #test_suite.addTest(unittest.makeSuite(PipelineTest)) 52 | #test_suite.addTest(unittest.makeSuite(PairwiseTest)) 53 | test_suite.addTest(unittest.makeSuite(PlotTest)) 54 | test_suite.addTest(unittest.makeSuite(PortfolioTest)) 55 | test_suite.addTest(unittest.makeSuite(PositionTest)) 56 | <<<<<<< HEAD 57 | test_suite.addTest(unittest.makeSuite(SerializerTest)) 58 | test_suite.addTest(unittest.makeSuite(TALibSMATest)) 59 | #test_suite.addTest(unittest.makeSuite(TestCompareWithFunctionalBacktest)) 60 | test_suite.addTest(unittest.makeSuite(InMemoryDBTest)) 61 | #test_suite.addTest(unittest.makeSuite(PersistenceTest)) 62 | #test_suite.addTest(unittest.makeSuite(StrategyPersistenceTest)) 63 | test_suite.addTest(unittest.makeSuite(PipelineTest)) 64 | test_suite.addTest(unittest.makeSuite(PairwiseTest)) 65 | ======= 66 | test_suite.addTest(unittest.makeSuite(RefDataTest)) 67 | >>>>>>> cc21e5ebd346d2b2956bbf45f11daba52e4086b1 68 | test_suite.addTest(unittest.makeSuite(RollingApplyTest)) 69 | test_suite.addTest(unittest.makeSuite(SerializationTest)) 70 | test_suite.addTest(unittest.makeSuite(IndicatorPersistenceTest)) 71 | test_suite.addTest(unittest.makeSuite(StrategyPersistenceTest)) 72 | test_suite.addTest(unittest.makeSuite(TALibSMATest)) 73 | return test_suite 74 | 75 | 76 | mySuit = suite() 77 | 78 | runner = unittest.TextTestRunner() 79 | runner.run(mySuit) 80 | # creating a new test suite 81 | newSuite = unittest.TestSuite() 82 | -------------------------------------------------------------------------------- /todo.txt: -------------------------------------------------------------------------------- 1 | ### Dataframe / Series 2 | - to and from pandas 3 | - to and from Protobuf 4 | 5 | - persist to and from DB 6 | - serialize and deserialize via network 7 | 8 | 9 | 10 | - register to globalcontext 11 | 12 | 13 | - offline (batch) vs online (realtime update) mode 14 | - DAG (batch mode, pull from child, recursively up to the root) 15 | - parent 16 | store lasttimestamp 17 | return value since timestamp (can be empty array) 18 | 19 | - child 20 | store parentId 21 | store last update timestamp & last result (conflated into a dict) from each parent 22 | store last updatetimestamp 23 | batch calculate and store the result. 24 | 25 | when compute / evaluate, check lastcomputed timestamp 26 | 27 | 28 | 29 | - subscribe (realtime update, push from parent) 30 | - rxpy 31 | - publish 32 | - when subscriber receive update, update the parent timestamp 33 | 34 | 35 | 36 | - multi-index 37 | 38 | 39 | 40 | DataFrame 41 | - holding multiple series 42 | 43 | 44 | EvaluationContext (global) 45 | def evaluate(): 46 | - for each input: 47 | input[self.last_ts:] -> timeseries 48 | put the ts into a dict 49 | 50 | if the dict of timeseries is empty: 51 | no further update. 52 | 53 | if the dict of timeseries is non empty: 54 | merge into a DF 55 | iterate the DF (backfill for na?) 56 | call the update method 57 | 58 | acutally.... 59 | covert to pandas in global context 60 | just use pandas to calculate the value 61 | convert back to series 62 | 63 | Expression 64 | def evaluate(): 65 | create a EvaluationContext 66 | slice for each parent since last calculate into a series, add to EvaluationContext 67 | EvaluationContext.evaluate() 68 | call the compute() 69 | 70 | 71 | def on_event(timestamp: long, input: str, data: Dict): 72 | 73 | 74 | def compute(timestamp: long, input_data: Dict[str, double]): 75 | 76 | 77 | def publish(timestamp, output_data:Dict[str, double]): 78 | 79 | 80 | 81 | - slice for each parebnt 82 | 83 | def update(input, ): 84 | --------------------------------------------------------------------------------