├── README.md
├── util
├── __init__.py
├── etf_tmp.py
├── init_save_until_today.py
├── init_save_us_until_today.py
├── plot.py
├── fresh_until_today.py
├── fresh_us_until_today.py
├── scrapy_etf.py
├── ml_util.py
├── Feature_utils.py
└── db.py
├── deal_data
├── __init__.py
├── db.pyc
├── forecast.pyc
├── FeatureUtils.pyc
├── db.py
├── get_r2.py
├── find_features_plot.py
├── deal_hs300.py
├── FeatureUtils.py
└── forecast.py
├── back_test_system
├── .gitignore
├── self
│ ├── ib_execution.py
│ ├── event.pyc
│ ├── performance.pyc
│ ├── strategy.py
│ ├── performance.py
│ ├── exexcution.py
│ ├── event.py
│ ├── mac.py
│ ├── backtest.py
│ ├── data.py
│ └── portfolio.py
├── data.pyc
├── event.pyc
├── backtest.pyc
├── strategy.pyc
├── execution.pyc
├── performance.pyc
├── portfolio.pyc
├── .idea
│ ├── misc.xml
│ ├── inspectionProfiles
│ │ └── profiles_settings.xml
│ ├── modules.xml
│ ├── back_test_system.iml
│ └── workspace.xml
├── try.py
├── strategy.py
├── to_csv.py
├── performance.py
├── execution.py
├── mac.py
├── backtest.py
├── event.py
├── ib_execution.py
├── data.py
└── portfolio.py
├── strategy_1
├── mark
└── simulation.py
├── strategy_2_alpace
├── mark
├── tryData.py
└── alpace.py
├── .gitignore
└── get_data
├── get_300_names.py
├── securities_master.sql
├── get_hist_data.py
├── find_mean_reversion.py
├── find_pairs.py
└── clustering_hs300.py
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/deal_data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/back_test_system/.gitignore:
--------------------------------------------------------------------------------
1 | .pyc
--------------------------------------------------------------------------------
/back_test_system/self/ib_execution.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/strategy_1/mark:
--------------------------------------------------------------------------------
1 | {
2 | 博云新材:002297
3 | }
--------------------------------------------------------------------------------
/strategy_2_alpace/mark:
--------------------------------------------------------------------------------
1 | 羊驼策略
2 | 针对美股
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.csv.xz
2 | *.csv
3 | *.xz
4 | *.pyc
--------------------------------------------------------------------------------
/deal_data/db.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ayizhi/quantResearch_py/HEAD/deal_data/db.pyc
--------------------------------------------------------------------------------
/deal_data/forecast.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ayizhi/quantResearch_py/HEAD/deal_data/forecast.pyc
--------------------------------------------------------------------------------
/back_test_system/data.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ayizhi/quantResearch_py/HEAD/back_test_system/data.pyc
--------------------------------------------------------------------------------
/back_test_system/event.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ayizhi/quantResearch_py/HEAD/back_test_system/event.pyc
--------------------------------------------------------------------------------
/deal_data/FeatureUtils.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ayizhi/quantResearch_py/HEAD/deal_data/FeatureUtils.pyc
--------------------------------------------------------------------------------
/back_test_system/backtest.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ayizhi/quantResearch_py/HEAD/back_test_system/backtest.pyc
--------------------------------------------------------------------------------
/back_test_system/strategy.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ayizhi/quantResearch_py/HEAD/back_test_system/strategy.pyc
--------------------------------------------------------------------------------
/back_test_system/execution.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ayizhi/quantResearch_py/HEAD/back_test_system/execution.pyc
--------------------------------------------------------------------------------
/back_test_system/performance.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ayizhi/quantResearch_py/HEAD/back_test_system/performance.pyc
--------------------------------------------------------------------------------
/back_test_system/portfolio.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ayizhi/quantResearch_py/HEAD/back_test_system/portfolio.pyc
--------------------------------------------------------------------------------
/back_test_system/self/event.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ayizhi/quantResearch_py/HEAD/back_test_system/self/event.pyc
--------------------------------------------------------------------------------
/back_test_system/self/performance.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ayizhi/quantResearch_py/HEAD/back_test_system/self/performance.pyc
--------------------------------------------------------------------------------
/back_test_system/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/util/etf_tmp.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import db
4 |
5 |
6 | data = pd.read_csv('etf.csv')
7 |
8 | data['Sector'] = 'ETF'
9 |
10 | data = data[['Symbol','Sector','Name']]
11 |
12 | db.save_us_into_db(data)
--------------------------------------------------------------------------------
/back_test_system/try.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import numpy
4 | import pandas
5 | import matplotlib
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 | from abc import ABCMeta, abstractmethod
15 |
16 |
17 |
--------------------------------------------------------------------------------
/back_test_system/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/back_test_system/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/back_test_system/self/strategy.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from abc import ABCMeta,abstractmethod
3 | import datetime
4 | try:
5 | import Queue as queue
6 | except ImportError:
7 | import queue
8 | import numpy as np
9 | import pandas as pd
10 | from event import SignalEvent
11 |
12 | class Strategy(object):
13 | __metaclass = ABCMeta
14 |
15 | @abstractmethod
16 | def calculate_signals(self):
17 | raise NotImplementedError('Should implement calculate_signals()')
--------------------------------------------------------------------------------
/back_test_system/.idea/back_test_system.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/util/init_save_until_today.py:
--------------------------------------------------------------------------------
1 | #coding: utf-8
2 |
3 | from db import get_hs300_tickers,get_ticker_info_by_id,save_ticker_into_db,get_last_date
4 | import datetime
5 |
6 |
7 | if __name__ == '__main__':
8 | #hs300的id
9 | ticker_info = get_hs300_tickers()
10 | for i in range(len(ticker_info)):
11 | ticker = ticker_info[i]
12 | ticker_id = ticker[1]
13 | ticker_name = ticker[2]
14 | vendor_id = i
15 |
16 | if i < 190:
17 | continue
18 |
19 | #获取
20 | ticker_data = get_ticker_info_by_id(ticker_id,'')
21 |
22 | #存储
23 | save_ticker_into_db(ticker_id,ticker_data,vendor_id)
--------------------------------------------------------------------------------
/back_test_system/self/performance.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import numpy as np
3 | import pandas as pd
4 |
5 | def create_sharpe_ratio(returns,period=252):
6 | return np.sqrt(period) * (np.mean(returns)) / np.std(returns)
7 |
8 | def create_drawdowns(pnl):
9 | hwm = [0]
10 | idx = pnl.index
11 | drawdown = pd.Series(index = idx)
12 | duration = pd.Series(index = idx)
13 |
14 | for t in range(1,len(idx)):
15 | hwm.append(max(hwm[t - 1],pnl[t]))
16 | drawdown[t] = (hwm[t] - pnl[t])
17 | duration[t] = (0 if drawdown[t] == 0 else duration[t - 1] + 1)
18 |
19 | return drawdown,drawdown.max(),duration.max()
20 |
--------------------------------------------------------------------------------
/strategy_2_alpace/tryData.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import datetime
4 | import sys
5 | sys.path.append('..')
6 | import util.db as db
7 | import util.plot as plot
8 | import pprint
9 |
10 | data = ['EWZS']
11 |
12 | df = pd.DataFrame(data)
13 |
14 | my_judge = [];
15 | for i in range(df.shape[0]):
16 | # print df[i]
17 | ticker = df.loc[i]
18 | print ticker,'========================='
19 | ticker_id = ticker[0]
20 | plot.plotCurrentMeanStd(ticker_id,400)
21 | ticker_judge = raw_input()#1:buy,0:no,2:interest
22 | my_judge.append((ticker_id,ticker_judge))
23 |
24 |
25 |
26 |
27 |
28 |
29 | pprint.pprint(my_judge)
30 |
31 |
32 |
--------------------------------------------------------------------------------
/back_test_system/self/exexcution.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | from abc import ABCMeta,abstractmethod
4 | import datetime
5 | try:
6 | import Queue as queue
7 | except ImportError:
8 | import queue
9 |
10 | from event import FillEvent,OrderEvent
11 |
12 | class ExecutionHandler(object):
13 | __metaclass = ABCMeta
14 |
15 | @abstractmethod
16 | def execute_order(self,event):
17 | raise NotImplementedError('Should implement execute_order()')
18 |
19 | class SimulateExecutionHandler(ExecutionHandler):
20 | def __init__(self,events):
21 | self.events = events
22 |
23 | def execute_order(self,event):
24 | if event.type == 'ORDER':
25 | fill_event = FillEvent(
26 | datetime.datetime.utcnow(),event.symbol,'ARCA',event.quantity,event.direction,None)
27 | self.events.put(fill_event)
--------------------------------------------------------------------------------
/util/init_save_us_until_today.py:
--------------------------------------------------------------------------------
1 | #coding: utf-8
2 | import numpy as np
3 | import pandas as pd
4 | import datetime
5 | import sys
6 | import util.db as db
7 |
8 |
9 | if __name__ == '__main__':
10 | symbols = db.get_us_tickers();
11 | start_date = datetime.datetime(2015,1,1)
12 | error_arr = [];
13 | for i in range(len(symbols)):
14 | symbol = symbols[i][1];
15 | try:
16 | print '========= loading %s , %s ==========' % (i,symbol)
17 | ticker = db.get_us_ticker_by_id(symbol,start_date)
18 | print '========= loading success'
19 |
20 | db.save_us_ticker_into_db(symbol,ticker,i)
21 | print '+++++++++++++ save %s , %s success +++++++++++++++' % (i,symbol)
22 | except:
23 | error_arr.append(symbol)
24 | db.delete_symbol_from_db_by_id(symbol)
25 | print '------------- delete %s , %s success -------------' % (i,symbol)
26 |
27 |
--------------------------------------------------------------------------------
/util/plot.py:
--------------------------------------------------------------------------------
1 | #coding utf-8
2 | import matplotlib.pyplot as plt
3 | import pandas as pd
4 | import numpy as np
5 | import db
6 |
7 |
8 |
9 |
10 | def plotCurrentMeanStd(tickerId,days_rang=200):
11 | df = db.get_current_mean_std_df(tickerId,days_rang)
12 | cur = df['close']
13 | theMin = cur.min();
14 | theMax = cur.max();
15 | mean_60 = df['ma_60']
16 | emwa_60 = df['ewma_60']
17 | std_60 = df['std_60']
18 |
19 | plt.plot(cur,'r',lw=0.75,linestyle='-',label='cur')
20 | plt.plot(std_60,'p',lw=0.75,linestyle='-',label='std_60')
21 | plt.plot(mean_60,'b',lw=0.75,linestyle='-',label='mean_60')
22 | plt.plot(emwa_60,'g',lw=0.75,linestyle='-',label='emwa_60')
23 | plt.ylim(float(theMin) * 0.9, float(theMax) * 1.1)
24 |
25 | plt.legend(loc=4,prop={'size':2})
26 | plt.setp(plt.gca().get_xticklabels(), rotation=30)
27 | plt.grid(True)
28 |
29 | plt.show()
30 |
31 |
32 |
--------------------------------------------------------------------------------
/strategy_1/simulation.py:
--------------------------------------------------------------------------------
1 | #coding: utf-8
2 | import numpy as np
3 | import pandas as pd
4 | import datetime
5 | import sys
6 | sys.path.append('..')
7 | import util.db as db
8 | from util.Feature_utils import get_good_feature
9 | from util.ml_util import get_classification_r2,get_regression_r2
10 |
11 |
12 | if __name__ == '__main__':
13 | ticker_df = db.get_ticker_from_db_by_id('300133')
14 | ticker_index = ticker_df['index']
15 | ticker_np = np.array(ticker_df[['open','high','low','close','volume']])
16 | ticker_df = pd.DataFrame(ticker_np,index=ticker_index,columns=['open','high','low','close','volume'],dtype="float")
17 | start_date = datetime.datetime(2016,1,1)
18 | end_date = datetime.datetime.today()
19 | ticker_df = ticker_df[ticker_df.index >= start_date]
20 | ticker_df = ticker_df[ticker_df.index <= end_date]
21 | ticker_data = get_good_feature(ticker_df,10)
22 | ticker_data.dtype = '|S6'
23 |
24 |
25 |
26 | #get best ml
27 | get_regression_r2(ticker_data)
28 |
29 |
--------------------------------------------------------------------------------
/util/fresh_until_today.py:
--------------------------------------------------------------------------------
1 | from db import get_hs300_tickers,get_ticker_info_by_id,save_ticker_into_db,get_last_date
2 | import datetime
3 | import tushare as ts
4 | import time
5 |
6 |
7 | if __name__ == '__main__':
8 | #hs300的id
9 | ticker_info = get_hs300_tickers()
10 | for i in range(len(ticker_info)):
11 | ticker = ticker_info[i]
12 | ticker_id = ticker[1]
13 | ticker_name = ticker[2]
14 | vendor_id = i;
15 |
16 | print '--------------------- %s ---------------------' % vendor_id
17 |
18 | try:
19 | start_date = get_last_date(ticker_id)[0]
20 | start_date = str(start_date[0] + datetime.timedelta(days = 1))[0:10]
21 | except:
22 | start_date = ''
23 |
24 |
25 | ticker_data = get_ticker_info_by_id(ticker_id,start_date)
26 |
27 | print ('data_shape:' , ticker_data.shape)
28 |
29 | if ticker_data.shape[0] != 0:
30 | #存储
31 | try:
32 | save_ticker_into_db(ticker_id,ticker_data,vendor_id)
33 | except:
34 | print '数据有问题!'
35 |
36 |
37 |
38 |
--------------------------------------------------------------------------------
/back_test_system/strategy.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 |
4 | # strategy.py
5 |
6 | from __future__ import print_function
7 |
8 | from abc import ABCMeta, abstractmethod
9 | import datetime
10 | try:
11 | import Queue as queue
12 | except ImportError:
13 | import queue
14 |
15 | import numpy as np
16 | import pandas as pd
17 |
18 | from event import SignalEvent
19 |
20 |
21 | class Strategy(object):
22 | """
23 | Strategy is an abstract base class providing an interface for
24 | all subsequent (inherited) strategy handling objects.
25 |
26 | The goal of a (derived) Strategy object is to generate Signal
27 | objects for particular symbols based on the inputs of Bars
28 | (OHLCV) generated by a DataHandler object.
29 |
30 | This is designed to work both with historic and live data as
31 | the Strategy object is agnostic to where the data came from,
32 | since it obtains the bar tuples from a queue object.
33 | """
34 |
35 | __metaclass__ = ABCMeta
36 |
37 | @abstractmethod
38 | def calculate_signals(self):
39 | """
40 | Provides the mechanisms to calculate the list of signals.
41 | """
42 | raise NotImplementedError("Should implement calculate_signals()")
43 |
--------------------------------------------------------------------------------
/get_data/get_300_names.py:
--------------------------------------------------------------------------------
1 | #coding:utf-8
2 | #对沪深三百股票进行聚类,并画出关系图
3 | import tushare as ts
4 | import sqlalchemy as create_engine
5 | import MySQLdb as mdb
6 | import datetime
7 |
8 |
9 |
10 | #get hs300's names and insert them into database
11 | def get_hs300(con):
12 | now = datetime.datetime.utcnow()
13 | hs300 = ts.get_hs300s()
14 | column_str = """ticker, instrument, name, sector, currency, created_date, last_updated_date"""
15 | insert_str = ("%s, " * 7)[:-2]
16 | final_str = "INSERT INTO symbol (%s) VALUES (%s)" % (column_str, insert_str)
17 | symbols = []
18 |
19 | for i in range(len(hs300)):
20 | t = hs300.ix[i]
21 | symbols.append(
22 | (
23 | t['code'],
24 | 'stock',
25 | t['name'],
26 | '',
27 | 'RMB',
28 | now,
29 | now,
30 | )
31 | )
32 | cur = con.cursor()
33 | with con:
34 | cur = con.cursor()
35 | cur.executemany(final_str, symbols)
36 | print 'success insert hs300 into symbol!'
37 |
38 |
39 |
40 | if __name__ == '__main__':
41 | db_host = 'localhost'
42 | db_user = 'root'
43 | db_password = ''
44 | db_name = 'securities_master'
45 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name)
46 |
47 | #get all 300 names and put them into database
48 | get_hs300(con);
49 |
50 |
51 |
52 |
--------------------------------------------------------------------------------
/back_test_system/to_csv.py:
--------------------------------------------------------------------------------
1 | import MySQLdb as mdb
2 | import csv
3 | import pandas as pd
4 | import numpy as np
5 | from pandas import DataFrame,Series
6 | import sys
7 |
8 | def get_tickers_from_db():
9 | db_host = 'localhost'
10 | db_user = 'root'
11 | db_pd = ''
12 | db_name = 'securities_master'
13 | con = mdb.connect(host=db_host, user=db_user, passwd=db_pd, db=db_name)
14 | with con:
15 | cur = con.cursor()
16 | cur.execute('SELECT id,ticker FROM symbol');
17 | return cur.fetchall()
18 |
19 |
20 | def get_one_ticker_by_id(ticker_id):
21 | db_host = 'localhost'
22 | db_user = 'root'
23 | db_pd = ''
24 | db_name = 'securities_master'
25 | con = mdb.connect(host=db_host, user=db_user, passwd=db_pd, db=db_name)
26 | with con:
27 | cur = con.cursor()
28 | cur.execute('SELECT price_date,open_price,high_price,low_price,close_price,volume FROM daily_price WHERE symbol_id = %s' % ticker_id)
29 | return cur.fetchall()
30 |
31 |
32 | def render_csv(ticker_id):
33 | data = get_one_ticker_by_id(ticker_id)
34 |
35 | np_data = np.array(data)
36 | pd_data = DataFrame(np_data,columns = ['datetime', 'open', 'high', 'low', 'close', 'volume'])
37 | adj_close = Series([d[4] for d in data])
38 | pd_data['adj_close'] = adj_close
39 | pd_data.to_csv('./data/%s.csv' % ticker_id)
40 |
41 |
42 | render_csv('600050')
43 |
44 |
45 |
46 |
47 |
48 |
49 |
--------------------------------------------------------------------------------
/util/fresh_us_until_today.py:
--------------------------------------------------------------------------------
1 | from db import get_us_tickers,get_us_ticker_from_db_by_id,save_us_ticker_into_db,get_us_last_date
2 | import datetime
3 | import db
4 | import tushare as ts
5 | import pandas as pd
6 | import time
7 |
8 |
9 |
10 | if __name__ == '__main__':
11 | #hs300的id
12 | ticker_info = get_us_tickers()
13 | for i in range(len(ticker_info)):
14 | ticker = ticker_info[i]
15 | ticker_id = ticker[1]
16 | ticker_name = ticker[2]
17 | vendor_id = i;
18 |
19 |
20 | print '--------------------- %s ---------------------' % vendor_id
21 |
22 |
23 | try:
24 | print '========= loading %s , %s ==========' % (i,ticker_id)
25 | start_date = get_us_last_date(ticker_id)[0][0]
26 | print '========= loading success =========='
27 | start_date = start_date + datetime.timedelta(days = 1)
28 |
29 | except:
30 | start_date = ''
31 |
32 | print start_date , '============================================'
33 |
34 | # try:
35 | ticker_data = db.get_us_ticker_by_id(ticker_id,start_date)
36 | # except:
37 | # print 'get data fail...'
38 | # ticker_data = pd.DataFrame([])
39 |
40 | print ('data :' , ticker_data)
41 |
42 | if ticker_data.shape[0] != 0:
43 | #存储
44 | try:
45 | print '+++++++++++++ save %s , %s success +++++++++++++++' % (i,ticker_id)
46 | save_us_ticker_into_db(ticker_id,ticker_data,vendor_id)
47 | except:
48 | print(ticker_id, 'is error')
49 |
50 | print ticker_id,'==== finished ====='
51 |
52 |
53 |
54 |
55 |
--------------------------------------------------------------------------------
/deal_data/db.py:
--------------------------------------------------------------------------------
1 | #coding:utf-8
2 | import MySQLdb as mdb
3 |
4 |
5 |
6 |
7 | #get name
8 | def get_tickers_from_db():
9 | db_host = 'localhost'
10 | db_user = 'root'
11 | db_password = ''
12 | db_name = 'securities_master'
13 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name)
14 |
15 | #get name form symbol;
16 | with con:
17 | cur = con.cursor()
18 | cur.execute('SELECT ticker,name FROM symbol')
19 | data = cur.fetchall()
20 | return data
21 |
22 | #获取当日成交量
23 | def get_day_volumn_33_66(day):
24 | db_host = 'localhost'
25 | db_user = 'root'
26 | db_password = ''
27 | db_name = 'securities_master'
28 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name)
29 |
30 | with con:
31 | cur = con.cursor()
32 | cur.execute('SELECT volume from daily_price where (price_date="%s")' % day)
33 | day_volume = cur.fetchall()
34 | day_volume = [int(day_volume[i][0]) for i in range(len(day_volume))]
35 | return day_volume
36 |
37 |
38 | #get data by tickerId
39 | def get_10_50_by_id(ticker_id):
40 | db_host = 'localhost'
41 | db_user = 'root'
42 | db_password = ''
43 | db_name = 'securities_master'
44 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name)
45 |
46 | with con:
47 | cur = con.cursor()
48 | cur.execute('SELECT price_date,open_price,high_price,low_price,close_price,volume from daily_price where (symbol_id = %s) and (price_date BETWEEN "20150101" AND "20151231") and (high_price BETWEEN 10 and 50)' % ticker_id)
49 | daily_data = cur.fetchall()
50 | return daily_data
51 |
52 |
53 |
--------------------------------------------------------------------------------
/back_test_system/self/event.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | class Event(object):
4 | pass
5 |
6 | class MarketEvent(Event):
7 | def __init__(self):
8 | self.type = 'MARKET'
9 |
10 | class SignalEvent(Event):
11 | def __init__(self,strategy_id,symbol,datetime,signal_type,strength):
12 | self.strategy_id = strategy_id
13 | self.type = 'SIGNAL'
14 | self.symbol = symbol
15 | self.datetime = datetime
16 | self.signal_type = signal_type
17 | self.strength = strength
18 |
19 | class OrderEvent(Event):
20 | def __init__(self,symbol,order_type,quantity,direction):
21 | self.type = 'ORDER'
22 | self.symbol = symbol
23 | self.order_type = order_type
24 | self.quantity = quantity
25 | self.direction = direction
26 |
27 | def print_order(self):
28 | print("Order: Symbol=%s,Type=%s,Quantity=%s,Direction=%s" %s (self.symbol,self.order_type,self.quantity,self.direction))
29 |
30 | class FillEvent(Event):
31 | def __init__(self,timeindex,symbol,exchange,quantity,direction,fill_cost,commission=None):
32 | self.type = 'FILL'
33 | self.timeindex = timeindex
34 | self.symbol = symbol
35 | self.exchange = exchange
36 | self.quantity = quantity
37 | self.direction = direction
38 | self.fill_cost = fill_cost
39 |
40 | if commission is None:
41 | self.commission = self.calculate_ib_commission()
42 | else:
43 | self.commission = commission
44 |
45 | def calculate_ib_commission(self):
46 | full_cost = 1.3
47 | if self.quantity <= 500:
48 | full_cost = max(1.3,0.013 * self.quantity)
49 | else:
50 | full_cost = max(1.3,0.008 * self.quantity)
51 | return full_cost
--------------------------------------------------------------------------------
/back_test_system/performance.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 |
4 | # performance.py
5 |
6 | from __future__ import print_function
7 |
8 | import numpy as np
9 | import pandas as pd
10 |
11 |
12 | def create_sharpe_ratio(returns, periods=252):
13 | """
14 | Create the Sharpe ratio for the strategy, based on a
15 | benchmark of zero (i.e. no risk-free rate information).
16 |
17 | Parameters:
18 | returns - A pandas Series representing period percentage returns.
19 | periods - Daily (252), Hourly (252*6.5), Minutely(252*6.5*60) etc.
20 | """
21 | return np.sqrt(periods) * (np.mean(returns)) / np.std(returns)
22 |
23 |
24 | def create_drawdowns(pnl):
25 | """
26 | Calculate the largest peak-to-trough drawdown of the PnL curve
27 | as well as the duration of the drawdown. Requires that the
28 | pnl_returns is a pandas Series.
29 |
30 | Parameters:
31 | pnl - A pandas Series representing period percentage returns.
32 |
33 | Returns:
34 | drawdown, duration - Highest peak-to-trough drawdown and duration.
35 | """
36 |
37 | # Calculate the cumulative returns curve
38 | # and set up the High Water Mark
39 | hwm = [0]
40 |
41 | # Create the drawdown and duration series
42 | idx = pnl.index
43 | drawdown = pd.Series(index = idx)
44 | duration = pd.Series(index = idx)
45 |
46 | # Loop over the index range
47 | for t in range(1, len(idx)):
48 | hwm.append(max(hwm[t-1], pnl[t]))
49 | drawdown[t]= (hwm[t]-pnl[t])
50 | duration[t]= (0 if drawdown[t] == 0 else duration[t-1]+1)
51 | return drawdown, drawdown.max(), duration.max()
52 |
--------------------------------------------------------------------------------
/get_data/securities_master.sql:
--------------------------------------------------------------------------------
1 | CREATE TABLE `exchange` (
2 | `id` int NOT NULL AUTO_INCREMENT,
3 | `abbrev` varchar(32) NOT NULL,
4 | `name` varchar(255) NOT NULL,
5 | `city` varchar(255) NULL,
6 | `country` varchar(255) NULL,
7 | `currency` varchar(64) NULL,
8 | `timezone_offset` time NULL,
9 | `created_date` datetime NOT NULL,
10 | `last_updated_date` datetime NOT NULL,
11 | PRIMARY KEY (`id`)
12 | ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8;
13 |
14 | CREATE TABLE `data_vendor` (
15 | `id` int NOT NULL AUTO_INCREMENT,
16 | `name` varchar(64) NOT NULL,
17 | `website_url` varchar(255) NULL,
18 | `support_email` varchar(255) NULL,
19 | `created_date` datetime NOT NULL,
20 | `last_updated_date` datetime NOT NULL,
21 | PRIMARY KEY (`id`)
22 | ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8;
23 |
24 | CREATE TABLE `symbol` (
25 | `id` int NOT NULL AUTO_INCREMENT,
26 | `exchange_id` int NULL,
27 | `ticker` varchar(32) NOT NULL,
28 | `instrument` varchar(64) NOT NULL,
29 | `name` varchar(255) NULL,
30 | `sector` varchar(255) NULL,
31 | `currency` varchar(32) NULL,
32 | `created_date` datetime NOT NULL,
33 | `last_updated_date` datetime NOT NULL,
34 | PRIMARY KEY (`id`),
35 | KEY `index_exchange_id` (`exchange_id`)
36 | ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8;
37 |
38 | CREATE TABLE `daily_price` (
39 | `id` int NOT NULL AUTO_INCREMENT,
40 | `data_vendor_id` int NOT NULL,
41 | `symbol_id` int NOT NULL,
42 | `price_date` datetime NOT NULL,
43 | `created_date` datetime NOT NULL,
44 | `last_updated_date` datetime NOT NULL,
45 | `open_price` decimal(19,4) NULL,
46 | `high_price` decimal(19,4) NULL,
47 | `low_price` decimal(19,4) NULL,
48 | `close_price` decimal(19,4) NULL,
49 | `adj_close_price` decimal(19,4) NULL,
50 | `volume` bigint NULL,
51 | PRIMARY KEY (`id`),
52 | KEY `index_data_vendor_id` (`data_vendor_id`),
53 | KEY `index_symbol_id` (`symbol_id`)
54 | ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8;
--------------------------------------------------------------------------------
/util/scrapy_etf.py:
--------------------------------------------------------------------------------
1 | import requests
2 | import simplejson as json
3 | import pandas as pd
4 |
5 | # url = 'https://xueqiu.com/hq#exchange=US&plate=3_1_11&firstName=3&secondName=3_1&order=desc&orderby=marketcapital&page=1'
6 | url1 = 'https://xueqiu.com/snowman/login'
7 |
8 | data = {
9 | "remember_me":"true",
10 | "username":"13151998870",
11 | "password":"zhangyizhi112358"
12 | }
13 | headers = {
14 | "Accept":"application/json, text/javascript, */*; q=0.01",
15 | "Accept-Encoding":"gzip, deflate, sdch, br",
16 | "Accept-Language":"zh-CN,zh;q=0.8,en;q=0.6",
17 | "Cache-Control":"no-cache",
18 | "Connection":"keep-alive",
19 | "Host":"xueqiu.com",
20 | "Pragma":"no-cache",
21 | "Referer":"https://xueqiu.com/hq",
22 | "User-Agent":"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_6) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/56.0.2924.87 Safari/537.36",
23 | "X-Requested-With":"XMLHttpRequest",
24 | }
25 |
26 | res1 = requests.post(url1,data,headers=headers)
27 |
28 | target_data = []
29 | pageNum = 1
30 |
31 | while True:
32 | url2 = 'http://xueqiu.com/stock/cata/stocklist.json?page=%s&size=30&order=desc&orderby=marketCapital&exchange=US&plate=ETF&isdelay=1' % (pageNum)
33 | res2 = requests.get(url2, cookies=res1.cookies, headers=headers)
34 | stock_data = json.loads(res2.text)['stocks']
35 |
36 |
37 |
38 | print 'pageNum: ', pageNum ,' ====================================='
39 |
40 | # print stock_data,'---------------------------'
41 |
42 |
43 |
44 | if len(stock_data) == 0 :
45 | break
46 |
47 | for i in range(len(stock_data)):
48 | stock = stock_data[i]
49 | stock_id = stock['code'].encode('utf-8')
50 | stock_name = stock['name'].encode('utf-8')
51 | print stock_id,stock_name
52 | target_data.append((stock_id,stock_name))
53 |
54 | pageNum = pageNum + 1
55 |
56 |
57 |
58 | df = pd.DataFrame(target_data,columns=['Symbol','Name'])
59 | df.to_csv('etf.csv')
60 |
61 |
62 |
63 | print target_data
64 |
65 |
--------------------------------------------------------------------------------
/back_test_system/self/mac.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import numpy as np
3 |
4 | from backtest import Backtest
5 | from data import HistoricCSVDataHandler
6 | from event import SignalEvent
7 | from execution import SimulateExecutionHandler
8 | from portfolio import Portfolio
9 | from strategy import Strategy
10 |
11 | class MovingAverageCrossStrategy(Strategy):
12 | def __init__(self,bars,events,short_window=100,long_window=400):
13 | self.bars = bars
14 | self.symbol_list = self.bars.symbol_list
15 | self.events = events
16 | self.short_window = short_window
17 | self.long_window = long_window
18 |
19 | self.bought = self._calculate_initial_bought()
20 |
21 | def _calculate_initial_bought(self):
22 | bought = {}
23 | for s in self.symbol_list:
24 | bought[s] = 'OUT'
25 | return bought
26 |
27 | def calculate_signals(self,event):
28 | if event.type == 'MARKET':
29 | for symbol in self.symbol_list:
30 | bars = self.bars.get_latest_bars_values(symbol,'close',N=self.long_window)
31 | if bars is not None and bars != []:
32 | short_sma = np.mean(bars[-self.short_window:])
33 | long_sma = np.mean(bars[-self.long_window:])
34 |
35 | dt = self.bars.get_latest_bar_datetime(symbol)
36 | sig_dir = ''
37 | strength = 1.0
38 | strategy_id = 1
39 |
40 | if short_sma > long_sma and self.bought[symbol] == 'OUT':
41 | sig_dir = 'LONG'
42 | signal = SignalEvent(strategy_id,symbol,dt,sig_dir,strength)
43 | self.events.put(signal)
44 | self.bought[symbol] = 'LONG'
45 | elif short_sma < long_sma and self.bought[symbol] == 'LONG':
46 | sig_dir = 'EXIT'
47 | signal = SignalEvent(
48 | strategy_id,symbol,dt,sig_dir,strength)
49 | self.events.put(signal)
50 | self.bought[symbol] = 'OUT'
51 |
52 | if __name == '__main__':
53 | csv_dir = 'data'
54 | symbol_list = ['600050']
55 | initial_capital = 100000.0
56 | start_date = datetime.datetime(1990,1,1,0,0,0)
57 | hearbeat = 0.0
58 |
59 | backtest = Backtest(
60 | csv_dir,symbol_list,
61 | initial_capital,
62 | hearbeat,
63 | start_date,
64 | HistoricCSVDataHandler,
65 | SimulateExecutionHandler,
66 | Portfolio,
67 | MovingAverageCrossStrategy
68 | )
69 | backtest.simulate_trading()
70 |
71 |
--------------------------------------------------------------------------------
/get_data/get_hist_data.py:
--------------------------------------------------------------------------------
1 | #coding:utf-8
2 | #对沪深三百股票进行聚类,并画出关系图
3 | import tushare as ts
4 | import sqlalchemy as create_engine
5 | import MySQLdb as mdb
6 | import datetime
7 |
8 |
9 |
10 | def get_tickers_from_db(con):
11 | #get name form symbol;
12 | with con:
13 | cur = con.cursor()
14 | cur.execute('SELECT id,ticker FROM symbol')
15 | data = cur.fetchall()
16 | return [(d[0],d[1]) for d in data]
17 |
18 | def get_hist_data_from_tushare(ticker):
19 | start_date = '2000-1-1'
20 | end_date = str(datetime.today().timetuple())
21 |
22 | get every detail data
23 | tData = ts.get_hist_data(ticker,start=start_date,end=end_date,retry_count=5,pause=1)
24 | return tData
25 |
26 | def into_db(tTicker,data,data_vendor_id,con):
27 | # Create the time now
28 | now = datetime.datetime.utcnow()
29 | # Create the insert strings
30 | column_str = """data_vendor_id, symbol_id, price_date, created_date,
31 | last_updated_date, open_price, high_price, low_price,
32 | close_price, volume, adj_close_price"""
33 | insert_str = ("%s, " * 11)[:-2]
34 | final_str = "INSERT INTO daily_price (%s) VALUES (%s)" % (column_str, insert_str)
35 | daily_data = []
36 |
37 | for i in range(len(data.index)):
38 | t_date = data.index[i]
39 | t_data = data.ix[t_date]
40 | daily_data.append(
41 | (data_vendor_id, tTicker, t_date, now, now,t_data['open'], t_data['high']
42 | , t_data['low'], t_data['close'], t_data['volume'], 0)
43 | )
44 |
45 | with con:
46 | cur = con.cursor()
47 | cur.executemany(final_str, daily_data)
48 |
49 |
50 |
51 | if __name__ == '__main__':
52 | db_host = 'localhost'
53 | db_user = 'root'
54 | db_password = ''
55 | db_name = 'securities_master'
56 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name)
57 |
58 | #get history data and put them into database
59 | tickers = get_tickers_from_db(con);
60 |
61 | #iterating every ticker and put theirs data into datbase
62 | for i in range(len(tickers)):
63 | t = tickers[i]
64 | tTicker = tickers[i][1]
65 | data = get_hist_data_from_tushare(tTicker)
66 | data_vendor_id = i
67 | print 'data_vendor_id : %s' % data_vendor_id
68 | print 'tTicker : %s' % tTicker
69 | print '%s of %s' % (i,len(tickers))
70 | into_db(tTicker,data,data_vendor_id,con)
71 |
72 |
73 |
74 |
75 |
76 |
77 |
--------------------------------------------------------------------------------
/strategy_2_alpace/alpace.py:
--------------------------------------------------------------------------------
1 |
2 | #coding: utf-8
3 | import numpy as np
4 | import pandas as pd
5 | import datetime
6 | import sys
7 | sys.path.append('..')
8 | import util.db as db
9 | import pprint
10 |
11 |
12 |
13 |
14 |
15 | if __name__ == '__main__':
16 |
17 | day_range = 20 #计算周期
18 | round_days = 7 #执行周期
19 | average_days_70 = 6 * 10 #几日均线
20 | average_days_5 = 10
21 | average_days_30 = 20
22 |
23 | #找到volumn在33%-66%之间的股票池,20日平均交易量,并且
24 | stockers = db.get_us_middle33_volume(day_range,8,20)
25 | stocker_ids = stockers['id']
26 | ticker_content = [];
27 | #需要计算的,方差,30周均线,计算周期趋势
28 |
29 |
30 | for i in range(len(stocker_ids)):
31 | stocker_id = stocker_ids[i]
32 | end_date = db.get_us_last_date(stocker_id)[0][0]
33 | start_date = end_date + datetime.timedelta(days = day_range * -1)
34 | stocker_data = db.get_us_ticker_from_db_by_id(stocker_id,start_date,end_date)
35 | profit = (stocker_data.loc[0].close - stocker_data.loc[len(stocker_data) - 1].close)/stocker_data.loc[len(stocker_data) - 1].close
36 | current_price = stocker_data.loc[0].close
37 | mean_price_70, std_price_70 = db.get_average_days_price_by_id(stocker_id,average_days_70)
38 | mean_price_30, std_price_30 = db.get_average_days_price_by_id(stocker_id,average_days_30)
39 | mean_price_5, std_price_5 = db.get_average_days_price_by_id(stocker_id,average_days_5)
40 |
41 |
42 |
43 | if (current_price > mean_price_30) and (abs(mean_price_30 - mean_price_70) < 1) :
44 | print (stocker_id,current_price,mean_price_30,mean_price_70,'===========================')
45 | ticker_content.append((stocker_id,current_price,profit,mean_price_5,mean_price_30,mean_price_70))
46 |
47 | df = pd.DataFrame(ticker_content,columns=['id','price','profit','mean5','mean_price_30','mean_price_70'])
48 | df = df.sort(['profit'],ascending=False)
49 | df = df.reset_index()
50 |
51 | #选盈利前30%中的后40%
52 | # df_length_30per = int(df.shape[0] * 0.3)
53 | # print df_length_30per,df.shape,'------------'
54 | # best_30per = df[:df_length_30per]
55 | # df_length_40per = int(best_30per.shape[0] *0.8)
56 | # print df_length_40per,best_30per.shape,'========='
57 | # best_30per_40per = best_30per[df_length_40per:]
58 |
59 | # print '~~~~~~~~~~~~~~~~~~~~~~~~~~'
60 |
61 | # pprint.pprint(best_30per_40per)
62 | # pprint.pprint(np.array(best_30per_40per['id']))
63 |
64 | pprint.pprint(df)
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
--------------------------------------------------------------------------------
/back_test_system/execution.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 |
4 | # execution.py
5 |
6 | from __future__ import print_function
7 |
8 | from abc import ABCMeta, abstractmethod
9 | import datetime
10 | try:
11 | import Queue as queue
12 | except ImportError:
13 | import queue
14 |
15 | from event import FillEvent, OrderEvent
16 |
17 |
18 | class ExecutionHandler(object):
19 | """
20 | The ExecutionHandler abstract class handles the interaction
21 | between a set of order objects generated by a Portfolio and
22 | the ultimate set of Fill objects that actually occur in the
23 | market.
24 |
25 | The handlers can be used to subclass simulated brokerages
26 | or live brokerages, with identical interfaces. This allows
27 | strategies to be backtested in a very similar manner to the
28 | live trading engine.
29 | """
30 |
31 | __metaclass__ = ABCMeta
32 |
33 | @abstractmethod
34 | def execute_order(self, event):
35 | """
36 | Takes an Order event and executes it, producing
37 | a Fill event that gets placed onto the Events queue.
38 |
39 | Parameters:
40 | event - Contains an Event object with order information.
41 | """
42 | raise NotImplementedError("Should implement execute_order()")
43 |
44 |
45 | class SimulatedExecutionHandler(ExecutionHandler):
46 | """
47 | The simulated execution handler simply converts all order
48 | objects into their equivalent fill objects automatically
49 | without latency, slippage or fill-ratio issues.
50 |
51 | This allows a straightforward "first go" test of any strategy,
52 | before implementation with a more sophisticated execution
53 | handler.
54 | """
55 |
56 | def __init__(self, events):
57 | """
58 | Initialises the handler, setting the event queues
59 | up internally.
60 |
61 | Parameters:
62 | events - The Queue of Event objects.
63 | """
64 | self.events = events
65 |
66 | def execute_order(self, event):
67 | """
68 | Simply converts Order objects into Fill objects naively,
69 | i.e. without any latency, slippage or fill ratio problems.
70 |
71 | Parameters:
72 | event - Contains an Event object with order information.
73 | """
74 | if event.type == 'ORDER':
75 | fill_event = FillEvent(
76 | datetime.datetime.utcnow(), event.symbol,
77 | 'ARCA', event.quantity, event.direction, None
78 | )
79 | self.events.put(fill_event)
80 |
--------------------------------------------------------------------------------
/deal_data/get_r2.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | import pprint
5 | import os
6 | path = os.getcwd()
7 | print path
8 | import FeatureUtils
9 | import forecast
10 | from pandas import Series,DataFrame
11 |
12 | #回归
13 | from sklearn.cross_validation import train_test_split
14 | from sklearn.metrics import classification_report
15 | from sklearn.svm import SVC
16 | from sklearn.linear_model import Lasso,LinearRegression,Ridge,LassoLars
17 | from sklearn.ensemble import RandomForestRegressor
18 | from sklearn.metrics import r2_score
19 | from sklearn import linear_model
20 |
21 |
22 | data = pd.read_csv(path + '/2/hs.csv',parse_dates=True,iterator=True)
23 | # data = pd.read_csv(path + '/2/hs.csv',parse_dates=True)
24 | data = pd.DataFrame(data.get_chunk(5000),dtype='|S6')
25 | # data = pd.DataFrame(data,dtype='|S6')
26 |
27 |
28 |
29 | def get_regression_r2(ticker_data):
30 | data_len = len(ticker_data)
31 | split_line = int(data_len * 0.8)
32 | X = ticker_data.drop('realY',1)
33 | y = ticker_data['realY'].dropna()
34 |
35 | X_train = X.ix[:split_line]
36 | X_test = X.ix[split_line:]
37 | y_train = y.ix[:split_line]
38 | y_test = y.ix[split_line:]
39 |
40 | models = [
41 | # ('LR',LinearRegression()),
42 | # ('RidgeR',Ridge (alpha = 0.005)),
43 | # ('lasso',Lasso(alpha=0.00001)),
44 | # ('LassoLars',LassoLars(alpha=0.00001)),
45 | ('RandomForestRegression',RandomForestRegressor(1000))]
46 |
47 | best_r2 = ('',-10000000)
48 |
49 | for m in models:
50 | m[1].fit(np.array(X_train),np.array(y_train))
51 | #因为index方面,pred出的其实是相当于往后挪了一位,跟原来的y_test是对不上的,所以x需要往前进一位
52 | #比较绕,所以从日期对应的方面去考虑
53 | pred = m[1].predict(X_test.shift(-1).fillna(0))
54 | r2 = r2_score(y_test,pred)
55 | if r2 > best_r2[1]:
56 | best_r2 = (m[1],r2)
57 | # print "%s:\n%0.3f" % (m[0], r2_score(np.array(y_test),np.array(pred)))
58 |
59 | print 'the best regression is:',best_r2
60 |
61 | model = best_r2[0]
62 | model.fit(X_train, y_train)
63 | pred = model.predict(X_test.shift(-1).fillna(0))
64 | pred_test = pd.Series(pred, index=y_test.index)
65 |
66 | fig = plt.figure()
67 | ax = fig.add_subplot(1,1,1)
68 | ax.plot(y_test,'r',lw=0.75,linestyle='-',label='realY')
69 | ax.plot(pred_test,'b',lw=0.75,linestyle='-',label='predY')
70 | plt.legend(loc=2,prop={'size':9})
71 | plt.setp(plt.gca().get_xticklabels(), rotation=30)
72 | plt.grid(True)
73 | plt.show()
74 |
75 | return best_r2
76 |
77 |
78 | X = data.drop(['realY','predictY'],1)
79 | y = data['realY']
80 |
81 | #get importances of features
82 | # features = FeatureUtils.forestFindFeature(X,y,100)[:300]
83 |
84 | data = X.join(y)
85 |
86 | get_regression_r2(data)
87 |
88 |
89 | # print data
--------------------------------------------------------------------------------
/deal_data/find_features_plot.py:
--------------------------------------------------------------------------------
1 | #coding:utf-8
2 | import pandas as pd
3 | import numpy as np
4 | from pandas import DataFrame,Series
5 | import matplotlib.pyplot as plt
6 | import tushare as ts
7 | import FeatureUtils
8 | from FeatureUtils import toDatetime
9 | from sklearn.cross_validation import train_test_split
10 | from sklearn.metrics import classification_report
11 | from sklearn.svm import SVC
12 | from sklearn.linear_model import Lasso,LinearRegression,Ridge,LassoLars
13 | from sklearn.metrics import r2_score
14 | from sklearn import linear_model
15 | import datetime
16 |
17 |
18 |
19 | #获取沪深300股指数据
20 | hs300 = DataFrame(ts.get_hist_data('hs300'))
21 | hs300 = toDatetime(hs300).dropna()
22 |
23 | #加入feature
24 | hs300 = FeatureUtils.CCI(hs300,10)
25 | hs300 = FeatureUtils.TL(hs300,10)
26 | hs300 = FeatureUtils.EVM(hs300,10)
27 | hs300 = FeatureUtils.SMA(hs300,10)
28 | hs300 = FeatureUtils.EWMA(hs300,10)
29 | hs300 = FeatureUtils.ROC(hs300,10)
30 | hs300 = FeatureUtils.ForceIndex(hs300,10)
31 | hs300 = FeatureUtils.BBANDS(hs300,10)
32 | hs300 = hs300.dropna()
33 |
34 | #归一化
35 | hs300_norm = (hs300 - hs300.mean())/(hs300.max() - hs300.min())
36 |
37 | # # Build a classification task using 3 informative features
38 | X_real = DataFrame(hs300_norm.drop('close',1),dtype='|S6')
39 | X = X_real.shift(1).dropna()
40 | y_real = Series(hs300_norm['close'],dtype='|S6')
41 | y = y_real[:-1]
42 |
43 |
44 | #forest find feature
45 | features = FeatureUtils.forestFindFeature(X,y,100)
46 |
47 |
48 | X_real_F = DataFrame(hs300_norm[features[0:10]],dtype='|S6')
49 | X_F = X_real_F.shift(1).dropna()
50 | y_F = Series(y,dtype='|S6')
51 |
52 |
53 |
54 | d = datetime.datetime(2015,12,31)
55 |
56 |
57 | X_train = X_F[X.index < d]
58 | X_test = X_F[X.index >= d]
59 | y_train = y_F[y.index <= d]
60 | y_test = y_F[y.index > d]
61 |
62 |
63 | # Create the (parametrised) models
64 | print("Hit Rates/Confusion Matrices:\n")
65 | models = [
66 | ('LR',LinearRegression()),
67 | ('RidgeR',Ridge (alpha = 0.005)),
68 | ('lasso',Lasso(alpha=0.0001)),
69 | ('LassoLars',LassoLars(alpha=0.0001))
70 | ]
71 |
72 | for m in models:
73 | m[1].fit(np.array(X_train),np.array(y_train))
74 | pred = m[1].predict(X_test)
75 | print "%s:\n%0.3f" % (m[0], r2_score(y_test,pred))
76 |
77 |
78 | model = Lasso(alpha=0.0001)
79 | model.fit(X_train, y_train)
80 | pred = model.predict(X_test)
81 | pred_test = pd.Series(pred, index=y_test.index)
82 | fig = plt.figure()
83 | ax = fig.add_subplot(1,1,1)
84 | ax.plot(y_test,'r',lw=0.75,linestyle='-',label='realY')
85 | ax.plot(pred_test,'b',lw=0.75,linestyle='-',label='predY')
86 | plt.legend(loc=2,prop={'size':9})
87 | plt.setp(plt.gca().get_xticklabels(), rotation=30)
88 | plt.grid(True)
89 | plt.show()
--------------------------------------------------------------------------------
/get_data/find_mean_reversion.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import pandas as pd
3 | import numpy as np
4 | from pandas import DataFrame,Series
5 | import MySQLdb as mdb
6 | import numpy as np
7 | import statsmodels.tsa.stattools as ts
8 |
9 |
10 | #get name
11 | def get_tickers_from_db(con):
12 | with con:
13 | cur = con.cursor()
14 | cur.execute('SELECT id,ticker,name FROM symbol')
15 | data = cur.fetchall()
16 | return [(d[0],d[1],d[2]) for d in data]
17 |
18 | #get data from 2010 to 2015
19 | def get_2010_2015(ticker_id,ticker_name,con):
20 | with con:
21 | cur = con.cursor()
22 | cur.execute('SELECT price_date,close_price from daily_price where (symbol_id = %s) and (price_date BETWEEN "20100101" AND "20151231")' % ticker_id)
23 | ticker_data = cur.fetchall()
24 | dates = np.array([d[0] for d in ticker_data])
25 | t_data = np.array([d[1] for d in ticker_data])
26 | ticker_data = np.array(t_data,dtype='float64')
27 |
28 | return ticker_data
29 |
30 | #hurst
31 | def hurst(ts):
32 | lags = range(2,100)
33 | tau = [np.sqrt(np.std(np.subtract(ts[lag:],ts[:-lag]))) for lag in lags]
34 | poly = np.polyfit(np.log(lags),np.log(tau),1)
35 | return poly[0] * 2.0
36 |
37 |
38 |
39 | if __name__ == '__main__':
40 | #connect to db
41 | db_host = 'localhost'
42 | db_user = 'root'
43 | db_pass = ''
44 | db_name = 'securities_master'
45 | con = mdb.connect(db_host, db_user, db_pass, db_name)
46 |
47 | #get 300 names and id
48 | tickers = get_tickers_from_db(con)
49 |
50 | all_hurst_data = []
51 | all_adf_data = []
52 |
53 | #get data of 2010-2015
54 | for i in range(len(tickers)):
55 | ticker = tickers[i]
56 | ticker_id = ticker[1]
57 | ticker_name = ticker[2]
58 | ticker_data = get_2010_2015(ticker_id,ticker_name,con)
59 |
60 |
61 | #数量太少报maxlag should be < nobs
62 | if ticker_data.shape[0] < 100:
63 | continue
64 |
65 | print '========================================'
66 | print '========================================'
67 | print '========================================'
68 | #hust
69 | t_hurst = hurst(ticker_data)
70 | if t_hurst < 0.3:
71 | all_hurst_data.append((ticker_id,ticker_name))
72 | print 'Hurst %s : %s' % (ticker_name,t_hurst)
73 |
74 | #adf test
75 | t_adf = ts.adfuller(ticker_data,1)
76 | if t_adf[0] < t_adf[4]['5%']:
77 | all_adf_data.append((ticker_id,ticker_name))
78 | print 'ADF test %s : %s' % (ticker_name,t_adf)
79 | print '========================================'
80 | print '========================================'
81 | print '========================================'
82 |
83 |
84 |
85 |
86 | print all_hurst_data
87 | print all_adf_data
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
--------------------------------------------------------------------------------
/deal_data/deal_hs300.py:
--------------------------------------------------------------------------------
1 | # 股票价格的预测分类/回归模型
2 | # 1. price between 10-50
3 | # 2. 沪深三百内的
4 | # 3. average daily volume (ADV) in the middle 33 percentile
5 |
6 | import db
7 | from db import get_10_50_by_id,get_tickers_from_db,get_day_volumn_33_66
8 | import datetime
9 | import time
10 | import pandas as pd
11 | import numpy as np
12 | from pandas import DataFrame,Series
13 | import matplotlib.pyplot as plt
14 | import forecast
15 |
16 |
17 | def get_33_and_66_volumn(date_list):
18 | date_list = [str(date_list[i])[:10] for i in range(len(date_list))]
19 | date_obj = {}
20 | for i in range(len(date_list)):
21 | day = date_list[i]
22 | day_volume = np.mean(np.array(get_day_volumn_33_66(day)))
23 | if np.isnan(day_volume) :
24 | continue
25 | t33 = int(day_volume * 0.33)
26 | t66 = int(day_volume * 0.66)
27 | date_obj[day] = (t33,t66)
28 | print day,day_volume
29 | return date_obj
30 |
31 |
32 | tickers = get_tickers_from_db()
33 | new_date_list = pd.date_range('1/1/2015', '12/31/2015', freq='1D')
34 | daily_volumn = get_33_and_66_volumn(new_date_list)
35 | selected_data = []
36 | best_performer = []
37 |
38 |
39 |
40 | for i in range(len(tickers)):
41 | this_ticker = tickers[i]
42 | ticker_id = this_ticker[0]
43 | ticker_name = this_ticker[1]
44 | data = get_10_50_by_id(ticker_id);
45 | g_data = []
46 | index_list = []
47 | for i in range(len(data)):
48 | date = data[i][0]
49 | date = str(date)[0:10]
50 | volume = data[i][5]
51 | volume_range = range(daily_volumn[date][0],daily_volumn[date][1])
52 | if volume > daily_volumn[date][0] and volume < daily_volumn[date][1]:
53 | g_data.append([data[i][1],data[i][2],data[i][3],data[i][4],data[i][5]])
54 | index_list.append(data[i][0])
55 | if len(g_data) < 50:
56 | continue
57 | t_ticker = DataFrame(g_data,index=index_list,dtype='float64',columns=['open','high','low','close','volume'])
58 | t_ticker = t_ticker.reindex(new_date_list,method='ffill').fillna(method='bfill')
59 | selected_data.append(t_ticker)
60 | #得到表现最好的5个feature下的数据
61 | t_ticker = forecast.get_good_feature(t_ticker)
62 | best_regression_r2 = forecast.get_regression_r2(t_ticker)
63 | best_classification_r2 = forecast.get_classification_r2(t_ticker)
64 |
65 | #取最好的前20个
66 | if best_regression_r2 > 0.85:
67 | this_ticker_node = (best_regression_r2,best_classification_r2,ticker_id)
68 | if len(best_performer) < 20 :
69 | best_performer.append(this_ticker_node)
70 | best_performer = sorted(best_performer)
71 | else:
72 | theLast_reg_r2 = best_performer[0][0]
73 | if best_regression_r2 > theLast_reg_r2 :
74 | best_performer = best_performer[1:]
75 | best_performer.append(this_ticker_node)
76 |
77 | print best_performer
78 | print '可利用的数据:',len(selected_data)
79 |
80 |
81 |
82 |
--------------------------------------------------------------------------------
/back_test_system/self/backtest.py:
--------------------------------------------------------------------------------
1 | #coding: utf-8
2 |
3 | from __future__ import print_function
4 |
5 | import datetime
6 | import pprint
7 | try:
8 | import Queue as queue
9 | except ImportError:
10 | import queue
11 | import time
12 |
13 | class Backtest(object):
14 | def __init__(self,csv_dir,symbol_list,initial_capital,hearbeat,start_date,data_handler,execution_handler,portfolio,strategy):
15 | self.csv_dir = csv_dir
16 | self.symbol_list = symbol_list
17 | self.initial_capital = initial_capital
18 | self.hearbeat = hearbeat
19 | self.start_date = start_date
20 |
21 | self.data_handler_cls = data_handler
22 | self.execution_handler_cls = execution_handler
23 | self.portfolio_cls = portfolio
24 | self.strategy_cls = strategy
25 |
26 | self.events = queue.Queue()
27 |
28 | self.signals = 0
29 | self.orders = 0
30 | self.fills = 0
31 | self.num_strats = 1
32 | self._generate_trading_instances()
33 |
34 | def _generate_trading_instances(self):
35 | print("Creating DataHandler, Strategy, Portfolio and ExecutionHandler")
36 |
37 | self.data_handler = self.data_handler_cls(self.events,self.csv_dir,self.symbol_list)
38 | self.strategy = self.strategy_cls(self.data_handler,self.events)
39 | self.portfolio = self.portfolio_cls(self.data_handler,self.events,self.start_date,self.initial_capital)
40 | self.execution_handler = self.execution_handler_cls(self.events)
41 |
42 | def _run_backtest(self):
43 | i = 0
44 | while True:
45 | i += 1
46 | print (i)
47 | if self.data_handler.continue_backtest == True:
48 | self.data_handler.update_bars()
49 | else:
50 | break
51 |
52 | while True:
53 | try:
54 | event = self.events.get(False)
55 | except queue.Empty:
56 | break
57 | else:
58 | if event is not None:
59 | if event.type == 'MARKET':
60 | self.strategy.calculate_signals(event)
61 | self.Portfolio.update_timeindex(event)
62 | elif event.type == 'SIGNAL':
63 | self.signals += 1
64 | self.portfolio.update_signal(event)
65 | elif event.type == 'ORDER':
66 | self.orders += 1
67 | self.execution_handler.execute_order(event)
68 | elif event.type == 'FILL':
69 | self.fills += 1
70 | self.portfolio.updatee_fill(event)
71 | time.sleep(self.hearbeat)
72 |
73 | def _output_performance(self):
74 | self.Portfolio.create_equity_curve_dataframe()
75 | print ('Creating summary stats...')
76 | stats = self.Portfolio.output_summary_stats()
77 |
78 | print ('Creating equity curve...')
79 | print (self.Portfolio.equity_curve.tail(10))
80 | pprint.pprint(stats)
81 |
82 | print ('Signal: %s' % self.signals)
83 | print ('Orders: %s' % self.orders)
84 | print ('Fills: %s' % self.fills)
85 |
86 | def simulate_trading(self):
87 | self._run_backtest()
88 | self._output_performance()
--------------------------------------------------------------------------------
/util/ml_util.py:
--------------------------------------------------------------------------------
1 | #coding:utf-8
2 | import numpy as np
3 | import pandas as pd
4 | from pandas import DataFrame,Series
5 | import datetime
6 | import matplotlib.pyplot as plt
7 | import Feature_utils
8 | #回归
9 | from sklearn.cross_validation import train_test_split
10 | from sklearn.metrics import classification_report
11 | from sklearn.svm import SVC
12 | from sklearn.linear_model import Lasso,LinearRegression,Ridge,LassoLars
13 | from sklearn.metrics import r2_score
14 | from sklearn import linear_model
15 | #分类
16 | from sklearn.ensemble import RandomForestClassifier
17 | from sklearn.linear_model import LogisticRegression
18 | from sklearn.lda import LDA
19 | from sklearn.metrics import confusion_matrix
20 | from sklearn.qda import QDA
21 | from sklearn.svm import LinearSVC, SVC
22 |
23 |
24 | def get_regression_r2(ticker_data):
25 | data_len = len(ticker_data)
26 | split_line = int(data_len * 0.2)
27 | target_X = ticker_data.drop('close',1)[0:10]
28 | target_Y = ticker_data['close'][0:10]
29 | X = ticker_data.drop('close',1)[10:].dropna()
30 | y = ticker_data['close'].shift(10).dropna()
31 |
32 |
33 | # X = ticker_data.drop('close',1).shift(-10).dropna()
34 | # y = ticker_data['close'][:-10].dropna()
35 |
36 | # X = ticker_data.drop('close',1).dropna()
37 | # y = ticker_data['close'].dropna()
38 |
39 | X_test = X.ix[:split_line]
40 | X_train = X.ix[split_line:]
41 | y_test = y.ix[:split_line]
42 | y_train = y.ix[split_line:]
43 |
44 | print target_X
45 |
46 | models = [
47 | ('LR',LinearRegression()),
48 | ('RidgeR',Ridge (alpha = 0.005)),
49 | ('lasso',Lasso(alpha=0.00001)),
50 | ('LassoLars',LassoLars(alpha=0.00001))]
51 |
52 | best_r2 = (models[0][1],0)
53 | for m in models:
54 | m[1].fit(np.array(X_train),np.array(y_train))
55 | pred = m[1].predict(X_test.fillna(0))
56 | r2 = r2_score(y_test,pred)
57 | if r2 > best_r2[1]:
58 | best_r2 = (m[1],r2)
59 | # print "%s:\n%0.3f" % (m[0], r2_score(np.array(y_test),np.array(pred)))
60 |
61 | print 'the best regression is:',best_r2
62 |
63 | model = best_r2[0]
64 | model.fit(X_train, y_train)
65 | pred = model.predict(target_X.fillna(0))
66 | pred_test = pd.Series(pred, index=target_X.index)
67 |
68 | fig = plt.figure()
69 | ax = fig.add_subplot(1,1,1)
70 | ax.plot(target_Y,'r',lw=0.75,linestyle='-',label='realY')
71 | ax.plot(pred_test,'b',lw=0.75,linestyle='-',label='predY')
72 |
73 | plt.legend(loc=2,prop={'size':9})
74 | plt.setp(plt.gca().get_xticklabels(), rotation=30)
75 | plt.grid(True)
76 | plt.show()
77 |
78 | return best_r2
79 |
80 |
81 | def get_classification_r2(ticker_data):
82 |
83 |
84 | data_len = len(ticker_data)
85 | split_line = int(data_len * 0.8)
86 | X = ticker_data.drop('close',1)[:-1]
87 | y = Series(ticker_data['close'].shift(-1).dropna(),dtype='|S6')
88 |
89 | X_train = X.ix[:split_line]
90 | X_test = X.ix[split_line:]
91 | y_train = y.ix[:split_line]
92 | y_test = y.ix[split_line:]
93 |
94 |
95 | models = [("LR", LogisticRegression()),
96 | ("LDA", LDA()),
97 | ("LSVC", LinearSVC()),
98 | ("RSVM", SVC(
99 | C=1000000.0, cache_size=200, class_weight=None,
100 | coef0=0.0, degree=3, gamma=0.0001, kernel='rbf',
101 | max_iter=-1, probability=False, random_state=None,
102 | shrinking=True, tol=0.001, verbose=False)
103 | ),
104 | ("RF", RandomForestClassifier(
105 | n_estimators=1000, criterion='gini',
106 | max_depth=None, min_samples_split=2,
107 | min_samples_leaf=1, max_features='auto',
108 | bootstrap=True, oob_score=False, n_jobs=1,
109 | random_state=None, verbose=0)
110 | )]
111 |
112 | best = (0,0)
113 | for m in models:
114 | m[1].fit(X_train, y_train)
115 | pred = m[1].predict(X_test)
116 | name = m[0]
117 | score = m[1].score(X_test, y_test)
118 | if score > best[1]:
119 | best = (name,score)
120 | print 'the best cluster is:' , best
121 | return best
122 |
123 |
124 |
--------------------------------------------------------------------------------
/back_test_system/mac.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import numpy as np
3 |
4 | from backtest import Backtest
5 | from data import HistoricCSVDataHandler
6 | from event import SignalEvent
7 | from execution import SimulatedExecutionHandler
8 | from portfolio import Portfolio
9 | from strategy import Strategy
10 |
11 |
12 | class MovingAverageCrossStrategy(Strategy):
13 | """
14 | Carries out a basic Moving Average Crossover strategy with a
15 | short/long simple weighted moving average. Default short/long
16 | windows are 100/400 periods respectively.
17 | """
18 |
19 | def __init__(self, bars, events, short_window=100, long_window=400):
20 | """
21 | Initialises the buy and hold strategy.
22 |
23 | Parameters:
24 | bars - The DataHandler object that provides bar information
25 | events - The Event Queue object.
26 | short_window - The short moving average lookback.
27 | long_window - The long moving average lookback.
28 | """
29 | self.bars = bars
30 | self.symbol_list = self.bars.symbol_list
31 | self.events = events
32 | self.short_window = short_window
33 | self.long_window = long_window
34 |
35 | # Set to True if a symbol is in the market
36 | self.bought = self._calculate_initial_bought()
37 |
38 | def _calculate_initial_bought(self):
39 | """
40 | Adds keys to the bought dictionary for all symbols
41 | and sets them to 'OUT'.
42 | """
43 | bought = {}
44 | for s in self.symbol_list:
45 | bought[s] = 'OUT'
46 | return bought
47 |
48 | def calculate_signals(self, event):
49 | """
50 | Generates a new set of signals based on the MAC
51 | SMA with the short window crossing the long window
52 | meaning a long entry and vice versa for a short entry.
53 |
54 | Parameters
55 | event - A MarketEvent object.
56 | """
57 | if event.type == 'MARKET':
58 | for symbol in self.symbol_list:
59 | bars = self.bars.get_latest_bars_values(symbol, "close", N=self.long_window)
60 |
61 | if bars is not None and bars != []:
62 | short_sma = np.mean(bars[-self.short_window:])
63 | long_sma = np.mean(bars[-self.long_window:])
64 |
65 | dt = self.bars.get_latest_bar_datetime(symbol)
66 | sig_dir = ""
67 | strength = 1.0
68 | strategy_id = 1
69 |
70 | if short_sma > long_sma and self.bought[symbol] == "OUT":
71 | sig_dir = 'LONG'
72 | signal = SignalEvent(strategy_id, symbol, dt, sig_dir, strength)
73 | self.events.put(signal)
74 | self.bought[symbol] = 'LONG'
75 |
76 | elif short_sma < long_sma and self.bought[symbol] == "LONG":
77 | sig_dir = 'EXIT'
78 | signal = SignalEvent(strategy_id, symbol, dt, sig_dir, strength)
79 | self.events.put(signal)
80 | self.bought[symbol] = 'OUT'
81 |
82 |
83 | if __name__ == "__main__":
84 | csv_dir = 'data'
85 | symbol_list = ['600050']
86 | initial_capital = 100000.0
87 | start_date = datetime.datetime(1990,1,1,0,0,0)
88 | heartbeat = 0.0
89 |
90 | backtest = Backtest(csv_dir,
91 | symbol_list,
92 | initial_capital,
93 | heartbeat,
94 | start_date,
95 | HistoricCSVDataHandler,
96 | SimulatedExecutionHandler,
97 | Portfolio,
98 | MovingAverageCrossStrategy)
99 |
100 | backtest.simulate_trading()
101 |
--------------------------------------------------------------------------------
/deal_data/FeatureUtils.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 | import numpy as np
4 | from pandas import DataFrame,Series
5 | import matplotlib.pyplot as plt
6 | import tushare as ts
7 | from sklearn.datasets import make_classification
8 | from sklearn.ensemble import ExtraTreesClassifier
9 |
10 |
11 | #CCI
12 | def CCI(data,ndays):
13 | TP = (data['high'] + data['low'] + data['close'])/3
14 | CCI = pd.Series((TP - pd.rolling_mean(TP, ndays)) / (0.015 * pd.rolling_std(TP, ndays)),name='CCI')
15 | data = data.join(CCI)
16 | return data
17 |
18 | #timeLag
19 | def TL(data,ndays):
20 | index = data.index
21 | pH = data['high'].resample(str(ndays) + 'D').max().reindex(index).fillna(method='bfill')
22 | pL = data['low'].resample(str(ndays) + 'D').max().reindex(index).fillna(method='bfill')
23 | pO = data['open'] - data['open'].shift(1)
24 | timeLag = pO/(pH - pL)
25 | timeLag.name = 'TL'
26 | data = data.join(timeLag)
27 | return data
28 |
29 |
30 | #Ease of Movement
31 | def EVM(data,ndays):
32 | dm = ((data['high'] + data['low'])/2) - ((data['high'].shift(1) + data['low'].shift(1))/2)
33 | br = (data['volume']/100000000)/((data['high'] - data['low']))
34 | EVM = dm/br
35 | EVM_MA = pd.Series(pd.rolling_mean(EVM,ndays),name='EVM')
36 | data = data.join(EVM_MA)
37 | return data
38 |
39 | # Simple Moving Average
40 | def SMA(data, ndays):
41 | SMA = pd.Series(pd.rolling_mean(data['close'], ndays), name = 'SMA')
42 | data = data.join(SMA)
43 | return data
44 |
45 | # Exponentially-weighted Moving Average
46 | def EWMA(data, ndays):
47 | EMA = pd.Series(pd.ewma(data['close'], span = ndays, min_periods = ndays - 1),
48 | name = 'EWMA_' + str(ndays))
49 | data = data.join(EMA)
50 | return data
51 |
52 |
53 | # Rate of Change (ROC)
54 | def ROC(data,n):
55 | N = data['close'].diff(n)
56 | D = data['close'].shift(n)
57 | ROC = pd.Series(N/D,name='Rate of Change')
58 | data = data.join(ROC)
59 | return data
60 |
61 | # Force Index
62 | def ForceIndex(data, ndays):
63 | FI = pd.Series(data['close'].diff(ndays) * data['volume'], name = 'ForceIndex')
64 | data = data.join(FI)
65 | return data
66 |
67 | # Compute the Bollinger Bands
68 | def BBANDS(data, ndays):
69 | MA = pd.Series(pd.rolling_mean(data['close'], ndays))
70 | SD = pd.Series(pd.rolling_std(data['close'], ndays))
71 | b1 = MA + (2 * SD)
72 | B1 = pd.Series(b1, name = 'Upper BollingerBand')
73 | b2 = MA - (2 * SD)
74 | B2 = pd.Series(b2, name = 'Lower BollingerBand')
75 | data = data.join([B1,B2])
76 | return data
77 |
78 |
79 |
80 | def plotTwoData(data1,data2):
81 | fig = plt.figure(figsize=(7,5))
82 | ax = fig.add_subplot(2, 1, 1)
83 | ax.set_xticklabels([])
84 | plt.plot(data1,lw=1)
85 | plt.title(str(data1.name) + 'Price Chart')
86 | plt.ylabel('Close Price')
87 | plt.grid(True)
88 | bx = fig.add_subplot(2, 1, 2)
89 | plt.plot(data2,'k',lw=0.75,linestyle='-',label='CCI')
90 | plt.legend(loc=2,prop={'size':9.5})
91 | plt.ylabel(str(data2.name) + 'values')
92 | plt.grid(True)
93 | plt.setp(plt.gca().get_xticklabels(), rotation=30)
94 | plt.show()
95 |
96 |
97 |
98 | def toDatetime(data):
99 | data.index = pd.to_datetime(data.index)
100 | return data
101 |
102 |
103 |
104 | def forestFindFeature(X,y,n):
105 | # Build a forest and compute the feature importances
106 | forest = ExtraTreesClassifier(n_estimators=n,random_state=0)
107 | forest.fit(X, y)
108 | importances = forest.feature_importances_
109 | std = np.std([tree.feature_importances_ for tree in forest.estimators_],axis=0)
110 | indices = np.argsort(importances)[::-1]
111 |
112 | #Print the feature ranking
113 | print("Feature ranking:")
114 |
115 | x_columns = X.columns
116 | features = []
117 | for f in range(X.shape[1]):
118 | features.append(x_columns[int(indices[f])])
119 | # print f,indices[f],x_columns[int(indices[f])],'===========', importances[indices[f]]
120 | return features
121 |
122 |
123 |
124 |
125 |
--------------------------------------------------------------------------------
/back_test_system/self/data.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from abc import ABCMeta,abstractmethod
3 |
4 | import datetime
5 | import os,os.path
6 |
7 | import numpy as np
8 | import pandas as pd
9 |
10 | from event import MarketEvent
11 |
12 | class DataHandler(object):
13 | __metaclass__ = ABCMeta
14 | @abstractmethod
15 | def get_latest_bar(self,symbol):
16 | raise NotImplementedError("Should implement get_latest_bar()")
17 |
18 | @abstractmethod
19 | def get_latest_bars(self,symbol,N=1):
20 | raise NotImplementedError("Should implement get_latest_bars()")
21 |
22 | @abstractmethod
23 | def get_latest_bar_datetime(self,symbol):
24 | raise NotImplementedError("Should implement get_latest_bar_datetime()")
25 |
26 | @abstractmethod
27 | def get_latest_bar_value(self,symbol,val_type):
28 | raise NotImplementedError("Should implement get_latest_bar_value()")
29 |
30 | @abstractmethod
31 | def get_latest_bars_value(self,symbol,val_type,N=1):
32 | raise NotImplementedError("Should implement get_latest_bars_values()")
33 |
34 | def update_bars(self):
35 | raise NotImplementedError("Should implement update_bars()")
36 |
37 |
38 | class HistoricCSVDataHandler(DataHandler):
39 |
40 | def __init__(self,events,csv_dir,symbol_list):
41 | self.events = events
42 | self.csv_dir = csv_dir
43 | self.symbol_list = symbol_list
44 |
45 | self.symbol_data = {}
46 | self.latest_symbol_data = {}
47 | self.continue_backtest = True
48 | self.bar_index = 0
49 |
50 | self._open_convert_csv_files()
51 |
52 | def _open_convert_csv_files(self):
53 | comb_index = None
54 | for s in self.symbol_list:
55 | self.symbol_data[s] = pd.io.parsers.read_csv(os.path.join(self.csv_dir,'%s.csv' % s), header=0,index_col=0,parse_dates=True,names=['datetime','open','high','low','close','volume','adj_close']).sort()
56 | if comb_index is None:
57 | comb_index = self.symbol_data[s].index
58 | else:
59 | comb_index.union(self.symbol_data[s].index)
60 | self.latest_symbol_data[s] = []
61 |
62 | for s in self.symbol_list:
63 | self.symbol_data[s] = self.symbol_data[s].reindex(index=comb_index,method='pad').iterrows()
64 |
65 | def _get_new_bar(self,symbol):
66 | for b in self.symbol_data[symbol]:
67 | yield b
68 |
69 | def get_latest_bar(self,symbol):
70 | try:
71 | bars_list = self.latest_symbol_data[symbol]
72 |
73 | except KeyError:
74 | print ("That symbol is not available in the historical data set.")
75 | raise
76 |
77 | else:
78 | return bars_list[-1]
79 |
80 | def get_latest_bars(self,symbol,N=1):
81 | try:
82 | bars_list = self.latest_symbol_data[symbol]
83 | except KeyError:
84 | print('That symbol is not available in the historical data set.')
85 | raise
86 | else:
87 | return bars_list[-N:]
88 |
89 | def get_latest_bar_datetime(self,symbol):
90 | try:
91 | bars_list = self.latest_symbol_data[symbol]
92 | except KeyError:
93 | print("That symbol is not available in the historical data set.")
94 | raise
95 | else:
96 | return bars_list[-1][0]
97 |
98 | def get_latest_bar_value(self,symbol,val_type):
99 | try:
100 | bars_list = self.latest_symbol_data[symbol]
101 | except KeyError:
102 | print("That symbol is not available in the historical data set.")
103 | raise
104 | else:
105 | return getattr(bars_list[-1][1],val_type)
106 |
107 | def get_latest_bars_values(self,symbol,val_type,N=1):
108 | try:
109 | bars_list = self.get_latest_bars(symbol,N)
110 | except KeyError:
111 | print("That symbol is not available in the historical data set.")
112 | raise
113 | else:
114 | return np.array([getattr(b[1],val_type) for b in bars_list])
115 |
116 | def update_bars(self):
117 | for s in self.symbol_list:
118 | try:
119 | bar = next(self._get_new_bar(s))
120 | except StopIteration:
121 | self.continue_backtest = False
122 | else:
123 | if bar is not None:
124 | self.latest_symbol_data[s].append(bar)
125 | self.events.put(MarketEvent())
--------------------------------------------------------------------------------
/deal_data/forecast.py:
--------------------------------------------------------------------------------
1 | #coding:utf-8
2 | import numpy as np
3 | import pandas as pd
4 | from pandas import DataFrame,Series
5 | import datetime
6 | import matplotlib.pyplot as plt
7 | import FeatureUtils
8 | #回归
9 | from sklearn.cross_validation import train_test_split
10 | from sklearn.metrics import classification_report
11 | from sklearn.svm import SVC
12 | from sklearn.linear_model import Lasso,LinearRegression,Ridge,LassoLars
13 | from sklearn.metrics import r2_score
14 | from sklearn import linear_model
15 | #分类
16 | from sklearn.ensemble import RandomForestClassifier
17 | from sklearn.linear_model import LogisticRegression
18 | from sklearn.lda import LDA
19 | from sklearn.metrics import confusion_matrix
20 | from sklearn.qda import QDA
21 | from sklearn.svm import LinearSVC, SVC
22 |
23 |
24 |
25 | def get_good_feature(ticker_data):
26 | ticker_data = FeatureUtils.CCI(ticker_data,10)
27 | ticker_data = FeatureUtils.TL(ticker_data,10)
28 | ticker_data = FeatureUtils.EVM(ticker_data,10)
29 | ticker_data = FeatureUtils.SMA(ticker_data,10)
30 | ticker_data = FeatureUtils.EWMA(ticker_data,10)
31 | ticker_data = FeatureUtils.ROC(ticker_data,10)
32 | ticker_data = FeatureUtils.ForceIndex(ticker_data,10)
33 | ticker_data = FeatureUtils.BBANDS(ticker_data,10)
34 | ticker_data = ticker_data.dropna()
35 | #formlization
36 | ticker_data = (ticker_data - ticker_data.mean())/(ticker_data.max() - ticker_data.min())
37 | #get today and next day
38 | X = DataFrame(ticker_data.drop('close',1).fillna(0),dtype='float64')[:-1]
39 | #y要把后一天的日期跟前一天对其
40 | y = Series(ticker_data['close'].shift(-1).dropna(),dtype='|S6')
41 | #forest find the best 11 features
42 | features = FeatureUtils.forestFindFeature(X,y,100)[:11]
43 |
44 | ticker_data = ticker_data[features].join(ticker_data['close'])
45 | return ticker_data
46 |
47 | def get_regression_r2(ticker_data):
48 | data_len = len(ticker_data)
49 | split_line = int(data_len * 0.8)
50 | X = ticker_data.drop('close',1)[:-1]
51 | y = ticker_data['close'].shift(-1).dropna()
52 |
53 | X_train = X.ix[:split_line]
54 | X_test = X.ix[split_line:]
55 | y_train = y.ix[:split_line]
56 | y_test = y.ix[split_line:]
57 |
58 | models = [
59 | ('LR',LinearRegression()),
60 | ('RidgeR',Ridge (alpha = 0.005)),
61 | ('lasso',Lasso(alpha=0.00001)),
62 | ('LassoLars',LassoLars(alpha=0.00001))]
63 |
64 | best_r2 = (0,0)
65 | for m in models:
66 | m[1].fit(np.array(X_train),np.array(y_train))
67 | #因为index方面,pred出的其实是相当于往后挪了一位,跟原来的y_test是对不上的,所以x需要往前进一位
68 | #比较绕,所以从日期对应的方面去考虑
69 | pred = m[1].predict(X_test.shift(-1).fillna(0))
70 | r2 = r2_score(y_test,pred)
71 | if r2 > best_r2[1]:
72 | best_r2 = (m[0],r2)
73 | # print "%s:\n%0.3f" % (m[0], r2_score(np.array(y_test),np.array(pred)))
74 |
75 | print 'the best regression is:',best_r2
76 |
77 | model = best_r2[0]
78 | model.fit(X_train, y_train)
79 | pred = model.predict(X_test.shift(-1).fillna(0))
80 | pred_test = pd.Series(pred, index=y_test.index)
81 |
82 | fig = plt.figure()
83 | ax = fig.add_subplot(1,1,1)
84 | ax.plot(y_test,'r',lw=0.75,linestyle='-',label='realY')
85 | ax.plot(pred_test,'b',lw=0.75,linestyle='-',label='predY')
86 | plt.legend(loc=2,prop={'size':9})
87 | plt.setp(plt.gca().get_xticklabels(), rotation=30)
88 | plt.grid(True)
89 | plt.show()
90 |
91 | return best_r2
92 |
93 | def get_classification_r2(ticker_data):
94 |
95 |
96 | data_len = len(ticker_data)
97 | split_line = int(data_len * 0.8)
98 | X = ticker_data.drop('close',1)[:-1]
99 | y = Series(ticker_data['close'].shift(-1).dropna(),dtype='|S6')
100 |
101 | X_train = X.ix[:split_line]
102 | X_test = X.ix[split_line:]
103 | y_train = y.ix[:split_line]
104 | y_test = y.ix[split_line:]
105 |
106 |
107 | models = [("LR", LogisticRegression()),
108 | ("LDA", LDA()),
109 | # ("QDA", QDA()),
110 | ("LSVC", LinearSVC()),
111 | ("RSVM", SVC(
112 | C=1000000.0, cache_size=200, class_weight=None,
113 | coef0=0.0, degree=3, gamma=0.0001, kernel='rbf',
114 | max_iter=-1, probability=False, random_state=None,
115 | shrinking=True, tol=0.001, verbose=False)
116 | ),
117 | ("RF", RandomForestClassifier(
118 | n_estimators=1000, criterion='gini',
119 | max_depth=None, min_samples_split=2,
120 | min_samples_leaf=1, max_features='auto',
121 | bootstrap=True, oob_score=False, n_jobs=1,
122 | random_state=None, verbose=0)
123 | )]
124 |
125 | best = (0,0)
126 | for m in models:
127 | m[1].fit(X_train, y_train)
128 | pred = m[1].predict(X_test)
129 | name = m[0]
130 | score = m[1].score(X_test, y_test)
131 | if score > best[1]:
132 | best = (name,score)
133 | print 'the best cluster is:' , best
134 | return best
135 |
136 |
137 |
138 |
139 |
140 |
--------------------------------------------------------------------------------
/util/Feature_utils.py:
--------------------------------------------------------------------------------
1 | #coding: utf-8
2 | import pandas as pd
3 | import numpy as np
4 | from pandas import DataFrame,Series
5 | import matplotlib.pyplot as plt
6 | import tushare as ts
7 | from sklearn.datasets import make_classification
8 | from sklearn.ensemble import ExtraTreesClassifier
9 |
10 |
11 | #CCI
12 | def CCI(data,ndays):
13 | TP = (data['high'] + data['low'] + data['close'])/3
14 | CCI = pd.Series((TP - pd.rolling_mean(TP, ndays)) / (0.015 * pd.rolling_std(TP, ndays)),name='CCI')
15 | data = data.join(CCI)
16 | return data
17 |
18 | #timeLag
19 | def TL(data,ndays):
20 | index = data.index
21 | pH = data['high'].resample(str(ndays) + 'D').max().reindex(index).fillna(method='bfill')
22 | pL = data['low'].resample(str(ndays) + 'D').max().reindex(index).fillna(method='bfill')
23 | pO = data['open'] - data['open'].shift(1)
24 | timeLag = pO/(pH - pL)
25 | timeLag.name = 'TL'
26 | data = data.join(timeLag)
27 | return data
28 |
29 |
30 | #Ease of Movement
31 | def EVM(data,ndays):
32 | dm = ((data['high'] + data['low'])/2) - ((data['high'].shift(1) + data['low'].shift(1))/2)
33 | br = (data['volume']/100000000)/((data['high'] - data['low']))
34 | EVM = dm/br
35 | EVM_MA = pd.Series(pd.rolling_mean(EVM,ndays),name='EVM')
36 | data = data.join(EVM_MA)
37 | return data
38 |
39 | # Simple Moving Average
40 | def SMA(data, ndays):
41 | SMA = pd.Series(pd.rolling_mean(data['close'], ndays), name = 'SMA')
42 | data = data.join(SMA)
43 | return data
44 |
45 | # Exponentially-weighted Moving Average
46 | def EWMA(data, ndays):
47 | EMA = pd.Series(pd.ewma(data['close'], span = ndays, min_periods = ndays - 1),
48 | name = 'EWMA_' + str(ndays))
49 | data = data.join(EMA)
50 | return data
51 |
52 |
53 | # Rate of Change (ROC)
54 | def ROC(data,n):
55 | N = data['close'].diff(n)
56 | D = data['close'].shift(n)
57 | ROC = pd.Series(N/D,name='Rate of Change')
58 | data = data.join(ROC)
59 | return data
60 |
61 | # Force Index
62 | def ForceIndex(data, ndays):
63 | FI = pd.Series(data['close'].diff(ndays) * data['volume'], name = 'ForceIndex')
64 | data = data.join(FI)
65 | return data
66 |
67 | # Compute the Bollinger Bands
68 | def BBANDS(data, ndays):
69 | MA = pd.Series(pd.rolling_mean(data['close'], ndays))
70 | SD = pd.Series(pd.rolling_std(data['close'], ndays))
71 | b1 = MA + (2 * SD)
72 | B1 = pd.Series(b1, name = 'Upper BollingerBand')
73 | b2 = MA - (2 * SD)
74 | B2 = pd.Series(b2, name = 'Lower BollingerBand')
75 | data = data.join([B1,B2])
76 | return data
77 |
78 |
79 |
80 | def plotTwoData(data1,data2):
81 | fig = plt.figure(figsize=(7,5))
82 | ax = fig.add_subplot(2, 1, 1)
83 | ax.set_xticklabels([])
84 | plt.plot(data1,lw=1)
85 | plt.title(str(data1.name) + 'Price Chart')
86 | plt.ylabel('Close Price')
87 | plt.grid(True)
88 | bx = fig.add_subplot(2, 1, 2)
89 | plt.plot(data2,'k',lw=0.75,linestyle='-',label='CCI')
90 | plt.legend(loc=2,prop={'size':9.5})
91 | plt.ylabel(str(data2.name) + 'values')
92 | plt.grid(True)
93 | plt.setp(plt.gca().get_xticklabels(), rotation=30)
94 | plt.show()
95 |
96 |
97 |
98 | def toDatetime(data):
99 | data.index = pd.to_datetime(data.index)
100 | return data
101 |
102 |
103 |
104 | def forestFindFeature(X,y,n):
105 | # Build a forest and compute the feature importances
106 | forest = ExtraTreesClassifier(n_estimators=n,random_state=0)
107 | forest.fit(X, y)
108 | importances = forest.feature_importances_
109 | std = np.std([tree.feature_importances_ for tree in forest.estimators_],axis=0)
110 | indices = np.argsort(importances)[::-1]
111 |
112 | #Print the feature ranking
113 | print("Feature ranking:")
114 |
115 | x_columns = X.columns
116 | features = []
117 | for f in range(X.shape[1]):
118 | features.append(x_columns[int(indices[f])])
119 | print f,indices[f],x_columns[int(indices[f])],'===========', importances[indices[f]]
120 | return features
121 |
122 |
123 | #找到最好的几个feature
124 | def get_good_feature(ticker_data,n):
125 | ticker_data = CCI(ticker_data,10)
126 | ticker_data = TL(ticker_data,10)
127 | ticker_data = EVM(ticker_data,10)
128 | ticker_data = SMA(ticker_data,10)
129 | ticker_data = EWMA(ticker_data,10)
130 | ticker_data = ROC(ticker_data,10)
131 | ticker_data = ForceIndex(ticker_data,10)
132 | ticker_data = BBANDS(ticker_data,10)
133 | print ticker_data
134 |
135 | ticker_data = ticker_data.fillna(method='ffill').fillna(method='bfill').dropna()
136 |
137 | #formlization
138 | ticker_data = (ticker_data - ticker_data.mean())/(ticker_data.max() - ticker_data.min())
139 | #get today and next day
140 | X = DataFrame(ticker_data.drop('close',1).fillna(0),dtype='float64')
141 | #y要把后一天的日期跟前一天对其
142 | y = Series(ticker_data['close'].dropna(),dtype='|S6')
143 |
144 | #forest find the best 11 features
145 | features = forestFindFeature(X,y,100)[:n]
146 |
147 | ticker_data = ticker_data[features].join(ticker_data['close'])
148 | return ticker_data
149 |
150 |
151 |
152 |
153 |
--------------------------------------------------------------------------------
/back_test_system/backtest.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 |
4 | # backtest.py
5 |
6 | from __future__ import print_function
7 |
8 | import datetime
9 | import pprint
10 | try:
11 | import Queue as queue
12 | except ImportError:
13 | import queue
14 | import time
15 |
16 |
17 | class Backtest(object):
18 | """
19 | Enscapsulates the settings and components for carrying out
20 | an event-driven backtest.
21 | """
22 |
23 | def __init__(
24 | self, csv_dir, symbol_list, initial_capital,
25 | heartbeat, start_date, data_handler,
26 | execution_handler, portfolio, strategy
27 | ):
28 | """
29 | Initialises the backtest.
30 |
31 | Parameters:
32 | csv_dir - The hard root to the CSV data directory.
33 | symbol_list - The list of symbol strings.
34 | intial_capital - The starting capital for the portfolio.
35 | heartbeat - Backtest "heartbeat" in seconds
36 | start_date - The start datetime of the strategy.
37 | data_handler - (Class) Handles the market data feed.
38 | execution_handler - (Class) Handles the orders/fills for trades.
39 | portfolio - (Class) Keeps track of portfolio current and prior positions.
40 | strategy - (Class) Generates signals based on market data.
41 | """
42 | self.csv_dir = csv_dir
43 | self.symbol_list = symbol_list
44 | self.initial_capital = initial_capital
45 | self.heartbeat = heartbeat
46 | self.start_date = start_date
47 |
48 | self.data_handler_cls = data_handler
49 | self.execution_handler_cls = execution_handler
50 | self.portfolio_cls = portfolio
51 | self.strategy_cls = strategy
52 |
53 | self.events = queue.Queue()
54 |
55 | self.signals = 0
56 | self.orders = 0
57 | self.fills = 0
58 | self.num_strats = 1
59 |
60 | self._generate_trading_instances()
61 |
62 | def _generate_trading_instances(self):
63 | """
64 | Generates the trading instance objects from
65 | their class types.
66 | """
67 | print(
68 | "Creating DataHandler, Strategy, Portfolio and ExecutionHandler"
69 | )
70 | self.data_handler = self.data_handler_cls(self.events, self.csv_dir, self.symbol_list)
71 | self.strategy = self.strategy_cls(self.data_handler, self.events)
72 | self.portfolio = self.portfolio_cls(self.data_handler, self.events, self.start_date,
73 | self.initial_capital)
74 | self.execution_handler = self.execution_handler_cls(self.events)
75 |
76 | def _run_backtest(self):
77 | """
78 | Executes the backtest.
79 | """
80 | i = 0
81 | while True:
82 | i += 1
83 | print(i)
84 | # Update the market bars
85 | if self.data_handler.continue_backtest == True:
86 | self.data_handler.update_bars()
87 | else:
88 | break
89 |
90 | # Handle the events
91 | while True:
92 | try:
93 | event = self.events.get(False)
94 | except queue.Empty:
95 | break
96 | else:
97 | if event is not None:
98 | if event.type == 'MARKET':
99 | self.strategy.calculate_signals(event)
100 | self.portfolio.update_timeindex(event)
101 |
102 | elif event.type == 'SIGNAL':
103 | self.signals += 1
104 | self.portfolio.update_signal(event)
105 |
106 | elif event.type == 'ORDER':
107 | self.orders += 1
108 | self.execution_handler.execute_order(event)
109 |
110 | elif event.type == 'FILL':
111 | self.fills += 1
112 | self.portfolio.update_fill(event)
113 |
114 | time.sleep(self.heartbeat)
115 |
116 | def _output_performance(self):
117 | """
118 | Outputs the strategy performance from the backtest.
119 | """
120 | self.portfolio.create_equity_curve_dataframe()
121 |
122 | print("Creating summary stats...")
123 | stats = self.portfolio.output_summary_stats()
124 |
125 | print("Creating equity curve...")
126 | print(self.portfolio.equity_curve.tail(10))
127 | pprint.pprint(stats)
128 |
129 | print("Signals: %s" % self.signals)
130 | print("Orders: %s" % self.orders)
131 | print("Fills: %s" % self.fills)
132 |
133 | def simulate_trading(self):
134 | """
135 | Simulates the backtest and outputs portfolio performance.
136 | """
137 | self._run_backtest()
138 | self._output_performance()
139 |
--------------------------------------------------------------------------------
/get_data/find_pairs.py:
--------------------------------------------------------------------------------
1 | # coding:utf-8
2 | import matplotlib.pyplot as plt
3 | import MySQLdb as mdb
4 | import datetime
5 | import numpy as np
6 | from pandas import Series,DataFrame
7 | import matplotlib.dates as mdates
8 | import pandas as pd
9 | from matplotlib.collections import LineCollection
10 | from sklearn import cluster, covariance, manifold
11 | from pandas.stats.api import ols
12 | import statsmodels.tsa.stattools as ts
13 | import sys
14 | reload(sys) # Python2.5 初始化后会删除 sys.setdefaultencoding 这个方法,我们需要重新载入
15 | sys.setdefaultencoding('utf-8')
16 |
17 | #找到有协整关系的pairs
18 |
19 | def plot_price_series(df, ts1, ts2):
20 | months = mdates.MonthLocator() # every month
21 | fig, ax = plt.subplots()
22 | ax.plot(df.index, df[ts1], label=ts1)
23 | ax.plot(df.index, df[ts2], label=ts2)
24 | ax.xaxis.set_major_locator(months)
25 | ax.xaxis.set_major_formatter(mdates.DateFormatter('%b %Y'))
26 | ax.set_xlim(datetime.datetime(2016, 5, 1), datetime.datetime(2016, 10, 1))
27 | ax.grid(True)
28 | fig.autofmt_xdate()
29 |
30 | plt.xlabel('Month/Year')
31 | plt.ylabel('Price ($)')
32 | plt.title('%s and %s Daily Prices' % (ts1, ts2))
33 | plt.legend()
34 | plt.show()
35 |
36 | def plot_scatter_series(df, ts1, ts2):
37 | plt.xlabel('%s Price ($)' % ts1)
38 | plt.ylabel('%s Price ($)' % ts2)
39 | plt.title('%s and %s Price Scatterplot' % (ts1, ts2))
40 | plt.scatter(df[ts1], df[ts2])
41 | plt.show()
42 |
43 |
44 | #get name
45 | def get_tickers_from_db(con):
46 | #get name form symbol;
47 | with con:
48 | cur = con.cursor()
49 | cur.execute('SELECT id,ticker,name FROM symbol')
50 | data = cur.fetchall()
51 | return [(d[0],d[1],d[2]) for d in data]
52 |
53 | #get daily data of 300
54 | def get_daily_data_from_db(ticker,ticker_name,ticker_id,date_list,con):
55 | with con:
56 | cur = con.cursor()
57 | cur.execute('SELECT price_date,open_price,high_price,low_price,close_price,volume from daily_price where symbol_id = %s' % ticker_id)
58 | daily_data = cur.fetchall()
59 | dates = np.array([d[0] for d in daily_data])
60 | #此处以受益为基准
61 | # open_price = np.array([d[2] for d in daily_data],dtype='float64')
62 | close_price = np.array([d[4] for d in daily_data],dtype='float64')
63 | # var_price = close_price - open_price
64 | # daily_data = DataFrame(var_price,index=dates,columns=[ticker_id])
65 |
66 | daily_data = DataFrame(close_price,index=dates,columns=[ticker_name])
67 | daily_data = daily_data.reindex(date_list,method='ffill')
68 | return daily_data
69 |
70 |
71 | #dealing data with two-pair way to calculating
72 | def deal_data(whole_data):
73 | finall_pair = []
74 | print '正在计算... '
75 |
76 | for i in range(len(whole_data)):
77 | for r in range(i,len(whole_data)):
78 |
79 | if i == r:
80 | continue
81 |
82 |
83 | d1 = whole_data[i].fillna(method='pad').fillna(0)
84 | d1_name = d1.columns[0]
85 | d2 = whole_data[r].fillna(method='pad').fillna(0)
86 | d2_name = d2.columns[0]
87 | df = pd.concat([d1,d2],axis=1)
88 |
89 | res = ols(y=d1[d1_name], x=d2[d2_name])
90 | beta_hr = res.beta.x
91 | df["res"] = df[d1_name] - beta_hr*df[d2_name]
92 | cadf = ts.adfuller(df["res"])
93 |
94 | #judge比较cadf那俩值的大小
95 | cadf1 = cadf[0]
96 | cadf2 = cadf[4]['5%']
97 | #这样更明显
98 | if cadf1 - cadf2 < -4 :
99 | print '得到结果=======>>>>>>',i,r,cadf1,cadf2,d1_name,d2_name
100 | finall_pair.append((i,r))
101 | test(i,r)
102 |
103 | return finall_pair
104 |
105 | #处理最终pair数据
106 | def deal_with_fianl_data(whole,pairs):
107 | names = [];
108 | final = [];
109 | for i in range(len(whole)):
110 | name = whole[i].columns[0]
111 | names.append(name)
112 | for r in range(len(pairs)):
113 | t = pairs[r]
114 | name1 = names[t[0]]
115 | name2 = names[t[1]]
116 | final.append((name1,name2))
117 | return final
118 |
119 | def test(d1,d2):
120 | d1 = whole_data[d1].fillna(method='pad').fillna(0)
121 | d1_name = d1.columns[0]
122 | d2 = whole_data[d2].fillna(method='pad').fillna(0)
123 | d2_name = d2.columns[0]
124 | df = pd.concat([d1,d2],axis=1)
125 | plot_price_series(df, d1.columns[0], d2.columns[0])
126 | plot_scatter_series(df,d1.columns[0], d2.columns[0])
127 |
128 |
129 |
130 | if __name__ == '__main__':
131 | db_host = 'localhost'
132 | db_user = 'root'
133 | db_password = ''
134 | db_name = 'securities_master'
135 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name)
136 | date_list = pd.date_range('5/1/2016', '10/1/2016', freq='1D')
137 | tickers = get_tickers_from_db(con)
138 | whole_data = []
139 |
140 | for i in range(len(tickers)):
141 | ticker = tickers[i]
142 | ticker_name = ticker[2]
143 | ticker_id = ticker[1]
144 | daily_data = get_daily_data_from_db(ticker,ticker_name,ticker_id,date_list,con)
145 | whole_data.append(daily_data)
146 |
147 | #dealing data with two-pair way to calculating
148 | final_pair = deal_data(whole_data)
149 |
150 | #整理最终结果,因为finall-pair返回的是序号
151 | pairs = deal_with_fianl_data(whole_data,final_pair)
152 |
153 |
154 | print pairs
155 |
156 |
157 |
158 |
--------------------------------------------------------------------------------
/get_data/clustering_hs300.py:
--------------------------------------------------------------------------------
1 | # coding:utf-8
2 | import matplotlib.pyplot as plt
3 | import MySQLdb as mdb
4 | import datetime
5 | import numpy as np
6 | from pandas import Series,DataFrame
7 | import pandas as pd
8 | from matplotlib.collections import LineCollection
9 | from sklearn import cluster, covariance, manifold
10 | import sys
11 | reload(sys) # Python2.5 初始化后会删除 sys.setdefaultencoding 这个方法,我们需要重新载入
12 | sys.setdefaultencoding('utf-8')
13 |
14 |
15 |
16 |
17 | #get name
18 | def get_tickers_from_db(con):
19 | #get name form symbol;
20 | with con:
21 | cur = con.cursor()
22 | cur.execute('SELECT id,ticker,name FROM symbol')
23 | data = cur.fetchall()
24 | return [(d[0],d[1],d[2]) for d in data]
25 |
26 |
27 | #read dataset from db
28 | def get_timeline_from_db(ticker,ticker_name,ticker_id,date_list,con):
29 | with con:
30 | cur = con.cursor()
31 | cur.execute('SELECT price_date,open_price,high_price,low_price,close_price,volume from daily_price where symbol_id = %s' % ticker_id)
32 | daily_data = cur.fetchall()
33 | dates = np.array([d[0] for d in daily_data])
34 | open_price = np.array([d[2] for d in daily_data],dtype='float64')
35 | close_price = np.array([d[4] for d in daily_data],dtype='float64')
36 | var_price = close_price - open_price
37 | daily_data = DataFrame(var_price,index=dates,columns=[ticker_name])
38 | daily_data = daily_data.reindex(date_list,method='ffill')
39 |
40 | return daily_data
41 |
42 | #deal data
43 | def deal_with_data(whole_data):
44 |
45 | #concat
46 | final = pd.concat(whole_data,axis=1)
47 | #fix data
48 | final = final.fillna(method='ffill')
49 |
50 | #由于太慢算得慢,所以改
51 | final = final.ix[-150:]
52 |
53 | return final
54 |
55 | #clustering
56 | def cluster_data(data):
57 | names = data.columns
58 | edge_model = covariance.GraphLassoCV()
59 | data = np.array(data)
60 |
61 | X = data.copy().T
62 | X /= X.std(axis=0)
63 |
64 |
65 | edge_model.fit(X)
66 | _, labels = cluster.affinity_propagation(edge_model.covariance_)
67 | n_labels = labels.max()
68 |
69 |
70 | for i in range(n_labels + 1):
71 | print('Cluster %i: %s' % ((i + 1), ', '.join(names[labels == i])))
72 |
73 |
74 | #Visualization
75 | node_position_model = manifold.LocallyLinearEmbedding(n_components=2, eigen_solver='dense', n_neighbors=6)
76 | embedding = node_position_model.fit_transform(X.T).T
77 | plt.figure(1, facecolor='w', figsize=(10, 8))
78 | plt.clf()
79 | ax = plt.axes([0., 0., 1., 1.])
80 | plt.axis('off')
81 |
82 | # Display a graph of the partial correlations
83 | partial_correlations = edge_model.precision_.copy()
84 | d = 1 / np.sqrt(np.diag(partial_correlations))
85 | partial_correlations *= d
86 | partial_correlations *= d[:, np.newaxis]
87 | non_zero = (np.abs(np.triu(partial_correlations, k=1)) > 0.02)
88 |
89 | # Plot the nodes using the coordinates of our embedding
90 | plt.scatter(embedding[0], embedding[1], s=100 * d ** 2, c=labels,cmap=plt.cm.spectral)
91 |
92 | # Plot the edges
93 | start_idx, end_idx = np.where(non_zero)
94 | #a sequence of (*line0*, *line1*, *line2*), where::
95 | # linen = (x0, y0), (x1, y1), ... (xm, ym)
96 | segments = [[embedding[:, start], embedding[:, stop]] for start, stop in zip(start_idx, end_idx)]
97 | values = np.abs(partial_correlations[non_zero])
98 | lc = LineCollection(segments,zorder=0, cmap=plt.cm.hot_r,norm=plt.Normalize(0, .7 * values.max()))
99 | lc.set_array(values)
100 | lc.set_linewidths(15 * values)
101 | ax.add_collection(lc)
102 |
103 | # Add a label to each node. The challenge here is that we want to
104 | # position the labels to avoid overlap with other labels
105 | for index, (name, label, (x, y)) in enumerate(zip(names, labels, embedding.T)):
106 | name = str(name).decode('utf-8').encode('utf-8')
107 | dx = x - embedding[0]
108 | dx[index] = 1
109 | dy = y - embedding[1]
110 | dy[index] = 1
111 | this_dx = dx[np.argmin(np.abs(dy))]
112 | this_dy = dy[np.argmin(np.abs(dx))]
113 | if this_dx > 0:
114 | horizontalalignment = 'left'
115 | x = x + .002
116 | else:
117 | horizontalalignment = 'right'
118 | x = x - .002
119 | if this_dy > 0:
120 | verticalalignment = 'bottom'
121 | y = y + .002
122 | else:
123 | verticalalignment = 'top'
124 | y = y - .002
125 | plt.text(x, y, name , size=10,horizontalalignment=horizontalalignment,verticalalignment=verticalalignment,bbox=dict(facecolor='w',edgecolor=plt.cm.spectral(label / float(n_labels)),alpha=.6))
126 |
127 | plt.xlim(embedding[0].min() - .15 * embedding[0].ptp(),
128 | embedding[0].max() + .10 * embedding[0].ptp(),)
129 | plt.ylim(embedding[1].min() - .03 * embedding[1].ptp(),
130 | embedding[1].max() + .03 * embedding[1].ptp())
131 | plt.show()
132 |
133 |
134 |
135 |
136 | if __name__ == '__main__':
137 | db_host = 'localhost'
138 | db_user = 'root'
139 | db_password = ''
140 | db_name = 'securities_master'
141 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name)
142 | whole_data = []
143 | #all tickers
144 | tickers = get_tickers_from_db(con)
145 | #由于太慢算得慢,所以改
146 | date_list = pd.date_range('1/1/2016', '10/1/2016', freq='1D')
147 |
148 | for i in range(len(tickers)):
149 | ticker = tickers[i]
150 | ticker_name = ticker[2]
151 | ticker_id = ticker[1]
152 | daily_data = get_timeline_from_db(ticker,ticker_name,ticker_id,date_list,con)
153 | whole_data.append(daily_data)
154 |
155 | final_data = deal_with_data(whole_data)
156 | # cluster data
157 | cluster_data(final_data)
158 |
159 |
160 |
161 |
162 |
163 |
164 |
--------------------------------------------------------------------------------
/back_test_system/event.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 |
4 | # event.py
5 |
6 | from __future__ import print_function
7 |
8 |
9 | class Event(object):
10 | """
11 | Event is base class providing an interface for all subsequent
12 | (inherited) events, that will trigger further events in the
13 | trading infrastructure.
14 | """
15 | pass
16 |
17 |
18 | class MarketEvent(Event):
19 | """
20 | Handles the event of receiving a new market update with
21 | corresponding bars.
22 | """
23 |
24 | def __init__(self):
25 | """
26 | Initialises the MarketEvent.
27 | """
28 | self.type = 'MARKET'
29 |
30 |
31 | class SignalEvent(Event):
32 | """
33 | Handles the event of sending a Signal from a Strategy object.
34 | This is received by a Portfolio object and acted upon.
35 | """
36 |
37 | def __init__(self, strategy_id, symbol, datetime, signal_type, strength):
38 | """
39 | Initialises the SignalEvent.
40 |
41 | Parameters:
42 | strategy_id - The unique ID of the strategy sending the signal.
43 | symbol - The ticker symbol, e.g. 'GOOG'.
44 | datetime - The timestamp at which the signal was generated.
45 | signal_type - 'LONG' or 'SHORT'.
46 | strength - An adjustment factor "suggestion" used to scale
47 | quantity at the portfolio level. Useful for pairs strategies.
48 | """
49 | self.strategy_id = strategy_id
50 | self.type = 'SIGNAL'
51 | self.symbol = symbol
52 | self.datetime = datetime
53 | self.signal_type = signal_type
54 | self.strength = strength
55 |
56 |
57 | class OrderEvent(Event):
58 | """
59 | Handles the event of sending an Order to an execution system.
60 | The order contains a symbol (e.g. GOOG), a type (market or limit),
61 | quantity and a direction.
62 | """
63 |
64 | def __init__(self, symbol, order_type, quantity, direction):
65 | """
66 | Initialises the order type, setting whether it is
67 | a Market order ('MKT') or Limit order ('LMT'), has
68 | a quantity (integral) and its direction ('BUY' or
69 | 'SELL').
70 |
71 | TODO: Must handle error checking here to obtain
72 | rational orders (i.e. no negative quantities etc).
73 |
74 | Parameters:
75 | symbol - The instrument to trade.
76 | order_type - 'MKT' or 'LMT' for Market or Limit.
77 | quantity - Non-negative integer for quantity.
78 | direction - 'BUY' or 'SELL' for long or short.
79 | """
80 | self.type = 'ORDER'
81 | self.symbol = symbol
82 | self.order_type = order_type
83 | self.quantity = quantity
84 | self.direction = direction
85 |
86 | def print_order(self):
87 | """
88 | Outputs the values within the Order.
89 | """
90 | print(
91 | "Order: Symbol=%s, Type=%s, Quantity=%s, Direction=%s" %
92 | (self.symbol, self.order_type, self.quantity, self.direction)
93 | )
94 |
95 |
96 | class FillEvent(Event):
97 | """
98 | Encapsulates the notion of a Filled Order, as returned
99 | from a brokerage. Stores the quantity of an instrument
100 | actually filled and at what price. In addition, stores
101 | the commission of the trade from the brokerage.
102 |
103 | TODO: Currently does not support filling positions at
104 | different prices. This will be simulated by averaging
105 | the cost.
106 | """
107 |
108 | def __init__(self, timeindex, symbol, exchange, quantity,
109 | direction, fill_cost, commission=None):
110 | """
111 | Initialises the FillEvent object. Sets the symbol, exchange,
112 | quantity, direction, cost of fill and an optional
113 | commission.
114 |
115 | If commission is not provided, the Fill object will
116 | calculate it based on the trade size and Interactive
117 | Brokers fees.
118 |
119 | Parameters:
120 | timeindex - The bar-resolution when the order was filled.
121 | symbol - The instrument which was filled.
122 | exchange - The exchange where the order was filled.
123 | quantity - The filled quantity.
124 | direction - The direction of fill ('BUY' or 'SELL')
125 | fill_cost - The holdings value in dollars.
126 | commission - An optional commission sent from IB.
127 | """
128 | self.type = 'FILL'
129 | self.timeindex = timeindex
130 | self.symbol = symbol
131 | self.exchange = exchange
132 | self.quantity = quantity
133 | self.direction = direction
134 | self.fill_cost = fill_cost
135 |
136 | # Calculate commission
137 | if commission is None:
138 | self.commission = self.calculate_ib_commission()
139 | else:
140 | self.commission = commission
141 |
142 | def calculate_ib_commission(self):
143 | """
144 | Calculates the fees of trading based on an Interactive
145 | Brokers fee structure for API, in USD.
146 |
147 | This does not include exchange or ECN fees.
148 |
149 | Based on "US API Directed Orders":
150 | https://www.interactivebrokers.com/en/index.php?f=commission&p=stocks2
151 | """
152 | full_cost = 1.3
153 | if self.quantity <= 500:
154 | full_cost = max(1.3, 0.013 * self.quantity)
155 | else: # Greater than 500
156 | full_cost = max(1.3, 0.008 * self.quantity)
157 | return full_cost
158 |
--------------------------------------------------------------------------------
/back_test_system/self/portfolio.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import datetime
4 | from math import floor
5 | try:
6 | import Queue as queue
7 | except ImportError:
8 | import queue
9 | import numpy as np
10 | import pandas as pd
11 |
12 | from event import FillEvent,OrderEvent
13 | from performance import create_sharpe_ratio,create_drawdowns
14 |
15 |
16 | class Portfolio(object):
17 | def __init__(self,bars,event,start_date,intial_capital=100000.0):
18 | self.bars = bars
19 | self.events = events
20 | self.symbol_list = self.bars.symbol_list
21 | self.start_date = start_date
22 | self.intial_capital = intial_capital
23 | self.all_positions = self.construct_all_positions()
24 | self.current_positions = dict((k,v) for k,v in [(s,0) for s in self.symbol_list])
25 | self.all_holdings = self.construct_all_holdings()
26 | self.current_holdings = self.construct_current_holdings()
27 |
28 | def construct_all_positions(self):
29 | d = dict((k,v) for k,v in [(s,0) for s in self.symbol_list])
30 | d['datetime'] = self.start_date
31 | return [d]
32 |
33 | def construct_all_holdings(self):
34 | d = dict((k,v) for k,v in [(s,0.0) for s in self.symbol_list])
35 | d['datetime'] = self.start_date
36 | d['cash'] = self.intial_capital
37 | d['commission'] = 0.0
38 | d['total'] = self.intial_capital
39 | return [d]
40 |
41 | def construct_current_holdings(self):
42 | d = dict((k,v) for k,v in [(s,0.0) for s in self.symbol_list])
43 | d['cash'] = self.intial_capital
44 | d['commission'] = 0.0
45 | d['total'] = self.intial_capital
46 | return d
47 |
48 | def update_timeindex(self,event):
49 | latest_datetime = self.bars.get_latest_bar_datetime(self.symbol_list[0])
50 |
51 | # Update positions
52 | # ================
53 | dp = dict((k,v) for k,v in [(s,0) for s in self.symbol_list])
54 | dp['datetime'] = latest_datetime
55 |
56 | for s in self.symbol_list:
57 | dp[s] = self.current_positions[s]
58 |
59 | self.all_positions.append(dp)
60 |
61 | # Update holdings
62 | # ===============
63 | dh = divt((k,v) for k,v in [(s,0) for s in self.symbol_list])
64 |
65 | dh['datetime'] = latest_datetime
66 | dh['cash'] = self.current_holdings['cash']
67 | dh['commission'] = self.current_holdings['commission']
68 | dh['total'] = self.current_holdings['cash']
69 |
70 | for s in self.symbol_list:
71 | market_value = self.current_positions[s] * self.bars.get_latest_bar_value(s,'close')
72 | dh[s] = market_value
73 | dh['total'] += market_value
74 |
75 | self.all_holdings.append(dh)
76 |
77 | # =======================
78 | # FILL/POSITION HANDLING
79 | # =======================
80 |
81 | def update_positions_from_fill(self,fill):
82 | fill_dir = 0
83 | if fill.direction == 'BUY':
84 | fill_dir = 1
85 | if fill.direction == 'SELL':
86 | fill_dir = -1
87 | self.current_positions[fill.symbol] += fill_dir * fill.quantity
88 |
89 | def update_holdings_from_fill(self,fill):
90 | fill_dir = 0
91 | if fill_direction == 'BUY':
92 | fill_dir = 1
93 | if fill.direction == 'SELL':
94 | fill_dir = -1
95 | fill_cost = self.bars.get_latest_bar_value(fill.symbol,'adj_close')
96 | cost = fill_dir * fill_cost * fill.quantity
97 | self.current_holdings[fill.symbol] += cost
98 | self.current_holdings['commission'] += fill.commission
99 | self.current_holdings['cash'] -= (cost + fill.commission)
100 | self.current_holdings['total'] -= (cost + fill.commission)
101 |
102 | def update_fill(self,event):
103 | if event.type == 'FILL':
104 | self.update_positions_from_fill(event)
105 | self.update_holdings_from_fill(event)
106 |
107 | def generate_naive_order(self,signal):
108 | order = None
109 |
110 | symbol = signal.symbol
111 | direction = signal.signal_type
112 | strength = signal.strength
113 |
114 | mkt_quantity = 100
115 | cur_quantity = self.current_positions[symbol]
116 | order_type = 'MKT'
117 |
118 | if direction == 'LONG' and cur_quantity == 0:
119 | order = OrderEvent(symbol,order_type,mkt_quantity,'BUY')
120 | if direction == 'SHORT' and cur_quantity == 0:
121 | order = OrderEvent(symbol,order_type,mkt_quantity,'SELL')
122 | if direction == 'EXIT' and cur_quantity > 0:
123 | order = OrderEvent(symbol,order_type,abs(cur_quantity),'SELL')
124 | if direction == 'EXIT' and cur_quantity < 0:
125 | order = OrderEvent(symbol,order_type,abs(cur_quantity),'BUY')
126 |
127 | return order
128 |
129 | def update_signal(self,event):
130 | if event.type == 'SIGNAL':
131 | order_event = self.generate_naive_order(event)
132 | self.events.put(order_event)
133 |
134 |
135 | # ========================
136 | # POST-BACKTEST STATISTICS
137 | # ========================
138 |
139 | def create_equity_curve_dataframe(self):
140 | curve = pd.DataFrame(self.all_holdings)
141 | curve.set_index('datetime',inplace=True)
142 | curve['returns'] = curve['total'].pct_change()
143 | curve['equity_curve'] = (1.0 * curve['returns']).cumprod()
144 | self.equity_curve = curve
145 |
146 | def output_summary_stats(self):
147 | total_return = self.equity_curve['equity_curve'][-1]
148 | returns = self.equity_curve['returns']
149 | pnl = self.equity_curve['equity_curve']
150 |
151 | print (total_return,returns,pnl)
152 |
153 | sharpe_ratio = create_sharpe_ratio(returns,periods=252 * 60 * 6.5)
154 | drawdown,max_dd,dd_duration = create_drawdowns(pnl)
155 | self.equity_curve['drawdown'] = drawdown
156 |
157 | stats = [('Total Return','%0.2f%%' % ((total_return - 1.0) * 100.0)),('Sharpe Ratio','%0.2f' % sharpe_ratio),('Max Drawdown','%0.2f%%' % (max_dd * 100.0)),('Drawdown Duration','%d' % dd_duration)]
158 |
159 | self.equity_curve.to_csv('equity.csv')
160 |
161 | return stats
162 |
163 |
--------------------------------------------------------------------------------
/back_test_system/ib_execution.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 |
4 | # ib_execution.py
5 |
6 | from __future__ import print_function
7 |
8 | import datetime
9 | import time
10 |
11 | from ib.ext.Contract import Contract
12 | from ib.ext.Order import Order
13 | from ib.opt import ibConnection, message
14 |
15 | from event import FillEvent, OrderEvent
16 | from execution import ExecutionHandler
17 |
18 |
19 | class IBExecutionHandler(ExecutionHandler):
20 | """
21 | Handles order execution via the Interactive Brokers
22 | API, for use against accounts when trading live
23 | directly.
24 | """
25 |
26 | def __init__(
27 | self, events, order_routing="SMART", currency="USD"
28 | ):
29 | """
30 | Initialises the IBExecutionHandler instance.
31 |
32 | Parameters:
33 | events - The Queue of Event objects.
34 | """
35 | self.events = events
36 | self.order_routing = order_routing
37 | self.currency = currency
38 | self.fill_dict = {}
39 |
40 | self.tws_conn = self.create_tws_connection()
41 | self.order_id = self.create_initial_order_id()
42 | self.register_handlers()
43 |
44 | def _error_handler(self, msg):
45 | """Handles the capturing of error messages"""
46 | # Currently no error handling.
47 | print("Server Error: %s" % msg)
48 |
49 | def _reply_handler(self, msg):
50 | """Handles of server replies"""
51 | # Handle open order orderId processing
52 | if msg.typeName == "openOrder" and \
53 | msg.orderId == self.order_id and \
54 | not self.fill_dict.has_key(msg.orderId):
55 | self.create_fill_dict_entry(msg)
56 | # Handle Fills
57 | if msg.typeName == "orderStatus" and \
58 | msg.status == "Filled" and \
59 | self.fill_dict[msg.orderId]["filled"] == False:
60 | self.create_fill(msg)
61 | print("Server Response: %s, %s\n" % (msg.typeName, msg))
62 |
63 | def create_tws_connection(self):
64 | """
65 | Connect to the Trader Workstation (TWS) running on the
66 | usual port of 7496, with a clientId of 100.
67 | The clientId is chosen by us and we will need
68 | separate IDs for both the execution connection and
69 | market data connection, if the latter is used elsewhere.
70 | """
71 | tws_conn = ibConnection()
72 | tws_conn.connect()
73 | return tws_conn
74 |
75 | def create_initial_order_id(self):
76 | """
77 | Creates the initial order ID used for Interactive
78 | Brokers to keep track of submitted orders.
79 | """
80 | # There is scope for more logic here, but we
81 | # will use "1" as the default for now.
82 | return 1
83 |
84 | def register_handlers(self):
85 | """
86 | Register the error and server reply
87 | message handling functions.
88 | """
89 | # Assign the error handling function defined above
90 | # to the TWS connection
91 | self.tws_conn.register(self._error_handler, 'Error')
92 |
93 | # Assign all of the server reply messages to the
94 | # reply_handler function defined above
95 | self.tws_conn.registerAll(self._reply_handler)
96 |
97 | def create_contract(self, symbol, sec_type, exch, prim_exch, curr):
98 | """Create a Contract object defining what will
99 | be purchased, at which exchange and in which currency.
100 |
101 | symbol - The ticker symbol for the contract
102 | sec_type - The security type for the contract ('STK' is 'stock')
103 | exch - The exchange to carry out the contract on
104 | prim_exch - The primary exchange to carry out the contract on
105 | curr - The currency in which to purchase the contract"""
106 | contract = Contract()
107 | contract.m_symbol = symbol
108 | contract.m_secType = sec_type
109 | contract.m_exchange = exch
110 | contract.m_primaryExch = prim_exch
111 | contract.m_currency = curr
112 | return contract
113 |
114 | def create_order(self, order_type, quantity, action):
115 | """Create an Order object (Market/Limit) to go long/short.
116 |
117 | order_type - 'MKT', 'LMT' for Market or Limit orders
118 | quantity - Integral number of assets to order
119 | action - 'BUY' or 'SELL'"""
120 | order = Order()
121 | order.m_orderType = order_type
122 | order.m_totalQuantity = quantity
123 | order.m_action = action
124 | return order
125 |
126 | def create_fill_dict_entry(self, msg):
127 | """
128 | Creates an entry in the Fill Dictionary that lists
129 | orderIds and provides security information. This is
130 | needed for the event-driven behaviour of the IB
131 | server message behaviour.
132 | """
133 | self.fill_dict[msg.orderId] = {
134 | "symbol": msg.contract.m_symbol,
135 | "exchange": msg.contract.m_exchange,
136 | "direction": msg.order.m_action,
137 | "filled": False
138 | }
139 |
140 | def create_fill(self, msg):
141 | """
142 | Handles the creation of the FillEvent that will be
143 | placed onto the events queue subsequent to an order
144 | being filled.
145 | """
146 | fd = self.fill_dict[msg.orderId]
147 |
148 | # Prepare the fill data
149 | symbol = fd["symbol"]
150 | exchange = fd["exchange"]
151 | filled = msg.filled
152 | direction = fd["direction"]
153 | fill_cost = msg.avgFillPrice
154 |
155 | # Create a fill event object
156 | fill = FillEvent(
157 | datetime.datetime.utcnow(), symbol,
158 | exchange, filled, direction, fill_cost
159 | )
160 |
161 | # Make sure that multiple messages don't create
162 | # additional fills.
163 | self.fill_dict[msg.orderId]["filled"] = True
164 |
165 | # Place the fill event onto the event queue
166 | self.events.put(fill_event)
167 |
168 | def execute_order(self, event):
169 | """
170 | Creates the necessary InteractiveBrokers order object
171 | and submits it to IB via their API.
172 |
173 | The results are then queried in order to generate a
174 | corresponding Fill object, which is placed back on
175 | the event queue.
176 |
177 | Parameters:
178 | event - Contains an Event object with order information.
179 | """
180 | if event.type == 'ORDER':
181 | # Prepare the parameters for the asset order
182 | asset = event.symbol
183 | asset_type = "STK"
184 | order_type = event.order_type
185 | quantity = event.quantity
186 | direction = event.direction
187 |
188 | # Create the Interactive Brokers contract via the
189 | # passed Order event
190 | ib_contract = self.create_contract(
191 | asset, asset_type, self.order_routing,
192 | self.order_routing, self.currency
193 | )
194 |
195 | # Create the Interactive Brokers order via the
196 | # passed Order event
197 | ib_order = self.create_order(
198 | order_type, quantity, direction
199 | )
200 |
201 | # Use the connection to the send the order to IB
202 | self.tws_conn.placeOrder(
203 | self.order_id, ib_contract, ib_order
204 | )
205 |
206 | # NOTE: This following line is crucial.
207 | # It ensures the order goes through!
208 | time.sleep(1)
209 |
210 | # Increment the order ID for this session
211 | self.order_id += 1
212 |
--------------------------------------------------------------------------------
/back_test_system/data.py:
--------------------------------------------------------------------------------
1 |
2 | #!/usr/bin/python
3 | # -*- coding: utf-8 -*-
4 |
5 | # data.py
6 |
7 | from __future__ import print_function
8 |
9 | from abc import ABCMeta, abstractmethod
10 | import datetime
11 | import os, os.path
12 |
13 | import numpy as np
14 | import pandas as pd
15 |
16 | from event import MarketEvent
17 |
18 |
19 | class DataHandler(object):
20 | """
21 | DataHandler is an abstract base class providing an interface for
22 | all subsequent (inherited) data handlers (both live and historic).
23 |
24 | The goal of a (derived) DataHandler object is to output a generated
25 | set of bars (OHLCVI) for each symbol requested.
26 |
27 | This will replicate how a live strategy would function as current
28 | market data would be sent "down the pipe". Thus a historic and live
29 | system will be treated identically by the rest of the backtesting suite.
30 | """
31 |
32 | __metaclass__ = ABCMeta
33 |
34 | @abstractmethod
35 | def get_latest_bar(self, symbol):
36 | """
37 | Returns the last bar updated.
38 | """
39 | raise NotImplementedError("Should implement get_latest_bar()")
40 |
41 | @abstractmethod
42 | def get_latest_bars(self, symbol, N=1):
43 | """
44 | Returns the last N bars updated.
45 | """
46 | raise NotImplementedError("Should implement get_latest_bars()")
47 |
48 | @abstractmethod
49 | def get_latest_bar_datetime(self, symbol):
50 | """
51 | Returns a Python datetime object for the last bar.
52 | """
53 | raise NotImplementedError("Should implement get_latest_bar_datetime()")
54 |
55 | @abstractmethod
56 | def get_latest_bar_value(self, symbol, val_type):
57 | """
58 | Returns one of the Open, High, Low, Close, Volume or OI
59 | from the last bar.
60 | """
61 | raise NotImplementedError("Should implement get_latest_bar_value()")
62 |
63 | @abstractmethod
64 | def get_latest_bars_values(self, symbol, val_type, N=1):
65 | """
66 | Returns the last N bar values from the
67 | latest_symbol list, or N-k if less available.
68 | """
69 | raise NotImplementedError("Should implement get_latest_bars_values()")
70 |
71 | @abstractmethod
72 | def update_bars(self):
73 | """
74 | Pushes the latest bars to the bars_queue for each symbol
75 | in a tuple OHLCVI format: (datetime, open, high, low,
76 | close, volume, open interest).
77 | """
78 | raise NotImplementedError("Should implement update_bars()")
79 |
80 |
81 | class HistoricCSVDataHandler(DataHandler):
82 | """
83 | HistoricCSVDataHandler is designed to read CSV files for
84 | each requested symbol from disk and provide an interface
85 | to obtain the "latest" bar in a manner identical to a live
86 | trading interface.
87 | """
88 |
89 | def __init__(self, events, csv_dir, symbol_list):
90 | """
91 | Initialises the historic data handler by requesting
92 | the location of the CSV files and a list of symbols.
93 |
94 | It will be assumed that all files are of the form
95 | 'symbol.csv', where symbol is a string in the list.
96 |
97 | Parameters:
98 | events - The Event Queue.
99 | csv_dir - Absolute directory path to the CSV files.
100 | symbol_list - A list of symbol strings.
101 | """
102 | self.events = events
103 | self.csv_dir = csv_dir
104 | self.symbol_list = symbol_list
105 |
106 | self.symbol_data = {}
107 | self.latest_symbol_data = {}
108 | self.continue_backtest = True
109 | self.bar_index = 0
110 |
111 | self._open_convert_csv_files()
112 |
113 | def _open_convert_csv_files(self):
114 | """
115 | Opens the CSV files from the data directory, converting
116 | them into pandas DataFrames within a symbol dictionary.
117 |
118 | For this handler it will be assumed that the data is
119 | taken from Yahoo. Thus its format will be respected.
120 | """
121 | comb_index = None
122 | for s in self.symbol_list:
123 | # Load the CSV file with no header information, indexed on date
124 | self.symbol_data[s] = pd.io.parsers.read_csv(
125 | os.path.join(self.csv_dir, '%s.csv' % s),
126 | header=0, index_col=0, parse_dates=True,
127 | names=[
128 | 'datetime', 'open', 'high',
129 | 'low', 'close', 'volume', 'adj_close'
130 | ]
131 | ).sort()
132 |
133 | # Combine the index to pad forward values
134 | if comb_index is None:
135 | comb_index = self.symbol_data[s].index
136 | else:
137 | comb_index.union(self.symbol_data[s].index)
138 |
139 | # Set the latest symbol_data to None
140 | self.latest_symbol_data[s] = []
141 |
142 | # Reindex the dataframes
143 | for s in self.symbol_list:
144 | self.symbol_data[s] = self.symbol_data[s].\
145 | reindex(index=comb_index, method='pad').iterrows()
146 |
147 |
148 |
149 | def _get_new_bar(self, symbol):
150 | """
151 | Returns the latest bar from the data feed.
152 | """
153 | for b in self.symbol_data[symbol]:
154 | yield b
155 |
156 | def get_latest_bar(self, symbol):
157 | """
158 | Returns the last bar from the latest_symbol list.
159 | """
160 | try:
161 | bars_list = self.latest_symbol_data[symbol]
162 | except KeyError:
163 | print("That symbol is not available in the historical data set.")
164 | raise
165 | else:
166 | return bars_list[-1]
167 |
168 | def get_latest_bars(self, symbol, N=1):
169 | """
170 | Returns the last N bars from the latest_symbol list,
171 | or N-k if less available.
172 | """
173 | try:
174 | bars_list = self.latest_symbol_data[symbol]
175 | except KeyError:
176 | print("That symbol is not available in the historical data set.")
177 | raise
178 | else:
179 | return bars_list[-N:]
180 |
181 | def get_latest_bar_datetime(self, symbol):
182 | """
183 | Returns a Python datetime object for the last bar.
184 | """
185 | try:
186 | bars_list = self.latest_symbol_data[symbol]
187 | except KeyError:
188 | print("That symbol is not available in the historical data set.")
189 | raise
190 | else:
191 | return bars_list[-1][0]
192 |
193 | def get_latest_bar_value(self, symbol, val_type):
194 | """
195 | Returns one of the Open, High, Low, Close, Volume or OI
196 | values from the pandas Bar series object.
197 | """
198 | try:
199 | bars_list = self.latest_symbol_data[symbol]
200 | except KeyError:
201 | print("That symbol is not available in the historical data set.")
202 | raise
203 | else:
204 | return getattr(bars_list[-1][1], val_type)
205 |
206 | def get_latest_bars_values(self, symbol, val_type, N=1):
207 | """
208 | Returns the last N bar values from the
209 | latest_symbol list, or N-k if less available.
210 | """
211 | try:
212 | bars_list = self.get_latest_bars(symbol, N)
213 | except KeyError:
214 | print("That symbol is not available in the historical data set.")
215 | raise
216 | else:
217 | return np.array([getattr(b[1], val_type) for b in bars_list])
218 |
219 | def update_bars(self):
220 | """
221 | Pushes the latest bar to the latest_symbol_data structure
222 | for all symbols in the symbol list.
223 | """
224 | for s in self.symbol_list:
225 | try:
226 | bar = next(self._get_new_bar(s))
227 | except StopIteration:
228 | self.continue_backtest = False
229 | else:
230 | if bar is not None:
231 | self.latest_symbol_data[s].append(bar)
232 | self.events.put(MarketEvent())
233 |
234 |
--------------------------------------------------------------------------------
/back_test_system/portfolio.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 |
4 | # portfolio.py
5 |
6 | from __future__ import print_function
7 |
8 | import datetime
9 | from math import floor
10 | try:
11 | import Queue as queue
12 | except ImportError:
13 | import queue
14 |
15 | import numpy as np
16 | import pandas as pd
17 |
18 | from event import FillEvent, OrderEvent
19 | from performance import create_sharpe_ratio, create_drawdowns
20 |
21 |
22 | class Portfolio(object):
23 | """
24 | The Portfolio class handles the positions and market
25 | value of all instruments at a resolution of a "bar",
26 | i.e. secondly, minutely, 5-min, 30-min, 60 min or EOD.
27 |
28 | The positions DataFrame stores a time-index of the
29 | quantity of positions held.
30 |
31 | The holdings DataFrame stores the cash and total market
32 | holdings value of each symbol for a particular
33 | time-index, as well as the percentage change in
34 | portfolio total across bars.
35 | """
36 |
37 | def __init__(self, bars, events, start_date, initial_capital=100000.0):
38 | """
39 | Initialises the portfolio with bars and an event queue.
40 | Also includes a starting datetime index and initial capital
41 | (USD unless otherwise stated).
42 |
43 | Parameters:
44 | bars - The DataHandler object with current market data.
45 | events - The Event Queue object.
46 | start_date - The start date (bar) of the portfolio.
47 | initial_capital - The starting capital in USD.
48 | """
49 | self.bars = bars
50 | self.events = events
51 | self.symbol_list = self.bars.symbol_list
52 | self.start_date = start_date
53 | self.initial_capital = initial_capital
54 | self.all_positions = self.construct_all_positions()
55 | self.current_positions = dict( (k,v) for k, v in [(s, 0) for s in self.symbol_list] )
56 |
57 | self.all_holdings = self.construct_all_holdings()
58 | self.current_holdings = self.construct_current_holdings()
59 |
60 | def construct_all_positions(self):
61 | """
62 | Constructs the positions list using the start_date
63 | to determine when the time index will begin.
64 | """
65 | d = dict( (k,v) for k, v in [(s, 0) for s in self.symbol_list] )
66 | d['datetime'] = self.start_date
67 | return [d]
68 |
69 | def construct_all_holdings(self):
70 | """
71 | Constructs the holdings list using the start_date
72 | to determine when the time index will begin.
73 | """
74 | d = dict( (k,v) for k, v in [(s, 0.0) for s in self.symbol_list] )
75 | d['datetime'] = self.start_date
76 | d['cash'] = self.initial_capital
77 | d['commission'] = 0.0
78 | d['total'] = self.initial_capital
79 | return [d]
80 |
81 | def construct_current_holdings(self):
82 | """
83 | This constructs the dictionary which will hold the instantaneous
84 | value of the portfolio across all symbols.
85 | """
86 | d = dict( (k,v) for k, v in [(s, 0.0) for s in self.symbol_list] )
87 | d['cash'] = self.initial_capital
88 | d['commission'] = 0.0
89 | d['total'] = self.initial_capital
90 | return d
91 |
92 | def update_timeindex(self, event):
93 | """
94 | Adds a new record to the positions matrix for the current
95 | market data bar. This reflects the PREVIOUS bar, i.e. all
96 | current market data at this stage is known (OHLCV).
97 |
98 | Makes use of a MarketEvent from the events queue.
99 | """
100 | latest_datetime = self.bars.get_latest_bar_datetime(self.symbol_list[0])
101 |
102 | # Update positions
103 | # ================
104 | dp = dict( (k,v) for k, v in [(s, 0) for s in self.symbol_list] )
105 | dp['datetime'] = latest_datetime
106 |
107 | for s in self.symbol_list:
108 | dp[s] = self.current_positions[s]
109 |
110 | # Append the current positions
111 | self.all_positions.append(dp)
112 |
113 | # Update holdings
114 | # ===============
115 | dh = dict( (k,v) for k, v in [(s, 0) for s in self.symbol_list] )
116 | dh['datetime'] = latest_datetime
117 | dh['cash'] = self.current_holdings['cash']
118 | dh['commission'] = self.current_holdings['commission']
119 | dh['total'] = self.current_holdings['cash']
120 |
121 | for s in self.symbol_list:
122 | # Approximation to the real value
123 | market_value = self.current_positions[s] * \
124 | self.bars.get_latest_bar_value(s, "close")
125 | dh[s] = market_value
126 | dh['total'] += market_value
127 |
128 | # Append the current holdings
129 | self.all_holdings.append(dh)
130 |
131 | # ======================
132 | # FILL/POSITION HANDLING
133 | # ======================
134 |
135 | def update_positions_from_fill(self, fill):
136 | """
137 | Takes a Fill object and updates the position matrix to
138 | reflect the new position.
139 |
140 | Parameters:
141 | fill - The Fill object to update the positions with.
142 | """
143 | # Check whether the fill is a buy or sell
144 | fill_dir = 0
145 | if fill.direction == 'BUY':
146 | fill_dir = 1
147 | if fill.direction == 'SELL':
148 | fill_dir = -1
149 |
150 | # Update positions list with new quantities
151 | self.current_positions[fill.symbol] += fill_dir*fill.quantity
152 |
153 | def update_holdings_from_fill(self, fill):
154 | """
155 | Takes a Fill object and updates the holdings matrix to
156 | reflect the holdings value.
157 |
158 | Parameters:
159 | fill - The Fill object to update the holdings with.
160 | """
161 | # Check whether the fill is a buy or sell
162 | fill_dir = 0
163 | if fill.direction == 'BUY':
164 | fill_dir = 1
165 | if fill.direction == 'SELL':
166 | fill_dir = -1
167 |
168 | # Update holdings list with new quantities
169 | fill_cost = self.bars.get_latest_bar_value(
170 | fill.symbol, "adj_close"
171 | )
172 | cost = fill_dir * fill_cost * fill.quantity
173 | self.current_holdings[fill.symbol] += cost
174 | self.current_holdings['commission'] += fill.commission
175 | self.current_holdings['cash'] -= (cost + fill.commission)
176 | self.current_holdings['total'] -= (cost + fill.commission)
177 |
178 | def update_fill(self, event):
179 | """
180 | Updates the portfolio current positions and holdings
181 | from a FillEvent.
182 | """
183 | if event.type == 'FILL':
184 | self.update_positions_from_fill(event)
185 | self.update_holdings_from_fill(event)
186 |
187 | def generate_naive_order(self, signal):
188 | """
189 | Simply files an Order object as a constant quantity
190 | sizing of the signal object, without risk management or
191 | position sizing considerations.
192 |
193 | Parameters:
194 | signal - The tuple containing Signal information.
195 | """
196 | order = None
197 |
198 | symbol = signal.symbol
199 | direction = signal.signal_type
200 | strength = signal.strength
201 |
202 | mkt_quantity = 100
203 | cur_quantity = self.current_positions[symbol]
204 | order_type = 'MKT'
205 |
206 | if direction == 'LONG' and cur_quantity == 0:
207 | order = OrderEvent(symbol, order_type, mkt_quantity, 'BUY')
208 | if direction == 'SHORT' and cur_quantity == 0:
209 | order = OrderEvent(symbol, order_type, mkt_quantity, 'SELL')
210 | if direction == 'EXIT' and cur_quantity > 0:
211 | order = OrderEvent(symbol, order_type, abs(cur_quantity), 'SELL')
212 | if direction == 'EXIT' and cur_quantity < 0:
213 | order = OrderEvent(symbol, order_type, abs(cur_quantity), 'BUY')
214 | return order
215 |
216 | def update_signal(self, event):
217 | """
218 | Acts on a SignalEvent to generate new orders
219 | based on the portfolio logic.
220 | """
221 | if event.type == 'SIGNAL':
222 | order_event = self.generate_naive_order(event)
223 | self.events.put(order_event)
224 |
225 | # ========================
226 | # POST-BACKTEST STATISTICS
227 | # ========================
228 |
229 | def create_equity_curve_dataframe(self):
230 | """
231 | Creates a pandas DataFrame from the all_holdings
232 | list of dictionaries.
233 | """
234 | curve = pd.DataFrame(self.all_holdings)
235 | curve.set_index('datetime', inplace=True)
236 | curve['returns'] = curve['total'].pct_change()
237 | curve['equity_curve'] = (1.0+curve['returns']).cumprod()
238 | self.equity_curve = curve
239 |
240 | def output_summary_stats(self):
241 | """
242 | Creates a list of summary statistics for the portfolio.
243 | """
244 | total_return = self.equity_curve['equity_curve'][-1]
245 | returns = self.equity_curve['returns']
246 | pnl = self.equity_curve['equity_curve']
247 |
248 | print (total_return,returns,pnl)
249 |
250 |
251 | sharpe_ratio = create_sharpe_ratio(returns, periods=252*60*6.5)
252 | drawdown, max_dd, dd_duration = create_drawdowns(pnl)
253 | self.equity_curve['drawdown'] = drawdown
254 |
255 | stats = [("Total Return", "%0.2f%%" % ((total_return - 1.0) * 100.0)),
256 | ("Sharpe Ratio", "%0.2f" % sharpe_ratio),
257 | ("Max Drawdown", "%0.2f%%" % (max_dd * 100.0)),
258 | ("Drawdown Duration", "%d" % dd_duration)]
259 |
260 | self.equity_curve.to_csv('equity.csv')
261 | return stats
262 |
--------------------------------------------------------------------------------
/back_test_system/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 | true
46 | DEFINITION_ORDER
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 | 1482761166539
98 |
99 |
100 | 1482761166539
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
--------------------------------------------------------------------------------
/util/db.py:
--------------------------------------------------------------------------------
1 | #coding:utf-8
2 | import tushare as ts
3 | import MySQLdb as mdb
4 | import datetime
5 | import time
6 | import numpy as np
7 | import pandas as pd
8 | import sys
9 | import pandas as pd
10 | import pandas_datareader.data as web
11 |
12 |
13 |
14 |
15 | #从tushare中获取hs300的股票
16 | def save_hs300_into_db():
17 | db_host = 'localhost'
18 | db_user = 'root'
19 | db_password = ''
20 | db_name = 'ticker_master'
21 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name)
22 | now = datetime.datetime.utcnow()
23 | hs300 = ts.get_hs300s()
24 | column_str = """ticker, instrument, name, sector, currency, created_date, last_updated_date"""
25 | insert_str = ("%s, " * 7)[:-2]
26 | final_str = "INSERT INTO symbol (%s) VALUES (%s)" % (column_str, insert_str)
27 | symbols = []
28 |
29 | for i in range(len(hs300)):
30 | t = hs300.ix[i]
31 | symbols.append(
32 | (
33 | t['code'],
34 | 'stock',
35 | t['name'],
36 | '',
37 | 'RMB',
38 | now,
39 | now,
40 | )
41 | )
42 | cur = con.cursor()
43 | with con:
44 | cur = con.cursor()
45 | cur.executemany(final_str, symbols)
46 | print 'success insert hs300 into symbol!'
47 |
48 | def save_us_into_db(symbols):
49 | db_host = 'localhost'
50 | db_user = 'root'
51 | db_password = ''
52 | db_name = 'us_ticker_master'
53 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name)
54 | now = datetime.datetime.utcnow()
55 | column_str = """ticker, instrument, name, sector, currency, created_date, last_updated_date"""
56 | insert_str = ("%s, " * 7)[:-2]
57 | final_str = "INSERT INTO symbol (%s) VALUES (%s)" % (column_str, insert_str)
58 | symbols_content = []
59 |
60 | for i in range(len(symbols)):
61 | t = symbols.ix[i]
62 | symbols_content.append(
63 | (
64 | t['Symbol'],
65 | 'stock',
66 | t['Name'],
67 | t['Sector'],
68 | 'USD',
69 | now,
70 | now,
71 | )
72 | )
73 | cur = con.cursor()
74 | with con:
75 | cur = con.cursor()
76 | cur.executemany(final_str, symbols_content)
77 | print 'success insert us_ticker into symbol!'
78 |
79 | #获取事先储存过的300支股票
80 | def get_hs300_tickers():
81 | db_host = 'localhost'
82 | db_user = 'root'
83 | db_password = ''
84 | db_name = 'ticker_master'
85 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name)
86 | with con:
87 | cur = con.cursor();
88 | cur.execute('SELECT id,ticker,name FROM symbol')
89 | data = cur.fetchall();
90 | return [(d[0],d[1],d[2]) for d in data]
91 |
92 | #获取美股id用于遍历获取信息以及存储
93 | def get_us_tickers():
94 | db_host = 'localhost'
95 | db_user = 'root'
96 | db_password = ''
97 | db_name = 'us_ticker_master'
98 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name)
99 | with con:
100 | cur = con.cursor();
101 | cur.execute('SELECT id,ticker,name FROM symbol')
102 | data = cur.fetchall();
103 | return [(d[0],d[1],d[2]) for d in data]
104 |
105 | #根据id获取ticker
106 | def get_ticker_info_by_id(ticker_id,start_date,end_date=str(datetime.date.today())):
107 | #如果没传,则默认为该支股票开始的位置
108 | if start_date == '':
109 | df = ts.get_stock_basics()
110 | start_date = df.ix[ticker_id]['timeToMarket'] #上市日期YYYYMMDD
111 | start_date = str(start_date)
112 | start_date_year = start_date[0:4]
113 | start_date_month = start_date[4:6]
114 | start_date_day = start_date[6:8]
115 | start_date = start_date_year + '-' + start_date_month + '-' + start_date_day
116 |
117 | print ('======= loading:%s to %s , %s ========' % (start_date,end_date,ticker_id))
118 | ticker_data = ts.get_h_data(ticker_id,start=start_date,end=end_date,retry_count=50,pause=1)
119 | print ('======= loading success =======')
120 | return ticker_data
121 |
122 | #获取美股数据
123 | def get_us_ticker_by_id(ticker_id,start_date,end_date=datetime.date.today()):
124 | start = datetime.datetime(2010, 1, 1)
125 | end = datetime.datetime(2013, 1, 27)
126 | print(start,end,ticker_id,'-----')
127 | data = web.DataReader(ticker_id, 'yahoo', start_date, end_date)
128 | return data
129 |
130 | #下载失败则在symbol中删除
131 | def delete_symbol_from_db_by_id(id):
132 | db_host = 'localhost'
133 | db_user = 'root'
134 | db_password = ''
135 | db_name = 'us_ticker_master'
136 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name)
137 | with con:
138 | cur = con.cursor()
139 | print "DELETE FROM symbol where ticker='%s'" % id
140 | cur.execute("DELETE FROM symbol where ticker='%s'" % id)
141 |
142 |
143 | #读取symbol表里的最新的日期
144 | def get_last_date(ticker_id):
145 | db_host = 'localhost'
146 | db_user = 'root'
147 | db_password = ''
148 | db_name = 'ticker_master'
149 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name)
150 | with con:
151 | cur = con.cursor()
152 | cur.execute("SELECT price_date FROM daily_price WHERE symbol_id=%s ORDER BY price_date DESC" % ticker_id)
153 | date = cur.fetchall()
154 | return date
155 |
156 | #读取us美股最老日期
157 | def get_us_oldest_date(ticker_id):
158 | db_host = 'localhost'
159 | db_user = 'root'
160 | db_password = ''
161 | db_name = 'us_ticker_master'
162 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name)
163 | with con:
164 | cur = con.cursor()
165 | cur.execute("SELECT price_date FROM daily_price WHERE symbol_id=%s ORDER BY price_date" % ticker_id)
166 | date = cur.fetchall()
167 | return date
168 |
169 | #读取美股最新日期
170 | def get_us_last_date(ticker_id):
171 | db_host = 'localhost'
172 | db_user = 'root'
173 | db_password = ''
174 | db_name = 'us_ticker_master'
175 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name)
176 | with con:
177 | cur = con.cursor()
178 | cur.execute("SELECT price_date FROM daily_price WHERE symbol_id='%s' ORDER BY price_date DESC" % ticker_id)
179 | date = cur.fetchall()
180 | return date
181 |
182 | #储存到数据库
183 | def save_ticker_into_db(ticker_id,ticker,vendor_id):
184 | db_host = 'localhost'
185 | db_user = 'root'
186 | db_password = ''
187 | db_name = 'ticker_master'
188 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name)
189 | # Create the time now
190 | now = datetime.datetime.utcnow()
191 | # Create the insert strings
192 | column_str = """data_vendor_id, symbol_id, price_date, created_date,
193 | last_updated_date, open_price, high_price, low_price,
194 | close_price, volume, amount"""
195 | insert_str = ("%s, " * 11)[:-2]
196 | final_str = "INSERT INTO daily_price (%s) VALUES (%s)" % (column_str, insert_str)
197 | daily_data = []
198 |
199 | for i in range(len(ticker.index)):
200 | t_date = ticker.index[i]
201 | t_data = ticker.ix[t_date]
202 | daily_data.append(
203 | (vendor_id, ticker_id, t_date, now, now,t_data['open'], t_data['high']
204 | , t_data['low'], t_data['close'], t_data['volume'], t_data['amount'])
205 | )
206 |
207 | with con:
208 | cur = con.cursor()
209 | cur.executemany(final_str, daily_data)
210 |
211 | #储存到美股数据库
212 | def save_us_ticker_into_db(ticker_id,ticker,vendor_id):
213 | db_host = 'localhost'
214 | db_user = 'root'
215 | db_password = ''
216 | db_name = 'us_ticker_master'
217 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name)
218 | # Create the time now
219 | now = datetime.datetime.utcnow()
220 | # Create the insert strings
221 | column_str = """data_vendor_id, symbol_id, price_date, created_date,
222 | last_updated_date, open_price, high_price, low_price,
223 | close_price, volume, adj_close_price"""
224 | insert_str = ("%s, " * 11)[:-2]
225 | final_str = "INSERT INTO daily_price (%s) VALUES (%s)" % (column_str, insert_str)
226 | daily_data = []
227 |
228 | for i in range(len(ticker.index)):
229 | t_date = ticker.index[i]
230 | t_data = ticker.ix[t_date]
231 | daily_data.append(
232 | (vendor_id, ticker_id, t_date, now, now,t_data['Open'], t_data['High']
233 | , t_data['Low'], t_data['Close'], t_data['Volume'], t_data['Adj Close'])
234 | )
235 |
236 | with con:
237 | cur = con.cursor()
238 | cur.executemany(final_str, daily_data)
239 | print 'success insert new into db!'
240 |
241 |
242 |
243 |
244 | #从数据中获取数据
245 | def get_ticker_from_db_by_id(ticker_id):
246 | db_host = 'localhost'
247 | db_user = 'root'
248 | db_password = ''
249 | db_name = 'ticker_master'
250 | con = mdb.connect(host=db_host,user=db_user,passwd=db_password,db=db_name)
251 | with con:
252 | cur = con.cursor()
253 | cur.execute('SELECT price_date,open_price,high_price,low_price,close_price,volume from daily_price where symbol_id = %s ORDER BY price_date DESC' % ticker_id )
254 | daily_data = cur.fetchall()
255 | daily_data_np = np.array(daily_data)
256 | daily_data_df = pd.DataFrame(daily_data_np,columns=['index','open','high','low','close','volume'])
257 |
258 | return daily_data_df
259 |
260 |
261 | #从数据中获取美股us数据
262 | def get_us_ticker_from_db_by_id(ticker_id,start_date,end_date=datetime.date.today()):
263 | db_host = 'localhost'
264 | db_user = 'root'
265 | db_password = ''
266 | db_name = 'us_ticker_master'
267 | con = mdb.connect(host=db_host,user=db_user,passwd=db_password,db=db_name)
268 |
269 | start_date = str(start_date)[:10]
270 | end_date = str(end_date)[:10]
271 |
272 | with con:
273 | cur = con.cursor()
274 | cur.execute('SELECT price_date,open_price,high_price,low_price,adj_close_price,volume FROM daily_price WHERE (price_date BETWEEN "%s" and "%s") AND (symbol_id="%s") ORDER BY price_date DESC' % ( start_date,end_date,ticker_id ))
275 | daily_data = cur.fetchall()
276 | daily_data_np = np.array(daily_data)
277 | daily_data_df = pd.DataFrame(daily_data_np,columns=['index','open','high','low','close','volume'])
278 | return daily_data_df
279 |
280 | #从csv中获取美股的名称
281 | def get_us_ticker_name_from_csv(filename):
282 | data = pd.read_csv(filename)[['Symbol','Name','Sector','MarketCap']]
283 | # print (data)
284 | return data;
285 |
286 |
287 | #获取us日均交易量在中间33%的股票,从当日计算, 股票值在10到30之间
288 | def get_us_middle33_volume(delay_days,low_price,high_price):
289 | tickers = get_us_tickers()
290 | cal_volume_list = pd.DataFrame([],columns=['id','volume'])
291 | df = pd.DataFrame([],columns=['id','volume'])
292 | length = len(tickers)
293 | print '================ is calculating ================='
294 | # print tickers
295 | for i in range(length):
296 | ticker = tickers[i]
297 | ticker_id = ticker[1]
298 |
299 | #处理时间
300 | end_date = get_us_last_date(ticker_id)[0][0]
301 | start_date = end_date + datetime.timedelta(days = delay_days * -1)
302 | ticker_data = get_us_ticker_from_db_by_id(ticker_id,start_date,end_date)
303 | days_mean_volume = ticker_data['volume'].mean()
304 | days_mean_daily_price = ticker_data['close'].mean()
305 | print '========== %s of %s , %s , %s==========' % (i,length,ticker_id,days_mean_volume)
306 | #判断是否符合10到30取值区间
307 | if int(days_mean_daily_price) in range(int(low_price),int(high_price)):
308 | days_mean_volume_df = pd.DataFrame([[ticker_id,days_mean_volume,days_mean_daily_price]],columns=['id','volume','price'])
309 | df = df.append(days_mean_volume_df)
310 |
311 | # if i > 1000:
312 | # break
313 |
314 | df = df.sort(columns="volume")
315 | df_len = len(df)
316 | df = df[int(df_len * 0.33) : int(df_len * 0.66)]
317 | df.index = range(len(df))
318 |
319 |
320 | return df
321 |
322 |
323 | #得出平均值与方差
324 | def get_average_days_price_by_id(ticker_id,average_days = 7 * 30):
325 | db_host = 'localhost'
326 | db_user = 'root'
327 | db_password = ''
328 | db_name = 'us_ticker_master'
329 | con = mdb.connect(host=db_host,user=db_user,passwd=db_password,db=db_name)
330 |
331 | end_date=datetime.date.today()
332 | start_date = end_date + datetime.timedelta(days = int(average_days) * -1)
333 | start_date = str(start_date)[:10]
334 | end_date = str(end_date)[:10]
335 | with con:
336 | cur = con.cursor()
337 | cur.execute('SELECT adj_close_price FROM daily_price WHERE (price_date BETWEEN "%s" and "%s") AND (symbol_id="%s") ORDER BY price_date DESC' % ( start_date,end_date,ticker_id ))
338 | daily_data = cur.fetchall()
339 | daily_data_np = np.array(daily_data)
340 | daily_data_df = pd.DataFrame(daily_data_np,columns=['close']).fillna(method="pad").fillna(method="bfill")
341 | if len(daily_data_df) == 0 :
342 | mean = 0;
343 | std = 0
344 | else:
345 | mean = daily_data_df['close'].mean()
346 | std = daily_data_df['close'].astype('float').std()
347 |
348 | return mean,std
349 |
350 |
351 | #获取一只股票的currt数据,均值方差数据
352 | def get_current_mean_std_df(ticker_id,days_range=200,cal_range=60,end_date=datetime.date.today()):
353 | db_host = 'localhost'
354 | db_user = 'root'
355 | db_password = ''
356 | db_name = 'us_ticker_master'
357 | con = mdb.connect(host=db_host,user=db_user,passwd=db_password,db=db_name)
358 |
359 | end_date=datetime.date.today()
360 | start_date = end_date + datetime.timedelta(days = int(days_range + cal_range) * -1)
361 | start_date = str(start_date)[:10]
362 | end_date = str(end_date)[:10]
363 |
364 | with con:
365 | cur = con.cursor()
366 | cur.execute('SELECT adj_close_price FROM daily_price WHERE (price_date BETWEEN "%s" and "%s") AND (symbol_id="%s") ORDER BY price_date DESC' % ( start_date,end_date,ticker_id ))
367 | daily_data = cur.fetchall()
368 | daily_data_np = np.array(daily_data)
369 | daily_data_df = pd.DataFrame(daily_data_np,columns=['close'])
370 | daily_data_df['ma_60'] = pd.rolling_mean(daily_data_df['close'],cal_range)
371 | daily_data_df['ewma_60'] = pd.ewma(daily_data_df['close'],cal_range)
372 | daily_data_df['std_60'] = pd.rolling_std(daily_data_df['close'],cal_range)
373 |
374 | return daily_data_df[60:]
375 |
--------------------------------------------------------------------------------