├── .gitignore
├── DMC_image
├── LICENSE
├── agent
│ ├── aps.py
│ ├── apt_dreamer.py
│ ├── choreo.py
│ ├── cic.py
│ ├── diayn.py
│ ├── dreamer.py
│ ├── dreamer_utils.py
│ ├── icm.py
│ ├── lbs.py
│ ├── lsd.py
│ ├── peac.py
│ ├── peac_diayn.py
│ ├── peac_lbs.py
│ ├── plan2explore.py
│ ├── random_dreamer.py
│ ├── rnd.py
│ ├── skill_utils.py
│ └── spectral_utils.py
├── configs
│ ├── agent
│ │ ├── aps_dreamer.yaml
│ │ ├── apt_dreamer.yaml
│ │ ├── choreo.yaml
│ │ ├── cic_dreamer.yaml
│ │ ├── diayn_dreamer.yaml
│ │ ├── dreamer.yaml
│ │ ├── icm_dreamer.yaml
│ │ ├── lbs_dreamer.yaml
│ │ ├── lsd.yaml
│ │ ├── peac_diayn.yaml
│ │ ├── peac_lbs.yaml
│ │ ├── plan2explore.yaml
│ │ ├── random_dreamer.yaml
│ │ └── rnd_dreamer.yaml
│ ├── dmc_pixels.yaml
│ ├── dmc_states.yaml
│ └── dreamer.yaml
├── custom_dmc_tasks
│ ├── __init__.py
│ ├── jaco.py
│ ├── quadruped.py
│ ├── quadruped.xml
│ ├── walker.py
│ └── walker.xml
├── dmc_benchmark.py
├── dreamer_finetune.py
├── dreamer_finetune.yaml
├── dreamer_pretrain.py
├── dreamer_pretrain.yaml
├── dreamer_replay.py
├── envs.py
├── logger.py
├── train_finetune.sh
└── utils.py
├── DMC_state
├── LICENSE
├── README.md
├── agent
│ ├── aps.py
│ ├── becl.py
│ ├── cic.py
│ ├── ddpg.py
│ ├── diayn.py
│ ├── disagreement.py
│ ├── icm.py
│ ├── lbs.py
│ ├── peac.py
│ ├── proto.py
│ ├── rnd.py
│ └── smm.py
├── configs
│ └── agent
│ │ ├── aps.yaml
│ │ ├── becl.yaml
│ │ ├── cic.yaml
│ │ ├── ddpg.yaml
│ │ ├── diayn.yaml
│ │ ├── disagreement.yaml
│ │ ├── icm.yaml
│ │ ├── lbs.yaml
│ │ ├── peac.yaml
│ │ ├── proto.yaml
│ │ ├── rnd.yaml
│ │ └── smm.yaml
├── custom_dmc_tasks
│ ├── __init__.py
│ ├── jaco.py
│ ├── quadruped.py
│ ├── quadruped.xml
│ ├── walker.py
│ └── walker.xml
├── dmc.py
├── dmc_benchmark.py
├── finetune.py
├── finetune.yaml
├── finetune_ddpg.sh
├── logger.py
├── pretrain.py
├── pretrain.yaml
├── replay_buffer.py
├── train_finetune.sh
├── utils.py
└── video.py
├── LICENSE
├── README.md
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | DMC_state/exp_local
2 | DMC_state/pretrained_models
3 | DMC_image/exp_local
4 | DMC_image/pretrained_models
5 |
6 | __pycache__/
7 | **/__pycache__/
--------------------------------------------------------------------------------
/DMC_image/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Pietro Mazzaglia
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/DMC_image/agent/apt_dreamer.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | import utils
8 | from agent.dreamer import DreamerAgent, stop_gradient
9 | import agent.dreamer_utils as common
10 |
11 |
12 | class APTDreamerAgent(DreamerAgent):
13 | def __init__(self, knn_rms, knn_k, knn_avg, knn_clip, **kwargs):
14 | super().__init__(**kwargs)
15 | self.reward_free = True
16 |
17 | # particle-based entropy
18 | rms = utils.RMS(self.device)
19 | self.pbe = utils.PBE(rms, knn_clip, knn_k, knn_avg, knn_rms,
20 | self.device)
21 |
22 | self.requires_grad_(requires_grad=False)
23 |
24 | def compute_intr_reward(self, seq):
25 | rep = stop_gradient(seq['deter'])
26 | B, T, _ = rep.shape
27 | rep = rep.reshape(B*T, -1)
28 | reward = self.pbe(rep, cdist=True)
29 | reward = reward.reshape(B, T, 1)
30 | return reward
31 |
32 | def update(self, data, step):
33 | metrics = {}
34 | B, T, _ = data['action'].shape
35 |
36 | state, outputs, mets = self.wm.update(data, state=None)
37 | metrics.update(mets)
38 | start = outputs['post']
39 | start = {k: stop_gradient(v) for k,v in start.items()}
40 | if self.reward_free:
41 | reward_fn = lambda seq: self.compute_intr_reward(seq)
42 | else:
43 | reward_fn = lambda seq: self.wm.heads['reward'](seq['feat']).mean #.mode()
44 | metrics.update(self._task_behavior.update(
45 | self.wm, start, data['is_terminal'], reward_fn))
46 | return state, metrics
--------------------------------------------------------------------------------
/DMC_image/agent/icm.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | import utils
8 | from agent.dreamer import DreamerAgent, stop_gradient
9 | import agent.dreamer_utils as common
10 |
11 |
12 | class ICM(nn.Module):
13 | def __init__(self, obs_dim, action_dim, hidden_dim):
14 | super().__init__()
15 | self.forward_net = nn.Sequential(
16 | nn.Linear(obs_dim + action_dim, hidden_dim), nn.ReLU(),
17 | nn.Linear(hidden_dim, obs_dim))
18 |
19 | self.backward_net = nn.Sequential(nn.Linear(2 * obs_dim, hidden_dim),
20 | nn.ReLU(),
21 | nn.Linear(hidden_dim, action_dim),
22 | nn.Tanh())
23 |
24 | self.apply(utils.weight_init)
25 |
26 | def forward(self, obs, action, next_obs):
27 | assert obs.shape[0] == next_obs.shape[0]
28 | assert obs.shape[0] == action.shape[0]
29 |
30 | next_obs_hat = self.forward_net(torch.cat([obs, action], dim=-1))
31 | action_hat = self.backward_net(torch.cat([obs, next_obs], dim=-1))
32 |
33 | forward_error = torch.norm(next_obs - next_obs_hat,
34 | dim=-1,
35 | p=2,
36 | keepdim=True)
37 | backward_error = torch.norm(action - action_hat,
38 | dim=-1,
39 | p=2,
40 | keepdim=True)
41 |
42 | return forward_error, backward_error
43 |
44 |
45 | class ICMDreamerAgent(DreamerAgent):
46 | def __init__(self, icm_scale, **kwargs):
47 | super().__init__(**kwargs)
48 | in_dim = self.wm.inp_size
49 | pred_dim = self.wm.embed_dim
50 | self.hidden_dim = pred_dim
51 | self.reward_free = True
52 | self.icm_scale = icm_scale
53 |
54 | self.icm = ICM(pred_dim, self.act_dim,
55 | self.hidden_dim).to(self.device)
56 |
57 | # optimizers
58 | self.icm_opt = common.Optimizer('icm', self.icm.parameters(), **self.cfg.model_opt, use_amp=self._use_amp)
59 |
60 | self.icm.train()
61 | self.requires_grad_(requires_grad=False)
62 |
63 | def update_icm(self, obs, action, next_obs, step):
64 | metrics = dict()
65 |
66 | forward_error, backward_error = self.icm(obs, action, next_obs)
67 |
68 | loss = forward_error.mean() # + backward_error.mean()
69 |
70 | metrics.update(self.icm_opt(loss, self.icm.parameters()))
71 |
72 | metrics['icm_loss'] = loss.item()
73 |
74 | return metrics
75 |
76 | def compute_intr_reward(self, obs, action, next_obs, step):
77 | forward_error, _ = self.icm(obs, action, next_obs)
78 |
79 | reward = forward_error * self.icm_scale
80 | reward = torch.log(reward + 1.0)
81 | return reward
82 |
83 | def update(self, data, step):
84 | metrics = {}
85 | B, T, _ = data['action'].shape
86 |
87 | if self.reward_free:
88 | T = T-1
89 | temp_data = self.wm.preprocess(data)
90 | embed = self.wm.encoder(temp_data)
91 | inp = stop_gradient(embed[:, :-1]).reshape(B*T, -1)
92 | action = data['action'][:, 1:].reshape(B*T, -1)
93 | out = stop_gradient(embed[:, 1:]).reshape(B*T, -1)
94 | with common.RequiresGrad(self.icm):
95 | with torch.cuda.amp.autocast(enabled=self._use_amp):
96 | metrics.update(
97 | self.update_icm(inp, action, out, step))
98 |
99 | with torch.no_grad():
100 | intr_reward = self.compute_intr_reward(inp, action, out, step).reshape(B, T, 1)
101 |
102 | data['reward'][:, 0] = 1
103 | data['reward'][:, 1:] = intr_reward
104 |
105 | state, outputs, mets = self.wm.update(data, state=None)
106 | metrics.update(mets)
107 | start = outputs['post']
108 | start = {k: stop_gradient(v) for k, v in start.items()}
109 | reward_fn = lambda seq: self.wm.heads['reward'](seq['feat']).mean #.mode()
110 | metrics.update(self._task_behavior.update(
111 | self.wm, start, data['is_terminal'], reward_fn))
112 | return state, metrics
--------------------------------------------------------------------------------
/DMC_image/agent/lbs.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import hydra
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import torch.distributions as D
9 |
10 | import utils
11 | from agent.dreamer import DreamerAgent, WorldModel, stop_gradient
12 | import agent.dreamer_utils as common
13 |
14 |
15 | class LBSDreamerAgent(DreamerAgent):
16 | def __init__(self, **kwargs):
17 | super().__init__(**kwargs)
18 |
19 | self.reward_free = True
20 |
21 | # LBS
22 | self.lbs = common.MLP(self.wm.inp_size, (1,), **self.cfg.reward_head).to(self.device)
23 | self.lbs_opt = common.Optimizer('lbs', self.lbs.parameters(), **self.cfg.model_opt, use_amp=self._use_amp)
24 | self.lbs.train()
25 |
26 | self.requires_grad_(requires_grad=False)
27 |
28 | def update_lbs(self, outs):
29 | metrics = dict()
30 | B, T, _ = outs['feat'].shape
31 | feat, kl = outs['feat'].detach(), outs['kl'].detach()
32 | feat = feat.reshape(B*T, -1)
33 | kl = kl.reshape(B*T, -1)
34 |
35 | loss = -self.lbs(feat).log_prob(kl).mean()
36 | metrics.update(self.lbs_opt(loss, self.lbs.parameters()))
37 | metrics['lbs_loss'] = loss.item()
38 | return metrics
39 |
40 | def update(self, data, step):
41 | metrics = {}
42 | state, outputs, mets = self.wm.update(data, state=None)
43 | metrics.update(mets)
44 | start = outputs['post']
45 | start = {k: stop_gradient(v) for k, v in start.items()}
46 |
47 | if self.reward_free:
48 | with common.RequiresGrad(self.lbs):
49 | with torch.cuda.amp.autocast(enabled=self._use_amp):
50 | metrics.update(
51 | self.update_lbs(outputs))
52 | reward_fn = lambda seq: self.lbs(seq['feat']).mean
53 | else:
54 | reward_fn = lambda seq: self.wm.heads['reward'](seq['feat']).mean #.mode()
55 |
56 | metrics.update(self._task_behavior.update(
57 | self.wm, start, data['is_terminal'], reward_fn))
58 | return state, metrics
--------------------------------------------------------------------------------
/DMC_image/agent/peac_lbs.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import hydra
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import torch.distributions as D
9 |
10 | import utils
11 | from agent.peac import PEACAgent, stop_gradient
12 | import agent.dreamer_utils as common
13 |
14 |
15 | class PEAC_LBSAgent(PEACAgent):
16 | def __init__(self, beta=1.0, **kwargs):
17 | super().__init__(**kwargs)
18 |
19 | self.reward_free = True
20 | self.beta = beta
21 | print("beta:", self.beta)
22 |
23 | # LBS
24 | # feat + context -> predict kl
25 | self.lbs = common.MLP(self.wm.inp_size+self.task_number, (1,),
26 | **self.cfg.reward_head).to(self.device)
27 | self.lbs_opt = common.Optimizer('lbs', self.lbs.parameters(), **self.cfg.model_opt, use_amp=self._use_amp)
28 | self.lbs.train()
29 |
30 | self.requires_grad_(requires_grad=False)
31 |
32 | def update_lbs(self, outs):
33 | metrics = dict()
34 | B, T, _ = outs['feat'].shape
35 | feat, kl = outs['feat'].detach(), outs['kl'].detach()
36 | feat = feat.reshape(B * T, -1)
37 | kl = kl.reshape(B * T, -1)
38 | context = F.softmax(self.wm.task_model(feat), dim=-1).detach()
39 |
40 | loss = -self.lbs(torch.cat([feat, context], dim=-1)).log_prob(kl).mean()
41 | metrics.update(self.lbs_opt(loss, self.lbs.parameters()))
42 | metrics['lbs_loss'] = loss.item()
43 | return metrics
44 |
45 | def update(self, data, step):
46 | metrics = {}
47 | state, outputs, mets = self.wm.update(data, state=None)
48 | metrics.update(mets)
49 | start = outputs['post']
50 | start['embodiment_id'] = data['embodiment_id']
51 | start['context'] = outputs['context']
52 | start = {k: stop_gradient(v) for k, v in start.items()}
53 |
54 | if self.reward_free:
55 | with common.RequiresGrad(self.lbs):
56 | with torch.cuda.amp.autocast(enabled=self._use_amp):
57 | metrics.update(self.update_lbs(outputs))
58 | reward_fn = lambda seq: self.compute_intr_reward(seq) + \
59 | self.beta * self.compute_task_reward(seq)
60 | else:
61 | reward_fn = lambda seq: self.wm.heads['reward'](seq['feat']).mean # .mode()
62 |
63 | metrics.update(self._task_behavior.update(
64 | self.wm, start, data['is_terminal'], reward_fn))
65 | return state, metrics
66 |
67 | def compute_intr_reward(self, seq):
68 | context = F.softmax(self.wm.task_model(seq['feat']), dim=-1)
69 | return self.lbs(torch.cat([seq['feat'], context], dim=-1)).mean
70 |
71 | def compute_task_reward(self, seq):
72 | # print('we use calculated reward')
73 | B, T, _ = seq['feat'].shape
74 | task_pred = self.wm.task_model(seq['feat'])
75 | task_truth = seq['embodiment_id'].repeat(B, 1, 1).to(dtype=torch.int64)
76 | # print(task_pred.shape) # 16, 2500, task_number
77 | # print(seq['action'].shape) # 16, 2500, _
78 | # print(task_truth.shape) # 16, 2500, 1
79 | task_pred = F.log_softmax(task_pred, dim=2)
80 | task_rew = task_pred.reshape(B * T, -1)[torch.arange(B * T), task_truth.reshape(-1)]
81 | task_rew = -task_rew.reshape(B, T, 1)
82 |
83 | # print(intr_rew.shape) # 16, 2500, 1
84 | return task_rew
85 |
--------------------------------------------------------------------------------
/DMC_image/agent/plan2explore.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | import utils
8 | from agent.dreamer import DreamerAgent, stop_gradient
9 | import agent.dreamer_utils as common
10 |
11 |
12 | class Disagreement(nn.Module):
13 | def __init__(self, obs_dim, action_dim, hidden_dim, n_models=5, pred_dim=None):
14 | super().__init__()
15 | if pred_dim is None: pred_dim = obs_dim
16 | self.ensemble = nn.ModuleList([
17 | nn.Sequential(nn.Linear(obs_dim + action_dim, hidden_dim),
18 | nn.ReLU(), nn.Linear(hidden_dim, pred_dim))
19 | for _ in range(n_models)
20 | ])
21 |
22 | def forward(self, obs, action, next_obs):
23 | #import ipdb; ipdb.set_trace()
24 | assert obs.shape[0] == next_obs.shape[0]
25 | assert obs.shape[0] == action.shape[0]
26 |
27 | errors = []
28 | for model in self.ensemble:
29 | next_obs_hat = model(torch.cat([obs, action], dim=-1))
30 | model_error = torch.norm(next_obs - next_obs_hat,
31 | dim=-1,
32 | p=2,
33 | keepdim=True)
34 | errors.append(model_error)
35 |
36 | return torch.cat(errors, dim=1)
37 |
38 | def get_disagreement(self, obs, action):
39 | assert obs.shape[0] == action.shape[0]
40 |
41 | preds = []
42 | for model in self.ensemble:
43 | next_obs_hat = model(torch.cat([obs, action], dim=-1))
44 | preds.append(next_obs_hat)
45 | preds = torch.stack(preds, dim=0)
46 | return torch.var(preds, dim=0).mean(dim=-1)
47 |
48 |
49 | class Plan2Explore(DreamerAgent):
50 | def __init__(self, **kwargs):
51 | super().__init__(**kwargs)
52 | in_dim = self.wm.inp_size
53 | pred_dim = self.wm.embed_dim
54 | self.hidden_dim = pred_dim
55 | self.reward_free = True
56 |
57 | self.disagreement = Disagreement(in_dim, self.act_dim,
58 | self.hidden_dim, pred_dim=pred_dim).to(self.device)
59 |
60 | # optimizers
61 | self.disagreement_opt = common.Optimizer('disagreement', self.disagreement.parameters(), **self.cfg.model_opt, use_amp=self._use_amp)
62 | self.disagreement.train()
63 | self.requires_grad_(requires_grad=False)
64 |
65 | def update_disagreement(self, obs, action, next_obs, step):
66 | metrics = dict()
67 |
68 | error = self.disagreement(obs, action, next_obs)
69 |
70 | loss = error.mean()
71 |
72 | metrics.update(self.disagreement_opt(loss, self.disagreement.parameters()))
73 |
74 | metrics['disagreement_loss'] = loss.item()
75 |
76 | return metrics
77 |
78 | def compute_intr_reward(self, seq):
79 | obs, action = seq['feat'][:-1], stop_gradient(seq['action'][1:])
80 | intr_rew = torch.zeros(list(seq['action'].shape[:-1]) + [1], device=self.device)
81 | if len(action.shape) > 2:
82 | B, T, _ = action.shape
83 | obs = obs.reshape(B*T, -1)
84 | action = action.reshape(B*T, -1)
85 | reward = self.disagreement.get_disagreement(obs, action).reshape(B, T, 1)
86 | else:
87 | reward = self.disagreement.get_disagreement(obs, action).unsqueeze(-1)
88 | intr_rew[1:] = reward
89 | return intr_rew
90 |
91 | def update(self, data, step):
92 | metrics = {}
93 | B, T, _ = data['action'].shape
94 | state, outputs, mets = self.wm.update(data, state=None)
95 | metrics.update(mets)
96 | start = outputs['post']
97 | start = {k: stop_gradient(v) for k,v in start.items()}
98 | if self.reward_free:
99 | T = T-1
100 | inp = stop_gradient(outputs['feat'][:, :-1]).reshape(B*T, -1)
101 | action = data['action'][:, 1:].reshape(B*T, -1)
102 | out = stop_gradient(outputs['embed'][:, 1:]).reshape(B*T, -1)
103 | with common.RequiresGrad(self.disagreement):
104 | with torch.cuda.amp.autocast(enabled=self._use_amp):
105 | metrics.update(
106 | self.update_disagreement(inp, action, out, step))
107 | metrics.update(self._task_behavior.update(
108 | self.wm, start, data['is_terminal'], reward_fn=self.compute_intr_reward))
109 | else:
110 | reward_fn = lambda seq: self.wm.heads['reward'](seq['feat']).mean #.mode()
111 | metrics.update(self._task_behavior.update(
112 | self.wm, start, data['is_terminal'], reward_fn))
113 | return state, metrics
--------------------------------------------------------------------------------
/DMC_image/agent/random_dreamer.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 |
4 | import utils
5 | from collections import OrderedDict
6 | import numpy as np
7 |
8 | from agent.dreamer import DreamerAgent, WorldModel, stop_gradient
9 | import agent.dreamer_utils as common
10 |
11 | Module = nn.Module
12 |
13 | class RandomDreamerAgent(DreamerAgent):
14 | def __init__(self, **kwargs):
15 | super().__init__(**kwargs)
16 |
17 | def act(self, obs, meta, step, eval_mode, state):
18 | return torch.zeros(self.act_spec.shape).uniform_(-1.0, 1.0).numpy(), None
19 |
20 | def update(self, data, step):
21 | metrics = {}
22 | state, outputs, mets = self.wm.update(data, state=None)
23 | metrics.update(mets)
24 | return state, metrics
--------------------------------------------------------------------------------
/DMC_image/agent/rnd.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import hydra
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | import utils
10 | from agent.dreamer import DreamerAgent, stop_gradient
11 | import agent.dreamer_utils as common
12 |
13 |
14 | class RND(nn.Module):
15 | def __init__(self,
16 | obs_dim,
17 | hidden_dim,
18 | rnd_rep_dim,
19 | encoder,
20 | aug,
21 | obs_shape,
22 | obs_type,
23 | clip_val=5.):
24 | super().__init__()
25 | self.clip_val = clip_val
26 | self.aug = aug
27 |
28 | if obs_type == "pixels":
29 | self.normalize_obs = nn.BatchNorm2d(obs_shape[0], affine=False)
30 | else:
31 | self.normalize_obs = nn.BatchNorm1d(obs_shape[0], affine=False)
32 |
33 | self.predictor = nn.Sequential(encoder, nn.Linear(obs_dim, hidden_dim),
34 | nn.ReLU(),
35 | nn.Linear(hidden_dim, hidden_dim),
36 | nn.ReLU(),
37 | nn.Linear(hidden_dim, rnd_rep_dim))
38 | self.target = nn.Sequential(copy.deepcopy(encoder),
39 | nn.Linear(obs_dim, hidden_dim), nn.ReLU(),
40 | nn.Linear(hidden_dim, hidden_dim),
41 | nn.ReLU(),
42 | nn.Linear(hidden_dim, rnd_rep_dim))
43 |
44 | for param in self.target.parameters():
45 | param.requires_grad = False
46 |
47 | self.apply(utils.weight_init)
48 |
49 | def forward(self, obs):
50 | if type(obs) == dict:
51 | img = obs['observation']
52 | img = self.aug(img)
53 | img = self.normalize_obs(img)
54 | img = torch.clamp(img, -self.clip_val, self.clip_val)
55 | obs['observation'] = img
56 | else:
57 | obs = self.aug(obs)
58 | obs = self.normalize_obs(obs)
59 | obs = torch.clamp(obs, -self.clip_val, self.clip_val)
60 | prediction, target = self.predictor(obs), self.target(obs)
61 | prediction_error = torch.square(target.detach() - prediction).mean(
62 | dim=-1, keepdim=True)
63 | return prediction_error
64 |
65 |
66 | class RNDDreamerAgent(DreamerAgent):
67 | def __init__(self, rnd_rep_dim, rnd_scale, **kwargs):
68 | super().__init__(**kwargs)
69 |
70 | self.reward_free = True
71 | self.rnd_scale = rnd_scale
72 |
73 | self.obs_dim = self.wm.embed_dim
74 | self.hidden_dim = self.wm.embed_dim
75 | self.aug = nn.Identity()
76 | self.obs_shape = (3,64,64)
77 | self.obs_type = self.cfg.obs_type
78 |
79 | encoder = copy.deepcopy(self.wm.encoder)
80 |
81 | self.rnd = RND(self.obs_dim, self.hidden_dim, rnd_rep_dim,
82 | encoder, self.aug, self.obs_shape,
83 | self.obs_type).to(self.device)
84 | self.intrinsic_reward_rms = utils.RMS(device=self.device)
85 |
86 | # optimizers
87 | self.rnd_opt = common.Optimizer('rnd', self.rnd.parameters(), **self.cfg.model_opt, use_amp=self._use_amp)
88 |
89 | self.rnd.train()
90 | self.requires_grad_(requires_grad=False)
91 |
92 | def update_rnd(self, obs, step):
93 | metrics = dict()
94 |
95 | prediction_error = self.rnd(obs)
96 |
97 | loss = prediction_error.mean()
98 |
99 | metrics.update(self.rnd_opt(loss, self.rnd.parameters()))
100 |
101 | metrics['rnd_loss'] = loss.item()
102 |
103 | return metrics
104 |
105 | def compute_intr_reward(self, obs):
106 | prediction_error = self.rnd(obs)
107 | _, intr_reward_var = self.intrinsic_reward_rms(prediction_error)
108 | reward = self.rnd_scale * prediction_error / (
109 | torch.sqrt(intr_reward_var) + 1e-8)
110 | return reward
111 |
112 | def update(self, data, step):
113 | metrics = {}
114 | B, T, _ = data['action'].shape
115 | obs_shape = data['observation'].shape[2:]
116 |
117 | if self.reward_free:
118 | temp_data = self.wm.preprocess(data)
119 | temp_data['observation'] = temp_data['observation'].reshape(B*T, *obs_shape)
120 | with common.RequiresGrad(self.rnd):
121 | with torch.cuda.amp.autocast(enabled=self._use_amp):
122 | metrics.update(self.update_rnd(temp_data, step))
123 |
124 | with torch.no_grad():
125 | intr_reward = self.compute_intr_reward(temp_data).reshape(B, T, 1)
126 |
127 | data['reward'] = intr_reward
128 |
129 | state, outputs, mets = self.wm.update(data, state=None)
130 | metrics.update(mets)
131 | start = outputs['post']
132 | start = {k: stop_gradient(v) for k,v in start.items()}
133 | reward_fn = lambda seq: self.wm.heads['reward'](seq['feat']).mean #.mode()
134 | metrics.update(self._task_behavior.update(
135 | self.wm, start, data['is_terminal'], reward_fn))
136 | return state, metrics
--------------------------------------------------------------------------------
/DMC_image/configs/agent/aps_dreamer.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.aps.APSDreamerAgent
3 | name: aps_dreamer
4 | cfg: ???
5 | obs_space: ???
6 | act_spec: ???
7 | grad_heads: [decoder]
8 | # Note: it's important to keep momentum = 1.00, otherwise SF won't work
9 | reward_norm: {momentum: 1.00, scale: 1.0, eps: 1e-8}
10 | actor_ent: 0.0
11 |
12 | skill_reward_norm: {momentum: 1.00, scale: 1.0, eps: 1e-8} # {momentum: 0.95, scale: 1.0, eps: 1e-8}
13 | skill_actor_ent: 0.0
14 |
15 | skill_dim: 5
16 | update_skill_every_step: 50
17 |
18 | knn_rms: true
19 | knn_k: 12
20 | knn_avg: true
21 | knn_clip: 0.0001
22 | num_init_frames: 4000 # set to ${num_train_frames} to disable finetune policy parameters
23 | lstsq_batch_size: 4096
--------------------------------------------------------------------------------
/DMC_image/configs/agent/apt_dreamer.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.apt_dreamer.APTDreamerAgent
3 | name: apt_dreamer
4 | cfg: ???
5 | obs_space: ???
6 | act_spec: ???
7 | grad_heads: [decoder]
8 | reward_norm: {momentum: 0.95, scale: 1.0, eps: 1e-8}
9 | actor_ent: 0
10 |
11 | knn_rms: false
12 | knn_k: 12
13 | knn_avg: true
14 | knn_clip: 0.0
--------------------------------------------------------------------------------
/DMC_image/configs/agent/choreo.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.choreo.ChoreoAgent
3 | name: choreo
4 | cfg: ???
5 | obs_space: ???
6 | act_spec: ???
7 | grad_heads: [decoder]
8 |
9 | # Exploration
10 | reward_norm: {momentum: 0.95, scale: 1.0, eps: 1e-8}
11 | actor_ent: 0
12 |
13 | # Skills
14 | skill_dim: 64
15 | skill_reward_norm: {momentum: 1.00, scale: 1.0, eps: 1e-8}
16 | skill_actor_ent: 0
17 | code_dim: 16
18 | code_resampling: True
19 | resample_every: 200
20 |
21 | # Adaptation
22 | num_init_frames: 4000
23 | update_skill_every_step: 125
24 | freeze_skills: False
25 |
26 | # PBE
27 | knn_rms: false
28 | knn_k: 30
29 | knn_avg: true
30 | knn_clip: 0.0001
31 |
32 | task_number: 1
--------------------------------------------------------------------------------
/DMC_image/configs/agent/cic_dreamer.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.cic.CICAgent
3 | name: cic_dreamer
4 | cfg: ???
5 | obs_space: ???
6 | act_spec: ???
7 | grad_heads: [decoder]
8 | # Note: it's important to keep momentum = 1.00, otherwise SF won't work
9 | reward_norm: {momentum: 1.00, scale: 1.0, eps: 1e-8}
10 | actor_ent: 0.0
11 |
12 | skill_reward_norm: {momentum: 1.00, scale: 1.0, eps: 1e-8} # {momentum: 0.95, scale: 1.0, eps: 1e-8}
13 | skill_actor_ent: 0.0
14 |
15 | skill_dim: 5
16 | update_skill_every_step: 50
17 |
18 | #knn_rms: true
19 | #knn_k: 12
20 | #knn_avg: true
21 | #knn_clip: 0.0001
22 | knn_k: 16
23 | knn_rms: true
24 | knn_avg: true
25 | knn_clip: 0.0005
26 | num_init_frames: 4000 # set to ${num_train_frames} to disable finetune policy parameters
27 | lstsq_batch_size: 4096
28 | project_skill: True
29 | temp: 0.5
--------------------------------------------------------------------------------
/DMC_image/configs/agent/diayn_dreamer.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.diayn.DIAYNDreamerAgent
3 | name: diayn_dreamer
4 | cfg: ???
5 | obs_space: ???
6 | act_spec: ???
7 | grad_heads: [decoder]
8 | reward_norm: {momentum: 0.95, scale: 1.0, eps: 1e-8}
9 | actor_ent: 0.1
10 |
11 | diayn_scale: 1.0
12 | update_skill_every_step: 50
13 |
14 | skill_reward_norm: {momentum: 0.95, scale: 1.0, eps: 1e-8}
15 | skill_actor_ent: 0.1
16 | skill_dim: 50
17 |
18 | num_init_frames: 4000
19 |
20 | reward_type: 0
21 |
--------------------------------------------------------------------------------
/DMC_image/configs/agent/dreamer.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.dreamer.DreamerAgent
3 | name: dreamer
4 | cfg: ???
5 | obs_space: ???
6 | act_spec: ???
7 | grad_heads: [decoder]
8 | reward_norm: {momentum: 1.0, scale: 1.0, eps: 1e-8}
9 | actor_ent: 1e-4
--------------------------------------------------------------------------------
/DMC_image/configs/agent/icm_dreamer.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.icm.ICMDreamerAgent
3 | name: icm_dreamer
4 | cfg: ???
5 | obs_space: ???
6 | act_spec: ???
7 | grad_heads: [decoder]
8 | reward_norm: {momentum: 0.95, scale: 1.0, eps: 1e-8}
9 | actor_ent: 0
10 | icm_scale: 1.0
--------------------------------------------------------------------------------
/DMC_image/configs/agent/lbs_dreamer.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.lbs.LBSDreamerAgent
3 | name: lbs_dreamer
4 | cfg: ???
5 | obs_space: ???
6 | act_spec: ???
7 | grad_heads: [decoder]
8 | reward_norm: {momentum: 0.95, scale: 1.0, eps: 1e-8}
9 | actor_ent: 0
10 |
11 | reward_type: 0
--------------------------------------------------------------------------------
/DMC_image/configs/agent/lsd.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.lsd.LSDDreamerAgent
3 | name: lsd_dreamer
4 | cfg: ???
5 | obs_space: ???
6 | act_spec: ???
7 | grad_heads: [decoder]
8 | reward_norm: {momentum: 0.95, scale: 1.0, eps: 1e-8}
9 | actor_ent: 0.1
10 |
11 | lsd_scale: 1.0
12 | update_skill_every_step: 50
13 |
14 | skill_reward_norm: {momentum: 0.95, scale: 1.0, eps: 1e-8}
15 | skill_actor_ent: 0.1
16 | skill_dim: 50
17 |
18 | num_init_frames: 4000
19 |
--------------------------------------------------------------------------------
/DMC_image/configs/agent/peac_diayn.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.peac_diayn.PEACDIAYNAgent
3 | name: peac_diayn
4 | cfg: ???
5 | obs_space: ???
6 | act_spec: ???
7 | grad_heads: [decoder]
8 | reward_norm: {momentum: 1.0, scale: 1.0, eps: 1e-8}
9 | actor_ent: 1e-4
10 |
11 | diayn_scale: 1.0
12 | update_skill_every_step: 50
13 | task_scale: 1.0
14 |
15 | context_skill_reward_norm: {momentum: 0.95, scale: 1.0, eps: 1e-8}
16 | context_skill_actor_ent: 0.1
17 | skill_dim: 50
18 | diayn_hidden: 50
19 |
20 | num_init_frames: 4000
21 |
22 | freeze_skills: False
23 |
24 | task_number: 1
--------------------------------------------------------------------------------
/DMC_image/configs/agent/peac_lbs.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.peac_lbs.PEAC_LBSAgent
3 | name: peac_lbs
4 | cfg: ???
5 | obs_space: ???
6 | act_spec: ???
7 | grad_heads: [decoder]
8 | #reward_norm: {momentum: 1.0, scale: 1.0, eps: 1e-8}
9 | #actor_ent: 1e-4
10 | reward_norm: {momentum: 0.95, scale: 1.0, eps: 1e-8}
11 | actor_ent: 0
12 |
13 | context_reward_norm: {momentum: 0.95, scale: 1.0, eps: 1e-8}
14 | #context_actor_ent: 0.1
15 | context_actor_ent: 1e-4
16 |
17 | num_init_frames: 4000
18 |
19 | task_number: 1
--------------------------------------------------------------------------------
/DMC_image/configs/agent/plan2explore.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.plan2explore.Plan2Explore
3 | name: plan2explore
4 | cfg: ???
5 | obs_space: ???
6 | act_spec: ???
7 | grad_heads: [decoder]
8 | reward_norm: {momentum: 0.95, scale: 1.0, eps: 1e-8}
9 | actor_ent: 0
--------------------------------------------------------------------------------
/DMC_image/configs/agent/random_dreamer.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.random_dreamer.RandomDreamerAgent
3 | name: random_dreamer
4 | cfg: ???
5 | obs_space: ???
6 | act_spec: ???
7 | grad_heads: [decoder]
8 | reward_norm: {momentum: 1.0, scale: 1.0, eps: 1e-8}
9 | actor_ent: 1e-4
--------------------------------------------------------------------------------
/DMC_image/configs/agent/rnd_dreamer.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.rnd.RNDDreamerAgent
3 | name: rnd_dreamer
4 | cfg: ???
5 | obs_space: ???
6 | act_spec: ???
7 | grad_heads: [decoder]
8 | reward_norm: {momentum: 1.0, scale: 1.0, eps: 1e-8}
9 | actor_ent: 0
10 |
11 | rnd_scale: 1.0
12 | rnd_rep_dim: 512
--------------------------------------------------------------------------------
/DMC_image/configs/dmc_pixels.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | obs_type: pixels
3 | action_repeat: 2
4 | encoder: {mlp_keys: '$^', cnn_keys: 'observation', norm: none, cnn_depth: 48, cnn_kernels: [4, 4, 4, 4], mlp_layers: [400, 400, 400, 400]} # act: elu
5 | decoder: {mlp_keys: '$^', cnn_keys: 'observation', norm: none, cnn_depth: 48, cnn_kernels: [5, 5, 6, 6], mlp_layers: [400, 400, 400, 400]} # act: elu
6 | replay.capacity: 1e6
--------------------------------------------------------------------------------
/DMC_image/configs/dmc_states.yaml:
--------------------------------------------------------------------------------
1 | obs_type: states
2 | action_repeat: 1
3 | encoder: {mlp_keys: 'observation', cnn_keys: '$^', norm: layer, cnn_depth: 48, cnn_kernels: [4, 4, 4, 4], mlp_layers: [400, 400, 400, 400]}
4 | decoder: {mlp_keys: 'observation', cnn_keys: '$^', norm: layer, cnn_depth: 48, cnn_kernels: [5, 5, 6, 6], mlp_layers: [400, 400, 400, 400]}
5 | replay.capacity: 2e6
--------------------------------------------------------------------------------
/DMC_image/configs/dreamer.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Dreamer defaults
4 | pred_discount: False
5 | rssm: {ensemble: 1, hidden: 200, deter: 200, stoch: 32, discrete: 32, norm: none, std_act: sigmoid2, min_std: 0.1} # act: elu,
6 | reward_head: {layers: 4, units: 400, norm: none, dist: mse} # act: elu
7 | # we add task head here
8 | task_head: {layers: 4, units: 200, norm: none, dist: mse} # act: elu
9 | kl: {free: 1.0, forward: False, balance: 0.8, free_avg: True}
10 | loss_scales: {kl: 1.0, reward: 1.0, discount: 1.0, proprio: 1.0}
11 | model_opt: {opt: adam, lr: 3e-4, eps: 1e-5, clip: 100, wd: 1e-6}
12 | replay: {capacity: 2e6, ongoing: False, minlen: 50, maxlen: 50, prioritize_ends: False}
13 |
14 | actor: {layers: 4, units: 400, norm: none, dist: trunc_normal, min_std: 0.1 } # act: elu
15 | critic: {layers: 4, units: 400, norm: none, dist: mse} # act: elu,
16 | actor_opt: {opt: adam, lr: 8e-5, eps: 1e-5, clip: 100, wd: 1e-6}
17 | critic_opt: {opt: adam, lr: 8e-5, eps: 1e-5, clip: 100, wd: 1e-6}
18 | discount: 0.99
19 | discount_lambda: 0.95
20 | actor_grad: dynamics
21 | slow_target: True
22 | slow_target_update: 100
23 | slow_target_fraction: 1
24 | slow_baseline: True
25 |
26 | clip_rewards: identity
27 |
28 | batch_size: 50
29 | batch_length: 50
30 | imag_horizon: 15
31 | eval_state_mean: False
32 |
33 | precision: 16
34 | train_every_actions: 10
35 | #
--------------------------------------------------------------------------------
/DMC_image/custom_dmc_tasks/__init__.py:
--------------------------------------------------------------------------------
1 | from custom_dmc_tasks import walker
2 | from custom_dmc_tasks import quadruped
3 | from custom_dmc_tasks import jaco
4 |
5 |
6 | def make(domain, task,
7 | task_kwargs=None,
8 | environment_kwargs=None,
9 | visualize_reward=False,
10 | mass=1.0):
11 |
12 | if domain == 'walker':
13 | return walker.make(task,
14 | task_kwargs=task_kwargs,
15 | environment_kwargs=environment_kwargs,
16 | visualize_reward=visualize_reward)
17 | elif domain == 'quadruped':
18 | return quadruped.make(task,
19 | task_kwargs=task_kwargs,
20 | environment_kwargs=environment_kwargs,
21 | visualize_reward=visualize_reward)
22 | else:
23 | raise f'{task} not found'
24 |
25 |
26 | def make_jaco(task, obs_type, seed, img_size,):
27 | return jaco.make(task, obs_type, seed, img_size,)
--------------------------------------------------------------------------------
/DMC_image/custom_dmc_tasks/jaco.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """A task where the goal is to move the hand close to a target prop or site."""
17 |
18 | import collections
19 |
20 | from dm_control import composer
21 | from dm_control.composer import initializers
22 | from dm_control.composer.observation import observable
23 | from dm_control.composer.variation import distributions
24 | from dm_control.entities import props
25 | from dm_control.manipulation.shared import arenas
26 | from dm_control.manipulation.shared import cameras
27 | from dm_control.manipulation.shared import constants
28 | from dm_control.manipulation.shared import observations
29 | from dm_control.manipulation.shared import registry
30 | from dm_control.manipulation.shared import robots
31 | from dm_control.manipulation.shared import tags
32 | from dm_control.manipulation.shared import workspaces
33 | from dm_control.utils import rewards
34 | import numpy as np
35 |
36 | _ReachWorkspace = collections.namedtuple(
37 | '_ReachWorkspace', ['target_bbox', 'tcp_bbox', 'arm_offset'])
38 |
39 | # Ensures that the props are not touching the table before settling.
40 | _PROP_Z_OFFSET = 0.001
41 |
42 | _DUPLO_WORKSPACE = _ReachWorkspace(
43 | target_bbox=workspaces.BoundingBox(
44 | lower=(-0.1, -0.1, _PROP_Z_OFFSET),
45 | upper=(0.1, 0.1, _PROP_Z_OFFSET)),
46 | tcp_bbox=workspaces.BoundingBox(
47 | lower=(-0.1, -0.1, 0.2),
48 | upper=(0.1, 0.1, 0.4)),
49 | arm_offset=robots.ARM_OFFSET)
50 |
51 | _SITE_WORKSPACE = _ReachWorkspace(
52 | target_bbox=workspaces.BoundingBox(
53 | lower=(-0.2, -0.2, 0.02),
54 | upper=(0.2, 0.2, 0.4)),
55 | tcp_bbox=workspaces.BoundingBox(
56 | lower=(-0.2, -0.2, 0.02),
57 | upper=(0.2, 0.2, 0.4)),
58 | arm_offset=robots.ARM_OFFSET)
59 |
60 | _TARGET_RADIUS = 0.05
61 | _TIME_LIMIT = 10
62 |
63 | TASKS = {
64 | 'reach_top_left': workspaces.BoundingBox(
65 | lower=(-0.09, 0.09, _PROP_Z_OFFSET),
66 | upper=(-0.09, 0.09, _PROP_Z_OFFSET)),
67 | 'reach_top_right': workspaces.BoundingBox(
68 | lower=(0.09, 0.09, _PROP_Z_OFFSET),
69 | upper=(0.09, 0.09, _PROP_Z_OFFSET)),
70 | 'reach_bottom_left': workspaces.BoundingBox(
71 | lower=(-0.09, -0.09, _PROP_Z_OFFSET),
72 | upper=(-0.09, -0.09, _PROP_Z_OFFSET)),
73 | 'reach_bottom_right': workspaces.BoundingBox(
74 | lower=(0.09, -0.09, _PROP_Z_OFFSET),
75 | upper=(0.09, -0.09, _PROP_Z_OFFSET)),
76 | }
77 |
78 |
79 | def make(task_id, obs_type, seed, img_size=84, ):
80 | obs_settings = observations.VISION if obs_type == 'pixels' else observations.PERFECT_FEATURES
81 | obs_settings = obs_settings._replace(camera=obs_settings[-1]._replace(width=img_size))
82 | obs_settings = obs_settings._replace(camera=obs_settings[-1]._replace(height=img_size))
83 | if obs_type == 'states':
84 | global _TIME_LIMIT
85 | _TIME_LIMIT = 10.04
86 | # Note: Adding this fixes the problem of having 249 steps with action repeat = 1
87 | task = _reach(task_id, obs_settings=obs_settings, use_site=False)
88 | return composer.Environment(task, time_limit=_TIME_LIMIT, random_state=seed)
89 |
90 |
91 | class MTReach(composer.Task):
92 | """Bring the hand close to a target prop or site."""
93 |
94 | def __init__(
95 | self, task_id, arena, arm, hand, prop, obs_settings, workspace, control_timestep):
96 | """Initializes a new `Reach` task.
97 |
98 | Args:
99 | arena: `composer.Entity` instance.
100 | arm: `robot_base.RobotArm` instance.
101 | hand: `robot_base.RobotHand` instance.
102 | prop: `composer.Entity` instance specifying the prop to reach to, or None
103 | in which case the target is a fixed site whose position is specified by
104 | the workspace.
105 | obs_settings: `observations.ObservationSettings` instance.
106 | workspace: `_ReachWorkspace` specifying the placement of the prop and TCP.
107 | control_timestep: Float specifying the control timestep in seconds.
108 | """
109 | self._task_id = task_id
110 | self._arena = arena
111 | self._arm = arm
112 | self._hand = hand
113 | self._arm.attach(self._hand)
114 | self._arena.attach_offset(self._arm, offset=workspace.arm_offset)
115 | self.control_timestep = control_timestep
116 | self._tcp_initializer = initializers.ToolCenterPointInitializer(
117 | self._hand, self._arm,
118 | position=distributions.Uniform(*workspace.tcp_bbox),
119 | quaternion=workspaces.DOWN_QUATERNION)
120 |
121 | # Add custom camera observable.
122 | self._task_observables = cameras.add_camera_observables(
123 | arena, obs_settings, cameras.FRONT_CLOSE)
124 |
125 | target_pos_distribution = distributions.Uniform(*TASKS[task_id])
126 | self._prop = prop
127 | if prop:
128 | # The prop itself is used to visualize the target location.
129 | self._make_target_site(parent_entity=prop, visible=False)
130 | self._target = self._arena.add_free_entity(prop)
131 | self._prop_placer = initializers.PropPlacer(
132 | props=[prop],
133 | position=target_pos_distribution,
134 | quaternion=workspaces.uniform_z_rotation,
135 | settle_physics=True)
136 | else:
137 | self._target = self._make_target_site(parent_entity=arena, visible=True)
138 | self._target_placer = target_pos_distribution
139 |
140 | # Add sites for visualizing the prop and target bounding boxes.
141 | workspaces.add_bbox_site(
142 | body=self.root_entity.mjcf_model.worldbody,
143 | lower=workspace.tcp_bbox.lower, upper=workspace.tcp_bbox.upper,
144 | rgba=constants.GREEN, name='tcp_spawn_area')
145 | workspaces.add_bbox_site(
146 | body=self.root_entity.mjcf_model.worldbody,
147 | lower=workspace.target_bbox.lower, upper=workspace.target_bbox.upper,
148 | rgba=constants.BLUE, name='target_spawn_area')
149 |
150 | def _make_target_site(self, parent_entity, visible):
151 | return workspaces.add_target_site(
152 | body=parent_entity.mjcf_model.worldbody,
153 | radius=_TARGET_RADIUS, visible=visible,
154 | rgba=constants.RED, name='target_site')
155 |
156 | @property
157 | def root_entity(self):
158 | return self._arena
159 |
160 | @property
161 | def arm(self):
162 | return self._arm
163 |
164 | @property
165 | def hand(self):
166 | return self._hand
167 |
168 | @property
169 | def task_observables(self):
170 | return self._task_observables
171 |
172 | def get_reward(self, physics):
173 | hand_pos = physics.bind(self._hand.tool_center_point).xpos
174 | target_pos = physics.bind(self._target).xpos
175 | distance = np.linalg.norm(hand_pos - target_pos)
176 | return rewards.tolerance(
177 | distance, bounds=(0, _TARGET_RADIUS), margin=_TARGET_RADIUS)
178 |
179 | def initialize_episode(self, physics, random_state):
180 | self._hand.set_grasp(physics, close_factors=random_state.uniform())
181 | self._tcp_initializer(physics, random_state)
182 | if self._prop:
183 | self._prop_placer(physics, random_state)
184 | else:
185 | physics.bind(self._target).pos = (
186 | self._target_placer(random_state=random_state))
187 |
188 |
189 | def _reach(task_id, obs_settings, use_site):
190 | """Configure and instantiate a `Reach` task.
191 |
192 | Args:
193 | obs_settings: An `observations.ObservationSettings` instance.
194 | use_site: Boolean, if True then the target will be a fixed site, otherwise
195 | it will be a moveable Duplo brick.
196 |
197 | Returns:
198 | An instance of `reach.Reach`.
199 | """
200 | arena = arenas.Standard()
201 | arm = robots.make_arm(obs_settings=obs_settings)
202 | hand = robots.make_hand(obs_settings=obs_settings)
203 | if use_site:
204 | workspace = _SITE_WORKSPACE
205 | prop = None
206 | else:
207 | workspace = _DUPLO_WORKSPACE
208 | prop = props.Duplo(observable_options=observations.make_options(
209 | obs_settings, observations.FREEPROP_OBSERVABLES))
210 | task = MTReach(task_id, arena=arena, arm=arm, hand=hand, prop=prop,
211 | obs_settings=obs_settings,
212 | workspace=workspace,
213 | control_timestep=constants.CONTROL_TIMESTEP)
214 | return task
215 |
--------------------------------------------------------------------------------
/DMC_image/custom_dmc_tasks/walker.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Planar Walker Domain."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import os
23 |
24 | from dm_control import mujoco
25 | from dm_control.rl import control
26 | from dm_control.suite import base
27 | from dm_control.suite import common
28 | from dm_control.suite.utils import randomizers
29 | from dm_control.utils import containers
30 | from dm_control.utils import rewards
31 | from dm_control.utils import io as resources
32 | from dm_control import suite
33 |
34 | _DEFAULT_TIME_LIMIT = 25
35 | _CONTROL_TIMESTEP = .025
36 |
37 | # Minimal height of torso over foot above which stand reward is 1.
38 | _STAND_HEIGHT = 1.2
39 |
40 | # Horizontal speeds (meters/second) above which move reward is 1.
41 | _WALK_SPEED = 1
42 | _RUN_SPEED = 8
43 | _SPIN_SPEED = 5
44 |
45 | SUITE = containers.TaggedTasks()
46 |
47 |
48 | def make(task,
49 | task_kwargs=None,
50 | environment_kwargs=None,
51 | visualize_reward=False):
52 | task_kwargs = task_kwargs or {}
53 | if environment_kwargs is not None:
54 | task_kwargs = task_kwargs.copy()
55 | task_kwargs['environment_kwargs'] = environment_kwargs
56 | env = SUITE[task](**task_kwargs)
57 | env.task.visualize_reward = visualize_reward
58 | return env
59 |
60 |
61 | def get_model_and_assets():
62 | """Returns a tuple containing the model XML string and a dict of assets."""
63 | root_dir = os.path.dirname(os.path.dirname(__file__))
64 | xml = resources.GetResource(os.path.join(root_dir, 'custom_dmc_tasks',
65 | 'walker.xml'))
66 | return xml, common.ASSETS
67 |
68 |
69 | @SUITE.add('benchmarking')
70 | def flip(time_limit=_DEFAULT_TIME_LIMIT,
71 | random=None,
72 | environment_kwargs=None):
73 | """Returns the Run task."""
74 | physics = Physics.from_xml_string(*get_model_and_assets())
75 | task = PlanarWalker(move_speed=_RUN_SPEED,
76 | forward=True,
77 | flip=True,
78 | random=random)
79 | environment_kwargs = environment_kwargs or {}
80 | return control.Environment(physics,
81 | task,
82 | time_limit=time_limit,
83 | control_timestep=_CONTROL_TIMESTEP,
84 | **environment_kwargs)
85 |
86 |
87 | class Physics(mujoco.Physics):
88 | """Physics simulation with additional features for the Walker domain."""
89 | def torso_upright(self):
90 | """Returns projection from z-axes of torso to the z-axes of world."""
91 | return self.named.data.xmat['torso', 'zz']
92 |
93 | def torso_height(self):
94 | """Returns the height of the torso."""
95 | return self.named.data.xpos['torso', 'z']
96 |
97 | def horizontal_velocity(self):
98 | """Returns the horizontal velocity of the center-of-mass."""
99 | return self.named.data.sensordata['torso_subtreelinvel'][0]
100 |
101 | def orientations(self):
102 | """Returns planar orientations of all bodies."""
103 | return self.named.data.xmat[1:, ['xx', 'xz']].ravel()
104 |
105 | def angmomentum(self):
106 | """Returns the angular momentum of torso of the Cheetah about Y axis."""
107 | return self.named.data.subtree_angmom['torso'][1]
108 |
109 |
110 | class PlanarWalker(base.Task):
111 | """A planar walker task."""
112 | def __init__(self, move_speed, forward=True, flip=False, random=None):
113 | """Initializes an instance of `PlanarWalker`.
114 |
115 | Args:
116 | move_speed: A float. If this value is zero, reward is given simply for
117 | standing up. Otherwise this specifies a target horizontal velocity for
118 | the walking task.
119 | random: Optional, either a `numpy.random.RandomState` instance, an
120 | integer seed for creating a new `RandomState`, or None to select a seed
121 | automatically (default).
122 | """
123 | self._move_speed = move_speed
124 | self._forward = 1 if forward else -1
125 | self._flip = flip
126 | super(PlanarWalker, self).__init__(random=random)
127 |
128 | def initialize_episode(self, physics):
129 | """Sets the state of the environment at the start of each episode.
130 |
131 | In 'standing' mode, use initial orientation and small velocities.
132 | In 'random' mode, randomize joint angles and let fall to the floor.
133 |
134 | Args:
135 | physics: An instance of `Physics`.
136 |
137 | """
138 | randomizers.randomize_limited_and_rotational_joints(
139 | physics, self.random)
140 | super(PlanarWalker, self).initialize_episode(physics)
141 |
142 | def get_observation(self, physics):
143 | """Returns an observation of body orientations, height and velocites."""
144 | obs = collections.OrderedDict()
145 | obs['orientations'] = physics.orientations()
146 | obs['height'] = physics.torso_height()
147 | obs['velocity'] = physics.velocity()
148 | return obs
149 |
150 | def get_reward(self, physics):
151 | """Returns a reward to the agent."""
152 | standing = rewards.tolerance(physics.torso_height(),
153 | bounds=(_STAND_HEIGHT, float('inf')),
154 | margin=_STAND_HEIGHT / 2)
155 | upright = (1 + physics.torso_upright()) / 2
156 | stand_reward = (3 * standing + upright) / 4
157 |
158 | if self._flip:
159 | move_reward = rewards.tolerance(self._forward *
160 | physics.angmomentum(),
161 | bounds=(_SPIN_SPEED, float('inf')),
162 | margin=_SPIN_SPEED,
163 | value_at_margin=0,
164 | sigmoid='linear')
165 | else:
166 | move_reward = rewards.tolerance(
167 | self._forward * physics.horizontal_velocity(),
168 | bounds=(self._move_speed, float('inf')),
169 | margin=self._move_speed / 2,
170 | value_at_margin=0.5,
171 | sigmoid='linear')
172 |
173 | return stand_reward * (5 * move_reward + 1) / 6
174 |
--------------------------------------------------------------------------------
/DMC_image/custom_dmc_tasks/walker.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
--------------------------------------------------------------------------------
/DMC_image/dmc_benchmark.py:
--------------------------------------------------------------------------------
1 | DOMAINS = [
2 | 'walker',
3 | 'quadruped',
4 | 'jaco',
5 | ]
6 |
7 | WALKER_TASKS = [
8 | 'walker_stand',
9 | 'walker_walk',
10 | 'walker_run',
11 | 'walker_flip',
12 | ]
13 |
14 | QUADRUPED_TASKS = [
15 | 'quadruped_walk',
16 | 'quadruped_run',
17 | 'quadruped_stand',
18 | 'quadruped_jump',
19 | ]
20 |
21 | JACO_TASKS = [
22 | 'jaco_reach_top_left',
23 | 'jaco_reach_top_right',
24 | 'jaco_reach_bottom_left',
25 | 'jaco_reach_bottom_right',
26 | ]
27 |
28 | TASKS = WALKER_TASKS + QUADRUPED_TASKS + JACO_TASKS
29 |
30 | parameter_1 = ['0.4', '0.8', '1.0', '1.4']
31 | parameter_1_eval = ['0.6', '1.2']
32 |
33 | parameter_2 = ['0.2', '0.6', '1.0', '1.4', '1.8']
34 | parameter_2_eval = ['0.4', '0.8', '1.2', '1.6']
35 |
36 | PRETRAIN_TASKS = {
37 | 'walker': 'walker_stand',
38 | 'jaco': 'jaco_reach_top_left',
39 | 'quadruped': 'quadruped_walk',
40 | 'walker_mass': ['walker_stand~mass~' + para for para in parameter_2],
41 | 'quadruped_mass': ['quadruped_stand~mass~' + para for para in parameter_1],
42 | 'quadruped_damping': ['quadruped_stand~damping~' + para for para in parameter_2],
43 | }
44 |
45 | FINETUNE_TASKS = {
46 | 'walker_stand_mass': ['walker_stand~mass~' + para for para in parameter_2],
47 | 'walker_stand_mass_eval': ['walker_stand~mass~' + para for para in parameter_2_eval],
48 | 'walker_walk_mass': ['walker_walk~mass~' + para for para in parameter_2],
49 | 'walker_walk_mass_eval': ['walker_walk~mass~' + para for para in parameter_2_eval],
50 | 'walker_run_mass': ['walker_run~mass~' + para for para in parameter_2],
51 | 'walker_run_mass_eval': ['walker_run~mass~' + para for para in parameter_2_eval],
52 | 'walker_flip_mass': ['walker_flip~mass~' + para for para in parameter_2],
53 | 'walker_flip_mass_eval': ['walker_flip~mass~' + para for para in parameter_2_eval],
54 |
55 | 'quadruped_stand_mass': ['quadruped_stand~mass~' + para for para in parameter_1],
56 | 'quadruped_stand_mass_eval': ['quadruped_stand~mass~' + para for para in parameter_1_eval],
57 | 'quadruped_walk_mass': ['quadruped_walk~mass~' + para for para in parameter_1],
58 | 'quadruped_walk_mass_eval': ['quadruped_walk~mass~' + para for para in parameter_1_eval],
59 | 'quadruped_run_mass': ['quadruped_run~mass~' + para for para in parameter_1],
60 | 'quadruped_run_mass_eval': ['quadruped_run~mass~' + para for para in parameter_1_eval],
61 | 'quadruped_jump_mass': ['quadruped_jump~mass~' + para for para in parameter_1],
62 | 'quadruped_jump_mass_eval': ['quadruped_jump~mass~' + para for para in parameter_1_eval],
63 |
64 | 'quadruped_stand_damping': ['quadruped_stand~damping~' + para for para in parameter_2],
65 | 'quadruped_stand_damping_eval': ['quadruped_stand~damping~' + para for para in parameter_2_eval],
66 | 'quadruped_walk_damping': ['quadruped_walk~damping~' + para for para in parameter_2],
67 | 'quadruped_walk_damping_eval': ['quadruped_walk~damping~' + para for para in parameter_2_eval],
68 | 'quadruped_run_damping': ['quadruped_run~damping~' + para for para in parameter_2],
69 | 'quadruped_run_damping_eval': ['quadruped_run~damping~' + para for para in parameter_2_eval],
70 | 'quadruped_jump_damping': ['quadruped_jump~damping~' + para for para in parameter_2],
71 | 'quadruped_jump_damping_eval': ['quadruped_jump~damping~' + para for para in parameter_2_eval],
72 | }
73 |
--------------------------------------------------------------------------------
/DMC_image/dreamer_finetune.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - configs/dreamer
3 | - configs/agent: dreamer
4 | - configs: ${configs}
5 | - override hydra/launcher: submitit_local
6 |
7 | # mode
8 | reward_free: false
9 | # task settings
10 | #task: walker_stand
11 | task: none
12 | domain: walker_mass
13 | finetune_domain: walker_stand_mass
14 | # train settings
15 | num_train_frames: 100010
16 | num_seed_frames: 4000
17 | # eval
18 | eval_every_frames: 10000
19 | num_eval_episodes: 10
20 | # pretrained
21 | snapshot_ts: 100000
22 | snapshot_base_dir: ./pretrained_models
23 | custom_snap_dir: none
24 | # replay buffer
25 | replay_buffer_size: 1000000
26 | replay_buffer_num_workers: 4
27 | # misc
28 | seed: 1
29 | device: cuda
30 | save_video: false
31 | save_train_video: false
32 | save_eval_episodes: false
33 | use_tb: true
34 | use_wandb: true
35 | # experiment
36 | experiment: ft
37 | project_name: ???
38 |
39 | # log settings
40 | log_every_frames: 1000
41 | recon_every_frames: 100000000 # edit for debug
42 |
43 | # planning
44 | mpc: false
45 | mpc_opt: { iterations : 12, num_samples : 512, num_elites : 64, mixture_coef : 0.05, min_std : 0.1, temperature : 0.5, momentum : 0.1, horizon : 5, use_value: true }
46 |
47 | # Pretrained network reuse
48 | init_critic: false
49 | init_actor: true
50 | init_task: 1.0
51 |
52 | # Fine-tuning ablation
53 | # we have saved the last model
54 | # save_ft_model: true
55 | save_ft_model: false
56 |
57 | # Dreamer FT
58 | grad_heads: [decoder, reward]
59 | reward_norm: {momentum: 1.0, scale: 1.0, eps: 1e-8}
60 | actor_ent: 1e-4
61 |
62 | hydra:
63 | run:
64 | dir: ./exp_local/${domain}/finetune_${finetune_domain}/${agent.name}/${snapshot_ts}/${now:%Y.%m.%d.%H%M%S}_${seed}
65 | sweep:
66 | dir: ./exp_sweep/${domain}/finetune_${finetune_domain}/${agent.name}/${snapshot_ts}/${now:%Y.%m.%d.%H%M%S}_${seed}_${experiment}
67 | subdir: ${hydra.job.num}
68 | launcher:
69 | timeout_min: 4300
70 | cpus_per_task: 10
71 | gpus_per_node: 1
72 | tasks_per_node: 1
73 | mem_gb: 160
74 | nodes: 1
75 | submitit_folder: ./exp_sweep/${domain}/finetune_${finetune_domain}/${agent.name}/${snapshot_ts}/${now:%Y.%m.%d.%H%M%S}_${seed}_${experiment}/.slurm
--------------------------------------------------------------------------------
/DMC_image/dreamer_pretrain.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - configs/dreamer
3 | - configs/agent: dreamer
4 | - configs: ${configs}
5 | - override hydra/launcher: submitit_local
6 |
7 | # mode
8 | reward_free: true
9 | # task settings
10 | task: none
11 | domain: walker # primal task will be inferred in runtime
12 | # train settings
13 | num_train_frames: 2000010
14 | num_seed_frames: 4000
15 | # eval
16 | eval_every_frames: 10000
17 | num_eval_episodes: 10
18 | # snapshot
19 | snapshots: [100000, 500000, 1000000, 2000000]
20 | snapshot_dir: ../../../../../pretrained_models/${obs_type}/${domain}/${agent.name}/${seed}
21 | # replay buffer
22 | replay_buffer_size: 1000000
23 | replay_buffer_num_workers: 4
24 | # misc
25 | seed: 1
26 | device: cuda
27 | save_video: false
28 | save_train_video: true
29 | use_tb: true
30 | use_wandb: true
31 |
32 | # experiment
33 | experiment: pt
34 | project_name: ???
35 |
36 | # log settings
37 | log_every_frames: 1000
38 | recon_every_frames: 100000000 # edit for debug
39 |
40 |
41 | hydra:
42 | run:
43 | dir: ./exp_local/${domain}/pretrain/${agent.name}/${now:%Y.%m.%d.%H%M%S}_${seed}
44 | sweep:
45 | dir: ./exp_sweep/${domain}/pretrain/${agent.name}/${now:%Y.%m.%d.%H%M}_${seed}_${experiment}
46 | subdir: ${hydra.job.num}
47 | launcher:
48 | timeout_min: 4300
49 | cpus_per_task: 10
50 | gpus_per_node: 1
51 | tasks_per_node: 1
52 | mem_gb: 160
53 | nodes: 1
54 | submitit_folder: ./exp_sweep/${domain}/pretrain/${agent.name}/${now:%Y.%m.%H%M}_${seed}_${experiment}/.slurm
55 |
--------------------------------------------------------------------------------
/DMC_image/dreamer_replay.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | warnings.filterwarnings('ignore')
3 | warnings.filterwarnings('ignore', category=DeprecationWarning)
4 |
5 | import collections
6 | import datetime
7 | import io
8 | import pathlib
9 | import uuid
10 |
11 | import numpy as np
12 | import random
13 | from torch.utils.data import IterableDataset, DataLoader
14 | import torch
15 | import utils
16 |
17 |
18 | class ReplayBuffer(IterableDataset):
19 |
20 | def __init__(
21 | self, data_specs, meta_specs, directory, length=20, capacity=0, ongoing=False, minlen=1, maxlen=0,
22 | prioritize_ends=False, device='cpu', load_first=False):
23 | self._directory = pathlib.Path(directory).expanduser()
24 | self._directory.mkdir(parents=True, exist_ok=True)
25 | self._capacity = capacity
26 | self._ongoing = ongoing
27 | self._minlen = minlen
28 | self._maxlen = maxlen
29 | self._prioritize_ends = prioritize_ends
30 | self._random = np.random.RandomState()
31 | # filename -> key -> value_sequence
32 | self._complete_eps = load_episodes(self._directory, capacity, minlen, load_first=load_first)
33 | # worker -> key -> value_sequence
34 | self._ongoing_eps = collections.defaultdict(
35 | lambda: collections.defaultdict(list))
36 | self._total_episodes, self._total_steps = count_episodes(directory)
37 | self._loaded_episodes = len(self._complete_eps)
38 | self._loaded_steps = sum(eplen(x) for x in self._complete_eps.values())
39 | self._length = length
40 | self._data_specs = data_specs
41 | self._meta_specs = meta_specs
42 | self.device = device
43 | try:
44 | assert self._minlen <= self._length <= self._maxlen
45 | except:
46 | print("Incosistency between min/max/length in the replay buffer. Defaulting to (length): ", length)
47 | self._minlen = self._maxlen = self._length = length
48 |
49 | def __len__(self):
50 | return self._total_steps
51 |
52 | @property
53 | def stats(self):
54 | return {
55 | 'total_steps': self._total_steps,
56 | 'total_episodes': self._total_episodes,
57 | 'loaded_steps': self._loaded_steps,
58 | 'loaded_episodes': self._loaded_episodes,
59 | }
60 |
61 | def add(self, time_step, meta, worker=0, env_id=None):
62 | episode = self._ongoing_eps[worker]
63 | for spec in self._data_specs:
64 | value = time_step[spec.name]
65 | if np.isscalar(value):
66 | value = np.full(spec.shape, value, spec.dtype)
67 | assert spec.shape == value.shape and spec.dtype == value.dtype
68 | episode[spec.name].append(value)
69 | for spec in self._meta_specs:
70 | value = meta[spec.name]
71 | if np.isscalar(value):
72 | value = np.full(spec.shape, value, spec.dtype)
73 | assert spec.shape == value.shape and spec.dtype == value.dtype
74 | episode[spec.name].append(value)
75 | if type(time_step) == dict:
76 | if time_step['is_last']:
77 | self.add_episode(episode, env_id=env_id)
78 | episode.clear()
79 | else:
80 | if bool(dreamer_obs['is_last']):
81 | self.add_episode(episode, env_id=env_id)
82 | episode.clear()
83 |
84 | def add_episode(self, episode, env_id=None):
85 | length = eplen(episode)
86 | if length < self._minlen:
87 | print(f'Skipping short episode of length {length}.')
88 | return
89 | self._total_steps += length
90 | self._loaded_steps += length
91 | self._total_episodes += 1
92 | self._loaded_episodes += 1
93 | episode = {key: convert(value) for key, value in episode.items()}
94 | filename = save_episode(self._directory, episode, env_id=env_id)
95 | self._complete_eps[str(filename)] = episode
96 | self._enforce_limit()
97 |
98 | def __iter__(self):
99 | sequence = self._sample_sequence()
100 | while True:
101 | chunk = collections.defaultdict(list)
102 | added = 0
103 | while added < self._length:
104 | needed = self._length - added
105 | adding = {k: v[:needed] for k, v in sequence.items()}
106 | sequence = {k: v[needed:] for k, v in sequence.items()}
107 | for key, value in adding.items():
108 | chunk[key].append(value)
109 | added += len(adding['action'])
110 | if len(sequence['action']) < 1:
111 | sequence = self._sample_sequence()
112 | chunk = {k: np.concatenate(v) for k, v in chunk.items()}
113 | chunk['is_terminal'] = chunk['discount'] == 0
114 | chunk = {k: torch.as_tensor(np.copy(v), device=self.device) for k, v in chunk.items()}
115 | yield chunk
116 |
117 | def _sample_sequence(self):
118 | episodes = list(self._complete_eps.values())
119 | if self._ongoing:
120 | episodes += [
121 | x for x in self._ongoing_eps.values()
122 | if eplen(x) >= self._minlen]
123 | episode = self._random.choice(episodes)
124 | total = len(episode['action'])
125 | length = total
126 | if self._maxlen:
127 | length = min(length, self._maxlen)
128 | # Randomize length to avoid all chunks ending at the same time in case the
129 | # episodes are all of the same length.
130 | length -= np.random.randint(self._minlen)
131 | length = max(self._minlen, length)
132 | upper = total - length + 1
133 | if self._prioritize_ends:
134 | upper += self._minlen
135 | index = min(self._random.randint(upper), total - length)
136 | sequence = {
137 | k: convert(v[index: index + length])
138 | for k, v in episode.items() if not k.startswith('log_')}
139 | # np.bool -> bool
140 | sequence['is_first'] = np.zeros(len(sequence['action']), bool)
141 | sequence['is_first'][0] = True
142 | if self._maxlen:
143 | assert self._minlen <= len(sequence['action']) <= self._maxlen
144 | return sequence
145 |
146 | def _enforce_limit(self):
147 | if not self._capacity:
148 | return
149 | while self._loaded_episodes > 1 and self._loaded_steps > self._capacity:
150 | # Relying on Python preserving the insertion order of dicts.
151 | oldest, episode = next(iter(self._complete_eps.items()))
152 | self._loaded_steps -= eplen(episode)
153 | self._loaded_episodes -= 1
154 | del self._complete_eps[oldest]
155 |
156 |
157 | def count_episodes(directory):
158 | filenames = list(directory.glob('*.npz'))
159 | num_episodes = len(filenames)
160 | if len(filenames) > 0 and "-" in str(filenames[0]):
161 | num_steps = sum(int(str(n).split('-')[-1][:-4]) - 1 for n in filenames)
162 | else:
163 | num_steps = sum(int(str(n).split('_')[-1][:-4]) - 1 for n in filenames)
164 | return num_episodes, num_steps
165 |
166 |
167 | @utils.retry
168 | def save_episode(directory, episode, env_id=None):
169 | timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')
170 | identifier = str(uuid.uuid4().hex)
171 | length = eplen(episode)
172 | if env_id is None:
173 | filename = directory / f'{timestamp}-{identifier}-{length}.npz'
174 | else:
175 | filename = directory / f'{env_id}-{timestamp}-{identifier}-{length}.npz'
176 | with io.BytesIO() as f1:
177 | np.savez_compressed(f1, **episode)
178 | f1.seek(0)
179 | with filename.open('wb') as f2:
180 | f2.write(f1.read())
181 | return filename
182 |
183 |
184 | def load_episodes(directory, capacity=None, minlen=1, load_first=False):
185 | # The returned directory from filenames to episodes is guaranteed to be in
186 | # temporally sorted order.
187 | filenames = sorted(directory.glob('*.npz'))
188 | if capacity:
189 | num_steps = 0
190 | num_episodes = 0
191 | ordered_filenames = filenames if load_first else reversed(filenames)
192 | for filename in ordered_filenames:
193 | if "-" in str(filename):
194 | length = int(str(filename).split('-')[-1][:-4])
195 | else:
196 | length = int(str(filename).split('_')[-1][:-4])
197 | num_steps += length
198 | num_episodes += 1
199 | if num_steps >= capacity:
200 | break
201 | if load_first:
202 | filenames = filenames[:num_episodes]
203 | else:
204 | filenames = filenames[-num_episodes:]
205 | episodes = {}
206 | for filename in filenames:
207 | try:
208 | with filename.open('rb') as f:
209 | episode = np.load(f)
210 | episode = {k: episode[k] for k in episode.keys()}
211 | except Exception as e:
212 | print(f'Could not load episode {str(filename)}: {e}')
213 | continue
214 | episodes[str(filename)] = episode
215 | return episodes
216 |
217 |
218 | def convert(value):
219 | value = np.array(value)
220 | if np.issubdtype(value.dtype, np.floating):
221 | return value.astype(np.float32)
222 | elif np.issubdtype(value.dtype, np.signedinteger):
223 | return value.astype(np.int32)
224 | elif np.issubdtype(value.dtype, np.uint8):
225 | return value.astype(np.uint8)
226 | return value
227 |
228 |
229 | def eplen(episode):
230 | return len(episode['action']) - 1
231 |
232 |
233 | def _worker_init_fn(worker_id):
234 | seed = np.random.get_state()[1][0] + worker_id
235 | np.random.seed(seed)
236 | random.seed(seed)
237 |
238 |
239 | def make_replay_loader(buffer, batch_size, num_workers):
240 | return DataLoader(buffer,
241 | batch_size=batch_size,
242 | drop_last=True,
243 | )
244 |
--------------------------------------------------------------------------------
/DMC_image/logger.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import datetime
3 | from collections import defaultdict
4 |
5 | import numpy as np
6 | import torch
7 | import torchvision
8 | import wandb
9 | from termcolor import colored
10 | from torch.utils.tensorboard import SummaryWriter
11 |
12 | COMMON_TRAIN_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'),
13 | ('episode', 'E', 'int'), ('episode_length', 'L', 'int'),
14 | ('episode_reward', 'R', 'float'),
15 | ('fps', 'FPS', 'float'), ('total_time', 'T', 'time'),]
16 |
17 | COMMON_EVAL_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'),
18 | ('episode', 'E', 'int'), ('episode_length', 'L', 'int'),
19 | ('episode_reward', 'R', 'float'),
20 | ('total_time', 'T', 'time'),
21 | ('episode_train_reward', 'TR', 'float'),
22 | ('episode_eval_reward', 'ER', 'float')]
23 |
24 |
25 | class AverageMeter(object):
26 | def __init__(self):
27 | self._sum = 0
28 | self._count = 0
29 |
30 | def update(self, value, n=1):
31 | self._sum += value
32 | self._count += n
33 |
34 | def value(self):
35 | return self._sum / max(1, self._count)
36 |
37 |
38 | class MetersGroup(object):
39 | def __init__(self, csv_file_name, formating, use_wandb):
40 | self._csv_file_name = csv_file_name
41 | self._formating = formating
42 | self._meters = defaultdict(AverageMeter)
43 | self._csv_file = None
44 | self._csv_writer = None
45 | self.use_wandb = use_wandb
46 |
47 | def log(self, key, value, n=1):
48 | self._meters[key].update(value, n)
49 |
50 | def _prime_meters(self):
51 | data = dict()
52 | for key, meter in self._meters.items():
53 | if key.startswith('train'):
54 | key = key[len('train') + 1:]
55 | else:
56 | key = key[len('eval') + 1:]
57 | key = key.replace('/', '_')
58 | data[key] = meter.value()
59 | return data
60 |
61 | def _remove_old_entries(self, data):
62 | rows = []
63 | with self._csv_file_name.open('r') as f:
64 | reader = csv.DictReader(f)
65 | for row in reader:
66 | if 'episode' in row:
67 | # BUGFIX: covers weird cases where CSV are badly written
68 | if row['episode'] == '':
69 | rows.append(row)
70 | continue
71 | if type(row['episode']) == type(None):
72 | continue
73 | if float(row['episode']) >= data['episode']:
74 | break
75 | rows.append(row)
76 | with self._csv_file_name.open('w') as f:
77 | # To handle CSV that have more keys than new data
78 | keys = set(data.keys())
79 | if len(rows) > 0: keys = keys | set(row.keys())
80 | keys = sorted(list(keys))
81 | #
82 | writer = csv.DictWriter(f,
83 | fieldnames=keys,
84 | restval=0.0)
85 | writer.writeheader()
86 | for row in rows:
87 | writer.writerow(row)
88 |
89 | def _dump_to_csv(self, data):
90 | if self._csv_writer is None:
91 | should_write_header = True
92 | if self._csv_file_name.exists():
93 | self._remove_old_entries(data)
94 | should_write_header = False
95 |
96 | self._csv_file = self._csv_file_name.open('a')
97 | self._csv_writer = csv.DictWriter(self._csv_file,
98 | fieldnames=sorted(data.keys()),
99 | restval=0.0)
100 | if should_write_header:
101 | self._csv_writer.writeheader()
102 |
103 | # To handle components that start training later
104 | # (restval covers only when data has less keys than the CSV)
105 | if self._csv_writer.fieldnames != sorted(data.keys()) and \
106 | len(self._csv_writer.fieldnames) < len(data.keys()):
107 | self._csv_file.close()
108 | self._csv_file = self._csv_file_name.open('r')
109 | dict_reader = csv.DictReader(self._csv_file)
110 | rows = [row for row in dict_reader]
111 | self._csv_file.close()
112 | self._csv_file = self._csv_file_name.open('w')
113 | self._csv_writer = csv.DictWriter(self._csv_file,
114 | fieldnames=sorted(data.keys()),
115 | restval=0.0)
116 | self._csv_writer.writeheader()
117 | for row in rows:
118 | self._csv_writer.writerow(row)
119 |
120 | self._csv_writer.writerow(data)
121 | self._csv_file.flush()
122 |
123 | def _format(self, key, value, ty):
124 | if ty == 'int':
125 | value = int(value)
126 | return f'{key}: {value}'
127 | elif ty == 'float':
128 | return f'{key}: {value:.04f}'
129 | elif ty == 'time':
130 | value = str(datetime.timedelta(seconds=int(value)))
131 | return f'{key}: {value}'
132 | elif ty == 'str':
133 | return f'{key}: {value}'
134 | else:
135 | raise f'invalid format type: {ty}'
136 |
137 | def _dump_to_console(self, data, prefix):
138 | prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green')
139 | pieces = [f'| {prefix: <14}']
140 | for key, disp_key, ty in self._formating:
141 | value = data.get(key, 0)
142 | pieces.append(self._format(disp_key, value, ty))
143 | print(' | '.join(pieces))
144 |
145 | def _dump_to_wandb(self, data):
146 | wandb.log(data)
147 |
148 | def dump(self, step, prefix):
149 | if len(self._meters) == 0:
150 | return
151 | data = self._prime_meters()
152 | data['frame'] = step
153 | if self.use_wandb:
154 | wandb_data = {prefix + '/' + key: val for key, val in data.items()}
155 | self._dump_to_wandb(data=wandb_data)
156 | self._dump_to_csv(data)
157 | self._dump_to_console(data, prefix)
158 | self._meters.clear()
159 |
160 |
161 | class Logger(object):
162 | def __init__(self, log_dir, use_tb, use_wandb):
163 | self._log_dir = log_dir
164 | self._train_mg = MetersGroup(log_dir / 'train.csv',
165 | formating=COMMON_TRAIN_FORMAT,
166 | use_wandb=use_wandb)
167 | self._eval_mg = MetersGroup(log_dir / 'eval.csv',
168 | formating=COMMON_EVAL_FORMAT,
169 | use_wandb=use_wandb)
170 | if use_tb:
171 | self._sw = SummaryWriter(str(log_dir / 'tb'))
172 | else:
173 | self._sw = None
174 | self.use_wandb = use_wandb
175 |
176 | def _try_sw_log(self, key, value, step):
177 | if self._sw is not None:
178 | self._sw.add_scalar(key, value, step)
179 |
180 | def log(self, key, value, step):
181 | assert key.startswith('train') or key.startswith('eval')
182 | if type(value) == torch.Tensor:
183 | value = value.item()
184 | # print(key)
185 | self._try_sw_log(key, value, step)
186 | mg = self._train_mg if key.startswith('train') else self._eval_mg
187 | mg.log(key, value)
188 |
189 | def log_metrics(self, metrics, step, ty):
190 | for key, value in metrics.items():
191 | self.log(f'{ty}/{key}', value, step)
192 |
193 | def dump(self, step, ty=None):
194 | if ty is None or ty == 'eval':
195 | self._eval_mg.dump(step, 'eval')
196 | if ty is None or ty == 'train':
197 | self._train_mg.dump(step, 'train')
198 |
199 | def log_and_dump_ctx(self, step, ty):
200 | return LogAndDumpCtx(self, step, ty)
201 |
202 | def log_video(self, data, step):
203 | if self._sw is not None:
204 | for k, v in data.items():
205 | self._sw.add_video(k, v, global_step=step, fps=15)
206 | if self.use_wandb:
207 | for k, v in data.items():
208 | v = np.uint8(v.cpu() * 255)
209 | wandb.log({k: wandb.Video(v, fps=15, format="gif")})
210 |
211 |
212 | class LogAndDumpCtx:
213 | def __init__(self, logger, step, ty):
214 | self._logger = logger
215 | self._step = step
216 | self._ty = ty
217 |
218 | def __enter__(self):
219 | return self
220 |
221 | def __call__(self, key, value):
222 | self._logger.log(f'{self._ty}/{key}', value, self._step)
223 |
224 | def __exit__(self, *args):
225 | self._logger.dump(self._step, self._ty)
226 |
--------------------------------------------------------------------------------
/DMC_image/train_finetune.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ALGO=$1
4 | DOMAIN=$2 # walker_mass, quadruped_mass, quadruped_damping
5 | GPU_ID=$3
6 |
7 | if [ "$DOMAIN" == "walker_mass" ]
8 | then
9 | ALL_TASKS=("walker_stand_mass" "walker_walk_mass" "walker_run_mass" "walker_flip_mass")
10 | elif [ "$DOMAIN" == "quadruped_mass" ]
11 | then
12 | ALL_TASKS=("quadruped_stand_mass" "quadruped_walk_mass" "quadruped_run_mass" "quadruped_jump_mass")
13 | elif [ "$DOMAIN" == "quadruped_damping" ]
14 | then
15 | ALL_TASKS=("quadruped_stand_damping" "quadruped_walk_damping" "quadruped_run_damping" "quadruped_jump_damping")
16 | else
17 | ALL_TASKS=()
18 | echo "No matching tasks, you can only take DOMAIN as walker_mass, quadruped_mass, or quadruped_damping"
19 | exit 0
20 | fi
21 |
22 | echo "Experiments started."
23 | for seed in $(seq 0 2)
24 | do
25 | export MUJOCO_EGL_DEVICE_ID=${GPU_ID}
26 | python dreamer_pretrain.py configs=dmc_pixels configs/agent=${ALGO} domain=${DOMAIN} seed=$seed device=cuda:${GPU_ID}
27 | for string in "${ALL_TASKS[@]}"
28 | do
29 | export MUJOCO_EGL_DEVICE_ID=${GPU_ID}
30 | python dreamer_finetune.py configs=dmc_pixels configs/agent=${ALGO} domain=${DOMAIN} seed=$seed device=cuda:${GPU_ID} snapshot_ts=2000000 finetune_domain=$string
31 | done
32 | done
33 | echo "Experiments ended."
34 |
35 | # e.g.
36 | # ./finetune.sh peac_lbs walker_mass 0
--------------------------------------------------------------------------------
/DMC_state/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Facebook, Inc. and its affiliates.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/DMC_state/README.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thu-ml/CEURL/c213c2936aa2394e645c069e9c8333ee6bd0ce0a/DMC_state/README.md
--------------------------------------------------------------------------------
/DMC_state/agent/becl.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import math
3 | from collections import OrderedDict
4 |
5 | import hydra
6 | import random
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from dm_env import specs
12 |
13 | import utils
14 | from agent.ddpg import DDPGAgent
15 |
16 |
17 | class BECL(nn.Module):
18 | def __init__(self, tau_dim, feature_dim, hidden_dim):
19 | super().__init__()
20 |
21 | self.embed = nn.Sequential(nn.Linear(tau_dim, hidden_dim),
22 | nn.ReLU(),
23 | nn.Linear(hidden_dim, hidden_dim),
24 | nn.ReLU(),
25 | nn.Linear(hidden_dim, feature_dim))
26 |
27 | self.project_head = nn.Sequential(nn.Linear(feature_dim, hidden_dim),
28 | nn.ReLU(),
29 | nn.Linear(hidden_dim, feature_dim))
30 | self.apply(utils.weight_init)
31 |
32 | def forward(self, tau):
33 | features = self.embed(tau)
34 | features = self.project_head(features)
35 | return features
36 |
37 |
38 | class BECLAgent(DDPGAgent):
39 | def __init__(self, update_skill_every_step, skill_dim,
40 | update_encoder, contrastive_update_rate, temperature, skill, **kwargs):
41 | self.skill_dim = skill_dim
42 | self.update_skill_every_step = update_skill_every_step
43 | self.update_encoder = update_encoder
44 | self.contrastive_update_rate = contrastive_update_rate
45 | self.temperature = temperature
46 | # specify skill in fine-tuning stage if needed
47 | self.skill = int(skill) if skill >= 0 else np.random.choice(self.skill_dim)
48 | # increase obs shape to include skill dim
49 | kwargs["meta_dim"] = self.skill_dim
50 | self.batch_size = kwargs['batch_size']
51 | # create actor and critic
52 | super().__init__(**kwargs)
53 |
54 | # net
55 | self.becl = BECL(self.obs_dim - self.skill_dim,
56 | self.skill_dim,
57 | kwargs['hidden_dim']).to(kwargs['device'])
58 |
59 | # optimizers
60 | self.becl_opt = torch.optim.Adam(self.becl.parameters(), lr=self.lr)
61 |
62 | self.becl.train()
63 |
64 | def get_meta_specs(self):
65 | return specs.Array((self.skill_dim,), np.float32, 'skill'),
66 |
67 | def init_meta(self):
68 | skill = np.zeros(self.skill_dim).astype(np.float32)
69 | if not self.reward_free:
70 | skill[self.skill] = 1.0
71 | else:
72 | skill[np.random.choice(self.skill_dim)] = 1.0
73 | meta = OrderedDict()
74 | meta['skill'] = skill
75 | return meta
76 |
77 | def update_meta(self, meta, global_step, time_step, finetune=False):
78 | if global_step % self.update_skill_every_step == 0:
79 | return self.init_meta()
80 | return meta
81 |
82 | def update_contrastive(self, state, skills):
83 | metrics = dict()
84 | features = self.becl(state)
85 | logits = self.compute_info_nce_loss(features, skills)
86 | loss = logits.mean()
87 |
88 | self.becl_opt.zero_grad()
89 | if self.encoder_opt is not None:
90 | self.encoder_opt.zero_grad(set_to_none=True)
91 | loss.backward()
92 | self.becl_opt.step()
93 | if self.encoder_opt is not None:
94 | self.encoder_opt.step()
95 |
96 | if self.use_tb or self.use_wandb:
97 | metrics['contrastive_loss'] = loss.item()
98 |
99 | return metrics
100 |
101 | def compute_intr_reward(self, skills, state, metrics):
102 |
103 | # compute contrastive reward
104 | features = self.becl(state)
105 | contrastive_reward = torch.exp(-self.compute_info_nce_loss(features, skills))
106 |
107 | intr_reward = contrastive_reward
108 | if self.use_tb or self.use_wandb:
109 | metrics['contrastive_reward'] = contrastive_reward.mean().item()
110 |
111 | return intr_reward
112 |
113 | def compute_info_nce_loss(self, features, skills):
114 | # features: (b,c), skills :(b, skill_dim)
115 | # label positives samples
116 | labels = torch.argmax(skills, dim=-1) # (b, 1)
117 | labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).long() # (b,b)
118 | labels = labels.to(self.device)
119 |
120 | features = F.normalize(features, dim=1) # (b,c)
121 | similarity_matrix = torch.matmul(features, features.T) # (b,b)
122 |
123 | # discard the main diagonal from both: labels and similarities matrix
124 | mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.device)
125 | labels = labels[~mask].view(labels.shape[0], -1) # (b,b-1)
126 | similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1) # (b,b-1)
127 |
128 | similarity_matrix = similarity_matrix / self.temperature
129 | similarity_matrix -= torch.max(similarity_matrix, 1)[0][:, None]
130 | similarity_matrix = torch.exp(similarity_matrix)
131 |
132 | pick_one_positive_sample_idx = torch.argmax(labels, dim=-1, keepdim=True)
133 | pick_one_positive_sample_idx = torch.zeros_like(labels).scatter_(-1, pick_one_positive_sample_idx, 1)
134 |
135 | positives = torch.sum(similarity_matrix * pick_one_positive_sample_idx, dim=-1, keepdim=True) # (b,1)
136 | negatives = torch.sum(similarity_matrix, dim=-1, keepdim=True) # (b,1)
137 | eps = torch.as_tensor(1e-6)
138 | loss = -torch.log(positives / (negatives + eps) + eps) # (b,1)
139 |
140 | return loss
141 |
142 | def update(self, replay_iter, step):
143 | metrics = dict()
144 |
145 | if step % self.update_every_steps != 0:
146 | return metrics
147 |
148 | if self.reward_free:
149 |
150 | batch = next(replay_iter)
151 | obs, action, reward, discount, next_obs, task_id, skill = utils.to_torch(batch, self.device)
152 | obs = self.aug_and_encode(obs)
153 | next_obs = self.aug_and_encode(next_obs)
154 |
155 | metrics.update(self.update_contrastive(next_obs, skill))
156 |
157 | for _ in range(self.contrastive_update_rate - 1):
158 | batch = next(replay_iter)
159 | obs, action, reward, discount, next_obs, task_id, skill = utils.to_torch(batch, self.device)
160 | obs = self.aug_and_encode(obs)
161 | next_obs = self.aug_and_encode(next_obs)
162 |
163 | metrics.update(self.update_contrastive(next_obs, skill))
164 |
165 | with torch.no_grad():
166 | intr_reward = self.compute_intr_reward(skill, next_obs, metrics)
167 |
168 | if self.use_tb or self.use_wandb:
169 | metrics['intr_reward'] = intr_reward.mean().item()
170 |
171 | reward = intr_reward
172 | else:
173 | batch = next(replay_iter)
174 |
175 | obs, action, extr_reward, discount, next_obs, task_id, skill = utils.to_torch(
176 | batch, self.device)
177 | obs = self.aug_and_encode(obs)
178 | next_obs = self.aug_and_encode(next_obs)
179 | reward = extr_reward
180 |
181 | if self.use_tb or self.use_wandb:
182 | metrics['batch_reward'] = reward.mean().item()
183 |
184 | if not self.update_encoder:
185 | obs = obs.detach()
186 | next_obs = next_obs.detach()
187 |
188 | # extend observations with skill
189 | obs = torch.cat([obs, skill], dim=1)
190 | next_obs = torch.cat([next_obs, skill], dim=1)
191 |
192 | # update critic
193 | metrics.update(
194 | self.update_critic(obs.detach(), action, reward, discount,
195 | next_obs.detach(), step))
196 |
197 | # update actor
198 | metrics.update(self.update_actor(obs.detach(), step))
199 |
200 | # update critic target
201 | utils.soft_update_params(self.critic, self.critic_target,
202 | self.critic_target_tau)
203 |
204 | return metrics
--------------------------------------------------------------------------------
/DMC_state/agent/cic.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from dm_env import specs
7 | import math
8 | from collections import OrderedDict
9 |
10 | import utils
11 |
12 | from agent.ddpg import DDPGAgent
13 |
14 |
15 | class CIC(nn.Module):
16 | def __init__(self, obs_dim, skill_dim, hidden_dim, project_skill):
17 | super().__init__()
18 | self.obs_dim = obs_dim
19 | self.skill_dim = skill_dim
20 |
21 | self.state_net = nn.Sequential(nn.Linear(self.obs_dim, hidden_dim), nn.ReLU(),
22 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
23 | nn.Linear(hidden_dim, self.skill_dim))
24 |
25 | self.next_state_net = nn.Sequential(nn.Linear(self.obs_dim, hidden_dim), nn.ReLU(),
26 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
27 | nn.Linear(hidden_dim, self.skill_dim))
28 |
29 | self.pred_net = nn.Sequential(nn.Linear(2 * self.skill_dim, hidden_dim), nn.ReLU(),
30 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
31 | nn.Linear(hidden_dim, self.skill_dim))
32 |
33 | if project_skill:
34 | self.skill_net = nn.Sequential(nn.Linear(self.skill_dim, hidden_dim), nn.ReLU(),
35 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
36 | nn.Linear(hidden_dim, self.skill_dim))
37 | else:
38 | self.skill_net = nn.Identity()
39 |
40 | self.apply(utils.weight_init)
41 |
42 | def forward(self, state, next_state, skill):
43 | assert len(state.size()) == len(next_state.size())
44 | state = self.state_net(state)
45 | next_state = self.state_net(next_state)
46 | query = self.skill_net(skill)
47 | key = self.pred_net(torch.cat([state, next_state], 1))
48 | return query, key
49 |
50 |
51 | # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52 |
53 |
54 | class RMS(object):
55 | def __init__(self, epsilon=1e-4, shape=(1,), device='cpu'):
56 | self.M = torch.zeros(shape).to(device)
57 | self.S = torch.ones(shape).to(device)
58 | self.n = epsilon
59 |
60 | def __call__(self, x):
61 | bs = x.size(0)
62 | delta = torch.mean(x, dim=0) - self.M
63 | new_M = self.M + delta * bs / (self.n + bs)
64 | new_S = (self.S * self.n + torch.var(x, dim=0) * bs + (delta ** 2) * self.n * bs / (self.n + bs)) / (
65 | self.n + bs)
66 |
67 | self.M = new_M
68 | self.S = new_S
69 | self.n += bs
70 |
71 | return self.M, self.S
72 |
73 |
74 | class APTArgs:
75 | def __init__(self, knn_k=16, knn_avg=True, rms=True, knn_clip=0.0005, ):
76 | self.knn_k = knn_k
77 | self.knn_avg = knn_avg
78 | self.rms = rms
79 | self.knn_clip = knn_clip
80 |
81 |
82 | def compute_apt_reward(source, target, args, device, rms):
83 | b1, b2 = source.size(0), target.size(0)
84 | # (b1, 1, c) - (1, b2, c) -> (b1, 1, c) - (1, b2, c) -> (b1, b2, c) -> (b1, b2)
85 | sim_matrix = torch.norm(source[:, None, :].view(b1, 1, -1) - target[None, :, :].view(1, b2, -1), dim=-1, p=2)
86 | reward, _ = sim_matrix.topk(args.knn_k, dim=1, largest=False, sorted=True) # (b1, k)
87 |
88 | if not args.knn_avg: # only keep k-th nearest neighbor
89 | reward = reward[:, -1]
90 | reward = reward.reshape(-1, 1) # (b1, 1)
91 | if args.rms:
92 | moving_mean, moving_std = rms(reward)
93 | reward = reward / moving_std
94 | reward = torch.max(reward - args.knn_clip, torch.zeros_like(reward).to(device)) # (b1, )
95 | else: # average over all k nearest neighbors
96 | reward = reward.reshape(-1, 1) # (b1 * k, 1)
97 | if args.rms:
98 | moving_mean, moving_std = rms(reward)
99 | reward = reward / moving_std
100 | reward = torch.max(reward - args.knn_clip, torch.zeros_like(reward).to(device))
101 | reward = reward.reshape((b1, args.knn_k)) # (b1, k)
102 | reward = reward.mean(dim=1) # (b1,)
103 | reward = torch.log(reward + 1.0)
104 | return reward
105 |
106 |
107 | class CICAgent(DDPGAgent):
108 | # Contrastive Intrinsic Control (CIC)
109 | def __init__(self, update_skill_every_step, skill_dim, scale,
110 | project_skill, rew_type, update_rep, temp, **kwargs):
111 | self.temp = temp
112 | self.skill_dim = skill_dim
113 | self.update_skill_every_step = update_skill_every_step
114 | self.scale = scale
115 | self.project_skill = project_skill
116 | self.rew_type = rew_type
117 | self.update_rep = update_rep
118 | kwargs["meta_dim"] = self.skill_dim
119 | # create actor and critic
120 | self.device = kwargs['device']
121 |
122 | super().__init__(**kwargs)
123 | self.rms = RMS(device=self.device)
124 |
125 | # create cic first
126 | self.cic = CIC(self.obs_dim - skill_dim, skill_dim,
127 | kwargs['hidden_dim'], project_skill).to(self.device)
128 |
129 | # optimizers
130 | self.cic_optimizer = torch.optim.Adam(self.cic.parameters(),
131 | lr=self.lr)
132 |
133 | self.cic.train()
134 |
135 | def get_meta_specs(self):
136 | return (specs.Array((self.skill_dim,), np.float32, 'skill'),)
137 |
138 | def init_meta(self):
139 | if not self.reward_free:
140 | # selects mean skill of 0.5 (to select skill automatically use CEM or Grid Sweep
141 | # procedures described in the CIC paper)
142 | skill = np.ones(self.skill_dim).astype(np.float32) * 0.5
143 | else:
144 | skill = np.random.uniform(0, 1, self.skill_dim).astype(np.float32)
145 | meta = OrderedDict()
146 | meta['skill'] = skill
147 | return meta
148 |
149 | def update_meta(self, meta, step, time_step):
150 | if step % self.update_skill_every_step == 0:
151 | return self.init_meta()
152 | return meta
153 |
154 | def compute_cpc_loss(self, obs, next_obs, skill):
155 | temperature = self.temp
156 | eps = 1e-6
157 | query, key = self.cic.forward(obs, next_obs, skill)
158 | query = F.normalize(query, dim=1)
159 | key = F.normalize(key, dim=1)
160 | cov = torch.mm(query, key.T) # (b,b)
161 | sim = torch.exp(cov / temperature)
162 | neg = sim.sum(dim=-1) # (b,)
163 | row_sub = torch.Tensor(neg.shape).fill_(math.e ** (1 / temperature)).to(neg.device)
164 | neg = torch.clamp(neg - row_sub, min=eps) # clamp for numerical stability
165 |
166 | pos = torch.exp(torch.sum(query * key, dim=-1) / temperature) # (b,)
167 | loss = -torch.log(pos / (neg + eps)) # (b,)
168 | return loss, cov / temperature
169 |
170 | def update_cic(self, obs, skill, next_obs, step):
171 | metrics = dict()
172 |
173 | loss, logits = self.compute_cpc_loss(obs, next_obs, skill)
174 | loss = loss.mean()
175 | self.cic_optimizer.zero_grad()
176 | loss.backward()
177 | self.cic_optimizer.step()
178 |
179 | if self.use_tb or self.use_wandb:
180 | metrics['cic_loss'] = loss.item()
181 | metrics['cic_logits'] = logits.norm()
182 |
183 | return metrics
184 |
185 | def compute_intr_reward(self, obs, skill, next_obs, step):
186 |
187 | with torch.no_grad():
188 | loss, logits = self.compute_cpc_loss(obs, next_obs, skill)
189 |
190 | reward = loss
191 | reward = reward.clone().detach().unsqueeze(-1)
192 |
193 | return reward * self.scale
194 |
195 | @torch.no_grad()
196 | def compute_apt_reward(self, obs, next_obs):
197 | args = APTArgs()
198 | source = self.cic.state_net(obs)
199 | target = self.cic.state_net(next_obs)
200 | reward = compute_apt_reward(source, target, args, device=self.device,
201 | rms=self.rms) # (b,)
202 | return reward.unsqueeze(-1) # (b,1)
203 |
204 | def update(self, replay_iter, step):
205 | metrics = dict()
206 |
207 | if step % self.update_every_steps != 0:
208 | return metrics
209 |
210 | batch = next(replay_iter)
211 |
212 | obs, action, extr_reward, discount, next_obs, task_id, skill = utils.to_torch(
213 | batch, self.device)
214 |
215 | with torch.no_grad():
216 | obs = self.aug_and_encode(obs)
217 |
218 | next_obs = self.aug_and_encode(next_obs)
219 |
220 | if self.reward_free:
221 | if self.update_rep:
222 | metrics.update(self.update_cic(obs, skill, next_obs, step))
223 |
224 | intr_reward = self.compute_apt_reward(next_obs, next_obs)
225 |
226 | reward = intr_reward
227 | else:
228 | reward = extr_reward
229 |
230 | if self.use_tb or self.use_wandb:
231 | if self.reward_free:
232 | metrics['extr_reward'] = extr_reward.mean().item()
233 | # metrics['intr_reward'] = apt_reward.mean().item()
234 | metrics['batch_reward'] = reward.mean().item()
235 |
236 | # extend observations with skill
237 | obs = torch.cat([obs, skill], dim=1)
238 | next_obs = torch.cat([next_obs, skill], dim=1)
239 |
240 | # update critic
241 | metrics.update(
242 | self.update_critic(obs, action, reward, discount, next_obs, step))
243 |
244 | # update actor
245 | metrics.update(self.update_actor(obs, step))
246 |
247 | # update critic target
248 | utils.soft_update_params(self.critic, self.critic_target,
249 | self.critic_target_tau)
250 |
251 | return metrics
--------------------------------------------------------------------------------
/DMC_state/agent/diayn.py:
--------------------------------------------------------------------------------
1 | import math
2 | from collections import OrderedDict
3 |
4 | import hydra
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from dm_env import specs
10 |
11 | import utils
12 | from agent.ddpg import DDPGAgent
13 |
14 |
15 | class DIAYN(nn.Module):
16 | def __init__(self, obs_dim, skill_dim, hidden_dim):
17 | super().__init__()
18 | self.skill_pred_net = nn.Sequential(nn.Linear(obs_dim, hidden_dim),
19 | nn.ReLU(),
20 | nn.Linear(hidden_dim, hidden_dim),
21 | nn.ReLU(),
22 | nn.Linear(hidden_dim, skill_dim))
23 |
24 | self.apply(utils.weight_init)
25 |
26 | def forward(self, obs):
27 | skill_pred = self.skill_pred_net(obs)
28 | return skill_pred
29 |
30 |
31 | class DIAYNAgent(DDPGAgent):
32 | def __init__(self, update_skill_every_step, skill_dim, diayn_scale,
33 | update_encoder, **kwargs):
34 | self.skill_dim = skill_dim
35 | self.update_skill_every_step = update_skill_every_step
36 | self.diayn_scale = diayn_scale
37 | self.update_encoder = update_encoder
38 | # increase obs shape to include skill dim
39 | kwargs["meta_dim"] = self.skill_dim
40 |
41 | # create actor and critic
42 | super().__init__(**kwargs)
43 |
44 | # create diayn
45 | self.diayn = DIAYN(self.obs_dim - self.skill_dim, self.skill_dim,
46 | kwargs['hidden_dim']).to(kwargs['device'])
47 |
48 | # loss criterion
49 | self.diayn_criterion = nn.CrossEntropyLoss()
50 | # optimizers
51 | self.diayn_opt = torch.optim.Adam(self.diayn.parameters(), lr=self.lr)
52 |
53 | self.diayn.train()
54 |
55 | def get_meta_specs(self):
56 | return (specs.Array((self.skill_dim,), np.float32, 'skill'),)
57 |
58 | def init_meta(self):
59 | skill = np.zeros(self.skill_dim, dtype=np.float32)
60 | skill[np.random.choice(self.skill_dim)] = 1.0
61 | meta = OrderedDict()
62 | meta['skill'] = skill
63 | return meta
64 |
65 | def update_meta(self, meta, global_step, time_step):
66 | if global_step % self.update_skill_every_step == 0:
67 | return self.init_meta()
68 | return meta
69 |
70 | def update_diayn(self, skill, next_obs, step):
71 | metrics = dict()
72 |
73 | loss, df_accuracy = self.compute_diayn_loss(next_obs, skill)
74 |
75 | self.diayn_opt.zero_grad()
76 | if self.encoder_opt is not None:
77 | self.encoder_opt.zero_grad(set_to_none=True)
78 | loss.backward()
79 | self.diayn_opt.step()
80 | if self.encoder_opt is not None:
81 | self.encoder_opt.step()
82 |
83 | if self.use_tb or self.use_wandb:
84 | metrics['diayn_loss'] = loss.item()
85 | metrics['diayn_acc'] = df_accuracy
86 |
87 | return metrics
88 |
89 | def compute_intr_reward(self, skill, next_obs, step):
90 | z_hat = torch.argmax(skill, dim=1)
91 | d_pred = self.diayn(next_obs)
92 | d_pred_log_softmax = F.log_softmax(d_pred, dim=1)
93 | _, pred_z = torch.max(d_pred_log_softmax, dim=1, keepdim=True)
94 | reward = d_pred_log_softmax[torch.arange(d_pred.shape[0]),
95 | z_hat] - math.log(1 / self.skill_dim)
96 | reward = reward.reshape(-1, 1)
97 |
98 | return reward * self.diayn_scale
99 |
100 | def compute_diayn_loss(self, next_state, skill):
101 | """
102 | DF Loss
103 | """
104 | z_hat = torch.argmax(skill, dim=1)
105 | d_pred = self.diayn(next_state)
106 | d_pred_log_softmax = F.log_softmax(d_pred, dim=1)
107 | _, pred_z = torch.max(d_pred_log_softmax, dim=1, keepdim=True)
108 | d_loss = self.diayn_criterion(d_pred, z_hat)
109 | df_accuracy = torch.sum(
110 | torch.eq(z_hat,
111 | pred_z.reshape(1,
112 | list(
113 | pred_z.size())[0])[0])).float() / list(
114 | pred_z.size())[0]
115 | return d_loss, df_accuracy
116 |
117 | def update(self, replay_iter, step):
118 | metrics = dict()
119 |
120 | if step % self.update_every_steps != 0:
121 | return metrics
122 |
123 | batch = next(replay_iter)
124 |
125 | obs, action, extr_reward, discount, next_obs, task_id, skill = utils.to_torch(
126 | batch, self.device)
127 |
128 | # augment and encode
129 | obs = self.aug_and_encode(obs)
130 | next_obs = self.aug_and_encode(next_obs)
131 |
132 | if self.reward_free:
133 | metrics.update(self.update_diayn(skill, next_obs, step))
134 |
135 | with torch.no_grad():
136 | intr_reward = self.compute_intr_reward(skill, next_obs, step)
137 |
138 | if self.use_tb or self.use_wandb:
139 | metrics['intr_reward'] = intr_reward.mean().item()
140 | reward = intr_reward
141 | else:
142 | reward = extr_reward
143 |
144 | if self.use_tb or self.use_wandb:
145 | metrics['extr_reward'] = extr_reward.mean().item()
146 | metrics['batch_reward'] = reward.mean().item()
147 |
148 | if not self.update_encoder:
149 | obs = obs.detach()
150 | next_obs = next_obs.detach()
151 |
152 | # extend observations with skill
153 | obs = torch.cat([obs, skill], dim=1)
154 | next_obs = torch.cat([next_obs, skill], dim=1)
155 |
156 | # update critic
157 | metrics.update(
158 | self.update_critic(obs.detach(), action, reward, discount,
159 | next_obs.detach(), step))
160 |
161 | # update actor
162 | metrics.update(self.update_actor(obs.detach(), step))
163 |
164 | # update critic target
165 | utils.soft_update_params(self.critic, self.critic_target,
166 | self.critic_target_tau)
167 |
168 | return metrics
169 |
--------------------------------------------------------------------------------
/DMC_state/agent/disagreement.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | import utils
8 | from agent.ddpg import DDPGAgent
9 |
10 |
11 | class Disagreement(nn.Module):
12 | def __init__(self, obs_dim, action_dim, hidden_dim, n_models=5):
13 | super().__init__()
14 | self.ensemble = nn.ModuleList([
15 | nn.Sequential(nn.Linear(obs_dim + action_dim, hidden_dim),
16 | nn.ReLU(), nn.Linear(hidden_dim, obs_dim))
17 | for _ in range(n_models)
18 | ])
19 |
20 | def forward(self, obs, action, next_obs):
21 | #import ipdb; ipdb.set_trace()
22 | assert obs.shape[0] == next_obs.shape[0]
23 | assert obs.shape[0] == action.shape[0]
24 |
25 | errors = []
26 | for model in self.ensemble:
27 | next_obs_hat = model(torch.cat([obs, action], dim=-1))
28 | model_error = torch.norm(next_obs - next_obs_hat,
29 | dim=-1,
30 | p=2,
31 | keepdim=True)
32 | errors.append(model_error)
33 |
34 | return torch.cat(errors, dim=1)
35 |
36 | def get_disagreement(self, obs, action, next_obs):
37 | assert obs.shape[0] == next_obs.shape[0]
38 | assert obs.shape[0] == action.shape[0]
39 |
40 | preds = []
41 | for model in self.ensemble:
42 | next_obs_hat = model(torch.cat([obs, action], dim=-1))
43 | preds.append(next_obs_hat)
44 | preds = torch.stack(preds, dim=0)
45 | return torch.var(preds, dim=0).mean(dim=-1)
46 |
47 |
48 | class DisagreementAgent(DDPGAgent):
49 | def __init__(self, update_encoder, **kwargs):
50 | super().__init__(**kwargs)
51 | self.update_encoder = update_encoder
52 |
53 | self.disagreement = Disagreement(self.obs_dim, self.action_dim,
54 | self.hidden_dim).to(self.device)
55 |
56 | # optimizers
57 | self.disagreement_opt = torch.optim.Adam(
58 | self.disagreement.parameters(), lr=self.lr)
59 |
60 | self.disagreement.train()
61 |
62 | def update_disagreement(self, obs, action, next_obs, step):
63 | metrics = dict()
64 |
65 | error = self.disagreement(obs, action, next_obs)
66 |
67 | loss = error.mean()
68 |
69 | self.disagreement_opt.zero_grad(set_to_none=True)
70 | if self.encoder_opt is not None:
71 | self.encoder_opt.zero_grad(set_to_none=True)
72 | loss.backward()
73 | self.disagreement_opt.step()
74 | if self.encoder_opt is not None:
75 | self.encoder_opt.step()
76 |
77 | if self.use_tb or self.use_wandb:
78 | metrics['disagreement_loss'] = loss.item()
79 |
80 | return metrics
81 |
82 | def compute_intr_reward(self, obs, action, next_obs, step):
83 | reward = self.disagreement.get_disagreement(obs, action,
84 | next_obs).unsqueeze(1)
85 | return reward
86 |
87 | def update(self, replay_iter, step):
88 | metrics = dict()
89 |
90 | if step % self.update_every_steps != 0:
91 | return metrics
92 |
93 | batch = next(replay_iter)
94 | obs, action, extr_reward, discount, next_obs, task_id = utils.to_torch(
95 | batch, self.device)
96 |
97 | # augment and encode
98 | obs = self.aug_and_encode(obs)
99 | with torch.no_grad():
100 | next_obs = self.aug_and_encode(next_obs)
101 |
102 | if self.reward_free:
103 | metrics.update(
104 | self.update_disagreement(obs, action, next_obs, step))
105 |
106 | with torch.no_grad():
107 | intr_reward = self.compute_intr_reward(obs, action, next_obs,
108 | step)
109 |
110 | if self.use_tb or self.use_wandb:
111 | metrics['intr_reward'] = intr_reward.mean().item()
112 | reward = intr_reward
113 | else:
114 | reward = extr_reward
115 |
116 | if self.use_tb or self.use_wandb:
117 | metrics['extr_reward'] = extr_reward.mean().item()
118 | metrics['batch_reward'] = reward.mean().item()
119 |
120 | if not self.update_encoder:
121 | obs = obs.detach()
122 | next_obs = next_obs.detach()
123 |
124 | # update critic
125 | metrics.update(
126 | self.update_critic(obs.detach(), action, reward, discount,
127 | next_obs.detach(), step))
128 |
129 | # update actor
130 | metrics.update(self.update_actor(obs.detach(), step))
131 |
132 | # update critic target
133 | utils.soft_update_params(self.critic, self.critic_target,
134 | self.critic_target_tau)
135 |
136 | return metrics
137 |
--------------------------------------------------------------------------------
/DMC_state/agent/icm.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | import utils
8 | from agent.ddpg import DDPGAgent
9 |
10 |
11 | class ICM(nn.Module):
12 | def __init__(self, obs_dim, action_dim, hidden_dim):
13 | super().__init__()
14 |
15 | self.forward_net = nn.Sequential(
16 | nn.Linear(obs_dim + action_dim, hidden_dim), nn.ReLU(),
17 | nn.Linear(hidden_dim, obs_dim))
18 |
19 | self.backward_net = nn.Sequential(nn.Linear(2 * obs_dim, hidden_dim),
20 | nn.ReLU(),
21 | nn.Linear(hidden_dim, action_dim),
22 | nn.Tanh())
23 |
24 | self.apply(utils.weight_init)
25 |
26 | def forward(self, obs, action, next_obs):
27 | assert obs.shape[0] == next_obs.shape[0]
28 | assert obs.shape[0] == action.shape[0]
29 |
30 | next_obs_hat = self.forward_net(torch.cat([obs, action], dim=-1))
31 | action_hat = self.backward_net(torch.cat([obs, next_obs], dim=-1))
32 |
33 | forward_error = torch.norm(next_obs - next_obs_hat,
34 | dim=-1,
35 | p=2,
36 | keepdim=True)
37 | backward_error = torch.norm(action - action_hat,
38 | dim=-1,
39 | p=2,
40 | keepdim=True)
41 |
42 | return forward_error, backward_error
43 |
44 |
45 | class ICMAgent(DDPGAgent):
46 | def __init__(self, icm_scale, update_encoder, **kwargs):
47 | super().__init__(**kwargs)
48 | self.icm_scale = icm_scale
49 | self.update_encoder = update_encoder
50 |
51 | self.icm = ICM(self.obs_dim, self.action_dim,
52 | self.hidden_dim).to(self.device)
53 |
54 | # optimizers
55 | self.icm_opt = torch.optim.Adam(self.icm.parameters(), lr=self.lr)
56 |
57 | self.icm.train()
58 |
59 | def update_icm(self, obs, action, next_obs, step):
60 | metrics = dict()
61 |
62 | forward_error, backward_error = self.icm(obs, action, next_obs)
63 |
64 | loss = forward_error.mean() + backward_error.mean()
65 |
66 | self.icm_opt.zero_grad(set_to_none=True)
67 | if self.encoder_opt is not None:
68 | self.encoder_opt.zero_grad(set_to_none=True)
69 | loss.backward()
70 | self.icm_opt.step()
71 | if self.encoder_opt is not None:
72 | self.encoder_opt.step()
73 |
74 | if self.use_tb or self.use_wandb:
75 | metrics['icm_loss'] = loss.item()
76 |
77 | return metrics
78 |
79 | def compute_intr_reward(self, obs, action, next_obs, step):
80 | forward_error, _ = self.icm(obs, action, next_obs)
81 |
82 | reward = forward_error * self.icm_scale
83 | reward = torch.log(reward + 1.0)
84 | return reward
85 |
86 | def update(self, replay_iter, step):
87 | metrics = dict()
88 |
89 | if step % self.update_every_steps != 0:
90 | return metrics
91 |
92 | batch = next(replay_iter)
93 | obs, action, extr_reward, discount, next_obs, task_id = utils.to_torch(
94 | batch, self.device)
95 |
96 | # augment and encode
97 | obs = self.aug_and_encode(obs)
98 | with torch.no_grad():
99 | next_obs = self.aug_and_encode(next_obs)
100 |
101 | if self.reward_free:
102 | metrics.update(self.update_icm(obs, action, next_obs, step))
103 |
104 | with torch.no_grad():
105 | intr_reward = self.compute_intr_reward(obs, action, next_obs,
106 | step)
107 |
108 | if self.use_tb or self.use_wandb:
109 | metrics['intr_reward'] = intr_reward.mean().item()
110 | reward = intr_reward
111 | else:
112 | reward = extr_reward
113 |
114 | if self.use_tb or self.use_wandb:
115 | metrics['extr_reward'] = extr_reward.mean().item()
116 | metrics['batch_reward'] = reward.mean().item()
117 |
118 | if not self.update_encoder:
119 | obs = obs.detach()
120 | next_obs = next_obs.detach()
121 |
122 | # update critic
123 | metrics.update(
124 | self.update_critic(obs.detach(), action, reward, discount,
125 | next_obs.detach(), step))
126 |
127 | # update actor
128 | metrics.update(self.update_actor(obs.detach(), step))
129 |
130 | # update critic target
131 | utils.soft_update_params(self.critic, self.critic_target,
132 | self.critic_target_tau)
133 |
134 | return metrics
135 |
--------------------------------------------------------------------------------
/DMC_state/agent/lbs.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torch.distributions as D
7 |
8 | import utils
9 | from agent.ddpg import DDPGAgent
10 |
11 |
12 | # s_t, a_t -> z_t+1
13 | # s_t, a_t, s_t+1 -> z_t+1
14 | # z_t+1 -> s_t+1
15 | class LBS(nn.Module):
16 | def __init__(self, obs_dim, action_dim, hidden_dim):
17 | super().__init__()
18 |
19 | self.pri_forward_net = nn.Sequential(
20 | nn.Linear(obs_dim + action_dim, hidden_dim), nn.ReLU(),
21 | nn.Linear(hidden_dim, hidden_dim))
22 |
23 | self.pos_forward_net = nn.Sequential(
24 | nn.Linear(2 * obs_dim + action_dim, hidden_dim), nn.ReLU(),
25 | nn.Linear(hidden_dim, hidden_dim))
26 |
27 | self.reconstruction_net = nn.Sequential(
28 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
29 | nn.Linear(hidden_dim, obs_dim))
30 |
31 | self.apply(utils.weight_init)
32 |
33 | def forward(self, obs, action, next_obs):
34 | assert obs.shape[0] == next_obs.shape[0]
35 | assert obs.shape[0] == action.shape[0]
36 |
37 | pri_z = self.pri_forward_net(torch.cat([obs, action], dim=-1))
38 | pos_z = self.pos_forward_net(torch.cat([obs, action, next_obs], dim=-1))
39 |
40 | reco_s = self.reconstruction_net(pos_z)
41 |
42 | pri_z = D.Independent(D.Normal(pri_z, 1.0), 1)
43 | pos_z = D.Independent(D.Normal(pos_z, 1.0), 1)
44 | reco_s = D.Independent(D.Normal(reco_s, 1.0), 1)
45 |
46 | kl_div = D.kl_divergence(pos_z, pri_z)
47 |
48 | reco_error = -reco_s.log_prob(next_obs).mean()
49 | kl_error = kl_div.mean()
50 |
51 | return kl_error, reco_error, kl_div.detach()
52 |
53 |
54 | class LBS_PRED(nn.Module):
55 | def __init__(self, obs_dim, action_dim, hidden_dim):
56 | super().__init__()
57 |
58 | self.pred_net = nn.Sequential(
59 | nn.Linear(obs_dim + action_dim, hidden_dim), nn.ReLU(),
60 | nn.Linear(hidden_dim, 1))
61 |
62 | self.apply(utils.weight_init)
63 |
64 | def forward(self, obs, action, next_obs):
65 | assert obs.shape[0] == next_obs.shape[0]
66 | assert obs.shape[0] == action.shape[0]
67 |
68 | pred_kl = self.pred_net(torch.cat([obs, action], dim=-1))
69 | pred_kl = D.Independent(D.Normal(pred_kl, 1.0), 1)
70 |
71 | return pred_kl
72 |
73 |
74 | class LBSAgent(DDPGAgent):
75 | def __init__(self, lbs_scale, update_encoder, **kwargs):
76 | super().__init__(**kwargs)
77 | self.lbs_scale = lbs_scale
78 | self.update_encoder = update_encoder
79 |
80 | self.lbs = LBS(self.obs_dim, self.action_dim,
81 | self.hidden_dim).to(self.device)
82 | # optimizers
83 | self.lbs_opt = torch.optim.Adam(self.lbs.parameters(), lr=self.lr)
84 |
85 | self.lbs_pred = LBS_PRED(self.obs_dim, self.action_dim,
86 | self.hidden_dim).to(self.device)
87 | self.lbs_pred_opt = torch.optim.Adam(self.lbs_pred.parameters(), lr=self.lr)
88 |
89 | self.lbs.train()
90 | self.lbs_pred.train()
91 |
92 | def update_lbs(self, obs, action, next_obs, step):
93 | metrics = dict()
94 |
95 | kl_error, reco_error, kl_div = self.lbs(obs, action, next_obs)
96 |
97 | lbs_loss = kl_error.mean() + reco_error.mean()
98 |
99 | self.lbs_opt.zero_grad(set_to_none=True)
100 | if self.encoder_opt is not None:
101 | self.encoder_opt.zero_grad(set_to_none=True)
102 | lbs_loss.backward()
103 | self.lbs_opt.step()
104 | if self.encoder_opt is not None:
105 | self.encoder_opt.step()
106 |
107 | kl_pred = self.lbs_pred(obs, action, next_obs)
108 | lbs_pred_loss = -kl_pred.log_prob(kl_div.detach()).mean()
109 | self.lbs_pred_opt.zero_grad(set_to_none=True)
110 | if self.encoder_opt is not None:
111 | self.encoder_opt.zero_grad(set_to_none=True)
112 | lbs_pred_loss.backward()
113 | self.lbs_pred_opt.step()
114 | if self.encoder_opt is not None:
115 | self.encoder_opt.step()
116 |
117 | if self.use_tb or self.use_wandb:
118 | metrics['lbs_loss'] = lbs_loss.item()
119 | metrics['lbs_pred_loss'] = lbs_pred_loss.item()
120 |
121 | return metrics
122 |
123 | def compute_intr_reward(self, obs, action, next_obs, step):
124 | kl_pred = self.lbs_pred(obs, action, next_obs)
125 |
126 | reward = kl_pred.mean * self.lbs_scale
127 | return reward
128 |
129 | def update(self, replay_iter, step):
130 | metrics = dict()
131 |
132 | if step % self.update_every_steps != 0:
133 | return metrics
134 |
135 | batch = next(replay_iter)
136 | obs, action, extr_reward, discount, next_obs, task_id = utils.to_torch(
137 | batch, self.device)
138 |
139 | # augment and encode
140 | obs = self.aug_and_encode(obs)
141 | with torch.no_grad():
142 | next_obs = self.aug_and_encode(next_obs)
143 |
144 | if self.reward_free:
145 | metrics.update(self.update_lbs(obs, action, next_obs, step))
146 |
147 | with torch.no_grad():
148 | intr_reward = self.compute_intr_reward(obs, action, next_obs,
149 | step)
150 |
151 | if self.use_tb or self.use_wandb:
152 | metrics['intr_reward'] = intr_reward.mean().item()
153 | reward = intr_reward
154 | else:
155 | reward = extr_reward
156 |
157 | if self.use_tb or self.use_wandb:
158 | metrics['extr_reward'] = extr_reward.mean().item()
159 | metrics['batch_reward'] = reward.mean().item()
160 |
161 | if not self.update_encoder:
162 | obs = obs.detach()
163 | next_obs = next_obs.detach()
164 |
165 | # update critic
166 | metrics.update(
167 | self.update_critic(obs.detach(), action, reward, discount,
168 | next_obs.detach(), step))
169 |
170 | # update actor
171 | metrics.update(self.update_actor(obs.detach(), step))
172 |
173 | # update critic target
174 | utils.soft_update_params(self.critic, self.critic_target,
175 | self.critic_target_tau)
176 |
177 | return metrics
178 |
--------------------------------------------------------------------------------
/DMC_state/agent/peac.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn import GRUCell
7 |
8 | import utils
9 | from agent.ddpg import DDPGAgent
10 |
11 |
12 | class ContextModel(nn.Module):
13 | def __init__(self, obs_dim, act_dim, context_dim, hidden_dim, device='cuda'):
14 | super().__init__()
15 |
16 | self.device = device
17 | self.hidden_dim = hidden_dim
18 | self.model = GRUCell(obs_dim+act_dim, hidden_dim)
19 | self.context_head = nn.Sequential(nn.ReLU(),
20 | nn.Linear(hidden_dim, context_dim))
21 |
22 | self.apply(utils.weight_init)
23 |
24 | def forward(self, obs, acts, hidden=None):
25 | if hidden is None:
26 | hidden = torch.zeros((obs.shape[0], self.hidden_dim), device=self.device)
27 | for i in range(obs.shape[1]):
28 | hidden = self.model(torch.cat([obs[:, i], acts[:, i]], dim=-1),
29 | hidden)
30 | context_pred = self.context_head(hidden)
31 | return context_pred
32 |
33 | class PEACAgent(DDPGAgent):
34 | def __init__(self, update_encoder,
35 | context_dim, **kwargs):
36 | super().__init__(**kwargs)
37 | self.update_encoder = update_encoder
38 | self.context_dim = context_dim
39 | print('context dim:', self.context_dim)
40 |
41 | self.task_model = ContextModel(self.obs_dim, self.action_dim, self.context_dim,
42 | self.hidden_dim, device=self.device).to(self.device)
43 |
44 | # optimizers
45 | self.task_opt = torch.optim.Adam(self.task_model.parameters(), lr=self.lr)
46 |
47 | self.task_model.train()
48 |
49 | def update_task_model(self, obs, action, next_obs, embodiment_id, pre_obs, pre_acts):
50 | metrics = dict()
51 | task_pred = self.task_model(pre_obs, pre_acts)
52 | # print(task_pred.shape)
53 | # print(torch.sum(embodiment_id))
54 | loss = F.cross_entropy(task_pred, embodiment_id.reshape(-1))
55 |
56 | self.task_opt.zero_grad(set_to_none=True)
57 | if self.encoder_opt is not None:
58 | self.encoder_opt.zero_grad(set_to_none=True)
59 | loss.backward()
60 | self.task_opt.step()
61 | if self.encoder_opt is not None:
62 | self.encoder_opt.step()
63 |
64 | if self.use_tb or self.use_wandb:
65 | metrics['task_loss'] = loss.item()
66 |
67 | return metrics
68 |
69 | def compute_intr_reward(self, obs, action, next_obs, embodiment_id, pre_obs, pre_acts):
70 | B, _ = action.shape
71 | task_pred = self.task_model(pre_obs, pre_acts) # B, task_num
72 | # calculate the task model predict prob
73 | task_pred = F.log_softmax(task_pred, dim=1)
74 | intr_rew = task_pred[torch.arange(B), embodiment_id.reshape(-1)] # B
75 | task_rew = -intr_rew.reshape(B, 1)
76 | return task_rew
77 |
78 | def update(self, replay_iter, step):
79 | metrics = dict()
80 |
81 | if step % self.update_every_steps != 0:
82 | return metrics
83 |
84 | batch = next(replay_iter)
85 | obs, action, extr_reward, discount, next_obs, embodiment_id, his_o, his_a = \
86 | utils.to_torch(batch, self.device)
87 |
88 | # augment and encode
89 | obs = self.aug_and_encode(obs)
90 | with torch.no_grad():
91 | next_obs = self.aug_and_encode(next_obs)
92 |
93 | if self.reward_free:
94 | with torch.no_grad():
95 | intr_reward = self.compute_intr_reward(obs, action, next_obs, embodiment_id,
96 | his_o, his_a)
97 |
98 | if self.use_tb or self.use_wandb:
99 | metrics['intr_reward'] = intr_reward.mean().item()
100 | reward = intr_reward
101 | else:
102 | reward = extr_reward
103 |
104 | if self.use_tb or self.use_wandb:
105 | metrics['extr_reward'] = extr_reward.mean().item()
106 | metrics['batch_reward'] = reward.mean().item()
107 |
108 | if not self.update_encoder:
109 | obs = obs.detach()
110 | next_obs = next_obs.detach()
111 |
112 | metrics.update(self.update_task_model(obs.detach(), action, next_obs, embodiment_id,
113 | his_o, his_a))
114 |
115 | # update critic
116 | metrics.update(
117 | self.update_critic(obs.detach(), action, reward, discount,
118 | next_obs.detach(), step))
119 |
120 | # update actor
121 | metrics.update(self.update_actor(obs.detach(), step))
122 |
123 | # update critic target
124 | utils.soft_update_params(self.critic, self.critic_target,
125 | self.critic_target_tau)
126 |
127 | return metrics
128 |
--------------------------------------------------------------------------------
/DMC_state/agent/proto.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch import distributions as pyd
8 | from torch import jit
9 |
10 | import utils
11 | from agent.ddpg import DDPGAgent
12 |
13 |
14 | # @jit.script
15 | def sinkhorn_knopp(Q):
16 | Q -= Q.max()
17 | Q = torch.exp(Q).T
18 | Q /= Q.sum()
19 |
20 | r = torch.ones(Q.shape[0], device=Q.device) / Q.shape[0]
21 | c = torch.ones(Q.shape[1], device=Q.device) / Q.shape[1]
22 | for it in range(3):
23 | u = Q.sum(dim=1)
24 | u = r / u
25 | Q *= u.unsqueeze(dim=1)
26 | Q *= (c / Q.sum(dim=0)).unsqueeze(dim=0)
27 | Q = Q / Q.sum(dim=0, keepdim=True)
28 | return Q.T
29 |
30 |
31 | class Projector(nn.Module):
32 | def __init__(self, pred_dim, proj_dim):
33 | super().__init__()
34 |
35 | self.trunk = nn.Sequential(nn.Linear(pred_dim, proj_dim), nn.ReLU(),
36 | nn.Linear(proj_dim, pred_dim))
37 |
38 | self.apply(utils.weight_init)
39 |
40 | def forward(self, x):
41 | return self.trunk(x)
42 |
43 |
44 | class ProtoAgent(DDPGAgent):
45 | def __init__(self, pred_dim, proj_dim, queue_size, num_protos, tau,
46 | encoder_target_tau, topk, update_encoder, **kwargs):
47 | super().__init__(**kwargs)
48 | self.tau = tau
49 | self.encoder_target_tau = encoder_target_tau
50 | self.topk = topk
51 | self.num_protos = num_protos
52 | self.update_encoder = update_encoder
53 |
54 | # models
55 | self.encoder_target = deepcopy(self.encoder)
56 |
57 | self.predictor = nn.Linear(self.obs_dim, pred_dim).to(self.device)
58 | self.predictor.apply(utils.weight_init)
59 | self.predictor_target = deepcopy(self.predictor)
60 |
61 | self.projector = Projector(pred_dim, proj_dim).to(self.device)
62 | self.projector.apply(utils.weight_init)
63 |
64 | # prototypes
65 | self.protos = nn.Linear(pred_dim, num_protos,
66 | bias=False).to(self.device)
67 | self.protos.apply(utils.weight_init)
68 |
69 | # candidate queue
70 | self.queue = torch.zeros(queue_size, pred_dim, device=self.device)
71 | self.queue_ptr = 0
72 |
73 | # optimizers
74 | self.proto_opt = torch.optim.Adam(utils.chain(
75 | self.encoder.parameters(), self.predictor.parameters(),
76 | self.projector.parameters(), self.protos.parameters()),
77 | lr=self.lr)
78 |
79 | self.predictor.train()
80 | self.projector.train()
81 | self.protos.train()
82 |
83 | def init_from(self, other):
84 | # copy parameters over
85 | utils.hard_update_params(other.encoder, self.encoder)
86 | utils.hard_update_params(other.actor, self.actor)
87 | utils.hard_update_params(other.predictor, self.predictor)
88 | utils.hard_update_params(other.projector, self.projector)
89 | utils.hard_update_params(other.protos, self.protos)
90 | if self.init_critic:
91 | utils.hard_update_params(other.critic, self.critic)
92 |
93 | def normalize_protos(self):
94 | C = self.protos.weight.data.clone()
95 | C = F.normalize(C, dim=1, p=2)
96 | self.protos.weight.data.copy_(C)
97 |
98 | def compute_intr_reward(self, obs, step):
99 | self.normalize_protos()
100 | # find a candidate for each prototype
101 | with torch.no_grad():
102 | z = self.encoder(obs)
103 | z = self.predictor(z)
104 | z = F.normalize(z, dim=1, p=2)
105 | scores = self.protos(z).T
106 | prob = F.softmax(scores, dim=1)
107 | candidates = pyd.Categorical(prob).sample()
108 |
109 | # enqueue candidates
110 | ptr = self.queue_ptr
111 | self.queue[ptr:ptr + self.num_protos] = z[candidates]
112 | self.queue_ptr = (ptr + self.num_protos) % self.queue.shape[0]
113 |
114 | # compute distances between the batch and the queue of candidates
115 | z_to_q = torch.norm(z[:, None, :] - self.queue[None, :, :], dim=2, p=2)
116 | all_dists, _ = torch.topk(z_to_q, self.topk, dim=1, largest=False)
117 | dist = all_dists[:, -1:]
118 | reward = dist
119 | return reward
120 |
121 | def update_proto(self, obs, next_obs, step):
122 | metrics = dict()
123 |
124 | # normalize prototypes
125 | self.normalize_protos()
126 |
127 | # online network
128 | s = self.encoder(obs)
129 | s = self.predictor(s)
130 | s = self.projector(s)
131 | s = F.normalize(s, dim=1, p=2)
132 | scores_s = self.protos(s)
133 | log_p_s = F.log_softmax(scores_s / self.tau, dim=1)
134 |
135 | # target network
136 | with torch.no_grad():
137 | t = self.encoder_target(next_obs)
138 | t = self.predictor_target(t)
139 | t = F.normalize(t, dim=1, p=2)
140 | scores_t = self.protos(t)
141 | q_t = sinkhorn_knopp(scores_t / self.tau)
142 |
143 | # loss
144 | loss = -(q_t * log_p_s).sum(dim=1).mean()
145 | if self.use_tb or self.use_wandb:
146 | metrics['repr_loss'] = loss.item()
147 | self.proto_opt.zero_grad(set_to_none=True)
148 | loss.backward()
149 | self.proto_opt.step()
150 |
151 | return metrics
152 |
153 | def update(self, replay_iter, step):
154 | metrics = dict()
155 |
156 | if step % self.update_every_steps != 0:
157 | return metrics
158 |
159 | batch = next(replay_iter)
160 | obs, action, extr_reward, discount, next_obs, task_id = utils.to_torch(
161 | batch, self.device)
162 |
163 | # augment and encode
164 | with torch.no_grad():
165 | obs = self.aug(obs)
166 | next_obs = self.aug(next_obs)
167 |
168 | if self.reward_free:
169 | metrics.update(self.update_proto(obs, next_obs, step))
170 |
171 | with torch.no_grad():
172 | intr_reward = self.compute_intr_reward(next_obs, step)
173 |
174 | if self.use_tb or self.use_wandb:
175 | metrics['intr_reward'] = intr_reward.mean().item()
176 | reward = intr_reward
177 | else:
178 | reward = extr_reward
179 |
180 | if self.use_tb or self.use_wandb:
181 | metrics['extr_reward'] = extr_reward.mean().item()
182 | metrics['batch_reward'] = reward.mean().item()
183 |
184 | obs = self.encoder(obs)
185 | next_obs = self.encoder(next_obs)
186 |
187 | if not self.update_encoder:
188 | obs = obs.detach()
189 | next_obs = next_obs.detach()
190 |
191 | # update critic
192 | metrics.update(
193 | self.update_critic(obs.detach(), action, reward, discount,
194 | next_obs.detach(), step))
195 |
196 | # update actor
197 | metrics.update(self.update_actor(obs.detach(), step))
198 |
199 | # update critic target
200 | utils.soft_update_params(self.encoder, self.encoder_target,
201 | self.encoder_target_tau)
202 | utils.soft_update_params(self.predictor, self.predictor_target,
203 | self.encoder_target_tau)
204 | utils.soft_update_params(self.critic, self.critic_target,
205 | self.critic_target_tau)
206 |
207 | return metrics
208 |
--------------------------------------------------------------------------------
/DMC_state/agent/rnd.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import hydra
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | import utils
10 | from agent.ddpg import DDPGAgent
11 |
12 |
13 | class RND(nn.Module):
14 | def __init__(self,
15 | obs_dim,
16 | hidden_dim,
17 | rnd_rep_dim,
18 | encoder,
19 | aug,
20 | obs_shape,
21 | obs_type,
22 | clip_val=5.):
23 | super().__init__()
24 | self.clip_val = clip_val
25 | self.aug = aug
26 |
27 | if obs_type == "pixels":
28 | self.normalize_obs = nn.BatchNorm2d(obs_shape[0], affine=False)
29 | else:
30 | self.normalize_obs = nn.BatchNorm1d(obs_shape[0], affine=False)
31 |
32 | self.predictor = nn.Sequential(encoder, nn.Linear(obs_dim, hidden_dim),
33 | nn.ReLU(),
34 | nn.Linear(hidden_dim, hidden_dim),
35 | nn.ReLU(),
36 | nn.Linear(hidden_dim, rnd_rep_dim))
37 | self.target = nn.Sequential(copy.deepcopy(encoder),
38 | nn.Linear(obs_dim, hidden_dim), nn.ReLU(),
39 | nn.Linear(hidden_dim, hidden_dim),
40 | nn.ReLU(),
41 | nn.Linear(hidden_dim, rnd_rep_dim))
42 |
43 | for param in self.target.parameters():
44 | param.requires_grad = False
45 |
46 | self.apply(utils.weight_init)
47 |
48 | def forward(self, obs):
49 | obs = self.aug(obs)
50 | obs = self.normalize_obs(obs)
51 | obs = torch.clamp(obs, -self.clip_val, self.clip_val)
52 | prediction, target = self.predictor(obs), self.target(obs)
53 | prediction_error = torch.square(target.detach() - prediction).mean(
54 | dim=-1, keepdim=True)
55 | return prediction_error
56 |
57 |
58 | class RNDAgent(DDPGAgent):
59 | def __init__(self, rnd_rep_dim, update_encoder, rnd_scale=1., **kwargs):
60 | super().__init__(**kwargs)
61 | self.rnd_scale = rnd_scale
62 | self.update_encoder = update_encoder
63 |
64 | self.rnd = RND(self.obs_dim, self.hidden_dim, rnd_rep_dim,
65 | self.encoder, self.aug, self.obs_shape,
66 | self.obs_type).to(self.device)
67 | self.intrinsic_reward_rms = utils.RMS(device=self.device)
68 |
69 | # optimizers
70 | self.rnd_opt = torch.optim.Adam(self.rnd.parameters(), lr=self.lr)
71 |
72 | self.rnd.train()
73 |
74 | def update_rnd(self, obs, step):
75 | metrics = dict()
76 |
77 | prediction_error = self.rnd(obs)
78 |
79 | loss = prediction_error.mean()
80 |
81 | self.rnd_opt.zero_grad(set_to_none=True)
82 | if self.encoder_opt is not None:
83 | self.encoder_opt.zero_grad(set_to_none=True)
84 | loss.backward()
85 | self.rnd_opt.step()
86 | if self.encoder_opt is not None:
87 | self.encoder_opt.step()
88 |
89 | if self.use_tb or self.use_wandb:
90 | metrics['rnd_loss'] = loss.item()
91 |
92 | return metrics
93 |
94 | def compute_intr_reward(self, obs, step):
95 | prediction_error = self.rnd(obs)
96 | _, intr_reward_var = self.intrinsic_reward_rms(prediction_error)
97 | reward = self.rnd_scale * prediction_error / (
98 | torch.sqrt(intr_reward_var) + 1e-8)
99 | return reward
100 |
101 | def update(self, replay_iter, step):
102 | metrics = dict()
103 |
104 | if step % self.update_every_steps != 0:
105 | return metrics
106 |
107 | batch = next(replay_iter)
108 | obs, action, extr_reward, discount, next_obs, task_id = utils.to_torch(
109 | batch, self.device)
110 |
111 | # update RND first
112 | if self.reward_free:
113 | # note: one difference is that the RND module is updated off policy
114 | metrics.update(self.update_rnd(obs, step))
115 |
116 | with torch.no_grad():
117 | intr_reward = self.compute_intr_reward(obs, step)
118 |
119 | if self.use_tb or self.use_wandb:
120 | metrics['intr_reward'] = intr_reward.mean().item()
121 | reward = intr_reward
122 | else:
123 | reward = extr_reward
124 |
125 | # augment and encode
126 | obs = self.aug_and_encode(obs)
127 | with torch.no_grad():
128 | next_obs = self.aug_and_encode(next_obs)
129 |
130 | if self.use_tb or self.use_wandb:
131 | metrics['extr_reward'] = extr_reward.mean().item()
132 | metrics['batch_reward'] = reward.mean().item()
133 |
134 | metrics['pred_error_mean'] = self.intrinsic_reward_rms.M
135 | metrics['pred_error_std'] = torch.sqrt(self.intrinsic_reward_rms.S)
136 |
137 | if not self.update_encoder:
138 | obs = obs.detach()
139 | next_obs = next_obs.detach()
140 |
141 | # update critic
142 | metrics.update(
143 | self.update_critic(obs.detach(), action, reward, discount,
144 | next_obs.detach(), step))
145 |
146 | # update actor
147 | metrics.update(self.update_actor(obs.detach(), step))
148 |
149 | # update critic target
150 | utils.soft_update_params(self.critic, self.critic_target,
151 | self.critic_target_tau)
152 |
153 | return metrics
154 |
--------------------------------------------------------------------------------
/DMC_state/configs/agent/aps.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.aps.APSAgent
3 | name: aps
4 | reward_free: ${reward_free}
5 | obs_type: ??? # to be specified later
6 | obs_shape: ??? # to be specified later
7 | action_shape: ??? # to be specified later
8 | device: ${device}
9 | lr: 1e-4
10 | critic_target_tau: 0.01
11 | update_every_steps: 2
12 | use_tb: ${use_tb}
13 | use_wandb: ${use_wandb}
14 | num_expl_steps: ??? # to be specified later
15 | hidden_dim: 1024
16 | feature_dim: 50
17 | stddev_schedule: 0.2
18 | stddev_clip: 0.3
19 | sf_dim: 10
20 | update_task_every_step: 5
21 | nstep: 3
22 | batch_size: 1024
23 | init_critic: true
24 | knn_rms: true
25 | knn_k: 12
26 | knn_avg: true
27 | knn_clip: 0.0001
28 | num_init_steps: 4096 # set to ${num_train_frames} to disable finetune policy parameters
29 | lstsq_batch_size: 4096
30 | update_encoder: ${update_encoder}
31 |
--------------------------------------------------------------------------------
/DMC_state/configs/agent/becl.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.becl.BECLAgent
3 | name: becl
4 | reward_free: ${reward_free}
5 | obs_type: ??? # to be specified later
6 | obs_shape: ??? # to be specified later
7 | action_shape: ??? # to be specified later
8 | device: ${device}
9 | lr: 1e-4
10 | critic_target_tau: 0.01
11 | update_every_steps: 2
12 | use_tb: ${use_tb}
13 | use_wandb: ${use_wandb}
14 | num_expl_steps: ??? # to be specified later
15 | hidden_dim: 1024
16 | feature_dim: 50
17 | stddev_schedule: 0.2
18 | stddev_clip: 0.3
19 | skill_dim: 16
20 | update_skill_every_step: 50
21 | nstep: 3
22 | batch_size: 1024
23 | init_critic: true
24 | update_encoder: ${update_encoder}
25 |
26 | # extra hyperparameter
27 | contrastive_update_rate: 3
28 | temperature: 0.5
29 |
30 | # skill finetuning ablation
31 | skill: -1
--------------------------------------------------------------------------------
/DMC_state/configs/agent/cic.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.cic.CICAgent
3 | name: cic
4 | reward_free: ${reward_free}
5 | obs_type: ??? # to be specified later
6 | obs_shape: ??? # to be specified later
7 | action_shape: ??? # to be specified later
8 | device: ${device}
9 | lr: 1e-4
10 | critic_target_tau: 0.01
11 | update_every_steps: 2
12 | use_tb: ${use_tb}
13 | use_wandb: ${use_wandb}
14 | num_expl_steps: 2000
15 | hidden_dim: 1024
16 | feature_dim: 1024
17 | stddev_schedule: 0.2
18 | stddev_clip: 0.3
19 | skill_dim: 64
20 | scale: 1.0
21 | update_skill_every_step: 50
22 | nstep: 3
23 | batch_size: 1024
24 | project_skill: true
25 | init_critic: true
26 | rew_type: og
27 | update_rep: true
28 | temp: 0.5
--------------------------------------------------------------------------------
/DMC_state/configs/agent/ddpg.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.ddpg.DDPGAgent
3 | name: ddpg
4 | reward_free: ${reward_free}
5 | obs_type: ??? # to be specified later
6 | obs_shape: ??? # to be specified later
7 | action_shape: ??? # to be specified later
8 | device: ${device}
9 | lr: 1e-4
10 | critic_target_tau: 0.01
11 | update_every_steps: 2
12 | use_tb: ${use_tb}
13 | use_wandb: ${use_wandb}
14 | num_expl_steps: ??? # to be specified later
15 | hidden_dim: 1024
16 | feature_dim: 50
17 | stddev_schedule: 0.2
18 | stddev_clip: 0.3
19 | nstep: 3
20 | batch_size: 1024 # 256 for pixels
21 | init_critic: true
22 | update_encoder: ${update_encoder}
23 |
--------------------------------------------------------------------------------
/DMC_state/configs/agent/diayn.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.diayn.DIAYNAgent
3 | name: diayn
4 | reward_free: ${reward_free}
5 | obs_type: ??? # to be specified later
6 | obs_shape: ??? # to be specified later
7 | action_shape: ??? # to be specified later
8 | device: ${device}
9 | lr: 1e-4
10 | critic_target_tau: 0.01
11 | update_every_steps: 2
12 | use_tb: ${use_tb}
13 | use_wandb: ${use_wandb}
14 | num_expl_steps: ??? # to be specified later
15 | hidden_dim: 1024
16 | feature_dim: 50
17 | stddev_schedule: 0.2
18 | stddev_clip: 0.3
19 | skill_dim: 16
20 | diayn_scale: 1.0
21 | update_skill_every_step: 50
22 | nstep: 3
23 | batch_size: 1024
24 | init_critic: true
25 | update_encoder: ${update_encoder}
26 |
--------------------------------------------------------------------------------
/DMC_state/configs/agent/disagreement.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.disagreement.DisagreementAgent
3 | name: disagreement
4 | reward_free: ${reward_free}
5 | obs_type: ??? # to be specified later
6 | obs_shape: ??? # to be specified later
7 | action_shape: ??? # to be specified later
8 | device: ${device}
9 | lr: 1e-4
10 | critic_target_tau: 0.01
11 | update_every_steps: 2
12 | use_tb: ${use_tb}
13 | use_wandb: ${use_wandb}
14 | num_expl_steps: ??? # to be specified later
15 | hidden_dim: 1024
16 | feature_dim: 50
17 | stddev_schedule: 0.2
18 | stddev_clip: 0.3
19 | nstep: 3
20 | batch_size: 1024
21 | init_critic: true
22 | update_encoder: ${update_encoder}
23 |
--------------------------------------------------------------------------------
/DMC_state/configs/agent/icm.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.icm.ICMAgent
3 | name: icm
4 | reward_free: ${reward_free}
5 | obs_type: ??? # to be specified later
6 | obs_shape: ??? # to be specified later
7 | action_shape: ??? # to be specified later
8 | device: ${device}
9 | lr: 1e-4
10 | critic_target_tau: 0.01
11 | update_every_steps: 2
12 | use_tb: ${use_tb}
13 | use_wandb: ${use_wandb}
14 | num_expl_steps: ??? # to be specified later
15 | hidden_dim: 1024
16 | feature_dim: 50
17 | stddev_schedule: 0.2
18 | stddev_clip: 0.3
19 | icm_scale: 1.0
20 | nstep: 3
21 | batch_size: 1024
22 | init_critic: true
23 | update_encoder: ${update_encoder}
24 |
--------------------------------------------------------------------------------
/DMC_state/configs/agent/lbs.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.lbs.LBSAgent
3 | name: lbs
4 | reward_free: ${reward_free}
5 | obs_type: ??? # to be specified later
6 | obs_shape: ??? # to be specified later
7 | action_shape: ??? # to be specified later
8 | device: ${device}
9 | lr: 1e-4
10 | critic_target_tau: 0.01
11 | update_every_steps: 2
12 | use_tb: ${use_tb}
13 | use_wandb: ${use_wandb}
14 | num_expl_steps: ??? # to be specified later
15 | hidden_dim: 1024
16 | feature_dim: 50
17 | stddev_schedule: 0.2
18 | stddev_clip: 0.3
19 | nstep: 3
20 | batch_size: 1024
21 | init_critic: true
22 | update_encoder: ${update_encoder}
23 |
24 | lbs_scale: 1.0
--------------------------------------------------------------------------------
/DMC_state/configs/agent/peac.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.peac.PEACAgent
3 | name: peac
4 | reward_free: ${reward_free}
5 | obs_type: ??? # to be specified later
6 | obs_shape: ??? # to be specified later
7 | action_shape: ??? # to be specified later
8 | device: ${device}
9 | lr: 1e-4
10 | critic_target_tau: 0.01
11 | update_every_steps: 2
12 | use_tb: ${use_tb}
13 | use_wandb: ${use_wandb}
14 | num_expl_steps: ??? # to be specified later
15 | hidden_dim: 1024
16 | feature_dim: 50
17 | stddev_schedule: 0.2
18 | stddev_clip: 0.3
19 | nstep: 3
20 | batch_size: 1024
21 | init_critic: true
22 | update_encoder: ${update_encoder}
23 |
24 | his_o_a: 10
25 | context_dim: 1
--------------------------------------------------------------------------------
/DMC_state/configs/agent/proto.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.proto.ProtoAgent
3 | name: proto
4 | reward_free: ${reward_free}
5 | obs_type: ??? # to be specified later
6 | obs_shape: ??? # to be specified later
7 | action_shape: ??? # to be specified later
8 | device: ${device}
9 | lr: 1e-4
10 | critic_target_tau: 0.01
11 | update_every_steps: 2
12 | use_tb: ${use_tb}
13 | use_wandb: ${use_wandb}
14 | num_expl_steps: ??? # to be specified later
15 | hidden_dim: 1024
16 | feature_dim: 50
17 | stddev_schedule: 0.2
18 | stddev_clip: 0.3
19 | nstep: 3
20 | batch_size: 1024
21 | init_critic: true
22 | pred_dim: 128
23 | proj_dim: 512
24 | num_protos: 512
25 | tau: 0.1
26 | topk: 3
27 | queue_size: 2048
28 | encoder_target_tau: 0.05
29 | update_encoder: ${update_encoder}
30 |
--------------------------------------------------------------------------------
/DMC_state/configs/agent/rnd.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.rnd.RNDAgent
3 | name: rnd
4 | reward_free: ${reward_free}
5 | obs_type: ??? # to be specified later
6 | obs_shape: ??? # to be specified later
7 | action_shape: ??? # to be specified later
8 | device: ${device}
9 | lr: 1e-4
10 | critic_target_tau: 0.01
11 | update_every_steps: 2
12 | use_tb: ${use_tb}
13 | use_wandb: ${use_wandb}
14 | num_expl_steps: ??? # to be specified later
15 | hidden_dim: 1024
16 | feature_dim: 50
17 | rnd_rep_dim: 512
18 | stddev_schedule: 0.2
19 | stddev_clip: 0.3
20 | rnd_scale: 1.0
21 | nstep: 3
22 | batch_size: 1024
23 | init_critic: true
24 | update_encoder: ${update_encoder}
25 |
--------------------------------------------------------------------------------
/DMC_state/configs/agent/smm.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.smm.SMMAgent
3 | name: smm
4 |
5 | # z params
6 | z_dim: 4 # default in codebase is 4
7 |
8 | # z discriminator params
9 | sp_lr: 1e-3
10 |
11 | # vae params
12 | vae_lr: 1e-2
13 | vae_beta: 0.5
14 |
15 | # reward params
16 | state_ent_coef: 1.0
17 | latent_ent_coef: 1.0
18 | latent_cond_ent_coef: 1.0
19 |
20 | # DDPG params
21 | reward_free: ${reward_free}
22 | obs_type: ??? # to be specified later
23 | obs_shape: ??? # to be specified later
24 | action_shape: ??? # to be specified later
25 | device: ${device}
26 | lr: 1e-4
27 | critic_target_tau: 0.01
28 | update_every_steps: 2
29 | use_tb: ${use_tb}
30 | use_wandb: ${use_wandb}
31 | num_expl_steps: ??? # to be specified later
32 | hidden_dim: 1024
33 | feature_dim: 50
34 | stddev_schedule: 0.2
35 | stddev_clip: 0.3
36 | nstep: 3
37 | batch_size: 1024
38 | init_critic: true
39 | update_encoder: ${update_encoder}
--------------------------------------------------------------------------------
/DMC_state/custom_dmc_tasks/__init__.py:
--------------------------------------------------------------------------------
1 | from custom_dmc_tasks import walker
2 | from custom_dmc_tasks import quadruped
3 | from custom_dmc_tasks import jaco
4 |
5 |
6 | def make(domain, task,
7 | task_kwargs=None,
8 | environment_kwargs=None,
9 | visualize_reward=False):
10 |
11 | if domain == 'walker':
12 | return walker.make(task,
13 | task_kwargs=task_kwargs,
14 | environment_kwargs=environment_kwargs,
15 | visualize_reward=visualize_reward)
16 | elif domain == 'quadruped':
17 | return quadruped.make(task,
18 | task_kwargs=task_kwargs,
19 | environment_kwargs=environment_kwargs,
20 | visualize_reward=visualize_reward)
21 | else:
22 | raise f'{task} not found'
23 |
24 | assert None
25 |
26 |
27 | def make_jaco(task, obs_type, seed):
28 | return jaco.make(task, obs_type, seed)
--------------------------------------------------------------------------------
/DMC_state/custom_dmc_tasks/jaco.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """A task where the goal is to move the hand close to a target prop or site."""
17 |
18 | import collections
19 |
20 | from dm_control import composer
21 | from dm_control.composer import initializers
22 | from dm_control.composer.observation import observable
23 | from dm_control.composer.variation import distributions
24 | from dm_control.entities import props
25 | from dm_control.manipulation.shared import arenas
26 | from dm_control.manipulation.shared import cameras
27 | from dm_control.manipulation.shared import constants
28 | from dm_control.manipulation.shared import observations
29 | from dm_control.manipulation.shared import registry
30 | from dm_control.manipulation.shared import robots
31 | from dm_control.manipulation.shared import tags
32 | from dm_control.manipulation.shared import workspaces
33 | from dm_control.utils import rewards
34 | import numpy as np
35 |
36 | _ReachWorkspace = collections.namedtuple(
37 | '_ReachWorkspace', ['target_bbox', 'tcp_bbox', 'arm_offset'])
38 |
39 | # Ensures that the props are not touching the table before settling.
40 | _PROP_Z_OFFSET = 0.001
41 |
42 | _DUPLO_WORKSPACE = _ReachWorkspace(
43 | target_bbox=workspaces.BoundingBox(
44 | lower=(-0.1, -0.1, _PROP_Z_OFFSET),
45 | upper=(0.1, 0.1, _PROP_Z_OFFSET)),
46 | tcp_bbox=workspaces.BoundingBox(
47 | lower=(-0.1, -0.1, 0.2),
48 | upper=(0.1, 0.1, 0.4)),
49 | arm_offset=robots.ARM_OFFSET)
50 |
51 | _SITE_WORKSPACE = _ReachWorkspace(
52 | target_bbox=workspaces.BoundingBox(
53 | lower=(-0.2, -0.2, 0.02),
54 | upper=(0.2, 0.2, 0.4)),
55 | tcp_bbox=workspaces.BoundingBox(
56 | lower=(-0.2, -0.2, 0.02),
57 | upper=(0.2, 0.2, 0.4)),
58 | arm_offset=robots.ARM_OFFSET)
59 |
60 | _TARGET_RADIUS = 0.05
61 | _TIME_LIMIT = 10.
62 |
63 | TASKS = {
64 | 'reach_top_left': workspaces.BoundingBox(
65 | lower=(-0.09, 0.09, _PROP_Z_OFFSET),
66 | upper=(-0.09, 0.09, _PROP_Z_OFFSET)),
67 | 'reach_top_right': workspaces.BoundingBox(
68 | lower=(0.09, 0.09, _PROP_Z_OFFSET),
69 | upper=(0.09, 0.09, _PROP_Z_OFFSET)),
70 | 'reach_bottom_left': workspaces.BoundingBox(
71 | lower=(-0.09, -0.09, _PROP_Z_OFFSET),
72 | upper=(-0.09, -0.09, _PROP_Z_OFFSET)),
73 | 'reach_bottom_right': workspaces.BoundingBox(
74 | lower=(0.09, -0.09, _PROP_Z_OFFSET),
75 | upper=(0.09, -0.09, _PROP_Z_OFFSET)),
76 | }
77 |
78 |
79 | def make(task_id, obs_type, seed):
80 | obs_settings = observations.VISION if obs_type == 'pixels' else observations.PERFECT_FEATURES
81 | if obs_type == 'states':
82 | global _TIME_LIMIT
83 | _TIME_LIMIT = 10.04
84 | # Note: Adding this fixes the problem of having 249 steps with action repeat = 1
85 | task = _reach(task_id, obs_settings=obs_settings, use_site=False)
86 | return composer.Environment(task, time_limit=_TIME_LIMIT, random_state=seed)
87 |
88 |
89 | class MTReach(composer.Task):
90 | """Bring the hand close to a target prop or site."""
91 |
92 | def __init__(
93 | self, task_id, arena, arm, hand, prop, obs_settings, workspace, control_timestep):
94 | """Initializes a new `Reach` task.
95 |
96 | Args:
97 | arena: `composer.Entity` instance.
98 | arm: `robot_base.RobotArm` instance.
99 | hand: `robot_base.RobotHand` instance.
100 | prop: `composer.Entity` instance specifying the prop to reach to, or None
101 | in which case the target is a fixed site whose position is specified by
102 | the workspace.
103 | obs_settings: `observations.ObservationSettings` instance.
104 | workspace: `_ReachWorkspace` specifying the placement of the prop and TCP.
105 | control_timestep: Float specifying the control timestep in seconds.
106 | """
107 | self._arena = arena
108 | self._arm = arm
109 | self._hand = hand
110 | self._arm.attach(self._hand)
111 | self._arena.attach_offset(self._arm, offset=workspace.arm_offset)
112 | self.control_timestep = control_timestep
113 | self._tcp_initializer = initializers.ToolCenterPointInitializer(
114 | self._hand, self._arm,
115 | position=distributions.Uniform(*workspace.tcp_bbox),
116 | quaternion=workspaces.DOWN_QUATERNION)
117 |
118 | # Add custom camera observable.
119 | self._task_observables = cameras.add_camera_observables(
120 | arena, obs_settings, cameras.FRONT_CLOSE)
121 |
122 | target_pos_distribution = distributions.Uniform(*TASKS[task_id])
123 | self._prop = prop
124 | if prop:
125 | # The prop itself is used to visualize the target location.
126 | self._make_target_site(parent_entity=prop, visible=False)
127 | self._target = self._arena.add_free_entity(prop)
128 | self._prop_placer = initializers.PropPlacer(
129 | props=[prop],
130 | position=target_pos_distribution,
131 | quaternion=workspaces.uniform_z_rotation,
132 | settle_physics=True)
133 | else:
134 | self._target = self._make_target_site(parent_entity=arena, visible=True)
135 | self._target_placer = target_pos_distribution
136 |
137 | obs = observable.MJCFFeature('pos', self._target)
138 | obs.configure(**obs_settings.prop_pose._asdict())
139 | self._task_observables['target_position'] = obs
140 |
141 | # Add sites for visualizing the prop and target bounding boxes.
142 | workspaces.add_bbox_site(
143 | body=self.root_entity.mjcf_model.worldbody,
144 | lower=workspace.tcp_bbox.lower, upper=workspace.tcp_bbox.upper,
145 | rgba=constants.GREEN, name='tcp_spawn_area')
146 | workspaces.add_bbox_site(
147 | body=self.root_entity.mjcf_model.worldbody,
148 | lower=workspace.target_bbox.lower, upper=workspace.target_bbox.upper,
149 | rgba=constants.BLUE, name='target_spawn_area')
150 |
151 | def _make_target_site(self, parent_entity, visible):
152 | return workspaces.add_target_site(
153 | body=parent_entity.mjcf_model.worldbody,
154 | radius=_TARGET_RADIUS, visible=visible,
155 | rgba=constants.RED, name='target_site')
156 |
157 | @property
158 | def root_entity(self):
159 | return self._arena
160 |
161 | @property
162 | def arm(self):
163 | return self._arm
164 |
165 | @property
166 | def hand(self):
167 | return self._hand
168 |
169 | @property
170 | def task_observables(self):
171 | return self._task_observables
172 |
173 | def get_reward(self, physics):
174 | hand_pos = physics.bind(self._hand.tool_center_point).xpos
175 | target_pos = physics.bind(self._target).xpos
176 | distance = np.linalg.norm(hand_pos - target_pos)
177 | return rewards.tolerance(
178 | distance, bounds=(0, _TARGET_RADIUS), margin=_TARGET_RADIUS)
179 |
180 | def initialize_episode(self, physics, random_state):
181 | self._hand.set_grasp(physics, close_factors=random_state.uniform())
182 | self._tcp_initializer(physics, random_state)
183 | if self._prop:
184 | self._prop_placer(physics, random_state)
185 | else:
186 | physics.bind(self._target).pos = (
187 | self._target_placer(random_state=random_state))
188 |
189 |
190 | def _reach(task_id, obs_settings, use_site):
191 | """Configure and instantiate a `Reach` task.
192 |
193 | Args:
194 | obs_settings: An `observations.ObservationSettings` instance.
195 | use_site: Boolean, if True then the target will be a fixed site, otherwise
196 | it will be a moveable Duplo brick.
197 |
198 | Returns:
199 | An instance of `reach.Reach`.
200 | """
201 | arena = arenas.Standard()
202 | arm = robots.make_arm(obs_settings=obs_settings)
203 | hand = robots.make_hand(obs_settings=obs_settings)
204 | if use_site:
205 | workspace = _SITE_WORKSPACE
206 | prop = None
207 | else:
208 | workspace = _DUPLO_WORKSPACE
209 | prop = props.Duplo(observable_options=observations.make_options(
210 | obs_settings, observations.FREEPROP_OBSERVABLES))
211 | task = MTReach(task_id, arena=arena, arm=arm, hand=hand, prop=prop,
212 | obs_settings=obs_settings,
213 | workspace=workspace,
214 | control_timestep=constants.CONTROL_TIMESTEP)
215 | return task
216 |
--------------------------------------------------------------------------------
/DMC_state/custom_dmc_tasks/walker.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Planar Walker Domain."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import os
23 |
24 | from dm_control import mujoco
25 | from dm_control.rl import control
26 | from dm_control.suite import base
27 | from dm_control.suite import common
28 | from dm_control.suite.utils import randomizers
29 | from dm_control.utils import containers
30 | from dm_control.utils import rewards
31 | from dm_control.utils import io as resources
32 | from dm_control import suite
33 |
34 | _DEFAULT_TIME_LIMIT = 25
35 | _CONTROL_TIMESTEP = .025
36 |
37 | # Minimal height of torso over foot above which stand reward is 1.
38 | _STAND_HEIGHT = 1.2
39 |
40 | # Horizontal speeds (meters/second) above which move reward is 1.
41 | _WALK_SPEED = 1
42 | _RUN_SPEED = 8
43 | _SPIN_SPEED = 5
44 |
45 | SUITE = containers.TaggedTasks()
46 |
47 | def make(task,
48 | task_kwargs=None,
49 | environment_kwargs=None,
50 | visualize_reward=False):
51 | task_kwargs = task_kwargs or {}
52 | if environment_kwargs is not None:
53 | task_kwargs = task_kwargs.copy()
54 | task_kwargs['environment_kwargs'] = environment_kwargs
55 | env = SUITE[task](**task_kwargs)
56 | env.task.visualize_reward = visualize_reward
57 | return env
58 |
59 | def get_model_and_assets():
60 | """Returns a tuple containing the model XML string and a dict of assets."""
61 | root_dir = os.path.dirname(os.path.dirname(__file__))
62 | xml = resources.GetResource(os.path.join(root_dir, 'custom_dmc_tasks',
63 | 'walker.xml'))
64 | return xml, common.ASSETS
65 |
66 |
67 |
68 |
69 |
70 |
71 | @SUITE.add('benchmarking')
72 | def flip(time_limit=_DEFAULT_TIME_LIMIT,
73 | random=None,
74 | environment_kwargs=None):
75 | """Returns the Run task."""
76 | physics = Physics.from_xml_string(*get_model_and_assets())
77 | task = PlanarWalker(move_speed=_RUN_SPEED,
78 | forward=True,
79 | flip=True,
80 | random=random)
81 | environment_kwargs = environment_kwargs or {}
82 | return control.Environment(physics,
83 | task,
84 | time_limit=time_limit,
85 | control_timestep=_CONTROL_TIMESTEP,
86 | **environment_kwargs)
87 |
88 |
89 | class Physics(mujoco.Physics):
90 | """Physics simulation with additional features for the Walker domain."""
91 | def torso_upright(self):
92 | """Returns projection from z-axes of torso to the z-axes of world."""
93 | return self.named.data.xmat['torso', 'zz']
94 |
95 | def torso_height(self):
96 | """Returns the height of the torso."""
97 | return self.named.data.xpos['torso', 'z']
98 |
99 | def horizontal_velocity(self):
100 | """Returns the horizontal velocity of the center-of-mass."""
101 | return self.named.data.sensordata['torso_subtreelinvel'][0]
102 |
103 | def orientations(self):
104 | """Returns planar orientations of all bodies."""
105 | return self.named.data.xmat[1:, ['xx', 'xz']].ravel()
106 |
107 | def angmomentum(self):
108 | """Returns the angular momentum of torso of the Cheetah about Y axis."""
109 | return self.named.data.subtree_angmom['torso'][1]
110 |
111 |
112 | class PlanarWalker(base.Task):
113 | """A planar walker task."""
114 | def __init__(self, move_speed, forward=True, flip=False, random=None):
115 | """Initializes an instance of `PlanarWalker`.
116 |
117 | Args:
118 | move_speed: A float. If this value is zero, reward is given simply for
119 | standing up. Otherwise this specifies a target horizontal velocity for
120 | the walking task.
121 | random: Optional, either a `numpy.random.RandomState` instance, an
122 | integer seed for creating a new `RandomState`, or None to select a seed
123 | automatically (default).
124 | """
125 | self._move_speed = move_speed
126 | self._forward = 1 if forward else -1
127 | self._flip = flip
128 | super(PlanarWalker, self).__init__(random=random)
129 |
130 | def initialize_episode(self, physics):
131 | """Sets the state of the environment at the start of each episode.
132 |
133 | In 'standing' mode, use initial orientation and small velocities.
134 | In 'random' mode, randomize joint angles and let fall to the floor.
135 |
136 | Args:
137 | physics: An instance of `Physics`.
138 |
139 | """
140 | randomizers.randomize_limited_and_rotational_joints(
141 | physics, self.random)
142 | super(PlanarWalker, self).initialize_episode(physics)
143 |
144 | def get_observation(self, physics):
145 | """Returns an observation of body orientations, height and velocites."""
146 | obs = collections.OrderedDict()
147 | obs['orientations'] = physics.orientations()
148 | obs['height'] = physics.torso_height()
149 | obs['velocity'] = physics.velocity()
150 | return obs
151 |
152 | def get_reward(self, physics):
153 | """Returns a reward to the agent."""
154 | standing = rewards.tolerance(physics.torso_height(),
155 | bounds=(_STAND_HEIGHT, float('inf')),
156 | margin=_STAND_HEIGHT / 2)
157 | upright = (1 + physics.torso_upright()) / 2
158 | stand_reward = (3 * standing + upright) / 4
159 |
160 | if self._flip:
161 | move_reward = rewards.tolerance(self._forward *
162 | physics.angmomentum(),
163 | bounds=(_SPIN_SPEED, float('inf')),
164 | margin=_SPIN_SPEED,
165 | value_at_margin=0,
166 | sigmoid='linear')
167 | else:
168 | move_reward = rewards.tolerance(
169 | self._forward * physics.horizontal_velocity(),
170 | bounds=(self._move_speed, float('inf')),
171 | margin=self._move_speed / 2,
172 | value_at_margin=0.5,
173 | sigmoid='linear')
174 |
175 | return stand_reward * (5 * move_reward + 1) / 6
176 |
--------------------------------------------------------------------------------
/DMC_state/custom_dmc_tasks/walker.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
--------------------------------------------------------------------------------
/DMC_state/dmc_benchmark.py:
--------------------------------------------------------------------------------
1 | DOMAINS = [
2 | 'walker',
3 | 'quadruped',
4 | 'jaco',
5 | ]
6 |
7 | WALKER_TASKS = [
8 | 'walker_stand',
9 | 'walker_walk',
10 | 'walker_run',
11 | 'walker_flip',
12 | ]
13 |
14 | QUADRUPED_TASKS = [
15 | 'quadruped_walk',
16 | 'quadruped_run',
17 | 'quadruped_stand',
18 | 'quadruped_jump',
19 | ]
20 |
21 | JACO_TASKS = [
22 | 'jaco_reach_top_left',
23 | 'jaco_reach_top_right',
24 | 'jaco_reach_bottom_left',
25 | 'jaco_reach_bottom_right',
26 | ]
27 |
28 | TASKS = WALKER_TASKS + QUADRUPED_TASKS + JACO_TASKS
29 |
30 | parameter_1 = ['0.2', '0.6', '1.0', '1.4', '1.8']
31 | parameter_1_eval = ['0.4', '0.8', '1.2', '1.6']
32 |
33 | parameter_2 = ['0.4', '0.8', '1.0', '1.4']
34 | parameter_2_eval = ['0.6', '1.2']
35 |
36 | PRETRAIN_TASKS = {
37 | 'walker': 'walker_stand',
38 | 'jaco': 'jaco_reach_top_left',
39 | 'quadruped': 'quadruped_walk',
40 | 'walker_mass': ['walker_stand~mass~' + para for para in parameter_1],
41 | 'quadruped_mass': ['quadruped_stand~mass~' + para for para in parameter_2],
42 | 'quadruped_damping': ['quadruped_stand~damping~' + para for para in parameter_1],
43 | }
44 |
45 | FINETUNE_TASKS = {
46 | 'walker_stand_mass': ['walker_stand~mass~' + para for para in parameter_1],
47 | 'walker_stand_mass_eval': ['walker_stand~mass~' + para for para in parameter_1_eval],
48 | 'walker_walk_mass': ['walker_walk~mass~' + para for para in parameter_1],
49 | 'walker_walk_mass_eval': ['walker_walk~mass~' + para for para in parameter_1_eval],
50 | 'walker_run_mass': ['walker_run~mass~' + para for para in parameter_1],
51 | 'walker_run_mass_eval': ['walker_run~mass~' + para for para in parameter_1_eval],
52 | 'walker_flip_mass': ['walker_flip~mass~' + para for para in parameter_1],
53 | 'walker_flip_mass_eval': ['walker_flip~mass~' + para for para in parameter_1_eval],
54 |
55 | 'quadruped_stand_mass': ['quadruped_stand~mass~' + para for para in parameter_2],
56 | 'quadruped_stand_mass_eval': ['quadruped_stand~mass~' + para for para in parameter_2_eval],
57 | 'quadruped_walk_mass': ['quadruped_walk~mass~' + para for para in parameter_2],
58 | 'quadruped_walk_mass_eval': ['quadruped_walk~mass~' + para for para in parameter_2_eval],
59 | 'quadruped_run_mass': ['quadruped_run~mass~' + para for para in parameter_2],
60 | 'quadruped_run_mass_eval': ['quadruped_run~mass~' + para for para in parameter_2_eval],
61 | 'quadruped_jump_mass': ['quadruped_jump~mass~' + para for para in parameter_2],
62 | 'quadruped_jump_mass_eval': ['quadruped_jump~mass~' + para for para in parameter_2_eval],
63 |
64 | 'quadruped_stand_damping': ['quadruped_stand~damping~' + para for para in parameter_1],
65 | 'quadruped_stand_damping_eval': ['quadruped_stand~damping~' + para for para in parameter_1_eval],
66 | 'quadruped_walk_damping': ['quadruped_walk~damping~' + para for para in parameter_1],
67 | 'quadruped_walk_damping_eval': ['quadruped_walk~damping~' + para for para in parameter_1_eval],
68 | 'quadruped_run_damping': ['quadruped_run~damping~' + para for para in parameter_1],
69 | 'quadruped_run_damping_eval': ['quadruped_run~damping~' + para for para in parameter_1_eval],
70 | 'quadruped_jump_damping': ['quadruped_jump~damping~' + para for para in parameter_1],
71 | 'quadruped_jump_damping_eval': ['quadruped_jump~damping~' + para for para in parameter_1_eval],
72 | }
73 |
--------------------------------------------------------------------------------
/DMC_state/finetune.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - configs/agent: ddpg
3 | - override hydra/launcher: submitit_local
4 |
5 | # mode
6 | reward_free: false
7 | # task settings
8 | task: none
9 | domain: walker_mass
10 | finetune_domain: walker_stand_mass
11 | obs_type: states # [states, pixels]
12 | frame_stack: 3 # only works if obs_type=pixels
13 | action_repeat: 1 # set to 2 for pixels
14 | discount: 0.99
15 | # train settings
16 | num_train_frames: 100010
17 | num_seed_frames: 4000
18 | # eval
19 | eval_every_frames: 10000
20 | num_eval_episodes: 10
21 | # pretrained
22 | snapshot_ts: 100000
23 | snapshot_base_dir: ../../../../../../pretrained_models
24 | # replay buffer
25 | replay_buffer_size: 1000000
26 | replay_buffer_num_workers: 4
27 | batch_size: ${agent.batch_size}
28 | nstep: ${agent.nstep}
29 | update_encoder: false # can be either true or false depending if we want to fine-tune encoder
30 | # misc
31 | seed: 1
32 | device: cuda
33 | save_video: true
34 | save_train_video: false
35 | use_tb: false
36 | use_wandb: false
37 | # experiment
38 | experiment: exp
39 |
40 | init_task: 1.0
41 |
42 | hydra:
43 | run:
44 | dir: ./exp_local/${domain}/finetune_${finetune_domain}/${agent.name}/${snapshot_ts}/${now:%Y.%m.%d.%H%M%S}_${seed}
45 | sweep:
46 | dir: ./exp_sweep/${domain}/finetune_${finetune_domain}/${agent.name}/${snapshot_ts}/${now:%Y.%m.%d.%H%M%S}_${seed}_${experiment}
47 | subdir: ${hydra.job.num}
48 | launcher:
49 | timeout_min: 4300
50 | cpus_per_task: 10
51 | gpus_per_node: 1
52 | tasks_per_node: 1
53 | mem_gb: 160
54 | nodes: 1
55 | submitit_folder: ./exp_sweep/${domain}/finetune_${finetune_domain}/${agent.name}/${snapshot_ts}/${now:%Y.%m.%d.%H%M%S}_${seed}_${experiment}/.slurm
--------------------------------------------------------------------------------
/DMC_state/finetune_ddpg.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | DOMAIN=$1 # walker_mass, quadruped_mass, quadruped_damping
3 | GPU_ID=$2
4 | FINETUNE_TASK=$3
5 |
6 | echo "Experiments started."
7 | for seed in $(seq 0 9)
8 | do
9 | export MUJOCO_EGL_DEVICE_ID=${GPU_ID}
10 | python finetune.py configs/agent=ddpg domain=${DOMAIN} seed=${seed} device=cuda:${GPU_ID} snapshot_ts=0 finetune_domain=${FINETUNE_TASK} num_train_frames=2000010
11 | done
12 | echo "Experiments ended."
13 |
14 | # e.g.
15 | # ./finetune_ddpg.sh walker_mass 0 walker_stand_mass
--------------------------------------------------------------------------------
/DMC_state/logger.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import datetime
3 | from collections import defaultdict
4 |
5 | import numpy as np
6 | import torch
7 | import torchvision
8 | import wandb
9 | from termcolor import colored
10 | from torch.utils.tensorboard import SummaryWriter
11 |
12 | COMMON_TRAIN_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'),
13 | ('episode', 'E', 'int'), ('episode_length', 'L', 'int'),
14 | ('episode_reward', 'R', 'float'),
15 | ('fps', 'FPS', 'float'), ('total_time', 'T', 'time')]
16 |
17 | COMMON_EVAL_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'),
18 | ('episode', 'E', 'int'), ('episode_length', 'L', 'int'),
19 | ('episode_reward', 'R', 'float'),
20 | ('total_time', 'T', 'time'),
21 | ('episode_train_reward', 'TR', 'float'),
22 | ('episode_eval_reward', 'ER', 'float')]
23 |
24 |
25 | class AverageMeter(object):
26 | def __init__(self):
27 | self._sum = 0
28 | self._count = 0
29 |
30 | def update(self, value, n=1):
31 | self._sum += value
32 | self._count += n
33 |
34 | def value(self):
35 | return self._sum / max(1, self._count)
36 |
37 |
38 | class MetersGroup(object):
39 | def __init__(self, csv_file_name, formating, use_wandb):
40 | self._csv_file_name = csv_file_name
41 | self._formating = formating
42 | self._meters = defaultdict(AverageMeter)
43 | self._csv_file = None
44 | self._csv_writer = None
45 | self.use_wandb = use_wandb
46 |
47 | def log(self, key, value, n=1):
48 | self._meters[key].update(value, n)
49 |
50 | def _prime_meters(self):
51 | data = dict()
52 | for key, meter in self._meters.items():
53 | if key.startswith('train'):
54 | key = key[len('train') + 1:]
55 | else:
56 | key = key[len('eval') + 1:]
57 | key = key.replace('/', '_')
58 | data[key] = meter.value()
59 | return data
60 |
61 | def _remove_old_entries(self, data):
62 | rows = []
63 | with self._csv_file_name.open('r') as f:
64 | reader = csv.DictReader(f)
65 | for row in reader:
66 | if float(row['episode']) >= data['episode']:
67 | break
68 | rows.append(row)
69 | with self._csv_file_name.open('w') as f:
70 | writer = csv.DictWriter(f,
71 | fieldnames=sorted(data.keys()),
72 | restval=0.0)
73 | writer.writeheader()
74 | for row in rows:
75 | writer.writerow(row)
76 |
77 | def _dump_to_csv(self, data):
78 | if self._csv_writer is None:
79 | should_write_header = True
80 | if self._csv_file_name.exists():
81 | self._remove_old_entries(data)
82 | should_write_header = False
83 |
84 | self._csv_file = self._csv_file_name.open('a')
85 | self._csv_writer = csv.DictWriter(self._csv_file,
86 | fieldnames=sorted(data.keys()),
87 | restval=0.0)
88 | if should_write_header:
89 | self._csv_writer.writeheader()
90 |
91 | self._csv_writer.writerow(data)
92 | self._csv_file.flush()
93 |
94 | def _format(self, key, value, ty):
95 | if ty == 'int':
96 | value = int(value)
97 | return f'{key}: {value}'
98 | elif ty == 'float':
99 | return f'{key}: {value:.04f}'
100 | elif ty == 'time':
101 | value = str(datetime.timedelta(seconds=int(value)))
102 | return f'{key}: {value}'
103 | else:
104 | raise f'invalid format type: {ty}'
105 |
106 | def _dump_to_console(self, data, prefix):
107 | prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green')
108 | pieces = [f'| {prefix: <14}']
109 | for key, disp_key, ty in self._formating:
110 | value = data.get(key, 0)
111 | pieces.append(self._format(disp_key, value, ty))
112 | print(' | '.join(pieces))
113 |
114 | def _dump_to_wandb(self, data):
115 | wandb.log(data)
116 |
117 | def dump(self, step, prefix):
118 | if len(self._meters) == 0:
119 | return
120 | data = self._prime_meters()
121 | data['frame'] = step
122 | if self.use_wandb:
123 | wandb_data = {prefix + '/' + key: val for key, val in data.items()}
124 | self._dump_to_wandb(data=wandb_data)
125 | self._dump_to_csv(data)
126 | self._dump_to_console(data, prefix)
127 | self._meters.clear()
128 |
129 |
130 | class Logger(object):
131 | def __init__(self, log_dir, use_tb, use_wandb):
132 | self._log_dir = log_dir
133 | self._train_mg = MetersGroup(log_dir / 'train.csv',
134 | formating=COMMON_TRAIN_FORMAT,
135 | use_wandb=use_wandb)
136 | self._eval_mg = MetersGroup(log_dir / 'eval.csv',
137 | formating=COMMON_EVAL_FORMAT,
138 | use_wandb=use_wandb)
139 | if use_tb:
140 | self._sw = SummaryWriter(str(log_dir / 'tb'))
141 | else:
142 | self._sw = None
143 | self.use_wandb = use_wandb
144 |
145 | def _try_sw_log(self, key, value, step):
146 | if self._sw is not None:
147 | self._sw.add_scalar(key, value, step)
148 |
149 | def log(self, key, value, step):
150 | assert key.startswith('train') or key.startswith('eval')
151 | if type(value) == torch.Tensor:
152 | value = value.item()
153 | self._try_sw_log(key, value, step)
154 | mg = self._train_mg if key.startswith('train') else self._eval_mg
155 | mg.log(key, value)
156 |
157 | def log_metrics(self, metrics, step, ty):
158 | for key, value in metrics.items():
159 | self.log(f'{ty}/{key}', value, step)
160 |
161 | def dump(self, step, ty=None):
162 | if ty is None or ty == 'eval':
163 | self._eval_mg.dump(step, 'eval')
164 | if ty is None or ty == 'train':
165 | self._train_mg.dump(step, 'train')
166 |
167 | def log_and_dump_ctx(self, step, ty):
168 | return LogAndDumpCtx(self, step, ty)
169 |
170 |
171 | class LogAndDumpCtx:
172 | def __init__(self, logger, step, ty):
173 | self._logger = logger
174 | self._step = step
175 | self._ty = ty
176 |
177 | def __enter__(self):
178 | return self
179 |
180 | def __call__(self, key, value):
181 | self._logger.log(f'{self._ty}/{key}', value, self._step)
182 |
183 | def __exit__(self, *args):
184 | self._logger.dump(self._step, self._ty)
185 |
--------------------------------------------------------------------------------
/DMC_state/pretrain.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | warnings.filterwarnings('ignore', category=DeprecationWarning)
4 |
5 | import os
6 |
7 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
8 | os.environ['MUJOCO_GL'] = 'egl'
9 |
10 | from pathlib import Path
11 |
12 | import hydra
13 | import numpy as np
14 | import torch
15 | import wandb
16 | from dm_env import specs
17 |
18 | import dmc
19 | import utils
20 | from logger import Logger
21 | from replay_buffer import ReplayBufferStorage, make_replay_loader
22 | from video import TrainVideoRecorder, VideoRecorder
23 |
24 | torch.backends.cudnn.benchmark = True
25 |
26 | from dmc_benchmark import PRETRAIN_TASKS
27 |
28 |
29 | def make_agent(obs_type, obs_spec, action_spec, num_expl_steps, cfg):
30 | cfg.obs_type = obs_type
31 | cfg.obs_shape = obs_spec.shape
32 | cfg.action_shape = action_spec.shape
33 | cfg.num_expl_steps = num_expl_steps
34 | return hydra.utils.instantiate(cfg)
35 |
36 |
37 | class Workspace:
38 | def __init__(self, cfg):
39 | self.work_dir = Path.cwd()
40 | print(f'workspace: {self.work_dir}')
41 |
42 | self.cfg = cfg
43 | utils.set_seed_everywhere(cfg.seed)
44 | self.device = torch.device(cfg.device)
45 |
46 | # create logger
47 | if cfg.use_wandb:
48 | exp_name = '_'.join([
49 | cfg.experiment, cfg.agent.name, cfg.domain, cfg.obs_type,
50 | str(cfg.seed)
51 | ])
52 | wandb.init(project="urlb", group=cfg.agent.name, name=exp_name)
53 |
54 | self.logger = Logger(self.work_dir,
55 | use_tb=cfg.use_tb,
56 | use_wandb=cfg.use_wandb)
57 |
58 | # create envs
59 | if cfg.task != 'none':
60 | # single task
61 | tasks = [cfg.task]
62 | else:
63 | # pre-define multi-task
64 | tasks = PRETRAIN_TASKS[self.cfg.domain]
65 | frame_stack = 1
66 | img_size = 64
67 |
68 | self.train_envs = [dmc.make(task, cfg.obs_type, cfg.frame_stack,
69 | cfg.action_repeat, cfg.seed, )
70 | for task in tasks]
71 | self.tasks_name = tasks
72 | self.train_envs_number = len(self.train_envs)
73 | self.current_train_id = 0
74 | self.eval_env = [dmc.make(task, cfg.obs_type, cfg.frame_stack,
75 | cfg.action_repeat, cfg.seed, )
76 | for task in tasks]
77 |
78 | # create agent
79 | if 'peac' in cfg.agent.name or 'context' in cfg.agent.name:
80 | cfg.agent['context_dim'] = self.train_envs_number
81 |
82 | self.agent = make_agent(cfg.obs_type,
83 | self.train_envs[0].observation_spec(),
84 | self.train_envs[0].action_spec(),
85 | cfg.num_seed_frames // cfg.action_repeat,
86 | cfg.agent)
87 |
88 | # get meta specs
89 | meta_specs = self.agent.get_meta_specs()
90 | # create replay buffer
91 | data_specs = (self.train_envs[0].observation_spec(),
92 | self.train_envs[0].action_spec(),
93 | specs.Array((1,), np.float32, 'reward'),
94 | specs.Array((1,), np.float32, 'discount'),
95 | specs.Array((1,), np.int64, 'embodiment_id'),)
96 |
97 | # create data storage
98 | self.replay_storage = ReplayBufferStorage(data_specs, meta_specs,
99 | self.work_dir / 'buffer')
100 |
101 | # create replay buffer
102 | his_o_a = cfg.agent.get('his_o_a', 0)
103 | print('history o and a:', his_o_a)
104 | self.replay_loader = make_replay_loader(self.replay_storage,
105 | cfg.replay_buffer_size,
106 | cfg.batch_size,
107 | cfg.replay_buffer_num_workers,
108 | False, cfg.nstep, cfg.discount,
109 | his_o_a=his_o_a)
110 | self._replay_iter = None
111 |
112 | # create video recorders
113 | self.video_recorder = VideoRecorder(
114 | self.work_dir if cfg.save_video else None,
115 | camera_id=0 if 'quadruped' not in self.cfg.domain else 2,
116 | use_wandb=self.cfg.use_wandb)
117 | self.train_video_recorder = TrainVideoRecorder(
118 | self.work_dir if cfg.save_train_video else None,
119 | camera_id=0 if 'quadruped' not in self.cfg.domain else 2,
120 | use_wandb=self.cfg.use_wandb)
121 |
122 | self.timer = utils.Timer()
123 | self._global_step = 0
124 | self._global_episode = 0
125 |
126 | @property
127 | def global_step(self):
128 | return self._global_step
129 |
130 | @property
131 | def global_episode(self):
132 | return self._global_episode
133 |
134 | @property
135 | def global_frame(self):
136 | return self.global_step * self.cfg.action_repeat
137 |
138 | @property
139 | def replay_iter(self):
140 | if self._replay_iter is None:
141 | self._replay_iter = iter(self.replay_loader)
142 | return self._replay_iter
143 |
144 | def eval(self):
145 | # we do not eval in the pre-train stage
146 | pass
147 |
148 | def train(self):
149 | # predicates
150 | train_until_step = utils.Until(self.cfg.num_train_frames,
151 | self.cfg.action_repeat)
152 | seed_until_step = utils.Until(self.cfg.num_seed_frames,
153 | self.cfg.action_repeat)
154 | eval_every_step = utils.Every(self.cfg.eval_every_frames,
155 | self.cfg.action_repeat)
156 |
157 | episode_step, episode_reward = 0, 0
158 | time_step = self.train_envs[self.current_train_id].reset()
159 | if hasattr(self.agent, "init_context"):
160 | self.agent.init_context()
161 | time_step['embodiment_id'] = self.current_train_id
162 | print('current task is', self.tasks_name[self.current_train_id])
163 | meta = self.agent.init_meta()
164 | self.replay_storage.add(time_step, meta)
165 | # self.train_video_recorder.init(time_step.observation)
166 | self.train_video_recorder.init(time_step['observation'])
167 | metrics = None
168 | while train_until_step(self.global_step):
169 | # if time_step.last():
170 | if time_step['is_last']:
171 | self._global_episode += 1
172 | self.train_video_recorder.save(f'{self.global_frame}.mp4')
173 | # wait until all the metrics schema is populated
174 | if metrics is not None:
175 | # log stats
176 | elapsed_time, total_time = self.timer.reset()
177 | episode_frame = episode_step * self.cfg.action_repeat
178 | with self.logger.log_and_dump_ctx(self.global_frame,
179 | ty='train') as log:
180 | log('fps', episode_frame / elapsed_time)
181 | log('total_time', total_time)
182 | log('episode_reward', episode_reward)
183 | log('episode_length', episode_frame)
184 | log('episode', self.global_episode)
185 | log('buffer_size', len(self.replay_storage))
186 | log('step', self.global_step)
187 |
188 | # reset env
189 | self.current_train_id = (self.current_train_id + 1) % self.train_envs_number
190 | time_step = self.train_envs[self.current_train_id].reset()
191 | if hasattr(self.agent, "init_context"):
192 | self.agent.init_context()
193 | print('current task is', self.tasks_name[self.current_train_id])
194 | time_step['embodiment_id'] = self.current_train_id
195 |
196 | meta = self.agent.init_meta()
197 | self.replay_storage.add(time_step, meta)
198 | # self.train_video_recorder.init(time_step.observation)
199 | self.train_video_recorder.init(time_step['observation'])
200 | # try to save snapshot
201 | if self.global_frame in self.cfg.snapshots:
202 | self.save_snapshot()
203 | episode_step = 0
204 | episode_reward = 0
205 |
206 | # try to evaluate
207 | if eval_every_step(self.global_step):
208 | self.logger.log('eval_total_time', self.timer.total_time(),
209 | self.global_frame)
210 | self.eval()
211 |
212 | meta = self.agent.update_meta(meta, self.global_step, time_step)
213 | # sample action
214 | with torch.no_grad(), utils.eval_mode(self.agent):
215 | action = self.agent.act(time_step['observation'],
216 | meta,
217 | self.global_step,
218 | eval_mode=False)
219 | # action = self.agent.act(time_step.observation,
220 | # meta,
221 | # self.global_step,
222 | # eval_mode=False)
223 |
224 | # try to update the agent
225 | if not seed_until_step(self.global_step):
226 | metrics = self.agent.update(self.replay_iter, self.global_step)
227 | self.logger.log_metrics(metrics, self.global_frame, ty='train')
228 |
229 | # take env step
230 | time_step = self.train_envs[self.current_train_id].step(action)
231 | time_step['embodiment_id'] = self.current_train_id
232 | # episode_reward += time_step.reward
233 | episode_reward += time_step['reward']
234 | self.replay_storage.add(time_step, meta)
235 | self.train_video_recorder.record(time_step['observation'])
236 | # self.train_video_recorder.record(time_step.observation)
237 | episode_step += 1
238 | self._global_step += 1
239 |
240 | def save_snapshot(self):
241 | snapshot_dir = self.work_dir / Path(self.cfg.snapshot_dir)
242 | snapshot_dir.mkdir(exist_ok=True, parents=True)
243 | snapshot = snapshot_dir / f'snapshot_{self.global_frame}.pt'
244 | keys_to_save = ['agent', '_global_step', '_global_episode']
245 | payload = {k: self.__dict__[k] for k in keys_to_save}
246 | with snapshot.open('wb') as f:
247 | torch.save(payload, f)
248 |
249 |
250 | @hydra.main(config_path='.', config_name='pretrain')
251 | def main(cfg):
252 | from pretrain import Workspace as W
253 | root_dir = Path.cwd()
254 | workspace = W(cfg)
255 | snapshot = root_dir / 'snapshot.pt'
256 | if snapshot.exists():
257 | print(f'resuming: {snapshot}')
258 | workspace.load_snapshot()
259 | workspace.train()
260 |
261 |
262 | if __name__ == '__main__':
263 | main()
264 |
--------------------------------------------------------------------------------
/DMC_state/pretrain.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - configs/agent: ddpg
3 | - override hydra/launcher: submitit_local
4 |
5 | # mode
6 | reward_free: true
7 | # task settings
8 | task: none
9 | domain: walker # primal task will be infered in runtime
10 | obs_type: states # [states, pixels]
11 | frame_stack: 3 # only works if obs_type=pixels
12 | action_repeat: 1 # set to 2 for pixels
13 | discount: 0.99
14 | # train settings
15 | num_train_frames: 2000010
16 | num_seed_frames: 4000
17 | # eval
18 | eval_every_frames: 10000
19 | num_eval_episodes: 10
20 | # snapshot
21 | snapshots: [100000, 500000, 1000000, 2000000]
22 | snapshot_dir: ../../../../../pretrained_models/${obs_type}/${domain}/${agent.name}/${seed}
23 | # replay buffer
24 | replay_buffer_size: 1000000
25 | #replay_buffer_num_workers: 4
26 | replay_buffer_num_workers: 4
27 | batch_size: ${agent.batch_size}
28 | nstep: ${agent.nstep}
29 | update_encoder: true # should always be true for pre-training
30 | # misc
31 | seed: 1
32 | device: cuda
33 | save_video: true
34 | save_train_video: false
35 | use_tb: false
36 | use_wandb: false
37 | # experiment
38 | experiment: exp
39 |
40 |
41 | hydra:
42 | run:
43 | dir: ./exp_local/${domain}/pretrain/${agent.name}/${now:%Y.%m.%d.%H%M%S}_${seed}
44 | sweep:
45 | dir: ./exp_sweep/${domain}/pretrain/${agent.name}/${now:%Y.%m.%d.%H%M}_${seed}_${experiment}
46 | subdir: ${hydra.job.num}
47 | launcher:
48 | timeout_min: 4300
49 | cpus_per_task: 10
50 | gpus_per_node: 1
51 | tasks_per_node: 1
52 | mem_gb: 160
53 | nodes: 1
54 | submitit_folder: ./exp_sweep/${domain}/pretrain/${agent.name}/${now:%Y.%m.%H%M}_${seed}_${experiment}/.slurm
55 |
--------------------------------------------------------------------------------
/DMC_state/replay_buffer.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import io
3 | import random
4 | import traceback
5 | from collections import defaultdict
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | from torch.utils.data import IterableDataset
11 |
12 |
13 | def episode_len(episode):
14 | # subtract -1 because the dummy first transition
15 | return next(iter(episode.values())).shape[0] - 1
16 |
17 |
18 | def save_episode(episode, fn):
19 | with io.BytesIO() as bs:
20 | np.savez_compressed(bs, **episode)
21 | bs.seek(0)
22 | with fn.open('wb') as f:
23 | f.write(bs.read())
24 |
25 |
26 | def load_episode(fn):
27 | with fn.open('rb') as f:
28 | episode = np.load(f)
29 | episode = {k: episode[k] for k in episode.keys()}
30 | return episode
31 |
32 |
33 | class ReplayBufferStorage:
34 | def __init__(self, data_specs, meta_specs, replay_dir):
35 | self._data_specs = data_specs
36 | self._meta_specs = meta_specs
37 | self._replay_dir = replay_dir
38 | replay_dir.mkdir(exist_ok=True)
39 | self._current_episode = defaultdict(list)
40 | self._preload()
41 |
42 | def __len__(self):
43 | return self._num_transitions
44 |
45 | def add(self, time_step, meta):
46 | for key, value in meta.items():
47 | self._current_episode[key].append(value)
48 | for spec in self._data_specs:
49 | value = time_step[spec.name]
50 | if np.isscalar(value):
51 | value = np.full(spec.shape, value, spec.dtype)
52 | assert spec.shape == value.shape and spec.dtype == value.dtype
53 | self._current_episode[spec.name].append(value)
54 | # if time_step.last():
55 | if time_step['is_last']:
56 | episode = dict()
57 | for spec in self._data_specs:
58 | value = self._current_episode[spec.name]
59 | episode[spec.name] = np.array(value, spec.dtype)
60 | for spec in self._meta_specs:
61 | value = self._current_episode[spec.name]
62 | episode[spec.name] = np.array(value, spec.dtype)
63 | self._current_episode = defaultdict(list)
64 | self._store_episode(episode)
65 |
66 | def _preload(self):
67 | self._num_episodes = 0
68 | self._num_transitions = 0
69 | for fn in self._replay_dir.glob('*.npz'):
70 | _, _, eps_len = fn.stem.split('_')
71 | self._num_episodes += 1
72 | self._num_transitions += int(eps_len)
73 |
74 | def _store_episode(self, episode):
75 | eps_idx = self._num_episodes
76 | eps_len = episode_len(episode)
77 | self._num_episodes += 1
78 | self._num_transitions += eps_len
79 | ts = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')
80 | eps_fn = f'{ts}_{eps_idx}_{eps_len}.npz'
81 | save_episode(episode, self._replay_dir / eps_fn)
82 |
83 |
84 | class ReplayBuffer(IterableDataset):
85 | def __init__(self, storage, max_size, num_workers, nstep, discount,
86 | fetch_every, save_snapshot, his_o_a=0):
87 | self._storage = storage
88 | self._size = 0
89 | self._max_size = max_size
90 | self._num_workers = max(1, num_workers)
91 | self._episode_fns = []
92 | self._episodes = dict()
93 | self._nstep = nstep
94 | self._discount = discount
95 | self._fetch_every = fetch_every
96 | self._samples_since_last_fetch = fetch_every
97 | self._save_snapshot = save_snapshot
98 | self._his_o_a = his_o_a
99 |
100 | def _sample_episode(self):
101 | # print(len(self._episode_fns))
102 | eps_fn = random.choice(self._episode_fns)
103 | # print('eps_fn embodiment id', eps_fn, self._episodes[eps_fn]['embodiment_id'])
104 | return self._episodes[eps_fn]
105 |
106 | def _store_episode(self, eps_fn):
107 | try:
108 | episode = load_episode(eps_fn)
109 | except:
110 | return False
111 | eps_len = episode_len(episode)
112 | while eps_len + self._size > self._max_size:
113 | early_eps_fn = self._episode_fns.pop(0)
114 | early_eps = self._episodes.pop(early_eps_fn)
115 | self._size -= episode_len(early_eps)
116 | early_eps_fn.unlink(missing_ok=True)
117 | self._episode_fns.append(eps_fn)
118 | self._episode_fns.sort()
119 | self._episodes[eps_fn] = episode
120 | self._size += eps_len
121 |
122 | if not self._save_snapshot:
123 | eps_fn.unlink(missing_ok=True)
124 | return True
125 |
126 | def _try_fetch(self):
127 | if self._samples_since_last_fetch < self._fetch_every:
128 | return
129 | self._samples_since_last_fetch = 0
130 | try:
131 | worker_id = torch.utils.data.get_worker_info().id
132 | except:
133 | worker_id = 0
134 | eps_fns = sorted(self._storage._replay_dir.glob('*.npz'), reverse=True)
135 | fetched_size = 0
136 | # print(eps_fns)
137 | for eps_fn in eps_fns:
138 | # print('hhh', eps_fn.stem)
139 | # print(eps_fn.stem.split('_')[1:])
140 | eps_idx, eps_len = [int(x) for x in eps_fn.stem.split('_')[1:]]
141 | if eps_idx % self._num_workers != worker_id:
142 | continue
143 | if eps_fn in self._episodes.keys():
144 | break
145 | if fetched_size + eps_len > self._max_size:
146 | break
147 | fetched_size += eps_len
148 | if not self._store_episode(eps_fn):
149 | break
150 |
151 | def _sample(self):
152 | try:
153 | self._try_fetch()
154 | except:
155 | traceback.print_exc()
156 | self._samples_since_last_fetch += 1
157 | episode = self._sample_episode()
158 | # add +1 for the first dummy transition
159 | idx = np.random.randint(0, episode_len(episode) - self._his_o_a -
160 | self._nstep + 1) + 1 + self._his_o_a
161 | # print(idx)
162 | meta = []
163 | for spec in self._storage._meta_specs:
164 | meta.append(episode[spec.name][idx - 1])
165 | obs = episode['observation'][idx - 1]
166 | action = episode['action'][idx]
167 | next_obs = episode['observation'][idx + self._nstep - 1]
168 | reward = np.zeros_like(episode['reward'][idx])
169 | discount = np.ones_like(episode['discount'][idx])
170 | embodiment_id = episode['embodiment_id'][idx]
171 | for i in range(self._nstep):
172 | step_reward = episode['reward'][idx + i]
173 | reward += discount * step_reward
174 | discount *= episode['discount'][idx + i] * self._discount
175 | # print(obs.shape)
176 | # print('lalala', episode['embodiment_id'], embodiment_id)
177 | if self._his_o_a == 0:
178 | return (obs, action, reward, discount, next_obs, embodiment_id, *meta)
179 | his_o = episode['observation'][idx - self._his_o_a - 1: idx - 1]
180 | his_a = episode['action'][idx - self._his_o_a: idx]
181 | return (obs, action, reward, discount, next_obs, embodiment_id, his_o, his_a, *meta)
182 |
183 | def __iter__(self):
184 | while True:
185 | yield self._sample()
186 |
187 |
188 | def _worker_init_fn(worker_id):
189 | seed = np.random.get_state()[1][0] + worker_id
190 | np.random.seed(seed)
191 | random.seed(seed)
192 |
193 |
194 | def make_replay_loader(storage, max_size, batch_size, num_workers,
195 | save_snapshot, nstep, discount, his_o_a=0):
196 | max_size_per_worker = max_size // max(1, num_workers)
197 |
198 | iterable = ReplayBuffer(storage,
199 | max_size_per_worker,
200 | num_workers,
201 | nstep,
202 | discount,
203 | fetch_every=1000,
204 | save_snapshot=save_snapshot,
205 | his_o_a=his_o_a,)
206 |
207 | loader = torch.utils.data.DataLoader(iterable,
208 | batch_size=batch_size,
209 | num_workers=num_workers,
210 | pin_memory=True,
211 | worker_init_fn=_worker_init_fn)
212 | return loader
213 |
--------------------------------------------------------------------------------
/DMC_state/train_finetune.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | ALGO=$1
3 | DOMAIN=$2 # walker_mass, quadruped_mass, quadruped_damping
4 | GPU_ID=$3
5 |
6 |
7 | if [ "$DOMAIN" == "walker_mass" ]
8 | then
9 | ALL_TASKS=("walker_stand_mass" "walker_walk_mass" "walker_run_mass" "walker_flip_mass")
10 | elif [ "$DOMAIN" == "quadruped_mass" ]
11 | then
12 | ALL_TASKS=("quadruped_stand_mass" "quadruped_walk_mass" "quadruped_run_mass" "quadruped_jump_mass")
13 | elif [ "$DOMAIN" == "quadruped_damping" ]
14 | then
15 | ALL_TASKS=("quadruped_stand_damping" "quadruped_walk_damping" "quadruped_run_damping" "quadruped_jump_damping")
16 | else
17 | ALL_TASKS=()
18 | echo "No matching tasks"
19 | exit 0
20 | fi
21 |
22 | echo "Experiments started."
23 | for seed in $(seq 0 9)
24 | do
25 | export MUJOCO_EGL_DEVICE_ID=${GPU_ID}
26 | python pretrain.py configs/agent=${ALGO} domain=${DOMAIN} seed=$seed device=cuda:${GPU_ID}
27 | for string in "${ALL_TASKS[@]}"
28 | do
29 | export MUJOCO_EGL_DEVICE_ID=${GPU_ID}
30 | python finetune.py configs/agent=${ALGO} domain=${DOMAIN} seed=$seed device=cuda:${GPU_ID} snapshot_ts=2000000 finetune_domain=$string
31 | done
32 | done
33 | echo "Experiments ended."
34 |
35 | # e.g.
36 | # ./train.sh peac walker_mass 0
--------------------------------------------------------------------------------
/DMC_state/video.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import imageio
3 | import numpy as np
4 | import wandb
5 |
6 |
7 | class VideoRecorder:
8 | def __init__(self,
9 | root_dir,
10 | render_size=256,
11 | fps=20,
12 | camera_id=0,
13 | use_wandb=False):
14 | if root_dir is not None:
15 | self.save_dir = root_dir / 'eval_video'
16 | self.save_dir.mkdir(exist_ok=True)
17 | else:
18 | self.save_dir = None
19 |
20 | self.render_size = render_size
21 | self.fps = fps
22 | self.frames = []
23 | self.camera_id = camera_id
24 | self.use_wandb = use_wandb
25 |
26 | def init(self, env, enabled=True):
27 | self.frames = []
28 | self.enabled = self.save_dir is not None and enabled
29 | self.record(env)
30 |
31 | def record(self, env):
32 | if self.enabled:
33 | if hasattr(env, 'physics'):
34 | frame = env.physics.render(height=self.render_size,
35 | width=self.render_size,
36 | camera_id=self.camera_id)
37 | else:
38 | frame = env.render()
39 | self.frames.append(frame)
40 |
41 | def log_to_wandb(self):
42 | frames = np.transpose(np.array(self.frames), (0, 3, 1, 2))
43 | fps, skip = 6, 8
44 | wandb.log({
45 | 'eval/video':
46 | wandb.Video(frames[::skip, :, ::2, ::2], fps=fps, format="gif")
47 | })
48 |
49 | def save(self, file_name):
50 | if self.enabled:
51 | if self.use_wandb:
52 | self.log_to_wandb()
53 | path = self.save_dir / file_name
54 | imageio.mimsave(str(path), self.frames, fps=self.fps)
55 |
56 |
57 | class TrainVideoRecorder:
58 | def __init__(self,
59 | root_dir,
60 | render_size=256,
61 | fps=20,
62 | camera_id=0,
63 | use_wandb=False):
64 | if root_dir is not None:
65 | self.save_dir = root_dir / 'train_video'
66 | self.save_dir.mkdir(exist_ok=True)
67 | else:
68 | self.save_dir = None
69 |
70 | self.render_size = render_size
71 | self.fps = fps
72 | self.frames = []
73 | self.camera_id = camera_id
74 | self.use_wandb = use_wandb
75 |
76 | def init(self, obs, enabled=True):
77 | self.frames = []
78 | self.enabled = self.save_dir is not None and enabled
79 | self.record(obs)
80 |
81 | def record(self, obs):
82 | if self.enabled:
83 | frame = cv2.resize(obs[-3:].transpose(1, 2, 0),
84 | dsize=(self.render_size, self.render_size),
85 | interpolation=cv2.INTER_CUBIC)
86 | self.frames.append(frame)
87 |
88 | def log_to_wandb(self):
89 | frames = np.transpose(np.array(self.frames), (0, 3, 1, 2))
90 | fps, skip = 6, 8
91 | wandb.log({
92 | 'train/video':
93 | wandb.Video(frames[::skip, :, ::2, ::2], fps=fps, format="gif")
94 | })
95 |
96 | def save(self, file_name):
97 | if self.enabled:
98 | if self.use_wandb:
99 | self.log_to_wandb()
100 | path = self.save_dir / file_name
101 | imageio.mimsave(str(path), self.frames, fps=self.fps)
102 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 TSAIL group
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CEURL
2 |
3 | [](https://arxiv.org/abs/2405.14073) [](https://yingchengyang.github.io/ceurl)
4 |
5 | This is the Official implementation for "PEAC: Unsupervised Pre-training for Cross-Embodiment Reinforcement Learning" (NeurIPS 2024)
6 |
7 | ## State-based DMC & Image-based DMC
8 |
9 | ### Installation
10 |
11 | The code is based on [URLB](https://github.com/rll-research/url_benchmark)
12 |
13 | You can create an anaconda environment and install all required dependencies by running
14 | ```sh
15 | conda create -n ceurl python=3.8
16 | conda activate ceurl
17 | pip install -r requirements.txt
18 | pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html
19 | ```
20 |
21 | ### Instructions
22 |
23 | The simplest way to try PEAC in three embodiment distributions of state-based DMC by running
24 | ```sh
25 | cd DMC_state
26 | chmod +x train_finetune.sh
27 |
28 | ./train_finetune.sh peac walker_mass 0
29 | ./train_finetune.sh peac quadruped_mass 0
30 | ./train_finetune.sh peac quadruped_damping 0
31 | ```
32 |
33 | The simplest way to try PEAC in three embodiment distributions of image-based DMC by running
34 | ```sh
35 | cd DMC_image
36 | chmod +x train_finetune.sh
37 |
38 | ./train_finetune.sh peac_lbs walker_mass 0
39 | ./train_finetune.sh peac_lbs quadruped_mass 0
40 | ./train_finetune.sh peac_lbs quadruped_damping 0
41 |
42 | ./train_finetune.sh peac_diayn walker_mass 0
43 | ./train_finetune.sh peac_diayn quadruped_mass 0
44 | ./train_finetune.sh peac_diayn quadruped_damping 0
45 | ```
46 |
47 | ## Citation
48 |
49 | If you find this work helpful, please cite our paper.
50 |
51 | ```
52 | @article{ying2024peac,
53 | title={PEAC: Unsupervised Pre-training for Cross-Embodiment Reinforcement Learning},
54 | author={Ying, Chengyang and Hao, Zhongkai and Zhou, Xinning and Xu, Xuezhou and Su, Hang and Zhang, Xingxing and Zhu, Jun},
55 | journal={arXiv preprint arXiv:2405.14073},
56 | year={2024}
57 | }
58 | ```
59 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | hydra-core==1.1.0
2 | hydra-submitit-launcher==1.1.5
3 | termcolor==1.1.0
4 | wandb==0.11.1
5 | gym==0.26.2
6 | protobuf==3.20.0
7 | dm_env==1.6
8 | dm_control==1.0.14
9 | tensorboard==2.0.2
10 | numpy==1.19.2
11 | torchaudio==0.8.0
12 | pandas==1.3.0
13 | mujoco==2.3.7
14 | opencv-python
15 | imageio==2.9.0
16 | imageio-ffmpeg==0.4.4
--------------------------------------------------------------------------------