193 |
194 |
195 |
196 |
Demonstration-Guided Reinforcement Learning with Learned Skills
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
Demonstration-Guided Reinforcement Learning with Learned Skills
231 |
232 |
235 |
236 |
239 |
240 |
243 |
244 |
247 |
248 |
253 |
Conference on Robot Learning (CoRL), 2021
254 |
255 |
260 |
261 |
262 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 | Demonstration-guided reinforcement learning (RL) is a promising approach for learning complex behaviors by leveraging both reward feedback and a set of target task demonstrations. Prior approaches for demonstration-guided RL treat every new task as an independent learning problem and attempt to follow the provided demonstrations step-by-step, akin to a human trying to imitate a completely unseen behavior by following the demonstrator's exact muscle movements. Naturally, such learning will be slow, but often new behaviors are not completely unseen: they share subtasks with behaviors we have previously learned. In this work, we aim to exploit this shared subtask structure to increase the efficiency of demonstration-guided RL. We first learn a set of reusable skills from large offline datasets of prior experience collected across many tasks. We then propose Ski ll-based L earning with D emonstrations (SkiLD ), an algorithm for demonstration-guided RL that efficiently leverages the provided demonstrations by following the demonstrated skills instead of the primitive actions, resulting in substantial performance improvements over prior demonstration-guided RL approaches. We validate the effectiveness of our approach on long-horizon maze navigation and complex robot manipulation tasks.
273 |
274 |
275 |
276 |
277 |
278 |
Overview
279 |
280 | Our goal is to use skills extracted from prior experience to improve the efficiency of demonstration-guided RL on a new task. We aim to leverage a set of provided demonstrations by following the performed skills as opposed to the primitive actions.
281 |
282 |
283 |
284 |
285 | Learning in our approach, SkiLD, is performed in three stages.
(1) : First, we extract a set of reusable skills from prior, task-agnostic experience. We build on prior work in skill-based RL for learning the skill extraction module (
SPiRL, Pertsch et al. 2020 ).
(2) : We then use the pre-trained skill encoder to infer the skills performed in task-agnostic and demonstration sequences and learn state-conditioned skill distributions, which we call
skill prior and
skill posterior respectively.
(3) : Finally, we use both distributions to guide a hierarchical skill policy during learning of the downstream task.
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 | Demonstration-Guided Downstream Learning
298 |
299 | While we have learned a state-conditioned distribution over the demonstrated skills, we cannot always trust this skill posterior, since it is only valid within the demonstration support (green region). Thus, to guide the hierarchical policy during downstream learning, SkiLD leverages the skill posterior only within the support of the demonstrations and uses the learned skill prior otherwise, since it was trained on the task-agnostic experience dataset with a much wider support (red region).
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 | Environments
311 |
312 |
313 |
314 | Maze Navigation
315 |
316 |
317 |
318 |
319 | Kitchen Manipulation
320 |
321 |
322 |
323 |
324 | Office Cleanup
325 |
326 |
327 |
328 |
329 |
330 |
331 | We evaluate our approach on three long-horizon tasks: maze navigation, kitchen manipulation and office cleanup. In each environment, we collect a large, task-agnostic dataset and a small set of task-specific demonstrations.
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
How does SkiLD Follow the Demonstrations?
340 |
341 |
342 |
343 |
344 | We analyze the qualitative behavior of our approach in the maze environment: the discriminator D(s) can accurately estimate the support of the demonstrations (green). Thus, the SkiLD policy minimizes divergence to the demonstration-based skill posterior within the demonstration support (third panel, blue) and follows the task-agnostic skill prior otherwise (fourth panel). In summary, the agent learns to follow the demonstrations whenever it's within their support and falls back to prior-based exploration outside the support.
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
Qualitative Results
353 |
354 |
355 |
356 |
357 |
358 | Kitchen Manipulation
359 |
360 |
361 |
362 |
363 | SkiLD
364 |
365 |
366 |
367 |
368 | SPiRL
369 |
370 |
371 |
372 |
373 | SkillBC + SAC
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 | Office Cleanup
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 | Rollouts from the trained policies on the robotic manipulation tasks. In the kitchen environment the agent needs to perform four subtasks: open microwave, flip light switch, open slide cabinet, open hinge cabinet. In the office cleanup task it needs to put the correct objects in the correct receptacles. In both environments, our approach SkiLD is the only method that cann solve the full task. SPiRL lacks guidance through the demonstrations and thus solves wrong subtasks and fails at the target task. Skill-based BC with SAC finetuning is brittle and unable to solve more than one subtask. For more qualitative result videos, please check our
supplementary website .
400 |
401 |
402 |
403 |
404 |
405 |
Quantitative Results
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
Imitation Learning Results
414 |
415 |
416 |
417 |
418 | We apply SkiLD in the pure imitation setting, without access to environment rewards and instead use a GAIL-style reward based on our learned discriminator, which is trained to estimate demonstration support. We show that our approach is able to leverage prior experience through skills for effective imitation of long-horizon tasks. By finetuning the learned discriminator we can further improve performance on the kitchen manipulation task which requires more complex control.
419 |
420 |
421 |
422 |
Source Code
423 |
424 | We have released our implementation in PyTorch on the github page. Try our code!
425 |
426 |
429 |
430 |
431 |
432 |
433 | Citation
434 |
435 |
436 |
437 | @article{pertsch2021skild,
438 | title={Demonstration-Guided Reinforcement Learning with Learned Skills},
439 | author={Karl Pertsch and Youngwoon Lee and Yue Wu and Joseph J. Lim},
440 | journal={5th Conference on Robot Learning},
441 | year={2021},
442 | }
443 |
444 |
445 |
446 |
447 |
448 |
449 |
450 |
454 |
455 |
456 |
459 |
460 |
461 |
462 |
--------------------------------------------------------------------------------
/docs/resources/clvr_icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/clvr_icon.png
--------------------------------------------------------------------------------
/docs/resources/env_videos/kitchen.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/env_videos/kitchen.mp4
--------------------------------------------------------------------------------
/docs/resources/env_videos/maze.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/env_videos/maze.mp4
--------------------------------------------------------------------------------
/docs/resources/env_videos/office.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/env_videos/office.mp4
--------------------------------------------------------------------------------
/docs/resources/kitchen_subtask_distribution.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/kitchen_subtask_distribution.png
--------------------------------------------------------------------------------
/docs/resources/policy_videos/kitchen_skild.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/policy_videos/kitchen_skild.mp4
--------------------------------------------------------------------------------
/docs/resources/policy_videos/kitchen_skillBCSAC.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/policy_videos/kitchen_skillBCSAC.mp4
--------------------------------------------------------------------------------
/docs/resources/policy_videos/kitchen_spirl.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/policy_videos/kitchen_spirl.mp4
--------------------------------------------------------------------------------
/docs/resources/policy_videos/office_skild.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/policy_videos/office_skild.mp4
--------------------------------------------------------------------------------
/docs/resources/policy_videos/office_skillBCSAC.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/policy_videos/office_skillBCSAC.mp4
--------------------------------------------------------------------------------
/docs/resources/policy_videos/office_spirl.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/policy_videos/office_spirl.mp4
--------------------------------------------------------------------------------
/docs/resources/skild_downstream_sketch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/skild_downstream_sketch.png
--------------------------------------------------------------------------------
/docs/resources/skild_imitation_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/skild_imitation_results.png
--------------------------------------------------------------------------------
/docs/resources/skild_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/skild_model.png
--------------------------------------------------------------------------------
/docs/resources/skild_quali_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/skild_quali_results.png
--------------------------------------------------------------------------------
/docs/resources/skild_quant_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/skild_quant_results.png
--------------------------------------------------------------------------------
/docs/resources/skild_teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/skild_teaser.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # core
2 | numpy
3 | matplotlib
4 | pillow
5 | h5py==2.10.0
6 | scikit-image
7 | funcsigs
8 | opencv-python
9 | moviepy
10 | torch==1.3.1
11 | torchvision==0.4.2
12 | tensorboard==2.1.1
13 | tensorboardX==2.0
14 | gym==0.15.4
15 | pandas
16 |
17 | # RL
18 | wandb
19 | mpi4py
20 | mujoco_py==2.0.2.9
21 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 |
3 | setup(name='skild', version='0.0.1', packages=['skild'])
4 |
--------------------------------------------------------------------------------
/skild/configs/demo_discriminator/kitchen/conf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 |
4 | from spirl.utils.general_utils import AttrDict
5 | from skild.models.demo_discriminator import DemoDiscriminator, DemoDiscriminatorLogger
6 | from spirl.configs.default_data_configs.kitchen import data_spec
7 | from spirl.components.evaluator import DummyEvaluator
8 |
9 |
10 | current_dir = os.path.dirname(os.path.realpath(__file__))
11 |
12 |
13 | configuration = {
14 | 'model': DemoDiscriminator,
15 | 'model_test': DemoDiscriminator,
16 | 'logger': DemoDiscriminatorLogger,
17 | 'logger_test': DemoDiscriminatorLogger,
18 | 'data_dir': ".",
19 | 'num_epochs': 100,
20 | 'epoch_cycles_train': 10,
21 | 'evaluator': DummyEvaluator,
22 | }
23 | configuration = AttrDict(configuration)
24 |
25 | model_config = AttrDict(
26 | action_dim=data_spec.n_actions,
27 | normalization='none',
28 | )
29 |
30 | # Demo Dataset
31 | demo_data_config = AttrDict()
32 | demo_data_config.dataset_spec = copy.deepcopy(data_spec)
33 | demo_data_config.dataset_spec.crop_rand_subseq = True
34 | demo_data_config.dataset_spec.subseq_len = 1+1
35 | demo_data_config.dataset_spec.filter_indices = [[320, 337], [339, 344]] # use only demos for one task (here: KBTS)
36 | demo_data_config.dataset_spec.demo_repeats = 10 # repeat those demos N times
37 | model_config.demo_data_conf = demo_data_config
38 | model_config.demo_data_path = '.'
39 |
40 | # Non-demo Dataset
41 | data_config = AttrDict()
42 | data_config.dataset_spec = data_spec
43 | data_config.dataset_spec.crop_rand_subseq = True
44 | data_config.dataset_spec.subseq_len = 1+1
45 |
--------------------------------------------------------------------------------
/skild/configs/demo_discriminator/maze/conf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 |
4 | from spirl.utils.general_utils import AttrDict
5 | from skild.models.demo_discriminator import DemoDiscriminator, DemoDiscriminatorLogger
6 | from spirl.configs.default_data_configs.maze import data_spec
7 | from spirl.components.evaluator import DummyEvaluator
8 |
9 |
10 | current_dir = os.path.dirname(os.path.realpath(__file__))
11 |
12 |
13 | configuration = {
14 | 'model': DemoDiscriminator,
15 | 'model_test': DemoDiscriminator,
16 | 'logger': DemoDiscriminatorLogger,
17 | 'logger_test': DemoDiscriminatorLogger,
18 | 'data_dir': os.path.join(os.environ['DATA_DIR'], 'maze_TA'),
19 | 'num_epochs': 100,
20 | 'epoch_cycles_train': 200,
21 | 'evaluator': DummyEvaluator,
22 | }
23 | configuration = AttrDict(configuration)
24 |
25 | model_config = AttrDict(
26 | action_dim=data_spec.n_actions,
27 | normalization='none',
28 | )
29 |
30 | # Demo Dataset
31 | demo_data_config = AttrDict()
32 | demo_data_config.dataset_spec = copy.deepcopy(data_spec)
33 | demo_data_config.dataset_spec.crop_rand_subseq = True
34 | demo_data_config.dataset_spec.subseq_len = 1+1
35 | demo_data_config.dataset_spec.n_seqs = 5 # number of demos used
36 | demo_data_config.dataset_spec.seq_repeat = 30 # repeat those demos N times
37 | model_config.demo_data_conf = demo_data_config
38 | model_config.demo_data_path = os.path.join(os.environ['DATA_DIR'], 'maze_demos')
39 |
40 | # Non-demo Dataset
41 | data_config = AttrDict()
42 | data_config.dataset_spec = data_spec
43 | data_config.dataset_spec.crop_rand_subseq = True
44 | data_config.dataset_spec.subseq_len = 1+1
45 |
--------------------------------------------------------------------------------
/skild/configs/demo_discriminator/office/conf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 |
4 | from spirl.utils.general_utils import AttrDict
5 | from skild.models.demo_discriminator import DemoDiscriminator, DemoDiscriminatorLogger
6 | from spirl.configs.default_data_configs.office import data_spec
7 | from spirl.components.evaluator import DummyEvaluator
8 |
9 |
10 | current_dir = os.path.dirname(os.path.realpath(__file__))
11 |
12 |
13 | configuration = {
14 | 'model': DemoDiscriminator,
15 | 'model_test': DemoDiscriminator,
16 | 'logger': DemoDiscriminatorLogger,
17 | 'logger_test': DemoDiscriminatorLogger,
18 | 'data_dir': os.path.join(os.environ['DATA_DIR'], 'office_TA'),
19 | 'num_epochs': 100,
20 | 'epoch_cycles_train': 300,
21 | 'evaluator': DummyEvaluator,
22 | }
23 | configuration = AttrDict(configuration)
24 |
25 | model_config = AttrDict(
26 | action_dim=data_spec.n_actions,
27 | normalization='none',
28 | )
29 |
30 | # Demo Dataset
31 | demo_data_config = AttrDict()
32 | demo_data_config.dataset_spec = copy.deepcopy(data_spec)
33 | demo_data_config.dataset_spec.crop_rand_subseq = True
34 | demo_data_config.dataset_spec.subseq_len = 1+1
35 | demo_data_config.dataset_spec.n_seqs = 50 # number of demos used
36 | demo_data_config.dataset_spec.seq_repeat = 3 # repeat those demos N times
37 | model_config.demo_data_conf = demo_data_config
38 | model_config.demo_data_path = os.path.join(os.environ['DATA_DIR'], 'office_demos')
39 |
40 | # Non-demo Dataset
41 | data_config = AttrDict()
42 | data_config.dataset_spec = data_spec
43 | data_config.dataset_spec.crop_rand_subseq = True
44 | data_config.dataset_spec.subseq_len = 1+1
45 |
--------------------------------------------------------------------------------
/skild/configs/demo_rl/kitchen/conf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import torch
4 |
5 | from spirl.utils.general_utils import AttrDict
6 | from spirl.rl.components.agent import FixedIntervalHierarchicalAgent
7 | from spirl.rl.components.replay_buffer import UniformReplayBuffer
8 | from spirl.rl.components.sampler import HierarchicalSampler
9 | from spirl.rl.components.critic import MLPCritic, SplitObsMLPCritic
10 | from spirl.rl.agents.ac_agent import SACAgent
11 | from spirl.rl.policies.cl_model_policies import ClModelPolicy
12 | from spirl.rl.envs.kitchen import KitchenEnv
13 | from spirl.models.closed_loop_spirl_mdl import ClSPiRLMdl
14 | from spirl.configs.default_data_configs.kitchen import data_spec
15 |
16 | from skild.rl.policies.posterior_policies import LearnedPPPolicy
17 | from skild.models.demo_discriminator import DemoDiscriminator
18 | from skild.rl.agents.skild_agent import SkiLDAgent
19 |
20 |
21 | current_dir = os.path.dirname(os.path.realpath(__file__))
22 |
23 | notes = 'used to test the RL implementation'
24 |
25 | configuration = {
26 | 'seed': 42,
27 | 'agent': FixedIntervalHierarchicalAgent,
28 | 'environment': KitchenEnv,
29 | 'sampler': HierarchicalSampler,
30 | 'data_dir': '.',
31 | 'num_epochs': 200,
32 | 'max_rollout_len': 280,
33 | 'n_steps_per_epoch': 1e6,
34 | 'log_output_per_epoch': 1000,
35 | 'n_warmup_steps': 2e3,
36 | }
37 | configuration = AttrDict(configuration)
38 |
39 | # Observation Normalization
40 | obs_norm_params = AttrDict(
41 | )
42 |
43 | base_agent_params = AttrDict(
44 | batch_size=128,
45 | # update_iterations=XXX,
46 | )
47 |
48 | ###### Low-Level ######
49 | # LL Policy
50 | ll_model_params = AttrDict(
51 | state_dim=data_spec.state_dim,
52 | action_dim=data_spec.n_actions,
53 | n_rollout_steps=10,
54 | kl_div_weight=5e-4,
55 | nz_vae=10,
56 | nz_enc=128,
57 | nz_mid=128,
58 | n_processing_layers=5,
59 | cond_decode=True,
60 | )
61 |
62 | # LL Policy
63 | ll_policy_params = AttrDict(
64 | policy_model=ClSPiRLMdl,
65 | policy_model_params=ll_model_params,
66 | policy_model_checkpoint=os.path.join(os.environ["EXP_DIR"], "skill_prior/kitchen/kitchen_prior"),
67 | )
68 | ll_policy_params.update(ll_model_params)
69 |
70 | # LL Critic
71 | ll_critic_params = AttrDict(
72 | action_dim=data_spec.n_actions,
73 | input_dim=data_spec.state_dim,
74 | output_dim=1,
75 | action_input=True,
76 | unused_obs_size=ll_model_params.nz_vae, # ignore HL policy z output in observation for LL critic
77 | )
78 |
79 | # LL Agent
80 | ll_agent_config = copy.deepcopy(base_agent_params)
81 | ll_agent_config.update(AttrDict(
82 | policy=ClModelPolicy,
83 | policy_params=ll_policy_params,
84 | critic=SplitObsMLPCritic,
85 | critic_params=ll_critic_params,
86 | ))
87 |
88 | ###### High-Level ########
89 | # HL Policy
90 | hl_policy_params = AttrDict(
91 | action_dim=ll_model_params.nz_vae, # z-dimension of the skill VAE
92 | input_dim=data_spec.state_dim,
93 | squash_output_dist=True,
94 | max_action_range=2.,
95 | prior_model_params=ll_policy_params.policy_model_params,
96 | prior_model=ll_policy_params.policy_model,
97 | prior_model_checkpoint=ll_policy_params.policy_model_checkpoint,
98 | posterior_model=ll_policy_params.policy_model,
99 | posterior_model_params=copy.deepcopy(ll_policy_params.policy_model_params),
100 | posterior_model_checkpoint=os.path.join(os.environ["EXP_DIR"], "skill_posterior/kitchen/kitchen_post"),
101 | )
102 | hl_policy_params.posterior_model_params.batch_size = base_agent_params.batch_size
103 |
104 | hl_policy_params.policy_model = ll_policy_params.policy_model
105 | hl_policy_params.policy_model_params = copy.deepcopy(ll_policy_params.policy_model_params)
106 | hl_policy_params.policy_model_checkpoint = hl_policy_params.posterior_model_checkpoint
107 | hl_policy_params.policy_model_params.batch_size = base_agent_params.batch_size
108 |
109 |
110 | # HL Critic
111 | hl_critic_params = AttrDict(
112 | action_dim=hl_policy_params.action_dim,
113 | input_dim=hl_policy_params.input_dim,
114 | output_dim=1,
115 | n_layers=2,
116 | nz_mid=256,
117 | action_input=True,
118 | )
119 |
120 | # HL GAIL Demo Dataset
121 | from spirl.components.data_loader import GlobalSplitVideoDataset
122 | data_config = AttrDict()
123 | data_config.dataset_spec = data_spec
124 | data_config.dataset_spec.update(AttrDict(
125 | crop_rand_subseq=True,
126 | subseq_len=2,
127 | filter_indices=[[320, 337], [339, 344]],
128 | demo_repeats=10,
129 | ))
130 |
131 | # HL Pre-Trained Demo Discriminator
132 | demo_discriminator_config = AttrDict(
133 | state_dim=data_spec.state_dim,
134 | normalization='none',
135 | demo_data_conf=data_config,
136 | )
137 |
138 | # HL Agent
139 | hl_agent_config = copy.deepcopy(base_agent_params)
140 | hl_agent_config.update(AttrDict(
141 | policy=LearnedPPPolicy,
142 | policy_params=hl_policy_params,
143 | critic=MLPCritic,
144 | critic_params=hl_critic_params,
145 | discriminator=DemoDiscriminator,
146 | discriminator_params=demo_discriminator_config,
147 | discriminator_checkpoint=os.path.join(os.environ["EXP_DIR"], "demo_discriminator/kitchen/kitchen_discr"),
148 | freeze_discriminator=True, # don't update pretrained discriminator
149 | buffer=UniformReplayBuffer,
150 | buffer_params={'capacity': 1e6,},
151 | reset_buffer=False,
152 | replay=UniformReplayBuffer,
153 | replay_params={'dump_replay': False, 'capacity': 2e6},
154 | expert_data_conf=data_config,
155 | expert_data_path=".",
156 | ))
157 |
158 | # SkiLD Parameters
159 | hl_agent_config.update(AttrDict(
160 | lambda_gail_schedule_params=AttrDict(p=0.9),
161 | fixed_alpha=1e-1,
162 | fixed_alpha_q=1e-1,
163 | ))
164 |
165 |
166 | ##### Joint Agent #######
167 | agent_config = AttrDict(
168 | hl_agent=SkiLDAgent,
169 | hl_agent_params=hl_agent_config,
170 | ll_agent=SACAgent,
171 | ll_agent_params=ll_agent_config,
172 | hl_interval=ll_model_params.n_rollout_steps,
173 | log_videos=True,
174 | update_hl=True,
175 | update_ll=False,
176 | )
177 |
178 | # Sampler
179 | sampler_config = AttrDict(
180 | )
181 |
182 | # Environment
183 | env_config = AttrDict(
184 | reward_norm=1,
185 | name='kitchen-kbts-v0',
186 | )
187 |
188 |
--------------------------------------------------------------------------------
/skild/configs/demo_rl/maze/conf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import torch
4 |
5 | from spirl.utils.general_utils import AttrDict
6 | from spirl.rl.components.agent import FixedIntervalHierarchicalAgent
7 | from spirl.rl.components.replay_buffer import UniformReplayBuffer
8 | from spirl.rl.components.sampler import HierarchicalSampler
9 | from spirl.rl.components.critic import MLPCritic, SplitObsMLPCritic
10 | from spirl.rl.agents.ac_agent import SACAgent
11 | from spirl.rl.policies.cl_model_policies import ClModelPolicy
12 | from spirl.models.closed_loop_spirl_mdl import ClSPiRLMdl
13 | from spirl.configs.default_data_configs.maze import data_spec
14 |
15 | from skild.rl.policies.posterior_policies import LearnedPPPolicy
16 | from skild.models.demo_discriminator import DemoDiscriminator
17 | from skild.rl.envs.maze import ACRandMaze0S40Env
18 | from skild.rl.agents.skild_agent import SkiLDAgent
19 | from skild.data.maze.src.maze_agents import MazeSkiLDAgent
20 |
21 |
22 | current_dir = os.path.dirname(os.path.realpath(__file__))
23 |
24 | notes = 'used to test the RL implementation'
25 |
26 | configuration = {
27 | 'seed': 42,
28 | 'agent': FixedIntervalHierarchicalAgent,
29 | 'environment': ACRandMaze0S40Env,
30 | 'sampler': HierarchicalSampler,
31 | 'data_dir': '.',
32 | 'num_epochs': 200,
33 | 'max_rollout_len': 2000,
34 | 'n_steps_per_epoch': 1e5,
35 | 'log_output_per_epoch': 1000,
36 | 'n_warmup_steps': 2e3,
37 | }
38 | configuration = AttrDict(configuration)
39 |
40 | # Observation Normalization
41 | obs_norm_params = AttrDict(
42 | )
43 |
44 | base_agent_params = AttrDict(
45 | batch_size=128,
46 | )
47 |
48 | ###### Low-Level ######
49 | # LL Policy
50 | ll_model_params = AttrDict(
51 | state_dim=data_spec.state_dim,
52 | action_dim=data_spec.n_actions,
53 | n_rollout_steps=10,
54 | kl_div_weight=1e-3,
55 | nz_vae=10,
56 | nz_enc=128,
57 | nz_mid=128,
58 | n_processing_layers=5,
59 | cond_decode=True,
60 | )
61 |
62 | # LL Policy
63 | ll_policy_params = AttrDict(
64 | policy_model=ClSPiRLMdl,
65 | policy_model_params=ll_model_params,
66 | policy_model_checkpoint=os.path.join(os.environ["EXP_DIR"], "skill_prior/maze/maze_prior"),
67 | )
68 | ll_policy_params.update(ll_model_params)
69 |
70 | # LL Critic
71 | ll_critic_params = AttrDict(
72 | action_dim=data_spec.n_actions,
73 | input_dim=data_spec.state_dim,
74 | output_dim=1,
75 | action_input=True,
76 | unused_obs_size=ll_model_params.nz_vae, # ignore HL policy z output in observation for LL critic
77 | )
78 |
79 | # LL Agent
80 | ll_agent_config = copy.deepcopy(base_agent_params)
81 | ll_agent_config.update(AttrDict(
82 | policy=ClModelPolicy,
83 | policy_params=ll_policy_params,
84 | critic=SplitObsMLPCritic,
85 | critic_params=ll_critic_params,
86 | ))
87 |
88 | ###### High-Level ########
89 | # HL Policy
90 | hl_policy_params = AttrDict(
91 | action_dim=ll_model_params.nz_vae, # z-dimension of the skill VAE
92 | input_dim=data_spec.state_dim,
93 | squash_output_dist=True,
94 | max_action_range=2.,
95 | prior_model_params=ll_policy_params.policy_model_params,
96 | prior_model=ll_policy_params.policy_model,
97 | prior_model_checkpoint=ll_policy_params.policy_model_checkpoint,
98 | posterior_model=ll_policy_params.policy_model,
99 | posterior_model_params=copy.deepcopy(ll_policy_params.policy_model_params),
100 | posterior_model_checkpoint=os.path.join(os.environ["EXP_DIR"], "skill_posterior/maze/maze_post"),
101 | )
102 | hl_policy_params.posterior_model_params.batch_size = base_agent_params.batch_size
103 |
104 | hl_policy_params.policy_model = ll_policy_params.policy_model
105 | hl_policy_params.policy_model_params = copy.deepcopy(ll_policy_params.policy_model_params)
106 | hl_policy_params.policy_model_checkpoint = hl_policy_params.prior_model_checkpoint
107 | hl_policy_params.policy_model_params.batch_size = base_agent_params.batch_size
108 |
109 |
110 | # HL Critic
111 | hl_critic_params = AttrDict(
112 | action_dim=hl_policy_params.action_dim,
113 | input_dim=hl_policy_params.input_dim,
114 | output_dim=1,
115 | n_layers=2,
116 | nz_mid=256,
117 | action_input=True,
118 | )
119 |
120 | # HL GAIL Demo Dataset
121 | from spirl.components.data_loader import GlobalSplitVideoDataset
122 | data_config = AttrDict()
123 | data_config.dataset_spec = data_spec
124 | data_config.dataset_spec.update(AttrDict(
125 | crop_rand_subseq=True,
126 | subseq_len=2,
127 | n_seqs=10,
128 | seq_repeat=100,
129 | split=AttrDict(train=0.5, val=0.5, test=0.0),
130 | ))
131 |
132 | # HL Pre-Trained Demo Discriminator
133 | demo_discriminator_config = AttrDict(
134 | state_dim=data_spec.state_dim,
135 | normalization='none',
136 | demo_data_conf=data_config,
137 | )
138 |
139 | # HL Agent
140 | hl_agent_config = copy.deepcopy(base_agent_params)
141 | hl_agent_config.update(AttrDict(
142 | policy=LearnedPPPolicy,
143 | policy_params=hl_policy_params,
144 | critic=MLPCritic,
145 | critic_params=hl_critic_params,
146 | discriminator=DemoDiscriminator,
147 | discriminator_params=demo_discriminator_config,
148 | discriminator_checkpoint=os.path.join(os.environ["EXP_DIR"], "demo_discriminator/maze/maze_discr"),
149 | freeze_discriminator=True, # don't update pretrained discriminator
150 | buffer=UniformReplayBuffer,
151 | buffer_params={'capacity': 1e6,},
152 | reset_buffer=False,
153 | replay=UniformReplayBuffer,
154 | replay_params={'dump_replay': False, 'capacity': 2e6},
155 | expert_data_conf=data_config,
156 | expert_data_path=os.path.join(os.environ['DATA_DIR'], 'maze_demos'),
157 | ))
158 |
159 | # SkiLD Parameters
160 | hl_agent_config.update(AttrDict(
161 | lambda_gail_schedule_params=AttrDict(p=0.9),
162 | td_schedule_params=AttrDict(p=10.0),
163 | tdq_schedule_params=AttrDict(p=1.0),
164 | ))
165 |
166 |
167 | ##### Joint Agent #######
168 | agent_config = AttrDict(
169 | hl_agent=MazeSkiLDAgent,
170 | hl_agent_params=hl_agent_config,
171 | ll_agent=SACAgent,
172 | ll_agent_params=ll_agent_config,
173 | hl_interval=ll_model_params.n_rollout_steps,
174 | log_videos=False,
175 | update_hl=True,
176 | update_ll=False,
177 | )
178 |
179 | # Sampler
180 | sampler_config = AttrDict(
181 | )
182 |
183 | # Environment
184 | env_config = AttrDict(
185 | reward_norm=1,
186 | )
187 |
188 |
--------------------------------------------------------------------------------
/skild/configs/demo_rl/office/conf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import torch
4 |
5 | from spirl.utils.general_utils import AttrDict
6 | from spirl.rl.components.agent import FixedIntervalHierarchicalAgent
7 | from spirl.rl.components.replay_buffer import UniformReplayBuffer
8 | from spirl.rl.components.sampler import HierarchicalSampler
9 | from spirl.rl.components.critic import MLPCritic, SplitObsMLPCritic
10 | from spirl.rl.agents.ac_agent import SACAgent
11 | from spirl.rl.policies.cl_model_policies import ClModelPolicy
12 | from spirl.rl.envs.office import OfficeEnv
13 | from spirl.models.closed_loop_spirl_mdl import ClSPiRLMdl
14 | from spirl.configs.default_data_configs.office import data_spec
15 |
16 | from skild.rl.policies.posterior_policies import LearnedPPPolicy
17 | from skild.models.demo_discriminator import DemoDiscriminator
18 | from skild.rl.agents.skild_agent import SkiLDAgent
19 |
20 |
21 | current_dir = os.path.dirname(os.path.realpath(__file__))
22 |
23 | notes = 'used to test the RL implementation'
24 |
25 | configuration = {
26 | 'seed': 42,
27 | 'agent': FixedIntervalHierarchicalAgent,
28 | 'environment': OfficeEnv,
29 | 'sampler': HierarchicalSampler,
30 | 'data_dir': '.',
31 | 'num_epochs': 200,
32 | 'max_rollout_len': 350,
33 | 'n_steps_per_epoch': 5e5,
34 | 'log_output_per_epoch': 1000,
35 | 'n_warmup_steps': 2e3,
36 | }
37 | configuration = AttrDict(configuration)
38 |
39 | # Observation Normalization
40 | obs_norm_params = AttrDict(
41 | )
42 |
43 | base_agent_params = AttrDict(
44 | batch_size=128,
45 | )
46 |
47 | ###### Low-Level ######
48 | # LL Policy
49 | ll_model_params = AttrDict(
50 | state_dim=data_spec.state_dim,
51 | action_dim=data_spec.n_actions,
52 | n_rollout_steps=10,
53 | kl_div_weight=5e-4,
54 | nz_vae=10,
55 | nz_enc=128,
56 | nz_mid=128,
57 | n_processing_layers=5,
58 | cond_decode=True,
59 | )
60 |
61 | # LL Policy
62 | ll_policy_params = AttrDict(
63 | policy_model=ClSPiRLMdl,
64 | policy_model_params=ll_model_params,
65 | policy_model_checkpoint=os.path.join(os.environ["EXP_DIR"], "skill_prior/office/office_prior"),
66 | )
67 | ll_policy_params.update(ll_model_params)
68 |
69 | # LL Critic
70 | ll_critic_params = AttrDict(
71 | action_dim=data_spec.n_actions,
72 | input_dim=data_spec.state_dim,
73 | output_dim=1,
74 | action_input=True,
75 | unused_obs_size=ll_model_params.nz_vae, # ignore HL policy z output in observation for LL critic
76 | )
77 |
78 | # LL Agent
79 | ll_agent_config = copy.deepcopy(base_agent_params)
80 | ll_agent_config.update(AttrDict(
81 | policy=ClModelPolicy,
82 | policy_params=ll_policy_params,
83 | critic=SplitObsMLPCritic,
84 | critic_params=ll_critic_params,
85 | ))
86 |
87 | ###### High-Level ########
88 | # HL Policy
89 | hl_policy_params = AttrDict(
90 | action_dim=ll_model_params.nz_vae, # z-dimension of the skill VAE
91 | input_dim=data_spec.state_dim,
92 | squash_output_dist=True,
93 | max_action_range=2.,
94 | prior_model_params=ll_policy_params.policy_model_params,
95 | prior_model=ll_policy_params.policy_model,
96 | prior_model_checkpoint=ll_policy_params.policy_model_checkpoint,
97 | posterior_model=ll_policy_params.policy_model,
98 | posterior_model_params=copy.deepcopy(ll_policy_params.policy_model_params),
99 | posterior_model_checkpoint=os.path.join(os.environ["EXP_DIR"], "skill_posterior/office/office_post"),
100 | )
101 | hl_policy_params.posterior_model_params.batch_size = base_agent_params.batch_size
102 |
103 | hl_policy_params.policy_model = ll_policy_params.policy_model
104 | hl_policy_params.policy_model_params = copy.deepcopy(ll_policy_params.policy_model_params)
105 | hl_policy_params.policy_model_checkpoint = hl_policy_params.posterior_model_checkpoint
106 | hl_policy_params.policy_model_params.batch_size = base_agent_params.batch_size
107 |
108 |
109 | # HL Critic
110 | hl_critic_params = AttrDict(
111 | action_dim=hl_policy_params.action_dim,
112 | input_dim=hl_policy_params.input_dim,
113 | output_dim=1,
114 | n_layers=2,
115 | nz_mid=256,
116 | action_input=True,
117 | )
118 |
119 | # HL GAIL Demo Dataset
120 | from spirl.components.data_loader import GlobalSplitVideoDataset
121 | data_config = AttrDict()
122 | data_config.dataset_spec = data_spec
123 | data_config.dataset_spec.update(AttrDict(
124 | crop_rand_subseq=True,
125 | subseq_len=2,
126 | n_seqs=100,
127 | seq_repeat=100,
128 | split=AttrDict(train=0.5, val=0.5, test=0.0),
129 | ))
130 |
131 | # HL Pre-Trained Demo Discriminator
132 | demo_discriminator_config = AttrDict(
133 | state_dim=data_spec.state_dim,
134 | normalization='none',
135 | demo_data_conf=data_config,
136 | )
137 |
138 | # HL Agent
139 | hl_agent_config = copy.deepcopy(base_agent_params)
140 | hl_agent_config.update(AttrDict(
141 | policy=LearnedPPPolicy,
142 | policy_params=hl_policy_params,
143 | critic=MLPCritic,
144 | critic_params=hl_critic_params,
145 | discriminator=DemoDiscriminator,
146 | discriminator_params=demo_discriminator_config,
147 | discriminator_checkpoint=os.path.join(os.environ["EXP_DIR"], "demo_discriminator/office/office_discr"),
148 | freeze_discriminator=True, # don't update pretrained discriminator
149 | buffer=UniformReplayBuffer,
150 | buffer_params={'capacity': 1e6,},
151 | reset_buffer=False,
152 | replay=UniformReplayBuffer,
153 | replay_params={'dump_replay': False, 'capacity': 2e6},
154 | expert_data_conf=data_config,
155 | expert_data_path=os.path.join(os.environ['DATA_DIR'], 'office_demos'),
156 | ))
157 |
158 | # SkiLD Parameters
159 | hl_agent_config.update(AttrDict(
160 | lambda_gail_schedule_params=AttrDict(p=0.9),
161 | fixed_alpha=5.0,
162 | fixed_alpha_q=5.0,
163 | ))
164 |
165 |
166 | ##### Joint Agent #######
167 | agent_config = AttrDict(
168 | hl_agent=SkiLDAgent,
169 | hl_agent_params=hl_agent_config,
170 | ll_agent=SACAgent,
171 | ll_agent_params=ll_agent_config,
172 | hl_interval=ll_model_params.n_rollout_steps,
173 | log_videos=True,
174 | update_hl=True,
175 | update_ll=False,
176 | )
177 |
178 | # Sampler
179 | sampler_config = AttrDict(
180 | )
181 |
182 | # Environment
183 | env_config = AttrDict(
184 | reward_norm=1,
185 | )
186 |
187 |
--------------------------------------------------------------------------------
/skild/configs/imitation/kitchen/conf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import torch
4 |
5 | from spirl.utils.general_utils import AttrDict
6 | from spirl.rl.components.agent import FixedIntervalHierarchicalAgent
7 | from spirl.rl.components.replay_buffer import UniformReplayBuffer
8 | from spirl.rl.components.sampler import HierarchicalSampler
9 | from spirl.rl.components.critic import MLPCritic, SplitObsMLPCritic
10 | from spirl.rl.agents.ac_agent import SACAgent
11 | from spirl.rl.policies.cl_model_policies import ClModelPolicy
12 | from spirl.rl.envs.kitchen import KitchenEnv
13 | from spirl.models.closed_loop_spirl_mdl import ClSPiRLMdl
14 | from spirl.configs.default_data_configs.kitchen import data_spec
15 |
16 | from skild.rl.policies.posterior_policies import LearnedPPPolicy
17 | from skild.models.demo_discriminator import DemoDiscriminator
18 | from skild.rl.agents.skild_agent import SkiLDAgent
19 |
20 |
21 | current_dir = os.path.dirname(os.path.realpath(__file__))
22 |
23 | notes = 'used to test the RL implementation'
24 |
25 | configuration = {
26 | 'seed': 42,
27 | 'agent': FixedIntervalHierarchicalAgent,
28 | 'environment': KitchenEnv,
29 | 'sampler': HierarchicalSampler,
30 | 'data_dir': '.',
31 | 'num_epochs': 200,
32 | 'max_rollout_len': 280,
33 | 'n_steps_per_epoch': 1e6,
34 | 'log_output_per_epoch': 1000,
35 | 'n_warmup_steps': 2e3,
36 | }
37 | configuration = AttrDict(configuration)
38 |
39 | # Observation Normalization
40 | obs_norm_params = AttrDict(
41 | )
42 |
43 | base_agent_params = AttrDict(
44 | batch_size=128,
45 | # update_iterations=XXX,
46 | )
47 |
48 | ###### Low-Level ######
49 | # LL Policy
50 | ll_model_params = AttrDict(
51 | state_dim=data_spec.state_dim,
52 | action_dim=data_spec.n_actions,
53 | n_rollout_steps=10,
54 | kl_div_weight=5e-4,
55 | nz_vae=10,
56 | nz_enc=128,
57 | nz_mid=128,
58 | n_processing_layers=5,
59 | cond_decode=True,
60 | )
61 |
62 | # LL Policy
63 | ll_policy_params = AttrDict(
64 | policy_model=ClSPiRLMdl,
65 | policy_model_params=ll_model_params,
66 | policy_model_checkpoint=os.path.join(os.environ["EXP_DIR"], "skill_prior/kitchen/kitchen_prior"),
67 | )
68 | ll_policy_params.update(ll_model_params)
69 |
70 | # LL Critic
71 | ll_critic_params = AttrDict(
72 | action_dim=data_spec.n_actions,
73 | input_dim=data_spec.state_dim,
74 | output_dim=1,
75 | action_input=True,
76 | unused_obs_size=ll_model_params.nz_vae, # ignore HL policy z output in observation for LL critic
77 | )
78 |
79 | # LL Agent
80 | ll_agent_config = copy.deepcopy(base_agent_params)
81 | ll_agent_config.update(AttrDict(
82 | policy=ClModelPolicy,
83 | policy_params=ll_policy_params,
84 | critic=SplitObsMLPCritic,
85 | critic_params=ll_critic_params,
86 | ))
87 |
88 | ###### High-Level ########
89 | # HL Policy
90 | hl_policy_params = AttrDict(
91 | action_dim=ll_model_params.nz_vae, # z-dimension of the skill VAE
92 | input_dim=data_spec.state_dim,
93 | squash_output_dist=True,
94 | max_action_range=2.,
95 | prior_model_params=ll_policy_params.policy_model_params,
96 | prior_model=ll_policy_params.policy_model,
97 | prior_model_checkpoint=ll_policy_params.policy_model_checkpoint,
98 | posterior_model=ll_policy_params.policy_model,
99 | posterior_model_params=copy.deepcopy(ll_policy_params.policy_model_params),
100 | posterior_model_checkpoint=os.path.join(os.environ["EXP_DIR"], "skill_posterior/kitchen/kitchen_post"),
101 | )
102 | hl_policy_params.posterior_model_params.batch_size = base_agent_params.batch_size
103 |
104 | hl_policy_params.policy_model = ll_policy_params.policy_model
105 | hl_policy_params.policy_model_params = copy.deepcopy(ll_policy_params.policy_model_params)
106 | hl_policy_params.policy_model_checkpoint = hl_policy_params.posterior_model_checkpoint
107 | hl_policy_params.policy_model_params.batch_size = base_agent_params.batch_size
108 |
109 |
110 | # HL Critic
111 | hl_critic_params = AttrDict(
112 | action_dim=hl_policy_params.action_dim,
113 | input_dim=hl_policy_params.input_dim,
114 | output_dim=1,
115 | n_layers=2,
116 | nz_mid=256,
117 | action_input=True,
118 | )
119 |
120 | # HL GAIL Demo Dataset
121 | from spirl.components.data_loader import GlobalSplitVideoDataset
122 | data_config = AttrDict()
123 | data_config.dataset_spec = data_spec
124 | data_config.dataset_spec.update(AttrDict(
125 | crop_rand_subseq=True,
126 | subseq_len=2,
127 | filter_indices=[[320, 337], [339, 344]],
128 | demo_repeats=10,
129 | ))
130 |
131 | # HL Pre-Trained Demo Discriminator
132 | demo_discriminator_config = AttrDict(
133 | state_dim=data_spec.state_dim,
134 | normalization='none',
135 | demo_data_conf=data_config,
136 | )
137 |
138 | # HL Agent
139 | hl_agent_config = copy.deepcopy(base_agent_params)
140 | hl_agent_config.update(AttrDict(
141 | policy=LearnedPPPolicy,
142 | policy_params=hl_policy_params,
143 | critic=MLPCritic,
144 | critic_params=hl_critic_params,
145 | discriminator=DemoDiscriminator,
146 | discriminator_params=demo_discriminator_config,
147 | discriminator_checkpoint=os.path.join(os.environ["EXP_DIR"], "demo_discriminator/kitchen/kitchen_discr"),
148 | freeze_discriminator=False, # don't update pretrained discriminator
149 | discriminator_updates=5e-4,
150 | buffer=UniformReplayBuffer,
151 | buffer_params={'capacity': 1e6,},
152 | reset_buffer=False,
153 | replay=UniformReplayBuffer,
154 | replay_params={'dump_replay': False, 'capacity': 2e6},
155 | expert_data_conf=data_config,
156 | expert_data_path=".",
157 | ))
158 |
159 | # SkiLD Parameters
160 | hl_agent_config.update(AttrDict(
161 | lambda_gail_schedule_params=AttrDict(p=0.9),
162 | fixed_alpha=1e-1,
163 | fixed_alpha_q=1e-1,
164 | ))
165 |
166 |
167 | ##### Joint Agent #######
168 | agent_config = AttrDict(
169 | hl_agent=SkiLDAgent,
170 | hl_agent_params=hl_agent_config,
171 | ll_agent=SACAgent,
172 | ll_agent_params=ll_agent_config,
173 | hl_interval=ll_model_params.n_rollout_steps,
174 | log_videos=True,
175 | update_hl=True,
176 | update_ll=False,
177 | )
178 |
179 | # Sampler
180 | sampler_config = AttrDict(
181 | )
182 |
183 | # Environment
184 | env_config = AttrDict(
185 | reward_norm=1,
186 | name='kitchen-kbts-v0',
187 | )
188 |
189 |
--------------------------------------------------------------------------------
/skild/configs/imitation/maze/conf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import torch
4 |
5 | from spirl.utils.general_utils import AttrDict
6 | from spirl.rl.components.agent import FixedIntervalHierarchicalAgent
7 | from spirl.rl.components.replay_buffer import UniformReplayBuffer
8 | from spirl.rl.components.sampler import HierarchicalSampler
9 | from spirl.rl.components.critic import MLPCritic, SplitObsMLPCritic
10 | from spirl.rl.agents.ac_agent import SACAgent
11 | from spirl.rl.policies.cl_model_policies import ClModelPolicy
12 | from spirl.models.closed_loop_spirl_mdl import ClSPiRLMdl
13 | from spirl.configs.default_data_configs.maze import data_spec
14 |
15 | from skild.rl.policies.posterior_policies import LearnedPPPolicy
16 | from skild.models.demo_discriminator import DemoDiscriminator
17 | from skild.rl.envs.maze import ACRandMaze0S40Env
18 | from skild.rl.agents.skild_agent import SkiLDAgent
19 | from skild.data.maze.src.maze_agents import MazeSkiLDAgent
20 |
21 |
22 | current_dir = os.path.dirname(os.path.realpath(__file__))
23 |
24 | notes = 'used to test the RL implementation'
25 |
26 | configuration = {
27 | 'seed': 42,
28 | 'agent': FixedIntervalHierarchicalAgent,
29 | 'environment': ACRandMaze0S40Env,
30 | 'sampler': HierarchicalSampler,
31 | 'data_dir': '.',
32 | 'num_epochs': 200,
33 | 'max_rollout_len': 2000,
34 | 'n_steps_per_epoch': 1e5,
35 | 'log_output_per_epoch': 1000,
36 | 'n_warmup_steps': 2e3,
37 | }
38 | configuration = AttrDict(configuration)
39 |
40 | # Observation Normalization
41 | obs_norm_params = AttrDict(
42 | )
43 |
44 | base_agent_params = AttrDict(
45 | batch_size=128,
46 | )
47 |
48 | ###### Low-Level ######
49 | # LL Policy
50 | ll_model_params = AttrDict(
51 | state_dim=data_spec.state_dim,
52 | action_dim=data_spec.n_actions,
53 | n_rollout_steps=10,
54 | kl_div_weight=1e-3,
55 | nz_vae=10,
56 | nz_enc=128,
57 | nz_mid=128,
58 | n_processing_layers=5,
59 | cond_decode=True,
60 | )
61 |
62 | # LL Policy
63 | ll_policy_params = AttrDict(
64 | policy_model=ClSPiRLMdl,
65 | policy_model_params=ll_model_params,
66 | policy_model_checkpoint=os.path.join(os.environ["EXP_DIR"], "skill_prior/maze/maze_prior"),
67 | )
68 | ll_policy_params.update(ll_model_params)
69 |
70 | # LL Critic
71 | ll_critic_params = AttrDict(
72 | action_dim=data_spec.n_actions,
73 | input_dim=data_spec.state_dim,
74 | output_dim=1,
75 | action_input=True,
76 | unused_obs_size=ll_model_params.nz_vae, # ignore HL policy z output in observation for LL critic
77 | )
78 |
79 | # LL Agent
80 | ll_agent_config = copy.deepcopy(base_agent_params)
81 | ll_agent_config.update(AttrDict(
82 | policy=ClModelPolicy,
83 | policy_params=ll_policy_params,
84 | critic=SplitObsMLPCritic,
85 | critic_params=ll_critic_params,
86 | ))
87 |
88 | ###### High-Level ########
89 | # HL Policy
90 | hl_policy_params = AttrDict(
91 | action_dim=ll_model_params.nz_vae, # z-dimension of the skill VAE
92 | input_dim=data_spec.state_dim,
93 | squash_output_dist=True,
94 | max_action_range=2.,
95 | prior_model_params=ll_policy_params.policy_model_params,
96 | prior_model=ll_policy_params.policy_model,
97 | prior_model_checkpoint=ll_policy_params.policy_model_checkpoint,
98 | posterior_model=ll_policy_params.policy_model,
99 | posterior_model_params=copy.deepcopy(ll_policy_params.policy_model_params),
100 | posterior_model_checkpoint=os.path.join(os.environ["EXP_DIR"], "skill_posterior/maze/maze_post"),
101 | )
102 | hl_policy_params.posterior_model_params.batch_size = base_agent_params.batch_size
103 |
104 | hl_policy_params.policy_model = ll_policy_params.policy_model
105 | hl_policy_params.policy_model_params = copy.deepcopy(ll_policy_params.policy_model_params)
106 | hl_policy_params.policy_model_checkpoint = hl_policy_params.prior_model_checkpoint
107 | hl_policy_params.policy_model_params.batch_size = base_agent_params.batch_size
108 |
109 |
110 | # HL Critic
111 | hl_critic_params = AttrDict(
112 | action_dim=hl_policy_params.action_dim,
113 | input_dim=hl_policy_params.input_dim,
114 | output_dim=1,
115 | n_layers=2,
116 | nz_mid=256,
117 | action_input=True,
118 | )
119 |
120 | # HL GAIL Demo Dataset
121 | from spirl.components.data_loader import GlobalSplitVideoDataset
122 | data_config = AttrDict()
123 | data_config.dataset_spec = data_spec
124 | data_config.dataset_spec.update(AttrDict(
125 | crop_rand_subseq=True,
126 | subseq_len=2,
127 | n_seqs=10,
128 | seq_repeat=100,
129 | split=AttrDict(train=0.5, val=0.5, test=0.0),
130 | ))
131 |
132 | # HL Pre-Trained Demo Discriminator
133 | demo_discriminator_config = AttrDict(
134 | state_dim=data_spec.state_dim,
135 | normalization='none',
136 | demo_data_conf=data_config,
137 | )
138 |
139 | # HL Agent
140 | hl_agent_config = copy.deepcopy(base_agent_params)
141 | hl_agent_config.update(AttrDict(
142 | policy=LearnedPPPolicy,
143 | policy_params=hl_policy_params,
144 | critic=MLPCritic,
145 | critic_params=hl_critic_params,
146 | discriminator=DemoDiscriminator,
147 | discriminator_params=demo_discriminator_config,
148 | discriminator_checkpoint=os.path.join(os.environ["EXP_DIR"], "demo_discriminator/maze/maze_discr"),
149 | freeze_discriminator=False, # don't update pretrained discriminator
150 | discriminator_updates=0.2,
151 | buffer=UniformReplayBuffer,
152 | buffer_params={'capacity': 1e6,},
153 | reset_buffer=False,
154 | replay=UniformReplayBuffer,
155 | replay_params={'dump_replay': False, 'capacity': 2e6},
156 | expert_data_conf=data_config,
157 | expert_data_path=os.path.join(os.environ['DATA_DIR'], 'maze_demos'),
158 | ))
159 |
160 | # SkiLD Parameters
161 | hl_agent_config.update(AttrDict(
162 | lambda_gail_schedule_params=AttrDict(p=0.9),
163 | td_schedule_params=AttrDict(p=10.0),
164 | tdq_schedule_params=AttrDict(p=1.0),
165 | ))
166 |
167 |
168 | ##### Joint Agent #######
169 | agent_config = AttrDict(
170 | hl_agent=MazeSkiLDAgent,
171 | hl_agent_params=hl_agent_config,
172 | ll_agent=SACAgent,
173 | ll_agent_params=ll_agent_config,
174 | hl_interval=ll_model_params.n_rollout_steps,
175 | log_videos=False,
176 | update_hl=True,
177 | update_ll=False,
178 | )
179 |
180 | # Sampler
181 | sampler_config = AttrDict(
182 | )
183 |
184 | # Environment
185 | env_config = AttrDict(
186 | reward_norm=1,
187 | )
188 |
189 |
--------------------------------------------------------------------------------
/skild/configs/skill_posterior/kitchen/conf.py:
--------------------------------------------------------------------------------
1 | from skild.configs.skill_prior.kitchen.conf import *
2 |
3 | data_config.dataset_spec.filter_indices = [[320, 337], [339, 344]] # use only demos for one task (here: KBTS)
4 | data_config.dataset_spec.demo_repeats = 10 # repeat those demos N times
5 |
6 | model_config.embedding_checkpoint = os.path.join(os.environ["EXP_DIR"],
7 | "skill_prior/kitchen/kitchen_prior/weights")
8 |
--------------------------------------------------------------------------------
/skild/configs/skill_posterior/maze/conf.py:
--------------------------------------------------------------------------------
1 | from skild.configs.skill_prior.maze.conf import *
2 |
3 | configuration['data_dir'] = os.path.join(os.environ['DATA_DIR'], 'maze_demos')
4 | data_config.dataset_spec.n_seqs = 5 # number of demos
5 | data_config.dataset_spec.seq_repeat = 30 # how often to repeat these demos
6 |
7 | configuration['epoch_cycles_train'] = 4200
8 |
9 | model_config.embedding_checkpoint = os.path.join(os.environ["EXP_DIR"],
10 | "skill_prior/maze/maze_prior/weights")
11 |
--------------------------------------------------------------------------------
/skild/configs/skill_posterior/office/conf.py:
--------------------------------------------------------------------------------
1 | from skild.configs.skill_prior.office.conf import *
2 |
3 | configuration['data_dir'] = os.path.join(os.environ['DATA_DIR'], 'office_demos')
4 | data_config.dataset_spec.n_seqs = 50 # number of demos
5 | data_config.dataset_spec.seq_repeat = 3 # how often to repeat these demos
6 |
7 | configuration['epoch_cycles_train'] = 6000
8 |
9 | model_config.embedding_checkpoint = os.path.join(os.environ["EXP_DIR"],
10 | "skill_prior/office/office_prior/weights")
11 |
--------------------------------------------------------------------------------
/skild/configs/skill_prior/kitchen/conf.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from spirl.models.closed_loop_spirl_mdl import ClSPiRLMdl
4 | from spirl.components.logger import Logger
5 | from spirl.utils.general_utils import AttrDict
6 | from spirl.configs.default_data_configs.kitchen import data_spec
7 | from spirl.components.evaluator import TopOfNSequenceEvaluator
8 |
9 | current_dir = os.path.dirname(os.path.realpath(__file__))
10 |
11 |
12 | configuration = {
13 | 'model': ClSPiRLMdl,
14 | 'logger': Logger,
15 | 'data_dir': '.',
16 | 'epoch_cycles_train': 50,
17 | 'num_epochs': 100,
18 | 'evaluator': TopOfNSequenceEvaluator,
19 | 'top_of_n_eval': 100,
20 | 'top_comp_metric': 'mse',
21 | }
22 | configuration = AttrDict(configuration)
23 |
24 | model_config = AttrDict(
25 | state_dim=data_spec.state_dim,
26 | action_dim=data_spec.n_actions,
27 | n_rollout_steps=10,
28 | kl_div_weight=5e-4,
29 | nz_enc=128,
30 | nz_mid=128,
31 | n_processing_layers=5,
32 | cond_decode=True,
33 | )
34 |
35 | # Dataset
36 | data_config = AttrDict()
37 | data_config.dataset_spec = data_spec
38 | data_config.dataset_spec.subseq_len = model_config.n_rollout_steps + 1 # flat last action from seq gets cropped
39 |
--------------------------------------------------------------------------------
/skild/configs/skill_prior/maze/conf.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from spirl.models.closed_loop_spirl_mdl import ClSPiRLMdl
4 | from spirl.components.logger import Logger
5 | from spirl.utils.general_utils import AttrDict
6 | from spirl.configs.default_data_configs.maze import data_spec
7 | from spirl.components.evaluator import TopOfNSequenceEvaluator
8 |
9 | current_dir = os.path.dirname(os.path.realpath(__file__))
10 |
11 |
12 | configuration = {
13 | 'model': ClSPiRLMdl,
14 | 'logger': Logger,
15 | 'data_dir': os.path.join(os.environ['DATA_DIR'], 'maze_TA'),
16 | 'epoch_cycles_train': 250,
17 | 'num_epochs': 100,
18 | 'evaluator': TopOfNSequenceEvaluator,
19 | 'top_of_n_eval': 100,
20 | 'top_comp_metric': 'mse',
21 | }
22 | configuration = AttrDict(configuration)
23 |
24 | model_config = AttrDict(
25 | state_dim=data_spec.state_dim,
26 | action_dim=data_spec.n_actions,
27 | n_rollout_steps=10,
28 | kl_div_weight=1e-3,
29 | nz_enc=128,
30 | nz_mid=128,
31 | n_processing_layers=5,
32 | cond_decode=True,
33 | )
34 |
35 | # Dataset
36 | data_config = AttrDict()
37 | data_config.dataset_spec = data_spec
38 | data_config.dataset_spec.subseq_len = model_config.n_rollout_steps + 1 # flat last action from seq gets cropped
39 |
--------------------------------------------------------------------------------
/skild/configs/skill_prior/office/conf.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from spirl.models.closed_loop_spirl_mdl import ClSPiRLMdl
4 | from spirl.components.logger import Logger
5 | from spirl.utils.general_utils import AttrDict
6 | from spirl.configs.default_data_configs.office import data_spec
7 | from spirl.components.evaluator import TopOfNSequenceEvaluator
8 |
9 | current_dir = os.path.dirname(os.path.realpath(__file__))
10 |
11 |
12 | configuration = {
13 | 'model': ClSPiRLMdl,
14 | 'logger': Logger,
15 | 'data_dir': os.path.join(os.environ['DATA_DIR'], 'office_TA'),
16 | 'epoch_cycles_train': 300,
17 | 'num_epochs': 100,
18 | 'evaluator': TopOfNSequenceEvaluator,
19 | 'top_of_n_eval': 100,
20 | 'top_comp_metric': 'mse',
21 | }
22 | configuration = AttrDict(configuration)
23 |
24 | model_config = AttrDict(
25 | state_dim=data_spec.state_dim,
26 | action_dim=data_spec.n_actions,
27 | n_rollout_steps=10,
28 | kl_div_weight=5e-4,
29 | nz_enc=128,
30 | nz_mid=128,
31 | n_processing_layers=5,
32 | cond_decode=True,
33 | )
34 |
35 | # Dataset
36 | data_config = AttrDict()
37 | data_config.dataset_spec = data_spec
38 | data_config.dataset_spec.subseq_len = model_config.n_rollout_steps + 1 # flat last action from seq gets cropped
39 |
--------------------------------------------------------------------------------
/skild/data/kitchen/README.md:
--------------------------------------------------------------------------------
1 | # Choosing Kitchen Target Tasks
2 |
3 | In the kitchen environment a task defines the consecutive execution of four subtasks.
4 | Some subtask sequences can be more challenging for agents to learn than others. For SkiLD, as well as for SPiRL and any
5 | other approach that leverages prior experience, the task complexity is mainly influenced by how well the respective
6 | subtask transitions are represented in the prior experience data.
7 |
8 | We use the training data of Gupta et al., 2020 for training our models on the kitchen tasks.
9 | The subtask transitions in this dataset are not uniformly distributed, i.e., certain subtask sequences are more likely
10 | than others. Thus, we can define easier tasks in the kitchen environments as those that require more likely subtask
11 | transitions. Conversely, more challenging tasks will require unlikely or unseen subtask transitions.
12 |
13 | In the SkiLD paper, we analyze the effect of target tasks of differing alignment with the pre-training data
14 | in the kitchen environment (Section 4.4). We also provide an analysis of the subtask transition probabilities in the
15 | dataset of Gupta et al. (Figure 14, see below), which we can use to determine tasks of varying complexity.
16 |
17 |