├── .gitignore ├── .gitmodules ├── README.md ├── docs ├── index.html └── resources │ ├── clvr_icon.png │ ├── env_videos │ ├── kitchen.mp4 │ ├── maze.mp4 │ └── office.mp4 │ ├── kitchen_subtask_distribution.png │ ├── policy_videos │ ├── kitchen_skild.mp4 │ ├── kitchen_skillBCSAC.mp4 │ ├── kitchen_spirl.mp4 │ ├── office_skild.mp4 │ ├── office_skillBCSAC.mp4 │ └── office_spirl.mp4 │ ├── skild_downstream_sketch.png │ ├── skild_imitation_results.png │ ├── skild_model.png │ ├── skild_quali_results.png │ ├── skild_quant_results.png │ └── skild_teaser.png ├── requirements.txt ├── setup.py └── skild ├── configs ├── demo_discriminator │ ├── kitchen │ │ └── conf.py │ ├── maze │ │ └── conf.py │ └── office │ │ └── conf.py ├── demo_rl │ ├── kitchen │ │ └── conf.py │ ├── maze │ │ └── conf.py │ └── office │ │ └── conf.py ├── imitation │ ├── kitchen │ │ └── conf.py │ └── maze │ │ └── conf.py ├── skill_posterior │ ├── kitchen │ │ └── conf.py │ ├── maze │ │ └── conf.py │ └── office │ │ └── conf.py └── skill_prior │ ├── kitchen │ └── conf.py │ ├── maze │ └── conf.py │ └── office │ └── conf.py ├── data ├── kitchen │ ├── README.md │ └── kitchen_subtasks.py └── maze │ └── src │ └── maze_agents.py ├── models └── demo_discriminator.py └── rl ├── agents ├── gail_agent.py ├── ppo_agent.py └── skild_agent.py ├── envs └── maze.py └── policies └── posterior_policies.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ 3 | *.egg-info 4 | *.DS_Store 5 | venv 6 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "spirl"] 2 | path = spirl 3 | url = git@github.com:clvrai/spirl.git 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Demonstration-Guided Reinforcement Learning with Learned Skills 2 | #### [[Project Website]](https://clvrai.github.io/skild/) [[Paper]](https://arxiv.org/abs/2107.10253) 3 | 4 | [Karl Pertsch](https://kpertsch.github.io/)1, [Youngwoon Lee](https://youngwoon.github.io/)1, 5 | [Yue Wu](https://ventusyue.github.io/)1, [Joseph Lim](https://www.clvrai.com/)1 6 | 7 | 1CLVR Lab, University of Southern California 8 | 9 | 10 |

11 | 12 |

13 |
14 | 15 | This is the official PyTorch implementation of the paper "**Demonstration-Guided Reinforcement Learning with Learned Skills**". 16 | 17 | ## Requirements 18 | 19 | - python 3.7+ 20 | - mujoco 2.1 (for RL experiments) 21 | - Ubuntu 18.04 22 | 23 | ## Installation Instructions 24 | 25 | Create a virtual environment and install all required packages: 26 | ``` 27 | cd skild 28 | pip3 install virtualenv 29 | virtualenv -p $(which python3) ./venv 30 | source ./venv/bin/activate 31 | 32 | # Install dependencies and package 33 | pip3 install -r requirements.txt 34 | pip3 install -e . 35 | ``` 36 | 37 | Install [SPiRL](https://github.com/clvrai/spirl) as a git submodule: 38 | ``` 39 | # Download SPiRL as a submodule (all requirements should already be installed) 40 | git submodule update --init --recursive 41 | cd spirl 42 | pip3 install -e . 43 | cd .. 44 | ``` 45 | 46 | Set the environment variables that specify the root experiment and data directories. For example: 47 | ``` 48 | mkdir ./experiments 49 | mkdir ./data 50 | export EXP_DIR=./experiments 51 | export DATA_DIR=./data 52 | ``` 53 | 54 | If you are planning to use GPUs, set the target GPU via `export CUDA_VISIBLE_DEVICES=XXX`. 55 | 56 | Finally, for running RL experiments on maze or kitchen environments, install our fork of the 57 | [D4RL benchmark](https://github.com/kpertsch/d4rl) repository by following its installation instructions. Also make sure 58 | to place your Mujoco license file `mj_key.txt` in `~/.mujoco`. 59 | For running RL in the office environment, install our fork of the [Roboverse repo](https://github.com/VentusYue/roboverse) 60 | and follow it's installation instructions for installing PyBullet. 61 | 62 | ## Example Commands 63 | Our skill-based imitation / demo-guided RL pipeline is run in four steps: (1) train skill embedding and skill prior, 64 | (2) train skill posterior, (3) train demo discriminator, (4) use all components for demo-guided RL or imitation learning 65 | on the downstream task. 66 | 67 | All results will be written to [WandB](https://www.wandb.com/). Before running any of the commands below, 68 | create an account and then change the WandB entity and project name at the top of [train.py](https://github.com/clvrai/spirl/blob/master/spirl/train.py) and 69 | [rl/train.py](https://github.com/clvrai/spirl/blob/master/spirl/rl/train.py) to match your account. 70 | 71 | #### Skill Embedding & Prior 72 | To train skill embedding and skill prior model for the kitchen environment, run: 73 | ``` 74 | python3 spirl/spirl/train.py --path=skild/configs/skill_prior/kitchen --val_data_size=160 --prefix=kitchen_prior 75 | ``` 76 | 77 | #### Skill Posterior 78 | For training the skill posterior on the demonstration data, run: 79 | ``` 80 | python3 spirl/spirl/train.py --path=skild/configs/skill_posterior/kitchen --val_data_size=160 --prefix=kitchen_post 81 | ``` 82 | Note that the skill posterior can only be trained once skill embedding and prior training is completed 83 | since it leverages the pre-trained skill embedding. 84 | 85 | #### Demo Discriminator 86 | For training the demonstration discriminator, run: 87 | ``` 88 | python3 spirl/spirl/train.py --path=skild/configs/demo_discriminator/kitchen --val_data_size=160 --prefix=kitchen_discr 89 | ``` 90 | 91 | #### Demonstration-Guided RL 92 | For training a SkiLD agent on the kitchen environment using the pre-trained components from above, run: 93 | ``` 94 | python3 spirl/spirl/rl/train.py --path=skild/configs/demo_rl/kitchen --seed=0 --prefix=SkiLD_demoRL_kitchen_seed0 95 | ``` 96 | 97 | #### Imitation Learning 98 | For training a SkiLD agent on the kitchen environment with pure imitation learning, run: 99 | ``` 100 | python3 spirl/spirl/rl/train.py --path=skild/configs/imitation/kitchen --seed=0 --prefix=SkiLD_IL_kitchen_seed0 101 | ``` 102 | 103 | In all commands above, `kitchen` can be replaced with `maze / office` to run on the respective environment. Before training models 104 | on these environments, the corresponding datasets need to be downloaded (the kitchen dataset gets downloaded automatically) 105 | -- download links are provided below. 106 | 107 | To accelerate RL / IL training, you can use MPI for multi-processing by pre-pending `mpirun -np XXX` to the above RL / IL commands, where `XXX` corresponds to the number of parallel workers you want to spawn. Also update the corresponding [config file](skild/configs/demo_rl/kitchen/conf.py) by uncommenting the `update_iterations = XXX` line and again replacing `XXX` with the desired number of workers. 108 | 109 | 110 | ## Datasets 111 | 112 | |Dataset | Link | Size | 113 | |:------------- |:-------------|:-----| 114 | | Maze Task-Agnostic | [https://drive.google.com/file/d/103RFpEg4ATnH06fd1ps8ZQL4sTtifrvX/view?usp=sharing](https://drive.google.com/file/d/103RFpEg4ATnH06fd1ps8ZQL4sTtifrvX/view?usp=sharing)| 470MB | 115 | | Maze Demos | [https://drive.google.com/file/d/1wTR9ns5QsEJnrMJRXFEJWCMk-d1s4S9t/view?usp=sharing](https://drive.google.com/file/d/1wTR9ns5QsEJnrMJRXFEJWCMk-d1s4S9t/view?usp=sharing)| 100MB | 116 | | Office Cleanup Task-Agnostic | [https://drive.google.com/file/d/1yNsTZkefMMvdbIBe-dTHJxgPIRXyxzb7/view?usp=sharing](https://drive.google.com/file/d/1yNsTZkefMMvdbIBe-dTHJxgPIRXyxzb7/view?usp=sharing)| 170MB | 117 | | Office Cleanup Demos | [https://drive.google.com/file/d/149trMTyh3A2KnbUOXwt6Lc3ba-1T9SXj/view?usp=sharing](https://drive.google.com/file/d/149trMTyh3A2KnbUOXwt6Lc3ba-1T9SXj/view?usp=sharing)| 6MB | 118 | 119 | To download the dataset files from Google Drive via the command line, you can use the 120 | [gdown](https://github.com/wkentaro/gdown) package. Install it with: 121 | ``` 122 | pip install gdown 123 | ``` 124 | 125 | Then navigate to the folder you want to download the data to and run the following commands: 126 | ``` 127 | # Download Maze Task-Agnostic Dataset 128 | gdown https://drive.google.com/uc?id=103RFpEg4ATnH06fd1ps8ZQL4sTtifrvX 129 | 130 | # Download Maze Demonstration Dataset 131 | gdown https://drive.google.com/uc?id=1wTR9ns5QsEJnrMJRXFEJWCMk-d1s4S9t 132 | ``` 133 | 134 | Finally, unzip the downloaded files with `unzip `. 135 | 136 | ## Code Structure & Modifying the Code 137 | For a more detailed documentation of the code structure and how to extend the code (adding new enviroments, models, RL algos) 138 | please check the [documentation in the SPiRL repo](https://github.com/clvrai/spirl#starting-to-modify-the-code). 139 | 140 | ## Citation 141 | If you find this work useful in your research, please consider citing: 142 | ``` 143 | @article{pertsch2021skild, 144 | title={Demonstration-Guided Reinforcement Learning with Learned Skills}, 145 | author={Karl Pertsch and Youngwoon Lee and Yue Wu and Joseph J. Lim}, 146 | journal={5th Conference on Robot Learning}, 147 | year={2021}, 148 | } 149 | ``` 150 | 151 | ## Acknowledgements 152 | Most of the heavy-lifting in this code is done by the [SPiRL codebase](https://github.com/clvrai/spirl), published as part 153 | of our prior work. 154 | 155 | We thank Justin Fu and Aviral Kumar et al. for providing the [D4RL codebase](https://github.com/rail-berkeley/d4rl) 156 | which we use for some of our experiments. We also thank Avi Singh et al. for open-sourcing the [Roboverse repo](https://github.com/avisingh599/roboverse) 157 | which we build on for our office environment experiments. 158 | 159 | 160 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 176 | 177 | 178 | 179 | 187 | 188 | 189 | 190 | 191 | 192 |
193 | 194 | 195 | 196 | Demonstration-Guided Reinforcement Learning with Learned Skills 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 |
230 |
Demonstration-Guided Reinforcement Learning
with Learned Skills

231 |
232 |
Karl Pertsch
233 | 234 |
235 | 236 |
Youngwoon Lee
237 | 238 |
239 | 240 |
Yue Wu
241 | 242 |
243 | 244 |
Joseph Lim
245 | 246 |
247 |
248 | 249 | 250 | 251 | 252 |
CLVR Lab, University of Southern California
253 |
Conference on Robot Learning (CoRL), 2021
254 | 255 |
256 |
[Paper]
257 |
[Code]
258 | 259 |
260 | 261 | 262 | 264 | 265 | 266 | 267 |
268 |

269 |
270 | 271 |
272 | Demonstration-guided reinforcement learning (RL) is a promising approach for learning complex behaviors by leveraging both reward feedback and a set of target task demonstrations. Prior approaches for demonstration-guided RL treat every new task as an independent learning problem and attempt to follow the provided demonstrations step-by-step, akin to a human trying to imitate a completely unseen behavior by following the demonstrator's exact muscle movements. Naturally, such learning will be slow, but often new behaviors are not completely unseen: they share subtasks with behaviors we have previously learned. In this work, we aim to exploit this shared subtask structure to increase the efficiency of demonstration-guided RL. We first learn a set of reusable skills from large offline datasets of prior experience collected across many tasks. We then propose Skill-based Learning with Demonstrations (SkiLD), an algorithm for demonstration-guided RL that efficiently leverages the provided demonstrations by following the demonstrated skills instead of the primitive actions, resulting in substantial performance improvements over prior demonstration-guided RL approaches. We validate the effectiveness of our approach on long-horizon maze navigation and complex robot manipulation tasks. 273 |
274 |

275 | 276 | 277 | 278 |

Overview

279 |
280 | Our goal is to use skills extracted from prior experience to improve the efficiency of demonstration-guided RL on a new task. We aim to leverage a set of provided demonstrations by following the performed skills as opposed to the primitive actions. 281 |

282 |

283 |
284 |
285 | Learning in our approach, SkiLD, is performed in three stages. (1): First, we extract a set of reusable skills from prior, task-agnostic experience. We build on prior work in skill-based RL for learning the skill extraction module (SPiRL, Pertsch et al. 2020). (2): We then use the pre-trained skill encoder to infer the skills performed in task-agnostic and demonstration sequences and learn state-conditioned skill distributions, which we call skill prior and skill posterior respectively. (3): Finally, we use both distributions to guide a hierarchical skill policy during learning of the downstream task. 286 |

287 | 288 | 289 | 290 | 291 | 295 | 296 | 303 | 304 |
292 | 293 | 294 | 297 |

Demonstration-Guided Downstream Learning

298 | 299 | While we have learned a state-conditioned distribution over the demonstrated skills, we cannot always trust this skill posterior, since it is only valid within the demonstration support (green region). Thus, to guide the hierarchical policy during downstream learning, SkiLD leverages the skill posterior only within the support of the demonstrations and uses the learned skill prior otherwise, since it was trained on the task-agnostic experience dataset with a much wider support (red region). 300 | 301 | 302 |


305 | 306 | 307 | 308 | 309 | 310 |

Environments

311 | 312 | 313 | 317 | 318 | 322 | 323 | 327 | 328 |
314 |

Maze Navigation

315 | 316 |
319 |

Kitchen Manipulation

320 | 321 |
324 |

Office Cleanup

325 | 326 |

329 | 330 |
331 | We evaluate our approach on three long-horizon tasks: maze navigation, kitchen manipulation and office cleanup. In each environment, we collect a large, task-agnostic dataset and a small set of task-specific demonstrations. 332 |
333 |
334 | 335 | 336 | 337 | 338 |
339 |

How does SkiLD Follow the Demonstrations?

340 |
341 | 342 |

343 |
344 | We analyze the qualitative behavior of our approach in the maze environment: the discriminator D(s) can accurately estimate the support of the demonstrations (green). Thus, the SkiLD policy minimizes divergence to the demonstration-based skill posterior within the demonstration support (third panel, blue) and follows the task-agnostic skill prior otherwise (fourth panel). In summary, the agent learns to follow the demonstrations whenever it's within their support and falls back to prior-based exploration outside the support. 345 |
346 |

347 | 348 | 349 | 350 | 351 |
352 |

Qualitative Results

353 |
354 | 355 | 356 | 361 | 362 | 366 | 367 | 371 | 372 | 376 | 377 | 378 | 379 | 384 | 385 | 388 | 389 | 392 | 393 | 396 | 397 |
357 |
358 | Kitchen Manipulation 359 |
360 |
363 |

SkiLD

364 | 365 |
368 |

SPiRL

369 | 370 |
373 |

SkillBC + SAC

