├── README.md ├── ppt ├── 机器学习和量化交易实战 Lecture 01.pdf ├── 机器学习和量化交易实战 Lecture 02.pdf ├── 机器学习和量化交易实战 Lecture 03.pdf ├── 机器学习和量化交易实战 Lecture 04.pdf ├── 机器学习和量化交易实战 Lecture 06.pdf ├── 机器学习和量化交易实战 Lecture 07.pptx ├── 机器学习和量化交易实战 Lecture 09.pptx └── 机器学习和量化交易实战 Lecture 10.pptx └── 代码 ├── lecture_code 03 ├── cadf.py ├── insert_symbols.py ├── price_retrieval.py ├── quandl_data.py ├── quantitative.sql ├── retrieving_data.py └── securities_master.sql ├── lecture_code 04 └── code for lecture 4.ipynb ├── lecture_code 05 ├── BB.py ├── CCI.py ├── FI.py ├── MA.py ├── ROC.py ├── evm.py ├── forecast.py └── grid_search.py └── lecture_code 08 ├── backtest.py ├── event.py ├── mac.py └── portfolio.py /README.md: -------------------------------------------------------------------------------- 1 | ### 七月算法 - 量化交易课程 - 机器学习和量化交易课程 2 | ### youtube 链接:https://www.youtube.com/playlist?list=PLwTxjmW4U1YTKgEh9R8n66EcU7NHADR9y&disable_polymer=true 3 | -------------------------------------------------------------------------------- /ppt/机器学习和量化交易实战 Lecture 01.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/runrunbear/ml_lianghuajiaoyi_2017/71d78ecc8e0949e97798f0d6ba8e9b87eea1277f/ppt/机器学习和量化交易实战 Lecture 01.pdf -------------------------------------------------------------------------------- /ppt/机器学习和量化交易实战 Lecture 02.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/runrunbear/ml_lianghuajiaoyi_2017/71d78ecc8e0949e97798f0d6ba8e9b87eea1277f/ppt/机器学习和量化交易实战 Lecture 02.pdf -------------------------------------------------------------------------------- /ppt/机器学习和量化交易实战 Lecture 03.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/runrunbear/ml_lianghuajiaoyi_2017/71d78ecc8e0949e97798f0d6ba8e9b87eea1277f/ppt/机器学习和量化交易实战 Lecture 03.pdf -------------------------------------------------------------------------------- /ppt/机器学习和量化交易实战 Lecture 04.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/runrunbear/ml_lianghuajiaoyi_2017/71d78ecc8e0949e97798f0d6ba8e9b87eea1277f/ppt/机器学习和量化交易实战 Lecture 04.pdf -------------------------------------------------------------------------------- /ppt/机器学习和量化交易实战 Lecture 06.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/runrunbear/ml_lianghuajiaoyi_2017/71d78ecc8e0949e97798f0d6ba8e9b87eea1277f/ppt/机器学习和量化交易实战 Lecture 06.pdf -------------------------------------------------------------------------------- /ppt/机器学习和量化交易实战 Lecture 07.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/runrunbear/ml_lianghuajiaoyi_2017/71d78ecc8e0949e97798f0d6ba8e9b87eea1277f/ppt/机器学习和量化交易实战 Lecture 07.pptx -------------------------------------------------------------------------------- /ppt/机器学习和量化交易实战 Lecture 09.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/runrunbear/ml_lianghuajiaoyi_2017/71d78ecc8e0949e97798f0d6ba8e9b87eea1277f/ppt/机器学习和量化交易实战 Lecture 09.pptx -------------------------------------------------------------------------------- /ppt/机器学习和量化交易实战 Lecture 10.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/runrunbear/ml_lianghuajiaoyi_2017/71d78ecc8e0949e97798f0d6ba8e9b87eea1277f/ppt/机器学习和量化交易实战 Lecture 10.pptx -------------------------------------------------------------------------------- /代码/lecture_code 03/cadf.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import matplotlib.dates as mdates 5 | import pandas as pd 6 | import pandas.io.data as web 7 | import pprint 8 | import statsmodels.tsa.stattools as ts 9 | 10 | from pandas.stats.api import ols 11 | 12 | 13 | def plot_price_series(df, ts1, ts2): 14 | months = mdates.MonthLocator() # every month 15 | fig, ax = plt.subplots() 16 | ax.plot(df.index, df[ts1], label=ts1) 17 | ax.plot(df.index, df[ts2], label=ts2) 18 | ax.xaxis.set_major_locator(months) 19 | ax.xaxis.set_major_formatter(mdates.DateFormatter('%b %Y')) 20 | ax.set_xlim(datetime.datetime(2012, 1, 1), datetime.datetime(2013, 1, 1)) 21 | ax.grid(True) 22 | fig.autofmt_xdate() 23 | 24 | plt.xlabel('Month/Year') 25 | plt.ylabel('Price ($)') 26 | plt.title('%s and %s Daily Prices' % (ts1, ts2)) 27 | plt.legend() 28 | plt.show() 29 | 30 | def plot_scatter_series(df, ts1, ts2): 31 | plt.xlabel('%s Price ($)' % ts1) 32 | plt.ylabel('%s Price ($)' % ts2) 33 | plt.title('%s and %s Price Scatterplot' % (ts1, ts2)) 34 | plt.scatter(df[ts1], df[ts2]) 35 | plt.show() 36 | 37 | def plot_residuals(df): 38 | months = mdates.MonthLocator() # every month 39 | fig, ax = plt.subplots() 40 | ax.plot(df.index, df["res"], label="Residuals") 41 | ax.xaxis.set_major_locator(months) 42 | ax.xaxis.set_major_formatter(mdates.DateFormatter('%b %Y')) 43 | ax.set_xlim(datetime.datetime(2012, 1, 1), datetime.datetime(2013, 1, 1)) 44 | ax.grid(True) 45 | fig.autofmt_xdate() 46 | 47 | plt.xlabel('Month/Year') 48 | plt.ylabel('Price ($)') 49 | plt.title('Residual Plot') 50 | plt.legend() 51 | 52 | plt.plot(df["res"]) 53 | plt.show() 54 | 55 | if __name__ == "__main__": 56 | start = datetime.datetime(2012, 1, 1) 57 | end = datetime.datetime(2013, 1, 1) 58 | 59 | arex = web.DataReader("AREX", "yahoo", start, end) 60 | wll = web.DataReader("WLL", "yahoo", start, end) 61 | 62 | df = pd.DataFrame(index=arex.index) 63 | df["AREX"] = arex["Adj Close"] 64 | df["WLL"] = wll["Adj Close"] 65 | 66 | # Plot the two time series 67 | plot_price_series(df, "AREX", "WLL") 68 | 69 | # Display a scatter plot of the two time series 70 | plot_scatter_series(df, "AREX", "WLL") 71 | 72 | # Calculate optimal hedge ratio "beta" 73 | res = ols(y=df['WLL'], x=df["AREX"]) 74 | beta_hr = res.beta.x 75 | 76 | # Calculate the residuals of the linear combination 77 | df["res"] = df["WLL"] - beta_hr*df["AREX"] 78 | 79 | # Plot the residuals 80 | plot_residuals(df) 81 | 82 | # Calculate and output the CADF test on the residuals 83 | cadf = ts.adfuller(df["res"]) 84 | pprint.pprint(cadf) -------------------------------------------------------------------------------- /代码/lecture_code 03/insert_symbols.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | # insert_symbols.py 5 | 6 | from __future__ import print_function 7 | 8 | import datetime 9 | from math import ceil 10 | 11 | import bs4 12 | import MySQLdb as mdb 13 | import requests 14 | 15 | 16 | def obtain_parse_wiki_snp500(): 17 | """ 18 | Download and parse the Wikipedia list of S&P500 19 | constituents using requests and BeautifulSoup. 20 | 21 | Returns a list of tuples for to add to MySQL. 22 | """ 23 | # Stores the current time, for the created_at record 24 | now = datetime.datetime.utcnow() 25 | 26 | # Use requests and BeautifulSoup to download the 27 | # list of S&P500 companies and obtain the symbol table 28 | response = requests.get( 29 | "http://en.wikipedia.org/wiki/List_of_S%26P_500_companies" 30 | ) 31 | soup = bs4.BeautifulSoup(response.text) 32 | 33 | # This selects the first table, using CSS Selector syntax 34 | # and then ignores the header row ([1:]) 35 | symbolslist = soup.select('table')[0].select('tr')[1:] 36 | 37 | # Obtain the symbol information for each 38 | # row in the S&P500 constituent table 39 | symbols = [] 40 | for i, symbol in enumerate(symbolslist): 41 | tds = symbol.select('td') 42 | symbols.append( 43 | ( 44 | tds[0].select('a')[0].text, # Ticker 45 | 'stock', 46 | tds[1].select('a')[0].text, # Name 47 | tds[3].text, # Sector 48 | 'USD', now, now 49 | ) 50 | ) 51 | return symbols 52 | 53 | 54 | def insert_snp500_symbols(symbols): 55 | """ 56 | Insert the S&P500 symbols into the MySQL database. 57 | """ 58 | # Connect to the MySQL instance 59 | db_host = 'localhost' 60 | db_user = 'sec_user' 61 | db_pass = 'password' 62 | db_name = 'securities_master' 63 | con = mdb.connect( 64 | host=db_host, user=db_user, passwd=db_pass, db=db_name 65 | ) 66 | 67 | # Create the insert strings 68 | column_str = """ticker, instrument, name, sector, 69 | currency, created_date, last_updated_date 70 | """ 71 | insert_str = ("%s, " * 7)[:-2] 72 | final_str = "INSERT INTO symbol (%s) VALUES (%s)" % \ 73 | (column_str, insert_str) 74 | 75 | # Using the MySQL connection, carry out 76 | # an INSERT INTO for every symbol 77 | with con: 78 | cur = con.cursor() 79 | cur.executemany(final_str, symbols) 80 | 81 | 82 | if __name__ == "__main__": 83 | symbols = obtain_parse_wiki_snp500() 84 | insert_snp500_symbols(symbols) 85 | print("%s symbols were successfully added." % len(symbols)) 86 | -------------------------------------------------------------------------------- /代码/lecture_code 03/price_retrieval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | # price_retrieval.py 5 | 6 | from __future__ import print_function 7 | 8 | import datetime 9 | import warnings 10 | 11 | import MySQLdb as mdb 12 | import requests 13 | 14 | 15 | # Obtain a database connection to the MySQL instance 16 | db_host = 'localhost' 17 | db_user = 'sec_user' 18 | db_pass = 'password' 19 | db_name = 'securities_master' 20 | con = mdb.connect(db_host, db_user, db_pass, db_name) 21 | 22 | 23 | def obtain_list_of_db_tickers(): 24 | """ 25 | Obtains a list of the ticker symbols in the database. 26 | """ 27 | with con: 28 | cur = con.cursor() 29 | cur.execute("SELECT id, ticker FROM symbol") 30 | data = cur.fetchall() 31 | return [(d[0], d[1]) for d in data] 32 | 33 | 34 | def get_daily_historic_data_yahoo( 35 | ticker, start_date=(2000,1,1), 36 | end_date=datetime.date.today().timetuple()[0:3] 37 | ): 38 | """ 39 | Obtains data from Yahoo Finance returns and a list of tuples. 40 | 41 | ticker: Yahoo Finance ticker symbol, e.g. "GOOG" for Google, Inc. 42 | start_date: Start date in (YYYY, M, D) format 43 | end_date: End date in (YYYY, M, D) format 44 | """ 45 | # Construct the Yahoo URL with the correct integer query parameters 46 | # for start and end dates. Note that some parameters are zero-based! 47 | ticker_tup = ( 48 | ticker, start_date[1]-1, start_date[2], 49 | start_date[0], end_date[1]-1, end_date[2], 50 | end_date[0] 51 | ) 52 | yahoo_url = "http://ichart.finance.yahoo.com/table.csv" 53 | yahoo_url += "?s=%s&a=%s&b=%s&c=%s&d=%s&e=%s&f=%s" 54 | yahoo_url = yahoo_url % ticker_tup 55 | 56 | # Try connecting to Yahoo Finance and obtaining the data 57 | # On failure, print an error message. 58 | try: 59 | yf_data = requests.get(yahoo_url).text.split("\n")[1:-1] 60 | prices = [] 61 | for y in yf_data: 62 | p = y.strip().split(',') 63 | prices.append( 64 | (datetime.datetime.strptime(p[0], '%Y-%m-%d'), 65 | p[1], p[2], p[3], p[4], p[5], p[6]) 66 | ) 67 | except Exception as e: 68 | print("Could not download Yahoo data: %s" % e) 69 | return prices 70 | 71 | 72 | def insert_daily_data_into_db( 73 | data_vendor_id, symbol_id, daily_data 74 | ): 75 | """ 76 | Takes a list of tuples of daily data and adds it to the 77 | MySQL database. Appends the vendor ID and symbol ID to the data. 78 | 79 | daily_data: List of tuples of the OHLC data (with 80 | adj_close and volume) 81 | """ 82 | # Create the time now 83 | now = datetime.datetime.utcnow() 84 | 85 | # Amend the data to include the vendor ID and symbol ID 86 | daily_data = [ 87 | (data_vendor_id, symbol_id, d[0], now, now, 88 | d[1], d[2], d[3], d[4], d[5], d[6]) 89 | for d in daily_data 90 | ] 91 | 92 | # Create the insert strings 93 | column_str = """data_vendor_id, symbol_id, price_date, created_date, 94 | last_updated_date, open_price, high_price, low_price, 95 | close_price, volume, adj_close_price""" 96 | insert_str = ("%s, " * 11)[:-2] 97 | final_str = "INSERT INTO daily_price (%s) VALUES (%s)" % \ 98 | (column_str, insert_str) 99 | 100 | # Using the MySQL connection, carry out an INSERT INTO for every symbol 101 | with con: 102 | cur = con.cursor() 103 | cur.executemany(final_str, daily_data) 104 | 105 | 106 | if __name__ == "__main__": 107 | # This ignores the warnings regarding Data Truncation 108 | # from the Yahoo precision to Decimal(19,4) datatypes 109 | warnings.filterwarnings('ignore') 110 | 111 | # Loop over the tickers and insert the daily historical 112 | # data into the database 113 | tickers = obtain_list_of_db_tickers() 114 | lentickers = len(tickers) 115 | for i, t in enumerate(tickers): 116 | print( 117 | "Adding data for %s: %s out of %s" % 118 | (t[1], i+1, lentickers) 119 | ) 120 | yf_data = get_daily_historic_data_yahoo(t[1]) 121 | insert_daily_data_into_db('1', t[0], yf_data) 122 | print("Successfully added Yahoo Finance pricing data to DB.") 123 | -------------------------------------------------------------------------------- /代码/lecture_code 03/quandl_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | # quandl_data.py 5 | 6 | from __future__ import print_function 7 | 8 | import matplotlib.pyplot as plt 9 | import pandas as pd 10 | import requests 11 | 12 | 13 | def construct_futures_symbols( 14 | symbol, start_year=2010, end_year=2014 15 | ): 16 | """ 17 | Constructs a list of futures contract codes 18 | for a particular symbol and timeframe. 19 | """ 20 | futures = [] 21 | # March, June, September and 22 | # December delivery codes 23 | months = 'HMUZ' 24 | for y in range(start_year, end_year+1): 25 | for m in months: 26 | futures.append("%s%s%s" % (symbol, m, y)) 27 | return futures 28 | 29 | 30 | def download_contract_from_quandl(contract, dl_dir): 31 | """ 32 | Download an individual futures contract from Quandl and then 33 | store it to disk in the 'dl_dir' directory. An auth_token is 34 | required, which is obtained from the Quandl upon sign-up. 35 | """ 36 | # Construct the API call from the contract and auth_token 37 | api_call = "http://www.quandl.com/api/v1/datasets/" 38 | api_call += "OFDP/FUTURE_%s.csv" % contract 39 | # If you wish to add an auth token for more downloads, simply 40 | # comment the following line and replace MY_AUTH_TOKEN with 41 | # your auth token in the line below 42 | params = "?sort_order=asc" 43 | #params = "?auth_token=MY_AUTH_TOKEN&sort_order=asc" 44 | full_url = "%s%s" % (api_call, params) 45 | 46 | # Download the data from Quandl 47 | data = requests.get(full_url).text 48 | 49 | # Store the data to disk 50 | fc = open('%s/%s.csv' % (dl_dir, contract), 'w') 51 | fc.write(data) 52 | fc.close() 53 | 54 | 55 | def download_historical_contracts( 56 | symbol, dl_dir, start_year=2010, end_year=2014 57 | ): 58 | """ 59 | Downloads all futures contracts for a specified symbol 60 | between a start_year and an end_year. 61 | """ 62 | contracts = construct_futures_symbols( 63 | symbol, start_year, end_year 64 | ) 65 | for c in contracts: 66 | print("Downloading contract: %s" % c) 67 | download_contract_from_quandl(c, dl_dir) 68 | 69 | 70 | if __name__ == "__main__": 71 | symbol = 'ES' 72 | 73 | # Make sure you've created this 74 | # relative directory beforehand 75 | dl_dir = 'quandl/futures/ES' 76 | 77 | # Create the start and end years 78 | start_year = 2010 79 | end_year = 2014 80 | 81 | # Download the contracts into the directory 82 | download_historical_contracts( 83 | symbol, dl_dir, start_year, end_year 84 | ) 85 | 86 | # Open up a single contract via read_csv 87 | # and plot the settle price 88 | es = pd.io.parsers.read_csv( 89 | "%s/ESH2010.csv" % dl_dir, index_col="Date" 90 | ) 91 | es["Settle"].plot() 92 | plt.show() 93 | -------------------------------------------------------------------------------- /代码/lecture_code 03/quantitative.sql: -------------------------------------------------------------------------------- 1 | -- Errors encountered generating script 2 | -- Select items in the error list to the left 3 | 4 | CREATE TABLE `exchange` ( 5 | `id` int NOT NULL AUTO_INCREMENT, 6 | `abbrev` varchar(32) NOT NULL, 7 | `name` varchar(255) NOT NULL, 8 | `city` varchar(255) NULL, 9 | `country` varchar(255) NULL, 10 | `currency` varchar(64) NULL, 11 | `timezone_offset` time NULL, 12 | `created_date` datetime NOT NULL, 13 | `last_updated_date` datetime NOT NULL, 14 | PRIMARY KEY (`id`) 15 | ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8; 16 | 17 | CREATE TABLE `data_vendor` ( 18 | `id` int NOT NULL AUTO_INCREMENT, 19 | `name` varchar(64) NOT NULL, 20 | `website_url` varchar(255) NULL, 21 | `support_email` varchar(255) NULL, 22 | `created_date` datetime NOT NULL, 23 | `last_updated_date` datetime NOT NULL, 24 | PRIMARY KEY (`id`) 25 | ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8; 26 | 27 | CREATE TABLE `symbol` ( 28 | `id` int NOT NULL AUTO_INCREMENT, 29 | `exchange_id` int NULL, 30 | `ticker` varchar(32) NOT NULL, 31 | `instrument` varchar(64) NOT NULL, 32 | `name` varchar(255) NULL, 33 | `sector` varchar(255) NULL, 34 | `currency` varchar(32) NULL, 35 | `created_date` datetime NOT NULL, 36 | `last_updated_date` datetime NOT NULL, 37 | PRIMARY KEY (`id`), 38 | KEY `index_exchange_id` (`exchange_id`) 39 | ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8; 40 | 41 | CREATE TABLE `daily_price` ( 42 | `id` int NOT NULL AUTO_INCREMENT, 43 | `data_vendor_id` int NOT NULL, 44 | `symbol_id` int NOT NULL, 45 | `price_date` datetime NOT NULL, 46 | `created_date` datetime NOT NULL, 47 | `last_updated_date` datetime NOT NULL, 48 | `open_price` decimal(19,4) NULL, 49 | `high_price` decimal(19,4) NULL, 50 | `low_price` decimal(19,4) NULL, 51 | `close_price` decimal(19,4) NULL, 52 | `adj_close_price` decimal(19,4) NULL, 53 | `volume` bigint NULL, 54 | PRIMARY KEY (`id`), 55 | KEY `index_data_vendor_id` (`data_vendor_id`), 56 | KEY `index_symbol_id` (`symbol_id`) 57 | ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8; 58 | -------------------------------------------------------------------------------- /代码/lecture_code 03/retrieving_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | # retrieving_data.py 5 | 6 | from __future__ import print_function 7 | 8 | import pandas as pd 9 | import MySQLdb as mdb 10 | 11 | 12 | if __name__ == "__main__": 13 | # Connect to the MySQL instance 14 | db_host = 'localhost' 15 | db_user = 'sec_user' 16 | db_pass = 'password' 17 | db_name = 'securities_master' 18 | con = mdb.connect(db_host, db_user, db_pass, db_name) 19 | 20 | # Select all of the historic Google adjusted close data 21 | sql = """SELECT dp.price_date, dp.adj_close_price 22 | FROM symbol AS sym 23 | INNER JOIN daily_price AS dp 24 | ON dp.symbol_id = sym.id 25 | WHERE sym.ticker = 'GOOG' 26 | ORDER BY dp.price_date ASC;""" 27 | 28 | # Create a pandas dataframe from the SQL query 29 | goog = pd.read_sql_query(sql, con=con, index_col='price_date') 30 | 31 | # Output the dataframe tail 32 | print(goog.tail()) 33 | -------------------------------------------------------------------------------- /代码/lecture_code 03/securities_master.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE `exchange` ( 2 | `id` int NOT NULL AUTO_INCREMENT, 3 | `abbrev` varchar(32) NOT NULL, 4 | `name` varchar(255) NOT NULL, 5 | `city` varchar(255) NULL, 6 | `country` varchar(255) NULL, 7 | `currency` varchar(64) NULL, 8 | `timezone_offset` time NULL, 9 | `created_date` datetime NOT NULL, 10 | `last_updated_date` datetime NOT NULL, 11 | PRIMARY KEY (`id`) 12 | ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8; 13 | 14 | CREATE TABLE `data_vendor` ( 15 | `id` int NOT NULL AUTO_INCREMENT, 16 | `name` varchar(64) NOT NULL, 17 | `website_url` varchar(255) NULL, 18 | `support_email` varchar(255) NULL, 19 | `created_date` datetime NOT NULL, 20 | `last_updated_date` datetime NOT NULL, 21 | PRIMARY KEY (`id`) 22 | ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8; 23 | 24 | CREATE TABLE `symbol` ( 25 | `id` int NOT NULL AUTO_INCREMENT, 26 | `exchange_id` int NULL, 27 | `ticker` varchar(32) NOT NULL, 28 | `instrument` varchar(64) NOT NULL, 29 | `name` varchar(255) NULL, 30 | `sector` varchar(255) NULL, 31 | `currency` varchar(32) NULL, 32 | `created_date` datetime NOT NULL, 33 | `last_updated_date` datetime NOT NULL, 34 | PRIMARY KEY (`id`), 35 | KEY `index_exchange_id` (`exchange_id`) 36 | ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8; 37 | 38 | CREATE TABLE `daily_price` ( 39 | `id` int NOT NULL AUTO_INCREMENT, 40 | `data_vendor_id` int NOT NULL, 41 | `symbol_id` int NOT NULL, 42 | `price_date` datetime NOT NULL, 43 | `created_date` datetime NOT NULL, 44 | `last_updated_date` datetime NOT NULL, 45 | `open_price` decimal(19,4) NULL, 46 | `high_price` decimal(19,4) NULL, 47 | `low_price` decimal(19,4) NULL, 48 | `close_price` decimal(19,4) NULL, 49 | `adj_close_price` decimal(19,4) NULL, 50 | `volume` bigint NULL, 51 | PRIMARY KEY (`id`), 52 | KEY `index_data_vendor_id` (`data_vendor_id`), 53 | KEY `index_symbol_id` (`symbol_id`) 54 | ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8; -------------------------------------------------------------------------------- /代码/lecture_code 05/BB.py: -------------------------------------------------------------------------------- 1 | ################ Bollinger Bands ############################# 2 | 3 | # Load the necessary packages and modules 4 | import pandas as pd 5 | import pandas.io.data as web 6 | 7 | # Compute the Bollinger Bands 8 | def BBANDS(data, ndays): 9 | 10 | MA = pd.Series(pd.rolling_mean(data['Close'], ndays)) 11 | SD = pd.Series(pd.rolling_std(data['Close'], ndays)) 12 | 13 | b1 = MA + (2 * SD) 14 | B1 = pd.Series(b1, name = 'Upper BollingerBand') 15 | data = data.join(B1) 16 | 17 | b2 = MA - (2 * SD) 18 | B2 = pd.Series(b2, name = 'Lower BollingerBand') 19 | data = data.join(B2) 20 | 21 | return data 22 | 23 | # Retrieve the Nifty data from Yahoo finance: 24 | data = web.DataReader('^NSEI',data_source='yahoo',start='1/1/2010', end='1/1/2016') 25 | data = pd.DataFrame(data) 26 | 27 | # Compute the Bollinger Bands for NIFTY using the 50-day Moving average 28 | n = 50 29 | NIFTY_BBANDS = BBANDS(data, n) 30 | print(NIFTY_BBANDS) -------------------------------------------------------------------------------- /代码/lecture_code 05/CCI.py: -------------------------------------------------------------------------------- 1 | # Load the necessary packages and modules 2 | import pandas as pd 3 | import pandas.io.data as web 4 | import matplotlib.pyplot as plt 5 | # Commodity Channel Index 6 | def CCI(data, ndays): 7 | TP = (data['High'] + data['Low'] + data['Close']) / 3 8 | CCI = pd.Series((TP - pd.rolling_mean(TP, ndays)) / (0.015 * pd.rolling_std(TP, ndays)), 9 | name = 'CCI') 10 | data = data.join(CCI) 11 | return data 12 |   13 | # Retrieve the Nifty data from Yahoo finance: 14 | data = web.DataReader('^NSEI',data_source='yahoo',start='1/1/2014', end='1/1/2016') 15 | data = pd.DataFrame(data) 16 |   17 | # Compute the Commodity Channel Index(CCI) for NIFTY based on the 20-day Moving average 18 | n = 20 19 | NIFTY_CCI = CCI(data, n) 20 | CCI = NIFTY_CCI['CCI'] 21 |   22 | # Plotting the Price Series chart and the Commodity Channel index below 23 | fig = plt.figure(figsize=(7,5)) 24 | ax = fig.add_subplot(2, 1, 1) 25 | ax.set_xticklabels([]) 26 | plt.plot(data['Close'],lw=1) 27 | plt.title('NSE Price Chart') 28 | plt.ylabel('Close Price') 29 | plt.grid(True) 30 | bx = fig.add_subplot(2, 1, 2) 31 | plt.plot(CCI,'k',lw=0.75,linestyle='-',label='CCI') 32 | plt.legend(loc=2,prop={'size':9.5}) 33 | plt.ylabel('CCI values') 34 | plt.grid(True) 35 | plt.setp(plt.gca().get_xticklabels(), rotation=30) -------------------------------------------------------------------------------- /代码/lecture_code 05/FI.py: -------------------------------------------------------------------------------- 1 | ################# Force Index ######################################################## 2 | 3 | # Load the necessary packages and modules 4 | import pandas as pd 5 | import pandas.io.data as web 6 | 7 | # Force Index 8 | def ForceIndex(data, ndays): 9 | FI = pd.Series(data['Close'].diff(ndays) * data['Volume'], name = 'ForceIndex') 10 | data = data.join(FI) 11 | return data 12 | 13 | 14 | # Retrieve the Apple data from Yahoo finance: 15 | data = web.DataReader('AAPL',data_source='yahoo',start='1/1/2010', end='1/1/2016') 16 | data = pd.DataFrame(data) 17 | 18 | # Compute the Force Index for Apple 19 | n = 1 20 | AAPL_ForceIndex = ForceIndex(data,n) 21 | print(AAPL_ForceIndex) -------------------------------------------------------------------------------- /代码/lecture_code 05/MA.py: -------------------------------------------------------------------------------- 1 | # Moving Averages Code 2 | 3 | # Load the necessary packages and modules 4 | import pandas as pd 5 | import pandas.io.data as web 6 | import matplotlib.pyplot as plt 7 | 8 | # Simple Moving Average 9 | def SMA(data, ndays): 10 | SMA = pd.Series(pd.rolling_mean(data['Close'], ndays), name = 'SMA') 11 | data = data.join(SMA) 12 | return data 13 | 14 | # Exponentially-weighted Moving Average 15 | def EWMA(data, ndays): 16 | EMA = pd.Series(pd.ewma(data['Close'], span = ndays, min_periods = ndays - 1), 17 | name = 'EWMA_' + str(ndays)) 18 | data = data.join(EMA) 19 | return data 20 | 21 | # Retrieve the Nifty data from Yahoo finance: 22 | data = web.DataReader('^NSEI',data_source='yahoo',start='1/1/2013', end='1/1/2016') 23 | data = pd.DataFrame(data) 24 | close = data['Close'] 25 | 26 | # Compute the 50-day SMA for NIFTY 27 | n = 50 28 | SMA_NIFTY = SMA(data,n) 29 | SMA_NIFTY = SMA_NIFTY.dropna() 30 | SMA = SMA_NIFTY['SMA'] 31 | 32 | # Compute the 200-day EWMA for NIFTY 33 | ew = 200 34 | EWMA_NIFTY = EWMA(data,ew) 35 | EWMA_NIFTY = EWMA_NIFTY.dropna() 36 | EWMA = EWMA_NIFTY['EWMA_200'] 37 | 38 | # Plotting the NIFTY Price Series chart and Moving Averages below 39 | plt.figure(figsize=(9,5)) 40 | plt.plot(data['Close'],lw=1, label='NSE Prices') 41 | plt.plot(SMA,'g',lw=1, label='50-day SMA (green)') 42 | plt.plot(EWMA,'r', lw=1, label='200-day EWMA (red)') 43 | plt.legend(loc=2,prop={'size':11}) 44 | plt.grid(True) 45 | plt.setp(plt.gca().get_xticklabels(), rotation=30) -------------------------------------------------------------------------------- /代码/lecture_code 05/ROC.py: -------------------------------------------------------------------------------- 1 | # Rate of Change code 2 | 3 | # Load the necessary packages and modules 4 | import pandas as pd 5 | import pandas.io.data as web 6 | import matplotlib.pyplot as plt 7 | 8 | # Rate of Change (ROC) 9 | def ROC(data,n): 10 | N = data['Close'].diff(n) 11 | D = data['Close'].shift(n) 12 | ROC = pd.Series(N/D,name='Rate of Change') 13 | data = data.join(ROC) 14 | return data 15 | 16 | # Retrieve the NIFTY data from Yahoo finance: 17 | data = web.DataReader('^NSEI',data_source='yahoo',start='6/1/2015',end='1/1/2016') 18 | data = pd.DataFrame(data) 19 | 20 | # Compute the 5-period Rate of Change for NIFTY 21 | n = 5 22 | NIFTY_ROC = ROC(data,n) 23 | ROC = NIFTY_ROC['Rate of Change'] 24 | 25 | # Plotting the Price Series chart and the Ease Of Movement below 26 | fig = plt.figure(figsize=(7,5)) 27 | ax = fig.add_subplot(2, 1, 1) 28 | ax.set_xticklabels([]) 29 | plt.plot(data['Close'],lw=1) 30 | plt.title('NSE Price Chart') 31 | plt.ylabel('Close Price') 32 | plt.grid(True) 33 | bx = fig.add_subplot(2, 1, 2) 34 | plt.plot(ROC,'k',lw=0.75,linestyle='-',label='ROC') 35 | plt.legend(loc=2,prop={'size':9}) 36 | plt.ylabel('ROC values') 37 | plt.grid(True) 38 | plt.setp(plt.gca().get_xticklabels(), rotation=30) -------------------------------------------------------------------------------- /代码/lecture_code 05/evm.py: -------------------------------------------------------------------------------- 1 | # Load the necessary packages and modules 2 | import pandas as pd 3 | import pandas.io.data as web 4 | import matplotlib.pyplot as plt 5 | 6 | # Ease of Movement 7 | def EVM(data, ndays): 8 | dm = ((data['High'] + data['Low'])/2) - ((data['High'].shift(1) + data['Low'].shift(1))/2) 9 | br = (data['Volume'] / 100000000) / ((data['High'] - data['Low'])) 10 | EVM = dm / br 11 | EVM_MA = pd.Series(pd.rolling_mean(EVM, ndays), name = 'EVM') 12 | data = data.join(EVM_MA) 13 | return data 14 | 15 | # Retrieve the AAPL data from Yahoo finance: 16 | data = web.DataReader('AAPL',data_source='yahoo',start='1/1/2015', end='1/1/2016') 17 | data = pd.DataFrame(data) 18 | 19 | # Compute the 14-day Ease of Movement for AAPL 20 | n = 14 21 | AAPL_EVM = EVM(data, n) 22 | EVM = AAPL_EVM['EVM'] 23 | 24 | # Plotting the Price Series chart and the Ease Of Movement below 25 | fig = plt.figure(figsize=(7,5)) 26 | ax = fig.add_subplot(2, 1, 1) 27 | ax.set_xticklabels([]) 28 | plt.plot(data['Close'],lw=1) 29 | plt.title('AAPL Price Chart') 30 | plt.ylabel('Close Price') 31 | plt.grid(True) 32 | bx = fig.add_subplot(2, 1, 2) 33 | plt.plot(EVM,'k',lw=0.75,linestyle='-',label='EVM(14)') 34 | plt.legend(loc=2,prop={'size':9}) 35 | plt.ylabel('EVM values') 36 | plt.grid(True) 37 | plt.setp(plt.gca().get_xticklabels(), rotation=30) -------------------------------------------------------------------------------- /代码/lecture_code 05/forecast.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | # forecast.py 5 | 6 | from __future__ import print_function 7 | 8 | import datetime 9 | import numpy as np 10 | import pandas as pd 11 | import sklearn 12 | 13 | from pandas.io.data import DataReader 14 | from sklearn.ensemble import RandomForestClassifier 15 | from sklearn.linear_model import LogisticRegression 16 | from sklearn.lda import LDA 17 | from sklearn.metrics import confusion_matrix 18 | from sklearn.qda import QDA 19 | from sklearn.svm import LinearSVC, SVC 20 | 21 | 22 | def create_lagged_series(symbol, start_date, end_date, lags=5): 23 | """ 24 | This creates a pandas DataFrame that stores the 25 | percentage returns of the adjusted closing value of 26 | a stock obtained from Yahoo Finance, along with a 27 | number of lagged returns from the prior trading days 28 | (lags defaults to 5 days). Trading volume, as well as 29 | the Direction from the previous day, are also included. 30 | """ 31 | 32 | # Obtain stock information from Yahoo Finance 33 | ts = DataReader( 34 | symbol, "yahoo", 35 | start_date-datetime.timedelta(days=365), 36 | end_date 37 | ) 38 | 39 | # Create the new lagged DataFrame 40 | tslag = pd.DataFrame(index=ts.index) 41 | tslag["Today"] = ts["Adj Close"] 42 | tslag["Volume"] = ts["Volume"] 43 | 44 | # Create the shifted lag series of prior trading period close values 45 | for i in range(0, lags): 46 | tslag["Lag%s" % str(i+1)] = ts["Adj Close"].shift(i+1) 47 | 48 | # Create the returns DataFrame 49 | tsret = pd.DataFrame(index=tslag.index) 50 | tsret["Volume"] = tslag["Volume"] 51 | tsret["Today"] = tslag["Today"].pct_change()*100.0 52 | 53 | # If any of the values of percentage returns equal zero, set them to 54 | # a small number (stops issues with QDA model in scikit-learn) 55 | for i,x in enumerate(tsret["Today"]): 56 | if (abs(x) < 0.0001): 57 | tsret["Today"][i] = 0.0001 58 | 59 | # Create the lagged percentage returns columns 60 | for i in range(0, lags): 61 | tsret["Lag%s" % str(i+1)] = \ 62 | tslag["Lag%s" % str(i+1)].pct_change()*100.0 63 | 64 | # Create the "Direction" column (+1 or -1) indicating an up/down day 65 | tsret["Direction"] = np.sign(tsret["Today"]) 66 | tsret = tsret[tsret.index >= start_date] 67 | 68 | return tsret 69 | 70 | 71 | if __name__ == "__main__": 72 | # Create a lagged series of the S&P500 US stock market index 73 | snpret = create_lagged_series( 74 | "^GSPC", datetime.datetime(2001,1,10), 75 | datetime.datetime(2005,12,31), lags=5 76 | ) 77 | 78 | # Use the prior two days of returns as predictor 79 | # values, with direction as the response 80 | X = snpret[["Lag1","Lag2"]] 81 | y = snpret["Direction"] 82 | 83 | # The test data is split into two parts: Before and after 1st Jan 2005. 84 | start_test = datetime.datetime(2005,1,1) 85 | 86 | # Create training and test sets 87 | X_train = X[X.index < start_test] 88 | X_test = X[X.index >= start_test] 89 | y_train = y[y.index < start_test] 90 | y_test = y[y.index >= start_test] 91 | 92 | # Create the (parametrised) models 93 | print("Hit Rates/Confusion Matrices:\n") 94 | models = [("LR", LogisticRegression()), 95 | ("LDA", LDA()), 96 | ("QDA", QDA()), 97 | ("LSVC", LinearSVC()), 98 | ("RSVM", SVC( 99 | C=1000000.0, cache_size=200, class_weight=None, 100 | coef0=0.0, degree=3, gamma=0.0001, kernel='rbf', 101 | max_iter=-1, probability=False, random_state=None, 102 | shrinking=True, tol=0.001, verbose=False) 103 | ), 104 | ("RF", RandomForestClassifier( 105 | n_estimators=1000, criterion='gini', 106 | max_depth=None, min_samples_split=2, 107 | min_samples_leaf=1, max_features='auto', 108 | bootstrap=True, oob_score=False, n_jobs=1, 109 | random_state=None, verbose=0) 110 | )] 111 | 112 | # Iterate through the models 113 | for m in models: 114 | 115 | # Train each of the models on the training set 116 | m[1].fit(X_train, y_train) 117 | 118 | # Make an array of predictions on the test set 119 | pred = m[1].predict(X_test) 120 | 121 | # Output the hit-rate and the confusion matrix for each model 122 | print("%s:\n%0.3f" % (m[0], m[1].score(X_test, y_test))) 123 | print("%s\n" % confusion_matrix(pred, y_test)) -------------------------------------------------------------------------------- /代码/lecture_code 05/grid_search.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | # grid_search.py 5 | 6 | from __future__ import print_function 7 | 8 | import datetime 9 | 10 | import sklearn 11 | from sklearn import cross_validation 12 | from sklearn.cross_validation import train_test_split 13 | from sklearn.grid_search import GridSearchCV 14 | from sklearn.metrics import classification_report 15 | from sklearn.svm import SVC 16 | 17 | from create_lagged_series import create_lagged_series 18 | 19 | 20 | if __name__ == "__main__": 21 | # Create a lagged series of the S&P500 US stock market index 22 | snpret = create_lagged_series( 23 | "^GSPC", datetime.datetime(2001,1,10), 24 | datetime.datetime(2005,12,31), lags=5 25 | ) 26 | 27 | # Use the prior two days of returns as predictor 28 | # values, with direction as the response 29 | X = snpret[["Lag1","Lag2"]] 30 | y = snpret["Direction"] 31 | 32 | # Train/test split 33 | X_train, X_test, y_train, y_test = train_test_split( 34 | X, y, test_size=0.5, random_state=42 35 | ) 36 | 37 | # Set the parameters by cross-validation 38 | tuned_parameters = [ 39 | {'kernel': ['rbf'], 'gamma': [1e-3, 1e-4], 'C': [1, 10, 100, 1000]} 40 | ] 41 | 42 | # Perform the grid search on the tuned parameters 43 | model = GridSearchCV(SVC(C=1), tuned_parameters, cv=10) 44 | model.fit(X_train, y_train) 45 | 46 | print("Optimised parameters found on training set:") 47 | print(model.best_estimator_, "\n") 48 | 49 | print("Grid scores calculated on training set:") 50 | for params, mean_score, scores in model.grid_scores_: 51 | print("%0.3f for %r" % (mean_score, params)) 52 | -------------------------------------------------------------------------------- /代码/lecture_code 08/backtest.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | # backtest.py 5 | 6 | from __future__ import print_function 7 | 8 | import datetime 9 | import pprint 10 | try: 11 | import Queue as queue 12 | except ImportError: 13 | import queue 14 | import time 15 | 16 | 17 | class Backtest(object): 18 | """ 19 | Enscapsulates the settings and components for carrying out 20 | an event-driven backtest. 21 | """ 22 | 23 | def __init__( 24 | self, csv_dir, symbol_list, initial_capital, 25 | heartbeat, start_date, data_handler, 26 | execution_handler, portfolio, strategy 27 | ): 28 | """ 29 | Initialises the backtest. 30 | 31 | Parameters: 32 | csv_dir - The hard root to the CSV data directory. 33 | symbol_list - The list of symbol strings. 34 | intial_capital - The starting capital for the portfolio. 35 | heartbeat - Backtest "heartbeat" in seconds 36 | start_date - The start datetime of the strategy. 37 | data_handler - (Class) Handles the market data feed. 38 | execution_handler - (Class) Handles the orders/fills for trades. 39 | portfolio - (Class) Keeps track of portfolio current and prior positions. 40 | strategy - (Class) Generates signals based on market data. 41 | """ 42 | self.csv_dir = csv_dir 43 | self.symbol_list = symbol_list 44 | self.initial_capital = initial_capital 45 | self.heartbeat = heartbeat 46 | self.start_date = start_date 47 | 48 | self.data_handler_cls = data_handler 49 | self.execution_handler_cls = execution_handler 50 | self.portfolio_cls = portfolio 51 | self.strategy_cls = strategy 52 | 53 | self.events = queue.Queue() 54 | 55 | self.signals = 0 56 | self.orders = 0 57 | self.fills = 0 58 | self.num_strats = 1 59 | 60 | self._generate_trading_instances() 61 | 62 | def _generate_trading_instances(self): 63 | """ 64 | Generates the trading instance objects from 65 | their class types. 66 | """ 67 | print( 68 | "Creating DataHandler, Strategy, Portfolio and ExecutionHandler" 69 | ) 70 | self.data_handler = self.data_handler_cls(self.events, self.csv_dir, self.symbol_list) 71 | self.strategy = self.strategy_cls(self.data_handler, self.events) 72 | self.portfolio = self.portfolio_cls(self.data_handler, self.events, self.start_date, 73 | self.initial_capital) 74 | self.execution_handler = self.execution_handler_cls(self.events) 75 | 76 | def _run_backtest(self): 77 | """ 78 | Executes the backtest. 79 | """ 80 | i = 0 81 | while True: 82 | i += 1 83 | print(i) 84 | # Update the market bars 85 | if self.data_handler.continue_backtest == True: 86 | self.data_handler.update_bars() 87 | else: 88 | break 89 | 90 | # Handle the events 91 | while True: 92 | try: 93 | event = self.events.get(False) 94 | except queue.Empty: 95 | break 96 | else: 97 | if event is not None: 98 | if event.type == 'MARKET': 99 | self.strategy.calculate_signals(event) 100 | self.portfolio.update_timeindex(event) 101 | 102 | elif event.type == 'SIGNAL': 103 | self.signals += 1 104 | self.portfolio.update_signal(event) 105 | 106 | elif event.type == 'ORDER': 107 | self.orders += 1 108 | self.execution_handler.execute_order(event) 109 | 110 | elif event.type == 'FILL': 111 | self.fills += 1 112 | self.portfolio.update_fill(event) 113 | 114 | time.sleep(self.heartbeat) 115 | 116 | def _output_performance(self): 117 | """ 118 | Outputs the strategy performance from the backtest. 119 | """ 120 | self.portfolio.create_equity_curve_dataframe() 121 | 122 | print("Creating summary stats...") 123 | stats = self.portfolio.output_summary_stats() 124 | 125 | print("Creating equity curve...") 126 | print(self.portfolio.equity_curve.tail(10)) 127 | pprint.pprint(stats) 128 | 129 | print("Signals: %s" % self.signals) 130 | print("Orders: %s" % self.orders) 131 | print("Fills: %s" % self.fills) 132 | 133 | def simulate_trading(self): 134 | """ 135 | Simulates the backtest and outputs portfolio performance. 136 | """ 137 | self._run_backtest() 138 | self._output_performance() 139 | -------------------------------------------------------------------------------- /代码/lecture_code 08/event.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | # event.py 5 | 6 | from __future__ import print_function 7 | 8 | 9 | class Event(object): 10 | """ 11 | Event is base class providing an interface for all subsequent 12 | (inherited) events, that will trigger further events in the 13 | trading infrastructure. 14 | """ 15 | pass 16 | 17 | 18 | class MarketEvent(Event): 19 | """ 20 | Handles the event of receiving a new market update with 21 | corresponding bars. 22 | """ 23 | 24 | def __init__(self): 25 | """ 26 | Initialises the MarketEvent. 27 | """ 28 | self.type = 'MARKET' 29 | 30 | 31 | class SignalEvent(Event): 32 | """ 33 | Handles the event of sending a Signal from a Strategy object. 34 | This is received by a Portfolio object and acted upon. 35 | """ 36 | 37 | def __init__(self, strategy_id, symbol, datetime, signal_type, strength): 38 | """ 39 | Initialises the SignalEvent. 40 | 41 | Parameters: 42 | strategy_id - The unique ID of the strategy sending the signal. 43 | symbol - The ticker symbol, e.g. 'GOOG'. 44 | datetime - The timestamp at which the signal was generated. 45 | signal_type - 'LONG' or 'SHORT'. 46 | strength - An adjustment factor "suggestion" used to scale 47 | quantity at the portfolio level. Useful for pairs strategies. 48 | """ 49 | self.strategy_id = strategy_id 50 | self.type = 'SIGNAL' 51 | self.symbol = symbol 52 | self.datetime = datetime 53 | self.signal_type = signal_type 54 | self.strength = strength 55 | 56 | 57 | class OrderEvent(Event): 58 | """ 59 | Handles the event of sending an Order to an execution system. 60 | The order contains a symbol (e.g. GOOG), a type (market or limit), 61 | quantity and a direction. 62 | """ 63 | 64 | def __init__(self, symbol, order_type, quantity, direction): 65 | """ 66 | Initialises the order type, setting whether it is 67 | a Market order ('MKT') or Limit order ('LMT'), has 68 | a quantity (integral) and its direction ('BUY' or 69 | 'SELL'). 70 | 71 | TODO: Must handle error checking here to obtain 72 | rational orders (i.e. no negative quantities etc). 73 | 74 | Parameters: 75 | symbol - The instrument to trade. 76 | order_type - 'MKT' or 'LMT' for Market or Limit. 77 | quantity - Non-negative integer for quantity. 78 | direction - 'BUY' or 'SELL' for long or short. 79 | """ 80 | self.type = 'ORDER' 81 | self.symbol = symbol 82 | self.order_type = order_type 83 | self.quantity = quantity 84 | self.direction = direction 85 | 86 | def print_order(self): 87 | """ 88 | Outputs the values within the Order. 89 | """ 90 | print( 91 | "Order: Symbol=%s, Type=%s, Quantity=%s, Direction=%s" % 92 | (self.symbol, self.order_type, self.quantity, self.direction) 93 | ) 94 | 95 | 96 | class FillEvent(Event): 97 | """ 98 | Encapsulates the notion of a Filled Order, as returned 99 | from a brokerage. Stores the quantity of an instrument 100 | actually filled and at what price. In addition, stores 101 | the commission of the trade from the brokerage. 102 | 103 | TODO: Currently does not support filling positions at 104 | different prices. This will be simulated by averaging 105 | the cost. 106 | """ 107 | 108 | def __init__(self, timeindex, symbol, exchange, quantity, 109 | direction, fill_cost, commission=None): 110 | """ 111 | Initialises the FillEvent object. Sets the symbol, exchange, 112 | quantity, direction, cost of fill and an optional 113 | commission. 114 | 115 | If commission is not provided, the Fill object will 116 | calculate it based on the trade size and Interactive 117 | Brokers fees. 118 | 119 | Parameters: 120 | timeindex - The bar-resolution when the order was filled. 121 | symbol - The instrument which was filled. 122 | exchange - The exchange where the order was filled. 123 | quantity - The filled quantity. 124 | direction - The direction of fill ('BUY' or 'SELL') 125 | fill_cost - The holdings value in dollars. 126 | commission - An optional commission sent from IB. 127 | """ 128 | self.type = 'FILL' 129 | self.timeindex = timeindex 130 | self.symbol = symbol 131 | self.exchange = exchange 132 | self.quantity = quantity 133 | self.direction = direction 134 | self.fill_cost = fill_cost 135 | 136 | # Calculate commission 137 | if commission is None: 138 | self.commission = self.calculate_ib_commission() 139 | else: 140 | self.commission = commission 141 | 142 | def calculate_ib_commission(self): 143 | """ 144 | Calculates the fees of trading based on an Interactive 145 | Brokers fee structure for API, in USD. 146 | 147 | This does not include exchange or ECN fees. 148 | 149 | Based on "US API Directed Orders": 150 | https://www.interactivebrokers.com/en/index.php?f=commission&p=stocks2 151 | """ 152 | full_cost = 1.3 153 | if self.quantity <= 500: 154 | full_cost = max(1.3, 0.013 * self.quantity) 155 | else: # Greater than 500 156 | full_cost = max(1.3, 0.008 * self.quantity) 157 | return full_cost 158 | -------------------------------------------------------------------------------- /代码/lecture_code 08/mac.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import numpy as np 3 | 4 | from backtest import Backtest 5 | from data import HistoricCSVDataHandler 6 | from event import SignalEvent 7 | from execution import SimulatedExecutionHandler 8 | from portfolio import Portfolio 9 | from strategy import Strategy 10 | 11 | 12 | class MovingAverageCrossStrategy(Strategy): 13 | """ 14 | Carries out a basic Moving Average Crossover strategy with a 15 | short/long simple weighted moving average. Default short/long 16 | windows are 100/400 periods respectively. 17 | """ 18 | 19 | def __init__(self, bars, events, short_window=100, long_window=400): 20 | """ 21 | Initialises the buy and hold strategy. 22 | 23 | Parameters: 24 | bars - The DataHandler object that provides bar information 25 | events - The Event Queue object. 26 | short_window - The short moving average lookback. 27 | long_window - The long moving average lookback. 28 | """ 29 | self.bars = bars 30 | self.symbol_list = self.bars.symbol_list 31 | self.events = events 32 | self.short_window = short_window 33 | self.long_window = long_window 34 | 35 | # Set to True if a symbol is in the market 36 | self.bought = self._calculate_initial_bought() 37 | 38 | def _calculate_initial_bought(self): 39 | """ 40 | Adds keys to the bought dictionary for all symbols 41 | and sets them to 'OUT'. 42 | """ 43 | bought = {} 44 | for s in self.symbol_list: 45 | bought[s] = 'OUT' 46 | return bought 47 | 48 | def calculate_signals(self, event): 49 | """ 50 | Generates a new set of signals based on the MAC 51 | SMA with the short window crossing the long window 52 | meaning a long entry and vice versa for a short entry. 53 | 54 | Parameters 55 | event - A MarketEvent object. 56 | """ 57 | if event.type == 'MARKET': 58 | for symbol in self.symbol_list: 59 | bars = self.bars.get_latest_bars_values(symbol, "close", N=self.long_window) 60 | 61 | if bars is not None and bars != []: 62 | short_sma = np.mean(bars[-self.short_window:]) 63 | long_sma = np.mean(bars[-self.long_window:]) 64 | 65 | dt = self.bars.get_latest_bar_datetime(symbol) 66 | sig_dir = "" 67 | strength = 1.0 68 | strategy_id = 1 69 | 70 | if short_sma > long_sma and self.bought[symbol] == "OUT": 71 | sig_dir = 'LONG' 72 | signal = SignalEvent(strategy_id, symbol, dt, sig_dir, strength) 73 | self.events.put(signal) 74 | self.bought[symbol] = 'LONG' 75 | 76 | elif short_sma < long_sma and self.bought[symbol] == "LONG": 77 | sig_dir = 'EXIT' 78 | signal = SignalEvent(strategy_id, symbol, dt, sig_dir, strength) 79 | self.events.put(signal) 80 | self.bought[symbol] = 'OUT' 81 | 82 | 83 | if __name__ == "__main__": 84 | csv_dir = REPLACE_WITH_YOUR_CSV_DIR_HERE 85 | symbol_list = ['AAPL'] 86 | initial_capital = 100000.0 87 | start_date = datetime.datetime(1990,1,1,0,0,0) 88 | heartbeat = 0.0 89 | 90 | backtest = Backtest(csv_dir, 91 | symbol_list, 92 | initial_capital, 93 | heartbeat, 94 | start_date, 95 | HistoricCSVDataHandler, 96 | SimulatedExecutionHandler, 97 | Portfolio, 98 | MovingAverageCrossStrategy) 99 | 100 | backtest.simulate_trading() 101 | -------------------------------------------------------------------------------- /代码/lecture_code 08/portfolio.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | # portfolio.py 5 | 6 | from __future__ import print_function 7 | 8 | import datetime 9 | from math import floor 10 | try: 11 | import Queue as queue 12 | except ImportError: 13 | import queue 14 | 15 | import numpy as np 16 | import pandas as pd 17 | 18 | from event import FillEvent, OrderEvent 19 | from performance import create_sharpe_ratio, create_drawdowns 20 | 21 | 22 | class Portfolio(object): 23 | """ 24 | The Portfolio class handles the positions and market 25 | value of all instruments at a resolution of a "bar", 26 | i.e. secondly, minutely, 5-min, 30-min, 60 min or EOD. 27 | 28 | The positions DataFrame stores a time-index of the 29 | quantity of positions held. 30 | 31 | The holdings DataFrame stores the cash and total market 32 | holdings value of each symbol for a particular 33 | time-index, as well as the percentage change in 34 | portfolio total across bars. 35 | """ 36 | 37 | def __init__(self, bars, events, start_date, initial_capital=100000.0): 38 | """ 39 | Initialises the portfolio with bars and an event queue. 40 | Also includes a starting datetime index and initial capital 41 | (USD unless otherwise stated). 42 | 43 | Parameters: 44 | bars - The DataHandler object with current market data. 45 | events - The Event Queue object. 46 | start_date - The start date (bar) of the portfolio. 47 | initial_capital - The starting capital in USD. 48 | """ 49 | self.bars = bars 50 | self.events = events 51 | self.symbol_list = self.bars.symbol_list 52 | self.start_date = start_date 53 | self.initial_capital = initial_capital 54 | 55 | self.all_positions = self.construct_all_positions() 56 | self.current_positions = dict( (k,v) for k, v in [(s, 0) for s in self.symbol_list] ) 57 | 58 | self.all_holdings = self.construct_all_holdings() 59 | self.current_holdings = self.construct_current_holdings() 60 | 61 | def construct_all_positions(self): 62 | """ 63 | Constructs the positions list using the start_date 64 | to determine when the time index will begin. 65 | """ 66 | d = dict( (k,v) for k, v in [(s, 0) for s in self.symbol_list] ) 67 | d['datetime'] = self.start_date 68 | return [d] 69 | 70 | def construct_all_holdings(self): 71 | """ 72 | Constructs the holdings list using the start_date 73 | to determine when the time index will begin. 74 | """ 75 | d = dict( (k,v) for k, v in [(s, 0.0) for s in self.symbol_list] ) 76 | d['datetime'] = self.start_date 77 | d['cash'] = self.initial_capital 78 | d['commission'] = 0.0 79 | d['total'] = self.initial_capital 80 | return [d] 81 | 82 | def construct_current_holdings(self): 83 | """ 84 | This constructs the dictionary which will hold the instantaneous 85 | value of the portfolio across all symbols. 86 | """ 87 | d = dict( (k,v) for k, v in [(s, 0.0) for s in self.symbol_list] ) 88 | d['cash'] = self.initial_capital 89 | d['commission'] = 0.0 90 | d['total'] = self.initial_capital 91 | return d 92 | 93 | def update_timeindex(self, event): 94 | """ 95 | Adds a new record to the positions matrix for the current 96 | market data bar. This reflects the PREVIOUS bar, i.e. all 97 | current market data at this stage is known (OHLCV). 98 | 99 | Makes use of a MarketEvent from the events queue. 100 | """ 101 | latest_datetime = self.bars.get_latest_bar_datetime(self.symbol_list[0]) 102 | 103 | # Update positions 104 | # ================ 105 | dp = dict( (k,v) for k, v in [(s, 0) for s in self.symbol_list] ) 106 | dp['datetime'] = latest_datetime 107 | 108 | for s in self.symbol_list: 109 | dp[s] = self.current_positions[s] 110 | 111 | # Append the current positions 112 | self.all_positions.append(dp) 113 | 114 | # Update holdings 115 | # =============== 116 | dh = dict( (k,v) for k, v in [(s, 0) for s in self.symbol_list] ) 117 | dh['datetime'] = latest_datetime 118 | dh['cash'] = self.current_holdings['cash'] 119 | dh['commission'] = self.current_holdings['commission'] 120 | dh['total'] = self.current_holdings['cash'] 121 | 122 | for s in self.symbol_list: 123 | # Approximation to the real value 124 | market_value = self.current_positions[s] * \ 125 | self.bars.get_latest_bar_value(s, "adj_close") 126 | dh[s] = market_value 127 | dh['total'] += market_value 128 | 129 | # Append the current holdings 130 | self.all_holdings.append(dh) 131 | 132 | # ====================== 133 | # FILL/POSITION HANDLING 134 | # ====================== 135 | 136 | def update_positions_from_fill(self, fill): 137 | """ 138 | Takes a Fill object and updates the position matrix to 139 | reflect the new position. 140 | 141 | Parameters: 142 | fill - The Fill object to update the positions with. 143 | """ 144 | # Check whether the fill is a buy or sell 145 | fill_dir = 0 146 | if fill.direction == 'BUY': 147 | fill_dir = 1 148 | if fill.direction == 'SELL': 149 | fill_dir = -1 150 | 151 | # Update positions list with new quantities 152 | self.current_positions[fill.symbol] += fill_dir*fill.quantity 153 | 154 | def update_holdings_from_fill(self, fill): 155 | """ 156 | Takes a Fill object and updates the holdings matrix to 157 | reflect the holdings value. 158 | 159 | Parameters: 160 | fill - The Fill object to update the holdings with. 161 | """ 162 | # Check whether the fill is a buy or sell 163 | fill_dir = 0 164 | if fill.direction == 'BUY': 165 | fill_dir = 1 166 | if fill.direction == 'SELL': 167 | fill_dir = -1 168 | 169 | # Update holdings list with new quantities 170 | fill_cost = self.bars.get_latest_bar_value( 171 | fill.symbol, "adj_close" 172 | ) 173 | cost = fill_dir * fill_cost * fill.quantity 174 | self.current_holdings[fill.symbol] += cost 175 | self.current_holdings['commission'] += fill.commission 176 | self.current_holdings['cash'] -= (cost + fill.commission) 177 | self.current_holdings['total'] -= (cost + fill.commission) 178 | 179 | def update_fill(self, event): 180 | """ 181 | Updates the portfolio current positions and holdings 182 | from a FillEvent. 183 | """ 184 | if event.type == 'FILL': 185 | self.update_positions_from_fill(event) 186 | self.update_holdings_from_fill(event) 187 | 188 | def generate_naive_order(self, signal): 189 | """ 190 | Simply files an Order object as a constant quantity 191 | sizing of the signal object, without risk management or 192 | position sizing considerations. 193 | 194 | Parameters: 195 | signal - The tuple containing Signal information. 196 | """ 197 | order = None 198 | 199 | symbol = signal.symbol 200 | direction = signal.signal_type 201 | strength = signal.strength 202 | 203 | mkt_quantity = 100 204 | cur_quantity = self.current_positions[symbol] 205 | order_type = 'MKT' 206 | 207 | if direction == 'LONG' and cur_quantity == 0: 208 | order = OrderEvent(symbol, order_type, mkt_quantity, 'BUY') 209 | if direction == 'SHORT' and cur_quantity == 0: 210 | order = OrderEvent(symbol, order_type, mkt_quantity, 'SELL') 211 | 212 | if direction == 'EXIT' and cur_quantity > 0: 213 | order = OrderEvent(symbol, order_type, abs(cur_quantity), 'SELL') 214 | if direction == 'EXIT' and cur_quantity < 0: 215 | order = OrderEvent(symbol, order_type, abs(cur_quantity), 'BUY') 216 | return order 217 | 218 | def update_signal(self, event): 219 | """ 220 | Acts on a SignalEvent to generate new orders 221 | based on the portfolio logic. 222 | """ 223 | if event.type == 'SIGNAL': 224 | order_event = self.generate_naive_order(event) 225 | self.events.put(order_event) 226 | 227 | # ======================== 228 | # POST-BACKTEST STATISTICS 229 | # ======================== 230 | 231 | def create_equity_curve_dataframe(self): 232 | """ 233 | Creates a pandas DataFrame from the all_holdings 234 | list of dictionaries. 235 | """ 236 | curve = pd.DataFrame(self.all_holdings) 237 | curve.set_index('datetime', inplace=True) 238 | curve['returns'] = curve['total'].pct_change() 239 | curve['equity_curve'] = (1.0+curve['returns']).cumprod() 240 | self.equity_curve = curve 241 | 242 | def output_summary_stats(self): 243 | """ 244 | Creates a list of summary statistics for the portfolio. 245 | """ 246 | total_return = self.equity_curve['equity_curve'][-1] 247 | returns = self.equity_curve['returns'] 248 | pnl = self.equity_curve['equity_curve'] 249 | 250 | sharpe_ratio = create_sharpe_ratio(returns, periods=252*60*6.5) 251 | drawdown, max_dd, dd_duration = create_drawdowns(pnl) 252 | self.equity_curve['drawdown'] = drawdown 253 | 254 | stats = [("Total Return", "%0.2f%%" % ((total_return - 1.0) * 100.0)), 255 | ("Sharpe Ratio", "%0.2f" % sharpe_ratio), 256 | ("Max Drawdown", "%0.2f%%" % (max_dd * 100.0)), 257 | ("Drawdown Duration", "%d" % dd_duration)] 258 | 259 | self.equity_curve.to_csv('equity.csv') 260 | return stats 261 | --------------------------------------------------------------------------------