├── LICENSE ├── README.md ├── baselines └── baselines │ └── ddpg │ ├── README.md │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── ddpg.cpython-36.pyc │ ├── ddpg_learner.cpython-36.pyc │ ├── memory.cpython-36.pyc │ ├── models.cpython-36.pyc │ └── noise.cpython-36.pyc │ ├── ddpg.py │ ├── ddpg_learner.py │ ├── memory.py │ ├── models.py │ ├── noise.py │ └── test_smoke.py ├── compare.py ├── data ├── AAPL.csv ├── AXP.csv ├── BA.csv ├── CAT.csv ├── CSCO.csv ├── CVX.csv ├── DIS.csv ├── GE.csv ├── GLD.csv ├── GS.csv ├── HD.csv ├── IBM.csv ├── INTC.csv ├── JNJ.csv ├── JPM.csv ├── KO.csv ├── MCD.csv ├── MMM.csv ├── MRK.csv ├── MSFT.csv ├── NKE.csv ├── PFE.csv ├── PG.csv ├── QQQ.csv ├── SHV.csv ├── SHY.csv ├── SPY.csv ├── UNH.csv ├── UTX.csv ├── VZ.csv ├── WMT.csv ├── XOM.csv ├── ^DJI.csv ├── ^GSPC.csv ├── ^IXIC.csv ├── ^RUT.csv ├── ^TNX.csv ├── ^TYX.csv ├── ^VIX.csv ├── ddpg_input_states.csv └── ddpg_stock_price.csv ├── data_preprocessing.py ├── feature_select.py ├── gym └── envs │ ├── StarTrader │ ├── StarTrade_env.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── StarTrade_env.cpython-36.pyc │ │ ├── StarTrade_test_env.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ └── intelliTrade_env.cpython-36.pyc │ └── data │ │ ├── AAPL.csv │ │ ├── AXP.csv │ │ ├── BA.csv │ │ ├── CAT.csv │ │ ├── CSCO.csv │ │ ├── CVX.csv │ │ ├── DIS.csv │ │ ├── GE.csv │ │ ├── GS.csv │ │ ├── HD.csv │ │ ├── IBM.csv │ │ ├── INTC.csv │ │ ├── JNJ.csv │ │ ├── JPM.csv │ │ ├── KO.csv │ │ ├── MCD.csv │ │ ├── MMM.csv │ │ ├── MRK.csv │ │ ├── MSFT.csv │ │ ├── NKE.csv │ │ ├── PFE.csv │ │ ├── PG.csv │ │ ├── UNH.csv │ │ ├── UTX.csv │ │ ├── VZ.csv │ │ ├── WMT.csv │ │ └── XOM.csv │ ├── StarTraderTest │ ├── StarTrade_test_env.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── StarTrade_env.cpython-36.pyc │ │ ├── StarTrade_test_env.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ └── intelliTrade_env.cpython-36.pyc │ └── data │ │ ├── AAPL.csv │ │ ├── AXP.csv │ │ ├── BA.csv │ │ ├── CAT.csv │ │ ├── CSCO.csv │ │ ├── CVX.csv │ │ ├── DIS.csv │ │ ├── GE.csv │ │ ├── GS.csv │ │ ├── HD.csv │ │ ├── IBM.csv │ │ ├── INTC.csv │ │ ├── JNJ.csv │ │ ├── JPM.csv │ │ ├── KO.csv │ │ ├── MCD.csv │ │ ├── MMM.csv │ │ ├── MRK.csv │ │ ├── MSFT.csv │ │ ├── NKE.csv │ │ ├── PFE.csv │ │ ├── PG.csv │ │ ├── UNH.csv │ │ ├── UTX.csv │ │ ├── VZ.csv │ │ ├── WMT.csv │ │ └── XOM.csv │ └── __init__.py ├── model ├── DDPG_trained_model_0 ├── DDPG_trained_model_1 ├── DDPG_trained_model_2 ├── DDPG_trained_model_3 ├── DDPG_trained_model_4 ├── DDPG_trained_model_5 ├── DDPG_trained_model_6 ├── DDPG_trained_model_7 ├── DDPG_trained_model_8 ├── DDPG_trained_model_9 └── best_lstm_model.h5 ├── results ├── model_1.jpg ├── model_10.jpg ├── model_10_rerun.jpg ├── model_11.jpg ├── model_11_rerun.jpg ├── model_12.jpg ├── model_13.jpg ├── model_14.jpg ├── model_15.jpg ├── model_16.jpg ├── model_17.jpg ├── model_18.jpg ├── model_2.jpg ├── model_3.jpg ├── model_4.jpg ├── model_5.jpg ├── model_6.jpg ├── model_7.jpg ├── model_8.jpg └── model_9.jpg ├── run.py ├── test_iteration_1.gif ├── test_result ├── all_strategies_cum_returns.csv ├── all_strategies_returns.csv ├── asset_evolution_test_1.png ├── cummulative_returns_test_1.png ├── efficient_frontier.png ├── kpi_backtest.csv ├── kpi_test_1.csv ├── portfolios_return.png ├── portfolios_returns.png ├── portfolios_risk.png ├── price_prediction_LSTM.png ├── trading_book_backtest.csv └── trading_book_test_1.csv ├── train_iterations_9.gif └── train_result ├── asset_evolution_train_1.png ├── asset_evolution_train_10.png ├── asset_evolution_train_11.png ├── asset_evolution_train_2.png ├── asset_evolution_train_3.png ├── asset_evolution_train_4.png ├── asset_evolution_train_5.png ├── asset_evolution_train_6.png ├── asset_evolution_train_7.png ├── asset_evolution_train_8.png ├── asset_evolution_train_9.png ├── cummulative_returns_train_1.png ├── cummulative_returns_train_10.png ├── cummulative_returns_train_2.png ├── cummulative_returns_train_3.png ├── cummulative_returns_train_4.png ├── cummulative_returns_train_5.png ├── cummulative_returns_train_6.png ├── cummulative_returns_train_7.png ├── cummulative_returns_train_8.png ├── cummulative_returns_train_9.png ├── kpi_train_1.csv ├── kpi_train_10.csv ├── kpi_train_2.csv ├── kpi_train_3.csv ├── kpi_train_4.csv ├── kpi_train_5.csv ├── kpi_train_6.csv ├── kpi_train_7.csv ├── kpi_train_8.csv ├── kpi_train_9.csv ├── trading_book_train_1.csv ├── trading_book_train_10.csv ├── trading_book_train_11.csv ├── trading_book_train_2.csv ├── trading_book_train_3.csv ├── trading_book_train_4.csv ├── trading_book_train_5.csv ├── trading_book_train_6.csv ├── trading_book_train_7.csv ├── trading_book_train_8.csv └── trading_book_train_9.csv /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Jiew Wan Tan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [//]: # (Image References) 2 | 3 | [image1]: https://github.com/jiewwantan/StarTrader/blob/master/train_iterations_9.gif "Training iterations" 4 | [image2]: https://github.com/jiewwantan/StarTrader/blob/master/test_iteration_1.gif "Testing trained model with one iteration" 5 | [image3]: https://github.com/jiewwantan/StarTrader/blob/master/test_result/portfolios_return.png "Trading strategy performance returns comparison" 6 | [image4]: https://github.com/jiewwantan/StarTrader/blob/master/test_result/portfolios_risk.png "Trading strategy performance risk comparison" 7 | 8 | # **StarTrader:**
Intelligent Trading Agent Development
with Deep Reinforcement Learning 9 | 10 | ### Introduction 11 | 12 | This project sets to create an intelligent trading agent and a trading environment that provides an ideal learning ground. A real-world trading environment is complex with stock, related instruments, macroeconomic, news and possibly alternative data in consideration. An effective agent must derive efficient representations of the environment from high-dimensional input, and generalize past experience to new situation. The project adopts a deep reinforcement learning algorithm, deep deterministic policy gradient (DDPG) to trade a portfolio of five stocks. Different reward system and hyperparameters was tried. Its performance compared to models created by recurrent neural network, modern portfolio theory, simple buy-and-hold and benchmark DJIA index. The agent and environment will then be evaluated to deliberate possible improvement and the agent potential to beat professional human trader, just like Deepmind’s Alpha series of intelligent game playing agents. 13 | 14 | The trading agent will learn and trade in [OpenAI Gym](https://gym.openai.com/) environment. Two Gym environments are created to serve the purpose, one for training (StarTrader-v0), another testing (StarTraderTest-v0). Both versions of StarTrader will utilize Gym's baseline implmentation of Deep deterministic policy gradient (DDPG). 15 | 16 | A portfolio of five stocks (out of 27 Dow Jones Industrial Average stocks) are selected based on non-correlation factor. StarTrader will trade these five non-correlated stocks by learning to maximize total asset (portfolio value + current account balance) as its goal. During the trading process, StarTrader-v0 will also optimize the portfolio by deciding how many stock units to trade for each of the five stocks. 17 | 18 | Based on non-correlation factor, a portfolio optimization algorithm has chosen the following five stocks to trade: 19 | 20 | 1. American Express 21 | 2. Wal Mart 22 | 3. UnitedHealth Group 23 | 4. Apple 24 | 5. Verizon Communications 25 | 26 | The preprocessing function creates technical data derived from each of the stock’s OHLCV data. On average there are roughly 6-8 time series data derived for each stock. 27 | 28 | Apart from stock data, context data is also used to aid learning: 29 | 30 | 1. S&P 500 index 31 | 2. Dow Jones Industrial Average index 32 | 3. NASDAQ Composite index 33 | 4. Russell 2000 index 34 | 5. SPDR S&P 500 ETF 35 | 6. Invesco QQQ Trust 36 | 7. CBOE Volatility Index 37 | 8. SPDR Gold Shares 38 | 9. Treasury Yield 30 Years 39 | 10. CBOE Interest Rate 10 Year T Note 40 | 11. iShares 1-3 Year Treasury Bond ETF 41 | 12. iShares Short Treasury Bond ETF 42 | 43 | Similarly, technical data derived from the above context data’s OHLCV data are being created. All data preprocessing is handled by two modules: 44 | 1. data_preprocessing.py 45 | 2. feature_select.py 46 | 47 | The preprocessed data are then being fed directly to StarTrader’s trading environment: class StarTradingEnv. 48 | 49 | The feature selection module (feature_select.py) select about 6-8 features out of 41 OHLCV and its technical data, In total, there are 121 features (may varies on different machine as the algorithm is not seeded) with about 36 stock feature data and the rest are context feature data. 50 | 51 | When trading is executed, 121 features along with total asset, current asset holdings and unrealized profit and loss will form a complete state space for the agent to trade and learn. The state space is designed to allow the agent to get a sense of the instantaneous environment in addition to how its interactions with the environment affects future state space. In another words, the trading agent bears the fruits and consequences of its own actions. 52 | 53 | ### Training agent on 9 iterations 54 | ![Training iterations][image1] 55 | 56 | ### Testing agent on one iteration 57 | No learning or model refinement, purely on testing the trained model. 58 | Trading agent survived the major market correction in 2018 with 1.13 Sharpe ratio.
59 | 60 | ![Testing trained model with one iteration][image2] 61 | 62 | ### Compare agent's performance with other trading strategies 63 | DDPG is the best performer in terms of cumulative returns. However with a much less volatile ride, RNN-LSTM model has better risk-adjusted return: the highest Sharpe ratio (1.88) and Sortino ratio (3.06). Both RNN-LSTM and DRL-DDPG modelled trading strategies have trading costs: commission (based on Interactive Broker's fee) and slippage (modelled by Zipline and based on stock's daily volume) incorporated since there are many transactions during the trading window. The other buy-and-hold strategies' trading costs are omitted since there is stocks are only transacted once. 64 | DDPG's reward system shall be modified to yield higher risk-adjusted return. 65 | For a fair comparison, LSTM model uses the same training data and similar backtester as DDPG model. 66 | 67 | ![Trading strategy performance returns comparison][image3] 68 | ![Trading strategy performance risk comparison][image4] 69 | 70 | 71 | ## Prerequisites 72 | 73 | Python 3.6 or Anaconda with Python 3.6 environment 74 | Python packages: pandas, numpy, matplotlib, statsmodels, sklearn, tensorflow 75 | 76 | The code is written in a Linux machine and has been tested on two operating systems: 77 | Linux Ubuntu 16.04 & Windows 10 Pro 78 | 79 | 80 | ## Installation instructions: 81 | 82 | 1. Installation of system packages CMake, OpenMPI on Mac 83 | 84 | ```brew install cmake openmpi``` 85 | 86 | 2. Activate environemnt and install gym under this environment 87 | 88 | ```pip install gym``` 89 | 90 | 3. Download Official Baseline Package 91 | 92 | Clone the repo: 93 | 94 | ``` 95 | git clone https://github.com/openai/baselines.git 96 | 97 | cd baselines 98 | 99 | pip install -e . 100 | ``` 101 | 102 | 4. Install Tensorflow 103 | 104 | There are several ways of installing Tensorflow, this page provide a good description on how it can be done with system OS, Python version and GPU availability taken into consideration. 105 | 106 | https://www.tensorflow.org/install/ 107 | 108 | In short, after environment activation, Tensorflow can be installed with these commands: 109 | 110 | Tensorflow for CPU:
111 | ```pip3 install --upgrade tensorflow``` 112 | 113 | Tensorflow for GPU:
114 | ```pip3 install --upgrade tensorflow-gpu``` 115 | 116 | Installing Tensorflow GPU allows faster training if your machine has nVidia GPU(s) built-in. 117 | However, Tensorflow GPU version requires the installation of the right cuDNN and CUDA, these pages provide instructions to ensure the right version is installed: 118 | 119 | [Ubuntu](https://www.tensorflow.org/install/install_linux) 120 | 121 | [MacOS](https://www.tensorflow.org/install/install_mac (Tensorflow 1.2 no longer provides GPU support for MacOS) ) 122 | 123 | [Windows](https://www.tensorflow.org/install/install_windows) 124 | 125 | 5. Place StarTrader and StarTraderTest folders in this repository to your machine's OpenAI Gym's environment folder: 126 | 127 | gym/envs/ 128 | 129 | 6. Replace the ```__init__.py``` file in the following folder with the ```__ini__.py``` provided in this repository: 130 | 131 | ```gym/envs/__init__.py``` 132 | 133 | 7. Place run.py in baselines folder to the folder where you want to execute run.py, for example: 134 | 135 | From Gym's installation:
136 | ```baselines/baselines/run.py``` 137 | 138 | To:
139 | ```run.py``` 140 | 141 | 8. Place 'data' folder to the folder where run.py resides 142 | 143 | ```/data/``` 144 | 145 | 9. Replace ddpg.py from Gym's installation with the ddpg.py in this repository: 146 | 147 | In your machine Gym's installation:
148 | ```baselines/baselines/ddpg/ddpg.py``` 149 | 150 | replaced by the ddpg.py in repository:
151 | ```baselines/baselines/ddpg/ddpg.py``` 152 | 153 | 10. Replace ddpg_learner.py from Gym's installation with the ddpg_learner.py in this repository: 154 | 155 | In your machine Gym's installation:
156 | ```baselines/baselines/ddpg/ddpg_learner.py``` 157 | 158 | replaced by the ddpg_learner.py in repository:
159 | ```baselines/baselines/ddpg/ddpg_learner.py``` 160 | 161 | 11. Place feature_select.py and data_preprocessing.py in this repository into the same folder as run.py 162 | 163 | 12. Place the following folders in this repository into the folder where your run.py resides 164 | 165 | ```/test_result/```
166 | ```/train_result/```
167 | ```/model/```
168 | 169 | You do not need to include the folders' content, they will be generated when the program executes. If contents are included, they will be replaced once program executes. 170 | 171 | 12. Under the folder where run.py resides enter the following command: 172 | 173 | To train agent:
174 | ```python -m run --alg=ddpg --env=StarTrader-v0 --network=mlp --num_timesteps=2e4``` 175 | 176 | To test agent:
177 | ```python -m run --alg=ddpg --env=StarTraderTest-v0 --network=mlp --num_timesteps=2e3 --load_path='./model/DDPG_trained_model_8'``` 178 | 179 | If you have trained a better model, replace ```DDPG_trained_model_8``` with your new model. 180 | 181 | After training and testing the agent successfully, pick the first DDPG trading book for the test run which is saved as ./test_result/trading_book_test_1.csv or modify filename in compare.py.
182 | Compare agent performance with benchmark index and other trading strategies:
183 | 184 | ```python compare.py``` 185 | 186 | ## Special intructions: 187 | 1. Depends on machine configuration, the following intallation maybe necessary: 188 | 189 | ```pip3 install -U numpy```
190 | ```pip3 install opencv-python```
191 | ```pip3 install mujoco-py==0.5.7```
192 | ```pip3 install lockfile```
193 | 194 | 2. The technical analysis library, TA-Lib may be tricky to install in some machines. The following page is a handy guide: 195 | https://goldenjumper.wordpress.com/tag/ta-lib/ 196 | 197 | graphiviz which is required to plot the XGBoost tree diagram, can be installed with the following command:
198 | Windows:
199 | ```conda install python-graphviz```
200 | Mac/Linux:
201 | ```conda install graphviz```
202 | 203 | -------------------------------------------------------------------------------- /baselines/baselines/ddpg/README.md: -------------------------------------------------------------------------------- 1 | # DDPG 2 | 3 | - Original paper: https://arxiv.org/abs/1509.02971 4 | - Baselines post: https://blog.openai.com/better-exploration-with-parameter-noise/ 5 | - `python -m baselines.run --alg=ddpg --env=HalfCheetah-v2 --num_timesteps=1e6` runs the algorithm for 1M frames = 10M timesteps on a Mujoco environment. See help (`-h`) for more options. 6 | -------------------------------------------------------------------------------- /baselines/baselines/ddpg/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/baselines/baselines/ddpg/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /baselines/baselines/ddpg/__pycache__/ddpg.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/baselines/baselines/ddpg/__pycache__/ddpg.cpython-36.pyc -------------------------------------------------------------------------------- /baselines/baselines/ddpg/__pycache__/ddpg_learner.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/baselines/baselines/ddpg/__pycache__/ddpg_learner.cpython-36.pyc -------------------------------------------------------------------------------- /baselines/baselines/ddpg/__pycache__/memory.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/baselines/baselines/ddpg/__pycache__/memory.cpython-36.pyc -------------------------------------------------------------------------------- /baselines/baselines/ddpg/__pycache__/models.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/baselines/baselines/ddpg/__pycache__/models.cpython-36.pyc -------------------------------------------------------------------------------- /baselines/baselines/ddpg/__pycache__/noise.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/baselines/baselines/ddpg/__pycache__/noise.cpython-36.pyc -------------------------------------------------------------------------------- /baselines/baselines/ddpg/ddpg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from collections import deque 4 | import pickle 5 | 6 | from baselines.ddpg.ddpg_learner import DDPG 7 | from baselines.ddpg.models import Actor, Critic 8 | from baselines.ddpg.memory import Memory 9 | from baselines.ddpg.noise import AdaptiveParamNoiseSpec, NormalActionNoise, OrnsteinUhlenbeckActionNoise 10 | from baselines.common import set_global_seeds 11 | import baselines.common.tf_util as U 12 | 13 | from baselines import logger 14 | import numpy as np 15 | 16 | try: 17 | from mpi4py import MPI 18 | except ImportError: 19 | MPI = None 20 | 21 | def learn(network, env, 22 | seed=None, 23 | total_timesteps=None, 24 | nb_epochs=None, # with default settings, perform 1M steps total 25 | nb_epoch_cycles=20, 26 | nb_rollout_steps=100, 27 | reward_scale=1.0, 28 | render=False, 29 | render_eval=False, 30 | noise_type='adaptive-param_0.2', 31 | normalize_returns=False, 32 | normalize_observations=True, 33 | critic_l2_reg=1e-2, 34 | actor_lr=1e-4, 35 | critic_lr=3e-4, 36 | popart=False, 37 | gamma=0.99, 38 | clip_norm=None, 39 | nb_train_steps=50, # per epoch cycle and MPI worker, 40 | nb_eval_steps=100, 41 | batch_size=128, # per MPI worker 42 | tau=0.01, 43 | eval_env=None, 44 | param_noise_adaption_interval=50, 45 | load_path = None, 46 | save_path = './model/', 47 | **network_kwargs): 48 | 49 | print("Save PATH;{}".format(save_path)) 50 | print("Load PATH;{}".format(load_path)) 51 | set_global_seeds(seed) 52 | 53 | if total_timesteps is not None: 54 | assert nb_epochs is None 55 | nb_epochs = int(total_timesteps) // (nb_epoch_cycles * nb_rollout_steps) 56 | else: 57 | nb_epochs = 500 58 | 59 | if MPI is not None: 60 | rank = MPI.COMM_WORLD.Get_rank() 61 | else: 62 | rank = 0 63 | 64 | nb_actions = env.action_space.shape[-1] 65 | assert (np.abs(env.action_space.low) == env.action_space.high).all() # we assume symmetric actions. 66 | 67 | memory = Memory(limit=int(1e6), action_shape=env.action_space.shape, observation_shape=env.observation_space.shape) 68 | critic = Critic(network=network, **network_kwargs) 69 | actor = Actor(nb_actions, network=network, **network_kwargs) 70 | 71 | action_noise = None 72 | param_noise = None 73 | if noise_type is not None: 74 | for current_noise_type in noise_type.split(','): 75 | current_noise_type = current_noise_type.strip() 76 | if current_noise_type == 'none': 77 | pass 78 | elif 'adaptive-param' in current_noise_type: 79 | _, stddev = current_noise_type.split('_') 80 | param_noise = AdaptiveParamNoiseSpec(initial_stddev=float(stddev), desired_action_stddev=float(stddev)) 81 | elif 'normal' in current_noise_type: 82 | _, stddev = current_noise_type.split('_') 83 | action_noise = NormalActionNoise(mu=np.zeros(nb_actions), sigma=float(stddev) * np.ones(nb_actions)) 84 | elif 'ou' in current_noise_type: 85 | _, stddev = current_noise_type.split('_') 86 | action_noise = OrnsteinUhlenbeckActionNoise(mu=np.zeros(nb_actions), sigma=float(stddev) * np.ones(nb_actions)) 87 | else: 88 | raise RuntimeError('unknown noise type "{}"'.format(current_noise_type)) 89 | 90 | max_action = env.action_space.high 91 | logger.info('scaling actions by {} before executing in env'.format(max_action)) 92 | 93 | agent = DDPG(actor, critic, memory, env.observation_space.shape, env.action_space.shape, 94 | gamma=gamma, tau=tau, normalize_returns=normalize_returns, normalize_observations=normalize_observations, 95 | batch_size=batch_size, action_noise=action_noise, param_noise=param_noise, critic_l2_reg=critic_l2_reg, 96 | actor_lr=actor_lr, critic_lr=critic_lr, enable_popart=popart, clip_norm=clip_norm, 97 | reward_scale=reward_scale) 98 | logger.info('Using agent with the following configuration:') 99 | logger.info(str(agent.__dict__.items())) 100 | 101 | eval_episode_rewards_history = deque(maxlen=100) 102 | episode_rewards_history = deque(maxlen=100) 103 | sess = U.get_session() 104 | # Prepare everything. 105 | agent.initialize(sess) 106 | checkpoint_num = 0 107 | if load_path is not None: 108 | print("Loading model from ", load_path) 109 | agent.load(load_path) 110 | # Get the model number and assign a newer number, so that training done after loading this model will be saved 111 | # with a new number that reflects the model has been updated. 112 | checkpoint_num = int(os.path.split(load_path)[1].split("_",3)[-1]) + 1 113 | sess.graph.finalize() 114 | 115 | agent.reset() 116 | 117 | obs = env.reset() 118 | if eval_env is not None: 119 | eval_obs = eval_env.reset() 120 | nenvs = obs.shape[0] 121 | 122 | episode_reward = np.zeros(nenvs, dtype = np.float32) #vector 123 | episode_step = np.zeros(nenvs, dtype = int) # vector 124 | episodes = 0 #scalar 125 | t = 0 # scalar 126 | 127 | epoch = 0 128 | 129 | 130 | 131 | start_time = time.time() 132 | 133 | epoch_episode_rewards = [] 134 | epoch_episode_steps = [] 135 | epoch_actions = [] 136 | epoch_qs = [] 137 | epoch_episodes = 0 138 | if load_path is None: 139 | os.makedirs(save_path, exist_ok=True) 140 | for epoch in range(nb_epochs): 141 | for cycle in range(nb_epoch_cycles): 142 | # Perform rollouts. 143 | if nenvs > 1: 144 | # if simulating multiple envs in parallel, impossible to reset agent at the end of the episode in each 145 | # of the environments, so resetting here instead 146 | agent.reset() 147 | for t_rollout in range(nb_rollout_steps): 148 | # Predict next action. 149 | action, q, _, _ = agent.step(obs, apply_noise=True, compute_Q=True) 150 | 151 | # Execute next action. 152 | if rank == 0 and render: 153 | env.render() 154 | 155 | # max_action is of dimension A, whereas action is dimension (nenvs, A) - the multiplication gets broadcasted to the batch 156 | new_obs, r, done, info = env.step(max_action * action) # scale for execution in env (as far as DDPG is concerned, every action is in [-1, 1]) 157 | # note these outputs are batched from vecenv 158 | 159 | t += 1 160 | if rank == 0 and render: 161 | env.render() 162 | episode_reward += r 163 | episode_step += 1 164 | 165 | # Book-keeping. 166 | epoch_actions.append(action) 167 | epoch_qs.append(q) 168 | agent.store_transition(obs, action, r, new_obs, done) #the batched data will be unrolled in memory.py's append. 169 | 170 | obs = new_obs 171 | 172 | for d in range(len(done)): 173 | if done[d]: 174 | # Episode done. 175 | epoch_episode_rewards.append(episode_reward[d]) 176 | episode_rewards_history.append(episode_reward[d]) 177 | epoch_episode_steps.append(episode_step[d]) 178 | episode_reward[d] = 0. 179 | episode_step[d] = 0 180 | epoch_episodes += 1 181 | episodes += 1 182 | if nenvs == 1: 183 | agent.reset() 184 | 185 | 186 | 187 | # Train. 188 | epoch_actor_losses = [] 189 | epoch_critic_losses = [] 190 | epoch_adaptive_distances = [] 191 | for t_train in range(nb_train_steps): 192 | # Adapt param noise, if necessary. 193 | if memory.nb_entries >= batch_size and t_train % param_noise_adaption_interval == 0: 194 | distance = agent.adapt_param_noise() 195 | epoch_adaptive_distances.append(distance) 196 | 197 | cl, al = agent.train() 198 | epoch_critic_losses.append(cl) 199 | epoch_actor_losses.append(al) 200 | agent.update_target_net() 201 | 202 | # Evaluate. 203 | eval_episode_rewards = [] 204 | eval_qs = [] 205 | if eval_env is not None: 206 | nenvs_eval = eval_obs.shape[0] 207 | eval_episode_reward = np.zeros(nenvs_eval, dtype = np.float32) 208 | for t_rollout in range(nb_eval_steps): 209 | eval_action, eval_q, _, _ = agent.step(eval_obs, apply_noise=False, compute_Q=True) 210 | eval_obs, eval_r, eval_done, eval_info = eval_env.step(max_action * eval_action) # scale for execution in env (as far as DDPG is concerned, every action is in [-1, 1]) 211 | if render_eval: 212 | eval_env.render() 213 | eval_episode_reward += eval_r 214 | 215 | eval_qs.append(eval_q) 216 | for d in range(len(eval_done)): 217 | if eval_done[d]: 218 | eval_episode_rewards.append(eval_episode_reward[d]) 219 | eval_episode_rewards_history.append(eval_episode_reward[d]) 220 | eval_episode_reward[d] = 0.0 221 | 222 | if MPI is not None: 223 | mpi_size = MPI.COMM_WORLD.Get_size() 224 | else: 225 | mpi_size = 1 226 | 227 | # Log stats. 228 | # XXX shouldn't call np.mean on variable length lists 229 | duration = time.time() - start_time 230 | stats = agent.get_stats() 231 | combined_stats = stats.copy() 232 | combined_stats['rollout/return'] = np.mean(epoch_episode_rewards) 233 | combined_stats['rollout/return_history'] = np.mean(episode_rewards_history) 234 | combined_stats['rollout/episode_steps'] = np.mean(epoch_episode_steps) 235 | combined_stats['rollout/actions_mean'] = np.mean(epoch_actions) 236 | combined_stats['rollout/Q_mean'] = np.mean(epoch_qs) 237 | combined_stats['train/loss_actor'] = np.mean(epoch_actor_losses) 238 | combined_stats['train/loss_critic'] = np.mean(epoch_critic_losses) 239 | combined_stats['train/param_noise_distance'] = np.mean(epoch_adaptive_distances) 240 | combined_stats['total/duration'] = duration 241 | combined_stats['total/steps_per_second'] = float(t) / float(duration) 242 | combined_stats['total/episodes'] = episodes 243 | combined_stats['rollout/episodes'] = epoch_episodes 244 | combined_stats['rollout/actions_std'] = np.std(epoch_actions) 245 | # Evaluation statistics. 246 | if eval_env is not None: 247 | combined_stats['eval/return'] = eval_episode_rewards 248 | combined_stats['eval/return_history'] = np.mean(eval_episode_rewards_history) 249 | combined_stats['eval/Q'] = eval_qs 250 | combined_stats['eval/episodes'] = len(eval_episode_rewards) 251 | def as_scalar(x): 252 | if isinstance(x, np.ndarray): 253 | assert x.size == 1 254 | return x[0] 255 | elif np.isscalar(x): 256 | return x 257 | else: 258 | raise ValueError('expected scalar, got %s'%x) 259 | 260 | combined_stats_sums = np.array([ np.array(x).flatten()[0] for x in combined_stats.values()]) 261 | if MPI is not None: 262 | combined_stats_sums = MPI.COMM_WORLD.allreduce(combined_stats_sums) 263 | 264 | combined_stats = {k : v / mpi_size for (k,v) in zip(combined_stats.keys(), combined_stats_sums)} 265 | 266 | # Total statistics. 267 | combined_stats['total/epochs'] = epoch + 1 268 | combined_stats['total/steps'] = t 269 | 270 | for key in sorted(combined_stats.keys()): 271 | logger.record_tabular(key, combined_stats[key]) 272 | 273 | if rank == 0: 274 | logger.dump_tabular() 275 | logger.info('') 276 | logdir = logger.get_dir() 277 | if rank == 0 and logdir: 278 | if hasattr(env, 'get_state'): 279 | with open(os.path.join(logdir, 'env_state.pkl'), 'wb') as f: 280 | pickle.dump(env.get_state(), f) 281 | if eval_env and hasattr(eval_env, 'get_state'): 282 | with open(os.path.join(logdir, 'eval_env_state.pkl'), 'wb') as f: 283 | pickle.dump(eval_env.get_state(), f) 284 | 285 | savepath = os.path.join(save_path, "DDPG_trained_model_" + str(epoch + checkpoint_num)) 286 | print('Model saved to ', savepath) 287 | print('\n') 288 | agent.save(savepath) 289 | 290 | 291 | return agent 292 | -------------------------------------------------------------------------------- /baselines/baselines/ddpg/ddpg_learner.py: -------------------------------------------------------------------------------- 1 | from copy import copy 2 | from functools import reduce 3 | 4 | import functools 5 | import numpy as np 6 | import tensorflow as tf 7 | import tensorflow.contrib as tc 8 | 9 | from baselines import logger 10 | from baselines.common.mpi_adam import MpiAdam 11 | import baselines.common.tf_util as U 12 | from baselines.common.mpi_running_mean_std import RunningMeanStd 13 | from baselines.common.tf_util import save_variables, load_variables 14 | try: 15 | from mpi4py import MPI 16 | except ImportError: 17 | MPI = None 18 | 19 | def normalize(x, stats): 20 | if stats is None: 21 | return x 22 | return (x - stats.mean) / stats.std 23 | 24 | 25 | def denormalize(x, stats): 26 | if stats is None: 27 | return x 28 | return x * stats.std + stats.mean 29 | 30 | def reduce_std(x, axis=None, keepdims=False): 31 | return tf.sqrt(reduce_var(x, axis=axis, keepdims=keepdims)) 32 | 33 | def reduce_var(x, axis=None, keepdims=False): 34 | m = tf.reduce_mean(x, axis=axis, keepdims=True) 35 | devs_squared = tf.square(x - m) 36 | return tf.reduce_mean(devs_squared, axis=axis, keepdims=keepdims) 37 | 38 | def get_target_updates(vars, target_vars, tau): 39 | logger.info('setting up target updates ...') 40 | soft_updates = [] 41 | init_updates = [] 42 | assert len(vars) == len(target_vars) 43 | for var, target_var in zip(vars, target_vars): 44 | logger.info(' {} <- {}'.format(target_var.name, var.name)) 45 | init_updates.append(tf.assign(target_var, var)) 46 | soft_updates.append(tf.assign(target_var, (1. - tau) * target_var + tau * var)) 47 | assert len(init_updates) == len(vars) 48 | assert len(soft_updates) == len(vars) 49 | return tf.group(*init_updates), tf.group(*soft_updates) 50 | 51 | 52 | def get_perturbed_actor_updates(actor, perturbed_actor, param_noise_stddev): 53 | assert len(actor.vars) == len(perturbed_actor.vars) 54 | assert len(actor.perturbable_vars) == len(perturbed_actor.perturbable_vars) 55 | 56 | updates = [] 57 | for var, perturbed_var in zip(actor.vars, perturbed_actor.vars): 58 | if var in actor.perturbable_vars: 59 | logger.info(' {} <- {} + noise'.format(perturbed_var.name, var.name)) 60 | updates.append(tf.assign(perturbed_var, var + tf.random_normal(tf.shape(var), mean=0., stddev=param_noise_stddev))) 61 | else: 62 | logger.info(' {} <- {}'.format(perturbed_var.name, var.name)) 63 | updates.append(tf.assign(perturbed_var, var)) 64 | assert len(updates) == len(actor.vars) 65 | return tf.group(*updates) 66 | 67 | 68 | class DDPG(object): 69 | def __init__(self, actor, critic, memory, observation_shape, action_shape, param_noise=None, action_noise=None, 70 | gamma=0.99, tau=0.001, normalize_returns=False, enable_popart=False, normalize_observations=True, 71 | batch_size=128, observation_range=(-5., 5.), action_range=(-1., 1.), return_range=(-np.inf, np.inf), 72 | critic_l2_reg=0., actor_lr=1e-4, critic_lr=1e-3, clip_norm=None, reward_scale=1.): 73 | # Inputs. 74 | self.obs0 = tf.placeholder(tf.float32, shape=(None,) + observation_shape, name='obs0') 75 | self.obs1 = tf.placeholder(tf.float32, shape=(None,) + observation_shape, name='obs1') 76 | self.terminals1 = tf.placeholder(tf.float32, shape=(None, 1), name='terminals1') 77 | self.rewards = tf.placeholder(tf.float32, shape=(None, 1), name='rewards') 78 | self.actions = tf.placeholder(tf.float32, shape=(None,) + action_shape, name='actions') 79 | self.critic_target = tf.placeholder(tf.float32, shape=(None, 1), name='critic_target') 80 | self.param_noise_stddev = tf.placeholder(tf.float32, shape=(), name='param_noise_stddev') 81 | 82 | # Parameters. 83 | self.gamma = gamma 84 | self.tau = tau 85 | self.memory = memory 86 | self.normalize_observations = normalize_observations 87 | self.normalize_returns = normalize_returns 88 | self.action_noise = action_noise 89 | self.param_noise = param_noise 90 | self.action_range = action_range 91 | self.return_range = return_range 92 | self.observation_range = observation_range 93 | self.critic = critic 94 | self.actor = actor 95 | self.actor_lr = actor_lr 96 | self.critic_lr = critic_lr 97 | self.clip_norm = clip_norm 98 | self.enable_popart = enable_popart 99 | self.reward_scale = reward_scale 100 | self.batch_size = batch_size 101 | self.stats_sample = None 102 | self.critic_l2_reg = critic_l2_reg 103 | self.save = None 104 | self.load = None 105 | 106 | # Observation normalization. 107 | if self.normalize_observations: 108 | with tf.variable_scope('obs_rms'): 109 | self.obs_rms = RunningMeanStd(shape=observation_shape) 110 | else: 111 | self.obs_rms = None 112 | normalized_obs0 = tf.clip_by_value(normalize(self.obs0, self.obs_rms), 113 | self.observation_range[0], self.observation_range[1]) 114 | normalized_obs1 = tf.clip_by_value(normalize(self.obs1, self.obs_rms), 115 | self.observation_range[0], self.observation_range[1]) 116 | 117 | # Return normalization. 118 | if self.normalize_returns: 119 | with tf.variable_scope('ret_rms'): 120 | self.ret_rms = RunningMeanStd() 121 | else: 122 | self.ret_rms = None 123 | 124 | # Create target networks. 125 | target_actor = copy(actor) 126 | target_actor.name = 'target_actor' 127 | self.target_actor = target_actor 128 | target_critic = copy(critic) 129 | target_critic.name = 'target_critic' 130 | self.target_critic = target_critic 131 | 132 | # Create networks and core TF parts that are shared across setup parts. 133 | self.actor_tf = actor(normalized_obs0) 134 | self.normalized_critic_tf = critic(normalized_obs0, self.actions) 135 | self.critic_tf = denormalize(tf.clip_by_value(self.normalized_critic_tf, self.return_range[0], self.return_range[1]), self.ret_rms) 136 | self.normalized_critic_with_actor_tf = critic(normalized_obs0, self.actor_tf, reuse=True) 137 | self.critic_with_actor_tf = denormalize(tf.clip_by_value(self.normalized_critic_with_actor_tf, self.return_range[0], self.return_range[1]), self.ret_rms) 138 | Q_obs1 = denormalize(target_critic(normalized_obs1, target_actor(normalized_obs1)), self.ret_rms) 139 | self.target_Q = self.rewards + (1. - self.terminals1) * gamma * Q_obs1 140 | 141 | # Set up parts. 142 | if self.param_noise is not None: 143 | self.setup_param_noise(normalized_obs0) 144 | self.setup_actor_optimizer() 145 | self.setup_critic_optimizer() 146 | if self.normalize_returns and self.enable_popart: 147 | self.setup_popart() 148 | self.setup_stats() 149 | self.setup_target_network_updates() 150 | 151 | self.initial_state = None # recurrent architectures not supported yet 152 | 153 | def setup_target_network_updates(self): 154 | actor_init_updates, actor_soft_updates = get_target_updates(self.actor.vars, self.target_actor.vars, self.tau) 155 | critic_init_updates, critic_soft_updates = get_target_updates(self.critic.vars, self.target_critic.vars, self.tau) 156 | self.target_init_updates = [actor_init_updates, critic_init_updates] 157 | self.target_soft_updates = [actor_soft_updates, critic_soft_updates] 158 | 159 | def setup_param_noise(self, normalized_obs0): 160 | assert self.param_noise is not None 161 | 162 | # Configure perturbed actor. 163 | param_noise_actor = copy(self.actor) 164 | param_noise_actor.name = 'param_noise_actor' 165 | self.perturbed_actor_tf = param_noise_actor(normalized_obs0) 166 | logger.info('setting up param noise') 167 | self.perturb_policy_ops = get_perturbed_actor_updates(self.actor, param_noise_actor, self.param_noise_stddev) 168 | 169 | # Configure separate copy for stddev adoption. 170 | adaptive_param_noise_actor = copy(self.actor) 171 | adaptive_param_noise_actor.name = 'adaptive_param_noise_actor' 172 | adaptive_actor_tf = adaptive_param_noise_actor(normalized_obs0) 173 | self.perturb_adaptive_policy_ops = get_perturbed_actor_updates(self.actor, adaptive_param_noise_actor, self.param_noise_stddev) 174 | self.adaptive_policy_distance = tf.sqrt(tf.reduce_mean(tf.square(self.actor_tf - adaptive_actor_tf))) 175 | 176 | def setup_actor_optimizer(self): 177 | logger.info('setting up actor optimizer') 178 | self.actor_loss = -tf.reduce_mean(self.critic_with_actor_tf) 179 | actor_shapes = [var.get_shape().as_list() for var in self.actor.trainable_vars] 180 | actor_nb_params = sum([reduce(lambda x, y: x * y, shape) for shape in actor_shapes]) 181 | logger.info(' actor shapes: {}'.format(actor_shapes)) 182 | logger.info(' actor params: {}'.format(actor_nb_params)) 183 | self.actor_grads = U.flatgrad(self.actor_loss, self.actor.trainable_vars, clip_norm=self.clip_norm) 184 | self.actor_optimizer = MpiAdam(var_list=self.actor.trainable_vars, 185 | beta1=0.9, beta2=0.999, epsilon=1e-08) 186 | 187 | def setup_critic_optimizer(self): 188 | logger.info('setting up critic optimizer') 189 | normalized_critic_target_tf = tf.clip_by_value(normalize(self.critic_target, self.ret_rms), self.return_range[0], self.return_range[1]) 190 | self.critic_loss = tf.reduce_mean(tf.square(self.normalized_critic_tf - normalized_critic_target_tf)) 191 | if self.critic_l2_reg > 0.: 192 | critic_reg_vars = [var for var in self.critic.trainable_vars if var.name.endswith('/w:0') and 'output' not in var.name] 193 | for var in critic_reg_vars: 194 | logger.info(' regularizing: {}'.format(var.name)) 195 | logger.info(' applying l2 regularization with {}'.format(self.critic_l2_reg)) 196 | critic_reg = tc.layers.apply_regularization( 197 | tc.layers.l2_regularizer(self.critic_l2_reg), 198 | weights_list=critic_reg_vars 199 | ) 200 | self.critic_loss += critic_reg 201 | critic_shapes = [var.get_shape().as_list() for var in self.critic.trainable_vars] 202 | critic_nb_params = sum([reduce(lambda x, y: x * y, shape) for shape in critic_shapes]) 203 | logger.info(' critic shapes: {}'.format(critic_shapes)) 204 | logger.info(' critic params: {}'.format(critic_nb_params)) 205 | self.critic_grads = U.flatgrad(self.critic_loss, self.critic.trainable_vars, clip_norm=self.clip_norm) 206 | self.critic_optimizer = MpiAdam(var_list=self.critic.trainable_vars, 207 | beta1=0.9, beta2=0.999, epsilon=1e-08) 208 | 209 | def setup_popart(self): 210 | # See https://arxiv.org/pdf/1602.07714.pdf for details. 211 | self.old_std = tf.placeholder(tf.float32, shape=[1], name='old_std') 212 | new_std = self.ret_rms.std 213 | self.old_mean = tf.placeholder(tf.float32, shape=[1], name='old_mean') 214 | new_mean = self.ret_rms.mean 215 | 216 | self.renormalize_Q_outputs_op = [] 217 | for vs in [self.critic.output_vars, self.target_critic.output_vars]: 218 | assert len(vs) == 2 219 | M, b = vs 220 | assert 'kernel' in M.name 221 | assert 'bias' in b.name 222 | assert M.get_shape()[-1] == 1 223 | assert b.get_shape()[-1] == 1 224 | self.renormalize_Q_outputs_op += [M.assign(M * self.old_std / new_std)] 225 | self.renormalize_Q_outputs_op += [b.assign((b * self.old_std + self.old_mean - new_mean) / new_std)] 226 | 227 | def setup_stats(self): 228 | ops = [] 229 | names = [] 230 | 231 | if self.normalize_returns: 232 | ops += [self.ret_rms.mean, self.ret_rms.std] 233 | names += ['ret_rms_mean', 'ret_rms_std'] 234 | 235 | if self.normalize_observations: 236 | ops += [tf.reduce_mean(self.obs_rms.mean), tf.reduce_mean(self.obs_rms.std)] 237 | names += ['obs_rms_mean', 'obs_rms_std'] 238 | 239 | ops += [tf.reduce_mean(self.critic_tf)] 240 | names += ['reference_Q_mean'] 241 | ops += [reduce_std(self.critic_tf)] 242 | names += ['reference_Q_std'] 243 | 244 | ops += [tf.reduce_mean(self.critic_with_actor_tf)] 245 | names += ['reference_actor_Q_mean'] 246 | ops += [reduce_std(self.critic_with_actor_tf)] 247 | names += ['reference_actor_Q_std'] 248 | 249 | ops += [tf.reduce_mean(self.actor_tf)] 250 | names += ['reference_action_mean'] 251 | ops += [reduce_std(self.actor_tf)] 252 | names += ['reference_action_std'] 253 | 254 | if self.param_noise: 255 | ops += [tf.reduce_mean(self.perturbed_actor_tf)] 256 | names += ['reference_perturbed_action_mean'] 257 | ops += [reduce_std(self.perturbed_actor_tf)] 258 | names += ['reference_perturbed_action_std'] 259 | 260 | self.stats_ops = ops 261 | self.stats_names = names 262 | 263 | def step(self, obs, apply_noise=True, compute_Q=True): 264 | if self.param_noise is not None and apply_noise: 265 | actor_tf = self.perturbed_actor_tf 266 | else: 267 | actor_tf = self.actor_tf 268 | feed_dict = {self.obs0: U.adjust_shape(self.obs0, [obs])} 269 | if compute_Q: 270 | action, q = self.sess.run([actor_tf, self.critic_with_actor_tf], feed_dict=feed_dict) 271 | else: 272 | action = self.sess.run(actor_tf, feed_dict=feed_dict) 273 | q = None 274 | 275 | if self.action_noise is not None and apply_noise: 276 | noise = self.action_noise() 277 | assert noise.shape == action[0].shape 278 | action += noise 279 | action = np.clip(action, self.action_range[0], self.action_range[1]) 280 | 281 | 282 | return action, q, None, None 283 | 284 | def store_transition(self, obs0, action, reward, obs1, terminal1): 285 | reward *= self.reward_scale 286 | 287 | B = obs0.shape[0] 288 | for b in range(B): 289 | self.memory.append(obs0[b], action[b], reward[b], obs1[b], terminal1[b]) 290 | if self.normalize_observations: 291 | self.obs_rms.update(np.array([obs0[b]])) 292 | 293 | def train(self): 294 | # Get a batch. 295 | batch = self.memory.sample(batch_size=self.batch_size) 296 | 297 | if self.normalize_returns and self.enable_popart: 298 | old_mean, old_std, target_Q = self.sess.run([self.ret_rms.mean, self.ret_rms.std, self.target_Q], feed_dict={ 299 | self.obs1: batch['obs1'], 300 | self.rewards: batch['rewards'], 301 | self.terminals1: batch['terminals1'].astype('float32'), 302 | }) 303 | self.ret_rms.update(target_Q.flatten()) 304 | self.sess.run(self.renormalize_Q_outputs_op, feed_dict={ 305 | self.old_std : np.array([old_std]), 306 | self.old_mean : np.array([old_mean]), 307 | }) 308 | 309 | # Run sanity check. Disabled by default since it slows down things considerably. 310 | # print('running sanity check') 311 | # target_Q_new, new_mean, new_std = self.sess.run([self.target_Q, self.ret_rms.mean, self.ret_rms.std], feed_dict={ 312 | # self.obs1: batch['obs1'], 313 | # self.rewards: batch['rewards'], 314 | # self.terminals1: batch['terminals1'].astype('float32'), 315 | # }) 316 | # print(target_Q_new, target_Q, new_mean, new_std) 317 | # assert (np.abs(target_Q - target_Q_new) < 1e-3).all() 318 | else: 319 | target_Q = self.sess.run(self.target_Q, feed_dict={ 320 | self.obs1: batch['obs1'], 321 | self.rewards: batch['rewards'], 322 | self.terminals1: batch['terminals1'].astype('float32'), 323 | }) 324 | 325 | # Get all gradients and perform a synced update. 326 | ops = [self.actor_grads, self.actor_loss, self.critic_grads, self.critic_loss] 327 | actor_grads, actor_loss, critic_grads, critic_loss = self.sess.run(ops, feed_dict={ 328 | self.obs0: batch['obs0'], 329 | self.actions: batch['actions'], 330 | self.critic_target: target_Q, 331 | }) 332 | self.actor_optimizer.update(actor_grads, stepsize=self.actor_lr) 333 | self.critic_optimizer.update(critic_grads, stepsize=self.critic_lr) 334 | 335 | return critic_loss, actor_loss 336 | 337 | def initialize(self, sess): 338 | self.sess = sess 339 | self.sess.run(tf.global_variables_initializer()) 340 | self.save = functools.partial(save_variables, sess=self.sess) 341 | self.load = functools.partial(load_variables, sess=self.load) 342 | self.actor_optimizer.sync() 343 | self.critic_optimizer.sync() 344 | self.sess.run(self.target_init_updates) 345 | 346 | def update_target_net(self): 347 | self.sess.run(self.target_soft_updates) 348 | 349 | def get_stats(self): 350 | if self.stats_sample is None: 351 | # Get a sample and keep that fixed for all further computations. 352 | # This allows us to estimate the change in value for the same set of inputs. 353 | self.stats_sample = self.memory.sample(batch_size=self.batch_size) 354 | values = self.sess.run(self.stats_ops, feed_dict={ 355 | self.obs0: self.stats_sample['obs0'], 356 | self.actions: self.stats_sample['actions'], 357 | }) 358 | 359 | names = self.stats_names[:] 360 | assert len(names) == len(values) 361 | stats = dict(zip(names, values)) 362 | 363 | if self.param_noise is not None: 364 | stats = {**stats, **self.param_noise.get_stats()} 365 | 366 | return stats 367 | 368 | def adapt_param_noise(self): 369 | try: 370 | from mpi4py import MPI 371 | except ImportError: 372 | MPI = None 373 | 374 | if self.param_noise is None: 375 | return 0. 376 | 377 | # Perturb a separate copy of the policy to adjust the scale for the next "real" perturbation. 378 | batch = self.memory.sample(batch_size=self.batch_size) 379 | self.sess.run(self.perturb_adaptive_policy_ops, feed_dict={ 380 | self.param_noise_stddev: self.param_noise.current_stddev, 381 | }) 382 | distance = self.sess.run(self.adaptive_policy_distance, feed_dict={ 383 | self.obs0: batch['obs0'], 384 | self.param_noise_stddev: self.param_noise.current_stddev, 385 | }) 386 | 387 | if MPI is not None: 388 | mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size() 389 | else: 390 | mean_distance = distance 391 | 392 | if MPI is not None: 393 | mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size() 394 | else: 395 | mean_distance = distance 396 | 397 | self.param_noise.adapt(mean_distance) 398 | return mean_distance 399 | 400 | def reset(self): 401 | # Reset internal state after an episode is complete. 402 | if self.action_noise is not None: 403 | self.action_noise.reset() 404 | if self.param_noise is not None: 405 | self.sess.run(self.perturb_policy_ops, feed_dict={ 406 | self.param_noise_stddev: self.param_noise.current_stddev, 407 | }) -------------------------------------------------------------------------------- /baselines/baselines/ddpg/memory.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class RingBuffer(object): 5 | def __init__(self, maxlen, shape, dtype='float32'): 6 | self.maxlen = maxlen 7 | self.start = 0 8 | self.length = 0 9 | self.data = np.zeros((maxlen,) + shape).astype(dtype) 10 | 11 | def __len__(self): 12 | return self.length 13 | 14 | def __getitem__(self, idx): 15 | if idx < 0 or idx >= self.length: 16 | raise KeyError() 17 | return self.data[(self.start + idx) % self.maxlen] 18 | 19 | def get_batch(self, idxs): 20 | return self.data[(self.start + idxs) % self.maxlen] 21 | 22 | def append(self, v): 23 | if self.length < self.maxlen: 24 | # We have space, simply increase the length. 25 | self.length += 1 26 | elif self.length == self.maxlen: 27 | # No space, "remove" the first item. 28 | self.start = (self.start + 1) % self.maxlen 29 | else: 30 | # This should never happen. 31 | raise RuntimeError() 32 | self.data[(self.start + self.length - 1) % self.maxlen] = v 33 | 34 | 35 | def array_min2d(x): 36 | x = np.array(x) 37 | if x.ndim >= 2: 38 | return x 39 | return x.reshape(-1, 1) 40 | 41 | 42 | class Memory(object): 43 | def __init__(self, limit, action_shape, observation_shape): 44 | self.limit = limit 45 | 46 | self.observations0 = RingBuffer(limit, shape=observation_shape) 47 | self.actions = RingBuffer(limit, shape=action_shape) 48 | self.rewards = RingBuffer(limit, shape=(1,)) 49 | self.terminals1 = RingBuffer(limit, shape=(1,)) 50 | self.observations1 = RingBuffer(limit, shape=observation_shape) 51 | 52 | def sample(self, batch_size): 53 | # Draw such that we always have a proceeding element. 54 | batch_idxs = np.random.randint(self.nb_entries - 2, size=batch_size) 55 | 56 | obs0_batch = self.observations0.get_batch(batch_idxs) 57 | obs1_batch = self.observations1.get_batch(batch_idxs) 58 | action_batch = self.actions.get_batch(batch_idxs) 59 | reward_batch = self.rewards.get_batch(batch_idxs) 60 | terminal1_batch = self.terminals1.get_batch(batch_idxs) 61 | 62 | result = { 63 | 'obs0': array_min2d(obs0_batch), 64 | 'obs1': array_min2d(obs1_batch), 65 | 'rewards': array_min2d(reward_batch), 66 | 'actions': array_min2d(action_batch), 67 | 'terminals1': array_min2d(terminal1_batch), 68 | } 69 | return result 70 | 71 | def append(self, obs0, action, reward, obs1, terminal1, training=True): 72 | if not training: 73 | return 74 | 75 | self.observations0.append(obs0) 76 | self.actions.append(action) 77 | self.rewards.append(reward) 78 | self.observations1.append(obs1) 79 | self.terminals1.append(terminal1) 80 | 81 | @property 82 | def nb_entries(self): 83 | return len(self.observations0) 84 | -------------------------------------------------------------------------------- /baselines/baselines/ddpg/models.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from baselines.common.models import get_network_builder 3 | 4 | 5 | class Model(object): 6 | def __init__(self, name, network='mlp', **network_kwargs): 7 | self.name = name 8 | self.network_builder = get_network_builder(network)(**network_kwargs) 9 | 10 | @property 11 | def vars(self): 12 | return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.name) 13 | 14 | @property 15 | def trainable_vars(self): 16 | return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name) 17 | 18 | @property 19 | def perturbable_vars(self): 20 | return [var for var in self.trainable_vars if 'LayerNorm' not in var.name] 21 | 22 | 23 | class Actor(Model): 24 | def __init__(self, nb_actions, name='actor', network='mlp', **network_kwargs): 25 | super().__init__(name=name, network=network, **network_kwargs) 26 | self.nb_actions = nb_actions 27 | 28 | def __call__(self, obs, reuse=False): 29 | with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE): 30 | x = self.network_builder(obs) 31 | x = tf.layers.dense(x, self.nb_actions, kernel_initializer=tf.random_uniform_initializer(minval=-3e-3, maxval=3e-3)) 32 | x = tf.nn.tanh(x) 33 | return x 34 | 35 | 36 | class Critic(Model): 37 | def __init__(self, name='critic', network='mlp', **network_kwargs): 38 | super().__init__(name=name, network=network, **network_kwargs) 39 | self.layer_norm = True 40 | 41 | def __call__(self, obs, action, reuse=False): 42 | with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE): 43 | x = tf.concat([obs, action], axis=-1) # this assumes observation and action can be concatenated 44 | x = self.network_builder(x) 45 | x = tf.layers.dense(x, 1, kernel_initializer=tf.random_uniform_initializer(minval=-3e-3, maxval=3e-3), name='output') 46 | return x 47 | 48 | @property 49 | def output_vars(self): 50 | output_vars = [var for var in self.trainable_vars if 'output' in var.name] 51 | return output_vars 52 | -------------------------------------------------------------------------------- /baselines/baselines/ddpg/noise.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class AdaptiveParamNoiseSpec(object): 5 | def __init__(self, initial_stddev=0.1, desired_action_stddev=0.1, adoption_coefficient=1.01): 6 | self.initial_stddev = initial_stddev 7 | self.desired_action_stddev = desired_action_stddev 8 | self.adoption_coefficient = adoption_coefficient 9 | 10 | self.current_stddev = initial_stddev 11 | 12 | def adapt(self, distance): 13 | if distance > self.desired_action_stddev: 14 | # Decrease stddev. 15 | self.current_stddev /= self.adoption_coefficient 16 | else: 17 | # Increase stddev. 18 | self.current_stddev *= self.adoption_coefficient 19 | 20 | def get_stats(self): 21 | stats = { 22 | 'param_noise_stddev': self.current_stddev, 23 | } 24 | return stats 25 | 26 | def __repr__(self): 27 | fmt = 'AdaptiveParamNoiseSpec(initial_stddev={}, desired_action_stddev={}, adoption_coefficient={})' 28 | return fmt.format(self.initial_stddev, self.desired_action_stddev, self.adoption_coefficient) 29 | 30 | 31 | class ActionNoise(object): 32 | def reset(self): 33 | pass 34 | 35 | 36 | class NormalActionNoise(ActionNoise): 37 | def __init__(self, mu, sigma): 38 | self.mu = mu 39 | self.sigma = sigma 40 | 41 | def __call__(self): 42 | return np.random.normal(self.mu, self.sigma) 43 | 44 | def __repr__(self): 45 | return 'NormalActionNoise(mu={}, sigma={})'.format(self.mu, self.sigma) 46 | 47 | 48 | # Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab 49 | class OrnsteinUhlenbeckActionNoise(ActionNoise): 50 | def __init__(self, mu, sigma, theta=.15, dt=1e-2, x0=None): 51 | self.theta = theta 52 | self.mu = mu 53 | self.sigma = sigma 54 | self.dt = dt 55 | self.x0 = x0 56 | self.reset() 57 | 58 | def __call__(self): 59 | x = self.x_prev + self.theta * (self.mu - self.x_prev) * self.dt + self.sigma * np.sqrt(self.dt) * np.random.normal(size=self.mu.shape) 60 | self.x_prev = x 61 | return x 62 | 63 | def reset(self): 64 | self.x_prev = self.x0 if self.x0 is not None else np.zeros_like(self.mu) 65 | 66 | def __repr__(self): 67 | return 'OrnsteinUhlenbeckActionNoise(mu={}, sigma={})'.format(self.mu, self.sigma) 68 | -------------------------------------------------------------------------------- /baselines/baselines/ddpg/test_smoke.py: -------------------------------------------------------------------------------- 1 | from baselines.run import main as M 2 | 3 | def _run(argstr): 4 | M(('--alg=ddpg --env=Pendulum-v0 --num_timesteps=0 ' + argstr).split(' ')) 5 | 6 | def test_popart(): 7 | _run('--normalize_returns=True --popart=True') 8 | 9 | def test_noise_normal(): 10 | _run('--noise_type=normal_0.1') 11 | 12 | def test_noise_ou(): 13 | _run('--noise_type=ou_0.1') 14 | 15 | def test_noise_adaptive(): 16 | _run('--noise_type=adaptive-param_0.2,normal_0.1') 17 | 18 | -------------------------------------------------------------------------------- /compare.py: -------------------------------------------------------------------------------- 1 | # --------------------------- IMPORT LIBRARIES ------------------------- 2 | import numpy as np 3 | import pandas as pd 4 | import matplotlib.pyplot as plt 5 | from datetime import datetime 6 | import data_preprocessing as dp 7 | from sklearn.preprocessing import MinMaxScaler 8 | 9 | import keras 10 | from keras.models import Sequential 11 | from keras.layers.recurrent import LSTM 12 | from keras.callbacks import ModelCheckpoint, EarlyStopping 13 | from keras.models import load_model 14 | from keras.layers import Dense, Dropout 15 | 16 | # ------------------------- GLOBAL PARAMETERS ------------------------- 17 | # Start and end period of historical data in question 18 | START_TRAIN = datetime(2008, 12, 31) 19 | END_TRAIN = datetime(2017, 2, 12) 20 | START_TEST = datetime(2017, 2, 12) 21 | END_TEST = datetime(2019, 2, 22) 22 | 23 | STARTING_ACC_BALANCE = 100000 24 | NUMBER_NON_CORR_STOCKS = 5 25 | # Number of times of no-improvement before training is stop. 26 | PATIENCE = 30 27 | 28 | # Pools of stocks to trade 29 | DJI = ['MMM', 'AXP', 'AAPL', 'BA', 'CAT', 'CVX', 'CSCO', 'KO', 'DIS', 'XOM', 'GE', 'GS', 'HD', 'IBM', 'INTC', 'JNJ', 30 | 'JPM', 'MCD', 'MRK', 'MSFT', 'NKE', 'PFE', 'PG', 'UTX', 'UNH', 'VZ', 'WMT'] 31 | 32 | DJI_N = ['3M', 'American Express', 'Apple', 'Boeing', 'Caterpillar', 'Chevron', 'Cisco Systems', 'Coca-Cola', 'Disney' 33 | , 'ExxonMobil', 'General Electric', 'Goldman Sachs', 'Home Depot', 'IBM', 'Intel', 'Johnson & Johnson', 34 | 'JPMorgan Chase', 'McDonalds', 'Merck', 'Microsoft', 'NIKE', 'Pfizer', 'Procter & Gamble', 35 | 'United Technologies', 'UnitedHealth Group', 'Verizon Communications', 'Wal Mart'] 36 | 37 | # Market and macroeconomic data to be used as context data 38 | CONTEXT_DATA = ['^GSPC', '^DJI', '^IXIC', '^RUT', 'SPY', 'QQQ', '^VIX', 'GLD', '^TYX', '^TNX', 'SHY', 'SHV'] 39 | 40 | # --------------------------------- CLASSES ------------------------------------ 41 | class Trading: 42 | def __init__(self, recovered_data_lstm, portfolio_stock_price, portfolio_stock_volume, test_set, non_corr_stocks): 43 | self.test_set = test_set 44 | self.ncs = non_corr_stocks 45 | self.stock_price = portfolio_stock_price 46 | self.stock_volume = portfolio_stock_volume 47 | self.generate_signals(recovered_data_lstm) 48 | 49 | def generate_signals(self, predicted_tomorrow_close): 50 | """ 51 | Generate trade signla from the prediction of the LSTM model 52 | :param predicted_tomorrow_close: 53 | :return: 54 | """ 55 | 56 | predicted_tomorrow_close.columns = self.stock_price.columns 57 | predicted_next_day_returns = (predicted_tomorrow_close / predicted_tomorrow_close.shift(1) - 1).dropna() 58 | next_day_returns = (self.stock_price / self.stock_price.shift(1) - 1).dropna() 59 | signals = pd.DataFrame(index=predicted_tomorrow_close.index, columns=self.stock_price.columns) 60 | 61 | for s in self.stock_price.columns: 62 | for d in next_day_returns.index: 63 | if predicted_tomorrow_close[s].loc[d] > self.stock_price[s].loc[d] and next_day_returns[s].loc[ 64 | d] > 0 and predicted_next_day_returns[s].loc[d] > 0: 65 | signals[s].loc[d] = 2 66 | elif predicted_tomorrow_close[s].loc[d] < self.stock_price[s].loc[d] and next_day_returns[s].loc[ 67 | d] < 0 and predicted_next_day_returns[s].loc[d] < 0: 68 | signals[s].loc[d] = -2 69 | elif predicted_tomorrow_close[s].loc[d] > self.stock_price[s].loc[d]: 70 | signals[s].loc[d] = 2 71 | elif next_day_returns[s].loc[d] > 0: 72 | signals[s].loc[d] = 1 73 | elif next_day_returns[s].loc[d] < 0: 74 | signals[s].loc[d] = -1 75 | elif predicted_next_day_returns[s].loc[d] > 0: 76 | signals[s].loc[d] = 2 77 | elif predicted_next_day_returns[s].loc[d] < 0: 78 | signals[s].loc[d] = -1 79 | else: 80 | signals[s].loc[d] = 0 81 | signals.loc[self.stock_price.index[0]] = [0, 0, 0, 0, 0] 82 | self.signals = signals 83 | 84 | def _sell(self, stock, sig, day): 85 | """ 86 | Perform and record sell transactions. 87 | """ 88 | 89 | # Get the index of the stock 90 | idx = self.ncs.index(stock) 91 | 92 | # Only need to sell the unit recommended by the trading agent, not necessarily all stock unit. 93 | num_share = min(abs(int(sig)), self.state[idx + 1]) 94 | commission = dp.Trading.commission(num_share, self.stock_price.loc[day][stock]) 95 | # Calculate slipped price. Though, at max trading volume of 10 shares, there's hardly any slippage 96 | transacted_price = dp.Trading.slippage_price(self.stock_price.loc[day][stock], -num_share, 97 | self.stock_volume.loc[day][stock]) 98 | 99 | # If there is existing stock holding 100 | if self.state[idx + 1] > 0: 101 | # Only need to sell the unit recommended by the trading agent, not necessarily all stock unit. 102 | # Update account balance after transaction 103 | self.state[0] += (transacted_price * num_share) - commission 104 | # Update stock holding 105 | self.state[idx + 1] -= num_share 106 | # Reset transacted buy price record to 0.0 if there is no more stock holding 107 | if self.state[idx + 1] == 0.0: 108 | self.buy_price[idx] = 0.0 109 | 110 | else: 111 | pass 112 | 113 | def _buy(self, stock, sig, day): 114 | """ 115 | Perform and record buy transactions. 116 | """ 117 | 118 | idx = self.ncs.index(stock) 119 | # Calculate the maximum possible number of stock unit the current cash can buy 120 | available_unit = self.state[0] // self.stock_price.loc[day][stock] 121 | num_share = min(available_unit, int(sig)) 122 | # Deduct the traded amount from account balance. If available balance is not enough to purchase stock unit 123 | # recommended by trading agent's action, just use what is left. 124 | commission = dp.Trading.commission(num_share, self.stock_price.loc[day][stock]) 125 | # Calculate slipped price. Though, at max trading volume of 10 shares, there's hardly any slippage 126 | transacted_price = dp.Trading.slippage_price(self.stock_price.loc[day][stock], num_share, 127 | self.stock_volume.loc[day][stock]) 128 | # Revise number of share to trade if account balance does not have enough 129 | if (self.state[0] - commission) < transacted_price * num_share: 130 | num_share = (self.state[0] - commission) // transacted_price 131 | self.state[0] -= (transacted_price * num_share) + commission 132 | 133 | # If there are existing stock holding already, calculate the average buy price 134 | if self.state[idx + 2] > 0.0: 135 | existing_unit = self.state[idx + 2] 136 | previous_buy_price = self.buy_price[idx] 137 | additional_unit = min(available_unit, int(sig)) 138 | new_holding = existing_unit + additional_unit 139 | self.buy_price[idx] = ((existing_unit * previous_buy_price) + ( 140 | self.stock_price.loc[day][stock] * additional_unit)) / new_holding 141 | # if there is no existing stock holding, simply record the current buy price 142 | elif self.state[idx + 2] == 0.0: 143 | self.buy_price[idx] = self.stock_price.loc[day][stock] 144 | 145 | # Update stock holding at its index 146 | self.state[idx + 1] += min(available_unit, int(sig)) 147 | 148 | def execute_trading(self, non_corr_stocks): 149 | """ 150 | This function performs long only trades for the LSTM model. 151 | """ 152 | 153 | # The money in the trading account 154 | self.acc_balance = [STARTING_ACC_BALANCE] 155 | self.total_asset = self.acc_balance 156 | self.portfolio_asset = [0.0] 157 | self.buy_price = np.zeros((1, len(non_corr_stocks))).flatten() 158 | # Unrealized profit and loss 159 | self.unrealized_pnl = [0.0] 160 | # The value of all-stock holdings 161 | self.portfolio_value = 0.0 162 | 163 | # The state of the trading environment, defined by account balance, unrealized profit and loss, relevant 164 | # stock technical data & current stock holdings 165 | self.state = self.acc_balance + self.unrealized_pnl + [0 for i in range(len(non_corr_stocks))] 166 | 167 | # Slide through the timeline 168 | for d in self.test_set.index[:-1]: 169 | 170 | signals = self.signals.loc[d] 171 | 172 | # Get the stocks to be sold 173 | sell_stocks = signals[signals < 0].sort_values(ascending=True) 174 | # Get the stocks to be bought 175 | buy_stocks = signals[signals > 0].sort_values(ascending=True) 176 | 177 | for idx, sig in enumerate(sell_stocks): 178 | self._sell(sell_stocks.index[idx], sig, d) 179 | 180 | for idx, sig in enumerate(buy_stocks): 181 | self._buy(buy_stocks.index[idx], sig, d) 182 | 183 | self.unrealized_pnl = np.sum(np.array(self.stock_price.loc[d] - self.buy_price) * np.array( 184 | self.state[2:])) 185 | 186 | # Current state space 187 | self.state = [self.state[0]] + [self.unrealized_pnl] + list(self.state[2:]) 188 | # Portfolio value is the current stock prices multiply with their respective holdings 189 | portfolio_value = sum(np.array(self.stock_price.loc[d]) * np.array(self.state[2:])) 190 | # Total asset = account balance + portfolio value 191 | total_asset_ending = self.state[0] + portfolio_value 192 | 193 | # Update account balance statement 194 | self.acc_balance = np.append(self.acc_balance, self.state[0]) 195 | 196 | # Update portfolio value statement 197 | self.portfolio_asset = np.append(self.portfolio_asset, portfolio_value) 198 | 199 | # Update total asset statement 200 | self.total_asset = np.append(self.total_asset, total_asset_ending) 201 | 202 | trading_book = pd.DataFrame(index=self.test_set.index, 203 | columns=["Cash balance", "Portfolio value", "Total asset", "Returns", "CumReturns"]) 204 | trading_book["Cash balance"] = self.acc_balance 205 | trading_book["Portfolio value"] = self.portfolio_asset 206 | trading_book["Total asset"] = self.total_asset 207 | trading_book["Returns"] = trading_book["Total asset"] / trading_book["Total asset"].shift(1) - 1 208 | trading_book["CumReturns"] = trading_book["Returns"].add(1).cumprod().fillna(1) 209 | trading_book.to_csv('./test_result/trading_book_backtest.csv') 210 | 211 | kpi = dp.MathCalc.calc_kpi(trading_book) 212 | kpi.to_csv('./test_result/kpi_backtest.csv') 213 | 214 | print("\n") 215 | print("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^") 216 | print( 217 | " KPI of RNN-LSTM modelled trading strategy for a portfolio of {} non-correlated stocks".format( 218 | NUMBER_NON_CORR_STOCKS)) 219 | print(kpi) 220 | print("vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv") 221 | 222 | return trading_book, kpi 223 | 224 | 225 | class Data_ScaleSplit: 226 | """ 227 | This class preprosses data for the LSTM model. 228 | """ 229 | 230 | def __init__(self, X, selected_stocks_price, train_portion): 231 | self.X = X 232 | self.stock_price = selected_stocks_price 233 | self.generate_labels() 234 | self.scale_data() 235 | self.split_data(train_portion) 236 | 237 | 238 | def generate_labels(self): 239 | """ 240 | Generate label data for tomorrow's prediction. 241 | """ 242 | self.Y = self.stock_price.shift(-1) 243 | self.Y.columns = [c + '_Y' for c in self.Y.columns] 244 | 245 | def scale_data(self): 246 | """ 247 | Scale the X and Y data with minimax scaller. 248 | The scaling is done separately for the train and test set to avoid look ahead bias. 249 | """ 250 | self.XY = pd.concat([self.X, self.Y], axis=1).dropna() 251 | train_set = self.XY.loc[START_TRAIN:END_TRAIN] 252 | test_set = self.XY.loc[START_TEST:END_TEST] 253 | # MinMax scaling 254 | minmaxed_scaler = MinMaxScaler(feature_range=(0, 1)) 255 | self.minmaxed = minmaxed_scaler.fit(train_set) 256 | train_set_matrix = minmaxed_scaler.transform(train_set) 257 | test_set_matrix = minmaxed_scaler.transform(test_set) 258 | self.train_set_matrix_df = pd.DataFrame(train_set_matrix, index=train_set.index, columns=train_set.columns) 259 | self.test_set_matrix_df = pd.DataFrame(test_set_matrix, index=test_set.index, columns=test_set.columns) 260 | self.XY = pd.concat([self.train_set_matrix_df, self.test_set_matrix_df], axis=0) 261 | 262 | # print ("Train set shape: ", train_set_matrix.shape) 263 | # print ("Test set shape: ", test_set_matrix.shape) 264 | 265 | def split_data(self, train_portion): 266 | """ 267 | Perform train test split with cut off date defined. 268 | """ 269 | df_values = self.XY.values 270 | # split into train and test sets 271 | 272 | train = df_values[:int(train_portion), :] 273 | test = df_values[int(train_portion):, :] 274 | # split into input and outputs 275 | train_X, self.train_y = train[:, :-5], train[:, -5:] 276 | test_X, self.test_y = test[:, :-5], test[:, -5:] 277 | # reshape input to be 3D [samples, timesteps, features] 278 | self.train_X = train_X.reshape((train_X.shape[0], 1, train_X.shape[1])) 279 | self.test_X = test_X.reshape((test_X.shape[0], 1, test_X.shape[1])) 280 | print("\n") 281 | print("Dataset shapes >") 282 | print("Train feature data shape:", self.train_X.shape) 283 | print("Train label data shape:", self.train_y.shape) 284 | print("Test feature data shape:", self.test_X.shape) 285 | print("Test label data shape:", self.test_y.shape) 286 | 287 | def get_prediction(self, model_lstm): 288 | """ 289 | Get the model prediction, inverse transform scaling to get back to original price and 290 | reassemble the full XY dataframe. 291 | """ 292 | # Get the model to predict test_y 293 | 294 | predicted_y_lstm = model_lstm.predict(self.test_X, batch_size=None, verbose=0, steps=None) 295 | # Get the model to generate train_y 296 | trained_y_lstm = model_lstm.predict(self.train_X, batch_size=None, verbose=0, steps=None) 297 | 298 | # combine the model generated train_y and test_y to create the full_y 299 | y_lstm = pd.DataFrame(data=np.vstack((trained_y_lstm, predicted_y_lstm)), 300 | columns=[c + '_LSTM' for c in self.XY.columns[-5:]], index=self.XY.index) 301 | 302 | # Combine the original full length y with model generated y 303 | lstm_y_df = pd.concat([self.XY[self.XY.columns[-5:]], y_lstm], axis=1) 304 | # Get the full length XY data with the length of model generated y 305 | lstm_df = self.XY.loc[lstm_y_df.index] 306 | # Replace the full length XY data's Y with the model generated Y 307 | lstm_df[lstm_df.columns[-5:]] = lstm_y_df[lstm_y_df.columns[-5:]] 308 | # Inverse transform it to get back the original data, the model generated y would be transformed to reveal its true predicted value 309 | recovered_data_lstm = self.minmaxed.inverse_transform(lstm_df) 310 | # Create a dataframe from it 311 | self.recovered_data_lstm = pd.DataFrame(data=recovered_data_lstm, columns=self.XY.columns, index=lstm_df.index) 312 | return self.recovered_data_lstm 313 | 314 | def get_train_test_set(self): 315 | """ 316 | Get the split X and y data. 317 | """ 318 | return self.train_X, self.train_y, self.test_X, self.test_y 319 | 320 | def get_all_data(self): 321 | """ 322 | Get the full XY data and the original stock price. 323 | """ 324 | return self.XY, self.stock_price 325 | 326 | class Model: 327 | """ 328 | This class contains all the functions required to build a LSTM or LSTM-CNN model 329 | It also offer an option to load a pre-built model. 330 | """ 331 | @staticmethod 332 | def train_model(model, train_X, train_y, model_type): 333 | """ 334 | Try to load a pre-built model. 335 | Otherwise fit a new mode with the training data. Once training is done, save the model. 336 | """ 337 | es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=PATIENCE) 338 | if model_type == "LSTM": 339 | batch_size = 4 340 | mc = ModelCheckpoint('./model/best_lstm_model.h5', monitor='val_loss', save_weights_only=False, 341 | mode='min', verbose=1, save_best_only=True) 342 | try: 343 | model = load_model('./model/best_lstm_model.h5') 344 | print("\n") 345 | print("Loading pre-saved model ...") 346 | except: 347 | print("\n") 348 | print("No pre-saved model, training new model.") 349 | pass 350 | elif model_type == "CNN": 351 | batch_size = 8 352 | mc = ModelCheckpoint('./model/best_cnn_model.h5'.format(symbol), monitor='val_loss', save_weights_only=False, 353 | mode='min', verbose=1, save_best_only=True) 354 | try: 355 | model = load_model('./model/best_cnn_model.h5') 356 | print("\n") 357 | print("Loading pre-saved model ...") 358 | except: 359 | print("\n") 360 | print("No pre-saved model, training new model.") 361 | pass 362 | # fit network 363 | history = model.fit( 364 | train_X, 365 | train_y, 366 | epochs=500, 367 | batch_size=batch_size, 368 | validation_split=0.2, 369 | verbose=2, 370 | shuffle=True, 371 | # callbacks=[es, mc, tb, LearningRateTracker()]) 372 | callbacks=[es, mc]) 373 | 374 | if model_type == "LSTM": 375 | model.save('./model/best_lstm_model.h5') 376 | elif model_type == "CNN": 377 | model.save('./model/best_cnn_model.h5') 378 | 379 | return history, model 380 | 381 | @staticmethod 382 | def plot_training(history,nn): 383 | """ 384 | Plot the historical training loss. 385 | """ 386 | # plot history 387 | plt.plot(history.history['loss'], label='train') 388 | plt.plot(history.history['val_loss'], label='test') 389 | plt.legend() 390 | plt.title('Training loss history for {} model'.format(nn)) 391 | plt.savefig('./train_result/training_loss_history_{}.png'.format(nn)) 392 | plt.show() 393 | 394 | @staticmethod 395 | def build_rnn_model(train_X): 396 | """ 397 | Build the RNN model architecture. 398 | """ 399 | # design network 400 | print("\n") 401 | print("RNN LSTM model architecture >") 402 | model = Sequential() 403 | model.add(LSTM(128, kernel_initializer='random_uniform', 404 | bias_initializer='zeros', return_sequences=True, 405 | recurrent_dropout=0.2, 406 | input_shape=(train_X.shape[1], train_X.shape[2]))) 407 | model.add(Dropout(0.5)) 408 | model.add(LSTM(64, kernel_initializer='random_uniform', 409 | return_sequences=True, 410 | # bias_regularizer=regularizers.l2(0.01), 411 | # kernel_regularizer=regularizers.l1_l2(l1=0.01,l2=0.01), 412 | # activity_regularizer=regularizers.l2(0.01), 413 | bias_initializer='zeros')) 414 | model.add(Dropout(0.5)) 415 | model.add(LSTM(64, kernel_initializer='random_uniform', 416 | # bias_regularizer=regularizers.l2(0.01), 417 | # kernel_regularizer=regularizers.l1_l2(l1=0.01,l2=0.01), 418 | # activity_regularizer=regularizers.l2(0.01), 419 | bias_initializer='zeros')) 420 | model.add(Dropout(0.5)) 421 | model.add(Dense(5)) 422 | # optimizer = keras.optimizers.RMSprop(lr=0.25, rho=0.9, epsilon=1e-0) 423 | # optimizer = keras.optimizers.Adagrad(lr=0.0001, epsilon=1e-08, decay=0.00002) 424 | # optimizer = keras.optimizers.Adam(lr=0.0001) 425 | # optimizer = keras.optimizers.Nadam(lr=0.0002, beta_1=0.9, beta_2=0.999, schedule_decay=0.004) 426 | # optimizer = keras.optimizers.Adamax(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0) 427 | optimizer = keras.optimizers.Adadelta(lr=0.2, rho=0.95, epsilon=None, decay=0.00001) 428 | 429 | model.compile(loss='mae', optimizer=optimizer, metrics=['mse', 'mae']) 430 | model.summary() 431 | print("\n") 432 | return model 433 | 434 | 435 | 436 | # ------------------------------ Main Program --------------------------------- 437 | 438 | def main(): 439 | print("\n") 440 | print("######################### This program compare performance of trading strategies ############################") 441 | print("\n") 442 | print( "1. Simple Buy and hold strategy of a portfolio with {} non-correlated stocks".format(NUMBER_NON_CORR_STOCKS)) 443 | print( "2. Sharpe ratio optimized portfolio of {} non-correlated stocks".format(NUMBER_NON_CORR_STOCKS)) 444 | print( "3. Minimum variance optimized portfolio of {} non-correlated stocks".format(NUMBER_NON_CORR_STOCKS)) 445 | print( "4. Simple Buy and hold strategy ") 446 | print( "1. Simple Buy and hold strategy ") 447 | 448 | print("\n") 449 | 450 | print("Starting to pre-process data for trading environment construction ... ") 451 | # Data Preprocessing 452 | dataset = dp.DataRetrieval() 453 | dow_stocks_train, dow_stocks_test = dataset.get_all() 454 | train_portion = len(dow_stocks_train) 455 | dow_stock_volume = dataset.components_df_v[DJI] 456 | portfolios = dp.Trading(dow_stocks_train, dow_stocks_test, dow_stock_volume.loc[START_TEST:END_TEST]) 457 | _, _, non_corr_stocks = portfolios.find_non_correlate_stocks(NUMBER_NON_CORR_STOCKS) 458 | non_corr_stocks_data = dataset.get_adj_close(non_corr_stocks) 459 | print("\n") 460 | print("Base on non-correlation preference, {} stocks are selected for portfolio construction:".format(NUMBER_NON_CORR_STOCKS)) 461 | 462 | for stock in non_corr_stocks: 463 | print(DJI_N[DJI.index(stock)]) 464 | print("\n") 465 | 466 | sharpe_portfolio, min_variance_portfolio = portfolios.find_efficient_frontier(non_corr_stocks_data, non_corr_stocks) 467 | print("Risk-averse portfolio with low variance:") 468 | print(min_variance_portfolio.T) 469 | print("High return portfolio with high Sharpe ratio") 470 | print(sharpe_portfolio.T) 471 | dow_stocks = pd.concat([dow_stocks_train, dow_stocks_test], axis=0) 472 | 473 | test_values_buyhold, test_returns_buyhold, test_kpi_buyhold = \ 474 | portfolios.diversified_trade(non_corr_stocks, dow_stocks.loc[START_TEST:END_TEST][non_corr_stocks]) 475 | print("\n") 476 | print("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^") 477 | print(" KPI of a simple buy and hold strategy for a portfolio of {} non-correlated stocks".format(NUMBER_NON_CORR_STOCKS)) 478 | print("------------------------------------------------------------------------------------") 479 | print(test_kpi_buyhold) 480 | print("vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv") 481 | 482 | 483 | test_values_sharpe_optimized_buyhold, test_returns_sharpe_optimized_buyhold, test_kpi_sharpe_optimized_buyhold =\ 484 | portfolios.optimized_diversified_trade(non_corr_stocks, sharpe_portfolio, dow_stocks.loc[START_TEST:END_TEST][non_corr_stocks]) 485 | print("\n") 486 | print("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^") 487 | print(" KPI of a simple buy and hold strategy for a Sharpe ratio optimized portfolio of {} non-correlated stocks".format(NUMBER_NON_CORR_STOCKS)) 488 | print("------------------------------------------------------------------------------------") 489 | print(test_kpi_sharpe_optimized_buyhold) 490 | print("vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv") 491 | 492 | test_values_minvar_optimized_buyhold, test_returns_minvar_optimized_buyhold, test_kpi_minvar_optimized_buyhold = \ 493 | portfolios.optimized_diversified_trade(non_corr_stocks, min_variance_portfolio, dow_stocks.loc[START_TEST:END_TEST][non_corr_stocks]) 494 | print("\n") 495 | print("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^") 496 | print(" KPI of a simple buy and hold strategy for a Minimum variance optimized portfolio of {} non-correlated stocks".format(NUMBER_NON_CORR_STOCKS)) 497 | print("------------------------------------------------------------------------------------") 498 | print(test_kpi_minvar_optimized_buyhold) 499 | print("vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv") 500 | 501 | plot = dp.UserDisplay() 502 | test_returns = dp.MathCalc.assemble_returns(test_returns_buyhold['Returns'], 503 | test_returns_sharpe_optimized_buyhold['Returns'], 504 | test_returns_minvar_optimized_buyhold['Returns']) 505 | test_cum_returns = dp.MathCalc.assemble_cum_returns(test_returns_buyhold['CumReturns'], 506 | test_returns_sharpe_optimized_buyhold['CumReturns'], 507 | test_returns_minvar_optimized_buyhold['CumReturns']) 508 | 509 | print("\n") 510 | print("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^") 511 | print("Buy and hold strategies computation completed. Now creating prediction model using RNN LSTM architecture") 512 | print("--------------------------------------------------------------------------------------------------------") 513 | print("vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv") 514 | 515 | 516 | # Use feature data preprocessed by StartTrader, so that they both use the same training data, to have a fair comparison 517 | input_states = pd.read_csv("./data/ddpg_input_states.csv", index_col='Date', parse_dates=True) 518 | scale_split = Data_ScaleSplit(input_states, dow_stocks[non_corr_stocks], train_portion) 519 | train_X, train_y, test_X, test_y = scale_split.get_train_test_set() 520 | 521 | modelling = Model 522 | model_lstm = modelling.build_rnn_model(train_X) 523 | history_lstm, model_lstm = modelling.train_model(model_lstm, train_X, train_y, "LSTM") 524 | print("RNN model loaded, now training the model again, training will stop after {} episodes no improvement") 525 | modelling.plot_training(history_lstm, "LSTM") 526 | print("Training completed, loading prediction using the trained RNN model >") 527 | recovered_data_lstm = scale_split.get_prediction(model_lstm) 528 | plot.plot_prediction(dow_stocks[non_corr_stocks].loc[recovered_data_lstm.index], recovered_data_lstm[recovered_data_lstm.columns[-5:]] , len(train_X), "LSTM") 529 | 530 | # Get the original stock price with the prediction length 531 | original_portfolio_stock_price = dow_stocks[non_corr_stocks].loc[recovered_data_lstm.index] 532 | # Get the predicted stock price with the prediction length 533 | predicted_portfolio_stock_price = recovered_data_lstm[recovered_data_lstm.columns[-5:]] 534 | print("Bactesting the RNN-LSTM model now") 535 | # Run backtest, the backtester is similar to those use by StarTrader too 536 | backtest = Trading(predicted_portfolio_stock_price, original_portfolio_stock_price, dow_stock_volume[non_corr_stocks].loc[recovered_data_lstm.index], dow_stocks_test[non_corr_stocks], non_corr_stocks) 537 | trading_book, kpi = backtest.execute_trading(non_corr_stocks) 538 | # Load backtest result for StarTrader using DDPG as learning algorithm 539 | ddpg_backtest = pd.read_csv('./test_result/trading_book_test_1.csv', index_col='Unnamed: 0', parse_dates=True) 540 | print("Backtesting completed, plotting comparison of trading models") 541 | # Compare performance on all 4 trading type 542 | djia_daily = dataset._get_daily_data(CONTEXT_DATA[1]).loc[START_TEST:END_TEST]['Close'] 543 | #print(djia_daily) 544 | all_benchmark_returns = test_returns 545 | all_benchmark_returns['DJIA'] = dp.MathCalc.calc_return(djia_daily) 546 | all_benchmark_returns['RNN LSTM'] = trading_book['Returns'] 547 | all_benchmark_returns['DDPG'] = ddpg_backtest['Returns'] 548 | all_benchmark_returns.to_csv('./test_result/all_strategies_returns.csv') 549 | plot.plot_portfolio_risk(all_benchmark_returns) 550 | 551 | all_benchmark_cum_returns = test_cum_returns 552 | all_benchmark_cum_returns['DJIA'] = all_benchmark_returns['DJIA'].add(1).cumprod().fillna(1) 553 | all_benchmark_cum_returns['RNN LSTM'] = trading_book['CumReturns'] 554 | all_benchmark_cum_returns['DDPG'] = ddpg_backtest['CumReturns'] 555 | all_benchmark_cum_returns.to_csv('./test_result/all_strategies_cum_returns.csv') 556 | plot.plot_portfolio_return(all_benchmark_cum_returns) 557 | 558 | 559 | if __name__ == '__main__': 560 | main() -------------------------------------------------------------------------------- /feature_select.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------- IMPORT LIBRARIES ------------------------------------------- 2 | import pandas as pd 3 | import numpy as np 4 | import lightgbm as lgb 5 | import gc 6 | from itertools import chain 7 | from sklearn.cluster import KMeans 8 | from mpl_toolkits.mplot3d import Axes3D 9 | import matplotlib.pyplot as plt 10 | from mpl_finance import candlestick_ohlc 11 | import copy 12 | from matplotlib.dates import (DateFormatter, WeekdayLocator, DayLocator, MONDAY) 13 | from sklearn.model_selection import train_test_split 14 | import warnings 15 | warnings.filterwarnings("ignore") 16 | 17 | # ------------------------------------------------ CLASSES -------------------------------------------- 18 | 19 | class FeatureSelector(): 20 | """ 21 | Courtesy of William Koehrsen from Feature Labs 22 | Class for performing feature selection for machine learning or data preprocessing. 23 | 24 | Implements five different methods to identify features for removal 25 | 26 | 1. Find columns with a missing percentage greater than a specified threshold 27 | 2. Find columns with a single unique value 28 | 3. Find collinear variables with a correlation greater than a specified correlation coefficient 29 | 4. Find features with 0.0 feature importance from a gradient boosting machine (gbm) 30 | 5. Find low importance features that do not contribute to a specified cumulative feature importance from the gbm 31 | 32 | Parameters 33 | -------- 34 | data : dataframe 35 | A dataset with observations in the rows and features in the columns 36 | 37 | labels : array or series, default = None 38 | Array of labels for training the machine learning model to find feature importances. These can be either binary labels 39 | (if task is 'classification') or continuous targets (if task is 'regression'). 40 | If no labels are provided, then the feature importance based methods are not available. 41 | 42 | Attributes 43 | -------- 44 | 45 | ops : dict 46 | Dictionary of operations run and features identified for removal 47 | 48 | missing_stats : dataframe 49 | The fraction of missing values for all features 50 | 51 | record_missing : dataframe 52 | The fraction of missing values for features with missing fraction above threshold 53 | 54 | unique_stats : dataframe 55 | Number of unique values for all features 56 | 57 | record_single_unique : dataframe 58 | Records the features that have a single unique value 59 | 60 | corr_matrix : dataframe 61 | All correlations between all features in the data 62 | 63 | record_collinear : dataframe 64 | Records the pairs of collinear variables with a correlation coefficient above the threshold 65 | 66 | feature_importances : dataframe 67 | All feature importances from the gradient boosting machine 68 | 69 | record_zero_importance : dataframe 70 | Records the zero importance features in the data according to the gbm 71 | 72 | record_low_importance : dataframe 73 | Records the lowest importance features not needed to reach the threshold of cumulative importance according to the gbm 74 | 75 | 76 | Notes 77 | -------- 78 | 79 | - All 5 operations can be run with the `identify_all` method. 80 | - If using feature importances, one-hot encoding is used for categorical variables which creates new columns 81 | 82 | """ 83 | 84 | def __init__(self, data, labels=None): 85 | 86 | # Dataset and optional training labels 87 | self.data = data 88 | self.labels = labels 89 | 90 | if labels is None: 91 | print('No labels provided. Feature importance based methods are not available.') 92 | 93 | self.base_features = list(data.columns) 94 | self.one_hot_features = None 95 | 96 | # Dataframes recording information about features to remove 97 | self.record_missing = None 98 | self.record_single_unique = None 99 | self.record_collinear = None 100 | self.record_zero_importance = None 101 | self.record_low_importance = None 102 | 103 | self.missing_stats = None 104 | self.unique_stats = None 105 | self.corr_matrix = None 106 | self.feature_importances = None 107 | 108 | # Dictionary to hold removal operations 109 | self.ops = {} 110 | 111 | self.one_hot_correlated = False 112 | 113 | def identify_missing(self, missing_threshold): 114 | """Find the features with a fraction of missing values above `missing_threshold`""" 115 | 116 | self.missing_threshold = missing_threshold 117 | 118 | # Calculate the fraction of missing in each column 119 | missing_series = self.data.isnull().sum() / self.data.shape[0] 120 | self.missing_stats = pd.DataFrame(missing_series).rename(columns={'index': 'feature', 0: 'missing_fraction'}) 121 | 122 | # Sort with highest number of missing values on top 123 | self.missing_stats = self.missing_stats.sort_values('missing_fraction', ascending=False) 124 | 125 | # Find the columns with a missing percentage above the threshold 126 | record_missing = pd.DataFrame(missing_series[missing_series > missing_threshold]).reset_index().rename(columns= 127 | {'index': 'feature', 0: 'missing_fraction'}) 128 | 129 | to_drop = list(record_missing['feature']) 130 | 131 | self.record_missing = record_missing 132 | self.ops['missing'] = to_drop 133 | 134 | print('%d features with greater than %0.2f missing values.\n' % ( 135 | len(self.ops['missing']), self.missing_threshold)) 136 | 137 | def identify_single_unique(self): 138 | """Finds features with only a single unique value. NaNs do not count as a unique value. """ 139 | 140 | # Calculate the unique counts in each column 141 | unique_counts = self.data.nunique() 142 | self.unique_stats = pd.DataFrame(unique_counts).rename(columns={'index': 'feature', 0: 'nunique'}) 143 | self.unique_stats = self.unique_stats.sort_values('nunique', ascending=True) 144 | 145 | # Find the columns with only one unique count 146 | record_single_unique = pd.DataFrame(unique_counts[unique_counts == 1]).reset_index().rename( 147 | columns={'index': 'feature', 148 | 0: 'nunique'}) 149 | 150 | to_drop = list(record_single_unique['feature']) 151 | 152 | self.record_single_unique = record_single_unique 153 | self.ops['single_unique'] = to_drop 154 | 155 | print('%d features with a single unique value.\n' % len(self.ops['single_unique'])) 156 | 157 | def identify_collinear(self, correlation_threshold, one_hot=False): 158 | """ 159 | Finds collinear features based on the correlation coefficient between features. 160 | For each pair of features with a correlation coefficient greather than `correlation_threshold`, 161 | only one of the pair is identified for removal. 162 | 163 | Using code adapted from: https://chrisalbon.com/machine_learning/feature_selection/drop_highly_correlated_features/ 164 | 165 | Parameters 166 | -------- 167 | 168 | correlation_threshold : float between 0 and 1 169 | Value of the Pearson correlation cofficient for identifying correlation features 170 | 171 | one_hot : boolean, default = False 172 | Whether to one-hot encode the features before calculating the correlation coefficients 173 | 174 | """ 175 | 176 | self.correlation_threshold = correlation_threshold 177 | self.one_hot_correlated = one_hot 178 | 179 | # Calculate the correlations between every column 180 | if one_hot: 181 | 182 | # One hot encoding 183 | features = pd.get_dummies(self.data) 184 | self.one_hot_features = [column for column in features.columns if column not in self.base_features] 185 | 186 | # Add one hot encoded data to original data 187 | self.data_all = pd.concat([features[self.one_hot_features], self.data], axis=1) 188 | 189 | corr_matrix = pd.get_dummies(features).corr() 190 | 191 | else: 192 | corr_matrix = self.data.corr() 193 | 194 | self.corr_matrix = corr_matrix 195 | 196 | # Extract the upper triangle of the correlation matrix 197 | upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(np.bool)) 198 | 199 | # Select the features with correlations above the threshold 200 | # Need to use the absolute value 201 | to_drop = [column for column in upper.columns if any(upper[column].abs() > correlation_threshold)] 202 | 203 | # Dataframe to hold correlated pairs 204 | record_collinear = pd.DataFrame(columns=['drop_feature', 'corr_feature', 'corr_value']) 205 | 206 | # Iterate through the columns to drop to record pairs of correlated features 207 | for column in to_drop: 208 | # Find the correlated features 209 | corr_features = list(upper.index[upper[column].abs() > correlation_threshold]) 210 | 211 | # Find the correlated values 212 | corr_values = list(upper[column][upper[column].abs() > correlation_threshold]) 213 | drop_features = [column for _ in range(len(corr_features))] 214 | 215 | # Record the information (need a temp df for now) 216 | temp_df = pd.DataFrame.from_dict({'drop_feature': drop_features, 217 | 'corr_feature': corr_features, 218 | 'corr_value': corr_values}) 219 | 220 | # Add to dataframe 221 | record_collinear = record_collinear.append(temp_df, ignore_index=True) 222 | 223 | self.record_collinear = record_collinear 224 | self.ops['collinear'] = to_drop 225 | 226 | print('%d features with a correlation magnitude greater than %0.2f.\n' % ( 227 | len(self.ops['collinear']), self.correlation_threshold)) 228 | 229 | def identify_zero_importance(self, task, eval_metric=None, 230 | n_iterations=10, early_stopping=True): 231 | """ 232 | 233 | Identify the features with zero importance according to a gradient boosting machine. 234 | The gbm can be trained with early stopping using a validation set to prevent overfitting. 235 | The feature importances are averaged over `n_iterations` to reduce variance. 236 | 237 | Uses the LightGBM implementation (http://lightgbm.readthedocs.io/en/latest/index.html) 238 | 239 | Parameters 240 | -------- 241 | 242 | eval_metric : string 243 | Evaluation metric to use for the gradient boosting machine for early stopping. Must be 244 | provided if `early_stopping` is True 245 | 246 | task : string 247 | The machine learning task, either 'classification' or 'regression' 248 | 249 | n_iterations : int, default = 10 250 | Number of iterations to train the gradient boosting machine 251 | 252 | early_stopping : boolean, default = True 253 | Whether or not to use early stopping with a validation set when training 254 | 255 | 256 | Notes 257 | -------- 258 | 259 | - Features are one-hot encoded to handle the categorical variables before training. 260 | - The gbm is not optimized for any particular task and might need some hyperparameter tuning 261 | - Feature importances, including zero importance features, can change across runs 262 | 263 | """ 264 | 265 | if early_stopping and eval_metric is None: 266 | raise ValueError("""eval metric must be provided with early stopping. Examples include "auc" for classification or 267 | "l2" for regression.""") 268 | 269 | if self.labels is None: 270 | raise ValueError("No training labels provided.") 271 | 272 | # One hot encoding 273 | features = pd.get_dummies(self.data) 274 | self.one_hot_features = [column for column in features.columns if column not in self.base_features] 275 | 276 | # Add one hot encoded data to original data 277 | self.data_all = pd.concat([features[self.one_hot_features], self.data], axis=1) 278 | 279 | # Extract feature names 280 | feature_names = list(features.columns) 281 | 282 | # Convert to np array 283 | features = np.array(features) 284 | labels = np.array(self.labels).reshape((-1,)) 285 | 286 | # Empty array for feature importances 287 | feature_importance_values = np.zeros(len(feature_names)) 288 | 289 | print('Training Gradient Boosting Model\n') 290 | 291 | # Iterate through each fold 292 | for _ in range(n_iterations): 293 | 294 | if task == 'classification': 295 | model = lgb.LGBMClassifier(n_estimators=1000, learning_rate=0.05, verbose=-1) 296 | 297 | elif task == 'regression': 298 | model = lgb.LGBMRegressor(n_estimators=1000, learning_rate=0.05, verbose=-1) 299 | 300 | else: 301 | raise ValueError('Task must be either "classification" or "regression"') 302 | 303 | # If training using early stopping need a validation set 304 | if early_stopping: 305 | 306 | train_features, valid_features, train_labels, valid_labels = train_test_split(features, labels, 307 | test_size=0.15) 308 | # Train the model with early stopping 309 | model.fit(train_features, train_labels, eval_metric=eval_metric, 310 | eval_set=[(valid_features, valid_labels)], 311 | early_stopping_rounds=100, verbose=0) 312 | 313 | # Clean up memory 314 | gc.enable() 315 | del train_features, train_labels, valid_features, valid_labels 316 | gc.collect() 317 | 318 | else: 319 | model.fit(features, labels) 320 | 321 | # Record the feature importances 322 | feature_importance_values += model.feature_importances_ / n_iterations 323 | 324 | feature_importances = pd.DataFrame({'feature': feature_names, 'importance': feature_importance_values}) 325 | 326 | # Sort features according to importance 327 | feature_importances = feature_importances.sort_values('importance', ascending=False).reset_index(drop=True) 328 | 329 | # Normalize the feature importances to add up to one 330 | feature_importances['normalized_importance'] = feature_importances['importance'] / feature_importances[ 331 | 'importance'].sum() 332 | feature_importances['cumulative_importance'] = np.cumsum(feature_importances['normalized_importance']) 333 | 334 | # Extract the features with zero importance 335 | record_zero_importance = feature_importances[feature_importances['importance'] == 0.0] 336 | 337 | to_drop = list(record_zero_importance['feature']) 338 | 339 | self.feature_importances = feature_importances 340 | self.record_zero_importance = record_zero_importance 341 | self.ops['zero_importance'] = to_drop 342 | 343 | print('\n%d features with zero importance after one-hot encoding.\n' % len(self.ops['zero_importance'])) 344 | 345 | def identify_low_importance(self, cumulative_importance): 346 | """ 347 | Finds the lowest importance features not needed to account for `cumulative_importance` fraction 348 | of the total feature importance from the gradient boosting machine. As an example, if cumulative 349 | importance is set to 0.95, this will retain only the most important features needed to 350 | reach 95% of the total feature importance. The identified features are those not needed. 351 | 352 | Parameters 353 | -------- 354 | cumulative_importance : float between 0 and 1 355 | The fraction of cumulative importance to account for 356 | 357 | """ 358 | 359 | self.cumulative_importance = cumulative_importance 360 | 361 | # The feature importances need to be calculated before running 362 | if self.feature_importances is None: 363 | raise NotImplementedError("""Feature importances have not yet been determined. 364 | Call the `identify_zero_importance` method first.""") 365 | 366 | # Make sure most important features are on top 367 | self.feature_importances = self.feature_importances.sort_values('cumulative_importance') 368 | 369 | # Identify the features not needed to reach the cumulative_importance 370 | record_low_importance = self.feature_importances[ 371 | self.feature_importances['cumulative_importance'] > cumulative_importance] 372 | 373 | to_drop = list(record_low_importance['feature']) 374 | 375 | self.record_low_importance = record_low_importance 376 | self.ops['low_importance'] = to_drop 377 | 378 | print('%d features required for cumulative importance of %0.2f after one hot encoding.' % ( 379 | len(self.feature_importances) - 380 | len(self.record_low_importance), self.cumulative_importance)) 381 | print('%d features do not contribute to cumulative importance of %0.2f.\n' % (len(self.ops['low_importance']), 382 | self.cumulative_importance)) 383 | 384 | def identify_all(self, selection_params): 385 | """ 386 | Use all five of the methods to identify features to remove. 387 | 388 | Parameters 389 | -------- 390 | 391 | selection_params : dict 392 | Parameters to use in the five feature selection methhods. 393 | Params must contain the keys ['missing_threshold', 'correlation_threshold', 'eval_metric', 'task', 'cumulative_importance'] 394 | 395 | """ 396 | 397 | # Check for all required parameters 398 | for param in ['missing_threshold', 'correlation_threshold', 'eval_metric', 'task', 'cumulative_importance']: 399 | if param not in selection_params.keys(): 400 | raise ValueError('%s is a required parameter for this method.' % param) 401 | 402 | # Implement each of the five methods 403 | self.identify_missing(selection_params['missing_threshold']) 404 | self.identify_single_unique() 405 | self.identify_collinear(selection_params['correlation_threshold']) 406 | self.identify_zero_importance(task=selection_params['task'], eval_metric=selection_params['eval_metric']) 407 | self.identify_low_importance(selection_params['cumulative_importance']) 408 | 409 | # Find the number of features identified to drop 410 | self.all_identified = set(list(chain(*list(self.ops.values())))) 411 | self.n_identified = len(self.all_identified) 412 | 413 | print('%d total features out of %d identified for removal after one-hot encoding.\n' % (self.n_identified, 414 | self.data_all.shape[1])) 415 | 416 | def check_removal(self, keep_one_hot=True): 417 | 418 | """Check the identified features before removal. Returns a list of the unique features identified.""" 419 | 420 | self.all_identified = set(list(chain(*list(self.ops.values())))) 421 | print('Total of %d features identified for removal' % len(self.all_identified)) 422 | 423 | if not keep_one_hot: 424 | if self.one_hot_features is None: 425 | print('Data has not been one-hot encoded') 426 | else: 427 | one_hot_to_remove = [x for x in self.one_hot_features if x not in self.all_identified] 428 | print('%d additional one-hot features can be removed' % len(one_hot_to_remove)) 429 | 430 | return list(self.all_identified) 431 | 432 | def remove(self, methods, keep_one_hot=True): 433 | """ 434 | Remove the features from the data according to the specified methods. 435 | 436 | Parameters 437 | -------- 438 | methods : 'all' or list of methods 439 | If methods == 'all', any methods that have identified features will be used 440 | Otherwise, only the specified methods will be used. 441 | Can be one of ['missing', 'single_unique', 'collinear', 'zero_importance', 'low_importance'] 442 | keep_one_hot : boolean, default = True 443 | Whether or not to keep one-hot encoded features 444 | 445 | Return 446 | -------- 447 | data : dataframe 448 | Dataframe with identified features removed 449 | 450 | 451 | Notes 452 | -------- 453 | - If feature importances are used, the one-hot encoded columns will be added to the data (and then may be removed) 454 | - Check the features that will be removed before transforming data! 455 | 456 | """ 457 | 458 | features_to_drop = [] 459 | 460 | if methods == 'all': 461 | 462 | # Need to use one-hot encoded data as well 463 | data = self.data_all 464 | 465 | print('{} methods have been run\n'.format(list(self.ops.keys()))) 466 | 467 | # Find the unique features to drop 468 | features_to_drop = set(list(chain(*list(self.ops.values())))) 469 | 470 | else: 471 | # Need to use one-hot encoded data as well 472 | if 'zero_importance' in methods or 'low_importance' in methods or self.one_hot_correlated: 473 | data = self.data_all 474 | 475 | else: 476 | data = self.data 477 | 478 | # Iterate through the specified methods 479 | for method in methods: 480 | 481 | # Check to make sure the method has been run 482 | if method not in self.ops.keys(): 483 | raise NotImplementedError('%s method has not been run' % method) 484 | 485 | # Append the features identified for removal 486 | else: 487 | features_to_drop.append(self.ops[method]) 488 | 489 | # Find the unique features to drop 490 | features_to_drop = set(list(chain(*features_to_drop))) 491 | 492 | features_to_drop = list(features_to_drop) 493 | 494 | if not keep_one_hot: 495 | 496 | if self.one_hot_features is None: 497 | print('Data has not been one-hot encoded') 498 | else: 499 | 500 | features_to_drop = list(set(features_to_drop) | set(self.one_hot_features)) 501 | 502 | # Remove the features and return the data 503 | data = data.drop(features_to_drop, axis=1) 504 | #data = data.drop(columns=features_to_drop) 505 | self.removed_features = features_to_drop 506 | 507 | if not keep_one_hot: 508 | print('Removed %d features including one-hot features.' % len(features_to_drop)) 509 | else: 510 | print('Removed %d features.' % len(features_to_drop)) 511 | 512 | return data 513 | 514 | def plot_missing(self): 515 | """Histogram of missing fraction in each feature""" 516 | if self.record_missing is None: 517 | raise NotImplementedError("Missing values have not been calculated. Run `identify_missing`") 518 | 519 | self.reset_plot() 520 | 521 | # Histogram of missing values 522 | plt.style.use('seaborn-white') 523 | plt.figure(figsize=(7, 5)) 524 | plt.hist(self.missing_stats['missing_fraction'], bins=np.linspace(0, 1, 11), edgecolor='k', color='red', 525 | linewidth=1.5) 526 | plt.xticks(np.linspace(0, 1, 11)); 527 | plt.xlabel('Missing Fraction', size=14); 528 | plt.ylabel('Count of Features', size=14); 529 | plt.title("Fraction of Missing Values Histogram", size=16); 530 | 531 | def plot_unique(self): 532 | """Histogram of number of unique values in each feature""" 533 | if self.record_single_unique is None: 534 | raise NotImplementedError('Unique values have not been calculated. Run `identify_single_unique`') 535 | 536 | self.reset_plot() 537 | 538 | # Histogram of number of unique values 539 | self.unique_stats.plot.hist(edgecolor='k', figsize=(7, 5)) 540 | plt.ylabel('Frequency', size=14); 541 | plt.xlabel('Unique Values', size=14); 542 | plt.title('Number of Unique Values Histogram', size=16); 543 | 544 | def plot_collinear(self, plot_all=False): 545 | """ 546 | Heatmap of the correlation values. If plot_all = True plots all the correlations otherwise 547 | plots only those features that have a correlation above the threshold 548 | 549 | Notes 550 | -------- 551 | - Not all of the plotted correlations are above the threshold because this plots 552 | all the variables that have been idenfitied as having even one correlation above the threshold 553 | - The features on the x-axis are those that will be removed. The features on the y-axis 554 | are the correlated features with those on the x-axis 555 | 556 | Code adapted from https://seaborn.pydata.org/examples/many_pairwise_correlations.html 557 | """ 558 | 559 | if self.record_collinear is None: 560 | raise NotImplementedError('Collinear features have not been idenfitied. Run `identify_collinear`.') 561 | 562 | if plot_all: 563 | corr_matrix_plot = self.corr_matrix 564 | title = 'All Correlations' 565 | 566 | else: 567 | # Identify the correlations that were above the threshold 568 | # columns (x-axis) are features to drop and rows (y_axis) are correlated pairs 569 | corr_matrix_plot = self.corr_matrix.loc[list(set(self.record_collinear['corr_feature'])), 570 | list(set(self.record_collinear['drop_feature']))] 571 | 572 | title = "Correlations Above Threshold" 573 | 574 | f, ax = plt.subplots(figsize=(10, 8)) 575 | 576 | # Diverging colormap 577 | cmap = sns.diverging_palette(220, 10, as_cmap=True) 578 | 579 | # Draw the heatmap with a color bar 580 | sns.heatmap(corr_matrix_plot, cmap=cmap, center=0, 581 | linewidths=.25, cbar_kws={"shrink": 0.6}) 582 | 583 | # Set the ylabels 584 | ax.set_yticks([x + 0.5 for x in list(range(corr_matrix_plot.shape[0]))]) 585 | ax.set_yticklabels(list(corr_matrix_plot.index), size=int(160 / corr_matrix_plot.shape[0])); 586 | 587 | # Set the xlabels 588 | ax.set_xticks([x + 0.5 for x in list(range(corr_matrix_plot.shape[1]))]) 589 | ax.set_xticklabels(list(corr_matrix_plot.columns), size=int(160 / corr_matrix_plot.shape[1])); 590 | plt.title(title, size=14) 591 | 592 | def plot_feature_importances(self, plot_n=15, threshold=None): 593 | """ 594 | Plots `plot_n` most important features and the cumulative importance of features. 595 | If `threshold` is provided, prints the number of features needed to reach `threshold` cumulative importance. 596 | 597 | Parameters 598 | -------- 599 | 600 | plot_n : int, default = 15 601 | Number of most important features to plot. Defaults to 15 or the maximum number of features whichever is smaller 602 | 603 | threshold : float, between 0 and 1 default = None 604 | Threshold for printing information about cumulative importances 605 | 606 | """ 607 | 608 | if self.record_zero_importance is None: 609 | raise NotImplementedError('Feature importances have not been determined. Run `idenfity_zero_importance`') 610 | 611 | # Need to adjust number of features if greater than the features in the data 612 | if plot_n > self.feature_importances.shape[0]: 613 | plot_n = self.feature_importances.shape[0] - 1 614 | 615 | self.reset_plot() 616 | 617 | # Make a horizontal bar chart of feature importances 618 | plt.figure(figsize=(10, 6)) 619 | ax = plt.subplot() 620 | 621 | # Need to reverse the index to plot most important on top 622 | # There might be a more efficient method to accomplish this 623 | ax.barh(list(reversed(list(self.feature_importances.index[:plot_n]))), 624 | self.feature_importances['normalized_importance'][:plot_n], 625 | align='center', edgecolor='k') 626 | 627 | # Set the yticks and labels 628 | ax.set_yticks(list(reversed(list(self.feature_importances.index[:plot_n])))) 629 | ax.set_yticklabels(self.feature_importances['feature'][:plot_n], size=12) 630 | 631 | # Plot labeling 632 | plt.xlabel('Normalized Importance', size=16); 633 | plt.title('Feature Importances', size=18) 634 | plt.show() 635 | 636 | # Cumulative importance plot 637 | plt.figure(figsize=(6, 4)) 638 | plt.plot(list(range(1, len(self.feature_importances) + 1)), self.feature_importances['cumulative_importance'], 639 | 'r-') 640 | plt.xlabel('Number of Features', size=14); 641 | plt.ylabel('Cumulative Importance', size=14); 642 | plt.title('Cumulative Feature Importance', size=16); 643 | 644 | if threshold: 645 | # Index of minimum number of features needed for cumulative importance threshold 646 | # np.where returns the index so need to add 1 to have correct number 647 | importance_index = np.min(np.where(self.feature_importances['cumulative_importance'] > threshold)) 648 | plt.vlines(x=importance_index + 1, ymin=0, ymax=1, linestyles='--', colors='blue') 649 | plt.show(); 650 | 651 | print('%d features required for %0.2f of cumulative importance' % (importance_index + 1, threshold)) 652 | 653 | def reset_plot(self): 654 | plt.rcParams = plt.rcParamsDefault -------------------------------------------------------------------------------- /gym/envs/StarTrader/StarTrade_env.py: -------------------------------------------------------------------------------- 1 | # --------------------------- IMPORT LIBRARIES ------------------------- 2 | import numpy as np 3 | import pandas as pd 4 | import matplotlib.pyplot as plt 5 | from datetime import datetime, timedelta 6 | from gym.utils import seeding 7 | import gym 8 | from gym import spaces 9 | import data_preprocessing as dp 10 | 11 | # ------------------------- GLOBAL PARAMETERS ------------------------- 12 | # Start and end period of historical data in question 13 | START_TRAIN = datetime(2008, 12, 31) 14 | END_TRAIN = datetime(2017, 2, 12) 15 | START_TEST = datetime(2017, 2, 12) 16 | END_TEST = datetime(2019, 2, 22) 17 | 18 | STARTING_ACC_BALANCE = 100000 19 | NUMBER_NON_CORR_STOCKS = 5 20 | MAX_TRADE = 10 21 | TRAIN_RATIO = 0.8 22 | PRICE_IMPACT = 0.1 23 | 24 | # Pools of stocks to trade 25 | DJI = ['MMM', 'AXP', 'AAPL', 'BA', 'CAT', 'CVX', 'CSCO', 'KO', 'DIS', 'XOM', 'GE', 'GS', 'HD', 'IBM', 'INTC', 'JNJ', 26 | 'JPM', 'MCD', 'MRK', 'MSFT', 'NKE', 'PFE', 'PG', 'UTX', 'UNH', 'VZ', 'WMT'] 27 | 28 | DJI_N = ['3M','American Express', 'Apple','Boeing','Caterpillar','Chevron','Cisco Systems','Coca-Cola','Disney' 29 | ,'ExxonMobil','General Electric','Goldman Sachs','Home Depot','IBM','Intel','Johnson & Johnson', 30 | 'JPMorgan Chase','McDonalds','Merck','Microsoft','NIKE','Pfizer','Procter & Gamble', 31 | 'United Technologies','UnitedHealth Group','Verizon Communications','Wal Mart'] 32 | 33 | #Market and macroeconomic data to be used as context data 34 | CONTEXT_DATA = ['^GSPC', '^DJI', '^IXIC', '^RUT', 'SPY', 'QQQ', '^VIX', 'GLD', '^TYX', '^TNX' , 'SHY', 'SHV'] 35 | 36 | 37 | # ------------------------------ PREPROCESSING --------------------------------- 38 | print ("\n") 39 | print ("############################## Welcome to the playground of Star Trader!! ###################################") 40 | print ("\n") 41 | print ("Hello, I am Star, I am learning to trade like a human. In this playground, I trade stocks and optimize my portfolio.") 42 | print ("\n") 43 | 44 | print ("Starting to pre-process data for trading environment construction ... ") 45 | # Data Preprocessing 46 | dataset = dp.DataRetrieval() 47 | dow_stocks_train, dow_stocks_test = dataset.get_all() 48 | dow_stock_volume = dataset.components_df_v[DJI] 49 | portfolios = dp.Trading(dow_stocks_train, dow_stocks_test, dow_stock_volume.loc[START_TEST:END_TEST]) 50 | _, _, non_corr_stocks = portfolios.find_non_correlate_stocks(NUMBER_NON_CORR_STOCKS) 51 | feature_df = dataset.get_feature_dataframe (non_corr_stocks) 52 | context_df = dataset.get_feature_dataframe (CONTEXT_DATA) 53 | 54 | # With context data 55 | input_states = pd.concat([context_df, feature_df], axis=1) 56 | input_states.to_csv('./data/ddpg_input_states.csv') 57 | # Without context data 58 | #input_states = feature_df 59 | feature_length = len(input_states.columns) 60 | data_length = len(input_states) 61 | stock_price = dataset.components_df_o[non_corr_stocks] 62 | stock_volume = dataset.components_df_v[non_corr_stocks] 63 | stock_price.to_csv('./data/ddpg_stock_price.csv') 64 | 65 | print("\n") 66 | print("Base on non-correlation preference, {} stocks are selected for portfolio construction:".format(NUMBER_NON_CORR_STOCKS)) 67 | 68 | for stock in non_corr_stocks: 69 | print(DJI_N[DJI.index(stock)]) 70 | print("\n") 71 | 72 | print("Pre-processing and stock selection complete, trading starts now ...") 73 | print("_______________________________________________________________________________________________________________") 74 | 75 | 76 | # ------------------------------ CLASSES --------------------------------- 77 | 78 | class StarTradingEnv(gym.Env): 79 | metadata = {'render.modes': ['human']} 80 | 81 | def __init__(self, day = START_TRAIN): 82 | 83 | """ 84 | Initializing the trading environment, trading parameters starting values are defined. 85 | """ 86 | self.iteration = 0 87 | self.day = day 88 | self.done = False 89 | 90 | 91 | # trading agent's action with low and high as the maximum number of stocks allowed to sell or buy, 92 | # defined using Gym's Box action space function 93 | self.action_space = spaces.Box(low = -MAX_TRADE, high = MAX_TRADE,shape = (NUMBER_NON_CORR_STOCKS,),dtype=np.int8) 94 | 95 | # [account balance]+[unrealized profit/loss] +[number of features, 36]+[portfolio stock of 5 stocks holdings] 96 | self.full_feature_length = 2 + feature_length 97 | print("full length", self.full_feature_length) 98 | self.observation_space = spaces.Box(low=0, high=np.inf, shape = (self.full_feature_length + NUMBER_NON_CORR_STOCKS,)) 99 | 100 | # Sliding the timeline window day-by-day, skipping the non-trading day as data is not available 101 | wrong_day = True 102 | add_day = 0 103 | while wrong_day: 104 | try: 105 | temp_date = self.day + timedelta(days=add_day) 106 | self.data = input_states.loc[temp_date] 107 | self.day = temp_date 108 | wrong_day = False 109 | except: 110 | add_day += 1 111 | 112 | self.timeline = [self.day] 113 | # The money in the trading account 114 | self.acc_balance = [STARTING_ACC_BALANCE] 115 | self.total_asset = self.acc_balance 116 | self.portfolio_asset = [0.0] 117 | self.buy_price = np.zeros((1,NUMBER_NON_CORR_STOCKS)).flatten() 118 | # Unrealized profit and loss 119 | self.unrealized_pnl = [0.0] 120 | #The value of all-stock holdings 121 | self.portfolio_value = 0.0 122 | 123 | 124 | # The state of the trading environment, defined by account balance, unrealized profit and loss, relevant 125 | # stock technical data & current stock holdings 126 | self.state = self.acc_balance + self.unrealized_pnl + self.data.values.tolist() + [0 for i in range(NUMBER_NON_CORR_STOCKS)] 127 | 128 | # The current reward of the agent. 129 | self.reward = 0 130 | self._seed() 131 | self.reset() 132 | 133 | def _sell(self, idx, action): 134 | """ 135 | Perform and record sell transactions. Commissions and slippage are taken into account. 136 | """ 137 | 138 | # Only need to sell the unit recommended by the trading agent, not necessarily all stock unit. 139 | num_share = min(abs(int(action)) , self.state[idx + self.full_feature_length]) 140 | commission = dp.Trading.commission(num_share, stock_price.loc[self.day][idx]) 141 | # Calculate slipped price. Though, at max trading volume of 10 shares, there's hardly any slippage 142 | transacted_price = dp.Trading.slippage_price(stock_price.loc[self.day][idx], -num_share, stock_volume.loc[self.day][idx]) 143 | 144 | # If there is existing stock holding 145 | if self.state[idx + self.full_feature_length] > 0: 146 | 147 | # Update account balance after transaction 148 | self.state[0] += (transacted_price * num_share) - commission 149 | # Update stock holding 150 | self.state[idx + self.full_feature_length] -= num_share 151 | # Reset transacted buy price record to 0.0 if there is no more stock holding 152 | if self.state[idx + self.full_feature_length] == 0.0: 153 | self.buy_price[idx] = 0.0 154 | else: 155 | pass 156 | 157 | def _buy(self, idx, action): 158 | """ 159 | Perform and record buy transactions. Commissions and slippage are taken into account. 160 | """ 161 | 162 | # Calculate the maximum possible number of stock unit the current cash can buy 163 | available_unit = self.state[0] // stock_price.loc[self.day][idx] 164 | num_share = min(available_unit, int(action)) 165 | # Deduct the traded amount from account balance. If available balance is not enough to purchase stock unit 166 | # recommended by trading agent's action, just use what is left. 167 | commission = dp.Trading.commission(num_share, stock_price.loc[self.day][idx]) 168 | # Calculate slipped price. Though, at max trading volume of 10 shares, there's hardly any slippage 169 | transacted_price = dp.Trading.slippage_price(stock_price.loc[self.day][idx], num_share, 170 | stock_volume.loc[self.day][idx]) 171 | 172 | # Revise number of share to trade if account balance does not have enough 173 | if (self.state[0] - commission) < transacted_price * num_share: 174 | num_share = (self.state[0] - commission) // transacted_price 175 | self.state[0] -= ( transacted_price * num_share ) + commission 176 | 177 | 178 | # If there are existing stock holding already, calculate the average buy price 179 | if self.state[idx + self.full_feature_length] > 0.0: 180 | existing_unit = self.state[idx + self.full_feature_length] 181 | previous_buy_price = self.buy_price[idx] 182 | additional_unit = min(available_unit, int(action)) 183 | new_holding = existing_unit + additional_unit 184 | self.buy_price[idx] = ((existing_unit * previous_buy_price ) + (transacted_price * additional_unit))/ new_holding 185 | # if there is no existing stock holding, simply record the current buy price 186 | elif self.state[idx + self.full_feature_length] == 0.0: 187 | self.buy_price[idx] = transacted_price 188 | 189 | # Update stock holding at its index 190 | self.state[idx + self.full_feature_length] += num_share 191 | 192 | 193 | def step(self, actions): 194 | """ 195 | The step of an episode. Perform all activities of an episode. 196 | """ 197 | 198 | # Episode ends when timestep reaches the last day in feature data 199 | self.done = self.day >= END_TRAIN 200 | # Uncomment below to run a quick test 201 | #self.done = self.day >= START_TRAIN + timedelta(days=10) 202 | 203 | # If it is the last step, plot trading performance 204 | if self.done: 205 | print("@@@@@@@@@@@@@@@@@") 206 | print("Iteration", self.iteration-1) 207 | # Construct trading book and save to a spreadsheet for analysis 208 | trading_book = pd.DataFrame(index=self.timeline, columns=["Cash balance", "Portfolio value", "Total asset", "Returns", "Cum Returns"]) 209 | trading_book["Cash balance"] = self.acc_balance 210 | trading_book["Portfolio value"] = self.portfolio_asset 211 | trading_book["Total asset"] = self.total_asset 212 | trading_book["Returns"] = trading_book["Total asset"] / trading_book["Total asset"].shift(1) - 1 213 | trading_book["CumReturns"] = trading_book["Returns"].add(1).cumprod().fillna(1) 214 | trading_book.to_csv('./train_result/trading_book_train_{}.csv'.format(self.iteration-1)) 215 | 216 | kpi = dp.MathCalc.calc_kpi(trading_book) 217 | kpi.to_csv('./train_result/kpi_train_{}.csv'.format(self.iteration-1)) 218 | print("===============================================================================================") 219 | print(kpi) 220 | print("===============================================================================================") 221 | 222 | # Visualize results 223 | plt.plot(trading_book.index, trading_book["Cash balance"], 'g', label='Account cash balance',alpha=0.8,) 224 | plt.plot(trading_book.index, trading_book["Portfolio value"], 'r', label='Portfolio value',alpha=1,lw=1.5) 225 | plt.plot(trading_book.index, trading_book["Total asset"], 'b', label='Total asset',alpha=0.6,lw=3) 226 | plt.xlabel('Timeline', fontsize=12) 227 | plt.ylabel('Value', fontsize=12) 228 | plt.title('Portfolio value + account fund evolution @ train iteration {}'.format(self.iteration-1), fontsize=13) 229 | plt.tight_layout() 230 | plt.legend() 231 | plt.savefig('./train_result/asset_evolution_train_{}.png'.format(self.iteration-1)) 232 | plt.show() 233 | plt.close() 234 | 235 | plt.plot(trading_book.index, trading_book["CumReturns"], 'g', label='Cummulative returns') 236 | plt.xlabel('Timeline', fontsize=12) 237 | plt.ylabel('Returns', fontsize=12) 238 | plt.title('Cummulative returns @ train iteration {}'.format(self.iteration-1), fontsize=13) 239 | plt.tight_layout() 240 | plt.legend() 241 | plt.savefig('./train_result/cummulative_returns_train_{}.png'.format(self.iteration-1)) 242 | plt.show() 243 | plt.close() 244 | 245 | return self.state, self.reward, self.done, {} 246 | 247 | else: 248 | # Portfolio value is current holdings multiply with current respective stock prices 249 | portfolio_value = sum(np.array(stock_price.loc[self.day]) * np.array(self.state[self.full_feature_length:])) 250 | # Total asset is account balance + portfolio value 251 | total_asset_starting = self.state[0] + portfolio_value 252 | 253 | # Sort the trade order in increasing order, stocks with higher number of units to sell will be 254 | # transacted first. Stocks with lesser number of units to buy will be transacted first 255 | sorted_actions = np.argsort(actions) 256 | # Get the stocks to be sold 257 | sell_stocks = sorted_actions[:np.where(actions < 0)[0].shape[0]] 258 | 259 | # Alternatively, sell with static order 260 | #sell_stocks = np.where(actions < 0)[0].flatten() 261 | #np.random.shuffle(sell_stocks) 262 | for stock_idx in sell_stocks: 263 | self._sell(stock_idx, actions[stock_idx]) 264 | 265 | # Get the stocks to be bought 266 | buy_stocks = sorted_actions[::-1][:np.where(actions > 0)[0].shape[0]] 267 | # Alternatively, buy with static order 268 | #buy_stocks = np.where(actions > 0)[0].flatten() 269 | #np.random.shuffle(buy_stocks) 270 | for stock_idx in buy_stocks: 271 | self._buy(stock_idx, actions[stock_idx]) 272 | 273 | # Update date and skip some date since not every day is trading day 274 | self.day += timedelta(days=1) 275 | wrong_day = True 276 | add_day = 0 277 | while wrong_day: 278 | try: 279 | temp_date = self.day + timedelta(days=add_day) 280 | self.data = input_states.loc[temp_date] 281 | self.day = temp_date 282 | wrong_day = False 283 | except: 284 | add_day += 1 285 | 286 | # Calculate unrealized profit and loss for existing stock holdings 287 | self.unrealized_pnl = np.sum(np.array(stock_price.loc[self.day] - self.buy_price) * np.array( 288 | self.state[self.full_feature_length:])) 289 | 290 | # Current state space 291 | self.state = [self.state[0]] + [self.unrealized_pnl] + self.data.values.tolist() + list(self.state[self.full_feature_length:]) 292 | # Portfolio value is the current stock prices multiply with their respective holdings 293 | portfolio_value = sum(np.array(stock_price.loc[self.day]) * np.array(self.state[self.full_feature_length:])) 294 | # Total asset = account balance + portfolio value 295 | total_asset_ending = self.state[0] + portfolio_value 296 | 297 | # Update account balance statement 298 | self.acc_balance = np.append(self.acc_balance, self.state[0]) 299 | 300 | # Update portfolio value statement 301 | self.portfolio_asset = np.append(self.portfolio_asset, portfolio_value) 302 | 303 | # Update total asset statement 304 | self.total_asset = np.append(self.total_asset, total_asset_ending) 305 | 306 | # Update timeline 307 | self.timeline = np.append(self.timeline, self.day) 308 | 309 | # Get the agent to consider gain-to-pain or lake ratio and be responsible for it if it has traded long enough 310 | if len(self.total_asset) > 9: 311 | returns = dp.MathCalc.calc_return(pd.Series(self.total_asset)) 312 | 313 | self.reward = total_asset_ending - total_asset_starting \ 314 | + (100*dp.MathCalc.calc_gain_to_pain(returns))\ 315 | - (500 * dp.MathCalc.calc_lake_ratio(pd.Series(returns).add(1).cumprod().fillna(1))) 316 | #+ (50 * dp.MathCalc.sharpe_ratio(pd.Series(returns))) 317 | 318 | # If agent has not traded long enough, it only has to bear total asset difference at the end of the day 319 | else: 320 | self.reward = total_asset_ending - total_asset_starting 321 | 322 | return self.state, self.reward, self.done, {} 323 | 324 | def reset(self): 325 | """ 326 | Reset the environment once an episode end. 327 | 328 | """ 329 | self.acc_balance = [STARTING_ACC_BALANCE] 330 | self.total_asset = self.acc_balance 331 | self.portfolio_asset = [0.0] 332 | self.buy_price = np.zeros((1, NUMBER_NON_CORR_STOCKS)).flatten() 333 | self.unrealized_pnl = [0] 334 | self.day = START_TRAIN 335 | self.portfolio_value = 0.0 336 | 337 | wrong_day = True 338 | add_day = 0 339 | while wrong_day: 340 | try: 341 | temp_date = self.day + timedelta(days=add_day) 342 | self.data = input_states.loc[temp_date] 343 | self.day = temp_date 344 | wrong_day = False 345 | except: 346 | add_day += 1 347 | 348 | self.timeline = [self.day] 349 | 350 | self.state = self.acc_balance + self.unrealized_pnl + self.data.values.tolist() + [0 for i in range(NUMBER_NON_CORR_STOCKS)] 351 | self.iteration += 1 352 | 353 | return self.state 354 | 355 | def render(self, mode='human'): 356 | """ 357 | Render the environment with current state. 358 | 359 | """ 360 | return self.state 361 | 362 | def _seed(self, seed=None): 363 | """ 364 | Seed the iteration. 365 | """ 366 | self.np_random, seed = seeding.np_random(seed) 367 | return [seed] 368 | -------------------------------------------------------------------------------- /gym/envs/StarTrader/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.StarTrader.StarTrade_env import StarTradingEnv 2 | -------------------------------------------------------------------------------- /gym/envs/StarTrader/__pycache__/StarTrade_env.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/gym/envs/StarTrader/__pycache__/StarTrade_env.cpython-36.pyc -------------------------------------------------------------------------------- /gym/envs/StarTrader/__pycache__/StarTrade_test_env.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/gym/envs/StarTrader/__pycache__/StarTrade_test_env.cpython-36.pyc -------------------------------------------------------------------------------- /gym/envs/StarTrader/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/gym/envs/StarTrader/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /gym/envs/StarTrader/__pycache__/intelliTrade_env.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/gym/envs/StarTrader/__pycache__/intelliTrade_env.cpython-36.pyc -------------------------------------------------------------------------------- /gym/envs/StarTraderTest/StarTrade_test_env.py: -------------------------------------------------------------------------------- 1 | # --------------------------- IMPORT LIBRARIES ------------------------- 2 | import numpy as np 3 | import pandas as pd 4 | import matplotlib.pyplot as plt 5 | from datetime import datetime, timedelta 6 | from gym.utils import seeding 7 | import gym 8 | from gym import spaces 9 | import data_preprocessing as dp 10 | 11 | # ------------------------- GLOBAL PARAMETERS ------------------------- 12 | # Start and end period of historical data in question 13 | START_TRAIN = datetime(2008, 12, 31) 14 | END_TRAIN = datetime(2017, 2, 12) 15 | START_TEST = datetime(2017, 2, 12) 16 | END_TEST = datetime(2019, 2, 22) 17 | 18 | STARTING_ACC_BALANCE = 100000 19 | NUMBER_NON_CORR_STOCKS = 5 20 | MAX_TRADE = 10 21 | TRAIN_RATIO = 0.8 22 | 23 | # Pools of stocks to trade 24 | DJI = ['MMM', 'AXP', 'AAPL', 'BA', 'CAT', 'CVX', 'CSCO', 'KO', 'DIS', 'XOM', 'GE', 'GS', 'HD', 'IBM', 'INTC', 'JNJ', 25 | 'JPM', 'MCD', 'MRK', 'MSFT', 'NKE', 'PFE', 'PG', 'UTX', 'UNH', 'VZ', 'WMT'] 26 | 27 | DJI_N = ['3M','American Express', 'Apple','Boeing','Caterpillar','Chevron','Cisco Systems','Coca-Cola','Disney' 28 | ,'ExxonMobil','General Electric','Goldman Sachs','Home Depot','IBM','Intel','Johnson & Johnson', 29 | 'JPMorgan Chase','McDonalds','Merck','Microsoft','NIKE','Pfizer','Procter & Gamble', 30 | 'United Technologies','UnitedHealth Group','Verizon Communications','Wal Mart'] 31 | 32 | #Market and macroeconomic data to be used as context data 33 | CONTEXT_DATA = ['^GSPC', '^DJI', '^IXIC', '^RUT', 'SPY', 'QQQ', '^VIX', 'GLD', '^TYX', '^TNX' , 'SHY', 'SHV'] 34 | 35 | 36 | # ------------------------------ PREPROCESSING --------------------------------- 37 | print ("\n") 38 | print ("############################## Welcome to the playground of Star Trader!! ###################################") 39 | print ("\n") 40 | print ("Hello, I am Star, I am learning to trade like a human. In this playground, I trade stocks and optimize my portfolio.") 41 | print ("\n") 42 | 43 | print ("Starting to pre-process data for trading environment construction ... ") 44 | # Data Preprocessing 45 | dataset = dp.DataRetrieval() 46 | dow_stocks_train, dow_stocks_test = dataset.get_all() 47 | dow_stock_volume = dataset.components_df_v[DJI] 48 | portfolios = dp.Trading(dow_stocks_train, dow_stocks_test, dow_stock_volume.loc[START_TEST:END_TEST]) 49 | _, _, non_corr_stocks = portfolios.find_non_correlate_stocks(NUMBER_NON_CORR_STOCKS) 50 | 51 | # Reuse the training agent, StarTrader-v0 's preprocessed feature data 52 | input_states = pd.read_csv("./data/ddpg_input_states.csv", index_col='Date', parse_dates=True) 53 | feature_length = len(input_states.columns) 54 | data_length = len(input_states) 55 | stock_price = dataset.components_df_o[non_corr_stocks] 56 | 57 | print("feature length", feature_length) 58 | print("\n") 59 | print("Base on non-correlation, {} stocks are selected for portfolio construction:".format(NUMBER_NON_CORR_STOCKS)) 60 | 61 | for stock in non_corr_stocks: 62 | print(DJI_N[DJI.index(stock)]) 63 | print("\n") 64 | 65 | print("Pre-processing and stock selection complete, trading starts now ...") 66 | print("_______________________________________________________________________________________________________________") 67 | 68 | 69 | # ------------------------------ CLASSES --------------------------------- 70 | 71 | class StarTradingTestEnv(gym.Env): 72 | metadata = {'render.modes': ['human']} 73 | 74 | def __init__(self, day = START_TEST): 75 | 76 | """ 77 | Initializing the trading environment, many trading parameters starting values are defined here 78 | 79 | """ 80 | self.iteration = 0 81 | self.day = day 82 | self.done = False 83 | 84 | 85 | # trading agent's action with low and high as the maximum number of stocks allowed to trade, 86 | # defined using Gym's Box action space function 87 | self.action_space = spaces.Box(low = -MAX_TRADE, high = MAX_TRADE,shape = (NUMBER_NON_CORR_STOCKS,),dtype=np.int8) 88 | 89 | # [account balance]+[unrealized profit/loss] +[number of features, 36]+[portfolio stock of 5 stocks holdings] 90 | self.full_feature_length = 2 + feature_length 91 | print("full length", self.full_feature_length) 92 | self.observation_space = spaces.Box(low=0, high=np.inf, shape = (self.full_feature_length + NUMBER_NON_CORR_STOCKS,)) 93 | 94 | # Sliding the timeline window day-by-day, skipping the non-trading day as data is not available 95 | wrong_day = True 96 | add_day = 0 97 | while wrong_day: 98 | try: 99 | temp_date = self.day + timedelta(days=add_day) 100 | self.data = input_states.loc[temp_date] 101 | self.day = temp_date 102 | wrong_day = False 103 | except: 104 | add_day += 1 105 | 106 | self.timeline = [self.day] 107 | # The money in the trading account 108 | self.acc_balance = [STARTING_ACC_BALANCE] 109 | self.total_asset = self.acc_balance 110 | self.portfolio_asset = [0.0] 111 | self.buy_price = np.zeros((1,NUMBER_NON_CORR_STOCKS)).flatten() 112 | # Unrealized profit and loss 113 | self.unrealized_pnl = [0.0] 114 | #The value of all-stock holdings 115 | self.portfolio_value = 0.0 116 | 117 | 118 | # The state of the trading environment, defined by account balance, unrealized profit and loss, relevant 119 | # stock technical data & current stock holdings 120 | self.state = self.acc_balance + self.unrealized_pnl + self.data.values.tolist() + [0 for i in range(NUMBER_NON_CORR_STOCKS)] 121 | 122 | # The current reward of the agent. 123 | self.reward = 0 124 | self._seed() 125 | self.reset() 126 | 127 | def _sell(self, idx, action): 128 | 129 | # If there is existing stock holding 130 | if self.state[idx + self.full_feature_length] > 0: 131 | # Only need to sell the unit recommended by the trading agent, not necessarily all stock unit. 132 | # Update account balance after transaction 133 | self.state[0] += stock_price.loc[self.day][idx] * min(abs(int(action)), self.state[idx + self.full_feature_length]) 134 | # Update stock holding 135 | self.state[idx + self.full_feature_length] -= min(abs(int(action)), self.state[idx + self.full_feature_length]) 136 | # Reset transacted buy price record to 0.0 if there is no more stock holding 137 | if self.state[idx + self.full_feature_length] == 0.0: 138 | self.buy_price[idx] = 0.0 139 | else: 140 | pass 141 | 142 | def _buy(self, idx, action): 143 | 144 | # Calculate the maximum possible number of stock unit the current cash can buy 145 | available_unit = self.state[0] // stock_price.loc[self.day][idx] 146 | 147 | # Deduct the traded amount from account balance. If available balance is not enough to purchase stock unit 148 | # recommended by trading agent's action, just use what is left. 149 | self.state[0] -= stock_price.loc[self.day][idx] * min(available_unit, int(action)) 150 | 151 | 152 | # If there are existing stock holding already, calculate the average buy price 153 | if self.state[idx + self.full_feature_length] > 0.0: 154 | existing_unit = self.state[idx + self.full_feature_length] 155 | previous_buy_price = self.buy_price[idx] 156 | additional_unit = min(available_unit, int(action)) 157 | new_holding = existing_unit + additional_unit 158 | self.buy_price[idx] = ((existing_unit * previous_buy_price ) + (stock_price.loc[self.day][idx] * additional_unit))/ new_holding 159 | # if there is no existing stock holding, simply record the current buy price 160 | elif self.state[idx + self.full_feature_length] == 0.0: 161 | self.buy_price[idx] = stock_price.loc[self.day][idx] 162 | 163 | # Update stock holding at its index 164 | self.state[idx + self.full_feature_length] += min(available_unit, int(action)) 165 | 166 | 167 | def step(self, actions): 168 | 169 | # Episode ends when timestep reaches the last day in feature data 170 | self.done = self.day >= END_TEST 171 | 172 | # If it is the last step, plot trading performance 173 | if self.done: 174 | print("@@@@@@@@@@@@@@@@@") 175 | print("Iteration", self.iteration-1) 176 | # Construct trading and save to a spreadsheet for analysis 177 | trading_book = pd.DataFrame(index=self.timeline, columns=["Cash balance", "Portfolio value", "Total asset", "Returns", "Cum Returns"]) 178 | trading_book["Cash balance"] = self.acc_balance 179 | trading_book["Portfolio value"] = self.portfolio_asset 180 | trading_book["Total asset"] = self.total_asset 181 | trading_book["Returns"] = trading_book["Total asset"] / trading_book["Total asset"].shift(1) - 1 182 | trading_book["CumReturns"] = trading_book["Returns"].add(1).cumprod().fillna(1) 183 | trading_book.to_csv('./test_result/trading_book_test_{}.csv'.format(self.iteration-1)) 184 | 185 | kpi = dp.MathCalc.calc_kpi(trading_book) 186 | kpi.to_csv('./test_result/kpi_test_{}.csv'.format(self.iteration-1)) 187 | print("===============================================================================================") 188 | print(kpi) 189 | print("===============================================================================================") 190 | 191 | # Visualize results 192 | plt.plot(trading_book.index, trading_book["Cash balance"], 'g', label='Account cash balance',alpha=0.8,) 193 | plt.plot(trading_book.index, trading_book["Portfolio value"], 'r', label='Portfolio value',alpha=1,lw=1.5) 194 | plt.plot(trading_book.index, trading_book["Total asset"], 'b', label='Total asset',alpha=0.6,lw=3) 195 | plt.xlabel('Timeline', fontsize=12) 196 | plt.ylabel('Value', fontsize=12) 197 | plt.title('Portfolio value + Account fund evolution @ test iteration {}'.format(self.iteration-1), fontsize=13) 198 | plt.legend() 199 | plt.savefig('./test_result/asset_evolution_test_{}.png'.format(self.iteration-1)) 200 | plt.show() 201 | plt.close() 202 | 203 | plt.plot(trading_book.index, trading_book["CumReturns"], 'g', label='Cummulative returns') 204 | plt.xlabel('Timeline', fontsize=12) 205 | plt.ylabel('Returns', fontsize=12) 206 | plt.title('Cummulative returns @ test iteration {}'.format(self.iteration-1), fontsize=13) 207 | plt.legend() 208 | plt.savefig('./test_result/cummulative_returns_test_{}.png'.format(self.iteration-1)) 209 | plt.show() 210 | plt.close() 211 | 212 | return self.state, self.reward, self.done, {} 213 | 214 | else: 215 | # Portfolio value is current holdings multiply with current respective stock prices 216 | portfolio_value = sum(np.array(stock_price.loc[self.day]) * np.array(self.state[self.full_feature_length:])) 217 | # Total asset is account balance + portfolio value 218 | total_asset_starting = self.state[0] + portfolio_value 219 | 220 | # Sort the trade order in increasing order, stocks with higher number of units to sell will be 221 | # transacted first. Stocks with lesser number of units to buy will be transacted first 222 | sorted_actions = np.argsort(actions) 223 | # Get the stocks to be sold 224 | sell_stocks = sorted_actions[:np.where(actions < 0)[0].shape[0]] 225 | 226 | # Alternatively, sell with static order 227 | # sell_stocks = np.where(actions < 0)[0].flatten() 228 | # np.random.shuffle(sell_stocks) 229 | for stock_idx in sell_stocks: 230 | self._sell(stock_idx, actions[stock_idx]) 231 | 232 | # Get the stocks to be bought 233 | buy_stocks = sorted_actions[::-1][:np.where(actions > 0)[0].shape[0]] 234 | # Alternatively, buy with static order 235 | # buy_stocks = np.where(actions > 0)[0].flatten() 236 | # np.random.shuffle(buy_stocks) 237 | for stock_idx in buy_stocks: 238 | self._buy(stock_idx, actions[stock_idx]) 239 | 240 | # Update date and skip some date since not every day is trading day 241 | self.day += timedelta(days=1) 242 | wrong_day = True 243 | add_day = 0 244 | while wrong_day: 245 | try: 246 | temp_date = self.day + timedelta(days=add_day) 247 | self.data = input_states.loc[temp_date] 248 | self.day = temp_date 249 | wrong_day = False 250 | except: 251 | add_day += 1 252 | 253 | # Calculate unrealized profit and loss for existing stock holdings 254 | self.unrealized_pnl = np.sum(np.array(stock_price.loc[self.day] - self.buy_price) * np.array( 255 | self.state[self.full_feature_length:])) 256 | 257 | # Current state space 258 | self.state = [self.state[0]] + [self.unrealized_pnl] + self.data.values.tolist() + list( 259 | self.state[self.full_feature_length:]) 260 | # Portfolio value is the current stock prices multiply with their respective holdings 261 | portfolio_value = sum(np.array(stock_price.loc[self.day]) * np.array(self.state[self.full_feature_length:])) 262 | # Total asset = account balance + portfolio value 263 | total_asset_ending = self.state[0] + portfolio_value 264 | 265 | # Update account balance statement 266 | self.acc_balance = np.append(self.acc_balance, self.state[0]) 267 | 268 | # Update portfolio value statement 269 | self.portfolio_asset = np.append(self.portfolio_asset, portfolio_value) 270 | 271 | # Update total asset statement 272 | self.total_asset = np.append(self.total_asset, total_asset_ending) 273 | 274 | # Update timeline 275 | self.timeline = np.append(self.timeline, self.day) 276 | 277 | # Get the agent to consider gain-to-pain or lake ratio and be responsible for it if it has traded long enough 278 | if len(self.total_asset) > 9: 279 | returns = dp.MathCalc.calc_return(pd.Series(self.total_asset)) 280 | self.reward = total_asset_ending - total_asset_starting + (100 * dp.MathCalc.calc_gain_to_pain(returns)) \ 281 | - (500 * dp.MathCalc.calc_lake_ratio(pd.Series(returns).add(1).cumprod().fillna(1))) 282 | 283 | # If agent has not traded long enough, it only has to bear total asset difference at the end of the day 284 | else: 285 | self.reward = total_asset_ending - total_asset_starting 286 | 287 | return self.state, self.reward, self.done, {} 288 | 289 | def reset(self): 290 | """ 291 | Reset the environment once an episode end 292 | 293 | """ 294 | self.acc_balance = [STARTING_ACC_BALANCE] 295 | self.total_asset = self.acc_balance 296 | self.portfolio_asset = [0.0] 297 | self.buy_price = np.zeros((1, NUMBER_NON_CORR_STOCKS)).flatten() 298 | self.unrealized_pnl = [0] 299 | self.day = START_TEST 300 | self.portfolio_value = 0.0 301 | 302 | wrong_day = True 303 | add_day = 0 304 | while wrong_day: 305 | try: 306 | temp_date = self.day + timedelta(days=add_day) 307 | self.data = input_states.loc[temp_date] 308 | self.day = temp_date 309 | wrong_day = False 310 | except: 311 | add_day += 1 312 | 313 | self.timeline = [self.day] 314 | 315 | self.state = self.acc_balance + self.unrealized_pnl + self.data.values.tolist() + [0 for i in range(NUMBER_NON_CORR_STOCKS)] 316 | self.iteration += 1 317 | print("\n") 318 | 319 | 320 | return self.state 321 | 322 | def render(self, mode='human'): 323 | return self.state 324 | 325 | def _seed(self, seed=None): 326 | self.np_random, seed = seeding.np_random(seed) 327 | return [seed] 328 | -------------------------------------------------------------------------------- /gym/envs/StarTraderTest/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.StarTraderTest.StarTrade_test_env import StarTradingTestEnv 2 | -------------------------------------------------------------------------------- /gym/envs/StarTraderTest/__pycache__/StarTrade_env.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/gym/envs/StarTraderTest/__pycache__/StarTrade_env.cpython-36.pyc -------------------------------------------------------------------------------- /gym/envs/StarTraderTest/__pycache__/StarTrade_test_env.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/gym/envs/StarTraderTest/__pycache__/StarTrade_test_env.cpython-36.pyc -------------------------------------------------------------------------------- /gym/envs/StarTraderTest/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/gym/envs/StarTraderTest/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /gym/envs/StarTraderTest/__pycache__/intelliTrade_env.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/gym/envs/StarTraderTest/__pycache__/intelliTrade_env.cpython-36.pyc -------------------------------------------------------------------------------- /gym/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import registry, register, make, spec 2 | 3 | register( 4 | id='StarTrader-v0', 5 | entry_point='gym.envs.StarTrader:StarTradingEnv', 6 | ) 7 | 8 | register( 9 | id='StarTraderTest-v0', 10 | entry_point='gym.envs.StarTraderTest:StarTradingTestEnv', 11 | ) 12 | 13 | # Algorithmic 14 | # ---------------------------------------- 15 | 16 | 17 | 18 | register( 19 | id='Copy-v0', 20 | entry_point='gym.envs.algorithmic:CopyEnv', 21 | max_episode_steps=200, 22 | reward_threshold=25.0, 23 | ) 24 | 25 | register( 26 | id='RepeatCopy-v0', 27 | entry_point='gym.envs.algorithmic:RepeatCopyEnv', 28 | max_episode_steps=200, 29 | reward_threshold=75.0, 30 | ) 31 | 32 | register( 33 | id='ReversedAddition-v0', 34 | entry_point='gym.envs.algorithmic:ReversedAdditionEnv', 35 | kwargs={'rows' : 2}, 36 | max_episode_steps=200, 37 | reward_threshold=25.0, 38 | ) 39 | 40 | register( 41 | id='ReversedAddition3-v0', 42 | entry_point='gym.envs.algorithmic:ReversedAdditionEnv', 43 | kwargs={'rows' : 3}, 44 | max_episode_steps=200, 45 | reward_threshold=25.0, 46 | ) 47 | 48 | register( 49 | id='DuplicatedInput-v0', 50 | entry_point='gym.envs.algorithmic:DuplicatedInputEnv', 51 | max_episode_steps=200, 52 | reward_threshold=9.0, 53 | ) 54 | 55 | register( 56 | id='Reverse-v0', 57 | entry_point='gym.envs.algorithmic:ReverseEnv', 58 | max_episode_steps=200, 59 | reward_threshold=25.0, 60 | ) 61 | 62 | # Classic 63 | # ---------------------------------------- 64 | 65 | register( 66 | id='CartPole-v0', 67 | entry_point='gym.envs.classic_control:CartPoleEnv', 68 | max_episode_steps=200, 69 | reward_threshold=195.0, 70 | ) 71 | 72 | register( 73 | id='CartPole-v1', 74 | entry_point='gym.envs.classic_control:CartPoleEnv', 75 | max_episode_steps=500, 76 | reward_threshold=475.0, 77 | ) 78 | 79 | register( 80 | id='MountainCar-v0', 81 | entry_point='gym.envs.classic_control:MountainCarEnv', 82 | max_episode_steps=200, 83 | reward_threshold=-110.0, 84 | ) 85 | 86 | register( 87 | id='MountainCarContinuous-v0', 88 | entry_point='gym.envs.classic_control:Continuous_MountainCarEnv', 89 | max_episode_steps=999, 90 | reward_threshold=90.0, 91 | ) 92 | 93 | register( 94 | id='Pendulum-v0', 95 | entry_point='gym.envs.classic_control:PendulumEnv', 96 | max_episode_steps=200, 97 | ) 98 | 99 | register( 100 | id='Acrobot-v1', 101 | entry_point='gym.envs.classic_control:AcrobotEnv', 102 | max_episode_steps=500, 103 | ) 104 | 105 | # Box2d 106 | # ---------------------------------------- 107 | 108 | register( 109 | id='LunarLander-v2', 110 | entry_point='gym.envs.box2d:LunarLander', 111 | max_episode_steps=1000, 112 | reward_threshold=200, 113 | ) 114 | 115 | register( 116 | id='LunarLanderContinuous-v2', 117 | entry_point='gym.envs.box2d:LunarLanderContinuous', 118 | max_episode_steps=1000, 119 | reward_threshold=200, 120 | ) 121 | 122 | register( 123 | id='BipedalWalker-v2', 124 | entry_point='gym.envs.box2d:BipedalWalker', 125 | max_episode_steps=1600, 126 | reward_threshold=300, 127 | ) 128 | 129 | register( 130 | id='BipedalWalkerHardcore-v2', 131 | entry_point='gym.envs.box2d:BipedalWalkerHardcore', 132 | max_episode_steps=2000, 133 | reward_threshold=300, 134 | ) 135 | 136 | register( 137 | id='CarRacing-v0', 138 | entry_point='gym.envs.box2d:CarRacing', 139 | max_episode_steps=1000, 140 | reward_threshold=900, 141 | ) 142 | 143 | # Toy Text 144 | # ---------------------------------------- 145 | 146 | register( 147 | id='Blackjack-v0', 148 | entry_point='gym.envs.toy_text:BlackjackEnv', 149 | ) 150 | 151 | register( 152 | id='KellyCoinflip-v0', 153 | entry_point='gym.envs.toy_text:KellyCoinflipEnv', 154 | reward_threshold=246.61, 155 | ) 156 | register( 157 | id='KellyCoinflipGeneralized-v0', 158 | entry_point='gym.envs.toy_text:KellyCoinflipGeneralizedEnv', 159 | ) 160 | 161 | register( 162 | id='FrozenLake-v0', 163 | entry_point='gym.envs.toy_text:FrozenLakeEnv', 164 | kwargs={'map_name' : '4x4'}, 165 | max_episode_steps=100, 166 | reward_threshold=0.78, # optimum = .8196 167 | ) 168 | 169 | register( 170 | id='FrozenLake8x8-v0', 171 | entry_point='gym.envs.toy_text:FrozenLakeEnv', 172 | kwargs={'map_name' : '8x8'}, 173 | max_episode_steps=200, 174 | reward_threshold=0.99, # optimum = 1 175 | ) 176 | 177 | register( 178 | id='CliffWalking-v0', 179 | entry_point='gym.envs.toy_text:CliffWalkingEnv', 180 | ) 181 | 182 | register( 183 | id='NChain-v0', 184 | entry_point='gym.envs.toy_text:NChainEnv', 185 | max_episode_steps=1000, 186 | ) 187 | 188 | register( 189 | id='Roulette-v0', 190 | entry_point='gym.envs.toy_text:RouletteEnv', 191 | max_episode_steps=100, 192 | ) 193 | 194 | register( 195 | id='Taxi-v2', 196 | entry_point='gym.envs.toy_text.taxi:TaxiEnv', 197 | reward_threshold=8, # optimum = 8.46 198 | max_episode_steps=200, 199 | ) 200 | 201 | register( 202 | id='GuessingGame-v0', 203 | entry_point='gym.envs.toy_text.guessing_game:GuessingGame', 204 | max_episode_steps=200, 205 | ) 206 | 207 | register( 208 | id='HotterColder-v0', 209 | entry_point='gym.envs.toy_text.hotter_colder:HotterColder', 210 | max_episode_steps=200, 211 | ) 212 | 213 | # Mujoco 214 | # ---------------------------------------- 215 | 216 | # 2D 217 | 218 | register( 219 | id='Reacher-v2', 220 | entry_point='gym.envs.mujoco:ReacherEnv', 221 | max_episode_steps=50, 222 | reward_threshold=-3.75, 223 | ) 224 | 225 | register( 226 | id='Pusher-v2', 227 | entry_point='gym.envs.mujoco:PusherEnv', 228 | max_episode_steps=100, 229 | reward_threshold=0.0, 230 | ) 231 | 232 | register( 233 | id='Thrower-v2', 234 | entry_point='gym.envs.mujoco:ThrowerEnv', 235 | max_episode_steps=100, 236 | reward_threshold=0.0, 237 | ) 238 | 239 | register( 240 | id='Striker-v2', 241 | entry_point='gym.envs.mujoco:StrikerEnv', 242 | max_episode_steps=100, 243 | reward_threshold=0.0, 244 | ) 245 | 246 | register( 247 | id='InvertedPendulum-v2', 248 | entry_point='gym.envs.mujoco:InvertedPendulumEnv', 249 | max_episode_steps=1000, 250 | reward_threshold=950.0, 251 | ) 252 | 253 | register( 254 | id='InvertedDoublePendulum-v2', 255 | entry_point='gym.envs.mujoco:InvertedDoublePendulumEnv', 256 | max_episode_steps=1000, 257 | reward_threshold=9100.0, 258 | ) 259 | 260 | register( 261 | id='HalfCheetah-v2', 262 | entry_point='gym.envs.mujoco:HalfCheetahEnv', 263 | max_episode_steps=1000, 264 | reward_threshold=4800.0, 265 | ) 266 | 267 | register( 268 | id='Hopper-v2', 269 | entry_point='gym.envs.mujoco:HopperEnv', 270 | max_episode_steps=1000, 271 | reward_threshold=3800.0, 272 | ) 273 | 274 | register( 275 | id='Swimmer-v2', 276 | entry_point='gym.envs.mujoco:SwimmerEnv', 277 | max_episode_steps=1000, 278 | reward_threshold=360.0, 279 | ) 280 | 281 | register( 282 | id='Walker2d-v2', 283 | max_episode_steps=1000, 284 | entry_point='gym.envs.mujoco:Walker2dEnv', 285 | ) 286 | 287 | register( 288 | id='Ant-v2', 289 | entry_point='gym.envs.mujoco:AntEnv', 290 | max_episode_steps=1000, 291 | reward_threshold=6000.0, 292 | ) 293 | 294 | register( 295 | id='Humanoid-v2', 296 | entry_point='gym.envs.mujoco:HumanoidEnv', 297 | max_episode_steps=1000, 298 | ) 299 | 300 | register( 301 | id='HumanoidStandup-v2', 302 | entry_point='gym.envs.mujoco:HumanoidStandupEnv', 303 | max_episode_steps=1000, 304 | ) 305 | 306 | # Robotics 307 | # ---------------------------------------- 308 | 309 | def _merge(a, b): 310 | a.update(b) 311 | return a 312 | 313 | for reward_type in ['sparse', 'dense']: 314 | suffix = 'Dense' if reward_type == 'dense' else '' 315 | kwargs = { 316 | 'reward_type': reward_type, 317 | } 318 | 319 | # Fetch 320 | register( 321 | id='FetchSlide{}-v1'.format(suffix), 322 | entry_point='gym.envs.robotics:FetchSlideEnv', 323 | kwargs=kwargs, 324 | max_episode_steps=50, 325 | ) 326 | 327 | register( 328 | id='FetchPickAndPlace{}-v1'.format(suffix), 329 | entry_point='gym.envs.robotics:FetchPickAndPlaceEnv', 330 | kwargs=kwargs, 331 | max_episode_steps=50, 332 | ) 333 | 334 | register( 335 | id='FetchReach{}-v1'.format(suffix), 336 | entry_point='gym.envs.robotics:FetchReachEnv', 337 | kwargs=kwargs, 338 | max_episode_steps=50, 339 | ) 340 | 341 | register( 342 | id='FetchPush{}-v1'.format(suffix), 343 | entry_point='gym.envs.robotics:FetchPushEnv', 344 | kwargs=kwargs, 345 | max_episode_steps=50, 346 | ) 347 | 348 | # Hand 349 | register( 350 | id='HandReach{}-v0'.format(suffix), 351 | entry_point='gym.envs.robotics:HandReachEnv', 352 | kwargs=kwargs, 353 | max_episode_steps=50, 354 | ) 355 | 356 | register( 357 | id='HandManipulateBlockRotateZ{}-v0'.format(suffix), 358 | entry_point='gym.envs.robotics:HandBlockEnv', 359 | kwargs=_merge({'target_position': 'ignore', 'target_rotation': 'z'}, kwargs), 360 | max_episode_steps=100, 361 | ) 362 | 363 | register( 364 | id='HandManipulateBlockRotateParallel{}-v0'.format(suffix), 365 | entry_point='gym.envs.robotics:HandBlockEnv', 366 | kwargs=_merge({'target_position': 'ignore', 'target_rotation': 'parallel'}, kwargs), 367 | max_episode_steps=100, 368 | ) 369 | 370 | register( 371 | id='HandManipulateBlockRotateXYZ{}-v0'.format(suffix), 372 | entry_point='gym.envs.robotics:HandBlockEnv', 373 | kwargs=_merge({'target_position': 'ignore', 'target_rotation': 'xyz'}, kwargs), 374 | max_episode_steps=100, 375 | ) 376 | 377 | register( 378 | id='HandManipulateBlockFull{}-v0'.format(suffix), 379 | entry_point='gym.envs.robotics:HandBlockEnv', 380 | kwargs=_merge({'target_position': 'random', 'target_rotation': 'xyz'}, kwargs), 381 | max_episode_steps=100, 382 | ) 383 | 384 | # Alias for "Full" 385 | register( 386 | id='HandManipulateBlock{}-v0'.format(suffix), 387 | entry_point='gym.envs.robotics:HandBlockEnv', 388 | kwargs=_merge({'target_position': 'random', 'target_rotation': 'xyz'}, kwargs), 389 | max_episode_steps=100, 390 | ) 391 | 392 | register( 393 | id='HandManipulateEggRotate{}-v0'.format(suffix), 394 | entry_point='gym.envs.robotics:HandEggEnv', 395 | kwargs=_merge({'target_position': 'ignore', 'target_rotation': 'xyz'}, kwargs), 396 | max_episode_steps=100, 397 | ) 398 | 399 | register( 400 | id='HandManipulateEggFull{}-v0'.format(suffix), 401 | entry_point='gym.envs.robotics:HandEggEnv', 402 | kwargs=_merge({'target_position': 'random', 'target_rotation': 'xyz'}, kwargs), 403 | max_episode_steps=100, 404 | ) 405 | 406 | # Alias for "Full" 407 | register( 408 | id='HandManipulateEgg{}-v0'.format(suffix), 409 | entry_point='gym.envs.robotics:HandEggEnv', 410 | kwargs=_merge({'target_position': 'random', 'target_rotation': 'xyz'}, kwargs), 411 | max_episode_steps=100, 412 | ) 413 | 414 | register( 415 | id='HandManipulatePenRotate{}-v0'.format(suffix), 416 | entry_point='gym.envs.robotics:HandPenEnv', 417 | kwargs=_merge({'target_position': 'ignore', 'target_rotation': 'xyz'}, kwargs), 418 | max_episode_steps=100, 419 | ) 420 | 421 | register( 422 | id='HandManipulatePenFull{}-v0'.format(suffix), 423 | entry_point='gym.envs.robotics:HandPenEnv', 424 | kwargs=_merge({'target_position': 'random', 'target_rotation': 'xyz'}, kwargs), 425 | max_episode_steps=100, 426 | ) 427 | 428 | # Alias for "Full" 429 | register( 430 | id='HandManipulatePen{}-v0'.format(suffix), 431 | entry_point='gym.envs.robotics:HandPenEnv', 432 | kwargs=_merge({'target_position': 'random', 'target_rotation': 'xyz'}, kwargs), 433 | max_episode_steps=100, 434 | ) 435 | 436 | # Atari 437 | # ---------------------------------------- 438 | 439 | # # print ', '.join(["'{}'".format(name.split('.')[0]) for name in atari_py.list_games()]) 440 | for game in ['air_raid', 'alien', 'amidar', 'assault', 'asterix', 'asteroids', 'atlantis', 441 | 'bank_heist', 'battle_zone', 'beam_rider', 'berzerk', 'bowling', 'boxing', 'breakout', 'carnival', 442 | 'centipede', 'chopper_command', 'crazy_climber', 'demon_attack', 'double_dunk', 443 | 'elevator_action', 'enduro', 'fishing_derby', 'freeway', 'frostbite', 'gopher', 'gravitar', 444 | 'hero', 'ice_hockey', 'jamesbond', 'journey_escape', 'kangaroo', 'krull', 'kung_fu_master', 445 | 'montezuma_revenge', 'ms_pacman', 'name_this_game', 'phoenix', 'pitfall', 'pong', 'pooyan', 446 | 'private_eye', 'qbert', 'riverraid', 'road_runner', 'robotank', 'seaquest', 'skiing', 447 | 'solaris', 'space_invaders', 'star_gunner', 'tennis', 'time_pilot', 'tutankham', 'up_n_down', 448 | 'venture', 'video_pinball', 'wizard_of_wor', 'yars_revenge', 'zaxxon']: 449 | for obs_type in ['image', 'ram']: 450 | # space_invaders should yield SpaceInvaders-v0 and SpaceInvaders-ram-v0 451 | name = ''.join([g.capitalize() for g in game.split('_')]) 452 | if obs_type == 'ram': 453 | name = '{}-ram'.format(name) 454 | 455 | nondeterministic = False 456 | if game == 'elevator_action' and obs_type == 'ram': 457 | # ElevatorAction-ram-v0 seems to yield slightly 458 | # non-deterministic observations about 10% of the time. We 459 | # should track this down eventually, but for now we just 460 | # mark it as nondeterministic. 461 | nondeterministic = True 462 | 463 | register( 464 | id='{}-v0'.format(name), 465 | entry_point='gym.envs.atari:AtariEnv', 466 | kwargs={'game': game, 'obs_type': obs_type, 'repeat_action_probability': 0.25}, 467 | max_episode_steps=10000, 468 | nondeterministic=nondeterministic, 469 | ) 470 | 471 | register( 472 | id='{}-v4'.format(name), 473 | entry_point='gym.envs.atari:AtariEnv', 474 | kwargs={'game': game, 'obs_type': obs_type}, 475 | max_episode_steps=100000, 476 | nondeterministic=nondeterministic, 477 | ) 478 | 479 | # Standard Deterministic (as in the original DeepMind paper) 480 | if game == 'space_invaders': 481 | frameskip = 3 482 | else: 483 | frameskip = 4 484 | 485 | # Use a deterministic frame skip. 486 | register( 487 | id='{}Deterministic-v0'.format(name), 488 | entry_point='gym.envs.atari:AtariEnv', 489 | kwargs={'game': game, 'obs_type': obs_type, 'frameskip': frameskip, 'repeat_action_probability': 0.25}, 490 | max_episode_steps=100000, 491 | nondeterministic=nondeterministic, 492 | ) 493 | 494 | register( 495 | id='{}Deterministic-v4'.format(name), 496 | entry_point='gym.envs.atari:AtariEnv', 497 | kwargs={'game': game, 'obs_type': obs_type, 'frameskip': frameskip}, 498 | max_episode_steps=100000, 499 | nondeterministic=nondeterministic, 500 | ) 501 | 502 | register( 503 | id='{}NoFrameskip-v0'.format(name), 504 | entry_point='gym.envs.atari:AtariEnv', 505 | kwargs={'game': game, 'obs_type': obs_type, 'frameskip': 1, 'repeat_action_probability': 0.25}, # A frameskip of 1 means we get every frame 506 | max_episode_steps=frameskip * 100000, 507 | nondeterministic=nondeterministic, 508 | ) 509 | 510 | # No frameskip. (Atari has no entropy source, so these are 511 | # deterministic environments.) 512 | register( 513 | id='{}NoFrameskip-v4'.format(name), 514 | entry_point='gym.envs.atari:AtariEnv', 515 | kwargs={'game': game, 'obs_type': obs_type, 'frameskip': 1}, # A frameskip of 1 means we get every frame 516 | max_episode_steps=frameskip * 100000, 517 | nondeterministic=nondeterministic, 518 | ) 519 | 520 | 521 | 522 | 523 | 524 | # Unit test 525 | # --------- 526 | 527 | register( 528 | id='CubeCrash-v0', 529 | entry_point='gym.envs.unittest:CubeCrash', 530 | reward_threshold=0.9, 531 | ) 532 | register( 533 | id='CubeCrashSparse-v0', 534 | entry_point='gym.envs.unittest:CubeCrashSparse', 535 | reward_threshold=0.9, 536 | ) 537 | register( 538 | id='CubeCrashScreenBecomesBlack-v0', 539 | entry_point='gym.envs.unittest:CubeCrashScreenBecomesBlack', 540 | reward_threshold=0.9, 541 | ) 542 | 543 | register( 544 | id='MemorizeDigits-v0', 545 | entry_point='gym.envs.unittest:MemorizeDigits', 546 | reward_threshold=20, 547 | ) 548 | 549 | -------------------------------------------------------------------------------- /model/DDPG_trained_model_0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/model/DDPG_trained_model_0 -------------------------------------------------------------------------------- /model/DDPG_trained_model_1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/model/DDPG_trained_model_1 -------------------------------------------------------------------------------- /model/DDPG_trained_model_2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/model/DDPG_trained_model_2 -------------------------------------------------------------------------------- /model/DDPG_trained_model_3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/model/DDPG_trained_model_3 -------------------------------------------------------------------------------- /model/DDPG_trained_model_4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/model/DDPG_trained_model_4 -------------------------------------------------------------------------------- /model/DDPG_trained_model_5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/model/DDPG_trained_model_5 -------------------------------------------------------------------------------- /model/DDPG_trained_model_6: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/model/DDPG_trained_model_6 -------------------------------------------------------------------------------- /model/DDPG_trained_model_7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/model/DDPG_trained_model_7 -------------------------------------------------------------------------------- /model/DDPG_trained_model_8: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/model/DDPG_trained_model_8 -------------------------------------------------------------------------------- /model/DDPG_trained_model_9: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/model/DDPG_trained_model_9 -------------------------------------------------------------------------------- /model/best_lstm_model.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/model/best_lstm_model.h5 -------------------------------------------------------------------------------- /results/model_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_1.jpg -------------------------------------------------------------------------------- /results/model_10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_10.jpg -------------------------------------------------------------------------------- /results/model_10_rerun.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_10_rerun.jpg -------------------------------------------------------------------------------- /results/model_11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_11.jpg -------------------------------------------------------------------------------- /results/model_11_rerun.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_11_rerun.jpg -------------------------------------------------------------------------------- /results/model_12.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_12.jpg -------------------------------------------------------------------------------- /results/model_13.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_13.jpg -------------------------------------------------------------------------------- /results/model_14.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_14.jpg -------------------------------------------------------------------------------- /results/model_15.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_15.jpg -------------------------------------------------------------------------------- /results/model_16.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_16.jpg -------------------------------------------------------------------------------- /results/model_17.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_17.jpg -------------------------------------------------------------------------------- /results/model_18.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_18.jpg -------------------------------------------------------------------------------- /results/model_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_2.jpg -------------------------------------------------------------------------------- /results/model_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_3.jpg -------------------------------------------------------------------------------- /results/model_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_4.jpg -------------------------------------------------------------------------------- /results/model_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_5.jpg -------------------------------------------------------------------------------- /results/model_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_6.jpg -------------------------------------------------------------------------------- /results/model_7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_7.jpg -------------------------------------------------------------------------------- /results/model_8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_8.jpg -------------------------------------------------------------------------------- /results/model_9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/results/model_9.jpg -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import multiprocessing 3 | import os.path as osp 4 | import gym 5 | from collections import defaultdict 6 | import tensorflow as tf 7 | import numpy as np 8 | 9 | from baselines.common.vec_env.vec_video_recorder import VecVideoRecorder 10 | from baselines.common.vec_env.vec_frame_stack import VecFrameStack 11 | from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, make_env 12 | from baselines.common.tf_util import get_session 13 | from baselines import logger 14 | from importlib import import_module 15 | 16 | from baselines.common.vec_env.vec_normalize import VecNormalize 17 | 18 | try: 19 | from mpi4py import MPI 20 | except ImportError: 21 | MPI = None 22 | 23 | try: 24 | import pybullet_envs 25 | except ImportError: 26 | pybullet_envs = None 27 | 28 | try: 29 | import roboschool 30 | except ImportError: 31 | roboschool = None 32 | 33 | _game_envs = defaultdict(set) 34 | for env in gym.envs.registry.all(): 35 | # TODO: solve this with regexes 36 | env_type = env._entry_point.split(':')[0].split('.')[-1] 37 | _game_envs[env_type].add(env.id) 38 | 39 | # reading benchmark names directly from retro requires 40 | # importing retro here, and for some reason that crashes tensorflow 41 | # in ubuntu 42 | _game_envs['retro'] = { 43 | 'BubbleBobble-Nes', 44 | 'SuperMarioBros-Nes', 45 | 'TwinBee3PokoPokoDaimaou-Nes', 46 | 'SpaceHarrier-Nes', 47 | 'SonicTheHedgehog-Genesis', 48 | 'Vectorman-Genesis', 49 | 'FinalFight-Snes', 50 | 'SpaceInvaders-Snes', 51 | } 52 | 53 | 54 | def train(args, extra_args): 55 | env_type, env_id = get_env_type(args.env) 56 | print('env_type: {}'.format(env_type)) 57 | 58 | total_timesteps = int(args.num_timesteps) 59 | seed = args.seed 60 | 61 | learn = get_learn_function(args.alg) 62 | alg_kwargs = get_learn_function_defaults(args.alg, env_type) 63 | alg_kwargs.update(extra_args) 64 | 65 | env = build_env(args) 66 | if args.save_video_interval != 0: 67 | env = VecVideoRecorder(env, osp.join(logger.Logger.CURRENT.dir, "videos"), 68 | record_video_trigger=lambda x: x % args.save_video_interval == 0, video_length=args.save_video_length) 69 | 70 | if args.network: 71 | alg_kwargs['network'] = args.network 72 | else: 73 | if alg_kwargs.get('network') is None: 74 | alg_kwargs['network'] = get_default_network(env_type) 75 | 76 | print('Training {} on {}:{} with arguments \n{}'.format(args.alg, env_type, env_id, alg_kwargs)) 77 | 78 | model = learn( 79 | env=env, 80 | seed=seed, 81 | total_timesteps=total_timesteps, 82 | **alg_kwargs 83 | ) 84 | 85 | return model, env 86 | 87 | 88 | def build_env(args): 89 | ncpu = multiprocessing.cpu_count() 90 | if sys.platform == 'darwin': ncpu //= 2 91 | nenv = args.num_env or ncpu 92 | alg = args.alg 93 | seed = args.seed 94 | 95 | env_type, env_id = get_env_type(args.env) 96 | 97 | if env_type in {'atari', 'retro'}: 98 | if alg == 'deepq': 99 | env = make_env(env_id, env_type, seed=seed, wrapper_kwargs={'frame_stack': True}) 100 | elif alg == 'trpo_mpi': 101 | env = make_env(env_id, env_type, seed=seed) 102 | else: 103 | frame_stack_size = 4 104 | env = make_vec_env(env_id, env_type, nenv, seed, gamestate=args.gamestate, reward_scale=args.reward_scale) 105 | env = VecFrameStack(env, frame_stack_size) 106 | 107 | else: 108 | config = tf.ConfigProto(allow_soft_placement=True, 109 | intra_op_parallelism_threads=1, 110 | inter_op_parallelism_threads=1) 111 | config.gpu_options.allow_growth = True 112 | get_session(config=config) 113 | 114 | flatten_dict_observations = alg not in {'her'} 115 | env = make_vec_env(env_id, env_type, args.num_env or 1, seed, reward_scale=args.reward_scale, flatten_dict_observations=flatten_dict_observations) 116 | 117 | if env_type == 'mujoco': 118 | env = VecNormalize(env) 119 | 120 | return env 121 | 122 | 123 | def get_env_type(env_id): 124 | # Re-parse the gym registry, since we could have new envs since last time. 125 | for env in gym.envs.registry.all(): 126 | env_type = env._entry_point.split(':')[0].split('.')[-1] 127 | _game_envs[env_type].add(env.id) # This is a set so add is idempotent 128 | 129 | if env_id in _game_envs.keys(): 130 | env_type = env_id 131 | env_id = [g for g in _game_envs[env_type]][0] 132 | else: 133 | env_type = None 134 | for g, e in _game_envs.items(): 135 | if env_id in e: 136 | env_type = g 137 | break 138 | assert env_type is not None, 'env_id {} is not recognized in env types'.format(env_id, _game_envs.keys()) 139 | 140 | return env_type, env_id 141 | 142 | 143 | def get_default_network(env_type): 144 | if env_type in {'atari', 'retro'}: 145 | return 'cnn' 146 | else: 147 | return 'mlp' 148 | 149 | def get_alg_module(alg, submodule=None): 150 | submodule = submodule or alg 151 | try: 152 | # first try to import the alg module from baselines 153 | alg_module = import_module('.'.join(['baselines', alg, submodule])) 154 | except ImportError: 155 | # then from rl_algs 156 | alg_module = import_module('.'.join(['rl_' + 'algs', alg, submodule])) 157 | 158 | return alg_module 159 | 160 | 161 | def get_learn_function(alg): 162 | return get_alg_module(alg).learn 163 | 164 | 165 | def get_learn_function_defaults(alg, env_type): 166 | try: 167 | alg_defaults = get_alg_module(alg, 'defaults') 168 | kwargs = getattr(alg_defaults, env_type)() 169 | except (ImportError, AttributeError): 170 | kwargs = {} 171 | return kwargs 172 | 173 | 174 | 175 | def parse_cmdline_kwargs(args): 176 | ''' 177 | convert a list of '='-spaced command-line arguments to a dictionary, evaluating python objects when possible 178 | ''' 179 | def parse(v): 180 | 181 | assert isinstance(v, str) 182 | try: 183 | return eval(v) 184 | except (NameError, SyntaxError): 185 | return v 186 | 187 | return {k: parse(v) for k,v in parse_unknown_args(args).items()} 188 | 189 | 190 | 191 | def main(args): 192 | # configure logger, disable logging in child MPI processes (with rank > 0) 193 | 194 | arg_parser = common_arg_parser() 195 | args, unknown_args = arg_parser.parse_known_args(args) 196 | extra_args = parse_cmdline_kwargs(unknown_args) 197 | 198 | if args.extra_import is not None: 199 | import_module(args.extra_import) 200 | 201 | if MPI is None or MPI.COMM_WORLD.Get_rank() == 0: 202 | rank = 0 203 | logger.configure() 204 | else: 205 | logger.configure(format_strs=[]) 206 | rank = MPI.COMM_WORLD.Get_rank() 207 | 208 | # If argument indicate training to be done: 209 | model, env = train(args, extra_args) 210 | env.close() 211 | 212 | if args.save_path is not None and rank == 0: 213 | save_path = osp.expanduser(args.save_path) 214 | model.save(save_path) 215 | saver = tf.train.Saver() 216 | 217 | #logger.info("saving the trained model") 218 | #start_time_save = time.time() 219 | #saver.save(sess, save_path + "ddpg_test_model") 220 | #logger.info('runtime saving: {}s'.format(time.time() - start_time_save)) 221 | 222 | # If it is a test run on the learned model 223 | if args.play: 224 | logger.log("Running trained model") 225 | env = build_env(args) 226 | obs = env.reset() 227 | 228 | state = model.initial_state if hasattr(model, 'initial_state') else None 229 | dones = np.zeros((1,)) 230 | 231 | while True: 232 | if state is not None: 233 | actions, _, state, _ = model.step(obs,S=state, M=dones) 234 | else: 235 | actions, _, _, _ = model.step(obs) 236 | 237 | obs, _, done, _ = env.step(actions) 238 | env.render() 239 | done = done.any() if isinstance(done, np.ndarray) else done 240 | 241 | if done: 242 | obs = env.reset() 243 | 244 | env.close() 245 | 246 | return model 247 | 248 | if __name__ == '__main__': 249 | main(sys.argv) 250 | -------------------------------------------------------------------------------- /test_iteration_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/test_iteration_1.gif -------------------------------------------------------------------------------- /test_result/asset_evolution_test_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/test_result/asset_evolution_test_1.png -------------------------------------------------------------------------------- /test_result/cummulative_returns_test_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/test_result/cummulative_returns_test_1.png -------------------------------------------------------------------------------- /test_result/efficient_frontier.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/test_result/efficient_frontier.png -------------------------------------------------------------------------------- /test_result/kpi_backtest.csv: -------------------------------------------------------------------------------- 1 | ,Avg. monthly return,Pos months pct,Avg yearly return,Max monthly dd,Max drawdown,Lake ratio,Gain to Pain,Sharpe ratio,Sortino ratio 2 | KPI,0.384797258876,49.2156862745098,3.44328458684,-1.28,-2.4,0.0178496358163,1.48968700183,0.577476288178,0.881134359686 3 | -------------------------------------------------------------------------------- /test_result/kpi_test_1.csv: -------------------------------------------------------------------------------- 1 | ,Avg. monthly return,Pos months pct,Avg yearly return,Max monthly dd,Max drawdown,Lake ratio,Gain to Pain,Sharpe ratio,Sortino ratio 2 | KPI,1.41869638196,57.84313725490197,3.8704579038,-2,-4.09,0.0395631134021,1.23491340742,0.94625848311,1.3355405784 3 | -------------------------------------------------------------------------------- /test_result/portfolios_return.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/test_result/portfolios_return.png -------------------------------------------------------------------------------- /test_result/portfolios_returns.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/test_result/portfolios_returns.png -------------------------------------------------------------------------------- /test_result/portfolios_risk.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/test_result/portfolios_risk.png -------------------------------------------------------------------------------- /test_result/price_prediction_LSTM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/test_result/price_prediction_LSTM.png -------------------------------------------------------------------------------- /train_iterations_9.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_iterations_9.gif -------------------------------------------------------------------------------- /train_result/asset_evolution_train_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/asset_evolution_train_1.png -------------------------------------------------------------------------------- /train_result/asset_evolution_train_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/asset_evolution_train_10.png -------------------------------------------------------------------------------- /train_result/asset_evolution_train_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/asset_evolution_train_11.png -------------------------------------------------------------------------------- /train_result/asset_evolution_train_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/asset_evolution_train_2.png -------------------------------------------------------------------------------- /train_result/asset_evolution_train_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/asset_evolution_train_3.png -------------------------------------------------------------------------------- /train_result/asset_evolution_train_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/asset_evolution_train_4.png -------------------------------------------------------------------------------- /train_result/asset_evolution_train_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/asset_evolution_train_5.png -------------------------------------------------------------------------------- /train_result/asset_evolution_train_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/asset_evolution_train_6.png -------------------------------------------------------------------------------- /train_result/asset_evolution_train_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/asset_evolution_train_7.png -------------------------------------------------------------------------------- /train_result/asset_evolution_train_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/asset_evolution_train_8.png -------------------------------------------------------------------------------- /train_result/asset_evolution_train_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/asset_evolution_train_9.png -------------------------------------------------------------------------------- /train_result/cummulative_returns_train_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/cummulative_returns_train_1.png -------------------------------------------------------------------------------- /train_result/cummulative_returns_train_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/cummulative_returns_train_10.png -------------------------------------------------------------------------------- /train_result/cummulative_returns_train_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/cummulative_returns_train_2.png -------------------------------------------------------------------------------- /train_result/cummulative_returns_train_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/cummulative_returns_train_3.png -------------------------------------------------------------------------------- /train_result/cummulative_returns_train_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/cummulative_returns_train_4.png -------------------------------------------------------------------------------- /train_result/cummulative_returns_train_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/cummulative_returns_train_5.png -------------------------------------------------------------------------------- /train_result/cummulative_returns_train_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/cummulative_returns_train_6.png -------------------------------------------------------------------------------- /train_result/cummulative_returns_train_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/cummulative_returns_train_7.png -------------------------------------------------------------------------------- /train_result/cummulative_returns_train_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/cummulative_returns_train_8.png -------------------------------------------------------------------------------- /train_result/cummulative_returns_train_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiewwantan/StarTrader/b81d5c37f9e06e740c61d9783750bb16fffdd33e/train_result/cummulative_returns_train_9.png -------------------------------------------------------------------------------- /train_result/kpi_train_1.csv: -------------------------------------------------------------------------------- 1 | ,Avg. monthly return,Pos months pct,Avg yearly return,Max monthly dd,Max drawdown,Lake ratio,Gain to Pain,Sharpe ratio,Sortino ratio 2 | KPI,0.868741890631,53.47358121330724,10.2170228228,-2.17,-6.69,0.0940450425508,0.792966158356,0.558347516033,0.798649743083 3 | -------------------------------------------------------------------------------- /train_result/kpi_train_10.csv: -------------------------------------------------------------------------------- 1 | ,Avg. monthly return,Pos months pct,Avg yearly return,Max monthly dd,Max drawdown,Lake ratio,Gain to Pain,Sharpe ratio,Sortino ratio 2 | KPI,1.32611976461,54.17156286721504,15.6440160667,-2.08,-6.38,0.0355723668945,0.204786346071,0.928934411628,0.173474724563 3 | -------------------------------------------------------------------------------- /train_result/kpi_train_2.csv: -------------------------------------------------------------------------------- 1 | ,Avg. monthly return,Pos months pct,Avg yearly return,Max monthly dd,Max drawdown,Lake ratio,Gain to Pain,Sharpe ratio,Sortino ratio 2 | KPI,0.640332831824,51.2720156555773,7.02923228398,-2.51,-3.41,0.0517834512147,0.640076187631,0.428556037997,0.621309289651 3 | -------------------------------------------------------------------------------- /train_result/kpi_train_3.csv: -------------------------------------------------------------------------------- 1 | ,Avg. monthly return,Pos months pct,Avg yearly return,Max monthly dd,Max drawdown,Lake ratio,Gain to Pain,Sharpe ratio,Sortino ratio 2 | KPI,0.418111638812,51.56555772994129,4.74311658088,-2.09,-8.16,0.107625342005,0.314120450446,0.192896687976,0.267347484606 3 | -------------------------------------------------------------------------------- /train_result/kpi_train_4.csv: -------------------------------------------------------------------------------- 1 | ,Avg. monthly return,Pos months pct,Avg yearly return,Max monthly dd,Max drawdown,Lake ratio,Gain to Pain,Sharpe ratio,Sortino ratio 2 | KPI,0.806692096574,52.6908023483366,8.99711662931,-1.81,-8.23,0.0686801277596,0.843390655755,0.552771717154,0.796880009409 3 | -------------------------------------------------------------------------------- /train_result/kpi_train_5.csv: -------------------------------------------------------------------------------- 1 | ,Avg. monthly return,Pos months pct,Avg yearly return,Max monthly dd,Max drawdown,Lake ratio,Gain to Pain,Sharpe ratio,Sortino ratio 2 | KPI,1.06979591874,54.109589041095894,11.9608325657,-1.84,-7.67,0.0846979767109,1.05175799254,0.693293494501,1.0105019238 3 | -------------------------------------------------------------------------------- /train_result/kpi_train_6.csv: -------------------------------------------------------------------------------- 1 | ,Avg. monthly return,Pos months pct,Avg yearly return,Max monthly dd,Max drawdown,Lake ratio,Gain to Pain,Sharpe ratio,Sortino ratio 2 | KPI,1.1403025973,54.59882583170255,12.8482034616,-1.74,-6.88,0.0697130237698,1.23040299427,0.77590085894,1.12930357835 3 | -------------------------------------------------------------------------------- /train_result/kpi_train_7.csv: -------------------------------------------------------------------------------- 1 | ,Avg. monthly return,Pos months pct,Avg yearly return,Max monthly dd,Max drawdown,Lake ratio,Gain to Pain,Sharpe ratio,Sortino ratio 2 | KPI,1.53304467069,55.283757338551865,17.2352833671,-1.69,-8.2,0.0704577793934,1.49385458682,0.916896826069,1.34057729725 3 | -------------------------------------------------------------------------------- /train_result/kpi_train_8.csv: -------------------------------------------------------------------------------- 1 | ,Avg. monthly return,Pos months pct,Avg yearly return,Max monthly dd,Max drawdown,Lake ratio,Gain to Pain,Sharpe ratio,Sortino ratio 2 | KPI,2.2381635965,54.01174168297456,26.6304669876,-7.47,-3.71,0.0352518826338,2.49573425564,1.2167193459,1.81690763343 3 | -------------------------------------------------------------------------------- /train_result/kpi_train_9.csv: -------------------------------------------------------------------------------- 1 | ,Avg. monthly return,Pos months pct,Avg yearly return,Max monthly dd,Max drawdown,Lake ratio,Gain to Pain,Sharpe ratio,Sortino ratio 2 | KPI,1.90387965587,52.98434442270059,22.3442318826,-8.2,-4.4,0.0419270101037,1.73680487363,0.962348814721,1.4271890396 3 | --------------------------------------------------------------------------------