├── .gitignore
├── QLearner.py
├── README.md
├── StrategyLearner.py
├── screencapture-evernote-client-web-2018-11-30-19_44_58.jpg
├── testStrategy.py
├── util.py
└── yahoo_finance_data
├── GOOG.csv
└── SPY.csv
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | /data/
3 | /trained_result/
4 | /feedback/
5 | /example/
6 | /references/
7 | /training_dir/
8 | /mc3p4_qlearning_trader/
9 | /mc3p2_qlearning_robot/
10 | mc3p2_qlearning_robot/
11 | mc3p4_qlearning_trader/
12 |
--------------------------------------------------------------------------------
/QLearner.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | Qlearner implementation by Wenchen Li
5 |
6 | Qlearner:
7 |
8 | dyna-Q:
9 | we start with straight, regular
10 | Q-Learning here and then we add three new components.
11 | The three components are, we update models of T and R, then we hallucinate an experience
12 | and update our Q table.
13 | Now we may repeat this many times, in fact maybe hundreds of times, until we're happy.
14 | Usually, it's 1 or 200 here.
15 | Once we've completed those, we then return back up to the top and continue our interaction
16 | with the real world.
17 | The reason Dyna-Q is useful is that these experiences with the real world are potentially very
18 | expensive and these hallucinations can be very cheap. And when we iterate doing many of
19 | them, we update our Q table much more quickly.
20 | """
21 |
22 | # Copyright (C) 2017 Wenchen Li
23 | #
24 | # This program is free software: you can redistribute it and/or modify
25 | # it under the terms of the GNU Lesser General Public License as published by
26 | # the Free Software Foundation, either version 3 of the License, or
27 | # (at your option) any later version.
28 | #
29 | # This program is distributed in the hope that it will be useful,
30 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
31 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
32 | # GNU Lesser General Public License for more details.
33 | #
34 | # You should have received a copy of the GNU Lesser General Public License
35 | # along with this program. If not, see .
36 |
37 | import numpy as np
38 | import random as rand
39 | from util import save, load
40 |
41 |
42 | class QLearner(object):
43 | """
44 | The constructor QLearner() reserve space for keeping track of Q[s, a] for the number of states and actions.
45 | It initialize Q[] with all zeros.
46 | num_states integer, the number of states to consider
47 | num_actions integer, the number of actions available.
48 | alpha float, the learning rate used in the update rule. Should range between 0.0 and 1.0 with 0.2 as a typical value.
49 | gamma float, the discount rate used in the update rule. Should range between 0.0 and 1.0 with 0.9 as a typical value.
50 | rar float, random action rate: the probability of selecting a random action at each step. Should range between 0.0 (no random actions) to 1.0 (always random action) with 0.5 as a typical value.
51 | radr float, random action decay rate, after each update, rar = rar * radr. Ranges between 0.0 (immediate decay to 0) and 1.0 (no decay). Typically 0.99.
52 | dyna integer, conduct this number of dyna updates for each regular update. When Dyna is used, 200 is a typical value.
53 | verbose boolean, if True, your class is allowed to print debugging statements, if False, all printing is prohibited.
54 | """
55 |
56 | def __init__(self,
57 | num_states,
58 | num_actions,
59 | alpha=0.2,
60 | alpha_decay=.9999,
61 | gamma=0.9,
62 | rar=0.5,
63 | radr=0.99,
64 | dyna=0,
65 | verbose=False):
66 |
67 | self.verbose = verbose
68 | self.num_actions = num_actions
69 | self.num_states = num_states
70 | self.s = 0
71 | self.a = 0
72 | self.q_table = np.zeros((num_states, num_actions))
73 | self.rar = rar
74 | self.radr = radr
75 | self.state_rar = np.zeros(num_states)
76 | self.state_rar += rar
77 | self.alpha = alpha
78 | self.alpha_decay = alpha_decay
79 | self.gamma = gamma
80 |
81 | self.dyna = dyna
82 | self.dyna_lr = .8
83 | self.dyna_init = dyna
84 | self.t_table_increment_unit = 1.0 / 100
85 | self.t_table = np.ones(
86 | (num_actions * num_states, num_states)) * self.t_table_increment_unit # state transition,
87 | self.r_table = np.zeros((num_states, num_actions))
88 |
89 | self.last_state = rand.randint(0, self.num_states - 1)
90 | self.last_action = None
91 |
92 | def decay_alpha(self):
93 | """
94 | decay the learning rate alpha in the value iteration
95 | """
96 | self.alpha *= self.alpha_decay
97 |
98 | def querysetstate(self, s):
99 | """
100 | @summary: Update the state without updating the Q-table
101 | @detail: A special version of the query method that sets the state to s,
102 | and returns an integer action according to the same rules as query()
103 | (including choosing a random action sometimes), but it does not execute
104 | an update to the Q-table. It also does not update rar. There are two main
105 | uses for this method: 1) To set the initial state, and 2) when using a
106 | learned policy, but not updating it.
107 | @param s:int, The new state
108 | @returns:int, The selected action
109 | """
110 |
111 | # exploitation only
112 | action = np.argmax(self.q_table[s])
113 | assert action < self.num_actions
114 |
115 | if self.verbose: print "s =", s
116 | # update state and action
117 | self.last_state = s
118 | self.last_action = action
119 | return action
120 |
121 | def query(self, s_prime, r):
122 | """
123 |
124 | @summary: Update the Q table and return an action
125 | @detail: the core method of the Q-Learner. It keep track
126 | of the last state s and the last action a, then use the new information
127 | s_prime and r to update the Q table. The learning instance, or experience
128 | tuple is . query() return an integer, which is
129 | the next action to take. Note that it choose a random action with
130 | probability rar, and that it update rar according to the decay
131 | rate radr at each step. During exploration and exploitation each state has its own
132 | rar[random action rate],Details on the arguments:
133 | @param s_prime: int, the new state.
134 | @param r :float, a real valued immediate reward.
135 | @returns:int, The selected action
136 | """
137 |
138 | # exploration vs exploitation
139 | if rand.uniform(0, 1) <= self.state_rar[self.last_state]: # exploration
140 | action = rand.randint(0, self.num_actions - 1)
141 | self.state_rar[self.last_state] *= self.radr
142 | else: # exploitation
143 | action = np.argmax(self.q_table[s_prime])
144 | assert action < self.num_actions
145 | self.q_table[self.last_state][self.last_action] = (1.0 - self.alpha) * self.q_table[self.last_state][
146 | self.last_action] + self.alpha * (
147 | r + self.gamma * self.q_table[s_prime][np.argmax(self.q_table[s_prime])]) # bellman
148 |
149 | if self.verbose: print "s =", s_prime, "a =", action, "r =", r
150 |
151 | # update state and action in Qlearn
152 | self.last_state = s_prime
153 | self.last_action = action
154 |
155 | # dyna ( combine model free and with model learning)
156 | ## dyna: update T and R table
157 | while self.dyna > 0:
158 | # update T
159 | transition_index = self.num_states * self.last_action + self.last_state
160 | self.t_table[transition_index][s_prime] += self.t_table_increment_unit
161 |
162 | # update R
163 | self.r_table[self.last_state][self.last_action] = (1 - self.dyna_lr) * self.r_table[self.last_state][
164 | self.last_action] + self.dyna_lr * r
165 |
166 | ## dyna hallucinate
167 | dyna_state = rand.randint(0, self.num_states - 1)
168 | dyna_action = rand.randint(0, self.num_actions - 1)
169 | transition_index = self.num_states * dyna_action + dyna_state
170 | transition_prob = self.t_table[transition_index] / np.sum(self.t_table[transition_index])
171 | state_infer_from_t_table = np.random.choice(range(self.num_states), 1, p=transition_prob)[0]
172 | dyna_r = self.r_table[dyna_state][dyna_action]
173 |
174 | ## dyna update Q table
175 | self.q_table[dyna_state][dyna_action] = (1.0 - self.alpha) * self.q_table[dyna_state][
176 | dyna_action] + self.alpha * (dyna_r + self.gamma * self.q_table[state_infer_from_t_table][
177 | np.argmax(self.q_table[state_infer_from_t_table])])
178 | self.dyna -= 1
179 | # end of dyna
180 | self.dyna = self.dyna_init
181 |
182 | return action
183 |
184 | def save_model(self, table_name="q_learner_tables.pkl"):
185 | """
186 | save trained q learner, aka the q table (and t table and r table if dyna is included)
187 | :param table_name:saved table name
188 | """
189 | tables = [self.q_table, self.t_table, self.r_table]
190 | save(tables, table_name)
191 |
192 | def load_model(self, table_name="q_learner_tables.pkl"):
193 | """
194 | load trained q learner, aka the q table (and t table and r table if dyna is included)
195 | :param table_name:saved table name
196 | """
197 | [self.q_table, self.t_table, self.r_table] = load(table_name)
198 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Q-learning trader
2 | A technical q trader on SPY given macd.
3 |
4 | ## run
5 | `python teststratgy.py`
6 |
7 | ## structure
8 | QLearner.py: an independent tabular (dyna)q-learner.
9 |
10 | StrategyLearner.py: Build upon Qlearner.py to learn the trading strategy
11 |
12 | testStrategy.py: train and test the StrategyLearner
13 |
14 | util.py: some helper functions for the model
15 |
16 |
17 | ## experiment result
18 |
19 | please see [this link](https://www.evernote.com/shard/s120/sh/00a11079-5e8d-4243-87e5-5daaf8565836/ebafa9649081c29dc98d27cb7c3a8a18)
20 |
21 | or the evernote screenshot below:
22 | 
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/StrategyLearner.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | Q trader strategy learner. States are classified by kmeans now, Implementation 2017 Wenchen Li
5 | """
6 | # Copyright (C) 2017 Wenchen Li
7 | #
8 | # This program is free software: you can redistribute it and/or modify
9 | # it under the terms of the GNU Lesser General Public License as published by
10 | # the Free Software Foundation, either version 3 of the License, or
11 | # (at your option) any later version.
12 | #
13 | # This program is distributed in the hope that it will be useful,
14 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
15 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 | # GNU Lesser General Public License for more details.
17 | #
18 | # You should have received a copy of the GNU Lesser General Public License
19 | # along with this program. If not, see .
20 |
21 | import os
22 | import datetime as dt
23 | import QLearner as ql
24 | import pandas as pd
25 | import util as ut
26 | from sklearn.cluster import KMeans
27 | from sklearn.neighbors import KernelDensity
28 | import pickle
29 | import numpy as np
30 |
31 | CASH = "Cash"
32 | STOCK = "Stock"
33 |
34 | TRAIN_DIR = "./training_dir/"
35 | kmeans_model_save_name = 'kmeans_model.pkl'
36 | dyna_q_trader_model_save_name = "q_learner_tables.pkl"
37 | training_record_save_name = "records.pkl"
38 |
39 |
40 | class RawTradeFeatures(object):
41 | """
42 | get raw trade features like adjust closed price and volume of the trading
43 | """
44 |
45 | def __init__(self, symbol, sd, ed, ):
46 | self.syms = [symbol]
47 | self.dates = pd.date_range(sd, ed)
48 |
49 | def get_adj_price(self):
50 | prices_all = ut.get_data(self.syms, self.dates) # automatically adds SPY
51 | prices = prices_all[self.syms] # only portfolio symbols
52 | prices_SPY = prices_all['SPY'] # only SPY, for comparison later
53 |
54 | return prices, prices_SPY
55 |
56 | def get_vol(self):
57 | ## example use with new colname
58 | volume_all = ut.get_data(self.syms, self.dates, colname="Volume") # automatically adds SPY
59 | volume = volume_all[self.syms] # only portfolio symbols
60 | volume_SPY = volume_all['SPY'] # only SPY, for comparison later
61 |
62 | return volume, volume_SPY
63 |
64 |
65 | class StrategyLearner(object):
66 | """
67 | For the policy learning part:
68 |
69 | Select several technical features, and compute their values for the training data
70 | Discretize the values of the features
71 | Instantiate a Q-learner
72 | For each day in the training data:
73 | Compute the current state (including holding)
74 | Compute the reward for the last action
75 | Query the learner with the current state and reward to get an action
76 | Implement the action the learner returned (BUY, SELL, NOTHING), and update portfolio value
77 | """
78 |
79 | # constructor
80 | def __init__(self, verbose=True, save_dir=""):
81 | self.verbose = verbose
82 | self.Nothing = 0
83 | self.Buy = 1
84 | self.Sell = 2
85 | self.num_holding_state = 3 # 0 for 0, 1 for long, 2 for short
86 | self.num_feature_state = 100
87 | self.num_state = self.num_holding_state * self.num_feature_state
88 | self.num_action = 3
89 | self.shares_to_buy_or_sell = 100
90 | self.current_holding_state = 0 # prepare holding state, long, short, 0
91 | self.last_r = 0.0
92 | self.portfolio = {}
93 | self.num_epoch = 1000
94 | self.stop_threshold = 1.0
95 | self.epsilon = .01
96 | self.learner_dyan_iter = 0 # 200
97 | # keep records of each epoch and within each epoch the transaction
98 | self.records = [] # element for each epoch is (current_state,action,value, reward)
99 |
100 | self.negative_return_punish_factor = 1.3
101 | # save paths
102 | self.save_dir = save_dir
103 | self.current_working_dir = TRAIN_DIR + save_dir + "/"
104 | if not os.path.exists(self.current_working_dir):
105 | os.makedirs(self.current_working_dir)
106 |
107 | def get_current_portfolio_values(self, today_stock_price):
108 | return self.portfolio[CASH] + self.portfolio[STOCK] * today_stock_price
109 |
110 | def init_portfolio(self, sv,):
111 | self.portfolio[CASH] = float(sv)
112 | self.portfolio[STOCK] = 0
113 |
114 | def get_raw_data(self, symbol, sd, ed):
115 | # record in panda format
116 | dates = pd.date_range(sd, ed)
117 | prices_all = ut.get_data([symbol], dates) # automatically adds SPY
118 | trades = prices_all[[symbol, ]] # only portfolio symbols
119 | portfolio_value = prices_all[[symbol, ]]
120 |
121 | # Select several technical features, and compute their values for the training data
122 |
123 | tf = RawTradeFeatures(symbol, sd, ed)
124 | price, price_SPY = tf.get_adj_price()
125 |
126 | return dates,prices_all, trades, portfolio_value, tf, price, price_SPY
127 |
128 | def get_benchmark(self,symbol, price_SPY, prices_all,sv ):
129 | """
130 | get buy and hold benchmark result
131 | """
132 | # buy and hold the benchmark
133 | benchmark_price = price_SPY.as_matrix()
134 | benchmark_values = prices_all[[symbol, ]] # only portfolio symbols
135 | benchmark_num_stock = int(sv / benchmark_price[0])
136 | cash = sv - float(benchmark_num_stock * benchmark_price[0])
137 | for i, p in enumerate(benchmark_price):
138 | benchmark_values.values[i, :] = cash + p * benchmark_num_stock
139 |
140 | return benchmark_values
141 |
142 | def get_derived_data(self,symbol, prices_all,sd,ed):
143 | """
144 | get derived data macd and bollinger band given the raw data
145 | """
146 | # macd
147 | macd = ut.norm(ut.get_macd(prices_all[symbol])["MACD"].as_matrix())
148 |
149 | # bollinger band
150 | bb = ut.Bollinger_Bands_given_sym_dates([symbol], sd, ed)
151 | bb_rm = ut.norm(bb['rolling_mean'].as_matrix())
152 | bb_ub = ut.norm(bb['upper_band'].as_matrix())
153 | bb_lb = ut.norm(bb['lower_band'].as_matrix())
154 |
155 | return macd, bb,bb_rm, bb_ub, bb_lb
156 |
157 | def get_finalized_input(self,price, symbol,macd):
158 | """
159 | reformatted the finalized input macd and price
160 | """
161 | # input to model or later process
162 | l = price[symbol].as_matrix()
163 | price_array = l.copy()
164 |
165 | x = macd.reshape((-1, 1))
166 |
167 | return price_array, x
168 |
169 | def perform_actions_update_trades_value(self,action,price_array,trades,i):
170 | """
171 | given action perform the action and record the trades value.
172 | """
173 | if action == 0: # do nothing
174 | if self.verbose: print "do nothing"
175 | elif action == 1: # buy
176 | if self.current_holding_state == 0 or self.current_holding_state == 2: # holding nothing or short
177 | self.portfolio[CASH] -= self.shares_to_buy_or_sell * price_array[i]
178 | self.portfolio[STOCK] += self.shares_to_buy_or_sell
179 | elif self.current_holding_state == 1: # long
180 | if self.verbose: print "buy but long already, nothing to do"
181 |
182 | else: # action sell
183 | if self.current_holding_state == 0 or self.current_holding_state == 1: # holding nothing or long
184 | self.portfolio[CASH] += self.shares_to_buy_or_sell * price_array[i]
185 | self.portfolio[STOCK] -= self.shares_to_buy_or_sell
186 | elif self.current_holding_state == 2: # short
187 | if self.verbose: print "sell but short already, nothing to do"
188 |
189 | assert np.abs(self.portfolio[STOCK]) <= self.shares_to_buy_or_sell
190 | # update self.holding state
191 | if self.portfolio[STOCK] == self.shares_to_buy_or_sell:
192 | self.current_holding_state = 1
193 | elif self.portfolio[STOCK] == -self.shares_to_buy_or_sell:
194 | self.current_holding_state = 2
195 | elif self.portfolio[STOCK] == 0:
196 | self.current_holding_state = 0
197 | else:
198 | if self.verbose: print self.portfolio, "current portfolio is not valid"
199 |
200 | trades.values[i, :] = self.shares_to_buy_or_sell
201 | if action == 0:
202 | trades.values[i, :] *= 0
203 | elif action == 1:
204 | trades.values[i, :] *= 1
205 | else:
206 | trades.values[i, :] *= -1
207 |
208 | return trades
209 |
210 | def save_plot_and_model(self,benchmark_values, portfolio_value,trades,save_plot=False):
211 | """
212 | save trained model and plot result if save_plot set to True
213 | """
214 | # save transaction and portfolio image and training records
215 | benchmark_values = benchmark_values.rename(columns={'SPY': "benchmark"})
216 | portfolio_value = portfolio_value.rename(columns={'SPY': "q-learn-trader"})
217 | if save_plot:
218 | p_value_all = portfolio_value.join(benchmark_values)
219 | ut.plot_data(trades, title="transactions_train", ylabel="amount", save_image=True,
220 | save_dir=self.current_working_dir)
221 | ut.plot_data(p_value_all, title="portfolio value_train", ylabel="USD", save_image=True,
222 | save_dir=self.current_working_dir)
223 | self.learner.save_model(table_name=self.current_working_dir + dyna_q_trader_model_save_name)
224 |
225 | def addEvidence(self, symbol="SPY",
226 | sd=dt.datetime(2008, 1, 1),
227 | ed=dt.datetime(2009, 1, 1),
228 | sv=10000):
229 | """
230 | train q learner
231 | :param symbol:security symbol
232 | :param sd: start date
233 | :param ed: end data
234 | :param sv: start fund value
235 | :return: return of the trade
236 | """
237 |
238 | self.init_portfolio(sv)
239 |
240 | dates, prices_all, trades, portfolio_value, tf, price, price_SPY = self.get_raw_data(symbol, sd, ed)
241 |
242 | benchmark_values = self.get_benchmark( symbol, price_SPY, prices_all, sv)
243 |
244 | macd, bb, bb_rm, bb_ub, bb_lb = self.get_derived_data( symbol, prices_all, sd, ed)
245 |
246 | price_array, x = self.get_finalized_input( price, symbol, macd)
247 |
248 | # discretize
249 | kmeans_model = KMeans(n_clusters=self.num_feature_state, random_state=0, )
250 | kmeans = kmeans_model.fit(x)
251 | pickle.dump(kmeans_model, open(self.current_working_dir + kmeans_model_save_name, 'wb'))
252 | feature_states = kmeans.labels_
253 |
254 | # Instantiate a Q-learner
255 | self.learner = ql.QLearner(num_states=self.num_state, num_actions=self.num_action, rar=.5, alpha=.001,
256 | alpha_decay=.99, dyna=self.learner_dyan_iter)
257 | self.last_cumulated_return = 0.0
258 | self.cumulated_return_indicator = 0
259 |
260 | # value iteration
261 | for k in xrange(self.num_epoch):
262 | for i, s in enumerate(feature_states):
263 | if i == len(feature_states) - 1:
264 | portfolio_value.values[i, :] = self.get_current_portfolio_values(price_array[-1])
265 | continue # skip last because 2nd day price
266 | # compute the current state(include holding)
267 | current_holding_state = self.current_holding_state
268 | current_feature_state = s
269 | current_state = self.num_feature_state * current_holding_state + current_feature_state
270 |
271 | # computer the last reward
272 | r = self.last_r
273 | if r < 0: # punish negative reward
274 | r *= self.negative_return_punish_factor
275 | # Query the learner with the current state and reward to get an action
276 | action = self.learner.query(current_state, r)
277 |
278 | # Implement the action the learner returned (BUY, SELL, NOTHING), and update portfolio value
279 | last_portfolio_value = self.get_current_portfolio_values(price_array[i]) # sum(self.portfolio_values)
280 |
281 | trades = self.perform_actions_update_trades_value(action, price_array, trades, i)
282 |
283 | portfolio_value.values[i, :] = self.get_current_portfolio_values(price_array[i])
284 | self.last_r = (self.get_current_portfolio_values(price_array[i + 1]) - last_portfolio_value) / float(sv)
285 | if self.verbose: print self.last_r
286 |
287 | if self.verbose: print "epoch", k, " current cumulated return:", self.get_current_portfolio_values(
288 | price_array[-1]) / float(sv) - 1.0, "portfolio:", self.portfolio
289 | self.cumulated_return_indicator = self.last_cumulated_return / (
290 | self.get_current_portfolio_values(price_array[-1]) / float(sv))
291 | self.last_cumulated_return = self.get_current_portfolio_values(price_array[-1]) / float(sv) - 1
292 |
293 | if self.num_epoch - 1 != k:
294 | # rest portfolio
295 | self.portfolio[CASH] = sv
296 | self.portfolio[STOCK] = 0
297 | self.current_holding_state = 0
298 | self.last_r = 0.0
299 | # decay alpha
300 | self.learner.decay_alpha()
301 |
302 | self.save_plot_and_model(benchmark_values, portfolio_value, trades)
303 |
304 | trade_return = self.get_current_portfolio_values(price_array[-1]) / sv - 1.0
305 | return trade_return
306 |
307 | # this method should use the existing policy and test it against new data
308 |
309 |
310 | def testPolicy(self, symbol="IBM",
311 | sd=dt.datetime(2009, 1, 1),
312 | ed=dt.datetime(2010, 1, 1),
313 | sv=10000):
314 | """
315 | test q learner
316 | :param symbol:security symbol
317 | :param sd: start date
318 | :param ed: end data
319 | :param sv: start fund value
320 | :return: trades: record of the trades, trade_return: return of the trade
321 | """
322 | self.init_portfolio(sv)
323 |
324 | dates, prices_all, trades, portfolio_value, tf, price, price_SPY = self.get_raw_data(symbol, sd, ed)
325 |
326 | benchmark_values = self.get_benchmark(symbol, price_SPY, prices_all, sv)
327 |
328 | macd, bb, bb_rm, bb_ub, bb_lb = self.get_derived_data( symbol, prices_all, sd, ed)
329 |
330 | price_array, x = self.get_finalized_input( price, symbol, macd)
331 |
332 | # kmeans model load
333 | kmeans_model = pickle.load(open(self.current_working_dir + kmeans_model_save_name, 'rb'))
334 | kmeans = kmeans_model.predict(x)
335 | feature_states = kmeans
336 |
337 | # Instantiate a Q-learner
338 | self.learner = ql.QLearner(num_states=self.num_state, num_actions=self.num_action)
339 | self.last_cumulated_return = 0.0
340 | self.cumulated_return_indicator = 0
341 | # load trained q table(if dyna,t_table and r_table)
342 | self.learner.load_model(table_name=self.current_working_dir + dyna_q_trader_model_save_name)
343 |
344 | # value iteration
345 | for i, s in enumerate(feature_states):
346 | if i == len(feature_states) - 1:
347 | portfolio_value.values[i, :] = self.get_current_portfolio_values(price_array[-1])
348 | continue # skip last because 2nd day price
349 |
350 | # compute the current state(include holding)
351 | current_holding_state = self.current_holding_state
352 | current_feature_state = s
353 | current_state = self.num_feature_state * current_holding_state + current_feature_state
354 |
355 | # Query the learner with the current state
356 | action = self.learner.querysetstate(current_state)
357 |
358 | # Implement the action the learner returned (BUY, SELL, NOTHING), and update portfolio value
359 | last_portfolio_value = self.get_current_portfolio_values(price_array[i]) # sum(self.portfolio_values)
360 |
361 | trades = self.perform_actions_update_trades_value(action, price_array, trades, i)
362 |
363 | portfolio_value.values[i, :] = self.get_current_portfolio_values(price_array[i])
364 |
365 | self.last_r = (self.get_current_portfolio_values(price_array[i + 1]) - last_portfolio_value) / float(sv)
366 | if self.verbose: print self.last_r
367 |
368 | self.last_cumulated_return = self.get_current_portfolio_values(price_array[-1]) / float(sv) - 1
369 |
370 | self.save_plot_and_model(benchmark_values, portfolio_value, trades)
371 |
372 | trade_return = self.get_current_portfolio_values(price_array[-1]) / sv - 1.0
373 | if self.verbose: print "cumulated return=:", trade_return
374 | return trades, trade_return
375 |
--------------------------------------------------------------------------------
/screencapture-evernote-client-web-2018-11-30-19_44_58.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WenchenLi/q-learning-trader/7f34453567914d2c980f540f030c228b50f574f4/screencapture-evernote-client-web-2018-11-30-19_44_58.jpg
--------------------------------------------------------------------------------
/testStrategy.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | Test a Strategy Learner. (c) 2017 Wenchen Li
5 | """
6 | # Copyright (C) 2017 Wenchen Li
7 | #
8 | # This program is free software: you can redistribute it and/or modify
9 | # it under the terms of the GNU Lesser General Public License as published by
10 | # the Free Software Foundation, either version 3 of the License, or
11 | # (at your option) any later version.
12 | #
13 | # This program is distributed in the hope that it will be useful,
14 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
15 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 | # GNU Lesser General Public License for more details.
17 | #
18 | # You should have received a copy of the GNU Lesser General Public License
19 | # along with this program. If not, see .
20 |
21 | import pandas as pd
22 | import datetime as dt
23 | import util as ut
24 | import StrategyLearner as sl
25 |
26 | def test_code(train_q=None,out_q=None,verb = True, test_only=False,save_dir=".",):
27 |
28 | # instantiate the strategy learner
29 | learner = sl.StrategyLearner(verbose = verb,save_dir=save_dir)
30 |
31 | if not test_only:
32 | # set parameters for training the learner
33 | sym = "SPY"
34 | stdate =dt.datetime(2006,4,30)
35 | enddate =dt.datetime(2007,1,3) # just a few days for "shake out"
36 |
37 | # train the learner
38 | trade_return_train = learner.addEvidence(symbol = sym, sd = stdate,
39 | ed = enddate, sv = 10000)
40 | if train_q:
41 | train_q.put(trade_return_train)
42 | # set parameters for testing
43 | sym = "SPY"
44 | stdate =dt.datetime(2007,1,3)
45 | enddate =dt.datetime(2007,12,31)
46 |
47 | # get some data for reference
48 | syms=[sym]
49 | dates = pd.date_range(stdate, enddate)
50 | prices_all = ut.get_data(syms, dates) # automatically adds SPY
51 | prices = prices_all[syms] # only portfolio symbols
52 | # if verb: print prices
53 |
54 | # test the learner
55 | df_trades,trade_return = learner.testPolicy(symbol = sym, sd = stdate, \
56 | ed = enddate, sv = 10000)
57 | if out_q:
58 | out_q.put(trade_return)
59 | # a few sanity checks
60 | # df_trades should be a single column DataFrame (not a series)
61 | # including only the values 500, 0, -500
62 | if isinstance(df_trades, pd.DataFrame) == False:
63 | print "Returned result is not a DataFrame"
64 | if prices.shape != df_trades.shape:
65 | print "Returned result is not the right shape"
66 |
67 | if __name__=="__main__":
68 | from multiprocessing import Process,Queue
69 | import numpy as np
70 |
71 | train_q = Queue()
72 | out_q = Queue()
73 | train_resultdict = {}
74 | test_resultdict = {}
75 | total_num_simulation_left = 1000
76 |
77 | while total_num_simulation_left>0:
78 |
79 | nprocs_each_iter = 10
80 | procs = []
81 | current_training_ids = []
82 | for i in xrange(nprocs_each_iter):
83 | training_id = total_num_simulation_left - i
84 | current_training_ids.append(training_id)
85 | proc = Process(target=test_code,args=(train_q,out_q,False,False,str(training_id),))
86 | procs.append(proc)
87 | proc.start()
88 |
89 | for training_id in current_training_ids:
90 | train_resultdict[training_id] = train_q.get()
91 | test_resultdict[training_id]= out_q.get()
92 |
93 | for p in procs:
94 | p.join()
95 |
96 | total_num_simulation_left -= nprocs_each_iter
97 |
98 | print "train:"
99 | print "each strategy return:", train_resultdict.values()
100 | print "mean return:", np.average(train_resultdict.values())
101 | print "std return:", np.std(train_resultdict.values())
102 |
103 | mean = np.average(test_resultdict.values())
104 | std = np.std(test_resultdict.values())
105 | yearly_risk_free_rate = .05 # https://www.treasury.gov/resource-center/data-chart-center/interest-rates/Pages/TextView.aspx?data=yieldYear&year=2007
106 | print "test:"
107 | print "each strategy return:",test_resultdict.values()
108 | print "mean return:",mean
109 | print "std return:",std
110 | print "max, min return:",max(test_resultdict.values()), min(test_resultdict.values())
111 | print "sharpe ratio:", (mean - yearly_risk_free_rate)/std
112 |
113 |
114 |
115 |
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """ Utility code."""
4 |
5 | # Copyright (C) 2017 Wenchen Li
6 | #
7 | # This program is free software: you can redistribute it and/or modify
8 | # it under the terms of the GNU Lesser General Public License as published by
9 | # the Free Software Foundation, either version 3 of the License, or
10 | # (at your option) any later version.
11 | #
12 | # This program is distributed in the hope that it will be useful,
13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | # GNU Lesser General Public License for more details.
16 | #
17 | # You should have received a copy of the GNU Lesser General Public License
18 | # along with this program. If not, see .
19 |
20 | import os
21 | import pandas as pd
22 | import matplotlib.pyplot as plt
23 | import pickle
24 | import datetime as dt
25 | import numpy as np
26 |
27 |
28 | def symbol_to_path(symbol, base_dir=os.path.join(".", "yahoo_finance_data")):
29 | """Return CSV file path given ticker symbol."""
30 | return os.path.join(base_dir, "{}.csv".format(str(symbol)))
31 |
32 |
33 | def get_data(symbols, dates, addSPY=True, colname = 'Adj Close'):
34 | """Read stock data (adjusted close) for given symbols from CSV files."""
35 | df = pd.DataFrame(index=dates)
36 | if addSPY and 'SPY' not in symbols: # add SPY for reference, if absent
37 | symbols = ['SPY'] + symbols
38 |
39 | for symbol in symbols:
40 | df_temp = pd.read_csv(symbol_to_path(symbol), index_col='Date',
41 | parse_dates=True, usecols=['Date', colname], na_values=['nan'])
42 | df_temp = df_temp.rename(columns={colname: symbol})
43 | df = df.join(df_temp)
44 | if symbol == 'SPY': # drop dates SPY did not trade
45 | df = df.dropna(subset=["SPY"])
46 |
47 | return df
48 |
49 |
50 | def plot_data(df, title="Stock prices", xlabel="Date", ylabel="Price",save_image=False,save_dir="./"):
51 | """Plot stock prices with a custom title and meaningful axis labels."""
52 | ax = df.plot(title=title, fontsize=12)
53 | ax.set_xlabel(xlabel)
54 | ax.set_ylabel(ylabel)
55 | if not save_image:
56 | plt.show()
57 | else:
58 | plt.savefig(save_dir+title)
59 | plt.close()
60 |
61 | def get_macd(group):
62 |
63 | def moving_average(group, n=9):
64 | sma = pd.rolling_mean(group, n)
65 | return sma
66 |
67 | def moving_average_convergence(group, nslow=26, nfast=12):
68 | emaslow = pd.Series.ewm(group, span=nslow, min_periods=1).mean()
69 | emafast = pd.Series.ewm(group, span=nfast, min_periods=1).mean()
70 | result = pd.DataFrame({'MACD': emafast - emaslow, 'emaSlw': emaslow, 'emaFst': emafast})
71 | return result
72 |
73 | return moving_average_convergence(group)
74 |
75 |
76 | def Bollinger_Bands(stock_price, window_size, num_of_std):
77 | rolling_mean = stock_price.rolling(window=window_size).mean()
78 | rolling_std = stock_price.rolling(window=window_size).std()
79 | upper_band = rolling_mean + (rolling_std * num_of_std)
80 | lower_band = rolling_mean - (rolling_std * num_of_std)
81 |
82 | return rolling_mean, upper_band, lower_band
83 |
84 |
85 | def Bollinger_Bands_given_sym_dates(sym, start_date,end_date,window_size=20, num_of_std=2):
86 |
87 | dates = pd.date_range(start_date - dt.timedelta(window_size*2-10), end_date) #TODO think better choose nan dates
88 | stock_price = get_data(sym, dates)
89 |
90 | rolling_mean, upper_band, lower_band = Bollinger_Bands(stock_price["SPY"], window_size,num_of_std)
91 | retrive_dates = pd.date_range(start_date, end_date)
92 | result = pd.DataFrame({'rolling_mean': rolling_mean, 'upper_band': upper_band, 'lower_band': lower_band},index=dates)
93 | result = result.dropna()
94 | return result
95 |
96 |
97 | def momentum(sym,start_date,end_date,window_size=10):
98 | dates = pd.date_range(start_date - dt.timedelta(window_size), end_date) # TODO think better choose nan dates
99 | stock_price = get_data(sym, dates)
100 |
101 | # print rolling_mean, upper_band,lower_band
102 | # M =
103 | result = pd.DataFrame()
104 | result = result.dropna()
105 | return result
106 |
107 |
108 | def norm(l):
109 | l = np.array(l)
110 | return (l - l.min()) / (l.max() - l.min())
111 |
112 | def save(object,file_path):
113 | with open(file_path,"wb") as handle:
114 | pickle.dump(object,handle)
115 |
116 | def load(file_path):
117 | with open(file_path,"rb") as handle:
118 | obj = pickle.load(handle)
119 | return obj
120 |
121 |
122 |
123 | if __name__=="__main__":
124 | #plot test
125 | sym = "GOOG"
126 | stdate = dt.datetime(2007, 1, 3)
127 | enddate = dt.datetime(2007, 12, 31)
128 | syms = [sym]
129 | dates = pd.date_range(stdate, enddate)
130 | prices_all = get_data(syms, dates) # automatically adds SPY
131 | print prices_all
132 | # plot_data(prices_all)
133 |
134 | # test macd
135 | # record in panda format
136 | stdate = dt.datetime(2007, 1, 3)
137 | enddate = dt.datetime(2007, 12, 31)
138 | sym = ["GOOG"]
139 | dates = pd.date_range(stdate, enddate)
140 | prices_all = get_data(sym, dates) # automatically adds SPY
141 | print get_macd(prices_all["SPY"])["MACD"].as_matrix()
142 |
143 | # test Bollinger band #TODO retrieve the first missing window data
144 |
145 | # print Bollinger_Bands(prices_all["SPY"], 20, 2)
146 | bb = Bollinger_Bands_given_sym_dates(sym,stdate,enddate)
147 | print bb
148 |
149 | # momentum
150 | m = momentum(sym, stdate, enddate)
151 | print m
--------------------------------------------------------------------------------