├── linear_code ├── replearn │ ├── __init__.py │ ├── features.py │ ├── mountaincar.py │ ├── loadunload.py │ ├── rollout.py │ └── learn.py ├── setup.py ├── README.md ├── LICENSE └── experiments │ ├── plot.py │ ├── run_loadunload.py │ └── run_mountaincar.py ├── mujoco_code ├── agents │ ├── __init__.py │ └── alm.py ├── workspaces │ ├── __init__.py │ ├── distracted_env.py │ ├── common.py │ └── mujoco_workspace.py ├── cfgs │ ├── config.yaml │ └── agent │ │ └── alm.yaml ├── utils │ ├── __init__.py │ ├── system.py │ ├── torch_utils.py │ ├── replay_buffer.py │ ├── env.py │ └── logger.py ├── requirements.txt ├── LICENSE ├── train.py ├── README.md ├── vis_tool.py └── models.py ├── .gitignore ├── minigrid_code ├── requirements.txt ├── LICENSE ├── vis_tool.py ├── main.py ├── models.py ├── README.md ├── run.py ├── logger.py ├── r2d2replaybuffer.py └── agent.py └── readme.md /linear_code/replearn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mujoco_code/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from .alm import AlmAgent 2 | -------------------------------------------------------------------------------- /mujoco_code/workspaces/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import make_agent, make_env 2 | from .mujoco_workspace import MujocoWorkspace 3 | -------------------------------------------------------------------------------- /mujoco_code/cfgs/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - agent@_global_: 'alm' 4 | - override hydra/hydra_logging: disabled 5 | - override hydra/job_logging: disabled -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .vscode/ 3 | __pycache__/ 4 | .ipynb_checkpoints 5 | *.pyc 6 | .cache/ 7 | 8 | 9 | # logs, a soft link to disk 10 | logs 11 | logs/ 12 | data/ 13 | debug/ 14 | plts/ 15 | results/ 16 | 17 | # slurm 18 | *.out 19 | -------------------------------------------------------------------------------- /mujoco_code/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .env import linear_schedule, register_mbpo_environments, save_frames_as_gif 2 | from .replay_buffer import ReplayMemory 3 | from .torch_utils import ( 4 | weight_init, 5 | soft_update, 6 | hard_update, 7 | get_parameters, 8 | FreezeParameters, 9 | TruncatedNormal, 10 | Dirac, 11 | ) 12 | -------------------------------------------------------------------------------- /linear_code/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | install_requires = [ 4 | "jax", 5 | "optax", 6 | "numpy", 7 | "pandas", 8 | "dm-haiku", 9 | "dm-env", 10 | "rlax", 11 | "chex", 12 | "absl-py", 13 | ] 14 | 15 | setup( 16 | name="linear-representation-learning", 17 | version="", 18 | packages=["replearn"], 19 | install_requires=install_requires, 20 | url="", 21 | license="", 22 | author="Clement Gehring", 23 | author_email="", 24 | description="", 25 | ) 26 | -------------------------------------------------------------------------------- /mujoco_code/utils/system.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import datetime 4 | import dateutil.tz 5 | 6 | 7 | def reproduce(seed): 8 | """ 9 | This can only fix the randomness of numpy and torch 10 | To fix the environment's, please use 11 | env.seed(seed) 12 | env.action_space.np_random.seed(seed) 13 | We have add these in our training script 14 | """ 15 | assert seed >= 0 16 | np.random.seed(seed) 17 | random.seed(seed) 18 | 19 | 20 | def now_str(): 21 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 22 | return now.strftime( 23 | "%Y-%m-%d-%H-%M-%S" 24 | ) # may cause collision, please use PID to prevent 25 | -------------------------------------------------------------------------------- /mujoco_code/requirements.txt: -------------------------------------------------------------------------------- 1 | antlr4-python3-runtime==4.9.3 2 | certifi==2022.6.15 3 | charset-normalizer==2.1.0 4 | click==8.1.3 5 | cloudpickle==1.3.0 6 | Cython==0.29.30 7 | docker-pycreds==0.4.0 8 | future==0.18.2 9 | gitdb==4.0.9 10 | GitPython==3.1.27 11 | glfw==2.5.3 12 | gym==0.17.2 13 | hydra-core==1.2.0 14 | idna==3.3 15 | imageio==2.19.3 16 | importlib-resources==5.8.0 17 | joblib==1.1.0 18 | lockfile==0.12.2 19 | matplotlib==3.5.3 20 | numpy 21 | omegaconf==2.2.2 22 | packaging==21.3 23 | pathtools==0.1.2 24 | promise==2.3 25 | protobuf==3.20.1 26 | psutil==5.9.1 27 | pycparser==2.21 28 | pyglet==1.5.0 29 | pyparsing==3.0.9 30 | PyYAML==6.0 31 | requests==2.28.1 32 | rliable==1.0.8 33 | scikit-learn 34 | scipy 35 | sentry-sdk 36 | setproctitle 37 | shortuuid 38 | six 39 | smmap 40 | threadpoolctl 41 | torch 42 | wandb 43 | -------------------------------------------------------------------------------- /linear_code/README.md: -------------------------------------------------------------------------------- 1 | # Code for Illustrating Theorem 3 2 | 3 | Code contributors: [Clement Gehring](https://people.csail.mit.edu/gehring/) (main), [Tianwei Ni](https://twni2016.github.io/). 4 | 5 | ## Installation 6 | ```bash 7 | python setup.py 8 | ``` 9 | 10 | ## Running 11 | 12 | ```bash 13 | python experiments/run_mountaincar.py 14 | python experiments/run_loadunload.py 15 | ``` 16 | For each environment, it will run each option of ZP target ("Online", "Detached", "EMA") for 100 seeds and generate the resulting data in `results/` folder. 17 | 18 | ## Plotting 19 | 20 | ```bash 21 | python experiments/plot.py \ 22 | --results_path results/mountaincar.pkl --figure_title "Mountain Car" 23 | python experiments/plot.py \ 24 | --results_path results/loadunload.pkl --with_legend --figure_title "Load-Unload" 25 | ``` 26 | 27 | It will generate the Figure 2 in the paper. 28 | 29 | 30 | -------------------------------------------------------------------------------- /mujoco_code/workspaces/distracted_env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | 4 | 5 | class DistractedWrapper(gym.Wrapper): 6 | def __init__(self, env, distraction: int, scale: float = 1.0) -> None: 7 | super().__init__(env) 8 | assert distraction > 0 9 | self.d = distraction 10 | self.scale = scale 11 | self.observation_space = gym.spaces.Box( 12 | low=-np.inf, high=np.inf, shape=self.reset().shape, dtype=np.float32 13 | ) 14 | 15 | def _get_distract_obs(self): 16 | return self.scale * np.random.normal(size=(self.d,)) 17 | 18 | def reset(self): 19 | obs = self.env.reset() 20 | return np.concatenate([obs, self._get_distract_obs()]) 21 | 22 | def step(self, action): 23 | obs, rew, done, info = self.env.step(action) 24 | return np.concatenate([obs, self._get_distract_obs()]), rew, done, info 25 | -------------------------------------------------------------------------------- /minigrid_code/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.1.0 2 | astunparse==1.6.3 3 | cached-property==1.5.2 4 | cachetools==5.2.0 5 | certifi==2022.6.15 6 | charset-normalizer==2.0.12 7 | cloudpickle==2.1.0 8 | colorama==0.4.6 9 | cycler==0.11.0 10 | decorator==4.4.2 11 | flatbuffers==2.0 12 | gast==0.4.0 13 | google-auth==2.8.0 14 | google-auth-oauthlib==0.4.6 15 | google-pasta==0.2.0 16 | grpcio==1.38.1 17 | gymnasium 18 | minigrid==2.3.0 19 | gym-notices==0.0.7 20 | idna==3.3 21 | importlib-metadata==4.11.4 22 | kiwisolver==1.3.1 23 | libclang==14.0.1 24 | Markdown==3.3.7 25 | MarkupSafe==2.1.2 26 | matplotlib==3.4.2 27 | numpy==1.21.4 28 | oauthlib==3.2.0 29 | packaging==23.0 30 | Pillow==8.4.0 31 | proglog==0.1.10 32 | protobuf==3.19.4 33 | pyasn1==0.4.8 34 | pyasn1-modules==0.2.8 35 | pyparsing==3.0.9 36 | python-dateutil==2.8.2 37 | requests==2.28.0 38 | requests-oauthlib==1.3.1 39 | rliable==1.0.8 40 | rsa==4.8 41 | six==1.16.0 42 | termcolor==1.1.0 43 | torch==1.12.1 44 | torchvision==0.13.1 45 | tqdm==4.64.1 46 | typing_extensions==4.2.0 47 | urllib3==1.26.9 48 | Werkzeug==2.1.2 49 | wrapt==1.12.1 50 | zipp==3.8.0 51 | -------------------------------------------------------------------------------- /linear_code/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Clement Gehring, Tianwei Ni 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /minigrid_code/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Erfan Seyedsalehi, Tianwei Ni 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /mujoco_code/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Tianwei Ni 4 | Copyright (c) 2022 Raj Ghugare 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /mujoco_code/cfgs/agent/alm.yaml: -------------------------------------------------------------------------------- 1 | #common 2 | agent: 'alm' 3 | device: 'cuda' 4 | seed: 1 5 | 6 | #benchmark 7 | benchmark: 'gym' 8 | id: 'Humanoid-v2' # ("HalfCheetah-v2", "Humanoid-v2", "Ant-v2", "Walker2d-v2", "Hopper-v2") 9 | distraction: 0 10 | scale: 0.1 11 | 12 | #data 13 | num_train_steps: 500000 14 | explore_steps: 5000 15 | max_episode_steps: 1000 16 | env_buffer_size: 100000 # humanoid-v2 will be automatically changed to 1e6 17 | batch_size: 512 18 | seq_len: 1 19 | 20 | #key hparams 21 | algo: td3 # {null, td3, alm-3, alm-1, alm-no-model, alm-no-model-ours} 22 | aux: rkl # {rkl, l2, none} 23 | aux_optim: ema # {ema, detach, online, none} 24 | aux_coef: 1.0 25 | disable_svg: true 26 | disable_reward: true 27 | freeze_critic: true 28 | online_encoder_actorcritic: true 29 | 30 | #learning 31 | gamma: 0.99 32 | tau: 0.005 33 | target_update_interval: 1 34 | lambda_cost: 0.1 35 | lr: {'encoder' : 0.0001, 'model' : 0.0001, 'reward' : 0.0001, 'critic' : 0.0001, 'actor' : 0.0001} 36 | max_grad_norm: 100.0 37 | 38 | #exploration 39 | expl_start: 1.0 40 | expl_end: 0.1 41 | expl_duration: 100000 42 | stddev_clip: 0.3 43 | 44 | #hidden_dims and layers 45 | latent_dims: 50 46 | hidden_dims: 512 47 | model_hidden_dims: 1024 48 | 49 | #bias evaluation 50 | eval_bias: False 51 | eval_bias_interval: 500 52 | 53 | #evaluation 54 | eval_episode_interval: 5000 55 | num_eval_episodes: 10 56 | 57 | #logging 58 | debug: false 59 | save_dir: "logs" 60 | log_interval: 500 61 | 62 | #saving 63 | save_snapshot: False 64 | save_snapshot_interval: 50000 65 | 66 | hydra: 67 | output_subdir: null 68 | run: 69 | dir: . -------------------------------------------------------------------------------- /mujoco_code/workspaces/common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import utils 3 | from utils import logger 4 | 5 | 6 | def make_agent(env, device, cfg): 7 | if cfg.agent == "alm": 8 | from agents.alm import AlmAgent 9 | 10 | num_states = np.prod(env.observation_space.shape) 11 | num_actions = np.prod(env.action_space.shape) 12 | action_low = env.action_space.low[0] 13 | action_high = env.action_space.high[0] 14 | 15 | if cfg.id == "Humanoid-v2": 16 | cfg.env_buffer_size = 1000000 17 | buffer_size = min(cfg.env_buffer_size, cfg.num_train_steps) 18 | 19 | agent = AlmAgent( 20 | device, 21 | action_low, 22 | action_high, 23 | num_states, 24 | num_actions, 25 | buffer_size, 26 | cfg, 27 | ) 28 | 29 | else: 30 | raise NotImplementedError 31 | 32 | return agent 33 | 34 | 35 | def make_env(cfg): 36 | if cfg.benchmark == "gym": 37 | import gym 38 | 39 | if cfg.id == "T-Ant-v2" or cfg.id == "T-Humanoid-v2": 40 | utils.register_mbpo_environments() 41 | 42 | def get_env(cfg): 43 | env = gym.make(cfg.id) 44 | 45 | if cfg.distraction > 0: 46 | from workspaces.distracted_env import DistractedWrapper 47 | 48 | env = DistractedWrapper( 49 | env, 50 | distraction=cfg.distraction, 51 | scale=cfg.scale, 52 | ) 53 | 54 | env = gym.wrappers.RecordEpisodeStatistics(env) 55 | env.seed(seed=cfg.seed) 56 | env.observation_space.seed(cfg.seed) 57 | env.action_space.seed(cfg.seed) 58 | logger.log(env.observation_space.shape, env.action_space) 59 | return env 60 | 61 | return get_env(cfg), get_env(cfg) 62 | 63 | else: 64 | raise NotImplementedError 65 | -------------------------------------------------------------------------------- /mujoco_code/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | pid = str(os.getpid()) 5 | if "SLURM_JOB_ID" in os.environ: 6 | jobid = str(os.environ["SLURM_JOB_ID"]) 7 | else: 8 | jobid = pid 9 | 10 | from utils import logger, system 11 | from omegaconf import DictConfig, OmegaConf 12 | import hydra 13 | import warnings 14 | 15 | warnings.simplefilter("ignore", UserWarning) 16 | 17 | 18 | @hydra.main(config_path="cfgs", config_name="config", version_base=None) 19 | def main(cfg: DictConfig): 20 | if cfg.benchmark == "gym": 21 | from workspaces.mujoco_workspace import MujocoWorkspace as W 22 | else: 23 | raise NotImplementedError 24 | env_id = cfg.id 25 | if cfg.distraction > 0: 26 | env_id += f"-d{cfg.distraction}" 27 | 28 | if cfg.seed < 0: 29 | cfg.seed = int(pid) # to avoid conflict within a job which has same datetime 30 | 31 | run_name = f"{system.now_str()}+{jobid}-{pid}" 32 | format_strs = ["csv"] 33 | if cfg.debug: 34 | cfg.save_dir = "debug" 35 | format_strs.extend(["stdout", "log"]) # logger.log 36 | 37 | log_path = os.path.join(cfg.save_dir, env_id, run_name) 38 | logger.configure(dir=log_path, format_strs=format_strs, precision=4) 39 | 40 | existing_variants = { 41 | "alm-3": (False, False, False, False, "rkl", "ema", "v-1.0", 3), 42 | "alm-1": (False, False, False, False, "rkl", "ema", "v-1.0", 1), 43 | "alm-no-model": (False, True, False, False, "rkl", "ema", "v-1.0", 1), 44 | "alm-0": (True, True, False, False, "rkl", "ema", "v-1.0", 1), 45 | "alm-0-ours": (True, True, True, True, "rkl", "ema", "v-1.0", 1), 46 | "td3": (True, True, True, True, None, None, "v-0.0", 1), 47 | } 48 | 49 | if cfg.algo in existing_variants: 50 | ( 51 | cfg.disable_reward, 52 | cfg.disable_svg, 53 | cfg.freeze_critic, 54 | cfg.online_encoder_actorcritic, 55 | cfg.aux, 56 | cfg.aux_optim, 57 | cfg.aux_coef, 58 | cfg.seq_len, 59 | ) = existing_variants[cfg.algo] 60 | elif cfg.algo == "ours": 61 | ( 62 | cfg.disable_reward, 63 | cfg.disable_svg, 64 | cfg.freeze_critic, 65 | cfg.online_encoder_actorcritic, 66 | cfg.seq_len, 67 | ) = (True, True, True, True, 1) 68 | else: 69 | raise ValueError(cfg.algo) 70 | 71 | # write config to a yml 72 | with open(os.path.join(log_path, "flags.yml"), "w") as f: 73 | OmegaConf.save(cfg, f) 74 | 75 | workspace = W(cfg) 76 | workspace.train() 77 | 78 | 79 | if __name__ == "__main__": 80 | main() 81 | -------------------------------------------------------------------------------- /mujoco_code/utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.distributions as td 4 | from typing import Iterable 5 | 6 | # NN weight utils 7 | 8 | 9 | def weight_init(m): 10 | if isinstance(m, nn.Linear): 11 | nn.init.orthogonal_(m.weight.data) 12 | if hasattr(m.bias, "data"): 13 | m.bias.data.fill_(0.0) 14 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 15 | gain = nn.init.calculate_gain("relu") 16 | nn.init.orthogonal_(m.weight.data, gain) 17 | if hasattr(m.bias, "data"): 18 | m.bias.data.fill_(0.0) 19 | 20 | 21 | def soft_update(target, source, tau): 22 | for target_param, param in zip(target.parameters(), source.parameters()): 23 | target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau) 24 | 25 | 26 | def hard_update(target, source): 27 | for target_param, param in zip(target.parameters(), source.parameters()): 28 | target_param.data.copy_(param.data) 29 | 30 | 31 | # NN module utils 32 | 33 | 34 | def get_parameters(modules: Iterable[nn.Module]): 35 | model_parameters = [] 36 | for module in modules: 37 | model_parameters += list(module.parameters()) 38 | return model_parameters 39 | 40 | 41 | class FreezeParameters: 42 | def __init__(self, modules: Iterable[nn.Module]): 43 | self.modules = modules 44 | self.param_states = [p.requires_grad for p in get_parameters(self.modules)] 45 | 46 | def __enter__(self): 47 | for param in get_parameters(self.modules): 48 | param.requires_grad = False 49 | 50 | def __exit__(self, exc_type, exc_val, exc_tb): 51 | for i, param in enumerate(get_parameters(self.modules)): 52 | param.requires_grad = self.param_states[i] 53 | 54 | 55 | # torch dist utils 56 | 57 | 58 | class TruncatedNormal(td.Normal): 59 | def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6): 60 | super().__init__(loc, scale, validate_args=False) 61 | self.low = low 62 | self.high = high 63 | self.eps = eps 64 | 65 | def _clamp(self, x): 66 | clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps) 67 | x = x - x.detach() + clamped_x.detach() 68 | return x 69 | 70 | def sample(self, clip=None, sample_shape=torch.Size()): 71 | shape = self._extended_shape(sample_shape) 72 | eps = td.utils._standard_normal( 73 | shape, dtype=self.loc.dtype, device=self.loc.device 74 | ) 75 | eps *= self.scale 76 | if clip is not None: 77 | eps = torch.clamp(eps, -clip, clip) 78 | x = self.loc + eps 79 | return self._clamp(x) 80 | 81 | 82 | class Dirac: 83 | def __init__(self, loc): 84 | self.loc = loc # a tensor 85 | 86 | def sample(self): 87 | return self.loc.detach() 88 | 89 | def rsample(self): 90 | return self.loc 91 | -------------------------------------------------------------------------------- /mujoco_code/utils/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class ReplayMemory: 5 | def __init__(self, buffer_limit, obs_size, action_size, obs_dtype): 6 | print("buffer limit is = ", buffer_limit) 7 | self.buffer_limit = buffer_limit 8 | self.observation = np.empty((buffer_limit, obs_size), dtype=obs_dtype) 9 | self.next_observation = np.empty((buffer_limit, obs_size), dtype=obs_dtype) 10 | self.action = np.empty((buffer_limit, action_size), dtype=np.float32) 11 | self.reward = np.empty((buffer_limit,), dtype=np.float32) 12 | self.terminal = np.empty((buffer_limit,), dtype=bool) 13 | self.idx = 0 14 | self.full = False 15 | 16 | def push(self, transition): 17 | state, action, reward, next_state, done = transition 18 | self.observation[self.idx] = state 19 | self.next_observation[self.idx] = next_state 20 | self.action[self.idx] = action 21 | self.reward[self.idx] = reward 22 | self.terminal[self.idx] = done 23 | self.idx = (self.idx + 1) % self.buffer_limit 24 | self.full = self.full or self.idx == 0 25 | 26 | def sample(self, n): 27 | idxes = np.random.randint( 28 | 0, self.buffer_limit if self.full else self.idx, size=n 29 | ) 30 | return ( 31 | self.observation[idxes], 32 | self.action[idxes], 33 | self.reward[idxes], 34 | self.next_observation[idxes], 35 | self.terminal[idxes], 36 | ) 37 | 38 | def sample_seq(self, seq_len, batch_size): 39 | n = batch_size 40 | l = seq_len 41 | obs, act, rew, next_obs, term = self._retrieve_batch( 42 | np.asarray([self._sample_idx(l) for _ in range(n)]), n, l 43 | ) 44 | return obs, act, rew, next_obs, term 45 | 46 | def sample_probe_data(self, data_size): 47 | idxes = np.random.randint( 48 | 0, self.buffer_limit if self.full else self.idx, size=data_size 49 | ) 50 | return self.observation[idxes] 51 | 52 | def _sample_idx(self, L): 53 | valid_idx = False 54 | while not valid_idx: 55 | idx = np.random.randint(0, self.buffer_limit if self.full else self.idx - L) 56 | idxs = np.arange(idx, idx + L) % self.buffer_limit 57 | valid_idx = (not self.idx in idxs[1:]) and ( 58 | not self.terminal[idxs[:-1]].any() 59 | ) 60 | return idxs 61 | 62 | def _retrieve_batch(self, idxs, n, l): 63 | vec_idxs = idxs.transpose().reshape(-1) 64 | return ( 65 | self.observation[vec_idxs].reshape(l, n, -1), 66 | self.action[vec_idxs].reshape(l, n, -1), 67 | self.reward[vec_idxs].reshape(l, n), 68 | self.next_observation[vec_idxs].reshape(l, n, -1), 69 | self.terminal[vec_idxs].reshape(l, n), 70 | ) 71 | 72 | def __len__(self): 73 | return self.buffer_limit if self.full else self.idx + 1 74 | -------------------------------------------------------------------------------- /linear_code/replearn/features.py: -------------------------------------------------------------------------------- 1 | from collections import abc 2 | import itertools 3 | from typing import Any, NamedTuple, Protocol, TypeVar 4 | 5 | import numpy as np 6 | 7 | import jax 8 | import jax.numpy as jnp 9 | 10 | import dm_env 11 | 12 | T = TypeVar("T", bound=Any, covariant=True) 13 | PyTree = Any 14 | Array = Any 15 | 16 | 17 | class Encoder(Protocol[T]): 18 | def apply(self, observation: PyTree) -> T: 19 | raise NotImplementedError 20 | 21 | 22 | def _as_iterable(bound, max_repeat=None): 23 | if not isinstance(bound, abc.Iterable): 24 | bound = itertools.repeat(bound, max_repeat) 25 | return bound 26 | 27 | 28 | def uniform_centers(env: dm_env.Environment, centers_per_dim: int): 29 | obs_spec = env.observation_spec() 30 | assert len(obs_spec.shape) == 1, "Only rank 1 observations are supported." 31 | 32 | minimums = _as_iterable(obs_spec.minimum, obs_spec.shape[0]) 33 | maximums = _as_iterable(obs_spec.maximum, obs_spec.shape[0]) 34 | 35 | centers = np.meshgrid( 36 | *[ 37 | np.linspace(lb, ub, centers_per_dim, endpoint=True) 38 | for lb, ub in zip(minimums, maximums) 39 | ] 40 | ) 41 | centers = np.stack([x.flatten() for x in centers], axis=0) 42 | 43 | return centers 44 | 45 | 46 | def normalized_scales(env: dm_env.Environment, scale: float): 47 | obs_spec = env.observation_spec() 48 | assert len(obs_spec.shape) == 1, "Only rank 1 observations are supported." 49 | 50 | minimums = list(_as_iterable(obs_spec.minimum, obs_spec.shape[0])) 51 | maximums = list(_as_iterable(obs_spec.maximum, obs_spec.shape[0])) 52 | span = np.subtract(maximums, minimums) 53 | 54 | return span * scale 55 | 56 | 57 | class RBFEncoder(NamedTuple): 58 | """A feature encoding using radial basis functions.""" 59 | 60 | centers: Array 61 | scales: Array 62 | normalized: bool 63 | 64 | def apply(self, inputs): 65 | diff = (inputs[..., None] - self.centers) / self.scales[..., None] 66 | neg_dist = -jnp.sum(diff**2, axis=-2) 67 | if self.normalized: 68 | return jax.nn.softmax(neg_dist) 69 | else: 70 | return jnp.exp(neg_dist) 71 | 72 | 73 | class OneHot(NamedTuple): 74 | """A one-hot encoder.""" 75 | 76 | dim: int 77 | 78 | def apply(self, inputs): 79 | if inputs.shape[-1] == 1: 80 | inputs = jnp.squeeze(inputs, -1) 81 | return jax.nn.one_hot(inputs, self.dim) 82 | 83 | 84 | class TruncatedHistoryEncoder(NamedTuple): 85 | horizon: int 86 | 87 | def apply(self, observations, actions): 88 | observations = jnp.pad(observations, [(self.horizon - 1, 0), (0, 0)]) 89 | actions = jnp.pad(actions, [(self.horizon - 1, 0), (0, 0)]) 90 | stacked_obs = [ 91 | self.index_to_history(i, observations, actions) 92 | for i in range(0, observations.shape[0] - self.horizon + 1) 93 | ] 94 | return jnp.stack(stacked_obs) 95 | 96 | def index_to_history(self, index, observations, actions): 97 | return jnp.concatenate( 98 | ( 99 | observations[index : index + self.horizon].reshape(-1), 100 | actions[index : index + self.horizon - 1].reshape(-1), 101 | ) 102 | ) 103 | -------------------------------------------------------------------------------- /linear_code/experiments/plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | import matplotlib.pyplot as plt 8 | import seaborn as sns 9 | 10 | from absl import app 11 | from absl import flags 12 | 13 | if __name__ == "__main__": 14 | flags.DEFINE_string( 15 | "results_path", "./results/mountaincar.pkl", "Path to the saved results." 16 | ) 17 | flags.DEFINE_string("output_path", None, "Filename for the saved figure.") 18 | flags.DEFINE_string("figure_title", None, "Title to add to the figure.") 19 | flags.DEFINE_boolean("with_legend", False, "If flag present, add a legend.") 20 | 21 | FLAGS = flags.FLAGS 22 | 23 | sns.set_style("whitegrid", {"grid.linestyle": "--"}) 24 | plt.rcParams["figure.dpi"] = 300 25 | plt.rcParams["figure.figsize"] = (4, 3) 26 | plt.rcParams["axes.labelsize"] = 15 27 | plt.rcParams["axes.titlesize"] = 15 28 | plt.rcParams["xtick.labelsize"] = 14 29 | plt.rcParams["ytick.labelsize"] = 14 30 | plt.rcParams["legend.fontsize"] = 13 31 | plt.rcParams["axes.grid"] = True 32 | plt.rcParams["legend.loc"] = "best" 33 | plt.rcParams["lines.linewidth"] = 1.5 34 | plt.rcParams["axes.formatter.useoffset"] = False 35 | plt.rcParams["axes.formatter.offset_threshold"] = 1 36 | # plt.rcParams["font.size"] = 8 37 | plt.rcParams["font.family"] = "serif" 38 | plt.rcParams["font.serif"] = ["Liberation Serif"] 39 | plt.rcParams["text.usetex"] = True 40 | 41 | 42 | def cosine_similarity(x, y): 43 | x = x / np.linalg.norm(x) 44 | y = y / np.linalg.norm(y) 45 | return np.dot(x, y) 46 | 47 | 48 | def main(argv): 49 | if len(argv) > 1: 50 | raise app.UsageError("Too many command-line arguments.") 51 | 52 | with open(FLAGS.results_path, "rb") as f: 53 | results = pickle.load(f) 54 | 55 | X_AXIS_LABEL = "Iterations" 56 | Y_AXIS_LABEL = " abs. cosine similarity" 57 | 58 | for log in results: 59 | params = log.pop("params") 60 | log[Y_AXIS_LABEL] = np.abs(cosine_similarity(params[:, 0], params[:, 1])) 61 | log["Iterations"] = log["step"] 62 | 63 | data = pd.DataFrame.from_records(results) 64 | 65 | ax = sns.lineplot( 66 | data=data, 67 | x=X_AXIS_LABEL, 68 | y=Y_AXIS_LABEL, 69 | units="seed", 70 | hue="use_stop_gradient", 71 | estimator=None, # show all seeds 72 | legend=False, 73 | alpha=0.2, 74 | ) 75 | 76 | ax = sns.lineplot( 77 | data=data, 78 | x=X_AXIS_LABEL, 79 | y=Y_AXIS_LABEL, 80 | hue="use_stop_gradient", 81 | estimator=np.median, 82 | ax=ax, 83 | linewidth=2.0, 84 | palette=sns.color_palette("dark", 3, desat=0.9), 85 | ) 86 | plt.yscale("log") 87 | plt.yticks([10 ** (2 * i - 8) for i in range(4)]) 88 | plt.xticks(np.arange(0, 501, step=100)) 89 | plt.ylim(1e-10, 1) 90 | 91 | if FLAGS.with_legend: 92 | ax.legend(framealpha=0.2) # must use the returned ans 93 | else: 94 | ax.legend().set_visible(False) 95 | 96 | if FLAGS.figure_title: 97 | plt.title(FLAGS.figure_title) 98 | 99 | filename = FLAGS.output_path 100 | if filename is None: 101 | filename = os.path.splitext(FLAGS.results_path)[0] + ".pdf" 102 | plt.savefig(filename, bbox_inches="tight") 103 | 104 | 105 | if __name__ == "__main__": 106 | app.run(main) 107 | -------------------------------------------------------------------------------- /minigrid_code/vis_tool.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | pd.options.mode.chained_assignment = ( 4 | None # ignore all warnings like SettingWithCopyWarning 5 | ) 6 | 7 | import numpy as np 8 | from typing import Callable 9 | import os, glob 10 | import shutil 11 | import json 12 | 13 | 14 | def walk_through( 15 | path: str, 16 | metric: str, 17 | query_fn: Callable, 18 | start: int, 19 | end: int, 20 | steps: int, 21 | window: int, 22 | cutoff: float = 0.95, 23 | extrapolate: bool = False, 24 | delete: bool = False, 25 | ): 26 | def isnan(number): 27 | return np.isnan(float(number)) 28 | 29 | def smooth(df): 30 | try: 31 | df = df.dropna(subset=[metric]) # remove NaN rows 32 | except KeyError: 33 | print("!!key error csv", run) 34 | if delete: 35 | shutil.rmtree(run) 36 | print("deleted") 37 | return None 38 | 39 | if isnan(df["env_steps"].iloc[-1]) or df["env_steps"].iloc[-1] < cutoff * end: 40 | # an incomplete run 41 | print("!!incomplete csv", run, df["env_steps"].iloc[-1], end=" ") 42 | if delete: 43 | shutil.rmtree(run) 44 | print("deleted") 45 | else: 46 | print("\n") 47 | return None 48 | 49 | # smooth by moving average 50 | df[metric] = df[metric].rolling(window=window, min_periods=1).mean() 51 | 52 | # update the columns with interpolated values and aligned steps 53 | aligned_step = np.linspace(start, end, steps).astype(np.int32) 54 | if not extrapolate: 55 | ## we only do interpolation, not extrapolation 56 | aligned_step = aligned_step[aligned_step <= df["env_steps"].iloc[-1]] 57 | aligned_value = np.interp(aligned_step, df["env_steps"], df[metric]) 58 | 59 | # enlarge or reduce to same number of rows 60 | print(run, df.shape[0], df["env_steps"].iloc[-1]) 61 | 62 | extracted_df = pd.DataFrame( 63 | data={ 64 | "env_steps": aligned_step, 65 | metric: aligned_value, 66 | } 67 | ) 68 | 69 | return extracted_df 70 | 71 | dfs = [] 72 | i = 0 73 | 74 | runs = sorted(glob.glob(os.path.join(path, "*"))) 75 | 76 | for run in runs: 77 | with open(os.path.join(run, "config.json")) as f: 78 | flags = json.load(f) 79 | 80 | if not query_fn(flags): 81 | continue 82 | 83 | csv_path = os.path.join(run, "progress.csv") 84 | try: 85 | df = pd.read_csv(open(csv_path)) 86 | except pd.errors.EmptyDataError: 87 | print("!!empty csv", run) 88 | if delete: 89 | shutil.rmtree(run) 90 | print("deleted") 91 | continue 92 | 93 | df = smooth(df) 94 | if df is None: 95 | continue 96 | i += 1 97 | 98 | # concat flags (dot) 99 | pd_flags = pd.json_normalize(flags) 100 | df_flag = pd.concat([pd_flags] * df.shape[0], axis=0) # repeat rows 101 | df_flag.index = df.index # copy index 102 | df = pd.concat([df, df_flag], axis=1) 103 | dfs.append(df) 104 | 105 | print("\n in total:", i) 106 | dfs = pd.concat(dfs, ignore_index=True) 107 | return dfs 108 | -------------------------------------------------------------------------------- /linear_code/replearn/mountaincar.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import dm_env 4 | from dm_env import specs 5 | 6 | import numpy as np 7 | 8 | import jax 9 | import jax.numpy as jnp 10 | 11 | from replearn import rollout 12 | 13 | 14 | class MountainCarPolicy(rollout.EpsilonGreedyPolicy): 15 | def preferences(self, observations): 16 | # prefer action along the current velocity 17 | return jax.lax.cond( 18 | observations[1] < 0, 19 | lambda _: jnp.array([1.0, 0.0, 0.0]), 20 | lambda _: jnp.array([0.0, 0.0, 1.0]), 21 | None, 22 | ) 23 | 24 | 25 | class MountainCar(dm_env.Environment): 26 | """Implementation of the Mountain Car domain. 27 | 28 | Moore, Andrew William. "Efficient memory-based learning for robot control." (1990). 29 | 30 | Default parameters use values presented in Example 10.1 by Sutton & Barto (2018): 31 | Sutton, Richard S., and Andrew G. Barto. Reinforcement learning: An introduction. MIT press, 32 | 2018. 33 | 34 | """ 35 | 36 | def __init__( 37 | self, 38 | seed: Optional[int] = None, 39 | min_pos: float = -1.2, 40 | max_pos: float = 0.6, 41 | min_init_pos: float = -0.6, 42 | max_init_pos: float = -0.4, 43 | max_speed: float = 0.07, 44 | goal_pos: float = 0.5, 45 | force: float = 0.001, 46 | gravity: float = 0.0025, 47 | ): 48 | self._min_pos = min_pos 49 | self._max_pos = max_pos 50 | self._min_init_pos = min_init_pos 51 | self._max_init_pos = max_init_pos 52 | self._max_speed = max_speed 53 | self._goal_pos = goal_pos 54 | self._force = force 55 | self._gravity = gravity 56 | 57 | self._rng = np.random.default_rng(seed) 58 | self._position = 0.0 59 | self._velocity = 0.0 60 | 61 | def _observation(self): 62 | return np.array([self._position, self._velocity], np.float32) 63 | 64 | def reset(self): 65 | self._position = self._rng.uniform(self._min_init_pos, self._max_init_pos) 66 | self._velocity = 0.0 67 | return dm_env.restart(self._observation()) 68 | 69 | def step(self, action): 70 | """Step the environment 71 | 72 | :param action: 0, 1, 2 correspond to actions left, idle, right, respectively. 73 | :return: the next timestep 74 | """ 75 | next_vel = ( 76 | self._velocity 77 | + self._force * (action - 1) 78 | - self._gravity * np.cos(self._position * 3) 79 | ) 80 | self._velocity = np.clip(next_vel, -self._max_speed, self._max_speed) 81 | 82 | self._position = np.clip( 83 | self._position + next_vel, self._min_pos, self._max_pos 84 | ) 85 | 86 | reward = -1 87 | obs = self._observation() 88 | 89 | if self._position >= self._goal_pos: 90 | return dm_env.termination(reward=0.0, observation=obs) 91 | 92 | return dm_env.transition(reward=reward, observation=obs) 93 | 94 | def observation_spec(self): 95 | return specs.BoundedArray( 96 | shape=(2,), 97 | dtype=np.float32, 98 | minimum=[self._min_pos, -self._max_speed], 99 | maximum=[self._max_pos, self._max_speed], 100 | ) 101 | 102 | def action_spec(self): 103 | """Actions 0, 1, 2 correspond to actions left, idle, right, respectively.""" 104 | return specs.DiscreteArray(3, name="action") 105 | -------------------------------------------------------------------------------- /minigrid_code/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | pid = str(os.getpid()) 5 | if "SLURM_JOB_ID" in os.environ: 6 | jobid = str(os.environ["SLURM_JOB_ID"]) 7 | else: 8 | jobid = pid 9 | 10 | import argparse 11 | import json 12 | from run import run_exp 13 | import logger 14 | 15 | 16 | def main(): 17 | parser = argparse.ArgumentParser() 18 | 19 | parser.add_argument("--debug", type=bool, default=False) 20 | parser.add_argument("--seed", type=int, default=1) 21 | parser.add_argument("--env_name", type=str) 22 | 23 | # Training 24 | parser.add_argument("--cuda", type=bool, default=True) 25 | parser.add_argument("--num_steps", type=int, default=4000000) 26 | ## Freq 27 | parser.add_argument("--logging_freq", type=int, default=10000) 28 | parser.add_argument("--target_update_interval", type=int, default=1) 29 | parser.add_argument("--rl_update_every_n_steps", type=int, default=10) 30 | parser.add_argument("--rl_updates_per_step", type=int, default=1) 31 | parser.add_argument("--model_updates_per_step", type=int, default=1) 32 | parser.add_argument("--random_actions_until", type=int, default=0) 33 | 34 | # Buffer 35 | parser.add_argument("--replay_size", type=int, default=400000) 36 | parser.add_argument("--batch_size", type=int, default=256) 37 | ## Len 38 | parser.add_argument("--burn_in_len", type=int, default=50) 39 | parser.add_argument("--learning_obs_len", type=int, default=10) 40 | parser.add_argument("--forward_len", type=int, default=5) 41 | 42 | # Representation learning 43 | parser.add_argument("--aux", type=str, default="None") 44 | parser.add_argument("--aux_optim", type=str, default="None") 45 | parser.add_argument("--aux_coef", type=float, default=0.5) 46 | parser.add_argument("--aux_lr", type=float, default=1e-3) 47 | parser.add_argument("--AIS_state_size", type=int, default=128) 48 | 49 | # RL (DDQN) 50 | parser.add_argument("--hidden_size", type=int, default=128) 51 | parser.add_argument("--gamma", type=float, default=0.99) 52 | parser.add_argument("--tau", type=float, default=0.005) 53 | parser.add_argument("--rl_lr", type=float, default=1e-3) 54 | parser.add_argument("--TD_loss", type=str, default="mse") 55 | ## Exploration 56 | parser.add_argument("--EPS_start", type=float, default=1.0) 57 | parser.add_argument("--EPS_end", type=float, default=0.05) 58 | parser.add_argument("--EPS_decay", type=int, default=400000) 59 | parser.add_argument("--EPS_decay_type", type=str, default="exponential") 60 | parser.add_argument("--test_epsilon", type=float, default=0.0) 61 | 62 | args = parser.parse_args() 63 | 64 | # convert to dictionary 65 | params = vars(args) 66 | print(params) 67 | 68 | ################################## 69 | ### CREATE DIRECTORY FOR LOGGING 70 | ################################## 71 | format_strs = ["csv"] 72 | save_dir = "logs" 73 | if args.debug: 74 | save_dir = "debug" 75 | format_strs.extend(["stdout", "log"]) # logger.log 76 | 77 | unique_id = time.strftime("%Y-%m-%d-%H:%M:%S") + "_" + jobid + "-" + pid 78 | logdir = os.path.join(save_dir, args.env_name, unique_id) 79 | params["logdir"] = logdir 80 | logger.configure(dir=logdir, format_strs=format_strs) 81 | 82 | config_path = os.path.join(logdir, "config.json") 83 | with open(config_path, "w") as fp: 84 | json.dump(params, fp, indent=4) 85 | 86 | ################### 87 | ### RUN TRAINING 88 | ################### 89 | 90 | if params["seed"] < 0: 91 | params["seed"] = int( 92 | pid 93 | ) # to avoid conflict within a job which has same datetime 94 | run_exp(params) 95 | 96 | 97 | if __name__ == "__main__": 98 | main() 99 | -------------------------------------------------------------------------------- /linear_code/replearn/loadunload.py: -------------------------------------------------------------------------------- 1 | from typing import Any, NamedTuple, Optional 2 | 3 | import dm_env 4 | from dm_env import specs 5 | 6 | import numpy as np 7 | 8 | import jax 9 | import jax.numpy as jnp 10 | 11 | from replearn import rollout 12 | 13 | 14 | class LoadUnloadState(NamedTuple): 15 | position: int 16 | loaded: bool 17 | 18 | 19 | class LoadUnloadPolicyState(NamedTuple): 20 | key: Any 21 | last_action: int 22 | 23 | 24 | class LoadUnloadPolicy: 25 | def __init__(self, key, switch_prob=0.2): 26 | self._state = LoadUnloadPolicyState(key, 0) 27 | self._swtich_prob = switch_prob 28 | 29 | def _sample_and_next_state(state, observations): 30 | key, a_key = jax.random.split(state.key) 31 | action = self.sample( 32 | LoadUnloadPolicyState(a_key, state.last_action), observations 33 | ) 34 | return LoadUnloadPolicyState(key, action), action 35 | 36 | self.sample_and_next_state = jax.jit(_sample_and_next_state) 37 | 38 | def sample(self, state, observations): 39 | def _maybe_switch_action(): 40 | return jax.lax.cond( 41 | jax.random.uniform(state.key) < self._swtich_prob, 42 | lambda: (state.last_action + 1) % 2, 43 | lambda: state.last_action, 44 | ) 45 | 46 | # if at unload go right, if at load go left, else keep previous action with random switch 47 | return jax.lax.switch( 48 | observations[0], [lambda: 1, _maybe_switch_action, lambda: 0] 49 | ) 50 | 51 | def stateful_sample(self, observations): 52 | self._state, action = self.sample_and_next_state(self._state, observations) 53 | return action 54 | 55 | def set_rng_key(self, key): 56 | self._state = LoadUnloadPolicyState(key, self._state.last_action) 57 | 58 | 59 | class LoadUnload(dm_env.Environment): 60 | def __init__(self, seed: Optional[int] = None, load_position: Optional[int] = 6): 61 | self._state = LoadUnloadState(0, False) 62 | self._load_position = load_position 63 | self._rng = np.random.default_rng(seed) 64 | 65 | def _observation(self): 66 | if self._state.position == 0: 67 | obs = 0 68 | elif self._state.position == self._load_position: 69 | obs = 2 70 | else: 71 | obs = 1 72 | return np.array([obs], np.int32) 73 | 74 | def reset(self): 75 | new_pos = self._rng.integers(self._load_position, size=1)[0] 76 | self._state = LoadUnloadState( 77 | new_pos, 78 | new_pos == self._load_position, 79 | ) 80 | return dm_env.restart(self._observation()) 81 | 82 | def step(self, action): 83 | pos, loaded = self._state 84 | if action == 0: 85 | new_pos = max((pos - 1, 0)) 86 | elif action == 1: 87 | new_pos = min((pos + 1, self._load_position)) 88 | else: 89 | raise ValueError(f"Unrecognized action '{action}'") 90 | 91 | reward = -1 92 | if new_pos == 0: 93 | if loaded: 94 | reward = 100 95 | loaded = False 96 | elif new_pos == self._load_position: 97 | loaded = True 98 | 99 | self._state = LoadUnloadState(new_pos, loaded) 100 | obs = self._observation() 101 | 102 | return dm_env.transition(reward=reward, observation=obs) 103 | 104 | def observation_spec(self): 105 | return specs.DiscreteArray(3, name="observation") 106 | 107 | def action_spec(self): 108 | """Actions 0, 1 correspond to actions left, right, respectively.""" 109 | return specs.DiscreteArray(2, name="action") 110 | -------------------------------------------------------------------------------- /mujoco_code/README.md: -------------------------------------------------------------------------------- 1 | # Code for State Representation Learning in Standard and Distracting MDPs (Section 5.1 & 5.2) 2 | 3 | Code contributor: [Tianwei Ni](https://twni2016.github.io/). 4 | 5 | ## Installation 6 | 7 | We use python 3.7+ and list the basic requirements in `requirements.txt`. 8 | 9 | ## Key Flags 10 | The configuration file is `cfgs/agent/alm.yaml`. 11 | 12 | Environments: 13 | - **Standard MuJoCo** (Section 5.1): set `id` from "HalfCheetah-v2", "Humanoid-v2", "Ant-v2", "Walker2d-v2", "Hopper-v2" 14 | - **Distracted MuJoCo** (Section 5.2): set `distraction=128` and `scale=1.0` for 128 distractors with standard Gaussian noises 15 | 16 | Compared algorithms: 17 | - ALM(3): set `algo=alm-3` 18 | - ALM-no-model: set `algo=alm-no-model` 19 | - ALM(0): set `algo=alm-0` 20 | 21 | Our minimalist $\phi_L$: set `algo=ours` and 22 | - `aux`: select from `fkl, rkl, l2` 23 | - `aux_optim`: select from `ema, detach, online` 24 | 25 | Learning $\phi_O$ and $\phi_{Q^*}$: set `algo=ours` and 26 | - `aux`: select from `op-l2, op-kl, null` 27 | 28 | ## Examples 29 | 30 | To reproduce original ALM(3) on Humanoid-v2: 31 | ```bash 32 | python train.py id=Humanoid-v2 algo=alm-3 33 | ``` 34 | 35 | To reproduce our minimalist $\phi_L$ with l2 objective and EMA targets on Ant-v2: 36 | ```bash 37 | python train.py id=Ant-v2 algo=ours aux=l2 aux_optim=ema aux_coef=v-10.0 38 | ``` 39 | 40 | To reproduce our minimalist $\phi_L$ with reverse KL objective and online targets on Ant-v2: 41 | ```bash 42 | python train.py id=Ant-v2 algo=ours aux=rkl aux_optim=online aux_coef=v-1.0 43 | ``` 44 | 45 | To reproduce learning $\phi_O$ with forward KL objective on distracted HalfCheetah-v2 with 256 distractors: 46 | ```bash 47 | python train.py id=HalfCheetah-v2 distraction=256 scale=1.0 algo=ours aux=op-kl aux_optim=null aux_coef=v-1.0 48 | ``` 49 | 50 | To reproduce learning $\phi_{Q^*}$ on distracted HalfCheetah-v2 with 256 distractors: 51 | ```bash 52 | python train.py id=HalfCheetah-v2 distraction=256 scale=1.0 algo=ours aux=null aux_optim=null aux_coef=v-0.0 53 | ``` 54 | 55 | You will see the logging and executed config files in `logs/` folder. 56 | 57 | ## Logged Results and Plotting 58 | 59 | The log files used in our paper is provided at [Google Drive](https://drive.google.com/file/d/1KaxHySEX3xNCfqUyMsPM2sLzo96SQZd5/view?usp=sharing) (~2.8GB; maybe redundant with some unpublished results). You can download and unzip it to this folder and name it as `logs`. 60 | 61 | We use the [`vis.ipynb`](https://github.com/twni2016/self-predictive-rl/blob/main/mujoco_code/vis.ipynb) for generating plots in our paper. 62 | Below are the commands to generate specific figures in the paper. 63 | 64 | ### Plotting Standard MuJoCo results 65 | In Part 1, in product(), choose `[0, ]` in `distraction` 66 | 67 | - Figure 3: learning curves. Choose 68 | ```python 69 | metric, y_label, sci_axis = "return", "episode return", "both" 70 | tag = "" 71 | hue = "aux" 72 | style = None 73 | # in query_fn(), select the line "return flags["algo"] == "alm-3"" 74 | ``` 75 | - Figure 4, 11, and 12: ablation on ZP targets. Choose 76 | ```python 77 | metric, y_label, sci_axis = "return", "episode return", "both" 78 | # metric, y_label, sci_axis = "rank-2", "matrix rank", "x" 79 | # metric, y_label, sci_axis = "l2", "ZP loss", "x" 80 | tag = "l2" # "fkl", "rkl" 81 | hue = "aux_optim" 82 | style = None 83 | # in query_fn(), select the line "return False" 84 | ``` 85 | - Figure 10: ablation on ALM variants. Choose 86 | ```python 87 | metric, y_label, sci_axis = "return", "episode return", "both" 88 | tag = "ablate-" 89 | hue = "aux" 90 | style = None 91 | # in query_fn(), select the line "return flags["algo"] in ["alm-3", "alm-no-model", "alm-no-model-1"]" 92 | ``` 93 | 94 | ### Plotting Distracting MuJoCo results 95 | 96 | First, in query_fn(), select the line `return False`; in product(), choose `[2**4, 2**5, 2**6, 2**7, 2**8]` in `distraction`. 97 | 98 | - Figure 13: learning curves. In Part 1, 99 | ```python 100 | plt.rcParams["axes.titlesize"] = 11 # for distractors 101 | metric, y_label, sci_axis = "return", "episode return", "both" 102 | tag = "" 103 | hue = "aux" 104 | style = None 105 | ``` 106 | - Figure 5, aggregated plots. In Part 2, 107 | ```python 108 | metric, y_label, sci_axis = "return", "episode return", "y" 109 | hue = "aux" 110 | ``` 111 | 112 | ## Acknowledgement 113 | Our codebase has been largely build on Raj's codebase [ALM](https://github.com/RajGhugare19/alm). 114 | -------------------------------------------------------------------------------- /linear_code/replearn/rollout.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Protocol, Sequence, Tuple 2 | import functools 3 | 4 | import numpy as np 5 | 6 | import dm_env 7 | from dm_env import specs 8 | 9 | import chex 10 | import jax 11 | import rlax 12 | from rlax._src import distributions 13 | 14 | 15 | class Policy(Protocol): 16 | def sample(self, key, observations): 17 | raise NotImplementedError 18 | 19 | def stateful_sample(self, observations): 20 | raise NotImplementedError 21 | 22 | def set_rng_key(self, key): 23 | raise NotImplementedError 24 | 25 | 26 | class DiscretePolicy(Policy, Protocol): 27 | action_distribution: distributions.DiscreteDistribution 28 | 29 | def probs(self, observations): 30 | raise NotImplementedError 31 | 32 | 33 | class EpsilonGreedyPolicy(DiscretePolicy): 34 | def __init__(self, key, epsilon): 35 | self._key = key 36 | self.action_distribution = rlax.epsilon_greedy(epsilon) 37 | 38 | def _sample_and_split(key, observations): 39 | key, a_key = jax.random.split(key) 40 | return key, self.sample(a_key, observations) 41 | 42 | self.sample_and_split = jax.jit(_sample_and_split) 43 | 44 | def preferences(self, observations): 45 | raise NotImplementedError 46 | 47 | def sample(self, key, observations): 48 | return self.action_distribution.sample(key, self.preferences(observations)) 49 | 50 | def probs(self, observations): 51 | return self.action_distribution.probs(self.preferences(observations)) 52 | 53 | def set_rng_key(self, key): 54 | self._key = key 55 | 56 | def stateful_sample(self, observations): 57 | self._key, action = self.sample_and_split(self._key, observations) 58 | return action 59 | 60 | 61 | def generate_trajectory( 62 | key: chex.PRNGKey, 63 | env: dm_env.Environment, 64 | policy: Policy, 65 | max_steps: Optional[int] = None, 66 | ) -> Tuple[Sequence[dm_env.TimeStep], Sequence]: 67 | t = 0 68 | timestep = env.reset() 69 | trajectory = [timestep] 70 | actions = [] 71 | 72 | policy.set_rng_key(key) 73 | 74 | while (not timestep.last()) and (max_steps is None or t < max_steps): 75 | action = policy.stateful_sample(timestep.observation) 76 | timestep = env.step(action) 77 | 78 | t += 1 79 | trajectory.append(timestep) 80 | actions.append(action) 81 | 82 | return trajectory, actions 83 | 84 | 85 | def traj_to_observation_array(trajectory): 86 | _, _, _, observations = zip(*trajectory) 87 | return np.array(observations) 88 | 89 | 90 | def rollout_dataset( 91 | key, 92 | *, 93 | env_cls, 94 | policy, 95 | history_encoder, 96 | act_encoder, 97 | obs_encoder, 98 | max_traj_length: int, 99 | num_traj: Optional[int] = None, 100 | num_steps: Optional[int] = None 101 | ): 102 | if num_traj == num_steps: 103 | raise ValueError( 104 | ( 105 | "Either `num_traj` or `num_steps` is required. Providing a value for both is not " 106 | "supported." 107 | ) 108 | ) 109 | 110 | env_seed, policy_seed = np.random.SeedSequence(key).spawn(2) 111 | policy_key = policy_seed.generate_state(2) 112 | 113 | env = env_cls(seed=env_seed.generate_state(1)[0]) 114 | data = [] 115 | 116 | traj_count = 0 117 | step_count = 0 118 | while (num_traj is None or traj_count < num_traj) and ( 119 | num_steps is None or step_count < num_steps 120 | ): 121 | traj_len_limit = max_traj_length 122 | if num_steps is not None: 123 | traj_len_limit = min((traj_len_limit, num_steps - step_count)) 124 | 125 | traj_key, policy_key = jax.random.split(policy_key) 126 | traj, actions = generate_trajectory( 127 | traj_key, env, policy, max_steps=traj_len_limit 128 | ) 129 | 130 | actions = act_encoder.apply(np.array(actions)) 131 | observations = traj_to_observation_array(traj) 132 | observations = obs_encoder.apply(observations) 133 | history = history_encoder.apply(observations, actions) 134 | 135 | data.append((history[:-1], actions, history[1:])) 136 | 137 | traj_count += 1 138 | step_count += data[-1][0].shape[0] 139 | 140 | states, actions, next_states = [np.concatenate(x) for x in zip(*data)] 141 | return states, actions, next_states 142 | -------------------------------------------------------------------------------- /mujoco_code/vis_tool.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | pd.options.mode.chained_assignment = ( 4 | None # ignore all warnings like SettingWithCopyWarning 5 | ) 6 | 7 | import numpy as np 8 | from typing import Callable 9 | import os, glob 10 | import shutil 11 | import yaml 12 | from ml_collections.config_dict import ConfigDict 13 | 14 | 15 | def _flatten_dict(input_dict, parent_key="", sep="."): 16 | """ 17 | Based on https://github.com/google/flax/blob/main/flax/metrics/tensorboard.py 18 | Flattens and simplifies dict such that it can be used by hparams. 19 | Args: 20 | input_dict: Input dict, e.g., from ConfigDict. 21 | parent_key: String used in recursion. 22 | sep: String used to separate parent and child keys. 23 | Returns: 24 | Flattened dict. 25 | """ 26 | items = [] 27 | for k, v in input_dict.items(): 28 | new_key = parent_key + sep + k if parent_key else k 29 | 30 | # Valid types according to https://github.com/tensorflow/tensorboard/blob/1204566da5437af55109f7a4af18f9f8b7c4f864/tensorboard/plugins/hparams/summary_v2.py 31 | valid_types = ( 32 | bool, 33 | int, 34 | float, 35 | str, 36 | np.bool_, 37 | np.integer, 38 | np.floating, 39 | np.character, 40 | ) 41 | 42 | if isinstance(v, dict) or isinstance(v, ConfigDict): 43 | # Recursively flatten the dict. 44 | if isinstance(v, ConfigDict): 45 | v = v.to_dict() 46 | items.extend(_flatten_dict(v, new_key, sep=sep).items()) 47 | continue 48 | elif not isinstance(v, valid_types): 49 | # Cast any incompatible values as strings such that they can be handled by hparams 50 | v = str(v) 51 | items.append((new_key, v)) 52 | return dict(items) 53 | 54 | 55 | def walk_through( 56 | path: str, 57 | metric: str, 58 | query_fn: Callable, 59 | start: int, 60 | end: int, 61 | steps: int, 62 | window: int, 63 | delete: bool = False, 64 | ): 65 | def isnan(number): 66 | return np.isnan(float(number)) 67 | 68 | def smooth(df): 69 | try: 70 | df = df.dropna(subset=[metric]) # remove NaN rows 71 | except KeyError: 72 | print("!!key error csv", run) 73 | return None 74 | 75 | if isnan(df["env_steps"].iloc[-1]) or df["env_steps"].iloc[-1] < 0.9 * end: 76 | # an incomplete run 77 | print("!!incomplete csv", run, df["env_steps"].iloc[-1], end=" ") 78 | if delete: 79 | shutil.rmtree(run) 80 | print("deleted") 81 | else: 82 | print("\n") 83 | return None 84 | 85 | # smooth by moving average 86 | df[metric] = df[metric].rolling(window=window, min_periods=1).mean() 87 | 88 | # update the columns with interpolated values and aligned steps 89 | aligned_step = np.linspace(start, end, steps).astype(np.int32) 90 | aligned_value = np.interp(aligned_step, df["env_steps"], df[metric]) 91 | 92 | # enlarge or reduce to same number of rows 93 | print(run, df.shape[0], df["env_steps"].iloc[-1]) 94 | 95 | extracted_df = pd.DataFrame( 96 | data={ 97 | "env_steps": aligned_step, 98 | metric: aligned_value, 99 | } 100 | ) 101 | 102 | return extracted_df 103 | 104 | dfs = [] 105 | i = 0 106 | 107 | runs = sorted(glob.glob(os.path.join(path, "*"))) 108 | 109 | for run in runs: 110 | with open(os.path.join(run, "flags.yml")) as f: 111 | flags = yaml.safe_load(f) 112 | 113 | if not query_fn(flags): 114 | continue 115 | 116 | csv_path = os.path.join(run, "progress.csv") 117 | try: 118 | df = pd.read_csv(open(csv_path)) 119 | except pd.errors.EmptyDataError: 120 | print("!!empty csv", run) 121 | continue 122 | 123 | df = smooth(df) 124 | if df is None: 125 | continue 126 | i += 1 127 | 128 | # concat flags (dot) 129 | flags["logdir"] = run 130 | pd_flags = pd.json_normalize(_flatten_dict(flags)) 131 | df_flag = pd.concat([pd_flags] * df.shape[0], axis=0) # repeat rows 132 | df_flag.index = df.index # copy index 133 | df = pd.concat([df, df_flag], axis=1) 134 | dfs.append(df) 135 | # print(flags) 136 | 137 | print("\n in total:", i) 138 | dfs = pd.concat(dfs, ignore_index=True) 139 | return dfs 140 | -------------------------------------------------------------------------------- /linear_code/experiments/run_loadunload.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" 4 | 5 | from typing import NamedTuple, Optional 6 | 7 | import pickle 8 | 9 | from absl import app 10 | from absl import flags 11 | 12 | import numpy as np 13 | import pandas as pd 14 | 15 | import jax 16 | 17 | import optax 18 | 19 | from replearn import learn 20 | from replearn import loadunload 21 | from replearn import features 22 | from replearn import rollout 23 | 24 | 25 | if __name__ == "__main__": 26 | flags.DEFINE_string("results_dir", "./results", "Directory used to log results.") 27 | 28 | FLAGS = flags.FLAGS 29 | 30 | # for i in {1..100}; do echo $(od -A n -t u -N 4 /dev/urandom | tr -d ' \n'); done 31 | SEEDS = [ 32 | 1843764854, 33 | 2953627709, 34 | 2065696022, 35 | 3917761262, 36 | 1592259201, 37 | 684225940, 38 | 3906814804, 39 | 855070892, 40 | 79122374, 41 | 3353291887, 42 | 2001425368, 43 | 3649968832, 44 | 1137905990, 45 | 1274987999, 46 | 1821861786, 47 | 3683081310, 48 | 1087561443, 49 | 310321600, 50 | 1055175732, 51 | 121547637, 52 | 4044866360, 53 | 182248956, 54 | 4039913460, 55 | 462825347, 56 | 3727679027, 57 | 3215526904, 58 | 2431647752, 59 | 2379353061, 60 | 2323226982, 61 | 3725743208, 62 | 2031918674, 63 | 3762025650, 64 | 425696606, 65 | 805171965, 66 | 2503275839, 67 | 2277247045, 68 | 2109158367, 69 | 181242376, 70 | 3956246306, 71 | 2755595351, 72 | 4187306323, 73 | 1007867152, 74 | 1242463926, 75 | 4129796788, 76 | 1099410125, 77 | 209730990, 78 | 64549074, 79 | 712869140, 80 | 3522339780, 81 | 3428373530, 82 | 2464126123, 83 | 3456720685, 84 | 503202288, 85 | 518482939, 86 | 862737849, 87 | 2403136178, 88 | 159923561, 89 | 2839661397, 90 | 2140359683, 91 | 2108678269, 92 | 1984270380, 93 | 678399733, 94 | 358224968, 95 | 4124224329, 96 | 3459659839, 97 | 3008777333, 98 | 1884818714, 99 | 2158764360, 100 | 3267115782, 101 | 1498615144, 102 | 729227282, 103 | 356343867, 104 | 3273136234, 105 | 600066107, 106 | 3613546418, 107 | 1637623759, 108 | 1043304407, 109 | 2854775057, 110 | 2055801193, 111 | 1136497228, 112 | 1506477464, 113 | 3358102518, 114 | 3061257360, 115 | 3644648965, 116 | 3559804656, 117 | 3972212350, 118 | 963994575, 119 | 1947277982, 120 | 4279881374, 121 | 156505623, 122 | 2226220832, 123 | 4149186854, 124 | 1200204930, 125 | 1194710917, 126 | 3409768682, 127 | 2569256998, 128 | 3500793866, 129 | 606886659, 130 | 2720394887, 131 | 2696168484, 132 | ] 133 | 134 | 135 | def main(argv): 136 | if len(argv) > 1: 137 | raise app.UsageError("Too many command-line arguments.") 138 | 139 | print(f"Using result dir: {FLAGS.results_dir}") 140 | 141 | action_encoder = features.OneHot(2) 142 | observation_encoder = features.OneHot(3) 143 | history_encoder = features.TruncatedHistoryEncoder(20) 144 | 145 | encoder = learn.create_latent_encoder(2) 146 | optimizer = optax.sgd(1e-2) 147 | 148 | results = [] 149 | for i, seed in enumerate(SEEDS): 150 | print(f"Starting run {i} with seed={seed}") 151 | data_seed, train_seed = np.random.SeedSequence(seed).spawn(2) 152 | s_t, a_t, s_tp1 = rollout.rollout_dataset( 153 | data_seed.generate_state(2), 154 | env_cls=loadunload.LoadUnload, 155 | policy=loadunload.LoadUnloadPolicy(None), 156 | history_encoder=history_encoder, 157 | act_encoder=action_encoder, 158 | obs_encoder=observation_encoder, 159 | max_traj_length=200, 160 | num_traj=10, 161 | ) 162 | for use_stop_gradient in ["Online", "Detached", "EMA"]: 163 | logs = learn.train( 164 | key=train_seed.generate_state(2), 165 | optimizer=optimizer, 166 | encoder=encoder, 167 | states=s_t, 168 | actions=a_t, 169 | next_states=s_tp1, 170 | num_steps=500, 171 | log_n_steps=10, 172 | use_stop_gradient=use_stop_gradient, 173 | ) 174 | for log in logs: 175 | log.update({"seed": seed, "use_stop_gradient": use_stop_gradient}) 176 | results = results + logs 177 | 178 | with open(os.path.join(FLAGS.results_dir, "loadunload.pkl"), "wb") as f: 179 | pickle.dump(results, f) 180 | 181 | 182 | if __name__ == "__main__": 183 | app.run(main) 184 | -------------------------------------------------------------------------------- /minigrid_code/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils.rnn import pack_padded_sequence 5 | 6 | 7 | class SeqEncoder(nn.Module): 8 | """ 9 | rho in AIS, phi in RL literature. 10 | Deterministic model z = phi(h) 11 | """ 12 | 13 | def __init__(self, num_obs, num_actions, AIS_state_size): 14 | super(SeqEncoder, self).__init__() 15 | input_ndims = ( 16 | num_obs + num_actions + 1 17 | ) # including reward, but it is uninformative 18 | self.AIS_state_size = AIS_state_size 19 | self.fc1 = nn.Linear(input_ndims, AIS_state_size) 20 | self.fc2 = nn.Linear(AIS_state_size, AIS_state_size) 21 | self.lstm = nn.LSTM(AIS_state_size, AIS_state_size, batch_first=True) 22 | 23 | self.apply(weights_init_) 24 | 25 | def get_initial_hidden(self, batch_size, device): # TODO: 26 | return ( 27 | torch.zeros(1, batch_size, self.AIS_state_size).to(device), 28 | torch.zeros(1, batch_size, self.AIS_state_size).to(device), 29 | ) 30 | 31 | def forward( 32 | self, 33 | x, 34 | batch_size, 35 | hidden, 36 | device, 37 | batch_lengths, 38 | pack_sequence=True, 39 | ): 40 | if hidden == None: 41 | hidden = self.get_initial_hidden(batch_size, device) 42 | x = F.elu(self.fc1(x)) 43 | x = F.elu(self.fc2(x)) 44 | if pack_sequence is True: 45 | x = pack_padded_sequence( 46 | x, batch_lengths, batch_first=True, enforce_sorted=False 47 | ) 48 | # print('packed',x.data.shape) 49 | x, hidden = self.lstm(x, hidden) 50 | return x, hidden 51 | 52 | 53 | class LatentModel(nn.Module): 54 | """ 55 | psi in AIS, P_theta in RL. 56 | Deterministic latent transition models. 57 | E[o' | z, a] or E[z' | z, a], depends on num_obs 58 | """ 59 | 60 | def __init__(self, num_obs, num_actions, AIS_state_size): 61 | super(LatentModel, self).__init__() 62 | input_ndims = AIS_state_size + num_actions 63 | self.fc1_d = nn.Linear(input_ndims, AIS_state_size // 2) 64 | self.fc2_d = nn.Linear(AIS_state_size // 2, num_obs) 65 | 66 | self.apply(weights_init_) 67 | 68 | def forward(self, x): 69 | x_d = F.elu(self.fc1_d(x)) 70 | obs = self.fc2_d(x_d) 71 | return obs 72 | 73 | 74 | class AISModel(nn.Module): 75 | """ 76 | psi in AIS, P_theta in RL. 77 | Deterministic transition and reward models. 78 | E[o' | z, a] or E[z' | z, a] AND E[r | z, a] 79 | """ 80 | 81 | def __init__(self, num_obs, num_actions, AIS_state_size): 82 | super(AISModel, self).__init__() 83 | input_ndims = AIS_state_size + num_actions 84 | self.fc1_d = nn.Linear(input_ndims, AIS_state_size // 2) 85 | self.fc2_d = nn.Linear(AIS_state_size // 2, num_obs) 86 | self.fc1_r = nn.Linear(input_ndims, AIS_state_size // 2) 87 | self.fc2_r = nn.Linear(AIS_state_size // 2, 1) 88 | 89 | self.apply(weights_init_) 90 | 91 | def forward(self, x): 92 | x_d = F.elu(self.fc1_d(x)) 93 | obs = self.fc2_d(x_d) 94 | x_r = F.elu(self.fc1_r(x)) 95 | rew = self.fc2_r(x_r) 96 | return obs, rew 97 | 98 | 99 | class QNetwork_discrete(nn.Module): 100 | def __init__(self, num_inputs, num_actions, hidden_dim): 101 | super(QNetwork_discrete, self).__init__() 102 | 103 | self.linear1 = nn.Linear(num_inputs, hidden_dim) 104 | self.linear2 = nn.Linear(hidden_dim, hidden_dim) 105 | self.linear3 = nn.Linear(hidden_dim, num_actions) 106 | 107 | self.apply(weights_init_) 108 | 109 | def forward(self, state): 110 | x1 = F.elu(self.linear1(state)) 111 | x1 = F.elu(self.linear2(x1)) 112 | x1 = self.linear3(x1) 113 | 114 | return x1 115 | 116 | 117 | def convert_int_to_onehot(value, num_values): 118 | onehot = torch.zeros(num_values) 119 | if value >= 0: # ignore negative index 120 | onehot[int(value)] = 1.0 121 | return onehot 122 | 123 | 124 | def weights_init_(m, gain=1): 125 | if isinstance(m, nn.Linear): 126 | torch.nn.init.xavier_uniform_(m.weight, gain=gain) 127 | torch.nn.init.constant_(m.bias, 0) 128 | 129 | 130 | def soft_update(target, source, tau): 131 | for target_param, param in zip(target.parameters(), source.parameters()): 132 | target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau) 133 | 134 | 135 | def hard_update(target, source): 136 | for target_param, param in zip(target.parameters(), source.parameters()): 137 | target_param.data.copy_(param.data) 138 | -------------------------------------------------------------------------------- /linear_code/replearn/learn.py: -------------------------------------------------------------------------------- 1 | from typing import Any, NamedTuple 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | import numpy as np 7 | from copy import deepcopy 8 | import haiku as hk 9 | import optax 10 | 11 | 12 | class Parameters(NamedTuple): 13 | encoder: Any 14 | transition: Any 15 | 16 | 17 | def create_latent_encoder(latent_size): 18 | def encode_latent_state(state): 19 | return hk.Linear( 20 | latent_size, 21 | with_bias=False, 22 | w_init=hk.initializers.Orthogonal(), 23 | )(state) 24 | 25 | return hk.transform(encode_latent_state) 26 | 27 | 28 | def apply_transition(trans_matrix, latent_state, action): 29 | state_action = jnp.concatenate((latent_state, action)) 30 | return jnp.dot(trans_matrix.T, state_action) 31 | 32 | 33 | def train( 34 | key, 35 | optimizer, 36 | encoder, 37 | states, 38 | actions, 39 | next_states, 40 | use_stop_gradient: str, 41 | num_steps, 42 | log_n_steps, 43 | ): 44 | def solve_transition_params(encoder_params, target_encoder_params): 45 | z_t = encoder.apply(encoder_params, None, states) 46 | za_t = jnp.concatenate((z_t, actions), axis=1) 47 | z_tp1 = encoder.apply(target_encoder_params, None, next_states) 48 | opt_trans_params, *_ = jnp.linalg.lstsq(za_t, z_tp1) 49 | return opt_trans_params 50 | 51 | def loss(encoder_params, target_encoder_params, s_t, a_t, s_tp1): 52 | batch_encoder_apply = jax.vmap(encoder.apply, (None, None, 0)) 53 | z_t = batch_encoder_apply(encoder_params, None, s_t) 54 | 55 | if use_stop_gradient == "Detached": 56 | z_tp1 = batch_encoder_apply( 57 | jax.lax.stop_gradient(encoder_params), None, s_tp1 58 | ) 59 | opt_trans_params = jax.lax.stop_gradient( 60 | solve_transition_params(encoder_params, encoder_params) 61 | ) 62 | elif use_stop_gradient == "Online": 63 | z_tp1 = batch_encoder_apply(encoder_params, None, s_tp1) 64 | opt_trans_params = jax.lax.stop_gradient( 65 | solve_transition_params(encoder_params, encoder_params) 66 | ) 67 | else: # EMA 68 | z_tp1 = batch_encoder_apply(target_encoder_params, None, s_tp1) 69 | opt_trans_params = jax.lax.stop_gradient( 70 | solve_transition_params(encoder_params, target_encoder_params) 71 | ) 72 | 73 | estimated_z_tp1 = jax.vmap(apply_transition, (None, 0, 0))( 74 | opt_trans_params, z_t, a_t 75 | ) 76 | error = estimated_z_tp1 - z_tp1 77 | 78 | return 0.5 * jnp.mean(jnp.sum(error**2, axis=-1)) 79 | 80 | @jax.jit 81 | def step(params, target_params, opt_state, s_t, a_t, s_tp1): 82 | loss_value, grads = jax.value_and_grad(loss)( 83 | params, target_params, s_t, a_t, s_tp1 84 | ) 85 | updates, opt_state = optimizer.update(grads, opt_state, params) 86 | params = optax.apply_updates(params, updates) 87 | target_params = target_update(target_params, params) 88 | 89 | return params, target_params, opt_state, loss_value 90 | 91 | def target_update( 92 | target_params, 93 | new_params, 94 | tau: float = 0.005, 95 | ): 96 | return jax.tree_util.tree_map( 97 | lambda p, tp: p * tau + tp * (1 - tau), new_params, target_params 98 | ) 99 | 100 | assert use_stop_gradient in ["Online", "Detached", "EMA"] 101 | 102 | key, init_key = jax.random.split(key) 103 | encoder_params = encoder.init(init_key, states[0]) 104 | target_encoder_params = deepcopy(encoder_params) 105 | opt_state = optimizer.init(encoder_params) 106 | 107 | loss_value = jax.jit(loss)( 108 | encoder_params, target_encoder_params, states, actions, next_states 109 | ) 110 | 111 | logs = [] 112 | for i in range(num_steps): 113 | if i % log_n_steps == 0: 114 | params = encoder_params["linear"]["w"] 115 | logs.append( 116 | { 117 | "step": i, 118 | "loss": float(loss_value), 119 | "params": np.array(params), 120 | } 121 | ) 122 | 123 | encoder_params, target_encoder_params, opt_state, loss_value = step( 124 | params=encoder_params, 125 | target_params=target_encoder_params, 126 | opt_state=opt_state, 127 | s_t=states, 128 | a_t=actions, 129 | s_tp1=next_states, 130 | ) 131 | 132 | params = encoder_params["linear"]["w"] 133 | logs.append( 134 | { 135 | "step": num_steps, 136 | "loss": float(loss_value), 137 | "params": np.array(params), 138 | } 139 | ) 140 | 141 | return logs 142 | -------------------------------------------------------------------------------- /linear_code/experiments/run_mountaincar.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" 4 | 5 | from typing import NamedTuple, Optional 6 | 7 | import pickle 8 | 9 | from absl import app 10 | from absl import flags 11 | 12 | import numpy as np 13 | import pandas as pd 14 | 15 | import jax 16 | 17 | import optax 18 | 19 | from replearn import learn 20 | from replearn import mountaincar 21 | from replearn import features 22 | from replearn import rollout 23 | 24 | 25 | if __name__ == "__main__": 26 | flags.DEFINE_string("results_dir", "./results", "Directory used to log results.") 27 | 28 | FLAGS = flags.FLAGS 29 | 30 | # for i in {1..20}; do echo $(od -A n -t u -N 4 /dev/urandom | tr -d ' \n'); done 31 | SEEDS = [ 32 | 1843764854, 33 | 2953627709, 34 | 2065696022, 35 | 3917761262, 36 | 1592259201, 37 | 684225940, 38 | 3906814804, 39 | 855070892, 40 | 79122374, 41 | 3353291887, 42 | 2001425368, 43 | 3649968832, 44 | 1137905990, 45 | 1274987999, 46 | 1821861786, 47 | 3683081310, 48 | 1087561443, 49 | 310321600, 50 | 1055175732, 51 | 121547637, 52 | 4044866360, 53 | 182248956, 54 | 4039913460, 55 | 462825347, 56 | 3727679027, 57 | 3215526904, 58 | 2431647752, 59 | 2379353061, 60 | 2323226982, 61 | 3725743208, 62 | 2031918674, 63 | 3762025650, 64 | 425696606, 65 | 805171965, 66 | 2503275839, 67 | 2277247045, 68 | 2109158367, 69 | 181242376, 70 | 3956246306, 71 | 2755595351, 72 | 4187306323, 73 | 1007867152, 74 | 1242463926, 75 | 4129796788, 76 | 1099410125, 77 | 209730990, 78 | 64549074, 79 | 712869140, 80 | 3522339780, 81 | 3428373530, 82 | 2464126123, 83 | 3456720685, 84 | 503202288, 85 | 518482939, 86 | 862737849, 87 | 2403136178, 88 | 159923561, 89 | 2839661397, 90 | 2140359683, 91 | 2108678269, 92 | 1984270380, 93 | 678399733, 94 | 358224968, 95 | 4124224329, 96 | 3459659839, 97 | 3008777333, 98 | 1884818714, 99 | 2158764360, 100 | 3267115782, 101 | 1498615144, 102 | 729227282, 103 | 356343867, 104 | 3273136234, 105 | 600066107, 106 | 3613546418, 107 | 1637623759, 108 | 1043304407, 109 | 2854775057, 110 | 2055801193, 111 | 1136497228, 112 | 1506477464, 113 | 3358102518, 114 | 3061257360, 115 | 3644648965, 116 | 3559804656, 117 | 3972212350, 118 | 963994575, 119 | 1947277982, 120 | 4279881374, 121 | 156505623, 122 | 2226220832, 123 | 4149186854, 124 | 1200204930, 125 | 1194710917, 126 | 3409768682, 127 | 2569256998, 128 | 3500793866, 129 | 606886659, 130 | 2720394887, 131 | 2696168484, 132 | ] 133 | 134 | 135 | def main(argv): 136 | if len(argv) > 1: 137 | raise app.UsageError("Too many command-line arguments.") 138 | 139 | print(f"Using result dir: {FLAGS.results_dir}") 140 | 141 | env = mountaincar.MountainCar() 142 | obs_encoder = features.RBFEncoder( 143 | centers=features.uniform_centers(env, 10), 144 | scales=features.normalized_scales(env, 0.15), 145 | normalized=False, 146 | ) 147 | action_encoder = features.OneHot(3) 148 | history_encoder = features.TruncatedHistoryEncoder(1) 149 | 150 | encoder = learn.create_latent_encoder(2) 151 | optimizer = optax.sgd(1e-2) 152 | 153 | results = [] 154 | for i, seed in enumerate(SEEDS): 155 | print(f"Starting run {i} with seed={seed}") 156 | data_seed, train_seed = np.random.SeedSequence(seed).spawn(2) 157 | s_t, a_t, s_tp1 = rollout.rollout_dataset( 158 | data_seed.generate_state(2), 159 | env_cls=mountaincar.MountainCar, 160 | policy=mountaincar.MountainCarPolicy(None, 0.1), 161 | history_encoder=history_encoder, 162 | act_encoder=action_encoder, 163 | obs_encoder=obs_encoder, 164 | max_traj_length=200, 165 | num_traj=10, 166 | ) 167 | for use_stop_gradient in ["Online", "Detached", "EMA"]: 168 | logs = learn.train( 169 | key=train_seed.generate_state(2), 170 | optimizer=optimizer, 171 | encoder=encoder, 172 | states=s_t, 173 | actions=a_t, 174 | next_states=s_tp1, 175 | num_steps=500, 176 | log_n_steps=10, 177 | use_stop_gradient=use_stop_gradient, 178 | ) 179 | for log in logs: 180 | log.update({"seed": seed, "use_stop_gradient": use_stop_gradient}) 181 | results = results + logs 182 | 183 | with open(os.path.join(FLAGS.results_dir, "mountaincar.pkl"), "wb") as f: 184 | pickle.dump(results, f) 185 | 186 | 187 | if __name__ == "__main__": 188 | app.run(main) 189 | -------------------------------------------------------------------------------- /minigrid_code/README.md: -------------------------------------------------------------------------------- 1 | # Code for History Representation Learning in Sparse-Reward POMDPs (Section 5.3) 2 | 3 | Code contributors: [Erfan Seyedsalehi](https://openreview.net/profile?id=~Erfan_Seyedsalehi2) (main), [Tianwei Ni](https://twni2016.github.io/). 4 | 5 | Benchmark: [MiniGrid](https://minigrid.farama.org/environments/minigrid/ 6 | ) benchmark composed of 20 tasks, featuring sparse rewards and partial observability. 7 | 8 | Baseline: a single-thread version of [R2D2](https://openreview.net/forum?id=r1lyTjAqYX), named as R2D2 below. 9 | 10 | 11 | ## Installation 12 | 13 | We use python 3.7+ and list the basic requirements in `requirements.txt`. 14 | 15 | 16 | ## Examples 17 | 18 | To reproduce R2D2 in SimpleCrossingS9N1: 19 | ```bash 20 | python main.py --num_steps 4000000 --env_name MiniGrid-SimpleCrossingS9N1-v0 \ 21 | --aux None 22 | ``` 23 | 24 | To reproduce our minimalist algorithm (end-to-end ZP with EMA target) in SimpleCrossingS9N1: 25 | ```bash 26 | python main.py --num_steps 4000000 --env_name MiniGrid-SimpleCrossingS9N1-v0 \ 27 | --aux ZP --aux_coef 1.0 --aux_optim ema 28 | ``` 29 | 30 | To reproduce the minimalist algorithm (end-to-end OP) in SimpleCrossingS9N1: 31 | ```bash 32 | python main.py --num_steps 4000000 --env_name MiniGrid-SimpleCrossingS9N1-v0 \ 33 | --aux OP --aux_coef 0.01 34 | ``` 35 | 36 | To reproduce the phased algorithm (RP + ZP with EMA) in SimpleCrossingS9N1: 37 | ```bash 38 | python main.py --num_steps 4000000 --env_name MiniGrid-SimpleCrossingS9N1-v0 \ 39 | --aux AIS-P2 --aux_coef 1.0 40 | ``` 41 | 42 | To reproduce the phased algorithm (RP + OP) in SimpleCrossingS9N1: 43 | ```bash 44 | python main.py --num_steps 4000000 --env_name MiniGrid-SimpleCrossingS9N1-v0 \ 45 | --aux AIS --aux_coef 1.0 46 | ``` 47 | 48 | ## Logged Results and Plotting 49 | 50 | The log files used in our paper is provided at [Google Drive](https://drive.google.com/file/d/1abVEBh7hrk9kdPjzsENR30Tih80iU5Qb/view?usp=sharing) (~125MB). You can download and unzip it to this folder and name it as `logs`. 51 | 52 | We use the [`vis.ipynb`](https://github.com/twni2016/self-predictive-rl/blob/main/minigrid_code/vis.ipynb) for generating plots in our paper. Below are the commands to generate specific figures in the paper: 53 | - Figure 14: individual-task episode return curves; Figure 6a: aggregated episode return curves. In Part 1 (for Figure 14) or Part 2 (for Figure 6a), choose 54 | ```python 55 | metric, y_label, sci_axis = "return", "episode return", "x" 56 | tag = "" 57 | hue = "aux" 58 | style = "modular" 59 | ``` 60 | - Figure 15: individual-task matrix rank curves; Figure 6b: aggregated matrix rank curves. In Part 1 (for Figure 15) or Part 2 (for Figure 6b), choose 61 | ```python 62 | metric, y_label, sci_axis = "rank-3", "matrix rank", "x" 63 | tag = "ZP" 64 | hue = "aux_optim" 65 | style = None 66 | ``` 67 | Keep the values of the other variables and run the whole part to generate all the plots. Please be patient: it takes 1min to generate one individual-task plot and takes <20min to generate one aggregated plot. 68 | 69 | ## Flags 70 | 71 | This program accepts the following command line arguments: 72 | 73 | | Option | Description | 74 | | --------------- | ----------- | 75 | | `--aux` | This specifies whether model-learning is done or not. 'AIS' is for model learning (RQL-AIS). 'None' is for no model learning (ND-R2D2). | 76 | | `--env_name` | The environment name. | 77 | | `--batch_size` | The batch size used for AIS updates and reinforcement learning updates. This specifies the number of samples drawn from the buffer. Each trajectory has a fixed length (learning_obs_len) | 78 | | `--hidden_size` | The number of neurons in the hidden layers of the Q network. | 79 | | `--gamma` | Discount Factor | 80 | | `--AIS_state_size` | The size of the hidden vector and the output of the LSTM used as state representation for the POMDP. | 81 | | `--rl_lr` | The learning rate used for updating Q networks and the LSTMs (for ND-R2D2) and only the Q-network (for RQL-AIS) | 82 | | `--aux_lr` | The learning rate used for updating the AIS components (for RQL-AIS). | 83 | | `--num_steps` | Total number of training steps taken in the environment.| 84 | | `--target_update_interval` | This specifies the environment step intervals after which the target Q network (and target LSTM in case of ND-R2D2) is updated. | 85 | | `--replay_size` | This spcecifies the number of episodes that are stored in the replay memory. After the replay buffer is filled, new experience episodes will overwrite the least recenet episodes in the buffer. | 86 | | `--aux_coef` | The hyperparameter which specifies how we are averaging between reward learning loss and next observation predictions loss in the AIS learning phase. | 87 | | `--logging_freq` | The frequency in terms of environment steps in which we evaluate the agent, log the results and save the neural network parameters on disk. | 88 | | `--rl_update_every_n_steps` | It specifies the frequency in terms of environment steps at which we do reinforcement learning updates. | 89 | | `--EPS_start` | This specifies the start value for the epsilon hyperparameter used in Q-learning for exploration. | 90 | | `--EPS_decay` | This specifies decay rate for the epsilon hyperparameter used in Q-learning for exploration | 91 | | `--EPS_end` | This specifies the end value for the epsilon hyperparameter used in Q-learning for exploration | 92 | | `--burn_in_len` | Length of the preceding Burn-In Sequence saved with each sample in the R2D2 buffer. | 93 | | `--learning_obs_len` | Sequence length of R2D2 samples. | 94 | | `--forward_len` | The multi-step Q-learning length. | 95 | | `--test_epsilon` | Epsilon value used at test time. Default is 0. | 96 | 97 | ## Acknowledgement 98 | 99 | Our codebase has been largely build on Erfan's codebase [RQL-AIS](https://github.com/esalehi1996/POMDP_RL). 100 | -------------------------------------------------------------------------------- /minigrid_code/run.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from agent import Agent 3 | from r2d2replaybuffer import r2d2_ReplayMemory 4 | import torch 5 | import gymnasium as gym 6 | import time 7 | import logger 8 | 9 | 10 | def run_exp(args): 11 | ## Create env 12 | assert "MiniGrid-" in args["env_name"] 13 | env = gym.make(args["env_name"]) 14 | test_env = gym.make(args["env_name"]) 15 | 16 | obs_dim = np.prod(env.observation_space["image"].shape) # only use image obs 17 | act_dim = env.action_space.n # discrete action 18 | logger.log( 19 | env.observation_space["image"], 20 | f"obs_dim={obs_dim} act_dim={act_dim} max_steps={env.max_steps}", 21 | ) 22 | 23 | ## Initialize agent and buffer 24 | agent = Agent(env, args) 25 | 26 | memory = r2d2_ReplayMemory(args["replay_size"], obs_dim, act_dim, args) 27 | 28 | ## Training 29 | seed = args["seed"] 30 | np.random.seed(seed) 31 | torch.manual_seed(seed) 32 | env.reset(seed=seed) 33 | test_env.reset(seed=seed + 1) 34 | memory.reset(seed) 35 | 36 | total_numsteps = 0 37 | while total_numsteps <= args["num_steps"]: 38 | hidden_p = agent.get_initial_hidden() 39 | action = -1 # placeholder 40 | reward = 0 41 | state = env.reset()[0]["image"].astype(np.float32).reshape(-1) 42 | 43 | ep_hiddens = [hidden_p] # z[-1] 44 | ep_actions = [action] # a[-1] 45 | ep_rewards = [reward] # r[-1] 46 | ep_states = [state] # o[0] 47 | 48 | while True: 49 | if total_numsteps % args["logging_freq"] == 0: 50 | if total_numsteps > 0: # except the first evaluation 51 | FPS = running_metrics["length"] / (time.time() - time_now) 52 | # average the metrics 53 | running_metrics = { 54 | k: v / k_episode for k, v in running_metrics.items() 55 | } 56 | running_losses = { 57 | k: v / k_updates for k, v in running_losses.items() 58 | } 59 | log_and_test( 60 | test_env, 61 | agent, 62 | total_numsteps, 63 | running_metrics if total_numsteps > 0 else None, 64 | running_losses if total_numsteps > 0 else None, 65 | FPS if total_numsteps > 0 else None, 66 | ) 67 | ## running metrics 68 | k_episode = 0 # num of env episodes 69 | k_updates = 0 # num of agent updates 70 | running_metrics = { 71 | k: 0.0 72 | for k in [ 73 | "return", 74 | "length", 75 | "success", 76 | ] 77 | } 78 | running_losses = {} 79 | time_now = time.time() 80 | 81 | if total_numsteps < args["random_actions_until"]: # never used 82 | action = env.action_space.sample() 83 | else: 84 | action, hidden_p = agent.select_action( 85 | state, 86 | action, 87 | reward, 88 | hidden_p, 89 | EPS_up=True, 90 | evaluate=False, 91 | ) 92 | 93 | next_state, reward, terminated, truncated, _ = env.step(action) # Step 94 | state = next_state["image"].astype(np.float32).reshape(-1) 95 | 96 | ep_hiddens.append(hidden_p) # z[t] 97 | ep_actions.append(action) # a[t] 98 | ep_rewards.append(reward) # r[t] 99 | ep_states.append(state) # o[t+1] 100 | 101 | running_metrics["return"] += reward 102 | running_metrics["length"] += 1 103 | 104 | if ( 105 | len(memory) > args["batch_size"] 106 | and total_numsteps % args["rl_update_every_n_steps"] == 0 107 | ): 108 | losses = agent.update_parameters( 109 | memory, args["batch_size"], args["rl_updates_per_step"] 110 | ) 111 | k_updates += 1 112 | if running_losses == {}: 113 | running_losses = losses 114 | else: 115 | running_losses = { 116 | k: running_losses[k] + v for k, v in losses.items() 117 | } 118 | 119 | total_numsteps += 1 120 | 121 | if terminated or truncated: 122 | break 123 | 124 | # Append transition to memory 125 | memory.push(ep_states, ep_actions, ep_rewards, ep_hiddens) 126 | 127 | k_episode += 1 128 | running_metrics["success"] += int(reward > 0.0) # terminal reward 129 | 130 | 131 | def log_and_test( 132 | env, 133 | agent, 134 | total_numsteps, 135 | running_metrics, 136 | running_losses, 137 | FPS, 138 | ): 139 | logger.record_step("env_steps", total_numsteps) 140 | if total_numsteps > 0: 141 | for k, v in running_metrics.items(): 142 | logger.record_tabular("train/" + k, v) 143 | for k, v in running_losses.items(): 144 | logger.record_tabular(k, v) 145 | logger.record_tabular("FPS", FPS) 146 | 147 | metrics = { 148 | k: 0.0 149 | for k in [ 150 | "return", 151 | "length", 152 | "success", 153 | ] 154 | } 155 | episodes = 10 156 | for _ in range(episodes): 157 | hidden_p = agent.get_initial_hidden() 158 | action = -1 # placeholder 159 | reward = 0 160 | state = env.reset()[0]["image"].astype(np.float32).reshape(-1) 161 | 162 | while True: 163 | action, hidden_p = agent.select_action( 164 | state, action, reward, hidden_p, EPS_up=False, evaluate=True 165 | ) 166 | next_state, reward, terminated, truncated, _ = env.step(action) 167 | metrics["return"] += reward 168 | metrics["length"] += 1 169 | 170 | state = next_state["image"].astype(np.float32).reshape(-1) 171 | 172 | if terminated or truncated: 173 | break 174 | 175 | metrics["success"] += int(reward > 0.0) 176 | 177 | metrics = {k: metrics[k] / episodes for k in metrics.keys()} 178 | for k, v in metrics.items(): 179 | logger.record_tabular(k, v) 180 | logger.dump_tabular() 181 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Bridging State and History Representations: Understanding Self-Predictive RL 2 | This is the official code for the paper 3 | 4 | ["Bridging State and History Representations: Understanding Self-Predictive RL"](https://arxiv.org/abs/2401.08898), **ICLR 2024** 5 | 6 | by [Tianwei Ni](https://twni2016.github.io/), [Benjamin Eysenbach](https://ben-eysenbach.github.io/), [Erfan Seyedsalehi](https://openreview.net/profile?id=~Erfan_Seyedsalehi2), [Michel Ma](https://scholar.google.com/citations?user=capMFX8AAAAJ&hl=en), [Clement Gehring](https://people.csail.mit.edu/gehring/), [Aditya Mahajan](https://cim.mcgill.ca/~adityam/), and [Pierre-Luc Bacon](http://pierrelucbacon.com/). 7 | 8 | ## TLDR: A Minimal Augmentation for Model-Free RL Loss 🚀 9 | 10 | In this work, we demonstrate a *principled, minimal, and effective* design, as reflected in the following pseudocode: 11 | 12 | ```python 13 | def total_loss(hist, act, next_obs, rew): 14 | """ 15 | Compute the total loss for learning one of the three abstractions. 16 | 17 | Args: Batch of transition data (h, a, o', r). 18 | hist h: (B, T, O+A), act a: (B, A), next_obs o': (B, O), rew r: (B, 1) 19 | """ 20 | 21 | # Encode current history into a latent state 22 | h_enc = Encoder(hist) # z: (B, Z) 23 | next_hist = torch.cat([hist, torch.cat([act, next_obs], dim=-1)], dim=1) # h': (B, T+1, O+A) 24 | # Encode next history into a latent state using an EMA encoder 25 | next_h_enc_tar = Encoder_Target(next_hist) # z': (B, Z) 26 | 27 | # Model-free RL loss in the latent state space (e.g., TD3, R2D2) 28 | rl_loss = RL_loss(h_enc, act, next_h_enc_tar, rew) # (z, a, z', r) 29 | 30 | if [learning Q^*-irrelevance representations]: # model-free RL 31 | return rl_loss 32 | elif [learning self-predictive representations]: # l2 loss with EMA ZP target 33 | zp_loss = ((Latent_Model(h_enc, act) - next_h_enc_tar)**2).sum(-1).mean() 34 | return rl_loss + coef * zp_loss 35 | elif [learning observation-predictive representations]: # l2 loss 36 | op_loss = ((Observ_Model(h_enc, act) - next_obs)**2).sum(-1).mean() 37 | return rl_loss + coef * op_loss 38 | ``` 39 | 40 | ## Background 🔍 41 | 42 | In deep RL, numerous representation learning methods have been proposed, ranging from *state representations* for MDPs to *history representations* for POMDPs. However, these methods often involve different learning objectives and training techniques, making it challenging for RL practitioners to select the most suitable approach for their specific problems. 43 | 44 | This work unifies various representation learning methods by analyzing their objectives and ideal abstractions. Surprisingly, these methods are connected by a **self-predictive** condition, termed the **ZP condition**: *the latent state generated by the encoder can be used to predict the next latent state*. We summarize three abstractions learned by these methods and provide examples of popular instances: 45 | 46 | 1. **$Q^*$-irrelevance abstraction**: purely maximizes returns. Examples: [model-free RL (cleanrl)](https://github.com/vwxyzjn/cleanrl), [recurrent model-free RL](https://github.com/twni2016/pomdp-baselines). 47 | 2. **Self-predictive abstraction**: involves the self-predictive (ZP) and reward-prediction (RP) conditions. Examples: [SPR](https://github.com/mila-iqia/spr), [DBC](https://github.com/facebookresearch/deep_bisim4control), [TD-MPC](https://github.com/nicklashansen/tdmpc), [EfficientZero](https://github.com/YeWR/EfficientZero). 48 | 3. **Observation-predictive abstraction**: involves the observation-predictive (OP) and reward-prediction (RP) conditions. Examples: [Dreamer](https://github.com/danijar/dreamerv3), [SLAC](https://github.com/alexlee-gk/slac), [SAC-AE](https://github.com/denisyarats/pytorch_sac_ae). 49 | 50 | ## Using Our Minimalist Algorithm as Your Baseline 🔧 51 | 52 | In our paper, we establish how the ZP condition connects the three abstractions. Crucially, we investigate the training objectives for learning ZP, including widely-used $\ell_2$, $\cos$, and KL divergences, along with the *stop-gradient* operator to prevent representational collapse. 53 | 54 | These analyses lead to the development of **our minimalist algorithm** for learning self-predictive abstraction. We provide the code as **a baseline** for future research, believing it to be: 55 | - **Principled in representation learning**: targets each of the three abstractions. 56 | - **Minimal in algorithmic design**: uses single auxiliary task for representation learning (just one extra loss), and model-free policy optimization (no planning). 57 | - **Effective in practice**: our implementation of self-predictive representations outperforms $Q^*$-irrelevance abstraction (the model-free baseline), and is more robust to distractions than observation-predictive representations. 58 | 59 | ## Code Implementation 🗂️ 60 | 61 | - [`mujoco_code/`](https://github.com/twni2016/self-predictive-rl/tree/main/mujoco_code): contains the code on standard MDPs (Section 5.1) and distracting MDPs (Section 5.2) using [MuJoCo](https://gymnasium.farama.org/environments/mujoco/) simulators. 62 | - [`minigrid_code/`](https://github.com/twni2016/self-predictive-rl/tree/main/minigrid_code): contains the code on sparse-reward POMDPs (Section 5.3) using [MiniGrid](https://minigrid.farama.org/index.html) environments. 63 | - [`linear_code/`](https://github.com/twni2016/self-predictive-rl/tree/main/linear_code): contains the code for illustrating our theorem on stop-gradient to prevent collapse (Section 4.2). 64 | 65 | ## Our Recommendations for Practitioners 📋 66 | 67 | Here we restate our preliminary recommendations from our paper (Section 6): 68 | 69 | - Analyze your task first. For example, in noisy or distracting tasks, consider using self-predictive representations. In sparse-reward tasks, consider using observation-predictive representations. In deterministic tasks, choose the deterministic $\ell_2$ objectives for representation learning. 70 | - Use our minimalist algorithm as your baseline. Our algorithm allows for an independent evaluation of representation learning and policy optimization effects. Start with end-to-end learning and model-free RL for policy optimization. 71 | - Implementation tips. For our minimalist algorithm, we recommend adopting the $\ell_2$ objective with EMA ZP targets first. When tackling POMDPs, start with recurrent networks as the encoder. 72 | 73 | ## Questions❓ 74 | 75 | If you have any questions, please raise an issue (preferred) or send an email to Tianwei (tianwei.ni@mila.quebec). 76 | 77 | ## Citation 78 | 79 | ```bibtex 80 | @inproceedings{ni2024bridging, 81 | title={Bridging State and History Representations: Understanding Self-Predictive RL}, 82 | author={Ni, Tianwei and Eysenbach, Benjamin and Seyedsalehi, Erfan and Ma, Michel and Gehring, Clement and Mahajan, Aditya and Bacon, Pierre-Luc}, 83 | booktitle={The Twelfth International Conference on Learning Representations}, 84 | year={2024} 85 | } 86 | ``` 87 | -------------------------------------------------------------------------------- /mujoco_code/utils/env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym.envs.mujoco import mujoco_env 3 | from gym import utils 4 | import matplotlib.pyplot as plt 5 | from matplotlib import animation 6 | import numpy as np 7 | import warnings 8 | 9 | warnings.filterwarnings("ignore", category=DeprecationWarning) 10 | 11 | 12 | def linear_schedule(start_sigma: float, end_sigma: float, duration: int, t: int): 13 | return end_sigma + (1 - min(t / duration, 1)) * (start_sigma - end_sigma) 14 | 15 | 16 | # saving frames 17 | 18 | 19 | def save_frames_as_gif(frames, path="./", filename="gym_animation.gif"): 20 | # Mess with this to change frame size 21 | plt.figure(figsize=(frames[0].shape[1] / 72.0, frames[0].shape[0] / 72.0), dpi=72) 22 | 23 | patch = plt.imshow(frames[0]) 24 | plt.axis("off") 25 | 26 | def animate(i): 27 | patch.set_data(frames[i]) 28 | 29 | anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=50) 30 | anim.save(path + filename, writer="imagemagick", fps=60) 31 | 32 | 33 | # MBPO environments 34 | 35 | MBPO_ENVIRONMENT_SPECS = ( 36 | { 37 | "id": "T-Ant-v2", 38 | "entry_point": (f"utils.env:AntTruncatedObsEnv"), 39 | "max_episode_steps": 1000, 40 | }, 41 | { 42 | "id": "T-Humanoid-v2", 43 | "entry_point": (f"utils.env:HumanoidTruncatedObsEnv"), 44 | "max_episode_steps": 1000, 45 | }, 46 | ) 47 | 48 | 49 | def _register_environments(register, specs): 50 | for env in specs: 51 | register(**env) 52 | 53 | gym_ids = tuple(environment_spec["id"] for environment_spec in specs) 54 | return gym_ids 55 | 56 | 57 | def register_mbpo_environments(): 58 | _register_environments(gym.register, MBPO_ENVIRONMENT_SPECS) 59 | 60 | 61 | def mass_center(model, sim): 62 | mass = np.expand_dims(model.body_mass, 1) 63 | xpos = sim.data.xipos 64 | return (np.sum(mass * xpos, 0) / np.sum(mass))[0] 65 | 66 | 67 | class HumanoidTruncatedObsEnv(mujoco_env.MujocoEnv, utils.EzPickle): 68 | """ 69 | COM inertia (cinert), COM velocity (cvel), actuator forces (qfrc_actuator), 70 | and external forces (cfrc_ext) are removed from the observation. 71 | Otherwise identical to Humanoid-v2 from 72 | https://github.com/openai/gym/blob/master/gym/envs/mujoco/humanoid.py 73 | """ 74 | 75 | def __init__(self): 76 | mujoco_env.MujocoEnv.__init__(self, "humanoid.xml", 5) 77 | utils.EzPickle.__init__(self) 78 | 79 | def _get_obs(self): 80 | data = self.sim.data 81 | return np.concatenate( 82 | [ 83 | data.qpos.flat[2:], 84 | data.qvel.flat, 85 | # data.cinert.flat, 86 | # data.cvel.flat, 87 | # data.qfrc_actuator.flat, 88 | # data.cfrc_ext.flat 89 | ] 90 | ) 91 | 92 | def step(self, a): 93 | pos_before = mass_center(self.model, self.sim) 94 | self.do_simulation(a, self.frame_skip) 95 | pos_after = mass_center(self.model, self.sim) 96 | alive_bonus = 5.0 97 | data = self.sim.data 98 | lin_vel_cost = 0.25 * (pos_after - pos_before) / self.model.opt.timestep 99 | quad_ctrl_cost = 0.1 * np.square(data.ctrl).sum() 100 | quad_impact_cost = 0.5e-6 * np.square(data.cfrc_ext).sum() 101 | quad_impact_cost = min(quad_impact_cost, 10) 102 | reward = lin_vel_cost - quad_ctrl_cost - quad_impact_cost + alive_bonus 103 | qpos = self.sim.data.qpos 104 | done = bool((qpos[2] < 1.0) or (qpos[2] > 2.0)) 105 | return ( 106 | self._get_obs(), 107 | reward, 108 | done, 109 | dict( 110 | reward_linvel=lin_vel_cost, 111 | reward_quadctrl=-quad_ctrl_cost, 112 | reward_alive=alive_bonus, 113 | reward_impact=-quad_impact_cost, 114 | ), 115 | ) 116 | 117 | def reset_model(self): 118 | c = 0.01 119 | self.set_state( 120 | self.init_qpos + self.np_random.uniform(low=-c, high=c, size=self.model.nq), 121 | self.init_qvel 122 | + self.np_random.uniform( 123 | low=-c, 124 | high=c, 125 | size=self.model.nv, 126 | ), 127 | ) 128 | return self._get_obs() 129 | 130 | def viewer_setup(self): 131 | self.viewer.cam.trackbodyid = 1 132 | self.viewer.cam.distance = self.model.stat.extent * 1.0 133 | self.viewer.cam.lookat[2] = 2.0 134 | self.viewer.cam.elevation = -20 135 | 136 | 137 | class AntTruncatedObsEnv(mujoco_env.MujocoEnv, utils.EzPickle): 138 | """ 139 | External forces (sim.data.cfrc_ext) are removed from the observation. 140 | Otherwise identical to Ant-v2 from 141 | https://github.com/openai/gym/blob/master/gym/envs/mujoco/ant.py 142 | """ 143 | 144 | def __init__(self): 145 | mujoco_env.MujocoEnv.__init__(self, "ant.xml", 5) 146 | utils.EzPickle.__init__(self) 147 | 148 | def step(self, a): 149 | xposbefore = self.get_body_com("torso")[0] 150 | self.do_simulation(a, self.frame_skip) 151 | xposafter = self.get_body_com("torso")[0] 152 | forward_reward = (xposafter - xposbefore) / self.dt 153 | ctrl_cost = 0.5 * np.square(a).sum() 154 | contact_cost = ( 155 | 0.5 * 1e-3 * np.sum(np.square(np.clip(self.sim.data.cfrc_ext, -1, 1))) 156 | ) 157 | survive_reward = 1.0 158 | reward = forward_reward - ctrl_cost - contact_cost + survive_reward 159 | state = self.state_vector() 160 | notdone = np.isfinite(state).all() and state[2] >= 0.2 and state[2] <= 1.0 161 | done = not notdone 162 | ob = self._get_obs() 163 | return ( 164 | ob, 165 | reward, 166 | done, 167 | dict( 168 | reward_forward=forward_reward, 169 | reward_ctrl=-ctrl_cost, 170 | reward_contact=-contact_cost, 171 | reward_survive=survive_reward, 172 | ), 173 | ) 174 | 175 | def _get_obs(self): 176 | return np.concatenate( 177 | [ 178 | self.sim.data.qpos.flat[2:], 179 | self.sim.data.qvel.flat, 180 | # np.clip(self.sim.data.cfrc_ext, -1, 1).flat, 181 | ] 182 | ) 183 | 184 | def reset_model(self): 185 | qpos = self.init_qpos + self.np_random.uniform( 186 | size=self.model.nq, low=-0.1, high=0.1 187 | ) 188 | qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1 189 | self.set_state(qpos, qvel) 190 | return self._get_obs() 191 | 192 | def viewer_setup(self): 193 | self.viewer.cam.distance = self.model.stat.extent * 0.5 194 | 195 | 196 | class GymActionRepeatWrapper(gym.Wrapper): 197 | def __init__(self, env, num_repeats): 198 | assert "-v4" in env.unwrapped.spec.id 199 | super().__init__(env) 200 | self._env = env 201 | self._num_repeats = num_repeats 202 | 203 | def step(self, action): 204 | reward = 0.0 205 | notdone = True 206 | for i in range(self._num_repeats): 207 | state, rew, done, info = self._env.step(action) 208 | notdone = not done 209 | reward += (rew) * (notdone) 210 | notdone *= notdone 211 | if done: 212 | break 213 | return state, reward, done, info 214 | -------------------------------------------------------------------------------- /mujoco_code/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.distributions as td 5 | import numpy as np 6 | import utils 7 | from utils import logger 8 | 9 | 10 | class DetEncoder(nn.Module): 11 | def __init__(self, input_shape, hidden_dims, latent_dims): 12 | super().__init__() 13 | self.latent_dims = latent_dims 14 | self.encoder = nn.Sequential( 15 | nn.Linear(input_shape, hidden_dims), 16 | nn.ELU(), 17 | nn.Linear(hidden_dims, hidden_dims), 18 | nn.ELU(), 19 | nn.Linear(hidden_dims, latent_dims), 20 | ) 21 | 22 | self.apply(utils.weight_init) 23 | logger.log(self) 24 | 25 | def forward(self, x): 26 | loc = self.encoder(x) 27 | return utils.Dirac(loc) 28 | 29 | 30 | class StoEncoder(nn.Module): 31 | def __init__(self, input_shape, hidden_dims, latent_dims): 32 | super().__init__() 33 | self.latent_dims = latent_dims 34 | self.encoder = nn.Sequential( 35 | nn.Linear(input_shape, hidden_dims), 36 | nn.ELU(), 37 | nn.Linear(hidden_dims, hidden_dims), 38 | nn.ELU(), 39 | nn.Linear(hidden_dims, 2 * latent_dims), 40 | ) 41 | 42 | self.std_min = 0.1 43 | self.std_max = 10.0 44 | self.apply(utils.weight_init) 45 | logger.log(self) 46 | 47 | def forward(self, x): 48 | x = self.encoder(x) 49 | mean, std = torch.chunk(x, 2, -1) 50 | mean = 30 * torch.tanh(mean / 30) # [-30, 30] 51 | std = self.std_max - F.softplus(self.std_max - std) 52 | std = self.std_min + F.softplus(std - self.std_min) # (-std_min, std_max) 53 | return td.independent.Independent(td.Normal(mean, std), 1) 54 | 55 | 56 | class DetModel(nn.Module): 57 | def __init__( 58 | self, latent_dims, action_dims, hidden_dims, num_layers=2, obs_dims=None 59 | ): 60 | super().__init__() 61 | self.latent_dims = latent_dims 62 | self.action_dims = action_dims 63 | self.hidden_dims = hidden_dims 64 | self.num_layers = num_layers 65 | self.obs_dims = obs_dims 66 | 67 | self.model = self._build_model() 68 | self.apply(utils.weight_init) 69 | logger.log(self) 70 | 71 | def _build_model(self): 72 | model = [nn.Linear(self.action_dims + self.latent_dims, self.hidden_dims)] 73 | model += [nn.ELU()] 74 | for i in range(self.num_layers - 1): 75 | model += [nn.Linear(self.hidden_dims, self.hidden_dims)] 76 | model += [nn.ELU()] 77 | model += [ 78 | nn.Linear( 79 | self.hidden_dims, 80 | self.latent_dims if self.obs_dims is None else self.obs_dims, 81 | ) 82 | ] 83 | return nn.Sequential(*model) 84 | 85 | def forward(self, z, action): 86 | x = torch.cat([z, action], axis=-1) 87 | loc = self.model(x) 88 | return utils.Dirac(loc) 89 | 90 | 91 | class StoModel(nn.Module): 92 | def __init__( 93 | self, latent_dims, action_dims, hidden_dims, num_layers=2, obs_dims=None 94 | ): 95 | super().__init__() 96 | self.latent_dims = latent_dims 97 | self.action_dims = action_dims 98 | self.hidden_dims = hidden_dims 99 | self.num_layers = num_layers 100 | self.obs_dims = obs_dims 101 | 102 | self.std_min = 0.1 103 | self.std_max = 10.0 104 | self.model = self._build_model() 105 | self.apply(utils.weight_init) 106 | logger.log(self) 107 | 108 | def _build_model(self): 109 | model = [nn.Linear(self.action_dims + self.latent_dims, self.hidden_dims)] 110 | model += [nn.ELU()] 111 | for i in range(self.num_layers - 1): 112 | model += [nn.Linear(self.hidden_dims, self.hidden_dims)] 113 | model += [nn.ELU()] 114 | model += [ 115 | nn.Linear( 116 | self.hidden_dims, 117 | 2 * self.latent_dims if self.obs_dims is None else 2 * self.obs_dims, 118 | ) 119 | ] 120 | return nn.Sequential(*model) 121 | 122 | def forward(self, z, action): 123 | x = torch.cat([z, action], axis=-1) 124 | x = self.model(x) 125 | mean, std = torch.chunk(x, 2, -1) 126 | mean = 30 * torch.tanh(mean / 30) 127 | std = self.std_max - F.softplus(self.std_max - std) 128 | std = self.std_min + F.softplus(std - self.std_min) 129 | return td.independent.Independent(td.Normal(mean, std), 1) 130 | 131 | 132 | class RewardPrior(nn.Module): 133 | def __init__(self, latent_dims, hidden_dims, action_dims): 134 | super().__init__() 135 | self.reward = nn.Sequential( 136 | nn.Linear(latent_dims + action_dims, hidden_dims), 137 | nn.LayerNorm(hidden_dims), 138 | nn.Tanh(), 139 | nn.Linear(hidden_dims, hidden_dims), 140 | nn.ELU(), 141 | nn.Linear(hidden_dims, 1), 142 | ) 143 | self.apply(utils.weight_init) 144 | logger.log(self) 145 | 146 | def forward(self, z, a): 147 | z_a = torch.cat([z, a], -1) 148 | reward = self.reward(z_a) 149 | return reward 150 | 151 | 152 | class Discriminator(nn.Module): 153 | def __init__(self, latent_dims, hidden_dims, action_dims): 154 | super().__init__() 155 | self.classifier = nn.Sequential( 156 | nn.Linear(2 * latent_dims + action_dims, hidden_dims), 157 | nn.LayerNorm(hidden_dims), 158 | nn.Tanh(), 159 | nn.Linear(hidden_dims, hidden_dims), 160 | nn.ELU(), 161 | nn.Linear(hidden_dims, 2), 162 | ) 163 | self.apply(utils.weight_init) 164 | logger.log(self) 165 | 166 | def forward(self, z, a, z_next): 167 | x = torch.cat([z, a, z_next], -1) 168 | logits = self.classifier(x) 169 | return logits 170 | 171 | def get_reward(self, z, a, z_next): 172 | x = torch.cat([z, a, z_next], -1) 173 | logits = self.classifier(x) 174 | reward = torch.sub(logits[..., 1], logits[..., 0]) 175 | return reward.unsqueeze(-1) 176 | 177 | 178 | class Critic(nn.Module): 179 | def __init__(self, latent_dims, hidden_dims, action_shape): 180 | super().__init__() 181 | self.Q1 = nn.Sequential( 182 | nn.Linear(latent_dims + action_shape, hidden_dims), 183 | nn.LayerNorm(hidden_dims), 184 | nn.Tanh(), 185 | nn.Linear(hidden_dims, hidden_dims), 186 | nn.ELU(), 187 | nn.Linear(hidden_dims, 1), 188 | ) 189 | 190 | self.Q2 = nn.Sequential( 191 | nn.Linear(latent_dims + action_shape, hidden_dims), 192 | nn.LayerNorm(hidden_dims), 193 | nn.Tanh(), 194 | nn.Linear(hidden_dims, hidden_dims), 195 | nn.ELU(), 196 | nn.Linear(hidden_dims, 1), 197 | ) 198 | 199 | self.apply(utils.weight_init) 200 | logger.log(self) 201 | 202 | def forward(self, x, a): 203 | x_a = torch.cat([x, a], -1) 204 | q1 = self.Q1(x_a) 205 | q2 = self.Q2(x_a) 206 | return q1, q2 207 | 208 | 209 | class Actor(nn.Module): 210 | def __init__(self, input_shape, hidden_dims, output_shape, low, high): 211 | super().__init__() 212 | self.low = low 213 | self.high = high 214 | self.fc1 = nn.Linear(input_shape, hidden_dims) 215 | self.fc2 = nn.Linear(hidden_dims, hidden_dims) 216 | self.mean = nn.Linear(hidden_dims, output_shape) 217 | self.apply(utils.weight_init) 218 | logger.log(self) 219 | 220 | def forward(self, x, std): 221 | x = F.elu(self.fc1(x)) 222 | x = F.elu(self.fc2(x)) 223 | mean = torch.tanh(self.mean(x)) 224 | std = torch.ones_like(mean) * std 225 | dist = utils.TruncatedNormal(mean, std, self.low, self.high) 226 | return dist 227 | 228 | 229 | class StoActor(nn.Module): 230 | def __init__(self, input_shape, hidden_dims, output_shape, low, high): 231 | super().__init__() 232 | self.low = low 233 | self.high = high 234 | self.fc1 = nn.Linear(input_shape, hidden_dims) 235 | self.fc2 = nn.Linear(hidden_dims, hidden_dims) 236 | self.fc3 = nn.Linear(hidden_dims, 2 * output_shape) 237 | self.std_min = np.exp(-5) 238 | self.std_max = np.exp(2) 239 | self.apply(utils.weight_init) 240 | logger.log(self) 241 | 242 | def forward(self, x): 243 | x = F.elu(self.fc1(x)) 244 | x = F.elu(self.fc2(x)) 245 | x = self.fc3(x) 246 | mean, std = torch.chunk(x, 2, -1) 247 | mean = torch.tanh(mean) 248 | std = self.std_max - F.softplus(self.std_max - std) 249 | std = self.std_min + F.softplus(std - self.std_min) 250 | dist = utils.TruncatedNormal(mean, std, self.low, self.high) 251 | return dist 252 | -------------------------------------------------------------------------------- /mujoco_code/workspaces/mujoco_workspace.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import time 4 | import numpy as np 5 | 6 | from pathlib import Path 7 | from utils.env import save_frames_as_gif 8 | from utils import logger 9 | from workspaces.common import make_agent, make_env 10 | 11 | 12 | class MujocoWorkspace: 13 | def __init__(self, cfg): 14 | self.work_dir = Path.cwd() 15 | self.cfg = cfg 16 | if self.cfg.save_snapshot: 17 | self.checkpoint_path = self.work_dir / "checkpoints" 18 | self.checkpoint_path.mkdir(exist_ok=True) 19 | self.device = torch.device(cfg.device) 20 | self.set_seed() 21 | self.train_env, self.eval_env = make_env(self.cfg) 22 | self.agent = make_agent(self.train_env, self.device, self.cfg) 23 | self._train_step = 0 24 | self._train_episode = 0 25 | self._best_eval_returns = -np.inf 26 | 27 | def set_seed(self): 28 | random.seed(self.cfg.seed) 29 | np.random.seed(self.cfg.seed) 30 | torch.manual_seed(self.cfg.seed) 31 | torch.cuda.manual_seed_all(self.cfg.seed) 32 | 33 | def train(self): 34 | self._explore() 35 | self._eval() 36 | 37 | state, done, episode_start_time = self.train_env.reset(), False, time.time() 38 | 39 | for _ in range(1, self.cfg.num_train_steps - self.cfg.explore_steps + 1): 40 | action = self.agent.get_action(state, self._train_step) 41 | next_state, reward, done, info = self.train_env.step(action) 42 | self._train_step += 1 43 | 44 | self.agent.env_buffer.push( 45 | ( 46 | state, 47 | action, 48 | reward, 49 | next_state, 50 | False if info.get("TimeLimit.truncated", False) else done, 51 | ) 52 | ) 53 | 54 | self.agent.update(self._train_step) 55 | 56 | if (self._train_step) % self.cfg.eval_episode_interval == 0: 57 | self._eval() 58 | 59 | if ( 60 | self.cfg.save_snapshot 61 | and (self._train_step) % self.cfg.save_snapshot_interval == 0 62 | ): 63 | self.save_snapshot() 64 | 65 | if done: 66 | self._train_episode += 1 67 | print( 68 | "TRAIN Episode: {}, total numsteps: {}, return: {}".format( 69 | self._train_episode, 70 | self._train_step, 71 | round(info["episode"]["r"], 2), 72 | ) 73 | ) 74 | episode_metrics = dict() 75 | episode_metrics["train/length"] = info["episode"]["l"] 76 | episode_metrics["train/return"] = info["episode"]["r"] 77 | episode_metrics["FPS"] = info["episode"]["l"] / ( 78 | time.time() - episode_start_time 79 | ) 80 | # episode_metrics["env_buffer_length"] = len(self.agent.env_buffer) 81 | logger.record_step("env_steps", self._train_step) 82 | for k, v in episode_metrics.items(): 83 | logger.record_tabular(k, v) 84 | logger.dump_tabular() 85 | 86 | state, done, episode_start_time = ( 87 | self.train_env.reset(), 88 | False, 89 | time.time(), 90 | ) 91 | else: 92 | state = next_state 93 | 94 | self.train_env.close() 95 | 96 | def _explore(self): 97 | state, done = self.train_env.reset(), False 98 | 99 | for _ in range(1, self.cfg.explore_steps): 100 | action = self.train_env.action_space.sample() 101 | next_state, reward, done, info = self.train_env.step(action) 102 | self.agent.env_buffer.push( 103 | ( 104 | state, 105 | action, 106 | reward, 107 | next_state, 108 | False if info.get("TimeLimit.truncated", False) else done, 109 | ) 110 | ) 111 | 112 | if done: 113 | state, done = self.train_env.reset(), False 114 | else: 115 | state = next_state 116 | 117 | def _eval(self): 118 | returns = 0 119 | steps = 0 120 | for _ in range(self.cfg.num_eval_episodes): 121 | done = False 122 | state = self.eval_env.reset() 123 | while not done: 124 | action = self.agent.get_action(state, self._train_step, eval=True) 125 | next_state, _, done, info = self.eval_env.step(action) 126 | state = next_state 127 | 128 | returns += info["episode"]["r"] 129 | steps += info["episode"]["l"] 130 | 131 | print( 132 | "EVAL Episode: {}, total numsteps: {}, return: {}".format( 133 | self._train_episode, 134 | self._train_step, 135 | round(info["episode"]["r"], 2), 136 | ) 137 | ) 138 | 139 | eval_metrics = dict() 140 | eval_metrics["return"] = returns / self.cfg.num_eval_episodes 141 | eval_metrics["length"] = steps / self.cfg.num_eval_episodes 142 | 143 | if ( 144 | self.cfg.save_snapshot 145 | and returns / self.cfg.num_eval_episodes >= self._best_eval_returns 146 | ): 147 | self.save_snapshot(best=True) 148 | self._best_eval_returns = returns / self.cfg.num_eval_episodes 149 | 150 | logger.record_step("env_steps", self._train_step) 151 | for k, v in eval_metrics.items(): 152 | logger.record_tabular(k, v) 153 | logger.dump_tabular() 154 | 155 | def _render_episodes(self, record): 156 | frames = [] 157 | done = False 158 | state = self.eval_env.reset() 159 | while not done: 160 | action = self.agent.get_action(state, self._train_step, True) 161 | next_state, _, done, info = self.eval_env.step(action) 162 | self.eval_env.render() 163 | state = next_state 164 | if record: 165 | save_frames_as_gif(frames) 166 | print( 167 | "Episode: {}, episode steps: {}, episode returns: {}".format( 168 | i, info["episode"]["l"], round(info["episode"]["r"], 2) 169 | ) 170 | ) 171 | 172 | def _eval_bias(self): 173 | final_mc_list, final_obs_list, final_act_list = self._mc_returns() 174 | final_mc_norm_list = np.abs(final_mc_list.copy()) 175 | final_mc_norm_list[final_mc_norm_list < 10] = 10 176 | 177 | obs_tensor = torch.FloatTensor(final_obs_list).to(self.device) 178 | acts_tensor = torch.FloatTensor(final_act_list).to(self.device) 179 | lower_bound = self.agent.get_lower_bound(obs_tensor, acts_tensor) 180 | 181 | bias = final_mc_list - lower_bound 182 | normalized_bias_per_state = bias / final_mc_norm_list 183 | 184 | # metrics = dict() 185 | # metrics["mean_bias"] = np.mean(bias) 186 | # metrics["std_bias"] = np.std(bias) 187 | # metrics["mean_normalised_bias"] = np.mean(normalized_bias_per_state) 188 | # metrics["std_normalised_bias"] = np.std(normalized_bias_per_state) 189 | 190 | def _mc_returns(self): 191 | final_mc_list = np.zeros(0) 192 | final_obs_list = [] 193 | final_act_list = [] 194 | n_mc_eval = 1000 195 | n_mc_cutoff = 350 196 | 197 | while final_mc_list.shape[0] < n_mc_eval: 198 | o = self.eval_env.reset() 199 | reward_list, obs_list, act_list = [], [], [] 200 | r, d, ep_ret, ep_len = 0, False, 0, 0 201 | 202 | while not d: 203 | a = self.agent.get_action(o, self._train_step, True) 204 | obs_list.append(o) 205 | act_list.append(a) 206 | o, r, d, _ = self.eval_env.step(a) 207 | ep_ret += r 208 | ep_len += 1 209 | reward_list.append(r) 210 | 211 | discounted_return_list = np.zeros(ep_len) 212 | for i_step in range(ep_len - 1, -1, -1): 213 | if i_step == ep_len - 1: 214 | discounted_return_list[i_step] = reward_list[i_step] 215 | else: 216 | discounted_return_list[i_step] = ( 217 | reward_list[i_step] 218 | + self.cfg.gamma * discounted_return_list[i_step + 1] 219 | ) 220 | 221 | final_mc_list = np.concatenate( 222 | (final_mc_list, discounted_return_list[:n_mc_cutoff]) 223 | ) 224 | final_obs_list += obs_list[:n_mc_cutoff] 225 | final_act_list += act_list[:n_mc_cutoff] 226 | 227 | return final_mc_list, np.array(final_obs_list), np.array(final_act_list) 228 | 229 | def save_snapshot(self, best=False): 230 | if best: 231 | snapshot = Path(self.checkpoint_path) / "best.pt" 232 | else: 233 | snapshot = Path(self.checkpoint_path) / Path(str(self._train_step) + ".pt") 234 | save_dict = self.agent.get_save_dict() 235 | torch.save(save_dict, snapshot) 236 | -------------------------------------------------------------------------------- /mujoco_code/utils/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | 4 | from utils import logger 5 | format_strs=[ 6 | "stdout", # print 7 | "csv", # progress.csv 8 | "tensorboard", 9 | "log", # experiment.log 10 | ] 11 | logger.configure(dir=log_path, format_strs=format_strs) 12 | 13 | # during training 14 | logger.record_step("env_steps", i) 15 | for k, v in metrics.items(): 16 | logger.record_tabular(f"training/{k}", v) 17 | logger.dump_tabular() 18 | 19 | it will record all the data to format_strs you specified, 20 | with tensorboard add scalar("training/{k}", v, i) 21 | 22 | """ 23 | import os 24 | import sys 25 | import os.path as osp 26 | import time 27 | import datetime 28 | import dateutil.tz 29 | from collections import OrderedDict 30 | 31 | try: # py3.10 32 | from collections.abc import Set 33 | except ImportError: # py3.7 34 | from collections import Set 35 | 36 | import numpy as np 37 | 38 | 39 | LOG_OUTPUT_FORMATS = ["stdout", "log", "csv"] 40 | 41 | DEBUG = 10 42 | INFO = 20 43 | WARN = 30 44 | ERROR = 40 45 | 46 | DISABLED = 50 47 | 48 | 49 | class OrderedSet(Set): 50 | # https://stackoverflow.com/a/10006674/9072850 51 | def __init__(self, iterable=()): 52 | self.d = OrderedDict.fromkeys(iterable) 53 | 54 | def __len__(self): 55 | return len(self.d) 56 | 57 | def __contains__(self, element): 58 | return element in self.d 59 | 60 | def __iter__(self): 61 | return iter(self.d) 62 | 63 | 64 | class KVWriter(object): 65 | def writekvs(self, kvs): 66 | raise NotImplementedError 67 | 68 | 69 | class SeqWriter(object): 70 | def writeseq(self, seq): 71 | raise NotImplementedError 72 | 73 | 74 | def put_in_middle(str1, str2): 75 | # Put str1 in str2 76 | n = len(str1) 77 | m = len(str2) 78 | if n <= m: 79 | return str2 80 | else: 81 | start = (n - m) // 2 82 | return str1[:start] + str2 + str1[start + m :] 83 | 84 | 85 | class HumanOutputFormat(KVWriter, SeqWriter): 86 | def __init__(self, filename_or_file): 87 | if isinstance(filename_or_file, str): 88 | self.file = open(filename_or_file, "wt") 89 | self.own_file = True 90 | else: 91 | assert hasattr(filename_or_file, "read"), ( 92 | "expected file or str, got %s" % filename_or_file 93 | ) 94 | self.file = filename_or_file 95 | self.own_file = False 96 | 97 | def writekvs(self, kvs): 98 | # Create strings for printing 99 | key2str = {} 100 | for key, val in sorted(kvs.items()): 101 | if isinstance(val, float): 102 | valstr = "%-8.3g" % (val,) 103 | else: 104 | valstr = str(val) 105 | key2str[self._truncate(key)] = self._truncate(valstr) 106 | 107 | # Find max widths 108 | if len(key2str) == 0: 109 | print("WARNING: tried to write empty key-value dict") 110 | return 111 | else: 112 | keywidth = max(map(len, key2str.keys())) 113 | valwidth = max(map(len, key2str.values())) 114 | 115 | # Write out the data 116 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 117 | timestamp = now.strftime("%Y-%m-%d %H:%M:%S.%f %Z") 118 | 119 | dashes = "-" * (keywidth + valwidth + 7) 120 | dashes_time = put_in_middle(dashes, timestamp) 121 | lines = [dashes_time] 122 | for key, val in sorted(key2str.items()): 123 | lines.append( 124 | "| %s%s | %s%s |" 125 | % ( 126 | key, 127 | " " * (keywidth - len(key)), 128 | val, 129 | " " * (valwidth - len(val)), 130 | ) 131 | ) 132 | lines.append(dashes) 133 | self.file.write("\n".join(lines) + "\n") 134 | 135 | # Flush the output to the file 136 | self.file.flush() 137 | 138 | def _truncate(self, s): 139 | return s[:30] + "..." if len(s) > 33 else s 140 | 141 | def writeseq(self, seq): 142 | for arg in seq: 143 | self.file.write(arg + " ") 144 | self.file.write("\n") 145 | self.file.flush() 146 | 147 | def close(self): 148 | if self.own_file: 149 | self.file.close() 150 | 151 | 152 | class CSVOutputFormat(KVWriter): 153 | def __init__(self, filename): 154 | self.file = open(filename, "w+t") 155 | self.keys = [] 156 | self.sep = "," 157 | 158 | def writekvs(self, kvs): 159 | # Add our current row to the history 160 | extra_keys = list(OrderedSet(kvs.keys()) - OrderedSet(self.keys)) 161 | if extra_keys: 162 | self.keys.extend(extra_keys) 163 | self.file.seek(0) 164 | lines = self.file.readlines() 165 | self.file.seek(0) 166 | for i, k in enumerate(self.keys): 167 | if i > 0: 168 | self.file.write(",") 169 | self.file.write(k) 170 | self.file.write("\n") 171 | for line in lines[1:]: 172 | self.file.write(line[:-1]) 173 | self.file.write(self.sep * len(extra_keys)) 174 | self.file.write("\n") 175 | for i, k in enumerate(self.keys): 176 | if i > 0: 177 | self.file.write(",") 178 | v = kvs.get(k) 179 | if v is not None: 180 | self.file.write(str(v)) 181 | self.file.write("\n") 182 | self.file.flush() 183 | 184 | def close(self): 185 | self.file.close() 186 | 187 | 188 | def make_output_format(format, ev_dir, log_suffix=""): 189 | os.makedirs(ev_dir, exist_ok=True) 190 | if format == "stdout": 191 | return HumanOutputFormat(sys.stdout) 192 | elif format == "log": 193 | return HumanOutputFormat(osp.join(ev_dir, "experiment%s.log" % log_suffix)) 194 | elif format == "csv": 195 | return CSVOutputFormat(osp.join(ev_dir, "progress.csv")) 196 | else: 197 | raise ValueError("Unknown format specified: %s" % (format,)) 198 | 199 | 200 | # ================================================================ 201 | # API 202 | # ================================================================ 203 | 204 | 205 | def logkv(key, val): 206 | """ 207 | Log a value of some diagnostic 208 | Call this once for each diagnostic quantity, each iteration 209 | If called many times, last value will be used. 210 | """ 211 | Logger.CURRENT.logkv(key, val) 212 | 213 | 214 | def logkv_mean(key, val): 215 | """ 216 | The same as logkv(), but if called many times, values averaged. 217 | """ 218 | Logger.CURRENT.logkv_mean(key, val) 219 | 220 | 221 | def logkvs(d): 222 | """ 223 | Log a dictionary of key-value pairs 224 | """ 225 | for k, v in d.items(): 226 | logkv(k, v) 227 | 228 | 229 | def set_tb_step(key, step): 230 | """ 231 | record step for tensorboard 232 | """ 233 | Logger.CURRENT.set_tb_step(key, step) 234 | 235 | 236 | def add_figure(*args): 237 | """ 238 | add_figure for tensorboard 239 | """ 240 | Logger.CURRENT.add_figure(*args) 241 | 242 | 243 | def dumpkvs(): 244 | """ 245 | Write all of the diagnostics from the current iteration 246 | 247 | level: int. (see logger.py docs) If the global logger level is higher than 248 | the level argument here, don't print to stdout. 249 | """ 250 | Logger.CURRENT.dumpkvs() 251 | 252 | 253 | def getkvs(): 254 | return Logger.CURRENT.name2val 255 | 256 | 257 | def log(*args, level=INFO): 258 | """ 259 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 260 | """ 261 | Logger.CURRENT.log(*args, level=level) 262 | 263 | 264 | def debug(*args): 265 | log(*args, level=DEBUG) 266 | 267 | 268 | def info(*args): 269 | log(*args, level=INFO) 270 | 271 | 272 | def warn(*args): 273 | log(*args, level=WARN) 274 | 275 | 276 | def error(*args): 277 | log(*args, level=ERROR) 278 | 279 | 280 | def set_level(level): 281 | """ 282 | Set logging threshold on current logger. 283 | """ 284 | Logger.CURRENT.set_level(level) 285 | 286 | 287 | def get_dir(): 288 | """ 289 | Get directory that log files are being written to. 290 | will be None if there is no output directory (i.e., if you didn't call start) 291 | """ 292 | return Logger.CURRENT.get_dir() 293 | 294 | 295 | record_tabular = logkv 296 | record_step = set_tb_step 297 | dump_tabular = dumpkvs 298 | 299 | 300 | class ProfileKV: 301 | """ 302 | Usage: 303 | with logger.ProfileKV("interesting_scope"): 304 | code 305 | """ 306 | 307 | def __init__(self, n): 308 | self.n = "wait_" + n 309 | 310 | def __enter__(self): 311 | self.t1 = time.time() 312 | 313 | def __exit__(self, type, value, traceback): 314 | Logger.CURRENT.name2val[self.n] += time.time() - self.t1 315 | 316 | 317 | def profile(n): 318 | """ 319 | Usage: 320 | @profile("my_func") 321 | def my_func(): code 322 | """ 323 | 324 | def decorator_with_name(func): 325 | def func_wrapper(*args, **kwargs): 326 | with ProfileKV(n): 327 | return func(*args, **kwargs) 328 | 329 | return func_wrapper 330 | 331 | return decorator_with_name 332 | 333 | 334 | # ================================================================ 335 | # Backend 336 | # ================================================================ 337 | 338 | 339 | class Logger(object): 340 | DEFAULT = None # A logger with no output files. (See right below class definition) 341 | # So that you can still log to the terminal without setting up any output files 342 | CURRENT = None # Current logger being used by the free functions above 343 | 344 | def __init__(self, dir, output_formats, precision=None): 345 | self.name2val = OrderedDict() 346 | self.level = INFO 347 | self.dir = dir 348 | self.output_formats = output_formats 349 | self.precision = precision # float 350 | 351 | # Logging API, forwarded 352 | # ---------------------------------------- 353 | def logkv(self, key, val): 354 | if isinstance(val, np.ndarray): 355 | val = val.item() 356 | if self.precision is not None and isinstance(val, float): 357 | self.name2val[key] = round(val, self.precision) 358 | else: 359 | self.name2val[key] = val 360 | 361 | def add_figure(self, *args): 362 | pass 363 | 364 | def set_tb_step(self, key, step): 365 | self.logkv(key, step) # also record 366 | 367 | def dumpkvs(self): 368 | if self.level == DISABLED: 369 | return 370 | for fmt in self.output_formats: 371 | if isinstance(fmt, KVWriter): 372 | fmt.writekvs(self.name2val) 373 | self.name2val.clear() 374 | 375 | def log(self, *args, level=INFO): 376 | if self.level <= level: 377 | self._do_log(args) 378 | 379 | # Configuration 380 | # ---------------------------------------- 381 | def set_level(self, level): 382 | self.level = level 383 | 384 | def get_dir(self): 385 | return self.dir 386 | 387 | def close(self): 388 | for fmt in self.output_formats: 389 | fmt.close() 390 | 391 | # Misc 392 | # ---------------------------------------- 393 | def _do_log(self, args): 394 | for fmt in self.output_formats: 395 | if isinstance(fmt, SeqWriter): 396 | fmt.writeseq(map(str, args)) 397 | 398 | 399 | Logger.DEFAULT = Logger.CURRENT = Logger( 400 | dir=None, output_formats=[HumanOutputFormat(sys.stdout)] 401 | ) 402 | 403 | 404 | def configure(dir, format_strs=LOG_OUTPUT_FORMATS, log_suffix="", precision=4): 405 | assert isinstance(dir, str) 406 | os.makedirs(dir, exist_ok=True) 407 | 408 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 409 | 410 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, precision=precision) 411 | log("*" * 10, "\nLogging to %s" % dir, "\n" + "*" * 10) 412 | -------------------------------------------------------------------------------- /minigrid_code/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | 4 | from utils import logger 5 | format_strs=[ 6 | "stdout", # print 7 | "csv", # progress.csv 8 | "tensorboard", 9 | "log", # experiment.log 10 | ] 11 | logger.configure(dir=log_path, format_strs=format_strs) 12 | 13 | # during training 14 | logger.record_step("env_steps", i) 15 | for k, v in metrics.items(): 16 | logger.record_tabular(f"training/{k}", v) 17 | logger.dump_tabular() 18 | 19 | it will record all the data to format_strs you specified, 20 | with tensorboard add scalar("training/{k}", v, i) 21 | 22 | """ 23 | import os 24 | import sys 25 | import os.path as osp 26 | import time 27 | import datetime 28 | import dateutil.tz 29 | from collections import OrderedDict 30 | 31 | try: # py3.10 32 | from collections.abc import Set 33 | except ImportError: # py3.7 34 | from collections import Set 35 | 36 | import numpy as np 37 | 38 | try: 39 | from tensorboardX import SummaryWriter 40 | except: 41 | from torch.utils.tensorboard import SummaryWriter 42 | 43 | LOG_OUTPUT_FORMATS = ["stdout", "log", "csv", "tensorboard"] 44 | 45 | DEBUG = 10 46 | INFO = 20 47 | WARN = 30 48 | ERROR = 40 49 | 50 | DISABLED = 50 51 | 52 | 53 | class OrderedSet(Set): 54 | # https://stackoverflow.com/a/10006674/9072850 55 | def __init__(self, iterable=()): 56 | self.d = OrderedDict.fromkeys(iterable) 57 | 58 | def __len__(self): 59 | return len(self.d) 60 | 61 | def __contains__(self, element): 62 | return element in self.d 63 | 64 | def __iter__(self): 65 | return iter(self.d) 66 | 67 | 68 | class KVWriter(object): 69 | def writekvs(self, kvs): 70 | raise NotImplementedError 71 | 72 | 73 | class SeqWriter(object): 74 | def writeseq(self, seq): 75 | raise NotImplementedError 76 | 77 | 78 | def put_in_middle(str1, str2): 79 | # Put str1 in str2 80 | n = len(str1) 81 | m = len(str2) 82 | if n <= m: 83 | return str2 84 | else: 85 | start = (n - m) // 2 86 | return str1[:start] + str2 + str1[start + m :] 87 | 88 | 89 | class HumanOutputFormat(KVWriter, SeqWriter): 90 | def __init__(self, filename_or_file): 91 | if isinstance(filename_or_file, str): 92 | self.file = open(filename_or_file, "wt") 93 | self.own_file = True 94 | else: 95 | assert hasattr(filename_or_file, "read"), ( 96 | "expected file or str, got %s" % filename_or_file 97 | ) 98 | self.file = filename_or_file 99 | self.own_file = False 100 | 101 | def writekvs(self, kvs): 102 | # Create strings for printing 103 | key2str = {} 104 | for key, val in sorted(kvs.items()): 105 | if isinstance(val, float): 106 | valstr = "%-8.3g" % (val,) 107 | else: 108 | valstr = str(val) 109 | key2str[self._truncate(key)] = self._truncate(valstr) 110 | 111 | # Find max widths 112 | if len(key2str) == 0: 113 | print("WARNING: tried to write empty key-value dict") 114 | return 115 | else: 116 | keywidth = max(map(len, key2str.keys())) 117 | valwidth = max(map(len, key2str.values())) 118 | 119 | # Write out the data 120 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 121 | timestamp = now.strftime("%Y-%m-%d %H:%M:%S.%f %Z") 122 | 123 | dashes = "-" * (keywidth + valwidth + 7) 124 | dashes_time = put_in_middle(dashes, timestamp) 125 | lines = [dashes_time] 126 | for key, val in sorted(key2str.items()): 127 | lines.append( 128 | "| %s%s | %s%s |" 129 | % ( 130 | key, 131 | " " * (keywidth - len(key)), 132 | val, 133 | " " * (valwidth - len(val)), 134 | ) 135 | ) 136 | lines.append(dashes) 137 | self.file.write("\n".join(lines) + "\n") 138 | 139 | # Flush the output to the file 140 | self.file.flush() 141 | 142 | def _truncate(self, s): 143 | return s[:30] + "..." if len(s) > 33 else s 144 | 145 | def writeseq(self, seq): 146 | for arg in seq: 147 | self.file.write(arg + " ") 148 | self.file.write("\n") 149 | self.file.flush() 150 | 151 | def close(self): 152 | if self.own_file: 153 | self.file.close() 154 | 155 | 156 | class CSVOutputFormat(KVWriter): 157 | def __init__(self, filename): 158 | self.file = open(filename, "w+t") 159 | self.keys = [] 160 | self.sep = "," 161 | 162 | def writekvs(self, kvs): 163 | # Add our current row to the history 164 | extra_keys = list(OrderedSet(kvs.keys()) - OrderedSet(self.keys)) 165 | if extra_keys: 166 | self.keys.extend(extra_keys) 167 | self.file.seek(0) 168 | lines = self.file.readlines() 169 | self.file.seek(0) 170 | for i, k in enumerate(self.keys): 171 | if i > 0: 172 | self.file.write(",") 173 | self.file.write(k) 174 | self.file.write("\n") 175 | for line in lines[1:]: 176 | self.file.write(line[:-1]) 177 | self.file.write(self.sep * len(extra_keys)) 178 | self.file.write("\n") 179 | for i, k in enumerate(self.keys): 180 | if i > 0: 181 | self.file.write(",") 182 | v = kvs.get(k) 183 | if v is not None: 184 | self.file.write(str(v)) 185 | self.file.write("\n") 186 | self.file.flush() 187 | 188 | def close(self): 189 | self.file.close() 190 | 191 | 192 | class TensorBoardOutputFormat(KVWriter): 193 | """ 194 | Dumps key/value pairs into TensorBoard's numeric format. 195 | """ 196 | 197 | def __init__(self, dir): 198 | os.makedirs(dir, exist_ok=True) 199 | self.step = 0 200 | self.writer = SummaryWriter(dir) 201 | 202 | def writekvs(self, kvs): 203 | for k, v in kvs.items(): 204 | self.writer.add_scalar(k, v, self.step) 205 | 206 | self.writer.flush() 207 | 208 | def add_figure(self, tag, figure): 209 | self.writer.add_figure(tag, figure, self.step) 210 | 211 | def set_step(self, step: int): 212 | self.step = step 213 | 214 | def close(self): 215 | if self.writer: 216 | self.writer.Close() 217 | self.writer = None 218 | 219 | 220 | def make_output_format(format, ev_dir, log_suffix=""): 221 | os.makedirs(ev_dir, exist_ok=True) 222 | if format == "stdout": 223 | return HumanOutputFormat(sys.stdout) 224 | elif format == "log": 225 | return HumanOutputFormat(osp.join(ev_dir, "experiment%s.log" % log_suffix)) 226 | elif format == "csv": 227 | return CSVOutputFormat(osp.join(ev_dir, "progress.csv")) 228 | elif format == "tensorboard": 229 | return TensorBoardOutputFormat(ev_dir) 230 | else: 231 | raise ValueError("Unknown format specified: %s" % (format,)) 232 | 233 | 234 | # ================================================================ 235 | # API 236 | # ================================================================ 237 | 238 | 239 | def logkv(key, val): 240 | """ 241 | Log a value of some diagnostic 242 | Call this once for each diagnostic quantity, each iteration 243 | If called many times, last value will be used. 244 | """ 245 | Logger.CURRENT.logkv(key, val) 246 | 247 | 248 | def logkv_mean(key, val): 249 | """ 250 | The same as logkv(), but if called many times, values averaged. 251 | """ 252 | Logger.CURRENT.logkv_mean(key, val) 253 | 254 | 255 | def logkvs(d): 256 | """ 257 | Log a dictionary of key-value pairs 258 | """ 259 | for k, v in d.items(): 260 | logkv(k, v) 261 | 262 | 263 | def set_tb_step(key, step): 264 | """ 265 | record step for tensorboard 266 | """ 267 | Logger.CURRENT.set_tb_step(key, step) 268 | 269 | 270 | def add_figure(*args): 271 | """ 272 | add_figure for tensorboard 273 | """ 274 | Logger.CURRENT.add_figure(*args) 275 | 276 | 277 | def dumpkvs(): 278 | """ 279 | Write all of the diagnostics from the current iteration 280 | 281 | level: int. (see logger.py docs) If the global logger level is higher than 282 | the level argument here, don't print to stdout. 283 | """ 284 | Logger.CURRENT.dumpkvs() 285 | 286 | 287 | def getkvs(): 288 | return Logger.CURRENT.name2val 289 | 290 | 291 | def log(*args, level=INFO): 292 | """ 293 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 294 | """ 295 | Logger.CURRENT.log(*args, level=level) 296 | 297 | 298 | def debug(*args): 299 | log(*args, level=DEBUG) 300 | 301 | 302 | def info(*args): 303 | log(*args, level=INFO) 304 | 305 | 306 | def warn(*args): 307 | log(*args, level=WARN) 308 | 309 | 310 | def error(*args): 311 | log(*args, level=ERROR) 312 | 313 | 314 | def set_level(level): 315 | """ 316 | Set logging threshold on current logger. 317 | """ 318 | Logger.CURRENT.set_level(level) 319 | 320 | 321 | def get_dir(): 322 | """ 323 | Get directory that log files are being written to. 324 | will be None if there is no output directory (i.e., if you didn't call start) 325 | """ 326 | return Logger.CURRENT.get_dir() 327 | 328 | 329 | record_tabular = logkv 330 | record_step = set_tb_step 331 | dump_tabular = dumpkvs 332 | 333 | 334 | class ProfileKV: 335 | """ 336 | Usage: 337 | with logger.ProfileKV("interesting_scope"): 338 | code 339 | """ 340 | 341 | def __init__(self, n): 342 | self.n = "wait_" + n 343 | 344 | def __enter__(self): 345 | self.t1 = time.time() 346 | 347 | def __exit__(self, type, value, traceback): 348 | Logger.CURRENT.name2val[self.n] += time.time() - self.t1 349 | 350 | 351 | def profile(n): 352 | """ 353 | Usage: 354 | @profile("my_func") 355 | def my_func(): code 356 | """ 357 | 358 | def decorator_with_name(func): 359 | def func_wrapper(*args, **kwargs): 360 | with ProfileKV(n): 361 | return func(*args, **kwargs) 362 | 363 | return func_wrapper 364 | 365 | return decorator_with_name 366 | 367 | 368 | # ================================================================ 369 | # Backend 370 | # ================================================================ 371 | 372 | 373 | class Logger(object): 374 | DEFAULT = None # A logger with no output files. (See right below class definition) 375 | # So that you can still log to the terminal without setting up any output files 376 | CURRENT = None # Current logger being used by the free functions above 377 | 378 | def __init__(self, dir, output_formats, precision=None): 379 | self.name2val = OrderedDict() 380 | self.level = INFO 381 | self.dir = dir 382 | self.output_formats = output_formats 383 | self.precision = precision # float 384 | 385 | # Logging API, forwarded 386 | # ---------------------------------------- 387 | def logkv(self, key, val): 388 | if isinstance(val, np.ndarray): 389 | val = val.item() 390 | if self.precision is not None and isinstance(val, float): 391 | self.name2val[key] = round(val, self.precision) 392 | else: 393 | self.name2val[key] = val 394 | 395 | def add_figure(self, *args): 396 | for fmt in self.output_formats: 397 | if isinstance(fmt, TensorBoardOutputFormat): 398 | fmt.add_figure(*args) 399 | 400 | def set_tb_step(self, key, step): 401 | self.logkv(key, step) # also record 402 | 403 | for fmt in self.output_formats: 404 | if isinstance(fmt, TensorBoardOutputFormat): 405 | fmt.set_step(step) 406 | 407 | def dumpkvs(self): 408 | if self.level == DISABLED: 409 | return 410 | for fmt in self.output_formats: 411 | if isinstance(fmt, KVWriter): 412 | fmt.writekvs(self.name2val) 413 | self.name2val.clear() 414 | 415 | def log(self, *args, level=INFO): 416 | if self.level <= level: 417 | self._do_log(args) 418 | 419 | # Configuration 420 | # ---------------------------------------- 421 | def set_level(self, level): 422 | self.level = level 423 | 424 | def get_dir(self): 425 | return self.dir 426 | 427 | def close(self): 428 | for fmt in self.output_formats: 429 | fmt.close() 430 | 431 | # Misc 432 | # ---------------------------------------- 433 | def _do_log(self, args): 434 | for fmt in self.output_formats: 435 | if isinstance(fmt, SeqWriter): 436 | fmt.writeseq(map(str, args)) 437 | 438 | 439 | Logger.DEFAULT = Logger.CURRENT = Logger( 440 | dir=None, output_formats=[HumanOutputFormat(sys.stdout)] 441 | ) 442 | 443 | 444 | def configure(dir, format_strs=LOG_OUTPUT_FORMATS, log_suffix="", precision=4): 445 | assert isinstance(dir, str) 446 | os.makedirs(dir, exist_ok=True) 447 | 448 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 449 | 450 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, precision=precision) 451 | log("*" * 10, "\nLogging to %s" % dir, "\n" + "*" * 10) 452 | -------------------------------------------------------------------------------- /minigrid_code/r2d2replaybuffer.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class r2d2_ReplayMemory: 8 | def __init__(self, capacity, obs_dim, act_dim, args): 9 | self.capacity = capacity 10 | self.args = args 11 | self.obs_dim = obs_dim 12 | self.act_dim = act_dim 13 | self.gamma = args["gamma"] 14 | self.burn_in_len = args["burn_in_len"] # H = 50 15 | self.learning_obs_len = args["learning_obs_len"] # L = 10 16 | self.forward_len = args["forward_len"] # N = 5 17 | self.AIS_state_size = args["AIS_state_size"] 18 | self.batch_size = args["batch_size"] 19 | 20 | self.buffer_hidden = ( 21 | np.zeros([self.capacity, self.AIS_state_size], dtype=np.float32), 22 | np.zeros([self.capacity, self.AIS_state_size], dtype=np.float32), 23 | ) 24 | self.buffer_burn_in_len = np.zeros([self.capacity], dtype=np.int32) 25 | self.buffer_burn_in_history = np.zeros( 26 | [self.capacity, self.burn_in_len, self.obs_dim + self.act_dim + 1], 27 | dtype=np.float32, 28 | ) 29 | 30 | self.buffer_learning_len = np.zeros([self.capacity], dtype=np.int32) 31 | self.buffer_learning_history = np.zeros( 32 | [ 33 | self.capacity, 34 | self.learning_obs_len + self.forward_len, 35 | self.obs_dim + self.act_dim + 1, 36 | ], 37 | dtype=np.float32, 38 | ) 39 | 40 | self.buffer_learn_forward_len = np.zeros([self.capacity], dtype=np.int32) 41 | self.buffer_forward_idx = np.zeros( 42 | [self.capacity, self.learning_obs_len], dtype=np.int32 43 | ) 44 | self.buffer_current_act = np.zeros( 45 | [self.capacity, self.learning_obs_len], dtype=np.int32 46 | ) 47 | self.buffer_next_obs = np.zeros( 48 | [self.capacity, self.learning_obs_len, self.obs_dim], 49 | dtype=np.float32, 50 | ) 51 | self.buffer_rewards = np.zeros( 52 | [self.capacity, self.learning_obs_len], dtype=np.float32 53 | ) 54 | self.buffer_model_target_rewards = np.zeros( 55 | [self.capacity, self.learning_obs_len], dtype=np.float32 56 | ) 57 | 58 | self.buffer_final_flag = np.zeros( 59 | [self.capacity, self.learning_obs_len], dtype=np.int32 60 | ) 61 | self.buffer_model_final_flag = np.zeros( 62 | [self.capacity, self.learning_obs_len], dtype=np.int32 63 | ) 64 | self.buffer_gammas = np.zeros( 65 | [self.capacity, self.learning_obs_len], dtype=np.float32 66 | ) 67 | 68 | self.position_r2d2 = 0 69 | 70 | self.full = False 71 | 72 | def save_buffer(self, dir, seed): 73 | import os 74 | 75 | path = os.path.join(dir, "Seed_" + str(seed) + "_replaybuffer.pt") 76 | 77 | torch.save( 78 | { 79 | "buffer_burn_in_history": self.buffer_burn_in_history, 80 | "buffer_learning_history": self.buffer_learning_history, 81 | "buffer_current_act": self.buffer_current_act, 82 | "buffer_next_obs": self.buffer_next_obs, 83 | "buffer_rewards": self.buffer_rewards, 84 | "buffer_model_target_rewards": self.buffer_model_target_rewards, 85 | "buffer_burn_in_len": self.buffer_burn_in_len, 86 | "buffer_forward_idx": self.buffer_forward_idx, 87 | "buffer_learning_len": self.buffer_learning_len, 88 | "buffer_hidden_1": self.buffer_hidden[0], 89 | "buffer_hidden_2": self.buffer_hidden[1], 90 | "buffer_final_flag": self.buffer_final_flag, 91 | "buffer_model_final_flag": self.buffer_model_final_flag, 92 | "buffer_gammas": self.buffer_gammas, 93 | }, 94 | path, 95 | ) 96 | 97 | def reset(self, seed): 98 | random.seed(seed) 99 | 100 | self.position_r2d2 = 0 101 | self.full = False 102 | 103 | def push(self, ep_states, ep_actions, ep_rewards, ep_hiddens): 104 | """ 105 | add an entire episode to the buffer 106 | """ 107 | ep_states = np.array(ep_states) # (T+1, O) 108 | ep_actions = np.array(ep_actions) # (T+1,) 109 | ep_rewards = np.array(ep_rewards) # (T+1,) 110 | ep_hiddens = np.array( 111 | [ 112 | [ 113 | ep_hidden[0].cpu().numpy().flatten(), 114 | ep_hidden[1].cpu().numpy().flatten(), 115 | ] 116 | for ep_hidden in ep_hiddens 117 | ] 118 | ) # (T+1, 2, Z) assume LSTM 119 | 120 | # Prepare raw data 121 | ls_prev_rewards = ep_rewards[:-1] # (T) 122 | ls_curr_rewards = ep_rewards[1:] # (T) 123 | 124 | ls_curr_actions = F.one_hot( 125 | torch.LongTensor(ep_actions[1:]), num_classes=self.act_dim 126 | ) 127 | ls_curr_actions = ls_curr_actions.numpy().astype(np.int32) # (T, A) 128 | ls_prev_actions = np.concatenate( 129 | [np.zeros((1, self.act_dim), dtype=np.int32), ls_curr_actions[:-1]], axis=0 130 | ) # (T, A) 131 | 132 | ls_curr_obs = ep_states[:-1] # (T, O) 133 | ls_next_obs = ep_states[1:] # (T, O) 134 | ls_hiddens = ep_hiddens[:-1] # (T, 2, Z) 135 | T = len(ls_curr_obs) 136 | 137 | ### Prepare burn-in history: early items are shorter than burn_in_len 138 | hidden_list = [ 139 | ls_hiddens[max(0, x - self.burn_in_len)] 140 | for x in range(0, T, self.learning_obs_len) 141 | ] 142 | burn_in_act_list = [ 143 | ls_prev_actions[max(0, x - self.burn_in_len) : x] 144 | for x in range(0, T, self.learning_obs_len) 145 | ] 146 | burn_in_r_list = [ 147 | ls_prev_rewards[max(0, x - self.burn_in_len) : x] 148 | for x in range(0, T, self.learning_obs_len) 149 | ] 150 | burn_in_obs_list = [ 151 | ls_curr_obs[max(0, x - self.burn_in_len) : x] 152 | for x in range(0, T, self.learning_obs_len) 153 | ] 154 | 155 | ### Prepare learning data: late items are shorter than self.learning_obs_len + self.forward_len 156 | ### They do not include the terminal tuple of action, reward, next_obs 157 | learning_act_list = [ 158 | ls_prev_actions[x : x + self.learning_obs_len + self.forward_len] 159 | for x in range(0, T, self.learning_obs_len) 160 | ] 161 | learning_r_list = [ 162 | ls_prev_rewards[x : x + self.learning_obs_len + self.forward_len] 163 | for x in range(0, T, self.learning_obs_len) 164 | ] 165 | learning_obs_list = [ 166 | ls_curr_obs[x : x + self.learning_obs_len + self.forward_len] 167 | for x in range(0, T, self.learning_obs_len) 168 | ] 169 | 170 | ### Prepare TD and AIS data (this include terminal action, obs, and reward) 171 | current_act_list = [ 172 | ls_curr_actions[x : x + self.learning_obs_len] 173 | for x in range(0, T, self.learning_obs_len) 174 | ] 175 | next_obs_list = [ 176 | ls_next_obs[x : x + self.learning_obs_len] 177 | for x in range(0, T, self.learning_obs_len) 178 | ] # for one-step prediction, instead of forward_len-step prediction in prior work 179 | ep_rewards_list = [ 180 | ls_curr_rewards[x : x + self.learning_obs_len] 181 | for x in range(0, T, self.learning_obs_len) 182 | ] # for one-step prediction, instead of forward_len-step prediction in prior work 183 | 184 | ep_rewards_ = ls_curr_rewards[:-1] 185 | discounted_sum = [ 186 | [ 187 | sum_rewards(ep_rewards_[x + y : x + y + self.forward_len], self.gamma) 188 | if x + y != len(ep_rewards_) 189 | else ls_curr_rewards[x + y] 190 | for y in range(0, min(self.learning_obs_len, T - x)) 191 | ] 192 | for x in range(0, T, self.learning_obs_len) 193 | ] 194 | 195 | ### Store into the buffer 196 | for i in range(len(hidden_list)): 197 | # store burn-in 198 | self.buffer_burn_in_len[self.position_r2d2] = len(burn_in_obs_list[i]) 199 | self.buffer_hidden[0][self.position_r2d2, :] = hidden_list[i][0] 200 | self.buffer_hidden[1][self.position_r2d2, :] = hidden_list[i][1] 201 | if len(burn_in_obs_list[i]) != 0: 202 | self.buffer_burn_in_history[ 203 | self.position_r2d2, : len(burn_in_act_list[i]), : 204 | ] = np.concatenate( 205 | ( 206 | burn_in_obs_list[i], 207 | burn_in_act_list[i], 208 | burn_in_r_list[i].reshape(-1, 1), 209 | ), 210 | axis=-1, 211 | ) 212 | 213 | # store learn data 214 | self.buffer_learn_forward_len[self.position_r2d2] = len( 215 | learning_act_list[i] 216 | ) 217 | self.buffer_learning_history[ 218 | self.position_r2d2, : len(learning_act_list[i]), : 219 | ] = np.concatenate( 220 | ( 221 | learning_obs_list[i], 222 | learning_act_list[i], 223 | learning_r_list[i].reshape(-1, 1), 224 | ), 225 | axis=-1, 226 | ) 227 | 228 | # store TD and AIS data 229 | self.buffer_current_act[ 230 | self.position_r2d2, : len(current_act_list[i]) 231 | ] = np.argmax(current_act_list[i], axis=-1) 232 | self.buffer_next_obs[ 233 | self.position_r2d2, : len(next_obs_list[i]), : 234 | ] = next_obs_list[i] 235 | self.buffer_model_target_rewards[ 236 | self.position_r2d2, : len(ep_rewards_list[i]) 237 | ] = ep_rewards_list[i] 238 | 239 | self.buffer_rewards[ 240 | self.position_r2d2, : len(discounted_sum[i]) 241 | ] = np.array(discounted_sum[i]) 242 | self.buffer_learning_len[self.position_r2d2] = len(discounted_sum[i]) 243 | self.buffer_forward_idx[ 244 | self.position_r2d2, : len(discounted_sum[i]) 245 | ] = np.array( 246 | [ 247 | min(j + self.forward_len, len(learning_obs_list[i]) - 1) 248 | for j in range(len(discounted_sum[i])) 249 | ] 250 | ) 251 | 252 | # NOTE: assume all dones are terminated, which is okay in minigrid tasks 253 | # where the timeout reward is exact 0.0, 254 | # and the training code is hard to adapt to timeout scenarios, as it 255 | self.buffer_final_flag[ 256 | self.position_r2d2, : len(discounted_sum[i]) 257 | ] = np.array( 258 | [ 259 | int(i * self.learning_obs_len + j < T - 1) 260 | for j in range(len(discounted_sum[i])) 261 | ] 262 | ) 263 | self.buffer_model_final_flag[ 264 | self.position_r2d2, : len(discounted_sum[i]) 265 | ] = np.array( 266 | [ 267 | int(i * self.learning_obs_len + j <= T - 1) 268 | for j in range(len(discounted_sum[i])) 269 | ] 270 | ) # this flag includes terminal step, used for reward prediction only 271 | 272 | self.buffer_gammas[self.position_r2d2, : len(discounted_sum[i])] = np.array( 273 | [ 274 | self.gamma 275 | ** (min(j + self.forward_len, len(learning_obs_list[i]) - 1) - j) 276 | for j in range(len(discounted_sum[i])) 277 | ] 278 | ) 279 | 280 | if self.full is False and self.position_r2d2 + 1 == self.capacity: 281 | self.full = True 282 | self.position_r2d2 = (self.position_r2d2 + 1) % self.capacity 283 | 284 | def sample(self, batch_size): 285 | tmp = self.position_r2d2 286 | if self.full: 287 | tmp = self.capacity 288 | idx = np.random.choice(tmp, batch_size, replace=False) 289 | 290 | batch_burn_in_hist = self.buffer_burn_in_history[idx, :, :] 291 | batch_learn_hist = self.buffer_learning_history[idx, :, :] 292 | batch_rewards = self.buffer_rewards[idx, :] 293 | batch_burn_in_len = self.buffer_burn_in_len[idx] 294 | batch_forward_idx = self.buffer_forward_idx[idx, :] 295 | batch_final_flag = self.buffer_final_flag[idx, :] 296 | batch_learn_len = self.buffer_learning_len[idx] 297 | batch_hidden = (self.buffer_hidden[0][idx], self.buffer_hidden[1][idx]) 298 | batch_current_act = self.buffer_current_act[idx, :] 299 | batch_learn_forward_len = self.buffer_learn_forward_len[idx] 300 | batch_next_obs = self.buffer_next_obs[idx] 301 | batch_model_target_reward = self.buffer_model_target_rewards[idx] 302 | batch_model_final_flag = self.buffer_model_final_flag[idx, :] 303 | batch_gammas = self.buffer_gammas[idx] 304 | 305 | return ( 306 | batch_burn_in_hist, 307 | batch_learn_hist, 308 | batch_rewards, 309 | batch_learn_len, 310 | batch_forward_idx, 311 | batch_final_flag, 312 | batch_current_act, 313 | batch_hidden, 314 | batch_burn_in_len, 315 | batch_learn_forward_len, 316 | batch_next_obs, 317 | batch_model_target_reward, 318 | batch_model_final_flag, 319 | batch_gammas, 320 | ) 321 | 322 | def __len__(self): 323 | if self.full: 324 | return self.capacity 325 | else: 326 | return self.position_r2d2 327 | 328 | 329 | def sum_rewards(reward_list, gamma): 330 | ls = [reward_list[i] * gamma**i for i in range(0, len(reward_list))] 331 | return sum(ls) 332 | -------------------------------------------------------------------------------- /minigrid_code/agent.py: -------------------------------------------------------------------------------- 1 | from models import * 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.optim import Adam 6 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 7 | import math 8 | import random 9 | import logger 10 | 11 | 12 | class Agent(object): 13 | def __init__(self, env, args): 14 | self.args = args 15 | self.device = torch.device("cuda" if args["cuda"] else "cpu") 16 | self.obs_dim = np.prod(env.observation_space["image"].shape) # flatten 17 | self.act_dim = env.action_space.n 18 | self.gamma = args["gamma"] 19 | self.tau = args["tau"] 20 | self.target_update_interval = args["target_update_interval"] 21 | 22 | self.aux = args["aux"] 23 | assert self.aux in ["None", "ZP", "OP", "AIS", "AIS-P2"] 24 | self.AIS_state_size = args["AIS_state_size"] 25 | 26 | self.encoder = SeqEncoder(self.obs_dim, self.act_dim, self.AIS_state_size).to( 27 | self.device 28 | ) 29 | self.encoder_target = SeqEncoder( 30 | self.obs_dim, self.act_dim, self.AIS_state_size 31 | ).to(self.device) 32 | hard_update(self.encoder_target, self.encoder) 33 | 34 | self.critic = QNetwork_discrete( 35 | self.AIS_state_size, self.act_dim, args["hidden_size"] 36 | ).to(device=self.device) 37 | self.critic_target = QNetwork_discrete( 38 | self.AIS_state_size, self.act_dim, args["hidden_size"] 39 | ).to(self.device) 40 | hard_update(self.critic_target, self.critic) 41 | 42 | if self.aux in ["AIS", "AIS-P2"]: # modular 43 | self.optim = Adam( 44 | self.critic.parameters(), 45 | lr=args["rl_lr"], 46 | ) 47 | else: # end-to-end 48 | self.optim = Adam( 49 | list(self.encoder.parameters()) + list(self.critic.parameters()), 50 | lr=args["rl_lr"], 51 | ) 52 | 53 | if self.aux in ["AIS", "AIS-P2"]: 54 | self.model = AISModel( 55 | self.obs_dim if self.aux == "AIS" else self.AIS_state_size, 56 | self.act_dim, 57 | self.AIS_state_size, 58 | ).to(self.device) 59 | self.AIS_optim = Adam( 60 | list(self.encoder.parameters()) + list(self.model.parameters()), 61 | lr=args["aux_lr"], 62 | ) 63 | elif self.aux == "None": # model-free R2D2 64 | self.model = None 65 | else: 66 | self.model = LatentModel( 67 | self.obs_dim if self.aux == "OP" else self.AIS_state_size, 68 | self.act_dim, 69 | self.AIS_state_size, 70 | ).to(self.device) 71 | self.AIS_optim = Adam( 72 | self.model.parameters(), 73 | lr=args["aux_lr"], 74 | ) 75 | 76 | logger.log(self.encoder, self.model, self.critic) 77 | 78 | self.aux_optim = args["aux_optim"] 79 | self.aux_coef = args["aux_coef"] 80 | assert self.aux_optim in ["None", "ema", "detach", "online"] 81 | assert self.aux_coef >= 0.0 82 | 83 | self.update_to_q = 0 84 | self.eps_greedy_parameters = { 85 | "EPS_START": args["EPS_start"], 86 | "EPS_END": args["EPS_end"], 87 | "EPS_DECAY": args["EPS_decay"], 88 | } 89 | self.env_steps = 0 90 | 91 | self.get_initial_hidden = lambda: self.encoder.get_initial_hidden( 92 | 1, self.device 93 | ) 94 | 95 | @torch.no_grad() 96 | def select_action( 97 | self, state, action, reward, hidden_p, EPS_up: bool, evaluate: bool 98 | ): 99 | action = convert_int_to_onehot(action, self.act_dim) 100 | reward = torch.Tensor([reward]) 101 | state = torch.Tensor(state) 102 | rho_input = torch.cat((state, action, reward)).reshape(1, 1, -1).to(self.device) 103 | 104 | ais_z, hidden_p = self.encoder( 105 | rho_input, 106 | batch_size=1, 107 | hidden=hidden_p, 108 | device=self.device, 109 | batch_lengths=[], 110 | pack_sequence=False, 111 | ) 112 | if evaluate is False and EPS_up: 113 | self.env_steps += 1 114 | 115 | if self.args["EPS_decay_type"] == "exponential": 116 | eps_threshold = self.eps_greedy_parameters["EPS_END"] + ( 117 | self.eps_greedy_parameters["EPS_START"] 118 | - self.eps_greedy_parameters["EPS_END"] 119 | ) * math.exp( 120 | -1.0 * self.env_steps / self.eps_greedy_parameters["EPS_DECAY"] 121 | ) 122 | elif self.args["EPS_decay_type"] == "linear": 123 | eps_threshold = self.eps_greedy_parameters["EPS_START"] + ( 124 | self.eps_greedy_parameters["EPS_END"] 125 | - self.eps_greedy_parameters["EPS_START"] 126 | ) * (self.env_steps / self.eps_greedy_parameters["EPS_DECAY"]) 127 | eps_threshold = max(eps_threshold, self.eps_greedy_parameters["EPS_END"]) 128 | 129 | sample = random.random() 130 | if (sample < eps_threshold and evaluate is False) or ( 131 | sample < self.args["test_epsilon"] and evaluate is True 132 | ): 133 | return random.randrange(self.act_dim), hidden_p 134 | 135 | qf = self.critic(ais_z)[0, 0] 136 | greedy_action = torch.argmax(qf).item() 137 | 138 | return greedy_action, hidden_p 139 | 140 | def update_parameters(self, memory, batch_size: int, updates: int): 141 | for _ in range(updates): 142 | self.update_to_q += 1 143 | metrics = self.single_update(memory, batch_size) 144 | 145 | if self.update_to_q % self.target_update_interval == 0: 146 | # We change the hard update to soft update 147 | soft_update(self.encoder_target, self.encoder, self.tau) 148 | soft_update(self.critic_target, self.critic, self.tau) 149 | 150 | return metrics 151 | 152 | def report_rank(self, z_batch, metrics: dict): 153 | from torch.linalg import matrix_rank 154 | 155 | rank3 = matrix_rank(z_batch, atol=1e-3, rtol=1e-3) 156 | rank2 = matrix_rank(z_batch, atol=1e-2, rtol=1e-2) 157 | rank1 = matrix_rank(z_batch, atol=1e-1, rtol=1e-1) 158 | metrics["rank-3"] = rank3.item() 159 | metrics["rank-2"] = rank2.item() 160 | metrics["rank-1"] = rank1.item() 161 | 162 | def single_update(self, memory, batch_size: int): 163 | """ 164 | H: burn-in len 165 | L: forward len 166 | N: TD(n) 167 | """ 168 | metrics = {} 169 | losses = 0.0 170 | 171 | # 1. Sample a batch of data in numpy arrays 172 | ( 173 | batch_burn_in_hist, # (B, H, O+A+1) 174 | batch_learn_hist, # (B, L+N, O+A+1) 175 | batch_rewards, # (B, L) 176 | batch_learn_len, # (B,) in [1, L] 177 | batch_forward_idx, # (B, L) in [N, L+N-1] and {0,