374 | 375 |
380 |
381 | Office Cleanup 382 |
383 |
386 | 387 | 390 | 391 | 394 | 395 |
398 |
399 | Rollouts from the trained policies on the robotic manipulation tasks. In the kitchen environment the agent needs to perform four subtasks: open microwave, flip light switch, open slide cabinet, open hinge cabinet. In the office cleanup task it needs to put the correct objects in the correct receptacles. In both environments, our approach SkiLD is the only method that cann solve the full task. SPiRL lacks guidance through the demonstrations and thus solves wrong subtasks and fails at the target task. Skill-based BC with SAC finetuning is brittle and unable to solve more than one subtask. For more qualitative result videos, please check our supplementary website. 400 |


401 | 402 | 403 | 404 |
405 |

Quantitative Results

406 |
407 | 408 |

409 |
410 | 411 | 412 |
413 |

Imitation Learning Results

414 |
415 | 416 |

417 |
418 | We apply SkiLD in the pure imitation setting, without access to environment rewards and instead use a GAIL-style reward based on our learned discriminator, which is trained to estimate demonstration support. We show that our approach is able to leverage prior experience through skills for effective imitation of long-horizon tasks. By finetuning the learned discriminator we can further improve performance on the kitchen manipulation task which requires more complex control. 419 |


420 | 421 | 422 |

Source Code

423 |
424 | We have released our implementation in PyTorch on the github page. Try our code! 425 |
426 |
427 | [GitHub] 428 |
429 |

430 | 431 | 432 | 433 |

Citation

434 | 435 | 445 | 446 |
436 |

437 |           @article{pertsch2021skild,
438 |             title={Demonstration-Guided Reinforcement Learning with Learned Skills},
439 |             author={Karl Pertsch and Youngwoon Lee and Yue Wu and Joseph J. Lim},
440 |             journal={5th Conference on Robot Learning},
441 |             year={2021},
442 |           }
443 |         
444 |
447 |

