├── .gitignore ├── LICENSE ├── README.md ├── data ├── AAPL.csv └── MSFT.csv ├── env ├── StockTradingEnv.py └── __init__.py ├── main.py └── render ├── StockTradingGraph.py └── __init__.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Adam King 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Stock-Trading-Visualization 2 | 3 | A custom OpenAI gym environment for simulating stock trades on historical price data with live rendering. 4 | 5 | In this article, we've added a simple, yet elegant visualization of the agent's trades using Matplotlib. 6 | 7 | If you'd like to learn more about how we created this visualization, check out the Medium article: https://medium.com/@notadamking/visualizing-stock-trading-agents-using-matplotlib-and-gym-584c992bc6d4 8 | -------------------------------------------------------------------------------- /env/StockTradingEnv.py: -------------------------------------------------------------------------------- 1 | import random 2 | import json 3 | import gym 4 | from gym import spaces 5 | import pandas as pd 6 | import numpy as np 7 | 8 | from render.StockTradingGraph import StockTradingGraph 9 | 10 | MAX_ACCOUNT_BALANCE = 2147483647 11 | MAX_NUM_SHARES = 2147483647 12 | MAX_SHARE_PRICE = 5000 13 | MAX_OPEN_POSITIONS = 5 14 | MAX_STEPS = 20000 15 | 16 | INITIAL_ACCOUNT_BALANCE = 10000 17 | 18 | LOOKBACK_WINDOW_SIZE = 40 19 | 20 | 21 | def factor_pairs(val): 22 | return [(i, val / i) for i in range(1, int(val**0.5)+1) if val % i == 0] 23 | 24 | 25 | class StockTradingEnv(gym.Env): 26 | """A stock trading environment for OpenAI gym""" 27 | metadata = {'render.modes': ['live', 'file', 'none']} 28 | visualization = None 29 | 30 | def __init__(self, df): 31 | super(StockTradingEnv, self).__init__() 32 | 33 | self.df = self._adjust_prices(df) 34 | self.reward_range = (0, MAX_ACCOUNT_BALANCE) 35 | 36 | # Actions of the format Buy x%, Sell x%, Hold, etc. 37 | self.action_space = spaces.Box( 38 | low=np.array([0, 0]), high=np.array([3, 1]), dtype=np.float16) 39 | 40 | # Prices contains the OHCL values for the last five prices 41 | self.observation_space = spaces.Box( 42 | low=0, high=1, shape=(5, LOOKBACK_WINDOW_SIZE + 2), dtype=np.float16) 43 | 44 | def _adjust_prices(self, df): 45 | adjust_ratio = df['Adjusted_Close'] / df['Close'] 46 | 47 | df['Open'] = df['Open'] * adjust_ratio 48 | df['High'] = df['High'] * adjust_ratio 49 | df['Low'] = df['Low'] * adjust_ratio 50 | df['Close'] = df['Close'] * adjust_ratio 51 | 52 | return df 53 | 54 | def _next_observation(self): 55 | frame = np.zeros((5, LOOKBACK_WINDOW_SIZE + 1)) 56 | 57 | # Get the stock data points for the last 5 days and scale to between 0-1 58 | np.put(frame, [0, 4], [ 59 | self.df.loc[self.current_step: self.current_step + 60 | LOOKBACK_WINDOW_SIZE, 'Open'].values / MAX_SHARE_PRICE, 61 | self.df.loc[self.current_step: self.current_step + 62 | LOOKBACK_WINDOW_SIZE, 'High'].values / MAX_SHARE_PRICE, 63 | self.df.loc[self.current_step: self.current_step + 64 | LOOKBACK_WINDOW_SIZE, 'Low'].values / MAX_SHARE_PRICE, 65 | self.df.loc[self.current_step: self.current_step + 66 | LOOKBACK_WINDOW_SIZE, 'Close'].values / MAX_SHARE_PRICE, 67 | self.df.loc[self.current_step: self.current_step + 68 | LOOKBACK_WINDOW_SIZE, 'Volume'].values / MAX_NUM_SHARES, 69 | ]) 70 | 71 | # Append additional data and scale each value to between 0-1 72 | obs = np.append(frame, [ 73 | [self.balance / MAX_ACCOUNT_BALANCE], 74 | [self.max_net_worth / MAX_ACCOUNT_BALANCE], 75 | [self.shares_held / MAX_NUM_SHARES], 76 | [self.cost_basis / MAX_SHARE_PRICE], 77 | [self.total_sales_value / (MAX_NUM_SHARES * MAX_SHARE_PRICE)], 78 | ], axis=1) 79 | 80 | return obs 81 | 82 | def _take_action(self, action): 83 | current_price = random.uniform( 84 | self.df.loc[self.current_step, "Open"], self.df.loc[self.current_step, "Close"]) 85 | 86 | action_type = action[0] 87 | amount = action[1] 88 | 89 | if action_type < 1: 90 | # Buy amount % of balance in shares 91 | total_possible = int(self.balance / current_price) 92 | shares_bought = int(total_possible * amount) 93 | prev_cost = self.cost_basis * self.shares_held 94 | additional_cost = shares_bought * current_price 95 | 96 | self.balance -= additional_cost 97 | self.cost_basis = ( 98 | prev_cost + additional_cost) / (self.shares_held + shares_bought) 99 | self.shares_held += shares_bought 100 | 101 | if shares_bought > 0: 102 | self.trades.append({'step': self.current_step, 103 | 'shares': shares_bought, 'total': additional_cost, 104 | 'type': "buy"}) 105 | 106 | elif action_type < 2: 107 | # Sell amount % of shares held 108 | shares_sold = int(self.shares_held * amount) 109 | self.balance += shares_sold * current_price 110 | self.shares_held -= shares_sold 111 | self.total_shares_sold += shares_sold 112 | self.total_sales_value += shares_sold * current_price 113 | 114 | if shares_sold > 0: 115 | self.trades.append({'step': self.current_step, 116 | 'shares': shares_sold, 'total': shares_sold * current_price, 117 | 'type': "sell"}) 118 | 119 | self.net_worth = self.balance + self.shares_held * current_price 120 | 121 | if self.net_worth > self.max_net_worth: 122 | self.max_net_worth = self.net_worth 123 | 124 | if self.shares_held == 0: 125 | self.cost_basis = 0 126 | 127 | def step(self, action): 128 | # Execute one time step within the environment 129 | self._take_action(action) 130 | 131 | self.current_step += 1 132 | 133 | delay_modifier = (self.current_step / MAX_STEPS) 134 | 135 | reward = self.balance * delay_modifier + self.current_step 136 | done = self.net_worth <= 0 or self.current_step >= len( 137 | self.df.loc[:, 'Open'].values) 138 | 139 | obs = self._next_observation() 140 | 141 | return obs, reward, done, {} 142 | 143 | def reset(self): 144 | # Reset the state of the environment to an initial state 145 | self.balance = INITIAL_ACCOUNT_BALANCE 146 | self.net_worth = INITIAL_ACCOUNT_BALANCE 147 | self.max_net_worth = INITIAL_ACCOUNT_BALANCE 148 | self.shares_held = 0 149 | self.cost_basis = 0 150 | self.total_shares_sold = 0 151 | self.total_sales_value = 0 152 | self.current_step = 0 153 | self.trades = [] 154 | 155 | return self._next_observation() 156 | 157 | def _render_to_file(self, filename='render.txt'): 158 | profit = self.net_worth - INITIAL_ACCOUNT_BALANCE 159 | 160 | file = open(filename, 'a+') 161 | 162 | file.write(f'Step: {self.current_step}\n') 163 | file.write(f'Balance: {self.balance}\n') 164 | file.write( 165 | f'Shares held: {self.shares_held} (Total sold: {self.total_shares_sold})\n') 166 | file.write( 167 | f'Avg cost for held shares: {self.cost_basis} (Total sales value: {self.total_sales_value})\n') 168 | file.write( 169 | f'Net worth: {self.net_worth} (Max net worth: {self.max_net_worth})\n') 170 | file.write(f'Profit: {profit}\n\n') 171 | 172 | file.close() 173 | 174 | def render(self, mode='live', **kwargs): 175 | # Render the environment to the screen 176 | if mode == 'file': 177 | self._render_to_file(kwargs.get('filename', 'render.txt')) 178 | 179 | elif mode == 'live': 180 | if self.visualization == None: 181 | self.visualization = StockTradingGraph( 182 | self.df, kwargs.get('title', None)) 183 | 184 | if self.current_step > LOOKBACK_WINDOW_SIZE: 185 | self.visualization.render( 186 | self.current_step, self.net_worth, self.trades, window_size=LOOKBACK_WINDOW_SIZE) 187 | 188 | def close(self): 189 | if self.visualization != None: 190 | self.visualization.close() 191 | self.visualization = None 192 | -------------------------------------------------------------------------------- /env/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notadamking/Stock-Trading-Visualization/39ed1d4dc4ce734853f76a3256ed6de5ee963192/env/__init__.py -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | from stable_baselines.common.policies import MlpPolicy 4 | from stable_baselines.common.vec_env import DummyVecEnv 5 | from stable_baselines import PPO2 6 | 7 | from env.StockTradingEnv import StockTradingEnv 8 | 9 | import pandas as pd 10 | 11 | df = pd.read_csv('./data/MSFT.csv') 12 | df = df.sort_values('Date') 13 | 14 | # The algorithms require a vectorized environment to run 15 | env = DummyVecEnv([lambda: StockTradingEnv(df)]) 16 | 17 | model = PPO2(MlpPolicy, env, verbose=1) 18 | model.learn(total_timesteps=50) 19 | 20 | obs = env.reset() 21 | for i in range(len(df['Date'])): 22 | action, _states = model.predict(obs) 23 | obs, rewards, done, info = env.step(action) 24 | env.render(title="MSFT") 25 | -------------------------------------------------------------------------------- /render/StockTradingGraph.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import numpy as np 4 | import matplotlib 5 | import matplotlib.pyplot as plt 6 | import matplotlib.dates as mdates 7 | from matplotlib import style 8 | 9 | # finance module is no longer part of matplotlib 10 | # see: https://github.com/matplotlib/mpl_finance 11 | from mpl_finance import candlestick_ochl as candlestick 12 | 13 | style.use('dark_background') 14 | 15 | VOLUME_CHART_HEIGHT = 0.33 16 | 17 | UP_COLOR = '#27A59A' 18 | DOWN_COLOR = '#EF534F' 19 | UP_TEXT_COLOR = '#73D3CC' 20 | DOWN_TEXT_COLOR = '#DC2C27' 21 | 22 | 23 | def date2num(date): 24 | converter = mdates.strpdate2num('%Y-%m-%d') 25 | return converter(date) 26 | 27 | 28 | class StockTradingGraph: 29 | """A stock trading visualization using matplotlib made to render OpenAI gym environments""" 30 | 31 | def __init__(self, df, title=None): 32 | self.df = df 33 | self.net_worths = np.zeros(len(df['Date'])) 34 | 35 | # Create a figure on screen and set the title 36 | fig = plt.figure() 37 | fig.suptitle(title) 38 | 39 | # Create top subplot for net worth axis 40 | self.net_worth_ax = plt.subplot2grid( 41 | (6, 1), (0, 0), rowspan=2, colspan=1) 42 | 43 | # Create bottom subplot for shared price/volume axis 44 | self.price_ax = plt.subplot2grid( 45 | (6, 1), (2, 0), rowspan=8, colspan=1, sharex=self.net_worth_ax) 46 | 47 | # Create a new axis for volume which shares its x-axis with price 48 | self.volume_ax = self.price_ax.twinx() 49 | 50 | # Add padding to make graph easier to view 51 | plt.subplots_adjust(left=0.11, bottom=0.24, 52 | right=0.90, top=0.90, wspace=0.2, hspace=0) 53 | 54 | # Show the graph without blocking the rest of the program 55 | plt.show(block=False) 56 | 57 | def _render_net_worth(self, current_step, net_worth, step_range, dates): 58 | # Clear the frame rendered last step 59 | self.net_worth_ax.clear() 60 | 61 | # Plot net worths 62 | self.net_worth_ax.plot_date( 63 | dates, self.net_worths[step_range], '-', label='Net Worth') 64 | 65 | # Show legend, which uses the label we defined for the plot above 66 | self.net_worth_ax.legend() 67 | legend = self.net_worth_ax.legend(loc=2, ncol=2, prop={'size': 8}) 68 | legend.get_frame().set_alpha(0.4) 69 | 70 | last_date = date2num(self.df['Date'].values[current_step]) 71 | last_net_worth = self.net_worths[current_step] 72 | 73 | # Annotate the current net worth on the net worth graph 74 | self.net_worth_ax.annotate('{0:.2f}'.format(net_worth), (last_date, last_net_worth), 75 | xytext=(last_date, last_net_worth), 76 | bbox=dict(boxstyle='round', 77 | fc='w', ec='k', lw=1), 78 | color="black", 79 | fontsize="small") 80 | 81 | # Add space above and below min/max net worth 82 | self.net_worth_ax.set_ylim( 83 | min(self.net_worths[np.nonzero(self.net_worths)]) / 1.25, max(self.net_worths) * 1.25) 84 | 85 | def _render_price(self, current_step, net_worth, dates, step_range): 86 | self.price_ax.clear() 87 | 88 | # Format data for OHCL candlestick graph 89 | candlesticks = zip(dates, 90 | self.df['Open'].values[step_range], self.df['Close'].values[step_range], 91 | self.df['High'].values[step_range], self.df['Low'].values[step_range]) 92 | 93 | # Plot price using candlestick graph from mpl_finance 94 | candlestick(self.price_ax, candlesticks, width=1, 95 | colorup=UP_COLOR, colordown=DOWN_COLOR) 96 | 97 | last_date = date2num(self.df['Date'].values[current_step]) 98 | last_close = self.df['Close'].values[current_step] 99 | last_high = self.df['High'].values[current_step] 100 | 101 | # Print the current price to the price axis 102 | self.price_ax.annotate('{0:.2f}'.format(last_close), (last_date, last_close), 103 | xytext=(last_date, last_high), 104 | bbox=dict(boxstyle='round', 105 | fc='w', ec='k', lw=1), 106 | color="black", 107 | fontsize="small") 108 | 109 | # Shift price axis up to give volume chart space 110 | ylim = self.price_ax.get_ylim() 111 | self.price_ax.set_ylim(ylim[0] - (ylim[1] - ylim[0]) 112 | * VOLUME_CHART_HEIGHT, ylim[1]) 113 | 114 | def _render_volume(self, current_step, net_worth, dates, step_range): 115 | self.volume_ax.clear() 116 | 117 | volume = np.array(self.df['Volume'].values[step_range]) 118 | 119 | pos = self.df['Open'].values[step_range] - \ 120 | self.df['Close'].values[step_range] < 0 121 | neg = self.df['Open'].values[step_range] - \ 122 | self.df['Close'].values[step_range] > 0 123 | 124 | # Color volume bars based on price direction on that date 125 | self.volume_ax.bar(dates[pos], volume[pos], color=UP_COLOR, 126 | alpha=0.4, width=1, align='center') 127 | self.volume_ax.bar(dates[neg], volume[neg], color=DOWN_COLOR, 128 | alpha=0.4, width=1, align='center') 129 | 130 | # Cap volume axis height below price chart and hide ticks 131 | self.volume_ax.set_ylim(0, max(volume) / VOLUME_CHART_HEIGHT) 132 | self.volume_ax.yaxis.set_ticks([]) 133 | 134 | def _render_trades(self, current_step, trades, step_range): 135 | for trade in trades: 136 | if trade['step'] in step_range: 137 | date = date2num(self.df['Date'].values[trade['step']]) 138 | high = self.df['High'].values[trade['step']] 139 | low = self.df['Low'].values[trade['step']] 140 | 141 | if trade['type'] == 'buy': 142 | high_low = low 143 | color = UP_TEXT_COLOR 144 | else: 145 | high_low = high 146 | color = DOWN_TEXT_COLOR 147 | 148 | total = '{0:.2f}'.format(trade['total']) 149 | 150 | # Print the current price to the price axis 151 | self.price_ax.annotate(f'${total}', (date, high_low), 152 | xytext=(date, high_low), 153 | color=color, 154 | fontsize=8, 155 | arrowprops=(dict(color=color))) 156 | 157 | def render(self, current_step, net_worth, trades, window_size=40): 158 | self.net_worths[current_step] = net_worth 159 | 160 | window_start = max(current_step - window_size, 0) 161 | step_range = range(window_start, current_step + 1) 162 | 163 | # Format dates as timestamps, necessary for candlestick graph 164 | dates = np.array([date2num(x) 165 | for x in self.df['Date'].values[step_range]]) 166 | 167 | self._render_net_worth(current_step, net_worth, step_range, dates) 168 | self._render_price(current_step, net_worth, dates, step_range) 169 | self._render_volume(current_step, net_worth, dates, step_range) 170 | self._render_trades(current_step, trades, step_range) 171 | 172 | # Format the date ticks to be more easily read 173 | self.price_ax.set_xticklabels(self.df['Date'].values[step_range], rotation=45, 174 | horizontalalignment='right') 175 | 176 | # Hide duplicate net worth date labels 177 | plt.setp(self.net_worth_ax.get_xticklabels(), visible=False) 178 | 179 | # Necessary to view frames before they are unrendered 180 | plt.pause(0.001) 181 | 182 | def close(self): 183 | plt.close() 184 | -------------------------------------------------------------------------------- /render/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notadamking/Stock-Trading-Visualization/39ed1d4dc4ce734853f76a3256ed6de5ee963192/render/__init__.py --------------------------------------------------------------------------------