├── README.md ├── agent └── agent.py ├── data ├── ^GSPC.csv └── ^GSPC_2011.csv ├── evaluate.py ├── functions.py ├── images ├── AAPL_2015.png ├── AAPL_2016.png ├── BABA_2014.png ├── BABA_2015.png ├── GOOG_8_2017.png ├── ^GSPC_2014.png └── ^GSPC_2015.png └── train.py /README.md: -------------------------------------------------------------------------------- 1 | ## Overview 2 | 3 | This is the code for [this](https://www.youtube.com/watch?v=05NqKJ0v7EE) video on Youtube by Siraj Raval. The author of this code is [edwardhdlu](https://github.com/edwardhdlu/q-trader) . It's implementation of Q-learning applied to (short-term) stock trading. The model uses n-day windows of closing prices to determine if the best action to take at a given time is to buy, sell or sit. 4 | 5 | As a result of the short-term state representation, the model is not very good at making decisions over long-term trends, but is quite good at predicting peaks and troughs. 6 | 7 | ## Results 8 | 9 | Some examples of results on test sets: 10 | 11 | ![^GSPC 2015](https://github.com/edwardhdlu/q-trader/blob/master/images/^GSPC_2015.png) 12 | S&P 500, 2015. Profit of $431.04. 13 | 14 | ![BABA_2015](https://github.com/edwardhdlu/q-trader/blob/master/images/BABA_2015.png) 15 | Alibaba Group Holding Ltd, 2015. Loss of $351.59. 16 | 17 | ![AAPL 2016](https://github.com/edwardhdlu/q-trader/blob/master/images/AAPL_2016.png) 18 | Apple, Inc, 2016. Profit of $162.73. 19 | 20 | ![GOOG_8_2017](https://github.com/edwardhdlu/q-trader/blob/master/images/GOOG_8_2017.png) 21 | Google, Inc, August 2017. Profit of $19.37. 22 | 23 | ## Running the Code 24 | 25 | To train the model, download a training and test csv files from [Yahoo! Finance](https://ca.finance.yahoo.com/quote/%5EGSPC/history?p=%5EGSPC) into `data/` 26 | ``` 27 | mkdir model 28 | python train ^GSPC 10 1000 29 | ``` 30 | 31 | Then when training finishes (minimum 200 episodes for results): 32 | ``` 33 | python evaluate.py ^GSPC_2011 model_ep1000 34 | ``` 35 | 36 | ## References 37 | 38 | [Deep Q-Learning with Keras and Gym](https://keon.io/deep-q-learning/) - Q-learning overview and Agent skeleton code 39 | -------------------------------------------------------------------------------- /agent/agent.py: -------------------------------------------------------------------------------- 1 | import keras 2 | from keras.models import Sequential 3 | from keras.models import load_model 4 | from keras.layers import Dense 5 | from keras.optimizers import Adam 6 | 7 | import numpy as np 8 | import random 9 | from collections import deque 10 | 11 | class Agent: 12 | def __init__(self, state_size, is_eval=False, model_name=""): 13 | self.state_size = state_size # normalized previous days 14 | self.action_size = 3 # sit, buy, sell 15 | self.memory = deque(maxlen=1000) 16 | self.inventory = [] 17 | self.model_name = model_name 18 | self.is_eval = is_eval 19 | 20 | self.gamma = 0.95 21 | self.epsilon = 1.0 22 | self.epsilon_min = 0.01 23 | self.epsilon_decay = 0.995 24 | 25 | self.model = load_model("models/" + model_name) if is_eval else self._model() 26 | 27 | def _model(self): 28 | model = Sequential() 29 | model.add(Dense(units=64, input_dim=self.state_size, activation="relu")) 30 | model.add(Dense(units=32, activation="relu")) 31 | model.add(Dense(units=8, activation="relu")) 32 | model.add(Dense(self.action_size, activation="linear")) 33 | model.compile(loss="mse", optimizer=Adam(lr=0.001)) 34 | 35 | return model 36 | 37 | def act(self, state): 38 | if not self.is_eval and random.random() <= self.epsilon: 39 | return random.randrange(self.action_size) 40 | 41 | options = self.model.predict(state) 42 | return np.argmax(options[0]) 43 | 44 | def expReplay(self, batch_size): 45 | mini_batch = [] 46 | l = len(self.memory) 47 | for i in range(l - batch_size + 1, l): 48 | mini_batch.append(self.memory[i]) 49 | 50 | for state, action, reward, next_state, done in mini_batch: 51 | target = reward 52 | if not done: 53 | target = reward + self.gamma * np.amax(self.model.predict(next_state)[0]) 54 | 55 | target_f = self.model.predict(state) 56 | target_f[0][action] = target 57 | self.model.fit(state, target_f, epochs=1, verbose=0) 58 | 59 | if self.epsilon > self.epsilon_min: 60 | self.epsilon *= self.epsilon_decay 61 | -------------------------------------------------------------------------------- /data/^GSPC_2011.csv: -------------------------------------------------------------------------------- 1 | Date,Open,High,Low,Close,Adj Close,Volume 2 | 2011-01-03,1257.619995,1276.170044,1257.619995,1271.869995,1271.869995,4286670000 3 | 2011-01-04,1272.949951,1274.119995,1262.660034,1270.199951,1270.199951,4796420000 4 | 2011-01-05,1268.780029,1277.630005,1265.359985,1276.560059,1276.560059,4764920000 5 | 2011-01-06,1276.290039,1278.170044,1270.430054,1273.849976,1273.849976,4844100000 6 | 2011-01-07,1274.410034,1276.829956,1261.699951,1271.500000,1271.500000,4963110000 7 | 2011-01-10,1270.839966,1271.520020,1262.180054,1269.750000,1269.750000,4036450000 8 | 2011-01-11,1272.579956,1277.250000,1269.619995,1274.479980,1274.479980,4050750000 9 | 2011-01-12,1275.650024,1286.869995,1275.650024,1285.959961,1285.959961,4226940000 10 | 2011-01-13,1285.780029,1286.699951,1280.469971,1283.760010,1283.760010,4310840000 11 | 2011-01-14,1282.900024,1293.239990,1281.239990,1293.239990,1293.239990,4661590000 12 | 2011-01-18,1293.219971,1296.060059,1290.160034,1295.020020,1295.020020,5284990000 13 | 2011-01-19,1294.520020,1294.599976,1278.920044,1281.920044,1281.920044,4743710000 14 | 2011-01-20,1280.849976,1283.349976,1271.260010,1280.260010,1280.260010,4935320000 15 | 2011-01-21,1283.630005,1291.209961,1282.069946,1283.349976,1283.349976,4935320000 16 | 2011-01-24,1283.290039,1291.930054,1282.469971,1290.839966,1290.839966,3902470000 17 | 2011-01-25,1288.170044,1291.260010,1281.069946,1291.180054,1291.180054,4595380000 18 | 2011-01-26,1291.969971,1299.739990,1291.969971,1296.630005,1296.630005,4730980000 19 | 2011-01-27,1297.510010,1301.290039,1294.410034,1299.540039,1299.540039,4309190000 20 | 2011-01-28,1299.630005,1302.670044,1275.099976,1276.339966,1276.339966,5618630000 21 | 2011-01-31,1276.500000,1287.170044,1276.500000,1286.119995,1286.119995,4167160000 22 | 2011-02-01,1289.140015,1308.859985,1289.140015,1307.589966,1307.589966,5164500000 23 | 2011-02-02,1305.910034,1307.609985,1302.619995,1304.030029,1304.030029,4098260000 24 | 2011-02-03,1302.770020,1308.599976,1294.829956,1307.099976,1307.099976,4370990000 25 | 2011-02-04,1307.010010,1311.000000,1301.670044,1310.869995,1310.869995,3925950000 26 | 2011-02-07,1311.849976,1322.849976,1311.849976,1319.050049,1319.050049,3902270000 27 | 2011-02-08,1318.760010,1324.869995,1316.030029,1324.569946,1324.569946,3881530000 28 | 2011-02-09,1322.479980,1324.540039,1314.890015,1320.880005,1320.880005,3922240000 29 | 2011-02-10,1318.130005,1322.780029,1311.739990,1321.869995,1321.869995,4184610000 30 | 2011-02-11,1318.660034,1330.790039,1316.079956,1329.150024,1329.150024,4219300000 31 | 2011-02-14,1328.729980,1332.959961,1326.900024,1332.319946,1332.319946,3567040000 32 | 2011-02-15,1330.430054,1330.430054,1324.609985,1328.010010,1328.010010,3926860000 33 | 2011-02-16,1329.510010,1337.609985,1329.510010,1336.319946,1336.319946,1966450000 34 | 2011-02-17,1334.369995,1341.500000,1331.000000,1340.430054,1340.430054,1966450000 35 | 2011-02-18,1340.380005,1344.069946,1338.119995,1343.010010,1343.010010,1162310000 36 | 2011-02-22,1338.910034,1338.910034,1312.329956,1315.439941,1315.439941,1322780000 37 | 2011-02-23,1315.439941,1317.910034,1299.550049,1307.400024,1307.400024,1330340000 38 | 2011-02-24,1307.089966,1310.910034,1294.260010,1306.099976,1306.099976,1222900000 39 | 2011-02-25,1307.339966,1320.609985,1307.339966,1319.880005,1319.880005,3836030000 40 | 2011-02-28,1321.609985,1329.380005,1320.550049,1327.219971,1327.219971,1252850000 41 | 2011-03-01,1328.640015,1332.089966,1306.140015,1306.329956,1306.329956,1180420000 42 | 2011-03-02,1305.469971,1314.189941,1302.579956,1308.439941,1308.439941,1025000000 43 | 2011-03-03,1312.369995,1332.280029,1312.369995,1330.969971,1330.969971,4340470000 44 | 2011-03-04,1330.729980,1331.079956,1312.589966,1321.150024,1321.150024,4223740000 45 | 2011-03-07,1322.719971,1327.680054,1303.989990,1310.130005,1310.130005,3964730000 46 | 2011-03-08,1311.050049,1325.739990,1306.859985,1321.819946,1321.819946,4531420000 47 | 2011-03-09,1319.920044,1323.209961,1312.270020,1320.020020,1320.020020,3709520000 48 | 2011-03-10,1315.719971,1315.719971,1294.209961,1295.109985,1295.109985,4723020000 49 | 2011-03-11,1293.430054,1308.349976,1291.989990,1304.280029,1304.280029,3740400000 50 | 2011-03-14,1301.189941,1301.189941,1286.369995,1296.390015,1296.390015,4050370000 51 | 2011-03-15,1288.459961,1288.459961,1261.119995,1281.869995,1281.869995,5201400000 52 | 2011-03-16,1279.459961,1280.910034,1249.050049,1256.880005,1256.880005,5833000000 53 | 2011-03-17,1261.609985,1278.880005,1261.609985,1273.719971,1273.719971,4134950000 54 | 2011-03-18,1276.709961,1288.880005,1276.180054,1279.209961,1279.209961,4685500000 55 | 2011-03-21,1281.650024,1300.579956,1281.650024,1298.380005,1298.380005,4223730000 56 | 2011-03-22,1298.290039,1299.349976,1292.699951,1293.770020,1293.770020,3576550000 57 | 2011-03-23,1292.189941,1300.510010,1284.050049,1297.540039,1297.540039,3842350000 58 | 2011-03-24,1300.609985,1311.339966,1297.739990,1309.660034,1309.660034,4223740000 59 | 2011-03-25,1311.800049,1319.180054,1310.150024,1313.800049,1313.800049,4223740000 60 | 2011-03-28,1315.449951,1319.739990,1310.189941,1310.189941,1310.189941,3215170000 61 | 2011-03-29,1309.369995,1319.449951,1305.260010,1319.439941,1319.439941,3482580000 62 | 2011-03-30,1321.890015,1331.739990,1321.890015,1328.260010,1328.260010,3809570000 63 | 2011-03-31,1327.439941,1329.770020,1325.030029,1325.829956,1325.829956,3566270000 64 | 2011-04-01,1329.479980,1337.849976,1328.890015,1332.410034,1332.410034,4223740000 65 | 2011-04-04,1333.560059,1336.739990,1329.099976,1332.869995,1332.869995,4223740000 66 | 2011-04-05,1332.030029,1338.209961,1330.030029,1332.630005,1332.630005,3852280000 67 | 2011-04-06,1335.939941,1339.380005,1331.089966,1335.540039,1335.540039,4223740000 68 | 2011-04-07,1334.819946,1338.800049,1326.560059,1333.510010,1333.510010,4005600000 69 | 2011-04-08,1336.160034,1339.459961,1322.939941,1328.170044,1328.170044,3582810000 70 | 2011-04-11,1329.010010,1333.770020,1321.060059,1324.459961,1324.459961,3478970000 71 | 2011-04-12,1321.959961,1321.959961,1309.510010,1314.160034,1314.160034,4275490000 72 | 2011-04-13,1314.030029,1321.349976,1309.189941,1314.410034,1314.410034,3850860000 73 | 2011-04-14,1311.130005,1316.790039,1302.420044,1314.520020,1314.520020,3872630000 74 | 2011-04-15,1314.540039,1322.880005,1313.680054,1319.680054,1319.680054,4223740000 75 | 2011-04-18,1313.349976,1313.349976,1294.699951,1305.140015,1305.140015,4223740000 76 | 2011-04-19,1305.989990,1312.699951,1303.969971,1312.619995,1312.619995,3886300000 77 | 2011-04-20,1319.119995,1332.660034,1319.119995,1330.359985,1330.359985,4236280000 78 | 2011-04-21,1333.229980,1337.489990,1332.829956,1337.380005,1337.380005,3587240000 79 | 2011-04-25,1337.140015,1337.550049,1331.469971,1335.250000,1335.250000,2142130000 80 | 2011-04-26,1336.750000,1349.550049,1336.750000,1347.239990,1347.239990,3908060000 81 | 2011-04-27,1348.430054,1357.489990,1344.250000,1355.660034,1355.660034,4051570000 82 | 2011-04-28,1353.859985,1361.709961,1353.599976,1360.479980,1360.479980,4036820000 83 | 2011-04-29,1360.140015,1364.560059,1358.689941,1363.609985,1363.609985,3479070000 84 | 2011-05-02,1365.209961,1370.579956,1358.589966,1361.219971,1361.219971,3846250000 85 | 2011-05-03,1359.760010,1360.839966,1349.520020,1356.619995,1356.619995,4223740000 86 | 2011-05-04,1355.900024,1355.900024,1341.500000,1347.319946,1347.319946,4223740000 87 | 2011-05-05,1344.160034,1348.000000,1329.170044,1335.099976,1335.099976,3846250000 88 | 2011-05-06,1340.239990,1354.359985,1335.579956,1340.199951,1340.199951,4223740000 89 | 2011-05-09,1340.199951,1349.439941,1338.640015,1346.290039,1346.290039,4265250000 90 | 2011-05-10,1348.339966,1359.439941,1348.339966,1357.160034,1357.160034,4223740000 91 | 2011-05-11,1354.510010,1354.510010,1336.359985,1342.079956,1342.079956,3846250000 92 | 2011-05-12,1339.390015,1351.050049,1332.030029,1348.650024,1348.650024,3777210000 93 | 2011-05-13,1348.689941,1350.469971,1333.359985,1337.770020,1337.770020,3426660000 94 | 2011-05-16,1334.770020,1343.329956,1327.319946,1329.469971,1329.469971,3846250000 95 | 2011-05-17,1326.099976,1330.420044,1318.510010,1328.979980,1328.979980,4053970000 96 | 2011-05-18,1328.540039,1341.819946,1326.589966,1340.680054,1340.680054,3922030000 97 | 2011-05-19,1342.400024,1346.819946,1336.359985,1343.599976,1343.599976,3626110000 98 | 2011-05-20,1342.000000,1342.000000,1330.670044,1333.270020,1333.270020,4066020000 99 | 2011-05-23,1333.069946,1333.069946,1312.880005,1317.369995,1317.369995,3255580000 100 | 2011-05-24,1317.699951,1323.719971,1313.869995,1316.280029,1316.280029,3846250000 101 | 2011-05-25,1316.359985,1325.859985,1311.800049,1320.469971,1320.469971,4109670000 102 | 2011-05-26,1320.640015,1328.510010,1314.410034,1325.689941,1325.689941,3259470000 103 | 2011-05-27,1325.689941,1334.619995,1325.689941,1331.099976,1331.099976,3124560000 104 | 2011-05-31,1331.099976,1345.199951,1331.099976,1345.199951,1345.199951,4696240000 105 | 2011-06-01,1345.199951,1345.199951,1313.709961,1314.550049,1314.550049,4241090000 106 | 2011-06-02,1314.550049,1318.030029,1305.609985,1312.939941,1312.939941,3762170000 107 | 2011-06-03,1312.939941,1312.939941,1297.900024,1300.160034,1300.160034,3505030000 108 | 2011-06-06,1300.260010,1300.260010,1284.719971,1286.170044,1286.170044,3555980000 109 | 2011-06-07,1286.310059,1296.219971,1284.739990,1284.939941,1284.939941,3846250000 110 | 2011-06-08,1284.630005,1287.040039,1277.420044,1279.560059,1279.560059,3970810000 111 | 2011-06-09,1279.630005,1294.540039,1279.630005,1289.000000,1289.000000,3332510000 112 | 2011-06-10,1288.599976,1288.599976,1268.280029,1270.979980,1270.979980,3846250000 113 | 2011-06-13,1271.310059,1277.040039,1265.640015,1271.829956,1271.829956,4132520000 114 | 2011-06-14,1272.219971,1292.500000,1272.219971,1287.869995,1287.869995,3500280000 115 | 2011-06-15,1287.869995,1287.869995,1261.900024,1265.420044,1265.420044,4070500000 116 | 2011-06-16,1265.530029,1274.109985,1258.069946,1267.640015,1267.640015,3846250000 117 | 2011-06-17,1268.579956,1279.819946,1267.400024,1271.500000,1271.500000,4916460000 118 | 2011-06-20,1271.500000,1280.420044,1267.560059,1278.359985,1278.359985,3464660000 119 | 2011-06-21,1278.400024,1297.619995,1278.400024,1295.520020,1295.520020,4056150000 120 | 2011-06-22,1295.479980,1298.609985,1286.790039,1287.140015,1287.140015,3718420000 121 | 2011-06-23,1286.599976,1286.599976,1262.869995,1283.500000,1283.500000,4983450000 122 | 2011-06-24,1283.040039,1283.930054,1267.239990,1268.449951,1268.449951,3665340000 123 | 2011-06-27,1268.439941,1284.910034,1267.530029,1280.099976,1280.099976,3479070000 124 | 2011-06-28,1280.209961,1296.800049,1280.209961,1296.670044,1296.670044,3681500000 125 | 2011-06-29,1296.849976,1309.209961,1296.849976,1307.410034,1307.410034,4347540000 126 | 2011-06-30,1307.640015,1321.969971,1307.640015,1320.640015,1320.640015,4200500000 127 | 2011-07-01,1320.640015,1341.010010,1318.180054,1339.670044,1339.670044,3796930000 128 | 2011-07-05,1339.589966,1340.890015,1334.300049,1337.880005,1337.880005,3722320000 129 | 2011-07-06,1337.560059,1340.939941,1330.920044,1339.219971,1339.219971,3564190000 130 | 2011-07-07,1339.619995,1356.479980,1339.619995,1353.219971,1353.219971,4069530000 131 | 2011-07-08,1352.390015,1352.390015,1333.709961,1343.800049,1343.800049,3594360000 132 | 2011-07-11,1343.310059,1343.310059,1316.420044,1319.489990,1319.489990,3879130000 133 | 2011-07-12,1319.609985,1327.170044,1313.329956,1313.640015,1313.640015,4227890000 134 | 2011-07-13,1314.449951,1331.479980,1314.449951,1317.719971,1317.719971,4060080000 135 | 2011-07-14,1317.739990,1326.880005,1306.510010,1308.869995,1308.869995,4358570000 136 | 2011-07-15,1308.869995,1317.699951,1307.520020,1316.140015,1316.140015,4242760000 137 | 2011-07-18,1315.939941,1315.939941,1295.920044,1305.439941,1305.439941,4118160000 138 | 2011-07-19,1307.069946,1328.140015,1307.069946,1326.729980,1326.729980,4304600000 139 | 2011-07-20,1328.660034,1330.430054,1323.650024,1325.839966,1325.839966,3767420000 140 | 2011-07-21,1325.650024,1347.000000,1325.650024,1343.800049,1343.800049,4837430000 141 | 2011-07-22,1343.800049,1346.099976,1336.949951,1345.020020,1345.020020,3522830000 142 | 2011-07-25,1344.319946,1344.319946,1331.089966,1337.430054,1337.430054,3536890000 143 | 2011-07-26,1337.390015,1338.510010,1329.589966,1331.939941,1331.939941,4007050000 144 | 2011-07-27,1331.910034,1331.910034,1303.489990,1304.890015,1304.890015,3479040000 145 | 2011-07-28,1304.839966,1316.319946,1299.160034,1300.670044,1300.670044,4951800000 146 | 2011-07-29,1300.119995,1304.160034,1282.859985,1292.280029,1292.280029,5061190000 147 | 2011-08-01,1292.589966,1307.380005,1274.729980,1286.939941,1286.939941,4967390000 148 | 2011-08-02,1286.560059,1286.560059,1254.030029,1254.050049,1254.050049,5206290000 149 | 2011-08-03,1254.250000,1261.199951,1234.560059,1260.339966,1260.339966,6446940000 150 | 2011-08-04,1260.229980,1260.229980,1199.540039,1200.069946,1200.069946,4266530000 151 | 2011-08-05,1200.280029,1218.109985,1168.089966,1199.380005,1199.380005,5454590000 152 | 2011-08-08,1198.479980,1198.479980,1119.280029,1119.459961,1119.459961,2615150000 153 | 2011-08-09,1120.229980,1172.880005,1101.540039,1172.530029,1172.530029,2366660000 154 | 2011-08-10,1171.770020,1171.770020,1118.010010,1120.760010,1120.760010,5018070000 155 | 2011-08-11,1121.300049,1186.290039,1121.300049,1172.640015,1172.640015,3685050000 156 | 2011-08-12,1172.869995,1189.040039,1170.739990,1178.810059,1178.810059,5640380000 157 | 2011-08-15,1178.859985,1204.489990,1178.859985,1204.489990,1204.489990,4272850000 158 | 2011-08-16,1204.219971,1204.219971,1180.530029,1192.760010,1192.760010,5071600000 159 | 2011-08-17,1192.890015,1208.469971,1184.359985,1193.890015,1193.890015,4388340000 160 | 2011-08-18,1189.619995,1189.619995,1131.030029,1140.650024,1140.650024,3234810000 161 | 2011-08-19,1140.469971,1154.540039,1122.050049,1123.530029,1123.530029,5167560000 162 | 2011-08-22,1123.550049,1145.489990,1121.089966,1123.819946,1123.819946,5436260000 163 | 2011-08-23,1124.359985,1162.349976,1124.359985,1162.349976,1162.349976,5013170000 164 | 2011-08-24,1162.160034,1178.560059,1156.300049,1177.599976,1177.599976,5315310000 165 | 2011-08-25,1176.689941,1190.680054,1155.469971,1159.270020,1159.270020,5748420000 166 | 2011-08-26,1158.849976,1181.229980,1135.910034,1176.800049,1176.800049,5035320000 167 | 2011-08-29,1177.910034,1210.280029,1177.910034,1210.079956,1210.079956,4228070000 168 | 2011-08-30,1209.760010,1220.099976,1195.770020,1212.920044,1212.920044,4572570000 169 | 2011-08-31,1213.000000,1230.709961,1209.349976,1218.890015,1218.890015,5267840000 170 | 2011-09-01,1219.119995,1229.290039,1203.849976,1204.420044,1204.420044,4780410000 171 | 2011-09-02,1203.900024,1203.900024,1170.560059,1173.969971,1173.969971,4401740000 172 | 2011-09-06,1173.969971,1173.969971,1140.130005,1165.239990,1165.239990,5103980000 173 | 2011-09-07,1165.849976,1198.619995,1165.849976,1198.619995,1198.619995,4441040000 174 | 2011-09-08,1197.979980,1204.400024,1183.339966,1185.900024,1185.900024,4465170000 175 | 2011-09-09,1185.369995,1185.369995,1148.369995,1154.229980,1154.229980,4586370000 176 | 2011-09-12,1153.500000,1162.520020,1136.069946,1162.270020,1162.270020,5168550000 177 | 2011-09-13,1162.589966,1176.410034,1157.439941,1172.869995,1172.869995,4681370000 178 | 2011-09-14,1173.319946,1202.380005,1162.729980,1188.680054,1188.680054,4986740000 179 | 2011-09-15,1189.439941,1209.109985,1189.439941,1209.109985,1209.109985,4479730000 180 | 2011-09-16,1209.209961,1220.060059,1204.459961,1216.010010,1216.010010,5248890000 181 | 2011-09-19,1214.989990,1214.989990,1188.359985,1204.089966,1204.089966,4254190000 182 | 2011-09-20,1204.500000,1220.390015,1201.290039,1202.089966,1202.089966,4315610000 183 | 2011-09-21,1203.630005,1206.300049,1166.209961,1166.760010,1166.760010,4728550000 184 | 2011-09-22,1164.550049,1164.550049,1114.219971,1129.560059,1129.560059,6703140000 185 | 2011-09-23,1128.819946,1141.719971,1121.359985,1136.430054,1136.430054,5639930000 186 | 2011-09-26,1136.910034,1164.189941,1131.069946,1162.949951,1162.949951,4762830000 187 | 2011-09-27,1163.319946,1195.859985,1163.319946,1175.380005,1175.380005,5548130000 188 | 2011-09-28,1175.390015,1184.709961,1150.400024,1151.060059,1151.060059,4787920000 189 | 2011-09-29,1151.739990,1175.869995,1139.930054,1160.400024,1160.400024,5285740000 190 | 2011-09-30,1159.930054,1159.930054,1131.339966,1131.420044,1131.420044,4416790000 191 | 2011-10-03,1131.209961,1138.989990,1098.920044,1099.229980,1099.229980,5670340000 192 | 2011-10-04,1097.420044,1125.119995,1074.770020,1123.949951,1123.949951,3714670000 193 | 2011-10-05,1124.030029,1146.069946,1115.680054,1144.030029,1144.030029,2510620000 194 | 2011-10-06,1144.109985,1165.550049,1134.949951,1164.969971,1164.969971,5098330000 195 | 2011-10-07,1165.030029,1171.400024,1150.260010,1155.459961,1155.459961,5580380000 196 | 2011-10-10,1158.150024,1194.910034,1158.150024,1194.890015,1194.890015,4446800000 197 | 2011-10-11,1194.599976,1199.239990,1187.300049,1195.540039,1195.540039,4424500000 198 | 2011-10-12,1196.189941,1220.250000,1196.189941,1207.250000,1207.250000,5355360000 199 | 2011-10-13,1206.959961,1207.459961,1190.579956,1203.660034,1203.660034,4436270000 200 | 2011-10-14,1205.650024,1224.609985,1205.650024,1224.579956,1224.579956,4116690000 201 | 2011-10-17,1224.469971,1224.469971,1198.550049,1200.859985,1200.859985,4300700000 202 | 2011-10-18,1200.750000,1233.099976,1191.479980,1225.380005,1225.380005,4840170000 203 | 2011-10-19,1223.459961,1229.640015,1206.310059,1209.880005,1209.880005,4846390000 204 | 2011-10-20,1209.920044,1219.530029,1197.339966,1215.390015,1215.390015,4870290000 205 | 2011-10-21,1215.390015,1239.030029,1215.390015,1238.250000,1238.250000,4980770000 206 | 2011-10-24,1238.719971,1256.550049,1238.719971,1254.189941,1254.189941,4309380000 207 | 2011-10-25,1254.189941,1254.189941,1226.790039,1229.050049,1229.050049,4473970000 208 | 2011-10-26,1229.170044,1246.280029,1221.060059,1242.000000,1242.000000,4873530000 209 | 2011-10-27,1243.969971,1292.660034,1243.969971,1284.589966,1284.589966,6367610000 210 | 2011-10-28,1284.390015,1287.079956,1277.010010,1285.089966,1285.089966,4536690000 211 | 2011-10-31,1284.959961,1284.959961,1253.160034,1253.300049,1253.300049,4310210000 212 | 2011-11-01,1251.000000,1251.000000,1215.420044,1218.280029,1218.280029,5645540000 213 | 2011-11-02,1219.619995,1242.479980,1219.619995,1237.900024,1237.900024,4110530000 214 | 2011-11-03,1238.250000,1263.209961,1234.810059,1261.150024,1261.150024,4849140000 215 | 2011-11-04,1260.819946,1260.819946,1238.920044,1253.229980,1253.229980,3830650000 216 | 2011-11-07,1253.209961,1261.699951,1240.750000,1261.119995,1261.119995,3429740000 217 | 2011-11-08,1261.119995,1277.550049,1254.989990,1275.920044,1275.920044,3908490000 218 | 2011-11-09,1275.180054,1275.180054,1226.640015,1229.099976,1229.099976,4659740000 219 | 2011-11-10,1229.589966,1246.219971,1227.699951,1239.699951,1239.699951,4002760000 220 | 2011-11-11,1240.119995,1266.979980,1240.119995,1263.849976,1263.849976,3370180000 221 | 2011-11-14,1263.849976,1263.849976,1246.680054,1251.780029,1251.780029,3219680000 222 | 2011-11-15,1251.699951,1264.250000,1244.339966,1257.810059,1257.810059,3599300000 223 | 2011-11-16,1257.810059,1259.609985,1235.670044,1236.910034,1236.910034,4085010000 224 | 2011-11-17,1236.560059,1237.729980,1209.430054,1216.130005,1216.130005,4596450000 225 | 2011-11-18,1216.189941,1223.510010,1211.359985,1215.650024,1215.650024,3827610000 226 | 2011-11-21,1215.619995,1215.619995,1183.160034,1192.979980,1192.979980,4050070000 227 | 2011-11-22,1192.979980,1196.810059,1181.650024,1188.040039,1188.040039,3911710000 228 | 2011-11-23,1187.479980,1187.479980,1161.790039,1161.790039,1161.790039,3798940000 229 | 2011-11-25,1161.410034,1172.660034,1158.660034,1158.670044,1158.670044,1664200000 230 | 2011-11-28,1158.670044,1197.349976,1158.670044,1192.550049,1192.550049,3920750000 231 | 2011-11-29,1192.560059,1203.670044,1191.800049,1195.189941,1195.189941,3992650000 232 | 2011-11-30,1196.719971,1247.109985,1196.719971,1246.959961,1246.959961,5801910000 233 | 2011-12-01,1246.910034,1251.089966,1239.729980,1244.579956,1244.579956,3818680000 234 | 2011-12-02,1246.030029,1260.079956,1243.349976,1244.280029,1244.280029,4144310000 235 | 2011-12-05,1244.329956,1266.729980,1244.329956,1257.079956,1257.079956,4148060000 236 | 2011-12-06,1257.189941,1266.030029,1253.030029,1258.469971,1258.469971,3734230000 237 | 2011-12-07,1258.140015,1267.060059,1244.800049,1261.010010,1261.010010,4160540000 238 | 2011-12-08,1260.869995,1260.869995,1231.469971,1234.349976,1234.349976,4298370000 239 | 2011-12-09,1234.479980,1258.250000,1234.479980,1255.189941,1255.189941,3830610000 240 | 2011-12-12,1255.050049,1255.050049,1227.250000,1236.469971,1236.469971,3600570000 241 | 2011-12-13,1236.829956,1249.859985,1219.430054,1225.729980,1225.729980,4121570000 242 | 2011-12-14,1225.729980,1225.729980,1209.469971,1211.819946,1211.819946,4298290000 243 | 2011-12-15,1212.119995,1225.599976,1212.119995,1215.750000,1215.750000,3810340000 244 | 2011-12-16,1216.089966,1231.040039,1215.199951,1219.660034,1219.660034,5345800000 245 | 2011-12-19,1219.739990,1224.569946,1202.369995,1205.349976,1205.349976,3659820000 246 | 2011-12-20,1205.719971,1242.819946,1205.719971,1241.300049,1241.300049,4055590000 247 | 2011-12-21,1241.250000,1245.089966,1229.510010,1243.719971,1243.719971,2959020000 248 | 2011-12-22,1243.719971,1255.219971,1243.719971,1254.000000,1254.000000,3492250000 249 | 2011-12-23,1254.000000,1265.420044,1254.000000,1265.329956,1265.329956,2233830000 250 | 2011-12-27,1265.020020,1269.369995,1262.300049,1265.430054,1265.430054,2130590000 251 | 2011-12-28,1265.380005,1265.849976,1248.640015,1249.640015,1249.640015,2349980000 252 | 2011-12-29,1249.750000,1263.540039,1249.750000,1263.020020,1263.020020,2278130000 253 | 2011-12-30,1262.819946,1264.119995,1257.459961,1257.599976,1257.599976,2271850000 254 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import keras 2 | from keras.models import load_model 3 | 4 | from agent.agent import Agent 5 | from functions import * 6 | import sys 7 | 8 | if len(sys.argv) != 3: 9 | print("Usage: python evaluate.py [stock] [model]") 10 | exit() 11 | 12 | stock_name, model_name = sys.argv[1], sys.argv[2] 13 | model = load_model("models/" + model_name) 14 | window_size = model.layers[0].input.shape.as_list()[1] 15 | 16 | agent = Agent(window_size, True, model_name) 17 | data = getStockDataVec(stock_name) 18 | l = len(data) - 1 19 | batch_size = 32 20 | 21 | state = getState(data, 0, window_size + 1) 22 | total_profit = 0 23 | agent.inventory = [] 24 | 25 | for t in range(l): 26 | action = agent.act(state) 27 | 28 | # sit 29 | next_state = getState(data, t + 1, window_size + 1) 30 | reward = 0 31 | 32 | if action == 1: # buy 33 | agent.inventory.append(data[t]) 34 | print("Buy: " + formatPrice(data[t])) 35 | 36 | elif action == 2 and len(agent.inventory) > 0: # sell 37 | bought_price = agent.inventory.pop(0) 38 | reward = max(data[t] - bought_price, 0) 39 | total_profit += data[t] - bought_price 40 | print("Sell: " + formatPrice(data[t]) + " | Profit: " + formatPrice(data[t] - bought_price)) 41 | 42 | done = True if t == l - 1 else False 43 | agent.memory.append((state, action, reward, next_state, done)) 44 | state = next_state 45 | 46 | if done: 47 | print("--------------------------------") 48 | print(stock_name + " Total Profit: " + formatPrice(total_profit)) 49 | print("--------------------------------") -------------------------------------------------------------------------------- /functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | 4 | # prints formatted price 5 | def formatPrice(n): 6 | return ("-$" if n < 0 else "$") + "{0:.2f}".format(abs(n)) 7 | 8 | # returns the vector containing stock data from a fixed file 9 | def getStockDataVec(key): 10 | vec = [] 11 | lines = open("data/" + key + ".csv", "r").read().splitlines() 12 | 13 | for line in lines[1:]: 14 | vec.append(float(line.split(",")[4])) 15 | 16 | return vec 17 | 18 | # returns the sigmoid 19 | def sigmoid(x): 20 | return 1 / (1 + math.exp(-x)) 21 | 22 | # returns an an n-day state representation ending at time t 23 | def getState(data, t, n): 24 | d = t - n + 1 25 | block = data[d:t + 1] if d >= 0 else -d * [data[0]] + data[0:t + 1] # pad with t0 26 | res = [] 27 | for i in range(n - 1): 28 | res.append(sigmoid(block[i + 1] - block[i])) 29 | 30 | return np.array([res]) 31 | -------------------------------------------------------------------------------- /images/AAPL_2015.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/Reinforcement_Learning_for_Stock_Prediction/0b87478668e0bc7dddc833e2adabc0948d318471/images/AAPL_2015.png -------------------------------------------------------------------------------- /images/AAPL_2016.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/Reinforcement_Learning_for_Stock_Prediction/0b87478668e0bc7dddc833e2adabc0948d318471/images/AAPL_2016.png -------------------------------------------------------------------------------- /images/BABA_2014.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/Reinforcement_Learning_for_Stock_Prediction/0b87478668e0bc7dddc833e2adabc0948d318471/images/BABA_2014.png -------------------------------------------------------------------------------- /images/BABA_2015.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/Reinforcement_Learning_for_Stock_Prediction/0b87478668e0bc7dddc833e2adabc0948d318471/images/BABA_2015.png -------------------------------------------------------------------------------- /images/GOOG_8_2017.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/Reinforcement_Learning_for_Stock_Prediction/0b87478668e0bc7dddc833e2adabc0948d318471/images/GOOG_8_2017.png -------------------------------------------------------------------------------- /images/^GSPC_2014.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/Reinforcement_Learning_for_Stock_Prediction/0b87478668e0bc7dddc833e2adabc0948d318471/images/^GSPC_2014.png -------------------------------------------------------------------------------- /images/^GSPC_2015.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/Reinforcement_Learning_for_Stock_Prediction/0b87478668e0bc7dddc833e2adabc0948d318471/images/^GSPC_2015.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from agent.agent import Agent 2 | from functions import * 3 | import sys 4 | 5 | if len(sys.argv) != 4: 6 | print("Usage: python train.py [stock] [window] [episodes]") 7 | exit() 8 | 9 | stock_name, window_size, episode_count = sys.argv[1], int(sys.argv[2]), int(sys.argv[3]) 10 | 11 | agent = Agent(window_size) 12 | data = getStockDataVec(stock_name) 13 | l = len(data) - 1 14 | batch_size = 32 15 | 16 | for e in range(episode_count + 1): 17 | print("Episode " + str(e) + "/" + str(episode_count)) 18 | state = getState(data, 0, window_size + 1) 19 | 20 | total_profit = 0 21 | agent.inventory = [] 22 | 23 | for t in range(l): 24 | action = agent.act(state) 25 | 26 | # sit 27 | next_state = getState(data, t + 1, window_size + 1) 28 | reward = 0 29 | 30 | if action == 1: # buy 31 | agent.inventory.append(data[t]) 32 | print("Buy: " + formatPrice(data[t])) 33 | 34 | elif action == 2 and len(agent.inventory) > 0: # sell 35 | bought_price = agent.inventory.pop(0) 36 | reward = max(data[t] - bought_price, 0) 37 | total_profit += data[t] - bought_price 38 | print("Sell: " + formatPrice(data[t]) + " | Profit: " + formatPrice(data[t] - bought_price)) 39 | 40 | done = True if t == l - 1 else False 41 | agent.memory.append((state, action, reward, next_state, done)) 42 | state = next_state 43 | 44 | if done: 45 | print("--------------------------------") 46 | print("Total Profit: " + formatPrice(total_profit)) 47 | print("--------------------------------") 48 | 49 | if len(agent.memory) > batch_size: 50 | agent.expReplay(batch_size) 51 | 52 | if e % 10 == 0: 53 | agent.model.save("models/model_ep" + str(e)) 54 | --------------------------------------------------------------------------------