├── .gitignore ├── LICENSE.md ├── README.md ├── main.py ├── models.py ├── replay_memory.py ├── running_state.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Ilya Kostrikov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch implementation of PPO 2 | 3 | NOTE: This is not maintained. I recommend using the implementation [here](https://github.com/ikostrikov/pytorch-a2c-ppo-acktr). It is much more full featured and tested. 4 | 5 | This is a PyTorch implementation of [Proximal Policy Optimization](https://arxiv.org/abs/1707.06347). 6 | 7 | This is code mostly ported from the [OpenAI baselines implementation](https://github.com/openai/baselines) but currently does not optimize each batch for several epochs. I will add this soon. 8 | 9 | ## Usage 10 | 11 | ``` 12 | python main.py --env-name Walker2d-v1 13 | ``` 14 | ## Contributions 15 | 16 | Contributions are very welcome. If you know how to make this code better, don't hesitate to send a pull request. 17 | 18 | ## Todo 19 | 20 | - [ ] Add multiple epochs per batch 21 | - [ ] Test results compared to baselines code 22 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import math 4 | from collections import namedtuple 5 | from itertools import count 6 | 7 | import gym 8 | import numpy as np 9 | import scipy.optimize 10 | from gym import wrappers 11 | 12 | import torch 13 | import torch.autograd as autograd 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch.optim as optim 17 | import torchvision.transforms as T 18 | from torch.autograd import Variable 19 | 20 | from models import Policy, Value, ActorCritic 21 | from replay_memory import Memory 22 | from running_state import ZFilter 23 | 24 | # from utils import * 25 | 26 | torch.set_default_tensor_type('torch.DoubleTensor') 27 | PI = torch.DoubleTensor([3.1415926]) 28 | 29 | parser = argparse.ArgumentParser(description='PyTorch actor-critic example') 30 | parser.add_argument('--gamma', type=float, default=0.995, metavar='G', 31 | help='discount factor (default: 0.995)') 32 | parser.add_argument('--env-name', default="Reacher-v1", metavar='G', 33 | help='name of the environment to run') 34 | parser.add_argument('--tau', type=float, default=0.97, metavar='G', 35 | help='gae (default: 0.97)') 36 | # parser.add_argument('--l2_reg', type=float, default=1e-3, metavar='G', 37 | # help='l2 regularization regression (default: 1e-3)') 38 | # parser.add_argument('--max_kl', type=float, default=1e-2, metavar='G', 39 | # help='max kl value (default: 1e-2)') 40 | # parser.add_argument('--damping', type=float, default=1e-1, metavar='G', 41 | # help='damping (default: 1e-1)') 42 | parser.add_argument('--seed', type=int, default=543, metavar='N', 43 | help='random seed (default: 1)') 44 | parser.add_argument('--batch-size', type=int, default=5000, metavar='N', 45 | help='batch size (default: 5000)') 46 | parser.add_argument('--render', action='store_true', 47 | help='render the environment') 48 | parser.add_argument('--log-interval', type=int, default=1, metavar='N', 49 | help='interval between training status logs (default: 10)') 50 | parser.add_argument('--entropy-coeff', type=float, default=0.0, metavar='N', 51 | help='coefficient for entropy cost') 52 | parser.add_argument('--clip-epsilon', type=float, default=0.2, metavar='N', 53 | help='Clipping for PPO grad') 54 | parser.add_argument('--use-joint-pol-val', action='store_true', 55 | help='whether to use combined policy and value nets') 56 | args = parser.parse_args() 57 | 58 | env = gym.make(args.env_name) 59 | 60 | num_inputs = env.observation_space.shape[0] 61 | num_actions = env.action_space.shape[0] 62 | 63 | env.seed(args.seed) 64 | torch.manual_seed(args.seed) 65 | 66 | if args.use_joint_pol_val: 67 | ac_net = ActorCritic(num_inputs, num_actions) 68 | opt_ac = optim.Adam(ac_net.parameters(), lr=0.001) 69 | else: 70 | policy_net = Policy(num_inputs, num_actions) 71 | value_net = Value(num_inputs) 72 | opt_policy = optim.Adam(policy_net.parameters(), lr=0.001) 73 | opt_value = optim.Adam(value_net.parameters(), lr=0.001) 74 | 75 | def select_action(state): 76 | state = torch.from_numpy(state).unsqueeze(0) 77 | action_mean, _, action_std = policy_net(Variable(state)) 78 | action = torch.normal(action_mean, action_std) 79 | return action 80 | 81 | def select_action_actor_critic(state): 82 | state = torch.from_numpy(state).unsqueeze(0) 83 | action_mean, _, action_std, v = ac_net(Variable(state)) 84 | action = torch.normal(action_mean, action_std) 85 | return action 86 | 87 | def normal_log_density(x, mean, log_std, std): 88 | var = std.pow(2) 89 | log_density = -(x - mean).pow(2) / (2 * var) - 0.5 * torch.log(2 * Variable(PI)) - log_std 90 | return log_density.sum(1) 91 | 92 | def update_params_actor_critic(batch): 93 | rewards = torch.Tensor(batch.reward) 94 | masks = torch.Tensor(batch.mask) 95 | actions = torch.Tensor(np.concatenate(batch.action, 0)) 96 | states = torch.Tensor(batch.state) 97 | action_means, action_log_stds, action_stds, values = ac_net(Variable(states)) 98 | 99 | returns = torch.Tensor(actions.size(0),1) 100 | deltas = torch.Tensor(actions.size(0),1) 101 | advantages = torch.Tensor(actions.size(0),1) 102 | 103 | prev_return = 0 104 | prev_value = 0 105 | prev_advantage = 0 106 | for i in reversed(range(rewards.size(0))): 107 | returns[i] = rewards[i] + args.gamma * prev_return * masks[i] 108 | deltas[i] = rewards[i] + args.gamma * prev_value * masks[i] - values.data[i] 109 | advantages[i] = deltas[i] + args.gamma * args.tau * prev_advantage * masks[i] 110 | prev_return = returns[i, 0] 111 | prev_value = values.data[i, 0] 112 | prev_advantage = advantages[i, 0] 113 | 114 | targets = Variable(returns) 115 | 116 | # kloldnew = policy_net.kl_old_new() # oldpi.pd.kl(pi.pd) 117 | # ent = policy_net.entropy() #pi.pd.entropy() 118 | # meankl = torch.reduce_mean(kloldnew) 119 | # meanent = torch.reduce_mean(ent) 120 | # pol_entpen = (-args.entropy_coeff) * meanent 121 | 122 | action_var = Variable(actions) 123 | # compute probs from actions above 124 | log_prob_cur = normal_log_density(action_var, action_means, action_log_stds, action_stds) 125 | 126 | action_means_old, action_log_stds_old, action_stds_old, values_old = ac_net(Variable(states), old=True) 127 | log_prob_old = normal_log_density(action_var, action_means_old, action_log_stds_old, action_stds_old) 128 | 129 | # backup params after computing probs but before updating new params 130 | ac_net.backup() 131 | 132 | advantages = (advantages - advantages.mean()) / advantages.std() 133 | advantages_var = Variable(advantages) 134 | 135 | opt_ac.zero_grad() 136 | ratio = torch.exp(log_prob_cur - log_prob_old) # pnew / pold 137 | surr1 = ratio * advantages_var[:,0] 138 | surr2 = torch.clamp(ratio, 1.0 - args.clip_epsilon, 1.0 + args.clip_epsilon) * advantages_var[:,0] 139 | policy_surr = -torch.min(surr1, surr2).mean() 140 | 141 | vf_loss1 = (values - targets).pow(2.) 142 | vpredclipped = values_old + torch.clamp(values - values_old, -args.clip_epsilon, args.clip_epsilon) 143 | vf_loss2 = (vpredclipped - targets).pow(2.) 144 | vf_loss = 0.5 * torch.max(vf_loss1, vf_loss2).mean() 145 | 146 | total_loss = policy_surr + vf_loss 147 | total_loss.backward() 148 | torch.nn.utils.clip_grad_norm(ac_net.parameters(), 40) 149 | opt_ac.step() 150 | 151 | 152 | def update_params(batch): 153 | rewards = torch.Tensor(batch.reward) 154 | masks = torch.Tensor(batch.mask) 155 | actions = torch.Tensor(np.concatenate(batch.action, 0)) 156 | states = torch.Tensor(batch.state) 157 | values = value_net(Variable(states)) 158 | 159 | returns = torch.Tensor(actions.size(0),1) 160 | deltas = torch.Tensor(actions.size(0),1) 161 | advantages = torch.Tensor(actions.size(0),1) 162 | 163 | prev_return = 0 164 | prev_value = 0 165 | prev_advantage = 0 166 | for i in reversed(range(rewards.size(0))): 167 | returns[i] = rewards[i] + args.gamma * prev_return * masks[i] 168 | deltas[i] = rewards[i] + args.gamma * prev_value * masks[i] - values.data[i] 169 | advantages[i] = deltas[i] + args.gamma * args.tau * prev_advantage * masks[i] 170 | prev_return = returns[i, 0] 171 | prev_value = values.data[i, 0] 172 | prev_advantage = advantages[i, 0] 173 | 174 | targets = Variable(returns) 175 | 176 | opt_value.zero_grad() 177 | value_loss = (values - targets).pow(2.).mean() 178 | value_loss.backward() 179 | opt_value.step() 180 | 181 | # kloldnew = policy_net.kl_old_new() # oldpi.pd.kl(pi.pd) 182 | # ent = policy_net.entropy() #pi.pd.entropy() 183 | # meankl = torch.reduce_mean(kloldnew) 184 | # meanent = torch.reduce_mean(ent) 185 | # pol_entpen = (-args.entropy_coeff) * meanent 186 | 187 | action_var = Variable(actions) 188 | 189 | action_means, action_log_stds, action_stds = policy_net(Variable(states)) 190 | log_prob_cur = normal_log_density(action_var, action_means, action_log_stds, action_stds) 191 | 192 | action_means_old, action_log_stds_old, action_stds_old = policy_net(Variable(states), old=True) 193 | log_prob_old = normal_log_density(action_var, action_means_old, action_log_stds_old, action_stds_old) 194 | 195 | # backup params after computing probs but before updating new params 196 | policy_net.backup() 197 | 198 | advantages = (advantages - advantages.mean()) / advantages.std() 199 | advantages_var = Variable(advantages) 200 | 201 | opt_policy.zero_grad() 202 | ratio = torch.exp(log_prob_cur - log_prob_old) # pnew / pold 203 | surr1 = ratio * advantages_var[:,0] 204 | surr2 = torch.clamp(ratio, 1.0 - args.clip_epsilon, 1.0 + args.clip_epsilon) * advantages_var[:,0] 205 | policy_surr = -torch.min(surr1, surr2).mean() 206 | policy_surr.backward() 207 | torch.nn.utils.clip_grad_norm(policy_net.parameters(), 40) 208 | opt_policy.step() 209 | 210 | running_state = ZFilter((num_inputs,), clip=5) 211 | running_reward = ZFilter((1,), demean=False, clip=10) 212 | episode_lengths = [] 213 | 214 | for i_episode in count(1): 215 | memory = Memory() 216 | 217 | num_steps = 0 218 | reward_batch = 0 219 | num_episodes = 0 220 | while num_steps < args.batch_size: 221 | state = env.reset() 222 | state = running_state(state) 223 | 224 | reward_sum = 0 225 | for t in range(10000): # Don't infinite loop while learning 226 | if args.use_joint_pol_val: 227 | action = select_action_actor_critic(state) 228 | else: 229 | action = select_action(state) 230 | action = action.data[0].numpy() 231 | next_state, reward, done, _ = env.step(action) 232 | reward_sum += reward 233 | 234 | next_state = running_state(next_state) 235 | 236 | mask = 1 237 | if done: 238 | mask = 0 239 | 240 | memory.push(state, np.array([action]), mask, next_state, reward) 241 | 242 | if args.render: 243 | env.render() 244 | if done: 245 | break 246 | 247 | state = next_state 248 | num_steps += (t-1) 249 | num_episodes += 1 250 | reward_batch += reward_sum 251 | 252 | reward_batch /= num_episodes 253 | batch = memory.sample() 254 | if args.use_joint_pol_val: 255 | update_params_actor_critic(batch) 256 | else: 257 | update_params(batch) 258 | 259 | if i_episode % args.log_interval == 0: 260 | print('Episode {}\tLast reward: {}\tAverage reward {:.2f}'.format( 261 | i_episode, reward_sum, reward_batch)) 262 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | 4 | import torch 5 | import torch.autograd as autograd 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | def square(a): 10 | return torch.pow(a, 2.) 11 | 12 | class ActorCritic(nn.Module): 13 | 14 | def __init__(self, num_inputs, num_outputs, hidden=64): 15 | super(ActorCritic, self).__init__() 16 | self.affine1 = nn.Linear(num_inputs, hidden) 17 | self.affine2 = nn.Linear(hidden, hidden) 18 | 19 | self.action_mean = nn.Linear(hidden, num_outputs) 20 | self.action_mean.weight.data.mul_(0.1) 21 | self.action_mean.bias.data.mul_(0.0) 22 | self.action_log_std = nn.Parameter(torch.zeros(1, num_outputs)) 23 | 24 | self.value_head = nn.Linear(hidden, 1) 25 | 26 | self.module_list_current = [self.affine1, self.affine2, self.action_mean, self.action_log_std, self.value_head] 27 | self.module_list_old = [None]*len(self.module_list_current) 28 | self.backup() 29 | 30 | def backup(self): 31 | for i in range(len(self.module_list_current)): 32 | self.module_list_old[i] = copy.deepcopy(self.module_list_current[i]) 33 | 34 | def forward(self, x, old=False): 35 | if old: 36 | x = F.tanh(self.module_list_old[0](x)) 37 | x = F.tanh(self.module_list_old[1](x)) 38 | 39 | action_mean = self.module_list_old[2](x) 40 | action_log_std = self.module_list_old[3].expand_as(action_mean) 41 | action_std = torch.exp(action_log_std) 42 | 43 | value = self.module_list_old[4](x) 44 | else: 45 | x = F.tanh(self.affine1(x)) 46 | x = F.tanh(self.affine2(x)) 47 | 48 | action_mean = self.action_mean(x) 49 | action_log_std = self.action_log_std.expand_as(action_mean) 50 | action_std = torch.exp(action_log_std) 51 | 52 | value = self.value_head(x) 53 | 54 | return action_mean, action_log_std, action_std, value 55 | 56 | class Policy(nn.Module): 57 | 58 | def __init__(self, num_inputs, num_outputs): 59 | super(Policy, self).__init__() 60 | self.affine1 = nn.Linear(num_inputs, 64) 61 | self.affine2 = nn.Linear(64, 64) 62 | self.action_mean = nn.Linear(64, num_outputs) 63 | self.action_mean.weight.data.mul_(0.1) 64 | self.action_mean.bias.data.mul_(0.0) 65 | self.action_log_std = nn.Parameter(torch.zeros(1, num_outputs)) 66 | self.module_list_current = [self.affine1, self.affine2, self.action_mean, self.action_log_std] 67 | 68 | self.module_list_old = [None]*len(self.module_list_current) #self.affine1_old, self.affine2_old, self.action_mean_old, self.action_log_std_old] 69 | self.backup() 70 | 71 | def backup(self): 72 | for i in range(len(self.module_list_current)): 73 | self.module_list_old[i] = copy.deepcopy(self.module_list_current[i]) 74 | 75 | def kl_div_p_q(self, p_mean, p_std, q_mean, q_std): 76 | """KL divergence D_{KL}[p(x)||q(x)] for a fully factorized Gaussian""" 77 | # print (type(p_mean), type(p_std), type(q_mean), type(q_std)) 78 | # q_mean = Variable(torch.DoubleTensor([q_mean])).expand_as(p_mean) 79 | # q_std = Variable(torch.DoubleTensor([q_std])).expand_as(p_std) 80 | numerator = square(p_mean - q_mean) + \ 81 | square(p_std) - square(q_std) #.expand_as(p_std) 82 | denominator = 2. * square(q_std) + eps 83 | return torch.sum(numerator / denominator + torch.log(q_std) - torch.log(p_std)) 84 | 85 | def kl_old_new(self): 86 | """Gives kld from old params to new params""" 87 | kl_div = self.kl_div_p_q(self.module_list_old[-2], self.module_list_old[-1], self.action_mean, self.action_log_std) 88 | return kl_div 89 | 90 | def entropy(self): 91 | """Gives entropy of current defined prob dist""" 92 | ent = torch.sum(self.action_log_std + .5 * torch.log(2.0 * np.pi * np.e)) 93 | return ent 94 | 95 | def forward(self, x, old=False): 96 | if old: 97 | x = F.tanh(self.module_list_old[0](x)) 98 | x = F.tanh(self.module_list_old[1](x)) 99 | 100 | action_mean = self.module_list_old[2](x) 101 | action_log_std = self.module_list_old[3].expand_as(action_mean) 102 | action_std = torch.exp(action_log_std) 103 | else: 104 | x = F.tanh(self.affine1(x)) 105 | x = F.tanh(self.affine2(x)) 106 | 107 | action_mean = self.action_mean(x) 108 | action_log_std = self.action_log_std.expand_as(action_mean) 109 | action_std = torch.exp(action_log_std) 110 | 111 | return action_mean, action_log_std, action_std 112 | 113 | 114 | class Value(nn.Module): 115 | def __init__(self, num_inputs): 116 | super(Value, self).__init__() 117 | self.affine1 = nn.Linear(num_inputs, 64) 118 | self.affine2 = nn.Linear(64, 64) 119 | self.value_head = nn.Linear(64, 1) 120 | self.value_head.weight.data.mul_(0.1) 121 | self.value_head.bias.data.mul_(0.0) 122 | 123 | def forward(self, x): 124 | x = F.tanh(self.affine1(x)) 125 | x = F.tanh(self.affine2(x)) 126 | 127 | state_values = self.value_head(x) 128 | return state_values 129 | -------------------------------------------------------------------------------- /replay_memory.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import namedtuple 3 | 4 | # Taken from 5 | # https://github.com/pytorch/tutorials/blob/master/Reinforcement%20(Q-)Learning%20with%20PyTorch.ipynb 6 | 7 | Transition = namedtuple('Transition', ('state', 'action', 'mask', 'next_state', 8 | 'reward')) 9 | 10 | class Memory(object): 11 | def __init__(self): 12 | self.memory = [] 13 | 14 | def push(self, state, action, mask, next_state, reward): 15 | """Saves a transition.""" 16 | self.memory.append(Transition(state, action, mask, next_state, reward)) 17 | 18 | def sample(self): 19 | return Transition(*zip(*self.memory)) 20 | 21 | def __len__(self): 22 | return len(self.memory) 23 | -------------------------------------------------------------------------------- /running_state.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | 3 | import numpy as np 4 | 5 | 6 | # from https://github.com/joschu/modular_rl 7 | # http://www.johndcook.com/blog/standard_deviation/ 8 | class RunningStat(object): 9 | def __init__(self, shape): 10 | self._n = 0 11 | self._M = np.zeros(shape) 12 | self._S = np.zeros(shape) 13 | 14 | def push(self, x): 15 | x = np.asarray(x) 16 | # print ("push shape: ", x.shape) 17 | assert x.shape == self._M.shape 18 | self._n += 1 19 | if self._n == 1: 20 | self._M[...] = x 21 | else: 22 | oldM = self._M.copy() 23 | self._M[...] = oldM + (x - oldM) / self._n 24 | self._S[...] = self._S + (x - oldM) * (x - self._M) 25 | 26 | @property 27 | def n(self): 28 | return self._n 29 | 30 | @property 31 | def mean(self): 32 | return self._M 33 | 34 | @property 35 | def var(self): 36 | return self._S / (self._n - 1) if self._n > 1 else np.square(self._M) 37 | 38 | @property 39 | def std(self): 40 | return np.sqrt(self.var) 41 | 42 | @property 43 | def shape(self): 44 | return self._M.shape 45 | 46 | 47 | class ZFilter: 48 | """ 49 | y = (x-mean)/std 50 | using running estimates of mean,std 51 | """ 52 | 53 | def __init__(self, shape, demean=True, destd=True, clip=10.0): 54 | self.demean = demean 55 | self.destd = destd 56 | self.clip = clip 57 | 58 | # print ("Zfilter shape: ", shape) 59 | self.rs = RunningStat(shape) 60 | 61 | def __call__(self, x, update=True): 62 | if update: self.rs.push(x) 63 | if self.demean: 64 | x = x - self.rs.mean 65 | if self.destd: 66 | x = x / (self.rs.std + 1e-8) 67 | if self.clip: 68 | x = np.clip(x, -self.clip, self.clip) 69 | return x 70 | 71 | def output_shape(self, input_space): 72 | return input_space.shape 73 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def normal_entropy(std): 7 | var = std.pow(2) 8 | entropy = 0.5 + 0.5 * torch.log(2 * var * math.pi) 9 | return entropy.sum(1) 10 | 11 | 12 | def normal_log_density(x, mean, log_std, std): 13 | var = std.pow(2) 14 | log_density = -(x - mean).pow(2) / ( 15 | 2 * var) - 0.5 * math.log(2 * math.pi) - log_std 16 | return log_density.sum(1) 17 | 18 | 19 | def get_flat_params_from(model): 20 | params = [] 21 | for param in model.parameters(): 22 | params.append(param.data.view(-1)) 23 | 24 | flat_params = torch.cat(params) 25 | return flat_params 26 | 27 | 28 | def set_flat_params_to(model, flat_params): 29 | prev_ind = 0 30 | for param in model.parameters(): 31 | flat_size = int(np.prod(list(param.size()))) 32 | param.data.copy_( 33 | flat_params[prev_ind:prev_ind + flat_size].view(param.size())) 34 | prev_ind += flat_size 35 | 36 | 37 | def get_flat_grad_from(net, grad_grad=False): 38 | grads = [] 39 | for param in net.parameters(): 40 | if grad_grad: 41 | grads.append(param.grad.grad.view(-1)) 42 | else: 43 | grads.append(param.grad.view(-1)) 44 | 45 | flat_grad = torch.cat(grads) 46 | return flat_grad 47 | --------------------------------------------------------------------------------