├── sac ├── __init__.py └── sac_continuos_action.py ├── utils ├── __init__.py ├── nets.py ├── utils.py └── buffer.py ├── dynamics ├── __init__.py ├── util.py └── probabilistic_ensemble.py ├── README.md ├── run.py ├── LICENSE └── .gitignore /sac/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dynamics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SAC+ 2 | 3 | Soft Actor-Critic implementation with SOTA model-free extension (REDQ) and SOTA model-based extension (MBPO). 4 | 5 | Algorithms implemented: 6 | 7 | * [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf) 8 | * [Model-Based Policy Optimization (MBPO)](https://arxiv.org/pdf/1906.08253.pdf) 9 | * [Randomized Ensembled Double Q-Learning (REDQ)](https://arxiv.org/pdf/2101.05982.pdf) 10 | 11 | ## Acknowledgements 12 | 13 | I was inspired by the following repositories: 14 | 15 | * https://github.com/vwxyzjn/cleanrl 16 | * https://github.com/DLR-RM/stable-baselines3 17 | * https://github.com/kzl/lifelong_rl -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import random 3 | import numpy as np 4 | import torch as th 5 | from sac.sac_continuos_action import SAC 6 | SEED = 0 7 | random.seed(SEED) 8 | np.random.seed(SEED) 9 | th.manual_seed(SEED) 10 | 11 | 12 | def run(): 13 | 14 | env = gym.make('Hopper-v2') 15 | env.seed(SEED) 16 | 17 | agent = SAC(env, 18 | gradient_updates=20, 19 | num_q_nets=2, 20 | m_sample=None, 21 | buffer_size=int(4e5), 22 | mbpo=False, 23 | experiment_name=f'sac-hopper-{SEED}', 24 | log=True, 25 | wandb=True) 26 | 27 | agent.learn(total_timesteps=175000) 28 | agent.save() 29 | 30 | if __name__ == '__main__': 31 | run() 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Lucas Alegre 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/nets.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple, Type, Union 2 | from torch import nn 3 | 4 | def create_mlp(input_dim: int, output_dim: int, net_arch: List[int], activation_fn: Type[nn.Module] = nn.ReLU) -> List[nn.Module]: 5 | """ 6 | Create a multi layer perceptron (MLP), which is 7 | a collection of fully-connected layers each followed by an activation function. 8 | 9 | :param input_dim: Dimension of the input vector 10 | :param output_dim: 11 | :param net_arch: Architecture of the neural net 12 | It represents the number of units per layer. 13 | The length of this list is the number of layers. 14 | :param activation_fn: The activation function 15 | to use after each layer. 16 | :param squash_output: Whether to squash the output using a Tanh 17 | activation function 18 | :return: 19 | """ 20 | assert len(net_arch) > 0 21 | modules = [nn.Linear(input_dim, net_arch[0]), activation_fn()] 22 | 23 | for idx in range(len(net_arch) - 1): 24 | modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1])) 25 | modules.append(activation_fn()) 26 | 27 | if output_dim > 0: 28 | last_layer_dim = net_arch[-1] 29 | modules.append(nn.Linear(last_layer_dim, output_dim)) 30 | 31 | return nn.Sequential(*modules) -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch as th 4 | from torch import nn 5 | 6 | def layer_init(layer, method='xavier', weight_gain=1, bias_const=0): 7 | if isinstance(layer, nn.Linear): 8 | if method == "xavier": 9 | th.nn.init.xavier_uniform_(layer.weight, gain=weight_gain) 10 | elif method == "orthogonal": 11 | th.nn.init.orthogonal_(layer.weight, gain=weight_gain) 12 | th.nn.init.constant_(layer.bias, bias_const) 13 | 14 | # Code courtesy of JPH: https://github.com/jparkerholder 15 | def make_gif(policy, env, step_count, state_filter, maxsteps=1000): 16 | envname = env.spec.id 17 | gif_name = '_'.join([envname, str(step_count)]) 18 | state = env.reset() 19 | done = False 20 | steps = [] 21 | rewards = [] 22 | t = 0 23 | while (not done) & (t < maxsteps): 24 | s = env.render('rgb_array') 25 | steps.append(s) 26 | action = policy.get_action(state, state_filter=state_filter, deterministic=True) 27 | action = np.clip(action, env.action_space.low[0], env.action_space.high[0]) 28 | action = action.reshape(len(action), ) 29 | state, reward, done, _ = env.step(action) 30 | rewards.append(reward) 31 | t +=1 32 | print('Final reward :', np.sum(rewards)) 33 | clip = ImageSequenceClip(steps, fps=30) 34 | if not os.path.isdir('gifs'): 35 | os.makedirs('gifs') 36 | clip.write_gif('gifs/{}.gif'.format(gif_name), fps=30) -------------------------------------------------------------------------------- /utils/buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | 5 | class ReplayBuffer: 6 | 7 | def __init__(self, obs_dim, action_dim, rew_dim=1, max_size=100000): 8 | self.max_size = max_size 9 | self.ptr, self.size, = 0, 0 10 | 11 | self.obs = np.zeros((max_size, obs_dim), dtype=np.float32) 12 | self.next_obs = np.zeros((max_size, obs_dim), dtype=np.float32) 13 | self.actions = np.zeros((max_size, action_dim), dtype=np.float32) 14 | self.rewards = np.zeros((max_size, rew_dim), dtype=np.float32) 15 | self.dones = np.zeros((max_size, 1), dtype=np.float32) 16 | 17 | def add(self, obs, action, reward, next_obs, done): 18 | self.obs[self.ptr] = np.array(obs).copy() 19 | self.next_obs[self.ptr] = np.array(next_obs).copy() 20 | self.actions[self.ptr] = np.array(action).copy() 21 | self.rewards[self.ptr] = np.array(reward).copy() 22 | self.dones[self.ptr] = np.array(done).copy() 23 | self.ptr = (self.ptr + 1) % self.max_size 24 | self.size = min(self.size + 1, self.max_size) 25 | 26 | def sample(self, batch_size, replace=True, to_tensor=False, device=None): 27 | inds = np.random.choice(self.size, batch_size, replace=replace) 28 | experience_tuples = (self.obs[inds], self.actions[inds], self.rewards[inds], self.next_obs[inds], self.dones[inds]) 29 | if to_tensor: 30 | return tuple(map(lambda x: th.tensor(x).to(device), experience_tuples)) 31 | else: 32 | return experience_tuples 33 | 34 | def sample_obs(self, batch_size, replace=True, to_tensor=False, device=None): 35 | inds = np.random.choice(self.size, batch_size, replace=replace) 36 | if to_tensor: 37 | return th.tensor(self.obs[inds]).to(device) 38 | else: 39 | return self.obs[inds] 40 | 41 | def get_all_data(self): 42 | inds = np.arange(self.size) 43 | return self.obs[inds], self.actions[inds], self.rewards[inds], self.next_obs[inds], self.dones[inds] 44 | 45 | def __len__(self): 46 | return self.size -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Mine 132 | .vscode 133 | /weights 134 | /wandb 135 | /runs 136 | /old_logs -------------------------------------------------------------------------------- /dynamics/util.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import numpy as np 3 | 4 | 5 | def termination_fn_false(obs, act, next_obs): 6 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 7 | done = np.array([False]).repeat(len(obs)) 8 | done = done[:,None] 9 | return done 10 | 11 | def termination_fn_hopper(obs, act, next_obs): 12 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 13 | height = next_obs[:, 0] 14 | angle = next_obs[:, 1] 15 | not_done = np.isfinite(next_obs).all(axis=-1) \ 16 | * np.abs(next_obs[:,1:] < 100).all(axis=-1) \ 17 | * (height > .7) \ 18 | * (np.abs(angle) < .2) 19 | done = ~not_done 20 | done = done[:,None] 21 | return done 22 | 23 | class FakeEnv: 24 | 25 | def __init__(self, model, env_id=None): 26 | self.model = model 27 | if env_id == 'Hopper-v2': 28 | self.termination_func = termination_fn_hopper 29 | elif env_id == 'HalfCheetah-v2': 30 | self.termination_func = termination_fn_false 31 | else: 32 | raise NotImplementedError 33 | 34 | def _get_logprob(self, x, means, variances): 35 | ''' 36 | x : [ batch_size, obs_dim + 1 ] 37 | means : [ num_models, batch_size, obs_dim + 1 ] 38 | vars : [ num_models, batch_size, obs_dim + 1 ] 39 | ''' 40 | k = x.shape[-1] 41 | 42 | ## [ num_networks, batch_size ] 43 | log_prob = -1/2 * (k * np.log(2*np.pi) + np.log(variances).sum(-1) + (np.power(x-means, 2)/variances).sum(-1)) 44 | 45 | ## [ batch_size ] 46 | prob = np.exp(log_prob).sum(0) 47 | 48 | ## [ batch_size ] 49 | log_prob = np.log(prob) 50 | 51 | #stds = np.std(means,0).mean(-1) 52 | #var_mean = np.var(means, axis=0, ddof=1).mean(axis=-1) 53 | maxes = np.max(np.linalg.norm(variances, axis=-1), axis=0) 54 | 55 | return log_prob, maxes 56 | 57 | def step(self, obs, act, deterministic=False): 58 | assert len(obs.shape) == len(act.shape) 59 | """ if len(obs.shape) == 1: 60 | obs = obs[None] 61 | act = act[None] 62 | return_single = True 63 | else: 64 | return_single = False """ 65 | 66 | inputs = th.cat((obs, act), dim=-1).float().to(self.model.device) 67 | with th.no_grad(): 68 | samples, ensemble_model_means, ensemble_model_logvars = self.model(inputs, deterministic=False, return_dist=True) 69 | obs = obs.detach().cpu().numpy() 70 | samples = samples.detach().cpu().numpy() 71 | #ensemble_model_means = ensemble_model_means.detach().cpu().numpy() 72 | #ensemble_model_logvars = ensemble_model_logvars.detach().cpu().numpy() 73 | #ensemble_model_vars = np.exp(ensemble_model_logvars) 74 | 75 | #ensemble_model_means[:,:,1:] += obs 76 | samples[:,:,1:] += obs 77 | #ensemble_model_stds = np.sqrt(ensemble_model_vars) 78 | 79 | #### choose one model from ensemble 80 | num_models, batch_size, _ = ensemble_model_means.shape 81 | model_inds = np.random.choice(self.model.elites, size=batch_size) 82 | batch_inds = np.arange(0, batch_size) 83 | samples = samples[model_inds, batch_inds] 84 | #model_means = ensemble_model_means[model_inds, batch_inds] 85 | #model_stds = ensemble_model_stds[model_inds, batch_inds] 86 | 87 | #log_prob, dev = self._get_logprob(samples, ensemble_model_means, ensemble_model_vars) 88 | 89 | rewards, next_obs = samples[:,:1], samples[:,1:] 90 | terminals = self.termination_func(obs, act, next_obs) 91 | 92 | #batch_size = model_means.shape[0] 93 | #return_means = np.concatenate((model_means[:,:1], terminals, model_means[:,1:]), axis=-1) 94 | #return_stds = np.concatenate((model_stds[:,:1], np.zeros((batch_size,1)), model_stds[:,1:]), axis=-1) 95 | 96 | """ if return_single: 97 | next_obs = next_obs[0] 98 | return_means = return_means[0] 99 | return_stds = return_stds[0] 100 | rewards = rewards[0] 101 | terminals = terminals[0] """ 102 | 103 | #info = {'mean': return_means, 'std': return_stds, 'log_prob': log_prob, 'dev': dev} 104 | return next_obs, rewards, terminals, {} 105 | -------------------------------------------------------------------------------- /dynamics/probabilistic_ensemble.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch as th 3 | from torch import nn as nn 4 | from torch.nn import functional as F 5 | import numpy as np 6 | import pickle 7 | 8 | #TODO: 9 | # - Better to predict logvar or logstd? 10 | # - Learn logvar or keep it constant? 11 | # - Holdout loss: best ratio? save best checkpoint in epoch? individual improvement? 12 | 13 | class EnsembleLayer(nn.Module): 14 | 15 | def __init__(self, ensemble_size, input_dim, output_dim): 16 | super().__init__() 17 | self.W = nn.Parameter(th.empty((ensemble_size, input_dim, output_dim)), requires_grad=True).float() 18 | nn.init.xavier_uniform_(self.W, gain=nn.init.calculate_gain('relu')) 19 | self.b = nn.Parameter(th.zeros((ensemble_size, 1, output_dim)), requires_grad=True).float() 20 | 21 | def forward(self, x): 22 | # assumes x is 3D: (ensemble_size, batch_size, dimension) 23 | return x @ self.W + self.b 24 | 25 | class ProbabilisticEnsemble(nn.Module): 26 | 27 | def __init__(self, input_dim, output_dim, ensemble_size=5, arch=(200,200,200,200), activation=F.relu, learning_rate=0.001, num_elites=2, device='auto'): 28 | super().__init__() 29 | 30 | self.ensemble_size = ensemble_size 31 | self.input_dim = input_dim 32 | self.output_dim = output_dim * 2 # mean and std 33 | self.activation = activation 34 | self.arch = arch 35 | self.num_elites = num_elites 36 | self.elites = [i for i in range(self.ensemble_size)] 37 | 38 | self.layers = nn.ModuleList() 39 | in_size = input_dim 40 | for hidden_size in self.arch: 41 | self.layers.append(EnsembleLayer(ensemble_size, in_size, hidden_size)) 42 | in_size = hidden_size 43 | self.layers.append(EnsembleLayer(ensemble_size, self.arch[-1], self.output_dim)) 44 | 45 | self.inputs_mu = nn.Parameter(th.zeros(input_dim), requires_grad=False).float() 46 | self.inputs_sigma = nn.Parameter(th.zeros(input_dim), requires_grad=False).float() 47 | 48 | self.max_logvar = nn.Parameter(th.ones(1, output_dim, dtype=th.float32) / 2.0).float() 49 | self.min_logvar = nn.Parameter(-th.ones(1, output_dim, dtype=th.float32) * 10.0).float() 50 | 51 | self.decays = [0.000025, 0.00005, 0.000075, 0.000075, 0.0001] 52 | self.optim = th.optim.Adam([{'params': self.layers[i].parameters(), 'weight_decay': self.decays[i]} for i in range(len(self.layers))] + 53 | [{'params': self.max_logvar}, {'params': self.min_logvar}], lr=learning_rate) 54 | if device == 'auto': 55 | self.device = th.device('cuda') if th.cuda.is_available() else th.device('cpu') 56 | else: 57 | self.device = device 58 | self.to(self.device) 59 | 60 | def forward(self, input, deterministic=False, return_dist=False): 61 | dim = len(input.shape) 62 | # input normalization 63 | h = (input - self.inputs_mu) / self.inputs_sigma 64 | # repeat h to make amenable to parallelization 65 | # if dim = 3, then we probably already did this somewhere else (e.g. bootstrapping in training optimization) 66 | if dim < 3: 67 | h = h.unsqueeze(0) 68 | if dim == 1: 69 | h = h.unsqueeze(0) 70 | h = h.repeat(self.ensemble_size, 1, 1) 71 | 72 | for layer in self.layers[:-1]: 73 | h = layer(h) 74 | h = self.activation(h) 75 | output = self.layers[-1](h) 76 | 77 | # if original dim was 1D, squeeze the extra created layer 78 | if dim == 1: 79 | output = output.squeeze(1) # output is (ensemble_size, output_size) 80 | 81 | mean, logvar = th.chunk(output, 2, dim=-1) 82 | 83 | # Variance clamping to prevent poor numerical predictions 84 | logvar = self.max_logvar - F.softplus(self.max_logvar - logvar) 85 | logvar = self.min_logvar + F.softplus(logvar - self.min_logvar) 86 | 87 | if deterministic: 88 | if return_dist: 89 | return mean, logvar 90 | else: 91 | return mean 92 | else: 93 | std = th.sqrt(th.exp(logvar)) 94 | samples = mean + std * th.randn(std.shape, device=std.device) 95 | if return_dist: 96 | return samples, mean, logvar 97 | else: 98 | return samples 99 | 100 | def compute_loss(self, x, y): 101 | mean, logvar = self.forward(x, deterministic=True, return_dist=True) 102 | inv_var = th.exp(-logvar) 103 | 104 | if len(y.shape) < 3: 105 | y = y.unsqueeze(0).repeat(self.ensemble_size, 1, 1) 106 | 107 | mse_losses = (th.square(mean - y) * inv_var).mean(-1).mean(-1) 108 | var_losses = logvar.mean(-1).mean(-1) 109 | total_losses = (mse_losses + var_losses).sum() 110 | total_losses += 0.01*self.max_logvar.sum() - 0.01*self.min_logvar.sum() 111 | return total_losses 112 | 113 | def compute_mse_losses(self, x, y): 114 | mean = self.forward(x, deterministic=True, return_dist=False) 115 | if len(y.shape) < 3: 116 | y = y.unsqueeze(0).repeat(self.ensemble_size, 1, 1) 117 | mse_losses = (mean - y)**2 118 | return mse_losses.mean(-1).mean(-1) 119 | 120 | def save(self, path): 121 | save_dir = 'weights/' 122 | if not os.path.isdir(save_dir): 123 | os.makedirs(save_dir) 124 | th.save({'ensemble_state_dict': self.state_dict(), 125 | 'ensemble_optimizer_state_dict': self.optim.state_dict()}, path + '.tar') 126 | 127 | def load(self, path): 128 | params = th.load(path) 129 | self.load_state_dict(params['ensemble_state_dict']) 130 | self.optim.load_state_dict(params['ensemble_optimizer_state_dict']) 131 | 132 | def fit_input_stats(self, data): 133 | mu = np.mean(data, axis=0, keepdims=True) 134 | sigma = np.std(data, axis=0, keepdims=True) 135 | sigma[sigma < 1e-12] = 1.0 136 | self.inputs_mu.data = th.from_numpy(mu).to(self.device).float() # Can I ommit .data? 137 | self.inputs_sigma.data = th.from_numpy(sigma).to(self.device).float() 138 | 139 | def train_ensemble(self, X, Y, batch_size=256, holdout_ratio=0.1, max_holdout_size=5000, max_epochs_no_improvement=5, max_epochs=200): 140 | self.fit_input_stats(X) 141 | 142 | num_holdout = min(int(X.shape[0] * holdout_ratio), max_holdout_size) 143 | permutation = np.random.permutation(X.shape[0]) 144 | inputs, holdout_inputs = X[permutation[num_holdout:]], X[permutation[:num_holdout]] 145 | targets, holdout_targets = Y[permutation[num_holdout:]], Y[permutation[:num_holdout]] 146 | holdout_inputs = th.from_numpy(holdout_inputs).to(self.device).float() 147 | holdout_targets = th.from_numpy(holdout_targets).to(self.device).float() 148 | 149 | idxs = np.random.randint(inputs.shape[0], size=[self.ensemble_size, inputs.shape[0]]) 150 | num_batches = int(np.ceil(idxs.shape[-1] / batch_size)) 151 | 152 | def shuffle_rows(arr): 153 | idxs = np.argsort(np.random.uniform(size=arr.shape), axis=-1) 154 | return arr[np.arange(arr.shape[0])[:, None], idxs] 155 | 156 | num_epochs_no_improvement = 0 157 | epoch = 0 158 | best_holdout_losses = [float('inf') for _ in range(self.ensemble_size)] 159 | while num_epochs_no_improvement < max_epochs_no_improvement and epoch < max_epochs: 160 | self.train() 161 | for batch_num in range(num_batches): 162 | batch_idxs = idxs[:, batch_num * batch_size : (batch_num + 1) * batch_size] 163 | batch_x, batch_y = inputs[batch_idxs], targets[batch_idxs] 164 | batch_x, batch_y = th.from_numpy(batch_x).to(self.device).float(), th.from_numpy(batch_y).to(self.device).float() 165 | 166 | loss = self.compute_loss(batch_x, batch_y) 167 | self.optim.zero_grad() 168 | loss.backward() 169 | self.optim.step() 170 | 171 | idxs = shuffle_rows(idxs) 172 | 173 | self.eval() 174 | with th.no_grad(): 175 | holdout_losses = self.compute_mse_losses(holdout_inputs, holdout_targets) 176 | holdout_losses = [l.item() for l in holdout_losses] 177 | #print('Epoch:', epoch, 'Holdout losses:', [l.item() for l in holdout_losses]) 178 | 179 | self.elites = np.argsort(holdout_losses)[:self.num_elites] 180 | 181 | improved = False 182 | for i in range(self.ensemble_size): 183 | if epoch == 0 or (best_holdout_losses[i] - holdout_losses[i]) / (best_holdout_losses[i]) > 0.01: 184 | best_holdout_losses[i] = holdout_losses[i] 185 | num_epochs_no_improvement = 0 186 | improved = True 187 | if not improved: 188 | num_epochs_no_improvement += 1 189 | 190 | epoch += 1 191 | 192 | print('Epoch:', epoch, 'Holdout losses:', ', '.join(["%.4f"%hl for hl in holdout_losses])) 193 | return np.mean(holdout_losses) 194 | 195 | if __name__ == '__main__': 196 | 197 | with open('/home/lucas/Desktop/drl-cd/weights/drlcd-cheetah-ns-paper1data0', 'rb') as f: 198 | memory = pickle.load(f) 199 | X, Y = memory.to_train_batch() 200 | 201 | model = ProbabilisticEnsemble(X.shape[1], Y.shape[1]) 202 | model.train_ensemble(X, Y, max_epochs=200) -------------------------------------------------------------------------------- /sac/sac_continuos_action.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Optional, Union 3 | from utils.buffer import ReplayBuffer 4 | import torch as th 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | from torch.distributions.categorical import Categorical 9 | from torch.distributions.normal import Normal 10 | from torch.utils.tensorboard import SummaryWriter 11 | import numpy as np 12 | import time 13 | from utils.nets import create_mlp 14 | from utils.utils import layer_init 15 | from dynamics.util import FakeEnv 16 | from dynamics.probabilistic_ensemble import ProbabilisticEnsemble 17 | 18 | LOG_STD_MAX = 2 19 | LOG_STD_MIN = -20 20 | 21 | 22 | class Policy(nn.Module): 23 | def __init__(self, input_dim, output_dim, action_space, net_arch=[256,256]): 24 | super(Policy, self).__init__() 25 | self.latent_pi = create_mlp(input_dim, -1, net_arch) 26 | self.mean = nn.Linear(net_arch[-1], output_dim) 27 | self.logstd = nn.Linear(net_arch[-1], output_dim) 28 | self.action_low = th.FloatTensor(action_space.low) 29 | self.action_high = th.FloatTensor(action_space.high) 30 | self.apply(layer_init) 31 | 32 | def action_dist(self, obs): 33 | h = self.latent_pi(obs) 34 | mean = self.mean(h) 35 | log_std = self.logstd(h) 36 | log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX) 37 | return mean, log_std 38 | 39 | def scale_action(self, action): 40 | return 2.0 * ((action - self.action_low) / (self.action_high - self.action_low)) - 1.0 41 | 42 | def unscale_action(self, scaled_action): 43 | return self.action_low + (0.5 * (scaled_action + 1.0) * (self.action_high - self.action_low)) 44 | 45 | def forward(self, obs, deterministic=False): 46 | mean, log_std = self.action_dist(obs) 47 | if deterministic: 48 | return th.tanh(mean) 49 | normal = Normal(mean, log_std.exp()) 50 | x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) 51 | action = th.tanh(x_t) 52 | return action 53 | 54 | def action_log_prob(self, obs): 55 | mean, log_std = self.action_dist(obs) 56 | normal = Normal(mean, log_std.exp()) 57 | x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) 58 | action = th.tanh(x_t) 59 | log_prob = normal.log_prob(x_t).sum(dim=1) 60 | log_prob -= th.log((1 - action.pow(2)) + 1e-6).sum(dim=1) 61 | return action, log_prob 62 | 63 | def to(self, device): 64 | self.action_low = self.action_low.to(device) 65 | self.action_high = self.action_high.to(device) 66 | return super(Policy, self).to(device) 67 | 68 | class SoftQNetwork(nn.Module): 69 | def __init__(self, input_dim, net_arch=[256,256]): 70 | super(SoftQNetwork, self).__init__() 71 | self.net = create_mlp(input_dim, 1, net_arch) 72 | self.apply(layer_init) 73 | 74 | def forward(self, input): 75 | q_value = self.net(input) 76 | return q_value 77 | 78 | 79 | class SAC: 80 | # TODO: 81 | # scale action 82 | # load save 83 | def __init__(self, 84 | env, 85 | learning_rate: float = 3e-4, 86 | tau: float = 0.005, 87 | buffer_size: int = 1e6, 88 | alpha: Union[float, str] = 'auto', 89 | net_arch: List = [256, 256], 90 | batch_size: int = 256, 91 | num_q_nets: int = 2, 92 | m_sample: int = None, # None == SAC, 2 == REDQ 93 | learning_starts: int = 100, 94 | gradient_updates: int = 1, 95 | gamma: float = 0.99, 96 | mbpo: bool = False, 97 | dynamics_rollout_len: int = 1, 98 | rollout_dynamics_starts: int = 5000, 99 | real_ratio: float = 0.05, 100 | project_name: str = 'sac', 101 | experiment_name: Optional[str] = None, 102 | log: bool = True, 103 | wandb: bool = True, 104 | device: Union[th.device, str] = 'auto'): 105 | 106 | self.env = env 107 | self.observation_dim = self.env.observation_space.shape[0] 108 | self.action_dim = self.env.action_space.shape[0] 109 | self.learning_rate = learning_rate 110 | self.tau = tau 111 | self.gamma = gamma 112 | self.buffer_size = buffer_size 113 | self.num_q_nets = num_q_nets 114 | self.m_sample = m_sample 115 | self.net_arch = net_arch 116 | self.learning_starts = learning_starts 117 | self.batch_size = batch_size 118 | self.gradient_updates = gradient_updates 119 | self.device = th.device('cuda' if th.cuda.is_available() else 'cpu') if device == 'auto' else device 120 | self.replay_buffer = ReplayBuffer(self.observation_dim, self.action_dim, max_size=buffer_size) 121 | 122 | self.q_nets = [SoftQNetwork(self.observation_dim+self.action_dim, net_arch=net_arch).to(self.device) for _ in range(num_q_nets)] 123 | self.target_q_nets = [SoftQNetwork(self.observation_dim+self.action_dim, net_arch=net_arch).to(self.device) for _ in range(num_q_nets)] 124 | for q_net, target_q_net in zip(self.q_nets, self.target_q_nets): 125 | target_q_net.load_state_dict(q_net.state_dict()) 126 | for param in target_q_net.parameters(): 127 | param.requires_grad = False 128 | 129 | self.policy = Policy(self.observation_dim, self.action_dim, self.env.action_space, net_arch=net_arch).to(self.device) 130 | 131 | self.target_entropy = -th.prod(th.Tensor(self.env.action_space.shape)).item() 132 | if alpha == 'auto': 133 | self.log_alpha = th.zeros(1, requires_grad=True, device=self.device) 134 | self.alpha = self.log_alpha.exp().item() 135 | self.alpha_optim = optim.Adam([self.log_alpha], lr=self.learning_rate) 136 | else: 137 | self.alpha_optim = None 138 | self.alpha = alpha 139 | 140 | q_net_params = [] 141 | for q_net in self.q_nets: 142 | q_net_params += list(q_net.parameters()) 143 | self.q_optim = optim.Adam(q_net_params, lr=self.learning_rate) 144 | self.policy_optim = optim.Adam(list(self.policy.parameters()), lr=self.learning_rate) 145 | 146 | self.mbpo = mbpo 147 | if self.mbpo: 148 | self.dynamics = ProbabilisticEnsemble(input_dim=self.observation_dim + self.action_dim, 149 | output_dim=self.observation_dim + 1, 150 | device=self.device) 151 | self.dynamics_buffer = ReplayBuffer(self.observation_dim, 152 | self.action_dim, 153 | max_size=400000) 154 | self.dynamics_rollout_len = dynamics_rollout_len 155 | self.rollout_dynamics_starts = rollout_dynamics_starts 156 | self.real_ratio = real_ratio 157 | 158 | self.experiment_name = experiment_name if experiment_name is not None else f"sac_{int(time.time())}" 159 | self.log = log 160 | if self.log: 161 | self.writer = SummaryWriter(f"runs/{self.experiment_name}") 162 | if wandb: 163 | import wandb 164 | wandb.init(project=project_name, sync_tensorboard=True, config=self.get_config(), name=self.experiment_name, monitor_gym=True, save_code=True) 165 | self.writer = SummaryWriter(f"/tmp/{self.experiment_name}") 166 | 167 | def get_config(self): 168 | return {'env_id': self.env.unwrapped.spec.id, 169 | 'learning_rate': self.learning_rate, 170 | 'num_q_nets': self.num_q_nets, 171 | 'batch_size': self.batch_size, 172 | 'tau': self.tau, 173 | 'gamma': self.gamma, 174 | 'net_arch': self.net_arch, 175 | 'gradient_updates': self.gradient_updates, 176 | 'm_sample': self.m_sample, 177 | 'buffer_size': self.buffer_size, 178 | 'learning_starts': self.learning_starts, 179 | 'mbpo': self.mbpo, 180 | 'dynamics_rollout_len': self.dynamics_rollout_len} 181 | 182 | def save(self, save_replay_buffer=True): 183 | save_dir = 'weights/' 184 | if not os.path.isdir(save_dir): 185 | os.makedirs(save_dir) 186 | 187 | saved_params = {'policy_state_dict': self.policy.state_dict(), 188 | 'policy_optimizer_state_dict': self.policy_optim.state_dict(), 189 | 'log_alpha': self.log_alpha, 190 | 'alpha_optimizer_state_dict': self.alpha_optim.state_dict()} 191 | for i, (q_net, target_q_net) in enumerate(zip(self.q_nets, self.target_q_nets)): 192 | saved_params['q_net_'+str(i)+'_state_dict'] = q_net.state_dict() 193 | saved_params['target_q_net_'+str(i)+'_state_dict'] = target_q_net.state_dict() 194 | saved_params['q_nets_optimizer_state_dict'] = self.q_optim.state_dict() 195 | 196 | if save_replay_buffer: 197 | saved_params['replay_buffer'] = self.replay_buffer 198 | 199 | th.save(saved_params, save_dir + "/" + self.experiment_name + '.tar') 200 | 201 | def load(self, path, load_replay_buffer=True): 202 | params = th.load(path) 203 | self.policy.load_state_dict(params['policy_state_dict']) 204 | self.policy_optim.load_state_dict(params['policy_optimizer_state_dict']) 205 | self.log_alpha = params['log_alpha'] 206 | self.alpha_optim.load_state_dict(params['alpha_optimizer_state_dict']) 207 | for i, (q_net, target_q_net) in enumerate(zip(self.q_nets, self.target_q_nets)): 208 | q_net.load_state_dict(params['q_net_'+str(i)+'_state_dict']) 209 | target_q_net.load_state_dict(params['target_q_net_'+str(i)+'_state_dict']) 210 | self.q_optim.load_state_dict(params['q_nets_optimizer_state_dict']) 211 | if load_replay_buffer and 'replay_buffer' in params: 212 | self.replay_buffer = params['replay_buffer'] 213 | 214 | def sample_batch_experiences(self): 215 | if not self.mbpo or self.num_timesteps < self.rollout_dynamics_starts: 216 | return self.replay_buffer.sample(self.batch_size, to_tensor=True, device=self.device) 217 | else: 218 | num_real_samples = int(self.batch_size * 0.05) # 5% of real world data 219 | s_obs, s_actions, s_rewards, s_next_obs, s_dones = self.replay_buffer.sample(num_real_samples, to_tensor=True, device=self.device) 220 | m_obs, m_actions, m_rewards, m_next_obs, m_dones = self.dynamics_buffer.sample(self.batch_size-num_real_samples, to_tensor=True, device=self.device) 221 | experience_tuples = (th.cat([s_obs, m_obs], dim=0), 222 | th.cat([s_actions, m_actions], dim=0), 223 | th.cat([s_rewards, m_rewards], dim=0), 224 | th.cat([s_next_obs, m_next_obs], dim=0), 225 | th.cat([s_dones, m_dones], dim=0)) 226 | return experience_tuples 227 | 228 | def rollout_dynamics(self): 229 | # MBPO Planning 230 | with th.no_grad(): 231 | for _ in range(4): # 4 samples of 25000 instead of 1 of 100000 to not allocate all gpu memory 232 | obs = self.replay_buffer.sample_obs(25000, to_tensor=True, device=self.device) 233 | fake_env = FakeEnv(self.dynamics, self.env.unwrapped.spec.id) 234 | for plan_step in range(self.dynamics_rollout_len): 235 | actions = self.policy(obs, deterministic=False) 236 | 237 | next_obs_pred, r_pred, dones, info = fake_env.step(obs, actions) 238 | obs, actions = obs.detach().cpu().numpy(), actions.detach().cpu().numpy() 239 | 240 | for i in range(len(obs)): 241 | self.dynamics_buffer.add(obs[i], actions[i], r_pred[i], next_obs_pred[i], dones[i]) 242 | 243 | nonterm_mask = ~dones.squeeze(-1) 244 | if nonterm_mask.sum() == 0: 245 | break 246 | 247 | obs = next_obs_pred[nonterm_mask] 248 | 249 | @property 250 | def dynamics_train_freq(self): 251 | if self.num_timesteps < 100000: 252 | return 250 253 | else: 254 | return 1000 255 | 256 | def train(self): 257 | for _ in range(self.gradient_updates): 258 | s_obs, s_actions, s_rewards, s_next_obs, s_dones = self.sample_batch_experiences() 259 | 260 | with th.no_grad(): 261 | next_actions, log_probs = self.policy.action_log_prob(s_next_obs) 262 | q_input = th.cat([s_next_obs, next_actions], dim=1) 263 | if self.m_sample is not None: # REDQ sampling 264 | q_targets = th.cat([q_target(q_input) for q_target in np.random.choice(self.target_q_nets, self.m_sample, replace=False)], dim=1) 265 | else: 266 | q_targets = th.cat([q_target(q_input) for q_target in self.target_q_nets], dim=1) 267 | 268 | target_q, _ = th.min(q_targets, dim=1, keepdim=True) 269 | target_q -= self.alpha * log_probs.reshape(-1, 1) 270 | target_q = s_rewards + (1 - s_dones) * self.gamma * target_q 271 | 272 | sa = th.cat([s_obs, s_actions], dim=1) 273 | q_values = [q_net(sa) for q_net in self.q_nets] 274 | critic_loss = (1/self.num_q_nets) * sum([F.mse_loss(q_value, target_q) for q_value in q_values]) 275 | 276 | self.q_optim.zero_grad() 277 | critic_loss.backward() 278 | self.q_optim.step() 279 | 280 | # Polyak update 281 | for q_net, target_q_net in zip(self.q_nets, self.target_q_nets): 282 | for param, target_param in zip(q_net.parameters(), target_q_net.parameters()): 283 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 284 | 285 | # Policy update 286 | actions, log_pi = self.policy.action_log_prob(s_obs) 287 | sa = th.cat([s_obs, actions], dim=1) 288 | q_values_pi = th.cat([q_net(sa) for q_net in self.q_nets], dim=1) 289 | if self.m_sample is not None: 290 | min_q_value_pi = th.mean(q_values_pi, dim=1, keepdim=True) 291 | else: 292 | min_q_value_pi, _ = th.min(q_values_pi, dim=1, keepdim=True) 293 | policy_loss = (self.alpha * log_pi - min_q_value_pi).mean() 294 | 295 | self.policy_optim.zero_grad() 296 | policy_loss.backward() 297 | self.policy_optim.step() 298 | 299 | # Automatic temperature learning 300 | if self.alpha_optim is not None: 301 | alpha_loss = (-self.log_alpha * (log_pi.detach() + self.target_entropy)).mean() 302 | self.alpha_optim.zero_grad() 303 | alpha_loss.backward() 304 | self.alpha_optim.step() 305 | self.alpha = self.log_alpha.exp().item() 306 | 307 | # Log losses 308 | if self.log and self.num_timesteps % 100 == 0: 309 | self.writer.add_scalar("losses/critic_loss", critic_loss.item(), self.num_timesteps) 310 | self.writer.add_scalar("losses/policy_loss", policy_loss.item(), self.num_timesteps) 311 | self.writer.add_scalar("losses/alpha", self.alpha, self.num_timesteps) 312 | if self.alpha_optim is not None: 313 | self.writer.add_scalar("losses/alpha_loss", alpha_loss.item(), self.num_timesteps) 314 | 315 | def learn(self, total_timesteps): 316 | episode_reward = 0.0, 317 | num_episodes = 0 318 | obs, done = self.env.reset(), False 319 | self.num_timesteps = 0 320 | for step in range(1, total_timesteps+1): 321 | self.num_timesteps += 1 322 | 323 | if step < self.learning_starts: 324 | action = self.env.action_space.sample() 325 | else: 326 | with th.no_grad(): 327 | action = self.policy(th.tensor(obs).float().to(self.device)).detach().cpu().numpy() 328 | 329 | next_obs, reward, done, info = self.env.step(action) 330 | 331 | terminal = done if 'TimeLimit.truncated' not in info else not info['TimeLimit.truncated'] 332 | self.replay_buffer.add(obs, action, reward, next_obs, terminal) 333 | 334 | if step >= self.learning_starts: 335 | if self.mbpo: 336 | if self.num_timesteps % self.dynamics_train_freq == 0: 337 | m_obs, m_actions, m_rewards, m_next_obs, m_dones = self.replay_buffer.get_all_data() 338 | X = np.hstack((m_obs, m_actions)) 339 | Y = np.hstack((m_rewards, m_next_obs - m_obs)) 340 | mean_holdout_loss = self.dynamics.train_ensemble(X, Y) 341 | self.writer.add_scalar("dynamics/mean_holdout_loss", mean_holdout_loss, self.num_timesteps) 342 | 343 | if self.num_timesteps >= self.rollout_dynamics_starts and self.num_timesteps % 250 == 0: 344 | self.rollout_dynamics() 345 | 346 | self.train() 347 | 348 | episode_reward += reward 349 | if done: 350 | obs, done = self.env.reset(), False 351 | num_episodes += 1 352 | 353 | if num_episodes % 10 == 0: 354 | print(f"Episode: {num_episodes} Step: {step}, Ep. Reward: {episode_reward}") 355 | if self.log: 356 | self.writer.add_scalar("metrics/episode_reward", episode_reward, self.num_timesteps) 357 | 358 | episode_reward = 0.0 359 | else: 360 | obs = next_obs 361 | 362 | if self.log: 363 | self.writer.close() 364 | self.env.close() --------------------------------------------------------------------------------