├── .gitignore ├── LICENSE ├── README.md ├── agents ├── drqv2.py └── taco.py ├── cfgs ├── agent │ ├── drqv2.yaml │ └── taco.yaml ├── config.yaml └── task │ ├── acrobot_swingup.yaml │ ├── acrobot_swingup_sparse.yaml │ ├── cartpole_balance.yaml │ ├── cartpole_balance_sparse.yaml │ ├── cartpole_swingup.yaml │ ├── cartpole_swingup_sparse.yaml │ ├── cheetah_run.yaml │ ├── cup_catch.yaml │ ├── easy.yaml │ ├── finger_spin.yaml │ ├── finger_turn_easy.yaml │ ├── finger_turn_hard.yaml │ ├── hard.yaml │ ├── hopper_hop.yaml │ ├── hopper_stand.yaml │ ├── manipulator_bring_ball.yaml │ ├── medium.yaml │ ├── pendulum_swingup.yaml │ ├── quadruped_run.yaml │ ├── quadruped_walk.yaml │ ├── reach_duplo.yaml │ ├── reacher_easy.yaml │ ├── reacher_hard.yaml │ ├── walker_run.yaml │ ├── walker_stand.yaml │ └── walker_walk.yaml ├── dmc.py ├── environment.yml ├── logger.py ├── media ├── dmc.gif ├── overview.png └── taco.jpg ├── replay_buffer.py ├── train.py ├── utils.py └── video.py /.gitignore: -------------------------------------------------------------------------------- 1 | exp_local/ 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/#use-with-ide 111 | .pdm.toml 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | #.idea/ 162 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Ruijie Zheng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TACO: Temporal Action-driven Contrastive Learning 2 | 3 | Original PyTorch implementation of **TACO** from 4 | 5 | [TACO: Temporal Latent Action-Driven Contrastive Loss for Visual Reinforcement Learning](https://arxiv.org/pdf/2306.13229.pdf) by 6 | 7 | [Ruijie Zheng](https://ruijiezheng.com), [Xiyao Wang](https://si0wang.github.io)\*, [Yanchao Sun](https://ycsun2017.github.io)\*, [Shuang Ma](https://www.shuangma.me)\*, [Jieyu Zhao](https://jyzhao.net)\*, [Huazhe Xu](http://hxu.rocks)\*, [Hal Daumé III](http://users.umiacs.umd.edu/~hal/)\*, [Furong Huang](https://furong-huang.com)\* 8 | 9 | 10 |

11 |

12 | [Paper][Website] 13 |

14 | 15 | 16 | ## Method 17 | 18 | **TACO** is a simple yet powerful temporal contrastive learning approach that facilitates the concurrent acquisition of latent state and action representations for agents. **TACO** simultaneously learns a state and an action representation by optimizing the mutual information between representations of current states paired with action sequences and representations of the corresponding future states. 19 | 20 |

21 | 22 |

