├── .gitignore ├── .gitmodules ├── LICENSE ├── Pipfile ├── README.md ├── conda ├── environment_cpu.yaml ├── environment_gpu.yaml └── requirements.txt ├── configs └── pybullet │ ├── dynamics │ └── table_env.yaml │ └── envs │ ├── assets │ ├── franka_panda │ │ ├── collision │ │ │ ├── finger.stl │ │ │ ├── hand.stl │ │ │ ├── link0.stl │ │ │ ├── link1.stl │ │ │ ├── link2.stl │ │ │ ├── link3.stl │ │ │ ├── link4.stl │ │ │ ├── link5.stl │ │ │ ├── link6.stl │ │ │ ├── link7.stl │ │ │ ├── robotiq_arg2f_85_base_link.stl │ │ │ ├── robotiq_arg2f_85_inner_finger.dae │ │ │ ├── robotiq_arg2f_85_inner_knuckle.dae │ │ │ ├── robotiq_arg2f_85_outer_finger.dae │ │ │ ├── robotiq_arg2f_85_outer_knuckle.dae │ │ │ └── robotiq_arg2f_base_link.stl │ │ ├── franka_panda.urdf │ │ ├── franka_panda_robotiq.urdf │ │ └── visual │ │ │ ├── finger.dae │ │ │ ├── hand.dae │ │ │ ├── link0.dae │ │ │ ├── link1.dae │ │ │ ├── link2.dae │ │ │ ├── link3.dae │ │ │ ├── link4.dae │ │ │ ├── link5.dae │ │ │ ├── link6.dae │ │ │ ├── link7.dae │ │ │ ├── robotiq_arg2f_85_base_link.dae │ │ │ ├── robotiq_arg2f_85_inner_finger.dae │ │ │ ├── robotiq_arg2f_85_inner_knuckle.dae │ │ │ ├── robotiq_arg2f_85_outer_finger.dae │ │ │ ├── robotiq_arg2f_85_outer_knuckle.dae │ │ │ └── robotiq_arg2f_85_pad.dae │ ├── icecream.yaml │ ├── iprl_table.urdf │ ├── milk.yaml │ ├── salt.yaml │ └── yogurt.yaml │ ├── official │ ├── domains │ │ ├── constrained_packing │ │ │ ├── task0.yaml │ │ │ ├── task1.yaml │ │ │ └── task2.yaml │ │ ├── hook_reach │ │ │ ├── task0.yaml │ │ │ ├── task1.yaml │ │ │ └── task2.yaml │ │ └── rearrangement_push │ │ │ ├── task0.yaml │ │ │ ├── task1.yaml │ │ │ ├── task2.yaml │ │ │ └── task3.yaml │ └── primitives │ │ ├── pick.yaml │ │ ├── pick_eval.yaml │ │ ├── place.yaml │ │ ├── place_eval.yaml │ │ ├── pull.yaml │ │ ├── pull_eval.yaml │ │ ├── push.yaml │ │ └── push_eval.yaml │ └── robots │ └── franka_panda_sim.yaml ├── generative_skill_chaining ├── __init__.py ├── agents │ ├── __init__.py │ ├── base.py │ ├── rl.py │ ├── sac.py │ ├── utils.py │ └── wrapper.py ├── diff_models │ ├── __init__.py │ ├── classifier_transformer.py │ ├── unet_transformer.py │ └── utils │ │ ├── __init__.py │ │ └── helpers.py ├── encoders │ ├── __init__.py │ ├── base.py │ ├── state.py │ └── utils.py ├── envs │ ├── __init__.py │ ├── base.py │ ├── empty.py │ ├── pybullet │ │ ├── __init__.py │ │ ├── base.py │ │ ├── real │ │ │ ├── __init__.py │ │ │ ├── arm.py │ │ │ ├── gripper.py │ │ │ ├── object_tracker.py │ │ │ └── redisgl.py │ │ ├── sim │ │ │ ├── __init__.py │ │ │ ├── arm.py │ │ │ ├── articulated_body.py │ │ │ ├── body.py │ │ │ ├── gripper.py │ │ │ ├── math.py │ │ │ ├── redisgl.py │ │ │ ├── robot.py │ │ │ └── shapes.py │ │ ├── table │ │ │ ├── __init__.py │ │ │ ├── object_state.py │ │ │ ├── objects.py │ │ │ ├── predicates.py │ │ │ ├── primitive_actions.py │ │ │ ├── primitives.py │ │ │ └── utils.py │ │ ├── table_env.py │ │ └── utils.py │ ├── utils.py │ └── variant.py ├── mixed_diffusion │ ├── __init__.py │ ├── cond_diffusion1D.py │ ├── datasets_transformer.py │ ├── datasets_transformer_class.py │ ├── grad_discriminator.py │ ├── sde_cont.py │ └── utils │ │ ├── __init__.py │ │ ├── cont_utils.py │ │ └── diff_utils.py ├── networks │ ├── __init__.py │ ├── actors │ │ ├── __init__.py │ │ ├── base.py │ │ └── mlp.py │ ├── critics │ │ ├── __init__.py │ │ ├── base.py │ │ └── mlp.py │ ├── encoders │ │ ├── __init__.py │ │ ├── base.py │ │ ├── normalize.py │ │ └── table_env.py │ ├── mlp.py │ └── utils.py └── utils │ ├── __init__.py │ ├── configs.py │ ├── logging.py │ ├── metrics.py │ ├── nest.py │ ├── random.py │ ├── recording.py │ ├── spaces.py │ ├── tensors.py │ ├── timing.py │ └── typing.py ├── pyproject.toml ├── requirements.txt ├── scripts ├── eval │ ├── eval_constrained_packing.py │ ├── eval_constrained_packing.sh │ ├── eval_diffusion_transformer.py │ ├── eval_diffusion_transformer.sh │ ├── eval_hook_reach.py │ └── eval_hook_reach.sh └── train │ ├── train_diffusion_transformer.py │ ├── train_diffusion_transformer.sh │ ├── train_diffusion_transformer_w_classifier.py │ └── train_diffusion_transformer_w_classifier.sh ├── setup.cfg └── setup_shell.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # generative-skill-chaining outputs 2 | archive/ 3 | final_results/ 4 | models/ 5 | plots/ 6 | plots_debug/ 7 | logs/ 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | pip-wheel-metadata/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 103 | __pypackages__/ 104 | 105 | # Celery stuff 106 | celerybeat-schedule 107 | celerybeat.pid 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | .mypy.ini 136 | 137 | # Pyre type checker 138 | .pyre/ 139 | 140 | # MacOS 141 | .DS_Store 142 | 143 | *.png 144 | **/**/*.png 145 | *.jpg 146 | **/**/*.jpg 147 | *.gif 148 | **/**/*.gif 149 | *.pkl 150 | **/**/*.pkl 151 | *.pt 152 | **/**/*.pt 153 | *.pth 154 | **/**/*.pth 155 | 156 | diffusion_models/ 157 | 158 | third_party/ 159 | *.pkl 160 | datasets/ -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/scod-regression"] 2 | path = third_party/scod-regression 3 | url = git@github.com:agiachris/scod-regression.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Utkarsh Aashu Mishra 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | url = "https://pypi.org/simple" 3 | verify_ssl = true 4 | name = "pypi" 5 | 6 | [packages] 7 | generative-skill-chaining = { editable = true, path = "." } 8 | matplotlib = "*" 9 | pandas = "*" 10 | seaborn = "*" 11 | shapely = "*" 12 | 13 | [dev-packages] 14 | black = "*" 15 | mypy = "*" 16 | flake8 = "*" 17 | 18 | [requires] 19 | python_version = "3.8" 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Generative Skill Chaining (GSC): Long-Horizon Skill Planning with Diffusion Models 2 | 3 | GSC Preview 4 | 5 | ## Citation 6 | If you find this package helpful, please consider citing our work: 7 | ``` 8 | @inproceedings{ 9 | mishra2023generative, 10 | title={Generative Skill Chaining: Long-Horizon Skill Planning with Diffusion Models}, 11 | author={Utkarsh Aashu Mishra and Shangjie Xue and Yongxin Chen and Danfei Xu}, 12 | booktitle={7th Annual Conference on Robot Learning}, 13 | year={2023}, 14 | url={https://openreview.net/forum?id=HtJE9ly5dT} 15 | } 16 | ``` 17 | -------------------------------------------------------------------------------- /conda/environment_cpu.yaml: -------------------------------------------------------------------------------- 1 | name: generative_skill_chaining 2 | 3 | channels: 4 | - pytorch 5 | - conda-forge 6 | - default 7 | 8 | dependencies: 9 | - python=3.7 10 | 11 | # PyTorch 12 | - pytorch 13 | - cpuonly 14 | - torchvision 15 | - torchtext 16 | 17 | # NumPy Family 18 | - numpy 19 | - scipy 20 | - networkx 21 | - scikit-image 22 | 23 | # Env 24 | - pybox2d 25 | 26 | # IO 27 | - imageio 28 | - pillow 29 | - pyyaml 30 | - cloudpickle 31 | - h5py 32 | - absl-py 33 | - pyparsing 34 | 35 | # Plotting 36 | - tensorboard 37 | - pandas 38 | - matplotlib 39 | - seaborn 40 | 41 | # Other 42 | - pytest 43 | - tqdm 44 | - future 45 | 46 | - pip 47 | - pip: 48 | - -r requirements.txt -------------------------------------------------------------------------------- /conda/environment_gpu.yaml: -------------------------------------------------------------------------------- 1 | name: dplan 2 | 3 | channels: 4 | - pytorch 5 | - conda-forge 6 | - nvidia 7 | - default 8 | 9 | dependencies: 10 | - python=3.7 11 | 12 | # PyTorch 13 | - pytorch 14 | - cudatoolkit=11.1 15 | - torchvision 16 | - torchtext 17 | 18 | # NumPy Family 19 | - numpy 20 | - scipy 21 | - networkx 22 | - scikit-image 23 | 24 | # Env 25 | - pybox2d 26 | 27 | # IO 28 | - imageio 29 | - pillow 30 | - pyyaml 31 | - cloudpickle 32 | - h5py 33 | - absl-py 34 | - pyparsing 35 | 36 | # Plotting 37 | - tensorboard 38 | - pandas 39 | - matplotlib 40 | - seaborn 41 | 42 | # Other 43 | - pytest 44 | - tqdm 45 | - future 46 | 47 | - pip 48 | - pip: 49 | - -r requirements.txt -------------------------------------------------------------------------------- /conda/requirements.txt: -------------------------------------------------------------------------------- 1 | gym>=0.12 2 | pygame 3 | functorch 4 | -e ../third_party/scod-regression 5 | black 6 | mypy 7 | flake8 8 | -e ../. -------------------------------------------------------------------------------- /configs/pybullet/dynamics/table_env.yaml: -------------------------------------------------------------------------------- 1 | dynamics: TableEnvDynamics 2 | dynamics_kwargs: 3 | network_class: dynamics.MLPDynamics 4 | network_kwargs: 5 | hidden_layers: [256, 256] 6 | ortho_init: true 7 | -------------------------------------------------------------------------------- /configs/pybullet/envs/assets/franka_panda/collision/finger.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/generative-skill-chaining/gsc-code/b8f5bc0d44bd453827f3f254ca1b03ec049849c5/configs/pybullet/envs/assets/franka_panda/collision/finger.stl -------------------------------------------------------------------------------- /configs/pybullet/envs/assets/franka_panda/collision/hand.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/generative-skill-chaining/gsc-code/b8f5bc0d44bd453827f3f254ca1b03ec049849c5/configs/pybullet/envs/assets/franka_panda/collision/hand.stl -------------------------------------------------------------------------------- /configs/pybullet/envs/assets/franka_panda/collision/link0.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/generative-skill-chaining/gsc-code/b8f5bc0d44bd453827f3f254ca1b03ec049849c5/configs/pybullet/envs/assets/franka_panda/collision/link0.stl -------------------------------------------------------------------------------- /configs/pybullet/envs/assets/franka_panda/collision/link1.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/generative-skill-chaining/gsc-code/b8f5bc0d44bd453827f3f254ca1b03ec049849c5/configs/pybullet/envs/assets/franka_panda/collision/link1.stl -------------------------------------------------------------------------------- /configs/pybullet/envs/assets/franka_panda/collision/link2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/generative-skill-chaining/gsc-code/b8f5bc0d44bd453827f3f254ca1b03ec049849c5/configs/pybullet/envs/assets/franka_panda/collision/link2.stl -------------------------------------------------------------------------------- /configs/pybullet/envs/assets/franka_panda/collision/link3.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/generative-skill-chaining/gsc-code/b8f5bc0d44bd453827f3f254ca1b03ec049849c5/configs/pybullet/envs/assets/franka_panda/collision/link3.stl -------------------------------------------------------------------------------- /configs/pybullet/envs/assets/franka_panda/collision/link4.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/generative-skill-chaining/gsc-code/b8f5bc0d44bd453827f3f254ca1b03ec049849c5/configs/pybullet/envs/assets/franka_panda/collision/link4.stl -------------------------------------------------------------------------------- /configs/pybullet/envs/assets/franka_panda/collision/link5.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/generative-skill-chaining/gsc-code/b8f5bc0d44bd453827f3f254ca1b03ec049849c5/configs/pybullet/envs/assets/franka_panda/collision/link5.stl -------------------------------------------------------------------------------- /configs/pybullet/envs/assets/franka_panda/collision/link6.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/generative-skill-chaining/gsc-code/b8f5bc0d44bd453827f3f254ca1b03ec049849c5/configs/pybullet/envs/assets/franka_panda/collision/link6.stl -------------------------------------------------------------------------------- /configs/pybullet/envs/assets/franka_panda/collision/link7.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/generative-skill-chaining/gsc-code/b8f5bc0d44bd453827f3f254ca1b03ec049849c5/configs/pybullet/envs/assets/franka_panda/collision/link7.stl -------------------------------------------------------------------------------- /configs/pybullet/envs/assets/franka_panda/collision/robotiq_arg2f_85_base_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/generative-skill-chaining/gsc-code/b8f5bc0d44bd453827f3f254ca1b03ec049849c5/configs/pybullet/envs/assets/franka_panda/collision/robotiq_arg2f_85_base_link.stl -------------------------------------------------------------------------------- /configs/pybullet/envs/assets/franka_panda/collision/robotiq_arg2f_base_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/generative-skill-chaining/gsc-code/b8f5bc0d44bd453827f3f254ca1b03ec049849c5/configs/pybullet/envs/assets/franka_panda/collision/robotiq_arg2f_base_link.stl -------------------------------------------------------------------------------- /configs/pybullet/envs/assets/franka_panda/visual/robotiq_arg2f_85_pad.dae: -------------------------------------------------------------------------------- 1 | 2 | 3 | 2016-07-17T22:25:43.361178 4 | 2016-07-17T22:25:43.361188 5 | Z_UP 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 0.0 0.0 0.0 1.0 14 | 15 | 16 | 0.0 0.0 0.0 1.0 17 | 18 | 19 | 0.7 0.7 0.7 1.0 20 | 21 | 22 | 1 1 1 1.0 23 | 24 | 25 | 0.0 26 | 27 | 28 | 0.0 0.0 0.0 1.0 29 | 30 | 31 | 0.0 32 | 33 | 34 | 0.0 0.0 0.0 1.0 35 | 36 | 37 | 1.0 38 | 39 | 40 | 41 | 42 | 43 | 0 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 4.38531e-14 -1 -4.451336e-05 4.38531e-14 -1 -4.451336e-05 -4.38531e-14 1 4.451336e-05 -4.38531e-14 1 4.451336e-05 -1 -4.385301e-14 -2.011189e-15 -1 -4.385301e-14 -2.011189e-15 -2.009237e-15 -4.451336e-05 1 -2.009237e-15 -4.451336e-05 1 1 4.385301e-14 2.011189e-15 1 4.385301e-14 2.011189e-15 2.009237e-15 4.451336e-05 -1 2.009237e-15 4.451336e-05 -1 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | -10 -23.90175 13.51442 10 -23.9033 48.51442 -10 -23.9033 48.51442 10 -23.90175 13.51442 -10 -18.90175 13.51464 -10 -18.9033 48.51464 10 -18.90175 13.51464 10 -18.9033 48.51464 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 |

0 0 1 0 2 0 3 1 1 1 0 1 4 2 5 2 6 2 5 3 7 3 6 3 2 4 5 4 4 4 2 5 4 5 0 5 5 6 2 6 1 6 5 7 1 7 7 7 7 8 1 8 6 8 1 9 3 9 6 9 0 10 4 10 3 10 4 11 6 11 3 11