448 | 449 | 450 | 454 | 455 | 456 | 459 |
460 | 461 | 462 | -------------------------------------------------------------------------------- /docs/resources/clvr_icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/clvr_icon.png -------------------------------------------------------------------------------- /docs/resources/env_videos/kitchen.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/env_videos/kitchen.mp4 -------------------------------------------------------------------------------- /docs/resources/env_videos/maze.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/env_videos/maze.mp4 -------------------------------------------------------------------------------- /docs/resources/env_videos/office.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/env_videos/office.mp4 -------------------------------------------------------------------------------- /docs/resources/kitchen_subtask_distribution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/kitchen_subtask_distribution.png -------------------------------------------------------------------------------- /docs/resources/policy_videos/kitchen_skild.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/policy_videos/kitchen_skild.mp4 -------------------------------------------------------------------------------- /docs/resources/policy_videos/kitchen_skillBCSAC.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/policy_videos/kitchen_skillBCSAC.mp4 -------------------------------------------------------------------------------- /docs/resources/policy_videos/kitchen_spirl.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/policy_videos/kitchen_spirl.mp4 -------------------------------------------------------------------------------- /docs/resources/policy_videos/office_skild.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/policy_videos/office_skild.mp4 -------------------------------------------------------------------------------- /docs/resources/policy_videos/office_skillBCSAC.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/policy_videos/office_skillBCSAC.mp4 -------------------------------------------------------------------------------- /docs/resources/policy_videos/office_spirl.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/policy_videos/office_spirl.mp4 -------------------------------------------------------------------------------- /docs/resources/skild_downstream_sketch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/skild_downstream_sketch.png -------------------------------------------------------------------------------- /docs/resources/skild_imitation_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/skild_imitation_results.png -------------------------------------------------------------------------------- /docs/resources/skild_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/skild_model.png -------------------------------------------------------------------------------- /docs/resources/skild_quali_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/skild_quali_results.png -------------------------------------------------------------------------------- /docs/resources/skild_quant_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/skild_quant_results.png -------------------------------------------------------------------------------- /docs/resources/skild_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/skild/91868b1fb1460e7e9013711bd01a047fc7a8ec7b/docs/resources/skild_teaser.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # core 2 | numpy 3 | matplotlib 4 | pillow 5 | h5py==2.10.0 6 | scikit-image 7 | funcsigs 8 | opencv-python 9 | moviepy 10 | torch==1.3.1 11 | torchvision==0.4.2 12 | tensorboard==2.1.1 13 | tensorboardX==2.0 14 | gym==0.15.4 15 | pandas 16 | 17 | # RL 18 | wandb 19 | mpi4py 20 | mujoco_py==2.0.2.9 21 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | 3 | setup(name='skild', version='0.0.1', packages=['skild']) 4 | -------------------------------------------------------------------------------- /skild/configs/demo_discriminator/kitchen/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | 4 | from spirl.utils.general_utils import AttrDict 5 | from skild.models.demo_discriminator import DemoDiscriminator, DemoDiscriminatorLogger 6 | from spirl.configs.default_data_configs.kitchen import data_spec 7 | from spirl.components.evaluator import DummyEvaluator 8 | 9 | 10 | current_dir = os.path.dirname(os.path.realpath(__file__)) 11 | 12 | 13 | configuration = { 14 | 'model': DemoDiscriminator, 15 | 'model_test': DemoDiscriminator, 16 | 'logger': DemoDiscriminatorLogger, 17 | 'logger_test': DemoDiscriminatorLogger, 18 | 'data_dir': ".", 19 | 'num_epochs': 100, 20 | 'epoch_cycles_train': 10, 21 | 'evaluator': DummyEvaluator, 22 | } 23 | configuration = AttrDict(configuration) 24 | 25 | model_config = AttrDict( 26 | action_dim=data_spec.n_actions, 27 | normalization='none', 28 | ) 29 | 30 | # Demo Dataset 31 | demo_data_config = AttrDict() 32 | demo_data_config.dataset_spec = copy.deepcopy(data_spec) 33 | demo_data_config.dataset_spec.crop_rand_subseq = True 34 | demo_data_config.dataset_spec.subseq_len = 1+1 35 | demo_data_config.dataset_spec.filter_indices = [[320, 337], [339, 344]] # use only demos for one task (here: KBTS) 36 | demo_data_config.dataset_spec.demo_repeats = 10 # repeat those demos N times 37 | model_config.demo_data_conf = demo_data_config 38 | model_config.demo_data_path = '.' 39 | 40 | # Non-demo Dataset 41 | data_config = AttrDict() 42 | data_config.dataset_spec = data_spec 43 | data_config.dataset_spec.crop_rand_subseq = True 44 | data_config.dataset_spec.subseq_len = 1+1 45 | -------------------------------------------------------------------------------- /skild/configs/demo_discriminator/maze/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | 4 | from spirl.utils.general_utils import AttrDict 5 | from skild.models.demo_discriminator import DemoDiscriminator, DemoDiscriminatorLogger 6 | from spirl.configs.default_data_configs.maze import data_spec 7 | from spirl.components.evaluator import DummyEvaluator 8 | 9 | 10 | current_dir = os.path.dirname(os.path.realpath(__file__)) 11 | 12 | 13 | configuration = { 14 | 'model': DemoDiscriminator, 15 | 'model_test': DemoDiscriminator, 16 | 'logger': DemoDiscriminatorLogger, 17 | 'logger_test': DemoDiscriminatorLogger, 18 | 'data_dir': os.path.join(os.environ['DATA_DIR'], 'maze_TA'), 19 | 'num_epochs': 100, 20 | 'epoch_cycles_train': 200, 21 | 'evaluator': DummyEvaluator, 22 | } 23 | configuration = AttrDict(configuration) 24 | 25 | model_config = AttrDict( 26 | action_dim=data_spec.n_actions, 27 | normalization='none', 28 | ) 29 | 30 | # Demo Dataset 31 | demo_data_config = AttrDict() 32 | demo_data_config.dataset_spec = copy.deepcopy(data_spec) 33 | demo_data_config.dataset_spec.crop_rand_subseq = True 34 | demo_data_config.dataset_spec.subseq_len = 1+1 35 | demo_data_config.dataset_spec.n_seqs = 5 # number of demos used 36 | demo_data_config.dataset_spec.seq_repeat = 30 # repeat those demos N times 37 | model_config.demo_data_conf = demo_data_config 38 | model_config.demo_data_path = os.path.join(os.environ['DATA_DIR'], 'maze_demos') 39 | 40 | # Non-demo Dataset 41 | data_config = AttrDict() 42 | data_config.dataset_spec = data_spec 43 | data_config.dataset_spec.crop_rand_subseq = True 44 | data_config.dataset_spec.subseq_len = 1+1 45 | -------------------------------------------------------------------------------- /skild/configs/demo_discriminator/office/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | 4 | from spirl.utils.general_utils import AttrDict 5 | from skild.models.demo_discriminator import DemoDiscriminator, DemoDiscriminatorLogger 6 | from spirl.configs.default_data_configs.office import data_spec 7 | from spirl.components.evaluator import DummyEvaluator 8 | 9 | 10 | current_dir = os.path.dirname(os.path.realpath(__file__)) 11 | 12 | 13 | configuration = { 14 | 'model': DemoDiscriminator, 15 | 'model_test': DemoDiscriminator, 16 | 'logger': DemoDiscriminatorLogger, 17 | 'logger_test': DemoDiscriminatorLogger, 18 | 'data_dir': os.path.join(os.environ['DATA_DIR'], 'office_TA'), 19 | 'num_epochs': 100, 20 | 'epoch_cycles_train': 300, 21 | 'evaluator': DummyEvaluator, 22 | } 23 | configuration = AttrDict(configuration) 24 | 25 | model_config = AttrDict( 26 | action_dim=data_spec.n_actions, 27 | normalization='none', 28 | ) 29 | 30 | # Demo Dataset 31 | demo_data_config = AttrDict() 32 | demo_data_config.dataset_spec = copy.deepcopy(data_spec) 33 | demo_data_config.dataset_spec.crop_rand_subseq = True 34 | demo_data_config.dataset_spec.subseq_len = 1+1 35 | demo_data_config.dataset_spec.n_seqs = 50 # number of demos used 36 | demo_data_config.dataset_spec.seq_repeat = 3 # repeat those demos N times 37 | model_config.demo_data_conf = demo_data_config 38 | model_config.demo_data_path = os.path.join(os.environ['DATA_DIR'], 'office_demos') 39 | 40 | # Non-demo Dataset 41 | data_config = AttrDict() 42 | data_config.dataset_spec = data_spec 43 | data_config.dataset_spec.crop_rand_subseq = True 44 | data_config.dataset_spec.subseq_len = 1+1 45 | -------------------------------------------------------------------------------- /skild/configs/demo_rl/kitchen/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import torch 4 | 5 | from spirl.utils.general_utils import AttrDict 6 | from spirl.rl.components.agent import FixedIntervalHierarchicalAgent 7 | from spirl.rl.components.replay_buffer import UniformReplayBuffer 8 | from spirl.rl.components.sampler import HierarchicalSampler 9 | from spirl.rl.components.critic import MLPCritic, SplitObsMLPCritic 10 | from spirl.rl.agents.ac_agent import SACAgent 11 | from spirl.rl.policies.cl_model_policies import ClModelPolicy 12 | from spirl.rl.envs.kitchen import KitchenEnv 13 | from spirl.models.closed_loop_spirl_mdl import ClSPiRLMdl 14 | from spirl.configs.default_data_configs.kitchen import data_spec 15 | 16 | from skild.rl.policies.posterior_policies import LearnedPPPolicy 17 | from skild.models.demo_discriminator import DemoDiscriminator 18 | from skild.rl.agents.skild_agent import SkiLDAgent 19 | 20 | 21 | current_dir = os.path.dirname(os.path.realpath(__file__)) 22 | 23 | notes = 'used to test the RL implementation' 24 | 25 | configuration = { 26 | 'seed': 42, 27 | 'agent': FixedIntervalHierarchicalAgent, 28 | 'environment': KitchenEnv, 29 | 'sampler': HierarchicalSampler, 30 | 'data_dir': '.', 31 | 'num_epochs': 200, 32 | 'max_rollout_len': 280, 33 | 'n_steps_per_epoch': 1e6, 34 | 'log_output_per_epoch': 1000, 35 | 'n_warmup_steps': 2e3, 36 | } 37 | configuration = AttrDict(configuration) 38 | 39 | # Observation Normalization 40 | obs_norm_params = AttrDict( 41 | ) 42 | 43 | base_agent_params = AttrDict( 44 | batch_size=128, 45 | # update_iterations=XXX, 46 | ) 47 | 48 | ###### Low-Level ###### 49 | # LL Policy 50 | ll_model_params = AttrDict( 51 | state_dim=data_spec.state_dim, 52 | action_dim=data_spec.n_actions, 53 | n_rollout_steps=10, 54 | kl_div_weight=5e-4, 55 | nz_vae=10, 56 | nz_enc=128, 57 | nz_mid=128, 58 | n_processing_layers=5, 59 | cond_decode=True, 60 | ) 61 | 62 | # LL Policy 63 | ll_policy_params = AttrDict( 64 | policy_model=ClSPiRLMdl, 65 | policy_model_params=ll_model_params, 66 | policy_model_checkpoint=os.path.join(os.environ["EXP_DIR"], "skill_prior/kitchen/kitchen_prior"), 67 | ) 68 | ll_policy_params.update(ll_model_params) 69 | 70 | # LL Critic 71 | ll_critic_params = AttrDict( 72 | action_dim=data_spec.n_actions, 73 | input_dim=data_spec.state_dim, 74 | output_dim=1, 75 | action_input=True, 76 | unused_obs_size=ll_model_params.nz_vae, # ignore HL policy z output in observation for LL critic 77 | ) 78 | 79 | # LL Agent 80 | ll_agent_config = copy.deepcopy(base_agent_params) 81 | ll_agent_config.update(AttrDict( 82 | policy=ClModelPolicy, 83 | policy_params=ll_policy_params, 84 | critic=SplitObsMLPCritic, 85 | critic_params=ll_critic_params, 86 | )) 87 | 88 | ###### High-Level ######## 89 | # HL Policy 90 | hl_policy_params = AttrDict( 91 | action_dim=ll_model_params.nz_vae, # z-dimension of the skill VAE 92 | input_dim=data_spec.state_dim, 93 | squash_output_dist=True, 94 | max_action_range=2., 95 | prior_model_params=ll_policy_params.policy_model_params, 96 | prior_model=ll_policy_params.policy_model, 97 | prior_model_checkpoint=ll_policy_params.policy_model_checkpoint, 98 | posterior_model=ll_policy_params.policy_model, 99 | posterior_model_params=copy.deepcopy(ll_policy_params.policy_model_params), 100 | posterior_model_checkpoint=os.path.join(os.environ["EXP_DIR"], "skill_posterior/kitchen/kitchen_post"), 101 | ) 102 | hl_policy_params.posterior_model_params.batch_size = base_agent_params.batch_size 103 | 104 | hl_policy_params.policy_model = ll_policy_params.policy_model 105 | hl_policy_params.policy_model_params = copy.deepcopy(ll_policy_params.policy_model_params) 106 | hl_policy_params.policy_model_checkpoint = hl_policy_params.posterior_model_checkpoint 107 | hl_policy_params.policy_model_params.batch_size = base_agent_params.batch_size 108 | 109 | 110 | # HL Critic 111 | hl_critic_params = AttrDict( 112 | action_dim=hl_policy_params.action_dim, 113 | input_dim=hl_policy_params.input_dim, 114 | output_dim=1, 115 | n_layers=2, 116 | nz_mid=256, 117 | action_input=True, 118 | ) 119 | 120 | # HL GAIL Demo Dataset 121 | from spirl.components.data_loader import GlobalSplitVideoDataset 122 | data_config = AttrDict() 123 | data_config.dataset_spec = data_spec 124 | data_config.dataset_spec.update(AttrDict( 125 | crop_rand_subseq=True, 126 | subseq_len=2, 127 | filter_indices=[[320, 337], [339, 344]], 128 | demo_repeats=10, 129 | )) 130 | 131 | # HL Pre-Trained Demo Discriminator 132 | demo_discriminator_config = AttrDict( 133 | state_dim=data_spec.state_dim, 134 | normalization='none', 135 | demo_data_conf=data_config, 136 | ) 137 | 138 | # HL Agent 139 | hl_agent_config = copy.deepcopy(base_agent_params) 140 | hl_agent_config.update(AttrDict( 141 | policy=LearnedPPPolicy, 142 | policy_params=hl_policy_params, 143 | critic=MLPCritic, 144 | critic_params=hl_critic_params, 145 | discriminator=DemoDiscriminator, 146 | discriminator_params=demo_discriminator_config, 147 | discriminator_checkpoint=os.path.join(os.environ["EXP_DIR"], "demo_discriminator/kitchen/kitchen_discr"), 148 | freeze_discriminator=True, # don't update pretrained discriminator 149 | buffer=UniformReplayBuffer, 150 | buffer_params={'capacity': 1e6,}, 151 | reset_buffer=False, 152 | replay=UniformReplayBuffer, 153 | replay_params={'dump_replay': False, 'capacity': 2e6}, 154 | expert_data_conf=data_config, 155 | expert_data_path=".", 156 | )) 157 | 158 | # SkiLD Parameters 159 | hl_agent_config.update(AttrDict( 160 | lambda_gail_schedule_params=AttrDict(p=0.9), 161 | fixed_alpha=1e-1, 162 | fixed_alpha_q=1e-1, 163 | )) 164 | 165 | 166 | ##### Joint Agent ####### 167 | agent_config = AttrDict( 168 | hl_agent=SkiLDAgent, 169 | hl_agent_params=hl_agent_config, 170 | ll_agent=SACAgent, 171 | ll_agent_params=ll_agent_config, 172 | hl_interval=ll_model_params.n_rollout_steps, 173 | log_videos=True, 174 | update_hl=True, 175 | update_ll=False, 176 | ) 177 | 178 | # Sampler 179 | sampler_config = AttrDict( 180 | ) 181 | 182 | # Environment 183 | env_config = AttrDict( 184 | reward_norm=1, 185 | name='kitchen-kbts-v0', 186 | ) 187 | 188 | -------------------------------------------------------------------------------- /skild/configs/demo_rl/maze/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import torch 4 | 5 | from spirl.utils.general_utils import AttrDict 6 | from spirl.rl.components.agent import FixedIntervalHierarchicalAgent 7 | from spirl.rl.components.replay_buffer import UniformReplayBuffer 8 | from spirl.rl.components.sampler import HierarchicalSampler 9 | from spirl.rl.components.critic import MLPCritic, SplitObsMLPCritic 10 | from spirl.rl.agents.ac_agent import SACAgent 11 | from spirl.rl.policies.cl_model_policies import ClModelPolicy 12 | from spirl.models.closed_loop_spirl_mdl import ClSPiRLMdl 13 | from spirl.configs.default_data_configs.maze import data_spec 14 | 15 | from skild.rl.policies.posterior_policies import LearnedPPPolicy 16 | from skild.models.demo_discriminator import DemoDiscriminator 17 | from skild.rl.envs.maze import ACRandMaze0S40Env 18 | from skild.rl.agents.skild_agent import SkiLDAgent 19 | from skild.data.maze.src.maze_agents import MazeSkiLDAgent 20 | 21 | 22 | current_dir = os.path.dirname(os.path.realpath(__file__)) 23 | 24 | notes = 'used to test the RL implementation' 25 | 26 | configuration = { 27 | 'seed': 42, 28 | 'agent': FixedIntervalHierarchicalAgent, 29 | 'environment': ACRandMaze0S40Env, 30 | 'sampler': HierarchicalSampler, 31 | 'data_dir': '.', 32 | 'num_epochs': 200, 33 | 'max_rollout_len': 2000, 34 | 'n_steps_per_epoch': 1e5, 35 | 'log_output_per_epoch': 1000, 36 | 'n_warmup_steps': 2e3, 37 | } 38 | configuration = AttrDict(configuration) 39 | 40 | # Observation Normalization 41 | obs_norm_params = AttrDict( 42 | ) 43 | 44 | base_agent_params = AttrDict( 45 | batch_size=128, 46 | ) 47 | 48 | ###### Low-Level ###### 49 | # LL Policy 50 | ll_model_params = AttrDict( 51 | state_dim=data_spec.state_dim, 52 | action_dim=data_spec.n_actions, 53 | n_rollout_steps=10, 54 | kl_div_weight=1e-3, 55 | nz_vae=10, 56 | nz_enc=128, 57 | nz_mid=128, 58 | n_processing_layers=5, 59 | cond_decode=True, 60 | ) 61 | 62 | # LL Policy 63 | ll_policy_params = AttrDict( 64 | policy_model=ClSPiRLMdl, 65 | policy_model_params=ll_model_params, 66 | policy_model_checkpoint=os.path.join(os.environ["EXP_DIR"], "skill_prior/maze/maze_prior"), 67 | ) 68 | ll_policy_params.update(ll_model_params) 69 | 70 | # LL Critic 71 | ll_critic_params = AttrDict( 72 | action_dim=data_spec.n_actions, 73 | input_dim=data_spec.state_dim, 74 | output_dim=1, 75 | action_input=True, 76 | unused_obs_size=ll_model_params.nz_vae, # ignore HL policy z output in observation for LL critic 77 | ) 78 | 79 | # LL Agent 80 | ll_agent_config = copy.deepcopy(base_agent_params) 81 | ll_agent_config.update(AttrDict( 82 | policy=ClModelPolicy, 83 | policy_params=ll_policy_params, 84 | critic=SplitObsMLPCritic, 85 | critic_params=ll_critic_params, 86 | )) 87 | 88 | ###### High-Level ######## 89 | # HL Policy 90 | hl_policy_params = AttrDict( 91 | action_dim=ll_model_params.nz_vae, # z-dimension of the skill VAE 92 | input_dim=data_spec.state_dim, 93 | squash_output_dist=True, 94 | max_action_range=2., 95 | prior_model_params=ll_policy_params.policy_model_params, 96 | prior_model=ll_policy_params.policy_model, 97 | prior_model_checkpoint=ll_policy_params.policy_model_checkpoint, 98 | posterior_model=ll_policy_params.policy_model, 99 | posterior_model_params=copy.deepcopy(ll_policy_params.policy_model_params), 100 | posterior_model_checkpoint=os.path.join(os.environ["EXP_DIR"], "skill_posterior/maze/maze_post"), 101 | ) 102 | hl_policy_params.posterior_model_params.batch_size = base_agent_params.batch_size 103 | 104 | hl_policy_params.policy_model = ll_policy_params.policy_model 105 | hl_policy_params.policy_model_params = copy.deepcopy(ll_policy_params.policy_model_params) 106 | hl_policy_params.policy_model_checkpoint = hl_policy_params.prior_model_checkpoint 107 | hl_policy_params.policy_model_params.batch_size = base_agent_params.batch_size 108 | 109 | 110 | # HL Critic 111 | hl_critic_params = AttrDict( 112 | action_dim=hl_policy_params.action_dim, 113 | input_dim=hl_policy_params.input_dim, 114 | output_dim=1, 115 | n_layers=2, 116 | nz_mid=256, 117 | action_input=True, 118 | ) 119 | 120 | # HL GAIL Demo Dataset 121 | from spirl.components.data_loader import GlobalSplitVideoDataset 122 | data_config = AttrDict() 123 | data_config.dataset_spec = data_spec 124 | data_config.dataset_spec.update(AttrDict( 125 | crop_rand_subseq=True, 126 | subseq_len=2, 127 | n_seqs=10, 128 | seq_repeat=100, 129 | split=AttrDict(train=0.5, val=0.5, test=0.0), 130 | )) 131 | 132 | # HL Pre-Trained Demo Discriminator 133 | demo_discriminator_config = AttrDict( 134 | state_dim=data_spec.state_dim, 135 | normalization='none', 136 | demo_data_conf=data_config, 137 | ) 138 | 139 | # HL Agent 140 | hl_agent_config = copy.deepcopy(base_agent_params) 141 | hl_agent_config.update(AttrDict( 142 | policy=LearnedPPPolicy, 143 | policy_params=hl_policy_params, 144 | critic=MLPCritic, 145 | critic_params=hl_critic_params, 146 | discriminator=DemoDiscriminator, 147 | discriminator_params=demo_discriminator_config, 148 | discriminator_checkpoint=os.path.join(os.environ["EXP_DIR"], "demo_discriminator/maze/maze_discr"), 149 | freeze_discriminator=True, # don't update pretrained discriminator 150 | buffer=UniformReplayBuffer, 151 | buffer_params={'capacity': 1e6,}, 152 | reset_buffer=False, 153 | replay=UniformReplayBuffer, 154 | replay_params={'dump_replay': False, 'capacity': 2e6}, 155 | expert_data_conf=data_config, 156 | expert_data_path=os.path.join(os.environ['DATA_DIR'], 'maze_demos'), 157 | )) 158 | 159 | # SkiLD Parameters 160 | hl_agent_config.update(AttrDict( 161 | lambda_gail_schedule_params=AttrDict(p=0.9), 162 | td_schedule_params=AttrDict(p=10.0), 163 | tdq_schedule_params=AttrDict(p=1.0), 164 | )) 165 | 166 | 167 | ##### Joint Agent ####### 168 | agent_config = AttrDict( 169 | hl_agent=MazeSkiLDAgent, 170 | hl_agent_params=hl_agent_config, 171 | ll_agent=SACAgent, 172 | ll_agent_params=ll_agent_config, 173 | hl_interval=ll_model_params.n_rollout_steps, 174 | log_videos=False, 175 | update_hl=True, 176 | update_ll=False, 177 | ) 178 | 179 | # Sampler 180 | sampler_config = AttrDict( 181 | ) 182 | 183 | # Environment 184 | env_config = AttrDict( 185 | reward_norm=1, 186 | ) 187 | 188 | -------------------------------------------------------------------------------- /skild/configs/demo_rl/office/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import torch 4 | 5 | from spirl.utils.general_utils import AttrDict 6 | from spirl.rl.components.agent import FixedIntervalHierarchicalAgent 7 | from spirl.rl.components.replay_buffer import UniformReplayBuffer 8 | from spirl.rl.components.sampler import HierarchicalSampler 9 | from spirl.rl.components.critic import MLPCritic, SplitObsMLPCritic 10 | from spirl.rl.agents.ac_agent import SACAgent 11 | from spirl.rl.policies.cl_model_policies import ClModelPolicy 12 | from spirl.rl.envs.office import OfficeEnv 13 | from spirl.models.closed_loop_spirl_mdl import ClSPiRLMdl 14 | from spirl.configs.default_data_configs.office import data_spec 15 | 16 | from skild.rl.policies.posterior_policies import LearnedPPPolicy 17 | from skild.models.demo_discriminator import DemoDiscriminator 18 | from skild.rl.agents.skild_agent import SkiLDAgent 19 | 20 | 21 | current_dir = os.path.dirname(os.path.realpath(__file__)) 22 | 23 | notes = 'used to test the RL implementation' 24 | 25 | configuration = { 26 | 'seed': 42, 27 | 'agent': FixedIntervalHierarchicalAgent, 28 | 'environment': OfficeEnv, 29 | 'sampler': HierarchicalSampler, 30 | 'data_dir': '.', 31 | 'num_epochs': 200, 32 | 'max_rollout_len': 350, 33 | 'n_steps_per_epoch': 5e5, 34 | 'log_output_per_epoch': 1000, 35 | 'n_warmup_steps': 2e3, 36 | } 37 | configuration = AttrDict(configuration) 38 | 39 | # Observation Normalization 40 | obs_norm_params = AttrDict( 41 | ) 42 | 43 | base_agent_params = AttrDict( 44 | batch_size=128, 45 | ) 46 | 47 | ###### Low-Level ###### 48 | # LL Policy 49 | ll_model_params = AttrDict( 50 | state_dim=data_spec.state_dim, 51 | action_dim=data_spec.n_actions, 52 | n_rollout_steps=10, 53 | kl_div_weight=5e-4, 54 | nz_vae=10, 55 | nz_enc=128, 56 | nz_mid=128, 57 | n_processing_layers=5, 58 | cond_decode=True, 59 | ) 60 | 61 | # LL Policy 62 | ll_policy_params = AttrDict( 63 | policy_model=ClSPiRLMdl, 64 | policy_model_params=ll_model_params, 65 | policy_model_checkpoint=os.path.join(os.environ["EXP_DIR"], "skill_prior/office/office_prior"), 66 | ) 67 | ll_policy_params.update(ll_model_params) 68 | 69 | # LL Critic 70 | ll_critic_params = AttrDict( 71 | action_dim=data_spec.n_actions, 72 | input_dim=data_spec.state_dim, 73 | output_dim=1, 74 | action_input=True, 75 | unused_obs_size=ll_model_params.nz_vae, # ignore HL policy z output in observation for LL critic 76 | ) 77 | 78 | # LL Agent 79 | ll_agent_config = copy.deepcopy(base_agent_params) 80 | ll_agent_config.update(AttrDict( 81 | policy=ClModelPolicy, 82 | policy_params=ll_policy_params, 83 | critic=SplitObsMLPCritic, 84 | critic_params=ll_critic_params, 85 | )) 86 | 87 | ###### High-Level ######## 88 | # HL Policy 89 | hl_policy_params = AttrDict( 90 | action_dim=ll_model_params.nz_vae, # z-dimension of the skill VAE 91 | input_dim=data_spec.state_dim, 92 | squash_output_dist=True, 93 | max_action_range=2., 94 | prior_model_params=ll_policy_params.policy_model_params, 95 | prior_model=ll_policy_params.policy_model, 96 | prior_model_checkpoint=ll_policy_params.policy_model_checkpoint, 97 | posterior_model=ll_policy_params.policy_model, 98 | posterior_model_params=copy.deepcopy(ll_policy_params.policy_model_params), 99 | posterior_model_checkpoint=os.path.join(os.environ["EXP_DIR"], "skill_posterior/office/office_post"), 100 | ) 101 | hl_policy_params.posterior_model_params.batch_size = base_agent_params.batch_size 102 | 103 | hl_policy_params.policy_model = ll_policy_params.policy_model 104 | hl_policy_params.policy_model_params = copy.deepcopy(ll_policy_params.policy_model_params) 105 | hl_policy_params.policy_model_checkpoint = hl_policy_params.posterior_model_checkpoint 106 | hl_policy_params.policy_model_params.batch_size = base_agent_params.batch_size 107 | 108 | 109 | # HL Critic 110 | hl_critic_params = AttrDict( 111 | action_dim=hl_policy_params.action_dim, 112 | input_dim=hl_policy_params.input_dim, 113 | output_dim=1, 114 | n_layers=2, 115 | nz_mid=256, 116 | action_input=True, 117 | ) 118 | 119 | # HL GAIL Demo Dataset 120 | from spirl.components.data_loader import GlobalSplitVideoDataset 121 | data_config = AttrDict() 122 | data_config.dataset_spec = data_spec 123 | data_config.dataset_spec.update(AttrDict( 124 | crop_rand_subseq=True, 125 | subseq_len=2, 126 | n_seqs=100, 127 | seq_repeat=100, 128 | split=AttrDict(train=0.5, val=0.5, test=0.0), 129 | )) 130 | 131 | # HL Pre-Trained Demo Discriminator 132 | demo_discriminator_config = AttrDict( 133 | state_dim=data_spec.state_dim, 134 | normalization='none', 135 | demo_data_conf=data_config, 136 | ) 137 | 138 | # HL Agent 139 | hl_agent_config = copy.deepcopy(base_agent_params) 140 | hl_agent_config.update(AttrDict( 141 | policy=LearnedPPPolicy, 142 | policy_params=hl_policy_params, 143 | critic=MLPCritic, 144 | critic_params=hl_critic_params, 145 | discriminator=DemoDiscriminator, 146 | discriminator_params=demo_discriminator_config, 147 | discriminator_checkpoint=os.path.join(os.environ["EXP_DIR"], "demo_discriminator/office/office_discr"), 148 | freeze_discriminator=True, # don't update pretrained discriminator 149 | buffer=UniformReplayBuffer, 150 | buffer_params={'capacity': 1e6,}, 151 | reset_buffer=False, 152 | replay=UniformReplayBuffer, 153 | replay_params={'dump_replay': False, 'capacity': 2e6}, 154 | expert_data_conf=data_config, 155 | expert_data_path=os.path.join(os.environ['DATA_DIR'], 'office_demos'), 156 | )) 157 | 158 | # SkiLD Parameters 159 | hl_agent_config.update(AttrDict( 160 | lambda_gail_schedule_params=AttrDict(p=0.9), 161 | fixed_alpha=5.0, 162 | fixed_alpha_q=5.0, 163 | )) 164 | 165 | 166 | ##### Joint Agent ####### 167 | agent_config = AttrDict( 168 | hl_agent=SkiLDAgent, 169 | hl_agent_params=hl_agent_config, 170 | ll_agent=SACAgent, 171 | ll_agent_params=ll_agent_config, 172 | hl_interval=ll_model_params.n_rollout_steps, 173 | log_videos=True, 174 | update_hl=True, 175 | update_ll=False, 176 | ) 177 | 178 | # Sampler 179 | sampler_config = AttrDict( 180 | ) 181 | 182 | # Environment 183 | env_config = AttrDict( 184 | reward_norm=1, 185 | ) 186 | 187 | -------------------------------------------------------------------------------- /skild/configs/imitation/kitchen/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import torch 4 | 5 | from spirl.utils.general_utils import AttrDict 6 | from spirl.rl.components.agent import FixedIntervalHierarchicalAgent 7 | from spirl.rl.components.replay_buffer import UniformReplayBuffer 8 | from spirl.rl.components.sampler import HierarchicalSampler 9 | from spirl.rl.components.critic import MLPCritic, SplitObsMLPCritic 10 | from spirl.rl.agents.ac_agent import SACAgent 11 | from spirl.rl.policies.cl_model_policies import ClModelPolicy 12 | from spirl.rl.envs.kitchen import KitchenEnv 13 | from spirl.models.closed_loop_spirl_mdl import ClSPiRLMdl 14 | from spirl.configs.default_data_configs.kitchen import data_spec 15 | 16 | from skild.rl.policies.posterior_policies import LearnedPPPolicy 17 | from skild.models.demo_discriminator import DemoDiscriminator 18 | from skild.rl.agents.skild_agent import SkiLDAgent 19 | 20 | 21 | current_dir = os.path.dirname(os.path.realpath(__file__)) 22 | 23 | notes = 'used to test the RL implementation' 24 | 25 | configuration = { 26 | 'seed': 42, 27 | 'agent': FixedIntervalHierarchicalAgent, 28 | 'environment': KitchenEnv, 29 | 'sampler': HierarchicalSampler, 30 | 'data_dir': '.', 31 | 'num_epochs': 200, 32 | 'max_rollout_len': 280, 33 | 'n_steps_per_epoch': 1e6, 34 | 'log_output_per_epoch': 1000, 35 | 'n_warmup_steps': 2e3, 36 | } 37 | configuration = AttrDict(configuration) 38 | 39 | # Observation Normalization 40 | obs_norm_params = AttrDict( 41 | ) 42 | 43 | base_agent_params = AttrDict( 44 | batch_size=128, 45 | # update_iterations=XXX, 46 | ) 47 | 48 | ###### Low-Level ###### 49 | # LL Policy 50 | ll_model_params = AttrDict( 51 | state_dim=data_spec.state_dim, 52 | action_dim=data_spec.n_actions, 53 | n_rollout_steps=10, 54 | kl_div_weight=5e-4, 55 | nz_vae=10, 56 | nz_enc=128, 57 | nz_mid=128, 58 | n_processing_layers=5, 59 | cond_decode=True, 60 | ) 61 | 62 | # LL Policy 63 | ll_policy_params = AttrDict( 64 | policy_model=ClSPiRLMdl, 65 | policy_model_params=ll_model_params, 66 | policy_model_checkpoint=os.path.join(os.environ["EXP_DIR"], "skill_prior/kitchen/kitchen_prior"), 67 | ) 68 | ll_policy_params.update(ll_model_params) 69 | 70 | # LL Critic 71 | ll_critic_params = AttrDict( 72 | action_dim=data_spec.n_actions, 73 | input_dim=data_spec.state_dim, 74 | output_dim=1, 75 | action_input=True, 76 | unused_obs_size=ll_model_params.nz_vae, # ignore HL policy z output in observation for LL critic 77 | ) 78 | 79 | # LL Agent 80 | ll_agent_config = copy.deepcopy(base_agent_params) 81 | ll_agent_config.update(AttrDict( 82 | policy=ClModelPolicy, 83 | policy_params=ll_policy_params, 84 | critic=SplitObsMLPCritic, 85 | critic_params=ll_critic_params, 86 | )) 87 | 88 | ###### High-Level ######## 89 | # HL Policy 90 | hl_policy_params = AttrDict( 91 | action_dim=ll_model_params.nz_vae, # z-dimension of the skill VAE 92 | input_dim=data_spec.state_dim, 93 | squash_output_dist=True, 94 | max_action_range=2., 95 | prior_model_params=ll_policy_params.policy_model_params, 96 | prior_model=ll_policy_params.policy_model, 97 | prior_model_checkpoint=ll_policy_params.policy_model_checkpoint, 98 | posterior_model=ll_policy_params.policy_model, 99 | posterior_model_params=copy.deepcopy(ll_policy_params.policy_model_params), 100 | posterior_model_checkpoint=os.path.join(os.environ["EXP_DIR"], "skill_posterior/kitchen/kitchen_post"), 101 | ) 102 | hl_policy_params.posterior_model_params.batch_size = base_agent_params.batch_size 103 | 104 | hl_policy_params.policy_model = ll_policy_params.policy_model 105 | hl_policy_params.policy_model_params = copy.deepcopy(ll_policy_params.policy_model_params) 106 | hl_policy_params.policy_model_checkpoint = hl_policy_params.posterior_model_checkpoint 107 | hl_policy_params.policy_model_params.batch_size = base_agent_params.batch_size 108 | 109 | 110 | # HL Critic 111 | hl_critic_params = AttrDict( 112 | action_dim=hl_policy_params.action_dim, 113 | input_dim=hl_policy_params.input_dim, 114 | output_dim=1, 115 | n_layers=2, 116 | nz_mid=256, 117 | action_input=True, 118 | ) 119 | 120 | # HL GAIL Demo Dataset 121 | from spirl.components.data_loader import GlobalSplitVideoDataset 122 | data_config = AttrDict() 123 | data_config.dataset_spec = data_spec 124 | data_config.dataset_spec.update(AttrDict( 125 | crop_rand_subseq=True, 126 | subseq_len=2, 127 | filter_indices=[[320, 337], [339, 344]], 128 | demo_repeats=10, 129 | )) 130 | 131 | # HL Pre-Trained Demo Discriminator 132 | demo_discriminator_config = AttrDict( 133 | state_dim=data_spec.state_dim, 134 | normalization='none', 135 | demo_data_conf=data_config, 136 | ) 137 | 138 | # HL Agent 139 | hl_agent_config = copy.deepcopy(base_agent_params) 140 | hl_agent_config.update(AttrDict( 141 | policy=LearnedPPPolicy, 142 | policy_params=hl_policy_params, 143 | critic=MLPCritic, 144 | critic_params=hl_critic_params, 145 | discriminator=DemoDiscriminator, 146 | discriminator_params=demo_discriminator_config, 147 | discriminator_checkpoint=os.path.join(os.environ["EXP_DIR"], "demo_discriminator/kitchen/kitchen_discr"), 148 | freeze_discriminator=False, # don't update pretrained discriminator 149 | discriminator_updates=5e-4, 150 | buffer=UniformReplayBuffer, 151 | buffer_params={'capacity': 1e6,}, 152 | reset_buffer=False, 153 | replay=UniformReplayBuffer, 154 | replay_params={'dump_replay': False, 'capacity': 2e6}, 155 | expert_data_conf=data_config, 156 | expert_data_path=".", 157 | )) 158 | 159 | # SkiLD Parameters 160 | hl_agent_config.update(AttrDict( 161 | lambda_gail_schedule_params=AttrDict(p=0.9), 162 | fixed_alpha=1e-1, 163 | fixed_alpha_q=1e-1, 164 | )) 165 | 166 | 167 | ##### Joint Agent ####### 168 | agent_config = AttrDict( 169 | hl_agent=SkiLDAgent, 170 | hl_agent_params=hl_agent_config, 171 | ll_agent=SACAgent, 172 | ll_agent_params=ll_agent_config, 173 | hl_interval=ll_model_params.n_rollout_steps, 174 | log_videos=True, 175 | update_hl=True, 176 | update_ll=False, 177 | ) 178 | 179 | # Sampler 180 | sampler_config = AttrDict( 181 | ) 182 | 183 | # Environment 184 | env_config = AttrDict( 185 | reward_norm=1, 186 | name='kitchen-kbts-v0', 187 | ) 188 | 189 | -------------------------------------------------------------------------------- /skild/configs/imitation/maze/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import torch 4 | 5 | from spirl.utils.general_utils import AttrDict 6 | from spirl.rl.components.agent import FixedIntervalHierarchicalAgent 7 | from spirl.rl.components.replay_buffer import UniformReplayBuffer 8 | from spirl.rl.components.sampler import HierarchicalSampler 9 | from spirl.rl.components.critic import MLPCritic, SplitObsMLPCritic 10 | from spirl.rl.agents.ac_agent import SACAgent 11 | from spirl.rl.policies.cl_model_policies import ClModelPolicy 12 | from spirl.models.closed_loop_spirl_mdl import ClSPiRLMdl 13 | from spirl.configs.default_data_configs.maze import data_spec 14 | 15 | from skild.rl.policies.posterior_policies import LearnedPPPolicy 16 | from skild.models.demo_discriminator import DemoDiscriminator 17 | from skild.rl.envs.maze import ACRandMaze0S40Env 18 | from skild.rl.agents.skild_agent import SkiLDAgent 19 | from skild.data.maze.src.maze_agents import MazeSkiLDAgent 20 | 21 | 22 | current_dir = os.path.dirname(os.path.realpath(__file__)) 23 | 24 | notes = 'used to test the RL implementation' 25 | 26 | configuration = { 27 | 'seed': 42, 28 | 'agent': FixedIntervalHierarchicalAgent, 29 | 'environment': ACRandMaze0S40Env, 30 | 'sampler': HierarchicalSampler, 31 | 'data_dir': '.', 32 | 'num_epochs': 200, 33 | 'max_rollout_len': 2000, 34 | 'n_steps_per_epoch': 1e5, 35 | 'log_output_per_epoch': 1000, 36 | 'n_warmup_steps': 2e3, 37 | } 38 | configuration = AttrDict(configuration) 39 | 40 | # Observation Normalization 41 | obs_norm_params = AttrDict( 42 | ) 43 | 44 | base_agent_params = AttrDict( 45 | batch_size=128, 46 | ) 47 | 48 | ###### Low-Level ###### 49 | # LL Policy 50 | ll_model_params = AttrDict( 51 | state_dim=data_spec.state_dim, 52 | action_dim=data_spec.n_actions, 53 | n_rollout_steps=10, 54 | kl_div_weight=1e-3, 55 | nz_vae=10, 56 | nz_enc=128, 57 | nz_mid=128, 58 | n_processing_layers=5, 59 | cond_decode=True, 60 | ) 61 | 62 | # LL Policy 63 | ll_policy_params = AttrDict( 64 | policy_model=ClSPiRLMdl, 65 | policy_model_params=ll_model_params, 66 | policy_model_checkpoint=os.path.join(os.environ["EXP_DIR"], "skill_prior/maze/maze_prior"), 67 | ) 68 | ll_policy_params.update(ll_model_params) 69 | 70 | # LL Critic 71 | ll_critic_params = AttrDict( 72 | action_dim=data_spec.n_actions, 73 | input_dim=data_spec.state_dim, 74 | output_dim=1, 75 | action_input=True, 76 | unused_obs_size=ll_model_params.nz_vae, # ignore HL policy z output in observation for LL critic 77 | ) 78 | 79 | # LL Agent 80 | ll_agent_config = copy.deepcopy(base_agent_params) 81 | ll_agent_config.update(AttrDict( 82 | policy=ClModelPolicy, 83 | policy_params=ll_policy_params, 84 | critic=SplitObsMLPCritic, 85 | critic_params=ll_critic_params, 86 | )) 87 | 88 | ###### High-Level ######## 89 | # HL Policy 90 | hl_policy_params = AttrDict( 91 | action_dim=ll_model_params.nz_vae, # z-dimension of the skill VAE 92 | input_dim=data_spec.state_dim, 93 | squash_output_dist=True, 94 | max_action_range=2., 95 | prior_model_params=ll_policy_params.policy_model_params, 96 | prior_model=ll_policy_params.policy_model, 97 | prior_model_checkpoint=ll_policy_params.policy_model_checkpoint, 98 | posterior_model=ll_policy_params.policy_model, 99 | posterior_model_params=copy.deepcopy(ll_policy_params.policy_model_params), 100 | posterior_model_checkpoint=os.path.join(os.environ["EXP_DIR"], "skill_posterior/maze/maze_post"), 101 | ) 102 | hl_policy_params.posterior_model_params.batch_size = base_agent_params.batch_size 103 | 104 | hl_policy_params.policy_model = ll_policy_params.policy_model 105 | hl_policy_params.policy_model_params = copy.deepcopy(ll_policy_params.policy_model_params) 106 | hl_policy_params.policy_model_checkpoint = hl_policy_params.prior_model_checkpoint 107 | hl_policy_params.policy_model_params.batch_size = base_agent_params.batch_size 108 | 109 | 110 | # HL Critic 111 | hl_critic_params = AttrDict( 112 | action_dim=hl_policy_params.action_dim, 113 | input_dim=hl_policy_params.input_dim, 114 | output_dim=1, 115 | n_layers=2, 116 | nz_mid=256, 117 | action_input=True, 118 | ) 119 | 120 | # HL GAIL Demo Dataset 121 | from spirl.components.data_loader import GlobalSplitVideoDataset 122 | data_config = AttrDict() 123 | data_config.dataset_spec = data_spec 124 | data_config.dataset_spec.update(AttrDict( 125 | crop_rand_subseq=True, 126 | subseq_len=2, 127 | n_seqs=10, 128 | seq_repeat=100, 129 | split=AttrDict(train=0.5, val=0.5, test=0.0), 130 | )) 131 | 132 | # HL Pre-Trained Demo Discriminator 133 | demo_discriminator_config = AttrDict( 134 | state_dim=data_spec.state_dim, 135 | normalization='none', 136 | demo_data_conf=data_config, 137 | ) 138 | 139 | # HL Agent 140 | hl_agent_config = copy.deepcopy(base_agent_params) 141 | hl_agent_config.update(AttrDict( 142 | policy=LearnedPPPolicy, 143 | policy_params=hl_policy_params, 144 | critic=MLPCritic, 145 | critic_params=hl_critic_params, 146 | discriminator=DemoDiscriminator, 147 | discriminator_params=demo_discriminator_config, 148 | discriminator_checkpoint=os.path.join(os.environ["EXP_DIR"], "demo_discriminator/maze/maze_discr"), 149 | freeze_discriminator=False, # don't update pretrained discriminator 150 | discriminator_updates=0.2, 151 | buffer=UniformReplayBuffer, 152 | buffer_params={'capacity': 1e6,}, 153 | reset_buffer=False, 154 | replay=UniformReplayBuffer, 155 | replay_params={'dump_replay': False, 'capacity': 2e6}, 156 | expert_data_conf=data_config, 157 | expert_data_path=os.path.join(os.environ['DATA_DIR'], 'maze_demos'), 158 | )) 159 | 160 | # SkiLD Parameters 161 | hl_agent_config.update(AttrDict( 162 | lambda_gail_schedule_params=AttrDict(p=0.9), 163 | td_schedule_params=AttrDict(p=10.0), 164 | tdq_schedule_params=AttrDict(p=1.0), 165 | )) 166 | 167 | 168 | ##### Joint Agent ####### 169 | agent_config = AttrDict( 170 | hl_agent=MazeSkiLDAgent, 171 | hl_agent_params=hl_agent_config, 172 | ll_agent=SACAgent, 173 | ll_agent_params=ll_agent_config, 174 | hl_interval=ll_model_params.n_rollout_steps, 175 | log_videos=False, 176 | update_hl=True, 177 | update_ll=False, 178 | ) 179 | 180 | # Sampler 181 | sampler_config = AttrDict( 182 | ) 183 | 184 | # Environment 185 | env_config = AttrDict( 186 | reward_norm=1, 187 | ) 188 | 189 | -------------------------------------------------------------------------------- /skild/configs/skill_posterior/kitchen/conf.py: -------------------------------------------------------------------------------- 1 | from skild.configs.skill_prior.kitchen.conf import * 2 | 3 | data_config.dataset_spec.filter_indices = [[320, 337], [339, 344]] # use only demos for one task (here: KBTS) 4 | data_config.dataset_spec.demo_repeats = 10 # repeat those demos N times 5 | 6 | model_config.embedding_checkpoint = os.path.join(os.environ["EXP_DIR"], 7 | "skill_prior/kitchen/kitchen_prior/weights") 8 | -------------------------------------------------------------------------------- /skild/configs/skill_posterior/maze/conf.py: -------------------------------------------------------------------------------- 1 | from skild.configs.skill_prior.maze.conf import * 2 | 3 | configuration['data_dir'] = os.path.join(os.environ['DATA_DIR'], 'maze_demos') 4 | data_config.dataset_spec.n_seqs = 5 # number of demos 5 | data_config.dataset_spec.seq_repeat = 30 # how often to repeat these demos 6 | 7 | configuration['epoch_cycles_train'] = 4200 8 | 9 | model_config.embedding_checkpoint = os.path.join(os.environ["EXP_DIR"], 10 | "skill_prior/maze/maze_prior/weights") 11 | -------------------------------------------------------------------------------- /skild/configs/skill_posterior/office/conf.py: -------------------------------------------------------------------------------- 1 | from skild.configs.skill_prior.office.conf import * 2 | 3 | configuration['data_dir'] = os.path.join(os.environ['DATA_DIR'], 'office_demos') 4 | data_config.dataset_spec.n_seqs = 50 # number of demos 5 | data_config.dataset_spec.seq_repeat = 3 # how often to repeat these demos 6 | 7 | configuration['epoch_cycles_train'] = 6000 8 | 9 | model_config.embedding_checkpoint = os.path.join(os.environ["EXP_DIR"], 10 | "skill_prior/office/office_prior/weights") 11 | -------------------------------------------------------------------------------- /skild/configs/skill_prior/kitchen/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from spirl.models.closed_loop_spirl_mdl import ClSPiRLMdl 4 | from spirl.components.logger import Logger 5 | from spirl.utils.general_utils import AttrDict 6 | from spirl.configs.default_data_configs.kitchen import data_spec 7 | from spirl.components.evaluator import TopOfNSequenceEvaluator 8 | 9 | current_dir = os.path.dirname(os.path.realpath(__file__)) 10 | 11 | 12 | configuration = { 13 | 'model': ClSPiRLMdl, 14 | 'logger': Logger, 15 | 'data_dir': '.', 16 | 'epoch_cycles_train': 50, 17 | 'num_epochs': 100, 18 | 'evaluator': TopOfNSequenceEvaluator, 19 | 'top_of_n_eval': 100, 20 | 'top_comp_metric': 'mse', 21 | } 22 | configuration = AttrDict(configuration) 23 | 24 | model_config = AttrDict( 25 | state_dim=data_spec.state_dim, 26 | action_dim=data_spec.n_actions, 27 | n_rollout_steps=10, 28 | kl_div_weight=5e-4, 29 | nz_enc=128, 30 | nz_mid=128, 31 | n_processing_layers=5, 32 | cond_decode=True, 33 | ) 34 | 35 | # Dataset 36 | data_config = AttrDict() 37 | data_config.dataset_spec = data_spec 38 | data_config.dataset_spec.subseq_len = model_config.n_rollout_steps + 1 # flat last action from seq gets cropped 39 | -------------------------------------------------------------------------------- /skild/configs/skill_prior/maze/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from spirl.models.closed_loop_spirl_mdl import ClSPiRLMdl 4 | from spirl.components.logger import Logger 5 | from spirl.utils.general_utils import AttrDict 6 | from spirl.configs.default_data_configs.maze import data_spec 7 | from spirl.components.evaluator import TopOfNSequenceEvaluator 8 | 9 | current_dir = os.path.dirname(os.path.realpath(__file__)) 10 | 11 | 12 | configuration = { 13 | 'model': ClSPiRLMdl, 14 | 'logger': Logger, 15 | 'data_dir': os.path.join(os.environ['DATA_DIR'], 'maze_TA'), 16 | 'epoch_cycles_train': 250, 17 | 'num_epochs': 100, 18 | 'evaluator': TopOfNSequenceEvaluator, 19 | 'top_of_n_eval': 100, 20 | 'top_comp_metric': 'mse', 21 | } 22 | configuration = AttrDict(configuration) 23 | 24 | model_config = AttrDict( 25 | state_dim=data_spec.state_dim, 26 | action_dim=data_spec.n_actions, 27 | n_rollout_steps=10, 28 | kl_div_weight=1e-3, 29 | nz_enc=128, 30 | nz_mid=128, 31 | n_processing_layers=5, 32 | cond_decode=True, 33 | ) 34 | 35 | # Dataset 36 | data_config = AttrDict() 37 | data_config.dataset_spec = data_spec 38 | data_config.dataset_spec.subseq_len = model_config.n_rollout_steps + 1 # flat last action from seq gets cropped 39 | -------------------------------------------------------------------------------- /skild/configs/skill_prior/office/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from spirl.models.closed_loop_spirl_mdl import ClSPiRLMdl 4 | from spirl.components.logger import Logger 5 | from spirl.utils.general_utils import AttrDict 6 | from spirl.configs.default_data_configs.office import data_spec 7 | from spirl.components.evaluator import TopOfNSequenceEvaluator 8 | 9 | current_dir = os.path.dirname(os.path.realpath(__file__)) 10 | 11 | 12 | configuration = { 13 | 'model': ClSPiRLMdl, 14 | 'logger': Logger, 15 | 'data_dir': os.path.join(os.environ['DATA_DIR'], 'office_TA'), 16 | 'epoch_cycles_train': 300, 17 | 'num_epochs': 100, 18 | 'evaluator': TopOfNSequenceEvaluator, 19 | 'top_of_n_eval': 100, 20 | 'top_comp_metric': 'mse', 21 | } 22 | configuration = AttrDict(configuration) 23 | 24 | model_config = AttrDict( 25 | state_dim=data_spec.state_dim, 26 | action_dim=data_spec.n_actions, 27 | n_rollout_steps=10, 28 | kl_div_weight=5e-4, 29 | nz_enc=128, 30 | nz_mid=128, 31 | n_processing_layers=5, 32 | cond_decode=True, 33 | ) 34 | 35 | # Dataset 36 | data_config = AttrDict() 37 | data_config.dataset_spec = data_spec 38 | data_config.dataset_spec.subseq_len = model_config.n_rollout_steps + 1 # flat last action from seq gets cropped 39 | -------------------------------------------------------------------------------- /skild/data/kitchen/README.md: -------------------------------------------------------------------------------- 1 | # Choosing Kitchen Target Tasks 2 | 3 | In the kitchen environment a task defines the consecutive execution of four subtasks. 4 | Some subtask sequences can be more challenging for agents to learn than others. For SkiLD, as well as for SPiRL and any 5 | other approach that leverages prior experience, the task complexity is mainly influenced by how well the respective 6 | subtask transitions are represented in the prior experience data. 7 | 8 | We use the training data of Gupta et al., 2020 for training our models on the kitchen tasks. 9 | The subtask transitions in this dataset are not uniformly distributed, i.e., certain subtask sequences are more likely 10 | than others. Thus, we can define easier tasks in the kitchen environments as those that require more likely subtask 11 | transitions. Conversely, more challenging tasks will require unlikely or unseen subtask transitions. 12 | 13 | In the SkiLD paper, we analyze the effect of target tasks of differing alignment with the pre-training data 14 | in the kitchen environment (Section 4.4). We also provide an analysis of the subtask transition probabilities in the 15 | dataset of Gupta et al. (Figure 14, see below), which we can use to determine tasks of varying complexity. 16 | 17 |

