├── .gitignore ├── DMC_image ├── LICENSE ├── agent │ ├── aps.py │ ├── apt_dreamer.py │ ├── choreo.py │ ├── cic.py │ ├── diayn.py │ ├── dreamer.py │ ├── dreamer_utils.py │ ├── icm.py │ ├── lbs.py │ ├── lsd.py │ ├── peac.py │ ├── peac_diayn.py │ ├── peac_lbs.py │ ├── plan2explore.py │ ├── random_dreamer.py │ ├── rnd.py │ ├── skill_utils.py │ └── spectral_utils.py ├── configs │ ├── agent │ │ ├── aps_dreamer.yaml │ │ ├── apt_dreamer.yaml │ │ ├── choreo.yaml │ │ ├── cic_dreamer.yaml │ │ ├── diayn_dreamer.yaml │ │ ├── dreamer.yaml │ │ ├── icm_dreamer.yaml │ │ ├── lbs_dreamer.yaml │ │ ├── lsd.yaml │ │ ├── peac_diayn.yaml │ │ ├── peac_lbs.yaml │ │ ├── plan2explore.yaml │ │ ├── random_dreamer.yaml │ │ └── rnd_dreamer.yaml │ ├── dmc_pixels.yaml │ ├── dmc_states.yaml │ └── dreamer.yaml ├── custom_dmc_tasks │ ├── __init__.py │ ├── jaco.py │ ├── quadruped.py │ ├── quadruped.xml │ ├── walker.py │ └── walker.xml ├── dmc_benchmark.py ├── dreamer_finetune.py ├── dreamer_finetune.yaml ├── dreamer_pretrain.py ├── dreamer_pretrain.yaml ├── dreamer_replay.py ├── envs.py ├── logger.py ├── train_finetune.sh └── utils.py ├── DMC_state ├── LICENSE ├── README.md ├── agent │ ├── aps.py │ ├── becl.py │ ├── cic.py │ ├── ddpg.py │ ├── diayn.py │ ├── disagreement.py │ ├── icm.py │ ├── lbs.py │ ├── peac.py │ ├── proto.py │ ├── rnd.py │ └── smm.py ├── configs │ └── agent │ │ ├── aps.yaml │ │ ├── becl.yaml │ │ ├── cic.yaml │ │ ├── ddpg.yaml │ │ ├── diayn.yaml │ │ ├── disagreement.yaml │ │ ├── icm.yaml │ │ ├── lbs.yaml │ │ ├── peac.yaml │ │ ├── proto.yaml │ │ ├── rnd.yaml │ │ └── smm.yaml ├── custom_dmc_tasks │ ├── __init__.py │ ├── jaco.py │ ├── quadruped.py │ ├── quadruped.xml │ ├── walker.py │ └── walker.xml ├── dmc.py ├── dmc_benchmark.py ├── finetune.py ├── finetune.yaml ├── finetune_ddpg.sh ├── logger.py ├── pretrain.py ├── pretrain.yaml ├── replay_buffer.py ├── train_finetune.sh ├── utils.py └── video.py ├── LICENSE ├── README.md └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | DMC_state/exp_local 2 | DMC_state/pretrained_models 3 | DMC_image/exp_local 4 | DMC_image/pretrained_models 5 | 6 | __pycache__/ 7 | **/__pycache__/ -------------------------------------------------------------------------------- /DMC_image/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Pietro Mazzaglia 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /DMC_image/agent/apt_dreamer.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | import utils 8 | from agent.dreamer import DreamerAgent, stop_gradient 9 | import agent.dreamer_utils as common 10 | 11 | 12 | class APTDreamerAgent(DreamerAgent): 13 | def __init__(self, knn_rms, knn_k, knn_avg, knn_clip, **kwargs): 14 | super().__init__(**kwargs) 15 | self.reward_free = True 16 | 17 | # particle-based entropy 18 | rms = utils.RMS(self.device) 19 | self.pbe = utils.PBE(rms, knn_clip, knn_k, knn_avg, knn_rms, 20 | self.device) 21 | 22 | self.requires_grad_(requires_grad=False) 23 | 24 | def compute_intr_reward(self, seq): 25 | rep = stop_gradient(seq['deter']) 26 | B, T, _ = rep.shape 27 | rep = rep.reshape(B*T, -1) 28 | reward = self.pbe(rep, cdist=True) 29 | reward = reward.reshape(B, T, 1) 30 | return reward 31 | 32 | def update(self, data, step): 33 | metrics = {} 34 | B, T, _ = data['action'].shape 35 | 36 | state, outputs, mets = self.wm.update(data, state=None) 37 | metrics.update(mets) 38 | start = outputs['post'] 39 | start = {k: stop_gradient(v) for k,v in start.items()} 40 | if self.reward_free: 41 | reward_fn = lambda seq: self.compute_intr_reward(seq) 42 | else: 43 | reward_fn = lambda seq: self.wm.heads['reward'](seq['feat']).mean #.mode() 44 | metrics.update(self._task_behavior.update( 45 | self.wm, start, data['is_terminal'], reward_fn)) 46 | return state, metrics -------------------------------------------------------------------------------- /DMC_image/agent/icm.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | import utils 8 | from agent.dreamer import DreamerAgent, stop_gradient 9 | import agent.dreamer_utils as common 10 | 11 | 12 | class ICM(nn.Module): 13 | def __init__(self, obs_dim, action_dim, hidden_dim): 14 | super().__init__() 15 | self.forward_net = nn.Sequential( 16 | nn.Linear(obs_dim + action_dim, hidden_dim), nn.ReLU(), 17 | nn.Linear(hidden_dim, obs_dim)) 18 | 19 | self.backward_net = nn.Sequential(nn.Linear(2 * obs_dim, hidden_dim), 20 | nn.ReLU(), 21 | nn.Linear(hidden_dim, action_dim), 22 | nn.Tanh()) 23 | 24 | self.apply(utils.weight_init) 25 | 26 | def forward(self, obs, action, next_obs): 27 | assert obs.shape[0] == next_obs.shape[0] 28 | assert obs.shape[0] == action.shape[0] 29 | 30 | next_obs_hat = self.forward_net(torch.cat([obs, action], dim=-1)) 31 | action_hat = self.backward_net(torch.cat([obs, next_obs], dim=-1)) 32 | 33 | forward_error = torch.norm(next_obs - next_obs_hat, 34 | dim=-1, 35 | p=2, 36 | keepdim=True) 37 | backward_error = torch.norm(action - action_hat, 38 | dim=-1, 39 | p=2, 40 | keepdim=True) 41 | 42 | return forward_error, backward_error 43 | 44 | 45 | class ICMDreamerAgent(DreamerAgent): 46 | def __init__(self, icm_scale, **kwargs): 47 | super().__init__(**kwargs) 48 | in_dim = self.wm.inp_size 49 | pred_dim = self.wm.embed_dim 50 | self.hidden_dim = pred_dim 51 | self.reward_free = True 52 | self.icm_scale = icm_scale 53 | 54 | self.icm = ICM(pred_dim, self.act_dim, 55 | self.hidden_dim).to(self.device) 56 | 57 | # optimizers 58 | self.icm_opt = common.Optimizer('icm', self.icm.parameters(), **self.cfg.model_opt, use_amp=self._use_amp) 59 | 60 | self.icm.train() 61 | self.requires_grad_(requires_grad=False) 62 | 63 | def update_icm(self, obs, action, next_obs, step): 64 | metrics = dict() 65 | 66 | forward_error, backward_error = self.icm(obs, action, next_obs) 67 | 68 | loss = forward_error.mean() # + backward_error.mean() 69 | 70 | metrics.update(self.icm_opt(loss, self.icm.parameters())) 71 | 72 | metrics['icm_loss'] = loss.item() 73 | 74 | return metrics 75 | 76 | def compute_intr_reward(self, obs, action, next_obs, step): 77 | forward_error, _ = self.icm(obs, action, next_obs) 78 | 79 | reward = forward_error * self.icm_scale 80 | reward = torch.log(reward + 1.0) 81 | return reward 82 | 83 | def update(self, data, step): 84 | metrics = {} 85 | B, T, _ = data['action'].shape 86 | 87 | if self.reward_free: 88 | T = T-1 89 | temp_data = self.wm.preprocess(data) 90 | embed = self.wm.encoder(temp_data) 91 | inp = stop_gradient(embed[:, :-1]).reshape(B*T, -1) 92 | action = data['action'][:, 1:].reshape(B*T, -1) 93 | out = stop_gradient(embed[:, 1:]).reshape(B*T, -1) 94 | with common.RequiresGrad(self.icm): 95 | with torch.cuda.amp.autocast(enabled=self._use_amp): 96 | metrics.update( 97 | self.update_icm(inp, action, out, step)) 98 | 99 | with torch.no_grad(): 100 | intr_reward = self.compute_intr_reward(inp, action, out, step).reshape(B, T, 1) 101 | 102 | data['reward'][:, 0] = 1 103 | data['reward'][:, 1:] = intr_reward 104 | 105 | state, outputs, mets = self.wm.update(data, state=None) 106 | metrics.update(mets) 107 | start = outputs['post'] 108 | start = {k: stop_gradient(v) for k, v in start.items()} 109 | reward_fn = lambda seq: self.wm.heads['reward'](seq['feat']).mean #.mode() 110 | metrics.update(self._task_behavior.update( 111 | self.wm, start, data['is_terminal'], reward_fn)) 112 | return state, metrics -------------------------------------------------------------------------------- /DMC_image/agent/lbs.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import hydra 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.distributions as D 9 | 10 | import utils 11 | from agent.dreamer import DreamerAgent, WorldModel, stop_gradient 12 | import agent.dreamer_utils as common 13 | 14 | 15 | class LBSDreamerAgent(DreamerAgent): 16 | def __init__(self, **kwargs): 17 | super().__init__(**kwargs) 18 | 19 | self.reward_free = True 20 | 21 | # LBS 22 | self.lbs = common.MLP(self.wm.inp_size, (1,), **self.cfg.reward_head).to(self.device) 23 | self.lbs_opt = common.Optimizer('lbs', self.lbs.parameters(), **self.cfg.model_opt, use_amp=self._use_amp) 24 | self.lbs.train() 25 | 26 | self.requires_grad_(requires_grad=False) 27 | 28 | def update_lbs(self, outs): 29 | metrics = dict() 30 | B, T, _ = outs['feat'].shape 31 | feat, kl = outs['feat'].detach(), outs['kl'].detach() 32 | feat = feat.reshape(B*T, -1) 33 | kl = kl.reshape(B*T, -1) 34 | 35 | loss = -self.lbs(feat).log_prob(kl).mean() 36 | metrics.update(self.lbs_opt(loss, self.lbs.parameters())) 37 | metrics['lbs_loss'] = loss.item() 38 | return metrics 39 | 40 | def update(self, data, step): 41 | metrics = {} 42 | state, outputs, mets = self.wm.update(data, state=None) 43 | metrics.update(mets) 44 | start = outputs['post'] 45 | start = {k: stop_gradient(v) for k, v in start.items()} 46 | 47 | if self.reward_free: 48 | with common.RequiresGrad(self.lbs): 49 | with torch.cuda.amp.autocast(enabled=self._use_amp): 50 | metrics.update( 51 | self.update_lbs(outputs)) 52 | reward_fn = lambda seq: self.lbs(seq['feat']).mean 53 | else: 54 | reward_fn = lambda seq: self.wm.heads['reward'](seq['feat']).mean #.mode() 55 | 56 | metrics.update(self._task_behavior.update( 57 | self.wm, start, data['is_terminal'], reward_fn)) 58 | return state, metrics -------------------------------------------------------------------------------- /DMC_image/agent/peac_lbs.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import hydra 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.distributions as D 9 | 10 | import utils 11 | from agent.peac import PEACAgent, stop_gradient 12 | import agent.dreamer_utils as common 13 | 14 | 15 | class PEAC_LBSAgent(PEACAgent): 16 | def __init__(self, beta=1.0, **kwargs): 17 | super().__init__(**kwargs) 18 | 19 | self.reward_free = True 20 | self.beta = beta 21 | print("beta:", self.beta) 22 | 23 | # LBS 24 | # feat + context -> predict kl 25 | self.lbs = common.MLP(self.wm.inp_size+self.task_number, (1,), 26 | **self.cfg.reward_head).to(self.device) 27 | self.lbs_opt = common.Optimizer('lbs', self.lbs.parameters(), **self.cfg.model_opt, use_amp=self._use_amp) 28 | self.lbs.train() 29 | 30 | self.requires_grad_(requires_grad=False) 31 | 32 | def update_lbs(self, outs): 33 | metrics = dict() 34 | B, T, _ = outs['feat'].shape 35 | feat, kl = outs['feat'].detach(), outs['kl'].detach() 36 | feat = feat.reshape(B * T, -1) 37 | kl = kl.reshape(B * T, -1) 38 | context = F.softmax(self.wm.task_model(feat), dim=-1).detach() 39 | 40 | loss = -self.lbs(torch.cat([feat, context], dim=-1)).log_prob(kl).mean() 41 | metrics.update(self.lbs_opt(loss, self.lbs.parameters())) 42 | metrics['lbs_loss'] = loss.item() 43 | return metrics 44 | 45 | def update(self, data, step): 46 | metrics = {} 47 | state, outputs, mets = self.wm.update(data, state=None) 48 | metrics.update(mets) 49 | start = outputs['post'] 50 | start['embodiment_id'] = data['embodiment_id'] 51 | start['context'] = outputs['context'] 52 | start = {k: stop_gradient(v) for k, v in start.items()} 53 | 54 | if self.reward_free: 55 | with common.RequiresGrad(self.lbs): 56 | with torch.cuda.amp.autocast(enabled=self._use_amp): 57 | metrics.update(self.update_lbs(outputs)) 58 | reward_fn = lambda seq: self.compute_intr_reward(seq) + \ 59 | self.beta * self.compute_task_reward(seq) 60 | else: 61 | reward_fn = lambda seq: self.wm.heads['reward'](seq['feat']).mean # .mode() 62 | 63 | metrics.update(self._task_behavior.update( 64 | self.wm, start, data['is_terminal'], reward_fn)) 65 | return state, metrics 66 | 67 | def compute_intr_reward(self, seq): 68 | context = F.softmax(self.wm.task_model(seq['feat']), dim=-1) 69 | return self.lbs(torch.cat([seq['feat'], context], dim=-1)).mean 70 | 71 | def compute_task_reward(self, seq): 72 | # print('we use calculated reward') 73 | B, T, _ = seq['feat'].shape 74 | task_pred = self.wm.task_model(seq['feat']) 75 | task_truth = seq['embodiment_id'].repeat(B, 1, 1).to(dtype=torch.int64) 76 | # print(task_pred.shape) # 16, 2500, task_number 77 | # print(seq['action'].shape) # 16, 2500, _ 78 | # print(task_truth.shape) # 16, 2500, 1 79 | task_pred = F.log_softmax(task_pred, dim=2) 80 | task_rew = task_pred.reshape(B * T, -1)[torch.arange(B * T), task_truth.reshape(-1)] 81 | task_rew = -task_rew.reshape(B, T, 1) 82 | 83 | # print(intr_rew.shape) # 16, 2500, 1 84 | return task_rew 85 | -------------------------------------------------------------------------------- /DMC_image/agent/plan2explore.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | import utils 8 | from agent.dreamer import DreamerAgent, stop_gradient 9 | import agent.dreamer_utils as common 10 | 11 | 12 | class Disagreement(nn.Module): 13 | def __init__(self, obs_dim, action_dim, hidden_dim, n_models=5, pred_dim=None): 14 | super().__init__() 15 | if pred_dim is None: pred_dim = obs_dim 16 | self.ensemble = nn.ModuleList([ 17 | nn.Sequential(nn.Linear(obs_dim + action_dim, hidden_dim), 18 | nn.ReLU(), nn.Linear(hidden_dim, pred_dim)) 19 | for _ in range(n_models) 20 | ]) 21 | 22 | def forward(self, obs, action, next_obs): 23 | #import ipdb; ipdb.set_trace() 24 | assert obs.shape[0] == next_obs.shape[0] 25 | assert obs.shape[0] == action.shape[0] 26 | 27 | errors = [] 28 | for model in self.ensemble: 29 | next_obs_hat = model(torch.cat([obs, action], dim=-1)) 30 | model_error = torch.norm(next_obs - next_obs_hat, 31 | dim=-1, 32 | p=2, 33 | keepdim=True) 34 | errors.append(model_error) 35 | 36 | return torch.cat(errors, dim=1) 37 | 38 | def get_disagreement(self, obs, action): 39 | assert obs.shape[0] == action.shape[0] 40 | 41 | preds = [] 42 | for model in self.ensemble: 43 | next_obs_hat = model(torch.cat([obs, action], dim=-1)) 44 | preds.append(next_obs_hat) 45 | preds = torch.stack(preds, dim=0) 46 | return torch.var(preds, dim=0).mean(dim=-1) 47 | 48 | 49 | class Plan2Explore(DreamerAgent): 50 | def __init__(self, **kwargs): 51 | super().__init__(**kwargs) 52 | in_dim = self.wm.inp_size 53 | pred_dim = self.wm.embed_dim 54 | self.hidden_dim = pred_dim 55 | self.reward_free = True 56 | 57 | self.disagreement = Disagreement(in_dim, self.act_dim, 58 | self.hidden_dim, pred_dim=pred_dim).to(self.device) 59 | 60 | # optimizers 61 | self.disagreement_opt = common.Optimizer('disagreement', self.disagreement.parameters(), **self.cfg.model_opt, use_amp=self._use_amp) 62 | self.disagreement.train() 63 | self.requires_grad_(requires_grad=False) 64 | 65 | def update_disagreement(self, obs, action, next_obs, step): 66 | metrics = dict() 67 | 68 | error = self.disagreement(obs, action, next_obs) 69 | 70 | loss = error.mean() 71 | 72 | metrics.update(self.disagreement_opt(loss, self.disagreement.parameters())) 73 | 74 | metrics['disagreement_loss'] = loss.item() 75 | 76 | return metrics 77 | 78 | def compute_intr_reward(self, seq): 79 | obs, action = seq['feat'][:-1], stop_gradient(seq['action'][1:]) 80 | intr_rew = torch.zeros(list(seq['action'].shape[:-1]) + [1], device=self.device) 81 | if len(action.shape) > 2: 82 | B, T, _ = action.shape 83 | obs = obs.reshape(B*T, -1) 84 | action = action.reshape(B*T, -1) 85 | reward = self.disagreement.get_disagreement(obs, action).reshape(B, T, 1) 86 | else: 87 | reward = self.disagreement.get_disagreement(obs, action).unsqueeze(-1) 88 | intr_rew[1:] = reward 89 | return intr_rew 90 | 91 | def update(self, data, step): 92 | metrics = {} 93 | B, T, _ = data['action'].shape 94 | state, outputs, mets = self.wm.update(data, state=None) 95 | metrics.update(mets) 96 | start = outputs['post'] 97 | start = {k: stop_gradient(v) for k,v in start.items()} 98 | if self.reward_free: 99 | T = T-1 100 | inp = stop_gradient(outputs['feat'][:, :-1]).reshape(B*T, -1) 101 | action = data['action'][:, 1:].reshape(B*T, -1) 102 | out = stop_gradient(outputs['embed'][:, 1:]).reshape(B*T, -1) 103 | with common.RequiresGrad(self.disagreement): 104 | with torch.cuda.amp.autocast(enabled=self._use_amp): 105 | metrics.update( 106 | self.update_disagreement(inp, action, out, step)) 107 | metrics.update(self._task_behavior.update( 108 | self.wm, start, data['is_terminal'], reward_fn=self.compute_intr_reward)) 109 | else: 110 | reward_fn = lambda seq: self.wm.heads['reward'](seq['feat']).mean #.mode() 111 | metrics.update(self._task_behavior.update( 112 | self.wm, start, data['is_terminal'], reward_fn)) 113 | return state, metrics -------------------------------------------------------------------------------- /DMC_image/agent/random_dreamer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | import utils 5 | from collections import OrderedDict 6 | import numpy as np 7 | 8 | from agent.dreamer import DreamerAgent, WorldModel, stop_gradient 9 | import agent.dreamer_utils as common 10 | 11 | Module = nn.Module 12 | 13 | class RandomDreamerAgent(DreamerAgent): 14 | def __init__(self, **kwargs): 15 | super().__init__(**kwargs) 16 | 17 | def act(self, obs, meta, step, eval_mode, state): 18 | return torch.zeros(self.act_spec.shape).uniform_(-1.0, 1.0).numpy(), None 19 | 20 | def update(self, data, step): 21 | metrics = {} 22 | state, outputs, mets = self.wm.update(data, state=None) 23 | metrics.update(mets) 24 | return state, metrics -------------------------------------------------------------------------------- /DMC_image/agent/rnd.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import hydra 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | import utils 10 | from agent.dreamer import DreamerAgent, stop_gradient 11 | import agent.dreamer_utils as common 12 | 13 | 14 | class RND(nn.Module): 15 | def __init__(self, 16 | obs_dim, 17 | hidden_dim, 18 | rnd_rep_dim, 19 | encoder, 20 | aug, 21 | obs_shape, 22 | obs_type, 23 | clip_val=5.): 24 | super().__init__() 25 | self.clip_val = clip_val 26 | self.aug = aug 27 | 28 | if obs_type == "pixels": 29 | self.normalize_obs = nn.BatchNorm2d(obs_shape[0], affine=False) 30 | else: 31 | self.normalize_obs = nn.BatchNorm1d(obs_shape[0], affine=False) 32 | 33 | self.predictor = nn.Sequential(encoder, nn.Linear(obs_dim, hidden_dim), 34 | nn.ReLU(), 35 | nn.Linear(hidden_dim, hidden_dim), 36 | nn.ReLU(), 37 | nn.Linear(hidden_dim, rnd_rep_dim)) 38 | self.target = nn.Sequential(copy.deepcopy(encoder), 39 | nn.Linear(obs_dim, hidden_dim), nn.ReLU(), 40 | nn.Linear(hidden_dim, hidden_dim), 41 | nn.ReLU(), 42 | nn.Linear(hidden_dim, rnd_rep_dim)) 43 | 44 | for param in self.target.parameters(): 45 | param.requires_grad = False 46 | 47 | self.apply(utils.weight_init) 48 | 49 | def forward(self, obs): 50 | if type(obs) == dict: 51 | img = obs['observation'] 52 | img = self.aug(img) 53 | img = self.normalize_obs(img) 54 | img = torch.clamp(img, -self.clip_val, self.clip_val) 55 | obs['observation'] = img 56 | else: 57 | obs = self.aug(obs) 58 | obs = self.normalize_obs(obs) 59 | obs = torch.clamp(obs, -self.clip_val, self.clip_val) 60 | prediction, target = self.predictor(obs), self.target(obs) 61 | prediction_error = torch.square(target.detach() - prediction).mean( 62 | dim=-1, keepdim=True) 63 | return prediction_error 64 | 65 | 66 | class RNDDreamerAgent(DreamerAgent): 67 | def __init__(self, rnd_rep_dim, rnd_scale, **kwargs): 68 | super().__init__(**kwargs) 69 | 70 | self.reward_free = True 71 | self.rnd_scale = rnd_scale 72 | 73 | self.obs_dim = self.wm.embed_dim 74 | self.hidden_dim = self.wm.embed_dim 75 | self.aug = nn.Identity() 76 | self.obs_shape = (3,64,64) 77 | self.obs_type = self.cfg.obs_type 78 | 79 | encoder = copy.deepcopy(self.wm.encoder) 80 | 81 | self.rnd = RND(self.obs_dim, self.hidden_dim, rnd_rep_dim, 82 | encoder, self.aug, self.obs_shape, 83 | self.obs_type).to(self.device) 84 | self.intrinsic_reward_rms = utils.RMS(device=self.device) 85 | 86 | # optimizers 87 | self.rnd_opt = common.Optimizer('rnd', self.rnd.parameters(), **self.cfg.model_opt, use_amp=self._use_amp) 88 | 89 | self.rnd.train() 90 | self.requires_grad_(requires_grad=False) 91 | 92 | def update_rnd(self, obs, step): 93 | metrics = dict() 94 | 95 | prediction_error = self.rnd(obs) 96 | 97 | loss = prediction_error.mean() 98 | 99 | metrics.update(self.rnd_opt(loss, self.rnd.parameters())) 100 | 101 | metrics['rnd_loss'] = loss.item() 102 | 103 | return metrics 104 | 105 | def compute_intr_reward(self, obs): 106 | prediction_error = self.rnd(obs) 107 | _, intr_reward_var = self.intrinsic_reward_rms(prediction_error) 108 | reward = self.rnd_scale * prediction_error / ( 109 | torch.sqrt(intr_reward_var) + 1e-8) 110 | return reward 111 | 112 | def update(self, data, step): 113 | metrics = {} 114 | B, T, _ = data['action'].shape 115 | obs_shape = data['observation'].shape[2:] 116 | 117 | if self.reward_free: 118 | temp_data = self.wm.preprocess(data) 119 | temp_data['observation'] = temp_data['observation'].reshape(B*T, *obs_shape) 120 | with common.RequiresGrad(self.rnd): 121 | with torch.cuda.amp.autocast(enabled=self._use_amp): 122 | metrics.update(self.update_rnd(temp_data, step)) 123 | 124 | with torch.no_grad(): 125 | intr_reward = self.compute_intr_reward(temp_data).reshape(B, T, 1) 126 | 127 | data['reward'] = intr_reward 128 | 129 | state, outputs, mets = self.wm.update(data, state=None) 130 | metrics.update(mets) 131 | start = outputs['post'] 132 | start = {k: stop_gradient(v) for k,v in start.items()} 133 | reward_fn = lambda seq: self.wm.heads['reward'](seq['feat']).mean #.mode() 134 | metrics.update(self._task_behavior.update( 135 | self.wm, start, data['is_terminal'], reward_fn)) 136 | return state, metrics -------------------------------------------------------------------------------- /DMC_image/configs/agent/aps_dreamer.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.aps.APSDreamerAgent 3 | name: aps_dreamer 4 | cfg: ??? 5 | obs_space: ??? 6 | act_spec: ??? 7 | grad_heads: [decoder] 8 | # Note: it's important to keep momentum = 1.00, otherwise SF won't work 9 | reward_norm: {momentum: 1.00, scale: 1.0, eps: 1e-8} 10 | actor_ent: 0.0 11 | 12 | skill_reward_norm: {momentum: 1.00, scale: 1.0, eps: 1e-8} # {momentum: 0.95, scale: 1.0, eps: 1e-8} 13 | skill_actor_ent: 0.0 14 | 15 | skill_dim: 5 16 | update_skill_every_step: 50 17 | 18 | knn_rms: true 19 | knn_k: 12 20 | knn_avg: true 21 | knn_clip: 0.0001 22 | num_init_frames: 4000 # set to ${num_train_frames} to disable finetune policy parameters 23 | lstsq_batch_size: 4096 -------------------------------------------------------------------------------- /DMC_image/configs/agent/apt_dreamer.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.apt_dreamer.APTDreamerAgent 3 | name: apt_dreamer 4 | cfg: ??? 5 | obs_space: ??? 6 | act_spec: ??? 7 | grad_heads: [decoder] 8 | reward_norm: {momentum: 0.95, scale: 1.0, eps: 1e-8} 9 | actor_ent: 0 10 | 11 | knn_rms: false 12 | knn_k: 12 13 | knn_avg: true 14 | knn_clip: 0.0 -------------------------------------------------------------------------------- /DMC_image/configs/agent/choreo.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.choreo.ChoreoAgent 3 | name: choreo 4 | cfg: ??? 5 | obs_space: ??? 6 | act_spec: ??? 7 | grad_heads: [decoder] 8 | 9 | # Exploration 10 | reward_norm: {momentum: 0.95, scale: 1.0, eps: 1e-8} 11 | actor_ent: 0 12 | 13 | # Skills 14 | skill_dim: 64 15 | skill_reward_norm: {momentum: 1.00, scale: 1.0, eps: 1e-8} 16 | skill_actor_ent: 0 17 | code_dim: 16 18 | code_resampling: True 19 | resample_every: 200 20 | 21 | # Adaptation 22 | num_init_frames: 4000 23 | update_skill_every_step: 125 24 | freeze_skills: False 25 | 26 | # PBE 27 | knn_rms: false 28 | knn_k: 30 29 | knn_avg: true 30 | knn_clip: 0.0001 31 | 32 | task_number: 1 -------------------------------------------------------------------------------- /DMC_image/configs/agent/cic_dreamer.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.cic.CICAgent 3 | name: cic_dreamer 4 | cfg: ??? 5 | obs_space: ??? 6 | act_spec: ??? 7 | grad_heads: [decoder] 8 | # Note: it's important to keep momentum = 1.00, otherwise SF won't work 9 | reward_norm: {momentum: 1.00, scale: 1.0, eps: 1e-8} 10 | actor_ent: 0.0 11 | 12 | skill_reward_norm: {momentum: 1.00, scale: 1.0, eps: 1e-8} # {momentum: 0.95, scale: 1.0, eps: 1e-8} 13 | skill_actor_ent: 0.0 14 | 15 | skill_dim: 5 16 | update_skill_every_step: 50 17 | 18 | #knn_rms: true 19 | #knn_k: 12 20 | #knn_avg: true 21 | #knn_clip: 0.0001 22 | knn_k: 16 23 | knn_rms: true 24 | knn_avg: true 25 | knn_clip: 0.0005 26 | num_init_frames: 4000 # set to ${num_train_frames} to disable finetune policy parameters 27 | lstsq_batch_size: 4096 28 | project_skill: True 29 | temp: 0.5 -------------------------------------------------------------------------------- /DMC_image/configs/agent/diayn_dreamer.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.diayn.DIAYNDreamerAgent 3 | name: diayn_dreamer 4 | cfg: ??? 5 | obs_space: ??? 6 | act_spec: ??? 7 | grad_heads: [decoder] 8 | reward_norm: {momentum: 0.95, scale: 1.0, eps: 1e-8} 9 | actor_ent: 0.1 10 | 11 | diayn_scale: 1.0 12 | update_skill_every_step: 50 13 | 14 | skill_reward_norm: {momentum: 0.95, scale: 1.0, eps: 1e-8} 15 | skill_actor_ent: 0.1 16 | skill_dim: 50 17 | 18 | num_init_frames: 4000 19 | 20 | reward_type: 0 21 | -------------------------------------------------------------------------------- /DMC_image/configs/agent/dreamer.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.dreamer.DreamerAgent 3 | name: dreamer 4 | cfg: ??? 5 | obs_space: ??? 6 | act_spec: ??? 7 | grad_heads: [decoder] 8 | reward_norm: {momentum: 1.0, scale: 1.0, eps: 1e-8} 9 | actor_ent: 1e-4 -------------------------------------------------------------------------------- /DMC_image/configs/agent/icm_dreamer.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.icm.ICMDreamerAgent 3 | name: icm_dreamer 4 | cfg: ??? 5 | obs_space: ??? 6 | act_spec: ??? 7 | grad_heads: [decoder] 8 | reward_norm: {momentum: 0.95, scale: 1.0, eps: 1e-8} 9 | actor_ent: 0 10 | icm_scale: 1.0 -------------------------------------------------------------------------------- /DMC_image/configs/agent/lbs_dreamer.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.lbs.LBSDreamerAgent 3 | name: lbs_dreamer 4 | cfg: ??? 5 | obs_space: ??? 6 | act_spec: ??? 7 | grad_heads: [decoder] 8 | reward_norm: {momentum: 0.95, scale: 1.0, eps: 1e-8} 9 | actor_ent: 0 10 | 11 | reward_type: 0 -------------------------------------------------------------------------------- /DMC_image/configs/agent/lsd.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.lsd.LSDDreamerAgent 3 | name: lsd_dreamer 4 | cfg: ??? 5 | obs_space: ??? 6 | act_spec: ??? 7 | grad_heads: [decoder] 8 | reward_norm: {momentum: 0.95, scale: 1.0, eps: 1e-8} 9 | actor_ent: 0.1 10 | 11 | lsd_scale: 1.0 12 | update_skill_every_step: 50 13 | 14 | skill_reward_norm: {momentum: 0.95, scale: 1.0, eps: 1e-8} 15 | skill_actor_ent: 0.1 16 | skill_dim: 50 17 | 18 | num_init_frames: 4000 19 | -------------------------------------------------------------------------------- /DMC_image/configs/agent/peac_diayn.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.peac_diayn.PEACDIAYNAgent 3 | name: peac_diayn 4 | cfg: ??? 5 | obs_space: ??? 6 | act_spec: ??? 7 | grad_heads: [decoder] 8 | reward_norm: {momentum: 1.0, scale: 1.0, eps: 1e-8} 9 | actor_ent: 1e-4 10 | 11 | diayn_scale: 1.0 12 | update_skill_every_step: 50 13 | task_scale: 1.0 14 | 15 | context_skill_reward_norm: {momentum: 0.95, scale: 1.0, eps: 1e-8} 16 | context_skill_actor_ent: 0.1 17 | skill_dim: 50 18 | diayn_hidden: 50 19 | 20 | num_init_frames: 4000 21 | 22 | freeze_skills: False 23 | 24 | task_number: 1 -------------------------------------------------------------------------------- /DMC_image/configs/agent/peac_lbs.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.peac_lbs.PEAC_LBSAgent 3 | name: peac_lbs 4 | cfg: ??? 5 | obs_space: ??? 6 | act_spec: ??? 7 | grad_heads: [decoder] 8 | #reward_norm: {momentum: 1.0, scale: 1.0, eps: 1e-8} 9 | #actor_ent: 1e-4 10 | reward_norm: {momentum: 0.95, scale: 1.0, eps: 1e-8} 11 | actor_ent: 0 12 | 13 | context_reward_norm: {momentum: 0.95, scale: 1.0, eps: 1e-8} 14 | #context_actor_ent: 0.1 15 | context_actor_ent: 1e-4 16 | 17 | num_init_frames: 4000 18 | 19 | task_number: 1 -------------------------------------------------------------------------------- /DMC_image/configs/agent/plan2explore.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.plan2explore.Plan2Explore 3 | name: plan2explore 4 | cfg: ??? 5 | obs_space: ??? 6 | act_spec: ??? 7 | grad_heads: [decoder] 8 | reward_norm: {momentum: 0.95, scale: 1.0, eps: 1e-8} 9 | actor_ent: 0 -------------------------------------------------------------------------------- /DMC_image/configs/agent/random_dreamer.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.random_dreamer.RandomDreamerAgent 3 | name: random_dreamer 4 | cfg: ??? 5 | obs_space: ??? 6 | act_spec: ??? 7 | grad_heads: [decoder] 8 | reward_norm: {momentum: 1.0, scale: 1.0, eps: 1e-8} 9 | actor_ent: 1e-4 -------------------------------------------------------------------------------- /DMC_image/configs/agent/rnd_dreamer.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.rnd.RNDDreamerAgent 3 | name: rnd_dreamer 4 | cfg: ??? 5 | obs_space: ??? 6 | act_spec: ??? 7 | grad_heads: [decoder] 8 | reward_norm: {momentum: 1.0, scale: 1.0, eps: 1e-8} 9 | actor_ent: 0 10 | 11 | rnd_scale: 1.0 12 | rnd_rep_dim: 512 -------------------------------------------------------------------------------- /DMC_image/configs/dmc_pixels.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | obs_type: pixels 3 | action_repeat: 2 4 | encoder: {mlp_keys: '$^', cnn_keys: 'observation', norm: none, cnn_depth: 48, cnn_kernels: [4, 4, 4, 4], mlp_layers: [400, 400, 400, 400]} # act: elu 5 | decoder: {mlp_keys: '$^', cnn_keys: 'observation', norm: none, cnn_depth: 48, cnn_kernels: [5, 5, 6, 6], mlp_layers: [400, 400, 400, 400]} # act: elu 6 | replay.capacity: 1e6 -------------------------------------------------------------------------------- /DMC_image/configs/dmc_states.yaml: -------------------------------------------------------------------------------- 1 | obs_type: states 2 | action_repeat: 1 3 | encoder: {mlp_keys: 'observation', cnn_keys: '$^', norm: layer, cnn_depth: 48, cnn_kernels: [4, 4, 4, 4], mlp_layers: [400, 400, 400, 400]} 4 | decoder: {mlp_keys: 'observation', cnn_keys: '$^', norm: layer, cnn_depth: 48, cnn_kernels: [5, 5, 6, 6], mlp_layers: [400, 400, 400, 400]} 5 | replay.capacity: 2e6 -------------------------------------------------------------------------------- /DMC_image/configs/dreamer.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Dreamer defaults 4 | pred_discount: False 5 | rssm: {ensemble: 1, hidden: 200, deter: 200, stoch: 32, discrete: 32, norm: none, std_act: sigmoid2, min_std: 0.1} # act: elu, 6 | reward_head: {layers: 4, units: 400, norm: none, dist: mse} # act: elu 7 | # we add task head here 8 | task_head: {layers: 4, units: 200, norm: none, dist: mse} # act: elu 9 | kl: {free: 1.0, forward: False, balance: 0.8, free_avg: True} 10 | loss_scales: {kl: 1.0, reward: 1.0, discount: 1.0, proprio: 1.0} 11 | model_opt: {opt: adam, lr: 3e-4, eps: 1e-5, clip: 100, wd: 1e-6} 12 | replay: {capacity: 2e6, ongoing: False, minlen: 50, maxlen: 50, prioritize_ends: False} 13 | 14 | actor: {layers: 4, units: 400, norm: none, dist: trunc_normal, min_std: 0.1 } # act: elu 15 | critic: {layers: 4, units: 400, norm: none, dist: mse} # act: elu, 16 | actor_opt: {opt: adam, lr: 8e-5, eps: 1e-5, clip: 100, wd: 1e-6} 17 | critic_opt: {opt: adam, lr: 8e-5, eps: 1e-5, clip: 100, wd: 1e-6} 18 | discount: 0.99 19 | discount_lambda: 0.95 20 | actor_grad: dynamics 21 | slow_target: True 22 | slow_target_update: 100 23 | slow_target_fraction: 1 24 | slow_baseline: True 25 | 26 | clip_rewards: identity 27 | 28 | batch_size: 50 29 | batch_length: 50 30 | imag_horizon: 15 31 | eval_state_mean: False 32 | 33 | precision: 16 34 | train_every_actions: 10 35 | # -------------------------------------------------------------------------------- /DMC_image/custom_dmc_tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from custom_dmc_tasks import walker 2 | from custom_dmc_tasks import quadruped 3 | from custom_dmc_tasks import jaco 4 | 5 | 6 | def make(domain, task, 7 | task_kwargs=None, 8 | environment_kwargs=None, 9 | visualize_reward=False, 10 | mass=1.0): 11 | 12 | if domain == 'walker': 13 | return walker.make(task, 14 | task_kwargs=task_kwargs, 15 | environment_kwargs=environment_kwargs, 16 | visualize_reward=visualize_reward) 17 | elif domain == 'quadruped': 18 | return quadruped.make(task, 19 | task_kwargs=task_kwargs, 20 | environment_kwargs=environment_kwargs, 21 | visualize_reward=visualize_reward) 22 | else: 23 | raise f'{task} not found' 24 | 25 | 26 | def make_jaco(task, obs_type, seed, img_size,): 27 | return jaco.make(task, obs_type, seed, img_size,) -------------------------------------------------------------------------------- /DMC_image/custom_dmc_tasks/jaco.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """A task where the goal is to move the hand close to a target prop or site.""" 17 | 18 | import collections 19 | 20 | from dm_control import composer 21 | from dm_control.composer import initializers 22 | from dm_control.composer.observation import observable 23 | from dm_control.composer.variation import distributions 24 | from dm_control.entities import props 25 | from dm_control.manipulation.shared import arenas 26 | from dm_control.manipulation.shared import cameras 27 | from dm_control.manipulation.shared import constants 28 | from dm_control.manipulation.shared import observations 29 | from dm_control.manipulation.shared import registry 30 | from dm_control.manipulation.shared import robots 31 | from dm_control.manipulation.shared import tags 32 | from dm_control.manipulation.shared import workspaces 33 | from dm_control.utils import rewards 34 | import numpy as np 35 | 36 | _ReachWorkspace = collections.namedtuple( 37 | '_ReachWorkspace', ['target_bbox', 'tcp_bbox', 'arm_offset']) 38 | 39 | # Ensures that the props are not touching the table before settling. 40 | _PROP_Z_OFFSET = 0.001 41 | 42 | _DUPLO_WORKSPACE = _ReachWorkspace( 43 | target_bbox=workspaces.BoundingBox( 44 | lower=(-0.1, -0.1, _PROP_Z_OFFSET), 45 | upper=(0.1, 0.1, _PROP_Z_OFFSET)), 46 | tcp_bbox=workspaces.BoundingBox( 47 | lower=(-0.1, -0.1, 0.2), 48 | upper=(0.1, 0.1, 0.4)), 49 | arm_offset=robots.ARM_OFFSET) 50 | 51 | _SITE_WORKSPACE = _ReachWorkspace( 52 | target_bbox=workspaces.BoundingBox( 53 | lower=(-0.2, -0.2, 0.02), 54 | upper=(0.2, 0.2, 0.4)), 55 | tcp_bbox=workspaces.BoundingBox( 56 | lower=(-0.2, -0.2, 0.02), 57 | upper=(0.2, 0.2, 0.4)), 58 | arm_offset=robots.ARM_OFFSET) 59 | 60 | _TARGET_RADIUS = 0.05 61 | _TIME_LIMIT = 10 62 | 63 | TASKS = { 64 | 'reach_top_left': workspaces.BoundingBox( 65 | lower=(-0.09, 0.09, _PROP_Z_OFFSET), 66 | upper=(-0.09, 0.09, _PROP_Z_OFFSET)), 67 | 'reach_top_right': workspaces.BoundingBox( 68 | lower=(0.09, 0.09, _PROP_Z_OFFSET), 69 | upper=(0.09, 0.09, _PROP_Z_OFFSET)), 70 | 'reach_bottom_left': workspaces.BoundingBox( 71 | lower=(-0.09, -0.09, _PROP_Z_OFFSET), 72 | upper=(-0.09, -0.09, _PROP_Z_OFFSET)), 73 | 'reach_bottom_right': workspaces.BoundingBox( 74 | lower=(0.09, -0.09, _PROP_Z_OFFSET), 75 | upper=(0.09, -0.09, _PROP_Z_OFFSET)), 76 | } 77 | 78 | 79 | def make(task_id, obs_type, seed, img_size=84, ): 80 | obs_settings = observations.VISION if obs_type == 'pixels' else observations.PERFECT_FEATURES 81 | obs_settings = obs_settings._replace(camera=obs_settings[-1]._replace(width=img_size)) 82 | obs_settings = obs_settings._replace(camera=obs_settings[-1]._replace(height=img_size)) 83 | if obs_type == 'states': 84 | global _TIME_LIMIT 85 | _TIME_LIMIT = 10.04 86 | # Note: Adding this fixes the problem of having 249 steps with action repeat = 1 87 | task = _reach(task_id, obs_settings=obs_settings, use_site=False) 88 | return composer.Environment(task, time_limit=_TIME_LIMIT, random_state=seed) 89 | 90 | 91 | class MTReach(composer.Task): 92 | """Bring the hand close to a target prop or site.""" 93 | 94 | def __init__( 95 | self, task_id, arena, arm, hand, prop, obs_settings, workspace, control_timestep): 96 | """Initializes a new `Reach` task. 97 | 98 | Args: 99 | arena: `composer.Entity` instance. 100 | arm: `robot_base.RobotArm` instance. 101 | hand: `robot_base.RobotHand` instance. 102 | prop: `composer.Entity` instance specifying the prop to reach to, or None 103 | in which case the target is a fixed site whose position is specified by 104 | the workspace. 105 | obs_settings: `observations.ObservationSettings` instance. 106 | workspace: `_ReachWorkspace` specifying the placement of the prop and TCP. 107 | control_timestep: Float specifying the control timestep in seconds. 108 | """ 109 | self._task_id = task_id 110 | self._arena = arena 111 | self._arm = arm 112 | self._hand = hand 113 | self._arm.attach(self._hand) 114 | self._arena.attach_offset(self._arm, offset=workspace.arm_offset) 115 | self.control_timestep = control_timestep 116 | self._tcp_initializer = initializers.ToolCenterPointInitializer( 117 | self._hand, self._arm, 118 | position=distributions.Uniform(*workspace.tcp_bbox), 119 | quaternion=workspaces.DOWN_QUATERNION) 120 | 121 | # Add custom camera observable. 122 | self._task_observables = cameras.add_camera_observables( 123 | arena, obs_settings, cameras.FRONT_CLOSE) 124 | 125 | target_pos_distribution = distributions.Uniform(*TASKS[task_id]) 126 | self._prop = prop 127 | if prop: 128 | # The prop itself is used to visualize the target location. 129 | self._make_target_site(parent_entity=prop, visible=False) 130 | self._target = self._arena.add_free_entity(prop) 131 | self._prop_placer = initializers.PropPlacer( 132 | props=[prop], 133 | position=target_pos_distribution, 134 | quaternion=workspaces.uniform_z_rotation, 135 | settle_physics=True) 136 | else: 137 | self._target = self._make_target_site(parent_entity=arena, visible=True) 138 | self._target_placer = target_pos_distribution 139 | 140 | # Add sites for visualizing the prop and target bounding boxes. 141 | workspaces.add_bbox_site( 142 | body=self.root_entity.mjcf_model.worldbody, 143 | lower=workspace.tcp_bbox.lower, upper=workspace.tcp_bbox.upper, 144 | rgba=constants.GREEN, name='tcp_spawn_area') 145 | workspaces.add_bbox_site( 146 | body=self.root_entity.mjcf_model.worldbody, 147 | lower=workspace.target_bbox.lower, upper=workspace.target_bbox.upper, 148 | rgba=constants.BLUE, name='target_spawn_area') 149 | 150 | def _make_target_site(self, parent_entity, visible): 151 | return workspaces.add_target_site( 152 | body=parent_entity.mjcf_model.worldbody, 153 | radius=_TARGET_RADIUS, visible=visible, 154 | rgba=constants.RED, name='target_site') 155 | 156 | @property 157 | def root_entity(self): 158 | return self._arena 159 | 160 | @property 161 | def arm(self): 162 | return self._arm 163 | 164 | @property 165 | def hand(self): 166 | return self._hand 167 | 168 | @property 169 | def task_observables(self): 170 | return self._task_observables 171 | 172 | def get_reward(self, physics): 173 | hand_pos = physics.bind(self._hand.tool_center_point).xpos 174 | target_pos = physics.bind(self._target).xpos 175 | distance = np.linalg.norm(hand_pos - target_pos) 176 | return rewards.tolerance( 177 | distance, bounds=(0, _TARGET_RADIUS), margin=_TARGET_RADIUS) 178 | 179 | def initialize_episode(self, physics, random_state): 180 | self._hand.set_grasp(physics, close_factors=random_state.uniform()) 181 | self._tcp_initializer(physics, random_state) 182 | if self._prop: 183 | self._prop_placer(physics, random_state) 184 | else: 185 | physics.bind(self._target).pos = ( 186 | self._target_placer(random_state=random_state)) 187 | 188 | 189 | def _reach(task_id, obs_settings, use_site): 190 | """Configure and instantiate a `Reach` task. 191 | 192 | Args: 193 | obs_settings: An `observations.ObservationSettings` instance. 194 | use_site: Boolean, if True then the target will be a fixed site, otherwise 195 | it will be a moveable Duplo brick. 196 | 197 | Returns: 198 | An instance of `reach.Reach`. 199 | """ 200 | arena = arenas.Standard() 201 | arm = robots.make_arm(obs_settings=obs_settings) 202 | hand = robots.make_hand(obs_settings=obs_settings) 203 | if use_site: 204 | workspace = _SITE_WORKSPACE 205 | prop = None 206 | else: 207 | workspace = _DUPLO_WORKSPACE 208 | prop = props.Duplo(observable_options=observations.make_options( 209 | obs_settings, observations.FREEPROP_OBSERVABLES)) 210 | task = MTReach(task_id, arena=arena, arm=arm, hand=hand, prop=prop, 211 | obs_settings=obs_settings, 212 | workspace=workspace, 213 | control_timestep=constants.CONTROL_TIMESTEP) 214 | return task 215 | -------------------------------------------------------------------------------- /DMC_image/custom_dmc_tasks/walker.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Planar Walker Domain.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import os 23 | 24 | from dm_control import mujoco 25 | from dm_control.rl import control 26 | from dm_control.suite import base 27 | from dm_control.suite import common 28 | from dm_control.suite.utils import randomizers 29 | from dm_control.utils import containers 30 | from dm_control.utils import rewards 31 | from dm_control.utils import io as resources 32 | from dm_control import suite 33 | 34 | _DEFAULT_TIME_LIMIT = 25 35 | _CONTROL_TIMESTEP = .025 36 | 37 | # Minimal height of torso over foot above which stand reward is 1. 38 | _STAND_HEIGHT = 1.2 39 | 40 | # Horizontal speeds (meters/second) above which move reward is 1. 41 | _WALK_SPEED = 1 42 | _RUN_SPEED = 8 43 | _SPIN_SPEED = 5 44 | 45 | SUITE = containers.TaggedTasks() 46 | 47 | 48 | def make(task, 49 | task_kwargs=None, 50 | environment_kwargs=None, 51 | visualize_reward=False): 52 | task_kwargs = task_kwargs or {} 53 | if environment_kwargs is not None: 54 | task_kwargs = task_kwargs.copy() 55 | task_kwargs['environment_kwargs'] = environment_kwargs 56 | env = SUITE[task](**task_kwargs) 57 | env.task.visualize_reward = visualize_reward 58 | return env 59 | 60 | 61 | def get_model_and_assets(): 62 | """Returns a tuple containing the model XML string and a dict of assets.""" 63 | root_dir = os.path.dirname(os.path.dirname(__file__)) 64 | xml = resources.GetResource(os.path.join(root_dir, 'custom_dmc_tasks', 65 | 'walker.xml')) 66 | return xml, common.ASSETS 67 | 68 | 69 | @SUITE.add('benchmarking') 70 | def flip(time_limit=_DEFAULT_TIME_LIMIT, 71 | random=None, 72 | environment_kwargs=None): 73 | """Returns the Run task.""" 74 | physics = Physics.from_xml_string(*get_model_and_assets()) 75 | task = PlanarWalker(move_speed=_RUN_SPEED, 76 | forward=True, 77 | flip=True, 78 | random=random) 79 | environment_kwargs = environment_kwargs or {} 80 | return control.Environment(physics, 81 | task, 82 | time_limit=time_limit, 83 | control_timestep=_CONTROL_TIMESTEP, 84 | **environment_kwargs) 85 | 86 | 87 | class Physics(mujoco.Physics): 88 | """Physics simulation with additional features for the Walker domain.""" 89 | def torso_upright(self): 90 | """Returns projection from z-axes of torso to the z-axes of world.""" 91 | return self.named.data.xmat['torso', 'zz'] 92 | 93 | def torso_height(self): 94 | """Returns the height of the torso.""" 95 | return self.named.data.xpos['torso', 'z'] 96 | 97 | def horizontal_velocity(self): 98 | """Returns the horizontal velocity of the center-of-mass.""" 99 | return self.named.data.sensordata['torso_subtreelinvel'][0] 100 | 101 | def orientations(self): 102 | """Returns planar orientations of all bodies.""" 103 | return self.named.data.xmat[1:, ['xx', 'xz']].ravel() 104 | 105 | def angmomentum(self): 106 | """Returns the angular momentum of torso of the Cheetah about Y axis.""" 107 | return self.named.data.subtree_angmom['torso'][1] 108 | 109 | 110 | class PlanarWalker(base.Task): 111 | """A planar walker task.""" 112 | def __init__(self, move_speed, forward=True, flip=False, random=None): 113 | """Initializes an instance of `PlanarWalker`. 114 | 115 | Args: 116 | move_speed: A float. If this value is zero, reward is given simply for 117 | standing up. Otherwise this specifies a target horizontal velocity for 118 | the walking task. 119 | random: Optional, either a `numpy.random.RandomState` instance, an 120 | integer seed for creating a new `RandomState`, or None to select a seed 121 | automatically (default). 122 | """ 123 | self._move_speed = move_speed 124 | self._forward = 1 if forward else -1 125 | self._flip = flip 126 | super(PlanarWalker, self).__init__(random=random) 127 | 128 | def initialize_episode(self, physics): 129 | """Sets the state of the environment at the start of each episode. 130 | 131 | In 'standing' mode, use initial orientation and small velocities. 132 | In 'random' mode, randomize joint angles and let fall to the floor. 133 | 134 | Args: 135 | physics: An instance of `Physics`. 136 | 137 | """ 138 | randomizers.randomize_limited_and_rotational_joints( 139 | physics, self.random) 140 | super(PlanarWalker, self).initialize_episode(physics) 141 | 142 | def get_observation(self, physics): 143 | """Returns an observation of body orientations, height and velocites.""" 144 | obs = collections.OrderedDict() 145 | obs['orientations'] = physics.orientations() 146 | obs['height'] = physics.torso_height() 147 | obs['velocity'] = physics.velocity() 148 | return obs 149 | 150 | def get_reward(self, physics): 151 | """Returns a reward to the agent.""" 152 | standing = rewards.tolerance(physics.torso_height(), 153 | bounds=(_STAND_HEIGHT, float('inf')), 154 | margin=_STAND_HEIGHT / 2) 155 | upright = (1 + physics.torso_upright()) / 2 156 | stand_reward = (3 * standing + upright) / 4 157 | 158 | if self._flip: 159 | move_reward = rewards.tolerance(self._forward * 160 | physics.angmomentum(), 161 | bounds=(_SPIN_SPEED, float('inf')), 162 | margin=_SPIN_SPEED, 163 | value_at_margin=0, 164 | sigmoid='linear') 165 | else: 166 | move_reward = rewards.tolerance( 167 | self._forward * physics.horizontal_velocity(), 168 | bounds=(self._move_speed, float('inf')), 169 | margin=self._move_speed / 2, 170 | value_at_margin=0.5, 171 | sigmoid='linear') 172 | 173 | return stand_reward * (5 * move_reward + 1) / 6 174 | -------------------------------------------------------------------------------- /DMC_image/custom_dmc_tasks/walker.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /DMC_image/dmc_benchmark.py: -------------------------------------------------------------------------------- 1 | DOMAINS = [ 2 | 'walker', 3 | 'quadruped', 4 | 'jaco', 5 | ] 6 | 7 | WALKER_TASKS = [ 8 | 'walker_stand', 9 | 'walker_walk', 10 | 'walker_run', 11 | 'walker_flip', 12 | ] 13 | 14 | QUADRUPED_TASKS = [ 15 | 'quadruped_walk', 16 | 'quadruped_run', 17 | 'quadruped_stand', 18 | 'quadruped_jump', 19 | ] 20 | 21 | JACO_TASKS = [ 22 | 'jaco_reach_top_left', 23 | 'jaco_reach_top_right', 24 | 'jaco_reach_bottom_left', 25 | 'jaco_reach_bottom_right', 26 | ] 27 | 28 | TASKS = WALKER_TASKS + QUADRUPED_TASKS + JACO_TASKS 29 | 30 | parameter_1 = ['0.4', '0.8', '1.0', '1.4'] 31 | parameter_1_eval = ['0.6', '1.2'] 32 | 33 | parameter_2 = ['0.2', '0.6', '1.0', '1.4', '1.8'] 34 | parameter_2_eval = ['0.4', '0.8', '1.2', '1.6'] 35 | 36 | PRETRAIN_TASKS = { 37 | 'walker': 'walker_stand', 38 | 'jaco': 'jaco_reach_top_left', 39 | 'quadruped': 'quadruped_walk', 40 | 'walker_mass': ['walker_stand~mass~' + para for para in parameter_2], 41 | 'quadruped_mass': ['quadruped_stand~mass~' + para for para in parameter_1], 42 | 'quadruped_damping': ['quadruped_stand~damping~' + para for para in parameter_2], 43 | } 44 | 45 | FINETUNE_TASKS = { 46 | 'walker_stand_mass': ['walker_stand~mass~' + para for para in parameter_2], 47 | 'walker_stand_mass_eval': ['walker_stand~mass~' + para for para in parameter_2_eval], 48 | 'walker_walk_mass': ['walker_walk~mass~' + para for para in parameter_2], 49 | 'walker_walk_mass_eval': ['walker_walk~mass~' + para for para in parameter_2_eval], 50 | 'walker_run_mass': ['walker_run~mass~' + para for para in parameter_2], 51 | 'walker_run_mass_eval': ['walker_run~mass~' + para for para in parameter_2_eval], 52 | 'walker_flip_mass': ['walker_flip~mass~' + para for para in parameter_2], 53 | 'walker_flip_mass_eval': ['walker_flip~mass~' + para for para in parameter_2_eval], 54 | 55 | 'quadruped_stand_mass': ['quadruped_stand~mass~' + para for para in parameter_1], 56 | 'quadruped_stand_mass_eval': ['quadruped_stand~mass~' + para for para in parameter_1_eval], 57 | 'quadruped_walk_mass': ['quadruped_walk~mass~' + para for para in parameter_1], 58 | 'quadruped_walk_mass_eval': ['quadruped_walk~mass~' + para for para in parameter_1_eval], 59 | 'quadruped_run_mass': ['quadruped_run~mass~' + para for para in parameter_1], 60 | 'quadruped_run_mass_eval': ['quadruped_run~mass~' + para for para in parameter_1_eval], 61 | 'quadruped_jump_mass': ['quadruped_jump~mass~' + para for para in parameter_1], 62 | 'quadruped_jump_mass_eval': ['quadruped_jump~mass~' + para for para in parameter_1_eval], 63 | 64 | 'quadruped_stand_damping': ['quadruped_stand~damping~' + para for para in parameter_2], 65 | 'quadruped_stand_damping_eval': ['quadruped_stand~damping~' + para for para in parameter_2_eval], 66 | 'quadruped_walk_damping': ['quadruped_walk~damping~' + para for para in parameter_2], 67 | 'quadruped_walk_damping_eval': ['quadruped_walk~damping~' + para for para in parameter_2_eval], 68 | 'quadruped_run_damping': ['quadruped_run~damping~' + para for para in parameter_2], 69 | 'quadruped_run_damping_eval': ['quadruped_run~damping~' + para for para in parameter_2_eval], 70 | 'quadruped_jump_damping': ['quadruped_jump~damping~' + para for para in parameter_2], 71 | 'quadruped_jump_damping_eval': ['quadruped_jump~damping~' + para for para in parameter_2_eval], 72 | } 73 | -------------------------------------------------------------------------------- /DMC_image/dreamer_finetune.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - configs/dreamer 3 | - configs/agent: dreamer 4 | - configs: ${configs} 5 | - override hydra/launcher: submitit_local 6 | 7 | # mode 8 | reward_free: false 9 | # task settings 10 | #task: walker_stand 11 | task: none 12 | domain: walker_mass 13 | finetune_domain: walker_stand_mass 14 | # train settings 15 | num_train_frames: 100010 16 | num_seed_frames: 4000 17 | # eval 18 | eval_every_frames: 10000 19 | num_eval_episodes: 10 20 | # pretrained 21 | snapshot_ts: 100000 22 | snapshot_base_dir: ./pretrained_models 23 | custom_snap_dir: none 24 | # replay buffer 25 | replay_buffer_size: 1000000 26 | replay_buffer_num_workers: 4 27 | # misc 28 | seed: 1 29 | device: cuda 30 | save_video: false 31 | save_train_video: false 32 | save_eval_episodes: false 33 | use_tb: true 34 | use_wandb: true 35 | # experiment 36 | experiment: ft 37 | project_name: ??? 38 | 39 | # log settings 40 | log_every_frames: 1000 41 | recon_every_frames: 100000000 # edit for debug 42 | 43 | # planning 44 | mpc: false 45 | mpc_opt: { iterations : 12, num_samples : 512, num_elites : 64, mixture_coef : 0.05, min_std : 0.1, temperature : 0.5, momentum : 0.1, horizon : 5, use_value: true } 46 | 47 | # Pretrained network reuse 48 | init_critic: false 49 | init_actor: true 50 | init_task: 1.0 51 | 52 | # Fine-tuning ablation 53 | # we have saved the last model 54 | # save_ft_model: true 55 | save_ft_model: false 56 | 57 | # Dreamer FT 58 | grad_heads: [decoder, reward] 59 | reward_norm: {momentum: 1.0, scale: 1.0, eps: 1e-8} 60 | actor_ent: 1e-4 61 | 62 | hydra: 63 | run: 64 | dir: ./exp_local/${domain}/finetune_${finetune_domain}/${agent.name}/${snapshot_ts}/${now:%Y.%m.%d.%H%M%S}_${seed} 65 | sweep: 66 | dir: ./exp_sweep/${domain}/finetune_${finetune_domain}/${agent.name}/${snapshot_ts}/${now:%Y.%m.%d.%H%M%S}_${seed}_${experiment} 67 | subdir: ${hydra.job.num} 68 | launcher: 69 | timeout_min: 4300 70 | cpus_per_task: 10 71 | gpus_per_node: 1 72 | tasks_per_node: 1 73 | mem_gb: 160 74 | nodes: 1 75 | submitit_folder: ./exp_sweep/${domain}/finetune_${finetune_domain}/${agent.name}/${snapshot_ts}/${now:%Y.%m.%d.%H%M%S}_${seed}_${experiment}/.slurm -------------------------------------------------------------------------------- /DMC_image/dreamer_pretrain.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - configs/dreamer 3 | - configs/agent: dreamer 4 | - configs: ${configs} 5 | - override hydra/launcher: submitit_local 6 | 7 | # mode 8 | reward_free: true 9 | # task settings 10 | task: none 11 | domain: walker # primal task will be inferred in runtime 12 | # train settings 13 | num_train_frames: 2000010 14 | num_seed_frames: 4000 15 | # eval 16 | eval_every_frames: 10000 17 | num_eval_episodes: 10 18 | # snapshot 19 | snapshots: [100000, 500000, 1000000, 2000000] 20 | snapshot_dir: ../../../../../pretrained_models/${obs_type}/${domain}/${agent.name}/${seed} 21 | # replay buffer 22 | replay_buffer_size: 1000000 23 | replay_buffer_num_workers: 4 24 | # misc 25 | seed: 1 26 | device: cuda 27 | save_video: false 28 | save_train_video: true 29 | use_tb: true 30 | use_wandb: true 31 | 32 | # experiment 33 | experiment: pt 34 | project_name: ??? 35 | 36 | # log settings 37 | log_every_frames: 1000 38 | recon_every_frames: 100000000 # edit for debug 39 | 40 | 41 | hydra: 42 | run: 43 | dir: ./exp_local/${domain}/pretrain/${agent.name}/${now:%Y.%m.%d.%H%M%S}_${seed} 44 | sweep: 45 | dir: ./exp_sweep/${domain}/pretrain/${agent.name}/${now:%Y.%m.%d.%H%M}_${seed}_${experiment} 46 | subdir: ${hydra.job.num} 47 | launcher: 48 | timeout_min: 4300 49 | cpus_per_task: 10 50 | gpus_per_node: 1 51 | tasks_per_node: 1 52 | mem_gb: 160 53 | nodes: 1 54 | submitit_folder: ./exp_sweep/${domain}/pretrain/${agent.name}/${now:%Y.%m.%H%M}_${seed}_${experiment}/.slurm 55 | -------------------------------------------------------------------------------- /DMC_image/dreamer_replay.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings('ignore') 3 | warnings.filterwarnings('ignore', category=DeprecationWarning) 4 | 5 | import collections 6 | import datetime 7 | import io 8 | import pathlib 9 | import uuid 10 | 11 | import numpy as np 12 | import random 13 | from torch.utils.data import IterableDataset, DataLoader 14 | import torch 15 | import utils 16 | 17 | 18 | class ReplayBuffer(IterableDataset): 19 | 20 | def __init__( 21 | self, data_specs, meta_specs, directory, length=20, capacity=0, ongoing=False, minlen=1, maxlen=0, 22 | prioritize_ends=False, device='cpu', load_first=False): 23 | self._directory = pathlib.Path(directory).expanduser() 24 | self._directory.mkdir(parents=True, exist_ok=True) 25 | self._capacity = capacity 26 | self._ongoing = ongoing 27 | self._minlen = minlen 28 | self._maxlen = maxlen 29 | self._prioritize_ends = prioritize_ends 30 | self._random = np.random.RandomState() 31 | # filename -> key -> value_sequence 32 | self._complete_eps = load_episodes(self._directory, capacity, minlen, load_first=load_first) 33 | # worker -> key -> value_sequence 34 | self._ongoing_eps = collections.defaultdict( 35 | lambda: collections.defaultdict(list)) 36 | self._total_episodes, self._total_steps = count_episodes(directory) 37 | self._loaded_episodes = len(self._complete_eps) 38 | self._loaded_steps = sum(eplen(x) for x in self._complete_eps.values()) 39 | self._length = length 40 | self._data_specs = data_specs 41 | self._meta_specs = meta_specs 42 | self.device = device 43 | try: 44 | assert self._minlen <= self._length <= self._maxlen 45 | except: 46 | print("Incosistency between min/max/length in the replay buffer. Defaulting to (length): ", length) 47 | self._minlen = self._maxlen = self._length = length 48 | 49 | def __len__(self): 50 | return self._total_steps 51 | 52 | @property 53 | def stats(self): 54 | return { 55 | 'total_steps': self._total_steps, 56 | 'total_episodes': self._total_episodes, 57 | 'loaded_steps': self._loaded_steps, 58 | 'loaded_episodes': self._loaded_episodes, 59 | } 60 | 61 | def add(self, time_step, meta, worker=0, env_id=None): 62 | episode = self._ongoing_eps[worker] 63 | for spec in self._data_specs: 64 | value = time_step[spec.name] 65 | if np.isscalar(value): 66 | value = np.full(spec.shape, value, spec.dtype) 67 | assert spec.shape == value.shape and spec.dtype == value.dtype 68 | episode[spec.name].append(value) 69 | for spec in self._meta_specs: 70 | value = meta[spec.name] 71 | if np.isscalar(value): 72 | value = np.full(spec.shape, value, spec.dtype) 73 | assert spec.shape == value.shape and spec.dtype == value.dtype 74 | episode[spec.name].append(value) 75 | if type(time_step) == dict: 76 | if time_step['is_last']: 77 | self.add_episode(episode, env_id=env_id) 78 | episode.clear() 79 | else: 80 | if bool(dreamer_obs['is_last']): 81 | self.add_episode(episode, env_id=env_id) 82 | episode.clear() 83 | 84 | def add_episode(self, episode, env_id=None): 85 | length = eplen(episode) 86 | if length < self._minlen: 87 | print(f'Skipping short episode of length {length}.') 88 | return 89 | self._total_steps += length 90 | self._loaded_steps += length 91 | self._total_episodes += 1 92 | self._loaded_episodes += 1 93 | episode = {key: convert(value) for key, value in episode.items()} 94 | filename = save_episode(self._directory, episode, env_id=env_id) 95 | self._complete_eps[str(filename)] = episode 96 | self._enforce_limit() 97 | 98 | def __iter__(self): 99 | sequence = self._sample_sequence() 100 | while True: 101 | chunk = collections.defaultdict(list) 102 | added = 0 103 | while added < self._length: 104 | needed = self._length - added 105 | adding = {k: v[:needed] for k, v in sequence.items()} 106 | sequence = {k: v[needed:] for k, v in sequence.items()} 107 | for key, value in adding.items(): 108 | chunk[key].append(value) 109 | added += len(adding['action']) 110 | if len(sequence['action']) < 1: 111 | sequence = self._sample_sequence() 112 | chunk = {k: np.concatenate(v) for k, v in chunk.items()} 113 | chunk['is_terminal'] = chunk['discount'] == 0 114 | chunk = {k: torch.as_tensor(np.copy(v), device=self.device) for k, v in chunk.items()} 115 | yield chunk 116 | 117 | def _sample_sequence(self): 118 | episodes = list(self._complete_eps.values()) 119 | if self._ongoing: 120 | episodes += [ 121 | x for x in self._ongoing_eps.values() 122 | if eplen(x) >= self._minlen] 123 | episode = self._random.choice(episodes) 124 | total = len(episode['action']) 125 | length = total 126 | if self._maxlen: 127 | length = min(length, self._maxlen) 128 | # Randomize length to avoid all chunks ending at the same time in case the 129 | # episodes are all of the same length. 130 | length -= np.random.randint(self._minlen) 131 | length = max(self._minlen, length) 132 | upper = total - length + 1 133 | if self._prioritize_ends: 134 | upper += self._minlen 135 | index = min(self._random.randint(upper), total - length) 136 | sequence = { 137 | k: convert(v[index: index + length]) 138 | for k, v in episode.items() if not k.startswith('log_')} 139 | # np.bool -> bool 140 | sequence['is_first'] = np.zeros(len(sequence['action']), bool) 141 | sequence['is_first'][0] = True 142 | if self._maxlen: 143 | assert self._minlen <= len(sequence['action']) <= self._maxlen 144 | return sequence 145 | 146 | def _enforce_limit(self): 147 | if not self._capacity: 148 | return 149 | while self._loaded_episodes > 1 and self._loaded_steps > self._capacity: 150 | # Relying on Python preserving the insertion order of dicts. 151 | oldest, episode = next(iter(self._complete_eps.items())) 152 | self._loaded_steps -= eplen(episode) 153 | self._loaded_episodes -= 1 154 | del self._complete_eps[oldest] 155 | 156 | 157 | def count_episodes(directory): 158 | filenames = list(directory.glob('*.npz')) 159 | num_episodes = len(filenames) 160 | if len(filenames) > 0 and "-" in str(filenames[0]): 161 | num_steps = sum(int(str(n).split('-')[-1][:-4]) - 1 for n in filenames) 162 | else: 163 | num_steps = sum(int(str(n).split('_')[-1][:-4]) - 1 for n in filenames) 164 | return num_episodes, num_steps 165 | 166 | 167 | @utils.retry 168 | def save_episode(directory, episode, env_id=None): 169 | timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S') 170 | identifier = str(uuid.uuid4().hex) 171 | length = eplen(episode) 172 | if env_id is None: 173 | filename = directory / f'{timestamp}-{identifier}-{length}.npz' 174 | else: 175 | filename = directory / f'{env_id}-{timestamp}-{identifier}-{length}.npz' 176 | with io.BytesIO() as f1: 177 | np.savez_compressed(f1, **episode) 178 | f1.seek(0) 179 | with filename.open('wb') as f2: 180 | f2.write(f1.read()) 181 | return filename 182 | 183 | 184 | def load_episodes(directory, capacity=None, minlen=1, load_first=False): 185 | # The returned directory from filenames to episodes is guaranteed to be in 186 | # temporally sorted order. 187 | filenames = sorted(directory.glob('*.npz')) 188 | if capacity: 189 | num_steps = 0 190 | num_episodes = 0 191 | ordered_filenames = filenames if load_first else reversed(filenames) 192 | for filename in ordered_filenames: 193 | if "-" in str(filename): 194 | length = int(str(filename).split('-')[-1][:-4]) 195 | else: 196 | length = int(str(filename).split('_')[-1][:-4]) 197 | num_steps += length 198 | num_episodes += 1 199 | if num_steps >= capacity: 200 | break 201 | if load_first: 202 | filenames = filenames[:num_episodes] 203 | else: 204 | filenames = filenames[-num_episodes:] 205 | episodes = {} 206 | for filename in filenames: 207 | try: 208 | with filename.open('rb') as f: 209 | episode = np.load(f) 210 | episode = {k: episode[k] for k in episode.keys()} 211 | except Exception as e: 212 | print(f'Could not load episode {str(filename)}: {e}') 213 | continue 214 | episodes[str(filename)] = episode 215 | return episodes 216 | 217 | 218 | def convert(value): 219 | value = np.array(value) 220 | if np.issubdtype(value.dtype, np.floating): 221 | return value.astype(np.float32) 222 | elif np.issubdtype(value.dtype, np.signedinteger): 223 | return value.astype(np.int32) 224 | elif np.issubdtype(value.dtype, np.uint8): 225 | return value.astype(np.uint8) 226 | return value 227 | 228 | 229 | def eplen(episode): 230 | return len(episode['action']) - 1 231 | 232 | 233 | def _worker_init_fn(worker_id): 234 | seed = np.random.get_state()[1][0] + worker_id 235 | np.random.seed(seed) 236 | random.seed(seed) 237 | 238 | 239 | def make_replay_loader(buffer, batch_size, num_workers): 240 | return DataLoader(buffer, 241 | batch_size=batch_size, 242 | drop_last=True, 243 | ) 244 | -------------------------------------------------------------------------------- /DMC_image/logger.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import datetime 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | import torch 7 | import torchvision 8 | import wandb 9 | from termcolor import colored 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | COMMON_TRAIN_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'), 13 | ('episode', 'E', 'int'), ('episode_length', 'L', 'int'), 14 | ('episode_reward', 'R', 'float'), 15 | ('fps', 'FPS', 'float'), ('total_time', 'T', 'time'),] 16 | 17 | COMMON_EVAL_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'), 18 | ('episode', 'E', 'int'), ('episode_length', 'L', 'int'), 19 | ('episode_reward', 'R', 'float'), 20 | ('total_time', 'T', 'time'), 21 | ('episode_train_reward', 'TR', 'float'), 22 | ('episode_eval_reward', 'ER', 'float')] 23 | 24 | 25 | class AverageMeter(object): 26 | def __init__(self): 27 | self._sum = 0 28 | self._count = 0 29 | 30 | def update(self, value, n=1): 31 | self._sum += value 32 | self._count += n 33 | 34 | def value(self): 35 | return self._sum / max(1, self._count) 36 | 37 | 38 | class MetersGroup(object): 39 | def __init__(self, csv_file_name, formating, use_wandb): 40 | self._csv_file_name = csv_file_name 41 | self._formating = formating 42 | self._meters = defaultdict(AverageMeter) 43 | self._csv_file = None 44 | self._csv_writer = None 45 | self.use_wandb = use_wandb 46 | 47 | def log(self, key, value, n=1): 48 | self._meters[key].update(value, n) 49 | 50 | def _prime_meters(self): 51 | data = dict() 52 | for key, meter in self._meters.items(): 53 | if key.startswith('train'): 54 | key = key[len('train') + 1:] 55 | else: 56 | key = key[len('eval') + 1:] 57 | key = key.replace('/', '_') 58 | data[key] = meter.value() 59 | return data 60 | 61 | def _remove_old_entries(self, data): 62 | rows = [] 63 | with self._csv_file_name.open('r') as f: 64 | reader = csv.DictReader(f) 65 | for row in reader: 66 | if 'episode' in row: 67 | # BUGFIX: covers weird cases where CSV are badly written 68 | if row['episode'] == '': 69 | rows.append(row) 70 | continue 71 | if type(row['episode']) == type(None): 72 | continue 73 | if float(row['episode']) >= data['episode']: 74 | break 75 | rows.append(row) 76 | with self._csv_file_name.open('w') as f: 77 | # To handle CSV that have more keys than new data 78 | keys = set(data.keys()) 79 | if len(rows) > 0: keys = keys | set(row.keys()) 80 | keys = sorted(list(keys)) 81 | # 82 | writer = csv.DictWriter(f, 83 | fieldnames=keys, 84 | restval=0.0) 85 | writer.writeheader() 86 | for row in rows: 87 | writer.writerow(row) 88 | 89 | def _dump_to_csv(self, data): 90 | if self._csv_writer is None: 91 | should_write_header = True 92 | if self._csv_file_name.exists(): 93 | self._remove_old_entries(data) 94 | should_write_header = False 95 | 96 | self._csv_file = self._csv_file_name.open('a') 97 | self._csv_writer = csv.DictWriter(self._csv_file, 98 | fieldnames=sorted(data.keys()), 99 | restval=0.0) 100 | if should_write_header: 101 | self._csv_writer.writeheader() 102 | 103 | # To handle components that start training later 104 | # (restval covers only when data has less keys than the CSV) 105 | if self._csv_writer.fieldnames != sorted(data.keys()) and \ 106 | len(self._csv_writer.fieldnames) < len(data.keys()): 107 | self._csv_file.close() 108 | self._csv_file = self._csv_file_name.open('r') 109 | dict_reader = csv.DictReader(self._csv_file) 110 | rows = [row for row in dict_reader] 111 | self._csv_file.close() 112 | self._csv_file = self._csv_file_name.open('w') 113 | self._csv_writer = csv.DictWriter(self._csv_file, 114 | fieldnames=sorted(data.keys()), 115 | restval=0.0) 116 | self._csv_writer.writeheader() 117 | for row in rows: 118 | self._csv_writer.writerow(row) 119 | 120 | self._csv_writer.writerow(data) 121 | self._csv_file.flush() 122 | 123 | def _format(self, key, value, ty): 124 | if ty == 'int': 125 | value = int(value) 126 | return f'{key}: {value}' 127 | elif ty == 'float': 128 | return f'{key}: {value:.04f}' 129 | elif ty == 'time': 130 | value = str(datetime.timedelta(seconds=int(value))) 131 | return f'{key}: {value}' 132 | elif ty == 'str': 133 | return f'{key}: {value}' 134 | else: 135 | raise f'invalid format type: {ty}' 136 | 137 | def _dump_to_console(self, data, prefix): 138 | prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green') 139 | pieces = [f'| {prefix: <14}'] 140 | for key, disp_key, ty in self._formating: 141 | value = data.get(key, 0) 142 | pieces.append(self._format(disp_key, value, ty)) 143 | print(' | '.join(pieces)) 144 | 145 | def _dump_to_wandb(self, data): 146 | wandb.log(data) 147 | 148 | def dump(self, step, prefix): 149 | if len(self._meters) == 0: 150 | return 151 | data = self._prime_meters() 152 | data['frame'] = step 153 | if self.use_wandb: 154 | wandb_data = {prefix + '/' + key: val for key, val in data.items()} 155 | self._dump_to_wandb(data=wandb_data) 156 | self._dump_to_csv(data) 157 | self._dump_to_console(data, prefix) 158 | self._meters.clear() 159 | 160 | 161 | class Logger(object): 162 | def __init__(self, log_dir, use_tb, use_wandb): 163 | self._log_dir = log_dir 164 | self._train_mg = MetersGroup(log_dir / 'train.csv', 165 | formating=COMMON_TRAIN_FORMAT, 166 | use_wandb=use_wandb) 167 | self._eval_mg = MetersGroup(log_dir / 'eval.csv', 168 | formating=COMMON_EVAL_FORMAT, 169 | use_wandb=use_wandb) 170 | if use_tb: 171 | self._sw = SummaryWriter(str(log_dir / 'tb')) 172 | else: 173 | self._sw = None 174 | self.use_wandb = use_wandb 175 | 176 | def _try_sw_log(self, key, value, step): 177 | if self._sw is not None: 178 | self._sw.add_scalar(key, value, step) 179 | 180 | def log(self, key, value, step): 181 | assert key.startswith('train') or key.startswith('eval') 182 | if type(value) == torch.Tensor: 183 | value = value.item() 184 | # print(key) 185 | self._try_sw_log(key, value, step) 186 | mg = self._train_mg if key.startswith('train') else self._eval_mg 187 | mg.log(key, value) 188 | 189 | def log_metrics(self, metrics, step, ty): 190 | for key, value in metrics.items(): 191 | self.log(f'{ty}/{key}', value, step) 192 | 193 | def dump(self, step, ty=None): 194 | if ty is None or ty == 'eval': 195 | self._eval_mg.dump(step, 'eval') 196 | if ty is None or ty == 'train': 197 | self._train_mg.dump(step, 'train') 198 | 199 | def log_and_dump_ctx(self, step, ty): 200 | return LogAndDumpCtx(self, step, ty) 201 | 202 | def log_video(self, data, step): 203 | if self._sw is not None: 204 | for k, v in data.items(): 205 | self._sw.add_video(k, v, global_step=step, fps=15) 206 | if self.use_wandb: 207 | for k, v in data.items(): 208 | v = np.uint8(v.cpu() * 255) 209 | wandb.log({k: wandb.Video(v, fps=15, format="gif")}) 210 | 211 | 212 | class LogAndDumpCtx: 213 | def __init__(self, logger, step, ty): 214 | self._logger = logger 215 | self._step = step 216 | self._ty = ty 217 | 218 | def __enter__(self): 219 | return self 220 | 221 | def __call__(self, key, value): 222 | self._logger.log(f'{self._ty}/{key}', value, self._step) 223 | 224 | def __exit__(self, *args): 225 | self._logger.dump(self._step, self._ty) 226 | -------------------------------------------------------------------------------- /DMC_image/train_finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ALGO=$1 4 | DOMAIN=$2 # walker_mass, quadruped_mass, quadruped_damping 5 | GPU_ID=$3 6 | 7 | if [ "$DOMAIN" == "walker_mass" ] 8 | then 9 | ALL_TASKS=("walker_stand_mass" "walker_walk_mass" "walker_run_mass" "walker_flip_mass") 10 | elif [ "$DOMAIN" == "quadruped_mass" ] 11 | then 12 | ALL_TASKS=("quadruped_stand_mass" "quadruped_walk_mass" "quadruped_run_mass" "quadruped_jump_mass") 13 | elif [ "$DOMAIN" == "quadruped_damping" ] 14 | then 15 | ALL_TASKS=("quadruped_stand_damping" "quadruped_walk_damping" "quadruped_run_damping" "quadruped_jump_damping") 16 | else 17 | ALL_TASKS=() 18 | echo "No matching tasks, you can only take DOMAIN as walker_mass, quadruped_mass, or quadruped_damping" 19 | exit 0 20 | fi 21 | 22 | echo "Experiments started." 23 | for seed in $(seq 0 2) 24 | do 25 | export MUJOCO_EGL_DEVICE_ID=${GPU_ID} 26 | python dreamer_pretrain.py configs=dmc_pixels configs/agent=${ALGO} domain=${DOMAIN} seed=$seed device=cuda:${GPU_ID} 27 | for string in "${ALL_TASKS[@]}" 28 | do 29 | export MUJOCO_EGL_DEVICE_ID=${GPU_ID} 30 | python dreamer_finetune.py configs=dmc_pixels configs/agent=${ALGO} domain=${DOMAIN} seed=$seed device=cuda:${GPU_ID} snapshot_ts=2000000 finetune_domain=$string 31 | done 32 | done 33 | echo "Experiments ended." 34 | 35 | # e.g. 36 | # ./finetune.sh peac_lbs walker_mass 0 -------------------------------------------------------------------------------- /DMC_state/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /DMC_state/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CEURL/c213c2936aa2394e645c069e9c8333ee6bd0ce0a/DMC_state/README.md -------------------------------------------------------------------------------- /DMC_state/agent/becl.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import math 3 | from collections import OrderedDict 4 | 5 | import hydra 6 | import random 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from dm_env import specs 12 | 13 | import utils 14 | from agent.ddpg import DDPGAgent 15 | 16 | 17 | class BECL(nn.Module): 18 | def __init__(self, tau_dim, feature_dim, hidden_dim): 19 | super().__init__() 20 | 21 | self.embed = nn.Sequential(nn.Linear(tau_dim, hidden_dim), 22 | nn.ReLU(), 23 | nn.Linear(hidden_dim, hidden_dim), 24 | nn.ReLU(), 25 | nn.Linear(hidden_dim, feature_dim)) 26 | 27 | self.project_head = nn.Sequential(nn.Linear(feature_dim, hidden_dim), 28 | nn.ReLU(), 29 | nn.Linear(hidden_dim, feature_dim)) 30 | self.apply(utils.weight_init) 31 | 32 | def forward(self, tau): 33 | features = self.embed(tau) 34 | features = self.project_head(features) 35 | return features 36 | 37 | 38 | class BECLAgent(DDPGAgent): 39 | def __init__(self, update_skill_every_step, skill_dim, 40 | update_encoder, contrastive_update_rate, temperature, skill, **kwargs): 41 | self.skill_dim = skill_dim 42 | self.update_skill_every_step = update_skill_every_step 43 | self.update_encoder = update_encoder 44 | self.contrastive_update_rate = contrastive_update_rate 45 | self.temperature = temperature 46 | # specify skill in fine-tuning stage if needed 47 | self.skill = int(skill) if skill >= 0 else np.random.choice(self.skill_dim) 48 | # increase obs shape to include skill dim 49 | kwargs["meta_dim"] = self.skill_dim 50 | self.batch_size = kwargs['batch_size'] 51 | # create actor and critic 52 | super().__init__(**kwargs) 53 | 54 | # net 55 | self.becl = BECL(self.obs_dim - self.skill_dim, 56 | self.skill_dim, 57 | kwargs['hidden_dim']).to(kwargs['device']) 58 | 59 | # optimizers 60 | self.becl_opt = torch.optim.Adam(self.becl.parameters(), lr=self.lr) 61 | 62 | self.becl.train() 63 | 64 | def get_meta_specs(self): 65 | return specs.Array((self.skill_dim,), np.float32, 'skill'), 66 | 67 | def init_meta(self): 68 | skill = np.zeros(self.skill_dim).astype(np.float32) 69 | if not self.reward_free: 70 | skill[self.skill] = 1.0 71 | else: 72 | skill[np.random.choice(self.skill_dim)] = 1.0 73 | meta = OrderedDict() 74 | meta['skill'] = skill 75 | return meta 76 | 77 | def update_meta(self, meta, global_step, time_step, finetune=False): 78 | if global_step % self.update_skill_every_step == 0: 79 | return self.init_meta() 80 | return meta 81 | 82 | def update_contrastive(self, state, skills): 83 | metrics = dict() 84 | features = self.becl(state) 85 | logits = self.compute_info_nce_loss(features, skills) 86 | loss = logits.mean() 87 | 88 | self.becl_opt.zero_grad() 89 | if self.encoder_opt is not None: 90 | self.encoder_opt.zero_grad(set_to_none=True) 91 | loss.backward() 92 | self.becl_opt.step() 93 | if self.encoder_opt is not None: 94 | self.encoder_opt.step() 95 | 96 | if self.use_tb or self.use_wandb: 97 | metrics['contrastive_loss'] = loss.item() 98 | 99 | return metrics 100 | 101 | def compute_intr_reward(self, skills, state, metrics): 102 | 103 | # compute contrastive reward 104 | features = self.becl(state) 105 | contrastive_reward = torch.exp(-self.compute_info_nce_loss(features, skills)) 106 | 107 | intr_reward = contrastive_reward 108 | if self.use_tb or self.use_wandb: 109 | metrics['contrastive_reward'] = contrastive_reward.mean().item() 110 | 111 | return intr_reward 112 | 113 | def compute_info_nce_loss(self, features, skills): 114 | # features: (b,c), skills :(b, skill_dim) 115 | # label positives samples 116 | labels = torch.argmax(skills, dim=-1) # (b, 1) 117 | labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).long() # (b,b) 118 | labels = labels.to(self.device) 119 | 120 | features = F.normalize(features, dim=1) # (b,c) 121 | similarity_matrix = torch.matmul(features, features.T) # (b,b) 122 | 123 | # discard the main diagonal from both: labels and similarities matrix 124 | mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.device) 125 | labels = labels[~mask].view(labels.shape[0], -1) # (b,b-1) 126 | similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1) # (b,b-1) 127 | 128 | similarity_matrix = similarity_matrix / self.temperature 129 | similarity_matrix -= torch.max(similarity_matrix, 1)[0][:, None] 130 | similarity_matrix = torch.exp(similarity_matrix) 131 | 132 | pick_one_positive_sample_idx = torch.argmax(labels, dim=-1, keepdim=True) 133 | pick_one_positive_sample_idx = torch.zeros_like(labels).scatter_(-1, pick_one_positive_sample_idx, 1) 134 | 135 | positives = torch.sum(similarity_matrix * pick_one_positive_sample_idx, dim=-1, keepdim=True) # (b,1) 136 | negatives = torch.sum(similarity_matrix, dim=-1, keepdim=True) # (b,1) 137 | eps = torch.as_tensor(1e-6) 138 | loss = -torch.log(positives / (negatives + eps) + eps) # (b,1) 139 | 140 | return loss 141 | 142 | def update(self, replay_iter, step): 143 | metrics = dict() 144 | 145 | if step % self.update_every_steps != 0: 146 | return metrics 147 | 148 | if self.reward_free: 149 | 150 | batch = next(replay_iter) 151 | obs, action, reward, discount, next_obs, task_id, skill = utils.to_torch(batch, self.device) 152 | obs = self.aug_and_encode(obs) 153 | next_obs = self.aug_and_encode(next_obs) 154 | 155 | metrics.update(self.update_contrastive(next_obs, skill)) 156 | 157 | for _ in range(self.contrastive_update_rate - 1): 158 | batch = next(replay_iter) 159 | obs, action, reward, discount, next_obs, task_id, skill = utils.to_torch(batch, self.device) 160 | obs = self.aug_and_encode(obs) 161 | next_obs = self.aug_and_encode(next_obs) 162 | 163 | metrics.update(self.update_contrastive(next_obs, skill)) 164 | 165 | with torch.no_grad(): 166 | intr_reward = self.compute_intr_reward(skill, next_obs, metrics) 167 | 168 | if self.use_tb or self.use_wandb: 169 | metrics['intr_reward'] = intr_reward.mean().item() 170 | 171 | reward = intr_reward 172 | else: 173 | batch = next(replay_iter) 174 | 175 | obs, action, extr_reward, discount, next_obs, task_id, skill = utils.to_torch( 176 | batch, self.device) 177 | obs = self.aug_and_encode(obs) 178 | next_obs = self.aug_and_encode(next_obs) 179 | reward = extr_reward 180 | 181 | if self.use_tb or self.use_wandb: 182 | metrics['batch_reward'] = reward.mean().item() 183 | 184 | if not self.update_encoder: 185 | obs = obs.detach() 186 | next_obs = next_obs.detach() 187 | 188 | # extend observations with skill 189 | obs = torch.cat([obs, skill], dim=1) 190 | next_obs = torch.cat([next_obs, skill], dim=1) 191 | 192 | # update critic 193 | metrics.update( 194 | self.update_critic(obs.detach(), action, reward, discount, 195 | next_obs.detach(), step)) 196 | 197 | # update actor 198 | metrics.update(self.update_actor(obs.detach(), step)) 199 | 200 | # update critic target 201 | utils.soft_update_params(self.critic, self.critic_target, 202 | self.critic_target_tau) 203 | 204 | return metrics -------------------------------------------------------------------------------- /DMC_state/agent/cic.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from dm_env import specs 7 | import math 8 | from collections import OrderedDict 9 | 10 | import utils 11 | 12 | from agent.ddpg import DDPGAgent 13 | 14 | 15 | class CIC(nn.Module): 16 | def __init__(self, obs_dim, skill_dim, hidden_dim, project_skill): 17 | super().__init__() 18 | self.obs_dim = obs_dim 19 | self.skill_dim = skill_dim 20 | 21 | self.state_net = nn.Sequential(nn.Linear(self.obs_dim, hidden_dim), nn.ReLU(), 22 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), 23 | nn.Linear(hidden_dim, self.skill_dim)) 24 | 25 | self.next_state_net = nn.Sequential(nn.Linear(self.obs_dim, hidden_dim), nn.ReLU(), 26 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), 27 | nn.Linear(hidden_dim, self.skill_dim)) 28 | 29 | self.pred_net = nn.Sequential(nn.Linear(2 * self.skill_dim, hidden_dim), nn.ReLU(), 30 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), 31 | nn.Linear(hidden_dim, self.skill_dim)) 32 | 33 | if project_skill: 34 | self.skill_net = nn.Sequential(nn.Linear(self.skill_dim, hidden_dim), nn.ReLU(), 35 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), 36 | nn.Linear(hidden_dim, self.skill_dim)) 37 | else: 38 | self.skill_net = nn.Identity() 39 | 40 | self.apply(utils.weight_init) 41 | 42 | def forward(self, state, next_state, skill): 43 | assert len(state.size()) == len(next_state.size()) 44 | state = self.state_net(state) 45 | next_state = self.state_net(next_state) 46 | query = self.skill_net(skill) 47 | key = self.pred_net(torch.cat([state, next_state], 1)) 48 | return query, key 49 | 50 | 51 | # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 52 | 53 | 54 | class RMS(object): 55 | def __init__(self, epsilon=1e-4, shape=(1,), device='cpu'): 56 | self.M = torch.zeros(shape).to(device) 57 | self.S = torch.ones(shape).to(device) 58 | self.n = epsilon 59 | 60 | def __call__(self, x): 61 | bs = x.size(0) 62 | delta = torch.mean(x, dim=0) - self.M 63 | new_M = self.M + delta * bs / (self.n + bs) 64 | new_S = (self.S * self.n + torch.var(x, dim=0) * bs + (delta ** 2) * self.n * bs / (self.n + bs)) / ( 65 | self.n + bs) 66 | 67 | self.M = new_M 68 | self.S = new_S 69 | self.n += bs 70 | 71 | return self.M, self.S 72 | 73 | 74 | class APTArgs: 75 | def __init__(self, knn_k=16, knn_avg=True, rms=True, knn_clip=0.0005, ): 76 | self.knn_k = knn_k 77 | self.knn_avg = knn_avg 78 | self.rms = rms 79 | self.knn_clip = knn_clip 80 | 81 | 82 | def compute_apt_reward(source, target, args, device, rms): 83 | b1, b2 = source.size(0), target.size(0) 84 | # (b1, 1, c) - (1, b2, c) -> (b1, 1, c) - (1, b2, c) -> (b1, b2, c) -> (b1, b2) 85 | sim_matrix = torch.norm(source[:, None, :].view(b1, 1, -1) - target[None, :, :].view(1, b2, -1), dim=-1, p=2) 86 | reward, _ = sim_matrix.topk(args.knn_k, dim=1, largest=False, sorted=True) # (b1, k) 87 | 88 | if not args.knn_avg: # only keep k-th nearest neighbor 89 | reward = reward[:, -1] 90 | reward = reward.reshape(-1, 1) # (b1, 1) 91 | if args.rms: 92 | moving_mean, moving_std = rms(reward) 93 | reward = reward / moving_std 94 | reward = torch.max(reward - args.knn_clip, torch.zeros_like(reward).to(device)) # (b1, ) 95 | else: # average over all k nearest neighbors 96 | reward = reward.reshape(-1, 1) # (b1 * k, 1) 97 | if args.rms: 98 | moving_mean, moving_std = rms(reward) 99 | reward = reward / moving_std 100 | reward = torch.max(reward - args.knn_clip, torch.zeros_like(reward).to(device)) 101 | reward = reward.reshape((b1, args.knn_k)) # (b1, k) 102 | reward = reward.mean(dim=1) # (b1,) 103 | reward = torch.log(reward + 1.0) 104 | return reward 105 | 106 | 107 | class CICAgent(DDPGAgent): 108 | # Contrastive Intrinsic Control (CIC) 109 | def __init__(self, update_skill_every_step, skill_dim, scale, 110 | project_skill, rew_type, update_rep, temp, **kwargs): 111 | self.temp = temp 112 | self.skill_dim = skill_dim 113 | self.update_skill_every_step = update_skill_every_step 114 | self.scale = scale 115 | self.project_skill = project_skill 116 | self.rew_type = rew_type 117 | self.update_rep = update_rep 118 | kwargs["meta_dim"] = self.skill_dim 119 | # create actor and critic 120 | self.device = kwargs['device'] 121 | 122 | super().__init__(**kwargs) 123 | self.rms = RMS(device=self.device) 124 | 125 | # create cic first 126 | self.cic = CIC(self.obs_dim - skill_dim, skill_dim, 127 | kwargs['hidden_dim'], project_skill).to(self.device) 128 | 129 | # optimizers 130 | self.cic_optimizer = torch.optim.Adam(self.cic.parameters(), 131 | lr=self.lr) 132 | 133 | self.cic.train() 134 | 135 | def get_meta_specs(self): 136 | return (specs.Array((self.skill_dim,), np.float32, 'skill'),) 137 | 138 | def init_meta(self): 139 | if not self.reward_free: 140 | # selects mean skill of 0.5 (to select skill automatically use CEM or Grid Sweep 141 | # procedures described in the CIC paper) 142 | skill = np.ones(self.skill_dim).astype(np.float32) * 0.5 143 | else: 144 | skill = np.random.uniform(0, 1, self.skill_dim).astype(np.float32) 145 | meta = OrderedDict() 146 | meta['skill'] = skill 147 | return meta 148 | 149 | def update_meta(self, meta, step, time_step): 150 | if step % self.update_skill_every_step == 0: 151 | return self.init_meta() 152 | return meta 153 | 154 | def compute_cpc_loss(self, obs, next_obs, skill): 155 | temperature = self.temp 156 | eps = 1e-6 157 | query, key = self.cic.forward(obs, next_obs, skill) 158 | query = F.normalize(query, dim=1) 159 | key = F.normalize(key, dim=1) 160 | cov = torch.mm(query, key.T) # (b,b) 161 | sim = torch.exp(cov / temperature) 162 | neg = sim.sum(dim=-1) # (b,) 163 | row_sub = torch.Tensor(neg.shape).fill_(math.e ** (1 / temperature)).to(neg.device) 164 | neg = torch.clamp(neg - row_sub, min=eps) # clamp for numerical stability 165 | 166 | pos = torch.exp(torch.sum(query * key, dim=-1) / temperature) # (b,) 167 | loss = -torch.log(pos / (neg + eps)) # (b,) 168 | return loss, cov / temperature 169 | 170 | def update_cic(self, obs, skill, next_obs, step): 171 | metrics = dict() 172 | 173 | loss, logits = self.compute_cpc_loss(obs, next_obs, skill) 174 | loss = loss.mean() 175 | self.cic_optimizer.zero_grad() 176 | loss.backward() 177 | self.cic_optimizer.step() 178 | 179 | if self.use_tb or self.use_wandb: 180 | metrics['cic_loss'] = loss.item() 181 | metrics['cic_logits'] = logits.norm() 182 | 183 | return metrics 184 | 185 | def compute_intr_reward(self, obs, skill, next_obs, step): 186 | 187 | with torch.no_grad(): 188 | loss, logits = self.compute_cpc_loss(obs, next_obs, skill) 189 | 190 | reward = loss 191 | reward = reward.clone().detach().unsqueeze(-1) 192 | 193 | return reward * self.scale 194 | 195 | @torch.no_grad() 196 | def compute_apt_reward(self, obs, next_obs): 197 | args = APTArgs() 198 | source = self.cic.state_net(obs) 199 | target = self.cic.state_net(next_obs) 200 | reward = compute_apt_reward(source, target, args, device=self.device, 201 | rms=self.rms) # (b,) 202 | return reward.unsqueeze(-1) # (b,1) 203 | 204 | def update(self, replay_iter, step): 205 | metrics = dict() 206 | 207 | if step % self.update_every_steps != 0: 208 | return metrics 209 | 210 | batch = next(replay_iter) 211 | 212 | obs, action, extr_reward, discount, next_obs, task_id, skill = utils.to_torch( 213 | batch, self.device) 214 | 215 | with torch.no_grad(): 216 | obs = self.aug_and_encode(obs) 217 | 218 | next_obs = self.aug_and_encode(next_obs) 219 | 220 | if self.reward_free: 221 | if self.update_rep: 222 | metrics.update(self.update_cic(obs, skill, next_obs, step)) 223 | 224 | intr_reward = self.compute_apt_reward(next_obs, next_obs) 225 | 226 | reward = intr_reward 227 | else: 228 | reward = extr_reward 229 | 230 | if self.use_tb or self.use_wandb: 231 | if self.reward_free: 232 | metrics['extr_reward'] = extr_reward.mean().item() 233 | # metrics['intr_reward'] = apt_reward.mean().item() 234 | metrics['batch_reward'] = reward.mean().item() 235 | 236 | # extend observations with skill 237 | obs = torch.cat([obs, skill], dim=1) 238 | next_obs = torch.cat([next_obs, skill], dim=1) 239 | 240 | # update critic 241 | metrics.update( 242 | self.update_critic(obs, action, reward, discount, next_obs, step)) 243 | 244 | # update actor 245 | metrics.update(self.update_actor(obs, step)) 246 | 247 | # update critic target 248 | utils.soft_update_params(self.critic, self.critic_target, 249 | self.critic_target_tau) 250 | 251 | return metrics -------------------------------------------------------------------------------- /DMC_state/agent/diayn.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import OrderedDict 3 | 4 | import hydra 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from dm_env import specs 10 | 11 | import utils 12 | from agent.ddpg import DDPGAgent 13 | 14 | 15 | class DIAYN(nn.Module): 16 | def __init__(self, obs_dim, skill_dim, hidden_dim): 17 | super().__init__() 18 | self.skill_pred_net = nn.Sequential(nn.Linear(obs_dim, hidden_dim), 19 | nn.ReLU(), 20 | nn.Linear(hidden_dim, hidden_dim), 21 | nn.ReLU(), 22 | nn.Linear(hidden_dim, skill_dim)) 23 | 24 | self.apply(utils.weight_init) 25 | 26 | def forward(self, obs): 27 | skill_pred = self.skill_pred_net(obs) 28 | return skill_pred 29 | 30 | 31 | class DIAYNAgent(DDPGAgent): 32 | def __init__(self, update_skill_every_step, skill_dim, diayn_scale, 33 | update_encoder, **kwargs): 34 | self.skill_dim = skill_dim 35 | self.update_skill_every_step = update_skill_every_step 36 | self.diayn_scale = diayn_scale 37 | self.update_encoder = update_encoder 38 | # increase obs shape to include skill dim 39 | kwargs["meta_dim"] = self.skill_dim 40 | 41 | # create actor and critic 42 | super().__init__(**kwargs) 43 | 44 | # create diayn 45 | self.diayn = DIAYN(self.obs_dim - self.skill_dim, self.skill_dim, 46 | kwargs['hidden_dim']).to(kwargs['device']) 47 | 48 | # loss criterion 49 | self.diayn_criterion = nn.CrossEntropyLoss() 50 | # optimizers 51 | self.diayn_opt = torch.optim.Adam(self.diayn.parameters(), lr=self.lr) 52 | 53 | self.diayn.train() 54 | 55 | def get_meta_specs(self): 56 | return (specs.Array((self.skill_dim,), np.float32, 'skill'),) 57 | 58 | def init_meta(self): 59 | skill = np.zeros(self.skill_dim, dtype=np.float32) 60 | skill[np.random.choice(self.skill_dim)] = 1.0 61 | meta = OrderedDict() 62 | meta['skill'] = skill 63 | return meta 64 | 65 | def update_meta(self, meta, global_step, time_step): 66 | if global_step % self.update_skill_every_step == 0: 67 | return self.init_meta() 68 | return meta 69 | 70 | def update_diayn(self, skill, next_obs, step): 71 | metrics = dict() 72 | 73 | loss, df_accuracy = self.compute_diayn_loss(next_obs, skill) 74 | 75 | self.diayn_opt.zero_grad() 76 | if self.encoder_opt is not None: 77 | self.encoder_opt.zero_grad(set_to_none=True) 78 | loss.backward() 79 | self.diayn_opt.step() 80 | if self.encoder_opt is not None: 81 | self.encoder_opt.step() 82 | 83 | if self.use_tb or self.use_wandb: 84 | metrics['diayn_loss'] = loss.item() 85 | metrics['diayn_acc'] = df_accuracy 86 | 87 | return metrics 88 | 89 | def compute_intr_reward(self, skill, next_obs, step): 90 | z_hat = torch.argmax(skill, dim=1) 91 | d_pred = self.diayn(next_obs) 92 | d_pred_log_softmax = F.log_softmax(d_pred, dim=1) 93 | _, pred_z = torch.max(d_pred_log_softmax, dim=1, keepdim=True) 94 | reward = d_pred_log_softmax[torch.arange(d_pred.shape[0]), 95 | z_hat] - math.log(1 / self.skill_dim) 96 | reward = reward.reshape(-1, 1) 97 | 98 | return reward * self.diayn_scale 99 | 100 | def compute_diayn_loss(self, next_state, skill): 101 | """ 102 | DF Loss 103 | """ 104 | z_hat = torch.argmax(skill, dim=1) 105 | d_pred = self.diayn(next_state) 106 | d_pred_log_softmax = F.log_softmax(d_pred, dim=1) 107 | _, pred_z = torch.max(d_pred_log_softmax, dim=1, keepdim=True) 108 | d_loss = self.diayn_criterion(d_pred, z_hat) 109 | df_accuracy = torch.sum( 110 | torch.eq(z_hat, 111 | pred_z.reshape(1, 112 | list( 113 | pred_z.size())[0])[0])).float() / list( 114 | pred_z.size())[0] 115 | return d_loss, df_accuracy 116 | 117 | def update(self, replay_iter, step): 118 | metrics = dict() 119 | 120 | if step % self.update_every_steps != 0: 121 | return metrics 122 | 123 | batch = next(replay_iter) 124 | 125 | obs, action, extr_reward, discount, next_obs, task_id, skill = utils.to_torch( 126 | batch, self.device) 127 | 128 | # augment and encode 129 | obs = self.aug_and_encode(obs) 130 | next_obs = self.aug_and_encode(next_obs) 131 | 132 | if self.reward_free: 133 | metrics.update(self.update_diayn(skill, next_obs, step)) 134 | 135 | with torch.no_grad(): 136 | intr_reward = self.compute_intr_reward(skill, next_obs, step) 137 | 138 | if self.use_tb or self.use_wandb: 139 | metrics['intr_reward'] = intr_reward.mean().item() 140 | reward = intr_reward 141 | else: 142 | reward = extr_reward 143 | 144 | if self.use_tb or self.use_wandb: 145 | metrics['extr_reward'] = extr_reward.mean().item() 146 | metrics['batch_reward'] = reward.mean().item() 147 | 148 | if not self.update_encoder: 149 | obs = obs.detach() 150 | next_obs = next_obs.detach() 151 | 152 | # extend observations with skill 153 | obs = torch.cat([obs, skill], dim=1) 154 | next_obs = torch.cat([next_obs, skill], dim=1) 155 | 156 | # update critic 157 | metrics.update( 158 | self.update_critic(obs.detach(), action, reward, discount, 159 | next_obs.detach(), step)) 160 | 161 | # update actor 162 | metrics.update(self.update_actor(obs.detach(), step)) 163 | 164 | # update critic target 165 | utils.soft_update_params(self.critic, self.critic_target, 166 | self.critic_target_tau) 167 | 168 | return metrics 169 | -------------------------------------------------------------------------------- /DMC_state/agent/disagreement.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | import utils 8 | from agent.ddpg import DDPGAgent 9 | 10 | 11 | class Disagreement(nn.Module): 12 | def __init__(self, obs_dim, action_dim, hidden_dim, n_models=5): 13 | super().__init__() 14 | self.ensemble = nn.ModuleList([ 15 | nn.Sequential(nn.Linear(obs_dim + action_dim, hidden_dim), 16 | nn.ReLU(), nn.Linear(hidden_dim, obs_dim)) 17 | for _ in range(n_models) 18 | ]) 19 | 20 | def forward(self, obs, action, next_obs): 21 | #import ipdb; ipdb.set_trace() 22 | assert obs.shape[0] == next_obs.shape[0] 23 | assert obs.shape[0] == action.shape[0] 24 | 25 | errors = [] 26 | for model in self.ensemble: 27 | next_obs_hat = model(torch.cat([obs, action], dim=-1)) 28 | model_error = torch.norm(next_obs - next_obs_hat, 29 | dim=-1, 30 | p=2, 31 | keepdim=True) 32 | errors.append(model_error) 33 | 34 | return torch.cat(errors, dim=1) 35 | 36 | def get_disagreement(self, obs, action, next_obs): 37 | assert obs.shape[0] == next_obs.shape[0] 38 | assert obs.shape[0] == action.shape[0] 39 | 40 | preds = [] 41 | for model in self.ensemble: 42 | next_obs_hat = model(torch.cat([obs, action], dim=-1)) 43 | preds.append(next_obs_hat) 44 | preds = torch.stack(preds, dim=0) 45 | return torch.var(preds, dim=0).mean(dim=-1) 46 | 47 | 48 | class DisagreementAgent(DDPGAgent): 49 | def __init__(self, update_encoder, **kwargs): 50 | super().__init__(**kwargs) 51 | self.update_encoder = update_encoder 52 | 53 | self.disagreement = Disagreement(self.obs_dim, self.action_dim, 54 | self.hidden_dim).to(self.device) 55 | 56 | # optimizers 57 | self.disagreement_opt = torch.optim.Adam( 58 | self.disagreement.parameters(), lr=self.lr) 59 | 60 | self.disagreement.train() 61 | 62 | def update_disagreement(self, obs, action, next_obs, step): 63 | metrics = dict() 64 | 65 | error = self.disagreement(obs, action, next_obs) 66 | 67 | loss = error.mean() 68 | 69 | self.disagreement_opt.zero_grad(set_to_none=True) 70 | if self.encoder_opt is not None: 71 | self.encoder_opt.zero_grad(set_to_none=True) 72 | loss.backward() 73 | self.disagreement_opt.step() 74 | if self.encoder_opt is not None: 75 | self.encoder_opt.step() 76 | 77 | if self.use_tb or self.use_wandb: 78 | metrics['disagreement_loss'] = loss.item() 79 | 80 | return metrics 81 | 82 | def compute_intr_reward(self, obs, action, next_obs, step): 83 | reward = self.disagreement.get_disagreement(obs, action, 84 | next_obs).unsqueeze(1) 85 | return reward 86 | 87 | def update(self, replay_iter, step): 88 | metrics = dict() 89 | 90 | if step % self.update_every_steps != 0: 91 | return metrics 92 | 93 | batch = next(replay_iter) 94 | obs, action, extr_reward, discount, next_obs, task_id = utils.to_torch( 95 | batch, self.device) 96 | 97 | # augment and encode 98 | obs = self.aug_and_encode(obs) 99 | with torch.no_grad(): 100 | next_obs = self.aug_and_encode(next_obs) 101 | 102 | if self.reward_free: 103 | metrics.update( 104 | self.update_disagreement(obs, action, next_obs, step)) 105 | 106 | with torch.no_grad(): 107 | intr_reward = self.compute_intr_reward(obs, action, next_obs, 108 | step) 109 | 110 | if self.use_tb or self.use_wandb: 111 | metrics['intr_reward'] = intr_reward.mean().item() 112 | reward = intr_reward 113 | else: 114 | reward = extr_reward 115 | 116 | if self.use_tb or self.use_wandb: 117 | metrics['extr_reward'] = extr_reward.mean().item() 118 | metrics['batch_reward'] = reward.mean().item() 119 | 120 | if not self.update_encoder: 121 | obs = obs.detach() 122 | next_obs = next_obs.detach() 123 | 124 | # update critic 125 | metrics.update( 126 | self.update_critic(obs.detach(), action, reward, discount, 127 | next_obs.detach(), step)) 128 | 129 | # update actor 130 | metrics.update(self.update_actor(obs.detach(), step)) 131 | 132 | # update critic target 133 | utils.soft_update_params(self.critic, self.critic_target, 134 | self.critic_target_tau) 135 | 136 | return metrics 137 | -------------------------------------------------------------------------------- /DMC_state/agent/icm.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | import utils 8 | from agent.ddpg import DDPGAgent 9 | 10 | 11 | class ICM(nn.Module): 12 | def __init__(self, obs_dim, action_dim, hidden_dim): 13 | super().__init__() 14 | 15 | self.forward_net = nn.Sequential( 16 | nn.Linear(obs_dim + action_dim, hidden_dim), nn.ReLU(), 17 | nn.Linear(hidden_dim, obs_dim)) 18 | 19 | self.backward_net = nn.Sequential(nn.Linear(2 * obs_dim, hidden_dim), 20 | nn.ReLU(), 21 | nn.Linear(hidden_dim, action_dim), 22 | nn.Tanh()) 23 | 24 | self.apply(utils.weight_init) 25 | 26 | def forward(self, obs, action, next_obs): 27 | assert obs.shape[0] == next_obs.shape[0] 28 | assert obs.shape[0] == action.shape[0] 29 | 30 | next_obs_hat = self.forward_net(torch.cat([obs, action], dim=-1)) 31 | action_hat = self.backward_net(torch.cat([obs, next_obs], dim=-1)) 32 | 33 | forward_error = torch.norm(next_obs - next_obs_hat, 34 | dim=-1, 35 | p=2, 36 | keepdim=True) 37 | backward_error = torch.norm(action - action_hat, 38 | dim=-1, 39 | p=2, 40 | keepdim=True) 41 | 42 | return forward_error, backward_error 43 | 44 | 45 | class ICMAgent(DDPGAgent): 46 | def __init__(self, icm_scale, update_encoder, **kwargs): 47 | super().__init__(**kwargs) 48 | self.icm_scale = icm_scale 49 | self.update_encoder = update_encoder 50 | 51 | self.icm = ICM(self.obs_dim, self.action_dim, 52 | self.hidden_dim).to(self.device) 53 | 54 | # optimizers 55 | self.icm_opt = torch.optim.Adam(self.icm.parameters(), lr=self.lr) 56 | 57 | self.icm.train() 58 | 59 | def update_icm(self, obs, action, next_obs, step): 60 | metrics = dict() 61 | 62 | forward_error, backward_error = self.icm(obs, action, next_obs) 63 | 64 | loss = forward_error.mean() + backward_error.mean() 65 | 66 | self.icm_opt.zero_grad(set_to_none=True) 67 | if self.encoder_opt is not None: 68 | self.encoder_opt.zero_grad(set_to_none=True) 69 | loss.backward() 70 | self.icm_opt.step() 71 | if self.encoder_opt is not None: 72 | self.encoder_opt.step() 73 | 74 | if self.use_tb or self.use_wandb: 75 | metrics['icm_loss'] = loss.item() 76 | 77 | return metrics 78 | 79 | def compute_intr_reward(self, obs, action, next_obs, step): 80 | forward_error, _ = self.icm(obs, action, next_obs) 81 | 82 | reward = forward_error * self.icm_scale 83 | reward = torch.log(reward + 1.0) 84 | return reward 85 | 86 | def update(self, replay_iter, step): 87 | metrics = dict() 88 | 89 | if step % self.update_every_steps != 0: 90 | return metrics 91 | 92 | batch = next(replay_iter) 93 | obs, action, extr_reward, discount, next_obs, task_id = utils.to_torch( 94 | batch, self.device) 95 | 96 | # augment and encode 97 | obs = self.aug_and_encode(obs) 98 | with torch.no_grad(): 99 | next_obs = self.aug_and_encode(next_obs) 100 | 101 | if self.reward_free: 102 | metrics.update(self.update_icm(obs, action, next_obs, step)) 103 | 104 | with torch.no_grad(): 105 | intr_reward = self.compute_intr_reward(obs, action, next_obs, 106 | step) 107 | 108 | if self.use_tb or self.use_wandb: 109 | metrics['intr_reward'] = intr_reward.mean().item() 110 | reward = intr_reward 111 | else: 112 | reward = extr_reward 113 | 114 | if self.use_tb or self.use_wandb: 115 | metrics['extr_reward'] = extr_reward.mean().item() 116 | metrics['batch_reward'] = reward.mean().item() 117 | 118 | if not self.update_encoder: 119 | obs = obs.detach() 120 | next_obs = next_obs.detach() 121 | 122 | # update critic 123 | metrics.update( 124 | self.update_critic(obs.detach(), action, reward, discount, 125 | next_obs.detach(), step)) 126 | 127 | # update actor 128 | metrics.update(self.update_actor(obs.detach(), step)) 129 | 130 | # update critic target 131 | utils.soft_update_params(self.critic, self.critic_target, 132 | self.critic_target_tau) 133 | 134 | return metrics 135 | -------------------------------------------------------------------------------- /DMC_state/agent/lbs.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.distributions as D 7 | 8 | import utils 9 | from agent.ddpg import DDPGAgent 10 | 11 | 12 | # s_t, a_t -> z_t+1 13 | # s_t, a_t, s_t+1 -> z_t+1 14 | # z_t+1 -> s_t+1 15 | class LBS(nn.Module): 16 | def __init__(self, obs_dim, action_dim, hidden_dim): 17 | super().__init__() 18 | 19 | self.pri_forward_net = nn.Sequential( 20 | nn.Linear(obs_dim + action_dim, hidden_dim), nn.ReLU(), 21 | nn.Linear(hidden_dim, hidden_dim)) 22 | 23 | self.pos_forward_net = nn.Sequential( 24 | nn.Linear(2 * obs_dim + action_dim, hidden_dim), nn.ReLU(), 25 | nn.Linear(hidden_dim, hidden_dim)) 26 | 27 | self.reconstruction_net = nn.Sequential( 28 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), 29 | nn.Linear(hidden_dim, obs_dim)) 30 | 31 | self.apply(utils.weight_init) 32 | 33 | def forward(self, obs, action, next_obs): 34 | assert obs.shape[0] == next_obs.shape[0] 35 | assert obs.shape[0] == action.shape[0] 36 | 37 | pri_z = self.pri_forward_net(torch.cat([obs, action], dim=-1)) 38 | pos_z = self.pos_forward_net(torch.cat([obs, action, next_obs], dim=-1)) 39 | 40 | reco_s = self.reconstruction_net(pos_z) 41 | 42 | pri_z = D.Independent(D.Normal(pri_z, 1.0), 1) 43 | pos_z = D.Independent(D.Normal(pos_z, 1.0), 1) 44 | reco_s = D.Independent(D.Normal(reco_s, 1.0), 1) 45 | 46 | kl_div = D.kl_divergence(pos_z, pri_z) 47 | 48 | reco_error = -reco_s.log_prob(next_obs).mean() 49 | kl_error = kl_div.mean() 50 | 51 | return kl_error, reco_error, kl_div.detach() 52 | 53 | 54 | class LBS_PRED(nn.Module): 55 | def __init__(self, obs_dim, action_dim, hidden_dim): 56 | super().__init__() 57 | 58 | self.pred_net = nn.Sequential( 59 | nn.Linear(obs_dim + action_dim, hidden_dim), nn.ReLU(), 60 | nn.Linear(hidden_dim, 1)) 61 | 62 | self.apply(utils.weight_init) 63 | 64 | def forward(self, obs, action, next_obs): 65 | assert obs.shape[0] == next_obs.shape[0] 66 | assert obs.shape[0] == action.shape[0] 67 | 68 | pred_kl = self.pred_net(torch.cat([obs, action], dim=-1)) 69 | pred_kl = D.Independent(D.Normal(pred_kl, 1.0), 1) 70 | 71 | return pred_kl 72 | 73 | 74 | class LBSAgent(DDPGAgent): 75 | def __init__(self, lbs_scale, update_encoder, **kwargs): 76 | super().__init__(**kwargs) 77 | self.lbs_scale = lbs_scale 78 | self.update_encoder = update_encoder 79 | 80 | self.lbs = LBS(self.obs_dim, self.action_dim, 81 | self.hidden_dim).to(self.device) 82 | # optimizers 83 | self.lbs_opt = torch.optim.Adam(self.lbs.parameters(), lr=self.lr) 84 | 85 | self.lbs_pred = LBS_PRED(self.obs_dim, self.action_dim, 86 | self.hidden_dim).to(self.device) 87 | self.lbs_pred_opt = torch.optim.Adam(self.lbs_pred.parameters(), lr=self.lr) 88 | 89 | self.lbs.train() 90 | self.lbs_pred.train() 91 | 92 | def update_lbs(self, obs, action, next_obs, step): 93 | metrics = dict() 94 | 95 | kl_error, reco_error, kl_div = self.lbs(obs, action, next_obs) 96 | 97 | lbs_loss = kl_error.mean() + reco_error.mean() 98 | 99 | self.lbs_opt.zero_grad(set_to_none=True) 100 | if self.encoder_opt is not None: 101 | self.encoder_opt.zero_grad(set_to_none=True) 102 | lbs_loss.backward() 103 | self.lbs_opt.step() 104 | if self.encoder_opt is not None: 105 | self.encoder_opt.step() 106 | 107 | kl_pred = self.lbs_pred(obs, action, next_obs) 108 | lbs_pred_loss = -kl_pred.log_prob(kl_div.detach()).mean() 109 | self.lbs_pred_opt.zero_grad(set_to_none=True) 110 | if self.encoder_opt is not None: 111 | self.encoder_opt.zero_grad(set_to_none=True) 112 | lbs_pred_loss.backward() 113 | self.lbs_pred_opt.step() 114 | if self.encoder_opt is not None: 115 | self.encoder_opt.step() 116 | 117 | if self.use_tb or self.use_wandb: 118 | metrics['lbs_loss'] = lbs_loss.item() 119 | metrics['lbs_pred_loss'] = lbs_pred_loss.item() 120 | 121 | return metrics 122 | 123 | def compute_intr_reward(self, obs, action, next_obs, step): 124 | kl_pred = self.lbs_pred(obs, action, next_obs) 125 | 126 | reward = kl_pred.mean * self.lbs_scale 127 | return reward 128 | 129 | def update(self, replay_iter, step): 130 | metrics = dict() 131 | 132 | if step % self.update_every_steps != 0: 133 | return metrics 134 | 135 | batch = next(replay_iter) 136 | obs, action, extr_reward, discount, next_obs, task_id = utils.to_torch( 137 | batch, self.device) 138 | 139 | # augment and encode 140 | obs = self.aug_and_encode(obs) 141 | with torch.no_grad(): 142 | next_obs = self.aug_and_encode(next_obs) 143 | 144 | if self.reward_free: 145 | metrics.update(self.update_lbs(obs, action, next_obs, step)) 146 | 147 | with torch.no_grad(): 148 | intr_reward = self.compute_intr_reward(obs, action, next_obs, 149 | step) 150 | 151 | if self.use_tb or self.use_wandb: 152 | metrics['intr_reward'] = intr_reward.mean().item() 153 | reward = intr_reward 154 | else: 155 | reward = extr_reward 156 | 157 | if self.use_tb or self.use_wandb: 158 | metrics['extr_reward'] = extr_reward.mean().item() 159 | metrics['batch_reward'] = reward.mean().item() 160 | 161 | if not self.update_encoder: 162 | obs = obs.detach() 163 | next_obs = next_obs.detach() 164 | 165 | # update critic 166 | metrics.update( 167 | self.update_critic(obs.detach(), action, reward, discount, 168 | next_obs.detach(), step)) 169 | 170 | # update actor 171 | metrics.update(self.update_actor(obs.detach(), step)) 172 | 173 | # update critic target 174 | utils.soft_update_params(self.critic, self.critic_target, 175 | self.critic_target_tau) 176 | 177 | return metrics 178 | -------------------------------------------------------------------------------- /DMC_state/agent/peac.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import GRUCell 7 | 8 | import utils 9 | from agent.ddpg import DDPGAgent 10 | 11 | 12 | class ContextModel(nn.Module): 13 | def __init__(self, obs_dim, act_dim, context_dim, hidden_dim, device='cuda'): 14 | super().__init__() 15 | 16 | self.device = device 17 | self.hidden_dim = hidden_dim 18 | self.model = GRUCell(obs_dim+act_dim, hidden_dim) 19 | self.context_head = nn.Sequential(nn.ReLU(), 20 | nn.Linear(hidden_dim, context_dim)) 21 | 22 | self.apply(utils.weight_init) 23 | 24 | def forward(self, obs, acts, hidden=None): 25 | if hidden is None: 26 | hidden = torch.zeros((obs.shape[0], self.hidden_dim), device=self.device) 27 | for i in range(obs.shape[1]): 28 | hidden = self.model(torch.cat([obs[:, i], acts[:, i]], dim=-1), 29 | hidden) 30 | context_pred = self.context_head(hidden) 31 | return context_pred 32 | 33 | class PEACAgent(DDPGAgent): 34 | def __init__(self, update_encoder, 35 | context_dim, **kwargs): 36 | super().__init__(**kwargs) 37 | self.update_encoder = update_encoder 38 | self.context_dim = context_dim 39 | print('context dim:', self.context_dim) 40 | 41 | self.task_model = ContextModel(self.obs_dim, self.action_dim, self.context_dim, 42 | self.hidden_dim, device=self.device).to(self.device) 43 | 44 | # optimizers 45 | self.task_opt = torch.optim.Adam(self.task_model.parameters(), lr=self.lr) 46 | 47 | self.task_model.train() 48 | 49 | def update_task_model(self, obs, action, next_obs, embodiment_id, pre_obs, pre_acts): 50 | metrics = dict() 51 | task_pred = self.task_model(pre_obs, pre_acts) 52 | # print(task_pred.shape) 53 | # print(torch.sum(embodiment_id)) 54 | loss = F.cross_entropy(task_pred, embodiment_id.reshape(-1)) 55 | 56 | self.task_opt.zero_grad(set_to_none=True) 57 | if self.encoder_opt is not None: 58 | self.encoder_opt.zero_grad(set_to_none=True) 59 | loss.backward() 60 | self.task_opt.step() 61 | if self.encoder_opt is not None: 62 | self.encoder_opt.step() 63 | 64 | if self.use_tb or self.use_wandb: 65 | metrics['task_loss'] = loss.item() 66 | 67 | return metrics 68 | 69 | def compute_intr_reward(self, obs, action, next_obs, embodiment_id, pre_obs, pre_acts): 70 | B, _ = action.shape 71 | task_pred = self.task_model(pre_obs, pre_acts) # B, task_num 72 | # calculate the task model predict prob 73 | task_pred = F.log_softmax(task_pred, dim=1) 74 | intr_rew = task_pred[torch.arange(B), embodiment_id.reshape(-1)] # B 75 | task_rew = -intr_rew.reshape(B, 1) 76 | return task_rew 77 | 78 | def update(self, replay_iter, step): 79 | metrics = dict() 80 | 81 | if step % self.update_every_steps != 0: 82 | return metrics 83 | 84 | batch = next(replay_iter) 85 | obs, action, extr_reward, discount, next_obs, embodiment_id, his_o, his_a = \ 86 | utils.to_torch(batch, self.device) 87 | 88 | # augment and encode 89 | obs = self.aug_and_encode(obs) 90 | with torch.no_grad(): 91 | next_obs = self.aug_and_encode(next_obs) 92 | 93 | if self.reward_free: 94 | with torch.no_grad(): 95 | intr_reward = self.compute_intr_reward(obs, action, next_obs, embodiment_id, 96 | his_o, his_a) 97 | 98 | if self.use_tb or self.use_wandb: 99 | metrics['intr_reward'] = intr_reward.mean().item() 100 | reward = intr_reward 101 | else: 102 | reward = extr_reward 103 | 104 | if self.use_tb or self.use_wandb: 105 | metrics['extr_reward'] = extr_reward.mean().item() 106 | metrics['batch_reward'] = reward.mean().item() 107 | 108 | if not self.update_encoder: 109 | obs = obs.detach() 110 | next_obs = next_obs.detach() 111 | 112 | metrics.update(self.update_task_model(obs.detach(), action, next_obs, embodiment_id, 113 | his_o, his_a)) 114 | 115 | # update critic 116 | metrics.update( 117 | self.update_critic(obs.detach(), action, reward, discount, 118 | next_obs.detach(), step)) 119 | 120 | # update actor 121 | metrics.update(self.update_actor(obs.detach(), step)) 122 | 123 | # update critic target 124 | utils.soft_update_params(self.critic, self.critic_target, 125 | self.critic_target_tau) 126 | 127 | return metrics 128 | -------------------------------------------------------------------------------- /DMC_state/agent/proto.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch import distributions as pyd 8 | from torch import jit 9 | 10 | import utils 11 | from agent.ddpg import DDPGAgent 12 | 13 | 14 | # @jit.script 15 | def sinkhorn_knopp(Q): 16 | Q -= Q.max() 17 | Q = torch.exp(Q).T 18 | Q /= Q.sum() 19 | 20 | r = torch.ones(Q.shape[0], device=Q.device) / Q.shape[0] 21 | c = torch.ones(Q.shape[1], device=Q.device) / Q.shape[1] 22 | for it in range(3): 23 | u = Q.sum(dim=1) 24 | u = r / u 25 | Q *= u.unsqueeze(dim=1) 26 | Q *= (c / Q.sum(dim=0)).unsqueeze(dim=0) 27 | Q = Q / Q.sum(dim=0, keepdim=True) 28 | return Q.T 29 | 30 | 31 | class Projector(nn.Module): 32 | def __init__(self, pred_dim, proj_dim): 33 | super().__init__() 34 | 35 | self.trunk = nn.Sequential(nn.Linear(pred_dim, proj_dim), nn.ReLU(), 36 | nn.Linear(proj_dim, pred_dim)) 37 | 38 | self.apply(utils.weight_init) 39 | 40 | def forward(self, x): 41 | return self.trunk(x) 42 | 43 | 44 | class ProtoAgent(DDPGAgent): 45 | def __init__(self, pred_dim, proj_dim, queue_size, num_protos, tau, 46 | encoder_target_tau, topk, update_encoder, **kwargs): 47 | super().__init__(**kwargs) 48 | self.tau = tau 49 | self.encoder_target_tau = encoder_target_tau 50 | self.topk = topk 51 | self.num_protos = num_protos 52 | self.update_encoder = update_encoder 53 | 54 | # models 55 | self.encoder_target = deepcopy(self.encoder) 56 | 57 | self.predictor = nn.Linear(self.obs_dim, pred_dim).to(self.device) 58 | self.predictor.apply(utils.weight_init) 59 | self.predictor_target = deepcopy(self.predictor) 60 | 61 | self.projector = Projector(pred_dim, proj_dim).to(self.device) 62 | self.projector.apply(utils.weight_init) 63 | 64 | # prototypes 65 | self.protos = nn.Linear(pred_dim, num_protos, 66 | bias=False).to(self.device) 67 | self.protos.apply(utils.weight_init) 68 | 69 | # candidate queue 70 | self.queue = torch.zeros(queue_size, pred_dim, device=self.device) 71 | self.queue_ptr = 0 72 | 73 | # optimizers 74 | self.proto_opt = torch.optim.Adam(utils.chain( 75 | self.encoder.parameters(), self.predictor.parameters(), 76 | self.projector.parameters(), self.protos.parameters()), 77 | lr=self.lr) 78 | 79 | self.predictor.train() 80 | self.projector.train() 81 | self.protos.train() 82 | 83 | def init_from(self, other): 84 | # copy parameters over 85 | utils.hard_update_params(other.encoder, self.encoder) 86 | utils.hard_update_params(other.actor, self.actor) 87 | utils.hard_update_params(other.predictor, self.predictor) 88 | utils.hard_update_params(other.projector, self.projector) 89 | utils.hard_update_params(other.protos, self.protos) 90 | if self.init_critic: 91 | utils.hard_update_params(other.critic, self.critic) 92 | 93 | def normalize_protos(self): 94 | C = self.protos.weight.data.clone() 95 | C = F.normalize(C, dim=1, p=2) 96 | self.protos.weight.data.copy_(C) 97 | 98 | def compute_intr_reward(self, obs, step): 99 | self.normalize_protos() 100 | # find a candidate for each prototype 101 | with torch.no_grad(): 102 | z = self.encoder(obs) 103 | z = self.predictor(z) 104 | z = F.normalize(z, dim=1, p=2) 105 | scores = self.protos(z).T 106 | prob = F.softmax(scores, dim=1) 107 | candidates = pyd.Categorical(prob).sample() 108 | 109 | # enqueue candidates 110 | ptr = self.queue_ptr 111 | self.queue[ptr:ptr + self.num_protos] = z[candidates] 112 | self.queue_ptr = (ptr + self.num_protos) % self.queue.shape[0] 113 | 114 | # compute distances between the batch and the queue of candidates 115 | z_to_q = torch.norm(z[:, None, :] - self.queue[None, :, :], dim=2, p=2) 116 | all_dists, _ = torch.topk(z_to_q, self.topk, dim=1, largest=False) 117 | dist = all_dists[:, -1:] 118 | reward = dist 119 | return reward 120 | 121 | def update_proto(self, obs, next_obs, step): 122 | metrics = dict() 123 | 124 | # normalize prototypes 125 | self.normalize_protos() 126 | 127 | # online network 128 | s = self.encoder(obs) 129 | s = self.predictor(s) 130 | s = self.projector(s) 131 | s = F.normalize(s, dim=1, p=2) 132 | scores_s = self.protos(s) 133 | log_p_s = F.log_softmax(scores_s / self.tau, dim=1) 134 | 135 | # target network 136 | with torch.no_grad(): 137 | t = self.encoder_target(next_obs) 138 | t = self.predictor_target(t) 139 | t = F.normalize(t, dim=1, p=2) 140 | scores_t = self.protos(t) 141 | q_t = sinkhorn_knopp(scores_t / self.tau) 142 | 143 | # loss 144 | loss = -(q_t * log_p_s).sum(dim=1).mean() 145 | if self.use_tb or self.use_wandb: 146 | metrics['repr_loss'] = loss.item() 147 | self.proto_opt.zero_grad(set_to_none=True) 148 | loss.backward() 149 | self.proto_opt.step() 150 | 151 | return metrics 152 | 153 | def update(self, replay_iter, step): 154 | metrics = dict() 155 | 156 | if step % self.update_every_steps != 0: 157 | return metrics 158 | 159 | batch = next(replay_iter) 160 | obs, action, extr_reward, discount, next_obs, task_id = utils.to_torch( 161 | batch, self.device) 162 | 163 | # augment and encode 164 | with torch.no_grad(): 165 | obs = self.aug(obs) 166 | next_obs = self.aug(next_obs) 167 | 168 | if self.reward_free: 169 | metrics.update(self.update_proto(obs, next_obs, step)) 170 | 171 | with torch.no_grad(): 172 | intr_reward = self.compute_intr_reward(next_obs, step) 173 | 174 | if self.use_tb or self.use_wandb: 175 | metrics['intr_reward'] = intr_reward.mean().item() 176 | reward = intr_reward 177 | else: 178 | reward = extr_reward 179 | 180 | if self.use_tb or self.use_wandb: 181 | metrics['extr_reward'] = extr_reward.mean().item() 182 | metrics['batch_reward'] = reward.mean().item() 183 | 184 | obs = self.encoder(obs) 185 | next_obs = self.encoder(next_obs) 186 | 187 | if not self.update_encoder: 188 | obs = obs.detach() 189 | next_obs = next_obs.detach() 190 | 191 | # update critic 192 | metrics.update( 193 | self.update_critic(obs.detach(), action, reward, discount, 194 | next_obs.detach(), step)) 195 | 196 | # update actor 197 | metrics.update(self.update_actor(obs.detach(), step)) 198 | 199 | # update critic target 200 | utils.soft_update_params(self.encoder, self.encoder_target, 201 | self.encoder_target_tau) 202 | utils.soft_update_params(self.predictor, self.predictor_target, 203 | self.encoder_target_tau) 204 | utils.soft_update_params(self.critic, self.critic_target, 205 | self.critic_target_tau) 206 | 207 | return metrics 208 | -------------------------------------------------------------------------------- /DMC_state/agent/rnd.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import hydra 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | import utils 10 | from agent.ddpg import DDPGAgent 11 | 12 | 13 | class RND(nn.Module): 14 | def __init__(self, 15 | obs_dim, 16 | hidden_dim, 17 | rnd_rep_dim, 18 | encoder, 19 | aug, 20 | obs_shape, 21 | obs_type, 22 | clip_val=5.): 23 | super().__init__() 24 | self.clip_val = clip_val 25 | self.aug = aug 26 | 27 | if obs_type == "pixels": 28 | self.normalize_obs = nn.BatchNorm2d(obs_shape[0], affine=False) 29 | else: 30 | self.normalize_obs = nn.BatchNorm1d(obs_shape[0], affine=False) 31 | 32 | self.predictor = nn.Sequential(encoder, nn.Linear(obs_dim, hidden_dim), 33 | nn.ReLU(), 34 | nn.Linear(hidden_dim, hidden_dim), 35 | nn.ReLU(), 36 | nn.Linear(hidden_dim, rnd_rep_dim)) 37 | self.target = nn.Sequential(copy.deepcopy(encoder), 38 | nn.Linear(obs_dim, hidden_dim), nn.ReLU(), 39 | nn.Linear(hidden_dim, hidden_dim), 40 | nn.ReLU(), 41 | nn.Linear(hidden_dim, rnd_rep_dim)) 42 | 43 | for param in self.target.parameters(): 44 | param.requires_grad = False 45 | 46 | self.apply(utils.weight_init) 47 | 48 | def forward(self, obs): 49 | obs = self.aug(obs) 50 | obs = self.normalize_obs(obs) 51 | obs = torch.clamp(obs, -self.clip_val, self.clip_val) 52 | prediction, target = self.predictor(obs), self.target(obs) 53 | prediction_error = torch.square(target.detach() - prediction).mean( 54 | dim=-1, keepdim=True) 55 | return prediction_error 56 | 57 | 58 | class RNDAgent(DDPGAgent): 59 | def __init__(self, rnd_rep_dim, update_encoder, rnd_scale=1., **kwargs): 60 | super().__init__(**kwargs) 61 | self.rnd_scale = rnd_scale 62 | self.update_encoder = update_encoder 63 | 64 | self.rnd = RND(self.obs_dim, self.hidden_dim, rnd_rep_dim, 65 | self.encoder, self.aug, self.obs_shape, 66 | self.obs_type).to(self.device) 67 | self.intrinsic_reward_rms = utils.RMS(device=self.device) 68 | 69 | # optimizers 70 | self.rnd_opt = torch.optim.Adam(self.rnd.parameters(), lr=self.lr) 71 | 72 | self.rnd.train() 73 | 74 | def update_rnd(self, obs, step): 75 | metrics = dict() 76 | 77 | prediction_error = self.rnd(obs) 78 | 79 | loss = prediction_error.mean() 80 | 81 | self.rnd_opt.zero_grad(set_to_none=True) 82 | if self.encoder_opt is not None: 83 | self.encoder_opt.zero_grad(set_to_none=True) 84 | loss.backward() 85 | self.rnd_opt.step() 86 | if self.encoder_opt is not None: 87 | self.encoder_opt.step() 88 | 89 | if self.use_tb or self.use_wandb: 90 | metrics['rnd_loss'] = loss.item() 91 | 92 | return metrics 93 | 94 | def compute_intr_reward(self, obs, step): 95 | prediction_error = self.rnd(obs) 96 | _, intr_reward_var = self.intrinsic_reward_rms(prediction_error) 97 | reward = self.rnd_scale * prediction_error / ( 98 | torch.sqrt(intr_reward_var) + 1e-8) 99 | return reward 100 | 101 | def update(self, replay_iter, step): 102 | metrics = dict() 103 | 104 | if step % self.update_every_steps != 0: 105 | return metrics 106 | 107 | batch = next(replay_iter) 108 | obs, action, extr_reward, discount, next_obs, task_id = utils.to_torch( 109 | batch, self.device) 110 | 111 | # update RND first 112 | if self.reward_free: 113 | # note: one difference is that the RND module is updated off policy 114 | metrics.update(self.update_rnd(obs, step)) 115 | 116 | with torch.no_grad(): 117 | intr_reward = self.compute_intr_reward(obs, step) 118 | 119 | if self.use_tb or self.use_wandb: 120 | metrics['intr_reward'] = intr_reward.mean().item() 121 | reward = intr_reward 122 | else: 123 | reward = extr_reward 124 | 125 | # augment and encode 126 | obs = self.aug_and_encode(obs) 127 | with torch.no_grad(): 128 | next_obs = self.aug_and_encode(next_obs) 129 | 130 | if self.use_tb or self.use_wandb: 131 | metrics['extr_reward'] = extr_reward.mean().item() 132 | metrics['batch_reward'] = reward.mean().item() 133 | 134 | metrics['pred_error_mean'] = self.intrinsic_reward_rms.M 135 | metrics['pred_error_std'] = torch.sqrt(self.intrinsic_reward_rms.S) 136 | 137 | if not self.update_encoder: 138 | obs = obs.detach() 139 | next_obs = next_obs.detach() 140 | 141 | # update critic 142 | metrics.update( 143 | self.update_critic(obs.detach(), action, reward, discount, 144 | next_obs.detach(), step)) 145 | 146 | # update actor 147 | metrics.update(self.update_actor(obs.detach(), step)) 148 | 149 | # update critic target 150 | utils.soft_update_params(self.critic, self.critic_target, 151 | self.critic_target_tau) 152 | 153 | return metrics 154 | -------------------------------------------------------------------------------- /DMC_state/configs/agent/aps.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.aps.APSAgent 3 | name: aps 4 | reward_free: ${reward_free} 5 | obs_type: ??? # to be specified later 6 | obs_shape: ??? # to be specified later 7 | action_shape: ??? # to be specified later 8 | device: ${device} 9 | lr: 1e-4 10 | critic_target_tau: 0.01 11 | update_every_steps: 2 12 | use_tb: ${use_tb} 13 | use_wandb: ${use_wandb} 14 | num_expl_steps: ??? # to be specified later 15 | hidden_dim: 1024 16 | feature_dim: 50 17 | stddev_schedule: 0.2 18 | stddev_clip: 0.3 19 | sf_dim: 10 20 | update_task_every_step: 5 21 | nstep: 3 22 | batch_size: 1024 23 | init_critic: true 24 | knn_rms: true 25 | knn_k: 12 26 | knn_avg: true 27 | knn_clip: 0.0001 28 | num_init_steps: 4096 # set to ${num_train_frames} to disable finetune policy parameters 29 | lstsq_batch_size: 4096 30 | update_encoder: ${update_encoder} 31 | -------------------------------------------------------------------------------- /DMC_state/configs/agent/becl.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.becl.BECLAgent 3 | name: becl 4 | reward_free: ${reward_free} 5 | obs_type: ??? # to be specified later 6 | obs_shape: ??? # to be specified later 7 | action_shape: ??? # to be specified later 8 | device: ${device} 9 | lr: 1e-4 10 | critic_target_tau: 0.01 11 | update_every_steps: 2 12 | use_tb: ${use_tb} 13 | use_wandb: ${use_wandb} 14 | num_expl_steps: ??? # to be specified later 15 | hidden_dim: 1024 16 | feature_dim: 50 17 | stddev_schedule: 0.2 18 | stddev_clip: 0.3 19 | skill_dim: 16 20 | update_skill_every_step: 50 21 | nstep: 3 22 | batch_size: 1024 23 | init_critic: true 24 | update_encoder: ${update_encoder} 25 | 26 | # extra hyperparameter 27 | contrastive_update_rate: 3 28 | temperature: 0.5 29 | 30 | # skill finetuning ablation 31 | skill: -1 -------------------------------------------------------------------------------- /DMC_state/configs/agent/cic.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.cic.CICAgent 3 | name: cic 4 | reward_free: ${reward_free} 5 | obs_type: ??? # to be specified later 6 | obs_shape: ??? # to be specified later 7 | action_shape: ??? # to be specified later 8 | device: ${device} 9 | lr: 1e-4 10 | critic_target_tau: 0.01 11 | update_every_steps: 2 12 | use_tb: ${use_tb} 13 | use_wandb: ${use_wandb} 14 | num_expl_steps: 2000 15 | hidden_dim: 1024 16 | feature_dim: 1024 17 | stddev_schedule: 0.2 18 | stddev_clip: 0.3 19 | skill_dim: 64 20 | scale: 1.0 21 | update_skill_every_step: 50 22 | nstep: 3 23 | batch_size: 1024 24 | project_skill: true 25 | init_critic: true 26 | rew_type: og 27 | update_rep: true 28 | temp: 0.5 -------------------------------------------------------------------------------- /DMC_state/configs/agent/ddpg.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.ddpg.DDPGAgent 3 | name: ddpg 4 | reward_free: ${reward_free} 5 | obs_type: ??? # to be specified later 6 | obs_shape: ??? # to be specified later 7 | action_shape: ??? # to be specified later 8 | device: ${device} 9 | lr: 1e-4 10 | critic_target_tau: 0.01 11 | update_every_steps: 2 12 | use_tb: ${use_tb} 13 | use_wandb: ${use_wandb} 14 | num_expl_steps: ??? # to be specified later 15 | hidden_dim: 1024 16 | feature_dim: 50 17 | stddev_schedule: 0.2 18 | stddev_clip: 0.3 19 | nstep: 3 20 | batch_size: 1024 # 256 for pixels 21 | init_critic: true 22 | update_encoder: ${update_encoder} 23 | -------------------------------------------------------------------------------- /DMC_state/configs/agent/diayn.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.diayn.DIAYNAgent 3 | name: diayn 4 | reward_free: ${reward_free} 5 | obs_type: ??? # to be specified later 6 | obs_shape: ??? # to be specified later 7 | action_shape: ??? # to be specified later 8 | device: ${device} 9 | lr: 1e-4 10 | critic_target_tau: 0.01 11 | update_every_steps: 2 12 | use_tb: ${use_tb} 13 | use_wandb: ${use_wandb} 14 | num_expl_steps: ??? # to be specified later 15 | hidden_dim: 1024 16 | feature_dim: 50 17 | stddev_schedule: 0.2 18 | stddev_clip: 0.3 19 | skill_dim: 16 20 | diayn_scale: 1.0 21 | update_skill_every_step: 50 22 | nstep: 3 23 | batch_size: 1024 24 | init_critic: true 25 | update_encoder: ${update_encoder} 26 | -------------------------------------------------------------------------------- /DMC_state/configs/agent/disagreement.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.disagreement.DisagreementAgent 3 | name: disagreement 4 | reward_free: ${reward_free} 5 | obs_type: ??? # to be specified later 6 | obs_shape: ??? # to be specified later 7 | action_shape: ??? # to be specified later 8 | device: ${device} 9 | lr: 1e-4 10 | critic_target_tau: 0.01 11 | update_every_steps: 2 12 | use_tb: ${use_tb} 13 | use_wandb: ${use_wandb} 14 | num_expl_steps: ??? # to be specified later 15 | hidden_dim: 1024 16 | feature_dim: 50 17 | stddev_schedule: 0.2 18 | stddev_clip: 0.3 19 | nstep: 3 20 | batch_size: 1024 21 | init_critic: true 22 | update_encoder: ${update_encoder} 23 | -------------------------------------------------------------------------------- /DMC_state/configs/agent/icm.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.icm.ICMAgent 3 | name: icm 4 | reward_free: ${reward_free} 5 | obs_type: ??? # to be specified later 6 | obs_shape: ??? # to be specified later 7 | action_shape: ??? # to be specified later 8 | device: ${device} 9 | lr: 1e-4 10 | critic_target_tau: 0.01 11 | update_every_steps: 2 12 | use_tb: ${use_tb} 13 | use_wandb: ${use_wandb} 14 | num_expl_steps: ??? # to be specified later 15 | hidden_dim: 1024 16 | feature_dim: 50 17 | stddev_schedule: 0.2 18 | stddev_clip: 0.3 19 | icm_scale: 1.0 20 | nstep: 3 21 | batch_size: 1024 22 | init_critic: true 23 | update_encoder: ${update_encoder} 24 | -------------------------------------------------------------------------------- /DMC_state/configs/agent/lbs.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.lbs.LBSAgent 3 | name: lbs 4 | reward_free: ${reward_free} 5 | obs_type: ??? # to be specified later 6 | obs_shape: ??? # to be specified later 7 | action_shape: ??? # to be specified later 8 | device: ${device} 9 | lr: 1e-4 10 | critic_target_tau: 0.01 11 | update_every_steps: 2 12 | use_tb: ${use_tb} 13 | use_wandb: ${use_wandb} 14 | num_expl_steps: ??? # to be specified later 15 | hidden_dim: 1024 16 | feature_dim: 50 17 | stddev_schedule: 0.2 18 | stddev_clip: 0.3 19 | nstep: 3 20 | batch_size: 1024 21 | init_critic: true 22 | update_encoder: ${update_encoder} 23 | 24 | lbs_scale: 1.0 -------------------------------------------------------------------------------- /DMC_state/configs/agent/peac.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.peac.PEACAgent 3 | name: peac 4 | reward_free: ${reward_free} 5 | obs_type: ??? # to be specified later 6 | obs_shape: ??? # to be specified later 7 | action_shape: ??? # to be specified later 8 | device: ${device} 9 | lr: 1e-4 10 | critic_target_tau: 0.01 11 | update_every_steps: 2 12 | use_tb: ${use_tb} 13 | use_wandb: ${use_wandb} 14 | num_expl_steps: ??? # to be specified later 15 | hidden_dim: 1024 16 | feature_dim: 50 17 | stddev_schedule: 0.2 18 | stddev_clip: 0.3 19 | nstep: 3 20 | batch_size: 1024 21 | init_critic: true 22 | update_encoder: ${update_encoder} 23 | 24 | his_o_a: 10 25 | context_dim: 1 -------------------------------------------------------------------------------- /DMC_state/configs/agent/proto.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.proto.ProtoAgent 3 | name: proto 4 | reward_free: ${reward_free} 5 | obs_type: ??? # to be specified later 6 | obs_shape: ??? # to be specified later 7 | action_shape: ??? # to be specified later 8 | device: ${device} 9 | lr: 1e-4 10 | critic_target_tau: 0.01 11 | update_every_steps: 2 12 | use_tb: ${use_tb} 13 | use_wandb: ${use_wandb} 14 | num_expl_steps: ??? # to be specified later 15 | hidden_dim: 1024 16 | feature_dim: 50 17 | stddev_schedule: 0.2 18 | stddev_clip: 0.3 19 | nstep: 3 20 | batch_size: 1024 21 | init_critic: true 22 | pred_dim: 128 23 | proj_dim: 512 24 | num_protos: 512 25 | tau: 0.1 26 | topk: 3 27 | queue_size: 2048 28 | encoder_target_tau: 0.05 29 | update_encoder: ${update_encoder} 30 | -------------------------------------------------------------------------------- /DMC_state/configs/agent/rnd.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.rnd.RNDAgent 3 | name: rnd 4 | reward_free: ${reward_free} 5 | obs_type: ??? # to be specified later 6 | obs_shape: ??? # to be specified later 7 | action_shape: ??? # to be specified later 8 | device: ${device} 9 | lr: 1e-4 10 | critic_target_tau: 0.01 11 | update_every_steps: 2 12 | use_tb: ${use_tb} 13 | use_wandb: ${use_wandb} 14 | num_expl_steps: ??? # to be specified later 15 | hidden_dim: 1024 16 | feature_dim: 50 17 | rnd_rep_dim: 512 18 | stddev_schedule: 0.2 19 | stddev_clip: 0.3 20 | rnd_scale: 1.0 21 | nstep: 3 22 | batch_size: 1024 23 | init_critic: true 24 | update_encoder: ${update_encoder} 25 | -------------------------------------------------------------------------------- /DMC_state/configs/agent/smm.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.smm.SMMAgent 3 | name: smm 4 | 5 | # z params 6 | z_dim: 4 # default in codebase is 4 7 | 8 | # z discriminator params 9 | sp_lr: 1e-3 10 | 11 | # vae params 12 | vae_lr: 1e-2 13 | vae_beta: 0.5 14 | 15 | # reward params 16 | state_ent_coef: 1.0 17 | latent_ent_coef: 1.0 18 | latent_cond_ent_coef: 1.0 19 | 20 | # DDPG params 21 | reward_free: ${reward_free} 22 | obs_type: ??? # to be specified later 23 | obs_shape: ??? # to be specified later 24 | action_shape: ??? # to be specified later 25 | device: ${device} 26 | lr: 1e-4 27 | critic_target_tau: 0.01 28 | update_every_steps: 2 29 | use_tb: ${use_tb} 30 | use_wandb: ${use_wandb} 31 | num_expl_steps: ??? # to be specified later 32 | hidden_dim: 1024 33 | feature_dim: 50 34 | stddev_schedule: 0.2 35 | stddev_clip: 0.3 36 | nstep: 3 37 | batch_size: 1024 38 | init_critic: true 39 | update_encoder: ${update_encoder} -------------------------------------------------------------------------------- /DMC_state/custom_dmc_tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from custom_dmc_tasks import walker 2 | from custom_dmc_tasks import quadruped 3 | from custom_dmc_tasks import jaco 4 | 5 | 6 | def make(domain, task, 7 | task_kwargs=None, 8 | environment_kwargs=None, 9 | visualize_reward=False): 10 | 11 | if domain == 'walker': 12 | return walker.make(task, 13 | task_kwargs=task_kwargs, 14 | environment_kwargs=environment_kwargs, 15 | visualize_reward=visualize_reward) 16 | elif domain == 'quadruped': 17 | return quadruped.make(task, 18 | task_kwargs=task_kwargs, 19 | environment_kwargs=environment_kwargs, 20 | visualize_reward=visualize_reward) 21 | else: 22 | raise f'{task} not found' 23 | 24 | assert None 25 | 26 | 27 | def make_jaco(task, obs_type, seed): 28 | return jaco.make(task, obs_type, seed) -------------------------------------------------------------------------------- /DMC_state/custom_dmc_tasks/jaco.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """A task where the goal is to move the hand close to a target prop or site.""" 17 | 18 | import collections 19 | 20 | from dm_control import composer 21 | from dm_control.composer import initializers 22 | from dm_control.composer.observation import observable 23 | from dm_control.composer.variation import distributions 24 | from dm_control.entities import props 25 | from dm_control.manipulation.shared import arenas 26 | from dm_control.manipulation.shared import cameras 27 | from dm_control.manipulation.shared import constants 28 | from dm_control.manipulation.shared import observations 29 | from dm_control.manipulation.shared import registry 30 | from dm_control.manipulation.shared import robots 31 | from dm_control.manipulation.shared import tags 32 | from dm_control.manipulation.shared import workspaces 33 | from dm_control.utils import rewards 34 | import numpy as np 35 | 36 | _ReachWorkspace = collections.namedtuple( 37 | '_ReachWorkspace', ['target_bbox', 'tcp_bbox', 'arm_offset']) 38 | 39 | # Ensures that the props are not touching the table before settling. 40 | _PROP_Z_OFFSET = 0.001 41 | 42 | _DUPLO_WORKSPACE = _ReachWorkspace( 43 | target_bbox=workspaces.BoundingBox( 44 | lower=(-0.1, -0.1, _PROP_Z_OFFSET), 45 | upper=(0.1, 0.1, _PROP_Z_OFFSET)), 46 | tcp_bbox=workspaces.BoundingBox( 47 | lower=(-0.1, -0.1, 0.2), 48 | upper=(0.1, 0.1, 0.4)), 49 | arm_offset=robots.ARM_OFFSET) 50 | 51 | _SITE_WORKSPACE = _ReachWorkspace( 52 | target_bbox=workspaces.BoundingBox( 53 | lower=(-0.2, -0.2, 0.02), 54 | upper=(0.2, 0.2, 0.4)), 55 | tcp_bbox=workspaces.BoundingBox( 56 | lower=(-0.2, -0.2, 0.02), 57 | upper=(0.2, 0.2, 0.4)), 58 | arm_offset=robots.ARM_OFFSET) 59 | 60 | _TARGET_RADIUS = 0.05 61 | _TIME_LIMIT = 10. 62 | 63 | TASKS = { 64 | 'reach_top_left': workspaces.BoundingBox( 65 | lower=(-0.09, 0.09, _PROP_Z_OFFSET), 66 | upper=(-0.09, 0.09, _PROP_Z_OFFSET)), 67 | 'reach_top_right': workspaces.BoundingBox( 68 | lower=(0.09, 0.09, _PROP_Z_OFFSET), 69 | upper=(0.09, 0.09, _PROP_Z_OFFSET)), 70 | 'reach_bottom_left': workspaces.BoundingBox( 71 | lower=(-0.09, -0.09, _PROP_Z_OFFSET), 72 | upper=(-0.09, -0.09, _PROP_Z_OFFSET)), 73 | 'reach_bottom_right': workspaces.BoundingBox( 74 | lower=(0.09, -0.09, _PROP_Z_OFFSET), 75 | upper=(0.09, -0.09, _PROP_Z_OFFSET)), 76 | } 77 | 78 | 79 | def make(task_id, obs_type, seed): 80 | obs_settings = observations.VISION if obs_type == 'pixels' else observations.PERFECT_FEATURES 81 | if obs_type == 'states': 82 | global _TIME_LIMIT 83 | _TIME_LIMIT = 10.04 84 | # Note: Adding this fixes the problem of having 249 steps with action repeat = 1 85 | task = _reach(task_id, obs_settings=obs_settings, use_site=False) 86 | return composer.Environment(task, time_limit=_TIME_LIMIT, random_state=seed) 87 | 88 | 89 | class MTReach(composer.Task): 90 | """Bring the hand close to a target prop or site.""" 91 | 92 | def __init__( 93 | self, task_id, arena, arm, hand, prop, obs_settings, workspace, control_timestep): 94 | """Initializes a new `Reach` task. 95 | 96 | Args: 97 | arena: `composer.Entity` instance. 98 | arm: `robot_base.RobotArm` instance. 99 | hand: `robot_base.RobotHand` instance. 100 | prop: `composer.Entity` instance specifying the prop to reach to, or None 101 | in which case the target is a fixed site whose position is specified by 102 | the workspace. 103 | obs_settings: `observations.ObservationSettings` instance. 104 | workspace: `_ReachWorkspace` specifying the placement of the prop and TCP. 105 | control_timestep: Float specifying the control timestep in seconds. 106 | """ 107 | self._arena = arena 108 | self._arm = arm 109 | self._hand = hand 110 | self._arm.attach(self._hand) 111 | self._arena.attach_offset(self._arm, offset=workspace.arm_offset) 112 | self.control_timestep = control_timestep 113 | self._tcp_initializer = initializers.ToolCenterPointInitializer( 114 | self._hand, self._arm, 115 | position=distributions.Uniform(*workspace.tcp_bbox), 116 | quaternion=workspaces.DOWN_QUATERNION) 117 | 118 | # Add custom camera observable. 119 | self._task_observables = cameras.add_camera_observables( 120 | arena, obs_settings, cameras.FRONT_CLOSE) 121 | 122 | target_pos_distribution = distributions.Uniform(*TASKS[task_id]) 123 | self._prop = prop 124 | if prop: 125 | # The prop itself is used to visualize the target location. 126 | self._make_target_site(parent_entity=prop, visible=False) 127 | self._target = self._arena.add_free_entity(prop) 128 | self._prop_placer = initializers.PropPlacer( 129 | props=[prop], 130 | position=target_pos_distribution, 131 | quaternion=workspaces.uniform_z_rotation, 132 | settle_physics=True) 133 | else: 134 | self._target = self._make_target_site(parent_entity=arena, visible=True) 135 | self._target_placer = target_pos_distribution 136 | 137 | obs = observable.MJCFFeature('pos', self._target) 138 | obs.configure(**obs_settings.prop_pose._asdict()) 139 | self._task_observables['target_position'] = obs 140 | 141 | # Add sites for visualizing the prop and target bounding boxes. 142 | workspaces.add_bbox_site( 143 | body=self.root_entity.mjcf_model.worldbody, 144 | lower=workspace.tcp_bbox.lower, upper=workspace.tcp_bbox.upper, 145 | rgba=constants.GREEN, name='tcp_spawn_area') 146 | workspaces.add_bbox_site( 147 | body=self.root_entity.mjcf_model.worldbody, 148 | lower=workspace.target_bbox.lower, upper=workspace.target_bbox.upper, 149 | rgba=constants.BLUE, name='target_spawn_area') 150 | 151 | def _make_target_site(self, parent_entity, visible): 152 | return workspaces.add_target_site( 153 | body=parent_entity.mjcf_model.worldbody, 154 | radius=_TARGET_RADIUS, visible=visible, 155 | rgba=constants.RED, name='target_site') 156 | 157 | @property 158 | def root_entity(self): 159 | return self._arena 160 | 161 | @property 162 | def arm(self): 163 | return self._arm 164 | 165 | @property 166 | def hand(self): 167 | return self._hand 168 | 169 | @property 170 | def task_observables(self): 171 | return self._task_observables 172 | 173 | def get_reward(self, physics): 174 | hand_pos = physics.bind(self._hand.tool_center_point).xpos 175 | target_pos = physics.bind(self._target).xpos 176 | distance = np.linalg.norm(hand_pos - target_pos) 177 | return rewards.tolerance( 178 | distance, bounds=(0, _TARGET_RADIUS), margin=_TARGET_RADIUS) 179 | 180 | def initialize_episode(self, physics, random_state): 181 | self._hand.set_grasp(physics, close_factors=random_state.uniform()) 182 | self._tcp_initializer(physics, random_state) 183 | if self._prop: 184 | self._prop_placer(physics, random_state) 185 | else: 186 | physics.bind(self._target).pos = ( 187 | self._target_placer(random_state=random_state)) 188 | 189 | 190 | def _reach(task_id, obs_settings, use_site): 191 | """Configure and instantiate a `Reach` task. 192 | 193 | Args: 194 | obs_settings: An `observations.ObservationSettings` instance. 195 | use_site: Boolean, if True then the target will be a fixed site, otherwise 196 | it will be a moveable Duplo brick. 197 | 198 | Returns: 199 | An instance of `reach.Reach`. 200 | """ 201 | arena = arenas.Standard() 202 | arm = robots.make_arm(obs_settings=obs_settings) 203 | hand = robots.make_hand(obs_settings=obs_settings) 204 | if use_site: 205 | workspace = _SITE_WORKSPACE 206 | prop = None 207 | else: 208 | workspace = _DUPLO_WORKSPACE 209 | prop = props.Duplo(observable_options=observations.make_options( 210 | obs_settings, observations.FREEPROP_OBSERVABLES)) 211 | task = MTReach(task_id, arena=arena, arm=arm, hand=hand, prop=prop, 212 | obs_settings=obs_settings, 213 | workspace=workspace, 214 | control_timestep=constants.CONTROL_TIMESTEP) 215 | return task 216 | -------------------------------------------------------------------------------- /DMC_state/custom_dmc_tasks/walker.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Planar Walker Domain.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import os 23 | 24 | from dm_control import mujoco 25 | from dm_control.rl import control 26 | from dm_control.suite import base 27 | from dm_control.suite import common 28 | from dm_control.suite.utils import randomizers 29 | from dm_control.utils import containers 30 | from dm_control.utils import rewards 31 | from dm_control.utils import io as resources 32 | from dm_control import suite 33 | 34 | _DEFAULT_TIME_LIMIT = 25 35 | _CONTROL_TIMESTEP = .025 36 | 37 | # Minimal height of torso over foot above which stand reward is 1. 38 | _STAND_HEIGHT = 1.2 39 | 40 | # Horizontal speeds (meters/second) above which move reward is 1. 41 | _WALK_SPEED = 1 42 | _RUN_SPEED = 8 43 | _SPIN_SPEED = 5 44 | 45 | SUITE = containers.TaggedTasks() 46 | 47 | def make(task, 48 | task_kwargs=None, 49 | environment_kwargs=None, 50 | visualize_reward=False): 51 | task_kwargs = task_kwargs or {} 52 | if environment_kwargs is not None: 53 | task_kwargs = task_kwargs.copy() 54 | task_kwargs['environment_kwargs'] = environment_kwargs 55 | env = SUITE[task](**task_kwargs) 56 | env.task.visualize_reward = visualize_reward 57 | return env 58 | 59 | def get_model_and_assets(): 60 | """Returns a tuple containing the model XML string and a dict of assets.""" 61 | root_dir = os.path.dirname(os.path.dirname(__file__)) 62 | xml = resources.GetResource(os.path.join(root_dir, 'custom_dmc_tasks', 63 | 'walker.xml')) 64 | return xml, common.ASSETS 65 | 66 | 67 | 68 | 69 | 70 | 71 | @SUITE.add('benchmarking') 72 | def flip(time_limit=_DEFAULT_TIME_LIMIT, 73 | random=None, 74 | environment_kwargs=None): 75 | """Returns the Run task.""" 76 | physics = Physics.from_xml_string(*get_model_and_assets()) 77 | task = PlanarWalker(move_speed=_RUN_SPEED, 78 | forward=True, 79 | flip=True, 80 | random=random) 81 | environment_kwargs = environment_kwargs or {} 82 | return control.Environment(physics, 83 | task, 84 | time_limit=time_limit, 85 | control_timestep=_CONTROL_TIMESTEP, 86 | **environment_kwargs) 87 | 88 | 89 | class Physics(mujoco.Physics): 90 | """Physics simulation with additional features for the Walker domain.""" 91 | def torso_upright(self): 92 | """Returns projection from z-axes of torso to the z-axes of world.""" 93 | return self.named.data.xmat['torso', 'zz'] 94 | 95 | def torso_height(self): 96 | """Returns the height of the torso.""" 97 | return self.named.data.xpos['torso', 'z'] 98 | 99 | def horizontal_velocity(self): 100 | """Returns the horizontal velocity of the center-of-mass.""" 101 | return self.named.data.sensordata['torso_subtreelinvel'][0] 102 | 103 | def orientations(self): 104 | """Returns planar orientations of all bodies.""" 105 | return self.named.data.xmat[1:, ['xx', 'xz']].ravel() 106 | 107 | def angmomentum(self): 108 | """Returns the angular momentum of torso of the Cheetah about Y axis.""" 109 | return self.named.data.subtree_angmom['torso'][1] 110 | 111 | 112 | class PlanarWalker(base.Task): 113 | """A planar walker task.""" 114 | def __init__(self, move_speed, forward=True, flip=False, random=None): 115 | """Initializes an instance of `PlanarWalker`. 116 | 117 | Args: 118 | move_speed: A float. If this value is zero, reward is given simply for 119 | standing up. Otherwise this specifies a target horizontal velocity for 120 | the walking task. 121 | random: Optional, either a `numpy.random.RandomState` instance, an 122 | integer seed for creating a new `RandomState`, or None to select a seed 123 | automatically (default). 124 | """ 125 | self._move_speed = move_speed 126 | self._forward = 1 if forward else -1 127 | self._flip = flip 128 | super(PlanarWalker, self).__init__(random=random) 129 | 130 | def initialize_episode(self, physics): 131 | """Sets the state of the environment at the start of each episode. 132 | 133 | In 'standing' mode, use initial orientation and small velocities. 134 | In 'random' mode, randomize joint angles and let fall to the floor. 135 | 136 | Args: 137 | physics: An instance of `Physics`. 138 | 139 | """ 140 | randomizers.randomize_limited_and_rotational_joints( 141 | physics, self.random) 142 | super(PlanarWalker, self).initialize_episode(physics) 143 | 144 | def get_observation(self, physics): 145 | """Returns an observation of body orientations, height and velocites.""" 146 | obs = collections.OrderedDict() 147 | obs['orientations'] = physics.orientations() 148 | obs['height'] = physics.torso_height() 149 | obs['velocity'] = physics.velocity() 150 | return obs 151 | 152 | def get_reward(self, physics): 153 | """Returns a reward to the agent.""" 154 | standing = rewards.tolerance(physics.torso_height(), 155 | bounds=(_STAND_HEIGHT, float('inf')), 156 | margin=_STAND_HEIGHT / 2) 157 | upright = (1 + physics.torso_upright()) / 2 158 | stand_reward = (3 * standing + upright) / 4 159 | 160 | if self._flip: 161 | move_reward = rewards.tolerance(self._forward * 162 | physics.angmomentum(), 163 | bounds=(_SPIN_SPEED, float('inf')), 164 | margin=_SPIN_SPEED, 165 | value_at_margin=0, 166 | sigmoid='linear') 167 | else: 168 | move_reward = rewards.tolerance( 169 | self._forward * physics.horizontal_velocity(), 170 | bounds=(self._move_speed, float('inf')), 171 | margin=self._move_speed / 2, 172 | value_at_margin=0.5, 173 | sigmoid='linear') 174 | 175 | return stand_reward * (5 * move_reward + 1) / 6 176 | -------------------------------------------------------------------------------- /DMC_state/custom_dmc_tasks/walker.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /DMC_state/dmc_benchmark.py: -------------------------------------------------------------------------------- 1 | DOMAINS = [ 2 | 'walker', 3 | 'quadruped', 4 | 'jaco', 5 | ] 6 | 7 | WALKER_TASKS = [ 8 | 'walker_stand', 9 | 'walker_walk', 10 | 'walker_run', 11 | 'walker_flip', 12 | ] 13 | 14 | QUADRUPED_TASKS = [ 15 | 'quadruped_walk', 16 | 'quadruped_run', 17 | 'quadruped_stand', 18 | 'quadruped_jump', 19 | ] 20 | 21 | JACO_TASKS = [ 22 | 'jaco_reach_top_left', 23 | 'jaco_reach_top_right', 24 | 'jaco_reach_bottom_left', 25 | 'jaco_reach_bottom_right', 26 | ] 27 | 28 | TASKS = WALKER_TASKS + QUADRUPED_TASKS + JACO_TASKS 29 | 30 | parameter_1 = ['0.2', '0.6', '1.0', '1.4', '1.8'] 31 | parameter_1_eval = ['0.4', '0.8', '1.2', '1.6'] 32 | 33 | parameter_2 = ['0.4', '0.8', '1.0', '1.4'] 34 | parameter_2_eval = ['0.6', '1.2'] 35 | 36 | PRETRAIN_TASKS = { 37 | 'walker': 'walker_stand', 38 | 'jaco': 'jaco_reach_top_left', 39 | 'quadruped': 'quadruped_walk', 40 | 'walker_mass': ['walker_stand~mass~' + para for para in parameter_1], 41 | 'quadruped_mass': ['quadruped_stand~mass~' + para for para in parameter_2], 42 | 'quadruped_damping': ['quadruped_stand~damping~' + para for para in parameter_1], 43 | } 44 | 45 | FINETUNE_TASKS = { 46 | 'walker_stand_mass': ['walker_stand~mass~' + para for para in parameter_1], 47 | 'walker_stand_mass_eval': ['walker_stand~mass~' + para for para in parameter_1_eval], 48 | 'walker_walk_mass': ['walker_walk~mass~' + para for para in parameter_1], 49 | 'walker_walk_mass_eval': ['walker_walk~mass~' + para for para in parameter_1_eval], 50 | 'walker_run_mass': ['walker_run~mass~' + para for para in parameter_1], 51 | 'walker_run_mass_eval': ['walker_run~mass~' + para for para in parameter_1_eval], 52 | 'walker_flip_mass': ['walker_flip~mass~' + para for para in parameter_1], 53 | 'walker_flip_mass_eval': ['walker_flip~mass~' + para for para in parameter_1_eval], 54 | 55 | 'quadruped_stand_mass': ['quadruped_stand~mass~' + para for para in parameter_2], 56 | 'quadruped_stand_mass_eval': ['quadruped_stand~mass~' + para for para in parameter_2_eval], 57 | 'quadruped_walk_mass': ['quadruped_walk~mass~' + para for para in parameter_2], 58 | 'quadruped_walk_mass_eval': ['quadruped_walk~mass~' + para for para in parameter_2_eval], 59 | 'quadruped_run_mass': ['quadruped_run~mass~' + para for para in parameter_2], 60 | 'quadruped_run_mass_eval': ['quadruped_run~mass~' + para for para in parameter_2_eval], 61 | 'quadruped_jump_mass': ['quadruped_jump~mass~' + para for para in parameter_2], 62 | 'quadruped_jump_mass_eval': ['quadruped_jump~mass~' + para for para in parameter_2_eval], 63 | 64 | 'quadruped_stand_damping': ['quadruped_stand~damping~' + para for para in parameter_1], 65 | 'quadruped_stand_damping_eval': ['quadruped_stand~damping~' + para for para in parameter_1_eval], 66 | 'quadruped_walk_damping': ['quadruped_walk~damping~' + para for para in parameter_1], 67 | 'quadruped_walk_damping_eval': ['quadruped_walk~damping~' + para for para in parameter_1_eval], 68 | 'quadruped_run_damping': ['quadruped_run~damping~' + para for para in parameter_1], 69 | 'quadruped_run_damping_eval': ['quadruped_run~damping~' + para for para in parameter_1_eval], 70 | 'quadruped_jump_damping': ['quadruped_jump~damping~' + para for para in parameter_1], 71 | 'quadruped_jump_damping_eval': ['quadruped_jump~damping~' + para for para in parameter_1_eval], 72 | } 73 | -------------------------------------------------------------------------------- /DMC_state/finetune.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - configs/agent: ddpg 3 | - override hydra/launcher: submitit_local 4 | 5 | # mode 6 | reward_free: false 7 | # task settings 8 | task: none 9 | domain: walker_mass 10 | finetune_domain: walker_stand_mass 11 | obs_type: states # [states, pixels] 12 | frame_stack: 3 # only works if obs_type=pixels 13 | action_repeat: 1 # set to 2 for pixels 14 | discount: 0.99 15 | # train settings 16 | num_train_frames: 100010 17 | num_seed_frames: 4000 18 | # eval 19 | eval_every_frames: 10000 20 | num_eval_episodes: 10 21 | # pretrained 22 | snapshot_ts: 100000 23 | snapshot_base_dir: ../../../../../../pretrained_models 24 | # replay buffer 25 | replay_buffer_size: 1000000 26 | replay_buffer_num_workers: 4 27 | batch_size: ${agent.batch_size} 28 | nstep: ${agent.nstep} 29 | update_encoder: false # can be either true or false depending if we want to fine-tune encoder 30 | # misc 31 | seed: 1 32 | device: cuda 33 | save_video: true 34 | save_train_video: false 35 | use_tb: false 36 | use_wandb: false 37 | # experiment 38 | experiment: exp 39 | 40 | init_task: 1.0 41 | 42 | hydra: 43 | run: 44 | dir: ./exp_local/${domain}/finetune_${finetune_domain}/${agent.name}/${snapshot_ts}/${now:%Y.%m.%d.%H%M%S}_${seed} 45 | sweep: 46 | dir: ./exp_sweep/${domain}/finetune_${finetune_domain}/${agent.name}/${snapshot_ts}/${now:%Y.%m.%d.%H%M%S}_${seed}_${experiment} 47 | subdir: ${hydra.job.num} 48 | launcher: 49 | timeout_min: 4300 50 | cpus_per_task: 10 51 | gpus_per_node: 1 52 | tasks_per_node: 1 53 | mem_gb: 160 54 | nodes: 1 55 | submitit_folder: ./exp_sweep/${domain}/finetune_${finetune_domain}/${agent.name}/${snapshot_ts}/${now:%Y.%m.%d.%H%M%S}_${seed}_${experiment}/.slurm -------------------------------------------------------------------------------- /DMC_state/finetune_ddpg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DOMAIN=$1 # walker_mass, quadruped_mass, quadruped_damping 3 | GPU_ID=$2 4 | FINETUNE_TASK=$3 5 | 6 | echo "Experiments started." 7 | for seed in $(seq 0 9) 8 | do 9 | export MUJOCO_EGL_DEVICE_ID=${GPU_ID} 10 | python finetune.py configs/agent=ddpg domain=${DOMAIN} seed=${seed} device=cuda:${GPU_ID} snapshot_ts=0 finetune_domain=${FINETUNE_TASK} num_train_frames=2000010 11 | done 12 | echo "Experiments ended." 13 | 14 | # e.g. 15 | # ./finetune_ddpg.sh walker_mass 0 walker_stand_mass -------------------------------------------------------------------------------- /DMC_state/logger.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import datetime 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | import torch 7 | import torchvision 8 | import wandb 9 | from termcolor import colored 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | COMMON_TRAIN_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'), 13 | ('episode', 'E', 'int'), ('episode_length', 'L', 'int'), 14 | ('episode_reward', 'R', 'float'), 15 | ('fps', 'FPS', 'float'), ('total_time', 'T', 'time')] 16 | 17 | COMMON_EVAL_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'), 18 | ('episode', 'E', 'int'), ('episode_length', 'L', 'int'), 19 | ('episode_reward', 'R', 'float'), 20 | ('total_time', 'T', 'time'), 21 | ('episode_train_reward', 'TR', 'float'), 22 | ('episode_eval_reward', 'ER', 'float')] 23 | 24 | 25 | class AverageMeter(object): 26 | def __init__(self): 27 | self._sum = 0 28 | self._count = 0 29 | 30 | def update(self, value, n=1): 31 | self._sum += value 32 | self._count += n 33 | 34 | def value(self): 35 | return self._sum / max(1, self._count) 36 | 37 | 38 | class MetersGroup(object): 39 | def __init__(self, csv_file_name, formating, use_wandb): 40 | self._csv_file_name = csv_file_name 41 | self._formating = formating 42 | self._meters = defaultdict(AverageMeter) 43 | self._csv_file = None 44 | self._csv_writer = None 45 | self.use_wandb = use_wandb 46 | 47 | def log(self, key, value, n=1): 48 | self._meters[key].update(value, n) 49 | 50 | def _prime_meters(self): 51 | data = dict() 52 | for key, meter in self._meters.items(): 53 | if key.startswith('train'): 54 | key = key[len('train') + 1:] 55 | else: 56 | key = key[len('eval') + 1:] 57 | key = key.replace('/', '_') 58 | data[key] = meter.value() 59 | return data 60 | 61 | def _remove_old_entries(self, data): 62 | rows = [] 63 | with self._csv_file_name.open('r') as f: 64 | reader = csv.DictReader(f) 65 | for row in reader: 66 | if float(row['episode']) >= data['episode']: 67 | break 68 | rows.append(row) 69 | with self._csv_file_name.open('w') as f: 70 | writer = csv.DictWriter(f, 71 | fieldnames=sorted(data.keys()), 72 | restval=0.0) 73 | writer.writeheader() 74 | for row in rows: 75 | writer.writerow(row) 76 | 77 | def _dump_to_csv(self, data): 78 | if self._csv_writer is None: 79 | should_write_header = True 80 | if self._csv_file_name.exists(): 81 | self._remove_old_entries(data) 82 | should_write_header = False 83 | 84 | self._csv_file = self._csv_file_name.open('a') 85 | self._csv_writer = csv.DictWriter(self._csv_file, 86 | fieldnames=sorted(data.keys()), 87 | restval=0.0) 88 | if should_write_header: 89 | self._csv_writer.writeheader() 90 | 91 | self._csv_writer.writerow(data) 92 | self._csv_file.flush() 93 | 94 | def _format(self, key, value, ty): 95 | if ty == 'int': 96 | value = int(value) 97 | return f'{key}: {value}' 98 | elif ty == 'float': 99 | return f'{key}: {value:.04f}' 100 | elif ty == 'time': 101 | value = str(datetime.timedelta(seconds=int(value))) 102 | return f'{key}: {value}' 103 | else: 104 | raise f'invalid format type: {ty}' 105 | 106 | def _dump_to_console(self, data, prefix): 107 | prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green') 108 | pieces = [f'| {prefix: <14}'] 109 | for key, disp_key, ty in self._formating: 110 | value = data.get(key, 0) 111 | pieces.append(self._format(disp_key, value, ty)) 112 | print(' | '.join(pieces)) 113 | 114 | def _dump_to_wandb(self, data): 115 | wandb.log(data) 116 | 117 | def dump(self, step, prefix): 118 | if len(self._meters) == 0: 119 | return 120 | data = self._prime_meters() 121 | data['frame'] = step 122 | if self.use_wandb: 123 | wandb_data = {prefix + '/' + key: val for key, val in data.items()} 124 | self._dump_to_wandb(data=wandb_data) 125 | self._dump_to_csv(data) 126 | self._dump_to_console(data, prefix) 127 | self._meters.clear() 128 | 129 | 130 | class Logger(object): 131 | def __init__(self, log_dir, use_tb, use_wandb): 132 | self._log_dir = log_dir 133 | self._train_mg = MetersGroup(log_dir / 'train.csv', 134 | formating=COMMON_TRAIN_FORMAT, 135 | use_wandb=use_wandb) 136 | self._eval_mg = MetersGroup(log_dir / 'eval.csv', 137 | formating=COMMON_EVAL_FORMAT, 138 | use_wandb=use_wandb) 139 | if use_tb: 140 | self._sw = SummaryWriter(str(log_dir / 'tb')) 141 | else: 142 | self._sw = None 143 | self.use_wandb = use_wandb 144 | 145 | def _try_sw_log(self, key, value, step): 146 | if self._sw is not None: 147 | self._sw.add_scalar(key, value, step) 148 | 149 | def log(self, key, value, step): 150 | assert key.startswith('train') or key.startswith('eval') 151 | if type(value) == torch.Tensor: 152 | value = value.item() 153 | self._try_sw_log(key, value, step) 154 | mg = self._train_mg if key.startswith('train') else self._eval_mg 155 | mg.log(key, value) 156 | 157 | def log_metrics(self, metrics, step, ty): 158 | for key, value in metrics.items(): 159 | self.log(f'{ty}/{key}', value, step) 160 | 161 | def dump(self, step, ty=None): 162 | if ty is None or ty == 'eval': 163 | self._eval_mg.dump(step, 'eval') 164 | if ty is None or ty == 'train': 165 | self._train_mg.dump(step, 'train') 166 | 167 | def log_and_dump_ctx(self, step, ty): 168 | return LogAndDumpCtx(self, step, ty) 169 | 170 | 171 | class LogAndDumpCtx: 172 | def __init__(self, logger, step, ty): 173 | self._logger = logger 174 | self._step = step 175 | self._ty = ty 176 | 177 | def __enter__(self): 178 | return self 179 | 180 | def __call__(self, key, value): 181 | self._logger.log(f'{self._ty}/{key}', value, self._step) 182 | 183 | def __exit__(self, *args): 184 | self._logger.dump(self._step, self._ty) 185 | -------------------------------------------------------------------------------- /DMC_state/pretrain.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.filterwarnings('ignore', category=DeprecationWarning) 4 | 5 | import os 6 | 7 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' 8 | os.environ['MUJOCO_GL'] = 'egl' 9 | 10 | from pathlib import Path 11 | 12 | import hydra 13 | import numpy as np 14 | import torch 15 | import wandb 16 | from dm_env import specs 17 | 18 | import dmc 19 | import utils 20 | from logger import Logger 21 | from replay_buffer import ReplayBufferStorage, make_replay_loader 22 | from video import TrainVideoRecorder, VideoRecorder 23 | 24 | torch.backends.cudnn.benchmark = True 25 | 26 | from dmc_benchmark import PRETRAIN_TASKS 27 | 28 | 29 | def make_agent(obs_type, obs_spec, action_spec, num_expl_steps, cfg): 30 | cfg.obs_type = obs_type 31 | cfg.obs_shape = obs_spec.shape 32 | cfg.action_shape = action_spec.shape 33 | cfg.num_expl_steps = num_expl_steps 34 | return hydra.utils.instantiate(cfg) 35 | 36 | 37 | class Workspace: 38 | def __init__(self, cfg): 39 | self.work_dir = Path.cwd() 40 | print(f'workspace: {self.work_dir}') 41 | 42 | self.cfg = cfg 43 | utils.set_seed_everywhere(cfg.seed) 44 | self.device = torch.device(cfg.device) 45 | 46 | # create logger 47 | if cfg.use_wandb: 48 | exp_name = '_'.join([ 49 | cfg.experiment, cfg.agent.name, cfg.domain, cfg.obs_type, 50 | str(cfg.seed) 51 | ]) 52 | wandb.init(project="urlb", group=cfg.agent.name, name=exp_name) 53 | 54 | self.logger = Logger(self.work_dir, 55 | use_tb=cfg.use_tb, 56 | use_wandb=cfg.use_wandb) 57 | 58 | # create envs 59 | if cfg.task != 'none': 60 | # single task 61 | tasks = [cfg.task] 62 | else: 63 | # pre-define multi-task 64 | tasks = PRETRAIN_TASKS[self.cfg.domain] 65 | frame_stack = 1 66 | img_size = 64 67 | 68 | self.train_envs = [dmc.make(task, cfg.obs_type, cfg.frame_stack, 69 | cfg.action_repeat, cfg.seed, ) 70 | for task in tasks] 71 | self.tasks_name = tasks 72 | self.train_envs_number = len(self.train_envs) 73 | self.current_train_id = 0 74 | self.eval_env = [dmc.make(task, cfg.obs_type, cfg.frame_stack, 75 | cfg.action_repeat, cfg.seed, ) 76 | for task in tasks] 77 | 78 | # create agent 79 | if 'peac' in cfg.agent.name or 'context' in cfg.agent.name: 80 | cfg.agent['context_dim'] = self.train_envs_number 81 | 82 | self.agent = make_agent(cfg.obs_type, 83 | self.train_envs[0].observation_spec(), 84 | self.train_envs[0].action_spec(), 85 | cfg.num_seed_frames // cfg.action_repeat, 86 | cfg.agent) 87 | 88 | # get meta specs 89 | meta_specs = self.agent.get_meta_specs() 90 | # create replay buffer 91 | data_specs = (self.train_envs[0].observation_spec(), 92 | self.train_envs[0].action_spec(), 93 | specs.Array((1,), np.float32, 'reward'), 94 | specs.Array((1,), np.float32, 'discount'), 95 | specs.Array((1,), np.int64, 'embodiment_id'),) 96 | 97 | # create data storage 98 | self.replay_storage = ReplayBufferStorage(data_specs, meta_specs, 99 | self.work_dir / 'buffer') 100 | 101 | # create replay buffer 102 | his_o_a = cfg.agent.get('his_o_a', 0) 103 | print('history o and a:', his_o_a) 104 | self.replay_loader = make_replay_loader(self.replay_storage, 105 | cfg.replay_buffer_size, 106 | cfg.batch_size, 107 | cfg.replay_buffer_num_workers, 108 | False, cfg.nstep, cfg.discount, 109 | his_o_a=his_o_a) 110 | self._replay_iter = None 111 | 112 | # create video recorders 113 | self.video_recorder = VideoRecorder( 114 | self.work_dir if cfg.save_video else None, 115 | camera_id=0 if 'quadruped' not in self.cfg.domain else 2, 116 | use_wandb=self.cfg.use_wandb) 117 | self.train_video_recorder = TrainVideoRecorder( 118 | self.work_dir if cfg.save_train_video else None, 119 | camera_id=0 if 'quadruped' not in self.cfg.domain else 2, 120 | use_wandb=self.cfg.use_wandb) 121 | 122 | self.timer = utils.Timer() 123 | self._global_step = 0 124 | self._global_episode = 0 125 | 126 | @property 127 | def global_step(self): 128 | return self._global_step 129 | 130 | @property 131 | def global_episode(self): 132 | return self._global_episode 133 | 134 | @property 135 | def global_frame(self): 136 | return self.global_step * self.cfg.action_repeat 137 | 138 | @property 139 | def replay_iter(self): 140 | if self._replay_iter is None: 141 | self._replay_iter = iter(self.replay_loader) 142 | return self._replay_iter 143 | 144 | def eval(self): 145 | # we do not eval in the pre-train stage 146 | pass 147 | 148 | def train(self): 149 | # predicates 150 | train_until_step = utils.Until(self.cfg.num_train_frames, 151 | self.cfg.action_repeat) 152 | seed_until_step = utils.Until(self.cfg.num_seed_frames, 153 | self.cfg.action_repeat) 154 | eval_every_step = utils.Every(self.cfg.eval_every_frames, 155 | self.cfg.action_repeat) 156 | 157 | episode_step, episode_reward = 0, 0 158 | time_step = self.train_envs[self.current_train_id].reset() 159 | if hasattr(self.agent, "init_context"): 160 | self.agent.init_context() 161 | time_step['embodiment_id'] = self.current_train_id 162 | print('current task is', self.tasks_name[self.current_train_id]) 163 | meta = self.agent.init_meta() 164 | self.replay_storage.add(time_step, meta) 165 | # self.train_video_recorder.init(time_step.observation) 166 | self.train_video_recorder.init(time_step['observation']) 167 | metrics = None 168 | while train_until_step(self.global_step): 169 | # if time_step.last(): 170 | if time_step['is_last']: 171 | self._global_episode += 1 172 | self.train_video_recorder.save(f'{self.global_frame}.mp4') 173 | # wait until all the metrics schema is populated 174 | if metrics is not None: 175 | # log stats 176 | elapsed_time, total_time = self.timer.reset() 177 | episode_frame = episode_step * self.cfg.action_repeat 178 | with self.logger.log_and_dump_ctx(self.global_frame, 179 | ty='train') as log: 180 | log('fps', episode_frame / elapsed_time) 181 | log('total_time', total_time) 182 | log('episode_reward', episode_reward) 183 | log('episode_length', episode_frame) 184 | log('episode', self.global_episode) 185 | log('buffer_size', len(self.replay_storage)) 186 | log('step', self.global_step) 187 | 188 | # reset env 189 | self.current_train_id = (self.current_train_id + 1) % self.train_envs_number 190 | time_step = self.train_envs[self.current_train_id].reset() 191 | if hasattr(self.agent, "init_context"): 192 | self.agent.init_context() 193 | print('current task is', self.tasks_name[self.current_train_id]) 194 | time_step['embodiment_id'] = self.current_train_id 195 | 196 | meta = self.agent.init_meta() 197 | self.replay_storage.add(time_step, meta) 198 | # self.train_video_recorder.init(time_step.observation) 199 | self.train_video_recorder.init(time_step['observation']) 200 | # try to save snapshot 201 | if self.global_frame in self.cfg.snapshots: 202 | self.save_snapshot() 203 | episode_step = 0 204 | episode_reward = 0 205 | 206 | # try to evaluate 207 | if eval_every_step(self.global_step): 208 | self.logger.log('eval_total_time', self.timer.total_time(), 209 | self.global_frame) 210 | self.eval() 211 | 212 | meta = self.agent.update_meta(meta, self.global_step, time_step) 213 | # sample action 214 | with torch.no_grad(), utils.eval_mode(self.agent): 215 | action = self.agent.act(time_step['observation'], 216 | meta, 217 | self.global_step, 218 | eval_mode=False) 219 | # action = self.agent.act(time_step.observation, 220 | # meta, 221 | # self.global_step, 222 | # eval_mode=False) 223 | 224 | # try to update the agent 225 | if not seed_until_step(self.global_step): 226 | metrics = self.agent.update(self.replay_iter, self.global_step) 227 | self.logger.log_metrics(metrics, self.global_frame, ty='train') 228 | 229 | # take env step 230 | time_step = self.train_envs[self.current_train_id].step(action) 231 | time_step['embodiment_id'] = self.current_train_id 232 | # episode_reward += time_step.reward 233 | episode_reward += time_step['reward'] 234 | self.replay_storage.add(time_step, meta) 235 | self.train_video_recorder.record(time_step['observation']) 236 | # self.train_video_recorder.record(time_step.observation) 237 | episode_step += 1 238 | self._global_step += 1 239 | 240 | def save_snapshot(self): 241 | snapshot_dir = self.work_dir / Path(self.cfg.snapshot_dir) 242 | snapshot_dir.mkdir(exist_ok=True, parents=True) 243 | snapshot = snapshot_dir / f'snapshot_{self.global_frame}.pt' 244 | keys_to_save = ['agent', '_global_step', '_global_episode'] 245 | payload = {k: self.__dict__[k] for k in keys_to_save} 246 | with snapshot.open('wb') as f: 247 | torch.save(payload, f) 248 | 249 | 250 | @hydra.main(config_path='.', config_name='pretrain') 251 | def main(cfg): 252 | from pretrain import Workspace as W 253 | root_dir = Path.cwd() 254 | workspace = W(cfg) 255 | snapshot = root_dir / 'snapshot.pt' 256 | if snapshot.exists(): 257 | print(f'resuming: {snapshot}') 258 | workspace.load_snapshot() 259 | workspace.train() 260 | 261 | 262 | if __name__ == '__main__': 263 | main() 264 | -------------------------------------------------------------------------------- /DMC_state/pretrain.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - configs/agent: ddpg 3 | - override hydra/launcher: submitit_local 4 | 5 | # mode 6 | reward_free: true 7 | # task settings 8 | task: none 9 | domain: walker # primal task will be infered in runtime 10 | obs_type: states # [states, pixels] 11 | frame_stack: 3 # only works if obs_type=pixels 12 | action_repeat: 1 # set to 2 for pixels 13 | discount: 0.99 14 | # train settings 15 | num_train_frames: 2000010 16 | num_seed_frames: 4000 17 | # eval 18 | eval_every_frames: 10000 19 | num_eval_episodes: 10 20 | # snapshot 21 | snapshots: [100000, 500000, 1000000, 2000000] 22 | snapshot_dir: ../../../../../pretrained_models/${obs_type}/${domain}/${agent.name}/${seed} 23 | # replay buffer 24 | replay_buffer_size: 1000000 25 | #replay_buffer_num_workers: 4 26 | replay_buffer_num_workers: 4 27 | batch_size: ${agent.batch_size} 28 | nstep: ${agent.nstep} 29 | update_encoder: true # should always be true for pre-training 30 | # misc 31 | seed: 1 32 | device: cuda 33 | save_video: true 34 | save_train_video: false 35 | use_tb: false 36 | use_wandb: false 37 | # experiment 38 | experiment: exp 39 | 40 | 41 | hydra: 42 | run: 43 | dir: ./exp_local/${domain}/pretrain/${agent.name}/${now:%Y.%m.%d.%H%M%S}_${seed} 44 | sweep: 45 | dir: ./exp_sweep/${domain}/pretrain/${agent.name}/${now:%Y.%m.%d.%H%M}_${seed}_${experiment} 46 | subdir: ${hydra.job.num} 47 | launcher: 48 | timeout_min: 4300 49 | cpus_per_task: 10 50 | gpus_per_node: 1 51 | tasks_per_node: 1 52 | mem_gb: 160 53 | nodes: 1 54 | submitit_folder: ./exp_sweep/${domain}/pretrain/${agent.name}/${now:%Y.%m.%H%M}_${seed}_${experiment}/.slurm 55 | -------------------------------------------------------------------------------- /DMC_state/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import io 3 | import random 4 | import traceback 5 | from collections import defaultdict 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.data import IterableDataset 11 | 12 | 13 | def episode_len(episode): 14 | # subtract -1 because the dummy first transition 15 | return next(iter(episode.values())).shape[0] - 1 16 | 17 | 18 | def save_episode(episode, fn): 19 | with io.BytesIO() as bs: 20 | np.savez_compressed(bs, **episode) 21 | bs.seek(0) 22 | with fn.open('wb') as f: 23 | f.write(bs.read()) 24 | 25 | 26 | def load_episode(fn): 27 | with fn.open('rb') as f: 28 | episode = np.load(f) 29 | episode = {k: episode[k] for k in episode.keys()} 30 | return episode 31 | 32 | 33 | class ReplayBufferStorage: 34 | def __init__(self, data_specs, meta_specs, replay_dir): 35 | self._data_specs = data_specs 36 | self._meta_specs = meta_specs 37 | self._replay_dir = replay_dir 38 | replay_dir.mkdir(exist_ok=True) 39 | self._current_episode = defaultdict(list) 40 | self._preload() 41 | 42 | def __len__(self): 43 | return self._num_transitions 44 | 45 | def add(self, time_step, meta): 46 | for key, value in meta.items(): 47 | self._current_episode[key].append(value) 48 | for spec in self._data_specs: 49 | value = time_step[spec.name] 50 | if np.isscalar(value): 51 | value = np.full(spec.shape, value, spec.dtype) 52 | assert spec.shape == value.shape and spec.dtype == value.dtype 53 | self._current_episode[spec.name].append(value) 54 | # if time_step.last(): 55 | if time_step['is_last']: 56 | episode = dict() 57 | for spec in self._data_specs: 58 | value = self._current_episode[spec.name] 59 | episode[spec.name] = np.array(value, spec.dtype) 60 | for spec in self._meta_specs: 61 | value = self._current_episode[spec.name] 62 | episode[spec.name] = np.array(value, spec.dtype) 63 | self._current_episode = defaultdict(list) 64 | self._store_episode(episode) 65 | 66 | def _preload(self): 67 | self._num_episodes = 0 68 | self._num_transitions = 0 69 | for fn in self._replay_dir.glob('*.npz'): 70 | _, _, eps_len = fn.stem.split('_') 71 | self._num_episodes += 1 72 | self._num_transitions += int(eps_len) 73 | 74 | def _store_episode(self, episode): 75 | eps_idx = self._num_episodes 76 | eps_len = episode_len(episode) 77 | self._num_episodes += 1 78 | self._num_transitions += eps_len 79 | ts = datetime.datetime.now().strftime('%Y%m%dT%H%M%S') 80 | eps_fn = f'{ts}_{eps_idx}_{eps_len}.npz' 81 | save_episode(episode, self._replay_dir / eps_fn) 82 | 83 | 84 | class ReplayBuffer(IterableDataset): 85 | def __init__(self, storage, max_size, num_workers, nstep, discount, 86 | fetch_every, save_snapshot, his_o_a=0): 87 | self._storage = storage 88 | self._size = 0 89 | self._max_size = max_size 90 | self._num_workers = max(1, num_workers) 91 | self._episode_fns = [] 92 | self._episodes = dict() 93 | self._nstep = nstep 94 | self._discount = discount 95 | self._fetch_every = fetch_every 96 | self._samples_since_last_fetch = fetch_every 97 | self._save_snapshot = save_snapshot 98 | self._his_o_a = his_o_a 99 | 100 | def _sample_episode(self): 101 | # print(len(self._episode_fns)) 102 | eps_fn = random.choice(self._episode_fns) 103 | # print('eps_fn embodiment id', eps_fn, self._episodes[eps_fn]['embodiment_id']) 104 | return self._episodes[eps_fn] 105 | 106 | def _store_episode(self, eps_fn): 107 | try: 108 | episode = load_episode(eps_fn) 109 | except: 110 | return False 111 | eps_len = episode_len(episode) 112 | while eps_len + self._size > self._max_size: 113 | early_eps_fn = self._episode_fns.pop(0) 114 | early_eps = self._episodes.pop(early_eps_fn) 115 | self._size -= episode_len(early_eps) 116 | early_eps_fn.unlink(missing_ok=True) 117 | self._episode_fns.append(eps_fn) 118 | self._episode_fns.sort() 119 | self._episodes[eps_fn] = episode 120 | self._size += eps_len 121 | 122 | if not self._save_snapshot: 123 | eps_fn.unlink(missing_ok=True) 124 | return True 125 | 126 | def _try_fetch(self): 127 | if self._samples_since_last_fetch < self._fetch_every: 128 | return 129 | self._samples_since_last_fetch = 0 130 | try: 131 | worker_id = torch.utils.data.get_worker_info().id 132 | except: 133 | worker_id = 0 134 | eps_fns = sorted(self._storage._replay_dir.glob('*.npz'), reverse=True) 135 | fetched_size = 0 136 | # print(eps_fns) 137 | for eps_fn in eps_fns: 138 | # print('hhh', eps_fn.stem) 139 | # print(eps_fn.stem.split('_')[1:]) 140 | eps_idx, eps_len = [int(x) for x in eps_fn.stem.split('_')[1:]] 141 | if eps_idx % self._num_workers != worker_id: 142 | continue 143 | if eps_fn in self._episodes.keys(): 144 | break 145 | if fetched_size + eps_len > self._max_size: 146 | break 147 | fetched_size += eps_len 148 | if not self._store_episode(eps_fn): 149 | break 150 | 151 | def _sample(self): 152 | try: 153 | self._try_fetch() 154 | except: 155 | traceback.print_exc() 156 | self._samples_since_last_fetch += 1 157 | episode = self._sample_episode() 158 | # add +1 for the first dummy transition 159 | idx = np.random.randint(0, episode_len(episode) - self._his_o_a - 160 | self._nstep + 1) + 1 + self._his_o_a 161 | # print(idx) 162 | meta = [] 163 | for spec in self._storage._meta_specs: 164 | meta.append(episode[spec.name][idx - 1]) 165 | obs = episode['observation'][idx - 1] 166 | action = episode['action'][idx] 167 | next_obs = episode['observation'][idx + self._nstep - 1] 168 | reward = np.zeros_like(episode['reward'][idx]) 169 | discount = np.ones_like(episode['discount'][idx]) 170 | embodiment_id = episode['embodiment_id'][idx] 171 | for i in range(self._nstep): 172 | step_reward = episode['reward'][idx + i] 173 | reward += discount * step_reward 174 | discount *= episode['discount'][idx + i] * self._discount 175 | # print(obs.shape) 176 | # print('lalala', episode['embodiment_id'], embodiment_id) 177 | if self._his_o_a == 0: 178 | return (obs, action, reward, discount, next_obs, embodiment_id, *meta) 179 | his_o = episode['observation'][idx - self._his_o_a - 1: idx - 1] 180 | his_a = episode['action'][idx - self._his_o_a: idx] 181 | return (obs, action, reward, discount, next_obs, embodiment_id, his_o, his_a, *meta) 182 | 183 | def __iter__(self): 184 | while True: 185 | yield self._sample() 186 | 187 | 188 | def _worker_init_fn(worker_id): 189 | seed = np.random.get_state()[1][0] + worker_id 190 | np.random.seed(seed) 191 | random.seed(seed) 192 | 193 | 194 | def make_replay_loader(storage, max_size, batch_size, num_workers, 195 | save_snapshot, nstep, discount, his_o_a=0): 196 | max_size_per_worker = max_size // max(1, num_workers) 197 | 198 | iterable = ReplayBuffer(storage, 199 | max_size_per_worker, 200 | num_workers, 201 | nstep, 202 | discount, 203 | fetch_every=1000, 204 | save_snapshot=save_snapshot, 205 | his_o_a=his_o_a,) 206 | 207 | loader = torch.utils.data.DataLoader(iterable, 208 | batch_size=batch_size, 209 | num_workers=num_workers, 210 | pin_memory=True, 211 | worker_init_fn=_worker_init_fn) 212 | return loader 213 | -------------------------------------------------------------------------------- /DMC_state/train_finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ALGO=$1 3 | DOMAIN=$2 # walker_mass, quadruped_mass, quadruped_damping 4 | GPU_ID=$3 5 | 6 | 7 | if [ "$DOMAIN" == "walker_mass" ] 8 | then 9 | ALL_TASKS=("walker_stand_mass" "walker_walk_mass" "walker_run_mass" "walker_flip_mass") 10 | elif [ "$DOMAIN" == "quadruped_mass" ] 11 | then 12 | ALL_TASKS=("quadruped_stand_mass" "quadruped_walk_mass" "quadruped_run_mass" "quadruped_jump_mass") 13 | elif [ "$DOMAIN" == "quadruped_damping" ] 14 | then 15 | ALL_TASKS=("quadruped_stand_damping" "quadruped_walk_damping" "quadruped_run_damping" "quadruped_jump_damping") 16 | else 17 | ALL_TASKS=() 18 | echo "No matching tasks" 19 | exit 0 20 | fi 21 | 22 | echo "Experiments started." 23 | for seed in $(seq 0 9) 24 | do 25 | export MUJOCO_EGL_DEVICE_ID=${GPU_ID} 26 | python pretrain.py configs/agent=${ALGO} domain=${DOMAIN} seed=$seed device=cuda:${GPU_ID} 27 | for string in "${ALL_TASKS[@]}" 28 | do 29 | export MUJOCO_EGL_DEVICE_ID=${GPU_ID} 30 | python finetune.py configs/agent=${ALGO} domain=${DOMAIN} seed=$seed device=cuda:${GPU_ID} snapshot_ts=2000000 finetune_domain=$string 31 | done 32 | done 33 | echo "Experiments ended." 34 | 35 | # e.g. 36 | # ./train.sh peac walker_mass 0 -------------------------------------------------------------------------------- /DMC_state/video.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import imageio 3 | import numpy as np 4 | import wandb 5 | 6 | 7 | class VideoRecorder: 8 | def __init__(self, 9 | root_dir, 10 | render_size=256, 11 | fps=20, 12 | camera_id=0, 13 | use_wandb=False): 14 | if root_dir is not None: 15 | self.save_dir = root_dir / 'eval_video' 16 | self.save_dir.mkdir(exist_ok=True) 17 | else: 18 | self.save_dir = None 19 | 20 | self.render_size = render_size 21 | self.fps = fps 22 | self.frames = [] 23 | self.camera_id = camera_id 24 | self.use_wandb = use_wandb 25 | 26 | def init(self, env, enabled=True): 27 | self.frames = [] 28 | self.enabled = self.save_dir is not None and enabled 29 | self.record(env) 30 | 31 | def record(self, env): 32 | if self.enabled: 33 | if hasattr(env, 'physics'): 34 | frame = env.physics.render(height=self.render_size, 35 | width=self.render_size, 36 | camera_id=self.camera_id) 37 | else: 38 | frame = env.render() 39 | self.frames.append(frame) 40 | 41 | def log_to_wandb(self): 42 | frames = np.transpose(np.array(self.frames), (0, 3, 1, 2)) 43 | fps, skip = 6, 8 44 | wandb.log({ 45 | 'eval/video': 46 | wandb.Video(frames[::skip, :, ::2, ::2], fps=fps, format="gif") 47 | }) 48 | 49 | def save(self, file_name): 50 | if self.enabled: 51 | if self.use_wandb: 52 | self.log_to_wandb() 53 | path = self.save_dir / file_name 54 | imageio.mimsave(str(path), self.frames, fps=self.fps) 55 | 56 | 57 | class TrainVideoRecorder: 58 | def __init__(self, 59 | root_dir, 60 | render_size=256, 61 | fps=20, 62 | camera_id=0, 63 | use_wandb=False): 64 | if root_dir is not None: 65 | self.save_dir = root_dir / 'train_video' 66 | self.save_dir.mkdir(exist_ok=True) 67 | else: 68 | self.save_dir = None 69 | 70 | self.render_size = render_size 71 | self.fps = fps 72 | self.frames = [] 73 | self.camera_id = camera_id 74 | self.use_wandb = use_wandb 75 | 76 | def init(self, obs, enabled=True): 77 | self.frames = [] 78 | self.enabled = self.save_dir is not None and enabled 79 | self.record(obs) 80 | 81 | def record(self, obs): 82 | if self.enabled: 83 | frame = cv2.resize(obs[-3:].transpose(1, 2, 0), 84 | dsize=(self.render_size, self.render_size), 85 | interpolation=cv2.INTER_CUBIC) 86 | self.frames.append(frame) 87 | 88 | def log_to_wandb(self): 89 | frames = np.transpose(np.array(self.frames), (0, 3, 1, 2)) 90 | fps, skip = 6, 8 91 | wandb.log({ 92 | 'train/video': 93 | wandb.Video(frames[::skip, :, ::2, ::2], fps=fps, format="gif") 94 | }) 95 | 96 | def save(self, file_name): 97 | if self.enabled: 98 | if self.use_wandb: 99 | self.log_to_wandb() 100 | path = self.save_dir / file_name 101 | imageio.mimsave(str(path), self.frames, fps=self.fps) 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 TSAIL group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CEURL 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2405.14073-b31b1b.svg)](https://arxiv.org/abs/2405.14073) [![Project Page](https://img.shields.io/badge/project-page-blue)](https://yingchengyang.github.io/ceurl) 4 | 5 | This is the Official implementation for "PEAC: Unsupervised Pre-training for Cross-Embodiment Reinforcement Learning" (NeurIPS 2024) 6 | 7 | ## State-based DMC & Image-based DMC 8 | 9 | ### Installation 10 | 11 | The code is based on [URLB](https://github.com/rll-research/url_benchmark) 12 | 13 | You can create an anaconda environment and install all required dependencies by running 14 | ```sh 15 | conda create -n ceurl python=3.8 16 | conda activate ceurl 17 | pip install -r requirements.txt 18 | pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html 19 | ``` 20 | 21 | ### Instructions 22 | 23 | The simplest way to try PEAC in three embodiment distributions of state-based DMC by running 24 | ```sh 25 | cd DMC_state 26 | chmod +x train_finetune.sh 27 | 28 | ./train_finetune.sh peac walker_mass 0 29 | ./train_finetune.sh peac quadruped_mass 0 30 | ./train_finetune.sh peac quadruped_damping 0 31 | ``` 32 | 33 | The simplest way to try PEAC in three embodiment distributions of image-based DMC by running 34 | ```sh 35 | cd DMC_image 36 | chmod +x train_finetune.sh 37 | 38 | ./train_finetune.sh peac_lbs walker_mass 0 39 | ./train_finetune.sh peac_lbs quadruped_mass 0 40 | ./train_finetune.sh peac_lbs quadruped_damping 0 41 | 42 | ./train_finetune.sh peac_diayn walker_mass 0 43 | ./train_finetune.sh peac_diayn quadruped_mass 0 44 | ./train_finetune.sh peac_diayn quadruped_damping 0 45 | ``` 46 | 47 | ## Citation 48 | 49 | If you find this work helpful, please cite our paper. 50 | 51 | ``` 52 | @article{ying2024peac, 53 | title={PEAC: Unsupervised Pre-training for Cross-Embodiment Reinforcement Learning}, 54 | author={Ying, Chengyang and Hao, Zhongkai and Zhou, Xinning and Xu, Xuezhou and Su, Hang and Zhang, Xingxing and Zhu, Jun}, 55 | journal={arXiv preprint arXiv:2405.14073}, 56 | year={2024} 57 | } 58 | ``` 59 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | hydra-core==1.1.0 2 | hydra-submitit-launcher==1.1.5 3 | termcolor==1.1.0 4 | wandb==0.11.1 5 | gym==0.26.2 6 | protobuf==3.20.0 7 | dm_env==1.6 8 | dm_control==1.0.14 9 | tensorboard==2.0.2 10 | numpy==1.19.2 11 | torchaudio==0.8.0 12 | pandas==1.3.0 13 | mujoco==2.3.7 14 | opencv-python 15 | imageio==2.9.0 16 | imageio-ffmpeg==0.4.4 --------------------------------------------------------------------------------