├── .gitignore
├── LICENSE
├── README.md
├── agents
├── drqv2.py
└── taco.py
├── cfgs
├── agent
│ ├── drqv2.yaml
│ └── taco.yaml
├── config.yaml
└── task
│ ├── acrobot_swingup.yaml
│ ├── acrobot_swingup_sparse.yaml
│ ├── cartpole_balance.yaml
│ ├── cartpole_balance_sparse.yaml
│ ├── cartpole_swingup.yaml
│ ├── cartpole_swingup_sparse.yaml
│ ├── cheetah_run.yaml
│ ├── cup_catch.yaml
│ ├── easy.yaml
│ ├── finger_spin.yaml
│ ├── finger_turn_easy.yaml
│ ├── finger_turn_hard.yaml
│ ├── hard.yaml
│ ├── hopper_hop.yaml
│ ├── hopper_stand.yaml
│ ├── manipulator_bring_ball.yaml
│ ├── medium.yaml
│ ├── pendulum_swingup.yaml
│ ├── quadruped_run.yaml
│ ├── quadruped_walk.yaml
│ ├── reach_duplo.yaml
│ ├── reacher_easy.yaml
│ ├── reacher_hard.yaml
│ ├── walker_run.yaml
│ ├── walker_stand.yaml
│ └── walker_walk.yaml
├── dmc.py
├── environment.yml
├── logger.py
├── media
├── dmc.gif
├── overview.png
└── taco.jpg
├── replay_buffer.py
├── train.py
├── utils.py
└── video.py
/.gitignore:
--------------------------------------------------------------------------------
1 | exp_local/
2 | # Byte-compiled / optimized / DLL files
3 | __pycache__/
4 | *.py[cod]
5 | *$py.class
6 |
7 | # C extensions
8 | *.so
9 |
10 | # Distribution / packaging
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 | cover/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | .pybuilder/
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | # For a library or package, you might want to ignore these files since the code is
88 | # intended to run in multiple environments; otherwise, check them in:
89 | # .python-version
90 |
91 | # pipenv
92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
95 | # install all needed dependencies.
96 | #Pipfile.lock
97 |
98 | # poetry
99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
100 | # This is especially recommended for binary packages to ensure reproducibility, and is more
101 | # commonly ignored for libraries.
102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
103 | #poetry.lock
104 |
105 | # pdm
106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
107 | #pdm.lock
108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
109 | # in version control.
110 | # https://pdm.fming.dev/#use-with-ide
111 | .pdm.toml
112 |
113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
114 | __pypackages__/
115 |
116 | # Celery stuff
117 | celerybeat-schedule
118 | celerybeat.pid
119 |
120 | # SageMath parsed files
121 | *.sage.py
122 |
123 | # Environments
124 | .env
125 | .venv
126 | env/
127 | venv/
128 | ENV/
129 | env.bak/
130 | venv.bak/
131 |
132 | # Spyder project settings
133 | .spyderproject
134 | .spyproject
135 |
136 | # Rope project settings
137 | .ropeproject
138 |
139 | # mkdocs documentation
140 | /site
141 |
142 | # mypy
143 | .mypy_cache/
144 | .dmypy.json
145 | dmypy.json
146 |
147 | # Pyre type checker
148 | .pyre/
149 |
150 | # pytype static type analyzer
151 | .pytype/
152 |
153 | # Cython debug symbols
154 | cython_debug/
155 |
156 | # PyCharm
157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
159 | # and can be added to the global gitignore or merged into this file. For a more nuclear
160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
161 | #.idea/
162 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Ruijie Zheng
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TACO: Temporal Action-driven Contrastive Learning
2 |
3 | Original PyTorch implementation of **TACO** from
4 |
5 | [TACO: Temporal Latent Action-Driven Contrastive Loss for Visual Reinforcement Learning](https://arxiv.org/pdf/2306.13229.pdf) by
6 |
7 | [Ruijie Zheng](https://ruijiezheng.com), [Xiyao Wang](https://si0wang.github.io)\*, [Yanchao Sun](https://ycsun2017.github.io)\*, [Shuang Ma](https://www.shuangma.me)\*, [Jieyu Zhao](https://jyzhao.net)\*, [Huazhe Xu](http://hxu.rocks)\*, [Hal Daumé III](http://users.umiacs.umd.edu/~hal/)\*, [Furong Huang](https://furong-huang.com)\*
8 |
9 |
10 |
11 |

12 | [Paper] [Website]
13 |
14 |
15 |
16 | ## Method
17 |
18 | **TACO** is a simple yet powerful temporal contrastive learning approach that facilitates the concurrent acquisition of latent state and action representations for agents. **TACO** simultaneously learns a state and an action representation by optimizing the mutual information between representations of current states paired with action sequences and representations of the corresponding future states.
19 |
20 |
21 |
22 |
23 |
24 |
25 | ## Citation
26 |
27 | If you use our method or code in your research, please consider citing the paper as follows:
28 |
29 | ```
30 | @inproceedings{
31 | zheng2023taco,
32 | title={\${\textbackslash}texttt\{{TACO}\}\$: Temporal Latent Action-Driven Contrastive Loss for Visual Reinforcement Learning},
33 | author={Ruijie Zheng and Xiyao Wang and Yanchao Sun and Shuang Ma and Jieyu Zhao and Huazhe Xu and Hal Daumé III and Furong Huang},
34 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
35 | year={2023},
36 | url={https://openreview.net/forum?id=ezCsMOy1w9}
37 | }
38 |
39 | ```
40 |
41 | ## Instructions
42 |
43 | Assuming that you already have [MuJoCo](http://www.mujoco.org) installed, install dependencies using `conda`:
44 |
45 | ```
46 | conda env create -f environment.yml
47 | conda activate taco
48 | ```
49 |
50 | After installing dependencies, you can train a **TACO** agent by calling (using quadruped_run as an example):
51 |
52 | ```
53 | CUDA_VISIBLE_DEVICES=X python train.py agent=taco task=quadruped_run exp_name=${EXP_NAME}
54 | ```
55 |
56 | To train a **DrQ-v2** agent:
57 | ```
58 | CUDA_VISIBLE_DEVICES=X python train.py agent=drqv2 task=quadruped_run exp_name=${EXP_NAME}
59 | ```
60 |
61 | Evaluation videos and model weights can be saved with arguments `save_video=True` and `save_model=True`. Refer to the `cfgs` directory for a full list of options and default hyperparameters.
62 |
63 |
64 | ## Acknowledgement
65 | TACO is licensed under the MIT license. MuJoCo and DeepMind Control Suite are licensed under the Apache 2.0 license. We would like to thank DrQ-v2 authors for open-sourcing the [DrQv2](https://github.com/facebookresearch/drqv2) codebase. Our implementation builds on top of their repository.
66 |
67 |
68 |
69 |
--------------------------------------------------------------------------------
/agents/drqv2.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import utils
7 | import itertools
8 |
9 |
10 | class RandomShiftsAug(nn.Module):
11 | def __init__(self, pad):
12 | super().__init__()
13 | self.pad = pad
14 |
15 | def forward(self, x):
16 | n, c, h, w = x.size()
17 | assert h == w
18 | padding = tuple([self.pad] * 4)
19 | x = F.pad(x, padding, 'replicate')
20 | eps = 1.0 / (h + 2 * self.pad)
21 | arange = torch.linspace(-1.0 + eps,
22 | 1.0 - eps,
23 | h + 2 * self.pad,
24 | device=x.device,
25 | dtype=x.dtype)[:h]
26 | arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
27 | base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
28 | base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
29 |
30 | shift = torch.randint(0,
31 | 2 * self.pad + 1,
32 | size=(n, 1, 1, 2),
33 | device=x.device,
34 | dtype=x.dtype)
35 | shift *= 2.0 / (h + 2 * self.pad)
36 |
37 | grid = base_grid + shift
38 | return F.grid_sample(x,
39 | grid,
40 | padding_mode='zeros',
41 | align_corners=False)
42 |
43 | class Encoder(nn.Module):
44 | def __init__(self, obs_shape, feature_dim):
45 | super().__init__()
46 |
47 | assert len(obs_shape) == 3
48 | self.repr_dim = 32 * 35 * 35
49 |
50 | self.convnet = nn.Sequential(nn.Conv2d(obs_shape[0], 32, 3, stride=2),
51 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
52 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
53 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
54 | nn.ReLU())
55 |
56 | self.apply(utils.weight_init)
57 |
58 | def forward(self, obs):
59 | obs = obs / 255.0 - 0.5
60 | h = self.convnet(obs)
61 | h = h.view(h.shape[0], -1)
62 | return h
63 |
64 |
65 |
66 | class Actor(nn.Module):
67 | def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim):
68 | super().__init__()
69 |
70 | self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim),
71 | nn.LayerNorm(feature_dim), nn.Tanh())
72 |
73 | self.policy = nn.Sequential(nn.Linear(feature_dim, hidden_dim),
74 | nn.ReLU(inplace=True),
75 | nn.Linear(hidden_dim, hidden_dim),
76 | nn.ReLU(inplace=True),
77 | nn.Linear(hidden_dim, action_shape[0]))
78 |
79 | self.apply(utils.weight_init)
80 |
81 | def forward(self, obs, std):
82 | h = self.trunk(obs)
83 | mu = self.policy(h)
84 | mu = torch.tanh(mu)
85 | std = torch.ones_like(mu) * std
86 |
87 | dist = utils.TruncatedNormal(mu, std)
88 | return dist
89 |
90 |
91 | class Critic(nn.Module):
92 | def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim):
93 | super().__init__()
94 |
95 | self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim),
96 | nn.LayerNorm(feature_dim), nn.Tanh())
97 |
98 | self.Q1 = nn.Sequential(
99 | nn.Linear(feature_dim + action_shape[0], hidden_dim),
100 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim),
101 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1))
102 |
103 | self.Q2 = nn.Sequential(
104 | nn.Linear(feature_dim + action_shape[0], hidden_dim),
105 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim),
106 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1))
107 |
108 | self.apply(utils.weight_init)
109 |
110 | def forward(self, obs, action):
111 | h = self.trunk(obs)
112 | h_action = torch.cat([h, action], dim=-1)
113 | q1 = self.Q1(h_action)
114 | q2 = self.Q2(h_action)
115 |
116 | return q1, q2
117 |
118 |
119 | class DrQV2Agent:
120 | def __init__(self, obs_shape, action_shape, device, lr, feature_dim,
121 | hidden_dim, critic_target_tau, num_expl_steps,
122 | update_every_steps, stddev_schedule, stddev_clip, use_tb):
123 | self.device = device
124 | self.critic_target_tau = critic_target_tau
125 | self.update_every_steps = update_every_steps
126 | self.use_tb = use_tb
127 | self.num_expl_steps = num_expl_steps
128 | self.stddev_schedule = stddev_schedule
129 | self.stddev_clip = stddev_clip
130 |
131 | self.encoder = Encoder(obs_shape, feature_dim).to(device)
132 | self.actor = Actor(self.encoder.repr_dim, action_shape, feature_dim,
133 | hidden_dim).to(device)
134 |
135 | self.critic = Critic(self.encoder.repr_dim, action_shape, feature_dim, hidden_dim).to(device)
136 | self.critic_target = Critic(self.encoder.repr_dim, action_shape, feature_dim, hidden_dim).to(device)
137 | self.critic_target.load_state_dict(self.critic.state_dict())
138 |
139 | # optimizers
140 | self.encoder_opt = torch.optim.Adam(self.encoder.parameters(), lr=lr)
141 | self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr)
142 | self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=lr)
143 |
144 | self.cross_entropy_loss = nn.CrossEntropyLoss()
145 | self.mse_loss = nn.MSELoss()
146 |
147 | # data augmentation
148 | self.aug = RandomShiftsAug(pad=4)
149 |
150 | self.train()
151 | self.critic_target.train()
152 |
153 | def train(self, training=True):
154 | self.training = training
155 | self.encoder.train(training)
156 | self.actor.train(training)
157 | self.critic.train(training)
158 |
159 | def act(self, obs, step, eval_mode):
160 | obs = torch.as_tensor(obs, device=self.device)
161 | obs = self.encoder(obs.unsqueeze(0))
162 | stddev = utils.schedule(self.stddev_schedule, step)
163 | dist = self.actor(obs, stddev)
164 | if eval_mode:
165 | action = dist.mean
166 | else:
167 | action = dist.sample(clip=None)
168 | if step < self.num_expl_steps:
169 | action.uniform_(-1.0, 1.0)
170 | return action.cpu().numpy()[0]
171 |
172 | def update_critic(self, obs, action, reward, discount, next_obs, step):
173 | metrics = dict()
174 |
175 | with torch.no_grad():
176 | stddev = utils.schedule(self.stddev_schedule, step)
177 | dist = self.actor(next_obs, stddev)
178 | next_action = dist.sample(clip=self.stddev_clip)
179 | target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
180 | target_V = torch.min(target_Q1, target_Q2)
181 | target_Q = reward + (discount * target_V)
182 |
183 | Q1, Q2 = self.critic(obs, action)
184 | critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q)
185 |
186 | if self.use_tb:
187 | metrics['critic_target_q'] = target_Q.mean().item()
188 | metrics['critic_q1'] = Q1.mean().item()
189 | metrics['critic_q2'] = Q2.mean().item()
190 | metrics['critic_loss'] = critic_loss.item()
191 |
192 | # optimize encoder and critic
193 | self.encoder_opt.zero_grad(set_to_none=True)
194 | self.critic_opt.zero_grad(set_to_none=True)
195 | critic_loss.backward()
196 | self.critic_opt.step()
197 | self.encoder_opt.step()
198 |
199 | return metrics
200 |
201 | def update_actor(self, obs, step):
202 | metrics = dict()
203 |
204 | stddev = utils.schedule(self.stddev_schedule, step)
205 | dist = self.actor(obs, stddev)
206 | action = dist.sample(clip=self.stddev_clip)
207 | log_prob = dist.log_prob(action).sum(-1, keepdim=True)
208 | Q1, Q2 = self.critic(obs, action)
209 | Q = torch.min(Q1, Q2)
210 |
211 | actor_loss = -Q.mean()
212 |
213 | # optimize actor
214 | self.actor_opt.zero_grad(set_to_none=True)
215 | actor_loss.backward()
216 | self.actor_opt.step()
217 |
218 | if self.use_tb:
219 | metrics['actor_loss'] = actor_loss.item()
220 | metrics['actor_logprob'] = log_prob.mean().item()
221 | metrics['actor_ent'] = dist.entropy().sum(dim=-1).mean().item()
222 |
223 | return metrics
224 |
225 |
226 |
227 | def update(self, replay_iter, step):
228 | metrics = dict()
229 | if step % self.update_every_steps != 0:
230 | return metrics
231 |
232 | batch = next(replay_iter)
233 | obs, action, action_seq, reward, discount, next_obs, r_next_obs = utils.to_torch(
234 | batch, self.device)
235 |
236 | # augment
237 | obs_en = self.aug(obs.float())
238 | next_obs_en = self.aug(next_obs.float())
239 | # encode
240 | obs_en = self.encoder(obs_en)
241 | with torch.no_grad():
242 | next_obs_en = self.encoder(next_obs_en)
243 |
244 | if self.use_tb:
245 | metrics['batch_reward'] = reward.mean().item()
246 |
247 | # update critic
248 | metrics.update(
249 | self.update_critic(obs_en, action, reward, discount, next_obs_en, step))
250 |
251 | # update actor
252 | metrics.update(self.update_actor(obs_en.detach(), step))
253 |
254 | # update critic target
255 | utils.soft_update_params(self.critic, self.critic_target,
256 | self.critic_target_tau)
257 |
258 | return metrics
259 |
--------------------------------------------------------------------------------
/agents/taco.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import utils
7 | import itertools
8 |
9 |
10 | class RandomShiftsAug(nn.Module):
11 | def __init__(self, pad):
12 | super().__init__()
13 | self.pad = pad
14 |
15 | def forward(self, x):
16 | n, c, h, w = x.size()
17 | assert h == w
18 | padding = tuple([self.pad] * 4)
19 | x = F.pad(x, padding, 'replicate')
20 | eps = 1.0 / (h + 2 * self.pad)
21 | arange = torch.linspace(-1.0 + eps,
22 | 1.0 - eps,
23 | h + 2 * self.pad,
24 | device=x.device,
25 | dtype=x.dtype)[:h]
26 | arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
27 | base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
28 | base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
29 |
30 | shift = torch.randint(0,
31 | 2 * self.pad + 1,
32 | size=(n, 1, 1, 2),
33 | device=x.device,
34 | dtype=x.dtype)
35 | shift *= 2.0 / (h + 2 * self.pad)
36 |
37 | grid = base_grid + shift
38 | return F.grid_sample(x,
39 | grid,
40 | padding_mode='zeros',
41 | align_corners=False)
42 |
43 | class Encoder(nn.Module):
44 | def __init__(self, obs_shape, feature_dim):
45 | super().__init__()
46 |
47 | assert len(obs_shape) == 3
48 | self.repr_dim = 32 * 35 * 35
49 |
50 | self.convnet = nn.Sequential(nn.Conv2d(obs_shape[0], 32, 3, stride=2),
51 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
52 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
53 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
54 | nn.ReLU())
55 |
56 | self.apply(utils.weight_init)
57 |
58 | def forward(self, obs):
59 | obs = obs / 255.0 - 0.5
60 | h = self.convnet(obs)
61 | h = h.view(h.shape[0], -1)
62 | return h
63 |
64 | class TACO(nn.Module):
65 | """
66 | TACO Constrastive loss
67 | """
68 |
69 | def __init__(self, repr_dim, feature_dim, action_shape, latent_a_dim, hidden_dim, act_tok, encoder, multistep, device):
70 | super(TACO, self).__init__()
71 |
72 | self.multistep = multistep
73 | self.encoder = encoder
74 | self.device = device
75 |
76 | a_dim = action_shape[0]
77 |
78 | self.proj_sa = nn.Sequential(
79 | nn.Linear(feature_dim + latent_a_dim*multistep, hidden_dim),
80 | nn.ReLU(inplace=True),
81 | nn.Linear(hidden_dim, feature_dim)
82 | )
83 |
84 | self.act_tok = act_tok
85 |
86 | self.proj_s = nn.Sequential(nn.Linear(repr_dim, feature_dim),
87 | nn.LayerNorm(feature_dim), nn.Tanh())
88 |
89 | self.reward = nn.Sequential(
90 | nn.Linear(feature_dim+latent_a_dim*multistep, hidden_dim),
91 | nn.ReLU(inplace=True),
92 | nn.Linear(hidden_dim, 1)
93 | )
94 |
95 | self.W = nn.Parameter(torch.rand(feature_dim, feature_dim))
96 | self.apply(utils.weight_init)
97 |
98 | def encode(self, x, ema=False):
99 | """
100 | Encoder: z_t = e(x_t)
101 | :param x: x_t, x y coordinates
102 | :return: z_t, value in r2
103 | """
104 | if ema:
105 | with torch.no_grad():
106 | z_out = self.proj_s(self.encoder(x))
107 | else:
108 | z_out = self.proj_s(self.encoder(x))
109 | return z_out
110 |
111 | def project_sa(self, s, a):
112 | x = torch.concat([s,a], dim=-1)
113 | return self.proj_sa(x)
114 |
115 | def compute_logits(self, z_a, z_pos):
116 | """
117 | - compute (B,B) matrix z_a (W z_pos.T)
118 | - positives are all diagonal elements
119 | - negatives are all other elements
120 | - to compute loss use multiclass cross entropy with identity matrix for labels
121 | """
122 |
123 | Wz = torch.matmul(self.W, z_pos.T) # (z_dim,B)
124 | logits = torch.matmul(z_a, Wz) # (B,B)
125 | logits = logits - torch.max(logits, 1)[0][:, None]
126 | return logits
127 |
128 |
129 | class Actor(nn.Module):
130 | def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim):
131 | super().__init__()
132 |
133 | self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim),
134 | nn.LayerNorm(feature_dim), nn.Tanh())
135 |
136 | self.policy = nn.Sequential(nn.Linear(feature_dim, hidden_dim),
137 | nn.ReLU(inplace=True),
138 | nn.Linear(hidden_dim, hidden_dim),
139 | nn.ReLU(inplace=True),
140 | nn.Linear(hidden_dim, action_shape[0]))
141 |
142 | self.apply(utils.weight_init)
143 |
144 | def forward(self, obs, std):
145 | h = self.trunk(obs)
146 | mu = self.policy(h)
147 | mu = torch.tanh(mu)
148 | std = torch.ones_like(mu) * std
149 |
150 | dist = utils.TruncatedNormal(mu, std)
151 | return dist
152 |
153 |
154 | class Critic(nn.Module):
155 | def __init__(self, repr_dim, latent_a_dim, feature_dim, hidden_dim):
156 | super().__init__()
157 |
158 | self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim),
159 | nn.LayerNorm(feature_dim), nn.Tanh())
160 |
161 | self.Q1 = nn.Sequential(
162 | nn.Linear(feature_dim + latent_a_dim, hidden_dim),
163 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim),
164 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1))
165 |
166 | self.Q2 = nn.Sequential(
167 | nn.Linear(feature_dim + latent_a_dim, hidden_dim),
168 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim),
169 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1))
170 |
171 | self.apply(utils.weight_init)
172 |
173 | def forward(self, obs, action, act_tok=None):
174 | if act_tok is not None:
175 | action = act_tok(action)
176 | h = self.trunk(obs)
177 | h_action = torch.cat([h, action], dim=-1)
178 | q1 = self.Q1(h_action)
179 | q2 = self.Q2(h_action)
180 |
181 | return q1, q2
182 |
183 |
184 | class TACOAgent:
185 | def __init__(self, obs_shape, action_shape, device, lr, encoder_lr, feature_dim,
186 | hidden_dim, critic_target_tau, num_expl_steps,
187 | update_every_steps, stddev_schedule, stddev_clip, use_tb,
188 | reward, multistep, latent_a_dim, curl):
189 | self.device = device
190 | self.critic_target_tau = critic_target_tau
191 | self.update_every_steps = update_every_steps
192 | self.use_tb = use_tb
193 | self.num_expl_steps = num_expl_steps
194 | self.stddev_schedule = stddev_schedule
195 | self.stddev_clip = stddev_clip
196 |
197 | self.reward = reward
198 | self.multistep = multistep
199 | self.curl = curl
200 |
201 | ### A heuristics to choose the dimensionality of latent actions
202 | if latent_a_dim == 'none':
203 | latent_a_dim = int(action_shape[0]*1.25)+1
204 | ### Create action embeddings
205 | self.act_tok = utils.ActionEncoding(action_shape[0], latent_a_dim, multistep)
206 | self.encoder = Encoder(obs_shape, feature_dim).to(device)
207 |
208 | self.actor = Actor(self.encoder.repr_dim, action_shape, feature_dim,
209 | hidden_dim).to(device)
210 | self.critic = Critic(self.encoder.repr_dim, latent_a_dim, feature_dim,
211 | hidden_dim).to(device)
212 | self.critic_target = Critic(self.encoder.repr_dim, latent_a_dim,
213 | feature_dim, hidden_dim).to(device)
214 | self.critic_target.load_state_dict(self.critic.state_dict())
215 | self.TACO = TACO(self.encoder.repr_dim, feature_dim, action_shape, latent_a_dim, hidden_dim, self.act_tok, self.encoder, multistep, device).to(device)
216 |
217 | ### State & Action Encoders
218 | parameters = itertools.chain(self.encoder.parameters(),
219 | self.act_tok.parameters(),
220 | )
221 | self.encoder_opt = torch.optim.Adam(parameters, lr=encoder_lr)
222 | self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr)
223 | self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=lr)
224 | self.taco_opt = torch.optim.Adam(self.TACO.parameters(), lr=encoder_lr)
225 |
226 | self.cross_entropy_loss = nn.CrossEntropyLoss()
227 |
228 | # data augmentation
229 | self.aug = RandomShiftsAug(pad=4)
230 |
231 | self.train()
232 | self.critic_target.train()
233 |
234 | def train(self, training=True):
235 | self.training = training
236 | self.encoder.train(training)
237 | self.actor.train(training)
238 | self.critic.train(training)
239 | self.TACO.train()
240 |
241 | def act(self, obs, step, eval_mode):
242 | obs = torch.as_tensor(obs, device=self.device)
243 | obs = self.encoder(obs.unsqueeze(0))
244 | stddev = utils.schedule(self.stddev_schedule, step)
245 | dist = self.actor(obs, stddev)
246 | if eval_mode:
247 | action = dist.mean
248 | else:
249 | action = dist.sample(clip=None)
250 | if step < self.num_expl_steps:
251 | action.uniform_(-1.0, 1.0)
252 | return action.cpu().numpy()[0]
253 |
254 | def update_critic(self, obs, action, reward, discount, next_obs, step):
255 | metrics = dict()
256 |
257 | with torch.no_grad():
258 | stddev = utils.schedule(self.stddev_schedule, step)
259 | dist = self.actor(next_obs, stddev)
260 | next_action = dist.sample(clip=self.stddev_clip)
261 | target_Q1, target_Q2 = self.critic_target(next_obs, next_action, self.act_tok)
262 | target_V = torch.min(target_Q1, target_Q2)
263 | target_Q = reward + (discount * target_V)
264 |
265 | Q1, Q2 = self.critic(obs, action, self.act_tok)
266 | critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q)
267 |
268 | if self.use_tb:
269 | metrics['critic_target_q'] = target_Q.mean().item()
270 | metrics['critic_q1'] = Q1.mean().item()
271 | metrics['critic_q2'] = Q2.mean().item()
272 | metrics['critic_loss'] = critic_loss.item()
273 |
274 | # optimize encoder and critic
275 | self.encoder_opt.zero_grad(set_to_none=True)
276 | self.critic_opt.zero_grad(set_to_none=True)
277 | critic_loss.backward()
278 | self.critic_opt.step()
279 | self.encoder_opt.step()
280 |
281 | return metrics
282 |
283 | def update_actor(self, obs, step):
284 | metrics = dict()
285 |
286 | stddev = utils.schedule(self.stddev_schedule, step)
287 | dist = self.actor(obs, stddev)
288 | action = dist.sample(clip=self.stddev_clip)
289 | log_prob = dist.log_prob(action).sum(-1, keepdim=True)
290 | Q1, Q2 = self.critic(obs, action, self.act_tok)
291 | Q = torch.min(Q1, Q2)
292 |
293 | actor_loss = -Q.mean()
294 |
295 | # optimize actor
296 | self.actor_opt.zero_grad(set_to_none=True)
297 | actor_loss.backward()
298 | self.actor_opt.step()
299 |
300 | if self.use_tb:
301 | metrics['actor_loss'] = actor_loss.item()
302 | metrics['actor_logprob'] = log_prob.mean().item()
303 | metrics['actor_ent'] = dist.entropy().sum(dim=-1).mean().item()
304 |
305 | return metrics
306 |
307 | def update_taco(self, obs, action, action_seq, next_obs, reward):
308 | metrics = dict()
309 |
310 | obs_anchor = self.aug(obs.float())
311 | obs_pos = self.aug(obs.float())
312 | z_a = self.TACO.encode(obs_anchor)
313 | z_pos = self.TACO.encode(obs_pos, ema=True)
314 | ### Compute CURL loss
315 | if self.curl:
316 | logits = self.TACO.compute_logits(z_a, z_pos)
317 | labels = torch.arange(logits.shape[0]).long().to(self.device)
318 | curl_loss = self.cross_entropy_loss(logits, labels)
319 | else:
320 | curl_loss = torch.tensor(0.)
321 |
322 | ### Compute action encodings
323 | action_en = self.TACO.act_tok(action, seq=False)
324 | action_seq_en = self.TACO.act_tok(action_seq, seq=True)
325 |
326 | ### Compute reward prediction loss
327 | if self.reward:
328 | reward_pred = self.TACO.reward(torch.concat([z_a, action_seq_en], dim=-1))
329 | reward_loss = F.mse_loss(reward_pred, reward)
330 | else:
331 | reward_loss = torch.tensor(0.)
332 |
333 | ### Compute TACO loss
334 | next_z = self.TACO.encode(self.aug(next_obs.float()), ema=True)
335 | curr_za = self.TACO.project_sa(z_a, action_seq_en)
336 | logits = self.TACO.compute_logits(curr_za, next_z)
337 | labels = torch.arange(logits.shape[0]).long().to(self.device)
338 | taco_loss = self.cross_entropy_loss(logits, labels)
339 |
340 | self.taco_opt.zero_grad()
341 | (taco_loss + curl_loss + reward_loss).backward()
342 | self.taco_opt.step()
343 | if self.use_tb:
344 | metrics['reward_loss'] = reward_loss.item()
345 | metrics['curl_loss'] = curl_loss.item()
346 | metrics['taco_loss'] = taco_loss.item()
347 | return metrics
348 |
349 |
350 |
351 | def update(self, replay_iter, step):
352 | metrics = dict()
353 | if step % self.update_every_steps != 0:
354 | return metrics
355 |
356 | batch = next(replay_iter)
357 | obs, action, action_seq, reward, discount, next_obs, r_next_obs = utils.to_torch(
358 | batch, self.device)
359 |
360 | # augment
361 | obs_en = self.aug(obs.float())
362 | next_obs_en = self.aug(next_obs.float())
363 | # encode
364 | obs_en = self.encoder(obs_en)
365 | with torch.no_grad():
366 | next_obs_en = self.encoder(next_obs_en)
367 |
368 | if self.use_tb:
369 | metrics['batch_reward'] = reward.mean().item()
370 |
371 | # update critic
372 | metrics.update(
373 | self.update_critic(obs_en, action, reward, discount, next_obs_en, step))
374 |
375 | # update actor
376 | metrics.update(self.update_actor(obs_en.detach(), step))
377 |
378 | # update critic target
379 | utils.soft_update_params(self.critic, self.critic_target,
380 | self.critic_target_tau)
381 |
382 | metrics.update(self.update_taco(obs, action, action_seq, r_next_obs, reward))
383 |
384 | return metrics
385 |
--------------------------------------------------------------------------------
/cfgs/agent/drqv2.yaml:
--------------------------------------------------------------------------------
1 | agent:
2 | _target_: agents.drqv2.DrQV2Agent
3 | obs_shape: ??? # to be specified later
4 | action_shape: ??? # to be specified later
5 | device: ${device}
6 | lr: ${lr}
7 | critic_target_tau: 0.01
8 | update_every_steps: 2
9 | use_tb: ${use_tb}
10 | num_expl_steps: 2000
11 | hidden_dim: 1024
12 | feature_dim: ${feature_dim}
13 | stddev_schedule: ${stddev_schedule}
14 | stddev_clip: 0.3
15 |
16 | batch_size: 256
--------------------------------------------------------------------------------
/cfgs/agent/taco.yaml:
--------------------------------------------------------------------------------
1 | agent:
2 | _target_: agents.taco.TACOAgent
3 | obs_shape: ??? # to be specified later
4 | action_shape: ??? # to be specified later
5 | device: ${device}
6 | lr: ${lr}
7 | encoder_lr: ${encoder_lr}
8 | critic_target_tau: 0.01
9 | update_every_steps: 2
10 | use_tb: ${use_tb}
11 | num_expl_steps: 2000
12 | hidden_dim: 1024
13 | feature_dim: ${feature_dim}
14 | stddev_schedule: ${stddev_schedule}
15 | stddev_clip: 0.3
16 | curl: ${curl}
17 | reward: ${reward}
18 | multistep: ${multistep}
19 | latent_a_dim: ${latent_a_dim}
20 |
21 | ### TACO parameters
22 | curl: true
23 | reward: true
24 | multistep: 3
25 | latent_a_dim: none
26 | batch_size: 1024
--------------------------------------------------------------------------------
/cfgs/config.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - task@_global_: quadruped_walk
4 | - agent@_global_: taco
5 | - override hydra/launcher: submitit_local
6 |
7 | domain: dmc
8 | # task settings
9 | frame_stack: 3
10 | action_repeat: 2
11 | discount: 0.99
12 | # train settings
13 | num_seed_frames: 4000
14 | # eval
15 | eval_every_frames: 10000
16 | num_eval_episodes: 10
17 | # snapshot
18 | save_snapshot: true
19 | # replay buffer
20 | replay_buffer_size: 1000000
21 | replay_buffer_num_workers: 4
22 | nstep: 3
23 | batch_size: 256
24 | # misc
25 | seed: 1
26 | device: cuda
27 | save_video: false
28 | save_train_video: false
29 | use_tb: false
30 | # experiment
31 | experiment: exp
32 | # agent
33 | lr: 1e-4
34 | encoder_lr: 1e-4
35 | feature_dim: 50
36 | exp_name: default
37 |
38 |
39 | hydra:
40 | run:
41 | dir: ./exp_local/${exp_name}
42 | sweep:
43 | dir: ./exp/${now:%Y.%m.%d}/${now:%H%M}_${agent_cfg.experiment}
44 | subdir: ${hydra.job.num}
45 | launcher:
46 | timeout_min: 4300
47 | cpus_per_task: 10
48 | gpus_per_node: 1
49 | tasks_per_node: 1
50 | mem_gb: 160
51 | nodes: 1
52 | submitit_folder: ./exp/${now:%Y.%m.%d}/${now:%H%M%S}_${agent_cfg.experiment}/.slurm
--------------------------------------------------------------------------------
/cfgs/task/acrobot_swingup.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - medium
3 | - _self_
4 |
5 | task_name: acrobot_swingup
6 |
--------------------------------------------------------------------------------
/cfgs/task/acrobot_swingup_sparse.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - medium
3 | - _self_
4 |
5 | task_name: acrobot_swingup_sparse
6 |
--------------------------------------------------------------------------------
/cfgs/task/cartpole_balance.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - easy
3 | - _self_
4 |
5 | task_name: cartpole_balance
6 |
--------------------------------------------------------------------------------
/cfgs/task/cartpole_balance_sparse.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - easy
3 | - _self_
4 |
5 | task_name: cartpole_balance_sparse
6 |
--------------------------------------------------------------------------------
/cfgs/task/cartpole_swingup.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - easy
3 | - _self_
4 |
5 | task_name: cartpole_swingup
6 |
--------------------------------------------------------------------------------
/cfgs/task/cartpole_swingup_sparse.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - medium
3 | - _self_
4 |
5 | task_name: cartpole_swingup_sparse
6 |
--------------------------------------------------------------------------------
/cfgs/task/cheetah_run.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - medium
3 | - _self_
4 |
5 | task_name: cheetah_run
6 |
--------------------------------------------------------------------------------
/cfgs/task/cup_catch.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - easy
3 | - _self_
4 |
5 | task_name: cup_catch
6 |
--------------------------------------------------------------------------------
/cfgs/task/easy.yaml:
--------------------------------------------------------------------------------
1 | num_train_frames: 1100000
2 | stddev_schedule: 'linear(1.0,0.1,100000)'
--------------------------------------------------------------------------------
/cfgs/task/finger_spin.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - easy
3 | - _self_
4 |
5 | task_name: finger_spin
6 |
--------------------------------------------------------------------------------
/cfgs/task/finger_turn_easy.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - medium
3 | - _self_
4 |
5 | task_name: finger_turn_easy
--------------------------------------------------------------------------------
/cfgs/task/finger_turn_hard.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - medium
3 | - _self_
4 |
5 | task_name: finger_turn_hard
6 |
--------------------------------------------------------------------------------
/cfgs/task/hard.yaml:
--------------------------------------------------------------------------------
1 | num_train_frames: 30100000
2 | stddev_schedule: 'linear(1.0,0.1,2000000)'
3 |
--------------------------------------------------------------------------------
/cfgs/task/hopper_hop.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - medium
3 | - _self_
4 |
5 | task_name: hopper_hop
6 |
--------------------------------------------------------------------------------
/cfgs/task/hopper_stand.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - easy
3 | - _self_
4 |
5 | task_name: hopper_stand
6 |
--------------------------------------------------------------------------------
/cfgs/task/manipulator_bring_ball.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - hard
3 | - _self_
4 |
5 | task_name: manipulator_bring_ball
6 |
--------------------------------------------------------------------------------
/cfgs/task/medium.yaml:
--------------------------------------------------------------------------------
1 | num_train_frames: 3100000
2 | stddev_schedule: 'linear(1.0,0.1,500000)'
3 |
--------------------------------------------------------------------------------
/cfgs/task/pendulum_swingup.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - easy
3 | - _self_
4 |
5 | task_name: pendulum_swingup
6 |
--------------------------------------------------------------------------------
/cfgs/task/quadruped_run.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - medium
3 | - _self_
4 |
5 | task_name: quadruped_run
6 | replay_buffer_size: 100000
--------------------------------------------------------------------------------
/cfgs/task/quadruped_walk.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - medium
3 | - _self_
4 |
5 | task_name: quadruped_walk
6 |
--------------------------------------------------------------------------------
/cfgs/task/reach_duplo.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - medium
3 | - _self_
4 |
5 | task_name: reach_duplo
6 |
--------------------------------------------------------------------------------
/cfgs/task/reacher_easy.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - medium
3 | - _self_
4 |
5 | task_name: reacher_easy
6 |
--------------------------------------------------------------------------------
/cfgs/task/reacher_hard.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - medium
3 | - _self_
4 |
5 | task_name: reacher_hard
6 |
--------------------------------------------------------------------------------
/cfgs/task/walker_run.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - medium
3 | - _self_
4 |
5 | task_name: walker_run
6 | nstep: 1
7 | batch_size: 512
--------------------------------------------------------------------------------
/cfgs/task/walker_stand.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - easy
3 | - _self_
4 |
5 | task_name: walker_stand
6 | nstep: 1
7 | batch_size: 512
8 |
--------------------------------------------------------------------------------
/cfgs/task/walker_walk.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - easy
3 | - _self_
4 |
5 | task_name: walker_walk
6 | nstep: 1
7 | batch_size: 512
8 |
--------------------------------------------------------------------------------
/dmc.py:
--------------------------------------------------------------------------------
1 | from collections import deque
2 | from typing import Any, NamedTuple
3 |
4 | import dm_env
5 | import numpy as np
6 | from dm_control import manipulation, suite
7 | from dm_control.suite.wrappers import action_scale, pixels
8 | from dm_env import StepType, specs
9 |
10 |
11 | class ExtendedTimeStep(NamedTuple):
12 | step_type: Any
13 | reward: Any
14 | discount: Any
15 | observation: Any
16 | action: Any
17 |
18 | def first(self):
19 | return self.step_type == StepType.FIRST
20 |
21 | def mid(self):
22 | return self.step_type == StepType.MID
23 |
24 | def last(self):
25 | return self.step_type == StepType.LAST
26 |
27 | def __getitem__(self, attr):
28 | if isinstance(attr, str):
29 | return getattr(self, attr)
30 | else:
31 | return tuple.__getitem__(self, attr)
32 |
33 |
34 | class ActionRepeatWrapper(dm_env.Environment):
35 | def __init__(self, env, num_repeats):
36 | self._env = env
37 | self._num_repeats = num_repeats
38 |
39 | def step(self, action):
40 | reward = 0.0
41 | discount = 1.0
42 | for i in range(self._num_repeats):
43 | time_step = self._env.step(action)
44 | reward += (time_step.reward or 0.0) * discount
45 | discount *= time_step.discount
46 | if time_step.last():
47 | break
48 |
49 | return time_step._replace(reward=reward, discount=discount)
50 |
51 | def observation_spec(self):
52 | return self._env.observation_spec()
53 |
54 | def action_spec(self):
55 | return self._env.action_spec()
56 |
57 | def reset(self):
58 | return self._env.reset()
59 |
60 | def __getattr__(self, name):
61 | return getattr(self._env, name)
62 |
63 |
64 | class FrameStackWrapper(dm_env.Environment):
65 | def __init__(self, env, num_frames, pixels_key='pixels'):
66 | self._env = env
67 | self._num_frames = num_frames
68 | self._frames = deque([], maxlen=num_frames)
69 | self._pixels_key = pixels_key
70 |
71 | wrapped_obs_spec = env.observation_spec()
72 | assert pixels_key in wrapped_obs_spec
73 |
74 | pixels_shape = wrapped_obs_spec[pixels_key].shape
75 | # remove batch dim
76 | if len(pixels_shape) == 4:
77 | pixels_shape = pixels_shape[1:]
78 | self._obs_spec = specs.BoundedArray(shape=np.concatenate(
79 | [[pixels_shape[2] * num_frames], pixels_shape[:2]], axis=0),
80 | dtype=np.uint8,
81 | minimum=0,
82 | maximum=255,
83 | name='observation')
84 |
85 | def _transform_observation(self, time_step):
86 | assert len(self._frames) == self._num_frames
87 | obs = np.concatenate(list(self._frames), axis=0)
88 | return time_step._replace(observation=obs)
89 |
90 | def _extract_pixels(self, time_step):
91 | pixels = time_step.observation[self._pixels_key]
92 | # remove batch dim
93 | if len(pixels.shape) == 4:
94 | pixels = pixels[0]
95 | return pixels.transpose(2, 0, 1).copy()
96 |
97 | def reset(self):
98 | time_step = self._env.reset()
99 | pixels = self._extract_pixels(time_step)
100 | for _ in range(self._num_frames):
101 | self._frames.append(pixels)
102 | return self._transform_observation(time_step)
103 |
104 | def step(self, action):
105 | time_step = self._env.step(action)
106 | pixels = self._extract_pixels(time_step)
107 | self._frames.append(pixels)
108 | return self._transform_observation(time_step)
109 |
110 | def observation_spec(self):
111 | return self._obs_spec
112 |
113 | def action_spec(self):
114 | return self._env.action_spec()
115 |
116 | def __getattr__(self, name):
117 | return getattr(self._env, name)
118 |
119 |
120 | class ActionDTypeWrapper(dm_env.Environment):
121 | def __init__(self, env, dtype):
122 | self._env = env
123 | wrapped_action_spec = env.action_spec()
124 | self._action_spec = specs.BoundedArray(wrapped_action_spec.shape,
125 | dtype,
126 | wrapped_action_spec.minimum,
127 | wrapped_action_spec.maximum,
128 | 'action')
129 |
130 | def step(self, action):
131 | action = action.astype(self._env.action_spec().dtype)
132 | return self._env.step(action)
133 |
134 | def observation_spec(self):
135 | return self._env.observation_spec()
136 |
137 | def action_spec(self):
138 | return self._action_spec
139 |
140 | def reset(self):
141 | return self._env.reset()
142 |
143 | def __getattr__(self, name):
144 | return getattr(self._env, name)
145 |
146 |
147 | class ExtendedTimeStepWrapper(dm_env.Environment):
148 | def __init__(self, env):
149 | self._env = env
150 |
151 | def reset(self):
152 | time_step = self._env.reset()
153 | return self._augment_time_step(time_step)
154 |
155 | def step(self, action):
156 | time_step = self._env.step(action)
157 | return self._augment_time_step(time_step, action)
158 |
159 | def _augment_time_step(self, time_step, action=None):
160 | if action is None:
161 | action_spec = self.action_spec()
162 | action = np.zeros(action_spec.shape, dtype=action_spec.dtype)
163 | return ExtendedTimeStep(observation=time_step.observation,
164 | step_type=time_step.step_type,
165 | action=action,
166 | reward=time_step.reward or 0.0,
167 | discount=time_step.discount or 1.0)
168 |
169 | def observation_spec(self):
170 | return self._env.observation_spec()
171 |
172 | def action_spec(self):
173 | return self._env.action_spec()
174 |
175 | def __getattr__(self, name):
176 | return getattr(self._env, name)
177 |
178 |
179 | def make(name, frame_stack, action_repeat, seed):
180 | domain, task = name.split('_', 1)
181 | # overwrite cup to ball_in_cup
182 | domain = dict(cup='ball_in_cup').get(domain, domain)
183 | # make sure reward is not visualized
184 | if (domain, task) in suite.ALL_TASKS:
185 | env = suite.load(domain,
186 | task,
187 | task_kwargs={'random': seed},
188 | visualize_reward=False)
189 | pixels_key = 'pixels'
190 | else:
191 | name = f'{domain}_{task}_vision'
192 | env = manipulation.load(name, seed=seed)
193 | pixels_key = 'front_close'
194 | # add wrappers
195 | env = ActionDTypeWrapper(env, np.float32)
196 | env = ActionRepeatWrapper(env, action_repeat)
197 | env = action_scale.Wrapper(env, minimum=-1.0, maximum=+1.0)
198 | # add renderings for clasical tasks
199 | if (domain, task) in suite.ALL_TASKS:
200 | # zoom in camera for quadruped
201 | camera_id = dict(quadruped=2).get(domain, 0)
202 | render_kwargs = dict(height=84, width=84, camera_id=camera_id)
203 | env = pixels.Wrapper(env,
204 | pixels_only=True,
205 | render_kwargs=render_kwargs)
206 | # stack several frames
207 | env = FrameStackWrapper(env, frame_stack, pixels_key)
208 | env = ExtendedTimeStepWrapper(env)
209 | return env
210 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: taco
2 | channels:
3 | - defaults
4 | dependencies:
5 | - python=3.8
6 | - pip=21.1.3
7 | - numpy=1.19.2
8 | - absl-py=0.13.0
9 | - pyparsing=2.4.7
10 | - jupyterlab=3.0.14
11 | - scikit-image=0.18.1
12 | - nvidia::cudatoolkit=11.1
13 | - pytorch::pytorch
14 | - pytorch::torchvision
15 | - pytorch::torchaudio
16 | - pip:
17 | - chardet
18 | - gym
19 | - termcolor==1.1.0
20 | - dm_control
21 | - tb-nightly
22 | - imageio==2.9.0
23 | - imageio-ffmpeg==0.4.4
24 | - hydra-core==1.1.0
25 | - hydra-submitit-launcher==1.1.5
26 | - pandas==1.3.0
27 | - ipdb==0.13.9
28 | - yapf==0.31.0
29 | - mujoco_py==2.0.2.13
30 | - sklearn==0.0
31 | - matplotlib==3.4.2
32 | - opencv-python==4.5.3.56
33 |
--------------------------------------------------------------------------------
/logger.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import datetime
3 | from collections import defaultdict
4 |
5 | import numpy as np
6 | import torch
7 | import torchvision
8 | from termcolor import colored
9 | from torch.utils.tensorboard import SummaryWriter
10 |
11 | COMMON_TRAIN_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'),
12 | ('episode', 'E', 'int'), ('episode_length', 'L', 'int'),
13 | ('episode_reward', 'R', 'float'),
14 | ('buffer_size', 'BS', 'int'), ('fps', 'FPS', 'float'),
15 | ('total_time', 'T', 'time')]
16 |
17 | COMMON_EVAL_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'),
18 | ('episode', 'E', 'int'), ('episode_length', 'L', 'int'),
19 | ('episode_reward', 'R', 'float'),
20 | ('total_time', 'T', 'time')]
21 |
22 |
23 | class AverageMeter(object):
24 | def __init__(self):
25 | self._sum = 0
26 | self._count = 0
27 |
28 | def update(self, value, n=1):
29 | self._sum += value
30 | self._count += n
31 |
32 | def value(self):
33 | return self._sum / max(1, self._count)
34 |
35 |
36 | class MetersGroup(object):
37 | def __init__(self, csv_file_name, formating):
38 | self._csv_file_name = csv_file_name
39 | self._formating = formating
40 | self._meters = defaultdict(AverageMeter)
41 | self._csv_file = None
42 | self._csv_writer = None
43 |
44 | def log(self, key, value, n=1):
45 | self._meters[key].update(value, n)
46 |
47 | def _prime_meters(self):
48 | data = dict()
49 | for key, meter in self._meters.items():
50 | if key.startswith('train'):
51 | key = key[len('train') + 1:]
52 | else:
53 | key = key[len('eval') + 1:]
54 | key = key.replace('/', '_')
55 | data[key] = meter.value()
56 | return data
57 |
58 | def _remove_old_entries(self, data):
59 | rows = []
60 | with self._csv_file_name.open('r') as f:
61 | reader = csv.DictReader(f)
62 | for row in reader:
63 | if float(row['episode']) >= data['episode']:
64 | break
65 | rows.append(row)
66 | with self._csv_file_name.open('w') as f:
67 | writer = csv.DictWriter(f,
68 | fieldnames=sorted(data.keys()),
69 | restval=0.0)
70 | writer.writeheader()
71 | for row in rows:
72 | writer.writerow(row)
73 |
74 | def _dump_to_csv(self, data):
75 | if self._csv_writer is None:
76 | should_write_header = True
77 | if self._csv_file_name.exists():
78 | self._remove_old_entries(data)
79 | should_write_header = False
80 |
81 | self._csv_file = self._csv_file_name.open('a')
82 | self._csv_writer = csv.DictWriter(self._csv_file,
83 | fieldnames=sorted(data.keys()),
84 | restval=0.0)
85 | if should_write_header:
86 | self._csv_writer.writeheader()
87 |
88 | self._csv_writer.writerow(data)
89 | self._csv_file.flush()
90 |
91 | def _format(self, key, value, ty):
92 | if ty == 'int':
93 | value = int(value)
94 | return f'{key}: {value}'
95 | elif ty == 'float':
96 | return f'{key}: {value:.04f}'
97 | elif ty == 'time':
98 | value = str(datetime.timedelta(seconds=int(value)))
99 | return f'{key}: {value}'
100 | else:
101 | raise f'invalid format type: {ty}'
102 |
103 | def _dump_to_console(self, data, prefix):
104 | prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green')
105 | pieces = [f'| {prefix: <14}']
106 | for key, disp_key, ty in self._formating:
107 | value = data.get(key, 0)
108 | pieces.append(self._format(disp_key, value, ty))
109 | print(' | '.join(pieces), flush=True)
110 |
111 | def dump(self, step, prefix):
112 | if len(self._meters) == 0:
113 | return
114 | data = self._prime_meters()
115 | data['frame'] = step
116 | self._dump_to_csv(data)
117 | self._dump_to_console(data, prefix)
118 | self._meters.clear()
119 |
120 |
121 | class Logger(object):
122 | def __init__(self, log_dir, use_tb):
123 | self._log_dir = log_dir
124 | self._train_mg = MetersGroup(log_dir / 'train.csv',
125 | formating=COMMON_TRAIN_FORMAT)
126 | self._eval_mg = MetersGroup(log_dir / 'eval.csv',
127 | formating=COMMON_EVAL_FORMAT)
128 | if use_tb:
129 | self._sw = SummaryWriter(str(log_dir / 'tb'))
130 | else:
131 | self._sw = None
132 |
133 | def _try_sw_log(self, key, value, step):
134 | if self._sw is not None:
135 | self._sw.add_scalar(key, value, step)
136 |
137 | def log(self, key, value, step):
138 | assert key.startswith('train') or key.startswith('eval')
139 | if type(value) == torch.Tensor:
140 | value = value.item()
141 | self._try_sw_log(key, value, step)
142 | mg = self._train_mg if key.startswith('train') else self._eval_mg
143 | mg.log(key, value)
144 |
145 | def log_metrics(self, metrics, step, ty):
146 | for key, value in metrics.items():
147 | self.log(f'{ty}/{key}', value, step)
148 |
149 | def dump(self, step, ty=None):
150 | if ty is None or ty == 'eval':
151 | self._eval_mg.dump(step, 'eval')
152 | if ty is None or ty == 'train':
153 | self._train_mg.dump(step, 'train')
154 |
155 | def log_and_dump_ctx(self, step, ty):
156 | return LogAndDumpCtx(self, step, ty)
157 |
158 |
159 | class LogAndDumpCtx:
160 | def __init__(self, logger, step, ty):
161 | self._logger = logger
162 | self._step = step
163 | self._ty = ty
164 |
165 | def __enter__(self):
166 | return self
167 |
168 | def __call__(self, key, value):
169 | self._logger.log(f'{self._ty}/{key}', value, self._step)
170 |
171 | def __exit__(self, *args):
172 | self._logger.dump(self._step, self._ty)
--------------------------------------------------------------------------------
/media/dmc.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FrankZheng2022/TACO/84c38e34f4f9dfd2b059fb6d1356757e8d40712e/media/dmc.gif
--------------------------------------------------------------------------------
/media/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FrankZheng2022/TACO/84c38e34f4f9dfd2b059fb6d1356757e8d40712e/media/overview.png
--------------------------------------------------------------------------------
/media/taco.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FrankZheng2022/TACO/84c38e34f4f9dfd2b059fb6d1356757e8d40712e/media/taco.jpg
--------------------------------------------------------------------------------
/replay_buffer.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import io
3 | import random
4 | import traceback
5 | from collections import defaultdict
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | from torch.utils.data import IterableDataset
11 |
12 |
13 | def episode_len(episode):
14 | # subtract -1 because the dummy first transition
15 | return next(iter(episode.values())).shape[0] - 1
16 |
17 |
18 | def save_episode(episode, fn):
19 | with io.BytesIO() as bs:
20 | np.savez_compressed(bs, **episode)
21 | bs.seek(0)
22 | with fn.open('wb') as f:
23 | f.write(bs.read())
24 |
25 |
26 | def load_episode(fn):
27 | with fn.open('rb') as f:
28 | episode = np.load(f)
29 | episode = {k: episode[k] for k in episode.keys()}
30 | return episode
31 |
32 |
33 | class ReplayBufferStorage:
34 | def __init__(self, data_specs, replay_dir):
35 | self._data_specs = data_specs
36 | self._replay_dir = replay_dir
37 | replay_dir.mkdir(exist_ok=True)
38 | self._current_episode = defaultdict(list)
39 | self._preload()
40 |
41 | def __len__(self):
42 | return self._num_transitions
43 |
44 | def add(self, time_step):
45 | for spec in self._data_specs:
46 | value = time_step[spec.name]
47 | if np.isscalar(value):
48 | value = np.full(spec.shape, value, spec.dtype)
49 | assert spec.shape == value.shape and spec.dtype == value.dtype
50 | self._current_episode[spec.name].append(value)
51 | if time_step.last():
52 | episode = dict()
53 | for spec in self._data_specs:
54 | value = self._current_episode[spec.name]
55 | episode[spec.name] = np.array(value, spec.dtype)
56 | self._current_episode = defaultdict(list)
57 | self._store_episode(episode)
58 |
59 | def _preload(self):
60 | self._num_episodes = 0
61 | self._num_transitions = 0
62 | for fn in self._replay_dir.glob('*.npz'):
63 | _, _, eps_len = fn.stem.split('_')
64 | self._num_episodes += 1
65 | self._num_transitions += int(eps_len)
66 |
67 | def _store_episode(self, episode):
68 | eps_idx = self._num_episodes
69 | eps_len = episode_len(episode)
70 | self._num_episodes += 1
71 | self._num_transitions += eps_len
72 | ts = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')
73 | eps_fn = f'{ts}_{eps_idx}_{eps_len}.npz'
74 | save_episode(episode, self._replay_dir / eps_fn)
75 |
76 |
77 | class ReplayBuffer(IterableDataset):
78 | def __init__(self, replay_dir, max_size, num_workers, nstep, multistep,
79 | discount, fetch_every, save_snapshot):
80 | self._replay_dir = replay_dir
81 | self._size = 0
82 | self._max_size = max_size
83 | self._num_workers = max(1, num_workers)
84 | self._episode_fns = []
85 | self._episodes = dict()
86 | self._nstep = nstep
87 | self._discount = discount
88 | self._fetch_every = fetch_every
89 | self._samples_since_last_fetch = fetch_every
90 | self._save_snapshot = save_snapshot
91 | self._multistep = multistep
92 | print('Loading Data into CPU Memory')
93 | self._preload()
94 |
95 | def _sample_episode(self):
96 | eps_fn = random.choice(self._episode_fns)
97 | return self._episodes[eps_fn]
98 |
99 | def _store_episode(self, eps_fn):
100 | try:
101 | episode = load_episode(eps_fn)
102 | except:
103 | return False
104 | eps_len = episode_len(episode)
105 | while eps_len + self._size > self._max_size:
106 | early_eps_fn = self._episode_fns.pop(0)
107 | early_eps = self._episodes.pop(early_eps_fn)
108 | self._size -= episode_len(early_eps)
109 | early_eps_fn.unlink(missing_ok=True)
110 | self._episode_fns.append(eps_fn)
111 | self._episode_fns.sort()
112 | self._episodes[eps_fn] = episode
113 | self._size += eps_len
114 |
115 | if not self._save_snapshot:
116 | eps_fn.unlink(missing_ok=True)
117 | return True
118 |
119 | def _try_fetch(self):
120 | if self._samples_since_last_fetch < self._fetch_every:
121 | return
122 | self._samples_since_last_fetch = 0
123 | try:
124 | worker_id = torch.utils.data.get_worker_info().id
125 | except:
126 | worker_id = 0
127 | eps_fns = sorted(self._replay_dir.glob('*.npz'), reverse=True)
128 | fetched_size = 0
129 | for eps_fn in eps_fns:
130 | eps_idx, eps_len = [int(x) for x in eps_fn.stem.split('_')[1:]]
131 | if eps_idx % self._num_workers != worker_id:
132 | continue
133 | if eps_fn in self._episodes.keys():
134 | break
135 | if fetched_size + eps_len > self._max_size:
136 | break
137 | fetched_size += eps_len
138 | if not self._store_episode(eps_fn):
139 | break
140 |
141 | def _preload(self):
142 | eps_fns = sorted(self._replay_dir.glob('*.npz'), reverse=True)
143 | for eps_fn in eps_fns:
144 | self._store_episode(eps_fn)
145 |
146 | def _sample(self):
147 | try:
148 | self._try_fetch()
149 | except:
150 | traceback.print_exc()
151 | self._samples_since_last_fetch += 1
152 | episode = self._sample_episode()
153 | # add +1 for the first dummy transition
154 | n_step = max(self._nstep, self._multistep)
155 | idx = np.random.randint(0, episode_len(episode) - n_step + 1) + 1
156 | obs = episode['observation'][idx - 1]
157 | r_next_obs = episode['observation'][idx + self._multistep - 1]
158 | action = episode['action'][idx]
159 | action_seq = np.concatenate([episode['action'][idx+i][None, :] for i in range(self._multistep)])
160 | next_obs = episode['observation'][idx + self._nstep - 1]
161 | reward = np.zeros_like(episode['reward'][idx])
162 | discount = np.ones_like(episode['discount'][idx])
163 | for i in range(self._nstep):
164 | step_reward = episode['reward'][idx + i]
165 | reward += discount * step_reward
166 | discount *= episode['discount'][idx + i] * self._discount
167 | return (obs, action, action_seq, reward, discount, next_obs, r_next_obs)
168 |
169 | def __iter__(self):
170 | while True:
171 | yield self._sample()
172 |
173 |
174 | def _worker_init_fn(worker_id):
175 | seed = np.random.get_state()[1][0] + worker_id
176 | np.random.seed(seed)
177 | random.seed(seed)
178 |
179 |
180 | def make_replay_loader(replay_dir, max_size, batch_size, num_workers,
181 | save_snapshot, nstep, multistep, discount):
182 | max_size_per_worker = max_size // max(1, num_workers)
183 |
184 | iterable = ReplayBuffer(replay_dir,
185 | max_size_per_worker,
186 | num_workers,
187 | nstep,
188 | multistep,
189 | discount,
190 | fetch_every=1000,
191 | save_snapshot=save_snapshot)
192 |
193 | loader = torch.utils.data.DataLoader(iterable,
194 | batch_size=batch_size,
195 | num_workers=num_workers,
196 | pin_memory=True,
197 | worker_init_fn=_worker_init_fn)
198 | return loader
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | warnings.filterwarnings('ignore', category=DeprecationWarning)
3 |
4 | import os
5 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
6 | os.environ['MUJOCO_GL'] = 'egl'
7 |
8 | from pathlib import Path
9 |
10 | import hydra
11 | import numpy as np
12 | import torch
13 | from dm_env import specs
14 |
15 | import dmc
16 | import utils
17 | from logger import Logger
18 | from replay_buffer import ReplayBufferStorage, make_replay_loader
19 | from video import TrainVideoRecorder, VideoRecorder
20 |
21 | torch.backends.cudnn.benchmark = True
22 |
23 |
24 | def make_agent(obs_spec, action_spec, cfg):
25 | cfg.obs_shape = obs_spec.shape
26 | cfg.action_shape = action_spec.shape
27 | return hydra.utils.instantiate(cfg)
28 |
29 |
30 | class Workspace:
31 | def __init__(self, cfg):
32 | self.work_dir = Path.cwd()
33 | print(f'workspace: {self.work_dir}')
34 |
35 | self.cfg = cfg
36 | utils.set_seed_everywhere(cfg.seed)
37 | self.device = torch.device(cfg.device)
38 | self.setup()
39 |
40 | self.agent = make_agent(self.train_env.observation_spec(),
41 | self.train_env.action_spec(),
42 | self.cfg.agent)
43 | self.timer = utils.Timer()
44 | self._global_step = 0
45 | self._global_episode = 0
46 |
47 | def setup(self):
48 | # create logger
49 | self.logger = Logger(self.work_dir, use_tb=self.cfg.use_tb)
50 | self.train_env = dmc.make(self.cfg.task_name, self.cfg.frame_stack,
51 | self.cfg.action_repeat, self.cfg.seed)
52 | self.eval_env = dmc.make(self.cfg.task_name, self.cfg.frame_stack,
53 | self.cfg.action_repeat, self.cfg.seed)
54 | # create replay buffer
55 | data_specs = (self.train_env.observation_spec(),
56 | self.train_env.action_spec(),
57 | specs.Array((1,), np.float32, 'reward'),
58 | specs.Array((1,), np.float32, 'discount'))
59 |
60 | self.replay_storage = ReplayBufferStorage(data_specs,
61 | self.work_dir / 'buffer')
62 |
63 | self.replay_loader = make_replay_loader(
64 | self.work_dir / 'buffer', self.cfg.replay_buffer_size,
65 | self.cfg.batch_size, self.cfg.replay_buffer_num_workers,
66 | self.cfg.save_snapshot, self.cfg.nstep, self.cfg.multistep, self.cfg.discount)
67 | self._replay_iter = None
68 |
69 | self.video_recorder = VideoRecorder(
70 | self.work_dir if self.cfg.save_video else None)
71 | self.train_video_recorder = TrainVideoRecorder(
72 | self.work_dir if self.cfg.save_train_video else None)
73 |
74 |
75 | @property
76 | def global_step(self):
77 | return self._global_step
78 |
79 | @property
80 | def global_episode(self):
81 | return self._global_episode
82 |
83 | @property
84 | def global_frame(self):
85 | return self.global_step * self.cfg.action_repeat
86 |
87 | @property
88 | def replay_iter(self):
89 | if self._replay_iter is None:
90 | self._replay_iter = iter(self.replay_loader)
91 | return self._replay_iter
92 |
93 | def eval(self):
94 | step, episode, total_reward, success = 0, 0, 0, 0
95 | eval_until_episode = utils.Until(self.cfg.num_eval_episodes)
96 |
97 | while eval_until_episode(episode):
98 | time_step = self.eval_env.reset()
99 | self.video_recorder.init(self.eval_env, enabled=(episode == 0))
100 | while not time_step.last():
101 | with torch.no_grad(), utils.eval_mode(self.agent):
102 | action = self.agent.act(time_step.observation,
103 | self.global_step,
104 | eval_mode=True)
105 | time_step = self.eval_env.step(action)
106 | self.video_recorder.record(self.eval_env)
107 | total_reward += time_step.reward
108 | step += 1
109 |
110 | episode += 1
111 | self.video_recorder.save(f'{self.global_frame}.mp4')
112 |
113 | with self.logger.log_and_dump_ctx(self.global_frame, ty='eval') as log:
114 | log('episode_reward', total_reward / episode)
115 | log('episode_length', step * self.cfg.action_repeat / episode)
116 | log('episode', self.global_episode)
117 | log('step', self.global_step)
118 |
119 | def train(self):
120 | # predicates
121 | train_until_step = utils.Until(self.cfg.num_train_frames,
122 | self.cfg.action_repeat)
123 | seed_until_step = utils.Until(self.cfg.num_seed_frames,
124 | self.cfg.action_repeat)
125 | eval_every_step = utils.Every(self.cfg.eval_every_frames,
126 | self.cfg.action_repeat)
127 |
128 | episode_step, episode_reward = 0, 0
129 | time_step = self.train_env.reset()
130 | self.replay_storage.add(time_step)
131 | self.train_video_recorder.init(time_step.observation)
132 | metrics = None
133 | while train_until_step(self.global_step):
134 | if time_step.last():
135 | self._global_episode += 1
136 | self.train_video_recorder.save(f'{self.global_frame}.mp4')
137 | # wait until all the metrics schema is populated
138 | if metrics is not None:
139 | # log stats
140 | elapsed_time, total_time = self.timer.reset()
141 | episode_frame = episode_step * self.cfg.action_repeat
142 | with self.logger.log_and_dump_ctx(self.global_frame,
143 | ty='train') as log:
144 | log('fps', episode_frame / elapsed_time)
145 | log('total_time', total_time)
146 | log('episode_reward', episode_reward)
147 | log('episode_length', episode_frame)
148 | log('episode', self.global_episode)
149 | log('buffer_size', len(self.replay_storage))
150 | log('step', self.global_step)
151 |
152 | # reset env
153 | time_step = self.train_env.reset()
154 | self.replay_storage.add(time_step)
155 | self.train_video_recorder.init(time_step.observation)
156 | # try to save snapshot
157 | if self.cfg.save_snapshot:
158 | self.save_snapshot()
159 | episode_step = 0
160 | episode_reward = 0
161 |
162 | # try to evaluate
163 | if eval_every_step(self.global_step):
164 | self.logger.log('eval_total_time', self.timer.total_time(),
165 | self.global_frame)
166 | self.eval()
167 |
168 | # sample action
169 | with torch.no_grad(), utils.eval_mode(self.agent):
170 | action = self.agent.act(time_step.observation,
171 | self.global_step,
172 | eval_mode=False)
173 |
174 | # try to update the agent
175 | if not seed_until_step(self.global_step):
176 | metrics = self.agent.update(self.replay_iter, self.global_step)
177 | self.logger.log_metrics(metrics, self.global_frame, ty='train')
178 |
179 | # take env step
180 | time_step = self.train_env.step(action)
181 | episode_reward += time_step.reward
182 | self.replay_storage.add(time_step)
183 | self.train_video_recorder.record(time_step.observation)
184 | episode_step += 1
185 | self._global_step += 1
186 |
187 | def save_snapshot(self):
188 | snapshot = self.work_dir / 'snapshot.pt'
189 | keys_to_save = ['agent', 'timer', '_global_step', '_global_episode']
190 | payload = {k: self.__dict__[k] for k in keys_to_save}
191 | with snapshot.open('wb') as f:
192 | torch.save(payload, f)
193 |
194 | def load_snapshot(self):
195 | snapshot = self.work_dir / 'snapshot.pt'
196 | with snapshot.open('rb') as f:
197 | payload = torch.load(f)
198 | for k, v in payload.items():
199 | self.__dict__[k] = v
200 |
201 |
202 | @hydra.main(config_path='cfgs', config_name='config')
203 | def main(cfg):
204 | from train import Workspace as W
205 | root_dir = Path.cwd()
206 | workspace = W(cfg)
207 | snapshot = root_dir / 'snapshot.pt'
208 | if snapshot.exists():
209 | print(f'resuming: {snapshot}')
210 | workspace.load_snapshot()
211 | workspace.train()
212 |
213 |
214 | if __name__ == '__main__':
215 | main()
216 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | import re
3 | import time
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from omegaconf import OmegaConf
10 | from torch import distributions as pyd
11 | from torch.distributions.utils import _standard_normal
12 |
13 | ### input shape: (batch_size, length, action_dim)
14 | ### output shape: (batch_size, action_dim)
15 | class ActionEncoding(nn.Module):
16 | def __init__(self, action_dim, latent_action_dim, multistep):
17 | super().__init__()
18 | self.action_dim = action_dim
19 | self.action_tokenizer = nn.Sequential(
20 | nn.Linear(action_dim, 64), nn.Tanh(),
21 | nn.Linear(64, latent_action_dim)
22 | )
23 | self.action_seq_tokenizer = nn.Sequential(
24 | nn.Linear(latent_action_dim*multistep, latent_action_dim*multistep),
25 | nn.LayerNorm(latent_action_dim*multistep), nn.Tanh()
26 | )
27 | self.apply(weight_init)
28 |
29 | def forward(self, action, seq=False):
30 | if seq:
31 | batch_size = action.shape[0]
32 | action = self.action_tokenizer(action) #(batch_size, length_action_dim)
33 | action = action.reshape(batch_size, -1)
34 | return self.action_seq_tokenizer(action)
35 | else:
36 | return self.action_tokenizer(action)
37 |
38 |
39 |
40 | class eval_mode:
41 | def __init__(self, *models):
42 | self.models = models
43 |
44 | def __enter__(self):
45 | self.prev_states = []
46 | for model in self.models:
47 | self.prev_states.append(model.training)
48 | model.train(False)
49 |
50 | def __exit__(self, *args):
51 | for model, state in zip(self.models, self.prev_states):
52 | model.train(state)
53 | return False
54 |
55 |
56 | def set_seed_everywhere(seed):
57 | torch.manual_seed(seed)
58 | if torch.cuda.is_available():
59 | torch.cuda.manual_seed_all(seed)
60 | np.random.seed(seed)
61 | random.seed(seed)
62 |
63 |
64 | def soft_update_params(net, target_net, tau):
65 | for param, target_param in zip(net.parameters(), target_net.parameters()):
66 | target_param.data.copy_(tau * param.data +
67 | (1 - tau) * target_param.data)
68 |
69 | def expectile_loss(diff, expectile=0.8):
70 | weight = torch.where(diff > 0, expectile, (1 - expectile))
71 | return weight * (diff**2)
72 |
73 |
74 | def to_torch(xs, device):
75 | return tuple(torch.as_tensor(x, device=device) for x in xs)
76 |
77 | def encode_multiple(encoder, xs, detach_lst):
78 | length = [x.shape[0] for x in xs]
79 | xs, xs_lst = torch.cat(xs, dim=0), []
80 | xs = encoder(xs)
81 | start = 0
82 | for i in range(len(detach_lst)):
83 | x = xs[start:start+length[i], :]
84 | if detach_lst[i]:
85 | x = x.detach()
86 | xs_lst.append(x)
87 | start += length[i]
88 | return xs_lst
89 |
90 |
91 |
92 | def weight_init(m):
93 | if isinstance(m, nn.Linear):
94 | nn.init.orthogonal_(m.weight.data)
95 | if hasattr(m.bias, 'data'):
96 | m.bias.data.fill_(0.0)
97 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
98 | gain = nn.init.calculate_gain('relu')
99 | nn.init.orthogonal_(m.weight.data, gain)
100 | if hasattr(m.bias, 'data'):
101 | m.bias.data.fill_(0.0)
102 |
103 |
104 | class Until:
105 | def __init__(self, until, action_repeat=1):
106 | self._until = until
107 | self._action_repeat = action_repeat
108 |
109 | def __call__(self, step):
110 | if self._until is None:
111 | return True
112 | until = self._until // self._action_repeat
113 | return step < until
114 |
115 |
116 | class Every:
117 | def __init__(self, every, action_repeat=1):
118 | self._every = every
119 | self._action_repeat = action_repeat
120 |
121 | def __call__(self, step):
122 | if self._every is None:
123 | return False
124 | every = self._every // self._action_repeat
125 | if step % every == 0:
126 | return True
127 | return False
128 |
129 |
130 | class Timer:
131 | def __init__(self):
132 | self._start_time = time.time()
133 | self._last_time = time.time()
134 |
135 | def reset(self):
136 | elapsed_time = time.time() - self._last_time
137 | self._last_time = time.time()
138 | total_time = time.time() - self._start_time
139 | return elapsed_time, total_time
140 |
141 | def total_time(self):
142 | return time.time() - self._start_time
143 |
144 |
145 | class TruncatedNormal(pyd.Normal):
146 | def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6):
147 | super().__init__(loc, scale, validate_args=False)
148 | self.low = low
149 | self.high = high
150 | self.eps = eps
151 |
152 | def _clamp(self, x):
153 | clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
154 | x = x - x.detach() + clamped_x.detach()
155 | return x
156 |
157 | def sample(self, clip=None, sample_shape=torch.Size()):
158 | shape = self._extended_shape(sample_shape)
159 | eps = _standard_normal(shape,
160 | dtype=self.loc.dtype,
161 | device=self.loc.device)
162 | eps *= self.scale
163 | if clip is not None:
164 | eps = torch.clamp(eps, -clip, clip)
165 | x = self.loc + eps
166 | return self._clamp(x)
167 |
168 |
169 | def schedule(schdl, step):
170 | try:
171 | return float(schdl)
172 | except ValueError:
173 | match = re.match(r'linear\((.+),(.+),(.+)\)', schdl)
174 | if match:
175 | init, final, duration = [float(g) for g in match.groups()]
176 | mix = np.clip(step / duration, 0.0, 1.0)
177 | return (1.0 - mix) * init + mix * final
178 | match = re.match(r'step_linear\((.+),(.+),(.+),(.+),(.+)\)', schdl)
179 | if match:
180 | init, final1, duration1, final2, duration2 = [
181 | float(g) for g in match.groups()
182 | ]
183 | if step <= duration1:
184 | mix = np.clip(step / duration1, 0.0, 1.0)
185 | return (1.0 - mix) * init + mix * final1
186 | else:
187 | mix = np.clip((step - duration1) / duration2, 0.0, 1.0)
188 | return (1.0 - mix) * final1 + mix * final2
189 | raise NotImplementedError(schdl)
190 |
191 |
192 |
--------------------------------------------------------------------------------
/video.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import imageio
3 | import numpy as np
4 |
5 |
6 | class VideoRecorder:
7 | def __init__(self, root_dir, render_size=256, fps=20):
8 | if root_dir is not None:
9 | self.save_dir = root_dir / 'eval_video'
10 | self.save_dir.mkdir(exist_ok=True)
11 | else:
12 | self.save_dir = None
13 |
14 | self.render_size = render_size
15 | self.fps = fps
16 | self.frames = []
17 |
18 | def init(self, env, enabled=True):
19 | self.frames = []
20 | self.enabled = self.save_dir is not None and enabled
21 | self.record(env)
22 |
23 | def record(self, env):
24 | if self.enabled:
25 | if hasattr(env, 'physics'):
26 | frame = env.physics.render(height=self.render_size,
27 | width=self.render_size,
28 | camera_id=0)
29 | else:
30 | frame = env.render()
31 | self.frames.append(frame)
32 |
33 | def save(self, file_name):
34 | if self.enabled:
35 | path = self.save_dir / file_name
36 | imageio.mimsave(str(path), self.frames, fps=self.fps)
37 |
38 |
39 | class TrainVideoRecorder:
40 | def __init__(self, root_dir, render_size=256, fps=20):
41 | if root_dir is not None:
42 | self.save_dir = root_dir / 'train_video'
43 | self.save_dir.mkdir(exist_ok=True)
44 | else:
45 | self.save_dir = None
46 |
47 | self.render_size = render_size
48 | self.fps = fps
49 | self.frames = []
50 |
51 | def init(self, obs, enabled=True):
52 | self.frames = []
53 | self.enabled = self.save_dir is not None and enabled
54 | self.record(obs)
55 |
56 | def record(self, obs):
57 | if self.enabled:
58 | frame = cv2.resize(obs[-3:].transpose(1, 2, 0),
59 | dsize=(self.render_size, self.render_size),
60 | interpolation=cv2.INTER_CUBIC)
61 | self.frames.append(frame)
62 |
63 | def save(self, file_name):
64 | if self.enabled:
65 | path = self.save_dir / file_name
66 | imageio.mimsave(str(path), self.frames, fps=self.fps)
67 |
--------------------------------------------------------------------------------