├── LICENSE ├── README.md ├── agents ├── activations.py ├── buffers.py ├── crossq_cem.py ├── normalise.py ├── ppo_baseline.py ├── sac_baseline.py ├── sac_crossq.py ├── sac_crossq_bro.py ├── sac_droq.py └── utils.py ├── hypertune.py ├── learn_simple.py ├── learn_vectorised.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 modelbased 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mini RL Lab 2 | 3 | ### Easy Agent Experiments for Beginners 4 | 5 | I wrote this set of scripts to help me research and experiement with the latest concepts in RL, as well as a way to learn Python and PyTorch. 6 | 7 | It is a setup and workflow that works well for me to debug and experiment with concepts like agent algorithms, world models, planning, plasticity, transformers etc, and other beginners might find it a useful starting point for their own experiments. 8 | 9 | The focus of Mini RL Lab is on continous control gym-like environments aimed at physical systems or expensive-to-sample simulations - my personal research interest, and it makes the problem and architecture space tractable. 10 | 11 | The basis are CleanRL's PPO and SAC agents [https://github.com/vwxyzjn/cleanrl] which I modified to: 12 | 13 | 1. Separate the environment rollout and logging from the agent code. CleanRL's single file approach is great but I find this arrangement easier for experiments 14 | 2. Simplify the code, improve performance where possible 15 | 3. Use different specialised training scripts 16 | 4. Include algorithms + variants as baselines with which to compare 17 | 18 | ### Benefits 19 | 20 | 1. Agents based on established, tried and tested baselines from CleanRL 21 | 2. Agents are structured for easy experimentation, whilst staying ~"one file" 22 | 3. Various performance considerations such as minimising cpu<>gpu syncs, data transfers from buffer etc 23 | 1. Helpful for those of us limited to one workstation and a midrange GPU 24 | 4. Inline comments document design choices and links to source papers 25 | 5. Learn scripts implement a lot best practices I discovered as I went, minor (data logging structure) to major (multiprocessing allows running a number of parallel agents with different seeds, essential in RL) 26 | 27 | ### Prerequisites 28 | 29 | * Pytorch 2 (though 1.x will work with small changes) 30 | * Numpy (1.25 though older should work) 31 | * Tensorboard 32 | * Gymnasium[Box2D] and/or [Mujoco] (https://gymnasium.farama.org) 33 | * Or other gym compatible environment of choice 34 | * Bayesian Optimisation (https://github.com/bayesian-optimization/BayesianOptimization) 35 | 36 | ### Quickstart 37 | 38 | Test a change quickly for major errors: 39 | 40 | `Python learn_simple.py` 41 | 42 | Training run with multiple random seeds logging to tensorboard: 43 | 44 | `Python learn_simple.py --log --seed 8 --name "testing X new feature"` 45 | 46 | Run a vectorised environment with cuda and log: 47 | 48 | `Python learn_vectorised.py --log --cuda` 49 | 50 | Use bayesian optimisation to optimise hyperparameter(s): 51 | 52 | `Python hypertune.py` 53 | 54 | ### Usage Notes 55 | 56 | * ppo_baseline 57 | * Based on CleanRL's continuous PPO agent 58 | * Simplified for easy modifictions 59 | * Improved samples per second through small optimisations 60 | 61 | * sac_baseline 62 | * Based on CleanRL's continuous SAC agent 63 | * Simplified for easy modification 64 | * Removed CUDA <> CPU synchronisations for better performance 65 | * Variants: DroQ and CrossQ 66 | * CrossQ in particular is great, see the paper: https://arxiv.org/abs/1902.05605 67 | 68 | * Novel agents 69 | * crossq_cem 70 | * Based on sac_crossq with actor replaced by a cross entropy method optimiser 71 | * Inspired by QT-Opt https://arxiv.org/abs/1806.10293 and TD-MPC2 https://arxiv.org/abs/2310.16828 72 | * Question: Can QT-Opt's performance improve to match TD-MPC2 using CrossQ's improvements to the Q functions? 73 | * Result: maybe, but CEM actor is so compute intensive it is not clear this is a direction worth pursuing 74 | * WIP, could be improved 75 | 76 | * **sac_crossq_bro** 77 | * Inspired by various papers showing that SAC can be improved with (a) more compute (b) regularisation (c) simple design changes 78 | * **Promising results**, first agent in Mini RL Lab that seems to reliably solve WalkerHardcore in <500k steps 79 | * WIP, not tuned or optimised yet 80 | 81 | * learn_simple.py 82 | * Multiple training runs in parallel using multiprocessing (the processes have independent agents and environments) 83 | * Few assumptions about environments, more easily compatible with rl envs approximating the open ai gym api 84 | * Easy to edit and modify, design choices in comments 85 | * Use case: test performance of new feature on multiple environments with many random seeds in parallel 86 | 87 | * learn_vectorised.py 88 | * No multiprocessing, runs a single process 89 | * PPO seems to really need different hyperparameters when vectorised 90 | * Use case: check performance when vectorised 91 | 92 | * hypertune.py 93 | * uses bayesian optimisation to tune selected hyperparameters 94 | * uses multiprocessing to run multiple evaluations in parallel 95 | * implements a median pruner to stop badly performing runs early 96 | * Use case: optimise a new hyperparameter 97 | * Hyperparameters in RL https://arxiv.org/abs/2306.01324 is a good reference for this -------------------------------------------------------------------------------- /agents/activations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | ''' 5 | Non-linear activations 6 | Example usage: self.activation = torch.jit.script(LiSHT()) 7 | ''' 8 | 9 | # Used for normalisation in Dreamer v3 https://arxiv.org/abs/2301.04104 10 | # Added scale and bias learnable parameters like layer_norm 11 | # A useful stat.stackexchnage comment on the derivative of symlog https://stats.stackexchange.com/questions/605641/why-isnt-symmetric-log1x-used-as-neural-network-activation-function 12 | # TODO: receive shape like Rational to allow for (b,c,l) inputs; optional learnable like layer_norm 13 | class SymLog(nn.Module): 14 | def __init__(self, dim): 15 | super().__init__() 16 | # Learnable scale and bias like layer_norm 17 | self.scale = nn.Parameter(torch.ones(dim)) 18 | self.bias = nn.Parameter(torch.zeros(dim)) 19 | 20 | def forward(self, x): 21 | x = (torch.sign(x) * torch.log(torch.abs(x) + 1.0)) * self.scale + self.bias 22 | return x 23 | 24 | 25 | # LiSHT activation 26 | # https://arxiv.org/abs/1901.05894 27 | class LiSHT(nn.Module): 28 | def __init__(self) -> None: 29 | super().__init__() 30 | 31 | def forward(self, x): 32 | return x * torch.tanh(x) 33 | 34 | 35 | # OFN activation https://arxiv.org/abs/2403.05996 36 | # Unit ball normalisation to prevent gradient explosion 37 | # "all values are strictly between 0 and 1 and the gradients will be tangent to the unit sphere" 38 | class OFN(nn.Module): 39 | def __init__(self) -> None: 40 | super().__init__() 41 | 42 | def forward(self, x): 43 | return x / torch.linalg.vector_norm(x) 44 | 45 | # From Adaptive Rational Activations to Boost Deep Reinforcement Learning (v5, 2024) 46 | # rational function f(x) = Q(x) / P(x) where P and Q are polynomials and 47 | # (some) polynomial parameters are learnable 48 | # https://arxiv.org/abs/2102.09407v5 49 | class Rational(nn.Module): 50 | def __init__(self, num_dims=2, preset='lrelu') -> None: 51 | super().__init__() 52 | # m > n allows rationals to implicitly make use of residual connections 53 | m = 6 54 | n = 4 55 | 56 | # f(x) = Sum(P(x)) / 1 + Sum(Q(x)) where P and Q are polynomials 57 | # P(x) = ax^j , Q(x) = bx^k where a and b are learnable 58 | # whilst j and k are fixed integer power coefficients 59 | a, b = self._presets(preset) 60 | self.a = a 61 | self.b = b 62 | self.j = torch.arange(m) 63 | self.k = torch.arange(n) + 1 64 | 65 | # input could be dim = 2 (batch, channels) or 66 | # dim = 3 (batch, channels, length) for example 67 | # num_dims = x.dim() # if doing this dynamically 68 | p_shape = self.a.shape + (1,) * (num_dims) 69 | q_shape = self.b.shape + (1,) * (num_dims) 70 | 71 | # Reshape to broadcast with x 72 | self.abz = self.a.view(p_shape) 73 | self.jbz = self.j.view(p_shape) 74 | self.bbz = self.b.view(q_shape) 75 | self.kbz = self.k.view(q_shape) 76 | 77 | # Ensure they are learnable and/or moves to correct device 78 | self.ab = nn.Parameter(self.abz) 79 | self.bb = nn.Parameter(self.bbz) 80 | self.register_buffer('jb', self.jbz) 81 | self.register_buffer('kb', self.kbz) 82 | 83 | 84 | # 45% faster or more than reference implementation depending on sizes 85 | def forward(self, x: torch.Tensor) -> torch.Tensor: 86 | # Element-wise operation with broadcasting 87 | p = torch.sum(self.ab * x.pow(self.jb), dim=0) 88 | q = torch.sum((self.bb * x.pow(self.kb)).abs(), dim=0) + 1.0 89 | return p/q 90 | 91 | 92 | # From paper's own code base: https://github.com/ml-research/rational_activations/blob/master/rational/torch/rationals.py 93 | # See also here for eq. details: https://arxiv.org/abs/1907.06732 94 | # https://rational-activations.readthedocs.io/en/latest/index.html 95 | def reference(self, x): 96 | # Rational_PYTORCH_A_F 97 | # P(X) / Q(X) = a_0 + a_1 * X + ... + a_n * X^n / 1 + | b_1 * X | + | b_2 * X^2| + ... + | b_m * X ^m| 98 | 99 | weight_numerator = self.a 100 | weight_denominator = self.b 101 | 102 | z = x.view(-1) 103 | len_num, len_deno = len(weight_numerator), len(weight_denominator) 104 | # xps = torch.vander(z, max(len_num, len_deno), increasing=True) 105 | xps = self._get_xps(z, len_num, len_deno) 106 | numerator = xps.mul(weight_numerator).sum(1) 107 | expanded_dw = torch.cat([torch.tensor([1.]), weight_denominator, torch.zeros(len_num - len_deno - 1)]) 108 | denominator = xps.mul(expanded_dw).abs().sum(1) 109 | return numerator.div(denominator).view(x.shape) 110 | 111 | 112 | def _get_xps(self, z, len_numerator, len_denominator): 113 | xps = list() 114 | xps.append(z) 115 | for _ in range(max(len_numerator, len_denominator) - 2): 116 | xps.append(xps[-1].mul(z)) 117 | xps.insert(0, torch.ones_like(z)) 118 | return torch.stack(xps, 1) 119 | 120 | 121 | # https://github.com/ml-research/rational_activations/blob/master/rational/rationals_config.json 122 | def _presets(self, name): 123 | # presets assume m=6 and n=4 124 | 125 | if name == 'lrelu': 126 | # leaky_relu upperbound=3, lowerbound=-3 127 | num = torch.tensor([0.029792778657264946, 0.6183735264987601, 2.323309062531321, 3.051936237265109, 1.4854203263828845, 0.2510244961111299]) 128 | den = torch.tensor([-1.1419548357285474,4.393159974992486,0.8714712309957245, 0.34719662339598834]) 129 | elif name == 'tanh': 130 | # tanh, ub=3, lb=-3 131 | num = torch.tensor([-1.0804622559204184e-08,1.0003008043819048,-2.5878199375289335e-08,0.09632129918392647,3.4775841628196104e-09,0.0004255709234726337]) 132 | den = torch.tensor([-0.0013027181209176277,0.428349017422072,1.4524304083061898e-09,0.010796648111337176]) 133 | elif name == 'sigmoid': 134 | # sigmoid, ub=3, lb=-3 135 | num = torch.tensor([0.4999992534599381,0.25002157564685185,0.14061924500301096,0.049420492431596394,0.00876714851885483,0.0006442412789159799]) 136 | den = torch.tensor([2.1694506382753683e-09,0.28122766100417684,1.0123620714203357e-05,0.017531988049946]) 137 | elif name == 'gelu': 138 | # gelu, ub=3, lb=-3 139 | num = torch.tensor([-0.0012423594497499122,0.5080497063245629,0.41586363182937475,0.13022718688035761,0.024355900098993424,0.00290283948155535]) 140 | den = torch.tensor([-0.06675015696494944,0.17927646217001553,0.03746682605496631,1.6561610853276082e-10]) 141 | elif name == 'swish': 142 | # swish, ub=3, lb=-3 143 | num = torch.tensor([3.054879741161051e-07,0.5000007853744493,0.24999783422824703,0.05326628273219478,0.005803034571292244,0.0002751961022402342]) 144 | den = torch.tensor([-4.111554955950634e-06,0.10652899335007572,-1.2690007399796238e-06,0.0005502331264140556]) 145 | else: 146 | print("Error: No such preset for rational activations") 147 | exit() 148 | 149 | return num, den -------------------------------------------------------------------------------- /agents/buffers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | ''' 4 | Replay buffers for agents 5 | – On specific device, so no transfers or syncs when on GPU 6 | – pytorch only, no other dependencies 7 | – easily extensible by adding functions 8 | 9 | Additionally ReplayBufferSAC features 10 | – shared memory for use with multiprocessing 11 | – ability to return a trace (trace is created, no duplicate data is stored in the buffer) 12 | 13 | ''' 14 | 15 | class ReplayBufferPPO(): 16 | """ 17 | Call reset or flush at the end of a update 18 | Must call .store_transition() to advance pointer 19 | """ 20 | def __init__(self, obs_dim, action_dim, num_steps, num_env, device): 21 | 22 | self.obs_dim = obs_dim 23 | self.action_dim = action_dim 24 | 25 | # Storage setup 26 | self.obs = torch.zeros((num_steps, num_env) + (obs_dim,), device=device, requires_grad=False) 27 | self.actions = torch.zeros((num_steps, num_env) + (action_dim,), device=device, requires_grad=False) 28 | self.logprobs = torch.zeros((num_steps, num_env), device=device, requires_grad=False) 29 | self.rewards = torch.zeros((num_steps, num_env), device=device) 30 | self.dones = torch.zeros((num_steps, num_env), device=device, dtype=torch.int8, requires_grad=False) 31 | 32 | # Size in steps 33 | self.max_size = num_steps 34 | self.ptr = 0 35 | 36 | def store_choice(self, obs, action, logprob): 37 | self.obs[self.ptr] = obs 38 | self.actions[self.ptr] = action 39 | self.logprobs[self.ptr] = logprob 40 | 41 | def store_transition(self, reward, done): 42 | self.rewards[self.ptr] = reward 43 | self.dones[self.ptr] = done 44 | self.ptr += 1 45 | 46 | def get_ppo_update(self): 47 | s = int(self.ptr) 48 | 49 | # Flatten the batch 50 | b_obs = self.obs[0:s].reshape((-1,) + (self.obs_dim,)) 51 | b_actions = self.actions[0:s].reshape((-1,) + (self.action_dim,)) 52 | b_logprobs = self.logprobs[0:s].reshape(-1) 53 | return b_obs, b_actions, b_logprobs 54 | 55 | def get_gae(self): 56 | s = int(self.ptr) 57 | 58 | # Don't flatten for GAE 59 | b_obs = self.obs[0:s] 60 | b_rewards = self.rewards[0:s] 61 | b_dones = self.dones[0:s] 62 | return b_obs, b_rewards, b_dones 63 | 64 | def get_obs(self): 65 | s = int(self.ptr) 66 | return self.obs[0:s].reshape((-1,) + (self.obs_dim,)) 67 | 68 | def reset(self): 69 | self.ptr = 0 70 | 71 | def flush(self): 72 | self.obs.zero_() 73 | self.actions.zero_() 74 | self.logprobs.zero_() 75 | self.rewards.zero_() 76 | self.dones.zero_() 77 | self.ptr = 0 78 | 79 | 80 | class ReplayBufferSAC(): 81 | """ 82 | Circular buffer with reset() and flush() 83 | Must call .store_transition() to advance pointer 84 | """ 85 | def __init__(self, obs_dim, action_dim, max_size, num_env, device): 86 | 87 | self.obs_dim = obs_dim 88 | self.action_dim = action_dim 89 | self.num_env = num_env 90 | self.device = device 91 | 92 | # Use .share_memory_() to allow multiprocessing processes access to the same buffer data 93 | self.obs = torch.zeros((max_size, num_env) + (obs_dim,), device=device, requires_grad=False) 94 | self.actions = torch.zeros((max_size, num_env) + (action_dim,), device=device, requires_grad=False) 95 | self.rewards = torch.zeros((max_size, num_env), device=device, requires_grad=False) 96 | self.dones = torch.zeros((max_size, num_env), device=device, dtype=torch.int8, requires_grad=False) 97 | self.ep_num = torch.zeros((max_size, num_env), device=device, dtype=torch.int32, requires_grad=False) 98 | 99 | # Counters and bookeeping. Tensors so that adding .share_memory() enables multiprocessing shared memory support 100 | self.max_size = torch.tensor(min(int(1e6),max_size), dtype=torch.int32, device=device, requires_grad=False) 101 | self.ptr = torch.tensor(0, dtype=torch.int32, device=device, requires_grad=False) 102 | self.size = torch.tensor(0, dtype=torch.int32, device=device, requires_grad=False) 103 | self.ep_count = torch.zeros((num_env), dtype=torch.int32, device=device, requires_grad=False) 104 | 105 | 106 | def store_choice(self, obs, action): 107 | self.obs[self.ptr] = obs # o0 –> B 108 | self.actions[self.ptr] = action # a0 –> B 109 | return 110 | 111 | 112 | def store_transition(self, reward, done): 113 | self.rewards[self.ptr] = reward # r0 -> B 114 | self.dones[self.ptr] = done # d0 -> B 115 | self.ep_count += done # episode count increments when an episode finishes 116 | 117 | # In-place operations maintains .shared_memory() if that's in use 118 | self.ptr.add_(1) # t0 -> t1 119 | self.ptr %= self.max_size 120 | self.size.add_(1) 121 | self.size.clamp_(min=torch.zeros_like(self.size, device=self.device), max=self.max_size) 122 | self.ep_num[self.ptr] = self.ep_count # store episode number 123 | return 124 | 125 | 126 | def get_trace(self, idx, length): 127 | # Create a trace: start at idx, of length = length 128 | trace = torch.arange(idx, idx-length, step=-1, device=self.device) # start at idx and count backward 129 | window = self.ep_num[trace, :] # section of buffer we will get and mask 130 | idx_ep = self.ep_num[idx] # episode number at idx requested 131 | mask = (window == idx_ep) # mask out episodes != ep_num at idx 132 | obs = self.obs[trace, :, :] # get the trace we want 133 | obs[~mask, :] = 0.0 # apply mask and zero out any data from different episodes 134 | 135 | # Now a trace of actions_prev, the actions that resulted in these obs 136 | trace -= 1 # actions leading to obs 137 | window = self.ep_num[trace, :] # section of buffer we will get and mask 138 | idx_ep = self.ep_num[idx] # episode number at idx requested 139 | mask = (window == idx_ep) # mask out episodes != ep_num at idx 140 | actions_prev = self.actions[trace, :, :] # get the trace we want 141 | actions_prev[~mask, :] = 0.0 # apply mask and zero out any data from different episodes 142 | 143 | # return (batch, channels, length) used by convnets. 144 | return obs.permute(1, 2, 0), actions_prev.permute(1, 2, 0) # newest data in trace at (:, :, 0) 145 | 146 | 147 | def sample_trace(self, batch_size, length): 148 | ''' 149 | Returns a batch for obs and actions of trace lenght with zeros where data is not from the same episode 150 | b_actions are the actions that caused the b_obs, for critic training 151 | b_obs_next are the next obs also for actor critic training 152 | b_actions_next are for critic training and provides trace for new next action from actor(b_obs_next) 153 | b_rewards and b_dones are len=1, not provided as a trace 154 | ''' 155 | assert batch_size % self.num_env == 0, 'batch_size must be divisible by num_env' 156 | 157 | #TODO: Can almost certainly remove some ops 158 | def trace_and_mask(samples): 159 | 160 | # Create traces starting at each index 161 | inds = samples.unsqueeze(1).repeat(1,length) # make a 2d array of the indices 162 | count = torch.arange(0, -length, step=-1, device=self.device) # prepare a trace, same for all inds 163 | inds += count.unsqueeze(0) # inds now 2d array with a number of traces 164 | inds_trace = inds.view(-1) # reshape back into a 1d array of traces 165 | 166 | # Prepare window, which is series of traces from the buffer 167 | window = self.ep_num[inds_trace] # get episode number at that buffer position 168 | window = window.transpose(dim0=1,dim1=0).flatten() # we want a 1d vector with envs ordered sequentially, not interleaved 169 | 170 | inds_ep = self.ep_num[samples] # get episode number at that buffer position. samples are just the start of the trace, not the whole trace 171 | inds_ep = inds_ep.transpose(dim0=1, dim1=0) # correcting so each env's data will be sequential and not interleaved 172 | inds_ep = inds_ep.reshape(batch_size, -1) # correcting so each env's data will be sequential and not interleaved 173 | inds_ep = inds_ep.repeat(1,length).flatten() # now copy the episode number across the whole trace , flatten into a 1D vector 174 | 175 | # Create the mask to remove data from different episodes in the trace 176 | mask = (window - inds_ep) == 0 # data from correct episodes will match == 0 # 177 | mask = mask.view(batch_size,length) # we want shape (batch_size, length) 178 | 179 | return mask, inds_trace 180 | 181 | # Sample random indices 182 | end = self.size - 1 # allow for obs_next 183 | start = 1 # allow for actions_prev 184 | samples = torch.randint(start, end, (batch_size // self.num_env,), device=self.device) #BUG: causes cuda<>cpu sync ? 185 | 186 | # Make a mask and trace 187 | mask , inds_trace = trace_and_mask(samples) # ordinary samples 188 | mask_next, inds_trace_next = trace_and_mask(samples+1) # obs_next samples 189 | mask_prev, inds_trace_prev = trace_and_mask(samples-1) # action_prev samples 190 | 191 | # Get the obs and action data, re-arrange and mask 192 | b_obs = self.obs[inds_trace] 193 | b_obs.transpose_(dim0=1, dim1=0) 194 | b_obs = b_obs.reshape((batch_size, length) + (self.obs_dim,)) 195 | b_obs[~mask] = 0.0 196 | b_obs = b_obs.permute(0, 2, 1) # shape (batch, channels, length) 197 | 198 | b_actions = self.actions[inds_trace] 199 | b_actions.transpose_(dim0=1, dim1=0) 200 | b_actions = b_actions.reshape((-1,length) + (self.action_dim,)) 201 | b_actions[~mask] = 0.0 202 | b_actions = b_actions.permute(0, 2, 1) 203 | 204 | # Now do obs_next 205 | b_obs_next = self.obs[inds_trace_next] 206 | b_obs_next.transpose_(dim0=1, dim1=0) 207 | b_obs_next = b_obs_next.reshape((-1,length) + (self.obs_dim,)) 208 | b_obs_next[~mask_next] = 0.0 209 | b_obs_next = b_obs_next.permute(0, 2, 1) 210 | 211 | # Now do action_next 212 | b_actions_next = self.actions[inds_trace_next] 213 | b_actions_next.transpose_(dim0=1, dim1=0) 214 | b_actions_next = b_actions_next.reshape((-1,length) + (self.action_dim,)) 215 | b_actions_next[~mask_next] = 0.0 216 | b_actions_next = b_actions_next.permute(0, 2, 1) 217 | 218 | # Now do action_prev 219 | b_actions_prev = self.actions[inds_trace_prev] 220 | b_actions_prev.transpose_(dim0=1, dim1=0) 221 | b_actions_prev = b_actions_prev.reshape((-1,length) + (self.action_dim,)) 222 | b_actions_prev[~mask_prev] = 0.0 223 | b_actions_prev = b_actions_prev.permute(0, 2, 1) 224 | 225 | # No trace for these two but re-arranging needed for correct sequencing with obs and action 226 | b_rewards = self.rewards[samples].transpose(dim0=1,dim1=0).reshape(-1) 227 | b_dones = self.dones[samples].transpose(dim0=1,dim1=0).reshape(-1) 228 | 229 | # return osb and actions with shape (batch, channels, length) or (batch) for rewards and dones 230 | return b_obs, b_actions, b_obs_next, b_actions_next, b_actions_prev, b_rewards, b_dones # newest data in trace at (:, :, 0) 231 | 232 | 233 | def sample(self, batch_size, ere_bias=False): 234 | '''' For training the critic. Actions returned are the ones taken to cause the Observation ''' 235 | 236 | # Faster, but no uniqueness guarantee like inds = torch.randperm(s) 237 | end = self.size - 1 # allow for obs_next 238 | 239 | # BUG: conditionals probably causing cuda<>cpu syncs, optimise at some point 240 | # Emphasise Recent Experience bias https://arxiv.org/abs/1906.04009 241 | if ere_bias: 242 | start = torch.randint(0, end // 2, (1,), device=self.device) 243 | else: 244 | start = 0 # no bias, uniform sampling of buffer 245 | 246 | samples = torch.randint(start, end, (batch_size // self.num_env,), device=self.device) # on correct device to avoid cuda-cpu synchronisation 247 | 248 | # When biased account for circular buffer 249 | if ere_bias: 250 | if self.ptr < self.size: samples = (samples + self.ptr) % self.size # will force a cuda <> cpu sync 251 | 252 | # Flatten the batch (global_step, env_num, channels) -> (b_step, channels) 253 | b_obs = self.obs[samples].reshape((-1,) + (self.obs_dim,)) 254 | b_actions = self.actions[samples].reshape((-1,) + (self.action_dim,)) 255 | b_obs_next = self.obs[samples + 1].reshape((-1,) + (self.obs_dim,)) # Sometimes obs_next will be step zero of next episode, but ok for SAC 256 | b_rewards = self.rewards[samples].reshape(-1) 257 | b_dones = self.dones[samples].reshape(-1) 258 | return b_obs, b_actions, b_obs_next, b_rewards, b_dones 259 | 260 | def plasticity_data(self, batch_size): 261 | inds = torch.randint(0, self.size, (batch_size // self.num_env,), device=self.device) 262 | b_obs = self.obs[inds].reshape((-1,) + (self.obs_dim,)) 263 | b_actions = self.actions[inds].reshape((-1,) + (self.action_dim,)) 264 | return b_obs, b_actions 265 | 266 | def reset(self): 267 | self.ptr = 0 268 | self.size = 0 269 | return 270 | 271 | def flush(self): 272 | self.obs.zero_() 273 | self.actions.zero_() 274 | self.rewards.zero_() 275 | self.dones.zero_() 276 | self.ep_num.zero_() 277 | self.ptr = 0 278 | self.size = 0 279 | return -------------------------------------------------------------------------------- /agents/crossq_cem.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from types import SimpleNamespace 6 | from .utils import avg_weight_magnitude, count_dead_units, dormant_ratio 7 | from .buffers import ReplayBufferSAC 8 | 9 | ''' 10 | Based on https://github.com/vwxyzjn/cleanrl 11 | - SAC agent https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/sac_continuous_action.py 12 | - CrossQ extensiopns to SAC, replacing actor with cross entropy method action selection 13 | 14 | CrossQ extensions to SAC influenced by: 15 | - CrossQ: Batch Normalization in Deep Reinforcement Learning for Greater Sample Efficiency and Simplicity https://arxiv.org/abs/1902.05605 16 | - v3 of the arxiv paper (2023) 17 | 18 | Cross Entropy Method Action Optimisation influnced by 19 | - QT-Opt: Scalable Deep Reinforcement Learning for Vision-Based Robotic Manipulation https://arxiv.org/abs/1806.10293 20 | – TD-MPC2: Scalable, Robust World Models for Continuous Control https://arxiv.org/abs/2310.16828 21 | - iCEM: https://arxiv.org/abs/2008.06389 22 | 23 | Plasticity metrics and regularisation influenced by the following papers: 24 | - Maintaining Plasticity in Continual Learning via Regenerative Regularization https://arxiv.org/abs/2308.11958 25 | - Loss of Plasticity in Deep Continual Learning https://arxiv.org/abs/2306.13812 26 | - Sample-Efficient Reinforcement Learning by Breaking the Replay Ratio Barrier https://openreview.net/forum?id=OpC-9aBBVJe 27 | – Bigger, Better, Faster: Human-level Atari with human-level efficiency https://arxiv.org/abs/2305.19452 28 | ''' 29 | 30 | class SoftQNetwork(nn.Module): 31 | def __init__(self, obs_dim, action_dim, hidden_dim=256): 32 | super().__init__() 33 | 34 | self.mlp = nn.Sequential( 35 | nn.Linear(obs_dim + action_dim, hidden_dim, bias=False), # batchnorm has bias, this one is redundant 36 | nn.BatchNorm1d(hidden_dim, momentum=0.01), #CrossQ 37 | nn.ReLU(), 38 | nn.Linear(hidden_dim,hidden_dim, bias=False), 39 | nn.BatchNorm1d(hidden_dim, momentum=0.01), #CrossQ 40 | nn.ReLU(), 41 | nn.Linear(hidden_dim, 1) 42 | ) 43 | 44 | # init_model(self.mlp, init_method='xavier_uniform_') # CrossQ no mention of init 45 | 46 | def forward(self, x, a): 47 | x = torch.cat([x, a], 1) 48 | x = self.mlp(x) 49 | return x 50 | 51 | 52 | class Agent: 53 | def __init__(self, 54 | env_spec, 55 | buffer_size = int(1e6), 56 | num_env = 1, 57 | device = 'cpu', 58 | seed = 42, 59 | rr = 1, # RR = 1 for CrossQ 60 | q_lr = 1e-3, # CrossQ learning rates 61 | ): 62 | 63 | # Make global 64 | self.name = "crossq_cem" # name for logging 65 | self.obs_dim = env_spec['obs_dim'] # environment inputs for agent 66 | self.action_dim = env_spec['act_dim'] # agent outputs to environment 67 | self.act_max = env_spec['act_max'] # action range, scalar or vector 68 | self.act_min = env_spec['act_min'] # action range, scalar or vector 69 | self.device = device # gpu or cpu 70 | 71 | self.action_scale = torch.tensor(((self.act_max - self.act_min) * 0.5), device=device) 72 | self.action_bias = torch.tensor(((self.act_max + self.act_min) * 0.5), device=device) 73 | self.act_min = torch.tensor(self.act_min, device=device) 74 | self.act_max = torch.tensor(self.act_max, device=device) 75 | 76 | # All seeds default to 42 77 | torch.manual_seed(torch.tensor(seed)) 78 | torch.backends.cudnn.deterministic = True 79 | # torch.cuda.set_sync_debug_mode(1) # Set to 1 to receive warnings 80 | # torch.set_float32_matmul_precision("high") # "high" is 11% faster, but can reduce learning performance in certain envs 81 | 82 | # Hyperparameters 83 | hyperparameters = { 84 | "gamma" : 0.99, # (def: 0.99) Discount factor 85 | "q_lr" : q_lr, # (def: 1e-3) Q learning rate 86 | "learn_start" : int(5e3), # (def: 5e3) Start updating policies after this many global steps 87 | "batch_size" : 256, # (def: 256) Batch size of sample from replay buffer 88 | "dead_hurdle" : 0.001, # (def: 0.001) units with greater variation in output over one batch of data than this are not dead in plasticity terms 89 | "q_hidden_dim" : 512, # (def: 2048) CrossQ with 512 wide Qf did just as well, but with a little more variance 90 | "replay_ratio" : round(rr), 91 | "adam_betas" : (0.5, 0.999), # CrossQ 92 | } 93 | self.h = SimpleNamespace(**hyperparameters) 94 | 95 | # Loggin & debugging 96 | self.qf1_a_values = torch.tensor([0.0]) 97 | self.qf2_a_values = torch.tensor([0.0]) 98 | self.qf1_loss = 0 99 | self.qf2_loss = 0 100 | self.qf_loss = 0 101 | self.actor_loss = 0 102 | self.alpha_loss = 0 103 | self.actor_avg_wgt_mag = 0 # average weight magnitude as per https://arxiv.org/abs/2306.13812 104 | self.qf1_avg_wgt_mag = 0 105 | self.qf2_avg_wgt_mag = 0 106 | self.actor_dead_pct = 0 # dead units as per https://arxiv.org/abs/2306.13812 107 | self.qf1_dead_pct = 0 108 | self.qf2_dead_pct = 0 109 | self.qf1_dormant_ratio = 0 # DrM: Dormant Ratio Minimisation https://arxiv.org/abs/2310.19668 110 | self.qf2_dormant_ratio = 0 111 | self.actor_dormant_ratio = 0 112 | 113 | # Instantiate actor and Q networks, optimisers 114 | # CrossQ uses Adam but experience with AdamW is better 115 | self.qf1 = SoftQNetwork(self.obs_dim, self.action_dim, self.h.q_hidden_dim).to(device) 116 | self.qf2 = SoftQNetwork(self.obs_dim, self.action_dim, self.h.q_hidden_dim).to(device) 117 | self.q_optim = torch.optim.AdamW(list(self.qf1.parameters()) + list(self.qf2.parameters()), lr=self.h.q_lr, betas=self.h.adam_betas) 118 | 119 | # Storage setup 120 | self.rb = ReplayBufferSAC(self.obs_dim, self.action_dim, buffer_size, num_env, device=self.device) 121 | self.global_step = 0 122 | self.action_prev = None 123 | 124 | # CUDA timers for the update process 125 | self.chronos_start = torch.cuda.Event(enable_timing=True) 126 | self.chronos_end = torch.cuda.Event(enable_timing=True) 127 | 128 | 129 | def ce_action_solver(self, obs, explore=False): 130 | with torch.no_grad(): 131 | 132 | # Hyperparameters for CEM 133 | samples = 32 * self.action_dim # TODO: Starting wider and reducing sample numbers likely more efficient and effective 134 | topk = (samples // 8) 135 | iterations = 6 136 | converged_stdd = 0.025 137 | explore_std = 0.5 # TODO: Should reduce on schedule or as learning improves 138 | stdd_min = torch.tensor(0.01, device=self.device) 139 | 140 | batch_size = obs.shape[0] # input shape assumes (batch, channels) 141 | self.qf1.eval() # don't mess with crossq batchnorm layers during inference 142 | self.qf2.eval() 143 | 144 | 145 | # large stdd on first action otherwise smaller stdd to converge faster (assuming solution will be nearby) 146 | if self.action_prev is None: 147 | iter1_stdd = 1.0 148 | top_mean = torch.tile(torch.zeros(self.action_dim, device=self.device) + self.action_bias, (batch_size, 1)) 149 | else: 150 | iter1_stdd = 0.25 151 | top_mean = torch.tile(self.action_prev, (batch_size, 1)) 152 | 153 | # Important that we don't start optimising outside action range 154 | top_mean.clamp_(self.act_min, self.act_max) 155 | 156 | # action scaling is applied to stdd so it covers the action range correctly 157 | top_stdd = torch.ones((batch_size, self.action_dim), device=self.device) * iter1_stdd * self.action_scale # (batch, action_dim) 158 | 159 | for i in range(iterations): 160 | 161 | # Expand obs batch by tile of size samples 162 | obs_tiled = torch.tile(obs,(samples,1)) # (samples*batch_dim, channels) 163 | 164 | # Sample actions from normal distribution in as many batches as there are in obs 165 | #TODO: actions range should be clamped, make sure applying clamp does not break things 166 | actions_tiled = torch.normal(mean=torch.tile(top_mean, (samples,1)), std=torch.tile(top_stdd, (samples,1))) # (samples*batch_dim, action_dim) 167 | actions_tiled.clamp_(self.act_min, self.act_max) 168 | 169 | # Value each sampled action and get top (elite) indices 170 | v1 = self.qf1(obs_tiled, actions_tiled) # (samples*batch_size) 171 | v2 = self.qf2(obs_tiled, actions_tiled) # (samples*batch_size) 172 | v = torch.min(v1, v2).view(samples,batch_size) # (samples, batch_size) 173 | _, k = torch.topk(v, k=topk, dim=0) # (topk, batch_size) 174 | 175 | # Get mean and standard deviation of the top (elite) actions 176 | actions_view = actions_tiled.view(samples,batch_size, -1) # (samples, batch_size, action_dim) 177 | top_actions = actions_view[k,range(actions_view.shape[1]),:] # (topk, batch_size, action_dim) 178 | 179 | top_stdd, top_mean = torch.std_mean(top_actions, dim=0) # (batch_size, action_dim) 180 | top_stdd.clamp_min_(stdd_min) # clamp stdd to a min (also normal(std≤0) crashes 181 | 182 | # Break when cpnmverged 183 | if (torch.mean(top_stdd) < converged_stdd): 184 | break 185 | 186 | # Store it for next time if acting. Will be reset on new episode 187 | if explore: 188 | self.action_prev = top_mean # important to store before applying noise 189 | top_mean += torch.normal(mean=top_mean, std=explore_std) 190 | 191 | self.qf1.train() 192 | self.qf2.train() 193 | 194 | return top_mean 195 | 196 | 197 | def choose_action(self, obs): 198 | # Random uniform actions before learn_start can speed up training over using the agent's inital randomness. 199 | if self.global_step < self.h.learn_start: 200 | # actions are rand_uniform of shape (obs_batch_size, action_dim) 201 | action = (torch.rand((obs.size(0), self.action_dim), device=self.device) - 0.5) * 2.0 # rand_uniform -1..+1 202 | action = action * self.action_scale + self.action_bias # apply scale and bias 203 | else: 204 | action = self.ce_action_solver(obs, explore=True) # output is scaled and biased 205 | 206 | self.rb.store_choice(obs, action) 207 | 208 | return action 209 | 210 | def store_transition(self, reward, done): 211 | self.rb.store_transition(reward, done) 212 | self.global_step += 1 213 | 214 | if done: 215 | self.action_prev = None 216 | 217 | 218 | def update(self): 219 | 220 | ''' Call every step from learning script, agent decides if it is time to update ''' 221 | 222 | # Bookeeping 223 | updated = False 224 | chronos_total = 0.0 225 | 226 | if self.global_step > self.h.learn_start: 227 | updated = True 228 | self.chronos_start.record() 229 | 230 | for replay in range(0, self.h.replay_ratio): 231 | 232 | b_obs, b_actions, b_obs_next, b_rewards, b_dones = self.rb.sample(self.h.batch_size) 233 | 234 | with torch.no_grad(): 235 | next_state_actions = self.ce_action_solver(b_obs_next) 236 | 237 | bb_obs = torch.cat((b_obs, b_obs_next), dim=0) 238 | bb_acts = torch.cat((b_actions, next_state_actions), dim=0) 239 | 240 | bb_q1 = self.qf1(bb_obs, bb_acts) 241 | bb_q2 = self.qf2(bb_obs, bb_acts) 242 | 243 | b_q1, b_q1_next = torch.chunk(bb_q1, chunks=2, dim=0) 244 | b_q2, b_q2_next = torch.chunk(bb_q2, chunks=2, dim=0) 245 | self.qf1_a_values = b_q1 # mean of this is used in logging 246 | self.qf2_a_values = b_q2 # mean of this is used in logging 247 | 248 | min_q_next = torch.min(b_q1_next, b_q2_next) 249 | next_q_value = b_rewards.flatten() + (1 - b_dones.flatten()) * self.h.gamma * (min_q_next).view(-1) 250 | torch.detach_(next_q_value) # no gradients through here 251 | 252 | self.qf1_loss = F.mse_loss(b_q1.flatten(), next_q_value) 253 | self.qf2_loss = F.mse_loss(b_q2.flatten(), next_q_value) 254 | self.qf_loss = self.qf1_loss + self.qf2_loss 255 | 256 | self.q_optim.zero_grad() 257 | self.qf_loss.backward() 258 | self.q_optim.step() 259 | 260 | # Plasticity metrics occasionally 261 | if self.global_step % 2048 == 0 or self.global_step == self.h.learn_start: 262 | self.qf1_avg_wgt_mag = avg_weight_magnitude(self.qf1) 263 | self.qf2_avg_wgt_mag = avg_weight_magnitude(self.qf2) 264 | 265 | b_obs, b_actions = self.rb.plasticity_data(2048) # a representative sample 266 | _, _, self.qf1_dead_pct = count_dead_units(self.qf1, in1=b_obs, in2=b_actions, threshold=self.h.dead_hurdle) 267 | _, _, self.qf2_dead_pct = count_dead_units(self.qf2, in1=b_obs, in2=b_actions, threshold=self.h.dead_hurdle) 268 | 269 | self.qf1_dormant_ratio = dormant_ratio(self.qf1, in1=b_obs, in2=b_actions) 270 | self.qf2_dormant_ratio = dormant_ratio(self.qf2, in1=b_obs, in2=b_actions) 271 | 272 | # Record end time, wait for all cuda threads to sync and calc time in seconds 273 | self.chronos_end.record() 274 | torch.cuda.synchronize() 275 | chronos_total = (self.chronos_start.elapsed_time(self.chronos_end) * 0.001) 276 | 277 | return updated, chronos_total 278 | 279 | 280 | def save(self, path='./checkpoints/'): 281 | 282 | path = path + self.name + '.pt' 283 | models = {} 284 | 285 | models['Q1'] = self.qf1.state_dict() 286 | models['Q2'] = self.qf2.state_dict() 287 | 288 | torch.save(models, path) 289 | 290 | def load(self, path='./checkpoints/'): 291 | path = path + self.name + '.pt' 292 | models_file = torch.load(path) 293 | 294 | self.qf1.load_state_dict(models_file['Q1']) 295 | self.qf2.load_state_dict(models_file['Q2']) -------------------------------------------------------------------------------- /agents/normalise.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | import time 4 | 5 | """ 6 | Utils for normalizing observations 7 | based on https://github.com/openai/gym/blob/master/gym/wrappers/normalize.py 8 | """ 9 | 10 | # From gymnasium normalise wrapper 11 | class Normalise(): 12 | """This wrapper will normalize observations s.t. each coordinate is centered with unit variance. 13 | Note: 14 | The normalization depends on past trajectories 15 | https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 16 | """ 17 | 18 | def __init__(self, input_dim, epsilon=1e-8): 19 | """ epsilon: A stability parameter that is used when scaling the observations """ 20 | self.epsilon = epsilon 21 | 22 | """Tracks the mean, variance and count of values.""" 23 | self.mean = np.zeros(input_dim, "float64") 24 | self.var = np.ones(input_dim, "float64") 25 | self.count = 1e-4 26 | 27 | def update(self, x): 28 | """Updates the mean, var and count using the previous mean, var, count and batch values.""" 29 | 30 | batch_mean = np.mean(x, axis=0) 31 | batch_var = np.var(x, axis=0) 32 | batch_count = x.shape[0] 33 | 34 | delta = batch_mean - self.mean 35 | tot_count = self.count + batch_count 36 | 37 | new_mean = self.mean + delta * batch_count / tot_count 38 | m_a = self.var * self.count 39 | m_b = batch_var * batch_count 40 | M2 = m_a + m_b + np.square(delta) * self.count * batch_count / tot_count 41 | new_var = M2 / tot_count 42 | new_count = tot_count 43 | 44 | self.mean = new_mean 45 | self.var = new_var 46 | self.count = new_count 47 | 48 | def new(self, input): 49 | """Normalises new data""" 50 | input = np.array(input) 51 | input = np.expand_dims(input, axis=0) # new fix 52 | self.update(input) 53 | norm_obs = (input - self.mean) / np.sqrt(self.var + self.epsilon) 54 | norm_obs = np.squeeze(norm_obs) # new fix 55 | return np.float32(norm_obs) 56 | 57 | 58 | # Same as Normalise but slightly fewer calls, improving samples per second performance 59 | class NormaliseFast(): 60 | def __init__(self, input_dim, epsilon=1e-8): 61 | self.epsilon = epsilon 62 | 63 | self.mean = np.zeros(input_dim, "float64") 64 | self.var = np.ones(input_dim, "float64") 65 | self.count = 1e-4 66 | 67 | def new(self, input): 68 | input = np.array(input) 69 | 70 | delta = input - self.mean 71 | tot_count = self.count + 1 72 | 73 | M2 = (self.var * self.count) + np.square(delta) * self.count / tot_count 74 | 75 | self.mean = self.mean + delta / tot_count 76 | self.var = M2 / tot_count 77 | self.count += 1 78 | 79 | norm_obs = (input - self.mean) / np.sqrt(self.var + self.epsilon) 80 | return np.float32(norm_obs) 81 | 82 | 83 | # Same as Normalise, replacing numpy with torch 84 | class NormaliseTorch(): 85 | def __init__(self, input_dim, epsilon=1e-8, device='cpu'): 86 | self.device = device 87 | 88 | self.epsilon = th.tensor(epsilon, device=device) 89 | 90 | self.mean = th.zeros(1, input_dim, dtype=th.float64, device=device) 91 | self.var = th.ones(1, input_dim, dtype=th.float64, device=device) 92 | self.count = th.tensor(1e-4, device=device) 93 | 94 | def new(self, input): 95 | delta = input - self.mean 96 | tot_count = self.count + th.ones(1, device=self.device) 97 | 98 | M2 = (self.var * self.count) + th.square(delta) * self.count / tot_count 99 | 100 | self.mean = self.mean + delta / tot_count 101 | self.var = M2 / tot_count 102 | self.count = self.count + th.ones(1, device=self.device) 103 | 104 | norm_obs = (input - self.mean) / th.sqrt(self.var + self.epsilon) 105 | return norm_obs.to(dtype=th.float32) 106 | 107 | 108 | # GPT4 torchscript re-write of NormaliseTorch for better performance 109 | class NormaliseTorchScript(th.nn.Module): 110 | def __init__(self, input_dim, epsilon=1e-8, device='cpu'): 111 | super(NormaliseTorchScript, self).__init__() 112 | self.device = device 113 | self.input_dim = input_dim 114 | self.epsilon = epsilon 115 | 116 | self.mean = th.zeros(1, input_dim, dtype=th.float64, device=device) 117 | self.var = th.ones(1, input_dim, dtype=th.float64, device=device) 118 | self.count = th.ones(1, device=device) * 1e-4 # ensure count has shape [1] 119 | self.one = th.ones(1, device=self.device) 120 | 121 | @th.jit.export 122 | def new(self, input): 123 | delta = input - self.mean 124 | 125 | tot_count = self.count + self.one 126 | M2 = (self.var * self.count) + (delta ** 2) * self.count / tot_count 127 | 128 | self.mean = self.mean + delta / tot_count 129 | self.var = M2 / tot_count 130 | self.count = self.count + self.one 131 | 132 | norm_obs = (input - self.mean) / th.sqrt(self.var + self.epsilon) 133 | return norm_obs.to(dtype=th.float32) 134 | 135 | 136 | 137 | # For performance comparison 138 | @th.jit.script 139 | def symlog(x): 140 | # Element-wise symlog mapping 141 | return th.sign(x) * th.log(th.abs(x) + 1.0) 142 | 143 | 144 | def test_normalisations(): 145 | DEVICE = 'cpu' 146 | old = Normalise(2) 147 | new = NormaliseFast(2) 148 | tor = th.jit.script(NormaliseTorchScript(2, device=DEVICE)) 149 | 150 | # are they identical results? 151 | for i in range(100): 152 | o = old.new((1, i**2)) 153 | n = new.new((1, i**2)) 154 | data = th.tensor((1, i**2), device=DEVICE) 155 | t = tor.new(data) 156 | print("Deltas: ",o - n," ",o - t.cpu().detach().numpy()) 157 | 158 | # Compare performance 159 | cycles = int(100e3) 160 | data = np.ndarray((1,2)) 161 | 162 | start_time = time.time() 163 | for i in range(cycles): 164 | x = old.new(data) 165 | norm1_time = (time.time() - start_time) 166 | print("done 1") 167 | 168 | start_time = time.time() 169 | for i in range(cycles): 170 | x = new.new(data) 171 | norm2_time = (time.time() - start_time) 172 | print("done 2") 173 | 174 | data = th.tensor((1, 2), device=DEVICE) 175 | 176 | start_time = time.time() 177 | for i in range(cycles): 178 | x = tor.new(data) 179 | norm3_time = (time.time() - start_time) 180 | print("done 3") 181 | 182 | start_time = time.time() 183 | for i in range(cycles): 184 | x = symlog(data) 185 | norm4_time = (time.time() - start_time) 186 | print("done 4") 187 | 188 | print("Seconds per (100k) ops for old: ", norm1_time, "new: ", norm2_time, "torch: ", norm3_time, "symlog: ", norm4_time) 189 | 190 | ########################## 191 | if __name__ == '__main__': 192 | test_normalisations() -------------------------------------------------------------------------------- /agents/ppo_baseline.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from types import SimpleNamespace 7 | from torch.distributions.normal import Normal 8 | from .normalise import NormaliseTorchScript 9 | from .utils import init_layer, symlog, avg_weight_magnitude, count_dead_units 10 | from .buffers import ReplayBufferPPO 11 | 12 | ''' 13 | Based on https://github.com/vwxyzjn/cleanrl 14 | PPO agent with RPO option https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_continuous_action.py 15 | Note: CleanRL version uses gym.wrappers.NormalizeReward(), which materially improves performance in some environments. 16 | ''' 17 | 18 | 19 | class Actor(nn.Module): 20 | def __init__(self, obs_dim, action_dim, hidden_dim=64, rpo_alpha=0.0, device='cpu'): 21 | super(Actor, self).__init__() 22 | ''' agent input assumes shape: (batch, channels)''' 23 | 24 | self.device = device 25 | 26 | # Actor network 27 | self.fc0 = init_layer(nn.Linear(obs_dim, hidden_dim)) 28 | self.fc1 = init_layer(nn.Linear(hidden_dim, hidden_dim)) 29 | self.out = init_layer(nn.Linear(hidden_dim, action_dim), std=0.01) # Last layer init near zero (C57, https://arxiv.org/abs/2006.05990) 30 | self.nonlin = nn.Tanh() # tanh preferred (C55, https://arxiv.org/abs/2006.05990) 31 | 32 | # ReZero for deep networks https://arxiv.org/abs/2003.04887 33 | # Worse performance for such small nns 34 | # self.r0 = nn.Parameter(torch.zeros(1)) 35 | 36 | # Adds stochasticity to action https://arxiv.org/abs/2212.07536 37 | # "rpo_alpha=0.5 –> better than PPO 93% of environments, rpo_alpha=0.01 –> better in 100%" 38 | # https://docs.cleanrl.dev/rl-algorithms/rpo/#implementation-details 39 | self.rpo_alpha = rpo_alpha 40 | 41 | # Actor logstd (initial standard dev = 0.5 https://arxiv.org/abs/2006.05990) 42 | self.logstd = nn.Parameter(torch.ones(1, action_dim) * math.log(0.5)) 43 | 44 | def forward(self, obs, action=None): 45 | x = self.nonlin(self.fc0(obs)) 46 | x = self.nonlin(self.fc1(x)) 47 | action_mean = self.out(x) 48 | 49 | # expand to match shape of action_mean (e.g. batch dim) 50 | action_logstd = self.logstd.expand_as(action_mean) 51 | action_std = torch.exp(action_logstd) 52 | 53 | if action is None: 54 | probs = Normal(action_mean, action_std, validate_args=False) 55 | action = probs.rsample() 56 | else: # RPO option 57 | z = (torch.rand((action_mean.shape), device=self.device) - 0.5) * 2.0 * self.rpo_alpha # z = -rpo..+rpo 58 | action_mean = action_mean + z 59 | probs = Normal(action_mean, action_std, validate_args=False) 60 | 61 | log_prob = probs.log_prob(action).sum(1) 62 | entropy = probs.entropy().sum(1) # important: is there a lenght dim to consider in sum axis? 63 | 64 | # Consider environment action ranges, clip where appropriate or use tanh for ±1 if necessary 65 | return action, log_prob, entropy 66 | 67 | 68 | class Critic(nn.Module): 69 | def __init__(self, obs_dim, hidden_dim=64): 70 | super(Critic, self).__init__() 71 | 72 | # Critic network (C47, independent critic performs better https://arxiv.org/abs/2006.05990) 73 | # Wider than actor preferred for critic (~4x) (https://arxiv.org/abs/2006.05990) 74 | self.fc0 = init_layer(nn.Linear(obs_dim, hidden_dim)) 75 | self.fc1 = init_layer(nn.Linear(hidden_dim, hidden_dim)) 76 | self.out = init_layer(nn.Linear(hidden_dim, 1), std=1.0) # Last layer init near one (C57, https://arxiv.org/abs/2006.05990) 77 | self.nonlin = nn.Tanh() 78 | 79 | # ReZero for deep networks https://arxiv.org/abs/2003.04887 80 | # Worse performance when networks are small 81 | # self.r0 = nn.Parameter(torch.zeros(1)) 82 | 83 | def forward(self, obs): 84 | x = self.nonlin(self.fc0(obs)) 85 | x = self.nonlin(self.fc1(x)) 86 | v = self.out(x) 87 | return v 88 | 89 | 90 | class Agent: 91 | def __init__(self, env_spec, buffer_size, num_env=1, device='cpu', seed=42): 92 | 93 | # Make global 94 | self.name = "ppo_baseline" # name for logging 95 | self.obs_dim = env_spec['obs_dim'] # environment inputs for agent 96 | self.action_dim = env_spec['act_dim'] # agent outputs to environment 97 | self.act_max = env_spec['act_max'] # action range, scalar or vector 98 | self.act_min = env_spec['act_min'] # action range, scalar or vector 99 | self.device = device # gpu or cpu 100 | 101 | # All seeds default to 42 102 | torch.manual_seed(torch.tensor(seed)) 103 | torch.backends.cudnn.deterministic = True 104 | torch.cuda.set_sync_debug_mode(1) # Set to 1 to receive warnings 105 | 106 | 107 | # Hyperparameters 108 | hyperparameters = { 109 | "eps_clip" : 0.2, # (def: 0.2) clip parameter for PPO 110 | "gamma" : 0.99, # (def: 0.99) Key parameter should be tuned for each environment https://arxiv.org/abs/2006.05990 (C20) 111 | "gae_lambda" : 0.95, # (def: 0.95) the lambda for the general advantage estimation 112 | "clip_coef" : 0.2, # (def: 0.2) try 0.1 to 0.5 depending on environment (https://arxiv.org/abs/2006.05990) 113 | "ent_coef" : 0.0, # (def: 0.0) coefficient of the entropy. 0.01 is better for WalkerHardcore. 114 | "max_grad_norm" : 0.5, # (def: 0.5) the maximum norm for the gradient clipping 115 | "max_kl" : None, # (def: 0.02) skip actor minibatch update if target exceeded. approx_kl generally < 0.02 when algo is working well 116 | "adam_lr" : 0.0003, # (def: 0.0003) Adam optimiser learning rate 0.0003 "safe default" but tuning recommeneded https://arxiv.org/abs/2006.05990 117 | "adam_eps" : 1e-5, # (def: 1e-7) Adam optimiser epsilon 118 | "weight_decay" : 0.0, # (def: 0.0) AdamW weight decay for regularisation (AdamW >> Adam) 119 | "norm_adv" : False, # (def: False) Normalise advantage of each batch (note not minibatch like CleanRL, lost source) 120 | "rpo_alpha" : 0.0, # (def: 0.0) In Box2D and Mujoco Gym environments a value of 0.5 was found to be worse. Perhaps due to differences between this and CleanRL's version. 121 | "gae_recalc" : False, # (def: False) recalculate GAE in each update epoch 122 | "update_epochs" : 10, # (def: 10) the K epochs to update the policy 123 | "mb_size" : 64, # (def: 64) the size of mini batches. CleanRL multiplies this by num_envs when vectorised 124 | "update_step" : 2048, # (def: 2048) perform update after this many environmental steps 125 | "dead_hurdle" : 0.001, # (def: 0.001) units with greater variation in output over one batch of data than this are not dead in plasticity terms 126 | "a_hidden_dim" : 64, # (def: 64) actor's hidden layers dim 127 | "c_hidden_dim" : 64, # (def: 64) critic's hidden layers dim 128 | } 129 | self.h = SimpleNamespace(**hyperparameters) 130 | 131 | # Loggin & debugging 132 | self.approx_kl = 0 # estimated kl divergence 133 | self.clipfracs = 0 # fraction of training data that triggered clipping objective 134 | self.p_loss = 0 # policy/actor loss 135 | self.v_loss = 0 # value/critic loss 136 | self.entropy_loss = 0 # provides entropy bonus through ent_coeff parameter 137 | self.ppo_updates = 0 # number of minibatch updates performed 138 | self.actor_grad_norm = 0 # actor model gradient norm 139 | self.critic_grad_norm = 0 # critic model gradient norm 140 | self.actor_avg_wgt_mag = 0 # average weight magnitude of model parameters 141 | self.critic_avg_wgt_mag = 0 # average weight magnitude of model parameters 142 | self.actor_dead_pct = 0 # percentage of units which are dead by some threshold 143 | self.critic_dead_pct = 0 # percentage of units which are dead by some threshold 144 | self.chronos_total = 0.0 # time taken for training update (use torch.cuda.Event() for more accurate gpu timing) 145 | 146 | # Buffer only needs to be as large as update_step, so replay_buffer is redundand and kept for api compatibility 147 | self.rb = ReplayBufferPPO(self.obs_dim, self.action_dim, self.h.update_step, num_env, device=self.device) 148 | self.global_step = 0 149 | 150 | # Normalise state observations. Use symlog for reward "normalisation" – experimentally best results on Box2D and Mujoco 151 | # https://arxiv.org/pdf/2006.05990.pdf (C64) and https://arxiv.org/pdf/2005.12729.pdf 152 | self.normalise_observations = torch.jit.script(NormaliseTorchScript(self.obs_dim, num_env, device=self.device)) 153 | # self.normalise_rewards = torch.jit.script(NormaliseTorchScript(1, num_env, device=self.device)) 154 | 155 | # Instantiate actor and critic networks, same optimiser parameters 156 | self.actor = Actor(self.obs_dim, self.action_dim, hidden_dim=self.h.a_hidden_dim, rpo_alpha=self.h.rpo_alpha, device=device).to(self.device) 157 | self.optim_a = torch.optim.AdamW(self.actor.parameters(), lr=self.h.adam_lr, eps=self.h.adam_eps, weight_decay=self.h.weight_decay) 158 | 159 | self.critic = Critic(self.obs_dim, hidden_dim=self.h.c_hidden_dim).to(self.device) 160 | self.optim_c = torch.optim.AdamW(self.critic.parameters(), lr=self.h.adam_lr, eps=self.h.adam_eps, weight_decay=self.h.weight_decay) 161 | 162 | 163 | # Values from environments must be pytorch tensors of shape (batch, channels) 164 | def choose_action(self, obs): 165 | obs = self.normalise_observations.new(obs) # normalise better than symlog (or none) for obs 166 | 167 | with torch.no_grad(): 168 | action, logprob, _ = self.actor(obs) 169 | 170 | self.rb.store_choice(obs, action, logprob) 171 | return action # return shape is also (batch, channels) 172 | 173 | def store_transition(self, reward, done): 174 | reward = symlog(reward) #symlog better than Normalise (or None) for rewards 175 | self.rb.store_transition(reward, done) 176 | self.global_step += 1 177 | 178 | # Generalised advantage estimation 179 | def gae(self): 180 | b_obs, b_rewards, b_dones = self.rb.get_gae() 181 | b_size = self.rb.ptr 182 | 183 | with torch.no_grad(): 184 | b_values = self.critic(b_obs).squeeze(2) # latest critic values 185 | next_value = b_values[b_size - 1].reshape(1,-1) 186 | 187 | b_advantages = torch.zeros_like(b_rewards, device=self.device) 188 | lastgaelam = 0 189 | for t in reversed(range(b_size)): 190 | if t == b_size - 1: 191 | nextnonterminal = 1.0 - b_dones[b_size - 1] 192 | nextvalues = next_value 193 | else: 194 | nextnonterminal = 1.0 - b_dones[t + 1] 195 | nextvalues = b_values[t + 1] 196 | delta = b_rewards[t] + self.h.gamma * nextvalues * nextnonterminal - b_values[t] 197 | b_advantages[t] = lastgaelam = delta + self.h.gamma * self.h.gae_lambda * nextnonterminal * lastgaelam 198 | b_returns = b_advantages + b_values 199 | 200 | # Flatten on return 201 | return b_returns.reshape(-1), b_advantages.reshape(-1), b_values.reshape(-1) 202 | 203 | # Optimize actor/policy and critic/value networks 204 | def update(self): 205 | 206 | # Bookeeping 207 | updated = False 208 | 209 | if self.global_step % self.h.update_step == 0 and self.global_step != 0: 210 | 211 | updated = True 212 | chronos_start = time.time() 213 | 214 | b_obs, b_actions, b_logprobs = self.rb.get_ppo_update() 215 | batch_end = self.rb.ptr - 1 # index to last element 216 | 217 | clipfracs = torch.zeros(0, device=self.device) 218 | self.ppo_updates = 0 219 | for epoch in range(self.h.update_epochs): 220 | 221 | # Update GAE once or in each epoch for fresh advantages (https://arxiv.org/abs/2006.05990) 222 | if (self.h.gae_recalc) or (epoch == 0): 223 | b_returns, b_advantages, b_values = self.gae() 224 | if self.h.norm_adv: 225 | b_advantages = (b_advantages - b_advantages.mean()) / (b_advantages.std() + 1e-8) 226 | 227 | b_inds = torch.randperm(batch_end, device=self.device) # shuffled indices of the batch 228 | for start in range(0, batch_end, self.h.mb_size): 229 | end = min(start + self.h.mb_size, batch_end) 230 | mb_inds = b_inds[start:end] 231 | 232 | # Get minibatch set 233 | mb_obs = b_obs[mb_inds] 234 | mb_actions = b_actions[mb_inds] 235 | mb_advantages = b_advantages[mb_inds] 236 | 237 | # From latest policy 238 | _, newlogprob, entropy = self.actor(mb_obs, action=mb_actions) 239 | 240 | # Ratio for policy loss, logratio for estimating kl 241 | logratio = newlogprob - b_logprobs[mb_inds] 242 | ratio = logratio.exp() 243 | 244 | # Debugging & info 245 | with torch.no_grad(): 246 | # approx Kullback-Leibler divergence, usually < 0.02 when policy not changing too quickly 247 | self.approx_kl = ((ratio - 1) - logratio).mean() 248 | 249 | # fraction of training data that triggered clipped objective 250 | frac = torch.gt(torch.abs(ratio - 1.0), self.h.clip_coef) 251 | clipfracs = torch.cat([clipfracs, frac]) 252 | 253 | # PPO policy loss 254 | p_loss1 = -mb_advantages * ratio 255 | p_loss2 = -mb_advantages * torch.clamp(ratio, 1 - self.h.clip_coef, 1 + self.h.clip_coef) 256 | self.entropy_loss = entropy.mean() 257 | self.p_loss = torch.max(p_loss1, p_loss2).mean() - self.h.ent_coef * self.entropy_loss 258 | 259 | # Value loss 260 | mb_newvalues = self.critic(mb_obs).view(-1) 261 | self.v_loss = F.mse_loss(mb_newvalues, b_returns[mb_inds]) 262 | 263 | # Skip this minibatch update just before applying .step() if max kl exceeded 264 | if self.h.max_kl is not None: 265 | if self.approx_kl > self.h.max_kl: break 266 | 267 | # Update actor model 268 | self.optim_a.zero_grad() 269 | self.p_loss.backward() 270 | self.actor_grad_norm = nn.utils.clip_grad_norm_(self.actor.parameters(), self.h.max_grad_norm) 271 | self.optim_a.step() 272 | 273 | # Update critic model 274 | self.optim_c.zero_grad() 275 | self.v_loss.backward() 276 | self.critic_grad_norm = nn.utils.clip_grad_norm_(self.critic.parameters(), self.h.max_grad_norm) 277 | self.optim_c.step() 278 | 279 | self.ppo_updates += 1 280 | 281 | # the explained variance for the value function 282 | y_pred, y_true = b_values, b_returns 283 | var_y = torch.var(y_true) 284 | self.explained_var = 1 - torch.var(y_true - y_pred) / var_y 285 | 286 | # the fraction of the training data that triggered the clipped objective 287 | self.clipfracs = torch.mean(clipfracs) 288 | 289 | # Loss of Plasticity in Deep Continual Learning: https://arxiv.org/abs/2306.13812 290 | self.actor_avg_wgt_mag = avg_weight_magnitude(self.actor) 291 | self.critic_avg_wgt_mag = avg_weight_magnitude(self.critic) 292 | 293 | b_obs = self.rb.get_obs() 294 | _, _, self.actor_dead_pct = count_dead_units(self.actor, in1=b_obs, threshold=self.h.dead_hurdle) 295 | _, _, self.critic_dead_pct = count_dead_units(self.critic, in1=b_obs, threshold=self.h.dead_hurdle) 296 | 297 | # reset replay buffer and calc time taken 298 | self.rb.flush() 299 | self.chronos_total = time.time() - chronos_start 300 | 301 | return updated, self.chronos_total 302 | 303 | def save(self, path='./checkpoints/'): 304 | 305 | path = path + self.name + '.pt' 306 | models = {} 307 | 308 | models['Critic'] = self.critic.state_dict() 309 | models['Actor'] = self.actor.state_dict() 310 | 311 | torch.save(models, path) 312 | 313 | def load(self, path='./checkpoints/'): 314 | path = path + self.name + '.pt' 315 | models_file = torch.load(path) 316 | 317 | self.critic.load_state_dict(models_file['Critic']) 318 | self.actor.load_state_dict(models_file['Actor']) -------------------------------------------------------------------------------- /agents/sac_baseline.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from types import SimpleNamespace 6 | from .utils import avg_weight_magnitude, count_dead_units 7 | from .buffers import ReplayBufferSAC 8 | 9 | ''' 10 | Based on https://github.com/vwxyzjn/cleanrl 11 | SAC agent https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/sac_continuous_action.py 12 | ''' 13 | 14 | 15 | class SoftQNetwork(nn.Module): 16 | def __init__(self, obs_dim, action_dim, hidden_dim=256): 17 | super().__init__() 18 | 19 | self.mlp = nn.Sequential( 20 | nn.Linear(obs_dim + action_dim, hidden_dim), 21 | nn.ReLU(), 22 | nn.Linear(hidden_dim,hidden_dim), 23 | nn.ReLU(), 24 | nn.Linear(hidden_dim, 1) 25 | ) 26 | 27 | def forward(self, x, a): 28 | x = torch.cat([x, a], 1) 29 | x = self.mlp(x) 30 | return x 31 | 32 | 33 | class Actor(nn.Module): 34 | def __init__(self, env_spec, hidden_dim=256): 35 | super().__init__() 36 | obs_dim = env_spec['obs_dim'] 37 | act_dim = env_spec['act_dim'] 38 | 39 | self.fc0 = nn.Linear(obs_dim, hidden_dim) 40 | self.fc1 = nn.Linear(hidden_dim, hidden_dim) 41 | self.fc_mean = nn.Linear(hidden_dim, act_dim) 42 | self.fc_logstd = nn.Linear(hidden_dim, act_dim) 43 | self.nonlin = nn.ReLU() # If the non-linearity has state (e.g. trainable parameter) we need one per use 44 | 45 | action_high = nn.Parameter(torch.tensor(env_spec['act_max']), requires_grad=False) 46 | action_low = nn.Parameter(torch.tensor(env_spec['act_min']), requires_grad=False) 47 | self.action_scale = nn.Parameter(torch.tensor((action_high - action_low) * 0.5), requires_grad=False) 48 | self.action_bias = nn.Parameter(torch.tensor((action_high + action_low) * 0.5), requires_grad=False) 49 | self.log_std_max = 2 50 | self.log_std_min = -5 51 | 52 | def forward(self, x): 53 | x = self.nonlin(self.fc0(x)) 54 | x = self.nonlin(self.fc1(x)) 55 | mean = self.fc_mean(x) 56 | log_std = self.fc_logstd(x) 57 | log_std = torch.tanh(log_std) 58 | log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1) # From SpinUp / Denis Yarats 59 | return mean, log_std 60 | 61 | def get_action(self, x): 62 | mean, log_std = self(x) 63 | std = log_std.exp() 64 | normal = torch.distributions.Normal(mean, std, validate_args=False) # validation forces a cuda<>cpu sync 65 | x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) 66 | y_t = torch.tanh(x_t) 67 | action = y_t * self.action_scale + self.action_bias 68 | log_prob = normal.log_prob(x_t) 69 | 70 | # Enforcing Action Bound 71 | log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6) 72 | log_prob = log_prob.sum(1, keepdim=True) 73 | mean = torch.tanh(mean) * self.action_scale + self.action_bias 74 | return action, log_prob, mean 75 | 76 | 77 | class Agent: 78 | def __init__(self, env_spec, buffer_size=int(1e6), num_env=1, device='cpu', seed=42): 79 | 80 | # Make global 81 | self.name = "sac_baseline" # name for logging 82 | self.obs_dim = env_spec['obs_dim'] # environment inputs for agent 83 | self.action_dim = env_spec['act_dim'] # agent outputs to environment 84 | self.act_max = env_spec['act_max'] # action range, scalar or vector 85 | self.act_min = env_spec['act_min'] # action range, scalar or vector 86 | self.device = device # gpu or cpu 87 | 88 | # All seeds default to 42 89 | torch.manual_seed(torch.tensor(seed)) 90 | torch.backends.cudnn.deterministic = True 91 | # torch.cuda.set_sync_debug_mode(1) # Set to 1 to receive warnings 92 | 93 | # Hyperparameters 94 | hyperparameters = { 95 | "gamma" : 0.99, # (def: 0.99) Discount factor 96 | "q_lr" : 1e-3, # (def: 1e-3) Q learning rate 97 | "actor_lr" : 3e-4, # (def: 3e-4) Policy learning rate 98 | "learn_start" : int(5e3), # (def: 5e3) Start updating policies after this many global steps 99 | "batch_size" : 256, # (def: 256) Batch size of sample from replay buffer 100 | "policy_freq" : 2, # (def: 2) the frequency of training policy (delayed) 101 | "target_net_freq" : 1, # (def: 1) Denis Yarats' implementation delays this by 2 102 | "tau" : 0.005, # (def: 0.005) target smoothing coefficient 103 | "dead_hurdle" : 0.01, # (def: 0.01) units with greater variation in output over one batch of data than this are not dead in plasticity terms 104 | "a_hidden_dim" : 256, # (def: 256) size of actor's hidden layer(s) 105 | "q_hidden_dim" : 256, # (def: 256) size of Q's hidden layer(s) 106 | } 107 | self.h = SimpleNamespace(**hyperparameters) 108 | 109 | # Loggin & debugging 110 | self.qf1_a_values = torch.tensor([0.0]) 111 | self.qf2_a_values = torch.tensor([0.0]) 112 | self.qf1_loss = 0 113 | self.qf2_loss = 0 114 | self.qf_loss = 0 115 | self.actor_loss = 0 116 | self.alpha_loss = 0 117 | self.actor_avg_wgt_mag = 0 # average weight magnitude of model parameters 118 | self.qf1_avg_wgt_mag = 0 # average weight magnitude of model parameters 119 | self.qf2_avg_wgt_mag = 0 # average weight magnitude of model parameters 120 | self.actor_dead_pct = 0 # percentage of units which are dead by some threshold 121 | self.qf1_dead_pct = 0 # percentage of units which are dead by some threshold 122 | self.qf2_dead_pct = 0 # percentage of units which are dead by some threshold 123 | 124 | # Instantiate actor and Q networks, optimisers 125 | # AdamW (may have) resulted in more stable training than Adam 126 | self.qf1 = SoftQNetwork(self.obs_dim, self.action_dim, self.h.q_hidden_dim).to(device) 127 | self.qf2 = SoftQNetwork(self.obs_dim, self.action_dim, self.h.q_hidden_dim).to(device) 128 | self.qf1_target = SoftQNetwork(self.obs_dim, self.action_dim, self.h.q_hidden_dim).to(device) 129 | self.qf2_target = SoftQNetwork(self.obs_dim, self.action_dim, self.h.q_hidden_dim).to(device) 130 | self.q_optim = torch.optim.Adam(list(self.qf1.parameters()) + list(self.qf2.parameters()), lr=self.h.q_lr) 131 | self.qf1_target.load_state_dict(self.qf1.state_dict()) 132 | self.qf2_target.load_state_dict(self.qf2.state_dict()) 133 | 134 | self.actor = Actor(env_spec, self.h.a_hidden_dim).to(device) 135 | self.actor_optim = torch.optim.Adam(list(self.actor.parameters()), lr=self.h.actor_lr) 136 | 137 | # Use automatic entropy tuning 138 | self.target_entropy = -torch.prod(torch.Tensor((self.action_dim,)).to(device)).item() 139 | self.log_alpha = torch.zeros(1, requires_grad=True, device=device) 140 | self.alpha = self.log_alpha.exp().item() 141 | self.alpha_optim = torch.optim.Adam([self.log_alpha], lr=self.h.q_lr) 142 | 143 | # Storage setup 144 | self.rb = ReplayBufferSAC(self.obs_dim, self.action_dim, buffer_size, num_env, device=self.device) 145 | self.global_step = 0 146 | 147 | 148 | def choose_action(self, obs): 149 | # Random uniform actions before learn_start can speed up training over using the agent's inital randomness. 150 | if self.global_step < self.h.learn_start: 151 | # actions are rand_uniform of shape (obs_batch_size, action_dim) 152 | action = (torch.rand((obs.size(0), self.action_dim), device=self.device) - 0.5) * 2.0 # rand_uniform -1..+1 153 | action = action * self.actor.action_scale + self.actor.action_bias # apply scale and bias 154 | else: 155 | with torch.no_grad(): 156 | action, _, _ = self.actor.get_action(obs) 157 | self.rb.store_choice(obs, action) 158 | return action 159 | 160 | def store_transition(self, reward, done): 161 | self.rb.store_transition(reward, done) 162 | self.global_step += 1 163 | 164 | 165 | def update(self): 166 | 167 | ''' Call every step from learning script, agent decides if it is time to update ''' 168 | 169 | # Bookeeping 170 | updated = False 171 | chronos_total = 0.0 172 | 173 | if self.global_step > self.h.learn_start: 174 | updated = True 175 | chronos_start = time.time() 176 | 177 | b_obs, b_actions, b_obs_next, b_rewards, b_dones = self.rb.sample(self.h.batch_size) 178 | 179 | with torch.no_grad(): 180 | next_state_actions, next_state_log_pi, _ = self.actor.get_action(b_obs_next) 181 | qf1_next_target = self.qf1_target(b_obs_next, next_state_actions) 182 | qf2_next_target = self.qf2_target(b_obs_next, next_state_actions) 183 | min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi 184 | next_q_value = b_rewards.flatten() + (1 - b_dones.flatten()) * self.h.gamma * (min_qf_next_target).view(-1) 185 | 186 | self.qf1_a_values = self.qf1(b_obs, b_actions).view(-1) 187 | self.qf2_a_values = self.qf2(b_obs, b_actions).view(-1) 188 | self.qf1_loss = F.mse_loss(self.qf1_a_values, next_q_value) 189 | self.qf2_loss = F.mse_loss(self.qf2_a_values, next_q_value) 190 | self.qf_loss = self.qf1_loss + self.qf2_loss 191 | 192 | self.q_optim.zero_grad() 193 | self.qf_loss.backward() 194 | self.q_optim.step() 195 | 196 | # update actor network and alpha parameter 197 | if self.global_step % self.h.policy_freq == 0: # TD 3 Delayed update support 198 | for _ in range(self.h.policy_freq): # compensate for the delay by doing 'actor_update_interval' instead of 1 199 | pi, log_pi, _ = self.actor.get_action(b_obs) 200 | qf1_pi = self.qf1(b_obs, pi) 201 | qf2_pi = self.qf2(b_obs, pi) 202 | min_qf_pi = torch.min(qf1_pi, qf2_pi).view(-1) 203 | self.actor_loss = ((self.alpha * log_pi) - min_qf_pi).mean() 204 | 205 | self.actor_optim.zero_grad() 206 | self.actor_loss.backward() 207 | self.actor_optim.step() 208 | 209 | # Autotune alpha 210 | with torch.no_grad(): 211 | _, log_pi, _ = self.actor.get_action(b_obs) 212 | self.alpha_loss = (-self.log_alpha.exp() * (log_pi + self.target_entropy)).mean() 213 | 214 | self.alpha_optim.zero_grad() 215 | self.alpha_loss.backward() 216 | self.alpha_optim.step() 217 | self.alpha = self.log_alpha.exp().detach().clone() 218 | 219 | # update the target networks 220 | if self.global_step % self.h.target_net_freq == 0: 221 | for param, target_param in zip(self.qf1.parameters(), self.qf1_target.parameters()): 222 | target_param.data.copy_(self.h.tau * param.data + (1 - self.h.tau) * target_param.data) 223 | for param, target_param in zip(self.qf2.parameters(), self.qf2_target.parameters()): 224 | target_param.data.copy_(self.h.tau * param.data + (1 - self.h.tau) * target_param.data) 225 | 226 | # Plasticity metrics occasionally 227 | if self.global_step % 2048 == 0 or self.global_step == self.h.learn_start: 228 | self.actor_avg_wgt_mag = avg_weight_magnitude(self.actor) 229 | self.qf1_avg_wgt_mag = avg_weight_magnitude(self.qf1) 230 | self.qf2_avg_wgt_mag = avg_weight_magnitude(self.qf2) 231 | 232 | b_obs, b_actions = self.rb.plasticity_data(2048) # a representative sample 233 | _, _, self.actor_dead_pct = count_dead_units(self.actor, in1=b_obs, threshold=self.h.dead_hurdle) 234 | _, _, self.qf1_dead_pct = count_dead_units(self.qf1, in1=b_obs, in2=b_actions, threshold=self.h.dead_hurdle) 235 | _, _, self.qf2_dead_pct = count_dead_units(self.qf2, in1=b_obs, in2=b_actions, threshold=self.h.dead_hurdle) 236 | 237 | chronos_total = time.time() - chronos_start 238 | 239 | return updated, chronos_total 240 | -------------------------------------------------------------------------------- /agents/sac_crossq.py: -------------------------------------------------------------------------------- 1 | import time, math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from types import SimpleNamespace 6 | from .utils import avg_weight_magnitude, count_dead_units, dormant_ratio 7 | from .buffers import ReplayBufferSAC 8 | 9 | ''' 10 | Based on https://github.com/vwxyzjn/cleanrl 11 | SAC agent https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/sac_continuous_action.py 12 | 13 | CrossQ extensions to SAC influenced by: 14 | - CrossQ: Batch Normalization in Deep Reinforcement Learning for Greater Sample Efficiency and Simplicity https://arxiv.org/abs/1902.05605 15 | - v3 of the arxiv paper (2023) 16 | 17 | Plasticity metrics and regularisation influnced by the following papers: 18 | - Maintaining Plasticity in Continual Learning via Regenerative Regularization https://arxiv.org/abs/2308.11958 19 | - Loss of Plasticity in Deep Continual Learning https://arxiv.org/abs/2306.13812 20 | - Sample-Efficient Reinforcement Learning by Breaking the Replay Ratio Barrier https://openreview.net/forum?id=OpC-9aBBVJe 21 | – Bigger, Better, Faster: Human-level Atari with human-level efficiency https://arxiv.org/abs/2305.19452 22 | ''' 23 | 24 | 25 | class SoftQNetwork(nn.Module): 26 | def __init__(self, obs_dim, action_dim, hidden_dim=256): 27 | super().__init__() 28 | 29 | self.mlp = nn.Sequential( 30 | nn.Linear(obs_dim + action_dim, hidden_dim, bias=False), # batchnorm has bias, this one is redundant 31 | nn.BatchNorm1d(hidden_dim, momentum=0.01), #CrossQ 32 | nn.ReLU(), 33 | nn.Linear(hidden_dim,hidden_dim, bias=False), 34 | nn.BatchNorm1d(hidden_dim, momentum=0.01), #CrossQ 35 | nn.ReLU(), 36 | nn.Linear(hidden_dim, 1) 37 | ) 38 | 39 | # init_model(self.mlp, init_method='xavier_uniform_') # CrossQ no mention of init 40 | 41 | def forward(self, x, a): 42 | x = torch.cat([x, a], 1) 43 | x = self.mlp(x) 44 | return x 45 | 46 | 47 | class Actor(nn.Module): 48 | def __init__(self, env_spec, hidden_dim=256): 49 | super().__init__() 50 | 51 | obs_dim = env_spec['obs_dim'] 52 | act_dim = env_spec['act_dim'] 53 | 54 | self.mlp = nn.Sequential( 55 | nn.Linear(obs_dim, hidden_dim, bias=False), 56 | nn.BatchNorm1d(hidden_dim, momentum=0.01), #CrossQ 57 | nn.ReLU(), 58 | nn.Linear(hidden_dim,hidden_dim, bias=False), 59 | nn.BatchNorm1d(hidden_dim, momentum=0.01), #CrossQ 60 | nn.ReLU(), 61 | ) 62 | self.fc_mean = nn.Linear(hidden_dim, act_dim) 63 | self.fc_logstd = nn.Linear(hidden_dim, act_dim) 64 | 65 | action_high = nn.Parameter(torch.tensor(env_spec['act_max']), requires_grad=False) 66 | action_low = nn.Parameter(torch.tensor(env_spec['act_min']), requires_grad=False) 67 | self.action_scale = nn.Parameter((action_high - action_low) * 0.5, requires_grad=False) 68 | self.action_bias = nn.Parameter((action_high + action_low) * 0.5, requires_grad=False) 69 | self.log_std_max = 2 70 | self.log_std_min = -5 71 | 72 | def forward(self, x): 73 | x = self.mlp(x) 74 | mean = self.fc_mean(x) 75 | log_std = self.fc_logstd(x) 76 | log_std = torch.tanh(log_std) 77 | log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1) # From SpinUp / Denis Yarats 78 | return mean, log_std 79 | 80 | def get_action(self, x): 81 | mean, log_std = self(x) 82 | std = log_std.exp() 83 | normal = torch.distributions.Normal(mean, std, validate_args=False) # validation forces a cuda<>cpu sync 84 | x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) 85 | y_t = torch.tanh(x_t) 86 | action = y_t * self.action_scale + self.action_bias 87 | log_prob = normal.log_prob(x_t) 88 | 89 | # Enforcing Action Bound 90 | log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6) 91 | log_prob = log_prob.sum(1, keepdim=True) 92 | mean = torch.tanh(mean) * self.action_scale + self.action_bias 93 | return action, log_prob, mean 94 | 95 | 96 | class Agent: 97 | def __init__(self, 98 | env_spec, 99 | buffer_size = int(1e6), 100 | num_env = 1, 101 | device = 'cpu', 102 | seed = 42, 103 | rr = 1, # RR = 1 for CrossQ 104 | q_lr = 1e-3, # CrossQ learning rates 105 | actor_lr = 1e-3, 106 | alpha_lr = 1e-3, 107 | ): 108 | 109 | # Make global 110 | self.name = "sac_crossq" # name for logging 111 | self.obs_dim = env_spec['obs_dim'] # environment inputs for agent 112 | self.action_dim = env_spec['act_dim'] # agent outputs to environment 113 | self.act_max = env_spec['act_max'] # action range, scalar or vector 114 | self.act_min = env_spec['act_min'] # action range, scalar or vector 115 | self.device = device # gpu or cpu 116 | 117 | # All seeds default to 42 118 | torch.manual_seed(torch.tensor(seed)) 119 | torch.backends.cudnn.deterministic = True 120 | # torch.cuda.set_sync_debug_mode(1) # Set to 1 to receive warnings 121 | # torch.set_float32_matmul_precision("high") # "high" is 11% faster, but can reduce learning performance in certain envs 122 | 123 | # Hyperparameters 124 | hyperparameters = { 125 | "gamma" : 0.99, # (def: 0.99) Discount factor 126 | "q_lr" : q_lr, # (def: 1e-3) Q learning rate 127 | "a_lr" : actor_lr, # (def: 1e-3) Policy learning rate 128 | "alpha_lr" : alpha_lr, # (def: 1e-3) alpha auto entropoty tuning learning rate 129 | "learn_start" : int(5e3), # (def: 5e3) Start updating policies after this many global steps 130 | "batch_size" : 256, # (def: 256) Batch size of sample from replay buffer 131 | "policy_freq" : 3, # (def: 3) CrossQ 132 | "dead_hurdle" : 0.001, # (def: 0.001) units with greater variation in output over one batch of data than this are not dead in plasticity terms 133 | "a_hidden_dim" : 256, # (def: 256) size of actor's hidden layer(s) 134 | "q_hidden_dim" : 2048, # (def: 2048) CrossQ with 512 wide Qf did just as well, but with a little more variance 135 | "replay_ratio" : round(rr), 136 | "adam_betas" : (0.5, 0.999), # CrossQ 137 | } 138 | self.h = SimpleNamespace(**hyperparameters) 139 | 140 | # Loggin & debugging 141 | self.qf1_a_values = torch.tensor([0.0]) 142 | self.qf2_a_values = torch.tensor([0.0]) 143 | self.qf1_loss = 0 144 | self.qf2_loss = 0 145 | self.qf_loss = 0 146 | self.actor_loss = 0 147 | self.alpha_loss = 0 148 | self.actor_avg_wgt_mag = 0 # average weight magnitude as per https://arxiv.org/abs/2306.13812 149 | self.qf1_avg_wgt_mag = 0 150 | self.qf2_avg_wgt_mag = 0 151 | self.actor_dead_pct = 0 # dead units as per https://arxiv.org/abs/2306.13812 152 | self.qf1_dead_pct = 0 153 | self.qf2_dead_pct = 0 154 | self.qf1_dormant_ratio = 0 # DrM: Dormant Ratio Minimisation https://arxiv.org/abs/2310.19668 155 | self.qf2_dormant_ratio = 0 156 | self.actor_dormant_ratio = 0 157 | 158 | 159 | # Instantiate actor and Q networks, optimisers 160 | # CrossQ uses Adam but experience with AdamW is better 161 | self.qf1 = SoftQNetwork(self.obs_dim, self.action_dim, self.h.q_hidden_dim).to(device) 162 | self.qf2 = SoftQNetwork(self.obs_dim, self.action_dim, self.h.q_hidden_dim).to(device) 163 | self.q_optim = torch.optim.AdamW(list(self.qf1.parameters()) + list(self.qf2.parameters()), lr=self.h.q_lr, betas=self.h.adam_betas) 164 | 165 | self.actor = Actor(env_spec, self.h.a_hidden_dim).to(device) 166 | self.actor_optim = torch.optim.AdamW(list(self.actor.parameters()), lr=self.h.a_lr) 167 | # init_model(self.actor, init_method='xavier_uniform_') # CrossQ no mention of init 168 | 169 | # Use automatic entropy tuning 170 | self.target_entropy = -(torch.prod(torch.Tensor((self.action_dim,))).to(device)).item() 171 | self.log_alpha = torch.tensor((math.log(0.1)), requires_grad=True, device=device) 172 | self.alpha = self.log_alpha.exp().item() 173 | self.alpha_optim = torch.optim.AdamW([self.log_alpha], lr=self.h.alpha_lr) 174 | 175 | # Storage setup 176 | self.rb = ReplayBufferSAC(self.obs_dim, self.action_dim, buffer_size, num_env, device=self.device) 177 | self.global_step = 0 178 | 179 | # CUDA timers for the update process 180 | self.chronos_start = torch.cuda.Event(enable_timing=True) 181 | self.chronos_end = torch.cuda.Event(enable_timing=True) 182 | 183 | 184 | def choose_action(self, obs): 185 | # Random uniform actions before learn_start can speed up training over using the agent's inital randomness. 186 | if self.global_step < self.h.learn_start: 187 | # actions are rand_uniform of shape (obs_batch_size, action_dim) 188 | action = (torch.rand((obs.size(0), self.action_dim), device=self.device) - 0.5) * 2.0 # rand_uniform -1..+1 189 | action = action * self.actor.action_scale + self.actor.action_bias # apply scale and bias 190 | else: 191 | with torch.no_grad(): 192 | self.actor.eval() # prevent changes to batchnorm layers 193 | action, _, _ = self.actor.get_action(obs) 194 | self.actor.train() 195 | self.rb.store_choice(obs, action) 196 | return action 197 | 198 | def store_transition(self, reward, done): 199 | self.rb.store_transition(reward, done) 200 | self.global_step += 1 201 | 202 | 203 | def update(self): 204 | 205 | ''' Call every step from learning script, agent decides if it is time to update ''' 206 | 207 | # Bookeeping 208 | updated = False 209 | chronos_total = 0.0 210 | 211 | if self.global_step > self.h.learn_start: 212 | updated = True 213 | self.chronos_start.record() 214 | 215 | for replay in range(0, self.h.replay_ratio): 216 | 217 | b_obs, b_actions, b_obs_next, b_rewards, b_dones = self.rb.sample(self.h.batch_size) 218 | 219 | with torch.no_grad(): 220 | self.actor.eval() 221 | next_state_actions, next_state_log_pi, _ = self.actor.get_action(b_obs_next) 222 | self.actor.train() 223 | 224 | bb_obs = torch.cat((b_obs, b_obs_next), dim=0) 225 | bb_acts = torch.cat((b_actions, next_state_actions), dim=0) 226 | 227 | bb_q1 = self.qf1(bb_obs, bb_acts) 228 | bb_q2 = self.qf2(bb_obs, bb_acts) 229 | 230 | b_q1, b_q1_next = torch.chunk(bb_q1, chunks=2, dim=0) 231 | b_q2, b_q2_next = torch.chunk(bb_q2, chunks=2, dim=0) 232 | self.qf1_a_values = b_q1 # mean of this is used in logging 233 | self.qf2_a_values = b_q2 # mean of this is used in logging 234 | 235 | min_q_next = torch.min(b_q1_next, b_q2_next) - self.alpha * next_state_log_pi 236 | next_q_value = b_rewards.flatten() + (1 - b_dones.flatten()) * self.h.gamma * (min_q_next).view(-1) 237 | torch.detach_(next_q_value) # no gradients through here 238 | 239 | self.qf1_loss = F.mse_loss(b_q1.flatten(), next_q_value) 240 | self.qf2_loss = F.mse_loss(b_q2.flatten(), next_q_value) 241 | self.qf_loss = self.qf1_loss + self.qf2_loss 242 | 243 | self.q_optim.zero_grad() 244 | self.qf_loss.backward() 245 | self.q_optim.step() 246 | 247 | # Replay Ratio does not apply to actor nor to alpha as per DroQ 248 | 249 | # Update actor network and alpha parameter 250 | if self.global_step % self.h.policy_freq == 0: # TD 3 Delayed update support 251 | pi, log_pi, _ = self.actor.get_action(b_obs) 252 | 253 | self.qf1.eval() 254 | self.qf2.eval() 255 | qf1_pi = self.qf1(b_obs, pi) 256 | qf2_pi = self.qf2(b_obs, pi) 257 | self.qf1.train() 258 | self.qf2.train() 259 | 260 | min_qf_pi = torch.min(qf1_pi, qf2_pi).view(-1) 261 | self.actor_loss = ((self.alpha * log_pi) - min_qf_pi).mean() 262 | 263 | self.actor_optim.zero_grad() 264 | self.actor_loss.backward() 265 | self.actor_optim.step() 266 | 267 | # Autotune alpha 268 | with torch.no_grad(): 269 | self.actor.eval() 270 | _, log_pi, _ = self.actor.get_action(b_obs) 271 | self.actor.train() 272 | self.alpha_loss = (-self.log_alpha.exp() * (log_pi + self.target_entropy)).mean() 273 | 274 | self.alpha_optim.zero_grad() 275 | self.alpha_loss.backward() 276 | self.alpha_optim.step() 277 | self.alpha = self.log_alpha.exp().detach().clone() 278 | 279 | # Plasticity metrics occasionally 280 | if self.global_step % 2048 == 0 or self.global_step == self.h.learn_start: 281 | self.actor_avg_wgt_mag = avg_weight_magnitude(self.actor) 282 | self.qf1_avg_wgt_mag = avg_weight_magnitude(self.qf1) 283 | self.qf2_avg_wgt_mag = avg_weight_magnitude(self.qf2) 284 | 285 | b_obs, b_actions = self.rb.plasticity_data(2048) # a representative sample 286 | _, _, self.qf1_dead_pct = count_dead_units(self.qf1, in1=b_obs, in2=b_actions, threshold=self.h.dead_hurdle) 287 | _, _, self.qf2_dead_pct = count_dead_units(self.qf2, in1=b_obs, in2=b_actions, threshold=self.h.dead_hurdle) 288 | _, _, self.actor_dead_pct = count_dead_units(self.actor, in1=b_obs, threshold=self.h.dead_hurdle) 289 | 290 | self.qf1_dormant_ratio = dormant_ratio(self.qf1, in1=b_obs, in2=b_actions) 291 | self.qf2_dormant_ratio = dormant_ratio(self.qf2, in1=b_obs, in2=b_actions) 292 | self.actor_dormant_ratio = dormant_ratio(self.actor, in1=b_obs) 293 | 294 | # Record end time, wait for all cuda threads to sync and calc time in seconds 295 | self.chronos_end.record() 296 | torch.cuda.synchronize() 297 | chronos_total = (self.chronos_start.elapsed_time(self.chronos_end) * 0.001) 298 | 299 | return updated, chronos_total 300 | 301 | 302 | def save(self, path='./checkpoints/'): 303 | 304 | path = path + self.name + '.pt' 305 | models = {} 306 | 307 | models['Q1'] = self.qf1.state_dict() 308 | models['Q2'] = self.qf2.state_dict() 309 | models['Actor'] = self.actor.state_dict() 310 | 311 | torch.save(models, path) 312 | 313 | def load(self, path='./checkpoints/'): 314 | path = path + self.name + '.pt' 315 | models_file = torch.load(path) 316 | 317 | self.qf1.load_state_dict(models_file['Q1']) 318 | self.qf2.load_state_dict(models_file['Q2']) 319 | self.actor.load_state_dict(models_file['Actor']) -------------------------------------------------------------------------------- /agents/sac_crossq_bro.py: -------------------------------------------------------------------------------- 1 | import time, math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from types import SimpleNamespace 6 | from .utils import avg_weight_magnitude, count_dead_units, dormant_ratio 7 | from .buffers import ReplayBufferSAC 8 | # from torchinfo import summary 9 | 10 | ''' 11 | Based on https://github.com/vwxyzjn/cleanrl 12 | SAC agent https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/sac_continuous_action.py 13 | 14 | CrossQ extensions to SAC influenced by: 15 | - CrossQ: Batch Normalization in Deep Reinforcement Learning for Greater Sample Efficiency and Simplicity https://arxiv.org/abs/1902.05605 16 | - v3 of the arxiv paper (2023) 17 | 18 | Plasticity metrics and regularisation influnced by the following papers: 19 | - Maintaining Plasticity in Continual Learning via Regenerative Regularization https://arxiv.org/abs/2308.11958 20 | - Loss of Plasticity in Deep Continual Learning https://arxiv.org/abs/2306.13812 21 | - Sample-Efficient Reinforcement Learning by Breaking the Replay Ratio Barrier https://openreview.net/forum?id=OpC-9aBBVJe 22 | – Bigger, Better, Faster: Human-level Atari with human-level efficiency https://arxiv.org/abs/2305.19452 23 | 24 | Performance optimisations inspired by BRO and friends: 25 | - BRO: Bigger, Regularized, Optimistic: scaling for compute and sample-efficient continuous control https://arxiv.org/abs/2405.16158 26 | – SR-SAC: Sample-Efficient Reinforcement Learning by Breaking the Replay Ratio Barrier https://openreview.net/forum?id=OpC-9aBBVJe 27 | - Overestimation, Overfitting, and Plasticity https://arxiv.org/abs/2403.00514 28 | ''' 29 | 30 | # Using CrossQ structure with batch normalisation 31 | class CrossQBlock(nn.Module): 32 | def __init__(self, in_dim, out_dim): 33 | super().__init__() 34 | 35 | self.cqb = nn.Sequential( 36 | nn.Linear(in_dim, out_dim, bias=False), # batchnorm has bias, this one is redundant 37 | nn.BatchNorm1d(out_dim, momentum=0.01), # CrossQ momentum 38 | nn.LeakyReLU(), # LeakyReLU = similar performance to ReLU but better plasticiy metrics 39 | nn.Linear(out_dim, out_dim, bias=False), # batchnorm has bias, this one is redundant 40 | nn.BatchNorm1d(out_dim, momentum=0.01), # CrossQ momentum 41 | nn.LeakyReLU(), 42 | ) 43 | 44 | def forward(self, x): 45 | return self.cqb(x) 46 | 47 | 48 | class SoftQNetwork(nn.Module): 49 | def __init__(self, obs_dim, action_dim, hidden_dim=256): 50 | super().__init__() 51 | 52 | # ~5m param model with depth and skip connections (BRO sweet spot size ~5m parma) 53 | self.b1 = CrossQBlock(obs_dim + action_dim, hidden_dim) 54 | self.b2 = CrossQBlock(hidden_dim, hidden_dim) 55 | self.b3 = CrossQBlock(hidden_dim, hidden_dim) 56 | self.b4 = nn.Linear(hidden_dim, 1) 57 | 58 | 59 | def forward(self, o, a=None): 60 | 61 | if a is None: 62 | x = o 63 | else: 64 | x = torch.cat([o, a], 1) 65 | 66 | x1 = self.b1(x) 67 | x2 = self.b2(x1) + x1 # add residuals 68 | x3 = self.b3(x2) + x2 69 | x4 = self.b4(x3) 70 | 71 | return x4 72 | 73 | 74 | class Actor(nn.Module): 75 | def __init__(self, env_spec, hidden_dim=256): 76 | super().__init__() 77 | 78 | obs_dim = env_spec['obs_dim'] 79 | act_dim = env_spec['act_dim'] 80 | 81 | self.mlp = nn.Sequential( 82 | nn.Linear(obs_dim, hidden_dim, bias=False), 83 | nn.BatchNorm1d(hidden_dim, momentum=0.01), #CrossQ 84 | nn.LeakyReLU(), # LeakyReLU = similar performance to ReLU but better plasticiy metrics 85 | nn.Linear(hidden_dim,hidden_dim, bias=False), 86 | nn.BatchNorm1d(hidden_dim, momentum=0.01), #CrossQ 87 | nn.LeakyReLU(), 88 | ) 89 | self.fc_mean = nn.Linear(hidden_dim, act_dim) 90 | self.fc_logstd = nn.Linear(hidden_dim, act_dim) 91 | 92 | action_high = nn.Parameter(torch.tensor(env_spec['act_max']), requires_grad=False) 93 | action_low = nn.Parameter(torch.tensor(env_spec['act_min']), requires_grad=False) 94 | self.action_scale = nn.Parameter((action_high - action_low) * 0.5, requires_grad=False) 95 | self.action_bias = nn.Parameter((action_high + action_low) * 0.5, requires_grad=False) 96 | self.log_std_max = 2 97 | self.log_std_min = -5 98 | 99 | def forward(self, x): 100 | x = self.mlp(x) 101 | mean = self.fc_mean(x) 102 | log_std = self.fc_logstd(x) 103 | log_std = torch.tanh(log_std) 104 | log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1) # From SpinUp / Denis Yarats 105 | return mean, log_std 106 | 107 | def get_action(self, x): 108 | mean, log_std = self(x) 109 | std = log_std.exp() 110 | normal = torch.distributions.Normal(mean, std, validate_args=False) # validation forces a cuda<>cpu sync 111 | x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) 112 | y_t = torch.tanh(x_t) 113 | action = y_t * self.action_scale + self.action_bias 114 | log_prob = normal.log_prob(x_t) 115 | 116 | # Enforcing Action Bound 117 | log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6) 118 | log_prob = log_prob.sum(1, keepdim=True) 119 | mean = torch.tanh(mean) * self.action_scale + self.action_bias 120 | return action, log_prob, mean 121 | 122 | 123 | class Agent: 124 | def __init__(self, 125 | env_spec, 126 | buffer_size = int(1e6), 127 | num_env = 1, 128 | device = 'cpu', 129 | seed = 42, 130 | rr = 2, # BRO Fast RR=2, BRO Default RR = 15 131 | q_lr = 1e-3, # BRO learning rates seem better than CrossQ LRs 132 | actor_lr = 1e-3, 133 | alpha_lr = 1e-3, 134 | ): 135 | 136 | # Make global 137 | self.name = "sac_crossq_bro" # name for logging 138 | self.obs_dim = env_spec['obs_dim'] # environment inputs for agent 139 | self.action_dim = env_spec['act_dim'] # agent outputs to environment 140 | self.act_max = env_spec['act_max'] # action range, scalar or vector 141 | self.act_min = env_spec['act_min'] # action range, scalar or vector 142 | self.device = device # gpu or cpu 143 | 144 | # All seeds default to 42 145 | torch.manual_seed(torch.tensor(seed)) 146 | torch.backends.cudnn.deterministic = True 147 | # torch.cuda.set_sync_debug_mode(1) # Set to 1 to receive warnings 148 | # torch.set_float32_matmul_precision("high") # "high" is 11% faster, but can reduce learning performance in certain envs 149 | 150 | # Hyperparameters 151 | hyperparameters = { 152 | "gamma" : 0.99, # (def: 0.99) Discount factor 153 | "q_lr" : q_lr, # (def: 1e-3) Q learning rate 154 | "a_lr" : actor_lr, # (def: 1e-3) Policy learning rate 155 | "alpha_lr" : alpha_lr, # (def: 1e-3) alpha auto entropoty tuning learning rate 156 | "learn_start" : int(5e3), # (def: 5e3) Start updating policies after this many global steps 157 | "batch_size" : 256, # (def: 256) Batch size of sample from replay buffer 158 | "policy_freq" : 3, # (def: 3) CrossQ 159 | "dead_hurdle" : 0.001, # (def: 0.001) units with greater variation in output over one batch of data than this are not dead in plasticity terms 160 | "a_hidden_dim" : 256, # (def: 256) size of actor's hidden layer(s) 161 | "q_hidden_dim" : 1024, # (def: 2048) CrossQ with 512 wide Qf did just as well, but with a little more variance 162 | "replay_ratio" : round(rr), 163 | "adam_betas" : (0.5, 0.999), # CrossQ 164 | } 165 | self.h = SimpleNamespace(**hyperparameters) 166 | 167 | # Logging & debugging 168 | self.qf1_a_values = torch.tensor([0.0]) # average values 169 | self.qf2_a_values = torch.tensor([0.0]) 170 | self.qf1_loss = 0 171 | self.qf2_loss = 0 172 | self.qf_loss = 0 173 | self.actor_loss = 0 174 | self.alpha_loss = 0 175 | self.actor_avg_wgt_mag = 0 # average weight magnitude as per https://arxiv.org/abs/2306.13812 176 | self.qf1_avg_wgt_mag = 0 177 | self.qf2_avg_wgt_mag = 0 178 | self.actor_dead_pct = 0 # dead units as per https://arxiv.org/abs/2306.13812 179 | self.qf1_dead_pct = 0 180 | self.qf2_dead_pct = 0 181 | self.qf1_dormant_ratio = 0 # DrM: Dormant Ratio Minimisation https://arxiv.org/abs/2310.19668 182 | self.qf2_dormant_ratio = 0 183 | self.actor_dormant_ratio = 0 184 | 185 | 186 | # Instantiate actor and Q networks, optimisers 187 | # CrossQ uses Adam but experience with AdamW is better (BRO uses AdamW) 188 | self.qf1 = SoftQNetwork(self.obs_dim, self.action_dim, self.h.q_hidden_dim).to(device) 189 | self.qf2 = SoftQNetwork(self.obs_dim, self.action_dim, self.h.q_hidden_dim).to(device) 190 | self.q_optim = torch.optim.AdamW(list(self.qf1.parameters()) + list(self.qf2.parameters()), lr=self.h.q_lr, betas=self.h.adam_betas) 191 | 192 | # Check model size and architecture 193 | # summary(self.qf1, input_size=(1, self.obs_dim+self.action_dim)) 194 | # exit() 195 | 196 | self.actor = Actor(env_spec, self.h.a_hidden_dim).to(device) 197 | self.actor_optim = torch.optim.AdamW(list(self.actor.parameters()), lr=self.h.a_lr) 198 | 199 | # Use automatic entropy tuning 200 | self.target_entropy = torch.tensor(-self.action_dim, device=device, requires_grad=False) 201 | self.log_alpha = torch.tensor((math.log(0.1)), requires_grad=True, device=device) 202 | self.alpha = self.log_alpha.exp().item() 203 | self.alpha_optim = torch.optim.AdamW([self.log_alpha], lr=self.h.alpha_lr) 204 | 205 | # Storage setup 206 | self.rb = ReplayBufferSAC(self.obs_dim, self.action_dim, buffer_size, num_env, device=self.device) 207 | self.global_step = 0 208 | 209 | # CUDA timers for the update process 210 | self.chronos_start = torch.cuda.Event(enable_timing=True) 211 | self.chronos_end = torch.cuda.Event(enable_timing=True) 212 | 213 | 214 | def choose_action(self, obs): 215 | # Random uniform actions before learn_start can speed up training over using the agent's inital randomness. 216 | if self.global_step < self.h.learn_start: 217 | # actions are rand_uniform of shape (obs_batch_size, action_dim) 218 | action = (torch.rand((obs.size(0), self.action_dim), device=self.device) - 0.5) * 2.0 # rand_uniform -1..+1 219 | action = action * self.actor.action_scale + self.actor.action_bias # apply scale and bias 220 | else: 221 | with torch.no_grad(): 222 | self.actor.eval() # prevent changes to batchnorm layers 223 | action, _, _ = self.actor.get_action(obs) 224 | self.actor.train() 225 | self.rb.store_choice(obs, action) 226 | return action 227 | 228 | def store_transition(self, reward, done): 229 | self.rb.store_transition(reward, done) 230 | self.global_step += 1 231 | 232 | 233 | def update(self): 234 | 235 | ''' Call every step from learning script, agent decides if it is time to update ''' 236 | 237 | # Bookeeping 238 | updated = False 239 | chronos_total = 0.0 240 | 241 | if self.global_step > self.h.learn_start: 242 | updated = True 243 | self.chronos_start.record() 244 | 245 | for replay in range(0, self.h.replay_ratio): 246 | 247 | b_obs, b_actions, b_obs_next, b_rewards, b_dones = self.rb.sample(self.h.batch_size) 248 | 249 | with torch.no_grad(): 250 | self.actor.eval() 251 | next_state_actions, next_state_log_pi, _ = self.actor.get_action(b_obs_next) 252 | self.actor.train() 253 | 254 | bb_obs = torch.cat((b_obs, b_obs_next), dim=0) 255 | bb_acts = torch.cat((b_actions, next_state_actions), dim=0) 256 | 257 | bb_q1 = self.qf1(bb_obs, bb_acts) 258 | bb_q2 = self.qf2(bb_obs, bb_acts) 259 | 260 | b_q1, b_q1_next = torch.chunk(bb_q1, chunks=2, dim=0) 261 | b_q2, b_q2_next = torch.chunk(bb_q2, chunks=2, dim=0) 262 | self.qf1_a_values = b_q1 # mean of this is used in logging 263 | self.qf2_a_values = b_q2 # mean of this is used in logging 264 | 265 | # BRO uses mean() instead of min() for "exploration optimism" 266 | # min_q_next = torch.min(b_q1_next, b_q2_next) - self.alpha * next_state_log_pi 267 | min_q_next = ((b_q1_next + b_q2_next) * 0.5) - self.alpha * next_state_log_pi 268 | 269 | next_q_value = b_rewards.flatten() + (1 - b_dones.flatten()) * self.h.gamma * (min_q_next).view(-1) 270 | torch.detach_(next_q_value) # no gradients through here 271 | 272 | self.qf1_loss = F.mse_loss(b_q1.flatten(), next_q_value) 273 | self.qf2_loss = F.mse_loss(b_q2.flatten(), next_q_value) 274 | self.qf_loss = self.qf1_loss + self.qf2_loss 275 | 276 | self.q_optim.zero_grad() 277 | self.qf_loss.backward() 278 | self.q_optim.step() 279 | 280 | # BRO and SR-SAC seem to include actor in RR (unclear: alpha?) 281 | 282 | # Update actor network and alpha parameter 283 | if self.global_step % self.h.policy_freq == 0: # TD 3 Delayed update support 284 | pi, log_pi, _ = self.actor.get_action(b_obs) 285 | 286 | self.qf1.eval() 287 | self.qf2.eval() 288 | qf1_pi = self.qf1(b_obs, pi) 289 | qf2_pi = self.qf2(b_obs, pi) 290 | self.qf1.train() 291 | self.qf2.train() 292 | 293 | # Not clear mean() is better than min() here though BRO seems to use mean() to actor training too 294 | min_qf_pi = torch.min(qf1_pi, qf2_pi).view(-1) 295 | self.actor_loss = ((self.alpha * log_pi) - min_qf_pi).mean() 296 | 297 | self.actor_optim.zero_grad() 298 | self.actor_loss.backward() 299 | self.actor_optim.step() 300 | 301 | # Autotune alpha 302 | with torch.no_grad(): 303 | self.actor.eval() 304 | _, log_pi, _ = self.actor.get_action(b_obs) 305 | self.actor.train() 306 | self.alpha_loss = (-self.log_alpha.exp() * (log_pi + self.target_entropy)).mean() 307 | 308 | self.alpha_optim.zero_grad() 309 | self.alpha_loss.backward() 310 | self.alpha_optim.step() 311 | self.alpha = self.log_alpha.exp().detach().clone() 312 | 313 | # Plasticity metrics occasionally 314 | if self.global_step % 2048 == 0 or self.global_step == self.h.learn_start: 315 | self.actor_avg_wgt_mag = avg_weight_magnitude(self.actor) 316 | self.qf1_avg_wgt_mag = avg_weight_magnitude(self.qf1) 317 | self.qf2_avg_wgt_mag = avg_weight_magnitude(self.qf2) 318 | 319 | b_obs, b_actions = self.rb.plasticity_data(2048) # a representative sample 320 | _, _, self.qf1_dead_pct = count_dead_units(self.qf1, in1=b_obs, in2=b_actions, threshold=self.h.dead_hurdle) 321 | _, _, self.qf2_dead_pct = count_dead_units(self.qf2, in1=b_obs, in2=b_actions, threshold=self.h.dead_hurdle) 322 | _, _, self.actor_dead_pct = count_dead_units(self.actor, in1=b_obs, threshold=self.h.dead_hurdle) 323 | 324 | self.qf1_dormant_ratio = dormant_ratio(self.qf1, in1=b_obs, in2=b_actions) 325 | self.qf2_dormant_ratio = dormant_ratio(self.qf2, in1=b_obs, in2=b_actions) 326 | self.actor_dormant_ratio = dormant_ratio(self.actor, in1=b_obs) 327 | 328 | # Record end time, wait for all cuda threads to sync and calc time in seconds 329 | self.chronos_end.record() 330 | torch.cuda.synchronize() 331 | chronos_total = (self.chronos_start.elapsed_time(self.chronos_end) * 0.001) 332 | 333 | return updated, chronos_total 334 | 335 | 336 | def save(self, path='./checkpoints/'): 337 | 338 | path = path + self.name + '.pt' 339 | models = {} 340 | 341 | models['Q1'] = self.qf1.state_dict() 342 | models['Q2'] = self.qf2.state_dict() 343 | models['Actor'] = self.actor.state_dict() 344 | 345 | torch.save(models, path) 346 | 347 | def load(self, path='./checkpoints/'): 348 | path = path + self.name + '.pt' 349 | models_file = torch.load(path) 350 | 351 | self.qf1.load_state_dict(models_file['Q1']) 352 | self.qf2.load_state_dict(models_file['Q2']) 353 | self.actor.load_state_dict(models_file['Actor']) -------------------------------------------------------------------------------- /agents/sac_droq.py: -------------------------------------------------------------------------------- 1 | import time, math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from types import SimpleNamespace 6 | from .utils import avg_weight_magnitude, count_dead_units, init_model 7 | from .buffers import ReplayBufferSAC 8 | 9 | ''' 10 | Based on https://github.com/vwxyzjn/cleanrl 11 | SAC agent https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/sac_continuous_action.py 12 | 13 | DroQ extensions to SAC based on: 14 | - Original paper: https://openreview.net/forum?id=xCVJMsPv3RT 15 | – Berkely implemetation on a quadruped (Walk In the Park (WITP)): https://arxiv.org/abs/2208.07860 16 | 17 | Plasticity metrics and regularisation influnced by the following papers: 18 | - Maintaining Plasticity in Continual Learning via Regenerative Regularization https://arxiv.org/abs/2308.11958 19 | - Loss of Plasticity in Deep Continual Learning https://arxiv.org/abs/2306.13812 20 | - Sample-Efficient Reinforcement Learning by Breaking the Replay Ratio Barrier https://openreview.net/forum?id=OpC-9aBBVJe 21 | – Bigger, Better, Faster: Human-level Atari with human-level efficiency https://arxiv.org/abs/2305.19452 22 | ''' 23 | 24 | 25 | class SoftQNetwork(nn.Module): 26 | def __init__(self, obs_dim, action_dim, hidden_dim=256): 27 | super().__init__() 28 | 29 | self.mlp = nn.Sequential( 30 | nn.Linear(obs_dim + action_dim, hidden_dim), 31 | nn.Dropout(0.01), # Droq uses 0.01 for dropout 32 | nn.LayerNorm(hidden_dim), 33 | nn.ReLU(), 34 | nn.Linear(hidden_dim,hidden_dim), 35 | nn.Dropout(0.01), 36 | nn.LayerNorm(hidden_dim), 37 | nn.ReLU(), 38 | nn.Linear(hidden_dim, 1) 39 | ) 40 | 41 | # WITP and DroQ implementations both use xavier_uniform initialisations 42 | # Correct initialisation is important for stable training as replay ratio increases 43 | init_model(self.mlp, init_method='xavier_uniform_') 44 | 45 | def forward(self, x, a): 46 | x = torch.cat([x, a], 1) 47 | x = self.mlp(x) 48 | return x 49 | 50 | 51 | class Actor(nn.Module): 52 | def __init__(self, env_spec, hidden_dim=256): 53 | super().__init__() 54 | obs_dim = env_spec['obs_dim'] 55 | act_dim = env_spec['act_dim'] 56 | 57 | self.fc0 = nn.Linear(obs_dim, hidden_dim) 58 | self.fc1 = nn.Linear(hidden_dim, hidden_dim) 59 | self.fc_mean = nn.Linear(hidden_dim, act_dim) 60 | self.fc_logstd = nn.Linear(hidden_dim, act_dim) 61 | self.nonlin = nn.ReLU() # If the non-linearity has state (e.g. trainable parameter) we need one per use 62 | 63 | action_high = nn.Parameter(torch.tensor(env_spec['act_max']), requires_grad=False) 64 | action_low = nn.Parameter(torch.tensor(env_spec['act_min']), requires_grad=False) 65 | self.action_scale = nn.Parameter((action_high - action_low) * 0.5, requires_grad=False) 66 | self.action_bias = nn.Parameter((action_high + action_low) * 0.5, requires_grad=False) 67 | self.log_std_max = 2 68 | self.log_std_min = -5 69 | 70 | def forward(self, x): 71 | x = self.nonlin(self.fc0(x)) 72 | x = self.nonlin(self.fc1(x)) 73 | mean = self.fc_mean(x) 74 | log_std = self.fc_logstd(x) 75 | log_std = torch.tanh(log_std) 76 | log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1) # From SpinUp / Denis Yarats 77 | return mean, log_std 78 | 79 | def get_action(self, x): 80 | mean, log_std = self(x) 81 | std = log_std.exp() 82 | normal = torch.distributions.Normal(mean, std, validate_args=False) # validation forces a cuda<>cpu sync 83 | x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) 84 | y_t = torch.tanh(x_t) 85 | action = y_t * self.action_scale + self.action_bias 86 | log_prob = normal.log_prob(x_t) 87 | 88 | # Enforcing Action Bound 89 | log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6) 90 | log_prob = log_prob.sum(1, keepdim=True) 91 | mean = torch.tanh(mean) * self.action_scale + self.action_bias 92 | return action, log_prob, mean 93 | 94 | 95 | class Agent: 96 | def __init__(self, 97 | env_spec, 98 | buffer_size=int(1e6), 99 | num_env=1, 100 | device='cpu', 101 | seed=42, 102 | rr=20, # DroQ paper 103 | q_lr=1e-3, 104 | actor_lr=3e-4, 105 | alpha_lr=1e-3, 106 | ): 107 | 108 | # Make global 109 | self.name = "sac_droq" # name for logging 110 | self.obs_dim = env_spec['obs_dim'] # environment inputs for agent 111 | self.action_dim = env_spec['act_dim'] # agent outputs to environment 112 | self.act_max = env_spec['act_max'] # action range, scalar or vector 113 | self.act_min = env_spec['act_min'] # action range, scalar or vector 114 | self.device = device # gpu or cpu 115 | 116 | # All seeds default to 42 117 | torch.manual_seed(torch.tensor(seed)) 118 | torch.backends.cudnn.deterministic = True 119 | # torch.cuda.set_sync_debug_mode(1) # Set to 1 to receive warnings 120 | 121 | # Hyperparameters 122 | hyperparameters = { 123 | "gamma" : 0.99, # (def: 0.99) Discount factor 124 | "q_lr" : q_lr, # (def: 1e-3) Q learning rate 125 | "a_lr" : actor_lr, # (def: 3e-4) Policy learning rate 126 | "alpha_lr" : alpha_lr, # (def: 1e-3) alpha auto entropoty tuning learning rate 127 | "learn_start" : int(5e3), # (def: 5e3) Start updating policies after this many global steps 128 | "batch_size" : 256, # (def: 256) Batch size of sample from replay buffer 129 | "policy_freq" : 2, # (def: 2) the frequency of training policy (delayed) 130 | "target_net_freq" : 1, # (def: 1) Denis Yarats' implementation delays this by 2 131 | "tau" : 0.005, # (def: 0.005) target smoothing coefficient 132 | "dead_hurdle" : 0.01, # (def: 0.01) units with greater variation in output over one batch of data than this are not dead in plasticity terms 133 | "a_hidden_dim" : 256, # (def: 256) size of actor's hidden layer(s) 134 | "q_hidden_dim" : 256, # (def: 256) size of Q's hidden layer(s) 135 | "q_max_grad_norm" : 1000.0, # (def: 1000) qf maximum norm for the gradient clipping 136 | "a_max_grad_norm" : 1000.0, # (def: 1000) actor and alpha maximum norm for the gradient clipping 137 | "replay_ratio" : round(rr), 138 | } 139 | self.h = SimpleNamespace(**hyperparameters) 140 | 141 | # Loggin & debugging 142 | self.qf1_a_values = torch.tensor([0.0]) 143 | self.qf2_a_values = torch.tensor([0.0]) 144 | self.qf1_loss = 0 145 | self.qf2_loss = 0 146 | self.qf_loss = 0 147 | self.actor_loss = 0 148 | self.alpha_loss = 0 149 | self.actor_avg_wgt_mag = 0 # average weight magnitude of model parameters 150 | self.qf1_avg_wgt_mag = 0 # average weight magnitude of model parameters 151 | self.qf2_avg_wgt_mag = 0 # average weight magnitude of model parameters 152 | self.actor_dead_pct = 0 # percentage of units which are dead by some threshold 153 | self.qf1_dead_pct = 0 # percentage of units which are dead by some threshold 154 | self.qf2_dead_pct = 0 # percentage of units which are dead by some threshold 155 | 156 | # Instantiate actor and Q networks, optimisers 157 | # AdamW (may have) resulted in more stable training than Adam 158 | self.qf1 = SoftQNetwork(self.obs_dim, self.action_dim, self.h.q_hidden_dim).to(device) 159 | self.qf2 = SoftQNetwork(self.obs_dim, self.action_dim, self.h.q_hidden_dim).to(device) 160 | self.qf1_target = SoftQNetwork(self.obs_dim, self.action_dim, self.h.q_hidden_dim).to(device) 161 | self.qf2_target = SoftQNetwork(self.obs_dim, self.action_dim, self.h.q_hidden_dim).to(device) 162 | self.q_optim = torch.optim.Adam(list(self.qf1.parameters()) + list(self.qf2.parameters()), lr=self.h.q_lr) 163 | self.qf1_target.load_state_dict(self.qf1.state_dict()) 164 | self.qf2_target.load_state_dict(self.qf2.state_dict()) 165 | 166 | self.actor = Actor(env_spec, self.h.a_hidden_dim).to(device) 167 | self.actor_optim = torch.optim.Adam(list(self.actor.parameters()), lr=self.h.a_lr) 168 | init_model(self.actor, init_method='xavier_uniform_') # Correct initialisation essential for stable training with increasing replay ratio 169 | 170 | # Use automatic entropy tuning 171 | # WITP initialises alpha = 0.1, this seems to help stabilise training 172 | self.target_entropy = -torch.prod(torch.Tensor((self.action_dim,)).to(device)).item() 173 | self.log_alpha = torch.tensor((math.log(0.1)), requires_grad=True, device=device) 174 | self.alpha = self.log_alpha.exp().item() 175 | self.alpha_optim = torch.optim.Adam([self.log_alpha], lr=self.h.alpha_lr) 176 | 177 | # Storage setup 178 | self.rb = ReplayBufferSAC(self.obs_dim, self.action_dim, buffer_size, num_env, device=self.device) 179 | self.global_step = 0 180 | 181 | 182 | def choose_action(self, obs): 183 | # Random uniform actions before learn_start can speed up training over using the agent's inital randomness. 184 | if self.global_step < self.h.learn_start: 185 | # actions are rand_uniform of shape (obs_batch_size, action_dim) 186 | action = (torch.rand((obs.size(0), self.action_dim), device=self.device) - 0.5) * 2.0 # rand_uniform -1..+1 187 | action = action * self.actor.action_scale + self.actor.action_bias # apply scale and bias 188 | else: 189 | with torch.no_grad(): 190 | action, _, _ = self.actor.get_action(obs) 191 | self.rb.store_choice(obs, action) 192 | return action 193 | 194 | def store_transition(self, reward, done): 195 | self.rb.store_transition(reward, done) 196 | self.global_step += 1 197 | 198 | 199 | def update(self): 200 | 201 | ''' Call every step from learning script, agent decides if it is time to update ''' 202 | 203 | # Bookeeping 204 | updated = False 205 | chronos_total = 0.0 206 | 207 | if self.global_step > self.h.learn_start: 208 | updated = True 209 | chronos_start = time.time() 210 | 211 | for replay in range(0,self.h.replay_ratio): 212 | 213 | b_obs, b_actions, b_obs_next, b_rewards, b_dones = self.rb.sample(self.h.batch_size) 214 | 215 | with torch.no_grad(): 216 | next_state_actions, next_state_log_pi, _ = self.actor.get_action(b_obs_next) 217 | qf1_next_target = self.qf1_target(b_obs_next, next_state_actions) 218 | qf2_next_target = self.qf2_target(b_obs_next, next_state_actions) 219 | min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi 220 | next_q_value = b_rewards.flatten() + (1 - b_dones.flatten()) * self.h.gamma * (min_qf_next_target).view(-1) 221 | 222 | self.qf1_a_values = self.qf1(b_obs, b_actions).view(-1) 223 | self.qf2_a_values = self.qf2(b_obs, b_actions).view(-1) 224 | self.qf1_loss = F.mse_loss(self.qf1_a_values, next_q_value) 225 | self.qf2_loss = F.mse_loss(self.qf2_a_values, next_q_value) 226 | self.qf_loss = self.qf1_loss + self.qf2_loss 227 | 228 | self.q_optim.zero_grad() 229 | self.qf_loss.backward() 230 | nn.utils.clip_grad_norm_(self.qf1.parameters(), self.h.q_max_grad_norm) # Almost certainly unecessary 231 | nn.utils.clip_grad_norm_(self.qf2.parameters(), self.h.q_max_grad_norm) 232 | self.q_optim.step() 233 | 234 | # update the target networks within repay loop 235 | if self.global_step % self.h.target_net_freq == 0: 236 | for param, target_param in zip(self.qf1.parameters(), self.qf1_target.parameters()): 237 | target_param.data.copy_(self.h.tau * param.data + (1.0 - self.h.tau) * target_param.data) 238 | for param, target_param in zip(self.qf2.parameters(), self.qf2_target.parameters()): 239 | target_param.data.copy_(self.h.tau * param.data + (1.0 - self.h.tau) * target_param.data) 240 | 241 | # Replay Ratio does not apply to actor nor to alpha 242 | # update actor network and alpha parameter 243 | b_obs, _, _, _, _ = self.rb.sample(self.h.batch_size) 244 | 245 | if self.global_step % self.h.policy_freq == 0: # TD 3 Delayed update support 246 | for _ in range(self.h.policy_freq): # compensate for the delay by doing 'actor_update_interval' instead of 1 247 | pi, log_pi, _ = self.actor.get_action(b_obs) 248 | 249 | self.qf1.eval() 250 | self.qf2.eval() 251 | qf1_pi = self.qf1(b_obs, pi) 252 | qf2_pi = self.qf2(b_obs, pi) 253 | self.qf1.train() 254 | self.qf2.train() 255 | 256 | min_qf_pi = torch.min(qf1_pi, qf2_pi).view(-1) 257 | self.actor_loss = ((self.alpha * log_pi) - min_qf_pi).mean() 258 | 259 | self.actor_optim.zero_grad() 260 | self.actor_loss.backward() 261 | nn.utils.clip_grad_norm_(list(self.actor.parameters()), self.h.a_max_grad_norm) # almost certainly unecessary 262 | self.actor_optim.step() 263 | 264 | # Autotune alpha 265 | with torch.no_grad(): 266 | _, log_pi, _ = self.actor.get_action(b_obs) 267 | self.alpha_loss = (-self.log_alpha.exp() * (log_pi + self.target_entropy)).mean() 268 | 269 | self.alpha_optim.zero_grad() 270 | self.alpha_loss.backward() 271 | nn.utils.clip_grad_norm_([self.log_alpha], self.h.a_max_grad_norm) # almost certainly unecessary 272 | self.alpha_optim.step() 273 | self.alpha = self.log_alpha.exp().detach().clone() 274 | 275 | # Plasticity metrics occasionally 276 | if self.global_step % 2048 == 0 or self.global_step == self.h.learn_start: 277 | self.actor_avg_wgt_mag = avg_weight_magnitude(self.actor) 278 | self.qf1_avg_wgt_mag = avg_weight_magnitude(self.qf1) 279 | self.qf2_avg_wgt_mag = avg_weight_magnitude(self.qf2) 280 | 281 | b_obs, b_actions = self.rb.plasticity_data(2048) # a representative sample 282 | _, _, self.actor_dead_pct = count_dead_units(self.actor, in1=b_obs, threshold=self.h.dead_hurdle) 283 | _, _, self.qf1_dead_pct = count_dead_units(self.qf1, in1=b_obs, in2=b_actions, threshold=self.h.dead_hurdle) 284 | _, _, self.qf2_dead_pct = count_dead_units(self.qf2, in1=b_obs, in2=b_actions, threshold=self.h.dead_hurdle) 285 | 286 | chronos_total = time.time() - chronos_start 287 | 288 | return updated, chronos_total 289 | -------------------------------------------------------------------------------- /agents/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | ''' Utility functions for agents ''' 6 | 7 | # symlog() and symexp() are alternatives to avg/mean running input normalisation 8 | # Dreamer V3 https://arxiv.org/abs/2301.04104 9 | @torch.jit.script 10 | def symlog(x): 11 | return torch.sign(x) * torch.log(torch.abs(x)+1) 12 | 13 | @torch.jit.script 14 | def symexp(x): 15 | return torch.sign(x) * torch.exp(torch.abs(x)-1) 16 | 17 | 18 | def init_layer(layer, std=torch.sqrt(torch.tensor(2)), bias_const=0.0): 19 | torch.nn.init.orthogonal_(layer.weight, std) 20 | if bias_const is None: 21 | layer.bias = None 22 | else: 23 | torch.nn.init.constant_(layer.bias, bias_const) 24 | return layer 25 | 26 | 27 | def init_model(model, init_method): 28 | """ 29 | Initialize the weights and biases of a PyTorch model's linear layers using the specified initialization method. 30 | 31 | Args: 32 | model (nn.Module): The PyTorch model to initialize. 33 | init_method (str): The initialization method to apply to the model's linear layers. 34 | Should be one of 'xavier_uniform_', 'kaiming_uniform_', or 'orthogonal_'. 35 | 36 | Returns: 37 | None 38 | """ 39 | # iterate over the model's linear layers and initialize their weights and biases 40 | for layer in model.modules(): 41 | if isinstance(layer, nn.Linear): 42 | # initialize the weights and biases using the specified method 43 | if init_method == 'xavier_uniform_': 44 | nn.init.xavier_uniform_(layer.weight) 45 | if layer.bias is not None: nn.init.zeros_(layer.bias) 46 | elif init_method == 'kaiming_uniform_': 47 | nn.init.kaiming_uniform_(layer.weight, a=0, mode='fan_in', nonlinearity='linear') 48 | if layer.bias is not None: nn.init.zeros_(layer.bias) 49 | elif init_method == 'orthogonal_': 50 | nn.init.orthogonal_(layer.weight) 51 | if layer.bias is not None: nn.init.zeros_(layer.bias) 52 | 53 | 54 | # Credit GPT4 55 | def avg_weight_magnitude(model): 56 | """ Average weight magnitude of a models parameters """ 57 | total_weight_sum = 0.0 58 | total_weight_count = 0 59 | 60 | for param in model.parameters(): 61 | total_weight_sum += torch.sum(torch.abs(param)) 62 | total_weight_count += param.numel() 63 | 64 | return total_weight_sum / total_weight_count 65 | 66 | 67 | # GPT4 helpful as ever 68 | def count_dead_units(model, in1, in2=None, threshold=0.01): 69 | """ Counts dead units based on observed range of activation for various types """ 70 | 71 | activations = [] 72 | 73 | def hook_fn(module, input, output): 74 | activations.append(output) 75 | 76 | # Register hooks for all layers 77 | hooks = [] 78 | for layer in model.modules(): 79 | if isinstance(layer, (nn.ReLU, nn.Sigmoid, nn.Tanh, nn.LeakyReLU, nn.ELU, nn.SiLU)): 80 | hooks.append(layer.register_forward_hook(hook_fn)) 81 | 82 | # Initialize max and min values for each layer's activations 83 | max_values = [] 84 | min_values = [] 85 | 86 | model.eval() # dont mess with the model if it has e.g. batchnorm, dropout etc 87 | if in2 is None: 88 | with torch.no_grad(): 89 | model(in1) 90 | else: 91 | with torch.no_grad(): 92 | model(in1, in2) 93 | model.train() # put it back how it was 94 | 95 | # Directly compute the max and min values for each activation unit 96 | max_values = [act.max(dim=0)[0] for act in activations] 97 | min_values = [act.min(dim=0)[0] for act in activations] 98 | 99 | # Count dead units based on observed activation range 100 | dead_units_count = 0 101 | total_units_count = 0 102 | 103 | for max_val, min_val in zip(max_values, min_values): 104 | dead_range = (max_val - min_val) < threshold 105 | dead_units_count += dead_range.sum() 106 | total_units_count += max_val.numel() # Since max_val and min_val have the same size, we can use either for counting 107 | 108 | # Clean up hooks 109 | for hook in hooks: 110 | hook.remove() 111 | 112 | dead_percentage = (dead_units_count / total_units_count) * 100.0 113 | return dead_units_count, total_units_count, dead_percentage 114 | 115 | 116 | # Credit GPT4: Class to store initial parameters and compute L2 Init loss 117 | # https://arxiv.org/abs/2308.11958 118 | class L2NormRegularizer(): 119 | def __init__(self, model, lambda_reg, device): 120 | self.device = device 121 | self.lambda_reg = lambda_reg 122 | self.initial_params = torch.empty(0, device=device) 123 | 124 | # A tensor of model parameters 125 | for par in model.parameters(): 126 | self.initial_params = torch.cat((self.initial_params, par.view(-1).detach().clone()), dim=0) 127 | 128 | # Add this loss to the model's usual loss 129 | def __call__(self, model): 130 | current_params = torch.empty(0, device=self.device) 131 | 132 | for par in model.parameters(): 133 | current_params = torch.cat((current_params, par.view(-1)), dim=0) 134 | 135 | l2_loss = torch.linalg.vector_norm((current_params - self.initial_params), ord=2) 136 | return self.lambda_reg * l2_loss 137 | 138 | 139 | # DrM: Dormant Ratio Minimisation https://arxiv.org/abs/2310.19668 140 | # https://github.com/XuGW-Kevin/DrM/blob/main/utils.py 141 | # "a sharp decline in the dormant ratio of an agent’s policy network serves as an intrinsic indicator of the agent executing meaningful actions for exploration" 142 | # see dormant neurons here also: https://arxiv.org/abs/2302.12902 143 | # to get dormant neurons from start of training like paper needed 1. nn.RuLE(inplace=True) & 2. no BatchNorm() 144 | class LinearOutputHook: 145 | def __init__(self): 146 | self.outputs = [] 147 | 148 | def __call__(self, module, module_in, module_out): 149 | self.outputs.append(module_out) 150 | 151 | 152 | def dormant_ratio(model, in1, in2=None, percentage=0.10): 153 | hooks = [] 154 | hook_handlers = [] 155 | total_neurons = 0 156 | dormant_neurons = 0 157 | 158 | for _, module in model.named_modules(): 159 | if isinstance(module, nn.Linear): 160 | hook = LinearOutputHook() 161 | hooks.append(hook) 162 | hook_handlers.append(module.register_forward_hook(hook)) 163 | 164 | model.eval() # dont mess with the model if it has e.g. batchnorm, dropout etc 165 | if in2 is None: 166 | with torch.no_grad(): 167 | model(in1) 168 | else: 169 | with torch.no_grad(): 170 | model(in1, in2) 171 | model.train() # put it back how it was 172 | 173 | for module, hook in zip((module for module in model.modules() if isinstance(module, nn.Linear)), hooks): 174 | with torch.no_grad(): 175 | for output_data in hook.outputs: 176 | mean_output = output_data.abs().mean(0) 177 | avg_neuron_output = mean_output.mean() 178 | dormant_indices = (mean_output < avg_neuron_output * percentage).nonzero(as_tuple=True)[0] 179 | total_neurons += module.weight.shape[0] 180 | dormant_neurons += len(dormant_indices) 181 | 182 | for hook in hooks: 183 | hook.outputs.clear() 184 | 185 | for hook_handler in hook_handlers: 186 | hook_handler.remove() 187 | 188 | return dormant_neurons / total_neurons 189 | 190 | # https://sites.google.com/view/rl-vcse 191 | # https://github.com/kingdy2002/VCSE/blob/main/VCSE_SAC/vcse.py 192 | ''' 193 | Usage: 194 | 1. Add another pair of critics to SAC 195 | 2. One pair trains without receiving vcse intrinsic reward. Use this pair to feed value to vcse 196 | 3. Other pair trains with vcse intrinsic reward (add to usual reward). use this pair to train actor. 197 | ''' 198 | class VCSE(object): 199 | """particle-based entropy based on knn normalized by running mean """ 200 | def __init__(self, knn_k=12, beta=0.1, device='cpu'): 201 | self.knn_k = knn_k 202 | self.device = device 203 | self.beta = beta # Tuning factor from paper 204 | 205 | def __call__(self, state, value): 206 | # value => [b1 , 1] 207 | # state => [b1 , c] 208 | # z => [b1, c+1] 209 | # [b1] => [b1,b1] 210 | ds = state.size(1) 211 | source = target = state 212 | b1, b2 = source.size(0), target.size(0) 213 | # (b1, 1, c+1) - (1, b2, c+1) -> (b1, 1, c+1) - (1, b2, c+1) -> (b1, b2, c+1) -> (b1, b2) 214 | sim_matrix_s = torch.norm(source[:, None, :].view(b1, 1, -1) - target[None, :, :].view(1, b2, -1), dim=-1, p=2) 215 | 216 | source = target = value 217 | # (b1, 1, 1) - (1, b2, 1) -> (b1, 1, 1) - (1, b2, 1) -> (b1, b2, 1) -> (b1, b2) 218 | sim_matrix_v = torch.norm(source[:, None, :].view(b1, 1, -1) - target[None, :, :].view(1, b2, -1), dim=-1, p=2) 219 | 220 | sim_matrix = torch.max(torch.cat((sim_matrix_s.unsqueeze(-1),sim_matrix_v.unsqueeze(-1)),dim=-1),dim=-1)[0] 221 | eps, index = sim_matrix.topk(self.knn_k, dim=1, largest=False, sorted=True) # (b1, k) 222 | 223 | state_norm, index = sim_matrix_s.topk(self.knn_k, dim=1, largest=False, sorted=True) # (b1, k) 224 | value_norm, index = sim_matrix_v.topk(self.knn_k, dim=1, largest=False, sorted=True) # (b1, k) 225 | 226 | eps = eps[:, -1] #k-th nearest distance 227 | eps = eps.reshape(-1, 1) # (b1, 1) 228 | 229 | state_norm = state_norm[:, -1] #k-th nearest distance 230 | state_norm = state_norm.reshape(-1, 1) # (b1, 1) 231 | 232 | value_norm = value_norm[:, -1] #k-th nearest distance 233 | value_norm = value_norm.reshape(-1, 1) # (b1, 1) 234 | 235 | sim_matrix_v = sim_matrix_v < eps 236 | n_v = torch.sum(sim_matrix_v,dim=1,keepdim = True) # (b1,1) 237 | 238 | sim_matrix_s = sim_matrix_s < eps 239 | n_s = torch.sum(sim_matrix_s,dim=1,keepdim = True) # (b1,1) 240 | reward = torch.digamma((n_v+1).to(torch.float)) / ds + torch.log(eps * 2 + 0.00001) 241 | return reward * self.beta, n_v,n_s, eps, state_norm, value_norm 242 | 243 | 244 | 245 | class LowPassSinglePole(): 246 | ''' Classic simple filter ''' 247 | def __init__(self, decay=0.6, vec_dim=None): 248 | self.b = 1.0 - decay 249 | 250 | if vec_dim is None: 251 | self.y = 0.0 252 | else: 253 | self.y = np.zeros(vec_dim) 254 | 255 | def filter(self, x): 256 | ''' filter a scalar or a vector of independent channels ''' 257 | self.y += self.b * (x - self.y) 258 | return self.y -------------------------------------------------------------------------------- /hypertune.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | import numpy as np 3 | import os, time, math 4 | import torch as th 5 | from torch.utils.tensorboard import SummaryWriter 6 | from multiprocessing import Process, Queue 7 | from bayes_opt import BayesianOptimization, UtilityFunction 8 | from utils import log_scalars 9 | from agents.sac_crossq_trace import Agent 10 | 11 | ''' 12 | Hyperparameter tuning using bayesian optimisation 13 | - Bayes Opt from https://github.com/fmfn/BayesianOptimization 14 | – Simultenous multiprocessing with max_workers and configurable cpu and cuda workers 15 | - Median pruner, based on https://araffin.github.io/post/hyperparam-tuning/ 16 | ''' 17 | 18 | 19 | class Pruner(): 20 | ''' Assumes target_metric is reward related therefore maximise to positive is better ''' 21 | def __init__(self): 22 | self.window = 0.1 # allows this far below the median before pruning 23 | self.target_metrics_history = {} 24 | 25 | def decide(self, target_metric, prune_chkpoint_idx): 26 | 27 | # Ensure there is a history for the current checkpoint 28 | if prune_chkpoint_idx not in self.target_metrics_history: 29 | self.target_metrics_history[prune_chkpoint_idx] = [] 30 | 31 | # Only add this target_metric to the list if it is above the existing median, or the median will decline with time 32 | metrics = self.target_metrics_history[prune_chkpoint_idx] 33 | if len(metrics) == 0 or target_metric > np.median(metrics): 34 | self.target_metrics_history[prune_chkpoint_idx].append(target_metric) 35 | 36 | # Prune if below (median - x%) for this checkpoint - less aggressive pruning, since RL is noisy 37 | prune = len(metrics) > 0 and target_metric < (np.median(metrics) - np.median(np.abs(metrics)) * self.window) 38 | 39 | # return prune # returns true if process should be pruned, false if process keeps going 40 | return False # returns true if process should be pruned, false if process keeps going 41 | 42 | 43 | def run_env(cmd_queue, sample_queue, res_queue, current_time, environment, env_name, random_seed, device, process_num, prune_chkpoints): 44 | 45 | # Consider target metric carefully, depends on env and impacts pruning considerations 46 | # Scoreboard is cleared between pruning checkpoints 47 | # target_metric = lambda: np.median(scoreboard) 48 | # target_metric = lambda: np.mean(scoreboard) 49 | target_metric = lambda: np.mean(np.sort(scoreboard)[int(len(scoreboard)*0.25):int(len(scoreboard)*0.75)]) # interquartile mean (iqm) 50 | 51 | print("\nSTARTED WORKER PROCESS: ", process_num) 52 | print("TIME: ", current_time, "ENV: ", env_name, "DEVICE: ", device) 53 | 54 | 55 | # Receive and run through sample points 56 | while True: 57 | 58 | # Wait here for next point to sample 59 | sample_point = sample_queue.get(block=True) 60 | next_point = sample_point["next_sample"] 61 | sample_num = sample_point["sample_num"] 62 | r_seed_mixed = random_seed + sample_num # because bayes-optim can pick duplicate sample points, ensure each is actually different 63 | 64 | # Unknown reason cannot use sub-dictionary directly, have to put in new variable 65 | run_steps = environment['run_steps'] 66 | assert run_steps % prune_chkpoints == 0, "prune_chkpoints must be a multiple of run_steps" 67 | 68 | # Create environment 69 | env = gym.make(env_name) 70 | env_spec = { 71 | 'act_dim' : env.action_space.shape[0], 72 | 'obs_dim' : env.observation_space.shape[0], 73 | 'act_max' : env.action_space.high, 74 | 'act_min' : env.action_space.low, 75 | } 76 | obs, info = env.reset(seed=r_seed_mixed) 77 | score = np.zeros(1) # total rewards for episode 78 | scoreboard = np.zeros(1) # array of scores 79 | 80 | # Create agent with pre-initialisation hyperparameters 81 | # Remeber math.pow(10, p) for parameters using log ranges in hyperparameter_bounds 82 | agent = Agent(env_spec, 83 | buffer_size = run_steps, 84 | device = device, 85 | seed = random_seed, 86 | # rr = next_point['replay_ratio'], 87 | # q_lr = math.pow(10, next_point['q_lr']), 88 | # actor_lr = math.pow(10, next_point['actor_lr']), 89 | # alpha_lr = math.pow(10, next_point['alpha_lr']), 90 | ) 91 | 92 | # Hyperparameters post agent initialisation, modified after instantiation 93 | # agent.h.gamma = next_point['gamma'] 94 | # agent.h.gae_lambda = next_point['gae_lambda'] 95 | # agent.h.clip_coef = next_point['clip_coef'] 96 | # agent.h.ent_coef = next_point['ent_coef'] 97 | # agent.h.vf_coef = next_point['vf_coef'] 98 | # agent.h.max_grad_norm = next_point['max_grad_norm'] 99 | # agent.h.max_kl = next_point['max_kl'] 100 | # agent.h.update_epochs = int(next_point['update_epochs']) 101 | # agent.h.mb_size = int(next_point['mb_size']) 102 | 103 | # New log for this environment 104 | writer = SummaryWriter(f"runs/hypertune/{current_time}/{env_name}/{agent.name}/{str(sample_num)}") 105 | writer.add_text("hyperparameters","seed: " + str(r_seed_mixed) + "\n\n" + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(agent.h).items()])),) 106 | 107 | # Main loop: run environment for run_steps steps 108 | sps_timer = time.time() 109 | update_step_counter = 0 110 | prune_chkpoint = 0 111 | for step in range(run_steps): 112 | 113 | # Step the environment and collect observations, shape: (batch, channels) 114 | obs_th = th.tensor(obs, device=device, dtype=th.float32).unsqueeze(0) 115 | action_th = agent.choose_action(obs_th) 116 | action_np = action_th.cpu().squeeze().numpy() 117 | obs, reward, terminated, truncated, info = env.step(action_np) 118 | done_np = (terminated or truncated) 119 | done_th = th.tensor(done_np, device=device, dtype=th.bool).unsqueeze(0) 120 | reward_th = th.tensor(reward, device=device, dtype=th.float32).unsqueeze(0).unsqueeze(0) 121 | agent.store_transition(reward_th, done_th) 122 | 123 | # Episodic score 124 | score += reward 125 | 126 | # Track episodic score 127 | if done_np: 128 | scoreboard = np.append(scoreboard, score) 129 | obs, info = env.reset(seed=r_seed_mixed) 130 | writer.add_scalar(f"score/{env_name}", score, step) 131 | writer.add_scalar(f"score/target_metric", target_metric(), step) 132 | score = np.zeros(1) 133 | 134 | # Call at every step, agent decides if an update if due 135 | updated, update_time = agent.update() 136 | 137 | # Track samples per second 138 | if step % 2048 == 0 and step != 0: 139 | sps = 2048 / (time.time() - sps_timer) 140 | sps_timer = time.time() 141 | writer.add_scalar("perf/SPS", sps, step) 142 | print(env_name,'\t',sample_num, '\tStep: ',step, '\tSPS: %0.1f' % sps, '\tTarget Metric: %0.1f' % target_metric()) 143 | 144 | # Log agent update metrics, but not too often 145 | update_step_counter += 1 146 | if updated and update_step_counter >= 2048: 147 | writer.add_scalar("perf/Update", update_time, step) 148 | log_scalars(writer, agent, step) 149 | update_step_counter = 0 150 | 151 | # send interim result back for pruning control except at 0 and final steps 152 | if (step % (run_steps // prune_chkpoints) == 0 and step != 0 and step != (run_steps - 1)): 153 | 154 | result = {'process_num': process_num, 'target_metric': target_metric(), 'sampled_point': next_point, 'prune_chkpoint':prune_chkpoint} 155 | res_queue.put(result) 156 | 157 | # Block until a cmd addressed to this process has been received 158 | while True: 159 | cmd = cmd_queue.get() 160 | if cmd['process_num'] == process_num: 161 | break 162 | 163 | prune_chkpoint += 1 164 | 165 | # Break from running steps if we're pruned, log and go restart with new sample point 166 | if (cmd['process_num'] == process_num) and (cmd['break'] == True): 167 | writer.add_hparams(hparam_dict=next_point, metric_dict={'score/target metric': target_metric()}, run_name=str(sample_num)) 168 | print(env_name,'\t',sample_num, '\tStep: ',step, '\tSPS: %0.1f' % sps, '\tTarget Metric: %0.1f' % target_metric(), '\tPRUNED at ', prune_chkpoint) 169 | break 170 | 171 | # reset scoreboard at checkpoint, after pruner has decided 172 | scoreboard = np.zeros(1) 173 | 174 | # Got to the end without being pruned, log hyperparameters 175 | if step == (run_steps - 1): 176 | writer.add_hparams(hparam_dict=next_point, metric_dict={'score/target metric': target_metric()}, run_name=str(sample_num)) 177 | 178 | # Return target metric to optimiser, indicate this is not a pruning chkpoint 179 | result = {'process_num': process_num, 'target_metric': target_metric(), 'sampled_point': next_point, 'prune_chkpoint':prune_chkpoints + 1} 180 | res_queue.put(result) 181 | 182 | # Tidy up 183 | writer.close() 184 | env.close() 185 | del agent 186 | 187 | 188 | def main(): 189 | 190 | # System parameters 191 | random_seed = 42 # default: 42 192 | num_workers_cpu = 0 # number of workers in cpu, cpu can be faster sometimes (e.g. smaller model sizes & cpu environment) 193 | num_workers_cuda = 8 # number of workers in cuda 194 | max_workers = num_workers_cpu + num_workers_cuda # Max number of workers 195 | prune_chkpoints = 4 # should be divisor of run_steps; this many pruning section per run_steps 196 | os.nice(10) # don't hog the system 197 | th.set_num_threads(2) # usually no faster with more, but parallel runs are more efficient when th.threads=1 (depending on hardware!) 198 | np.random.seed(random_seed) # also given to agent 199 | 200 | # Select only one environment unless reward/target_metric has been scaled 201 | # https://gymnasium.farama.org 202 | environments = {} 203 | 204 | # Box 2D (action range -1..+1) 205 | # environments['LunarLanderContinuous-v2'] = {'run_steps': int(120e3)} 206 | # environments['BipedalWalker-v3'] = {'run_steps': int(400e3)} 207 | environments['BipedalWalkerHardcore-v3'] = {'run_steps': int( 1e6)} 208 | 209 | # Mujoco (action range -1..+1 except Humanoid which is -0.4..+0.4) 210 | # environments['Ant-v4'] = {'run_steps': int(20e3)} 211 | # environments['HalfCheetah-v4'] = {'run_steps': int(400e3)} 212 | # environments['Walker2d-v4'] = {'run_steps': int(400e3)} 213 | # environments['Humanoid-v4'] = {'run_steps': int(200e3)} 214 | 215 | # Start time of all runs 216 | current_time = time.strftime('%j_%H:%M') 217 | 218 | # Register agent hyperparameters to optimise 219 | # Consider math.log10(x) & math.pow(10, x) for large OOM ranges 220 | hyperparameter_bounds = { 221 | # 'gamma' : (0.8, 0.9997), 222 | # 'gae_lambda' : (0.9, 1.0), 223 | # 'clip_coef' : (0.1, 0.3), 224 | # 'ent_coef' : (0.0, 0.01), 225 | # 'vf_coef' : (0.5, 1.0), 226 | # 'max_grad_norm' : (0.1, 1.0), 227 | # 'max_kl' : (0.003, 0.03), 228 | # 'adam_lr' : (0.000005, 0.003), 229 | # 'update_epochs' : (1, 10), # integer 230 | # 'mb_size' : (8, 128), # integer 231 | # 'trace_length' : (4, 512), # integer 232 | # 'latent_dim' : (4, 128) # integer 233 | # 'topk_pct' : (0.001, 0.05), 234 | # 'min_similarity' : (0.5, 1.0), 235 | # 'pivotal_reward' : (0.01, 2.0), 236 | # 'weight_decay' : (0.0, 0.1), 237 | # 'l2init_lambda_q' : (math.log10(1e-5), math.log10(1e-2)), 238 | # 'l2init_lambda_a' : (math.log10(1e-5), math.log10(1e-2)), 239 | 'replay_ratio' : (1, 8), 240 | # 'q_lr' : (math.log10(5e-4), math.log10(3e-3)), 241 | # 'alpha_lr' : (math.log10(5e-4), math.log10(3e-3)), 242 | # 'actor_lr' : (math.log10(6e-5), math.log10(9e-4)), 243 | } 244 | 245 | # Setup bayesian optmiser. RL is noisy, probing duplicate points is valid. Can crash if not allowed. 246 | # Utility function EI is preferred for robustness with noisy objective functions whilst exploring landscape 247 | optimiser = BayesianOptimization(f=None, pbounds=hyperparameter_bounds, allow_duplicate_points=True, random_state=random_seed) 248 | utility = UtilityFunction(kind="ucb") 249 | # utility = UtilityFunction(kind="ei") # hypertune the optim-hyper-parameters! 250 | 251 | # Process and queues bookeeping 252 | processes = [] 253 | sample_queue = Queue() # Send sampling points to workers 254 | res_queue = Queue() # Receive results from workers 255 | cmd_queue = Queue() 256 | sample_count = 0 # bayes-optim sample point number 257 | pruner = Pruner() 258 | 259 | # Start processes for listed environments 260 | # Assumes if different envs that metric is normalised/scaled appropriately 261 | for _ in range(max_workers): 262 | for env_name in environments.keys(): 263 | 264 | # Allocate workers to device 265 | if len(processes) < num_workers_cpu: 266 | device = 'cpu' 267 | else: 268 | device = 'cuda' 269 | 270 | # Stop creating workers when we reach max_workers 271 | if len(processes) >= max_workers: 272 | break 273 | 274 | # Each worker is a multiprocessing process connected back here with queues 275 | p = Process(target=run_env, 276 | kwargs={ 277 | 'cmd_queue': cmd_queue, 278 | 'sample_queue': sample_queue, 279 | 'res_queue': res_queue, 280 | 'current_time': current_time, 281 | 'environment': environments[env_name], 282 | 'env_name': env_name, 283 | 'random_seed': random_seed, 284 | 'device': device, 285 | 'process_num': len(processes), 286 | 'prune_chkpoints': prune_chkpoints, 287 | }) 288 | p.start() 289 | processes.append(p) 290 | 291 | # Fetching multiple suggestions before any results 292 | # are registered generates different suggestions to start with 293 | next_point = optimiser.suggest(utility) 294 | sample_point = {"next_sample":next_point, "sample_num":sample_count} 295 | sample_queue.put(sample_point) 296 | sample_count += 1 297 | 298 | # Optimise forevermore (or ctl-c) 299 | while True: 300 | try: 301 | result = res_queue.get(block=True) # wait until a result is in 302 | 303 | if result['prune_chkpoint'] <= prune_chkpoints: 304 | 305 | prune_this_run = pruner.decide(result['target_metric'], result['prune_chkpoint']) 306 | 307 | if prune_this_run == True: 308 | cmd_queue.put({"process_num":result['process_num'], "break":True}) 309 | 310 | optimiser.register(params=result['sampled_point'], target=result['target_metric']) 311 | 312 | next_point = optimiser.suggest(utility) 313 | sample_point = {"next_sample":next_point, "sample_num":sample_count} 314 | sample_queue.put(sample_point) 315 | sample_count += 1 316 | else: 317 | cmd_queue.put({"process_num":result['process_num'], "break":False}) 318 | 319 | else: 320 | optimiser.register(params=result['sampled_point'], target=result['target_metric']) 321 | 322 | print("\nOPTIMISER SAMPLE: ", sample_count,' ',result['sampled_point'], "TARGET METRIC: ", result['target_metric'], "\n") 323 | 324 | next_point = optimiser.suggest(utility) 325 | sample_point = {"next_sample":next_point, "sample_num":sample_count} 326 | sample_queue.put(sample_point) 327 | sample_count += 1 328 | 329 | except KeyboardInterrupt: 330 | print("Stopping on keyboard interrupt") 331 | break 332 | 333 | # finished, terminate the processes 334 | for p in processes: 335 | p.join() 336 | exit() 337 | 338 | ########################## 339 | if __name__ == '__main__': 340 | main() -------------------------------------------------------------------------------- /learn_simple.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | import numpy as np 3 | import os, time, argparse, random 4 | import torch as th 5 | from torch.utils.tensorboard import SummaryWriter 6 | from multiprocessing import Process 7 | import multiprocessing as mp 8 | from utils import log_scalars, Colour 9 | from agents.cross_opt import Agent 10 | 11 | ''' 12 | Simple script to test and tune continuous agents 13 | Makes few assumptions about env api, so should be widely compatible 14 | 15 | Features 16 | - Simultaneous multiple environments using multiprocessing 17 | - Multiple random seeds for statistical aggregation 18 | - Tensorboard logging and progress printing to terminal 19 | - Select options as cli arguments 20 | ''' 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser(description='Train agent', 24 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 25 | 26 | parser.add_argument('--log', action='store_true', help=('Enables Tensorboard logging')) 27 | parser.add_argument('--name', type=str, help=('Name or describe this run in the logs')) 28 | parser.add_argument('--seeds', type=int, default=1, help=("Number of random seeds per environment")) 29 | parser.add_argument('--acc', action='store_true', help=('Use an accelerator')) 30 | parser.add_argument('--baseline', action='store_true', help=('Saves logs in baseline sub-folder for archiving')) 31 | parser.add_argument('--save', action='store_true', help=('Enables saving checkpoints')) 32 | parser.add_argument('--load', action='store_true', help=('Loads previously saved agent')) 33 | parser.add_argument('--quiet', action='store_true', help=('Loads previously saved agent')) 34 | 35 | 36 | # TODO: As needed 37 | # parser.add_argument('--gui', action='store_true', help=('Enables visualisation')) 38 | 39 | return parser.parse_args() 40 | 41 | def run_env(current_time, env, env_name, random_seed, device, args, log_dir): 42 | 43 | # Configure torch in spwaned child process 44 | th.set_num_threads(2) 45 | th.manual_seed(random_seed) 46 | th.backends.cudnn.deterministic = True 47 | 48 | # Unknown reason cannot use sub-dictionary directly, have to put in new variable 49 | run_steps = env['run_steps'] 50 | 51 | # Create environment 52 | env = gym.make(env_name) 53 | env_spec = { 54 | 'act_dim' : env.action_space.shape[0], 55 | 'obs_dim' : env.observation_space.shape[0], 56 | 'act_max' : env.action_space.high, 57 | 'act_min' : env.action_space.low, 58 | } 59 | obs, info = env.reset(seed=random_seed) # using the gym/gymnasium convention 60 | score = np.zeros(1) # episodic score 61 | scoreboard = np.zeros(1) # a list of episodic scores 62 | 63 | # Create agent 64 | agent = Agent(env_spec, buffer_size=run_steps, device=device, seed=random_seed) 65 | if args.load: 66 | agent.load() 67 | 68 | print('\n>>> RUNNING ENV: ', env_name, "WITH AGENT: ", agent.name) 69 | print('ACTIONS DIM: ', env_spec['act_dim'], ' OBS DIM: ', env_spec['obs_dim'], "\n") 70 | print(agent.h, '\n') 71 | 72 | # New log for this environment 73 | if args.log: 74 | writer = SummaryWriter(f"{log_dir}/{current_time}/{env_name}/{agent.name}_seed:{random_seed}") 75 | writer.add_text("hyperparameters",str(args.name) + "\n\n" + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(agent.h).items()])),) 76 | 77 | sps_timer = time.time() 78 | update_log_counter = 0 # avoid GB size logs when agent updates frequently 79 | updated = False 80 | 81 | for step in range(run_steps): 82 | 83 | # Step the environment and collect observations, shape: (batch, channels) as tensors 84 | obs_th = th.tensor(obs, device=device, dtype=th.float32).unsqueeze(0) 85 | action_th = agent.choose_action(obs_th) 86 | action_np = action_th.cpu().squeeze().to(dtype=th.float32).numpy() 87 | obs, reward, terminated, truncated, info = env.step(action_np) 88 | done = (terminated or truncated) 89 | done_th = th.tensor(done, device=device, dtype=th.int8).unsqueeze(0) 90 | reward_th = th.tensor(reward, device=device, dtype=th.float32).unsqueeze(0).unsqueeze(0) 91 | agent.store_transition(reward_th, done_th) 92 | 93 | # Episodic score 94 | score += reward 95 | 96 | # Track episodic score 97 | if done: 98 | scoreboard = np.append(scoreboard, score) 99 | obs, info = env.reset(seed=random_seed) 100 | if args.log: 101 | writer.add_scalar(f"score/{env_name}", score, step) 102 | if (not args.quiet) and (not args.log): 103 | print(env_name,'-',random_seed, 'Step: ', step, 'Score: %0.1f' % score[0]) 104 | score = np.zeros(1) 105 | 106 | # Call at every step, agent decides if an update if due 107 | updated, update_time = agent.update() 108 | 109 | # If saving checkpoints do it at x% total steps intervals 110 | if args.save and step % int(run_steps * 0.25) == 0: 111 | agent.save() 112 | 113 | # Track samples per second 114 | if step % 1024 == 0 and step != 0: 115 | sps = 1024 / (time.time() - sps_timer) 116 | sps_timer = time.time() 117 | 118 | # ups is indicative, since it's just 1/time_taken, if updates are coupled serially to env interaction 119 | ups = update_time and 1.0 / update_time or 0.0 # return zeros when dividing by zero 120 | 121 | if args.log: 122 | writer.add_scalar("perf/SPS", sps, step) 123 | # Median of latest 25% of scores 124 | print(env_name,'-',random_seed, 'Step: ',step, 'SPS: %0.1f' % sps, 'UPS: %0.1f' % ups, 'Scoreboard Median: %0.1f' % np.median(scoreboard[int(len(scoreboard) * 0.75): len(scoreboard)])) 125 | else: 126 | print(Colour.BLUE,env_name,'-',random_seed, 'Step: ',step, 'SPS: %0.1f' % sps, 'UPS: %0.1f' % ups, Colour.END) 127 | 128 | # Log agent update metrics, but not too often 129 | update_log_counter += 1 130 | if updated and args.log and update_log_counter >= 1024: 131 | writer.add_scalar("perf/Update", update_time, step) 132 | log_scalars(writer, agent, step) 133 | update_log_counter = 0 134 | 135 | # Tidy up when done 136 | print(env_name,'-',random_seed, 'FINISHED. MEDIAN OF FINAL 25pct of SCORES: %0.1f' % np.median(scoreboard[int(len(scoreboard) * 0.75): len(scoreboard)])) 137 | if args.log: writer.close() 138 | env.close() 139 | del agent 140 | 141 | 142 | def main(): 143 | 144 | # System parameters 145 | args = parse_args() 146 | random_seed = 42 # default: 42 147 | max_processes = 2 # small models (e.g. ppo) can have lots in cpu for greater total SPS 148 | th.set_num_threads(2) # threads per process, often 1 is most efficient when using cpu with seed > 1 149 | os.nice(10) # don't hog the system 150 | np.random.seed(random_seed) 151 | random.seed(random_seed) 152 | np.random.seed(random_seed) 153 | th.manual_seed(random_seed) 154 | th.backends.cudnn.deterministic = True 155 | np.set_printoptions(precision=3) 156 | mp.set_start_method('spawn') 157 | 158 | if args.acc: 159 | if th.cuda.is_available(): 160 | device = 'cuda' # cuda is faster when training is costly (large model or dataset) 161 | else: 162 | if th.backends.mps.is_available(): 163 | device = 'mps' # Recent apple macs 164 | else: 165 | print("No accelerator available, exiting.") 166 | exit() 167 | else: 168 | device = 'cpu' # cpu is often faster for PPO (small models) 169 | 170 | if not args.baseline: 171 | log_dir = 'runs' # tensorboard logging dir 172 | else: 173 | log_dir = 'runs/baseline' # baselined agents archive for future comparisons 174 | 175 | # Comment out undesired environments. Simultaneous processes: (num_envs * num_seeds) ≤ max_processes 176 | # https://gymnasium.farama.org 177 | environments = {} 178 | 179 | # Box 2D (action range -1..+1) 180 | environments['LunarLanderContinuous-v2'] = {'run_steps': int(120e3)} 181 | environments['BipedalWalker-v3'] = {'run_steps': int(400e3)} 182 | # environments['BipedalWalkerHardcore-v3'] = {'run_steps': int(400e3)} 183 | 184 | # Mujoco (action range -1..+1 except Humanoid which is -0.4..+0.4) 185 | environments['Ant-v4'] = {'run_steps': int(400e3)} 186 | environments['HalfCheetah-v4'] = {'run_steps': int(400e3)} 187 | environments['Walker2d-v4'] = {'run_steps': int(400e3)} 188 | environments['Humanoid-v4'] = {'run_steps': int(400e3)} 189 | environments['HumanoidStandup-v4'] = {'run_steps': int(400e3)} 190 | 191 | # start time of all runs 192 | current_time = time.strftime('%j_%H:%M') 193 | start_runs = time.time() 194 | 195 | processes = [] 196 | total_sleep = 0.0 197 | 198 | print("\nEXPERIMENT NAME: ", args.name) 199 | if args.log: 200 | print(f'>>> TENSORBOARD -> ENABLED IN /{log_dir}/{current_time}') 201 | 202 | # Run a process for each seed of each env 203 | for seed in range(args.seeds): 204 | for env_name in environments.keys(): 205 | 206 | p = Process(target=run_env, 207 | kwargs={ 208 | 'current_time': current_time, 209 | 'env': environments[env_name], 210 | 'env_name': env_name, 211 | 'random_seed': random_seed + seed, 212 | 'device': device, 213 | 'args': args, 214 | 'log_dir': log_dir, 215 | }) 216 | p.start() 217 | processes.append(p) 218 | 219 | # Check if the maximum number of processes is reached 220 | if len(processes) >= max_processes: 221 | # Wait for any of the processes to finish 222 | print(f"\n Running {len(processes)} environments") 223 | print("Waiting for a process to complete before starting next process") 224 | while len(processes) >= max_processes: 225 | for proc in processes: 226 | if not proc.is_alive(): 227 | processes.remove(proc) 228 | time.sleep(0.1) 229 | total_sleep += 0.1 230 | 231 | # Wait for all processes to finish before continuing 232 | for p in processes: 233 | p.join() 234 | 235 | print("\n") 236 | print("Completed runs in %0.3f" % (time.time() - start_runs), "secs") 237 | print("Completed runs in %0.3f" % ((time.time() - start_runs) / 3600), "hours") 238 | print("Log dir: /",log_dir,'/',current_time,'/') 239 | print("\n--name:", args.name) 240 | print("Slept this long waiting for max_processes: %.03f" % total_sleep, ' secs') 241 | 242 | 243 | ########################## 244 | if __name__ == '__main__': 245 | main() 246 | -------------------------------------------------------------------------------- /learn_vectorised.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | import torch as th 3 | import numpy as np 4 | import os, time, argparse 5 | from torch.utils.tensorboard import SummaryWriter 6 | from utils import log_scalars 7 | from agents.ppo_baseline import Agent 8 | 9 | ''' 10 | Vectorised gym script to test and tune continous agents 11 | - Uses gym.vector.AsyncVectorEnv 12 | ''' 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser(description='Train agent', 16 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 17 | 18 | parser.add_argument('--log', action='store_true', help=('Enables Tensorboard logging')) 19 | parser.add_argument('--name', type=str, help=('Name or describe this run in the logs')) 20 | parser.add_argument('--cuda', action='store_true', help=('Use CUDA')) 21 | parser.add_argument('--vecs', type=int, default=1, help=("Number of vectorised environments")) 22 | 23 | # These features are not implemented yet 24 | # parser.add_argument('--gui', action='store_true', help=('Enables visualisation')) 25 | # parser.add_argument('--save', action='store_true', help=('Enables saving checkpoints')) 26 | # parser.add_argument('--load', action='store_true', help=('Loads previously saved agent')) 27 | 28 | return parser.parse_args() 29 | 30 | def main(): 31 | args = parse_args() 32 | random_seed = 42 33 | num_vecs = args.vecs 34 | os.nice(10) 35 | np.random.seed(random_seed) 36 | th.set_num_threads(1) 37 | np.set_printoptions(precision=0) 38 | if args.cuda: 39 | device = 'cuda' # cuda can be faster when training is costly (large model or dataset) 40 | else: 41 | device = 'cpu' # cpu is very often fastest 42 | log_dir = 'runs_vec' # tensorboard logging dir 43 | 44 | # https://gymnasium.farama.org 45 | # comment out undesired environments 46 | environments = {} 47 | environments['LunarLanderContinuous-v2'] = {'run_steps': int(120e3)} 48 | # environments['BipedalWalker-v3'] = {'run_steps': int(400e3)} 49 | # environments['BipedalWalkerHardcore-v3'] = {'run_steps': int(1e6)} 50 | 51 | # start time of all runs 52 | current_time = time.strftime('%j_%H:%M') 53 | 54 | # Run though all environments in list sequentially, each one vectortised num_vecs times 55 | for env_name in environments.keys(): 56 | 57 | # Create environment 58 | envs = [lambda: gym.make(env_name) for i in range(num_vecs)] 59 | vec_env = gym.vector.AsyncVectorEnv(envs) 60 | env_spec = { 61 | 'act_dim' : vec_env.single_action_space.shape[0], 62 | 'obs_dim' : vec_env.single_observation_space.shape[0], 63 | 'act_max' : vec_env.action_space.high, 64 | 'act_min' : vec_env.action_space.low, 65 | } 66 | obs, info = vec_env.reset(seed=random_seed) 67 | score = np.zeros(num_vecs) 68 | score_sum = 0 69 | done_sum = 0 70 | global_step = 0 71 | 72 | print('>>> RUNNING ENV -> ', env_name, " NUM VECS: ",num_vecs) 73 | print('ACTIONS DIM: ', env_spec['act_dim'], ' OBS DIM: ', env_spec['obs_dim']) 74 | 75 | # Create agent 76 | agent = Agent(env_spec, buffer_size=environments[env_name]['run_steps'], num_env=num_vecs, device=device) 77 | print('>>> RUNNING AGENT -> ', agent.name) 78 | 79 | # New log for this environment 80 | if args.log: 81 | print('>>> TENSORBOARD -> ENABLED') 82 | writer = SummaryWriter(f"{log_dir}/{current_time}/{env_name}/{agent.name}") 83 | # if args.name != None: 84 | # writer.add_text(tag=f"score/{env_name}", text_string=args.name, global_step=0) 85 | # writer.add_text(tag=f"score/{env_name}", text_string=str(args.name) + " | "+ str(agent.h), global_step=0) 86 | writer.add_text("hyperparameters",str(args.name) + "\n\n" + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(agent.h).items()])),) 87 | 88 | 89 | sps_timer = time.time() 90 | update_step_counter = 0 # agents that update every step can create GBs of log data 91 | for step in (range(1, environments[env_name]['run_steps'])): 92 | 93 | # Agent receives tensor and sends back tensor shape (batch, channels) 94 | action = agent.choose_action(th.tensor(obs).to(device)) 95 | obs, reward, terminated, truncated, info = vec_env.step(action.cpu().numpy()) 96 | done = np.logical_or(terminated, truncated) 97 | agent.store_transition(th.tensor(reward).to(device), th.tensor(done).to(device)) 98 | global_step += 1 * num_vecs 99 | 100 | # Episodic score 101 | score += reward 102 | 103 | # Track episodic score 104 | if done.any(): 105 | score_sum += np.sum(score * done) 106 | done_sum += np.sum(done) 107 | # print(f"Steps: {step}, Score: {score * done}") 108 | if args.log: 109 | score_log = score * done # zero out non-completed episodes 110 | # print("logging score: ", score_log) 111 | for i in score_log: 112 | if np.abs(i) > 0.0: 113 | writer.add_scalar(f"score/{env_name}", i, step) 114 | score = score * (1 - done) # reset to zero done episodes 115 | 116 | # Call at every step, agent decides if an update if due 117 | updated, update_time = agent.update() 118 | 119 | # Track samples per second 120 | if step % 1024 == 0 and step != 0: 121 | sps = (step * num_vecs) / (time.time() - sps_timer) 122 | 123 | # Update on progress 124 | print('\033[4m','On global step: ',global_step, 'Updated in: %0.3f' % update_time, 'secs', 'SPS: %0.1f' % sps,'\033[0m') 125 | if done_sum > 0: 126 | print(f"Done episodes: {done_sum}, Mean Score: %0.0f" % (score_sum / done_sum), "\n") 127 | score_sum, done_sum = 0, 0 128 | 129 | if args.log: 130 | writer.add_scalar("perf/SPS", sps, step) 131 | 132 | # Log agent update metrics 133 | update_step_counter += 1 134 | if updated and args.log and update_step_counter >= 1024: 135 | writer.add_scalar("perf/Update", update_time, step) 136 | log_scalars(writer, agent, step) 137 | update_step_counter = 0 138 | 139 | # Tidy up for running next environment 140 | if args.log: writer.close() 141 | vec_env.close() 142 | del agent 143 | print("[--name] was set to: ", args.name) 144 | 145 | 146 | ########################## 147 | if __name__ == '__main__': 148 | main() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | ''' Utility functions for learning scripts ''' 2 | 3 | def log_scalars(writer, agent, step): 4 | 5 | # Agent (PPO) 6 | if hasattr(agent, "approx_kl") : writer.add_scalar("PPO/KL", agent.approx_kl, step) 7 | if hasattr(agent, "p_loss") : writer.add_scalar("PPO/policy loss", agent.p_loss, step) 8 | if hasattr(agent, "v_loss") : writer.add_scalar("PPO/value loss", agent.v_loss, step) 9 | if hasattr(agent, "entropy_loss") : writer.add_scalar("PPO/entropy loss", agent.entropy_loss, step) 10 | if hasattr(agent, "loss") : writer.add_scalar("PPO/total loss", agent.loss, step) 11 | if hasattr(agent, "clipfracs") : writer.add_scalar("PPO/clipfrac", agent.clipfracs, step) 12 | if hasattr(agent, "explained_var") : writer.add_scalar("PPO/explained var", agent.explained_var, step) 13 | if hasattr(agent, "ppo_updates") : writer.add_scalar("PPO/minibatch updates", agent.ppo_updates, step) 14 | if hasattr(agent, "grad_norm") : writer.add_scalar("PPO/gradient norm", agent.grad_norm, step) 15 | if hasattr(agent, "actor_grad_norm") : writer.add_scalar("PPO/actor_gradient norm", agent.actor_grad_norm, step) 16 | if hasattr(agent, "critic_grad_norm"): writer.add_scalar("PPO/critic_gradient norm", agent.critic_grad_norm, step) 17 | if hasattr(agent, "adv_rtn_corr") : writer.add_scalar("PPO/critic adv rtn correlation", agent.adv_rtn_corr, step) 18 | if hasattr(agent, "actor") : 19 | if hasattr(agent.actor, "logstd") : writer.add_scalar("PPO/actor action std", agent.actor.logstd.exp().mean(), step) 20 | 21 | # Agent (SAC) 22 | if hasattr(agent, "qf1_a_values"): writer.add_scalar("SAC/qf1 values", agent.qf1_a_values.mean(), step) 23 | if hasattr(agent, "qf2_a_values"): writer.add_scalar("SAC/qf2 values", agent.qf2_a_values.mean(), step) 24 | if hasattr(agent, "qf1_loss") : writer.add_scalar("SAC/qf1 loss", agent.qf1_loss, step) 25 | if hasattr(agent, "qf2_loss") : writer.add_scalar("SAC/qf2 loss", agent.qf2_loss, step) 26 | if hasattr(agent, "qf_loss") : writer.add_scalar("SAC/qf loss", agent.qf_loss * 0.5, step) 27 | if hasattr(agent, "actor_loss") : writer.add_scalar("SAC/actor loss", agent.actor_loss, step) 28 | if hasattr(agent, "alpha_loss") : writer.add_scalar("SAC/alpha loss", agent.alpha_loss, step) 29 | if hasattr(agent, "alpha") : writer.add_scalar("SAC/alpha", agent.alpha, step) 30 | 31 | # World Model / Other Model 32 | if hasattr(agent, "world_loss") : writer.add_scalar("world/model loss", agent.world_loss, step) 33 | if hasattr(agent, "world_epochs_cnt") : writer.add_scalar("world/model epoch count", agent.world_epochs_cnt, step) 34 | if hasattr(agent, "world_kl_loss") : writer.add_scalar("world/model kl div loss", agent.world_kl_loss, step) 35 | if hasattr(agent, "world_nan_loss_cnt"): writer.add_scalar("world/nan losses", agent.world_nan_loss_cnt, step) 36 | if hasattr(agent, "world_mbtrain_pct") : writer.add_scalar("world/mbatch % trained", agent.world_mbtrain_pct, step) 37 | if hasattr(agent, "idm_loss") : writer.add_scalar("world/inverse dynamics loss", agent.idm_loss, step) 38 | if hasattr(agent, "idm_epochs_cnt") : writer.add_scalar("world/inverse dynamics epochs", agent.idm_epochs_cnt, step) 39 | if hasattr(agent, "spvd_kl_div") : writer.add_scalar("world/supervised kl div", agent.spvd_kl_div, step) 40 | if hasattr(agent, "spvd_epochs_cnt") : writer.add_scalar("world/supervised epochs", agent.spvd_epochs_cnt, step) 41 | if hasattr(agent, "wm_loss") : writer.add_scalar("world/supervised loss", agent.wm_loss, step) 42 | if hasattr(agent, "obs_loss") : writer.add_scalar("world/supervised obs loss", agent.obs_loss, step) 43 | if hasattr(agent, "r_loss") : writer.add_scalar("world/supervised r loss", agent.r_loss, step) 44 | if hasattr(agent, "d_loss") : writer.add_scalar("world/supervised d loss", agent.d_loss, step) 45 | if hasattr(agent, "wm_grad_norm") : writer.add_scalar("world/gradient norm", agent.wm_grad_norm, step) 46 | if hasattr(agent, "env_mean_ed") : writer.add_scalar("world/mean env euclidean distance", agent.env_mean_ed, step) 47 | if hasattr(agent, "pred_mean_ed") : writer.add_scalar("world/mean pred euclidean distance", agent.pred_mean_ed, step) 48 | 49 | # Plasticity 50 | if hasattr(agent, "actor_avg_wgt_mag") : writer.add_scalar("plasticity/actor avg weight magnitude", agent.actor_avg_wgt_mag, step) 51 | if hasattr(agent, "critic_avg_wgt_mag"): writer.add_scalar("plasticity/critic avg weight magnitude", agent.critic_avg_wgt_mag, step) 52 | if hasattr(agent, "qf1_avg_wgt_mag") : writer.add_scalar("plasticity/qf1 avg weight magnitude", agent.qf1_avg_wgt_mag, step) 53 | if hasattr(agent, "qf2_avg_wgt_mag") : writer.add_scalar("plasticity/qf2 avg weight magnitude", agent.qf2_avg_wgt_mag, step) 54 | if hasattr(agent, "wm_avg_wgt_mag") : writer.add_scalar("plasticity/wm avg weight magnitude", agent.wm_avg_wgt_mag, step) 55 | if hasattr(agent, "actor_dead_pct") : writer.add_scalar("plasticity/actor dead units %", agent.actor_dead_pct, step) 56 | if hasattr(agent, "critic_dead_pct") : writer.add_scalar("plasticity/critic dead units %", agent.critic_dead_pct, step) 57 | if hasattr(agent, "qf1_dead_pct") : writer.add_scalar("plasticity/qf1 dead units %", agent.qf1_dead_pct, step) 58 | if hasattr(agent, "qf2_dead_pct") : writer.add_scalar("plasticity/qf2 dead units %", agent.qf2_dead_pct, step) 59 | if hasattr(agent, "wm_dead_pct") : writer.add_scalar("plasticity/wm dead units %", agent.wm_dead_pct, step) 60 | 61 | if hasattr(agent, "qf1_dormant_ratio") : writer.add_scalar("plasticity/qf1_dormant_ratio", agent.qf1_dormant_ratio, step) 62 | if hasattr(agent, "qf2_dormant_ratio") : writer.add_scalar("plasticity/qf2_dormant_ratio", agent.qf2_dormant_ratio, step) 63 | if hasattr(agent, "actor_dormant_ratio"): writer.add_scalar("plasticity/actor_dormant_ratio", agent.actor_dormant_ratio, step) 64 | 65 | 66 | class Colour: 67 | PURPLE = '\033[95m' 68 | CYAN = '\033[96m' 69 | DARKCYAN = '\033[36m' 70 | BLUE = '\033[94m' 71 | GREEN = '\033[92m' 72 | YELLOW = '\033[93m' 73 | RED = '\033[91m' 74 | BOLD = '\033[1m' 75 | UNDERLINE = '\033[4m' 76 | END = '\033[0m' --------------------------------------------------------------------------------