├── .gitignore ├── README.md ├── deep_control ├── __init__.py ├── aac.py ├── adv_estimator.py ├── augmentations.py ├── awac.py ├── critic_searchers.py ├── ddpg.py ├── discor.py ├── envs.py ├── grac.py ├── nets.py ├── redq.py ├── replay.py ├── run.py ├── sac.py ├── sac_aug.py ├── sbc.py ├── sunrise.py ├── td3.py ├── tsr_caql.py └── utils.py ├── examples ├── basic_control │ ├── ddpg_gym.py │ ├── sac_gym.py │ ├── sunrise_gym.py │ └── td3_gym.py ├── d4rl │ ├── awac.py │ └── sbc.py ├── dmc │ ├── ddpg_dmc.py │ ├── discor_dmc.py │ ├── grac_dmc.py │ ├── redq_dmc.py │ ├── sac_aug_dmc.py │ ├── sac_dmc.py │ ├── sunrise_dmc.py │ └── td3_dmc.py └── dmcr │ └── sac_aug_dmcr.py ├── requirements.txt ├── setup.cfg └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | deep_control/saves/ 3 | 4 | deep_control/__pycache__/ 5 | 6 | __pycache__/ 7 | 8 | .eggs/ 9 | 10 | deep_control.egg-info/ 11 | 12 | saves/ 13 | 14 | examples/dc_saves 15 | 16 | 17 | deep_control/dc_saves/ 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Control 2 | ## Simple PyTorch Implementations of Deep RL Algorithms for Continuous Control Research 3 | 4 | This repository contains re-implementations of Deep RL algorithms for continuous action spaces. Some highlights: 5 | 6 | 1) Code is readable, and written to be easy to modify for future research. Many popular Deep RL frameworks are highly modular, which can make it confusing to identify the changes in a new method. Aside from universal components like the replay buffer, network architectures, etc., each implementation in this repo is contained in a single file. 7 | 2) Train and test on different environments (for generalization research). 8 | 3) Built-in Tensorboard logging, parameter saving. 9 | 4) Support for offline (batch) RL. 10 | 5) Quick setup for benchmarks like Gym MuJoco, Atari, PyBullet, and DeepMind Control Suite. 11 | 12 | ### What's included? 13 | 14 | #### Deep Deterministic Policy Gradient (DDPG) 15 | Paper: [Continuous control with deep reinforcement learning](https://arxiv.org/abs/1509.02971), Lillicrap et al., 2015. 16 | 17 | Description: a baseline model-free, offline, actor-critic method that forms the template for many of the other algorithms here. 18 | 19 | Code: `deep_control.ddpg` (*with extra comments for an intro to deep actor-critics*) 20 | Examples: `examples/basic_control/ddpg_gym.py` 21 | 22 | #### Twin Delayed DDPG (TD3) 23 | Paper: [Addressing Function Approximation Error in Actor-Critic Methods](https://arxiv.org/abs/1802.09477), Fujimoto et al., 2018. 24 | 25 | Description: Builds off of DDPG and makes several changes to improve the critic's learning and performance (Clipped Double Q Learning, Target Smoothing, Actor Delay). Also includes the TD regularization term from "[TD-Regularized Actor-Critic Methods](https://arxiv.org/abs/1812.08288)." 26 | 27 | Code: `deep_control.td3` 28 | Examples: `examples/basic_control/td3_gym.py` 29 | 30 | Other References: [author's implementation](https://github.com/sfujim/TD3) 31 | 32 | #### Soft Actor Critic (SAC) 33 | Paper: [Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor](https://arxiv.org/abs/1801.01290), Haarnoja et al., 2018. 34 | 35 | Description: Samples actions from a stochastic actor rather than relying on added exploration noise during training. Uses a TD3-like double critic system. We *do* implement the learnable entropy coefficient approach described in the [follow-up paper](https://arxiv.org/abs/1812.05905). This version also supports the self-regularized crticic updates from GRAC (see below). 36 | 37 | Code: `deep_control.sac` 38 | Examples: `examples/dmc/sac.py`, `examples/sacd_demo.py` 39 | 40 | Other References: [Yarats and Kostrikov's implementation](https://github.com/denisyarats/pytorch_sac), [author's implementation](https://github.com/haarnoja/sac). 41 | 42 | #### Pixel SAC with Data Augmentation (SAC+AUG) 43 | Paper: [Measuring Visual Generalization in Continuous Control from Pixels](https://arxiv.org/abs/2010.06740), Grigsby and Qi, 2020 44 | 45 | Description: This is a pixel-specific version of SAC with a few tricks/hyperparemter settings to improve performance. We include many different data augmentation techniques, including those used in [RAD](https://arxiv.org/abs/2004.14990), [DrQ](https://arxiv.org/abs/2004.13649) and [Network Randomization](https://arxiv.org/abs/1910.05396). The DrQ augmentation is turned on by default, and has a huge impact on performance. 46 | 47 | *Please Note: If you are interested in control from images, these features are implemented much more thoroughly in another repo: [jakegrigsby/super_sac](https://github.com/jakegrigsby/super_sac)* 48 | 49 | Code: `deep_control.sac_aug` 50 | Examples: `examples/dmcr/sac_aug.py` 51 | 52 | Other References: [SAC+AE code](https://github.com/denisyarats/pytorch_sac_ae), [RAD Procgen code](https://github.com/pokaxpoka/rad_procgen), [DrQ](https://github.com/denisyarats/drq) 53 | 54 | #### Self-Guided and Self-Regularized Actor-Critic (GRAC) 55 | Paper: [GRAC: Self-Regularized Actor-Critic](https://arxiv.org/abs/2009.08973), Shao et al., 2020. 56 | 57 | Description: GRAC is a combination of a stochastic policy with TD3-like stability improvements and CEM-based action selection like you'd see in Qt-Opt or CAQL. 58 | 59 | Code: `deep_control.grac` 60 | Examples: `examples/dmc/grac.py` 61 | 62 | Other References: [author's implementation](https://github.com/stanford-iprl-lab/GRAC) 63 | 64 | #### Randomized Ensemble Double Q-Learning (REDQ) 65 | Paper: [Randomized Ensemble Double Q-Learning: Learning Fast Without a Model](https://openreview.net/forum?id=AY8zfZm0tDd) 66 | 67 | Description: Extends the double Q trick to random subsets of a larger critic ensemble. Reduced Q function bias allows for a much higher replay ratio. REDQ is sample efficient but slow (compared to other model-free methods). We implement the SAC version. 68 | 69 | Code: `deep_control.redq` 70 | Examples: `examples/dmc/redq.py` 71 | 72 | #### Distributional Correction (DisCor) 73 | Paper: [DisCor: Corrective Feedback in Reinforcement Learning via Distribution Correction](https://arxiv.org/abs/2003.07305), Kumar et al., 2020. 74 | 75 | Description: Reduce the effect of inaccurate target values propagating through the Q-function by learning to estimate the target networks' inaccuracies and adjusting the TD error accordingly. Implemented on top of standard SAC. 76 | 77 | Code: `deep_control.discor` 78 | Examples: `examples/dmc/discor.py` 79 | 80 | #### Simple Unified Framework for Ensemble Learning (SUNRISE) 81 | Paper: [SUNRISE: A Simple Unified Framework for Ensemble Learning in Deep Reinforcement Learning](https://arxiv.org/abs/2007.04938), Lee et al., 2020. 82 | 83 | Description: Extends SAC using an ensemble of actors and critics. Adds UCB-based exploration, ensembled inference, and a simpler weighted bellman backup. This version does not use the replay buffer masks from the original. 84 | 85 | Code: `deep_control.sunrise` 86 | Examples: `examples/dmc/sunrise.py` 87 | 88 | #### Stochastic Behavioral Cloning (SBC) 89 | 90 | Description: A simple approach to offline RL that trains the actor network to emulate the action choices of the demonstration dataset. Uses the stochastic actor from SAC and some basic ensembling to make this a reasonable baseline. 91 | 92 | Code: `deep_control.sbc` 93 | Examples: `examples/d4rl/sbc.py` 94 | 95 | #### Advantage Weighted Actor Critic (AWAC) and Critic Regularized Regression (CRR) 96 | Paper: [Accelerating Online Reinforcement Learning with Offline Datasets](https://arxiv.org/abs/2006.09359), Nair et al., 2020. & [Critic Regularized Regression](https://arxiv.org/abs/2006.15134), Wang et al., 2020. 97 | 98 | Description: TD3 with a stochastic policy and a modified actor update that makes better use of offline experience before finetuning in the online environment. The current implementation is a mix between AWAC and CRR. We allow for online finetuning and use standard critic networks as in AWAC, but add the binary advantage function, and max/mean advantage estimates from CRR. The `actor_per` experience prioritization trick is discussed in [A Closer Look at Advantage-Filtered Behavioral Cloning 99 | in High-Noise Datasets](https://arxiv.org/abs/2110.04698), Grigsby and Qi, 2021. 100 | 101 | Code: `deep_control.awac` 102 | Examples: `examples/d4rl/awac.py` 103 | 104 | #### Automatic Actor Critic (AAC) 105 | Paper: [Towards Automatic Actor-Critic Solutions to Continuous Control](https://arxiv.org/abs/2106.08918), Grigsby et al., 2021 106 | 107 | Description: AAC uses a genetic algorithm to automatically tune the hyperparameters of SAC. A population of SAC agents is trained in parallel with a shared relay buffer and several design decisions that reduce hyperparameter sensitivity while (mostly) preserving sample efficiency. Please refer to the paper for more details. **This is the official author implementation.** 108 | 109 | Code: `deep_control.aac` 110 | 111 | 112 | ### Installation 113 | ```bash 114 | git clone https://github.com/jakegrigsby/deep_control.git 115 | cd deep_control 116 | pip install -e . 117 | ``` 118 | 119 | ### Examples 120 | see the `examples` folder for a look at how to train agents in environments like the DeepMind Control Suite and OpenAI Gym. 121 | 122 | -------------------------------------------------------------------------------- /deep_control/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 4 | 5 | from . import ( 6 | ddpg, 7 | sac, 8 | sac_aug, 9 | td3, 10 | grac, 11 | redq, 12 | tsr_caql, 13 | discor, 14 | sunrise, 15 | sbc, 16 | awac, 17 | aac, 18 | ) 19 | -------------------------------------------------------------------------------- /deep_control/adv_estimator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class AdvantageEstimator(nn.Module): 6 | def __init__( 7 | self, actor, critics, popart=False, method="mean", ensembling="mean", n=4 8 | ): 9 | super().__init__() 10 | assert method in ["mean", "max"] 11 | assert ensembling in ["min", "mean"] 12 | self.actor = actor 13 | self.critics = critics 14 | self.method = method 15 | self.ensembling = ensembling 16 | self.val_s = None 17 | self.popart = popart 18 | self._n = n 19 | 20 | def pop(self, q, s, a): 21 | if self.popart: 22 | return self.popart(q(s, a)) 23 | else: 24 | return q(s, a) 25 | 26 | def get_hparams(self): 27 | return {"adv_method": self.method, "adv_ensembling_method": self.method} 28 | 29 | def estimate_value(self, state): 30 | # get an action distribution from the policy 31 | act_dist = self.actor(state) 32 | actions = [act_dist.sample() for _ in range(self._n)] 33 | 34 | # get the q value for each of the n actions 35 | qs = [] 36 | for act in actions: 37 | q_preds = torch.stack( 38 | [self.pop(critic, state, act) for critic in self.critics], dim=0 39 | ) 40 | if self.ensembling == "min": 41 | q_preds = q_preds.min(0).values 42 | elif self.ensembling == "mean": 43 | q_preds = q_preds.mean(0) 44 | qs.append(q_preds) 45 | 46 | if self.method == "mean": 47 | # V(s) = E_{a ~ \pi(s)} [Q(s, a)] 48 | value = torch.stack(qs, dim=0).mean(0) 49 | elif self.method == "max": 50 | # Optimisitc value estimate: V(s) = max_{a1, a2, a3, ..., aN}(Q(s, a)) 51 | value = torch.stack(qs, dim=0).max(0).values 52 | self.val_s = value 53 | return value 54 | 55 | def forward(self, state, action, use_computed_val=False): 56 | with torch.no_grad(): 57 | q_preds = torch.stack( 58 | [self.pop(critic, state, action) for critic in self.critics], dim=0 59 | ) 60 | if self.ensembling == "min": 61 | q_preds = q_preds.min(0).values 62 | elif self.ensembling == "mean": 63 | q_preds = q_preds.mean(0) 64 | # reuse the expensive value computation if it has already been done 65 | if use_computed_val: 66 | assert self.val_s is not None 67 | else: 68 | # do the value computation 69 | self.estimate_value(state) 70 | # A(s, a) = Q(s, a) - V(s) 71 | adv = q_preds - self.val_s 72 | return adv 73 | 74 | 75 | class AdvEstimatorFilter(nn.Module): 76 | def __init__(self, adv_estimator, filter_type="binary", beta=1.0): 77 | super().__init__() 78 | self.adv_estimator = adv_estimator 79 | self.filter_type = filter_type 80 | self.beta = beta 81 | self._norm_a2 = 0.5 82 | 83 | def get_hparams(self): 84 | return {"filter_type": self.filter_type, "filter_beta": self.beta} 85 | 86 | def forward(self, s, a, step_num=None): 87 | adv = self.adv_estimator(s, a) 88 | if self.filter_type == "exp": 89 | filter_val = (self.beta * adv.clamp(-5.0, 5.0)).exp() 90 | elif self.filter_type == "binary": 91 | filter_val = (adv >= 0.0).float() 92 | elif self.filter_type == "exp_norm": 93 | self._norm_a2 += 1e-5 * (adv.mean() ** 2 - self._norm_a2) 94 | norm_a = a / ((self._norm_a2).sqrt() + 1e-5) 95 | filter_val = (self.beta * norm_a).exp() 96 | elif self.filter_type == "softmax": 97 | batch_size = s.shape[0] 98 | filter_val = batch_size * F.softmax(self.beta * adv, dim=0) 99 | elif self.filter_type == "identity": 100 | filter_val = torch.ones_like(adv) 101 | else: 102 | raise ValueError(f"Unrecognized filter type '{self.filter_type}'") 103 | # final clip for numerical stability (only applies to exp filters) 104 | return filter_val.clamp(-100.0, 100.0) 105 | -------------------------------------------------------------------------------- /deep_control/awac.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import math 4 | import os 5 | from itertools import chain 6 | 7 | import numpy as np 8 | import tensorboardX 9 | import torch 10 | import torch.nn.functional as F 11 | import torch.distributions as pyd 12 | import tqdm 13 | 14 | from . import envs, nets, replay, run, utils, device, sac 15 | from deep_control.adv_estimator import AdvantageEstimator, AdvEstimatorFilter 16 | 17 | 18 | class AWACAgent(sac.SACAgent): 19 | def __init__( 20 | self, 21 | obs_space_size, 22 | act_space_size, 23 | log_std_low, 24 | log_std_high, 25 | actor_net_cls=nets.StochasticActor, 26 | critic_net_cls=nets.BigCritic, 27 | hidden_size=1024, 28 | ): 29 | super().__init__( 30 | obs_space_size, 31 | act_space_size, 32 | log_std_low, 33 | log_std_high, 34 | actor_net_cls, 35 | critic_net_cls, 36 | hidden_size=hidden_size, 37 | ) 38 | self.actor.dist_impl = "pyd" 39 | 40 | 41 | def awac( 42 | agent, 43 | buffer, 44 | train_env, 45 | test_env, 46 | num_steps_offline=25_000, 47 | num_steps_online=500_000, 48 | gradient_updates_per_step=1, 49 | transitions_per_online_step=1, 50 | max_episode_steps=100_000, 51 | actor_per=True, 52 | batch_size=1024, 53 | tau=0.005, 54 | beta=1.0, 55 | crr_function="binary", 56 | adv_method="mean", 57 | adv_method_n=4, 58 | actor_lr=1e-4, 59 | critic_lr=1e-4, 60 | gamma=0.99, 61 | eval_interval=5000, 62 | eval_episodes=10, 63 | warmup_steps=1000, 64 | actor_clip=None, 65 | critic_clip=None, 66 | actor_l2=0.0, 67 | critic_l2=0.0, 68 | target_delay=2, 69 | actor_delay=1, 70 | save_interval=100_000, 71 | name="awac_run", 72 | render=False, 73 | save_to_disk=True, 74 | log_to_disk=True, 75 | verbosity=0, 76 | infinite_bootstrap=True, 77 | **kwargs, 78 | ): 79 | 80 | if save_to_disk or log_to_disk: 81 | save_dir = utils.make_process_dirs(name) 82 | if log_to_disk: 83 | # create tb writer, save hparams 84 | writer = tensorboardX.SummaryWriter(save_dir) 85 | writer.add_hparams(locals(), {}) 86 | 87 | ########### 88 | ## SETUP ## 89 | ########### 90 | agent.to(device) 91 | agent.train() 92 | # initialize target networks 93 | target_agent = copy.deepcopy(agent) 94 | target_agent.to(device) 95 | utils.hard_update(target_agent.critic1, agent.critic1) 96 | utils.hard_update(target_agent.critic2, agent.critic2) 97 | target_agent.train() 98 | # set up optimizers 99 | critic_optimizer = torch.optim.Adam( 100 | chain( 101 | agent.critic1.parameters(), 102 | agent.critic2.parameters(), 103 | ), 104 | lr=critic_lr, 105 | weight_decay=critic_l2, 106 | betas=(0.9, 0.999), 107 | ) 108 | actor_optimizer = torch.optim.Adam( 109 | agent.actor.parameters(), 110 | lr=actor_lr, 111 | weight_decay=actor_l2, 112 | betas=(0.9, 0.999), 113 | ) 114 | # set up adv filter 115 | adv_estimator = AdvantageEstimator( 116 | agent.actor, [agent.critic1, agent.critic2], method=adv_method, n=adv_method_n 117 | ) 118 | adv_filter = AdvEstimatorFilter(adv_estimator, crr_function, beta=beta) 119 | 120 | ################### 121 | ## TRAINING LOOP ## 122 | ################### 123 | 124 | total_steps = num_steps_offline + num_steps_online 125 | steps_iter = range(total_steps) 126 | if verbosity: 127 | steps_iter = tqdm.tqdm(steps_iter) 128 | 129 | done = True 130 | for step in steps_iter: 131 | 132 | if step > num_steps_offline: 133 | # collect online experience 134 | for _ in range(transitions_per_online_step): 135 | if done: 136 | state = train_env.reset() 137 | steps_this_ep = 0 138 | done = False 139 | action = agent.sample_action(state) 140 | next_state, reward, done, info = train_env.step(action) 141 | if infinite_bootstrap: 142 | # allow infinite bootstrapping 143 | if steps_this_ep + 1 == max_episode_steps: 144 | done = False 145 | buffer.push(state, action, reward, next_state, done) 146 | state = next_state 147 | steps_this_ep += 1 148 | if steps_this_ep >= max_episode_steps: 149 | done = True 150 | 151 | for _ in range(gradient_updates_per_step): 152 | learn_awac( 153 | buffer=buffer, 154 | target_agent=target_agent, 155 | agent=agent, 156 | adv_filter=adv_filter, 157 | actor_optimizer=actor_optimizer, 158 | critic_optimizer=critic_optimizer, 159 | batch_size=batch_size, 160 | gamma=gamma, 161 | critic_clip=critic_clip, 162 | actor_clip=actor_clip, 163 | update_policy=step % actor_delay == 0, 164 | actor_per=actor_per, 165 | ) 166 | 167 | # move target model towards training model 168 | if step % target_delay == 0: 169 | utils.soft_update(target_agent.critic1, agent.critic1, tau) 170 | utils.soft_update(target_agent.critic2, agent.critic2, tau) 171 | 172 | if (step % eval_interval == 0) or (step == total_steps - 1): 173 | mean_return = run.evaluate_agent( 174 | agent, test_env, eval_episodes, max_episode_steps, render 175 | ) 176 | if log_to_disk: 177 | writer.add_scalar( 178 | "return", mean_return, step * transitions_per_online_step 179 | ) 180 | 181 | if step % save_interval == 0 and save_to_disk: 182 | agent.save(save_dir) 183 | 184 | if save_to_disk: 185 | agent.save(save_dir) 186 | 187 | return agent 188 | 189 | 190 | def learn_awac( 191 | buffer, 192 | target_agent, 193 | agent, 194 | adv_filter, 195 | actor_optimizer, 196 | critic_optimizer, 197 | batch_size, 198 | gamma, 199 | critic_clip, 200 | actor_clip, 201 | update_policy=True, 202 | actor_per=True, 203 | ): 204 | if actor_per: 205 | assert isinstance(buffer, replay.PrioritizedReplayBuffer) 206 | # sample with priorities for actor update 207 | actor_batch, *_ = buffer.sample(batch_size) 208 | # critic samples uniformly to find new high-adv experience 209 | critic_batch, priority_idxs = buffer.sample_uniform(batch_size) 210 | else: 211 | batch = buffer.sample(batch_size) 212 | actor_batch = batch 213 | critic_batch = batch 214 | 215 | agent.train() 216 | ################### 217 | ## CRITIC UPDATE ## 218 | ################### 219 | state_batch, action_batch, reward_batch, next_state_batch, done_batch = critic_batch 220 | state_batch = state_batch.to(device) 221 | next_state_batch = next_state_batch.to(device) 222 | action_batch = action_batch.to(device) 223 | reward_batch = reward_batch.to(device) 224 | done_batch = done_batch.to(device) 225 | 226 | with torch.no_grad(): 227 | action_dist_s1 = agent.actor(next_state_batch) 228 | action_s1 = action_dist_s1.rsample() 229 | logp_a1 = action_dist_s1.log_prob(action_s1).sum(-1, keepdim=True) 230 | target_action_value_s1 = torch.min( 231 | target_agent.critic1(next_state_batch, action_s1), 232 | target_agent.critic2(next_state_batch, action_s1), 233 | ) 234 | td_target = reward_batch + gamma * (1.0 - done_batch) * target_action_value_s1 235 | 236 | # update critics 237 | agent_critic1_pred = agent.critic1(state_batch, action_batch) 238 | agent_critic2_pred = agent.critic2(state_batch, action_batch) 239 | td_error1 = td_target - agent_critic1_pred 240 | td_error2 = td_target - agent_critic2_pred 241 | critic_loss = 0.5 * (td_error1 ** 2 + td_error2 ** 2) 242 | critic_loss = critic_loss.mean() 243 | critic_optimizer.zero_grad() 244 | critic_loss.backward() 245 | if critic_clip: 246 | torch.nn.utils.clip_grad_norm_( 247 | chain(agent.critic1.parameters(), agent.critic2.parameters()), critic_clip 248 | ) 249 | critic_optimizer.step() 250 | 251 | if actor_per: 252 | with torch.no_grad(): 253 | adv = adv_filter.adv_estimator(state_batch, action_batch) 254 | new_priorities = (F.relu(adv) + 1e-5).cpu().detach().squeeze(1).numpy() 255 | buffer.update_priorities(priority_idxs, new_priorities) 256 | 257 | if update_policy: 258 | ################## 259 | ## ACTOR UPDATE ## 260 | ################## 261 | state_batch, *_ = actor_batch 262 | state_batch = state_batch.to(device) 263 | 264 | dist = agent.actor(state_batch) 265 | actions = dist.sample() 266 | logp_a = dist.log_prob(actions).sum(-1, keepdim=True) 267 | with torch.no_grad(): 268 | filtered_adv = adv_filter(state_batch, actions) 269 | actor_loss = -(logp_a * filtered_adv).mean() 270 | 271 | actor_optimizer.zero_grad() 272 | actor_loss.backward() 273 | if actor_clip: 274 | torch.nn.utils.clip_grad_norm_(agent.actor.parameters(), actor_clip) 275 | actor_optimizer.step() 276 | 277 | 278 | def add_args(parser): 279 | parser.add_argument( 280 | "--num_steps_offline", 281 | type=int, 282 | default=500_000, 283 | help="Number of steps of offline learning", 284 | ) 285 | parser.add_argument( 286 | "--num_steps_online", 287 | type=int, 288 | default=50_000, 289 | help="Number of steps of online learning", 290 | ) 291 | parser.add_argument( 292 | "--transitions_per_online_step", 293 | type=int, 294 | default=1, 295 | help="env transitions per training step. Defaults to 1, but will need to \ 296 | be set higher for repaly ratios < 1", 297 | ) 298 | parser.add_argument( 299 | "--max_episode_steps", 300 | type=int, 301 | default=100000, 302 | help="maximum steps per episode", 303 | ) 304 | parser.add_argument( 305 | "--batch_size", type=int, default=1024, help="training batch size" 306 | ) 307 | parser.add_argument( 308 | "--tau", type=float, default=0.005, help="for model parameter % update" 309 | ) 310 | parser.add_argument( 311 | "--actor_lr", type=float, default=1e-4, help="actor learning rate" 312 | ) 313 | parser.add_argument( 314 | "--critic_lr", type=float, default=1e-4, help="critic learning rate" 315 | ) 316 | parser.add_argument( 317 | "--gamma", type=float, default=0.99, help="gamma, the discount factor" 318 | ) 319 | parser.add_argument( 320 | "--buffer_size", type=int, default=1_000_000, help="replay buffer size" 321 | ) 322 | parser.add_argument( 323 | "--eval_interval", 324 | type=int, 325 | default=5000, 326 | help="how often to test the agent without exploration (in episodes)", 327 | ) 328 | parser.add_argument( 329 | "--eval_episodes", 330 | type=int, 331 | default=10, 332 | help="how many episodes to run for when testing", 333 | ) 334 | parser.add_argument( 335 | "--render", 336 | action="store_true", 337 | help="flag to enable env rendering during training", 338 | ) 339 | parser.add_argument( 340 | "--actor_clip", 341 | type=float, 342 | default=None, 343 | help="gradient clipping for actor updates", 344 | ) 345 | parser.add_argument( 346 | "--critic_clip", 347 | type=float, 348 | default=None, 349 | help="gradient clipping for critic updates", 350 | ) 351 | parser.add_argument( 352 | "--name", type=str, default="awac_run", help="dir name for saves" 353 | ) 354 | parser.add_argument( 355 | "--actor_l2", 356 | type=float, 357 | default=0.0, 358 | help="L2 regularization coeff for actor network", 359 | ) 360 | parser.add_argument( 361 | "--critic_l2", 362 | type=float, 363 | default=0.0, 364 | help="L2 regularization coeff for critic network", 365 | ) 366 | parser.add_argument( 367 | "--target_delay", 368 | type=int, 369 | default=2, 370 | help="How many steps to go between target network updates", 371 | ) 372 | parser.add_argument( 373 | "--actor_delay", 374 | type=int, 375 | default=1, 376 | help="How many steps to go between actor updates", 377 | ) 378 | parser.add_argument( 379 | "--save_interval", 380 | type=int, 381 | default=100_000, 382 | help="How many steps to go between saving the agent params to disk", 383 | ) 384 | parser.add_argument( 385 | "--verbosity", 386 | type=int, 387 | default=1, 388 | help="verbosity > 0 displays a progress bar during training", 389 | ) 390 | parser.add_argument( 391 | "--gradient_updates_per_step", 392 | type=int, 393 | default=1, 394 | help="how many gradient updates to make per training step", 395 | ) 396 | parser.add_argument( 397 | "--skip_save_to_disk", 398 | action="store_true", 399 | help="flag to skip saving agent params to disk during training", 400 | ) 401 | parser.add_argument( 402 | "--skip_log_to_disk", 403 | action="store_true", 404 | help="flag to skip saving agent performance logs to disk during training", 405 | ) 406 | parser.add_argument( 407 | "--log_std_low", 408 | type=float, 409 | default=-10, 410 | help="Lower bound for log std of action distribution.", 411 | ) 412 | parser.add_argument( 413 | "--log_std_high", 414 | type=float, 415 | default=2, 416 | help="Upper bound for log std of action distribution.", 417 | ) 418 | parser.add_argument( 419 | "--beta", 420 | type=float, 421 | default=1.0, 422 | help="Lambda variable from AWAC actor update and Beta from CRR", 423 | ) 424 | parser.add_argument( 425 | "--crr_function", 426 | type=str, 427 | default="binary", 428 | choices=["binary", "exp", "exp_norm"], 429 | help="Approach for adjusting advantage weights", 430 | ) 431 | parser.add_argument( 432 | "--adv_method", 433 | type=str, 434 | default="mean", 435 | help="Approach for estimating the advantage function. Choices include {'max', 'mean'}.", 436 | ) 437 | parser.add_argument( 438 | "--adv_method_n", 439 | type=int, 440 | default=4, 441 | help="How many actions to sample from the policy when estimating the advantage. CRR uses 4.", 442 | ) 443 | -------------------------------------------------------------------------------- /deep_control/critic_searchers.py: -------------------------------------------------------------------------------- 1 | import deep_control as dc 2 | 3 | import torch 4 | import numpy as np 5 | 6 | """ 7 | This is code is from https://github.com/stanford-iprl-lab/GRAC/blob/master/ES.py 8 | """ 9 | 10 | 11 | class _CEM: 12 | def __init__( 13 | self, 14 | num_params, 15 | mu_init=None, 16 | batch_size=256, 17 | sigma_init=1e-3, 18 | clip=0.5, 19 | pop_size=256, 20 | damp=1e-3, 21 | damp_limit=1e-5, 22 | parents=None, 23 | elitism=True, 24 | device=dc.device, 25 | ): 26 | 27 | # misc 28 | self.num_params = num_params 29 | self.batch_size = batch_size 30 | self.device = device 31 | # distribution parameters 32 | if mu_init is None: 33 | self.mu = torch.zeros([self.batch_size, self.num_params], device=device) 34 | else: 35 | self.mu = mu_init.clone() 36 | self.sigma = sigma_init 37 | self.damp = damp 38 | self.damp_limit = damp_limit 39 | self.tau = 0.95 40 | self.cov = self.sigma * torch.ones( 41 | [self.batch_size, self.num_params], device=device 42 | ) 43 | self.clip = clip 44 | 45 | # elite stuff 46 | self.elitism = elitism 47 | self.elite = torch.sqrt(torch.tensor(self.sigma, device=device)) * torch.rand( 48 | self.batch_size, self.num_params, device=device 49 | ) 50 | self.elite_score = None 51 | 52 | # sampling stuff 53 | self.pop_size = pop_size 54 | if parents is None or parents <= 0: 55 | self.parents = pop_size // 2 56 | else: 57 | self.parents = parents 58 | self.weights = torch.FloatTensor( 59 | [np.log((self.parents + 1) / i) for i in range(1, self.parents + 1)] 60 | ).to(device) 61 | self.weights /= self.weights.sum() 62 | 63 | def ask(self, pop_size): 64 | """ 65 | Returns a list of candidates parameters 66 | """ 67 | epsilon = torch.randn( 68 | self.batch_size, pop_size, self.num_params, device=self.device 69 | ) 70 | inds = self.mu.unsqueeze(1) + ( 71 | epsilon * torch.sqrt(self.cov).unsqueeze(1) 72 | ).clamp(-self.clip, self.clip) 73 | if self.elitism: 74 | inds[:, -1] = self.elite 75 | return inds 76 | 77 | def tell(self, solutions, scores): 78 | """ 79 | Updates the distribution 80 | returns the best solution 81 | """ 82 | scores = scores.clone().squeeze() 83 | scores *= -1 84 | if len(scores.shape) == 1: 85 | scores = scores[None, :] 86 | _, idx_sorted = torch.sort(scores, dim=1) 87 | 88 | old_mu = self.mu.clone() 89 | self.damp = self.damp * self.tau + (1 - self.tau) * self.damp_limit 90 | idx_sorted = idx_sorted[:, : self.parents] 91 | top_solutions = torch.gather( 92 | solutions, 93 | 1, 94 | idx_sorted.unsqueeze(2).expand(*idx_sorted.shape, solutions.shape[-1]), 95 | ) 96 | self.mu = self.weights @ top_solutions 97 | z = top_solutions - old_mu.unsqueeze(1) 98 | self.cov = 1 / self.parents * self.weights @ (z * z) + self.damp * torch.ones( 99 | [self.batch_size, self.num_params], device=self.device 100 | ) 101 | 102 | self.elite = top_solutions[:, 0, :] 103 | 104 | return top_solutions[:, 0, :] 105 | 106 | def get_distrib_params(self): 107 | """ 108 | Returns the parameters of the distrubtion: 109 | the mean and sigma 110 | """ 111 | return self.mu.clone(), self.cov.clone() 112 | 113 | 114 | class CEM: 115 | def __init__( 116 | self, 117 | action_dim, 118 | max_action, 119 | batch_size=256, 120 | sigma_init=1e-3, 121 | clip=0.5, 122 | pop_size=25, 123 | damp=0.1, 124 | damp_limit=0.05, 125 | parents=5, 126 | device=dc.device, 127 | ): 128 | self.sigma_init = sigma_init 129 | self.clip = clip 130 | self.pop_size = pop_size 131 | self.damp = damp 132 | self.damp_limit = damp_limit 133 | self.parents = parents 134 | self.action_dim = action_dim 135 | self.batch_size = batch_size 136 | self.max_action = max_action 137 | self.device = device 138 | 139 | def search( 140 | self, state, action_init, critic, batch_size=None, n_iter=2, action_bound=True 141 | ): 142 | if batch_size is None: 143 | batch_size = self.batch_size 144 | cem = _CEM( 145 | self.action_dim, 146 | action_init, 147 | batch_size, 148 | self.sigma_init, 149 | self.clip, 150 | self.pop_size, 151 | self.damp, 152 | self.damp_limit, 153 | self.parents, 154 | device=self.device, 155 | ) 156 | with torch.no_grad(): 157 | for iter in range(n_iter): 158 | actions = cem.ask(self.pop_size) 159 | if action_bound: 160 | actions = actions.clamp(-self.max_action, self.max_action) 161 | actions_temp = actions.clone().view(self.pop_size * batch_size, -1) 162 | Qs = critic( 163 | state.unsqueeze(1) 164 | .repeat(1, self.pop_size, 1) 165 | .view(self.pop_size * batch_size, -1), 166 | actions_temp, 167 | ).view(batch_size, self.pop_size) 168 | best_action = cem.tell(actions, Qs) 169 | if iter == n_iter - 1: 170 | best_Q = critic(state, best_action) 171 | ori_Q = critic(state, action_init) 172 | 173 | action_index = (best_Q < ori_Q).squeeze() 174 | best_action[action_index] = action_init[action_index] 175 | 176 | return best_action 177 | -------------------------------------------------------------------------------- /deep_control/ddpg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import os 4 | 5 | import numpy as np 6 | import tensorboardX 7 | import torch 8 | import torch.nn.functional as F 9 | import tqdm 10 | 11 | from . import envs, nets, replay, run, utils 12 | 13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 14 | 15 | 16 | class DDPGAgent: 17 | def __init__( 18 | self, 19 | obs_space_size, 20 | action_space_size, 21 | actor_net_cls=nets.BaselineActor, 22 | critic_net_cls=nets.BaselineCritic, 23 | hidden_size=256, 24 | ): 25 | self.actor = actor_net_cls( 26 | obs_space_size, action_space_size, hidden_size=hidden_size 27 | ) 28 | self.critic = critic_net_cls( 29 | obs_space_size, action_space_size, hidden_size=hidden_size 30 | ) 31 | 32 | def to(self, device): 33 | self.actor = self.actor.to(device) 34 | self.critic = self.critic.to(device) 35 | 36 | def eval(self): 37 | self.actor.eval() 38 | self.critic.eval() 39 | 40 | def train(self): 41 | self.actor.train() 42 | self.critic.train() 43 | 44 | def save(self, path): 45 | actor_path = os.path.join(path, "actor.pt") 46 | critic_path = os.path.join(path, "critic.pt") 47 | torch.save(self.actor.state_dict(), actor_path) 48 | torch.save(self.critic.state_dict(), critic_path) 49 | 50 | def load(self, path): 51 | actor_path = os.path.join(path, "actor.pt") 52 | critic_path = os.path.join(path, "critic.pt") 53 | self.actor.load_state_dict(torch.load(actor_path)) 54 | self.critic.load_state_dict(torch.load(critic_path)) 55 | 56 | def forward(self, state): 57 | state = self.process_state(state) 58 | self.actor.eval() 59 | with torch.no_grad(): 60 | action = self.actor(state) 61 | self.actor.train() 62 | return np.squeeze(action.cpu().numpy(), 0) 63 | 64 | def process_state(self, state): 65 | return torch.from_numpy(np.expand_dims(state, 0).astype(np.float32)).to( 66 | utils.device 67 | ) 68 | 69 | 70 | def ddpg( 71 | agent, 72 | train_env, 73 | test_env, 74 | buffer, 75 | num_steps=1_000_000, 76 | transitions_per_step=1, 77 | max_episode_steps=100_000, 78 | batch_size=256, 79 | tau=0.005, 80 | actor_lr=1e-4, 81 | critic_lr=1e-3, 82 | gamma=0.99, 83 | sigma_start=0.2, 84 | sigma_final=0.1, 85 | sigma_anneal=100_000, 86 | theta=0.15, 87 | eval_interval=5000, 88 | eval_episodes=10, 89 | warmup_steps=1000, 90 | render=False, 91 | actor_clip=None, 92 | critic_clip=None, 93 | name="ddpg_run", 94 | actor_l2=0.0, 95 | critic_l2=0.0, 96 | save_interval=100_000, 97 | log_to_disk=True, 98 | save_to_disk=True, 99 | verbosity=0, 100 | gradient_updates_per_step=1, 101 | infinite_bootstrap=True, 102 | **_, 103 | ): 104 | """ 105 | Train `agent` on `train_env` with the Deep Deterministic Policy Gradient algorithm, 106 | and evaluate on `test_env`. 107 | 108 | Reference: https://arxiv.org/abs/1509.02971 109 | """ 110 | if save_to_disk or log_to_disk: 111 | # create save directory for this run 112 | save_dir = utils.make_process_dirs(name) 113 | if log_to_disk: 114 | # create tb writer, save hparams 115 | writer = tensorboardX.SummaryWriter(save_dir) 116 | writer.add_hparams(locals(), {}) 117 | 118 | agent.to(device) 119 | 120 | # initialize target networks 121 | target_agent = copy.deepcopy(agent) 122 | target_agent.to(device) 123 | utils.hard_update(target_agent.actor, agent.actor) 124 | utils.hard_update(target_agent.critic, agent.critic) 125 | 126 | # Ornstein-Uhlenbeck is a controlled random walk used 127 | # to introduce noise for exploration. The DDPG paper 128 | # picks it over the simpler gaussian noise alternative, 129 | # but later work has shown this is an unnecessary detail. 130 | random_process = utils.OrnsteinUhlenbeckProcess( 131 | theta=theta, 132 | size=train_env.action_space.shape, 133 | sigma=sigma_start, 134 | sigma_min=sigma_final, 135 | n_steps_annealing=sigma_anneal, 136 | ) 137 | 138 | critic_optimizer = torch.optim.Adam( 139 | agent.critic.parameters(), lr=critic_lr, weight_decay=critic_l2 140 | ) 141 | actor_optimizer = torch.optim.Adam( 142 | agent.actor.parameters(), lr=actor_lr, weight_decay=actor_l2 143 | ) 144 | 145 | # the replay buffer is filled with a few thousand transitions by 146 | # sampling from a uniform random policy, so that learning can begin 147 | # from a buffer that is >> the batch size. 148 | run.warmup_buffer(buffer, train_env, warmup_steps, max_episode_steps) 149 | 150 | done = True 151 | 152 | steps_iter = range(num_steps) 153 | if verbosity: 154 | # fancy progress bar 155 | steps_iter = tqdm.tqdm(steps_iter) 156 | 157 | for step in steps_iter: 158 | for _ in range(transitions_per_step): 159 | # collect experience from the environment, sampling from 160 | # the current policy (with added noise for exploration) 161 | if done: 162 | # reset the environment 163 | state = train_env.reset() 164 | random_process.reset_states() 165 | steps_this_ep = 0 166 | done = False 167 | action = agent.forward(state) 168 | noisy_action = run.exploration_noise(action, random_process) 169 | next_state, reward, done, info = train_env.step(noisy_action) 170 | if infinite_bootstrap: 171 | # allow infinite bootstrapping. Many envs terminate 172 | # (done = True) after an arbitrary number of steps 173 | # to let the agent reset and avoid getting stuck in 174 | # a failed position. infinite bootstrapping prevents 175 | # this from impacting our Q function calculation. This 176 | # can be harmful in edge cases where the environment really 177 | # would have ended (task failed) regardless of the step limit, 178 | # and makes no difference if the environment is not set up 179 | # to enforce a limit by itself (but many common benchmarks are). 180 | if steps_this_ep + 1 == max_episode_steps: 181 | done = False 182 | # add this transition to the replay buffer 183 | buffer.push(state, noisy_action, reward, next_state, done) 184 | state = next_state 185 | steps_this_ep += 1 186 | if steps_this_ep >= max_episode_steps: 187 | # enforce max step limit from the agent's perspective 188 | done = True 189 | 190 | for _ in range(gradient_updates_per_step): 191 | # update the actor and critics using the replay buffer 192 | learn( 193 | buffer=buffer, 194 | target_agent=target_agent, 195 | agent=agent, 196 | actor_optimizer=actor_optimizer, 197 | critic_optimizer=critic_optimizer, 198 | batch_size=batch_size, 199 | gamma=gamma, 200 | critic_clip=critic_clip, 201 | actor_clip=actor_clip, 202 | ) 203 | 204 | # move target models towards the online models 205 | # CC algorithms typically use a moving average rather 206 | # than the full copy of a DQN. 207 | utils.soft_update(target_agent.actor, agent.actor, tau) 208 | utils.soft_update(target_agent.critic, agent.critic, tau) 209 | 210 | if step % eval_interval == 0 or step == num_steps - 1: 211 | mean_return = run.evaluate_agent( 212 | agent, test_env, eval_episodes, max_episode_steps, render 213 | ) 214 | if log_to_disk: 215 | writer.add_scalar("return", mean_return, step * transitions_per_step) 216 | if step % save_interval == 0 and save_to_disk: 217 | agent.save(save_dir) 218 | 219 | if save_to_disk: 220 | agent.save(save_dir) 221 | return agent 222 | 223 | 224 | def learn( 225 | buffer, 226 | target_agent, 227 | agent, 228 | actor_optimizer, 229 | critic_optimizer, 230 | batch_size, 231 | gamma, 232 | critic_clip, 233 | actor_clip, 234 | ): 235 | """ 236 | DDPG inner optimization loop. The simplest deep 237 | actor critic update. 238 | """ 239 | # support for prioritized experience replay is 240 | # included in almost every algorithm in this repo. however, 241 | # it is somewhat rarely used in recent work because of its 242 | # extra hyperparameters and implementation complexity. 243 | per = isinstance(buffer, replay.PrioritizedReplayBuffer) 244 | if per: 245 | batch, imp_weights, priority_idxs = buffer.sample(batch_size) 246 | imp_weights = imp_weights.to(device) 247 | else: 248 | batch = buffer.sample(batch_size) 249 | 250 | # send transitions to the gpu 251 | state_batch, action_batch, reward_batch, next_state_batch, done_batch = batch 252 | state_batch = state_batch.to(device) 253 | next_state_batch = next_state_batch.to(device) 254 | action_batch = action_batch.to(device) 255 | reward_batch = reward_batch.to(device) 256 | done_batch = done_batch.to(device) 257 | 258 | ################### 259 | ## Critic Update ## 260 | ################### 261 | 262 | # compute target values 263 | with torch.no_grad(): 264 | target_action_s1 = target_agent.actor(next_state_batch) 265 | target_action_value_s1 = target_agent.critic(next_state_batch, target_action_s1) 266 | # bootstrapped estimate of Q(s, a) based on reward and target network 267 | td_target = reward_batch + gamma * (1.0 - done_batch) * target_action_value_s1 268 | 269 | # compute mean squared bellman error (MSE(Q(s, a), td_target)) 270 | agent_critic_pred = agent.critic(state_batch, action_batch) 271 | td_error = td_target - agent_critic_pred 272 | if per: 273 | critic_loss = (imp_weights * 0.5 * (td_error ** 2)).mean() 274 | else: 275 | critic_loss = 0.5 * (td_error ** 2).mean() 276 | critic_optimizer.zero_grad() 277 | # gradient descent step on critic network 278 | critic_loss.backward() 279 | if critic_clip: 280 | torch.nn.utils.clip_grad_norm_(agent.critic.parameters(), critic_clip) 281 | critic_optimizer.step() 282 | 283 | ################## 284 | ## Actor Update ## 285 | ################## 286 | 287 | # actor's objective is to maximize (or minimize the negative of) 288 | # the expectation of the critic's opinion of its action choices 289 | agent_actions = agent.actor(state_batch) 290 | actor_loss = -agent.critic(state_batch, agent_actions).mean() 291 | actor_optimizer.zero_grad() 292 | # gradient descent step on actor network 293 | actor_loss.backward() 294 | if actor_clip: 295 | torch.nn.utils.clip_grad_norm_(agent.actor.parameters(), actor_clip) 296 | actor_optimizer.step() 297 | 298 | if per: 299 | # update prioritized replay distribution 300 | new_priorities = (abs(td_error) + 1e-5).cpu().detach().squeeze(1).numpy() 301 | buffer.update_priorities(priority_idxs, new_priorities) 302 | 303 | 304 | def add_args(parser): 305 | parser.add_argument( 306 | "--num_steps", type=int, default=1000000, help="number of training steps" 307 | ) 308 | parser.add_argument( 309 | "--transitions_per_step", 310 | type=int, 311 | default=1, 312 | help="number of env steps per training step", 313 | ) 314 | parser.add_argument( 315 | "--max_episode_steps", 316 | type=int, 317 | default=100000, 318 | help="maximum steps per episode", 319 | ) 320 | parser.add_argument( 321 | "--batch_size", type=int, default=256, help="training batch size" 322 | ) 323 | parser.add_argument( 324 | "--tau", 325 | type=float, 326 | default=0.005, 327 | help="controls the speed that the target networks converge to the online networks", 328 | ) 329 | parser.add_argument( 330 | "--actor_lr", type=float, default=1e-4, help="actor network learning rate" 331 | ) 332 | parser.add_argument( 333 | "--critic_lr", type=float, default=1e-3, help="critic network learning rate" 334 | ) 335 | parser.add_argument( 336 | "--gamma", 337 | type=float, 338 | default=0.99, 339 | help="gamma, the MDP discount factor that determines emphasis on long-term rewards", 340 | ) 341 | parser.add_argument( 342 | "--sigma_final", 343 | type=float, 344 | default=0.1, 345 | help="final sigma value for Ornstein Uhlenbeck exploration process", 346 | ) 347 | parser.add_argument( 348 | "--sigma_anneal", 349 | type=float, 350 | default=100_000, 351 | help="How many steps to anneal sigma over.", 352 | ) 353 | parser.add_argument( 354 | "--sigma_start", 355 | type=float, 356 | default=0.2, 357 | help="sigma for Ornstein Uhlenbeck exploration process", 358 | ) 359 | parser.add_argument( 360 | "--theta", 361 | type=float, 362 | default=0.15, 363 | help="theta for Ornstein Uhlenbeck exploration process", 364 | ) 365 | parser.add_argument( 366 | "--eval_interval", 367 | type=int, 368 | default=5000, 369 | help="how often to test the agent without exploration (in steps)", 370 | ) 371 | parser.add_argument( 372 | "--eval_episodes", 373 | type=int, 374 | default=10, 375 | help="how many episodes to run for when testing. results are averaged over this many episodes", 376 | ) 377 | parser.add_argument( 378 | "--warmup_steps", 379 | type=int, 380 | default=1000, 381 | help="how many random steps to take before learning begins", 382 | ) 383 | parser.add_argument( 384 | "--render", 385 | action="store_true", 386 | help="render the environment during training. can slow training significantly", 387 | ) 388 | parser.add_argument( 389 | "--actor_clip", 390 | type=float, 391 | default=None, 392 | help="clip actor gradients based on this norm. less commonly used in actor critic algs than DQN", 393 | ) 394 | parser.add_argument( 395 | "--critic_clip", 396 | type=float, 397 | default=None, 398 | help="clip critic gradients based on this norm. less commonly used in actor critic algs than DQN", 399 | ) 400 | parser.add_argument( 401 | "--name", 402 | type=str, 403 | default="ddpg_run", 404 | help="we will save the results of this training run in a directory called dc_saves/{this name}", 405 | ) 406 | parser.add_argument( 407 | "--actor_l2", 408 | type=float, 409 | default=0.0, 410 | help="actor network L2 regularization coeff. Typically not helpful in single-environment settings", 411 | ) 412 | parser.add_argument( 413 | "--critic_l2", 414 | type=float, 415 | default=0.0, 416 | help="critic network L2 regularization coeff. Typically not helpful in single-environment settings", 417 | ) 418 | parser.add_argument( 419 | "--save_interval", 420 | type=int, 421 | default=100_000, 422 | help="how often (in terms of steps) to save the network weights to disk", 423 | ) 424 | parser.add_argument( 425 | "--verbosity", 426 | type=int, 427 | default=1, 428 | help="set to 0 for quiet mode (limit printing to std out). 1 shows a progress bar", 429 | ) 430 | parser.add_argument( 431 | "--skip_save_to_disk", 432 | action="store_true", 433 | help="do not save the agent weights to disk during this training run", 434 | ) 435 | parser.add_argument( 436 | "--skip_log_to_disk", 437 | action="store_true", 438 | help="do not write results to tensorboard during this training run", 439 | ) 440 | parser.add_argument( 441 | "--gradient_updates_per_step", 442 | type=int, 443 | default=1, 444 | help="learning updates per training step (aka replay ratio denominator)", 445 | ) 446 | parser.add_argument( 447 | "--prioritized_replay", 448 | action="store_true", 449 | help="Flag to enable prioritized experience replay", 450 | ) 451 | parser.add_argument( 452 | "--buffer_size", 453 | type=int, 454 | default=1_000_000, 455 | help="Maximum size of the replay buffer before oldest transitions are overwritten. Note that the default deep_control buffer allocates all of this space at the start of training to fail fast when there won't be enough space.", 456 | ) 457 | -------------------------------------------------------------------------------- /deep_control/envs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | from collections import deque 4 | import math 5 | 6 | import gym 7 | import numpy as np 8 | 9 | 10 | class PersistenceAwareWrapper(gym.Wrapper): 11 | def __init__(self, env, k=1, return_history=True, discount=1.0): 12 | super().__init__(env) 13 | self.k = k 14 | self.return_history = return_history 15 | ob_shape = (self.observation_space.shape[0] + 1,) 16 | ob_low = np.concatenate(([1], self.observation_space.low)) 17 | ob_high = np.concatenate(([float("inf")], self.observation_space.high)) 18 | self.observation_space = gym.spaces.Box( 19 | low=ob_low, high=ob_high, shape=ob_shape 20 | ) 21 | self.discount = discount 22 | 23 | def step(self, action): 24 | reward_history = np.zeros((self.k,)) 25 | done = False 26 | for step in range(self.k): 27 | if not done: 28 | next_state, reward, done, _ = self.env.step(action) 29 | reward_history[step] = reward * self.discount ** step 30 | else: 31 | reward_history[step] = 0 32 | if self.return_history: 33 | return self.obs(next_state), reward_history, done, {} 34 | else: 35 | return self.obs(next_state), reward_history.sum(), done, {} 36 | 37 | def obs(self, state): 38 | return np.concatenate([np.array((self.k,)), state], axis=0) 39 | 40 | def reset(self): 41 | return self.obs(self.env.reset()) 42 | 43 | def set_k(self, k): 44 | assert k > 0, "attempted to set action repeat <= 0" 45 | self.k = k 46 | 47 | 48 | class ActionRepeatOutputWrapper(gym.Wrapper): 49 | def __init__(self, env, repeat_multiplier=8): 50 | super().__init__(env) 51 | self.action_space = gym.spaces.Box( 52 | -1.0, 1.0, shape=(1 + self.env.action_space.shape[0],) 53 | ) 54 | self.repeat_multiplier = repeat_multiplier / 2.0 55 | 56 | def step(self, action): 57 | repeat_action = max(math.floor((action[0] + 1.0) * self.repeat_multiplier), 1) 58 | main_action = action[1:] 59 | total_reward = 0 60 | for _ in range(repeat_action): 61 | next_state, reward, done, _ = self.env.step(main_action) 62 | total_reward += reward 63 | return next_state, total_reward, done, {} 64 | 65 | 66 | class ChannelsFirstWrapper(gym.ObservationWrapper): 67 | """ 68 | Some pixel-based gym environments use a (Height, Width, Channel) image format. 69 | This wrapper rolls those axes to (Channel, Height, Width) to work with pytorch 70 | Conv2D layers. 71 | """ 72 | 73 | def __init__(self, env): 74 | super().__init__(env) 75 | self.observation_space.shape = ( 76 | env.observation_space.shape[-1], 77 | ) + env.observation_space.shape[:-1] 78 | 79 | def observation(self, frame): 80 | frame = np.transpose(frame, (2, 0, 1)) 81 | return np.ascontiguousarray(frame) 82 | 83 | 84 | class NormalizeObservationSpace(gym.ObservationWrapper): 85 | def __init__(self, env, obs_mean, obs_std): 86 | super().__init__(env) 87 | self.mean = obs_mean 88 | self.std = obs_std + 1e-5 89 | 90 | def observation(self, x): 91 | return (x - self.mean) / self.std 92 | 93 | 94 | class NormalizeContinuousActionSpace(gym.ActionWrapper): 95 | def __init__(self, env): 96 | super().__init__(env) 97 | self._true_action_space = env.action_space 98 | self.action_space = gym.spaces.Box( 99 | low=-1.0, 100 | high=1.0, 101 | shape=self._true_action_space.shape, 102 | dtype=np.float32, 103 | ) 104 | 105 | def action(self, action): 106 | true_delta = self._true_action_space.high - self._true_action_space.low 107 | norm_delta = self.action_space.high - self.action_space.low 108 | action = (action - self.action_space.low) / norm_delta 109 | action = action * true_delta + self._true_action_space.low 110 | return action 111 | 112 | 113 | def robosuite_action_adjustment(robosuite_env, verbose=False): 114 | if verbose: 115 | action_space = robosuite_env.action_space 116 | high = action_space.high 117 | same_high = np.all(high == high[0]) 118 | low = action_space.low 119 | same_low = np.all(low == low[0]) 120 | shape = action_space.shape[0] 121 | print("RoboSuite Action Space Report:") 122 | if same_high and same_low: 123 | print(f"Uniformly Bounded Action Space in [{low[0]}, {high[0]}]^{shape}") 124 | else: 125 | print(f"Non-uniform Bounded Action Space with elements = {zip(low, high)}") 126 | print("\nAttempting to normalize action space using dc.envs.Normalize...\n") 127 | env = NormalizeContinuousActionSpace(robosuite_env) 128 | if verbose: 129 | action_space = env.action_space 130 | high = action_space.high 131 | same_high = np.all(high == high[0]) 132 | low = action_space.low 133 | same_low = np.all(low == low[0]) 134 | shape = action_space.shape[0] 135 | print("Normalized RoboSuite Action Space Report:") 136 | if same_high and same_low: 137 | print(f"Uniformly Bounded Action Space in [{low[0]}, {high[0]}]^{shape}") 138 | else: 139 | print(f"Non-uniform Bounded Action Space with elements = {zip(low, high)}") 140 | return env 141 | 142 | 143 | class FlattenObsWrapper(gym.ObservationWrapper): 144 | """ 145 | Simple wrapper that flattens an image observation 146 | into a state vector when CNNs are overkill. 147 | """ 148 | 149 | def __init__(self, env): 150 | super().__init__(env) 151 | self.observation_space.shape = (np.prod(env.observation_space.shape),) 152 | 153 | def observation(self, obs): 154 | return obs.flatten() 155 | 156 | 157 | class ConcatObsWrapper(gym.ObservationWrapper): 158 | def __init__(self, env): 159 | super().__init__(env) 160 | obs_space_shape = sum(x.shape[0] for x in self.observation_space) 161 | self.observation_space.shape = (obs_space_shape,) 162 | 163 | def observation(self, obs): 164 | return np.concatenate(obs, axis=0) 165 | 166 | 167 | def highway_env(env_id): 168 | """ 169 | Convenience function to turn all the highway_env 170 | environments into continuous control tasks. 171 | 172 | highway_env: https://highway-env.readthedocs.io/en/latest/index.html 173 | """ 174 | import gym 175 | import highway_env 176 | 177 | env = gym.make(env_id) 178 | env.configure({"action": {"type": "ContinuousAction"}}) 179 | env.reset() 180 | env = NormalizeContinuousActionSpace(env) 181 | env = FlattenObsWrapper(env) 182 | return env 183 | 184 | 185 | class DiscreteActionWrapper(gym.ActionWrapper): 186 | """ 187 | This is intended to let the action be any scalar 188 | (float or int) or np array (float or int) of size 1. 189 | 190 | floats are cast to ints using python's standard rounding. 191 | """ 192 | 193 | def __init__(self, env): 194 | super().__init__(env) 195 | self.action_space.shape = (env.action_space.n,) 196 | 197 | def action(self, action): 198 | if isinstance(action, np.ndarray): 199 | if len(action.shape) > 0: 200 | action = action[0] 201 | return int(action) 202 | 203 | 204 | class FrameStack(gym.Wrapper): 205 | def __init__(self, env, num_stack): 206 | gym.Wrapper.__init__(self, env) 207 | self._k = num_stack 208 | self._frames = deque([], maxlen=num_stack) 209 | shp = env.observation_space.shape 210 | self.observation_space = gym.spaces.Box( 211 | low=0, 212 | high=1, 213 | shape=((shp[0] * num_stack,) + shp[1:]), 214 | dtype=env.observation_space.dtype, 215 | ) 216 | 217 | def reset(self): 218 | obs = self.env.reset() 219 | for _ in range(self._k): 220 | self._frames.append(obs) 221 | return self._get_obs() 222 | 223 | def step(self, action): 224 | obs, reward, done, info = self.env.step(action) 225 | self._frames.append(obs) 226 | return self._get_obs(), reward, done, info 227 | 228 | def _get_obs(self): 229 | assert len(self._frames) == self._k 230 | return np.concatenate(list(self._frames), axis=0) 231 | 232 | 233 | class GoalBasedWrapper(gym.ObservationWrapper): 234 | """ 235 | Some goal-based envs (like the Gym Robotics suite) use dictionary observations 236 | with one entry for the current state and another to describe the goal. This 237 | wrapper concatenates those into a single vector so it can be used just like 238 | any other env. 239 | """ 240 | 241 | def __init__(self, env): 242 | super().__init__(env) 243 | self.observation_space.shape = ( 244 | env.observation_space["observation"].shape[0] 245 | + env.observation_space["desired_goal"].shape[0], 246 | ) 247 | 248 | def observation(self, obs_dict): 249 | return self._flatten_obs(obs_dict) 250 | 251 | def _flatten_obs(self, obs_dict): 252 | return np.concatenate((obs_dict["observation"], obs_dict["desired_goal"])) 253 | 254 | 255 | def add_gym_args(parser): 256 | """ 257 | Add a --env_id cl flag to an argparser 258 | """ 259 | parser.add_argument("--env_id", type=str, default="Pendulum-v0") 260 | parser.add_argument("--seed", type=int, default=123) 261 | 262 | 263 | def load_gym(env_id="CartPole-v1", seed=None, normalize_action_space=True, **_): 264 | """ 265 | Load an environment from OpenAI gym (or pybullet_gym, if installed) 266 | """ 267 | # optional pybullet import 268 | try: 269 | import pybullet 270 | import pybullet_envs 271 | except ImportError: 272 | pass 273 | env = gym.make(env_id) 274 | if normalize_action_space and isinstance(env.action_space, gym.spaces.Box): 275 | env = NormalizeContinuousActionSpace(env) 276 | if seed is None: 277 | seed = random.randint(1, 100000) 278 | env.seed(seed) 279 | return env 280 | 281 | 282 | def add_dmc_args(parser): 283 | """ 284 | Add cl flags associated with the deepmind control suite to a parser 285 | """ 286 | parser.add_argument("--domain_name", type=str, default="fish") 287 | parser.add_argument("--task_name", type=str, default="swim") 288 | parser.add_argument( 289 | "--from_pixels", action="store_true", help="Use image observations" 290 | ) 291 | parser.add_argument("--height", type=int, default=84) 292 | parser.add_argument("--width", type=int, default=84) 293 | parser.add_argument("--camera_id", type=int, default=0) 294 | parser.add_argument("--frame_skip", type=int, default=1) 295 | parser.add_argument("--frame_stack", type=int, default=3) 296 | parser.add_argument("--channels_last", action="store_true") 297 | parser.add_argument("--rgb", action="store_true") 298 | parser.add_argument("--seed", type=int, default=231) 299 | 300 | 301 | def add_atari_args(parser): 302 | parser.add_argument("--game_id", type=str, default="Boxing-v0") 303 | parser.add_argument("--noop_max", type=int, default=30) 304 | parser.add_argument("--frame_skip", type=int, default=1) 305 | parser.add_argument("--screen_size", type=int, default=84) 306 | parser.add_argument("--terminal_on_life_loss", action="store_true") 307 | parser.add_argument("--rgb", action="store_true") 308 | parser.add_argument("--normalize", action="store_true") 309 | parser.add_argument("--frame_stack", type=int, default=4) 310 | parser.add_argument("--seed", type=int, default=231) 311 | 312 | 313 | def load_atari( 314 | game_id, 315 | seed=None, 316 | noop_max=30, 317 | frame_skip=1, 318 | screen_size=84, 319 | terminal_on_life_loss=False, 320 | rgb=False, 321 | normalize=False, 322 | frame_stack=4, 323 | clip_reward=True, 324 | **_, 325 | ): 326 | """ 327 | Load a game from the Atari benchmark, with the usual settings 328 | 329 | Note that the simplest game ids (e.g. Boxing-v0) come with frame 330 | skipping by default, and you'll get an error if the frame_skp arg > 1. 331 | Use `BoxingNoFrameskip-v0` with frame_skip > 1. 332 | """ 333 | env = gym.make(game_id) 334 | if seed is None: 335 | seed = random.randint(1, 100000) 336 | env.seed(seed) 337 | env = gym.wrappers.AtariPreprocessing( 338 | env, 339 | noop_max=noop_max, 340 | frame_skip=frame_skip, 341 | screen_size=screen_size, 342 | terminal_on_life_loss=terminal_on_life_loss, 343 | grayscale_obs=False, # use GrayScale wrapper instead... 344 | scale_obs=normalize, 345 | ) 346 | if not rgb: 347 | env = gym.wrappers.GrayScaleObservation(env, keep_dim=True) 348 | if clip_reward: 349 | env = ClipReward(env) 350 | env = ChannelsFirstWrapper(env) 351 | env = FrameStack(env, num_stack=frame_stack) 352 | env = DiscreteActionWrapper(env) 353 | return env 354 | 355 | 356 | class ClipReward(gym.RewardWrapper): 357 | def __init__(self, env, low=-1.0, high=1.0): 358 | super().__init__(env) 359 | self._clip_low = low 360 | self._clip_high = high 361 | 362 | def reward(self, rew): 363 | return max(min(rew, self._clip_high), self._clip_low) 364 | 365 | 366 | class ScaleReward(gym.RewardWrapper): 367 | def __init__(self, env, scale=1.0): 368 | super().__init__(env) 369 | self.scale = scale 370 | 371 | def reward(self, rew): 372 | return self.scale * rew 373 | 374 | 375 | class DeltaReward(gym.RewardWrapper): 376 | def __init__(self, env): 377 | super().__init__(env) 378 | self._old_rew = 0 379 | 380 | def reward(self, rew): 381 | delta_rew = rew - self._old_rew 382 | self._old_rew = rew 383 | return delta_rew 384 | 385 | 386 | def load_dmc( 387 | domain_name, 388 | task_name, 389 | seed=None, 390 | from_pixels=False, 391 | frame_stack=1, 392 | height=84, 393 | width=84, 394 | camera_id=0, 395 | frame_skip=1, 396 | channels_last=False, 397 | rgb=False, 398 | **_, 399 | ): 400 | """ 401 | Load a task from the deepmind control suite. 402 | 403 | Uses dmc2gym (https://github.com/denisyarats/dmc2gym) 404 | 405 | Note that setting seed=None (the default) picks a random seed 406 | """ 407 | import dmc2gym 408 | 409 | if seed is None: 410 | seed = random.randint(1, 100000) 411 | env = dmc2gym.make( 412 | domain_name=domain_name, 413 | task_name=task_name, 414 | from_pixels=from_pixels, 415 | height=height, 416 | width=width, 417 | camera_id=camera_id, 418 | visualize_reward=False, 419 | frame_skip=frame_skip, 420 | channels_first=not channels_last 421 | if rgb 422 | else False, # if we're using RGB, set the channel order here 423 | ) 424 | if not rgb and from_pixels: 425 | env = gym.wrappers.GrayScaleObservation(env, keep_dim=True) 426 | env = ChannelsFirstWrapper(env) 427 | if from_pixels: 428 | env = FrameStack(env, num_stack=frame_stack) 429 | return env 430 | -------------------------------------------------------------------------------- /deep_control/grac.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import math 4 | import os 5 | from itertools import chain 6 | 7 | import numpy as np 8 | import tensorboardX 9 | import torch 10 | import torch.nn.functional as F 11 | import tqdm 12 | 13 | from . import envs, nets, replay, run, utils, critic_searchers, sac 14 | 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | 17 | 18 | class GRACAgent(sac.SACAgent): 19 | def __init__( 20 | self, 21 | obs_space_size, 22 | act_space_size, 23 | log_std_low, 24 | log_std_high, 25 | actor_net_cls=nets.StochasticActor, 26 | critic_net_cls=nets.BigCritic, 27 | hidden_size=1024, 28 | ): 29 | super().__init__( 30 | obs_space_size=obs_space_size, 31 | act_space_size=act_space_size, 32 | log_std_low=log_std_low, 33 | log_std_high=log_std_high, 34 | actor_net_cls=actor_net_cls, 35 | hidden_size=hidden_size, 36 | ) 37 | self.cem = critic_searchers.CEM(act_space_size, max_action=1.0) 38 | 39 | 40 | def grac( 41 | agent, 42 | buffer, 43 | train_env, 44 | test_env, 45 | num_steps=1_000_000, 46 | transitions_per_step=1, 47 | max_critic_updates_per_step=20, 48 | critic_target_improvement_init=0.7, 49 | critic_target_improvement_final=0.9, 50 | gamma=0.99, 51 | batch_size=512, 52 | actor_lr=1e-4, 53 | critic_lr=1e-4, 54 | eval_interval=5000, 55 | eval_episodes=10, 56 | warmup_steps=1000, 57 | actor_clip=None, 58 | critic_clip=None, 59 | name="grac_run", 60 | max_episode_steps=100_000, 61 | render=False, 62 | save_interval=100_000, 63 | verbosity=0, 64 | critic_l2=0.0, 65 | actor_l2=0.0, 66 | log_to_disk=True, 67 | save_to_disk=True, 68 | infinite_bootstrap=True, 69 | **kwargs, 70 | ): 71 | """ 72 | Train `agent` on `train_env` using GRAC, and evaluate on `test_env`. 73 | 74 | GRAC: Self-Guided and Self-Regularized Actor-Critic (https://sites.google.com/view/gracdrl) 75 | 76 | GRAC is a combination of a stochastic policy with 77 | TD3-like stability improvements and CEM-based action selection 78 | like you'd find in Qt-Opt or CAQL. 79 | 80 | This is a pretty faithful reimplementation of the authors' version 81 | (https://github.com/stanford-iprl-lab/GRAC/blob/master/GRAC.py), with 82 | a couple differences: 83 | 84 | 1) The default agent architecture and batch size are chosen for a fair 85 | comparison with popular SAC settings (meaning they are larger). The 86 | agent's action distribution is also implemented in a way that is more 87 | like SAC (outputting log_std of a tanh squashed normal distribution) 88 | 2) The agent never collects experience with actions selected by CEM. 89 | """ 90 | if save_to_disk or log_to_disk: 91 | save_dir = utils.make_process_dirs(name) 92 | if log_to_disk: 93 | # create tb writer, save hparams 94 | writer = tensorboardX.SummaryWriter(save_dir) 95 | writer.add_hparams(locals(), {}) 96 | 97 | # no target networks! 98 | agent.to(device) 99 | agent.cem.batch_size = batch_size 100 | agent.train() 101 | 102 | # the critic target improvement ratio is annealed during training 103 | critic_target_imp_slope = ( 104 | critic_target_improvement_final - critic_target_improvement_init 105 | ) / num_steps 106 | current_target_imp = lambda step: min( 107 | critic_target_improvement_init + critic_target_imp_slope * step, 108 | critic_target_improvement_final, 109 | ) 110 | 111 | # set up optimizers 112 | critic_optimizer = torch.optim.Adam( 113 | chain( 114 | agent.critic1.parameters(), 115 | agent.critic2.parameters(), 116 | ), 117 | lr=critic_lr, 118 | weight_decay=critic_l2, 119 | betas=(0.9, 0.999), 120 | ) 121 | actor_optimizer = torch.optim.Adam( 122 | agent.actor.parameters(), 123 | lr=actor_lr, 124 | weight_decay=actor_l2, 125 | betas=(0.9, 0.999), 126 | ) 127 | 128 | # warmup the replay buffer with random actions 129 | run.warmup_buffer(buffer, train_env, warmup_steps, max_episode_steps) 130 | 131 | steps_iter = range(num_steps) 132 | if verbosity: 133 | steps_iter = tqdm.tqdm(steps_iter) 134 | 135 | done = True 136 | for step in steps_iter: 137 | # collect experience 138 | for _ in range(transitions_per_step): 139 | if done: 140 | state = train_env.reset() 141 | steps_this_ep = 0 142 | done = False 143 | action = agent.sample_action(state) 144 | next_state, reward, done, info = train_env.step(action) 145 | if infinite_bootstrap: 146 | # allow infinite bootstrapping 147 | if steps_this_ep + 1 == max_episode_steps: 148 | done = False 149 | buffer.push(state, action, reward, next_state, done) 150 | state = next_state 151 | steps_this_ep += 1 152 | if steps_this_ep >= max_episode_steps: 153 | done = True 154 | 155 | learn( 156 | buffer=buffer, 157 | agent=agent, 158 | actor_optimizer=actor_optimizer, 159 | critic_optimizer=critic_optimizer, 160 | critic_target_improvement=current_target_imp(step), 161 | max_critic_updates_per_step=max_critic_updates_per_step, 162 | batch_size=batch_size, 163 | gamma=gamma, 164 | critic_clip=critic_clip, 165 | actor_clip=actor_clip, 166 | ) 167 | 168 | if step % eval_interval == 0 or step == num_steps - 1: 169 | mean_return = run.evaluate_agent( 170 | agent, test_env, eval_episodes, max_episode_steps, render 171 | ) 172 | if log_to_disk: 173 | writer.add_scalar("return", mean_return, step * transitions_per_step) 174 | 175 | if step % save_interval == 0 and save_to_disk: 176 | agent.save(save_dir) 177 | 178 | if save_to_disk: 179 | agent.save(save_dir) 180 | return agent 181 | 182 | 183 | def learn( 184 | buffer, 185 | agent, 186 | actor_optimizer, 187 | critic_optimizer, 188 | critic_target_improvement, 189 | max_critic_updates_per_step, 190 | batch_size, 191 | gamma, 192 | critic_clip, 193 | actor_clip, 194 | ): 195 | per = isinstance(buffer, replay.PrioritizedReplayBuffer) 196 | if per: 197 | batch, imp_weights, priority_idxs = buffer.sample(batch_size) 198 | imp_weights = imp_weights.to(device) 199 | else: 200 | batch = buffer.sample(batch_size) 201 | 202 | # prepare transitions for models 203 | state_batch, action_batch, reward_batch, next_state_batch, done_batch = batch 204 | state_batch = state_batch.to(device) 205 | next_state_batch = next_state_batch.to(device) 206 | action_batch = action_batch.to(device) 207 | reward_batch = reward_batch.to(device) 208 | done_batch = done_batch.to(device) 209 | 210 | agent.train() 211 | 212 | ################### 213 | ## CRITIC UPDATE ## 214 | ################### 215 | 216 | with torch.no_grad(): 217 | # sample an action as normal 218 | action_dist_s1 = agent.actor(next_state_batch) 219 | action_s1 = action_dist_s1.sample().clamp(-1.0, 1.0) 220 | # use CEM to find a higher value action 221 | cem_action_s1 = agent.cem.search(next_state_batch, action_s1, agent.critic2) 222 | 223 | # clipped double q learning using both the agent and CEM actions 224 | clip_double_q_a1 = torch.min( 225 | agent.critic1(next_state_batch, action_s1), 226 | agent.critic2(next_state_batch, action_s1), 227 | ) 228 | 229 | clip_double_q_cema1 = torch.min( 230 | agent.critic1(next_state_batch, cem_action_s1), 231 | agent.critic2(next_state_batch, cem_action_s1), 232 | ) 233 | 234 | # best_action_s1 = argmax_a(clip_double_q_a1, clip_double_q_cema1) 235 | better_action_mask = (clip_double_q_cema1 >= clip_double_q_a1).squeeze(1) 236 | best_action_s1 = action_s1.clone() 237 | best_action_s1[better_action_mask] = cem_action_s1[better_action_mask] 238 | 239 | # critic opinions of best actions that were found 240 | y_1 = agent.critic1(next_state_batch, best_action_s1) 241 | y_2 = agent.critic2(next_state_batch, best_action_s1) 242 | 243 | # "max min double q learning" 244 | max_min_s1_value = torch.max(clip_double_q_a1, clip_double_q_cema1) 245 | td_target = reward_batch + gamma * (1.0 - done_batch) * max_min_s1_value 246 | 247 | # update critics 248 | critic_loss_initial = None 249 | for critic_update in range(max_critic_updates_per_step): 250 | # standard bellman error 251 | a_critic1_pred = agent.critic1(state_batch, action_batch) 252 | a_critic2_pred = agent.critic2(state_batch, action_batch) 253 | td_error1 = td_target - a_critic1_pred 254 | td_error2 = td_target - a_critic2_pred 255 | 256 | # constraints that discourage large changes in Q(s_{t+1}, a_{t+1}), 257 | a1_critic1_pred = agent.critic1(next_state_batch, best_action_s1) 258 | a1_critic2_pred = agent.critic2(next_state_batch, best_action_s1) 259 | a1_constraint1 = y_1 - a1_critic1_pred 260 | a1_constraint2 = y_2 - a1_critic2_pred 261 | 262 | elementwise_critic_loss = ( 263 | (td_error1 ** 2) 264 | + (td_error2 ** 2) 265 | + (a1_constraint1 ** 2) 266 | + (a1_constraint2 ** 2) 267 | ) 268 | if per: 269 | elementwise_loss *= imp_weights 270 | critic_loss = 0.5 * elementwise_critic_loss.mean() 271 | critic_optimizer.zero_grad() 272 | critic_loss.backward() 273 | if critic_clip: 274 | torch.nn.utils.clip_grad_norm_( 275 | chain(agent.critic1.parameters(), agent.critic2.parameters()), 276 | critic_clip, 277 | ) 278 | critic_optimizer.step() 279 | if critic_update == 0: 280 | critic_loss_initial = critic_loss 281 | elif critic_loss <= critic_target_improvement * critic_loss_initial: 282 | break 283 | 284 | ################## 285 | ## ACTOR UPDATE ## 286 | ################## 287 | 288 | # get agent's actions in this state 289 | dist = agent.actor(state_batch) 290 | agent_actions = dist.rsample().clamp(-1.0, 1.0) 291 | agent_action_value = agent.critic1(state_batch, agent_actions) 292 | 293 | # find higher-value actions with CEM 294 | cem_actions = agent.cem.search(state_batch, agent_actions, agent.critic1) 295 | logp_cema = dist.log_prob(cem_actions).sum(-1, keepdim=True) 296 | cem_action_value = agent.critic1(state_batch, cem_actions) 297 | 298 | # how much better are the CEM actions than the agent's? 299 | # clipped for rare cases where CEM actually finds a worse action... 300 | cem_advantage = F.relu(cem_action_value - agent_action_value).detach() 301 | # cem_adv_coeff = 1 / |A| ; best guess here is that this is meant 302 | # to balance the \sum_{i}_{log\pi(cem_action)_{i}}, which can get large 303 | # early in training when CEM tends to find very unlikely actions 304 | cem_adv_coeff = 1.0 / agent_actions.shape[1] 305 | 306 | actor_loss = -( 307 | agent_action_value + cem_adv_coeff * cem_advantage * logp_cema 308 | ).mean() 309 | actor_optimizer.zero_grad() 310 | actor_loss.backward() 311 | if actor_clip: 312 | torch.nn.utils.clip_grad_norm_(agent.actor.parameters(), actor_clip) 313 | actor_optimizer.step() 314 | 315 | if per: 316 | new_priorities = (abs(td_error1) + 1e-5).cpu().detach().squeeze(1).numpy() 317 | buffer.update_priorities(priority_idxs, new_priorities) 318 | 319 | 320 | def add_args(parser): 321 | parser.add_argument( 322 | "--num_steps", type=int, default=10 ** 6, help="number of steps in training" 323 | ) 324 | parser.add_argument( 325 | "--transitions_per_step", 326 | type=int, 327 | default=1, 328 | help="env transitions per training step. Defaults to 1, but will need to \ 329 | be set higher for repaly ratios < 1", 330 | ) 331 | parser.add_argument( 332 | "--max_episode_steps", 333 | type=int, 334 | default=100000, 335 | help="maximum steps per episode", 336 | ) 337 | parser.add_argument( 338 | "--batch_size", type=int, default=512, help="training batch size" 339 | ) 340 | parser.add_argument( 341 | "--actor_lr", type=float, default=1e-4, help="actor learning rate" 342 | ) 343 | parser.add_argument( 344 | "--critic_lr", type=float, default=1e-4, help="critic learning rate" 345 | ) 346 | parser.add_argument( 347 | "--gamma", type=float, default=0.99, help="gamma, the discount factor" 348 | ) 349 | parser.add_argument( 350 | "--buffer_size", type=int, default=1_000_000, help="replay buffer size" 351 | ) 352 | parser.add_argument( 353 | "--eval_interval", 354 | type=int, 355 | default=5000, 356 | help="how often to test the agent without exploration (in episodes)", 357 | ) 358 | parser.add_argument( 359 | "--eval_episodes", 360 | type=int, 361 | default=10, 362 | help="how many episodes to run for when testing", 363 | ) 364 | parser.add_argument( 365 | "--warmup_steps", type=int, default=1000, help="warmup length, in steps" 366 | ) 367 | parser.add_argument( 368 | "--render", 369 | action="store_true", 370 | help="flag to enable env rendering during training", 371 | ) 372 | parser.add_argument( 373 | "--actor_clip", 374 | type=float, 375 | default=None, 376 | help="gradient clipping for actor updates", 377 | ) 378 | parser.add_argument( 379 | "--critic_clip", 380 | type=float, 381 | default=None, 382 | help="gradient clipping for critic updates", 383 | ) 384 | parser.add_argument( 385 | "--name", type=str, default="grac_run", help="dir name for saves" 386 | ) 387 | parser.add_argument( 388 | "--actor_l2", 389 | type=float, 390 | default=0.0, 391 | help="L2 regularization coeff for actor network", 392 | ) 393 | parser.add_argument( 394 | "--critic_l2", 395 | type=float, 396 | default=0.0, 397 | help="L2 regularization coeff for critic network", 398 | ) 399 | parser.add_argument( 400 | "--save_interval", 401 | type=int, 402 | default=100_000, 403 | help="How many steps to go between saving the agent params to disk", 404 | ) 405 | parser.add_argument( 406 | "--verbosity", 407 | type=int, 408 | default=1, 409 | help="verbosity > 0 displays a progress bar during training", 410 | ) 411 | parser.add_argument( 412 | "--max_critic_updates_per_step", 413 | type=int, 414 | default=10, 415 | help="Max critic updates to make per step. The GRAC paper calls this K", 416 | ) 417 | parser.add_argument( 418 | "--prioritized_replay", 419 | action="store_true", 420 | help="flag that enables use of prioritized experience replay", 421 | ) 422 | parser.add_argument( 423 | "--skip_save_to_disk", 424 | action="store_true", 425 | help="flag to skip saving agent params to disk during training", 426 | ) 427 | parser.add_argument( 428 | "--skip_log_to_disk", 429 | action="store_true", 430 | help="flag to skip saving agent performance logs to disk during training", 431 | ) 432 | parser.add_argument( 433 | "--log_std_low", 434 | type=float, 435 | default=-10, 436 | help="Lower bound for log std of action distribution.", 437 | ) 438 | parser.add_argument( 439 | "--log_std_high", 440 | type=float, 441 | default=2, 442 | help="Upper bound for log std of action distribution.", 443 | ) 444 | parser.add_argument( 445 | "--critic_target_improvement_init", 446 | type=float, 447 | default=0.7, 448 | help="Stop critic updates when loss drops by this factor. The GRAC paper calls this alpha", 449 | ) 450 | parser.add_argument( 451 | "--critic_target_improvement_final", 452 | type=float, 453 | default=0.9, 454 | help="Stop critic updates when loss drops by this factor. The GRAC paper calls this alpha", 455 | ) 456 | -------------------------------------------------------------------------------- /deep_control/nets.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import distributions as pyd 7 | from torch import nn 8 | 9 | from . import utils 10 | 11 | 12 | def weight_init(m): 13 | if isinstance(m, nn.Linear): 14 | nn.init.orthogonal_(m.weight.data) 15 | m.bias.data.fill_(0.0) 16 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 17 | # delta-orthogonal init from https://arxiv.org/pdf/1806.05393.pdf 18 | assert m.weight.size(2) == m.weight.size(3) 19 | m.weight.data.fill_(0.0) 20 | m.bias.data.fill_(0.0) 21 | mid = m.weight.size(2) // 2 22 | gain = nn.init.calculate_gain("relu") 23 | nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain) 24 | 25 | 26 | class BigPixelEncoder(nn.Module): 27 | def __init__(self, obs_shape, out_dim=50): 28 | super().__init__() 29 | channels = obs_shape[0] 30 | self.conv1 = nn.Conv2d(channels, 32, kernel_size=3, stride=2) 31 | self.conv2 = nn.Conv2d(32, 32, kernel_size=3, stride=1) 32 | self.conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=1) 33 | self.conv4 = nn.Conv2d(32, 32, kernel_size=3, stride=1) 34 | 35 | output_height, output_width = utils.compute_conv_output( 36 | obs_shape[1:], kernel_size=(3, 3), stride=(2, 2) 37 | ) 38 | for _ in range(3): 39 | output_height, output_width = utils.compute_conv_output( 40 | (output_height, output_width), kernel_size=(3, 3), stride=(1, 1) 41 | ) 42 | 43 | self.fc = nn.Linear(output_height * output_width * 32, out_dim) 44 | self.ln = nn.LayerNorm(out_dim) 45 | self.apply(weight_init) 46 | 47 | def forward(self, obs): 48 | obs /= 255.0 49 | x = F.relu(self.conv1(obs)) 50 | x = F.relu(self.conv2(x)) 51 | x = F.relu(self.conv3(x)) 52 | x = F.relu(self.conv4(x)) 53 | x = x.view(x.size(0), -1) 54 | x = self.fc(x) 55 | x = self.ln(x) 56 | state = torch.tanh(x) 57 | return state 58 | 59 | 60 | class SmallPixelEncoder(nn.Module): 61 | def __init__(self, obs_shape, out_dim=50): 62 | super().__init__() 63 | channels = obs_shape[0] 64 | self.conv1 = nn.Conv2d(channels, 32, kernel_size=8, stride=4) 65 | self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) 66 | self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1) 67 | 68 | output_height, output_width = utils.compute_conv_output( 69 | obs_shape[1:], kernel_size=(8, 8), stride=(4, 4) 70 | ) 71 | 72 | output_height, output_width = utils.compute_conv_output( 73 | (output_height, output_width), kernel_size=(4, 4), stride=(2, 2) 74 | ) 75 | 76 | output_height, output_width = utils.compute_conv_output( 77 | (output_height, output_width), kernel_size=(3, 3), stride=(1, 1) 78 | ) 79 | 80 | self.fc = nn.Linear(output_height * output_width * 64, out_dim) 81 | self.apply(weight_init) 82 | 83 | def forward(self, obs): 84 | obs /= 255.0 85 | x = F.relu(self.conv1(obs)) 86 | x = F.relu(self.conv2(x)) 87 | x = F.relu(self.conv3(x)) 88 | x = x.view(x.size(0), -1) 89 | state = self.fc(x) 90 | return state 91 | 92 | 93 | class StochasticActor(nn.Module): 94 | def __init__( 95 | self, 96 | state_space_size, 97 | act_space_size, 98 | log_std_low=-10.0, 99 | log_std_high=2.0, 100 | hidden_size=1024, 101 | dist_impl="pyd", 102 | ): 103 | super().__init__() 104 | assert dist_impl in ["pyd", "beta"] 105 | self.fc1 = nn.Linear(state_space_size, hidden_size) 106 | self.fc2 = nn.Linear(hidden_size, hidden_size) 107 | self.fc3 = nn.Linear(hidden_size, 2 * act_space_size) 108 | self.log_std_low = log_std_low 109 | self.log_std_high = log_std_high 110 | self.apply(weight_init) 111 | self.dist_impl = dist_impl 112 | 113 | def forward(self, state): 114 | x = F.relu(self.fc1(state)) 115 | x = F.relu(self.fc2(x)) 116 | out = self.fc3(x) 117 | mu, log_std = out.chunk(2, dim=1) 118 | if self.dist_impl == "pyd": 119 | log_std = torch.tanh(log_std) 120 | log_std = self.log_std_low + 0.5 * ( 121 | self.log_std_high - self.log_std_low 122 | ) * (log_std + 1) 123 | std = log_std.exp() 124 | dist = SquashedNormal(mu, std) 125 | elif self.dist_impl == "beta": 126 | out = 1.0 + F.softplus(out) 127 | alpha, beta = out.chunk(2, dim=1) 128 | dist = BetaDist(alpha, beta) 129 | return dist 130 | 131 | 132 | class BigCritic(nn.Module): 133 | def __init__(self, state_space_size, act_space_size, hidden_size=1024): 134 | super().__init__() 135 | self.fc1 = nn.Linear(state_space_size + act_space_size, hidden_size) 136 | self.fc2 = nn.Linear(hidden_size, hidden_size) 137 | self.fc3 = nn.Linear(hidden_size, 1) 138 | 139 | self.apply(weight_init) 140 | 141 | def forward(self, state, action): 142 | x = F.relu(self.fc1(torch.cat((state, action), dim=1))) 143 | x = F.relu(self.fc2(x)) 144 | out = self.fc3(x) 145 | return out 146 | 147 | 148 | class BaselineActor(nn.Module): 149 | def __init__(self, state_size, action_size, hidden_size=400): 150 | super().__init__() 151 | self.fc1 = nn.Linear(state_size, hidden_size) 152 | self.fc2 = nn.Linear(hidden_size, hidden_size) 153 | self.out = nn.Linear(hidden_size, action_size) 154 | 155 | def forward(self, state): 156 | x = F.relu(self.fc1(state)) 157 | x = F.relu(self.fc2(x)) 158 | act = torch.tanh(self.out(x)) 159 | return act 160 | 161 | 162 | class BaselineCritic(nn.Module): 163 | def __init__(self, state_size, action_size, hidden_size=400): 164 | super().__init__() 165 | self.fc1 = nn.Linear(state_size + action_size, hidden_size) 166 | self.fc2 = nn.Linear(hidden_size, hidden_size) 167 | self.out = nn.Linear(hidden_size, 1) 168 | 169 | def forward(self, state, action): 170 | x = torch.cat((state, action), dim=1) 171 | x = F.relu(self.fc1(x)) 172 | x = F.relu(self.fc2(x)) 173 | val = self.out(x) 174 | return val 175 | 176 | 177 | class BetaDist(pyd.transformed_distribution.TransformedDistribution): 178 | class _BetaDistTransform(pyd.transforms.Transform): 179 | domain = pyd.constraints.real 180 | codomain = pyd.constraints.interval(-1.0, 1.0) 181 | 182 | def __init__(self, cache_size=1): 183 | super().__init__(cache_size=cache_size) 184 | 185 | def __eq__(self, other): 186 | return isinstance(other, _BetaDistTransform) 187 | 188 | def _inverse(self, y): 189 | return (y.clamp(-0.99, 0.99) + 1.0) / 2.0 190 | 191 | def _call(self, x): 192 | return (2.0 * x) - 1.0 193 | 194 | def log_abs_det_jacobian(self, x, y): 195 | # return log det jacobian |dy/dx| given input and output 196 | return torch.Tensor([math.log(2.0)]).to(x.device) 197 | 198 | def __init__(self, alpha, beta): 199 | self.base_dist = pyd.beta.Beta(alpha, beta) 200 | transforms = [self._BetaDistTransform()] 201 | super().__init__(self.base_dist, transforms) 202 | 203 | @property 204 | def mean(self): 205 | mu = self.base_dist.mean 206 | for tr in self.transforms: 207 | mu = tr(mu) 208 | return mu 209 | 210 | 211 | """ 212 | Credit for actor distribution code: https://github.com/denisyarats/pytorch_sac/blob/master/agent/actor.py 213 | """ 214 | 215 | 216 | class TanhTransform(pyd.transforms.Transform): 217 | domain = pyd.constraints.real 218 | codomain = pyd.constraints.interval(-1.0, 1.0) 219 | bijective = True 220 | sign = +1 221 | 222 | def __init__(self, cache_size=1): 223 | super().__init__(cache_size=cache_size) 224 | 225 | @staticmethod 226 | def atanh(x): 227 | return 0.5 * (x.log1p() - (-x).log1p()) 228 | 229 | def __eq__(self, other): 230 | return isinstance(other, TanhTransform) 231 | 232 | def _call(self, x): 233 | return x.tanh() 234 | 235 | def _inverse(self, y): 236 | return self.atanh(y.clamp(-0.99, 0.99)) 237 | 238 | def log_abs_det_jacobian(self, x, y): 239 | return 2.0 * (math.log(2.0) - x - F.softplus(-2.0 * x)) 240 | 241 | 242 | class SquashedNormal(pyd.transformed_distribution.TransformedDistribution): 243 | def __init__(self, loc, scale): 244 | self.loc = loc 245 | self.scale = scale 246 | 247 | self.base_dist = pyd.Normal(loc, scale) 248 | transforms = [TanhTransform()] 249 | super().__init__(self.base_dist, transforms) 250 | 251 | @property 252 | def mean(self): 253 | mu = self.loc 254 | for tr in self.transforms: 255 | mu = tr(mu) 256 | return mu 257 | 258 | 259 | class GracBaselineActor(nn.Module): 260 | def __init__(self, obs_size, action_size): 261 | super().__init__() 262 | self.fc1 = nn.Linear(obs_size, 400) 263 | self.fc2 = nn.Linear(400, 300) 264 | self.fc_mean = nn.Linear(300, action_size) 265 | self.fc_std = nn.Linear(300, action_size) 266 | 267 | def forward(self, state, stochastic=False): 268 | x = F.relu(self.fc1(state)) 269 | x = F.relu(self.fc2(x)) 270 | mean = torch.tanh(self.fc_mean(x)) 271 | std = F.softplus(self.fc_std(x)) + 1e-3 272 | dist = pyd.Normal(mean, std) 273 | return dist 274 | -------------------------------------------------------------------------------- /deep_control/redq.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import math 4 | import os 5 | from itertools import chain 6 | import random 7 | 8 | import numpy as np 9 | import tensorboardX 10 | import torch 11 | import torch.nn.functional as F 12 | import tqdm 13 | 14 | from . import envs, nets, replay, run, utils, device 15 | 16 | 17 | class REDQAgent: 18 | def __init__( 19 | self, 20 | obs_space_size, 21 | act_space_size, 22 | log_std_low, 23 | log_std_high, 24 | critic_ensemble_size=10, 25 | actor_network_cls=nets.StochasticActor, 26 | critic_network_cls=nets.BigCritic, 27 | hidden_size=1024, 28 | ): 29 | self.actor = actor_network_cls( 30 | obs_space_size, 31 | act_space_size, 32 | log_std_low, 33 | log_std_high, 34 | dist_impl="pyd", 35 | hidden_size=hidden_size, 36 | ) 37 | self.critics = [ 38 | critic_network_cls(obs_space_size, act_space_size, hidden_size=hidden_size) 39 | for _ in range(critic_ensemble_size) 40 | ] 41 | 42 | def to(self, device): 43 | self.actor = self.actor.to(device) 44 | for i, critic in enumerate(self.critics): 45 | self.critics[i] = critic.to(device) 46 | 47 | def eval(self): 48 | self.actor.eval() 49 | for critic in self.critics: 50 | critic.eval() 51 | 52 | def train(self): 53 | self.actor.train() 54 | for critic in self.critics: 55 | critic.train() 56 | 57 | def save(self, path): 58 | actor_path = os.path.join(path, "actor.pt") 59 | torch.save(self.actor.state_dict(), actor_path) 60 | for i, critic in enumerate(self.critics): 61 | critic_path = os.path.join(path, f"critic{i}.pt") 62 | torch.save(critic.state_dict(), critic_path) 63 | 64 | def load(self, path): 65 | actor_path = os.path.join(path, "actor.pt") 66 | self.actor.load_state_dict(torch.load(actor_path)) 67 | for i, critic in enumerate(self.critics): 68 | critic_path = os.path.join(path, f"critic{i}.pt") 69 | critic.load_state_dict(torch.load(critic_path)) 70 | 71 | def forward(self, state, from_cpu=True): 72 | if from_cpu: 73 | state = self.process_state(state) 74 | self.actor.eval() 75 | with torch.no_grad(): 76 | act_dist = self.actor.forward(state) 77 | act = act_dist.mean 78 | self.actor.train() 79 | if from_cpu: 80 | act = self.process_act(act) 81 | return act 82 | 83 | def sample_action(self, state, from_cpu=True): 84 | if from_cpu: 85 | state = self.process_state(state) 86 | self.actor.eval() 87 | with torch.no_grad(): 88 | act_dist = self.actor.forward(state) 89 | act = act_dist.sample() 90 | self.actor.train() 91 | if from_cpu: 92 | act = self.process_act(act) 93 | return act 94 | 95 | def process_state(self, state): 96 | return torch.from_numpy(np.expand_dims(state, 0).astype(np.float32)).to( 97 | utils.device 98 | ) 99 | 100 | def process_act(self, act): 101 | return np.squeeze(act.clamp(-1.0, 1.0).cpu().numpy(), 0) 102 | 103 | 104 | def redq( 105 | agent, 106 | buffer, 107 | train_env, 108 | test_env, 109 | num_steps=1_000_000, 110 | transitions_per_step=1, 111 | max_episode_steps=100_000, 112 | batch_size=512, 113 | tau=0.005, 114 | actor_lr=3e-4, 115 | critic_lr=3e-4, 116 | alpha_lr=1e-4, 117 | gamma=0.99, 118 | eval_interval=5000, 119 | eval_episodes=10, 120 | warmup_steps=1000, 121 | actor_clip=None, 122 | critic_clip=None, 123 | actor_l2=0.0, 124 | critic_l2=0.0, 125 | target_delay=2, 126 | save_interval=100_000, 127 | name="redq_run", 128 | render=False, 129 | save_to_disk=True, 130 | log_to_disk=True, 131 | verbosity=0, 132 | actor_updates_per_step=1, 133 | critic_updates_per_step=20, 134 | random_ensemble_size=2, 135 | init_alpha=0.1, 136 | infinite_bootstrap=True, 137 | **kwargs, 138 | ): 139 | """ 140 | "Randomized Ensembled Dobule Q-Learning: Learning Fast Without a Model", Chen et al., 2020 141 | 142 | REDQ is an extension of the clipped double Q learning trick. To create the 143 | target value, we sample M critic networks from an ensemble of size N. This 144 | reduces the overestimation bias of the critics, and also allows us to use 145 | much higher replay ratios (actor_updates_per_step or critic_updates_per_step 146 | >> transitions_per_step). This makes REDQ very sample efficient, but really 147 | hurts wall clock time relative to SAC/TD3. REDQ's sample efficiency makes 148 | MBPO a more fair comparison, in which case it would be considered fast. 149 | REDQ can be applied to just about any actor-critic algorithm; we implement 150 | it on SAC here. 151 | """ 152 | assert len(agent.critics) >= random_ensemble_size 153 | 154 | if save_to_disk or log_to_disk: 155 | save_dir = utils.make_process_dirs(name) 156 | if log_to_disk: 157 | # create tb writer, save hparams 158 | writer = tensorboardX.SummaryWriter(save_dir) 159 | writer.add_hparams(locals(), {}) 160 | 161 | ########### 162 | ## SETUP ## 163 | ########### 164 | agent.to(device) 165 | agent.train() 166 | # initialize target networks 167 | target_agent = copy.deepcopy(agent) 168 | target_agent.to(device) 169 | for target_critic, agent_critic in zip(target_agent.critics, agent.critics): 170 | utils.hard_update(target_critic, agent_critic) 171 | target_agent.train() 172 | 173 | # set up optimizers 174 | critic_optimizer = torch.optim.Adam( 175 | chain(*(critic.parameters() for critic in agent.critics)), 176 | lr=critic_lr, 177 | weight_decay=critic_l2, 178 | betas=(0.9, 0.999), 179 | ) 180 | actor_optimizer = torch.optim.Adam( 181 | agent.actor.parameters(), 182 | lr=actor_lr, 183 | weight_decay=actor_l2, 184 | betas=(0.9, 0.999), 185 | ) 186 | log_alpha = torch.Tensor([math.log(init_alpha)]).to(device) 187 | log_alpha.requires_grad = True 188 | log_alpha_optimizer = torch.optim.Adam([log_alpha], lr=alpha_lr, betas=(0.5, 0.999)) 189 | target_entropy = -train_env.action_space.shape[0] 190 | 191 | ################### 192 | ## TRAINING LOOP ## 193 | ################### 194 | # warmup the replay buffer with random actions 195 | run.warmup_buffer(buffer, train_env, warmup_steps, max_episode_steps) 196 | done = True 197 | steps_iter = range(num_steps) 198 | if verbosity: 199 | steps_iter = tqdm.tqdm(steps_iter) 200 | for step in steps_iter: 201 | # collect experience 202 | for _ in range(transitions_per_step): 203 | if done: 204 | state = train_env.reset() 205 | steps_this_ep = 0 206 | done = False 207 | action = agent.sample_action(state) 208 | next_state, reward, done, info = train_env.step(action) 209 | if infinite_bootstrap and steps_this_ep + 1 == max_episode_steps: 210 | # allow infinite bootstrapping 211 | done = False 212 | buffer.push(state, action, reward, next_state, done) 213 | state = next_state 214 | steps_this_ep += 1 215 | if steps_this_ep >= max_episode_steps: 216 | done = True 217 | 218 | # critic update 219 | for _ in range(critic_updates_per_step): 220 | learn_critics( 221 | buffer=buffer, 222 | target_agent=target_agent, 223 | agent=agent, 224 | critic_optimizer=critic_optimizer, 225 | log_alpha=log_alpha, 226 | batch_size=batch_size, 227 | gamma=gamma, 228 | critic_clip=critic_clip, 229 | random_ensemble_size=random_ensemble_size, 230 | ) 231 | 232 | # actor update 233 | for _ in range(actor_updates_per_step): 234 | learn_actor( 235 | buffer=buffer, 236 | agent=agent, 237 | actor_optimizer=actor_optimizer, 238 | log_alpha_optimizer=log_alpha_optimizer, 239 | target_entropy=target_entropy, 240 | batch_size=batch_size, 241 | log_alpha=log_alpha, 242 | gamma=gamma, 243 | actor_clip=actor_clip, 244 | ) 245 | 246 | # move target model towards training model 247 | if step % target_delay == 0: 248 | for (target_critic, agent_critic) in zip( 249 | target_agent.critics, agent.critics 250 | ): 251 | utils.soft_update(target_critic, agent_critic, tau) 252 | 253 | if (step % eval_interval == 0) or (step == num_steps - 1): 254 | mean_return = run.evaluate_agent( 255 | agent, test_env, eval_episodes, max_episode_steps, render 256 | ) 257 | if log_to_disk: 258 | writer.add_scalar("return", mean_return, step * transitions_per_step) 259 | 260 | if step % save_interval == 0 and save_to_disk: 261 | agent.save(save_dir) 262 | 263 | if save_to_disk: 264 | agent.save(save_dir) 265 | return agent 266 | 267 | 268 | def learn_critics( 269 | buffer, 270 | target_agent, 271 | agent, 272 | critic_optimizer, 273 | batch_size, 274 | log_alpha, 275 | gamma, 276 | critic_clip, 277 | random_ensemble_size, 278 | ): 279 | per = isinstance(buffer, replay.PrioritizedReplayBuffer) 280 | if per: 281 | batch, imp_weights, priority_idxs = buffer.sample(batch_size) 282 | imp_weights = imp_weights.to(device) 283 | else: 284 | batch = buffer.sample(batch_size) 285 | 286 | # prepare transitions for models 287 | state_batch, action_batch, reward_batch, next_state_batch, done_batch = batch 288 | state_batch = state_batch.to(device) 289 | next_state_batch = next_state_batch.to(device) 290 | action_batch = action_batch.to(device) 291 | reward_batch = reward_batch.to(device) 292 | done_batch = done_batch.to(device) 293 | 294 | agent.train() 295 | 296 | ################### 297 | ## CRITIC UPDATE ## 298 | ################### 299 | alpha = torch.exp(log_alpha) 300 | with torch.no_grad(): 301 | action_dist_s1 = agent.actor(next_state_batch) 302 | action_s1 = action_dist_s1.rsample() 303 | logp_a1 = action_dist_s1.log_prob(action_s1).sum(-1, keepdim=True) 304 | 305 | target_critic_ensemble = random.sample( 306 | target_agent.critics, random_ensemble_size 307 | ) 308 | target_critic_ensemble_preds = ( 309 | critic(next_state_batch, action_s1) for critic in target_critic_ensemble 310 | ) 311 | target_action_value_s1 = torch.min(*target_critic_ensemble_preds) 312 | td_target = reward_batch + gamma * (1.0 - done_batch) * ( 313 | target_action_value_s1 - (alpha * logp_a1) 314 | ) 315 | 316 | # update critics 317 | critic_loss = 0.0 318 | for i, critic in enumerate(agent.critics): 319 | agent_critic_pred = critic(state_batch, action_batch) 320 | td_error = td_target - agent_critic_pred 321 | critic_loss += 0.5 * (td_error ** 2) 322 | if per: 323 | critic_loss *= imp_weights 324 | critic_loss = critic_loss.mean() 325 | critic_optimizer.zero_grad() 326 | critic_loss.backward() 327 | if critic_clip: 328 | torch.nn.utils.clip_grad_norm_( 329 | chain(*(critic.parameters() for critic in agent.critics)), 330 | critic_clip, 331 | ) 332 | critic_optimizer.step() 333 | 334 | if per: 335 | # just using td error of the last critic here, although an average is probably better 336 | new_priorities = (abs(td_error) + 1e-5).cpu().detach().squeeze(1).numpy() 337 | buffer.update_priorities(priority_idxs, new_priorities) 338 | 339 | 340 | def learn_actor( 341 | buffer, 342 | agent, 343 | actor_optimizer, 344 | log_alpha_optimizer, 345 | target_entropy, 346 | batch_size, 347 | log_alpha, 348 | gamma, 349 | actor_clip, 350 | ): 351 | per = isinstance(buffer, replay.PrioritizedReplayBuffer) 352 | if per: 353 | batch, *_ = buffer.sample(batch_size) 354 | imp_weights = imp_weights.to(device) 355 | else: 356 | batch = buffer.sample(batch_size) 357 | 358 | # prepare transitions for models 359 | state_batch, *_ = batch 360 | state_batch = state_batch.to(device) 361 | 362 | agent.train() 363 | alpha = torch.exp(log_alpha) 364 | 365 | ################## 366 | ## ACTOR UPDATE ## 367 | ################## 368 | dist = agent.actor(state_batch) 369 | agent_actions = dist.rsample() 370 | logp_a = dist.log_prob(agent_actions).sum(-1, keepdim=True) 371 | stacked_preds = torch.stack( 372 | [critic(state_batch, agent_actions) for critic in agent.critics], dim=0 373 | ) 374 | mean_critic_pred = torch.mean(stacked_preds, dim=0) 375 | actor_loss = -(mean_critic_pred - (alpha.detach() * logp_a)).mean() 376 | actor_optimizer.zero_grad() 377 | actor_loss.backward() 378 | if actor_clip: 379 | torch.nn.utils.clip_grad_norm_(agent.actor.parameters(), actor_clip) 380 | actor_optimizer.step() 381 | 382 | ################## 383 | ## ALPHA UPDATE ## 384 | ################## 385 | alpha_loss = (-alpha * (logp_a + target_entropy).detach()).mean() 386 | log_alpha_optimizer.zero_grad() 387 | alpha_loss.backward() 388 | log_alpha_optimizer.step() 389 | 390 | 391 | def add_args(parser): 392 | parser.add_argument( 393 | "--num_steps", type=int, default=10 ** 6, help="Number of steps in training" 394 | ) 395 | parser.add_argument( 396 | "--transitions_per_step", 397 | type=int, 398 | default=1, 399 | help="env transitions per training step. Defaults to 1, but will need to \ 400 | be set higher for repaly ratios < 1", 401 | ) 402 | parser.add_argument( 403 | "--max_episode_steps", 404 | type=int, 405 | default=100000, 406 | help="maximum steps per episode", 407 | ) 408 | parser.add_argument( 409 | "--batch_size", type=int, default=512, help="training batch size" 410 | ) 411 | parser.add_argument( 412 | "--tau", type=float, default=0.005, help="for model parameter % update" 413 | ) 414 | parser.add_argument( 415 | "--actor_lr", type=float, default=3e-4, help="actor learning rate" 416 | ) 417 | parser.add_argument( 418 | "--critic_lr", type=float, default=3e-4, help="critic learning rate" 419 | ) 420 | parser.add_argument( 421 | "--gamma", type=float, default=0.99, help="gamma, the discount factor" 422 | ) 423 | parser.add_argument( 424 | "--init_alpha", 425 | type=float, 426 | default=0.1, 427 | help="initial entropy regularization coefficeint.", 428 | ) 429 | parser.add_argument( 430 | "--alpha_lr", 431 | type=float, 432 | default=1e-4, 433 | help="alpha (entropy regularization coefficeint) learning rate", 434 | ) 435 | parser.add_argument( 436 | "--buffer_size", type=int, default=1_000_000, help="replay buffer size" 437 | ) 438 | parser.add_argument( 439 | "--eval_interval", 440 | type=int, 441 | default=5000, 442 | help="how often to test the agent without exploration (in episodes)", 443 | ) 444 | parser.add_argument( 445 | "--eval_episodes", 446 | type=int, 447 | default=10, 448 | help="how many episodes to run for when testing", 449 | ) 450 | parser.add_argument( 451 | "--warmup_steps", type=int, default=1000, help="warmup length, in steps" 452 | ) 453 | parser.add_argument( 454 | "--render", 455 | action="store_true", 456 | help="flag to enable env rendering during training", 457 | ) 458 | parser.add_argument( 459 | "--actor_clip", 460 | type=float, 461 | default=None, 462 | help="gradient clipping for actor updates", 463 | ) 464 | parser.add_argument( 465 | "--critic_clip", 466 | type=float, 467 | default=None, 468 | help="gradient clipping for critic updates", 469 | ) 470 | parser.add_argument( 471 | "--name", type=str, default="redq_run", help="dir name for saves" 472 | ) 473 | parser.add_argument( 474 | "--actor_l2", 475 | type=float, 476 | default=0.0, 477 | help="L2 regularization coeff for actor network", 478 | ) 479 | parser.add_argument( 480 | "--critic_l2", 481 | type=float, 482 | default=0.0, 483 | help="L2 regularization coeff for critic network", 484 | ) 485 | parser.add_argument( 486 | "--target_delay", 487 | type=int, 488 | default=2, 489 | help="How many training steps to go between target network updates", 490 | ) 491 | parser.add_argument( 492 | "--save_interval", 493 | type=int, 494 | default=100_000, 495 | help="How many steps to go between saving the agent params to disk", 496 | ) 497 | parser.add_argument( 498 | "--verbosity", 499 | type=int, 500 | default=1, 501 | help="verbosity > 0 displays a progress bar during training", 502 | ) 503 | parser.add_argument( 504 | "--critic_updates_per_step", 505 | type=int, 506 | default=20, 507 | help="how many critic gradient updates to make per training step. The REDQ paper calls this variable G.", 508 | ) 509 | parser.add_argument( 510 | "--actor_updates_per_step", 511 | type=int, 512 | default=1, 513 | help="how many actor gradient updates to make per training step", 514 | ) 515 | parser.add_argument( 516 | "--prioritized_replay", 517 | action="store_true", 518 | help="flag that enables use of prioritized experience replay", 519 | ) 520 | parser.add_argument( 521 | "--skip_save_to_disk", 522 | action="store_true", 523 | help="flag to skip saving agent params to disk during training", 524 | ) 525 | parser.add_argument( 526 | "--skip_log_to_disk", 527 | action="store_true", 528 | help="flag to skip saving agent performance logs to disk during training", 529 | ) 530 | parser.add_argument( 531 | "--log_std_low", 532 | type=float, 533 | default=-10, 534 | help="Lower bound for log std of action distribution.", 535 | ) 536 | parser.add_argument( 537 | "--log_std_high", 538 | type=float, 539 | default=2, 540 | help="Upper bound for log std of action distribution.", 541 | ) 542 | parser.add_argument( 543 | "--random_ensemble_size", 544 | type=int, 545 | default=2, 546 | help="How many random critic networks to use per TD target computation. The REDQ paper calls this variable M", 547 | ) 548 | parser.add_argument( 549 | "--critic_ensemble_size", 550 | type=int, 551 | default=10, 552 | help="How many critic networks to sample from on each TD target computation. This it the total size of the critic ensemble. The REDQ paper calls this variable N", 553 | ) 554 | -------------------------------------------------------------------------------- /deep_control/replay.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def unique(sorted_array): 6 | """ 7 | More efficient implementation of np.unique for sorted arrays 8 | :param sorted_array: (np.ndarray) 9 | :return:(np.ndarray) sorted_array without duplicate elements 10 | """ 11 | if len(sorted_array) == 1: 12 | return sorted_array 13 | left = sorted_array[:-1] 14 | right = sorted_array[1:] 15 | uniques = np.append(right != left, True) 16 | return sorted_array[uniques] 17 | 18 | 19 | class SegmentTree: 20 | def __init__(self, capacity, operation, neutral_element): 21 | """ 22 | Build a Segment Tree data structure. 23 | https://en.wikipedia.org/wiki/Segment_tree 24 | Can be used as regular array that supports Index arrays, but with two 25 | important differences: 26 | a) setting item's value is slightly slower. 27 | It is O(lg capacity) instead of O(1). 28 | b) user has access to an efficient ( O(log segment size) ) 29 | `reduce` operation which reduces `operation` over 30 | a contiguous subsequence of items in the array. 31 | :param capacity: (int) Total size of the array - must be a power of two. 32 | :param operation: (lambda (Any, Any): Any) operation for combining elements (eg. sum, max) must form a 33 | mathematical group together with the set of possible values for array elements (i.e. be associative) 34 | :param neutral_element: (Any) neutral element for the operation above. eg. float('-inf') for max and 0 for sum. 35 | """ 36 | assert ( 37 | capacity > 0 and capacity & (capacity - 1) == 0 38 | ), "capacity must be positive and a power of 2." 39 | self._capacity = capacity 40 | self._value = [neutral_element for _ in range(2 * capacity)] 41 | self._operation = operation 42 | self.neutral_element = neutral_element 43 | 44 | def _reduce_helper(self, start, end, node, node_start, node_end): 45 | if start == node_start and end == node_end: 46 | return self._value[node] 47 | mid = (node_start + node_end) // 2 48 | if end <= mid: 49 | return self._reduce_helper(start, end, 2 * node, node_start, mid) 50 | else: 51 | if mid + 1 <= start: 52 | return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end) 53 | else: 54 | return self._operation( 55 | self._reduce_helper(start, mid, 2 * node, node_start, mid), 56 | self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end), 57 | ) 58 | 59 | def reduce(self, start=0, end=None): 60 | """ 61 | Returns result of applying `self.operation` 62 | to a contiguous subsequence of the array. 63 | self.operation(arr[start], operation(arr[start+1], operation(... arr[end]))) 64 | :param start: (int) beginning of the subsequence 65 | :param end: (int) end of the subsequences 66 | :return: (Any) result of reducing self.operation over the specified range of array elements. 67 | """ 68 | if end is None: 69 | end = self._capacity 70 | if end < 0: 71 | end += self._capacity 72 | end -= 1 73 | return self._reduce_helper(start, end, 1, 0, self._capacity - 1) 74 | 75 | def __setitem__(self, idx, val): 76 | # indexes of the leaf 77 | idxs = idx + self._capacity 78 | self._value[idxs] = val 79 | if isinstance(idxs, int): 80 | idxs = np.array([idxs]) 81 | # go up one level in the tree and remove duplicate indexes 82 | idxs = unique(idxs // 2) 83 | while len(idxs) > 1 or idxs[0] > 0: 84 | # as long as there are non-zero indexes, update the corresponding values 85 | self._value[idxs] = self._operation( 86 | self._value[2 * idxs], self._value[2 * idxs + 1] 87 | ) 88 | # go up one level in the tree and remove duplicate indexes 89 | idxs = unique(idxs // 2) 90 | 91 | def __getitem__(self, idx): 92 | assert np.max(idx) < self._capacity 93 | assert 0 <= np.min(idx) 94 | return self._value[self._capacity + idx] 95 | 96 | 97 | class SumSegmentTree(SegmentTree): 98 | def __init__(self, capacity): 99 | super(SumSegmentTree, self).__init__( 100 | capacity=capacity, operation=np.add, neutral_element=0.0 101 | ) 102 | self._value = np.array(self._value) 103 | 104 | def sum(self, start=0, end=None): 105 | """ 106 | Returns arr[start] + ... + arr[end] 107 | :param start: (int) start position of the reduction (must be >= 0) 108 | :param end: (int) end position of the reduction (must be < len(arr), can be None for len(arr) - 1) 109 | :return: (Any) reduction of SumSegmentTree 110 | """ 111 | return super(SumSegmentTree, self).reduce(start, end) 112 | 113 | def find_prefixsum_idx(self, prefixsum): 114 | """ 115 | Find the highest index `i` in the array such that 116 | sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum for each entry in prefixsum 117 | if array values are probabilities, this function 118 | allows to sample indexes according to the discrete 119 | probability efficiently. 120 | :param prefixsum: (np.ndarray) float upper bounds on the sum of array prefix 121 | :return: (np.ndarray) highest indexes satisfying the prefixsum constraint 122 | """ 123 | if isinstance(prefixsum, float): 124 | prefixsum = np.array([prefixsum]) 125 | assert 0 <= np.min(prefixsum) 126 | assert np.max(prefixsum) <= self.sum() + 1e-5 127 | assert isinstance(prefixsum[0], float) 128 | 129 | idx = np.ones(len(prefixsum), dtype=int) 130 | cont = np.ones(len(prefixsum), dtype=bool) 131 | 132 | while np.any(cont): # while not all nodes are leafs 133 | idx[cont] = 2 * idx[cont] 134 | prefixsum_new = np.where( 135 | self._value[idx] <= prefixsum, prefixsum - self._value[idx], prefixsum 136 | ) 137 | # prepare update of prefixsum for all right children 138 | idx = np.where( 139 | np.logical_or(self._value[idx] > prefixsum, np.logical_not(cont)), 140 | idx, 141 | idx + 1, 142 | ) 143 | # Select child node for non-leaf nodes 144 | prefixsum = prefixsum_new 145 | # update prefixsum 146 | cont = idx < self._capacity 147 | # collect leafs 148 | return idx - self._capacity 149 | 150 | 151 | class MinSegmentTree(SegmentTree): 152 | def __init__(self, capacity): 153 | super(MinSegmentTree, self).__init__( 154 | capacity=capacity, operation=np.minimum, neutral_element=float("inf") 155 | ) 156 | self._value = np.array(self._value) 157 | 158 | def min(self, start=0, end=None): 159 | """ 160 | Returns min(arr[start], ..., arr[end]) 161 | :param start: (int) start position of the reduction (must be >= 0) 162 | :param end: (int) end position of the reduction (must be < len(arr), can be None for len(arr) - 1) 163 | :return: (Any) reduction of MinSegmentTree 164 | """ 165 | return super(MinSegmentTree, self).reduce(start, end) 166 | 167 | 168 | class ReplayBufferStorage: 169 | def __init__(self, size, obs_shape, act_shape, obs_dtype=torch.float32): 170 | self.s_dtype = obs_dtype 171 | 172 | # buffer arrays 173 | self.s_stack = torch.zeros((size,) + obs_shape, dtype=self.s_dtype) 174 | self.action_stack = torch.zeros((size,) + act_shape, dtype=torch.float32) 175 | self.reward_stack = torch.zeros((size, 1), dtype=torch.float32) 176 | self.s1_stack = torch.zeros((size,) + obs_shape, dtype=self.s_dtype) 177 | self.done_stack = torch.zeros((size, 1), dtype=torch.int) 178 | 179 | self.obs_shape = obs_shape 180 | self.size = size 181 | self._next_idx = 0 182 | self._max_filled = 0 183 | 184 | def __len__(self): 185 | return self._max_filled 186 | 187 | def add(self, s, a, r, s_1, d): 188 | # this buffer supports batched experience 189 | if len(s.shape) > len(self.obs_shape): 190 | # there must be a batch dimension 191 | num_samples = len(s) 192 | else: 193 | num_samples = 1 194 | r, d = [r], [d] 195 | 196 | if not isinstance(s, torch.Tensor): 197 | # convert states to numpy (checking for LazyFrames) 198 | if not isinstance(s, np.ndarray): 199 | s = np.asarray(s) 200 | if not isinstance(s_1, np.ndarray): 201 | s_1 = np.asarray(s_1) 202 | 203 | # convert to torch tensors 204 | s = torch.from_numpy(s) 205 | a = torch.from_numpy(a).float() 206 | r = torch.Tensor(r).float() 207 | s_1 = torch.from_numpy(s_1) 208 | d = torch.Tensor(d).int() 209 | 210 | # make sure tensors are floats not doubles 211 | if self.s_dtype is torch.float32: 212 | s = s.float() 213 | s_1 = s_1.float() 214 | 215 | else: 216 | # move to cpu 217 | s = s.cpu() 218 | a = a.cpu() 219 | r = r.cpu() 220 | s_1 = s_1.cpu() 221 | d = d.int().cpu() 222 | 223 | # Store at end of buffer. Wrap around if past end. 224 | R = np.arange(self._next_idx, self._next_idx + num_samples) % self.size 225 | self.s_stack[R] = s 226 | self.action_stack[R] = a 227 | self.reward_stack[R] = r 228 | self.s1_stack[R] = s_1 229 | self.done_stack[R] = d 230 | # Advance index. 231 | self._max_filled = min( 232 | max(self._next_idx + num_samples, self._max_filled), self.size 233 | ) 234 | self._next_idx = (self._next_idx + num_samples) % self.size 235 | return R 236 | 237 | def __getitem__(self, indices): 238 | try: 239 | iter(indices) 240 | except ValueError: 241 | raise IndexError( 242 | "ReplayBufferStorage getitem called with indices object that is not iterable" 243 | ) 244 | 245 | # converting states and actions to float here instead of inside the learning loop 246 | # of each agent seems fine for now. 247 | state = self.s_stack[indices].float() 248 | action = self.action_stack[indices].float() 249 | reward = self.reward_stack[indices] 250 | next_state = self.s1_stack[indices].float() 251 | done = self.done_stack[indices] 252 | return (state, action, reward, next_state, done) 253 | 254 | def __setitem__(self, indices, experience): 255 | s, a, r, s1, d = experience 256 | self.s_stack[indices] = s.float() 257 | self.action_stack[indices] = a.float() 258 | self.reward_stack[indices] = r 259 | self.s1_stack[indices] = s1.float() 260 | self.done_stack[indices] = d 261 | 262 | def get_all_transitions(self): 263 | return ( 264 | self.s_stack[: self._max_filled], 265 | self.action_stack[: self._max_filled], 266 | self.reward_stack[: self._max_filled], 267 | self.s1_stack[: self._max_filled], 268 | self.done_stack[: self._max_filled], 269 | ) 270 | 271 | 272 | class ReplayBuffer: 273 | def __init__(self, size, state_shape=None, action_shape=None, state_dtype=float): 274 | self._maxsize = size 275 | self.state_shape = state_shape 276 | self.state_dtype = self._convert_dtype(state_dtype) 277 | self.action_shape = action_shape 278 | self._storage = None 279 | assert self.state_shape, "Must provide shape of state space to ReplayBuffer" 280 | assert self.action_shape, "Must provide shape of action space to ReplayBuffer" 281 | 282 | def _convert_dtype(self, dtype): 283 | if dtype in [int, np.uint8, torch.uint8]: 284 | return torch.uint8 285 | elif dtype in [float, np.float32, np.float64, torch.float32, torch.float64]: 286 | return torch.float32 287 | elif dtype in ["int32", np.int32]: 288 | return torch.int32 289 | else: 290 | raise ValueError(f"Uncreocgnized replay buffer dtype: {dtype}") 291 | 292 | def __len__(self): 293 | return len(self._storage) if self._storage is not None else 0 294 | 295 | def push(self, state, action, reward, next_state, done): 296 | if self._storage is None: 297 | self._storage = ReplayBufferStorage( 298 | self._maxsize, 299 | obs_shape=self.state_shape, 300 | act_shape=self.action_shape, 301 | obs_dtype=self.state_dtype, 302 | ) 303 | return self._storage.add(state, action, reward, next_state, done) 304 | 305 | def sample(self, batch_size, get_idxs=False): 306 | random_idxs = torch.randint(len(self._storage), (batch_size,)) 307 | if get_idxs: 308 | return self._storage[random_idxs], random_idxs.cpu().numpy() 309 | else: 310 | return self._storage[random_idxs] 311 | 312 | def get_all_transitions(self): 313 | return self._storage.get_all_transitions() 314 | 315 | def load_experience(self, s, a, r, s1, d): 316 | assert ( 317 | s.shape[0] <= self._maxsize 318 | ), "Experience dataset is larger than the buffer." 319 | if len(r.shape) < 2: 320 | r = np.expand_dims(r, 1) 321 | if len(d.shape) < 2: 322 | d = np.expand_dims(d, 1) 323 | self.push(s, a, r, s1, d) 324 | 325 | 326 | class PrioritizedReplayBuffer(ReplayBuffer): 327 | def __init__( 328 | self, size, state_shape, action_shape, state_dtype=float, alpha=0.6, beta=1.0 329 | ): 330 | super(PrioritizedReplayBuffer, self).__init__( 331 | size, state_shape, action_shape, state_dtype 332 | ) 333 | assert alpha >= 0 334 | self.alpha = alpha 335 | self.beta = beta 336 | 337 | it_capacity = 1 338 | while it_capacity < size: 339 | it_capacity *= 2 340 | 341 | self._it_sum = SumSegmentTree(it_capacity) 342 | self._it_min = MinSegmentTree(it_capacity) 343 | self._max_priority = 1.0 344 | 345 | def push(self, s, a, r, s_1, d, priorities=None): 346 | R = super().push(s, a, r, s_1, d) 347 | if priorities is None: 348 | priorities = self._max_priority 349 | self._it_sum[R] = priorities ** self.alpha 350 | self._it_min[R] = priorities ** self.alpha 351 | 352 | def _sample_proportional(self, batch_size): 353 | mass = [] 354 | total = self._it_sum.sum(0, len(self._storage) - 1) 355 | mass = np.random.random(size=batch_size) * total 356 | idx = self._it_sum.find_prefixsum_idx(mass) 357 | return idx 358 | 359 | def sample(self, batch_size): 360 | idxes = self._sample_proportional(batch_size) 361 | p_min = self._it_min.min() / self._it_sum.sum() 362 | max_weight = (p_min * len(self._storage)) ** (-self.beta) 363 | p_sample = self._it_sum[idxes] / self._it_sum.sum() 364 | weights = (p_sample * len(self._storage)) ** (-self.beta) / max_weight 365 | return self._storage[idxes], torch.from_numpy(weights), idxes 366 | 367 | def sample_uniform(self, batch_size): 368 | return super().sample(batch_size, get_idxs=True) 369 | 370 | def update_priorities(self, idxes, priorities): 371 | assert len(idxes) == len(priorities) 372 | assert np.min(priorities) > 0 373 | assert np.min(idxes) >= 0 374 | assert np.max(idxes) < len(self._storage) 375 | self._it_sum[idxes] = priorities ** self.alpha 376 | self._it_min[idxes] = priorities ** self.alpha 377 | self._max_priority = max(self._max_priority, np.max(priorities)) 378 | 379 | 380 | class MultiPriorityBuffer(ReplayBuffer): 381 | def __init__( 382 | self, 383 | size, 384 | trees, 385 | state_shape, 386 | action_shape, 387 | state_dtype=float, 388 | alpha=0.6, 389 | beta=1.0, 390 | ): 391 | super(MultiPriorityBuffer, self).__init__( 392 | size, state_shape, action_shape, state_dtype 393 | ) 394 | assert alpha >= 0 395 | self.alpha = alpha 396 | self.beta = beta 397 | 398 | it_capacity = 1 399 | while it_capacity < size: 400 | it_capacity *= 2 401 | 402 | self.sum_trees = [SumSegmentTree(it_capacity) for _ in range(trees)] 403 | self.min_trees = [MinSegmentTree(it_capacity) for _ in range(trees)] 404 | self._max_priority = 1.0 405 | 406 | def push(self, s, a, r, s_1, d, priorities=None): 407 | R = super().push(s, a, r, s_1, d) 408 | if priorities is None: 409 | priorities = self._max_priority 410 | 411 | for sum_tree in self.sum_trees: 412 | sum_tree[R] = priorities ** self.alpha 413 | for min_tree in self.min_trees: 414 | min_tree[R] = priorities ** self.alpha 415 | 416 | def _sample_proportional(self, batch_size, tree_num): 417 | mass = [] 418 | total = self.sum_trees[tree_num].sum(0, len(self._storage) - 1) 419 | mass = np.random.random(size=batch_size) * total 420 | idx = self.sum_trees[tree_num].find_prefixsum_idx(mass) 421 | return idx 422 | 423 | def sample(self, batch_size, tree_num): 424 | idxes = self._sample_proportional(batch_size, tree_num) 425 | p_min = self.min_trees[tree_num].min() / self.sum_trees[tree_num].sum() 426 | max_weight = (p_min * len(self._storage)) ** (-self.beta) 427 | p_sample = self.sum_trees[tree_num][idxes] / self.sum_trees[tree_num].sum() 428 | weights = (p_sample * len(self._storage)) ** (-self.beta) / max_weight 429 | return self._storage[idxes], torch.from_numpy(weights), idxes 430 | 431 | def sample_uniform(self, batch_size): 432 | return super().sample(batch_size, get_idxs=True) 433 | 434 | def update_priorities(self, idxes, priorities, tree_num): 435 | assert len(idxes) == len(priorities) 436 | assert np.min(priorities) > 0 437 | assert np.min(idxes) >= 0 438 | assert np.max(idxes) < len(self._storage) 439 | self.sum_trees[tree_num][idxes] = priorities ** self.alpha 440 | self.min_trees[tree_num][idxes] = priorities ** self.alpha 441 | self._max_priority = max(self._max_priority, np.max(priorities)) 442 | -------------------------------------------------------------------------------- /deep_control/run.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | import torch 4 | 5 | from . import envs, utils 6 | 7 | 8 | def run_env(agent, env, episodes, max_steps, render=False, verbosity=1, discount=1.0): 9 | episode_return_history = [] 10 | if render: 11 | env.render("rgb_array") 12 | for episode in range(episodes): 13 | episode_return = 0.0 14 | state = env.reset() 15 | done, info = False, {} 16 | for step_num in range(max_steps): 17 | if done: 18 | break 19 | action = agent.forward(state) 20 | state, reward, done, info = env.step(action) 21 | if render: 22 | env.render("rgb_array") 23 | episode_return += reward * (discount ** step_num) 24 | if verbosity: 25 | print(f"Episode {episode}:: {episode_return}") 26 | episode_return_history.append(episode_return) 27 | return torch.tensor(episode_return_history) 28 | 29 | 30 | def exploration_noise(action, random_process): 31 | return np.clip(action + random_process.sample(), -1.0, 1.0) 32 | 33 | 34 | def evaluate_agent( 35 | agent, env, eval_episodes, max_episode_steps, render=False, verbosity=0 36 | ): 37 | agent.eval() 38 | returns = run_env( 39 | agent, env, eval_episodes, max_episode_steps, render, verbosity=verbosity 40 | ) 41 | agent.train() 42 | mean_return = returns.mean() 43 | return mean_return 44 | 45 | 46 | def collect_experience_by_steps( 47 | agent, 48 | env, 49 | buffer, 50 | num_steps, 51 | current_state=None, 52 | current_done=None, 53 | steps_this_ep=None, 54 | max_rollout_length=None, 55 | random_process=None, 56 | ): 57 | if current_state is None: 58 | state = env.reset() 59 | else: 60 | state = current_state 61 | if current_done is None: 62 | done = False 63 | else: 64 | done = current_done 65 | if steps_this_ep is None: 66 | steps_this_ep = 0 67 | for step in range(num_steps): 68 | if done: 69 | state = env.reset() 70 | steps_this_ep = 0 71 | 72 | # collect a new transition 73 | action = agent.collection_forward(state) 74 | if random_process is not None: 75 | action = exploration_noise(action, random_process, env.action_space.high[0]) 76 | next_state, reward, done, info = env.step(action) 77 | buffer.push(state, action, reward, next_state, done) 78 | state = next_state 79 | 80 | steps_this_ep += 1 81 | if max_rollout_length and steps_this_ep >= max_rollout_length: 82 | done = True 83 | return state, done, steps_this_ep 84 | 85 | 86 | def collect_experience_by_rollouts( 87 | agent, 88 | env, 89 | buffer, 90 | num_rollouts, 91 | max_rollout_length, 92 | random_process=None, 93 | ): 94 | for rollout in range(num_rollouts): 95 | state = env.reset() 96 | done = False 97 | step_num = 0 98 | while not done: 99 | action = agent.collection_forward(state) 100 | if random_process is not None: 101 | action = exploration_noise( 102 | action, random_process, env.action_space.high[0] 103 | ) 104 | next_state, reward, done, info = env.step(action) 105 | buffer.push(state, action, reward, next_state, done) 106 | state = next_state 107 | step_num += 1 108 | if step_num >= max_rollout_length: 109 | done = True 110 | 111 | 112 | def warmup_buffer(buffer, env, warmup_steps, max_episode_steps): 113 | # use warmp up steps to add random transitions to the buffer 114 | state = env.reset() 115 | done = False 116 | steps_this_ep = 0 117 | for _ in range(warmup_steps): 118 | if done: 119 | state = env.reset() 120 | steps_this_ep = 0 121 | done = False 122 | rand_action = env.action_space.sample() 123 | if not isinstance(rand_action, np.ndarray): 124 | rand_action = np.array(float(rand_action)) 125 | next_state, reward, done, info = env.step(rand_action) 126 | buffer.push(state, rand_action, reward, next_state, done) 127 | state = next_state 128 | steps_this_ep += 1 129 | if steps_this_ep >= max_episode_steps: 130 | done = True 131 | 132 | 133 | if __name__ == "__main__": 134 | parser = argparse.ArgumentParser() 135 | parser.add_argument("--render", type=int, default=1) 136 | parser.add_argument("--env", type=str) 137 | parser.add_argument("--episodes", type=int, default=10) 138 | parser.add_argument("--save", type=str) 139 | parser.add_argument("--algo", type=str) 140 | parser.add_argument("--max_steps", type=int, default=300) 141 | args = parser.parse_args() 142 | 143 | agent, env = envs.load_env(args.env, args.algo) 144 | agent.load(args.agent) 145 | run_env(agent, env, args.episodes, args.max_steps, args.render, verbosity=1) 146 | -------------------------------------------------------------------------------- /deep_control/sac_aug.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import math 4 | import os 5 | from itertools import chain 6 | 7 | import gym 8 | import numpy as np 9 | import tensorboardX 10 | import torch 11 | import torch.nn.functional as F 12 | import tqdm 13 | 14 | from deep_control import envs, nets, replay, run, sac, utils 15 | from deep_control.augmentations import AugmentationSequence, DrqAug 16 | 17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | 19 | 20 | class PixelSACAgent(sac.SACAgent): 21 | def __init__(self, obs_shape, act_space_size, log_std_low, log_std_high): 22 | self.encoder = nets.BigPixelEncoder(obs_shape, out_dim=50) 23 | self.actor = nets.StochasticActor(50, act_space_size, log_std_low, log_std_high) 24 | self.critic1 = nets.BigCritic(50, act_space_size) 25 | self.critic2 = nets.BigCritic(50, act_space_size) 26 | self.log_std_low = log_std_low 27 | self.log_std_high = log_std_high 28 | 29 | def forward(self, obs): 30 | # eval forward (don't sample from distribution) 31 | obs = self.process_state(obs) 32 | self.encoder.eval() 33 | self.actor.eval() 34 | with torch.no_grad(): 35 | state_rep = self.encoder.forward(obs) 36 | act_dist = self.actor.forward(state_rep) 37 | act = act_dist.mean 38 | self.encoder.train() 39 | self.actor.train() 40 | return self.process_act(act) 41 | 42 | def sample_action(self, obs): 43 | obs = self.process_state(obs) 44 | self.encoder.eval() 45 | self.actor.eval() 46 | with torch.no_grad(): 47 | state_rep = self.encoder.forward(obs) 48 | act_dist = self.actor.forward(state_rep) 49 | act = act_dist.sample() 50 | self.encoder.train() 51 | self.actor.train() 52 | act = self.process_act(act) 53 | return act 54 | 55 | def to(self, device): 56 | self.encoder = self.encoder.to(device) 57 | super().to(device) 58 | 59 | def eval(self): 60 | self.encoder.eval() 61 | super().eval() 62 | 63 | def train(self): 64 | self.encoder.train() 65 | super().train() 66 | 67 | def save(self, path): 68 | encoder_path = os.path.join(path, "encoder.pt") 69 | torch.save(self.encoder.state_dict(), encoder_path) 70 | super().save(path) 71 | 72 | def load(self, path): 73 | encoder_path = os.path.join(path, "encoder.pt") 74 | self.encoder.load_state_dict(torch.load(encoder_path)) 75 | super().load(path) 76 | 77 | 78 | def sac_aug( 79 | agent, 80 | buffer, 81 | train_env, 82 | test_env, 83 | augmenter, 84 | num_steps=250_000, 85 | transitions_per_step=1, 86 | max_episode_steps=100_000, 87 | batch_size=256, 88 | mlp_tau=0.01, 89 | encoder_tau=0.05, 90 | actor_lr=1e-3, 91 | critic_lr=1e-3, 92 | encoder_lr=1e-3, 93 | alpha_lr=1e-4, 94 | gamma=0.99, 95 | eval_interval=10_000, 96 | test_eval_episodes=10, 97 | train_eval_episodes=0, 98 | warmup_steps=1000, 99 | actor_clip=None, 100 | critic_clip=None, 101 | actor_l2=0.0, 102 | critic_l2=0.0, 103 | encoder_l2=0.0, 104 | delay=2, 105 | save_interval=10_000, 106 | name="sac_aug_run", 107 | render=False, 108 | save_to_disk=True, 109 | log_to_disk=True, 110 | verbosity=0, 111 | gradient_updates_per_step=1, 112 | init_alpha=0.1, 113 | feature_matching_imp=0.0, 114 | aug_mix=1.0, 115 | infinite_bootstrap=True, 116 | **kwargs, 117 | ): 118 | if save_to_disk or log_to_disk: 119 | save_dir = utils.make_process_dirs(name) 120 | # create tb writer, save hparams 121 | if log_to_disk: 122 | writer = tensorboardX.SummaryWriter(save_dir) 123 | writer.add_hparams(locals(), {}) 124 | 125 | agent.to(device) 126 | agent.train() 127 | 128 | # initialize target networks (target actor isn't used in SAC) 129 | target_agent = copy.deepcopy(agent) 130 | target_agent.to(device) 131 | utils.hard_update(target_agent.critic1, agent.critic1) 132 | utils.hard_update(target_agent.critic2, agent.critic2) 133 | utils.hard_update(target_agent.encoder, agent.encoder) 134 | target_agent.train() 135 | 136 | # create network optimizers 137 | critic_optimizer = torch.optim.Adam( 138 | chain( 139 | agent.critic1.parameters(), 140 | agent.critic2.parameters(), 141 | ), 142 | lr=critic_lr, 143 | weight_decay=critic_l2, 144 | betas=(0.9, 0.999), 145 | ) 146 | encoder_optimizer = torch.optim.Adam( 147 | agent.encoder.parameters(), 148 | lr=encoder_lr, 149 | weight_decay=encoder_l2, 150 | betas=(0.9, 0.999), 151 | ) 152 | actor_optimizer = torch.optim.Adam( 153 | agent.actor.parameters(), 154 | lr=actor_lr, 155 | weight_decay=actor_l2, 156 | betas=(0.9, 0.999), 157 | ) 158 | 159 | # initialize learnable alpha param 160 | log_alpha = torch.Tensor([math.log(init_alpha)]).to(device) 161 | log_alpha.requires_grad = True 162 | log_alpha_optimizer = torch.optim.Adam([log_alpha], lr=alpha_lr, betas=(0.5, 0.999)) 163 | 164 | target_entropy = -train_env.action_space.shape[0] 165 | 166 | run.warmup_buffer(buffer, train_env, warmup_steps, max_episode_steps) 167 | 168 | done = True 169 | steps_this_ep = 0 170 | 171 | steps_iter = range(num_steps) 172 | if verbosity: 173 | steps_iter = tqdm.tqdm(steps_iter) 174 | 175 | for step in steps_iter: 176 | for _ in range(transitions_per_step): 177 | if done: 178 | obs = train_env.reset() 179 | steps_this_ep = 0 180 | done = False 181 | # batch the actions 182 | action = agent.sample_action(obs) 183 | next_obs, reward, done, info = train_env.step(action) 184 | if infinite_bootstrap and steps_this_ep + 1 == max_episode_steps: 185 | # allow infinite bootstrapping 186 | done = False 187 | buffer.push(obs, action, reward, next_obs, done) 188 | obs = next_obs 189 | steps_this_ep += 1 190 | if steps_this_ep >= max_episode_steps: 191 | done = True 192 | 193 | update_policy = step % delay == 0 194 | for _ in range(gradient_updates_per_step): 195 | learn_from_pixels( 196 | buffer=buffer, 197 | target_agent=target_agent, 198 | agent=agent, 199 | actor_optimizer=actor_optimizer, 200 | critic_optimizer=critic_optimizer, 201 | encoder_optimizer=encoder_optimizer, 202 | log_alpha=log_alpha, 203 | log_alpha_optimizer=log_alpha_optimizer, 204 | target_entropy=target_entropy, 205 | batch_size=batch_size, 206 | gamma=gamma, 207 | critic_clip=critic_clip, 208 | actor_clip=actor_clip, 209 | update_policy=update_policy, 210 | augmenter=augmenter, 211 | feature_matching_imp=feature_matching_imp, 212 | aug_mix=aug_mix, 213 | ) 214 | 215 | # move target model towards training model 216 | if update_policy: 217 | utils.soft_update(target_agent.critic1, agent.critic1, mlp_tau) 218 | utils.soft_update(target_agent.critic2, agent.critic2, mlp_tau) 219 | utils.soft_update(target_agent.encoder, agent.encoder, encoder_tau) 220 | 221 | if step % eval_interval == 0 or step == num_steps - 1: 222 | mean_test_return = run.evaluate_agent( 223 | agent, test_env, test_eval_episodes, max_episode_steps, render 224 | ) 225 | mean_train_return = run.evaluate_agent( 226 | agent, train_env, train_eval_episodes, max_episode_steps, render 227 | ) 228 | if log_to_disk: 229 | writer.add_scalar( 230 | "performance/test_return", 231 | mean_test_return, 232 | step * transitions_per_step, 233 | ) 234 | writer.add_scalar( 235 | "performance/train_return", 236 | mean_train_return, 237 | step * transitions_per_step, 238 | ) 239 | 240 | if step % save_interval == 0 and save_to_disk: 241 | agent.save(save_dir) 242 | 243 | if save_to_disk: 244 | agent.save(save_dir) 245 | return agent 246 | 247 | 248 | def learn_from_pixels( 249 | buffer, 250 | target_agent, 251 | agent, 252 | actor_optimizer, 253 | critic_optimizer, 254 | encoder_optimizer, 255 | log_alpha_optimizer, 256 | target_entropy, 257 | log_alpha, 258 | augmenter, 259 | batch_size=128, 260 | gamma=0.99, 261 | critic_clip=None, 262 | actor_clip=None, 263 | update_policy=True, 264 | feature_matching_imp=1.0, 265 | aug_mix=0.75, 266 | ): 267 | 268 | per = isinstance(buffer, replay.PrioritizedReplayBuffer) 269 | if per: 270 | batch, imp_weights, priority_idxs = buffer.sample(batch_size) 271 | imp_weights = imp_weights.to(device) 272 | else: 273 | batch = buffer.sample(batch_size) 274 | 275 | # sample unaugmented transitions from the buffer 276 | og_obs_batch, action_batch, reward_batch, og_next_obs_batch, done_batch = batch 277 | og_obs_batch = og_obs_batch.to(device) 278 | og_next_obs_batch = og_next_obs_batch.to(device) 279 | # at this point, the obs batches are float32s [0., 255.] on the gpu 280 | 281 | # created an augmented version of each transition 282 | # the augmenter applies a random transition to each batch index, 283 | # but keep the random params consistent between obs and next_obs batches 284 | aug_obs_batch, aug_next_obs_batch = augmenter(og_obs_batch, og_next_obs_batch) 285 | 286 | # mix the augmented versions in with the standard 287 | # no need to shuffle because the replay buffer handles that 288 | aug_mix_idx = int(batch_size * aug_mix) 289 | obs_batch = og_obs_batch.clone() 290 | obs_batch[:aug_mix_idx] = aug_obs_batch[:aug_mix_idx] 291 | next_obs_batch = og_next_obs_batch.clone() 292 | next_obs_batch[:aug_mix_idx] = aug_next_obs_batch[:aug_mix_idx] 293 | 294 | action_batch = action_batch.to(device) 295 | reward_batch = reward_batch.to(device) 296 | done_batch = done_batch.to(device) 297 | 298 | alpha = torch.exp(log_alpha) 299 | 300 | with torch.no_grad(): 301 | # create critic targets (clipped double Q learning) 302 | next_state_rep = target_agent.encoder(next_obs_batch) 303 | action_dist_s1 = agent.actor(next_state_rep) 304 | action_s1 = action_dist_s1.rsample() 305 | logp_a1 = action_dist_s1.log_prob(action_s1).sum(-1, keepdim=True) 306 | 307 | target_action_value_s1 = torch.min( 308 | target_agent.critic1(next_state_rep, action_s1), 309 | target_agent.critic2(next_state_rep, action_s1), 310 | ) 311 | td_target = reward_batch + gamma * (1.0 - done_batch) * ( 312 | target_action_value_s1 - (alpha * logp_a1) 313 | ) 314 | 315 | # update critics with Bellman MSE 316 | state_rep = agent.encoder(obs_batch) 317 | agent_critic1_pred = agent.critic1(state_rep, action_batch) 318 | td_error1 = td_target - agent_critic1_pred 319 | if per: 320 | critic1_loss = (imp_weights * 0.5 * (td_error1 ** 2)).mean() 321 | else: 322 | critic1_loss = 0.5 * (td_error1 ** 2).mean() 323 | 324 | agent_critic2_pred = agent.critic2(state_rep, action_batch) 325 | td_error2 = td_target - agent_critic2_pred 326 | if per: 327 | critic2_loss = (imp_weights * 0.5 * (td_error2 ** 2)).mean() 328 | else: 329 | critic2_loss = 0.5 * (td_error2 ** 2).mean() 330 | 331 | # optional feature matching loss to make state_rep invariant to augs 332 | if feature_matching_imp > 0.0: 333 | aug_rep = agent.encoder(aug_obs_batch) 334 | with torch.no_grad(): 335 | og_rep = agent.encoder(og_obs_batch) 336 | fm_loss = torch.norm(aug_rep - og_rep) 337 | else: 338 | fm_loss = 0.0 339 | 340 | critic_loss = critic1_loss + critic2_loss + feature_matching_imp * fm_loss 341 | 342 | critic_optimizer.zero_grad() 343 | encoder_optimizer.zero_grad() 344 | critic_loss.backward() 345 | if critic_clip: 346 | torch.nn.utils.clip_grad_norm_( 347 | chain( 348 | agent.critic1.parameters(), 349 | agent.critic2.parameters(), 350 | ), 351 | critic_clip, 352 | ) 353 | critic_optimizer.step() 354 | encoder_optimizer.step() 355 | 356 | if update_policy: 357 | # actor update 358 | dist = agent.actor(state_rep.detach()) 359 | agent_actions = dist.rsample() 360 | logp_a = dist.log_prob(agent_actions).sum(-1, keepdim=True) 361 | 362 | actor_loss = -( 363 | torch.min( 364 | agent.critic1(state_rep.detach(), agent_actions), 365 | agent.critic2(state_rep.detach(), agent_actions), 366 | ) 367 | - (alpha.detach() * logp_a) 368 | ).mean() 369 | 370 | actor_optimizer.zero_grad() 371 | actor_loss.backward() 372 | if actor_clip: 373 | torch.nn.utils.clip_grad_norm_(agent.actor.parameters(), actor_clip) 374 | actor_optimizer.step() 375 | 376 | # alpha update 377 | alpha_loss = (-alpha * (logp_a + target_entropy).detach()).mean() 378 | log_alpha_optimizer.zero_grad() 379 | alpha_loss.backward() 380 | log_alpha_optimizer.step() 381 | 382 | if per: 383 | new_priorities = (abs(td_error1) + 1e-5).cpu().data.squeeze(1).numpy() 384 | buffer.update_priorities(priority_idxs, new_priorities) 385 | 386 | 387 | def add_args(parser): 388 | parser.add_argument( 389 | "--num_steps", 390 | type=int, 391 | default=250_000, 392 | help="Number of training steps.", 393 | ) 394 | parser.add_argument( 395 | "--transitions_per_step", 396 | type=int, 397 | default=1, 398 | help="Env transitions per training step. Defaults to 1, but will need to \ 399 | be set higher for repaly ratios < 1", 400 | ) 401 | parser.add_argument( 402 | "--max_episode_steps_start", 403 | type=int, 404 | default=1000, 405 | help="Maximum steps per episode", 406 | ) 407 | parser.add_argument( 408 | "--max_episode_steps_final", 409 | type=int, 410 | default=1000, 411 | help="Maximum steps per episode", 412 | ) 413 | parser.add_argument( 414 | "--max_episode_steps_anneal", 415 | type=float, 416 | default=0.4, 417 | help="Maximum steps per episode", 418 | ) 419 | parser.add_argument( 420 | "--batch_size", type=int, default=256, help="Training batch size" 421 | ) 422 | parser.add_argument( 423 | "--mlp_tau", 424 | type=float, 425 | default=0.01, 426 | help="Determines how quickly the target agent's critic networks params catch up to the trained agent.", 427 | ) 428 | parser.add_argument( 429 | "--encoder_tau", 430 | type=float, 431 | default=0.05, 432 | help="Determines how quickly the target agent's encoder network params catch up to the trained agent. This is typically set higher than mlp_tau because the encoder is used in both actor and critic updates.", 433 | ) 434 | parser.add_argument( 435 | "--actor_lr", 436 | type=float, 437 | default=1e-3, 438 | help="Actor network learning rate", 439 | ) 440 | parser.add_argument( 441 | "--critic_lr", 442 | type=float, 443 | default=1e-3, 444 | help="Critic networks' learning rate", 445 | ) 446 | parser.add_argument( 447 | "--gamma", 448 | type=float, 449 | default=0.99, 450 | help="POMDP discount factor", 451 | ) 452 | parser.add_argument( 453 | "--init_alpha", 454 | type=float, 455 | default=0.1, 456 | help="Initial entropy regularization coefficeint.", 457 | ) 458 | parser.add_argument( 459 | "--alpha_lr", 460 | type=float, 461 | default=1e-4, 462 | help="Alpha (entropy regularization coefficeint) learning rate", 463 | ) 464 | parser.add_argument( 465 | "--buffer_size", 466 | type=int, 467 | default=100_000, 468 | help="Replay buffer maximum capacity. Note that image observations can take up a lot of memory, especially when using frame stacking. The buffer allocates a large tensor of zeros to fail fast if it will not have enough memory to complete the training run.", 469 | ) 470 | parser.add_argument( 471 | "--eval_interval", 472 | type=int, 473 | default=10_000, 474 | help="How often to test the agent without exploration (in training steps)", 475 | ) 476 | parser.add_argument( 477 | "--test_eval_episodes", 478 | type=int, 479 | default=10, 480 | help="How many episodes to run for when evaluating on the testing set", 481 | ) 482 | parser.add_argument( 483 | "--train_eval_episodes", 484 | type=int, 485 | default=10, 486 | help="How many episodes to run for when evaluating on the training set", 487 | ) 488 | parser.add_argument( 489 | "--warmup_steps", 490 | type=int, 491 | default=1000, 492 | help="Number of uniform random actions to take at the beginning of training", 493 | ) 494 | parser.add_argument( 495 | "--render", 496 | action="store_true", 497 | help="Flag to enable env rendering during training", 498 | ) 499 | parser.add_argument( 500 | "--actor_clip", 501 | type=float, 502 | default=None, 503 | help="Gradient clipping for actor updates", 504 | ) 505 | parser.add_argument( 506 | "--critic_clip", 507 | type=float, 508 | default=None, 509 | help="Gradient clipping for critic updates", 510 | ) 511 | parser.add_argument( 512 | "--name", 513 | type=str, 514 | default="pixel_sac_run", 515 | help="Dir name for saves, (look in ./dc_saves/{name})", 516 | ) 517 | parser.add_argument( 518 | "--actor_l2", 519 | type=float, 520 | default=0.0, 521 | help="L2 regularization coeff for actor network", 522 | ) 523 | parser.add_argument( 524 | "--critic_l2", 525 | type=float, 526 | default=0.0, 527 | help="L2 regularization coeff for critic network", 528 | ) 529 | parser.add_argument( 530 | "--delay", 531 | type=int, 532 | default=2, 533 | help="How many steps to go between actor and target agent updates", 534 | ) 535 | parser.add_argument( 536 | "--save_interval", 537 | type=int, 538 | default=10_000, 539 | help="How many steps to go between saving the agent params to disk", 540 | ) 541 | parser.add_argument( 542 | "--verbosity", 543 | type=int, 544 | default=1, 545 | help="Verbosity > 0 displays a progress bar during training", 546 | ) 547 | parser.add_argument( 548 | "--gradient_updates_per_step", 549 | type=int, 550 | default=1, 551 | help="How many gradient updates to make per training step", 552 | ) 553 | parser.add_argument( 554 | "--prioritized_replay", 555 | action="store_true", 556 | help="Flag that enables use of prioritized experience replay", 557 | ) 558 | parser.add_argument( 559 | "--skip_save_to_disk", 560 | action="store_true", 561 | help="Flag to skip saving agent params to disk during training", 562 | ) 563 | parser.add_argument( 564 | "--skip_log_to_disk", 565 | action="store_true", 566 | help="Flag to skip saving agent performance logs to disk during training", 567 | ) 568 | parser.add_argument( 569 | "--feature_matching_imp", 570 | type=float, 571 | default=0.0, 572 | help="Coefficient for feature matching loss", 573 | ) 574 | parser.add_argument( 575 | "--encoder_lr", 576 | type=float, 577 | default=1e-3, 578 | help="Learning rate for the encoder network", 579 | ) 580 | parser.add_argument( 581 | "--encoder_l2", 582 | type=float, 583 | default=0.0, 584 | help="Weight decay coefficient for pixel encoder network", 585 | ) 586 | parser.add_argument( 587 | "--aug_mix", 588 | type=float, 589 | default=1.0, 590 | help="Fraction of each update batch that is made up of augmented samples", 591 | ) 592 | parser.add_argument( 593 | "--log_std_low", 594 | type=int, 595 | default=-10, 596 | help="Lower bound for log std of action distribution.", 597 | ) 598 | parser.add_argument( 599 | "--log_std_high", 600 | type=int, 601 | default=2, 602 | help="Upper bound for log std of action distribution.", 603 | ) 604 | parser.add_argument( 605 | "--augmentations", 606 | type=str, 607 | default="[DrqAug]", 608 | help="Sequence of image data augmentations to perform during training. e.g [ColorJitterAug,DrQAug]", 609 | ) 610 | -------------------------------------------------------------------------------- /deep_control/sbc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from itertools import chain 4 | import random 5 | 6 | import numpy as np 7 | import tensorboardX 8 | import torch 9 | import tqdm 10 | 11 | from . import envs, nets, replay, run, utils, device 12 | 13 | 14 | class SBCAgent: 15 | def __init__( 16 | self, 17 | obs_space_size, 18 | act_space_size, 19 | log_std_low, 20 | log_std_high, 21 | ensemble_size=5, 22 | actor_net_cls=nets.StochasticActor, 23 | hidden_size=1024, 24 | beta_dist=False, 25 | ): 26 | self.actors = [ 27 | actor_net_cls( 28 | obs_space_size, 29 | act_space_size, 30 | log_std_low, 31 | log_std_high, 32 | dist_impl="beta" if beta_dist else "pyd", 33 | hidden_size=hidden_size, 34 | ) 35 | for _ in range(ensemble_size) 36 | ] 37 | 38 | def to(self, device): 39 | for i, actor in enumerate(self.actors): 40 | self.actors[i] = actor.to(device) 41 | 42 | def eval(self): 43 | for actor in self.actors: 44 | actor.eval() 45 | 46 | def train(self): 47 | for actor in self.actors: 48 | actor.train() 49 | 50 | def save(self, path): 51 | for i, actor in enumerate(self.actors): 52 | actor_path = os.path.join(path, f"actor{i}.pt") 53 | torch.save(actor.state_dict(), actor_path) 54 | 55 | def load(self, path): 56 | for i, actor in enumerate(self.actors): 57 | actor_path = os.path.join(path, f"actor{i}.pt") 58 | actor.load_state_dict(torch.load(actor_path)) 59 | 60 | def forward(self, state, from_cpu=True): 61 | # evaluation forward: 62 | # take the average of the mean of each 63 | # actor's distribution. 64 | if from_cpu: 65 | state = self.process_state(state) 66 | self.eval() 67 | with torch.no_grad(): 68 | act = torch.stack( 69 | [actor.forward(state).mean for actor in self.actors], dim=0 70 | ).mean(0) 71 | self.train() 72 | if from_cpu: 73 | act = self.process_act(act) 74 | return act 75 | 76 | def process_state(self, state): 77 | return torch.from_numpy(state).unsqueeze(0).float().to(device) 78 | 79 | def process_act(self, act): 80 | return act.clamp(-1.0, 1.0).cpu().squeeze(0).numpy() 81 | 82 | 83 | def sbc( 84 | agent, 85 | buffer, 86 | test_env, 87 | num_steps_offline=1_000_000, 88 | batch_size=256, 89 | log_prob_clip=None, 90 | max_episode_steps=100_000, 91 | actor_lr=1e-4, 92 | eval_interval=5000, 93 | eval_episodes=10, 94 | actor_clip=None, 95 | actor_l2=0.0, 96 | save_interval=100_000, 97 | name="sbc_run", 98 | render=False, 99 | save_to_disk=True, 100 | log_to_disk=True, 101 | verbosity=0, 102 | **kwargs, 103 | ): 104 | """ 105 | Stochastic Behavioral Cloning 106 | 107 | A simple approach to offline RL that learns to emulate the 108 | behavior dataset in a supervised way. Uses the stochastic actor 109 | from SAC, and adds some basic ensembling to improve performance 110 | and make this a reasonable baseline. 111 | 112 | For examples of how to set up and run offline RL in deep_control, 113 | see examples/d4rl/sbc.py 114 | """ 115 | if save_to_disk or log_to_disk: 116 | save_dir = utils.make_process_dirs(name) 117 | if log_to_disk: 118 | writer = tensorboardX.SummaryWriter(save_dir) 119 | writer.add_hparams(locals(), {}) 120 | 121 | ########### 122 | ## SETUP ## 123 | ########### 124 | agent.to(device) 125 | agent.train() 126 | 127 | actor_optimizer = torch.optim.Adam( 128 | chain(*(actor.parameters() for actor in agent.actors)), 129 | lr=actor_lr, 130 | weight_decay=actor_l2, 131 | betas=(0.9, 0.999), 132 | ) 133 | 134 | ################### 135 | ## TRAINING LOOP ## 136 | ################### 137 | 138 | steps_iter = range(num_steps_offline) 139 | if verbosity: 140 | steps_iter = tqdm.tqdm(steps_iter) 141 | for step in steps_iter: 142 | learn_sbc( 143 | buffer=buffer, 144 | agent=agent, 145 | batch_size=batch_size, 146 | actor_optimizer=actor_optimizer, 147 | actor_clip=actor_clip, 148 | log_prob_clip=log_prob_clip, 149 | ) 150 | 151 | if (step % eval_interval == 0) or (step == num_steps_offline - 1): 152 | mean_return = run.evaluate_agent( 153 | agent, test_env, eval_episodes, max_episode_steps, render 154 | ) 155 | if log_to_disk: 156 | writer.add_scalar("return", mean_return, step) 157 | 158 | if step % save_interval == 0 and save_to_disk: 159 | agent.save(save_dir) 160 | 161 | if save_to_disk: 162 | agent.save(save_dir) 163 | 164 | return agent 165 | 166 | 167 | def learn_sbc( 168 | buffer, 169 | agent, 170 | batch_size, 171 | actor_optimizer, 172 | actor_clip, 173 | log_prob_clip, 174 | ): 175 | agent.train() 176 | 177 | ############################# 178 | ## SUPERVISED ACTOR UPDATE ## 179 | ############################# 180 | 181 | actor_loss = 0.0 182 | for actor in agent.actors: 183 | # sample a fresh batch of data to keep the ensemble unique 184 | batch = buffer.sample(batch_size) 185 | state_batch, action_batch, *_ = batch 186 | state_batch = state_batch.to(device) 187 | action_batch = action_batch.to(device) 188 | 189 | # maximize the probability that the agent takes the demonstration's action in this state 190 | dist = actor(state_batch) 191 | logp_demo_act = dist.log_prob(action_batch).sum(-1, keepdim=True) 192 | if log_prob_clip: 193 | logp_demo_act = logp_demo_act.clamp(-log_prob_clip, log_prob_clip) 194 | actor_loss += -logp_demo_act.mean() 195 | 196 | # actor gradient step 197 | actor_optimizer.zero_grad() 198 | actor_loss.backward() 199 | if actor_clip: 200 | torch.nn.utils.clip_grad_norm_( 201 | chain(*(actor.parameters() for actor in agent.actors)), actor_clip 202 | ) 203 | actor_optimizer.step() 204 | 205 | 206 | def add_args(parser): 207 | parser.add_argument( 208 | "--num_steps_offline", 209 | type=int, 210 | default=10 ** 6, 211 | help="Number of offline training steps", 212 | ) 213 | parser.add_argument( 214 | "--max_episode_steps", 215 | type=int, 216 | default=100000, 217 | help="maximum steps per episode", 218 | ) 219 | parser.add_argument( 220 | "--batch_size", type=int, default=256, help="training batch size" 221 | ) 222 | parser.add_argument( 223 | "--actor_lr", type=float, default=3e-4, help="actor learning rate" 224 | ) 225 | parser.add_argument( 226 | "--eval_interval", 227 | type=int, 228 | default=5000, 229 | help="how often to test the agent without exploration (in episodes)", 230 | ) 231 | parser.add_argument( 232 | "--eval_episodes", 233 | type=int, 234 | default=10, 235 | help="how many episodes to run for when testing", 236 | ) 237 | parser.add_argument( 238 | "--render", 239 | action="store_true", 240 | help="flag to enable env rendering during training", 241 | ) 242 | parser.add_argument( 243 | "--actor_clip", 244 | type=float, 245 | default=None, 246 | help="gradient clipping for actor updates", 247 | ) 248 | parser.add_argument( 249 | "--name", type=str, default="redq_run", help="dir name for saves" 250 | ) 251 | parser.add_argument( 252 | "--actor_l2", 253 | type=float, 254 | default=0.0, 255 | help="L2 regularization coeff for actor network", 256 | ) 257 | parser.add_argument( 258 | "--save_interval", 259 | type=int, 260 | default=100_000, 261 | help="How many steps to go between saving the agent params to disk", 262 | ) 263 | parser.add_argument( 264 | "--verbosity", 265 | type=int, 266 | default=1, 267 | help="verbosity > 0 displays a progress bar during training", 268 | ) 269 | parser.add_argument( 270 | "--skip_save_to_disk", 271 | action="store_true", 272 | help="flag to skip saving agent params to disk during training", 273 | ) 274 | parser.add_argument( 275 | "--skip_log_to_disk", 276 | action="store_true", 277 | help="flag to skip saving agent performance logs to disk during training", 278 | ) 279 | parser.add_argument( 280 | "--log_std_low", 281 | type=float, 282 | default=-10, 283 | help="Lower bound for log std of action distribution.", 284 | ) 285 | parser.add_argument( 286 | "--log_std_high", 287 | type=float, 288 | default=2, 289 | help="Upper bound for log std of action distribution.", 290 | ) 291 | parser.add_argument( 292 | "--ensemble_size", 293 | type=int, 294 | default=5, 295 | help="actor ensemble size", 296 | ) 297 | parser.add_argument( 298 | "--hidden_size", 299 | type=int, 300 | default=1024, 301 | help="actor network hidden dim", 302 | ) 303 | -------------------------------------------------------------------------------- /deep_control/td3.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import os 4 | from itertools import chain 5 | 6 | import numpy as np 7 | import tensorboardX 8 | import torch 9 | import torch.nn.functional as F 10 | import tqdm 11 | 12 | from . import envs, nets, replay, run, utils 13 | 14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 15 | 16 | 17 | class TD3Agent: 18 | def __init__( 19 | self, 20 | obs_space_size, 21 | act_space_size, 22 | actor_net_cls=nets.BaselineActor, 23 | critic_net_cls=nets.BigCritic, 24 | hidden_size=256, 25 | ): 26 | self.actor = actor_net_cls( 27 | obs_space_size, act_space_size, hidden_size=hidden_size 28 | ) 29 | self.critic1 = critic_net_cls( 30 | obs_space_size, act_space_size, hidden_size=hidden_size 31 | ) 32 | self.critic2 = critic_net_cls( 33 | obs_space_size, act_space_size, hidden_size=hidden_size 34 | ) 35 | 36 | def to(self, device): 37 | self.actor = self.actor.to(device) 38 | self.critic1 = self.critic1.to(device) 39 | self.critic2 = self.critic2.to(device) 40 | 41 | def eval(self): 42 | self.actor.eval() 43 | self.critic1.eval() 44 | self.critic2.eval() 45 | 46 | def train(self): 47 | self.actor.train() 48 | self.critic1.train() 49 | self.critic2.train() 50 | 51 | def save(self, path): 52 | actor_path = os.path.join(path, "actor.pt") 53 | critic1_path = os.path.join(path, "critic1.pt") 54 | critic2_path = os.path.join(path, "critic2.pt") 55 | torch.save(self.actor.state_dict(), actor_path) 56 | torch.save(self.critic1.state_dict(), critic1_path) 57 | torch.save(self.critic2.state_dict(), critic2_path) 58 | 59 | def load(self, path): 60 | actor_path = os.path.join(path, "actor.pt") 61 | critic1_path = os.path.join(path, "critic1.pt") 62 | critic2_path = os.path.join(path, "critic2.pt") 63 | self.actor.load_state_dict(torch.load(actor_path)) 64 | self.critic1.load_state_dict(torch.load(critic1_path)) 65 | self.critic2.load_state_dict(torch.load(critic2_path)) 66 | 67 | def forward(self, state, from_cpu=True): 68 | if from_cpu: 69 | state = self.process_state(state) 70 | self.actor.eval() 71 | with torch.no_grad(): 72 | action = self.actor(state) 73 | self.actor.train() 74 | if from_cpu: 75 | action = self.process_act(action) 76 | return action 77 | 78 | def process_state(self, state): 79 | return torch.from_numpy(np.expand_dims(state, 0).astype(np.float32)).to( 80 | utils.device 81 | ) 82 | 83 | def process_act(self, act): 84 | return np.squeeze(act.cpu().numpy(), 0) 85 | 86 | 87 | def td3( 88 | agent, 89 | train_env, 90 | test_env, 91 | buffer, 92 | num_steps=1_000_000, 93 | transitions_per_step=1, 94 | max_episode_steps=100_000, 95 | batch_size=256, 96 | tau=0.005, 97 | actor_lr=1e-4, 98 | critic_lr=1e-3, 99 | gamma=0.99, 100 | sigma_start=0.2, 101 | sigma_final=0.1, 102 | sigma_anneal=100_000, 103 | theta=0.15, 104 | eval_interval=5000, 105 | eval_episodes=10, 106 | warmup_steps=1000, 107 | actor_clip=None, 108 | critic_clip=None, 109 | actor_l2=0.0, 110 | critic_l2=0.0, 111 | delay=2, 112 | target_noise_scale=0.2, 113 | save_interval=100_000, 114 | c=0.5, 115 | name="td3_run", 116 | render=False, 117 | save_to_disk=True, 118 | log_to_disk=True, 119 | verbosity=0, 120 | gradient_updates_per_step=1, 121 | td_reg_coeff=0.0, 122 | td_reg_coeff_decay=0.9999, 123 | infinite_bootstrap=True, 124 | **_, 125 | ): 126 | """ 127 | Train `agent` on `train_env` with Twin Delayed Deep Deterministic Policy 128 | Gradient algorithm, and evaluate on `test_env`. 129 | 130 | Reference: https://arxiv.org/abs/1802.09477 131 | """ 132 | if save_to_disk or log_to_disk: 133 | save_dir = utils.make_process_dirs(name) 134 | if log_to_disk: 135 | # create tb writer, save hparams 136 | writer = tensorboardX.SummaryWriter(save_dir) 137 | writer.add_hparams(locals(), {}) 138 | 139 | agent.to(device) 140 | 141 | # initialize target networks 142 | target_agent = copy.deepcopy(agent) 143 | target_agent.to(device) 144 | utils.hard_update(target_agent.actor, agent.actor) 145 | utils.hard_update(target_agent.critic1, agent.critic1) 146 | utils.hard_update(target_agent.critic2, agent.critic2) 147 | 148 | random_process = utils.OrnsteinUhlenbeckProcess( 149 | size=train_env.action_space.shape, 150 | sigma=sigma_start, 151 | sigma_min=sigma_final, 152 | n_steps_annealing=sigma_anneal, 153 | theta=theta, 154 | ) 155 | 156 | # set up optimizers 157 | critic_optimizer = torch.optim.Adam( 158 | chain( 159 | agent.critic1.parameters(), 160 | agent.critic2.parameters(), 161 | ), 162 | lr=critic_lr, 163 | weight_decay=critic_l2, 164 | betas=(0.9, 0.999), 165 | ) 166 | actor_optimizer = torch.optim.Adam( 167 | agent.actor.parameters(), lr=actor_lr, weight_decay=actor_l2 168 | ) 169 | 170 | run.warmup_buffer(buffer, train_env, warmup_steps, max_episode_steps) 171 | 172 | done = True 173 | 174 | steps_iter = range(num_steps) 175 | if verbosity: 176 | steps_iter = tqdm.tqdm(steps_iter) 177 | 178 | for step in steps_iter: 179 | for _ in range(transitions_per_step): 180 | if done: 181 | state = train_env.reset() 182 | random_process.reset_states() 183 | steps_this_ep = 0 184 | done = False 185 | action = agent.forward(state) 186 | noisy_action = run.exploration_noise(action, random_process) 187 | next_state, reward, done, info = train_env.step(noisy_action) 188 | if infinite_bootstrap: 189 | # allow infinite bootstrapping 190 | if steps_this_ep + 1 == max_episode_steps: 191 | done = False 192 | buffer.push(state, noisy_action, reward, next_state, done) 193 | state = next_state 194 | steps_this_ep += 1 195 | if steps_this_ep >= max_episode_steps: 196 | done = True 197 | 198 | update_policy = step % delay == 0 199 | for _ in range(gradient_updates_per_step): 200 | learn( 201 | buffer=buffer, 202 | target_agent=target_agent, 203 | agent=agent, 204 | actor_optimizer=actor_optimizer, 205 | critic_optimizer=critic_optimizer, 206 | batch_size=batch_size, 207 | target_noise_scale=target_noise_scale, 208 | c=c, 209 | gamma=gamma, 210 | critic_clip=critic_clip, 211 | actor_clip=actor_clip, 212 | td_reg_coeff=td_reg_coeff, 213 | update_policy=update_policy, 214 | ) 215 | 216 | # move target model towards training model 217 | if update_policy: 218 | utils.soft_update(target_agent.actor, agent.actor, tau) 219 | # original td3 impl only updates critic targets with the actor... 220 | utils.soft_update(target_agent.critic1, agent.critic1, tau) 221 | utils.soft_update(target_agent.critic2, agent.critic2, tau) 222 | 223 | # decay td regularization 224 | td_reg_coeff *= td_reg_coeff_decay 225 | 226 | if step % eval_interval == 0 or step == num_steps - 1: 227 | mean_return = run.evaluate_agent( 228 | agent, test_env, eval_episodes, max_episode_steps, render 229 | ) 230 | if log_to_disk: 231 | writer.add_scalar("return", mean_return, step * transitions_per_step) 232 | 233 | if step % save_interval == 0 and save_to_disk: 234 | agent.save(save_dir) 235 | 236 | if save_to_disk: 237 | agent.save(save_dir) 238 | return agent 239 | 240 | 241 | def learn( 242 | buffer, 243 | target_agent, 244 | agent, 245 | actor_optimizer, 246 | critic_optimizer, 247 | batch_size, 248 | target_noise_scale, 249 | c, 250 | gamma, 251 | critic_clip, 252 | actor_clip, 253 | td_reg_coeff, 254 | update_policy=True, 255 | ): 256 | 257 | per = isinstance(buffer, replay.PrioritizedReplayBuffer) 258 | if per: 259 | batch, imp_weights, priority_idxs = buffer.sample(batch_size) 260 | imp_weights = imp_weights.to(device) 261 | else: 262 | batch = buffer.sample(batch_size) 263 | 264 | # prepare transitions for models 265 | state_batch, action_batch, reward_batch, next_state_batch, done_batch = batch 266 | state_batch = state_batch.to(device) 267 | next_state_batch = next_state_batch.to(device) 268 | action_batch = action_batch.to(device) 269 | reward_batch = reward_batch.to(device) 270 | done_batch = done_batch.to(device) 271 | 272 | agent.train() 273 | 274 | with torch.no_grad(): 275 | # create critic targets (clipped double Q learning) 276 | target_action_s1 = target_agent.actor(next_state_batch) 277 | target_noise = torch.clamp( 278 | target_noise_scale * torch.randn(*target_action_s1.shape).to(device), -c, c 279 | ) 280 | # target smoothing 281 | target_action_s1 = torch.clamp( 282 | target_action_s1 + target_noise, 283 | -1.0, 284 | 1.0, 285 | ) 286 | target_action_value_s1 = torch.min( 287 | target_agent.critic1(next_state_batch, target_action_s1), 288 | target_agent.critic2(next_state_batch, target_action_s1), 289 | ) 290 | td_target = reward_batch + gamma * (1.0 - done_batch) * target_action_value_s1 291 | 292 | # update critics 293 | agent_critic1_pred = agent.critic1(state_batch, action_batch) 294 | td_error1 = td_target - agent_critic1_pred 295 | if per: 296 | critic1_loss = (imp_weights * 0.5 * (td_error1 ** 2)).mean() 297 | else: 298 | critic1_loss = 0.5 * (td_error1 ** 2).mean() 299 | agent_critic2_pred = agent.critic2(state_batch, action_batch) 300 | td_error2 = td_target - agent_critic2_pred 301 | if per: 302 | critic2_loss = (imp_weights * 0.5 * (td_error2 ** 2)).mean() 303 | else: 304 | critic2_loss = 0.5 * (td_error2 ** 2).mean() 305 | critic_loss = critic1_loss + critic2_loss 306 | critic_optimizer.zero_grad() 307 | critic_loss.backward() 308 | if critic_clip: 309 | torch.nn.utils.clip_grad_norm_( 310 | chain(agent.critic1.parameters(), agent.critic2.parameters()), critic_clip 311 | ) 312 | critic_optimizer.step() 313 | 314 | if update_policy: 315 | # actor update 316 | agent_actions = agent.actor(state_batch) 317 | actor_loss = -( 318 | agent.critic1(state_batch, agent_actions).mean() 319 | - td_reg_coeff * critic_loss.detach() 320 | ) 321 | actor_optimizer.zero_grad() 322 | actor_loss.backward() 323 | if actor_clip: 324 | torch.nn.utils.clip_grad_norm_(agent.actor.parameters(), actor_clip) 325 | actor_optimizer.step() 326 | 327 | if per: 328 | new_priorities = (abs(td_error1) + 1e-5).cpu().detach().squeeze(1).numpy() 329 | buffer.update_priorities(priority_idxs, new_priorities) 330 | 331 | 332 | def add_args(parser): 333 | parser.add_argument( 334 | "--num_steps", 335 | type=int, 336 | default=10 ** 6, 337 | help="number of training steps", 338 | ) 339 | parser.add_argument( 340 | "--transitions_per_step", 341 | type=int, 342 | default=1, 343 | help="env transitions collected per training step. Defaults to 1, in which case we're training for num_steps total env steps. But when looking for replay ratios < 1, this value will need to be set higher.", 344 | ) 345 | parser.add_argument( 346 | "--max_episode_steps", 347 | type=int, 348 | default=100000, 349 | help="maximum steps per episode", 350 | ) 351 | parser.add_argument( 352 | "--batch_size", type=int, default=256, help="training batch size" 353 | ) 354 | parser.add_argument( 355 | "--tau", type=float, default=0.005, help="for model parameter % update" 356 | ) 357 | parser.add_argument( 358 | "--actor_lr", type=float, default=1e-4, help="actor learning rate" 359 | ) 360 | parser.add_argument( 361 | "--critic_lr", type=float, default=1e-3, help="critic learning rate" 362 | ) 363 | parser.add_argument( 364 | "--gamma", type=float, default=0.99, help="gamma, the discount factor" 365 | ) 366 | parser.add_argument("--sigma_final", type=float, default=0.1) 367 | parser.add_argument( 368 | "--sigma_anneal", 369 | type=float, 370 | default=100_000, 371 | help="How many steps to anneal sigma over.", 372 | ) 373 | parser.add_argument( 374 | "--theta", 375 | type=float, 376 | default=0.15, 377 | help="theta for Ornstein Uhlenbeck process computation", 378 | ) 379 | parser.add_argument( 380 | "--sigma_start", 381 | type=float, 382 | default=0.2, 383 | help="sigma for Ornstein Uhlenbeck process computation", 384 | ) 385 | parser.add_argument( 386 | "--eval_interval", 387 | type=int, 388 | default=5000, 389 | help="how often to test the agent without exploration (in episodes)", 390 | ) 391 | parser.add_argument( 392 | "--eval_episodes", 393 | type=int, 394 | default=10, 395 | help="how many episodes to run for when testing", 396 | ) 397 | parser.add_argument( 398 | "--warmup_steps", type=int, default=1000, help="warmup length, in steps" 399 | ) 400 | parser.add_argument("--render", action="store_true") 401 | parser.add_argument("--actor_clip", type=float, default=None) 402 | parser.add_argument("--critic_clip", type=float, default=None) 403 | parser.add_argument("--name", type=str, default="td3_run") 404 | parser.add_argument("--actor_l2", type=float, default=0.0) 405 | parser.add_argument("--critic_l2", type=float, default=0.0) 406 | parser.add_argument("--delay", type=int, default=2) 407 | parser.add_argument("--target_noise_scale", type=float, default=0.2) 408 | parser.add_argument("--save_interval", type=int, default=100_000) 409 | parser.add_argument("--c", type=float, default=0.5) 410 | parser.add_argument("--verbosity", type=int, default=1) 411 | parser.add_argument("--gradient_updates_per_step", type=int, default=1) 412 | parser.add_argument("--prioritized_replay", action="store_true") 413 | parser.add_argument("--buffer_size", type=int, default=1_000_000) 414 | parser.add_argument("--skip_save_to_disk", action="store_true") 415 | parser.add_argument("--skip_log_to_disk", action="store_true") 416 | parser.add_argument("--td_reg_coeff", type=float, default=0.0) 417 | parser.add_argument("--td_reg_coeff_decay", type=float, default=0.9999) 418 | -------------------------------------------------------------------------------- /deep_control/tsr_caql.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import math 4 | import os 5 | from itertools import chain 6 | 7 | import numpy as np 8 | import tensorboardX 9 | import torch 10 | import torch.nn.functional as F 11 | import tqdm 12 | 13 | from . import envs, nets, replay, run, utils, critic_searchers, grac 14 | 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | 17 | TSR_CAQLAgent = grac.GRACAgent 18 | 19 | 20 | def tsr_caql( 21 | agent, 22 | buffer, 23 | train_env, 24 | test_env, 25 | num_steps=1_000_000, 26 | transitions_per_step=1, 27 | max_critic_updates_per_step=20, 28 | critic_target_improvement_init=0.7, 29 | critic_target_improvement_final=0.9, 30 | gamma=0.99, 31 | batch_size=512, 32 | actor_lr=1e-4, 33 | critic_lr=1e-4, 34 | eval_interval=5000, 35 | eval_episodes=10, 36 | warmup_steps=1000, 37 | actor_clip=None, 38 | critic_clip=None, 39 | name="tsr_caql_run", 40 | max_episode_steps=100_000, 41 | render=False, 42 | save_interval=100_000, 43 | verbosity=0, 44 | critic_l2=0.0, 45 | actor_l2=0.0, 46 | log_to_disk=True, 47 | save_to_disk=True, 48 | debug_logs=False, 49 | infinite_bootstrap=True, 50 | **kwargs, 51 | ): 52 | if save_to_disk or log_to_disk: 53 | save_dir = utils.make_process_dirs(name) 54 | if log_to_disk: 55 | # create tb writer, save hparams 56 | writer = tensorboardX.SummaryWriter(save_dir) 57 | writer.add_hparams(locals(), {}) 58 | 59 | # no target networks! 60 | agent.to(device) 61 | agent.cem.batch_size = batch_size 62 | agent.train() 63 | 64 | # the critic target improvement ratio is annealed during training 65 | critic_target_imp_slope = ( 66 | critic_target_improvement_final - critic_target_improvement_init 67 | ) / num_steps 68 | current_target_imp = lambda step: min( 69 | critic_target_improvement_init + critic_target_imp_slope * step, 70 | critic_target_improvement_final, 71 | ) 72 | 73 | # set up optimizers 74 | critic_optimizer = torch.optim.Adam( 75 | chain( 76 | agent.critic1.parameters(), 77 | agent.critic2.parameters(), 78 | ), 79 | lr=critic_lr, 80 | weight_decay=critic_l2, 81 | betas=(0.9, 0.999), 82 | ) 83 | actor_optimizer = torch.optim.Adam( 84 | agent.actor.parameters(), 85 | lr=actor_lr, 86 | weight_decay=actor_l2, 87 | betas=(0.9, 0.999), 88 | ) 89 | 90 | # warmup the replay buffer with random actions 91 | run.warmup_buffer(buffer, train_env, warmup_steps, max_episode_steps) 92 | 93 | steps_iter = range(num_steps) 94 | if verbosity: 95 | steps_iter = tqdm.tqdm(steps_iter) 96 | 97 | done = True 98 | for step in steps_iter: 99 | # collect experience 100 | for _ in range(transitions_per_step): 101 | if done: 102 | state = train_env.reset() 103 | steps_this_ep = 0 104 | done = False 105 | action = agent.sample_action(state) 106 | next_state, reward, done, info = train_env.step(action) 107 | if infinite_bootstrap and (steps_this_ep + 1 == max_episode_steps): 108 | # allow infinite bootstrapping 109 | done = False 110 | buffer.push(state, action, reward, next_state, done) 111 | state = next_state 112 | steps_this_ep += 1 113 | if steps_this_ep >= max_episode_steps: 114 | done = True 115 | 116 | learning_info = learn( 117 | buffer=buffer, 118 | agent=agent, 119 | actor_optimizer=actor_optimizer, 120 | critic_optimizer=critic_optimizer, 121 | target_entropy=-train_env.action_space.shape[0], 122 | critic_target_improvement=current_target_imp(step), 123 | max_critic_updates_per_step=max_critic_updates_per_step, 124 | batch_size=batch_size, 125 | gamma=gamma, 126 | critic_clip=critic_clip, 127 | actor_clip=actor_clip, 128 | ) 129 | 130 | if debug_logs: 131 | for key, val in learning_info.items(): 132 | writer.add_scalar(key, val.item(), step * transitions_per_step) 133 | 134 | if step % eval_interval == 0 or step == num_steps - 1: 135 | mean_return = run.evaluate_agent( 136 | agent, test_env, eval_episodes, max_episode_steps, render 137 | ) 138 | if log_to_disk: 139 | writer.add_scalar("return", mean_return, step * transitions_per_step) 140 | 141 | if step % save_interval == 0 and save_to_disk: 142 | agent.save(save_dir) 143 | 144 | if save_to_disk: 145 | agent.save(save_dir) 146 | return agent 147 | 148 | 149 | def learn( 150 | buffer, 151 | agent, 152 | actor_optimizer, 153 | critic_optimizer, 154 | target_entropy, 155 | critic_target_improvement, 156 | max_critic_updates_per_step, 157 | batch_size, 158 | gamma, 159 | critic_clip, 160 | actor_clip, 161 | ): 162 | per = isinstance(buffer, replay.PrioritizedReplayBuffer) 163 | if per: 164 | batch, imp_weights, priority_idxs = buffer.sample(batch_size) 165 | imp_weights = imp_weights.to(device) 166 | else: 167 | batch = buffer.sample(batch_size) 168 | 169 | # prepare transitions for models 170 | state_batch, action_batch, reward_batch, next_state_batch, done_batch = batch 171 | state_batch = state_batch.to(device) 172 | next_state_batch = next_state_batch.to(device) 173 | action_batch = action_batch.to(device) 174 | reward_batch = reward_batch.to(device) 175 | done_batch = done_batch.to(device) 176 | 177 | agent.train() 178 | 179 | def min_and_argmin(x, y, x_args, y_args): 180 | min_ = torch.min(x, y) 181 | use_x_mask = (x <= y).squeeze(1) 182 | argmin = y_args.clone() 183 | argmin[use_x_mask] = x_args[use_x_mask] 184 | return min_, argmin 185 | 186 | def max_and_argmax(x, y, x_args, y_args): 187 | max_ = torch.max(x, y) 188 | use_x_mask = (x >= y).squeeze(1) 189 | argmax = y_args.clone() 190 | argmax[use_x_mask] = x_args[use_x_mask] 191 | return max_, argmax 192 | 193 | ################### 194 | ## CRITIC UPDATE ## 195 | ################### 196 | with torch.no_grad(): 197 | # sample an action as normal 198 | action_dist_s1 = agent.actor(next_state_batch) 199 | action_s1 = action_dist_s1.sample() 200 | action_value_s1_q1 = agent.critic1(next_state_batch, action_s1) 201 | action_value_s1_q2 = agent.critic2(next_state_batch, action_s1) 202 | 203 | # use CEM to find a higher value action 204 | cem_actions_s1_q1 = agent.cem.search(next_state_batch, action_s1, agent.critic1) 205 | cem_action_value_s1_q1 = agent.critic1(next_state_batch, cem_actions_s1_q1) 206 | cem_actions_s1_q2 = agent.cem.search(next_state_batch, action_s1, agent.critic2) 207 | cem_action_value_s1_q2 = agent.critic2(next_state_batch, cem_actions_s1_q2) 208 | best_q1, best_actions_q1 = max_and_argmax( 209 | action_value_s1_q1, cem_action_value_s1_q1, action_s1, cem_actions_s1_q1 210 | ) 211 | best_q2, best_actions_q2 = max_and_argmax( 212 | action_value_s1_q2, cem_action_value_s1_q2, action_s1, cem_actions_s1_q2 213 | ) 214 | clipped_double_q_s1, final_actions_s1 = min_and_argmin( 215 | best_q1, best_q2, best_actions_q1, best_actions_q2 216 | ) 217 | td_target = reward_batch + gamma * (1.0 - done_batch) * clipped_double_q_s1 218 | y1 = agent.critic1(next_state_batch, final_actions_s1) 219 | y2 = agent.critic2(next_state_batch, final_actions_s1) 220 | 221 | learning_info = { 222 | "td_target": td_target.mean(), 223 | "clip_double_q_s1_mean": clipped_double_q_s1.mean(), 224 | } 225 | 226 | # update critics 227 | critic_loss_initial = None 228 | for critic_update in range(max_critic_updates_per_step): 229 | # standard bellman error 230 | a_critic1_pred = agent.critic1(state_batch, action_batch) 231 | a_critic2_pred = agent.critic2(state_batch, action_batch) 232 | td_error1 = td_target - a_critic1_pred 233 | td_error2 = td_target - a_critic2_pred 234 | 235 | # constraints that discourage large changes in Q(s_{t+1}, a_{t+1}), 236 | a1_critic1_pred = agent.critic1(next_state_batch, final_actions_s1) 237 | a1_critic2_pred = agent.critic2(next_state_batch, final_actions_s1) 238 | a1_constraint1 = y1 - a1_critic1_pred 239 | a1_constraint2 = y2 - a1_critic2_pred 240 | 241 | elementwise_critic_loss = ( 242 | (td_error1 ** 2) 243 | + (td_error2 ** 2) 244 | + (a1_constraint1 ** 2) 245 | + (a1_constraint2 ** 2) 246 | ) 247 | if per: 248 | elementwise_loss *= imp_weights 249 | critic_loss = 0.5 * elementwise_critic_loss.mean() 250 | critic_optimizer.zero_grad() 251 | critic_loss.backward() 252 | if critic_clip: 253 | torch.nn.utils.clip_grad_norm_( 254 | chain(agent.critic1.parameters(), agent.critic2.parameters()), 255 | critic_clip, 256 | ) 257 | critic_optimizer.step() 258 | if critic_update == 0: 259 | critic_loss_initial = critic_loss 260 | elif critic_loss <= critic_target_improvement * critic_loss_initial: 261 | break 262 | 263 | ################## 264 | ## ACTOR UPDATE ## 265 | ################## 266 | # get agent's actions in this state 267 | dist = agent.actor(state_batch) 268 | agent_actions = dist.rsample() 269 | logp_a = dist.log_prob(agent_actions).sum(-1, keepdim=True) 270 | with torch.no_grad(): 271 | agent_action_value = torch.min( 272 | agent.critic1(state_batch, agent_actions), 273 | agent.critic2(state_batch, agent_actions), 274 | ) 275 | # find higher-value actions with CEM 276 | cem_actions_q1 = agent.cem.search(state_batch, agent_actions, agent.critic1) 277 | cem_action_value_q1 = agent.critic1(state_batch, cem_actions_q1) 278 | cem_actions_q2 = agent.cem.search(state_batch, agent_actions, agent.critic2) 279 | cem_action_value_q2 = agent.critic2(state_batch, cem_actions_q2) 280 | cem_action_value, cem_actions = min_and_argmin( 281 | cem_action_value_q1, cem_action_value_q2, cem_actions_q1, cem_actions_q2 282 | ) 283 | logp_cema = dist.log_prob(cem_actions).sum(-1, keepdim=True) 284 | 285 | # how much better are the CEM actions than the agent's? 286 | # clipped for rare cases where CEM actually finds a worse action... 287 | cem_advantage = F.relu(cem_action_value - agent_action_value).detach() 288 | actor_loss = -(cem_advantage * logp_cema).mean() 289 | actor_optimizer.zero_grad() 290 | actor_loss.backward() 291 | if actor_clip: 292 | torch.nn.utils.clip_grad_norm_(agent.actor.parameters(), actor_clip) 293 | actor_optimizer.step() 294 | learning_info.update( 295 | { 296 | "cem_adv": cem_advantage.mean(), 297 | "actor_loss": actor_loss, 298 | "logp_a": logp_a.mean(), 299 | "logp_cema": logp_cema.mean(), 300 | "agent_action_value": agent_action_value.mean(), 301 | "cem_action_value": cem_action_value.mean(), 302 | } 303 | ) 304 | 305 | if per: 306 | new_priorities = (abs(td_error1) + 1e-5).cpu().detach().squeeze(1).numpy() 307 | buffer.update_priorities(priority_idxs, new_priorities) 308 | 309 | return learning_info 310 | 311 | 312 | def add_args(parser): 313 | parser.add_argument( 314 | "--num_steps", type=int, default=10 ** 6, help="number of steps in training" 315 | ) 316 | parser.add_argument( 317 | "--transitions_per_step", 318 | type=int, 319 | default=1, 320 | help="env transitions per training step. Defaults to 1, but will need to \ 321 | be set higher for repaly ratios < 1", 322 | ) 323 | parser.add_argument( 324 | "--max_episode_steps", 325 | type=int, 326 | default=100000, 327 | help="maximum steps per episode", 328 | ) 329 | parser.add_argument( 330 | "--batch_size", type=int, default=512, help="training batch size" 331 | ) 332 | parser.add_argument( 333 | "--actor_lr", type=float, default=1e-4, help="actor learning rate" 334 | ) 335 | parser.add_argument( 336 | "--critic_lr", type=float, default=1e-4, help="critic learning rate" 337 | ) 338 | parser.add_argument( 339 | "--gamma", type=float, default=0.99, help="gamma, the discount factor" 340 | ) 341 | parser.add_argument( 342 | "--buffer_size", type=int, default=1_000_000, help="replay buffer size" 343 | ) 344 | parser.add_argument( 345 | "--eval_interval", 346 | type=int, 347 | default=5000, 348 | help="how often to test the agent without exploration (in episodes)", 349 | ) 350 | parser.add_argument( 351 | "--eval_episodes", 352 | type=int, 353 | default=10, 354 | help="how many episodes to run for when testing", 355 | ) 356 | parser.add_argument( 357 | "--warmup_steps", type=int, default=1000, help="warmup length, in steps" 358 | ) 359 | parser.add_argument( 360 | "--render", 361 | action="store_true", 362 | help="flag to enable env rendering during training", 363 | ) 364 | parser.add_argument( 365 | "--actor_clip", 366 | type=float, 367 | default=None, 368 | help="gradient clipping for actor updates", 369 | ) 370 | parser.add_argument( 371 | "--critic_clip", 372 | type=float, 373 | default=None, 374 | help="gradient clipping for critic updates", 375 | ) 376 | parser.add_argument( 377 | "--name", type=str, default="tsr_caql_run", help="dir name for saves" 378 | ) 379 | parser.add_argument( 380 | "--actor_l2", 381 | type=float, 382 | default=0.0, 383 | help="L2 regularization coeff for actor network", 384 | ) 385 | parser.add_argument( 386 | "--critic_l2", 387 | type=float, 388 | default=0.0, 389 | help="L2 regularization coeff for critic network", 390 | ) 391 | parser.add_argument( 392 | "--save_interval", 393 | type=int, 394 | default=100_000, 395 | help="How many steps to go between saving the agent params to disk", 396 | ) 397 | parser.add_argument( 398 | "--verbosity", 399 | type=int, 400 | default=1, 401 | help="verbosity > 0 displays a progress bar during training", 402 | ) 403 | parser.add_argument( 404 | "--max_critic_updates_per_step", 405 | type=int, 406 | default=10, 407 | help="Max critic updates to make per step. The GRAC paper calls this K", 408 | ) 409 | parser.add_argument( 410 | "--prioritized_replay", 411 | action="store_true", 412 | help="flag that enables use of prioritized experience replay", 413 | ) 414 | parser.add_argument( 415 | "--skip_save_to_disk", 416 | action="store_true", 417 | help="flag to skip saving agent params to disk during training", 418 | ) 419 | parser.add_argument( 420 | "--skip_log_to_disk", 421 | action="store_true", 422 | help="flag to skip saving agent performance logs to disk during training", 423 | ) 424 | parser.add_argument( 425 | "--log_std_low", 426 | type=float, 427 | default=-10, 428 | help="Lower bound for log std of action distribution.", 429 | ) 430 | parser.add_argument( 431 | "--log_std_high", 432 | type=float, 433 | default=2, 434 | help="Upper bound for log std of action distribution.", 435 | ) 436 | parser.add_argument( 437 | "--critic_target_improvement_init", 438 | type=float, 439 | default=0.7, 440 | help="Stop critic updates when loss drops by this factor. The GRAC paper calls this alpha", 441 | ) 442 | parser.add_argument( 443 | "--critic_target_improvement_final", 444 | type=float, 445 | default=0.9, 446 | help="Stop critic updates when loss drops by this factor. The GRAC paper calls this alpha", 447 | ) 448 | parser.add_argument( 449 | "--debug_logs", 450 | action="store_true", 451 | ) 452 | -------------------------------------------------------------------------------- /deep_control/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | from collections import namedtuple 5 | 6 | import gym 7 | import numpy as np 8 | import torch 9 | 10 | from . import run 11 | 12 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | 14 | 15 | def clean_hparams_dict(hparams_dict): 16 | return {key: val for key, val in hparams_dict.items() if val} 17 | 18 | 19 | def get_grad_norm(model): 20 | total_norm = 0.0 21 | for p in model.parameters(): 22 | try: 23 | param = p.grad.data 24 | except AttributeError: 25 | continue 26 | else: 27 | param_norm = param.norm(2) 28 | total_norm += param_norm.item() ** 2 29 | total_norm = total_norm ** (1.0 / 2) 30 | return total_norm 31 | 32 | 33 | def torch_and_pad(x): 34 | if not isinstance(x, np.ndarray): 35 | x = np.array(x) 36 | return torch.from_numpy(x.astype(np.float32)).unsqueeze(0) 37 | 38 | 39 | def mean(lst): 40 | return float(sum(lst)) / len(lst) 41 | 42 | 43 | def make_process_dirs(run_name, base_path="dc_saves"): 44 | base_dir = os.path.join(base_path, run_name) 45 | i = 0 46 | while os.path.exists(base_dir + f"_{i}"): 47 | i += 1 48 | base_dir += f"_{i}" 49 | os.makedirs(base_dir) 50 | return base_dir 51 | 52 | 53 | def compute_conv_output( 54 | inp_shape, kernel_size, padding=(0, 0), dilation=(1, 1), stride=(1, 1) 55 | ): 56 | """ 57 | Compute the shape of the output of a torch Conv2d layer using 58 | the formula from the docs. 59 | 60 | every argument is a tuple corresponding to (height, width), e.g. kernel_size=(3, 4) 61 | """ 62 | height_out = math.floor( 63 | ( 64 | (inp_shape[0] + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) 65 | / stride[0] 66 | ) 67 | + 1 68 | ) 69 | width_out = math.floor( 70 | ( 71 | (inp_shape[1] + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) 72 | / stride[1] 73 | ) 74 | + 1 75 | ) 76 | return height_out, width_out 77 | 78 | 79 | def soft_update(target, source, tau): 80 | for target_param, param in zip(target.parameters(), source.parameters()): 81 | target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau) 82 | 83 | 84 | def hard_update(target, source): 85 | for target_param, param in zip(target.parameters(), source.parameters()): 86 | target_param.data.copy_(param.data) 87 | 88 | 89 | """ This is all from: https://github.com/matthiasplappert/keras-rl/blob/master/rl/random.py """ 90 | 91 | 92 | class AnnealedGaussianProcess: 93 | def __init__(self, mu, sigma, sigma_min, n_steps_annealing): 94 | self.mu = mu 95 | self.sigma = sigma 96 | self.n_steps = 0 97 | 98 | if sigma_min is not None: 99 | self.m = -float(sigma - sigma_min) / float(n_steps_annealing) 100 | self.c = sigma 101 | self.sigma_min = sigma_min 102 | else: 103 | self.m = 0.0 104 | self.c = sigma 105 | self.sigma_min = sigma 106 | 107 | @property 108 | def current_sigma(self): 109 | sigma = max(self.sigma_min, self.m * float(self.n_steps) + self.c) 110 | return sigma 111 | 112 | 113 | class OrnsteinUhlenbeckProcess(AnnealedGaussianProcess): 114 | def __init__( 115 | self, 116 | theta, 117 | mu=0.0, 118 | sigma=1.0, 119 | dt=1e-2, 120 | x0=None, 121 | size=1, 122 | sigma_min=None, 123 | n_steps_annealing=1000, 124 | ): 125 | super(OrnsteinUhlenbeckProcess, self).__init__( 126 | mu=mu, sigma=sigma, sigma_min=sigma_min, n_steps_annealing=n_steps_annealing 127 | ) 128 | self.theta = theta 129 | self.mu = mu 130 | self.dt = dt 131 | self.x0 = x0 132 | self.size = size 133 | self.reset_states() 134 | 135 | def sample(self): 136 | x = ( 137 | self.x_prev 138 | + self.theta * (self.mu - self.x_prev) * self.dt 139 | + self.current_sigma * np.sqrt(self.dt) * np.random.normal(size=self.size) 140 | ) 141 | self.x_prev = x 142 | self.n_steps += 1 143 | return x 144 | 145 | def reset_states(self): 146 | self.x_prev = self.x0 if self.x0 is not None else np.zeros(self.size) 147 | 148 | 149 | class GaussianExplorationNoise: 150 | def __init__(self, size, start_scale=1.0, final_scale=0.1, steps_annealed=1000): 151 | assert start_scale >= final_scale 152 | self.size = size 153 | self.start_scale = start_scale 154 | self.final_scale = final_scale 155 | self.steps_annealed = steps_annealed 156 | self._current_scale = start_scale 157 | self._scale_slope = (start_scale - final_scale) / steps_annealed 158 | 159 | def sample(self): 160 | noise = self._current_scale * torch.randn(*self.size) 161 | self._current_scale = max( 162 | self._current_scale - self._scale_slope, self.final_scale 163 | ) 164 | return noise.numpy() 165 | 166 | def reset_states(self): 167 | pass 168 | -------------------------------------------------------------------------------- /examples/basic_control/ddpg_gym.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import deep_control as dc 4 | 5 | 6 | def train_gym_ddpg(args): 7 | # same training and testing seed 8 | train_env = dc.envs.load_gym(args.env_id, args.seed) 9 | test_env = dc.envs.load_gym(args.env_id, args.seed) 10 | 11 | state_space = train_env.observation_space 12 | action_space = train_env.action_space 13 | 14 | # create agent 15 | agent = dc.ddpg.DDPGAgent(state_space.shape[0], action_space.shape[0]) 16 | 17 | # create replay buffer 18 | if args.prioritized_replay: 19 | buffer_type = dc.replay.PrioritizedReplayBuffer 20 | else: 21 | buffer_type = dc.replay.ReplayBuffer 22 | buffer = buffer_type( 23 | args.buffer_size, 24 | state_shape=state_space.shape, 25 | state_dtype=float, 26 | action_shape=action_space.shape, 27 | ) 28 | 29 | # run ddpg 30 | dc.ddpg.ddpg(agent=agent, train_env=train_env, test_env=test_env, buffer=buffer, **vars(args)) 31 | 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser() 35 | dc.envs.add_gym_args(parser) 36 | dc.ddpg.add_args(parser) 37 | args = parser.parse_args() 38 | train_gym_ddpg(args) 39 | -------------------------------------------------------------------------------- /examples/basic_control/sac_gym.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import deep_control as dc 4 | 5 | 6 | def train_gym_sac(args): 7 | # same training and testing seed 8 | train_env = dc.envs.load_gym(args.env_id, args.seed) 9 | test_env = dc.envs.load_gym(args.env_id, args.seed) 10 | 11 | state_space = train_env.observation_space 12 | action_space = train_env.action_space 13 | 14 | # create agent 15 | agent = dc.sac.SACAgent( 16 | state_space.shape[0], action_space.shape[0], args.log_std_low, args.log_std_high, 17 | ) 18 | 19 | # create replay buffer 20 | if args.prioritized_replay: 21 | buffer_type = dc.replay.PrioritizedReplayBuffer 22 | else: 23 | buffer_type = dc.replay.ReplayBuffer 24 | buffer = buffer_type( 25 | args.buffer_size, 26 | state_shape=state_space.shape, 27 | state_dtype=float, 28 | action_shape=action_space.shape, 29 | ) 30 | 31 | # run sac 32 | dc.sac.sac( 33 | agent=agent, train_env=train_env, test_env=test_env, buffer=buffer, **vars(args) 34 | ) 35 | 36 | 37 | if __name__ == "__main__": 38 | parser = argparse.ArgumentParser() 39 | dc.envs.add_gym_args(parser) 40 | dc.sac.add_args(parser) 41 | args = parser.parse_args() 42 | args.max_episode_steps = 1000 43 | train_gym_sac(args) 44 | -------------------------------------------------------------------------------- /examples/basic_control/sunrise_gym.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import deep_control as dc 4 | 5 | 6 | def train_gym_sunrise(args): 7 | train_env = dc.envs.load_gym(**vars(args)) 8 | test_env = dc.envs.load_gym(**vars(args)) 9 | 10 | obs_shape = train_env.observation_space.shape 11 | action_shape = train_env.action_space.shape 12 | max_action = train_env.action_space.high[0] 13 | 14 | agent = dc.sunrise.SunriseAgent( 15 | obs_shape[0], 16 | action_shape[0], 17 | args.log_std_low, 18 | args.log_std_high, 19 | args.ensemble_size, 20 | args.ucb_bonus, 21 | ) 22 | 23 | # select a replay buffer 24 | if args.prioritized_replay: 25 | buffer_t = dc.replay.PrioritizedReplayBuffer 26 | else: 27 | buffer_t = dc.replay.ReplayBuffer 28 | buffer = buffer_t( 29 | args.buffer_size, 30 | state_dtype=float, 31 | state_shape=train_env.observation_space.shape, 32 | action_shape=train_env.action_space.shape, 33 | ) 34 | 35 | agent = dc.sunrise.sunrise( 36 | agent=agent, train_env=train_env, test_env=test_env, buffer=buffer, **vars(args) 37 | ) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | dc.envs.add_gym_args(parser) 43 | # add sunrise-related cl args 44 | dc.sunrise.add_args(parser) 45 | args = parser.parse_args() 46 | train_gym_sunrise(args) 47 | -------------------------------------------------------------------------------- /examples/basic_control/td3_gym.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import deep_control as dc 4 | 5 | 6 | def train_gym_td3(args): 7 | # same training and testing seed 8 | train_env = dc.envs.load_gym(args.env_id, args.seed) 9 | test_env = dc.envs.load_gym(args.env_id, args.seed) 10 | 11 | state_space = train_env.observation_space 12 | action_space = train_env.action_space 13 | 14 | # create agent 15 | agent = dc.td3.TD3Agent(state_space.shape[0], action_space.shape[0]) 16 | 17 | # create replay buffer 18 | if args.prioritized_replay: 19 | buffer_type = dc.replay.PrioritizedReplayBuffer 20 | else: 21 | buffer_type = dc.replay.ReplayBuffer 22 | buffer = buffer_type( 23 | args.buffer_size, 24 | state_shape=state_space.shape, 25 | state_dtype=float, 26 | action_shape=action_space.shape, 27 | ) 28 | 29 | # run td3 30 | dc.td3.td3(agent=agent, train_env=train_env, test_env=test_env, buffer=buffer, **vars(args)) 31 | 32 | if __name__ == "__main__": 33 | parser = argparse.ArgumentParser() 34 | dc.envs.add_gym_args(parser) 35 | dc.td3.add_args(parser) 36 | args = parser.parse_args() 37 | train_gym_td3(args) 38 | -------------------------------------------------------------------------------- /examples/d4rl/awac.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gym 4 | import d4rl 5 | import numpy as np 6 | 7 | import deep_control as dc 8 | 9 | 10 | def train_d4rl_awac(args): 11 | train_env, test_env = gym.make(args.env_id), gym.make(args.env_id) 12 | test_env.seed(args.seed) 13 | train_env.seed(args.seed) 14 | state_space = test_env.observation_space 15 | action_space = test_env.action_space 16 | 17 | # create agent 18 | agent = dc.awac.AWACAgent( 19 | state_space.shape[0], 20 | action_space.shape[0], 21 | args.log_std_low, 22 | args.log_std_high, 23 | ) 24 | 25 | # get offline datset 26 | dset = d4rl.qlearning_dataset(test_env) 27 | dset_size = dset["observations"].shape[0] 28 | # create replay buffer 29 | buffer = dc.replay.PrioritizedReplayBuffer( 30 | size=dset_size, 31 | state_shape=state_space.shape, 32 | state_dtype=float, 33 | action_shape=action_space.shape, 34 | ) 35 | buffer.load_experience( 36 | dset["observations"], 37 | dset["actions"], 38 | dset["rewards"], 39 | dset["next_observations"], 40 | dset["terminals"], 41 | ) 42 | 43 | # run awac 44 | dc.awac.awac( 45 | agent=agent, train_env=train_env, test_env=test_env, buffer=buffer, **vars(args) 46 | ) 47 | 48 | 49 | if __name__ == "__main__": 50 | parser = argparse.ArgumentParser() 51 | dc.envs.add_gym_args(parser) 52 | dc.awac.add_args(parser) 53 | args = parser.parse_args() 54 | train_d4rl_awac(args) 55 | -------------------------------------------------------------------------------- /examples/d4rl/sbc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gym 4 | import d4rl 5 | import numpy as np 6 | 7 | import deep_control as dc 8 | 9 | 10 | def train_d4rl_sbc(args): 11 | test_env = gym.make(args.env_id) 12 | test_env.seed(args.seed) 13 | state_space = test_env.observation_space 14 | action_space = test_env.action_space 15 | 16 | # create agent 17 | agent = dc.sbc.SBCAgent( 18 | state_space.shape[0], 19 | action_space.shape[0], 20 | args.log_std_low, 21 | args.log_std_high, 22 | ) 23 | 24 | # get offline datset 25 | dset = d4rl.qlearning_dataset(test_env) 26 | dset_size = dset["observations"].shape[0] 27 | # create replay buffer 28 | buffer = dc.replay.ReplayBuffer( 29 | size=dset_size, 30 | state_shape=state_space.shape, 31 | state_dtype=float, 32 | action_shape=action_space.shape, 33 | ) 34 | buffer.load_experience( 35 | dset["observations"], 36 | dset["actions"], 37 | dset["rewards"], 38 | dset["next_observations"], 39 | dset["terminals"], 40 | ) 41 | 42 | # run sbc 43 | dc.sbc.sbc(agent=agent, test_env=test_env, buffer=buffer, **vars(args)) 44 | 45 | 46 | if __name__ == "__main__": 47 | parser = argparse.ArgumentParser() 48 | dc.envs.add_gym_args(parser) 49 | dc.sbc.add_args(parser) 50 | args = parser.parse_args() 51 | train_d4rl_sbc(args) 52 | -------------------------------------------------------------------------------- /examples/dmc/ddpg_dmc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import deep_control as dc 4 | 5 | 6 | def train_dmc_ddpg(args): 7 | train_env = dc.envs.load_dmc(**vars(args)) 8 | test_env = dc.envs.load_dmc(**vars(args)) 9 | 10 | obs_shape = train_env.observation_space.shape 11 | action_shape = train_env.action_space.shape 12 | max_action = train_env.action_space.high[0] 13 | 14 | agent = dc.ddpg.DDPGAgent(obs_shape[0], action_shape[0]) 15 | 16 | # select a replay buffer 17 | if args.prioritized_replay: 18 | buffer_t = dc.replay.PrioritizedReplayBuffer 19 | else: 20 | buffer_t = dc.replay.ReplayBuffer 21 | buffer = buffer_t( 22 | args.buffer_size, 23 | state_dtype=float, 24 | state_shape=train_env.observation_space.shape, 25 | action_shape=train_env.action_space.shape, 26 | ) 27 | 28 | agent = dc.ddpg.ddpg( 29 | agent=agent, train_env=train_env, test_env=test_env, buffer=buffer, **vars(args) 30 | ) 31 | 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser() 35 | # add dmc-related cl args 36 | dc.envs.add_dmc_args(parser) 37 | # add sac-related cl args 38 | dc.ddpg.add_args(parser) 39 | args = parser.parse_args() 40 | args.from_pixels = False 41 | args.max_episode_steps = 1000 42 | train_dmc_ddpg(args) 43 | -------------------------------------------------------------------------------- /examples/dmc/discor_dmc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import deep_control as dc 4 | 5 | 6 | def train_dmc_discor(args): 7 | train_env = dc.envs.load_dmc(**vars(args)) 8 | test_env = dc.envs.load_dmc(**vars(args)) 9 | 10 | obs_shape = train_env.observation_space.shape 11 | action_shape = train_env.action_space.shape 12 | max_action = train_env.action_space.high[0] 13 | 14 | agent = dc.discor.DisCorAgent( 15 | obs_shape[0], action_shape[0], args.log_std_low, args.log_std_high 16 | ) 17 | 18 | buffer = dc.replay.ReplayBuffer( 19 | args.buffer_size, 20 | state_dtype=float, 21 | state_shape=train_env.observation_space.shape, 22 | action_shape=train_env.action_space.shape, 23 | ) 24 | 25 | agent = dc.discor.discor( 26 | agent=agent, train_env=train_env, test_env=test_env, buffer=buffer, **vars(args) 27 | ) 28 | 29 | 30 | if __name__ == "__main__": 31 | parser = argparse.ArgumentParser() 32 | # add dmc-related cl args 33 | dc.envs.add_dmc_args(parser) 34 | # add discor-related cl args 35 | dc.discor.add_args(parser) 36 | args = parser.parse_args() 37 | args.from_pixels = False 38 | args.max_episode_steps = 1000 39 | train_dmc_discor(args) 40 | -------------------------------------------------------------------------------- /examples/dmc/grac_dmc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import deep_control as dc 4 | 5 | 6 | def train_dmc_grac(args): 7 | train_env = dc.envs.load_dmc(**vars(args)) 8 | test_env = dc.envs.load_dmc(**vars(args)) 9 | 10 | obs_shape = train_env.observation_space.shape 11 | action_shape = train_env.action_space.shape 12 | max_action = train_env.action_space.high[0] 13 | 14 | agent = dc.grac.GRACAgent( 15 | obs_shape[0], action_shape[0], args.log_std_low, args.log_std_high 16 | ) 17 | 18 | # select a replay buffer 19 | if args.prioritized_replay: 20 | buffer_t = dc.replay.PrioritizedReplayBuffer 21 | else: 22 | buffer_t = dc.replay.ReplayBuffer 23 | buffer = buffer_t( 24 | args.buffer_size, 25 | state_dtype=float, 26 | state_shape=train_env.observation_space.shape, 27 | action_shape=train_env.action_space.shape, 28 | ) 29 | 30 | agent = dc.grac.grac( 31 | agent=agent, train_env=train_env, test_env=test_env, buffer=buffer, **vars(args) 32 | ) 33 | 34 | 35 | if __name__ == "__main__": 36 | parser = argparse.ArgumentParser() 37 | # add dmc-related cl args 38 | dc.envs.add_dmc_args(parser) 39 | # add sgrac-related cl args 40 | dc.grac.add_args(parser) 41 | args = parser.parse_args() 42 | args.from_pixels = False 43 | args.max_episode_steps = 1000 44 | train_dmc_grac(args) 45 | -------------------------------------------------------------------------------- /examples/dmc/redq_dmc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import deep_control as dc 4 | 5 | 6 | def train_dmc_redq(args): 7 | train_env = dc.envs.load_dmc(**vars(args)) 8 | test_env = dc.envs.load_dmc(**vars(args)) 9 | 10 | obs_shape = train_env.observation_space.shape 11 | action_shape = train_env.action_space.shape 12 | max_action = train_env.action_space.high[0] 13 | 14 | agent = dc.redq.REDQAgent( 15 | obs_shape[0], 16 | action_shape[0], 17 | args.log_std_low, 18 | args.log_std_high, 19 | args.critic_ensemble_size, 20 | ) 21 | 22 | # select a replay buffer 23 | if args.prioritized_replay: 24 | buffer_t = dc.replay.PrioritizedReplayBuffer 25 | else: 26 | buffer_t = dc.replay.ReplayBuffer 27 | buffer = buffer_t( 28 | args.buffer_size, 29 | state_dtype=float, 30 | state_shape=train_env.observation_space.shape, 31 | action_shape=train_env.action_space.shape, 32 | ) 33 | 34 | agent = dc.redq.redq( 35 | agent=agent, train_env=train_env, test_env=test_env, buffer=buffer, **vars(args) 36 | ) 37 | 38 | 39 | if __name__ == "__main__": 40 | parser = argparse.ArgumentParser() 41 | # add dmc-related cl args 42 | dc.envs.add_dmc_args(parser) 43 | # add sac-related cl args 44 | dc.redq.add_args(parser) 45 | args = parser.parse_args() 46 | args.from_pixels = False 47 | args.max_episode_steps = 1000 48 | train_dmc_redq(args) 49 | -------------------------------------------------------------------------------- /examples/dmc/sac_aug_dmc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import deep_control as dc 4 | from deep_control.augmentations import * 5 | 6 | 7 | def train_dmc_sac_aug(args): 8 | train_env = dc.envs.load_dmc(**vars(args)) 9 | test_env = dc.envs.load_dmc(**vars(args)) 10 | 11 | obs_shape = train_env.observation_space.shape 12 | action_shape = train_env.action_space.shape 13 | max_action = train_env.action_space.high[0] 14 | 15 | augmentation_lst = [aug(args.batch_size) for aug in eval(args.augmentations)] 16 | augmenter = dc.augmentations.AugmentationSequence(augmentation_lst) 17 | 18 | agent = dc.sac_aug.PixelSACAgent( 19 | obs_shape, action_shape[0], args.log_std_low, args.log_std_high 20 | ) 21 | 22 | # select a replay buffer 23 | if args.prioritized_replay: 24 | buffer_t = dc.replay.PrioritizedReplayBuffer 25 | else: 26 | buffer_t = dc.replay.ReplayBuffer 27 | buffer = buffer_t( 28 | args.buffer_size, 29 | state_dtype=int, 30 | state_shape=train_env.observation_space.shape, 31 | action_shape=train_env.action_space.shape, 32 | ) 33 | 34 | agent = dc.sac_aug.sac_aug( 35 | agent=agent, 36 | train_env=train_env, 37 | test_env=test_env, 38 | buffer=buffer, 39 | augmenter=augmenter, 40 | **vars(args) 41 | ) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | # add dmc-related cl args 47 | dc.envs.add_dmc_args(parser) 48 | # add sac-related cl args 49 | dc.sac_aug.add_args(parser) 50 | args = parser.parse_args() 51 | args.from_pixels = True 52 | args.rgb = True 53 | args.max_episode_steps = (1000 + args.frame_skip - 1) // args.frame_skip 54 | train_dmc_sac_aug(args) 55 | -------------------------------------------------------------------------------- /examples/dmc/sac_dmc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import deep_control as dc 4 | 5 | 6 | def train_dmc_sac(args): 7 | train_env = dc.envs.load_dmc(**vars(args)) 8 | test_env = dc.envs.load_dmc(**vars(args)) 9 | 10 | obs_shape = train_env.observation_space.shape 11 | action_shape = train_env.action_space.shape 12 | max_action = train_env.action_space.high[0] 13 | 14 | agent = dc.sac.SACAgent( 15 | obs_shape[0], action_shape[0], args.log_std_low, args.log_std_high 16 | ) 17 | 18 | # select a replay buffer 19 | if args.prioritized_replay: 20 | buffer_t = dc.replay.PrioritizedReplayBuffer 21 | else: 22 | buffer_t = dc.replay.ReplayBuffer 23 | buffer = buffer_t( 24 | args.buffer_size, 25 | state_dtype=float, 26 | state_shape=train_env.observation_space.shape, 27 | action_shape=train_env.action_space.shape, 28 | ) 29 | 30 | agent = dc.sac.sac( 31 | agent=agent, train_env=train_env, test_env=test_env, buffer=buffer, **vars(args) 32 | ) 33 | 34 | 35 | if __name__ == "__main__": 36 | parser = argparse.ArgumentParser() 37 | # add dmc-related cl args 38 | dc.envs.add_dmc_args(parser) 39 | # add sac-related cl args 40 | dc.sac.add_args(parser) 41 | args = parser.parse_args() 42 | args.from_pixels = False 43 | args.max_episode_steps = 1000 44 | train_dmc_sac(args) 45 | -------------------------------------------------------------------------------- /examples/dmc/sunrise_dmc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import deep_control as dc 4 | 5 | 6 | def train_dmc_sunrise(args): 7 | train_env = dc.envs.load_dmc(**vars(args)) 8 | test_env = dc.envs.load_dmc(**vars(args)) 9 | 10 | obs_shape = train_env.observation_space.shape 11 | action_shape = train_env.action_space.shape 12 | max_action = train_env.action_space.high[0] 13 | 14 | agent = dc.sunrise.SunriseAgent( 15 | obs_shape[0], 16 | action_shape[0], 17 | args.log_std_low, 18 | args.log_std_high, 19 | args.ensemble_size, 20 | args.ucb_bonus, 21 | ) 22 | 23 | # select a replay buffer 24 | if args.prioritized_replay: 25 | buffer_t = dc.replay.PrioritizedReplayBuffer 26 | else: 27 | buffer_t = dc.replay.ReplayBuffer 28 | buffer = buffer_t( 29 | args.buffer_size, 30 | state_dtype=float, 31 | state_shape=train_env.observation_space.shape, 32 | action_shape=train_env.action_space.shape, 33 | ) 34 | 35 | agent = dc.sunrise.sunrise( 36 | agent=agent, train_env=train_env, test_env=test_env, buffer=buffer, **vars(args) 37 | ) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | # add dmc-related cl args 43 | dc.envs.add_dmc_args(parser) 44 | # add sunrise-related cl args 45 | dc.sunrise.add_args(parser) 46 | args = parser.parse_args() 47 | args.from_pixels = False 48 | args.max_episode_steps = 1000 49 | train_dmc_sunrise(args) 50 | -------------------------------------------------------------------------------- /examples/dmc/td3_dmc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import deep_control as dc 4 | 5 | 6 | def train_dmc_td3(args): 7 | train_env = dc.envs.load_dmc(**vars(args)) 8 | test_env = dc.envs.load_dmc(**vars(args)) 9 | 10 | obs_shape = train_env.observation_space.shape 11 | action_shape = train_env.action_space.shape 12 | max_action = train_env.action_space.high[0] 13 | 14 | agent = dc.td3.TD3Agent(obs_shape[0], action_shape[0]) 15 | 16 | # select a replay buffer 17 | if args.prioritized_replay: 18 | buffer_t = dc.replay.PrioritizedReplayBuffer 19 | else: 20 | buffer_t = dc.replay.ReplayBuffer 21 | buffer = buffer_t( 22 | args.buffer_size, 23 | state_dtype=float, 24 | state_shape=train_env.observation_space.shape, 25 | action_shape=train_env.action_space.shape, 26 | ) 27 | 28 | agent = dc.td3.td3( 29 | agent=agent, train_env=train_env, test_env=test_env, buffer=buffer, **vars(args) 30 | ) 31 | 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser() 35 | # add dmc-related cl args 36 | dc.envs.add_dmc_args(parser) 37 | # add sac-related cl args 38 | dc.td3.add_args(parser) 39 | args = parser.parse_args() 40 | args.from_pixels = False 41 | args.max_episode_steps = 1000 42 | train_dmc_td3(args) 43 | -------------------------------------------------------------------------------- /examples/dmcr/sac_aug_dmcr.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import copy 4 | import sys 5 | 6 | import deep_control as dc 7 | 8 | from deep_control.augmentations import * 9 | 10 | from dmc_remastered.benchmarks import * 11 | 12 | 13 | def train_dmcr_sac(args): 14 | benchmark_kwargs = { 15 | "domain": args.domain, 16 | "task": args.task, 17 | "num_levels": args.num_levels, 18 | "frame_stack": args.frame_stack, 19 | "height": 84, 20 | "width": 84, 21 | "frame_skip": args.frame_skip, 22 | "channels_last": False, 23 | } 24 | 25 | if args.benchmark == "visual_generalization": 26 | benchmark = visual_generalization 27 | args.train_eval_episdodes = 100 28 | args.test_eval_episodes = 100 29 | elif args.benchmark == "visual_sim2real": 30 | benchmark = visual_sim2real 31 | args.train_eval_episodes = 100 32 | args.test_eval_episodes = 10 33 | elif args.benchmark == "classic": 34 | benchmark = classic 35 | del benchmark_kwargs["num_levels"] 36 | benchmark_kwargs["visual_seed"] = args.visual_seed 37 | args.train_eval_episodes = 0 38 | args.test_eval_episodes = 10 39 | elif args.benchmark == "control": 40 | seed = random.randint(0, 10000) 41 | train_env = dc.envs.load_dmc( 42 | domain_name=args.domain, 43 | task_name=args.task, 44 | from_pixels=True, 45 | rgb=True, 46 | frame_stack=args.frame_stack, 47 | frame_skip=args.frame_skip, 48 | seed=seed, 49 | ) 50 | test_env = dc.envs.load_dmc( 51 | domain_name=args.domain, 52 | task_name=args.task, 53 | from_pixels=True, 54 | rgb=True, 55 | frame_stack=args.frame_stack, 56 | frame_skip=args.frame_skip, 57 | seed=seed, 58 | ) 59 | args.train_eval_episodes = 0 60 | args.test_eval_episodes = 10 61 | 62 | if args.benchmark != "control": 63 | train_env, test_env = benchmark(**benchmark_kwargs) 64 | 65 | obs_shape = train_env.observation_space.shape 66 | action_shape = train_env.action_space.shape 67 | max_action = train_env.action_space.high[0] 68 | 69 | augmentation_lst = [aug(args.batch_size) for aug in eval(args.augmentations)] 70 | augmenter = AugmentationSequence(augmentation_lst) 71 | 72 | agent = dc.sac_aug.PixelSACAgent( 73 | obs_shape, action_shape[0], args.log_std_low, args.log_std_high 74 | ) 75 | 76 | # select a replay buffer 77 | if args.prioritized_replay: 78 | buffer_t = dc.replay.PrioritizedReplayBuffer 79 | else: 80 | buffer_t = dc.replay.ReplayBuffer 81 | buffer = buffer_t( 82 | args.buffer_size, 83 | state_dtype=int, 84 | state_shape=train_env.observation_space.shape, 85 | action_shape=train_env.action_space.shape, 86 | ) 87 | 88 | dc.sac_aug.sac_aug( 89 | agent=agent, 90 | train_env=train_env, 91 | test_env=test_env, 92 | buffer=buffer, 93 | augmenter=augmenter, 94 | **vars(args), 95 | ) 96 | 97 | 98 | if __name__ == "__main__": 99 | parser = argparse.ArgumentParser() 100 | parser.add_argument("--domain", type=str, default="walker") 101 | parser.add_argument("--task", type=str, default="walk") 102 | parser.add_argument("--benchmark", type=str, default="classic") 103 | parser.add_argument("--visual_seed", type=int, default=0) 104 | parser.add_argument("--num_levels", type=int, default=1_000_000) 105 | parser.add_argument("--frame_stack", type=int, default=3) 106 | parser.add_argument("--frame_skip", type=int, default=2) 107 | dc.sac_aug.add_args(parser) # sac+aug related args 108 | args = parser.parse_args() 109 | 110 | # auto-adjust the max episode steps to compensate for the frame skipping. 111 | # dmc (and dmcr) automatically reset after 1k steps, but this allows for 112 | # infinite bootstrapping (used by CURL and SAC-AE) 113 | args.max_episode_steps = (1000 + args.frame_skip - 1) // args.frame_skip 114 | train_dmcr_sac(args) 115 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gym 2 | numpy 3 | tensorboardX 4 | torch 5 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [aliases] 2 | test=pytest 3 | 4 | [tool:pytest] 5 | addopts = --verbose 6 | python_files = tests/*.py 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="deep_control", 5 | version="0.0.1", 6 | setup_requires=["pytest-runner"], 7 | tests_require=["pytest"], 8 | description="Deep Reinforcement Learning for Continuous Control Tasks", 9 | author="Jake Grigsby", 10 | author_email="jcg6dn@virginia.edu", 11 | license="MIT", 12 | packages=find_packages(), 13 | ) 14 | --------------------------------------------------------------------------------