├── .gitignore ├── setup.py ├── LICENSE ├── trpo ├── plotting.py ├── value.py ├── utils.py ├── policy.py ├── archive.py └── train.py ├── notebooks ├── plotting.py └── env_dimension_sizes.ipynb └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | # logs and checkpoints 2 | .idea/ 3 | *.log 4 | notebooks/__pycache__/ 5 | notebooks/.ipynb_checkpoints/ 6 | src/__pycache__/ 7 | src/log-files/ 8 | src/.ipynb_checkpoints/ 9 | tmp/ 10 | doc/ 11 | .mlt/ 12 | .mlt 13 | 14 | 15 | # generated files 16 | exportToHTML/ 17 | 18 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """A setuptools based setup module. 2 | 3 | See: 4 | https://packaging.python.org/guides/distributing-packages-using-setuptools/ 5 | https://github.com/pypa/sampleproject 6 | """ 7 | 8 | from setuptools import setup, find_packages 9 | from os import path 10 | 11 | here = path.abspath(path.dirname(__file__)) 12 | 13 | with open(path.join(here, 'README.md'), encoding='utf-8') as f: 14 | long_description = f.read() 15 | 16 | setup( 17 | name='trpo', 18 | version='1.0.0', 19 | description='Audio representation learning.', 20 | packages=find_packages(exclude=['contrib', 'docs', 'tests']), 21 | python_requires='>=3.6', install_requires=['tensorflow', 'numpy', 'pybullet', 'gym', 'scipy'] 22 | ) 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 pat-coady 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /trpo/plotting.py: -------------------------------------------------------------------------------- 1 | """ 2 | Short Plotting Routine to Plot Pandas Dataframes by Column Label 3 | 4 | 1. Takes list of dateframes to compare multiple trials 5 | 2. Takes list of y-variables to combine on 1 plot 6 | 3. Legend location and y-axis limits can be customized 7 | 8 | Written by Patrick Coady (pat-coady.github.io) 9 | """ 10 | import matplotlib.pyplot as plt 11 | 12 | 13 | def df_plot(dfs, x, ys, ylim=None, legend_loc='best'): 14 | """ Plot y vs. x curves from pandas dataframe(s) 15 | 16 | Args: 17 | dfs: list of pandas dataframes 18 | x: str column label for x variable 19 | ys: list of str column labels for y variable(s) 20 | ylim: tuple to override automatic y-axis limits 21 | legend_loc: str to override automatic legend placement: 22 | 'upper left', 'lower left', 'lower right' , 'right' , 23 | 'center left', 'center right', 'lower center', 24 | 'upper center', and 'center' 25 | """ 26 | if ylim is not None: 27 | plt.ylim(ylim) 28 | for df, name in dfs: 29 | if '_' in name: 30 | name = name.split('_')[1] 31 | for y in ys: 32 | plt.plot(df[x], df[y], linewidth=1, 33 | label=name + ' ' + y.replace('_', '')) 34 | plt.xlabel(x.replace('_', '')) 35 | plt.legend(loc=legend_loc) 36 | plt.show() 37 | -------------------------------------------------------------------------------- /notebooks/plotting.py: -------------------------------------------------------------------------------- 1 | """ 2 | Short Plotting Routine to Plot Pandas Dataframes by Column Label 3 | 4 | 1. Takes list of dateframes to compare multiple trials 5 | 2. Takes list of y-variables to combine on 1 plot 6 | 3. Legend location and y-axis limits can be customized 7 | 8 | Written by Patrick Coady (pat-coady.github.io) 9 | """ 10 | import matplotlib.pyplot as plt 11 | 12 | 13 | def df_plot(dfs, x, ys, ylim=None, xlim=None, legend_loc='best'): 14 | """ Plot y vs. x curves from pandas dataframe(s) 15 | 16 | Args: 17 | dfs: list of pandas dataframes 18 | x: str column label for x variable 19 | ys: list of str column labels for y variable(s) 20 | ylim: tuple to override automatic y-axis limits 21 | xlim: tuple to override automatic x-axis limits 22 | legend_loc: str to override automatic legend placement: 23 | 'upper left', 'lower left', 'lower right' , 'right' , 24 | 'center left', 'center right', 'lower center', 25 | 'upper center', and 'center' 26 | """ 27 | if ylim is not None: 28 | plt.ylim(ylim) 29 | if xlim is not None: 30 | plt.xlim(xlim) 31 | for df, name in dfs: 32 | if '_' in name: 33 | name = name.split('_')[1] 34 | for y in ys: 35 | plt.plot(df[x], df[y], linewidth=1, 36 | label=name + ' ' + y.replace('_', '')) 37 | plt.xlabel(x.replace('_', '')) 38 | plt.legend(loc=legend_loc) 39 | plt.show() 40 | -------------------------------------------------------------------------------- /trpo/value.py: -------------------------------------------------------------------------------- 1 | """ 2 | State-Value Function 3 | 4 | Written by Patrick Coady (pat-coady.github.io) 5 | """ 6 | from tensorflow.keras import Model 7 | from tensorflow.keras.layers import Input, Dense 8 | from tensorflow.keras.optimizers import Adam 9 | 10 | import numpy as np 11 | 12 | 13 | class NNValueFunction(object): 14 | """ NN-based state-value function """ 15 | def __init__(self, obs_dim, hid1_mult): 16 | """ 17 | Args: 18 | obs_dim: number of dimensions in observation vector (int) 19 | hid1_mult: size of first hidden layer, multiplier of obs_dim 20 | """ 21 | self.replay_buffer_x = None 22 | self.replay_buffer_y = None 23 | self.obs_dim = obs_dim 24 | self.hid1_mult = hid1_mult 25 | self.epochs = 10 26 | self.lr = None # learning rate set in _build_model() 27 | self.model = self._build_model() 28 | 29 | def _build_model(self): 30 | """ Construct TensorFlow graph, including loss function, init op and train op """ 31 | obs = Input(shape=(self.obs_dim,), dtype='float32') 32 | # hid1 layer size is 10x obs_dim, hid3 size is 10, and hid2 is geometric mean 33 | hid1_units = self.obs_dim * self.hid1_mult 34 | hid3_units = 5 # 5 chosen empirically on 'Hopper-v1' 35 | hid2_units = int(np.sqrt(hid1_units * hid3_units)) 36 | # heuristic to set learning rate based on NN size (tuned on 'Hopper-v1') 37 | self.lr = 1e-2 / np.sqrt(hid2_units) # 1e-2 empirically determined 38 | print('Value Params -- h1: {}, h2: {}, h3: {}, lr: {:.3g}' 39 | .format(hid1_units, hid2_units, hid3_units, self.lr)) 40 | y = Dense(hid1_units, activation='tanh')(obs) 41 | y = Dense(hid2_units, activation='tanh')(y) 42 | y = Dense(hid3_units, activation='tanh')(y) 43 | y = Dense(1)(y) 44 | model = Model(inputs=obs, outputs=y) 45 | optimizer = Adam(self.lr) 46 | model.compile(optimizer=optimizer, loss='mse') 47 | 48 | return model 49 | 50 | def fit(self, x, y, logger): 51 | """ Fit model to current data batch + previous data batch 52 | 53 | Args: 54 | x: features 55 | y: target 56 | logger: logger to save training loss and % explained variance 57 | """ 58 | num_batches = max(x.shape[0] // 256, 1) 59 | batch_size = x.shape[0] // num_batches 60 | y_hat = self.model.predict(x) # check explained variance prior to update 61 | old_exp_var = 1 - np.var(y - y_hat)/np.var(y) 62 | if self.replay_buffer_x is None: 63 | x_train, y_train = x, y 64 | else: 65 | x_train = np.concatenate([x, self.replay_buffer_x]) 66 | y_train = np.concatenate([y, self.replay_buffer_y]) 67 | self.replay_buffer_x = x 68 | self.replay_buffer_y = y 69 | self.model.fit(x_train, y_train, epochs=self.epochs, batch_size=batch_size, 70 | shuffle=True, verbose=0) 71 | y_hat = self.model.predict(x) 72 | loss = np.mean(np.square(y_hat - y)) # explained variance after update 73 | exp_var = 1 - np.var(y - y_hat) / np.var(y) # diagnose over-fitting of val func 74 | 75 | logger.log({'ValFuncLoss': loss, 76 | 'ExplainedVarNew': exp_var, 77 | 'ExplainedVarOld': old_exp_var}) 78 | 79 | def predict(self, x): 80 | """ Predict method """ 81 | return self.model.predict(x) 82 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Trust Region Policy Optimization with Generalized Advantage Estimation 2 | 3 | By Patrick Coady: [Learning Artificial Intelligence](https://learningai.io/) 4 | 5 | ### Summary 6 | 7 | **NOTE:** The code has been refactored to use TensorFlow 2.0 and PyBullet (instead of MuJoCo). See the `tf1_mujoco` branch for old version. 8 | 9 | The project's original goal was to use the same algorithm to "solve" [10 MuJoCo robotic control environments](https://gym.openai.com/envs/#mujoco). And, specifically, to achieve this without hand-tuning the hyperparameters (network sizes, learning rates, and TRPO settings) for each environment. This is challenging because the environments range from a simple cart pole problem with a single control input to a humanoid with 17 controlled joints and 44 observed variables. The project was successful, nabbing top spots on almost all of the AI Gym MuJoCo leaderboards. 10 | 11 | With the release of TensorFlow 2.0, I decided to dust off this project and upgrade the code. And, while I was at it, I moved from the paid MuJoCo simulator to the free PyBullet simulator. 12 | 13 | Here are the key points: 14 | 15 | * Trust Region Policy Optimization \[1\] \[2\] 16 | * Value function approximated with 3 hidden-layer NN (tanh activations): 17 | * hid1 size = obs_dim x 10 18 | * hid2 size = geometric mean of hid1 and hid3 sizes 19 | * hid3 size = 5 20 | * Policy is a multi-variate Gaussian parameterized by a 3 hidden-layer NN (tanh activations): 21 | * hid1 size = obs_dim x 10 22 | * hid2 size = geometric mean of hid1 and hid3 sizes 23 | * hid3 size = action_dim x 10 24 | * Diagonal covariance matrix variables are separately trained 25 | * Generalized Advantage Estimation (gamma = 0.995, lambda = 0.98) \[3\] \[4\] 26 | * ADAM optimizer used for both neural networks 27 | * The policy is evaluated for 20 episodes between updates, except: 28 | * 50 episodes for Reacher 29 | * 5 episodes for Swimmer 30 | * 5 episodes for HalfCheetah 31 | * 5 episodes for HumanoidStandup 32 | * Value function is trained on current batch + previous batch 33 | * KL loss factor and ADAM learning rate are dynamically adjusted during training 34 | * Policy and Value NNs built with TensorFlow 35 | 36 | ### PyBullet Gym Environments 37 | 38 | ``` 39 | HumanoidDeepMimicBulletEnv-v1 40 | CartPoleBulletEnv-v1 41 | MinitaurBulletEnv-v0 42 | MinitaurBulletDuckEnv-v0 43 | RacecarBulletEnv-v0 44 | RacecarZedBulletEnv-v0 45 | KukaBulletEnv-v0 46 | KukaCamBulletEnv-v0 47 | InvertedPendulumBulletEnv-v0 48 | InvertedDoublePendulumBulletEnv-v0 49 | InvertedPendulumSwingupBulletEnv-v0 50 | ReacherBulletEnv-v0 51 | PusherBulletEnv-v0 52 | ThrowerBulletEnv-v0 53 | StrikerBulletEnv-v0 54 | Walker2DBulletEnv-v0 55 | HalfCheetahBulletEnv-v0 56 | AntBulletEnv-v0 57 | HopperBulletEnv-v0 58 | HumanoidBulletEnv-v0 59 | HumanoidFlagrunBulletEnv-v0 60 | HumanoidFlagrunHarderBulletEnv-v0 61 | ``` 62 | 63 | ### Using 64 | 65 | I ran quick checks on three of the above environments and successfully stabilized a double-inverted pendulum and taught the "half cheetah" to run. 66 | 67 | ``` 68 | python train.py InvertedPendulumBulletEnv-v0 69 | python train.py InvertedDoublePendulumBulletEnv-v0 -n 5000 70 | python train.py HalfCheetahBulletEnv-v0 -n 5000 -b 5 71 | ``` 72 | 73 | ### Videos 74 | 75 | During training, videos are periodically saved automatically to the /tmp folder. These can be enjoyable to view, and also instructive. 76 | 77 | ### Dependencies 78 | 79 | * Python 3.6 80 | * The Usual Suspects: numpy, matplotlib, scipy 81 | * TensorFlow 2.x 82 | * Open AI Gym: [installation instructions](https://gym.openai.com/docs) 83 | * [pybullet](https://pypi.org/project/pybullet/) physics simulator 84 | 85 | ### References 86 | 87 | 1. [Trust Region Policy Optimization](https://arxiv.org/pdf/1502.05477.pdf) (Schulman et al., 2016) 88 | 2. [Emergence of Locomotion Behaviours in Rich Environments](https://arxiv.org/pdf/1707.02286.pdf) (Heess et al., 2017) 89 | 3. [High-Dimensional Continuous Control Using Generalized Advantage Estimation](https://arxiv.org/pdf/1506.02438.pdf) (Schulman et al., 2016) 90 | 4. [GitHub Repository with several helpful implementation ideas](https://github.com/joschu/modular_rl) (Schulman) 91 | -------------------------------------------------------------------------------- /notebooks/env_dimension_sizes.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import numpy as np\n", 12 | "import gym" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "metadata": { 19 | "collapsed": true 20 | }, 21 | "outputs": [], 22 | "source": [ 23 | "def init_gym(env_name):\n", 24 | " \"\"\"\n", 25 | "\n", 26 | " :param env_name: str, OpenAI Gym environment name\n", 27 | " :return: 3-tuple\n", 28 | " env: ai gym environment\n", 29 | " obs_dim: observation dimensions\n", 30 | " act_dim: action dimensions\n", 31 | " \"\"\"\n", 32 | " env = gym.make(env_name)\n", 33 | " obs_dim = env.observation_space.shape[0]\n", 34 | " act_dim = env.action_space.shape[0]\n", 35 | " obs_dim += 1\n", 36 | "\n", 37 | " return env, obs_dim, act_dim" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 3, 43 | "metadata": { 44 | "collapsed": true 45 | }, 46 | "outputs": [], 47 | "source": [ 48 | "env_names = ['InvertedPendulum-v1', 'InvertedDoublePendulum-v1', 'Reacher-v1',\n", 49 | " 'HalfCheetah-v1', 'Swimmer-v1', 'Hopper-v1', 'Walker2d-v1', \n", 50 | " 'Ant-v1', 'Humanoid-v1', 'HumanoidStandup-v1']" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 4, 56 | "metadata": {}, 57 | "outputs": [ 58 | { 59 | "name": "stderr", 60 | "output_type": "stream", 61 | "text": [ 62 | "[2017-07-22 06:35:06,733] Making new env: InvertedPendulum-v1\n", 63 | "[2017-07-22 06:35:06,974] Making new env: InvertedDoublePendulum-v1\n", 64 | "[2017-07-22 06:35:06,979] Making new env: Reacher-v1\n", 65 | "[2017-07-22 06:35:06,988] Making new env: HalfCheetah-v1\n", 66 | "[2017-07-22 06:35:06,997] Making new env: Swimmer-v1\n", 67 | "[2017-07-22 06:35:07,003] Making new env: Hopper-v1\n", 68 | "[2017-07-22 06:35:07,012] Making new env: Walker2d-v1\n", 69 | "[2017-07-22 06:35:07,019] Making new env: Ant-v1\n", 70 | "[2017-07-22 06:35:07,030] Making new env: Humanoid-v1\n", 71 | "[2017-07-22 06:35:07,037] Making new env: HumanoidStandup-v1\n" 72 | ] 73 | } 74 | ], 75 | "source": [ 76 | "dims = []\n", 77 | "for env_name in env_names:\n", 78 | " env, obs_dim, act_dim = init_gym(env_name)\n", 79 | " dims.append((obs_dim, act_dim))" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 5, 85 | "metadata": {}, 86 | "outputs": [ 87 | { 88 | "name": "stdout", 89 | "output_type": "stream", 90 | "text": [ 91 | "InvertedPendulum-v1: obs_dim: 5, act_dim: 1\n", 92 | "InvertedDoublePendulum-v1: obs_dim: 12, act_dim: 1\n", 93 | "Reacher-v1: obs_dim: 12, act_dim: 2\n", 94 | "HalfCheetah-v1: obs_dim: 18, act_dim: 6\n", 95 | "Swimmer-v1: obs_dim: 9, act_dim: 2\n", 96 | "Hopper-v1: obs_dim: 12, act_dim: 3\n", 97 | "Walker2d-v1: obs_dim: 18, act_dim: 6\n", 98 | "Ant-v1: obs_dim: 112, act_dim: 8\n", 99 | "Humanoid-v1: obs_dim: 377, act_dim: 17\n", 100 | "HumanoidStandup-v1: obs_dim: 377, act_dim: 17\n" 101 | ] 102 | } 103 | ], 104 | "source": [ 105 | "names_dims = list(zip(env_names, dims))\n", 106 | "for name_dim in names_dims:\n", 107 | " print('{}: obs_dim: {}, act_dim: {}'.format(name_dim[0], name_dim[1][0], name_dim[1][1]))" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": { 114 | "collapsed": true 115 | }, 116 | "outputs": [], 117 | "source": [] 118 | } 119 | ], 120 | "metadata": { 121 | "kernelspec": { 122 | "display_name": "Python 3", 123 | "language": "python", 124 | "name": "python3" 125 | }, 126 | "language_info": { 127 | "codemirror_mode": { 128 | "name": "ipython", 129 | "version": 3 130 | }, 131 | "file_extension": ".py", 132 | "mimetype": "text/x-python", 133 | "name": "python", 134 | "nbconvert_exporter": "python", 135 | "pygments_lexer": "ipython3", 136 | "version": "3.5.2" 137 | } 138 | }, 139 | "nbformat": 4, 140 | "nbformat_minor": 2 141 | } 142 | -------------------------------------------------------------------------------- /trpo/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logging and Data Scaling Utilities 3 | 4 | Written by Patrick Coady (pat-coady.github.io) 5 | """ 6 | import numpy as np 7 | import os 8 | import shutil 9 | import glob 10 | import csv 11 | 12 | 13 | class Scaler(object): 14 | """ Generate scale and offset based on running mean and stddev along axis=0 15 | 16 | offset = running mean 17 | scale = 1 / (stddev + 0.1) / 3 (i.e. 3x stddev = +/- 1.0) 18 | """ 19 | 20 | def __init__(self, obs_dim): 21 | """ 22 | Args: 23 | obs_dim: dimension of axis=1 24 | """ 25 | self.vars = np.zeros(obs_dim) 26 | self.means = np.zeros(obs_dim) 27 | self.m = 0 28 | self.n = 0 29 | self.first_pass = True 30 | 31 | def update(self, x): 32 | """ Update running mean and variance (this is an exact method) 33 | Args: 34 | x: NumPy array, shape = (N, obs_dim) 35 | 36 | see: https://stats.stackexchange.com/questions/43159/how-to-calculate-pooled- 37 | variance-of-two-groups-given-known-group-variances-mean 38 | """ 39 | if self.first_pass: 40 | self.means = np.mean(x, axis=0) 41 | self.vars = np.var(x, axis=0) 42 | self.m = x.shape[0] 43 | self.first_pass = False 44 | else: 45 | n = x.shape[0] 46 | new_data_var = np.var(x, axis=0) 47 | new_data_mean = np.mean(x, axis=0) 48 | new_data_mean_sq = np.square(new_data_mean) 49 | new_means = ((self.means * self.m) + (new_data_mean * n)) / (self.m + n) 50 | self.vars = (((self.m * (self.vars + np.square(self.means))) + 51 | (n * (new_data_var + new_data_mean_sq))) / (self.m + n) - 52 | np.square(new_means)) 53 | self.vars = np.maximum(0.0, self.vars) # occasionally goes negative, clip 54 | self.means = new_means 55 | self.m += n 56 | 57 | def get(self): 58 | """ returns 2-tuple: (scale, offset) """ 59 | return 1/(np.sqrt(self.vars) + 0.1)/3, self.means 60 | 61 | 62 | class Logger(object): 63 | """ Simple training logger: saves to file and optionally prints to stdout """ 64 | def __init__(self, logname, now): 65 | """ 66 | Args: 67 | logname: name for log (e.g. 'Hopper-v1') 68 | now: unique sub-directory name (e.g. date/time string) 69 | """ 70 | path = os.path.join('log-files', logname, now) 71 | os.makedirs(path) 72 | filenames = glob.glob('*.py') # put copy of all python files in log_dir 73 | for filename in filenames: # for reference 74 | shutil.copy(filename, path) 75 | path = os.path.join(path, 'log.csv') 76 | 77 | self.write_header = True 78 | self.log_entry = {} 79 | self.f = open(path, 'w') 80 | self.writer = None # DictWriter created with first call to write() method 81 | 82 | def write(self, display=True): 83 | """ Write 1 log entry to file, and optionally to stdout 84 | Log fields preceded by '_' will not be printed to stdout 85 | 86 | Args: 87 | display: boolean, print to stdout 88 | """ 89 | if display: 90 | self.disp(self.log_entry) 91 | if self.write_header: 92 | fieldnames = [x for x in self.log_entry.keys()] 93 | self.writer = csv.DictWriter(self.f, fieldnames=fieldnames) 94 | self.writer.writeheader() 95 | self.write_header = False 96 | self.writer.writerow(self.log_entry) 97 | self.log_entry = {} 98 | 99 | @staticmethod 100 | def disp(log): 101 | """Print metrics to stdout""" 102 | log_keys = [k for k in log.keys()] 103 | log_keys.sort() 104 | print('***** Episode {}, Mean R = {:.1f} *****'.format(log['_Episode'], 105 | log['_MeanReward'])) 106 | for key in log_keys: 107 | if key[0] != '_': # don't display log items with leading '_' 108 | print('{:s}: {:.3g}'.format(key, log[key])) 109 | print('\n') 110 | 111 | def log(self, items): 112 | """ Update fields in log (does not write to file, used to collect updates. 113 | 114 | Args: 115 | items: dictionary of items to update 116 | """ 117 | self.log_entry.update(items) 118 | 119 | def close(self): 120 | """ Close log file - log cannot be written after this """ 121 | self.f.close() 122 | -------------------------------------------------------------------------------- /trpo/policy.py: -------------------------------------------------------------------------------- 1 | """ 2 | NN Policy with KL Divergence Constraint 3 | 4 | Written by Patrick Coady (pat-coady.github.io) 5 | """ 6 | import tensorflow.keras.backend as K 7 | from tensorflow.keras import Model 8 | from tensorflow.keras.layers import Dense, Layer 9 | from tensorflow.keras.optimizers import Adam 10 | import numpy as np 11 | 12 | 13 | class Policy(object): 14 | def __init__(self, obs_dim, act_dim, kl_targ, hid1_mult, init_logvar): 15 | """ 16 | Args: 17 | obs_dim: num observation dimensions (int) 18 | act_dim: num action dimensions (int) 19 | kl_targ: target KL divergence between pi_old and pi_new 20 | hid1_mult: size of first hidden layer, multiplier of obs_dim 21 | init_logvar: natural log of initial policy variance 22 | """ 23 | self.beta = 1.0 # dynamically adjusted D_KL loss multiplier 24 | eta = 50 # multiplier for D_KL-kl_targ hinge-squared loss 25 | self.kl_targ = kl_targ 26 | self.epochs = 20 27 | self.lr_multiplier = 1.0 # dynamically adjust lr when D_KL out of control 28 | self.trpo = TRPO(obs_dim, act_dim, hid1_mult, kl_targ, init_logvar, eta) 29 | self.policy = self.trpo.get_layer('policy_nn') 30 | self.lr = self.policy.get_lr() # lr calculated based on size of PolicyNN 31 | self.trpo.compile(optimizer=Adam(self.lr * self.lr_multiplier)) 32 | self.logprob_calc = LogProb() 33 | 34 | def sample(self, obs): 35 | """Draw sample from policy.""" 36 | act_means, act_logvars = self.policy(obs) 37 | act_stddevs = np.exp(act_logvars / 2) 38 | 39 | return np.random.normal(act_means, act_stddevs).astype(np.float32) 40 | 41 | def update(self, observes, actions, advantages, logger): 42 | """ Update policy based on observations, actions and advantages 43 | 44 | Args: 45 | observes: observations, shape = (N, obs_dim) 46 | actions: actions, shape = (N, act_dim) 47 | advantages: advantages, shape = (N,) 48 | logger: Logger object, see utils.py 49 | """ 50 | K.set_value(self.trpo.optimizer.lr, self.lr * self.lr_multiplier) 51 | K.set_value(self.trpo.beta, self.beta) 52 | old_means, old_logvars = self.policy(observes) 53 | old_means = old_means.numpy() 54 | old_logvars = old_logvars.numpy() 55 | old_logp = self.logprob_calc([actions, old_means, old_logvars]) 56 | old_logp = old_logp.numpy() 57 | loss, kl, entropy = 0, 0, 0 58 | for e in range(self.epochs): 59 | loss = self.trpo.train_on_batch([observes, actions, advantages, 60 | old_means, old_logvars, old_logp]) 61 | kl, entropy = self.trpo.predict_on_batch([observes, actions, advantages, 62 | old_means, old_logvars, old_logp]) 63 | kl, entropy = np.mean(kl), np.mean(entropy) 64 | if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly 65 | break 66 | # TODO: too many "magic numbers" in next 8 lines of code, need to clean up 67 | if kl > self.kl_targ * 2: # servo beta to reach D_KL target 68 | self.beta = np.minimum(35, 1.5 * self.beta) # max clip beta 69 | if self.beta > 30 and self.lr_multiplier > 0.1: 70 | self.lr_multiplier /= 1.5 71 | elif kl < self.kl_targ / 2: 72 | self.beta = np.maximum(1 / 35, self.beta / 1.5) # min clip beta 73 | if self.beta < (1 / 30) and self.lr_multiplier < 10: 74 | self.lr_multiplier *= 1.5 75 | 76 | logger.log({'PolicyLoss': loss, 77 | 'PolicyEntropy': entropy, 78 | 'KL': kl, 79 | 'Beta': self.beta, 80 | '_lr_multiplier': self.lr_multiplier}) 81 | 82 | 83 | class PolicyNN(Layer): 84 | """ Neural net for policy approximation function. 85 | 86 | Policy parameterized by Gaussian means and variances. NN outputs mean 87 | action based on observation. Trainable variables hold log-variances 88 | for each action dimension (i.e. variances not determined by NN). 89 | """ 90 | def __init__(self, obs_dim, act_dim, hid1_mult, init_logvar, **kwargs): 91 | super(PolicyNN, self).__init__(**kwargs) 92 | self.batch_sz = None 93 | self.init_logvar = init_logvar 94 | hid1_units = obs_dim * hid1_mult 95 | hid3_units = act_dim * 10 # 10 empirically determined 96 | hid2_units = int(np.sqrt(hid1_units * hid3_units)) 97 | self.lr = 9e-4 / np.sqrt(hid2_units) # 9e-4 empirically determined 98 | # heuristic to set learning rate based on NN size (tuned on 'Hopper-v1') 99 | self.dense1 = Dense(hid1_units, activation='tanh', input_shape=(obs_dim,)) 100 | self.dense2 = Dense(hid2_units, activation='tanh', input_shape=(hid1_units,)) 101 | self.dense3 = Dense(hid3_units, activation='tanh', input_shape=(hid2_units,)) 102 | self.dense4 = Dense(act_dim, input_shape=(hid3_units,)) 103 | # logvar_speed increases learning rate for log-variances. 104 | # heuristic sets logvar_speed based on network size. 105 | logvar_speed = (10 * hid3_units) // 48 106 | self.logvars = self.add_weight(shape=(logvar_speed, act_dim), 107 | trainable=True, initializer='zeros') 108 | print('Policy Params -- h1: {}, h2: {}, h3: {}, lr: {:.3g}, logvar_speed: {}' 109 | .format(hid1_units, hid2_units, hid3_units, self.lr, logvar_speed)) 110 | 111 | def build(self, input_shape): 112 | self.batch_sz = input_shape[0] 113 | 114 | def call(self, inputs, **kwargs): 115 | y = self.dense1(inputs) 116 | y = self.dense2(y) 117 | y = self.dense3(y) 118 | means = self.dense4(y) 119 | logvars = K.sum(self.logvars, axis=0, keepdims=True) + self.init_logvar 120 | logvars = K.tile(logvars, (self.batch_sz, 1)) 121 | 122 | return [means, logvars] 123 | 124 | def get_lr(self): 125 | return self.lr 126 | 127 | 128 | class KLEntropy(Layer): 129 | """ 130 | Layer calculates: 131 | 1. KL divergence between old and new distributions 132 | 2. Entropy of present policy 133 | 134 | https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Kullback.E2.80.93Leibler_divergence 135 | https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Entropy 136 | """ 137 | def __init__(self, **kwargs): 138 | super(KLEntropy, self).__init__(**kwargs) 139 | self.act_dim = None 140 | 141 | def build(self, input_shape): 142 | self.act_dim = input_shape[0][1] 143 | 144 | def call(self, inputs, **kwargs): 145 | old_means, old_logvars, new_means, new_logvars = inputs 146 | log_det_cov_old = K.sum(old_logvars, axis=-1, keepdims=True) 147 | log_det_cov_new = K.sum(new_logvars, axis=-1, keepdims=True) 148 | trace_old_new = K.sum(K.exp(old_logvars - new_logvars), axis=-1, keepdims=True) 149 | kl = 0.5 * (log_det_cov_new - log_det_cov_old + trace_old_new + 150 | K.sum(K.square(new_means - old_means) / 151 | K.exp(new_logvars), axis=-1, keepdims=True) - 152 | np.float32(self.act_dim)) 153 | entropy = 0.5 * (np.float32(self.act_dim) * (np.log(2 * np.pi) + 1.0) + 154 | K.sum(new_logvars, axis=-1, keepdims=True)) 155 | 156 | return [kl, entropy] 157 | 158 | 159 | class LogProb(Layer): 160 | """Layer calculates log probabilities of a batch of actions.""" 161 | def __init__(self, **kwargs): 162 | super(LogProb, self).__init__(**kwargs) 163 | 164 | def call(self, inputs, **kwargs): 165 | actions, act_means, act_logvars = inputs 166 | logp = -0.5 * K.sum(act_logvars, axis=-1, keepdims=True) 167 | logp += -0.5 * K.sum(K.square(actions - act_means) / K.exp(act_logvars), 168 | axis=-1, keepdims=True) 169 | 170 | return logp 171 | 172 | 173 | class TRPO(Model): 174 | def __init__(self, obs_dim, act_dim, hid1_mult, kl_targ, init_logvar, eta, **kwargs): 175 | super(TRPO, self).__init__(**kwargs) 176 | self.kl_targ = kl_targ 177 | self.eta = eta 178 | self.beta = self.add_weight('beta', initializer='zeros', trainable=False) 179 | self.policy = PolicyNN(obs_dim, act_dim, hid1_mult, init_logvar) 180 | self.logprob = LogProb() 181 | self.kl_entropy = KLEntropy() 182 | 183 | def call(self, inputs): 184 | obs, act, adv, old_means, old_logvars, old_logp = inputs 185 | new_means, new_logvars = self.policy(obs) 186 | new_logp = self.logprob([act, new_means, new_logvars]) 187 | kl, entropy = self.kl_entropy([old_means, old_logvars, 188 | new_means, new_logvars]) 189 | loss1 = -K.mean(adv * K.exp(new_logp - old_logp)) 190 | loss2 = K.mean(self.beta * kl) 191 | # TODO - Take mean before or after hinge loss? 192 | loss3 = self.eta * K.square(K.maximum(0.0, K.mean(kl) - 2.0 * self.kl_targ)) 193 | self.add_loss(loss1 + loss2 + loss3) 194 | 195 | return [kl, entropy] 196 | -------------------------------------------------------------------------------- /trpo/archive.py: -------------------------------------------------------------------------------- 1 | """ 2 | Archive of Procedures Not Used in Final Implementation 3 | 4 | Written by Patrick Coady (pat-coady.github.io) 5 | """ 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | 10 | class ConstantScaler(object): 11 | """ Dumb scaler, scale and offset set at initialization """ 12 | def __init__(self, obs_dim, scale=1.0, offset=0.0): 13 | self.scale = np.ones(obs_dim) * scale 14 | self.offset = np.zeros(obs_dim) + offset 15 | 16 | def update(self, x): 17 | pass # no updates for constant scaler 18 | 19 | def get(self): 20 | """ returns 2-tuple: (scale, offset) """ 21 | return self.scale, self.offset 22 | 23 | 24 | class LinearValueFunction(object): 25 | """Simple linear regression value function, uses linear and squared features. 26 | 27 | Mostly copied from: https://github.com/joschu/modular_rl 28 | """ 29 | def __init__(self): 30 | self.coef = None 31 | 32 | def fit(self, x, y, logger): 33 | """ Fit model - (i.e. solve normal equations) 34 | 35 | Args: 36 | x: features 37 | y: target 38 | logger: logger to save training loss and % explained variance 39 | """ 40 | y_hat = self.predict(x) 41 | old_exp_var = 1-np.var(y-y_hat)/np.var(y) 42 | xp = self.preproc(x) 43 | a = xp.T.dot(xp) 44 | nfeats = xp.shape[1] 45 | a[np.arange(nfeats), np.arange(nfeats)] += 1e-3 # a little ridge regression 46 | b = xp.T.dot(y) 47 | self.coef = np.linalg.solve(a, b) 48 | y_hat = self.predict(x) 49 | loss = np.mean(np.square(y_hat-y)) 50 | exp_var = 1-np.var(y-y_hat)/np.var(y) 51 | 52 | logger.log({'LinValFuncLoss': loss, 53 | 'LinExplainedVarNew': exp_var, 54 | 'LinExplainedVarOld': old_exp_var}) 55 | 56 | def predict(self, x): 57 | """ Predict method, predict zeros if model untrained """ 58 | if self.coef is None: 59 | return np.zeros(x.shape[0]) 60 | else: 61 | return self.preproc(x).dot(self.coef) 62 | 63 | @staticmethod 64 | def preproc(X): 65 | """ Adds squared features and bias term """ 66 | 67 | return np.concatenate([np.ones([X.shape[0], 1]), X, np.square(X)/2.0], axis=1) 68 | 69 | 70 | def add_advantage(trajectories): 71 | """ Adds estimated advantage to all time steps of all trajectories 72 | 73 | Args: 74 | trajectories: as returned by run_policy(), must include 'values' 75 | key from add_value(). 76 | 77 | Returns: 78 | None (mutates trajectories dictionary to add 'advantages') 79 | """ 80 | for trajectory in trajectories: 81 | trajectory['advantages'] = trajectory['disc_sum_rew'] - trajectory['values'] 82 | 83 | 84 | class PolicyWithVariance(object): 85 | """ Neural Net output means AND variance (had poor performance) """ 86 | def __init__(self, obs_dim, act_dim, kl_targ=0.003): 87 | self.beta = 1.0 88 | self.kl_targ = kl_targ 89 | self._build_graph(obs_dim, act_dim) 90 | self._init_session() 91 | 92 | def _build_graph(self, obs_dim, act_dim): 93 | """ Build TensorFlow graph""" 94 | self.g = tf.Graph() 95 | with self.g.as_default(): 96 | self._placeholders(obs_dim, act_dim) 97 | self._policy_nn(obs_dim, act_dim) 98 | self._logprob(act_dim) 99 | self._kl_entropy(act_dim) 100 | self._sample(act_dim) 101 | self._loss_train_op() 102 | self.init = tf.global_variables_initializer() 103 | 104 | def _placeholders(self, obs_dim, act_dim): 105 | """ Input placeholders""" 106 | self.obs_ph = tf.placeholder(tf.float32, (None, obs_dim), 'obs') 107 | self.act_ph = tf.placeholder(tf.float32, (None, act_dim), 'act') 108 | self.advantages_ph = tf.placeholder(tf.float32, (None,), 'advantages') 109 | self.beta_ph = tf.placeholder(tf.float32, (), 'beta') 110 | self.eta_ph = tf.placeholder(tf.float32, (), 'eta') 111 | self.old_log_vars_ph = tf.placeholder(tf.float32, (None, act_dim,), 'old_log_vars') 112 | self.old_means_ph = tf.placeholder(tf.float32, (None, act_dim), 'old_means') 113 | 114 | def _policy_nn(self, obs_dim, act_dim): 115 | """ Neural net for policy approximation function """ 116 | out = tf.layers.dense(self.obs_ph, 200, tf.tanh, 117 | kernel_initializer=tf.random_normal_initializer( 118 | stddev=np.sqrt(1 / obs_dim)), 119 | name="h1") 120 | out = tf.layers.dense(out, 100, tf.tanh, 121 | kernel_initializer=tf.random_normal_initializer( 122 | stddev=np.sqrt(1 / 200)), 123 | name="h2") 124 | out = tf.layers.dense(out, 50, tf.tanh, 125 | kernel_initializer=tf.random_normal_initializer( 126 | stddev=np.sqrt(1 / 100)), 127 | name="h3") 128 | self.means = tf.layers.dense(out, act_dim, 129 | kernel_initializer=tf.random_normal_initializer( 130 | stddev=np.sqrt(1 / 50)), 131 | name="means") 132 | self.log_vars = tf.layers.dense(out, act_dim, 133 | kernel_initializer=tf.random_normal_initializer( 134 | stddev=np.sqrt(1 / 50)), 135 | name="log_vars") 136 | 137 | def _logprob(self, act_dim): 138 | """ Log probabilities of batch of states, actions""" 139 | logp = -0.5 * (np.log(np.sqrt(2.0 * np.pi)) * act_dim) 140 | # logp += -0.5 * tf.reduce_sum(self.log_vars, axis=1) 141 | logp += -0.5 * tf.reduce_sum(self.log_vars) 142 | logp += -0.5 * tf.reduce_sum(tf.square(self.act_ph - self.means) / 143 | tf.exp(self.log_vars), axis=1) 144 | self.logp = logp 145 | 146 | logp_old = -0.5 * (np.log(np.sqrt(2.0 * np.pi)) * act_dim) 147 | logp_old += -0.5 * tf.reduce_sum(self.old_log_vars_ph, axis=1) 148 | logp_old += -0.5 * tf.reduce_sum(tf.square(self.act_ph - self.old_means_ph) / 149 | tf.exp(self.old_log_vars_ph), axis=1) 150 | self.logp_old = logp_old 151 | 152 | def _kl_entropy(self, act_dim): 153 | """ 154 | Add KL divergence between old and new distributions 155 | Add entropy of present policy given states and actions 156 | """ 157 | log_det_cov_old = tf.reduce_sum(self.old_log_vars_ph, axis=1) 158 | log_det_cov_new = tf.reduce_sum(self.log_vars, axis=1) 159 | tr_old_new = tf.reduce_sum(tf.exp(self.old_log_vars_ph - self.log_vars), axis=1) 160 | 161 | self.kl = 0.5 * tf.reduce_mean(log_det_cov_new - log_det_cov_old + tr_old_new + 162 | tf.reduce_sum(tf.square(self.means - self.old_means_ph) / 163 | tf.exp(self.log_vars), axis=1) - act_dim) 164 | 165 | self.entropy = 0.5 * (act_dim * (np.log(2 * np.pi) + 1) + 166 | tf.reduce_mean(tf.reduce_sum(self.log_vars, axis=1))) 167 | 168 | def _sample(self, act_dim): 169 | """ Sample from distribution, given observation""" 170 | self.sampled_act = (self.means + 171 | tf.exp(self.log_vars / 2.0) * tf.random_normal(shape=(act_dim,))) 172 | 173 | def _loss_train_op(self): 174 | # TODO: use reduce_mean or reduce_sum? 175 | loss1 = -tf.reduce_mean(self.advantages_ph * 176 | tf.exp(self.logp - self.logp_old)) 177 | loss2 = tf.reduce_mean(self.beta_ph * self.kl) 178 | loss3 = self.eta_ph * tf.square(tf.maximum(0.0, self.kl - 2.0 * self.kl_targ)) 179 | self.loss = loss1 + loss2 + loss3 180 | # optimizer = tf.train.AdamOptimizer(0.00003) 181 | optimizer = tf.train.MomentumOptimizer(learning_rate=0.001, momentum=0.9, use_nesterov=True) 182 | self.train_op = optimizer.minimize(self.loss) 183 | 184 | def _init_session(self): 185 | """Launch TensorFlow session and initialize variables""" 186 | self.sess = tf.Session(graph=self.g) 187 | self.sess.run(self.init) 188 | 189 | def sample(self, obs): 190 | """Draw sample from policy distribution""" 191 | feed_dict = {self.obs_ph: obs} 192 | 193 | return self.sess.run(self.sampled_act, feed_dict=feed_dict) 194 | 195 | def update(self, observes, actions, advantages, logger, epochs=20): 196 | feed_dict = {self.obs_ph: observes, 197 | self.act_ph: actions, 198 | self.advantages_ph: advantages, 199 | self.beta_ph: self.beta, 200 | self.eta_ph: 100} 201 | old_means_np, old_log_vars_np = self.sess.run([self.means, self.log_vars], 202 | feed_dict) 203 | feed_dict[self.old_log_vars_ph] = old_log_vars_np 204 | feed_dict[self.old_means_ph] = old_means_np 205 | for e in range(epochs): 206 | self.sess.run(self.train_op, feed_dict) 207 | loss, kl, entropy = self.sess.run([self.loss, self.kl, self.entropy], feed_dict) 208 | if kl > self.kl_targ * 4: 209 | break 210 | if kl > self.kl_targ * 2: 211 | self.beta *= 1.5 212 | elif kl < self.kl_targ / 2: 213 | self.beta /= 1.5 214 | 215 | logger.log({'PolicyLoss': loss, 216 | 'PolicyEntropy': entropy, 217 | 'KL': kl, 218 | 'Beta': self.beta}) 219 | 220 | def close_sess(self): 221 | self.sess.close() -------------------------------------------------------------------------------- /trpo/train.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | """ 3 | TRPO: Trust Region Policy Optimization 4 | 5 | Written by Patrick Coady (pat-coady.github.io) 6 | 7 | See these papers for details: 8 | 9 | TRPO / PPO: 10 | https://arxiv.org/pdf/1502.05477.pdf (Schulman et al., 2016) 11 | 12 | Distributed PPO: 13 | https://arxiv.org/abs/1707.02286 (Heess et al., 2017) 14 | 15 | Generalized Advantage Estimation: 16 | https://arxiv.org/pdf/1506.02438.pdf 17 | 18 | And, also, this GitHub repo which was helpful to me during 19 | implementation: 20 | https://github.com/joschu/modular_rl 21 | 22 | This implementation learns policies for continuous environments 23 | in the OpenAI Gym (https://gym.openai.com/). Testing was focused on 24 | the MuJoCo control tasks. 25 | """ 26 | import gym 27 | import pybullet 28 | import pybullet_envs 29 | import numpy as np 30 | from gym import wrappers 31 | from policy import Policy 32 | from value import NNValueFunction 33 | import scipy.signal 34 | from utils import Logger, Scaler 35 | from datetime import datetime 36 | import os 37 | import argparse 38 | import signal 39 | 40 | 41 | class GracefulKiller: 42 | """Gracefully exit program on CTRL-C.""" 43 | def __init__(self): 44 | self.kill_now = False 45 | signal.signal(signal.SIGINT, self.exit_gracefully) 46 | signal.signal(signal.SIGTERM, self.exit_gracefully) 47 | 48 | def exit_gracefully(self, signum, frame): 49 | self.kill_now = True 50 | 51 | 52 | def init_gym(env_name): 53 | """ 54 | Initialize gym environment, return dimension of observation 55 | and action spaces. 56 | 57 | Args: 58 | env_name: str environment name (e.g. "Humanoid-v1") 59 | 60 | Returns: 3-tuple 61 | gym environment (object) 62 | number of observation dimensions (int) 63 | number of action dimensions (int) 64 | """ 65 | env = gym.make(env_name) 66 | obs_dim = env.observation_space.shape[0] 67 | act_dim = env.action_space.shape[0] 68 | 69 | return env, obs_dim, act_dim 70 | 71 | 72 | def run_episode(env, policy, scaler, animate=False): 73 | """Run single episode with option to animate. 74 | 75 | Args: 76 | env: ai gym environment 77 | policy: policy object with sample() method 78 | scaler: scaler object, used to scale/offset each observation dimension 79 | to a similar range 80 | animate: boolean, True uses env.render() method to animate episode 81 | 82 | Returns: 4-tuple of NumPy arrays 83 | observes: shape = (episode len, obs_dim) 84 | actions: shape = (episode len, act_dim) 85 | rewards: shape = (episode len,) 86 | unscaled_obs: useful for training scaler, shape = (episode len, obs_dim) 87 | """ 88 | obs = env.reset() 89 | observes, actions, rewards, unscaled_obs = [], [], [], [] 90 | done = False 91 | step = 0.0 92 | scale, offset = scaler.get() 93 | scale[-1] = 1.0 # don't scale time step feature 94 | offset[-1] = 0.0 # don't offset time step feature 95 | while not done: 96 | if animate: 97 | env.render() 98 | obs = np.concatenate([obs, [step]]) # add time step feature 99 | obs = obs.astype(np.float32).reshape((1, -1)) 100 | unscaled_obs.append(obs) 101 | obs = np.float32((obs - offset) * scale) # center and scale observations 102 | observes.append(obs) 103 | action = policy.sample(obs) 104 | actions.append(action) 105 | obs, reward, done, _ = env.step(action.flatten()) 106 | rewards.append(reward) 107 | step += 1e-3 # increment time step feature 108 | 109 | return (np.concatenate(observes), np.concatenate(actions), 110 | np.array(rewards, dtype=np.float32), np.concatenate(unscaled_obs)) 111 | 112 | 113 | def run_policy(env, policy, scaler, logger, episodes): 114 | """ Run policy and collect data for a minimum of min_steps and min_episodes 115 | 116 | Args: 117 | env: ai gym environment 118 | policy: policy object with sample() method 119 | scaler: scaler object, used to scale/offset each observation dimension 120 | to a similar range 121 | logger: logger object, used to save stats from episodes 122 | episodes: total episodes to run 123 | 124 | Returns: list of trajectory dictionaries, list length = number of episodes 125 | 'observes' : NumPy array of states from episode 126 | 'actions' : NumPy array of actions from episode 127 | 'rewards' : NumPy array of (un-discounted) rewards from episode 128 | 'unscaled_obs' : NumPy array of (un-discounted) rewards from episode 129 | """ 130 | total_steps = 0 131 | trajectories = [] 132 | for e in range(episodes): 133 | observes, actions, rewards, unscaled_obs = run_episode(env, policy, scaler) 134 | # print(observes.shape) 135 | # print(actions.shape) 136 | # print(rewards.shape) 137 | # print(unscaled_obs.shape) 138 | # print(observes.dtype) 139 | # print(actions.dtype) 140 | # print(rewards.dtype) 141 | # print(unscaled_obs.dtype) 142 | total_steps += observes.shape[0] 143 | trajectory = {'observes': observes, 144 | 'actions': actions, 145 | 'rewards': rewards, 146 | 'unscaled_obs': unscaled_obs} 147 | trajectories.append(trajectory) 148 | unscaled = np.concatenate([t['unscaled_obs'] for t in trajectories]) 149 | scaler.update(unscaled) # update running statistics for scaling observations 150 | logger.log({'_MeanReward': np.mean([t['rewards'].sum() for t in trajectories]), 151 | 'Steps': total_steps}) 152 | 153 | return trajectories 154 | 155 | 156 | def discount(x, gamma): 157 | """ Calculate discounted forward sum of a sequence at each point """ 158 | return scipy.signal.lfilter([1.0], [1.0, -gamma], x[::-1])[::-1] 159 | 160 | 161 | def add_disc_sum_rew(trajectories, gamma): 162 | """ Adds discounted sum of rewards to all time steps of all trajectories 163 | 164 | Args: 165 | trajectories: as returned by run_policy() 166 | gamma: discount 167 | 168 | Returns: 169 | None (mutates trajectories dictionary to add 'disc_sum_rew') 170 | """ 171 | for trajectory in trajectories: 172 | if gamma < 0.999: # don't scale for gamma ~= 1 173 | rewards = trajectory['rewards'] * (1 - gamma) 174 | else: 175 | rewards = trajectory['rewards'] 176 | disc_sum_rew = discount(rewards, gamma) 177 | trajectory['disc_sum_rew'] = disc_sum_rew 178 | 179 | 180 | def add_value(trajectories, val_func): 181 | """ Adds estimated value to all time steps of all trajectories 182 | 183 | Args: 184 | trajectories: as returned by run_policy() 185 | val_func: object with predict() method, takes observations 186 | and returns predicted state value 187 | 188 | Returns: 189 | None (mutates trajectories dictionary to add 'values') 190 | """ 191 | for trajectory in trajectories: 192 | observes = trajectory['observes'] 193 | values = val_func.predict(observes) 194 | trajectory['values'] = values.flatten() 195 | 196 | 197 | def add_gae(trajectories, gamma, lam): 198 | """ Add generalized advantage estimator. 199 | https://arxiv.org/pdf/1506.02438.pdf 200 | 201 | Args: 202 | trajectories: as returned by run_policy(), must include 'values' 203 | key from add_value(). 204 | gamma: reward discount 205 | lam: lambda (see paper). 206 | lam=0 : use TD residuals 207 | lam=1 : A = Sum Discounted Rewards - V_hat(s) 208 | 209 | Returns: 210 | None (mutates trajectories dictionary to add 'advantages') 211 | """ 212 | for trajectory in trajectories: 213 | if gamma < 0.999: # don't scale for gamma ~= 1 214 | rewards = trajectory['rewards'] * (1 - gamma) 215 | else: 216 | rewards = trajectory['rewards'] 217 | values = trajectory['values'] 218 | # temporal differences 219 | tds = rewards - values + np.append(values[1:] * gamma, 0) 220 | advantages = discount(tds, gamma * lam) 221 | trajectory['advantages'] = advantages 222 | 223 | 224 | def build_train_set(trajectories): 225 | """ 226 | 227 | Args: 228 | trajectories: trajectories after processing by add_disc_sum_rew(), 229 | add_value(), and add_gae() 230 | 231 | Returns: 4-tuple of NumPy arrays 232 | observes: shape = (N, obs_dim) 233 | actions: shape = (N, act_dim) 234 | advantages: shape = (N,) 235 | disc_sum_rew: shape = (N,) 236 | """ 237 | observes = np.concatenate([t['observes'] for t in trajectories]) 238 | actions = np.concatenate([t['actions'] for t in trajectories]) 239 | disc_sum_rew = np.concatenate([t['disc_sum_rew'] for t in trajectories]) 240 | advantages = np.concatenate([t['advantages'] for t in trajectories]) 241 | # normalize advantages 242 | advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-6) 243 | 244 | return observes, actions, advantages, disc_sum_rew 245 | 246 | 247 | def log_batch_stats(observes, actions, advantages, disc_sum_rew, logger, episode): 248 | """ Log various batch statistics """ 249 | logger.log({'_mean_obs': np.mean(observes), 250 | '_min_obs': np.min(observes), 251 | '_max_obs': np.max(observes), 252 | '_std_obs': np.mean(np.var(observes, axis=0)), 253 | '_mean_act': np.mean(actions), 254 | '_min_act': np.min(actions), 255 | '_max_act': np.max(actions), 256 | '_std_act': np.mean(np.var(actions, axis=0)), 257 | '_mean_adv': np.mean(advantages), 258 | '_min_adv': np.min(advantages), 259 | '_max_adv': np.max(advantages), 260 | '_std_adv': np.var(advantages), 261 | '_mean_discrew': np.mean(disc_sum_rew), 262 | '_min_discrew': np.min(disc_sum_rew), 263 | '_max_discrew': np.max(disc_sum_rew), 264 | '_std_discrew': np.var(disc_sum_rew), 265 | '_Episode': episode 266 | }) 267 | 268 | 269 | def main(env_name, num_episodes, gamma, lam, kl_targ, batch_size, hid1_mult, init_logvar): 270 | """ Main training loop 271 | 272 | Args: 273 | env_name: OpenAI Gym environment name, e.g. 'Hopper-v1' 274 | num_episodes: maximum number of episodes to run 275 | gamma: reward discount factor (float) 276 | lam: lambda from Generalized Advantage Estimate 277 | kl_targ: D_KL target for policy update [D_KL(pi_old || pi_new) 278 | batch_size: number of episodes per policy training batch 279 | hid1_mult: hid1 size for policy and value_f (multiplier of obs dimension) 280 | init_logvar: natural log of initial policy variance 281 | """ 282 | pybullet.connect(pybullet.DIRECT) 283 | killer = GracefulKiller() 284 | env, obs_dim, act_dim = init_gym(env_name) 285 | obs_dim += 1 # add 1 to obs dimension for time step feature (see run_episode()) 286 | now = datetime.utcnow().strftime("%b-%d_%H:%M:%S") # create unique directories 287 | logger = Logger(logname=env_name, now=now) 288 | aigym_path = os.path.join('/tmp', env_name, now) 289 | env = wrappers.Monitor(env, aigym_path, force=True) 290 | scaler = Scaler(obs_dim) 291 | val_func = NNValueFunction(obs_dim, hid1_mult) 292 | policy = Policy(obs_dim, act_dim, kl_targ, hid1_mult, init_logvar) 293 | # run a few episodes of untrained policy to initialize scaler: 294 | run_policy(env, policy, scaler, logger, episodes=5) 295 | episode = 0 296 | while episode < num_episodes: 297 | trajectories = run_policy(env, policy, scaler, logger, episodes=batch_size) 298 | episode += len(trajectories) 299 | add_value(trajectories, val_func) # add estimated values to episodes 300 | add_disc_sum_rew(trajectories, gamma) # calculated discounted sum of Rs 301 | add_gae(trajectories, gamma, lam) # calculate advantage 302 | # concatenate all episodes into single NumPy arrays 303 | observes, actions, advantages, disc_sum_rew = build_train_set(trajectories) 304 | # add various stats to training log: 305 | log_batch_stats(observes, actions, advantages, disc_sum_rew, logger, episode) 306 | policy.update(observes, actions, advantages, logger) # update policy 307 | val_func.fit(observes, disc_sum_rew, logger) # update value function 308 | logger.write(display=True) # write logger results to file and stdout 309 | if killer.kill_now: 310 | if input('Terminate training (y/[n])? ') == 'y': 311 | break 312 | killer.kill_now = False 313 | logger.close() 314 | 315 | 316 | if __name__ == "__main__": 317 | parser = argparse.ArgumentParser(description=('Train policy on OpenAI Gym environment ' 318 | 'using Proximal Policy Optimizer')) 319 | parser.add_argument('env_name', type=str, help='OpenAI Gym (PyBullet) environment name') 320 | parser.add_argument('-n', '--num_episodes', type=int, help='Number of episodes to run', 321 | default=1000) 322 | parser.add_argument('-g', '--gamma', type=float, help='Discount factor', default=0.995) 323 | parser.add_argument('-l', '--lam', type=float, help='Lambda for Generalized Advantage Estimation', 324 | default=0.98) 325 | parser.add_argument('-k', '--kl_targ', type=float, help='D_KL target value', 326 | default=0.003) 327 | parser.add_argument('-b', '--batch_size', type=int, 328 | help='Number of episodes per training batch', 329 | default=20) 330 | parser.add_argument('-m', '--hid1_mult', type=int, 331 | help='Size of first hidden layer for value and policy NNs' 332 | '(integer multiplier of observation dimension)', 333 | default=10) 334 | parser.add_argument('-v', '--init_logvar', type=float, 335 | help='Initial policy log-variance (natural log of variance)', 336 | default=-1.0) 337 | 338 | args = parser.parse_args() 339 | main(**vars(args)) 340 | --------------------------------------------------------------------------------