├── README.md ├── util ├── __init__.py ├── etf_tmp.py ├── init_save_until_today.py ├── init_save_us_until_today.py ├── plot.py ├── fresh_until_today.py ├── fresh_us_until_today.py ├── scrapy_etf.py ├── ml_util.py ├── Feature_utils.py └── db.py ├── deal_data ├── __init__.py ├── db.pyc ├── forecast.pyc ├── FeatureUtils.pyc ├── db.py ├── get_r2.py ├── find_features_plot.py ├── deal_hs300.py ├── FeatureUtils.py └── forecast.py ├── back_test_system ├── .gitignore ├── self │ ├── ib_execution.py │ ├── event.pyc │ ├── performance.pyc │ ├── strategy.py │ ├── performance.py │ ├── exexcution.py │ ├── event.py │ ├── mac.py │ ├── backtest.py │ ├── data.py │ └── portfolio.py ├── data.pyc ├── event.pyc ├── backtest.pyc ├── strategy.pyc ├── execution.pyc ├── performance.pyc ├── portfolio.pyc ├── .idea │ ├── misc.xml │ ├── inspectionProfiles │ │ └── profiles_settings.xml │ ├── modules.xml │ ├── back_test_system.iml │ └── workspace.xml ├── try.py ├── strategy.py ├── to_csv.py ├── performance.py ├── execution.py ├── mac.py ├── backtest.py ├── event.py ├── ib_execution.py ├── data.py └── portfolio.py ├── strategy_1 ├── mark └── simulation.py ├── strategy_2_alpace ├── mark ├── tryData.py └── alpace.py ├── .gitignore └── get_data ├── get_300_names.py ├── securities_master.sql ├── get_hist_data.py ├── find_mean_reversion.py ├── find_pairs.py └── clustering_hs300.py /README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /deal_data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /back_test_system/.gitignore: -------------------------------------------------------------------------------- 1 | .pyc -------------------------------------------------------------------------------- /back_test_system/self/ib_execution.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /strategy_1/mark: -------------------------------------------------------------------------------- 1 | { 2 | 博云新材:002297 3 | } -------------------------------------------------------------------------------- /strategy_2_alpace/mark: -------------------------------------------------------------------------------- 1 | 羊驼策略 2 | 针对美股 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.csv.xz 2 | *.csv 3 | *.xz 4 | *.pyc -------------------------------------------------------------------------------- /deal_data/db.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayizhi/quantResearch_py/HEAD/deal_data/db.pyc -------------------------------------------------------------------------------- /deal_data/forecast.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayizhi/quantResearch_py/HEAD/deal_data/forecast.pyc -------------------------------------------------------------------------------- /back_test_system/data.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayizhi/quantResearch_py/HEAD/back_test_system/data.pyc -------------------------------------------------------------------------------- /back_test_system/event.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayizhi/quantResearch_py/HEAD/back_test_system/event.pyc -------------------------------------------------------------------------------- /deal_data/FeatureUtils.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayizhi/quantResearch_py/HEAD/deal_data/FeatureUtils.pyc -------------------------------------------------------------------------------- /back_test_system/backtest.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayizhi/quantResearch_py/HEAD/back_test_system/backtest.pyc -------------------------------------------------------------------------------- /back_test_system/strategy.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayizhi/quantResearch_py/HEAD/back_test_system/strategy.pyc -------------------------------------------------------------------------------- /back_test_system/execution.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayizhi/quantResearch_py/HEAD/back_test_system/execution.pyc -------------------------------------------------------------------------------- /back_test_system/performance.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayizhi/quantResearch_py/HEAD/back_test_system/performance.pyc -------------------------------------------------------------------------------- /back_test_system/portfolio.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayizhi/quantResearch_py/HEAD/back_test_system/portfolio.pyc -------------------------------------------------------------------------------- /back_test_system/self/event.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayizhi/quantResearch_py/HEAD/back_test_system/self/event.pyc -------------------------------------------------------------------------------- /back_test_system/self/performance.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayizhi/quantResearch_py/HEAD/back_test_system/self/performance.pyc -------------------------------------------------------------------------------- /back_test_system/.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /util/etf_tmp.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import db 4 | 5 | 6 | data = pd.read_csv('etf.csv') 7 | 8 | data['Sector'] = 'ETF' 9 | 10 | data = data[['Symbol','Sector','Name']] 11 | 12 | db.save_us_into_db(data) -------------------------------------------------------------------------------- /back_test_system/try.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import numpy 4 | import pandas 5 | import matplotlib 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | from abc import ABCMeta, abstractmethod 15 | 16 | 17 | -------------------------------------------------------------------------------- /back_test_system/.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 7 | -------------------------------------------------------------------------------- /back_test_system/.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /back_test_system/self/strategy.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from abc import ABCMeta,abstractmethod 3 | import datetime 4 | try: 5 | import Queue as queue 6 | except ImportError: 7 | import queue 8 | import numpy as np 9 | import pandas as pd 10 | from event import SignalEvent 11 | 12 | class Strategy(object): 13 | __metaclass = ABCMeta 14 | 15 | @abstractmethod 16 | def calculate_signals(self): 17 | raise NotImplementedError('Should implement calculate_signals()') -------------------------------------------------------------------------------- /back_test_system/.idea/back_test_system.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /util/init_save_until_today.py: -------------------------------------------------------------------------------- 1 | #coding: utf-8 2 | 3 | from db import get_hs300_tickers,get_ticker_info_by_id,save_ticker_into_db,get_last_date 4 | import datetime 5 | 6 | 7 | if __name__ == '__main__': 8 | #hs300的id 9 | ticker_info = get_hs300_tickers() 10 | for i in range(len(ticker_info)): 11 | ticker = ticker_info[i] 12 | ticker_id = ticker[1] 13 | ticker_name = ticker[2] 14 | vendor_id = i 15 | 16 | if i < 190: 17 | continue 18 | 19 | #获取 20 | ticker_data = get_ticker_info_by_id(ticker_id,'') 21 | 22 | #存储 23 | save_ticker_into_db(ticker_id,ticker_data,vendor_id) -------------------------------------------------------------------------------- /back_test_system/self/performance.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import pandas as pd 4 | 5 | def create_sharpe_ratio(returns,period=252): 6 | return np.sqrt(period) * (np.mean(returns)) / np.std(returns) 7 | 8 | def create_drawdowns(pnl): 9 | hwm = [0] 10 | idx = pnl.index 11 | drawdown = pd.Series(index = idx) 12 | duration = pd.Series(index = idx) 13 | 14 | for t in range(1,len(idx)): 15 | hwm.append(max(hwm[t - 1],pnl[t])) 16 | drawdown[t] = (hwm[t] - pnl[t]) 17 | duration[t] = (0 if drawdown[t] == 0 else duration[t - 1] + 1) 18 | 19 | return drawdown,drawdown.max(),duration.max() 20 | -------------------------------------------------------------------------------- /strategy_2_alpace/tryData.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import datetime 4 | import sys 5 | sys.path.append('..') 6 | import util.db as db 7 | import util.plot as plot 8 | import pprint 9 | 10 | data = ['EWZS'] 11 | 12 | df = pd.DataFrame(data) 13 | 14 | my_judge = []; 15 | for i in range(df.shape[0]): 16 | # print df[i] 17 | ticker = df.loc[i] 18 | print ticker,'=========================' 19 | ticker_id = ticker[0] 20 | plot.plotCurrentMeanStd(ticker_id,400) 21 | ticker_judge = raw_input()#1:buy,0:no,2:interest 22 | my_judge.append((ticker_id,ticker_judge)) 23 | 24 | 25 | 26 | 27 | 28 | 29 | pprint.pprint(my_judge) 30 | 31 | 32 | -------------------------------------------------------------------------------- /back_test_system/self/exexcution.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from abc import ABCMeta,abstractmethod 4 | import datetime 5 | try: 6 | import Queue as queue 7 | except ImportError: 8 | import queue 9 | 10 | from event import FillEvent,OrderEvent 11 | 12 | class ExecutionHandler(object): 13 | __metaclass = ABCMeta 14 | 15 | @abstractmethod 16 | def execute_order(self,event): 17 | raise NotImplementedError('Should implement execute_order()') 18 | 19 | class SimulateExecutionHandler(ExecutionHandler): 20 | def __init__(self,events): 21 | self.events = events 22 | 23 | def execute_order(self,event): 24 | if event.type == 'ORDER': 25 | fill_event = FillEvent( 26 | datetime.datetime.utcnow(),event.symbol,'ARCA',event.quantity,event.direction,None) 27 | self.events.put(fill_event) -------------------------------------------------------------------------------- /util/init_save_us_until_today.py: -------------------------------------------------------------------------------- 1 | #coding: utf-8 2 | import numpy as np 3 | import pandas as pd 4 | import datetime 5 | import sys 6 | import util.db as db 7 | 8 | 9 | if __name__ == '__main__': 10 | symbols = db.get_us_tickers(); 11 | start_date = datetime.datetime(2015,1,1) 12 | error_arr = []; 13 | for i in range(len(symbols)): 14 | symbol = symbols[i][1]; 15 | try: 16 | print '========= loading %s , %s ==========' % (i,symbol) 17 | ticker = db.get_us_ticker_by_id(symbol,start_date) 18 | print '========= loading success' 19 | 20 | db.save_us_ticker_into_db(symbol,ticker,i) 21 | print '+++++++++++++ save %s , %s success +++++++++++++++' % (i,symbol) 22 | except: 23 | error_arr.append(symbol) 24 | db.delete_symbol_from_db_by_id(symbol) 25 | print '------------- delete %s , %s success -------------' % (i,symbol) 26 | 27 | -------------------------------------------------------------------------------- /util/plot.py: -------------------------------------------------------------------------------- 1 | #coding utf-8 2 | import matplotlib.pyplot as plt 3 | import pandas as pd 4 | import numpy as np 5 | import db 6 | 7 | 8 | 9 | 10 | def plotCurrentMeanStd(tickerId,days_rang=200): 11 | df = db.get_current_mean_std_df(tickerId,days_rang) 12 | cur = df['close'] 13 | theMin = cur.min(); 14 | theMax = cur.max(); 15 | mean_60 = df['ma_60'] 16 | emwa_60 = df['ewma_60'] 17 | std_60 = df['std_60'] 18 | 19 | plt.plot(cur,'r',lw=0.75,linestyle='-',label='cur') 20 | plt.plot(std_60,'p',lw=0.75,linestyle='-',label='std_60') 21 | plt.plot(mean_60,'b',lw=0.75,linestyle='-',label='mean_60') 22 | plt.plot(emwa_60,'g',lw=0.75,linestyle='-',label='emwa_60') 23 | plt.ylim(float(theMin) * 0.9, float(theMax) * 1.1) 24 | 25 | plt.legend(loc=4,prop={'size':2}) 26 | plt.setp(plt.gca().get_xticklabels(), rotation=30) 27 | plt.grid(True) 28 | 29 | plt.show() 30 | 31 | 32 | -------------------------------------------------------------------------------- /strategy_1/simulation.py: -------------------------------------------------------------------------------- 1 | #coding: utf-8 2 | import numpy as np 3 | import pandas as pd 4 | import datetime 5 | import sys 6 | sys.path.append('..') 7 | import util.db as db 8 | from util.Feature_utils import get_good_feature 9 | from util.ml_util import get_classification_r2,get_regression_r2 10 | 11 | 12 | if __name__ == '__main__': 13 | ticker_df = db.get_ticker_from_db_by_id('300133') 14 | ticker_index = ticker_df['index'] 15 | ticker_np = np.array(ticker_df[['open','high','low','close','volume']]) 16 | ticker_df = pd.DataFrame(ticker_np,index=ticker_index,columns=['open','high','low','close','volume'],dtype="float") 17 | start_date = datetime.datetime(2016,1,1) 18 | end_date = datetime.datetime.today() 19 | ticker_df = ticker_df[ticker_df.index >= start_date] 20 | ticker_df = ticker_df[ticker_df.index <= end_date] 21 | ticker_data = get_good_feature(ticker_df,10) 22 | ticker_data.dtype = '|S6' 23 | 24 | 25 | 26 | #get best ml 27 | get_regression_r2(ticker_data) 28 | 29 | -------------------------------------------------------------------------------- /util/fresh_until_today.py: -------------------------------------------------------------------------------- 1 | from db import get_hs300_tickers,get_ticker_info_by_id,save_ticker_into_db,get_last_date 2 | import datetime 3 | import tushare as ts 4 | import time 5 | 6 | 7 | if __name__ == '__main__': 8 | #hs300的id 9 | ticker_info = get_hs300_tickers() 10 | for i in range(len(ticker_info)): 11 | ticker = ticker_info[i] 12 | ticker_id = ticker[1] 13 | ticker_name = ticker[2] 14 | vendor_id = i; 15 | 16 | print '--------------------- %s ---------------------' % vendor_id 17 | 18 | try: 19 | start_date = get_last_date(ticker_id)[0] 20 | start_date = str(start_date[0] + datetime.timedelta(days = 1))[0:10] 21 | except: 22 | start_date = '' 23 | 24 | 25 | ticker_data = get_ticker_info_by_id(ticker_id,start_date) 26 | 27 | print ('data_shape:' , ticker_data.shape) 28 | 29 | if ticker_data.shape[0] != 0: 30 | #存储 31 | try: 32 | save_ticker_into_db(ticker_id,ticker_data,vendor_id) 33 | except: 34 | print '数据有问题!' 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /back_test_system/strategy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | # strategy.py 5 | 6 | from __future__ import print_function 7 | 8 | from abc import ABCMeta, abstractmethod 9 | import datetime 10 | try: 11 | import Queue as queue 12 | except ImportError: 13 | import queue 14 | 15 | import numpy as np 16 | import pandas as pd 17 | 18 | from event import SignalEvent 19 | 20 | 21 | class Strategy(object): 22 | """ 23 | Strategy is an abstract base class providing an interface for 24 | all subsequent (inherited) strategy handling objects. 25 | 26 | The goal of a (derived) Strategy object is to generate Signal 27 | objects for particular symbols based on the inputs of Bars 28 | (OHLCV) generated by a DataHandler object. 29 | 30 | This is designed to work both with historic and live data as 31 | the Strategy object is agnostic to where the data came from, 32 | since it obtains the bar tuples from a queue object. 33 | """ 34 | 35 | __metaclass__ = ABCMeta 36 | 37 | @abstractmethod 38 | def calculate_signals(self): 39 | """ 40 | Provides the mechanisms to calculate the list of signals. 41 | """ 42 | raise NotImplementedError("Should implement calculate_signals()") 43 | -------------------------------------------------------------------------------- /get_data/get_300_names.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | #对沪深三百股票进行聚类,并画出关系图 3 | import tushare as ts 4 | import sqlalchemy as create_engine 5 | import MySQLdb as mdb 6 | import datetime 7 | 8 | 9 | 10 | #get hs300's names and insert them into database 11 | def get_hs300(con): 12 | now = datetime.datetime.utcnow() 13 | hs300 = ts.get_hs300s() 14 | column_str = """ticker, instrument, name, sector, currency, created_date, last_updated_date""" 15 | insert_str = ("%s, " * 7)[:-2] 16 | final_str = "INSERT INTO symbol (%s) VALUES (%s)" % (column_str, insert_str) 17 | symbols = [] 18 | 19 | for i in range(len(hs300)): 20 | t = hs300.ix[i] 21 | symbols.append( 22 | ( 23 | t['code'], 24 | 'stock', 25 | t['name'], 26 | '', 27 | 'RMB', 28 | now, 29 | now, 30 | ) 31 | ) 32 | cur = con.cursor() 33 | with con: 34 | cur = con.cursor() 35 | cur.executemany(final_str, symbols) 36 | print 'success insert hs300 into symbol!' 37 | 38 | 39 | 40 | if __name__ == '__main__': 41 | db_host = 'localhost' 42 | db_user = 'root' 43 | db_password = '' 44 | db_name = 'securities_master' 45 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name) 46 | 47 | #get all 300 names and put them into database 48 | get_hs300(con); 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /back_test_system/to_csv.py: -------------------------------------------------------------------------------- 1 | import MySQLdb as mdb 2 | import csv 3 | import pandas as pd 4 | import numpy as np 5 | from pandas import DataFrame,Series 6 | import sys 7 | 8 | def get_tickers_from_db(): 9 | db_host = 'localhost' 10 | db_user = 'root' 11 | db_pd = '' 12 | db_name = 'securities_master' 13 | con = mdb.connect(host=db_host, user=db_user, passwd=db_pd, db=db_name) 14 | with con: 15 | cur = con.cursor() 16 | cur.execute('SELECT id,ticker FROM symbol'); 17 | return cur.fetchall() 18 | 19 | 20 | def get_one_ticker_by_id(ticker_id): 21 | db_host = 'localhost' 22 | db_user = 'root' 23 | db_pd = '' 24 | db_name = 'securities_master' 25 | con = mdb.connect(host=db_host, user=db_user, passwd=db_pd, db=db_name) 26 | with con: 27 | cur = con.cursor() 28 | cur.execute('SELECT price_date,open_price,high_price,low_price,close_price,volume FROM daily_price WHERE symbol_id = %s' % ticker_id) 29 | return cur.fetchall() 30 | 31 | 32 | def render_csv(ticker_id): 33 | data = get_one_ticker_by_id(ticker_id) 34 | 35 | np_data = np.array(data) 36 | pd_data = DataFrame(np_data,columns = ['datetime', 'open', 'high', 'low', 'close', 'volume']) 37 | adj_close = Series([d[4] for d in data]) 38 | pd_data['adj_close'] = adj_close 39 | pd_data.to_csv('./data/%s.csv' % ticker_id) 40 | 41 | 42 | render_csv('600050') 43 | 44 | 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /util/fresh_us_until_today.py: -------------------------------------------------------------------------------- 1 | from db import get_us_tickers,get_us_ticker_from_db_by_id,save_us_ticker_into_db,get_us_last_date 2 | import datetime 3 | import db 4 | import tushare as ts 5 | import pandas as pd 6 | import time 7 | 8 | 9 | 10 | if __name__ == '__main__': 11 | #hs300的id 12 | ticker_info = get_us_tickers() 13 | for i in range(len(ticker_info)): 14 | ticker = ticker_info[i] 15 | ticker_id = ticker[1] 16 | ticker_name = ticker[2] 17 | vendor_id = i; 18 | 19 | 20 | print '--------------------- %s ---------------------' % vendor_id 21 | 22 | 23 | try: 24 | print '========= loading %s , %s ==========' % (i,ticker_id) 25 | start_date = get_us_last_date(ticker_id)[0][0] 26 | print '========= loading success ==========' 27 | start_date = start_date + datetime.timedelta(days = 1) 28 | 29 | except: 30 | start_date = '' 31 | 32 | print start_date , '============================================' 33 | 34 | # try: 35 | ticker_data = db.get_us_ticker_by_id(ticker_id,start_date) 36 | # except: 37 | # print 'get data fail...' 38 | # ticker_data = pd.DataFrame([]) 39 | 40 | print ('data :' , ticker_data) 41 | 42 | if ticker_data.shape[0] != 0: 43 | #存储 44 | try: 45 | print '+++++++++++++ save %s , %s success +++++++++++++++' % (i,ticker_id) 46 | save_us_ticker_into_db(ticker_id,ticker_data,vendor_id) 47 | except: 48 | print(ticker_id, 'is error') 49 | 50 | print ticker_id,'==== finished =====' 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /deal_data/db.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import MySQLdb as mdb 3 | 4 | 5 | 6 | 7 | #get name 8 | def get_tickers_from_db(): 9 | db_host = 'localhost' 10 | db_user = 'root' 11 | db_password = '' 12 | db_name = 'securities_master' 13 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name) 14 | 15 | #get name form symbol; 16 | with con: 17 | cur = con.cursor() 18 | cur.execute('SELECT ticker,name FROM symbol') 19 | data = cur.fetchall() 20 | return data 21 | 22 | #获取当日成交量 23 | def get_day_volumn_33_66(day): 24 | db_host = 'localhost' 25 | db_user = 'root' 26 | db_password = '' 27 | db_name = 'securities_master' 28 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name) 29 | 30 | with con: 31 | cur = con.cursor() 32 | cur.execute('SELECT volume from daily_price where (price_date="%s")' % day) 33 | day_volume = cur.fetchall() 34 | day_volume = [int(day_volume[i][0]) for i in range(len(day_volume))] 35 | return day_volume 36 | 37 | 38 | #get data by tickerId 39 | def get_10_50_by_id(ticker_id): 40 | db_host = 'localhost' 41 | db_user = 'root' 42 | db_password = '' 43 | db_name = 'securities_master' 44 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name) 45 | 46 | with con: 47 | cur = con.cursor() 48 | cur.execute('SELECT price_date,open_price,high_price,low_price,close_price,volume from daily_price where (symbol_id = %s) and (price_date BETWEEN "20150101" AND "20151231") and (high_price BETWEEN 10 and 50)' % ticker_id) 49 | daily_data = cur.fetchall() 50 | return daily_data 51 | 52 | 53 | -------------------------------------------------------------------------------- /back_test_system/self/event.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | class Event(object): 4 | pass 5 | 6 | class MarketEvent(Event): 7 | def __init__(self): 8 | self.type = 'MARKET' 9 | 10 | class SignalEvent(Event): 11 | def __init__(self,strategy_id,symbol,datetime,signal_type,strength): 12 | self.strategy_id = strategy_id 13 | self.type = 'SIGNAL' 14 | self.symbol = symbol 15 | self.datetime = datetime 16 | self.signal_type = signal_type 17 | self.strength = strength 18 | 19 | class OrderEvent(Event): 20 | def __init__(self,symbol,order_type,quantity,direction): 21 | self.type = 'ORDER' 22 | self.symbol = symbol 23 | self.order_type = order_type 24 | self.quantity = quantity 25 | self.direction = direction 26 | 27 | def print_order(self): 28 | print("Order: Symbol=%s,Type=%s,Quantity=%s,Direction=%s" %s (self.symbol,self.order_type,self.quantity,self.direction)) 29 | 30 | class FillEvent(Event): 31 | def __init__(self,timeindex,symbol,exchange,quantity,direction,fill_cost,commission=None): 32 | self.type = 'FILL' 33 | self.timeindex = timeindex 34 | self.symbol = symbol 35 | self.exchange = exchange 36 | self.quantity = quantity 37 | self.direction = direction 38 | self.fill_cost = fill_cost 39 | 40 | if commission is None: 41 | self.commission = self.calculate_ib_commission() 42 | else: 43 | self.commission = commission 44 | 45 | def calculate_ib_commission(self): 46 | full_cost = 1.3 47 | if self.quantity <= 500: 48 | full_cost = max(1.3,0.013 * self.quantity) 49 | else: 50 | full_cost = max(1.3,0.008 * self.quantity) 51 | return full_cost -------------------------------------------------------------------------------- /back_test_system/performance.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | # performance.py 5 | 6 | from __future__ import print_function 7 | 8 | import numpy as np 9 | import pandas as pd 10 | 11 | 12 | def create_sharpe_ratio(returns, periods=252): 13 | """ 14 | Create the Sharpe ratio for the strategy, based on a 15 | benchmark of zero (i.e. no risk-free rate information). 16 | 17 | Parameters: 18 | returns - A pandas Series representing period percentage returns. 19 | periods - Daily (252), Hourly (252*6.5), Minutely(252*6.5*60) etc. 20 | """ 21 | return np.sqrt(periods) * (np.mean(returns)) / np.std(returns) 22 | 23 | 24 | def create_drawdowns(pnl): 25 | """ 26 | Calculate the largest peak-to-trough drawdown of the PnL curve 27 | as well as the duration of the drawdown. Requires that the 28 | pnl_returns is a pandas Series. 29 | 30 | Parameters: 31 | pnl - A pandas Series representing period percentage returns. 32 | 33 | Returns: 34 | drawdown, duration - Highest peak-to-trough drawdown and duration. 35 | """ 36 | 37 | # Calculate the cumulative returns curve 38 | # and set up the High Water Mark 39 | hwm = [0] 40 | 41 | # Create the drawdown and duration series 42 | idx = pnl.index 43 | drawdown = pd.Series(index = idx) 44 | duration = pd.Series(index = idx) 45 | 46 | # Loop over the index range 47 | for t in range(1, len(idx)): 48 | hwm.append(max(hwm[t-1], pnl[t])) 49 | drawdown[t]= (hwm[t]-pnl[t]) 50 | duration[t]= (0 if drawdown[t] == 0 else duration[t-1]+1) 51 | return drawdown, drawdown.max(), duration.max() 52 | -------------------------------------------------------------------------------- /get_data/securities_master.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE `exchange` ( 2 | `id` int NOT NULL AUTO_INCREMENT, 3 | `abbrev` varchar(32) NOT NULL, 4 | `name` varchar(255) NOT NULL, 5 | `city` varchar(255) NULL, 6 | `country` varchar(255) NULL, 7 | `currency` varchar(64) NULL, 8 | `timezone_offset` time NULL, 9 | `created_date` datetime NOT NULL, 10 | `last_updated_date` datetime NOT NULL, 11 | PRIMARY KEY (`id`) 12 | ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8; 13 | 14 | CREATE TABLE `data_vendor` ( 15 | `id` int NOT NULL AUTO_INCREMENT, 16 | `name` varchar(64) NOT NULL, 17 | `website_url` varchar(255) NULL, 18 | `support_email` varchar(255) NULL, 19 | `created_date` datetime NOT NULL, 20 | `last_updated_date` datetime NOT NULL, 21 | PRIMARY KEY (`id`) 22 | ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8; 23 | 24 | CREATE TABLE `symbol` ( 25 | `id` int NOT NULL AUTO_INCREMENT, 26 | `exchange_id` int NULL, 27 | `ticker` varchar(32) NOT NULL, 28 | `instrument` varchar(64) NOT NULL, 29 | `name` varchar(255) NULL, 30 | `sector` varchar(255) NULL, 31 | `currency` varchar(32) NULL, 32 | `created_date` datetime NOT NULL, 33 | `last_updated_date` datetime NOT NULL, 34 | PRIMARY KEY (`id`), 35 | KEY `index_exchange_id` (`exchange_id`) 36 | ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8; 37 | 38 | CREATE TABLE `daily_price` ( 39 | `id` int NOT NULL AUTO_INCREMENT, 40 | `data_vendor_id` int NOT NULL, 41 | `symbol_id` int NOT NULL, 42 | `price_date` datetime NOT NULL, 43 | `created_date` datetime NOT NULL, 44 | `last_updated_date` datetime NOT NULL, 45 | `open_price` decimal(19,4) NULL, 46 | `high_price` decimal(19,4) NULL, 47 | `low_price` decimal(19,4) NULL, 48 | `close_price` decimal(19,4) NULL, 49 | `adj_close_price` decimal(19,4) NULL, 50 | `volume` bigint NULL, 51 | PRIMARY KEY (`id`), 52 | KEY `index_data_vendor_id` (`data_vendor_id`), 53 | KEY `index_symbol_id` (`symbol_id`) 54 | ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8; -------------------------------------------------------------------------------- /util/scrapy_etf.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import simplejson as json 3 | import pandas as pd 4 | 5 | # url = 'https://xueqiu.com/hq#exchange=US&plate=3_1_11&firstName=3&secondName=3_1&order=desc&orderby=marketcapital&page=1' 6 | url1 = 'https://xueqiu.com/snowman/login' 7 | 8 | data = { 9 | "remember_me":"true", 10 | "username":"13151998870", 11 | "password":"zhangyizhi112358" 12 | } 13 | headers = { 14 | "Accept":"application/json, text/javascript, */*; q=0.01", 15 | "Accept-Encoding":"gzip, deflate, sdch, br", 16 | "Accept-Language":"zh-CN,zh;q=0.8,en;q=0.6", 17 | "Cache-Control":"no-cache", 18 | "Connection":"keep-alive", 19 | "Host":"xueqiu.com", 20 | "Pragma":"no-cache", 21 | "Referer":"https://xueqiu.com/hq", 22 | "User-Agent":"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_6) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/56.0.2924.87 Safari/537.36", 23 | "X-Requested-With":"XMLHttpRequest", 24 | } 25 | 26 | res1 = requests.post(url1,data,headers=headers) 27 | 28 | target_data = [] 29 | pageNum = 1 30 | 31 | while True: 32 | url2 = 'http://xueqiu.com/stock/cata/stocklist.json?page=%s&size=30&order=desc&orderby=marketCapital&exchange=US&plate=ETF&isdelay=1' % (pageNum) 33 | res2 = requests.get(url2, cookies=res1.cookies, headers=headers) 34 | stock_data = json.loads(res2.text)['stocks'] 35 | 36 | 37 | 38 | print 'pageNum: ', pageNum ,' =====================================' 39 | 40 | # print stock_data,'---------------------------' 41 | 42 | 43 | 44 | if len(stock_data) == 0 : 45 | break 46 | 47 | for i in range(len(stock_data)): 48 | stock = stock_data[i] 49 | stock_id = stock['code'].encode('utf-8') 50 | stock_name = stock['name'].encode('utf-8') 51 | print stock_id,stock_name 52 | target_data.append((stock_id,stock_name)) 53 | 54 | pageNum = pageNum + 1 55 | 56 | 57 | 58 | df = pd.DataFrame(target_data,columns=['Symbol','Name']) 59 | df.to_csv('etf.csv') 60 | 61 | 62 | 63 | print target_data 64 | 65 | -------------------------------------------------------------------------------- /back_test_system/self/mac.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import numpy as np 3 | 4 | from backtest import Backtest 5 | from data import HistoricCSVDataHandler 6 | from event import SignalEvent 7 | from execution import SimulateExecutionHandler 8 | from portfolio import Portfolio 9 | from strategy import Strategy 10 | 11 | class MovingAverageCrossStrategy(Strategy): 12 | def __init__(self,bars,events,short_window=100,long_window=400): 13 | self.bars = bars 14 | self.symbol_list = self.bars.symbol_list 15 | self.events = events 16 | self.short_window = short_window 17 | self.long_window = long_window 18 | 19 | self.bought = self._calculate_initial_bought() 20 | 21 | def _calculate_initial_bought(self): 22 | bought = {} 23 | for s in self.symbol_list: 24 | bought[s] = 'OUT' 25 | return bought 26 | 27 | def calculate_signals(self,event): 28 | if event.type == 'MARKET': 29 | for symbol in self.symbol_list: 30 | bars = self.bars.get_latest_bars_values(symbol,'close',N=self.long_window) 31 | if bars is not None and bars != []: 32 | short_sma = np.mean(bars[-self.short_window:]) 33 | long_sma = np.mean(bars[-self.long_window:]) 34 | 35 | dt = self.bars.get_latest_bar_datetime(symbol) 36 | sig_dir = '' 37 | strength = 1.0 38 | strategy_id = 1 39 | 40 | if short_sma > long_sma and self.bought[symbol] == 'OUT': 41 | sig_dir = 'LONG' 42 | signal = SignalEvent(strategy_id,symbol,dt,sig_dir,strength) 43 | self.events.put(signal) 44 | self.bought[symbol] = 'LONG' 45 | elif short_sma < long_sma and self.bought[symbol] == 'LONG': 46 | sig_dir = 'EXIT' 47 | signal = SignalEvent( 48 | strategy_id,symbol,dt,sig_dir,strength) 49 | self.events.put(signal) 50 | self.bought[symbol] = 'OUT' 51 | 52 | if __name == '__main__': 53 | csv_dir = 'data' 54 | symbol_list = ['600050'] 55 | initial_capital = 100000.0 56 | start_date = datetime.datetime(1990,1,1,0,0,0) 57 | hearbeat = 0.0 58 | 59 | backtest = Backtest( 60 | csv_dir,symbol_list, 61 | initial_capital, 62 | hearbeat, 63 | start_date, 64 | HistoricCSVDataHandler, 65 | SimulateExecutionHandler, 66 | Portfolio, 67 | MovingAverageCrossStrategy 68 | ) 69 | backtest.simulate_trading() 70 | 71 | -------------------------------------------------------------------------------- /get_data/get_hist_data.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | #对沪深三百股票进行聚类,并画出关系图 3 | import tushare as ts 4 | import sqlalchemy as create_engine 5 | import MySQLdb as mdb 6 | import datetime 7 | 8 | 9 | 10 | def get_tickers_from_db(con): 11 | #get name form symbol; 12 | with con: 13 | cur = con.cursor() 14 | cur.execute('SELECT id,ticker FROM symbol') 15 | data = cur.fetchall() 16 | return [(d[0],d[1]) for d in data] 17 | 18 | def get_hist_data_from_tushare(ticker): 19 | start_date = '2000-1-1' 20 | end_date = str(datetime.today().timetuple()) 21 | 22 | get every detail data 23 | tData = ts.get_hist_data(ticker,start=start_date,end=end_date,retry_count=5,pause=1) 24 | return tData 25 | 26 | def into_db(tTicker,data,data_vendor_id,con): 27 | # Create the time now 28 | now = datetime.datetime.utcnow() 29 | # Create the insert strings 30 | column_str = """data_vendor_id, symbol_id, price_date, created_date, 31 | last_updated_date, open_price, high_price, low_price, 32 | close_price, volume, adj_close_price""" 33 | insert_str = ("%s, " * 11)[:-2] 34 | final_str = "INSERT INTO daily_price (%s) VALUES (%s)" % (column_str, insert_str) 35 | daily_data = [] 36 | 37 | for i in range(len(data.index)): 38 | t_date = data.index[i] 39 | t_data = data.ix[t_date] 40 | daily_data.append( 41 | (data_vendor_id, tTicker, t_date, now, now,t_data['open'], t_data['high'] 42 | , t_data['low'], t_data['close'], t_data['volume'], 0) 43 | ) 44 | 45 | with con: 46 | cur = con.cursor() 47 | cur.executemany(final_str, daily_data) 48 | 49 | 50 | 51 | if __name__ == '__main__': 52 | db_host = 'localhost' 53 | db_user = 'root' 54 | db_password = '' 55 | db_name = 'securities_master' 56 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name) 57 | 58 | #get history data and put them into database 59 | tickers = get_tickers_from_db(con); 60 | 61 | #iterating every ticker and put theirs data into datbase 62 | for i in range(len(tickers)): 63 | t = tickers[i] 64 | tTicker = tickers[i][1] 65 | data = get_hist_data_from_tushare(tTicker) 66 | data_vendor_id = i 67 | print 'data_vendor_id : %s' % data_vendor_id 68 | print 'tTicker : %s' % tTicker 69 | print '%s of %s' % (i,len(tickers)) 70 | into_db(tTicker,data,data_vendor_id,con) 71 | 72 | 73 | 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /strategy_2_alpace/alpace.py: -------------------------------------------------------------------------------- 1 | 2 | #coding: utf-8 3 | import numpy as np 4 | import pandas as pd 5 | import datetime 6 | import sys 7 | sys.path.append('..') 8 | import util.db as db 9 | import pprint 10 | 11 | 12 | 13 | 14 | 15 | if __name__ == '__main__': 16 | 17 | day_range = 20 #计算周期 18 | round_days = 7 #执行周期 19 | average_days_70 = 6 * 10 #几日均线 20 | average_days_5 = 10 21 | average_days_30 = 20 22 | 23 | #找到volumn在33%-66%之间的股票池,20日平均交易量,并且 24 | stockers = db.get_us_middle33_volume(day_range,8,20) 25 | stocker_ids = stockers['id'] 26 | ticker_content = []; 27 | #需要计算的,方差,30周均线,计算周期趋势 28 | 29 | 30 | for i in range(len(stocker_ids)): 31 | stocker_id = stocker_ids[i] 32 | end_date = db.get_us_last_date(stocker_id)[0][0] 33 | start_date = end_date + datetime.timedelta(days = day_range * -1) 34 | stocker_data = db.get_us_ticker_from_db_by_id(stocker_id,start_date,end_date) 35 | profit = (stocker_data.loc[0].close - stocker_data.loc[len(stocker_data) - 1].close)/stocker_data.loc[len(stocker_data) - 1].close 36 | current_price = stocker_data.loc[0].close 37 | mean_price_70, std_price_70 = db.get_average_days_price_by_id(stocker_id,average_days_70) 38 | mean_price_30, std_price_30 = db.get_average_days_price_by_id(stocker_id,average_days_30) 39 | mean_price_5, std_price_5 = db.get_average_days_price_by_id(stocker_id,average_days_5) 40 | 41 | 42 | 43 | if (current_price > mean_price_30) and (abs(mean_price_30 - mean_price_70) < 1) : 44 | print (stocker_id,current_price,mean_price_30,mean_price_70,'===========================') 45 | ticker_content.append((stocker_id,current_price,profit,mean_price_5,mean_price_30,mean_price_70)) 46 | 47 | df = pd.DataFrame(ticker_content,columns=['id','price','profit','mean5','mean_price_30','mean_price_70']) 48 | df = df.sort(['profit'],ascending=False) 49 | df = df.reset_index() 50 | 51 | #选盈利前30%中的后40% 52 | # df_length_30per = int(df.shape[0] * 0.3) 53 | # print df_length_30per,df.shape,'------------' 54 | # best_30per = df[:df_length_30per] 55 | # df_length_40per = int(best_30per.shape[0] *0.8) 56 | # print df_length_40per,best_30per.shape,'=========' 57 | # best_30per_40per = best_30per[df_length_40per:] 58 | 59 | # print '~~~~~~~~~~~~~~~~~~~~~~~~~~' 60 | 61 | # pprint.pprint(best_30per_40per) 62 | # pprint.pprint(np.array(best_30per_40per['id'])) 63 | 64 | pprint.pprint(df) 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /back_test_system/execution.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | # execution.py 5 | 6 | from __future__ import print_function 7 | 8 | from abc import ABCMeta, abstractmethod 9 | import datetime 10 | try: 11 | import Queue as queue 12 | except ImportError: 13 | import queue 14 | 15 | from event import FillEvent, OrderEvent 16 | 17 | 18 | class ExecutionHandler(object): 19 | """ 20 | The ExecutionHandler abstract class handles the interaction 21 | between a set of order objects generated by a Portfolio and 22 | the ultimate set of Fill objects that actually occur in the 23 | market. 24 | 25 | The handlers can be used to subclass simulated brokerages 26 | or live brokerages, with identical interfaces. This allows 27 | strategies to be backtested in a very similar manner to the 28 | live trading engine. 29 | """ 30 | 31 | __metaclass__ = ABCMeta 32 | 33 | @abstractmethod 34 | def execute_order(self, event): 35 | """ 36 | Takes an Order event and executes it, producing 37 | a Fill event that gets placed onto the Events queue. 38 | 39 | Parameters: 40 | event - Contains an Event object with order information. 41 | """ 42 | raise NotImplementedError("Should implement execute_order()") 43 | 44 | 45 | class SimulatedExecutionHandler(ExecutionHandler): 46 | """ 47 | The simulated execution handler simply converts all order 48 | objects into their equivalent fill objects automatically 49 | without latency, slippage or fill-ratio issues. 50 | 51 | This allows a straightforward "first go" test of any strategy, 52 | before implementation with a more sophisticated execution 53 | handler. 54 | """ 55 | 56 | def __init__(self, events): 57 | """ 58 | Initialises the handler, setting the event queues 59 | up internally. 60 | 61 | Parameters: 62 | events - The Queue of Event objects. 63 | """ 64 | self.events = events 65 | 66 | def execute_order(self, event): 67 | """ 68 | Simply converts Order objects into Fill objects naively, 69 | i.e. without any latency, slippage or fill ratio problems. 70 | 71 | Parameters: 72 | event - Contains an Event object with order information. 73 | """ 74 | if event.type == 'ORDER': 75 | fill_event = FillEvent( 76 | datetime.datetime.utcnow(), event.symbol, 77 | 'ARCA', event.quantity, event.direction, None 78 | ) 79 | self.events.put(fill_event) 80 | -------------------------------------------------------------------------------- /deal_data/get_r2.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import pprint 5 | import os 6 | path = os.getcwd() 7 | print path 8 | import FeatureUtils 9 | import forecast 10 | from pandas import Series,DataFrame 11 | 12 | #回归 13 | from sklearn.cross_validation import train_test_split 14 | from sklearn.metrics import classification_report 15 | from sklearn.svm import SVC 16 | from sklearn.linear_model import Lasso,LinearRegression,Ridge,LassoLars 17 | from sklearn.ensemble import RandomForestRegressor 18 | from sklearn.metrics import r2_score 19 | from sklearn import linear_model 20 | 21 | 22 | data = pd.read_csv(path + '/2/hs.csv',parse_dates=True,iterator=True) 23 | # data = pd.read_csv(path + '/2/hs.csv',parse_dates=True) 24 | data = pd.DataFrame(data.get_chunk(5000),dtype='|S6') 25 | # data = pd.DataFrame(data,dtype='|S6') 26 | 27 | 28 | 29 | def get_regression_r2(ticker_data): 30 | data_len = len(ticker_data) 31 | split_line = int(data_len * 0.8) 32 | X = ticker_data.drop('realY',1) 33 | y = ticker_data['realY'].dropna() 34 | 35 | X_train = X.ix[:split_line] 36 | X_test = X.ix[split_line:] 37 | y_train = y.ix[:split_line] 38 | y_test = y.ix[split_line:] 39 | 40 | models = [ 41 | # ('LR',LinearRegression()), 42 | # ('RidgeR',Ridge (alpha = 0.005)), 43 | # ('lasso',Lasso(alpha=0.00001)), 44 | # ('LassoLars',LassoLars(alpha=0.00001)), 45 | ('RandomForestRegression',RandomForestRegressor(1000))] 46 | 47 | best_r2 = ('',-10000000) 48 | 49 | for m in models: 50 | m[1].fit(np.array(X_train),np.array(y_train)) 51 | #因为index方面,pred出的其实是相当于往后挪了一位,跟原来的y_test是对不上的,所以x需要往前进一位 52 | #比较绕,所以从日期对应的方面去考虑 53 | pred = m[1].predict(X_test.shift(-1).fillna(0)) 54 | r2 = r2_score(y_test,pred) 55 | if r2 > best_r2[1]: 56 | best_r2 = (m[1],r2) 57 | # print "%s:\n%0.3f" % (m[0], r2_score(np.array(y_test),np.array(pred))) 58 | 59 | print 'the best regression is:',best_r2 60 | 61 | model = best_r2[0] 62 | model.fit(X_train, y_train) 63 | pred = model.predict(X_test.shift(-1).fillna(0)) 64 | pred_test = pd.Series(pred, index=y_test.index) 65 | 66 | fig = plt.figure() 67 | ax = fig.add_subplot(1,1,1) 68 | ax.plot(y_test,'r',lw=0.75,linestyle='-',label='realY') 69 | ax.plot(pred_test,'b',lw=0.75,linestyle='-',label='predY') 70 | plt.legend(loc=2,prop={'size':9}) 71 | plt.setp(plt.gca().get_xticklabels(), rotation=30) 72 | plt.grid(True) 73 | plt.show() 74 | 75 | return best_r2 76 | 77 | 78 | X = data.drop(['realY','predictY'],1) 79 | y = data['realY'] 80 | 81 | #get importances of features 82 | # features = FeatureUtils.forestFindFeature(X,y,100)[:300] 83 | 84 | data = X.join(y) 85 | 86 | get_regression_r2(data) 87 | 88 | 89 | # print data -------------------------------------------------------------------------------- /deal_data/find_features_plot.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import pandas as pd 3 | import numpy as np 4 | from pandas import DataFrame,Series 5 | import matplotlib.pyplot as plt 6 | import tushare as ts 7 | import FeatureUtils 8 | from FeatureUtils import toDatetime 9 | from sklearn.cross_validation import train_test_split 10 | from sklearn.metrics import classification_report 11 | from sklearn.svm import SVC 12 | from sklearn.linear_model import Lasso,LinearRegression,Ridge,LassoLars 13 | from sklearn.metrics import r2_score 14 | from sklearn import linear_model 15 | import datetime 16 | 17 | 18 | 19 | #获取沪深300股指数据 20 | hs300 = DataFrame(ts.get_hist_data('hs300')) 21 | hs300 = toDatetime(hs300).dropna() 22 | 23 | #加入feature 24 | hs300 = FeatureUtils.CCI(hs300,10) 25 | hs300 = FeatureUtils.TL(hs300,10) 26 | hs300 = FeatureUtils.EVM(hs300,10) 27 | hs300 = FeatureUtils.SMA(hs300,10) 28 | hs300 = FeatureUtils.EWMA(hs300,10) 29 | hs300 = FeatureUtils.ROC(hs300,10) 30 | hs300 = FeatureUtils.ForceIndex(hs300,10) 31 | hs300 = FeatureUtils.BBANDS(hs300,10) 32 | hs300 = hs300.dropna() 33 | 34 | #归一化 35 | hs300_norm = (hs300 - hs300.mean())/(hs300.max() - hs300.min()) 36 | 37 | # # Build a classification task using 3 informative features 38 | X_real = DataFrame(hs300_norm.drop('close',1),dtype='|S6') 39 | X = X_real.shift(1).dropna() 40 | y_real = Series(hs300_norm['close'],dtype='|S6') 41 | y = y_real[:-1] 42 | 43 | 44 | #forest find feature 45 | features = FeatureUtils.forestFindFeature(X,y,100) 46 | 47 | 48 | X_real_F = DataFrame(hs300_norm[features[0:10]],dtype='|S6') 49 | X_F = X_real_F.shift(1).dropna() 50 | y_F = Series(y,dtype='|S6') 51 | 52 | 53 | 54 | d = datetime.datetime(2015,12,31) 55 | 56 | 57 | X_train = X_F[X.index < d] 58 | X_test = X_F[X.index >= d] 59 | y_train = y_F[y.index <= d] 60 | y_test = y_F[y.index > d] 61 | 62 | 63 | # Create the (parametrised) models 64 | print("Hit Rates/Confusion Matrices:\n") 65 | models = [ 66 | ('LR',LinearRegression()), 67 | ('RidgeR',Ridge (alpha = 0.005)), 68 | ('lasso',Lasso(alpha=0.0001)), 69 | ('LassoLars',LassoLars(alpha=0.0001)) 70 | ] 71 | 72 | for m in models: 73 | m[1].fit(np.array(X_train),np.array(y_train)) 74 | pred = m[1].predict(X_test) 75 | print "%s:\n%0.3f" % (m[0], r2_score(y_test,pred)) 76 | 77 | 78 | model = Lasso(alpha=0.0001) 79 | model.fit(X_train, y_train) 80 | pred = model.predict(X_test) 81 | pred_test = pd.Series(pred, index=y_test.index) 82 | fig = plt.figure() 83 | ax = fig.add_subplot(1,1,1) 84 | ax.plot(y_test,'r',lw=0.75,linestyle='-',label='realY') 85 | ax.plot(pred_test,'b',lw=0.75,linestyle='-',label='predY') 86 | plt.legend(loc=2,prop={'size':9}) 87 | plt.setp(plt.gca().get_xticklabels(), rotation=30) 88 | plt.grid(True) 89 | plt.show() -------------------------------------------------------------------------------- /get_data/find_mean_reversion.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import pandas as pd 3 | import numpy as np 4 | from pandas import DataFrame,Series 5 | import MySQLdb as mdb 6 | import numpy as np 7 | import statsmodels.tsa.stattools as ts 8 | 9 | 10 | #get name 11 | def get_tickers_from_db(con): 12 | with con: 13 | cur = con.cursor() 14 | cur.execute('SELECT id,ticker,name FROM symbol') 15 | data = cur.fetchall() 16 | return [(d[0],d[1],d[2]) for d in data] 17 | 18 | #get data from 2010 to 2015 19 | def get_2010_2015(ticker_id,ticker_name,con): 20 | with con: 21 | cur = con.cursor() 22 | cur.execute('SELECT price_date,close_price from daily_price where (symbol_id = %s) and (price_date BETWEEN "20100101" AND "20151231")' % ticker_id) 23 | ticker_data = cur.fetchall() 24 | dates = np.array([d[0] for d in ticker_data]) 25 | t_data = np.array([d[1] for d in ticker_data]) 26 | ticker_data = np.array(t_data,dtype='float64') 27 | 28 | return ticker_data 29 | 30 | #hurst 31 | def hurst(ts): 32 | lags = range(2,100) 33 | tau = [np.sqrt(np.std(np.subtract(ts[lag:],ts[:-lag]))) for lag in lags] 34 | poly = np.polyfit(np.log(lags),np.log(tau),1) 35 | return poly[0] * 2.0 36 | 37 | 38 | 39 | if __name__ == '__main__': 40 | #connect to db 41 | db_host = 'localhost' 42 | db_user = 'root' 43 | db_pass = '' 44 | db_name = 'securities_master' 45 | con = mdb.connect(db_host, db_user, db_pass, db_name) 46 | 47 | #get 300 names and id 48 | tickers = get_tickers_from_db(con) 49 | 50 | all_hurst_data = [] 51 | all_adf_data = [] 52 | 53 | #get data of 2010-2015 54 | for i in range(len(tickers)): 55 | ticker = tickers[i] 56 | ticker_id = ticker[1] 57 | ticker_name = ticker[2] 58 | ticker_data = get_2010_2015(ticker_id,ticker_name,con) 59 | 60 | 61 | #数量太少报maxlag should be < nobs 62 | if ticker_data.shape[0] < 100: 63 | continue 64 | 65 | print '========================================' 66 | print '========================================' 67 | print '========================================' 68 | #hust 69 | t_hurst = hurst(ticker_data) 70 | if t_hurst < 0.3: 71 | all_hurst_data.append((ticker_id,ticker_name)) 72 | print 'Hurst %s : %s' % (ticker_name,t_hurst) 73 | 74 | #adf test 75 | t_adf = ts.adfuller(ticker_data,1) 76 | if t_adf[0] < t_adf[4]['5%']: 77 | all_adf_data.append((ticker_id,ticker_name)) 78 | print 'ADF test %s : %s' % (ticker_name,t_adf) 79 | print '========================================' 80 | print '========================================' 81 | print '========================================' 82 | 83 | 84 | 85 | 86 | print all_hurst_data 87 | print all_adf_data 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /deal_data/deal_hs300.py: -------------------------------------------------------------------------------- 1 | # 股票价格的预测分类/回归模型 2 | # 1. price between 10-50 3 | # 2. 沪深三百内的 4 | # 3. average daily volume (ADV) in the middle 33 percentile 5 | 6 | import db 7 | from db import get_10_50_by_id,get_tickers_from_db,get_day_volumn_33_66 8 | import datetime 9 | import time 10 | import pandas as pd 11 | import numpy as np 12 | from pandas import DataFrame,Series 13 | import matplotlib.pyplot as plt 14 | import forecast 15 | 16 | 17 | def get_33_and_66_volumn(date_list): 18 | date_list = [str(date_list[i])[:10] for i in range(len(date_list))] 19 | date_obj = {} 20 | for i in range(len(date_list)): 21 | day = date_list[i] 22 | day_volume = np.mean(np.array(get_day_volumn_33_66(day))) 23 | if np.isnan(day_volume) : 24 | continue 25 | t33 = int(day_volume * 0.33) 26 | t66 = int(day_volume * 0.66) 27 | date_obj[day] = (t33,t66) 28 | print day,day_volume 29 | return date_obj 30 | 31 | 32 | tickers = get_tickers_from_db() 33 | new_date_list = pd.date_range('1/1/2015', '12/31/2015', freq='1D') 34 | daily_volumn = get_33_and_66_volumn(new_date_list) 35 | selected_data = [] 36 | best_performer = [] 37 | 38 | 39 | 40 | for i in range(len(tickers)): 41 | this_ticker = tickers[i] 42 | ticker_id = this_ticker[0] 43 | ticker_name = this_ticker[1] 44 | data = get_10_50_by_id(ticker_id); 45 | g_data = [] 46 | index_list = [] 47 | for i in range(len(data)): 48 | date = data[i][0] 49 | date = str(date)[0:10] 50 | volume = data[i][5] 51 | volume_range = range(daily_volumn[date][0],daily_volumn[date][1]) 52 | if volume > daily_volumn[date][0] and volume < daily_volumn[date][1]: 53 | g_data.append([data[i][1],data[i][2],data[i][3],data[i][4],data[i][5]]) 54 | index_list.append(data[i][0]) 55 | if len(g_data) < 50: 56 | continue 57 | t_ticker = DataFrame(g_data,index=index_list,dtype='float64',columns=['open','high','low','close','volume']) 58 | t_ticker = t_ticker.reindex(new_date_list,method='ffill').fillna(method='bfill') 59 | selected_data.append(t_ticker) 60 | #得到表现最好的5个feature下的数据 61 | t_ticker = forecast.get_good_feature(t_ticker) 62 | best_regression_r2 = forecast.get_regression_r2(t_ticker) 63 | best_classification_r2 = forecast.get_classification_r2(t_ticker) 64 | 65 | #取最好的前20个 66 | if best_regression_r2 > 0.85: 67 | this_ticker_node = (best_regression_r2,best_classification_r2,ticker_id) 68 | if len(best_performer) < 20 : 69 | best_performer.append(this_ticker_node) 70 | best_performer = sorted(best_performer) 71 | else: 72 | theLast_reg_r2 = best_performer[0][0] 73 | if best_regression_r2 > theLast_reg_r2 : 74 | best_performer = best_performer[1:] 75 | best_performer.append(this_ticker_node) 76 | 77 | print best_performer 78 | print '可利用的数据:',len(selected_data) 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /back_test_system/self/backtest.py: -------------------------------------------------------------------------------- 1 | #coding: utf-8 2 | 3 | from __future__ import print_function 4 | 5 | import datetime 6 | import pprint 7 | try: 8 | import Queue as queue 9 | except ImportError: 10 | import queue 11 | import time 12 | 13 | class Backtest(object): 14 | def __init__(self,csv_dir,symbol_list,initial_capital,hearbeat,start_date,data_handler,execution_handler,portfolio,strategy): 15 | self.csv_dir = csv_dir 16 | self.symbol_list = symbol_list 17 | self.initial_capital = initial_capital 18 | self.hearbeat = hearbeat 19 | self.start_date = start_date 20 | 21 | self.data_handler_cls = data_handler 22 | self.execution_handler_cls = execution_handler 23 | self.portfolio_cls = portfolio 24 | self.strategy_cls = strategy 25 | 26 | self.events = queue.Queue() 27 | 28 | self.signals = 0 29 | self.orders = 0 30 | self.fills = 0 31 | self.num_strats = 1 32 | self._generate_trading_instances() 33 | 34 | def _generate_trading_instances(self): 35 | print("Creating DataHandler, Strategy, Portfolio and ExecutionHandler") 36 | 37 | self.data_handler = self.data_handler_cls(self.events,self.csv_dir,self.symbol_list) 38 | self.strategy = self.strategy_cls(self.data_handler,self.events) 39 | self.portfolio = self.portfolio_cls(self.data_handler,self.events,self.start_date,self.initial_capital) 40 | self.execution_handler = self.execution_handler_cls(self.events) 41 | 42 | def _run_backtest(self): 43 | i = 0 44 | while True: 45 | i += 1 46 | print (i) 47 | if self.data_handler.continue_backtest == True: 48 | self.data_handler.update_bars() 49 | else: 50 | break 51 | 52 | while True: 53 | try: 54 | event = self.events.get(False) 55 | except queue.Empty: 56 | break 57 | else: 58 | if event is not None: 59 | if event.type == 'MARKET': 60 | self.strategy.calculate_signals(event) 61 | self.Portfolio.update_timeindex(event) 62 | elif event.type == 'SIGNAL': 63 | self.signals += 1 64 | self.portfolio.update_signal(event) 65 | elif event.type == 'ORDER': 66 | self.orders += 1 67 | self.execution_handler.execute_order(event) 68 | elif event.type == 'FILL': 69 | self.fills += 1 70 | self.portfolio.updatee_fill(event) 71 | time.sleep(self.hearbeat) 72 | 73 | def _output_performance(self): 74 | self.Portfolio.create_equity_curve_dataframe() 75 | print ('Creating summary stats...') 76 | stats = self.Portfolio.output_summary_stats() 77 | 78 | print ('Creating equity curve...') 79 | print (self.Portfolio.equity_curve.tail(10)) 80 | pprint.pprint(stats) 81 | 82 | print ('Signal: %s' % self.signals) 83 | print ('Orders: %s' % self.orders) 84 | print ('Fills: %s' % self.fills) 85 | 86 | def simulate_trading(self): 87 | self._run_backtest() 88 | self._output_performance() -------------------------------------------------------------------------------- /util/ml_util.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import numpy as np 3 | import pandas as pd 4 | from pandas import DataFrame,Series 5 | import datetime 6 | import matplotlib.pyplot as plt 7 | import Feature_utils 8 | #回归 9 | from sklearn.cross_validation import train_test_split 10 | from sklearn.metrics import classification_report 11 | from sklearn.svm import SVC 12 | from sklearn.linear_model import Lasso,LinearRegression,Ridge,LassoLars 13 | from sklearn.metrics import r2_score 14 | from sklearn import linear_model 15 | #分类 16 | from sklearn.ensemble import RandomForestClassifier 17 | from sklearn.linear_model import LogisticRegression 18 | from sklearn.lda import LDA 19 | from sklearn.metrics import confusion_matrix 20 | from sklearn.qda import QDA 21 | from sklearn.svm import LinearSVC, SVC 22 | 23 | 24 | def get_regression_r2(ticker_data): 25 | data_len = len(ticker_data) 26 | split_line = int(data_len * 0.2) 27 | target_X = ticker_data.drop('close',1)[0:10] 28 | target_Y = ticker_data['close'][0:10] 29 | X = ticker_data.drop('close',1)[10:].dropna() 30 | y = ticker_data['close'].shift(10).dropna() 31 | 32 | 33 | # X = ticker_data.drop('close',1).shift(-10).dropna() 34 | # y = ticker_data['close'][:-10].dropna() 35 | 36 | # X = ticker_data.drop('close',1).dropna() 37 | # y = ticker_data['close'].dropna() 38 | 39 | X_test = X.ix[:split_line] 40 | X_train = X.ix[split_line:] 41 | y_test = y.ix[:split_line] 42 | y_train = y.ix[split_line:] 43 | 44 | print target_X 45 | 46 | models = [ 47 | ('LR',LinearRegression()), 48 | ('RidgeR',Ridge (alpha = 0.005)), 49 | ('lasso',Lasso(alpha=0.00001)), 50 | ('LassoLars',LassoLars(alpha=0.00001))] 51 | 52 | best_r2 = (models[0][1],0) 53 | for m in models: 54 | m[1].fit(np.array(X_train),np.array(y_train)) 55 | pred = m[1].predict(X_test.fillna(0)) 56 | r2 = r2_score(y_test,pred) 57 | if r2 > best_r2[1]: 58 | best_r2 = (m[1],r2) 59 | # print "%s:\n%0.3f" % (m[0], r2_score(np.array(y_test),np.array(pred))) 60 | 61 | print 'the best regression is:',best_r2 62 | 63 | model = best_r2[0] 64 | model.fit(X_train, y_train) 65 | pred = model.predict(target_X.fillna(0)) 66 | pred_test = pd.Series(pred, index=target_X.index) 67 | 68 | fig = plt.figure() 69 | ax = fig.add_subplot(1,1,1) 70 | ax.plot(target_Y,'r',lw=0.75,linestyle='-',label='realY') 71 | ax.plot(pred_test,'b',lw=0.75,linestyle='-',label='predY') 72 | 73 | plt.legend(loc=2,prop={'size':9}) 74 | plt.setp(plt.gca().get_xticklabels(), rotation=30) 75 | plt.grid(True) 76 | plt.show() 77 | 78 | return best_r2 79 | 80 | 81 | def get_classification_r2(ticker_data): 82 | 83 | 84 | data_len = len(ticker_data) 85 | split_line = int(data_len * 0.8) 86 | X = ticker_data.drop('close',1)[:-1] 87 | y = Series(ticker_data['close'].shift(-1).dropna(),dtype='|S6') 88 | 89 | X_train = X.ix[:split_line] 90 | X_test = X.ix[split_line:] 91 | y_train = y.ix[:split_line] 92 | y_test = y.ix[split_line:] 93 | 94 | 95 | models = [("LR", LogisticRegression()), 96 | ("LDA", LDA()), 97 | ("LSVC", LinearSVC()), 98 | ("RSVM", SVC( 99 | C=1000000.0, cache_size=200, class_weight=None, 100 | coef0=0.0, degree=3, gamma=0.0001, kernel='rbf', 101 | max_iter=-1, probability=False, random_state=None, 102 | shrinking=True, tol=0.001, verbose=False) 103 | ), 104 | ("RF", RandomForestClassifier( 105 | n_estimators=1000, criterion='gini', 106 | max_depth=None, min_samples_split=2, 107 | min_samples_leaf=1, max_features='auto', 108 | bootstrap=True, oob_score=False, n_jobs=1, 109 | random_state=None, verbose=0) 110 | )] 111 | 112 | best = (0,0) 113 | for m in models: 114 | m[1].fit(X_train, y_train) 115 | pred = m[1].predict(X_test) 116 | name = m[0] 117 | score = m[1].score(X_test, y_test) 118 | if score > best[1]: 119 | best = (name,score) 120 | print 'the best cluster is:' , best 121 | return best 122 | 123 | 124 | -------------------------------------------------------------------------------- /back_test_system/mac.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import numpy as np 3 | 4 | from backtest import Backtest 5 | from data import HistoricCSVDataHandler 6 | from event import SignalEvent 7 | from execution import SimulatedExecutionHandler 8 | from portfolio import Portfolio 9 | from strategy import Strategy 10 | 11 | 12 | class MovingAverageCrossStrategy(Strategy): 13 | """ 14 | Carries out a basic Moving Average Crossover strategy with a 15 | short/long simple weighted moving average. Default short/long 16 | windows are 100/400 periods respectively. 17 | """ 18 | 19 | def __init__(self, bars, events, short_window=100, long_window=400): 20 | """ 21 | Initialises the buy and hold strategy. 22 | 23 | Parameters: 24 | bars - The DataHandler object that provides bar information 25 | events - The Event Queue object. 26 | short_window - The short moving average lookback. 27 | long_window - The long moving average lookback. 28 | """ 29 | self.bars = bars 30 | self.symbol_list = self.bars.symbol_list 31 | self.events = events 32 | self.short_window = short_window 33 | self.long_window = long_window 34 | 35 | # Set to True if a symbol is in the market 36 | self.bought = self._calculate_initial_bought() 37 | 38 | def _calculate_initial_bought(self): 39 | """ 40 | Adds keys to the bought dictionary for all symbols 41 | and sets them to 'OUT'. 42 | """ 43 | bought = {} 44 | for s in self.symbol_list: 45 | bought[s] = 'OUT' 46 | return bought 47 | 48 | def calculate_signals(self, event): 49 | """ 50 | Generates a new set of signals based on the MAC 51 | SMA with the short window crossing the long window 52 | meaning a long entry and vice versa for a short entry. 53 | 54 | Parameters 55 | event - A MarketEvent object. 56 | """ 57 | if event.type == 'MARKET': 58 | for symbol in self.symbol_list: 59 | bars = self.bars.get_latest_bars_values(symbol, "close", N=self.long_window) 60 | 61 | if bars is not None and bars != []: 62 | short_sma = np.mean(bars[-self.short_window:]) 63 | long_sma = np.mean(bars[-self.long_window:]) 64 | 65 | dt = self.bars.get_latest_bar_datetime(symbol) 66 | sig_dir = "" 67 | strength = 1.0 68 | strategy_id = 1 69 | 70 | if short_sma > long_sma and self.bought[symbol] == "OUT": 71 | sig_dir = 'LONG' 72 | signal = SignalEvent(strategy_id, symbol, dt, sig_dir, strength) 73 | self.events.put(signal) 74 | self.bought[symbol] = 'LONG' 75 | 76 | elif short_sma < long_sma and self.bought[symbol] == "LONG": 77 | sig_dir = 'EXIT' 78 | signal = SignalEvent(strategy_id, symbol, dt, sig_dir, strength) 79 | self.events.put(signal) 80 | self.bought[symbol] = 'OUT' 81 | 82 | 83 | if __name__ == "__main__": 84 | csv_dir = 'data' 85 | symbol_list = ['600050'] 86 | initial_capital = 100000.0 87 | start_date = datetime.datetime(1990,1,1,0,0,0) 88 | heartbeat = 0.0 89 | 90 | backtest = Backtest(csv_dir, 91 | symbol_list, 92 | initial_capital, 93 | heartbeat, 94 | start_date, 95 | HistoricCSVDataHandler, 96 | SimulatedExecutionHandler, 97 | Portfolio, 98 | MovingAverageCrossStrategy) 99 | 100 | backtest.simulate_trading() 101 | -------------------------------------------------------------------------------- /deal_data/FeatureUtils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | import numpy as np 4 | from pandas import DataFrame,Series 5 | import matplotlib.pyplot as plt 6 | import tushare as ts 7 | from sklearn.datasets import make_classification 8 | from sklearn.ensemble import ExtraTreesClassifier 9 | 10 | 11 | #CCI 12 | def CCI(data,ndays): 13 | TP = (data['high'] + data['low'] + data['close'])/3 14 | CCI = pd.Series((TP - pd.rolling_mean(TP, ndays)) / (0.015 * pd.rolling_std(TP, ndays)),name='CCI') 15 | data = data.join(CCI) 16 | return data 17 | 18 | #timeLag 19 | def TL(data,ndays): 20 | index = data.index 21 | pH = data['high'].resample(str(ndays) + 'D').max().reindex(index).fillna(method='bfill') 22 | pL = data['low'].resample(str(ndays) + 'D').max().reindex(index).fillna(method='bfill') 23 | pO = data['open'] - data['open'].shift(1) 24 | timeLag = pO/(pH - pL) 25 | timeLag.name = 'TL' 26 | data = data.join(timeLag) 27 | return data 28 | 29 | 30 | #Ease of Movement 31 | def EVM(data,ndays): 32 | dm = ((data['high'] + data['low'])/2) - ((data['high'].shift(1) + data['low'].shift(1))/2) 33 | br = (data['volume']/100000000)/((data['high'] - data['low'])) 34 | EVM = dm/br 35 | EVM_MA = pd.Series(pd.rolling_mean(EVM,ndays),name='EVM') 36 | data = data.join(EVM_MA) 37 | return data 38 | 39 | # Simple Moving Average 40 | def SMA(data, ndays): 41 | SMA = pd.Series(pd.rolling_mean(data['close'], ndays), name = 'SMA') 42 | data = data.join(SMA) 43 | return data 44 | 45 | # Exponentially-weighted Moving Average 46 | def EWMA(data, ndays): 47 | EMA = pd.Series(pd.ewma(data['close'], span = ndays, min_periods = ndays - 1), 48 | name = 'EWMA_' + str(ndays)) 49 | data = data.join(EMA) 50 | return data 51 | 52 | 53 | # Rate of Change (ROC) 54 | def ROC(data,n): 55 | N = data['close'].diff(n) 56 | D = data['close'].shift(n) 57 | ROC = pd.Series(N/D,name='Rate of Change') 58 | data = data.join(ROC) 59 | return data 60 | 61 | # Force Index 62 | def ForceIndex(data, ndays): 63 | FI = pd.Series(data['close'].diff(ndays) * data['volume'], name = 'ForceIndex') 64 | data = data.join(FI) 65 | return data 66 | 67 | # Compute the Bollinger Bands 68 | def BBANDS(data, ndays): 69 | MA = pd.Series(pd.rolling_mean(data['close'], ndays)) 70 | SD = pd.Series(pd.rolling_std(data['close'], ndays)) 71 | b1 = MA + (2 * SD) 72 | B1 = pd.Series(b1, name = 'Upper BollingerBand') 73 | b2 = MA - (2 * SD) 74 | B2 = pd.Series(b2, name = 'Lower BollingerBand') 75 | data = data.join([B1,B2]) 76 | return data 77 | 78 | 79 | 80 | def plotTwoData(data1,data2): 81 | fig = plt.figure(figsize=(7,5)) 82 | ax = fig.add_subplot(2, 1, 1) 83 | ax.set_xticklabels([]) 84 | plt.plot(data1,lw=1) 85 | plt.title(str(data1.name) + 'Price Chart') 86 | plt.ylabel('Close Price') 87 | plt.grid(True) 88 | bx = fig.add_subplot(2, 1, 2) 89 | plt.plot(data2,'k',lw=0.75,linestyle='-',label='CCI') 90 | plt.legend(loc=2,prop={'size':9.5}) 91 | plt.ylabel(str(data2.name) + 'values') 92 | plt.grid(True) 93 | plt.setp(plt.gca().get_xticklabels(), rotation=30) 94 | plt.show() 95 | 96 | 97 | 98 | def toDatetime(data): 99 | data.index = pd.to_datetime(data.index) 100 | return data 101 | 102 | 103 | 104 | def forestFindFeature(X,y,n): 105 | # Build a forest and compute the feature importances 106 | forest = ExtraTreesClassifier(n_estimators=n,random_state=0) 107 | forest.fit(X, y) 108 | importances = forest.feature_importances_ 109 | std = np.std([tree.feature_importances_ for tree in forest.estimators_],axis=0) 110 | indices = np.argsort(importances)[::-1] 111 | 112 | #Print the feature ranking 113 | print("Feature ranking:") 114 | 115 | x_columns = X.columns 116 | features = [] 117 | for f in range(X.shape[1]): 118 | features.append(x_columns[int(indices[f])]) 119 | # print f,indices[f],x_columns[int(indices[f])],'===========', importances[indices[f]] 120 | return features 121 | 122 | 123 | 124 | 125 | -------------------------------------------------------------------------------- /back_test_system/self/data.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from abc import ABCMeta,abstractmethod 3 | 4 | import datetime 5 | import os,os.path 6 | 7 | import numpy as np 8 | import pandas as pd 9 | 10 | from event import MarketEvent 11 | 12 | class DataHandler(object): 13 | __metaclass__ = ABCMeta 14 | @abstractmethod 15 | def get_latest_bar(self,symbol): 16 | raise NotImplementedError("Should implement get_latest_bar()") 17 | 18 | @abstractmethod 19 | def get_latest_bars(self,symbol,N=1): 20 | raise NotImplementedError("Should implement get_latest_bars()") 21 | 22 | @abstractmethod 23 | def get_latest_bar_datetime(self,symbol): 24 | raise NotImplementedError("Should implement get_latest_bar_datetime()") 25 | 26 | @abstractmethod 27 | def get_latest_bar_value(self,symbol,val_type): 28 | raise NotImplementedError("Should implement get_latest_bar_value()") 29 | 30 | @abstractmethod 31 | def get_latest_bars_value(self,symbol,val_type,N=1): 32 | raise NotImplementedError("Should implement get_latest_bars_values()") 33 | 34 | def update_bars(self): 35 | raise NotImplementedError("Should implement update_bars()") 36 | 37 | 38 | class HistoricCSVDataHandler(DataHandler): 39 | 40 | def __init__(self,events,csv_dir,symbol_list): 41 | self.events = events 42 | self.csv_dir = csv_dir 43 | self.symbol_list = symbol_list 44 | 45 | self.symbol_data = {} 46 | self.latest_symbol_data = {} 47 | self.continue_backtest = True 48 | self.bar_index = 0 49 | 50 | self._open_convert_csv_files() 51 | 52 | def _open_convert_csv_files(self): 53 | comb_index = None 54 | for s in self.symbol_list: 55 | self.symbol_data[s] = pd.io.parsers.read_csv(os.path.join(self.csv_dir,'%s.csv' % s), header=0,index_col=0,parse_dates=True,names=['datetime','open','high','low','close','volume','adj_close']).sort() 56 | if comb_index is None: 57 | comb_index = self.symbol_data[s].index 58 | else: 59 | comb_index.union(self.symbol_data[s].index) 60 | self.latest_symbol_data[s] = [] 61 | 62 | for s in self.symbol_list: 63 | self.symbol_data[s] = self.symbol_data[s].reindex(index=comb_index,method='pad').iterrows() 64 | 65 | def _get_new_bar(self,symbol): 66 | for b in self.symbol_data[symbol]: 67 | yield b 68 | 69 | def get_latest_bar(self,symbol): 70 | try: 71 | bars_list = self.latest_symbol_data[symbol] 72 | 73 | except KeyError: 74 | print ("That symbol is not available in the historical data set.") 75 | raise 76 | 77 | else: 78 | return bars_list[-1] 79 | 80 | def get_latest_bars(self,symbol,N=1): 81 | try: 82 | bars_list = self.latest_symbol_data[symbol] 83 | except KeyError: 84 | print('That symbol is not available in the historical data set.') 85 | raise 86 | else: 87 | return bars_list[-N:] 88 | 89 | def get_latest_bar_datetime(self,symbol): 90 | try: 91 | bars_list = self.latest_symbol_data[symbol] 92 | except KeyError: 93 | print("That symbol is not available in the historical data set.") 94 | raise 95 | else: 96 | return bars_list[-1][0] 97 | 98 | def get_latest_bar_value(self,symbol,val_type): 99 | try: 100 | bars_list = self.latest_symbol_data[symbol] 101 | except KeyError: 102 | print("That symbol is not available in the historical data set.") 103 | raise 104 | else: 105 | return getattr(bars_list[-1][1],val_type) 106 | 107 | def get_latest_bars_values(self,symbol,val_type,N=1): 108 | try: 109 | bars_list = self.get_latest_bars(symbol,N) 110 | except KeyError: 111 | print("That symbol is not available in the historical data set.") 112 | raise 113 | else: 114 | return np.array([getattr(b[1],val_type) for b in bars_list]) 115 | 116 | def update_bars(self): 117 | for s in self.symbol_list: 118 | try: 119 | bar = next(self._get_new_bar(s)) 120 | except StopIteration: 121 | self.continue_backtest = False 122 | else: 123 | if bar is not None: 124 | self.latest_symbol_data[s].append(bar) 125 | self.events.put(MarketEvent()) -------------------------------------------------------------------------------- /deal_data/forecast.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import numpy as np 3 | import pandas as pd 4 | from pandas import DataFrame,Series 5 | import datetime 6 | import matplotlib.pyplot as plt 7 | import FeatureUtils 8 | #回归 9 | from sklearn.cross_validation import train_test_split 10 | from sklearn.metrics import classification_report 11 | from sklearn.svm import SVC 12 | from sklearn.linear_model import Lasso,LinearRegression,Ridge,LassoLars 13 | from sklearn.metrics import r2_score 14 | from sklearn import linear_model 15 | #分类 16 | from sklearn.ensemble import RandomForestClassifier 17 | from sklearn.linear_model import LogisticRegression 18 | from sklearn.lda import LDA 19 | from sklearn.metrics import confusion_matrix 20 | from sklearn.qda import QDA 21 | from sklearn.svm import LinearSVC, SVC 22 | 23 | 24 | 25 | def get_good_feature(ticker_data): 26 | ticker_data = FeatureUtils.CCI(ticker_data,10) 27 | ticker_data = FeatureUtils.TL(ticker_data,10) 28 | ticker_data = FeatureUtils.EVM(ticker_data,10) 29 | ticker_data = FeatureUtils.SMA(ticker_data,10) 30 | ticker_data = FeatureUtils.EWMA(ticker_data,10) 31 | ticker_data = FeatureUtils.ROC(ticker_data,10) 32 | ticker_data = FeatureUtils.ForceIndex(ticker_data,10) 33 | ticker_data = FeatureUtils.BBANDS(ticker_data,10) 34 | ticker_data = ticker_data.dropna() 35 | #formlization 36 | ticker_data = (ticker_data - ticker_data.mean())/(ticker_data.max() - ticker_data.min()) 37 | #get today and next day 38 | X = DataFrame(ticker_data.drop('close',1).fillna(0),dtype='float64')[:-1] 39 | #y要把后一天的日期跟前一天对其 40 | y = Series(ticker_data['close'].shift(-1).dropna(),dtype='|S6') 41 | #forest find the best 11 features 42 | features = FeatureUtils.forestFindFeature(X,y,100)[:11] 43 | 44 | ticker_data = ticker_data[features].join(ticker_data['close']) 45 | return ticker_data 46 | 47 | def get_regression_r2(ticker_data): 48 | data_len = len(ticker_data) 49 | split_line = int(data_len * 0.8) 50 | X = ticker_data.drop('close',1)[:-1] 51 | y = ticker_data['close'].shift(-1).dropna() 52 | 53 | X_train = X.ix[:split_line] 54 | X_test = X.ix[split_line:] 55 | y_train = y.ix[:split_line] 56 | y_test = y.ix[split_line:] 57 | 58 | models = [ 59 | ('LR',LinearRegression()), 60 | ('RidgeR',Ridge (alpha = 0.005)), 61 | ('lasso',Lasso(alpha=0.00001)), 62 | ('LassoLars',LassoLars(alpha=0.00001))] 63 | 64 | best_r2 = (0,0) 65 | for m in models: 66 | m[1].fit(np.array(X_train),np.array(y_train)) 67 | #因为index方面,pred出的其实是相当于往后挪了一位,跟原来的y_test是对不上的,所以x需要往前进一位 68 | #比较绕,所以从日期对应的方面去考虑 69 | pred = m[1].predict(X_test.shift(-1).fillna(0)) 70 | r2 = r2_score(y_test,pred) 71 | if r2 > best_r2[1]: 72 | best_r2 = (m[0],r2) 73 | # print "%s:\n%0.3f" % (m[0], r2_score(np.array(y_test),np.array(pred))) 74 | 75 | print 'the best regression is:',best_r2 76 | 77 | model = best_r2[0] 78 | model.fit(X_train, y_train) 79 | pred = model.predict(X_test.shift(-1).fillna(0)) 80 | pred_test = pd.Series(pred, index=y_test.index) 81 | 82 | fig = plt.figure() 83 | ax = fig.add_subplot(1,1,1) 84 | ax.plot(y_test,'r',lw=0.75,linestyle='-',label='realY') 85 | ax.plot(pred_test,'b',lw=0.75,linestyle='-',label='predY') 86 | plt.legend(loc=2,prop={'size':9}) 87 | plt.setp(plt.gca().get_xticklabels(), rotation=30) 88 | plt.grid(True) 89 | plt.show() 90 | 91 | return best_r2 92 | 93 | def get_classification_r2(ticker_data): 94 | 95 | 96 | data_len = len(ticker_data) 97 | split_line = int(data_len * 0.8) 98 | X = ticker_data.drop('close',1)[:-1] 99 | y = Series(ticker_data['close'].shift(-1).dropna(),dtype='|S6') 100 | 101 | X_train = X.ix[:split_line] 102 | X_test = X.ix[split_line:] 103 | y_train = y.ix[:split_line] 104 | y_test = y.ix[split_line:] 105 | 106 | 107 | models = [("LR", LogisticRegression()), 108 | ("LDA", LDA()), 109 | # ("QDA", QDA()), 110 | ("LSVC", LinearSVC()), 111 | ("RSVM", SVC( 112 | C=1000000.0, cache_size=200, class_weight=None, 113 | coef0=0.0, degree=3, gamma=0.0001, kernel='rbf', 114 | max_iter=-1, probability=False, random_state=None, 115 | shrinking=True, tol=0.001, verbose=False) 116 | ), 117 | ("RF", RandomForestClassifier( 118 | n_estimators=1000, criterion='gini', 119 | max_depth=None, min_samples_split=2, 120 | min_samples_leaf=1, max_features='auto', 121 | bootstrap=True, oob_score=False, n_jobs=1, 122 | random_state=None, verbose=0) 123 | )] 124 | 125 | best = (0,0) 126 | for m in models: 127 | m[1].fit(X_train, y_train) 128 | pred = m[1].predict(X_test) 129 | name = m[0] 130 | score = m[1].score(X_test, y_test) 131 | if score > best[1]: 132 | best = (name,score) 133 | print 'the best cluster is:' , best 134 | return best 135 | 136 | 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /util/Feature_utils.py: -------------------------------------------------------------------------------- 1 | #coding: utf-8 2 | import pandas as pd 3 | import numpy as np 4 | from pandas import DataFrame,Series 5 | import matplotlib.pyplot as plt 6 | import tushare as ts 7 | from sklearn.datasets import make_classification 8 | from sklearn.ensemble import ExtraTreesClassifier 9 | 10 | 11 | #CCI 12 | def CCI(data,ndays): 13 | TP = (data['high'] + data['low'] + data['close'])/3 14 | CCI = pd.Series((TP - pd.rolling_mean(TP, ndays)) / (0.015 * pd.rolling_std(TP, ndays)),name='CCI') 15 | data = data.join(CCI) 16 | return data 17 | 18 | #timeLag 19 | def TL(data,ndays): 20 | index = data.index 21 | pH = data['high'].resample(str(ndays) + 'D').max().reindex(index).fillna(method='bfill') 22 | pL = data['low'].resample(str(ndays) + 'D').max().reindex(index).fillna(method='bfill') 23 | pO = data['open'] - data['open'].shift(1) 24 | timeLag = pO/(pH - pL) 25 | timeLag.name = 'TL' 26 | data = data.join(timeLag) 27 | return data 28 | 29 | 30 | #Ease of Movement 31 | def EVM(data,ndays): 32 | dm = ((data['high'] + data['low'])/2) - ((data['high'].shift(1) + data['low'].shift(1))/2) 33 | br = (data['volume']/100000000)/((data['high'] - data['low'])) 34 | EVM = dm/br 35 | EVM_MA = pd.Series(pd.rolling_mean(EVM,ndays),name='EVM') 36 | data = data.join(EVM_MA) 37 | return data 38 | 39 | # Simple Moving Average 40 | def SMA(data, ndays): 41 | SMA = pd.Series(pd.rolling_mean(data['close'], ndays), name = 'SMA') 42 | data = data.join(SMA) 43 | return data 44 | 45 | # Exponentially-weighted Moving Average 46 | def EWMA(data, ndays): 47 | EMA = pd.Series(pd.ewma(data['close'], span = ndays, min_periods = ndays - 1), 48 | name = 'EWMA_' + str(ndays)) 49 | data = data.join(EMA) 50 | return data 51 | 52 | 53 | # Rate of Change (ROC) 54 | def ROC(data,n): 55 | N = data['close'].diff(n) 56 | D = data['close'].shift(n) 57 | ROC = pd.Series(N/D,name='Rate of Change') 58 | data = data.join(ROC) 59 | return data 60 | 61 | # Force Index 62 | def ForceIndex(data, ndays): 63 | FI = pd.Series(data['close'].diff(ndays) * data['volume'], name = 'ForceIndex') 64 | data = data.join(FI) 65 | return data 66 | 67 | # Compute the Bollinger Bands 68 | def BBANDS(data, ndays): 69 | MA = pd.Series(pd.rolling_mean(data['close'], ndays)) 70 | SD = pd.Series(pd.rolling_std(data['close'], ndays)) 71 | b1 = MA + (2 * SD) 72 | B1 = pd.Series(b1, name = 'Upper BollingerBand') 73 | b2 = MA - (2 * SD) 74 | B2 = pd.Series(b2, name = 'Lower BollingerBand') 75 | data = data.join([B1,B2]) 76 | return data 77 | 78 | 79 | 80 | def plotTwoData(data1,data2): 81 | fig = plt.figure(figsize=(7,5)) 82 | ax = fig.add_subplot(2, 1, 1) 83 | ax.set_xticklabels([]) 84 | plt.plot(data1,lw=1) 85 | plt.title(str(data1.name) + 'Price Chart') 86 | plt.ylabel('Close Price') 87 | plt.grid(True) 88 | bx = fig.add_subplot(2, 1, 2) 89 | plt.plot(data2,'k',lw=0.75,linestyle='-',label='CCI') 90 | plt.legend(loc=2,prop={'size':9.5}) 91 | plt.ylabel(str(data2.name) + 'values') 92 | plt.grid(True) 93 | plt.setp(plt.gca().get_xticklabels(), rotation=30) 94 | plt.show() 95 | 96 | 97 | 98 | def toDatetime(data): 99 | data.index = pd.to_datetime(data.index) 100 | return data 101 | 102 | 103 | 104 | def forestFindFeature(X,y,n): 105 | # Build a forest and compute the feature importances 106 | forest = ExtraTreesClassifier(n_estimators=n,random_state=0) 107 | forest.fit(X, y) 108 | importances = forest.feature_importances_ 109 | std = np.std([tree.feature_importances_ for tree in forest.estimators_],axis=0) 110 | indices = np.argsort(importances)[::-1] 111 | 112 | #Print the feature ranking 113 | print("Feature ranking:") 114 | 115 | x_columns = X.columns 116 | features = [] 117 | for f in range(X.shape[1]): 118 | features.append(x_columns[int(indices[f])]) 119 | print f,indices[f],x_columns[int(indices[f])],'===========', importances[indices[f]] 120 | return features 121 | 122 | 123 | #找到最好的几个feature 124 | def get_good_feature(ticker_data,n): 125 | ticker_data = CCI(ticker_data,10) 126 | ticker_data = TL(ticker_data,10) 127 | ticker_data = EVM(ticker_data,10) 128 | ticker_data = SMA(ticker_data,10) 129 | ticker_data = EWMA(ticker_data,10) 130 | ticker_data = ROC(ticker_data,10) 131 | ticker_data = ForceIndex(ticker_data,10) 132 | ticker_data = BBANDS(ticker_data,10) 133 | print ticker_data 134 | 135 | ticker_data = ticker_data.fillna(method='ffill').fillna(method='bfill').dropna() 136 | 137 | #formlization 138 | ticker_data = (ticker_data - ticker_data.mean())/(ticker_data.max() - ticker_data.min()) 139 | #get today and next day 140 | X = DataFrame(ticker_data.drop('close',1).fillna(0),dtype='float64') 141 | #y要把后一天的日期跟前一天对其 142 | y = Series(ticker_data['close'].dropna(),dtype='|S6') 143 | 144 | #forest find the best 11 features 145 | features = forestFindFeature(X,y,100)[:n] 146 | 147 | ticker_data = ticker_data[features].join(ticker_data['close']) 148 | return ticker_data 149 | 150 | 151 | 152 | 153 | -------------------------------------------------------------------------------- /back_test_system/backtest.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | # backtest.py 5 | 6 | from __future__ import print_function 7 | 8 | import datetime 9 | import pprint 10 | try: 11 | import Queue as queue 12 | except ImportError: 13 | import queue 14 | import time 15 | 16 | 17 | class Backtest(object): 18 | """ 19 | Enscapsulates the settings and components for carrying out 20 | an event-driven backtest. 21 | """ 22 | 23 | def __init__( 24 | self, csv_dir, symbol_list, initial_capital, 25 | heartbeat, start_date, data_handler, 26 | execution_handler, portfolio, strategy 27 | ): 28 | """ 29 | Initialises the backtest. 30 | 31 | Parameters: 32 | csv_dir - The hard root to the CSV data directory. 33 | symbol_list - The list of symbol strings. 34 | intial_capital - The starting capital for the portfolio. 35 | heartbeat - Backtest "heartbeat" in seconds 36 | start_date - The start datetime of the strategy. 37 | data_handler - (Class) Handles the market data feed. 38 | execution_handler - (Class) Handles the orders/fills for trades. 39 | portfolio - (Class) Keeps track of portfolio current and prior positions. 40 | strategy - (Class) Generates signals based on market data. 41 | """ 42 | self.csv_dir = csv_dir 43 | self.symbol_list = symbol_list 44 | self.initial_capital = initial_capital 45 | self.heartbeat = heartbeat 46 | self.start_date = start_date 47 | 48 | self.data_handler_cls = data_handler 49 | self.execution_handler_cls = execution_handler 50 | self.portfolio_cls = portfolio 51 | self.strategy_cls = strategy 52 | 53 | self.events = queue.Queue() 54 | 55 | self.signals = 0 56 | self.orders = 0 57 | self.fills = 0 58 | self.num_strats = 1 59 | 60 | self._generate_trading_instances() 61 | 62 | def _generate_trading_instances(self): 63 | """ 64 | Generates the trading instance objects from 65 | their class types. 66 | """ 67 | print( 68 | "Creating DataHandler, Strategy, Portfolio and ExecutionHandler" 69 | ) 70 | self.data_handler = self.data_handler_cls(self.events, self.csv_dir, self.symbol_list) 71 | self.strategy = self.strategy_cls(self.data_handler, self.events) 72 | self.portfolio = self.portfolio_cls(self.data_handler, self.events, self.start_date, 73 | self.initial_capital) 74 | self.execution_handler = self.execution_handler_cls(self.events) 75 | 76 | def _run_backtest(self): 77 | """ 78 | Executes the backtest. 79 | """ 80 | i = 0 81 | while True: 82 | i += 1 83 | print(i) 84 | # Update the market bars 85 | if self.data_handler.continue_backtest == True: 86 | self.data_handler.update_bars() 87 | else: 88 | break 89 | 90 | # Handle the events 91 | while True: 92 | try: 93 | event = self.events.get(False) 94 | except queue.Empty: 95 | break 96 | else: 97 | if event is not None: 98 | if event.type == 'MARKET': 99 | self.strategy.calculate_signals(event) 100 | self.portfolio.update_timeindex(event) 101 | 102 | elif event.type == 'SIGNAL': 103 | self.signals += 1 104 | self.portfolio.update_signal(event) 105 | 106 | elif event.type == 'ORDER': 107 | self.orders += 1 108 | self.execution_handler.execute_order(event) 109 | 110 | elif event.type == 'FILL': 111 | self.fills += 1 112 | self.portfolio.update_fill(event) 113 | 114 | time.sleep(self.heartbeat) 115 | 116 | def _output_performance(self): 117 | """ 118 | Outputs the strategy performance from the backtest. 119 | """ 120 | self.portfolio.create_equity_curve_dataframe() 121 | 122 | print("Creating summary stats...") 123 | stats = self.portfolio.output_summary_stats() 124 | 125 | print("Creating equity curve...") 126 | print(self.portfolio.equity_curve.tail(10)) 127 | pprint.pprint(stats) 128 | 129 | print("Signals: %s" % self.signals) 130 | print("Orders: %s" % self.orders) 131 | print("Fills: %s" % self.fills) 132 | 133 | def simulate_trading(self): 134 | """ 135 | Simulates the backtest and outputs portfolio performance. 136 | """ 137 | self._run_backtest() 138 | self._output_performance() 139 | -------------------------------------------------------------------------------- /get_data/find_pairs.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import matplotlib.pyplot as plt 3 | import MySQLdb as mdb 4 | import datetime 5 | import numpy as np 6 | from pandas import Series,DataFrame 7 | import matplotlib.dates as mdates 8 | import pandas as pd 9 | from matplotlib.collections import LineCollection 10 | from sklearn import cluster, covariance, manifold 11 | from pandas.stats.api import ols 12 | import statsmodels.tsa.stattools as ts 13 | import sys 14 | reload(sys) # Python2.5 初始化后会删除 sys.setdefaultencoding 这个方法,我们需要重新载入 15 | sys.setdefaultencoding('utf-8') 16 | 17 | #找到有协整关系的pairs 18 | 19 | def plot_price_series(df, ts1, ts2): 20 | months = mdates.MonthLocator() # every month 21 | fig, ax = plt.subplots() 22 | ax.plot(df.index, df[ts1], label=ts1) 23 | ax.plot(df.index, df[ts2], label=ts2) 24 | ax.xaxis.set_major_locator(months) 25 | ax.xaxis.set_major_formatter(mdates.DateFormatter('%b %Y')) 26 | ax.set_xlim(datetime.datetime(2016, 5, 1), datetime.datetime(2016, 10, 1)) 27 | ax.grid(True) 28 | fig.autofmt_xdate() 29 | 30 | plt.xlabel('Month/Year') 31 | plt.ylabel('Price ($)') 32 | plt.title('%s and %s Daily Prices' % (ts1, ts2)) 33 | plt.legend() 34 | plt.show() 35 | 36 | def plot_scatter_series(df, ts1, ts2): 37 | plt.xlabel('%s Price ($)' % ts1) 38 | plt.ylabel('%s Price ($)' % ts2) 39 | plt.title('%s and %s Price Scatterplot' % (ts1, ts2)) 40 | plt.scatter(df[ts1], df[ts2]) 41 | plt.show() 42 | 43 | 44 | #get name 45 | def get_tickers_from_db(con): 46 | #get name form symbol; 47 | with con: 48 | cur = con.cursor() 49 | cur.execute('SELECT id,ticker,name FROM symbol') 50 | data = cur.fetchall() 51 | return [(d[0],d[1],d[2]) for d in data] 52 | 53 | #get daily data of 300 54 | def get_daily_data_from_db(ticker,ticker_name,ticker_id,date_list,con): 55 | with con: 56 | cur = con.cursor() 57 | cur.execute('SELECT price_date,open_price,high_price,low_price,close_price,volume from daily_price where symbol_id = %s' % ticker_id) 58 | daily_data = cur.fetchall() 59 | dates = np.array([d[0] for d in daily_data]) 60 | #此处以受益为基准 61 | # open_price = np.array([d[2] for d in daily_data],dtype='float64') 62 | close_price = np.array([d[4] for d in daily_data],dtype='float64') 63 | # var_price = close_price - open_price 64 | # daily_data = DataFrame(var_price,index=dates,columns=[ticker_id]) 65 | 66 | daily_data = DataFrame(close_price,index=dates,columns=[ticker_name]) 67 | daily_data = daily_data.reindex(date_list,method='ffill') 68 | return daily_data 69 | 70 | 71 | #dealing data with two-pair way to calculating 72 | def deal_data(whole_data): 73 | finall_pair = [] 74 | print '正在计算... ' 75 | 76 | for i in range(len(whole_data)): 77 | for r in range(i,len(whole_data)): 78 | 79 | if i == r: 80 | continue 81 | 82 | 83 | d1 = whole_data[i].fillna(method='pad').fillna(0) 84 | d1_name = d1.columns[0] 85 | d2 = whole_data[r].fillna(method='pad').fillna(0) 86 | d2_name = d2.columns[0] 87 | df = pd.concat([d1,d2],axis=1) 88 | 89 | res = ols(y=d1[d1_name], x=d2[d2_name]) 90 | beta_hr = res.beta.x 91 | df["res"] = df[d1_name] - beta_hr*df[d2_name] 92 | cadf = ts.adfuller(df["res"]) 93 | 94 | #judge比较cadf那俩值的大小 95 | cadf1 = cadf[0] 96 | cadf2 = cadf[4]['5%'] 97 | #这样更明显 98 | if cadf1 - cadf2 < -4 : 99 | print '得到结果=======>>>>>>',i,r,cadf1,cadf2,d1_name,d2_name 100 | finall_pair.append((i,r)) 101 | test(i,r) 102 | 103 | return finall_pair 104 | 105 | #处理最终pair数据 106 | def deal_with_fianl_data(whole,pairs): 107 | names = []; 108 | final = []; 109 | for i in range(len(whole)): 110 | name = whole[i].columns[0] 111 | names.append(name) 112 | for r in range(len(pairs)): 113 | t = pairs[r] 114 | name1 = names[t[0]] 115 | name2 = names[t[1]] 116 | final.append((name1,name2)) 117 | return final 118 | 119 | def test(d1,d2): 120 | d1 = whole_data[d1].fillna(method='pad').fillna(0) 121 | d1_name = d1.columns[0] 122 | d2 = whole_data[d2].fillna(method='pad').fillna(0) 123 | d2_name = d2.columns[0] 124 | df = pd.concat([d1,d2],axis=1) 125 | plot_price_series(df, d1.columns[0], d2.columns[0]) 126 | plot_scatter_series(df,d1.columns[0], d2.columns[0]) 127 | 128 | 129 | 130 | if __name__ == '__main__': 131 | db_host = 'localhost' 132 | db_user = 'root' 133 | db_password = '' 134 | db_name = 'securities_master' 135 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name) 136 | date_list = pd.date_range('5/1/2016', '10/1/2016', freq='1D') 137 | tickers = get_tickers_from_db(con) 138 | whole_data = [] 139 | 140 | for i in range(len(tickers)): 141 | ticker = tickers[i] 142 | ticker_name = ticker[2] 143 | ticker_id = ticker[1] 144 | daily_data = get_daily_data_from_db(ticker,ticker_name,ticker_id,date_list,con) 145 | whole_data.append(daily_data) 146 | 147 | #dealing data with two-pair way to calculating 148 | final_pair = deal_data(whole_data) 149 | 150 | #整理最终结果,因为finall-pair返回的是序号 151 | pairs = deal_with_fianl_data(whole_data,final_pair) 152 | 153 | 154 | print pairs 155 | 156 | 157 | 158 | -------------------------------------------------------------------------------- /get_data/clustering_hs300.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import matplotlib.pyplot as plt 3 | import MySQLdb as mdb 4 | import datetime 5 | import numpy as np 6 | from pandas import Series,DataFrame 7 | import pandas as pd 8 | from matplotlib.collections import LineCollection 9 | from sklearn import cluster, covariance, manifold 10 | import sys 11 | reload(sys) # Python2.5 初始化后会删除 sys.setdefaultencoding 这个方法,我们需要重新载入 12 | sys.setdefaultencoding('utf-8') 13 | 14 | 15 | 16 | 17 | #get name 18 | def get_tickers_from_db(con): 19 | #get name form symbol; 20 | with con: 21 | cur = con.cursor() 22 | cur.execute('SELECT id,ticker,name FROM symbol') 23 | data = cur.fetchall() 24 | return [(d[0],d[1],d[2]) for d in data] 25 | 26 | 27 | #read dataset from db 28 | def get_timeline_from_db(ticker,ticker_name,ticker_id,date_list,con): 29 | with con: 30 | cur = con.cursor() 31 | cur.execute('SELECT price_date,open_price,high_price,low_price,close_price,volume from daily_price where symbol_id = %s' % ticker_id) 32 | daily_data = cur.fetchall() 33 | dates = np.array([d[0] for d in daily_data]) 34 | open_price = np.array([d[2] for d in daily_data],dtype='float64') 35 | close_price = np.array([d[4] for d in daily_data],dtype='float64') 36 | var_price = close_price - open_price 37 | daily_data = DataFrame(var_price,index=dates,columns=[ticker_name]) 38 | daily_data = daily_data.reindex(date_list,method='ffill') 39 | 40 | return daily_data 41 | 42 | #deal data 43 | def deal_with_data(whole_data): 44 | 45 | #concat 46 | final = pd.concat(whole_data,axis=1) 47 | #fix data 48 | final = final.fillna(method='ffill') 49 | 50 | #由于太慢算得慢,所以改 51 | final = final.ix[-150:] 52 | 53 | return final 54 | 55 | #clustering 56 | def cluster_data(data): 57 | names = data.columns 58 | edge_model = covariance.GraphLassoCV() 59 | data = np.array(data) 60 | 61 | X = data.copy().T 62 | X /= X.std(axis=0) 63 | 64 | 65 | edge_model.fit(X) 66 | _, labels = cluster.affinity_propagation(edge_model.covariance_) 67 | n_labels = labels.max() 68 | 69 | 70 | for i in range(n_labels + 1): 71 | print('Cluster %i: %s' % ((i + 1), ', '.join(names[labels == i]))) 72 | 73 | 74 | #Visualization 75 | node_position_model = manifold.LocallyLinearEmbedding(n_components=2, eigen_solver='dense', n_neighbors=6) 76 | embedding = node_position_model.fit_transform(X.T).T 77 | plt.figure(1, facecolor='w', figsize=(10, 8)) 78 | plt.clf() 79 | ax = plt.axes([0., 0., 1., 1.]) 80 | plt.axis('off') 81 | 82 | # Display a graph of the partial correlations 83 | partial_correlations = edge_model.precision_.copy() 84 | d = 1 / np.sqrt(np.diag(partial_correlations)) 85 | partial_correlations *= d 86 | partial_correlations *= d[:, np.newaxis] 87 | non_zero = (np.abs(np.triu(partial_correlations, k=1)) > 0.02) 88 | 89 | # Plot the nodes using the coordinates of our embedding 90 | plt.scatter(embedding[0], embedding[1], s=100 * d ** 2, c=labels,cmap=plt.cm.spectral) 91 | 92 | # Plot the edges 93 | start_idx, end_idx = np.where(non_zero) 94 | #a sequence of (*line0*, *line1*, *line2*), where:: 95 | # linen = (x0, y0), (x1, y1), ... (xm, ym) 96 | segments = [[embedding[:, start], embedding[:, stop]] for start, stop in zip(start_idx, end_idx)] 97 | values = np.abs(partial_correlations[non_zero]) 98 | lc = LineCollection(segments,zorder=0, cmap=plt.cm.hot_r,norm=plt.Normalize(0, .7 * values.max())) 99 | lc.set_array(values) 100 | lc.set_linewidths(15 * values) 101 | ax.add_collection(lc) 102 | 103 | # Add a label to each node. The challenge here is that we want to 104 | # position the labels to avoid overlap with other labels 105 | for index, (name, label, (x, y)) in enumerate(zip(names, labels, embedding.T)): 106 | name = str(name).decode('utf-8').encode('utf-8') 107 | dx = x - embedding[0] 108 | dx[index] = 1 109 | dy = y - embedding[1] 110 | dy[index] = 1 111 | this_dx = dx[np.argmin(np.abs(dy))] 112 | this_dy = dy[np.argmin(np.abs(dx))] 113 | if this_dx > 0: 114 | horizontalalignment = 'left' 115 | x = x + .002 116 | else: 117 | horizontalalignment = 'right' 118 | x = x - .002 119 | if this_dy > 0: 120 | verticalalignment = 'bottom' 121 | y = y + .002 122 | else: 123 | verticalalignment = 'top' 124 | y = y - .002 125 | plt.text(x, y, name , size=10,horizontalalignment=horizontalalignment,verticalalignment=verticalalignment,bbox=dict(facecolor='w',edgecolor=plt.cm.spectral(label / float(n_labels)),alpha=.6)) 126 | 127 | plt.xlim(embedding[0].min() - .15 * embedding[0].ptp(), 128 | embedding[0].max() + .10 * embedding[0].ptp(),) 129 | plt.ylim(embedding[1].min() - .03 * embedding[1].ptp(), 130 | embedding[1].max() + .03 * embedding[1].ptp()) 131 | plt.show() 132 | 133 | 134 | 135 | 136 | if __name__ == '__main__': 137 | db_host = 'localhost' 138 | db_user = 'root' 139 | db_password = '' 140 | db_name = 'securities_master' 141 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name) 142 | whole_data = [] 143 | #all tickers 144 | tickers = get_tickers_from_db(con) 145 | #由于太慢算得慢,所以改 146 | date_list = pd.date_range('1/1/2016', '10/1/2016', freq='1D') 147 | 148 | for i in range(len(tickers)): 149 | ticker = tickers[i] 150 | ticker_name = ticker[2] 151 | ticker_id = ticker[1] 152 | daily_data = get_timeline_from_db(ticker,ticker_name,ticker_id,date_list,con) 153 | whole_data.append(daily_data) 154 | 155 | final_data = deal_with_data(whole_data) 156 | # cluster data 157 | cluster_data(final_data) 158 | 159 | 160 | 161 | 162 | 163 | 164 | -------------------------------------------------------------------------------- /back_test_system/event.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | # event.py 5 | 6 | from __future__ import print_function 7 | 8 | 9 | class Event(object): 10 | """ 11 | Event is base class providing an interface for all subsequent 12 | (inherited) events, that will trigger further events in the 13 | trading infrastructure. 14 | """ 15 | pass 16 | 17 | 18 | class MarketEvent(Event): 19 | """ 20 | Handles the event of receiving a new market update with 21 | corresponding bars. 22 | """ 23 | 24 | def __init__(self): 25 | """ 26 | Initialises the MarketEvent. 27 | """ 28 | self.type = 'MARKET' 29 | 30 | 31 | class SignalEvent(Event): 32 | """ 33 | Handles the event of sending a Signal from a Strategy object. 34 | This is received by a Portfolio object and acted upon. 35 | """ 36 | 37 | def __init__(self, strategy_id, symbol, datetime, signal_type, strength): 38 | """ 39 | Initialises the SignalEvent. 40 | 41 | Parameters: 42 | strategy_id - The unique ID of the strategy sending the signal. 43 | symbol - The ticker symbol, e.g. 'GOOG'. 44 | datetime - The timestamp at which the signal was generated. 45 | signal_type - 'LONG' or 'SHORT'. 46 | strength - An adjustment factor "suggestion" used to scale 47 | quantity at the portfolio level. Useful for pairs strategies. 48 | """ 49 | self.strategy_id = strategy_id 50 | self.type = 'SIGNAL' 51 | self.symbol = symbol 52 | self.datetime = datetime 53 | self.signal_type = signal_type 54 | self.strength = strength 55 | 56 | 57 | class OrderEvent(Event): 58 | """ 59 | Handles the event of sending an Order to an execution system. 60 | The order contains a symbol (e.g. GOOG), a type (market or limit), 61 | quantity and a direction. 62 | """ 63 | 64 | def __init__(self, symbol, order_type, quantity, direction): 65 | """ 66 | Initialises the order type, setting whether it is 67 | a Market order ('MKT') or Limit order ('LMT'), has 68 | a quantity (integral) and its direction ('BUY' or 69 | 'SELL'). 70 | 71 | TODO: Must handle error checking here to obtain 72 | rational orders (i.e. no negative quantities etc). 73 | 74 | Parameters: 75 | symbol - The instrument to trade. 76 | order_type - 'MKT' or 'LMT' for Market or Limit. 77 | quantity - Non-negative integer for quantity. 78 | direction - 'BUY' or 'SELL' for long or short. 79 | """ 80 | self.type = 'ORDER' 81 | self.symbol = symbol 82 | self.order_type = order_type 83 | self.quantity = quantity 84 | self.direction = direction 85 | 86 | def print_order(self): 87 | """ 88 | Outputs the values within the Order. 89 | """ 90 | print( 91 | "Order: Symbol=%s, Type=%s, Quantity=%s, Direction=%s" % 92 | (self.symbol, self.order_type, self.quantity, self.direction) 93 | ) 94 | 95 | 96 | class FillEvent(Event): 97 | """ 98 | Encapsulates the notion of a Filled Order, as returned 99 | from a brokerage. Stores the quantity of an instrument 100 | actually filled and at what price. In addition, stores 101 | the commission of the trade from the brokerage. 102 | 103 | TODO: Currently does not support filling positions at 104 | different prices. This will be simulated by averaging 105 | the cost. 106 | """ 107 | 108 | def __init__(self, timeindex, symbol, exchange, quantity, 109 | direction, fill_cost, commission=None): 110 | """ 111 | Initialises the FillEvent object. Sets the symbol, exchange, 112 | quantity, direction, cost of fill and an optional 113 | commission. 114 | 115 | If commission is not provided, the Fill object will 116 | calculate it based on the trade size and Interactive 117 | Brokers fees. 118 | 119 | Parameters: 120 | timeindex - The bar-resolution when the order was filled. 121 | symbol - The instrument which was filled. 122 | exchange - The exchange where the order was filled. 123 | quantity - The filled quantity. 124 | direction - The direction of fill ('BUY' or 'SELL') 125 | fill_cost - The holdings value in dollars. 126 | commission - An optional commission sent from IB. 127 | """ 128 | self.type = 'FILL' 129 | self.timeindex = timeindex 130 | self.symbol = symbol 131 | self.exchange = exchange 132 | self.quantity = quantity 133 | self.direction = direction 134 | self.fill_cost = fill_cost 135 | 136 | # Calculate commission 137 | if commission is None: 138 | self.commission = self.calculate_ib_commission() 139 | else: 140 | self.commission = commission 141 | 142 | def calculate_ib_commission(self): 143 | """ 144 | Calculates the fees of trading based on an Interactive 145 | Brokers fee structure for API, in USD. 146 | 147 | This does not include exchange or ECN fees. 148 | 149 | Based on "US API Directed Orders": 150 | https://www.interactivebrokers.com/en/index.php?f=commission&p=stocks2 151 | """ 152 | full_cost = 1.3 153 | if self.quantity <= 500: 154 | full_cost = max(1.3, 0.013 * self.quantity) 155 | else: # Greater than 500 156 | full_cost = max(1.3, 0.008 * self.quantity) 157 | return full_cost 158 | -------------------------------------------------------------------------------- /back_test_system/self/portfolio.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import datetime 4 | from math import floor 5 | try: 6 | import Queue as queue 7 | except ImportError: 8 | import queue 9 | import numpy as np 10 | import pandas as pd 11 | 12 | from event import FillEvent,OrderEvent 13 | from performance import create_sharpe_ratio,create_drawdowns 14 | 15 | 16 | class Portfolio(object): 17 | def __init__(self,bars,event,start_date,intial_capital=100000.0): 18 | self.bars = bars 19 | self.events = events 20 | self.symbol_list = self.bars.symbol_list 21 | self.start_date = start_date 22 | self.intial_capital = intial_capital 23 | self.all_positions = self.construct_all_positions() 24 | self.current_positions = dict((k,v) for k,v in [(s,0) for s in self.symbol_list]) 25 | self.all_holdings = self.construct_all_holdings() 26 | self.current_holdings = self.construct_current_holdings() 27 | 28 | def construct_all_positions(self): 29 | d = dict((k,v) for k,v in [(s,0) for s in self.symbol_list]) 30 | d['datetime'] = self.start_date 31 | return [d] 32 | 33 | def construct_all_holdings(self): 34 | d = dict((k,v) for k,v in [(s,0.0) for s in self.symbol_list]) 35 | d['datetime'] = self.start_date 36 | d['cash'] = self.intial_capital 37 | d['commission'] = 0.0 38 | d['total'] = self.intial_capital 39 | return [d] 40 | 41 | def construct_current_holdings(self): 42 | d = dict((k,v) for k,v in [(s,0.0) for s in self.symbol_list]) 43 | d['cash'] = self.intial_capital 44 | d['commission'] = 0.0 45 | d['total'] = self.intial_capital 46 | return d 47 | 48 | def update_timeindex(self,event): 49 | latest_datetime = self.bars.get_latest_bar_datetime(self.symbol_list[0]) 50 | 51 | # Update positions 52 | # ================ 53 | dp = dict((k,v) for k,v in [(s,0) for s in self.symbol_list]) 54 | dp['datetime'] = latest_datetime 55 | 56 | for s in self.symbol_list: 57 | dp[s] = self.current_positions[s] 58 | 59 | self.all_positions.append(dp) 60 | 61 | # Update holdings 62 | # =============== 63 | dh = divt((k,v) for k,v in [(s,0) for s in self.symbol_list]) 64 | 65 | dh['datetime'] = latest_datetime 66 | dh['cash'] = self.current_holdings['cash'] 67 | dh['commission'] = self.current_holdings['commission'] 68 | dh['total'] = self.current_holdings['cash'] 69 | 70 | for s in self.symbol_list: 71 | market_value = self.current_positions[s] * self.bars.get_latest_bar_value(s,'close') 72 | dh[s] = market_value 73 | dh['total'] += market_value 74 | 75 | self.all_holdings.append(dh) 76 | 77 | # ======================= 78 | # FILL/POSITION HANDLING 79 | # ======================= 80 | 81 | def update_positions_from_fill(self,fill): 82 | fill_dir = 0 83 | if fill.direction == 'BUY': 84 | fill_dir = 1 85 | if fill.direction == 'SELL': 86 | fill_dir = -1 87 | self.current_positions[fill.symbol] += fill_dir * fill.quantity 88 | 89 | def update_holdings_from_fill(self,fill): 90 | fill_dir = 0 91 | if fill_direction == 'BUY': 92 | fill_dir = 1 93 | if fill.direction == 'SELL': 94 | fill_dir = -1 95 | fill_cost = self.bars.get_latest_bar_value(fill.symbol,'adj_close') 96 | cost = fill_dir * fill_cost * fill.quantity 97 | self.current_holdings[fill.symbol] += cost 98 | self.current_holdings['commission'] += fill.commission 99 | self.current_holdings['cash'] -= (cost + fill.commission) 100 | self.current_holdings['total'] -= (cost + fill.commission) 101 | 102 | def update_fill(self,event): 103 | if event.type == 'FILL': 104 | self.update_positions_from_fill(event) 105 | self.update_holdings_from_fill(event) 106 | 107 | def generate_naive_order(self,signal): 108 | order = None 109 | 110 | symbol = signal.symbol 111 | direction = signal.signal_type 112 | strength = signal.strength 113 | 114 | mkt_quantity = 100 115 | cur_quantity = self.current_positions[symbol] 116 | order_type = 'MKT' 117 | 118 | if direction == 'LONG' and cur_quantity == 0: 119 | order = OrderEvent(symbol,order_type,mkt_quantity,'BUY') 120 | if direction == 'SHORT' and cur_quantity == 0: 121 | order = OrderEvent(symbol,order_type,mkt_quantity,'SELL') 122 | if direction == 'EXIT' and cur_quantity > 0: 123 | order = OrderEvent(symbol,order_type,abs(cur_quantity),'SELL') 124 | if direction == 'EXIT' and cur_quantity < 0: 125 | order = OrderEvent(symbol,order_type,abs(cur_quantity),'BUY') 126 | 127 | return order 128 | 129 | def update_signal(self,event): 130 | if event.type == 'SIGNAL': 131 | order_event = self.generate_naive_order(event) 132 | self.events.put(order_event) 133 | 134 | 135 | # ======================== 136 | # POST-BACKTEST STATISTICS 137 | # ======================== 138 | 139 | def create_equity_curve_dataframe(self): 140 | curve = pd.DataFrame(self.all_holdings) 141 | curve.set_index('datetime',inplace=True) 142 | curve['returns'] = curve['total'].pct_change() 143 | curve['equity_curve'] = (1.0 * curve['returns']).cumprod() 144 | self.equity_curve = curve 145 | 146 | def output_summary_stats(self): 147 | total_return = self.equity_curve['equity_curve'][-1] 148 | returns = self.equity_curve['returns'] 149 | pnl = self.equity_curve['equity_curve'] 150 | 151 | print (total_return,returns,pnl) 152 | 153 | sharpe_ratio = create_sharpe_ratio(returns,periods=252 * 60 * 6.5) 154 | drawdown,max_dd,dd_duration = create_drawdowns(pnl) 155 | self.equity_curve['drawdown'] = drawdown 156 | 157 | stats = [('Total Return','%0.2f%%' % ((total_return - 1.0) * 100.0)),('Sharpe Ratio','%0.2f' % sharpe_ratio),('Max Drawdown','%0.2f%%' % (max_dd * 100.0)),('Drawdown Duration','%d' % dd_duration)] 158 | 159 | self.equity_curve.to_csv('equity.csv') 160 | 161 | return stats 162 | 163 | -------------------------------------------------------------------------------- /back_test_system/ib_execution.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | # ib_execution.py 5 | 6 | from __future__ import print_function 7 | 8 | import datetime 9 | import time 10 | 11 | from ib.ext.Contract import Contract 12 | from ib.ext.Order import Order 13 | from ib.opt import ibConnection, message 14 | 15 | from event import FillEvent, OrderEvent 16 | from execution import ExecutionHandler 17 | 18 | 19 | class IBExecutionHandler(ExecutionHandler): 20 | """ 21 | Handles order execution via the Interactive Brokers 22 | API, for use against accounts when trading live 23 | directly. 24 | """ 25 | 26 | def __init__( 27 | self, events, order_routing="SMART", currency="USD" 28 | ): 29 | """ 30 | Initialises the IBExecutionHandler instance. 31 | 32 | Parameters: 33 | events - The Queue of Event objects. 34 | """ 35 | self.events = events 36 | self.order_routing = order_routing 37 | self.currency = currency 38 | self.fill_dict = {} 39 | 40 | self.tws_conn = self.create_tws_connection() 41 | self.order_id = self.create_initial_order_id() 42 | self.register_handlers() 43 | 44 | def _error_handler(self, msg): 45 | """Handles the capturing of error messages""" 46 | # Currently no error handling. 47 | print("Server Error: %s" % msg) 48 | 49 | def _reply_handler(self, msg): 50 | """Handles of server replies""" 51 | # Handle open order orderId processing 52 | if msg.typeName == "openOrder" and \ 53 | msg.orderId == self.order_id and \ 54 | not self.fill_dict.has_key(msg.orderId): 55 | self.create_fill_dict_entry(msg) 56 | # Handle Fills 57 | if msg.typeName == "orderStatus" and \ 58 | msg.status == "Filled" and \ 59 | self.fill_dict[msg.orderId]["filled"] == False: 60 | self.create_fill(msg) 61 | print("Server Response: %s, %s\n" % (msg.typeName, msg)) 62 | 63 | def create_tws_connection(self): 64 | """ 65 | Connect to the Trader Workstation (TWS) running on the 66 | usual port of 7496, with a clientId of 100. 67 | The clientId is chosen by us and we will need 68 | separate IDs for both the execution connection and 69 | market data connection, if the latter is used elsewhere. 70 | """ 71 | tws_conn = ibConnection() 72 | tws_conn.connect() 73 | return tws_conn 74 | 75 | def create_initial_order_id(self): 76 | """ 77 | Creates the initial order ID used for Interactive 78 | Brokers to keep track of submitted orders. 79 | """ 80 | # There is scope for more logic here, but we 81 | # will use "1" as the default for now. 82 | return 1 83 | 84 | def register_handlers(self): 85 | """ 86 | Register the error and server reply 87 | message handling functions. 88 | """ 89 | # Assign the error handling function defined above 90 | # to the TWS connection 91 | self.tws_conn.register(self._error_handler, 'Error') 92 | 93 | # Assign all of the server reply messages to the 94 | # reply_handler function defined above 95 | self.tws_conn.registerAll(self._reply_handler) 96 | 97 | def create_contract(self, symbol, sec_type, exch, prim_exch, curr): 98 | """Create a Contract object defining what will 99 | be purchased, at which exchange and in which currency. 100 | 101 | symbol - The ticker symbol for the contract 102 | sec_type - The security type for the contract ('STK' is 'stock') 103 | exch - The exchange to carry out the contract on 104 | prim_exch - The primary exchange to carry out the contract on 105 | curr - The currency in which to purchase the contract""" 106 | contract = Contract() 107 | contract.m_symbol = symbol 108 | contract.m_secType = sec_type 109 | contract.m_exchange = exch 110 | contract.m_primaryExch = prim_exch 111 | contract.m_currency = curr 112 | return contract 113 | 114 | def create_order(self, order_type, quantity, action): 115 | """Create an Order object (Market/Limit) to go long/short. 116 | 117 | order_type - 'MKT', 'LMT' for Market or Limit orders 118 | quantity - Integral number of assets to order 119 | action - 'BUY' or 'SELL'""" 120 | order = Order() 121 | order.m_orderType = order_type 122 | order.m_totalQuantity = quantity 123 | order.m_action = action 124 | return order 125 | 126 | def create_fill_dict_entry(self, msg): 127 | """ 128 | Creates an entry in the Fill Dictionary that lists 129 | orderIds and provides security information. This is 130 | needed for the event-driven behaviour of the IB 131 | server message behaviour. 132 | """ 133 | self.fill_dict[msg.orderId] = { 134 | "symbol": msg.contract.m_symbol, 135 | "exchange": msg.contract.m_exchange, 136 | "direction": msg.order.m_action, 137 | "filled": False 138 | } 139 | 140 | def create_fill(self, msg): 141 | """ 142 | Handles the creation of the FillEvent that will be 143 | placed onto the events queue subsequent to an order 144 | being filled. 145 | """ 146 | fd = self.fill_dict[msg.orderId] 147 | 148 | # Prepare the fill data 149 | symbol = fd["symbol"] 150 | exchange = fd["exchange"] 151 | filled = msg.filled 152 | direction = fd["direction"] 153 | fill_cost = msg.avgFillPrice 154 | 155 | # Create a fill event object 156 | fill = FillEvent( 157 | datetime.datetime.utcnow(), symbol, 158 | exchange, filled, direction, fill_cost 159 | ) 160 | 161 | # Make sure that multiple messages don't create 162 | # additional fills. 163 | self.fill_dict[msg.orderId]["filled"] = True 164 | 165 | # Place the fill event onto the event queue 166 | self.events.put(fill_event) 167 | 168 | def execute_order(self, event): 169 | """ 170 | Creates the necessary InteractiveBrokers order object 171 | and submits it to IB via their API. 172 | 173 | The results are then queried in order to generate a 174 | corresponding Fill object, which is placed back on 175 | the event queue. 176 | 177 | Parameters: 178 | event - Contains an Event object with order information. 179 | """ 180 | if event.type == 'ORDER': 181 | # Prepare the parameters for the asset order 182 | asset = event.symbol 183 | asset_type = "STK" 184 | order_type = event.order_type 185 | quantity = event.quantity 186 | direction = event.direction 187 | 188 | # Create the Interactive Brokers contract via the 189 | # passed Order event 190 | ib_contract = self.create_contract( 191 | asset, asset_type, self.order_routing, 192 | self.order_routing, self.currency 193 | ) 194 | 195 | # Create the Interactive Brokers order via the 196 | # passed Order event 197 | ib_order = self.create_order( 198 | order_type, quantity, direction 199 | ) 200 | 201 | # Use the connection to the send the order to IB 202 | self.tws_conn.placeOrder( 203 | self.order_id, ib_contract, ib_order 204 | ) 205 | 206 | # NOTE: This following line is crucial. 207 | # It ensures the order goes through! 208 | time.sleep(1) 209 | 210 | # Increment the order ID for this session 211 | self.order_id += 1 212 | -------------------------------------------------------------------------------- /back_test_system/data.py: -------------------------------------------------------------------------------- 1 | 2 | #!/usr/bin/python 3 | # -*- coding: utf-8 -*- 4 | 5 | # data.py 6 | 7 | from __future__ import print_function 8 | 9 | from abc import ABCMeta, abstractmethod 10 | import datetime 11 | import os, os.path 12 | 13 | import numpy as np 14 | import pandas as pd 15 | 16 | from event import MarketEvent 17 | 18 | 19 | class DataHandler(object): 20 | """ 21 | DataHandler is an abstract base class providing an interface for 22 | all subsequent (inherited) data handlers (both live and historic). 23 | 24 | The goal of a (derived) DataHandler object is to output a generated 25 | set of bars (OHLCVI) for each symbol requested. 26 | 27 | This will replicate how a live strategy would function as current 28 | market data would be sent "down the pipe". Thus a historic and live 29 | system will be treated identically by the rest of the backtesting suite. 30 | """ 31 | 32 | __metaclass__ = ABCMeta 33 | 34 | @abstractmethod 35 | def get_latest_bar(self, symbol): 36 | """ 37 | Returns the last bar updated. 38 | """ 39 | raise NotImplementedError("Should implement get_latest_bar()") 40 | 41 | @abstractmethod 42 | def get_latest_bars(self, symbol, N=1): 43 | """ 44 | Returns the last N bars updated. 45 | """ 46 | raise NotImplementedError("Should implement get_latest_bars()") 47 | 48 | @abstractmethod 49 | def get_latest_bar_datetime(self, symbol): 50 | """ 51 | Returns a Python datetime object for the last bar. 52 | """ 53 | raise NotImplementedError("Should implement get_latest_bar_datetime()") 54 | 55 | @abstractmethod 56 | def get_latest_bar_value(self, symbol, val_type): 57 | """ 58 | Returns one of the Open, High, Low, Close, Volume or OI 59 | from the last bar. 60 | """ 61 | raise NotImplementedError("Should implement get_latest_bar_value()") 62 | 63 | @abstractmethod 64 | def get_latest_bars_values(self, symbol, val_type, N=1): 65 | """ 66 | Returns the last N bar values from the 67 | latest_symbol list, or N-k if less available. 68 | """ 69 | raise NotImplementedError("Should implement get_latest_bars_values()") 70 | 71 | @abstractmethod 72 | def update_bars(self): 73 | """ 74 | Pushes the latest bars to the bars_queue for each symbol 75 | in a tuple OHLCVI format: (datetime, open, high, low, 76 | close, volume, open interest). 77 | """ 78 | raise NotImplementedError("Should implement update_bars()") 79 | 80 | 81 | class HistoricCSVDataHandler(DataHandler): 82 | """ 83 | HistoricCSVDataHandler is designed to read CSV files for 84 | each requested symbol from disk and provide an interface 85 | to obtain the "latest" bar in a manner identical to a live 86 | trading interface. 87 | """ 88 | 89 | def __init__(self, events, csv_dir, symbol_list): 90 | """ 91 | Initialises the historic data handler by requesting 92 | the location of the CSV files and a list of symbols. 93 | 94 | It will be assumed that all files are of the form 95 | 'symbol.csv', where symbol is a string in the list. 96 | 97 | Parameters: 98 | events - The Event Queue. 99 | csv_dir - Absolute directory path to the CSV files. 100 | symbol_list - A list of symbol strings. 101 | """ 102 | self.events = events 103 | self.csv_dir = csv_dir 104 | self.symbol_list = symbol_list 105 | 106 | self.symbol_data = {} 107 | self.latest_symbol_data = {} 108 | self.continue_backtest = True 109 | self.bar_index = 0 110 | 111 | self._open_convert_csv_files() 112 | 113 | def _open_convert_csv_files(self): 114 | """ 115 | Opens the CSV files from the data directory, converting 116 | them into pandas DataFrames within a symbol dictionary. 117 | 118 | For this handler it will be assumed that the data is 119 | taken from Yahoo. Thus its format will be respected. 120 | """ 121 | comb_index = None 122 | for s in self.symbol_list: 123 | # Load the CSV file with no header information, indexed on date 124 | self.symbol_data[s] = pd.io.parsers.read_csv( 125 | os.path.join(self.csv_dir, '%s.csv' % s), 126 | header=0, index_col=0, parse_dates=True, 127 | names=[ 128 | 'datetime', 'open', 'high', 129 | 'low', 'close', 'volume', 'adj_close' 130 | ] 131 | ).sort() 132 | 133 | # Combine the index to pad forward values 134 | if comb_index is None: 135 | comb_index = self.symbol_data[s].index 136 | else: 137 | comb_index.union(self.symbol_data[s].index) 138 | 139 | # Set the latest symbol_data to None 140 | self.latest_symbol_data[s] = [] 141 | 142 | # Reindex the dataframes 143 | for s in self.symbol_list: 144 | self.symbol_data[s] = self.symbol_data[s].\ 145 | reindex(index=comb_index, method='pad').iterrows() 146 | 147 | 148 | 149 | def _get_new_bar(self, symbol): 150 | """ 151 | Returns the latest bar from the data feed. 152 | """ 153 | for b in self.symbol_data[symbol]: 154 | yield b 155 | 156 | def get_latest_bar(self, symbol): 157 | """ 158 | Returns the last bar from the latest_symbol list. 159 | """ 160 | try: 161 | bars_list = self.latest_symbol_data[symbol] 162 | except KeyError: 163 | print("That symbol is not available in the historical data set.") 164 | raise 165 | else: 166 | return bars_list[-1] 167 | 168 | def get_latest_bars(self, symbol, N=1): 169 | """ 170 | Returns the last N bars from the latest_symbol list, 171 | or N-k if less available. 172 | """ 173 | try: 174 | bars_list = self.latest_symbol_data[symbol] 175 | except KeyError: 176 | print("That symbol is not available in the historical data set.") 177 | raise 178 | else: 179 | return bars_list[-N:] 180 | 181 | def get_latest_bar_datetime(self, symbol): 182 | """ 183 | Returns a Python datetime object for the last bar. 184 | """ 185 | try: 186 | bars_list = self.latest_symbol_data[symbol] 187 | except KeyError: 188 | print("That symbol is not available in the historical data set.") 189 | raise 190 | else: 191 | return bars_list[-1][0] 192 | 193 | def get_latest_bar_value(self, symbol, val_type): 194 | """ 195 | Returns one of the Open, High, Low, Close, Volume or OI 196 | values from the pandas Bar series object. 197 | """ 198 | try: 199 | bars_list = self.latest_symbol_data[symbol] 200 | except KeyError: 201 | print("That symbol is not available in the historical data set.") 202 | raise 203 | else: 204 | return getattr(bars_list[-1][1], val_type) 205 | 206 | def get_latest_bars_values(self, symbol, val_type, N=1): 207 | """ 208 | Returns the last N bar values from the 209 | latest_symbol list, or N-k if less available. 210 | """ 211 | try: 212 | bars_list = self.get_latest_bars(symbol, N) 213 | except KeyError: 214 | print("That symbol is not available in the historical data set.") 215 | raise 216 | else: 217 | return np.array([getattr(b[1], val_type) for b in bars_list]) 218 | 219 | def update_bars(self): 220 | """ 221 | Pushes the latest bar to the latest_symbol_data structure 222 | for all symbols in the symbol list. 223 | """ 224 | for s in self.symbol_list: 225 | try: 226 | bar = next(self._get_new_bar(s)) 227 | except StopIteration: 228 | self.continue_backtest = False 229 | else: 230 | if bar is not None: 231 | self.latest_symbol_data[s].append(bar) 232 | self.events.put(MarketEvent()) 233 | 234 | -------------------------------------------------------------------------------- /back_test_system/portfolio.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | # portfolio.py 5 | 6 | from __future__ import print_function 7 | 8 | import datetime 9 | from math import floor 10 | try: 11 | import Queue as queue 12 | except ImportError: 13 | import queue 14 | 15 | import numpy as np 16 | import pandas as pd 17 | 18 | from event import FillEvent, OrderEvent 19 | from performance import create_sharpe_ratio, create_drawdowns 20 | 21 | 22 | class Portfolio(object): 23 | """ 24 | The Portfolio class handles the positions and market 25 | value of all instruments at a resolution of a "bar", 26 | i.e. secondly, minutely, 5-min, 30-min, 60 min or EOD. 27 | 28 | The positions DataFrame stores a time-index of the 29 | quantity of positions held. 30 | 31 | The holdings DataFrame stores the cash and total market 32 | holdings value of each symbol for a particular 33 | time-index, as well as the percentage change in 34 | portfolio total across bars. 35 | """ 36 | 37 | def __init__(self, bars, events, start_date, initial_capital=100000.0): 38 | """ 39 | Initialises the portfolio with bars and an event queue. 40 | Also includes a starting datetime index and initial capital 41 | (USD unless otherwise stated). 42 | 43 | Parameters: 44 | bars - The DataHandler object with current market data. 45 | events - The Event Queue object. 46 | start_date - The start date (bar) of the portfolio. 47 | initial_capital - The starting capital in USD. 48 | """ 49 | self.bars = bars 50 | self.events = events 51 | self.symbol_list = self.bars.symbol_list 52 | self.start_date = start_date 53 | self.initial_capital = initial_capital 54 | self.all_positions = self.construct_all_positions() 55 | self.current_positions = dict( (k,v) for k, v in [(s, 0) for s in self.symbol_list] ) 56 | 57 | self.all_holdings = self.construct_all_holdings() 58 | self.current_holdings = self.construct_current_holdings() 59 | 60 | def construct_all_positions(self): 61 | """ 62 | Constructs the positions list using the start_date 63 | to determine when the time index will begin. 64 | """ 65 | d = dict( (k,v) for k, v in [(s, 0) for s in self.symbol_list] ) 66 | d['datetime'] = self.start_date 67 | return [d] 68 | 69 | def construct_all_holdings(self): 70 | """ 71 | Constructs the holdings list using the start_date 72 | to determine when the time index will begin. 73 | """ 74 | d = dict( (k,v) for k, v in [(s, 0.0) for s in self.symbol_list] ) 75 | d['datetime'] = self.start_date 76 | d['cash'] = self.initial_capital 77 | d['commission'] = 0.0 78 | d['total'] = self.initial_capital 79 | return [d] 80 | 81 | def construct_current_holdings(self): 82 | """ 83 | This constructs the dictionary which will hold the instantaneous 84 | value of the portfolio across all symbols. 85 | """ 86 | d = dict( (k,v) for k, v in [(s, 0.0) for s in self.symbol_list] ) 87 | d['cash'] = self.initial_capital 88 | d['commission'] = 0.0 89 | d['total'] = self.initial_capital 90 | return d 91 | 92 | def update_timeindex(self, event): 93 | """ 94 | Adds a new record to the positions matrix for the current 95 | market data bar. This reflects the PREVIOUS bar, i.e. all 96 | current market data at this stage is known (OHLCV). 97 | 98 | Makes use of a MarketEvent from the events queue. 99 | """ 100 | latest_datetime = self.bars.get_latest_bar_datetime(self.symbol_list[0]) 101 | 102 | # Update positions 103 | # ================ 104 | dp = dict( (k,v) for k, v in [(s, 0) for s in self.symbol_list] ) 105 | dp['datetime'] = latest_datetime 106 | 107 | for s in self.symbol_list: 108 | dp[s] = self.current_positions[s] 109 | 110 | # Append the current positions 111 | self.all_positions.append(dp) 112 | 113 | # Update holdings 114 | # =============== 115 | dh = dict( (k,v) for k, v in [(s, 0) for s in self.symbol_list] ) 116 | dh['datetime'] = latest_datetime 117 | dh['cash'] = self.current_holdings['cash'] 118 | dh['commission'] = self.current_holdings['commission'] 119 | dh['total'] = self.current_holdings['cash'] 120 | 121 | for s in self.symbol_list: 122 | # Approximation to the real value 123 | market_value = self.current_positions[s] * \ 124 | self.bars.get_latest_bar_value(s, "close") 125 | dh[s] = market_value 126 | dh['total'] += market_value 127 | 128 | # Append the current holdings 129 | self.all_holdings.append(dh) 130 | 131 | # ====================== 132 | # FILL/POSITION HANDLING 133 | # ====================== 134 | 135 | def update_positions_from_fill(self, fill): 136 | """ 137 | Takes a Fill object and updates the position matrix to 138 | reflect the new position. 139 | 140 | Parameters: 141 | fill - The Fill object to update the positions with. 142 | """ 143 | # Check whether the fill is a buy or sell 144 | fill_dir = 0 145 | if fill.direction == 'BUY': 146 | fill_dir = 1 147 | if fill.direction == 'SELL': 148 | fill_dir = -1 149 | 150 | # Update positions list with new quantities 151 | self.current_positions[fill.symbol] += fill_dir*fill.quantity 152 | 153 | def update_holdings_from_fill(self, fill): 154 | """ 155 | Takes a Fill object and updates the holdings matrix to 156 | reflect the holdings value. 157 | 158 | Parameters: 159 | fill - The Fill object to update the holdings with. 160 | """ 161 | # Check whether the fill is a buy or sell 162 | fill_dir = 0 163 | if fill.direction == 'BUY': 164 | fill_dir = 1 165 | if fill.direction == 'SELL': 166 | fill_dir = -1 167 | 168 | # Update holdings list with new quantities 169 | fill_cost = self.bars.get_latest_bar_value( 170 | fill.symbol, "adj_close" 171 | ) 172 | cost = fill_dir * fill_cost * fill.quantity 173 | self.current_holdings[fill.symbol] += cost 174 | self.current_holdings['commission'] += fill.commission 175 | self.current_holdings['cash'] -= (cost + fill.commission) 176 | self.current_holdings['total'] -= (cost + fill.commission) 177 | 178 | def update_fill(self, event): 179 | """ 180 | Updates the portfolio current positions and holdings 181 | from a FillEvent. 182 | """ 183 | if event.type == 'FILL': 184 | self.update_positions_from_fill(event) 185 | self.update_holdings_from_fill(event) 186 | 187 | def generate_naive_order(self, signal): 188 | """ 189 | Simply files an Order object as a constant quantity 190 | sizing of the signal object, without risk management or 191 | position sizing considerations. 192 | 193 | Parameters: 194 | signal - The tuple containing Signal information. 195 | """ 196 | order = None 197 | 198 | symbol = signal.symbol 199 | direction = signal.signal_type 200 | strength = signal.strength 201 | 202 | mkt_quantity = 100 203 | cur_quantity = self.current_positions[symbol] 204 | order_type = 'MKT' 205 | 206 | if direction == 'LONG' and cur_quantity == 0: 207 | order = OrderEvent(symbol, order_type, mkt_quantity, 'BUY') 208 | if direction == 'SHORT' and cur_quantity == 0: 209 | order = OrderEvent(symbol, order_type, mkt_quantity, 'SELL') 210 | if direction == 'EXIT' and cur_quantity > 0: 211 | order = OrderEvent(symbol, order_type, abs(cur_quantity), 'SELL') 212 | if direction == 'EXIT' and cur_quantity < 0: 213 | order = OrderEvent(symbol, order_type, abs(cur_quantity), 'BUY') 214 | return order 215 | 216 | def update_signal(self, event): 217 | """ 218 | Acts on a SignalEvent to generate new orders 219 | based on the portfolio logic. 220 | """ 221 | if event.type == 'SIGNAL': 222 | order_event = self.generate_naive_order(event) 223 | self.events.put(order_event) 224 | 225 | # ======================== 226 | # POST-BACKTEST STATISTICS 227 | # ======================== 228 | 229 | def create_equity_curve_dataframe(self): 230 | """ 231 | Creates a pandas DataFrame from the all_holdings 232 | list of dictionaries. 233 | """ 234 | curve = pd.DataFrame(self.all_holdings) 235 | curve.set_index('datetime', inplace=True) 236 | curve['returns'] = curve['total'].pct_change() 237 | curve['equity_curve'] = (1.0+curve['returns']).cumprod() 238 | self.equity_curve = curve 239 | 240 | def output_summary_stats(self): 241 | """ 242 | Creates a list of summary statistics for the portfolio. 243 | """ 244 | total_return = self.equity_curve['equity_curve'][-1] 245 | returns = self.equity_curve['returns'] 246 | pnl = self.equity_curve['equity_curve'] 247 | 248 | print (total_return,returns,pnl) 249 | 250 | 251 | sharpe_ratio = create_sharpe_ratio(returns, periods=252*60*6.5) 252 | drawdown, max_dd, dd_duration = create_drawdowns(pnl) 253 | self.equity_curve['drawdown'] = drawdown 254 | 255 | stats = [("Total Return", "%0.2f%%" % ((total_return - 1.0) * 100.0)), 256 | ("Sharpe Ratio", "%0.2f" % sharpe_ratio), 257 | ("Max Drawdown", "%0.2f%%" % (max_dd * 100.0)), 258 | ("Drawdown Duration", "%d" % dd_duration)] 259 | 260 | self.equity_curve.to_csv('equity.csv') 261 | return stats 262 | -------------------------------------------------------------------------------- /back_test_system/.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 12 | 13 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | true 46 | DEFINITION_ORDER 47 | 48 | 49 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 77 | 78 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 94 | 95 | 96 | 97 | 1482761166539 98 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 131 | 132 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | -------------------------------------------------------------------------------- /util/db.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import tushare as ts 3 | import MySQLdb as mdb 4 | import datetime 5 | import time 6 | import numpy as np 7 | import pandas as pd 8 | import sys 9 | import pandas as pd 10 | import pandas_datareader.data as web 11 | 12 | 13 | 14 | 15 | #从tushare中获取hs300的股票 16 | def save_hs300_into_db(): 17 | db_host = 'localhost' 18 | db_user = 'root' 19 | db_password = '' 20 | db_name = 'ticker_master' 21 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name) 22 | now = datetime.datetime.utcnow() 23 | hs300 = ts.get_hs300s() 24 | column_str = """ticker, instrument, name, sector, currency, created_date, last_updated_date""" 25 | insert_str = ("%s, " * 7)[:-2] 26 | final_str = "INSERT INTO symbol (%s) VALUES (%s)" % (column_str, insert_str) 27 | symbols = [] 28 | 29 | for i in range(len(hs300)): 30 | t = hs300.ix[i] 31 | symbols.append( 32 | ( 33 | t['code'], 34 | 'stock', 35 | t['name'], 36 | '', 37 | 'RMB', 38 | now, 39 | now, 40 | ) 41 | ) 42 | cur = con.cursor() 43 | with con: 44 | cur = con.cursor() 45 | cur.executemany(final_str, symbols) 46 | print 'success insert hs300 into symbol!' 47 | 48 | def save_us_into_db(symbols): 49 | db_host = 'localhost' 50 | db_user = 'root' 51 | db_password = '' 52 | db_name = 'us_ticker_master' 53 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name) 54 | now = datetime.datetime.utcnow() 55 | column_str = """ticker, instrument, name, sector, currency, created_date, last_updated_date""" 56 | insert_str = ("%s, " * 7)[:-2] 57 | final_str = "INSERT INTO symbol (%s) VALUES (%s)" % (column_str, insert_str) 58 | symbols_content = [] 59 | 60 | for i in range(len(symbols)): 61 | t = symbols.ix[i] 62 | symbols_content.append( 63 | ( 64 | t['Symbol'], 65 | 'stock', 66 | t['Name'], 67 | t['Sector'], 68 | 'USD', 69 | now, 70 | now, 71 | ) 72 | ) 73 | cur = con.cursor() 74 | with con: 75 | cur = con.cursor() 76 | cur.executemany(final_str, symbols_content) 77 | print 'success insert us_ticker into symbol!' 78 | 79 | #获取事先储存过的300支股票 80 | def get_hs300_tickers(): 81 | db_host = 'localhost' 82 | db_user = 'root' 83 | db_password = '' 84 | db_name = 'ticker_master' 85 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name) 86 | with con: 87 | cur = con.cursor(); 88 | cur.execute('SELECT id,ticker,name FROM symbol') 89 | data = cur.fetchall(); 90 | return [(d[0],d[1],d[2]) for d in data] 91 | 92 | #获取美股id用于遍历获取信息以及存储 93 | def get_us_tickers(): 94 | db_host = 'localhost' 95 | db_user = 'root' 96 | db_password = '' 97 | db_name = 'us_ticker_master' 98 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name) 99 | with con: 100 | cur = con.cursor(); 101 | cur.execute('SELECT id,ticker,name FROM symbol') 102 | data = cur.fetchall(); 103 | return [(d[0],d[1],d[2]) for d in data] 104 | 105 | #根据id获取ticker 106 | def get_ticker_info_by_id(ticker_id,start_date,end_date=str(datetime.date.today())): 107 | #如果没传,则默认为该支股票开始的位置 108 | if start_date == '': 109 | df = ts.get_stock_basics() 110 | start_date = df.ix[ticker_id]['timeToMarket'] #上市日期YYYYMMDD 111 | start_date = str(start_date) 112 | start_date_year = start_date[0:4] 113 | start_date_month = start_date[4:6] 114 | start_date_day = start_date[6:8] 115 | start_date = start_date_year + '-' + start_date_month + '-' + start_date_day 116 | 117 | print ('======= loading:%s to %s , %s ========' % (start_date,end_date,ticker_id)) 118 | ticker_data = ts.get_h_data(ticker_id,start=start_date,end=end_date,retry_count=50,pause=1) 119 | print ('======= loading success =======') 120 | return ticker_data 121 | 122 | #获取美股数据 123 | def get_us_ticker_by_id(ticker_id,start_date,end_date=datetime.date.today()): 124 | start = datetime.datetime(2010, 1, 1) 125 | end = datetime.datetime(2013, 1, 27) 126 | print(start,end,ticker_id,'-----') 127 | data = web.DataReader(ticker_id, 'yahoo', start_date, end_date) 128 | return data 129 | 130 | #下载失败则在symbol中删除 131 | def delete_symbol_from_db_by_id(id): 132 | db_host = 'localhost' 133 | db_user = 'root' 134 | db_password = '' 135 | db_name = 'us_ticker_master' 136 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name) 137 | with con: 138 | cur = con.cursor() 139 | print "DELETE FROM symbol where ticker='%s'" % id 140 | cur.execute("DELETE FROM symbol where ticker='%s'" % id) 141 | 142 | 143 | #读取symbol表里的最新的日期 144 | def get_last_date(ticker_id): 145 | db_host = 'localhost' 146 | db_user = 'root' 147 | db_password = '' 148 | db_name = 'ticker_master' 149 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name) 150 | with con: 151 | cur = con.cursor() 152 | cur.execute("SELECT price_date FROM daily_price WHERE symbol_id=%s ORDER BY price_date DESC" % ticker_id) 153 | date = cur.fetchall() 154 | return date 155 | 156 | #读取us美股最老日期 157 | def get_us_oldest_date(ticker_id): 158 | db_host = 'localhost' 159 | db_user = 'root' 160 | db_password = '' 161 | db_name = 'us_ticker_master' 162 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name) 163 | with con: 164 | cur = con.cursor() 165 | cur.execute("SELECT price_date FROM daily_price WHERE symbol_id=%s ORDER BY price_date" % ticker_id) 166 | date = cur.fetchall() 167 | return date 168 | 169 | #读取美股最新日期 170 | def get_us_last_date(ticker_id): 171 | db_host = 'localhost' 172 | db_user = 'root' 173 | db_password = '' 174 | db_name = 'us_ticker_master' 175 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name) 176 | with con: 177 | cur = con.cursor() 178 | cur.execute("SELECT price_date FROM daily_price WHERE symbol_id='%s' ORDER BY price_date DESC" % ticker_id) 179 | date = cur.fetchall() 180 | return date 181 | 182 | #储存到数据库 183 | def save_ticker_into_db(ticker_id,ticker,vendor_id): 184 | db_host = 'localhost' 185 | db_user = 'root' 186 | db_password = '' 187 | db_name = 'ticker_master' 188 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name) 189 | # Create the time now 190 | now = datetime.datetime.utcnow() 191 | # Create the insert strings 192 | column_str = """data_vendor_id, symbol_id, price_date, created_date, 193 | last_updated_date, open_price, high_price, low_price, 194 | close_price, volume, amount""" 195 | insert_str = ("%s, " * 11)[:-2] 196 | final_str = "INSERT INTO daily_price (%s) VALUES (%s)" % (column_str, insert_str) 197 | daily_data = [] 198 | 199 | for i in range(len(ticker.index)): 200 | t_date = ticker.index[i] 201 | t_data = ticker.ix[t_date] 202 | daily_data.append( 203 | (vendor_id, ticker_id, t_date, now, now,t_data['open'], t_data['high'] 204 | , t_data['low'], t_data['close'], t_data['volume'], t_data['amount']) 205 | ) 206 | 207 | with con: 208 | cur = con.cursor() 209 | cur.executemany(final_str, daily_data) 210 | 211 | #储存到美股数据库 212 | def save_us_ticker_into_db(ticker_id,ticker,vendor_id): 213 | db_host = 'localhost' 214 | db_user = 'root' 215 | db_password = '' 216 | db_name = 'us_ticker_master' 217 | con = mdb.connect(host=db_host, user=db_user, passwd=db_password, db=db_name) 218 | # Create the time now 219 | now = datetime.datetime.utcnow() 220 | # Create the insert strings 221 | column_str = """data_vendor_id, symbol_id, price_date, created_date, 222 | last_updated_date, open_price, high_price, low_price, 223 | close_price, volume, adj_close_price""" 224 | insert_str = ("%s, " * 11)[:-2] 225 | final_str = "INSERT INTO daily_price (%s) VALUES (%s)" % (column_str, insert_str) 226 | daily_data = [] 227 | 228 | for i in range(len(ticker.index)): 229 | t_date = ticker.index[i] 230 | t_data = ticker.ix[t_date] 231 | daily_data.append( 232 | (vendor_id, ticker_id, t_date, now, now,t_data['Open'], t_data['High'] 233 | , t_data['Low'], t_data['Close'], t_data['Volume'], t_data['Adj Close']) 234 | ) 235 | 236 | with con: 237 | cur = con.cursor() 238 | cur.executemany(final_str, daily_data) 239 | print 'success insert new into db!' 240 | 241 | 242 | 243 | 244 | #从数据中获取数据 245 | def get_ticker_from_db_by_id(ticker_id): 246 | db_host = 'localhost' 247 | db_user = 'root' 248 | db_password = '' 249 | db_name = 'ticker_master' 250 | con = mdb.connect(host=db_host,user=db_user,passwd=db_password,db=db_name) 251 | with con: 252 | cur = con.cursor() 253 | cur.execute('SELECT price_date,open_price,high_price,low_price,close_price,volume from daily_price where symbol_id = %s ORDER BY price_date DESC' % ticker_id ) 254 | daily_data = cur.fetchall() 255 | daily_data_np = np.array(daily_data) 256 | daily_data_df = pd.DataFrame(daily_data_np,columns=['index','open','high','low','close','volume']) 257 | 258 | return daily_data_df 259 | 260 | 261 | #从数据中获取美股us数据 262 | def get_us_ticker_from_db_by_id(ticker_id,start_date,end_date=datetime.date.today()): 263 | db_host = 'localhost' 264 | db_user = 'root' 265 | db_password = '' 266 | db_name = 'us_ticker_master' 267 | con = mdb.connect(host=db_host,user=db_user,passwd=db_password,db=db_name) 268 | 269 | start_date = str(start_date)[:10] 270 | end_date = str(end_date)[:10] 271 | 272 | with con: 273 | cur = con.cursor() 274 | cur.execute('SELECT price_date,open_price,high_price,low_price,adj_close_price,volume FROM daily_price WHERE (price_date BETWEEN "%s" and "%s") AND (symbol_id="%s") ORDER BY price_date DESC' % ( start_date,end_date,ticker_id )) 275 | daily_data = cur.fetchall() 276 | daily_data_np = np.array(daily_data) 277 | daily_data_df = pd.DataFrame(daily_data_np,columns=['index','open','high','low','close','volume']) 278 | return daily_data_df 279 | 280 | #从csv中获取美股的名称 281 | def get_us_ticker_name_from_csv(filename): 282 | data = pd.read_csv(filename)[['Symbol','Name','Sector','MarketCap']] 283 | # print (data) 284 | return data; 285 | 286 | 287 | #获取us日均交易量在中间33%的股票,从当日计算, 股票值在10到30之间 288 | def get_us_middle33_volume(delay_days,low_price,high_price): 289 | tickers = get_us_tickers() 290 | cal_volume_list = pd.DataFrame([],columns=['id','volume']) 291 | df = pd.DataFrame([],columns=['id','volume']) 292 | length = len(tickers) 293 | print '================ is calculating =================' 294 | # print tickers 295 | for i in range(length): 296 | ticker = tickers[i] 297 | ticker_id = ticker[1] 298 | 299 | #处理时间 300 | end_date = get_us_last_date(ticker_id)[0][0] 301 | start_date = end_date + datetime.timedelta(days = delay_days * -1) 302 | ticker_data = get_us_ticker_from_db_by_id(ticker_id,start_date,end_date) 303 | days_mean_volume = ticker_data['volume'].mean() 304 | days_mean_daily_price = ticker_data['close'].mean() 305 | print '========== %s of %s , %s , %s==========' % (i,length,ticker_id,days_mean_volume) 306 | #判断是否符合10到30取值区间 307 | if int(days_mean_daily_price) in range(int(low_price),int(high_price)): 308 | days_mean_volume_df = pd.DataFrame([[ticker_id,days_mean_volume,days_mean_daily_price]],columns=['id','volume','price']) 309 | df = df.append(days_mean_volume_df) 310 | 311 | # if i > 1000: 312 | # break 313 | 314 | df = df.sort(columns="volume") 315 | df_len = len(df) 316 | df = df[int(df_len * 0.33) : int(df_len * 0.66)] 317 | df.index = range(len(df)) 318 | 319 | 320 | return df 321 | 322 | 323 | #得出平均值与方差 324 | def get_average_days_price_by_id(ticker_id,average_days = 7 * 30): 325 | db_host = 'localhost' 326 | db_user = 'root' 327 | db_password = '' 328 | db_name = 'us_ticker_master' 329 | con = mdb.connect(host=db_host,user=db_user,passwd=db_password,db=db_name) 330 | 331 | end_date=datetime.date.today() 332 | start_date = end_date + datetime.timedelta(days = int(average_days) * -1) 333 | start_date = str(start_date)[:10] 334 | end_date = str(end_date)[:10] 335 | with con: 336 | cur = con.cursor() 337 | cur.execute('SELECT adj_close_price FROM daily_price WHERE (price_date BETWEEN "%s" and "%s") AND (symbol_id="%s") ORDER BY price_date DESC' % ( start_date,end_date,ticker_id )) 338 | daily_data = cur.fetchall() 339 | daily_data_np = np.array(daily_data) 340 | daily_data_df = pd.DataFrame(daily_data_np,columns=['close']).fillna(method="pad").fillna(method="bfill") 341 | if len(daily_data_df) == 0 : 342 | mean = 0; 343 | std = 0 344 | else: 345 | mean = daily_data_df['close'].mean() 346 | std = daily_data_df['close'].astype('float').std() 347 | 348 | return mean,std 349 | 350 | 351 | #获取一只股票的currt数据,均值方差数据 352 | def get_current_mean_std_df(ticker_id,days_range=200,cal_range=60,end_date=datetime.date.today()): 353 | db_host = 'localhost' 354 | db_user = 'root' 355 | db_password = '' 356 | db_name = 'us_ticker_master' 357 | con = mdb.connect(host=db_host,user=db_user,passwd=db_password,db=db_name) 358 | 359 | end_date=datetime.date.today() 360 | start_date = end_date + datetime.timedelta(days = int(days_range + cal_range) * -1) 361 | start_date = str(start_date)[:10] 362 | end_date = str(end_date)[:10] 363 | 364 | with con: 365 | cur = con.cursor() 366 | cur.execute('SELECT adj_close_price FROM daily_price WHERE (price_date BETWEEN "%s" and "%s") AND (symbol_id="%s") ORDER BY price_date DESC' % ( start_date,end_date,ticker_id )) 367 | daily_data = cur.fetchall() 368 | daily_data_np = np.array(daily_data) 369 | daily_data_df = pd.DataFrame(daily_data_np,columns=['close']) 370 | daily_data_df['ma_60'] = pd.rolling_mean(daily_data_df['close'],cal_range) 371 | daily_data_df['ewma_60'] = pd.ewma(daily_data_df['close'],cal_range) 372 | daily_data_df['std_60'] = pd.rolling_std(daily_data_df['close'],cal_range) 373 | 374 | return daily_data_df[60:] 375 | --------------------------------------------------------------------------------