├── .gitignore ├── LICENSE ├── README.md ├── configs.yaml ├── dreamer.py ├── exploration.py ├── figs ├── bank_heist_logs.png ├── dreamer-test.png ├── freeway_logs.png └── walker_walk_logs.png ├── models.py ├── networks.py ├── requirements.txt ├── tools.py └── wrappers.py /.gitignore: -------------------------------------------------------------------------------- 1 | # 2 | run.sh 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Jaesik Yoon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # dreamer-torch 2 | 3 | Pytorch version of Dreamer, which follows the original TF v2 codes (https://github.com/danijar/dreamerv2/tree/faf9e4c606e735f32c8b971a3092877265e49cc2). 4 | 5 | Due to limitation of resource, we tested it with the official TF version2 codes for 10 tasks (5 DMC / 5 Atari tasks) as below. 6 | 7 | ![alt text](https://github.com/jsikyoon/dreamer-torch/blob/main/figs/dreamer-test.png?raw=true) 8 | 9 | As we can see, for almost tasks, it shows similar performance with Dreamer when comparing the reported performance on paper and from running codes. 10 | 11 | For freeway, I found that it is slower than Dreamer, maybe I think it is because of random seed. I will find the reason more. 12 | 13 | The below logs are from Tensorboard, you also can see through [Tensorboard logs](https://tensorboard.dev/experiment/QsJYF8DaTaaLiPJFvqp02Q/#scalars). 14 | 15 | - DMC walker walk task 16 | 17 | ![alt text](https://github.com/jsikyoon/dreamer-torch/blob/main/figs/walker_walk_logs.png?raw=true) 18 | 19 | 20 | - Atari bank heist task 21 | 22 | ![alt text](https://github.com/jsikyoon/dreamer-torch/blob/main/figs/bank_heist_logs.png?raw=true) 23 | 24 | 25 | - atari freeway task 26 | 27 | ![alt text](https://github.com/jsikyoon/dreamer-torch/blob/main/figs/freeway_logs.png?raw=true) 28 | 29 | ## How to use 30 | For required packages to run this, you can find from [requirements.txt](https://github.com/jsikyoon/dreamer-torch/blob/main/requirements.txt). 31 | 32 | The command to run is exactly same to the official codes, [you can use it](https://github.com/danijar/dreamerv2/tree/faf9e4c606e735f32c8b971a3092877265e49cc2#instructions). 33 | 34 | ## Contact 35 | Any feedback are welcome! Please open an issue on this repository or send email to Jaesik Yoon (jaesik.yoon.kr@gmail.com). 36 | -------------------------------------------------------------------------------- /configs.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | 3 | logdir: null 4 | traindir: null 5 | evaldir: null 6 | offline_traindir: '' 7 | offline_evaldir: '' 8 | seed: 0 9 | steps: 1e7 10 | eval_every: 1e4 11 | log_every: 1e4 12 | reset_every: 0 13 | #gpu_growth: True 14 | device: 'cuda:0' 15 | precision: 16 16 | debug: False 17 | expl_gifs: False 18 | 19 | # Environment 20 | task: 'dmc_walker_walk' 21 | size: [64, 64] 22 | envs: 1 23 | action_repeat: 2 24 | time_limit: 1000 25 | grayscale: False 26 | prefill: 2500 27 | eval_noise: 0.0 28 | clip_rewards: 'identity' 29 | 30 | # Model 31 | dyn_cell: 'gru' 32 | dyn_hidden: 200 33 | dyn_deter: 200 34 | dyn_stoch: 50 35 | dyn_discrete: 0 36 | dyn_input_layers: 1 37 | dyn_output_layers: 1 38 | dyn_rec_depth: 1 39 | dyn_shared: False 40 | dyn_mean_act: 'none' 41 | dyn_std_act: 'sigmoid2' 42 | dyn_min_std: 0.1 43 | dyn_temp_post: True 44 | grad_heads: ['image', 'reward'] 45 | units: 400 46 | reward_layers: 2 47 | discount_layers: 3 48 | value_layers: 3 49 | actor_layers: 4 50 | act: 'ELU' 51 | cnn_depth: 32 52 | encoder_kernels: [4, 4, 4, 4] 53 | decoder_kernels: [5, 5, 6, 6] 54 | decoder_thin: True 55 | value_head: 'normal' 56 | kl_scale: '1.0' 57 | kl_balance: '0.8' 58 | kl_free: '1.0' 59 | kl_forward: False 60 | pred_discount: False 61 | discount_scale: 1.0 62 | reward_scale: 1.0 63 | weight_decay: 0.0 64 | 65 | # Training 66 | batch_size: 50 67 | batch_length: 50 68 | train_every: 5 69 | train_steps: 1 70 | pretrain: 100 71 | model_lr: 3e-4 72 | value_lr: 8e-5 73 | actor_lr: 8e-5 74 | opt_eps: 1e-5 75 | grad_clip: 100 76 | value_grad_clip: 100 77 | actor_grad_clip: 100 78 | dataset_size: 0 79 | oversample_ends: False 80 | slow_value_target: True 81 | slow_actor_target: True 82 | slow_target_update: 100 83 | slow_target_fraction: 1 84 | opt: 'adam' 85 | 86 | # Behavior. 87 | discount: 0.99 88 | discount_lambda: 0.95 89 | imag_horizon: 15 90 | imag_gradient: 'dynamics' 91 | imag_gradient_mix: '0.1' 92 | imag_sample: True 93 | actor_dist: 'trunc_normal' 94 | actor_entropy: '1e-4' 95 | actor_state_entropy: 0.0 96 | actor_init_std: 1.0 97 | actor_min_std: 0.1 98 | actor_disc: 5 99 | actor_temp: 0.1 100 | actor_outscale: 0.0 101 | expl_amount: 0.0 102 | eval_state_mean: False 103 | collect_dyn_sample: True 104 | behavior_stop_grad: True 105 | value_decay: 0.0 106 | future_entropy: False 107 | 108 | # Exploration 109 | expl_behavior: 'greedy' 110 | expl_until: 0 111 | expl_extr_scale: 0.0 112 | expl_intr_scale: 1.0 113 | disag_target: 'stoch' 114 | disag_log: True 115 | disag_models: 10 116 | disag_offset: 1 117 | disag_layers: 4 118 | disag_units: 400 119 | disag_action_cond: False 120 | 121 | 122 | dmlab: 123 | 124 | # General 125 | task: 'dmlab_rooms_watermaze' 126 | steps: 2e8 127 | eval_every: 1e5 128 | log_every: 1e4 129 | prefill: 50000 130 | dataset_size: 2e6 131 | pretrain: 0 132 | 133 | # Environment 134 | time_limit: 108000 # 30 minutes of game play. 135 | #grayscale: True 136 | action_repeat: 4 137 | eval_noise: 0.0 138 | train_every: 16 139 | train_steps: 1 140 | clip_rewards: 'tanh' 141 | 142 | # Model 143 | grad_heads: ['image', 'reward', 'discount'] 144 | dyn_cell: 'gru_layer_norm' 145 | pred_discount: True 146 | cnn_depth: 48 147 | dyn_deter: 600 148 | dyn_hidden: 600 149 | dyn_stoch: 32 150 | dyn_discrete: 32 151 | reward_layers: 4 152 | discount_layers: 4 153 | value_layers: 4 154 | actor_layers: 4 155 | 156 | # Behavior 157 | actor_dist: 'onehot' 158 | actor_entropy: 'linear(3e-3,3e-4,2.5e6)' 159 | expl_amount: 0.0 160 | discount: 0.999 161 | imag_gradient: 'both' 162 | imag_gradient_mix: 'linear(0.1,0,2.5e6)' 163 | 164 | # Training 165 | discount_scale: 5.0 166 | reward_scale: 1 167 | weight_decay: 1e-6 168 | model_lr: 2e-4 169 | kl_scale: 0.1 170 | kl_free: 0.0 171 | actor_lr: 4e-5 172 | value_lr: 1e-4 173 | oversample_ends: True 174 | 175 | 176 | atari: 177 | 178 | # General 179 | task: 'atari_pong' 180 | steps: 2e8 181 | eval_every: 1e5 182 | log_every: 1e4 183 | prefill: 50000 184 | dataset_size: 2e6 185 | pretrain: 0 186 | 187 | # Environment 188 | time_limit: 108000 # 30 minutes of game play. 189 | grayscale: True 190 | action_repeat: 4 191 | eval_noise: 0.0 192 | train_every: 16 193 | train_steps: 1 194 | clip_rewards: 'tanh' 195 | 196 | # Model 197 | grad_heads: ['image', 'reward', 'discount'] 198 | dyn_cell: 'gru_layer_norm' 199 | pred_discount: True 200 | cnn_depth: 48 201 | dyn_deter: 600 202 | dyn_hidden: 600 203 | dyn_stoch: 32 204 | dyn_discrete: 32 205 | reward_layers: 4 206 | discount_layers: 4 207 | value_layers: 4 208 | actor_layers: 4 209 | 210 | # Behavior 211 | actor_dist: 'onehot' 212 | actor_entropy: 'linear(3e-3,3e-4,2.5e6)' 213 | expl_amount: 0.0 214 | discount: 0.999 215 | imag_gradient: 'both' 216 | imag_gradient_mix: 'linear(0.1,0,2.5e6)' 217 | 218 | # Training 219 | discount_scale: 5.0 220 | reward_scale: 1 221 | weight_decay: 1e-6 222 | model_lr: 2e-4 223 | kl_scale: 0.1 224 | kl_free: 0.0 225 | actor_lr: 4e-5 226 | value_lr: 1e-4 227 | oversample_ends: True 228 | 229 | dmc: 230 | 231 | # General 232 | task: 'dmc_walker_walk' 233 | steps: 1e7 234 | eval_every: 1e4 235 | log_every: 1e4 236 | prefill: 2500 237 | dataset_size: 0 238 | pretrain: 100 239 | 240 | # Environment 241 | time_limit: 1000 242 | action_repeat: 2 243 | train_every: 5 244 | train_steps: 1 245 | 246 | # Model 247 | grad_heads: ['image', 'reward'] 248 | dyn_cell: 'gru_layer_norm' 249 | pred_discount: False 250 | cnn_depth: 32 251 | dyn_deter: 200 252 | dyn_stoch: 50 253 | dyn_discrete: 0 254 | reward_layers: 2 255 | discount_layers: 3 256 | value_layers: 3 257 | actor_layers: 4 258 | 259 | # Behavior 260 | actor_dist: 'trunc_normal' 261 | expl_amount: 0.0 262 | actor_entropy: '1e-4' 263 | discount: 0.99 264 | imag_gradient: 'dynamics' 265 | imag_gradient_mix: 1.0 266 | 267 | # Training 268 | reward_scale: 2 269 | weight_decay: 0.0 270 | model_lr: 3e-4 271 | value_lr: 8e-5 272 | actor_lr: 8e-5 273 | opt_eps: 1e-5 274 | kl_free: '1.0' 275 | kl_scale: '1.0' 276 | 277 | debug: 278 | 279 | debug: True 280 | pretrain: 1 281 | prefill: 1 282 | train_steps: 1 283 | batch_size: 10 284 | batch_length: 20 285 | 286 | -------------------------------------------------------------------------------- /dreamer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import functools 4 | import os 5 | import pathlib 6 | import sys 7 | import warnings 8 | 9 | os.environ['MUJOCO_GL'] = 'egl' 10 | 11 | import numpy as np 12 | import ruamel.yaml as yaml 13 | 14 | sys.path.append(str(pathlib.Path(__file__).parent)) 15 | 16 | import exploration as expl 17 | import models 18 | import tools 19 | import wrappers 20 | 21 | import torch 22 | from torch import nn 23 | from torch import distributions as torchd 24 | to_np = lambda x: x.detach().cpu().numpy() 25 | 26 | 27 | class Dreamer(nn.Module): 28 | 29 | def __init__(self, config, logger, dataset): 30 | super(Dreamer, self).__init__() 31 | self._config = config 32 | self._logger = logger 33 | self._should_log = tools.Every(config.log_every) 34 | self._should_train = tools.Every(config.train_every) 35 | self._should_pretrain = tools.Once() 36 | self._should_reset = tools.Every(config.reset_every) 37 | self._should_expl = tools.Until(int( 38 | config.expl_until / config.action_repeat)) 39 | self._metrics = {} 40 | self._step = count_steps(config.traindir) 41 | # Schedules. 42 | config.actor_entropy = ( 43 | lambda x=config.actor_entropy: tools.schedule(x, self._step)) 44 | config.actor_state_entropy = ( 45 | lambda x=config.actor_state_entropy: tools.schedule(x, self._step)) 46 | config.imag_gradient_mix = ( 47 | lambda x=config.imag_gradient_mix: tools.schedule(x, self._step)) 48 | self._dataset = dataset 49 | self._wm = models.WorldModel(self._step, config) 50 | self._task_behavior = models.ImagBehavior( 51 | config, self._wm, config.behavior_stop_grad) 52 | reward = lambda f, s, a: self._wm.heads['reward'](f).mean 53 | self._expl_behavior = dict( 54 | greedy=lambda: self._task_behavior, 55 | random=lambda: expl.Random(config), 56 | plan2explore=lambda: expl.Plan2Explore(config, self._wm, reward), 57 | )[config.expl_behavior]() 58 | 59 | def __call__(self, obs, reset, state=None, reward=None, training=True): 60 | step = self._step 61 | if self._should_reset(step): 62 | state = None 63 | if state is not None and reset.any(): 64 | mask = 1 - reset 65 | for key in state[0].keys(): 66 | for i in range(state[0][key].shape[0]): 67 | state[0][key][i] *= mask[i] 68 | for i in range(len(state[1])): 69 | state[1][i] *= mask[i] 70 | if training and self._should_train(step): 71 | steps = ( 72 | self._config.pretrain if self._should_pretrain() 73 | else self._config.train_steps) 74 | for _ in range(steps): 75 | self._train(next(self._dataset)) 76 | if self._should_log(step): 77 | for name, values in self._metrics.items(): 78 | self._logger.scalar(name, float(np.mean(values))) 79 | self._metrics[name] = [] 80 | openl = self._wm.video_pred(next(self._dataset)) 81 | self._logger.video('train_openl', to_np(openl)) 82 | self._logger.write(fps=True) 83 | 84 | policy_output, state = self._policy(obs, state, training) 85 | 86 | if training: 87 | self._step += len(reset) 88 | self._logger.step = self._config.action_repeat * self._step 89 | return policy_output, state 90 | 91 | def _policy(self, obs, state, training): 92 | if state is None: 93 | batch_size = len(obs['image']) 94 | latent = self._wm.dynamics.initial(len(obs['image'])) 95 | action = torch.zeros((batch_size, self._config.num_actions)).to(self._config.device) 96 | else: 97 | latent, action = state 98 | embed = self._wm.encoder(self._wm.preprocess(obs)) 99 | latent, _ = self._wm.dynamics.obs_step( 100 | latent, action, embed, self._config.collect_dyn_sample) 101 | if self._config.eval_state_mean: 102 | latent['stoch'] = latent['mean'] 103 | feat = self._wm.dynamics.get_feat(latent) 104 | if not training: 105 | actor = self._task_behavior.actor(feat) 106 | action = actor.mode() 107 | elif self._should_expl(self._step): 108 | actor = self._expl_behavior.actor(feat) 109 | action = actor.sample() 110 | else: 111 | actor = self._task_behavior.actor(feat) 112 | action = actor.sample() 113 | logprob = actor.log_prob(action) 114 | latent = {k: v.detach() for k, v in latent.items()} 115 | action = action.detach() 116 | if self._config.actor_dist == 'onehot_gumble': 117 | action = torch.one_hot(torch.argmax(action, dim=-1), self._config.num_actions) 118 | action = self._exploration(action, training) 119 | policy_output = {'action': action, 'logprob': logprob} 120 | state = (latent, action) 121 | return policy_output, state 122 | 123 | def _exploration(self, action, training): 124 | amount = self._config.expl_amount if training else self._config.eval_noise 125 | if amount == 0: 126 | return action 127 | if 'onehot' in self._config.actor_dist: 128 | probs = amount / self._config.num_actions + (1 - amount) * action 129 | return tools.OneHotDist(probs=probs).sample() 130 | else: 131 | return torch.clip(torchd.normal.Normal(action, amount).sample(), -1, 1) 132 | raise NotImplementedError(self._config.action_noise) 133 | 134 | def _train(self, data): 135 | metrics = {} 136 | post, context, mets = self._wm._train(data) 137 | metrics.update(mets) 138 | start = post 139 | if self._config.pred_discount: # Last step could be terminal. 140 | start = {k: v[:, :-1] for k, v in post.items()} 141 | context = {k: v[:, :-1] for k, v in context.items()} 142 | reward = lambda f, s, a: self._wm.heads['reward']( 143 | self._wm.dynamics.get_feat(s)).mode() 144 | metrics.update(self._task_behavior._train(start, reward)[-1]) 145 | if self._config.expl_behavior != 'greedy': 146 | if self._config.pred_discount: 147 | data = {k: v[:, :-1] for k, v in data.items()} 148 | mets = self._expl_behavior.train(start, context, data)[-1] 149 | metrics.update({'expl_' + key: value for key, value in mets.items()}) 150 | for name, value in metrics.items(): 151 | if not name in self._metrics.keys(): 152 | self._metrics[name] = [value] 153 | else: 154 | self._metrics[name].append(value) 155 | 156 | 157 | def count_steps(folder): 158 | return sum(int(str(n).split('-')[-1][:-4]) - 1 for n in folder.glob('*.npz')) 159 | 160 | 161 | def make_dataset(episodes, config): 162 | generator = tools.sample_episodes( 163 | episodes, config.batch_length, config.oversample_ends) 164 | dataset = tools.from_generator(generator, config.batch_size) 165 | return dataset 166 | 167 | 168 | def make_env(config, logger, mode, train_eps, eval_eps): 169 | suite, task = config.task.split('_', 1) 170 | if suite == 'dmc': 171 | env = wrappers.DeepMindControl(task, config.action_repeat, config.size) 172 | env = wrappers.NormalizeActions(env) 173 | elif suite == 'atari': 174 | env = wrappers.Atari( 175 | task, config.action_repeat, config.size, 176 | grayscale=config.grayscale, 177 | life_done=False and ('train' in mode), 178 | sticky_actions=True, 179 | all_actions=True) 180 | env = wrappers.OneHotAction(env) 181 | elif suite == 'dmlab': 182 | env = wrappers.DeepMindLabyrinth( 183 | task, 184 | mode if 'train' in mode else 'test', 185 | config.action_repeat) 186 | env = wrappers.OneHotAction(env) 187 | else: 188 | raise NotImplementedError(suite) 189 | env = wrappers.TimeLimit(env, config.time_limit) 190 | env = wrappers.SelectAction(env, key='action') 191 | if (mode == 'train') or (mode == 'eval'): 192 | callbacks = [functools.partial( 193 | process_episode, config, logger, mode, train_eps, eval_eps)] 194 | env = wrappers.CollectDataset(env, callbacks) 195 | env = wrappers.RewardObs(env) 196 | return env 197 | 198 | 199 | def process_episode(config, logger, mode, train_eps, eval_eps, episode): 200 | directory = dict(train=config.traindir, eval=config.evaldir)[mode] 201 | cache = dict(train=train_eps, eval=eval_eps)[mode] 202 | filename = tools.save_episodes(directory, [episode])[0] 203 | length = len(episode['reward']) - 1 204 | score = float(episode['reward'].astype(np.float64).sum()) 205 | video = episode['image'] 206 | if mode == 'eval': 207 | cache.clear() 208 | if mode == 'train' and config.dataset_size: 209 | total = 0 210 | for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])): 211 | if total <= config.dataset_size - length: 212 | total += len(ep['reward']) - 1 213 | else: 214 | del cache[key] 215 | logger.scalar('dataset_size', total + length) 216 | cache[str(filename)] = episode 217 | print(f'{mode.title()} episode has {length} steps and return {score:.1f}.') 218 | logger.scalar(f'{mode}_return', score) 219 | logger.scalar(f'{mode}_length', length) 220 | logger.scalar(f'{mode}_episodes', len(cache)) 221 | if mode == 'eval' or config.expl_gifs: 222 | logger.video(f'{mode}_policy', video[None]) 223 | logger.write() 224 | 225 | 226 | def main(config): 227 | logdir = pathlib.Path(config.logdir).expanduser() 228 | config.traindir = config.traindir or logdir / 'train_eps' 229 | config.evaldir = config.evaldir or logdir / 'eval_eps' 230 | config.steps //= config.action_repeat 231 | config.eval_every //= config.action_repeat 232 | config.log_every //= config.action_repeat 233 | config.time_limit //= config.action_repeat 234 | config.act = getattr(torch.nn, config.act) 235 | 236 | print('Logdir', logdir) 237 | logdir.mkdir(parents=True, exist_ok=True) 238 | config.traindir.mkdir(parents=True, exist_ok=True) 239 | config.evaldir.mkdir(parents=True, exist_ok=True) 240 | step = count_steps(config.traindir) 241 | logger = tools.Logger(logdir, config.action_repeat * step) 242 | 243 | print('Create envs.') 244 | if config.offline_traindir: 245 | directory = config.offline_traindir.format(**vars(config)) 246 | else: 247 | directory = config.traindir 248 | train_eps = tools.load_episodes(directory, limit=config.dataset_size) 249 | if config.offline_evaldir: 250 | directory = config.offline_evaldir.format(**vars(config)) 251 | else: 252 | directory = config.evaldir 253 | eval_eps = tools.load_episodes(directory, limit=1) 254 | make = lambda mode: make_env(config, logger, mode, train_eps, eval_eps) 255 | train_envs = [make('train') for _ in range(config.envs)] 256 | eval_envs = [make('eval') for _ in range(config.envs)] 257 | acts = train_envs[0].action_space 258 | config.num_actions = acts.n if hasattr(acts, 'n') else acts.shape[0] 259 | 260 | if not config.offline_traindir: 261 | prefill = max(0, config.prefill - count_steps(config.traindir)) 262 | print(f'Prefill dataset ({prefill} steps).') 263 | if hasattr(acts, 'discrete'): 264 | random_actor = tools.OneHotDist(torch.zeros_like(torch.Tensor(acts.low))[None]) 265 | else: 266 | random_actor = torchd.independent.Independent( 267 | torchd.uniform.Uniform(torch.Tensor(acts.low)[None], 268 | torch.Tensor(acts.high)[None]), 1) 269 | def random_agent(o, d, s, r): 270 | action = random_actor.sample() 271 | logprob = random_actor.log_prob(action) 272 | return {'action': action, 'logprob': logprob}, None 273 | tools.simulate(random_agent, train_envs, prefill) 274 | tools.simulate(random_agent, eval_envs, episodes=1) 275 | logger.step = config.action_repeat * count_steps(config.traindir) 276 | 277 | print('Simulate agent.') 278 | train_dataset = make_dataset(train_eps, config) 279 | eval_dataset = make_dataset(eval_eps, config) 280 | agent = Dreamer(config, logger, train_dataset).to(config.device) 281 | agent.requires_grad_(requires_grad=False) 282 | if (logdir / 'latest_model.pt').exists(): 283 | agent.load_state_dict(torch.load(logdir / 'latest_model.pt')) 284 | agent._should_pretrain._once = False 285 | 286 | state = None 287 | while agent._step < config.steps: 288 | logger.write() 289 | print('Start evaluation.') 290 | video_pred = agent._wm.video_pred(next(eval_dataset)) 291 | logger.video('eval_openl', to_np(video_pred)) 292 | eval_policy = functools.partial(agent, training=False) 293 | tools.simulate(eval_policy, eval_envs, episodes=1) 294 | print('Start training.') 295 | state = tools.simulate(agent, train_envs, config.eval_every, state=state) 296 | torch.save(agent.state_dict(), logdir / 'latest_model.pt') 297 | for env in train_envs + eval_envs: 298 | try: 299 | env.close() 300 | except Exception: 301 | pass 302 | 303 | 304 | if __name__ == '__main__': 305 | parser = argparse.ArgumentParser() 306 | parser.add_argument('--configs', nargs='+', required=True) 307 | args, remaining = parser.parse_known_args() 308 | configs = yaml.safe_load( 309 | (pathlib.Path(sys.argv[0]).parent / 'configs.yaml').read_text()) 310 | defaults = {} 311 | for name in args.configs: 312 | defaults.update(configs[name]) 313 | parser = argparse.ArgumentParser() 314 | for key, value in sorted(defaults.items(), key=lambda x: x[0]): 315 | arg_type = tools.args_type(value) 316 | parser.add_argument(f'--{key}', type=arg_type, default=arg_type(value)) 317 | main(parser.parse_args(remaining)) 318 | -------------------------------------------------------------------------------- /exploration.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch import distributions as torchd 4 | 5 | import models 6 | import networks 7 | import tools 8 | 9 | 10 | class Random(nn.Module): 11 | 12 | def __init__(self, config): 13 | self._config = config 14 | 15 | def actor(self, feat): 16 | shape = feat.shape[:-1] + [self._config.num_actions] 17 | if self._config.actor_dist == 'onehot': 18 | return tools.OneHotDist(torch.zeros(shape)) 19 | else: 20 | ones = torch.ones(shape) 21 | return tools.ContDist(torchd.uniform.Uniform(-ones, ones)) 22 | 23 | def train(self, start, context): 24 | return None, {} 25 | 26 | 27 | #class Plan2Explore(tools.Module): 28 | class Plan2Explore(nn.Module): 29 | 30 | def __init__(self, config, world_model, reward=None): 31 | self._config = config 32 | self._reward = reward 33 | self._behavior = models.ImagBehavior(config, world_model) 34 | self.actor = self._behavior.actor 35 | stoch_size = config.dyn_stoch 36 | if config.dyn_discrete: 37 | stoch_size *= config.dyn_discrete 38 | size = { 39 | 'embed': 32 * config.cnn_depth, 40 | 'stoch': stoch_size, 41 | 'deter': config.dyn_deter, 42 | 'feat': config.dyn_stoch + config.dyn_deter, 43 | }[self._config.disag_target] 44 | kw = dict( 45 | inp_dim=config.dyn_stoch, # pytorch version 46 | shape=size, layers=config.disag_layers, units=config.disag_units, 47 | act=config.act) 48 | self._networks = [ 49 | networks.DenseHead(**kw) for _ in range(config.disag_models)] 50 | self._opt = tools.optimizer(config.opt, self.parameters(), 51 | config.model_lr, config.opt_eps, config.weight_decay) 52 | #self._opt = tools.Optimizer( 53 | # 'ensemble', config.model_lr, config.opt_eps, config.grad_clip, 54 | # config.weight_decay, opt=config.opt) 55 | 56 | def train(self, start, context, data): 57 | metrics = {} 58 | stoch = start['stoch'] 59 | if self._config.dyn_discrete: 60 | stoch = tf.reshape( 61 | stoch, stoch.shape[:-2] + (stoch.shape[-2] * stoch.shape[-1])) 62 | target = { 63 | 'embed': context['embed'], 64 | 'stoch': stoch, 65 | 'deter': start['deter'], 66 | 'feat': context['feat'], 67 | }[self._config.disag_target] 68 | inputs = context['feat'] 69 | if self._config.disag_action_cond: 70 | inputs = tf.concat([inputs, data['action']], -1) 71 | metrics.update(self._train_ensemble(inputs, target)) 72 | metrics.update(self._behavior.train(start, self._intrinsic_reward)[-1]) 73 | return None, metrics 74 | 75 | def _intrinsic_reward(self, feat, state, action): 76 | inputs = feat 77 | if self._config.disag_action_cond: 78 | inputs = tf.concat([inputs, action], -1) 79 | preds = [head(inputs, tf.float32).mean() for head in self._networks] 80 | disag = tf.reduce_mean(tf.math.reduce_std(preds, 0), -1) 81 | if self._config.disag_log: 82 | disag = tf.math.log(disag) 83 | reward = self._config.expl_intr_scale * disag 84 | if self._config.expl_extr_scale: 85 | reward += tf.cast(self._config.expl_extr_scale * self._reward( 86 | feat, state, action), tf.float32) 87 | return reward 88 | 89 | def _train_ensemble(self, inputs, targets): 90 | if self._config.disag_offset: 91 | targets = targets[:, self._config.disag_offset:] 92 | inputs = inputs[:, :-self._config.disag_offset] 93 | targets = tf.stop_gradient(targets) 94 | inputs = tf.stop_gradient(inputs) 95 | with tf.GradientTape() as tape: 96 | preds = [head(inputs) for head in self._networks] 97 | likes = [tf.reduce_mean(pred.log_prob(targets)) for pred in preds] 98 | loss = -tf.cast(tf.reduce_sum(likes), tf.float32) 99 | metrics = self._opt(tape, loss, self._networks) 100 | return metrics 101 | -------------------------------------------------------------------------------- /figs/bank_heist_logs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsikyoon/dreamer-torch/7c2331acd4fa6196d140943e977f23fb177398b3/figs/bank_heist_logs.png -------------------------------------------------------------------------------- /figs/dreamer-test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsikyoon/dreamer-torch/7c2331acd4fa6196d140943e977f23fb177398b3/figs/dreamer-test.png -------------------------------------------------------------------------------- /figs/freeway_logs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsikyoon/dreamer-torch/7c2331acd4fa6196d140943e977f23fb177398b3/figs/freeway_logs.png -------------------------------------------------------------------------------- /figs/walker_walk_logs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsikyoon/dreamer-torch/7c2331acd4fa6196d140943e977f23fb177398b3/figs/walker_walk_logs.png -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | from PIL import ImageColor, Image, ImageDraw, ImageFont 5 | 6 | import networks 7 | import tools 8 | to_np = lambda x: x.detach().cpu().numpy() 9 | 10 | 11 | class WorldModel(nn.Module): 12 | 13 | def __init__(self, step, config): 14 | super(WorldModel, self).__init__() 15 | self._step = step 16 | self._use_amp = True if config.precision==16 else False 17 | self._config = config 18 | self.encoder = networks.ConvEncoder(config.grayscale, 19 | config.cnn_depth, config.act, config.encoder_kernels) 20 | if config.size[0] == 64 and config.size[1] == 64: 21 | embed_size = 2 ** (len(config.encoder_kernels)-1) * config.cnn_depth 22 | embed_size *= 2 * 2 23 | else: 24 | raise NotImplemented(f"{config.size} is not applicable now") 25 | self.dynamics = networks.RSSM( 26 | config.dyn_stoch, config.dyn_deter, config.dyn_hidden, 27 | config.dyn_input_layers, config.dyn_output_layers, 28 | config.dyn_rec_depth, config.dyn_shared, config.dyn_discrete, 29 | config.act, config.dyn_mean_act, config.dyn_std_act, 30 | config.dyn_temp_post, config.dyn_min_std, config.dyn_cell, 31 | config.num_actions, embed_size, config.device) 32 | self.heads = nn.ModuleDict() 33 | channels = (1 if config.grayscale else 3) 34 | shape = (channels,) + config.size 35 | if config.dyn_discrete: 36 | feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter 37 | else: 38 | feat_size = config.dyn_stoch + config.dyn_deter 39 | self.heads['image'] = networks.ConvDecoder( 40 | feat_size, # pytorch version 41 | config.cnn_depth, config.act, shape, config.decoder_kernels, 42 | config.decoder_thin) 43 | self.heads['reward'] = networks.DenseHead( 44 | feat_size, # pytorch version 45 | [], config.reward_layers, config.units, config.act) 46 | if config.pred_discount: 47 | self.heads['discount'] = networks.DenseHead( 48 | feat_size, # pytorch version 49 | [], config.discount_layers, config.units, config.act, dist='binary') 50 | for name in config.grad_heads: 51 | assert name in self.heads, name 52 | self._model_opt = tools.Optimizer( 53 | 'model', self.parameters(), config.model_lr, config.opt_eps, config.grad_clip, 54 | config.weight_decay, opt=config.opt, 55 | use_amp=self._use_amp) 56 | self._scales = dict( 57 | reward=config.reward_scale, discount=config.discount_scale) 58 | 59 | def _train(self, data): 60 | data = self.preprocess(data) 61 | 62 | with tools.RequiresGrad(self): 63 | with torch.cuda.amp.autocast(self._use_amp): 64 | embed = self.encoder(data) 65 | post, prior = self.dynamics.observe(embed, data['action']) 66 | kl_balance = tools.schedule(self._config.kl_balance, self._step) 67 | kl_free = tools.schedule(self._config.kl_free, self._step) 68 | kl_scale = tools.schedule(self._config.kl_scale, self._step) 69 | kl_loss, kl_value = self.dynamics.kl_loss( 70 | post, prior, self._config.kl_forward, kl_balance, kl_free, kl_scale) 71 | losses = {} 72 | likes = {} 73 | for name, head in self.heads.items(): 74 | grad_head = (name in self._config.grad_heads) 75 | feat = self.dynamics.get_feat(post) 76 | feat = feat if grad_head else feat.detach() 77 | pred = head(feat) 78 | like = pred.log_prob(data[name]) 79 | likes[name] = like 80 | losses[name] = -torch.mean(like) * self._scales.get(name, 1.0) 81 | model_loss = sum(losses.values()) + kl_loss 82 | metrics = self._model_opt(model_loss, self.parameters()) 83 | 84 | metrics.update({f'{name}_loss': to_np(loss) for name, loss in losses.items()}) 85 | metrics['kl_balance'] = kl_balance 86 | metrics['kl_free'] = kl_free 87 | metrics['kl_scale'] = kl_scale 88 | metrics['kl'] = to_np(torch.mean(kl_value)) 89 | with torch.cuda.amp.autocast(self._use_amp): 90 | metrics['prior_ent'] = to_np(torch.mean(self.dynamics.get_dist(prior).entropy())) 91 | metrics['post_ent'] = to_np(torch.mean(self.dynamics.get_dist(post).entropy())) 92 | context = dict( 93 | embed=embed, feat=self.dynamics.get_feat(post), 94 | kl=kl_value, postent=self.dynamics.get_dist(post).entropy()) 95 | post = {k: v.detach() for k, v in post.items()} 96 | return post, context, metrics 97 | 98 | def preprocess(self, obs): 99 | obs = obs.copy() 100 | obs['image'] = torch.Tensor(obs['image']) / 255.0 - 0.5 101 | if self._config.clip_rewards == 'tanh': 102 | obs['reward'] = torch.tanh(torch.Tensor(obs['reward'])).unsqueeze(-1) 103 | elif self._config.clip_rewards == 'identity': 104 | obs['reward'] = torch.Tensor(obs['reward']).unsqueeze(-1) 105 | else: 106 | raise NotImplemented(f'{self._config.clip_rewards} is not implemented') 107 | if 'discount' in obs: 108 | obs['discount'] *= self._config.discount 109 | obs['discount'] = torch.Tensor(obs['discount']).unsqueeze(-1) 110 | obs = {k: torch.Tensor(v).to(self._config.device) for k, v in obs.items()} 111 | return obs 112 | 113 | def video_pred(self, data): 114 | data = self.preprocess(data) 115 | truth = data['image'][:6] + 0.5 116 | embed = self.encoder(data) 117 | 118 | states, _ = self.dynamics.observe(embed[:6, :5], data['action'][:6, :5]) 119 | recon = self.heads['image']( 120 | self.dynamics.get_feat(states)).mode()[:6] 121 | reward_post = self.heads['reward']( 122 | self.dynamics.get_feat(states)).mode()[:6] 123 | init = {k: v[:, -1] for k, v in states.items()} 124 | prior = self.dynamics.imagine(data['action'][:6, 5:], init) 125 | openl = self.heads['image'](self.dynamics.get_feat(prior)).mode() 126 | reward_prior = self.heads['reward'](self.dynamics.get_feat(prior)).mode() 127 | model = torch.cat([recon[:, :5] + 0.5, openl + 0.5], 1) 128 | error = (model - truth + 1) / 2 129 | 130 | return torch.cat([truth, model, error], 2) 131 | 132 | 133 | class ImagBehavior(nn.Module): 134 | 135 | def __init__(self, config, world_model, stop_grad_actor=True, reward=None): 136 | super(ImagBehavior, self).__init__() 137 | self._use_amp = True if config.precision==16 else False 138 | self._config = config 139 | self._world_model = world_model 140 | self._stop_grad_actor = stop_grad_actor 141 | self._reward = reward 142 | if config.dyn_discrete: 143 | feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter 144 | else: 145 | feat_size = config.dyn_stoch + config.dyn_deter 146 | self.actor = networks.ActionHead( 147 | feat_size, # pytorch version 148 | config.num_actions, config.actor_layers, config.units, config.act, 149 | config.actor_dist, config.actor_init_std, config.actor_min_std, 150 | config.actor_dist, config.actor_temp, config.actor_outscale) 151 | self.value = networks.DenseHead( 152 | feat_size, # pytorch version 153 | [], config.value_layers, config.units, config.act, 154 | config.value_head) 155 | if config.slow_value_target or config.slow_actor_target: 156 | self._slow_value = networks.DenseHead( 157 | feat_size, # pytorch version 158 | [], config.value_layers, config.units, config.act) 159 | self._updates = 0 160 | kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp) 161 | self._actor_opt = tools.Optimizer( 162 | 'actor', self.actor.parameters(), config.actor_lr, config.opt_eps, config.actor_grad_clip, 163 | **kw) 164 | self._value_opt = tools.Optimizer( 165 | 'value', self.value.parameters(), config.value_lr, config.opt_eps, config.value_grad_clip, 166 | **kw) 167 | 168 | def _train( 169 | self, start, objective=None, action=None, reward=None, imagine=None, tape=None, repeats=None): 170 | objective = objective or self._reward 171 | self._update_slow_target() 172 | metrics = {} 173 | 174 | with tools.RequiresGrad(self.actor): 175 | with torch.cuda.amp.autocast(self._use_amp): 176 | imag_feat, imag_state, imag_action = self._imagine( 177 | start, self.actor, self._config.imag_horizon, repeats) 178 | reward = objective(imag_feat, imag_state, imag_action) 179 | actor_ent = self.actor(imag_feat).entropy() 180 | state_ent = self._world_model.dynamics.get_dist( 181 | imag_state).entropy() 182 | target, weights = self._compute_target( 183 | imag_feat, imag_state, imag_action, reward, actor_ent, state_ent, 184 | self._config.slow_actor_target) 185 | actor_loss, mets = self._compute_actor_loss( 186 | imag_feat, imag_state, imag_action, target, actor_ent, state_ent, 187 | weights) 188 | metrics.update(mets) 189 | if self._config.slow_value_target != self._config.slow_actor_target: 190 | target, weights = self._compute_target( 191 | imag_feat, imag_state, imag_action, reward, actor_ent, state_ent, 192 | self._config.slow_value_target) 193 | value_input = imag_feat 194 | 195 | with tools.RequiresGrad(self.value): 196 | with torch.cuda.amp.autocast(self._use_amp): 197 | value = self.value(value_input[:-1].detach()) 198 | target = torch.stack(target, dim=1) 199 | value_loss = -value.log_prob(target.detach()) 200 | if self._config.value_decay: 201 | value_loss += self._config.value_decay * value.mode() 202 | value_loss = torch.mean(weights[:-1] * value_loss[:,:,None]) 203 | 204 | metrics['reward_mean'] = to_np(torch.mean(reward)) 205 | metrics['reward_std'] = to_np(torch.std(reward)) 206 | metrics['actor_ent'] = to_np(torch.mean(actor_ent)) 207 | with tools.RequiresGrad(self): 208 | metrics.update(self._actor_opt(actor_loss, self.actor.parameters())) 209 | metrics.update(self._value_opt(value_loss, self.value.parameters())) 210 | return imag_feat, imag_state, imag_action, weights, metrics 211 | 212 | def _imagine(self, start, policy, horizon, repeats=None): 213 | dynamics = self._world_model.dynamics 214 | if repeats: 215 | raise NotImplemented("repeats is not implemented in this version") 216 | flatten = lambda x: x.reshape([-1] + list(x.shape[2:])) 217 | start = {k: flatten(v) for k, v in start.items()} 218 | def step(prev, _): 219 | state, _, _ = prev 220 | feat = dynamics.get_feat(state) 221 | inp = feat.detach() if self._stop_grad_actor else feat 222 | action = policy(inp).sample() 223 | succ = dynamics.img_step(state, action, sample=self._config.imag_sample) 224 | return succ, feat, action 225 | feat = 0 * dynamics.get_feat(start) 226 | action = policy(feat).mode() 227 | succ, feats, actions = tools.static_scan( 228 | step, [torch.arange(horizon)], (start, feat, action)) 229 | states = {k: torch.cat([ 230 | start[k][None], v[:-1]], 0) for k, v in succ.items()} 231 | if repeats: 232 | raise NotImplemented("repeats is not implemented in this version") 233 | 234 | return feats, states, actions 235 | 236 | def _compute_target( 237 | self, imag_feat, imag_state, imag_action, reward, actor_ent, state_ent, 238 | slow): 239 | if 'discount' in self._world_model.heads: 240 | inp = self._world_model.dynamics.get_feat(imag_state) 241 | discount = self._world_model.heads['discount'](inp).mean 242 | else: 243 | discount = self._config.discount * torch.ones_like(reward) 244 | if self._config.future_entropy and self._config.actor_entropy() > 0: 245 | reward += self._config.actor_entropy() * actor_ent 246 | if self._config.future_entropy and self._config.actor_state_entropy() > 0: 247 | reward += self._config.actor_state_entropy() * state_ent 248 | if slow: 249 | value = self._slow_value(imag_feat).mode() 250 | else: 251 | value = self.value(imag_feat).mode() 252 | target = tools.lambda_return( 253 | reward[:-1], value[:-1], discount[:-1], 254 | bootstrap=value[-1], lambda_=self._config.discount_lambda, axis=0) 255 | weights = torch.cumprod( 256 | torch.cat([torch.ones_like(discount[:1]), discount[:-1]], 0), 0).detach() 257 | return target, weights 258 | 259 | def _compute_actor_loss( 260 | self, imag_feat, imag_state, imag_action, target, actor_ent, state_ent, 261 | weights): 262 | metrics = {} 263 | inp = imag_feat.detach() if self._stop_grad_actor else imag_feat 264 | policy = self.actor(inp) 265 | actor_ent = policy.entropy() 266 | target = torch.stack(target, dim=1) 267 | if self._config.imag_gradient == 'dynamics': 268 | actor_target = target 269 | elif self._config.imag_gradient == 'reinforce': 270 | actor_target = policy.log_prob(imag_action)[:-1][:, :, None] * ( 271 | target - self.value(imag_feat[:-1]).mode()).detach() 272 | elif self._config.imag_gradient == 'both': 273 | actor_target = policy.log_prob(imag_action)[:-1][:, :, None] * ( 274 | target - self.value(imag_feat[:-1]).mode()).detach() 275 | mix = self._config.imag_gradient_mix() 276 | actor_target = mix * target + (1 - mix) * actor_target 277 | metrics['imag_gradient_mix'] = mix 278 | else: 279 | raise NotImplementedError(self._config.imag_gradient) 280 | if not self._config.future_entropy and (self._config.actor_entropy() > 0): 281 | actor_target += self._config.actor_entropy() * actor_ent[:-1][:,:,None] 282 | if not self._config.future_entropy and (self._config.actor_state_entropy() > 0): 283 | actor_target += self._config.actor_state_entropy() * state_ent[:-1] 284 | actor_loss = -torch.mean(weights[:-1] * actor_target) 285 | return actor_loss, metrics 286 | 287 | def _update_slow_target(self): 288 | if self._config.slow_value_target or self._config.slow_actor_target: 289 | if self._updates % self._config.slow_target_update == 0: 290 | mix = self._config.slow_target_fraction 291 | for s, d in zip(self.value.parameters(), self._slow_value.parameters()): 292 | d.data = mix * s.data + (1 - mix) * d.data 293 | self._updates += 1 294 | 295 | 296 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from torch import distributions as torchd 7 | 8 | import tools 9 | 10 | 11 | class RSSM(nn.Module): 12 | 13 | def __init__( 14 | self, stoch=30, deter=200, hidden=200, layers_input=1, layers_output=1, 15 | rec_depth=1, shared=False, discrete=False, act=nn.ELU, 16 | mean_act='none', std_act='softplus', temp_post=True, min_std=0.1, 17 | cell='gru', 18 | num_actions=None, embed = None, device=None): 19 | super(RSSM, self).__init__() 20 | self._stoch = stoch 21 | self._deter = deter 22 | self._hidden = hidden 23 | self._min_std = min_std 24 | self._layers_input = layers_input 25 | self._layers_output = layers_output 26 | self._rec_depth = rec_depth 27 | self._shared = shared 28 | self._discrete = discrete 29 | self._act = act 30 | self._mean_act = mean_act 31 | self._std_act = std_act 32 | self._temp_post = temp_post 33 | self._embed = embed 34 | self._device = device 35 | 36 | inp_layers = [] 37 | if self._discrete: 38 | inp_dim = self._stoch * self._discrete + num_actions 39 | else: 40 | inp_dim = self._stoch + num_actions 41 | if self._shared: 42 | inp_dim += self._embed 43 | for i in range(self._layers_input): 44 | inp_layers.append(nn.Linear(inp_dim, self._hidden)) 45 | inp_layers.append(self._act()) 46 | if i == 0: 47 | inp_dim = self._hidden 48 | self._inp_layers = nn.Sequential(*inp_layers) 49 | 50 | if cell == 'gru': 51 | self._cell = GRUCell(self._hidden, self._deter) 52 | elif cell == 'gru_layer_norm': 53 | self._cell = GRUCell(self._hidden, self._deter, norm=True) 54 | else: 55 | raise NotImplementedError(cell) 56 | 57 | img_out_layers = [] 58 | inp_dim = self._deter 59 | for i in range(self._layers_output): 60 | img_out_layers.append(nn.Linear(inp_dim, self._hidden)) 61 | img_out_layers.append(self._act()) 62 | if i == 0: 63 | inp_dim = self._hidden 64 | self._img_out_layers = nn.Sequential(*img_out_layers) 65 | 66 | obs_out_layers = [] 67 | if self._temp_post: 68 | inp_dim = self._deter + self._embed 69 | else: 70 | inp_dim = self._embed 71 | for i in range(self._layers_output): 72 | obs_out_layers.append(nn.Linear(inp_dim, self._hidden)) 73 | obs_out_layers.append(self._act()) 74 | if i == 0: 75 | inp_dim = self._hidden 76 | self._obs_out_layers = nn.Sequential(*obs_out_layers) 77 | 78 | if self._discrete: 79 | self._ims_stat_layer = nn.Linear(self._hidden, self._stoch*self._discrete) 80 | self._obs_stat_layer = nn.Linear(self._hidden, self._stoch*self._discrete) 81 | else: 82 | self._ims_stat_layer = nn.Linear(self._hidden, 2*self._stoch) 83 | self._obs_stat_layer = nn.Linear(self._hidden, 2*self._stoch) 84 | 85 | def initial(self, batch_size): 86 | deter = torch.zeros(batch_size, self._deter).to(self._device) 87 | if self._discrete: 88 | state = dict( 89 | logit=torch.zeros([batch_size, self._stoch, self._discrete]).to(self._device), 90 | stoch=torch.zeros([batch_size, self._stoch, self._discrete]).to(self._device), 91 | deter=deter) 92 | else: 93 | state = dict( 94 | mean=torch.zeros([batch_size, self._stoch]).to(self._device), 95 | std=torch.zeros([batch_size, self._stoch]).to(self._device), 96 | stoch=torch.zeros([batch_size, self._stoch]).to(self._device), 97 | deter=deter) 98 | return state 99 | 100 | def observe(self, embed, action, state=None): 101 | swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape)))) 102 | if state is None: 103 | state = self.initial(action.shape[0]) 104 | embed, action = swap(embed), swap(action) 105 | post, prior = tools.static_scan( 106 | lambda prev_state, prev_act, embed: self.obs_step( 107 | prev_state[0], prev_act, embed), 108 | (action, embed), (state, state)) 109 | post = {k: swap(v) for k, v in post.items()} 110 | prior = {k: swap(v) for k, v in prior.items()} 111 | return post, prior 112 | 113 | def imagine(self, action, state=None): 114 | swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape)))) 115 | if state is None: 116 | state = self.initial(action.shape[0]) 117 | assert isinstance(state, dict), state 118 | action = action 119 | action = swap(action) 120 | prior = tools.static_scan(self.img_step, [action], state) 121 | prior = prior[0] 122 | prior = {k: swap(v) for k, v in prior.items()} 123 | return prior 124 | 125 | def get_feat(self, state): 126 | stoch = state['stoch'] 127 | if self._discrete: 128 | shape = list(stoch.shape[:-2]) + [self._stoch * self._discrete] 129 | stoch = stoch.reshape(shape) 130 | return torch.cat([stoch, state['deter']], -1) 131 | 132 | def get_dist(self, state, dtype=None): 133 | if self._discrete: 134 | logit = state['logit'] 135 | dist = torchd.independent.Independent(tools.OneHotDist(logit), 1) 136 | else: 137 | mean, std = state['mean'], state['std'] 138 | dist = tools.ContDist(torchd.independent.Independent( 139 | torchd.normal.Normal(mean, std), 1)) 140 | return dist 141 | 142 | def obs_step(self, prev_state, prev_action, embed, sample=True): 143 | prior = self.img_step(prev_state, prev_action, None, sample) 144 | if self._shared: 145 | post = self.img_step(prev_state, prev_action, embed, sample) 146 | else: 147 | if self._temp_post: 148 | x = torch.cat([prior['deter'], embed], -1) 149 | else: 150 | x = embed 151 | x = self._obs_out_layers(x) 152 | stats = self._suff_stats_layer('obs', x) 153 | if sample: 154 | stoch = self.get_dist(stats).sample() 155 | else: 156 | stoch = self.get_dist(stats).mode() 157 | post = {'stoch': stoch, 'deter': prior['deter'], **stats} 158 | return post, prior 159 | 160 | def img_step(self, prev_state, prev_action, embed=None, sample=True): 161 | prev_stoch = prev_state['stoch'] 162 | if self._discrete: 163 | shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete] 164 | prev_stoch = prev_stoch.reshape(shape) 165 | if self._shared: 166 | if embed is None: 167 | shape = list(prev_action.shape[:-1]) + [self._embed] 168 | embed = torch.zeros(shape) 169 | x = torch.cat([prev_stoch, prev_action, embed], -1) 170 | else: 171 | x = torch.cat([prev_stoch, prev_action], -1) 172 | x = self._inp_layers(x) 173 | for _ in range(self._rec_depth): # rec depth is not correctly implemented 174 | deter = prev_state['deter'] 175 | x, deter = self._cell(x, [deter]) 176 | deter = deter[0] # Keras wraps the state in a list. 177 | x = self._img_out_layers(x) 178 | stats = self._suff_stats_layer('ims', x) 179 | if sample: 180 | stoch = self.get_dist(stats).sample() 181 | else: 182 | stoch = self.get_dist(stats).mode() 183 | prior = {'stoch': stoch, 'deter': deter, **stats} 184 | return prior 185 | 186 | def _suff_stats_layer(self, name, x): 187 | if self._discrete: 188 | if name == 'ims': 189 | x = self._ims_stat_layer(x) 190 | elif name == 'obs': 191 | x = self._obs_stat_layer(x) 192 | else: 193 | raise NotImplementedError 194 | logit = x.reshape(list(x.shape[:-1]) + [self._stoch, self._discrete]) 195 | return {'logit': logit} 196 | else: 197 | if name == 'ims': 198 | x = self._ims_stat_layer(x) 199 | elif name == 'obs': 200 | x = self._obs_stat_layer(x) 201 | else: 202 | raise NotImplementedError 203 | mean, std = torch.split(x, [self._stoch]*2, -1) 204 | mean = { 205 | 'none': lambda: mean, 206 | 'tanh5': lambda: 5.0 * torch.tanh(mean / 5.0), 207 | }[self._mean_act]() 208 | std = { 209 | 'softplus': lambda: torch.softplus(std), 210 | 'abs': lambda: torch.abs(std + 1), 211 | 'sigmoid': lambda: torch.sigmoid(std), 212 | 'sigmoid2': lambda: 2 * torch.sigmoid(std / 2), 213 | }[self._std_act]() 214 | std = std + self._min_std 215 | return {'mean': mean, 'std': std} 216 | 217 | def kl_loss(self, post, prior, forward, balance, free, scale): 218 | kld = torchd.kl.kl_divergence 219 | dist = lambda x: self.get_dist(x) 220 | sg = lambda x: {k: v.detach() for k, v in x.items()} 221 | lhs, rhs = (prior, post) if forward else (post, prior) 222 | mix = balance if forward else (1 - balance) 223 | if balance == 0.5: 224 | value = kld(dist(lhs) if self._discrete else dist(lhs)._dist, 225 | dist(rhs) if self._discrete else dist(rhs)._dist) 226 | loss = torch.mean(torch.maximum(value, free)) 227 | else: 228 | value_lhs = value = kld(dist(lhs) if self._discrete else dist(lhs)._dist, 229 | dist(sg(rhs)) if self._discrete else dist(sg(rhs))._dist) 230 | value_rhs = kld(dist(sg(lhs)) if self._discrete else dist(sg(lhs))._dist, 231 | dist(rhs) if self._discrete else dist(rhs)._dist) 232 | loss_lhs = torch.maximum(torch.mean(value_lhs), torch.Tensor([free])[0]) 233 | loss_rhs = torch.maximum(torch.mean(value_rhs), torch.Tensor([free])[0]) 234 | loss = mix * loss_lhs + (1 - mix) * loss_rhs 235 | loss *= scale 236 | return loss, value 237 | 238 | 239 | class ConvEncoder(nn.Module): 240 | 241 | def __init__(self, grayscale=False, 242 | depth=32, act=nn.ReLU, kernels=(4, 4, 4, 4)): 243 | super(ConvEncoder, self).__init__() 244 | self._act = act 245 | self._depth = depth 246 | self._kernels = kernels 247 | 248 | layers = [] 249 | for i, kernel in enumerate(self._kernels): 250 | if i == 0: 251 | if grayscale: 252 | inp_dim = 1 253 | else: 254 | inp_dim = 3 255 | else: 256 | inp_dim = 2 ** (i-1) * self._depth 257 | depth = 2 ** i * self._depth 258 | layers.append(nn.Conv2d(inp_dim, depth, kernel, 2)) 259 | layers.append(act()) 260 | self.layers = nn.Sequential(*layers) 261 | 262 | def __call__(self, obs): 263 | x = obs['image'].reshape((-1,) + tuple(obs['image'].shape[-3:])) 264 | x = x.permute(0, 3, 1, 2) 265 | x = self.layers(x) 266 | x = x.reshape([x.shape[0], np.prod(x.shape[1:])]) 267 | shape = list(obs['image'].shape[:-3]) + [x.shape[-1]] 268 | return x.reshape(shape) 269 | 270 | 271 | class ConvDecoder(nn.Module): 272 | 273 | def __init__( 274 | self, inp_depth, 275 | depth=32, act=nn.ReLU, shape=(3, 64, 64), kernels=(5, 5, 6, 6), 276 | thin=True): 277 | super(ConvDecoder, self).__init__() 278 | self._inp_depth = inp_depth 279 | self._act = act 280 | self._depth = depth 281 | self._shape = shape 282 | self._kernels = kernels 283 | self._thin = thin 284 | 285 | if self._thin: 286 | self._linear_layer = nn.Linear(inp_depth, 32 * self._depth) 287 | else: 288 | self._linear_layer = nn.Linear(inp_depth, 128 * self._depth) 289 | inp_dim = 32 * self._depth 290 | 291 | cnnt_layers = [] 292 | for i, kernel in enumerate(self._kernels): 293 | depth = 2 ** (len(self._kernels) - i - 2) * self._depth 294 | act = self._act 295 | if i == len(self._kernels) - 1: 296 | #depth = self._shape[-1] 297 | depth = self._shape[0] 298 | act = None 299 | if i != 0: 300 | inp_dim = 2 ** (len(self._kernels) - (i-1) - 2) * self._depth 301 | cnnt_layers.append(nn.ConvTranspose2d(inp_dim, depth, kernel, 2)) 302 | if act is not None: 303 | cnnt_layers.append(act()) 304 | self._cnnt_layers = nn.Sequential(*cnnt_layers) 305 | 306 | def __call__(self, features, dtype=None): 307 | if self._thin: 308 | x = self._linear_layer(features) 309 | x = x.reshape([-1, 1, 1, 32 * self._depth]) 310 | x = x.permute(0,3,1,2) 311 | else: 312 | x = self._linear_layer(features) 313 | x = x.reshape([-1, 2, 2, 32 * self._depth]) 314 | x = x.permute(0,3,1,2) 315 | x = self._cnnt_layers(x) 316 | mean = x.reshape(features.shape[:-1] + self._shape) 317 | mean = mean.permute(0, 1, 3, 4, 2) 318 | return tools.ContDist(torchd.independent.Independent( 319 | torchd.normal.Normal(mean, 1), len(self._shape))) 320 | 321 | 322 | class DenseHead(nn.Module): 323 | 324 | def __init__( 325 | self, inp_dim, 326 | shape, layers, units, act=nn.ELU, dist='normal', std=1.0): 327 | super(DenseHead, self).__init__() 328 | self._shape = (shape,) if isinstance(shape, int) else shape 329 | if len(self._shape) == 0: 330 | self._shape = (1,) 331 | self._layers = layers 332 | self._units = units 333 | self._act = act 334 | self._dist = dist 335 | self._std = std 336 | 337 | mean_layers = [] 338 | for index in range(self._layers): 339 | mean_layers.append(nn.Linear(inp_dim, self._units)) 340 | mean_layers.append(act()) 341 | if index == 0: 342 | inp_dim = self._units 343 | mean_layers.append(nn.Linear(inp_dim, np.prod(self._shape))) 344 | self._mean_layers = nn.Sequential(*mean_layers) 345 | 346 | if self._std == 'learned': 347 | self._std_layer = nn.Linear(self._units, np.prod(self._shape)) 348 | 349 | def __call__(self, features, dtype=None): 350 | x = features 351 | mean = self._mean_layers(x) 352 | if self._std == 'learned': 353 | std = self._std_layer(x) 354 | std = torch.softplus(std) + 0.01 355 | else: 356 | std = self._std 357 | if self._dist == 'normal': 358 | return tools.ContDist(torchd.independent.Independent( 359 | torchd.normal.Normal(mean, std), len(self._shape))) 360 | if self._dist == 'huber': 361 | return tools.ContDist(torchd.independent.Independent( 362 | tools.UnnormalizedHuber(mean, std, 1.0), len(self._shape))) 363 | if self._dist == 'binary': 364 | return tools.Bernoulli(torchd.independent.Independent( 365 | torchd.bernoulli.Bernoulli(logits=mean), len(self._shape))) 366 | raise NotImplementedError(self._dist) 367 | 368 | 369 | class ActionHead(nn.Module): 370 | 371 | def __init__( 372 | self, inp_dim, size, layers, units, act=nn.ELU, dist='trunc_normal', 373 | init_std=0.0, min_std=0.1, action_disc=5, temp=0.1, outscale=0): 374 | super(ActionHead, self).__init__() 375 | self._size = size 376 | self._layers = layers 377 | self._units = units 378 | self._dist = dist 379 | self._act = act 380 | self._min_std = min_std 381 | self._init_std = init_std 382 | self._action_disc = action_disc 383 | self._temp = temp() if callable(temp) else temp 384 | self._outscale = outscale 385 | 386 | pre_layers = [] 387 | for index in range(self._layers): 388 | pre_layers.append(nn.Linear(inp_dim, self._units)) 389 | pre_layers.append(act()) 390 | if index == 0: 391 | inp_dim = self._units 392 | self._pre_layers = nn.Sequential(*pre_layers) 393 | 394 | if self._dist in ['tanh_normal','tanh_normal_5','normal','trunc_normal']: 395 | self._dist_layer = nn.Linear(self._units, 2 * self._size) 396 | elif self._dist in ['normal_1','onehot','onehot_gumbel']: 397 | self._dist_layer = nn.Linear(self._units, self._size) 398 | 399 | def __call__(self, features, dtype=None): 400 | x = features 401 | x = self._pre_layers(x) 402 | if self._dist == 'tanh_normal': 403 | x = self._dist_layer(x) 404 | mean, std = torch.split(x, 2, -1) 405 | mean = torch.tanh(mean) 406 | std = F.softplus(std + self._init_std) + self._min_std 407 | dist = torchd.normal.Normal(mean, std) 408 | dist = torchd.transformed_distribution.TransformedDistribution( 409 | dist, tools.TanhBijector()) 410 | dist = torchd.independent.Independent(dist, 1) 411 | dist = tools.SampleDist(dist) 412 | elif self._dist == 'tanh_normal_5': 413 | x = self._dist_layer(x) 414 | mean, std = torch.split(x, 2, -1) 415 | mean = 5 * torch.tanh(mean / 5) 416 | std = F.softplus(std + 5) + 5 417 | dist = torchd.normal.Normal(mean, std) 418 | dist = torchd.transformed_distribution.TransformedDistribution( 419 | dist, tools.TanhBijector()) 420 | dist = torchd.independent.Independent(dist, 1) 421 | dist = tools.SampleDist(dist) 422 | elif self._dist == 'normal': 423 | x = self._dist_layer(x) 424 | mean, std = torch.split(x, 2, -1) 425 | std = F.softplus(std + self._init_std) + self._min_std 426 | dist = torchd.normal.Normal(mean, std) 427 | dist = tools.ContDist(torchd.independent.Independent(dist, 1)) 428 | elif self._dist == 'normal_1': 429 | x = self._dist_layer(x) 430 | dist = torchd.normal.Normal(mean, 1) 431 | dist = tools.ContDist(torchd.independent.Independent(dist, 1)) 432 | elif self._dist == 'trunc_normal': 433 | x = self._dist_layer(x) 434 | mean, std = torch.split(x, [self._size]*2, -1) 435 | mean = torch.tanh(mean) 436 | std = 2 * torch.sigmoid(std / 2) + self._min_std 437 | dist = tools.SafeTruncatedNormal(mean, std, -1, 1) 438 | dist = tools.ContDist(torchd.independent.Independent(dist, 1)) 439 | elif self._dist == 'onehot': 440 | x = self._dist_layer(x) 441 | dist = tools.OneHotDist(x) 442 | elif self._dist == 'onehot_gumble': 443 | x = self._dist_layer(x) 444 | temp = self._temp 445 | dist = tools.ContDist(torchd.gumbel.Gumbel(x, 1/temp)) 446 | else: 447 | raise NotImplementedError(self._dist) 448 | return dist 449 | 450 | 451 | class GRUCell(nn.Module): 452 | 453 | def __init__(self, inp_size, 454 | size, norm=False, act=torch.tanh, update_bias=-1): 455 | super(GRUCell, self).__init__() 456 | self._inp_size = inp_size 457 | self._size = size 458 | self._act = act 459 | self._norm = norm 460 | self._update_bias = update_bias 461 | self._layer = nn.Linear(inp_size+size, 3*size, 462 | bias=norm is not None) 463 | if norm: 464 | self._norm = nn.LayerNorm(3*size) 465 | 466 | @property 467 | def state_size(self): 468 | return self._size 469 | 470 | def forward(self, inputs, state): 471 | state = state[0] # Keras wraps the state in a list. 472 | parts = self._layer(torch.cat([inputs, state], -1)) 473 | if self._norm: 474 | parts = self._norm(parts) 475 | reset, cand, update = torch.split(parts, [self._size]*3, -1) 476 | reset = torch.sigmoid(reset) 477 | cand = self._act(reset * cand) 478 | update = torch.sigmoid(update + self._update_bias) 479 | output = update * cand + (1 - update) * state 480 | return output, [output] 481 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.8.1 2 | numpy==1.20.1 3 | torchvision==0.9.1 4 | tensorboard==2.5.0 5 | pandas==1.2.4 6 | matplotlib==3.4.1 7 | ruamel.yaml==0.17.4 8 | gym[atari]==0.18.0 9 | moviepy==1.0.3 10 | einops==0.3.0 11 | -------------------------------------------------------------------------------- /tools.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import io 3 | import json 4 | import pathlib 5 | import pickle 6 | import re 7 | import time 8 | import uuid 9 | 10 | import numpy as np 11 | 12 | import torch 13 | from torch import nn 14 | from torch.nn import functional as F 15 | from torch import distributions as torchd 16 | from torch.utils.data import Dataset 17 | from torch.utils.tensorboard import SummaryWriter 18 | 19 | 20 | class RequiresGrad: 21 | 22 | def __init__(self, model): 23 | self._model = model 24 | 25 | def __enter__(self): 26 | self._model.requires_grad_(requires_grad=True) 27 | 28 | def __exit__(self, *args): 29 | self._model.requires_grad_(requires_grad=False) 30 | 31 | 32 | class TimeRecording: 33 | 34 | def __init__(self, comment): 35 | self._comment = comment 36 | 37 | def __enter__(self): 38 | self._st = torch.cuda.Event(enable_timing=True) 39 | self._nd = torch.cuda.Event(enable_timing=True) 40 | self._st.record() 41 | 42 | def __exit__(self, *args): 43 | self._nd.record() 44 | torch.cuda.synchronize() 45 | print(self._comment, self._st.elapsed_time(self._nd)/1000) 46 | 47 | 48 | class Logger: 49 | 50 | def __init__(self, logdir, step): 51 | self._logdir = logdir 52 | self._writer = SummaryWriter(log_dir=str(logdir), max_queue=1000) 53 | self._last_step = None 54 | self._last_time = None 55 | self._scalars = {} 56 | self._images = {} 57 | self._videos = {} 58 | self.step = step 59 | 60 | def scalar(self, name, value): 61 | self._scalars[name] = float(value) 62 | 63 | def image(self, name, value): 64 | self._images[name] = np.array(value) 65 | 66 | def video(self, name, value): 67 | self._videos[name] = np.array(value) 68 | 69 | def write(self, fps=False): 70 | scalars = list(self._scalars.items()) 71 | if fps: 72 | scalars.append(('fps', self._compute_fps(self.step))) 73 | print(f'[{self.step}]', ' / '.join(f'{k} {v:.1f}' for k, v in scalars)) 74 | with (self._logdir / 'metrics.jsonl').open('a') as f: 75 | f.write(json.dumps({'step': self.step, ** dict(scalars)}) + '\n') 76 | for name, value in scalars: 77 | self._writer.add_scalar('scalars/' + name, value, self.step) 78 | for name, value in self._images.items(): 79 | self._writer.add_image(name, value, self.step) 80 | for name, value in self._videos.items(): 81 | name = name if isinstance(name, str) else name.decode('utf-8') 82 | if np.issubdtype(value.dtype, np.floating): 83 | value = np.clip(255 * value, 0, 255).astype(np.uint8) 84 | B, T, H, W, C = value.shape 85 | value = value.transpose(1, 4, 2, 0, 3).reshape((1, T, C, H, B*W)) 86 | self._writer.add_video(name, value, self.step, 16) 87 | 88 | self._writer.flush() 89 | self._scalars = {} 90 | self._images = {} 91 | self._videos = {} 92 | 93 | def _compute_fps(self, step): 94 | if self._last_step is None: 95 | self._last_time = time.time() 96 | self._last_step = step 97 | return 0 98 | steps = step - self._last_step 99 | duration = time.time() - self._last_time 100 | self._last_time += duration 101 | self._last_step = step 102 | return steps / duration 103 | 104 | def offline_scalar(self, name, value, step): 105 | self._writer.add_scalar('scalars/'+name, value, step) 106 | 107 | def offline_video(self, name, value, step): 108 | if np.issubdtype(value.dtype, np.floating): 109 | value = np.clip(255 * value, 0, 255).astype(np.uint8) 110 | B, T, H, W, C = value.shape 111 | value = value.transpose(1, 4, 2, 0, 3).reshape((1, T, C, H, B*W)) 112 | self._writer.add_video(name, value, step, 16) 113 | 114 | 115 | def simulate(agent, envs, steps=0, episodes=0, state=None): 116 | # Initialize or unpack simulation state. 117 | if state is None: 118 | step, episode = 0, 0 119 | done = np.ones(len(envs), np.bool) 120 | length = np.zeros(len(envs), np.int32) 121 | obs = [None] * len(envs) 122 | agent_state = None 123 | reward = [0]*len(envs) 124 | else: 125 | step, episode, done, length, obs, agent_state, reward = state 126 | while (steps and step < steps) or (episodes and episode < episodes): 127 | # Reset envs if necessary. 128 | if done.any(): 129 | indices = [index for index, d in enumerate(done) if d] 130 | results = [envs[i].reset() for i in indices] 131 | for index, result in zip(indices, results): 132 | obs[index] = result 133 | reward = [reward[i]*(1-done[i]) for i in range(len(envs))] 134 | # Step agents. 135 | obs = {k: np.stack([o[k] for o in obs]) for k in obs[0]} 136 | action, agent_state = agent(obs, done, agent_state, reward) 137 | if isinstance(action, dict): 138 | action = [ 139 | {k: np.array(action[k][i].detach().cpu()) for k in action} 140 | for i in range(len(envs))] 141 | else: 142 | action = np.array(action) 143 | assert len(action) == len(envs) 144 | # Step envs. 145 | results = [e.step(a) for e, a in zip(envs, action)] 146 | obs, reward, done = zip(*[p[:3] for p in results]) 147 | obs = list(obs) 148 | reward = list(reward) 149 | done = np.stack(done) 150 | episode += int(done.sum()) 151 | length += 1 152 | step += (done * length).sum() 153 | length *= (1 - done) 154 | 155 | return (step - steps, episode - episodes, done, length, obs, agent_state, reward) 156 | 157 | 158 | def save_episodes(directory, episodes): 159 | directory = pathlib.Path(directory).expanduser() 160 | directory.mkdir(parents=True, exist_ok=True) 161 | timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S') 162 | filenames = [] 163 | for episode in episodes: 164 | identifier = str(uuid.uuid4().hex) 165 | length = len(episode['reward']) 166 | filename = directory / f'{timestamp}-{identifier}-{length}.npz' 167 | with io.BytesIO() as f1: 168 | np.savez_compressed(f1, **episode) 169 | f1.seek(0) 170 | with filename.open('wb') as f2: 171 | f2.write(f1.read()) 172 | filenames.append(filename) 173 | return filenames 174 | 175 | 176 | def from_generator(generator, batch_size): 177 | while True: 178 | batch = [] 179 | for _ in range(batch_size): 180 | batch.append(next(generator)) 181 | data = {} 182 | for key in batch[0].keys(): 183 | data[key] = [] 184 | for i in range(batch_size): 185 | data[key].append(batch[i][key]) 186 | data[key] = np.stack(data[key], 0) 187 | yield data 188 | 189 | 190 | def sample_episodes(episodes, length=None, balance=False, seed=0): 191 | random = np.random.RandomState(seed) 192 | while True: 193 | episode = random.choice(list(episodes.values())) 194 | if length: 195 | total = len(next(iter(episode.values()))) 196 | available = total - length 197 | if available < 1: 198 | print(f'Skipped short episode of length {available}.') 199 | continue 200 | if balance: 201 | index = min(random.randint(0, total), available) 202 | else: 203 | index = int(random.randint(0, available + 1)) 204 | episode = {k: v[index: index + length] for k, v in episode.items()} 205 | yield episode 206 | 207 | 208 | def load_episodes(directory, limit=None, reverse=True): 209 | directory = pathlib.Path(directory).expanduser() 210 | episodes = {} 211 | total = 0 212 | if reverse: 213 | for filename in reversed(sorted(directory.glob('*.npz'))): 214 | try: 215 | with filename.open('rb') as f: 216 | episode = np.load(f) 217 | episode = {k: episode[k] for k in episode.keys()} 218 | except Exception as e: 219 | print(f'Could not load episode: {e}') 220 | continue 221 | episodes[str(filename)] = episode 222 | total += len(episode['reward']) - 1 223 | if limit and total >= limit: 224 | break 225 | else: 226 | for filename in sorted(directory.glob('*.npz')): 227 | try: 228 | with filename.open('rb') as f: 229 | episode = np.load(f) 230 | episode = {k: episode[k] for k in episode.keys()} 231 | except Exception as e: 232 | print(f'Could not load episode: {e}') 233 | continue 234 | episodes[str(filename)] = episode 235 | total += len(episode['reward']) - 1 236 | if limit and total >= limit: 237 | break 238 | return episodes 239 | 240 | 241 | class SampleDist: 242 | 243 | def __init__(self, dist, samples=100): 244 | self._dist = dist 245 | self._samples = samples 246 | 247 | @property 248 | def name(self): 249 | return 'SampleDist' 250 | 251 | def __getattr__(self, name): 252 | return getattr(self._dist, name) 253 | 254 | def mean(self): 255 | samples = self._dist.sample(self._samples) 256 | return torch.mean(samples, 0) 257 | 258 | def mode(self): 259 | sample = self._dist.sample(self._samples) 260 | logprob = self._dist.log_prob(sample) 261 | return sample[torch.argmax(logprob)][0] 262 | 263 | def entropy(self): 264 | sample = self._dist.sample(self._samples) 265 | logprob = self.log_prob(sample) 266 | return -torch.mean(logprob, 0) 267 | 268 | 269 | class OneHotDist(torchd.one_hot_categorical.OneHotCategorical): 270 | 271 | def __init__(self, logits=None, probs=None): 272 | super().__init__(logits=logits, probs=probs) 273 | 274 | def mode(self): 275 | _mode = F.one_hot(torch.argmax(super().logits, axis=-1), super().logits.shape[-1]) 276 | return _mode.detach() + super().logits - super().logits.detach() 277 | 278 | def sample(self, sample_shape=(), seed=None): 279 | if seed is not None: 280 | raise ValueError('need to check') 281 | sample = super().sample(sample_shape) 282 | probs = super().probs 283 | while len(probs.shape) < len(sample.shape): 284 | probs = probs[None] 285 | sample += probs - probs.detach() 286 | return sample 287 | 288 | 289 | class ContDist: 290 | 291 | def __init__(self, dist=None): 292 | super().__init__() 293 | self._dist = dist 294 | self.mean = dist.mean 295 | 296 | def __getattr__(self, name): 297 | return getattr(self._dist, name) 298 | 299 | def entropy(self): 300 | return self._dist.entropy() 301 | 302 | def mode(self): 303 | return self._dist.mean 304 | 305 | def sample(self, sample_shape=()): 306 | return self._dist.rsample(sample_shape) 307 | 308 | def log_prob(self, x): 309 | return self._dist.log_prob(x) 310 | 311 | 312 | class Bernoulli: 313 | 314 | def __init__(self, dist=None): 315 | super().__init__() 316 | self._dist = dist 317 | self.mean = dist.mean 318 | 319 | def __getattr__(self, name): 320 | return getattr(self._dist, name) 321 | 322 | def entropy(self): 323 | return self._dist.entropy() 324 | 325 | def mode(self): 326 | _mode = torch.round(self._dist.mean) 327 | return _mode.detach() +self._dist.mean - self._dist.mean.detach() 328 | 329 | def sample(self, sample_shape=()): 330 | return self._dist.rsample(sample_shape) 331 | 332 | def log_prob(self, x): 333 | _logits = self._dist.base_dist.logits 334 | log_probs0 = -F.softplus(_logits) 335 | log_probs1 = -F.softplus(-_logits) 336 | 337 | return log_probs0 * (1-x) + log_probs1 * x 338 | 339 | 340 | class UnnormalizedHuber(torchd.normal.Normal): 341 | 342 | def __init__(self, loc, scale, threshold=1, **kwargs): 343 | super().__init__(loc, scale, **kwargs) 344 | self._threshold = threshold 345 | 346 | def log_prob(self, event): 347 | return -(torch.sqrt( 348 | (event - self.mean) ** 2 + self._threshold ** 2) - self._threshold) 349 | 350 | def mode(self): 351 | return self.mean 352 | 353 | 354 | class SafeTruncatedNormal(torchd.normal.Normal): 355 | 356 | def __init__(self, loc, scale, low, high, clip=1e-6, mult=1): 357 | super().__init__(loc, scale) 358 | self._low = low 359 | self._high = high 360 | self._clip = clip 361 | self._mult = mult 362 | 363 | def sample(self, sample_shape): 364 | event = super().sample(sample_shape) 365 | if self._clip: 366 | clipped = torch.clip(event, self._low + self._clip, 367 | self._high - self._clip) 368 | event = event - event.detach() + clipped.detach() 369 | if self._mult: 370 | event *= self._mult 371 | return event 372 | 373 | 374 | class TanhBijector(torchd.Transform): 375 | 376 | def __init__(self, validate_args=False, name='tanh'): 377 | super().__init__() 378 | 379 | def _forward(self, x): 380 | return torch.tanh(x) 381 | 382 | def _inverse(self, y): 383 | y = torch.where( 384 | (torch.abs(y) <= 1.), 385 | torch.clamp(y, -0.99999997, 0.99999997), y) 386 | y = torch.atanh(y) 387 | return y 388 | 389 | def _forward_log_det_jacobian(self, x): 390 | log2 = torch.math.log(2.0) 391 | return 2.0 * (log2 - x - torch.softplus(-2.0 * x)) 392 | 393 | 394 | def static_scan_for_lambda_return(fn, inputs, start): 395 | last = start 396 | indices = range(inputs[0].shape[0]) 397 | indices = reversed(indices) 398 | flag = True 399 | for index in indices: 400 | inp = lambda x: (_input[x] for _input in inputs) 401 | last = fn(last, *inp(index)) 402 | if flag: 403 | outputs = last 404 | flag = False 405 | else: 406 | outputs = torch.cat([outputs, last], dim=-1) 407 | outputs = torch.reshape(outputs, [outputs.shape[0], outputs.shape[1], 1]) 408 | outputs = torch.unbind(outputs, dim=0) 409 | return outputs 410 | 411 | 412 | def lambda_return( 413 | reward, value, pcont, bootstrap, lambda_, axis): 414 | # Setting lambda=1 gives a discounted Monte Carlo return. 415 | # Setting lambda=0 gives a fixed 1-step return. 416 | #assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape) 417 | assert len(reward.shape) == len(value.shape), (reward.shape, value.shape) 418 | if isinstance(pcont, (int, float)): 419 | pcont = pcont * torch.ones_like(reward) 420 | dims = list(range(len(reward.shape))) 421 | dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:] 422 | if axis != 0: 423 | reward = reward.permute(dims) 424 | value = value.permute(dims) 425 | pcont = pcont.permute(dims) 426 | if bootstrap is None: 427 | bootstrap = torch.zeros_like(value[-1]) 428 | next_values = torch.cat([value[1:], bootstrap[None]], 0) 429 | inputs = reward + pcont * next_values * (1 - lambda_) 430 | #returns = static_scan( 431 | # lambda agg, cur0, cur1: cur0 + cur1 * lambda_ * agg, 432 | # (inputs, pcont), bootstrap, reverse=True) 433 | # reimplement to optimize performance 434 | returns = static_scan_for_lambda_return( 435 | lambda agg, cur0, cur1: cur0 + cur1 * lambda_ * agg, 436 | (inputs, pcont), bootstrap) 437 | if axis != 0: 438 | returns = returns.permute(dims) 439 | return returns 440 | 441 | 442 | class Optimizer(): 443 | 444 | def __init__( 445 | self, name, parameters, lr, eps=1e-4, clip=None, wd=None, wd_pattern=r'.*', 446 | opt='adam', use_amp=False): 447 | assert 0 <= wd < 1 448 | assert not clip or 1 <= clip 449 | self._name = name 450 | self._parameters = parameters 451 | self._clip = clip 452 | self._wd = wd 453 | self._wd_pattern = wd_pattern 454 | self._opt = { 455 | 'adam': lambda: torch.optim.Adam(parameters, 456 | lr=lr, 457 | eps=eps), 458 | 'nadam': lambda: NotImplemented( 459 | f'{config.opt} is not implemented'), 460 | 'adamax': lambda: torch.optim.Adamax(parameters, 461 | lr=lr, 462 | eps=eps), 463 | 'sgd': lambda: torch.optim.SGD(parameters, 464 | lr=lr), 465 | 'momentum': lambda: torch.optim.SGD(parameters, 466 | lr=lr, 467 | momentum=0.9), 468 | }[opt]() 469 | self._scaler = torch.cuda.amp.GradScaler(enabled=use_amp) 470 | 471 | def __call__(self, loss, params, retain_graph=False): 472 | assert len(loss.shape) == 0, loss.shape 473 | metrics = {} 474 | metrics[f'{self._name}_loss'] = loss.detach().cpu().numpy() 475 | self._scaler.scale(loss).backward() 476 | self._scaler.unscale_(self._opt) 477 | #loss.backward(retain_graph=retain_graph) 478 | norm = torch.nn.utils.clip_grad_norm_(params, self._clip) 479 | if self._wd: 480 | self._apply_weight_decay(params) 481 | self._scaler.step(self._opt) 482 | self._scaler.update() 483 | #self._opt.step() 484 | self._opt.zero_grad() 485 | metrics[f'{self._name}_grad_norm'] = norm.item() 486 | return metrics 487 | 488 | def _apply_weight_decay(self, varibs): 489 | nontrivial = (self._wd_pattern != r'.*') 490 | if nontrivial: 491 | raise NotImplementedError 492 | for var in varibs: 493 | var.data = (1 - self._wd) * var.data 494 | 495 | 496 | def args_type(default): 497 | def parse_string(x): 498 | if default is None: 499 | return x 500 | if isinstance(default, bool): 501 | return bool(['False', 'True'].index(x)) 502 | if isinstance(default, int): 503 | return float(x) if ('e' in x or '.' in x) else int(x) 504 | if isinstance(default, (list, tuple)): 505 | return tuple(args_type(default[0])(y) for y in x.split(',')) 506 | return type(default)(x) 507 | def parse_object(x): 508 | if isinstance(default, (list, tuple)): 509 | return tuple(x) 510 | return x 511 | return lambda x: parse_string(x) if isinstance(x, str) else parse_object(x) 512 | 513 | 514 | def static_scan(fn, inputs, start): 515 | last = start 516 | indices = range(inputs[0].shape[0]) 517 | flag = True 518 | for index in indices: 519 | inp = lambda x: (_input[x] for _input in inputs) 520 | last = fn(last, *inp(index)) 521 | if flag: 522 | if type(last) == type({}): 523 | outputs = {key: value.clone().unsqueeze(0) for key, value in last.items()} 524 | else: 525 | outputs = [] 526 | for _last in last: 527 | if type(_last) == type({}): 528 | outputs.append({key: value.clone().unsqueeze(0) for key, value in _last.items()}) 529 | else: 530 | outputs.append(_last.clone().unsqueeze(0)) 531 | flag = False 532 | else: 533 | if type(last) == type({}): 534 | for key in last.keys(): 535 | outputs[key] = torch.cat([outputs[key], last[key].unsqueeze(0)], dim=0) 536 | else: 537 | for j in range(len(outputs)): 538 | if type(last[j]) == type({}): 539 | for key in last[j].keys(): 540 | outputs[j][key] = torch.cat([outputs[j][key], 541 | last[j][key].unsqueeze(0)], dim=0) 542 | else: 543 | outputs[j] = torch.cat([outputs[j], last[j].unsqueeze(0)], dim=0) 544 | if type(last) == type({}): 545 | outputs = [outputs] 546 | return outputs 547 | 548 | 549 | # Original version 550 | #def static_scan2(fn, inputs, start, reverse=False): 551 | # last = start 552 | # outputs = [[] for _ in range(len([start] if type(start)==type({}) else start))] 553 | # indices = range(inputs[0].shape[0]) 554 | # if reverse: 555 | # indices = reversed(indices) 556 | # for index in indices: 557 | # inp = lambda x: (_input[x] for _input in inputs) 558 | # last = fn(last, *inp(index)) 559 | # [o.append(l) for o, l in zip(outputs, [last] if type(last)==type({}) else last)] 560 | # if reverse: 561 | # outputs = [list(reversed(x)) for x in outputs] 562 | # res = [[]] * len(outputs) 563 | # for i in range(len(outputs)): 564 | # if type(outputs[i][0]) == type({}): 565 | # _res = {} 566 | # for key in outputs[i][0].keys(): 567 | # _res[key] = [] 568 | # for j in range(len(outputs[i])): 569 | # _res[key].append(outputs[i][j][key]) 570 | # #_res[key] = torch.stack(_res[key], 0) 571 | # _res[key] = faster_stack(_res[key], 0) 572 | # else: 573 | # _res = outputs[i] 574 | # #_res = torch.stack(_res, 0) 575 | # _res = faster_stack(_res, 0) 576 | # res[i] = _res 577 | # return res 578 | 579 | 580 | class Every: 581 | 582 | def __init__(self, every): 583 | self._every = every 584 | self._last = None 585 | 586 | def __call__(self, step): 587 | if not self._every: 588 | return False 589 | if self._last is None: 590 | self._last = step 591 | return True 592 | if step >= self._last + self._every: 593 | self._last += self._every 594 | return True 595 | return False 596 | 597 | 598 | class Once: 599 | 600 | def __init__(self): 601 | self._once = True 602 | 603 | def __call__(self): 604 | if self._once: 605 | self._once = False 606 | return True 607 | return False 608 | 609 | 610 | class Until: 611 | 612 | def __init__(self, until): 613 | self._until = until 614 | 615 | def __call__(self, step): 616 | if not self._until: 617 | return True 618 | return step < self._until 619 | 620 | 621 | def schedule(string, step): 622 | try: 623 | return float(string) 624 | except ValueError: 625 | match = re.match(r'linear\((.+),(.+),(.+)\)', string) 626 | if match: 627 | initial, final, duration = [float(group) for group in match.groups()] 628 | mix = torch.clip(torch.Tensor([step / duration]), 0, 1)[0] 629 | return (1 - mix) * initial + mix * final 630 | match = re.match(r'warmup\((.+),(.+)\)', string) 631 | if match: 632 | warmup, value = [float(group) for group in match.groups()] 633 | scale = torch.clip(step / warmup, 0, 1) 634 | return scale * value 635 | match = re.match(r'exp\((.+),(.+),(.+)\)', string) 636 | if match: 637 | initial, final, halflife = [float(group) for group in match.groups()] 638 | return (initial - final) * 0.5 ** (step / halflife) + final 639 | match = re.match(r'horizon\((.+),(.+),(.+)\)', string) 640 | if match: 641 | initial, final, duration = [float(group) for group in match.groups()] 642 | mix = torch.clip(step / duration, 0, 1) 643 | horizon = (1 - mix) * initial + mix * final 644 | return 1 - 1 / horizon 645 | raise NotImplementedError(string) 646 | -------------------------------------------------------------------------------- /wrappers.py: -------------------------------------------------------------------------------- 1 | import threading 2 | 3 | import gym 4 | import numpy as np 5 | 6 | 7 | class DeepMindLabyrinth(object): 8 | 9 | ACTION_SET_DEFAULT = ( 10 | (0, 0, 0, 1, 0, 0, 0), # Forward 11 | (0, 0, 0, -1, 0, 0, 0), # Backward 12 | (0, 0, -1, 0, 0, 0, 0), # Strafe Left 13 | (0, 0, 1, 0, 0, 0, 0), # Strafe Right 14 | (-20, 0, 0, 0, 0, 0, 0), # Look Left 15 | (20, 0, 0, 0, 0, 0, 0), # Look Right 16 | (-20, 0, 0, 1, 0, 0, 0), # Look Left + Forward 17 | (20, 0, 0, 1, 0, 0, 0), # Look Right + Forward 18 | (0, 0, 0, 0, 1, 0, 0), # Fire 19 | ) 20 | 21 | ACTION_SET_MEDIUM = ( 22 | (0, 0, 0, 1, 0, 0, 0), # Forward 23 | (0, 0, 0, -1, 0, 0, 0), # Backward 24 | (0, 0, -1, 0, 0, 0, 0), # Strafe Left 25 | (0, 0, 1, 0, 0, 0, 0), # Strafe Right 26 | (-20, 0, 0, 0, 0, 0, 0), # Look Left 27 | (20, 0, 0, 0, 0, 0, 0), # Look Right 28 | (0, 0, 0, 0, 0, 0, 0), # Idle. 29 | ) 30 | 31 | ACTION_SET_SMALL = ( 32 | (0, 0, 0, 1, 0, 0, 0), # Forward 33 | (-20, 0, 0, 0, 0, 0, 0), # Look Left 34 | (20, 0, 0, 0, 0, 0, 0), # Look Right 35 | ) 36 | 37 | def __init__( 38 | self, level, mode, action_repeat=4, render_size=(64, 64), 39 | action_set=ACTION_SET_DEFAULT, level_cache=None, seed=None, 40 | runfiles_path=None): 41 | assert mode in ('train', 'test') 42 | import deepmind_lab 43 | if runfiles_path: 44 | print('Setting DMLab runfiles path:', runfiles_path) 45 | deepmind_lab.set_runfiles_path(runfiles_path) 46 | self._config = {} 47 | self._config['width'] = render_size[0] 48 | self._config['height'] = render_size[1] 49 | self._config['logLevel'] = 'WARN' 50 | if mode == 'test': 51 | self._config['allowHoldOutLevels'] = 'true' 52 | self._config['mixerSeed'] = 0x600D5EED 53 | self._action_repeat = action_repeat 54 | self._random = np.random.RandomState(seed) 55 | self._env = deepmind_lab.Lab( 56 | level='contributed/dmlab30/'+level, 57 | observations=['RGB_INTERLEAVED'], 58 | config={k: str(v) for k, v in self._config.items()}, 59 | level_cache=level_cache) 60 | self._action_set = action_set 61 | self._last_image = None 62 | self._done = True 63 | 64 | @property 65 | def observation_space(self): 66 | shape = (self._config['height'], self._config['width'], 3) 67 | space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8) 68 | return gym.spaces.Dict({'image': space}) 69 | 70 | @property 71 | def action_space(self): 72 | return gym.spaces.Discrete(len(self._action_set)) 73 | 74 | def reset(self): 75 | self._done = False 76 | self._env.reset(seed=self._random.randint(0, 2 ** 31 - 1)) 77 | obs = self._get_obs() 78 | return obs 79 | 80 | def step(self, action): 81 | raw_action = np.array(self._action_set[action], np.intc) 82 | reward = self._env.step(raw_action, num_steps=self._action_repeat) 83 | self._done = not self._env.is_running() 84 | obs = self._get_obs() 85 | return obs, reward, self._done, {} 86 | 87 | def render(self, *args, **kwargs): 88 | if kwargs.get('mode', 'rgb_array') != 'rgb_array': 89 | raise ValueError("Only render mode 'rgb_array' is supported.") 90 | del args # Unused 91 | del kwargs # Unused 92 | return self._last_image 93 | 94 | def close(self): 95 | self._env.close() 96 | 97 | def _get_obs(self): 98 | if self._done: 99 | image = 0 * self._last_image 100 | else: 101 | image = self._env.observations()['RGB_INTERLEAVED'] 102 | self._last_image = image 103 | return {'image': image} 104 | 105 | 106 | 107 | class DeepMindControl: 108 | 109 | def __init__(self, name, action_repeat=1, size=(64, 64), camera=None): 110 | domain, task = name.split('_', 1) 111 | if domain == 'cup': # Only domain with multiple words. 112 | domain = 'ball_in_cup' 113 | if isinstance(domain, str): 114 | from dm_control import suite 115 | self._env = suite.load(domain, task) 116 | else: 117 | assert task is None 118 | self._env = domain() 119 | self._action_repeat = action_repeat 120 | self._size = size 121 | if camera is None: 122 | camera = dict(quadruped=2).get(domain, 0) 123 | self._camera = camera 124 | 125 | @property 126 | def observation_space(self): 127 | spaces = {} 128 | for key, value in self._env.observation_spec().items(): 129 | spaces[key] = gym.spaces.Box( 130 | -np.inf, np.inf, value.shape, dtype=np.float32) 131 | spaces['image'] = gym.spaces.Box( 132 | 0, 255, self._size + (3,), dtype=np.uint8) 133 | return gym.spaces.Dict(spaces) 134 | 135 | @property 136 | def action_space(self): 137 | spec = self._env.action_spec() 138 | return gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32) 139 | 140 | def step(self, action): 141 | assert np.isfinite(action).all(), action 142 | reward = 0 143 | for _ in range(self._action_repeat): 144 | time_step = self._env.step(action) 145 | reward += time_step.reward or 0 146 | if time_step.last(): 147 | break 148 | obs = dict(time_step.observation) 149 | obs['image'] = self.render() 150 | done = time_step.last() 151 | info = {'discount': np.array(time_step.discount, np.float32)} 152 | return obs, reward, done, info 153 | 154 | def reset(self): 155 | time_step = self._env.reset() 156 | obs = dict(time_step.observation) 157 | obs['image'] = self.render() 158 | return obs 159 | 160 | def render(self, *args, **kwargs): 161 | if kwargs.get('mode', 'rgb_array') != 'rgb_array': 162 | raise ValueError("Only render mode 'rgb_array' is supported.") 163 | return self._env.physics.render(*self._size, camera_id=self._camera) 164 | 165 | 166 | class Atari: 167 | 168 | LOCK = threading.Lock() 169 | 170 | def __init__( 171 | self, name, action_repeat=4, size=(84, 84), grayscale=True, noops=30, 172 | life_done=False, sticky_actions=True, all_actions=False): 173 | assert size[0] == size[1] 174 | import gym.wrappers 175 | import gym.envs.atari 176 | if name == 'james_bond': 177 | name = 'jamesbond' 178 | with self.LOCK: 179 | env = gym.envs.atari.AtariEnv( 180 | game=name, obs_type='image', frameskip=1, 181 | repeat_action_probability=0.25 if sticky_actions else 0.0, 182 | full_action_space=all_actions) 183 | # Avoid unnecessary rendering in inner env. 184 | env._get_obs = lambda: None 185 | # Tell wrapper that the inner env has no action repeat. 186 | env.spec = gym.envs.registration.EnvSpec('NoFrameskip-v0') 187 | env = gym.wrappers.AtariPreprocessing( 188 | env, noops, action_repeat, size[0], life_done, grayscale) 189 | self._env = env 190 | self._grayscale = grayscale 191 | 192 | @property 193 | def observation_space(self): 194 | return gym.spaces.Dict({ 195 | 'image': self._env.observation_space, 196 | 'ram': gym.spaces.Box(0, 255, (128,), np.uint8), 197 | }) 198 | 199 | @property 200 | def action_space(self): 201 | return self._env.action_space 202 | 203 | def close(self): 204 | return self._env.close() 205 | 206 | def reset(self): 207 | with self.LOCK: 208 | image = self._env.reset() 209 | if self._grayscale: 210 | image = image[..., None] 211 | obs = {'image': image, 'ram': self._env.env._get_ram()} 212 | return obs 213 | 214 | def step(self, action): 215 | image, reward, done, info = self._env.step(action) 216 | if self._grayscale: 217 | image = image[..., None] 218 | obs = {'image': image, 'ram': self._env.env._get_ram()} 219 | return obs, reward, done, info 220 | 221 | def render(self, mode): 222 | return self._env.render(mode) 223 | 224 | 225 | class CollectDataset: 226 | 227 | def __init__(self, env, callbacks=None, precision=32): 228 | self._env = env 229 | self._callbacks = callbacks or () 230 | self._precision = precision 231 | self._episode = None 232 | 233 | def __getattr__(self, name): 234 | return getattr(self._env, name) 235 | 236 | def step(self, action): 237 | obs, reward, done, info = self._env.step(action) 238 | obs = {k: self._convert(v) for k, v in obs.items()} 239 | transition = obs.copy() 240 | if isinstance(action, dict): 241 | transition.update(action) 242 | else: 243 | transition['action'] = action 244 | transition['reward'] = reward 245 | transition['discount'] = info.get('discount', np.array(1 - float(done))) 246 | self._episode.append(transition) 247 | if done: 248 | for key, value in self._episode[1].items(): 249 | if key not in self._episode[0]: 250 | self._episode[0][key] = 0 * value 251 | episode = {k: [t[k] for t in self._episode] for k in self._episode[0]} 252 | episode = {k: self._convert(v) for k, v in episode.items()} 253 | info['episode'] = episode 254 | for callback in self._callbacks: 255 | callback(episode) 256 | return obs, reward, done, info 257 | 258 | def reset(self): 259 | obs = self._env.reset() 260 | transition = obs.copy() 261 | # Missing keys will be filled with a zeroed out version of the first 262 | # transition, because we do not know what action information the agent will 263 | # pass yet. 264 | transition['reward'] = 0.0 265 | transition['discount'] = 1.0 266 | self._episode = [transition] 267 | return obs 268 | 269 | def _convert(self, value): 270 | value = np.array(value) 271 | if np.issubdtype(value.dtype, np.floating): 272 | dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self._precision] 273 | elif np.issubdtype(value.dtype, np.signedinteger): 274 | dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision] 275 | elif np.issubdtype(value.dtype, np.uint8): 276 | dtype = np.uint8 277 | else: 278 | raise NotImplementedError(value.dtype) 279 | return value.astype(dtype) 280 | 281 | 282 | class TimeLimit: 283 | 284 | def __init__(self, env, duration): 285 | self._env = env 286 | self._duration = duration 287 | self._step = None 288 | 289 | def __getattr__(self, name): 290 | return getattr(self._env, name) 291 | 292 | def step(self, action): 293 | assert self._step is not None, 'Must reset environment.' 294 | obs, reward, done, info = self._env.step(action) 295 | self._step += 1 296 | if self._step >= self._duration: 297 | done = True 298 | if 'discount' not in info: 299 | info['discount'] = np.array(1.0).astype(np.float32) 300 | self._step = None 301 | return obs, reward, done, info 302 | 303 | def reset(self): 304 | self._step = 0 305 | return self._env.reset() 306 | 307 | 308 | class NormalizeActions: 309 | 310 | def __init__(self, env): 311 | self._env = env 312 | self._mask = np.logical_and( 313 | np.isfinite(env.action_space.low), 314 | np.isfinite(env.action_space.high)) 315 | self._low = np.where(self._mask, env.action_space.low, -1) 316 | self._high = np.where(self._mask, env.action_space.high, 1) 317 | 318 | def __getattr__(self, name): 319 | return getattr(self._env, name) 320 | 321 | @property 322 | def action_space(self): 323 | low = np.where(self._mask, -np.ones_like(self._low), self._low) 324 | high = np.where(self._mask, np.ones_like(self._low), self._high) 325 | return gym.spaces.Box(low, high, dtype=np.float32) 326 | 327 | def step(self, action): 328 | original = (action + 1) / 2 * (self._high - self._low) + self._low 329 | original = np.where(self._mask, original, action) 330 | return self._env.step(original) 331 | 332 | 333 | class OneHotAction: 334 | 335 | def __init__(self, env): 336 | assert isinstance(env.action_space, gym.spaces.Discrete) 337 | self._env = env 338 | self._random = np.random.RandomState() 339 | 340 | def __getattr__(self, name): 341 | return getattr(self._env, name) 342 | 343 | @property 344 | def action_space(self): 345 | shape = (self._env.action_space.n,) 346 | space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32) 347 | space.sample = self._sample_action 348 | space.discrete = True 349 | return space 350 | 351 | def step(self, action): 352 | index = np.argmax(action).astype(int) 353 | reference = np.zeros_like(action) 354 | reference[index] = 1 355 | if not np.allclose(reference, action): 356 | raise ValueError(f'Invalid one-hot action:\n{action}') 357 | return self._env.step(index) 358 | 359 | def reset(self): 360 | return self._env.reset() 361 | 362 | def _sample_action(self): 363 | actions = self._env.action_space.n 364 | index = self._random.randint(0, actions) 365 | reference = np.zeros(actions, dtype=np.float32) 366 | reference[index] = 1.0 367 | return reference 368 | 369 | 370 | class RewardObs: 371 | 372 | def __init__(self, env): 373 | self._env = env 374 | 375 | def __getattr__(self, name): 376 | return getattr(self._env, name) 377 | 378 | @property 379 | def observation_space(self): 380 | spaces = self._env.observation_space.spaces 381 | assert 'reward' not in spaces 382 | spaces['reward'] = gym.spaces.Box(-np.inf, np.inf, dtype=np.float32) 383 | return gym.spaces.Dict(spaces) 384 | 385 | def step(self, action): 386 | obs, reward, done, info = self._env.step(action) 387 | obs['reward'] = reward 388 | return obs, reward, done, info 389 | 390 | def reset(self): 391 | obs = self._env.reset() 392 | obs['reward'] = 0.0 393 | return obs 394 | 395 | 396 | class SelectAction: 397 | 398 | def __init__(self, env, key): 399 | self._env = env 400 | self._key = key 401 | 402 | def __getattr__(self, name): 403 | return getattr(self._env, name) 404 | 405 | def step(self, action): 406 | return self._env.step(action[self._key]) 407 | --------------------------------------------------------------------------------