├── .gitignore ├── LICENSE ├── README.md ├── agent ├── __init__.py ├── actor.py ├── critic.py └── sac.py ├── conda_env.yml ├── config ├── agent │ └── sac.yaml └── train.yaml ├── data ├── sac.csv └── sac.ipynb ├── figures └── dm_control.png ├── logger.py ├── replay_buffer.py ├── train.py ├── utils.py └── video.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .ipynb_checkpoints/ 3 | exp 4 | tb 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Denis Yarats 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Soft Actor-Critic (SAC) implementation in PyTorch 2 | 3 | This is PyTorch implementation of Soft Actor-Critic (SAC) [[ArXiv]](https://arxiv.org/abs/1812.05905). 4 | 5 | If you use this code in your research project please cite us as: 6 | ``` 7 | @misc{pytorch_sac, 8 | author = {Yarats, Denis and Kostrikov, Ilya}, 9 | title = {Soft Actor-Critic (SAC) implementation in PyTorch}, 10 | year = {2020}, 11 | publisher = {GitHub}, 12 | journal = {GitHub repository}, 13 | howpublished = {\url{https://github.com/denisyarats/pytorch_sac}}, 14 | } 15 | ``` 16 | 17 | ## Requirements 18 | We assume you have access to a gpu that can run CUDA 9.2. Then, the simplest way to install all required dependencies is to create an anaconda environment and activate it: 19 | ``` 20 | conda env create -f conda_env.yml 21 | source activate pytorch_sac 22 | ``` 23 | 24 | ## Instructions 25 | To train an SAC agent on the `cheetah run` task run: 26 | ``` 27 | python train.py env=cheetah_run 28 | ``` 29 | This will produce `exp` folder, where all the outputs are going to be stored including train/eval logs, tensorboard blobs, and evaluation episode videos. One can attacha tensorboard to monitor training by running: 30 | ``` 31 | tensorboard --logdir exp 32 | ``` 33 | 34 | ## Results 35 | An extensive benchmarking of SAC on the DM Control Suite against D4PG. We plot an average performance of SAC over 3 seeds together with p95 confidence intervals. Importantly, we keep the hyperparameters fixed across all the tasks. Note that results for D4PG are reported after 10^8 steps and taken from the original paper. 36 | ![Results](figures/dm_control.png) 37 | -------------------------------------------------------------------------------- /agent/__init__.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class Agent(object): 5 | def reset(self): 6 | """For state-full agents this function performs reseting at the beginning of each episode.""" 7 | pass 8 | 9 | @abc.abstractmethod 10 | def train(self, training=True): 11 | """Sets the agent in either training or evaluation mode.""" 12 | 13 | @abc.abstractmethod 14 | def update(self, replay_buffer, logger, step): 15 | """Main function of the agent that performs learning.""" 16 | 17 | @abc.abstractmethod 18 | def act(self, obs, sample=False): 19 | """Issues an action given an observation.""" 20 | -------------------------------------------------------------------------------- /agent/actor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import math 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from torch import distributions as pyd 7 | 8 | import utils 9 | 10 | 11 | class TanhTransform(pyd.transforms.Transform): 12 | domain = pyd.constraints.real 13 | codomain = pyd.constraints.interval(-1.0, 1.0) 14 | bijective = True 15 | sign = +1 16 | 17 | def __init__(self, cache_size=1): 18 | super().__init__(cache_size=cache_size) 19 | 20 | @staticmethod 21 | def atanh(x): 22 | return 0.5 * (x.log1p() - (-x).log1p()) 23 | 24 | def __eq__(self, other): 25 | return isinstance(other, TanhTransform) 26 | 27 | def _call(self, x): 28 | return x.tanh() 29 | 30 | def _inverse(self, y): 31 | # We do not clamp to the boundary here as it may degrade the performance of certain algorithms. 32 | # one should use `cache_size=1` instead 33 | return self.atanh(y) 34 | 35 | def log_abs_det_jacobian(self, x, y): 36 | # We use a formula that is more numerically stable, see details in the following link 37 | # https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7 38 | return 2. * (math.log(2.) - x - F.softplus(-2. * x)) 39 | 40 | 41 | class SquashedNormal(pyd.transformed_distribution.TransformedDistribution): 42 | def __init__(self, loc, scale): 43 | self.loc = loc 44 | self.scale = scale 45 | 46 | self.base_dist = pyd.Normal(loc, scale) 47 | transforms = [TanhTransform()] 48 | super().__init__(self.base_dist, transforms) 49 | 50 | @property 51 | def mean(self): 52 | mu = self.loc 53 | for tr in self.transforms: 54 | mu = tr(mu) 55 | return mu 56 | 57 | 58 | class DiagGaussianActor(nn.Module): 59 | """torch.distributions implementation of an diagonal Gaussian policy.""" 60 | def __init__(self, obs_dim, action_dim, hidden_dim, hidden_depth, 61 | log_std_bounds): 62 | super().__init__() 63 | 64 | self.log_std_bounds = log_std_bounds 65 | self.trunk = utils.mlp(obs_dim, hidden_dim, 2 * action_dim, 66 | hidden_depth) 67 | 68 | self.outputs = dict() 69 | self.apply(utils.weight_init) 70 | 71 | def forward(self, obs): 72 | mu, log_std = self.trunk(obs).chunk(2, dim=-1) 73 | 74 | # constrain log_std inside [log_std_min, log_std_max] 75 | log_std = torch.tanh(log_std) 76 | log_std_min, log_std_max = self.log_std_bounds 77 | log_std = log_std_min + 0.5 * (log_std_max - log_std_min) * (log_std + 78 | 1) 79 | 80 | std = log_std.exp() 81 | 82 | self.outputs['mu'] = mu 83 | self.outputs['std'] = std 84 | 85 | dist = SquashedNormal(mu, std) 86 | return dist 87 | 88 | def log(self, logger, step): 89 | for k, v in self.outputs.items(): 90 | logger.log_histogram(f'train_actor/{k}_hist', v, step) 91 | 92 | for i, m in enumerate(self.trunk): 93 | if type(m) == nn.Linear: 94 | logger.log_param(f'train_actor/fc{i}', m, step) -------------------------------------------------------------------------------- /agent/critic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | import utils 7 | 8 | 9 | class DoubleQCritic(nn.Module): 10 | """Critic network, employes double Q-learning.""" 11 | def __init__(self, obs_dim, action_dim, hidden_dim, hidden_depth): 12 | super().__init__() 13 | 14 | self.Q1 = utils.mlp(obs_dim + action_dim, hidden_dim, 1, hidden_depth) 15 | self.Q2 = utils.mlp(obs_dim + action_dim, hidden_dim, 1, hidden_depth) 16 | 17 | self.outputs = dict() 18 | self.apply(utils.weight_init) 19 | 20 | def forward(self, obs, action): 21 | assert obs.size(0) == action.size(0) 22 | 23 | obs_action = torch.cat([obs, action], dim=-1) 24 | q1 = self.Q1(obs_action) 25 | q2 = self.Q2(obs_action) 26 | 27 | self.outputs['q1'] = q1 28 | self.outputs['q2'] = q2 29 | 30 | return q1, q2 31 | 32 | def log(self, logger, step): 33 | for k, v in self.outputs.items(): 34 | logger.log_histogram(f'train_critic/{k}_hist', v, step) 35 | 36 | assert len(self.Q1) == len(self.Q2) 37 | for i, (m1, m2) in enumerate(zip(self.Q1, self.Q2)): 38 | assert type(m1) == type(m2) 39 | if type(m1) is nn.Linear: 40 | logger.log_param(f'train_critic/q1_fc{i}', m1, step) 41 | logger.log_param(f'train_critic/q2_fc{i}', m2, step) 42 | -------------------------------------------------------------------------------- /agent/sac.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math 6 | 7 | from agent import Agent 8 | import utils 9 | 10 | import hydra 11 | 12 | 13 | class SACAgent(Agent): 14 | """SAC algorithm.""" 15 | def __init__(self, obs_dim, action_dim, action_range, device, critic_cfg, 16 | actor_cfg, discount, init_temperature, alpha_lr, alpha_betas, 17 | actor_lr, actor_betas, actor_update_frequency, critic_lr, 18 | critic_betas, critic_tau, critic_target_update_frequency, 19 | batch_size, learnable_temperature): 20 | super().__init__() 21 | 22 | self.action_range = action_range 23 | self.device = torch.device(device) 24 | self.discount = discount 25 | self.critic_tau = critic_tau 26 | self.actor_update_frequency = actor_update_frequency 27 | self.critic_target_update_frequency = critic_target_update_frequency 28 | self.batch_size = batch_size 29 | self.learnable_temperature = learnable_temperature 30 | 31 | self.critic = hydra.utils.instantiate(critic_cfg).to(self.device) 32 | self.critic_target = hydra.utils.instantiate(critic_cfg).to( 33 | self.device) 34 | self.critic_target.load_state_dict(self.critic.state_dict()) 35 | 36 | self.actor = hydra.utils.instantiate(actor_cfg).to(self.device) 37 | 38 | self.log_alpha = torch.tensor(np.log(init_temperature)).to(self.device) 39 | self.log_alpha.requires_grad = True 40 | # set target entropy to -|A| 41 | self.target_entropy = -action_dim 42 | 43 | # optimizers 44 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), 45 | lr=actor_lr, 46 | betas=actor_betas) 47 | 48 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), 49 | lr=critic_lr, 50 | betas=critic_betas) 51 | 52 | self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha], 53 | lr=alpha_lr, 54 | betas=alpha_betas) 55 | 56 | self.train() 57 | self.critic_target.train() 58 | 59 | def train(self, training=True): 60 | self.training = training 61 | self.actor.train(training) 62 | self.critic.train(training) 63 | 64 | @property 65 | def alpha(self): 66 | return self.log_alpha.exp() 67 | 68 | def act(self, obs, sample=False): 69 | obs = torch.FloatTensor(obs).to(self.device) 70 | obs = obs.unsqueeze(0) 71 | dist = self.actor(obs) 72 | action = dist.sample() if sample else dist.mean 73 | action = action.clamp(*self.action_range) 74 | assert action.ndim == 2 and action.shape[0] == 1 75 | return utils.to_np(action[0]) 76 | 77 | def update_critic(self, obs, action, reward, next_obs, not_done, logger, 78 | step): 79 | dist = self.actor(next_obs) 80 | next_action = dist.rsample() 81 | log_prob = dist.log_prob(next_action).sum(-1, keepdim=True) 82 | target_Q1, target_Q2 = self.critic_target(next_obs, next_action) 83 | target_V = torch.min(target_Q1, 84 | target_Q2) - self.alpha.detach() * log_prob 85 | target_Q = reward + (not_done * self.discount * target_V) 86 | target_Q = target_Q.detach() 87 | 88 | # get current Q estimates 89 | current_Q1, current_Q2 = self.critic(obs, action) 90 | critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss( 91 | current_Q2, target_Q) 92 | logger.log('train_critic/loss', critic_loss, step) 93 | 94 | # Optimize the critic 95 | self.critic_optimizer.zero_grad() 96 | critic_loss.backward() 97 | self.critic_optimizer.step() 98 | 99 | self.critic.log(logger, step) 100 | 101 | def update_actor_and_alpha(self, obs, logger, step): 102 | dist = self.actor(obs) 103 | action = dist.rsample() 104 | log_prob = dist.log_prob(action).sum(-1, keepdim=True) 105 | actor_Q1, actor_Q2 = self.critic(obs, action) 106 | 107 | actor_Q = torch.min(actor_Q1, actor_Q2) 108 | actor_loss = (self.alpha.detach() * log_prob - actor_Q).mean() 109 | 110 | logger.log('train_actor/loss', actor_loss, step) 111 | logger.log('train_actor/target_entropy', self.target_entropy, step) 112 | logger.log('train_actor/entropy', -log_prob.mean(), step) 113 | 114 | # optimize the actor 115 | self.actor_optimizer.zero_grad() 116 | actor_loss.backward() 117 | self.actor_optimizer.step() 118 | 119 | self.actor.log(logger, step) 120 | 121 | if self.learnable_temperature: 122 | self.log_alpha_optimizer.zero_grad() 123 | alpha_loss = (self.alpha * 124 | (-log_prob - self.target_entropy).detach()).mean() 125 | logger.log('train_alpha/loss', alpha_loss, step) 126 | logger.log('train_alpha/value', self.alpha, step) 127 | alpha_loss.backward() 128 | self.log_alpha_optimizer.step() 129 | 130 | def update(self, replay_buffer, logger, step): 131 | obs, action, reward, next_obs, not_done, not_done_no_max = replay_buffer.sample( 132 | self.batch_size) 133 | 134 | logger.log('train/batch_reward', reward.mean(), step) 135 | 136 | self.update_critic(obs, action, reward, next_obs, not_done_no_max, 137 | logger, step) 138 | 139 | if step % self.actor_update_frequency == 0: 140 | self.update_actor_and_alpha(obs, logger, step) 141 | 142 | if step % self.critic_target_update_frequency == 0: 143 | utils.soft_update_params(self.critic, self.critic_target, 144 | self.critic_tau) 145 | -------------------------------------------------------------------------------- /conda_env.yml: -------------------------------------------------------------------------------- 1 | name: pytorch_sac 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.6 6 | - pytorch 7 | - cudatoolkit=9.2 8 | - absl-py 9 | - pyparsing 10 | - pip: 11 | - termcolor 12 | - git+git://github.com/deepmind/dm_control.git 13 | - git+git://github.com/denisyarats/dmc2gym.git 14 | - tb-nightly 15 | - imageio 16 | - imageio-ffmpeg 17 | - git+git://github.com/facebookresearch/hydra@0.11_branch 18 | -------------------------------------------------------------------------------- /config/agent/sac.yaml: -------------------------------------------------------------------------------- 1 | agent: 2 | name: sac 3 | class: agent.sac.SACAgent 4 | params: 5 | obs_dim: ??? # to be specified later 6 | action_dim: ??? # to be specified later 7 | action_range: ??? # to be specified later 8 | device: ${device} 9 | critic_cfg: ${double_q_critic} 10 | actor_cfg: ${diag_gaussian_actor} 11 | discount: 0.99 12 | init_temperature: 0.1 13 | alpha_lr: 1e-4 14 | alpha_betas: [0.9, 0.999] 15 | actor_lr: 1e-4 16 | actor_betas: [0.9, 0.999] 17 | actor_update_frequency: 1 18 | critic_lr: 1e-4 19 | critic_betas: [0.9, 0.999] 20 | critic_tau: 0.005 21 | critic_target_update_frequency: 2 22 | batch_size: 1024 23 | learnable_temperature: true 24 | 25 | double_q_critic: 26 | class: agent.critic.DoubleQCritic 27 | params: 28 | obs_dim: ${agent.params.obs_dim} 29 | action_dim: ${agent.params.action_dim} 30 | hidden_dim: 1024 31 | hidden_depth: 2 32 | 33 | diag_gaussian_actor: 34 | class: agent.actor.DiagGaussianActor 35 | params: 36 | obs_dim: ${agent.params.obs_dim} 37 | action_dim: ${agent.params.action_dim} 38 | hidden_depth: 2 39 | hidden_dim: 1024 40 | log_std_bounds: [-5, 2] -------------------------------------------------------------------------------- /config/train.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - agent: sac 3 | 4 | env: cheetah_run 5 | 6 | # this needs to be specified manually 7 | experiment: test_exp 8 | 9 | num_train_steps: 1e6 10 | replay_buffer_capacity: ${num_train_steps} 11 | 12 | num_seed_steps: 5000 13 | 14 | eval_frequency: 10000 15 | num_eval_episodes: 10 16 | 17 | device: cuda 18 | 19 | # logger 20 | log_frequency: 10000 21 | log_save_tb: true 22 | 23 | # video recorder 24 | save_video: true 25 | 26 | 27 | seed: 1 28 | 29 | 30 | # hydra configuration 31 | hydra: 32 | name: ${env} 33 | run: 34 | dir: ./exp/${now:%Y.%m.%d}/${now:%H%M}_${agent.name}_${experiment} -------------------------------------------------------------------------------- /figures/dm_control.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/denisyarats/pytorch_sac/81c5b536d3a1c5616b2531e446450df412a064fb/figures/dm_control.png -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | from collections import defaultdict 3 | import json 4 | import os 5 | import csv 6 | import shutil 7 | import torch 8 | import numpy as np 9 | from termcolor import colored 10 | 11 | COMMON_TRAIN_FORMAT = [ 12 | ('episode', 'E', 'int'), 13 | ('step', 'S', 'int'), 14 | ('episode_reward', 'R', 'float'), 15 | ('duration', 'D', 'time') 16 | ] 17 | 18 | COMMON_EVAL_FORMAT = [ 19 | ('episode', 'E', 'int'), 20 | ('step', 'S', 'int'), 21 | ('episode_reward', 'R', 'float') 22 | ] 23 | 24 | 25 | AGENT_TRAIN_FORMAT = { 26 | 'sac': [ 27 | ('batch_reward', 'BR', 'float'), 28 | ('actor_loss', 'ALOSS', 'float'), 29 | ('critic_loss', 'CLOSS', 'float'), 30 | ('alpha_loss', 'TLOSS', 'float'), 31 | ('alpha_value', 'TVAL', 'float'), 32 | ('actor_entropy', 'AENT', 'float') 33 | ] 34 | } 35 | 36 | 37 | class AverageMeter(object): 38 | def __init__(self): 39 | self._sum = 0 40 | self._count = 0 41 | 42 | def update(self, value, n=1): 43 | self._sum += value 44 | self._count += n 45 | 46 | def value(self): 47 | return self._sum / max(1, self._count) 48 | 49 | 50 | class MetersGroup(object): 51 | def __init__(self, file_name, formating): 52 | self._csv_file_name = self._prepare_file(file_name, 'csv') 53 | self._formating = formating 54 | self._meters = defaultdict(AverageMeter) 55 | self._csv_file = open(self._csv_file_name, 'w') 56 | self._csv_writer = None 57 | 58 | def _prepare_file(self, prefix, suffix): 59 | file_name = f'{prefix}.{suffix}' 60 | if os.path.exists(file_name): 61 | os.remove(file_name) 62 | return file_name 63 | 64 | def log(self, key, value, n=1): 65 | self._meters[key].update(value, n) 66 | 67 | def _prime_meters(self): 68 | data = dict() 69 | for key, meter in self._meters.items(): 70 | if key.startswith('train'): 71 | key = key[len('train') + 1:] 72 | else: 73 | key = key[len('eval') + 1:] 74 | key = key.replace('/', '_') 75 | data[key] = meter.value() 76 | return data 77 | 78 | def _dump_to_csv(self, data): 79 | if self._csv_writer is None: 80 | self._csv_writer = csv.DictWriter(self._csv_file, 81 | fieldnames=sorted(data.keys()), 82 | restval=0.0) 83 | self._csv_writer.writeheader() 84 | self._csv_writer.writerow(data) 85 | self._csv_file.flush() 86 | 87 | def _format(self, key, value, ty): 88 | if ty == 'int': 89 | value = int(value) 90 | return f'{key}: {value}' 91 | elif ty == 'float': 92 | return f'{key}: {value:.04f}' 93 | elif ty == 'time': 94 | return f'{key}: {value:04.1f} s' 95 | else: 96 | raise f'invalid format type: {ty}' 97 | 98 | def _dump_to_console(self, data, prefix): 99 | prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green') 100 | pieces = [f'| {prefix: <14}'] 101 | for key, disp_key, ty in self._formating: 102 | value = data.get(key, 0) 103 | pieces.append(self._format(disp_key, value, ty)) 104 | print(' | '.join(pieces)) 105 | 106 | def dump(self, step, prefix, save=True): 107 | if len(self._meters) == 0: 108 | return 109 | if save: 110 | data = self._prime_meters() 111 | data['step'] = step 112 | self._dump_to_csv(data) 113 | self._dump_to_console(data, prefix) 114 | self._meters.clear() 115 | 116 | 117 | class Logger(object): 118 | def __init__(self, 119 | log_dir, 120 | save_tb=False, 121 | log_frequency=10000, 122 | agent='sac'): 123 | self._log_dir = log_dir 124 | self._log_frequency = log_frequency 125 | if save_tb: 126 | tb_dir = os.path.join(log_dir, 'tb') 127 | if os.path.exists(tb_dir): 128 | try: 129 | shutil.rmtree(tb_dir) 130 | except: 131 | print("logger.py warning: Unable to remove tb directory") 132 | pass 133 | self._sw = SummaryWriter(tb_dir) 134 | else: 135 | self._sw = None 136 | # each agent has specific output format for training 137 | assert agent in AGENT_TRAIN_FORMAT 138 | train_format = COMMON_TRAIN_FORMAT + AGENT_TRAIN_FORMAT[agent] 139 | self._train_mg = MetersGroup(os.path.join(log_dir, 'train'), 140 | formating=train_format) 141 | self._eval_mg = MetersGroup(os.path.join(log_dir, 'eval'), 142 | formating=COMMON_EVAL_FORMAT) 143 | 144 | def _should_log(self, step, log_frequency): 145 | log_frequency = log_frequency or self._log_frequency 146 | return step % log_frequency == 0 147 | 148 | def _try_sw_log(self, key, value, step): 149 | if self._sw is not None: 150 | self._sw.add_scalar(key, value, step) 151 | 152 | def _try_sw_log_video(self, key, frames, step): 153 | if self._sw is not None: 154 | frames = torch.from_numpy(np.array(frames)) 155 | frames = frames.unsqueeze(0) 156 | self._sw.add_video(key, frames, step, fps=30) 157 | 158 | def _try_sw_log_histogram(self, key, histogram, step): 159 | if self._sw is not None: 160 | self._sw.add_histogram(key, histogram, step) 161 | 162 | def log(self, key, value, step, n=1, log_frequency=1): 163 | if not self._should_log(step, log_frequency): 164 | return 165 | assert key.startswith('train') or key.startswith('eval') 166 | if type(value) == torch.Tensor: 167 | value = value.item() 168 | self._try_sw_log(key, value / n, step) 169 | mg = self._train_mg if key.startswith('train') else self._eval_mg 170 | mg.log(key, value, n) 171 | 172 | def log_param(self, key, param, step, log_frequency=None): 173 | if not self._should_log(step, log_frequency): 174 | return 175 | self.log_histogram(key + '_w', param.weight.data, step) 176 | if hasattr(param.weight, 'grad') and param.weight.grad is not None: 177 | self.log_histogram(key + '_w_g', param.weight.grad.data, step) 178 | if hasattr(param, 'bias') and hasattr(param.bias, 'data'): 179 | self.log_histogram(key + '_b', param.bias.data, step) 180 | if hasattr(param.bias, 'grad') and param.bias.grad is not None: 181 | self.log_histogram(key + '_b_g', param.bias.grad.data, step) 182 | 183 | def log_video(self, key, frames, step, log_frequency=None): 184 | if not self._should_log(step, log_frequency): 185 | return 186 | assert key.startswith('train') or key.startswith('eval') 187 | self._try_sw_log_video(key, frames, step) 188 | 189 | def log_histogram(self, key, histogram, step, log_frequency=None): 190 | if not self._should_log(step, log_frequency): 191 | return 192 | assert key.startswith('train') or key.startswith('eval') 193 | self._try_sw_log_histogram(key, histogram, step) 194 | 195 | def dump(self, step, save=True, ty=None): 196 | if ty is None: 197 | self._train_mg.dump(step, 'train', save) 198 | self._eval_mg.dump(step, 'eval', save) 199 | elif ty == 'eval': 200 | self._eval_mg.dump(step, 'eval', save) 201 | elif ty == 'train': 202 | self._train_mg.dump(step, 'train', save) 203 | else: 204 | raise f'invalid log type: {ty}' 205 | -------------------------------------------------------------------------------- /replay_buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class ReplayBuffer(object): 6 | """Buffer to store environment transitions.""" 7 | def __init__(self, obs_shape, action_shape, capacity, device): 8 | self.capacity = capacity 9 | self.device = device 10 | 11 | # the proprioceptive obs is stored as float32, pixels obs as uint8 12 | obs_dtype = np.float32 if len(obs_shape) == 1 else np.uint8 13 | 14 | self.obses = np.empty((capacity, *obs_shape), dtype=obs_dtype) 15 | self.next_obses = np.empty((capacity, *obs_shape), dtype=obs_dtype) 16 | self.actions = np.empty((capacity, *action_shape), dtype=np.float32) 17 | self.rewards = np.empty((capacity, 1), dtype=np.float32) 18 | self.not_dones = np.empty((capacity, 1), dtype=np.float32) 19 | self.not_dones_no_max = np.empty((capacity, 1), dtype=np.float32) 20 | 21 | self.idx = 0 22 | self.last_save = 0 23 | self.full = False 24 | 25 | def __len__(self): 26 | return self.capacity if self.full else self.idx 27 | 28 | def add(self, obs, action, reward, next_obs, done, done_no_max): 29 | np.copyto(self.obses[self.idx], obs) 30 | np.copyto(self.actions[self.idx], action) 31 | np.copyto(self.rewards[self.idx], reward) 32 | np.copyto(self.next_obses[self.idx], next_obs) 33 | np.copyto(self.not_dones[self.idx], not done) 34 | np.copyto(self.not_dones_no_max[self.idx], not done_no_max) 35 | 36 | self.idx = (self.idx + 1) % self.capacity 37 | self.full = self.full or self.idx == 0 38 | 39 | def sample(self, batch_size): 40 | idxs = np.random.randint(0, 41 | self.capacity if self.full else self.idx, 42 | size=batch_size) 43 | 44 | obses = torch.as_tensor(self.obses[idxs], device=self.device).float() 45 | actions = torch.as_tensor(self.actions[idxs], device=self.device) 46 | rewards = torch.as_tensor(self.rewards[idxs], device=self.device) 47 | next_obses = torch.as_tensor(self.next_obses[idxs], 48 | device=self.device).float() 49 | not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device) 50 | not_dones_no_max = torch.as_tensor(self.not_dones_no_max[idxs], 51 | device=self.device) 52 | 53 | return obses, actions, rewards, next_obses, not_dones, not_dones_no_max -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import copy 7 | import math 8 | import os 9 | import sys 10 | import time 11 | import pickle as pkl 12 | 13 | from video import VideoRecorder 14 | from logger import Logger 15 | from replay_buffer import ReplayBuffer 16 | import utils 17 | 18 | import dmc2gym 19 | import hydra 20 | 21 | 22 | def make_env(cfg): 23 | """Helper function to create dm_control environment""" 24 | if cfg.env == 'ball_in_cup_catch': 25 | domain_name = 'ball_in_cup' 26 | task_name = 'catch' 27 | else: 28 | domain_name = cfg.env.split('_')[0] 29 | task_name = '_'.join(cfg.env.split('_')[1:]) 30 | 31 | env = dmc2gym.make(domain_name=domain_name, 32 | task_name=task_name, 33 | seed=cfg.seed, 34 | visualize_reward=True) 35 | env.seed(cfg.seed) 36 | assert env.action_space.low.min() >= -1 37 | assert env.action_space.high.max() <= 1 38 | 39 | return env 40 | 41 | 42 | class Workspace(object): 43 | def __init__(self, cfg): 44 | self.work_dir = os.getcwd() 45 | print(f'workspace: {self.work_dir}') 46 | 47 | self.cfg = cfg 48 | 49 | self.logger = Logger(self.work_dir, 50 | save_tb=cfg.log_save_tb, 51 | log_frequency=cfg.log_frequency, 52 | agent=cfg.agent.name) 53 | 54 | utils.set_seed_everywhere(cfg.seed) 55 | self.device = torch.device(cfg.device) 56 | self.env = utils.make_env(cfg) 57 | 58 | cfg.agent.params.obs_dim = self.env.observation_space.shape[0] 59 | cfg.agent.params.action_dim = self.env.action_space.shape[0] 60 | cfg.agent.params.action_range = [ 61 | float(self.env.action_space.low.min()), 62 | float(self.env.action_space.high.max()) 63 | ] 64 | self.agent = hydra.utils.instantiate(cfg.agent) 65 | 66 | self.replay_buffer = ReplayBuffer(self.env.observation_space.shape, 67 | self.env.action_space.shape, 68 | int(cfg.replay_buffer_capacity), 69 | self.device) 70 | 71 | self.video_recorder = VideoRecorder( 72 | self.work_dir if cfg.save_video else None) 73 | self.step = 0 74 | 75 | def evaluate(self): 76 | average_episode_reward = 0 77 | for episode in range(self.cfg.num_eval_episodes): 78 | obs = self.env.reset() 79 | self.agent.reset() 80 | self.video_recorder.init(enabled=(episode == 0)) 81 | done = False 82 | episode_reward = 0 83 | while not done: 84 | with utils.eval_mode(self.agent): 85 | action = self.agent.act(obs, sample=False) 86 | obs, reward, done, _ = self.env.step(action) 87 | self.video_recorder.record(self.env) 88 | episode_reward += reward 89 | 90 | average_episode_reward += episode_reward 91 | self.video_recorder.save(f'{self.step}.mp4') 92 | average_episode_reward /= self.cfg.num_eval_episodes 93 | self.logger.log('eval/episode_reward', average_episode_reward, 94 | self.step) 95 | self.logger.dump(self.step) 96 | 97 | def run(self): 98 | episode, episode_reward, done = 0, 0, True 99 | start_time = time.time() 100 | while self.step < self.cfg.num_train_steps: 101 | if done: 102 | if self.step > 0: 103 | self.logger.log('train/duration', 104 | time.time() - start_time, self.step) 105 | start_time = time.time() 106 | self.logger.dump( 107 | self.step, save=(self.step > self.cfg.num_seed_steps)) 108 | 109 | # evaluate agent periodically 110 | if self.step > 0 and self.step % self.cfg.eval_frequency == 0: 111 | self.logger.log('eval/episode', episode, self.step) 112 | self.evaluate() 113 | 114 | self.logger.log('train/episode_reward', episode_reward, 115 | self.step) 116 | 117 | obs = self.env.reset() 118 | self.agent.reset() 119 | done = False 120 | episode_reward = 0 121 | episode_step = 0 122 | episode += 1 123 | 124 | self.logger.log('train/episode', episode, self.step) 125 | 126 | # sample action for data collection 127 | if self.step < self.cfg.num_seed_steps: 128 | action = self.env.action_space.sample() 129 | else: 130 | with utils.eval_mode(self.agent): 131 | action = self.agent.act(obs, sample=True) 132 | 133 | # run training update 134 | if self.step >= self.cfg.num_seed_steps: 135 | self.agent.update(self.replay_buffer, self.logger, self.step) 136 | 137 | next_obs, reward, done, _ = self.env.step(action) 138 | 139 | # allow infinite bootstrap 140 | done = float(done) 141 | done_no_max = 0 if episode_step + 1 == self.env._max_episode_steps else done 142 | episode_reward += reward 143 | 144 | self.replay_buffer.add(obs, action, reward, next_obs, done, 145 | done_no_max) 146 | 147 | obs = next_obs 148 | episode_step += 1 149 | self.step += 1 150 | 151 | 152 | @hydra.main(config_path='config/train.yaml', strict=True) 153 | def main(cfg): 154 | workspace = Workspace(cfg) 155 | workspace.run() 156 | 157 | 158 | if __name__ == '__main__': 159 | main() 160 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch import distributions as pyd 5 | import torch.nn.functional as F 6 | import gym 7 | import os 8 | from collections import deque 9 | import random 10 | import math 11 | 12 | import dmc2gym 13 | 14 | 15 | def make_env(cfg): 16 | """Helper function to create dm_control environment""" 17 | if cfg.env == 'ball_in_cup_catch': 18 | domain_name = 'ball_in_cup' 19 | task_name = 'catch' 20 | else: 21 | domain_name = cfg.env.split('_')[0] 22 | task_name = '_'.join(cfg.env.split('_')[1:]) 23 | 24 | env = dmc2gym.make(domain_name=domain_name, 25 | task_name=task_name, 26 | seed=cfg.seed, 27 | visualize_reward=True) 28 | env.seed(cfg.seed) 29 | assert env.action_space.low.min() >= -1 30 | assert env.action_space.high.max() <= 1 31 | 32 | return env 33 | 34 | 35 | class eval_mode(object): 36 | def __init__(self, *models): 37 | self.models = models 38 | 39 | def __enter__(self): 40 | self.prev_states = [] 41 | for model in self.models: 42 | self.prev_states.append(model.training) 43 | model.train(False) 44 | 45 | def __exit__(self, *args): 46 | for model, state in zip(self.models, self.prev_states): 47 | model.train(state) 48 | return False 49 | 50 | 51 | class train_mode(object): 52 | def __init__(self, *models): 53 | self.models = models 54 | 55 | def __enter__(self): 56 | self.prev_states = [] 57 | for model in self.models: 58 | self.prev_states.append(model.training) 59 | model.train(True) 60 | 61 | def __exit__(self, *args): 62 | for model, state in zip(self.models, self.prev_states): 63 | model.train(state) 64 | return False 65 | 66 | 67 | def soft_update_params(net, target_net, tau): 68 | for param, target_param in zip(net.parameters(), target_net.parameters()): 69 | target_param.data.copy_(tau * param.data + 70 | (1 - tau) * target_param.data) 71 | 72 | def set_seed_everywhere(seed): 73 | torch.manual_seed(seed) 74 | if torch.cuda.is_available(): 75 | torch.cuda.manual_seed_all(seed) 76 | np.random.seed(seed) 77 | random.seed(seed) 78 | 79 | 80 | def make_dir(*path_parts): 81 | dir_path = os.path.join(*path_parts) 82 | try: 83 | os.mkdir(dir_path) 84 | except OSError: 85 | pass 86 | return dir_path 87 | 88 | def weight_init(m): 89 | """Custom weight init for Conv2D and Linear layers.""" 90 | if isinstance(m, nn.Linear): 91 | nn.init.orthogonal_(m.weight.data) 92 | if hasattr(m.bias, 'data'): 93 | m.bias.data.fill_(0.0) 94 | 95 | 96 | class MLP(nn.Module): 97 | def __init__(self, 98 | input_dim, 99 | hidden_dim, 100 | output_dim, 101 | hidden_depth, 102 | output_mod=None): 103 | super().__init__() 104 | self.trunk = mlp(input_dim, hidden_dim, output_dim, hidden_depth, 105 | output_mod) 106 | self.apply(weight_init) 107 | 108 | def forward(self, x): 109 | return self.trunk(x) 110 | 111 | 112 | def mlp(input_dim, hidden_dim, output_dim, hidden_depth, output_mod=None): 113 | if hidden_depth == 0: 114 | mods = [nn.Linear(input_dim, output_dim)] 115 | else: 116 | mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)] 117 | for i in range(hidden_depth - 1): 118 | mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)] 119 | mods.append(nn.Linear(hidden_dim, output_dim)) 120 | if output_mod is not None: 121 | mods.append(output_mod) 122 | trunk = nn.Sequential(*mods) 123 | return trunk 124 | 125 | def to_np(t): 126 | if t is None: 127 | return None 128 | elif t.nelement() == 0: 129 | return np.array([]) 130 | else: 131 | return t.cpu().detach().numpy() 132 | -------------------------------------------------------------------------------- /video.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import os 3 | import numpy as np 4 | import sys 5 | 6 | import utils 7 | 8 | class VideoRecorder(object): 9 | def __init__(self, root_dir, height=256, width=256, camera_id=0, fps=30): 10 | self.save_dir = utils.make_dir(root_dir, 'video') if root_dir else None 11 | self.height = height 12 | self.width = width 13 | self.camera_id = camera_id 14 | self.fps = fps 15 | self.frames = [] 16 | 17 | def init(self, enabled=True): 18 | self.frames = [] 19 | self.enabled = self.save_dir is not None and enabled 20 | 21 | def record(self, env): 22 | if self.enabled: 23 | frame = env.render(mode='rgb_array', 24 | height=self.height, 25 | width=self.width, 26 | camera_id=self.camera_id) 27 | self.frames.append(frame) 28 | 29 | def save(self, file_name): 30 | if self.enabled: 31 | path = os.path.join(self.save_dir, file_name) 32 | imageio.mimsave(path, self.frames, fps=self.fps) 33 | --------------------------------------------------------------------------------