├── .gitignore ├── QLearner.py ├── README.md ├── StrategyLearner.py ├── screencapture-evernote-client-web-2018-11-30-19_44_58.jpg ├── testStrategy.py ├── util.py └── yahoo_finance_data ├── GOOG.csv └── SPY.csv /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | /data/ 3 | /trained_result/ 4 | /feedback/ 5 | /example/ 6 | /references/ 7 | /training_dir/ 8 | /mc3p4_qlearning_trader/ 9 | /mc3p2_qlearning_robot/ 10 | mc3p2_qlearning_robot/ 11 | mc3p4_qlearning_trader/ 12 | -------------------------------------------------------------------------------- /QLearner.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Qlearner implementation by Wenchen Li 5 | 6 | Qlearner: 7 | 8 | dyna-Q: 9 | we start with straight, regular 10 | Q-Learning here and then we add three new components. 11 | The three components are, we update models of T and R, then we hallucinate an experience 12 | and update our Q table. 13 | Now we may repeat this many times, in fact maybe hundreds of times, until we're happy. 14 | Usually, it's 1 or 200 here. 15 | Once we've completed those, we then return back up to the top and continue our interaction 16 | with the real world. 17 | The reason Dyna-Q is useful is that these experiences with the real world are potentially very 18 | expensive and these hallucinations can be very cheap. And when we iterate doing many of 19 | them, we update our Q table much more quickly. 20 | """ 21 | 22 | # Copyright (C) 2017 Wenchen Li 23 | # 24 | # This program is free software: you can redistribute it and/or modify 25 | # it under the terms of the GNU Lesser General Public License as published by 26 | # the Free Software Foundation, either version 3 of the License, or 27 | # (at your option) any later version. 28 | # 29 | # This program is distributed in the hope that it will be useful, 30 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 31 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 32 | # GNU Lesser General Public License for more details. 33 | # 34 | # You should have received a copy of the GNU Lesser General Public License 35 | # along with this program. If not, see . 36 | 37 | import numpy as np 38 | import random as rand 39 | from util import save, load 40 | 41 | 42 | class QLearner(object): 43 | """ 44 | The constructor QLearner() reserve space for keeping track of Q[s, a] for the number of states and actions. 45 | It initialize Q[] with all zeros. 46 | num_states integer, the number of states to consider 47 | num_actions integer, the number of actions available. 48 | alpha float, the learning rate used in the update rule. Should range between 0.0 and 1.0 with 0.2 as a typical value. 49 | gamma float, the discount rate used in the update rule. Should range between 0.0 and 1.0 with 0.9 as a typical value. 50 | rar float, random action rate: the probability of selecting a random action at each step. Should range between 0.0 (no random actions) to 1.0 (always random action) with 0.5 as a typical value. 51 | radr float, random action decay rate, after each update, rar = rar * radr. Ranges between 0.0 (immediate decay to 0) and 1.0 (no decay). Typically 0.99. 52 | dyna integer, conduct this number of dyna updates for each regular update. When Dyna is used, 200 is a typical value. 53 | verbose boolean, if True, your class is allowed to print debugging statements, if False, all printing is prohibited. 54 | """ 55 | 56 | def __init__(self, 57 | num_states, 58 | num_actions, 59 | alpha=0.2, 60 | alpha_decay=.9999, 61 | gamma=0.9, 62 | rar=0.5, 63 | radr=0.99, 64 | dyna=0, 65 | verbose=False): 66 | 67 | self.verbose = verbose 68 | self.num_actions = num_actions 69 | self.num_states = num_states 70 | self.s = 0 71 | self.a = 0 72 | self.q_table = np.zeros((num_states, num_actions)) 73 | self.rar = rar 74 | self.radr = radr 75 | self.state_rar = np.zeros(num_states) 76 | self.state_rar += rar 77 | self.alpha = alpha 78 | self.alpha_decay = alpha_decay 79 | self.gamma = gamma 80 | 81 | self.dyna = dyna 82 | self.dyna_lr = .8 83 | self.dyna_init = dyna 84 | self.t_table_increment_unit = 1.0 / 100 85 | self.t_table = np.ones( 86 | (num_actions * num_states, num_states)) * self.t_table_increment_unit # state transition, 87 | self.r_table = np.zeros((num_states, num_actions)) 88 | 89 | self.last_state = rand.randint(0, self.num_states - 1) 90 | self.last_action = None 91 | 92 | def decay_alpha(self): 93 | """ 94 | decay the learning rate alpha in the value iteration 95 | """ 96 | self.alpha *= self.alpha_decay 97 | 98 | def querysetstate(self, s): 99 | """ 100 | @summary: Update the state without updating the Q-table 101 | @detail: A special version of the query method that sets the state to s, 102 | and returns an integer action according to the same rules as query() 103 | (including choosing a random action sometimes), but it does not execute 104 | an update to the Q-table. It also does not update rar. There are two main 105 | uses for this method: 1) To set the initial state, and 2) when using a 106 | learned policy, but not updating it. 107 | @param s:int, The new state 108 | @returns:int, The selected action 109 | """ 110 | 111 | # exploitation only 112 | action = np.argmax(self.q_table[s]) 113 | assert action < self.num_actions 114 | 115 | if self.verbose: print "s =", s 116 | # update state and action 117 | self.last_state = s 118 | self.last_action = action 119 | return action 120 | 121 | def query(self, s_prime, r): 122 | """ 123 | 124 | @summary: Update the Q table and return an action 125 | @detail: the core method of the Q-Learner. It keep track 126 | of the last state s and the last action a, then use the new information 127 | s_prime and r to update the Q table. The learning instance, or experience 128 | tuple is . query() return an integer, which is 129 | the next action to take. Note that it choose a random action with 130 | probability rar, and that it update rar according to the decay 131 | rate radr at each step. During exploration and exploitation each state has its own 132 | rar[random action rate],Details on the arguments: 133 | @param s_prime: int, the new state. 134 | @param r :float, a real valued immediate reward. 135 | @returns:int, The selected action 136 | """ 137 | 138 | # exploration vs exploitation 139 | if rand.uniform(0, 1) <= self.state_rar[self.last_state]: # exploration 140 | action = rand.randint(0, self.num_actions - 1) 141 | self.state_rar[self.last_state] *= self.radr 142 | else: # exploitation 143 | action = np.argmax(self.q_table[s_prime]) 144 | assert action < self.num_actions 145 | self.q_table[self.last_state][self.last_action] = (1.0 - self.alpha) * self.q_table[self.last_state][ 146 | self.last_action] + self.alpha * ( 147 | r + self.gamma * self.q_table[s_prime][np.argmax(self.q_table[s_prime])]) # bellman 148 | 149 | if self.verbose: print "s =", s_prime, "a =", action, "r =", r 150 | 151 | # update state and action in Qlearn 152 | self.last_state = s_prime 153 | self.last_action = action 154 | 155 | # dyna ( combine model free and with model learning) 156 | ## dyna: update T and R table 157 | while self.dyna > 0: 158 | # update T 159 | transition_index = self.num_states * self.last_action + self.last_state 160 | self.t_table[transition_index][s_prime] += self.t_table_increment_unit 161 | 162 | # update R 163 | self.r_table[self.last_state][self.last_action] = (1 - self.dyna_lr) * self.r_table[self.last_state][ 164 | self.last_action] + self.dyna_lr * r 165 | 166 | ## dyna hallucinate 167 | dyna_state = rand.randint(0, self.num_states - 1) 168 | dyna_action = rand.randint(0, self.num_actions - 1) 169 | transition_index = self.num_states * dyna_action + dyna_state 170 | transition_prob = self.t_table[transition_index] / np.sum(self.t_table[transition_index]) 171 | state_infer_from_t_table = np.random.choice(range(self.num_states), 1, p=transition_prob)[0] 172 | dyna_r = self.r_table[dyna_state][dyna_action] 173 | 174 | ## dyna update Q table 175 | self.q_table[dyna_state][dyna_action] = (1.0 - self.alpha) * self.q_table[dyna_state][ 176 | dyna_action] + self.alpha * (dyna_r + self.gamma * self.q_table[state_infer_from_t_table][ 177 | np.argmax(self.q_table[state_infer_from_t_table])]) 178 | self.dyna -= 1 179 | # end of dyna 180 | self.dyna = self.dyna_init 181 | 182 | return action 183 | 184 | def save_model(self, table_name="q_learner_tables.pkl"): 185 | """ 186 | save trained q learner, aka the q table (and t table and r table if dyna is included) 187 | :param table_name:saved table name 188 | """ 189 | tables = [self.q_table, self.t_table, self.r_table] 190 | save(tables, table_name) 191 | 192 | def load_model(self, table_name="q_learner_tables.pkl"): 193 | """ 194 | load trained q learner, aka the q table (and t table and r table if dyna is included) 195 | :param table_name:saved table name 196 | """ 197 | [self.q_table, self.t_table, self.r_table] = load(table_name) 198 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Q-learning trader 2 | A technical q trader on SPY given macd. 3 | 4 | ## run 5 | `python teststratgy.py` 6 | 7 | ## structure 8 | QLearner.py: an independent tabular (dyna)q-learner. 9 | 10 | StrategyLearner.py: Build upon Qlearner.py to learn the trading strategy 11 | 12 | testStrategy.py: train and test the StrategyLearner 13 | 14 | util.py: some helper functions for the model 15 | 16 | 17 | ## experiment result 18 | 19 | please see [this link](https://www.evernote.com/shard/s120/sh/00a11079-5e8d-4243-87e5-5daaf8565836/ebafa9649081c29dc98d27cb7c3a8a18) 20 | 21 | or the evernote screenshot below: 22 | ![experiments result](screencapture-evernote-client-web-2018-11-30-19_44_58.jpg) 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /StrategyLearner.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Q trader strategy learner. States are classified by kmeans now, Implementation 2017 Wenchen Li 5 | """ 6 | # Copyright (C) 2017 Wenchen Li 7 | # 8 | # This program is free software: you can redistribute it and/or modify 9 | # it under the terms of the GNU Lesser General Public License as published by 10 | # the Free Software Foundation, either version 3 of the License, or 11 | # (at your option) any later version. 12 | # 13 | # This program is distributed in the hope that it will be useful, 14 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | # GNU Lesser General Public License for more details. 17 | # 18 | # You should have received a copy of the GNU Lesser General Public License 19 | # along with this program. If not, see . 20 | 21 | import os 22 | import datetime as dt 23 | import QLearner as ql 24 | import pandas as pd 25 | import util as ut 26 | from sklearn.cluster import KMeans 27 | from sklearn.neighbors import KernelDensity 28 | import pickle 29 | import numpy as np 30 | 31 | CASH = "Cash" 32 | STOCK = "Stock" 33 | 34 | TRAIN_DIR = "./training_dir/" 35 | kmeans_model_save_name = 'kmeans_model.pkl' 36 | dyna_q_trader_model_save_name = "q_learner_tables.pkl" 37 | training_record_save_name = "records.pkl" 38 | 39 | 40 | class RawTradeFeatures(object): 41 | """ 42 | get raw trade features like adjust closed price and volume of the trading 43 | """ 44 | 45 | def __init__(self, symbol, sd, ed, ): 46 | self.syms = [symbol] 47 | self.dates = pd.date_range(sd, ed) 48 | 49 | def get_adj_price(self): 50 | prices_all = ut.get_data(self.syms, self.dates) # automatically adds SPY 51 | prices = prices_all[self.syms] # only portfolio symbols 52 | prices_SPY = prices_all['SPY'] # only SPY, for comparison later 53 | 54 | return prices, prices_SPY 55 | 56 | def get_vol(self): 57 | ## example use with new colname 58 | volume_all = ut.get_data(self.syms, self.dates, colname="Volume") # automatically adds SPY 59 | volume = volume_all[self.syms] # only portfolio symbols 60 | volume_SPY = volume_all['SPY'] # only SPY, for comparison later 61 | 62 | return volume, volume_SPY 63 | 64 | 65 | class StrategyLearner(object): 66 | """ 67 | For the policy learning part: 68 | 69 | Select several technical features, and compute their values for the training data 70 | Discretize the values of the features 71 | Instantiate a Q-learner 72 | For each day in the training data: 73 | Compute the current state (including holding) 74 | Compute the reward for the last action 75 | Query the learner with the current state and reward to get an action 76 | Implement the action the learner returned (BUY, SELL, NOTHING), and update portfolio value 77 | """ 78 | 79 | # constructor 80 | def __init__(self, verbose=True, save_dir=""): 81 | self.verbose = verbose 82 | self.Nothing = 0 83 | self.Buy = 1 84 | self.Sell = 2 85 | self.num_holding_state = 3 # 0 for 0, 1 for long, 2 for short 86 | self.num_feature_state = 100 87 | self.num_state = self.num_holding_state * self.num_feature_state 88 | self.num_action = 3 89 | self.shares_to_buy_or_sell = 100 90 | self.current_holding_state = 0 # prepare holding state, long, short, 0 91 | self.last_r = 0.0 92 | self.portfolio = {} 93 | self.num_epoch = 1000 94 | self.stop_threshold = 1.0 95 | self.epsilon = .01 96 | self.learner_dyan_iter = 0 # 200 97 | # keep records of each epoch and within each epoch the transaction 98 | self.records = [] # element for each epoch is (current_state,action,value, reward) 99 | 100 | self.negative_return_punish_factor = 1.3 101 | # save paths 102 | self.save_dir = save_dir 103 | self.current_working_dir = TRAIN_DIR + save_dir + "/" 104 | if not os.path.exists(self.current_working_dir): 105 | os.makedirs(self.current_working_dir) 106 | 107 | def get_current_portfolio_values(self, today_stock_price): 108 | return self.portfolio[CASH] + self.portfolio[STOCK] * today_stock_price 109 | 110 | def init_portfolio(self, sv,): 111 | self.portfolio[CASH] = float(sv) 112 | self.portfolio[STOCK] = 0 113 | 114 | def get_raw_data(self, symbol, sd, ed): 115 | # record in panda format 116 | dates = pd.date_range(sd, ed) 117 | prices_all = ut.get_data([symbol], dates) # automatically adds SPY 118 | trades = prices_all[[symbol, ]] # only portfolio symbols 119 | portfolio_value = prices_all[[symbol, ]] 120 | 121 | # Select several technical features, and compute their values for the training data 122 | 123 | tf = RawTradeFeatures(symbol, sd, ed) 124 | price, price_SPY = tf.get_adj_price() 125 | 126 | return dates,prices_all, trades, portfolio_value, tf, price, price_SPY 127 | 128 | def get_benchmark(self,symbol, price_SPY, prices_all,sv ): 129 | """ 130 | get buy and hold benchmark result 131 | """ 132 | # buy and hold the benchmark 133 | benchmark_price = price_SPY.as_matrix() 134 | benchmark_values = prices_all[[symbol, ]] # only portfolio symbols 135 | benchmark_num_stock = int(sv / benchmark_price[0]) 136 | cash = sv - float(benchmark_num_stock * benchmark_price[0]) 137 | for i, p in enumerate(benchmark_price): 138 | benchmark_values.values[i, :] = cash + p * benchmark_num_stock 139 | 140 | return benchmark_values 141 | 142 | def get_derived_data(self,symbol, prices_all,sd,ed): 143 | """ 144 | get derived data macd and bollinger band given the raw data 145 | """ 146 | # macd 147 | macd = ut.norm(ut.get_macd(prices_all[symbol])["MACD"].as_matrix()) 148 | 149 | # bollinger band 150 | bb = ut.Bollinger_Bands_given_sym_dates([symbol], sd, ed) 151 | bb_rm = ut.norm(bb['rolling_mean'].as_matrix()) 152 | bb_ub = ut.norm(bb['upper_band'].as_matrix()) 153 | bb_lb = ut.norm(bb['lower_band'].as_matrix()) 154 | 155 | return macd, bb,bb_rm, bb_ub, bb_lb 156 | 157 | def get_finalized_input(self,price, symbol,macd): 158 | """ 159 | reformatted the finalized input macd and price 160 | """ 161 | # input to model or later process 162 | l = price[symbol].as_matrix() 163 | price_array = l.copy() 164 | 165 | x = macd.reshape((-1, 1)) 166 | 167 | return price_array, x 168 | 169 | def perform_actions_update_trades_value(self,action,price_array,trades,i): 170 | """ 171 | given action perform the action and record the trades value. 172 | """ 173 | if action == 0: # do nothing 174 | if self.verbose: print "do nothing" 175 | elif action == 1: # buy 176 | if self.current_holding_state == 0 or self.current_holding_state == 2: # holding nothing or short 177 | self.portfolio[CASH] -= self.shares_to_buy_or_sell * price_array[i] 178 | self.portfolio[STOCK] += self.shares_to_buy_or_sell 179 | elif self.current_holding_state == 1: # long 180 | if self.verbose: print "buy but long already, nothing to do" 181 | 182 | else: # action sell 183 | if self.current_holding_state == 0 or self.current_holding_state == 1: # holding nothing or long 184 | self.portfolio[CASH] += self.shares_to_buy_or_sell * price_array[i] 185 | self.portfolio[STOCK] -= self.shares_to_buy_or_sell 186 | elif self.current_holding_state == 2: # short 187 | if self.verbose: print "sell but short already, nothing to do" 188 | 189 | assert np.abs(self.portfolio[STOCK]) <= self.shares_to_buy_or_sell 190 | # update self.holding state 191 | if self.portfolio[STOCK] == self.shares_to_buy_or_sell: 192 | self.current_holding_state = 1 193 | elif self.portfolio[STOCK] == -self.shares_to_buy_or_sell: 194 | self.current_holding_state = 2 195 | elif self.portfolio[STOCK] == 0: 196 | self.current_holding_state = 0 197 | else: 198 | if self.verbose: print self.portfolio, "current portfolio is not valid" 199 | 200 | trades.values[i, :] = self.shares_to_buy_or_sell 201 | if action == 0: 202 | trades.values[i, :] *= 0 203 | elif action == 1: 204 | trades.values[i, :] *= 1 205 | else: 206 | trades.values[i, :] *= -1 207 | 208 | return trades 209 | 210 | def save_plot_and_model(self,benchmark_values, portfolio_value,trades,save_plot=False): 211 | """ 212 | save trained model and plot result if save_plot set to True 213 | """ 214 | # save transaction and portfolio image and training records 215 | benchmark_values = benchmark_values.rename(columns={'SPY': "benchmark"}) 216 | portfolio_value = portfolio_value.rename(columns={'SPY': "q-learn-trader"}) 217 | if save_plot: 218 | p_value_all = portfolio_value.join(benchmark_values) 219 | ut.plot_data(trades, title="transactions_train", ylabel="amount", save_image=True, 220 | save_dir=self.current_working_dir) 221 | ut.plot_data(p_value_all, title="portfolio value_train", ylabel="USD", save_image=True, 222 | save_dir=self.current_working_dir) 223 | self.learner.save_model(table_name=self.current_working_dir + dyna_q_trader_model_save_name) 224 | 225 | def addEvidence(self, symbol="SPY", 226 | sd=dt.datetime(2008, 1, 1), 227 | ed=dt.datetime(2009, 1, 1), 228 | sv=10000): 229 | """ 230 | train q learner 231 | :param symbol:security symbol 232 | :param sd: start date 233 | :param ed: end data 234 | :param sv: start fund value 235 | :return: return of the trade 236 | """ 237 | 238 | self.init_portfolio(sv) 239 | 240 | dates, prices_all, trades, portfolio_value, tf, price, price_SPY = self.get_raw_data(symbol, sd, ed) 241 | 242 | benchmark_values = self.get_benchmark( symbol, price_SPY, prices_all, sv) 243 | 244 | macd, bb, bb_rm, bb_ub, bb_lb = self.get_derived_data( symbol, prices_all, sd, ed) 245 | 246 | price_array, x = self.get_finalized_input( price, symbol, macd) 247 | 248 | # discretize 249 | kmeans_model = KMeans(n_clusters=self.num_feature_state, random_state=0, ) 250 | kmeans = kmeans_model.fit(x) 251 | pickle.dump(kmeans_model, open(self.current_working_dir + kmeans_model_save_name, 'wb')) 252 | feature_states = kmeans.labels_ 253 | 254 | # Instantiate a Q-learner 255 | self.learner = ql.QLearner(num_states=self.num_state, num_actions=self.num_action, rar=.5, alpha=.001, 256 | alpha_decay=.99, dyna=self.learner_dyan_iter) 257 | self.last_cumulated_return = 0.0 258 | self.cumulated_return_indicator = 0 259 | 260 | # value iteration 261 | for k in xrange(self.num_epoch): 262 | for i, s in enumerate(feature_states): 263 | if i == len(feature_states) - 1: 264 | portfolio_value.values[i, :] = self.get_current_portfolio_values(price_array[-1]) 265 | continue # skip last because 2nd day price 266 | # compute the current state(include holding) 267 | current_holding_state = self.current_holding_state 268 | current_feature_state = s 269 | current_state = self.num_feature_state * current_holding_state + current_feature_state 270 | 271 | # computer the last reward 272 | r = self.last_r 273 | if r < 0: # punish negative reward 274 | r *= self.negative_return_punish_factor 275 | # Query the learner with the current state and reward to get an action 276 | action = self.learner.query(current_state, r) 277 | 278 | # Implement the action the learner returned (BUY, SELL, NOTHING), and update portfolio value 279 | last_portfolio_value = self.get_current_portfolio_values(price_array[i]) # sum(self.portfolio_values) 280 | 281 | trades = self.perform_actions_update_trades_value(action, price_array, trades, i) 282 | 283 | portfolio_value.values[i, :] = self.get_current_portfolio_values(price_array[i]) 284 | self.last_r = (self.get_current_portfolio_values(price_array[i + 1]) - last_portfolio_value) / float(sv) 285 | if self.verbose: print self.last_r 286 | 287 | if self.verbose: print "epoch", k, " current cumulated return:", self.get_current_portfolio_values( 288 | price_array[-1]) / float(sv) - 1.0, "portfolio:", self.portfolio 289 | self.cumulated_return_indicator = self.last_cumulated_return / ( 290 | self.get_current_portfolio_values(price_array[-1]) / float(sv)) 291 | self.last_cumulated_return = self.get_current_portfolio_values(price_array[-1]) / float(sv) - 1 292 | 293 | if self.num_epoch - 1 != k: 294 | # rest portfolio 295 | self.portfolio[CASH] = sv 296 | self.portfolio[STOCK] = 0 297 | self.current_holding_state = 0 298 | self.last_r = 0.0 299 | # decay alpha 300 | self.learner.decay_alpha() 301 | 302 | self.save_plot_and_model(benchmark_values, portfolio_value, trades) 303 | 304 | trade_return = self.get_current_portfolio_values(price_array[-1]) / sv - 1.0 305 | return trade_return 306 | 307 | # this method should use the existing policy and test it against new data 308 | 309 | 310 | def testPolicy(self, symbol="IBM", 311 | sd=dt.datetime(2009, 1, 1), 312 | ed=dt.datetime(2010, 1, 1), 313 | sv=10000): 314 | """ 315 | test q learner 316 | :param symbol:security symbol 317 | :param sd: start date 318 | :param ed: end data 319 | :param sv: start fund value 320 | :return: trades: record of the trades, trade_return: return of the trade 321 | """ 322 | self.init_portfolio(sv) 323 | 324 | dates, prices_all, trades, portfolio_value, tf, price, price_SPY = self.get_raw_data(symbol, sd, ed) 325 | 326 | benchmark_values = self.get_benchmark(symbol, price_SPY, prices_all, sv) 327 | 328 | macd, bb, bb_rm, bb_ub, bb_lb = self.get_derived_data( symbol, prices_all, sd, ed) 329 | 330 | price_array, x = self.get_finalized_input( price, symbol, macd) 331 | 332 | # kmeans model load 333 | kmeans_model = pickle.load(open(self.current_working_dir + kmeans_model_save_name, 'rb')) 334 | kmeans = kmeans_model.predict(x) 335 | feature_states = kmeans 336 | 337 | # Instantiate a Q-learner 338 | self.learner = ql.QLearner(num_states=self.num_state, num_actions=self.num_action) 339 | self.last_cumulated_return = 0.0 340 | self.cumulated_return_indicator = 0 341 | # load trained q table(if dyna,t_table and r_table) 342 | self.learner.load_model(table_name=self.current_working_dir + dyna_q_trader_model_save_name) 343 | 344 | # value iteration 345 | for i, s in enumerate(feature_states): 346 | if i == len(feature_states) - 1: 347 | portfolio_value.values[i, :] = self.get_current_portfolio_values(price_array[-1]) 348 | continue # skip last because 2nd day price 349 | 350 | # compute the current state(include holding) 351 | current_holding_state = self.current_holding_state 352 | current_feature_state = s 353 | current_state = self.num_feature_state * current_holding_state + current_feature_state 354 | 355 | # Query the learner with the current state 356 | action = self.learner.querysetstate(current_state) 357 | 358 | # Implement the action the learner returned (BUY, SELL, NOTHING), and update portfolio value 359 | last_portfolio_value = self.get_current_portfolio_values(price_array[i]) # sum(self.portfolio_values) 360 | 361 | trades = self.perform_actions_update_trades_value(action, price_array, trades, i) 362 | 363 | portfolio_value.values[i, :] = self.get_current_portfolio_values(price_array[i]) 364 | 365 | self.last_r = (self.get_current_portfolio_values(price_array[i + 1]) - last_portfolio_value) / float(sv) 366 | if self.verbose: print self.last_r 367 | 368 | self.last_cumulated_return = self.get_current_portfolio_values(price_array[-1]) / float(sv) - 1 369 | 370 | self.save_plot_and_model(benchmark_values, portfolio_value, trades) 371 | 372 | trade_return = self.get_current_portfolio_values(price_array[-1]) / sv - 1.0 373 | if self.verbose: print "cumulated return=:", trade_return 374 | return trades, trade_return 375 | -------------------------------------------------------------------------------- /screencapture-evernote-client-web-2018-11-30-19_44_58.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenchenLi/q-learning-trader/7f34453567914d2c980f540f030c228b50f574f4/screencapture-evernote-client-web-2018-11-30-19_44_58.jpg -------------------------------------------------------------------------------- /testStrategy.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Test a Strategy Learner. (c) 2017 Wenchen Li 5 | """ 6 | # Copyright (C) 2017 Wenchen Li 7 | # 8 | # This program is free software: you can redistribute it and/or modify 9 | # it under the terms of the GNU Lesser General Public License as published by 10 | # the Free Software Foundation, either version 3 of the License, or 11 | # (at your option) any later version. 12 | # 13 | # This program is distributed in the hope that it will be useful, 14 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | # GNU Lesser General Public License for more details. 17 | # 18 | # You should have received a copy of the GNU Lesser General Public License 19 | # along with this program. If not, see . 20 | 21 | import pandas as pd 22 | import datetime as dt 23 | import util as ut 24 | import StrategyLearner as sl 25 | 26 | def test_code(train_q=None,out_q=None,verb = True, test_only=False,save_dir=".",): 27 | 28 | # instantiate the strategy learner 29 | learner = sl.StrategyLearner(verbose = verb,save_dir=save_dir) 30 | 31 | if not test_only: 32 | # set parameters for training the learner 33 | sym = "SPY" 34 | stdate =dt.datetime(2006,4,30) 35 | enddate =dt.datetime(2007,1,3) # just a few days for "shake out" 36 | 37 | # train the learner 38 | trade_return_train = learner.addEvidence(symbol = sym, sd = stdate, 39 | ed = enddate, sv = 10000) 40 | if train_q: 41 | train_q.put(trade_return_train) 42 | # set parameters for testing 43 | sym = "SPY" 44 | stdate =dt.datetime(2007,1,3) 45 | enddate =dt.datetime(2007,12,31) 46 | 47 | # get some data for reference 48 | syms=[sym] 49 | dates = pd.date_range(stdate, enddate) 50 | prices_all = ut.get_data(syms, dates) # automatically adds SPY 51 | prices = prices_all[syms] # only portfolio symbols 52 | # if verb: print prices 53 | 54 | # test the learner 55 | df_trades,trade_return = learner.testPolicy(symbol = sym, sd = stdate, \ 56 | ed = enddate, sv = 10000) 57 | if out_q: 58 | out_q.put(trade_return) 59 | # a few sanity checks 60 | # df_trades should be a single column DataFrame (not a series) 61 | # including only the values 500, 0, -500 62 | if isinstance(df_trades, pd.DataFrame) == False: 63 | print "Returned result is not a DataFrame" 64 | if prices.shape != df_trades.shape: 65 | print "Returned result is not the right shape" 66 | 67 | if __name__=="__main__": 68 | from multiprocessing import Process,Queue 69 | import numpy as np 70 | 71 | train_q = Queue() 72 | out_q = Queue() 73 | train_resultdict = {} 74 | test_resultdict = {} 75 | total_num_simulation_left = 1000 76 | 77 | while total_num_simulation_left>0: 78 | 79 | nprocs_each_iter = 10 80 | procs = [] 81 | current_training_ids = [] 82 | for i in xrange(nprocs_each_iter): 83 | training_id = total_num_simulation_left - i 84 | current_training_ids.append(training_id) 85 | proc = Process(target=test_code,args=(train_q,out_q,False,False,str(training_id),)) 86 | procs.append(proc) 87 | proc.start() 88 | 89 | for training_id in current_training_ids: 90 | train_resultdict[training_id] = train_q.get() 91 | test_resultdict[training_id]= out_q.get() 92 | 93 | for p in procs: 94 | p.join() 95 | 96 | total_num_simulation_left -= nprocs_each_iter 97 | 98 | print "train:" 99 | print "each strategy return:", train_resultdict.values() 100 | print "mean return:", np.average(train_resultdict.values()) 101 | print "std return:", np.std(train_resultdict.values()) 102 | 103 | mean = np.average(test_resultdict.values()) 104 | std = np.std(test_resultdict.values()) 105 | yearly_risk_free_rate = .05 # https://www.treasury.gov/resource-center/data-chart-center/interest-rates/Pages/TextView.aspx?data=yieldYear&year=2007 106 | print "test:" 107 | print "each strategy return:",test_resultdict.values() 108 | print "mean return:",mean 109 | print "std return:",std 110 | print "max, min return:",max(test_resultdict.values()), min(test_resultdict.values()) 111 | print "sharpe ratio:", (mean - yearly_risk_free_rate)/std 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ Utility code.""" 4 | 5 | # Copyright (C) 2017 Wenchen Li 6 | # 7 | # This program is free software: you can redistribute it and/or modify 8 | # it under the terms of the GNU Lesser General Public License as published by 9 | # the Free Software Foundation, either version 3 of the License, or 10 | # (at your option) any later version. 11 | # 12 | # This program is distributed in the hope that it will be useful, 13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | # GNU Lesser General Public License for more details. 16 | # 17 | # You should have received a copy of the GNU Lesser General Public License 18 | # along with this program. If not, see . 19 | 20 | import os 21 | import pandas as pd 22 | import matplotlib.pyplot as plt 23 | import pickle 24 | import datetime as dt 25 | import numpy as np 26 | 27 | 28 | def symbol_to_path(symbol, base_dir=os.path.join(".", "yahoo_finance_data")): 29 | """Return CSV file path given ticker symbol.""" 30 | return os.path.join(base_dir, "{}.csv".format(str(symbol))) 31 | 32 | 33 | def get_data(symbols, dates, addSPY=True, colname = 'Adj Close'): 34 | """Read stock data (adjusted close) for given symbols from CSV files.""" 35 | df = pd.DataFrame(index=dates) 36 | if addSPY and 'SPY' not in symbols: # add SPY for reference, if absent 37 | symbols = ['SPY'] + symbols 38 | 39 | for symbol in symbols: 40 | df_temp = pd.read_csv(symbol_to_path(symbol), index_col='Date', 41 | parse_dates=True, usecols=['Date', colname], na_values=['nan']) 42 | df_temp = df_temp.rename(columns={colname: symbol}) 43 | df = df.join(df_temp) 44 | if symbol == 'SPY': # drop dates SPY did not trade 45 | df = df.dropna(subset=["SPY"]) 46 | 47 | return df 48 | 49 | 50 | def plot_data(df, title="Stock prices", xlabel="Date", ylabel="Price",save_image=False,save_dir="./"): 51 | """Plot stock prices with a custom title and meaningful axis labels.""" 52 | ax = df.plot(title=title, fontsize=12) 53 | ax.set_xlabel(xlabel) 54 | ax.set_ylabel(ylabel) 55 | if not save_image: 56 | plt.show() 57 | else: 58 | plt.savefig(save_dir+title) 59 | plt.close() 60 | 61 | def get_macd(group): 62 | 63 | def moving_average(group, n=9): 64 | sma = pd.rolling_mean(group, n) 65 | return sma 66 | 67 | def moving_average_convergence(group, nslow=26, nfast=12): 68 | emaslow = pd.Series.ewm(group, span=nslow, min_periods=1).mean() 69 | emafast = pd.Series.ewm(group, span=nfast, min_periods=1).mean() 70 | result = pd.DataFrame({'MACD': emafast - emaslow, 'emaSlw': emaslow, 'emaFst': emafast}) 71 | return result 72 | 73 | return moving_average_convergence(group) 74 | 75 | 76 | def Bollinger_Bands(stock_price, window_size, num_of_std): 77 | rolling_mean = stock_price.rolling(window=window_size).mean() 78 | rolling_std = stock_price.rolling(window=window_size).std() 79 | upper_band = rolling_mean + (rolling_std * num_of_std) 80 | lower_band = rolling_mean - (rolling_std * num_of_std) 81 | 82 | return rolling_mean, upper_band, lower_band 83 | 84 | 85 | def Bollinger_Bands_given_sym_dates(sym, start_date,end_date,window_size=20, num_of_std=2): 86 | 87 | dates = pd.date_range(start_date - dt.timedelta(window_size*2-10), end_date) #TODO think better choose nan dates 88 | stock_price = get_data(sym, dates) 89 | 90 | rolling_mean, upper_band, lower_band = Bollinger_Bands(stock_price["SPY"], window_size,num_of_std) 91 | retrive_dates = pd.date_range(start_date, end_date) 92 | result = pd.DataFrame({'rolling_mean': rolling_mean, 'upper_band': upper_band, 'lower_band': lower_band},index=dates) 93 | result = result.dropna() 94 | return result 95 | 96 | 97 | def momentum(sym,start_date,end_date,window_size=10): 98 | dates = pd.date_range(start_date - dt.timedelta(window_size), end_date) # TODO think better choose nan dates 99 | stock_price = get_data(sym, dates) 100 | 101 | # print rolling_mean, upper_band,lower_band 102 | # M = 103 | result = pd.DataFrame() 104 | result = result.dropna() 105 | return result 106 | 107 | 108 | def norm(l): 109 | l = np.array(l) 110 | return (l - l.min()) / (l.max() - l.min()) 111 | 112 | def save(object,file_path): 113 | with open(file_path,"wb") as handle: 114 | pickle.dump(object,handle) 115 | 116 | def load(file_path): 117 | with open(file_path,"rb") as handle: 118 | obj = pickle.load(handle) 119 | return obj 120 | 121 | 122 | 123 | if __name__=="__main__": 124 | #plot test 125 | sym = "GOOG" 126 | stdate = dt.datetime(2007, 1, 3) 127 | enddate = dt.datetime(2007, 12, 31) 128 | syms = [sym] 129 | dates = pd.date_range(stdate, enddate) 130 | prices_all = get_data(syms, dates) # automatically adds SPY 131 | print prices_all 132 | # plot_data(prices_all) 133 | 134 | # test macd 135 | # record in panda format 136 | stdate = dt.datetime(2007, 1, 3) 137 | enddate = dt.datetime(2007, 12, 31) 138 | sym = ["GOOG"] 139 | dates = pd.date_range(stdate, enddate) 140 | prices_all = get_data(sym, dates) # automatically adds SPY 141 | print get_macd(prices_all["SPY"])["MACD"].as_matrix() 142 | 143 | # test Bollinger band #TODO retrieve the first missing window data 144 | 145 | # print Bollinger_Bands(prices_all["SPY"], 20, 2) 146 | bb = Bollinger_Bands_given_sym_dates(sym,stdate,enddate) 147 | print bb 148 | 149 | # momentum 150 | m = momentum(sym, stdate, enddate) 151 | print m --------------------------------------------------------------------------------