├── README.md ├── requirements.txt ├── src ├── oprl │ ├── algos │ │ ├── ddpg.py │ │ ├── nn.py │ │ ├── sac.py │ │ ├── td3.py │ │ ├── tqc.py │ │ └── utils.py │ ├── configs │ │ ├── d3pg.py │ │ ├── ddpg.py │ │ ├── sac.py │ │ ├── td3.py │ │ ├── tqc.py │ │ └── utils.py │ ├── distrib │ │ └── distrib_runner.py │ ├── distrib_train.py │ ├── env.py │ ├── trainers │ │ ├── base_trainer.py │ │ ├── buffers │ │ │ └── episodic_buffer.py │ │ └── safe_trainer.py │ └── utils │ │ ├── config.py │ │ ├── logger.py │ │ ├── run_training.py │ │ └── utils.py └── setup.py └── tests └── functional ├── requirements.txt └── src ├── test_env.py └── test_rl_algos.py /README.md: -------------------------------------------------------------------------------- 1 | oprl_logo 2 | 3 | # OPRL 4 | 5 | A Modular Library for Off-Policy Reinforcement Learning with a focus on SafeRL and distributed computing. Benchmarking resutls are available at associated homepage: [Homepage](https://schatty.github.io/oprl/) 6 | 7 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 8 | 9 | 10 | # Disclaimer 11 | The project is under an active renovation, for the old code with D4PG algorithm working with multiprocessing queues and `mujoco_py` please refer to the branch `d4pg_legacy`. 12 | 13 | ### Roadmap 🏗 14 | - [x] Switching to `mujoco 3.1.1` 15 | - [x] Replacing multiprocessing queues with RabbitMQ for distributed RL 16 | - [x] Baselines with DDPG, TQC for `dm_control` for 1M step 17 | - [x] Tests 18 | - [x] Support for SafetyGymnasium 19 | - [ ] Style and readability improvements 20 | - [ ] Baselines with Distributed algorithms for `dm_control` 21 | - [ ] D4PG logic on top of TQC 22 | 23 | # Installation 24 | 25 | ``` 26 | pip install -r requirements.txt 27 | cd src && pip install -e . 28 | ``` 29 | 30 | For working with [SafetyGymnasium](https://github.com/PKU-Alignment/safety-gymnasium) install it manually 31 | ``` 32 | git clone https://github.com/PKU-Alignment/safety-gymnasium 33 | cd safety-gymnasium && pip install -e . 34 | ``` 35 | 36 | # Usage 37 | 38 | To run DDPG in a single process 39 | ``` 40 | python src/oprl/configs/ddpg.py --env walker-walk 41 | ``` 42 | 43 | To run distributed DDPG 44 | 45 | Run RabbitMQ 46 | ``` 47 | docker run -it --rm --name rabbitmq -p 5672:5672 -p 15672:15672 rabbitmq:3.12-management 48 | ``` 49 | 50 | Run training 51 | ``` 52 | python src/oprl/configs/d3pg.py --env walker-walk 53 | ``` 54 | 55 | ## Tests 56 | 57 | ``` 58 | cd src && pip install -e . 59 | cd .. && pip install -r tests/functional/requirements.txt 60 | python -m pytest tests 61 | ``` 62 | 63 | ## Results 64 | 65 | Results for single process DDPG and TQC: 66 | ![ddpg_tqc_eval](https://github.com/schatty/d4pg-pytorch/assets/23639048/f2c32f62-63b4-4a66-a636-4ce0ea1522f6) 67 | 68 | ## Acknowledgements 69 | * DDPG and TD3 code is based on the official TD3 implementation: [sfujim/TD3](https://github.com/sfujim/TD3) 70 | * TQC code is based on the official TQC implementation: [SamsungLabs/tqc](https://github.com/SamsungLabs/tqc) 71 | * SafetyGymnasium: [PKU-Alignment/safety-gymnasium](https://github.com/PKU-Alignment/safety-gymnasium) 72 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.2.2 2 | tensorboard==2.15.1 3 | packaging==23.2 4 | dm-control==1.0.16 5 | mujoco==3.1.3 6 | -------------------------------------------------------------------------------- /src/oprl/algos/ddpg.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Any, Dict 3 | 4 | import numpy as np 5 | import numpy.typing as npt 6 | import torch as t 7 | from torch import nn 8 | 9 | from oprl.algos.nn import Critic, DeterministicPolicy 10 | from oprl.algos.utils import disable_gradient 11 | from oprl.utils.logger import Logger, StdLogger 12 | 13 | 14 | class DDPG: 15 | def __init__( 16 | self, 17 | state_dim: int, 18 | action_dim: int, 19 | max_action: float = 1, 20 | expl_noise: float = 0.1, 21 | discount: float = 0.99, 22 | tau: float = 5e-3, 23 | batch_size: int = 256, 24 | device: str = "cpu", 25 | logger: Logger = StdLogger(), 26 | ): 27 | self._expl_noise = expl_noise 28 | self._action_dim = action_dim 29 | self._discount = discount 30 | self._tau = tau 31 | self._batch_size = batch_size 32 | self._max_action = max_action 33 | self._device = device 34 | self._logger = logger 35 | 36 | self.actor = DeterministicPolicy( 37 | state_dim=state_dim, 38 | action_dim=action_dim, 39 | hidden_units=(256, 256), 40 | hidden_activation=nn.ReLU(inplace=True), 41 | ).to(device) 42 | self.actor_target = deepcopy(self.actor) 43 | disable_gradient(self.actor_target) 44 | self.optim_actor = t.optim.Adam(self.actor.parameters(), lr=3e-4) 45 | 46 | self.critic = Critic(state_dim, action_dim).to(device) 47 | self.critic_target = deepcopy(self.critic) 48 | disable_gradient(self.critic_target) 49 | self.optim_critic = t.optim.Adam(self.critic.parameters(), lr=3e-4) 50 | 51 | def update( 52 | self, 53 | state: t.Tensor, 54 | action: t.Tensor, 55 | reward: t.Tensor, 56 | done: t.Tensor, 57 | next_state: t.Tensor, 58 | ): 59 | self._update_critic(state, action, reward, done, next_state) 60 | self._update_actor(state) 61 | 62 | # Update the frozen target models 63 | for param, target_param in zip( 64 | self.critic.parameters(), self.critic_target.parameters() 65 | ): 66 | target_param.data.copy_( 67 | self._tau * param.data + (1 - self._tau) * target_param.data 68 | ) 69 | 70 | for param, target_param in zip( 71 | self.actor.parameters(), self.actor_target.parameters() 72 | ): 73 | target_param.data.copy_( 74 | self._tau * param.data + (1 - self._tau) * target_param.data 75 | ) 76 | 77 | def _update_critic( 78 | self, 79 | state: t.Tensor, 80 | action: t.Tensor, 81 | reward: t.Tensor, 82 | done: t.Tensor, 83 | next_state: t.Tensor, 84 | ) -> None: 85 | target_Q = self.critic_target(next_state, self.actor_target(next_state)) 86 | target_Q = reward + (1.0 - done) * self._discount * target_Q.detach() 87 | current_Q = self.critic(state, action) 88 | 89 | critic_loss = (current_Q - target_Q).pow(2).mean() 90 | 91 | self.optim_critic.zero_grad() 92 | critic_loss.backward() 93 | self.optim_critic.step() 94 | 95 | def _update_actor(self, state: t.Tensor) -> None: 96 | actor_loss = -self.critic(state, self.actor(state)).mean() 97 | 98 | self.optim_actor.zero_grad() 99 | actor_loss.backward() 100 | self.optim_actor.step() 101 | 102 | def exploit(self, state: npt.ArrayLike) -> npt.ArrayLike: 103 | state = t.tensor(state, device=self._device).unsqueeze_(0) 104 | with t.no_grad(): 105 | action = self.actor(state).cpu() 106 | return action.numpy().flatten() 107 | 108 | # TODO: remove explore from algo to agent completely 109 | def explore(self, state: npt.ArrayLike) -> npt.ArrayLike: 110 | state = t.tensor(state, device=self._device).unsqueeze_(0) 111 | 112 | with t.no_grad(): 113 | noise = ( 114 | t.randn(self._action_dim) * self._max_action * self._expl_noise 115 | ).to(self._device) 116 | action = self.actor(state) + noise 117 | 118 | a = action.cpu().numpy()[0] 119 | return np.clip(a, -self._max_action, self._max_action) 120 | 121 | def get_policy_state_dict(self) -> Dict[str, Any]: 122 | return self.actor.state_dict() 123 | 124 | @property 125 | def logger(self) -> Logger: 126 | return self._logger 127 | -------------------------------------------------------------------------------- /src/oprl/algos/nn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.typing as npt 3 | import torch as t 4 | import torch.nn as nn 5 | from torch.distributions import Distribution, Normal 6 | from torch.nn.functional import logsigmoid 7 | 8 | from oprl.algos.utils import initialize_weight 9 | 10 | LOG_STD_MIN_MAX = (-20, 2) 11 | 12 | 13 | class Critic(nn.Module): 14 | def __init__( 15 | self, 16 | state_dim: int, 17 | action_dim: int, 18 | hidden_units: tuple[int, ...] = (256, 256), 19 | hidden_activation: nn.Module = nn.ReLU(inplace=True), 20 | ): 21 | super().__init__() 22 | 23 | self.q1 = MLP( 24 | input_dim=state_dim + action_dim, 25 | output_dim=1, 26 | hidden_units=hidden_units, 27 | hidden_activation=hidden_activation, 28 | ) 29 | 30 | def forward(self, states: t.Tensor, actions: t.Tensor): 31 | x = t.cat([states, actions], dim=-1) 32 | return self.q1(x) 33 | 34 | def Q1(self, states: t.Tensor, actions: t.Tensor) -> t.Tensor: 35 | x = t.cat([states, actions], dim=-1) 36 | return self.q1(x) 37 | 38 | 39 | class DoubleCritic(nn.Module): 40 | def __init__( 41 | self, 42 | state_dim: int, 43 | action_dim: int, 44 | hidden_units: tuple[int, ...] = (256, 256), 45 | hidden_activation: nn.Module = nn.ReLU(inplace=True), 46 | ): 47 | super().__init__() 48 | 49 | self.q1 = MLP( 50 | input_dim=state_dim + action_dim, 51 | output_dim=1, 52 | hidden_units=hidden_units, 53 | hidden_activation=hidden_activation, 54 | ) 55 | 56 | self.q2 = MLP( 57 | input_dim=state_dim + action_dim, 58 | output_dim=1, 59 | hidden_units=hidden_units, 60 | hidden_activation=hidden_activation, 61 | ) 62 | 63 | def forward(self, states: t.Tensor, actions: t.Tensor) -> tuple[t.Tensor, t.Tensor]: 64 | x = t.cat([states, actions], dim=-1) 65 | return self.q1(x), self.q2(x) 66 | 67 | def Q1(self, states: t.Tensor, actions: t.Tensor) -> t.Tensor: 68 | x = t.cat([states, actions], dim=-1) 69 | return self.q1(x) 70 | 71 | 72 | class MLP(nn.Module): 73 | def __init__( 74 | self, 75 | input_dim: int, 76 | output_dim: int, 77 | hidden_units: tuple[int, ...] = (64, 64), 78 | hidden_activation: nn.Module = nn.Tanh(), 79 | output_activation: nn.Module = nn.Identity(), 80 | ): 81 | super().__init__() 82 | 83 | layers = [] 84 | units = input_dim 85 | for next_units in hidden_units: 86 | layers.append(nn.Linear(units, next_units)) 87 | layers.append(hidden_activation) 88 | units = next_units 89 | layers.append(nn.Linear(units, output_dim)) 90 | layers.append(output_activation) 91 | 92 | self.nn = nn.Sequential(*layers) 93 | 94 | def forward(self, x: t.Tensor) -> t.Tensor: 95 | return self.nn(x) 96 | 97 | 98 | class DeterministicPolicy(nn.Module): 99 | def __init__( 100 | self, 101 | state_dim: int, 102 | action_dim: int, 103 | hidden_units: tuple[int, ...] = (256, 256), 104 | hidden_activation=nn.ReLU(inplace=True), 105 | max_action: float = 1.0, 106 | expl_noise: float = 0.1, 107 | device: str = "cpu", 108 | ): 109 | super().__init__() 110 | 111 | self.mlp = MLP( 112 | input_dim=state_dim, 113 | output_dim=action_dim, 114 | hidden_units=hidden_units, 115 | hidden_activation=hidden_activation, 116 | ).apply(initialize_weight) 117 | 118 | self._device = device 119 | self._action_shape = action_dim 120 | self._max_action = max_action 121 | self._expl_noise = expl_noise 122 | 123 | def forward(self, states: t.Tensor) -> t.Tensor: 124 | return t.tanh(self.mlp(states)) 125 | 126 | def exploit(self, state: npt.ArrayLike) -> npt.NDArray: 127 | state = t.tensor(state).unsqueeze_(0).to(self._device) 128 | return self.forward(state).cpu().numpy().flatten() 129 | 130 | def explore(self, state: npt.ArrayLike) -> npt.NDArray: 131 | state = t.tensor(state, device=self._device).unsqueeze_(0) 132 | 133 | with t.no_grad(): 134 | noise = (t.randn(self._action_shape) * self._expl_noise).to(self._device) 135 | action = self.mlp(state) + noise 136 | 137 | a = action.cpu().numpy()[0] 138 | return np.clip(a, -self._max_action, self._max_action) 139 | 140 | 141 | class GaussianActor(nn.Module): 142 | def __init__(self, state_dim, action_dim, hidden_units, hidden_activation): 143 | super().__init__() 144 | self.action_dim = action_dim 145 | self.net = MLP( 146 | state_dim, 2 * action_dim, hidden_units, hidden_activation=hidden_activation 147 | ) 148 | 149 | def forward(self, obs: t.Tensor) -> tuple[t.Tensor, t.Tensor | None]: 150 | mean, log_std = self.net(obs).split([self.action_dim, self.action_dim], dim=1) 151 | log_std = log_std.clamp(*LOG_STD_MIN_MAX) 152 | 153 | if self.training: 154 | std = t.exp(log_std) 155 | tanh_normal = TanhNormal(mean, std, self.device) 156 | action, pre_tanh = tanh_normal.rsample() 157 | log_prob = tanh_normal.log_prob(pre_tanh) 158 | log_prob = log_prob.sum(dim=1, keepdim=True) 159 | else: # deterministic eval without log_prob computation 160 | action = t.tanh(mean) 161 | log_prob = None 162 | return action, log_prob 163 | 164 | @property 165 | def device(self): 166 | return next(self.parameters()).device 167 | 168 | 169 | class TanhNormal(Distribution): 170 | def __init__(self, normal_mean: t.Tensor, normal_std: t.Tensor, device: str): 171 | super().__init__() 172 | self.normal_mean = normal_mean 173 | self.normal_std = normal_std 174 | self.standard_normal = Normal( 175 | t.zeros_like(self.normal_mean, device=device), 176 | t.ones_like(self.normal_std, device=device), 177 | ) 178 | self.normal = Normal(normal_mean, normal_std) 179 | 180 | def log_prob(self, pre_tanh: t.Tensor) -> t.Tensor: 181 | log_det = 2 * np.log(2) + logsigmoid(2 * pre_tanh) + logsigmoid(-2 * pre_tanh) 182 | result = self.normal.log_prob(pre_tanh) - log_det 183 | return result 184 | 185 | def rsample(self) -> tuple[t.Tensor, t.Tensor]: 186 | pretanh = self.normal_mean + self.normal_std * self.standard_normal.sample() 187 | return t.tanh(pretanh), pretanh 188 | -------------------------------------------------------------------------------- /src/oprl/algos/sac.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import numpy as np 4 | import numpy.typing as npt 5 | import torch as t 6 | from torch import nn 7 | from torch.optim import Adam 8 | 9 | from oprl.algos.nn import DoubleCritic, GaussianActor 10 | from oprl.algos.utils import disable_gradient, soft_update 11 | from oprl.utils.logger import Logger, StdLogger 12 | 13 | 14 | class SAC: 15 | def __init__( 16 | self, 17 | state_dim: int, 18 | action_dim: int, 19 | batch_size: int = 256, 20 | tune_alpha: bool = False, 21 | gamma: float = 0.99, 22 | lr_actor: float = 3e-4, 23 | lr_critic: float = 3e-4, 24 | lr_alpha: float = 1e-3, 25 | alpha_init: float = 0.2, 26 | target_update_coef: float = 5e-3, 27 | device: str = "cpu", 28 | log_every: int = 5000, 29 | logger: Logger = StdLogger(), 30 | ): 31 | self._update_step = 0 32 | self._state_dim = state_dim 33 | self._action_dim = action_dim 34 | self._device = device 35 | self._batch_size = batch_size 36 | self._gamma = gamma 37 | self._tune_alpha = tune_alpha 38 | self._discount = gamma 39 | self._target_update_coef = target_update_coef 40 | self._log_every = log_every 41 | 42 | self.actor = GaussianActor( 43 | state_dim=self._state_dim, 44 | action_dim=action_dim, 45 | hidden_units=(256, 256), 46 | hidden_activation=nn.ReLU(inplace=True), 47 | ).to(device) 48 | 49 | self.critic = DoubleCritic( 50 | state_dim=self._state_dim, 51 | action_dim=self._action_dim, 52 | hidden_units=(256, 256), 53 | hidden_activation=nn.ReLU(inplace=True), 54 | ).to(device) 55 | 56 | self.critic_target = deepcopy(self.critic).to(device).eval() 57 | disable_gradient(self.critic_target) 58 | 59 | self.optim_actor = Adam(self.actor.parameters(), lr=lr_actor) 60 | self.optim_critic = Adam(self.critic.parameters(), lr=lr_critic) 61 | 62 | self._alpha = alpha_init 63 | if self._tune_alpha: 64 | self.log_alpha = t.tensor( 65 | np.log(self._alpha), device=device, requires_grad=True 66 | ) 67 | self.optim_alpha = t.optim.Adam([self.log_alpha], lr=lr_alpha) 68 | self._target_entropy = -float(action_dim) 69 | 70 | self._logger = logger 71 | 72 | def update( 73 | self, 74 | state: t.Tensor, 75 | action: t.Tensor, 76 | reward: t.Tensor, 77 | done: t.Tensor, 78 | next_state: t.Tensor, 79 | ) -> None: 80 | self.update_critic(state, action, reward, done, next_state) 81 | self.update_actor(state) 82 | soft_update(self.critic_target, self.critic, self._target_update_coef) 83 | 84 | self._update_step += 1 85 | 86 | def update_critic( 87 | self, 88 | states: t.Tensor, 89 | actions: t.Tensor, 90 | rewards: t.Tensor, 91 | dones: t.Tensor, 92 | next_states: t.Tensor, 93 | ) -> None: 94 | q1, q2 = self.critic(states, actions) 95 | with t.no_grad(): 96 | next_actions, log_pis = self.actor(next_states) 97 | q1_next, q2_next = self.critic_target(next_states, next_actions) 98 | q_next = t.min(q1_next, q2_next) - self._alpha * log_pis 99 | 100 | q_target = rewards + (1.0 - dones) * self._discount * q_next 101 | 102 | td_error1 = (q1 - q_target).pow(2).mean() 103 | td_error2 = (q2 - q_target).pow(2).mean() 104 | loss_critic = td_error1 + td_error2 105 | 106 | self.optim_critic.zero_grad() 107 | loss_critic.backward() 108 | self.optim_critic.step() 109 | 110 | if self._update_step % self._log_every == 0: 111 | self._logger.log_scalars( 112 | { 113 | "algo/q1": q1.detach().mean().cpu(), 114 | "algo/q_target": q_target.mean().cpu(), 115 | "algo/abs_q_err": (q1 - q_target).detach().mean().cpu(), 116 | "algo/critic_loss": loss_critic.item(), 117 | }, 118 | self._update_step, 119 | ) 120 | 121 | def update_actor(self, state: t.Tensor) -> None: 122 | actions, log_pi = self.actor(state) 123 | qs1, qs2 = self.critic(state, actions) 124 | loss_actor = self._alpha * log_pi.mean() - t.min(qs1, qs2).mean() 125 | 126 | self.optim_actor.zero_grad() 127 | loss_actor.backward() 128 | self.optim_actor.step() 129 | 130 | if self._tune_alpha: 131 | loss_alpha = -self.log_alpha * ( 132 | self._target_entropy + log_pi.detach_().mean() 133 | ) 134 | 135 | self.optim_alpha.zero_grad() 136 | loss_alpha.backward() 137 | self.optim_alpha.step() 138 | with t.no_grad(): 139 | self._alpha = self.log_alpha.exp().item() 140 | 141 | if self._update_step % self._log_every == 0: 142 | if self._tune_alpha: 143 | self._logger.log_scalar( 144 | "algo/loss_alpha", loss_alpha.item(), self._update_step 145 | ) 146 | self._logger.log_scalars( 147 | { 148 | "algo/loss_actor": loss_actor.item(), 149 | "algo/alpha": self._alpha, 150 | "algo/log_pi": log_pi.cpu().mean(), 151 | }, 152 | self._update_step, 153 | ) 154 | 155 | def explore(self, state: npt.ArrayLike) -> npt.ArrayLike: 156 | state = t.tensor(state, device=self._device).unsqueeze_(0) 157 | with t.no_grad(): 158 | action, _ = self.actor(state) 159 | return action.cpu().numpy()[0] 160 | 161 | def exploit(self, state: npt.ArrayLike) -> npt.ArrayLike: 162 | self.actor.eval() 163 | action = self.explore(state) 164 | self.actor.train() 165 | return action 166 | -------------------------------------------------------------------------------- /src/oprl/algos/td3.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import numpy as np 4 | import numpy.typing as npt 5 | import torch as t 6 | from torch import nn 7 | from torch.optim import Adam 8 | 9 | from oprl.algos.nn import DeterministicPolicy, DoubleCritic 10 | from oprl.algos.utils import disable_gradient, soft_update 11 | from oprl.utils.logger import Logger, StdLogger 12 | 13 | 14 | class TD3: 15 | def __init__( 16 | self, 17 | state_dim: int, 18 | action_dim: int, 19 | batch_size: int = 256, 20 | policy_noise: float = 0.2, 21 | expl_noise: float = 0.1, 22 | noise_clip: float = 0.5, 23 | policy_freq: int = 2, 24 | discount: float = 0.99, 25 | lr_actor: float = 3e-4, 26 | lr_critic: float = 3e-4, 27 | max_action: float = 1.0, 28 | tau: float = 5e-3, 29 | log_every: int = 5000, 30 | device="cpu", 31 | logger: Logger = StdLogger(), 32 | ): 33 | self._aciton_dim = action_dim 34 | self._expl_noise = expl_noise 35 | self._batch_size = batch_size 36 | self._discount = discount 37 | self._tau = tau 38 | self._policy_noise = policy_noise 39 | self._noise_clip = noise_clip 40 | self._policy_freq = policy_freq 41 | self._max_action = max_action 42 | self._device = device 43 | self._logger = logger 44 | 45 | self._log_every = log_every 46 | self._update_step = 0 47 | 48 | self.actor = DeterministicPolicy( 49 | state_dim=state_dim, 50 | action_dim=action_dim, 51 | hidden_units=(256, 256), 52 | hidden_activation=nn.ReLU(inplace=True), 53 | ).to(self._device) 54 | self.actor_target = deepcopy(self.actor).to(self._device).eval() 55 | disable_gradient(self.actor_target) 56 | 57 | self.optim_actor = Adam(self.actor.parameters(), lr=lr_actor) 58 | 59 | self.critic = DoubleCritic( 60 | state_dim=state_dim, 61 | action_dim=action_dim, 62 | hidden_units=(256, 256), 63 | hidden_activation=nn.ReLU(inplace=True), 64 | ).to(self._device) 65 | self.critic_target = deepcopy(self.critic).to(self._device).eval() 66 | disable_gradient(self.critic_target) 67 | 68 | self.optim_critic = Adam(self.critic.parameters(), lr=lr_critic) 69 | 70 | def exploit(self, state: npt.ArrayLike) -> npt.ArrayLike: 71 | state = t.tensor(state, device=self._device).unsqueeze_(0) 72 | with t.no_grad(): 73 | action = self.actor(state) 74 | return action.cpu().numpy().flatten() 75 | 76 | def explore(self, state: npt.ArrayLike) -> npt.ArrayLike: 77 | state = t.tensor(state, device=self._device).unsqueeze_(0) 78 | noise = (t.randn(self._aciton_dim) * self._max_action * self._expl_noise).to( 79 | self._device 80 | ) 81 | 82 | with t.no_grad(): 83 | action = self.actor(state) + noise 84 | 85 | a = action.cpu().numpy()[0] 86 | return np.clip(a, -self._max_action, self._max_action) 87 | 88 | def update(self, state: t.Tensor, action, reward, done, next_state) -> None: 89 | self._update_critic(state, action, reward, done, next_state) 90 | 91 | if self._update_step % self._policy_freq == 0: 92 | self._update_actor(state) 93 | soft_update(self.critic_target, self.critic, self._tau) 94 | soft_update(self.actor_target, self.actor, self._tau) 95 | 96 | self._update_step += 1 97 | 98 | def _update_critic( 99 | self, 100 | state: t.Tensor, 101 | action: t.Tensor, 102 | reward: t.Tensor, 103 | done: t.Tensor, 104 | next_state: t.Tensor, 105 | ) -> None: 106 | q1, q2 = self.critic(state, action) 107 | 108 | with t.no_grad(): 109 | noise = (t.randn_like(action) * self._policy_noise).clamp( 110 | -self._noise_clip, self._noise_clip 111 | ) 112 | 113 | next_actions = self.actor_target(next_state) + noise 114 | next_actions = next_actions.clamp(-self._max_action, self._max_action) 115 | 116 | q1_next, q2_next = self.critic_target(next_state, next_actions) 117 | q_next = t.min(q1_next, q2_next) 118 | 119 | q_target = reward + (1.0 - done) * self._discount * q_next 120 | 121 | td_error1 = (q1 - q_target).pow(2).mean() 122 | td_error2 = (q2 - q_target).pow(2).mean() 123 | loss_critic = td_error1 + td_error2 124 | 125 | self.optim_critic.zero_grad() 126 | loss_critic.backward() 127 | self.optim_critic.step() 128 | 129 | if self._update_step % self._log_every == 0: 130 | self._logger.log_scalar( 131 | "algo/q1", q1.detach().mean().cpu(), self._update_step 132 | ) 133 | self._logger.log_scalar( 134 | "algo/q_target", q_target.mean().cpu(), self._update_step 135 | ) 136 | self._logger.log_scalar( 137 | "algo/abs_q_err", 138 | (q1 - q_target).detach().mean().cpu(), 139 | self._update_step, 140 | ) 141 | self._logger.log_scalar( 142 | "algo/critic_loss", loss_critic.item(), self._update_step 143 | ) 144 | 145 | def _update_actor(self, state: t.Tensor) -> None: 146 | actions = self.actor(state) 147 | qs1 = self.critic.Q1(state, actions) 148 | loss_actor = -qs1.mean() 149 | 150 | self.optim_actor.zero_grad() 151 | loss_actor.backward() 152 | self.optim_actor.step() 153 | 154 | if self._update_step % self._log_every == 0: 155 | self._logger.log_scalar( 156 | "algo/loss_actor", loss_actor.item(), self._update_step 157 | ) 158 | -------------------------------------------------------------------------------- /src/oprl/algos/tqc.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy as np 4 | import numpy.typing as npt 5 | import torch as t 6 | import torch.nn as nn 7 | 8 | from oprl.algos.nn import MLP, GaussianActor 9 | from oprl.utils.logger import Logger, StdLogger 10 | 11 | 12 | def quantile_huber_loss_f( 13 | quantiles: t.Tensor, samples: t.Tensor, device: str 14 | ) -> t.Tensor: 15 | """ 16 | Args: 17 | quantiles: [batch, n_nets, n_quantiles]. 18 | samples: [batch, n_nets * n_quantiles - top_quantiles_to_drop]. 19 | 20 | Returns: 21 | loss as a torch value. 22 | """ 23 | pairwise_delta = ( 24 | samples[:, None, None, :] - quantiles[:, :, :, None] 25 | ) # batch x nets x quantiles x samples 26 | abs_pairwise_delta = t.abs(pairwise_delta) 27 | huber_loss = t.where( 28 | abs_pairwise_delta > 1, abs_pairwise_delta - 0.5, pairwise_delta**2 * 0.5 29 | ) 30 | 31 | n_quantiles = quantiles.shape[2] 32 | tau = ( 33 | t.arange(n_quantiles, device=device).float() / n_quantiles + 1 / 2 / n_quantiles 34 | ) 35 | loss = ( 36 | t.abs(tau[None, None, :, None] - (pairwise_delta < 0).float()) * huber_loss 37 | ).mean() 38 | return loss 39 | 40 | 41 | class QuantileQritic(nn.Module): 42 | def __init__(self, state_dim: int, action_dim: int, n_quantiles: int, n_nets: int): 43 | super().__init__() 44 | self.nets = [] 45 | self.n_quantiles = n_quantiles 46 | self.n_nets = n_nets 47 | for i in range(n_nets): 48 | net = MLP( 49 | state_dim + action_dim, 50 | n_quantiles, 51 | (512, 512, 512), 52 | hidden_activation=nn.ReLU(), 53 | ) 54 | self.add_module(f"qf{i}", net) 55 | self.nets.append(net) 56 | 57 | def forward(self, state: t.Tensor, action: t.Tensor) -> t.Tensor: 58 | sa = t.cat((state, action), dim=1) 59 | quantiles = t.stack(tuple(net(sa) for net in self.nets), dim=1) 60 | return quantiles 61 | 62 | 63 | class TQC: 64 | def __init__( 65 | self, 66 | state_dim: int, 67 | action_dim: int, 68 | discount: float = 0.99, 69 | tau: float = 0.005, 70 | top_quantiles_to_drop: int = 2, 71 | n_quantiles: int = 25, 72 | n_nets: int = 5, 73 | log_every: int = 5000, 74 | device: str = "cpu", 75 | logger: Logger = StdLogger(), 76 | ): 77 | self._discount = discount 78 | self._tau = tau 79 | self._top_quantiles_to_drop = top_quantiles_to_drop 80 | self._target_entropy = -np.prod(action_dim).item() 81 | self._device = device 82 | self._update_step = 0 83 | self._log_every = log_every 84 | self._logger = logger 85 | 86 | self.actor = GaussianActor( 87 | state_dim, 88 | action_dim, 89 | hidden_units=(256, 256), 90 | hidden_activation=nn.ReLU(), 91 | ).to(device) 92 | self.critic = QuantileQritic(state_dim, action_dim, n_quantiles, n_nets).to( 93 | device 94 | ) 95 | self.critic_target = copy.deepcopy(self.critic) 96 | self.log_alpha = t.tensor(np.log(0.2), requires_grad=True, device=device) 97 | self._quantiles_total = self.critic.n_quantiles * self.critic.n_nets 98 | 99 | # TODO: check hyperparams 100 | self.actor_optimizer = t.optim.Adam(self.actor.parameters(), lr=3e-4) 101 | self.critic_optimizer = t.optim.Adam(self.critic.parameters(), lr=3e-4) 102 | self.alpha_optimizer = t.optim.Adam([self.log_alpha], lr=3e-4) 103 | 104 | def update( 105 | self, 106 | state: t.Tensor, 107 | action: t.Tensor, 108 | reward: t.Tensor, 109 | done: t.Tensor, 110 | next_state: t.Tensor, 111 | ): 112 | batch_size = state.shape[0] 113 | 114 | alpha = t.exp(self.log_alpha) 115 | 116 | # --- Q loss --- 117 | with t.no_grad(): 118 | # get policy action 119 | new_next_action, next_log_pi = self.actor(next_state) 120 | 121 | # compute and cut quantiles at the next state 122 | next_z = self.critic_target( 123 | next_state, new_next_action 124 | ) # batch x nets x quantiles 125 | sorted_z, _ = t.sort(next_z.reshape(batch_size, -1)) 126 | sorted_z_part = sorted_z[ 127 | :, : self._quantiles_total - self._top_quantiles_to_drop 128 | ] 129 | 130 | # compute target 131 | target = reward + (1 - done) * self._discount * ( 132 | sorted_z_part - alpha * next_log_pi 133 | ) 134 | 135 | cur_z = self.critic(state, action) 136 | critic_loss = quantile_huber_loss_f(cur_z, target, self._device) 137 | 138 | self.critic_optimizer.zero_grad() 139 | critic_loss.backward() 140 | self.critic_optimizer.step() 141 | 142 | for param, target_param in zip( 143 | self.critic.parameters(), self.critic_target.parameters() 144 | ): 145 | target_param.data.copy_( 146 | self._tau * param.data + (1 - self._tau) * target_param.data 147 | ) 148 | 149 | # --- Policy and alpha loss --- 150 | new_action, log_pi = self.actor(state) 151 | alpha_loss = -self.log_alpha * (log_pi + self._target_entropy).detach().mean() 152 | actor_loss = ( 153 | alpha * log_pi 154 | - self.critic(state, new_action).mean(2).mean(1, keepdim=True) 155 | ).mean() 156 | 157 | # --- Update --- 158 | 159 | self.actor_optimizer.zero_grad() 160 | actor_loss.backward() 161 | self.actor_optimizer.step() 162 | 163 | self.alpha_optimizer.zero_grad() 164 | alpha_loss.backward() 165 | self.alpha_optimizer.step() 166 | 167 | if self._update_step % self._log_every == 0: 168 | self._logger.log_scalars( 169 | { 170 | "algo/critic_loss": critic_loss.item(), 171 | "algo/actor_loss": actor_loss.item(), 172 | "algo/alpha_loss": alpha_loss.item(), 173 | }, 174 | self._update_step, 175 | ) 176 | 177 | self._update_step += 1 178 | 179 | def explore(self, state: npt.ArrayLike) -> npt.ArrayLike: 180 | state = t.tensor(state, device=self._device).unsqueeze_(0) 181 | with t.no_grad(): 182 | action, _ = self.actor(state) 183 | return action.cpu().numpy()[0] 184 | 185 | def exploit(self, state: npt.ArrayLike) -> npt.ArrayLike: 186 | self.actor.eval() 187 | action = self.explore(state) 188 | self.actor.train() 189 | return action 190 | -------------------------------------------------------------------------------- /src/oprl/algos/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Clamp(nn.Module): 6 | def forward(self, log_stds): 7 | return log_stds.clamp_(-20, 2) 8 | 9 | 10 | def initialize_weight(m, gain=nn.init.calculate_gain("relu")): 11 | # Initialize linear layers with the orthogonal initialization. 12 | if isinstance(m, nn.Linear): 13 | nn.init.orthogonal_(m.weight.data, gain) 14 | m.bias.data.fill_(0.0) 15 | 16 | # Initialize conv layers with the delta-orthogonal initialization. 17 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 18 | assert m.weight.size(2) == m.weight.size(3) 19 | m.weight.data.fill_(0.0) 20 | m.bias.data.fill_(0.0) 21 | mid = m.weight.size(2) // 2 22 | nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain) 23 | 24 | 25 | def soft_update(target, source, tau): 26 | """Update target network using Polyak-Ruppert Averaging.""" 27 | with torch.no_grad(): 28 | for tgt, src in zip(target.parameters(), source.parameters()): 29 | tgt.data.mul_(1.0 - tau) 30 | tgt.data.add_(tau * src.data) 31 | 32 | 33 | def disable_gradient(network): 34 | """Disable gradient calculations of the network.""" 35 | for param in network.parameters(): 36 | param.requires_grad = False 37 | -------------------------------------------------------------------------------- /src/oprl/configs/d3pg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from multiprocessing import Process 4 | 5 | import torch.nn as nn 6 | 7 | from oprl.algos.ddpg import DDPG, DeterministicPolicy 8 | from oprl.configs.utils import create_logdir 9 | from oprl.distrib.distrib_runner import env_worker, policy_update_worker 10 | from oprl.utils.utils import set_logging 11 | 12 | set_logging(logging.INFO) 13 | from oprl.env import make_env as _make_env 14 | from oprl.trainers.buffers.episodic_buffer import EpisodicReplayBuffer 15 | from oprl.utils.logger import FileLogger, Logger 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser(description="Run training") 20 | parser.add_argument("--config", type=str, help="Path to the config file.") 21 | parser.add_argument( 22 | "--env", type=str, default="cartpole-balance", help="Name of the environment." 23 | ) 24 | parser.add_argument( 25 | "--device", type=str, default="cpu", help="Device to perform training on." 26 | ) 27 | parser.add_argument("--seed", type=int, default=0, help="Random seed") 28 | return parser.parse_args() 29 | 30 | 31 | # -------- Distrib params ----------- 32 | 33 | ENV_WORKERS = 4 34 | N_EPISODES = 50 # 500 # Number of episodes each env worker would perform 35 | 36 | # ----------------------------------- 37 | 38 | args = parse_args() 39 | 40 | 41 | def make_env(seed: int): 42 | return _make_env(args.env, seed=seed) 43 | 44 | 45 | env = make_env(seed=0) 46 | STATE_DIM = env.observation_space.shape[0] 47 | ACTION_DIM = env.action_space.shape[0] 48 | logging.info(f"Env state {STATE_DIM}\tEnv action {ACTION_DIM}") 49 | 50 | 51 | log_dir = create_logdir(logdir="logs", algo="D3PG", env=args.env, seed=args.seed) 52 | logging.info(f"LOG_DIR: {log_dir}") 53 | 54 | 55 | def make_logger(seed: int) -> Logger: 56 | log_dir = create_logdir(logdir="logs", algo="D3PG", env=args.env, seed=seed) 57 | # TODO: add here actual config 58 | return FileLogger(log_dir, {}) 59 | 60 | 61 | def make_policy(): 62 | return DeterministicPolicy( 63 | state_dim=STATE_DIM, 64 | action_dim=ACTION_DIM, 65 | hidden_units=(256, 256), 66 | hidden_activation=nn.ReLU(inplace=True), 67 | device=args.device, 68 | ) 69 | 70 | 71 | def make_buffer(): 72 | return EpisodicReplayBuffer( 73 | buffer_size=int(1_000_000), 74 | state_dim=STATE_DIM, 75 | action_dim=ACTION_DIM, 76 | device=args.device, 77 | gamma=0.99, 78 | ) 79 | 80 | 81 | def make_algo(): 82 | logger = make_logger(args.seed) 83 | 84 | algo = DDPG( 85 | state_dim=STATE_DIM, 86 | action_dim=ACTION_DIM, 87 | device=args.device, 88 | logger=logger, 89 | ) 90 | return algo 91 | 92 | 93 | if __name__ == "__main__": 94 | processes = [] 95 | 96 | for i_env in range(ENV_WORKERS): 97 | processes.append( 98 | Process(target=env_worker, args=(make_env, make_policy, N_EPISODES, i_env)) 99 | ) 100 | processes.append( 101 | Process( 102 | target=policy_update_worker, 103 | args=(make_algo, make_env, make_buffer, ENV_WORKERS), 104 | ) 105 | ) 106 | 107 | for p in processes: 108 | p.start() 109 | 110 | for p in processes: 111 | p.join() 112 | 113 | logging.info("Training OK.") 114 | -------------------------------------------------------------------------------- /src/oprl/configs/ddpg.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from oprl.algos.ddpg import DDPG 4 | from oprl.configs.utils import create_logdir, parse_args 5 | from oprl.utils.utils import set_logging 6 | 7 | set_logging(logging.INFO) 8 | from oprl.env import make_env as _make_env 9 | from oprl.utils.logger import FileLogger, Logger 10 | from oprl.utils.run_training import run_training 11 | 12 | args = parse_args() 13 | 14 | 15 | def make_env(seed: int): 16 | return _make_env(args.env, seed=seed) 17 | 18 | 19 | env = make_env(seed=0) 20 | STATE_DIM: int = env.observation_space.shape[0] 21 | ACTION_DIM: int = env.action_space.shape[0] 22 | 23 | 24 | # -------- Config params ----------- 25 | 26 | config = { 27 | "state_dim": STATE_DIM, 28 | "action_dim": ACTION_DIM, 29 | "num_steps": int(1_000_000), 30 | "eval_every": 2500, 31 | "device": args.device, 32 | "save_buffer": False, 33 | "visualise_every": 0, 34 | "estimate_q_every": 5000, 35 | "log_every": 2500, 36 | } 37 | 38 | # ----------------------------------- 39 | 40 | 41 | def make_algo(logger): 42 | return DDPG( 43 | state_dim=STATE_DIM, 44 | action_dim=ACTION_DIM, 45 | device=args.device, 46 | logger=logger, 47 | ) 48 | 49 | 50 | def make_logger(seed: int) -> Logger: 51 | global config 52 | log_dir = create_logdir(logdir="logs", algo="DDPG", env=args.env, seed=seed) 53 | return FileLogger(log_dir, config) 54 | 55 | 56 | if __name__ == "__main__": 57 | args = parse_args() 58 | run_training(make_algo, make_env, make_logger, config, args.seeds, args.start_seed) 59 | -------------------------------------------------------------------------------- /src/oprl/configs/sac.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from oprl.algos.sac import SAC 4 | from oprl.configs.utils import create_logdir, parse_args 5 | from oprl.utils.utils import set_logging 6 | 7 | set_logging(logging.INFO) 8 | from oprl.env import make_env as _make_env 9 | from oprl.utils.logger import FileLogger, Logger 10 | from oprl.utils.run_training import run_training 11 | 12 | logging.basicConfig(level=logging.INFO) 13 | 14 | args = parse_args() 15 | 16 | 17 | def make_env(seed: int): 18 | return _make_env(args.env, seed=seed) 19 | 20 | 21 | env = make_env(seed=0) 22 | STATE_DIM: int = env.observation_space.shape[0] 23 | ACTION_DIM: int = env.action_space.shape[0] 24 | 25 | 26 | # -------- Config params ----------- 27 | 28 | config = { 29 | "state_dim": STATE_DIM, 30 | "action_dim": ACTION_DIM, 31 | "num_steps": int(1_000_000), 32 | "eval_every": 2500, 33 | "device": args.device, 34 | "save_buffer": False, 35 | "visualise_every": 0, 36 | "estimate_q_every": 5000, 37 | "log_every": 1000, 38 | } 39 | 40 | # ----------------------------------- 41 | 42 | 43 | def make_algo(logger): 44 | return SAC( 45 | state_dim=STATE_DIM, 46 | action_dim=ACTION_DIM, 47 | device=args.device, 48 | logger=logger, 49 | ) 50 | 51 | 52 | def make_logger(seed: int) -> Logger: 53 | global config 54 | log_dir = create_logdir(logdir="logs", algo="SAC", env=args.env, seed=seed) 55 | return FileLogger(log_dir, config) 56 | 57 | 58 | if __name__ == "__main__": 59 | args = parse_args() 60 | run_training(make_algo, make_env, make_logger, config, args.seeds, args.start_seed) 61 | -------------------------------------------------------------------------------- /src/oprl/configs/td3.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from oprl.algos.td3 import TD3 4 | from oprl.configs.utils import create_logdir, parse_args 5 | from oprl.utils.utils import set_logging 6 | 7 | set_logging(logging.INFO) 8 | from oprl.env import make_env as _make_env 9 | from oprl.utils.logger import FileLogger, Logger 10 | from oprl.utils.run_training import run_training 11 | 12 | args = parse_args() 13 | 14 | 15 | def make_env(seed: int): 16 | return _make_env(args.env, seed=seed) 17 | 18 | 19 | env = make_env(seed=0) 20 | STATE_DIM: int = env.observation_space.shape[0] 21 | ACTION_DIM: int = env.action_space.shape[0] 22 | 23 | 24 | # -------- Config params ----------- 25 | 26 | config = { 27 | "state_dim": STATE_DIM, 28 | "action_dim": ACTION_DIM, 29 | "num_steps": int(1_000_000), 30 | "eval_every": 2500, 31 | "device": args.device, 32 | "save_buffer": False, 33 | "visualise_every": 0, 34 | "estimate_q_every": 5000, 35 | "log_every": 2500, 36 | } 37 | 38 | # ----------------------------------- 39 | 40 | 41 | def make_algo(logger): 42 | return TD3( 43 | state_dim=STATE_DIM, 44 | action_dim=ACTION_DIM, 45 | device=args.device, 46 | logger=logger, 47 | ) 48 | 49 | 50 | def make_logger(seed: int) -> Logger: 51 | global config 52 | 53 | log_dir = create_logdir(logdir="logs", algo="TD3", env=args.env, seed=seed) 54 | return FileLogger(log_dir, config) 55 | 56 | 57 | if __name__ == "__main__": 58 | args = parse_args() 59 | run_training(make_algo, make_env, make_logger, config, args.seeds, args.start_seed) 60 | -------------------------------------------------------------------------------- /src/oprl/configs/tqc.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from oprl.algos.tqc import TQC 4 | from oprl.configs.utils import create_logdir, parse_args 5 | from oprl.utils.utils import set_logging 6 | 7 | set_logging(logging.INFO) 8 | from oprl.env import make_env as _make_env 9 | from oprl.utils.logger import FileLogger, Logger 10 | from oprl.utils.run_training import run_training 11 | 12 | args = parse_args() 13 | 14 | 15 | def make_env(seed: int): 16 | return _make_env(args.env, seed=seed) 17 | 18 | 19 | env = make_env(seed=0) 20 | STATE_DIM: int = env.observation_space.shape[0] 21 | ACTION_DIM: int = env.action_space.shape[0] 22 | 23 | 24 | # -------- Config params ----------- 25 | 26 | config = { 27 | "state_dim": STATE_DIM, 28 | "action_dim": ACTION_DIM, 29 | "num_steps": int(1_000_000), 30 | "eval_every": 2500, 31 | "device": args.device, 32 | "save_buffer": False, 33 | "visualise_every": 0, 34 | "estimate_q_every": 0, # TODO: Here is the unsupported logic 35 | "log_every": 2500, 36 | } 37 | 38 | # ----------------------------------- 39 | 40 | 41 | def make_algo(logger: Logger): 42 | return TQC( 43 | state_dim=STATE_DIM, 44 | action_dim=ACTION_DIM, 45 | device=args.device, 46 | logger=logger, 47 | ) 48 | 49 | 50 | def make_logger(seed: int) -> Logger: 51 | global config 52 | log_dir = create_logdir(logdir="logs", algo="TQC", env=args.env, seed=seed) 53 | return FileLogger(log_dir, config) 54 | 55 | 56 | if __name__ == "__main__": 57 | args = parse_args() 58 | run_training(make_algo, make_env, make_logger, config, args.seeds, args.start_seed) 59 | -------------------------------------------------------------------------------- /src/oprl/configs/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | from datetime import datetime 5 | 6 | 7 | def parse_args() -> argparse.Namespace: 8 | parser = argparse.ArgumentParser(description="Run training") 9 | parser.add_argument("--config", type=str, help="Path to the config file.") 10 | parser.add_argument( 11 | "--env", type=str, default="cartpole-balance", help="Name of the environment." 12 | ) 13 | parser.add_argument( 14 | "--seeds", 15 | type=int, 16 | default=1, 17 | help="Number of parallel processes launched with different random seeds.", 18 | ) 19 | parser.add_argument( 20 | "--start_seed", 21 | type=int, 22 | default=0, 23 | help="Number of the first seed. Following seeds will be incremented from it.", 24 | ) 25 | parser.add_argument( 26 | "--device", type=str, default="cpu", help="Device to perform training on." 27 | ) 28 | return parser.parse_args() 29 | 30 | 31 | def create_logdir(logdir: str, algo: str, env: str, seed: int) -> str: 32 | dt = datetime.now().strftime("%Y_%m_%d_%Hh%Mm") 33 | log_dir = os.path.join(logdir, algo, f"{algo}-env_{env}-seed_{seed}-{dt}") 34 | logging.info(f"LOGDIR: {log_dir}") 35 | return log_dir 36 | -------------------------------------------------------------------------------- /src/oprl/distrib/distrib_runner.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pickle 3 | import time 4 | from itertools import count 5 | from multiprocessing import Process 6 | 7 | import numpy as np 8 | import pika 9 | import torch 10 | 11 | 12 | class Queue: 13 | def __init__(self, name: str, host: str = "localhost"): 14 | self._name = name 15 | 16 | connection = pika.BlockingConnection(pika.ConnectionParameters(host=host)) 17 | self.channel = connection.channel() 18 | self.channel.queue_declare(queue=name) 19 | 20 | def push(self, data) -> None: 21 | self.channel.basic_publish(exchange="", routing_key=self._name, body=data) 22 | 23 | def pop(self) -> bytes | None: 24 | method_frame, header_frame, body = self.channel.basic_get(queue=self._name) 25 | if method_frame: 26 | self.channel.basic_ack(method_frame.delivery_tag) 27 | return body 28 | return None 29 | 30 | 31 | def env_worker(make_env, make_policy, n_episodes, id_worker): 32 | env = make_env(seed=0) 33 | logging.info("Env created.") 34 | 35 | policy = make_policy() 36 | logging.info("Policy created.") 37 | 38 | q_env = Queue(f"env_{id_worker}") 39 | q_policy = Queue(f"policy_{id_worker}") 40 | logging.info("Queue created.") 41 | 42 | episodes = [] 43 | 44 | total_env_step = 0 45 | # TODO: Move parameter to config 46 | start_steps = 1000 47 | for i_ep in range(n_episodes): 48 | if i_ep % 10 == 0: 49 | logging.info(f"AGENT {id_worker} EPISODE {i_ep}") 50 | 51 | episode = [] 52 | state, _ = env.reset() 53 | # TODO: Move parameter to config 54 | for env_step in range(1000): 55 | if total_env_step <= start_steps: 56 | action = env.sample_action() 57 | else: 58 | action = policy.explore(state) 59 | 60 | next_state, reward, terminated, truncated, _ = env.step(action) 61 | episode.append([state, action, reward, terminated, next_state]) 62 | 63 | if terminated or truncated: 64 | break 65 | state = next_state 66 | total_env_step += 1 67 | 68 | q_env.push(pickle.dumps(episode)) 69 | 70 | while True: 71 | data = q_policy.pop() 72 | if data is None: 73 | logging.info("Waiting for the policy..") 74 | time.sleep(2.0) 75 | continue 76 | 77 | policy.load_state_dict(pickle.loads(data)) 78 | break 79 | 80 | logging.info("Episode by env worker is done.") 81 | 82 | 83 | def policy_update_worker(make_algo, make_env_test, make_buffer, n_workers): 84 | algo = make_algo() 85 | logging.info("Algo created.") 86 | buffer = make_buffer() 87 | logging.info("Buffer created.") 88 | 89 | q_envs = [] 90 | q_policies = [] 91 | for i_env in range(n_workers): 92 | q_envs.append(Queue(f"env_{i_env}")) 93 | q_policies.append(Queue(f"policy_{i_env}")) 94 | logging.info("Learner queue created.") 95 | 96 | batch_size = 128 97 | 98 | logging.info("Warming up the learner...") 99 | time.sleep(2.0) 100 | 101 | for i_epoch in count(0): 102 | logging.info(f"Epoch: {i_epoch}") 103 | n_waits = 0 104 | for i_env in range(n_workers): 105 | while True: 106 | data = q_envs[i_env].pop() 107 | if data: 108 | episode = pickle.loads(data) 109 | buffer.add_episode(episode) 110 | break 111 | else: 112 | logging.info("Waiting for the env data...") 113 | # TODO: not optimal wait for each queue 114 | time.sleep(1) 115 | n_waits += 1 116 | if n_waits == 10: 117 | logging.info("Learner tired to wait, exiting...") 118 | return 119 | continue 120 | 121 | # TODO: Remove hardcoded value 122 | if i_epoch > 16: 123 | for i in range(1000 * 4): 124 | batch = buffer.sample(batch_size) 125 | algo.update(*batch) 126 | if i % int(1000) == 0: 127 | logging.info(f"\tUpdating {i}") 128 | 129 | policy_state_dict = algo.get_policy_state_dict() 130 | 131 | policy_serialized = pickle.dumps(policy_state_dict) 132 | for i_env in range(n_workers): 133 | q_policies[i_env].push(policy_serialized) 134 | 135 | if True: 136 | mean_reward = evaluate(algo, make_env_test) 137 | algo.logger.log_scalar("trainer/ep_reward", mean_reward, i_epoch) 138 | 139 | logging.info("Update by policy update worker done.") 140 | 141 | 142 | def evaluate(algo, make_env_test, num_eval_episodes: int = 5, seed: int = 0): 143 | returns = [] 144 | for i_ep in range(num_eval_episodes): 145 | env_test = make_env_test(seed * 100 + i_ep) 146 | state, _ = env_test.reset() 147 | 148 | episode_return = 0.0 149 | terminated, truncated = False, False 150 | 151 | while not (terminated or truncated): 152 | action = algo.exploit(state) 153 | state, reward, terminated, truncated, _ = env_test.step(action) 154 | episode_return += reward 155 | 156 | returns.append(episode_return) 157 | 158 | mean_return = np.mean(returns) 159 | return mean_return 160 | -------------------------------------------------------------------------------- /src/oprl/distrib_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | from datetime import datetime 5 | from multiprocessing import Process 6 | 7 | import torch 8 | import torch.nn as nn 9 | from algos.ddpg import DDPG, DeterministicPolicy 10 | from distrib.distrib_runner import env_worker, policy_update_worker 11 | from env import DMControlEnv, make_env 12 | from trainers.buffers.episodic_buffer import EpisodicReplayBuffer 13 | from utils.logger import Logger 14 | 15 | print("Imports ok.") 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser(description="Run training") 20 | parser.add_argument("--config", type=str, help="Path to the config file.") 21 | parser.add_argument( 22 | "--env", type=str, default="cartpole-balance", help="Name of the environment." 23 | ) 24 | parser.add_argument( 25 | "--device", type=str, default="cpu", help="Device to perform training on." 26 | ) 27 | return parser.parse_args() 28 | 29 | 30 | args = parse_args() 31 | 32 | 33 | def make_env(seed: int): 34 | """ 35 | Args: 36 | name: Environment name. 37 | """ 38 | return DMControlEnv(args.env, seed=seed) 39 | 40 | 41 | env = make_env(seed=0) 42 | 43 | STATE_SHAPE = env.observation_space.shape 44 | ACTION_SHAPE = env.action_space.shape 45 | print("STATE ACTION SHAPE: ", STATE_SHAPE, ACTION_SHAPE) 46 | 47 | 48 | def make_policy(): 49 | return DeterministicPolicy( 50 | state_dim=STATE_SHAPE, 51 | action_dim=ACTION_SHAPE, 52 | hidden_units=[256, 256], 53 | hidden_activation=nn.ReLU(inplace=True), 54 | ) 55 | 56 | 57 | def make_buffer(): 58 | buffer = EpisodicReplayBuffer( 59 | buffer_size=int(100_000), 60 | state_shape=STATE_SHAPE, 61 | action_shape=ACTION_SHAPE, 62 | device="cpu", 63 | gamma=0.99, 64 | ) 65 | return buffer 66 | 67 | 68 | def make_algo(): 69 | time = datetime.now().strftime("%Y-%m-%d_%H_%M") 70 | log_dir = os.path.join("logs_debug", "DDPG", f"DDPG-env_ENV-seedSEED-{time}") 71 | print("LOGDIR: ", log_dir) 72 | logger = Logger(log_dir, {}) 73 | 74 | algo = DDPG( 75 | state_dim=STATE_SHAPE, 76 | action_dim=ACTION_SHAPE, 77 | device="cpu", 78 | seed=0, 79 | logger=logger, 80 | ) 81 | return algo 82 | 83 | 84 | if __name__ == "__main__": 85 | ENV_WORKERS = 2 86 | 87 | seed = 0 88 | 89 | processes = [] 90 | 91 | for i_env in range(ENV_WORKERS): 92 | processes.append( 93 | Process(target=env_worker, args=(make_env, make_policy, i_env)) 94 | ) 95 | processes.append( 96 | Process( 97 | target=policy_update_worker, 98 | args=(make_algo, make_env, make_buffer, ENV_WORKERS), 99 | ) 100 | ) 101 | 102 | for p in processes: 103 | p.start() 104 | 105 | for p in processes: 106 | p.join() 107 | 108 | print("OK.") 109 | -------------------------------------------------------------------------------- /src/oprl/env.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from collections import OrderedDict 3 | from typing import Any 4 | 5 | import numpy as np 6 | import numpy.typing as npt 7 | from dm_control import suite 8 | 9 | 10 | class BaseEnv(ABC): 11 | @abstractmethod 12 | def reset(self) -> tuple[npt.ArrayLike, dict[str, Any]]: 13 | pass 14 | 15 | @abstractmethod 16 | def step( 17 | self, action: npt.ArrayLike 18 | ) -> tuple[npt.ArrayLike, npt.ArrayLike, bool, bool, dict[str, Any]]: 19 | pass 20 | 21 | @abstractmethod 22 | def sample_action(self) -> npt.ArrayLike: 23 | pass 24 | 25 | @property 26 | def env_family(self) -> str: 27 | return "" 28 | 29 | 30 | class DummyEnv(BaseEnv): 31 | def reset(self) -> tuple[npt.ArrayLike, dict[str, Any]]: 32 | return np.array([]), {} 33 | 34 | def step( 35 | self, action: npt.ArrayLike 36 | ) -> tuple[npt.ArrayLike, npt.ArrayLike, bool, bool, dict[str, Any]]: 37 | return np.array([]), np.array([]), False, False, {} 38 | 39 | def sample_action(self) -> npt.ArrayLike: 40 | return np.array([]) 41 | 42 | @property 43 | def env_family(self) -> str: 44 | return "" 45 | 46 | 47 | class SafetyGym(BaseEnv): 48 | def __init__(self, env_name: str, seed: int): 49 | import safety_gymnasium as gym 50 | 51 | self._env = gym.make(env_name) 52 | self._seed = seed 53 | 54 | def step( 55 | self, action: npt.ArrayLike 56 | ) -> tuple[npt.ArrayLike, npt.ArrayLike, bool, bool, dict[str, Any]]: 57 | obs, reward, cost, terminated, truncated, info = self._env.step(action) 58 | info["cost"] = cost 59 | return obs, reward, terminated, truncated, info 60 | 61 | def reset(self) -> tuple[npt.ArrayLike, dict[str, Any]]: 62 | obs, info = self._env.reset(seed=self._seed) 63 | self._env.step(self._env.action_space.sample()) 64 | return obs, info 65 | 66 | def sample_action(self): 67 | return self._env.action_space.sample() 68 | 69 | @property 70 | def observation_space(self): 71 | return self._env.observation_space 72 | 73 | @property 74 | def action_space(self): 75 | return self._env.action_space 76 | 77 | @property 78 | def env_family(self) -> str: 79 | return "safety_gymnasium" 80 | 81 | 82 | class DMControlEnv(BaseEnv): 83 | def __init__(self, env: str, seed: int): 84 | domain, task = env.split("-") 85 | self.random_state = np.random.RandomState(seed) 86 | self.env = suite.load(domain, task, task_kwargs={"random": self.random_state}) 87 | 88 | self._render_width = 200 89 | self._render_height = 200 90 | self._camera_id = 0 91 | 92 | def reset(self, *args, **kwargs) -> tuple[npt.ArrayLike, dict[str, Any]]: 93 | obs = self._flat_obs(self.env.reset().observation) 94 | return obs, {} 95 | 96 | def step( 97 | self, action: npt.ArrayLike 98 | ) -> tuple[npt.ArrayLike, npt.ArrayLike, bool, bool, dict[str, Any]]: 99 | time_step = self.env.step(action) 100 | obs = self._flat_obs(time_step.observation) 101 | 102 | terminated = False 103 | truncated = self.env._step_count >= self.env._step_limit 104 | 105 | return obs, time_step.reward, terminated, truncated, {} 106 | 107 | def sample_action(self) -> npt.ArrayLike: 108 | spec = self.env.action_spec() 109 | action = self.random_state.uniform(spec.minimum, spec.maximum, spec.shape) 110 | return action 111 | 112 | @property 113 | def observation_space(self) -> npt.ArrayLike: 114 | return np.zeros( 115 | sum(int(np.prod(v.shape)) for v in self.env.observation_spec().values()) 116 | ) 117 | 118 | @property 119 | def action_space(self) -> npt.ArrayLike: 120 | return np.zeros(self.env.action_spec().shape[0]) 121 | 122 | def render(self) -> npt.ArrayLike: 123 | """ 124 | returned shape: [1, W, H, C] 125 | """ 126 | img = self.env.physics.render( 127 | camera_id=self._camera_id, 128 | height=self._render_width, 129 | width=self._render_width, 130 | ) 131 | img = img.astype(np.uint8) 132 | return np.expand_dims(img, 0) 133 | 134 | def _flat_obs(self, obs: OrderedDict) -> npt.ArrayLike: 135 | obs_flatten = [] 136 | for _, o in obs.items(): 137 | if len(o.shape) == 0: 138 | obs_flatten.append(np.array([o])) 139 | elif len(o.shape) == 2 and o.shape[1] > 1: 140 | obs_flatten.append(o.flatten()) 141 | else: 142 | obs_flatten.append(o) 143 | return np.concatenate(obs_flatten, dtype="float32") 144 | 145 | @property 146 | def env_family(self) -> str: 147 | return "dm_control" 148 | 149 | 150 | ENV_MAPPER = { 151 | "dm_control": set( 152 | [ 153 | "acrobot-swingup", 154 | "ball_in_cup-catch", 155 | "cartpole-balance", 156 | "cartpole-swingup", 157 | "cheetah-run", 158 | "finger-spin", 159 | "finger-turn_easy", 160 | "finger-turn_hard", 161 | "fish-upright", 162 | "fish-swim", 163 | "hopper-stand", 164 | "hopper-hop", 165 | "humanoid-stand", 166 | "humanoid-walk", 167 | "humanoid-run", 168 | "pendulum-swingup", 169 | "point_mass-easy", 170 | "reacher-easy", 171 | "reacher-hard", 172 | "swimmer-swimmer6", 173 | "swimmer-swimmer15", 174 | "walker-stand", 175 | "walker-walk", 176 | "walker-run", 177 | ] 178 | ), 179 | "safety_gymnasium": set( 180 | [ 181 | "SafetyPointGoal1-v0", 182 | "SafetyPointGoal2-v0", 183 | "SafetyPointButton1-v0", 184 | "SafetyPointButton2-v0", 185 | "SafetyPointPush1-v0", 186 | "SafetyPointPush2-v0", 187 | "SafetyPointCircle1-v0", 188 | "SafetyPointCircle2-v0", 189 | "SafetyCarGoal1-v0", 190 | "SafetyCarGoal2-v0", 191 | "SafetyCarButton1-v0", 192 | "SafetyCarButton2-v0", 193 | "SafetyCarPush1-v0", 194 | "SafetyCarPush2-v0", 195 | "SafetyCarCircle1-v0", 196 | "SafetyCarCircle2-v0", 197 | "SafetyAntGoal1-v0", 198 | "SafetyAntGoal2-v0", 199 | "SafetyAntButton1-v0", 200 | "SafetyAntButton2-v0", 201 | "SafetyAntPush1-v0", 202 | "SafetyAntPush2-v0", 203 | "SafetyAntCircle1-v0", 204 | "SafetyAntCircle2-v0", 205 | "SafetyDoggoGoal1-v0", 206 | "SafetyDoggoGoal2-v0", 207 | "SafetyDoggoButton1-v0", 208 | "SafetyDoggoButton2-v0", 209 | "SafetyDoggoPush1-v0", 210 | "SafetyDoggoPush2-v0", 211 | "SafetyDoggoCircle1-v0", 212 | "SafetyDoggoCircle2-v0", 213 | ] 214 | ), 215 | } 216 | 217 | 218 | def make_env(name: str, seed: int): 219 | """ 220 | Args: 221 | name: Environment name. 222 | """ 223 | for env_type, env_set in ENV_MAPPER.items(): 224 | if name in env_set: 225 | if env_type == "dm_control": 226 | return DMControlEnv(name, seed=seed) 227 | elif env_type == "safety_gymnasium": 228 | return SafetyGym(name, seed=seed) 229 | else: 230 | raise ValueError(f"Unsupported environment: {name}") 231 | -------------------------------------------------------------------------------- /src/oprl/trainers/base_trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from oprl.env import BaseEnv 7 | from oprl.trainers.buffers.episodic_buffer import EpisodicReplayBuffer 8 | from oprl.utils.logger import Logger, StdLogger 9 | 10 | 11 | class BaseTrainer: 12 | def __init__( 13 | self, 14 | state_dim: int, 15 | action_dim: int, 16 | env: BaseEnv, 17 | make_env_test: Callable[[int], BaseEnv], 18 | algo: Any | None = None, 19 | buffer_size: int = int(1e6), 20 | gamma: float = 0.99, 21 | num_steps: int = int(1e6), 22 | start_steps: int = int(10e3), 23 | batch_size: int = 128, 24 | eval_interval: int = int(2e3), 25 | num_eval_episodes: int = 10, 26 | save_buffer_every: int = 0, 27 | visualise_every: int = 0, 28 | estimate_q_every: int = 0, 29 | stdout_log_every: int = int(1e5), 30 | device: str = "cpu", 31 | seed: int = 0, 32 | logger: Logger = StdLogger(), 33 | ): 34 | """ 35 | Args: 36 | state_dim: Dimension of the observation. 37 | action_dim: Dimension of the action. 38 | env: Enviornment object. 39 | make_env_test: Environment object for evaluation. 40 | algo: Codename for the algo (SAC). 41 | buffer_size: Buffer size in transitions. 42 | gamma: Discount factor. 43 | num_step: Number of env steps to train. 44 | start_steps: Number of environment steps not to perform training at the beginning. 45 | batch_size: Batch-size. 46 | eval_interval: Number of env step after which perform evaluation. 47 | save_buffer_every: Number of env steps after which save replay buffer. 48 | visualise_every: Number of env steps after which perform vizualisation. 49 | device: Name of the device. 50 | stdout_log_every: Number of evn steps after which log info to stdout. 51 | seed: Random seed. 52 | logger: Logger instance. 53 | """ 54 | self._env = env 55 | self._make_env_test = make_env_test 56 | self._algo = algo 57 | self._gamma = gamma 58 | self._device = device 59 | self._save_buffer_every = save_buffer_every 60 | self._visualize_every = visualise_every 61 | self._estimate_q_every = estimate_q_every 62 | self._stdout_log_every = stdout_log_every 63 | self._logger = logger 64 | self.seed = seed 65 | 66 | self.buffer = EpisodicReplayBuffer( 67 | buffer_size=buffer_size, 68 | state_dim=state_dim, 69 | action_dim=action_dim, 70 | device=device, 71 | gamma=gamma, 72 | ) 73 | 74 | self.batch_size = batch_size 75 | self.num_steps = num_steps 76 | self.start_steps = start_steps 77 | self.eval_interval = eval_interval 78 | self.num_eval_episodes = num_eval_episodes 79 | 80 | def train(self): 81 | ep_step = 0 82 | state, _ = self._env.reset() 83 | 84 | for env_step in range(self.num_steps + 1): 85 | ep_step += 1 86 | if env_step <= self.start_steps: 87 | action = self._env.sample_action() 88 | else: 89 | action = self._algo.explore(state) 90 | next_state, reward, terminated, truncated, _ = self._env.step(action) 91 | 92 | self.buffer.append( 93 | state, action, reward, terminated, episode_done=terminated or truncated 94 | ) 95 | if terminated or truncated: 96 | next_state, _ = self._env.reset() 97 | ep_step = 0 98 | state = next_state 99 | 100 | if len(self.buffer) < self.batch_size: 101 | continue 102 | 103 | batch = self.buffer.sample(self.batch_size) 104 | self._algo.update(*batch) 105 | 106 | self._eval_routine(env_step, batch) 107 | self._visualize(env_step) 108 | self._save_buffer(env_step) 109 | self._log_stdout(env_step, batch) 110 | 111 | def _eval_routine(self, env_step: int, batch): 112 | if env_step % self.eval_interval == 0: 113 | self._log_evaluation(env_step) 114 | 115 | self._logger.log_scalar("trainer/avg_reward", batch[2].mean(), env_step) 116 | self._logger.log_scalar( 117 | "trainer/buffer_transitions", len(self.buffer), env_step 118 | ) 119 | self._logger.log_scalar( 120 | "trainer/buffer_episodes", self.buffer.num_episodes, env_step 121 | ) 122 | self._logger.log_scalar( 123 | "trainer/buffer_last_ep_len", 124 | self.buffer.get_last_ep_len(), 125 | env_step, 126 | ) 127 | 128 | def _log_evaluation(self, env_step: int): 129 | returns = [] 130 | for i_ep in range(self.num_eval_episodes): 131 | env_test = self._make_env_test(seed=self.seed + i_ep) 132 | state, _ = env_test.reset() 133 | 134 | episode_return = 0.0 135 | terminated, truncated = False, False 136 | 137 | while not (terminated or truncated): 138 | action = self._algo.exploit(state) 139 | state, reward, terminated, truncated, _ = env_test.step(action) 140 | episode_return += reward 141 | 142 | returns.append(episode_return) 143 | 144 | mean_return = np.mean(returns) 145 | self._logger.log_scalar("trainer/ep_reward", mean_return, env_step) 146 | 147 | def _visualize(self, env_step: int): 148 | if self._visualize_every > 0 and env_step % self._visualize_every == 0: 149 | imgs = self.visualise_policy() # [T, W, H, C] 150 | if imgs is not None: 151 | self._logger.log_video("eval_policy", imgs, env_step) 152 | 153 | def _save_buffer(self, env_step: int): 154 | if self._save_buffer_every > 0 and env_step % self._save_buffer_every == 0: 155 | self.buffer.save(f"{self.log_dir}/buffers/buffer_step_{env_step}.pickle") 156 | 157 | def _estimate_q(self, env_step: int): 158 | if self._estimate_q_every > 0 and env_step % self._estimate_q_every == 0: 159 | q_true = self.estimate_true_q() 160 | q_critic = self.estimate_critic_q() 161 | if q_true is not None: 162 | self._logger.log_scalar("trainer/Q-estimate", q_true, env_step) 163 | self._logger.log_scalar("trainer/Q-critic", q_critic, env_step) 164 | self._logger.log_scalar( 165 | "trainer/Q_asb_diff", q_critic - q_true, env_step 166 | ) 167 | 168 | def _log_stdout(self, env_step: int, batch): 169 | if env_step % self._stdout_log_every == 0: 170 | perc = int(env_step / self.num_steps * 100) 171 | print( 172 | f"Env step {env_step:8d} ({perc:2d}%) Avg Reward {batch[2].mean():10.3f}" 173 | ) 174 | 175 | def visualise_policy(self): 176 | """ 177 | returned shape: [N, C, W, H] 178 | """ 179 | env = self._make_env_test(seed=self.seed) 180 | try: 181 | imgs = [] 182 | state, _ = env.reset() 183 | done = False 184 | while not done: 185 | img = env.render() 186 | imgs.append(img) 187 | action = self._algo.exploit(state) 188 | state, _, terminated, truncated, _ = env.step(action) 189 | done = terminated or truncated 190 | return np.concatenate(imgs) 191 | except Exception as e: 192 | print(f"Failed to visualise a policy: {e}") 193 | return None 194 | 195 | def estimate_true_q(self, eval_episodes: int = 10) -> float | None: 196 | try: 197 | qs = [] 198 | for i_eval in range(eval_episodes): 199 | env = self._make_env_test(seed=self.seed * 100 + i_eval) 200 | print("Before reset etimate q") 201 | state, _ = env.reset() 202 | 203 | q = 0 204 | s_i = 1 205 | while True: 206 | action = self._algo.exploit(state) 207 | state, r, terminated, truncated, _ = env.step(action) 208 | q += r * self._gamma**s_i 209 | s_i += 1 210 | if terminated or truncated: 211 | break 212 | 213 | qs.append(q) 214 | 215 | return np.mean(qs, dtype=float) 216 | except Exception as e: 217 | print(f"Failed to estimate Q-value: {e}") 218 | return None 219 | 220 | def estimate_critic_q(self, num_episodes: int = 10) -> float: 221 | qs = [] 222 | for i_eval in range(num_episodes): 223 | env = self._make_env_test(seed=self.seed * 100 + i_eval) 224 | 225 | state, _ = env.reset() 226 | action = self._algo.exploit(state) 227 | 228 | state = torch.tensor(state).unsqueeze(0).float().to(self._device) 229 | action = torch.tensor(action).unsqueeze(0).float().to(self._device) 230 | 231 | q = self._algo.critic(state, action) 232 | # TODO: TQC is not supported by this logic, need to update 233 | if isinstance(q, tuple): 234 | q = q[0] 235 | q = q.item() 236 | qs.append(q) 237 | 238 | return np.mean(qs, dtype=float) 239 | 240 | 241 | def run_training(make_algo, make_env, make_logger, config: dict[str, Any], seed: int): 242 | env = make_env(seed=seed) 243 | logger = make_logger(seed) 244 | 245 | trainer = BaseTrainer( 246 | state_dim=config["state_shape"], 247 | action_dim=config["action_shape"], 248 | env=env, 249 | make_env_test=make_env, 250 | algo=make_algo(logger, seed), 251 | num_steps=config["num_steps"], 252 | eval_interval=config["eval_every"], 253 | device=config["device"], 254 | save_buffer_every=config["save_buffer"], 255 | visualise_every=config["visualise_every"], 256 | estimate_q_every=config["estimate_q_every"], 257 | stdout_log_every=config["log_every"], 258 | seed=seed, 259 | logger=logger, 260 | ) 261 | 262 | trainer.train() 263 | -------------------------------------------------------------------------------- /src/oprl/trainers/buffers/episodic_buffer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | class EpisodicReplayBuffer: 9 | def __init__( 10 | self, 11 | buffer_size: int, 12 | state_dim: int, 13 | action_dim: int, 14 | device: str, 15 | gamma: float, 16 | max_episode_len: int = 1000, 17 | dtype=torch.float, 18 | ): 19 | """ 20 | Args: 21 | buffer_size: Max number of transitions in buffer. 22 | state_dim: Dimension of the state. 23 | action_dim: Dimension of the action. 24 | device: Device to place buffer. 25 | gamma: Discount factor for N-step. 26 | max_episode_len: Max length of the episode to store. 27 | dtype: Data type. 28 | """ 29 | self.buffer_size = buffer_size 30 | self.max_episodes = buffer_size // max_episode_len 31 | self.max_episode_len = max_episode_len 32 | self.state_dim = state_dim 33 | self.action_dim = action_dim 34 | self.device = device 35 | self.gamma = gamma 36 | 37 | self.ep_pointer = 0 38 | self.cur_episodes = 1 39 | self.cur_size = 0 40 | 41 | self.actions = torch.empty( 42 | (self.max_episodes, max_episode_len, action_dim), 43 | dtype=dtype, 44 | device=device, 45 | ) 46 | self.rewards = torch.empty( 47 | (self.max_episodes, max_episode_len, 1), dtype=dtype, device=device 48 | ) 49 | self.dones = torch.empty( 50 | (self.max_episodes, max_episode_len, 1), dtype=dtype, device=device 51 | ) 52 | self.states = torch.empty( 53 | (self.max_episodes, max_episode_len + 1, state_dim), 54 | dtype=dtype, 55 | device=device, 56 | ) 57 | self.ep_lens = [0] * self.max_episodes 58 | 59 | self.actions_for_std = torch.empty( 60 | (100, action_dim), dtype=dtype, device=device 61 | ) 62 | self.actions_for_std_cnt = 0 63 | 64 | # TODO: rename to add 65 | def append(self, state, action, reward, done, episode_done=None): 66 | """ 67 | Args: 68 | state: state. 69 | action: action. 70 | reward: reward. 71 | done: done only if episode ends naturally. 72 | episode_done: done that can be set to True if time limit is reached. 73 | """ 74 | self.states[self.ep_pointer, self.ep_lens[self.ep_pointer]].copy_( 75 | torch.from_numpy(state) 76 | ) 77 | self.actions[self.ep_pointer, self.ep_lens[self.ep_pointer]].copy_( 78 | torch.from_numpy(action) 79 | ) 80 | self.rewards[self.ep_pointer, self.ep_lens[self.ep_pointer]] = float(reward) 81 | self.dones[self.ep_pointer, self.ep_lens[self.ep_pointer]] = float(done) 82 | 83 | self.actions_for_std[self.actions_for_std_cnt % 100].copy_( 84 | torch.from_numpy(action) 85 | ) 86 | self.actions_for_std_cnt += 1 87 | 88 | self.ep_lens[self.ep_pointer] += 1 89 | self.cur_size = min(self.cur_size + 1, self.buffer_size) 90 | if episode_done: 91 | self._inc_episode() 92 | 93 | def _inc_episode(self): 94 | self.ep_pointer = (self.ep_pointer + 1) % self.max_episodes 95 | self.cur_episodes = min(self.cur_episodes + 1, self.max_episodes) 96 | self.cur_size -= self.ep_lens[self.ep_pointer] 97 | self.ep_lens[self.ep_pointer] = 0 98 | 99 | def add_episode(self, episode): 100 | for s, a, r, d, s_ in episode: 101 | self.append(s, a, r, d, episode_done=d) 102 | if d: 103 | break 104 | else: 105 | self._inc_episode() 106 | 107 | def _inds_to_episodic(self, inds): 108 | start_inds = np.cumsum([0] + self.ep_lens[: self.cur_episodes - 1]) 109 | end_inds = start_inds + np.array(self.ep_lens[: self.cur_episodes]) 110 | ep_inds = np.argmin( 111 | inds.reshape(-1, 1) >= np.tile(end_inds, (len(inds), 1)), axis=1 112 | ) 113 | step_inds = inds - start_inds[ep_inds] 114 | 115 | return ep_inds, step_inds 116 | 117 | def sample(self, batch_size): 118 | inds = np.random.randint(low=0, high=self.cur_size, size=batch_size) 119 | ep_inds, step_inds = self._inds_to_episodic(inds) 120 | 121 | return ( 122 | self.states[ep_inds, step_inds], 123 | self.actions[ep_inds, step_inds], 124 | self.rewards[ep_inds, step_inds], 125 | self.dones[ep_inds, step_inds], 126 | self.states[ep_inds, step_inds + 1], 127 | ) 128 | 129 | def save(self, path: str): 130 | """ 131 | Args: 132 | path: Path to pickle file. 133 | """ 134 | dirname = os.path.dirname(path) 135 | if not os.path.exists(dirname): 136 | os.makedirs(dirname) 137 | 138 | data = { 139 | "states": self.states.cpu(), 140 | "actions": self.actions.cpu(), 141 | "rewards": self.rewards.cpu(), 142 | "dones": self.dones.cpu(), 143 | "ep_lens": self.ep_lens, 144 | } 145 | try: 146 | with open(path, "wb") as f: 147 | pickle.dump(data, f) 148 | print(f"Replay buffer saved to {path}") 149 | except Exception as e: 150 | print(f"Failed to save replay buffer: {e}") 151 | 152 | def __len__(self): 153 | return self.cur_size 154 | 155 | @property 156 | def num_episodes(self): 157 | return self.cur_episodes 158 | 159 | def get_last_ep_len(self): 160 | return self.ep_lens[self.ep_pointer] 161 | -------------------------------------------------------------------------------- /src/oprl/trainers/safe_trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable 2 | 3 | import numpy as np 4 | 5 | from oprl.env import BaseEnv 6 | from oprl.trainers.base_trainer import BaseTrainer 7 | from oprl.utils.logger import Logger, StdLogger 8 | 9 | 10 | class SafeTrainer(BaseTrainer): 11 | def __init__( 12 | self, 13 | state_dim: int, 14 | action_dim: int, 15 | env: BaseEnv, 16 | make_env_test: Callable[[int], BaseEnv], 17 | algo: Any | None = None, 18 | buffer_size: int = int(1e6), 19 | gamma: float = 0.99, 20 | num_steps=int(1e6), 21 | start_steps: int = int(10e3), 22 | batch_size: int = 128, 23 | eval_interval: int = int(2e3), 24 | num_eval_episodes: int = 10, 25 | save_buffer_every: int = 0, 26 | visualise_every: int = 0, 27 | estimate_q_every: int = 0, 28 | stdout_log_every: int = int(1e5), 29 | device: str = "cpu", 30 | seed: int = 0, 31 | logger: Logger = StdLogger(), 32 | ): 33 | """ 34 | Args: 35 | state_dim: Dimension of the observation. 36 | action_dim: Dimension of the action. 37 | env: Enviornment object. 38 | make_env_test: Environment object for evaluation. 39 | algo: Codename for the algo (SAC). 40 | buffer_size: Buffer size in transitions. 41 | gamma: Discount factor. 42 | num_step: Number of env steps to train. 43 | start_steps: Number of environment steps not to perform training at the beginning. 44 | batch_size: Batch-size. 45 | eval_interval: Number of env step after which perform evaluation. 46 | save_buffer_every: Number of env steps after which save replay buffer. 47 | visualise_every: Number of env steps after which perform vizualisation. 48 | stdout_log_every: Number of evn steps after which log info to stdout. 49 | device: Name of the device. 50 | seed: Random seed. 51 | logger: Logger instance. 52 | """ 53 | super().__init__( 54 | state_dim=state_dim, 55 | action_dim=action_dim, 56 | env=env, 57 | make_env_test=make_env_test, 58 | algo=algo, 59 | buffer_size=buffer_size, 60 | gamma=gamma, 61 | device=device, 62 | num_steps=num_steps, 63 | start_steps=start_steps, 64 | batch_size=batch_size, 65 | eval_interval=eval_interval, 66 | num_eval_episodes=num_eval_episodes, 67 | save_buffer_every=save_buffer_every, 68 | visualise_every=visualise_every, 69 | estimate_q_every=estimate_q_every, 70 | stdout_log_every=stdout_log_every, 71 | seed=seed, 72 | logger=logger, 73 | ) 74 | 75 | def train(self): 76 | ep_step = 0 77 | state, _ = self._env.reset() 78 | total_cost = 0 79 | 80 | for env_step in range(self.num_steps + 1): 81 | ep_step += 1 82 | if env_step <= self.start_steps: 83 | action = self._env.sample_action() 84 | else: 85 | action = self._algo.explore(state) 86 | next_state, reward, terminated, truncated, info = self._env.step(action) 87 | total_cost += info["cost"] 88 | 89 | self.buffer.append( 90 | state, action, reward, terminated, episode_done=terminated or truncated 91 | ) 92 | if terminated or truncated: 93 | next_state, _ = self._env.reset() 94 | ep_step = 0 95 | state = next_state 96 | 97 | if len(self.buffer) < self.batch_size: 98 | continue 99 | batch = self.buffer.sample(self.batch_size) 100 | self._algo.update(batch) 101 | 102 | self._eval_routine(env_step, batch) 103 | self._visualize(env_step) 104 | self._save_buffer(env_step) 105 | self._log_stdout(env_step, batch) 106 | 107 | self._logger.log_scalar("trainer/total_cost", total_cost, self.num_steps) 108 | 109 | def _log_evaluation(self, env_step: int): 110 | returns = [] 111 | costs = [] 112 | for i_ep in range(self.num_eval_episodes): 113 | env_test = self._make_env_test(seed=self.seed + i_ep) 114 | state, _ = env_test.reset() 115 | 116 | episode_return = 0 117 | episode_cost = 0 118 | terminated, truncated = False, False 119 | 120 | while not (terminated or truncated): 121 | action = self._algo.exploit(state) 122 | state, reward, terminated, truncated, info = env_test.step(action) 123 | episode_return += reward 124 | episode_cost += info["cost"] 125 | 126 | returns.append(episode_return) 127 | costs.append(episode_cost) 128 | 129 | self._logger.log_scalar( 130 | "trainer/ep_reward", np.mean(returns, dtype=float), env_step 131 | ) 132 | self._logger.log_scalar( 133 | "trainer/ep_cost", np.mean(costs, dtype=float), env_step 134 | ) 135 | -------------------------------------------------------------------------------- /src/oprl/utils/config.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | import sys 3 | 4 | 5 | def load_config(path: str): 6 | spec = importlib.util.spec_from_file_location("config", path) 7 | config = importlib.util.module_from_spec(spec) 8 | sys.modules["config"] = config 9 | spec.loader.exec_module(config) 10 | return config 11 | -------------------------------------------------------------------------------- /src/oprl/utils/logger.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import shutil 5 | from abc import ABC, abstractmethod 6 | from typing import Any 7 | 8 | import numpy as np 9 | from torch.utils.tensorboard.writer import SummaryWriter 10 | 11 | 12 | def copy_exp_dir(log_dir: str) -> None: 13 | cur_dir = os.path.join(os.getcwd(), "src") 14 | dest_dir = os.path.join(log_dir, "src") 15 | shutil.copytree(cur_dir, dest_dir) 16 | logging.info(f"Source copied into {dest_dir}") 17 | 18 | 19 | def save_json_config(config: dict[str, Any], path: str): 20 | with open(path, "w") as f: 21 | json.dump(config, f) 22 | 23 | 24 | class Logger(ABC): 25 | 26 | def log_scalars(self, values: dict[str, float], step: int): 27 | """ 28 | Args: 29 | values: Dict with tag -> value to log. 30 | step: Iter step. 31 | """ 32 | (self.log_scalar(k, v, step) for k, v in values.items()) 33 | 34 | @abstractmethod 35 | def log_scalar(self, tag: str, value: float, step: int): 36 | logging.info(f"{tag}\t{value}\tat step {step}") 37 | 38 | @abstractmethod 39 | def log_video(self, tag: str, imgs, step: int): 40 | logging.warning("Skipping logging video in STDOUT logger") 41 | 42 | 43 | class StdLogger(Logger): 44 | def __init__(self, *args, **kwargs): 45 | pass 46 | 47 | def log_scalar(self, tag: str, value: float, step: int): 48 | logging.info(f"{tag}\t{value}\tat step {step}") 49 | 50 | def log_video(self, *args, **kwargs): 51 | logging.warning("Skipping logging video in STDOUT logger") 52 | 53 | 54 | class FileLogger(Logger): 55 | def __init__(self, logdir: str, config: dict[str, Any]): 56 | self.writer = SummaryWriter(logdir) 57 | 58 | self._log_dir = logdir 59 | 60 | logging.info(f"Source code is copied to {logdir}") 61 | copy_exp_dir(logdir) 62 | save_json_config(config, os.path.join(logdir, "config.json")) 63 | 64 | def log_scalar(self, tag: str, value: float, step: int) -> None: 65 | self.writer.add_scalar(tag, value, step) 66 | self._log_scalar_to_file(tag, value, step) 67 | 68 | def log_video(self, tag: str, imgs, step: int) -> None: 69 | os.makedirs(os.path.join(self._log_dir, "images")) 70 | fn = os.path.join(self._log_dir, "images", f"{tag}_step_{step}.npz") 71 | with open(fn, "wb") as f: 72 | np.save(f, imgs) 73 | 74 | def _log_scalar_to_file(self, tag: str, val: float, step: int) -> None: 75 | fn = os.path.join(self._log_dir, f"{tag}.log") 76 | os.makedirs(os.path.dirname(fn), exist_ok=True) 77 | with open(fn, "a") as f: 78 | f.write(f"{step} {val}\n") 79 | -------------------------------------------------------------------------------- /src/oprl/utils/run_training.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from multiprocessing import Process 3 | 4 | from oprl.trainers.base_trainer import BaseTrainer 5 | from oprl.trainers.safe_trainer import SafeTrainer 6 | from oprl.utils.utils import set_seed 7 | 8 | 9 | def run_training( 10 | make_algo, make_env, make_logger, config, seeds: int = 1, start_seed: int = 0 11 | ): 12 | if seeds == 1: 13 | _run_training_func(make_algo, make_env, make_logger, config, 0) 14 | else: 15 | processes = [] 16 | for seed in range(start_seed, start_seed + seeds): 17 | processes.append( 18 | Process( 19 | target=_run_training_func, 20 | args=(make_algo, make_env, make_logger, config, seed), 21 | ) 22 | ) 23 | 24 | for i, p in enumerate(processes): 25 | p.start() 26 | logging.info(f"Starting process {i}...") 27 | 28 | for p in processes: 29 | p.join() 30 | 31 | logging.info("Training OK.") 32 | 33 | 34 | def _run_training_func(make_algo, make_env, make_logger, config, seed: int): 35 | set_seed(seed) 36 | env = make_env(seed=seed) 37 | logger = make_logger(seed) 38 | 39 | if env.env_family == "dm_control": 40 | trainer_class = BaseTrainer 41 | elif env.env_family == "safety_gymnasium": 42 | trainer_class = SafeTrainer 43 | else: 44 | raise ValueError(f"Unsupported env family: {env.env_family}") 45 | 46 | trainer = trainer_class( 47 | state_dim=config["state_dim"], 48 | action_dim=config["action_dim"], 49 | env=env, 50 | make_env_test=make_env, 51 | algo=make_algo(logger), 52 | num_steps=config["num_steps"], 53 | eval_interval=config["eval_every"], 54 | device=config["device"], 55 | save_buffer_every=config["save_buffer"], 56 | visualise_every=config["visualise_every"], 57 | estimate_q_every=config["estimate_q_every"], 58 | stdout_log_every=config["log_every"], 59 | seed=seed, 60 | logger=logger, 61 | ) 62 | 63 | trainer.train() 64 | -------------------------------------------------------------------------------- /src/oprl/utils/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | import shutil 5 | import sys 6 | from glob import glob 7 | 8 | import imageio 9 | import numpy as np 10 | import torch as t 11 | 12 | 13 | class OUNoise(object): 14 | def __init__( 15 | self, 16 | dim, 17 | low, 18 | high, 19 | mu=0.0, 20 | theta=0.15, 21 | max_sigma=0.3, 22 | min_sigma=0.3, 23 | decay_period=10_000, 24 | ): 25 | self.mu = mu 26 | self.theta = theta 27 | self.sigma = max_sigma 28 | self.max_sigma = max_sigma 29 | self.min_sigma = min_sigma 30 | self.decay_period = decay_period 31 | self.action_dim = dim 32 | self.low = low 33 | self.high = high 34 | 35 | def reset(self): 36 | self.state = np.ones(self.action_dim) * self.mu 37 | 38 | def evolve_state(self): 39 | x = self.state 40 | dx = self.theta * (self.mu - x) + self.sigma * np.random.randn(self.action_dim) 41 | self.state = x + dx 42 | return self.state 43 | 44 | def get_action(self, action, t=0): 45 | ou_state = self.evolve_state() 46 | self.sigma = self.max_sigma - (self.max_sigma - self.min_sigma) * min( 47 | 1.0, t / self.decay_period 48 | ) 49 | action = action.cpu().detach().numpy() 50 | return np.clip(action + ou_state, self.low, self.high) 51 | 52 | 53 | def make_gif(source_dir, output): 54 | """ 55 | Make gif file from set of .jpeg images. 56 | Args: 57 | source_dir (str): path with .jpeg images 58 | output (str): path to the output .gif file 59 | Returns: None 60 | """ 61 | batch_sort = lambda s: int(s[s.rfind("/") + 1 : s.rfind(".")]) 62 | image_paths = sorted(glob(os.path.join(source_dir, "*.png")), key=batch_sort) 63 | 64 | images = [] 65 | for filename in image_paths: 66 | images.append(imageio.imread(filename)) 67 | imageio.mimsave(output, images) 68 | 69 | 70 | def empty_torch_queue(q): 71 | while True: 72 | try: 73 | o = q.get_nowait() 74 | del o 75 | except: 76 | break 77 | q.close() 78 | 79 | 80 | def copy_exp_dir(log_dir: str) -> None: 81 | cur_dir = os.path.join(os.getcwd(), "src") 82 | dest_dir = os.path.join(log_dir, "src") 83 | shutil.copy(cur_dir, dest_dir) 84 | logging.info(f"Source copied into {dest_dir}") 85 | 86 | 87 | def set_seed(seed: int) -> None: 88 | random.seed(seed) 89 | np.random.seed(seed) 90 | t.manual_seed(seed) 91 | 92 | 93 | def set_logging(level: int): 94 | logging.basicConfig( 95 | level=level, 96 | format="%(asctime)s | %(filename)s:%(lineno)d\t %(levelname)s - %(message)s", 97 | stream=sys.stdout, 98 | ) 99 | -------------------------------------------------------------------------------- /src/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="oprl", 5 | version="1.0", 6 | author="Igor Kuznetsov", 7 | description="Reinforcement Learning Off-policy Algorithms with PyTorch", 8 | long_description="todo", 9 | url="todo", 10 | keywords="reinforcement, learning, off-policy", 11 | python_requires=">=3.7", 12 | # packages=find_packages(include=['exampleproject', 'exampleproject.*']), 13 | # install_requires=[ 14 | # 'PyYAML', 15 | # 'pandas==0.23.3', 16 | # 'numpy>=1.14.5', 17 | # 'matplotlib>=2.2.0,, 18 | # 'jupyter' 19 | # ], 20 | ) 21 | -------------------------------------------------------------------------------- /tests/functional/requirements.txt: -------------------------------------------------------------------------------- 1 | pytest==8.0.1 2 | torch==2.1.2 3 | tensorboard==2.15.1 4 | packaging==23.2 5 | dm-control==1.0.16 6 | mujoco==3.1.3 7 | -------------------------------------------------------------------------------- /tests/functional/src/test_env.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from oprl.env import DMControlEnv 4 | 5 | dm_control_envs = [ 6 | "acrobot-swingup", 7 | "ball_in_cup-catch", 8 | "cartpole-balance", 9 | "cartpole-swingup", 10 | "cheetah-run", 11 | "finger-spin", 12 | "finger-turn_easy", 13 | "finger-turn_hard", 14 | "fish-upright", 15 | "fish-swim", 16 | "hopper-stand", 17 | "hopper-hop", 18 | "humanoid-stand", 19 | "humanoid-walk", 20 | "humanoid-run", 21 | "pendulum-swingup", 22 | "point_mass-easy", 23 | "reacher-easy", 24 | "reacher-hard", 25 | "swimmer-swimmer6", 26 | "swimmer-swimmer15", 27 | "walker-stand", 28 | "walker-walk", 29 | "walker-run", 30 | ] 31 | 32 | 33 | @pytest.mark.parametrize("env_name", dm_control_envs) 34 | def test_dm_control_envs(env_name: str): 35 | env = DMControlEnv(env_name, seed=0) 36 | obs, info = env.reset() 37 | assert obs.shape[0] == env.observation_space.shape[0] 38 | assert isinstance(info, dict), "Info is expected to be a dict" 39 | 40 | rand_action = env.sample_action() 41 | assert rand_action.ndim == 1 42 | 43 | next_obs, reward, terminated, truncated, info = env.step(rand_action) 44 | assert next_obs.ndim == 1, "Expected 1-dimensional array as observation" 45 | assert isinstance(reward, float), "Reward is epxected to be a single float value" 46 | assert isinstance( 47 | terminated, bool 48 | ), "Terminated is expected to be a single bool value" 49 | assert isinstance( 50 | truncated, bool 51 | ), "Truncated is expected to be a single bool value" 52 | assert isinstance(info, dict), "Info is expected to be a dict" 53 | -------------------------------------------------------------------------------- /tests/functional/src/test_rl_algos.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from oprl.algos.ddpg import DDPG 5 | from oprl.algos.sac import SAC 6 | from oprl.algos.td3 import TD3 7 | from oprl.algos.tqc import TQC 8 | from oprl.env import DMControlEnv 9 | 10 | rl_algo_classes = [DDPG, SAC, TD3, TQC] 11 | 12 | 13 | @pytest.mark.parametrize("algo_class", rl_algo_classes) 14 | def test_rl_algo_run(algo_class): 15 | env = DMControlEnv("walker-walk", seed=0) 16 | obs, _ = env.reset(env.sample_action()) 17 | 18 | algo = algo_class( 19 | state_dim=env.observation_space.shape[0], 20 | action_dim=env.action_space.shape[0], 21 | ) 22 | action = algo.exploit(obs) 23 | assert action.ndim == 1 24 | 25 | action = algo.explore(obs) 26 | assert action.ndim == 1 27 | 28 | _batch_size = 8 29 | batch_obs = torch.randn(_batch_size, env.observation_space.shape[0]) 30 | batch_actions = torch.clamp( 31 | torch.randn(_batch_size, env.action_space.shape[0]), -1, 1 32 | ) 33 | batch_rewards = torch.randn(_batch_size, 1) 34 | batch_dones = torch.randint(2, (_batch_size, 1)) 35 | algo.update(batch_obs, batch_actions, batch_rewards, batch_dones, batch_obs) 36 | --------------------------------------------------------------------------------