├── LICENSE.txt ├── __init__.py ├── cookbook ├── calculatePnl.py ├── calculateSpread.py ├── connectToTWS.py ├── cython │ ├── mean_c.pyx │ ├── mean_py.py │ ├── setup.py │ └── testSpeed.py ├── downloadVixFutures.py ├── getDataFromYahooFinance.py ├── guiqwt_CurveDialog.py ├── ib_logQuotes.py ├── ib_placeOrder.py ├── ib_streamQuotes.py ├── reconstructVXX │ ├── downloadVixFutures.py │ ├── reconstructVXX.py │ └── vix_futures.csv ├── runConsoleUntilInterrupt.py ├── scales.py └── workingWithDatesAndTime.py ├── createDistribution.py ├── dist ├── make.bat ├── setup.py └── tradingWithPython │ └── __init__.py ├── historicDataDownloader ├── historicDataDownloader.py ├── testData.py └── timeKeeper.py ├── lib ├── __init__.py ├── cboe.py ├── classes.py ├── csvDatabase.py ├── eventSystem.py ├── extra.py ├── functions.py ├── interactiveBrokers │ ├── __init__.py │ ├── extra.py │ ├── histData.py │ ├── logger.py │ └── tickLogger.py ├── interactivebrokers.py ├── logger.py ├── qtpandas.py ├── vixFutures.py ├── widgets.py └── yahooFinance.py ├── nautilus └── nautilus.py ├── sandbox ├── dataModels.py ├── guiWithDatabase.py ├── spreadCalculations.py └── spreadGroup.py └── spreadApp ├── gold_stocks.csv ├── makeDist.py └── spreadScanner.pyw /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (C) 2010-2013 Jev Kuznetsov 2 | All Rights Reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | Redistributions of source code must retain the above copyright notice, 8 | this list of conditions and the following disclaimer. 9 | 10 | Redistributions in binary form must reproduce the above copyright 11 | notice, this list of conditions and the following disclaimer in the 12 | documentation and/or other materials provided with the distribution. 13 | 14 | The name of Jev Kuznetsov may not be used to endorse or promote products 15 | derived from this software without specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 21 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 | POSSIBILITY OF SUCH DAMAGE. 28 | 29 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | __docformat__ = 'restructuredtext' from datetime import datetime import numpy as np from lib.classes import * from lib.functions import * from lib.csvDatabase import HistDataCsv -------------------------------------------------------------------------------- /cookbook/calculatePnl.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Nov 21 19:03:09 2011 4 | 5 | @author: jev 6 | """ 7 | import numpy as np 8 | from pandas import * 9 | 10 | dataFile = 'D:\\Development\\tradingWithPython\\cookbook\\tritonDaily.csv' 11 | 12 | spy = DataFrame.from_csv(dataFile)['SPY'] 13 | 14 | pos = Series(np.zeros(spy.shape[0]),index=spy.index) 15 | pos[10:20] = 10 16 | 17 | #pos.plot() 18 | 19 | 20 | d = {'price':spy, 'pos':pos} 21 | df = DataFrame(d) 22 | df['port'] = df['price']*df['pos'] 23 | 24 | df.to_csv('pnl.csv',index_label='dates') 25 | # = DataFrame(d) 26 | 27 | # test data frame 28 | #idx = spy.index[:10] 29 | idx = DateRange('1/1/2000', periods=10) 30 | data = np.random.rand(10,2) 31 | df = DataFrame(data=data,index=idx,columns =['a','b']) 32 | df.to_csv('foo.csv') 33 | 34 | 35 | df = read_csv('foo.csv',index_col=0, parse_dates=True) 36 | df.to_csv('bar.csv') -------------------------------------------------------------------------------- /cookbook/calculateSpread.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on 28 okt 2011 3 | 4 | @author: jev 5 | ''' 6 | 7 | from tradingWithPython import Symbol, estimateBeta, Spread 8 | from tradingWithPython.lib import yahooFinance 9 | from pandas import DataFrame 10 | import numpy as np 11 | 12 | 13 | startDate = (2010,1,1) 14 | # create two timeseries. data for SPY goes much further back 15 | # than data of VXX 16 | 17 | 18 | 19 | symbolX = Symbol('SPY') 20 | symbolY = Symbol('IWM') 21 | 22 | 23 | symbolX.downloadHistData(startDate) 24 | symbolY.downloadHistData(startDate) 25 | 26 | 27 | 28 | s = Spread(symbolX,symbolY) 29 | 30 | 31 | -------------------------------------------------------------------------------- /cookbook/connectToTWS.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from ib.ext.Contract import Contract 5 | from ib.ext.ExecutionFilter import ExecutionFilter 6 | from ib.opt import ibConnection, message 7 | from time import sleep 8 | 9 | # print all messages from TWS 10 | def watcher(msg): 11 | print '[watcher]',msg 12 | 13 | def dummyHandler(msg): 14 | pass 15 | 16 | # show Bid and Ask quotes 17 | def my_BidAsk(msg): 18 | print 'bid_ask' 19 | print msg 20 | 21 | if msg.field == 1: 22 | print '%s: bid: %s' % (contractTuple[0], msg.price) 23 | elif msg.field == 2: 24 | print '%s: ask: %s' % (contractTuple[0], msg.price) 25 | 26 | def my_BidAsk2(msg): 27 | print 'Handler 2' 28 | print msg 29 | 30 | 31 | def portfolioHandler(msg): 32 | print msg 33 | print msg.contract.m_symbol 34 | 35 | def makeStkContract(contractTuple): 36 | newContract = Contract() 37 | newContract.m_symbol = contractTuple[0] 38 | newContract.m_secType = contractTuple[1] 39 | newContract.m_exchange = contractTuple[2] 40 | newContract.m_currency = contractTuple[3] 41 | 42 | print 'Contract Values:%s,%s,%s,%s:' % contractTuple 43 | return newContract 44 | 45 | def testMarketData(): 46 | tickId = 1 47 | 48 | # Note: Option quotes will give an error if they aren't shown in TWS 49 | contractTuple = ('SPY', 'STK', 'SMART', 'USD') 50 | 51 | stkContract = makeStkContract(contractTuple) 52 | print '* * * * REQUESTING MARKET DATA * * * *' 53 | con.reqMktData(tickId, stkContract, '', False) 54 | sleep(3) 55 | print '* * * * CANCELING MARKET DATA * * * *' 56 | con.cancelMktData(tickId) 57 | 58 | def testExecutions(): 59 | print 'testing executions' 60 | f = ExecutionFilter() 61 | #f.m_clientId = 101 62 | f.m_time = '20110901-00:00:00' 63 | f.m_symbol = 'SPY' 64 | f.m_secType = 'STK' 65 | f.m_exchange = 'SMART' 66 | #f.m_side = 'BUY' 67 | 68 | con.reqExecutions(f) 69 | 70 | 71 | sleep(2) 72 | 73 | def testAccountUpdates(): 74 | con.reqAccountUpdates(True,'') 75 | 76 | def testHistoricData(con): 77 | print 'Testing historic data' 78 | 79 | contractTuple = ('SPY', 'STK', 'SMART', 'USD') 80 | contract = makeStkContract(contractTuple) 81 | 82 | con.reqHistoricalData(1,contract,'20120803 22:00:00','1800 S','1 secs','TRADES',1,2) 83 | sleep(2) 84 | 85 | 86 | def showMessageTypes(): 87 | # show available messages 88 | m = message.registry.keys() 89 | m.sort() 90 | print 'Available message types\n-------------------------' 91 | for msgType in m: 92 | print msgType 93 | 94 | if __name__ == '__main__': 95 | 96 | 97 | showMessageTypes() 98 | 99 | con = ibConnection() 100 | con.registerAll(watcher) # show all messages 101 | con.register(portfolioHandler,message.UpdatePortfolio) 102 | #con.register(watcher,(message.tickPrice,)) 103 | con.connect() 104 | 105 | testHistoricData(con) 106 | 107 | sleep(1) 108 | #testMarketData() 109 | #testExecutions() 110 | #testAccountUpdates() 111 | 112 | con.disconnect() 113 | sleep(2) 114 | print 'All done!' 115 | -------------------------------------------------------------------------------- /cookbook/cython/mean_c.pyx: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | test file for mean calculation 4 | """ 5 | 6 | #cython: boundscheck=False 7 | #cython: wraparound=False 8 | 9 | from __future__ import division 10 | import numpy as np 11 | 12 | cimport numpy as np 13 | ctypedef np.float32_t dtype_t 14 | 15 | def mean(np.ndarray[dtype_t, ndim=2] data): 16 | 17 | 18 | cdef unsigned int row, col, i 19 | cdef dtype_t val 20 | 21 | cdef np.ndarray[dtype_t,ndim=1] s = np.zeros(data.shape[1], dtype=np.float32) 22 | 23 | for row in range(data.shape[0]): 24 | for col in range(data.shape[1]): 25 | s[col]+=data[row,col] 26 | 27 | for row in xrange(s.shape[0]): 28 | s[row] = s[row]/data.shape[0] 29 | 30 | return s 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /cookbook/cython/mean_py.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | test file for mean calculation 4 | """ 5 | 6 | from __future__ import division 7 | import numpy as np 8 | 9 | def mean(data): 10 | 11 | s = np.zeros(data.shape[1]) 12 | 13 | for row in range(data.shape[0]): 14 | for col in range(data.shape[1]): 15 | s[col]+=data[row,col] 16 | 17 | for i in range(s.shape[0]): 18 | s[i] = s[i]/data.shape[0] 19 | 20 | 21 | return s 22 | 23 | 24 | -------------------------------------------------------------------------------- /cookbook/cython/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from distutils.extension import Extension 3 | from Cython.Distutils import build_ext 4 | 5 | import numpy 6 | 7 | ext = Extension("mean_c", ["mean_c.pyx"], 8 | include_dirs = [numpy.get_include()]) 9 | 10 | setup(ext_modules=[ext], 11 | cmdclass = {'build_ext': build_ext}) -------------------------------------------------------------------------------- /cookbook/cython/testSpeed.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Jan 25 20:43:49 2012 4 | 5 | @author: jev 6 | """ 7 | 8 | 9 | import numpy as np 10 | import mean_py 11 | import mean_c 12 | import time 13 | 14 | a = np.random.rand(128,100000).astype(np.float32) 15 | 16 | start = time.clock() 17 | #m1 = a.sum(axis=1) 18 | #m1 = m1/a.shape[0] 19 | m1 = a.mean(axis=0) 20 | print 'Done in %.3f s' % (time.clock()-start) 21 | print m1.shape 22 | 23 | start = time.clock() 24 | m2 = mean_c.mean(a) 25 | print 'Done in %.3f s' % (time.clock()-start) 26 | print m2.shape 27 | #clf() 28 | #plot(m1,'b-x') 29 | #plot(m2,'r-o') 30 | #plot(a.mean(axis=0),'g') -------------------------------------------------------------------------------- /cookbook/downloadVixFutures.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------- 2 | # Name: download CBOE futures 3 | # Purpose: get VIX futures data from CBOE and save to user directory 4 | # 5 | # 6 | # Created: 15-10-2011 7 | # Copyright: (c) Jev Kuznetsov 2011 8 | # Licence: GPL v2 9 | #------------------------------------------------------------------------------- 10 | #!/usr/bin/env python 11 | 12 | 13 | 14 | from urllib import urlretrieve 15 | import os 16 | 17 | m_codes = ['F','G','H','J','K','M','N','Q','U','V','X','Z'] #month codes of the futures 18 | dataDir = os.getenv("USERPROFILE")+'\\twpData\\vixFutures' # data directory 19 | 20 | def saveVixFutureData(year,month, path): 21 | ''' Get future from CBOE and save to file ''' 22 | fName = "CFE_{0}{1}_VX.csv".format(m_codes[month],str(year)[-2:]) 23 | urlStr = "http://cfe.cboe.com/Publish/ScheduledTask/MktData/datahouse/{0}".format(fName) 24 | 25 | try: 26 | urlretrieve(urlStr,path+'\\'+fName) 27 | except Exception as e: 28 | print e 29 | 30 | 31 | 32 | if __name__ == '__main__': 33 | 34 | if not os.path.exists(dataDir): 35 | os.makedirs(dataDir) 36 | 37 | for year in range(2004,2012): 38 | for month in range(12): 39 | print 'Getting data for {0}/{1}'.format(year,month) 40 | saveVixFutureData(year,month,dataDir) 41 | 42 | print 'Data was saved to {0}'.format(dataDir) -------------------------------------------------------------------------------- /cookbook/getDataFromYahooFinance.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Oct 16 18:37:23 2011 4 | 5 | @author: jev 6 | """ 7 | 8 | 9 | from urllib import urlretrieve 10 | from urllib2 import urlopen 11 | from pandas import Index, DataFrame 12 | from datetime import datetime 13 | import matplotlib.pyplot as plt 14 | 15 | sDate = (2005,1,1) 16 | eDate = (2011,10,1) 17 | 18 | symbol = 'SPY' 19 | 20 | fName = symbol+'.csv' 21 | 22 | try: # try to load saved csv file, otherwise get from the net 23 | fid = open(fName) 24 | lines = fid.readlines() 25 | fid.close() 26 | print 'Loaded from ' , fName 27 | except Exception as e: 28 | print e 29 | urlStr = 'http://ichart.finance.yahoo.com/table.csv?s={0}&a={1}&b={2}&c={3}&d={4}&e={5}&f={6}'.\ 30 | format(symbol.upper(),sDate[1]-1,sDate[2],sDate[0],eDate[1]-1,eDate[2],eDate[0]) 31 | print 'Downloading from ', urlStr 32 | urlretrieve(urlStr,symbol+'.csv') 33 | lines = urlopen(urlStr).readlines() 34 | 35 | 36 | dates = [] 37 | data = [[] for i in range(6)] 38 | #high 39 | 40 | # header : Date,Open,High,Low,Close,Volume,Adj Close 41 | for line in lines[1:]: 42 | fields = line.rstrip().split(',') 43 | dates.append(datetime.strptime( fields[0],'%Y-%m-%d')) 44 | for i,field in enumerate(fields[1:]): 45 | data[i].append(float(field)) 46 | 47 | idx = Index(dates) 48 | data = dict(zip(['open','high','low','close','volume','adj_close'],data)) 49 | 50 | # create a pandas dataframe structure 51 | df = DataFrame(data,index=idx).sort() 52 | 53 | df.plot(secondary_y=['volume']) 54 | 55 | 56 | -------------------------------------------------------------------------------- /cookbook/guiqwt_CurveDialog.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Dec 14 19:47:02 2011 4 | 5 | @author: jev 6 | """ 7 | 8 | 9 | from PyQt4.QtGui import * 10 | from PyQt4.QtCore import * 11 | 12 | from guiqwt.plot import CurveDialog 13 | from guiqwt.builder import make 14 | import sys 15 | 16 | import numpy as np 17 | 18 | class MainForm(QDialog): 19 | def __init__(self,parent=None): 20 | super(MainForm,self).__init__(parent) 21 | self.resize(200,200) 22 | 23 | but = QPushButton() 24 | but.setText('Create plot') 25 | self.connect(but,SIGNAL('clicked()'),self.testFcn) 26 | 27 | lay = QVBoxLayout() 28 | lay.addWidget(but) 29 | self.setLayout(lay) 30 | 31 | def testFcn(self): 32 | x = np.linspace(0, 100, 1000) 33 | 34 | y = (np.random.rand(len(x))-0.5).cumsum() 35 | 36 | curve = make.curve(x, y, "ab", "b") 37 | range = make.range(0, 5) 38 | 39 | disp2 = make.computations(range, "TL", 40 | [(curve, "min=%.5f", lambda x,y: y.min()), 41 | (curve, "max=%.5f", lambda x,y: y.max()), 42 | (curve, "avg=%.5f", lambda x,y: y.mean())]) 43 | legend = make.legend("TR") 44 | items = [ curve, range, disp2, legend] 45 | 46 | win = CurveDialog(edit=False, toolbar=True, parent=self) 47 | plot = win.get_plot() 48 | for item in items: 49 | plot.add_item(item) 50 | win.show() 51 | 52 | 53 | 54 | 55 | if __name__ == "__main__": 56 | app = QApplication(sys.argv) 57 | form = MainForm() 58 | form.show() 59 | app.exec_() 60 | -------------------------------------------------------------------------------- /cookbook/ib_logQuotes.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on May 5, 2013 3 | Copyright: Jev Kuznetsov 4 | License: BSD 5 | 6 | Program to log tick events to a file 7 | 8 | example usage: 9 | > python ib_logQuotes.py SPY,VXX,XLE 10 | 11 | start with -v option to show all incoming events 12 | 13 | 14 | ''' 15 | 16 | import argparse # command line argument parser 17 | import datetime as dt # date and time functions 18 | import time # time module for timestamping 19 | import os # used to create directories 20 | import sys # used to print a dot to a terminal without new line 21 | 22 | #--------ibpy imports ---------------------- 23 | from ib.ext.Contract import Contract 24 | from ib.opt import ibConnection, message 25 | 26 | 27 | # tick type definitions, see IB api manual 28 | priceTicks = {1:'bid',2:'ask',4:'last',6:'high',7:'low',9:'close', 14:'open'} 29 | sizeTicks = {0:'bid',3:'ask',5:'last',8:'volume'} 30 | 31 | class TickLogger(object): 32 | ''' class for handling incoming ticks and saving them to file 33 | will create a subdirectory 'tickLogs' if needed and start logging 34 | to a file with current timestamp in its name. 35 | All timestamps in the file are in seconds relative to start of logging 36 | 37 | ''' 38 | def __init__(self,tws, subscriptions): 39 | ''' init class, register handlers ''' 40 | 41 | tws.register(self._priceHandler,message.TickPrice) 42 | tws.register(self._sizeHandler,message.TickSize) 43 | 44 | self.subscriptions = subscriptions 45 | 46 | # save starting time of logging. All times will be in seconds relative 47 | # to this moment 48 | self._startTime = time.time() 49 | 50 | # create data directory if it does not exist 51 | if not os.path.exists('tickLogs'): os.mkdir('tickLogs') 52 | 53 | # open data file for writing 54 | fileName = 'tickLogs\\tickLog_%s.csv' % dt.datetime.now().strftime('%H_%M_%S') 55 | print 'Logging ticks to ' , fileName 56 | self.dataFile = open(fileName,'w') 57 | 58 | 59 | def _priceHandler(self,msg): 60 | ''' price tick handler ''' 61 | data = [self.subscriptions[msg.tickerId].m_symbol,'price',priceTicks[msg.field],msg.price] # data, second field is price tick type 62 | self._writeData(data) 63 | 64 | def _sizeHandler(self,msg): 65 | ''' size tick handler ''' 66 | data = [self.subscriptions[msg.tickerId].m_symbol,'size',sizeTicks[msg.field],msg.size] 67 | self._writeData(data) 68 | 69 | def _writeData(self,data): 70 | ''' write data to log file while adding a timestamp ''' 71 | timestamp = '%.3f' % (time.time()-self._startTime) # 1 ms resolution 72 | dataLine = ','.join(str(bit) for bit in [timestamp]+data) + '\n' 73 | self.dataFile.write(dataLine) 74 | 75 | def flush(self): 76 | ''' commits data to file''' 77 | self.dataFile.flush() 78 | 79 | def close(self): 80 | '''close file in a neat manner ''' 81 | print 'Closing data file' 82 | self.dataFile.close() 83 | 84 | 85 | def printMessage(msg): 86 | ''' function to print all incoming messages from TWS ''' 87 | print '[msg]:', msg 88 | 89 | 90 | def createContract(symbol): 91 | ''' create contract object ''' 92 | c = Contract() 93 | c.m_symbol = symbol 94 | c.m_secType= "STK" 95 | c.m_exchange = "SMART" 96 | c.m_currency = "USD" 97 | 98 | return c 99 | 100 | #--------------main script------------------ 101 | 102 | if __name__ == '__main__': 103 | 104 | #-----------parse command line arguments 105 | parser = argparse.ArgumentParser(description='Log ticks for a set of stocks') 106 | 107 | 108 | parser.add_argument("symbols",help = 'symbols separated by coma: SPY,VXX') 109 | parser.add_argument("-v", "--verbose", help="show all incoming messages", 110 | action="store_true") 111 | 112 | args = parser.parse_args() 113 | 114 | symbols = args.symbols.strip().split(',') 115 | print 'Logging ticks for:',symbols 116 | 117 | 118 | #---create subscriptions dictionary. Keys are subscription ids 119 | subscriptions = {} 120 | for idx, symbol in enumerate(symbols): 121 | subscriptions[idx+1] = createContract(symbol) 122 | 123 | tws = ibConnection() 124 | logger = TickLogger(tws,subscriptions) 125 | 126 | # print all messages to the screen if verbose option is chosen 127 | if args.verbose: 128 | print 'Starting in verbose mode' 129 | tws.registerAll(printMessage) 130 | 131 | 132 | tws.connect() 133 | 134 | #-------subscribe to data 135 | for subId, c in subscriptions.iteritems(): 136 | tws.reqMktData(subId,c,"",False) 137 | 138 | #------start a loop that must be interrupted with Ctrl-C 139 | print 'Press Ctr-C to stop loop' 140 | 141 | try: 142 | while True: 143 | time.sleep(2) # wait a little 144 | logger.flush() # commit data to file 145 | sys.stdout.write('.') # print a dot to the screen 146 | 147 | 148 | except KeyboardInterrupt: 149 | print 'Interrupted with Ctrl-c' 150 | 151 | logger.close() 152 | tws.disconnect() 153 | print 'All done' 154 | -------------------------------------------------------------------------------- /cookbook/ib_placeOrder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Demonstrate order submission with ibpy 4 | """ 5 | 6 | from time import sleep 7 | 8 | from ib.ext.Contract import Contract 9 | from ib.opt import ibConnection 10 | from ib.ext.Order import Order 11 | 12 | def createContract(symbol): 13 | ''' create contract object ''' 14 | c = Contract() 15 | c.m_symbol = symbol 16 | c.m_secType= "STK" 17 | c.m_exchange = "SMART" 18 | c.m_currency = "USD" 19 | 20 | return c 21 | 22 | def createOrder(orderId,shares,limit = None, transmit=0): 23 | ''' 24 | create order object 25 | 26 | Parameters 27 | ----------- 28 | orderId : The order Id. You must specify a unique value. 29 | When the order status returns, it will be identified by this tag. 30 | This tag is also used when canceling the order. 31 | 32 | shares: number of shares to buy or sell. Negative for sell order. 33 | limit : price limit, None for MKT order 34 | transmit: transmit immideatelly from tws 35 | ''' 36 | 37 | action = {-1:'SELL',1:'BUY'} 38 | 39 | o = Order() 40 | 41 | o.m_orderId = orderId 42 | o.m_action = action[cmp(shares,0)] 43 | o.m_totalQuantity = abs(shares) 44 | o.m_transmit = transmit 45 | 46 | if limit is not None: 47 | o.m_orderType = 'LMT' 48 | o.m_lmtPrice = limit 49 | else: 50 | o.m_orderType = 'MKT' 51 | 52 | return o 53 | 54 | class MessageHandler(object): 55 | ''' class for handling incoming messages ''' 56 | 57 | def __init__(self,tws): 58 | ''' create class, provide ibConnection object as parameter ''' 59 | self.nextValidOrderId = None 60 | 61 | tws.registerAll(self.debugHandler) 62 | tws.register(self.nextValidIdHandler,'NextValidId') 63 | 64 | 65 | def nextValidIdHandler(self,msg): 66 | ''' handles NextValidId messages ''' 67 | self.nextValidOrderId = msg.orderId 68 | 69 | def debugHandler(self,msg): 70 | """ function to print messages """ 71 | print msg 72 | 73 | 74 | 75 | #-----------Main script----------------- 76 | 77 | tws = ibConnection() # create connection object 78 | handler = MessageHandler(tws) # message handling class 79 | 80 | tws.connect() # connect to API 81 | 82 | sleep(1) # wait for nextOrderId to come in 83 | 84 | orderId = handler.nextValidOrderId # numeric order id, must be unique. 85 | print 'Placing order with id ', orderId 86 | 87 | contract = createContract('SPY') 88 | order = createOrder(orderId,shares=5, transmit=0) # create order 89 | 90 | 91 | tws.placeOrder(orderId, contract, order) # place order 92 | 93 | sleep(5) 94 | 95 | print 'Cancelling order ' 96 | tws.cancelOrder(orderId) # cancel it. 97 | 98 | print 'All done' 99 | 100 | tws.disconnect() -------------------------------------------------------------------------------- /cookbook/ib_streamQuotes.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 3 | Copyright: Jev Kuznetsov 4 | License: BSD 5 | 6 | Demonstration of how to stream quotes from IB. 7 | This script will subscribe to SPY and stream quotes to the sreen for 10 seconds. 8 | 9 | ''' 10 | 11 | from time import sleep 12 | from ib.ext.Contract import Contract 13 | from ib.opt import ibConnection, message 14 | 15 | 16 | 17 | def price_tick_handler(msg): 18 | """ function to handle price ticks """ 19 | print msg 20 | 21 | 22 | #--------------main script------------------ 23 | 24 | tws = ibConnection() # create connection object 25 | tws.register(price_tick_handler, message.TickPrice) # register handler 26 | tws.connect() # connect to API 27 | 28 | #-------create contract and subscribe to data 29 | c = Contract() 30 | c.m_symbol = "SPY" 31 | c.m_secType= "STK" 32 | c.m_exchange = "SMART" 33 | c.m_currency = "USD" 34 | 35 | tws.reqMktData(1,c,"",False) # request market data 36 | 37 | #-------print data for a couple of seconds, then close 38 | sleep(10) 39 | 40 | print 'All done' 41 | 42 | tws.disconnect() -------------------------------------------------------------------------------- /cookbook/reconstructVXX/downloadVixFutures.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------- 2 | # Name: download CBOE futures 3 | # Purpose: get VIX futures data from CBOE, process data to a single file 4 | # 5 | # 6 | # Created: 15-10-2011 7 | # Copyright: (c) Jev Kuznetsov 2011 8 | # Licence: BSD 9 | #------------------------------------------------------------------------------- 10 | #!/usr/bin/env python 11 | 12 | 13 | 14 | from urllib import urlretrieve 15 | import os 16 | from pandas import * 17 | import datetime 18 | import numpy as np 19 | 20 | m_codes = ['F','G','H','J','K','M','N','Q','U','V','X','Z'] #month codes of the futures 21 | codes = dict(zip(m_codes,range(1,len(m_codes)+1))) 22 | 23 | #dataDir = os.path.dirname(__file__)+'/data' 24 | 25 | 26 | dataDir = os.path.expanduser('~')+'/twpData/vixFutures' 27 | print 'Data directory: ', dataDir 28 | 29 | 30 | 31 | def saveVixFutureData(year,month, path, forceDownload=False): 32 | ''' Get future from CBOE and save to file ''' 33 | fName = "CFE_{0}{1}_VX.csv".format(m_codes[month],str(year)[-2:]) 34 | if os.path.exists(path+'\\'+fName) or forceDownload: 35 | print 'File already downloaded, skipping' 36 | return 37 | 38 | urlStr = "http://cfe.cboe.com/Publish/ScheduledTask/MktData/datahouse/{0}".format(fName) 39 | print 'Getting: %s' % urlStr 40 | try: 41 | urlretrieve(urlStr,path+'\\'+fName) 42 | except Exception as e: 43 | print e 44 | 45 | def buildDataTable(dataDir): 46 | """ create single data sheet """ 47 | files = os.listdir(dataDir) 48 | 49 | data = {} 50 | for fName in files: 51 | print 'Processing: ', fName 52 | try: 53 | df = DataFrame.from_csv(dataDir+'/'+fName) 54 | 55 | 56 | code = fName.split('.')[0].split('_')[1] 57 | month = '%02d' % codes[code[0]] 58 | year = '20'+code[1:] 59 | newCode = year+'_'+month 60 | data[newCode] = df 61 | except Exception as e: 62 | print 'Could not process:', e 63 | 64 | 65 | full = DataFrame() 66 | for k,df in data.iteritems(): 67 | s = df['Settle'] 68 | s.name = k 69 | s[s<5] = np.nan 70 | if len(s.dropna())>0: 71 | full = full.join(s,how='outer') 72 | else: 73 | print s.name, ': Empty dataset.' 74 | 75 | full[full<5]=np.nan 76 | full = full[sorted(full.columns)] 77 | 78 | # use only data after this date 79 | startDate = datetime.datetime(2008,1,1) 80 | 81 | idx = full.index >= startDate 82 | full = full.ix[idx,:] 83 | 84 | #full.plot(ax=gca()) 85 | fName = os.path.expanduser('~')+'/twpData/vix_futures.csv' 86 | print 'Saving to ', fName 87 | full.to_csv(fName) 88 | 89 | 90 | if __name__ == '__main__': 91 | 92 | if not os.path.exists(dataDir): 93 | print 'creating data directory %s' % dataDir 94 | os.makedirs(dataDir) 95 | 96 | for year in range(2008,2013): 97 | for month in range(12): 98 | print 'Getting data for {0}/{1}'.format(year,month+1) 99 | saveVixFutureData(year,month,dataDir) 100 | 101 | print 'Raw wata was saved to {0}'.format(dataDir) 102 | 103 | buildDataTable(dataDir) -------------------------------------------------------------------------------- /cookbook/reconstructVXX/reconstructVXX.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Reconstructing VXX from futures data 4 | 5 | author: Jev Kuznetsov 6 | 7 | License : BSD 8 | """ 9 | from __future__ import division 10 | from pandas import * 11 | import numpy as np 12 | import os 13 | 14 | class Future(object): 15 | """ vix future class, used to keep data structures simple """ 16 | def __init__(self,series,code=None): 17 | """ code is optional, example '2010_01' """ 18 | self.series = series.dropna() # price data 19 | self.settleDate = self.series.index[-1] 20 | self.dt = len(self.series) # roll period (this is default, should be recalculated) 21 | self.code = code # string code 'YYYY_MM' 22 | 23 | def monthNr(self): 24 | """ get month nr from the future code """ 25 | return int(self.code.split('_')[1]) 26 | 27 | def dr(self,date): 28 | """ days remaining before settlement, on a given date """ 29 | return(sum(self.series.index>date)) 30 | 31 | 32 | def price(self,date): 33 | """ price on a date """ 34 | return self.series.get_value(date) 35 | 36 | 37 | def returns(df): 38 | """ daily return """ 39 | return (df/df.shift(1)-1) 40 | 41 | 42 | def recounstructVXX(): 43 | """ 44 | calculate VXX returns 45 | needs a previously preprocessed file vix_futures.csv 46 | """ 47 | dataDir = os.path.expanduser('~')+'/twpData' 48 | X = DataFrame.from_csv(dataDir+'/vix_futures.csv') # raw data table 49 | 50 | # build end dates list & futures classes 51 | futures = [] 52 | codes = X.columns 53 | endDates = [] 54 | for code in codes: 55 | f = Future(X[code],code=code) 56 | print code,':', f.settleDate 57 | endDates.append(f.settleDate) 58 | futures.append(f) 59 | 60 | endDates = np.array(endDates) 61 | 62 | # set roll period of each future 63 | for i in range(1,len(futures)): 64 | futures[i].dt = futures[i].dr(futures[i-1].settleDate) 65 | 66 | 67 | # Y is the result table 68 | idx = X.index 69 | Y = DataFrame(index=idx, columns=['first','second','days_left','w1','w2', 70 | 'ret','30days_avg']) 71 | 72 | # W is the weight matrix 73 | W = DataFrame(data = np.zeros(X.values.shape),index=idx,columns = X.columns) 74 | 75 | # for VXX calculation see http://www.ipathetn.com/static/pdf/vix-prospectus.pdf 76 | # page PS-20 77 | for date in idx: 78 | i =np.nonzero(endDates>=date)[0][0] # find first not exprired future 79 | first = futures[i] # first month futures class 80 | second = futures[i+1] # second month futures class 81 | 82 | dr = first.dr(date) # number of remaining dates in the first futures contract 83 | dt = first.dt #number of business days in roll period 84 | 85 | W.set_value(date,codes[i],100*dr/dt) 86 | W.set_value(date,codes[i+1],100*(dt-dr)/dt) 87 | 88 | # this is all just debug info 89 | p1 = first.price(date) 90 | p2 = second.price(date) 91 | w1 = 100*dr/dt 92 | w2 = 100*(dt-dr)/dt 93 | 94 | Y.set_value(date,'first',p1) 95 | Y.set_value(date,'second',p2) 96 | Y.set_value(date,'days_left',first.dr(date)) 97 | Y.set_value(date,'w1',w1) 98 | Y.set_value(date,'w2',w2) 99 | 100 | 101 | Y.set_value(date,'30days_avg',(p1*w1+p2*w2)/100) 102 | 103 | valCurr = (X*W.shift(1)).sum(axis=1) # value on day N 104 | valYest = (X.shift(1)*W.shift(1)).sum(axis=1) # value on day N-1 105 | Y['ret'] = valCurr/valYest-1 # index return on day N 106 | 107 | return Y 108 | 109 | 110 | 111 | 112 | 113 | ##-------------------Main script--------------------------- 114 | if __name__=="__main__": 115 | Y = recounstructVXX() 116 | 117 | print Y.head(30)# 118 | Y.to_csv('reconstructedVXX.csv') 119 | 120 | -------------------------------------------------------------------------------- /cookbook/runConsoleUntilInterrupt.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | example on how to run a console script until interrupted by keyboard 4 | 5 | @author: jev 6 | """ 7 | 8 | 9 | from time import sleep 10 | counter = 0 11 | 12 | print 'Press Ctr-C to stop loop' 13 | 14 | try: 15 | while True: 16 | print counter 17 | counter += 1 18 | sleep(1) 19 | 20 | except KeyboardInterrupt: 21 | print 'All done' 22 | 23 | -------------------------------------------------------------------------------- /cookbook/scales.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | ############################################################################# 4 | ## 5 | ## This file was adapted from Taurus, a Tango User Interface Library 6 | ## 7 | ## http://www.tango-controls.org/static/taurus/latest/doc/html/index.html 8 | ## 9 | ## Copyright 2011 CELLS / ALBA Synchrotron, Bellaterra, Spain 10 | ## 11 | ## Taurus is free software: you can redistribute it and/or modify 12 | ## it under the terms of the GNU Lesser General Public License as published by 13 | ## the Free Software Foundation, either version 3 of the License, or 14 | ## (at your option) any later version. 15 | ## 16 | ## Taurus is distributed in the hope that it will be useful, 17 | ## but WITHOUT ANY WARRANTY; without even the implied warranty of 18 | ## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 19 | ## GNU Lesser General Public License for more details. 20 | ## 21 | ## You should have received a copy of the GNU Lesser General Public License 22 | ## along with Taurus. If not, see . 23 | ## 24 | ############################################################################# 25 | 26 | """ 27 | scales.py: Custom scales 28 | """ 29 | __all__=["DateTimeScaleEngine", "DeltaTimeScaleEngine", "FixedLabelsScaleEngine", 30 | "FancyScaleDraw", "TaurusTimeScaleDraw", "DeltaTimeScaleDraw", 31 | "FixedLabelsScaleDraw"] 32 | 33 | import numpy 34 | from datetime import datetime, timedelta 35 | from time import mktime 36 | from PyQt4 import Qt, Qwt5 37 | 38 | 39 | def _getDefaultAxisLabelsAlignment(axis, rotation): 40 | '''return a "smart" alignment for the axis labels depending on the axis 41 | and the label rotation 42 | 43 | :param axis: (Qwt5.QwtPlot.Axis) the axis 44 | :param rotation: (float) The rotation (in degrees, clockwise-positive) 45 | 46 | :return: (Qt.Alignment) an alignment 47 | ''' 48 | if axis == Qwt5.QwtPlot.xBottom: 49 | if rotation == 0 : return Qt.Qt.AlignHCenter|Qt.Qt.AlignBottom 50 | elif rotation < 0: return Qt.Qt.AlignLeft|Qt.Qt.AlignBottom 51 | else: return Qt.Qt.AlignRight|Qt.Qt.AlignBottom 52 | elif axis == Qwt5.QwtPlot.yLeft: 53 | if rotation == 0 : return Qt.Qt.AlignLeft|Qt.Qt.AlignVCenter 54 | elif rotation < 0: return Qt.Qt.AlignLeft|Qt.Qt.AlignBottom 55 | else: return Qt.Qt.AlignLeft|Qt.Qt.AlignTop 56 | elif axis == Qwt5.QwtPlot.yRight: 57 | if rotation == 0 : return Qt.Qt.AlignRight|Qt.Qt.AlignVCenter 58 | elif rotation < 0: return Qt.Qt.AlignRight|Qt.Qt.AlignTop 59 | else: return Qt.Qt.AlignRight|Qt.Qt.AlignBottom 60 | elif axis == Qwt5.QwtPlot.xTop: 61 | if rotation == 0 : return Qt.Qt.AlignHCenter|Qt.Qt.AlignTop 62 | elif rotation < 0: return Qt.Qt.AlignLeft|Qt.Qt.AlignTop 63 | else: return Qt.Qt.AlignRight|Qt.Qt.AlignTop 64 | 65 | class FancyScaleDraw(Qwt5.QwtScaleDraw): 66 | 67 | '''This is a scaleDraw with a tuneable palette and label formats''' 68 | def __init__(self, format = None, palette = None): 69 | Qwt5.QwtScaleDraw.__init__(self) 70 | self._labelFormat = format 71 | self._palette = palette 72 | 73 | def setPalette(self, palette): 74 | '''pass a QPalette or None to use default''' 75 | self._palette = palette 76 | 77 | def getPalette(self): 78 | return self._palette 79 | 80 | def setLabelFormat(self, format): 81 | '''pass a format string (e.g. "%g") or None to use default (it uses the locale)''' 82 | self._labelFormat = format 83 | self.invalidateCache() #to force repainting of the labels 84 | 85 | def getLabelFormat(self): 86 | '''pass a format string (e.g. "%g") or None to use default (it uses the locale)''' 87 | return self._labelFormat 88 | 89 | def label(self, val): 90 | if str(self._labelFormat) == "": return Qwt5.QwtText() 91 | if self._labelFormat is None: 92 | return Qwt5.QwtScaleDraw.label(self, val) 93 | else: 94 | return Qwt5.QwtText(self._labelFormat%val) 95 | 96 | def draw(self, painter, palette): 97 | if self._palette is None: 98 | Qwt5.QwtScaleDraw.draw(self, painter, palette) 99 | else: 100 | Qwt5.QwtScaleDraw.draw(self, painter, self._palette) 101 | 102 | 103 | class DateTimeScaleEngine(Qwt5.QwtLinearScaleEngine): 104 | def __init__(self, scaleDraw=None): 105 | Qwt5.QwtLinearScaleEngine.__init__(self) 106 | self.setScaleDraw(scaleDraw) 107 | 108 | def setScaleDraw(self, scaleDraw): 109 | self._scaleDraw = scaleDraw 110 | 111 | def scaleDraw(self): 112 | return self._scaleDraw 113 | 114 | def divideScale(self, x1, x2, maxMajSteps, maxMinSteps, stepSize): 115 | ''' Reimplements Qwt5.QwtLinearScaleEngine.divideScale 116 | 117 | **Important**: The stepSize parameter is **ignored**. 118 | 119 | :return: (Qwt5.QwtScaleDiv) a scale division whose ticks are aligned with 120 | the natural time units ''' 121 | 122 | #if stepSize != 0: 123 | # scaleDiv = Qwt5.QwtLinearScaleEngine.divideScale(self, x1, x2, maxMajSteps, maxMinSteps, stepSize) 124 | # scaleDiv.datetimeLabelFormat = "%Y/%m/%d %H:%M%S.%f" 125 | # return scaleDiv 126 | 127 | interval = Qwt5.QwtDoubleInterval(x1, x2).normalized() 128 | if interval.width() <= 0: 129 | return Qwt5.QwtScaleDiv() 130 | 131 | dt1=datetime.fromtimestamp(interval.minValue()) 132 | dt2=datetime.fromtimestamp(interval.maxValue()) 133 | 134 | if dt1.year<1900 or dt2.year>9999 : #limits in time.mktime and datetime 135 | return Qwt5.QwtScaleDiv() 136 | 137 | majticks = [] 138 | medticks = [] 139 | minticks = [] 140 | 141 | dx=interval.width() 142 | 143 | if dx > 63072001: # = 3600s*24*(365+366) = 2 years (counting a leap year) 144 | format = "%Y" 145 | for y in range(dt1.year+1,dt2.year): 146 | dt = datetime(year=y, month=1, day=1) 147 | majticks.append(mktime(dt.timetuple())) 148 | 149 | elif dx > 5270400: # = 3600s*24*61 = 61 days 150 | format = "%Y %b" 151 | d = timedelta(days=31) 152 | dt = dt1.replace(day=1, hour=0, minute=0, second=0, microsecond=0)+d 153 | while(dt 172800: # 3600s24*2 = 2 days 159 | format = "%b/%d" 160 | d = timedelta(days=1) 161 | dt = dt1.replace(hour=0, minute=0, second=0, microsecond=0) + d 162 | while(dt 7200: # 3600s*2 = 2hours 167 | format = "%b/%d-%Hh" 168 | d = timedelta(hours=1) 169 | dt = dt1.replace(minute=0, second=0, microsecond=0) + d 170 | while(dt 1200: # 60s*20 =20 minutes 175 | format = "%H:%M" 176 | d = timedelta(minutes=10) 177 | dt = dt1.replace(minute=(dt1.minute//10)*10, second=0, microsecond=0) + d 178 | while(dt 120: # =60s*2 = 2 minutes 183 | format = "%H:%M" 184 | d = timedelta(minutes=1) 185 | dt = dt1.replace(second=0, microsecond=0) + d 186 | while(dt 20: # 20 s 191 | format = "%H:%M:%S" 192 | d = timedelta(seconds=10) 193 | dt = dt1.replace(second=(dt1.second//10)*10, microsecond=0) + d 194 | while(dt 2: # 2s 199 | format = "%H:%M:%S" 200 | majticks=range(int(x1)+1, int(x2)) 201 | 202 | else: #less than 2s (show microseconds) 203 | scaleDiv = Qwt5.QwtLinearScaleEngine.divideScale(self, x1, x2, maxMajSteps, maxMinSteps, stepSize) 204 | self.scaleDraw().setDatetimeLabelFormat("%S.%f") 205 | return scaleDiv 206 | 207 | #make sure to comply with maxMajTicks 208 | L= len(majticks) 209 | if L > maxMajSteps: 210 | majticks = majticks[::int(numpy.ceil(float(L)/maxMajSteps))] 211 | 212 | scaleDiv = Qwt5.QwtScaleDiv(interval, minticks, medticks, majticks) 213 | self.scaleDraw().setDatetimeLabelFormat(format) 214 | if x1>x2: 215 | scaleDiv.invert() 216 | 217 | ##START DEBUG 218 | #print "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!" 219 | #for tk in scaleDiv.ticks(scaleDiv.MajorTick): 220 | # print datetime.fromtimestamp(tk).isoformat() 221 | #print "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!" 222 | ##END DEBUG 223 | 224 | return scaleDiv 225 | 226 | @staticmethod 227 | def getDefaultAxisLabelsAlignment(axis, rotation): 228 | '''return a "smart" alignment for the axis labels depending on the axis 229 | and the label rotation 230 | 231 | :param axis: (Qwt5.QwtPlot.Axis) the axis 232 | :param rotation: (float) The rotation (in degrees, clockwise-positive) 233 | 234 | :return: (Qt.Alignment) an alignment 235 | ''' 236 | return _getDefaultAxisLabelsAlignment(axis, rotation) 237 | 238 | @staticmethod 239 | def enableInAxis(plot, axis, scaleDraw =None, rotation=None): 240 | '''convenience method that will enable this engine in the given 241 | axis. Note that it changes the ScaleDraw as well. 242 | 243 | :param plot: (Qwt5.QwtPlot) the plot to change 244 | :param axis: (Qwt5.QwtPlot.Axis) the id of the axis 245 | :param scaleDraw: (Qwt5.QwtScaleDraw) Scale draw to use. If None given, 246 | the current ScaleDraw for the plot will be used if 247 | possible, and a :class:`TaurusTimeScaleDraw` will be set if not 248 | :param rotation: (float or None) The rotation of the labels (in degrees, clockwise-positive) 249 | ''' 250 | if scaleDraw is None: 251 | scaleDraw = plot.axisScaleDraw(axis) 252 | if not isinstance(scaleDraw, TaurusTimeScaleDraw): 253 | scaleDraw = TaurusTimeScaleDraw() 254 | plot.setAxisScaleDraw(axis, scaleDraw) 255 | plot.setAxisScaleEngine(axis, DateTimeScaleEngine(scaleDraw)) 256 | if rotation is not None: 257 | alignment = DateTimeScaleEngine.getDefaultAxisLabelsAlignment(axis, rotation) 258 | plot.setAxisLabelRotation(axis, rotation) 259 | plot.setAxisLabelAlignment(axis, alignment) 260 | 261 | @staticmethod 262 | def disableInAxis(plot, axis, scaleDraw=None, scaleEngine=None): 263 | '''convenience method that will disable this engine in the given 264 | axis. Note that it changes the ScaleDraw as well. 265 | 266 | :param plot: (Qwt5.QwtPlot) the plot to change 267 | :param axis: (Qwt5.QwtPlot.Axis) the id of the axis 268 | :param scaleDraw: (Qwt5.QwtScaleDraw) Scale draw to use. If None given, 269 | a :class:`FancyScaleDraw` will be set 270 | :param scaleEngine: (Qwt5.QwtScaleEngine) Scale draw to use. If None given, 271 | a :class:`Qwt5.QwtLinearScaleEngine` will be set 272 | ''' 273 | if scaleDraw is None: 274 | scaleDraw=FancyScaleDraw() 275 | if scaleEngine is None: 276 | scaleEngine = Qwt5.QwtLinearScaleEngine() 277 | plot.setAxisScaleEngine(axis, scaleEngine) 278 | plot.setAxisScaleDraw(axis, scaleDraw) 279 | 280 | 281 | class TaurusTimeScaleDraw(FancyScaleDraw): 282 | 283 | def __init__(self, *args): 284 | FancyScaleDraw.__init__(self, *args) 285 | 286 | def setDatetimeLabelFormat(self, format): 287 | self._datetimeLabelFormat = format 288 | 289 | def datetimeLabelFormat(self): 290 | return self._datetimeLabelFormat 291 | 292 | def label(self, val): 293 | if str(self._labelFormat) == "": return Qwt5.QwtText() 294 | # From val to a string with time 295 | t = datetime.fromtimestamp(val) 296 | try: #If the scaleDiv was created by a DateTimeScaleEngine it has a _datetimeLabelFormat 297 | s = t.strftime(self._datetimeLabelFormat) 298 | except AttributeError, e: 299 | print "Warning: cannot get the datetime label format (Are you using a DateTimeScaleEngine?)" 300 | s = t.isoformat(' ') 301 | return Qwt5.QwtText(s) 302 | 303 | 304 | class DeltaTimeScaleEngine(Qwt5.QwtLinearScaleEngine): 305 | def __init__(self, scaleDraw=None): 306 | Qwt5.QwtLinearScaleEngine.__init__(self) 307 | self.setScaleDraw(scaleDraw) 308 | 309 | def setScaleDraw(self, scaleDraw): 310 | self._scaleDraw = scaleDraw 311 | 312 | def scaleDraw(self): 313 | return self._scaleDraw 314 | 315 | def divideScale(self, x1, x2, maxMajSteps, maxMinSteps, stepSize): 316 | ''' Reimplements Qwt5.QwtLinearScaleEngine.divideScale 317 | 318 | :return: (Qwt5.QwtScaleDiv) a scale division whose ticks are aligned with 319 | the natural delta time units ''' 320 | interval = Qwt5.QwtDoubleInterval(x1, x2).normalized() 321 | if interval.width() <= 0: 322 | return Qwt5.QwtScaleDiv() 323 | d_range = interval.width() 324 | if d_range < 2: # 2s 325 | return Qwt5.QwtLinearScaleEngine.divideScale(self, x1, x2, maxMajSteps, maxMinSteps, stepSize) 326 | elif d_range < 20: # 20 s 327 | s = 1 328 | elif d_range < 120: # =60s*2 = 2 minutes 329 | s = 10 330 | elif d_range < 1200: # 60s*20 =20 minutes 331 | s = 60 332 | elif d_range < 7200: # 3600s*2 = 2 hours 333 | s = 600 334 | elif d_range < 172800: # 3600s24*2 = 2 days 335 | s = 3600 336 | else: 337 | s = 86400 #1 day 338 | #calculate a step size that respects the base step (s) and also enforces the maxMajSteps 339 | stepSize = s * int(numpy.ceil(float(d_range//s)/maxMajSteps)) 340 | return Qwt5.QwtLinearScaleEngine.divideScale(self, x1, x2, maxMajSteps, maxMinSteps, stepSize) 341 | 342 | @staticmethod 343 | def getDefaultAxisLabelsAlignment(axis, rotation): 344 | '''return a "smart" alignment for the axis labels depending on the axis 345 | and the label rotation 346 | 347 | :param axis: (Qwt5.QwtPlot.Axis) the axis 348 | :param rotation: (float) The rotation (in degrees, clockwise-positive) 349 | 350 | :return: (Qt.Alignment) an alignment 351 | ''' 352 | return _getDefaultAxisLabelsAlignment(axis, rotation) 353 | 354 | @staticmethod 355 | def enableInAxis(plot, axis, scaleDraw =None, rotation=None): 356 | '''convenience method that will enable this engine in the given 357 | axis. Note that it changes the ScaleDraw as well. 358 | 359 | :param plot: (Qwt5.QwtPlot) the plot to change 360 | :param axis: (Qwt5.QwtPlot.Axis) the id of the axis 361 | :param scaleDraw: (Qwt5.QwtScaleDraw) Scale draw to use. If None given, 362 | the current ScaleDraw for the plot will be used if 363 | possible, and a :class:`TaurusTimeScaleDraw` will be set if not 364 | :param rotation: (float or None) The rotation of the labels (in degrees, clockwise-positive) 365 | ''' 366 | if scaleDraw is None: 367 | scaleDraw = plot.axisScaleDraw(axis) 368 | if not isinstance(scaleDraw, DeltaTimeScaleDraw): 369 | scaleDraw = DeltaTimeScaleDraw() 370 | plot.setAxisScaleDraw(axis, scaleDraw) 371 | plot.setAxisScaleEngine(axis, DeltaTimeScaleEngine(scaleDraw)) 372 | if rotation is not None: 373 | alignment = DeltaTimeScaleEngine.getDefaultAxisLabelsAlignment(axis, rotation) 374 | plot.setAxisLabelRotation(axis, rotation) 375 | plot.setAxisLabelAlignment(axis, alignment) 376 | 377 | @staticmethod 378 | def disableInAxis(plot, axis, scaleDraw=None, scaleEngine=None): 379 | '''convenience method that will disable this engine in the given 380 | axis. Note that it changes the ScaleDraw as well. 381 | 382 | :param plot: (Qwt5.QwtPlot) the plot to change 383 | :param axis: (Qwt5.QwtPlot.Axis) the id of the axis 384 | :param scaleDraw: (Qwt5.QwtScaleDraw) Scale draw to use. If None given, 385 | a :class:`FancyScaleDraw` will be set 386 | :param scaleEngine: (Qwt5.QwtScaleEngine) Scale draw to use. If None given, 387 | a :class:`Qwt5.QwtLinearScaleEngine` will be set 388 | ''' 389 | if scaleDraw is None: 390 | scaleDraw=FancyScaleDraw() 391 | if scaleEngine is None: 392 | scaleEngine = Qwt5.QwtLinearScaleEngine() 393 | plot.setAxisScaleEngine(axis, scaleEngine) 394 | plot.setAxisScaleDraw(axis, scaleDraw) 395 | 396 | 397 | class DeltaTimeScaleDraw(FancyScaleDraw): 398 | 399 | def __init__(self, *args): 400 | FancyScaleDraw.__init__(self, *args) 401 | 402 | def label(self, val): 403 | if val >= 0: 404 | s = "+%s"%str(timedelta(seconds=val)) 405 | else: 406 | s = "-%s"%str(timedelta(seconds=-val)) 407 | return Qwt5.QwtText(s) 408 | 409 | 410 | 411 | class FixedLabelsScaleEngine(Qwt5.QwtLinearScaleEngine): 412 | def __init__(self, positions): 413 | '''labels is a sequence of (pos,label) tuples where pos is the point 414 | at wich to draw the label and label is given as a python string (or QwtText)''' 415 | Qwt5.QwtScaleEngine.__init__(self) 416 | self._positions = positions 417 | #self.setAttribute(self.Floating,True) 418 | 419 | def divideScale(self, x1, x2, maxMajSteps, maxMinSteps, stepSize=0.0): 420 | div = Qwt5.QwtScaleDiv(x1, x2, self._positions, [], []) 421 | div.setTicks(Qwt5.QwtScaleDiv.MajorTick, self._positions) 422 | return div 423 | 424 | @staticmethod 425 | def enableInAxis(plot, axis, scaleDraw =None): 426 | '''convenience method that will enable this engine in the given 427 | axis. Note that it changes the ScaleDraw as well. 428 | 429 | :param plot: (Qwt5.QwtPlot) the plot to change 430 | :param axis: (Qwt5.QwtPlot.Axis) the id of the axis 431 | :param scaleDraw: (Qwt5.QwtScaleDraw) Scale draw to use. If None given, 432 | the current ScaleDraw for the plot will be used if 433 | possible, and a :class:`FixedLabelsScaleDraw` will be set if not 434 | ''' 435 | if scaleDraw is None: 436 | scaleDraw = plot.axisScaleDraw(axis) 437 | if not isinstance(scaleDraw, FixedLabelsScaleDraw): 438 | scaleDraw = FixedLabelsScaleDraw() 439 | plot.setAxisScaleDraw(axis, scaleDraw) 440 | plot.setAxisScaleEngine(axis, FixedLabelsScaleEngine(scaleDraw)) 441 | 442 | @staticmethod 443 | def disableInAxis(plot, axis, scaleDraw=None, scaleEngine=None): 444 | '''convenience method that will disable this engine in the given 445 | axis. Note that it changes the ScaleDraw as well. 446 | 447 | :param plot: (Qwt5.QwtPlot) the plot to change 448 | :param axis: (Qwt5.QwtPlot.Axis) the id of the axis 449 | :param scaleDraw: (Qwt5.QwtScaleDraw) Scale draw to use. If None given, 450 | a :class:`FancyScaleDraw` will be set 451 | :param scaleEngine: (Qwt5.QwtScaleEngine) Scale draw to use. If None given, 452 | a :class:`Qwt5.QwtLinearScaleEngine` will be set 453 | ''' 454 | if scaleDraw is None: 455 | scaleDraw=FancyScaleDraw() 456 | if scaleEngine is None: 457 | scaleEngine = Qwt5.QwtLinearScaleEngine() 458 | plot.setAxisScaleEngine(axis, scaleEngine) 459 | plot.setAxisScaleDraw(axis, scaleDraw) 460 | 461 | 462 | class FixedLabelsScaleDraw(FancyScaleDraw): 463 | def __init__(self, positions, labels): 464 | '''This is a custom ScaleDraw that shows labels at given positions (and nowhere else) 465 | positions is a sequence of points for which labels are defined. 466 | labels is a sequence strings (or QwtText) 467 | Note that the lengths of positions and labels must match''' 468 | 469 | if len(positions) != len(labels): 470 | raise ValueError('lengths of positions and labels do not match') 471 | 472 | FancyScaleDraw.__init__(self) 473 | self._positions = positions 474 | self._labels = labels 475 | #self._positionsarray = numpy.array(self._positions) #this is stored just in case 476 | 477 | def label(self, val): 478 | try: 479 | index = self._positions.index(val) #try to find an exact match 480 | except: 481 | index = None #It won't show any label 482 | #use the index of the closest position 483 | #index = (numpy.abs(self._positionsarray - val)).argmin() 484 | if index is not None: 485 | return Qwt5.QwtText(self._labels[index]) 486 | else: Qwt5.QwtText() 487 | 488 | 489 | -------------------------------------------------------------------------------- /cookbook/workingWithDatesAndTime.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Oct 16 17:45:02 2011 4 | 5 | @author: jev 6 | """ 7 | 8 | import time 9 | import datetime as dt 10 | from pandas import * 11 | from pandas.core import datetools 12 | 13 | 14 | 15 | 16 | 17 | # basic functions 18 | print 'Epoch start: %s' % time.asctime(time.gmtime(0)) 19 | print 'Seconds from epoch: %.2f' % time.time() 20 | 21 | today = dt.date.today() 22 | print type(today) 23 | print 'Today is %s' % today.strftime('%Y.%m.%d') 24 | 25 | # parse datetime 26 | d = dt.datetime.strptime('20120803 21:59:59',"%Y%m%d %H:%M:%S") 27 | 28 | 29 | 30 | 31 | # time deltas 32 | someDate = dt.date(2011,8,1) 33 | delta = today - someDate 34 | print 'Delta :', delta 35 | 36 | # calculate difference in dates 37 | delta = dt.timedelta(days=20) 38 | print 'Today-delta=', today-delta 39 | 40 | 41 | t = dt.datetime(*time.strptime('3/30/2004',"%m/%d/%Y")[0:5]) 42 | # the '*' operator unpacks the tuple, producing the argument list. 43 | print t 44 | 45 | 46 | # print every 3d wednesday of the month 47 | for month in xrange(1,13): 48 | t = dt.date(2013,month,1)+datetools.relativedelta(months=1) 49 | 50 | 51 | offset = datetools.Week(weekday=4) 52 | if t.weekday()<>4: 53 | t_new = t+3*offset 54 | else: 55 | t_new = t+2*offset 56 | 57 | t_new = t_new-datetools.relativedelta(days=30) 58 | print t_new.strftime("%B, %d %Y (%A)") 59 | 60 | #rng = DateRange(t, t+datetools.YearEnd()) 61 | #print rng 62 | 63 | # create a range of times 64 | start = dt.datetime(2012,8,1)+datetools.relativedelta(hours=9,minutes=30) 65 | end = dt.datetime(2012,8,1)+datetools.relativedelta(hours=22) 66 | 67 | rng = date_range(start,end,freq='30min') 68 | for r in rng: print r.strftime("%Y%m%d %H:%M:%S") -------------------------------------------------------------------------------- /createDistribution.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import os, shutil 5 | 6 | def copyFiles(sourceDir,targetDir,includes=None): 7 | 8 | if not os.path.exists(targetDir): os.makedirs(targetDir) 9 | 10 | 11 | for f in os.listdir(sourceDir): 12 | 13 | base,ext = os.path.splitext(f) 14 | 15 | if (ext =='.py') and (True if not includes else (base in includes)): 16 | s = os.path.join(sourceDir, f) 17 | d = os.path.join(targetDir, f) 18 | print s, '->', d 19 | shutil.copyfile(s,d) 20 | 21 | 22 | sourceDir = 'lib' 23 | targetDir = 'dist\\tradingWithPython\\lib' 24 | 25 | 26 | 27 | includes = ['__init__','cboe','csvDatabase','functions','yahooFinance','extra'] 28 | 29 | print '-----------lib files---------' 30 | copyFiles(sourceDir,targetDir,includes) 31 | print '-----------IB files----------' 32 | copyFiles(sourceDir+'\\interactiveBrokers',targetDir+'\\interactiveBrokers') 33 | 34 | -------------------------------------------------------------------------------- /dist/make.bat: -------------------------------------------------------------------------------- 1 | rd /s /q build 2 | python setup.py bdist_wininst 3 | python setup.py sdist -------------------------------------------------------------------------------- /dist/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | 3 | import tradingWithPython as twp 4 | 5 | 6 | setup(name = "tradingWithPython", 7 | version = twp.__version__, 8 | description = "A collection of functions and classes for Quantitative trading", 9 | author = "Jev Kuznetsov", 10 | author_email = "jev.kuznetsov@gmail.com", 11 | url = "http://www.tradingwithpython.com/", 12 | packages=["tradingWithPython","tradingWithPython\\lib","tradingWithPython\\lib\\interactiveBrokers"] 13 | ) -------------------------------------------------------------------------------- /dist/tradingWithPython/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | __version__ = '0.0.6' 3 | 4 | 5 | from lib.functions import pos2pnl, tradeBracket, estimateBeta, sharpe, drawdown, plotCorrelationMatrix -------------------------------------------------------------------------------- /historicDataDownloader/historicDataDownloader.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on 4 aug. 2012 3 | Copyright: Jev Kuznetsov 4 | License: BSD 5 | 6 | a module for downloading historic data from IB 7 | 8 | ''' 9 | import ib 10 | import pandas 11 | from ib.ext.Contract import Contract 12 | from ib.opt import ibConnection, message 13 | from time import sleep 14 | import tradingWithPython.lib.logger as logger 15 | from pandas import DataFrame, Index 16 | import datetime as dt 17 | from timeKeeper import TimeKeeper 18 | import time 19 | 20 | timeFormat = "%Y%m%d %H:%M:%S" 21 | 22 | class DataHandler(object): 23 | ''' handles incoming messages ''' 24 | def __init__(self,tws): 25 | self._log = logger.getLogger('DH') 26 | tws.register(self.msgHandler,message.HistoricalData) 27 | self.reset() 28 | 29 | def reset(self): 30 | self._log.debug('Resetting data') 31 | self.dataReady = False 32 | self._timestamp = [] 33 | self._data = {'open':[],'high':[],'low':[],'close':[],'volume':[],'count':[],'WAP':[]} 34 | 35 | def msgHandler(self,msg): 36 | #print '[msg]', msg 37 | 38 | if msg.date[:8] == 'finished': 39 | self._log.debug('Data recieved') 40 | self.dataReady = True 41 | return 42 | 43 | self._timestamp.append(dt.datetime.strptime(msg.date,timeFormat)) 44 | for k in self._data.keys(): 45 | self._data[k].append(getattr(msg, k)) 46 | 47 | @property 48 | def data(self): 49 | ''' return downloaded data as a DataFrame ''' 50 | df = DataFrame(data=self._data,index=Index(self._timestamp)) 51 | return df 52 | 53 | 54 | class Downloader(object): 55 | def __init__(self,debug=False): 56 | self._log = logger.getLogger('DLD') 57 | self._log.debug('Initializing data dwonloader. Pandas version={0}, ibpy version:{1}'.format(pandas.__version__,ib.version)) 58 | 59 | self.tws = ibConnection() 60 | self._dataHandler = DataHandler(self.tws) 61 | 62 | if debug: 63 | self.tws.registerAll(self._debugHandler) 64 | self.tws.unregister(self._debugHandler,message.HistoricalData) 65 | 66 | self._log.debug('Connecting to tws') 67 | self.tws.connect() 68 | 69 | self._timeKeeper = TimeKeeper() # keep track of past requests 70 | self._reqId = 1 # current request id 71 | 72 | 73 | def _debugHandler(self,msg): 74 | print '[debug]', msg 75 | 76 | 77 | def requestData(self,contract,endDateTime,durationStr='1800 S',barSizeSetting='1 secs',whatToShow='TRADES',useRTH=1,formatDate=1): 78 | self._log.debug('Requesting data for %s end time %s.' % (contract.m_symbol,endDateTime)) 79 | 80 | 81 | while self._timeKeeper.nrRequests(timeSpan=600) > 59: 82 | print 'Too many requests done. Waiting... ' 83 | time.sleep(1) 84 | 85 | self._timeKeeper.addRequest() 86 | self._dataHandler.reset() 87 | self.tws.reqHistoricalData(self._reqId,contract,endDateTime,durationStr,barSizeSetting,whatToShow,useRTH,formatDate) 88 | self._reqId+=1 89 | 90 | #wait for data 91 | startTime = time.time() 92 | timeout = 3 93 | while not self._dataHandler.dataReady and (time.time()-startTime < timeout): 94 | sleep(2) 95 | 96 | if not self._dataHandler.dataReady: 97 | self._log.error('Data timeout') 98 | 99 | print self._dataHandler.data 100 | 101 | return self._dataHandler.data 102 | 103 | def getIntradayData(self,contract, dateTuple ): 104 | ''' get full day data on 1-s interval 105 | date: a tuple of (yyyy,mm,dd) 106 | ''' 107 | 108 | openTime = dt.datetime(*dateTuple)+dt.timedelta(hours=16) 109 | closeTime = dt.datetime(*dateTuple)+dt.timedelta(hours=22) 110 | 111 | timeRange = pandas.date_range(openTime,closeTime,freq='30min') 112 | 113 | datasets = [] 114 | 115 | for t in timeRange: 116 | datasets.append(self.requestData(contract,t.strftime(timeFormat))) 117 | 118 | return pandas.concat(datasets) 119 | 120 | 121 | def disconnect(self): 122 | self.tws.disconnect() 123 | 124 | 125 | if __name__=='__main__': 126 | 127 | dl = Downloader(debug=True) 128 | 129 | c = Contract() 130 | c.m_symbol = 'SPY' 131 | c.m_secType = 'STK' 132 | c.m_exchange = 'SMART' 133 | c.m_currency = 'USD' 134 | df = dl.getIntradayData(c, (2012,8,6)) 135 | df.to_csv('test.csv') 136 | 137 | # df = dl.requestData(c, '20120803 22:00:00') 138 | # df.to_csv('test1.csv') 139 | # df = dl.requestData(c, '20120803 21:30:00') 140 | # df.to_csv('test2.csv') 141 | 142 | dl.disconnect() 143 | 144 | 145 | 146 | print 'Done.' -------------------------------------------------------------------------------- /historicDataDownloader/testData.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Aug 05 22:06:13 2012 4 | 5 | @author: jev 6 | """ 7 | import numpy as np 8 | from pandas import * 9 | from matplotlib.pyplot import * 10 | 11 | 12 | 13 | #df1 = DataFrame.from_csv('test1.csv').astype(np.dtype('f4')) 14 | #df2 = DataFrame.from_csv('test2.csv').astype(np.dtype('f4')) 15 | #df = DataFrame([df1,df2]) 16 | df = DataFrame.from_csv('test.csv').astype(np.dtype('f4')) 17 | 18 | close('all') 19 | clf() 20 | ax1=subplot(2,1,1) 21 | df[['high','low','WAP']].plot(grid=True,ax=gca()) 22 | 23 | 24 | subplot(2,1,2,sharex=ax1) 25 | df[['count','volume']].plot(ax=gca()) -------------------------------------------------------------------------------- /historicDataDownloader/timeKeeper.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | used to check timing constraints 4 | 5 | @author: jev 6 | """ 7 | 8 | 9 | import os 10 | import datetime as dt 11 | import tradingWithPython.lib.logger as logger 12 | 13 | class TimeKeeper(object): 14 | def __init__(self): 15 | self._log = logger.getLogger('TK') 16 | dataDir = os.path.expanduser('~')+'/twpData' 17 | 18 | if not os.path.exists(dataDir): 19 | os.mkdir(dataDir) 20 | 21 | self._timeFormat = "%Y%m%d %H:%M:%S" 22 | self.dataFile = os.path.normpath(os.path.join(dataDir,'requests.txt')) 23 | self._log.debug('Data file: {0}'.format(self.dataFile)) 24 | 25 | def addRequest(self): 26 | ''' adds a timestamp of current request''' 27 | with open(self.dataFile,'a') as f: 28 | f.write(dt.datetime.now().strftime(self._timeFormat)+'\n') 29 | 30 | 31 | def nrRequests(self,timeSpan=600): 32 | ''' return number of requests in past timespan (s) ''' 33 | delta = dt.timedelta(seconds=timeSpan) 34 | now = dt.datetime.now() 35 | requests = 0 36 | 37 | with open(self.dataFile,'r') as f: 38 | lines = f.readlines() 39 | 40 | for line in lines: 41 | if now-dt.datetime.strptime(line.strip(),self._timeFormat) < delta: 42 | requests+=1 43 | 44 | if requests==0: # erase all contents if no requests are relevant 45 | open(self.dataFile,'w').close() 46 | 47 | self._log.debug('past requests: {0}'.format(requests)) 48 | return requests 49 | 50 | if __name__=='__main__': 51 | print 'testing timeKeeper' 52 | 53 | tk = TimeKeeper() 54 | tk.addRequest() 55 | print tk.nrRequests() -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmorgan/trading-with-python/c7d91a418030316cc9d06f318ade134fde016dee/lib/__init__.py -------------------------------------------------------------------------------- /lib/cboe.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | toolset working with cboe data 4 | 5 | @author: Jev Kuznetsov 6 | Licence: BSD 7 | """ 8 | import datetime 9 | from datetime import datetime, date 10 | import urllib2 11 | from pandas import DataFrame, Index, DateRange 12 | from pandas.core import datetools 13 | import numpy as np 14 | 15 | 16 | def monthCode(month): 17 | """ 18 | perform month->code and back conversion 19 | 20 | Input: either month nr (int) or month code (str) 21 | Returns: code or month nr 22 | 23 | """ 24 | codes = ('F','G','H','J','K','M','N','Q','U','V','X','Z') 25 | 26 | if isinstance(month,int): 27 | return codes[month-1] 28 | elif isinstance(month,str): 29 | return codes.index(month)+1 30 | else: 31 | raise ValueError('Function accepts int or str') 32 | 33 | 34 | def vixExpiration(year,month): 35 | """ 36 | expriration date of a VX future 37 | """ 38 | t = datetime(year,month,1)+datetools.relativedelta(months=1) 39 | 40 | 41 | offset = datetools.Week(weekday=4) 42 | if t.weekday()<>4: 43 | t_new = t+3*offset 44 | else: 45 | t_new = t+2*offset 46 | 47 | t_exp = t_new-datetools.relativedelta(days=30) 48 | return t_exp 49 | 50 | def getPutCallRatio(): 51 | """ download current Put/Call ratio""" 52 | urlStr = 'http://www.cboe.com/publish/ScheduledTask/MktData/datahouse/totalpc.csv' 53 | 54 | try: 55 | lines = urllib2.urlopen(urlStr).readlines() 56 | except Exception, e: 57 | s = "Failed to download:\n{0}".format(e); 58 | print s 59 | 60 | headerLine = 2 61 | 62 | header = lines[headerLine].strip().split(',') 63 | 64 | data = [[] for i in range(len(header))] 65 | 66 | for line in lines[(headerLine+1):]: 67 | fields = line.rstrip().split(',') 68 | data[0].append(datetime.strptime(fields[0],'%m/%d/%Y')) 69 | for i,field in enumerate(fields[1:]): 70 | data[i+1].append(float(field)) 71 | 72 | 73 | return DataFrame(dict(zip(header[1:],data[1:])), index = Index(data[0])) 74 | 75 | 76 | def getHistoricData(symbol): 77 | ''' get historic data from CBOE 78 | symbol: VIX or VXV 79 | return dataframe 80 | ''' 81 | print 'Downloading %s' % symbol 82 | urls = {'VIX':'http://www.cboe.com/publish/ScheduledTask/MktData/datahouse/vixcurrent.csv', \ 83 | 'VXV':'http://www.cboe.com/publish/scheduledtask/mktdata/datahouse/vxvdailyprices.csv' } 84 | 85 | startLine = {'VIX':2,'VXV':3} 86 | 87 | urlStr = urls[symbol] 88 | 89 | try: 90 | lines = urllib2.urlopen(urlStr).readlines() 91 | except Exception, e: 92 | s = "Failed to download:\n{0}".format(e); 93 | print s 94 | 95 | header = ['open','high','low','close'] 96 | dates = [] 97 | data = [[] for i in range(len(header))] 98 | 99 | 100 | for line in lines[startLine[symbol]:]: 101 | fields = line.rstrip().split(',') 102 | try: 103 | dates.append(datetime.strptime( fields[0],'%m/%d/%Y')) 104 | for i,field in enumerate(fields[1:]): 105 | data[i].append(float(field)) 106 | except ValueError as e: 107 | print 'Catched error:' , e 108 | print 'Line:', line 109 | 110 | 111 | 112 | 113 | return DataFrame(dict(zip(header,data)),index=Index(dates)).sort() 114 | 115 | 116 | #---------------------classes-------------------------------------------- 117 | class VixFuture(object): 118 | """ 119 | Class for easy handling of futures data. 120 | """ 121 | 122 | def __init__(self,year,month): 123 | self.year = year 124 | self.month = month 125 | 126 | def expirationDate(self): 127 | return vixExpiration(self.year,self.month) 128 | 129 | def daysLeft(self,date): 130 | """ business days to expiration date """ 131 | r = DateRange(date,self.expirationDate()) 132 | return len(r) 133 | 134 | def __repr__(self): 135 | return 'VX future [%i-%i %s] Exprires: %s' % (self.year,self.month,monthCode(self.month), 136 | self.expirationDate()) 137 | #-------------------test functions--------------------------------------- 138 | def testDownload(): 139 | vix = getHistoricData('VIX') 140 | vxv = getHistoricData('VXV') 141 | vix.plot() 142 | vxv.plot() 143 | 144 | def testExpiration(): 145 | for month in xrange(1,13): 146 | d = vixExpiration(2011,month) 147 | print d.strftime("%B, %d %Y (%A)") 148 | 149 | 150 | 151 | if __name__ == '__main__': 152 | 153 | #testExpiration() 154 | v = VixFuture(2011,11) 155 | print v 156 | 157 | print v.daysLeft(datetime(2011,11,10)) 158 | 159 | -------------------------------------------------------------------------------- /lib/classes.py: -------------------------------------------------------------------------------- 1 | """ 2 | worker classes 3 | 4 | @author: Jev Kuznetsov 5 | Licence: GPL v2 6 | """ 7 | 8 | __docformat__ = 'restructuredtext' 9 | 10 | import os 11 | import logger as logger 12 | from yahooFinance import getHistoricData 13 | from functions import estimateBeta, returns, rank 14 | from datetime import date 15 | from pandas import DataFrame, Series 16 | import numpy as np 17 | import matplotlib.pyplot as plt 18 | 19 | class Symbol(object): 20 | ''' 21 | Symbol class, the foundation of Trading With Python library, 22 | This class acts as an interface to Yahoo data, Interactive Brokers etc 23 | ''' 24 | def __init__(self,name): 25 | self.name = name 26 | self.log = logger.getLogger(self.name) 27 | self.log.debug('class created.') 28 | 29 | self.dataDir = os.getenv("USERPROFILE")+'\\twpData\\symbols\\'+self.name 30 | self.log.debug('Data dir:'+self.dataDir) 31 | self.ohlc = None # historic OHLC data 32 | 33 | def downloadHistData(self, startDate=(2010,1,1),endDate=date.today().timetuple()[:3],\ 34 | source = 'yahoo'): 35 | ''' 36 | get historical OHLC data from a data source (yahoo is default) 37 | startDate and endDate are tuples in form (d,m,y) 38 | ''' 39 | self.log.debug('Getting OHLC data') 40 | self.ohlc = getHistoricData(self.name,startDate,endDate) 41 | 42 | 43 | def histData(self,column='adj_close'): 44 | ''' 45 | Return a column of historic data. 46 | 47 | Returns 48 | ------------- 49 | df : DataFrame 50 | ''' 51 | s = self.ohlc[column] 52 | return DataFrame(s.values,s.index,[self.name]) 53 | 54 | @property 55 | def dayReturns(self): 56 | ''' close-close returns ''' 57 | return (self.ohlc['adj_close']/self.ohlc['adj_close'].shift(1)-1) 58 | #return DataFrame(s.values,s.index,[self.name]) 59 | 60 | class Portfolio(object): 61 | def __init__(self,histPrice,name=''): 62 | """ 63 | Constructor 64 | 65 | Parameters 66 | ---------- 67 | histPrice : historic price 68 | 69 | """ 70 | self.histPrice = histPrice 71 | self.params = DataFrame(index=self.symbols) 72 | self.params['capital'] = 100*np.ones(self.histPrice.shape[1],dtype=np.float) 73 | self.params['last'] = self.histPrice.tail(1).T.ix[:,0] 74 | self.params['shares'] = self.params['capital']/self.params['last'] 75 | self.name= name 76 | 77 | 78 | def setHistPrice(self,histPrice): 79 | self.histPrice = histPrice 80 | 81 | def setShares(self,shares): 82 | """ set number of shares, adjust capital 83 | shares: list, np array or Series 84 | """ 85 | 86 | if len(shares) != self.histPrice.shape[1]: 87 | raise AttributeError('Wrong size of shares vector.') 88 | self.params['shares'] = shares 89 | self.params['capital'] = self.params['shares']*self.params['last'] 90 | 91 | def setCapital(self,capital): 92 | """ Set target captial, adjust number of shares """ 93 | if len(capital) != self.histPrice.shape[1]: 94 | raise AttributeError('Wrong size of shares vector.') 95 | self.params['capital'] = capital 96 | self.params['shares'] = self.params['capital']/self.params['last'] 97 | 98 | 99 | def calculateStatistics(self,other=None): 100 | ''' calculate spread statistics, save internally ''' 101 | res = {} 102 | res['micro'] = rank(self.returns[-1],self.returns) 103 | res['macro'] = rank(self.value[-1], self.value) 104 | 105 | res['last'] = self.value[-1] 106 | 107 | if other is not None: 108 | res['corr'] = self.returns.corr(returns(other)) 109 | 110 | return Series(res,name=self.name) 111 | 112 | @property 113 | def symbols(self): 114 | return self.histPrice.columns.tolist() 115 | 116 | @property 117 | def returns(self): 118 | return (returns(self.histPrice)*self.params['capital']).sum(axis=1) 119 | 120 | @property 121 | def value(self): 122 | return (self.histPrice*self.params['shares']).sum(axis=1) 123 | 124 | def __repr__(self): 125 | return ("Portfolio %s \n" % self.name ) + str(self.params) 126 | #return ('Spread %s :' % self.name ) + str.join(',', 127 | # ['%s*%.2f' % t for t in zip(self.symbols,self.capital)]) 128 | 129 | 130 | 131 | 132 | 133 | class Spread(object): 134 | ''' 135 | Spread class, used to build a spread out of two symbols. 136 | ''' 137 | def __init__(self,symbols, bet = 100, histClose=None, beta = None): 138 | """ symbols : ['XYZ','SPY'] . first one is primary , second one is hedge """ 139 | self.symbols = symbols 140 | self.histClose = histClose 141 | if self.histClose is None: 142 | self._getYahooData() 143 | 144 | self.params = DataFrame(index=self.symbols) 145 | if beta is None: 146 | self.beta =self._estimateBeta() 147 | else: 148 | self.beta = beta 149 | 150 | 151 | self.params['capital'] = Series({symbols[0]:bet, symbols[1]:-bet/self.beta}) 152 | self.params['lastClose'] = self.histClose.tail(1).T.ix[:,0] 153 | self.params['last'] = self.params['lastClose'] 154 | self.params['shares'] = (self.params['capital']/self.params['last']) 155 | 156 | 157 | self._calculate() 158 | def _calculate(self): 159 | """ internal calculations """ 160 | self.params['change'] = (self.params['last']-self.params['lastClose'])*self.params['shares'] 161 | self.params['mktValue'] = self.params['shares']*self.params['last'] 162 | 163 | def setLast(self,last): 164 | """ set current price, perform internal recalculation """ 165 | self.params['last'] = last 166 | self._calculate() 167 | 168 | def setShares(self,shares): 169 | """ set target shares, adjust capital """ 170 | self.params['shares'] = shares 171 | self.params['capital'] = self.params['last']*self.params['shares'] 172 | 173 | def _getYahooData(self, startDate=(2007,1,1)): 174 | """ fetch historic data """ 175 | data = {} 176 | for symbol in self.symbols: 177 | print 'Downloading %s' % symbol 178 | data[symbol]=(getHistoricData(symbol,startDate)['adj_close'] ) 179 | 180 | self.histClose = DataFrame(data).dropna() 181 | 182 | 183 | def _estimateBeta(self): 184 | return estimateBeta(self.histClose[self.symbols[1]],self.histClose[self.symbols[0]]) 185 | 186 | def __repr__(self): 187 | 188 | header = '-'*10+self.name+'-'*10 189 | return header+'\n'+str(self.params)+'\n' 190 | 191 | @property 192 | def change(self): 193 | return (returns(self.histClose)*self.params['capital']).sum(axis=1) 194 | 195 | @property 196 | def value(self): 197 | """ historic market value of the spread """ 198 | return (self.histClose*self.params['shares']).sum(axis=1) 199 | 200 | @property 201 | def name(self): 202 | return str.join('_',self.symbols) 203 | 204 | def calculateStatistics(self): 205 | ''' calculate spread statistics ''' 206 | res = {} 207 | res['micro'] = rank(self.params['change'].sum(),self.change) 208 | res['macro'] = rank(self.params['mktValue'].sum(), self.value) 209 | res['last'] = self.params['mktValue'].sum() 210 | 211 | return Series(res,name=self.name) 212 | 213 | 214 | 215 | 216 | #-----------plotting functions------------------- 217 | def plot(self, figure=None): 218 | 219 | if figure is None: 220 | figure = plt.gcf() 221 | 222 | figure.clear() 223 | 224 | ax1 = plt.subplot(2,1,1) 225 | self.value.plot(ax=ax1, style = 'o-') 226 | p = self.params.T 227 | plt.title('Spread %.2f (\$ %.2f) %s vs %.2f (\$%.2f) %s ' %(p.ix['shares',0],p.ix['capital',0], p.columns[0], 228 | p.ix['shares',1],p.ix['capital',1],p.columns[1])) 229 | 230 | 231 | ax2 = plt.subplot(2,1,2,sharex = ax1) 232 | (self.change).plot(ax=ax2, style= 'o-') 233 | plt.title('daily change') 234 | plt.ylabel('$ change') 235 | 236 | # ax3 = plt.subplot(3,1,3,sharex = ax1) 237 | # self.histClose.plot(ax=ax3) 238 | # plt.title('Price movements') 239 | # plt.ylabel('$') 240 | 241 | 242 | if __name__=='__main__': 243 | 244 | 245 | s = Spread(['SPY','IWM']) 246 | 247 | -------------------------------------------------------------------------------- /lib/csvDatabase.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | intraday data handlers in csv format. 4 | 5 | @author: jev 6 | """ 7 | 8 | from __future__ import division 9 | 10 | from pandas import * 11 | import datetime as dt 12 | import os 13 | from extra import ProgressBar 14 | 15 | dateFormat = "%Y%m%d" # date format for converting filenames to dates 16 | dateTimeFormat = "%Y%m%d %H:%M:%S" 17 | 18 | def fileName2date(fName): 19 | '''convert filename to date''' 20 | name = os.path.splitext(fName)[0] 21 | return dt.datetime.strptime(name.split('_')[1],dateFormat).date() 22 | 23 | def parseDateTime(dateTimeStr): 24 | return dt.datetime.strptime(dateTimeStr,dateTimeFormat) 25 | 26 | def loadCsv(fName): 27 | ''' load DataFrame from csv file ''' 28 | with open(fName,'r') as f: 29 | lines = f.readlines() 30 | 31 | dates= [] 32 | header = [h.strip() for h in lines[0].strip().split(',')[1:]] 33 | data = [[] for i in range(len(header))] 34 | 35 | 36 | for line in lines[1:]: 37 | fields = line.rstrip().split(',') 38 | dates.append(parseDateTime(fields[0])) 39 | for i,field in enumerate(fields[1:]): 40 | data[i].append(float(field)) 41 | 42 | return DataFrame(data=dict(zip(header,data)),index=Index(dates)) 43 | 44 | 45 | class HistDataCsv(object): 46 | '''class for working with historic database in .csv format''' 47 | def __init__(self,symbol,dbDir): 48 | self.symbol = symbol 49 | self.dbDir = os.path.normpath(os.path.join(dbDir,symbol)) 50 | 51 | if not os.path.exists(self.dbDir): 52 | print 'Creating data directory ', self.dbDir 53 | os.mkdir(self.dbDir) 54 | 55 | self.dates = [] 56 | 57 | for fName in os.listdir(self.dbDir): 58 | self.dates.append(fileName2date(fName)) 59 | 60 | 61 | def saveData(self,date, df,lowerCaseColumns=True): 62 | ''' add data to database''' 63 | 64 | if lowerCaseColumns: # this should provide consistency to column names. All lowercase 65 | df.columns = [ c.lower() for c in df.columns] 66 | 67 | s = self.symbol+'_'+date.strftime(dateFormat)+'.csv' # file name 68 | dest = os.path.join(self.dbDir,s) # full path destination 69 | print 'Saving data to: ', dest 70 | df.to_csv(dest) 71 | 72 | def loadDate(self,date): 73 | ''' load data ''' 74 | s = self.symbol+'_'+date.strftime(dateFormat)+'.csv' # file name 75 | 76 | df = DataFrame.from_csv(os.path.join(self.dbDir,s)) 77 | cols = [col.strip() for col in df.columns.tolist()] 78 | df.columns = cols 79 | #df = loadCsv(os.path.join(self.dbDir,s)) 80 | 81 | return df 82 | 83 | def loadDates(self,dates): 84 | ''' load multiple dates, concantenating to one DataFrame ''' 85 | tmp =[] 86 | print 'Loading multiple dates for ' , self.symbol 87 | p = ProgressBar(len(dates)) 88 | 89 | for i,date in enumerate(dates): 90 | tmp.append(self.loadDate(date)) 91 | p.animate(i+1) 92 | 93 | print '' 94 | return concat(tmp) 95 | 96 | 97 | def createOHLC(self): 98 | ''' create ohlc from intraday data''' 99 | ohlc = DataFrame(index=self.dates, columns=['open','high','low','close']) 100 | 101 | for date in self.dates: 102 | 103 | print 'Processing', date 104 | try: 105 | df = self.loadDate(date) 106 | 107 | ohlc.set_value(date,'open',df['open'][0]) 108 | ohlc.set_value(date,'high',df['wap'].max()) 109 | ohlc.set_value(date,'low', df['wap'].min()) 110 | ohlc.set_value(date,'close',df['close'][-1]) 111 | 112 | except Exception as e: 113 | print 'Could not convert:', e 114 | 115 | return ohlc 116 | 117 | def __repr__(self): 118 | return '{symbol} dataset with {nrDates} days of data'.format(symbol=self.symbol, nrDates=len(self.dates)) 119 | 120 | class HistDatabase(object): 121 | ''' class working with multiple symbols at once ''' 122 | def __init__(self, dataDir): 123 | 124 | # get symbols from directory names 125 | symbols = [] 126 | for l in os.listdir(dataDir): 127 | if os.path.isdir(os.path.join(dataDir,l)): 128 | symbols.append(l) 129 | 130 | #build dataset 131 | self.csv = {} # dict of HistDataCsv halndlers 132 | 133 | for symbol in symbols: 134 | self.csv[symbol] = HistDataCsv(symbol,dataDir) 135 | 136 | 137 | def loadDates(self,dates=None): 138 | ''' 139 | get data for all symbols as wide panel 140 | provide a dates list. If no dates list is provided, common dates are used. 141 | ''' 142 | if dates is None: dates=self.commonDates 143 | 144 | tmp = {} 145 | 146 | 147 | for k,v in self.csv.iteritems(): 148 | tmp[k] = v.loadDates(dates) 149 | 150 | return WidePanel(tmp) 151 | 152 | def toHDF(self,dataFile,dates=None): 153 | ''' write wide panel data to a hdfstore file ''' 154 | 155 | if dates is None: dates=self.commonDates 156 | store = HDFStore(dataFile) 157 | wp = self.loadDates(dates) 158 | 159 | store['data'] = wp 160 | store.close() 161 | 162 | 163 | 164 | 165 | 166 | @property 167 | def commonDates(self): 168 | ''' return dates common for all symbols ''' 169 | t = [v.dates for v in self.csv.itervalues()] # get all dates in a list 170 | 171 | d = list(set(t[0]).intersection(*t[1:])) 172 | return sorted(d) 173 | 174 | 175 | def __repr__(self): 176 | s = '-----Hist CSV Database-----\n' 177 | for k,v in self.csv.iteritems(): 178 | s+= (str(v)+'\n') 179 | return s 180 | 181 | 182 | #-------------------- 183 | 184 | if __name__=='__main__': 185 | 186 | dbDir =os.path.normpath('D:/data/30sec') 187 | vxx = HistDataCsv('VXX',dbDir) 188 | spy = HistDataCsv('SPY',dbDir) 189 | # 190 | date = dt.date(2012,8,31) 191 | print date 192 | # 193 | pair = DataFrame({'SPY':spy.loadDate(date)['close'],'VXX':vxx.loadDate(date)['close']}) 194 | 195 | print pair.tail() -------------------------------------------------------------------------------- /lib/eventSystem.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on 26 dec. 2011 3 | Copyright: Jev Kuznetsov 4 | License: BSD 5 | 6 | sender-reciever pattern. 7 | 8 | ''' 9 | 10 | import logger as logger 11 | import types 12 | 13 | class Sender(object): 14 | """ 15 | Sender -> dispatches messages to interested callables 16 | """ 17 | def __init__(self): 18 | self.listeners = {} 19 | self.logger = logger.getLogger() 20 | 21 | 22 | def register(self,listener,events=None): 23 | """ 24 | register a listener function 25 | 26 | Parameters 27 | ----------- 28 | listener : external listener function 29 | events : tuple or list of relevant events (default=None) 30 | """ 31 | if events is not None and type(events) not in (types.TupleType,types.ListType): 32 | events = (events,) 33 | 34 | self.listeners[listener] = events 35 | 36 | def dispatch(self,event=None, msg=None): 37 | """notify listeners """ 38 | for listener,events in self.listeners.items(): 39 | if events is None or event is None or event in events: 40 | try: 41 | listener(self,event,msg) 42 | except (Exception,): 43 | self.unregister(listener) 44 | errmsg = "Exception in message dispatch: Handler '{0}' unregistered for event '{1}' ".format(listener.func_name,event) 45 | self.logger.exception(errmsg) 46 | 47 | def unregister(self,listener): 48 | """ unregister listener function """ 49 | del self.listeners[listener] 50 | 51 | #---------------test functions-------------- 52 | 53 | class ExampleListener(object): 54 | def __init__(self,name=None): 55 | self.name = name 56 | 57 | def method(self,sender,event,msg=None): 58 | print "[{0}] got event {1} with message {2}".format(self.name,event,msg) 59 | 60 | 61 | if __name__=="__main__": 62 | print 'demonstrating event system' 63 | 64 | 65 | alice = Sender() 66 | bob = ExampleListener('bob') 67 | charlie = ExampleListener('charlie') 68 | dave = ExampleListener('dave') 69 | 70 | 71 | # add subscribers to messages from alice 72 | alice.register(bob.method,events='event1') # listen to 'event1' 73 | alice.register(charlie.method,events ='event2') # listen to 'event2' 74 | alice.register(dave.method) # listen to all events 75 | 76 | # dispatch some events 77 | alice.dispatch(event='event1') 78 | alice.dispatch(event='event2',msg=[1,2,3]) 79 | alice.dispatch(msg='attention to all') 80 | 81 | print 'Done.' 82 | 83 | 84 | -------------------------------------------------------------------------------- /lib/extra.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Apr 28, 2013 3 | Copyright: Jev Kuznetsov 4 | License: BSD 5 | ''' 6 | from __future__ import print_function 7 | import sys 8 | 9 | class ProgressBar: 10 | def __init__(self, iterations): 11 | self.iterations = iterations 12 | self.prog_bar = '[]' 13 | self.fill_char = '*' 14 | self.width = 50 15 | self.__update_amount(0) 16 | 17 | def animate(self, iteration): 18 | print('\r', self, end='') 19 | sys.stdout.flush() 20 | self.update_iteration(iteration + 1) 21 | 22 | def update_iteration(self, elapsed_iter): 23 | self.__update_amount((elapsed_iter / float(self.iterations)) * 100.0) 24 | self.prog_bar += ' %d of %s complete' % (elapsed_iter, self.iterations) 25 | 26 | def __update_amount(self, new_amount): 27 | percent_done = int(round((new_amount / 100.0) * 100.0)) 28 | all_full = self.width - 2 29 | num_hashes = int(round((percent_done / 100.0) * all_full)) 30 | self.prog_bar = '[' + self.fill_char * num_hashes + ' ' * (all_full - num_hashes) + ']' 31 | pct_place = (len(self.prog_bar) // 2) - len(str(percent_done)) 32 | pct_string = '%d%%' % percent_done 33 | self.prog_bar = self.prog_bar[0:pct_place] + \ 34 | (pct_string + self.prog_bar[pct_place + len(pct_string):]) 35 | 36 | def __str__(self): 37 | return str(self.prog_bar) -------------------------------------------------------------------------------- /lib/functions.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | twp support functions 4 | 5 | @author: Jev Kuznetsov 6 | Licence: GPL v2 7 | """ 8 | 9 | from scipy import polyfit, polyval 10 | import datetime as dt 11 | #from datetime import datetime, date 12 | from pandas import DataFrame, Index, Series 13 | import csv 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | 17 | def plotCorrelationMatrix(price, thresh = None): 18 | ''' plot a correlation matrix as a heatmap image 19 | inputs: 20 | price: prices DataFrame 21 | thresh: correlation threshold to use for checking, default None 22 | 23 | ''' 24 | symbols = price.columns.tolist() 25 | R = price.pct_change() 26 | 27 | 28 | correlationMatrix = R.corr() 29 | 30 | if thresh is not None: 31 | correlationMatrix = correlationMatrix > thresh 32 | 33 | plt.imshow(abs(correlationMatrix.values),interpolation='none') 34 | plt.xticks(range(len(symbols)),symbols) 35 | plt.yticks(range(len(symbols)),symbols) 36 | plt.colorbar() 37 | plt.title('Correlation matrix') 38 | 39 | return correlationMatrix 40 | 41 | 42 | def pca(A): 43 | """ performs principal components analysis 44 | (PCA) on the n-by-p DataFrame A 45 | Rows of A correspond to observations, columns to variables. 46 | 47 | Returns : 48 | coeff : principal components, column-wise 49 | transform: A in principal component space 50 | latent : eigenvalues 51 | 52 | """ 53 | # computing eigenvalues and eigenvectors of covariance matrix 54 | M = (A - A.mean()).T # subtract the mean (along columns) 55 | [latent,coeff] = np.linalg.eig(np.cov(M)) # attention:not always sorted 56 | 57 | idx = np.argsort(latent) # sort eigenvalues 58 | idx = idx[::-1] # in ascending order 59 | 60 | coeff = coeff[:,idx] 61 | latent = latent[idx] 62 | 63 | score = np.dot(coeff.T,A.T) # projection of the data in the new space 64 | 65 | transform = DataFrame(index = A.index, data = score.T) 66 | 67 | return coeff,transform,latent 68 | 69 | 70 | 71 | def pos2pnl(price,position , ibTransactionCost=False ): 72 | """ 73 | calculate pnl based on price and position 74 | Inputs: 75 | --------- 76 | price: series or dataframe of price 77 | position: number of shares at each time. Column names must be same as in price 78 | ibTransactionCost: use bundled Interactive Brokers transaction cost of 0.005$/share 79 | 80 | Returns a portfolio DataFrame 81 | """ 82 | 83 | delta=position.diff() 84 | port = DataFrame(index=price.index) 85 | 86 | if isinstance(price,Series): # no need to sum along 1 for series 87 | port['cash'] = (-delta*price).cumsum() 88 | port['stock'] = (position*price) 89 | 90 | else: # dealing with DataFrame here 91 | port['cash'] = (-delta*price).sum(axis=1).cumsum() 92 | port['stock'] = (position*price).sum(axis=1) 93 | 94 | 95 | 96 | if ibTransactionCost: 97 | tc = -0.005*position.diff().abs() # basic transaction cost 98 | tc[(tc>-1) & (tc<0)] = -1 # everything under 1$ will be ceil'd to 1$ 99 | tc = tc.sum(axis=1) 100 | port['tc'] = tc.cumsum() 101 | else: 102 | port['tc'] = 0. 103 | 104 | port['total'] = port['stock']+port['cash']+port['tc'] 105 | 106 | 107 | 108 | return port 109 | 110 | def tradeBracket(price,entryBar,maxTradeLength,bracket): 111 | ''' 112 | trade a symmetrical bracket on price series, return price delta and exit bar # 113 | Input 114 | ------ 115 | price : series of price values 116 | entryBar: entry bar number 117 | maxTradeLength : max trade duration in bars 118 | bracket : allowed price deviation 119 | 120 | 121 | ''' 122 | 123 | lastBar = min(entryBar+maxTradeLength,len(price)-1) 124 | p = price[entryBar:lastBar]-price[entryBar] 125 | 126 | idxOutOfBound = np.nonzero(abs(p)>bracket) # find indices where price comes out of bracket 127 | if idxOutOfBound[0].any(): # found match 128 | priceDelta = p[idxOutOfBound[0][0]] 129 | exitBar = idxOutOfBound[0][0]+entryBar 130 | else: # all in bracket, exiting based on time 131 | priceDelta = p[-1] 132 | exitBar = lastBar 133 | 134 | return priceDelta, exitBar 135 | 136 | 137 | def estimateBeta(priceY,priceX,algo = 'standard'): 138 | ''' 139 | estimate stock Y vs stock X beta using iterative linear 140 | regression. Outliers outside 3 sigma boundary are filtered out 141 | 142 | Parameters 143 | -------- 144 | priceX : price series of x (usually market) 145 | priceY : price series of y (estimate beta of this price) 146 | 147 | Returns 148 | -------- 149 | beta : stockY beta relative to stock X 150 | ''' 151 | 152 | X = DataFrame({'x':priceX,'y':priceY}) 153 | 154 | if algo=='returns': 155 | ret = (X/X.shift(1)-1).dropna().values 156 | 157 | #print len(ret) 158 | 159 | x = ret[:,0] 160 | y = ret[:,1] 161 | 162 | iteration = 1 163 | nrOutliers = 1 164 | while iteration < 10 and nrOutliers > 0 : 165 | (a,b) = polyfit(x,y,1) 166 | yf = polyval([a,b],x) 167 | #plot(x,y,'x',x,yf,'r-') 168 | err = yf-y 169 | idxOutlier = abs(err) > 3*np.std(err) 170 | nrOutliers =sum(idxOutlier) 171 | beta = a 172 | #print 'Iteration: %i beta: %.2f outliers: %i' % (iteration,beta, nrOutliers) 173 | x = x[~idxOutlier] 174 | y = y[~idxOutlier] 175 | iteration += 1 176 | 177 | elif algo=='log': 178 | x = np.log(X['x']) 179 | y = np.log(X['y']) 180 | (a,b) = polyfit(x,y,1) 181 | beta = a 182 | 183 | elif algo=='standard': 184 | ret =np.log(X).diff().dropna() 185 | beta = ret['x'].cov(ret['y'])/ret['x'].var() 186 | 187 | 188 | 189 | else: 190 | raise TypeError("unknown algorithm type, use 'standard', 'log' or 'returns'") 191 | 192 | return beta 193 | 194 | def rank(current,past): 195 | ''' calculate a relative rank 0..1 for a value against series ''' 196 | return (current>past).sum()/float(past.count()) 197 | 198 | 199 | def returns(df): 200 | return (df/df.shift(1)-1) 201 | 202 | def logReturns(df): 203 | t = np.log(df) 204 | return t-t.shift(1) 205 | 206 | def dateTimeToDate(idx): 207 | ''' convert datetime index to date ''' 208 | dates = [] 209 | for dtm in idx: 210 | dates.append(dtm.date()) 211 | return dates 212 | 213 | 214 | 215 | def readBiggerScreener(fName): 216 | ''' import data from Bigger Capital screener ''' 217 | with open(fName,'rb') as f: 218 | reader = csv.reader(f) 219 | rows = [row for row in reader] 220 | 221 | header = rows[0] 222 | data = [[] for i in range(len(header))] 223 | 224 | for row in rows[1:]: 225 | for i,elm in enumerate(row): 226 | try: 227 | data[i].append(float(elm)) 228 | except Exception: 229 | data[i].append(str(elm)) 230 | 231 | 232 | 233 | return DataFrame(dict(zip(header,data)),index=Index(range(len(data[0]))))[header] 234 | 235 | def sharpe(pnl): 236 | return np.sqrt(250)*pnl.mean()/pnl.std() 237 | 238 | 239 | def drawdown(pnl): 240 | """ 241 | calculate max drawdown and duration 242 | 243 | Input: 244 | pnl, in $ 245 | Returns: 246 | drawdown : vector of drawdwon values 247 | duration : vector of drawdown duration 248 | 249 | 250 | """ 251 | cumret = pnl.cumsum() 252 | 253 | highwatermark = [0] 254 | 255 | idx = pnl.index 256 | drawdown = Series(index = idx) 257 | drawdowndur = Series(index = idx) 258 | 259 | for t in range(1, len(idx)) : 260 | highwatermark.append(max(highwatermark[t-1], cumret[t])) 261 | drawdown[t]= (highwatermark[t]-cumret[t]) 262 | drawdowndur[t]= (0 if drawdown[t] == 0 else drawdowndur[t-1]+1) 263 | 264 | return drawdown, drawdowndur 265 | 266 | 267 | def profitRatio(pnl): 268 | ''' 269 | calculate profit ratio as sum(pnl)/drawdown 270 | Input: pnl - daily pnl, Series or DataFrame 271 | ''' 272 | def processVector(pnl): # process a single column 273 | s = pnl.fillna(0) 274 | dd = drawdown(s)[0] 275 | p = s.sum()/dd.max() 276 | return p 277 | 278 | if isinstance(pnl,Series): 279 | return processVector(pnl) 280 | 281 | elif isinstance(pnl,DataFrame): 282 | 283 | p = Series(index = pnl.columns) 284 | 285 | for col in pnl.columns: 286 | p[col] = processVector(pnl[col]) 287 | 288 | return p 289 | else: 290 | raise TypeError("Input must be DataFrame or Series, not "+str(type(pnl))) 291 | 292 | 293 | 294 | 295 | def candlestick(df,width=0.5, colorup='b', colordown='r'): 296 | ''' plot a candlestick chart of a dataframe ''' 297 | 298 | O = df['open'].values 299 | H = df['high'].values 300 | L = df['low'].values 301 | C = df['close'].values 302 | 303 | fig = plt.gcf() 304 | ax = plt.axes() 305 | #ax.hold(True) 306 | 307 | X = df.index 308 | 309 | 310 | #plot high and low 311 | ax.bar(X,height=H-L,bottom=L,width=0.1,color='k') 312 | 313 | idxUp = C>O 314 | ax.bar(X[idxUp],height=(C-O)[idxUp],bottom=O[idxUp],width=width,color=colorup) 315 | 316 | idxDown = C<=O 317 | ax.bar(X[idxDown],height=(O-C)[idxDown],bottom=C[idxDown],width=width,color=colordown) 318 | 319 | try: 320 | fig.autofmt_xdate() 321 | except Exception: # pragma: no cover 322 | pass 323 | 324 | 325 | ax.grid(True) 326 | 327 | #ax.bar(x,height=H-L,bottom=L,width=0.01,color='k') 328 | 329 | def datetime2matlab(t): 330 | ''' convert datetime timestamp to matlab numeric timestamp ''' 331 | mdn = t + dt.timedelta(days = 366) 332 | frac = (t-dt.datetime(t.year,t.month,t.day,0,0,0)).seconds / (24.0 * 60.0 * 60.0) 333 | return mdn.toordinal() + frac 334 | 335 | 336 | if __name__ == '__main__': 337 | df = DataFrame({'open':[1,2,3],'high':[5,6,7],'low':[-2,-1,0],'close':[2,1,4]}) 338 | plt.clf() 339 | candlestick(df) -------------------------------------------------------------------------------- /lib/interactiveBrokers/__init__.py: -------------------------------------------------------------------------------- 1 | from extra import createContract 2 | from tickLogger import logTicks 3 | 4 | from extra import * -------------------------------------------------------------------------------- /lib/interactiveBrokers/extra.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on May 8, 2013 3 | Copyright: Jev Kuznetsov 4 | License: BSD 5 | 6 | convenience functions for interactiveBrokers module 7 | 8 | ''' 9 | from ib.ext.Contract import Contract 10 | 11 | 12 | priceTicks = {1:'bid',2:'ask',4:'last',6:'high',7:'low',9:'close', 14:'open'} 13 | timeFormat = "%Y%m%d %H:%M:%S" 14 | dateFormat = "%Y%m%d" 15 | 16 | 17 | def createContract(symbol,secType='STK',exchange='SMART',currency='USD'): 18 | ''' create contract object ''' 19 | c = Contract() 20 | c.m_symbol = symbol 21 | c.m_secType= secType 22 | c.m_exchange = exchange 23 | c.m_currency = currency 24 | 25 | return c -------------------------------------------------------------------------------- /lib/interactiveBrokers/histData.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on May 8, 2013 3 | Copyright: Jev Kuznetsov 4 | License: BSD 5 | 6 | Module for downloading historic data from IB 7 | 8 | ''' 9 | 10 | import ib 11 | import pandas as pd 12 | from ib.ext.Contract import Contract 13 | from ib.opt import ibConnection, message 14 | 15 | import logger as logger 16 | 17 | from pandas import DataFrame, Index 18 | 19 | import os 20 | import datetime as dt 21 | import time 22 | from time import sleep 23 | from extra import timeFormat, dateFormat 24 | 25 | class Downloader(object): 26 | def __init__(self,debug=False): 27 | self._log = logger.getLogger('DLD') 28 | self._log.debug('Initializing data dwonloader. Pandas version={0}, ibpy version:{1}'.format(pd.__version__,ib.version)) 29 | 30 | self.tws = ibConnection() 31 | self._dataHandler = _HistDataHandler(self.tws) 32 | 33 | if debug: 34 | self.tws.registerAll(self._debugHandler) 35 | self.tws.unregister(self._debugHandler,message.HistoricalData) 36 | 37 | self._log.debug('Connecting to tws') 38 | self.tws.connect() 39 | 40 | self._timeKeeper = TimeKeeper() # keep track of past requests 41 | self._reqId = 1 # current request id 42 | 43 | 44 | def _debugHandler(self,msg): 45 | print '[debug]', msg 46 | 47 | 48 | def requestData(self,contract,endDateTime,durationStr='1 D',barSizeSetting='30 secs',whatToShow='TRADES',useRTH=1,formatDate=1): 49 | 50 | if isinstance(endDateTime,dt.datetime): # convert to string 51 | endDateTime = endDateTime.strftime(timeFormat) 52 | 53 | 54 | self._log.debug('Requesting data for %s end time %s.' % (contract.m_symbol,endDateTime)) 55 | 56 | 57 | while self._timeKeeper.nrRequests(timeSpan=600) > 59: 58 | print 'Too many requests done. Waiting... ' 59 | time.sleep(10) 60 | 61 | self._timeKeeper.addRequest() 62 | self._dataHandler.reset() 63 | self.tws.reqHistoricalData(self._reqId,contract,endDateTime,durationStr,barSizeSetting,whatToShow,useRTH,formatDate) 64 | self._reqId+=1 65 | 66 | #wait for data 67 | startTime = time.time() 68 | timeout = 3 69 | while not self._dataHandler.dataReady and (time.time()-startTime < timeout): 70 | sleep(2) 71 | 72 | if not self._dataHandler.dataReady: 73 | self._log.error('Data timeout') 74 | 75 | print self._dataHandler.data 76 | 77 | return self._dataHandler.data 78 | 79 | # def getIntradayData(self,contract, dateTuple ): 80 | # ''' get full day data on 1-s interval 81 | # date: a tuple of (yyyy,mm,dd) 82 | # ''' 83 | # 84 | # openTime = dt.datetime(*dateTuple)+dt.timedelta(hours=16) 85 | # closeTime = dt.datetime(*dateTuple)+dt.timedelta(hours=22) 86 | # 87 | # timeRange = pd.date_range(openTime,closeTime,freq='30min') 88 | # 89 | # datasets = [] 90 | # 91 | # for t in timeRange: 92 | # datasets.append(self.requestData(contract,t.strftime(timeFormat))) 93 | # 94 | # return pd.concat(datasets) 95 | 96 | 97 | def disconnect(self): 98 | self.tws.disconnect() 99 | 100 | class _HistDataHandler(object): 101 | ''' handles incoming messages ''' 102 | def __init__(self,tws): 103 | self._log = logger.getLogger('DH') 104 | tws.register(self.msgHandler,message.HistoricalData) 105 | self.reset() 106 | 107 | def reset(self): 108 | self._log.debug('Resetting data') 109 | self.dataReady = False 110 | self._timestamp = [] 111 | self._data = {'open':[],'high':[],'low':[],'close':[],'volume':[],'count':[],'WAP':[]} 112 | 113 | def msgHandler(self,msg): 114 | #print '[msg]', msg 115 | 116 | if msg.date[:8] == 'finished': 117 | self._log.debug('Data recieved') 118 | self.dataReady = True 119 | return 120 | 121 | if len(msg.date) > 8: 122 | self._timestamp.append(dt.datetime.strptime(msg.date,timeFormat)) 123 | else: 124 | self._timestamp.append(dt.datetime.strptime(msg.date,dateFormat)) 125 | 126 | 127 | for k in self._data.keys(): 128 | self._data[k].append(getattr(msg, k)) 129 | 130 | @property 131 | def data(self): 132 | ''' return downloaded data as a DataFrame ''' 133 | df = DataFrame(data=self._data,index=Index(self._timestamp)) 134 | return df 135 | 136 | 137 | 138 | class TimeKeeper(object): 139 | ''' 140 | class for keeping track of previous requests, to satify the IB requirements 141 | (max 60 requests / 10 min) 142 | 143 | each time a requiest is made, a timestamp is added to a txt file in the user dir. 144 | 145 | ''' 146 | 147 | def __init__(self): 148 | self._log = logger.getLogger('TK') 149 | dataDir = os.path.expanduser('~')+'/twpData' 150 | 151 | if not os.path.exists(dataDir): 152 | os.mkdir(dataDir) 153 | 154 | self._timeFormat = "%Y%m%d %H:%M:%S" 155 | self.dataFile = os.path.normpath(os.path.join(dataDir,'requests.txt')) 156 | self._log.debug('Data file: {0}'.format(self.dataFile)) 157 | 158 | def addRequest(self): 159 | ''' adds a timestamp of current request''' 160 | with open(self.dataFile,'a') as f: 161 | f.write(dt.datetime.now().strftime(self._timeFormat)+'\n') 162 | 163 | 164 | def nrRequests(self,timeSpan=600): 165 | ''' return number of requests in past timespan (s) ''' 166 | delta = dt.timedelta(seconds=timeSpan) 167 | now = dt.datetime.now() 168 | requests = 0 169 | 170 | with open(self.dataFile,'r') as f: 171 | lines = f.readlines() 172 | 173 | for line in lines: 174 | if now-dt.datetime.strptime(line.strip(),self._timeFormat) < delta: 175 | requests+=1 176 | 177 | if requests==0: # erase all contents if no requests are relevant 178 | open(self.dataFile,'w').close() 179 | 180 | self._log.debug('past requests: {0}'.format(requests)) 181 | return requests 182 | 183 | 184 | -------------------------------------------------------------------------------- /lib/interactiveBrokers/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | ## 5 | # Defines logging formats and logger instance 6 | ## 7 | 8 | import logging 9 | import os 10 | 11 | ## 12 | # Default log message formatting string. 13 | format = '%(asctime)s %(levelname)s [%(name)s] %(message)s' 14 | 15 | ## 16 | # Default log date formatting string. 17 | datefmt = '%d-%b-%y %H:%M:%S' 18 | 19 | ## 20 | # Default log level. Set TWP_LOGLEVEL environment variable to 21 | # change this default. 22 | level = int(os.environ.get('TWP_LOGLEVEL', logging.DEBUG)) 23 | 24 | 25 | def getLogger(name='twp', level=level, format=format, 26 | datefmt=datefmt): 27 | """ Configures and returns a logging instance. 28 | 29 | @param name ignored 30 | @param level logging level 31 | @param format format string for log messages 32 | @param datefmt format string for log dates 33 | @return logging instance (the module) 34 | """ 35 | #print 'Loglevel:' , level 36 | logging.basicConfig(level=level, format=format, datefmt=datefmt) 37 | return logging.getLogger(name) 38 | -------------------------------------------------------------------------------- /lib/interactiveBrokers/tickLogger.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on May 5, 2013 3 | Copyright: Jev Kuznetsov 4 | License: BSD 5 | 6 | Program to log tick events to a file 7 | 8 | example usage: 9 | > python ib_logQuotes.py SPY,VXX,XLE 10 | 11 | start with -v option to show all incoming events 12 | 13 | 14 | ''' 15 | 16 | import argparse # command line argument parser 17 | import datetime as dt # date and time functions 18 | import time # time module for timestamping 19 | import os # used to create directories 20 | import sys # used to print a dot to a terminal without new line 21 | 22 | #--------ibpy imports ---------------------- 23 | from extra import createContract 24 | from ib.opt import ibConnection, message 25 | from ib.ext.Contract import Contract 26 | 27 | 28 | # tick type definitions, see IB api manual 29 | priceTicks = {1:'bid',2:'ask',4:'last',6:'high',7:'low',9:'close', 14:'open'} 30 | sizeTicks = {0:'bid',3:'ask',5:'last',8:'volume'} 31 | 32 | class TickLogger(object): 33 | ''' class for handling incoming ticks and saving them to file 34 | will create a subdirectory 'tickLogs' if needed and start logging 35 | to a file with current timestamp in its name. 36 | All timestamps in the file are in seconds relative to start of logging 37 | 38 | ''' 39 | def __init__(self,tws, subscriptions): 40 | ''' init class, register handlers ''' 41 | 42 | tws.register(self._priceHandler,message.TickPrice) 43 | tws.register(self._sizeHandler,message.TickSize) 44 | 45 | self.subscriptions = subscriptions 46 | 47 | # save starting time of logging. All times will be in seconds relative 48 | # to this moment 49 | self._startTime = time.time() 50 | 51 | # create data directory if it does not exist 52 | if not os.path.exists('tickLogs'): os.mkdir('tickLogs') 53 | 54 | # open data file for writing 55 | fileName = 'tickLogs\\tickLog_%s.csv' % dt.datetime.now().strftime('%H_%M_%S') 56 | print 'Logging ticks to ' , fileName 57 | self.dataFile = open(fileName,'w') 58 | 59 | 60 | def _priceHandler(self,msg): 61 | ''' price tick handler ''' 62 | data = [self.subscriptions[msg.tickerId].m_symbol,'price',priceTicks[msg.field],msg.price] # data, second field is price tick type 63 | self._writeData(data) 64 | 65 | def _sizeHandler(self,msg): 66 | ''' size tick handler ''' 67 | data = [self.subscriptions[msg.tickerId].m_symbol,'size',sizeTicks[msg.field],msg.size] 68 | self._writeData(data) 69 | 70 | def _writeData(self,data): 71 | ''' write data to log file while adding a timestamp ''' 72 | timestamp = '%.3f' % (time.time()-self._startTime) # 1 ms resolution 73 | dataLine = ','.join(str(bit) for bit in [timestamp]+data) + '\n' 74 | self.dataFile.write(dataLine) 75 | 76 | def flush(self): 77 | ''' commits data to file''' 78 | self.dataFile.flush() 79 | 80 | def close(self): 81 | '''close file in a neat manner ''' 82 | print 'Closing data file' 83 | self.dataFile.close() 84 | 85 | 86 | def printMessage(msg): 87 | ''' function to print all incoming messages from TWS ''' 88 | print '[msg]:', msg 89 | 90 | 91 | 92 | 93 | def logTicks(contracts,verbose=False): 94 | ''' 95 | log ticks from IB to a csv file 96 | 97 | Parameters 98 | ---------- 99 | contracts : ib.ext.Contract objects 100 | verbose : print out all tick events 101 | ''' 102 | # check for correct input 103 | assert isinstance(contracts,(list,Contract)) ,'Wrong input, should be a Contract or list of contracts' 104 | 105 | #---create subscriptions dictionary. Keys are subscription ids 106 | subscriptions = {} 107 | try: 108 | for idx, c in enumerate(contracts): 109 | subscriptions[idx+1] = c 110 | except TypeError: # not iterable, one contract provided 111 | subscriptions[1] = contracts 112 | 113 | tws = ibConnection() 114 | logger = TickLogger(tws,subscriptions) 115 | 116 | if verbose: tws.registerAll(printMessage) 117 | 118 | tws.connect() 119 | 120 | #-------subscribe to data 121 | for subId, c in subscriptions.iteritems(): 122 | assert isinstance(c,Contract) , 'Need a Contract object to subscribe' 123 | tws.reqMktData(subId,c,"",False) 124 | 125 | #------start a loop that must be interrupted with Ctrl-C 126 | print 'Press Ctr-C to stop loop' 127 | 128 | try: 129 | while True: 130 | time.sleep(2) # wait a little 131 | logger.flush() # commit data to file 132 | sys.stdout.write('.') # print a dot to the screen 133 | 134 | 135 | except KeyboardInterrupt: 136 | print 'Interrupted with Ctrl-c' 137 | 138 | logger.close() 139 | tws.disconnect() 140 | print 'All done' 141 | 142 | #--------------main script------------------ 143 | 144 | if __name__ == '__main__': 145 | 146 | #-----------parse command line arguments 147 | parser = argparse.ArgumentParser(description='Log ticks for a set of stocks') 148 | 149 | 150 | parser.add_argument("symbols",help = 'symbols separated by coma: SPY,VXX') 151 | parser.add_argument("-v", "--verbose", help="show all incoming messages", 152 | action="store_true") 153 | 154 | args = parser.parse_args() 155 | 156 | symbols = args.symbols.strip().split(',') 157 | print 'Logging ticks for:',symbols 158 | 159 | contracts = [createContract(symbol) for symbol in symbols] 160 | 161 | logTicks(contracts, verbose=args.verbose) 162 | 163 | 164 | 165 | 166 | -------------------------------------------------------------------------------- /lib/interactivebrokers.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright: Jev Kuznetsov 3 | Licence: BSD 4 | 5 | Interface to interactive brokers together with gui widgets 6 | 7 | ''' 8 | import sys 9 | #import os 10 | from time import sleep 11 | from PyQt4.QtCore import (SIGNAL,SLOT) 12 | from PyQt4.QtGui import (QApplication,QFileDialog,QDialog,QVBoxLayout,QHBoxLayout,QDialogButtonBox, 13 | QTableView, QPushButton,QWidget,QLabel,QLineEdit,QGridLayout,QHeaderView) 14 | 15 | import ib 16 | from ib.ext.Contract import Contract 17 | from ib.opt import ibConnection, message 18 | from ib.ext.Order import Order 19 | 20 | import logger as logger 21 | from qtpandas import DataFrameModel, TableView 22 | from eventSystem import Sender 23 | import numpy as np 24 | 25 | import pandas 26 | from pandas import DataFrame, Index 27 | from datetime import datetime 28 | import os 29 | import datetime as dt 30 | import time 31 | 32 | priceTicks = {1:'bid',2:'ask',4:'last',6:'high',7:'low',9:'close', 14:'open'} 33 | timeFormat = "%Y%m%d %H:%M:%S" 34 | dateFormat = "%Y%m%d" 35 | 36 | def createContract(symbol, secType='STK', exchange='SMART',currency='USD'): 37 | ''' contract factory function ''' 38 | contract = Contract() 39 | contract.m_symbol = symbol 40 | contract.m_secType = secType 41 | contract.m_exchange = exchange 42 | contract.m_currency = currency 43 | 44 | return contract 45 | 46 | def _str2datetime(s): 47 | """ convert string to datetime """ 48 | return datetime.strptime( s,'%Y%m%d') 49 | 50 | 51 | def readActivityFlex(fName): 52 | """ 53 | parse trade log in a csv file produced by IB 'Activity Flex Query' 54 | the file should contain these columns: 55 | ['Symbol','TradeDate','Quantity','TradePrice','IBCommission'] 56 | 57 | Returns: 58 | A DataFrame with parsed trade data 59 | 60 | """ 61 | import csv 62 | 63 | rows = [] 64 | 65 | with open(fName,'rb') as f: 66 | reader = csv.reader(f) 67 | for row in reader: 68 | rows.append(row) 69 | 70 | header = ['TradeDate','Symbol','Quantity','TradePrice','IBCommission'] 71 | 72 | types =dict(zip(header,[ _str2datetime,str , int, float, float])) 73 | idx = dict(zip(header,[rows[0].index(h) for h in header ] )) 74 | data = dict(zip(header,[[] for h in header])) 75 | 76 | for row in rows[1:]: 77 | print row 78 | for col in header: 79 | val = types[col](row[idx[col]]) 80 | data[col].append(val) 81 | 82 | return DataFrame(data)[header].sort(column = 'TradeDate') 83 | 84 | 85 | 86 | class Subscriptions(DataFrameModel, Sender): 87 | ''' a data table containing price & subscription data ''' 88 | def __init__(self, tws=None): 89 | 90 | super(Subscriptions,self).__init__() 91 | self.df = DataFrame() # this property holds the data in a table format 92 | 93 | self._nextId = 1 94 | self._id2symbol = {} # id-> symbol lookup dict 95 | self._header = ['id','position','bid','ask','last'] # columns of the _data table 96 | 97 | # register callbacks 98 | if tws is not None: 99 | tws.register(self.priceHandler,message.TickPrice) 100 | tws.register(self.accountHandler,message.UpdatePortfolio) 101 | 102 | 103 | 104 | def add(self,symbol, subId = None): 105 | ''' 106 | Add a subscription to data table 107 | return : subscription id 108 | 109 | ''' 110 | if subId is None: 111 | subId = self._nextId 112 | 113 | data = dict(zip(self._header,[subId,0,np.nan,np.nan,np.nan])) 114 | row = DataFrame(data, index = Index([symbol])) 115 | 116 | self.df = self.df.append(row[self._header]) # append data and set correct column order 117 | 118 | self._nextId = subId+1 119 | self._rebuildIndex() 120 | 121 | self.emit(SIGNAL("layoutChanged()")) 122 | 123 | return subId 124 | 125 | def priceHandler(self,msg): 126 | ''' handler function for price updates. register this with ibConnection class ''' 127 | 128 | if priceTicks[msg.field] not in self._header: # do nothing for ticks that are not in _data table 129 | return 130 | 131 | self.df[priceTicks[msg.field]][self._id2symbol[msg.tickerId]]=msg.price 132 | 133 | #notify viewer 134 | col = self._header.index(priceTicks[msg.field]) 135 | row = self.df.index.tolist().index(self._id2symbol[msg.tickerId]) 136 | 137 | idx = self.createIndex(row,col) 138 | self.emit(SIGNAL("dataChanged(QModelIndex,QModelIndex)"),idx, idx) 139 | 140 | def accountHandler(self,msg): 141 | if msg.contract.m_symbol in self.df.index.tolist(): 142 | self.df['position'][msg.contract.m_symbol]=msg.position 143 | 144 | def _rebuildIndex(self): 145 | ''' udate lookup dictionary id-> symbol ''' 146 | symbols = self.df.index.tolist() 147 | ids = self.df['id'].values.tolist() 148 | self._id2symbol = dict(zip(ids,symbols)) 149 | 150 | 151 | def __repr__(self): 152 | return str(self.df) 153 | 154 | 155 | 156 | class Broker(object): 157 | ''' 158 | Broker class acts as a wrapper around ibConnection 159 | from ibPy. It tracks current subscriptions and provides 160 | data models to viewiers . 161 | ''' 162 | def __init__(self, name='broker'): 163 | ''' initialize broker class 164 | @param dbFile: sqlite database file. will be created if it does not exist 165 | ''' 166 | self.name = name 167 | self.log = logger.getLogger(self.name) 168 | 169 | self.log.debug('Initializing broker. Pandas version={0}'.format(pandas.__version__)) 170 | self.contracts = {} # a dict to keep track of subscribed contracts 171 | 172 | self.tws = ibConnection() # tws interface 173 | self.nextValidOrderId = None 174 | 175 | self.dataModel = Subscriptions(self.tws) # data container 176 | 177 | self.tws.registerAll(self.defaultHandler) 178 | #self.tws.register(self.debugHandler,message.TickPrice) 179 | self.tws.register(self.nextValidIdHandler,'NextValidId') 180 | self.log.debug('Connecting to tws') 181 | self.tws.connect() 182 | 183 | self.tws.reqAccountUpdates(True,'') 184 | 185 | def subscribeStk(self,symbol, secType='STK', exchange='SMART',currency='USD'): 186 | ''' subscribe to stock data ''' 187 | self.log.debug('Subscribing to '+symbol) 188 | # if symbol in self.data.symbols: 189 | # print 'Already subscribed to {0}'.format(symbol) 190 | # return 191 | 192 | c = Contract() 193 | c.m_symbol = symbol 194 | c.m_secType = secType 195 | c.m_exchange = exchange 196 | c.m_currency = currency 197 | 198 | subId = self.dataModel.add(symbol) 199 | self.tws.reqMktData(subId,c,'',False) 200 | 201 | self.contracts[symbol]=c 202 | 203 | return subId 204 | 205 | @property 206 | def data(self): 207 | return self.dataModel.df 208 | 209 | 210 | def placeOrder(self,symbol,shares,limit=None,exchange='SMART', transmit=0): 211 | ''' place an order on already subscribed contract ''' 212 | 213 | 214 | if symbol not in self.contracts.keys(): 215 | self.log.error("Can't place order, not subscribed to %s" % symbol) 216 | return 217 | 218 | action = {-1:'SELL',1:'BUY'} 219 | 220 | o= Order() 221 | o.m_orderId = self.getOrderId() 222 | o.m_action = action[cmp(shares,0)] 223 | o.m_totalQuantity = abs(shares) 224 | o.m_transmit = transmit 225 | 226 | if limit is not None: 227 | o.m_orderType = 'LMT' 228 | o.m_lmtPrice = limit 229 | 230 | self.log.debug('Placing %s order for %i %s (id=%i)' % (o.m_action,o.m_totalQuantity,symbol,o.m_orderId)) 231 | 232 | 233 | self.tws.placeOrder(o.m_orderId,self.contracts[symbol],o) 234 | 235 | 236 | 237 | def getOrderId(self): 238 | self.nextValidOrderId+=1 239 | return self.nextValidOrderId-1 240 | 241 | def unsubscribeStk(self,symbol): 242 | self.log.debug('Function not implemented') 243 | 244 | def disconnect(self): 245 | self.tws.disconnect() 246 | 247 | def __del__(self): 248 | '''destructor, clean up ''' 249 | print 'Broker is cleaning up after itself.' 250 | self.tws.disconnect() 251 | 252 | def debugHandler(self,msg): 253 | print msg 254 | 255 | def defaultHandler(self,msg): 256 | ''' default message handler ''' 257 | #print msg.typeName 258 | if msg.typeName == 'Error': 259 | self.log.error(msg) 260 | 261 | 262 | def nextValidIdHandler(self,msg): 263 | self.nextValidOrderId = msg.orderId 264 | self.log.debug( 'Next valid order id:{0}'.format(self.nextValidOrderId)) 265 | 266 | def saveData(self, fname): 267 | ''' save current dataframe to csv ''' 268 | self.log.debug("Saving data to {0}".format(fname)) 269 | self.dataModel.df.to_csv(fname) 270 | 271 | # def __getattr__(self, name): 272 | # """ x.__getattr__('name') <==> x.name 273 | # an easy way to call ibConnection methods 274 | # @return named attribute from instance tws 275 | # """ 276 | # return getattr(self.tws, name) 277 | 278 | 279 | 280 | class _HistDataHandler(object): 281 | ''' handles incoming messages ''' 282 | def __init__(self,tws): 283 | self._log = logger.getLogger('DH') 284 | tws.register(self.msgHandler,message.HistoricalData) 285 | self.reset() 286 | 287 | def reset(self): 288 | self._log.debug('Resetting data') 289 | self.dataReady = False 290 | self._timestamp = [] 291 | self._data = {'open':[],'high':[],'low':[],'close':[],'volume':[],'count':[],'WAP':[]} 292 | 293 | def msgHandler(self,msg): 294 | #print '[msg]', msg 295 | 296 | if msg.date[:8] == 'finished': 297 | self._log.debug('Data recieved') 298 | self.dataReady = True 299 | return 300 | 301 | if len(msg.date) > 8: 302 | self._timestamp.append(dt.datetime.strptime(msg.date,timeFormat)) 303 | else: 304 | self._timestamp.append(dt.datetime.strptime(msg.date,dateFormat)) 305 | 306 | 307 | for k in self._data.keys(): 308 | self._data[k].append(getattr(msg, k)) 309 | 310 | @property 311 | def data(self): 312 | ''' return downloaded data as a DataFrame ''' 313 | df = DataFrame(data=self._data,index=Index(self._timestamp)) 314 | return df 315 | 316 | 317 | class Downloader(object): 318 | def __init__(self,debug=False): 319 | self._log = logger.getLogger('DLD') 320 | self._log.debug('Initializing data dwonloader. Pandas version={0}, ibpy version:{1}'.format(pandas.__version__,ib.version)) 321 | 322 | self.tws = ibConnection() 323 | self._dataHandler = _HistDataHandler(self.tws) 324 | 325 | if debug: 326 | self.tws.registerAll(self._debugHandler) 327 | self.tws.unregister(self._debugHandler,message.HistoricalData) 328 | 329 | self._log.debug('Connecting to tws') 330 | self.tws.connect() 331 | 332 | self._timeKeeper = TimeKeeper() # keep track of past requests 333 | self._reqId = 1 # current request id 334 | 335 | 336 | def _debugHandler(self,msg): 337 | print '[debug]', msg 338 | 339 | 340 | def requestData(self,contract,endDateTime,durationStr='1 D',barSizeSetting='30 secs',whatToShow='TRADES',useRTH=1,formatDate=1): 341 | self._log.debug('Requesting data for %s end time %s.' % (contract.m_symbol,endDateTime)) 342 | 343 | 344 | while self._timeKeeper.nrRequests(timeSpan=600) > 59: 345 | print 'Too many requests done. Waiting... ' 346 | time.sleep(10) 347 | 348 | self._timeKeeper.addRequest() 349 | self._dataHandler.reset() 350 | self.tws.reqHistoricalData(self._reqId,contract,endDateTime,durationStr,barSizeSetting,whatToShow,useRTH,formatDate) 351 | self._reqId+=1 352 | 353 | #wait for data 354 | startTime = time.time() 355 | timeout = 3 356 | while not self._dataHandler.dataReady and (time.time()-startTime < timeout): 357 | sleep(2) 358 | 359 | if not self._dataHandler.dataReady: 360 | self._log.error('Data timeout') 361 | 362 | print self._dataHandler.data 363 | 364 | return self._dataHandler.data 365 | 366 | def getIntradayData(self,contract, dateTuple ): 367 | ''' get full day data on 1-s interval 368 | date: a tuple of (yyyy,mm,dd) 369 | ''' 370 | 371 | openTime = dt.datetime(*dateTuple)+dt.timedelta(hours=16) 372 | closeTime = dt.datetime(*dateTuple)+dt.timedelta(hours=22) 373 | 374 | timeRange = pandas.date_range(openTime,closeTime,freq='30min') 375 | 376 | datasets = [] 377 | 378 | for t in timeRange: 379 | datasets.append(self.requestData(contract,t.strftime(timeFormat))) 380 | 381 | return pandas.concat(datasets) 382 | 383 | 384 | def disconnect(self): 385 | self.tws.disconnect() 386 | 387 | 388 | class TimeKeeper(object): 389 | def __init__(self): 390 | self._log = logger.getLogger('TK') 391 | dataDir = os.path.expanduser('~')+'/twpData' 392 | 393 | if not os.path.exists(dataDir): 394 | os.mkdir(dataDir) 395 | 396 | self._timeFormat = "%Y%m%d %H:%M:%S" 397 | self.dataFile = os.path.normpath(os.path.join(dataDir,'requests.txt')) 398 | self._log.debug('Data file: {0}'.format(self.dataFile)) 399 | 400 | def addRequest(self): 401 | ''' adds a timestamp of current request''' 402 | with open(self.dataFile,'a') as f: 403 | f.write(dt.datetime.now().strftime(self._timeFormat)+'\n') 404 | 405 | 406 | def nrRequests(self,timeSpan=600): 407 | ''' return number of requests in past timespan (s) ''' 408 | delta = dt.timedelta(seconds=timeSpan) 409 | now = dt.datetime.now() 410 | requests = 0 411 | 412 | with open(self.dataFile,'r') as f: 413 | lines = f.readlines() 414 | 415 | for line in lines: 416 | if now-dt.datetime.strptime(line.strip(),self._timeFormat) < delta: 417 | requests+=1 418 | 419 | if requests==0: # erase all contents if no requests are relevant 420 | open(self.dataFile,'w').close() 421 | 422 | self._log.debug('past requests: {0}'.format(requests)) 423 | return requests 424 | 425 | 426 | 427 | #---------------test functions----------------- 428 | 429 | def dummyHandler(msg): 430 | print msg 431 | 432 | def testConnection(): 433 | ''' a simple test to check working of streaming prices etc ''' 434 | tws = ibConnection() 435 | tws.registerAll(dummyHandler) 436 | 437 | tws.connect() 438 | 439 | c = createContract('SPY') 440 | tws.reqMktData(1,c,'',False) 441 | sleep(3) 442 | 443 | print 'testConnection done.' 444 | 445 | 446 | 447 | def testSubscriptions(): 448 | s = Subscriptions() 449 | s.add('SPY') 450 | #s.add('XLE') 451 | 452 | print s 453 | 454 | def testBroker(): 455 | b = Broker() 456 | sleep(2) 457 | b.subscribeStk('SPY') 458 | b.subscribeStk('XLE') 459 | b.subscribeStk('GOOG') 460 | 461 | b.placeOrder('ABC', 125, 55.1) 462 | sleep(3) 463 | return b 464 | 465 | #---------------------GUI stuff-------------------------------------------- 466 | class AddSubscriptionDlg(QDialog): 467 | def __init__(self,parent=None): 468 | super(AddSubscriptionDlg,self).__init__(parent) 469 | symbolLabel = QLabel('Symbol') 470 | self.symbolEdit = QLineEdit() 471 | secTypeLabel = QLabel('secType') 472 | self.secTypeEdit = QLineEdit('STK') 473 | exchangeLabel = QLabel('exchange') 474 | self.exchangeEdit = QLineEdit('SMART') 475 | currencyLabel = QLabel('currency') 476 | self.currencyEdit = QLineEdit('USD') 477 | 478 | buttonBox = QDialogButtonBox(QDialogButtonBox.Ok| 479 | QDialogButtonBox.Cancel) 480 | 481 | lay = QGridLayout() 482 | lay.addWidget(symbolLabel,0,0) 483 | lay.addWidget(self.symbolEdit,0,1) 484 | lay.addWidget(secTypeLabel,1,0) 485 | lay.addWidget(self.secTypeEdit,1,1) 486 | lay.addWidget(exchangeLabel,2,0) 487 | lay.addWidget(self.exchangeEdit,2,1) 488 | lay.addWidget(currencyLabel,3,0) 489 | lay.addWidget(self.currencyEdit,3,1) 490 | 491 | lay.addWidget(buttonBox,4,0,1,2) 492 | self.setLayout(lay) 493 | 494 | self.connect(buttonBox, SIGNAL("accepted()"), 495 | self, SLOT("accept()")) 496 | self.connect(buttonBox, SIGNAL("rejected()"), 497 | self, SLOT("reject()")) 498 | self.setWindowTitle("Add subscription") 499 | 500 | class BrokerWidget(QWidget): 501 | def __init__(self,broker,parent = None ): 502 | super(BrokerWidget,self).__init__(parent) 503 | 504 | self.broker = broker 505 | 506 | self.dataTable = TableView() 507 | self.dataTable.setModel(self.broker.dataModel) 508 | self.dataTable.horizontalHeader().setResizeMode(QHeaderView.Stretch) 509 | #self.dataTable.resizeColumnsToContents() 510 | dataLabel = QLabel('Price Data') 511 | dataLabel.setBuddy(self.dataTable) 512 | 513 | dataLayout = QVBoxLayout() 514 | 515 | dataLayout.addWidget(dataLabel) 516 | dataLayout.addWidget(self.dataTable) 517 | 518 | 519 | addButton = QPushButton("&Add Symbol") 520 | saveDataButton = QPushButton("&Save Data") 521 | #deleteButton = QPushButton("&Delete") 522 | 523 | buttonLayout = QVBoxLayout() 524 | buttonLayout.addWidget(addButton) 525 | buttonLayout.addWidget(saveDataButton) 526 | buttonLayout.addStretch() 527 | 528 | layout = QHBoxLayout() 529 | layout.addLayout(dataLayout) 530 | layout.addLayout(buttonLayout) 531 | self.setLayout(layout) 532 | 533 | self.connect(addButton,SIGNAL('clicked()'),self.addSubscription) 534 | self.connect(saveDataButton,SIGNAL('clicked()'),self.saveData) 535 | #self.connect(deleteButton,SIGNAL('clicked()'),self.deleteSubscription) 536 | 537 | def addSubscription(self): 538 | dialog = AddSubscriptionDlg(self) 539 | if dialog.exec_(): 540 | self.broker.subscribeStk(str(dialog.symbolEdit.text()),str( dialog.secTypeEdit.text()), 541 | str(dialog.exchangeEdit.text()),str(dialog.currencyEdit.text())) 542 | 543 | def saveData(self): 544 | ''' save data to a .csv file ''' 545 | fname =unicode(QFileDialog.getSaveFileName( self, caption="Save data to csv",filter = '*.csv')) 546 | if fname: 547 | self.broker.saveData(fname) 548 | 549 | 550 | # def deleteSubscription(self): 551 | # pass 552 | 553 | 554 | class Form(QDialog): 555 | def __init__(self, parent=None): 556 | super(Form, self).__init__(parent) 557 | self.resize(640,480) 558 | self.setWindowTitle('Broker test') 559 | 560 | self.broker = Broker() 561 | 562 | self.broker.subscribeStk('SPY') 563 | self.broker.subscribeStk('XLE') 564 | self.broker.subscribeStk('GOOG') 565 | 566 | brokerWidget = BrokerWidget(self.broker,self) 567 | lay = QVBoxLayout() 568 | lay.addWidget(brokerWidget) 569 | self.setLayout(lay) 570 | 571 | def startGui(): 572 | app = QApplication(sys.argv) 573 | form = Form() 574 | form.show() 575 | app.exec_() 576 | 577 | if __name__ == "__main__": 578 | import ib 579 | print 'iby version:' , ib.version 580 | testConnection() 581 | #testBroker() 582 | #testSubscriptions() 583 | print message.messageTypeNames() 584 | #startGui() 585 | print 'All done' 586 | 587 | -------------------------------------------------------------------------------- /lib/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | ## 5 | # Defines logging formats and logger instance 6 | ## 7 | 8 | import logging 9 | import os 10 | 11 | ## 12 | # Default log message formatting string. 13 | format = '%(asctime)s %(levelname)s [%(name)s] %(message)s' 14 | 15 | ## 16 | # Default log date formatting string. 17 | datefmt = '%d-%b-%y %H:%M:%S' 18 | 19 | ## 20 | # Default log level. Set TWP_LOGLEVEL environment variable to 21 | # change this default. 22 | level = int(os.environ.get('TWP_LOGLEVEL', logging.DEBUG)) 23 | 24 | 25 | def getLogger(name='twp', level=level, format=format, 26 | datefmt=datefmt): 27 | """ Configures and returns a logging instance. 28 | 29 | @param name ignored 30 | @param level logging level 31 | @param format format string for log messages 32 | @param datefmt format string for log dates 33 | @return logging instance (the module) 34 | """ 35 | #print 'Loglevel:' , level 36 | logging.basicConfig(level=level, format=format, datefmt=datefmt) 37 | return logging.getLogger(name) 38 | -------------------------------------------------------------------------------- /lib/qtpandas.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Easy integration of DataFrame into pyqt framework 3 | 4 | Copyright: Jev Kuznetsov 5 | Licence: BSD 6 | 7 | ''' 8 | from PyQt4.QtCore import (QAbstractTableModel,Qt,QVariant,QModelIndex,SIGNAL) 9 | from PyQt4.QtGui import (QApplication,QDialog,QVBoxLayout, QTableView, 10 | QWidget,QTableWidget, QHeaderView, QFont,QMenu,QAbstractItemView) 11 | 12 | from pandas import DataFrame, Index 13 | 14 | 15 | 16 | 17 | class DataFrameModel(QAbstractTableModel): 18 | ''' data model for a DataFrame class ''' 19 | def __init__(self,parent=None): 20 | super(DataFrameModel,self).__init__(parent) 21 | self.df = DataFrame() 22 | self.columnFormat = {} # format columns 23 | 24 | def setFormat(self,fmt): 25 | """ 26 | set string formatting for the output 27 | example : format = {'close':"%.2f"} 28 | """ 29 | 30 | self.columnFormat = fmt 31 | 32 | def setDataFrame(self,dataFrame): 33 | self.df = dataFrame 34 | self.signalUpdate() 35 | 36 | def signalUpdate(self): 37 | ''' tell viewers to update their data (this is full update, not efficient)''' 38 | self.layoutChanged.emit() 39 | 40 | def __repr__(self): 41 | return str(self.df) 42 | 43 | #------------- table display functions ----------------- 44 | def headerData(self,section,orientation,role=Qt.DisplayRole): 45 | if role != Qt.DisplayRole: 46 | return QVariant() 47 | 48 | if orientation == Qt.Horizontal: 49 | try: 50 | return self.df.columns.tolist()[section] 51 | except (IndexError, ): 52 | return QVariant() 53 | elif orientation == Qt.Vertical: 54 | try: 55 | #return self.df.index.tolist() 56 | return str(self.df.index.tolist()[section]) 57 | except (IndexError, ): 58 | return QVariant() 59 | 60 | def data(self, index, role=Qt.DisplayRole): 61 | if role != Qt.DisplayRole: 62 | return QVariant() 63 | 64 | if not index.isValid(): 65 | return QVariant() 66 | 67 | col = self.df.ix[:,index.column()] # get a column slice first to get the right data type 68 | elm = col[index.row()] 69 | #elm = self.df.ix[index.row(),index.column()] 70 | 71 | if self.df.columns[index.column()] in self.columnFormat.keys(): 72 | return QVariant(self.columnFormat[self.df.columns[index.column()]] % elm ) 73 | else: 74 | return QVariant(str(elm)) 75 | 76 | def sort(self,nCol,order): 77 | 78 | self.layoutAboutToBeChanged.emit() 79 | if order == Qt.AscendingOrder: 80 | self.df = self.df.sort(column=self.df.columns[nCol], ascending=True) 81 | elif order == Qt.DescendingOrder: 82 | self.df = self.df.sort(column=self.df.columns[nCol], ascending=False) 83 | 84 | self.layoutChanged.emit() 85 | 86 | 87 | 88 | def rowCount(self, index=QModelIndex()): 89 | return self.df.shape[0] 90 | 91 | def columnCount(self, index=QModelIndex()): 92 | return self.df.shape[1] 93 | 94 | 95 | class TableView(QTableView): 96 | """ extended table view """ 97 | def __init__(self,name='TableView1', parent=None): 98 | super(TableView,self).__init__(parent) 99 | self.name = name 100 | self.setSelectionBehavior(QAbstractItemView.SelectRows) 101 | 102 | def contextMenuEvent(self, event): 103 | menu = QMenu(self) 104 | 105 | Action = menu.addAction("print selected rows") 106 | Action.triggered.connect(self.printName) 107 | 108 | menu.exec_(event.globalPos()) 109 | 110 | def printName(self): 111 | print "Action triggered from " + self.name 112 | 113 | print 'Selected rows:' 114 | for idx in self.selectionModel().selectedRows(): 115 | print idx.row() 116 | 117 | 118 | class DataFrameWidget(QWidget): 119 | ''' a simple widget for using DataFrames in a gui ''' 120 | def __init__(self,name='DataFrameTable1', parent=None): 121 | super(DataFrameWidget,self).__init__(parent) 122 | self.name = name 123 | 124 | self.dataModel = DataFrameModel() 125 | self.dataModel.setDataFrame(DataFrame()) 126 | 127 | self.dataTable = QTableView() 128 | self.dataTable.setSelectionBehavior(QAbstractItemView.SelectRows) 129 | self.dataTable.setSortingEnabled(True) 130 | 131 | self.dataTable.setModel(self.dataModel) 132 | self.dataModel.signalUpdate() 133 | 134 | #self.dataTable.setFont(QFont("Courier New", 8)) 135 | 136 | layout = QVBoxLayout() 137 | layout.addWidget(self.dataTable) 138 | self.setLayout(layout) 139 | 140 | 141 | 142 | def setFormat(self,fmt): 143 | """ set non-default string formatting for a column """ 144 | for colName, f in fmt.iteritems(): 145 | self.dataModel.columnFormat[colName]=f 146 | 147 | def fitColumns(self): 148 | self.dataTable.horizontalHeader().setResizeMode(QHeaderView.Stretch) 149 | 150 | def setDataFrame(self,df): 151 | self.dataModel.setDataFrame(df) 152 | 153 | 154 | def resizeColumnsToContents(self): 155 | self.dataTable.resizeColumnsToContents() 156 | 157 | #-----------------stand alone test code 158 | 159 | def testDf(): 160 | ''' creates test dataframe ''' 161 | data = {'int':[1,2,3],'float':[1./3,2.5,3.5],'string':['a','b','c'],'nan':[np.nan,np.nan,np.nan]} 162 | return DataFrame(data, index=Index(['AAA','BBB','CCC']))[['int','float','string','nan']] 163 | 164 | 165 | class Form(QDialog): 166 | def __init__(self,parent=None): 167 | super(Form,self).__init__(parent) 168 | 169 | df = testDf() # make up some data 170 | widget = DataFrameWidget(parent=self) 171 | widget.setDataFrame(df) 172 | #widget.resizeColumnsToContents() 173 | widget.fitColumns() 174 | widget.setFormat({'float': '%.2f'}) 175 | 176 | 177 | layout = QVBoxLayout() 178 | layout.addWidget(widget) 179 | self.setLayout(layout) 180 | 181 | if __name__=='__main__': 182 | import sys 183 | import numpy as np 184 | 185 | app = QApplication(sys.argv) 186 | form = Form() 187 | form.show() 188 | app.exec_() 189 | 190 | 191 | 192 | 193 | 194 | -------------------------------------------------------------------------------- /lib/vixFutures.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | set of tools for working with VIX futures 4 | 5 | @author: Jev Kuznetsov 6 | Licence: GPL v2 7 | """ 8 | 9 | import datetime as dt 10 | from pandas import * 11 | import os 12 | import urllib2 13 | #from csvDatabase import HistDataCsv 14 | 15 | m_codes = dict(zip(range(1,13),['F','G','H','J','K','M','N','Q','U','V','X','Z'])) #month codes of the futures 16 | monthToCode = dict(zip(range(1,len(m_codes)+1),m_codes)) 17 | 18 | 19 | def getCboeData(year,month): 20 | ''' download data from cboe ''' 21 | fName = "CFE_{0}{1}_VX.csv".format(m_codes[month],str(year)[-2:]) 22 | urlStr = "http://cfe.cboe.com/Publish/ScheduledTask/MktData/datahouse/{0}".format(fName) 23 | 24 | try: 25 | lines = urllib2.urlopen(urlStr).readlines() 26 | except Exception, e: 27 | s = "Failed to download:\n{0}".format(e); 28 | print s 29 | 30 | # first column is date, second is future , skip these 31 | header = lines[0].strip().split(',')[2:] 32 | 33 | dates = [] 34 | data = [[] for i in range(len(header))] 35 | 36 | 37 | 38 | for line in lines[1:]: 39 | fields = line.strip().split(',') 40 | dates.append(datetime.strptime( fields[0],'%m/%d/%Y')) 41 | for i,field in enumerate(fields[2:]): 42 | data[i].append(float(field)) 43 | 44 | data = dict(zip(header,data)) 45 | 46 | df = DataFrame(data=data, index=Index(dates)) 47 | 48 | return df 49 | 50 | class Future(object): 51 | ''' vix future class ''' 52 | def __init__(self,year,month): 53 | self.year = year 54 | self.month = month 55 | self.expiration = self._calculateExpirationDate() 56 | self.cboeData = None # daily cboe data 57 | self.intradayDb = None # intraday database (csv) 58 | 59 | def _calculateExpirationDate(self): 60 | ''' calculate expiration date of the future, (not 100% reliable) ''' 61 | t = dt.date(self.year,self.month,1)+datetools.relativedelta(months=1) 62 | offset = datetools.Week(weekday=4) 63 | if t.weekday()<>4: 64 | t_new = t+3*offset 65 | else: 66 | t_new = t+2*offset 67 | 68 | t_new = t_new-datetools.relativedelta(days=30) 69 | return t_new 70 | 71 | 72 | def getCboeData(self, dataDir=None, forceUpdate=False): 73 | ''' download interday CBOE data 74 | specify dataDir to save data to csv. 75 | data will not be downloaded if csv file is already present. 76 | This can be overridden with setting forceUpdate to True 77 | ''' 78 | 79 | 80 | if dataDir is not None: 81 | fileFound = os.path.exists(self._csvFilename(dataDir)) 82 | 83 | if forceUpdate or not fileFound: 84 | self.cboeData = getCboeData(self.year, self.month) 85 | self.to_csv(dataDir) 86 | else: 87 | self.cboeData = DataFrame.from_csv(self._csvFilename(dataDir)) 88 | 89 | else: 90 | self.cboeData = getCboeData(self.year, self.month) 91 | 92 | 93 | return self.cboeData 94 | 95 | def updateIntradayDb(self,dbDir): 96 | #self.intradayDb = 97 | pass 98 | 99 | def to_csv(self,dataDir): 100 | ''' save to csv in given dir. Filename is automatically generated ''' 101 | self.cboeData.to_csv(self._csvFilename(dataDir)) 102 | 103 | 104 | @property 105 | def dates(self): 106 | ''' trading days derived from cboe data ''' 107 | if self.cboeData is not None: 108 | dates = [d.date() for d in self.cboeData.index] 109 | else: 110 | dates = None 111 | 112 | return dates 113 | 114 | def _csvFilename(self,dataDir): 115 | fName = "VIX_future_%i_%i.csv" % (self.year, self.month) 116 | return os.path.join(dataDir,fName) 117 | 118 | def __repr__(self): 119 | s = 'Vix future [%i-%i (%s)] exp: %s\n' % (self.year, self.month,monthToCode[self.month], self.expiration.strftime("%B, %d %Y (%A)")) 120 | s+= 'Cboe data: %i days'% len(self.cboeData) if self.cboeData is not None else 'No data downloaded yet' 121 | return s 122 | 123 | 124 | 125 | if __name__ == '__main__': 126 | print 'testing vix futures' 127 | 128 | year = 2012 129 | month = 12 130 | 131 | 132 | f = Future(year,month) 133 | f.getCboeData() 134 | print f 135 | 136 | -------------------------------------------------------------------------------- /lib/widgets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | A collection of widgets for gui building 4 | 5 | Copyright: Jev Kuznetsov 6 | License: BSD 7 | """ 8 | 9 | from __future__ import division 10 | 11 | import sys 12 | from PyQt4.QtCore import * 13 | from PyQt4.QtGui import * 14 | 15 | 16 | import numpy as np 17 | from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg as FigureCanvas 18 | from matplotlib.backends.backend_qt4agg import NavigationToolbar2QTAgg as NavigationToolbar 19 | from matplotlib.figure import Figure 20 | import matplotlib.pyplot as plt 21 | 22 | 23 | class MatplotlibWidget(QWidget): 24 | def __init__(self,parent=None,grid=True): 25 | QWidget.__init__(self,parent) 26 | 27 | self.grid = grid 28 | 29 | 30 | self.fig = Figure() 31 | self.canvas =FigureCanvas(self.fig) 32 | self.canvas.setParent(self) 33 | self.canvas.mpl_connect('button_press_event', self.onPick) # bind pick event 34 | 35 | 36 | #self.axes = self.fig.add_subplot(111) 37 | margins = [0.05,0.1,0.9,0.8] 38 | self.axes = self.fig.add_axes(margins) 39 | self.toolbar = NavigationToolbar(self.canvas,self) 40 | 41 | 42 | #self.initFigure() 43 | 44 | layout = QVBoxLayout() 45 | layout.addWidget(self.toolbar) 46 | layout.addWidget(self.canvas) 47 | 48 | self.setLayout(layout) 49 | 50 | def onPick(self,event): 51 | print 'Pick event' 52 | print 'you pressed', event.button, event.xdata, event.ydata 53 | 54 | def update(self): 55 | self.canvas.draw() 56 | 57 | def plot(self,*args,**kwargs): 58 | self.axes.plot(*args,**kwargs) 59 | self.axes.grid(self.grid) 60 | self.update() 61 | 62 | def clear(self): 63 | self.axes.clear() 64 | 65 | def initFigure(self): 66 | self.axes.grid(True) 67 | x = np.linspace(-1,1) 68 | y = x**2 69 | self.axes.plot(x,y,'o-') 70 | 71 | 72 | class PlotWindow(QMainWindow): 73 | ''' a stand-alone window with embedded matplotlib widget ''' 74 | def __init__(self,parent=None): 75 | super(PlotWindow,self).__init__(parent) 76 | self.setAttribute(Qt.WA_DeleteOnClose) 77 | self.mplWidget = MatplotlibWidget() 78 | self.setCentralWidget(self.mplWidget) 79 | 80 | def plot(self,dataFrame): 81 | ''' plot dataframe ''' 82 | dataFrame.plot(ax=self.mplWidget.axes) 83 | 84 | def getAxes(self): 85 | return self.mplWidget.axes 86 | 87 | def getFigure(self): 88 | return self.mplWidget.fig 89 | 90 | def update(self): 91 | self.mplWidget.update() 92 | 93 | class MainForm(QMainWindow): 94 | def __init__(self, parent=None): 95 | QMainWindow.__init__(self, parent) 96 | self.setWindowTitle('Demo: PyQt with matplotlib') 97 | 98 | self.plot = MatplotlibWidget() 99 | self.setCentralWidget(self.plot) 100 | 101 | self.plot.clear() 102 | self.plot.plot(np.random.rand(10),'x-') 103 | 104 | 105 | #--------------------- 106 | if __name__=='__main__': 107 | app = QApplication(sys.argv) 108 | form = MainForm() 109 | form.show() 110 | app.exec_() -------------------------------------------------------------------------------- /lib/yahooFinance.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Author: Jev Kuznetsov 4 | # License: BSD 5 | 6 | 7 | """ 8 | Toolset working with yahoo finance data 9 | 10 | This module includes functions for easy access to YahooFinance data 11 | 12 | Functions 13 | ---------- 14 | - `getHistoricData` get historic data for a single symbol 15 | - `getQuote` get current quote for a symbol 16 | - `getScreenerSymbols` load symbols from a yahoo stock screener file 17 | 18 | Classes 19 | --------- 20 | - `HistData` a class for working with multiple symbols 21 | 22 | 23 | 24 | """ 25 | 26 | 27 | from datetime import datetime, date 28 | import urllib2 29 | from pandas import DataFrame, Index, HDFStore, WidePanel 30 | import numpy as np 31 | import os 32 | from extra import ProgressBar 33 | 34 | class HistData(object): 35 | ''' a class for working with yahoo finance data ''' 36 | def __init__(self, autoAdjust=True): 37 | 38 | self.startDate = (2008,1,1) 39 | self.autoAdjust=autoAdjust 40 | self.wp = WidePanel() 41 | 42 | 43 | def load(self,dataFile): 44 | """load data from HDF""" 45 | if os.path.exists(dataFile): 46 | store = HDFStore(dataFile) 47 | symbols = store.keys() 48 | data = dict(zip(symbols,[store[symbol] for symbol in symbols])) 49 | self.wp = WidePanel(data) 50 | store.close() 51 | else: 52 | raise IOError('Data file does not exist') 53 | 54 | 55 | def save(self,dataFile): 56 | """ save data to HDF""" 57 | print 'Saving data to', dataFile 58 | store = HDFStore(dataFile) 59 | for symbol in self.wp.items: 60 | store[symbol] = self.wp[symbol] 61 | 62 | store.close() 63 | 64 | 65 | 66 | def downloadData(self,symbols='all'): 67 | ''' get data from yahoo ''' 68 | 69 | if symbols == 'all': 70 | symbols = self.symbols 71 | 72 | #store = HDFStore(self.dataFile) 73 | p = ProgressBar(len(symbols)) 74 | 75 | for idx,symbol in enumerate(symbols): 76 | 77 | try: 78 | df = getHistoricData(symbol,self.startDate,verbose=False) 79 | if self.autoAdjust: 80 | df = _adjust(df,removeOrig=True) 81 | 82 | if len(self.symbols)==0: 83 | self.wp = WidePanel({symbol:df}) 84 | else: 85 | self.wp[symbol] = df 86 | 87 | except Exception,e: 88 | print e 89 | p.animate(idx+1) 90 | 91 | def getDataFrame(self,field='close'): 92 | ''' return a slice on wide panel for a given field ''' 93 | return self.wp.minor_xs(field) 94 | 95 | 96 | @property 97 | def symbols(self): 98 | return self.wp.items.tolist() 99 | 100 | 101 | def __repr__(self): 102 | return str(self.wp) 103 | 104 | 105 | def getQuote(symbols): 106 | ''' get current yahoo quote 107 | 108 | 109 | , return a DataFrame ''' 110 | 111 | if not isinstance(symbols,list): 112 | symbols = [symbols] 113 | # for codes see: http://www.gummy-stuff.org/Yahoo-data.htm 114 | codes = {'symbol':'s','last':'l1','change_pct':'p2','PE':'r','time':'t1','short_ratio':'s7'} 115 | request = str.join('',codes.values()) 116 | header = codes.keys() 117 | 118 | data = dict(zip(codes.keys(),[[] for i in range(len(codes))])) 119 | 120 | urlStr = 'http://finance.yahoo.com/d/quotes.csv?s=%s&f=%s' % (str.join('+',symbols), request) 121 | 122 | try: 123 | lines = urllib2.urlopen(urlStr).readlines() 124 | except Exception, e: 125 | s = "Failed to download:\n{0}".format(e); 126 | print s 127 | 128 | for line in lines: 129 | fields = line.strip().split(',') 130 | #print fields 131 | for i,field in enumerate(fields): 132 | if field[0] == '"': 133 | data[header[i]].append( field.strip('"')) 134 | else: 135 | try: 136 | data[header[i]].append(float(field)) 137 | except ValueError: 138 | data[header[i]].append(np.nan) 139 | 140 | idx = data.pop('symbol') 141 | 142 | return DataFrame(data,index=idx) 143 | 144 | def _historicDataUrll(symbol, sDate=(1990,1,1),eDate=date.today().timetuple()[0:3]): 145 | """ 146 | generate url 147 | 148 | symbol: Yahoo finanance symbol 149 | sDate: start date (y,m,d) 150 | eDate: end date (y,m,d) 151 | """ 152 | 153 | urlStr = 'http://ichart.finance.yahoo.com/table.csv?s={0}&a={1}&b={2}&c={3}&d={4}&e={5}&f={6}'.\ 154 | format(symbol.upper(),sDate[1]-1,sDate[2],sDate[0],eDate[1]-1,eDate[2],eDate[0]) 155 | 156 | return urlStr 157 | 158 | def getHistoricData(symbol, sDate=(1990,1,1),eDate=date.today().timetuple()[0:3],verbose=True): 159 | """ 160 | get data from Yahoo finance and return pandas dataframe 161 | 162 | symbol: Yahoo finanance symbol 163 | sDate: start date (y,m,d) 164 | eDate: end date (y,m,d) 165 | """ 166 | 167 | urlStr = 'http://ichart.finance.yahoo.com/table.csv?s={0}&a={1}&b={2}&c={3}&d={4}&e={5}&f={6}'.\ 168 | format(symbol.upper(),sDate[1]-1,sDate[2],sDate[0],eDate[1]-1,eDate[2],eDate[0]) 169 | 170 | 171 | try: 172 | lines = urllib2.urlopen(urlStr).readlines() 173 | except Exception, e: 174 | s = "Failed to download:\n{0}".format(e); 175 | print s 176 | 177 | dates = [] 178 | data = [[] for i in range(6)] 179 | #high 180 | 181 | # header : Date,Open,High,Low,Close,Volume,Adj Close 182 | for line in lines[1:]: 183 | #print line 184 | fields = line.rstrip().split(',') 185 | dates.append(datetime.strptime( fields[0],'%Y-%m-%d')) 186 | for i,field in enumerate(fields[1:]): 187 | data[i].append(float(field)) 188 | 189 | idx = Index(dates) 190 | data = dict(zip(['open','high','low','close','volume','adj_close'],data)) 191 | 192 | # create a pandas dataframe structure 193 | df = DataFrame(data,index=idx).sort() 194 | 195 | if verbose: 196 | print 'Got %i days of data' % len(df) 197 | 198 | return df 199 | 200 | def _adjust(df, removeOrig=False): 201 | ''' 202 | _adjustust hist data based on adj_close field 203 | ''' 204 | c = df['close']/df['adj_close'] 205 | 206 | df['adj_open'] = df['open']/c 207 | df['adj_high'] = df['high']/c 208 | df['adj_low'] = df['low']/c 209 | 210 | if removeOrig: 211 | df=df.drop(['open','close','high','low'],axis=1) 212 | renames = dict(zip(['adj_open','adj_close','adj_high','adj_low'],['open','close','high','low'])) 213 | df=df.rename(columns=renames) 214 | 215 | return df 216 | 217 | def getScreenerSymbols(fileName): 218 | ''' read symbols from a .csv saved by yahoo stock screener ''' 219 | 220 | with open(fileName,'r') as fid: 221 | lines = fid.readlines() 222 | 223 | symbols = [] 224 | for line in lines[3:]: 225 | fields = line.strip().split(',') 226 | field = fields[0].strip() 227 | if len(field) > 0: 228 | symbols.append(field) 229 | return symbols 230 | 231 | -------------------------------------------------------------------------------- /nautilus/nautilus.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on 26 dec. 2011 3 | Copyright: Jev Kuznetsov 4 | License: BSD 5 | ''' 6 | 7 | 8 | 9 | from PyQt4.QtCore import * 10 | from PyQt4.QtGui import * 11 | 12 | from ib.ext.Contract import Contract 13 | from ib.opt import ibConnection 14 | from ib.ext.Order import Order 15 | 16 | import tradingWithPython.lib.logger as logger 17 | from tradingWithPython.lib.eventSystem import Sender, ExampleListener 18 | import tradingWithPython.lib.qtpandas as qtpandas 19 | import numpy as np 20 | 21 | import pandas 22 | 23 | 24 | priceTicks = {1:'bid',2:'ask',4:'last',6:'high',7:'low',9:'close', 14:'open'} 25 | 26 | 27 | class PriceListener(qtpandas.DataFrameModel): 28 | def __init__(self): 29 | super(PriceListener,self).__init__() 30 | self._header = ['position','bid','ask','last'] 31 | 32 | def addSymbol(self,symbol): 33 | data = dict(zip(self._header,[0,np.nan,np.nan,np.nan])) 34 | row = pandas.DataFrame(data, index = pandas.Index([symbol])) 35 | self.df = self.df.append(row[self._header]) # append data and set correct column order 36 | 37 | 38 | def priceHandler(self,sender,event,msg=None): 39 | 40 | if msg['symbol'] not in self.df.index: 41 | self.addSymbol(msg['symbol']) 42 | 43 | if msg['type'] in self._header: 44 | self.df.ix[msg['symbol'],msg['type']] = msg['price'] 45 | self.signalUpdate() 46 | #print self.df 47 | 48 | 49 | 50 | class Broker(Sender): 51 | def __init__(self, name = "broker"): 52 | super(Broker,self).__init__() 53 | 54 | self.name = name 55 | self.log = logger.getLogger(self.name) 56 | 57 | self.log.debug('Initializing broker. Pandas version={0}'.format(pandas.__version__)) 58 | self.contracts = {} # a dict to keep track of subscribed contracts 59 | self._id2symbol = {} # id-> symbol dict 60 | self.tws = None 61 | self._nextId = 1 # tws subscription id 62 | self.nextValidOrderId = None 63 | 64 | 65 | 66 | def connect(self): 67 | """ connect to tws """ 68 | self.tws = ibConnection() # tws interface 69 | self.tws.registerAll(self._defaultHandler) 70 | self.tws.register(self._nextValidIdHandler,'NextValidId') 71 | self.log.debug('Connecting to tws') 72 | self.tws.connect() 73 | 74 | self.tws.reqAccountUpdates(True,'') 75 | self.tws.register(self._priceHandler,'TickPrice') 76 | 77 | def subscribeStk(self,symbol, secType='STK', exchange='SMART',currency='USD'): 78 | ''' subscribe to stock data ''' 79 | self.log.debug('Subscribing to '+symbol) 80 | c = Contract() 81 | c.m_symbol = symbol 82 | c.m_secType = secType 83 | c.m_exchange = exchange 84 | c.m_currency = currency 85 | 86 | subId = self._nextId 87 | self._nextId += 1 88 | 89 | self.tws.reqMktData(subId,c,'',False) 90 | self._id2symbol[subId] = c.m_symbol 91 | self.contracts[symbol]=c 92 | 93 | 94 | def disconnect(self): 95 | self.tws.disconnect() 96 | #------event handlers-------------------- 97 | 98 | def _defaultHandler(self,msg): 99 | ''' default message handler ''' 100 | #print msg.typeName 101 | if msg.typeName == 'Error': 102 | self.log.error(msg) 103 | 104 | 105 | def _nextValidIdHandler(self,msg): 106 | self.nextValidOrderId = msg.orderId 107 | self.log.debug( 'Next valid order id:{0}'.format(self.nextValidOrderId)) 108 | 109 | def _priceHandler(self,msg): 110 | #translate to meaningful messages 111 | message = {'symbol':self._id2symbol[msg.tickerId], 112 | 'price':msg.price, 113 | 'type':priceTicks[msg.field]} 114 | self.dispatch('price',message) 115 | 116 | 117 | #-----------------GUI elements------------------------- 118 | 119 | class TableView(QTableView): 120 | """ extended table view """ 121 | def __init__(self,name='TableView1', parent=None): 122 | super(TableView,self).__init__(parent) 123 | self.name = name 124 | self.setSelectionBehavior(QAbstractItemView.SelectRows) 125 | 126 | def contextMenuEvent(self, event): 127 | menu = QMenu(self) 128 | 129 | Action = menu.addAction("print selected rows") 130 | Action.triggered.connect(self.printName) 131 | 132 | menu.exec_(event.globalPos()) 133 | 134 | def printName(self): 135 | print "Action triggered from " + self.name 136 | 137 | print 'Selected :' 138 | for idx in self.selectionModel().selectedRows(): 139 | print self.model().df.ix[idx.row(),:] 140 | 141 | 142 | 143 | class Form(QDialog): 144 | def __init__(self,parent=None): 145 | super(Form,self).__init__(parent) 146 | 147 | self.broker = Broker() 148 | self.price = PriceListener() 149 | 150 | self.broker.connect() 151 | 152 | symbols = ['SPY','XLE','QQQ','VXX','XIV'] 153 | for symbol in symbols: 154 | self.broker.subscribeStk(symbol) 155 | 156 | self.broker.register(self.price.priceHandler, 'price') 157 | 158 | 159 | widget = TableView(parent=self) 160 | widget.setModel(self.price) 161 | widget.horizontalHeader().setResizeMode(QHeaderView.Stretch) 162 | 163 | layout = QVBoxLayout() 164 | layout.addWidget(widget) 165 | self.setLayout(layout) 166 | 167 | def __del__(self): 168 | print 'Disconnecting.' 169 | self.broker.disconnect() 170 | 171 | if __name__=="__main__": 172 | print "Running nautilus" 173 | 174 | import sys 175 | app = QApplication(sys.argv) 176 | form = Form() 177 | form.show() 178 | app.exec_() 179 | print "All done." -------------------------------------------------------------------------------- /sandbox/dataModels.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | 4 | """ 5 | 6 | 7 | 8 | import sqlite3 as lite 9 | import sys 10 | from PyQt4.QtCore import (QAbstractTableModel,Qt,QVariant,QModelIndex, SIGNAL) 11 | from PyQt4.QtGui import (QApplication,QDialog,QVBoxLayout, QTableView, QWidget, QHeaderView) 12 | 13 | 14 | #import sys, os 15 | 16 | def initDb(dbName): 17 | '''reset database ''' 18 | print 'Resetting ' , dbName 19 | con = lite.connect(dbName) 20 | cur = con.cursor() 21 | 22 | cur.execute("DROP TABLE IF EXISTS tbl_symbols") 23 | cur.execute("""CREATE TABLE tbl_symbols ( 24 | id INTEGER PRIMARY KEY AUTOINCREMENT, 25 | symbol TEXT, 26 | secType TEXT DEFAULT 'STK', 27 | currency TEXT DEFAULT 'USD', 28 | exchange TEXT DEFAULT 'SMART', 29 | active BOOLEAN DEFAULT 1)""") 30 | 31 | 32 | symbols = ('AAA','BBB','CCC') 33 | for symbol in symbols: 34 | cur.execute("INSERT INTO tbl_symbols (symbol) VALUES(?) ",(symbol,)) 35 | 36 | con.commit() 37 | 38 | 39 | 40 | class SqliteTableModel(QAbstractTableModel): 41 | ''' base class for interfacing to sqlite db''' 42 | def __init__(self,dbConnection,tableName,parent=None): 43 | super(SqliteTableModel,self).__init__(parent) 44 | self._con = dbConnection 45 | self._table = tableName 46 | self._data = [] 47 | self._header = [] 48 | self._reload() 49 | 50 | 51 | 52 | def _reload(self): 53 | ' reload all data' 54 | q = "SELECT * FROM %s " % self._table 55 | cur = self._con.cursor() 56 | cur.execute(q) 57 | 58 | self._data = [] 59 | for row in cur: 60 | self._data.append([]) 61 | curr = self._data[-1] 62 | for elm in row: 63 | curr.append(elm) 64 | 65 | cur.execute("PRAGMA table_info(%s)" % self._table) 66 | for c in cur: 67 | self._header.append(c[1]) 68 | 69 | def __repr__(self): 70 | return str(self._header)+'\n'+str.join('\n',[ str(row) for row in self._data]) 71 | 72 | 73 | 74 | 75 | class Symbols(QAbstractTableModel): 76 | ''' class for managing a group of spreads through sqlite ''' 77 | def __init__(self,dbConnection,tableName): 78 | 79 | self.tblName = tableName # name of the database table 80 | self.con = dbConnection 81 | 82 | self.cur = self.con.cursor() 83 | self.data = [] 84 | 85 | def sql(self,query): 86 | cur = self.con.cursor() 87 | cur.execute(query) 88 | return cur.fetchall() 89 | 90 | 91 | def initDb(self): 92 | self.cur.execute("DROP TABLE IF EXISTS tbl_symbols") 93 | self.cur.execute("""CREATE TABLE tbl_symbols ( 94 | id INTEGER PRIMARY KEY AUTOINCREMENT, 95 | symbol TEXT, 96 | secType TEXT DEFAULT 'STK', 97 | currency TEXT DEFAULT 'USD', 98 | exchange TEXT DEFAULT 'SMART', 99 | active BOOLEAN DEFAULT 1)""") 100 | 101 | 102 | self.con.commit() 103 | 104 | 105 | def addSymbol(self, symbol): 106 | t = (symbol,) 107 | self.cur.execute("INSERT INTO tbl_symbols (symbol) VALUES(?) ",t) 108 | 109 | 110 | def load(self): 111 | ''' reload full table from db ''' 112 | q = "SELECT * FROM tbl_symbols " 113 | self.cur.execute(q) 114 | 115 | self.data = [] 116 | for row in self.cur: 117 | self.data.append([]) 118 | curr = self.data[-1] 119 | for elm in row: 120 | curr.append(elm) 121 | 122 | 123 | 124 | def printTable(self): 125 | self.load() 126 | print '-'*10 127 | print self.data 128 | 129 | def _testFcn(self): 130 | self.sql("insert into tbl_symbols ") 131 | 132 | def showTables(self): 133 | self.cur.execute("select name from sqlite_master where type='table' ") 134 | res = self.cur.fetchall() 135 | for row in res: 136 | print row[0] 137 | 138 | def __del__(self): 139 | self.con.close() 140 | 141 | 142 | #----------------test code 143 | class Form(QDialog): 144 | def __init__(self, parent=None): 145 | super(Form, self).__init__(parent) 146 | self.resize(640,480) 147 | self.setWindowTitle('Model test') 148 | 149 | model = SqliteTableModel(con,'tbl_symbols') 150 | table = QTableView() 151 | table.setModel(model) 152 | 153 | lay = QVBoxLayout() 154 | lay.addWidget(table) 155 | self.setLayout(lay) 156 | 157 | def startGui(): 158 | app = QApplication(sys.argv) 159 | form = Form() 160 | form.show() 161 | app.exec_() 162 | 163 | def testModel(): 164 | ' simple model test, without gui' 165 | m = SqliteTableModel(con,'tbl_symbols') 166 | print m 167 | 168 | if __name__=='__main__': 169 | 170 | dbName = 'testDb' 171 | 172 | #initDb(dbName) 173 | con = lite.connect(dbName) 174 | cur = con.cursor() 175 | cur.execute("select name from sqlite_master where type='table' ") 176 | 177 | for row in cur: 178 | print row[0] 179 | 180 | 181 | testModel() 182 | #startGui() 183 | 184 | 185 | -------------------------------------------------------------------------------- /sandbox/guiWithDatabase.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | from future_builtins import * 6 | 7 | import os 8 | import sys 9 | from PyQt4.QtCore import * 10 | from PyQt4.QtGui import * 11 | from PyQt4.QtSql import (QSqlDatabase, QSqlQuery, QSqlRelation, 12 | QSqlRelationalDelegate, QSqlRelationalTableModel, QSqlTableModel) 13 | 14 | 15 | 16 | import sqlite3 as db 17 | 18 | dbName = 'test.db' 19 | 20 | 21 | def initDb(): 22 | '''reset database ''' 23 | con = db.connect(dbName) 24 | cur = con.cursor() 25 | 26 | cur.execute("DROP TABLE IF EXISTS tbl_symbols") 27 | cur.execute("""CREATE TABLE tbl_symbols ( 28 | id INTEGER PRIMARY KEY AUTOINCREMENT, 29 | symbol TEXT, 30 | secType TEXT DEFAULT 'STK', 31 | currency TEXT DEFAULT 'USD', 32 | exchange TEXT DEFAULT 'SMART', 33 | active BOOLEAN DEFAULT 1)""") 34 | 35 | 36 | symbols = ('AAA','BBB','CCC') 37 | for symbol in symbols: 38 | cur.execute("INSERT INTO tbl_symbols (symbol) VALUES(?) ",(symbol,)) 39 | 40 | con.commit() 41 | 42 | 43 | #----------------- 44 | class MainForm(QDialog): 45 | def __init__(self): 46 | super(MainForm, self).__init__() 47 | 48 | self.model = QSqlTableModel(self) 49 | self.model.setTable('tbl_symbols') 50 | self.model.select() 51 | 52 | self.view = QTableView() 53 | self.view.setModel(self.model) 54 | self.view.horizontalHeader().setResizeMode(QHeaderView.Stretch) 55 | 56 | addButton = QPushButton("&Add") 57 | deleteButton = QPushButton("&Delete") 58 | 59 | buttonLayout = QVBoxLayout() 60 | buttonLayout.addWidget(addButton) 61 | buttonLayout.addWidget(deleteButton) 62 | buttonLayout.addStretch() 63 | 64 | lay = QHBoxLayout() 65 | lay.addWidget(self.view) 66 | lay.addLayout(buttonLayout) 67 | self.setLayout(lay) 68 | 69 | self.connect(addButton, SIGNAL("clicked()"), self.addRecord) 70 | self.connect(deleteButton, SIGNAL("clicked()"), self.deleteRecord) 71 | 72 | 73 | def addRecord(self): 74 | row = self.model.rowCount() 75 | self.model.insertRow(row) 76 | index = self.model.index(row, 1) 77 | self.view.setCurrentIndex(index) 78 | self.view.edit(index) 79 | 80 | def deleteRecord(self): 81 | index = self.view.currentIndex() 82 | if not index.isValid(): 83 | return 84 | #QSqlDatabase.database().transaction() 85 | record = self.model.record(index.row()) 86 | self.model.removeRow(index.row()) 87 | self.model.submitAll() 88 | 89 | def main(): 90 | app = QApplication(sys.argv) 91 | db = QSqlDatabase.addDatabase("QSQLITE") 92 | db.setDatabaseName(dbName) 93 | if not db.open(): 94 | QMessageBox.warning(None, "Asset Manager", 95 | QString("Database Error: %1") 96 | .arg(db.lastError().text())) 97 | sys.exit(1) 98 | 99 | form = MainForm() 100 | form.show() 101 | app.exec_() 102 | del form 103 | del db 104 | #initDb() 105 | main() -------------------------------------------------------------------------------- /sandbox/spreadCalculations.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on 28 okt 2011 3 | 4 | @author: jev 5 | ''' 6 | 7 | from tradingWithPython import estimateBeta, Spread, returns, Portfolio, readBiggerScreener 8 | from tradingWithPython.lib import yahooFinance 9 | from pandas import DataFrame, Series 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | import os 13 | 14 | 15 | 16 | symbols = ['SPY','IWM'] 17 | y = yahooFinance.HistData('temp.csv') 18 | y.startDate = (2007,1,1) 19 | 20 | df = y.loadSymbols(symbols,forceDownload=False) 21 | #df = y.downloadData(symbols) 22 | 23 | res = readBiggerScreener('CointPairs.csv') 24 | 25 | #---check with spread scanner 26 | #sp = DataFrame(index=symbols) 27 | # 28 | #sp['last'] = df.ix[-1,:] 29 | #sp['targetCapital'] = Series({'SPY':100,'IWM':-100}) 30 | #sp['targetShares'] = sp['targetCapital']/sp['last'] 31 | #print sp 32 | 33 | #The dollar-neutral ratio is about 1 * IWM - 1.7 * IWM. You will get the spread = zero (or probably very near zero) 34 | 35 | 36 | #s = Spread(symbols, histClose = df) 37 | #print s 38 | 39 | #s.value.plot() 40 | 41 | #print 'beta (returns)', estimateBeta(df[symbols[0]],df[symbols[1]],algo='returns') 42 | #print 'beta (log)', estimateBeta(df[symbols[0]],df[symbols[1]],algo='log') 43 | #print 'beta (standard)', estimateBeta(df[symbols[0]],df[symbols[1]],algo='standard') 44 | 45 | #p = Portfolio(df) 46 | #p.setShares([1, -1.7]) 47 | #p.value.plot() 48 | 49 | 50 | quote = yahooFinance.getQuote(symbols) 51 | print quote 52 | 53 | 54 | s = Spread(symbols,histClose=df, estimateBeta = False) 55 | s.setLast(quote['last']) 56 | 57 | s.setShares(Series({'SPY':1,'IWM':-1.7})) 58 | print s 59 | #s.value.plot() 60 | #s.plot() 61 | fig = figure(2) 62 | s.plot() 63 | 64 | 65 | -------------------------------------------------------------------------------- /sandbox/spreadGroup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Dec 09 18:41:08 2011 4 | 5 | @author: jev 6 | """ 7 | 8 | 9 | 10 | import sqlite3 as db 11 | #import sys, os 12 | 13 | 14 | class Symbols(object): 15 | ''' class for managing a group of spreads through sqlite ''' 16 | def __init__(self,fName='spreads.db'): 17 | self.con = db.connect(fName) 18 | self.cur = self.con.cursor() 19 | 20 | def sql(self,query): 21 | cur = self.con.cursor() 22 | cur.execute(query) 23 | return cur.fetchall() 24 | 25 | 26 | def initDb(self): 27 | self.cur.execute("DROP TABLE IF EXISTS tbl_symbols") 28 | self.cur.execute("""CREATE TABLE tbl_symbols ( 29 | id INTEGER PRIMARY KEY AUTOINCREMENT, 30 | symbol TEXT, 31 | secType TEXT DEFAULT 'STK', 32 | currency TEXT DEFAULT 'USD', 33 | exchange TEXT DEFAULT 'SMART', 34 | active BOOLEAN DEFAULT 1)""") 35 | 36 | 37 | self.con.commit() 38 | 39 | 40 | def addSymbol(self, symbol): 41 | t = (symbol,) 42 | self.cur.execute("INSERT INTO tbl_symbols (symbol) VALUES(?) ",t) 43 | 44 | def printTable(self,table): 45 | 46 | q = "SELECT * FROM "+table # insecure, but ? does not work here 47 | self.cur.execute(q) 48 | print '-'*10+table+"-"*10 49 | for row in self.cur: 50 | print row 51 | 52 | def _testFcn(self): 53 | self.sql("insert into tbl_symbols ") 54 | 55 | def showTables(self): 56 | self.cur.execute("select name from sqlite_master where type='table' ") 57 | res = self.cur.fetchall() 58 | for row in res: 59 | print row[0] 60 | 61 | def __del__(self): 62 | self.con.close() 63 | 64 | 65 | 66 | 67 | 68 | if __name__=='__main__': 69 | g = Symbols() 70 | g.initDb() 71 | g.showTables() 72 | g.addSymbol('SPY') 73 | g.addSymbol('XYZ') 74 | g.printTable('tbl_symbols') 75 | -------------------------------------------------------------------------------- /spreadApp/gold_stocks.csv: -------------------------------------------------------------------------------- 1 | Screener Results: 2 | 3 | Ticker,Company Name,Last Trade,Trade Time,Mkt Cap,Return On Equity,Return On Assets,Forward PE 4 | AU , AngloGold Ashanti , 44,25 , 1:30pm , 85,484B , 26,85 , 10,694 , 7,68 5 | ABX , Barrick Gold Corp , 47,66 , 1:30pm , 47,650B , 19,136 , 9,884 , 8,29 6 | GG , Goldcorp Incorpor , 47,82 , 1:30pm , 38,721B , 9,357 , 5,266 , 15,92 7 | NEM , Newmont Mining Co , 64,77 , 1:30pm , 32,049B , 19,569 , 9,911 , 10,96 8 | KGC , Kinross Gold Corp , 12,62 , 1:30pm , 14,356B , 5,149 , 4,694 , 10,44 9 | AUY , Yamana Gold, Inc. , 15,43 , 1:30pm , 11,506B , 7,85 , 5,275 , 12,55 10 | GFI , Gold Fields Ltd. , 15,67 , 1:30pm , 11,332B , 0 , 0 , 0 11 | BVN , Compania Mina Bue , 37,39 , 1:30pm , 9,514B , 32,741 , 13,33 , 9,11 12 | EGO , Eldorado Gold Cor , 15,73 , 1:30pm , 8,671B , 9,3 , 7,446 , 15,21 13 | AEM , Agnico-Eagle Mine , 41,33 , 1:30pm , 7,000B , 3,283 , 2,603 , 13,05 14 | IAG , Iamgold Corporati , 18,23 , 1:30pm , 6,852B , 15,331 , 10,41 , 12,73 15 | HMY , Harmony Gold Mini , 12,85 , 1:30pm , 5,527B , 0 , 0 , 11,57 16 | AUQ , AuRico Gold Inc. , 8,92 , 1:30pm , 2,511B , 13,228 , 9,076 , 7,44 17 | JAG , Jaguar Mining Inc , 6,26 , 1:30pm , 528,4M , -28,931 , 0,948 , 0 18 | UXG , U.S. Gold Corpora , 3,45 , 1:30pm , 482,1M , -25,537 , -10,547 , 0 19 | 20 | 21 | -------------------------------------------------------------------------------- /spreadApp/makeDist.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | import py2exe 3 | 4 | manifest_template = ''' 5 | 6 | 7 | 13 | %(prog)s Program 14 | 15 | 16 | 24 | 25 | 26 | 27 | ''' 28 | 29 | RT_MANIFEST = 24 30 | import matplotlib 31 | 32 | 33 | opts = { 34 | 'py2exe': { 35 | "compressed": 1, 36 | "bundle_files" : 3, 37 | "includes" : ["sip", 38 | "matplotlib.backends", 39 | "matplotlib.backends.backend_qt4agg", 40 | "pylab", "numpy", 41 | "matplotlib.backends.backend_tkagg"], 42 | 'excludes': ['_gtkagg', '_tkagg', '_agg2', 43 | '_cairo', '_cocoaagg', 44 | '_fltkagg', '_gtk', '_gtkcairo', ], 45 | 'dll_excludes': ['libgdk-win32-2.0-0.dll', 46 | 'libgobject-2.0-0.dll'] 47 | 48 | } 49 | } 50 | 51 | 52 | 53 | setup(name="triton", 54 | version = "0.1", 55 | scripts=["spreadScanner.pyw"], 56 | windows=[{"script": "spreadScanner.pyw"}], 57 | options=opts, 58 | data_files=matplotlib.get_py2exe_datafiles(), 59 | other_resources = [(RT_MANIFEST, 1, manifest_template % dict(prog="spreadDetective"))], 60 | zipfile = None) -------------------------------------------------------------------------------- /spreadApp/spreadScanner.pyw: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on 12 dec. 2011 3 | Copyright: Jev Kuznetsov 4 | License: BSD 5 | ''' 6 | 7 | import sys, os 8 | 9 | __version__ = "0.1.0" 10 | 11 | from PyQt4.QtCore import (Qt, SIGNAL) 12 | from PyQt4.QtGui import * 13 | import platform 14 | 15 | import widgets.ui_symbolChooser 16 | from tradingWithPython.lib.yahooFinance import getScreenerSymbols 17 | import qrc_resources 18 | from tradingWithPython import readBiggerScreener 19 | from tradingWithPython.lib.qtpandas import DataFrameWidget, DataFrameModel 20 | from tradingWithPython.lib.widgets import PlotWindow 21 | from tradingWithPython.lib.classes import Spread 22 | from tradingWithPython.lib.functions import returns 23 | 24 | from widgets.spread import SpreadWidget 25 | 26 | import numpy as np 27 | import matplotlib.pyplot as plt 28 | from pandas import DataFrame,Index,Series 29 | 30 | #---------globals 31 | dataFile = 'yahooData.csv' 32 | dataStartDate = (2010,1,1) 33 | #--------classes 34 | 35 | class SymbolChooser(QWidget,widgets.ui_symbolChooser.Ui_Form): 36 | ''' symbol chooser widget ''' 37 | def __init__(self,parent=None): 38 | super(SymbolChooser,self).__init__(parent) 39 | self.setupUi(self) 40 | 41 | def symbols(self): 42 | symbols = [] 43 | for i in range(self.listSymbols.count()): 44 | symbols.append(str(self.listSymbols.item(i).text())) 45 | return symbols 46 | 47 | 48 | 49 | 50 | class SpreadViewModel(DataFrameModel): 51 | """ modified version of the model to hack around sorting issue""" 52 | def __init__(self,parent=None): 53 | super(SpreadViewModel,self).__init__(parent=None) 54 | 55 | def sort(self,nCol,order): 56 | 57 | self.layoutAboutToBeChanged.emit() 58 | 59 | col = self.df[self.df.columns[nCol]].values.tolist() 60 | 61 | #create indexed list, sort, rebuild complete data frame 8( 62 | di = [(d,i) for i,d in enumerate(col)] 63 | 64 | if order == Qt.AscendingOrder: 65 | idx = [i for d,i in sorted(di)] 66 | elif order == Qt.DescendingOrder: 67 | idx = [i for d,i in sorted(di,reverse=True)] 68 | 69 | data = self.df.values[idx,:] 70 | cols = self.df.columns 71 | # rebuild the whole thing 72 | self.df = DataFrame(data=data, columns=cols, index = Index(range(len(idx)))) 73 | 74 | self.layoutChanged.emit() 75 | 76 | 77 | 78 | 79 | class SpreadView(QTableView): 80 | """ extended table view """ 81 | def __init__(self,name='TableView1', parent=None): 82 | super(SpreadView,self).__init__(parent) 83 | self.name = name 84 | self.setSelectionBehavior(QAbstractItemView.SelectRows) 85 | 86 | 87 | def contextMenuEvent(self, event): 88 | menu = QMenu(self) 89 | 90 | Action = menu.addAction("Show spread") 91 | Action.triggered.connect(self.showSpread) 92 | 93 | menu.exec_(event.globalPos()) 94 | 95 | def showSpread(self): 96 | """ open a spread window """ 97 | for idx in self.selectionModel().selectedRows(): 98 | row = self.selectionModel().model().df.ix[idx.row(),:] 99 | symbols = [row['StockA'],row['StockB']] 100 | spread = Spread(symbols) 101 | spread.setShares(Series({row['StockA']:1,row['StockB']:-row['Price Ratio']})) 102 | spreadWindow = SpreadWindow(self) 103 | spreadWindow.setSpread(spread) 104 | 105 | spreadWindow.show() 106 | 107 | 108 | class SpreadWindow(QMainWindow): 109 | def __init__(self,parent=None): 110 | super(SpreadWindow,self).__init__(parent) 111 | self.resize(640,600) 112 | self.setWindowTitle('Spread test') 113 | 114 | self.widget = SpreadWidget(self) 115 | self.setCentralWidget(self.widget) 116 | 117 | self.spread = None 118 | 119 | def setSpread(self,spread): 120 | 121 | self.spread = spread 122 | self.setWindowTitle(spread.name) 123 | self.widget.setSpread(self.spread) 124 | 125 | class BiggerSpreads(QWidget): 126 | """ class for working with spreads from screener """ 127 | def __init__(self, parent=None): 128 | super(QWidget,self).__init__(parent) 129 | self.name = 'bigger spreads' 130 | 131 | self.df = DataFrame() # main data container 132 | self.dataModel = SpreadViewModel() 133 | self.dataModel.setDataFrame(self.df) 134 | 135 | self.dataTable = SpreadView() 136 | self.dataTable.setSortingEnabled(True) 137 | 138 | self.dataTable.setModel(self.dataModel) 139 | self.dataModel.signalUpdate() 140 | 141 | 142 | layout = QVBoxLayout() 143 | layout.addWidget(self.dataTable) 144 | self.setLayout(layout) 145 | 146 | def loadSpreads(self,fName): 147 | self.df = readBiggerScreener(fName) 148 | self.dataModel.setDataFrame(self.df) 149 | #self.dataTable.resizeColumnsToContents() 150 | self.dataTable.horizontalHeader().setResizeMode(QHeaderView.Stretch) 151 | 152 | 153 | class MainWindow(QMainWindow): 154 | def __init__(self, parent=None): 155 | super(MainWindow, self).__init__(parent) 156 | 157 | self.setWindowTitle('Spread Detective [alpha]') 158 | 159 | # general vars 160 | self.filename = None 161 | self.actions = {} # actions list 162 | 163 | #fill central area 164 | self.dataTable = BiggerSpreads() 165 | self.setCentralWidget(self.dataTable) 166 | 167 | #create actions 168 | self.actions['loadScreener'] = self.createAction("Load symbols",self.loadScreenerSymbols,icon="fileopen") 169 | self.actions['helpAbout'] = self.createAction("About",self.helpAbout) 170 | #self.actions['test'] = self.createAction("Test",self._testFcn) 171 | 172 | #set app menu 173 | self.createMenu() 174 | self.createToolbars() 175 | 176 | #quick init 177 | self._quickInit() 178 | self.resize(800,600) 179 | 180 | def _quickInit(self): 181 | testFile = 'CointPairs_test.csv' 182 | if os.path.exists(testFile): 183 | self.dataTable.loadSpreads(testFile) 184 | 185 | def createMenu(self): 186 | menu = self.menuBar() 187 | menu.addMenu("File").addAction(self.actions['loadScreener']) 188 | menu.addMenu("Help").addAction(self.actions['helpAbout']) 189 | 190 | def createToolbars(self): 191 | t = self.addToolBar("File") 192 | t.setObjectName("FileToolBar") 193 | t.addAction(self.actions['loadScreener']) 194 | 195 | def loadScreenerSymbols(self, fName = None): 196 | ' load symbols from yahoo screener csv' 197 | 198 | if fName is None: 199 | formats = ['*.csv'] 200 | path = (os.path.dirname(self.filename) 201 | if self.filename is not None else ".") 202 | 203 | fName = unicode(QFileDialog.getOpenFileName(self,"Open screener results",path, 204 | "CSV files ({0})".format(" ".join(formats)))) 205 | 206 | if fName: 207 | self.dataTable.loadSpreads(fName) 208 | 209 | def plotHistData(self): 210 | ''' plot internal historic data ''' 211 | plt = PlotWindow(self) 212 | plt.plot(self.dataTable.histData.df) 213 | plt.show() 214 | 215 | def createAction(self, text, slot=None, shortcut=None, icon=None, 216 | tip=None, checkable=False, signal="triggered()"): 217 | action = QAction(text, self) 218 | if icon is not None: 219 | action.setIcon(QIcon(":/{0}.png".format(icon))) 220 | if shortcut is not None: 221 | action.setShortcut(shortcut) 222 | if tip is not None: 223 | action.setToolTip(tip) 224 | action.setStatusTip(tip) 225 | if slot is not None: 226 | self.connect(action, SIGNAL(signal), slot) 227 | if checkable: 228 | action.setCheckable(True) 229 | return action 230 | 231 | 232 | def helpAbout(self): 233 | QMessageBox.about(self, "Spread Detective - About", 234 | """Spread Detective v {0} 235 |

Copyright © 2011 Jev Kuznetsov. 236 | All rights reserved. 237 |

238 | Copyright © 2008-2011 AQR Capital Management, LLC All rights reserved. 239 |

240 | Copyright © 2011 Wes McKinney and pandas developers All rights reserved. 241 | 242 |

243 | 244 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER AND CONTRIBUTORS 245 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 246 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 247 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 248 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 249 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 250 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 251 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 252 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 253 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 254 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 255 | 256 | 257 |

Python {1} """.format(__version__, platform.python_version())) 258 | 259 | 260 | 261 | def _testFcn(self): 262 | print 'test function' 263 | self.dataTable.dataModel.signalUpdate() 264 | 265 | 266 | 267 | def main(): 268 | app = QApplication(sys.argv) 269 | form = MainWindow() 270 | form.show() 271 | app.exec_() 272 | 273 | 274 | main() 275 | --------------------------------------------------------------------------------