23 | 24 | 25 | ## Citation 26 | 27 | If you use our method or code in your research, please consider citing the paper as follows: 28 | 29 | ``` 30 | @inproceedings{ 31 | zheng2023taco, 32 | title={\${\textbackslash}texttt\{{TACO}\}\$: Temporal Latent Action-Driven Contrastive Loss for Visual Reinforcement Learning}, 33 | author={Ruijie Zheng and Xiyao Wang and Yanchao Sun and Shuang Ma and Jieyu Zhao and Huazhe Xu and Hal Daumé III and Furong Huang}, 34 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems}, 35 | year={2023}, 36 | url={https://openreview.net/forum?id=ezCsMOy1w9} 37 | } 38 | 39 | ``` 40 | 41 | ## Instructions 42 | 43 | Assuming that you already have [MuJoCo](http://www.mujoco.org) installed, install dependencies using `conda`: 44 | 45 | ``` 46 | conda env create -f environment.yml 47 | conda activate taco 48 | ``` 49 | 50 | After installing dependencies, you can train a **TACO** agent by calling (using quadruped_run as an example): 51 | 52 | ``` 53 | CUDA_VISIBLE_DEVICES=X python train.py agent=taco task=quadruped_run exp_name=${EXP_NAME} 54 | ``` 55 | 56 | To train a **DrQ-v2** agent: 57 | ``` 58 | CUDA_VISIBLE_DEVICES=X python train.py agent=drqv2 task=quadruped_run exp_name=${EXP_NAME} 59 | ``` 60 | 61 | Evaluation videos and model weights can be saved with arguments `save_video=True` and `save_model=True`. Refer to the `cfgs` directory for a full list of options and default hyperparameters. 62 | 63 | 64 | ## Acknowledgement 65 | TACO is licensed under the MIT license. MuJoCo and DeepMind Control Suite are licensed under the Apache 2.0 license. We would like to thank DrQ-v2 authors for open-sourcing the [DrQv2](https://github.com/facebookresearch/drqv2) codebase. Our implementation builds on top of their repository. 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /agents/drqv2.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import utils 7 | import itertools 8 | 9 | 10 | class RandomShiftsAug(nn.Module): 11 | def __init__(self, pad): 12 | super().__init__() 13 | self.pad = pad 14 | 15 | def forward(self, x): 16 | n, c, h, w = x.size() 17 | assert h == w 18 | padding = tuple([self.pad] * 4) 19 | x = F.pad(x, padding, 'replicate') 20 | eps = 1.0 / (h + 2 * self.pad) 21 | arange = torch.linspace(-1.0 + eps, 22 | 1.0 - eps, 23 | h + 2 * self.pad, 24 | device=x.device, 25 | dtype=x.dtype)[:h] 26 | arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2) 27 | base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2) 28 | base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1) 29 | 30 | shift = torch.randint(0, 31 | 2 * self.pad + 1, 32 | size=(n, 1, 1, 2), 33 | device=x.device, 34 | dtype=x.dtype) 35 | shift *= 2.0 / (h + 2 * self.pad) 36 | 37 | grid = base_grid + shift 38 | return F.grid_sample(x, 39 | grid, 40 | padding_mode='zeros', 41 | align_corners=False) 42 | 43 | class Encoder(nn.Module): 44 | def __init__(self, obs_shape, feature_dim): 45 | super().__init__() 46 | 47 | assert len(obs_shape) == 3 48 | self.repr_dim = 32 * 35 * 35 49 | 50 | self.convnet = nn.Sequential(nn.Conv2d(obs_shape[0], 32, 3, stride=2), 51 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1), 52 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1), 53 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1), 54 | nn.ReLU()) 55 | 56 | self.apply(utils.weight_init) 57 | 58 | def forward(self, obs): 59 | obs = obs / 255.0 - 0.5 60 | h = self.convnet(obs) 61 | h = h.view(h.shape[0], -1) 62 | return h 63 | 64 | 65 | 66 | class Actor(nn.Module): 67 | def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim): 68 | super().__init__() 69 | 70 | self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim), 71 | nn.LayerNorm(feature_dim), nn.Tanh()) 72 | 73 | self.policy = nn.Sequential(nn.Linear(feature_dim, hidden_dim), 74 | nn.ReLU(inplace=True), 75 | nn.Linear(hidden_dim, hidden_dim), 76 | nn.ReLU(inplace=True), 77 | nn.Linear(hidden_dim, action_shape[0])) 78 | 79 | self.apply(utils.weight_init) 80 | 81 | def forward(self, obs, std): 82 | h = self.trunk(obs) 83 | mu = self.policy(h) 84 | mu = torch.tanh(mu) 85 | std = torch.ones_like(mu) * std 86 | 87 | dist = utils.TruncatedNormal(mu, std) 88 | return dist 89 | 90 | 91 | class Critic(nn.Module): 92 | def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim): 93 | super().__init__() 94 | 95 | self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim), 96 | nn.LayerNorm(feature_dim), nn.Tanh()) 97 | 98 | self.Q1 = nn.Sequential( 99 | nn.Linear(feature_dim + action_shape[0], hidden_dim), 100 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim), 101 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1)) 102 | 103 | self.Q2 = nn.Sequential( 104 | nn.Linear(feature_dim + action_shape[0], hidden_dim), 105 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim), 106 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1)) 107 | 108 | self.apply(utils.weight_init) 109 | 110 | def forward(self, obs, action): 111 | h = self.trunk(obs) 112 | h_action = torch.cat([h, action], dim=-1) 113 | q1 = self.Q1(h_action) 114 | q2 = self.Q2(h_action) 115 | 116 | return q1, q2 117 | 118 | 119 | class DrQV2Agent: 120 | def __init__(self, obs_shape, action_shape, device, lr, feature_dim, 121 | hidden_dim, critic_target_tau, num_expl_steps, 122 | update_every_steps, stddev_schedule, stddev_clip, use_tb): 123 | self.device = device 124 | self.critic_target_tau = critic_target_tau 125 | self.update_every_steps = update_every_steps 126 | self.use_tb = use_tb 127 | self.num_expl_steps = num_expl_steps 128 | self.stddev_schedule = stddev_schedule 129 | self.stddev_clip = stddev_clip 130 | 131 | self.encoder = Encoder(obs_shape, feature_dim).to(device) 132 | self.actor = Actor(self.encoder.repr_dim, action_shape, feature_dim, 133 | hidden_dim).to(device) 134 | 135 | self.critic = Critic(self.encoder.repr_dim, action_shape, feature_dim, hidden_dim).to(device) 136 | self.critic_target = Critic(self.encoder.repr_dim, action_shape, feature_dim, hidden_dim).to(device) 137 | self.critic_target.load_state_dict(self.critic.state_dict()) 138 | 139 | # optimizers 140 | self.encoder_opt = torch.optim.Adam(self.encoder.parameters(), lr=lr) 141 | self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr) 142 | self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=lr) 143 | 144 | self.cross_entropy_loss = nn.CrossEntropyLoss() 145 | self.mse_loss = nn.MSELoss() 146 | 147 | # data augmentation 148 | self.aug = RandomShiftsAug(pad=4) 149 | 150 | self.train() 151 | self.critic_target.train() 152 | 153 | def train(self, training=True): 154 | self.training = training 155 | self.encoder.train(training) 156 | self.actor.train(training) 157 | self.critic.train(training) 158 | 159 | def act(self, obs, step, eval_mode): 160 | obs = torch.as_tensor(obs, device=self.device) 161 | obs = self.encoder(obs.unsqueeze(0)) 162 | stddev = utils.schedule(self.stddev_schedule, step) 163 | dist = self.actor(obs, stddev) 164 | if eval_mode: 165 | action = dist.mean 166 | else: 167 | action = dist.sample(clip=None) 168 | if step < self.num_expl_steps: 169 | action.uniform_(-1.0, 1.0) 170 | return action.cpu().numpy()[0] 171 | 172 | def update_critic(self, obs, action, reward, discount, next_obs, step): 173 | metrics = dict() 174 | 175 | with torch.no_grad(): 176 | stddev = utils.schedule(self.stddev_schedule, step) 177 | dist = self.actor(next_obs, stddev) 178 | next_action = dist.sample(clip=self.stddev_clip) 179 | target_Q1, target_Q2 = self.critic_target(next_obs, next_action) 180 | target_V = torch.min(target_Q1, target_Q2) 181 | target_Q = reward + (discount * target_V) 182 | 183 | Q1, Q2 = self.critic(obs, action) 184 | critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q) 185 | 186 | if self.use_tb: 187 | metrics['critic_target_q'] = target_Q.mean().item() 188 | metrics['critic_q1'] = Q1.mean().item() 189 | metrics['critic_q2'] = Q2.mean().item() 190 | metrics['critic_loss'] = critic_loss.item() 191 | 192 | # optimize encoder and critic 193 | self.encoder_opt.zero_grad(set_to_none=True) 194 | self.critic_opt.zero_grad(set_to_none=True) 195 | critic_loss.backward() 196 | self.critic_opt.step() 197 | self.encoder_opt.step() 198 | 199 | return metrics 200 | 201 | def update_actor(self, obs, step): 202 | metrics = dict() 203 | 204 | stddev = utils.schedule(self.stddev_schedule, step) 205 | dist = self.actor(obs, stddev) 206 | action = dist.sample(clip=self.stddev_clip) 207 | log_prob = dist.log_prob(action).sum(-1, keepdim=True) 208 | Q1, Q2 = self.critic(obs, action) 209 | Q = torch.min(Q1, Q2) 210 | 211 | actor_loss = -Q.mean() 212 | 213 | # optimize actor 214 | self.actor_opt.zero_grad(set_to_none=True) 215 | actor_loss.backward() 216 | self.actor_opt.step() 217 | 218 | if self.use_tb: 219 | metrics['actor_loss'] = actor_loss.item() 220 | metrics['actor_logprob'] = log_prob.mean().item() 221 | metrics['actor_ent'] = dist.entropy().sum(dim=-1).mean().item() 222 | 223 | return metrics 224 | 225 | 226 | 227 | def update(self, replay_iter, step): 228 | metrics = dict() 229 | if step % self.update_every_steps != 0: 230 | return metrics 231 | 232 | batch = next(replay_iter) 233 | obs, action, action_seq, reward, discount, next_obs, r_next_obs = utils.to_torch( 234 | batch, self.device) 235 | 236 | # augment 237 | obs_en = self.aug(obs.float()) 238 | next_obs_en = self.aug(next_obs.float()) 239 | # encode 240 | obs_en = self.encoder(obs_en) 241 | with torch.no_grad(): 242 | next_obs_en = self.encoder(next_obs_en) 243 | 244 | if self.use_tb: 245 | metrics['batch_reward'] = reward.mean().item() 246 | 247 | # update critic 248 | metrics.update( 249 | self.update_critic(obs_en, action, reward, discount, next_obs_en, step)) 250 | 251 | # update actor 252 | metrics.update(self.update_actor(obs_en.detach(), step)) 253 | 254 | # update critic target 255 | utils.soft_update_params(self.critic, self.critic_target, 256 | self.critic_target_tau) 257 | 258 | return metrics 259 | -------------------------------------------------------------------------------- /agents/taco.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import utils 7 | import itertools 8 | 9 | 10 | class RandomShiftsAug(nn.Module): 11 | def __init__(self, pad): 12 | super().__init__() 13 | self.pad = pad 14 | 15 | def forward(self, x): 16 | n, c, h, w = x.size() 17 | assert h == w 18 | padding = tuple([self.pad] * 4) 19 | x = F.pad(x, padding, 'replicate') 20 | eps = 1.0 / (h + 2 * self.pad) 21 | arange = torch.linspace(-1.0 + eps, 22 | 1.0 - eps, 23 | h + 2 * self.pad, 24 | device=x.device, 25 | dtype=x.dtype)[:h] 26 | arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2) 27 | base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2) 28 | base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1) 29 | 30 | shift = torch.randint(0, 31 | 2 * self.pad + 1, 32 | size=(n, 1, 1, 2), 33 | device=x.device, 34 | dtype=x.dtype) 35 | shift *= 2.0 / (h + 2 * self.pad) 36 | 37 | grid = base_grid + shift 38 | return F.grid_sample(x, 39 | grid, 40 | padding_mode='zeros', 41 | align_corners=False) 42 | 43 | class Encoder(nn.Module): 44 | def __init__(self, obs_shape, feature_dim): 45 | super().__init__() 46 | 47 | assert len(obs_shape) == 3 48 | self.repr_dim = 32 * 35 * 35 49 | 50 | self.convnet = nn.Sequential(nn.Conv2d(obs_shape[0], 32, 3, stride=2), 51 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1), 52 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1), 53 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1), 54 | nn.ReLU()) 55 | 56 | self.apply(utils.weight_init) 57 | 58 | def forward(self, obs): 59 | obs = obs / 255.0 - 0.5 60 | h = self.convnet(obs) 61 | h = h.view(h.shape[0], -1) 62 | return h 63 | 64 | class TACO(nn.Module): 65 | """ 66 | TACO Constrastive loss 67 | """ 68 | 69 | def __init__(self, repr_dim, feature_dim, action_shape, latent_a_dim, hidden_dim, act_tok, encoder, multistep, device): 70 | super(TACO, self).__init__() 71 | 72 | self.multistep = multistep 73 | self.encoder = encoder 74 | self.device = device 75 | 76 | a_dim = action_shape[0] 77 | 78 | self.proj_sa = nn.Sequential( 79 | nn.Linear(feature_dim + latent_a_dim*multistep, hidden_dim), 80 | nn.ReLU(inplace=True), 81 | nn.Linear(hidden_dim, feature_dim) 82 | ) 83 | 84 | self.act_tok = act_tok 85 | 86 | self.proj_s = nn.Sequential(nn.Linear(repr_dim, feature_dim), 87 | nn.LayerNorm(feature_dim), nn.Tanh()) 88 | 89 | self.reward = nn.Sequential( 90 | nn.Linear(feature_dim+latent_a_dim*multistep, hidden_dim), 91 | nn.ReLU(inplace=True), 92 | nn.Linear(hidden_dim, 1) 93 | ) 94 | 95 | self.W = nn.Parameter(torch.rand(feature_dim, feature_dim)) 96 | self.apply(utils.weight_init) 97 | 98 | def encode(self, x, ema=False): 99 | """ 100 | Encoder: z_t = e(x_t) 101 | :param x: x_t, x y coordinates 102 | :return: z_t, value in r2 103 | """ 104 | if ema: 105 | with torch.no_grad(): 106 | z_out = self.proj_s(self.encoder(x)) 107 | else: 108 | z_out = self.proj_s(self.encoder(x)) 109 | return z_out 110 | 111 | def project_sa(self, s, a): 112 | x = torch.concat([s,a], dim=-1) 113 | return self.proj_sa(x) 114 | 115 | def compute_logits(self, z_a, z_pos): 116 | """ 117 | - compute (B,B) matrix z_a (W z_pos.T) 118 | - positives are all diagonal elements 119 | - negatives are all other elements 120 | - to compute loss use multiclass cross entropy with identity matrix for labels 121 | """ 122 | 123 | Wz = torch.matmul(self.W, z_pos.T) # (z_dim,B) 124 | logits = torch.matmul(z_a, Wz) # (B,B) 125 | logits = logits - torch.max(logits, 1)[0][:, None] 126 | return logits 127 | 128 | 129 | class Actor(nn.Module): 130 | def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim): 131 | super().__init__() 132 | 133 | self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim), 134 | nn.LayerNorm(feature_dim), nn.Tanh()) 135 | 136 | self.policy = nn.Sequential(nn.Linear(feature_dim, hidden_dim), 137 | nn.ReLU(inplace=True), 138 | nn.Linear(hidden_dim, hidden_dim), 139 | nn.ReLU(inplace=True), 140 | nn.Linear(hidden_dim, action_shape[0])) 141 | 142 | self.apply(utils.weight_init) 143 | 144 | def forward(self, obs, std): 145 | h = self.trunk(obs) 146 | mu = self.policy(h) 147 | mu = torch.tanh(mu) 148 | std = torch.ones_like(mu) * std 149 | 150 | dist = utils.TruncatedNormal(mu, std) 151 | return dist 152 | 153 | 154 | class Critic(nn.Module): 155 | def __init__(self, repr_dim, latent_a_dim, feature_dim, hidden_dim): 156 | super().__init__() 157 | 158 | self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim), 159 | nn.LayerNorm(feature_dim), nn.Tanh()) 160 | 161 | self.Q1 = nn.Sequential( 162 | nn.Linear(feature_dim + latent_a_dim, hidden_dim), 163 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim), 164 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1)) 165 | 166 | self.Q2 = nn.Sequential( 167 | nn.Linear(feature_dim + latent_a_dim, hidden_dim), 168 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim), 169 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1)) 170 | 171 | self.apply(utils.weight_init) 172 | 173 | def forward(self, obs, action, act_tok=None): 174 | if act_tok is not None: 175 | action = act_tok(action) 176 | h = self.trunk(obs) 177 | h_action = torch.cat([h, action], dim=-1) 178 | q1 = self.Q1(h_action) 179 | q2 = self.Q2(h_action) 180 | 181 | return q1, q2 182 | 183 | 184 | class TACOAgent: 185 | def __init__(self, obs_shape, action_shape, device, lr, encoder_lr, feature_dim, 186 | hidden_dim, critic_target_tau, num_expl_steps, 187 | update_every_steps, stddev_schedule, stddev_clip, use_tb, 188 | reward, multistep, latent_a_dim, curl): 189 | self.device = device 190 | self.critic_target_tau = critic_target_tau 191 | self.update_every_steps = update_every_steps 192 | self.use_tb = use_tb 193 | self.num_expl_steps = num_expl_steps 194 | self.stddev_schedule = stddev_schedule 195 | self.stddev_clip = stddev_clip 196 | 197 | self.reward = reward 198 | self.multistep = multistep 199 | self.curl = curl 200 | 201 | ### A heuristics to choose the dimensionality of latent actions 202 | if latent_a_dim == 'none': 203 | latent_a_dim = int(action_shape[0]*1.25)+1 204 | ### Create action embeddings 205 | self.act_tok = utils.ActionEncoding(action_shape[0], latent_a_dim, multistep) 206 | self.encoder = Encoder(obs_shape, feature_dim).to(device) 207 | 208 | self.actor = Actor(self.encoder.repr_dim, action_shape, feature_dim, 209 | hidden_dim).to(device) 210 | self.critic = Critic(self.encoder.repr_dim, latent_a_dim, feature_dim, 211 | hidden_dim).to(device) 212 | self.critic_target = Critic(self.encoder.repr_dim, latent_a_dim, 213 | feature_dim, hidden_dim).to(device) 214 | self.critic_target.load_state_dict(self.critic.state_dict()) 215 | self.TACO = TACO(self.encoder.repr_dim, feature_dim, action_shape, latent_a_dim, hidden_dim, self.act_tok, self.encoder, multistep, device).to(device) 216 | 217 | ### State & Action Encoders 218 | parameters = itertools.chain(self.encoder.parameters(), 219 | self.act_tok.parameters(), 220 | ) 221 | self.encoder_opt = torch.optim.Adam(parameters, lr=encoder_lr) 222 | self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr) 223 | self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=lr) 224 | self.taco_opt = torch.optim.Adam(self.TACO.parameters(), lr=encoder_lr) 225 | 226 | self.cross_entropy_loss = nn.CrossEntropyLoss() 227 | 228 | # data augmentation 229 | self.aug = RandomShiftsAug(pad=4) 230 | 231 | self.train() 232 | self.critic_target.train() 233 | 234 | def train(self, training=True): 235 | self.training = training 236 | self.encoder.train(training) 237 | self.actor.train(training) 238 | self.critic.train(training) 239 | self.TACO.train() 240 | 241 | def act(self, obs, step, eval_mode): 242 | obs = torch.as_tensor(obs, device=self.device) 243 | obs = self.encoder(obs.unsqueeze(0)) 244 | stddev = utils.schedule(self.stddev_schedule, step) 245 | dist = self.actor(obs, stddev) 246 | if eval_mode: 247 | action = dist.mean 248 | else: 249 | action = dist.sample(clip=None) 250 | if step < self.num_expl_steps: 251 | action.uniform_(-1.0, 1.0) 252 | return action.cpu().numpy()[0] 253 | 254 | def update_critic(self, obs, action, reward, discount, next_obs, step): 255 | metrics = dict() 256 | 257 | with torch.no_grad(): 258 | stddev = utils.schedule(self.stddev_schedule, step) 259 | dist = self.actor(next_obs, stddev) 260 | next_action = dist.sample(clip=self.stddev_clip) 261 | target_Q1, target_Q2 = self.critic_target(next_obs, next_action, self.act_tok) 262 | target_V = torch.min(target_Q1, target_Q2) 263 | target_Q = reward + (discount * target_V) 264 | 265 | Q1, Q2 = self.critic(obs, action, self.act_tok) 266 | critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q) 267 | 268 | if self.use_tb: 269 | metrics['critic_target_q'] = target_Q.mean().item() 270 | metrics['critic_q1'] = Q1.mean().item() 271 | metrics['critic_q2'] = Q2.mean().item() 272 | metrics['critic_loss'] = critic_loss.item() 273 | 274 | # optimize encoder and critic 275 | self.encoder_opt.zero_grad(set_to_none=True) 276 | self.critic_opt.zero_grad(set_to_none=True) 277 | critic_loss.backward() 278 | self.critic_opt.step() 279 | self.encoder_opt.step() 280 | 281 | return metrics 282 | 283 | def update_actor(self, obs, step): 284 | metrics = dict() 285 | 286 | stddev = utils.schedule(self.stddev_schedule, step) 287 | dist = self.actor(obs, stddev) 288 | action = dist.sample(clip=self.stddev_clip) 289 | log_prob = dist.log_prob(action).sum(-1, keepdim=True) 290 | Q1, Q2 = self.critic(obs, action, self.act_tok) 291 | Q = torch.min(Q1, Q2) 292 | 293 | actor_loss = -Q.mean() 294 | 295 | # optimize actor 296 | self.actor_opt.zero_grad(set_to_none=True) 297 | actor_loss.backward() 298 | self.actor_opt.step() 299 | 300 | if self.use_tb: 301 | metrics['actor_loss'] = actor_loss.item() 302 | metrics['actor_logprob'] = log_prob.mean().item() 303 | metrics['actor_ent'] = dist.entropy().sum(dim=-1).mean().item() 304 | 305 | return metrics 306 | 307 | def update_taco(self, obs, action, action_seq, next_obs, reward): 308 | metrics = dict() 309 | 310 | obs_anchor = self.aug(obs.float()) 311 | obs_pos = self.aug(obs.float()) 312 | z_a = self.TACO.encode(obs_anchor) 313 | z_pos = self.TACO.encode(obs_pos, ema=True) 314 | ### Compute CURL loss 315 | if self.curl: 316 | logits = self.TACO.compute_logits(z_a, z_pos) 317 | labels = torch.arange(logits.shape[0]).long().to(self.device) 318 | curl_loss = self.cross_entropy_loss(logits, labels) 319 | else: 320 | curl_loss = torch.tensor(0.) 321 | 322 | ### Compute action encodings 323 | action_en = self.TACO.act_tok(action, seq=False) 324 | action_seq_en = self.TACO.act_tok(action_seq, seq=True) 325 | 326 | ### Compute reward prediction loss 327 | if self.reward: 328 | reward_pred = self.TACO.reward(torch.concat([z_a, action_seq_en], dim=-1)) 329 | reward_loss = F.mse_loss(reward_pred, reward) 330 | else: 331 | reward_loss = torch.tensor(0.) 332 | 333 | ### Compute TACO loss 334 | next_z = self.TACO.encode(self.aug(next_obs.float()), ema=True) 335 | curr_za = self.TACO.project_sa(z_a, action_seq_en) 336 | logits = self.TACO.compute_logits(curr_za, next_z) 337 | labels = torch.arange(logits.shape[0]).long().to(self.device) 338 | taco_loss = self.cross_entropy_loss(logits, labels) 339 | 340 | self.taco_opt.zero_grad() 341 | (taco_loss + curl_loss + reward_loss).backward() 342 | self.taco_opt.step() 343 | if self.use_tb: 344 | metrics['reward_loss'] = reward_loss.item() 345 | metrics['curl_loss'] = curl_loss.item() 346 | metrics['taco_loss'] = taco_loss.item() 347 | return metrics 348 | 349 | 350 | 351 | def update(self, replay_iter, step): 352 | metrics = dict() 353 | if step % self.update_every_steps != 0: 354 | return metrics 355 | 356 | batch = next(replay_iter) 357 | obs, action, action_seq, reward, discount, next_obs, r_next_obs = utils.to_torch( 358 | batch, self.device) 359 | 360 | # augment 361 | obs_en = self.aug(obs.float()) 362 | next_obs_en = self.aug(next_obs.float()) 363 | # encode 364 | obs_en = self.encoder(obs_en) 365 | with torch.no_grad(): 366 | next_obs_en = self.encoder(next_obs_en) 367 | 368 | if self.use_tb: 369 | metrics['batch_reward'] = reward.mean().item() 370 | 371 | # update critic 372 | metrics.update( 373 | self.update_critic(obs_en, action, reward, discount, next_obs_en, step)) 374 | 375 | # update actor 376 | metrics.update(self.update_actor(obs_en.detach(), step)) 377 | 378 | # update critic target 379 | utils.soft_update_params(self.critic, self.critic_target, 380 | self.critic_target_tau) 381 | 382 | metrics.update(self.update_taco(obs, action, action_seq, r_next_obs, reward)) 383 | 384 | return metrics 385 | -------------------------------------------------------------------------------- /cfgs/agent/drqv2.yaml: -------------------------------------------------------------------------------- 1 | agent: 2 | _target_: agents.drqv2.DrQV2Agent 3 | obs_shape: ??? # to be specified later 4 | action_shape: ??? # to be specified later 5 | device: ${device} 6 | lr: ${lr} 7 | critic_target_tau: 0.01 8 | update_every_steps: 2 9 | use_tb: ${use_tb} 10 | num_expl_steps: 2000 11 | hidden_dim: 1024 12 | feature_dim: ${feature_dim} 13 | stddev_schedule: ${stddev_schedule} 14 | stddev_clip: 0.3 15 | 16 | batch_size: 256 -------------------------------------------------------------------------------- /cfgs/agent/taco.yaml: -------------------------------------------------------------------------------- 1 | agent: 2 | _target_: agents.taco.TACOAgent 3 | obs_shape: ??? # to be specified later 4 | action_shape: ??? # to be specified later 5 | device: ${device} 6 | lr: ${lr} 7 | encoder_lr: ${encoder_lr} 8 | critic_target_tau: 0.01 9 | update_every_steps: 2 10 | use_tb: ${use_tb} 11 | num_expl_steps: 2000 12 | hidden_dim: 1024 13 | feature_dim: ${feature_dim} 14 | stddev_schedule: ${stddev_schedule} 15 | stddev_clip: 0.3 16 | curl: ${curl} 17 | reward: ${reward} 18 | multistep: ${multistep} 19 | latent_a_dim: ${latent_a_dim} 20 | 21 | ### TACO parameters 22 | curl: true 23 | reward: true 24 | multistep: 3 25 | latent_a_dim: none 26 | batch_size: 1024 -------------------------------------------------------------------------------- /cfgs/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - task@_global_: quadruped_walk 4 | - agent@_global_: taco 5 | - override hydra/launcher: submitit_local 6 | 7 | domain: dmc 8 | # task settings 9 | frame_stack: 3 10 | action_repeat: 2 11 | discount: 0.99 12 | # train settings 13 | num_seed_frames: 4000 14 | # eval 15 | eval_every_frames: 10000 16 | num_eval_episodes: 10 17 | # snapshot 18 | save_snapshot: true 19 | # replay buffer 20 | replay_buffer_size: 1000000 21 | replay_buffer_num_workers: 4 22 | nstep: 3 23 | batch_size: 256 24 | # misc 25 | seed: 1 26 | device: cuda 27 | save_video: false 28 | save_train_video: false 29 | use_tb: false 30 | # experiment 31 | experiment: exp 32 | # agent 33 | lr: 1e-4 34 | encoder_lr: 1e-4 35 | feature_dim: 50 36 | exp_name: default 37 | 38 | 39 | hydra: 40 | run: 41 | dir: ./exp_local/${exp_name} 42 | sweep: 43 | dir: ./exp/${now:%Y.%m.%d}/${now:%H%M}_${agent_cfg.experiment} 44 | subdir: ${hydra.job.num} 45 | launcher: 46 | timeout_min: 4300 47 | cpus_per_task: 10 48 | gpus_per_node: 1 49 | tasks_per_node: 1 50 | mem_gb: 160 51 | nodes: 1 52 | submitit_folder: ./exp/${now:%Y.%m.%d}/${now:%H%M%S}_${agent_cfg.experiment}/.slurm -------------------------------------------------------------------------------- /cfgs/task/acrobot_swingup.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - medium 3 | - _self_ 4 | 5 | task_name: acrobot_swingup 6 | -------------------------------------------------------------------------------- /cfgs/task/acrobot_swingup_sparse.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - medium 3 | - _self_ 4 | 5 | task_name: acrobot_swingup_sparse 6 | -------------------------------------------------------------------------------- /cfgs/task/cartpole_balance.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - easy 3 | - _self_ 4 | 5 | task_name: cartpole_balance 6 | -------------------------------------------------------------------------------- /cfgs/task/cartpole_balance_sparse.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - easy 3 | - _self_ 4 | 5 | task_name: cartpole_balance_sparse 6 | -------------------------------------------------------------------------------- /cfgs/task/cartpole_swingup.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - easy 3 | - _self_ 4 | 5 | task_name: cartpole_swingup 6 | -------------------------------------------------------------------------------- /cfgs/task/cartpole_swingup_sparse.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - medium 3 | - _self_ 4 | 5 | task_name: cartpole_swingup_sparse 6 | -------------------------------------------------------------------------------- /cfgs/task/cheetah_run.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - medium 3 | - _self_ 4 | 5 | task_name: cheetah_run 6 | -------------------------------------------------------------------------------- /cfgs/task/cup_catch.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - easy 3 | - _self_ 4 | 5 | task_name: cup_catch 6 | -------------------------------------------------------------------------------- /cfgs/task/easy.yaml: -------------------------------------------------------------------------------- 1 | num_train_frames: 1100000 2 | stddev_schedule: 'linear(1.0,0.1,100000)' -------------------------------------------------------------------------------- /cfgs/task/finger_spin.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - easy 3 | - _self_ 4 | 5 | task_name: finger_spin 6 | -------------------------------------------------------------------------------- /cfgs/task/finger_turn_easy.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - medium 3 | - _self_ 4 | 5 | task_name: finger_turn_easy -------------------------------------------------------------------------------- /cfgs/task/finger_turn_hard.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - medium 3 | - _self_ 4 | 5 | task_name: finger_turn_hard 6 | -------------------------------------------------------------------------------- /cfgs/task/hard.yaml: -------------------------------------------------------------------------------- 1 | num_train_frames: 30100000 2 | stddev_schedule: 'linear(1.0,0.1,2000000)' 3 | -------------------------------------------------------------------------------- /cfgs/task/hopper_hop.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - medium 3 | - _self_ 4 | 5 | task_name: hopper_hop 6 | -------------------------------------------------------------------------------- /cfgs/task/hopper_stand.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - easy 3 | - _self_ 4 | 5 | task_name: hopper_stand 6 | -------------------------------------------------------------------------------- /cfgs/task/manipulator_bring_ball.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - hard 3 | - _self_ 4 | 5 | task_name: manipulator_bring_ball 6 | -------------------------------------------------------------------------------- /cfgs/task/medium.yaml: -------------------------------------------------------------------------------- 1 | num_train_frames: 3100000 2 | stddev_schedule: 'linear(1.0,0.1,500000)' 3 | -------------------------------------------------------------------------------- /cfgs/task/pendulum_swingup.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - easy 3 | - _self_ 4 | 5 | task_name: pendulum_swingup 6 | -------------------------------------------------------------------------------- /cfgs/task/quadruped_run.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - medium 3 | - _self_ 4 | 5 | task_name: quadruped_run 6 | replay_buffer_size: 100000 -------------------------------------------------------------------------------- /cfgs/task/quadruped_walk.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - medium 3 | - _self_ 4 | 5 | task_name: quadruped_walk 6 | -------------------------------------------------------------------------------- /cfgs/task/reach_duplo.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - medium 3 | - _self_ 4 | 5 | task_name: reach_duplo 6 | -------------------------------------------------------------------------------- /cfgs/task/reacher_easy.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - medium 3 | - _self_ 4 | 5 | task_name: reacher_easy 6 | -------------------------------------------------------------------------------- /cfgs/task/reacher_hard.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - medium 3 | - _self_ 4 | 5 | task_name: reacher_hard 6 | -------------------------------------------------------------------------------- /cfgs/task/walker_run.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - medium 3 | - _self_ 4 | 5 | task_name: walker_run 6 | nstep: 1 7 | batch_size: 512 -------------------------------------------------------------------------------- /cfgs/task/walker_stand.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - easy 3 | - _self_ 4 | 5 | task_name: walker_stand 6 | nstep: 1 7 | batch_size: 512 8 | -------------------------------------------------------------------------------- /cfgs/task/walker_walk.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - easy 3 | - _self_ 4 | 5 | task_name: walker_walk 6 | nstep: 1 7 | batch_size: 512 8 | -------------------------------------------------------------------------------- /dmc.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from typing import Any, NamedTuple 3 | 4 | import dm_env 5 | import numpy as np 6 | from dm_control import manipulation, suite 7 | from dm_control.suite.wrappers import action_scale, pixels 8 | from dm_env import StepType, specs 9 | 10 | 11 | class ExtendedTimeStep(NamedTuple): 12 | step_type: Any 13 | reward: Any 14 | discount: Any 15 | observation: Any 16 | action: Any 17 | 18 | def first(self): 19 | return self.step_type == StepType.FIRST 20 | 21 | def mid(self): 22 | return self.step_type == StepType.MID 23 | 24 | def last(self): 25 | return self.step_type == StepType.LAST 26 | 27 | def __getitem__(self, attr): 28 | if isinstance(attr, str): 29 | return getattr(self, attr) 30 | else: 31 | return tuple.__getitem__(self, attr) 32 | 33 | 34 | class ActionRepeatWrapper(dm_env.Environment): 35 | def __init__(self, env, num_repeats): 36 | self._env = env 37 | self._num_repeats = num_repeats 38 | 39 | def step(self, action): 40 | reward = 0.0 41 | discount = 1.0 42 | for i in range(self._num_repeats): 43 | time_step = self._env.step(action) 44 | reward += (time_step.reward or 0.0) * discount 45 | discount *= time_step.discount 46 | if time_step.last(): 47 | break 48 | 49 | return time_step._replace(reward=reward, discount=discount) 50 | 51 | def observation_spec(self): 52 | return self._env.observation_spec() 53 | 54 | def action_spec(self): 55 | return self._env.action_spec() 56 | 57 | def reset(self): 58 | return self._env.reset() 59 | 60 | def __getattr__(self, name): 61 | return getattr(self._env, name) 62 | 63 | 64 | class FrameStackWrapper(dm_env.Environment): 65 | def __init__(self, env, num_frames, pixels_key='pixels'): 66 | self._env = env 67 | self._num_frames = num_frames 68 | self._frames = deque([], maxlen=num_frames) 69 | self._pixels_key = pixels_key 70 | 71 | wrapped_obs_spec = env.observation_spec() 72 | assert pixels_key in wrapped_obs_spec 73 | 74 | pixels_shape = wrapped_obs_spec[pixels_key].shape 75 | # remove batch dim 76 | if len(pixels_shape) == 4: 77 | pixels_shape = pixels_shape[1:] 78 | self._obs_spec = specs.BoundedArray(shape=np.concatenate( 79 | [[pixels_shape[2] * num_frames], pixels_shape[:2]], axis=0), 80 | dtype=np.uint8, 81 | minimum=0, 82 | maximum=255, 83 | name='observation') 84 | 85 | def _transform_observation(self, time_step): 86 | assert len(self._frames) == self._num_frames 87 | obs = np.concatenate(list(self._frames), axis=0) 88 | return time_step._replace(observation=obs) 89 | 90 | def _extract_pixels(self, time_step): 91 | pixels = time_step.observation[self._pixels_key] 92 | # remove batch dim 93 | if len(pixels.shape) == 4: 94 | pixels = pixels[0] 95 | return pixels.transpose(2, 0, 1).copy() 96 | 97 | def reset(self): 98 | time_step = self._env.reset() 99 | pixels = self._extract_pixels(time_step) 100 | for _ in range(self._num_frames): 101 | self._frames.append(pixels) 102 | return self._transform_observation(time_step) 103 | 104 | def step(self, action): 105 | time_step = self._env.step(action) 106 | pixels = self._extract_pixels(time_step) 107 | self._frames.append(pixels) 108 | return self._transform_observation(time_step) 109 | 110 | def observation_spec(self): 111 | return self._obs_spec 112 | 113 | def action_spec(self): 114 | return self._env.action_spec() 115 | 116 | def __getattr__(self, name): 117 | return getattr(self._env, name) 118 | 119 | 120 | class ActionDTypeWrapper(dm_env.Environment): 121 | def __init__(self, env, dtype): 122 | self._env = env 123 | wrapped_action_spec = env.action_spec() 124 | self._action_spec = specs.BoundedArray(wrapped_action_spec.shape, 125 | dtype, 126 | wrapped_action_spec.minimum, 127 | wrapped_action_spec.maximum, 128 | 'action') 129 | 130 | def step(self, action): 131 | action = action.astype(self._env.action_spec().dtype) 132 | return self._env.step(action) 133 | 134 | def observation_spec(self): 135 | return self._env.observation_spec() 136 | 137 | def action_spec(self): 138 | return self._action_spec 139 | 140 | def reset(self): 141 | return self._env.reset() 142 | 143 | def __getattr__(self, name): 144 | return getattr(self._env, name) 145 | 146 | 147 | class ExtendedTimeStepWrapper(dm_env.Environment): 148 | def __init__(self, env): 149 | self._env = env 150 | 151 | def reset(self): 152 | time_step = self._env.reset() 153 | return self._augment_time_step(time_step) 154 | 155 | def step(self, action): 156 | time_step = self._env.step(action) 157 | return self._augment_time_step(time_step, action) 158 | 159 | def _augment_time_step(self, time_step, action=None): 160 | if action is None: 161 | action_spec = self.action_spec() 162 | action = np.zeros(action_spec.shape, dtype=action_spec.dtype) 163 | return ExtendedTimeStep(observation=time_step.observation, 164 | step_type=time_step.step_type, 165 | action=action, 166 | reward=time_step.reward or 0.0, 167 | discount=time_step.discount or 1.0) 168 | 169 | def observation_spec(self): 170 | return self._env.observation_spec() 171 | 172 | def action_spec(self): 173 | return self._env.action_spec() 174 | 175 | def __getattr__(self, name): 176 | return getattr(self._env, name) 177 | 178 | 179 | def make(name, frame_stack, action_repeat, seed): 180 | domain, task = name.split('_', 1) 181 | # overwrite cup to ball_in_cup 182 | domain = dict(cup='ball_in_cup').get(domain, domain) 183 | # make sure reward is not visualized 184 | if (domain, task) in suite.ALL_TASKS: 185 | env = suite.load(domain, 186 | task, 187 | task_kwargs={'random': seed}, 188 | visualize_reward=False) 189 | pixels_key = 'pixels' 190 | else: 191 | name = f'{domain}_{task}_vision' 192 | env = manipulation.load(name, seed=seed) 193 | pixels_key = 'front_close' 194 | # add wrappers 195 | env = ActionDTypeWrapper(env, np.float32) 196 | env = ActionRepeatWrapper(env, action_repeat) 197 | env = action_scale.Wrapper(env, minimum=-1.0, maximum=+1.0) 198 | # add renderings for clasical tasks 199 | if (domain, task) in suite.ALL_TASKS: 200 | # zoom in camera for quadruped 201 | camera_id = dict(quadruped=2).get(domain, 0) 202 | render_kwargs = dict(height=84, width=84, camera_id=camera_id) 203 | env = pixels.Wrapper(env, 204 | pixels_only=True, 205 | render_kwargs=render_kwargs) 206 | # stack several frames 207 | env = FrameStackWrapper(env, frame_stack, pixels_key) 208 | env = ExtendedTimeStepWrapper(env) 209 | return env 210 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: taco 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.8 6 | - pip=21.1.3 7 | - numpy=1.19.2 8 | - absl-py=0.13.0 9 | - pyparsing=2.4.7 10 | - jupyterlab=3.0.14 11 | - scikit-image=0.18.1 12 | - nvidia::cudatoolkit=11.1 13 | - pytorch::pytorch 14 | - pytorch::torchvision 15 | - pytorch::torchaudio 16 | - pip: 17 | - chardet 18 | - gym 19 | - termcolor==1.1.0 20 | - dm_control 21 | - tb-nightly 22 | - imageio==2.9.0 23 | - imageio-ffmpeg==0.4.4 24 | - hydra-core==1.1.0 25 | - hydra-submitit-launcher==1.1.5 26 | - pandas==1.3.0 27 | - ipdb==0.13.9 28 | - yapf==0.31.0 29 | - mujoco_py==2.0.2.13 30 | - sklearn==0.0 31 | - matplotlib==3.4.2 32 | - opencv-python==4.5.3.56 33 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import datetime 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | import torch 7 | import torchvision 8 | from termcolor import colored 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | COMMON_TRAIN_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'), 12 | ('episode', 'E', 'int'), ('episode_length', 'L', 'int'), 13 | ('episode_reward', 'R', 'float'), 14 | ('buffer_size', 'BS', 'int'), ('fps', 'FPS', 'float'), 15 | ('total_time', 'T', 'time')] 16 | 17 | COMMON_EVAL_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'), 18 | ('episode', 'E', 'int'), ('episode_length', 'L', 'int'), 19 | ('episode_reward', 'R', 'float'), 20 | ('total_time', 'T', 'time')] 21 | 22 | 23 | class AverageMeter(object): 24 | def __init__(self): 25 | self._sum = 0 26 | self._count = 0 27 | 28 | def update(self, value, n=1): 29 | self._sum += value 30 | self._count += n 31 | 32 | def value(self): 33 | return self._sum / max(1, self._count) 34 | 35 | 36 | class MetersGroup(object): 37 | def __init__(self, csv_file_name, formating): 38 | self._csv_file_name = csv_file_name 39 | self._formating = formating 40 | self._meters = defaultdict(AverageMeter) 41 | self._csv_file = None 42 | self._csv_writer = None 43 | 44 | def log(self, key, value, n=1): 45 | self._meters[key].update(value, n) 46 | 47 | def _prime_meters(self): 48 | data = dict() 49 | for key, meter in self._meters.items(): 50 | if key.startswith('train'): 51 | key = key[len('train') + 1:] 52 | else: 53 | key = key[len('eval') + 1:] 54 | key = key.replace('/', '_') 55 | data[key] = meter.value() 56 | return data 57 | 58 | def _remove_old_entries(self, data): 59 | rows = [] 60 | with self._csv_file_name.open('r') as f: 61 | reader = csv.DictReader(f) 62 | for row in reader: 63 | if float(row['episode']) >= data['episode']: 64 | break 65 | rows.append(row) 66 | with self._csv_file_name.open('w') as f: 67 | writer = csv.DictWriter(f, 68 | fieldnames=sorted(data.keys()), 69 | restval=0.0) 70 | writer.writeheader() 71 | for row in rows: 72 | writer.writerow(row) 73 | 74 | def _dump_to_csv(self, data): 75 | if self._csv_writer is None: 76 | should_write_header = True 77 | if self._csv_file_name.exists(): 78 | self._remove_old_entries(data) 79 | should_write_header = False 80 | 81 | self._csv_file = self._csv_file_name.open('a') 82 | self._csv_writer = csv.DictWriter(self._csv_file, 83 | fieldnames=sorted(data.keys()), 84 | restval=0.0) 85 | if should_write_header: 86 | self._csv_writer.writeheader() 87 | 88 | self._csv_writer.writerow(data) 89 | self._csv_file.flush() 90 | 91 | def _format(self, key, value, ty): 92 | if ty == 'int': 93 | value = int(value) 94 | return f'{key}: {value}' 95 | elif ty == 'float': 96 | return f'{key}: {value:.04f}' 97 | elif ty == 'time': 98 | value = str(datetime.timedelta(seconds=int(value))) 99 | return f'{key}: {value}' 100 | else: 101 | raise f'invalid format type: {ty}' 102 | 103 | def _dump_to_console(self, data, prefix): 104 | prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green') 105 | pieces = [f'| {prefix: <14}'] 106 | for key, disp_key, ty in self._formating: 107 | value = data.get(key, 0) 108 | pieces.append(self._format(disp_key, value, ty)) 109 | print(' | '.join(pieces), flush=True) 110 | 111 | def dump(self, step, prefix): 112 | if len(self._meters) == 0: 113 | return 114 | data = self._prime_meters() 115 | data['frame'] = step 116 | self._dump_to_csv(data) 117 | self._dump_to_console(data, prefix) 118 | self._meters.clear() 119 | 120 | 121 | class Logger(object): 122 | def __init__(self, log_dir, use_tb): 123 | self._log_dir = log_dir 124 | self._train_mg = MetersGroup(log_dir / 'train.csv', 125 | formating=COMMON_TRAIN_FORMAT) 126 | self._eval_mg = MetersGroup(log_dir / 'eval.csv', 127 | formating=COMMON_EVAL_FORMAT) 128 | if use_tb: 129 | self._sw = SummaryWriter(str(log_dir / 'tb')) 130 | else: 131 | self._sw = None 132 | 133 | def _try_sw_log(self, key, value, step): 134 | if self._sw is not None: 135 | self._sw.add_scalar(key, value, step) 136 | 137 | def log(self, key, value, step): 138 | assert key.startswith('train') or key.startswith('eval') 139 | if type(value) == torch.Tensor: 140 | value = value.item() 141 | self._try_sw_log(key, value, step) 142 | mg = self._train_mg if key.startswith('train') else self._eval_mg 143 | mg.log(key, value) 144 | 145 | def log_metrics(self, metrics, step, ty): 146 | for key, value in metrics.items(): 147 | self.log(f'{ty}/{key}', value, step) 148 | 149 | def dump(self, step, ty=None): 150 | if ty is None or ty == 'eval': 151 | self._eval_mg.dump(step, 'eval') 152 | if ty is None or ty == 'train': 153 | self._train_mg.dump(step, 'train') 154 | 155 | def log_and_dump_ctx(self, step, ty): 156 | return LogAndDumpCtx(self, step, ty) 157 | 158 | 159 | class LogAndDumpCtx: 160 | def __init__(self, logger, step, ty): 161 | self._logger = logger 162 | self._step = step 163 | self._ty = ty 164 | 165 | def __enter__(self): 166 | return self 167 | 168 | def __call__(self, key, value): 169 | self._logger.log(f'{self._ty}/{key}', value, self._step) 170 | 171 | def __exit__(self, *args): 172 | self._logger.dump(self._step, self._ty) -------------------------------------------------------------------------------- /media/dmc.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FrankZheng2022/TACO/84c38e34f4f9dfd2b059fb6d1356757e8d40712e/media/dmc.gif -------------------------------------------------------------------------------- /media/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FrankZheng2022/TACO/84c38e34f4f9dfd2b059fb6d1356757e8d40712e/media/overview.png -------------------------------------------------------------------------------- /media/taco.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FrankZheng2022/TACO/84c38e34f4f9dfd2b059fb6d1356757e8d40712e/media/taco.jpg -------------------------------------------------------------------------------- /replay_buffer.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import io 3 | import random 4 | import traceback 5 | from collections import defaultdict 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.data import IterableDataset 11 | 12 | 13 | def episode_len(episode): 14 | # subtract -1 because the dummy first transition 15 | return next(iter(episode.values())).shape[0] - 1 16 | 17 | 18 | def save_episode(episode, fn): 19 | with io.BytesIO() as bs: 20 | np.savez_compressed(bs, **episode) 21 | bs.seek(0) 22 | with fn.open('wb') as f: 23 | f.write(bs.read()) 24 | 25 | 26 | def load_episode(fn): 27 | with fn.open('rb') as f: 28 | episode = np.load(f) 29 | episode = {k: episode[k] for k in episode.keys()} 30 | return episode 31 | 32 | 33 | class ReplayBufferStorage: 34 | def __init__(self, data_specs, replay_dir): 35 | self._data_specs = data_specs 36 | self._replay_dir = replay_dir 37 | replay_dir.mkdir(exist_ok=True) 38 | self._current_episode = defaultdict(list) 39 | self._preload() 40 | 41 | def __len__(self): 42 | return self._num_transitions 43 | 44 | def add(self, time_step): 45 | for spec in self._data_specs: 46 | value = time_step[spec.name] 47 | if np.isscalar(value): 48 | value = np.full(spec.shape, value, spec.dtype) 49 | assert spec.shape == value.shape and spec.dtype == value.dtype 50 | self._current_episode[spec.name].append(value) 51 | if time_step.last(): 52 | episode = dict() 53 | for spec in self._data_specs: 54 | value = self._current_episode[spec.name] 55 | episode[spec.name] = np.array(value, spec.dtype) 56 | self._current_episode = defaultdict(list) 57 | self._store_episode(episode) 58 | 59 | def _preload(self): 60 | self._num_episodes = 0 61 | self._num_transitions = 0 62 | for fn in self._replay_dir.glob('*.npz'): 63 | _, _, eps_len = fn.stem.split('_') 64 | self._num_episodes += 1 65 | self._num_transitions += int(eps_len) 66 | 67 | def _store_episode(self, episode): 68 | eps_idx = self._num_episodes 69 | eps_len = episode_len(episode) 70 | self._num_episodes += 1 71 | self._num_transitions += eps_len 72 | ts = datetime.datetime.now().strftime('%Y%m%dT%H%M%S') 73 | eps_fn = f'{ts}_{eps_idx}_{eps_len}.npz' 74 | save_episode(episode, self._replay_dir / eps_fn) 75 | 76 | 77 | class ReplayBuffer(IterableDataset): 78 | def __init__(self, replay_dir, max_size, num_workers, nstep, multistep, 79 | discount, fetch_every, save_snapshot): 80 | self._replay_dir = replay_dir 81 | self._size = 0 82 | self._max_size = max_size 83 | self._num_workers = max(1, num_workers) 84 | self._episode_fns = [] 85 | self._episodes = dict() 86 | self._nstep = nstep 87 | self._discount = discount 88 | self._fetch_every = fetch_every 89 | self._samples_since_last_fetch = fetch_every 90 | self._save_snapshot = save_snapshot 91 | self._multistep = multistep 92 | print('Loading Data into CPU Memory') 93 | self._preload() 94 | 95 | def _sample_episode(self): 96 | eps_fn = random.choice(self._episode_fns) 97 | return self._episodes[eps_fn] 98 | 99 | def _store_episode(self, eps_fn): 100 | try: 101 | episode = load_episode(eps_fn) 102 | except: 103 | return False 104 | eps_len = episode_len(episode) 105 | while eps_len + self._size > self._max_size: 106 | early_eps_fn = self._episode_fns.pop(0) 107 | early_eps = self._episodes.pop(early_eps_fn) 108 | self._size -= episode_len(early_eps) 109 | early_eps_fn.unlink(missing_ok=True) 110 | self._episode_fns.append(eps_fn) 111 | self._episode_fns.sort() 112 | self._episodes[eps_fn] = episode 113 | self._size += eps_len 114 | 115 | if not self._save_snapshot: 116 | eps_fn.unlink(missing_ok=True) 117 | return True 118 | 119 | def _try_fetch(self): 120 | if self._samples_since_last_fetch < self._fetch_every: 121 | return 122 | self._samples_since_last_fetch = 0 123 | try: 124 | worker_id = torch.utils.data.get_worker_info().id 125 | except: 126 | worker_id = 0 127 | eps_fns = sorted(self._replay_dir.glob('*.npz'), reverse=True) 128 | fetched_size = 0 129 | for eps_fn in eps_fns: 130 | eps_idx, eps_len = [int(x) for x in eps_fn.stem.split('_')[1:]] 131 | if eps_idx % self._num_workers != worker_id: 132 | continue 133 | if eps_fn in self._episodes.keys(): 134 | break 135 | if fetched_size + eps_len > self._max_size: 136 | break 137 | fetched_size += eps_len 138 | if not self._store_episode(eps_fn): 139 | break 140 | 141 | def _preload(self): 142 | eps_fns = sorted(self._replay_dir.glob('*.npz'), reverse=True) 143 | for eps_fn in eps_fns: 144 | self._store_episode(eps_fn) 145 | 146 | def _sample(self): 147 | try: 148 | self._try_fetch() 149 | except: 150 | traceback.print_exc() 151 | self._samples_since_last_fetch += 1 152 | episode = self._sample_episode() 153 | # add +1 for the first dummy transition 154 | n_step = max(self._nstep, self._multistep) 155 | idx = np.random.randint(0, episode_len(episode) - n_step + 1) + 1 156 | obs = episode['observation'][idx - 1] 157 | r_next_obs = episode['observation'][idx + self._multistep - 1] 158 | action = episode['action'][idx] 159 | action_seq = np.concatenate([episode['action'][idx+i][None, :] for i in range(self._multistep)]) 160 | next_obs = episode['observation'][idx + self._nstep - 1] 161 | reward = np.zeros_like(episode['reward'][idx]) 162 | discount = np.ones_like(episode['discount'][idx]) 163 | for i in range(self._nstep): 164 | step_reward = episode['reward'][idx + i] 165 | reward += discount * step_reward 166 | discount *= episode['discount'][idx + i] * self._discount 167 | return (obs, action, action_seq, reward, discount, next_obs, r_next_obs) 168 | 169 | def __iter__(self): 170 | while True: 171 | yield self._sample() 172 | 173 | 174 | def _worker_init_fn(worker_id): 175 | seed = np.random.get_state()[1][0] + worker_id 176 | np.random.seed(seed) 177 | random.seed(seed) 178 | 179 | 180 | def make_replay_loader(replay_dir, max_size, batch_size, num_workers, 181 | save_snapshot, nstep, multistep, discount): 182 | max_size_per_worker = max_size // max(1, num_workers) 183 | 184 | iterable = ReplayBuffer(replay_dir, 185 | max_size_per_worker, 186 | num_workers, 187 | nstep, 188 | multistep, 189 | discount, 190 | fetch_every=1000, 191 | save_snapshot=save_snapshot) 192 | 193 | loader = torch.utils.data.DataLoader(iterable, 194 | batch_size=batch_size, 195 | num_workers=num_workers, 196 | pin_memory=True, 197 | worker_init_fn=_worker_init_fn) 198 | return loader 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings('ignore', category=DeprecationWarning) 3 | 4 | import os 5 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' 6 | os.environ['MUJOCO_GL'] = 'egl' 7 | 8 | from pathlib import Path 9 | 10 | import hydra 11 | import numpy as np 12 | import torch 13 | from dm_env import specs 14 | 15 | import dmc 16 | import utils 17 | from logger import Logger 18 | from replay_buffer import ReplayBufferStorage, make_replay_loader 19 | from video import TrainVideoRecorder, VideoRecorder 20 | 21 | torch.backends.cudnn.benchmark = True 22 | 23 | 24 | def make_agent(obs_spec, action_spec, cfg): 25 | cfg.obs_shape = obs_spec.shape 26 | cfg.action_shape = action_spec.shape 27 | return hydra.utils.instantiate(cfg) 28 | 29 | 30 | class Workspace: 31 | def __init__(self, cfg): 32 | self.work_dir = Path.cwd() 33 | print(f'workspace: {self.work_dir}') 34 | 35 | self.cfg = cfg 36 | utils.set_seed_everywhere(cfg.seed) 37 | self.device = torch.device(cfg.device) 38 | self.setup() 39 | 40 | self.agent = make_agent(self.train_env.observation_spec(), 41 | self.train_env.action_spec(), 42 | self.cfg.agent) 43 | self.timer = utils.Timer() 44 | self._global_step = 0 45 | self._global_episode = 0 46 | 47 | def setup(self): 48 | # create logger 49 | self.logger = Logger(self.work_dir, use_tb=self.cfg.use_tb) 50 | self.train_env = dmc.make(self.cfg.task_name, self.cfg.frame_stack, 51 | self.cfg.action_repeat, self.cfg.seed) 52 | self.eval_env = dmc.make(self.cfg.task_name, self.cfg.frame_stack, 53 | self.cfg.action_repeat, self.cfg.seed) 54 | # create replay buffer 55 | data_specs = (self.train_env.observation_spec(), 56 | self.train_env.action_spec(), 57 | specs.Array((1,), np.float32, 'reward'), 58 | specs.Array((1,), np.float32, 'discount')) 59 | 60 | self.replay_storage = ReplayBufferStorage(data_specs, 61 | self.work_dir / 'buffer') 62 | 63 | self.replay_loader = make_replay_loader( 64 | self.work_dir / 'buffer', self.cfg.replay_buffer_size, 65 | self.cfg.batch_size, self.cfg.replay_buffer_num_workers, 66 | self.cfg.save_snapshot, self.cfg.nstep, self.cfg.multistep, self.cfg.discount) 67 | self._replay_iter = None 68 | 69 | self.video_recorder = VideoRecorder( 70 | self.work_dir if self.cfg.save_video else None) 71 | self.train_video_recorder = TrainVideoRecorder( 72 | self.work_dir if self.cfg.save_train_video else None) 73 | 74 | 75 | @property 76 | def global_step(self): 77 | return self._global_step 78 | 79 | @property 80 | def global_episode(self): 81 | return self._global_episode 82 | 83 | @property 84 | def global_frame(self): 85 | return self.global_step * self.cfg.action_repeat 86 | 87 | @property 88 | def replay_iter(self): 89 | if self._replay_iter is None: 90 | self._replay_iter = iter(self.replay_loader) 91 | return self._replay_iter 92 | 93 | def eval(self): 94 | step, episode, total_reward, success = 0, 0, 0, 0 95 | eval_until_episode = utils.Until(self.cfg.num_eval_episodes) 96 | 97 | while eval_until_episode(episode): 98 | time_step = self.eval_env.reset() 99 | self.video_recorder.init(self.eval_env, enabled=(episode == 0)) 100 | while not time_step.last(): 101 | with torch.no_grad(), utils.eval_mode(self.agent): 102 | action = self.agent.act(time_step.observation, 103 | self.global_step, 104 | eval_mode=True) 105 | time_step = self.eval_env.step(action) 106 | self.video_recorder.record(self.eval_env) 107 | total_reward += time_step.reward 108 | step += 1 109 | 110 | episode += 1 111 | self.video_recorder.save(f'{self.global_frame}.mp4') 112 | 113 | with self.logger.log_and_dump_ctx(self.global_frame, ty='eval') as log: 114 | log('episode_reward', total_reward / episode) 115 | log('episode_length', step * self.cfg.action_repeat / episode) 116 | log('episode', self.global_episode) 117 | log('step', self.global_step) 118 | 119 | def train(self): 120 | # predicates 121 | train_until_step = utils.Until(self.cfg.num_train_frames, 122 | self.cfg.action_repeat) 123 | seed_until_step = utils.Until(self.cfg.num_seed_frames, 124 | self.cfg.action_repeat) 125 | eval_every_step = utils.Every(self.cfg.eval_every_frames, 126 | self.cfg.action_repeat) 127 | 128 | episode_step, episode_reward = 0, 0 129 | time_step = self.train_env.reset() 130 | self.replay_storage.add(time_step) 131 | self.train_video_recorder.init(time_step.observation) 132 | metrics = None 133 | while train_until_step(self.global_step): 134 | if time_step.last(): 135 | self._global_episode += 1 136 | self.train_video_recorder.save(f'{self.global_frame}.mp4') 137 | # wait until all the metrics schema is populated 138 | if metrics is not None: 139 | # log stats 140 | elapsed_time, total_time = self.timer.reset() 141 | episode_frame = episode_step * self.cfg.action_repeat 142 | with self.logger.log_and_dump_ctx(self.global_frame, 143 | ty='train') as log: 144 | log('fps', episode_frame / elapsed_time) 145 | log('total_time', total_time) 146 | log('episode_reward', episode_reward) 147 | log('episode_length', episode_frame) 148 | log('episode', self.global_episode) 149 | log('buffer_size', len(self.replay_storage)) 150 | log('step', self.global_step) 151 | 152 | # reset env 153 | time_step = self.train_env.reset() 154 | self.replay_storage.add(time_step) 155 | self.train_video_recorder.init(time_step.observation) 156 | # try to save snapshot 157 | if self.cfg.save_snapshot: 158 | self.save_snapshot() 159 | episode_step = 0 160 | episode_reward = 0 161 | 162 | # try to evaluate 163 | if eval_every_step(self.global_step): 164 | self.logger.log('eval_total_time', self.timer.total_time(), 165 | self.global_frame) 166 | self.eval() 167 | 168 | # sample action 169 | with torch.no_grad(), utils.eval_mode(self.agent): 170 | action = self.agent.act(time_step.observation, 171 | self.global_step, 172 | eval_mode=False) 173 | 174 | # try to update the agent 175 | if not seed_until_step(self.global_step): 176 | metrics = self.agent.update(self.replay_iter, self.global_step) 177 | self.logger.log_metrics(metrics, self.global_frame, ty='train') 178 | 179 | # take env step 180 | time_step = self.train_env.step(action) 181 | episode_reward += time_step.reward 182 | self.replay_storage.add(time_step) 183 | self.train_video_recorder.record(time_step.observation) 184 | episode_step += 1 185 | self._global_step += 1 186 | 187 | def save_snapshot(self): 188 | snapshot = self.work_dir / 'snapshot.pt' 189 | keys_to_save = ['agent', 'timer', '_global_step', '_global_episode'] 190 | payload = {k: self.__dict__[k] for k in keys_to_save} 191 | with snapshot.open('wb') as f: 192 | torch.save(payload, f) 193 | 194 | def load_snapshot(self): 195 | snapshot = self.work_dir / 'snapshot.pt' 196 | with snapshot.open('rb') as f: 197 | payload = torch.load(f) 198 | for k, v in payload.items(): 199 | self.__dict__[k] = v 200 | 201 | 202 | @hydra.main(config_path='cfgs', config_name='config') 203 | def main(cfg): 204 | from train import Workspace as W 205 | root_dir = Path.cwd() 206 | workspace = W(cfg) 207 | snapshot = root_dir / 'snapshot.pt' 208 | if snapshot.exists(): 209 | print(f'resuming: {snapshot}') 210 | workspace.load_snapshot() 211 | workspace.train() 212 | 213 | 214 | if __name__ == '__main__': 215 | main() 216 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import re 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from omegaconf import OmegaConf 10 | from torch import distributions as pyd 11 | from torch.distributions.utils import _standard_normal 12 | 13 | ### input shape: (batch_size, length, action_dim) 14 | ### output shape: (batch_size, action_dim) 15 | class ActionEncoding(nn.Module): 16 | def __init__(self, action_dim, latent_action_dim, multistep): 17 | super().__init__() 18 | self.action_dim = action_dim 19 | self.action_tokenizer = nn.Sequential( 20 | nn.Linear(action_dim, 64), nn.Tanh(), 21 | nn.Linear(64, latent_action_dim) 22 | ) 23 | self.action_seq_tokenizer = nn.Sequential( 24 | nn.Linear(latent_action_dim*multistep, latent_action_dim*multistep), 25 | nn.LayerNorm(latent_action_dim*multistep), nn.Tanh() 26 | ) 27 | self.apply(weight_init) 28 | 29 | def forward(self, action, seq=False): 30 | if seq: 31 | batch_size = action.shape[0] 32 | action = self.action_tokenizer(action) #(batch_size, length_action_dim) 33 | action = action.reshape(batch_size, -1) 34 | return self.action_seq_tokenizer(action) 35 | else: 36 | return self.action_tokenizer(action) 37 | 38 | 39 | 40 | class eval_mode: 41 | def __init__(self, *models): 42 | self.models = models 43 | 44 | def __enter__(self): 45 | self.prev_states = [] 46 | for model in self.models: 47 | self.prev_states.append(model.training) 48 | model.train(False) 49 | 50 | def __exit__(self, *args): 51 | for model, state in zip(self.models, self.prev_states): 52 | model.train(state) 53 | return False 54 | 55 | 56 | def set_seed_everywhere(seed): 57 | torch.manual_seed(seed) 58 | if torch.cuda.is_available(): 59 | torch.cuda.manual_seed_all(seed) 60 | np.random.seed(seed) 61 | random.seed(seed) 62 | 63 | 64 | def soft_update_params(net, target_net, tau): 65 | for param, target_param in zip(net.parameters(), target_net.parameters()): 66 | target_param.data.copy_(tau * param.data + 67 | (1 - tau) * target_param.data) 68 | 69 | def expectile_loss(diff, expectile=0.8): 70 | weight = torch.where(diff > 0, expectile, (1 - expectile)) 71 | return weight * (diff**2) 72 | 73 | 74 | def to_torch(xs, device): 75 | return tuple(torch.as_tensor(x, device=device) for x in xs) 76 | 77 | def encode_multiple(encoder, xs, detach_lst): 78 | length = [x.shape[0] for x in xs] 79 | xs, xs_lst = torch.cat(xs, dim=0), [] 80 | xs = encoder(xs) 81 | start = 0 82 | for i in range(len(detach_lst)): 83 | x = xs[start:start+length[i], :] 84 | if detach_lst[i]: 85 | x = x.detach() 86 | xs_lst.append(x) 87 | start += length[i] 88 | return xs_lst 89 | 90 | 91 | 92 | def weight_init(m): 93 | if isinstance(m, nn.Linear): 94 | nn.init.orthogonal_(m.weight.data) 95 | if hasattr(m.bias, 'data'): 96 | m.bias.data.fill_(0.0) 97 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 98 | gain = nn.init.calculate_gain('relu') 99 | nn.init.orthogonal_(m.weight.data, gain) 100 | if hasattr(m.bias, 'data'): 101 | m.bias.data.fill_(0.0) 102 | 103 | 104 | class Until: 105 | def __init__(self, until, action_repeat=1): 106 | self._until = until 107 | self._action_repeat = action_repeat 108 | 109 | def __call__(self, step): 110 | if self._until is None: 111 | return True 112 | until = self._until // self._action_repeat 113 | return step < until 114 | 115 | 116 | class Every: 117 | def __init__(self, every, action_repeat=1): 118 | self._every = every 119 | self._action_repeat = action_repeat 120 | 121 | def __call__(self, step): 122 | if self._every is None: 123 | return False 124 | every = self._every // self._action_repeat 125 | if step % every == 0: 126 | return True 127 | return False 128 | 129 | 130 | class Timer: 131 | def __init__(self): 132 | self._start_time = time.time() 133 | self._last_time = time.time() 134 | 135 | def reset(self): 136 | elapsed_time = time.time() - self._last_time 137 | self._last_time = time.time() 138 | total_time = time.time() - self._start_time 139 | return elapsed_time, total_time 140 | 141 | def total_time(self): 142 | return time.time() - self._start_time 143 | 144 | 145 | class TruncatedNormal(pyd.Normal): 146 | def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6): 147 | super().__init__(loc, scale, validate_args=False) 148 | self.low = low 149 | self.high = high 150 | self.eps = eps 151 | 152 | def _clamp(self, x): 153 | clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps) 154 | x = x - x.detach() + clamped_x.detach() 155 | return x 156 | 157 | def sample(self, clip=None, sample_shape=torch.Size()): 158 | shape = self._extended_shape(sample_shape) 159 | eps = _standard_normal(shape, 160 | dtype=self.loc.dtype, 161 | device=self.loc.device) 162 | eps *= self.scale 163 | if clip is not None: 164 | eps = torch.clamp(eps, -clip, clip) 165 | x = self.loc + eps 166 | return self._clamp(x) 167 | 168 | 169 | def schedule(schdl, step): 170 | try: 171 | return float(schdl) 172 | except ValueError: 173 | match = re.match(r'linear\((.+),(.+),(.+)\)', schdl) 174 | if match: 175 | init, final, duration = [float(g) for g in match.groups()] 176 | mix = np.clip(step / duration, 0.0, 1.0) 177 | return (1.0 - mix) * init + mix * final 178 | match = re.match(r'step_linear\((.+),(.+),(.+),(.+),(.+)\)', schdl) 179 | if match: 180 | init, final1, duration1, final2, duration2 = [ 181 | float(g) for g in match.groups() 182 | ] 183 | if step <= duration1: 184 | mix = np.clip(step / duration1, 0.0, 1.0) 185 | return (1.0 - mix) * init + mix * final1 186 | else: 187 | mix = np.clip((step - duration1) / duration2, 0.0, 1.0) 188 | return (1.0 - mix) * final1 + mix * final2 189 | raise NotImplementedError(schdl) 190 | 191 | 192 | -------------------------------------------------------------------------------- /video.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import imageio 3 | import numpy as np 4 | 5 | 6 | class VideoRecorder: 7 | def __init__(self, root_dir, render_size=256, fps=20): 8 | if root_dir is not None: 9 | self.save_dir = root_dir / 'eval_video' 10 | self.save_dir.mkdir(exist_ok=True) 11 | else: 12 | self.save_dir = None 13 | 14 | self.render_size = render_size 15 | self.fps = fps 16 | self.frames = [] 17 | 18 | def init(self, env, enabled=True): 19 | self.frames = [] 20 | self.enabled = self.save_dir is not None and enabled 21 | self.record(env) 22 | 23 | def record(self, env): 24 | if self.enabled: 25 | if hasattr(env, 'physics'): 26 | frame = env.physics.render(height=self.render_size, 27 | width=self.render_size, 28 | camera_id=0) 29 | else: 30 | frame = env.render() 31 | self.frames.append(frame) 32 | 33 | def save(self, file_name): 34 | if self.enabled: 35 | path = self.save_dir / file_name 36 | imageio.mimsave(str(path), self.frames, fps=self.fps) 37 | 38 | 39 | class TrainVideoRecorder: 40 | def __init__(self, root_dir, render_size=256, fps=20): 41 | if root_dir is not None: 42 | self.save_dir = root_dir / 'train_video' 43 | self.save_dir.mkdir(exist_ok=True) 44 | else: 45 | self.save_dir = None 46 | 47 | self.render_size = render_size 48 | self.fps = fps 49 | self.frames = [] 50 | 51 | def init(self, obs, enabled=True): 52 | self.frames = [] 53 | self.enabled = self.save_dir is not None and enabled 54 | self.record(obs) 55 | 56 | def record(self, obs): 57 | if self.enabled: 58 | frame = cv2.resize(obs[-3:].transpose(1, 2, 0), 59 | dsize=(self.render_size, self.render_size), 60 | interpolation=cv2.INTER_CUBIC) 61 | self.frames.append(frame) 62 | 63 | def save(self, file_name): 64 | if self.enabled: 65 | path = self.save_dir / file_name 66 | imageio.mimsave(str(path), self.frames, fps=self.fps) 67 | --------------------------------------------------------------------------------