├── LICENSE
├── README.md
├── baselines
└── baselines
│ └── ddpg
│ ├── README.md
│ ├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── ddpg.cpython-36.pyc
│ ├── ddpg_learner.cpython-36.pyc
│ ├── memory.cpython-36.pyc
│ ├── models.cpython-36.pyc
│ └── noise.cpython-36.pyc
│ ├── ddpg.py
│ ├── ddpg_learner.py
│ ├── memory.py
│ ├── models.py
│ ├── noise.py
│ └── test_smoke.py
├── compare.py
├── data
├── AAPL.csv
├── AXP.csv
├── BA.csv
├── CAT.csv
├── CSCO.csv
├── CVX.csv
├── DIS.csv
├── GE.csv
├── GLD.csv
├── GS.csv
├── HD.csv
├── IBM.csv
├── INTC.csv
├── JNJ.csv
├── JPM.csv
├── KO.csv
├── MCD.csv
├── MMM.csv
├── MRK.csv
├── MSFT.csv
├── NKE.csv
├── PFE.csv
├── PG.csv
├── QQQ.csv
├── SHV.csv
├── SHY.csv
├── SPY.csv
├── UNH.csv
├── UTX.csv
├── VZ.csv
├── WMT.csv
├── XOM.csv
├── ^DJI.csv
├── ^GSPC.csv
├── ^IXIC.csv
├── ^RUT.csv
├── ^TNX.csv
├── ^TYX.csv
├── ^VIX.csv
├── ddpg_input_states.csv
└── ddpg_stock_price.csv
├── data_preprocessing.py
├── feature_select.py
├── gym
└── envs
│ ├── StarTrader
│ ├── StarTrade_env.py
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── StarTrade_env.cpython-36.pyc
│ │ ├── StarTrade_test_env.cpython-36.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ └── intelliTrade_env.cpython-36.pyc
│ └── data
│ │ ├── AAPL.csv
│ │ ├── AXP.csv
│ │ ├── BA.csv
│ │ ├── CAT.csv
│ │ ├── CSCO.csv
│ │ ├── CVX.csv
│ │ ├── DIS.csv
│ │ ├── GE.csv
│ │ ├── GS.csv
│ │ ├── HD.csv
│ │ ├── IBM.csv
│ │ ├── INTC.csv
│ │ ├── JNJ.csv
│ │ ├── JPM.csv
│ │ ├── KO.csv
│ │ ├── MCD.csv
│ │ ├── MMM.csv
│ │ ├── MRK.csv
│ │ ├── MSFT.csv
│ │ ├── NKE.csv
│ │ ├── PFE.csv
│ │ ├── PG.csv
│ │ ├── UNH.csv
│ │ ├── UTX.csv
│ │ ├── VZ.csv
│ │ ├── WMT.csv
│ │ └── XOM.csv
│ ├── StarTraderTest
│ ├── StarTrade_test_env.py
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── StarTrade_env.cpython-36.pyc
│ │ ├── StarTrade_test_env.cpython-36.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ └── intelliTrade_env.cpython-36.pyc
│ └── data
│ │ ├── AAPL.csv
│ │ ├── AXP.csv
│ │ ├── BA.csv
│ │ ├── CAT.csv
│ │ ├── CSCO.csv
│ │ ├── CVX.csv
│ │ ├── DIS.csv
│ │ ├── GE.csv
│ │ ├── GS.csv
│ │ ├── HD.csv
│ │ ├── IBM.csv
│ │ ├── INTC.csv
│ │ ├── JNJ.csv
│ │ ├── JPM.csv
│ │ ├── KO.csv
│ │ ├── MCD.csv
│ │ ├── MMM.csv
│ │ ├── MRK.csv
│ │ ├── MSFT.csv
│ │ ├── NKE.csv
│ │ ├── PFE.csv
│ │ ├── PG.csv
│ │ ├── UNH.csv
│ │ ├── UTX.csv
│ │ ├── VZ.csv
│ │ ├── WMT.csv
│ │ └── XOM.csv
│ └── __init__.py
├── model
├── DDPG_trained_model_0
├── DDPG_trained_model_1
├── DDPG_trained_model_2
├── DDPG_trained_model_3
├── DDPG_trained_model_4
├── DDPG_trained_model_5
├── DDPG_trained_model_6
├── DDPG_trained_model_7
├── DDPG_trained_model_8
├── DDPG_trained_model_9
└── best_lstm_model.h5
├── results
├── model_1.jpg
├── model_10.jpg
├── model_10_rerun.jpg
├── model_11.jpg
├── model_11_rerun.jpg
├── model_12.jpg
├── model_13.jpg
├── model_14.jpg
├── model_15.jpg
├── model_16.jpg
├── model_17.jpg
├── model_18.jpg
├── model_2.jpg
├── model_3.jpg
├── model_4.jpg
├── model_5.jpg
├── model_6.jpg
├── model_7.jpg
├── model_8.jpg
└── model_9.jpg
├── run.py
├── test_iteration_1.gif
├── test_result
├── all_strategies_cum_returns.csv
├── all_strategies_returns.csv
├── asset_evolution_test_1.png
├── cummulative_returns_test_1.png
├── efficient_frontier.png
├── kpi_backtest.csv
├── kpi_test_1.csv
├── portfolios_return.png
├── portfolios_returns.png
├── portfolios_risk.png
├── price_prediction_LSTM.png
├── trading_book_backtest.csv
└── trading_book_test_1.csv
├── train_iterations_9.gif
└── train_result
├── asset_evolution_train_1.png
├── asset_evolution_train_10.png
├── asset_evolution_train_11.png
├── asset_evolution_train_2.png
├── asset_evolution_train_3.png
├── asset_evolution_train_4.png
├── asset_evolution_train_5.png
├── asset_evolution_train_6.png
├── asset_evolution_train_7.png
├── asset_evolution_train_8.png
├── asset_evolution_train_9.png
├── cummulative_returns_train_1.png
├── cummulative_returns_train_10.png
├── cummulative_returns_train_2.png
├── cummulative_returns_train_3.png
├── cummulative_returns_train_4.png
├── cummulative_returns_train_5.png
├── cummulative_returns_train_6.png
├── cummulative_returns_train_7.png
├── cummulative_returns_train_8.png
├── cummulative_returns_train_9.png
├── kpi_train_1.csv
├── kpi_train_10.csv
├── kpi_train_2.csv
├── kpi_train_3.csv
├── kpi_train_4.csv
├── kpi_train_5.csv
├── kpi_train_6.csv
├── kpi_train_7.csv
├── kpi_train_8.csv
├── kpi_train_9.csv
├── trading_book_train_1.csv
├── trading_book_train_10.csv
├── trading_book_train_11.csv
├── trading_book_train_2.csv
├── trading_book_train_3.csv
├── trading_book_train_4.csv
├── trading_book_train_5.csv
├── trading_book_train_6.csv
├── trading_book_train_7.csv
├── trading_book_train_8.csv
└── trading_book_train_9.csv
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Jiew Wan Tan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [//]: # (Image References)
2 |
3 | [image1]: https://github.com/jiewwantan/StarTrader/blob/master/train_iterations_9.gif "Training iterations"
4 | [image2]: https://github.com/jiewwantan/StarTrader/blob/master/test_iteration_1.gif "Testing trained model with one iteration"
5 | [image3]: https://github.com/jiewwantan/StarTrader/blob/master/test_result/portfolios_return.png "Trading strategy performance returns comparison"
6 | [image4]: https://github.com/jiewwantan/StarTrader/blob/master/test_result/portfolios_risk.png "Trading strategy performance risk comparison"
7 |
8 | # **StarTrader:**
Intelligent Trading Agent Development
with Deep Reinforcement Learning
9 |
10 | ### Introduction
11 |
12 | This project sets to create an intelligent trading agent and a trading environment that provides an ideal learning ground. A real-world trading environment is complex with stock, related instruments, macroeconomic, news and possibly alternative data in consideration. An effective agent must derive efficient representations of the environment from high-dimensional input, and generalize past experience to new situation. The project adopts a deep reinforcement learning algorithm, deep deterministic policy gradient (DDPG) to trade a portfolio of five stocks. Different reward system and hyperparameters was tried. Its performance compared to models created by recurrent neural network, modern portfolio theory, simple buy-and-hold and benchmark DJIA index. The agent and environment will then be evaluated to deliberate possible improvement and the agent potential to beat professional human trader, just like Deepmind’s Alpha series of intelligent game playing agents.
13 |
14 | The trading agent will learn and trade in [OpenAI Gym](https://gym.openai.com/) environment. Two Gym environments are created to serve the purpose, one for training (StarTrader-v0), another testing (StarTraderTest-v0). Both versions of StarTrader will utilize Gym's baseline implmentation of Deep deterministic policy gradient (DDPG).
15 |
16 | A portfolio of five stocks (out of 27 Dow Jones Industrial Average stocks) are selected based on non-correlation factor. StarTrader will trade these five non-correlated stocks by learning to maximize total asset (portfolio value + current account balance) as its goal. During the trading process, StarTrader-v0 will also optimize the portfolio by deciding how many stock units to trade for each of the five stocks.
17 |
18 | Based on non-correlation factor, a portfolio optimization algorithm has chosen the following five stocks to trade:
19 |
20 | 1. American Express
21 | 2. Wal Mart
22 | 3. UnitedHealth Group
23 | 4. Apple
24 | 5. Verizon Communications
25 |
26 | The preprocessing function creates technical data derived from each of the stock’s OHLCV data. On average there are roughly 6-8 time series data derived for each stock.
27 |
28 | Apart from stock data, context data is also used to aid learning:
29 |
30 | 1. S&P 500 index
31 | 2. Dow Jones Industrial Average index
32 | 3. NASDAQ Composite index
33 | 4. Russell 2000 index
34 | 5. SPDR S&P 500 ETF
35 | 6. Invesco QQQ Trust
36 | 7. CBOE Volatility Index
37 | 8. SPDR Gold Shares
38 | 9. Treasury Yield 30 Years
39 | 10. CBOE Interest Rate 10 Year T Note
40 | 11. iShares 1-3 Year Treasury Bond ETF
41 | 12. iShares Short Treasury Bond ETF
42 |
43 | Similarly, technical data derived from the above context data’s OHLCV data are being created. All data preprocessing is handled by two modules:
44 | 1. data_preprocessing.py
45 | 2. feature_select.py
46 |
47 | The preprocessed data are then being fed directly to StarTrader’s trading environment: class StarTradingEnv.
48 |
49 | The feature selection module (feature_select.py) select about 6-8 features out of 41 OHLCV and its technical data, In total, there are 121 features (may varies on different machine as the algorithm is not seeded) with about 36 stock feature data and the rest are context feature data.
50 |
51 | When trading is executed, 121 features along with total asset, current asset holdings and unrealized profit and loss will form a complete state space for the agent to trade and learn. The state space is designed to allow the agent to get a sense of the instantaneous environment in addition to how its interactions with the environment affects future state space. In another words, the trading agent bears the fruits and consequences of its own actions.
52 |
53 | ### Training agent on 9 iterations
54 | ![Training iterations][image1]
55 |
56 | ### Testing agent on one iteration
57 | No learning or model refinement, purely on testing the trained model.
58 | Trading agent survived the major market correction in 2018 with 1.13 Sharpe ratio.
59 |
60 | ![Testing trained model with one iteration][image2]
61 |
62 | ### Compare agent's performance with other trading strategies
63 | DDPG is the best performer in terms of cumulative returns. However with a much less volatile ride, RNN-LSTM model has better risk-adjusted return: the highest Sharpe ratio (1.88) and Sortino ratio (3.06). Both RNN-LSTM and DRL-DDPG modelled trading strategies have trading costs: commission (based on Interactive Broker's fee) and slippage (modelled by Zipline and based on stock's daily volume) incorporated since there are many transactions during the trading window. The other buy-and-hold strategies' trading costs are omitted since there is stocks are only transacted once.
64 | DDPG's reward system shall be modified to yield higher risk-adjusted return.
65 | For a fair comparison, LSTM model uses the same training data and similar backtester as DDPG model.
66 |
67 | ![Trading strategy performance returns comparison][image3]
68 | ![Trading strategy performance risk comparison][image4]
69 |
70 |
71 | ## Prerequisites
72 |
73 | Python 3.6 or Anaconda with Python 3.6 environment
74 | Python packages: pandas, numpy, matplotlib, statsmodels, sklearn, tensorflow
75 |
76 | The code is written in a Linux machine and has been tested on two operating systems:
77 | Linux Ubuntu 16.04 & Windows 10 Pro
78 |
79 |
80 | ## Installation instructions:
81 |
82 | 1. Installation of system packages CMake, OpenMPI on Mac
83 |
84 | ```brew install cmake openmpi```
85 |
86 | 2. Activate environemnt and install gym under this environment
87 |
88 | ```pip install gym```
89 |
90 | 3. Download Official Baseline Package
91 |
92 | Clone the repo:
93 |
94 | ```
95 | git clone https://github.com/openai/baselines.git
96 |
97 | cd baselines
98 |
99 | pip install -e .
100 | ```
101 |
102 | 4. Install Tensorflow
103 |
104 | There are several ways of installing Tensorflow, this page provide a good description on how it can be done with system OS, Python version and GPU availability taken into consideration.
105 |
106 | https://www.tensorflow.org/install/
107 |
108 | In short, after environment activation, Tensorflow can be installed with these commands:
109 |
110 | Tensorflow for CPU:
111 | ```pip3 install --upgrade tensorflow```
112 |
113 | Tensorflow for GPU:
114 | ```pip3 install --upgrade tensorflow-gpu```
115 |
116 | Installing Tensorflow GPU allows faster training if your machine has nVidia GPU(s) built-in.
117 | However, Tensorflow GPU version requires the installation of the right cuDNN and CUDA, these pages provide instructions to ensure the right version is installed:
118 |
119 | [Ubuntu](https://www.tensorflow.org/install/install_linux)
120 |
121 | [MacOS](https://www.tensorflow.org/install/install_mac (Tensorflow 1.2 no longer provides GPU support for MacOS) )
122 |
123 | [Windows](https://www.tensorflow.org/install/install_windows)
124 |
125 | 5. Place StarTrader and StarTraderTest folders in this repository to your machine's OpenAI Gym's environment folder:
126 |
127 | gym/envs/
128 |
129 | 6. Replace the ```__init__.py``` file in the following folder with the ```__ini__.py``` provided in this repository:
130 |
131 | ```gym/envs/__init__.py```
132 |
133 | 7. Place run.py in baselines folder to the folder where you want to execute run.py, for example:
134 |
135 | From Gym's installation:
136 | ```baselines/baselines/run.py```
137 |
138 | To:
139 | ```run.py```
140 |
141 | 8. Place 'data' folder to the folder where run.py resides
142 |
143 | ```/data/```
144 |
145 | 9. Replace ddpg.py from Gym's installation with the ddpg.py in this repository:
146 |
147 | In your machine Gym's installation:
148 | ```baselines/baselines/ddpg/ddpg.py```
149 |
150 | replaced by the ddpg.py in repository:
151 | ```baselines/baselines/ddpg/ddpg.py```
152 |
153 | 10. Replace ddpg_learner.py from Gym's installation with the ddpg_learner.py in this repository:
154 |
155 | In your machine Gym's installation:
156 | ```baselines/baselines/ddpg/ddpg_learner.py```
157 |
158 | replaced by the ddpg_learner.py in repository:
159 | ```baselines/baselines/ddpg/ddpg_learner.py```
160 |
161 | 11. Place feature_select.py and data_preprocessing.py in this repository into the same folder as run.py
162 |
163 | 12. Place the following folders in this repository into the folder where your run.py resides
164 |
165 | ```/test_result/```
166 | ```/train_result/```
167 | ```/model/```
168 |
169 | You do not need to include the folders' content, they will be generated when the program executes. If contents are included, they will be replaced once program executes.
170 |
171 | 12. Under the folder where run.py resides enter the following command:
172 |
173 | To train agent:
174 | ```python -m run --alg=ddpg --env=StarTrader-v0 --network=mlp --num_timesteps=2e4```
175 |
176 | To test agent:
177 | ```python -m run --alg=ddpg --env=StarTraderTest-v0 --network=mlp --num_timesteps=2e3 --load_path='./model/DDPG_trained_model_8'```
178 |
179 | If you have trained a better model, replace ```DDPG_trained_model_8``` with your new model.
180 |
181 | After training and testing the agent successfully, pick the first DDPG trading book for the test run which is saved as ./test_result/trading_book_test_1.csv or modify filename in compare.py.
182 | Compare agent performance with benchmark index and other trading strategies:
183 |
184 | ```python compare.py```
185 |
186 | ## Special intructions:
187 | 1. Depends on machine configuration, the following intallation maybe necessary:
188 |
189 | ```pip3 install -U numpy```
190 | ```pip3 install opencv-python```
191 | ```pip3 install mujoco-py==0.5.7```
192 | ```pip3 install lockfile```
193 |
194 | 2. The technical analysis library, TA-Lib may be tricky to install in some machines. The following page is a handy guide:
195 | https://goldenjumper.wordpress.com/tag/ta-lib/
196 |
197 | graphiviz which is required to plot the XGBoost tree diagram, can be installed with the following command:
198 | Windows:
199 | ```conda install python-graphviz```
200 | Mac/Linux:
201 | ```conda install graphviz```
202 |
203 |
--------------------------------------------------------------------------------
/baselines/baselines/ddpg/README.md:
--------------------------------------------------------------------------------
1 | # DDPG
2 |
3 | - Original paper: https://arxiv.org/abs/1509.02971
4 | - Baselines post: https://blog.openai.com/better-exploration-with-parameter-noise/
5 | - `python -m baselines.run --alg=ddpg --env=HalfCheetah-v2 --num_timesteps=1e6` runs the algorithm for 1M frames = 10M timesteps on a Mujoco environment. See help (`-h`) for more options.
6 |
--------------------------------------------------------------------------------
/baselines/baselines/ddpg/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/baselines/baselines/ddpg/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/baselines/baselines/ddpg/__pycache__/ddpg.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/baselines/baselines/ddpg/__pycache__/ddpg.cpython-36.pyc
--------------------------------------------------------------------------------
/baselines/baselines/ddpg/__pycache__/ddpg_learner.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/baselines/baselines/ddpg/__pycache__/ddpg_learner.cpython-36.pyc
--------------------------------------------------------------------------------
/baselines/baselines/ddpg/__pycache__/memory.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/baselines/baselines/ddpg/__pycache__/memory.cpython-36.pyc
--------------------------------------------------------------------------------
/baselines/baselines/ddpg/__pycache__/models.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/baselines/baselines/ddpg/__pycache__/models.cpython-36.pyc
--------------------------------------------------------------------------------
/baselines/baselines/ddpg/__pycache__/noise.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/baselines/baselines/ddpg/__pycache__/noise.cpython-36.pyc
--------------------------------------------------------------------------------
/baselines/baselines/ddpg/ddpg.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from collections import deque
4 | import pickle
5 |
6 | from baselines.ddpg.ddpg_learner import DDPG
7 | from baselines.ddpg.models import Actor, Critic
8 | from baselines.ddpg.memory import Memory
9 | from baselines.ddpg.noise import AdaptiveParamNoiseSpec, NormalActionNoise, OrnsteinUhlenbeckActionNoise
10 | from baselines.common import set_global_seeds
11 | import baselines.common.tf_util as U
12 |
13 | from baselines import logger
14 | import numpy as np
15 |
16 | try:
17 | from mpi4py import MPI
18 | except ImportError:
19 | MPI = None
20 |
21 | def learn(network, env,
22 | seed=None,
23 | total_timesteps=None,
24 | nb_epochs=None, # with default settings, perform 1M steps total
25 | nb_epoch_cycles=20,
26 | nb_rollout_steps=100,
27 | reward_scale=1.0,
28 | render=False,
29 | render_eval=False,
30 | noise_type='adaptive-param_0.2',
31 | normalize_returns=False,
32 | normalize_observations=True,
33 | critic_l2_reg=1e-2,
34 | actor_lr=1e-4,
35 | critic_lr=3e-4,
36 | popart=False,
37 | gamma=0.99,
38 | clip_norm=None,
39 | nb_train_steps=50, # per epoch cycle and MPI worker,
40 | nb_eval_steps=100,
41 | batch_size=128, # per MPI worker
42 | tau=0.01,
43 | eval_env=None,
44 | param_noise_adaption_interval=50,
45 | load_path = None,
46 | save_path = './model/',
47 | **network_kwargs):
48 |
49 | print("Save PATH;{}".format(save_path))
50 | print("Load PATH;{}".format(load_path))
51 | set_global_seeds(seed)
52 |
53 | if total_timesteps is not None:
54 | assert nb_epochs is None
55 | nb_epochs = int(total_timesteps) // (nb_epoch_cycles * nb_rollout_steps)
56 | else:
57 | nb_epochs = 500
58 |
59 | if MPI is not None:
60 | rank = MPI.COMM_WORLD.Get_rank()
61 | else:
62 | rank = 0
63 |
64 | nb_actions = env.action_space.shape[-1]
65 | assert (np.abs(env.action_space.low) == env.action_space.high).all() # we assume symmetric actions.
66 |
67 | memory = Memory(limit=int(1e6), action_shape=env.action_space.shape, observation_shape=env.observation_space.shape)
68 | critic = Critic(network=network, **network_kwargs)
69 | actor = Actor(nb_actions, network=network, **network_kwargs)
70 |
71 | action_noise = None
72 | param_noise = None
73 | if noise_type is not None:
74 | for current_noise_type in noise_type.split(','):
75 | current_noise_type = current_noise_type.strip()
76 | if current_noise_type == 'none':
77 | pass
78 | elif 'adaptive-param' in current_noise_type:
79 | _, stddev = current_noise_type.split('_')
80 | param_noise = AdaptiveParamNoiseSpec(initial_stddev=float(stddev), desired_action_stddev=float(stddev))
81 | elif 'normal' in current_noise_type:
82 | _, stddev = current_noise_type.split('_')
83 | action_noise = NormalActionNoise(mu=np.zeros(nb_actions), sigma=float(stddev) * np.ones(nb_actions))
84 | elif 'ou' in current_noise_type:
85 | _, stddev = current_noise_type.split('_')
86 | action_noise = OrnsteinUhlenbeckActionNoise(mu=np.zeros(nb_actions), sigma=float(stddev) * np.ones(nb_actions))
87 | else:
88 | raise RuntimeError('unknown noise type "{}"'.format(current_noise_type))
89 |
90 | max_action = env.action_space.high
91 | logger.info('scaling actions by {} before executing in env'.format(max_action))
92 |
93 | agent = DDPG(actor, critic, memory, env.observation_space.shape, env.action_space.shape,
94 | gamma=gamma, tau=tau, normalize_returns=normalize_returns, normalize_observations=normalize_observations,
95 | batch_size=batch_size, action_noise=action_noise, param_noise=param_noise, critic_l2_reg=critic_l2_reg,
96 | actor_lr=actor_lr, critic_lr=critic_lr, enable_popart=popart, clip_norm=clip_norm,
97 | reward_scale=reward_scale)
98 | logger.info('Using agent with the following configuration:')
99 | logger.info(str(agent.__dict__.items()))
100 |
101 | eval_episode_rewards_history = deque(maxlen=100)
102 | episode_rewards_history = deque(maxlen=100)
103 | sess = U.get_session()
104 | # Prepare everything.
105 | agent.initialize(sess)
106 | checkpoint_num = 0
107 | if load_path is not None:
108 | print("Loading model from ", load_path)
109 | agent.load(load_path)
110 | # Get the model number and assign a newer number, so that training done after loading this model will be saved
111 | # with a new number that reflects the model has been updated.
112 | checkpoint_num = int(os.path.split(load_path)[1].split("_",3)[-1]) + 1
113 | sess.graph.finalize()
114 |
115 | agent.reset()
116 |
117 | obs = env.reset()
118 | if eval_env is not None:
119 | eval_obs = eval_env.reset()
120 | nenvs = obs.shape[0]
121 |
122 | episode_reward = np.zeros(nenvs, dtype = np.float32) #vector
123 | episode_step = np.zeros(nenvs, dtype = int) # vector
124 | episodes = 0 #scalar
125 | t = 0 # scalar
126 |
127 | epoch = 0
128 |
129 |
130 |
131 | start_time = time.time()
132 |
133 | epoch_episode_rewards = []
134 | epoch_episode_steps = []
135 | epoch_actions = []
136 | epoch_qs = []
137 | epoch_episodes = 0
138 | if load_path is None:
139 | os.makedirs(save_path, exist_ok=True)
140 | for epoch in range(nb_epochs):
141 | for cycle in range(nb_epoch_cycles):
142 | # Perform rollouts.
143 | if nenvs > 1:
144 | # if simulating multiple envs in parallel, impossible to reset agent at the end of the episode in each
145 | # of the environments, so resetting here instead
146 | agent.reset()
147 | for t_rollout in range(nb_rollout_steps):
148 | # Predict next action.
149 | action, q, _, _ = agent.step(obs, apply_noise=True, compute_Q=True)
150 |
151 | # Execute next action.
152 | if rank == 0 and render:
153 | env.render()
154 |
155 | # max_action is of dimension A, whereas action is dimension (nenvs, A) - the multiplication gets broadcasted to the batch
156 | new_obs, r, done, info = env.step(max_action * action) # scale for execution in env (as far as DDPG is concerned, every action is in [-1, 1])
157 | # note these outputs are batched from vecenv
158 |
159 | t += 1
160 | if rank == 0 and render:
161 | env.render()
162 | episode_reward += r
163 | episode_step += 1
164 |
165 | # Book-keeping.
166 | epoch_actions.append(action)
167 | epoch_qs.append(q)
168 | agent.store_transition(obs, action, r, new_obs, done) #the batched data will be unrolled in memory.py's append.
169 |
170 | obs = new_obs
171 |
172 | for d in range(len(done)):
173 | if done[d]:
174 | # Episode done.
175 | epoch_episode_rewards.append(episode_reward[d])
176 | episode_rewards_history.append(episode_reward[d])
177 | epoch_episode_steps.append(episode_step[d])
178 | episode_reward[d] = 0.
179 | episode_step[d] = 0
180 | epoch_episodes += 1
181 | episodes += 1
182 | if nenvs == 1:
183 | agent.reset()
184 |
185 |
186 |
187 | # Train.
188 | epoch_actor_losses = []
189 | epoch_critic_losses = []
190 | epoch_adaptive_distances = []
191 | for t_train in range(nb_train_steps):
192 | # Adapt param noise, if necessary.
193 | if memory.nb_entries >= batch_size and t_train % param_noise_adaption_interval == 0:
194 | distance = agent.adapt_param_noise()
195 | epoch_adaptive_distances.append(distance)
196 |
197 | cl, al = agent.train()
198 | epoch_critic_losses.append(cl)
199 | epoch_actor_losses.append(al)
200 | agent.update_target_net()
201 |
202 | # Evaluate.
203 | eval_episode_rewards = []
204 | eval_qs = []
205 | if eval_env is not None:
206 | nenvs_eval = eval_obs.shape[0]
207 | eval_episode_reward = np.zeros(nenvs_eval, dtype = np.float32)
208 | for t_rollout in range(nb_eval_steps):
209 | eval_action, eval_q, _, _ = agent.step(eval_obs, apply_noise=False, compute_Q=True)
210 | eval_obs, eval_r, eval_done, eval_info = eval_env.step(max_action * eval_action) # scale for execution in env (as far as DDPG is concerned, every action is in [-1, 1])
211 | if render_eval:
212 | eval_env.render()
213 | eval_episode_reward += eval_r
214 |
215 | eval_qs.append(eval_q)
216 | for d in range(len(eval_done)):
217 | if eval_done[d]:
218 | eval_episode_rewards.append(eval_episode_reward[d])
219 | eval_episode_rewards_history.append(eval_episode_reward[d])
220 | eval_episode_reward[d] = 0.0
221 |
222 | if MPI is not None:
223 | mpi_size = MPI.COMM_WORLD.Get_size()
224 | else:
225 | mpi_size = 1
226 |
227 | # Log stats.
228 | # XXX shouldn't call np.mean on variable length lists
229 | duration = time.time() - start_time
230 | stats = agent.get_stats()
231 | combined_stats = stats.copy()
232 | combined_stats['rollout/return'] = np.mean(epoch_episode_rewards)
233 | combined_stats['rollout/return_history'] = np.mean(episode_rewards_history)
234 | combined_stats['rollout/episode_steps'] = np.mean(epoch_episode_steps)
235 | combined_stats['rollout/actions_mean'] = np.mean(epoch_actions)
236 | combined_stats['rollout/Q_mean'] = np.mean(epoch_qs)
237 | combined_stats['train/loss_actor'] = np.mean(epoch_actor_losses)
238 | combined_stats['train/loss_critic'] = np.mean(epoch_critic_losses)
239 | combined_stats['train/param_noise_distance'] = np.mean(epoch_adaptive_distances)
240 | combined_stats['total/duration'] = duration
241 | combined_stats['total/steps_per_second'] = float(t) / float(duration)
242 | combined_stats['total/episodes'] = episodes
243 | combined_stats['rollout/episodes'] = epoch_episodes
244 | combined_stats['rollout/actions_std'] = np.std(epoch_actions)
245 | # Evaluation statistics.
246 | if eval_env is not None:
247 | combined_stats['eval/return'] = eval_episode_rewards
248 | combined_stats['eval/return_history'] = np.mean(eval_episode_rewards_history)
249 | combined_stats['eval/Q'] = eval_qs
250 | combined_stats['eval/episodes'] = len(eval_episode_rewards)
251 | def as_scalar(x):
252 | if isinstance(x, np.ndarray):
253 | assert x.size == 1
254 | return x[0]
255 | elif np.isscalar(x):
256 | return x
257 | else:
258 | raise ValueError('expected scalar, got %s'%x)
259 |
260 | combined_stats_sums = np.array([ np.array(x).flatten()[0] for x in combined_stats.values()])
261 | if MPI is not None:
262 | combined_stats_sums = MPI.COMM_WORLD.allreduce(combined_stats_sums)
263 |
264 | combined_stats = {k : v / mpi_size for (k,v) in zip(combined_stats.keys(), combined_stats_sums)}
265 |
266 | # Total statistics.
267 | combined_stats['total/epochs'] = epoch + 1
268 | combined_stats['total/steps'] = t
269 |
270 | for key in sorted(combined_stats.keys()):
271 | logger.record_tabular(key, combined_stats[key])
272 |
273 | if rank == 0:
274 | logger.dump_tabular()
275 | logger.info('')
276 | logdir = logger.get_dir()
277 | if rank == 0 and logdir:
278 | if hasattr(env, 'get_state'):
279 | with open(os.path.join(logdir, 'env_state.pkl'), 'wb') as f:
280 | pickle.dump(env.get_state(), f)
281 | if eval_env and hasattr(eval_env, 'get_state'):
282 | with open(os.path.join(logdir, 'eval_env_state.pkl'), 'wb') as f:
283 | pickle.dump(eval_env.get_state(), f)
284 |
285 | savepath = os.path.join(save_path, "DDPG_trained_model_" + str(epoch + checkpoint_num))
286 | print('Model saved to ', savepath)
287 | print('\n')
288 | agent.save(savepath)
289 |
290 |
291 | return agent
292 |
--------------------------------------------------------------------------------
/baselines/baselines/ddpg/ddpg_learner.py:
--------------------------------------------------------------------------------
1 | from copy import copy
2 | from functools import reduce
3 |
4 | import functools
5 | import numpy as np
6 | import tensorflow as tf
7 | import tensorflow.contrib as tc
8 |
9 | from baselines import logger
10 | from baselines.common.mpi_adam import MpiAdam
11 | import baselines.common.tf_util as U
12 | from baselines.common.mpi_running_mean_std import RunningMeanStd
13 | from baselines.common.tf_util import save_variables, load_variables
14 | try:
15 | from mpi4py import MPI
16 | except ImportError:
17 | MPI = None
18 |
19 | def normalize(x, stats):
20 | if stats is None:
21 | return x
22 | return (x - stats.mean) / stats.std
23 |
24 |
25 | def denormalize(x, stats):
26 | if stats is None:
27 | return x
28 | return x * stats.std + stats.mean
29 |
30 | def reduce_std(x, axis=None, keepdims=False):
31 | return tf.sqrt(reduce_var(x, axis=axis, keepdims=keepdims))
32 |
33 | def reduce_var(x, axis=None, keepdims=False):
34 | m = tf.reduce_mean(x, axis=axis, keepdims=True)
35 | devs_squared = tf.square(x - m)
36 | return tf.reduce_mean(devs_squared, axis=axis, keepdims=keepdims)
37 |
38 | def get_target_updates(vars, target_vars, tau):
39 | logger.info('setting up target updates ...')
40 | soft_updates = []
41 | init_updates = []
42 | assert len(vars) == len(target_vars)
43 | for var, target_var in zip(vars, target_vars):
44 | logger.info(' {} <- {}'.format(target_var.name, var.name))
45 | init_updates.append(tf.assign(target_var, var))
46 | soft_updates.append(tf.assign(target_var, (1. - tau) * target_var + tau * var))
47 | assert len(init_updates) == len(vars)
48 | assert len(soft_updates) == len(vars)
49 | return tf.group(*init_updates), tf.group(*soft_updates)
50 |
51 |
52 | def get_perturbed_actor_updates(actor, perturbed_actor, param_noise_stddev):
53 | assert len(actor.vars) == len(perturbed_actor.vars)
54 | assert len(actor.perturbable_vars) == len(perturbed_actor.perturbable_vars)
55 |
56 | updates = []
57 | for var, perturbed_var in zip(actor.vars, perturbed_actor.vars):
58 | if var in actor.perturbable_vars:
59 | logger.info(' {} <- {} + noise'.format(perturbed_var.name, var.name))
60 | updates.append(tf.assign(perturbed_var, var + tf.random_normal(tf.shape(var), mean=0., stddev=param_noise_stddev)))
61 | else:
62 | logger.info(' {} <- {}'.format(perturbed_var.name, var.name))
63 | updates.append(tf.assign(perturbed_var, var))
64 | assert len(updates) == len(actor.vars)
65 | return tf.group(*updates)
66 |
67 |
68 | class DDPG(object):
69 | def __init__(self, actor, critic, memory, observation_shape, action_shape, param_noise=None, action_noise=None,
70 | gamma=0.99, tau=0.001, normalize_returns=False, enable_popart=False, normalize_observations=True,
71 | batch_size=128, observation_range=(-5., 5.), action_range=(-1., 1.), return_range=(-np.inf, np.inf),
72 | critic_l2_reg=0., actor_lr=1e-4, critic_lr=1e-3, clip_norm=None, reward_scale=1.):
73 | # Inputs.
74 | self.obs0 = tf.placeholder(tf.float32, shape=(None,) + observation_shape, name='obs0')
75 | self.obs1 = tf.placeholder(tf.float32, shape=(None,) + observation_shape, name='obs1')
76 | self.terminals1 = tf.placeholder(tf.float32, shape=(None, 1), name='terminals1')
77 | self.rewards = tf.placeholder(tf.float32, shape=(None, 1), name='rewards')
78 | self.actions = tf.placeholder(tf.float32, shape=(None,) + action_shape, name='actions')
79 | self.critic_target = tf.placeholder(tf.float32, shape=(None, 1), name='critic_target')
80 | self.param_noise_stddev = tf.placeholder(tf.float32, shape=(), name='param_noise_stddev')
81 |
82 | # Parameters.
83 | self.gamma = gamma
84 | self.tau = tau
85 | self.memory = memory
86 | self.normalize_observations = normalize_observations
87 | self.normalize_returns = normalize_returns
88 | self.action_noise = action_noise
89 | self.param_noise = param_noise
90 | self.action_range = action_range
91 | self.return_range = return_range
92 | self.observation_range = observation_range
93 | self.critic = critic
94 | self.actor = actor
95 | self.actor_lr = actor_lr
96 | self.critic_lr = critic_lr
97 | self.clip_norm = clip_norm
98 | self.enable_popart = enable_popart
99 | self.reward_scale = reward_scale
100 | self.batch_size = batch_size
101 | self.stats_sample = None
102 | self.critic_l2_reg = critic_l2_reg
103 | self.save = None
104 | self.load = None
105 |
106 | # Observation normalization.
107 | if self.normalize_observations:
108 | with tf.variable_scope('obs_rms'):
109 | self.obs_rms = RunningMeanStd(shape=observation_shape)
110 | else:
111 | self.obs_rms = None
112 | normalized_obs0 = tf.clip_by_value(normalize(self.obs0, self.obs_rms),
113 | self.observation_range[0], self.observation_range[1])
114 | normalized_obs1 = tf.clip_by_value(normalize(self.obs1, self.obs_rms),
115 | self.observation_range[0], self.observation_range[1])
116 |
117 | # Return normalization.
118 | if self.normalize_returns:
119 | with tf.variable_scope('ret_rms'):
120 | self.ret_rms = RunningMeanStd()
121 | else:
122 | self.ret_rms = None
123 |
124 | # Create target networks.
125 | target_actor = copy(actor)
126 | target_actor.name = 'target_actor'
127 | self.target_actor = target_actor
128 | target_critic = copy(critic)
129 | target_critic.name = 'target_critic'
130 | self.target_critic = target_critic
131 |
132 | # Create networks and core TF parts that are shared across setup parts.
133 | self.actor_tf = actor(normalized_obs0)
134 | self.normalized_critic_tf = critic(normalized_obs0, self.actions)
135 | self.critic_tf = denormalize(tf.clip_by_value(self.normalized_critic_tf, self.return_range[0], self.return_range[1]), self.ret_rms)
136 | self.normalized_critic_with_actor_tf = critic(normalized_obs0, self.actor_tf, reuse=True)
137 | self.critic_with_actor_tf = denormalize(tf.clip_by_value(self.normalized_critic_with_actor_tf, self.return_range[0], self.return_range[1]), self.ret_rms)
138 | Q_obs1 = denormalize(target_critic(normalized_obs1, target_actor(normalized_obs1)), self.ret_rms)
139 | self.target_Q = self.rewards + (1. - self.terminals1) * gamma * Q_obs1
140 |
141 | # Set up parts.
142 | if self.param_noise is not None:
143 | self.setup_param_noise(normalized_obs0)
144 | self.setup_actor_optimizer()
145 | self.setup_critic_optimizer()
146 | if self.normalize_returns and self.enable_popart:
147 | self.setup_popart()
148 | self.setup_stats()
149 | self.setup_target_network_updates()
150 |
151 | self.initial_state = None # recurrent architectures not supported yet
152 |
153 | def setup_target_network_updates(self):
154 | actor_init_updates, actor_soft_updates = get_target_updates(self.actor.vars, self.target_actor.vars, self.tau)
155 | critic_init_updates, critic_soft_updates = get_target_updates(self.critic.vars, self.target_critic.vars, self.tau)
156 | self.target_init_updates = [actor_init_updates, critic_init_updates]
157 | self.target_soft_updates = [actor_soft_updates, critic_soft_updates]
158 |
159 | def setup_param_noise(self, normalized_obs0):
160 | assert self.param_noise is not None
161 |
162 | # Configure perturbed actor.
163 | param_noise_actor = copy(self.actor)
164 | param_noise_actor.name = 'param_noise_actor'
165 | self.perturbed_actor_tf = param_noise_actor(normalized_obs0)
166 | logger.info('setting up param noise')
167 | self.perturb_policy_ops = get_perturbed_actor_updates(self.actor, param_noise_actor, self.param_noise_stddev)
168 |
169 | # Configure separate copy for stddev adoption.
170 | adaptive_param_noise_actor = copy(self.actor)
171 | adaptive_param_noise_actor.name = 'adaptive_param_noise_actor'
172 | adaptive_actor_tf = adaptive_param_noise_actor(normalized_obs0)
173 | self.perturb_adaptive_policy_ops = get_perturbed_actor_updates(self.actor, adaptive_param_noise_actor, self.param_noise_stddev)
174 | self.adaptive_policy_distance = tf.sqrt(tf.reduce_mean(tf.square(self.actor_tf - adaptive_actor_tf)))
175 |
176 | def setup_actor_optimizer(self):
177 | logger.info('setting up actor optimizer')
178 | self.actor_loss = -tf.reduce_mean(self.critic_with_actor_tf)
179 | actor_shapes = [var.get_shape().as_list() for var in self.actor.trainable_vars]
180 | actor_nb_params = sum([reduce(lambda x, y: x * y, shape) for shape in actor_shapes])
181 | logger.info(' actor shapes: {}'.format(actor_shapes))
182 | logger.info(' actor params: {}'.format(actor_nb_params))
183 | self.actor_grads = U.flatgrad(self.actor_loss, self.actor.trainable_vars, clip_norm=self.clip_norm)
184 | self.actor_optimizer = MpiAdam(var_list=self.actor.trainable_vars,
185 | beta1=0.9, beta2=0.999, epsilon=1e-08)
186 |
187 | def setup_critic_optimizer(self):
188 | logger.info('setting up critic optimizer')
189 | normalized_critic_target_tf = tf.clip_by_value(normalize(self.critic_target, self.ret_rms), self.return_range[0], self.return_range[1])
190 | self.critic_loss = tf.reduce_mean(tf.square(self.normalized_critic_tf - normalized_critic_target_tf))
191 | if self.critic_l2_reg > 0.:
192 | critic_reg_vars = [var for var in self.critic.trainable_vars if var.name.endswith('/w:0') and 'output' not in var.name]
193 | for var in critic_reg_vars:
194 | logger.info(' regularizing: {}'.format(var.name))
195 | logger.info(' applying l2 regularization with {}'.format(self.critic_l2_reg))
196 | critic_reg = tc.layers.apply_regularization(
197 | tc.layers.l2_regularizer(self.critic_l2_reg),
198 | weights_list=critic_reg_vars
199 | )
200 | self.critic_loss += critic_reg
201 | critic_shapes = [var.get_shape().as_list() for var in self.critic.trainable_vars]
202 | critic_nb_params = sum([reduce(lambda x, y: x * y, shape) for shape in critic_shapes])
203 | logger.info(' critic shapes: {}'.format(critic_shapes))
204 | logger.info(' critic params: {}'.format(critic_nb_params))
205 | self.critic_grads = U.flatgrad(self.critic_loss, self.critic.trainable_vars, clip_norm=self.clip_norm)
206 | self.critic_optimizer = MpiAdam(var_list=self.critic.trainable_vars,
207 | beta1=0.9, beta2=0.999, epsilon=1e-08)
208 |
209 | def setup_popart(self):
210 | # See https://arxiv.org/pdf/1602.07714.pdf for details.
211 | self.old_std = tf.placeholder(tf.float32, shape=[1], name='old_std')
212 | new_std = self.ret_rms.std
213 | self.old_mean = tf.placeholder(tf.float32, shape=[1], name='old_mean')
214 | new_mean = self.ret_rms.mean
215 |
216 | self.renormalize_Q_outputs_op = []
217 | for vs in [self.critic.output_vars, self.target_critic.output_vars]:
218 | assert len(vs) == 2
219 | M, b = vs
220 | assert 'kernel' in M.name
221 | assert 'bias' in b.name
222 | assert M.get_shape()[-1] == 1
223 | assert b.get_shape()[-1] == 1
224 | self.renormalize_Q_outputs_op += [M.assign(M * self.old_std / new_std)]
225 | self.renormalize_Q_outputs_op += [b.assign((b * self.old_std + self.old_mean - new_mean) / new_std)]
226 |
227 | def setup_stats(self):
228 | ops = []
229 | names = []
230 |
231 | if self.normalize_returns:
232 | ops += [self.ret_rms.mean, self.ret_rms.std]
233 | names += ['ret_rms_mean', 'ret_rms_std']
234 |
235 | if self.normalize_observations:
236 | ops += [tf.reduce_mean(self.obs_rms.mean), tf.reduce_mean(self.obs_rms.std)]
237 | names += ['obs_rms_mean', 'obs_rms_std']
238 |
239 | ops += [tf.reduce_mean(self.critic_tf)]
240 | names += ['reference_Q_mean']
241 | ops += [reduce_std(self.critic_tf)]
242 | names += ['reference_Q_std']
243 |
244 | ops += [tf.reduce_mean(self.critic_with_actor_tf)]
245 | names += ['reference_actor_Q_mean']
246 | ops += [reduce_std(self.critic_with_actor_tf)]
247 | names += ['reference_actor_Q_std']
248 |
249 | ops += [tf.reduce_mean(self.actor_tf)]
250 | names += ['reference_action_mean']
251 | ops += [reduce_std(self.actor_tf)]
252 | names += ['reference_action_std']
253 |
254 | if self.param_noise:
255 | ops += [tf.reduce_mean(self.perturbed_actor_tf)]
256 | names += ['reference_perturbed_action_mean']
257 | ops += [reduce_std(self.perturbed_actor_tf)]
258 | names += ['reference_perturbed_action_std']
259 |
260 | self.stats_ops = ops
261 | self.stats_names = names
262 |
263 | def step(self, obs, apply_noise=True, compute_Q=True):
264 | if self.param_noise is not None and apply_noise:
265 | actor_tf = self.perturbed_actor_tf
266 | else:
267 | actor_tf = self.actor_tf
268 | feed_dict = {self.obs0: U.adjust_shape(self.obs0, [obs])}
269 | if compute_Q:
270 | action, q = self.sess.run([actor_tf, self.critic_with_actor_tf], feed_dict=feed_dict)
271 | else:
272 | action = self.sess.run(actor_tf, feed_dict=feed_dict)
273 | q = None
274 |
275 | if self.action_noise is not None and apply_noise:
276 | noise = self.action_noise()
277 | assert noise.shape == action[0].shape
278 | action += noise
279 | action = np.clip(action, self.action_range[0], self.action_range[1])
280 |
281 |
282 | return action, q, None, None
283 |
284 | def store_transition(self, obs0, action, reward, obs1, terminal1):
285 | reward *= self.reward_scale
286 |
287 | B = obs0.shape[0]
288 | for b in range(B):
289 | self.memory.append(obs0[b], action[b], reward[b], obs1[b], terminal1[b])
290 | if self.normalize_observations:
291 | self.obs_rms.update(np.array([obs0[b]]))
292 |
293 | def train(self):
294 | # Get a batch.
295 | batch = self.memory.sample(batch_size=self.batch_size)
296 |
297 | if self.normalize_returns and self.enable_popart:
298 | old_mean, old_std, target_Q = self.sess.run([self.ret_rms.mean, self.ret_rms.std, self.target_Q], feed_dict={
299 | self.obs1: batch['obs1'],
300 | self.rewards: batch['rewards'],
301 | self.terminals1: batch['terminals1'].astype('float32'),
302 | })
303 | self.ret_rms.update(target_Q.flatten())
304 | self.sess.run(self.renormalize_Q_outputs_op, feed_dict={
305 | self.old_std : np.array([old_std]),
306 | self.old_mean : np.array([old_mean]),
307 | })
308 |
309 | # Run sanity check. Disabled by default since it slows down things considerably.
310 | # print('running sanity check')
311 | # target_Q_new, new_mean, new_std = self.sess.run([self.target_Q, self.ret_rms.mean, self.ret_rms.std], feed_dict={
312 | # self.obs1: batch['obs1'],
313 | # self.rewards: batch['rewards'],
314 | # self.terminals1: batch['terminals1'].astype('float32'),
315 | # })
316 | # print(target_Q_new, target_Q, new_mean, new_std)
317 | # assert (np.abs(target_Q - target_Q_new) < 1e-3).all()
318 | else:
319 | target_Q = self.sess.run(self.target_Q, feed_dict={
320 | self.obs1: batch['obs1'],
321 | self.rewards: batch['rewards'],
322 | self.terminals1: batch['terminals1'].astype('float32'),
323 | })
324 |
325 | # Get all gradients and perform a synced update.
326 | ops = [self.actor_grads, self.actor_loss, self.critic_grads, self.critic_loss]
327 | actor_grads, actor_loss, critic_grads, critic_loss = self.sess.run(ops, feed_dict={
328 | self.obs0: batch['obs0'],
329 | self.actions: batch['actions'],
330 | self.critic_target: target_Q,
331 | })
332 | self.actor_optimizer.update(actor_grads, stepsize=self.actor_lr)
333 | self.critic_optimizer.update(critic_grads, stepsize=self.critic_lr)
334 |
335 | return critic_loss, actor_loss
336 |
337 | def initialize(self, sess):
338 | self.sess = sess
339 | self.sess.run(tf.global_variables_initializer())
340 | self.save = functools.partial(save_variables, sess=self.sess)
341 | self.load = functools.partial(load_variables, sess=self.load)
342 | self.actor_optimizer.sync()
343 | self.critic_optimizer.sync()
344 | self.sess.run(self.target_init_updates)
345 |
346 | def update_target_net(self):
347 | self.sess.run(self.target_soft_updates)
348 |
349 | def get_stats(self):
350 | if self.stats_sample is None:
351 | # Get a sample and keep that fixed for all further computations.
352 | # This allows us to estimate the change in value for the same set of inputs.
353 | self.stats_sample = self.memory.sample(batch_size=self.batch_size)
354 | values = self.sess.run(self.stats_ops, feed_dict={
355 | self.obs0: self.stats_sample['obs0'],
356 | self.actions: self.stats_sample['actions'],
357 | })
358 |
359 | names = self.stats_names[:]
360 | assert len(names) == len(values)
361 | stats = dict(zip(names, values))
362 |
363 | if self.param_noise is not None:
364 | stats = {**stats, **self.param_noise.get_stats()}
365 |
366 | return stats
367 |
368 | def adapt_param_noise(self):
369 | try:
370 | from mpi4py import MPI
371 | except ImportError:
372 | MPI = None
373 |
374 | if self.param_noise is None:
375 | return 0.
376 |
377 | # Perturb a separate copy of the policy to adjust the scale for the next "real" perturbation.
378 | batch = self.memory.sample(batch_size=self.batch_size)
379 | self.sess.run(self.perturb_adaptive_policy_ops, feed_dict={
380 | self.param_noise_stddev: self.param_noise.current_stddev,
381 | })
382 | distance = self.sess.run(self.adaptive_policy_distance, feed_dict={
383 | self.obs0: batch['obs0'],
384 | self.param_noise_stddev: self.param_noise.current_stddev,
385 | })
386 |
387 | if MPI is not None:
388 | mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size()
389 | else:
390 | mean_distance = distance
391 |
392 | if MPI is not None:
393 | mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size()
394 | else:
395 | mean_distance = distance
396 |
397 | self.param_noise.adapt(mean_distance)
398 | return mean_distance
399 |
400 | def reset(self):
401 | # Reset internal state after an episode is complete.
402 | if self.action_noise is not None:
403 | self.action_noise.reset()
404 | if self.param_noise is not None:
405 | self.sess.run(self.perturb_policy_ops, feed_dict={
406 | self.param_noise_stddev: self.param_noise.current_stddev,
407 | })
--------------------------------------------------------------------------------
/baselines/baselines/ddpg/memory.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class RingBuffer(object):
5 | def __init__(self, maxlen, shape, dtype='float32'):
6 | self.maxlen = maxlen
7 | self.start = 0
8 | self.length = 0
9 | self.data = np.zeros((maxlen,) + shape).astype(dtype)
10 |
11 | def __len__(self):
12 | return self.length
13 |
14 | def __getitem__(self, idx):
15 | if idx < 0 or idx >= self.length:
16 | raise KeyError()
17 | return self.data[(self.start + idx) % self.maxlen]
18 |
19 | def get_batch(self, idxs):
20 | return self.data[(self.start + idxs) % self.maxlen]
21 |
22 | def append(self, v):
23 | if self.length < self.maxlen:
24 | # We have space, simply increase the length.
25 | self.length += 1
26 | elif self.length == self.maxlen:
27 | # No space, "remove" the first item.
28 | self.start = (self.start + 1) % self.maxlen
29 | else:
30 | # This should never happen.
31 | raise RuntimeError()
32 | self.data[(self.start + self.length - 1) % self.maxlen] = v
33 |
34 |
35 | def array_min2d(x):
36 | x = np.array(x)
37 | if x.ndim >= 2:
38 | return x
39 | return x.reshape(-1, 1)
40 |
41 |
42 | class Memory(object):
43 | def __init__(self, limit, action_shape, observation_shape):
44 | self.limit = limit
45 |
46 | self.observations0 = RingBuffer(limit, shape=observation_shape)
47 | self.actions = RingBuffer(limit, shape=action_shape)
48 | self.rewards = RingBuffer(limit, shape=(1,))
49 | self.terminals1 = RingBuffer(limit, shape=(1,))
50 | self.observations1 = RingBuffer(limit, shape=observation_shape)
51 |
52 | def sample(self, batch_size):
53 | # Draw such that we always have a proceeding element.
54 | batch_idxs = np.random.randint(self.nb_entries - 2, size=batch_size)
55 |
56 | obs0_batch = self.observations0.get_batch(batch_idxs)
57 | obs1_batch = self.observations1.get_batch(batch_idxs)
58 | action_batch = self.actions.get_batch(batch_idxs)
59 | reward_batch = self.rewards.get_batch(batch_idxs)
60 | terminal1_batch = self.terminals1.get_batch(batch_idxs)
61 |
62 | result = {
63 | 'obs0': array_min2d(obs0_batch),
64 | 'obs1': array_min2d(obs1_batch),
65 | 'rewards': array_min2d(reward_batch),
66 | 'actions': array_min2d(action_batch),
67 | 'terminals1': array_min2d(terminal1_batch),
68 | }
69 | return result
70 |
71 | def append(self, obs0, action, reward, obs1, terminal1, training=True):
72 | if not training:
73 | return
74 |
75 | self.observations0.append(obs0)
76 | self.actions.append(action)
77 | self.rewards.append(reward)
78 | self.observations1.append(obs1)
79 | self.terminals1.append(terminal1)
80 |
81 | @property
82 | def nb_entries(self):
83 | return len(self.observations0)
84 |
--------------------------------------------------------------------------------
/baselines/baselines/ddpg/models.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from baselines.common.models import get_network_builder
3 |
4 |
5 | class Model(object):
6 | def __init__(self, name, network='mlp', **network_kwargs):
7 | self.name = name
8 | self.network_builder = get_network_builder(network)(**network_kwargs)
9 |
10 | @property
11 | def vars(self):
12 | return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.name)
13 |
14 | @property
15 | def trainable_vars(self):
16 | return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name)
17 |
18 | @property
19 | def perturbable_vars(self):
20 | return [var for var in self.trainable_vars if 'LayerNorm' not in var.name]
21 |
22 |
23 | class Actor(Model):
24 | def __init__(self, nb_actions, name='actor', network='mlp', **network_kwargs):
25 | super().__init__(name=name, network=network, **network_kwargs)
26 | self.nb_actions = nb_actions
27 |
28 | def __call__(self, obs, reuse=False):
29 | with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
30 | x = self.network_builder(obs)
31 | x = tf.layers.dense(x, self.nb_actions, kernel_initializer=tf.random_uniform_initializer(minval=-3e-3, maxval=3e-3))
32 | x = tf.nn.tanh(x)
33 | return x
34 |
35 |
36 | class Critic(Model):
37 | def __init__(self, name='critic', network='mlp', **network_kwargs):
38 | super().__init__(name=name, network=network, **network_kwargs)
39 | self.layer_norm = True
40 |
41 | def __call__(self, obs, action, reuse=False):
42 | with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
43 | x = tf.concat([obs, action], axis=-1) # this assumes observation and action can be concatenated
44 | x = self.network_builder(x)
45 | x = tf.layers.dense(x, 1, kernel_initializer=tf.random_uniform_initializer(minval=-3e-3, maxval=3e-3), name='output')
46 | return x
47 |
48 | @property
49 | def output_vars(self):
50 | output_vars = [var for var in self.trainable_vars if 'output' in var.name]
51 | return output_vars
52 |
--------------------------------------------------------------------------------
/baselines/baselines/ddpg/noise.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class AdaptiveParamNoiseSpec(object):
5 | def __init__(self, initial_stddev=0.1, desired_action_stddev=0.1, adoption_coefficient=1.01):
6 | self.initial_stddev = initial_stddev
7 | self.desired_action_stddev = desired_action_stddev
8 | self.adoption_coefficient = adoption_coefficient
9 |
10 | self.current_stddev = initial_stddev
11 |
12 | def adapt(self, distance):
13 | if distance > self.desired_action_stddev:
14 | # Decrease stddev.
15 | self.current_stddev /= self.adoption_coefficient
16 | else:
17 | # Increase stddev.
18 | self.current_stddev *= self.adoption_coefficient
19 |
20 | def get_stats(self):
21 | stats = {
22 | 'param_noise_stddev': self.current_stddev,
23 | }
24 | return stats
25 |
26 | def __repr__(self):
27 | fmt = 'AdaptiveParamNoiseSpec(initial_stddev={}, desired_action_stddev={}, adoption_coefficient={})'
28 | return fmt.format(self.initial_stddev, self.desired_action_stddev, self.adoption_coefficient)
29 |
30 |
31 | class ActionNoise(object):
32 | def reset(self):
33 | pass
34 |
35 |
36 | class NormalActionNoise(ActionNoise):
37 | def __init__(self, mu, sigma):
38 | self.mu = mu
39 | self.sigma = sigma
40 |
41 | def __call__(self):
42 | return np.random.normal(self.mu, self.sigma)
43 |
44 | def __repr__(self):
45 | return 'NormalActionNoise(mu={}, sigma={})'.format(self.mu, self.sigma)
46 |
47 |
48 | # Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
49 | class OrnsteinUhlenbeckActionNoise(ActionNoise):
50 | def __init__(self, mu, sigma, theta=.15, dt=1e-2, x0=None):
51 | self.theta = theta
52 | self.mu = mu
53 | self.sigma = sigma
54 | self.dt = dt
55 | self.x0 = x0
56 | self.reset()
57 |
58 | def __call__(self):
59 | x = self.x_prev + self.theta * (self.mu - self.x_prev) * self.dt + self.sigma * np.sqrt(self.dt) * np.random.normal(size=self.mu.shape)
60 | self.x_prev = x
61 | return x
62 |
63 | def reset(self):
64 | self.x_prev = self.x0 if self.x0 is not None else np.zeros_like(self.mu)
65 |
66 | def __repr__(self):
67 | return 'OrnsteinUhlenbeckActionNoise(mu={}, sigma={})'.format(self.mu, self.sigma)
68 |
--------------------------------------------------------------------------------
/baselines/baselines/ddpg/test_smoke.py:
--------------------------------------------------------------------------------
1 | from baselines.run import main as M
2 |
3 | def _run(argstr):
4 | M(('--alg=ddpg --env=Pendulum-v0 --num_timesteps=0 ' + argstr).split(' '))
5 |
6 | def test_popart():
7 | _run('--normalize_returns=True --popart=True')
8 |
9 | def test_noise_normal():
10 | _run('--noise_type=normal_0.1')
11 |
12 | def test_noise_ou():
13 | _run('--noise_type=ou_0.1')
14 |
15 | def test_noise_adaptive():
16 | _run('--noise_type=adaptive-param_0.2,normal_0.1')
17 |
18 |
--------------------------------------------------------------------------------
/compare.py:
--------------------------------------------------------------------------------
1 | # --------------------------- IMPORT LIBRARIES -------------------------
2 | import numpy as np
3 | import pandas as pd
4 | import matplotlib.pyplot as plt
5 | from datetime import datetime
6 | import data_preprocessing as dp
7 | from sklearn.preprocessing import MinMaxScaler
8 |
9 | import keras
10 | from keras.models import Sequential
11 | from keras.layers.recurrent import LSTM
12 | from keras.callbacks import ModelCheckpoint, EarlyStopping
13 | from keras.models import load_model
14 | from keras.layers import Dense, Dropout
15 |
16 | # ------------------------- GLOBAL PARAMETERS -------------------------
17 | # Start and end period of historical data in question
18 | START_TRAIN = datetime(2008, 12, 31)
19 | END_TRAIN = datetime(2017, 2, 12)
20 | START_TEST = datetime(2017, 2, 12)
21 | END_TEST = datetime(2019, 2, 22)
22 |
23 | STARTING_ACC_BALANCE = 100000
24 | NUMBER_NON_CORR_STOCKS = 5
25 | # Number of times of no-improvement before training is stop.
26 | PATIENCE = 30
27 |
28 | # Pools of stocks to trade
29 | DJI = ['MMM', 'AXP', 'AAPL', 'BA', 'CAT', 'CVX', 'CSCO', 'KO', 'DIS', 'XOM', 'GE', 'GS', 'HD', 'IBM', 'INTC', 'JNJ',
30 | 'JPM', 'MCD', 'MRK', 'MSFT', 'NKE', 'PFE', 'PG', 'UTX', 'UNH', 'VZ', 'WMT']
31 |
32 | DJI_N = ['3M', 'American Express', 'Apple', 'Boeing', 'Caterpillar', 'Chevron', 'Cisco Systems', 'Coca-Cola', 'Disney'
33 | , 'ExxonMobil', 'General Electric', 'Goldman Sachs', 'Home Depot', 'IBM', 'Intel', 'Johnson & Johnson',
34 | 'JPMorgan Chase', 'McDonalds', 'Merck', 'Microsoft', 'NIKE', 'Pfizer', 'Procter & Gamble',
35 | 'United Technologies', 'UnitedHealth Group', 'Verizon Communications', 'Wal Mart']
36 |
37 | # Market and macroeconomic data to be used as context data
38 | CONTEXT_DATA = ['^GSPC', '^DJI', '^IXIC', '^RUT', 'SPY', 'QQQ', '^VIX', 'GLD', '^TYX', '^TNX', 'SHY', 'SHV']
39 |
40 | # --------------------------------- CLASSES ------------------------------------
41 | class Trading:
42 | def __init__(self, recovered_data_lstm, portfolio_stock_price, portfolio_stock_volume, test_set, non_corr_stocks):
43 | self.test_set = test_set
44 | self.ncs = non_corr_stocks
45 | self.stock_price = portfolio_stock_price
46 | self.stock_volume = portfolio_stock_volume
47 | self.generate_signals(recovered_data_lstm)
48 |
49 | def generate_signals(self, predicted_tomorrow_close):
50 | """
51 | Generate trade signla from the prediction of the LSTM model
52 | :param predicted_tomorrow_close:
53 | :return:
54 | """
55 |
56 | predicted_tomorrow_close.columns = self.stock_price.columns
57 | predicted_next_day_returns = (predicted_tomorrow_close / predicted_tomorrow_close.shift(1) - 1).dropna()
58 | next_day_returns = (self.stock_price / self.stock_price.shift(1) - 1).dropna()
59 | signals = pd.DataFrame(index=predicted_tomorrow_close.index, columns=self.stock_price.columns)
60 |
61 | for s in self.stock_price.columns:
62 | for d in next_day_returns.index:
63 | if predicted_tomorrow_close[s].loc[d] > self.stock_price[s].loc[d] and next_day_returns[s].loc[
64 | d] > 0 and predicted_next_day_returns[s].loc[d] > 0:
65 | signals[s].loc[d] = 2
66 | elif predicted_tomorrow_close[s].loc[d] < self.stock_price[s].loc[d] and next_day_returns[s].loc[
67 | d] < 0 and predicted_next_day_returns[s].loc[d] < 0:
68 | signals[s].loc[d] = -2
69 | elif predicted_tomorrow_close[s].loc[d] > self.stock_price[s].loc[d]:
70 | signals[s].loc[d] = 2
71 | elif next_day_returns[s].loc[d] > 0:
72 | signals[s].loc[d] = 1
73 | elif next_day_returns[s].loc[d] < 0:
74 | signals[s].loc[d] = -1
75 | elif predicted_next_day_returns[s].loc[d] > 0:
76 | signals[s].loc[d] = 2
77 | elif predicted_next_day_returns[s].loc[d] < 0:
78 | signals[s].loc[d] = -1
79 | else:
80 | signals[s].loc[d] = 0
81 | signals.loc[self.stock_price.index[0]] = [0, 0, 0, 0, 0]
82 | self.signals = signals
83 |
84 | def _sell(self, stock, sig, day):
85 | """
86 | Perform and record sell transactions.
87 | """
88 |
89 | # Get the index of the stock
90 | idx = self.ncs.index(stock)
91 |
92 | # Only need to sell the unit recommended by the trading agent, not necessarily all stock unit.
93 | num_share = min(abs(int(sig)), self.state[idx + 1])
94 | commission = dp.Trading.commission(num_share, self.stock_price.loc[day][stock])
95 | # Calculate slipped price. Though, at max trading volume of 10 shares, there's hardly any slippage
96 | transacted_price = dp.Trading.slippage_price(self.stock_price.loc[day][stock], -num_share,
97 | self.stock_volume.loc[day][stock])
98 |
99 | # If there is existing stock holding
100 | if self.state[idx + 1] > 0:
101 | # Only need to sell the unit recommended by the trading agent, not necessarily all stock unit.
102 | # Update account balance after transaction
103 | self.state[0] += (transacted_price * num_share) - commission
104 | # Update stock holding
105 | self.state[idx + 1] -= num_share
106 | # Reset transacted buy price record to 0.0 if there is no more stock holding
107 | if self.state[idx + 1] == 0.0:
108 | self.buy_price[idx] = 0.0
109 |
110 | else:
111 | pass
112 |
113 | def _buy(self, stock, sig, day):
114 | """
115 | Perform and record buy transactions.
116 | """
117 |
118 | idx = self.ncs.index(stock)
119 | # Calculate the maximum possible number of stock unit the current cash can buy
120 | available_unit = self.state[0] // self.stock_price.loc[day][stock]
121 | num_share = min(available_unit, int(sig))
122 | # Deduct the traded amount from account balance. If available balance is not enough to purchase stock unit
123 | # recommended by trading agent's action, just use what is left.
124 | commission = dp.Trading.commission(num_share, self.stock_price.loc[day][stock])
125 | # Calculate slipped price. Though, at max trading volume of 10 shares, there's hardly any slippage
126 | transacted_price = dp.Trading.slippage_price(self.stock_price.loc[day][stock], num_share,
127 | self.stock_volume.loc[day][stock])
128 | # Revise number of share to trade if account balance does not have enough
129 | if (self.state[0] - commission) < transacted_price * num_share:
130 | num_share = (self.state[0] - commission) // transacted_price
131 | self.state[0] -= (transacted_price * num_share) + commission
132 |
133 | # If there are existing stock holding already, calculate the average buy price
134 | if self.state[idx + 2] > 0.0:
135 | existing_unit = self.state[idx + 2]
136 | previous_buy_price = self.buy_price[idx]
137 | additional_unit = min(available_unit, int(sig))
138 | new_holding = existing_unit + additional_unit
139 | self.buy_price[idx] = ((existing_unit * previous_buy_price) + (
140 | self.stock_price.loc[day][stock] * additional_unit)) / new_holding
141 | # if there is no existing stock holding, simply record the current buy price
142 | elif self.state[idx + 2] == 0.0:
143 | self.buy_price[idx] = self.stock_price.loc[day][stock]
144 |
145 | # Update stock holding at its index
146 | self.state[idx + 1] += min(available_unit, int(sig))
147 |
148 | def execute_trading(self, non_corr_stocks):
149 | """
150 | This function performs long only trades for the LSTM model.
151 | """
152 |
153 | # The money in the trading account
154 | self.acc_balance = [STARTING_ACC_BALANCE]
155 | self.total_asset = self.acc_balance
156 | self.portfolio_asset = [0.0]
157 | self.buy_price = np.zeros((1, len(non_corr_stocks))).flatten()
158 | # Unrealized profit and loss
159 | self.unrealized_pnl = [0.0]
160 | # The value of all-stock holdings
161 | self.portfolio_value = 0.0
162 |
163 | # The state of the trading environment, defined by account balance, unrealized profit and loss, relevant
164 | # stock technical data & current stock holdings
165 | self.state = self.acc_balance + self.unrealized_pnl + [0 for i in range(len(non_corr_stocks))]
166 |
167 | # Slide through the timeline
168 | for d in self.test_set.index[:-1]:
169 |
170 | signals = self.signals.loc[d]
171 |
172 | # Get the stocks to be sold
173 | sell_stocks = signals[signals < 0].sort_values(ascending=True)
174 | # Get the stocks to be bought
175 | buy_stocks = signals[signals > 0].sort_values(ascending=True)
176 |
177 | for idx, sig in enumerate(sell_stocks):
178 | self._sell(sell_stocks.index[idx], sig, d)
179 |
180 | for idx, sig in enumerate(buy_stocks):
181 | self._buy(buy_stocks.index[idx], sig, d)
182 |
183 | self.unrealized_pnl = np.sum(np.array(self.stock_price.loc[d] - self.buy_price) * np.array(
184 | self.state[2:]))
185 |
186 | # Current state space
187 | self.state = [self.state[0]] + [self.unrealized_pnl] + list(self.state[2:])
188 | # Portfolio value is the current stock prices multiply with their respective holdings
189 | portfolio_value = sum(np.array(self.stock_price.loc[d]) * np.array(self.state[2:]))
190 | # Total asset = account balance + portfolio value
191 | total_asset_ending = self.state[0] + portfolio_value
192 |
193 | # Update account balance statement
194 | self.acc_balance = np.append(self.acc_balance, self.state[0])
195 |
196 | # Update portfolio value statement
197 | self.portfolio_asset = np.append(self.portfolio_asset, portfolio_value)
198 |
199 | # Update total asset statement
200 | self.total_asset = np.append(self.total_asset, total_asset_ending)
201 |
202 | trading_book = pd.DataFrame(index=self.test_set.index,
203 | columns=["Cash balance", "Portfolio value", "Total asset", "Returns", "CumReturns"])
204 | trading_book["Cash balance"] = self.acc_balance
205 | trading_book["Portfolio value"] = self.portfolio_asset
206 | trading_book["Total asset"] = self.total_asset
207 | trading_book["Returns"] = trading_book["Total asset"] / trading_book["Total asset"].shift(1) - 1
208 | trading_book["CumReturns"] = trading_book["Returns"].add(1).cumprod().fillna(1)
209 | trading_book.to_csv('./test_result/trading_book_backtest.csv')
210 |
211 | kpi = dp.MathCalc.calc_kpi(trading_book)
212 | kpi.to_csv('./test_result/kpi_backtest.csv')
213 |
214 | print("\n")
215 | print("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^")
216 | print(
217 | " KPI of RNN-LSTM modelled trading strategy for a portfolio of {} non-correlated stocks".format(
218 | NUMBER_NON_CORR_STOCKS))
219 | print(kpi)
220 | print("vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv")
221 |
222 | return trading_book, kpi
223 |
224 |
225 | class Data_ScaleSplit:
226 | """
227 | This class preprosses data for the LSTM model.
228 | """
229 |
230 | def __init__(self, X, selected_stocks_price, train_portion):
231 | self.X = X
232 | self.stock_price = selected_stocks_price
233 | self.generate_labels()
234 | self.scale_data()
235 | self.split_data(train_portion)
236 |
237 |
238 | def generate_labels(self):
239 | """
240 | Generate label data for tomorrow's prediction.
241 | """
242 | self.Y = self.stock_price.shift(-1)
243 | self.Y.columns = [c + '_Y' for c in self.Y.columns]
244 |
245 | def scale_data(self):
246 | """
247 | Scale the X and Y data with minimax scaller.
248 | The scaling is done separately for the train and test set to avoid look ahead bias.
249 | """
250 | self.XY = pd.concat([self.X, self.Y], axis=1).dropna()
251 | train_set = self.XY.loc[START_TRAIN:END_TRAIN]
252 | test_set = self.XY.loc[START_TEST:END_TEST]
253 | # MinMax scaling
254 | minmaxed_scaler = MinMaxScaler(feature_range=(0, 1))
255 | self.minmaxed = minmaxed_scaler.fit(train_set)
256 | train_set_matrix = minmaxed_scaler.transform(train_set)
257 | test_set_matrix = minmaxed_scaler.transform(test_set)
258 | self.train_set_matrix_df = pd.DataFrame(train_set_matrix, index=train_set.index, columns=train_set.columns)
259 | self.test_set_matrix_df = pd.DataFrame(test_set_matrix, index=test_set.index, columns=test_set.columns)
260 | self.XY = pd.concat([self.train_set_matrix_df, self.test_set_matrix_df], axis=0)
261 |
262 | # print ("Train set shape: ", train_set_matrix.shape)
263 | # print ("Test set shape: ", test_set_matrix.shape)
264 |
265 | def split_data(self, train_portion):
266 | """
267 | Perform train test split with cut off date defined.
268 | """
269 | df_values = self.XY.values
270 | # split into train and test sets
271 |
272 | train = df_values[:int(train_portion), :]
273 | test = df_values[int(train_portion):, :]
274 | # split into input and outputs
275 | train_X, self.train_y = train[:, :-5], train[:, -5:]
276 | test_X, self.test_y = test[:, :-5], test[:, -5:]
277 | # reshape input to be 3D [samples, timesteps, features]
278 | self.train_X = train_X.reshape((train_X.shape[0], 1, train_X.shape[1]))
279 | self.test_X = test_X.reshape((test_X.shape[0], 1, test_X.shape[1]))
280 | print("\n")
281 | print("Dataset shapes >")
282 | print("Train feature data shape:", self.train_X.shape)
283 | print("Train label data shape:", self.train_y.shape)
284 | print("Test feature data shape:", self.test_X.shape)
285 | print("Test label data shape:", self.test_y.shape)
286 |
287 | def get_prediction(self, model_lstm):
288 | """
289 | Get the model prediction, inverse transform scaling to get back to original price and
290 | reassemble the full XY dataframe.
291 | """
292 | # Get the model to predict test_y
293 |
294 | predicted_y_lstm = model_lstm.predict(self.test_X, batch_size=None, verbose=0, steps=None)
295 | # Get the model to generate train_y
296 | trained_y_lstm = model_lstm.predict(self.train_X, batch_size=None, verbose=0, steps=None)
297 |
298 | # combine the model generated train_y and test_y to create the full_y
299 | y_lstm = pd.DataFrame(data=np.vstack((trained_y_lstm, predicted_y_lstm)),
300 | columns=[c + '_LSTM' for c in self.XY.columns[-5:]], index=self.XY.index)
301 |
302 | # Combine the original full length y with model generated y
303 | lstm_y_df = pd.concat([self.XY[self.XY.columns[-5:]], y_lstm], axis=1)
304 | # Get the full length XY data with the length of model generated y
305 | lstm_df = self.XY.loc[lstm_y_df.index]
306 | # Replace the full length XY data's Y with the model generated Y
307 | lstm_df[lstm_df.columns[-5:]] = lstm_y_df[lstm_y_df.columns[-5:]]
308 | # Inverse transform it to get back the original data, the model generated y would be transformed to reveal its true predicted value
309 | recovered_data_lstm = self.minmaxed.inverse_transform(lstm_df)
310 | # Create a dataframe from it
311 | self.recovered_data_lstm = pd.DataFrame(data=recovered_data_lstm, columns=self.XY.columns, index=lstm_df.index)
312 | return self.recovered_data_lstm
313 |
314 | def get_train_test_set(self):
315 | """
316 | Get the split X and y data.
317 | """
318 | return self.train_X, self.train_y, self.test_X, self.test_y
319 |
320 | def get_all_data(self):
321 | """
322 | Get the full XY data and the original stock price.
323 | """
324 | return self.XY, self.stock_price
325 |
326 | class Model:
327 | """
328 | This class contains all the functions required to build a LSTM or LSTM-CNN model
329 | It also offer an option to load a pre-built model.
330 | """
331 | @staticmethod
332 | def train_model(model, train_X, train_y, model_type):
333 | """
334 | Try to load a pre-built model.
335 | Otherwise fit a new mode with the training data. Once training is done, save the model.
336 | """
337 | es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=PATIENCE)
338 | if model_type == "LSTM":
339 | batch_size = 4
340 | mc = ModelCheckpoint('./model/best_lstm_model.h5', monitor='val_loss', save_weights_only=False,
341 | mode='min', verbose=1, save_best_only=True)
342 | try:
343 | model = load_model('./model/best_lstm_model.h5')
344 | print("\n")
345 | print("Loading pre-saved model ...")
346 | except:
347 | print("\n")
348 | print("No pre-saved model, training new model.")
349 | pass
350 | elif model_type == "CNN":
351 | batch_size = 8
352 | mc = ModelCheckpoint('./model/best_cnn_model.h5'.format(symbol), monitor='val_loss', save_weights_only=False,
353 | mode='min', verbose=1, save_best_only=True)
354 | try:
355 | model = load_model('./model/best_cnn_model.h5')
356 | print("\n")
357 | print("Loading pre-saved model ...")
358 | except:
359 | print("\n")
360 | print("No pre-saved model, training new model.")
361 | pass
362 | # fit network
363 | history = model.fit(
364 | train_X,
365 | train_y,
366 | epochs=500,
367 | batch_size=batch_size,
368 | validation_split=0.2,
369 | verbose=2,
370 | shuffle=True,
371 | # callbacks=[es, mc, tb, LearningRateTracker()])
372 | callbacks=[es, mc])
373 |
374 | if model_type == "LSTM":
375 | model.save('./model/best_lstm_model.h5')
376 | elif model_type == "CNN":
377 | model.save('./model/best_cnn_model.h5')
378 |
379 | return history, model
380 |
381 | @staticmethod
382 | def plot_training(history,nn):
383 | """
384 | Plot the historical training loss.
385 | """
386 | # plot history
387 | plt.plot(history.history['loss'], label='train')
388 | plt.plot(history.history['val_loss'], label='test')
389 | plt.legend()
390 | plt.title('Training loss history for {} model'.format(nn))
391 | plt.savefig('./train_result/training_loss_history_{}.png'.format(nn))
392 | plt.show()
393 |
394 | @staticmethod
395 | def build_rnn_model(train_X):
396 | """
397 | Build the RNN model architecture.
398 | """
399 | # design network
400 | print("\n")
401 | print("RNN LSTM model architecture >")
402 | model = Sequential()
403 | model.add(LSTM(128, kernel_initializer='random_uniform',
404 | bias_initializer='zeros', return_sequences=True,
405 | recurrent_dropout=0.2,
406 | input_shape=(train_X.shape[1], train_X.shape[2])))
407 | model.add(Dropout(0.5))
408 | model.add(LSTM(64, kernel_initializer='random_uniform',
409 | return_sequences=True,
410 | # bias_regularizer=regularizers.l2(0.01),
411 | # kernel_regularizer=regularizers.l1_l2(l1=0.01,l2=0.01),
412 | # activity_regularizer=regularizers.l2(0.01),
413 | bias_initializer='zeros'))
414 | model.add(Dropout(0.5))
415 | model.add(LSTM(64, kernel_initializer='random_uniform',
416 | # bias_regularizer=regularizers.l2(0.01),
417 | # kernel_regularizer=regularizers.l1_l2(l1=0.01,l2=0.01),
418 | # activity_regularizer=regularizers.l2(0.01),
419 | bias_initializer='zeros'))
420 | model.add(Dropout(0.5))
421 | model.add(Dense(5))
422 | # optimizer = keras.optimizers.RMSprop(lr=0.25, rho=0.9, epsilon=1e-0)
423 | # optimizer = keras.optimizers.Adagrad(lr=0.0001, epsilon=1e-08, decay=0.00002)
424 | # optimizer = keras.optimizers.Adam(lr=0.0001)
425 | # optimizer = keras.optimizers.Nadam(lr=0.0002, beta_1=0.9, beta_2=0.999, schedule_decay=0.004)
426 | # optimizer = keras.optimizers.Adamax(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0)
427 | optimizer = keras.optimizers.Adadelta(lr=0.2, rho=0.95, epsilon=None, decay=0.00001)
428 |
429 | model.compile(loss='mae', optimizer=optimizer, metrics=['mse', 'mae'])
430 | model.summary()
431 | print("\n")
432 | return model
433 |
434 |
435 |
436 | # ------------------------------ Main Program ---------------------------------
437 |
438 | def main():
439 | print("\n")
440 | print("######################### This program compare performance of trading strategies ############################")
441 | print("\n")
442 | print( "1. Simple Buy and hold strategy of a portfolio with {} non-correlated stocks".format(NUMBER_NON_CORR_STOCKS))
443 | print( "2. Sharpe ratio optimized portfolio of {} non-correlated stocks".format(NUMBER_NON_CORR_STOCKS))
444 | print( "3. Minimum variance optimized portfolio of {} non-correlated stocks".format(NUMBER_NON_CORR_STOCKS))
445 | print( "4. Simple Buy and hold strategy ")
446 | print( "1. Simple Buy and hold strategy ")
447 |
448 | print("\n")
449 |
450 | print("Starting to pre-process data for trading environment construction ... ")
451 | # Data Preprocessing
452 | dataset = dp.DataRetrieval()
453 | dow_stocks_train, dow_stocks_test = dataset.get_all()
454 | train_portion = len(dow_stocks_train)
455 | dow_stock_volume = dataset.components_df_v[DJI]
456 | portfolios = dp.Trading(dow_stocks_train, dow_stocks_test, dow_stock_volume.loc[START_TEST:END_TEST])
457 | _, _, non_corr_stocks = portfolios.find_non_correlate_stocks(NUMBER_NON_CORR_STOCKS)
458 | non_corr_stocks_data = dataset.get_adj_close(non_corr_stocks)
459 | print("\n")
460 | print("Base on non-correlation preference, {} stocks are selected for portfolio construction:".format(NUMBER_NON_CORR_STOCKS))
461 |
462 | for stock in non_corr_stocks:
463 | print(DJI_N[DJI.index(stock)])
464 | print("\n")
465 |
466 | sharpe_portfolio, min_variance_portfolio = portfolios.find_efficient_frontier(non_corr_stocks_data, non_corr_stocks)
467 | print("Risk-averse portfolio with low variance:")
468 | print(min_variance_portfolio.T)
469 | print("High return portfolio with high Sharpe ratio")
470 | print(sharpe_portfolio.T)
471 | dow_stocks = pd.concat([dow_stocks_train, dow_stocks_test], axis=0)
472 |
473 | test_values_buyhold, test_returns_buyhold, test_kpi_buyhold = \
474 | portfolios.diversified_trade(non_corr_stocks, dow_stocks.loc[START_TEST:END_TEST][non_corr_stocks])
475 | print("\n")
476 | print("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^")
477 | print(" KPI of a simple buy and hold strategy for a portfolio of {} non-correlated stocks".format(NUMBER_NON_CORR_STOCKS))
478 | print("------------------------------------------------------------------------------------")
479 | print(test_kpi_buyhold)
480 | print("vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv")
481 |
482 |
483 | test_values_sharpe_optimized_buyhold, test_returns_sharpe_optimized_buyhold, test_kpi_sharpe_optimized_buyhold =\
484 | portfolios.optimized_diversified_trade(non_corr_stocks, sharpe_portfolio, dow_stocks.loc[START_TEST:END_TEST][non_corr_stocks])
485 | print("\n")
486 | print("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^")
487 | print(" KPI of a simple buy and hold strategy for a Sharpe ratio optimized portfolio of {} non-correlated stocks".format(NUMBER_NON_CORR_STOCKS))
488 | print("------------------------------------------------------------------------------------")
489 | print(test_kpi_sharpe_optimized_buyhold)
490 | print("vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv")
491 |
492 | test_values_minvar_optimized_buyhold, test_returns_minvar_optimized_buyhold, test_kpi_minvar_optimized_buyhold = \
493 | portfolios.optimized_diversified_trade(non_corr_stocks, min_variance_portfolio, dow_stocks.loc[START_TEST:END_TEST][non_corr_stocks])
494 | print("\n")
495 | print("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^")
496 | print(" KPI of a simple buy and hold strategy for a Minimum variance optimized portfolio of {} non-correlated stocks".format(NUMBER_NON_CORR_STOCKS))
497 | print("------------------------------------------------------------------------------------")
498 | print(test_kpi_minvar_optimized_buyhold)
499 | print("vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv")
500 |
501 | plot = dp.UserDisplay()
502 | test_returns = dp.MathCalc.assemble_returns(test_returns_buyhold['Returns'],
503 | test_returns_sharpe_optimized_buyhold['Returns'],
504 | test_returns_minvar_optimized_buyhold['Returns'])
505 | test_cum_returns = dp.MathCalc.assemble_cum_returns(test_returns_buyhold['CumReturns'],
506 | test_returns_sharpe_optimized_buyhold['CumReturns'],
507 | test_returns_minvar_optimized_buyhold['CumReturns'])
508 |
509 | print("\n")
510 | print("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^")
511 | print("Buy and hold strategies computation completed. Now creating prediction model using RNN LSTM architecture")
512 | print("--------------------------------------------------------------------------------------------------------")
513 | print("vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv")
514 |
515 |
516 | # Use feature data preprocessed by StartTrader, so that they both use the same training data, to have a fair comparison
517 | input_states = pd.read_csv("./data/ddpg_input_states.csv", index_col='Date', parse_dates=True)
518 | scale_split = Data_ScaleSplit(input_states, dow_stocks[non_corr_stocks], train_portion)
519 | train_X, train_y, test_X, test_y = scale_split.get_train_test_set()
520 |
521 | modelling = Model
522 | model_lstm = modelling.build_rnn_model(train_X)
523 | history_lstm, model_lstm = modelling.train_model(model_lstm, train_X, train_y, "LSTM")
524 | print("RNN model loaded, now training the model again, training will stop after {} episodes no improvement")
525 | modelling.plot_training(history_lstm, "LSTM")
526 | print("Training completed, loading prediction using the trained RNN model >")
527 | recovered_data_lstm = scale_split.get_prediction(model_lstm)
528 | plot.plot_prediction(dow_stocks[non_corr_stocks].loc[recovered_data_lstm.index], recovered_data_lstm[recovered_data_lstm.columns[-5:]] , len(train_X), "LSTM")
529 |
530 | # Get the original stock price with the prediction length
531 | original_portfolio_stock_price = dow_stocks[non_corr_stocks].loc[recovered_data_lstm.index]
532 | # Get the predicted stock price with the prediction length
533 | predicted_portfolio_stock_price = recovered_data_lstm[recovered_data_lstm.columns[-5:]]
534 | print("Bactesting the RNN-LSTM model now")
535 | # Run backtest, the backtester is similar to those use by StarTrader too
536 | backtest = Trading(predicted_portfolio_stock_price, original_portfolio_stock_price, dow_stock_volume[non_corr_stocks].loc[recovered_data_lstm.index], dow_stocks_test[non_corr_stocks], non_corr_stocks)
537 | trading_book, kpi = backtest.execute_trading(non_corr_stocks)
538 | # Load backtest result for StarTrader using DDPG as learning algorithm
539 | ddpg_backtest = pd.read_csv('./test_result/trading_book_test_1.csv', index_col='Unnamed: 0', parse_dates=True)
540 | print("Backtesting completed, plotting comparison of trading models")
541 | # Compare performance on all 4 trading type
542 | djia_daily = dataset._get_daily_data(CONTEXT_DATA[1]).loc[START_TEST:END_TEST]['Close']
543 | #print(djia_daily)
544 | all_benchmark_returns = test_returns
545 | all_benchmark_returns['DJIA'] = dp.MathCalc.calc_return(djia_daily)
546 | all_benchmark_returns['RNN LSTM'] = trading_book['Returns']
547 | all_benchmark_returns['DDPG'] = ddpg_backtest['Returns']
548 | all_benchmark_returns.to_csv('./test_result/all_strategies_returns.csv')
549 | plot.plot_portfolio_risk(all_benchmark_returns)
550 |
551 | all_benchmark_cum_returns = test_cum_returns
552 | all_benchmark_cum_returns['DJIA'] = all_benchmark_returns['DJIA'].add(1).cumprod().fillna(1)
553 | all_benchmark_cum_returns['RNN LSTM'] = trading_book['CumReturns']
554 | all_benchmark_cum_returns['DDPG'] = ddpg_backtest['CumReturns']
555 | all_benchmark_cum_returns.to_csv('./test_result/all_strategies_cum_returns.csv')
556 | plot.plot_portfolio_return(all_benchmark_cum_returns)
557 |
558 |
559 | if __name__ == '__main__':
560 | main()
--------------------------------------------------------------------------------
/feature_select.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------- IMPORT LIBRARIES -------------------------------------------
2 | import pandas as pd
3 | import numpy as np
4 | import lightgbm as lgb
5 | import gc
6 | from itertools import chain
7 | from sklearn.cluster import KMeans
8 | from mpl_toolkits.mplot3d import Axes3D
9 | import matplotlib.pyplot as plt
10 | from mpl_finance import candlestick_ohlc
11 | import copy
12 | from matplotlib.dates import (DateFormatter, WeekdayLocator, DayLocator, MONDAY)
13 | from sklearn.model_selection import train_test_split
14 | import warnings
15 | warnings.filterwarnings("ignore")
16 |
17 | # ------------------------------------------------ CLASSES --------------------------------------------
18 |
19 | class FeatureSelector():
20 | """
21 | Courtesy of William Koehrsen from Feature Labs
22 | Class for performing feature selection for machine learning or data preprocessing.
23 |
24 | Implements five different methods to identify features for removal
25 |
26 | 1. Find columns with a missing percentage greater than a specified threshold
27 | 2. Find columns with a single unique value
28 | 3. Find collinear variables with a correlation greater than a specified correlation coefficient
29 | 4. Find features with 0.0 feature importance from a gradient boosting machine (gbm)
30 | 5. Find low importance features that do not contribute to a specified cumulative feature importance from the gbm
31 |
32 | Parameters
33 | --------
34 | data : dataframe
35 | A dataset with observations in the rows and features in the columns
36 |
37 | labels : array or series, default = None
38 | Array of labels for training the machine learning model to find feature importances. These can be either binary labels
39 | (if task is 'classification') or continuous targets (if task is 'regression').
40 | If no labels are provided, then the feature importance based methods are not available.
41 |
42 | Attributes
43 | --------
44 |
45 | ops : dict
46 | Dictionary of operations run and features identified for removal
47 |
48 | missing_stats : dataframe
49 | The fraction of missing values for all features
50 |
51 | record_missing : dataframe
52 | The fraction of missing values for features with missing fraction above threshold
53 |
54 | unique_stats : dataframe
55 | Number of unique values for all features
56 |
57 | record_single_unique : dataframe
58 | Records the features that have a single unique value
59 |
60 | corr_matrix : dataframe
61 | All correlations between all features in the data
62 |
63 | record_collinear : dataframe
64 | Records the pairs of collinear variables with a correlation coefficient above the threshold
65 |
66 | feature_importances : dataframe
67 | All feature importances from the gradient boosting machine
68 |
69 | record_zero_importance : dataframe
70 | Records the zero importance features in the data according to the gbm
71 |
72 | record_low_importance : dataframe
73 | Records the lowest importance features not needed to reach the threshold of cumulative importance according to the gbm
74 |
75 |
76 | Notes
77 | --------
78 |
79 | - All 5 operations can be run with the `identify_all` method.
80 | - If using feature importances, one-hot encoding is used for categorical variables which creates new columns
81 |
82 | """
83 |
84 | def __init__(self, data, labels=None):
85 |
86 | # Dataset and optional training labels
87 | self.data = data
88 | self.labels = labels
89 |
90 | if labels is None:
91 | print('No labels provided. Feature importance based methods are not available.')
92 |
93 | self.base_features = list(data.columns)
94 | self.one_hot_features = None
95 |
96 | # Dataframes recording information about features to remove
97 | self.record_missing = None
98 | self.record_single_unique = None
99 | self.record_collinear = None
100 | self.record_zero_importance = None
101 | self.record_low_importance = None
102 |
103 | self.missing_stats = None
104 | self.unique_stats = None
105 | self.corr_matrix = None
106 | self.feature_importances = None
107 |
108 | # Dictionary to hold removal operations
109 | self.ops = {}
110 |
111 | self.one_hot_correlated = False
112 |
113 | def identify_missing(self, missing_threshold):
114 | """Find the features with a fraction of missing values above `missing_threshold`"""
115 |
116 | self.missing_threshold = missing_threshold
117 |
118 | # Calculate the fraction of missing in each column
119 | missing_series = self.data.isnull().sum() / self.data.shape[0]
120 | self.missing_stats = pd.DataFrame(missing_series).rename(columns={'index': 'feature', 0: 'missing_fraction'})
121 |
122 | # Sort with highest number of missing values on top
123 | self.missing_stats = self.missing_stats.sort_values('missing_fraction', ascending=False)
124 |
125 | # Find the columns with a missing percentage above the threshold
126 | record_missing = pd.DataFrame(missing_series[missing_series > missing_threshold]).reset_index().rename(columns=
127 | {'index': 'feature', 0: 'missing_fraction'})
128 |
129 | to_drop = list(record_missing['feature'])
130 |
131 | self.record_missing = record_missing
132 | self.ops['missing'] = to_drop
133 |
134 | print('%d features with greater than %0.2f missing values.\n' % (
135 | len(self.ops['missing']), self.missing_threshold))
136 |
137 | def identify_single_unique(self):
138 | """Finds features with only a single unique value. NaNs do not count as a unique value. """
139 |
140 | # Calculate the unique counts in each column
141 | unique_counts = self.data.nunique()
142 | self.unique_stats = pd.DataFrame(unique_counts).rename(columns={'index': 'feature', 0: 'nunique'})
143 | self.unique_stats = self.unique_stats.sort_values('nunique', ascending=True)
144 |
145 | # Find the columns with only one unique count
146 | record_single_unique = pd.DataFrame(unique_counts[unique_counts == 1]).reset_index().rename(
147 | columns={'index': 'feature',
148 | 0: 'nunique'})
149 |
150 | to_drop = list(record_single_unique['feature'])
151 |
152 | self.record_single_unique = record_single_unique
153 | self.ops['single_unique'] = to_drop
154 |
155 | print('%d features with a single unique value.\n' % len(self.ops['single_unique']))
156 |
157 | def identify_collinear(self, correlation_threshold, one_hot=False):
158 | """
159 | Finds collinear features based on the correlation coefficient between features.
160 | For each pair of features with a correlation coefficient greather than `correlation_threshold`,
161 | only one of the pair is identified for removal.
162 |
163 | Using code adapted from: https://chrisalbon.com/machine_learning/feature_selection/drop_highly_correlated_features/
164 |
165 | Parameters
166 | --------
167 |
168 | correlation_threshold : float between 0 and 1
169 | Value of the Pearson correlation cofficient for identifying correlation features
170 |
171 | one_hot : boolean, default = False
172 | Whether to one-hot encode the features before calculating the correlation coefficients
173 |
174 | """
175 |
176 | self.correlation_threshold = correlation_threshold
177 | self.one_hot_correlated = one_hot
178 |
179 | # Calculate the correlations between every column
180 | if one_hot:
181 |
182 | # One hot encoding
183 | features = pd.get_dummies(self.data)
184 | self.one_hot_features = [column for column in features.columns if column not in self.base_features]
185 |
186 | # Add one hot encoded data to original data
187 | self.data_all = pd.concat([features[self.one_hot_features], self.data], axis=1)
188 |
189 | corr_matrix = pd.get_dummies(features).corr()
190 |
191 | else:
192 | corr_matrix = self.data.corr()
193 |
194 | self.corr_matrix = corr_matrix
195 |
196 | # Extract the upper triangle of the correlation matrix
197 | upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(np.bool))
198 |
199 | # Select the features with correlations above the threshold
200 | # Need to use the absolute value
201 | to_drop = [column for column in upper.columns if any(upper[column].abs() > correlation_threshold)]
202 |
203 | # Dataframe to hold correlated pairs
204 | record_collinear = pd.DataFrame(columns=['drop_feature', 'corr_feature', 'corr_value'])
205 |
206 | # Iterate through the columns to drop to record pairs of correlated features
207 | for column in to_drop:
208 | # Find the correlated features
209 | corr_features = list(upper.index[upper[column].abs() > correlation_threshold])
210 |
211 | # Find the correlated values
212 | corr_values = list(upper[column][upper[column].abs() > correlation_threshold])
213 | drop_features = [column for _ in range(len(corr_features))]
214 |
215 | # Record the information (need a temp df for now)
216 | temp_df = pd.DataFrame.from_dict({'drop_feature': drop_features,
217 | 'corr_feature': corr_features,
218 | 'corr_value': corr_values})
219 |
220 | # Add to dataframe
221 | record_collinear = record_collinear.append(temp_df, ignore_index=True)
222 |
223 | self.record_collinear = record_collinear
224 | self.ops['collinear'] = to_drop
225 |
226 | print('%d features with a correlation magnitude greater than %0.2f.\n' % (
227 | len(self.ops['collinear']), self.correlation_threshold))
228 |
229 | def identify_zero_importance(self, task, eval_metric=None,
230 | n_iterations=10, early_stopping=True):
231 | """
232 |
233 | Identify the features with zero importance according to a gradient boosting machine.
234 | The gbm can be trained with early stopping using a validation set to prevent overfitting.
235 | The feature importances are averaged over `n_iterations` to reduce variance.
236 |
237 | Uses the LightGBM implementation (http://lightgbm.readthedocs.io/en/latest/index.html)
238 |
239 | Parameters
240 | --------
241 |
242 | eval_metric : string
243 | Evaluation metric to use for the gradient boosting machine for early stopping. Must be
244 | provided if `early_stopping` is True
245 |
246 | task : string
247 | The machine learning task, either 'classification' or 'regression'
248 |
249 | n_iterations : int, default = 10
250 | Number of iterations to train the gradient boosting machine
251 |
252 | early_stopping : boolean, default = True
253 | Whether or not to use early stopping with a validation set when training
254 |
255 |
256 | Notes
257 | --------
258 |
259 | - Features are one-hot encoded to handle the categorical variables before training.
260 | - The gbm is not optimized for any particular task and might need some hyperparameter tuning
261 | - Feature importances, including zero importance features, can change across runs
262 |
263 | """
264 |
265 | if early_stopping and eval_metric is None:
266 | raise ValueError("""eval metric must be provided with early stopping. Examples include "auc" for classification or
267 | "l2" for regression.""")
268 |
269 | if self.labels is None:
270 | raise ValueError("No training labels provided.")
271 |
272 | # One hot encoding
273 | features = pd.get_dummies(self.data)
274 | self.one_hot_features = [column for column in features.columns if column not in self.base_features]
275 |
276 | # Add one hot encoded data to original data
277 | self.data_all = pd.concat([features[self.one_hot_features], self.data], axis=1)
278 |
279 | # Extract feature names
280 | feature_names = list(features.columns)
281 |
282 | # Convert to np array
283 | features = np.array(features)
284 | labels = np.array(self.labels).reshape((-1,))
285 |
286 | # Empty array for feature importances
287 | feature_importance_values = np.zeros(len(feature_names))
288 |
289 | print('Training Gradient Boosting Model\n')
290 |
291 | # Iterate through each fold
292 | for _ in range(n_iterations):
293 |
294 | if task == 'classification':
295 | model = lgb.LGBMClassifier(n_estimators=1000, learning_rate=0.05, verbose=-1)
296 |
297 | elif task == 'regression':
298 | model = lgb.LGBMRegressor(n_estimators=1000, learning_rate=0.05, verbose=-1)
299 |
300 | else:
301 | raise ValueError('Task must be either "classification" or "regression"')
302 |
303 | # If training using early stopping need a validation set
304 | if early_stopping:
305 |
306 | train_features, valid_features, train_labels, valid_labels = train_test_split(features, labels,
307 | test_size=0.15)
308 | # Train the model with early stopping
309 | model.fit(train_features, train_labels, eval_metric=eval_metric,
310 | eval_set=[(valid_features, valid_labels)],
311 | early_stopping_rounds=100, verbose=0)
312 |
313 | # Clean up memory
314 | gc.enable()
315 | del train_features, train_labels, valid_features, valid_labels
316 | gc.collect()
317 |
318 | else:
319 | model.fit(features, labels)
320 |
321 | # Record the feature importances
322 | feature_importance_values += model.feature_importances_ / n_iterations
323 |
324 | feature_importances = pd.DataFrame({'feature': feature_names, 'importance': feature_importance_values})
325 |
326 | # Sort features according to importance
327 | feature_importances = feature_importances.sort_values('importance', ascending=False).reset_index(drop=True)
328 |
329 | # Normalize the feature importances to add up to one
330 | feature_importances['normalized_importance'] = feature_importances['importance'] / feature_importances[
331 | 'importance'].sum()
332 | feature_importances['cumulative_importance'] = np.cumsum(feature_importances['normalized_importance'])
333 |
334 | # Extract the features with zero importance
335 | record_zero_importance = feature_importances[feature_importances['importance'] == 0.0]
336 |
337 | to_drop = list(record_zero_importance['feature'])
338 |
339 | self.feature_importances = feature_importances
340 | self.record_zero_importance = record_zero_importance
341 | self.ops['zero_importance'] = to_drop
342 |
343 | print('\n%d features with zero importance after one-hot encoding.\n' % len(self.ops['zero_importance']))
344 |
345 | def identify_low_importance(self, cumulative_importance):
346 | """
347 | Finds the lowest importance features not needed to account for `cumulative_importance` fraction
348 | of the total feature importance from the gradient boosting machine. As an example, if cumulative
349 | importance is set to 0.95, this will retain only the most important features needed to
350 | reach 95% of the total feature importance. The identified features are those not needed.
351 |
352 | Parameters
353 | --------
354 | cumulative_importance : float between 0 and 1
355 | The fraction of cumulative importance to account for
356 |
357 | """
358 |
359 | self.cumulative_importance = cumulative_importance
360 |
361 | # The feature importances need to be calculated before running
362 | if self.feature_importances is None:
363 | raise NotImplementedError("""Feature importances have not yet been determined.
364 | Call the `identify_zero_importance` method first.""")
365 |
366 | # Make sure most important features are on top
367 | self.feature_importances = self.feature_importances.sort_values('cumulative_importance')
368 |
369 | # Identify the features not needed to reach the cumulative_importance
370 | record_low_importance = self.feature_importances[
371 | self.feature_importances['cumulative_importance'] > cumulative_importance]
372 |
373 | to_drop = list(record_low_importance['feature'])
374 |
375 | self.record_low_importance = record_low_importance
376 | self.ops['low_importance'] = to_drop
377 |
378 | print('%d features required for cumulative importance of %0.2f after one hot encoding.' % (
379 | len(self.feature_importances) -
380 | len(self.record_low_importance), self.cumulative_importance))
381 | print('%d features do not contribute to cumulative importance of %0.2f.\n' % (len(self.ops['low_importance']),
382 | self.cumulative_importance))
383 |
384 | def identify_all(self, selection_params):
385 | """
386 | Use all five of the methods to identify features to remove.
387 |
388 | Parameters
389 | --------
390 |
391 | selection_params : dict
392 | Parameters to use in the five feature selection methhods.
393 | Params must contain the keys ['missing_threshold', 'correlation_threshold', 'eval_metric', 'task', 'cumulative_importance']
394 |
395 | """
396 |
397 | # Check for all required parameters
398 | for param in ['missing_threshold', 'correlation_threshold', 'eval_metric', 'task', 'cumulative_importance']:
399 | if param not in selection_params.keys():
400 | raise ValueError('%s is a required parameter for this method.' % param)
401 |
402 | # Implement each of the five methods
403 | self.identify_missing(selection_params['missing_threshold'])
404 | self.identify_single_unique()
405 | self.identify_collinear(selection_params['correlation_threshold'])
406 | self.identify_zero_importance(task=selection_params['task'], eval_metric=selection_params['eval_metric'])
407 | self.identify_low_importance(selection_params['cumulative_importance'])
408 |
409 | # Find the number of features identified to drop
410 | self.all_identified = set(list(chain(*list(self.ops.values()))))
411 | self.n_identified = len(self.all_identified)
412 |
413 | print('%d total features out of %d identified for removal after one-hot encoding.\n' % (self.n_identified,
414 | self.data_all.shape[1]))
415 |
416 | def check_removal(self, keep_one_hot=True):
417 |
418 | """Check the identified features before removal. Returns a list of the unique features identified."""
419 |
420 | self.all_identified = set(list(chain(*list(self.ops.values()))))
421 | print('Total of %d features identified for removal' % len(self.all_identified))
422 |
423 | if not keep_one_hot:
424 | if self.one_hot_features is None:
425 | print('Data has not been one-hot encoded')
426 | else:
427 | one_hot_to_remove = [x for x in self.one_hot_features if x not in self.all_identified]
428 | print('%d additional one-hot features can be removed' % len(one_hot_to_remove))
429 |
430 | return list(self.all_identified)
431 |
432 | def remove(self, methods, keep_one_hot=True):
433 | """
434 | Remove the features from the data according to the specified methods.
435 |
436 | Parameters
437 | --------
438 | methods : 'all' or list of methods
439 | If methods == 'all', any methods that have identified features will be used
440 | Otherwise, only the specified methods will be used.
441 | Can be one of ['missing', 'single_unique', 'collinear', 'zero_importance', 'low_importance']
442 | keep_one_hot : boolean, default = True
443 | Whether or not to keep one-hot encoded features
444 |
445 | Return
446 | --------
447 | data : dataframe
448 | Dataframe with identified features removed
449 |
450 |
451 | Notes
452 | --------
453 | - If feature importances are used, the one-hot encoded columns will be added to the data (and then may be removed)
454 | - Check the features that will be removed before transforming data!
455 |
456 | """
457 |
458 | features_to_drop = []
459 |
460 | if methods == 'all':
461 |
462 | # Need to use one-hot encoded data as well
463 | data = self.data_all
464 |
465 | print('{} methods have been run\n'.format(list(self.ops.keys())))
466 |
467 | # Find the unique features to drop
468 | features_to_drop = set(list(chain(*list(self.ops.values()))))
469 |
470 | else:
471 | # Need to use one-hot encoded data as well
472 | if 'zero_importance' in methods or 'low_importance' in methods or self.one_hot_correlated:
473 | data = self.data_all
474 |
475 | else:
476 | data = self.data
477 |
478 | # Iterate through the specified methods
479 | for method in methods:
480 |
481 | # Check to make sure the method has been run
482 | if method not in self.ops.keys():
483 | raise NotImplementedError('%s method has not been run' % method)
484 |
485 | # Append the features identified for removal
486 | else:
487 | features_to_drop.append(self.ops[method])
488 |
489 | # Find the unique features to drop
490 | features_to_drop = set(list(chain(*features_to_drop)))
491 |
492 | features_to_drop = list(features_to_drop)
493 |
494 | if not keep_one_hot:
495 |
496 | if self.one_hot_features is None:
497 | print('Data has not been one-hot encoded')
498 | else:
499 |
500 | features_to_drop = list(set(features_to_drop) | set(self.one_hot_features))
501 |
502 | # Remove the features and return the data
503 | data = data.drop(features_to_drop, axis=1)
504 | #data = data.drop(columns=features_to_drop)
505 | self.removed_features = features_to_drop
506 |
507 | if not keep_one_hot:
508 | print('Removed %d features including one-hot features.' % len(features_to_drop))
509 | else:
510 | print('Removed %d features.' % len(features_to_drop))
511 |
512 | return data
513 |
514 | def plot_missing(self):
515 | """Histogram of missing fraction in each feature"""
516 | if self.record_missing is None:
517 | raise NotImplementedError("Missing values have not been calculated. Run `identify_missing`")
518 |
519 | self.reset_plot()
520 |
521 | # Histogram of missing values
522 | plt.style.use('seaborn-white')
523 | plt.figure(figsize=(7, 5))
524 | plt.hist(self.missing_stats['missing_fraction'], bins=np.linspace(0, 1, 11), edgecolor='k', color='red',
525 | linewidth=1.5)
526 | plt.xticks(np.linspace(0, 1, 11));
527 | plt.xlabel('Missing Fraction', size=14);
528 | plt.ylabel('Count of Features', size=14);
529 | plt.title("Fraction of Missing Values Histogram", size=16);
530 |
531 | def plot_unique(self):
532 | """Histogram of number of unique values in each feature"""
533 | if self.record_single_unique is None:
534 | raise NotImplementedError('Unique values have not been calculated. Run `identify_single_unique`')
535 |
536 | self.reset_plot()
537 |
538 | # Histogram of number of unique values
539 | self.unique_stats.plot.hist(edgecolor='k', figsize=(7, 5))
540 | plt.ylabel('Frequency', size=14);
541 | plt.xlabel('Unique Values', size=14);
542 | plt.title('Number of Unique Values Histogram', size=16);
543 |
544 | def plot_collinear(self, plot_all=False):
545 | """
546 | Heatmap of the correlation values. If plot_all = True plots all the correlations otherwise
547 | plots only those features that have a correlation above the threshold
548 |
549 | Notes
550 | --------
551 | - Not all of the plotted correlations are above the threshold because this plots
552 | all the variables that have been idenfitied as having even one correlation above the threshold
553 | - The features on the x-axis are those that will be removed. The features on the y-axis
554 | are the correlated features with those on the x-axis
555 |
556 | Code adapted from https://seaborn.pydata.org/examples/many_pairwise_correlations.html
557 | """
558 |
559 | if self.record_collinear is None:
560 | raise NotImplementedError('Collinear features have not been idenfitied. Run `identify_collinear`.')
561 |
562 | if plot_all:
563 | corr_matrix_plot = self.corr_matrix
564 | title = 'All Correlations'
565 |
566 | else:
567 | # Identify the correlations that were above the threshold
568 | # columns (x-axis) are features to drop and rows (y_axis) are correlated pairs
569 | corr_matrix_plot = self.corr_matrix.loc[list(set(self.record_collinear['corr_feature'])),
570 | list(set(self.record_collinear['drop_feature']))]
571 |
572 | title = "Correlations Above Threshold"
573 |
574 | f, ax = plt.subplots(figsize=(10, 8))
575 |
576 | # Diverging colormap
577 | cmap = sns.diverging_palette(220, 10, as_cmap=True)
578 |
579 | # Draw the heatmap with a color bar
580 | sns.heatmap(corr_matrix_plot, cmap=cmap, center=0,
581 | linewidths=.25, cbar_kws={"shrink": 0.6})
582 |
583 | # Set the ylabels
584 | ax.set_yticks([x + 0.5 for x in list(range(corr_matrix_plot.shape[0]))])
585 | ax.set_yticklabels(list(corr_matrix_plot.index), size=int(160 / corr_matrix_plot.shape[0]));
586 |
587 | # Set the xlabels
588 | ax.set_xticks([x + 0.5 for x in list(range(corr_matrix_plot.shape[1]))])
589 | ax.set_xticklabels(list(corr_matrix_plot.columns), size=int(160 / corr_matrix_plot.shape[1]));
590 | plt.title(title, size=14)
591 |
592 | def plot_feature_importances(self, plot_n=15, threshold=None):
593 | """
594 | Plots `plot_n` most important features and the cumulative importance of features.
595 | If `threshold` is provided, prints the number of features needed to reach `threshold` cumulative importance.
596 |
597 | Parameters
598 | --------
599 |
600 | plot_n : int, default = 15
601 | Number of most important features to plot. Defaults to 15 or the maximum number of features whichever is smaller
602 |
603 | threshold : float, between 0 and 1 default = None
604 | Threshold for printing information about cumulative importances
605 |
606 | """
607 |
608 | if self.record_zero_importance is None:
609 | raise NotImplementedError('Feature importances have not been determined. Run `idenfity_zero_importance`')
610 |
611 | # Need to adjust number of features if greater than the features in the data
612 | if plot_n > self.feature_importances.shape[0]:
613 | plot_n = self.feature_importances.shape[0] - 1
614 |
615 | self.reset_plot()
616 |
617 | # Make a horizontal bar chart of feature importances
618 | plt.figure(figsize=(10, 6))
619 | ax = plt.subplot()
620 |
621 | # Need to reverse the index to plot most important on top
622 | # There might be a more efficient method to accomplish this
623 | ax.barh(list(reversed(list(self.feature_importances.index[:plot_n]))),
624 | self.feature_importances['normalized_importance'][:plot_n],
625 | align='center', edgecolor='k')
626 |
627 | # Set the yticks and labels
628 | ax.set_yticks(list(reversed(list(self.feature_importances.index[:plot_n]))))
629 | ax.set_yticklabels(self.feature_importances['feature'][:plot_n], size=12)
630 |
631 | # Plot labeling
632 | plt.xlabel('Normalized Importance', size=16);
633 | plt.title('Feature Importances', size=18)
634 | plt.show()
635 |
636 | # Cumulative importance plot
637 | plt.figure(figsize=(6, 4))
638 | plt.plot(list(range(1, len(self.feature_importances) + 1)), self.feature_importances['cumulative_importance'],
639 | 'r-')
640 | plt.xlabel('Number of Features', size=14);
641 | plt.ylabel('Cumulative Importance', size=14);
642 | plt.title('Cumulative Feature Importance', size=16);
643 |
644 | if threshold:
645 | # Index of minimum number of features needed for cumulative importance threshold
646 | # np.where returns the index so need to add 1 to have correct number
647 | importance_index = np.min(np.where(self.feature_importances['cumulative_importance'] > threshold))
648 | plt.vlines(x=importance_index + 1, ymin=0, ymax=1, linestyles='--', colors='blue')
649 | plt.show();
650 |
651 | print('%d features required for %0.2f of cumulative importance' % (importance_index + 1, threshold))
652 |
653 | def reset_plot(self):
654 | plt.rcParams = plt.rcParamsDefault
--------------------------------------------------------------------------------
/gym/envs/StarTrader/StarTrade_env.py:
--------------------------------------------------------------------------------
1 | # --------------------------- IMPORT LIBRARIES -------------------------
2 | import numpy as np
3 | import pandas as pd
4 | import matplotlib.pyplot as plt
5 | from datetime import datetime, timedelta
6 | from gym.utils import seeding
7 | import gym
8 | from gym import spaces
9 | import data_preprocessing as dp
10 |
11 | # ------------------------- GLOBAL PARAMETERS -------------------------
12 | # Start and end period of historical data in question
13 | START_TRAIN = datetime(2008, 12, 31)
14 | END_TRAIN = datetime(2017, 2, 12)
15 | START_TEST = datetime(2017, 2, 12)
16 | END_TEST = datetime(2019, 2, 22)
17 |
18 | STARTING_ACC_BALANCE = 100000
19 | NUMBER_NON_CORR_STOCKS = 5
20 | MAX_TRADE = 10
21 | TRAIN_RATIO = 0.8
22 | PRICE_IMPACT = 0.1
23 |
24 | # Pools of stocks to trade
25 | DJI = ['MMM', 'AXP', 'AAPL', 'BA', 'CAT', 'CVX', 'CSCO', 'KO', 'DIS', 'XOM', 'GE', 'GS', 'HD', 'IBM', 'INTC', 'JNJ',
26 | 'JPM', 'MCD', 'MRK', 'MSFT', 'NKE', 'PFE', 'PG', 'UTX', 'UNH', 'VZ', 'WMT']
27 |
28 | DJI_N = ['3M','American Express', 'Apple','Boeing','Caterpillar','Chevron','Cisco Systems','Coca-Cola','Disney'
29 | ,'ExxonMobil','General Electric','Goldman Sachs','Home Depot','IBM','Intel','Johnson & Johnson',
30 | 'JPMorgan Chase','McDonalds','Merck','Microsoft','NIKE','Pfizer','Procter & Gamble',
31 | 'United Technologies','UnitedHealth Group','Verizon Communications','Wal Mart']
32 |
33 | #Market and macroeconomic data to be used as context data
34 | CONTEXT_DATA = ['^GSPC', '^DJI', '^IXIC', '^RUT', 'SPY', 'QQQ', '^VIX', 'GLD', '^TYX', '^TNX' , 'SHY', 'SHV']
35 |
36 |
37 | # ------------------------------ PREPROCESSING ---------------------------------
38 | print ("\n")
39 | print ("############################## Welcome to the playground of Star Trader!! ###################################")
40 | print ("\n")
41 | print ("Hello, I am Star, I am learning to trade like a human. In this playground, I trade stocks and optimize my portfolio.")
42 | print ("\n")
43 |
44 | print ("Starting to pre-process data for trading environment construction ... ")
45 | # Data Preprocessing
46 | dataset = dp.DataRetrieval()
47 | dow_stocks_train, dow_stocks_test = dataset.get_all()
48 | dow_stock_volume = dataset.components_df_v[DJI]
49 | portfolios = dp.Trading(dow_stocks_train, dow_stocks_test, dow_stock_volume.loc[START_TEST:END_TEST])
50 | _, _, non_corr_stocks = portfolios.find_non_correlate_stocks(NUMBER_NON_CORR_STOCKS)
51 | feature_df = dataset.get_feature_dataframe (non_corr_stocks)
52 | context_df = dataset.get_feature_dataframe (CONTEXT_DATA)
53 |
54 | # With context data
55 | input_states = pd.concat([context_df, feature_df], axis=1)
56 | input_states.to_csv('./data/ddpg_input_states.csv')
57 | # Without context data
58 | #input_states = feature_df
59 | feature_length = len(input_states.columns)
60 | data_length = len(input_states)
61 | stock_price = dataset.components_df_o[non_corr_stocks]
62 | stock_volume = dataset.components_df_v[non_corr_stocks]
63 | stock_price.to_csv('./data/ddpg_stock_price.csv')
64 |
65 | print("\n")
66 | print("Base on non-correlation preference, {} stocks are selected for portfolio construction:".format(NUMBER_NON_CORR_STOCKS))
67 |
68 | for stock in non_corr_stocks:
69 | print(DJI_N[DJI.index(stock)])
70 | print("\n")
71 |
72 | print("Pre-processing and stock selection complete, trading starts now ...")
73 | print("_______________________________________________________________________________________________________________")
74 |
75 |
76 | # ------------------------------ CLASSES ---------------------------------
77 |
78 | class StarTradingEnv(gym.Env):
79 | metadata = {'render.modes': ['human']}
80 |
81 | def __init__(self, day = START_TRAIN):
82 |
83 | """
84 | Initializing the trading environment, trading parameters starting values are defined.
85 | """
86 | self.iteration = 0
87 | self.day = day
88 | self.done = False
89 |
90 |
91 | # trading agent's action with low and high as the maximum number of stocks allowed to sell or buy,
92 | # defined using Gym's Box action space function
93 | self.action_space = spaces.Box(low = -MAX_TRADE, high = MAX_TRADE,shape = (NUMBER_NON_CORR_STOCKS,),dtype=np.int8)
94 |
95 | # [account balance]+[unrealized profit/loss] +[number of features, 36]+[portfolio stock of 5 stocks holdings]
96 | self.full_feature_length = 2 + feature_length
97 | print("full length", self.full_feature_length)
98 | self.observation_space = spaces.Box(low=0, high=np.inf, shape = (self.full_feature_length + NUMBER_NON_CORR_STOCKS,))
99 |
100 | # Sliding the timeline window day-by-day, skipping the non-trading day as data is not available
101 | wrong_day = True
102 | add_day = 0
103 | while wrong_day:
104 | try:
105 | temp_date = self.day + timedelta(days=add_day)
106 | self.data = input_states.loc[temp_date]
107 | self.day = temp_date
108 | wrong_day = False
109 | except:
110 | add_day += 1
111 |
112 | self.timeline = [self.day]
113 | # The money in the trading account
114 | self.acc_balance = [STARTING_ACC_BALANCE]
115 | self.total_asset = self.acc_balance
116 | self.portfolio_asset = [0.0]
117 | self.buy_price = np.zeros((1,NUMBER_NON_CORR_STOCKS)).flatten()
118 | # Unrealized profit and loss
119 | self.unrealized_pnl = [0.0]
120 | #The value of all-stock holdings
121 | self.portfolio_value = 0.0
122 |
123 |
124 | # The state of the trading environment, defined by account balance, unrealized profit and loss, relevant
125 | # stock technical data & current stock holdings
126 | self.state = self.acc_balance + self.unrealized_pnl + self.data.values.tolist() + [0 for i in range(NUMBER_NON_CORR_STOCKS)]
127 |
128 | # The current reward of the agent.
129 | self.reward = 0
130 | self._seed()
131 | self.reset()
132 |
133 | def _sell(self, idx, action):
134 | """
135 | Perform and record sell transactions. Commissions and slippage are taken into account.
136 | """
137 |
138 | # Only need to sell the unit recommended by the trading agent, not necessarily all stock unit.
139 | num_share = min(abs(int(action)) , self.state[idx + self.full_feature_length])
140 | commission = dp.Trading.commission(num_share, stock_price.loc[self.day][idx])
141 | # Calculate slipped price. Though, at max trading volume of 10 shares, there's hardly any slippage
142 | transacted_price = dp.Trading.slippage_price(stock_price.loc[self.day][idx], -num_share, stock_volume.loc[self.day][idx])
143 |
144 | # If there is existing stock holding
145 | if self.state[idx + self.full_feature_length] > 0:
146 |
147 | # Update account balance after transaction
148 | self.state[0] += (transacted_price * num_share) - commission
149 | # Update stock holding
150 | self.state[idx + self.full_feature_length] -= num_share
151 | # Reset transacted buy price record to 0.0 if there is no more stock holding
152 | if self.state[idx + self.full_feature_length] == 0.0:
153 | self.buy_price[idx] = 0.0
154 | else:
155 | pass
156 |
157 | def _buy(self, idx, action):
158 | """
159 | Perform and record buy transactions. Commissions and slippage are taken into account.
160 | """
161 |
162 | # Calculate the maximum possible number of stock unit the current cash can buy
163 | available_unit = self.state[0] // stock_price.loc[self.day][idx]
164 | num_share = min(available_unit, int(action))
165 | # Deduct the traded amount from account balance. If available balance is not enough to purchase stock unit
166 | # recommended by trading agent's action, just use what is left.
167 | commission = dp.Trading.commission(num_share, stock_price.loc[self.day][idx])
168 | # Calculate slipped price. Though, at max trading volume of 10 shares, there's hardly any slippage
169 | transacted_price = dp.Trading.slippage_price(stock_price.loc[self.day][idx], num_share,
170 | stock_volume.loc[self.day][idx])
171 |
172 | # Revise number of share to trade if account balance does not have enough
173 | if (self.state[0] - commission) < transacted_price * num_share:
174 | num_share = (self.state[0] - commission) // transacted_price
175 | self.state[0] -= ( transacted_price * num_share ) + commission
176 |
177 |
178 | # If there are existing stock holding already, calculate the average buy price
179 | if self.state[idx + self.full_feature_length] > 0.0:
180 | existing_unit = self.state[idx + self.full_feature_length]
181 | previous_buy_price = self.buy_price[idx]
182 | additional_unit = min(available_unit, int(action))
183 | new_holding = existing_unit + additional_unit
184 | self.buy_price[idx] = ((existing_unit * previous_buy_price ) + (transacted_price * additional_unit))/ new_holding
185 | # if there is no existing stock holding, simply record the current buy price
186 | elif self.state[idx + self.full_feature_length] == 0.0:
187 | self.buy_price[idx] = transacted_price
188 |
189 | # Update stock holding at its index
190 | self.state[idx + self.full_feature_length] += num_share
191 |
192 |
193 | def step(self, actions):
194 | """
195 | The step of an episode. Perform all activities of an episode.
196 | """
197 |
198 | # Episode ends when timestep reaches the last day in feature data
199 | self.done = self.day >= END_TRAIN
200 | # Uncomment below to run a quick test
201 | #self.done = self.day >= START_TRAIN + timedelta(days=10)
202 |
203 | # If it is the last step, plot trading performance
204 | if self.done:
205 | print("@@@@@@@@@@@@@@@@@")
206 | print("Iteration", self.iteration-1)
207 | # Construct trading book and save to a spreadsheet for analysis
208 | trading_book = pd.DataFrame(index=self.timeline, columns=["Cash balance", "Portfolio value", "Total asset", "Returns", "Cum Returns"])
209 | trading_book["Cash balance"] = self.acc_balance
210 | trading_book["Portfolio value"] = self.portfolio_asset
211 | trading_book["Total asset"] = self.total_asset
212 | trading_book["Returns"] = trading_book["Total asset"] / trading_book["Total asset"].shift(1) - 1
213 | trading_book["CumReturns"] = trading_book["Returns"].add(1).cumprod().fillna(1)
214 | trading_book.to_csv('./train_result/trading_book_train_{}.csv'.format(self.iteration-1))
215 |
216 | kpi = dp.MathCalc.calc_kpi(trading_book)
217 | kpi.to_csv('./train_result/kpi_train_{}.csv'.format(self.iteration-1))
218 | print("===============================================================================================")
219 | print(kpi)
220 | print("===============================================================================================")
221 |
222 | # Visualize results
223 | plt.plot(trading_book.index, trading_book["Cash balance"], 'g', label='Account cash balance',alpha=0.8,)
224 | plt.plot(trading_book.index, trading_book["Portfolio value"], 'r', label='Portfolio value',alpha=1,lw=1.5)
225 | plt.plot(trading_book.index, trading_book["Total asset"], 'b', label='Total asset',alpha=0.6,lw=3)
226 | plt.xlabel('Timeline', fontsize=12)
227 | plt.ylabel('Value', fontsize=12)
228 | plt.title('Portfolio value + account fund evolution @ train iteration {}'.format(self.iteration-1), fontsize=13)
229 | plt.tight_layout()
230 | plt.legend()
231 | plt.savefig('./train_result/asset_evolution_train_{}.png'.format(self.iteration-1))
232 | plt.show()
233 | plt.close()
234 |
235 | plt.plot(trading_book.index, trading_book["CumReturns"], 'g', label='Cummulative returns')
236 | plt.xlabel('Timeline', fontsize=12)
237 | plt.ylabel('Returns', fontsize=12)
238 | plt.title('Cummulative returns @ train iteration {}'.format(self.iteration-1), fontsize=13)
239 | plt.tight_layout()
240 | plt.legend()
241 | plt.savefig('./train_result/cummulative_returns_train_{}.png'.format(self.iteration-1))
242 | plt.show()
243 | plt.close()
244 |
245 | return self.state, self.reward, self.done, {}
246 |
247 | else:
248 | # Portfolio value is current holdings multiply with current respective stock prices
249 | portfolio_value = sum(np.array(stock_price.loc[self.day]) * np.array(self.state[self.full_feature_length:]))
250 | # Total asset is account balance + portfolio value
251 | total_asset_starting = self.state[0] + portfolio_value
252 |
253 | # Sort the trade order in increasing order, stocks with higher number of units to sell will be
254 | # transacted first. Stocks with lesser number of units to buy will be transacted first
255 | sorted_actions = np.argsort(actions)
256 | # Get the stocks to be sold
257 | sell_stocks = sorted_actions[:np.where(actions < 0)[0].shape[0]]
258 |
259 | # Alternatively, sell with static order
260 | #sell_stocks = np.where(actions < 0)[0].flatten()
261 | #np.random.shuffle(sell_stocks)
262 | for stock_idx in sell_stocks:
263 | self._sell(stock_idx, actions[stock_idx])
264 |
265 | # Get the stocks to be bought
266 | buy_stocks = sorted_actions[::-1][:np.where(actions > 0)[0].shape[0]]
267 | # Alternatively, buy with static order
268 | #buy_stocks = np.where(actions > 0)[0].flatten()
269 | #np.random.shuffle(buy_stocks)
270 | for stock_idx in buy_stocks:
271 | self._buy(stock_idx, actions[stock_idx])
272 |
273 | # Update date and skip some date since not every day is trading day
274 | self.day += timedelta(days=1)
275 | wrong_day = True
276 | add_day = 0
277 | while wrong_day:
278 | try:
279 | temp_date = self.day + timedelta(days=add_day)
280 | self.data = input_states.loc[temp_date]
281 | self.day = temp_date
282 | wrong_day = False
283 | except:
284 | add_day += 1
285 |
286 | # Calculate unrealized profit and loss for existing stock holdings
287 | self.unrealized_pnl = np.sum(np.array(stock_price.loc[self.day] - self.buy_price) * np.array(
288 | self.state[self.full_feature_length:]))
289 |
290 | # Current state space
291 | self.state = [self.state[0]] + [self.unrealized_pnl] + self.data.values.tolist() + list(self.state[self.full_feature_length:])
292 | # Portfolio value is the current stock prices multiply with their respective holdings
293 | portfolio_value = sum(np.array(stock_price.loc[self.day]) * np.array(self.state[self.full_feature_length:]))
294 | # Total asset = account balance + portfolio value
295 | total_asset_ending = self.state[0] + portfolio_value
296 |
297 | # Update account balance statement
298 | self.acc_balance = np.append(self.acc_balance, self.state[0])
299 |
300 | # Update portfolio value statement
301 | self.portfolio_asset = np.append(self.portfolio_asset, portfolio_value)
302 |
303 | # Update total asset statement
304 | self.total_asset = np.append(self.total_asset, total_asset_ending)
305 |
306 | # Update timeline
307 | self.timeline = np.append(self.timeline, self.day)
308 |
309 | # Get the agent to consider gain-to-pain or lake ratio and be responsible for it if it has traded long enough
310 | if len(self.total_asset) > 9:
311 | returns = dp.MathCalc.calc_return(pd.Series(self.total_asset))
312 |
313 | self.reward = total_asset_ending - total_asset_starting \
314 | + (100*dp.MathCalc.calc_gain_to_pain(returns))\
315 | - (500 * dp.MathCalc.calc_lake_ratio(pd.Series(returns).add(1).cumprod().fillna(1)))
316 | #+ (50 * dp.MathCalc.sharpe_ratio(pd.Series(returns)))
317 |
318 | # If agent has not traded long enough, it only has to bear total asset difference at the end of the day
319 | else:
320 | self.reward = total_asset_ending - total_asset_starting
321 |
322 | return self.state, self.reward, self.done, {}
323 |
324 | def reset(self):
325 | """
326 | Reset the environment once an episode end.
327 |
328 | """
329 | self.acc_balance = [STARTING_ACC_BALANCE]
330 | self.total_asset = self.acc_balance
331 | self.portfolio_asset = [0.0]
332 | self.buy_price = np.zeros((1, NUMBER_NON_CORR_STOCKS)).flatten()
333 | self.unrealized_pnl = [0]
334 | self.day = START_TRAIN
335 | self.portfolio_value = 0.0
336 |
337 | wrong_day = True
338 | add_day = 0
339 | while wrong_day:
340 | try:
341 | temp_date = self.day + timedelta(days=add_day)
342 | self.data = input_states.loc[temp_date]
343 | self.day = temp_date
344 | wrong_day = False
345 | except:
346 | add_day += 1
347 |
348 | self.timeline = [self.day]
349 |
350 | self.state = self.acc_balance + self.unrealized_pnl + self.data.values.tolist() + [0 for i in range(NUMBER_NON_CORR_STOCKS)]
351 | self.iteration += 1
352 |
353 | return self.state
354 |
355 | def render(self, mode='human'):
356 | """
357 | Render the environment with current state.
358 |
359 | """
360 | return self.state
361 |
362 | def _seed(self, seed=None):
363 | """
364 | Seed the iteration.
365 | """
366 | self.np_random, seed = seeding.np_random(seed)
367 | return [seed]
368 |
--------------------------------------------------------------------------------
/gym/envs/StarTrader/__init__.py:
--------------------------------------------------------------------------------
1 | from gym.envs.StarTrader.StarTrade_env import StarTradingEnv
2 |
--------------------------------------------------------------------------------
/gym/envs/StarTrader/__pycache__/StarTrade_env.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/gym/envs/StarTrader/__pycache__/StarTrade_env.cpython-36.pyc
--------------------------------------------------------------------------------
/gym/envs/StarTrader/__pycache__/StarTrade_test_env.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/gym/envs/StarTrader/__pycache__/StarTrade_test_env.cpython-36.pyc
--------------------------------------------------------------------------------
/gym/envs/StarTrader/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/gym/envs/StarTrader/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/gym/envs/StarTrader/__pycache__/intelliTrade_env.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/gym/envs/StarTrader/__pycache__/intelliTrade_env.cpython-36.pyc
--------------------------------------------------------------------------------
/gym/envs/StarTraderTest/StarTrade_test_env.py:
--------------------------------------------------------------------------------
1 | # --------------------------- IMPORT LIBRARIES -------------------------
2 | import numpy as np
3 | import pandas as pd
4 | import matplotlib.pyplot as plt
5 | from datetime import datetime, timedelta
6 | from gym.utils import seeding
7 | import gym
8 | from gym import spaces
9 | import data_preprocessing as dp
10 |
11 | # ------------------------- GLOBAL PARAMETERS -------------------------
12 | # Start and end period of historical data in question
13 | START_TRAIN = datetime(2008, 12, 31)
14 | END_TRAIN = datetime(2017, 2, 12)
15 | START_TEST = datetime(2017, 2, 12)
16 | END_TEST = datetime(2019, 2, 22)
17 |
18 | STARTING_ACC_BALANCE = 100000
19 | NUMBER_NON_CORR_STOCKS = 5
20 | MAX_TRADE = 10
21 | TRAIN_RATIO = 0.8
22 |
23 | # Pools of stocks to trade
24 | DJI = ['MMM', 'AXP', 'AAPL', 'BA', 'CAT', 'CVX', 'CSCO', 'KO', 'DIS', 'XOM', 'GE', 'GS', 'HD', 'IBM', 'INTC', 'JNJ',
25 | 'JPM', 'MCD', 'MRK', 'MSFT', 'NKE', 'PFE', 'PG', 'UTX', 'UNH', 'VZ', 'WMT']
26 |
27 | DJI_N = ['3M','American Express', 'Apple','Boeing','Caterpillar','Chevron','Cisco Systems','Coca-Cola','Disney'
28 | ,'ExxonMobil','General Electric','Goldman Sachs','Home Depot','IBM','Intel','Johnson & Johnson',
29 | 'JPMorgan Chase','McDonalds','Merck','Microsoft','NIKE','Pfizer','Procter & Gamble',
30 | 'United Technologies','UnitedHealth Group','Verizon Communications','Wal Mart']
31 |
32 | #Market and macroeconomic data to be used as context data
33 | CONTEXT_DATA = ['^GSPC', '^DJI', '^IXIC', '^RUT', 'SPY', 'QQQ', '^VIX', 'GLD', '^TYX', '^TNX' , 'SHY', 'SHV']
34 |
35 |
36 | # ------------------------------ PREPROCESSING ---------------------------------
37 | print ("\n")
38 | print ("############################## Welcome to the playground of Star Trader!! ###################################")
39 | print ("\n")
40 | print ("Hello, I am Star, I am learning to trade like a human. In this playground, I trade stocks and optimize my portfolio.")
41 | print ("\n")
42 |
43 | print ("Starting to pre-process data for trading environment construction ... ")
44 | # Data Preprocessing
45 | dataset = dp.DataRetrieval()
46 | dow_stocks_train, dow_stocks_test = dataset.get_all()
47 | dow_stock_volume = dataset.components_df_v[DJI]
48 | portfolios = dp.Trading(dow_stocks_train, dow_stocks_test, dow_stock_volume.loc[START_TEST:END_TEST])
49 | _, _, non_corr_stocks = portfolios.find_non_correlate_stocks(NUMBER_NON_CORR_STOCKS)
50 |
51 | # Reuse the training agent, StarTrader-v0 's preprocessed feature data
52 | input_states = pd.read_csv("./data/ddpg_input_states.csv", index_col='Date', parse_dates=True)
53 | feature_length = len(input_states.columns)
54 | data_length = len(input_states)
55 | stock_price = dataset.components_df_o[non_corr_stocks]
56 |
57 | print("feature length", feature_length)
58 | print("\n")
59 | print("Base on non-correlation, {} stocks are selected for portfolio construction:".format(NUMBER_NON_CORR_STOCKS))
60 |
61 | for stock in non_corr_stocks:
62 | print(DJI_N[DJI.index(stock)])
63 | print("\n")
64 |
65 | print("Pre-processing and stock selection complete, trading starts now ...")
66 | print("_______________________________________________________________________________________________________________")
67 |
68 |
69 | # ------------------------------ CLASSES ---------------------------------
70 |
71 | class StarTradingTestEnv(gym.Env):
72 | metadata = {'render.modes': ['human']}
73 |
74 | def __init__(self, day = START_TEST):
75 |
76 | """
77 | Initializing the trading environment, many trading parameters starting values are defined here
78 |
79 | """
80 | self.iteration = 0
81 | self.day = day
82 | self.done = False
83 |
84 |
85 | # trading agent's action with low and high as the maximum number of stocks allowed to trade,
86 | # defined using Gym's Box action space function
87 | self.action_space = spaces.Box(low = -MAX_TRADE, high = MAX_TRADE,shape = (NUMBER_NON_CORR_STOCKS,),dtype=np.int8)
88 |
89 | # [account balance]+[unrealized profit/loss] +[number of features, 36]+[portfolio stock of 5 stocks holdings]
90 | self.full_feature_length = 2 + feature_length
91 | print("full length", self.full_feature_length)
92 | self.observation_space = spaces.Box(low=0, high=np.inf, shape = (self.full_feature_length + NUMBER_NON_CORR_STOCKS,))
93 |
94 | # Sliding the timeline window day-by-day, skipping the non-trading day as data is not available
95 | wrong_day = True
96 | add_day = 0
97 | while wrong_day:
98 | try:
99 | temp_date = self.day + timedelta(days=add_day)
100 | self.data = input_states.loc[temp_date]
101 | self.day = temp_date
102 | wrong_day = False
103 | except:
104 | add_day += 1
105 |
106 | self.timeline = [self.day]
107 | # The money in the trading account
108 | self.acc_balance = [STARTING_ACC_BALANCE]
109 | self.total_asset = self.acc_balance
110 | self.portfolio_asset = [0.0]
111 | self.buy_price = np.zeros((1,NUMBER_NON_CORR_STOCKS)).flatten()
112 | # Unrealized profit and loss
113 | self.unrealized_pnl = [0.0]
114 | #The value of all-stock holdings
115 | self.portfolio_value = 0.0
116 |
117 |
118 | # The state of the trading environment, defined by account balance, unrealized profit and loss, relevant
119 | # stock technical data & current stock holdings
120 | self.state = self.acc_balance + self.unrealized_pnl + self.data.values.tolist() + [0 for i in range(NUMBER_NON_CORR_STOCKS)]
121 |
122 | # The current reward of the agent.
123 | self.reward = 0
124 | self._seed()
125 | self.reset()
126 |
127 | def _sell(self, idx, action):
128 |
129 | # If there is existing stock holding
130 | if self.state[idx + self.full_feature_length] > 0:
131 | # Only need to sell the unit recommended by the trading agent, not necessarily all stock unit.
132 | # Update account balance after transaction
133 | self.state[0] += stock_price.loc[self.day][idx] * min(abs(int(action)), self.state[idx + self.full_feature_length])
134 | # Update stock holding
135 | self.state[idx + self.full_feature_length] -= min(abs(int(action)), self.state[idx + self.full_feature_length])
136 | # Reset transacted buy price record to 0.0 if there is no more stock holding
137 | if self.state[idx + self.full_feature_length] == 0.0:
138 | self.buy_price[idx] = 0.0
139 | else:
140 | pass
141 |
142 | def _buy(self, idx, action):
143 |
144 | # Calculate the maximum possible number of stock unit the current cash can buy
145 | available_unit = self.state[0] // stock_price.loc[self.day][idx]
146 |
147 | # Deduct the traded amount from account balance. If available balance is not enough to purchase stock unit
148 | # recommended by trading agent's action, just use what is left.
149 | self.state[0] -= stock_price.loc[self.day][idx] * min(available_unit, int(action))
150 |
151 |
152 | # If there are existing stock holding already, calculate the average buy price
153 | if self.state[idx + self.full_feature_length] > 0.0:
154 | existing_unit = self.state[idx + self.full_feature_length]
155 | previous_buy_price = self.buy_price[idx]
156 | additional_unit = min(available_unit, int(action))
157 | new_holding = existing_unit + additional_unit
158 | self.buy_price[idx] = ((existing_unit * previous_buy_price ) + (stock_price.loc[self.day][idx] * additional_unit))/ new_holding
159 | # if there is no existing stock holding, simply record the current buy price
160 | elif self.state[idx + self.full_feature_length] == 0.0:
161 | self.buy_price[idx] = stock_price.loc[self.day][idx]
162 |
163 | # Update stock holding at its index
164 | self.state[idx + self.full_feature_length] += min(available_unit, int(action))
165 |
166 |
167 | def step(self, actions):
168 |
169 | # Episode ends when timestep reaches the last day in feature data
170 | self.done = self.day >= END_TEST
171 |
172 | # If it is the last step, plot trading performance
173 | if self.done:
174 | print("@@@@@@@@@@@@@@@@@")
175 | print("Iteration", self.iteration-1)
176 | # Construct trading and save to a spreadsheet for analysis
177 | trading_book = pd.DataFrame(index=self.timeline, columns=["Cash balance", "Portfolio value", "Total asset", "Returns", "Cum Returns"])
178 | trading_book["Cash balance"] = self.acc_balance
179 | trading_book["Portfolio value"] = self.portfolio_asset
180 | trading_book["Total asset"] = self.total_asset
181 | trading_book["Returns"] = trading_book["Total asset"] / trading_book["Total asset"].shift(1) - 1
182 | trading_book["CumReturns"] = trading_book["Returns"].add(1).cumprod().fillna(1)
183 | trading_book.to_csv('./test_result/trading_book_test_{}.csv'.format(self.iteration-1))
184 |
185 | kpi = dp.MathCalc.calc_kpi(trading_book)
186 | kpi.to_csv('./test_result/kpi_test_{}.csv'.format(self.iteration-1))
187 | print("===============================================================================================")
188 | print(kpi)
189 | print("===============================================================================================")
190 |
191 | # Visualize results
192 | plt.plot(trading_book.index, trading_book["Cash balance"], 'g', label='Account cash balance',alpha=0.8,)
193 | plt.plot(trading_book.index, trading_book["Portfolio value"], 'r', label='Portfolio value',alpha=1,lw=1.5)
194 | plt.plot(trading_book.index, trading_book["Total asset"], 'b', label='Total asset',alpha=0.6,lw=3)
195 | plt.xlabel('Timeline', fontsize=12)
196 | plt.ylabel('Value', fontsize=12)
197 | plt.title('Portfolio value + Account fund evolution @ test iteration {}'.format(self.iteration-1), fontsize=13)
198 | plt.legend()
199 | plt.savefig('./test_result/asset_evolution_test_{}.png'.format(self.iteration-1))
200 | plt.show()
201 | plt.close()
202 |
203 | plt.plot(trading_book.index, trading_book["CumReturns"], 'g', label='Cummulative returns')
204 | plt.xlabel('Timeline', fontsize=12)
205 | plt.ylabel('Returns', fontsize=12)
206 | plt.title('Cummulative returns @ test iteration {}'.format(self.iteration-1), fontsize=13)
207 | plt.legend()
208 | plt.savefig('./test_result/cummulative_returns_test_{}.png'.format(self.iteration-1))
209 | plt.show()
210 | plt.close()
211 |
212 | return self.state, self.reward, self.done, {}
213 |
214 | else:
215 | # Portfolio value is current holdings multiply with current respective stock prices
216 | portfolio_value = sum(np.array(stock_price.loc[self.day]) * np.array(self.state[self.full_feature_length:]))
217 | # Total asset is account balance + portfolio value
218 | total_asset_starting = self.state[0] + portfolio_value
219 |
220 | # Sort the trade order in increasing order, stocks with higher number of units to sell will be
221 | # transacted first. Stocks with lesser number of units to buy will be transacted first
222 | sorted_actions = np.argsort(actions)
223 | # Get the stocks to be sold
224 | sell_stocks = sorted_actions[:np.where(actions < 0)[0].shape[0]]
225 |
226 | # Alternatively, sell with static order
227 | # sell_stocks = np.where(actions < 0)[0].flatten()
228 | # np.random.shuffle(sell_stocks)
229 | for stock_idx in sell_stocks:
230 | self._sell(stock_idx, actions[stock_idx])
231 |
232 | # Get the stocks to be bought
233 | buy_stocks = sorted_actions[::-1][:np.where(actions > 0)[0].shape[0]]
234 | # Alternatively, buy with static order
235 | # buy_stocks = np.where(actions > 0)[0].flatten()
236 | # np.random.shuffle(buy_stocks)
237 | for stock_idx in buy_stocks:
238 | self._buy(stock_idx, actions[stock_idx])
239 |
240 | # Update date and skip some date since not every day is trading day
241 | self.day += timedelta(days=1)
242 | wrong_day = True
243 | add_day = 0
244 | while wrong_day:
245 | try:
246 | temp_date = self.day + timedelta(days=add_day)
247 | self.data = input_states.loc[temp_date]
248 | self.day = temp_date
249 | wrong_day = False
250 | except:
251 | add_day += 1
252 |
253 | # Calculate unrealized profit and loss for existing stock holdings
254 | self.unrealized_pnl = np.sum(np.array(stock_price.loc[self.day] - self.buy_price) * np.array(
255 | self.state[self.full_feature_length:]))
256 |
257 | # Current state space
258 | self.state = [self.state[0]] + [self.unrealized_pnl] + self.data.values.tolist() + list(
259 | self.state[self.full_feature_length:])
260 | # Portfolio value is the current stock prices multiply with their respective holdings
261 | portfolio_value = sum(np.array(stock_price.loc[self.day]) * np.array(self.state[self.full_feature_length:]))
262 | # Total asset = account balance + portfolio value
263 | total_asset_ending = self.state[0] + portfolio_value
264 |
265 | # Update account balance statement
266 | self.acc_balance = np.append(self.acc_balance, self.state[0])
267 |
268 | # Update portfolio value statement
269 | self.portfolio_asset = np.append(self.portfolio_asset, portfolio_value)
270 |
271 | # Update total asset statement
272 | self.total_asset = np.append(self.total_asset, total_asset_ending)
273 |
274 | # Update timeline
275 | self.timeline = np.append(self.timeline, self.day)
276 |
277 | # Get the agent to consider gain-to-pain or lake ratio and be responsible for it if it has traded long enough
278 | if len(self.total_asset) > 9:
279 | returns = dp.MathCalc.calc_return(pd.Series(self.total_asset))
280 | self.reward = total_asset_ending - total_asset_starting + (100 * dp.MathCalc.calc_gain_to_pain(returns)) \
281 | - (500 * dp.MathCalc.calc_lake_ratio(pd.Series(returns).add(1).cumprod().fillna(1)))
282 |
283 | # If agent has not traded long enough, it only has to bear total asset difference at the end of the day
284 | else:
285 | self.reward = total_asset_ending - total_asset_starting
286 |
287 | return self.state, self.reward, self.done, {}
288 |
289 | def reset(self):
290 | """
291 | Reset the environment once an episode end
292 |
293 | """
294 | self.acc_balance = [STARTING_ACC_BALANCE]
295 | self.total_asset = self.acc_balance
296 | self.portfolio_asset = [0.0]
297 | self.buy_price = np.zeros((1, NUMBER_NON_CORR_STOCKS)).flatten()
298 | self.unrealized_pnl = [0]
299 | self.day = START_TEST
300 | self.portfolio_value = 0.0
301 |
302 | wrong_day = True
303 | add_day = 0
304 | while wrong_day:
305 | try:
306 | temp_date = self.day + timedelta(days=add_day)
307 | self.data = input_states.loc[temp_date]
308 | self.day = temp_date
309 | wrong_day = False
310 | except:
311 | add_day += 1
312 |
313 | self.timeline = [self.day]
314 |
315 | self.state = self.acc_balance + self.unrealized_pnl + self.data.values.tolist() + [0 for i in range(NUMBER_NON_CORR_STOCKS)]
316 | self.iteration += 1
317 | print("\n")
318 |
319 |
320 | return self.state
321 |
322 | def render(self, mode='human'):
323 | return self.state
324 |
325 | def _seed(self, seed=None):
326 | self.np_random, seed = seeding.np_random(seed)
327 | return [seed]
328 |
--------------------------------------------------------------------------------
/gym/envs/StarTraderTest/__init__.py:
--------------------------------------------------------------------------------
1 | from gym.envs.StarTraderTest.StarTrade_test_env import StarTradingTestEnv
2 |
--------------------------------------------------------------------------------
/gym/envs/StarTraderTest/__pycache__/StarTrade_env.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/gym/envs/StarTraderTest/__pycache__/StarTrade_env.cpython-36.pyc
--------------------------------------------------------------------------------
/gym/envs/StarTraderTest/__pycache__/StarTrade_test_env.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/gym/envs/StarTraderTest/__pycache__/StarTrade_test_env.cpython-36.pyc
--------------------------------------------------------------------------------
/gym/envs/StarTraderTest/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/gym/envs/StarTraderTest/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/gym/envs/StarTraderTest/__pycache__/intelliTrade_env.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/gym/envs/StarTraderTest/__pycache__/intelliTrade_env.cpython-36.pyc
--------------------------------------------------------------------------------
/gym/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from gym.envs.registration import registry, register, make, spec
2 |
3 | register(
4 | id='StarTrader-v0',
5 | entry_point='gym.envs.StarTrader:StarTradingEnv',
6 | )
7 |
8 | register(
9 | id='StarTraderTest-v0',
10 | entry_point='gym.envs.StarTraderTest:StarTradingTestEnv',
11 | )
12 |
13 | # Algorithmic
14 | # ----------------------------------------
15 |
16 |
17 |
18 | register(
19 | id='Copy-v0',
20 | entry_point='gym.envs.algorithmic:CopyEnv',
21 | max_episode_steps=200,
22 | reward_threshold=25.0,
23 | )
24 |
25 | register(
26 | id='RepeatCopy-v0',
27 | entry_point='gym.envs.algorithmic:RepeatCopyEnv',
28 | max_episode_steps=200,
29 | reward_threshold=75.0,
30 | )
31 |
32 | register(
33 | id='ReversedAddition-v0',
34 | entry_point='gym.envs.algorithmic:ReversedAdditionEnv',
35 | kwargs={'rows' : 2},
36 | max_episode_steps=200,
37 | reward_threshold=25.0,
38 | )
39 |
40 | register(
41 | id='ReversedAddition3-v0',
42 | entry_point='gym.envs.algorithmic:ReversedAdditionEnv',
43 | kwargs={'rows' : 3},
44 | max_episode_steps=200,
45 | reward_threshold=25.0,
46 | )
47 |
48 | register(
49 | id='DuplicatedInput-v0',
50 | entry_point='gym.envs.algorithmic:DuplicatedInputEnv',
51 | max_episode_steps=200,
52 | reward_threshold=9.0,
53 | )
54 |
55 | register(
56 | id='Reverse-v0',
57 | entry_point='gym.envs.algorithmic:ReverseEnv',
58 | max_episode_steps=200,
59 | reward_threshold=25.0,
60 | )
61 |
62 | # Classic
63 | # ----------------------------------------
64 |
65 | register(
66 | id='CartPole-v0',
67 | entry_point='gym.envs.classic_control:CartPoleEnv',
68 | max_episode_steps=200,
69 | reward_threshold=195.0,
70 | )
71 |
72 | register(
73 | id='CartPole-v1',
74 | entry_point='gym.envs.classic_control:CartPoleEnv',
75 | max_episode_steps=500,
76 | reward_threshold=475.0,
77 | )
78 |
79 | register(
80 | id='MountainCar-v0',
81 | entry_point='gym.envs.classic_control:MountainCarEnv',
82 | max_episode_steps=200,
83 | reward_threshold=-110.0,
84 | )
85 |
86 | register(
87 | id='MountainCarContinuous-v0',
88 | entry_point='gym.envs.classic_control:Continuous_MountainCarEnv',
89 | max_episode_steps=999,
90 | reward_threshold=90.0,
91 | )
92 |
93 | register(
94 | id='Pendulum-v0',
95 | entry_point='gym.envs.classic_control:PendulumEnv',
96 | max_episode_steps=200,
97 | )
98 |
99 | register(
100 | id='Acrobot-v1',
101 | entry_point='gym.envs.classic_control:AcrobotEnv',
102 | max_episode_steps=500,
103 | )
104 |
105 | # Box2d
106 | # ----------------------------------------
107 |
108 | register(
109 | id='LunarLander-v2',
110 | entry_point='gym.envs.box2d:LunarLander',
111 | max_episode_steps=1000,
112 | reward_threshold=200,
113 | )
114 |
115 | register(
116 | id='LunarLanderContinuous-v2',
117 | entry_point='gym.envs.box2d:LunarLanderContinuous',
118 | max_episode_steps=1000,
119 | reward_threshold=200,
120 | )
121 |
122 | register(
123 | id='BipedalWalker-v2',
124 | entry_point='gym.envs.box2d:BipedalWalker',
125 | max_episode_steps=1600,
126 | reward_threshold=300,
127 | )
128 |
129 | register(
130 | id='BipedalWalkerHardcore-v2',
131 | entry_point='gym.envs.box2d:BipedalWalkerHardcore',
132 | max_episode_steps=2000,
133 | reward_threshold=300,
134 | )
135 |
136 | register(
137 | id='CarRacing-v0',
138 | entry_point='gym.envs.box2d:CarRacing',
139 | max_episode_steps=1000,
140 | reward_threshold=900,
141 | )
142 |
143 | # Toy Text
144 | # ----------------------------------------
145 |
146 | register(
147 | id='Blackjack-v0',
148 | entry_point='gym.envs.toy_text:BlackjackEnv',
149 | )
150 |
151 | register(
152 | id='KellyCoinflip-v0',
153 | entry_point='gym.envs.toy_text:KellyCoinflipEnv',
154 | reward_threshold=246.61,
155 | )
156 | register(
157 | id='KellyCoinflipGeneralized-v0',
158 | entry_point='gym.envs.toy_text:KellyCoinflipGeneralizedEnv',
159 | )
160 |
161 | register(
162 | id='FrozenLake-v0',
163 | entry_point='gym.envs.toy_text:FrozenLakeEnv',
164 | kwargs={'map_name' : '4x4'},
165 | max_episode_steps=100,
166 | reward_threshold=0.78, # optimum = .8196
167 | )
168 |
169 | register(
170 | id='FrozenLake8x8-v0',
171 | entry_point='gym.envs.toy_text:FrozenLakeEnv',
172 | kwargs={'map_name' : '8x8'},
173 | max_episode_steps=200,
174 | reward_threshold=0.99, # optimum = 1
175 | )
176 |
177 | register(
178 | id='CliffWalking-v0',
179 | entry_point='gym.envs.toy_text:CliffWalkingEnv',
180 | )
181 |
182 | register(
183 | id='NChain-v0',
184 | entry_point='gym.envs.toy_text:NChainEnv',
185 | max_episode_steps=1000,
186 | )
187 |
188 | register(
189 | id='Roulette-v0',
190 | entry_point='gym.envs.toy_text:RouletteEnv',
191 | max_episode_steps=100,
192 | )
193 |
194 | register(
195 | id='Taxi-v2',
196 | entry_point='gym.envs.toy_text.taxi:TaxiEnv',
197 | reward_threshold=8, # optimum = 8.46
198 | max_episode_steps=200,
199 | )
200 |
201 | register(
202 | id='GuessingGame-v0',
203 | entry_point='gym.envs.toy_text.guessing_game:GuessingGame',
204 | max_episode_steps=200,
205 | )
206 |
207 | register(
208 | id='HotterColder-v0',
209 | entry_point='gym.envs.toy_text.hotter_colder:HotterColder',
210 | max_episode_steps=200,
211 | )
212 |
213 | # Mujoco
214 | # ----------------------------------------
215 |
216 | # 2D
217 |
218 | register(
219 | id='Reacher-v2',
220 | entry_point='gym.envs.mujoco:ReacherEnv',
221 | max_episode_steps=50,
222 | reward_threshold=-3.75,
223 | )
224 |
225 | register(
226 | id='Pusher-v2',
227 | entry_point='gym.envs.mujoco:PusherEnv',
228 | max_episode_steps=100,
229 | reward_threshold=0.0,
230 | )
231 |
232 | register(
233 | id='Thrower-v2',
234 | entry_point='gym.envs.mujoco:ThrowerEnv',
235 | max_episode_steps=100,
236 | reward_threshold=0.0,
237 | )
238 |
239 | register(
240 | id='Striker-v2',
241 | entry_point='gym.envs.mujoco:StrikerEnv',
242 | max_episode_steps=100,
243 | reward_threshold=0.0,
244 | )
245 |
246 | register(
247 | id='InvertedPendulum-v2',
248 | entry_point='gym.envs.mujoco:InvertedPendulumEnv',
249 | max_episode_steps=1000,
250 | reward_threshold=950.0,
251 | )
252 |
253 | register(
254 | id='InvertedDoublePendulum-v2',
255 | entry_point='gym.envs.mujoco:InvertedDoublePendulumEnv',
256 | max_episode_steps=1000,
257 | reward_threshold=9100.0,
258 | )
259 |
260 | register(
261 | id='HalfCheetah-v2',
262 | entry_point='gym.envs.mujoco:HalfCheetahEnv',
263 | max_episode_steps=1000,
264 | reward_threshold=4800.0,
265 | )
266 |
267 | register(
268 | id='Hopper-v2',
269 | entry_point='gym.envs.mujoco:HopperEnv',
270 | max_episode_steps=1000,
271 | reward_threshold=3800.0,
272 | )
273 |
274 | register(
275 | id='Swimmer-v2',
276 | entry_point='gym.envs.mujoco:SwimmerEnv',
277 | max_episode_steps=1000,
278 | reward_threshold=360.0,
279 | )
280 |
281 | register(
282 | id='Walker2d-v2',
283 | max_episode_steps=1000,
284 | entry_point='gym.envs.mujoco:Walker2dEnv',
285 | )
286 |
287 | register(
288 | id='Ant-v2',
289 | entry_point='gym.envs.mujoco:AntEnv',
290 | max_episode_steps=1000,
291 | reward_threshold=6000.0,
292 | )
293 |
294 | register(
295 | id='Humanoid-v2',
296 | entry_point='gym.envs.mujoco:HumanoidEnv',
297 | max_episode_steps=1000,
298 | )
299 |
300 | register(
301 | id='HumanoidStandup-v2',
302 | entry_point='gym.envs.mujoco:HumanoidStandupEnv',
303 | max_episode_steps=1000,
304 | )
305 |
306 | # Robotics
307 | # ----------------------------------------
308 |
309 | def _merge(a, b):
310 | a.update(b)
311 | return a
312 |
313 | for reward_type in ['sparse', 'dense']:
314 | suffix = 'Dense' if reward_type == 'dense' else ''
315 | kwargs = {
316 | 'reward_type': reward_type,
317 | }
318 |
319 | # Fetch
320 | register(
321 | id='FetchSlide{}-v1'.format(suffix),
322 | entry_point='gym.envs.robotics:FetchSlideEnv',
323 | kwargs=kwargs,
324 | max_episode_steps=50,
325 | )
326 |
327 | register(
328 | id='FetchPickAndPlace{}-v1'.format(suffix),
329 | entry_point='gym.envs.robotics:FetchPickAndPlaceEnv',
330 | kwargs=kwargs,
331 | max_episode_steps=50,
332 | )
333 |
334 | register(
335 | id='FetchReach{}-v1'.format(suffix),
336 | entry_point='gym.envs.robotics:FetchReachEnv',
337 | kwargs=kwargs,
338 | max_episode_steps=50,
339 | )
340 |
341 | register(
342 | id='FetchPush{}-v1'.format(suffix),
343 | entry_point='gym.envs.robotics:FetchPushEnv',
344 | kwargs=kwargs,
345 | max_episode_steps=50,
346 | )
347 |
348 | # Hand
349 | register(
350 | id='HandReach{}-v0'.format(suffix),
351 | entry_point='gym.envs.robotics:HandReachEnv',
352 | kwargs=kwargs,
353 | max_episode_steps=50,
354 | )
355 |
356 | register(
357 | id='HandManipulateBlockRotateZ{}-v0'.format(suffix),
358 | entry_point='gym.envs.robotics:HandBlockEnv',
359 | kwargs=_merge({'target_position': 'ignore', 'target_rotation': 'z'}, kwargs),
360 | max_episode_steps=100,
361 | )
362 |
363 | register(
364 | id='HandManipulateBlockRotateParallel{}-v0'.format(suffix),
365 | entry_point='gym.envs.robotics:HandBlockEnv',
366 | kwargs=_merge({'target_position': 'ignore', 'target_rotation': 'parallel'}, kwargs),
367 | max_episode_steps=100,
368 | )
369 |
370 | register(
371 | id='HandManipulateBlockRotateXYZ{}-v0'.format(suffix),
372 | entry_point='gym.envs.robotics:HandBlockEnv',
373 | kwargs=_merge({'target_position': 'ignore', 'target_rotation': 'xyz'}, kwargs),
374 | max_episode_steps=100,
375 | )
376 |
377 | register(
378 | id='HandManipulateBlockFull{}-v0'.format(suffix),
379 | entry_point='gym.envs.robotics:HandBlockEnv',
380 | kwargs=_merge({'target_position': 'random', 'target_rotation': 'xyz'}, kwargs),
381 | max_episode_steps=100,
382 | )
383 |
384 | # Alias for "Full"
385 | register(
386 | id='HandManipulateBlock{}-v0'.format(suffix),
387 | entry_point='gym.envs.robotics:HandBlockEnv',
388 | kwargs=_merge({'target_position': 'random', 'target_rotation': 'xyz'}, kwargs),
389 | max_episode_steps=100,
390 | )
391 |
392 | register(
393 | id='HandManipulateEggRotate{}-v0'.format(suffix),
394 | entry_point='gym.envs.robotics:HandEggEnv',
395 | kwargs=_merge({'target_position': 'ignore', 'target_rotation': 'xyz'}, kwargs),
396 | max_episode_steps=100,
397 | )
398 |
399 | register(
400 | id='HandManipulateEggFull{}-v0'.format(suffix),
401 | entry_point='gym.envs.robotics:HandEggEnv',
402 | kwargs=_merge({'target_position': 'random', 'target_rotation': 'xyz'}, kwargs),
403 | max_episode_steps=100,
404 | )
405 |
406 | # Alias for "Full"
407 | register(
408 | id='HandManipulateEgg{}-v0'.format(suffix),
409 | entry_point='gym.envs.robotics:HandEggEnv',
410 | kwargs=_merge({'target_position': 'random', 'target_rotation': 'xyz'}, kwargs),
411 | max_episode_steps=100,
412 | )
413 |
414 | register(
415 | id='HandManipulatePenRotate{}-v0'.format(suffix),
416 | entry_point='gym.envs.robotics:HandPenEnv',
417 | kwargs=_merge({'target_position': 'ignore', 'target_rotation': 'xyz'}, kwargs),
418 | max_episode_steps=100,
419 | )
420 |
421 | register(
422 | id='HandManipulatePenFull{}-v0'.format(suffix),
423 | entry_point='gym.envs.robotics:HandPenEnv',
424 | kwargs=_merge({'target_position': 'random', 'target_rotation': 'xyz'}, kwargs),
425 | max_episode_steps=100,
426 | )
427 |
428 | # Alias for "Full"
429 | register(
430 | id='HandManipulatePen{}-v0'.format(suffix),
431 | entry_point='gym.envs.robotics:HandPenEnv',
432 | kwargs=_merge({'target_position': 'random', 'target_rotation': 'xyz'}, kwargs),
433 | max_episode_steps=100,
434 | )
435 |
436 | # Atari
437 | # ----------------------------------------
438 |
439 | # # print ', '.join(["'{}'".format(name.split('.')[0]) for name in atari_py.list_games()])
440 | for game in ['air_raid', 'alien', 'amidar', 'assault', 'asterix', 'asteroids', 'atlantis',
441 | 'bank_heist', 'battle_zone', 'beam_rider', 'berzerk', 'bowling', 'boxing', 'breakout', 'carnival',
442 | 'centipede', 'chopper_command', 'crazy_climber', 'demon_attack', 'double_dunk',
443 | 'elevator_action', 'enduro', 'fishing_derby', 'freeway', 'frostbite', 'gopher', 'gravitar',
444 | 'hero', 'ice_hockey', 'jamesbond', 'journey_escape', 'kangaroo', 'krull', 'kung_fu_master',
445 | 'montezuma_revenge', 'ms_pacman', 'name_this_game', 'phoenix', 'pitfall', 'pong', 'pooyan',
446 | 'private_eye', 'qbert', 'riverraid', 'road_runner', 'robotank', 'seaquest', 'skiing',
447 | 'solaris', 'space_invaders', 'star_gunner', 'tennis', 'time_pilot', 'tutankham', 'up_n_down',
448 | 'venture', 'video_pinball', 'wizard_of_wor', 'yars_revenge', 'zaxxon']:
449 | for obs_type in ['image', 'ram']:
450 | # space_invaders should yield SpaceInvaders-v0 and SpaceInvaders-ram-v0
451 | name = ''.join([g.capitalize() for g in game.split('_')])
452 | if obs_type == 'ram':
453 | name = '{}-ram'.format(name)
454 |
455 | nondeterministic = False
456 | if game == 'elevator_action' and obs_type == 'ram':
457 | # ElevatorAction-ram-v0 seems to yield slightly
458 | # non-deterministic observations about 10% of the time. We
459 | # should track this down eventually, but for now we just
460 | # mark it as nondeterministic.
461 | nondeterministic = True
462 |
463 | register(
464 | id='{}-v0'.format(name),
465 | entry_point='gym.envs.atari:AtariEnv',
466 | kwargs={'game': game, 'obs_type': obs_type, 'repeat_action_probability': 0.25},
467 | max_episode_steps=10000,
468 | nondeterministic=nondeterministic,
469 | )
470 |
471 | register(
472 | id='{}-v4'.format(name),
473 | entry_point='gym.envs.atari:AtariEnv',
474 | kwargs={'game': game, 'obs_type': obs_type},
475 | max_episode_steps=100000,
476 | nondeterministic=nondeterministic,
477 | )
478 |
479 | # Standard Deterministic (as in the original DeepMind paper)
480 | if game == 'space_invaders':
481 | frameskip = 3
482 | else:
483 | frameskip = 4
484 |
485 | # Use a deterministic frame skip.
486 | register(
487 | id='{}Deterministic-v0'.format(name),
488 | entry_point='gym.envs.atari:AtariEnv',
489 | kwargs={'game': game, 'obs_type': obs_type, 'frameskip': frameskip, 'repeat_action_probability': 0.25},
490 | max_episode_steps=100000,
491 | nondeterministic=nondeterministic,
492 | )
493 |
494 | register(
495 | id='{}Deterministic-v4'.format(name),
496 | entry_point='gym.envs.atari:AtariEnv',
497 | kwargs={'game': game, 'obs_type': obs_type, 'frameskip': frameskip},
498 | max_episode_steps=100000,
499 | nondeterministic=nondeterministic,
500 | )
501 |
502 | register(
503 | id='{}NoFrameskip-v0'.format(name),
504 | entry_point='gym.envs.atari:AtariEnv',
505 | kwargs={'game': game, 'obs_type': obs_type, 'frameskip': 1, 'repeat_action_probability': 0.25}, # A frameskip of 1 means we get every frame
506 | max_episode_steps=frameskip * 100000,
507 | nondeterministic=nondeterministic,
508 | )
509 |
510 | # No frameskip. (Atari has no entropy source, so these are
511 | # deterministic environments.)
512 | register(
513 | id='{}NoFrameskip-v4'.format(name),
514 | entry_point='gym.envs.atari:AtariEnv',
515 | kwargs={'game': game, 'obs_type': obs_type, 'frameskip': 1}, # A frameskip of 1 means we get every frame
516 | max_episode_steps=frameskip * 100000,
517 | nondeterministic=nondeterministic,
518 | )
519 |
520 |
521 |
522 |
523 |
524 | # Unit test
525 | # ---------
526 |
527 | register(
528 | id='CubeCrash-v0',
529 | entry_point='gym.envs.unittest:CubeCrash',
530 | reward_threshold=0.9,
531 | )
532 | register(
533 | id='CubeCrashSparse-v0',
534 | entry_point='gym.envs.unittest:CubeCrashSparse',
535 | reward_threshold=0.9,
536 | )
537 | register(
538 | id='CubeCrashScreenBecomesBlack-v0',
539 | entry_point='gym.envs.unittest:CubeCrashScreenBecomesBlack',
540 | reward_threshold=0.9,
541 | )
542 |
543 | register(
544 | id='MemorizeDigits-v0',
545 | entry_point='gym.envs.unittest:MemorizeDigits',
546 | reward_threshold=20,
547 | )
548 |
549 |
--------------------------------------------------------------------------------
/model/DDPG_trained_model_0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/model/DDPG_trained_model_0
--------------------------------------------------------------------------------
/model/DDPG_trained_model_1:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/model/DDPG_trained_model_1
--------------------------------------------------------------------------------
/model/DDPG_trained_model_2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/model/DDPG_trained_model_2
--------------------------------------------------------------------------------
/model/DDPG_trained_model_3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/model/DDPG_trained_model_3
--------------------------------------------------------------------------------
/model/DDPG_trained_model_4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/model/DDPG_trained_model_4
--------------------------------------------------------------------------------
/model/DDPG_trained_model_5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/model/DDPG_trained_model_5
--------------------------------------------------------------------------------
/model/DDPG_trained_model_6:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/model/DDPG_trained_model_6
--------------------------------------------------------------------------------
/model/DDPG_trained_model_7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/model/DDPG_trained_model_7
--------------------------------------------------------------------------------
/model/DDPG_trained_model_8:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/model/DDPG_trained_model_8
--------------------------------------------------------------------------------
/model/DDPG_trained_model_9:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/model/DDPG_trained_model_9
--------------------------------------------------------------------------------
/model/best_lstm_model.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/model/best_lstm_model.h5
--------------------------------------------------------------------------------
/results/model_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_1.jpg
--------------------------------------------------------------------------------
/results/model_10.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_10.jpg
--------------------------------------------------------------------------------
/results/model_10_rerun.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_10_rerun.jpg
--------------------------------------------------------------------------------
/results/model_11.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_11.jpg
--------------------------------------------------------------------------------
/results/model_11_rerun.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_11_rerun.jpg
--------------------------------------------------------------------------------
/results/model_12.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_12.jpg
--------------------------------------------------------------------------------
/results/model_13.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_13.jpg
--------------------------------------------------------------------------------
/results/model_14.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_14.jpg
--------------------------------------------------------------------------------
/results/model_15.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_15.jpg
--------------------------------------------------------------------------------
/results/model_16.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_16.jpg
--------------------------------------------------------------------------------
/results/model_17.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_17.jpg
--------------------------------------------------------------------------------
/results/model_18.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_18.jpg
--------------------------------------------------------------------------------
/results/model_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_2.jpg
--------------------------------------------------------------------------------
/results/model_3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_3.jpg
--------------------------------------------------------------------------------
/results/model_4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_4.jpg
--------------------------------------------------------------------------------
/results/model_5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_5.jpg
--------------------------------------------------------------------------------
/results/model_6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_6.jpg
--------------------------------------------------------------------------------
/results/model_7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_7.jpg
--------------------------------------------------------------------------------
/results/model_8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_8.jpg
--------------------------------------------------------------------------------
/results/model_9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_9.jpg
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import multiprocessing
3 | import os.path as osp
4 | import gym
5 | from collections import defaultdict
6 | import tensorflow as tf
7 | import numpy as np
8 |
9 | from baselines.common.vec_env.vec_video_recorder import VecVideoRecorder
10 | from baselines.common.vec_env.vec_frame_stack import VecFrameStack
11 | from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, make_env
12 | from baselines.common.tf_util import get_session
13 | from baselines import logger
14 | from importlib import import_module
15 |
16 | from baselines.common.vec_env.vec_normalize import VecNormalize
17 |
18 | try:
19 | from mpi4py import MPI
20 | except ImportError:
21 | MPI = None
22 |
23 | try:
24 | import pybullet_envs
25 | except ImportError:
26 | pybullet_envs = None
27 |
28 | try:
29 | import roboschool
30 | except ImportError:
31 | roboschool = None
32 |
33 | _game_envs = defaultdict(set)
34 | for env in gym.envs.registry.all():
35 | # TODO: solve this with regexes
36 | env_type = env._entry_point.split(':')[0].split('.')[-1]
37 | _game_envs[env_type].add(env.id)
38 |
39 | # reading benchmark names directly from retro requires
40 | # importing retro here, and for some reason that crashes tensorflow
41 | # in ubuntu
42 | _game_envs['retro'] = {
43 | 'BubbleBobble-Nes',
44 | 'SuperMarioBros-Nes',
45 | 'TwinBee3PokoPokoDaimaou-Nes',
46 | 'SpaceHarrier-Nes',
47 | 'SonicTheHedgehog-Genesis',
48 | 'Vectorman-Genesis',
49 | 'FinalFight-Snes',
50 | 'SpaceInvaders-Snes',
51 | }
52 |
53 |
54 | def train(args, extra_args):
55 | env_type, env_id = get_env_type(args.env)
56 | print('env_type: {}'.format(env_type))
57 |
58 | total_timesteps = int(args.num_timesteps)
59 | seed = args.seed
60 |
61 | learn = get_learn_function(args.alg)
62 | alg_kwargs = get_learn_function_defaults(args.alg, env_type)
63 | alg_kwargs.update(extra_args)
64 |
65 | env = build_env(args)
66 | if args.save_video_interval != 0:
67 | env = VecVideoRecorder(env, osp.join(logger.Logger.CURRENT.dir, "videos"),
68 | record_video_trigger=lambda x: x % args.save_video_interval == 0, video_length=args.save_video_length)
69 |
70 | if args.network:
71 | alg_kwargs['network'] = args.network
72 | else:
73 | if alg_kwargs.get('network') is None:
74 | alg_kwargs['network'] = get_default_network(env_type)
75 |
76 | print('Training {} on {}:{} with arguments \n{}'.format(args.alg, env_type, env_id, alg_kwargs))
77 |
78 | model = learn(
79 | env=env,
80 | seed=seed,
81 | total_timesteps=total_timesteps,
82 | **alg_kwargs
83 | )
84 |
85 | return model, env
86 |
87 |
88 | def build_env(args):
89 | ncpu = multiprocessing.cpu_count()
90 | if sys.platform == 'darwin': ncpu //= 2
91 | nenv = args.num_env or ncpu
92 | alg = args.alg
93 | seed = args.seed
94 |
95 | env_type, env_id = get_env_type(args.env)
96 |
97 | if env_type in {'atari', 'retro'}:
98 | if alg == 'deepq':
99 | env = make_env(env_id, env_type, seed=seed, wrapper_kwargs={'frame_stack': True})
100 | elif alg == 'trpo_mpi':
101 | env = make_env(env_id, env_type, seed=seed)
102 | else:
103 | frame_stack_size = 4
104 | env = make_vec_env(env_id, env_type, nenv, seed, gamestate=args.gamestate, reward_scale=args.reward_scale)
105 | env = VecFrameStack(env, frame_stack_size)
106 |
107 | else:
108 | config = tf.ConfigProto(allow_soft_placement=True,
109 | intra_op_parallelism_threads=1,
110 | inter_op_parallelism_threads=1)
111 | config.gpu_options.allow_growth = True
112 | get_session(config=config)
113 |
114 | flatten_dict_observations = alg not in {'her'}
115 | env = make_vec_env(env_id, env_type, args.num_env or 1, seed, reward_scale=args.reward_scale, flatten_dict_observations=flatten_dict_observations)
116 |
117 | if env_type == 'mujoco':
118 | env = VecNormalize(env)
119 |
120 | return env
121 |
122 |
123 | def get_env_type(env_id):
124 | # Re-parse the gym registry, since we could have new envs since last time.
125 | for env in gym.envs.registry.all():
126 | env_type = env._entry_point.split(':')[0].split('.')[-1]
127 | _game_envs[env_type].add(env.id) # This is a set so add is idempotent
128 |
129 | if env_id in _game_envs.keys():
130 | env_type = env_id
131 | env_id = [g for g in _game_envs[env_type]][0]
132 | else:
133 | env_type = None
134 | for g, e in _game_envs.items():
135 | if env_id in e:
136 | env_type = g
137 | break
138 | assert env_type is not None, 'env_id {} is not recognized in env types'.format(env_id, _game_envs.keys())
139 |
140 | return env_type, env_id
141 |
142 |
143 | def get_default_network(env_type):
144 | if env_type in {'atari', 'retro'}:
145 | return 'cnn'
146 | else:
147 | return 'mlp'
148 |
149 | def get_alg_module(alg, submodule=None):
150 | submodule = submodule or alg
151 | try:
152 | # first try to import the alg module from baselines
153 | alg_module = import_module('.'.join(['baselines', alg, submodule]))
154 | except ImportError:
155 | # then from rl_algs
156 | alg_module = import_module('.'.join(['rl_' + 'algs', alg, submodule]))
157 |
158 | return alg_module
159 |
160 |
161 | def get_learn_function(alg):
162 | return get_alg_module(alg).learn
163 |
164 |
165 | def get_learn_function_defaults(alg, env_type):
166 | try:
167 | alg_defaults = get_alg_module(alg, 'defaults')
168 | kwargs = getattr(alg_defaults, env_type)()
169 | except (ImportError, AttributeError):
170 | kwargs = {}
171 | return kwargs
172 |
173 |
174 |
175 | def parse_cmdline_kwargs(args):
176 | '''
177 | convert a list of '='-spaced command-line arguments to a dictionary, evaluating python objects when possible
178 | '''
179 | def parse(v):
180 |
181 | assert isinstance(v, str)
182 | try:
183 | return eval(v)
184 | except (NameError, SyntaxError):
185 | return v
186 |
187 | return {k: parse(v) for k,v in parse_unknown_args(args).items()}
188 |
189 |
190 |
191 | def main(args):
192 | # configure logger, disable logging in child MPI processes (with rank > 0)
193 |
194 | arg_parser = common_arg_parser()
195 | args, unknown_args = arg_parser.parse_known_args(args)
196 | extra_args = parse_cmdline_kwargs(unknown_args)
197 |
198 | if args.extra_import is not None:
199 | import_module(args.extra_import)
200 |
201 | if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
202 | rank = 0
203 | logger.configure()
204 | else:
205 | logger.configure(format_strs=[])
206 | rank = MPI.COMM_WORLD.Get_rank()
207 |
208 | # If argument indicate training to be done:
209 | model, env = train(args, extra_args)
210 | env.close()
211 |
212 | if args.save_path is not None and rank == 0:
213 | save_path = osp.expanduser(args.save_path)
214 | model.save(save_path)
215 | saver = tf.train.Saver()
216 |
217 | #logger.info("saving the trained model")
218 | #start_time_save = time.time()
219 | #saver.save(sess, save_path + "ddpg_test_model")
220 | #logger.info('runtime saving: {}s'.format(time.time() - start_time_save))
221 |
222 | # If it is a test run on the learned model
223 | if args.play:
224 | logger.log("Running trained model")
225 | env = build_env(args)
226 | obs = env.reset()
227 |
228 | state = model.initial_state if hasattr(model, 'initial_state') else None
229 | dones = np.zeros((1,))
230 |
231 | while True:
232 | if state is not None:
233 | actions, _, state, _ = model.step(obs,S=state, M=dones)
234 | else:
235 | actions, _, _, _ = model.step(obs)
236 |
237 | obs, _, done, _ = env.step(actions)
238 | env.render()
239 | done = done.any() if isinstance(done, np.ndarray) else done
240 |
241 | if done:
242 | obs = env.reset()
243 |
244 | env.close()
245 |
246 | return model
247 |
248 | if __name__ == '__main__':
249 | main(sys.argv)
250 |
--------------------------------------------------------------------------------
/test_iteration_1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/test_iteration_1.gif
--------------------------------------------------------------------------------
/test_result/asset_evolution_test_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/test_result/asset_evolution_test_1.png
--------------------------------------------------------------------------------
/test_result/cummulative_returns_test_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/test_result/cummulative_returns_test_1.png
--------------------------------------------------------------------------------
/test_result/efficient_frontier.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/test_result/efficient_frontier.png
--------------------------------------------------------------------------------
/test_result/kpi_backtest.csv:
--------------------------------------------------------------------------------
1 | ,Avg. monthly return,Pos months pct,Avg yearly return,Max monthly dd,Max drawdown,Lake ratio,Gain to Pain,Sharpe ratio,Sortino ratio
2 | KPI,0.384797258876,49.2156862745098,3.44328458684,-1.28,-2.4,0.0178496358163,1.48968700183,0.577476288178,0.881134359686
3 |
--------------------------------------------------------------------------------
/test_result/kpi_test_1.csv:
--------------------------------------------------------------------------------
1 | ,Avg. monthly return,Pos months pct,Avg yearly return,Max monthly dd,Max drawdown,Lake ratio,Gain to Pain,Sharpe ratio,Sortino ratio
2 | KPI,1.41869638196,57.84313725490197,3.8704579038,-2,-4.09,0.0395631134021,1.23491340742,0.94625848311,1.3355405784
3 |
--------------------------------------------------------------------------------
/test_result/portfolios_return.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/test_result/portfolios_return.png
--------------------------------------------------------------------------------
/test_result/portfolios_returns.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/test_result/portfolios_returns.png
--------------------------------------------------------------------------------
/test_result/portfolios_risk.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/test_result/portfolios_risk.png
--------------------------------------------------------------------------------
/test_result/price_prediction_LSTM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/test_result/price_prediction_LSTM.png
--------------------------------------------------------------------------------
/train_iterations_9.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_iterations_9.gif
--------------------------------------------------------------------------------
/train_result/asset_evolution_train_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/asset_evolution_train_1.png
--------------------------------------------------------------------------------
/train_result/asset_evolution_train_10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/asset_evolution_train_10.png
--------------------------------------------------------------------------------
/train_result/asset_evolution_train_11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/asset_evolution_train_11.png
--------------------------------------------------------------------------------
/train_result/asset_evolution_train_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/asset_evolution_train_2.png
--------------------------------------------------------------------------------
/train_result/asset_evolution_train_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/asset_evolution_train_3.png
--------------------------------------------------------------------------------
/train_result/asset_evolution_train_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/asset_evolution_train_4.png
--------------------------------------------------------------------------------
/train_result/asset_evolution_train_5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/asset_evolution_train_5.png
--------------------------------------------------------------------------------
/train_result/asset_evolution_train_6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/asset_evolution_train_6.png
--------------------------------------------------------------------------------
/train_result/asset_evolution_train_7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/asset_evolution_train_7.png
--------------------------------------------------------------------------------
/train_result/asset_evolution_train_8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/asset_evolution_train_8.png
--------------------------------------------------------------------------------
/train_result/asset_evolution_train_9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/asset_evolution_train_9.png
--------------------------------------------------------------------------------
/train_result/cummulative_returns_train_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/cummulative_returns_train_1.png
--------------------------------------------------------------------------------
/train_result/cummulative_returns_train_10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/cummulative_returns_train_10.png
--------------------------------------------------------------------------------
/train_result/cummulative_returns_train_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/cummulative_returns_train_2.png
--------------------------------------------------------------------------------
/train_result/cummulative_returns_train_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/cummulative_returns_train_3.png
--------------------------------------------------------------------------------
/train_result/cummulative_returns_train_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/cummulative_returns_train_4.png
--------------------------------------------------------------------------------
/train_result/cummulative_returns_train_5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/cummulative_returns_train_5.png
--------------------------------------------------------------------------------
/train_result/cummulative_returns_train_6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/cummulative_returns_train_6.png
--------------------------------------------------------------------------------
/train_result/cummulative_returns_train_7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/cummulative_returns_train_7.png
--------------------------------------------------------------------------------
/train_result/cummulative_returns_train_8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/cummulative_returns_train_8.png
--------------------------------------------------------------------------------
/train_result/cummulative_returns_train_9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/cummulative_returns_train_9.png
--------------------------------------------------------------------------------
/train_result/kpi_train_1.csv:
--------------------------------------------------------------------------------
1 | ,Avg. monthly return,Pos months pct,Avg yearly return,Max monthly dd,Max drawdown,Lake ratio,Gain to Pain,Sharpe ratio,Sortino ratio
2 | KPI,0.868741890631,53.47358121330724,10.2170228228,-2.17,-6.69,0.0940450425508,0.792966158356,0.558347516033,0.798649743083
3 |
--------------------------------------------------------------------------------
/train_result/kpi_train_10.csv:
--------------------------------------------------------------------------------
1 | ,Avg. monthly return,Pos months pct,Avg yearly return,Max monthly dd,Max drawdown,Lake ratio,Gain to Pain,Sharpe ratio,Sortino ratio
2 | KPI,1.32611976461,54.17156286721504,15.6440160667,-2.08,-6.38,0.0355723668945,0.204786346071,0.928934411628,0.173474724563
3 |
--------------------------------------------------------------------------------
/train_result/kpi_train_2.csv:
--------------------------------------------------------------------------------
1 | ,Avg. monthly return,Pos months pct,Avg yearly return,Max monthly dd,Max drawdown,Lake ratio,Gain to Pain,Sharpe ratio,Sortino ratio
2 | KPI,0.640332831824,51.2720156555773,7.02923228398,-2.51,-3.41,0.0517834512147,0.640076187631,0.428556037997,0.621309289651
3 |
--------------------------------------------------------------------------------
/train_result/kpi_train_3.csv:
--------------------------------------------------------------------------------
1 | ,Avg. monthly return,Pos months pct,Avg yearly return,Max monthly dd,Max drawdown,Lake ratio,Gain to Pain,Sharpe ratio,Sortino ratio
2 | KPI,0.418111638812,51.56555772994129,4.74311658088,-2.09,-8.16,0.107625342005,0.314120450446,0.192896687976,0.267347484606
3 |
--------------------------------------------------------------------------------
/train_result/kpi_train_4.csv:
--------------------------------------------------------------------------------
1 | ,Avg. monthly return,Pos months pct,Avg yearly return,Max monthly dd,Max drawdown,Lake ratio,Gain to Pain,Sharpe ratio,Sortino ratio
2 | KPI,0.806692096574,52.6908023483366,8.99711662931,-1.81,-8.23,0.0686801277596,0.843390655755,0.552771717154,0.796880009409
3 |
--------------------------------------------------------------------------------
/train_result/kpi_train_5.csv:
--------------------------------------------------------------------------------
1 | ,Avg. monthly return,Pos months pct,Avg yearly return,Max monthly dd,Max drawdown,Lake ratio,Gain to Pain,Sharpe ratio,Sortino ratio
2 | KPI,1.06979591874,54.109589041095894,11.9608325657,-1.84,-7.67,0.0846979767109,1.05175799254,0.693293494501,1.0105019238
3 |
--------------------------------------------------------------------------------
/train_result/kpi_train_6.csv:
--------------------------------------------------------------------------------
1 | ,Avg. monthly return,Pos months pct,Avg yearly return,Max monthly dd,Max drawdown,Lake ratio,Gain to Pain,Sharpe ratio,Sortino ratio
2 | KPI,1.1403025973,54.59882583170255,12.8482034616,-1.74,-6.88,0.0697130237698,1.23040299427,0.77590085894,1.12930357835
3 |
--------------------------------------------------------------------------------
/train_result/kpi_train_7.csv:
--------------------------------------------------------------------------------
1 | ,Avg. monthly return,Pos months pct,Avg yearly return,Max monthly dd,Max drawdown,Lake ratio,Gain to Pain,Sharpe ratio,Sortino ratio
2 | KPI,1.53304467069,55.283757338551865,17.2352833671,-1.69,-8.2,0.0704577793934,1.49385458682,0.916896826069,1.34057729725
3 |
--------------------------------------------------------------------------------
/train_result/kpi_train_8.csv:
--------------------------------------------------------------------------------
1 | ,Avg. monthly return,Pos months pct,Avg yearly return,Max monthly dd,Max drawdown,Lake ratio,Gain to Pain,Sharpe ratio,Sortino ratio
2 | KPI,2.2381635965,54.01174168297456,26.6304669876,-7.47,-3.71,0.0352518826338,2.49573425564,1.2167193459,1.81690763343
3 |
--------------------------------------------------------------------------------
/train_result/kpi_train_9.csv:
--------------------------------------------------------------------------------
1 | ,Avg. monthly return,Pos months pct,Avg yearly return,Max monthly dd,Max drawdown,Lake ratio,Gain to Pain,Sharpe ratio,Sortino ratio
2 | KPI,1.90387965587,52.98434442270059,22.3442318826,-8.2,-4.4,0.0419270101037,1.73680487363,0.962348814721,1.4271890396
3 |
--------------------------------------------------------------------------------