├── LICENSE.md ├── README.md ├── conjugate_gradients.py ├── main.py ├── models.py ├── replay_memory.py ├── running_state.py ├── trpo.py └── utils.py /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Ilya Kostrikov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch implementation of TRPO 2 | 3 | Try my implementation of [PPO](github.com/ikostrikov/pytorch-a2c-ppo-acktr/) (aka newer better variant of TRPO), unless you need to you TRPO for some specific reasons. 4 | 5 | ## 6 | 7 | This is a PyTorch implementation of ["Trust Region Policy Optimization (TRPO)"](https://arxiv.org/abs/1502.05477). 8 | 9 | This is code mostly ported from [original implementation by John Schulman](https://github.com/joschu/modular_rl). In contrast to [another implementation of TRPO in PyTorch](https://github.com/mjacar/pytorch-trpo), this implementation uses exact Hessian-vector product instead of finite differences approximation. 10 | 11 | ## Contributions 12 | 13 | Contributions are very welcome. If you know how to make this code better, don't hesitate to send a pull request. 14 | 15 | ## Usage 16 | 17 | ``` 18 | python main.py --env-name "Reacher-v1" 19 | ``` 20 | 21 | ## Recommended hyper parameters 22 | 23 | InvertedPendulum-v1: 5000 24 | 25 | Reacher-v1, InvertedDoublePendulum-v1: 15000 26 | 27 | HalfCheetah-v1, Hopper-v1, Swimmer-v1, Walker2d-v1: 25000 28 | 29 | Ant-v1, Humanoid-v1: 50000 30 | 31 | ## Results 32 | 33 | More or less similar to the original code. Coming soon. 34 | 35 | ## Todo 36 | 37 | - [ ] Plots. 38 | - [ ] Collect data in multiple threads. 39 | -------------------------------------------------------------------------------- /conjugate_gradients.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def conjugate_gradients(Avp, b, nsteps, residual_tol=1e-10): 5 | x = torch.zeros(b.size()) 6 | r = b - Avp(x) 7 | p = r 8 | rdotr = torch.dot(r, r) 9 | 10 | for i in range(nsteps): 11 | _Avp = Avp(p) 12 | alpha = rdotr / torch.dot(p, _Avp) 13 | x += alpha * p 14 | r -= alpha * _Avp 15 | new_rdotr = torch.dot(r, r) 16 | betta = new_rdotr / rdotr 17 | p = r + betta * p 18 | rdotr = new_rdotr 19 | if rdotr < residual_tol: 20 | break 21 | return x 22 | 23 | 24 | def flat_grad_from(net, grad_grad=False): 25 | grads = [] 26 | for param in net.parameters(): 27 | if grad_grad: 28 | grads.append(param.grad.grad.view(-1)) 29 | else: 30 | grads.append(param.grad.view(-1)) 31 | 32 | flat_grad = torch.cat(grads) 33 | return flat_grad 34 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from itertools import count 3 | 4 | import gym 5 | import scipy.optimize 6 | 7 | import torch 8 | from models import * 9 | from replay_memory import Memory 10 | from running_state import ZFilter 11 | from torch.autograd import Variable 12 | from trpo import trpo_step 13 | from utils import * 14 | 15 | torch.utils.backcompat.broadcast_warning.enabled = True 16 | torch.utils.backcompat.keepdim_warning.enabled = True 17 | 18 | torch.set_default_tensor_type('torch.DoubleTensor') 19 | 20 | parser = argparse.ArgumentParser(description='PyTorch actor-critic example') 21 | parser.add_argument('--gamma', type=float, default=0.995, metavar='G', 22 | help='discount factor (default: 0.995)') 23 | parser.add_argument('--env-name', default="Reacher-v1", metavar='G', 24 | help='name of the environment to run') 25 | parser.add_argument('--tau', type=float, default=0.97, metavar='G', 26 | help='gae (default: 0.97)') 27 | parser.add_argument('--l2-reg', type=float, default=1e-3, metavar='G', 28 | help='l2 regularization regression (default: 1e-3)') 29 | parser.add_argument('--max-kl', type=float, default=1e-2, metavar='G', 30 | help='max kl value (default: 1e-2)') 31 | parser.add_argument('--damping', type=float, default=1e-1, metavar='G', 32 | help='damping (default: 1e-1)') 33 | parser.add_argument('--seed', type=int, default=543, metavar='N', 34 | help='random seed (default: 1)') 35 | parser.add_argument('--batch-size', type=int, default=15000, metavar='N', 36 | help='random seed (default: 1)') 37 | parser.add_argument('--render', action='store_true', 38 | help='render the environment') 39 | parser.add_argument('--log-interval', type=int, default=1, metavar='N', 40 | help='interval between training status logs (default: 10)') 41 | args = parser.parse_args() 42 | 43 | env = gym.make(args.env_name) 44 | 45 | num_inputs = env.observation_space.shape[0] 46 | num_actions = env.action_space.shape[0] 47 | 48 | env.seed(args.seed) 49 | torch.manual_seed(args.seed) 50 | 51 | policy_net = Policy(num_inputs, num_actions) 52 | value_net = Value(num_inputs) 53 | 54 | def select_action(state): 55 | state = torch.from_numpy(state).unsqueeze(0) 56 | action_mean, _, action_std = policy_net(Variable(state)) 57 | action = torch.normal(action_mean, action_std) 58 | return action 59 | 60 | def update_params(batch): 61 | rewards = torch.Tensor(batch.reward) 62 | masks = torch.Tensor(batch.mask) 63 | actions = torch.Tensor(np.concatenate(batch.action, 0)) 64 | states = torch.Tensor(batch.state) 65 | values = value_net(Variable(states)) 66 | 67 | returns = torch.Tensor(actions.size(0),1) 68 | deltas = torch.Tensor(actions.size(0),1) 69 | advantages = torch.Tensor(actions.size(0),1) 70 | 71 | prev_return = 0 72 | prev_value = 0 73 | prev_advantage = 0 74 | for i in reversed(range(rewards.size(0))): 75 | returns[i] = rewards[i] + args.gamma * prev_return * masks[i] 76 | deltas[i] = rewards[i] + args.gamma * prev_value * masks[i] - values.data[i] 77 | advantages[i] = deltas[i] + args.gamma * args.tau * prev_advantage * masks[i] 78 | 79 | prev_return = returns[i, 0] 80 | prev_value = values.data[i, 0] 81 | prev_advantage = advantages[i, 0] 82 | 83 | targets = Variable(returns) 84 | 85 | # Original code uses the same LBFGS to optimize the value loss 86 | def get_value_loss(flat_params): 87 | set_flat_params_to(value_net, torch.Tensor(flat_params)) 88 | for param in value_net.parameters(): 89 | if param.grad is not None: 90 | param.grad.data.fill_(0) 91 | 92 | values_ = value_net(Variable(states)) 93 | 94 | value_loss = (values_ - targets).pow(2).mean() 95 | 96 | # weight decay 97 | for param in value_net.parameters(): 98 | value_loss += param.pow(2).sum() * args.l2_reg 99 | value_loss.backward() 100 | return (value_loss.data.double().numpy(), get_flat_grad_from(value_net).data.double().numpy()) 101 | 102 | flat_params, _, opt_info = scipy.optimize.fmin_l_bfgs_b(get_value_loss, get_flat_params_from(value_net).double().numpy(), maxiter=25) 103 | set_flat_params_to(value_net, torch.Tensor(flat_params)) 104 | 105 | advantages = (advantages - advantages.mean()) / advantages.std() 106 | 107 | action_means, action_log_stds, action_stds = policy_net(Variable(states)) 108 | fixed_log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds).data.clone() 109 | 110 | def get_loss(volatile=False): 111 | if volatile: 112 | with torch.no_grad(): 113 | action_means, action_log_stds, action_stds = policy_net(Variable(states)) 114 | else: 115 | action_means, action_log_stds, action_stds = policy_net(Variable(states)) 116 | 117 | log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds) 118 | action_loss = -Variable(advantages) * torch.exp(log_prob - Variable(fixed_log_prob)) 119 | return action_loss.mean() 120 | 121 | 122 | def get_kl(): 123 | mean1, log_std1, std1 = policy_net(Variable(states)) 124 | 125 | mean0 = Variable(mean1.data) 126 | log_std0 = Variable(log_std1.data) 127 | std0 = Variable(std1.data) 128 | kl = log_std1 - log_std0 + (std0.pow(2) + (mean0 - mean1).pow(2)) / (2.0 * std1.pow(2)) - 0.5 129 | return kl.sum(1, keepdim=True) 130 | 131 | trpo_step(policy_net, get_loss, get_kl, args.max_kl, args.damping) 132 | 133 | running_state = ZFilter((num_inputs,), clip=5) 134 | running_reward = ZFilter((1,), demean=False, clip=10) 135 | 136 | for i_episode in count(1): 137 | memory = Memory() 138 | 139 | num_steps = 0 140 | reward_batch = 0 141 | num_episodes = 0 142 | while num_steps < args.batch_size: 143 | state = env.reset() 144 | state = running_state(state) 145 | 146 | reward_sum = 0 147 | for t in range(10000): # Don't infinite loop while learning 148 | action = select_action(state) 149 | action = action.data[0].numpy() 150 | next_state, reward, done, _ = env.step(action) 151 | reward_sum += reward 152 | 153 | next_state = running_state(next_state) 154 | 155 | mask = 1 156 | if done: 157 | mask = 0 158 | 159 | memory.push(state, np.array([action]), mask, next_state, reward) 160 | 161 | if args.render: 162 | env.render() 163 | if done: 164 | break 165 | 166 | state = next_state 167 | num_steps += (t-1) 168 | num_episodes += 1 169 | reward_batch += reward_sum 170 | 171 | reward_batch /= num_episodes 172 | batch = memory.sample() 173 | update_params(batch) 174 | 175 | if i_episode % args.log_interval == 0: 176 | print('Episode {}\tLast reward: {}\tAverage reward {:.2f}'.format( 177 | i_episode, reward_sum, reward_batch)) 178 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | import torch.nn as nn 4 | 5 | 6 | class Policy(nn.Module): 7 | def __init__(self, num_inputs, num_outputs): 8 | super(Policy, self).__init__() 9 | self.affine1 = nn.Linear(num_inputs, 64) 10 | self.affine2 = nn.Linear(64, 64) 11 | 12 | self.action_mean = nn.Linear(64, num_outputs) 13 | self.action_mean.weight.data.mul_(0.1) 14 | self.action_mean.bias.data.mul_(0.0) 15 | 16 | self.action_log_std = nn.Parameter(torch.zeros(1, num_outputs)) 17 | 18 | self.saved_actions = [] 19 | self.rewards = [] 20 | self.final_value = 0 21 | 22 | def forward(self, x): 23 | x = torch.tanh(self.affine1(x)) 24 | x = torch.tanh(self.affine2(x)) 25 | 26 | action_mean = self.action_mean(x) 27 | action_log_std = self.action_log_std.expand_as(action_mean) 28 | action_std = torch.exp(action_log_std) 29 | 30 | return action_mean, action_log_std, action_std 31 | 32 | 33 | class Value(nn.Module): 34 | def __init__(self, num_inputs): 35 | super(Value, self).__init__() 36 | self.affine1 = nn.Linear(num_inputs, 64) 37 | self.affine2 = nn.Linear(64, 64) 38 | self.value_head = nn.Linear(64, 1) 39 | self.value_head.weight.data.mul_(0.1) 40 | self.value_head.bias.data.mul_(0.0) 41 | 42 | def forward(self, x): 43 | x = torch.tanh(self.affine1(x)) 44 | x = torch.tanh(self.affine2(x)) 45 | 46 | state_values = self.value_head(x) 47 | return state_values 48 | -------------------------------------------------------------------------------- /replay_memory.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import namedtuple 3 | 4 | # Taken from 5 | # https://github.com/pytorch/tutorials/blob/master/Reinforcement%20(Q-)Learning%20with%20PyTorch.ipynb 6 | 7 | Transition = namedtuple('Transition', ('state', 'action', 'mask', 'next_state', 8 | 'reward')) 9 | 10 | 11 | class Memory(object): 12 | def __init__(self): 13 | self.memory = [] 14 | 15 | def push(self, *args): 16 | """Saves a transition.""" 17 | self.memory.append(Transition(*args)) 18 | 19 | def sample(self): 20 | return Transition(*zip(*self.memory)) 21 | 22 | def __len__(self): 23 | return len(self.memory) 24 | -------------------------------------------------------------------------------- /running_state.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | 3 | import numpy as np 4 | 5 | 6 | # from https://github.com/joschu/modular_rl 7 | # http://www.johndcook.com/blog/standard_deviation/ 8 | class RunningStat(object): 9 | def __init__(self, shape): 10 | self._n = 0 11 | self._M = np.zeros(shape) 12 | self._S = np.zeros(shape) 13 | 14 | def push(self, x): 15 | x = np.asarray(x) 16 | assert x.shape == self._M.shape 17 | self._n += 1 18 | if self._n == 1: 19 | self._M[...] = x 20 | else: 21 | oldM = self._M.copy() 22 | self._M[...] = oldM + (x - oldM) / self._n 23 | self._S[...] = self._S + (x - oldM) * (x - self._M) 24 | 25 | @property 26 | def n(self): 27 | return self._n 28 | 29 | @property 30 | def mean(self): 31 | return self._M 32 | 33 | @property 34 | def var(self): 35 | return self._S / (self._n - 1) if self._n > 1 else np.square(self._M) 36 | 37 | @property 38 | def std(self): 39 | return np.sqrt(self.var) 40 | 41 | @property 42 | def shape(self): 43 | return self._M.shape 44 | 45 | 46 | class ZFilter: 47 | """ 48 | y = (x-mean)/std 49 | using running estimates of mean,std 50 | """ 51 | 52 | def __init__(self, shape, demean=True, destd=True, clip=10.0): 53 | self.demean = demean 54 | self.destd = destd 55 | self.clip = clip 56 | 57 | self.rs = RunningStat(shape) 58 | 59 | def __call__(self, x, update=True): 60 | if update: self.rs.push(x) 61 | if self.demean: 62 | x = x - self.rs.mean 63 | if self.destd: 64 | x = x / (self.rs.std + 1e-8) 65 | if self.clip: 66 | x = np.clip(x, -self.clip, self.clip) 67 | return x 68 | 69 | def output_shape(self, input_space): 70 | return input_space.shape 71 | -------------------------------------------------------------------------------- /trpo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch.autograd import Variable 5 | from utils import * 6 | 7 | 8 | def conjugate_gradients(Avp, b, nsteps, residual_tol=1e-10): 9 | x = torch.zeros(b.size()) 10 | r = b.clone() 11 | p = b.clone() 12 | rdotr = torch.dot(r, r) 13 | for i in range(nsteps): 14 | _Avp = Avp(p) 15 | alpha = rdotr / torch.dot(p, _Avp) 16 | x += alpha * p 17 | r -= alpha * _Avp 18 | new_rdotr = torch.dot(r, r) 19 | betta = new_rdotr / rdotr 20 | p = r + betta * p 21 | rdotr = new_rdotr 22 | if rdotr < residual_tol: 23 | break 24 | return x 25 | 26 | 27 | def linesearch(model, 28 | f, 29 | x, 30 | fullstep, 31 | expected_improve_rate, 32 | max_backtracks=10, 33 | accept_ratio=.1): 34 | fval = f(True).data 35 | print("fval before", fval.item()) 36 | for (_n_backtracks, stepfrac) in enumerate(.5**np.arange(max_backtracks)): 37 | xnew = x + stepfrac * fullstep 38 | set_flat_params_to(model, xnew) 39 | newfval = f(True).data 40 | actual_improve = fval - newfval 41 | expected_improve = expected_improve_rate * stepfrac 42 | ratio = actual_improve / expected_improve 43 | print("a/e/r", actual_improve.item(), expected_improve.item(), ratio.item()) 44 | 45 | if ratio.item() > accept_ratio and actual_improve.item() > 0: 46 | print("fval after", newfval.item()) 47 | return True, xnew 48 | return False, x 49 | 50 | 51 | def trpo_step(model, get_loss, get_kl, max_kl, damping): 52 | loss = get_loss() 53 | grads = torch.autograd.grad(loss, model.parameters()) 54 | loss_grad = torch.cat([grad.view(-1) for grad in grads]).data 55 | 56 | def Fvp(v): 57 | kl = get_kl() 58 | kl = kl.mean() 59 | 60 | grads = torch.autograd.grad(kl, model.parameters(), create_graph=True) 61 | flat_grad_kl = torch.cat([grad.view(-1) for grad in grads]) 62 | 63 | kl_v = (flat_grad_kl * Variable(v)).sum() 64 | grads = torch.autograd.grad(kl_v, model.parameters()) 65 | flat_grad_grad_kl = torch.cat([grad.contiguous().view(-1) for grad in grads]).data 66 | 67 | return flat_grad_grad_kl + v * damping 68 | 69 | stepdir = conjugate_gradients(Fvp, -loss_grad, 10) 70 | 71 | shs = 0.5 * (stepdir * Fvp(stepdir)).sum(0, keepdim=True) 72 | 73 | lm = torch.sqrt(shs / max_kl) 74 | fullstep = stepdir / lm[0] 75 | 76 | neggdotstepdir = (-loss_grad * stepdir).sum(0, keepdim=True) 77 | print(("lagrange multiplier:", lm[0], "grad_norm:", loss_grad.norm())) 78 | 79 | prev_params = get_flat_params_from(model) 80 | success, new_params = linesearch(model, get_loss, prev_params, fullstep, 81 | neggdotstepdir / lm[0]) 82 | set_flat_params_to(model, new_params) 83 | 84 | return loss 85 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | 5 | import torch 6 | 7 | 8 | def normal_entropy(std): 9 | var = std.pow(2) 10 | entropy = 0.5 + 0.5 * torch.log(2 * var * math.pi) 11 | return entropy.sum(1, keepdim=True) 12 | 13 | 14 | def normal_log_density(x, mean, log_std, std): 15 | var = std.pow(2) 16 | log_density = -(x - mean).pow(2) / ( 17 | 2 * var) - 0.5 * math.log(2 * math.pi) - log_std 18 | return log_density.sum(1, keepdim=True) 19 | 20 | 21 | def get_flat_params_from(model): 22 | params = [] 23 | for param in model.parameters(): 24 | params.append(param.data.view(-1)) 25 | 26 | flat_params = torch.cat(params) 27 | return flat_params 28 | 29 | 30 | def set_flat_params_to(model, flat_params): 31 | prev_ind = 0 32 | for param in model.parameters(): 33 | flat_size = int(np.prod(list(param.size()))) 34 | param.data.copy_( 35 | flat_params[prev_ind:prev_ind + flat_size].view(param.size())) 36 | prev_ind += flat_size 37 | 38 | 39 | def get_flat_grad_from(net, grad_grad=False): 40 | grads = [] 41 | for param in net.parameters(): 42 | if grad_grad: 43 | grads.append(param.grad.grad.view(-1)) 44 | else: 45 | grads.append(param.grad.view(-1)) 46 | 47 | flat_grad = torch.cat(grads) 48 | return flat_grad 49 | --------------------------------------------------------------------------------