18 | 19 | 20 | ## Changing the Kitchen Target Task 21 | 22 | By default, the configs provided in this repository can be used to train a SkiLD agent on the 23 | `Kettle --> Bottom Burner --> Top Burner --> Slide Cabinet` task, whose transitions relatively well-represented in the 24 | pre-training data. In order to change the target task, the value for `filter_indices` in the configs for training 25 | of the [skill posterior](../../../skild/configs/skill_posterior/kitchen/conf.py#L3) and 26 | [demonstration discriminator](../../../skild/configs/demo_discriminator/kitchen/conf.py#L35) need to be adjusted. 27 | These indices determine the sequences from the pre-training data that are used as demonstrations for the downstream task. 28 | 29 | To obtain a mapping from the sequence of solved subtasks, and thus the task, to the trajectory indices, we provide 30 | [a scirpt](../../../skild/data/kitchen/kitchen_subtasks.py) that determines the subtask sequence for each of the 603 sequences in the pre-training dataset. 31 | If you want to change the target task, simply run this script, determine the indices for the desired subtask and adjust 32 | the config files linked above. For example, for the challenging task 33 | `Microwave --> Light Switch --> Slide Cabinet --> Hinge Cabinet` used in the paper, the indices need to be changed to 34 | `[190, 210]` for training of skill posterior and discriminator. 35 | 36 | If you want to run RL on the new target task, you need to change the name of the Kitchen environment in the RL config 37 | accordingly, e.g., [HERE](../../../skild/configs/demo_rl/kitchen/conf.py#L185) from `kitchen-kbts-v0` to `kitchen-mlsh-v0`. 38 | 39 | 40 | -------------------------------------------------------------------------------- /skild/data/kitchen/kitchen_subtasks.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import gym 3 | import d4rl 4 | import numpy as np 5 | 6 | from spirl.utils.general_utils import AttrDict 7 | from spirl.configs.default_data_configs.kitchen import data_spec 8 | 9 | 10 | OBJS = ['bottom burner', 'top burner', 'light switch', 'slide cabinet', 'hinge cabinet', 'microwave', 'kettle'] 11 | OBS_ELEMENT_INDICES = { 12 | 'bottom burner': np.array([11, 12]), 13 | 'top burner': np.array([15, 16]), 14 | 'light switch': np.array([17, 18]), 15 | 'slide cabinet': np.array([19]), 16 | 'hinge cabinet': np.array([20, 21]), 17 | 'microwave': np.array([22]), 18 | 'kettle': np.array([23, 24, 25, 26, 27, 28, 29]), 19 | } 20 | OBS_ELEMENT_GOALS = { 21 | 'bottom burner': np.array([-0.88, -0.01]), 22 | 'top burner': np.array([-0.92, -0.01]), 23 | 'light switch': np.array([-0.69, -0.05]), 24 | 'slide cabinet': np.array([0.37]), 25 | 'hinge cabinet': np.array([0., 1.45]), 26 | 'microwave': np.array([-0.75]), 27 | 'kettle': np.array([-0.23, 0.75, 1.62, 0.99, 0., 0., -0.06]), 28 | } 29 | BONUS_THRESH = 0.3 30 | 31 | 32 | ## Demo Dataset 33 | demo_data_config = AttrDict() 34 | demo_data_config.device = 'cpu' 35 | demo_data_config.dataset_spec = data_spec 36 | demo_data_config.dataset_spec.crop_rand_subseq = True 37 | demo_data_config.dataset_spec.subseq_len = 1+1+(3-1) 38 | 39 | loader = data_spec.dataset_class('.', demo_data_config, resolution=32, phase='train', shuffle=True, dataset_size=-1) 40 | seqs = loader.seqs 41 | 42 | ## determine achieved subgoals + respective time steps 43 | n_seqs, n_objs = len(seqs), len(OBJS) 44 | subtask_steps = np.Inf * np.ones((n_seqs, n_objs)) 45 | for s_idx, seq in tqdm.tqdm(enumerate(seqs)): 46 | for o_idx, obj in enumerate(OBJS): 47 | for t, state in enumerate(seq.states): 48 | obj_state, obj_goal = state[OBS_ELEMENT_INDICES[obj]], OBS_ELEMENT_GOALS[obj] 49 | dist = np.linalg.norm(obj_state - obj_goal) 50 | if dist < BONUS_THRESH and subtask_steps[s_idx, o_idx] == np.Inf: 51 | subtask_steps[s_idx, o_idx] = t 52 | 53 | ## print subtask orders 54 | print("\n\n") 55 | 56 | subtask_freqs = {k+'_'+j+'_'+i+'_'+kk: 0 for k in OBJS for j in OBJS for i in OBJS for kk in OBJS} 57 | for s_idx, subtasks in enumerate(subtask_steps): 58 | min_task_idxs = np.argsort(subtasks)[:4] 59 | objs = [OBJS[i] for i in min_task_idxs] 60 | subtask_freqs[OBJS[min_task_idxs[0]]+'_'+OBJS[min_task_idxs[1]]\ 61 | +'_'+OBJS[min_task_idxs[2]]+'_'+OBJS[min_task_idxs[3]]] += 1 62 | print("seq {}: {}".format(s_idx, objs)) 63 | 64 | print("\n\n") 65 | for k in subtask_freqs: 66 | if subtask_freqs[k] > 0: 67 | print(k,": ", subtask_freqs[k]) 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /skild/data/maze/src/maze_agents.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | from spirl.rl.components.replay_buffer import UniformReplayBuffer 4 | from spirl.data.maze.src.maze_agents import MazeAgent 5 | from skild.rl.agents.skild_agent import SkiLDAgent 6 | 7 | 8 | 9 | class MazeSkiLDAgent(SkiLDAgent, MazeAgent): 10 | def __init__(self, *args, **kwargs): 11 | SkiLDAgent.__init__(self, *args, **kwargs) 12 | self.vis_replay_buffer = UniformReplayBuffer({'capacity': 1e7}) # purely for logging purposes 13 | 14 | def visualize(self, logger, rollout_storage, step): 15 | MazeAgent._vis_replay_buffer(self, logger, step) 16 | MazeSkiLDAgent._vis_replay_buffer(self, logger, step) 17 | 18 | def _vis_replay_buffer(self, logger, step): 19 | # visualize discriminator rewards 20 | if 'discr_reward' in self.vis_replay_buffer: 21 | # get data 22 | size = self.vis_replay_buffer.size 23 | start = max(0, size-5000) 24 | states = self.vis_replay_buffer.get().observation[start:size, :2] 25 | rewards = self.vis_replay_buffer.get().discr_reward[start:size] 26 | 27 | fig = plt.figure() 28 | plt.scatter(states[:, 0], states[:, 1], s=5, c=rewards, cmap='RdYlGn') 29 | plt.axis("equal") 30 | logger.log_plot(fig, "discr_reward_vis", step) 31 | plt.close(fig) 32 | -------------------------------------------------------------------------------- /skild/models/demo_discriminator.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import copy 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from matplotlib import pyplot as plt 7 | 8 | from spirl.components.base_model import BaseModel 9 | from spirl.components.logger import Logger 10 | from spirl.components.checkpointer import CheckpointHandler 11 | from spirl.modules.losses import BCELogitsLoss 12 | from spirl.modules.subnetworks import Predictor, Encoder 13 | from spirl.utils.general_utils import AttrDict, ParamDict, map_dict 14 | from spirl.utils.vis_utils import fig2img 15 | from spirl.utils.pytorch_utils import RemoveSpatial, ResizeSpatial 16 | from spirl.modules.layers import LayerBuilderParams 17 | 18 | 19 | class DemoDiscriminator(BaseModel): 20 | """Simple feed forward predictor network that distinguishes demo and non-demo states.""" 21 | def __init__(self, params, logger=None): 22 | BaseModel.__init__(self, logger) 23 | self._hp = self._default_hparams() 24 | self._hp.overwrite(params) # override defaults with config file 25 | self._hp.builder = LayerBuilderParams(self._hp.use_convs, self._hp.normalization) 26 | self.device = self._hp.device 27 | 28 | # set up demo dataset 29 | if self._hp.demo_data_path is not None: 30 | self._hp.demo_data_conf.device = self.device 31 | self._demo_data_loader = self._hp.demo_data_conf.dataset_spec.dataset_class( 32 | self._hp.demo_data_path, self._hp.demo_data_conf, resolution=self._hp.demo_data_conf.dataset_spec.res, 33 | phase="train", shuffle=True).get_data_loader(self._hp.batch_size, 34 | n_repeat=10000) # making new iterators is slow, so repeat often 35 | self._demo_data_iter = iter(self._demo_data_loader) 36 | 37 | self.build_network() 38 | 39 | def _default_hparams(self): 40 | # put new parameters in here: 41 | return super()._default_hparams().overwrite(ParamDict({ 42 | 'state_dim': None, 43 | 'action_dim': None, 44 | 'use_convs': False, 45 | 'device': None, 46 | 'nz_enc': 32, # number of dimensions in encoder-latent space 47 | 'nz_mid': 32, # number of dimensions for internal feature spaces 48 | 'n_processing_layers': 3, # number of layers in MLPs 49 | 'action_input': False, # if True, conditions on action in addition to state 50 | 'demo_data_conf': {}, # data configuration for demo dataset 51 | 'demo_data_path': None, # path to demo data directory 52 | })) 53 | 54 | def build_network(self): 55 | assert not self._hp.use_convs # currently only supports non-image inputs 56 | self.demo_discriminator = self.build_discriminator() 57 | 58 | def forward(self, inputs): 59 | """forward pass at training time""" 60 | # run discriminator in test-mode if no dataset for training is given 61 | if self._hp.demo_data_path is None: 62 | return self.demo_discriminator(inputs) 63 | 64 | output = AttrDict() 65 | 66 | # sample demo inputs 67 | demo_inputs = self._get_demo_batch() 68 | 69 | # run discriminator on demo and non-demo data 70 | output.demo_logits = self.demo_discriminator(self._discriminator_input(demo_inputs)) 71 | output.nondemo_logits = self.demo_discriminator(self._discriminator_input(inputs)) 72 | output.logits = torch.cat((output.demo_logits, output.nondemo_logits)) 73 | 74 | # compute targets for discriminator outputs 75 | output.targets = torch.cat((torch.ones_like(output.demo_logits), torch.zeros_like(output.nondemo_logits))) 76 | 77 | return output 78 | 79 | def loss(self, model_output, inputs): 80 | losses = AttrDict() 81 | 82 | # discriminator loss 83 | losses.discriminator_loss = BCELogitsLoss(1.)(model_output.logits, model_output.targets) 84 | 85 | losses.total = self._compute_total_loss(losses) 86 | return losses 87 | 88 | def _log_outputs(self, model_output, inputs, losses, step, log_images, phase, logger, **logging_kwargs): 89 | # log videos/gifs in tensorboard 90 | if 'demo_logits' in model_output: 91 | logger.log_scalar(torch.sigmoid(model_output.demo_logits).mean(), "p_demo", step, phase) 92 | logger.log_scalar(torch.sigmoid(model_output.nondemo_logits).mean(), "p_nondemo", step, phase) 93 | 94 | def build_discriminator(self): 95 | return Predictor(self._hp, input_size=self.discriminator_input_size, output_size=1, 96 | num_layers=self._hp.n_processing_layers, mid_size=self._hp.nz_mid) 97 | 98 | def _discriminator_input(self, inputs): 99 | if not self._hp.action_input: 100 | return inputs.states[:, 0] 101 | else: 102 | return torch.cat((inputs.states[:, 0], inputs.actions[:, 0]), dim=-1) 103 | 104 | def _get_demo_batch(self): 105 | try: 106 | demo_batch = next(self._demo_data_iter) 107 | except StopIteration: 108 | self._demo_data_iter = iter(self._demo_data_loader) 109 | demo_batch = next(self._demo_data_iter) 110 | return AttrDict(map_dict(lambda x: x.to(self.device), demo_batch)) 111 | 112 | def evaluate_discriminator(self, state): 113 | """Evaluates discriminator probability.""" 114 | return nn.Sigmoid()(self.demo_discriminator(state)) 115 | 116 | @property 117 | def resolution(self): 118 | return 64 # return dummy resolution, images are not used by this model 119 | 120 | @property 121 | def discriminator_input_size(self): 122 | return self._hp.state_dim if not self._hp.action_input else self._hp.state_dim + self._hp.action_dim 123 | 124 | @contextmanager 125 | def val_mode(self): 126 | pass 127 | yield 128 | pass 129 | 130 | 131 | class ImageDemoDiscriminator(DemoDiscriminator): 132 | """Implements demo discriminator with image input.""" 133 | def _default_hparams(self): 134 | default_dict = ParamDict({ 135 | 'discriminator_input_res': 32, # input resolution of prior images 136 | 'encoder_ngf': 8, # number of feature maps in shallowest level of encoder 137 | 'n_input_frames': 1, # number of prior input frames 138 | }) 139 | # add new params to parent params 140 | return super()._default_hparams().overwrite(default_dict) 141 | 142 | def _updated_encoder_params(self): 143 | params = copy.deepcopy(self._hp) 144 | return params.overwrite(AttrDict( 145 | use_convs=True, 146 | use_skips=False, # no skip connections needed bc we are not reconstructing 147 | img_sz=self._hp.discriminator_input_res, # image resolution 148 | input_nc=3*self._hp.n_input_frames, # number of input feature maps 149 | ngf=self._hp.encoder_ngf, # number of feature maps in shallowest level 150 | nz_enc=self.discriminator_input_size, # size of image encoder output feature 151 | builder=LayerBuilderParams(use_convs=True, normalization=self._hp.normalization) 152 | )) 153 | 154 | def build_discriminator(self): 155 | return nn.Sequential( 156 | ResizeSpatial(self._hp.discriminator_input_res), 157 | Encoder(self._updated_encoder_params()), 158 | RemoveSpatial(), 159 | super().build_discriminator(), 160 | ) 161 | 162 | def _discriminator_input(self, inputs): 163 | assert not self._hp.action_input # action input currently not supported for image discriminator 164 | return inputs.images[:, :self._hp.n_input_frames]\ 165 | .reshape(inputs.images.shape[0], -1, self.resolution, self.resolution) 166 | 167 | def filter_input(self, raw_input): 168 | assert raw_input.shape[-1] == raw_input.shape[-2] == self.resolution 169 | assert len(raw_input.shape) == 4 # [batch, channels, res, res] 170 | return raw_input[:, :self._hp.n_input_frames*3] 171 | 172 | @property 173 | def discriminator_input_size(self): 174 | return self._hp.nz_mid 175 | 176 | @property 177 | def resolution(self): 178 | return self._hp.discriminator_input_res 179 | 180 | 181 | class DemoDiscriminatorLogger(Logger): 182 | """ 183 | Logger for Skill Space model. No extra methods needed to implement by 184 | environment-specific logger implementation. 185 | """ 186 | N_LOGGING_SAMPLES = 5000 # number of samples from demo / non-demo used for logging 187 | 188 | def visualize(self, model_output, inputs, losses, step, phase, logger): 189 | pass 190 | 191 | @staticmethod 192 | def plot_discriminator_samples(demo_samples, non_demo_samples, logger, step, phase): 193 | # plot histogram of demo and non-demo sample probabilities 194 | bins = np.linspace(0, 1, 50) 195 | fig = plt.figure() 196 | plt.hist(demo_samples.p_demo, bins, alpha=0.5, label='demo') 197 | plt.hist(non_demo_samples.p_demo, bins, alpha=0.5, label='nondemo') 198 | plt.legend(loc='upper right') 199 | logger.log_images([fig2img(fig)], "p_demo_hist", step, phase) 200 | plt.close(fig) 201 | 202 | # plot 2D map of states with color-coded demo probabilities 203 | fig = plt.figure() 204 | plt.scatter(np.concatenate((demo_samples.states[:, 0], non_demo_samples.states[:, 0])), 205 | np.concatenate((demo_samples.states[:, 1], non_demo_samples.states[:, 1])), s=5, 206 | c=np.concatenate((demo_samples.p_demo, non_demo_samples.p_demo)), cmap='RdYlGn') 207 | plt.axis("equal") 208 | logger.log_images([fig2img(fig)], "maze_p_demo_vis", step, phase) 209 | plt.close(fig) 210 | -------------------------------------------------------------------------------- /skild/rl/agents/gail_agent.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | from torch.nn import BCEWithLogitsLoss 5 | from torch import autograd 6 | 7 | from spirl.utils.general_utils import ParamDict, AttrDict, map_dict, ConstantSchedule 8 | from spirl.utils.pytorch_utils import map2torch, map2np, ten2ar, update_optimizer_lr 9 | from spirl.rl.agents.ac_agent import SACAgent 10 | from spirl.rl.components.agent import BaseAgent 11 | from spirl.rl.agents.prior_sac_agent import ActionPriorSACAgent 12 | from skild.rl.agents.ppo_agent import PPOAgent 13 | 14 | 15 | class GAILAgent(PPOAgent): 16 | """Implements GAIL-based agent. Discriminator determines reward, policy update is inherited.""" 17 | EPS = 1e-20 # constant for numerical stability in computing discriminator-based rewards 18 | 19 | def __init__(self, *args, **kwargs): 20 | super().__init__(*args, **kwargs) 21 | self._init_gail() 22 | 23 | def _init_gail(self): 24 | # set up discriminator 25 | self.discriminator = self._hp.discriminator(self._hp.discriminator_params) 26 | self.discriminator_opt = self._get_optimizer(self._hp.optimizer, self.discriminator, self._hp.discriminator_lr) 27 | if self._hp.discriminator_checkpoint is not None: 28 | BaseAgent.load_model_weights(self.discriminator, 29 | self._hp.discriminator_checkpoint, 30 | self._hp.discriminator_epoch) 31 | 32 | # load demo dataset 33 | self._hp.expert_data_conf.device = self.device.type 34 | self._expert_data_loader = self._hp.expert_data_conf.dataset_spec.dataset_class( 35 | self._hp.expert_data_path, self._hp.expert_data_conf, resolution=self._hp.expert_data_conf.dataset_spec.res, 36 | phase="train", shuffle=True).get_data_loader(self._hp.batch_size, n_repeat=10000) # making new iterators is slow, so repeat often 37 | self._expert_data_iter = iter(self._expert_data_loader) 38 | 39 | # set up trajectory buffer for discriminator training 40 | self.gail_trajectory_buffer = self._hp.buffer(self._hp.buffer_params) \ 41 | if 'buffer' in self._hp and self._hp.buffer is not None \ 42 | else self._hp.replay(self._hp.replay_params) # in case we use GAIL w/ SAC 43 | self.gail_trajectory_buffer.reset() 44 | 45 | # misc 46 | self._discriminator_update_cycles = 0 47 | self._lambda_gail = self._hp.lambda_gail_schedule(self._hp.lambda_gail_schedule_params) 48 | 49 | # optionally run BC for policy init 50 | if self._hp.bc_init_steps > 0: 51 | self._run_bc_init() 52 | 53 | def _default_hparams(self): 54 | default_dict = ParamDict({ 55 | 'discriminator': None, # discriminator class 56 | 'discriminator_params': None, # parameters for the discriminator class 57 | 'discriminator_checkpoint': None, # checkpoint to load discriminator from 58 | 'discriminator_epoch': 'latest', # epoch at which to load discriminator weights 59 | 'discriminator_lr': 3e-4, # learning rate for discriminator update 60 | 'freeze_discriminator': False, # if True, does not update discriminator 61 | 'expert_data_conf': None, # data config for expert sequences 62 | 'expert_data_path': None, # path to expert data sequences 63 | 'reset_buffer': True, # if True, resets online buffer every update iteration 64 | 'discriminator_updates': 5, # number of discriminator updates per PPO policy update cycle 65 | 'lambda_gail_schedule': ConstantSchedule, # schedule for lambda parameter 66 | 'lambda_gail_schedule_params': AttrDict(p=0.0), # factor for original reward when mixing with GAIL reward 67 | 'grad_penalty_coefficient': 0.0, # discriminator gradient penalty coefficient 68 | 'entropy_coefficient_gail': 0.0, # discriminator entropy loss coefficient 69 | 'warmup_cycles': 0, # number of first calls to update() in which only discriminator gets trained 70 | 'bc_init_steps': 0, # number of BC steps for policy before GAIL training starts 71 | }) 72 | return super()._default_hparams().overwrite(default_dict) 73 | 74 | def update(self, experience_batch): 75 | self.gail_info = {} 76 | if self._lr(self._env_steps) < 1e-10: return {} # stop running updates if learning rate is decayed to 0 77 | if self._discriminator_update_cycles < self._hp.warmup_cycles: 78 | # only train discriminator during warmup, do not update policy 79 | self._add_experience_discriminator_buffer(experience_batch) 80 | self._update_discriminator() 81 | return self.gail_info 82 | else: 83 | # after warmup we first update discriminator, then policy (both handled by super().update()) 84 | info = super().update(experience_batch) 85 | info.update(self.gail_info) 86 | return info 87 | 88 | def _update_discriminator(self): 89 | """Performs one training update for the discriminator.""" 90 | if self._hp.freeze_discriminator: 91 | return # do not update discriminator if it is frozen 92 | 93 | n_discriminator_updates = self._hp.discriminator_updates if self._hp.discriminator_updates >= 1 else \ 94 | int(np.random.rand() < self._hp.discriminator_updates) 95 | for _ in range(n_discriminator_updates): 96 | # sample expert and policy data batches 97 | expert_batch = self._get_expert_batch() 98 | policy_batch = self.gail_trajectory_buffer.sample(n_samples=self._hp.batch_size) 99 | policy_batch = map2torch(policy_batch, self._hp.device) 100 | 101 | # run discriminator 102 | expert_disc_outputs = self.discriminator(self.discriminator._discriminator_input( 103 | AttrDict(states=expert_batch.states, 104 | actions=expert_batch.actions))) 105 | policy_disc_outputs = self.discriminator(self.discriminator._discriminator_input( 106 | AttrDict(states=policy_batch.observation[:, None], 107 | actions=policy_batch.action[:, None]))) 108 | 109 | # compute discriminator losses: cross-entropy, entropy and gradient penalty loss 110 | expert_logits, policy_logits = expert_disc_outputs, policy_disc_outputs 111 | logits = torch.cat((expert_logits, policy_logits)) 112 | targets = torch.cat((torch.ones_like(expert_logits), torch.zeros_like(policy_logits))) 113 | discriminator_loss = BCEWithLogitsLoss()(logits, targets) 114 | discriminator_entropy = torch.distributions.Bernoulli(logits=logits).entropy().mean() 115 | discriminator_loss -= self._hp.entropy_coefficient_gail * discriminator_entropy 116 | discriminator_accuracy = ((torch.sigmoid(logits) > 0.5).float() == targets).float().mean() 117 | if self._hp.grad_penalty_coefficient > 0: 118 | grad_penalty_loss = self._hp.grad_penalty_coefficient * self._compute_gradient_penalty(expert_batch, 119 | policy_batch) 120 | discriminator_loss += grad_penalty_loss 121 | discriminator_loss += self._regularization_losses(expert_disc_outputs, policy_disc_outputs) 122 | 123 | # update discriminator 124 | self._perform_update(discriminator_loss, self.discriminator_opt, self.discriminator) 125 | 126 | # log info 127 | info = AttrDict( 128 | discriminator_loss=discriminator_loss, 129 | discriminator_entropy=discriminator_entropy, 130 | discriminator_accuracy=discriminator_accuracy, 131 | discr_real_output=torch.sigmoid(expert_logits).mean(), 132 | discr_fake_output=torch.sigmoid(policy_logits).mean(), 133 | ) 134 | info.update(self._get_obs_norm_info()) 135 | if self._hp.grad_penalty_coefficient > 0: 136 | info.update(AttrDict(grad_penalty_loss=grad_penalty_loss)) 137 | self.gail_info = map_dict(ten2ar, info) 138 | self._discriminator_update_cycles += 1 139 | 140 | def _add_experience_discriminator_buffer(self, experience_batch): 141 | """Normalizes experience and adds to discriminator replay buffer.""" 142 | # fill policy trajectories in buffer 143 | if self._hp.reset_buffer: 144 | self.gail_trajectory_buffer.reset() 145 | self.gail_trajectory_buffer.append(map2np(experience_batch)) 146 | 147 | def _aux_updates(self): 148 | """Update discriminator before updating policy & critic.""" 149 | self._update_discriminator() 150 | 151 | def add_aux_experience(self, experience_batch): 152 | self._add_experience_discriminator_buffer(experience_batch) 153 | 154 | def _get_expert_batch(self): 155 | try: 156 | expert_batch = next(self._expert_data_iter) 157 | except StopIteration: 158 | self._expert_data_iter = iter(self._expert_data_loader) 159 | expert_batch = next(self._expert_data_iter) 160 | expert_batch = map2np(AttrDict(expert_batch)) 161 | expert_batch.states = self._obs_normalizer(expert_batch.states) 162 | expert_batch = map2torch(expert_batch, device=self.device) 163 | return expert_batch 164 | 165 | def _preprocess_experience(self, experience_batch, policy_outputs=None): 166 | """Trains discriminator and then uses it to relabel rewards.""" 167 | assert isinstance(experience_batch.reward[0], torch.Tensor) # expects tensors as input 168 | with torch.no_grad(): 169 | if 'orig_reward' not in experience_batch: 170 | experience_batch.orig_reward = copy.deepcopy(experience_batch.reward) 171 | experience_batch.discr_reward, experience_batch.p_demo = \ 172 | self._compute_discriminator_reward(experience_batch, policy_outputs) 173 | experience_batch.reward = [(1 - self._lambda_gail(self.schedule_steps)) 174 | * dr + self._lambda_gail(self.schedule_steps) * r \ 175 | for dr, r in zip(experience_batch.discr_reward, experience_batch.orig_reward)] 176 | if isinstance(experience_batch.orig_reward, torch.Tensor): 177 | # merge list into tensor in case input is also tensor not list (during RL update) 178 | experience_batch.reward = torch.tensor(experience_batch.reward, 179 | device=experience_batch.orig_reward.device) 180 | self.gail_info.update({'discriminator_reward': np.mean(map2np(experience_batch.discr_reward)), 181 | 'rl_training_reward': np.mean(map2np(experience_batch.reward)), 182 | 'lambda_gail': self._lambda_gail(self.schedule_steps), 183 | 'buffer_size': self.gail_trajectory_buffer.size,}) 184 | return experience_batch 185 | 186 | def _compute_discriminator_reward(self, experience_batch, unused_policy_outputs): 187 | """Uses discriminator to compute GAIL reward.""" 188 | logits = self._run_discriminator(experience_batch, unused_policy_outputs) 189 | D = torch.sigmoid(logits) 190 | discriminator_reward = (D + self.EPS).log() - (1 - D + self.EPS).log() 191 | return [r for r in discriminator_reward], D 192 | 193 | def _run_discriminator(self, experience_batch, unused_policy_outputs): 194 | """Runs discriminator on experience batch [obs, act], returns logits.""" 195 | input_states = torch.stack(experience_batch.observation) if isinstance(experience_batch.observation, list) \ 196 | else experience_batch.observation 197 | input_actions = torch.stack(experience_batch.action) if isinstance(experience_batch.action, list) \ 198 | else experience_batch.action 199 | discr_output = self.discriminator(self.discriminator._discriminator_input( 200 | AttrDict(states=input_states[:, None], actions=input_actions[:, None]))) 201 | return discr_output[:, 0] 202 | 203 | def _compute_gradient_penalty(self, expert_batch, policy_batch): 204 | """Computes mixup gradient penalty for discriminator.""" 205 | # create mixed policy + expert input 206 | alpha = torch.rand([policy_batch.observation.shape[0], 1], device=policy_batch.observation.device) 207 | mixup_state = alpha * policy_batch.observation + (1-alpha) * expert_batch.states[:, 0] 208 | mixup_action = alpha * policy_batch.action + (1-alpha) * expert_batch.actions[:, 0] 209 | mixup_state.requires_grad = True; mixup_action.requires_grad = True 210 | 211 | # compute discriminator gradients 212 | disc_output = self.discriminator(mixup_state, mixup_action).q[:, 0] 213 | grad = torch.cat(autograd.grad(outputs=disc_output, 214 | inputs=[mixup_state, mixup_action], 215 | grad_outputs=torch.ones_like(disc_output), 216 | create_graph=True, 217 | retain_graph=True, 218 | only_inputs=True), dim=-1) 219 | 220 | # compute gradient penalty 221 | grad_penalty = (grad.norm(2, dim=1) - 1).pow(2).mean() 222 | return grad_penalty 223 | 224 | def _regularization_losses(self, *unused_args, **unused_kwargs): 225 | """Optionally add more regularization losses to discriminator update.""" 226 | return 0. 227 | 228 | def _run_bc_init(self): 229 | """Performs BC-based policy initialization.""" 230 | self.to(self.device) 231 | policy_bc_opt = self._get_optimizer(self._hp.optimizer, self.policy, self._hp.policy_lr) 232 | for step in range(self._hp.bc_init_steps): 233 | data = self._get_expert_batch() 234 | policy_output = self.policy(data.states[:, 0]) 235 | loss = -1 * policy_output.dist.log_prob(data.actions[:, 0]).mean() 236 | self._perform_update(loss, policy_bc_opt, self.policy) 237 | if step % int(self._hp.bc_init_steps / 100) == 0: 238 | print("It {}: \tBC loss: {}, \tEntropy: {}" 239 | .format(step, loss, policy_output.dist.entropy().mean().data.cpu().numpy())) 240 | 241 | def _update_lr(self): 242 | super()._update_lr() 243 | if not isinstance(self._lr, ConstantSchedule): 244 | update_optimizer_lr(self.discriminator_opt, self._lr(self._env_steps)) 245 | 246 | 247 | class GAILSACAgent(SACAgent, GAILAgent): 248 | """GAIL agent that optimizes the discriminator reward using SAC.""" 249 | def __init__(self, *args, **kwargs): 250 | super().__init__(*args, **kwargs) 251 | self._init_gail() 252 | 253 | def _default_hparams(self): 254 | params = SACAgent._default_hparams(self) 255 | params.update(GAILAgent._default_hparams(self)) 256 | return params 257 | 258 | def update(self, experience_batch): 259 | if self._discriminator_update_cycles < self._hp.warmup_cycles: 260 | # only train discriminator during warmup, do not update policy 261 | self._add_experience_discriminator_buffer(experience_batch) 262 | self._update_discriminator() 263 | return self.gail_info 264 | else: 265 | # after warmup we first update discriminator, then policy (both handled by super().update()) 266 | info = SACAgent.update(self, experience_batch) 267 | info.update(self.gail_info) 268 | return info 269 | 270 | def add_experience(self, experience_batch): 271 | self._add_experience_discriminator_buffer(experience_batch) 272 | SACAgent.add_experience(self, experience_batch) 273 | 274 | def _preprocess_experience(self, experience_batch, policy_outputs=None): 275 | processed_experience = GAILAgent._preprocess_experience(self, experience_batch, policy_outputs) 276 | if hasattr(self, 'vis_replay_buffer'): 277 | self.vis_replay_buffer.append(map2np(processed_experience)) # for visualization 278 | return processed_experience 279 | 280 | 281 | class GAILActionPriorSACAgent(ActionPriorSACAgent, GAILAgent): 282 | """GAIL agent that optimizes the discriminator reward using SPiRL.""" 283 | def __init__(self, *args, **kwargs): 284 | super().__init__(*args, **kwargs) 285 | self._init_gail() 286 | 287 | def _default_hparams(self): 288 | params = ActionPriorSACAgent._default_hparams(self) 289 | params.update(GAILAgent._default_hparams(self)) 290 | return params 291 | 292 | def update(self, experience_batch): 293 | self.gail_info = {} 294 | if self._discriminator_update_cycles < self._hp.warmup_cycles: 295 | # only train discriminator during warmup, do not update policy 296 | self._add_experience_discriminator_buffer(experience_batch) 297 | self._update_discriminator() 298 | return self.gail_info 299 | else: 300 | # after warmup we first update discriminator, then policy (both handled by super().update()) 301 | info = ActionPriorSACAgent.update(self, experience_batch) 302 | info.update(self.gail_info) 303 | return info 304 | 305 | def _preprocess_experience(self, experience_batch, policy_outputs=None): 306 | processed_experience = GAILAgent._preprocess_experience(self, experience_batch, policy_outputs) 307 | if hasattr(self, 'vis_replay_buffer'): 308 | self.vis_replay_buffer.append(map2np(processed_experience)) # for visualization 309 | return processed_experience 310 | 311 | def add_experience(self, experience_batch): 312 | self._add_experience_discriminator_buffer(experience_batch) 313 | ActionPriorSACAgent.add_experience(self, experience_batch) 314 | -------------------------------------------------------------------------------- /skild/rl/agents/ppo_agent.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from spirl.utils.general_utils import ParamDict, AttrDict, map_dict, ConstantSchedule 5 | from spirl.utils.pytorch_utils import map2torch, map2np, ten2ar, ar2ten, avg_grad_norm, update_optimizer_lr 6 | from spirl.rl.agents.ac_agent import ACAgent 7 | from spirl.rl.components.normalization import DummyNormalizer 8 | 9 | 10 | class PPOAgent(ACAgent): 11 | """Implements PPO algorithm.""" 12 | def __init__(self, *args, **kwargs): 13 | super().__init__(*args, **kwargs) 14 | # build old actor policy 15 | self.old_policy = self._hp.policy(self._hp.policy_params) 16 | 17 | # build critic and critic optimizer 18 | self.critic = self._hp.critic(self._hp.critic_params) 19 | self.critic_opt = self._get_optimizer(self._hp.optimizer, self.critic, self._hp.critic_lr) 20 | self._lr = self._hp.lr_schedule(self._hp.lr_schedule_params) 21 | 22 | # build trajectory buffer and reward normalizer 23 | self.trajectory_buffer = self._hp.buffer(self._hp.buffer_params) 24 | self._reward_normalizer = self._hp.reward_normalizer(self._hp.reward_normalizer_params) 25 | 26 | self._update_steps = 0 27 | self._env_steps = 0 28 | 29 | def _default_hparams(self): 30 | default_dict = ParamDict({ 31 | 'critic': None, # critic class 32 | 'critic_params': None, # parameters for the critic class 33 | 'critic_lr': 3e-4, # learning rate for critic update 34 | 'buffer': None, # trajectory buffer class 35 | 'buffer_params': None, # parameters for trajectory buffer 36 | 'clip_ratio': 0.2, # policy update clipping value 37 | 'entropy_coefficient': 0.0, # coefficient for weighting of entropy loss 38 | 'gae_lambda': 0.95, # GAE lambda coefficient 39 | 'target_network_update_factor': 1.0, # always overwrite old actor policy completely 40 | 'gradient_clip': 0.5, # overwrite default to cligrad norm at 0.5 41 | 'clip_value_loss': False, # if True, applies clipping to value loss 42 | 'reward_normalizer': DummyNormalizer, # normalizer for rewards 43 | 'reward_normalizer_params': {}, # optional parameters for reward normalizer 44 | 'lr_schedule': ConstantSchedule, # schedule for learning rate 45 | 'lr_schedule_params': AttrDict(p=3e-4), # parameters for learning rate schedule 46 | }) 47 | return super()._default_hparams().overwrite(default_dict) 48 | 49 | def update(self, experience_batch): 50 | """Updates actor and critic.""" 51 | # normalize experience batch 52 | experience_batch = self._normalize_batch(experience_batch) 53 | 54 | # perform any auxiliary updates 55 | self.add_aux_experience(experience_batch) 56 | self._aux_updates() 57 | 58 | # prepare experience batch for policy update 59 | self.add_experience(experience_batch) 60 | self._env_steps += self.trajectory_buffer.size 61 | self._update_lr() 62 | 63 | # copy actor weights 64 | self._soft_update_target_network(self.old_policy, self.policy) 65 | 66 | for _ in range(self._hp.update_iterations): 67 | # sample update sample 68 | experience_batch = self.trajectory_buffer.sample(n_samples=self._hp.batch_size) 69 | experience_batch = map2torch(experience_batch, device=self.device) 70 | 71 | # compute policy loss 72 | policy_loss, entropy, pi_ratio = self._compute_policy_loss(experience_batch) 73 | 74 | # compute critic loss 75 | critic_loss = self._compute_critic_loss(experience_batch) 76 | 77 | # update networks & learning rate 78 | self._update_steps += 1 79 | self._perform_update(policy_loss, self.policy_opt, self.policy) 80 | self._perform_update(critic_loss, self.critic_opt, self.critic) 81 | 82 | # log info 83 | info = AttrDict( 84 | policy_loss=policy_loss, 85 | critic_loss=critic_loss, 86 | entropy=entropy, 87 | pi_ratio=pi_ratio.mean(), 88 | lr=self._lr(self.schedule_steps), 89 | ) 90 | if self._update_steps % 100 == 0: 91 | info.update(AttrDict( # gradient norms 92 | policy_grad_norm=avg_grad_norm(self.policy), 93 | critic_grad_norm=avg_grad_norm(self.critic), 94 | )) 95 | info = map_dict(ten2ar, info) 96 | return info 97 | 98 | def _normalize_batch(self, experience_batch): 99 | self._obs_normalizer.update(experience_batch.observation) 100 | self._reward_normalizer.update(experience_batch.reward) 101 | experience_batch.observation = self._obs_normalizer(experience_batch.observation) 102 | experience_batch.observation_next = self._obs_normalizer(experience_batch.observation_next) 103 | experience_batch.reward = self._reward_normalizer(experience_batch.reward) 104 | return experience_batch 105 | 106 | def add_experience(self, experience_batch): 107 | experience_batch = self._preprocess_experience(map2torch(experience_batch, self.device)) 108 | experience_batch = self._compute_advantage(map2np(experience_batch)) 109 | self.trajectory_buffer.reset() 110 | self.trajectory_buffer.append(experience_batch) 111 | 112 | def _compute_advantage(self, experience_batch): 113 | """Computes advantage and return of input trajectories using critic.""" 114 | n_steps = len(experience_batch.observation) - 1 115 | 116 | # compute estimated value 117 | with torch.no_grad(): 118 | value = ten2ar(self.critic( 119 | ar2ten(np.array(experience_batch.observation, dtype=np.float32), device=self.device)).q).squeeze(-1) 120 | 121 | # recursively compute returns and advantage 122 | advantage = np.empty_like(value[:-1]) 123 | last_adv = 0 124 | for t in reversed(range(n_steps)): 125 | advantage[t] = experience_batch.reward[t] \ 126 | + (1 - experience_batch.done[t]) * self._hp.discount_factor * value[t+1] \ 127 | - value[t] \ 128 | + self._hp.discount_factor * self._hp.gae_lambda * (1 - experience_batch.done[t]) * last_adv 129 | last_adv = advantage[t] 130 | 131 | # compute returns and normalized advantage 132 | returns = advantage + value[:-1] 133 | norm_advantage = (advantage - advantage.mean()) / advantage.std() 134 | 135 | # remove final transitions for which we don't have advantages + add computed adv to experience batch 136 | for key in experience_batch: 137 | experience_batch[key] = experience_batch[key][:advantage.shape[0]] 138 | experience_batch.returns = [r for r in returns] 139 | experience_batch.advantage = [a for a in norm_advantage] 140 | experience_batch.value_pred = [v for v in value[:-1]] 141 | 142 | return experience_batch 143 | 144 | def _compute_policy_loss(self, experience_batch): 145 | """Computes policy update loss.""" 146 | # run actors 147 | policy_output = self.policy(experience_batch.observation) 148 | old_policy_output = self.old_policy(experience_batch.observation) 149 | log_pi, old_log_pi = policy_output.dist.log_prob(experience_batch.action), \ 150 | old_policy_output.dist.log_prob(experience_batch.action) 151 | 152 | # compute actor loss 153 | ratio = torch.exp(log_pi - old_log_pi) 154 | surr1 = ratio * experience_batch.advantage 155 | surr2 = torch.clamp(ratio, 1.0 - self._hp.clip_ratio, 1.0 + self._hp.clip_ratio) * experience_batch.advantage 156 | actor_loss = -torch.min(surr1, surr2).mean() 157 | 158 | # compute entropy loss 159 | entropy_loss = -1 * policy_output.dist.entropy().mean() 160 | 161 | return actor_loss + self._hp.entropy_coefficient * entropy_loss, -1 * entropy_loss, ratio 162 | 163 | def _compute_critic_loss(self, experience_batch): 164 | value = self.critic(experience_batch.observation).q.squeeze(-1) 165 | if not self._hp.clip_value_loss: 166 | return 0.5 * (experience_batch.returns - value).pow(2).mean() 167 | else: 168 | value_clipped = experience_batch.value_pred + \ 169 | (value - experience_batch.value_pred).clamp(-self._hp.clip_ratio, self._hp.clip_ratio) 170 | value_losses = (experience_batch.returns - value).pow(2) 171 | value_losses_clipped = (value_clipped - experience_batch.returns).pow(2) 172 | return 0.5 * torch.max(value_losses, value_losses_clipped).mean() 173 | 174 | def _update_lr(self): 175 | """Updates learning rates with schedule.""" 176 | if not isinstance(self._lr, ConstantSchedule): 177 | update_optimizer_lr(self.policy_opt, self._lr(self.schedule_steps)) 178 | update_optimizer_lr(self.critic_opt, self._lr(self.schedule_steps)) 179 | 180 | @property 181 | def schedule_steps(self): 182 | return self._env_steps 183 | -------------------------------------------------------------------------------- /skild/rl/agents/skild_agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from spirl.utils.general_utils import ParamDict, ConstantSchedule, AttrDict 5 | from spirl.utils.pytorch_utils import TensorModule, check_shape, ar2ten 6 | from skild.rl.agents.gail_agent import GAILActionPriorSACAgent 7 | 8 | 9 | class SkiLDAgent(GAILActionPriorSACAgent): 10 | """Implements the SkiLD algorithm.""" 11 | def __init__(self, *args, **kwargs): 12 | GAILActionPriorSACAgent.__init__(self, *args, **kwargs) 13 | self._posterior_target_divergence = self._hp.tdq_schedule(self._hp.tdq_schedule_params) 14 | 15 | # define posterior divergence multiplier alpha_q 16 | if self._hp.fixed_alpha_q is not None: 17 | self._log_alpha_q = TensorModule(np.log(self._hp.fixed_alpha_q) 18 | * torch.ones(1, requires_grad=False, device=self._hp.device)) 19 | else: 20 | self._log_alpha_q = TensorModule(torch.zeros(1, requires_grad=True, device=self._hp.device)) 21 | self.alpha_q_opt = self._get_optimizer(self._hp.optimizer, self._log_alpha_q, self._hp.alpha_lr) 22 | 23 | def _default_hparams(self): 24 | return GAILActionPriorSACAgent._default_hparams(self).overwrite(ParamDict({ 25 | 'tdq_schedule': ConstantSchedule, # schedule used for posterior target divergence param 26 | 'tdq_schedule_params': AttrDict( # parameters for posterior target divergence schedule 27 | p = 1., 28 | ), 29 | 'action_cond_discriminator': False, # if True, conditions discriminator on actions 30 | 'fixed_alpha_q': None, 31 | })) 32 | 33 | def update(self, experience_batch): 34 | info = GAILActionPriorSACAgent.update(self, experience_batch) 35 | info.posterior_target_divergence = self._posterior_target_divergence(self.schedule_steps) 36 | return info 37 | 38 | def _update_alpha(self, experience_batch, policy_output): 39 | # update alpha_q 40 | if self._hp.fixed_alpha_q is None: 41 | self.alpha_q_loss = (self._compute_alpha_q_loss(policy_output) * experience_batch.p_demo.detach()).mean() 42 | self._perform_update(self.alpha_q_loss, self.alpha_q_opt, self._log_alpha_q) 43 | else: 44 | self.alpha_q_loss = 0. 45 | 46 | # update alpha 47 | if self._hp.fixed_alpha is None: 48 | alpha_loss = (self._compute_alpha_loss(policy_output) * (1-experience_batch.p_demo).detach()).mean() 49 | self._perform_update(alpha_loss, self.alpha_opt, self._log_alpha) 50 | else: 51 | alpha_loss = 0. 52 | return alpha_loss 53 | 54 | def _compute_alpha_q_loss(self, policy_output): 55 | return self.alpha_q * (self._posterior_target_divergence(self.schedule_steps) 56 | - policy_output.posterior_divergence).detach() 57 | 58 | def _compute_alpha_loss(self, policy_output): 59 | return self.alpha * (self._target_divergence(self.schedule_steps) - policy_output.prior_divergence).detach() 60 | 61 | def _compute_policy_loss(self, experience_batch, policy_output): 62 | q_est = torch.min(*[critic(experience_batch.observation, self._prep_action(policy_output.action)).q 63 | for critic in self.critics]) 64 | weighted_divergence = self.alpha * policy_output.prior_divergence[:, None] \ 65 | * (1 - experience_batch.p_demo[:, None]) \ 66 | + self.alpha_q * policy_output.posterior_divergence[:, None] \ 67 | * experience_batch.p_demo[:, None] 68 | policy_loss = -1 * q_est + weighted_divergence 69 | check_shape(policy_loss, [self._hp.batch_size, 1]) 70 | return policy_loss.mean() 71 | 72 | def _compute_next_value(self, experience_batch, policy_output): 73 | q_next = torch.min(*[critic_target(experience_batch.observation_next, self._prep_action(policy_output.action)).q 74 | for critic_target in self.critic_targets]) 75 | weighted_divergence = self.alpha * policy_output.prior_divergence[:, None] \ 76 | * (1 - experience_batch.p_demo[:, None]) \ 77 | + self.alpha_q * policy_output.posterior_divergence[:, None] \ 78 | * experience_batch.p_demo[:, None] 79 | next_val = (q_next - weighted_divergence) 80 | check_shape(next_val, [self._hp.batch_size, 1]) 81 | return next_val.squeeze(-1) 82 | 83 | def _aux_info(self, experience_batch, policy_output): 84 | aux_info = GAILActionPriorSACAgent._aux_info(self, experience_batch, policy_output) 85 | aux_info.update(AttrDict( 86 | prior_divergence=(policy_output.prior_divergence[experience_batch.p_demo < 0.5]).mean(), 87 | posterior_divergence=(policy_output.posterior_divergence[experience_batch.p_demo > 0.5]).mean(), 88 | alpha_q_loss=self.alpha_q_loss, 89 | alpha_q=self.alpha_q, 90 | p_demo=experience_batch.p_demo.mean(), 91 | )) 92 | aux_info.update(AttrDict( # log all reward components 93 | env_reward=self._hp.reward_scale * self._lambda_gail(self.schedule_steps) 94 | * experience_batch.orig_reward.mean(), 95 | gail_reward=self._hp.reward_scale * (1-self._lambda_gail(self.schedule_steps)) 96 | * torch.stack(experience_batch.discr_reward).mean(), 97 | prior_reward=self.alpha * (policy_output.prior_divergence * (1 - experience_batch.p_demo)).mean(), 98 | posterior_reward=self.alpha_q * (policy_output.posterior_divergence * experience_batch.p_demo).mean(), 99 | )) 100 | return aux_info 101 | 102 | def _run_discriminator(self, experience_batch, policy_output=None): 103 | # optionally unflatten observation (in case we have image inputs) 104 | if self._hp.action_cond_discriminator and policy_output is None: 105 | # first call -- before policy was called 106 | return torch.zeros_like(experience_batch.observation[:, 0]) 107 | if hasattr(self.policy.net, "unflatten_obs"): 108 | discriminator_input = self.discriminator.filter_input(self.policy.net.unflatten_obs( 109 | ar2ten(experience_batch.observation, device=self.device)).prior_obs) 110 | else: 111 | discriminator_input = ar2ten(experience_batch.observation, device=self.device) 112 | if self._hp.action_cond_discriminator: 113 | discriminator_input = torch.cat((discriminator_input, policy_output.action), dim=-1) 114 | return self.discriminator(discriminator_input)[..., 0] 115 | 116 | def _update_experience(self, experience_batch, policy_outputs): 117 | """Run discriminator with action input.""" 118 | if not self._hp.action_cond_discriminator: 119 | return super()._update_experience(experience_batch, policy_outputs) 120 | return self._preprocess_experience(experience_batch, policy_outputs) 121 | 122 | def state_dict(self, *args, **kwargs): 123 | d = GAILActionPriorSACAgent.state_dict(self) 124 | if hasattr(self, 'alpha_q_opt'): 125 | d['alpha_q_opt'] = self.alpha_q_opt.state_dict() 126 | return d 127 | 128 | def load_state_dict(self, state_dict, *args, **kwargs): 129 | if 'alpha_q_opt' in state_dict: 130 | self.alpha_q_opt.load_state_dict(state_dict.pop('alpha_q_opt')) 131 | GAILActionPriorSACAgent.load_state_dict(self, state_dict, *args, **kwargs) 132 | 133 | @property 134 | def alpha_q(self): 135 | if self._hp.alpha_min is not None: 136 | return torch.clamp(self._log_alpha_q().exp(), min=self._hp.alpha_min) 137 | return self._log_alpha_q().exp() 138 | -------------------------------------------------------------------------------- /skild/rl/envs/maze.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import d4rl 3 | 4 | from spirl.rl.components.environment import GymEnv 5 | from spirl.utils.general_utils import ParamDict, AttrDict 6 | 7 | 8 | class MazeEnv(GymEnv): 9 | """Extends SPiRL maze env with randomized init position and episode termination upon goal reaching.""" 10 | def __init__(self, *args, **kwargs): 11 | super().__init__(*args, **kwargs) 12 | 13 | def _default_hparams(self): 14 | default_dict = ParamDict({ 15 | 'start_rand_range': 2., # range of start position randomization, fixed pos if 0. 16 | }) 17 | return super()._default_hparams().overwrite(default_dict) 18 | 19 | def reset(self): 20 | super().reset() 21 | if self.TARGET_POS is not None and self.START_POS is not None: 22 | start_pos = self.START_POS + self._hp.start_rand_range * (np.random.rand(2) * 2 - 1) 23 | self._env.set_target(self.TARGET_POS) 24 | self._env.reset_to_location(start_pos) 25 | self._env.render(mode='rgb_array') # these are necessary to make sure new state is rendered on first frame 26 | obs, _, _, _ = self._env.step(np.zeros_like(self._env.action_space.sample())) 27 | return self._wrap_observation(obs) 28 | 29 | def step(self, *args, **kwargs): 30 | obs, rew, done, info = super().step(*args, **kwargs) 31 | if rew > 0: 32 | rew *= 100. 33 | done = True 34 | return obs, np.float64(rew), done, info # casting reward to float64 is important for getting shape later 35 | 36 | 37 | class ACRandMaze0S40Env(MazeEnv): 38 | START_POS = np.array([10., 24.]) 39 | TARGET_POS = np.array([18., 8.]) 40 | 41 | def _default_hparams(self): 42 | default_dict = ParamDict({ 43 | 'name': "maze2d-randMaze0S40-ac-v0", 44 | }) 45 | return super()._default_hparams().overwrite(default_dict) 46 | -------------------------------------------------------------------------------- /skild/rl/policies/posterior_policies.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from spirl.utils.pytorch_utils import no_batchnorm_update 4 | from spirl.utils.general_utils import ParamDict, AttrDict 5 | from spirl.rl.policies.prior_policies import LearnedPriorAugmentedPIPolicy 6 | from spirl.rl.components.agent import BaseAgent 7 | 8 | 9 | class LearnedPPPolicy(LearnedPriorAugmentedPIPolicy): 10 | """Computes both learned prior and posterior distribution.""" 11 | def __init__(self, *args, **kwargs): 12 | LearnedPriorAugmentedPIPolicy.__init__(self, *args, **kwargs) 13 | self.posterior_net = self._hp.posterior_model(self._hp.posterior_model_params, None) 14 | BaseAgent.load_model_weights(self.posterior_net, 15 | self._hp.posterior_model_checkpoint, 16 | self._hp.posterior_model_epoch) 17 | 18 | def _default_hparams(self): 19 | return LearnedPriorAugmentedPIPolicy._default_hparams(self).overwrite(ParamDict({ 20 | 'posterior_model': None, # posterior model class 21 | 'posterior_model_params': None, # parameters for the posterior model 22 | 'posterior_model_checkpoint': None, # checkpoint path of the posterior model 23 | 'posterior_model_epoch': 'latest', # epoch that checkpoint should be loaded for (defaults to latest) 24 | })) 25 | 26 | def forward(self, obs): 27 | policy_output = LearnedPriorAugmentedPIPolicy.forward(self, obs) 28 | if not self._rollout_mode: 29 | raw_posterior_divergence, policy_output.posterior_dist = \ 30 | self._compute_posterior_divergence(policy_output, obs) 31 | policy_output.posterior_divergence = self.clamp_divergence(raw_posterior_divergence) 32 | return policy_output 33 | 34 | def _compute_posterior_divergence(self, policy_output, obs): 35 | with no_batchnorm_update(self.posterior_net): 36 | posterior_dist = self.posterior_net.compute_learned_prior(obs, first_only=True).detach() 37 | if self._hp.analytic_KL: 38 | return self._analytic_divergence(policy_output, posterior_dist), posterior_dist 39 | return self._mc_divergence(policy_output, posterior_dist), posterior_dist 40 | 41 | 42 | class ACLearnedPPPolicy(LearnedPPPolicy): 43 | """LearnedPPPolicy for case with separate prior obs --> uses prior observation as input only.""" 44 | def forward(self, obs): 45 | if obs.shape[0] == 1: 46 | return super().forward(self.net.unflatten_obs(obs).prior_obs) # use policy_net or batch_size 1 inputs 47 | return super().forward(self.prior_net.unflatten_obs(obs).prior_obs) 48 | --------------------------------------------------------------------------------