79 |
80 |
81 |
82 |
83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 |
105 | -------------------------------------------------------------------------------- /configs/pybullet/envs/assets/icecream.yaml: -------------------------------------------------------------------------------- 1 | object_type: Box 2 | object_kwargs: 3 | name: cyan_box 4 | size: [0.06, 0.05, 0.07] 5 | # size: [0.06, 0.03, 0.06] 6 | color: [0.0, 1.0, 1.0, 1.0] 7 | -------------------------------------------------------------------------------- /configs/pybullet/envs/assets/iprl_table.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /configs/pybullet/envs/assets/milk.yaml: -------------------------------------------------------------------------------- 1 | object_type: Box 2 | object_kwargs: 3 | name: red_box 4 | size: [0.08, 0.055, 0.08] 5 | # size: [0.05, 0.05, 0.07] 6 | color: [1.0, 0.0, 0.0, 1.0] 7 | -------------------------------------------------------------------------------- /configs/pybullet/envs/assets/salt.yaml: -------------------------------------------------------------------------------- 1 | object_type: Box 2 | object_kwargs: 3 | name: blue_box 4 | size: [0.05, 0.05, 0.1] 5 | # size: [0.08, 0.05, 0.07] 6 | color: [0.0, 0.0, 1.0, 1.0] 7 | -------------------------------------------------------------------------------- /configs/pybullet/envs/assets/yogurt.yaml: -------------------------------------------------------------------------------- 1 | object_type: Box 2 | object_kwargs: 3 | name: yellow_box 4 | size: [0.07, 0.05, 0.065] 5 | # size: [0.07, 0.06, 0.1] 6 | color: [1.0, 1.0, 0.0, 1.0] 7 | -------------------------------------------------------------------------------- /configs/pybullet/envs/official/domains/constrained_packing/task0.yaml: -------------------------------------------------------------------------------- 1 | env: pybullet.TableEnv 2 | env_kwargs: 3 | name: constrained_packing_0 4 | gui: false 5 | 6 | primitives: 7 | - pick 8 | - place 9 | - pull 10 | - push 11 | 12 | tasks: 13 | - action_skeleton: 14 | - pick(yellow_box, table) 15 | - place(yellow_box, rack) 16 | - pick(red_box, table) 17 | - place(red_box, rack) 18 | - pick(cyan_box, table) 19 | - place(cyan_box, rack) 20 | initial_state: 21 | - free(yellow_box) 22 | - free(red_box) 23 | - free(cyan_box) 24 | - free(blue_box) 25 | - aligned(rack) 26 | - poslimit(rack) 27 | - inworkspace(rack) 28 | - inworkspace(yellow_box) 29 | - inworkspace(red_box) 30 | - inworkspace(cyan_box) 31 | - on(rack, table) 32 | - on(cyan_box, table) 33 | - on(yellow_box, table) 34 | - on(red_box, table) 35 | - on(blue_box, rack) 36 | 37 | robot_config: configs/pybullet/envs/robots/franka_panda_sim.yaml 38 | 39 | objects: 40 | - object_type: Urdf 41 | object_kwargs: 42 | name: table 43 | path: configs/pybullet/envs/assets/iprl_table.urdf 44 | is_static: true 45 | - object_type: Rack 46 | object_kwargs: 47 | name: rack 48 | size: [0.22, 0.32, 0.16] 49 | color: [0.4, 0.2, 0.0, 1.0] 50 | - configs/pybullet/envs/assets/yogurt.yaml 51 | - configs/pybullet/envs/assets/milk.yaml 52 | - configs/pybullet/envs/assets/icecream.yaml 53 | - configs/pybullet/envs/assets/salt.yaml 54 | # - object_type: Box 55 | # object_kwargs: 56 | # name: yellow_box 57 | # size: [0.07, 0.06, 0.1] 58 | # color: [1.0, 1.0, 0.0, 1.0] 59 | # - object_type: Box 60 | # object_kwargs: 61 | # name: red_box 62 | # size: [0.05, 0.05, 0.07] 63 | # color: [1.0, 0.0, 0.0, 1.0] 64 | # - object_type: Box 65 | # object_kwargs: 66 | # name: cyan_box 67 | # size: [0.06, 0.03, 0.06] 68 | # color: [0.0, 1.0, 1.0, 1.0] 69 | -------------------------------------------------------------------------------- /configs/pybullet/envs/official/domains/constrained_packing/task1.yaml: -------------------------------------------------------------------------------- 1 | env: pybullet.TableEnv 2 | env_kwargs: 3 | name: constrained_packing_1 4 | gui: false 5 | 6 | primitives: 7 | - pick 8 | - place 9 | - pull 10 | - push 11 | 12 | tasks: 13 | - action_skeleton: 14 | - pick(yellow_box, table) 15 | - place(yellow_box, rack) 16 | - pick(cyan_box, table) 17 | - place(cyan_box, rack) 18 | - pick(blue_box, table) 19 | - place(blue_box, rack) 20 | initial_state: 21 | - free(yellow_box) 22 | - free(cyan_box) 23 | - free(blue_box) 24 | - aligned(rack) 25 | - poslimit(rack) 26 | - inworkspace(rack) 27 | - inworkspace(yellow_box) 28 | - inworkspace(cyan_box) 29 | - inworkspace(blue_box) 30 | - on(rack, table) 31 | - on(yellow_box, table) 32 | - on(cyan_box, table) 33 | - on(blue_box, table) 34 | 35 | robot_config: configs/pybullet/envs/robots/franka_panda_sim.yaml 36 | 37 | objects: 38 | - object_type: Urdf 39 | object_kwargs: 40 | name: table 41 | path: configs/pybullet/envs/assets/iprl_table.urdf 42 | is_static: true 43 | - object_type: Rack 44 | object_kwargs: 45 | name: rack 46 | size: [0.22, 0.32, 0.16] 47 | color: [0.4, 0.2, 0.0, 1.0] 48 | - configs/pybullet/envs/assets/salt.yaml 49 | - configs/pybullet/envs/assets/yogurt.yaml 50 | - configs/pybullet/envs/assets/icecream.yaml 51 | # - object_type: Box 52 | # object_kwargs: 53 | # name: blue_box 54 | # size: [0.08, 0.05, 0.07] 55 | # color: [0.0, 0.0, 1.0, 1.0] 56 | # - object_type: Box 57 | # object_kwargs: 58 | # name: yellow_box 59 | # size: [0.07, 0.06, 0.1] 60 | # color: [1.0, 1.0, 0.0, 1.0] 61 | # - object_type: Box 62 | # object_kwargs: 63 | # name: cyan_box 64 | # size: [0.06, 0.03, 0.06] 65 | # color: [0.0, 1.0, 1.0, 1.0] 66 | -------------------------------------------------------------------------------- /configs/pybullet/envs/official/domains/constrained_packing/task2.yaml: -------------------------------------------------------------------------------- 1 | env: pybullet.TableEnv 2 | env_kwargs: 3 | name: constrained_packing_2 4 | gui: false 5 | 6 | primitives: 7 | - pick 8 | - place 9 | - pull 10 | - push 11 | 12 | tasks: 13 | - action_skeleton: 14 | - pick(red_box, table) 15 | - place(red_box, rack) 16 | - pick(yellow_box, table) 17 | - place(yellow_box, rack) 18 | - pick(cyan_box, table) 19 | - place(cyan_box, rack) 20 | - pick(blue_box, table) 21 | - place(blue_box, rack) 22 | initial_state: 23 | - free(red_box) 24 | - free(yellow_box) 25 | - free(cyan_box) 26 | - free(blue_box) 27 | - aligned(rack) 28 | - poslimit(rack) 29 | - inworkspace(rack) 30 | - inworkspace(red_box) 31 | - inworkspace(yellow_box) 32 | - inworkspace(cyan_box) 33 | - inworkspace(blue_box) 34 | - on(rack, table) 35 | - on(red_box, table) 36 | - on(yellow_box, table) 37 | - on(cyan_box, table) 38 | - on(blue_box, table) 39 | 40 | robot_config: configs/pybullet/envs/robots/franka_panda_sim.yaml 41 | 42 | objects: 43 | - object_type: Urdf 44 | object_kwargs: 45 | name: table 46 | path: configs/pybullet/envs/assets/iprl_table.urdf 47 | is_static: true 48 | - object_type: Rack 49 | object_kwargs: 50 | name: rack 51 | size: [0.22, 0.32, 0.16] 52 | color: [0.4, 0.2, 0.0, 1.0] 53 | - configs/pybullet/envs/assets/salt.yaml 54 | - configs/pybullet/envs/assets/yogurt.yaml 55 | - configs/pybullet/envs/assets/milk.yaml 56 | - configs/pybullet/envs/assets/icecream.yaml 57 | # - object_type: Box 58 | # object_kwargs: 59 | # name: blue_box 60 | # size: [0.08, 0.05, 0.07] 61 | # color: [0.0, 0.0, 1.0, 1.0] 62 | # - object_type: Box 63 | # object_kwargs: 64 | # name: yellow_box 65 | # size: [0.07, 0.06, 0.1] 66 | # color: [1.0, 1.0, 0.0, 1.0] 67 | # - object_type: Box 68 | # object_kwargs: 69 | # name: red_box 70 | # size: [0.05, 0.05, 0.07] 71 | # color: [1.0, 0.0, 0.0, 1.0] 72 | # - object_type: Box 73 | # object_kwargs: 74 | # name: cyan_box 75 | # size: [0.06, 0.03, 0.06] 76 | # color: [0.0, 1.0, 1.0, 1.0] 77 | -------------------------------------------------------------------------------- /configs/pybullet/envs/official/domains/hook_reach/task0.yaml: -------------------------------------------------------------------------------- 1 | env: pybullet.TableEnv 2 | env_kwargs: 3 | name: hook_reach_0 4 | gui: false 5 | 6 | primitives: 7 | - pick 8 | - place 9 | - pull 10 | - push 11 | 12 | tasks: 13 | - action_skeleton: 14 | - pick(hook, table) 15 | - pull(blue_box, hook) 16 | - place(hook, table) 17 | - pick(blue_box, table) 18 | initial_state: 19 | - free(hook) 20 | - free(blue_box) 21 | - inworkspace(hook) 22 | - beyondworkspace(blue_box) 23 | - on(hook, table) 24 | - on(blue_box, table) 25 | 26 | robot_config: configs/pybullet/envs/robots/franka_panda_sim.yaml 27 | 28 | objects: 29 | - object_type: Urdf 30 | object_kwargs: 31 | name: table 32 | path: configs/pybullet/envs/assets/iprl_table.urdf 33 | is_static: true 34 | - object_type: Hook 35 | object_kwargs: 36 | name: hook 37 | head_length: 0.2 38 | handle_length: 0.38 39 | handle_y: -1.0 40 | color: [0.6, 0.6, 0.6, 1.0] 41 | - configs/pybullet/envs/assets/salt.yaml 42 | -------------------------------------------------------------------------------- /configs/pybullet/envs/official/domains/hook_reach/task1.yaml: -------------------------------------------------------------------------------- 1 | env: pybullet.TableEnv 2 | env_kwargs: 3 | name: hook_reach_1 4 | gui: false 5 | 6 | primitives: 7 | - pick 8 | - place 9 | - pull 10 | - push 11 | 12 | tasks: 13 | - action_skeleton: 14 | - pick(hook, table) 15 | - pull(yellow_box, hook) 16 | - place(hook, table) 17 | - pick(yellow_box, table) 18 | - place(yellow_box, rack) 19 | initial_state: 20 | - free(hook) 21 | - free(yellow_box) 22 | - free(blue_box) 23 | - aligned(rack) 24 | - poslimit(rack) 25 | - inworkspace(rack) 26 | - inworkspace(hook) 27 | - beyondworkspace(yellow_box) 28 | - nonblocking(yellow_box, rack) 29 | - nonblocking(yellow_box, blue_box) 30 | - on(rack, table) 31 | - on(hook, table) 32 | - on(yellow_box, table) 33 | - on(blue_box, table) 34 | 35 | robot_config: configs/pybullet/envs/robots/franka_panda_sim.yaml 36 | 37 | objects: 38 | - object_type: Urdf 39 | object_kwargs: 40 | name: table 41 | path: configs/pybullet/envs/assets/iprl_table.urdf 42 | is_static: true 43 | - object_type: Rack 44 | object_kwargs: 45 | name: rack 46 | size: [0.22, 0.32, 0.16] 47 | color: [0.4, 0.2, 0.0, 1.0] 48 | - object_type: Hook 49 | object_kwargs: 50 | name: hook 51 | head_length: 0.2 52 | handle_length: 0.38 53 | handle_y: -1.0 54 | color: [0.6, 0.6, 0.6, 1.0] 55 | - configs/pybullet/envs/assets/salt.yaml 56 | - configs/pybullet/envs/assets/yogurt.yaml 57 | # - object_type: Box 58 | # object_kwargs: 59 | # name: blue_box 60 | # size: [0.08, 0.05, 0.07] 61 | # color: [0.0, 0.0, 1.0, 1.0] 62 | # - object_type: Box 63 | # object_kwargs: 64 | # name: yellow_box 65 | # size: [0.07, 0.06, 0.1] 66 | # color: [1.0, 1.0, 0.0, 1.0] 67 | -------------------------------------------------------------------------------- /configs/pybullet/envs/official/domains/hook_reach/task2.yaml: -------------------------------------------------------------------------------- 1 | env: pybullet.TableEnv 2 | env_kwargs: 3 | name: hook_reach_2 4 | gui: false 5 | 6 | primitives: 7 | - pick 8 | - place 9 | - pull 10 | - push 11 | 12 | tasks: 13 | - action_skeleton: 14 | - pick(hook, table) 15 | - pull(red_box, hook) 16 | - place(hook, table) 17 | - pick(red_box, table) 18 | - place(red_box, rack) 19 | initial_state: 20 | - free(hook) 21 | - free(red_box) 22 | - aligned(rack) 23 | - poslimit(rack) 24 | - inworkspace(rack) 25 | - inworkspace(hook) 26 | - beyondworkspace(red_box) 27 | - nonblocking(red_box, rack) 28 | - nonblocking(red_box, cyan_box) 29 | - on(rack, table) 30 | - on(hook, table) 31 | - on(yellow_box, rack) 32 | - on(blue_box, rack) 33 | - on(red_box, table) 34 | - on(cyan_box, table) 35 | 36 | robot_config: configs/pybullet/envs/robots/franka_panda_sim.yaml 37 | 38 | objects: 39 | - object_type: Urdf 40 | object_kwargs: 41 | name: table 42 | path: configs/pybullet/envs/assets/iprl_table.urdf 43 | is_static: true 44 | - object_type: Rack 45 | object_kwargs: 46 | name: rack 47 | size: [0.22, 0.32, 0.16] 48 | color: [0.4, 0.2, 0.0, 1.0] 49 | - object_type: Hook 50 | object_kwargs: 51 | name: hook 52 | head_length: 0.2 53 | handle_length: 0.38 54 | handle_y: -1.0 55 | color: [0.6, 0.6, 0.6, 1.0] 56 | - configs/pybullet/envs/assets/salt.yaml 57 | - configs/pybullet/envs/assets/yogurt.yaml 58 | - configs/pybullet/envs/assets/milk.yaml 59 | - configs/pybullet/envs/assets/icecream.yaml 60 | # - object_type: Box 61 | # object_kwargs: 62 | # name: blue_box 63 | # size: [0.08, 0.05, 0.07] 64 | # color: [0.0, 0.0, 1.0, 1.0] 65 | # - object_type: Box 66 | # object_kwargs: 67 | # name: yellow_box 68 | # size: [0.07, 0.06, 0.1] 69 | # color: [1.0, 1.0, 0.0, 1.0] 70 | # - object_type: Box 71 | # object_kwargs: 72 | # name: red_box 73 | # size: [0.05, 0.05, 0.07] 74 | # color: [1.0, 0.0, 0.0, 1.0] 75 | # - object_type: Box 76 | # object_kwargs: 77 | # name: cyan_box 78 | # size: [0.06, 0.03, 0.06] 79 | # color: [0.0, 1.0, 1.0, 1.0] 80 | -------------------------------------------------------------------------------- /configs/pybullet/envs/official/domains/rearrangement_push/task0.yaml: -------------------------------------------------------------------------------- 1 | env: pybullet.TableEnv 2 | env_kwargs: 3 | name: rearrangement_push_0 4 | gui: false 5 | 6 | primitives: 7 | - pick 8 | - place 9 | - pull 10 | - push 11 | 12 | tasks: 13 | - action_skeleton: 14 | - pick(cyan_box, table) 15 | - place(cyan_box, table) 16 | - pick(hook, table) 17 | - push(cyan_box, hook, rack) 18 | initial_state: 19 | - free(hook) 20 | - free(cyan_box) 21 | - aligned(rack) 22 | - poslimit(rack) 23 | - inworkspace(hook) 24 | - inworkspace(cyan_box) 25 | - beyondworkspace(rack) 26 | - nonblocking(rack, hook) 27 | - on(rack, table) 28 | - on(hook, table) 29 | - on(cyan_box, table) 30 | 31 | robot_config: configs/pybullet/envs/robots/franka_panda_sim.yaml 32 | 33 | objects: 34 | - object_type: Urdf 35 | object_kwargs: 36 | name: table 37 | path: configs/pybullet/envs/assets/iprl_table.urdf 38 | is_static: true 39 | - object_type: Rack 40 | object_kwargs: 41 | name: rack 42 | size: [0.22, 0.32, 0.16] 43 | color: [0.4, 0.2, 0.0, 1.0] 44 | - object_type: Hook 45 | object_kwargs: 46 | name: hook 47 | head_length: 0.2 48 | handle_length: 0.38 49 | handle_y: -1.0 50 | color: [0.6, 0.6, 0.6, 1.0] 51 | - configs/pybullet/envs/assets/icecream.yaml 52 | # - object_type: Box 53 | # object_kwargs: 54 | # name: cyan_box 55 | # size: [0.06, 0.03, 0.06] 56 | # color: [0.0, 1.0, 1.0, 1.0] 57 | -------------------------------------------------------------------------------- /configs/pybullet/envs/official/domains/rearrangement_push/task1.yaml: -------------------------------------------------------------------------------- 1 | env: pybullet.TableEnv 2 | env_kwargs: 3 | name: rearrangement_push_1 4 | gui: false 5 | 6 | primitives: 7 | - pick 8 | - place 9 | - pull 10 | - push 11 | 12 | tasks: 13 | - action_skeleton: 14 | - pick(hook, table) 15 | - place(hook, table) 16 | - pick(cyan_box, table) 17 | - place(cyan_box, table) 18 | - pick(hook, table) 19 | - push(yellow_box, hook, rack) 20 | initial_state: 21 | - free(hook) 22 | - free(red_box) 23 | - free(cyan_box) 24 | - free(yellow_box) 25 | - aligned(rack) 26 | - poslimit(rack) 27 | - inworkspace(hook) 28 | - inworkspace(red_box) 29 | - inoperationalzone(yellow_box) 30 | - inobstructionzone(cyan_box) 31 | - beyondworkspace(rack) 32 | - infront(red_box, rack) 33 | - infront(cyan_box, rack) 34 | - infront(yellow_box, rack) 35 | - on(rack, table) 36 | - on(hook, table) 37 | - on(cyan_box, table) 38 | - on(red_box, table) 39 | - on(yellow_box, table) 40 | 41 | robot_config: configs/pybullet/envs/robots/franka_panda_sim.yaml 42 | 43 | objects: 44 | - object_type: Urdf 45 | object_kwargs: 46 | name: table 47 | path: configs/pybullet/envs/assets/iprl_table.urdf 48 | is_static: true 49 | - object_type: Rack 50 | object_kwargs: 51 | name: rack 52 | size: [0.22, 0.32, 0.16] 53 | color: [0.4, 0.2, 0.0, 1.0] 54 | - object_type: Hook 55 | object_kwargs: 56 | name: hook 57 | head_length: 0.2 58 | handle_length: 0.38 59 | handle_y: -1.0 60 | color: [0.6, 0.6, 0.6, 1.0] 61 | - configs/pybullet/envs/assets/yogurt.yaml 62 | - configs/pybullet/envs/assets/milk.yaml 63 | - configs/pybullet/envs/assets/icecream.yaml 64 | # - object_type: Box 65 | # object_kwargs: 66 | # name: yellow_box 67 | # size: [0.07, 0.06, 0.1] 68 | # color: [1.0, 1.0, 0.0, 1.0] 69 | # - object_type: Box 70 | # object_kwargs: 71 | # name: red_box 72 | # size: [0.05, 0.05, 0.07] 73 | # color: [1.0, 0.0, 0.0, 1.0] 74 | # - object_type: Box 75 | # object_kwargs: 76 | # name: cyan_box 77 | # size: [0.06, 0.03, 0.06] 78 | # color: [0.0, 1.0, 1.0, 1.0] 79 | -------------------------------------------------------------------------------- /configs/pybullet/envs/official/domains/rearrangement_push/task2.yaml: -------------------------------------------------------------------------------- 1 | env: pybullet.TableEnv 2 | env_kwargs: 3 | name: rearrangement_push_2 4 | gui: false 5 | 6 | primitives: 7 | - pick 8 | - place 9 | - pull 10 | - push 11 | 12 | tasks: 13 | - action_skeleton: 14 | - pick(red_box, table) 15 | - place(red_box, table) 16 | - pick(yellow_box, table) 17 | - place(yellow_box, table) 18 | - pick(cyan_box, table) 19 | - place(cyan_box, table) 20 | - pick(hook, table) 21 | - push(blue_box, hook, rack) 22 | initial_state: 23 | - free(hook) 24 | - free(red_box) 25 | - free(cyan_box) 26 | - free(yellow_box) 27 | - free(blue_box) 28 | - aligned(rack) 29 | - poslimit(rack) 30 | - inworkspace(hook) 31 | - inworkspace(red_box) 32 | - inworkspace(yellow_box) 33 | - inobstructionzone(cyan_box) 34 | - beyondworkspace(rack) 35 | - infront(red_box, rack) 36 | - infront(cyan_box, rack) 37 | - infront(yellow_box, rack) 38 | - infront(blue_box, rack) 39 | - on(rack, table) 40 | - on(hook, table) 41 | - on(red_box, table) 42 | - on(cyan_box, table) 43 | - on(yellow_box, table) 44 | - on(blue_box, table) 45 | 46 | robot_config: configs/pybullet/envs/robots/franka_panda_sim.yaml 47 | 48 | objects: 49 | - object_type: Urdf 50 | object_kwargs: 51 | name: table 52 | path: configs/pybullet/envs/assets/iprl_table.urdf 53 | is_static: true 54 | - object_type: Rack 55 | object_kwargs: 56 | name: rack 57 | size: [0.22, 0.32, 0.16] 58 | color: [0.4, 0.2, 0.0, 1.0] 59 | - object_type: Hook 60 | object_kwargs: 61 | name: hook 62 | head_length: 0.2 63 | handle_length: 0.38 64 | handle_y: -1.0 65 | color: [0.6, 0.6, 0.6, 1.0] 66 | - configs/pybullet/envs/assets/salt.yaml 67 | - configs/pybullet/envs/assets/yogurt.yaml 68 | - configs/pybullet/envs/assets/milk.yaml 69 | - configs/pybullet/envs/assets/icecream.yaml 70 | # - object_type: Box 71 | # object_kwargs: 72 | # name: blue_box 73 | # size: [0.08, 0.05, 0.07] 74 | # color: [0.0, 0.0, 1.0, 1.0] 75 | # - object_type: Box 76 | # object_kwargs: 77 | # name: yellow_box 78 | # size: [0.07, 0.06, 0.1] 79 | # color: [1.0, 1.0, 0.0, 1.0] 80 | # - object_type: Box 81 | # object_kwargs: 82 | # name: red_box 83 | # size: [0.05, 0.05, 0.07] 84 | # color: [1.0, 0.0, 0.0, 1.0] 85 | # - object_type: Box 86 | # object_kwargs: 87 | # name: cyan_box 88 | # size: [0.06, 0.03, 0.06] 89 | # color: [0.0, 1.0, 1.0, 1.0] 90 | -------------------------------------------------------------------------------- /configs/pybullet/envs/official/domains/rearrangement_push/task3.yaml: -------------------------------------------------------------------------------- 1 | env: pybullet.TableEnv 2 | env_kwargs: 3 | name: rearrangement_push_3 4 | gui: false 5 | 6 | primitives: 7 | - pick 8 | - place 9 | - pull 10 | - push 11 | 12 | tasks: 13 | - action_skeleton: 14 | - pick(hook, table) 15 | - pull(cyan_box, hook) 16 | - place(hook, table) 17 | - pick(cyan_box, table) 18 | - place(cyan_box, table) 19 | - pick(hook, table) 20 | - push(cyan_box, hook, rack) 21 | initial_state: 22 | - free(hook) 23 | - free(cyan_box) 24 | - aligned(rack) 25 | - poslimit(rack) 26 | - inworkspace(hook) 27 | # - inworkspace(cyan_box) 28 | - beyondworkspace(cyan_box) 29 | - beyondworkspace(rack) 30 | - nonblocking(rack, hook) 31 | - on(rack, table) 32 | - on(hook, table) 33 | - on(cyan_box, table) 34 | 35 | robot_config: configs/pybullet/envs/robots/franka_panda_sim.yaml 36 | 37 | objects: 38 | - object_type: Urdf 39 | object_kwargs: 40 | name: table 41 | path: configs/pybullet/envs/assets/iprl_table.urdf 42 | is_static: true 43 | - object_type: Rack 44 | object_kwargs: 45 | name: rack 46 | size: [0.22, 0.32, 0.16] 47 | color: [0.4, 0.2, 0.0, 1.0] 48 | - object_type: Hook 49 | object_kwargs: 50 | name: hook 51 | head_length: 0.2 52 | handle_length: 0.38 53 | handle_y: -1.0 54 | color: [0.6, 0.6, 0.6, 1.0] 55 | - configs/pybullet/envs/assets/icecream.yaml 56 | # - object_type: Box 57 | # object_kwargs: 58 | # name: cyan_box 59 | # size: [0.06, 0.03, 0.06] 60 | # color: [0.0, 1.0, 1.0, 1.0] 61 | -------------------------------------------------------------------------------- /configs/pybullet/envs/official/primitives/pick.yaml: -------------------------------------------------------------------------------- 1 | env: pybullet.TableEnv 2 | env_kwargs: 3 | name: pick 4 | gui: false 5 | 6 | primitives: 7 | - pick 8 | 9 | tasks: 10 | # - action_skeleton: 11 | # - pick(table_box_1, table) 12 | # initial_state: 13 | # - aligned(rack) 14 | # - poslimit(rack) 15 | # - on(rack, table) 16 | # - on(hook, table) 17 | # - on(table_box_1, table) 18 | # - on(table_box_2, table) 19 | # - on(table_box_3, table) 20 | # - on(table_box_4, table) 21 | # - on(rack_box_1, rack) 22 | # - on(rack_box_2, rack) 23 | # - on(rack_box_3, rack) 24 | # - on(rack_box_4, rack) 25 | - action_skeleton: 26 | - pick(hook, table) 27 | initial_state: 28 | - aligned(rack) 29 | - poslimit(rack) 30 | - on(rack, table) 31 | - on(hook, table) 32 | - on(table_box_1, table) 33 | - on(table_box_2, table) 34 | - on(table_box_3, table) 35 | - on(table_box_4, table) 36 | - on(rack_box_1, rack) 37 | - on(rack_box_2, rack) 38 | - on(rack_box_3, rack) 39 | - on(rack_box_4, rack) 40 | 41 | robot_config: configs/pybullet/envs/robots/franka_panda_sim.yaml 42 | 43 | object_groups: 44 | - name: boxes 45 | objects: 46 | - configs/pybullet/envs/assets/salt.yaml 47 | - configs/pybullet/envs/assets/milk.yaml 48 | - configs/pybullet/envs/assets/yogurt.yaml 49 | - configs/pybullet/envs/assets/icecream.yaml 50 | - object_type: Null 51 | - object_type: Null 52 | - object_type: Null 53 | - object_type: Null 54 | - object_type: Null 55 | - object_type: Null 56 | - object_type: Null 57 | - object_type: Null 58 | 59 | objects: 60 | - object_type: Urdf 61 | object_kwargs: 62 | name: table 63 | path: configs/pybullet/envs/assets/iprl_table.urdf 64 | is_static: true 65 | - object_type: Variant 66 | object_kwargs: 67 | name: rack 68 | variants: 69 | - object_type: Rack 70 | object_kwargs: 71 | size: [0.22, 0.32, 0.16] 72 | color: [0.4, 0.2, 0.0, 1.0] 73 | - object_type: Null 74 | - object_type: Variant 75 | object_kwargs: 76 | name: hook 77 | variants: 78 | - object_type: Hook 79 | object_kwargs: 80 | head_length: 0.2 81 | handle_length: 0.38 82 | handle_y: -1.0 83 | color: [0.6, 0.6, 0.6, 1.0] 84 | - object_type: Null 85 | - object_type: Variant 86 | object_kwargs: 87 | name: table_box_1 88 | group: boxes 89 | - object_type: Variant 90 | object_kwargs: 91 | name: table_box_2 92 | group: boxes 93 | - object_type: Variant 94 | object_kwargs: 95 | name: table_box_3 96 | group: boxes 97 | - object_type: Variant 98 | object_kwargs: 99 | name: table_box_4 100 | group: boxes 101 | - object_type: Variant 102 | object_kwargs: 103 | name: rack_box_1 104 | group: boxes 105 | - object_type: Variant 106 | object_kwargs: 107 | name: rack_box_2 108 | group: boxes 109 | - object_type: Variant 110 | object_kwargs: 111 | name: rack_box_3 112 | group: boxes 113 | - object_type: Variant 114 | object_kwargs: 115 | name: rack_box_4 116 | group: boxes 117 | -------------------------------------------------------------------------------- /configs/pybullet/envs/official/primitives/pick_eval.yaml: -------------------------------------------------------------------------------- 1 | env: pybullet.TableEnv 2 | env_kwargs: 3 | name: pick 4 | gui: false 5 | 6 | primitives: 7 | - pick 8 | 9 | tasks: 10 | - action_skeleton: 11 | - pick(table_box_1, table) 12 | initial_state: 13 | - free(table_box_1) 14 | - aligned(rack) 15 | - poslimit(rack) 16 | - inworkspace(table_box_1) 17 | - on(rack, table) 18 | - on(hook, table) 19 | - on(table_box_1, table) 20 | - on(table_box_2, table) 21 | - on(table_box_3, table) 22 | - on(table_box_4, table) 23 | - on(rack_box_1, rack) 24 | - on(rack_box_2, rack) 25 | - on(rack_box_3, rack) 26 | - on(rack_box_4, rack) 27 | - action_skeleton: 28 | - pick(hook, table) 29 | initial_state: 30 | - free(hook) 31 | - aligned(rack) 32 | - poslimit(rack) 33 | - inworkspace(hook) 34 | - on(rack, table) 35 | - on(hook, table) 36 | - on(table_box_1, table) 37 | - on(table_box_2, table) 38 | - on(table_box_3, table) 39 | - on(table_box_4, table) 40 | - on(rack_box_1, rack) 41 | - on(rack_box_2, rack) 42 | - on(rack_box_3, rack) 43 | - on(rack_box_4, rack) 44 | 45 | robot_config: configs/pybullet/envs/robots/franka_panda_sim.yaml 46 | 47 | object_groups: 48 | - name: boxes 49 | objects: 50 | - configs/pybullet/envs/assets/salt.yaml 51 | - configs/pybullet/envs/assets/milk.yaml 52 | - configs/pybullet/envs/assets/yogurt.yaml 53 | - configs/pybullet/envs/assets/icecream.yaml 54 | - object_type: Null 55 | - object_type: Null 56 | - object_type: Null 57 | - object_type: Null 58 | - object_type: Null 59 | - object_type: Null 60 | - object_type: Null 61 | - object_type: Null 62 | 63 | objects: 64 | - object_type: Urdf 65 | object_kwargs: 66 | name: table 67 | path: configs/pybullet/envs/assets/iprl_table.urdf 68 | is_static: true 69 | - object_type: Variant 70 | object_kwargs: 71 | name: rack 72 | variants: 73 | - object_type: Rack 74 | object_kwargs: 75 | size: [0.22, 0.32, 0.16] 76 | color: [0.4, 0.2, 0.0, 1.0] 77 | - object_type: Null 78 | - object_type: Variant 79 | object_kwargs: 80 | name: hook 81 | variants: 82 | - object_type: Hook 83 | object_kwargs: 84 | head_length: 0.2 85 | handle_length: 0.38 86 | handle_y: -1.0 87 | color: [0.6, 0.6, 0.6, 1.0] 88 | - object_type: Null 89 | - object_type: Variant 90 | object_kwargs: 91 | name: table_box_1 92 | group: boxes 93 | - object_type: Variant 94 | object_kwargs: 95 | name: table_box_2 96 | group: boxes 97 | - object_type: Variant 98 | object_kwargs: 99 | name: table_box_3 100 | group: boxes 101 | - object_type: Variant 102 | object_kwargs: 103 | name: table_box_4 104 | group: boxes 105 | - object_type: Variant 106 | object_kwargs: 107 | name: rack_box_1 108 | group: boxes 109 | - object_type: Variant 110 | object_kwargs: 111 | name: rack_box_2 112 | group: boxes 113 | - object_type: Variant 114 | object_kwargs: 115 | name: rack_box_3 116 | group: boxes 117 | - object_type: Variant 118 | object_kwargs: 119 | name: rack_box_4 120 | group: boxes 121 | -------------------------------------------------------------------------------- /configs/pybullet/envs/official/primitives/place.yaml: -------------------------------------------------------------------------------- 1 | env: pybullet.TableEnv 2 | env_kwargs: 3 | name: place 4 | gui: false 5 | 6 | primitives: 7 | - place 8 | 9 | tasks: 10 | - prob: 0.333 11 | action_skeleton: 12 | - place(table_box_1, table) 13 | initial_state: 14 | - aligned(rack) 15 | - poslimit(rack) 16 | - on(rack, table) 17 | - on(hook, table) 18 | - on(table_box_2, table) 19 | - on(table_box_3, table) 20 | - on(table_box_4, table) 21 | - on(rack_box_1, rack) 22 | - on(rack_box_2, rack) 23 | - on(rack_box_3, rack) 24 | - on(rack_box_4, rack) 25 | - inhand(table_box_1) 26 | - prob: 0.333 27 | action_skeleton: 28 | - place(rack_box_1, rack) 29 | initial_state: 30 | - aligned(rack) 31 | - poslimit(rack) 32 | - on(rack, table) 33 | - on(hook, table) 34 | - on(table_box_1, table) 35 | - on(table_box_2, table) 36 | - on(table_box_3, table) 37 | - on(table_box_4, table) 38 | - on(rack_box_2, rack) 39 | - on(rack_box_3, rack) 40 | - on(rack_box_4, rack) 41 | - inhand(rack_box_1) 42 | - prob: 0.167 43 | action_skeleton: 44 | - place(hook, table) 45 | initial_state: 46 | - aligned(rack) 47 | - poslimit(rack) 48 | - on(rack, table) 49 | - on(table_box_1, table) 50 | - on(table_box_2, table) 51 | - on(table_box_3, table) 52 | - on(table_box_4, table) 53 | - on(rack_box_1, rack) 54 | - on(rack_box_2, rack) 55 | - on(rack_box_3, rack) 56 | - on(rack_box_4, rack) 57 | - inhand(hook) 58 | - prob: 0.167 59 | action_skeleton: 60 | - place(hook, table) 61 | initial_state: 62 | - handlegrasp(hook) 63 | - aligned(rack) 64 | - poslimit(rack) 65 | - on(rack, table) 66 | - on(table_box_1, table) 67 | - on(table_box_2, table) 68 | - on(table_box_3, table) 69 | - on(table_box_4, table) 70 | - on(rack_box_1, rack) 71 | - on(rack_box_2, rack) 72 | - on(rack_box_3, rack) 73 | - on(rack_box_4, rack) 74 | - inhand(hook) 75 | 76 | robot_config: configs/pybullet/envs/robots/franka_panda_sim.yaml 77 | 78 | object_groups: 79 | - name: boxes 80 | objects: 81 | - configs/pybullet/envs/assets/salt.yaml 82 | - configs/pybullet/envs/assets/milk.yaml 83 | - configs/pybullet/envs/assets/yogurt.yaml 84 | - configs/pybullet/envs/assets/icecream.yaml 85 | - object_type: Null 86 | - object_type: Null 87 | - object_type: Null 88 | - object_type: Null 89 | - object_type: Null 90 | - object_type: Null 91 | - object_type: Null 92 | - object_type: Null 93 | 94 | objects: 95 | - object_type: Urdf 96 | object_kwargs: 97 | name: table 98 | path: configs/pybullet/envs/assets/iprl_table.urdf 99 | is_static: true 100 | - object_type: Variant 101 | object_kwargs: 102 | name: rack 103 | variants: 104 | - object_type: Rack 105 | object_kwargs: 106 | size: [0.22, 0.32, 0.16] 107 | color: [0.4, 0.2, 0.0, 1.0] 108 | - object_type: Null 109 | - object_type: Variant 110 | object_kwargs: 111 | name: hook 112 | variants: 113 | - object_type: Hook 114 | object_kwargs: 115 | head_length: 0.2 116 | handle_length: 0.38 117 | handle_y: -1.0 118 | color: [0.6, 0.6, 0.6, 1.0] 119 | - object_type: Null 120 | - object_type: Variant 121 | object_kwargs: 122 | name: table_box_1 123 | group: boxes 124 | - object_type: Variant 125 | object_kwargs: 126 | name: table_box_2 127 | group: boxes 128 | - object_type: Variant 129 | object_kwargs: 130 | name: table_box_3 131 | group: boxes 132 | - object_type: Variant 133 | object_kwargs: 134 | name: table_box_4 135 | group: boxes 136 | - object_type: Variant 137 | object_kwargs: 138 | name: rack_box_1 139 | group: boxes 140 | - object_type: Variant 141 | object_kwargs: 142 | name: rack_box_2 143 | group: boxes 144 | - object_type: Variant 145 | object_kwargs: 146 | name: rack_box_3 147 | group: boxes 148 | - object_type: Variant 149 | object_kwargs: 150 | name: rack_box_4 151 | group: boxes 152 | -------------------------------------------------------------------------------- /configs/pybullet/envs/official/primitives/place_eval.yaml: -------------------------------------------------------------------------------- 1 | env: pybullet.TableEnv 2 | env_kwargs: 3 | name: place 4 | gui: false 5 | 6 | primitives: 7 | - place 8 | 9 | tasks: 10 | # - prob: 0.333 11 | # action_skeleton: 12 | # - place(table_box_1, table) 13 | # initial_state: 14 | # - aligned(rack) 15 | # - poslimit(rack) 16 | # - on(rack, table) 17 | # - on(hook, table) 18 | # - on(table_box_2, table) 19 | # - on(table_box_3, table) 20 | # - on(table_box_4, table) 21 | # - on(rack_box_1, rack) 22 | # - on(rack_box_2, rack) 23 | # - on(rack_box_3, rack) 24 | # - on(rack_box_4, rack) 25 | # - inhand(table_box_1) 26 | - prob: 1.0 # 0.333 27 | action_skeleton: 28 | - place(rack_box_1, rack) 29 | initial_state: 30 | - aligned(rack) 31 | - poslimit(rack) 32 | - inworkspace(rack) 33 | - on(rack, table) 34 | - on(hook, table) 35 | - on(table_box_1, table) 36 | - on(table_box_2, table) 37 | - on(table_box_3, table) 38 | - on(table_box_4, table) 39 | - on(rack_box_2, rack) 40 | - on(rack_box_3, rack) 41 | - on(rack_box_4, rack) 42 | - inhand(rack_box_1) 43 | # - prob: 0.167 44 | # action_skeleton: 45 | # - place(hook, table) 46 | # initial_state: 47 | # - aligned(rack) 48 | # - poslimit(rack) 49 | # - on(rack, table) 50 | # - on(table_box_1, table) 51 | # - on(table_box_2, table) 52 | # - on(table_box_3, table) 53 | # - on(table_box_4, table) 54 | # - on(rack_box_1, rack) 55 | # - on(rack_box_2, rack) 56 | # - on(rack_box_3, rack) 57 | # - on(rack_box_4, rack) 58 | # - inhand(hook) 59 | # - prob: 0.167 60 | # action_skeleton: 61 | # - place(hook, table) 62 | # initial_state: 63 | # - handlegrasp(hook) 64 | # - aligned(rack) 65 | # - poslimit(rack) 66 | # - on(rack, table) 67 | # - on(table_box_1, table) 68 | # - on(table_box_2, table) 69 | # - on(table_box_3, table) 70 | # - on(table_box_4, table) 71 | # - on(rack_box_1, rack) 72 | # - on(rack_box_2, rack) 73 | # - on(rack_box_3, rack) 74 | # - on(rack_box_4, rack) 75 | # - inhand(hook) 76 | 77 | robot_config: configs/pybullet/envs/robots/franka_panda_sim.yaml 78 | 79 | object_groups: 80 | - name: boxes 81 | objects: 82 | - configs/pybullet/envs/assets/salt.yaml 83 | - configs/pybullet/envs/assets/milk.yaml 84 | - configs/pybullet/envs/assets/yogurt.yaml 85 | - configs/pybullet/envs/assets/icecream.yaml 86 | - object_type: Null 87 | - object_type: Null 88 | - object_type: Null 89 | - object_type: Null 90 | - object_type: Null 91 | - object_type: Null 92 | - object_type: Null 93 | - object_type: Null 94 | 95 | objects: 96 | - object_type: Urdf 97 | object_kwargs: 98 | name: table 99 | path: configs/pybullet/envs/assets/iprl_table.urdf 100 | is_static: true 101 | - object_type: Variant 102 | object_kwargs: 103 | name: rack 104 | variants: 105 | - object_type: Rack 106 | object_kwargs: 107 | size: [0.22, 0.32, 0.16] 108 | color: [0.4, 0.2, 0.0, 1.0] 109 | - object_type: Null 110 | - object_type: Variant 111 | object_kwargs: 112 | name: hook 113 | variants: 114 | - object_type: Hook 115 | object_kwargs: 116 | head_length: 0.2 117 | handle_length: 0.38 118 | handle_y: -1.0 119 | color: [0.6, 0.6, 0.6, 1.0] 120 | - object_type: Null 121 | - object_type: Variant 122 | object_kwargs: 123 | name: table_box_1 124 | group: boxes 125 | - object_type: Variant 126 | object_kwargs: 127 | name: table_box_2 128 | group: boxes 129 | - object_type: Variant 130 | object_kwargs: 131 | name: table_box_3 132 | group: boxes 133 | - object_type: Variant 134 | object_kwargs: 135 | name: table_box_4 136 | group: boxes 137 | - object_type: Variant 138 | object_kwargs: 139 | name: rack_box_1 140 | group: boxes 141 | - object_type: Variant 142 | object_kwargs: 143 | name: rack_box_2 144 | group: boxes 145 | - object_type: Variant 146 | object_kwargs: 147 | name: rack_box_3 148 | group: boxes 149 | - object_type: Variant 150 | object_kwargs: 151 | name: rack_box_4 152 | group: boxes 153 | -------------------------------------------------------------------------------- /configs/pybullet/envs/official/primitives/pull.yaml: -------------------------------------------------------------------------------- 1 | env: pybullet.TableEnv 2 | env_kwargs: 3 | name: pull 4 | gui: false 5 | 6 | primitives: 7 | - pull 8 | 9 | tasks: 10 | - action_skeleton: 11 | - pull(table_box_1, hook) 12 | initial_state: 13 | - aligned(rack) 14 | - poslimit(rack) 15 | - beyondworkspace(table_box_1) 16 | - on(rack, table) 17 | - on(table_box_1, table) 18 | - on(table_box_2, table) 19 | - on(table_box_3, table) 20 | - on(table_box_4, table) 21 | - on(rack_box_1, rack) 22 | - on(rack_box_2, rack) 23 | - on(rack_box_3, rack) 24 | - on(rack_box_4, rack) 25 | - inhand(hook) 26 | - action_skeleton: 27 | - pull(table_box_1, hook) 28 | initial_state: 29 | - handlegrasp(hook) 30 | - aligned(rack) 31 | - poslimit(rack) 32 | - beyondworkspace(table_box_1) 33 | - on(rack, table) 34 | - on(table_box_1, table) 35 | - on(table_box_2, table) 36 | - on(table_box_3, table) 37 | - on(table_box_4, table) 38 | - on(rack_box_1, rack) 39 | - on(rack_box_2, rack) 40 | - on(rack_box_3, rack) 41 | - on(rack_box_4, rack) 42 | - inhand(hook) 43 | 44 | robot_config: configs/pybullet/envs/robots/franka_panda_sim.yaml 45 | 46 | object_groups: 47 | - name: boxes 48 | objects: 49 | - configs/pybullet/envs/assets/salt.yaml 50 | - configs/pybullet/envs/assets/milk.yaml 51 | - configs/pybullet/envs/assets/yogurt.yaml 52 | - configs/pybullet/envs/assets/icecream.yaml 53 | - object_type: Null 54 | - object_type: Null 55 | - object_type: Null 56 | - object_type: Null 57 | - object_type: Null 58 | - object_type: Null 59 | - object_type: Null 60 | 61 | objects: 62 | - object_type: Urdf 63 | object_kwargs: 64 | name: table 65 | path: configs/pybullet/envs/assets/iprl_table.urdf 66 | is_static: true 67 | - object_type: Variant 68 | object_kwargs: 69 | name: rack 70 | variants: 71 | - object_type: Rack 72 | object_kwargs: 73 | size: [0.22, 0.32, 0.16] 74 | color: [0.4, 0.2, 0.0, 1.0] 75 | - object_type: Null 76 | - object_type: Hook 77 | object_kwargs: 78 | name: hook 79 | head_length: 0.2 80 | handle_length: 0.38 81 | handle_y: -1.0 82 | color: [0.6, 0.6, 0.6, 1.0] 83 | - object_type: Variant 84 | object_kwargs: 85 | name: table_box_1 86 | group: boxes 87 | - object_type: Variant 88 | object_kwargs: 89 | name: table_box_2 90 | group: boxes 91 | - object_type: Variant 92 | object_kwargs: 93 | name: table_box_3 94 | group: boxes 95 | - object_type: Variant 96 | object_kwargs: 97 | name: table_box_4 98 | group: boxes 99 | - object_type: Variant 100 | object_kwargs: 101 | name: rack_box_1 102 | group: boxes 103 | - object_type: Variant 104 | object_kwargs: 105 | name: rack_box_2 106 | group: boxes 107 | - object_type: Variant 108 | object_kwargs: 109 | name: rack_box_3 110 | group: boxes 111 | - object_type: Variant 112 | object_kwargs: 113 | name: rack_box_4 114 | group: boxes 115 | -------------------------------------------------------------------------------- /configs/pybullet/envs/official/primitives/pull_eval.yaml: -------------------------------------------------------------------------------- 1 | env: pybullet.TableEnv 2 | env_kwargs: 3 | name: pull 4 | gui: false 5 | 6 | primitives: 7 | - pull 8 | 9 | tasks: 10 | - action_skeleton: 11 | - pull(table_box_1, hook) 12 | initial_state: 13 | - handlegrasp(hook) 14 | - free(table_box_1) 15 | - aligned(rack) 16 | - poslimit(rack) 17 | - beyondworkspace(table_box_1) 18 | - nonblocking(table_box_1, rack) 19 | - on(rack, table) 20 | - on(table_box_1, table) 21 | - on(table_box_2, table) 22 | - on(table_box_3, table) 23 | - on(table_box_4, table) 24 | - on(rack_box_1, rack) 25 | - on(rack_box_2, rack) 26 | - on(rack_box_3, rack) 27 | - on(rack_box_4, rack) 28 | - inhand(hook) 29 | 30 | robot_config: configs/pybullet/envs/robots/franka_panda_sim.yaml 31 | 32 | object_groups: 33 | - name: boxes 34 | objects: 35 | - configs/pybullet/envs/assets/salt.yaml 36 | - configs/pybullet/envs/assets/milk.yaml 37 | - configs/pybullet/envs/assets/yogurt.yaml 38 | - configs/pybullet/envs/assets/icecream.yaml 39 | - object_type: Null 40 | - object_type: Null 41 | - object_type: Null 42 | - object_type: Null 43 | - object_type: Null 44 | - object_type: Null 45 | - object_type: Null 46 | 47 | objects: 48 | - object_type: Urdf 49 | object_kwargs: 50 | name: table 51 | path: configs/pybullet/envs/assets/iprl_table.urdf 52 | is_static: true 53 | - object_type: Variant 54 | object_kwargs: 55 | name: rack 56 | variants: 57 | - object_type: Rack 58 | object_kwargs: 59 | size: [0.22, 0.32, 0.16] 60 | color: [0.4, 0.2, 0.0, 1.0] 61 | - object_type: Null 62 | - object_type: Hook 63 | object_kwargs: 64 | name: hook 65 | head_length: 0.2 66 | handle_length: 0.38 67 | handle_y: -1.0 68 | color: [0.6, 0.6, 0.6, 1.0] 69 | - object_type: Variant 70 | object_kwargs: 71 | name: table_box_1 72 | group: boxes 73 | - object_type: Variant 74 | object_kwargs: 75 | name: table_box_2 76 | group: boxes 77 | - object_type: Variant 78 | object_kwargs: 79 | name: table_box_3 80 | group: boxes 81 | - object_type: Variant 82 | object_kwargs: 83 | name: table_box_4 84 | group: boxes 85 | - object_type: Variant 86 | object_kwargs: 87 | name: rack_box_1 88 | group: boxes 89 | - object_type: Variant 90 | object_kwargs: 91 | name: rack_box_2 92 | group: boxes 93 | - object_type: Variant 94 | object_kwargs: 95 | name: rack_box_3 96 | group: boxes 97 | - object_type: Variant 98 | object_kwargs: 99 | name: rack_box_4 100 | group: boxes 101 | -------------------------------------------------------------------------------- /configs/pybullet/envs/official/primitives/push.yaml: -------------------------------------------------------------------------------- 1 | env: pybullet.TableEnv 2 | env_kwargs: 3 | name: push 4 | gui: false 5 | 6 | primitives: 7 | - push 8 | 9 | tasks: 10 | - action_skeleton: 11 | - push(table_box_1, hook, rack) 12 | initial_state: 13 | - aligned(rack) 14 | - poslimit(rack) 15 | - inworkspace(table_box_1) 16 | - beyondworkspace(rack) 17 | - on(rack, table) 18 | - on(table_box_1, table) 19 | - on(table_box_2, table) 20 | - on(table_box_3, table) 21 | - on(table_box_4, table) 22 | - on(rack_box_1, rack) 23 | - on(rack_box_2, rack) 24 | - on(rack_box_3, rack) 25 | - on(rack_box_4, rack) 26 | - inhand(hook) 27 | - action_skeleton: 28 | - push(table_box_1, hook, rack) 29 | initial_state: 30 | - upperhandlegrasp(hook) 31 | - aligned(rack) 32 | - poslimit(rack) 33 | - inworkspace(table_box_1) 34 | - beyondworkspace(rack) 35 | - on(rack, table) 36 | - on(table_box_1, table) 37 | - on(table_box_2, table) 38 | - on(table_box_3, table) 39 | - on(table_box_4, table) 40 | - on(rack_box_1, rack) 41 | - on(rack_box_2, rack) 42 | - on(rack_box_3, rack) 43 | - on(rack_box_4, rack) 44 | - inhand(hook) 45 | 46 | robot_config: configs/pybullet/envs/robots/franka_panda_sim.yaml 47 | 48 | object_groups: 49 | - name: boxes 50 | objects: 51 | - configs/pybullet/envs/assets/salt.yaml 52 | - configs/pybullet/envs/assets/milk.yaml 53 | - configs/pybullet/envs/assets/yogurt.yaml 54 | - configs/pybullet/envs/assets/icecream.yaml 55 | - object_type: Null 56 | - object_type: Null 57 | - object_type: Null 58 | - object_type: Null 59 | - object_type: Null 60 | - object_type: Null 61 | - object_type: Null 62 | 63 | objects: 64 | - object_type: Urdf 65 | object_kwargs: 66 | name: table 67 | path: configs/pybullet/envs/assets/iprl_table.urdf 68 | is_static: true 69 | - object_type: Rack 70 | object_kwargs: 71 | name: rack 72 | size: [0.22, 0.32, 0.16] 73 | color: [0.4, 0.2, 0.0, 1.0] 74 | - object_type: Hook 75 | object_kwargs: 76 | name: hook 77 | head_length: 0.2 78 | handle_length: 0.38 79 | handle_y: -1.0 80 | color: [0.6, 0.6, 0.6, 1.0] 81 | - object_type: Variant 82 | object_kwargs: 83 | name: table_box_1 84 | group: boxes 85 | - object_type: Variant 86 | object_kwargs: 87 | name: table_box_2 88 | group: boxes 89 | - object_type: Variant 90 | object_kwargs: 91 | name: table_box_3 92 | group: boxes 93 | - object_type: Variant 94 | object_kwargs: 95 | name: table_box_4 96 | group: boxes 97 | - object_type: Variant 98 | object_kwargs: 99 | name: rack_box_1 100 | group: boxes 101 | - object_type: Variant 102 | object_kwargs: 103 | name: rack_box_2 104 | group: boxes 105 | - object_type: Variant 106 | object_kwargs: 107 | name: rack_box_3 108 | group: boxes 109 | - object_type: Variant 110 | object_kwargs: 111 | name: rack_box_4 112 | group: boxes 113 | -------------------------------------------------------------------------------- /configs/pybullet/envs/official/primitives/push_eval.yaml: -------------------------------------------------------------------------------- 1 | env: pybullet.TableEnv 2 | env_kwargs: 3 | name: push 4 | gui: false 5 | 6 | primitives: 7 | - push 8 | 9 | tasks: 10 | - action_skeleton: 11 | - push(table_box_1, hook, rack) 12 | initial_state: 13 | - upperhandlegrasp(hook) 14 | - free(table_box_1) 15 | - aligned(rack) 16 | - poslimit(rack) 17 | - inoperationalzone(table_box_1) 18 | - beyondworkspace(rack) 19 | - infront(table_box_1, rack) 20 | - on(rack, table) 21 | - nonblocking(rack, table_box_2) 22 | - nonblocking(rack, table_box_3) 23 | - nonblocking(rack, table_box_4) 24 | - nonblocking(table_box_1, table_box_2) 25 | - nonblocking(table_box_1, table_box_3) 26 | - nonblocking(table_box_1, table_box_4) 27 | - on(table_box_1, table) 28 | - on(table_box_2, table) 29 | - on(table_box_3, table) 30 | - on(table_box_4, table) 31 | - on(rack_box_1, rack) 32 | - on(rack_box_2, rack) 33 | - on(rack_box_3, rack) 34 | - on(rack_box_4, rack) 35 | - inhand(hook) 36 | 37 | robot_config: configs/pybullet/envs/robots/franka_panda_sim.yaml 38 | 39 | object_groups: 40 | - name: boxes 41 | objects: 42 | - configs/pybullet/envs/assets/salt.yaml 43 | - configs/pybullet/envs/assets/milk.yaml 44 | - configs/pybullet/envs/assets/yogurt.yaml 45 | - configs/pybullet/envs/assets/icecream.yaml 46 | - object_type: Null 47 | - object_type: Null 48 | - object_type: Null 49 | - object_type: Null 50 | - object_type: Null 51 | - object_type: Null 52 | - object_type: Null 53 | 54 | objects: 55 | - object_type: Urdf 56 | object_kwargs: 57 | name: table 58 | path: configs/pybullet/envs/assets/iprl_table.urdf 59 | is_static: true 60 | - object_type: Rack 61 | object_kwargs: 62 | name: rack 63 | size: [0.22, 0.32, 0.16] 64 | color: [0.4, 0.2, 0.0, 1.0] 65 | - object_type: Hook 66 | object_kwargs: 67 | name: hook 68 | head_length: 0.2 69 | handle_length: 0.38 70 | handle_y: -1.0 71 | color: [0.6, 0.6, 0.6, 1.0] 72 | - object_type: Variant 73 | object_kwargs: 74 | name: table_box_1 75 | group: boxes 76 | - object_type: Variant 77 | object_kwargs: 78 | name: table_box_2 79 | group: boxes 80 | - object_type: Variant 81 | object_kwargs: 82 | name: table_box_3 83 | group: boxes 84 | - object_type: Variant 85 | object_kwargs: 86 | name: table_box_4 87 | group: boxes 88 | - object_type: Variant 89 | object_kwargs: 90 | name: rack_box_1 91 | group: boxes 92 | - object_type: Variant 93 | object_kwargs: 94 | name: rack_box_2 95 | group: boxes 96 | - object_type: Variant 97 | object_kwargs: 98 | name: rack_box_3 99 | group: boxes 100 | - object_type: Variant 101 | object_kwargs: 102 | name: rack_box_4 103 | group: boxes 104 | -------------------------------------------------------------------------------- /configs/pybullet/envs/robots/franka_panda_sim.yaml: -------------------------------------------------------------------------------- 1 | urdf: configs/pybullet/envs/assets/franka_panda/franka_panda_robotiq.urdf 2 | 3 | arm_class: sim.arm.Arm 4 | arm_kwargs: 5 | arm_urdf: configs/pybullet/envs/assets/franka_panda/franka_panda.urdf 6 | torque_joints: 7 | - joint1 8 | - joint2 9 | - joint3 10 | - joint4 11 | - joint5 12 | - joint6 13 | - joint7 14 | 15 | # q_home: [0.0, -0.52359878, 0.0, -2.61799388, 0.0, 2.0943951, 0.0] 16 | q_home: [0.0, -0.5, 0.0, -2.28, 0.0, 1.78, 0.0] 17 | ee_offset: [0.0, 0.0, 0.251] 18 | 19 | pos_gains: [121, 22] 20 | ori_gains: [121, 16] 21 | nullspace_joint_gains: [40, 10] 22 | nullspace_joint_indices: [2, 4] 23 | 24 | pos_threshold: [0.01, 0.01] 25 | ori_threshold: [0.01, 0.01] 26 | timeout: 10.0 27 | 28 | redisgl_config: 29 | gripper_offset: [0.0, 0.0, 0.107] 30 | 31 | redis_host: "127.0.0.1" 32 | redis_port: 6000 33 | redis_password: taps 34 | 35 | redis_keys: 36 | namespace: "franka_panda" 37 | control_pos: "franka_panda::control::pos" 38 | control_ori: "franka_panda::control::ori" 39 | control_pos_des: "franka_panda::control::pos_des" 40 | control_ori_des: "franka_panda::control::ori_des" 41 | opspace_inertia_pos: "franka_panda::opspace::inertia_pos" 42 | opspace_inertia_ori: "franka_panda::opspace::inertia_ori" 43 | sensor_q: "franka_panda::sensor::q" 44 | sensor_dq: "franka_panda::sensor::dq" 45 | sensor_pos: "franka_panda::sensor::pos" 46 | sensor_ori: "franka_panda::sensor::ori" 47 | 48 | gripper_class: sim.gripper.Gripper 49 | gripper_kwargs: 50 | torque_joints: 51 | - left_inner_finger_pad_prismatic 52 | 53 | position_joints: 54 | - right_inner_finger_pad_prismatic 55 | - finger_joint 56 | - left_inner_knuckle_joint 57 | - left_inner_finger_joint 58 | - right_outer_knuckle_joint 59 | - right_inner_knuckle_joint 60 | - right_inner_finger_joint 61 | 62 | finger_links: 63 | - left_inner_finger_pad_collision 64 | - right_inner_finger_pad_collision 65 | 66 | base_link: robotiq_arg2f_base_link 67 | 68 | command_multipliers: 69 | [0.04, -0.04, 0.813, 0.813, -0.813, 0.813, 0.813, -0.813] 70 | 71 | finger_contact_normals: 72 | - [0.0, 1.0, 0.0] # left_inner_finger_pad_collision 73 | - [0.0, -1.0, 0.0] # right_inner_finger_pad_collision 74 | 75 | inertia_kwargs: 76 | mass: 0.83 77 | com: [0, 0, 0.11] 78 | inertia: [0.001, 0.0025, 0.0017, 0, 0, 0] 79 | 80 | pos_gains: [400, 40] 81 | pos_threshold: 0.001 82 | timeout: 1.0 83 | -------------------------------------------------------------------------------- /generative_skill_chaining/__init__.py: -------------------------------------------------------------------------------- 1 | from . import agents 2 | # from . import controllers 3 | # from . import datasets 4 | # from . import dynamics 5 | from . import encoders 6 | from . import envs 7 | from . import networks 8 | # from . import planners 9 | # from . import processors 10 | # from . import schedulers 11 | from . import utils 12 | -------------------------------------------------------------------------------- /generative_skill_chaining/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Agent 2 | from .rl import RLAgent 3 | from .sac import SAC 4 | from .utils import * 5 | from .wrapper import WrapperAgent 6 | -------------------------------------------------------------------------------- /generative_skill_chaining/agents/base.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import gym 4 | import torch 5 | 6 | from generative_skill_chaining import encoders, networks 7 | from generative_skill_chaining.utils import tensors 8 | 9 | 10 | class Agent: 11 | """Base agent class.""" 12 | 13 | def __init__( 14 | self, 15 | state_space: gym.spaces.Box, 16 | action_space: gym.spaces.Box, 17 | observation_space: gym.spaces.Box, 18 | actor: networks.actors.Actor, 19 | critic: networks.critics.Critic, 20 | encoder: encoders.Encoder, 21 | device: str = "auto", 22 | ): 23 | """Assigns the required properties of the Agent. 24 | 25 | Args: 26 | state_space: Policy state space (encoder output, actor/critic input). 27 | action_space: Action space (actor output). 28 | observation_space: Observation space (encoder input). 29 | actor: Actor network. 30 | critic: Critic network. 31 | encoder: Encoder network. 32 | device: Torch device. 33 | """ 34 | assert isinstance(action_space, gym.spaces.Box) 35 | self._state_space = state_space 36 | self._action_space = action_space 37 | self._observation_space = observation_space 38 | self._actor = actor 39 | self._critic = critic 40 | self._encoder = encoder 41 | self.to(device) 42 | 43 | @property 44 | def state_space(self) -> gym.spaces.Box: 45 | """Policy state space (encoder output, actor/critic input).""" 46 | return self._state_space 47 | 48 | @property 49 | def action_space(self) -> gym.spaces.Box: 50 | """Action space (actor output).""" 51 | return self._action_space 52 | 53 | @property 54 | def observation_space(self) -> gym.spaces.Box: 55 | """Observation space (encoder input).""" 56 | return self._observation_space 57 | 58 | @property 59 | def actor(self) -> networks.actors.Actor: 60 | """Actor network that takes as input a state and outputs an action.""" 61 | return self._actor 62 | 63 | @property 64 | def critic(self) -> networks.critics.Critic: 65 | """Critic network that takes as input a state/action and outputs a 66 | success probability.""" 67 | return self._critic 68 | 69 | @property 70 | def encoder(self) -> encoders.Encoder: 71 | """Encoder network that encodes observations into states.""" 72 | return self._encoder 73 | 74 | @property 75 | def device(self) -> torch.device: 76 | """Torch device.""" 77 | return self._device 78 | 79 | def to(self, device: Union[str, torch.device]) -> "Agent": 80 | """Transfers networks to a device.""" 81 | self._device = tensors.device(device) 82 | self.actor.to(self.device) 83 | self.critic.to(self.device) 84 | self.encoder.to(self.device) 85 | return self 86 | 87 | def train_mode(self) -> None: 88 | """Switches the networks to train mode.""" 89 | self.actor.train() 90 | self.critic.train() 91 | self.encoder.train_mode() 92 | 93 | def eval_mode(self) -> None: 94 | """Switches the networks to eval mode.""" 95 | self.actor.eval() 96 | self.critic.eval() 97 | self.encoder.eval_mode() 98 | -------------------------------------------------------------------------------- /generative_skill_chaining/agents/rl.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from typing import Dict, Optional, OrderedDict, Union 3 | 4 | import torch 5 | 6 | from generative_skill_chaining import encoders, envs, networks 7 | from generative_skill_chaining.agents.base import Agent 8 | from generative_skill_chaining.utils.typing import Batch, Model 9 | 10 | 11 | class RLAgent(Agent, Model[Batch]): 12 | """RL agent base class.""" 13 | 14 | def __init__( 15 | self, 16 | env: envs.Env, 17 | actor: networks.actors.Actor, 18 | critic: networks.critics.Critic, 19 | encoder: encoders.Encoder, 20 | checkpoint: Optional[Union[str, pathlib.Path]] = None, 21 | device: str = "auto", 22 | ): 23 | """Sets up the agent and loads from checkpoint if available. 24 | 25 | Args: 26 | env: Agent env. 27 | actor: Actor network. 28 | critic: Critic network. 29 | encoder: Encoder network. 30 | checkpoint: Policy checkpoint. 31 | device: Torch device. 32 | """ 33 | super().__init__( 34 | state_space=encoder.state_space, 35 | action_space=env.action_space, 36 | observation_space=env.observation_space, 37 | actor=actor, 38 | critic=critic, 39 | encoder=encoder, 40 | device=device, 41 | ) 42 | 43 | self._env = env 44 | 45 | if checkpoint is not None: 46 | self.load(checkpoint, strict=True) 47 | 48 | @property 49 | def env(self) -> envs.Env: 50 | """Agent environment.""" 51 | return self._env 52 | 53 | def load_state_dict( 54 | self, state_dict: Dict[str, OrderedDict[str, torch.Tensor]], strict: bool = True 55 | ) -> None: 56 | """Loads the agent state dict. 57 | 58 | Args: 59 | state_dict: Torch state dict. 60 | strict: Ensure state_dict keys match networks exactly. 61 | """ 62 | self.critic.load_state_dict(state_dict["critic"], strict=strict) 63 | self.actor.load_state_dict(state_dict["actor"], strict=strict) 64 | self.encoder.network.load_state_dict(state_dict["encoder"], strict=strict) 65 | 66 | def state_dict(self) -> Dict[str, Dict[str, torch.Tensor]]: 67 | """Gets the agent state dicts.""" 68 | return { 69 | "critic": self.critic.state_dict(), 70 | "actor": self.actor.state_dict(), 71 | "encoder": self.encoder.network.state_dict(), 72 | } 73 | -------------------------------------------------------------------------------- /generative_skill_chaining/agents/wrapper.py: -------------------------------------------------------------------------------- 1 | from generative_skill_chaining.agents import base as agents 2 | 3 | 4 | class WrapperAgent(agents.Agent): 5 | """Base wrapper agent class.""" 6 | 7 | pass 8 | -------------------------------------------------------------------------------- /generative_skill_chaining/diff_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/generative-skill-chaining/gsc-code/b8f5bc0d44bd453827f3f254ca1b03ec049849c5/generative_skill_chaining/diff_models/__init__.py -------------------------------------------------------------------------------- /generative_skill_chaining/diff_models/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/generative-skill-chaining/gsc-code/b8f5bc0d44bd453827f3f254ca1b03ec049849c5/generative_skill_chaining/diff_models/utils/__init__.py -------------------------------------------------------------------------------- /generative_skill_chaining/diff_models/utils/helpers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | ############################################################################## 4 | # Code derived from "Planning with Diffusion for Flexible Behavior Synthesis" 5 | # Janner et al. (2022) https://diffusion-planning.github.io/ 6 | ############################################################################## 7 | 8 | import math 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import einops 14 | from einops.layers.torch import Rearrange 15 | import pdb 16 | 17 | class SinusoidalPosEmb(nn.Module): 18 | def __init__(self, dim): 19 | super().__init__() 20 | self.dim = dim 21 | 22 | def forward(self, x): 23 | device = x.device 24 | half_dim = self.dim // 2 25 | emb = math.log(10000) / (half_dim - 1) 26 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 27 | emb = x[:, None] * emb[None, :] 28 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 29 | return emb 30 | 31 | class Downsample1d(nn.Module): 32 | def __init__(self, dim): 33 | super().__init__() 34 | self.conv = nn.Conv1d(dim, dim, 3, 2, 1) 35 | 36 | def forward(self, x): 37 | return self.conv(x) 38 | 39 | class Upsample1d(nn.Module): 40 | def __init__(self, dim): 41 | super().__init__() 42 | self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) 43 | 44 | def forward(self, x): 45 | return self.conv(x) 46 | 47 | class Conv1dBlock(nn.Module): 48 | ''' 49 | Conv1d --> GroupNorm --> Mish 50 | ''' 51 | 52 | def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): 53 | super().__init__() 54 | 55 | self.block = nn.Sequential( 56 | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), 57 | Rearrange('batch channels horizon -> batch channels 1 horizon'), 58 | nn.GroupNorm(n_groups, out_channels), 59 | Rearrange('batch channels 1 horizon -> batch channels horizon'), 60 | nn.Mish(), 61 | ) 62 | 63 | def forward(self, x): 64 | return self.block(x) 65 | 66 | class Linear1DBlock(nn.Module): 67 | ''' 68 | Linear --> Norm --> Mish 69 | ''' 70 | 71 | def __init__(self, inp_channels, out_channels): 72 | super().__init__() 73 | 74 | self.block = nn.Sequential( 75 | nn.Linear(inp_channels, out_channels), 76 | nn.LayerNorm(out_channels), 77 | nn.Mish(), 78 | ) 79 | 80 | def forward(self, x): 81 | return self.block(x) 82 | 83 | class UpsampleLinear(nn.Module): 84 | def __init__(self, dim_in, dim_out, embed_dim): 85 | super(UpsampleLinear, self).__init__() 86 | self._layer = nn.Linear(dim_in, dim_out) 87 | self._hyper_bias = nn.Linear(embed_dim, dim_out, bias=False) 88 | self._hyper_gate = nn.Linear(embed_dim, dim_out) 89 | 90 | def forward(self, x, ctx): 91 | gate = torch.sigmoid(self._hyper_gate(ctx)) 92 | bias = self._hyper_bias(ctx) 93 | ret = self._layer(x) * gate + bias 94 | return ret 95 | 96 | class DownsampleLinear(nn.Module): 97 | def __init__(self, dim_in, dim_out, embed_dim): 98 | super(DownsampleLinear, self).__init__() 99 | self._layer = nn.Linear(dim_in, dim_out) 100 | self._hyper_bias = nn.Linear(embed_dim, dim_out, bias=False) 101 | self._hyper_gate = nn.Linear(embed_dim, dim_out) 102 | 103 | def forward(self, x, ctx): 104 | gate = torch.sigmoid(self._hyper_gate(ctx)) 105 | bias = self._hyper_bias(ctx) 106 | ret = self._layer(x) * gate + bias 107 | return ret 108 | 109 | class ConcatSquashLinear(nn.Module): 110 | def __init__(self, dim_in, dim_out, dim_ctx): 111 | super(ConcatSquashLinear, self).__init__() 112 | self._layer = nn.Linear(dim_in, dim_out) 113 | self._hyper_bias = nn.Linear(dim_ctx, dim_out, bias=False) 114 | self._hyper_gate = nn.Linear(dim_ctx, dim_out) 115 | 116 | def forward(self, ctx, x): 117 | gate = torch.sigmoid(self._hyper_gate(ctx)) 118 | bias = self._hyper_bias(ctx) 119 | # if x.dim() == 3: 120 | # gate = gate.unsqueeze(1) 121 | # bias = bias.unsqueeze(1) 122 | ret = self._layer(x) * gate + bias 123 | return ret -------------------------------------------------------------------------------- /generative_skill_chaining/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Encoder 2 | from .state import StateEncoder 3 | from .utils import EncoderFactory, load, load_config -------------------------------------------------------------------------------- /generative_skill_chaining/encoders/base.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Type, Union 2 | 3 | import gym 4 | import numpy as np 5 | import torch 6 | 7 | from generative_skill_chaining import envs, networks 8 | from generative_skill_chaining.utils import configs, tensors 9 | 10 | 11 | class Encoder: 12 | """Base encooder class.""" 13 | 14 | def __init__( 15 | self, 16 | env: envs.Env, 17 | network_class: Union[str, Type[networks.encoders.Encoder]], 18 | network_kwargs: Dict[str, Any] = {}, 19 | device: str = "auto", 20 | ): 21 | """Initializes the dynamics model network, dataset, and optimizer. 22 | 23 | Args: 24 | env: Encoder env. 25 | network_class: Dynamics model network class. 26 | network_kwargs: Kwargs for network class. 27 | device: Torch device. 28 | """ 29 | network_class = configs.get_class(network_class, networks) 30 | self._network = network_class(env, **network_kwargs) 31 | 32 | self._observation_space = env.observation_space 33 | 34 | self.to(device) 35 | 36 | @property 37 | def observation_space(self) -> gym.spaces.Box: 38 | return self._observation_space 39 | 40 | @property 41 | def state_space(self) -> gym.spaces.Box: 42 | return self.network.state_space 43 | 44 | @property 45 | def network(self) -> networks.encoders.Encoder: 46 | return self._network 47 | 48 | @property 49 | def device(self) -> torch.device: 50 | """Torch device.""" 51 | return self._device 52 | 53 | def to(self, device: Union[str, torch.device]) -> "Encoder": 54 | """Transfers networks to device.""" 55 | self._device = torch.device(tensors.device(device)) 56 | self.network.to(self.device) 57 | return self 58 | 59 | def train_mode(self) -> None: 60 | """Switches the networks to train mode.""" 61 | self.network.train() 62 | 63 | def eval_mode(self) -> None: 64 | """Switches the networks to eval mode.""" 65 | self.network.eval() 66 | 67 | def encode( 68 | self, 69 | observation: torch.Tensor, 70 | policy_args: Union[np.ndarray, Optional[Any]], 71 | **kwargs 72 | ) -> torch.Tensor: 73 | return self.network.predict(observation, policy_args, **kwargs) 74 | 75 | def decode( 76 | self, 77 | latent: torch.Tensor, 78 | policy_args: Union[np.ndarray, Optional[Any]], 79 | **kwargs 80 | ) -> torch.Tensor: 81 | return self.network.reverse(latent, policy_args, **kwargs) 82 | 83 | def unnormalize(self, observation: torch.Tensor) -> torch.Tensor: 84 | return self.network.unnormalize(observation) 85 | -------------------------------------------------------------------------------- /generative_skill_chaining/encoders/state.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from typing import Any, Dict, Optional, Tuple, Type, Union 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from generative_skill_chaining import envs, networks 8 | from generative_skill_chaining.encoders.base import Encoder 9 | from generative_skill_chaining.utils.typing import StateEncoderBatch, Model, Scalar 10 | 11 | 12 | class StateEncoder(Encoder, Model[StateEncoderBatch]): 13 | """Vanilla autoencoder.""" 14 | 15 | def __init__( 16 | self, 17 | env: envs.Env, 18 | encoder_class: Union[str, Type[networks.encoders.Encoder]], 19 | encoder_kwargs: Dict[str, Any], 20 | checkpoint: Optional[Union[str, pathlib.Path]] = None, 21 | device: str = "auto", 22 | ): 23 | """Initializes the autoencoder network. 24 | 25 | Args: 26 | env: Encoder env. 27 | encoder_class: Encoder network class. 28 | encoder_kwargs: Kwargs for encoder network class. 29 | decoder_class: decoder network class. 30 | decoder_kwargs: Kwargs for decoder network class. 31 | checkpoint: Autoencoder checkpoint. 32 | device: Torch device. 33 | """ 34 | super().__init__( 35 | env=env, 36 | network_class=encoder_class, 37 | network_kwargs=encoder_kwargs, 38 | device=device, 39 | ) 40 | 41 | if checkpoint is not None: 42 | self.load(checkpoint, strict=True) 43 | 44 | def compute_loss( 45 | self, 46 | observation: torch.Tensor, 47 | state: torch.Tensor, 48 | policy_args: np.ndarray, 49 | ) -> Tuple[torch.Tensor, Dict[str, Union[Scalar, np.ndarray]]]: 50 | state_prediction = self.network.predict(observation, policy_args) 51 | loss = torch.nn.functional.mse_loss(state_prediction, state) 52 | metrics: Dict[str, Union[Scalar, np.ndarray]] = { 53 | "loss": loss.item(), 54 | } 55 | 56 | return loss, metrics 57 | 58 | def train_step( 59 | self, 60 | step: int, 61 | batch: StateEncoderBatch, 62 | optimizers: Dict[str, torch.optim.Optimizer], 63 | schedulers: Dict[str, torch.optim.lr_scheduler._LRScheduler], 64 | ) -> Dict[str, Union[Scalar, np.ndarray]]: 65 | """Performs a single training step. 66 | 67 | Args: 68 | step: Training step. 69 | batch: Training batch. 70 | optimizers: Optimizers created in `LatentDynamics.create_optimizers()`. 71 | schedulers: Schedulers with the same keys as `optimizers`. 72 | 73 | Returns: 74 | Dict of training metrics for logging. 75 | """ 76 | assert isinstance(batch["observation"], torch.Tensor) 77 | loss, metrics = self.compute_loss(**batch) # type: ignore 78 | 79 | optimizers["encoder"].zero_grad() 80 | loss.backward() 81 | optimizers["encoder"].step() 82 | 83 | return metrics 84 | 85 | def train_mode(self) -> None: 86 | """Switches to training mode.""" 87 | self.network.train() 88 | 89 | def eval_mode(self) -> None: 90 | """Switches to eval mode.""" 91 | self.network.eval() 92 | -------------------------------------------------------------------------------- /generative_skill_chaining/encoders/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Union 2 | import pathlib 3 | 4 | from generative_skill_chaining import encoders, envs 5 | from generative_skill_chaining.utils import configs 6 | 7 | 8 | class EncoderFactory(configs.Factory): 9 | """Encoder factory.""" 10 | 11 | def __init__( 12 | self, 13 | config: Optional[Union[str, pathlib.Path, Dict[str, Any]]] = None, 14 | checkpoint: Optional[Union[str, pathlib.Path]] = None, 15 | env: Optional[envs.Env] = None, 16 | device: str = "auto", 17 | ): 18 | """Creates the dynamics model factory from a config or checkpoint. 19 | 20 | Args: 21 | config: Optional dynamics config path or dict. Must be provided if 22 | checkpoint is None. 23 | checkpoint: Optional dynamics checkpoint path. Must be provided if 24 | config is None. 25 | env: Encoder env. 26 | device: Torch device. 27 | """ 28 | if checkpoint is not None: 29 | ckpt_config = load_config(checkpoint) 30 | if config is None: 31 | config = ckpt_config 32 | if env is None: 33 | ckpt_env_config = envs.load_config(checkpoint) 34 | env = envs.EnvFactory(ckpt_env_config)() 35 | 36 | if config is None: 37 | raise ValueError("Either config or checkpoint must be specified") 38 | if env is None: 39 | raise ValueError("Either env or checkpoint must be specified") 40 | 41 | super().__init__(config, "encoder", encoders) 42 | 43 | if checkpoint is not None and self.config["encoder"] != ckpt_config["encoder"]: 44 | raise ValueError( 45 | f"Config encoder [{self.config['encoder']}] and checkpoint" 46 | f"encoder [{ckpt_config['encoder']}] must be the same" 47 | ) 48 | 49 | self.kwargs["env"] = env 50 | self.kwargs["device"] = device 51 | 52 | 53 | def load( 54 | config: Optional[Union[str, pathlib.Path, Dict[str, Any]]] = None, 55 | checkpoint: Optional[Union[str, pathlib.Path]] = None, 56 | env: Optional[envs.Env] = None, 57 | device: str = "auto", 58 | **kwargs, 59 | ) -> encoders.Encoder: 60 | """Loads the encoder from a config or checkpoint. 61 | 62 | Args: 63 | config: Optional encoder config path or dict. Must be provided if 64 | checkpoint is None. 65 | checkpoint: Optional encoder checkpoint path. Must be provided if 66 | config is None. 67 | env: Encoder env. 68 | device: Torch device. 69 | kwargs: Optional encoder constructor kwargs. 70 | 71 | Returns: 72 | Encoder instance. 73 | """ 74 | encoder_factory = EncoderFactory( 75 | config=config, 76 | checkpoint=checkpoint, 77 | env=env, 78 | device=device, 79 | ) 80 | return encoder_factory(**kwargs) 81 | 82 | 83 | def load_config(path: Union[str, pathlib.Path]) -> Dict[str, Any]: 84 | """Loads a encoder config from path. 85 | 86 | Args: 87 | path: Path to the config, config directory, or checkpoint. 88 | 89 | Returns: 90 | Encoder config dict. 91 | """ 92 | return configs.load_config(path, "encoder") 93 | -------------------------------------------------------------------------------- /generative_skill_chaining/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Env, Primitive 2 | from .empty import EmptyEnv 3 | from .utils import * 4 | from .variant import VariantEnv 5 | from . import pybullet 6 | -------------------------------------------------------------------------------- /generative_skill_chaining/envs/empty.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | import gym 4 | import numpy as np 5 | 6 | from generative_skill_chaining.envs import base as envs 7 | 8 | 9 | def _get_space(low=None, high=None, shape=None, dtype=None): 10 | 11 | all_vars = [low, high, shape, dtype] 12 | if any([isinstance(v, dict) for v in all_vars]): 13 | all_keys = set() # get all the keys 14 | for v in all_vars: 15 | if isinstance(v, dict): 16 | all_keys.update(v.keys()) 17 | # Construct all the sets 18 | spaces = {} 19 | for k in all_keys: 20 | ll = low.get(k, None) if isinstance(low, dict) else low 21 | h = high.get(k, None) if isinstance(high, dict) else high 22 | s = shape.get(k, None) if isinstance(shape, dict) else shape 23 | d = dtype.get(k, None) if isinstance(dtype, dict) else dtype 24 | spaces[k] = _get_space(ll, h, s, d) 25 | # Construct the gym dict space 26 | return gym.spaces.Dict(**spaces) 27 | 28 | if shape is None and isinstance(high, int): 29 | assert low is None, "Tried to specify a discrete space with both high and low." 30 | return gym.spaces.Discrete(high) 31 | 32 | # Otherwise assume its a box. 33 | if low is None: 34 | low = -np.inf 35 | if high is None: 36 | high = np.inf 37 | if dtype is None: 38 | dtype = np.float32 39 | return gym.spaces.Box(low=low, high=high, shape=shape, dtype=dtype) 40 | 41 | 42 | class EmptyEnv(envs.Env): 43 | 44 | """ 45 | An empty holder for defining supervised learning problems 46 | It works by specifying the ranges and shapes. 47 | """ 48 | 49 | def __init__( 50 | self, 51 | observation_low=None, 52 | observation_high=None, 53 | observation_shape=None, 54 | observation_dtype=np.float32, 55 | action_low=None, 56 | action_high=None, 57 | action_shape=None, 58 | action_dtype=np.float32, 59 | ): 60 | self.observation_space = _get_space( 61 | observation_low, observation_high, observation_shape, observation_dtype 62 | ) 63 | self._action_space = _get_space( 64 | action_low, action_high, action_shape, action_dtype 65 | ) 66 | 67 | @property 68 | def action_space(self) -> gym.spaces.Box: # type: ignore 69 | return self._action_space 70 | 71 | @property 72 | def action_scale(self) -> gym.spaces.Box: 73 | return self._action_space 74 | 75 | def get_primitive(self) -> envs.Primitive: 76 | raise NotImplementedError("Empty Env does not have primitives") 77 | 78 | def set_primitive( 79 | self, 80 | primitive: Optional[envs.Primitive] = None, 81 | action_call: Optional[str] = None, 82 | idx_policy: Optional[int] = None, 83 | policy_args: Optional[Any] = None, 84 | ) -> envs.Env: 85 | raise NotImplementedError("Empty Env does not have primitives") 86 | 87 | def get_primitive_info( 88 | self, 89 | action_call: Optional[str] = None, 90 | idx_policy: Optional[int] = None, 91 | policy_args: Optional[Any] = None, 92 | ) -> envs.Primitive: 93 | """Gets the primitive info.""" 94 | raise NotImplementedError("Empty Env does not have primitives") 95 | 96 | def get_state(self) -> np.ndarray: 97 | """Gets the environment state.""" 98 | raise NotImplementedError("Empty Env does not have states") 99 | 100 | def set_state(self, state: np.ndarray) -> bool: 101 | """Sets the environment state.""" 102 | raise NotImplementedError("Empty Env does not have states") 103 | 104 | def get_observation(self, image: Optional[bool] = None) -> np.ndarray: 105 | """Gets an observation for the current environment state.""" 106 | raise NotImplementedError("Empty Env does not have observations") 107 | 108 | def step(self, action): 109 | raise NotImplementedError("Empty Env does not have step") 110 | 111 | def reset(self, **kwargs): 112 | raise NotImplementedError("Empty Env does not have reset") 113 | -------------------------------------------------------------------------------- /generative_skill_chaining/envs/pybullet/__init__.py: -------------------------------------------------------------------------------- 1 | from .table_env import TableEnv, VariantTableEnv 2 | -------------------------------------------------------------------------------- /generative_skill_chaining/envs/pybullet/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from generative_skill_chaining.envs.base import Env 5 | from generative_skill_chaining.envs.pybullet.utils import RedirectStream 6 | 7 | with RedirectStream(sys.stderr): 8 | import pybullet as p 9 | 10 | 11 | def connect_pybullet(gui: bool = True, options: str = "") -> int: 12 | if not gui: 13 | with RedirectStream(): 14 | physics_id = p.connect(p.DIRECT, options=options) 15 | elif not os.environ["DISPLAY"]: 16 | raise p.error 17 | else: 18 | with RedirectStream(): 19 | physics_id = p.connect(p.GUI, options=options) 20 | 21 | p.configureDebugVisualizer(p.COV_ENABLE_GUI, 0, physicsClientId=physics_id) 22 | p.configureDebugVisualizer( 23 | p.COV_ENABLE_SEGMENTATION_MARK_PREVIEW, 24 | 0, 25 | physicsClientId=physics_id, 26 | ) 27 | p.configureDebugVisualizer( 28 | p.COV_ENABLE_DEPTH_BUFFER_PREVIEW, 0, physicsClientId=physics_id 29 | ) 30 | p.configureDebugVisualizer( 31 | p.COV_ENABLE_RGB_BUFFER_PREVIEW, 0, physicsClientId=physics_id 32 | ) 33 | p.configureDebugVisualizer(p.COV_ENABLE_SHADOWS, 0, physicsClientId=physics_id) 34 | p.resetDebugVisualizerCamera( 35 | cameraDistance=0.25, 36 | cameraYaw=90, 37 | cameraPitch=-48, 38 | cameraTargetPosition=[0.76, 0.07, 0.37], 39 | physicsClientId=physics_id, 40 | ) 41 | 42 | return physics_id 43 | 44 | 45 | class PybulletEnv(Env): 46 | def __init__(self, name: str, gui: bool = True): 47 | self.name = name 48 | options = ( 49 | "--background_color_red=0.12 " 50 | "--background_color_green=0.12 " 51 | "--background_color_blue=0.25" 52 | ) 53 | try: 54 | self._physics_id = connect_pybullet(gui=gui, options=options) 55 | except p.error as e: 56 | print(e) 57 | print("Unable to connect to pybullet with gui. Connecting without gui...") 58 | self._physics_id = connect_pybullet(gui=False, options=options) 59 | 60 | p.setGravity(0, 0, -9.8, physicsClientId=self.physics_id) 61 | 62 | @property 63 | def physics_id(self) -> int: 64 | return self._physics_id 65 | 66 | def close(self) -> None: 67 | with RedirectStream(): 68 | try: 69 | p.disconnect(physicsClientId=self.physics_id) 70 | except (AttributeError, p.error): 71 | pass 72 | 73 | def __del__(self) -> None: 74 | self.close() 75 | -------------------------------------------------------------------------------- /generative_skill_chaining/envs/pybullet/real/__init__.py: -------------------------------------------------------------------------------- 1 | from . import arm, gripper 2 | -------------------------------------------------------------------------------- /generative_skill_chaining/envs/pybullet/real/object_tracker.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from typing import Dict, Iterable, List, Optional, Sequence, Union 3 | 4 | import ctrlutils 5 | from ctrlutils import eigen 6 | import numpy as np 7 | 8 | from generative_skill_chaining.envs.pybullet.real import redisgl 9 | from generative_skill_chaining.envs.pybullet.sim import math, shapes 10 | from generative_skill_chaining.envs.pybullet.table.objects import Object, Variant 11 | 12 | 13 | def create_pose(shape: shapes.Shape) -> redisgl.Pose: 14 | if shape.pose is None: 15 | return redisgl.Pose() 16 | elif isinstance(shape, shapes.Cylinder): 17 | quat_pybullet_to_redisgl = eigen.Quaterniond( 18 | eigen.AngleAxisd(np.pi / 2, np.array([1.0, 0.0, 0.0])) 19 | ) 20 | quat = eigen.Quaterniond(shape.pose.quat) * quat_pybullet_to_redisgl 21 | return redisgl.Pose(shape.pose.pos, quat.coeffs) 22 | else: 23 | return redisgl.Pose(shape.pose.pos, shape.pose.quat) 24 | 25 | 26 | def create_geometry(shape: shapes.Shape) -> redisgl.Geometry: 27 | if isinstance(shape, shapes.Box): 28 | return redisgl.Box(scale=shape.size) 29 | elif isinstance(shape, shapes.Cylinder): 30 | return redisgl.Cylinder(radius=shape.radius, length=shape.length) 31 | elif isinstance(shape, shapes.Sphere): 32 | return redisgl.Sphere(radius=shape.radius) 33 | 34 | raise NotImplementedError(f"Shape type {shape} is not supported.") 35 | 36 | 37 | def create_graphics(object: Object) -> Sequence[redisgl.Graphics]: 38 | if isinstance(object, Variant): 39 | return [] 40 | 41 | return [ 42 | redisgl.Graphics( 43 | name=object.name, 44 | geometry=create_geometry(shape), 45 | T_to_parent=create_pose(shape), 46 | ) 47 | for shape in object.shapes 48 | ] 49 | 50 | 51 | def create_object_model(object: Object, key_namespace: str) -> redisgl.ObjectModel: 52 | return redisgl.ObjectModel( 53 | name=object.name, 54 | graphics=create_graphics(object), 55 | key_pos=f"{key_namespace}::objects::{object.name}::pos", 56 | key_ori=f"{key_namespace}::objects::{object.name}::ori", 57 | ) 58 | 59 | 60 | class ObjectTracker: 61 | def __init__( 62 | self, 63 | objects: Dict[str, Object], 64 | redis_host: str, 65 | redis_port: int, 66 | redis_password: str, 67 | key_namespace: str, 68 | object_key_prefix: str, 69 | assets_path: Union[str, pathlib.Path], 70 | ): 71 | self._redis = ctrlutils.RedisClient(redis_host, redis_port, redis_password) 72 | self._redis_pipe = self._redis.pipeline() 73 | self._object_key_prefix = object_key_prefix 74 | 75 | self._assets_path = str(pathlib.Path(assets_path).absolute()) 76 | redisgl.register_resource_path(self._redis_pipe, self._assets_path) 77 | self._model_keys = redisgl.ModelKeys(key_namespace) 78 | redisgl.register_model_keys(self._redis_pipe, self._model_keys) 79 | 80 | self._redis_pipe.execute() 81 | self._tracked_objects = [] # self.get_tracked_objects(objects.values()) 82 | for object in objects.values(): 83 | try: 84 | redisgl.register_object( 85 | self._redis_pipe, 86 | self._model_keys, 87 | object=create_object_model(object, key_namespace), 88 | ) 89 | except NotImplementedError: 90 | continue 91 | self._tracked_objects.append(object) 92 | self._redis_pipe.execute() 93 | 94 | def __del__(self) -> None: 95 | redisgl.unregister_resource_path(self._redis_pipe, self._assets_path) 96 | redisgl.unregister_model_keys(self._redis_pipe, self._model_keys) 97 | for object in self._tracked_objects: 98 | redisgl.unregister_object(self._redis_pipe, self._model_keys, object.name) 99 | self._redis_pipe.execute() 100 | 101 | def get_tracked_objects(self, objects: Iterable[Object]) -> List[Object]: 102 | for object in objects: 103 | self._redis_pipe.get(self._object_key_prefix + object.name + "::pos") 104 | object_models = self._redis_pipe.execute() 105 | 106 | return [ 107 | object 108 | for object, object_model in zip(objects, object_models) 109 | if object_model is not None 110 | ] 111 | 112 | def update_poses( 113 | self, 114 | objects: Optional[Iterable[Object]] = None, 115 | exclude: Optional[Sequence[Object]] = None, 116 | ) -> List[Object]: 117 | if objects is None: 118 | objects = self._tracked_objects 119 | 120 | # Query all object poses. 121 | for object in objects: 122 | self._redis_pipe.get(self._object_key_prefix + object.name + "::pos") 123 | self._redis_pipe.get(self._object_key_prefix + object.name + "::ori") 124 | b_object_poses = self._redis_pipe.execute() 125 | 126 | # Set returned poses. 127 | updated_objects = [] 128 | for i, object in enumerate(objects): 129 | if exclude is not None and object in exclude: 130 | continue 131 | b_object_pos = b_object_poses[2 * i] 132 | b_object_quat = b_object_poses[2 * i + 1] 133 | if b_object_pos is None or b_object_quat is None: 134 | continue 135 | 136 | object_pos = ctrlutils.redis.decode_matlab(b_object_pos) 137 | object_quat = ctrlutils.redis.decode_matlab(b_object_quat) 138 | 139 | object.set_pose(math.Pose(object_pos, object_quat)) 140 | updated_objects.append(object) 141 | 142 | return updated_objects 143 | 144 | def send_poses(self, objects: Optional[Iterable[Object]] = None) -> None: 145 | if objects is None: 146 | objects = self._tracked_objects 147 | 148 | for object in objects: 149 | pose = object.pose() 150 | self._redis_pipe.set_matrix( 151 | self._object_key_prefix + object.name + "::pos", pose.pos 152 | ) 153 | self._redis_pipe.set_matrix( 154 | self._object_key_prefix + object.name + "::ori", pose.quat 155 | ) 156 | self._redis_pipe.execute() 157 | -------------------------------------------------------------------------------- /generative_skill_chaining/envs/pybullet/sim/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/generative-skill-chaining/gsc-code/b8f5bc0d44bd453827f3f254ca1b03ec049849c5/generative_skill_chaining/envs/pybullet/sim/__init__.py -------------------------------------------------------------------------------- /generative_skill_chaining/envs/pybullet/sim/body.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | from ctrlutils import eigen 4 | import spatialdyn as dyn 5 | import numpy as np 6 | import pybullet as p 7 | 8 | from generative_skill_chaining.envs.pybullet.sim import math 9 | 10 | 11 | @dataclasses.dataclass 12 | class Body: 13 | physics_id: int 14 | body_id: int 15 | 16 | def aabb(self) -> np.ndarray: 17 | """Body aabb. 18 | 19 | Note: The aabb given by Pybullet is larger than the true aabb for 20 | collision detection purposes. 21 | 22 | Also, the aabb is only reported for the object base. 23 | """ 24 | return np.array(p.getAABB(self.body_id, physicsClientId=self.physics_id)) 25 | 26 | def pose(self) -> math.Pose: 27 | """Base pose.""" 28 | pos, quat = p.getBasePositionAndOrientation( 29 | self.body_id, physicsClientId=self.physics_id 30 | ) 31 | return math.Pose(np.array(pos), np.array(quat)) 32 | 33 | def set_pose(self, pose: math.Pose) -> None: 34 | """Sets the base pose.""" 35 | p.resetBasePositionAndOrientation( 36 | self.body_id, pose.pos, pose.quat, physicsClientId=self.physics_id 37 | ) 38 | 39 | def twist(self) -> np.ndarray: 40 | """Base twist.""" 41 | v, w = p.getBaseVelocity(self.body_id, physicsClientId=self.physics_id) 42 | return np.concatenate([v, w]) 43 | 44 | @property 45 | def dof(self) -> int: 46 | """Total number of joints in the articulated body, including non-controlled joints.""" 47 | try: 48 | return self._dof # type: ignore 49 | except AttributeError: 50 | pass 51 | 52 | self._dof = p.getNumJoints(self.body_id, physicsClientId=self.physics_id) 53 | return self._dof 54 | 55 | @property 56 | def inertia(self) -> dyn.SpatialInertiad: 57 | """Base inertia.""" 58 | try: 59 | return self._inertia # type: ignore 60 | except AttributeError: 61 | pass 62 | 63 | was_frozen = self.unfreeze() 64 | dynamics_info = p.getDynamicsInfo( 65 | self.body_id, -1, physicsClientId=self.physics_id 66 | ) 67 | if was_frozen: 68 | self.freeze() 69 | 70 | mass = dynamics_info[0] 71 | inertia_xyz = dynamics_info[2] 72 | com = np.array(dynamics_info[3]) 73 | quat_inertia = eigen.Quaterniond(dynamics_info[4]) 74 | T_inertia = eigen.Translation3d.identity() * quat_inertia 75 | self._inertia = ( 76 | dyn.SpatialInertiad(mass, com, np.concatenate([inertia_xyz, np.zeros(3)])) 77 | * T_inertia 78 | ) 79 | 80 | return self._inertia 81 | 82 | def freeze(self) -> bool: 83 | """Disable simulation for this body. 84 | 85 | Returns: 86 | Whether the object's frozen status changed. 87 | """ 88 | if not hasattr(self, "_is_frozen"): 89 | self._mass = p.getDynamicsInfo( 90 | self.body_id, -1, physicsClientId=self.physics_id 91 | )[0] 92 | elif self._is_frozen: # type: ignore 93 | return False 94 | 95 | p.changeDynamics(self.body_id, -1, mass=0, physicsClientId=self.physics_id) 96 | self._is_frozen = True 97 | return True 98 | 99 | def unfreeze(self) -> bool: 100 | """Enable simulation for this body. 101 | 102 | Returns: 103 | Whether the object's frozen status changed. 104 | """ 105 | if not hasattr(self, "_is_frozen") or not self._is_frozen: 106 | return False 107 | 108 | p.changeDynamics( 109 | self.body_id, -1, mass=self._mass, physicsClientId=self.physics_id 110 | ) 111 | self._is_frozen = False 112 | return True 113 | 114 | 115 | @dataclasses.dataclass 116 | class Link: 117 | physics_id: int 118 | body_id: int 119 | link_id: int 120 | 121 | @property 122 | def name(self) -> str: 123 | """Link name.""" 124 | try: 125 | return self._name # type: ignore 126 | except AttributeError: 127 | pass 128 | 129 | self._name = p.getJointInfo( 130 | self.body_id, self.link_id, physicsClientId=self.physics_id 131 | )[12].decode("utf8") 132 | return self._name 133 | 134 | def pose(self) -> math.Pose: 135 | """World pose of the center of mass.""" 136 | pos, quat = p.getLinkState( 137 | self.body_id, self.link_id, physicsClientId=self.physics_id 138 | )[:2] 139 | return math.Pose(np.array(pos), np.array(quat)) 140 | 141 | @property 142 | def inertia(self) -> str: 143 | """Inertia at the center of mass frame.""" 144 | try: 145 | return self._inertia # type: ignore 146 | except AttributeError: 147 | pass 148 | 149 | dynamics_info = p.getDynamicsInfo( 150 | self.body_id, self.link_id, physicsClientId=self.physics_id 151 | ) 152 | mass = dynamics_info[0] 153 | inertia_xyz = dynamics_info[2] 154 | com = np.zeros(3) 155 | self._inertia = dyn.SpatialInertiad( 156 | mass, com, np.concatenate([inertia_xyz, np.zeros(3)]) 157 | ) 158 | 159 | return self._inertia 160 | 161 | @property 162 | def joint_limits(self) -> np.ndarray: 163 | """(lower, upper) joint limits.""" 164 | try: 165 | return self._joint_limits # type: ignore 166 | except AttributeError: 167 | pass 168 | 169 | joint_info = p.getJointInfo( 170 | self.body_id, self.link_id, physicsClientId=self.physics_id 171 | ) 172 | self._joint_limits = np.array(joint_info[8:10]) 173 | 174 | return self._joint_limits 175 | -------------------------------------------------------------------------------- /generative_skill_chaining/envs/pybullet/sim/math.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import dataclasses 4 | 5 | from ctrlutils import eigen 6 | import numpy as np 7 | 8 | 9 | PYBULLET_STEPS_PER_SEC = 240 10 | PYBULLET_TIMESTEP = 1 / PYBULLET_STEPS_PER_SEC 11 | 12 | 13 | @dataclasses.dataclass 14 | class Pose: 15 | """6d pose. 16 | 17 | Args: 18 | pos: 3d position. 19 | quat: xyzw quaternion. 20 | """ 21 | 22 | pos: np.ndarray = np.zeros(3) 23 | quat: np.ndarray = np.array([0.0, 0.0, 0.0, 1.0]) 24 | 25 | @staticmethod 26 | def from_eigen(pose: eigen.Isometry3d) -> "Pose": 27 | """Creates a pose from an Eigen Isometry3d.""" 28 | pos = np.array(pose.translation) 29 | quat = np.array(eigen.Quaterniond(pose.linear).coeffs) 30 | return Pose(pos, quat) 31 | 32 | def to_eigen(self) -> eigen.Isometry3d: 33 | """Converts a pose to an Eigen Isometry3d.""" 34 | return eigen.Translation3d(self.pos) * eigen.Quaterniond(self.quat) 35 | 36 | 37 | def comb(n: int, r: int) -> int: 38 | """Computes (n choose r).""" 39 | try: 40 | return math.comb(n, r) 41 | except AttributeError: 42 | return math.factorial(n) // (math.factorial(r) * math.factorial(n - 4)) 43 | -------------------------------------------------------------------------------- /generative_skill_chaining/envs/pybullet/sim/shapes.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import enum 3 | from typing import Any, Dict, Optional, Sequence, Tuple, Union 4 | 5 | import numpy as np 6 | import pybullet as p 7 | 8 | from generative_skill_chaining.envs.pybullet.sim.math import Pose 9 | 10 | 11 | def create_body( 12 | shapes: Union["Shape", Sequence["Shape"]], 13 | link_parents: Optional[Sequence[int]] = None, 14 | physics_id: int = 0, 15 | ) -> int: 16 | if isinstance(shapes, Shape): 17 | base_shape = shapes 18 | link_shapes: Sequence[Shape] = [] 19 | else: 20 | base_shape = shapes[0] 21 | link_shapes = shapes[1:] 22 | 23 | base_collision_id, base_visual_id = base_shape.create_visual( 24 | physics_id, is_base=True 25 | ) 26 | kwargs: Dict[str, Any] = { 27 | "baseMass": base_shape.mass, 28 | "baseCollisionShapeIndex": base_collision_id, 29 | "baseVisualShapeIndex": base_visual_id, 30 | } 31 | if base_shape.pose is not None: 32 | kwargs["baseInertialFramePosition"] = base_shape.pose.pos 33 | kwargs["baseInertialFrameOrientation"] = base_shape.pose.quat 34 | 35 | if len(link_shapes) > 0: 36 | masses, poses, joints, collision_ids, visual_ids = zip( 37 | *[ 38 | (shape.mass, shape.pose, shape.joint, *shape.create_visual(physics_id)) 39 | for shape in link_shapes 40 | ] 41 | ) 42 | 43 | kwargs["linkMasses"] = masses 44 | kwargs["linkCollisionShapeIndices"] = collision_ids 45 | kwargs["linkVisualShapeIndices"] = visual_ids 46 | 47 | link_poses = [Pose() if pose is None else pose for pose in poses] 48 | link_inertia_poses = [Pose()] * len(poses) 49 | kwargs["linkPositions"] = [pose.pos for pose in link_poses] 50 | kwargs["linkOrientations"] = [pose.quat for pose in link_poses] 51 | kwargs["linkInertialFramePositions"] = [pose.pos for pose in link_inertia_poses] 52 | kwargs["linkInertialFrameOrientations"] = [ 53 | pose.quat for pose in link_inertia_poses 54 | ] 55 | 56 | link_joints = [Joint() if joint is None else joint for joint in joints] 57 | if link_parents is None: 58 | link_parents = list(range(len(joints))) 59 | kwargs["linkParentIndices"] = link_parents 60 | kwargs["linkJointTypes"] = [int(joint.joint_type) for joint in link_joints] 61 | kwargs["linkJointAxis"] = [joint.axis for joint in link_joints] 62 | 63 | body_id = p.createMultiBody( 64 | physicsClientId=physics_id, 65 | **kwargs, 66 | ) 67 | 68 | return body_id 69 | 70 | 71 | class JointType(enum.IntEnum): 72 | REVOLUTE = p.JOINT_REVOLUTE 73 | PRISMATIC = p.JOINT_PRISMATIC 74 | SPHERICAL = p.JOINT_SPHERICAL 75 | FIXED = p.JOINT_FIXED 76 | 77 | 78 | @dataclasses.dataclass 79 | class Joint: 80 | joint_type: JointType = JointType.FIXED 81 | axis: np.ndarray = np.array([0.0, 0.0, 1.0]) 82 | 83 | 84 | @dataclasses.dataclass 85 | class Shape: 86 | mass: float = 0.0 87 | color: Optional[np.ndarray] = None 88 | pose: Optional[Pose] = None 89 | joint: Optional[Joint] = None 90 | 91 | def visual_kwargs( 92 | self, is_base: bool = False 93 | ) -> Tuple[Dict[str, Any], Dict[str, Any]]: 94 | collision_kwargs = {} 95 | visual_kwargs = {} 96 | if self.color is not None: 97 | visual_kwargs["rgbaColor"] = self.color 98 | 99 | if is_base and self.pose is not None: 100 | collision_kwargs["collisionFramePosition"] = self.pose.pos 101 | collision_kwargs["collisionFrameOrientation"] = self.pose.quat 102 | visual_kwargs["visualFramePosition"] = self.pose.pos 103 | visual_kwargs["visualFrameOrientation"] = self.pose.quat 104 | 105 | return collision_kwargs, visual_kwargs 106 | 107 | def create_visual(self, physics_id: int, is_base: bool = False) -> Tuple[int, int]: 108 | collision_kwargs, visual_kwargs = self.visual_kwargs(is_base) 109 | 110 | collision_id = p.createCollisionShape( 111 | physicsClientId=physics_id, **collision_kwargs 112 | ) 113 | visual_id = p.createVisualShape(physicsClientId=physics_id, **visual_kwargs) 114 | 115 | return collision_id, visual_id 116 | 117 | 118 | @dataclasses.dataclass 119 | class Box(Shape): 120 | size: np.ndarray = np.array([0.1, 0.1, 0.1]) 121 | 122 | def visual_kwargs( 123 | self, is_base: bool = False 124 | ) -> Tuple[Dict[str, Any], Dict[str, Any]]: 125 | collision_kwargs, visual_kwargs = super().visual_kwargs(is_base) 126 | 127 | collision_kwargs["shapeType"] = p.GEOM_BOX 128 | collision_kwargs["halfExtents"] = self.size / 2 129 | 130 | visual_kwargs["shapeType"] = p.GEOM_BOX 131 | visual_kwargs["halfExtents"] = self.size / 2 132 | 133 | return collision_kwargs, visual_kwargs 134 | 135 | 136 | @dataclasses.dataclass 137 | class Cylinder(Shape): 138 | radius: float = 0.05 139 | length: float = 0.1 140 | 141 | def visual_kwargs( 142 | self, is_base: bool = False 143 | ) -> Tuple[Dict[str, Any], Dict[str, Any]]: 144 | collision_kwargs, visual_kwargs = super().visual_kwargs(is_base) 145 | 146 | collision_kwargs["shapeType"] = p.GEOM_CYLINDER 147 | collision_kwargs["radius"] = self.radius 148 | collision_kwargs["height"] = self.length 149 | 150 | visual_kwargs["shapeType"] = p.GEOM_CYLINDER 151 | visual_kwargs["radius"] = self.radius 152 | visual_kwargs["length"] = self.length 153 | 154 | return collision_kwargs, visual_kwargs 155 | 156 | 157 | @dataclasses.dataclass 158 | class Sphere(Shape): 159 | radius: float = 0.05 160 | 161 | def visual_kwargs( 162 | self, is_base: bool = False 163 | ) -> Tuple[Dict[str, Any], Dict[str, Any]]: 164 | collision_kwargs, visual_kwargs = super().visual_kwargs(is_base) 165 | 166 | collision_kwargs["shapeType"] = p.GEOM_SPHERE 167 | collision_kwargs["radius"] = self.radius 168 | 169 | visual_kwargs["shapeType"] = p.GEOM_SPHERE 170 | visual_kwargs["radius"] = self.radius 171 | 172 | return collision_kwargs, visual_kwargs 173 | -------------------------------------------------------------------------------- /generative_skill_chaining/envs/pybullet/table/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/generative-skill-chaining/gsc-code/b8f5bc0d44bd453827f3f254ca1b03ec049849c5/generative_skill_chaining/envs/pybullet/table/__init__.py -------------------------------------------------------------------------------- /generative_skill_chaining/envs/pybullet/table/object_state.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | from ctrlutils import eigen 4 | import numpy as np 5 | 6 | from generative_skill_chaining.envs.pybullet.sim import math 7 | 8 | 9 | class ObjectState: 10 | RANGES = { 11 | "x": (-0.3, 0.9), 12 | "y": (-0.5, 0.5), 13 | "z": (-0.1, 0.5), 14 | "wx": (-np.pi, np.pi), 15 | "wy": (-np.pi, np.pi), 16 | "wz": (-np.pi, np.pi), 17 | "box_size_x": (0.0, 0.1), 18 | "box_size_y": (0.0, 0.1), 19 | "box_size_z": (0.0, 0.2), 20 | "head_length": (0.0, 0.3), 21 | "handle_length": (0.0, 0.5), 22 | "handle_y": (-1.0, 1.0), 23 | } 24 | 25 | def __init__(self, vector: Optional[np.ndarray] = None): 26 | if vector is None: 27 | vector = np.zeros(len(self.RANGES), dtype=np.float32) 28 | elif vector.shape[-1] != len(self.RANGES): 29 | vector = vector.reshape( 30 | ( 31 | *vector.shape[:-1], 32 | vector.shape[-1] // len(self.RANGES), 33 | len(self.RANGES), 34 | ) 35 | ) 36 | self.vector = vector 37 | 38 | @property 39 | def pos(self) -> np.ndarray: 40 | return self.vector[..., :3] 41 | 42 | @pos.setter 43 | def pos(self, pos: np.ndarray) -> None: 44 | self.vector[..., :3] = pos 45 | 46 | @property 47 | def aa(self) -> np.ndarray: 48 | return self.vector[..., 3:6] 49 | 50 | @aa.setter 51 | def aa(self, aa: np.ndarray) -> None: 52 | self.vector[..., 3:6] = aa 53 | 54 | @property 55 | def box_size(self) -> np.ndarray: 56 | return self.vector[..., 6:9] 57 | 58 | @box_size.setter 59 | def box_size(self, box_size: np.ndarray) -> None: 60 | self.vector[..., 6:9] = box_size 61 | 62 | @property 63 | def head_length(self) -> Union[float, np.ndarray]: 64 | if self.vector.ndim > 1: 65 | return self.vector[..., 9:10] 66 | return self.vector[9] 67 | 68 | @head_length.setter 69 | def head_length(self, head_length: Union[float, np.ndarray]) -> None: 70 | self.vector[..., 9:10] = head_length 71 | 72 | @property 73 | def handle_length(self) -> Union[float, np.ndarray]: 74 | if self.vector.ndim > 1: 75 | return self.vector[..., 10:11] 76 | return self.vector[10] 77 | 78 | @handle_length.setter 79 | def handle_length(self, handle_length: Union[float, np.ndarray]) -> None: 80 | self.vector[..., 10:11] = handle_length 81 | 82 | @property 83 | def handle_y(self) -> Union[float, np.ndarray]: 84 | if self.vector.ndim > 1: 85 | return self.vector[..., 11:12] 86 | return self.vector[11] 87 | 88 | @handle_y.setter 89 | def handle_y(self, handle_y: Union[float, np.ndarray]) -> None: 90 | self.vector[..., 11:12] = handle_y 91 | 92 | @classmethod 93 | def range(cls) -> np.ndarray: 94 | return np.array(list(cls.RANGES.values()), dtype=np.float32).T 95 | 96 | def pose(self) -> math.Pose: 97 | angle = np.linalg.norm(self.aa) 98 | if angle == 0: 99 | quat = eigen.Quaterniond.identity() 100 | else: 101 | axis = self.aa / angle 102 | quat = eigen.Quaterniond(eigen.AngleAxisd(angle, axis)) 103 | return math.Pose(pos=self.pos, quat=quat.coeffs) 104 | 105 | def set_pose(self, pose: math.Pose) -> None: 106 | aa = eigen.AngleAxisd(eigen.Quaterniond(pose.quat)) 107 | self.pos = pose.pos 108 | self.aa = aa.angle * aa.axis 109 | 110 | def __repr__(self) -> str: 111 | return ( 112 | "{\n" 113 | f" pos: {self.pos},\n" 114 | f" aa: {self.aa},\n" 115 | f" box_size: {self.box_size},\n" 116 | f" head_length: {self.head_length},\n" 117 | f" handle_length: {self.handle_length},\n" 118 | f" handle_y: {self.handle_y},\n" 119 | "}" 120 | ) 121 | -------------------------------------------------------------------------------- /generative_skill_chaining/envs/pybullet/utils.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import sys 3 | import os 4 | 5 | 6 | class RedirectStream(object): 7 | """Taken from https://github.com/bulletphysics/bullet3/discussions/3441#discussioncomment-657321.""" 8 | 9 | @staticmethod 10 | def _flush_c_stream(stream): 11 | streamname = stream.name[1:-1] 12 | libc = ctypes.CDLL(None) 13 | libc.fflush(ctypes.c_void_p.in_dll(libc, streamname)) 14 | 15 | def __init__(self, stream=sys.stdout, file=os.devnull): 16 | self.stream = stream 17 | self.file = file 18 | 19 | def __enter__(self): 20 | self.stream.flush() # ensures python stream unaffected 21 | try: 22 | self.fd = open(self.file, "w+") 23 | except NameError: 24 | return 25 | self.dup_stream = os.dup(self.stream.fileno()) 26 | os.dup2(self.fd.fileno(), self.stream.fileno()) # replaces stream 27 | 28 | def __exit__(self, type, value, traceback): 29 | RedirectStream._flush_c_stream(self.stream) # ensures C stream buffer empty 30 | try: 31 | os.dup2(self.dup_stream, self.stream.fileno()) # restores stream 32 | except AttributeError: 33 | return 34 | os.close(self.dup_stream) 35 | self.fd.close() 36 | -------------------------------------------------------------------------------- /generative_skill_chaining/envs/utils.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from typing import Any, Dict, List, Optional, Union 3 | 4 | from generative_skill_chaining import envs 5 | from generative_skill_chaining.utils import configs 6 | 7 | 8 | class EnvFactory(configs.Factory[envs.Env]): 9 | """Env factory.""" 10 | 11 | def __init__( 12 | self, 13 | config: Union[str, pathlib.Path, Dict[str, Any]], 14 | ): 15 | """Creates the env factory from an env config or policy checkpoint. 16 | 17 | Args: 18 | config: Env config path or dict. 19 | """ 20 | super().__init__(config, "env", envs) 21 | 22 | if issubclass(self.cls, envs.VariantEnv): 23 | self._variants = [ 24 | EnvFactory(env_config) for env_config in self.kwargs["variants"] 25 | ] 26 | 27 | @property 28 | def env_factories(self) -> List["EnvFactory"]: 29 | """Primitive env factories for sequential env.""" 30 | if not issubclass(self.cls, envs.pybox2d.Sequential2D): 31 | raise AttributeError("Only Sequential2D has attribute env_factories") 32 | return self.kwargs["env_factories"] 33 | 34 | def __call__(self, *args, multiprocess: bool = False, **kwargs) -> envs.Env: 35 | """Creates an env instance. 36 | 37 | Args: 38 | *args: Env constructor args. 39 | multiprocess: Whether to wrap the env in a ProcessEnv. 40 | **kwargs: Env constructor kwargs. 41 | 42 | Returns: 43 | Env instance. 44 | """ 45 | if multiprocess: 46 | raise NotImplementedError 47 | # merged_kwargs = dict(self.kwargs) 48 | # merged_kwargs.update(kwargs) 49 | # instance = envs.ProcessEnv(self.cls, *args, **kwargs) 50 | # 51 | # self.run_post_hooks(instance) 52 | # 53 | # return instance 54 | 55 | if issubclass(self.cls, envs.VariantEnv): 56 | variants = [env_factory(*args, **kwargs) for env_factory in self._variants] 57 | return super().__call__(variants=variants) 58 | 59 | return super().__call__(*args, **kwargs) 60 | 61 | 62 | def load( 63 | config: Optional[Union[str, pathlib.Path, Dict[str, Any]]] = None, 64 | checkpoint: Optional[Union[str, pathlib.Path]] = None, 65 | multiprocess: bool = False, 66 | **kwargs, 67 | ) -> envs.Env: 68 | """Loads the agent from an env config or policy checkpoint. 69 | 70 | Args: 71 | config: Optional env config path or dict. Must be set if checkpoint is 72 | None. 73 | checkpoint: Optional policy checkpoint path. 74 | multiprocess: Whether to run the env in a separate process. 75 | kwargs: Additional env constructor kwargs. 76 | 77 | Returns: 78 | Env instance. 79 | """ 80 | if config is None: 81 | if checkpoint is None: 82 | raise ValueError("Env config or checkpoint must be specified") 83 | config = load_config(checkpoint) 84 | 85 | env_factory = EnvFactory(config) 86 | return env_factory(multiprocess=multiprocess, **kwargs) 87 | 88 | 89 | def load_config(path: Union[str, pathlib.Path]) -> Dict[str, Any]: 90 | """Loads an env config from path. 91 | 92 | Args: 93 | path: Path to the config, config directory, or checkpoint. 94 | 95 | Returns: 96 | Env config dict. 97 | """ 98 | return configs.load_config(path, "env") 99 | -------------------------------------------------------------------------------- /generative_skill_chaining/envs/variant.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from typing import Any, List, Optional, Sequence, Tuple, Union 3 | 4 | import gym 5 | import numpy as np 6 | 7 | from generative_skill_chaining.envs.base import Env, Primitive 8 | 9 | 10 | class VariantEnv(Env): 11 | def __init__(self, variants: Sequence[Env]): 12 | self._variants = variants 13 | self._idx_variant = np.random.randint(len(self.variants)) 14 | 15 | @property 16 | def variants(self) -> Sequence[Env]: 17 | return self._variants 18 | 19 | @property 20 | def env(self) -> Env: 21 | return self.variants[self._idx_variant] 22 | 23 | @property 24 | def metadata(self): 25 | return self.env.metadata 26 | 27 | @property 28 | def render_mode(self): 29 | return self.env.render_mode 30 | 31 | @property 32 | def observation_space(self) -> gym.spaces.Box: # type: ignore 33 | return self.env.observation_space 34 | 35 | @property 36 | def state_space(self) -> gym.spaces.Box: # type: ignore 37 | return self.env.state_space 38 | 39 | @property 40 | def image_space(self) -> gym.spaces.Box: # type: ignore 41 | return self.env.image_space 42 | 43 | @property 44 | def action_space(self) -> gym.spaces.Box: # type: ignore 45 | return self.env.action_space 46 | 47 | @property 48 | def action_scale(self) -> gym.spaces.Box: 49 | return self.env.action_scale 50 | 51 | @property 52 | def action_skeleton(self) -> Sequence[Primitive]: 53 | return self.env.action_skeleton 54 | 55 | @property 56 | def primitives(self) -> List[str]: 57 | return self.env.primitives 58 | 59 | def get_primitive(self) -> Primitive: 60 | return self.env.get_primitive() 61 | 62 | def set_primitive( 63 | self, 64 | primitive: Optional[Primitive] = None, 65 | action_call: Optional[str] = None, 66 | idx_policy: Optional[int] = None, 67 | policy_args: Optional[Any] = None, 68 | ) -> "Env": 69 | return self.env.set_primitive(primitive, action_call, idx_policy, policy_args) 70 | 71 | def get_primitive_info( 72 | self, 73 | action_call: Optional[str] = None, 74 | idx_policy: Optional[int] = None, 75 | policy_args: Optional[Any] = None, 76 | ) -> Primitive: 77 | return self.env.get_primitive_info(action_call, idx_policy, policy_args) 78 | 79 | def create_primitive_env(self, primitive: Primitive) -> "Env": 80 | return self.env.create_primitive_env(primitive) 81 | 82 | def get_state(self) -> np.ndarray: 83 | return self.env.get_state() 84 | 85 | def set_state(self, state: np.ndarray) -> bool: 86 | return self.env.set_state(state) 87 | 88 | def get_observation(self, image: Optional[bool] = None) -> np.ndarray: 89 | return self.env.get_observation(image) 90 | 91 | def reset( 92 | self, *, seed: Optional[int] = None, options: Optional[dict] = None 93 | ) -> Tuple[np.ndarray, dict]: 94 | self._idx_variant = np.random.randint(len(self.variants)) 95 | return self.env.reset(seed=seed, options=options) 96 | 97 | def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, dict]: 98 | return self.env.step(action) 99 | 100 | def close(self): 101 | for env in self.variants: 102 | env.close() 103 | 104 | def render(self) -> np.ndarray: # type: ignore 105 | return self.env.render() 106 | 107 | def record_start( 108 | self, 109 | prepend_id: Optional[Any] = None, 110 | frequency: Optional[int] = None, 111 | mode: str = "default", 112 | ) -> bool: 113 | return self.env.record_start(prepend_id, frequency, mode) 114 | 115 | def record_stop(self, save_id: Optional[Any] = None, mode: str = "default") -> bool: 116 | return self.env.record_stop(save_id, mode) 117 | 118 | def record_save( 119 | self, 120 | path: Union[str, pathlib.Path], 121 | reset: bool = True, 122 | mode: Optional[str] = None, 123 | ) -> bool: 124 | return self.env.record_save(path, reset, mode) 125 | -------------------------------------------------------------------------------- /generative_skill_chaining/mixed_diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/generative-skill-chaining/gsc-code/b8f5bc0d44bd453827f3f254ca1b03ec049849c5/generative_skill_chaining/mixed_diffusion/__init__.py -------------------------------------------------------------------------------- /generative_skill_chaining/mixed_diffusion/cond_diffusion1D.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.nn import Module, Parameter, ModuleList 6 | import numpy as np 7 | 8 | from generative_skill_chaining.mixed_diffusion.utils.diff_utils import * 9 | from generative_skill_chaining.diff_models.unet_transformer import ScoreNet 10 | from generative_skill_chaining.mixed_diffusion.sde_cont import PluginReverseSDE 11 | 12 | 13 | class VarianceSchedule(Module): 14 | 15 | def __init__(self, num_steps, beta_1, beta_T, mode='linear'): 16 | super().__init__() 17 | assert mode in ('linear', ) 18 | self.num_steps = num_steps 19 | self.beta_1 = beta_1 20 | self.beta_T = beta_T 21 | self.mode = mode 22 | 23 | if mode == 'linear': 24 | betas = torch.linspace(beta_1, beta_T, steps=num_steps) 25 | 26 | betas = torch.cat([torch.zeros([1]), betas], dim=0) # Padding 27 | 28 | alphas = 1 - betas 29 | log_alphas = torch.log(alphas) 30 | for i in range(1, log_alphas.size(0)): # 1 to T 31 | log_alphas[i] += log_alphas[i - 1] 32 | alpha_bars = log_alphas.exp() 33 | 34 | sigmas_flex = torch.sqrt(betas) 35 | sigmas_inflex = torch.zeros_like(sigmas_flex) 36 | for i in range(1, sigmas_flex.size(0)): 37 | sigmas_inflex[i] = ((1 - alpha_bars[i-1]) / (1 - alpha_bars[i])) * betas[i] 38 | sigmas_inflex = torch.sqrt(sigmas_inflex) 39 | 40 | self.register_buffer('betas', betas) 41 | self.register_buffer('alphas', alphas) 42 | self.register_buffer('alpha_bars', alpha_bars) 43 | self.register_buffer('sigmas_flex', sigmas_flex) 44 | self.register_buffer('sigmas_inflex', sigmas_inflex) 45 | 46 | def uniform_sample_t(self, batch_size): 47 | ts = np.random.choice(np.arange(1, self.num_steps+1), batch_size) 48 | return ts.tolist() 49 | 50 | def get_sigmas(self, t, flexibility): 51 | assert 0 <= flexibility and flexibility <= 1 52 | sigmas = self.sigmas_flex[t] * flexibility + self.sigmas_inflex[t] * (1 - flexibility) 53 | return sigmas 54 | 55 | 56 | class Diffusion(Module): 57 | 58 | def __init__(self, net:ScoreNet): 59 | super().__init__() 60 | self.net = net 61 | self.gensde = PluginReverseSDE(scorefunc=net) 62 | 63 | def get_loss(self, x_0, obs_ind=None): 64 | ############################################### 65 | # x_0: Input sample, (B, N, D). 66 | # condition: (B, C) 67 | ############################################### 68 | 69 | loss = self.gensde.dsm(x_0, obs_ind=obs_ind).mean() 70 | 71 | return loss 72 | 73 | @torch.no_grad() 74 | def sample(self, 75 | device, 76 | num_samples=16, 77 | sample_dim=7, 78 | num_steps=256, 79 | return_diffusion=False, 80 | grad_fn=None, 81 | replace=None, 82 | mask=None, 83 | dynamics=None): 84 | 85 | condition = None 86 | 87 | # if replace is not None: 88 | # replace = replace.repeat(batch_size*int(num_samples/replace.size(0)), 1).to(condition.device) 89 | 90 | x_T = torch.randn([num_samples, sample_dim]).to(device) 91 | diffusion = {num_steps: x_T.cpu().numpy().reshape(num_samples, sample_dim)} 92 | 93 | sde = self.gensde 94 | delta = sde.T / num_steps 95 | sde.base_sde.dt = delta 96 | ts = torch.linspace(1, 0, num_steps + 1).to(x_T) * sde.T 97 | ones = torch.ones(num_samples, 1).to(x_T)/num_steps 98 | 99 | for t in range(num_steps, 0, -1): 100 | xt = sde.sample(ones * t, x_T, condition, grad_fn, replace=replace, mask=mask, dynamics=dynamics) 101 | 102 | # if replace is not None: 103 | # xt[:, 0:replace.shape[1]] = replace 104 | 105 | x_T = xt 106 | 107 | diffusion[t - 1] = xt.cpu().numpy().reshape(num_samples, sample_dim) 108 | 109 | if return_diffusion: 110 | return diffusion[0], diffusion 111 | else: 112 | return diffusion[0] 113 | 114 | def configure_sdes(self, num_steps, x_T, num_samples=16): 115 | 116 | sde = self.gensde 117 | delta = sde.T / num_steps 118 | sde.base_sde.dt = delta 119 | ts = torch.linspace(1, 0, num_steps + 1).to(x_T) * sde.T 120 | ones = torch.ones(num_samples, 1).to(x_T)/num_steps 121 | 122 | return sde, ones 123 | 124 | def configure_sdes_forward(self, num_steps, x_0, num_samples=1): 125 | 126 | sde = self.gensde 127 | delta = sde.T / num_steps 128 | sde.base_sde.dt = delta 129 | ts = torch.linspace(1, 0, num_steps + 1).to(x_0) * sde.T 130 | ones = torch.ones(num_samples, 1).to(x_0)/num_steps 131 | 132 | return sde, ones -------------------------------------------------------------------------------- /generative_skill_chaining/mixed_diffusion/datasets_transformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader 4 | import os 5 | import tqdm 6 | import pickle 7 | 8 | class StandardDataset(Dataset): 9 | 10 | def __init__(self, data): 11 | self.data = data 12 | 13 | def __len__(self): 14 | return len(self.data) 15 | 16 | def __getitem__(self, idx): 17 | return self.data[idx] 18 | 19 | class MixedDiffusionDataset(Dataset): 20 | ##################################################### 21 | # This class is used to load the dataset for the # 22 | # diffusion transformer scorenet and classifier # 23 | ##################################################### 24 | 25 | def __init__(self, dataset_path, obs_processor): 26 | ##################################################### 27 | # dataset_path: path to the dataset # 28 | # obs_processor: function to process the observation# 29 | ##################################################### 30 | 31 | self.dataset_path = dataset_path 32 | self.obs_processor = obs_processor 33 | 34 | self.eval_datapath = self.dataset_path.split(".")[0] + "_eval_random.pkl" # for the classifier 35 | self.neg_eval_datapath = self.dataset_path.split(".")[0] + "_eval_random_neg.pkl" # for the classifier 36 | 37 | with open(self.eval_datapath, 'rb') as f: 38 | self.data = pickle.load(f) 39 | self.data = np.random.permutation(self.data).tolist() 40 | self.data = self.data[:10000] 41 | 42 | 43 | with open(dataset_path, 'rb') as f: 44 | self.data.extend(pickle.load(f)) 45 | 46 | with open(self.neg_eval_datapath, 'rb') as f: 47 | self.neg_data = pickle.load(f) 48 | self.neg_data = np.random.permutation(self.neg_data)[:10000] 49 | 50 | self.obs1 = [] 51 | self.obs2 = [] 52 | self.act = [] 53 | self.obs_indices = [] 54 | 55 | self.neg_obs1 = [] 56 | self.neg_obs = [] 57 | 58 | for data in tqdm.tqdm(self.data): 59 | 60 | if len(data['observations']) != 2: 61 | continue 62 | 63 | observation1 = data['observations'][0] 64 | observation2 = data['observations'][1] 65 | action = data['actions'][0] 66 | reset_info = data['reset_info'][0] if type(data['reset_info']) == list else data['reset_info'] 67 | 68 | self.obs1.append( 69 | self.obs_processor(observation1, reset_info['policy_args']) 70 | ) 71 | 72 | self.obs2.append( 73 | self.obs_processor(observation2, reset_info['policy_args']) 74 | ) 75 | 76 | self.obs_indices.append(reset_info['policy_args']['observation_indices']) 77 | 78 | self.act.append(action) 79 | 80 | if len(self.obs1) > 10000: 81 | break 82 | 83 | self.obs1 = np.array(self.obs1)*2 # scale to [-1, 1] from [-1/2, 1/2] 84 | self.obs2 = np.array(self.obs2)*2 # scale to [-1, 1] from [-1/2, 1/2] 85 | self.act = np.array(self.act) 86 | self.obs_indices = np.array(self.obs_indices) 87 | 88 | 89 | self.neg_obs_indices = [] 90 | 91 | for data in tqdm.tqdm(self.neg_data): 92 | observation1 = data['observations'][0] 93 | observation = data['observations'][1] 94 | reset_info = data['reset_info'][0] if type(data['reset_info']) == list else data['reset_info'] 95 | 96 | self.neg_obs.append( 97 | self.obs_processor(observation, reset_info['policy_args']) 98 | ) 99 | 100 | self.neg_obs_indices.append(reset_info['policy_args']['observation_indices']) 101 | 102 | self.neg_obs = np.array(self.neg_obs)*2 # scale to [-1, 1] from [-1/2, 1/2] 103 | self.neg_obs_indices = np.array(self.neg_obs_indices) 104 | 105 | def get_data_for_mode(self, mode): 106 | ##################################################### 107 | # mode: "transition", "state", "classifier" # 108 | ##################################################### 109 | 110 | if mode == "transition": 111 | data = np.concatenate([self.obs1, self.act, self.obs2, self.obs_indices], axis=1) 112 | elif mode == "state": 113 | indices = np.random.permutation(len(self.obs1)) 114 | data = self.obs1[indices] 115 | obs_indices = self.obs_indices[indices] 116 | data = np.concatenate([data, obs_indices], axis=1) 117 | elif mode == "classifier": 118 | pos_data = np.concatenate([self.obs2, self.obs_indices, np.ones([len(self.obs2), 1])], axis=1)[np.random.permutation(len(self.obs2))][:8000] 119 | neg_data = np.concatenate([self.neg_obs, self.neg_obs_indices, np.zeros([len(self.neg_obs), 1])], axis=1)[np.random.permutation(len(self.neg_obs))][:8000] 120 | data = np.concatenate([pos_data, neg_data], axis=0) 121 | 122 | return data 123 | 124 | def get_primitive_loader(dataset_path, obs_processor, modes, batch_size=128): 125 | ##################################################### 126 | # dataset_path: path to the dataset # 127 | # obs_processor: function to process the observation# 128 | # modes: list of modes # 129 | # batch_size: batch size # 130 | # returns: list of (dataloader, dataset) for each # 131 | # mode # 132 | ##################################################### 133 | 134 | dataset = MixedDiffusionDataset(dataset_path, obs_processor) 135 | 136 | all_dataset_loader = [] 137 | 138 | for mode in modes: 139 | data = dataset.get_data_for_mode(mode) 140 | mode_dataset = StandardDataset(data) 141 | mode_loader = DataLoader(mode_dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True) 142 | 143 | all_dataset_loader.append((mode_loader, mode_dataset)) 144 | 145 | return all_dataset_loader -------------------------------------------------------------------------------- /generative_skill_chaining/mixed_diffusion/datasets_transformer_class.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader 4 | import os 5 | import tqdm 6 | import pickle 7 | 8 | class StandardDataset(Dataset): 9 | 10 | def __init__(self, data): 11 | self.data = data 12 | 13 | def __len__(self): 14 | return len(self.data) 15 | 16 | def __getitem__(self, idx): 17 | return self.data[idx] 18 | 19 | class MixedDiffusionDataset(Dataset): 20 | 21 | def __init__(self, dataset_path, obs_processor): 22 | self.dataset_path = dataset_path 23 | self.obs_processor = obs_processor 24 | 25 | self.eval_datapath = self.dataset_path.split(".")[0] + "_eval_random.pkl" 26 | self.neg_eval_datapath = self.dataset_path.split(".")[0] + "_eval_random_neg.pkl" 27 | 28 | with open(self.eval_datapath, 'rb') as f: 29 | self.data = pickle.load(f) 30 | self.data = np.random.permutation(self.data).tolist() 31 | self.data = self.data 32 | 33 | with open(dataset_path, 'rb') as f: 34 | self.data.extend(pickle.load(f)) 35 | 36 | with open(self.neg_eval_datapath, 'rb') as f: 37 | self.neg_data = pickle.load(f) 38 | self.neg_data = np.random.permutation(self.neg_data)[:10000] 39 | 40 | # with open(dataset_path, 'rb') as f: 41 | # self.data = pickle.load(f) 42 | 43 | self.obs1 = [] 44 | self.obs2 = [] 45 | self.act = [] 46 | self.obs_indices = [] 47 | 48 | self.neg_obs = [] 49 | 50 | for data in tqdm.tqdm(self.data): 51 | 52 | if len(data['observations']) != 2: 53 | continue 54 | 55 | observation1 = data['observations'][0] 56 | observation2 = data['observations'][1] 57 | action = data['actions'][0] 58 | reset_info = data['reset_info'][0] if type(data['reset_info']) == list else data['reset_info'] 59 | 60 | self.obs1.append( 61 | self.obs_processor(observation1, reset_info['policy_args']) 62 | ) 63 | 64 | self.obs2.append( 65 | self.obs_processor(observation2, reset_info['policy_args']) 66 | ) 67 | 68 | self.obs_indices.append(reset_info['policy_args']['observation_indices']) 69 | 70 | self.act.append(action) 71 | 72 | if len(self.obs1) > 40000: 73 | break 74 | 75 | self.obs1 = np.array(self.obs1)*2 # scale to [-1, 1] from [-1/2, 1/2] 76 | self.obs2 = np.array(self.obs2)*2 # scale to [-1, 1] from [-1/2, 1/2] 77 | self.act = np.array(self.act) 78 | self.obs_indices = np.array(self.obs_indices) 79 | 80 | 81 | self.neg_obs_indices = [] 82 | 83 | for data in tqdm.tqdm(self.neg_data): 84 | observation = data['observations'][1] 85 | reset_info = data['reset_info'][0] if type(data['reset_info']) == list else data['reset_info'] 86 | 87 | self.neg_obs.append( 88 | self.obs_processor(observation, reset_info['policy_args']) 89 | ) 90 | 91 | self.neg_obs_indices.append(reset_info['policy_args']['observation_indices']) 92 | 93 | self.neg_obs = np.array(self.neg_obs)*2 # scale to [-1, 1] from [-1/2, 1/2] 94 | self.neg_obs_indices = np.array(self.neg_obs_indices) 95 | 96 | # self.neg_obs = self.obs1 97 | 98 | # print(np.max(self.act, axis=0), np.min(self.act, axis=0), np.max(self.obs1, axis=0), np.min(self.obs1, axis=0), np.max(self.obs2, axis=0), np.min(self.obs2, axis=0)) 99 | 100 | # for i in range(self.obs1.shape[0]): 101 | # if abs(np.linalg.norm(self.obs1[i][:48] - self.obs2[i][:48]) - np.linalg.norm(self.obs1[i] - self.obs2[i])) > 1e-3: 102 | # print("Warning:", i) 103 | 104 | 105 | # print("Dataset size:", self.obs1[48990] - self.obs2[48990]) 106 | 107 | # assert False 108 | 109 | def get_data_for_mode(self, mode): 110 | 111 | if mode == "transition": 112 | data = np.concatenate([self.obs1, self.act, self.obs_indices, self.obs2], axis=1) 113 | if mode == "inverse": 114 | data = np.concatenate([self.obs1, self.obs2, self.obs_indices, self.act], axis=1) 115 | elif mode == "state": 116 | indices = np.random.permutation(len(self.obs1)) 117 | data = self.obs1[indices] 118 | obs_indices = self.obs_indices[indices] 119 | data = np.concatenate([data, obs_indices], axis=1) 120 | elif mode == "classifier": 121 | pos_data = np.concatenate([self.obs2, self.obs_indices, np.ones([len(self.obs2), 1])], axis=1)[np.random.permutation(len(self.obs2))][:8000] 122 | neg_data = np.concatenate([self.neg_obs, self.neg_obs_indices, np.zeros([len(self.neg_obs), 1])], axis=1)[np.random.permutation(len(self.neg_obs))][:8000] 123 | data = np.concatenate([pos_data, neg_data], axis=0) 124 | 125 | return data 126 | 127 | def get_primitive_loader(dataset_path, obs_processor, modes, batch_size=64): 128 | dataset = MixedDiffusionDataset(dataset_path, obs_processor) 129 | 130 | all_dataset_loader = [] 131 | 132 | for mode in modes: 133 | data = dataset.get_data_for_mode(mode) 134 | mode_dataset = StandardDataset(data) 135 | mode_loader = DataLoader(mode_dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True) 136 | 137 | all_dataset_loader.append((mode_loader, mode_dataset)) 138 | 139 | return all_dataset_loader -------------------------------------------------------------------------------- /generative_skill_chaining/mixed_diffusion/grad_discriminator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import random 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | class GradDiscriminator(): 10 | 11 | ################################################# 12 | # Initialize the discriminator 13 | # score_models: list of score models 14 | # conditions: list of condition kwargs for each score model 15 | # batch_size: batch size 16 | ################################################# 17 | 18 | def __init__(self, score_models, batch_size): 19 | self.score_models = score_models 20 | self.batch_size = batch_size 21 | 22 | def calc_grad(self, candidate_samples, conditions): 23 | 24 | ############################################ 25 | # Calculate the gradient of the predicted scores 26 | # with respect to the candidate samples 27 | # candidate_samples: [batch_size, num_samples, sample_dim] 28 | # returns: [batch_size, num_samples, sample_dim] 29 | ############################################ 30 | 31 | device = candidate_samples.device 32 | 33 | candidate_samples = candidate_samples.reshape(self.batch_size, -1, candidate_samples.shape[-1]) 34 | 35 | candidate_samples.requires_grad = True 36 | 37 | predicted_scores = [] 38 | 39 | loss = 0 40 | for score_model, condition in zip(self.score_models, conditions): 41 | predicted_score = score_model(candidate_samples, **condition) 42 | loss += torch.abs(torch.ones_like(predicted_score) - predicted_score)**2 43 | 44 | loss = loss.mean() 45 | 46 | grad_value = torch.autograd.grad(loss, candidate_samples)[0] - 0.05 * torch.randn_like(candidate_samples) 47 | 48 | return grad_value.reshape(-1, grad_value.shape[-1]).to(device) 49 | 50 | def last_step_discrimination(self, candidate_samples, conditions, num_grad_steps=10): 51 | 52 | ############################################ 53 | # Perform the last step discrimination (refinement) 54 | # candidate_samples: [batch_size, num_samples, sample_dim] 55 | # returns: [batch_size, num_samples, sample_dim] 56 | ############################################ 57 | 58 | candidate_samples = torch.Tensor(candidate_samples) 59 | 60 | prev_loss = 0 61 | num_hits = 0 62 | 63 | candidate_samples.requires_grad = True 64 | 65 | while True: 66 | 67 | loss = 0 68 | 69 | for score_model, condition in zip(self.score_models, conditions): 70 | predicted_score = score_model(candidate_samples, **condition) 71 | loss += torch.abs(torch.ones_like(predicted_score) - predicted_score)**2 72 | 73 | loss = loss.mean() 74 | 75 | grad_value = torch.autograd.grad(loss, candidate_samples)[0] - 0.05 * torch.randn_like(candidate_samples) 76 | 77 | candidate_samples = candidate_samples - 0.01 * grad_value 78 | 79 | num_hits += 1 80 | 81 | if num_hits > num_grad_steps: 82 | break 83 | 84 | return candidate_samples 85 | 86 | def order_for_model(self, candidate_samples, score_model, condition, n_top): 87 | 88 | ############################################ 89 | # Order the candidate samples for a specific score model 90 | # based on the predicted scores 91 | # candidate_samples: [batch_size, num_samples, sample_dim] 92 | # score_model: score model 93 | # condition: condition kwargs for the score model 94 | # n_top: number of top samples to return 95 | # returns: [batch_size, n_top, sample_dim] 96 | ############################################ 97 | 98 | B, N, D = candidate_samples.shape # B: batch size, N: num samples, D: sample dimension 99 | 100 | predicted_scores = [] 101 | 102 | predicted_scores = score_model(samples=torch.Tensor(candidate_samples), **condition) 103 | 104 | predicted_scores = predicted_scores.detach().cpu().numpy() # (B, N) 105 | 106 | indices_reo = np.argsort(predicted_scores, axis=1)[:, ::-1][:, :n_top] # (B, n_top) 107 | 108 | candidate_samples = candidate_samples[np.arange(B)[:, None], indices_reo] # (B, n_top, D) 109 | 110 | return candidate_samples 111 | 112 | -------------------------------------------------------------------------------- /generative_skill_chaining/mixed_diffusion/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/generative-skill-chaining/gsc-code/b8f5bc0d44bd453827f3f254ca1b03ec049849c5/generative_skill_chaining/mixed_diffusion/utils/__init__.py -------------------------------------------------------------------------------- /generative_skill_chaining/mixed_diffusion/utils/diff_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | ############################################### 4 | # Code derived from Song et al. 2021 5 | # https://openreview.net/pdf?id=PxTIG12RRHS 6 | ############################################### 7 | 8 | import torch 9 | from torch.nn import Module, Linear 10 | from torch.optim.lr_scheduler import LambdaLR 11 | import numpy as np 12 | 13 | def reparameterize_gaussian(mean, logvar): 14 | std = torch.exp(0.5 * logvar) 15 | eps = torch.randn(std.size()).to(mean) 16 | return mean + std * eps 17 | 18 | 19 | def gaussian_entropy(logvar): 20 | const = 0.5 * float(logvar.size(1)) * (1. + np.log(np.pi * 2)) 21 | ent = 0.5 * logvar.sum(dim=1, keepdim=False) + const 22 | return ent 23 | 24 | 25 | def standard_normal_logprob(z): 26 | dim = z.size(-1) 27 | log_z = -0.5 * dim * np.log(2 * np.pi) 28 | return log_z - z.pow(2) / 2 29 | 30 | 31 | def truncated_normal_(tensor, mean=0, std=1, trunc_std=2): 32 | """ 33 | Taken from https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15 34 | """ 35 | size = tensor.shape 36 | tmp = tensor.new_empty(size + (4,)).normal_() 37 | valid = (tmp < trunc_std) & (tmp > -trunc_std) 38 | ind = valid.max(-1, keepdim=True)[1] 39 | tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) 40 | tensor.data.mul_(std).add_(mean) 41 | return tensor 42 | 43 | 44 | class ConcatSquashLinear(Module): 45 | def __init__(self, dim_in, dim_out, dim_ctx): 46 | super(ConcatSquashLinear, self).__init__() 47 | self._layer = Linear(dim_in, dim_out) 48 | self._hyper_bias = Linear(dim_ctx, dim_out, bias=False) 49 | self._hyper_gate = Linear(dim_ctx, dim_out) 50 | 51 | def forward(self, ctx, x): 52 | gate = torch.sigmoid(self._hyper_gate(ctx)) 53 | bias = self._hyper_bias(ctx) 54 | # if x.dim() == 3: 55 | # gate = gate.unsqueeze(1) 56 | # bias = bias.unsqueeze(1) 57 | ret = self._layer(x) * gate + bias 58 | return ret 59 | 60 | 61 | def get_linear_scheduler(optimizer, start_epoch, end_epoch, start_lr, end_lr): 62 | def lr_func(epoch): 63 | if epoch <= start_epoch: 64 | return 1.0 65 | elif epoch <= end_epoch: 66 | total = end_epoch - start_epoch 67 | delta = epoch - start_epoch 68 | frac = delta / total 69 | return (1-frac) * 1.0 + frac * (end_lr / start_lr) 70 | else: 71 | return end_lr / start_lr 72 | return LambdaLR(optimizer, lr_lambda=lr_func) 73 | 74 | def lr_func(epoch): 75 | if epoch <= start_epoch: 76 | return 1.0 77 | elif epoch <= end_epoch: 78 | total = end_epoch - start_epoch 79 | delta = epoch - start_epoch 80 | frac = delta / total 81 | return (1-frac) * 1.0 + frac * (end_lr / start_lr) 82 | else: 83 | return end_lr / start_lr -------------------------------------------------------------------------------- /generative_skill_chaining/networks/__init__.py: -------------------------------------------------------------------------------- 1 | # Register Network Classes here. 2 | from . import actors 3 | from . import critics 4 | from . import encoders 5 | from .mlp import MLP 6 | -------------------------------------------------------------------------------- /generative_skill_chaining/networks/actors/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Actor 2 | from .mlp import ContinuousMLPActor, DiagonalGaussianMLPActor 3 | -------------------------------------------------------------------------------- /generative_skill_chaining/networks/actors/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import torch 4 | 5 | 6 | class Actor(torch.nn.Module, abc.ABC): 7 | """Base actor class.""" 8 | 9 | @abc.abstractmethod 10 | def forward(self, state: torch.Tensor) -> torch.distributions.Distribution: 11 | """Outputs the actor distribution. 12 | 13 | Args: 14 | state: Environment state. 15 | 16 | Returns: 17 | Action distribution. 18 | """ 19 | pass 20 | 21 | @abc.abstractmethod 22 | def predict(self, state: torch.Tensor, sample: bool = False) -> torch.Tensor: 23 | """Outputs the actor prediction. 24 | 25 | Args: 26 | state: Environment state. 27 | sample: Whether to sample from the distribution or return the mode. 28 | 29 | Returns: 30 | Action. 31 | """ 32 | pass 33 | -------------------------------------------------------------------------------- /generative_skill_chaining/networks/actors/mlp.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence, Type 2 | 3 | import gym 4 | import torch 5 | 6 | from generative_skill_chaining.networks.mlp import LFF, MLP, weight_init 7 | from generative_skill_chaining.networks.actors import base 8 | from generative_skill_chaining.networks.utils import SquashedNormal 9 | 10 | 11 | class ContinuousMLPActor(base.Actor): 12 | def __init__( 13 | self, 14 | state_space: gym.spaces.Box, 15 | action_space: gym.spaces.Box, 16 | hidden_layers: Sequence[int] = [256, 256], 17 | act: Type[torch.nn.Module] = torch.nn.ReLU, 18 | output_act: Type[torch.nn.Module] = torch.nn.Tanh, 19 | ortho_init: bool = False, 20 | ): 21 | super().__init__() 22 | self.mlp = MLP( 23 | state_space.shape[0], 24 | action_space.shape[0], 25 | hidden_layers=hidden_layers, 26 | act=act, 27 | output_act=output_act, 28 | ) 29 | if ortho_init: 30 | self.apply(weight_init) 31 | 32 | def forward(self, state: torch.Tensor) -> torch.distributions.Distribution: 33 | return self.mlp(state) 34 | 35 | def predict(self, state: torch.Tensor, sample: bool = False) -> torch.Tensor: 36 | return self.mlp(state) 37 | 38 | 39 | class DiagonalGaussianMLPActor(base.Actor): 40 | def __init__( 41 | self, 42 | state_space: gym.spaces.Box, 43 | action_space: gym.spaces.Box, 44 | hidden_layers: Sequence[int] = [256, 256], 45 | act: Type[torch.nn.Module] = torch.nn.ReLU, 46 | ortho_init: bool = False, 47 | log_std_bounds: Sequence[int] = [-5, 2], 48 | fourier_features: Optional[int] = None, 49 | ): 50 | super().__init__() 51 | self.log_std_bounds = log_std_bounds 52 | if log_std_bounds is not None: 53 | assert log_std_bounds[0] < log_std_bounds[1] 54 | if fourier_features is not None: 55 | lff = LFF(state_space.shape[0], fourier_features) 56 | mlp = MLP( 57 | fourier_features, 58 | 2 * action_space.shape[0], 59 | hidden_layers=hidden_layers, 60 | act=act, 61 | output_act=None, 62 | ) 63 | self.mlp: torch.nn.Module = torch.nn.Sequential(lff, mlp) 64 | else: 65 | self.mlp = MLP( 66 | state_space.shape[0], 67 | 2 * action_space.shape[0], 68 | hidden_layers=hidden_layers, 69 | act=act, 70 | output_act=None, 71 | ) 72 | if ortho_init: 73 | self.apply(weight_init) 74 | self.action_range = [ 75 | float(action_space.low.min()), 76 | float(action_space.high.max()), 77 | ] 78 | 79 | def forward(self, state: torch.Tensor) -> torch.distributions.Distribution: 80 | mu, log_std = self.mlp(state).chunk(2, dim=-1) 81 | if self.log_std_bounds is not None: 82 | log_std = torch.tanh(log_std) 83 | log_std_min, log_std_max = self.log_std_bounds 84 | log_std = log_std_min + 0.5 * (log_std_max - log_std_min) * (log_std + 1) 85 | dist_class: Type[torch.distributions.Distribution] = SquashedNormal 86 | else: 87 | dist_class = torch.distributions.Normal 88 | std = log_std.exp() 89 | dist = dist_class(mu, std) 90 | return dist 91 | 92 | def predict(self, state: torch.Tensor, sample: bool = False) -> torch.Tensor: 93 | dist = self(state) 94 | if sample: 95 | action = dist.sample() 96 | else: 97 | action = dist.loc 98 | action = action.clamp(*self.action_range) 99 | return action 100 | -------------------------------------------------------------------------------- /generative_skill_chaining/networks/critics/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Critic 2 | from .mlp import ContinuousMLPCritic 3 | -------------------------------------------------------------------------------- /generative_skill_chaining/networks/critics/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import torch 4 | 5 | 6 | class Critic(torch.nn.Module, abc.ABC): 7 | """Base critic class.""" 8 | 9 | @abc.abstractmethod 10 | def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor: 11 | """Predicts the expected value of the given (state, action) pair. 12 | 13 | Args: 14 | state: State. 15 | action: Action. 16 | 17 | Returns: 18 | Predicted expected value. 19 | """ 20 | pass 21 | 22 | def predict(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor: 23 | """Predicts the expected value of the given (state, action) pair. 24 | 25 | Args: 26 | state: State. 27 | action: Action. 28 | 29 | Returns: 30 | Predicted expected value. 31 | """ 32 | return self.forward(state, action) 33 | -------------------------------------------------------------------------------- /generative_skill_chaining/networks/critics/mlp.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | 5 | from generative_skill_chaining.networks.critics.base import Critic 6 | from generative_skill_chaining.networks.mlp import LFF, MLP, weight_init 7 | 8 | 9 | def create_q_network( 10 | observation_space, action_space, hidden_layers, act, fourier_features: Optional[int] 11 | ) -> torch.nn.Module: 12 | if fourier_features is not None: 13 | lff = LFF(observation_space.shape[0] + action_space.shape[0], fourier_features) 14 | mlp = MLP( 15 | fourier_features, 16 | 1, 17 | hidden_layers=hidden_layers, 18 | act=act, 19 | ) 20 | return torch.nn.Sequential(lff, mlp) 21 | else: 22 | mlp = MLP( 23 | observation_space.shape[0] + action_space.shape[0], 24 | 1, 25 | hidden_layers=hidden_layers, 26 | act=act, 27 | ) 28 | return mlp 29 | 30 | 31 | class ContinuousMLPCritic(Critic): 32 | def __init__( 33 | self, 34 | observation_space, 35 | action_space, 36 | hidden_layers=[256, 256], 37 | act=torch.nn.ReLU, 38 | num_q_fns=2, 39 | ortho_init=False, 40 | fourier_features: Optional[int] = None, 41 | ): 42 | super().__init__() 43 | 44 | self.qs = torch.nn.ModuleList( 45 | [ 46 | create_q_network( 47 | observation_space, 48 | action_space, 49 | hidden_layers, 50 | act, 51 | fourier_features, 52 | ) 53 | for _ in range(num_q_fns) 54 | ] 55 | ) 56 | if ortho_init: 57 | self.apply(weight_init) 58 | 59 | def forward(self, state: torch.Tensor, action: torch.Tensor) -> List[torch.Tensor]: # type: ignore 60 | """Predicts the expected value of the given (state, action) pair. 61 | 62 | Args: 63 | state: State. 64 | action: Action. 65 | 66 | Returns: 67 | Predicted expected value. 68 | """ 69 | x = torch.cat((state, action), dim=-1) 70 | return [q(x).squeeze(-1) for q in self.qs] 71 | 72 | def predict(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor: 73 | """Predicts the expected value of the given (state, action) pair. 74 | 75 | Args: 76 | state: State. 77 | action: Action. 78 | 79 | Returns: 80 | Predicted expected value. 81 | """ 82 | qs = self.forward(state, action) 83 | return torch.min(*qs) 84 | -------------------------------------------------------------------------------- /generative_skill_chaining/networks/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Decoder, Encoder 2 | from .normalize import NormalizeObservation 3 | from .table_env import TableEnvEncoder 4 | -------------------------------------------------------------------------------- /generative_skill_chaining/networks/encoders/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Any, Optional, Union 3 | 4 | import gym 5 | import numpy as np 6 | import torch 7 | 8 | from generative_skill_chaining import envs 9 | 10 | 11 | class Encoder(torch.nn.Module, abc.ABC): 12 | """Base encoder class.""" 13 | 14 | def __init__(self, env: envs.Env, state_space: gym.spaces.Box): 15 | """Sets up the encoder spaces. 16 | 17 | Args: 18 | state_space: Policy latent state space. 19 | """ 20 | super().__init__() 21 | 22 | self._state_space = state_space 23 | 24 | @property 25 | def state_space(self) -> gym.spaces.Box: 26 | """Policy latent state space.""" 27 | return self._state_space 28 | 29 | @abc.abstractmethod 30 | def forward( 31 | self, 32 | observation: torch.Tensor, 33 | policy_args: Union[np.ndarray, Optional[Any]], 34 | **kwargs 35 | ) -> torch.Tensor: 36 | """Encodes the observation to the policy latent state. 37 | 38 | For VAEs, this will return the latent distribution parameters. 39 | 40 | Args: 41 | observation: Environment observation. 42 | policy_args: Auxiliary policy arguments. 43 | 44 | Returns: 45 | Encoded policy state. 46 | """ 47 | pass 48 | 49 | # @abc.abstractmethod 50 | # def backward( 51 | # self, 52 | # latent: torch.Tensor, 53 | # policy_args: Union[np.ndarray, Optional[Any]], 54 | # **kwargs 55 | # ) -> torch.Tensor: 56 | 57 | # """ 58 | # Decodes the latent state into an observation. 59 | # """ 60 | # pass 61 | 62 | def predict( 63 | self, 64 | observation: torch.Tensor, 65 | policy_args: Union[np.ndarray, Optional[Any]], 66 | **kwargs 67 | ) -> torch.Tensor: 68 | """Encodes the observation to the policy latent state. 69 | 70 | Args: 71 | observation: Environment observation. 72 | policy_args: Auxiliary policy arguments. 73 | 74 | Returns: 75 | Encoded policy state. 76 | """ 77 | return self.forward(observation, policy_args, **kwargs) 78 | 79 | def reverse( 80 | self, 81 | latent: torch.Tensor, 82 | policy_args: Union[np.ndarray, Optional[Any]], 83 | **kwargs 84 | ) -> torch.Tensor: 85 | """Decodes the latent state into an observation. 86 | 87 | Args: 88 | latent: Encoded latent. 89 | policy_args: Auxiliary policy arguments. 90 | 91 | Returns: 92 | Decoded observation. 93 | """ 94 | return self.backward(latent, policy_args, **kwargs) 95 | 96 | 97 | class Decoder(torch.nn.Module, abc.ABC): 98 | """Base decoder class.""" 99 | 100 | def __init__(self, env: envs.Env, **kwargs): 101 | super().__init__() 102 | 103 | @abc.abstractmethod 104 | def forward(self, latent: torch.Tensor) -> torch.Tensor: 105 | """Decodes the latent state into an observation. 106 | 107 | Args: 108 | latent: Encoded latent. 109 | 110 | Returns: 111 | Decoded observation. 112 | """ 113 | pass 114 | -------------------------------------------------------------------------------- /generative_skill_chaining/networks/encoders/normalize.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, Union 2 | 3 | import gym 4 | import numpy as np 5 | import torch 6 | 7 | from generative_skill_chaining import envs 8 | from generative_skill_chaining.networks.encoders import Encoder 9 | 10 | 11 | class NormalizeObservation(Encoder): 12 | """Normalizes observation to the range (-0.5, 0.5).""" 13 | 14 | def __init__(self, env: envs.Env): 15 | observation_space = env.observation_space 16 | if not isinstance(observation_space, gym.spaces.Box): 17 | raise NotImplementedError 18 | 19 | state_space = gym.spaces.Box( 20 | low=-0.5, high=0.5, shape=observation_space.shape, dtype=np.float32 21 | ) 22 | super().__init__(env, state_space) 23 | 24 | self.observation_mid = torch.from_numpy( 25 | (observation_space.low + observation_space.high) / 2 26 | ) 27 | self.observation_range = torch.from_numpy( 28 | observation_space.high - observation_space.low 29 | ) 30 | 31 | def _apply(self, fn): 32 | """Ensures members get transferred with NormalizeObservation.to(device).""" 33 | super()._apply(fn) 34 | self.observation_mid = fn(self.observation_mid) 35 | self.observation_range = fn(self.observation_range) 36 | return self 37 | 38 | def forward( 39 | self, 40 | observation: torch.Tensor, 41 | policy_args: Union[np.ndarray, Optional[Any]], 42 | **kwargs 43 | ) -> torch.Tensor: 44 | """Normalizes observation to the range (-0.5, 0.5).""" 45 | return (observation - self.observation_mid) / self.observation_range 46 | -------------------------------------------------------------------------------- /generative_skill_chaining/networks/mlp.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Sequence, Type 2 | import math 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def weight_init(m): 9 | if isinstance(m, torch.nn.Linear): 10 | torch.nn.init.orthogonal_(m.weight.data) 11 | if hasattr(m.bias, "data"): 12 | m.bias.data.fill_(0.0) 13 | 14 | 15 | class MLP(torch.nn.Module): 16 | def __init__( 17 | self, 18 | input_dim: int, 19 | output_dim: int, 20 | hidden_layers: Sequence[int] = [256, 256], 21 | act: Type[torch.nn.Module] = torch.nn.ReLU, 22 | output_act: Optional[Type[torch.nn.Module]] = None, 23 | ): 24 | super().__init__() 25 | net: List[torch.nn.Module] = [] 26 | last_dim = input_dim 27 | for dim in hidden_layers: 28 | net.append(torch.nn.Linear(last_dim, dim)) 29 | net.append(act()) 30 | last_dim = dim 31 | net.append(torch.nn.Linear(last_dim, output_dim)) 32 | if output_act is not None: 33 | net.append(output_act()) 34 | self.net = torch.nn.Sequential(*net) 35 | 36 | def forward(self, x): 37 | return self.net(x) 38 | 39 | 40 | class LFF(torch.nn.Module): 41 | """ 42 | get torch.std_mean(self.B) 43 | """ 44 | 45 | def __init__(self, in_features, out_features, scale=1.0, init="iso", sincos=False): 46 | super().__init__() 47 | self.in_features = in_features 48 | self.sincos = sincos 49 | self.out_features = out_features 50 | self.scale = scale 51 | if self.sincos: 52 | self.linear = torch.nn.Linear(in_features, self.out_features // 2) 53 | else: 54 | self.linear = torch.nn.Linear(in_features, self.out_features) 55 | if init == "iso": 56 | torch.nn.init.normal_(self.linear.weight, 0, scale / self.in_features) 57 | torch.nn.init.normal_(self.linear.bias, 0, 1) 58 | else: 59 | torch.nn.init.uniform_( 60 | self.linear.weight, -scale / self.in_features, scale / self.in_features 61 | ) 62 | torch.nn.init.uniform_(self.linear.bias, -1, 1) 63 | if self.sincos: 64 | torch.nn.init.zeros_(self.linear.bias) 65 | 66 | def forward(self, x, **_): 67 | x = np.pi * self.linear(x) 68 | if self.sincos: 69 | return torch.cat([torch.sin(x), torch.cos(x)], dim=-1) 70 | else: 71 | return torch.sin(x) 72 | 73 | 74 | class LinearEnsemble(torch.nn.Module): 75 | def __init__( 76 | self, 77 | in_features, 78 | out_features, 79 | bias=True, 80 | ensemble_size=3, 81 | device=None, 82 | dtype=None, 83 | ): 84 | factory_kwargs = {"device": device, "dtype": dtype} 85 | super().__init__() 86 | self.in_features = in_features 87 | self.out_features = out_features 88 | self.ensemble_size = ensemble_size 89 | self.weight = torch.nn.Parameter( 90 | torch.empty((ensemble_size, in_features, out_features), **factory_kwargs) 91 | ) 92 | if bias: 93 | self.bias = torch.nn.Parameter( 94 | torch.empty((ensemble_size, 1, out_features), **factory_kwargs) 95 | ) 96 | else: 97 | self.register_parameter("bias", None) 98 | self.reset_parameters() 99 | 100 | def reset_parameters(self): 101 | torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 102 | if self.bias is not None: 103 | fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight[0].T) 104 | bound = 1 / math.sqrt(fan_in) 105 | torch.nn.init.uniform_(self.bias, -bound, bound) 106 | 107 | def forward(self, input): 108 | if len(input.shape) == 2: 109 | input = input.repeat(self.ensemble_size, 1, 1) 110 | elif len(input.shape) > 3: 111 | raise ValueError( 112 | "LinearEnsemble layer does not support inputs with more than 3 dimensions." 113 | ) 114 | return torch.baddbmm(self.bias, input, self.weight) 115 | 116 | def extra_repr(self) -> str: 117 | return "ensemble_size={}, in_features={}, out_features={}, bias={}".format( 118 | self.ensemble_size, 119 | self.in_features, 120 | self.out_features, 121 | self.bias is not None, 122 | ) 123 | 124 | 125 | class EnsembleMLP(torch.nn.Module): 126 | def __init__( 127 | self, 128 | input_dim, 129 | output_dim, 130 | ensemble_size=3, 131 | hidden_layers=[256, 256], 132 | act=torch.nn.ReLU, 133 | output_act=None, 134 | ): 135 | super().__init__() 136 | net = [] 137 | last_dim = input_dim 138 | for dim in hidden_layers: 139 | net.append(LinearEnsemble(last_dim, dim, ensemble_size=ensemble_size)) 140 | net.append(act()) 141 | last_dim = dim 142 | net.append(LinearEnsemble(last_dim, output_dim, ensemble_size=ensemble_size)) 143 | if output_act is not None: 144 | net.append(output_act()) 145 | self.net = torch.nn.Sequential(*net) 146 | 147 | def forward(self, x): 148 | return self.net(x) 149 | -------------------------------------------------------------------------------- /generative_skill_chaining/networks/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class SquashedNormal(torch.distributions.TransformedDistribution): 5 | def __init__(self, loc, scale): 6 | self._loc = loc 7 | self.scale = scale 8 | self.base_dist = torch.distributions.Normal(loc, scale) 9 | transforms = [torch.distributions.transforms.TanhTransform(cache_size=1)] 10 | super().__init__(self.base_dist, transforms) 11 | 12 | @property 13 | def loc(self): 14 | loc = self._loc 15 | for transform in self.transforms: 16 | loc = transform(loc) 17 | return loc 18 | -------------------------------------------------------------------------------- /generative_skill_chaining/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import configs 2 | from . import logging 3 | from . import metrics 4 | from . import nest 5 | from . import random 6 | from . import recording 7 | from . import spaces 8 | from . import tensors 9 | from . import timing 10 | -------------------------------------------------------------------------------- /generative_skill_chaining/utils/logging.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import pathlib 3 | from typing import Any, Dict, IO, Optional, Union 4 | 5 | import numpy as np 6 | import torch 7 | from torch.utils import tensorboard 8 | 9 | from generative_skill_chaining.utils import metrics 10 | 11 | 12 | class Logger(object): 13 | def __init__(self, path: Union[str, pathlib.Path]): 14 | self.path = pathlib.Path(path) 15 | 16 | self._writer: Optional[tensorboard.SummaryWriter] = None 17 | self._csv_writer: Optional[csv.DictWriter] = None 18 | self._csv_file: Optional[IO] = None 19 | 20 | self._staged: Dict[str, Any] = {} 21 | self._flushed: Dict[str, Any] = {} 22 | self._images: Dict[str, np.ndarray] = {} 23 | self._embeddings: Dict[str, Dict[str, torch.Tensor]] = {} 24 | 25 | def log( 26 | self, key: str, value: Union[Any, Dict[str, Any]], std: bool = False 27 | ) -> None: 28 | """Logs the (key, value) pair. 29 | 30 | If `value` is an array, the mean and standard deviation of values will 31 | be logged. 32 | 33 | flush() must be manually called to send the results to Tensorboard. 34 | 35 | Args: 36 | key: Tensorboard key. 37 | value: Single value or dict of values. 38 | std: Whether to log the standard deviations of arrays. 39 | """ 40 | subkey = key.split("/")[-1] 41 | if subkey.startswith("emb"): 42 | # Stage embedding. 43 | self._embeddings[key] = value 44 | return 45 | 46 | if isinstance(value, np.ndarray): 47 | subkey = key.split("/")[-1] 48 | 49 | # Stage image. 50 | if subkey.startswith("img"): 51 | self._images[key] = value 52 | return 53 | 54 | # Log mean/std of array. 55 | self.log(key, np.mean(value)) 56 | if subkey in metrics.METRIC_AGGREGATION_FNS: 57 | self.log(f"{key}/std", np.std(value)) 58 | return 59 | 60 | if isinstance(value, dict): 61 | for subkey, subval in value.items(): 62 | self.log(f"{key}/{subkey}", subval) 63 | return 64 | 65 | # Stage scalar value. 66 | self._staged[key] = value 67 | 68 | def flush(self, step: int, dump_csv: bool = False): 69 | """Flushes the logged values to Tensorboard. 70 | 71 | Args: 72 | step: Training step. 73 | dump_csv: Whether to write the log to a CSV file. 74 | """ 75 | if self._writer is None: 76 | self._writer = tensorboard.SummaryWriter(log_dir=self.path) 77 | 78 | for key, value in self._staged.items(): 79 | self._writer.add_scalar(key, value, step) 80 | for key, img in self._images.items(): 81 | self._writer.add_images(key, img, step) 82 | for key, emb in self._embeddings.items(): 83 | self._writer.add_embedding(tag=f"{key}_{step}", **emb) 84 | self._writer.flush() 85 | 86 | self._flushed.update(self._staged) 87 | self._staged = {} 88 | self._images = {} 89 | self._embeddings = {} 90 | 91 | if dump_csv: 92 | if self._csv_writer is None: 93 | self._csv_file = open(self.path / "log.csv", "w") 94 | self._csv_writer = csv.DictWriter( 95 | self._csv_file, fieldnames=list(self._flushed.keys()) 96 | ) 97 | self._csv_writer.writeheader() 98 | assert self._csv_file is not None 99 | 100 | try: 101 | self._csv_writer.writerow(self._flushed) 102 | except ValueError: 103 | # Recreate csv headers. 104 | self._csv_file.close() 105 | self._csv_file = open(self.path / "log.csv", "w") 106 | self._csv_writer = csv.DictWriter( 107 | self._csv_file, fieldnames=list(self._flushed.keys()) 108 | ) 109 | self._csv_writer.writeheader() 110 | self._csv_writer.writerow(self._flushed) 111 | 112 | self._csv_file.flush() 113 | -------------------------------------------------------------------------------- /generative_skill_chaining/utils/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, List, Mapping 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from generative_skill_chaining.utils import nest, typing 7 | 8 | 9 | METRIC_CHOICE_FNS = { 10 | "accuracy": max, 11 | "reward": max, 12 | "success": max, 13 | "loss": min, 14 | "l2_loss": min, 15 | "elbo_loss": min, 16 | "q_loss": min, 17 | "q1_loss": min, 18 | "q2_loss": min, 19 | "actor_loss": min, 20 | "alpha_loss": min, 21 | "entropy": min, 22 | "alpha": min, 23 | "target_q": max, 24 | "length": min, 25 | } 26 | 27 | METRIC_AGGREGATION_FNS: Dict[str, Callable[[np.ndarray], float]] = { 28 | "accuracy": np.mean, 29 | "reward": np.sum, 30 | "success": lambda x: x[-1], 31 | "loss": np.mean, 32 | "l2_loss": np.mean, 33 | "elbo_loss": np.mean, 34 | "q_loss": np.mean, 35 | "q1_loss": np.mean, 36 | "q2_loss": np.mean, 37 | "actor_loss": np.mean, 38 | "alpha_loss": np.mean, 39 | "entropy": np.mean, 40 | "alpha": np.mean, 41 | "target_q": np.mean, 42 | "length": np.sum, 43 | } 44 | 45 | 46 | def init_metric(metric: str) -> float: 47 | """Returns the initial value for the metric. 48 | 49 | Args: 50 | metric: Metric type. 51 | 52 | Returns: 53 | inf for min metrics, -inf for max metrics. 54 | """ 55 | a = -float("inf") 56 | b = float("inf") 57 | return b if METRIC_CHOICE_FNS[metric](a, b) == a else a 58 | 59 | 60 | def best_metric(metric: str, *values) -> float: 61 | """Returns the best metric value. 62 | 63 | Args: 64 | metric: Metric type. 65 | values: Values to compare. 66 | 67 | Returns: 68 | Min or max value depending on the metric. 69 | """ 70 | return METRIC_CHOICE_FNS[metric](*values) 71 | 72 | 73 | def aggregate_metric(metric: str, values: np.ndarray) -> float: 74 | """Aggregates the metric values. 75 | 76 | Args: 77 | metric: Metric type. 78 | values: Values to aggregate. 79 | 80 | Returns: 81 | Aggregated value. 82 | """ 83 | return METRIC_AGGREGATION_FNS[metric](values) 84 | 85 | 86 | def aggregate_metrics(metrics_list: List[Mapping[str, Any]]) -> Dict[str, float]: 87 | """Aggregates a list of metric value dicts. 88 | 89 | Args: 90 | metric_list: List of metric value dicts. 91 | 92 | Returns: 93 | Aggregated metric value dict. 94 | """ 95 | metrics = collect_metrics(metrics_list) 96 | aggregated_metrics = { 97 | metric: aggregate_metric(metric, values) for metric, values in metrics.items() 98 | } 99 | return aggregated_metrics 100 | 101 | 102 | def collect_metrics(metrics_list: List[Mapping[str, Any]]) -> Dict[str, np.ndarray]: 103 | """Transforms a list of metric value dicts to a dict of metric value arrays. 104 | 105 | Args: 106 | metric_list: List of metric value dicts. 107 | 108 | Returns: 109 | Dict of metric value arrays. 110 | """ 111 | 112 | def stack(*args): 113 | args = [arg for arg in args if arg is not None] 114 | if isinstance(args[0], torch.Tensor): 115 | return torch.stack(args, dim=0) 116 | return np.array(args) 117 | 118 | metrics = nest.map_structure( 119 | stack, *metrics_list, atom_type=(*typing.scalars, np.ndarray, torch.Tensor) 120 | ) 121 | return metrics 122 | -------------------------------------------------------------------------------- /generative_skill_chaining/utils/nest.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Generator, Iterator, Optional, Tuple, Type, Union 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from generative_skill_chaining.utils import typing 7 | 8 | 9 | def map_structure( 10 | func: Callable, 11 | *args, 12 | atom_type: Union[Type, Tuple[Type, ...]] = ( 13 | torch.Tensor, 14 | np.ndarray, 15 | *typing.scalars, 16 | type(None), 17 | ), 18 | skip_type: Optional[Union[Type, Tuple[Type, ...]]] = None, 19 | ): 20 | """Applies the function over the nested structure atoms. 21 | 22 | Works like tensorflow.nest.map_structure(): 23 | https://www.tensorflow.org/api_docs/python/tf/nest/map_structure 24 | 25 | Args: 26 | func: Function applied to the atoms of *args. 27 | *args: Nested structure arguments of `func`. 28 | atom_type: Types considered to be atoms in the nested structure. 29 | skip_type: Types to be skipped and returned as-is in the nested structure. 30 | 31 | Returns: 32 | Results of func(*args_atoms) in the same nested structure as *args. 33 | """ 34 | arg_0 = args[0] 35 | if isinstance(arg_0, atom_type): 36 | return func(*args) 37 | elif skip_type is not None and isinstance(arg_0, skip_type): 38 | return arg_0 if len(args) == 1 else args 39 | elif isinstance(arg_0, dict): 40 | return { 41 | key: map_structure( 42 | func, 43 | *(arg[key] for arg in args), 44 | atom_type=atom_type, 45 | skip_type=skip_type, 46 | ) 47 | for key in arg_0 48 | } 49 | elif hasattr(arg_0, "__iter__"): 50 | iterable_class = type(arg_0) 51 | return iterable_class( 52 | map_structure(func, *args_i, atom_type=atom_type, skip_type=skip_type) 53 | for args_i in zip(*args) 54 | ) 55 | else: 56 | return arg_0 if len(args) == 1 else args 57 | 58 | 59 | def structure_iterator( 60 | structure, 61 | atom_type: Union[Type, Tuple[Type, ...]] = ( 62 | torch.Tensor, 63 | np.ndarray, 64 | *typing.scalars, 65 | type(None), 66 | ), 67 | skip_type: Optional[Union[Type, Tuple[Type, ...]]] = None, 68 | ) -> Iterator: 69 | """Provides an iterator over the atom values in the flattened nested structure. 70 | 71 | Args: 72 | structure: Nested structure. 73 | atom_type: Types considered to be atoms in the nested structure. 74 | skip_type: Types to be skipped and returned as-is in the nested structure. 75 | 76 | Returns: 77 | Iterator over the atom values in the flattened nested structure. 78 | """ 79 | 80 | def iterate_structure( 81 | structure, 82 | ) -> Generator: 83 | if isinstance(structure, atom_type): 84 | yield structure 85 | elif skip_type is not None and isinstance(structure, skip_type): 86 | pass 87 | elif isinstance(structure, dict): 88 | for val in structure.values(): 89 | for elem in iterate_structure(val): 90 | yield elem 91 | elif hasattr(structure, "__iter__"): 92 | for val in structure: 93 | for elem in iterate_structure(val): 94 | yield elem 95 | else: 96 | pass 97 | 98 | return iter(iterate_structure(structure)) 99 | -------------------------------------------------------------------------------- /generative_skill_chaining/utils/random.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Optional 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def seed(n: Optional[int]) -> None: 9 | """Sets the random seed. 10 | 11 | Args: 12 | n: Optional seed. If None, no seed is set. 13 | """ 14 | if n is None: 15 | return 16 | 17 | torch.manual_seed(n) 18 | np.random.seed(n) 19 | random.seed(n) 20 | -------------------------------------------------------------------------------- /generative_skill_chaining/utils/recording.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import pathlib 3 | from typing import Any, Callable, Dict, List, Optional, Union 4 | 5 | import imageio 6 | import numpy as np 7 | 8 | 9 | class Recorder: 10 | def __init__(self, frequency: int = 1, max_size: Optional[int] = 1000): 11 | self.frequency = frequency 12 | self.max_size = max_size 13 | 14 | self._recordings: Dict[Any, List[np.ndarray]] = collections.defaultdict(list) 15 | self._buffer: Optional[List[np.ndarray]] = None 16 | self._timestep = 0 17 | 18 | def timestep(self) -> int: 19 | return self._timestep 20 | 21 | def is_recording(self) -> bool: 22 | return self._buffer is not None 23 | 24 | def start( 25 | self, 26 | prepend_id: Optional[str] = None, 27 | frequency: Optional[int] = None, 28 | ) -> None: 29 | """Starts recording. 30 | 31 | Existing frame buffer will be wiped out. 32 | 33 | Args: 34 | prepend_id: Upcoming recording will be prepended with the recording at this id. 35 | frequency: Recording frequency. 36 | """ 37 | prepend_buffer = ( 38 | [] if prepend_id is None else list(self._recordings[prepend_id]) 39 | ) 40 | self._buffer = prepend_buffer 41 | self._timestep = 0 42 | 43 | if frequency is not None: 44 | self.frequency = frequency 45 | 46 | def stop(self, save_id: str = "") -> bool: 47 | """Stops recording. 48 | 49 | Args: 50 | save_id: Saves the recording to this id. 51 | Returns: 52 | False if there is no recording to stop. 53 | """ 54 | if self._buffer is None or len(self._buffer) == 0: 55 | return False 56 | 57 | self._recordings[save_id] = self._buffer 58 | self._buffer = None 59 | 60 | return True 61 | 62 | def save(self, path: Union[str, pathlib.Path], reset: bool = True) -> bool: 63 | """Saves all the recordings. 64 | 65 | Args: 66 | path: Path for the recording. 67 | reset: Reset the recordings after saving. 68 | Returns: 69 | False if there were no recordings to save. 70 | """ 71 | if not isinstance(path, pathlib.Path): 72 | path = pathlib.Path(path) 73 | path.parent.mkdir(parents=True, exist_ok=True) 74 | 75 | num_saved = 0 76 | for id, recording in self._recordings.items(): 77 | if len(recording) == 0: 78 | continue 79 | 80 | if id is None or id == "": 81 | path_video = path 82 | else: 83 | path_video = path.parent / f"{path.stem}-{id}{path.suffix}" 84 | 85 | imageio.mimsave(path_video, recording) # type: ignore 86 | num_saved += 1 87 | 88 | if reset: 89 | self._recordings.clear() 90 | 91 | return num_saved > 0 92 | 93 | def add_frame( 94 | self, 95 | grab_frame_fn: Optional[Callable[[], np.ndarray]] = None, 96 | frame: Optional[np.ndarray] = None, 97 | override_frequency: bool = False, 98 | ) -> bool: 99 | """Adds a frame to the buffer. 100 | 101 | Args: 102 | grab_frame_fn: Callback function for grabbing a frame that is only 103 | called if a frame is needed. Use this if rendering is expensive. 104 | frame: Frame to add. 105 | override_frequency: Add a frame regardless of the frequency. 106 | Returns: 107 | True if a frame was added. 108 | """ 109 | self._timestep += 1 110 | 111 | if self._buffer is None: 112 | return False 113 | if self.max_size is not None and len(self._buffer) >= self.max_size: 114 | return False 115 | if not override_frequency and (self._timestep - 1) % self.frequency != 0: 116 | return False 117 | 118 | if grab_frame_fn is not None: 119 | frame = grab_frame_fn() 120 | elif frame is None: 121 | raise ValueError("One of grab_frame_fn or frame must not be None.") 122 | 123 | self._buffer.append(frame) 124 | 125 | return True 126 | -------------------------------------------------------------------------------- /generative_skill_chaining/utils/timing.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import time 3 | from typing import Dict, List, Sequence 4 | 5 | import numpy as np 6 | 7 | 8 | class Timer: 9 | """Timer to keep track of timing intervals for different keys.""" 10 | 11 | def __init__(self): 12 | self._tics = {} 13 | 14 | def keys(self) -> Sequence[str]: 15 | """Timer keys.""" 16 | return self._tics.keys() 17 | 18 | def tic(self, key: str) -> float: 19 | """Starts timing for the given key. 20 | 21 | Args: 22 | key: Time interval key. 23 | 24 | Returns: 25 | Current time. 26 | """ 27 | self._tics[key] = time.time() 28 | return self._tics[key] 29 | 30 | def toc(self, key: str, set_tic: bool = False) -> float: 31 | """Returns the time elapsed since the last tic for the given key. 32 | 33 | Args: 34 | key: Time interval key. 35 | set_tic: Reset the tic to the current time. 36 | 37 | Returns: 38 | Time elapsed since the last tic. 39 | """ 40 | toc = time.time() 41 | tic = self._tics[key] 42 | if set_tic: 43 | self._tics[key] = toc 44 | return toc - tic 45 | 46 | 47 | class Profiler(Timer): 48 | """Profiler to keep track of average time interval for different keys.""" 49 | 50 | class ProfilerContext: 51 | """Context manager for timing code inside a `with` block.""" 52 | 53 | def __init__(self, profiler: "Profiler", key: str): 54 | self.profiler = profiler 55 | self.key = key 56 | 57 | def __enter__(self) -> float: 58 | return self.profiler.tic(self.key) 59 | 60 | def __exit__(self, type, value, traceback) -> None: 61 | self.profiler.toc(self.key) 62 | 63 | def __init__(self, disabled: bool = False): 64 | """Initializes the profiler with the given status. 65 | 66 | Args: 67 | disabled: Disable the profiler. 68 | """ 69 | super().__init__() 70 | self._disabled = disabled 71 | self._tictocs: Dict[str, List[float]] = collections.defaultdict(list) 72 | 73 | def disable(self) -> None: 74 | """Disables the profiler so that tic and toc do nothing.""" 75 | self._disabled = True 76 | 77 | def enable(self) -> None: 78 | """Enables the profiler.""" 79 | self._disabled = False 80 | 81 | def tic(self, key: str) -> float: 82 | """Starts timing for the given key. 83 | 84 | Args: 85 | key: Time interval key. 86 | 87 | Returns: 88 | Current time. 89 | """ 90 | if self._disabled: 91 | return 0.0 92 | return super().tic(key) 93 | 94 | def toc(self, key: str, set_tic: bool = False) -> float: 95 | """Returns the time elapsed since the last tic for the given key. 96 | 97 | Args: 98 | key: Time interval key. 99 | set_tic: Reset the tic to the current time. 100 | 101 | Returns: 102 | Time elapsed since the last tic. 103 | """ 104 | if self._disabled: 105 | return 0.0 106 | tictoc = super().toc(key, set_tic) 107 | self._tictocs[key].append(tictoc) 108 | return tictoc 109 | 110 | def profile(self, key: str) -> ProfilerContext: 111 | """Times the code inside a `with` block for the given key. 112 | 113 | Args: 114 | key: Time interval key. 115 | 116 | Returns: 117 | Profiler context. 118 | """ 119 | return Profiler.ProfilerContext(self, key) 120 | 121 | def compute_average(self, key: str, reset: bool = False) -> float: 122 | """Computes the average time interval for the given key. 123 | 124 | Args: 125 | key: Time interval key. 126 | reset: Reset the collected time intervals. 127 | 128 | Returns: 129 | Average time interval. 130 | """ 131 | mean = float(np.mean(self._tictocs[key])) 132 | if reset: 133 | self._tictocs[key] = [] 134 | return mean 135 | 136 | def compute_sum(self, key: str, reset: bool = False) -> float: 137 | """Computes the total time interval for the given key. 138 | 139 | Args: 140 | key: Time interval key. 141 | reset: Reset the collected time intervals. 142 | 143 | Returns: 144 | Total time interval. 145 | """ 146 | sum = float(np.sum(self._tictocs[key])) 147 | if reset: 148 | self._tictocs[key] = [] 149 | return sum 150 | 151 | def collect_profiles(self) -> Dict[str, float]: 152 | """Collects and resets the average time intervals for all keys. 153 | 154 | Returns: 155 | Dict mapping from key to average time interval. 156 | """ 157 | return { 158 | key: self.compute_average(key, reset=True) 159 | for key, tictoc in self._tictocs.items() 160 | if len(tictoc) > 0 161 | } 162 | -------------------------------------------------------------------------------- /generative_skill_chaining/utils/typing.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import pathlib 3 | from typing import Any, Dict, Generic, Mapping, Type, TypeVar, Union 4 | 5 | try: 6 | from typing import TypedDict 7 | except ModuleNotFoundError: 8 | from typing_extensions import TypedDict 9 | 10 | import numpy as np 11 | import torch 12 | 13 | Scalar = Union[np.generic, float, int, bool] 14 | scalars = (np.generic, float, int, bool) 15 | Tensor = Union[np.ndarray, torch.Tensor] 16 | 17 | 18 | ArrayType = TypeVar("ArrayType", np.ndarray, torch.Tensor) 19 | StateType = TypeVar("StateType") 20 | ObsType = TypeVar("ObsType") 21 | ActType = TypeVar("ActType") 22 | BatchType = TypeVar("BatchType", bound=Mapping) 23 | 24 | 25 | class Model(abc.ABC, Generic[BatchType]): 26 | @abc.abstractmethod 27 | def create_optimizers( 28 | self, 29 | optimizer_class: Type[torch.optim.Optimizer], 30 | optimizer_kwargs: Dict[str, Any], 31 | ) -> Dict[str, torch.optim.Optimizer]: 32 | pass 33 | 34 | @abc.abstractmethod 35 | def train_step( 36 | self, 37 | step: int, 38 | batch: BatchType, 39 | optimizers: Dict[str, torch.optim.Optimizer], 40 | schedulers: Dict[str, torch.optim.lr_scheduler._LRScheduler], 41 | ) -> Dict[str, Any]: 42 | pass 43 | 44 | @abc.abstractmethod 45 | def state_dict(self) -> Dict[str, Any]: 46 | pass 47 | 48 | @abc.abstractmethod 49 | def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True) -> None: 50 | pass 51 | 52 | def save(self, path: Union[str, pathlib.Path], name: str): 53 | """Saves a checkpoint of the model and the optimizers. 54 | 55 | Args: 56 | path: Directory of checkpoint. 57 | name: Name of checkpoint (saved as `path/name.pt`). 58 | """ 59 | torch.save(self.state_dict(), pathlib.Path(path) / f"{name}.pt") 60 | 61 | def load(self, checkpoint: Union[str, pathlib.Path], strict: bool = True) -> None: 62 | """Loads the model from the given checkpoint. 63 | 64 | Args: 65 | checkpoint: Checkpoint path. 66 | strict: Make sure the state dict keys match. 67 | """ 68 | try: 69 | device = self.device # type: ignore 70 | except AttributeError: 71 | device = None 72 | state_dict = torch.load(checkpoint, map_location=device) 73 | self.load_state_dict(state_dict) 74 | 75 | @abc.abstractmethod 76 | def train_mode(self) -> None: 77 | pass 78 | 79 | @abc.abstractmethod 80 | def eval_mode(self) -> None: 81 | pass 82 | 83 | @abc.abstractmethod 84 | def to(self, device): 85 | pass 86 | 87 | 88 | ModelType = TypeVar("ModelType", bound=Model) 89 | 90 | 91 | class Batch(TypedDict): 92 | observation: Tensor 93 | action: Tensor 94 | reward: Tensor 95 | next_observation: Tensor 96 | discount: Tensor 97 | policy_args: np.ndarray 98 | 99 | 100 | class WrappedBatch(Batch): 101 | idx_replay_buffer: np.ndarray 102 | 103 | 104 | class DynamicsBatch(TypedDict): 105 | observation: Tensor 106 | idx_policy: Tensor 107 | action: Tensor 108 | next_observation: Tensor 109 | policy_args: np.ndarray 110 | 111 | 112 | class StateBatch(TypedDict): 113 | state: Tensor 114 | observation: Tensor 115 | image: Tensor 116 | 117 | 118 | class AutoencoderBatch(TypedDict): 119 | observation: Tensor 120 | 121 | 122 | class StateEncoderBatch(TypedDict): 123 | observation: Tensor 124 | state: Tensor 125 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=40.8.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "generative_skill_chaining" 7 | version = "1.0.0" 8 | authors = [ 9 | { name = "Utkarsh A. Mishra", email = "umishra31@gatech.edu" }, 10 | ] 11 | description = "Long-Horizon Skill Planning with Diffusion Models." 12 | license = { file = "LICENSE" } 13 | readme = "README.md" 14 | requires-python = ">=3.7" 15 | classifiers = [ 16 | "Programming Language :: Python :: 3", 17 | "License :: OSI Approved :: MIT License", 18 | "Operating System :: OS Independent", 19 | ] 20 | dependencies = [ 21 | # Computing 22 | "numpy", 23 | "scipy", 24 | "torch==1.11", 25 | # Image processing 26 | "imageio", 27 | "pillow", 28 | "scikit-image", 29 | # IO 30 | "pyyaml", 31 | "tensorboard", 32 | "tqdm", 33 | # Env 34 | "box2d-py", 35 | "gym>=0.25", 36 | "pybullet", 37 | # Ours 38 | "ctrlutils==1.4.1", 39 | "pysymbolic>=1.0.1", 40 | # scod-regression, 41 | "spatialdyn==1.4.4", 42 | ] 43 | 44 | [tool.setuptools.packages.find] 45 | include = ["generative_skill_chaining*"] 46 | 47 | [[tool.mypy.overrides]] 48 | module = [ 49 | "ctrlutils", 50 | "matplotlib.*", 51 | "pandas", 52 | "PIL", 53 | "pybullet", 54 | "redis.*", 55 | "scipy.*", 56 | "seaborn", 57 | "skimage", 58 | "shapely.*", 59 | "spatialdyn", 60 | "symbolic", 61 | "tqdm" 62 | ] 63 | ignore_missing_imports = true 64 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | appdirs==1.4.4 3 | async-timeout==4.0.2 4 | black==23.3.0 5 | box2d-py==2.3.8 6 | cachetools==5.3.0 7 | certifi==2023.5.7 8 | charset-normalizer==3.1.0 9 | click==8.1.3 10 | cloudpickle==2.2.1 11 | contourpy==1.0.7 12 | ctrlutils==1.4.1 13 | cycler==0.11.0 14 | docker-pycreds==0.4.0 15 | einops==0.6.1 16 | filelock==3.12.0 17 | flake8==6.0.0 18 | fonttools==4.39.3 19 | fsspec==2023.5.0 20 | functorch==0.1.1 21 | gitdb==4.0.10 22 | GitPython==3.1.31 23 | google-auth==2.17.3 24 | google-auth-oauthlib==1.0.0 25 | grpcio==1.54.0 26 | gym==0.26.2 27 | gym-notices==0.0.8 28 | h5py==3.8.0 29 | huggingface-hub==0.14.1 30 | idna==3.4 31 | imageio==2.28.1 32 | importlib-metadata==6.6.0 33 | importlib-resources==5.12.0 34 | kiwisolver==1.4.4 35 | lazy_loader==0.2 36 | Markdown==3.4.3 37 | MarkupSafe==2.1.2 38 | matplotlib==3.7.1 39 | mccabe==0.7.0 40 | mypy==1.2.0 41 | mypy-extensions==1.0.0 42 | networkx==3.1 43 | numpy==1.24.3 44 | oauthlib==3.2.2 45 | opencv-python==4.7.0.72 46 | packaging==23.1 47 | pathspec==0.11.1 48 | pathtools==0.1.2 49 | Pillow==9.5.0 50 | platformdirs==3.5.0 51 | protobuf==4.22.4 52 | psutil==5.9.5 53 | pyasn1==0.5.0 54 | pyasn1-modules==0.3.0 55 | pybullet==3.2.5 56 | pycodestyle==2.10.0 57 | pyflakes==3.0.1 58 | pygame==2.4.0 59 | pyparsing==3.0.9 60 | pysymbolic==1.0.2 61 | python-dateutil==2.8.2 62 | PyWavelets==1.4.1 63 | PyYAML==6.0 64 | redis==4.5.5 65 | requests==2.30.0 66 | requests-oauthlib==1.3.1 67 | rsa==4.9 68 | safetensors==0.3.1 69 | scikit-image==0.20.0 70 | scipy==1.9.1 71 | -e git+https://github.com/agiachris/scod-regression@060a5502caecd5fc1c7dba85cc1211f27c9a90c4#egg=scod_regression 72 | sentry-sdk==1.22.2 73 | setproctitle==1.3.2 74 | shapely==2.0.1 75 | six==1.16.0 76 | smmap==5.0.0 77 | spatialdyn==1.4.4 78 | -e git+https://github.com/UtkarshMishra04/generative-skill-chaining.git@47449862d5fcc5f76dae9932346a058897a26990#egg=generative_skill_chaining 79 | tensorboard==2.13.0 80 | tensorboard-data-server==0.7.0 81 | tifffile==2023.4.12 82 | timm==0.9.2 83 | tomli==2.0.1 84 | torch==1.11.0+cu113 85 | torchvision==0.10.1+cu111 86 | tqdm==4.65.0 87 | typing_extensions==4.5.0 88 | urllib3==1.26.15 89 | wandb==0.15.2 90 | Werkzeug==2.3.3 91 | zipp==3.15.0 92 | -------------------------------------------------------------------------------- /scripts/eval/eval_constrained_packing.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | function run_cmd { 6 | echo "" 7 | echo "${CMD}" 8 | ${CMD} 9 | } 10 | 11 | function eval_tamp_diffusion { 12 | args="" 13 | args="${args} --env-config ${ENV_CONFIG}" 14 | if [ ${#POLICY_CHECKPOINTS[@]} -gt 0 ]; then 15 | args="${args} --policy-checkpoints ${POLICY_CHECKPOINTS[@]}" 16 | fi 17 | if [ ${#DIFFUSION_CHECKPOINTS[@]} -gt 0 ]; then 18 | args="${args} --diffusion-checkpoints ${DIFFUSION_CHECKPOINTS[@]}" 19 | fi 20 | args="${args} --seed ${SEED}" 21 | args="${args} --max-depth 4" 22 | args="${args} --timeout 10" 23 | args="${args} ${ENV_KWARGS}" 24 | if [[ $DEBUG -ne 0 ]]; then 25 | args="${args} --num-eval 10" 26 | args="${args} --path ${PLANNER_OUTPUT_PATH}_debug" 27 | args="${args} --verbose 1" 28 | else 29 | args="${args} --num-eval 50" 30 | args="${args} --path ${PLANNER_OUTPUT_PATH}" 31 | args="${args} --verbose 0" 32 | fi 33 | CMD="python scripts/eval/eval_constrained_packing.py ${args}" 34 | run_cmd 35 | } 36 | 37 | function run_planners { 38 | for planner in "${PLANNERS[@]}"; do 39 | 40 | POLICY_CHECKPOINTS=() 41 | for policy_env in "${POLICY_ENVS[@]}"; do 42 | POLICY_CHECKPOINTS+=("${POLICY_INPUT_PATH}/${policy_env}/${CKPT}.pt") 43 | done 44 | 45 | DIFFUSION_CHECKPOINTS=() 46 | for policy_env in "${POLICY_ENVS[@]}"; do 47 | DIFFUSION_CHECKPOINTS+=("diffusion_models/${exp_name}/${policy_env}/") 48 | done 49 | 50 | eval_tamp_diffusion 51 | done 52 | } 53 | 54 | SEED=100 55 | 56 | # Setup. 57 | DEBUG=0 58 | input_path="models" 59 | output_path="plots" 60 | 61 | # Evaluate planners. 62 | PLANNERS=( 63 | "diffusion" 64 | ) 65 | 66 | # Experiments. 67 | 68 | # Pybullet. 69 | exp_name="official" 70 | PLANNER_CONFIG_PATH="configs/pybullet/planners" 71 | ENVS=( 72 | "constrained_packing/task0" 73 | ) 74 | POLICY_ENVS=("pick" "place" "pull" "push") 75 | CKPT="best_model" 76 | ENV_KWARGS="--closed-loop 1" 77 | 78 | # Run planners. 79 | POLICY_INPUT_PATH="${input_path}/${exp_name}" 80 | 81 | for env in "${ENVS[@]}"; do 82 | ENV_CONFIG="configs/pybullet/envs/official/domains/${env}.yaml" 83 | PLANNER_OUTPUT_PATH="${output_path}/${exp_name}/tamp_experiment/${env}" 84 | run_planners 85 | done -------------------------------------------------------------------------------- /scripts/eval/eval_diffusion_transformer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | function run_cmd { 6 | echo "" 7 | echo "${CMD}" 8 | ${CMD} 9 | } 10 | 11 | function eval_diffusion { 12 | args="" 13 | args="${args} --checkpoint ${POLICY_CHECKPOINT}" 14 | args="${args} --diffusion-checkpoint ${DIFFUSION_CHECKPOINT}" 15 | args="${args} --env-config ${ENV_CONFIG}" 16 | args="${args} --seed ${SEED}" 17 | args="${args} ${ENV_KWARGS}" 18 | if [[ $DEBUG -ne 0 ]]; then 19 | args="${args} --path plots/${EXP_NAME}_debug" 20 | args="${args} --num-episodes 1" 21 | args="${args} --verbose 1" 22 | else 23 | args="${args} --path plots/${EXP_NAME}" 24 | args="${args} --verbose 0" 25 | args="${args} --num-episodes ${NUM_EPISODES}" 26 | fi 27 | if [[ -n "${DEBUG_RESULTS}" ]]; then 28 | args="${args} --debug-results ${DEBUG_RESULTS}" 29 | fi 30 | CMD="python scripts/eval/eval_diffusion_transformer.py ${args}" 31 | run_cmd 32 | } 33 | 34 | # Setup. 35 | 36 | DEBUG=0 37 | NUM_EPISODES=5 38 | 39 | # Evaluate policies. 40 | 41 | SEED=0 42 | 43 | policy_envs=( 44 | "pick" 45 | "place" 46 | "pull" 47 | "push" 48 | ) 49 | experiments=( 50 | "official" 51 | ) 52 | ckpts=( 53 | "best_model" 54 | ) 55 | 56 | for exp_name in "${experiments[@]}"; do 57 | for ckpt in "${ckpts[@]}"; do 58 | for policy_env in "${policy_envs[@]}"; do 59 | EXP_NAME="${exp_name}/${ckpt}" 60 | POLICY_CHECKPOINT="models/${exp_name}/${policy_env}/${ckpt}.pt" 61 | DIFFUSION_CHECKPOINT="diffusion_models/${exp_name}/${policy_env}/" 62 | ENV_CONFIG="configs/pybullet/envs/official/primitives/${policy_env}_eval.yaml" 63 | eval_diffusion 64 | done 65 | done 66 | done 67 | -------------------------------------------------------------------------------- /scripts/eval/eval_hook_reach.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | function run_cmd { 6 | echo "" 7 | echo "${CMD}" 8 | ${CMD} 9 | } 10 | 11 | function eval_tamp_diffusion { 12 | args="" 13 | args="${args} --env-config ${ENV_CONFIG}" 14 | if [ ${#POLICY_CHECKPOINTS[@]} -gt 0 ]; then 15 | args="${args} --policy-checkpoints ${POLICY_CHECKPOINTS[@]}" 16 | fi 17 | if [ ${#DIFFUSION_CHECKPOINTS[@]} -gt 0 ]; then 18 | args="${args} --diffusion-checkpoints ${DIFFUSION_CHECKPOINTS[@]}" 19 | fi 20 | args="${args} --seed ${SEED}" 21 | args="${args} --max-depth 4" 22 | args="${args} --timeout 10" 23 | args="${args} ${ENV_KWARGS}" 24 | if [[ $DEBUG -ne 0 ]]; then 25 | args="${args} --num-eval 10" 26 | args="${args} --path ${PLANNER_OUTPUT_PATH}_debug" 27 | args="${args} --verbose 1" 28 | else 29 | args="${args} --num-eval 50" 30 | args="${args} --path ${PLANNER_OUTPUT_PATH}" 31 | args="${args} --verbose 0" 32 | fi 33 | CMD="python scripts/eval/eval_hook_reach.py ${args}" 34 | run_cmd 35 | } 36 | 37 | function run_planners { 38 | for planner in "${PLANNERS[@]}"; do 39 | 40 | POLICY_CHECKPOINTS=() 41 | for policy_env in "${POLICY_ENVS[@]}"; do 42 | POLICY_CHECKPOINTS+=("${POLICY_INPUT_PATH}/${policy_env}/${CKPT}.pt") 43 | done 44 | 45 | DIFFUSION_CHECKPOINTS=() 46 | for policy_env in "${POLICY_ENVS[@]}"; do 47 | DIFFUSION_CHECKPOINTS+=("diffusion_models/${exp_name}/${policy_env}/") 48 | done 49 | 50 | eval_tamp_diffusion 51 | done 52 | } 53 | 54 | SEED=100 55 | 56 | # Setup. 57 | DEBUG=0 58 | input_path="models" 59 | output_path="plots" 60 | 61 | # Evaluate planners. 62 | PLANNERS=( 63 | "diffusion" 64 | ) 65 | 66 | # Experiments. 67 | 68 | # Pybullet. 69 | exp_name="official" 70 | PLANNER_CONFIG_PATH="configs/pybullet/planners" 71 | ENVS=( 72 | "hook_reach/task2" 73 | ) 74 | POLICY_ENVS=("pick" "place" "pull" "push") 75 | CKPT="best_model" 76 | ENV_KWARGS="--closed-loop 1" 77 | 78 | # Run planners. 79 | POLICY_INPUT_PATH="${input_path}/${exp_name}" 80 | 81 | for env in "${ENVS[@]}"; do 82 | ENV_CONFIG="configs/pybullet/envs/official/domains/${env}.yaml" 83 | PLANNER_OUTPUT_PATH="${output_path}/${exp_name}/tamp_experiment/${env}" 84 | run_planners 85 | done -------------------------------------------------------------------------------- /scripts/train/train_diffusion_transformer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | function run_cmd { 6 | echo "" 7 | echo "${CMD}" 8 | ${CMD} 9 | } 10 | 11 | function train_diffusion { 12 | args="" 13 | args="${args} --checkpoint ${POLICY_CHECKPOINT}" 14 | args="${args} --dataset-checkpoint ${DATASET_CHECKPOINT}" 15 | args="${args} --diffusion-checkpoint ${DIFFUSION_CHECKPOINT}" 16 | args="${args} --train-classifier" 17 | args="${args} --seed 0" 18 | args="${args} ${ENV_KWARGS}" 19 | if [[ $DEBUG -ne 0 ]]; then 20 | args="${args} --path plots/${EXP_NAME}_debug" 21 | args="${args} --num-episodes 1" 22 | args="${args} --verbose 1" 23 | else 24 | args="${args} --path plots/${EXP_NAME}" 25 | args="${args} --verbose 1" 26 | args="${args} --num-episodes ${NUM_EPISODES}" 27 | fi 28 | if [[ -n "${DEBUG_RESULTS}" ]]; then 29 | args="${args} --debug-results ${DEBUG_RESULTS}" 30 | fi 31 | if [[ -n "${ENV_CONFIG}" ]]; then 32 | args="${args} --env-config ${ENV_CONFIG}" 33 | fi 34 | CMD="python scripts/train/train_diffusion_transformer.py ${args}" 35 | run_cmd 36 | } 37 | 38 | # Setup. 39 | 40 | DEBUG=0 41 | NUM_EPISODES=50000 42 | 43 | # Evaluate policies. 44 | 45 | policy_envs=( 46 | "pick" 47 | "place" 48 | "pull" 49 | "push" 50 | ) 51 | experiments=( 52 | "official" 53 | ) 54 | ckpts=( 55 | "best_model" 56 | ) 57 | 58 | for exp_name in "${experiments[@]}"; do 59 | for ckpt in "${ckpts[@]}"; do 60 | for policy_env in "${policy_envs[@]}"; do 61 | EXP_NAME="${exp_name}/${ckpt}" 62 | POLICY_CHECKPOINT="models/${exp_name}/${policy_env}/${ckpt}.pt" 63 | DATASET_CHECKPOINT="datasets/${policy_env}.pkl" 64 | DIFFUSION_CHECKPOINT="diffusion_models/${exp_name}/${policy_env}/" 65 | train_diffusion 66 | done 67 | done 68 | done 69 | -------------------------------------------------------------------------------- /scripts/train/train_diffusion_transformer_w_classifier.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | function run_cmd { 6 | echo "" 7 | echo "${CMD}" 8 | ${CMD} 9 | } 10 | 11 | function train_diffusion { 12 | args="" 13 | args="${args} --checkpoint ${POLICY_CHECKPOINT}" 14 | args="${args} --dataset-checkpoint ${DATASET_CHECKPOINT}" 15 | args="${args} --diffusion-checkpoint ${DIFFUSION_CHECKPOINT}" 16 | args="${args} --train-classifier" 17 | args="${args} --seed 0" 18 | args="${args} ${ENV_KWARGS}" 19 | if [[ $DEBUG -ne 0 ]]; then 20 | args="${args} --path plots/${EXP_NAME}_debug" 21 | args="${args} --num-episodes 1" 22 | args="${args} --verbose 1" 23 | else 24 | args="${args} --path plots/${EXP_NAME}" 25 | args="${args} --verbose 1" 26 | args="${args} --num-episodes ${NUM_EPISODES}" 27 | fi 28 | if [[ -n "${DEBUG_RESULTS}" ]]; then 29 | args="${args} --debug-results ${DEBUG_RESULTS}" 30 | fi 31 | if [[ -n "${ENV_CONFIG}" ]]; then 32 | args="${args} --env-config ${ENV_CONFIG}" 33 | fi 34 | CMD="python scripts/train/train_diffusion_transformer_w_classifier.py ${args}" 35 | run_cmd 36 | } 37 | 38 | # Setup. 39 | 40 | DEBUG=0 41 | NUM_EPISODES=50000 42 | 43 | # Evaluate policies. 44 | 45 | policy_envs=( 46 | "pick" 47 | "place" 48 | "pull" 49 | "push" 50 | ) 51 | experiments=( 52 | "official" 53 | ) 54 | ckpts=( 55 | "best_model" 56 | ) 57 | 58 | for exp_name in "${experiments[@]}"; do 59 | for ckpt in "${ckpts[@]}"; do 60 | for policy_env in "${policy_envs[@]}"; do 61 | EXP_NAME="${exp_name}/${ckpt}" 62 | POLICY_CHECKPOINT="models/${exp_name}/${policy_env}/${ckpt}.pt" 63 | DATASET_CHECKPOINT="datasets/${policy_env}.pkl" 64 | DIFFUSION_CHECKPOINT="diffusion_models/${exp_name}/${policy_env}/" 65 | train_diffusion 66 | done 67 | done 68 | done 69 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # Dummy setup.cfg for editable installs until setuptools supports PEP 660. 2 | 3 | [options] 4 | packages = generative_skill_chaining 5 | -------------------------------------------------------------------------------- /setup_shell.sh: -------------------------------------------------------------------------------- 1 | # Make sure we have the conda environment set up. 2 | CONDA_PATH=~/miniconda3/bin/activate 3 | ENV_NAME=generative_skill_chaining 4 | REPO_PATH=path/to/your/repo 5 | 6 | # Setup Conda 7 | source $CONDA_PATH 8 | conda activate $ENV_NAME 9 | cd $REPO_PATH 10 | 11 | unset DISPLAY # Make sure display is not set or it will prevent scripts from running in headless mode. 12 | 13 | # Try to import CUDA if we can. 14 | if [ -d "/usr/local/cuda-11.1" ]; then 15 | export PATH=/usr/local/cuda-11.1/bin:$PATH 16 | elif [ -d "/usr/local/cuda-11.0" ]; then 17 | export PATH=/usr/local/cuda-11.0/bin:$PATH 18 | elif [ -d "/usr/local/cuda-10.2" ]; then 19 | export PATH=/usr/local/cuda-10.2/bin:$PATH 20 | fi 21 | --------------------------------------------------------------------------------