├── .gitignore ├── LICENSE ├── README.md ├── algorithms ├── __init__.py ├── base_agent.py ├── bc_agent.py ├── dac_agent.py ├── dataset.py ├── ddpg_agent.py ├── expert_dataset.py ├── gail_agent.py ├── ppo_agent.py ├── rollouts.py └── sac_agent.py ├── config └── __init__.py ├── environments ├── __init__.py └── test_env.py ├── main.py ├── networks ├── __init__.py ├── actor_critic.py ├── discriminator.py ├── distributions.py ├── encoder.py └── utils.py ├── requirements.txt ├── run.py ├── trainer.py └── utils ├── __init__.py ├── gym_env.py ├── info_dict.py ├── logger.py ├── mpi.py ├── normalizer.py ├── pytorch.py ├── subproc_vec_env.py └── vec_env.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Mac OS 2 | .DS_Store 3 | *~ 4 | .python-version 5 | 6 | # Log 7 | log 8 | wandb 9 | 10 | # Mujoco 11 | MUJOCO_LOG.TXT 12 | 13 | # Vim 14 | .*.s[a-w][a-z] 15 | Session.vim 16 | 17 | ## VSCode 18 | .vscode 19 | 20 | # Data 21 | *.csv 22 | *.ini 23 | *.npy 24 | *.mp4 25 | *.zip 26 | *.hdf5 27 | *screenlog* 28 | 29 | # Python 30 | ## Byte-compiled / optimized / DLL files 31 | __pycache__/ 32 | *.py[cod] 33 | *$py.class 34 | 35 | ## C extensions 36 | #*.so 37 | 38 | ## Distribution / packaging 39 | .Python 40 | build/ 41 | *.egg-info/ 42 | .installed.cfg 43 | *.egg 44 | 45 | ## Rope project settings 46 | .ropeproject 47 | 48 | ## Jupyter notebook 49 | **/*.ipynb 50 | .ipynb_checkpoints 51 | 52 | ## flake8 53 | .flake8 54 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Youngwoon Lee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Robot Learning Framework for Research 2 | 3 | 4 | ## RL algorithms 5 | * PPO 6 | * DDPG 7 | * TD3 8 | * SAC 9 | 10 | 11 | ## IL algorithms 12 | * BC 13 | * GAIL 14 | * DAC 15 | 16 | 17 | ## Directories 18 | * `run.py`: simply launches `main.py` 19 | * `main.py`: sets up experiment and runs training using `trainer.py` 20 | * `trainer.py`: contains training and evaluation code 21 | * `algorithms/`: implementation of all RL and IL algorithms 22 | * `config/`: hyper-parameters in `config/__init__.py` 23 | * `environments/`: registers environments (OpenAI Gym and Deepmind Control Suite) 24 | * `networks/`: implementation of networks, such as policy and value function 25 | * `utils/`: contains helper functions 26 | 27 | 28 | ## Prerequisites 29 | * Ubuntu 18.04 or above 30 | * Python 3.6 31 | * Mujoco 2.0 32 | 33 | 34 | ## Installation 35 | 36 | 1. Install mujoco 2.0 and add the following environment variables into `~/.bashrc` or `~/.zshrc` 37 | ```bash 38 | # download mujoco 2.0 39 | $ wget https://www.roboti.us/download/mujoco200_linux.zip -O mujoco.zip 40 | $ unzip mujoco.zip -d ~/.mujoco 41 | $ cp -r ~/.mujoco/mujoco200_linux ~/.mujoco/mujoco200 42 | 43 | # copy mujoco license key `mjkey.txt` to `~/.mujoco` 44 | 45 | # add mujoco to LD_LIBRARY_PATH 46 | $ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.mujoco/mujoco200/bin 47 | 48 | # for GPU rendering (replace 418 with your nvidia driver version or you can make a dummy directory /usr/lib/nvidia-000) 49 | $ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/nvidia-418 50 | 51 | # only for a headless server 52 | $ export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libGLEW.so:/usr/lib/nvidia-418/libGL.so 53 | ``` 54 | 55 | 2. Install python dependencies 56 | ```bash 57 | $ sudo apt-get install cmake libopenmpi-dev libgl1-mesa-dev libgl1-mesa-glx libosmesa6-dev patchelf libglew-dev 58 | 59 | # software rendering 60 | $ sudo apt-get install libgl1-mesa-glx libosmesa6 patchelf 61 | 62 | # window rendering 63 | $ sudo apt-get install libglfw3 libglew2.0 64 | 65 | $ pip install -r requirements.txt 66 | ``` 67 | 68 | 69 | ## Usage 70 | 71 | ### PPO 72 | ```bash 73 | $ python -m run --run_prefix test --algo ppo --env "Hopper-v2" 74 | ``` 75 | 76 | ### DDPG 77 | ```bash 78 | $ python -m run --run_prefix test --algo ddpg --env "Hopper-v2" 79 | ``` 80 | 81 | ### TD3 82 | ```bash 83 | $ python -m run --run_prefix test --algo td3 --env "Hopper-v2" 84 | ``` 85 | 86 | ### SAC 87 | ```bash 88 | $ python -m run --run_prefix test --algo sac --env "Hopper-v2" 89 | ``` 90 | 91 | ### BC 92 | 1. Generate demo using PPO 93 | ```bash 94 | # train ppo expert agent 95 | $ python -m run --run_prefix test --algo ppo --env "Hopper-v2" 96 | # collect expert trajectories using ppo expert policy 97 | $ python -m run --run_prefix test --algo ppo --env "Hopper-v2" --is_train False --record_video False --record_demo True --num_eval 100 98 | # 100 trajectories are stored in log/Hopper-v2.ppo.test.123/demo/Hopper-v2.ppo.test.123_step_00001000000_100.pkl 99 | ``` 100 | 101 | 2. Run BC 102 | ```bash 103 | $ python -m run --run_prefix test --algo bc --env "Hopper-v2" --demo_path log/Hopper-v2.ppo.test.123/demo/Hopper-v2.ppo.test.123_step_00001000000_100.pkl 104 | ``` 105 | 106 | ### GAIL 107 | ```bash 108 | $ python -m run --run_prefix test --algo gail --env "Hopper-v2" --demo_path log/Hopper-v2.ppo.test.123/demo/Hopper-v2.ppo.test.123_step_00001000000_100.pkl 109 | 110 | # initialize with BC policy 111 | $ python -m run --run_prefix test --algo gail --env "Hopper-v2" --demo_path log/Hopper-v2.ppo.test.123/demo/Hopper-v2.ppo.test.123_step_00001000000_100.pkl --init_ckpt_path log/Hopper-v2.bc.test.123/ckpt_00000020.pt 112 | ``` 113 | 114 | 115 | ## To dos 116 | * BC intialization for all algorithms 117 | * Ray 118 | * HER 119 | * Skill coordination 120 | 121 | -------------------------------------------------------------------------------- /algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | # RL algorithms 2 | from .sac_agent import SACAgent 3 | from .ppo_agent import PPOAgent 4 | from .ddpg_agent import DDPGAgent 5 | 6 | # IL algorithms 7 | from .bc_agent import BCAgent 8 | from .gail_agent import GAILAgent 9 | from .dac_agent import DACAgent 10 | 11 | 12 | RL_ALGOS = { 13 | "sac": SACAgent, 14 | "ppo": PPOAgent, 15 | "ddpg": DDPGAgent, 16 | "td3": DDPGAgent, 17 | } 18 | 19 | 20 | IL_ALGOS = { 21 | "bc": BCAgent, 22 | "gail": GAILAgent, 23 | "dac": DACAgent, 24 | } 25 | 26 | 27 | def get_agent_by_name(algo): 28 | """ 29 | Returns RL or IL agent. 30 | """ 31 | if algo in RL_ALGOS: 32 | return RL_ALGOS[algo] 33 | elif algo in IL_ALGOS: 34 | return IL_ALGOS[algo] 35 | else: 36 | raise ValueError("--algo %s is not supported" % algo) 37 | -------------------------------------------------------------------------------- /algorithms/base_agent.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import numpy as np 5 | 6 | from ..utils.normalizer import Normalizer 7 | from ..utils.pytorch import to_tensor, center_crop 8 | 9 | class BaseAgent(object): 10 | """ Base class for agents. """ 11 | 12 | def __init__(self, config, ob_space): 13 | self._config = config 14 | 15 | self._ob_norm = Normalizer( 16 | ob_space, default_clip_range=config.clip_range, clip_obs=config.clip_obs 17 | ) 18 | self._buffer = None 19 | 20 | def normalize(self, ob): 21 | """ Normalizes observations. """ 22 | if self._config.ob_norm: 23 | return self._ob_norm.normalize(ob) 24 | return ob 25 | 26 | def act(self, ob, is_train=True): 27 | """ Returns action and the actor's activation given an observation @ob. """ 28 | if hasattr(self, "_rl_agent"): 29 | return self._rl_agent.act(ob, is_train) 30 | 31 | ob = self.normalize(ob) 32 | 33 | ob = ob.copy() 34 | for k, v in ob.items(): 35 | if self._config.encoder_type == "cnn" and len(v.shape) == 3: 36 | ob[k] = center_crop(v, self._config.encoder_image_size) 37 | else: 38 | ob[k] = np.expand_dims(ob[k], axis=0) 39 | 40 | with torch.no_grad(): 41 | ob = to_tensor(ob, self._config.device) 42 | ac, activation, _, _ = self._actor.act(ob, deterministic=not is_train) 43 | 44 | for k in ac.keys(): 45 | ac[k] = ac[k].cpu().numpy().squeeze(0) 46 | activation[k] = activation[k].cpu().numpy().squeeze(0) 47 | 48 | return ac, activation 49 | 50 | def update_normalizer(self, obs=None): 51 | """ Updates normalizers. """ 52 | if self._config.ob_norm: 53 | if obs is None: 54 | for i in range(len(self._dataset)): 55 | self._ob_norm.update(self._dataset[i]["ob"]) 56 | self._ob_norm.recompute_stats() 57 | else: 58 | self._ob_norm.update(obs) 59 | self._ob_norm.recompute_stats() 60 | 61 | def store_episode(self, rollouts): 62 | """ Stores @rollouts to replay buffer. """ 63 | raise NotImplementedError() 64 | 65 | def is_off_policy(self): 66 | return self._buffer is not None 67 | 68 | def set_buffer(self, buffer): 69 | self._buffer = buffer 70 | 71 | def replay_buffer(self): 72 | return self._buffer.state_dict() 73 | 74 | def load_replay_buffer(self, state_dict): 75 | self._buffer.load_state_dict(state_dict) 76 | 77 | def set_reward_function(self, predict_reward): 78 | self._predict_reward = predict_reward 79 | 80 | def sync_networks(self): 81 | raise NotImplementedError() 82 | 83 | def train(self): 84 | raise NotImplementedError() 85 | 86 | def _soft_update_target_network(self, target, source, tau): 87 | for target_param, source_param in zip(target.parameters(), source.parameters()): 88 | target_param.data.copy_( 89 | (1 - tau) * source_param.data + tau * target_param.data 90 | ) 91 | 92 | def _copy_target_network(self, target, source): 93 | self._soft_update_target_network(target, source, 0) 94 | -------------------------------------------------------------------------------- /algorithms/bc_agent.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from torch.utils.data.sampler import SubsetRandomSampler 8 | from torch.optim.lr_scheduler import StepLR 9 | 10 | from .base_agent import BaseAgent 11 | from .expert_dataset import ExpertDataset 12 | from ..networks import Actor 13 | from ..utils.info_dict import Info 14 | from ..utils.logger import logger 15 | from ..utils.mpi import mpi_average 16 | from ..utils.pytorch import ( 17 | optimizer_cuda, 18 | count_parameters, 19 | compute_gradient_norm, 20 | compute_weight_norm, 21 | sync_networks, 22 | sync_grads, 23 | to_tensor, 24 | ) 25 | 26 | 27 | class BCAgent(BaseAgent): 28 | def __init__(self, config, ob_space, ac_space, env_ob_space): 29 | super().__init__(config, ob_space) 30 | 31 | self._ob_space = ob_space 32 | self._ac_space = ac_space 33 | 34 | self._epoch = 0 35 | 36 | self._actor = Actor(config, ob_space, ac_space, config.tanh_policy) 37 | self._network_cuda(config.device) 38 | self._actor_optim = optim.Adam(self._actor.parameters(), lr=config.bc_lr) 39 | self._actor_lr_scheduler = StepLR( 40 | self._actor_optim, step_size=self._config.max_global_step // 5, gamma=0.5, 41 | ) 42 | 43 | if config.is_train: 44 | self._dataset = ExpertDataset( 45 | config.demo_path, 46 | config.demo_subsample_interval, 47 | ac_space, 48 | use_low_level=config.demo_low_level, 49 | sample_range_start=config.demo_sample_range_start, 50 | sample_range_end=config.demo_sample_range_end, 51 | ) 52 | 53 | if self._config.val_split != 0: 54 | dataset_size = len(self._dataset) 55 | indices = list(range(dataset_size)) 56 | split = int(np.floor((1 - self._config.val_split) * dataset_size)) 57 | train_indices, val_indices = indices[split:], indices[:split] 58 | train_sampler = SubsetRandomSampler(train_indices) 59 | val_sampler = SubsetRandomSampler(val_indices) 60 | self._train_loader = torch.utils.data.DataLoader( 61 | self._dataset, 62 | batch_size=self._config.batch_size, 63 | sampler=train_sampler, 64 | ) 65 | self._val_loader = torch.utils.data.DataLoader( 66 | self._dataset, 67 | batch_size=self._config.batch_size, 68 | sampler=val_sampler, 69 | ) 70 | else: 71 | self._train_loader = torch.utils.data.DataLoader( 72 | self._dataset, batch_size=self._config.batch_size, shuffle=True 73 | ) 74 | 75 | self._log_creation() 76 | 77 | def _log_creation(self): 78 | if self._config.is_chef: 79 | logger.info("Creating a BC agent") 80 | logger.info("The actor has %d parameters", count_parameters(self._actor)) 81 | 82 | def state_dict(self): 83 | return { 84 | "actor_state_dict": self._actor.state_dict(), 85 | "actor_optim_state_dict": self._actor_optim.state_dict(), 86 | "ob_norm_state_dict": self._ob_norm.state_dict(), 87 | } 88 | 89 | def load_state_dict(self, ckpt): 90 | self._actor.load_state_dict(ckpt["actor_state_dict"]) 91 | self._ob_norm.load_state_dict(ckpt["ob_norm_state_dict"]) 92 | self._network_cuda(self._config.device) 93 | 94 | self._actor_optim.load_state_dict(ckpt["actor_optim_state_dict"]) 95 | optimizer_cuda(self._actor_optim, self._config.device) 96 | 97 | def _network_cuda(self, device): 98 | self._actor.to(device) 99 | 100 | def sync_networks(self): 101 | sync_networks(self._actor) 102 | 103 | def train(self): 104 | train_info = Info() 105 | for transitions in self._train_loader: 106 | _train_info = self._update_network(transitions, train=True) 107 | train_info.add(_train_info) 108 | self._epoch += 1 109 | self._actor_lr_scheduler.step() 110 | 111 | train_info.add( 112 | { 113 | "actor_grad_norm": compute_gradient_norm(self._actor), 114 | "actor_weight_norm": compute_weight_norm(self._actor), 115 | } 116 | ) 117 | train_info = train_info.get_dict(only_scalar=True) 118 | logger.info("BC loss %f", train_info["actor_loss"]) 119 | return train_info 120 | 121 | def evaluate(self): 122 | if self._val_loader: 123 | eval_info = Info() 124 | for transitions in self._val_loader: 125 | _eval_info = self._update_network(transitions, train=False) 126 | eval_info.add(_val_info) 127 | self._epoch += 1 128 | return eval_info.get_dict(only_scalar=True) 129 | logger.warning("No validation set available, make sure '--val_split' is set") 130 | return None 131 | 132 | def _update_network(self, transitions, train=True): 133 | info = Info() 134 | 135 | # pre-process observations 136 | o = transitions["ob"] 137 | o = self.normalize(o) 138 | 139 | # convert double tensor to float32 tensor 140 | _to_tensor = lambda x: to_tensor(x, self._config.device) 141 | o = _to_tensor(o) 142 | ac = _to_tensor(transitions["ac"]) 143 | if isinstance(ac, OrderedDict): 144 | ac = list(ac.values()) 145 | if len(ac[0].shape) == 1: 146 | ac = [x.unsqueeze(0) for x in ac] 147 | ac = torch.cat(ac, dim=-1) 148 | 149 | # the actor loss 150 | pred_ac, _ = self._actor(o) 151 | if isinstance(pred_ac, OrderedDict): 152 | pred_ac = list(pred_ac.values()) 153 | if len(pred_ac[0].shape) == 1: 154 | pred_ac = [x.unsqueeze(0) for x in pred_ac] 155 | pred_ac = torch.cat(pred_ac, dim=-1) 156 | 157 | diff = ac - pred_ac 158 | actor_loss = diff.pow(2).mean() 159 | info["actor_loss"] = actor_loss.cpu().item() 160 | info["pred_ac"] = pred_ac.cpu().detach() 161 | info["GT_ac"] = ac.cpu() 162 | diff = torch.sum(torch.abs(diff), axis=0).cpu() 163 | for i in range(diff.shape[0]): 164 | info["action" + str(i) + "_L1loss"] = diff[i].mean().item() 165 | 166 | if train: 167 | # update the actor 168 | self._actor_optim.zero_grad() 169 | actor_loss.backward() 170 | # torch.nn.utils.clip_grad_norm_(self._actor.parameters(), self._config.max_grad_norm) 171 | sync_grads(self._actor) 172 | self._actor_optim.step() 173 | 174 | return mpi_average(info.get_dict(only_scalar=True)) 175 | -------------------------------------------------------------------------------- /algorithms/dac_agent.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | import torch.autograd as autograd 9 | import torch.distributions 10 | from torch.optim.lr_scheduler import StepLR 11 | import gym.spaces 12 | 13 | from .base_agent import BaseAgent 14 | from .ddpg_agent import DDPGAgent 15 | from .sac_agent import SACAgent 16 | from .dataset import ReplayBuffer, ReplayBufferPerStep, RandomSampler 17 | from .expert_dataset import ExpertDataset 18 | from ..networks.discriminator import Discriminator 19 | from ..utils.info_dict import Info 20 | from ..utils.logger import logger 21 | from ..utils.mpi import mpi_average 22 | from ..utils.gym_env import spaces_to_shapes 23 | from ..utils.pytorch import ( 24 | optimizer_cuda, 25 | count_parameters, 26 | sync_networks, 27 | sync_grads, 28 | to_tensor, 29 | ) 30 | 31 | 32 | class DACAgent(BaseAgent): 33 | def __init__(self, config, ob_space, ac_space, env_ob_space): 34 | super().__init__(config, ob_space) 35 | 36 | self._ob_space = ob_space 37 | self._ac_space = ac_space 38 | 39 | if self._config.gail_rl_algo == "td3": 40 | self._rl_agent = DDPGAgent(config, ob_space, ac_space, env_ob_space) 41 | elif self._config.gail_rl_algo == "sac": 42 | self._rl_agent = SACAgent(config, ob_space, ac_space, env_ob_space) 43 | self._rl_agent.set_reward_function(self._predict_reward) 44 | 45 | # build up networks 46 | self._discriminator = Discriminator( 47 | config, ob_space, ac_space if not config.gail_no_action else None 48 | ) 49 | self._discriminator_loss = nn.BCEWithLogitsLoss() 50 | self._network_cuda(config.device) 51 | 52 | # build optimizers 53 | self._discriminator_optim = optim.Adam( 54 | self._discriminator.parameters(), lr=config.discriminator_lr 55 | ) 56 | 57 | # build learning rate scheduler 58 | self._discriminator_lr_scheduler = StepLR( 59 | self._discriminator_optim, 60 | step_size=self._config.max_global_step // 5, 61 | gamma=0.5, 62 | ) 63 | 64 | # expert dataset 65 | if config.is_train: 66 | self._dataset = ExpertDataset( 67 | config.demo_path, 68 | config.demo_subsample_interval, 69 | ac_space, 70 | use_low_level=config.demo_low_level, 71 | sample_range_start=config.demo_sample_range_start, 72 | sample_range_end=config.demo_sample_range_end, 73 | ) 74 | if self._config.absorbing_state: 75 | self._dataset.add_absorbing_states(ob_space, ac_space) 76 | self._data_loader = torch.utils.data.DataLoader( 77 | self._dataset, 78 | batch_size=self._config.batch_size, 79 | shuffle=True, 80 | drop_last=True, 81 | ) 82 | self._data_iter = iter(self._data_loader) 83 | 84 | # per-episode replay buffer 85 | sampler = RandomSampler(image_crop_size=config.encoder_image_size) 86 | buffer_keys = ["ob", "ob_next", "ac", "done", "done_mask", "rew"] 87 | self._buffer = ReplayBuffer( 88 | buffer_keys, config.buffer_size, sampler.sample_func 89 | ) 90 | 91 | # per-step replay buffer 92 | # shapes = { 93 | # "ob": spaces_to_shapes(env_ob_space), 94 | # "ob_next": spaces_to_shapes(env_ob_space), 95 | # "ac": spaces_to_shapes(ac_space), 96 | # "done": [1], 97 | # "done_mask": [1], 98 | # "rew": [1], 99 | # } 100 | # self._buffer = ReplayBufferPerStep( 101 | # shapes, 102 | # config.buffer_size, 103 | # config.encoder_image_size, 104 | # config.absorbing_state, 105 | # ) 106 | 107 | self._rl_agent.set_buffer(self._buffer) 108 | 109 | self._update_iter = 0 110 | 111 | self._log_creation() 112 | 113 | def _predict_reward(self, ob, ac): 114 | if self._config.gail_no_action: 115 | ac = None 116 | with torch.no_grad(): 117 | ret = self._discriminator(ob, ac) 118 | eps = 1e-10 119 | s = torch.sigmoid(ret) 120 | if self._config.gail_reward == "vanilla": 121 | reward = -(1 - s + eps).log() 122 | elif self._config.gail_reward == "gan": 123 | reward = (s + eps).log() - (1 - s + eps).log() 124 | elif self._config.gail_reward == "d": 125 | reward = ret 126 | return reward 127 | 128 | def predict_reward(self, ob, ac=None): 129 | ob = self.normalize(ob) 130 | ob = to_tensor(ob, self._config.device) 131 | if self._config.gail_no_action: 132 | ac = None 133 | if ac is not None: 134 | ac = to_tensor(ac, self._config.device) 135 | 136 | reward = self._predict_reward(ob, ac) 137 | return reward.cpu().item() 138 | 139 | def _log_creation(self): 140 | if self._config.is_chef: 141 | logger.info("Creating a DAC agent") 142 | logger.info( 143 | "The discriminator has %d parameters", 144 | count_parameters(self._discriminator), 145 | ) 146 | 147 | def store_episode(self, rollouts): 148 | self._rl_agent.store_episode(rollouts) 149 | 150 | def state_dict(self): 151 | return { 152 | "rl_agent": self._rl_agent.state_dict(), 153 | "discriminator_state_dict": self._discriminator.state_dict(), 154 | "discriminator_optim_state_dict": self._discriminator_optim.state_dict(), 155 | "ob_norm_state_dict": self._ob_norm.state_dict(), 156 | } 157 | 158 | def load_state_dict(self, ckpt): 159 | if "rl_agent" in ckpt: 160 | self._rl_agent.load_state_dict(ckpt["rl_agent"]) 161 | else: 162 | self._rl_agent.load_state_dict(ckpt) 163 | self._network_cuda(self._config.device) 164 | return 165 | 166 | self._discriminator.load_state_dict(ckpt["discriminator_state_dict"]) 167 | self._ob_norm.load_state_dict(ckpt["ob_norm_state_dict"]) 168 | self._network_cuda(self._config.device) 169 | 170 | self._discriminator_optim.load_state_dict( 171 | ckpt["discriminator_optim_state_dict"] 172 | ) 173 | optimizer_cuda(self._discriminator_optim, self._config.device) 174 | 175 | def _network_cuda(self, device): 176 | self._discriminator.to(device) 177 | 178 | def sync_networks(self): 179 | self._rl_agent.sync_networks() 180 | sync_networks(self._discriminator) 181 | 182 | def train(self): 183 | train_info = Info() 184 | 185 | self._discriminator_lr_scheduler.step() 186 | 187 | if self._update_iter % self._config.discriminator_update_freq == 0: 188 | self._num_updates = 1 189 | for _ in range(self._num_updates): 190 | policy_data = self._buffer.sample(self._config.batch_size) 191 | try: 192 | expert_data = next(self._data_iter) 193 | except StopIteration: 194 | self._data_iter = iter(self._data_loader) 195 | expert_data = next(self._data_iter) 196 | _train_info = self._update_discriminator(policy_data, expert_data) 197 | train_info.add(_train_info) 198 | 199 | _train_info = self._rl_agent.train() 200 | train_info.add(_train_info) 201 | 202 | return train_info.get_dict(only_scalar=True) 203 | 204 | def _update_discriminator(self, policy_data, expert_data): 205 | info = Info() 206 | 207 | _to_tensor = lambda x: to_tensor(x, self._config.device) 208 | # pre-process observations 209 | p_o = policy_data["ob"] 210 | p_o = self.normalize(p_o) 211 | 212 | p_bs = len(policy_data["ac"]) 213 | p_o = _to_tensor(p_o) 214 | if self._config.gail_no_action: 215 | p_ac = None 216 | else: 217 | p_ac = _to_tensor(policy_data["ac"]) 218 | 219 | e_o = expert_data["ob"] 220 | e_o = self.normalize(e_o) 221 | 222 | e_bs = len(expert_data["ac"]) 223 | e_o = _to_tensor(e_o) 224 | if self._config.gail_no_action: 225 | e_ac = None 226 | else: 227 | e_ac = _to_tensor(expert_data["ac"]) 228 | 229 | p_logit = self._discriminator(p_o, p_ac) 230 | e_logit = self._discriminator(e_o, e_ac) 231 | 232 | p_output = torch.sigmoid(p_logit) 233 | e_output = torch.sigmoid(e_logit) 234 | 235 | p_loss = self._discriminator_loss( 236 | p_logit, torch.zeros_like(p_logit).to(self._config.device) 237 | ) 238 | e_loss = self._discriminator_loss( 239 | e_logit, torch.ones_like(e_logit).to(self._config.device) 240 | ) 241 | 242 | logits = torch.cat([p_logit, e_logit], dim=0) 243 | entropy = torch.distributions.Bernoulli(logits).entropy().mean() 244 | entropy_loss = -self._config.gail_entropy_loss_coeff * entropy 245 | 246 | grad_pen = self._compute_grad_pen(p_o, p_ac, e_o, e_ac) 247 | grad_pen_loss = self._config.gail_grad_penalty_coeff * grad_pen 248 | 249 | gail_loss = p_loss + e_loss + entropy_loss + grad_pen_loss 250 | 251 | # update the discriminator 252 | self._discriminator.zero_grad() 253 | gail_loss.backward() 254 | sync_grads(self._discriminator) 255 | self._discriminator_optim.step() 256 | 257 | info["gail_policy_output"] = p_output.mean().detach().cpu().item() 258 | info["gail_expert_output"] = e_output.mean().detach().cpu().item() 259 | info["gail_entropy"] = entropy.detach().cpu().item() 260 | info["gail_policy_loss"] = p_loss.detach().cpu().item() 261 | info["gail_expert_loss"] = e_loss.detach().cpu().item() 262 | info["gail_entropy_loss"] = entropy_loss.detach().cpu().item() 263 | info["gail_grad_pen"] = grad_pen.detach().cpu().item() 264 | info["gail_grad_loss"] = grad_pen_loss.detach().cpu().item() 265 | 266 | return mpi_average(info.get_dict(only_scalar=True)) 267 | 268 | def _compute_grad_pen(self, policy_ob, policy_ac, expert_ob, expert_ac): 269 | batch_size = self._config.batch_size 270 | alpha = torch.rand(batch_size, 1, device=self._config.device) 271 | 272 | def blend_dict(a, b, alpha): 273 | if isinstance(a, dict): 274 | return OrderedDict( 275 | [(k, blend_dict(a[k], b[k], alpha)) for k in a.keys()] 276 | ) 277 | elif isinstance(a, list): 278 | return [blend_dict(a[i], b[i], alpha) for i in range(len(a))] 279 | else: 280 | expanded_alpha = alpha.expand_as(a) 281 | ret = expanded_alpha * a + (1 - expanded_alpha) * b 282 | ret.requires_grad = True 283 | return ret 284 | 285 | interpolated_ob = blend_dict(policy_ob, expert_ob, alpha) 286 | inputs = list(interpolated_ob.values()) 287 | if policy_ac is not None: 288 | interpolated_ac = blend_dict(policy_ac, expert_ac, alpha) 289 | inputs = inputs + list(interpolated_ob.values()) 290 | else: 291 | interpolated_ac = None 292 | 293 | interpolated_logit = self._discriminator(interpolated_ob, interpolated_ac) 294 | ones = torch.ones(interpolated_logit.size(), device=self._config.device) 295 | 296 | grad = autograd.grad( 297 | outputs=interpolated_logit, 298 | inputs=inputs, 299 | grad_outputs=ones, 300 | create_graph=True, 301 | retain_graph=True, 302 | only_inputs=True, 303 | )[0] 304 | 305 | grad_pen = (grad.norm(2, dim=1) - 1).pow(2).mean() 306 | return grad_pen 307 | -------------------------------------------------------------------------------- /algorithms/dataset.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from time import time 3 | 4 | import numpy as np 5 | 6 | from ..utils.pytorch import random_crop 7 | 8 | 9 | def make_buffer(shapes, buffer_size): 10 | buffer = {} 11 | for k, v in shapes.items(): 12 | if isinstance(v, dict): 13 | buffer[k] = make_buffer(v, buffer_size) 14 | else: 15 | if len(v) >= 3: 16 | buffer[k] = np.empty((buffer_size, *v), dtype=np.uint8) 17 | else: 18 | buffer[k] = np.empty((buffer_size, *v), dtype=np.float32) 19 | return buffer 20 | 21 | 22 | def add_rollout(buffer, rollout, idx: int): 23 | if isinstance(rollout, list): 24 | rollout = rollout[0] 25 | 26 | if isinstance(rollout, dict): 27 | for k in rollout.keys(): 28 | add_rollout(buffer[k], rollout[k], idx) 29 | else: 30 | np.copyto(buffer[idx], rollout) 31 | 32 | 33 | def get_batch(buffer: dict, idxs): 34 | batch = {} 35 | for k in buffer.keys(): 36 | if isinstance(buffer[k], dict): 37 | batch[k] = get_batch(buffer[k], idxs) 38 | else: 39 | batch[k] = buffer[k][idxs] 40 | return batch 41 | 42 | 43 | def augment_ob(batch, image_crop_size): 44 | for k, v in batch.items(): 45 | if isinstance(batch[k], dict): 46 | augment_ob(batch[k], image_crop_size) 47 | elif len(batch[k].shape) > 3: 48 | batch[k] = random_crop(batch[k], image_crop_size) 49 | 50 | 51 | class ReplayBufferPerStep(object): 52 | def __init__(self, shapes: dict, buffer_size: int, image_crop_size=84, absorbing_state=False): 53 | self._capacity = buffer_size 54 | 55 | if absorbing_state: 56 | shapes["ob"]["absorbing_state"] = [1] 57 | shapes["ob_next"]["absorbing_state"] = [1] 58 | 59 | self._shapes = shapes 60 | self._keys = list(shapes.keys()) 61 | self._image_crop_size = image_crop_size 62 | self._absorbing_state = absorbing_state 63 | 64 | self._buffer = make_buffer(shapes, buffer_size) 65 | self._idx = 0 66 | self._full = False 67 | 68 | def clear(self): 69 | self._idx = 0 70 | self._full = False 71 | 72 | # store the episode 73 | def store_episode(self, rollout): 74 | for k in self._keys: 75 | add_rollout(self._buffer[k], rollout[k], self._idx) 76 | 77 | self._idx = (self._idx + 1) % self._capacity 78 | self._full = self._full or self._idx == 0 79 | 80 | # sample the data from the replay buffer 81 | def sample(self, batch_size): 82 | idxs = np.random.randint( 83 | 0, self._capacity if self._full else self._idx, size=batch_size 84 | ) 85 | batch = get_batch(self._buffer, idxs) 86 | 87 | # apply random crop to image 88 | augment_ob(batch, self._image_crop_size) 89 | 90 | return batch 91 | 92 | def state_dict(self): 93 | return {"buffer": self._buffer, "idx": self._idx, "full": self._full} 94 | 95 | def load_state_dict(self, state_dict): 96 | self._buffer = state_dict["buffer"] 97 | self._idx = state_dict["idx"] 98 | self._full = state_dict["full"] 99 | 100 | 101 | class ReplayBuffer(object): 102 | def __init__(self, keys, buffer_size, sample_func): 103 | self._capacity = buffer_size 104 | self._sample_func = sample_func 105 | 106 | # create the buffer to store info 107 | self._keys = keys 108 | self.clear() 109 | 110 | def clear(self): 111 | self._idx = 0 112 | self._current_size = 0 113 | self._buffer = defaultdict(list) 114 | 115 | # store transitions 116 | def store_episode(self, rollout): 117 | # @rollout can be any length of transitions 118 | for k in self._keys: 119 | if self._current_size < self._capacity: 120 | self._buffer[k].append(rollout[k]) 121 | else: 122 | self._buffer[k][self._idx] = rollout[k] 123 | 124 | self._idx = (self._idx + 1) % self._capacity 125 | if self._current_size < self._capacity: 126 | self._current_size += 1 127 | 128 | # sample the data from the replay buffer 129 | def sample(self, batch_size): 130 | # sample transitions 131 | transitions = self._sample_func(self._buffer, batch_size) 132 | return transitions 133 | 134 | def state_dict(self): 135 | return self._buffer 136 | 137 | def load_state_dict(self, state_dict): 138 | self._buffer = state_dict 139 | self._current_size = len(self._buffer["ac"]) 140 | 141 | 142 | class ReplayBufferEpisode(object): 143 | def __init__(self, keys, buffer_size, sample_func): 144 | self._capacity = buffer_size 145 | self._sample_func = sample_func 146 | 147 | # create the buffer to store info 148 | self._keys = keys 149 | self.clear() 150 | 151 | def clear(self): 152 | self._idx = 0 153 | self._current_size = 0 154 | self._new_episode = True 155 | self._buffer = defaultdict(list) 156 | 157 | # store the episode 158 | def store_episode(self, rollout): 159 | if self._new_episode: 160 | self._new_episode = False 161 | for k in self._keys: 162 | if self._current_size < self._capacity: 163 | self._buffer[k].append(rollout[k]) 164 | else: 165 | self._buffer[k][self._idx] = rollout[k] 166 | else: 167 | for k in self._keys: 168 | self._buffer[k][self._idx].extend(rollout[k]) 169 | 170 | if rollout["done"][-1]: 171 | self._idx = (self._idx + 1) % self._capacity 172 | if self._current_size < self._capacity: 173 | self._current_size += 1 174 | self._new_episode = True 175 | 176 | # sample the data from the replay buffer 177 | def sample(self, batch_size): 178 | # sample transitions 179 | transitions = self._sample_func(self._buffer, batch_size) 180 | return transitions 181 | 182 | def state_dict(self): 183 | return self._buffer 184 | 185 | def load_state_dict(self, state_dict): 186 | self._buffer = state_dict 187 | self._current_size = len(self._buffer["ac"]) 188 | 189 | 190 | class RandomSampler(object): 191 | def __init__(self, image_crop_size=84): 192 | self._image_crop_size = image_crop_size 193 | 194 | def sample_func(self, episode_batch, batch_size_in_transitions): 195 | rollout_batch_size = len(episode_batch["ac"]) 196 | batch_size = batch_size_in_transitions 197 | 198 | episode_idxs = np.random.randint(0, rollout_batch_size, batch_size) 199 | t_samples = [ 200 | np.random.randint(len(episode_batch["ac"][episode_idx])) 201 | for episode_idx in episode_idxs 202 | ] 203 | 204 | transitions = {} 205 | for key in episode_batch.keys(): 206 | transitions[key] = [ 207 | episode_batch[key][episode_idx][t] 208 | for episode_idx, t in zip(episode_idxs, t_samples) 209 | ] 210 | 211 | transitions["ob_next"] = [ 212 | episode_batch["ob_next"][episode_idx][t] 213 | for episode_idx, t in zip(episode_idxs, t_samples) 214 | ] 215 | 216 | new_transitions = {} 217 | for k, v in transitions.items(): 218 | if isinstance(v[0], dict): 219 | sub_keys = v[0].keys() 220 | new_transitions[k] = { 221 | sub_key: np.stack([v_[sub_key] for v_ in v]) for sub_key in sub_keys 222 | } 223 | else: 224 | new_transitions[k] = np.stack(v) 225 | 226 | for k, v in new_transitions["ob"].items(): 227 | if len(v.shape) in [4, 5]: 228 | new_transitions["ob"][k] = random_crop(v, self._image_crop_size) 229 | 230 | for k, v in new_transitions["ob_next"].items(): 231 | if len(v.shape) in [4, 5]: 232 | new_transitions["ob_next"][k] = random_crop(v, self._image_crop_size) 233 | 234 | return new_transitions 235 | 236 | class SeqSampler(object): 237 | def __init__(self, seq_length, image_crop_size=84): 238 | self._seq_length = seq_length 239 | self._image_crop_size = image_crop_size 240 | 241 | def sample_func(self, episode_batch, batch_size_in_transitions): 242 | rollout_batch_size = len(episode_batch["ac"]) 243 | batch_size = batch_size_in_transitions 244 | 245 | episode_idxs = np.random.randint(0, rollout_batch_size, batch_size) 246 | t_samples = [ 247 | np.random.randint(len(episode_batch["ac"][episode_idx])) 248 | for episode_idx in episode_idxs 249 | ] 250 | 251 | transitions = {} 252 | for key in episode_batch.keys(): 253 | transitions[key] = [ 254 | episode_batch[key][episode_idx][t] 255 | for episode_idx, t in zip(episode_idxs, t_samples) 256 | ] 257 | 258 | transitions["ob_next"] = [ 259 | episode_batch["ob_next"][episode_idx][t] 260 | for episode_idx, t in zip(episode_idxs, t_samples) 261 | ] 262 | 263 | #Create a key that stores the specified future fixed length of sequences, pad last states if necessary 264 | 265 | print(episode_idxs) 266 | print(t_samples) 267 | 268 | #List of dictionaries is created here..., flatten it out? 269 | transitions["following_sequences"] = [ 270 | episode_batch["ob"][episode_idx][t:t+ self._seq_length] 271 | for episode_idx, t in zip(episode_idxs, t_samples) 272 | ] 273 | 274 | #something's wrong here... should use index episode_idx to episode_batch, not transitions 275 | 276 | # # Pad last states 277 | # for episode_idx in episode_idxs: 278 | # # curr_ep = episode_batch["ob"][episode_idx] 279 | # # curr_ep.extend(curr_ep[-1:] * (self._seq_length - len(curr_ep))) 280 | # 281 | # #all list should have 10 dictionaries now 282 | # if isinstance(transitions["following_sequences"][episode_idx], dict): 283 | # continue 284 | # transitions["following_sequences"][episode_idx].extend(transitions["following_sequences"][episode_idx][-1:] * (self._seq_length - len(transitions["following_sequences"][episode_idx]))) 285 | # 286 | # #turn transitions["following_sequences"] to a dictionary 287 | # fs_list = transitions["following_sequences"][episode_idx] 288 | # container = {} 289 | # container["ob"] = [] 290 | # for i in fs_list: 291 | # container["ob"].extend(i["ob"]) 292 | # container["ob"] = np.array(container["ob"]) 293 | # transitions["following_sequences"][episode_idx] = container 294 | 295 | # Pad last states 296 | for i in range(len(transitions["following_sequences"])): 297 | # curr_ep = episode_batch["ob"][episode_idx] 298 | # curr_ep.extend(curr_ep[-1:] * (self._seq_length - len(curr_ep))) 299 | 300 | #all list should have 10 dictionaries now 301 | if isinstance(transitions["following_sequences"][i], dict): 302 | continue 303 | transitions["following_sequences"][i].extend(transitions["following_sequences"][i][-1:] * (self._seq_length - len(transitions["following_sequences"][i]))) 304 | 305 | #turn transitions["following_sequences"] to a dictionary 306 | fs_list = transitions["following_sequences"][i] 307 | container = {} 308 | container["ob"] = [] 309 | for j in fs_list: 310 | container["ob"].extend(j["ob"]) 311 | container["ob"] = np.array(container["ob"]) 312 | transitions["following_sequences"][i] = container 313 | 314 | 315 | 316 | new_transitions = {} 317 | for k, v in transitions.items(): 318 | if isinstance(v[0], dict): 319 | sub_keys = v[0].keys() 320 | new_transitions[k] = { 321 | sub_key: np.stack([v_[sub_key] for v_ in v]) for sub_key in sub_keys 322 | } 323 | else: 324 | new_transitions[k] = np.stack(v) 325 | 326 | for k, v in new_transitions["ob"].items(): 327 | if len(v.shape) in [4, 5]: 328 | new_transitions["ob"][k] = random_crop(v, self._image_crop_size) 329 | 330 | for k, v in new_transitions["ob_next"].items(): 331 | if len(v.shape) in [4, 5]: 332 | new_transitions["ob_next"][k] = random_crop(v, self._image_crop_size) 333 | 334 | return new_transitions 335 | 336 | 337 | class HERSampler(object): 338 | def __init__(self, replay_strategy, replace_future, reward_func=None): 339 | self.replay_strategy = replay_strategy 340 | if self.replay_strategy == "future": 341 | self.future_p = replace_future 342 | else: 343 | self.future_p = 0 344 | self.reward_func = reward_func 345 | 346 | def sample_her_transitions(self, episode_batch, batch_size_in_transitions): 347 | rollout_batch_size = len(episode_batch["ac"]) 348 | batch_size = batch_size_in_transitions 349 | 350 | # select which rollouts and which timesteps to be used 351 | episode_idxs = np.random.randint(0, rollout_batch_size, batch_size) 352 | t_samples = [ 353 | np.random.randint(len(episode_batch["ac"][episode_idx])) 354 | for episode_idx in episode_idxs 355 | ] 356 | 357 | transitions = {} 358 | for key in episode_batch.keys(): 359 | transitions[key] = [ 360 | episode_batch[key][episode_idx][t] 361 | for episode_idx, t in zip(episode_idxs, t_samples) 362 | ] 363 | 364 | transitions["ob_next"] = [ 365 | episode_batch["ob"][episode_idx][t + 1] 366 | for episode_idx, t in zip(episode_idxs, t_samples) 367 | ] 368 | transitions["r"] = np.zeros((batch_size,)) 369 | 370 | # hindsight experience replay 371 | for i, (episode_idx, t) in enumerate(zip(episode_idxs, t_samples)): 372 | replace_goal = np.random.uniform() < self.future_p 373 | if replace_goal: 374 | future_t = np.random.randint( 375 | t + 1, len(episode_batch["ac"][episode_idx]) + 1 376 | ) 377 | future_ag = episode_batch["ag"][episode_idx][future_t] 378 | if ( 379 | self.reward_func( 380 | episode_batch["ag"][episode_idx][t], future_ag, None 381 | ) 382 | < 0 383 | ): 384 | transitions["g"][i] = future_ag 385 | transitions["r"][i] = self.reward_func( 386 | episode_batch["ag"][episode_idx][t + 1], transitions["g"][i], None 387 | ) 388 | 389 | new_transitions = {} 390 | for k, v in transitions.items(): 391 | if isinstance(v[0], dict): 392 | sub_keys = v[0].keys() 393 | new_transitions[k] = { 394 | sub_key: np.stack([v_[sub_key] for v_ in v]) for sub_key in sub_keys 395 | } 396 | else: 397 | new_transitions[k] = np.stack(v) 398 | 399 | return new_transitions 400 | 401 | -------------------------------------------------------------------------------- /algorithms/ddpg_agent.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | import gym.spaces 7 | from torch.optim.lr_scheduler import StepLR 8 | 9 | from .base_agent import BaseAgent 10 | from .dataset import ReplayBuffer, RandomSampler 11 | from ..networks import Actor, Critic 12 | from ..utils.info_dict import Info 13 | from ..utils.logger import logger 14 | from ..utils.mpi import mpi_average 15 | from ..utils.pytorch import ( 16 | optimizer_cuda, 17 | count_parameters, 18 | compute_gradient_norm, 19 | compute_weight_norm, 20 | sync_networks, 21 | sync_grads, 22 | to_tensor, 23 | scale_dict_tensor, 24 | ) 25 | 26 | 27 | class DDPGAgent(BaseAgent): 28 | def __init__(self, config, ob_space, ac_space, env_ob_space): 29 | super().__init__(config, ob_space) 30 | 31 | self._ob_space = ob_space 32 | self._ac_space = ac_space 33 | 34 | # build up networks 35 | self._actor = Actor(config, ob_space, ac_space, config.tanh_policy) 36 | self._critic = Critic(config, ob_space, ac_space) 37 | 38 | # build up target networks 39 | self._actor_target = Actor(config, ob_space, ac_space, config.tanh_policy) 40 | self._critic_target = Critic(config, ob_space, ac_space) 41 | self._network_cuda(self._config.device) 42 | self._copy_target_network(self._actor_target, self._actor) 43 | self._copy_target_network(self._critic_target, self._critic) 44 | self._actor.encoder.copy_conv_weights_from(self._critic.encoder) 45 | self._actor_target.encoder.copy_conv_weights_from(self._critic_target.encoder) 46 | 47 | # build optimizers 48 | self._actor_optim = optim.Adam(self._actor.parameters(), lr=config.actor_lr) 49 | self._critic_optim = optim.Adam(self._critic.parameters(), lr=config.critic_lr) 50 | 51 | # build learning rate scheduler 52 | self._actor_lr_scheduler = StepLR( 53 | self._actor_optim, step_size=self._config.max_global_step // 5, gamma=0.5, 54 | ) 55 | 56 | # per-episode replay buffer 57 | sampler = RandomSampler(image_crop_size=config.encoder_image_size) 58 | buffer_keys = ["ob", "ob_next", "ac", "done", "done_mask", "rew"] 59 | self._buffer = ReplayBuffer( 60 | buffer_keys, config.buffer_size, sampler.sample_func 61 | ) 62 | 63 | self._update_iter = 0 64 | self._predict_reward = None 65 | 66 | self._log_creation() 67 | 68 | def _log_creation(self): 69 | if self._config.is_chef: 70 | logger.info("Creating a DDPG agent") 71 | logger.info("The actor has %d parameters", count_parameters(self._actor)) 72 | logger.info("The critic has %d parameters", count_parameters(self._critic)) 73 | 74 | def act(self, ob, is_train=True): 75 | """ Returns action and the actor's activation given an observation @ob. """ 76 | ac, activation = super().act(ob, is_train=is_train) 77 | 78 | if not is_train: 79 | return ac, activation 80 | 81 | if self._config.epsilon_greedy: 82 | if np.random.uniform() < self._config.epsilon_greedy_eps: 83 | for k, v in self._ac_space.spaces.items(): 84 | ac[k] = v.sample() 85 | return ac, activation 86 | 87 | for k, v in self._ac_space.spaces.items(): 88 | if isinstance(v, gym.spaces.Box): 89 | ac[k] += self._config.policy_exploration_noise * np.random.randn( 90 | *ac[k].shape 91 | ) 92 | ac[k] = np.clip(ac[k], v.low, v.high) 93 | 94 | return ac, activation 95 | 96 | def store_episode(self, rollouts): 97 | self._buffer.store_episode(rollouts) 98 | 99 | def state_dict(self): 100 | return { 101 | "update_iter": self._update_iter, 102 | "actor_state_dict": self._actor.state_dict(), 103 | "critic_state_dict": self._critic.state_dict(), 104 | "actor_optim_state_dict": self._actor_optim.state_dict(), 105 | "critic_optim_state_dict": self._critic_optim.state_dict(), 106 | "ob_norm_state_dict": self._ob_norm.state_dict(), 107 | } 108 | 109 | def load_state_dict(self, ckpt): 110 | if "critic_state_dict" not in ckpt: 111 | missing = self._actor.load_state_dict( 112 | ckpt["actor_state_dict"], strict=False 113 | ) 114 | self._copy_target_network(self._actor_target, self._actor) 115 | self._network_cuda(self._config.device) 116 | return 117 | 118 | self._update_iter = ckpt["update_iter"] 119 | self._actor.load_state_dict(ckpt["actor_state_dict"]) 120 | self._critic.load_state_dict(ckpt["critic_state_dict"]) 121 | self._copy_target_network(self._actor_target, self._actor) 122 | self._copy_target_network(self._critic_target, self._critic) 123 | self._ob_norm.load_state_dict(ckpt["ob_norm_state_dict"]) 124 | self._network_cuda(self._config.device) 125 | 126 | self._actor_optim.load_state_dict(ckpt["actor_optim_state_dict"]) 127 | self._critic_optim.load_state_dict(ckpt["critic_optim_state_dict"]) 128 | optimizer_cuda(self._actor_optim, self._config.device) 129 | optimizer_cuda(self._critic_optim, self._config.device) 130 | 131 | def _network_cuda(self, device): 132 | self._actor.to(device) 133 | self._critic.to(device) 134 | self._actor_target.to(device) 135 | self._critic_target.to(device) 136 | 137 | def sync_networks(self): 138 | sync_networks(self._actor) 139 | sync_networks(self._critic) 140 | sync_networks(self._actor_target) 141 | sync_networks(self._critic_target) 142 | 143 | def train(self): 144 | train_info = Info() 145 | 146 | self._num_updates = 1 147 | for _ in range(self._num_updates): 148 | self._actor_lr_scheduler.step() 149 | transitions = self._buffer.sample(self._config.batch_size) 150 | train_info.add(self._update_network(transitions)) 151 | 152 | return train_info.get_dict() 153 | 154 | def _update_actor(self, o, mask): 155 | info = Info() 156 | 157 | # the actor loss 158 | actions_real, _, _, _ = self._actor.act( 159 | o, return_log_prob=False, detach_conv=True 160 | ) 161 | 162 | q_pred = self._critic(o, actions_real, detach_conv=True) 163 | if self._config.critic_ensemble > 1: 164 | q_pred = q_pred[0] 165 | 166 | if self._config.absorbing_state: 167 | # do not update the actor for absorbing states 168 | a_mask = 1.0 - torch.clamp(-mask, min=0) # 0 absorbing, 1 done/not done 169 | actor_loss = -(q_pred * a_mask).sum() 170 | if a_mask.sum() > 1e-8: 171 | actor_loss /= a_mask.sum() 172 | else: 173 | actor_loss = -q_pred.mean() 174 | info["actor_loss"] = actor_loss.cpu().item() 175 | 176 | # update the actor 177 | self._actor_optim.zero_grad() 178 | actor_loss.backward() 179 | if self._config.max_grad_norm: 180 | torch.nn.utils.clip_grad_norm_( 181 | self._actor.parameters(), self._config.max_grad_norm 182 | ) 183 | sync_grads(self._actor) 184 | self._actor_optim.step() 185 | 186 | return info 187 | 188 | def _update_critic(self, o, ac, rew, o_next, mask): 189 | info = Info() 190 | 191 | # calculate the target Q value function 192 | with torch.no_grad(): 193 | actions_next, _, _, _ = self._actor_target.act( 194 | o_next, return_log_prob=False 195 | ) 196 | 197 | # TD3 adds noise to action 198 | if self._config.critic_ensemble > 1: 199 | for k in self._ac_space.spaces.keys(): 200 | noise = ( 201 | torch.randn_like(actions_next[k]) * self._config.policy_noise 202 | ).clamp( 203 | -self._config.policy_noise_clip, self._config.policy_noise_clip 204 | ) 205 | actions_next[k] = (actions_next[k] + noise).clamp(-1, 1) 206 | 207 | if self._config.absorbing_state: 208 | a_mask = torch.clamp(mask, min=0) # 0 absorbing/done, 1 not done 209 | masked_actions_next = scale_dict_tensor(actions_next, a_mask) 210 | q_next_values = self._critic_target(o_next, masked_actions_next) 211 | else: 212 | q_next_values = self._critic_target(o_next, actions_next) 213 | 214 | q_next_value = torch.min(*q_next_values) 215 | 216 | else: 217 | q_next_value = self._critic_target(o_next, actions_next) 218 | 219 | # For IL, use IL reward 220 | if self._predict_reward is not None: 221 | rew_il = self._predict_reward(o, ac) 222 | rew = ( 223 | 1 - self._config.gail_env_reward 224 | ) * rew_il + self._config.gail_env_reward * rew 225 | 226 | if self._config.absorbing_state: 227 | target_q_value = ( 228 | rew + self._config.rl_discount_factor * q_next_value 229 | ) 230 | else: 231 | target_q_value = ( 232 | rew + mask * self._config.rl_discount_factor * q_next_value 233 | ) 234 | 235 | # the q loss 236 | if self._config.critic_ensemble == 1: 237 | real_q_value = self._critic(o, ac) 238 | critic_loss = F.mse_loss(target_q_value, real_q_value) 239 | else: 240 | real_q_value1, real_q_value2 = self._critic(o, ac) 241 | critic1_loss = F.mse_loss(target_q_value, real_q_value1) 242 | critic2_loss = F.mse_loss(target_q_value, real_q_value2) 243 | critic_loss = critic1_loss + critic2_loss 244 | 245 | # update the critic 246 | self._critic_optim.zero_grad() 247 | critic_loss.backward() 248 | sync_grads(self._critic) 249 | self._critic_optim.step() 250 | 251 | info["min_target_q"] = target_q_value.min().cpu().item() 252 | info["target_q"] = target_q_value.mean().cpu().item() 253 | 254 | if self._config.critic_ensemble == 1: 255 | info["min_real1_q"] = real_q_value.min().cpu().item() 256 | info["real1_q"] = real_q_value.mean().cpu().item() 257 | info["critic1_loss"] = critic_loss.cpu().item() 258 | else: 259 | info["min_real1_q"] = real_q_value1.min().cpu().item() 260 | info["min_real2_q"] = real_q_value2.min().cpu().item() 261 | info["real1_q"] = real_q_value1.mean().cpu().item() 262 | info["real2_q"] = real_q_value2.mean().cpu().item() 263 | info["critic1_loss"] = critic1_loss.cpu().item() 264 | info["critic2_loss"] = critic2_loss.cpu().item() 265 | 266 | return info 267 | 268 | def _update_network(self, transitions): 269 | info = Info() 270 | 271 | # pre-process the observation 272 | o, o_next = transitions["ob"], transitions["ob_next"] 273 | o = self.normalize(o) 274 | o_next = self.normalize(o_next) 275 | bs = len(transitions["done"]) 276 | 277 | _to_tensor = lambda x: to_tensor(x, self._config.device) 278 | o = _to_tensor(o) 279 | o_next = _to_tensor(o_next) 280 | ac = _to_tensor(transitions["ac"]) 281 | mask = _to_tensor(transitions["done_mask"]).reshape(bs, 1) 282 | rew = _to_tensor(transitions["rew"]).reshape(bs, 1) 283 | 284 | self._update_iter += 1 285 | 286 | critic_train_info = self._update_critic(o, ac, rew, o_next, mask) 287 | info.add(critic_train_info) 288 | 289 | if ( 290 | self._update_iter % self._config.actor_update_freq == 0 291 | and self._update_iter > self._config.actor_update_delay 292 | ): 293 | actor_train_info = self._update_actor(o, mask) 294 | info.add(actor_train_info) 295 | 296 | if self._update_iter % self._config.critic_target_update_freq == 0: 297 | for i, fc in enumerate(self._critic.fcs): 298 | self._soft_update_target_network( 299 | self._critic_target.fcs[i], 300 | fc, 301 | self._config.critic_soft_update_weight, 302 | ) 303 | self._soft_update_target_network( 304 | self._critic_target.encoder, 305 | self._critic.encoder, 306 | self._config.encoder_soft_update_weight, 307 | ) 308 | 309 | if ( 310 | self._update_iter % self._config.actor_target_update_freq == 0 311 | and self._update_iter > self._config.actor_update_delay 312 | ): 313 | self._soft_update_target_network( 314 | self._actor_target.fc, 315 | self._actor.fc, 316 | self._config.actor_soft_update_weight, 317 | ) 318 | for k, fc in self._actor.fcs.items(): 319 | self._soft_update_target_network( 320 | self._actor_target.fcs[k], 321 | fc, 322 | self._config.actor_soft_update_weight, 323 | ) 324 | self._soft_update_target_network( 325 | self._actor_target.encoder, 326 | self._actor.encoder, 327 | self._config.encoder_soft_update_weight, 328 | ) 329 | 330 | return info.get_dict(only_scalar=True) 331 | -------------------------------------------------------------------------------- /algorithms/expert_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import glob 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | import numpy as np 8 | import gym.spaces 9 | 10 | from ..utils.logger import logger 11 | from ..utils.gym_env import get_non_absorbing_state, get_absorbing_state, zero_value 12 | 13 | 14 | class ExpertDataset(Dataset): 15 | """ Dataset class for Imitation Learning. """ 16 | 17 | def __init__( 18 | self, 19 | path, 20 | subsample_interval=1, 21 | ac_space=None, 22 | train=True, 23 | transform=None, 24 | target_transform=None, 25 | download=False, 26 | use_low_level=False, 27 | sample_range_start=0.0, 28 | sample_range_end=1.0, 29 | ): 30 | self.train = train # training set or test set 31 | 32 | self._data = [] 33 | self._ac_space = ac_space 34 | 35 | assert ( 36 | path is not None 37 | ), "--demo_path should be set (e.g. demos/Sawyer_toy_table)" 38 | demo_files = self._get_demo_files(path) 39 | num_demos = 0 40 | 41 | # now load the picked numpy arrays 42 | for file_path in demo_files: 43 | with open(file_path, "rb") as f: 44 | demos = pickle.load(f) 45 | if not isinstance(demos, list): 46 | demos = [demos] 47 | 48 | for demo in demos: 49 | if len(demo["obs"]) != len(demo["actions"]) + 1: 50 | logger.error( 51 | "Mismatch in # of observations (%d) and actions (%d) (%s)", 52 | len(demo["obs"]), 53 | len(demo["actions"]), 54 | file_path, 55 | ) 56 | continue 57 | 58 | offset = np.random.randint(0, subsample_interval) 59 | num_demos += 1 60 | 61 | if use_low_level: 62 | length = len(demo["low_level_actions"]) 63 | start = int(length * sample_range_start) 64 | end = int(length * sample_range_end) 65 | for i in range(start + offset, end, subsample_interval): 66 | transition = { 67 | "ob": demo["low_level_obs"][i], 68 | "ob_next": demo["low_level_obs"][i + 1], 69 | } 70 | if isinstance(demo["low_level_actions"][i], dict): 71 | transition["ac"] = demo["low_level_actions"][i] 72 | else: 73 | transition["ac"] = gym.spaces.unflatten( 74 | ac_space, demo["low_level_actions"][i] 75 | ) 76 | 77 | transition["done"] = 1 if i + 1 == length else 0 78 | 79 | self._data.append(transition) 80 | 81 | continue 82 | 83 | length = len(demo["actions"]) 84 | start = int(length * sample_range_start) 85 | end = int(length * sample_range_end) 86 | for i in range(start + offset, end, subsample_interval): 87 | transition = { 88 | "ob": demo["obs"][i], 89 | "ob_next": demo["obs"][i + 1], 90 | } 91 | if isinstance(demo["actions"][i], dict): 92 | transition["ac"] = demo["actions"][i] 93 | else: 94 | transition["ac"] = gym.spaces.unflatten( 95 | ac_space, demo["actions"][i] 96 | ) 97 | if "rewards" in demo: 98 | transition["rew"] = demo["rewards"][i] 99 | if "dones" in demo: 100 | transition["done"] = int(demo["dones"][i]) 101 | else: 102 | transition["done"] = 1 if i + 1 == length else 0 103 | 104 | self._data.append(transition) 105 | 106 | logger.warn( 107 | "Load %d demonstrations with %d states from %d files", 108 | num_demos, 109 | len(self._data), 110 | len(demo_files), 111 | ) 112 | 113 | def add_absorbing_states(self, ob_space, ac_space): 114 | new_data = [] 115 | absorbing_state = get_absorbing_state(ob_space) 116 | absorbing_action = zero_value(ac_space, dtype=np.float32) 117 | for i in range(len(self._data)): 118 | transition = self._data[i].copy() 119 | transition["ob"] = get_non_absorbing_state(self._data[i]["ob"]) 120 | # learn reward for the last transition regardless of timeout (different from paper) 121 | if self._data[i]["done"]: 122 | transition["ob_next"] = absorbing_state 123 | transition["done_mask"] = 0 # -1 absorbing, 0 done, 1 not done 124 | else: 125 | transition["ob_next"] = get_non_absorbing_state( 126 | self._data[i]["ob_next"] 127 | ) 128 | transition["done_mask"] = 1 # -1 absorbing, 0 done, 1 not done 129 | new_data.append(transition) 130 | 131 | if self._data[i]["done"]: 132 | transition = { 133 | "ob": absorbing_state, 134 | "ob_next": absorbing_state, 135 | "ac": absorbing_action, 136 | # "rew": np.float64(0.0), 137 | "done": 0, 138 | "done_mask": -1, # -1 absorbing, 0 done, 1 not done 139 | } 140 | new_data.append(transition) 141 | 142 | self._data = new_data 143 | 144 | def _get_demo_files(self, demo_file_path): 145 | demos = [] 146 | if not demo_file_path.endswith(".pkl"): 147 | demo_file_path = demo_file_path + "*.pkl" 148 | for f in glob.glob(demo_file_path): 149 | if os.path.isfile(f): 150 | demos.append(f) 151 | return demos 152 | 153 | def __getitem__(self, index): 154 | """ 155 | Args: 156 | index (int): Index 157 | Returns: 158 | tuple: (ob, ac) where target is index of the target class. 159 | """ 160 | return self._data[index] 161 | 162 | def __len__(self): 163 | return len(self._data) 164 | -------------------------------------------------------------------------------- /algorithms/gail_agent.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.autograd as autograd 8 | import torch.distributions 9 | from torch.optim.lr_scheduler import StepLR 10 | 11 | from .base_agent import BaseAgent 12 | from .ppo_agent import PPOAgent 13 | from .dataset import ReplayBuffer, RandomSampler 14 | from .expert_dataset import ExpertDataset 15 | from ..networks.discriminator import Discriminator 16 | from ..utils.info_dict import Info 17 | from ..utils.logger import logger 18 | from ..utils.mpi import mpi_average 19 | from ..utils.pytorch import ( 20 | optimizer_cuda, 21 | count_parameters, 22 | sync_networks, 23 | sync_grads, 24 | to_tensor, 25 | ) 26 | 27 | 28 | class GAILAgent(BaseAgent): 29 | def __init__(self, config, ob_space, ac_space, env_ob_space): 30 | super().__init__(config, ob_space) 31 | 32 | self._ob_space = ob_space 33 | self._ac_space = ac_space 34 | 35 | if self._config.gail_rl_algo == "ppo": 36 | self._rl_agent = PPOAgent(config, ob_space, ac_space, env_ob_space) 37 | self._rl_agent.set_reward_function(self._predict_reward) 38 | 39 | # build up networks 40 | self._discriminator = Discriminator( 41 | config, ob_space, ac_space if not config.gail_no_action else None 42 | ) 43 | self._discriminator_loss = nn.BCEWithLogitsLoss() 44 | self._network_cuda(config.device) 45 | 46 | # build optimizers 47 | self._discriminator_optim = optim.Adam( 48 | self._discriminator.parameters(), lr=config.discriminator_lr 49 | ) 50 | 51 | # build learning rate scheduler 52 | self._discriminator_lr_scheduler = StepLR( 53 | self._discriminator_optim, 54 | step_size=self._config.max_global_step // self._config.rollout_length // 5, 55 | gamma=0.5, 56 | ) 57 | 58 | # expert dataset 59 | if config.is_train: 60 | self._dataset = ExpertDataset( 61 | config.demo_path, 62 | config.demo_subsample_interval, 63 | ac_space, 64 | use_low_level=config.demo_low_level, 65 | sample_range_start=config.demo_sample_range_start, 66 | sample_range_end=config.demo_sample_range_end, 67 | ) 68 | self._data_loader = torch.utils.data.DataLoader( 69 | self._dataset, 70 | batch_size=self._config.batch_size, 71 | shuffle=True, 72 | drop_last=True, 73 | ) 74 | self._data_iter = iter(self._data_loader) 75 | 76 | # policy dataset 77 | sampler = RandomSampler() 78 | self._buffer = ReplayBuffer( 79 | [ 80 | "ob", 81 | "ob_next", 82 | "ac", 83 | "done", 84 | "rew", 85 | "ret", 86 | "adv", 87 | "ac_before_activation", 88 | ], 89 | config.rollout_length, 90 | sampler.sample_func, 91 | ) 92 | 93 | self._rl_agent.set_buffer(self._buffer) 94 | 95 | # update observation normalizer with dataset 96 | self.update_normalizer() 97 | 98 | self._log_creation() 99 | 100 | def _predict_reward(self, ob, ac): 101 | if self._config.gail_no_action: 102 | ac = None 103 | with torch.no_grad(): 104 | ret = self._discriminator(ob, ac) 105 | eps = 1e-10 106 | s = torch.sigmoid(ret) 107 | if self._config.gail_reward == "vanilla": 108 | reward = -(1 - s + eps).log() 109 | elif self._config.gail_reward == "gan": 110 | reward = (s + eps).log() - (1 - s + eps).log() 111 | elif self._config.gail_reward == "d": 112 | reward = ret 113 | return reward 114 | 115 | def predict_reward(self, ob, ac=None): 116 | ob = self.normalize(ob) 117 | ob = to_tensor(ob, self._config.device) 118 | if self._config.gail_no_action: 119 | ac = None 120 | if ac is not None: 121 | ac = to_tensor(ac, self._config.device) 122 | 123 | reward = self._predict_reward(ob, ac) 124 | return reward.cpu().item() 125 | 126 | def _log_creation(self): 127 | if self._config.is_chef: 128 | logger.info("Creating a GAIL agent") 129 | logger.info( 130 | "The discriminator has %d parameters", 131 | count_parameters(self._discriminator), 132 | ) 133 | 134 | def store_episode(self, rollouts): 135 | self._rl_agent.store_episode(rollouts) 136 | 137 | def state_dict(self): 138 | return { 139 | "rl_agent": self._rl_agent.state_dict(), 140 | "discriminator_state_dict": self._discriminator.state_dict(), 141 | "discriminator_optim_state_dict": self._discriminator_optim.state_dict(), 142 | "ob_norm_state_dict": self._ob_norm.state_dict(), 143 | } 144 | 145 | def load_state_dict(self, ckpt): 146 | if "rl_agent" in ckpt: 147 | self._rl_agent.load_state_dict(ckpt["rl_agent"]) 148 | else: 149 | self._rl_agent.load_state_dict(ckpt) 150 | self._network_cuda(self._config.device) 151 | return 152 | 153 | self._discriminator.load_state_dict(ckpt["discriminator_state_dict"]) 154 | self._ob_norm.load_state_dict(ckpt["ob_norm_state_dict"]) 155 | self._network_cuda(self._config.device) 156 | 157 | self._discriminator_optim.load_state_dict( 158 | ckpt["discriminator_optim_state_dict"] 159 | ) 160 | optimizer_cuda(self._discriminator_optim, self._config.device) 161 | 162 | def _network_cuda(self, device): 163 | self._discriminator.to(device) 164 | 165 | def sync_networks(self): 166 | self._rl_agent.sync_networks() 167 | sync_networks(self._discriminator) 168 | 169 | def update_normalizer(self, obs=None): 170 | """ Updates normalizers for discriminator and PPO agent. """ 171 | if self._config.ob_norm: 172 | if obs is None: 173 | data_loader = torch.utils.data.DataLoader( 174 | self._dataset, 175 | batch_size=self._config.batch_size, 176 | shuffle=False, 177 | drop_last=False, 178 | ) 179 | for obs in data_loader: 180 | super().update_normalizer(obs) 181 | self._rl_agent.update_normalizer(obs) 182 | else: 183 | super().update_normalizer(obs) 184 | self._rl_agent.update_normalizer(obs) 185 | 186 | def train(self): 187 | train_info = Info() 188 | 189 | self._discriminator_lr_scheduler.step() 190 | 191 | num_batches = ( 192 | self._config.rollout_length 193 | // self._config.batch_size 194 | // self._config.discriminator_update_freq 195 | ) 196 | assert num_batches > 0 197 | for _ in range(num_batches): 198 | policy_data = self._buffer.sample(self._config.batch_size) 199 | try: 200 | expert_data = next(self._data_iter) 201 | except StopIteration: 202 | self._data_iter = iter(self._data_loader) 203 | expert_data = next(self._data_iter) 204 | 205 | _train_info = self._update_discriminator(policy_data, expert_data) 206 | train_info.add(_train_info) 207 | 208 | _train_info = self._rl_agent.train() 209 | train_info.add(_train_info) 210 | 211 | for _ in range(num_batches): 212 | try: 213 | expert_data = next(self._data_iter) 214 | except StopIteration: 215 | self._data_iter = iter(self._data_loader) 216 | expert_data = next(self._data_iter) 217 | self.update_normalizer(expert_data["ob"]) 218 | 219 | return train_info.get_dict(only_scalar=True) 220 | 221 | def _update_discriminator(self, policy_data, expert_data): 222 | info = Info() 223 | 224 | _to_tensor = lambda x: to_tensor(x, self._config.device) 225 | # pre-process observations 226 | p_o = policy_data["ob"] 227 | p_o = self.normalize(p_o) 228 | p_o = _to_tensor(p_o) 229 | 230 | e_o = expert_data["ob"] 231 | e_o = self.normalize(e_o) 232 | e_o = _to_tensor(e_o) 233 | 234 | if self._config.gail_no_action: 235 | p_ac = None 236 | e_ac = None 237 | else: 238 | p_ac = _to_tensor(policy_data["ac"]) 239 | e_ac = _to_tensor(expert_data["ac"]) 240 | 241 | p_logit = self._discriminator(p_o, p_ac) 242 | e_logit = self._discriminator(e_o, e_ac) 243 | 244 | p_output = torch.sigmoid(p_logit) 245 | e_output = torch.sigmoid(e_logit) 246 | 247 | p_loss = self._discriminator_loss( 248 | p_logit, torch.zeros_like(p_logit).to(self._config.device) 249 | ) 250 | e_loss = self._discriminator_loss( 251 | e_logit, torch.ones_like(e_logit).to(self._config.device) 252 | ) 253 | 254 | logits = torch.cat([p_logit, e_logit], dim=0) 255 | entropy = torch.distributions.Bernoulli(logits=logits).entropy().mean() 256 | entropy_loss = -self._config.gail_entropy_loss_coeff * entropy 257 | 258 | grad_pen = self._compute_grad_pen(p_o, p_ac, e_o, e_ac) 259 | grad_pen_loss = self._config.gail_grad_penalty_coeff * grad_pen 260 | 261 | gail_loss = p_loss + e_loss + entropy_loss + grad_pen_loss 262 | 263 | # update the discriminator 264 | self._discriminator.zero_grad() 265 | gail_loss.backward() 266 | sync_grads(self._discriminator) 267 | self._discriminator_optim.step() 268 | 269 | info["gail_policy_output"] = p_output.mean().detach().cpu().item() 270 | info["gail_expert_output"] = e_output.mean().detach().cpu().item() 271 | info["gail_entropy"] = entropy.detach().cpu().item() 272 | info["gail_policy_loss"] = p_loss.detach().cpu().item() 273 | info["gail_expert_loss"] = e_loss.detach().cpu().item() 274 | info["gail_entropy_loss"] = entropy_loss.detach().cpu().item() 275 | info["gail_grad_pen"] = grad_pen.detach().cpu().item() 276 | info["gail_grad_loss"] = grad_pen_loss.detach().cpu().item() 277 | info["gail_loss"] = gail_loss.detach().cpu().item() 278 | 279 | return mpi_average(info.get_dict(only_scalar=True)) 280 | 281 | def _compute_grad_pen(self, policy_ob, policy_ac, expert_ob, expert_ac): 282 | batch_size = self._config.batch_size 283 | alpha = torch.rand(batch_size, 1, device=self._config.device) 284 | 285 | def blend_dict(a, b, alpha): 286 | if isinstance(a, dict): 287 | return OrderedDict( 288 | [(k, blend_dict(a[k], b[k], alpha)) for k in a.keys()] 289 | ) 290 | elif isinstance(a, list): 291 | return [blend_dict(a[i], b[i], alpha) for i in range(len(a))] 292 | else: 293 | expanded_alpha = alpha.expand_as(a) 294 | ret = expanded_alpha * a + (1 - expanded_alpha) * b 295 | ret.requires_grad = True 296 | return ret 297 | 298 | interpolated_ob = blend_dict(policy_ob, expert_ob, alpha) 299 | inputs = list(interpolated_ob.values()) 300 | if policy_ac is not None: 301 | interpolated_ac = blend_dict(policy_ac, expert_ac, alpha) 302 | inputs = inputs + list(interpolated_ob.values()) 303 | else: 304 | interpolated_ac = None 305 | 306 | interpolated_logit = self._discriminator(interpolated_ob, interpolated_ac) 307 | ones = torch.ones(interpolated_logit.size(), device=self._config.device) 308 | 309 | grad = autograd.grad( 310 | outputs=interpolated_logit, 311 | inputs=inputs, 312 | grad_outputs=ones, 313 | create_graph=True, 314 | retain_graph=True, 315 | only_inputs=True, 316 | )[0] 317 | 318 | grad_pen = (grad.norm(2, dim=1) - 1).pow(2).mean() 319 | return grad_pen 320 | -------------------------------------------------------------------------------- /algorithms/ppo_agent.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from torch.optim.lr_scheduler import StepLR 6 | 7 | from ..networks import Actor, Critic 8 | from ..utils.info_dict import Info 9 | from ..utils.logger import logger 10 | from ..utils.mpi import mpi_average 11 | from ..utils.pytorch import ( 12 | compute_gradient_norm, 13 | compute_weight_norm, 14 | count_parameters, 15 | obs2tensor, 16 | optimizer_cuda, 17 | sync_grads, 18 | sync_networks, 19 | to_tensor, 20 | center_crop_images 21 | ) 22 | from .base_agent import BaseAgent 23 | from .dataset import RandomSampler, ReplayBuffer 24 | 25 | 26 | class PPOAgent(BaseAgent): 27 | def __init__(self, config, ob_space, ac_space, env_ob_space): 28 | super().__init__(config, ob_space) 29 | 30 | self._ac_space = ac_space 31 | 32 | # build up networks 33 | self._actor = Actor(config, ob_space, ac_space, config.tanh_policy) 34 | self._old_actor = Actor(config, ob_space, ac_space, config.tanh_policy) 35 | self._critic = Critic(config, ob_space) 36 | self._network_cuda(config.device) 37 | 38 | self._actor_optim = optim.Adam(self._actor.parameters(), lr=config.actor_lr) 39 | self._critic_optim = optim.Adam(self._critic.parameters(), lr=config.critic_lr) 40 | 41 | self._actor_lr_scheduler = StepLR( 42 | self._actor_optim, 43 | step_size=self._config.max_global_step // self._config.rollout_length // 5, 44 | gamma=0.5, 45 | ) 46 | self._critic_lr_scheduler = StepLR( 47 | self._critic_optim, 48 | step_size=self._config.max_global_step // self._config.rollout_length // 5, 49 | gamma=0.5, 50 | ) 51 | 52 | sampler = RandomSampler(image_crop_size=self._config.encoder_image_size) 53 | self._buffer = ReplayBuffer( 54 | [ 55 | "ob", 56 | "ob_next", 57 | "ac", 58 | "done", 59 | "rew", 60 | "ret", 61 | "adv", 62 | "ac_before_activation", 63 | ], 64 | config.rollout_length, 65 | sampler.sample_func, 66 | ) 67 | 68 | self._update_iter = 0 69 | 70 | self._log_creation() 71 | 72 | def _log_creation(self): 73 | if self._config.is_chef: 74 | logger.info("Creating a PPO agent") 75 | logger.info("The actor has %d parameters", count_parameters(self._actor)) 76 | logger.info("The critic has %d parameters", count_parameters(self._critic)) 77 | 78 | def store_episode(self, rollouts): 79 | self._compute_gae(rollouts) 80 | self._buffer.store_episode(rollouts) 81 | 82 | def _compute_gae(self, rollouts): 83 | T = len(rollouts["done"]) 84 | ob = rollouts["ob"] 85 | ob = self.normalize(ob) 86 | ob = obs2tensor(ob, self._config.device) 87 | for k, v in ob.items(): 88 | if self._config.encoder_type == "cnn" and len(v.shape) == 4: 89 | ob[k] = center_crop_images(v, self._config.encoder_image_size) 90 | 91 | ob_last = rollouts["ob_next"][-1:] 92 | ob_last = self.normalize(ob_last) 93 | ob_last = obs2tensor(ob_last, self._config.device) 94 | for k, v in ob_last.items(): 95 | if self._config.encoder_type == "cnn" and len(v.shape) == 4: 96 | ob_last[k] = center_crop_images(v, self._config.encoder_image_size) 97 | 98 | done = rollouts["done"] 99 | rew = rollouts["rew"] 100 | 101 | vpred = self._critic(ob).detach().cpu().numpy()[:, 0] 102 | vpred_last = self._critic(ob_last).detach().cpu().numpy()[:, 0] 103 | vpred = np.append(vpred, vpred_last) 104 | assert len(vpred) == T + 1 105 | 106 | if hasattr(self, "_predict_reward"): 107 | ob = rollouts["ob"] 108 | ob = self.normalize(ob) 109 | ob = obs2tensor(ob, self._config.device) 110 | ac = obs2tensor(rollouts["ac"], self._config.device) 111 | rew_il = self._predict_reward(ob, ac).cpu().numpy().squeeze() 112 | rew = (1 - self._config.gail_env_reward) * rew_il[ 113 | :T 114 | ] + self._config.gail_env_reward * np.array(rew) 115 | assert rew.shape == (T,) 116 | 117 | adv = np.empty((T,), "float32") 118 | lastgaelam = 0 119 | for t in reversed(range(T)): 120 | nonterminal = 1 - done[t] 121 | delta = ( 122 | rew[t] 123 | + self._config.rl_discount_factor * vpred[t + 1] * nonterminal 124 | - vpred[t] 125 | ) 126 | adv[t] = lastgaelam = ( 127 | delta 128 | + self._config.rl_discount_factor 129 | * self._config.gae_lambda 130 | * nonterminal 131 | * lastgaelam 132 | ) 133 | 134 | ret = adv + vpred[:-1] 135 | 136 | assert np.isfinite(adv).all() 137 | assert np.isfinite(ret).all() 138 | 139 | # update rollouts 140 | if self._config.advantage_norm: 141 | rollouts["adv"] = ((adv - adv.mean()) / (adv.std() + 1e-5)).tolist() 142 | else: 143 | rollouts["adv"] = adv.tolist() 144 | 145 | rollouts["ret"] = ret.tolist() 146 | 147 | def state_dict(self): 148 | return { 149 | "actor_state_dict": self._actor.state_dict(), 150 | "critic_state_dict": self._critic.state_dict(), 151 | "actor_optim_state_dict": self._actor_optim.state_dict(), 152 | "critic_optim_state_dict": self._critic_optim.state_dict(), 153 | "ob_norm_state_dict": self._ob_norm.state_dict(), 154 | } 155 | 156 | def load_state_dict(self, ckpt): 157 | if "critic_state_dict" not in ckpt: 158 | # BC initialization 159 | logger.warn("Load only actor from BC initialization") 160 | self._actor.load_state_dict(ckpt["actor_state_dict"], strict=False) 161 | self._network_cuda(self._config.device) 162 | self._ob_norm.load_state_dict(ckpt["ob_norm_state_dict"]) 163 | return 164 | 165 | self._actor.load_state_dict(ckpt["actor_state_dict"]) 166 | self._critic.load_state_dict(ckpt["critic_state_dict"]) 167 | self._ob_norm.load_state_dict(ckpt["ob_norm_state_dict"]) 168 | self._network_cuda(self._config.device) 169 | 170 | self._actor_optim.load_state_dict(ckpt["actor_optim_state_dict"]) 171 | self._critic_optim.load_state_dict(ckpt["critic_optim_state_dict"]) 172 | optimizer_cuda(self._actor_optim, self._config.device) 173 | optimizer_cuda(self._critic_optim, self._config.device) 174 | 175 | def _network_cuda(self, device): 176 | self._actor.to(device) 177 | self._old_actor.to(device) 178 | self._critic.to(device) 179 | 180 | def sync_networks(self): 181 | sync_networks(self._actor) 182 | sync_networks(self._critic) 183 | 184 | def train(self): 185 | train_info = Info() 186 | 187 | self._copy_target_network(self._old_actor, self._actor) 188 | 189 | num_batches = ( 190 | self._config.ppo_epoch 191 | * self._config.rollout_length 192 | // self._config.batch_size 193 | ) 194 | assert num_batches > 0 195 | 196 | for _ in range(num_batches): 197 | transitions = self._buffer.sample(self._config.batch_size) 198 | _train_info = self._update_network(transitions) 199 | train_info.add(_train_info) 200 | 201 | self._buffer.clear() 202 | 203 | self._actor_lr_scheduler.step() 204 | self._critic_lr_scheduler.step() 205 | 206 | logger.info( 207 | "Actor lr %f, Critic lr %f, PPO Clip Frac %f", 208 | self._actor_lr_scheduler.get_lr()[0], 209 | self._critic_lr_scheduler.get_lr()[0], 210 | np.mean(train_info["ppo_clip_frac"]) 211 | ) 212 | 213 | # slow! 214 | # train_info.add( 215 | # { 216 | # "actor_grad_norm": compute_gradient_norm(self._actor), 217 | # "actor_weight_norm": compute_weight_norm(self._actor), 218 | # "critic_grad_norm": compute_gradient_norm(self._critic), 219 | # "critic_weight_norm": compute_weight_norm(self._critic), 220 | # } 221 | # ) 222 | return mpi_average(train_info.get_dict(only_scalar=True)) 223 | 224 | def _update_actor(self, o, a_z, adv): 225 | info = Info() 226 | 227 | _, _, log_pi, ent = self._actor.act( 228 | o, activations=a_z, return_log_prob=True 229 | ) 230 | _, _, old_log_pi, _ = self._old_actor.act( 231 | o, activations=a_z, return_log_prob=True 232 | ) 233 | if old_log_pi.min() < -100: 234 | logger.error("sampling an action with a probability of 1e-100") 235 | import ipdb 236 | ipdb.set_trace() 237 | 238 | # the actor loss 239 | entropy_loss = -self._config.entropy_loss_coeff * ent.mean() 240 | ratio = torch.exp(log_pi - old_log_pi) 241 | surr1 = ratio * adv 242 | surr2 = ( 243 | torch.clamp(ratio, 1.0 - self._config.ppo_clip, 1.0 + self._config.ppo_clip) 244 | * adv 245 | ) 246 | actor_loss = -torch.min(surr1, surr2).mean() 247 | 248 | ppo_clip_frac = torch.gt(torch.abs(ratio - 1.0), self._config.ppo_clip).float().mean() 249 | 250 | if ( 251 | not np.isfinite(ratio.cpu().detach()).all() 252 | or not np.isfinite(adv.cpu().detach()).all() 253 | ): 254 | import ipdb 255 | 256 | ipdb.set_trace() 257 | info["ppo_clip_frac"] = ppo_clip_frac.cpu().item() 258 | info["entropy_loss"] = entropy_loss.cpu().item() 259 | info["actor_loss"] = actor_loss.cpu().item() 260 | actor_loss += entropy_loss 261 | 262 | # update the actor 263 | self._actor_optim.zero_grad() 264 | actor_loss.backward() 265 | if self._config.max_grad_norm: 266 | torch.nn.utils.clip_grad_norm_( 267 | self._actor.parameters(), self._config.max_grad_norm 268 | ) 269 | sync_grads(self._actor) 270 | self._actor_optim.step() 271 | 272 | # include info from policy 273 | info.add(self._actor.info) 274 | 275 | return info 276 | 277 | def _update_critic(self, o, ret): 278 | info = Info() 279 | 280 | # the q loss 281 | value_pred = self._critic(o) 282 | value_loss = self._config.value_loss_coeff * (ret - value_pred).pow(2).mean() 283 | 284 | # update the critic 285 | self._critic_optim.zero_grad() 286 | value_loss.backward() 287 | if self._config.max_grad_norm: 288 | torch.nn.utils.clip_grad_norm_( 289 | self._critic.parameters(), self._config.max_grad_norm 290 | ) 291 | sync_grads(self._critic) 292 | self._critic_optim.step() 293 | 294 | info["value_target"] = ret.mean().cpu().item() 295 | info["value_predicted"] = value_pred.mean().cpu().item() 296 | info["value_loss"] = value_loss.cpu().item() 297 | 298 | return info 299 | 300 | def _update_network(self, transitions): 301 | info = Info() 302 | 303 | # pre-process observations 304 | o = transitions["ob"] 305 | o = self.normalize(o) 306 | 307 | bs = len(transitions["done"]) 308 | _to_tensor = lambda x: to_tensor(x, self._config.device) 309 | o = _to_tensor(o) 310 | ac = _to_tensor(transitions["ac"]) 311 | a_z = _to_tensor(transitions["ac_before_activation"]) 312 | ret = _to_tensor(transitions["ret"]).reshape(bs, 1) 313 | adv = _to_tensor(transitions["adv"]).reshape(bs, 1) 314 | 315 | self._update_iter += 1 316 | 317 | critic_train_info = self._update_critic(o, ret) 318 | info.add(critic_train_info) 319 | 320 | if self._update_iter % self._config.actor_update_freq == 0: 321 | actor_train_info = self._update_actor(o, a_z, adv) 322 | info.add(actor_train_info) 323 | 324 | return info 325 | -------------------------------------------------------------------------------- /algorithms/rollouts.py: -------------------------------------------------------------------------------- 1 | """ 2 | Runs rollouts (RolloutRunner class) and collects transitions using Rollout class. 3 | """ 4 | 5 | import random 6 | import pickle 7 | from collections import defaultdict 8 | 9 | import numpy as np 10 | import cv2 11 | 12 | from ..utils.logger import logger 13 | from ..utils.info_dict import Info 14 | from ..utils.gym_env import get_non_absorbing_state, zero_value 15 | 16 | 17 | class Rollout(object): 18 | """ 19 | Rollout storing an episode. 20 | """ 21 | 22 | def __init__(self): 23 | """ Initialize buffer. """ 24 | self._history = defaultdict(list) 25 | 26 | def add(self, data): 27 | """ Add a transition @data to rollout buffer. """ 28 | for key, value in data.items(): 29 | self._history[key].append(value) 30 | 31 | def get(self): 32 | """ Returns rollout buffer and clears buffer. """ 33 | batch = {} 34 | batch["ob"] = self._history["ob"] 35 | batch["ob_next"] = self._history["ob_next"] 36 | batch["ac"] = self._history["ac"] 37 | batch["ac_before_activation"] = self._history["ac_before_activation"] 38 | batch["done"] = self._history["done"] 39 | batch["done_mask"] = self._history["done_mask"] 40 | batch["rew"] = self._history["rew"] 41 | self._history = defaultdict(list) 42 | return batch 43 | 44 | 45 | class RolloutRunner(object): 46 | """ 47 | Run rollout given environment and policy. 48 | """ 49 | 50 | def __init__(self, config, env, env_eval, pi): 51 | """ 52 | Args: 53 | config: configurations for the environment. 54 | env: environment. 55 | pi: policy. 56 | """ 57 | 58 | self._config = config 59 | self._env = env 60 | self._env_eval = env_eval 61 | self._pi = pi 62 | 63 | def run( 64 | self, 65 | is_train=True, 66 | every_steps=None, 67 | every_episodes=None, 68 | log_prefix="", 69 | step=0, 70 | ): 71 | """ 72 | Collects trajectories and yield every @every_steps/@every_episodes. 73 | 74 | Args: 75 | is_train: whether rollout is for training or evaluation. 76 | every_steps: if not None, returns rollouts @every_steps 77 | every_episodes: if not None, returns rollouts @every_epiosdes 78 | log_prefix: log as @log_prefix rollout: %s 79 | """ 80 | if every_steps is None and every_episodes is None: 81 | raise ValueError("Both every_steps and every_episodes cannot be None") 82 | 83 | config = self._config 84 | device = config.device 85 | env = self._env if is_train else self._env_eval 86 | pi = self._pi 87 | il = hasattr(pi, "predict_reward") 88 | 89 | # initialize rollout buffer 90 | rollout = Rollout() 91 | reward_info = Info() 92 | ep_info = Info() 93 | episode = 0 94 | 95 | while True: 96 | done = False 97 | ep_len = 0 98 | ep_rew = 0 99 | ep_rew_rl = 0 100 | if il: 101 | ep_rew_il = 0 102 | ob = env.reset() 103 | 104 | # run rollout 105 | while not done: 106 | # sample action from policy 107 | if step < config.warm_up_steps: 108 | ac, ac_before_activation = env.action_space.sample(), 0 109 | else: 110 | ac, ac_before_activation = pi.act(ob, is_train=is_train) 111 | 112 | rollout.add( 113 | {"ob": ob, "ac": ac, "ac_before_activation": ac_before_activation} 114 | ) 115 | 116 | if il: 117 | reward_il = pi.predict_reward(ob, ac) 118 | 119 | # take a step 120 | ob, reward, done, info = env.step(ac) 121 | rollout.add({"ob_next": ob}) 122 | 123 | # replace reward 124 | if il: 125 | reward_rl = ( 126 | 1 - config.gail_env_reward 127 | ) * reward_il + config.gail_env_reward * reward 128 | else: 129 | reward_rl = reward 130 | 131 | rollout.add({"done": done, "rew": reward}) 132 | step += 1 133 | ep_len += 1 134 | ep_rew += reward 135 | ep_rew_rl += reward_rl 136 | if il: 137 | ep_rew_il += reward_il 138 | 139 | if done and ep_len < env.max_episode_steps: 140 | done_mask = 0 # -1 absorbing, 0 done, 1 not done 141 | else: 142 | done_mask = 1 143 | 144 | rollout.add( 145 | {"done_mask": done_mask} 146 | ) # -1 absorbing, 0 done, 1 not done 147 | 148 | reward_info.add(info) 149 | 150 | if config.absorbing_state and done_mask == 0: 151 | absorbing_state = env.get_absorbing_state() 152 | absorbing_action = zero_value(env.action_space) 153 | rollout._history["ob_next"][-1] = absorbing_state 154 | rollout.add( 155 | { 156 | "ob": absorbing_state, 157 | "ob_next": absorbing_state, 158 | "ac": absorbing_action, 159 | "ac_before_activation": absorbing_action, 160 | "rew": 0.0, 161 | "done": 0, 162 | "done_mask": -1, # -1 absorbing, 0 done, 1 not done 163 | } 164 | ) 165 | 166 | if every_steps is not None and step % every_steps == 0: 167 | yield rollout.get(), ep_info.get_dict(only_scalar=True) 168 | 169 | # compute average/sum of information 170 | ep_info.add({"len": ep_len, "rew": ep_rew, "rew_rl": ep_rew_rl}) 171 | if il: 172 | ep_info.add({"rew_il": ep_rew_il}) 173 | reward_info_dict = reward_info.get_dict(reduction="sum", only_scalar=True) 174 | ep_info.add(reward_info_dict) 175 | reward_info_dict.update({"len": ep_len, "rew": ep_rew, "rew_rl": ep_rew_rl}) 176 | if il: 177 | reward_info_dict.update({"rew_il": ep_rew_il}) 178 | 179 | logger.info( 180 | log_prefix + " rollout: %s", 181 | { 182 | k: v 183 | for k, v in reward_info_dict.items() 184 | if not "qpos" in k and np.isscalar(v) 185 | }, 186 | ) 187 | 188 | episode += 1 189 | if every_episodes is not None and episode % every_episodes == 0: 190 | yield rollout.get(), ep_info.get_dict(only_scalar=True) 191 | 192 | def run_episode(self, max_step=10000, is_train=True, record_video=False): 193 | """ 194 | Runs one episode and returns the rollout (mainly for evaluation). 195 | 196 | Args: 197 | max_step: maximum number of steps of the rollout. 198 | is_train: whether rollout is for training or evaluation. 199 | record_video: record video of rollout if True. 200 | """ 201 | config = self._config 202 | device = config.device 203 | env = self._env if is_train else self._env_eval 204 | pi = self._pi 205 | il = hasattr(pi, "predict_reward") 206 | 207 | # initialize rollout buffer 208 | rollout = Rollout() 209 | reward_info = Info() 210 | 211 | done = False 212 | ep_len = 0 213 | ep_rew = 0 214 | ep_rew_rl = 0 215 | if il: 216 | ep_rew_il = 0 217 | 218 | ob = env.reset() 219 | 220 | self._record_frames = [] 221 | if record_video: 222 | self._store_frame(env, ep_len, ep_rew) 223 | 224 | # run rollout 225 | while not done and ep_len < max_step: 226 | # sample action from policy 227 | ac, ac_before_activation = pi.act(ob, is_train=is_train) 228 | rollout.add( 229 | {"ob": ob, "ac": ac, "ac_before_activation": ac_before_activation} 230 | ) 231 | 232 | if il: 233 | reward_il = pi.predict_reward(ob, ac) 234 | 235 | # take a step 236 | ob, reward, done, info = env.step(ac) 237 | 238 | # replace reward 239 | if il: 240 | reward_rl = ( 241 | 1 - config.gail_env_reward 242 | ) * reward_il + config.gail_env_reward * reward 243 | else: 244 | reward_rl = reward 245 | 246 | rollout.add({"done": done, "rew": reward}) 247 | ep_len += 1 248 | ep_rew += reward 249 | ep_rew_rl += reward_rl 250 | if il: 251 | ep_rew_il += reward_il 252 | 253 | reward_info.add(info) 254 | if record_video: 255 | frame_info = info.copy() 256 | if il: 257 | frame_info.update( 258 | { 259 | "ep_rew_il": ep_rew_il, 260 | "rew_il": reward_il, 261 | "rew_rl": reward_rl, 262 | } 263 | ) 264 | self._store_frame(env, ep_len, ep_rew, frame_info) 265 | 266 | # add last observation 267 | rollout.add({"ob": ob}) 268 | 269 | # compute average/sum of information 270 | ep_info = {"len": ep_len, "rew": ep_rew, "rew_rl": ep_rew_rl} 271 | if il: 272 | ep_info["rew_il"] = ep_rew_il 273 | ep_info.update(reward_info.get_dict(reduction="sum", only_scalar=True)) 274 | 275 | return rollout.get(), ep_info, self._record_frames 276 | 277 | def _store_frame(self, env, ep_len, ep_rew, info={}): 278 | """ Renders a frame and stores in @self._record_frames. """ 279 | color = (200, 200, 200) 280 | 281 | # render video frame 282 | frame = env.render("rgb_array") 283 | if len(frame.shape) == 4: 284 | frame = frame[0] 285 | if np.max(frame) <= 1.0: 286 | frame *= 255.0 287 | 288 | h, w = frame.shape[:2] 289 | if h < 500: 290 | h, w = 500, 500 291 | frame = cv2.resize(frame, (h, w)) 292 | frame = np.concatenate([frame, np.zeros((h, w, 3))], 0) 293 | scale = h / 500 294 | 295 | # add caption to video frame 296 | if self._config.record_video_caption: 297 | text = "{:4} {}".format(ep_len, ep_rew) 298 | font_size = 0.4 * scale 299 | thickness = 1 300 | offset = int(12 * scale) 301 | x, y = int(5 * scale), h + int(10 * scale) 302 | cv2.putText( 303 | frame, 304 | text, 305 | (x, y), 306 | cv2.FONT_HERSHEY_SIMPLEX, 307 | font_size, 308 | (255, 255, 0), 309 | thickness, 310 | cv2.LINE_AA, 311 | ) 312 | for i, k in enumerate(info.keys()): 313 | v = info[k] 314 | key_text = "{}: ".format(k) 315 | (key_width, _), _ = cv2.getTextSize( 316 | key_text, cv2.FONT_HERSHEY_SIMPLEX, font_size, thickness 317 | ) 318 | 319 | cv2.putText( 320 | frame, 321 | key_text, 322 | (x, y + offset * (i + 2)), 323 | cv2.FONT_HERSHEY_SIMPLEX, 324 | font_size, 325 | (66, 133, 244), 326 | thickness, 327 | cv2.LINE_AA, 328 | ) 329 | 330 | cv2.putText( 331 | frame, 332 | str(v), 333 | (x + key_width, y + offset * (i + 2)), 334 | cv2.FONT_HERSHEY_SIMPLEX, 335 | font_size, 336 | (255, 255, 255), 337 | thickness, 338 | cv2.LINE_AA, 339 | ) 340 | 341 | self._record_frames.append(frame) 342 | -------------------------------------------------------------------------------- /algorithms/sac_agent.py: -------------------------------------------------------------------------------- 1 | # SAC training code reference 2 | # https://github.com/vitchyr/rlkit/blob/master/rlkit/torch/sac/sac.py 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | import gym.spaces 10 | 11 | from .base_agent import BaseAgent 12 | from .dataset import ReplayBuffer, RandomSampler, ReplayBufferPerStep 13 | from ..networks import Actor, Critic 14 | from ..utils.info_dict import Info 15 | from ..utils.logger import logger 16 | from ..utils.mpi import mpi_average, mpi_sum 17 | from ..utils.gym_env import spaces_to_shapes 18 | from ..utils.pytorch import ( 19 | optimizer_cuda, 20 | count_parameters, 21 | compute_gradient_norm, 22 | compute_weight_norm, 23 | sync_networks, 24 | sync_grads, 25 | to_tensor, 26 | ) 27 | 28 | 29 | class SACAgent(BaseAgent): 30 | def __init__(self, config, ob_space, ac_space, env_ob_space): 31 | super().__init__(config, ob_space) 32 | 33 | self._ob_space = ob_space 34 | self._ac_space = ac_space 35 | 36 | if config.target_entropy is not None: 37 | self._target_entropy = config.target_entropy 38 | else: 39 | self._target_entropy = -gym.spaces.flatdim(ac_space) 40 | self._log_alpha = torch.tensor( 41 | np.log(config.alpha_init_temperature), 42 | requires_grad=True, 43 | device=config.device, 44 | ) 45 | 46 | # build up networks 47 | self._actor = Actor(config, ob_space, ac_space, config.tanh_policy) 48 | self._critic = Critic(config, ob_space, ac_space) 49 | 50 | # build up target networks 51 | self._critic_target = Critic(config, ob_space, ac_space) 52 | self._network_cuda(config.device) 53 | self._copy_target_network(self._critic_target, self._critic) 54 | self._actor.encoder.copy_conv_weights_from(self._critic.encoder) 55 | 56 | # optimizers 57 | self._alpha_optim = optim.Adam( 58 | [self._log_alpha], lr=config.alpha_lr, betas=(0.5, 0.999) 59 | ) 60 | self._actor_optim = optim.Adam( 61 | self._actor.parameters(), lr=config.actor_lr, betas=(0.9, 0.999) 62 | ) 63 | self._critic_optim = optim.Adam( 64 | self._critic.parameters(), lr=config.critic_lr, betas=(0.9, 0.999) 65 | ) 66 | 67 | # per-episode replay buffer 68 | sampler = RandomSampler(image_crop_size=config.encoder_image_size) 69 | buffer_keys = ["ob", "ob_next", "ac", "done", "rew"] 70 | self._buffer = ReplayBuffer( 71 | buffer_keys, config.buffer_size, sampler.sample_func 72 | ) 73 | 74 | # per-step replay buffer 75 | # shapes = { 76 | # "ob": spaces_to_shapes(env_ob_space), 77 | # "ob_next": spaces_to_shapes(env_ob_space), 78 | # "ac": spaces_to_shapes(ac_space), 79 | # "done": [1], 80 | # "rew": [1], 81 | # } 82 | # self._buffer = ReplayBufferPerStep( 83 | # shapes, config.buffer_size, config.encoder_image_size 84 | # ) 85 | 86 | self._update_iter = 0 87 | 88 | self._log_creation() 89 | 90 | def _log_creation(self): 91 | if self._config.is_chef: 92 | logger.info("Creating a SAC agent") 93 | logger.info("The actor has %d parameters", count_parameters(self._actor)) 94 | logger.info("The critic has %d parameters", count_parameters(self._critic)) 95 | 96 | def store_episode(self, rollouts): 97 | self._num_updates = ( 98 | mpi_sum(len(rollouts["ac"])) 99 | // self._config.num_workers 100 | // self._config.actor_update_freq 101 | ) 102 | self._buffer.store_episode(rollouts) 103 | 104 | def state_dict(self): 105 | return { 106 | "log_alpha": self._log_alpha.cpu().detach().numpy(), 107 | "actor_state_dict": self._actor.state_dict(), 108 | "critic_state_dict": self._critic.state_dict(), 109 | "alpha_optim_state_dict": self._alpha_optim.state_dict(), 110 | "actor_optim_state_dict": self._actor_optim.state_dict(), 111 | "critic_optim_state_dict": self._critic_optim.state_dict(), 112 | "ob_norm_state_dict": self._ob_norm.state_dict(), 113 | } 114 | 115 | def load_state_dict(self, ckpt): 116 | if "log_alpha" not in ckpt: 117 | missing = self._actor.load_state_dict( 118 | ckpt["actor_state_dict"], strict=False 119 | ) 120 | for missing_key in missing.missing_keys: 121 | if "stds" not in missing_key: 122 | logger.warn("Missing key", missing_key) 123 | if len(missing.unexpected_keys) > 0: 124 | logger.warn("Unexpected keys", missing.unexpected_keys) 125 | self._network_cuda(self._config.device) 126 | return 127 | 128 | self._log_alpha.data = torch.tensor( 129 | ckpt["log_alpha"], requires_grad=True, device=self._config.device 130 | ) 131 | self._actor.load_state_dict(ckpt["actor_state_dict"]) 132 | self._critic.load_state_dict(ckpt["critic_state_dict"]) 133 | self._copy_target_network(self._critic_target, self._critic) 134 | self._ob_norm.load_state_dict(ckpt["ob_norm_state_dict"]) 135 | self._network_cuda(self._config.device) 136 | 137 | self._alpha_optim.load_state_dict(ckpt["alpha_optim_state_dict"]) 138 | self._actor_optim.load_state_dict(ckpt["actor_optim_state_dict"]) 139 | self._critic_optim.load_state_dict(ckpt["critic_optim_state_dict"]) 140 | optimizer_cuda(self._alpha_optim, self._config.device) 141 | optimizer_cuda(self._actor_optim, self._config.device) 142 | optimizer_cuda(self._critic_optim, self._config.device) 143 | 144 | def _network_cuda(self, device): 145 | self._actor.to(device) 146 | self._critic.to(device) 147 | self._critic_target.to(device) 148 | 149 | def sync_networks(self): 150 | sync_networks(self._actor) 151 | sync_networks(self._critic) 152 | 153 | def train(self): 154 | train_info = Info() 155 | 156 | self._num_updates = 1 157 | for _ in range(self._num_updates): 158 | transitions = self._buffer.sample(self._config.batch_size) 159 | _train_info = self._update_network(transitions) 160 | train_info.add(_train_info) 161 | 162 | # slow! 163 | # train_info.add( 164 | # { 165 | # "actor_grad_norm": compute_gradient_norm(self._actor), 166 | # "actor_weight_norm": compute_weight_norm(self._actor), 167 | # "critic_grad_norm": compute_gradient_norm(self._critic), 168 | # "critic_weight_norm": compute_weight_norm(self._critic), 169 | # } 170 | # ) 171 | return mpi_average(train_info.get_dict(only_scalar=True)) 172 | 173 | def _update_actor_and_alpha(self, o): 174 | info = Info() 175 | 176 | actions_real, _, log_pi, _ = self._actor.act( 177 | o, return_log_prob=True, detach_conv=True 178 | ) 179 | alpha = self._log_alpha.exp() 180 | 181 | # the actor loss 182 | entropy_loss = (alpha.detach() * log_pi).mean() 183 | actor_loss = -torch.min(*self._critic(o, actions_real, detach_conv=True)).mean() 184 | info["entropy_alpha"] = alpha.cpu().item() 185 | info["entropy_loss"] = entropy_loss.cpu().item() 186 | info["actor_loss"] = actor_loss.cpu().item() 187 | actor_loss += entropy_loss 188 | 189 | # update the actor 190 | self._actor_optim.zero_grad() 191 | actor_loss.backward() 192 | sync_grads(self._actor) 193 | self._actor_optim.step() 194 | 195 | # update alpha 196 | alpha_loss = -(alpha * (log_pi + self._target_entropy).detach()).mean() 197 | self._alpha_optim.zero_grad() 198 | alpha_loss.backward() 199 | self._alpha_optim.step() 200 | 201 | return info 202 | 203 | def _update_critic(self, o, ac, rew, o_next, done): 204 | info = Info() 205 | 206 | # calculate the target Q value function 207 | with torch.no_grad(): 208 | alpha = self._log_alpha.exp().detach() 209 | actions_next, _, log_pi_next, _ = self._actor.act( 210 | o_next, return_log_prob=True 211 | ) 212 | q_next_value1, q_next_value2 = self._critic_target(o_next, actions_next) 213 | q_next_value = torch.min(q_next_value1, q_next_value2) - alpha * log_pi_next 214 | target_q_value = ( 215 | rew * self._config.reward_scale 216 | + (1 - done) * self._config.rl_discount_factor * q_next_value 217 | ) 218 | 219 | # the q loss 220 | real_q_value1, real_q_value2 = self._critic(o, ac) 221 | critic1_loss = F.mse_loss(target_q_value, real_q_value1) 222 | critic2_loss = F.mse_loss(target_q_value, real_q_value2) 223 | critic_loss = critic1_loss + critic2_loss 224 | 225 | # update the critic 226 | self._critic_optim.zero_grad() 227 | critic_loss.backward() 228 | sync_grads(self._critic) 229 | self._critic_optim.step() 230 | 231 | info["min_target_q"] = target_q_value.min().cpu().item() 232 | info["target_q"] = target_q_value.mean().cpu().item() 233 | info["min_real1_q"] = real_q_value1.min().cpu().item() 234 | info["min_real2_q"] = real_q_value2.min().cpu().item() 235 | info["real1_q"] = real_q_value1.mean().cpu().item() 236 | info["real2_q"] = real_q_value2.mean().cpu().item() 237 | info["critic1_loss"] = critic1_loss.cpu().item() 238 | info["critic2_loss"] = critic2_loss.cpu().item() 239 | 240 | return info 241 | 242 | def _update_network(self, transitions): 243 | info = Info() 244 | 245 | # pre-process observations 246 | o, o_next = transitions["ob"], transitions["ob_next"] 247 | o = self.normalize(o) 248 | o_next = self.normalize(o_next) 249 | 250 | bs = len(transitions["done"]) 251 | _to_tensor = lambda x: to_tensor(x, self._config.device) 252 | o = _to_tensor(o) 253 | o_next = _to_tensor(o_next) 254 | ac = _to_tensor(transitions["ac"]) 255 | done = _to_tensor(transitions["done"]).reshape(bs, 1).float() 256 | rew = _to_tensor(transitions["rew"]).reshape(bs, 1) 257 | 258 | self._update_iter += 1 259 | 260 | critic_train_info = self._update_critic(o, ac, rew, o_next, done) 261 | info.add(critic_train_info) 262 | 263 | if self._update_iter % self._config.actor_update_freq == 0: 264 | actor_train_info = self._update_actor_and_alpha(o) 265 | info.add(actor_train_info) 266 | 267 | if self._update_iter % self._config.critic_target_update_freq == 0: 268 | for i, fc in enumerate(self._critic.fcs): 269 | self._soft_update_target_network( 270 | self._critic_target.fcs[i], 271 | fc, 272 | self._config.critic_soft_update_weight, 273 | ) 274 | self._soft_update_target_network( 275 | self._critic_target.encoder, 276 | self._critic.encoder, 277 | self._config.encoder_soft_update_weight, 278 | ) 279 | 280 | return info.get_dict(only_scalar=True) 281 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | """ Define parameters for algorithms. """ 2 | 3 | import argparse 4 | 5 | 6 | def str2bool(v): 7 | return v.lower() == "true" 8 | 9 | 10 | def str2intlist(value): 11 | if not value: 12 | return value 13 | else: 14 | return [int(num) for num in value.split(",")] 15 | 16 | 17 | def str2list(value): 18 | if not value: 19 | return value 20 | else: 21 | return [num for num in value.split(",")] 22 | 23 | 24 | def create_parser(): 25 | """ 26 | Creates the argparser. Use this to add additional arguments 27 | to the parser later. 28 | """ 29 | parser = argparse.ArgumentParser( 30 | "Robot Learning Algorithms", 31 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 32 | ) 33 | 34 | # environment 35 | parser.add_argument( 36 | "--env", type=str, default="Hopper-v2", help="environment name", 37 | ) 38 | parser.add_argument("--seed", type=int, default=123) 39 | 40 | add_method_arguments(parser) 41 | 42 | return parser 43 | 44 | 45 | def add_method_arguments(parser): 46 | # algorithm 47 | parser.add_argument( 48 | "--algo", 49 | type=str, 50 | default="sac", 51 | choices=["sac", "ppo", "ddpg", "td3", "bc", "gail", "dac",], 52 | ) 53 | 54 | # training 55 | parser.add_argument("--is_train", type=str2bool, default=True) 56 | parser.add_argument("--resume", type=str2bool, default=True) 57 | parser.add_argument("--init_ckpt_path", type=str, default=None) 58 | parser.add_argument("--gpu", type=int, default=None) 59 | 60 | # evaluation 61 | parser.add_argument("--ckpt_num", type=int, default=None) 62 | parser.add_argument( 63 | "--num_eval", type=int, default=1, help="number of episodes for evaluation" 64 | ) 65 | 66 | # environment 67 | try: 68 | parser.add_argument("--screen_width", type=int, default=480) 69 | parser.add_argument("--screen_height", type=int, default=480) 70 | except: 71 | pass 72 | parser.add_argument("--action_repeat", type=int, default=1) 73 | 74 | # misc 75 | parser.add_argument("--run_prefix", type=str, default=None) 76 | parser.add_argument("--notes", type=str, default="") 77 | 78 | # log 79 | parser.add_argument("--average_info", type=str2bool, default=True) 80 | parser.add_argument("--log_interval", type=int, default=1) 81 | parser.add_argument("--evaluate_interval", type=int, default=10) 82 | parser.add_argument("--ckpt_interval", type=int, default=200) 83 | parser.add_argument("--log_root_dir", type=str, default="log") 84 | parser.add_argument( 85 | "--wandb", 86 | type=str2bool, 87 | default=False, 88 | help="set it True if you want to use wandb", 89 | ) 90 | parser.add_argument("--wandb_entity", type=str, default="clvr") 91 | parser.add_argument("--wandb_project", type=str, default="robot-learning") 92 | parser.add_argument("--record_video", type=str2bool, default=True) 93 | parser.add_argument("--record_video_caption", type=str2bool, default=True) 94 | try: 95 | parser.add_argument("--record_demo", type=str2bool, default=False) 96 | except: 97 | pass 98 | 99 | # observation normalization 100 | parser.add_argument("--ob_norm", type=str2bool, default=True) 101 | parser.add_argument("--max_ob_norm_step", type=int, default=int(1e8)) 102 | parser.add_argument( 103 | "--clip_obs", type=float, default=200, help="the clip range of observation" 104 | ) 105 | parser.add_argument( 106 | "--clip_range", 107 | type=float, 108 | default=10, 109 | help="the clip range after normalization of observation", 110 | ) 111 | 112 | parser.add_argument("--max_global_step", type=int, default=int(1e6)) 113 | parser.add_argument( 114 | "--batch_size", type=int, default=128, help="the sample batch size" 115 | ) 116 | 117 | add_policy_arguments(parser) 118 | 119 | # arguments specific to algorithms 120 | args, unparsed = parser.parse_known_args() 121 | if args.algo == "sac": 122 | add_sac_arguments(parser) 123 | 124 | elif args.algo == "ddpg": 125 | add_ddpg_arguments(parser) 126 | 127 | elif args.algo == "td3": 128 | add_td3_arguments(parser) 129 | 130 | elif args.algo == "ppo": 131 | add_ppo_arguments(parser) 132 | 133 | elif args.algo == "bc": 134 | add_il_arguments(parser) 135 | add_bc_arguments(parser) 136 | 137 | elif args.algo in ["gail", "gaifo", "gaifo-s"]: 138 | add_il_arguments(parser) 139 | add_gail_arguments(parser) 140 | 141 | elif args.algo in ["dac"]: 142 | add_il_arguments(parser) 143 | add_dac_arguments(parser) 144 | 145 | if args.algo in ["gail", "gaifo", "gaifo-s", "dac"]: 146 | args, unparsed = parser.parse_known_args() 147 | 148 | if args.gail_rl_algo == "ppo": 149 | add_ppo_arguments(parser) 150 | 151 | elif args.gail_rl_algo == "sac": 152 | add_sac_arguments(parser) 153 | 154 | elif args.gail_rl_algo == "td3": 155 | add_td3_arguments(parser) 156 | 157 | return parser 158 | 159 | 160 | def add_policy_arguments(parser): 161 | # network 162 | parser.add_argument("--policy_mlp_dim", type=str2intlist, default=[256, 256]) 163 | parser.add_argument("--critic_mlp_dim", type=str2intlist, default=[256, 256]) 164 | parser.add_argument("--critic_ensemble", type=int, default=1) 165 | parser.add_argument( 166 | "--policy_activation", type=str, default="relu", choices=["relu", "elu", "tanh"] 167 | ) 168 | parser.add_argument("--tanh_policy", type=str2bool, default=True) 169 | parser.add_argument("--gaussian_policy", type=str2bool, default=True) 170 | 171 | # encoder 172 | parser.add_argument( 173 | "--encoder_type", type=str, default="mlp", choices=["mlp", "cnn"] 174 | ) 175 | parser.add_argument("--encoder_image_size", type=int, default=84) 176 | parser.add_argument("--random_crop", type=str2bool, default=False) 177 | parser.add_argument("--encoder_conv_dim", type=int, default=32) 178 | parser.add_argument("--encoder_kernel_size", type=str2intlist, default=[3, 3, 3, 3]) 179 | parser.add_argument("--encoder_stride", type=str2intlist, default=[2, 1, 1, 1]) 180 | parser.add_argument("--encoder_conv_output_dim", type=int, default=50) 181 | parser.add_argument("--encoder_soft_update_weight", type=float, default=0.95) 182 | args, unparsed = parser.parse_known_args() 183 | if args.encoder_type == "cnn": 184 | parser.set_defaults(screen_width=100, screen_height=100) 185 | parser.set_defaults(policy_mlp_dim=[1024, 1024]) 186 | parser.set_defaults(critic_mlp_dim=[1024, 1024]) 187 | parser.add_argument("--asym_ac", type=str2bool, default=False) 188 | 189 | # actor-critic 190 | parser.add_argument( 191 | "--actor_lr", type=float, default=3e-4, help="the learning rate of the actor" 192 | ) 193 | parser.add_argument( 194 | "--critic_lr", type=float, default=3e-4, help="the learning rate of the critic" 195 | ) 196 | parser.add_argument( 197 | "--critic_soft_update_weight", 198 | type=float, 199 | default=0.995, 200 | help="the average coefficient", 201 | ) 202 | 203 | parser.add_argument("--log_std_min", type=float, default=-10) 204 | parser.add_argument("--log_std_max", type=float, default=2) 205 | 206 | # absorbing state 207 | parser.add_argument("--absorbing_state", type=str2bool, default=False) 208 | 209 | 210 | def add_rl_arguments(parser): 211 | parser.add_argument( 212 | "--rl_discount_factor", type=float, default=0.99, help="the discount factor" 213 | ) 214 | parser.add_argument("--warm_up_steps", type=int, default=0) 215 | 216 | 217 | def add_on_policy_arguments(parser): 218 | parser.add_argument("--rollout_length", type=int, default=2000) 219 | parser.add_argument("--gae_lambda", type=float, default=0.95) 220 | parser.add_argument("--advantage_norm", type=str2bool, default=True) 221 | 222 | 223 | def add_off_policy_arguments(parser): 224 | parser.add_argument( 225 | "--buffer_size", type=int, default=int(1e6), help="the size of the buffer" 226 | ) 227 | parser.set_defaults(warm_up_steps=1000) 228 | 229 | 230 | def add_sac_arguments(parser): 231 | add_rl_arguments(parser) 232 | add_off_policy_arguments(parser) 233 | 234 | parser.add_argument("--reward_scale", type=float, default=1.0, help="reward scale") 235 | parser.add_argument("--actor_update_freq", type=int, default=2) 236 | parser.add_argument("--critic_target_update_freq", type=int, default=2) 237 | parser.add_argument("--target_entropy", type=float, default=None) 238 | parser.add_argument("--alpha_init_temperature", type=float, default=0.1) 239 | parser.add_argument( 240 | "--alpha_lr", type=float, default=1e-4, help="the learning rate of the actor" 241 | ) 242 | parser.set_defaults(actor_lr=3e-4) 243 | parser.set_defaults(critic_lr=3e-4) 244 | parser.set_defaults(evaluate_interval=5000) 245 | parser.set_defaults(ckpt_interval=10000) 246 | parser.set_defaults(log_interval=500) 247 | parser.set_defaults(critic_soft_update_weight=0.99) 248 | parser.set_defaults(buffer_size=100000) 249 | parser.set_defaults(critic_ensemble=2) 250 | parser.set_defaults(ob_norm=False) 251 | 252 | 253 | def add_ppo_arguments(parser): 254 | add_rl_arguments(parser) 255 | add_on_policy_arguments(parser) 256 | 257 | parser.add_argument("--ppo_clip", type=float, default=0.2) 258 | parser.add_argument("--value_loss_coeff", type=float, default=0.5) 259 | parser.add_argument("--action_loss_coeff", type=float, default=1.0) 260 | parser.add_argument("--entropy_loss_coeff", type=float, default=1e-4) 261 | 262 | parser.add_argument("--ppo_epoch", type=int, default=5) 263 | parser.add_argument("--max_grad_norm", type=float, default=None) 264 | parser.add_argument("--actor_update_freq", type=int, default=1) 265 | parser.set_defaults(ob_norm=True) 266 | parser.set_defaults(evaluate_interval=20) 267 | parser.set_defaults(ckpt_interval=20) 268 | 269 | 270 | def add_ddpg_arguments(parser): 271 | add_rl_arguments(parser) 272 | add_off_policy_arguments(parser) 273 | 274 | parser.add_argument("--actor_update_delay", type=int, default=2000) 275 | parser.add_argument("--actor_update_freq", type=int, default=2) 276 | parser.add_argument("--actor_target_update_freq", type=int, default=2) 277 | parser.add_argument("--critic_target_update_freq", type=int, default=2) 278 | parser.add_argument( 279 | "--actor_soft_update_weight", 280 | type=float, 281 | default=0.995, 282 | help="the average coefficient", 283 | ) 284 | parser.set_defaults(critic_soft_update_weight=0.995) 285 | parser.add_argument("--max_grad_norm", type=float, default=40.0) 286 | 287 | # epsilon greedy 288 | parser.add_argument("--epsilon_greedy", type=str2bool, default=False) 289 | parser.add_argument("--epsilon_greedy_eps", type=float, default=0.3) 290 | parser.add_argument("--policy_exploration_noise", type=float, default=0.1) 291 | 292 | parser.set_defaults(gaussian_policy=False) 293 | parser.set_defaults(ob_norm=False) 294 | 295 | parser.set_defaults(evaluate_interval=10000) 296 | parser.set_defaults(ckpt_interval=50000) 297 | parser.set_defaults(log_interval=1000) 298 | 299 | 300 | def add_td3_arguments(parser): 301 | add_ddpg_arguments(parser) 302 | 303 | parser.set_defaults(critic_ensemble=2) 304 | 305 | parser.add_argument("--policy_noise", type=float, default=0.2) 306 | parser.add_argument("--policy_noise_clip", type=float, default=0.5) 307 | 308 | 309 | def add_il_arguments(parser): 310 | parser.add_argument("--demo_path", type=str, default=None, help="path to demos") 311 | parser.add_argument("--demo_low_level", type=str2bool, default=False, help="use low level actions for training") 312 | parser.add_argument( 313 | "--demo_subsample_interval", 314 | type=int, 315 | default=1, 316 | # default=20, # used in GAIL 317 | help="subsample interval of expert transitions", 318 | ) 319 | parser.add_argument( 320 | "--demo_sample_range_start", type=float, default=0.0, help="sample demo range" 321 | ) 322 | parser.add_argument( 323 | "--demo_sample_range_end", type=float, default=1.0, help="sample demo range" 324 | ) 325 | 326 | 327 | def add_bc_arguments(parser): 328 | parser.set_defaults(gaussian_policy=False) 329 | parser.set_defaults(max_global_step=100) 330 | parser.set_defaults(evaluate_interval=100) 331 | parser.set_defaults(ob_norm=False) 332 | parser.add_argument( 333 | "--bc_lr", type=float, default=1e-3, help="learning rate for bc" 334 | ) 335 | parser.add_argument( 336 | "--val_split", 337 | type=float, 338 | default=0, 339 | help="how much of dataset to leave for validation set", 340 | ) 341 | 342 | 343 | def add_gail_arguments(parser): 344 | parser.add_argument("--gail_entropy_loss_coeff", type=float, default=0.0) 345 | parser.add_argument( 346 | "--gail_reward", type=str, default="vanilla", choices=["vanilla", "gan", "d"] 347 | ) 348 | parser.add_argument("--discriminator_lr", type=float, default=1e-4) 349 | parser.add_argument("--discriminator_mlp_dim", type=str2intlist, default=[256, 256]) 350 | parser.add_argument( 351 | "--discriminator_activation", 352 | type=str, 353 | default="tanh", 354 | choices=["relu", "elu", "tanh"], 355 | ) 356 | parser.add_argument("--discriminator_update_freq", type=int, default=4) 357 | parser.add_argument("--gail_no_action", type=str2bool, default=False) 358 | parser.add_argument("--gail_env_reward", type=float, default=0.0) 359 | parser.add_argument("--gail_grad_penalty_coeff", type=float, default=10.0) 360 | 361 | parser.add_argument( 362 | "--gail_rl_algo", type=str, default="ppo", choices=["ppo", "sac", "td3"] 363 | ) 364 | 365 | 366 | def add_dac_arguments(parser): 367 | add_gail_arguments(parser) 368 | parser.set_defaults(gail_rl_algo="td3") 369 | parser.set_defaults(absorbing_state=True) 370 | parser.set_defaults(warm_up_steps=1000) 371 | parser.set_defaults(actor_lr=1e-3) 372 | parser.set_defaults(actor_update_delay=1000) 373 | parser.set_defaults(batch_size=100) 374 | parser.set_defaults(gail_reward="d") 375 | 376 | 377 | def argparser(): 378 | """ Directly parses the arguments. """ 379 | parser = create_parser() 380 | args, unparsed = parser.parse_known_args() 381 | 382 | return args, unparsed 383 | -------------------------------------------------------------------------------- /environments/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Define all environments and provide helper functions to load environments. 3 | """ 4 | 5 | # OpenAI gym interface 6 | import gym 7 | import dmc2gym 8 | 9 | from ..utils.logger import logger 10 | from ..utils.gym_env import DictWrapper, FrameStackWrapper, GymWrapper, AbsorbingWrapper 11 | from ..utils.subproc_vec_env import SubprocVecEnv 12 | 13 | 14 | REGISTERED_ENVS = {} 15 | 16 | 17 | def register_env(target_class): 18 | REGISTERED_ENVS[target_class.__name__] = target_class 19 | 20 | 21 | def get_env(name): 22 | """ 23 | Gets the environment class given @name. 24 | """ 25 | if name not in REGISTERED_ENVS: 26 | logger.warn( 27 | "Unknown environment name: {}\nAvailable environments: {}".format( 28 | name, ", ".join(REGISTERED_ENVS) 29 | ) 30 | ) 31 | logger.warn("Instead, query gym environments") 32 | return None 33 | return REGISTERED_ENVS[name] 34 | 35 | 36 | def make_env(name, config=None): 37 | """ 38 | Creates a new environment instance with @name and @config. 39 | """ 40 | # get default config if not provided 41 | if config is None: 42 | from ..config import argparser 43 | 44 | config, unparsed = argparser() 45 | 46 | env = get_env(name) 47 | if env is None: 48 | return get_gym_env(name, config) 49 | 50 | return env(config) 51 | 52 | 53 | def get_gym_env(env_id, config): 54 | if env_id.startswith("dm"): 55 | # environment name of dm_control: dm.DOMAIN_NAME.TASK_NAME 56 | _, domain_name, task_name = env_id.split(".") 57 | env = dmc2gym.make( 58 | domain_name=domain_name, 59 | task_name=task_name, 60 | seed=config.seed, 61 | visualize_reward=False, 62 | from_pixels=(config.encoder_type == "cnn"), 63 | height=config.screen_height, 64 | width=config.screen_width, 65 | frame_skip=config.action_repeat, 66 | channels_first=True, 67 | ) 68 | else: 69 | env_kwargs = config.__dict__.copy() 70 | try: 71 | env = gym.make(env_id, **env_kwargs) 72 | except Exception as e: 73 | logger.warn("Failed to launch an environment with config.") 74 | logger.warn(e) 75 | logger.warn("Launch an environment without config.") 76 | env = gym.make(env_id) 77 | env.seed(config.seed) 78 | env = GymWrapper( 79 | env=env, 80 | from_pixels=(config.encoder_type == "cnn"), 81 | height=config.screen_height, 82 | width=config.screen_width, 83 | channels_first=True, 84 | frame_skip=config.action_repeat, 85 | return_state=(config.encoder_type == "cnn" and config.asym_ac) 86 | ) 87 | 88 | env = DictWrapper(env, return_state=(config.encoder_type == "cnn" and config.asym_ac)) 89 | if config.encoder_type == "cnn": 90 | env = FrameStackWrapper(env, frame_stack=3, return_state=(config.encoder_type == "cnn" and config.asym_ac)) 91 | if config.absorbing_state: 92 | env = AbsorbingWrapper(env) 93 | 94 | return env 95 | 96 | 97 | def make_vec_env(env_id, num_env, config=None, env_kwargs=None): 98 | """ 99 | Creates a wrapped SubprocVecEnv using OpenAI gym interface. 100 | Unity app will use the port number from @config.port to (@config.port + @num_env - 1). 101 | 102 | Code modified based on 103 | https://github.com/openai/baselines/blob/master/baselines/common/cmd_util.py 104 | 105 | Args: 106 | env_id: environment id registered in in `env/__init__.py`. 107 | num_env: number of environments to launch. 108 | config: general configuration for the environment. 109 | """ 110 | env_kwargs = env_kwargs or {} 111 | 112 | if config is not None: 113 | for key, value in config.__dict__.items(): 114 | env_kwargs[key] = value 115 | 116 | def make_thunk(rank): 117 | new_env_kwargs = env_kwargs.copy() 118 | if "port" in new_env_kwargs: 119 | new_env_kwargs["port"] = env_kwargs["port"] + rank 120 | new_env_kwargs["seed"] = env_kwargs["seed"] + rank 121 | return lambda: get_gym_env(env_id, new_env_kwargs) 122 | 123 | return SubprocVecEnv([make_thunk(i) for i in range(num_env)]) 124 | 125 | 126 | class EnvMeta(type): 127 | """ Meta class for registering environments. """ 128 | 129 | def __new__(meta, name, bases, class_dict): 130 | cls = super().__new__(meta, name, bases, class_dict) 131 | 132 | # List all environments that should not be registered here. 133 | _unregistered_envs = ["FurnitureEnv"] 134 | 135 | if cls.__name__ not in _unregistered_envs: 136 | register_env(cls) 137 | return cls 138 | -------------------------------------------------------------------------------- /environments/test_env.py: -------------------------------------------------------------------------------- 1 | from . import make_env 2 | from ..config import argparser 3 | 4 | 5 | config, unparsed = argparser() 6 | 7 | env = make_env(config.env, config) 8 | 9 | ob = env.reset() 10 | 11 | while True: 12 | ob, reward, done, info = env.step(env.action_space.sample()) 13 | print(reward) 14 | if done: 15 | break 16 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ Launch RL/IL training and evaluation. """ 2 | 3 | import sys 4 | import signal 5 | import os 6 | import json 7 | import logging 8 | 9 | import numpy as np 10 | import torch 11 | from six.moves import shlex_quote 12 | from mpi4py import MPI 13 | 14 | from .config import create_parser 15 | from .trainer import Trainer 16 | from .utils.logger import logger 17 | from .utils.mpi import mpi_sync 18 | 19 | 20 | np.set_printoptions(precision=3) 21 | np.set_printoptions(suppress=True) 22 | 23 | 24 | def run(parser=None): 25 | """ Runs Trainer. """ 26 | if parser is None: 27 | parser = create_parser() 28 | 29 | config, unparsed = parser.parse_known_args() 30 | if len(unparsed): 31 | logger.error("Unparsed argument is detected:\n%s", unparsed) 32 | return 33 | 34 | rank = MPI.COMM_WORLD.Get_rank() 35 | config.rank = rank 36 | config.is_chef = rank == 0 37 | config.num_workers = MPI.COMM_WORLD.Get_size() 38 | set_log_path(config) 39 | 40 | config.seed = config.seed + rank 41 | if hasattr(config, "port"): 42 | config.port = config.port + rank * 2 # training env + evaluation env 43 | 44 | if config.is_chef: 45 | logger.warn("Run a base worker.") 46 | make_log_files(config) 47 | else: 48 | logger.warn("Run worker %d and disable logger.", config.rank) 49 | logger.setLevel(logging.CRITICAL) 50 | 51 | # syncronize all processes 52 | mpi_sync() 53 | 54 | def shutdown(signal, frame): 55 | logger.warn("Received signal %s: exiting", signal) 56 | sys.exit(128 + signal) 57 | 58 | signal.signal(signal.SIGHUP, shutdown) 59 | signal.signal(signal.SIGINT, shutdown) 60 | signal.signal(signal.SIGTERM, shutdown) 61 | 62 | # set global seed 63 | np.random.seed(config.seed) 64 | torch.manual_seed(config.seed) 65 | torch.cuda.manual_seed_all(config.seed) 66 | 67 | if config.gpu is not None: 68 | os.environ["CUDA_VISIBLE_DEVICES"] = "{}".format(config.gpu) 69 | assert torch.cuda.is_available() 70 | config.device = torch.device("cuda") 71 | else: 72 | config.device = torch.device("cpu") 73 | 74 | # build a trainer 75 | trainer = Trainer(config) 76 | if config.is_train: 77 | trainer.train() 78 | logger.info("Finish training") 79 | else: 80 | trainer.evaluate() 81 | logger.info("Finish evaluating") 82 | 83 | 84 | def set_log_path(config): 85 | """ 86 | Sets paths to log directories. 87 | """ 88 | config.run_name = "{}.{}.{}.{}".format( 89 | config.env, config.algo, config.run_prefix, config.seed 90 | ) 91 | config.log_dir = os.path.join(config.log_root_dir, config.run_name) 92 | config.record_dir = os.path.join(config.log_dir, "video") 93 | config.demo_dir = os.path.join(config.log_dir, "demo") 94 | 95 | 96 | def make_log_files(config): 97 | """ 98 | Sets up log directories and saves git diff and command line. 99 | """ 100 | logger.info("Create log directory: %s", config.log_dir) 101 | os.makedirs(config.log_dir, exist_ok=config.resume or not config.is_train) 102 | 103 | logger.info("Create video directory: %s", config.record_dir) 104 | os.makedirs(config.record_dir, exist_ok=config.resume or not config.is_train) 105 | 106 | logger.info("Create demo directory: %s", config.demo_dir) 107 | os.makedirs(config.demo_dir, exist_ok=config.resume or not config.is_train) 108 | 109 | if config.is_train: 110 | # log git diff 111 | git_path = os.path.join(config.log_dir, "git.txt") 112 | cmd_path = os.path.join(config.log_dir, "cmd.sh") 113 | cmds = [ 114 | "echo `git rev-parse HEAD` >> {}".format(git_path), 115 | "git diff >> {}".format(git_path), 116 | "echo 'python -m rl {}' >> {}".format( 117 | " ".join([shlex_quote(arg) for arg in sys.argv[1:]]), cmd_path 118 | ), 119 | ] 120 | os.system("\n".join(cmds)) 121 | 122 | # log config 123 | param_path = os.path.join(config.log_dir, "params.json") 124 | logger.info("Store parameters in %s", param_path) 125 | with open(param_path, "w") as fp: 126 | json.dump(config.__dict__, fp, indent=4, sort_keys=True) 127 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .actor_critic import Actor, Critic 2 | 3 | 4 | def get_actor_critic(encoder_type, algo): 5 | actor = critic = None 6 | if encoder_type == "mlp": 7 | from .mlp_actor_critic import MlpActor, MlpCritic, NoisyMlpActor 8 | 9 | if algo == "ddpg": # add exploratory noise to actor 10 | actor = NoisyMlpActor 11 | elif algo in ["bc"]: 12 | return MlpActor, None 13 | else: 14 | actor = MlpActor 15 | return actor, MlpCritic 16 | 17 | elif encoder_type == "cnn": 18 | from .cnn_actor_critic import CnnActor, CnnCritic 19 | 20 | if algo in ["bc"]: 21 | return CnnActor, None 22 | else: 23 | actor = CnnActor 24 | return actor, CnnCritic 25 | 26 | else: 27 | raise ValueError("--encoder_type %s is not supported." % encoder_type) 28 | -------------------------------------------------------------------------------- /networks/actor_critic.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import gym.spaces 8 | 9 | from .distributions import ( 10 | FixedCategorical, 11 | FixedNormal, 12 | Identity, 13 | MixedDistribution, 14 | ) 15 | from .utils import MLP, flatten_ac 16 | from .encoder import Encoder 17 | from ..utils.pytorch import to_tensor 18 | from ..utils.logger import logger 19 | 20 | 21 | class Actor(nn.Module): 22 | def __init__(self, config, ob_space, ac_space, tanh_policy, encoder=None): 23 | super().__init__() 24 | self._config = config 25 | self._ac_space = ac_space 26 | self._activation_fn = getattr(F, config.policy_activation) 27 | self._tanh = tanh_policy 28 | self._gaussian = config.gaussian_policy 29 | 30 | if encoder: 31 | self.encoder = encoder 32 | else: 33 | self.encoder = Encoder(config, ob_space) 34 | 35 | self.fc = MLP( 36 | config, self.encoder.output_dim, config.policy_mlp_dim[-1], config.policy_mlp_dim[:-1] 37 | ) 38 | 39 | self.fcs = nn.ModuleDict() 40 | self._dists = {} 41 | for k, v in ac_space.spaces.items(): 42 | if isinstance(v, gym.spaces.Box): # and self._gaussian: # for convenience to transfer bc policy 43 | self.fcs.update( 44 | {k: MLP(config, config.policy_mlp_dim[-1], gym.spaces.flatdim(v) * 2)} 45 | ) 46 | else: 47 | self.fcs.update( 48 | {k: MLP(config, config.policy_mlp_dim[-1], gym.spaces.flatdim(v))} 49 | ) 50 | 51 | if isinstance(v, gym.spaces.Box): 52 | if self._gaussian: 53 | self._dists[k] = lambda m, s: FixedNormal(m, s) 54 | else: 55 | self._dists[k] = lambda m, s: Identity(m) 56 | else: 57 | self._dists[k] = lambda m, s: FixedCategorical(logits=m) 58 | 59 | @property 60 | def info(self): 61 | return {} 62 | 63 | def forward(self, ob: dict, detach_conv=False): 64 | out = self.encoder(ob, detach_conv=detach_conv) 65 | out = self._activation_fn(self.fc(out)) 66 | 67 | means, stds = OrderedDict(), OrderedDict() 68 | for k, v in self._ac_space.spaces.items(): 69 | if isinstance(v, gym.spaces.Box): # and self._gaussian: 70 | mean, log_std = self.fcs[k](out).chunk(2, dim=-1) 71 | log_std_min, log_std_max = self._config.log_std_min , self._config.log_std_max 72 | log_std = torch.tanh(log_std) 73 | log_std = log_std_min + 0.5 * (log_std_max - log_std_min) * (log_std + 1) 74 | std = log_std.exp() 75 | else: 76 | mean, std = self.fcs[k](out), None 77 | 78 | means[k] = mean 79 | stds[k] = std 80 | 81 | return means, stds 82 | 83 | def act(self, ob, deterministic=False, activations=None, return_log_prob=False, detach_conv=False): 84 | """ Samples action for rollout. """ 85 | means, stds = self.forward(ob, detach_conv=detach_conv) 86 | 87 | dists = OrderedDict() 88 | for k in means.keys(): 89 | dists[k] = self._dists[k](means[k], stds[k]) 90 | 91 | actions = OrderedDict() 92 | mixed_dist = MixedDistribution(dists) 93 | if activations is None: 94 | if deterministic: 95 | activations = mixed_dist.mode() 96 | else: 97 | activations = mixed_dist.rsample() 98 | 99 | if return_log_prob: 100 | log_probs = mixed_dist.log_probs(activations) 101 | 102 | for k, v in self._ac_space.spaces.items(): 103 | z = activations[k] 104 | if self._tanh and isinstance(v, gym.spaces.Box): 105 | action = torch.tanh(z) 106 | if return_log_prob: 107 | # follow the Appendix C. Enforcing Action Bounds 108 | log_det_jacobian = 2 * (np.log(2.0) - z - F.softplus(-2.0 * z)).sum( 109 | dim=-1, keepdim=True 110 | ) 111 | log_probs[k] = log_probs[k] - log_det_jacobian 112 | else: 113 | action = z 114 | 115 | actions[k] = action 116 | 117 | if return_log_prob: 118 | log_probs = torch.cat(list(log_probs.values()), -1).sum(-1, keepdim=True) 119 | entropy = mixed_dist.entropy() 120 | else: 121 | log_probs = None 122 | entropy = None 123 | 124 | return actions, activations, log_probs, entropy 125 | 126 | 127 | class Critic(nn.Module): 128 | def __init__(self, config, ob_space, ac_space=None, encoder=None): 129 | super().__init__() 130 | self._config = config 131 | 132 | if encoder: 133 | self.encoder = encoder 134 | else: 135 | self.encoder = Encoder(config, ob_space) 136 | 137 | input_dim = self.encoder.output_dim 138 | if ac_space is not None: 139 | input_dim += gym.spaces.flatdim(ac_space) 140 | 141 | self.fcs = nn.ModuleList() 142 | 143 | for _ in range(config.critic_ensemble): 144 | self.fcs.append(MLP(config, input_dim, 1, config.critic_mlp_dim)) 145 | 146 | def forward(self, ob, ac=None, detach_conv=False): 147 | out = self.encoder(ob, detach_conv=detach_conv) 148 | 149 | if ac is not None: 150 | out = torch.cat([out, flatten_ac(ac)], dim=-1) 151 | assert len(out.shape) == 2 152 | 153 | out = [fc(out) for fc in self.fcs] 154 | if len(out) == 1: 155 | return out[0] 156 | return out 157 | -------------------------------------------------------------------------------- /networks/discriminator.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import gym 8 | 9 | from .utils import MLP 10 | from ..utils.pytorch import to_tensor 11 | 12 | 13 | class Discriminator(nn.Module): 14 | def __init__(self, config, ob_space, ac_space=None): 15 | super().__init__() 16 | self._config = config 17 | self._no_action = ac_space == None 18 | 19 | input_dim = gym.spaces.flatdim(ob_space) 20 | if not self._no_action: 21 | input_dim += gym.spaces.flatdim(ac_space) 22 | 23 | self.fc = MLP( 24 | config, 25 | input_dim, 26 | 1, 27 | config.discriminator_mlp_dim, 28 | getattr(F, config.discriminator_activation), 29 | ) 30 | 31 | def forward(self, ob, ac=None): 32 | # flatten observation 33 | if isinstance(ob, OrderedDict) or isinstance(ob, dict): 34 | ob = list(ob.values()) 35 | if len(ob[0].shape) == 1: 36 | ob = [x.unsqueeze(0) for x in ob] 37 | ob = torch.cat(ob, dim=-1) 38 | 39 | if ac is not None: 40 | # flatten action 41 | if isinstance(ac, OrderedDict): 42 | ac = list(ac.values()) 43 | if len(ac[0].shape) == 1: 44 | ac = [x.unsqueeze(0) for x in ac] 45 | ac = torch.cat(ac, dim=-1) 46 | ob = torch.cat([ob, ac], dim=-1) 47 | 48 | out = self.fc(ob) 49 | return out 50 | -------------------------------------------------------------------------------- /networks/distributions.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.distributions 7 | 8 | 9 | # Categorical 10 | FixedCategorical = torch.distributions.Categorical 11 | 12 | old_sample = FixedCategorical.sample 13 | FixedCategorical.sample = lambda self: old_sample(self).unsqueeze(-1) 14 | 15 | log_prob_cat = FixedCategorical.log_prob 16 | # FixedCategorical.log_probs = lambda self, actions: log_prob_cat(self, actions.squeeze(-1)).view(actions.size(0), -1).sum(-1).unsqueeze(-1) 17 | FixedCategorical.log_probs = lambda self, actions: log_prob_cat( 18 | self, actions.squeeze(-1) 19 | ).unsqueeze(-1) 20 | 21 | categorical_entropy = FixedCategorical.entropy 22 | FixedCategorical.entropy = lambda self: categorical_entropy(self) * 10.0 # scaling 23 | 24 | FixedCategorical.mode = lambda self: self.probs.argmax(dim=-1, keepdim=True) 25 | 26 | 27 | # Normal 28 | FixedNormal = torch.distributions.Normal 29 | 30 | normal_init = FixedNormal.__init__ 31 | FixedNormal.__init__ = lambda self, mean, std: normal_init( 32 | self, mean.double(), std.double() 33 | ) 34 | 35 | log_prob_normal = FixedNormal.log_prob 36 | FixedNormal.log_probs = ( 37 | lambda self, actions: log_prob_normal(self, actions.double()) 38 | .sum(-1, keepdim=True) 39 | .float() 40 | ) 41 | 42 | normal_entropy = FixedNormal.entropy 43 | FixedNormal.entropy = lambda self: normal_entropy(self).sum(-1).float() 44 | 45 | FixedNormal.mode = lambda self: self.mean.float() 46 | 47 | normal_sample = FixedNormal.sample 48 | FixedNormal.sample = lambda self: normal_sample(self).float() 49 | 50 | normal_rsample = FixedNormal.rsample 51 | FixedNormal.rsample = lambda self: normal_rsample(self).float() 52 | 53 | 54 | # Identity 55 | class Identity(object): 56 | def __init__(self, mean): 57 | self._mean = mean 58 | 59 | def mode(self): 60 | return self._mean 61 | 62 | def sample(self): 63 | return self._mean 64 | 65 | def rsample(self): 66 | return self._mean 67 | 68 | 69 | def init(module, weight_init, bias_init, gain=1): 70 | weight_init(module.weight.data, gain=gain) 71 | bias_init(module.bias.data) 72 | return module 73 | 74 | 75 | class AddBias(nn.Module): 76 | def __init__(self, bias): 77 | super().__init__() 78 | self._bias = nn.Parameter(bias.unsqueeze(1)) 79 | 80 | def forward(self, x): 81 | if x.dim() == 2: 82 | bias = self._bias.t().view(1, -1) 83 | else: 84 | bias = self._bias.t().view(1, -1, 1, 1) 85 | return x + bias 86 | 87 | 88 | class Categorical(nn.Module): 89 | def __init__(self): 90 | super().__init__() 91 | 92 | def forward(self, x): 93 | return FixedCategorical(logits=x) 94 | 95 | 96 | class DiagGaussian(nn.Module): 97 | def __init__(self, config): 98 | super().__init__() 99 | self.logstd = AddBias(torch.zeros(config.action_size)) 100 | self.config = config 101 | 102 | def forward(self, x): 103 | zeros = torch.zeros(x.size()).to(self.config.device) 104 | logstd = self.logstd(zeros) 105 | return FixedNormal(x, logstd.exp()) 106 | 107 | 108 | class MixedDistribution(nn.Module): 109 | def __init__(self, distributions): 110 | super().__init__() 111 | assert isinstance(distributions, OrderedDict) 112 | self.distributions = distributions 113 | 114 | def mode(self): 115 | return OrderedDict([(k, dist.mode()) for k, dist in self.distributions.items()]) 116 | 117 | def sample(self): 118 | return OrderedDict( 119 | [(k, dist.sample()) for k, dist in self.distributions.items()] 120 | ) 121 | 122 | def rsample(self): 123 | return OrderedDict( 124 | [(k, dist.rsample()) for k, dist in self.distributions.items()] 125 | ) 126 | 127 | def log_probs(self, x): 128 | assert isinstance(x, dict) 129 | return OrderedDict( 130 | [(k, dist.log_probs(x[k])) for k, dist in self.distributions.items()] 131 | ) 132 | 133 | def entropy(self): 134 | return sum([dist.entropy() for dist in self.distributions.values()]) 135 | -------------------------------------------------------------------------------- /networks/encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code reference: 3 | https://github.com/MishaLaskin/rad/blob/master/encoder.py 4 | """ 5 | 6 | import gym.spaces 7 | import torch 8 | import torch.nn as nn 9 | 10 | from .utils import CNN, MLP, flatten_ac 11 | 12 | 13 | class Encoder(nn.Module): 14 | def __init__(self, config, ob_space): 15 | super().__init__() 16 | 17 | self._encoder_type = config.encoder_type 18 | self._ob_space = ob_space 19 | 20 | self.base = nn.ModuleDict() 21 | encoder_output_dim = 0 22 | for k, v in ob_space.spaces.items(): 23 | if len(v.shape) in [3, 4]: 24 | if self._encoder_type == "mlp": 25 | self.base[k] = None 26 | encoder_output_dim += gym.spaces.flatdim(v) 27 | else: 28 | if len(v.shape) == 3: 29 | image_dim = v.shape[0] 30 | elif len(v.shape) == 4: 31 | image_dim = v.shape[0] * v.shape[1] 32 | self.base[k] = CNN(config, image_dim) 33 | encoder_output_dim += self.base[k].output_dim 34 | elif len(v.shape) == 1: 35 | self.base[k] = None 36 | encoder_output_dim += gym.spaces.flatdim(v) 37 | else: 38 | raise ValueError("Check the shape of observation %s (%s)" % (k, v)) 39 | 40 | self.output_dim = encoder_output_dim 41 | 42 | def forward(self, ob, detach_conv=False): 43 | encoder_outputs = [] 44 | for k, v in ob.items(): 45 | if self.base[k] is not None: 46 | if isinstance(self.base[k], CNN): 47 | if v.max() > 1.0: 48 | v = v.float() / 255.0 49 | encoder_outputs.append( 50 | self.base[k](v, detach_conv=detach_conv) 51 | ) 52 | else: 53 | encoder_outputs.append(v.flatten(start_dim=1)) 54 | out = torch.cat(encoder_outputs, dim=-1) 55 | assert len(out.shape) == 2 56 | return out 57 | 58 | def copy_conv_weights_from(self, source): 59 | """ Tie convolutional layers """ 60 | for k in self.base.keys(): 61 | if self.base[k] is not None: 62 | self.base[k].copy_conv_weights_from(source.base[k]) 63 | -------------------------------------------------------------------------------- /networks/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class CNN(nn.Module): 8 | def __init__(self, config, input_dim): 9 | super().__init__() 10 | 11 | self.convs = nn.ModuleList() 12 | d_prev = input_dim 13 | d = config.encoder_conv_dim 14 | w = config.encoder_image_size 15 | for k, s in zip(config.encoder_kernel_size, config.encoder_stride): 16 | self.convs.append(nn.Conv2d(d_prev, d, int(k), int(s))) 17 | w = int(np.floor((w - (int(k) - 1) - 1) / int(s) + 1)) 18 | d_prev = d 19 | 20 | print("Output of CNN (%d) = %d x %d x %d" % (w * w * d, w, w, d)) 21 | self.output_dim = config.encoder_conv_output_dim 22 | 23 | self.fc = nn.Linear(w * w * d, self.output_dim) 24 | self.ln = nn.LayerNorm(self.output_dim) 25 | 26 | self.apply(weight_init) 27 | 28 | def forward(self, ob, detach_conv=False): 29 | out = ob 30 | for conv in self.convs: 31 | out = F.relu(conv(out)) 32 | out = out.flatten(start_dim=1) 33 | 34 | if detach_conv: 35 | out = out.detach() 36 | 37 | out = self.fc(out) 38 | out = self.ln(out) 39 | out = F.tanh(out) 40 | 41 | return out 42 | 43 | # from https://github.com/MishaLaskin/rad/blob/master/encoder.py 44 | def copy_conv_weights_from(self, source): 45 | """Tie convolutional layers""" 46 | # only tie conv layers 47 | for i, conv in enumerate(self.convs): 48 | assert type(source.convs[i]) == type(conv) 49 | conv.weight = source.convs[i].weight 50 | conv.bias = source.convs[i].bias 51 | 52 | 53 | # from https://github.com/denisyarats/drq/blob/master/utils.py#L62 54 | def weight_init(tensor): 55 | if isinstance(tensor, nn.Linear): 56 | nn.init.orthogonal_(tensor.weight.data) 57 | tensor.bias.data.fill_(0.0) 58 | elif isinstance(tensor, nn.Conv2d) or isinstance(tensor, nn.ConvTranspose2d): 59 | tensor.weight.data.fill_(0.0) 60 | tensor.bias.data.fill_(0.0) 61 | mid = tensor.weight.size(2) // 2 62 | gain = nn.init.calculate_gain("relu") 63 | nn.init.orthogonal_(tensor.weight.data[:, :, mid, mid], gain) 64 | # nn.init.orthogonal_(tensor.weight.data, gain) 65 | 66 | 67 | class MLP(nn.Module): 68 | def __init__( 69 | self, config, input_dim, output_dim, hid_dims=[], activation_fn=None, 70 | ): 71 | super().__init__() 72 | self.activation_fn = activation_fn 73 | if activation_fn is None: 74 | self.activation_fn = getattr(F, config.policy_activation) 75 | 76 | self.fcs = nn.ModuleList() 77 | prev_dim = input_dim 78 | for d in hid_dims + [output_dim]: 79 | self.fcs.append(nn.Linear(prev_dim, d)) 80 | prev_dim = d 81 | 82 | self.output_dim = output_dim 83 | self.apply(weight_init) 84 | 85 | def forward(self, ob): 86 | out = ob 87 | for fc in self.fcs[:-1]: 88 | out = self.activation_fn(fc(out)) 89 | out = self.fcs[-1](out) 90 | return out 91 | 92 | 93 | def flatten_ob(ob: dict, ac=None): 94 | """ 95 | Flattens the observation dictionary. The observation dictionary 96 | can either contain a single ob, or a batch of obs. 97 | Any images must be flattened to 1D tensors, but 98 | we must be careful to check if we are doing a single instance 99 | or batch before we flatten. 100 | 101 | Returns a list of dim [N x D] where N is batch size and D is sum of flattened 102 | dims of observations 103 | """ 104 | inp = [] 105 | images = [] 106 | single_ob = False 107 | for k, v in ob.items(): 108 | if k in ["camera_ob", "depth_ob", "segmentation_ob"]: 109 | images.append(v) 110 | else: 111 | if len(v.shape) == 1: 112 | single_ob = True 113 | inp.append(v) 114 | # concatenate images into 1D 115 | for image in images: 116 | if single_ob: 117 | img = torch.flatten(image) 118 | else: # batch of obs, flatten after bs dim 119 | img = torch.flatten(image, start_dim=1) 120 | inp.append(img) 121 | # now flatten into Nx1 tensors 122 | if single_ob: 123 | inp = [x.unsqueeze(0) for x in inp] 124 | 125 | if ac is not None: 126 | ac = list(ac.values()) 127 | if len(ac[0].shape) == 1: 128 | ac = [x.unsqueeze(0) for x in ac] 129 | inp.extend(ac) 130 | inp = torch.cat(inp, dim=-1) 131 | return inp 132 | 133 | 134 | def flatten_ac(ac: dict): 135 | ac = list(ac.values()) 136 | if len(ac[0].shape) == 1: 137 | ac = [x.unsqueeze(0) for x in ac] 138 | ac = torch.cat(ac, dim=-1) 139 | return ac 140 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | wandb 4 | colorlog 5 | tqdm 6 | h5py 7 | ipdb 8 | opencv-python 9 | moviepy 10 | mpi4py 11 | gym 12 | mujoco-py 13 | #torch 14 | #torchvision 15 | absl-py 16 | git+git://github.com/deepmind/dm_control.git 17 | git+git://github.com/1nadequacy/dmc2gym.git 18 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | from main import run 2 | 3 | import torch 4 | 5 | 6 | if __name__ == "__main__": 7 | torch.multiprocessing.set_start_method('spawn') 8 | run() 9 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base code for RL/IL training. 3 | Collects rollouts and updates policy networks. 4 | """ 5 | 6 | import os 7 | import gzip 8 | import pickle 9 | import copy 10 | from time import time 11 | from collections import defaultdict, OrderedDict 12 | 13 | import torch 14 | import wandb 15 | import h5py 16 | import gym 17 | import numpy as np 18 | import moviepy.editor as mpy 19 | from tqdm import tqdm, trange 20 | 21 | from .algorithms import RL_ALGOS, IL_ALGOS, get_agent_by_name 22 | from .algorithms.rollouts import RolloutRunner 23 | from .utils.info_dict import Info 24 | from .utils.logger import logger 25 | from .utils.pytorch import get_ckpt_path, count_parameters 26 | from .utils.mpi import mpi_sum, mpi_average, mpi_gather_average 27 | from .environments import make_env 28 | 29 | 30 | class Trainer(object): 31 | """ 32 | Trainer class for SAC, PPO, DDPG, BC, and GAIL in PyTorch. 33 | """ 34 | 35 | def __init__(self, config): 36 | """ 37 | Initializes class with the configuration. 38 | """ 39 | self._config = config 40 | self._is_chef = config.is_chef 41 | self._is_rl = config.algo in RL_ALGOS 42 | self._average_info = config.average_info 43 | 44 | # create environment 45 | self._env = make_env(config.env, config) 46 | ob_space = env_ob_space = self._env.observation_space 47 | ac_space = self._env.action_space 48 | logger.info("Observation space: " + str(ob_space)) 49 | logger.info("Action space: " + str(ac_space)) 50 | 51 | config_eval = copy.copy(config) 52 | if hasattr(config_eval, "port"): 53 | config_eval.port += 1 54 | self._env_eval = make_env(config.env, config_eval) if self._is_chef else None 55 | 56 | # create a new observation space after data augmentation (random crop) 57 | if config.encoder_type == "cnn": 58 | assert ( 59 | not config.ob_norm 60 | ), "Turn off the observation norm (--ob_norm False) for pixel inputs" 61 | ob_space = gym.spaces.Dict(spaces=dict(ob_space.spaces)) 62 | for k in ob_space.spaces.keys(): 63 | if len(ob_space.spaces[k].shape) == 3: 64 | shape = [ 65 | ob_space.spaces[k].shape[0], 66 | config.encoder_image_size, 67 | config.encoder_image_size, 68 | ] 69 | ob_space.spaces[k] = gym.spaces.Box( 70 | low=0, high=255, shape=shape, dtype=np.uint8 71 | ) 72 | 73 | # build agent and networks for algorithm 74 | self._agent = get_agent_by_name(config.algo)( 75 | config, ob_space, ac_space, env_ob_space 76 | ) 77 | 78 | # build rollout runner 79 | self._runner = RolloutRunner(config, self._env, self._env_eval, self._agent) 80 | 81 | # setup log 82 | if self._is_chef and config.is_train: 83 | exclude = ["device"] 84 | if not config.wandb: 85 | os.environ["WANDB_MODE"] = "dryrun" 86 | 87 | wandb.init( 88 | resume=config.run_name, 89 | project=config.wandb_project, 90 | config={k: v for k, v in config.__dict__.items() if k not in exclude}, 91 | dir=config.log_dir, 92 | entity=config.wandb_entity, 93 | notes=config.notes, 94 | ) 95 | 96 | def _save_ckpt(self, ckpt_num, update_iter): 97 | """ 98 | Save checkpoint to log directory. 99 | 100 | Args: 101 | ckpt_num: number appended to checkpoint name. The number of 102 | environment step is used in this code. 103 | update_iter: number of policy update. It will be used for resuming training. 104 | """ 105 | ckpt_path = os.path.join(self._config.log_dir, "ckpt_%09d.pt" % ckpt_num) 106 | state_dict = {"step": ckpt_num, "update_iter": update_iter} 107 | state_dict["agent"] = self._agent.state_dict() 108 | torch.save(state_dict, ckpt_path) 109 | logger.warn("Save checkpoint: %s", ckpt_path) 110 | 111 | if self._agent.is_off_policy(): 112 | replay_path = os.path.join( 113 | self._config.log_dir, "replay_%08d.pkl" % ckpt_num 114 | ) 115 | with gzip.open(replay_path, "wb") as f: 116 | replay_buffers = {"replay": self._agent.replay_buffer()} 117 | pickle.dump(replay_buffers, f) 118 | 119 | def _load_ckpt(self, ckpt_path, ckpt_num): 120 | """ 121 | Loads checkpoint with path @ckpt_path or index number @ckpt_num. If @ckpt_num is None, 122 | it loads and returns the checkpoint with the largest index number. 123 | """ 124 | if ckpt_path is None: 125 | ckpt_path, ckpt_num = get_ckpt_path(self._config.log_dir, ckpt_num) 126 | else: 127 | ckpt_num = int(ckpt_path.rsplit("_", 1)[-1].split(".")[0]) 128 | 129 | if ckpt_path is not None: 130 | logger.warn("Load checkpoint %s", ckpt_path) 131 | ckpt = torch.load(ckpt_path, map_location=self._config.device) 132 | self._agent.load_state_dict(ckpt["agent"]) 133 | 134 | if self._config.is_train and self._agent.is_off_policy(): 135 | replay_path = os.path.join( 136 | self._config.log_dir, "replay_%08d.pkl" % ckpt_num 137 | ) 138 | logger.warn("Load replay_buffer %s", replay_path) 139 | if os.path.exists(replay_path): 140 | with gzip.open(replay_path, "rb") as f: 141 | replay_buffers = pickle.load(f) 142 | self._agent.load_replay_buffer(replay_buffers["replay"]) 143 | else: 144 | logger.warn("Replay buffer not exists at %s", replay_path) 145 | 146 | if ( 147 | self._config.init_ckpt_path is not None 148 | and "bc" in self._config.init_ckpt_path 149 | ): 150 | return 0, 0 151 | else: 152 | return ckpt["step"], ckpt["update_iter"] 153 | logger.warn("Randomly initialize models") 154 | return 0, 0 155 | 156 | def _log_train(self, step, train_info, ep_info): 157 | """ 158 | Logs training and episode information to wandb. 159 | Args: 160 | step: the number of environment steps. 161 | train_info: training information to log, such as loss, gradient. 162 | ep_info: episode information to log, such as reward, episode time. 163 | """ 164 | for k, v in train_info.items(): 165 | if np.isscalar(v) or (hasattr(v, "shape") and np.prod(v.shape) == 1): 166 | wandb.log({"train_rl/%s" % k: v}, step=step) 167 | else: 168 | wandb.log({"train_rl/%s" % k: [wandb.Image(v)]}, step=step) 169 | 170 | for k, v in ep_info.items(): 171 | wandb.log({"train_ep/%s" % k: np.mean(v)}, step=step) 172 | wandb.log({"train_ep_max/%s" % k: np.max(v)}, step=step) 173 | 174 | def _log_test(self, step, ep_info): 175 | """ 176 | Logs episode information during testing to wandb. 177 | Args: 178 | step: the number of environment steps. 179 | ep_info: episode information to log, such as reward, episode time. 180 | """ 181 | if self._config.is_train: 182 | for k, v in ep_info.items(): 183 | if isinstance(v, wandb.Video): 184 | wandb.log({"test_ep/%s" % k: v}, step=step) 185 | elif isinstance(v, list) and isinstance(v[0], wandb.Video): 186 | for i, video in enumerate(v): 187 | wandb.log({"test_ep/%s_%d" % (k, i): video}, step=step) 188 | else: 189 | wandb.log({"test_ep/%s" % k: np.mean(v)}, step=step) 190 | 191 | def train(self): 192 | """ Trains an agent. """ 193 | config = self._config 194 | 195 | # load checkpoint 196 | step, update_iter = self._load_ckpt(config.init_ckpt_path, config.ckpt_num) 197 | 198 | # sync the networks across the cpus 199 | self._agent.sync_networks() 200 | 201 | logger.info("Start training at step=%d", step) 202 | if self._is_chef: 203 | pbar = tqdm( 204 | initial=update_iter, total=config.max_global_step, desc=config.run_name 205 | ) 206 | ep_info = Info() 207 | train_info = Info() 208 | 209 | # decide how many episodes or how long rollout to collect 210 | if self._config.algo == "bc": 211 | runner = None 212 | elif self._config.algo == "gail": 213 | runner = self._runner.run( 214 | every_steps=self._config.rollout_length, step=step 215 | ) 216 | elif self._config.algo == "ppo": 217 | runner = self._runner.run( 218 | every_steps=self._config.rollout_length, step=step 219 | ) 220 | elif self._config.algo in ["sac", "ddpg", "td3"]: 221 | runner = self._runner.run(every_steps=1, step=step) 222 | # runner = self._runner.run(every_episodes=1) 223 | elif self._config.algo == "dac": 224 | runner = self._runner.run(every_steps=1, step=step) 225 | 226 | st_time = time() 227 | st_step = step 228 | 229 | while runner and step < config.warm_up_steps: 230 | rollout, info = next(runner) 231 | self._agent.store_episode(rollout) 232 | step_per_batch = mpi_sum(len(rollout["ac"])) 233 | step += step_per_batch 234 | if runner and step < config.max_ob_norm_step: 235 | self._update_normalizer(rollout) 236 | if self._is_chef: 237 | pbar.update(step_per_batch) 238 | 239 | if self._config.algo == "bc" and self._config.ob_norm: 240 | self._agent.update_normalizer() 241 | 242 | while step < config.max_global_step: 243 | # collect rollouts 244 | if runner: 245 | rollout, info = next(runner) 246 | if self._average_info: 247 | info = mpi_gather_average(info) 248 | self._agent.store_episode(rollout) 249 | step_per_batch = mpi_sum(len(rollout["ac"])) 250 | else: 251 | step_per_batch = mpi_sum(1) 252 | info = {} 253 | 254 | # train an agent 255 | _train_info = self._agent.train() 256 | 257 | if runner and step < config.max_ob_norm_step: 258 | self._update_normalizer(rollout) 259 | 260 | step += step_per_batch 261 | update_iter += 1 262 | 263 | # log training and episode information or evaluate 264 | if self._is_chef: 265 | pbar.update(step_per_batch) 266 | ep_info.add(info) 267 | train_info.add(_train_info) 268 | 269 | if update_iter % config.log_interval == 0: 270 | train_info.add( 271 | { 272 | "sec": (time() - st_time) / config.log_interval, 273 | "steps_per_sec": (step - st_step) / (time() - st_time), 274 | "update_iter": update_iter, 275 | } 276 | ) 277 | st_time = time() 278 | st_step = step 279 | self._log_train(step, train_info.get_dict(), ep_info.get_dict()) 280 | ep_info = Info() 281 | train_info = Info() 282 | 283 | if update_iter % config.evaluate_interval == 1: 284 | logger.info("Evaluate at %d", update_iter) 285 | rollout, info = self._evaluate( 286 | step=step, record_video=config.record_video 287 | ) 288 | self._log_test(step, info) 289 | 290 | if update_iter % config.ckpt_interval == 0: 291 | self._save_ckpt(step, update_iter) 292 | 293 | self._save_ckpt(step, update_iter) 294 | logger.info("Reached %s steps. worker %d stopped.", step, config.rank) 295 | 296 | def _update_normalizer(self, rollout): 297 | """ Updates normalizer with @rollout. """ 298 | if self._config.ob_norm: 299 | self._agent.update_normalizer(rollout["ob"]) 300 | 301 | def _evaluate(self, step=None, record_video=False): 302 | """ 303 | Runs one rollout if in eval mode (@idx is not None). 304 | Runs num_record_samples rollouts if in train mode (@idx is None). 305 | 306 | Args: 307 | step: the number of environment steps. 308 | record_video: whether to record video or not. 309 | """ 310 | logger.info("Run %d evaluations at step=%d", self._config.num_eval, step) 311 | rollouts = [] 312 | info_history = Info() 313 | for i in range(self._config.num_eval): 314 | logger.warn("Evalute run %d", i + 1) 315 | rollout, info, frames = self._runner.run_episode( 316 | is_train=False, record_video=record_video 317 | ) 318 | rollouts.append(rollout) 319 | logger.info( 320 | "rollout: %s", {k: v for k, v in info.items() if not "qpos" in k} 321 | ) 322 | 323 | if record_video: 324 | ep_rew = info["rew"] 325 | ep_success = ( 326 | "s" 327 | if "episode_success" in info and info["episode_success"] 328 | else "f" 329 | ) 330 | fname = "{}_step_{:011d}_{}_r_{}_{}.mp4".format( 331 | self._config.env, step, i, ep_rew, ep_success, 332 | ) 333 | video_path = self._save_video(fname, frames) 334 | if self._config.is_train: 335 | info["video"] = wandb.Video(video_path, fps=15, format="mp4") 336 | 337 | info_history.add(info) 338 | 339 | return rollouts, info_history 340 | 341 | def evaluate(self): 342 | """ Evaluates an agent stored in chekpoint with @self._config.ckpt_num. """ 343 | step, update_iter = self._load_ckpt( 344 | self._config.init_ckpt_path, self._config.ckpt_num 345 | ) 346 | 347 | logger.info( 348 | "Run %d evaluations at step=%d, update_iter=%d", 349 | self._config.num_eval, 350 | step, 351 | update_iter, 352 | ) 353 | rollouts, info = self._evaluate( 354 | step=step, record_video=self._config.record_video 355 | ) 356 | 357 | info_stat = info.get_stat() 358 | os.makedirs("result", exist_ok=True) 359 | with h5py.File("result/{}.hdf5".format(self._config.run_name), "w") as hf: 360 | for k, v in info.items(): 361 | hf.create_dataset(k, data=info[k]) 362 | with open("result/{}.txt".format(self._config.run_name), "w") as f: 363 | for k, v in info_stat.items(): 364 | f.write("{}\t{:.03f} $\\pm$ {:.03f}\n".format(k, v[0], v[1])) 365 | 366 | 367 | if self._config.record_demo: 368 | new_rollouts = [] 369 | for rollout in rollouts: 370 | new_rollout = { 371 | "obs": rollout["ob"], 372 | "actions": rollout["ac"], 373 | "rewards": rollout["rew"], 374 | "dones": rollout["done"], 375 | } 376 | new_rollouts.append(new_rollout) 377 | 378 | fname = "{}_step_{:011d}_{}_trajs.pkl".format( 379 | self._config.run_name, step, self._config.num_eval, 380 | ) 381 | path = os.path.join(self._config.demo_dir, fname) 382 | logger.warn("[*] Generating demo: {}".format(path)) 383 | with open(path, "wb") as f: 384 | pickle.dump(new_rollouts, f) 385 | 386 | def _save_video(self, fname, frames, fps=15.0): 387 | """ Saves @frames into a video with file name @fname. """ 388 | path = os.path.join(self._config.record_dir, fname) 389 | logger.warn("[*] Generating video: {}".format(path)) 390 | 391 | def f(t): 392 | frame_length = len(frames) 393 | new_fps = 1.0 / (1.0 / fps + 1.0 / frame_length) 394 | idx = min(int(t * new_fps), frame_length - 1) 395 | return frames[idx] 396 | 397 | video = mpy.VideoClip(f, duration=len(frames) / fps + 2) 398 | 399 | video.write_videofile(path, fps, verbose=False) 400 | logger.warn("[*] Video saved: {}".format(path)) 401 | return path 402 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/youngwoon/robot-learning/96af508abfca6aadb38d9c55f01602464fecf460/utils/__init__.py -------------------------------------------------------------------------------- /utils/gym_env.py: -------------------------------------------------------------------------------- 1 | from collections import deque, OrderedDict 2 | 3 | import gym 4 | import numpy as np 5 | 6 | 7 | def cat_spaces(spaces): 8 | if isinstance(spaces[0], gym.spaces.Box): 9 | out_space = gym.spaces.Box( 10 | low=np.concatenate([s.low for s in spaces]), 11 | high=np.concatenate([s.high for s in spaces]) 12 | ) 13 | elif isinstance(spaces[0], gym.spaces.Discrete): 14 | out_space = gym.spaces.Discrete(sum([s.n for s in spaces])) 15 | return out_space 16 | 17 | def stacked_space(space, k): 18 | if isinstance(space, gym.spaces.Box): 19 | space_stack = gym.spaces.Box( 20 | low=np.concatenate([space.low] * k, axis=0), 21 | high=np.concatenate([space.high] * k, axis=0), 22 | ) 23 | elif isinstance(space, gym.spaces.Discrete): 24 | space_stack = gym.spaces.Discrete(space.n * k) 25 | return space_stack 26 | 27 | 28 | def spaces_to_shapes(space): 29 | if isinstance(space, gym.spaces.Dict): 30 | return {k: spaces_to_shapes(v) for k, v in space.spaces.items()} 31 | elif isinstance(space, gym.spaces.Box): 32 | return space.shape 33 | elif isinstance(space, gym.spaces.Discrete): 34 | return [space.n] 35 | 36 | 37 | def zero_value(space, dtype=np.float64): 38 | if isinstance(space, gym.spaces.Dict): 39 | return OrderedDict( 40 | [(k, zero_value(space, dtype)) for k, space in space.spaces.items()] 41 | ) 42 | elif isinstance(space, gym.spaces.Box): 43 | return np.zeros(space.shape).astype(dtype) 44 | elif isinstance(space, gym.spaces.Discrete): 45 | return np.zeros(1).astype(dtype) 46 | 47 | 48 | def get_non_absorbing_state(ob): 49 | ob = ob.copy() 50 | ob["absorbing_state"] = np.array([0]) 51 | return ob 52 | 53 | 54 | def get_absorbing_state(space): 55 | ob = zero_value(space) 56 | ob["absorbing_state"] = np.array([1]) 57 | return ob 58 | 59 | 60 | class GymWrapper(gym.Wrapper): 61 | def __init__( 62 | self, 63 | env, 64 | from_pixels=False, 65 | height=100, 66 | width=100, 67 | camera_id=None, 68 | channels_first=True, 69 | frame_skip=1, 70 | return_state=False 71 | ): 72 | super().__init__(env) 73 | self._from_pixels = from_pixels 74 | self._height = height 75 | self._width = width 76 | self._camera_id = camera_id 77 | self._channels_first = channels_first 78 | self._frame_skip = frame_skip 79 | self.max_episode_steps = self.env._max_episode_steps // frame_skip 80 | self._return_state = return_state 81 | 82 | if from_pixels: 83 | shape = [3, height, width] if channels_first else [height, width, 3] 84 | self.observation_space = gym.spaces.Box( 85 | low=0, high=255, shape=shape, dtype=np.uint8 86 | ) 87 | else: 88 | self.observation_space = env.observation_space 89 | 90 | self.env_observation_space = env.observation_space 91 | 92 | def reset(self): 93 | ob = self.env.reset() 94 | 95 | if self._return_state: 96 | return self._get_obs(ob, reset=True), ob 97 | 98 | return self._get_obs(ob, reset=True) 99 | 100 | def step(self, ac): 101 | reward = 0 102 | for _ in range(self._frame_skip): 103 | ob, _reward, done, info = self.env.step(ac) 104 | reward += _reward 105 | if done: 106 | break 107 | if self._return_state: 108 | return (self._get_obs(ob), ob), reward, done, info 109 | 110 | return self._get_obs(ob), reward, done, info 111 | 112 | def _get_obs(self, ob, reset=False): 113 | if self._from_pixels: 114 | ob = self.render( 115 | mode="rgb_array", 116 | height=self._height, 117 | width=self._width, 118 | camera_id=self._camera_id, 119 | ) 120 | if reset: 121 | ob = self.render( 122 | mode="rgb_array", 123 | height=self._height, 124 | width=self._width, 125 | camera_id=self._camera_id, 126 | ) 127 | if self._channels_first: 128 | ob = ob.transpose(2, 0, 1).copy() 129 | return ob 130 | 131 | 132 | class DictWrapper(gym.Wrapper): 133 | def __init__(self, env, return_state=False): 134 | super().__init__(env) 135 | 136 | self._return_state = return_state 137 | 138 | self._is_ob_dict = isinstance(env.observation_space, gym.spaces.Dict) 139 | self._env_is_ob_dict = isinstance(env.env_observation_space, gym.spaces.Dict) 140 | if not self._is_ob_dict: 141 | self.observation_space = gym.spaces.Dict({"ob": env.observation_space}) 142 | else: 143 | self.observation_space = env.observation_space 144 | if not self._env_is_ob_dict: 145 | self.env_observation_space = gym.spaces.Dict({"state": env.env_observation_space}) 146 | else: 147 | self.env_observation_space = env.env_observation_space 148 | 149 | self._is_ac_dict = isinstance(env.action_space, gym.spaces.Dict) 150 | if not self._is_ac_dict: 151 | self.action_space = gym.spaces.Dict({"ac": env.action_space}) 152 | else: 153 | self.action_space = env.action_space 154 | 155 | def reset(self): 156 | ob = self.env.reset() 157 | return self._get_obs(ob) 158 | 159 | def step(self, ac): 160 | if not self._is_ac_dict: 161 | ac = ac["ac"] 162 | ob, reward, done, info = self.env.step(ac) 163 | return self._get_obs(ob), reward, done, info 164 | 165 | def _get_obs(self, ob): 166 | if not self._is_ob_dict: 167 | if self._return_state: 168 | if not self._env_is_ob_dict: 169 | ob = {"ob": ob[0], "state": {"state": ob[1]}} 170 | else: 171 | ob = {"ob": ob[0], "state": ob[1]} 172 | else: 173 | ob = {"ob": ob} 174 | return ob 175 | 176 | 177 | class FrameStackWrapper(gym.Wrapper): 178 | def __init__(self, env, frame_stack=3, return_state=False): 179 | super().__init__(env) 180 | 181 | # Both observation and action spaces must be gym.spaces.Dict. 182 | assert isinstance(env.observation_space, gym.spaces.Dict), env.observation_space 183 | assert isinstance(env.action_space, gym.spaces.Dict), env.action_space 184 | 185 | self._frame_stack = frame_stack 186 | self._frames = deque([], maxlen=frame_stack) 187 | self._return_state = return_state 188 | self._state = None 189 | 190 | ob_space = [] 191 | for k, space in env.observation_space.spaces.items(): 192 | space_stack = stacked_space(space, frame_stack) 193 | ob_space.append((k, space_stack)) 194 | self.observation_space = gym.spaces.Dict(ob_space) 195 | 196 | self.env_observation_space = env.env_observation_space 197 | 198 | def reset(self): 199 | ob = self.env.reset() 200 | if self._return_state: 201 | self._state = ob.pop("state", None) 202 | for _ in range(self._frame_stack): 203 | self._frames.append(ob) 204 | return self._get_obs() 205 | 206 | def step(self, ac): 207 | ob, reward, done, info = self.env.step(ac) 208 | if self._return_state: 209 | self._state = ob.pop("state", None) 210 | self._frames.append(ob) 211 | return self._get_obs(), reward, done, info 212 | 213 | def _get_obs(self): 214 | frames = list(self._frames) 215 | obs = [] 216 | for k in self.env.observation_space.spaces.keys(): 217 | obs.append((k, np.concatenate([f[k] for f in frames], axis=0))) 218 | if self._return_state: 219 | obs.append(("state", self._state)) 220 | 221 | return OrderedDict(obs) 222 | 223 | 224 | class AbsorbingWrapper(gym.Wrapper): 225 | def __init__(self, env): 226 | super().__init__(env) 227 | ob_space = gym.spaces.Dict(spaces=dict(env.observation_space.spaces)) 228 | ob_space.spaces["absorbing_state"] = gym.spaces.Box( 229 | low=-1, high=1, shape=(1,), dtype=np.uint8 230 | ) 231 | self.observation_space = ob_space 232 | 233 | def reset(self): 234 | ob = self.env.reset() 235 | return self._get_obs(ob) 236 | 237 | def step(self, ac): 238 | ob, reward, done, info = self.env.step(ac) 239 | return self._get_obs(ob), reward, done, info 240 | 241 | def _get_obs(self, ob): 242 | return get_non_absorbing_state(ob) 243 | 244 | def get_absorbing_state(self): 245 | return get_absorbing_state(self.observation_space) 246 | -------------------------------------------------------------------------------- /utils/info_dict.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | 5 | 6 | class Info(object): 7 | def __init__(self, info=None): 8 | self._info = defaultdict(list) 9 | if info: 10 | self.add(info) 11 | 12 | def add(self, info): 13 | if info is None: 14 | return 15 | if isinstance(info, Info): 16 | for k, v in info._info.items(): 17 | self._info[k].extend(v) 18 | elif isinstance(info, dict): 19 | for k, v in info.items(): 20 | if isinstance(v, list): 21 | self._info[k].extend(v) 22 | else: 23 | self._info[k].append(v) 24 | else: 25 | raise ValueError("info should be dict or Info (%s)" % info) 26 | 27 | def clear(self): 28 | self._info = defaultdict(list) 29 | 30 | def get_dict(self, reduction="mean", only_scalar=False): 31 | ret = {} 32 | for k, v in self._info.items(): 33 | if np.isscalar(v): 34 | ret[k] = v 35 | elif isinstance(v[0], (int, float, bool, np.float32, np.int64, np.ndarray)): 36 | if "_mean" in k or reduction == "mean": 37 | ret[k] = np.mean(v) 38 | elif reduction == "sum": 39 | ret[k] = np.sum(v) 40 | elif not only_scalar: 41 | ret[k] = v 42 | self.clear() 43 | return ret 44 | 45 | def get_stat(self): 46 | ret = {} 47 | for k, v in self._info.items(): 48 | if np.isscalar(v): 49 | ret[k] = (v, 0) 50 | elif isinstance(v[0], (int, float, bool, np.float32, np.int64, np.ndarray)): 51 | ret[k] = (np.mean(v), np.std(v)) 52 | return ret 53 | 54 | def __getitem__(self, key): 55 | return self._info[key] 56 | 57 | def __setitem__(self, key, value): 58 | self._info[key].append(value) 59 | 60 | def items(self): 61 | return self._info.items() 62 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | 4 | import numpy as np 5 | import colorlog 6 | 7 | 8 | formatter = colorlog.ColoredFormatter( 9 | "%(log_color)s[%(asctime)s] %(message)s", 10 | datefmt=None, 11 | reset=True, 12 | log_colors={ 13 | "DEBUG": "cyan", 14 | "INFO": "white", 15 | "WARNING": "yellow", 16 | "ERROR": "red,bold", 17 | "CRITICAL": "red,bg_white", 18 | }, 19 | secondary_log_colors={}, 20 | style="%", 21 | ) 22 | 23 | logger = colorlog.getLogger("robot-learning") 24 | logger.setLevel(logging.DEBUG) 25 | logger.propagate = False 26 | 27 | # fh = logging.FileHandler('log') 28 | # fh.setLevel(logging.DEBUG) 29 | # fh.setFormatter(formatter) 30 | # logger.addHandler(fh) 31 | 32 | if not logger.handlers: 33 | ch = colorlog.StreamHandler() 34 | ch.setLevel(logging.DEBUG) 35 | ch.setFormatter(formatter) 36 | logger.addHandler(ch) 37 | 38 | 39 | class StopWatch(object): 40 | def __init__(self): 41 | self.start = {} 42 | self.times = {} 43 | 44 | def begin(self, name): 45 | self.start[name] = time.time() 46 | 47 | def end(self, name): 48 | if name not in self.times: 49 | self.times[name] = [] 50 | assert name in self.start, "%s cannot be found in Stop Watch" % name 51 | 52 | self.times[name].append(time.time() - self.start[name]) 53 | 54 | def display(self): 55 | print("----Times----") 56 | for name in self.times: 57 | print(name, np.mean(self.times[name])) 58 | 59 | self.times = {} 60 | -------------------------------------------------------------------------------- /utils/mpi.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mpi4py import MPI 3 | 4 | from .info_dict import Info 5 | 6 | 7 | def mpi_gather_average(x): 8 | buf = MPI.COMM_WORLD.gather(x, root=0) 9 | if MPI.COMM_WORLD.rank == 0: 10 | info = Info() 11 | for data in buf: 12 | info.add(data) 13 | return info.get_dict() 14 | return None 15 | 16 | 17 | def _mpi_average(x): 18 | buf = np.zeros_like(x) 19 | MPI.COMM_WORLD.Allreduce(x, buf, op=MPI.SUM) 20 | buf /= MPI.COMM_WORLD.Get_size() 21 | return buf 22 | 23 | 24 | # Average across the cpu's data 25 | def mpi_average(x): 26 | if MPI.COMM_WORLD.Get_size() == 1: 27 | return x 28 | if isinstance(x, dict): 29 | keys = sorted(x.keys()) 30 | return {k: _mpi_average(np.array(x[k])) for k in keys} 31 | else: 32 | return _mpi_average(np.array(x)) 33 | 34 | 35 | def _mpi_sum(x): 36 | buf = np.zeros_like(x) 37 | MPI.COMM_WORLD.Allreduce(x, buf, op=MPI.SUM) 38 | return buf 39 | 40 | 41 | # Sum over the cpu's data 42 | def mpi_sum(x): 43 | if MPI.COMM_WORLD.Get_size() == 1: 44 | return x 45 | if isinstance(x, dict): 46 | keys = sorted(x.keys()) 47 | return {k: _mpi_sum(np.array(x[k])) for k in keys} 48 | else: 49 | return _mpi_sum(np.array(x)) 50 | 51 | 52 | # Syncronize all processes. 53 | def mpi_sync(): 54 | mpi_sum(0) 55 | -------------------------------------------------------------------------------- /utils/normalizer.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import numpy as np 4 | import gym.spaces 5 | 6 | from .mpi import mpi_average 7 | 8 | 9 | class SubNormalizer: 10 | def __init__(self, size, eps=1e-1, default_clip_range=np.inf, clip_obs=np.inf): 11 | if isinstance(size, list): 12 | self.size = size 13 | else: 14 | self.size = [size] 15 | self.eps = eps 16 | self.default_clip_range = default_clip_range 17 | self.clip_obs = clip_obs 18 | # some local information 19 | self.local_sum = np.zeros(self.size, np.float32) 20 | self.local_sumsq = np.zeros(self.size, np.float32) 21 | self.local_count = np.zeros(1, np.float32) 22 | # get the total sum sumsq and sum count 23 | self.total_sum = np.zeros(self.size, np.float32) 24 | self.total_sumsq = np.zeros(self.size, np.float32) 25 | self.total_count = np.ones(1, np.float32) 26 | # get the mean and std 27 | self.mean = np.zeros(self.size, np.float32) 28 | self.std = np.ones(self.size, np.float32) 29 | 30 | def _clip(self, v): 31 | return np.clip(v, -self.clip_obs, self.clip_obs) 32 | 33 | # update the parameters of the normalizer 34 | def update(self, v): 35 | v = self._clip(v) 36 | v = v.reshape([-1] + self.size) 37 | 38 | if not isinstance(v, np.ndarray): 39 | v = v.detach().numpy() 40 | # do the computing 41 | self.local_sum += v.sum(axis=0) 42 | self.local_sumsq += (np.square(v)).sum(axis=0) 43 | self.local_count[0] += v.shape[0] 44 | 45 | # sync the parameters across the cpus 46 | def sync(self, local_sum, local_sumsq, local_count): 47 | local_sum[...] = mpi_average(local_sum) 48 | local_sumsq[...] = mpi_average(local_sumsq) 49 | local_count[...] = mpi_average(local_count) 50 | return local_sum, local_sumsq, local_count 51 | 52 | def recompute_stats(self): 53 | local_count = self.local_count.copy() 54 | local_sum = self.local_sum.copy() 55 | local_sumsq = self.local_sumsq.copy() 56 | # reset 57 | self.local_count[...] = 0 58 | self.local_sum[...] = 0 59 | self.local_sumsq[...] = 0 60 | # synrc the stats 61 | sync_sum, sync_sumsq, sync_count = self.sync( 62 | local_sum, local_sumsq, local_count 63 | ) 64 | # update the total stuff 65 | self.total_sum += sync_sum 66 | self.total_sumsq += sync_sumsq 67 | self.total_count += sync_count 68 | # calculate the new mean and std 69 | self.mean = self.total_sum / self.total_count 70 | self.std = np.sqrt( 71 | np.maximum( 72 | np.square(self.eps), 73 | (self.total_sumsq / self.total_count) 74 | - np.square(self.total_sum / self.total_count), 75 | ) 76 | ) 77 | 78 | # normalize the observation 79 | def normalize(self, v, clip_range=None): 80 | v = self._clip(v) 81 | if clip_range is None: 82 | clip_range = self.default_clip_range 83 | return np.clip((v - self.mean) / (self.std), -clip_range, clip_range) 84 | 85 | def state_dict(self): 86 | return { 87 | "sum": self.total_sum, 88 | "sumsq": self.total_sumsq, 89 | "count": self.total_count, 90 | } 91 | 92 | def load_state_dict(self, state_dict): 93 | self.total_sum = state_dict["sum"] 94 | self.total_sumsq = state_dict["sumsq"] 95 | self.total_count = state_dict["count"] 96 | self.mean = self.total_sum / self.total_count 97 | self.std = np.sqrt( 98 | np.maximum( 99 | np.square(self.eps), 100 | (self.total_sumsq / self.total_count) 101 | - np.square(self.total_sum / self.total_count), 102 | ) 103 | ) 104 | 105 | 106 | class Normalizer: 107 | def __init__(self, shape, eps=1e-1, default_clip_range=np.inf, clip_obs=np.inf): 108 | if isinstance(shape, gym.spaces.Dict): 109 | self._shape = {k: list(v.shape) for k, v in shape.spaces.items()} 110 | elif isinstance(shape, dict): 111 | self._shape = shape 112 | else: 113 | self._shape = {"": shape} 114 | print("New ob_norm with shape", self._shape) 115 | 116 | self._keys = sorted(self._shape.keys()) 117 | 118 | self.sub_norm = {} 119 | for key in self._keys: 120 | self.sub_norm[key] = SubNormalizer( 121 | self._shape[key], eps, default_clip_range, clip_obs 122 | ) 123 | 124 | # update the parameters of the normalizer 125 | def update(self, v): 126 | if isinstance(v, list): 127 | if isinstance(v[0], dict): 128 | v = OrderedDict( 129 | [(k, np.asarray([x[k] for x in v])) for k in self._keys] 130 | ) 131 | else: 132 | v = np.asarray(v) 133 | 134 | if isinstance(v, dict): 135 | for k, v_ in v.items(): 136 | if k in self._keys: 137 | self.sub_norm[k].update(v_) 138 | else: 139 | self.sub_norm[""].update(v) 140 | 141 | def recompute_stats(self): 142 | for k in self._keys: 143 | self.sub_norm[k].recompute_stats() 144 | 145 | # normalize the observation 146 | def _normalize(self, v, clip_range=None): 147 | if not isinstance(v, dict): 148 | return self.sub_norm[""].normalize(v, clip_range) 149 | return OrderedDict( 150 | [ 151 | (k, self.sub_norm[k].normalize(v_, clip_range)) 152 | for k, v_ in v.items() 153 | if k in self._keys 154 | ] 155 | ) 156 | 157 | def normalize(self, v, clip_range=None): 158 | if isinstance(v, list): 159 | return [self._normalize(x, clip_range) for x in v] 160 | else: 161 | return self._normalize(v, clip_range) 162 | 163 | def state_dict(self): 164 | return OrderedDict([(k, self.sub_norm[k].state_dict()) for k in self._keys]) 165 | 166 | def load_state_dict(self, state_dict): 167 | for k in self._keys: 168 | self.sub_norm[k].load_state_dict(state_dict[k]) 169 | -------------------------------------------------------------------------------- /utils/pytorch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | from glob import glob 4 | from collections import OrderedDict, defaultdict 5 | 6 | import numpy as np 7 | import torch 8 | import torch.distributed as dist 9 | import torchvision.utils as vutils 10 | import torchvision.transforms.functional as TF 11 | import PIL.Image 12 | from mpi4py import MPI 13 | 14 | 15 | # Note! This is l2 square, not l2 16 | def l2(a, b): 17 | return torch.pow(torch.abs(a - b), 2).sum(dim=1) 18 | 19 | 20 | # required when we load optimizer from a checkpoint 21 | def optimizer_cuda(optimizer, device): 22 | for state in optimizer.state.values(): 23 | for k, v in state.items(): 24 | if isinstance(v, torch.Tensor): 25 | state[k] = v.to(device) 26 | 27 | 28 | def get_ckpt_path(base_dir, ckpt_num): 29 | if ckpt_num is None: 30 | return get_recent_ckpt_path(base_dir) 31 | files = glob(os.path.join(base_dir, "*.pt")) 32 | for f in files: 33 | if "ckpt_%08d.pt" % ckpt_num in f: 34 | return f, ckpt_num 35 | raise Exception("Did not find ckpt_%s.pt" % ckpt_num) 36 | 37 | 38 | def get_recent_ckpt_path(base_dir): 39 | files = glob(os.path.join(base_dir, "*.pt")) 40 | files.sort() 41 | if len(files) == 0: 42 | return None, None 43 | max_step = max([f.rsplit("_", 1)[-1].split(".")[0] for f in files]) 44 | paths = [f for f in files if max_step in f] 45 | if len(paths) == 1: 46 | return paths[0], int(max_step) 47 | else: 48 | raise Exception("Multiple most recent ckpts %s" % paths) 49 | 50 | 51 | def image_grid(image, n=4): 52 | return vutils.make_grid(image[:n], nrow=n).cpu().detach().numpy() 53 | 54 | 55 | def count_parameters(model): 56 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 57 | 58 | 59 | def slice_tensor(input, indices): 60 | ret = {} 61 | for k, v in input.items(): 62 | ret[k] = v[indices] 63 | return ret 64 | 65 | 66 | def average_gradients(model): 67 | size = float(dist.get_world_size()) 68 | for p in model.parameters(): 69 | if p.grad is not None: 70 | dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM) 71 | p.grad.data /= size 72 | 73 | 74 | def ensure_shared_grads(model, shared_model): 75 | """for A3C""" 76 | for param, shared_param in zip(model.parameters(), shared_model.parameters()): 77 | if shared_param.grad is not None: 78 | return 79 | shared_param._grad = param.grad 80 | 81 | 82 | def compute_gradient_norm(model): 83 | grad_norm = 0 84 | for p in model.parameters(): 85 | if p.grad is not None: 86 | grad_norm += (p.grad.data ** 2).sum().item() 87 | return grad_norm 88 | 89 | 90 | def compute_weight_norm(model): 91 | weight_norm = 0 92 | for p in model.parameters(): 93 | if p.data is not None: 94 | weight_norm += (p.data ** 2).sum().item() 95 | return weight_norm 96 | 97 | 98 | def compute_weight_sum(model): 99 | weight_sum = 0 100 | for p in model.parameters(): 101 | if p.data is not None: 102 | weight_sum += p.data.abs().sum().item() 103 | return weight_sum 104 | 105 | 106 | # sync_networks across the different cores 107 | def sync_networks(network): 108 | """ 109 | netowrk is the network you want to sync 110 | """ 111 | comm = MPI.COMM_WORLD 112 | if comm.Get_size() == 1: 113 | return 114 | flat_params, params_shape = _get_flat_params(network) 115 | comm.Bcast(flat_params, root=0) 116 | # set the flat params back to the network 117 | _set_flat_params(network, params_shape, flat_params) 118 | 119 | 120 | # get the flat params from the network 121 | def _get_flat_params(network): 122 | param_shape = {} 123 | flat_params = None 124 | for key_name, value in network.named_parameters(): 125 | param_shape[key_name] = value.cpu().detach().numpy().shape 126 | if flat_params is None: 127 | flat_params = value.cpu().detach().numpy().flatten() 128 | else: 129 | flat_params = np.append(flat_params, value.cpu().detach().numpy().flatten()) 130 | return flat_params, param_shape 131 | 132 | 133 | # set the params from the network 134 | def _set_flat_params(network, params_shape, params): 135 | pointer = 0 136 | if hasattr(network, "_config"): 137 | device = network._config.device 138 | else: 139 | device = torch.device("cpu") 140 | 141 | for key_name, values in network.named_parameters(): 142 | # get the length of the parameters 143 | len_param = np.prod(params_shape[key_name]) 144 | copy_params = params[pointer : pointer + len_param].reshape( 145 | params_shape[key_name] 146 | ) 147 | copy_params = torch.tensor(copy_params).to(device) 148 | # copy the params 149 | values.data.copy_(copy_params.data) 150 | # update the pointer 151 | pointer += len_param 152 | 153 | 154 | # sync gradients across the different cores 155 | def sync_grads(network): 156 | comm = MPI.COMM_WORLD 157 | if comm.Get_size() == 1: 158 | return 159 | flat_grads, grads_shape = _get_flat_grads(network) 160 | global_grads = np.zeros_like(flat_grads) 161 | comm.Allreduce(flat_grads, global_grads, op=MPI.SUM) 162 | _set_flat_grads(network, grads_shape, global_grads) 163 | 164 | 165 | def _set_flat_grads(network, grads_shape, flat_grads): 166 | pointer = 0 167 | if hasattr(network, "_config"): 168 | device = network._config.device 169 | else: 170 | device = torch.device("cpu") 171 | 172 | for key_name, value in network.named_parameters(): 173 | if key_name in grads_shape: 174 | len_grads = np.prod(grads_shape[key_name]) 175 | copy_grads = flat_grads[pointer : pointer + len_grads].reshape( 176 | grads_shape[key_name] 177 | ) 178 | copy_grads = torch.tensor(copy_grads).to(device) 179 | # copy the grads 180 | value.grad.data.copy_(copy_grads.data) 181 | pointer += len_grads 182 | 183 | 184 | def _get_flat_grads(network): 185 | grads_shape = {} 186 | flat_grads = None 187 | for key_name, value in network.named_parameters(): 188 | try: 189 | grads_shape[key_name] = value.grad.data.cpu().numpy().shape 190 | except: 191 | print("Cannot get grad of tensor {}".format(key_name)) 192 | continue 193 | 194 | if flat_grads is None: 195 | flat_grads = value.grad.data.cpu().numpy().flatten() 196 | else: 197 | flat_grads = np.append(flat_grads, value.grad.data.cpu().numpy().flatten()) 198 | return flat_grads, grads_shape 199 | 200 | 201 | def fig2tensor(draw_func): 202 | def decorate(*args, **kwargs): 203 | tmp = io.BytesIO() 204 | fig = draw_func(*args, **kwargs) 205 | fig.savefig(tmp, dpi=88) 206 | tmp.seek(0) 207 | fig.clf() 208 | return TF.to_tensor(PIL.Image.open(tmp)) 209 | 210 | return decorate 211 | 212 | 213 | def tensor2np(t): 214 | if isinstance(t, torch.Tensor): 215 | return t.clone().detach().cpu().numpy() 216 | else: 217 | return t 218 | 219 | 220 | def tensor2img(tensor): 221 | if len(tensor.shape) == 4: 222 | assert tensor.shape[0] == 1 223 | tensor = tensor.squeeze(0) 224 | img = tensor.permute(1, 2, 0).detach().cpu().numpy() 225 | import cv2 226 | 227 | cv2.imwrite("tensor.png", img) 228 | 229 | 230 | def obs2tensor(obs, device): 231 | if isinstance(obs, list): 232 | obs = list2dict(obs) 233 | 234 | return OrderedDict( 235 | [ 236 | (k, torch.tensor(np.stack(v), dtype=torch.float32).to(device)) 237 | for k, v in obs.items() 238 | ] 239 | ) 240 | 241 | 242 | # transfer a numpy array into a tensor 243 | def to_tensor(x, device): 244 | if isinstance(x, dict): 245 | return OrderedDict( 246 | [(k, torch.as_tensor(v, device=device).float()) for k, v in x.items()] 247 | ) 248 | if isinstance(x, list): 249 | return [torch.as_tensor(v, device=device).float() for v in x] 250 | return torch.as_tensor(x, device=device).float() 251 | 252 | 253 | def list2dict(rollout): 254 | ret = OrderedDict() 255 | for k in rollout[0].keys(): 256 | ret[k] = [] 257 | for transition in rollout: 258 | for k, v in transition.items(): 259 | ret[k].append(v) 260 | return ret 261 | 262 | 263 | def scale_dict_tensor(tensor, scalar): 264 | if isinstance(tensor, dict): 265 | return OrderedDict( 266 | [(k, scale_dict_tensor(tensor[k], scalar)) for k in tensor.keys()] 267 | ) 268 | elif isinstance(tensor, list): 269 | return [scale_dict_tensor(tensor[i], scalar) for i in range(len(tensor))] 270 | else: 271 | return tensor * scalar 272 | 273 | 274 | # From softlearning repo 275 | def flatten(unflattened, parent_key="", separator="/"): 276 | items = [] 277 | for k, v in unflattened.items(): 278 | if separator in k: 279 | raise ValueError("Found separator ({}) from key ({})".format(separator, k)) 280 | new_key = parent_key + separator + k if parent_key else k 281 | if isinstance(v, collections.MutableMapping) and v: 282 | items.extend(flatten(v, new_key, separator=separator).items()) 283 | else: 284 | items.append((new_key, v)) 285 | 286 | return OrderedDict(items) 287 | 288 | 289 | # From softlearning repo 290 | def unflatten(flattened, separator="."): 291 | result = {} 292 | for key, value in flattened.items(): 293 | parts = key.split(separator) 294 | d = result 295 | for part in parts[:-1]: 296 | if part not in d: 297 | d[part] = {} 298 | d = d[part] 299 | d[parts[-1]] = value 300 | 301 | return result 302 | 303 | 304 | # from https://github.com/MishaLaskin/rad/blob/master/utils.py 305 | def center_crop(img, out=84): 306 | """ 307 | args: 308 | imgs: np.array shape (C,H,W) 309 | out: output size (e.g. 84) 310 | returns np.array shape (1,C,H,W) 311 | """ 312 | h, w = img.shape[1:] 313 | new_h, new_w = out, out 314 | 315 | top = (h - new_h) // 2 316 | left = (w - new_w) // 2 317 | 318 | img = img[:, top : top + new_h, left : left + new_w] 319 | img = np.expand_dims(img, axis=0) 320 | return img 321 | 322 | # from https://github.com/MishaLaskin/rad/blob/master/utils.py 323 | def center_crop_images(image, out=84): 324 | """ 325 | args: 326 | imgs: np.array shape (B,C,H,W) 327 | out: output size (e.g. 84) 328 | returns np.array shape (B,C,H,W) 329 | """ 330 | h, w = image.shape[2:] 331 | new_h, new_w = out, out 332 | 333 | top = (h - new_h) // 2 334 | left = (w - new_w) // 2 335 | 336 | image = image[:, :, top:top + new_h, left:left + new_w] 337 | return image 338 | 339 | 340 | # from https://github.com/MishaLaskin/rad/blob/master/data_augs.py 341 | def random_crop(imgs, out=84): 342 | """ 343 | args: 344 | imgs: np.array shape (B,C,H,W) 345 | out: output size (e.g. 84) 346 | returns np.array 347 | """ 348 | b, c, h, w = imgs.shape 349 | crop_max = h - out + 1 350 | w1 = np.random.randint(0, crop_max, b) 351 | h1 = np.random.randint(0, crop_max, b) 352 | cropped = np.empty((b, c, out, out), dtype=imgs.dtype) 353 | for i, (img, w11, h11) in enumerate(zip(imgs, w1, h1)): 354 | cropped[i] = img[:, h11 : h11 + out, w11 : w11 + out] 355 | return cropped 356 | -------------------------------------------------------------------------------- /utils/subproc_vec_env.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper functions to make a vector environment. 3 | 4 | Code modified based on 5 | https://github.com/openai/baselines/blob/master/baselines/common/vec_env/subproc_vec_env.py 6 | """ 7 | 8 | import multiprocessing as mp 9 | 10 | import numpy as np 11 | 12 | from .vec_env import VecEnv, CloudpickleWrapper, clear_mpi_env_vars 13 | 14 | 15 | def worker(remote, parent_remote, env_fn_wrappers): 16 | def step_env(env, action): 17 | ob, reward, done, info = env.step(action) 18 | if done: 19 | ob = env.reset() 20 | return ob, reward, done, info 21 | 22 | parent_remote.close() 23 | envs = [env_fn_wrapper() for env_fn_wrapper in env_fn_wrappers.x] 24 | try: 25 | while True: 26 | cmd, data = remote.recv() 27 | if cmd == "step": 28 | remote.send([step_env(env, action) for env, action in zip(envs, data)]) 29 | elif cmd == "reset": 30 | remote.send([env.reset() for env in envs]) 31 | elif cmd == "render": 32 | remote.send([env.render(mode="rgb_array") for env in envs]) 33 | elif cmd == "close": 34 | remote.close() 35 | break 36 | elif cmd == "get_spaces_spec": 37 | remote.send( 38 | CloudpickleWrapper( 39 | (envs[0].observation_space, envs[0].action_space, envs[0].spec) 40 | ) 41 | ) 42 | else: 43 | raise NotImplementedError 44 | except KeyboardInterrupt: 45 | print("SubprocVecEnv worker: got KeyboardInterrupt") 46 | finally: 47 | for env in envs: 48 | env.close() 49 | 50 | 51 | class SubprocVecEnv(VecEnv): 52 | """ 53 | VecEnv that runs multiple environments in parallel in subproceses and communicates with them via pipes. 54 | Recommended to use when num_envs > 1 and step() can be a bottleneck. 55 | """ 56 | 57 | def __init__(self, env_fns, spaces=None, context="spawn", in_series=1): 58 | """ 59 | Arguments: 60 | 61 | env_fns: iterable of callables - functions that create environments to run in subprocesses. Need to be cloud-pickleable 62 | in_series: number of environments to run in series in a single process 63 | (e.g. when len(env_fns) == 12 and in_series == 3, it will run 4 processes, each running 3 envs in series) 64 | """ 65 | self.waiting = False 66 | self.closed = False 67 | self.in_series = in_series 68 | nenvs = len(env_fns) 69 | assert ( 70 | nenvs % in_series == 0 71 | ), "Number of envs must be divisible by number of envs to run in series" 72 | self.nremotes = nenvs // in_series 73 | env_fns = np.array_split(env_fns, self.nremotes) 74 | ctx = mp.get_context(context) 75 | self.remotes, self.work_remotes = zip( 76 | *[ctx.Pipe() for _ in range(self.nremotes)] 77 | ) 78 | self.ps = [ 79 | ctx.Process( 80 | target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)) 81 | ) 82 | for (work_remote, remote, env_fn) in zip( 83 | self.work_remotes, self.remotes, env_fns 84 | ) 85 | ] 86 | for p in self.ps: 87 | p.daemon = ( 88 | True # if the main process crashes, we should not cause things to hang 89 | ) 90 | with clear_mpi_env_vars(): 91 | p.start() 92 | for remote in self.work_remotes: 93 | remote.close() 94 | 95 | self.remotes[0].send(("get_spaces_spec", None)) 96 | observation_space, action_space, self.spec = self.remotes[0].recv().x 97 | self.viewer = None 98 | VecEnv.__init__(self, nenvs, observation_space, action_space) 99 | 100 | def step_async(self, actions): 101 | self._assert_not_closed() 102 | actions = np.array_split(actions, self.nremotes) 103 | for remote, action in zip(self.remotes, actions): 104 | remote.send(("step", action)) 105 | self.waiting = True 106 | 107 | def step_wait(self): 108 | self._assert_not_closed() 109 | results = [remote.recv() for remote in self.remotes] 110 | results = _flatten_list(results) 111 | self.waiting = False 112 | obs, rews, dones, infos = zip(*results) 113 | return _flatten_obs(obs), np.stack(rews), np.stack(dones), infos 114 | 115 | def reset(self): 116 | self._assert_not_closed() 117 | for remote in self.remotes: 118 | remote.send(("reset", None)) 119 | obs = [remote.recv() for remote in self.remotes] 120 | obs = _flatten_list(obs) 121 | return _flatten_obs(obs) 122 | 123 | def close_extras(self): 124 | self.closed = True 125 | if self.waiting: 126 | for remote in self.remotes: 127 | remote.recv() 128 | for remote in self.remotes: 129 | remote.send(("close", None)) 130 | for p in self.ps: 131 | p.join() 132 | 133 | def get_images(self): 134 | self._assert_not_closed() 135 | for pipe in self.remotes: 136 | pipe.send(("render", None)) 137 | imgs = [pipe.recv() for pipe in self.remotes] 138 | imgs = _flatten_list(imgs) 139 | return imgs 140 | 141 | def _assert_not_closed(self): 142 | assert ( 143 | not self.closed 144 | ), "Trying to operate on a SubprocVecEnv after calling close()" 145 | 146 | def __del__(self): 147 | if not self.closed: 148 | self.close() 149 | 150 | 151 | def _flatten_obs(obs): 152 | assert isinstance(obs, (list, tuple)) 153 | assert len(obs) > 0 154 | 155 | if isinstance(obs[0], dict): 156 | keys = obs[0].keys() 157 | return {k: np.stack([o[k] for o in obs]) for k in keys} 158 | else: 159 | return np.stack(obs) 160 | 161 | 162 | def _flatten_list(l): 163 | assert isinstance(l, (list, tuple)) 164 | assert len(l) > 0 165 | assert all([len(l_) > 0 for l_ in l]) 166 | 167 | return [l__ for l_ in l for l__ in l_] 168 | -------------------------------------------------------------------------------- /utils/vec_env.py: -------------------------------------------------------------------------------- 1 | """ VecEnv from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_env.py """ 2 | 3 | import contextlib 4 | import os 5 | from abc import ABC, abstractmethod 6 | 7 | import numpy as np 8 | 9 | 10 | def tile_images(img_nhwc): 11 | """ 12 | Tile N images into one big PxQ image 13 | (P,Q) are chosen to be as close as possible, and if N 14 | is square, then P=Q. 15 | input: img_nhwc, list or array of images, ndim=4 once turned into array 16 | n = batch index, h = height, w = width, c = channel 17 | returns: 18 | bigim_HWc, ndarray with ndim=3 19 | """ 20 | img_nhwc = np.asarray(img_nhwc) 21 | N, h, w, c = img_nhwc.shape 22 | H = int(np.ceil(np.sqrt(N))) 23 | W = int(np.ceil(float(N) / H)) 24 | img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0] * 0 for _ in range(N, H * W)]) 25 | img_HWhwc = img_nhwc.reshape(H, W, h, w, c) 26 | img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4) 27 | img_Hh_Ww_c = img_HhWwc.reshape(H * h, W * w, c) 28 | return img_Hh_Ww_c 29 | 30 | 31 | class AlreadySteppingError(Exception): 32 | """ 33 | Raised when an asynchronous step is running while 34 | step_async() is called again. 35 | """ 36 | 37 | def __init__(self): 38 | msg = "already running an async step" 39 | Exception.__init__(self, msg) 40 | 41 | 42 | class NotSteppingError(Exception): 43 | """ 44 | Raised when an asynchronous step is not running but 45 | step_wait() is called. 46 | """ 47 | 48 | def __init__(self): 49 | msg = "not running an async step" 50 | Exception.__init__(self, msg) 51 | 52 | 53 | class VecEnv(ABC): 54 | """ 55 | An abstract asynchronous, vectorized environment. 56 | Used to batch data from multiple copies of an environment, so that 57 | each observation becomes an batch of observations, and expected action is a batch of actions to 58 | be applied per-environment. 59 | """ 60 | 61 | closed = False 62 | viewer = None 63 | 64 | metadata = {"render.modes": ["human", "rgb_array"]} 65 | 66 | def __init__(self, num_envs, observation_space, action_space): 67 | self.num_envs = num_envs 68 | self.observation_space = observation_space 69 | self.action_space = action_space 70 | 71 | @abstractmethod 72 | def reset(self): 73 | """ 74 | Reset all the environments and return an array of 75 | observations, or a dict of observation arrays. 76 | 77 | If step_async is still doing work, that work will 78 | be cancelled and step_wait() should not be called 79 | until step_async() is invoked again. 80 | """ 81 | pass 82 | 83 | @abstractmethod 84 | def step_async(self, actions): 85 | """ 86 | Tell all the environments to start taking a step 87 | with the given actions. 88 | Call step_wait() to get the results of the step. 89 | 90 | You should not call this if a step_async run is 91 | already pending. 92 | """ 93 | pass 94 | 95 | @abstractmethod 96 | def step_wait(self): 97 | """ 98 | Wait for the step taken with step_async(). 99 | 100 | Returns (obs, rews, dones, infos): 101 | - obs: an array of observations, or a dict of 102 | arrays of observations. 103 | - rews: an array of rewards 104 | - dones: an array of "episode done" booleans 105 | - infos: a sequence of info objects 106 | """ 107 | pass 108 | 109 | def close_extras(self): 110 | """ 111 | Clean up the extra resources, beyond what's in this base class. 112 | Only runs when not self.closed. 113 | """ 114 | pass 115 | 116 | def close(self): 117 | if self.closed: 118 | return 119 | if self.viewer is not None: 120 | self.viewer.close() 121 | self.close_extras() 122 | self.closed = True 123 | 124 | def step(self, actions): 125 | """ 126 | Step the environments synchronously. 127 | 128 | This is available for backwards compatibility. 129 | """ 130 | self.step_async(actions) 131 | return self.step_wait() 132 | 133 | def render(self, mode="human"): 134 | imgs = self.get_images() 135 | bigimg = tile_images(imgs) 136 | if mode == "human": 137 | self.get_viewer().imshow(bigimg) 138 | return self.get_viewer().isopen 139 | elif mode == "rgb_array": 140 | return bigimg 141 | else: 142 | raise NotImplementedError 143 | 144 | def get_images(self): 145 | """ 146 | Return RGB images from each environment 147 | """ 148 | raise NotImplementedError 149 | 150 | @property 151 | def unwrapped(self): 152 | if isinstance(self, VecEnvWrapper): 153 | return self.venv.unwrapped 154 | else: 155 | return self 156 | 157 | def get_viewer(self): 158 | if self.viewer is None: 159 | from gym.envs.classic_control import rendering 160 | 161 | self.viewer = rendering.SimpleImageViewer() 162 | return self.viewer 163 | 164 | 165 | class VecEnvWrapper(VecEnv): 166 | """ 167 | An environment wrapper that applies to an entire batch 168 | of environments at once. 169 | """ 170 | 171 | def __init__(self, venv, observation_space=None, action_space=None): 172 | self.venv = venv 173 | super().__init__( 174 | num_envs=venv.num_envs, 175 | observation_space=observation_space or venv.observation_space, 176 | action_space=action_space or venv.action_space, 177 | ) 178 | 179 | def step_async(self, actions): 180 | self.venv.step_async(actions) 181 | 182 | @abstractmethod 183 | def reset(self): 184 | pass 185 | 186 | @abstractmethod 187 | def step_wait(self): 188 | pass 189 | 190 | def close(self): 191 | return self.venv.close() 192 | 193 | def render(self, mode="human"): 194 | return self.venv.render(mode=mode) 195 | 196 | def get_images(self): 197 | return self.venv.get_images() 198 | 199 | def __getattr__(self, name): 200 | if name.startswith("_"): 201 | raise AttributeError( 202 | "attempted to get missing private attribute '{}'".format(name) 203 | ) 204 | return getattr(self.venv, name) 205 | 206 | 207 | class VecEnvObservationWrapper(VecEnvWrapper): 208 | @abstractmethod 209 | def process(self, obs): 210 | pass 211 | 212 | def reset(self): 213 | obs = self.venv.reset() 214 | return self.process(obs) 215 | 216 | def step_wait(self): 217 | obs, rews, dones, infos = self.venv.step_wait() 218 | return self.process(obs), rews, dones, infos 219 | 220 | 221 | class CloudpickleWrapper(object): 222 | """ 223 | Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) 224 | """ 225 | 226 | def __init__(self, x): 227 | self.x = x 228 | 229 | def __getstate__(self): 230 | import cloudpickle 231 | 232 | return cloudpickle.dumps(self.x) 233 | 234 | def __setstate__(self, ob): 235 | import pickle 236 | 237 | self.x = pickle.loads(ob) 238 | 239 | 240 | @contextlib.contextmanager 241 | def clear_mpi_env_vars(): 242 | """ 243 | from mpi4py import MPI will call MPI_Init by default. If the child process has MPI environment variables, MPI will think that the child process is an MPI process just like the parent and do bad things such as hang. 244 | This context manager is a hacky way to clear those environment variables temporarily such as when we are starting multiprocessing 245 | Processes. 246 | """ 247 | removed_environment = {} 248 | for k, v in list(os.environ.items()): 249 | for prefix in ["OMPI_", "PMI_"]: 250 | if k.startswith(prefix): 251 | removed_environment[k] = v 252 | del os.environ[k] 253 | try: 254 | yield 255 | finally: 256 | os.environ.update(removed_environment) 257 | --------------------------------------------------------------------------------