├── state_diff
├── env
│ ├── kitchen
│ │ ├── relay_policy_learning
│ │ │ ├── adept_models
│ │ │ │ ├── __init__.py
│ │ │ │ ├── .gitignore
│ │ │ │ ├── kitchen
│ │ │ │ │ ├── meshes
│ │ │ │ │ │ ├── hood.stl
│ │ │ │ │ │ ├── knob.stl
│ │ │ │ │ │ ├── oven.stl
│ │ │ │ │ │ ├── tile.stl
│ │ │ │ │ │ ├── wall.stl
│ │ │ │ │ │ ├── faucet.stl
│ │ │ │ │ │ ├── handle2.stl
│ │ │ │ │ │ ├── kettle.stl
│ │ │ │ │ │ ├── micro.stl
│ │ │ │ │ │ ├── oventop.stl
│ │ │ │ │ │ ├── hingedoor.stl
│ │ │ │ │ │ ├── microdoor.stl
│ │ │ │ │ │ ├── microfeet.stl
│ │ │ │ │ │ ├── slidedoor.stl
│ │ │ │ │ │ ├── stoverim.stl
│ │ │ │ │ │ ├── burnerplate.stl
│ │ │ │ │ │ ├── cabinetbase.stl
│ │ │ │ │ │ ├── countertop.stl
│ │ │ │ │ │ ├── hingecabinet.stl
│ │ │ │ │ │ ├── hingehandle.stl
│ │ │ │ │ │ ├── kettlehandle.stl
│ │ │ │ │ │ ├── lightswitch.stl
│ │ │ │ │ │ ├── microbutton.stl
│ │ │ │ │ │ ├── microefeet.stl
│ │ │ │ │ │ ├── microhandle.stl
│ │ │ │ │ │ ├── microwindow.stl
│ │ │ │ │ │ ├── ovenhandle.stl
│ │ │ │ │ │ ├── ovenwindow.stl
│ │ │ │ │ │ ├── slidecabinet.stl
│ │ │ │ │ │ ├── cabinetdrawer.stl
│ │ │ │ │ │ ├── cabinethandle.stl
│ │ │ │ │ │ ├── burnerplate_mesh.stl
│ │ │ │ │ │ └── lightswitchbase.stl
│ │ │ │ │ ├── textures
│ │ │ │ │ │ ├── tile1.png
│ │ │ │ │ │ ├── wood1.png
│ │ │ │ │ │ ├── marble1.png
│ │ │ │ │ │ └── metal1.png
│ │ │ │ │ ├── microwave.xml
│ │ │ │ │ ├── oven.xml
│ │ │ │ │ ├── kettle.xml
│ │ │ │ │ ├── counters.xml
│ │ │ │ │ ├── hingecabinet.xml
│ │ │ │ │ ├── slidecabinet.xml
│ │ │ │ │ ├── assets
│ │ │ │ │ │ ├── backwall_chain.xml
│ │ │ │ │ │ ├── backwall_asset.xml
│ │ │ │ │ │ ├── slidecabinet_asset.xml
│ │ │ │ │ │ ├── hingecabinet_asset.xml
│ │ │ │ │ │ ├── kettle_chain.xml
│ │ │ │ │ │ ├── kettle_asset.xml
│ │ │ │ │ │ ├── microwave_asset.xml
│ │ │ │ │ │ ├── counters_asset.xml
│ │ │ │ │ │ ├── slidecabinet_chain.xml
│ │ │ │ │ │ ├── microwave_chain.xml
│ │ │ │ │ │ └── oven_asset.xml
│ │ │ │ │ └── kitchen.xml
│ │ │ │ ├── scenes
│ │ │ │ │ ├── textures
│ │ │ │ │ │ ├── white_marble_tile.png
│ │ │ │ │ │ └── white_marble_tile2.png
│ │ │ │ │ └── basic_scene.xml
│ │ │ │ ├── README.public.md
│ │ │ │ └── CONTRIBUTING.public.md
│ │ │ ├── third_party
│ │ │ │ └── franka
│ │ │ │ │ ├── franka_panda.png
│ │ │ │ │ ├── meshes
│ │ │ │ │ ├── visual
│ │ │ │ │ │ ├── hand.stl
│ │ │ │ │ │ ├── finger.stl
│ │ │ │ │ │ ├── link0.stl
│ │ │ │ │ │ ├── link1.stl
│ │ │ │ │ │ ├── link2.stl
│ │ │ │ │ │ ├── link3.stl
│ │ │ │ │ │ ├── link4.stl
│ │ │ │ │ │ ├── link5.stl
│ │ │ │ │ │ ├── link6.stl
│ │ │ │ │ │ └── link7.stl
│ │ │ │ │ └── collision
│ │ │ │ │ │ ├── hand.stl
│ │ │ │ │ │ ├── finger.stl
│ │ │ │ │ │ ├── link0.stl
│ │ │ │ │ │ ├── link1.stl
│ │ │ │ │ │ ├── link2.stl
│ │ │ │ │ │ ├── link3.stl
│ │ │ │ │ │ ├── link4.stl
│ │ │ │ │ │ ├── link5.stl
│ │ │ │ │ │ ├── link6.stl
│ │ │ │ │ │ └── link7.stl
│ │ │ │ │ ├── README.md
│ │ │ │ │ ├── assets
│ │ │ │ │ ├── basic_scene.xml
│ │ │ │ │ ├── teleop_actuator.xml
│ │ │ │ │ ├── actuator1.xml
│ │ │ │ │ ├── actuator0.xml
│ │ │ │ │ ├── assets.xml
│ │ │ │ │ └── chain0_overlay.xml
│ │ │ │ │ ├── franka_panda.xml
│ │ │ │ │ ├── franka_panda_teleop.xml
│ │ │ │ │ └── bi-franka_panda.xml
│ │ │ └── adept_envs
│ │ │ │ └── adept_envs
│ │ │ │ ├── __init__.py
│ │ │ │ ├── franka
│ │ │ │ └── __init__.py
│ │ │ │ ├── utils
│ │ │ │ ├── constants.py
│ │ │ │ └── config.py
│ │ │ │ └── simulation
│ │ │ │ └── module.py
│ │ ├── v0.py
│ │ ├── __init__.py
│ │ ├── kitchen_lowdim_wrapper.py
│ │ └── kitchen_util.py
│ ├── pusht
│ │ └── __init__.py
│ └── block_pushing
│ │ ├── assets
│ │ ├── plane.obj
│ │ ├── zone.urdf
│ │ ├── zone2.urdf
│ │ ├── workspace.urdf
│ │ ├── workspace_real.urdf
│ │ ├── block.urdf
│ │ ├── block2.urdf
│ │ ├── blocks
│ │ │ ├── red_moon.urdf
│ │ │ ├── blue_cube.urdf
│ │ │ ├── green_star.urdf
│ │ │ └── yellow_pentagon.urdf
│ │ ├── zone.obj
│ │ ├── insert.urdf
│ │ └── suction
│ │ │ ├── suction-base.urdf
│ │ │ ├── suction-head.urdf
│ │ │ ├── cylinder_real.urdf
│ │ │ ├── cylinder.urdf
│ │ │ └── suction-head-long.urdf
│ │ ├── oracles
│ │ ├── pushing_info.py
│ │ ├── reach_oracle.py
│ │ └── discontinuous_push_oracle.py
│ │ └── utils
│ │ └── pose3d.py
├── env_runner
│ └── base_lowdim_runner.py
├── model
│ ├── common
│ │ ├── module_attr_mixin.py
│ │ ├── shape_util.py
│ │ ├── dict_of_tensor_mixin.py
│ │ ├── lr_scheduler.py
│ │ ├── losses.py
│ │ └── rotation_transformer.py
│ └── diffusion
│ │ ├── positional_embedding.py
│ │ └── conv1d_components.py
├── common
│ ├── env_util.py
│ ├── nested_dict_util.py
│ ├── pymunk_util.py
│ ├── robomimic_config_util.py
│ ├── precise_sleep.py
│ ├── pytorch_util.py
│ └── checkpoint_util.py
├── scripts
│ ├── episode_lengths.py
│ ├── blockpush_abs_conversion.py
│ ├── bet_blockpush_conversion.py
│ ├── real_dataset_conversion.py
│ ├── generate_bet_blockpush.py
│ └── real_pusht_successrate.py
├── config
│ └── task
│ │ ├── kitchen_lowdim.yaml
│ │ ├── blockpush_lowdim_seed.yaml
│ │ ├── blockpush_lowdim_seed_abs.yaml
│ │ ├── blockpush_traj_lowdim_seed.yaml
│ │ ├── blockpush_traj_lowdim_seed_abs.yaml
│ │ ├── kitchen_lowdim_abs.yaml
│ │ ├── pushl_traj_lowdim.yaml
│ │ └── pushl_lowdim.yaml
├── shared_memory
│ └── shared_memory_util.py
├── gym_util
│ ├── video_wrapper.py
│ └── video_recording_wrapper.py
└── dataset
│ ├── base_dataset.py
│ ├── kitchen_lowdim_dataset.py
│ ├── pusht_dataset.py
│ └── pusht_traj_dataset.py
├── media
└── teaser_image.png
├── pyrightconfig.json
├── setup.py
├── tests
├── test_cv2_util.py
├── test_robomimic_lowdim_runner.py
├── test_robomimic_image_runner.py
├── test_block_pushing.py
├── test_precise_sleep.py
├── test_shared_queue.py
├── test_replay_buffer.py
├── test_multi_realsense.py
└── test_single_realsense.py
├── train.py
├── conda_environment.yaml
├── .gitignore
└── README.md
/state_diff/env/kitchen/relay_policy_learning/adept_models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/media/teaser_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/media/teaser_image.png
--------------------------------------------------------------------------------
/pyrightconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "exclude": [
3 | "data/**",
4 | "data_local/**",
5 | "outputs/**"
6 | ]
7 | }
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'state_diff',
5 | packages=find_packages(),
6 | )
7 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/.gitignore:
--------------------------------------------------------------------------------
1 | # General
2 | .DS_Store
3 | *.swp
4 | *.profraw
5 |
6 | # Editors
7 | .vscode
8 | .idea
9 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/hood.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/hood.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/knob.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/knob.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/oven.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/oven.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/tile.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/tile.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/wall.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/wall.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/third_party/franka/franka_panda.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/franka_panda.png
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/faucet.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/faucet.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/handle2.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/handle2.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/kettle.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/kettle.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/micro.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/micro.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/oventop.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/oventop.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/textures/tile1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/textures/tile1.png
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/textures/wood1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/textures/wood1.png
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/hingedoor.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/hingedoor.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/microdoor.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/microdoor.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/microfeet.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/microfeet.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/slidedoor.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/slidedoor.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/stoverim.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/stoverim.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/textures/marble1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/textures/marble1.png
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/textures/metal1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/textures/metal1.png
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/hand.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/hand.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/burnerplate.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/burnerplate.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/cabinetbase.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/cabinetbase.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/countertop.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/countertop.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/hingecabinet.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/hingecabinet.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/hingehandle.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/hingehandle.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/kettlehandle.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/kettlehandle.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/lightswitch.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/lightswitch.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/microbutton.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/microbutton.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/microefeet.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/microefeet.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/microhandle.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/microhandle.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/microwindow.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/microwindow.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/ovenhandle.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/ovenhandle.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/ovenwindow.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/ovenwindow.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/slidecabinet.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/slidecabinet.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/hand.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/hand.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/finger.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/finger.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/link0.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/link0.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/link1.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/link1.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/link2.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/link2.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/link3.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/link3.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/link4.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/link4.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/link5.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/link5.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/link6.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/link6.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/link7.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/link7.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/cabinetdrawer.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/cabinetdrawer.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/cabinethandle.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/cabinethandle.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/finger.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/finger.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link0.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link0.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link1.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link1.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link2.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link2.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link3.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link3.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link4.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link4.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link5.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link5.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link6.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link6.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link7.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link7.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/burnerplate_mesh.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/burnerplate_mesh.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/lightswitchbase.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/lightswitchbase.stl
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/scenes/textures/white_marble_tile.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/scenes/textures/white_marble_tile.png
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/scenes/textures/white_marble_tile2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/scenes/textures/white_marble_tile2.png
--------------------------------------------------------------------------------
/state_diff/env/pusht/__init__.py:
--------------------------------------------------------------------------------
1 | from gym.envs.registration import register
2 | import state_diff.env.pusht
3 |
4 | register(
5 | id='pusht-keypoints-v0',
6 | entry_point='envs.pusht.pusht_keypoints_env:PushTKeypointsEnv',
7 | max_episode_steps=200,
8 | reward_threshold=1.0
9 | )
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/third_party/franka/README.md:
--------------------------------------------------------------------------------
1 | # franka
2 | Franka panda mujoco models
3 |
4 |
5 | # Environment
6 |
7 | franka_panda.xml | coming soon
8 | :-------------------------:|:-------------------------:
9 |  | coming soon
10 |
--------------------------------------------------------------------------------
/state_diff/env_runner/base_lowdim_runner.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | from state_diff.policy.base_lowdim_policy import BaseLowdimPolicy
3 |
4 | class BaseLowdimRunner:
5 | def __init__(self, output_dir):
6 | self.output_dir = output_dir
7 |
8 | def run(self, policy: BaseLowdimPolicy) -> Dict:
9 | raise NotImplementedError()
10 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/README.public.md:
--------------------------------------------------------------------------------
1 | # D'Suite Scenes
2 |
3 | This repository is based on a collection of [MuJoCo](http://www.mujoco.org/) simulation
4 | scenes and common assets for D'Suite environments. Based on code in the ROBEL suite
5 | https://github.com/google-research/robel
6 |
7 | ## Disclaimer
8 |
9 | This is not an official Google product.
10 |
11 |
--------------------------------------------------------------------------------
/state_diff/model/common/module_attr_mixin.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | class ModuleAttrMixin(nn.Module):
4 | def __init__(self):
5 | super().__init__()
6 | self._dummy_variable = nn.Parameter()
7 |
8 | @property
9 | def device(self):
10 | return next(iter(self.parameters())).device
11 |
12 | @property
13 | def dtype(self):
14 | return next(iter(self.parameters())).dtype
15 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/microwave.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/oven.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/kettle.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/counters.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/hingecabinet.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/state_diff/env/block_pushing/assets/plane.obj:
--------------------------------------------------------------------------------
1 | # Blender v2.66 (sub 1) OBJ File: ''
2 | # www.blender.org
3 | mtllib plane.mtl
4 | o Plane
5 | v 15.000000 -15.000000 0.000000
6 | v 15.000000 15.000000 0.000000
7 | v -15.000000 15.000000 0.000000
8 | v -15.000000 -15.000000 0.000000
9 |
10 | vt 15.000000 0.000000
11 | vt 15.000000 15.000000
12 | vt 0.000000 15.000000
13 | vt 0.000000 0.000000
14 |
15 | usemtl Material
16 | s off
17 | f 1/1 2/2 3/3
18 | f 1/1 3/3 4/4
19 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/slidecabinet.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/state_diff/model/diffusion/positional_embedding.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 |
5 | class SinusoidalPosEmb(nn.Module):
6 | def __init__(self, dim):
7 | super().__init__()
8 | self.dim = dim
9 |
10 | def forward(self, x):
11 | device = x.device
12 | half_dim = self.dim // 2
13 | emb = math.log(10000) / (half_dim - 1)
14 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
15 | emb = x[:, None] * emb[None, :]
16 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
17 | return emb
18 |
--------------------------------------------------------------------------------
/tests/test_cv2_util.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
5 | sys.path.append(ROOT_DIR)
6 | os.chdir(ROOT_DIR)
7 |
8 | import numpy as np
9 | from state_diff.common.cv2_util import get_image_transform
10 |
11 |
12 | def test():
13 | tf = get_image_transform((1280,720), (640,480), bgr_to_rgb=False)
14 | in_img = np.zeros((720,1280,3), dtype=np.uint8)
15 | out_img = tf(in_img)
16 | # print(out_img.shape)
17 | assert out_img.shape == (480,640,3)
18 |
19 | # import pdb; pdb.set_trace()
20 |
21 | if __name__ == '__main__':
22 | test()
23 |
--------------------------------------------------------------------------------
/state_diff/env/block_pushing/assets/zone.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/state_diff/env/block_pushing/assets/zone2.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/v0.py:
--------------------------------------------------------------------------------
1 | from state_diff.env.kitchen.base import KitchenBase
2 |
3 |
4 | class KitchenMicrowaveKettleBottomBurnerLightV0(KitchenBase):
5 | TASK_ELEMENTS = ["microwave", "kettle", "bottom burner", "light switch"]
6 | COMPLETE_IN_ANY_ORDER = False
7 |
8 |
9 | class KitchenMicrowaveKettleLightSliderV0(KitchenBase):
10 | TASK_ELEMENTS = ["microwave", "kettle", "light switch", "slide cabinet"]
11 | COMPLETE_IN_ANY_ORDER = False
12 |
13 |
14 | class KitchenKettleMicrowaveLightSliderV0(KitchenBase):
15 | TASK_ELEMENTS = ["kettle", "microwave", "light switch", "slide cabinet"]
16 | COMPLETE_IN_ANY_ORDER = False
17 |
18 |
19 | class KitchenAllV0(KitchenBase):
20 | TASK_ELEMENTS = KitchenBase.ALL_TASKS
21 |
--------------------------------------------------------------------------------
/state_diff/common/env_util.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 |
5 | def render_env_video(env, states, actions=None):
6 | observations = states
7 | imgs = list()
8 | for i in range(len(observations)):
9 | state = observations[i]
10 | env.set_state(state)
11 | if i == 0:
12 | env.set_state(state)
13 | img = env.render()
14 | # draw action
15 | if actions is not None:
16 | action = actions[i]
17 | coord = (action / 512 * 96).astype(np.int32)
18 | cv2.drawMarker(img, coord,
19 | color=(255,0,0), markerType=cv2.MARKER_CROSS,
20 | markerSize=8, thickness=1)
21 | imgs.append(img)
22 | imgs = np.array(imgs)
23 | return imgs
24 |
--------------------------------------------------------------------------------
/state_diff/model/common/shape_util.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Tuple, Callable
2 | import torch
3 | import torch.nn as nn
4 |
5 | def get_module_device(m: nn.Module):
6 | device = torch.device('cpu')
7 | try:
8 | param = next(iter(m.parameters()))
9 | device = param.device
10 | except StopIteration:
11 | pass
12 | return device
13 |
14 | @torch.no_grad()
15 | def get_output_shape(
16 | input_shape: Tuple[int],
17 | net: Callable[[torch.Tensor], torch.Tensor]
18 | ):
19 | device = get_module_device(net)
20 | test_input = torch.zeros((1,)+tuple(input_shape), device=device)
21 | test_output = net(test_input)
22 | output_shape = tuple(test_output.shape[1:])
23 | return output_shape
24 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_envs/adept_envs/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2020 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import adept_envs.franka
18 |
19 | from adept_envs.utils.configurable import global_config
20 |
--------------------------------------------------------------------------------
/state_diff/scripts/episode_lengths.py:
--------------------------------------------------------------------------------
1 | if __name__ == "__main__":
2 | import sys
3 | import os
4 | import pathlib
5 |
6 | ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent)
7 | sys.path.append(ROOT_DIR)
8 |
9 | import click
10 | import numpy as np
11 | import json
12 | from state_diff.common.replay_buffer import ReplayBuffer
13 |
14 | @click.command()
15 | @click.option('--input', '-i', required=True)
16 | @click.option('--dt', default=0.1, type=float)
17 | def main(input, dt):
18 | buffer = ReplayBuffer.create_from_path(input)
19 | lengths = buffer.episode_lengths
20 | durations = lengths * dt
21 | result = {
22 | 'duration/mean': np.mean(durations)
23 | }
24 |
25 | text = json.dumps(result, indent=2)
26 | print(text)
27 |
28 | if __name__ == '__main__':
29 | main()
30 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/third_party/franka/assets/basic_scene.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/state_diff/common/nested_dict_util.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | def nested_dict_map(f, x):
4 | """
5 | Map f over all leaf of nested dict x
6 | """
7 |
8 | if not isinstance(x, dict):
9 | return f(x)
10 | y = dict()
11 | for key, value in x.items():
12 | y[key] = nested_dict_map(f, value)
13 | return y
14 |
15 | def nested_dict_reduce(f, x):
16 | """
17 | Map f over all values of nested dict x, and reduce to a single value
18 | """
19 | if not isinstance(x, dict):
20 | return x
21 |
22 | reduced_values = list()
23 | for value in x.values():
24 | reduced_values.append(nested_dict_reduce(f, value))
25 | y = functools.reduce(f, reduced_values)
26 | return y
27 |
28 |
29 | def nested_dict_check(f, x):
30 | bool_dict = nested_dict_map(f, x)
31 | result = nested_dict_reduce(lambda x, y: x and y, bool_dict)
32 | return result
33 |
--------------------------------------------------------------------------------
/state_diff/env/block_pushing/assets/workspace.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/state_diff/env/block_pushing/assets/workspace_real.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/state_diff/config/task/kitchen_lowdim.yaml:
--------------------------------------------------------------------------------
1 | name: kitchen_lowdim
2 |
3 | obs_dim: 60
4 | action_dim: 9
5 | keypoint_dim: 3
6 | num_agents: 1
7 |
8 | dataset_dir: &dataset_dir data/kitchen
9 |
10 | env_runner:
11 | _target_: state_diff.env_runner.kitchen_lowdim_runner.KitchenLowdimRunner
12 | dataset_dir: *dataset_dir
13 | n_train: 6
14 | n_train_vis: 2
15 | train_start_seed: 0
16 | n_test: 50
17 | n_test_vis: 4
18 | test_start_seed: 100000
19 | max_steps: 280
20 | n_obs_steps: ${n_obs_steps}
21 | n_action_steps: ${n_action_steps}
22 | render_hw: [240, 360]
23 | fps: 12.5
24 | past_action: ${past_action_visible}
25 | n_envs: null
26 |
27 | dataset:
28 | _target_: state_diff.dataset.kitchen_lowdim_dataset.KitchenLowdimDataset
29 | dataset_dir: *dataset_dir
30 | horizon: ${horizon}
31 | pad_before: ${eval:'${n_obs_steps}-1'}
32 | pad_after: ${eval:'${n_action_steps}-1'}
33 | seed: 42
34 | val_ratio: 0.02
35 |
--------------------------------------------------------------------------------
/state_diff/env/block_pushing/assets/block.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
--------------------------------------------------------------------------------
/state_diff/env/block_pushing/assets/block2.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_envs/adept_envs/franka/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2020 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from gym.envs.registration import register
18 |
19 | # Relax the robot
20 | register(
21 | id='kitchen_relax-v1',
22 | entry_point='adept_envs.franka.kitchen_multitask_v0:KitchenTaskRelaxV1',
23 | max_episode_steps=280,
24 | )
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_envs/adept_envs/utils/constants.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2020 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import os
18 |
19 | ENVS_ROOT_PATH = os.path.abspath(os.path.join(
20 | os.path.dirname(os.path.abspath(__file__)),
21 | "../../"))
22 |
23 | MODELS_PATH = os.path.abspath(os.path.join(ENVS_ROOT_PATH, "../adept_models/"))
24 |
--------------------------------------------------------------------------------
/state_diff/env/block_pushing/assets/blocks/red_moon.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/state_diff/env/block_pushing/assets/blocks/blue_cube.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/state_diff/env/block_pushing/assets/blocks/green_star.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/__init__.py:
--------------------------------------------------------------------------------
1 | """Environments using kitchen and Franka robot."""
2 | from gym.envs.registration import register
3 |
4 | register(
5 | id="kitchen-microwave-kettle-light-slider-v0",
6 | entry_point="state_diff.env.kitchen.v0:KitchenMicrowaveKettleLightSliderV0",
7 | max_episode_steps=280,
8 | reward_threshold=1.0,
9 | )
10 |
11 | register(
12 | id="kitchen-microwave-kettle-burner-light-v0",
13 | entry_point="state_diff.env.kitchen.v0:KitchenMicrowaveKettleBottomBurnerLightV0",
14 | max_episode_steps=280,
15 | reward_threshold=1.0,
16 | )
17 |
18 | register(
19 | id="kitchen-kettle-microwave-light-slider-v0",
20 | entry_point="state_diff.env.kitchen.v0:KitchenKettleMicrowaveLightSliderV0",
21 | max_episode_steps=280,
22 | reward_threshold=1.0,
23 | )
24 |
25 | register(
26 | id="kitchen-all-v0",
27 | entry_point="state_diff.env.kitchen.v0:KitchenAllV0",
28 | max_episode_steps=280,
29 | reward_threshold=1.0,
30 | )
31 |
--------------------------------------------------------------------------------
/state_diff/scripts/blockpush_abs_conversion.py:
--------------------------------------------------------------------------------
1 | if __name__ == "__main__":
2 | import sys
3 | import os
4 | import pathlib
5 |
6 | ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent)
7 | sys.path.append(ROOT_DIR)
8 |
9 | import os
10 | import click
11 | import pathlib
12 | from state_diff.common.replay_buffer import ReplayBuffer
13 |
14 |
15 | @click.command()
16 | @click.option('-i', '--input', required=True)
17 | @click.option('-o', '--output', required=True)
18 | @click.option('-t', '--target_eef_idx', default=8, type=int)
19 | def main(input, output, target_eef_idx):
20 | buffer = ReplayBuffer.copy_from_path(input)
21 | obs = buffer['obs']
22 | action = buffer['action']
23 | prev_eef_target = obs[:,target_eef_idx:target_eef_idx+action.shape[1]]
24 | next_eef_target = prev_eef_target + action
25 | action[:] = next_eef_target
26 | buffer.save_to_path(zarr_path=output, chunk_length=-1)
27 |
28 | if __name__ == '__main__':
29 | main()
30 |
--------------------------------------------------------------------------------
/state_diff/env/block_pushing/assets/blocks/yellow_pentagon.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/state_diff/config/task/blockpush_lowdim_seed.yaml:
--------------------------------------------------------------------------------
1 | name: blockpush_lowdim_seed
2 |
3 | obs_dim: 16
4 | action_dim: 2
5 | keypoint_dim: 2
6 | obs_eef_target: True
7 | num_agents: 1
8 |
9 | env_runner:
10 | _target_: state_diff.env_runner.blockpush_lowdim_runner.BlockPushLowdimRunner
11 | n_train: 6
12 | n_train_vis: 2
13 | train_start_seed: 0
14 | n_test: 50
15 | n_test_vis: 4
16 | test_start_seed: 100000
17 | max_steps: 350
18 | n_obs_steps: ${n_obs_steps}
19 | n_action_steps: ${n_action_steps}
20 | fps: 5
21 | past_action: ${past_action_visible}
22 | abs_action: False
23 | obs_eef_target: ${task.obs_eef_target}
24 | n_envs: null
25 |
26 | dataset:
27 | _target_: state_diff.dataset.blockpush_lowdim_dataset.BlockPushLowdimDataset
28 | zarr_path: data/block_pushing/multimodal_push_seed.zarr
29 | horizon: ${horizon}
30 | pad_before: ${eval:'${n_obs_steps}-1'}
31 | pad_after: ${eval:'${n_action_steps}-1'}
32 | obs_eef_target: ${task.obs_eef_target}
33 | use_manual_normalizer: False
34 | seed: 42
35 | val_ratio: 0.02
36 |
--------------------------------------------------------------------------------
/tests/test_robomimic_lowdim_runner.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
5 | sys.path.append(ROOT_DIR)
6 | os.chdir(ROOT_DIR)
7 |
8 | from state_diff.env_runner.robomimic_lowdim_runner import RobomimicLowdimRunner
9 |
10 | def test():
11 | import os
12 | from omegaconf import OmegaConf
13 | cfg_path = os.path.expanduser('/projects/haonan2/bimanual_ur5e/state_diff/config/task/lift_lowdim.yaml')
14 | cfg = OmegaConf.load(cfg_path)
15 | cfg['n_obs_steps'] = 1
16 | cfg['n_action_steps'] = 1
17 | cfg['past_action_visible'] = False
18 | runner_cfg = cfg['env_runner']
19 | runner_cfg['n_train'] = 1
20 | runner_cfg['n_test'] = 0
21 | del runner_cfg['_target_']
22 | runner = RobomimicLowdimRunner(
23 | **runner_cfg,
24 | output_dir='/tmp/test')
25 |
26 | # import pdb; pdb.set_trace()
27 |
28 | self = runner
29 | env = self.env
30 | env.seed(seeds=self.env_seeds)
31 | obs = env.reset()
32 |
33 | if __name__ == '__main__':
34 | test()
35 |
--------------------------------------------------------------------------------
/state_diff/config/task/blockpush_lowdim_seed_abs.yaml:
--------------------------------------------------------------------------------
1 | name: blockpush_lowdim_seed_abs
2 |
3 | obs_dim: 16
4 | action_dim: 2
5 | keypoint_dim: 2
6 | obs_eef_target: True
7 | num_agents: 1
8 |
9 | env_runner:
10 | _target_: state_diff.env_runner.blockpush_lowdim_runner.BlockPushLowdimRunner
11 | n_train: 6
12 | n_train_vis: 2
13 | train_start_seed: 0
14 | n_test: 50
15 | n_test_vis: 4
16 | test_start_seed: 100000
17 | max_steps: 350
18 | n_obs_steps: ${n_obs_steps}
19 | n_action_steps: ${n_action_steps}
20 | fps: 5
21 | past_action: ${past_action_visible}
22 | abs_action: True
23 | obs_eef_target: ${task.obs_eef_target}
24 | n_envs: null
25 |
26 | dataset:
27 | _target_: state_diff.dataset.blockpush_lowdim_dataset.BlockPushLowdimDataset
28 | zarr_path: data/block_pushing/multimodal_push_seed_abs.zarr
29 | horizon: ${horizon}
30 | pad_before: ${eval:'${n_obs_steps}-1'}
31 | pad_after: ${eval:'${n_action_steps}-1'}
32 | obs_eef_target: ${task.obs_eef_target}
33 | use_manual_normalizer: False
34 | seed: 42
35 | val_ratio: 0.02
36 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/assets/backwall_chain.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/state_diff/config/task/blockpush_traj_lowdim_seed.yaml:
--------------------------------------------------------------------------------
1 | name: blockpush_traj_lowdim_seed
2 |
3 | obs_dim: 16
4 | action_dim: 2
5 | keypoint_dim: 2
6 | inv_hidden_dim: 256
7 | obs_eef_target: True
8 | num_agents: 1
9 |
10 | env_runner:
11 | _target_: state_diff.env_runner.blockpush_lowdim_runner.BlockPushLowdimRunner
12 | n_train: 6
13 | n_train_vis: 2
14 | train_start_seed: 0
15 | n_test: 50
16 | n_test_vis: 4
17 | test_start_seed: 100000
18 | max_steps: 350
19 | n_obs_steps: ${n_obs_steps}
20 | n_action_steps: ${n_action_steps}
21 | fps: 5
22 | past_action: ${past_action_visible}
23 | abs_action: False
24 | obs_eef_target: ${task.obs_eef_target}
25 | n_envs: null
26 | num_agents: ${task.num_agents}
27 |
28 | dataset:
29 | _target_: state_diff.dataset.blockpush_lowdim_dataset.BlockPushLowdimDataset
30 | zarr_path: data/block_pushing/multimodal_push_seed.zarr
31 | horizon: ${horizon}
32 | pad_before: ${eval:'${n_obs_steps}-1'}
33 | pad_after: ${eval:'${n_action_steps}-1'}
34 | obs_eef_target: ${task.obs_eef_target}
35 | use_manual_normalizer: False
36 | seed: 42
37 | val_ratio: 0.02
38 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/assets/backwall_asset.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/state_diff/config/task/blockpush_traj_lowdim_seed_abs.yaml:
--------------------------------------------------------------------------------
1 | name: blockpush_traj_lowdim_seed_abs
2 |
3 | obs_dim: 16
4 | action_dim: 2
5 | keypoint_dim: 2
6 | inv_hidden_dim: 256
7 | obs_eef_target: True
8 | num_agents: 1
9 |
10 | env_runner:
11 | _target_: state_diff.env_runner.blockpush_lowdim_runner.BlockPushLowdimRunner
12 | n_train: 6
13 | n_train_vis: 2
14 | train_start_seed: 0
15 | n_test: 50
16 | n_test_vis: 4
17 | test_start_seed: 100000
18 | max_steps: 350
19 | n_obs_steps: ${n_obs_steps}
20 | n_action_steps: ${n_action_steps}
21 | fps: 5
22 | past_action: ${past_action_visible}
23 | abs_action: True
24 | obs_eef_target: ${task.obs_eef_target}
25 | n_envs: null
26 | num_agents: ${task.num_agents}
27 |
28 | dataset:
29 | _target_: state_diff.dataset.blockpush_lowdim_dataset.BlockPushLowdimDataset
30 | zarr_path: data/block_pushing/multimodal_push_seed_abs.zarr
31 | horizon: ${horizon}
32 | pad_before: ${eval:'${n_obs_steps}-1'}
33 | pad_after: ${eval:'${n_action_steps}-1'}
34 | obs_eef_target: ${task.obs_eef_target}
35 | use_manual_normalizer: False
36 | seed: 42
37 | val_ratio: 0.02
38 |
--------------------------------------------------------------------------------
/state_diff/config/task/kitchen_lowdim_abs.yaml:
--------------------------------------------------------------------------------
1 | name: kitchen_lowdim
2 |
3 | obs_dim: 60
4 | action_dim: 9
5 | keypoint_dim: 3
6 | num_agents: 1
7 |
8 | abs_action: True
9 | robot_noise_ratio: 0.1
10 |
11 | env_runner:
12 | _target_: state_diff.env_runner.kitchen_lowdim_runner.KitchenLowdimRunner
13 | dataset_dir: data/kitchen
14 | n_train: 6
15 | n_train_vis: 2
16 | train_start_seed: 0
17 | n_test: 50
18 | n_test_vis: 4
19 | test_start_seed: 100000
20 | max_steps: 280
21 | n_obs_steps: ${n_obs_steps}
22 | n_action_steps: ${n_action_steps}
23 | render_hw: [240, 360]
24 | fps: 12.5
25 | past_action: ${past_action_visible}
26 | abs_action: ${task.abs_action}
27 | robot_noise_ratio: ${task.robot_noise_ratio}
28 | n_envs: null
29 |
30 | dataset:
31 | _target_: state_diff.dataset.kitchen_mjl_lowdim_dataset.KitchenMjlLowdimDataset
32 | dataset_dir: data/kitchen/kitchen_demos_multitask
33 | horizon: ${horizon}
34 | pad_before: ${eval:'${n_obs_steps}-1'}
35 | pad_after: ${eval:'${n_action_steps}-1'}
36 | abs_action: ${task.abs_action}
37 | robot_noise_ratio: ${task.robot_noise_ratio}
38 | seed: 42
39 | val_ratio: 0.02
40 |
--------------------------------------------------------------------------------
/tests/test_robomimic_image_runner.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
5 | sys.path.append(ROOT_DIR)
6 | os.chdir(ROOT_DIR)
7 |
8 | from state_diff.env_runner.robomimic_image_runner import RobomimicImageRunner
9 |
10 | def test():
11 | import os
12 | from omegaconf import OmegaConf
13 | cfg_path = os.path.expanduser('~/dev/state_diff/state_diff/config/task/lift_image.yaml')
14 | cfg = OmegaConf.load(cfg_path)
15 | cfg['n_obs_steps'] = 1
16 | cfg['n_action_steps'] = 1
17 | cfg['past_action_visible'] = False
18 | runner_cfg = cfg['env_runner']
19 | runner_cfg['n_train'] = 1
20 | runner_cfg['n_test'] = 1
21 | del runner_cfg['_target_']
22 | runner = RobomimicImageRunner(
23 | **runner_cfg,
24 | output_dir='/tmp/test')
25 |
26 | # import pdb; pdb.set_trace()
27 |
28 | self = runner
29 | env = self.env
30 | env.seed(seeds=self.env_seeds)
31 | obs = env.reset()
32 | for i in range(10):
33 | _ = env.step(env.action_space.sample())
34 |
35 | imgs = env.render()
36 |
37 | if __name__ == '__main__':
38 | test()
39 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/assets/slidecabinet_asset.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | Training:
4 | python train.py --config-name=train_diffusion_trajectory_unet_lowdim_workspace
5 | """
6 |
7 | import sys
8 | # use line-buffering for both stdout and stderr
9 | sys.stdout = open(sys.stdout.fileno(), mode='w', buffering=1)
10 | sys.stderr = open(sys.stderr.fileno(), mode='w', buffering=1)
11 |
12 | import hydra
13 | from omegaconf import OmegaConf
14 | import pathlib
15 | from state_diff.workspace.base_workspace import BaseWorkspace
16 |
17 | # allows arbitrary python code execution in configs using the ${eval:''} resolver
18 | OmegaConf.register_new_resolver("eval", eval, replace=True)
19 |
20 | @hydra.main(
21 | version_base=None,
22 | config_path=str(pathlib.Path(__file__).parent.joinpath(
23 | 'state_diff','config'),
24 | ),
25 | config_name="train_diffusion_trajectory_unet_lowdim_workspace"
26 | )
27 | def main(cfg: OmegaConf):
28 | # resolve immediately so all the ${now:} resolvers
29 | # will use the same time.
30 | OmegaConf.resolve(cfg)
31 |
32 | cls = hydra.utils.get_class(cfg._target_)
33 | workspace: BaseWorkspace = cls(cfg)
34 | workspace.run()
35 |
36 | if __name__ == "__main__":
37 | main()
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/CONTRIBUTING.public.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows
28 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).
29 |
--------------------------------------------------------------------------------
/state_diff/shared_memory/shared_memory_util.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | from dataclasses import dataclass
3 | import numpy as np
4 | from multiprocessing.managers import SharedMemoryManager
5 | from atomics import atomicview, MemoryOrder, UINT
6 |
7 | @dataclass
8 | class ArraySpec:
9 | name: str
10 | shape: Tuple[int]
11 | dtype: np.dtype
12 |
13 |
14 | class SharedAtomicCounter:
15 | def __init__(self,
16 | shm_manager: SharedMemoryManager,
17 | size :int=8 # 64bit int
18 | ):
19 | shm = shm_manager.SharedMemory(size=size)
20 | self.shm = shm
21 | self.size = size
22 | self.store(0) # initialize
23 |
24 | @property
25 | def buf(self):
26 | return self.shm.buf[:self.size]
27 |
28 | def load(self) -> int:
29 | with atomicview(buffer=self.buf, atype=UINT) as a:
30 | value = a.load(order=MemoryOrder.ACQUIRE)
31 | return value
32 |
33 | def store(self, value: int):
34 | with atomicview(buffer=self.buf, atype=UINT) as a:
35 | a.store(value, order=MemoryOrder.RELEASE)
36 |
37 | def add(self, value: int):
38 | with atomicview(buffer=self.buf, atype=UINT) as a:
39 | a.add(value, order=MemoryOrder.ACQ_REL)
40 |
--------------------------------------------------------------------------------
/state_diff/env/block_pushing/oracles/pushing_info.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Reach ML Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Dataclass holding info needed for pushing oracles."""
17 | import dataclasses
18 | from typing import Any
19 |
20 |
21 | @dataclasses.dataclass
22 | class PushingInfo:
23 | """Holds onto info necessary for pushing state machine."""
24 |
25 | xy_block: Any = None
26 | xy_ee: Any = None
27 | xy_pre_block: Any = None
28 | xy_delta_to_nexttoblock: Any = None
29 | xy_delta_to_touchingblock: Any = None
30 | xy_dir_block_to_ee: Any = None
31 | theta_threshold_to_orient: Any = None
32 | theta_threshold_flat_enough: Any = None
33 | theta_error: Any = None
34 | obstacle_poses: Any = None
35 | distance_to_target: Any = None
36 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/third_party/franka/assets/teleop_actuator.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
--------------------------------------------------------------------------------
/state_diff/model/diffusion/conv1d_components.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | # from einops.layers.torch import Rearrange
5 |
6 |
7 | class Downsample1d(nn.Module):
8 | def __init__(self, dim):
9 | super().__init__()
10 | self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
11 |
12 | def forward(self, x):
13 | return self.conv(x)
14 |
15 | class Upsample1d(nn.Module):
16 | def __init__(self, dim):
17 | super().__init__()
18 | self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
19 |
20 | def forward(self, x):
21 | return self.conv(x)
22 |
23 | class Conv1dBlock(nn.Module):
24 | '''
25 | Conv1d --> GroupNorm --> Mish
26 | '''
27 |
28 | def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
29 | super().__init__()
30 |
31 | self.block = nn.Sequential(
32 | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
33 | # Rearrange('batch channels horizon -> batch channels 1 horizon'),
34 | nn.GroupNorm(n_groups, out_channels),
35 | # Rearrange('batch channels 1 horizon -> batch channels horizon'),
36 | nn.Mish(),
37 | )
38 |
39 | def forward(self, x):
40 | return self.block(x)
41 |
42 |
43 | def test():
44 | cb = Conv1dBlock(256, 128, kernel_size=3)
45 | x = torch.zeros((1,256,16))
46 | o = cb(x)
47 |
--------------------------------------------------------------------------------
/state_diff/gym_util/video_wrapper.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import numpy as np
3 |
4 | class VideoWrapper(gym.Wrapper):
5 | def __init__(self,
6 | env,
7 | mode='rgb_array',
8 | enabled=True,
9 | steps_per_render=1,
10 | **kwargs
11 | ):
12 | super().__init__(env)
13 |
14 | self.mode = mode
15 | self.enabled = enabled
16 | self.render_kwargs = kwargs
17 | self.steps_per_render = steps_per_render
18 |
19 | self.frames = list()
20 | self.step_count = 0
21 |
22 | def reset(self, **kwargs):
23 | obs = super().reset(**kwargs)
24 | self.frames = list()
25 | self.step_count = 1
26 | if self.enabled:
27 | frame = self.env.render(
28 | mode=self.mode, **self.render_kwargs)
29 | assert frame.dtype == np.uint8
30 | self.frames.append(frame)
31 | return obs
32 |
33 | def step(self, action):
34 | result = super().step(action)
35 | self.step_count += 1
36 | if self.enabled and ((self.step_count % self.steps_per_render) == 0):
37 | frame = self.env.render(
38 | mode=self.mode, **self.render_kwargs)
39 | assert frame.dtype == np.uint8
40 | self.frames.append(frame)
41 | return result
42 |
43 | def render(self, mode='rgb_array', **kwargs):
44 | return self.frames
45 |
--------------------------------------------------------------------------------
/tests/test_block_pushing.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
5 | sys.path.append(ROOT_DIR)
6 | os.chdir(ROOT_DIR)
7 |
8 | from state_diff.env.block_pushing.block_pushing_multimodal import BlockPushMultimodal
9 | from gym.wrappers import FlattenObservation
10 | from state_diff.gym_util.multistep_wrapper import MultiStepWrapper
11 | from state_diff.gym_util.video_wrapper import VideoWrapper
12 |
13 | def test():
14 | env = MultiStepWrapper(
15 | VideoWrapper(
16 | FlattenObservation(
17 | BlockPushMultimodal()
18 | ),
19 | enabled=True,
20 | steps_per_render=2
21 | ),
22 | n_obs_steps=2,
23 | n_action_steps=8,
24 | max_episode_steps=16
25 | )
26 | env = BlockPushMultimodal()
27 | obs = env.reset()
28 | import pdb; pdb.set_trace()
29 |
30 | env = FlattenObservation(BlockPushMultimodal())
31 | obs = env.reset()
32 | action = env.action_space.sample()
33 | next_obs, reward, done, info = env.step(action)
34 | print(obs[8:10] + action - next_obs[8:10])
35 | import pdb; pdb.set_trace()
36 |
37 | for i in range(3):
38 | obs, reward, done, info = env.step(env.action_space.sample())
39 | img = env.render()
40 | import pdb; pdb.set_trace()
41 | print("Done!", done)
42 |
43 | if __name__ == '__main__':
44 | test()
45 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/assets/hingecabinet_asset.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/state_diff/config/task/pushl_traj_lowdim.yaml:
--------------------------------------------------------------------------------
1 | name: pushl_traj_lowdim
2 |
3 | obs_dim: 40 # 40 or 42 # 9*2 keypoints + 10*2 keypoints + 4 state
4 | action_dim: 2
5 | keypoint_dim: 2
6 | inv_hidden_dim: 256
7 | num_agents: 1
8 |
9 | obs_key: keypoint # keypoint, structured_keypoint
10 | state_key: state
11 | action_key: action # action1,action2
12 |
13 | env_runner:
14 | _target_: state_diff.env_runner.push_keypoints_runner.PushKeypointsRunner
15 | keypoint_visible_rate: ${keypoint_visible_rate}
16 | n_train: 6
17 | n_train_vis: 2
18 | train_start_seed: 0
19 | n_test: 50
20 | n_test_vis: 4
21 | legacy_test: True
22 | test_start_seed: 100000
23 | max_steps: 500
24 | n_obs_steps: ${n_obs_steps}
25 | n_action_steps: ${n_action_steps}
26 | n_latency_steps: ${n_latency_steps}
27 | fps: 10
28 | agent_keypoints: False
29 | past_action: ${past_action_visible}
30 | n_envs: null
31 | obs_key: ${task.obs_key}
32 | state_key: ${task.state_key}
33 | action_key: ${task.action_key}
34 | env_type: ${task.name}
35 |
36 | dataset:
37 | _target_: state_diff.dataset.pusht_traj_dataset.PushTTrajLowdimDataset
38 | zarr_path: data/pushl_dataset
39 | horizon: ${horizon}
40 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'}
41 | pad_after: ${eval:'${n_action_steps}-1'}
42 | seed: 42
43 | val_ratio: 0.02
44 | max_train_episodes: null
45 | obs_key: ${task.obs_key}
46 | state_key: ${task.state_key}
47 | action_key: ${task.action_key}
48 | num_episodes: null
49 |
50 |
--------------------------------------------------------------------------------
/state_diff/env/block_pushing/assets/zone.obj:
--------------------------------------------------------------------------------
1 | # Object Export From Tinkercad Server 2015
2 |
3 | mtllib obj.mtl
4 |
5 | o obj_0
6 | v 10 -10 20
7 | v 10 -10 0
8 | v 10 10 0
9 | v 10 10 20
10 | v 9.002 9.003 20
11 | v 9.002 -9.002 20
12 | v -10 10 0
13 | v -10 10 20
14 | v -9.003 9.003 20
15 | v -9.003 9.003 0
16 | v 9.002 9.003 0
17 | v 9.002 -9.002 0
18 | v -9.003 -9.002 0
19 | v -9.003 -9.002 20
20 | v -10 -10 0
21 | v -10 -10 20
22 | # 16 vertices
23 |
24 | g group_0_15277357
25 |
26 | usemtl color_15277357
27 | s 0
28 |
29 | f 1 2 3
30 | f 1 3 4
31 | f 4 5 6
32 | f 4 6 1
33 | f 9 10 11
34 | f 9 11 5
35 | f 6 12 13
36 | f 6 13 14
37 | f 10 9 14
38 | f 10 14 13
39 | f 7 10 13
40 | f 7 13 15
41 | f 4 8 5
42 | f 9 5 8
43 | f 8 7 15
44 | f 8 15 16
45 | f 10 7 11
46 | f 3 11 7
47 | f 11 3 12
48 | f 2 12 3
49 | f 14 16 6
50 | f 1 6 16
51 | f 16 15 2
52 | f 16 2 1
53 | f 9 8 14
54 | f 16 14 8
55 | f 7 8 3
56 | f 4 3 8
57 | f 2 15 12
58 | f 13 12 15
59 | f 12 6 5
60 | f 12 5 11
61 | # 32 faces
62 |
63 | #end of obj_0
64 |
65 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/assets/kettle_chain.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/state_diff/model/common/dict_of_tensor_mixin.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class DictOfTensorMixin(nn.Module):
5 | def __init__(self, params_dict=None):
6 | super().__init__()
7 | if params_dict is None:
8 | params_dict = nn.ParameterDict()
9 | self.params_dict = params_dict
10 |
11 | @property
12 | def device(self):
13 | return next(iter(self.parameters())).device
14 |
15 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
16 | def dfs_add(dest, keys, value: torch.Tensor):
17 | if len(keys) == 1:
18 | dest[keys[0]] = value
19 | return
20 |
21 | if keys[0] not in dest:
22 | dest[keys[0]] = nn.ParameterDict()
23 | dfs_add(dest[keys[0]], keys[1:], value)
24 |
25 | def load_dict(state_dict, prefix):
26 | out_dict = nn.ParameterDict()
27 | for key, value in state_dict.items():
28 | value: torch.Tensor
29 | if key.startswith(prefix):
30 | param_keys = key[len(prefix):].split('.')[1:]
31 | # if len(param_keys) == 0:
32 | # import pdb; pdb.set_trace()
33 | dfs_add(out_dict, param_keys, value.clone())
34 | return out_dict
35 |
36 | self.params_dict = load_dict(state_dict, prefix + 'params_dict')
37 | self.params_dict.requires_grad_(False)
38 | return
39 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/assets/kettle_asset.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/third_party/franka/assets/actuator1.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/state_diff/dataset/base_dataset.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | import torch
4 | import torch.nn
5 | from state_diff.model.common.normalizer import LinearNormalizer
6 |
7 | class BaseLowdimDataset(torch.utils.data.Dataset):
8 | def get_validation_dataset(self) -> 'BaseLowdimDataset':
9 | # return an empty dataset by default
10 | return BaseLowdimDataset()
11 |
12 | def get_normalizer(self, **kwargs) -> LinearNormalizer:
13 | raise NotImplementedError()
14 |
15 | def get_all_actions(self) -> torch.Tensor:
16 | raise NotImplementedError()
17 |
18 | def __len__(self) -> int:
19 | return 0
20 |
21 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
22 | """
23 | output:
24 | obs: T, Do
25 | action: T, Da
26 | """
27 | raise NotImplementedError()
28 |
29 |
30 | class BaseImageDataset(torch.utils.data.Dataset):
31 | def get_validation_dataset(self) -> 'BaseLowdimDataset':
32 | # return an empty dataset by default
33 | return BaseImageDataset()
34 |
35 | def get_normalizer(self, **kwargs) -> LinearNormalizer:
36 | raise NotImplementedError()
37 |
38 | def get_all_actions(self) -> torch.Tensor:
39 | raise NotImplementedError()
40 |
41 | def __len__(self) -> int:
42 | return 0
43 |
44 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
45 | """
46 | output:
47 | obs:
48 | key: T, *
49 | action: T, Da
50 | """
51 | raise NotImplementedError()
52 |
--------------------------------------------------------------------------------
/tests/test_precise_sleep.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
5 | sys.path.append(ROOT_DIR)
6 | os.chdir(ROOT_DIR)
7 |
8 | import time
9 | import numpy as np
10 | from state_diff.common.precise_sleep import precise_sleep, precise_wait
11 |
12 |
13 | def test_sleep():
14 | dt = 0.1
15 | tol = 1e-3
16 | time_samples = list()
17 | for i in range(100):
18 | precise_sleep(dt)
19 | # time.sleep(dt)
20 | time_samples.append(time.monotonic())
21 | time_deltas = np.diff(time_samples)
22 |
23 | from matplotlib import pyplot as plt
24 | plt.plot(time_deltas)
25 | plt.ylim((dt-tol,dt+tol))
26 |
27 |
28 | def test_wait():
29 | dt = 0.1
30 | tol = 1e-3
31 | errors = list()
32 | t_start = time.monotonic()
33 | for i in range(1,100):
34 | t_end_desired = t_start + i * dt
35 | time.sleep(t_end_desired - time.monotonic())
36 | t_end = time.monotonic()
37 | errors.append(t_end - t_end_desired)
38 |
39 | new_errors = list()
40 | t_start = time.monotonic()
41 | for i in range(1,100):
42 | t_end_desired = t_start + i * dt
43 | precise_wait(t_end_desired)
44 | t_end = time.monotonic()
45 | new_errors.append(t_end - t_end_desired)
46 |
47 | from matplotlib import pyplot as plt
48 | plt.plot(errors, label='time.sleep')
49 | plt.plot(new_errors, label='sleep/spin hybrid')
50 | plt.ylim((-tol,+tol))
51 | plt.title('0.1 sec sleep error')
52 | plt.legend()
53 |
54 |
55 | if __name__ == '__main__':
56 | test_sleep()
57 |
--------------------------------------------------------------------------------
/state_diff/common/pymunk_util.py:
--------------------------------------------------------------------------------
1 | import pygame
2 | import pymunk
3 | import pymunk.pygame_util
4 | import numpy as np
5 |
6 | COLLTYPE_DEFAULT = 0
7 | COLLTYPE_MOUSE = 1
8 | COLLTYPE_BALL = 2
9 |
10 | def get_body_type(static=False):
11 | body_type = pymunk.Body.DYNAMIC
12 | if static:
13 | body_type = pymunk.Body.STATIC
14 | return body_type
15 |
16 |
17 | def create_rectangle(space,
18 | pos_x,pos_y,width,height,
19 | density=3,static=False):
20 | body = pymunk.Body(body_type=get_body_type(static))
21 | body.position = (pos_x,pos_y)
22 | shape = pymunk.Poly.create_box(body,(width,height))
23 | shape.density = density
24 | space.add(body,shape)
25 | return body, shape
26 |
27 |
28 | def create_rectangle_bb(space,
29 | left, bottom, right, top,
30 | **kwargs):
31 | pos_x = (left + right) / 2
32 | pos_y = (top + bottom) / 2
33 | height = top - bottom
34 | width = right - left
35 | return create_rectangle(space, pos_x, pos_y, width, height, **kwargs)
36 |
37 | def create_circle(space, pos_x, pos_y, radius, density=3, static=False):
38 | body = pymunk.Body(body_type=get_body_type(static))
39 | body.position = (pos_x, pos_y)
40 | shape = pymunk.Circle(body, radius=radius)
41 | shape.density = density
42 | shape.collision_type = COLLTYPE_BALL
43 | space.add(body, shape)
44 | return body, shape
45 |
46 | def get_body_state(body):
47 | state = np.zeros(6, dtype=np.float32)
48 | state[:2] = body.position
49 | state[2] = body.angle
50 | state[3:5] = body.velocity
51 | state[5] = body.angular_velocity
52 | return state
53 |
--------------------------------------------------------------------------------
/state_diff/scripts/bet_blockpush_conversion.py:
--------------------------------------------------------------------------------
1 | if __name__ == "__main__":
2 | import sys
3 | import os
4 | import pathlib
5 |
6 | ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent)
7 | sys.path.append(ROOT_DIR)
8 |
9 |
10 | import os
11 | import click
12 | import pathlib
13 | import numpy as np
14 | from state_diff.common.replay_buffer import ReplayBuffer
15 |
16 | @click.command()
17 | @click.option('-i', '--input', required=True, help='input dir contains npy files')
18 | @click.option('-o', '--output', required=True, help='output zarr path')
19 | @click.option('--abs_action', is_flag=True, default=False)
20 | def main(input, output, abs_action):
21 | data_directory = pathlib.Path(input)
22 | observations = np.load(
23 | data_directory / "multimodal_push_observations.npy"
24 | )
25 | actions = np.load(data_directory / "multimodal_push_actions.npy")
26 | masks = np.load(data_directory / "multimodal_push_masks.npy")
27 |
28 | buffer = ReplayBuffer.create_empty_numpy()
29 | for i in range(len(masks)):
30 | eps_len = int(masks[i].sum())
31 | obs = observations[i,:eps_len].astype(np.float32)
32 | action = actions[i,:eps_len].astype(np.float32)
33 | if abs_action:
34 | prev_eef_target = obs[:,8:10]
35 | next_eef_target = prev_eef_target + action
36 | action = next_eef_target
37 | data = {
38 | 'obs': obs,
39 | 'action': action
40 | }
41 | buffer.add_episode(data)
42 |
43 | buffer.save_to_path(zarr_path=output, chunk_length=-1)
44 |
45 | if __name__ == '__main__':
46 | main()
47 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/kitchen_lowdim_wrapper.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict, Optional, Optional
2 | import numpy as np
3 | import gym
4 | from gym.spaces import Box
5 | from state_diff.env.kitchen.base import KitchenBase
6 |
7 | class KitchenLowdimWrapper(gym.Env):
8 | def __init__(self,
9 | env: KitchenBase,
10 | init_qpos: Optional[np.ndarray]=None,
11 | init_qvel: Optional[np.ndarray]=None,
12 | render_hw = (240,360)
13 | ):
14 | self.env = env
15 | self.init_qpos = init_qpos
16 | self.init_qvel = init_qvel
17 | self.render_hw = render_hw
18 |
19 | @property
20 | def action_space(self):
21 | return self.env.action_space
22 |
23 | @property
24 | def observation_space(self):
25 | return self.env.observation_space
26 |
27 | def seed(self, seed=None):
28 | return self.env.seed(seed)
29 |
30 | def reset(self):
31 | if self.init_qpos is not None:
32 | # reset anyway to be safe, not very expensive
33 | _ = self.env.reset()
34 | # start from known state
35 | self.env.set_state(self.init_qpos, self.init_qvel)
36 | obs = self.env._get_obs()
37 | return obs
38 | # obs, _, _, _ = self.env.step(np.zeros_like(
39 | # self.action_space.sample()))
40 | # return obs
41 | else:
42 | return self.env.reset()
43 |
44 | def render(self, mode='rgb_array'):
45 | h, w = self.render_hw
46 | return self.env.render(mode=mode, width=w, height=h)
47 |
48 | def step(self, a):
49 | return self.env.step(a)
50 |
--------------------------------------------------------------------------------
/state_diff/config/task/pushl_lowdim.yaml:
--------------------------------------------------------------------------------
1 | name: pushl_lowdim
2 |
3 |
4 |
5 | # For single agent pushl
6 | obs_dim: 40 # 9*2 keypoints + 10*2 keypoints + 2 state
7 | action_dim: 2
8 | keypoint_dim: 2
9 | num_agents: 1
10 |
11 |
12 | obs_key: keypoint # keypoint, structured_keypoint
13 | state_key: state
14 | action_key: action
15 |
16 |
17 | # # For double agent pushl
18 | # obs_dim: 42 # 9*2 keypoints + 10*2 keypoints + 4 state
19 | # action_dim: 4
20 | # keypoint_dim: 2
21 | # obs_key: keypoint # keypoint, structured_keypoint
22 | # state_key: state
23 | # action_key: action1,action2
24 |
25 |
26 |
27 | env_runner:
28 | _target_: state_diff.env_runner.push_keypoints_runner.PushKeypointsRunner
29 | keypoint_visible_rate: ${keypoint_visible_rate}
30 | n_train: 6
31 | n_train_vis: 2
32 | train_start_seed: 0
33 | n_test: 200 # 50
34 | n_test_vis: 4
35 | legacy_test: True
36 | test_start_seed: 100000
37 | max_steps: 500
38 | n_obs_steps: ${n_obs_steps}
39 | n_action_steps: ${n_action_steps}
40 | n_latency_steps: ${n_latency_steps}
41 | fps: 10
42 | agent_keypoints: False
43 | past_action: ${past_action_visible}
44 | n_envs: null
45 | obs_key: ${task.obs_key}
46 | state_key: ${task.state_key}
47 | action_key: ${task.action_key}
48 |
49 |
50 | dataset:
51 | _target_: state_diff.dataset.pusht_dataset.PushTLowdimDataset
52 | zarr_path: data/pushl_dataset
53 | horizon: ${horizon}
54 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'}
55 | pad_after: ${eval:'${n_action_steps}-1'}
56 | seed: 42
57 | val_ratio: 0.02
58 | max_train_episodes: null
59 | obs_key: ${task.obs_key}
60 | state_key: ${task.state_key}
61 | action_key: ${task.action_key}
62 | num_episodes: 600
63 |
64 |
--------------------------------------------------------------------------------
/tests/test_shared_queue.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
5 | sys.path.append(ROOT_DIR)
6 | os.chdir(ROOT_DIR)
7 |
8 | import numpy as np
9 | from multiprocessing.managers import SharedMemoryManager
10 | from state_diff.shared_memory.shared_memory_queue import SharedMemoryQueue, Full, Empty
11 |
12 |
13 | def test():
14 | shm_manager = SharedMemoryManager()
15 | shm_manager.start()
16 | example = {
17 | 'cmd': 0,
18 | 'pose': np.zeros((6,))
19 | }
20 | queue = SharedMemoryQueue.create_from_examples(
21 | shm_manager=shm_manager,
22 | examples=example,
23 | buffer_size=3
24 | )
25 | raised = False
26 | try:
27 | queue.get()
28 | except Empty:
29 | raised = True
30 | assert raised
31 |
32 | data = {
33 | 'cmd': 1,
34 | 'pose': np.ones((6,))
35 | }
36 | queue.put(data)
37 | result = queue.get()
38 | assert result['cmd'] == data['cmd']
39 | assert np.allclose(result['pose'], data['pose'])
40 |
41 | queue.put(data)
42 | queue.put(data)
43 | queue.put(data)
44 | assert queue.qsize() == 3
45 | raised = False
46 | try:
47 | queue.put(data)
48 | except Full:
49 | raised = True
50 | assert raised
51 |
52 | result = queue.get_all()
53 | assert np.allclose(result['cmd'], [1,1,1])
54 |
55 | queue.put({'cmd': 0})
56 | queue.put({'cmd': 1})
57 | queue.put({'cmd': 2})
58 | queue.get()
59 | queue.put({'cmd': 3})
60 |
61 | result = queue.get_k(3)
62 | assert np.allclose(result['cmd'], [1,2,3])
63 |
64 | queue.clear()
65 |
66 | if __name__ == "__main__":
67 | test()
68 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/assets/microwave_asset.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/third_party/franka/assets/actuator0.xml:
--------------------------------------------------------------------------------
1 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/state_diff/common/robomimic_config_util.py:
--------------------------------------------------------------------------------
1 | from omegaconf import OmegaConf
2 | from robomimic.config import config_factory
3 | import robomimic.scripts.generate_paper_configs as gpc
4 | from robomimic.scripts.generate_paper_configs import (
5 | modify_config_for_default_image_exp,
6 | modify_config_for_default_low_dim_exp,
7 | modify_config_for_dataset,
8 | )
9 |
10 | def get_robomimic_config(
11 | algo_name='bc_rnn',
12 | hdf5_type='low_dim',
13 | task_name='square',
14 | dataset_type='ph'
15 | ):
16 | base_dataset_dir = '/tmp/null'
17 | filter_key = None
18 |
19 | # decide whether to use low-dim or image training defaults
20 | modifier_for_obs = modify_config_for_default_image_exp
21 | if hdf5_type in ["low_dim", "low_dim_sparse", "low_dim_dense"]:
22 | modifier_for_obs = modify_config_for_default_low_dim_exp
23 |
24 | algo_config_name = "bc" if algo_name == "bc_rnn" else algo_name
25 | config = config_factory(algo_name=algo_config_name)
26 | # turn into default config for observation modalities (e.g.: low-dim or rgb)
27 | config = modifier_for_obs(config)
28 | # add in config based on the dataset
29 | config = modify_config_for_dataset(
30 | config=config,
31 | task_name=task_name,
32 | dataset_type=dataset_type,
33 | hdf5_type=hdf5_type,
34 | base_dataset_dir=base_dataset_dir,
35 | filter_key=filter_key,
36 | )
37 | # add in algo hypers based on dataset
38 | algo_config_modifier = getattr(gpc, f'modify_{algo_name}_config_for_dataset')
39 | config = algo_config_modifier(
40 | config=config,
41 | task_name=task_name,
42 | dataset_type=dataset_type,
43 | hdf5_type=hdf5_type,
44 | )
45 | return config
46 |
47 |
48 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/scenes/basic_scene.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/state_diff/env/block_pushing/assets/insert.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/third_party/franka/franka_panda.xml:
--------------------------------------------------------------------------------
1 |
4 |
5 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
--------------------------------------------------------------------------------
/tests/test_replay_buffer.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
5 | sys.path.append(ROOT_DIR)
6 | os.chdir(ROOT_DIR)
7 |
8 | import zarr
9 | from state_diff.common.replay_buffer import ReplayBuffer
10 |
11 | def test():
12 | import numpy as np
13 | buff = ReplayBuffer.create_empty_numpy()
14 | buff.add_episode({
15 | 'obs': np.zeros((100,10), dtype=np.float16)
16 | })
17 | buff.add_episode({
18 | 'obs': np.ones((50,10)),
19 | 'action': np.ones((50,2))
20 | })
21 | # buff.rechunk(256)
22 | obs = buff.get_episode(0)
23 |
24 | import numpy as np
25 | buff = ReplayBuffer.create_empty_zarr()
26 | buff.add_episode({
27 | 'obs': np.zeros((100,10), dtype=np.float16)
28 | })
29 | buff.add_episode({
30 | 'obs': np.ones((50,10)),
31 | 'action': np.ones((50,2))
32 | })
33 | obs = buff.get_episode(0)
34 | buff.set_chunks({
35 | 'obs': (100,10),
36 | 'action': (100,2)
37 | })
38 |
39 |
40 | def test_real():
41 | import os
42 | dist_group = zarr.open(
43 | os.path.expanduser('~/dev/state_diff/data/pusht/pusht_cchi_v2.zarr'), 'r')
44 |
45 | buff = ReplayBuffer.create_empty_numpy()
46 | key, group = next(iter(dist_group.items()))
47 | for key, group in dist_group.items():
48 | buff.add_episode(group)
49 |
50 | # out_path = os.path.expanduser('~/dev/state_diff/data/pusht_cchi2_v2_replay.zarr')
51 | out_path = os.path.expanduser('~/dev/state_diff/data/test.zarr')
52 | out_store = zarr.DirectoryStore(out_path)
53 | buff.save_to_store(out_store)
54 |
55 | buff = ReplayBuffer.copy_from_path(out_path, store=zarr.MemoryStore())
56 | buff.pop_episode()
57 |
58 |
59 | def test_pop():
60 | buff = ReplayBuffer.create_from_path(
61 | '/home/chengchi/dev/state_diff/data/pusht_cchi_v3_replay.zarr',
62 | mode='rw')
63 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/assets/counters_asset.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/kitchen.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
--------------------------------------------------------------------------------
/conda_environment.yaml:
--------------------------------------------------------------------------------
1 | name: coord_bimanual
2 | channels:
3 | - pytorch
4 | - pytorch3d
5 | - nvidia
6 | - conda-forge
7 | dependencies:
8 | - python=3.9
9 | - pip=22.2.2
10 | - cudatoolkit=11.6
11 | - pytorch=1.12.1
12 | - torchvision=0.13.1
13 | - pytorch3d=0.7.0
14 | - numpy=1.23.3
15 | - numba==0.56.4
16 | - scipy==1.9.1
17 | - py-opencv=4.6.0
18 | - cffi=1.15.1
19 | - ipykernel=6.16
20 | - matplotlib=3.6.1
21 | - zarr=2.12.0
22 | - numcodecs=0.10.2
23 | - h5py=3.7.0
24 | - hydra-core=1.2.0
25 | - einops=0.4.1
26 | - tqdm=4.64.1
27 | - dill=0.3.5.1
28 | - scikit-video=1.1.11
29 | - scikit-image=0.19.3
30 | - gym=0.21.0
31 | - pymunk=6.2.1
32 | - wandb=0.13.3
33 | - threadpoolctl=3.1.0
34 | - shapely=1.8.4
35 | - cython=0.29.32
36 | - imageio=2.22.0
37 | - imageio-ffmpeg=0.4.7
38 | - termcolor=2.0.1
39 | - tensorboard=2.10.1
40 | - tensorboardx=2.5.1
41 | - psutil=5.9.2
42 | - click=8.0.4
43 | - boto3=1.24.96
44 | - accelerate=0.13.2
45 | - datasets=2.6.1
46 | - diffusers=0.11.1
47 | - av=10.0.0
48 | - cmake=3.24.3
49 | # trick to avoid cpu affinity issue described in https://github.com/pytorch/pytorch/issues/99625
50 | - llvm-openmp=14
51 | # trick to force reinstall imagecodecs via pip
52 | - imagecodecs==2022.9.26
53 | - pip:
54 | - ray[default,tune]==2.2.0
55 | # requires mujoco py dependencies libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf
56 | - free-mujoco-py==2.1.6
57 | - pygame==2.1.2
58 | - pybullet-svl==3.1.6.4
59 | - robosuite @ https://github.com/cheng-chi/robosuite/archive/277ab9588ad7a4f4b55cf75508b44aa67ec171f0.tar.gz
60 | - robomimic==0.2.0
61 | - pytorchvideo==0.1.5
62 | # pip package required for jpeg-xl
63 | - imagecodecs==2022.9.26
64 | - r3m @ https://github.com/facebookresearch/r3m/archive/b2334e726887fa0206962d7984c69c5fb09cceab.tar.gz
65 | - dm-control==1.0.9
66 | - huggingface-hub==0.25.2
67 | - -e .
68 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/kitchen_util.py:
--------------------------------------------------------------------------------
1 | import struct
2 | import numpy as np
3 |
4 | def parse_mjl_logs(read_filename, skipamount):
5 | with open(read_filename, mode='rb') as file:
6 | fileContent = file.read()
7 | headers = struct.unpack('iiiiiii', fileContent[:28])
8 | nq = headers[0]
9 | nv = headers[1]
10 | nu = headers[2]
11 | nmocap = headers[3]
12 | nsensordata = headers[4]
13 | nuserdata = headers[5]
14 | name_len = headers[6]
15 | name = struct.unpack(str(name_len) + 's', fileContent[28:28+name_len])[0]
16 | rem_size = len(fileContent[28 + name_len:])
17 | num_floats = int(rem_size/4)
18 | dat = np.asarray(struct.unpack(str(num_floats) + 'f', fileContent[28+name_len:]))
19 | recsz = 1 + nq + nv + nu + 7*nmocap + nsensordata + nuserdata
20 | if rem_size % recsz != 0:
21 | print("ERROR")
22 | else:
23 | dat = np.reshape(dat, (int(len(dat)/recsz), recsz))
24 | dat = dat.T
25 |
26 | time = dat[0,:][::skipamount] - 0*dat[0, 0]
27 | qpos = dat[1:nq + 1, :].T[::skipamount, :]
28 | qvel = dat[nq+1:nq+nv+1,:].T[::skipamount, :]
29 | ctrl = dat[nq+nv+1:nq+nv+nu+1,:].T[::skipamount,:]
30 | mocap_pos = dat[nq+nv+nu+1:nq+nv+nu+3*nmocap+1,:].T[::skipamount, :]
31 | mocap_quat = dat[nq+nv+nu+3*nmocap+1:nq+nv+nu+7*nmocap+1,:].T[::skipamount, :]
32 | sensordata = dat[nq+nv+nu+7*nmocap+1:nq+nv+nu+7*nmocap+nsensordata+1,:].T[::skipamount,:]
33 | userdata = dat[nq+nv+nu+7*nmocap+nsensordata+1:,:].T[::skipamount,:]
34 |
35 | data = dict(nq=nq,
36 | nv=nv,
37 | nu=nu,
38 | nmocap=nmocap,
39 | nsensordata=nsensordata,
40 | name=name,
41 | time=time,
42 | qpos=qpos,
43 | qvel=qvel,
44 | ctrl=ctrl,
45 | mocap_pos=mocap_pos,
46 | mocap_quat=mocap_quat,
47 | sensordata=sensordata,
48 | userdata=userdata,
49 | logName = read_filename
50 | )
51 | return data
52 |
--------------------------------------------------------------------------------
/state_diff/model/common/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | from diffusers.optimization import (
2 | Union, SchedulerType, Optional,
3 | Optimizer, TYPE_TO_SCHEDULER_FUNCTION
4 | )
5 |
6 | def get_scheduler(
7 | name: Union[str, SchedulerType],
8 | optimizer: Optimizer,
9 | num_warmup_steps: Optional[int] = None,
10 | num_training_steps: Optional[int] = None,
11 | **kwargs
12 | ):
13 | """
14 | Added kwargs vs diffuser's original implementation
15 |
16 | Unified API to get any scheduler from its name.
17 |
18 | Args:
19 | name (`str` or `SchedulerType`):
20 | The name of the scheduler to use.
21 | optimizer (`torch.optim.Optimizer`):
22 | The optimizer that will be used during training.
23 | num_warmup_steps (`int`, *optional*):
24 | The number of warmup steps to do. This is not required by all schedulers (hence the argument being
25 | optional), the function will raise an error if it's unset and the scheduler type requires it.
26 | num_training_steps (`int``, *optional*):
27 | The number of training steps to do. This is not required by all schedulers (hence the argument being
28 | optional), the function will raise an error if it's unset and the scheduler type requires it.
29 | """
30 | name = SchedulerType(name)
31 | schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
32 | if name == SchedulerType.CONSTANT:
33 | return schedule_func(optimizer, **kwargs)
34 |
35 | # All other schedulers require `num_warmup_steps`
36 | if num_warmup_steps is None:
37 | raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
38 |
39 | if name == SchedulerType.CONSTANT_WITH_WARMUP:
40 | return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **kwargs)
41 |
42 | # All other schedulers require `num_training_steps`
43 | if num_training_steps is None:
44 | raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
45 |
46 | return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **kwargs)
47 |
--------------------------------------------------------------------------------
/state_diff/scripts/real_dataset_conversion.py:
--------------------------------------------------------------------------------
1 | if __name__ == "__main__":
2 | import sys
3 | import os
4 | import pathlib
5 |
6 | ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent)
7 | sys.path.append(ROOT_DIR)
8 |
9 | import os
10 | import click
11 | import pathlib
12 | import zarr
13 | import cv2
14 | import threadpoolctl
15 | from state_diff.real_world.real_data_conversion import real_data_to_replay_buffer
16 |
17 | @click.command()
18 | @click.option('--input', '-i', required=True)
19 | @click.option('--output', '-o', default=None)
20 | @click.option('--resolution', '-r', default='640x480')
21 | @click.option('--n_decoding_threads', '-nd', default=-1, type=int)
22 | @click.option('--n_encoding_threads', '-ne', default=-1, type=int)
23 | def main(input, output, resolution, n_decoding_threads, n_encoding_threads):
24 | out_resolution = tuple(int(x) for x in resolution.split('x'))
25 | input = pathlib.Path(os.path.expanduser(input))
26 | in_zarr_path = input.joinpath('replay_buffer.zarr')
27 | in_video_dir = input.joinpath('videos')
28 | assert in_zarr_path.is_dir()
29 | assert in_video_dir.is_dir()
30 | if output is None:
31 | output = input.joinpath(resolution + '.zarr.zip')
32 | else:
33 | output = pathlib.Path(os.path.expanduser(output))
34 |
35 | if output.exists():
36 | click.confirm('Output path already exists! Overrite?', abort=True)
37 |
38 | cv2.setNumThreads(1)
39 | with threadpoolctl.threadpool_limits(1):
40 | replay_buffer = real_data_to_replay_buffer(
41 | dataset_path=str(input),
42 | out_resolutions=out_resolution,
43 | n_decoding_threads=n_decoding_threads,
44 | n_encoding_threads=n_encoding_threads
45 | )
46 |
47 | print('Saving to disk')
48 | if output.suffix == '.zip':
49 | with zarr.ZipStore(output) as zip_store:
50 | replay_buffer.save_to_store(
51 | store=zip_store
52 | )
53 | else:
54 | with zarr.DirectoryStore(output) as store:
55 | replay_buffer.save_to_store(
56 | store=store
57 | )
58 |
59 | if __name__ == '__main__':
60 | main()
61 |
--------------------------------------------------------------------------------
/state_diff/env/block_pushing/assets/suction/suction-base.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/assets/slidecabinet_chain.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
--------------------------------------------------------------------------------
/state_diff/env/block_pushing/assets/suction/suction-head.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/assets/microwave_chain.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
--------------------------------------------------------------------------------
/state_diff/scripts/generate_bet_blockpush.py:
--------------------------------------------------------------------------------
1 | if __name__ == "__main__":
2 | import sys
3 | import os
4 | import pathlib
5 |
6 | ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent)
7 | sys.path.append(ROOT_DIR)
8 |
9 |
10 | import os
11 | import click
12 | import pathlib
13 | import numpy as np
14 | from tqdm import tqdm
15 | from state_diff.common.replay_buffer import ReplayBuffer
16 | from tf_agents.environments.wrappers import TimeLimit
17 | from tf_agents.environments.gym_wrapper import GymWrapper
18 | from tf_agents.trajectories.time_step import StepType
19 | from state_diff.env.block_pushing.block_pushing_multimodal import BlockPushMultimodal
20 | from state_diff.env.block_pushing.block_pushing import BlockPush
21 | from state_diff.env.block_pushing.oracles.multimodal_push_oracle import MultimodalOrientedPushOracle
22 |
23 | @click.command()
24 | @click.option('-o', '--output', required=True)
25 | @click.option('-n', '--n_episodes', default=1000)
26 | @click.option('-c', '--chunk_length', default=-1)
27 | def main(output, n_episodes, chunk_length):
28 |
29 | buffer = ReplayBuffer.create_empty_numpy()
30 | env = TimeLimit(GymWrapper(BlockPushMultimodal()), duration=350)
31 | for i in tqdm(range(n_episodes)):
32 | print(i)
33 | obs_history = list()
34 | action_history = list()
35 |
36 | env.seed(i)
37 | policy = MultimodalOrientedPushOracle(env)
38 | time_step = env.reset()
39 | policy_state = policy.get_initial_state(1)
40 | while True:
41 | action_step = policy.action(time_step, policy_state)
42 | obs = np.concatenate(list(time_step.observation.values()), axis=-1)
43 | action = action_step.action
44 | obs_history.append(obs)
45 | action_history.append(action)
46 |
47 | if time_step.step_type == 2:
48 | break
49 |
50 | # state = env.wrapped_env().gym.get_pybullet_state()
51 | time_step = env.step(action)
52 | obs_history = np.array(obs_history)
53 | action_history = np.array(action_history)
54 |
55 | episode = {
56 | 'obs': obs_history,
57 | 'action': action_history
58 | }
59 | buffer.add_episode(episode)
60 |
61 | buffer.save_to_path(output, chunk_length=chunk_length)
62 |
63 | if __name__ == '__main__':
64 | main()
65 |
--------------------------------------------------------------------------------
/state_diff/env/block_pushing/utils/pose3d.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Reach ML Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A simple 6DOF pose container.
17 | """
18 |
19 | import dataclasses
20 | import numpy as np
21 | from scipy.spatial import transform
22 |
23 |
24 | class NoCopyAsDict(object):
25 | """Base class for dataclasses. Avoids a copy in the asdict() call."""
26 |
27 | def asdict(self):
28 | """Replacement for dataclasses.asdict.
29 |
30 | TF Dataset does not handle dataclasses.asdict, which uses copy.deepcopy when
31 | setting values in the output dict. This causes issues with tf.Dataset.
32 | Instead, shallow copy contents.
33 |
34 | Returns:
35 | dict containing contents of dataclass.
36 | """
37 | return {k.name: getattr(self, k.name) for k in dataclasses.fields(self)}
38 |
39 |
40 | @dataclasses.dataclass
41 | class Pose3d(NoCopyAsDict):
42 | """Simple container for translation and rotation."""
43 |
44 | rotation: transform.Rotation
45 | translation: np.ndarray
46 |
47 | @property
48 | def vec7(self):
49 | return np.concatenate([self.translation, self.rotation.as_quat()])
50 |
51 | def serialize(self):
52 | return {
53 | "rotation": self.rotation.as_quat().tolist(),
54 | "translation": self.translation.tolist(),
55 | }
56 |
57 | @staticmethod
58 | def deserialize(data):
59 | return Pose3d(
60 | rotation=transform.Rotation.from_quat(data["rotation"]),
61 | translation=np.array(data["translation"]),
62 | )
63 |
64 | def __eq__(self, other):
65 | return np.array_equal(
66 | self.rotation.as_quat(), other.rotation.as_quat()
67 | ) and np.array_equal(self.translation, other.translation)
68 |
69 | def __ne__(self, other):
70 | return not self.__eq__(other)
71 |
--------------------------------------------------------------------------------
/state_diff/common/precise_sleep.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import time
4 |
5 | def precise_sleep(dt: float, slack_time: float = 0.001, time_func=time.monotonic):
6 | """
7 | Performs a precise sleep for a specified duration using a combination of time.sleep and active spinning.
8 |
9 | Args:
10 | dt (float): The total duration to sleep, in seconds.
11 | slack_time (float, optional): The duration to actively spin (busy-wait) after sleeping, to achieve precision. Defaults to 0.001 seconds.
12 | time_func (function, optional): The time function used for measuring time. Defaults to time.monotonic.
13 |
14 | Description:
15 | This function first uses time.sleep to sleep for 'dt - slack_time' seconds,
16 | allowing for low CPU usage during most of the sleep duration. Then, it actively spins
17 | (busy-wait) for the remaining 'slack_time' to ensure precise wake-up. This hybrid approach
18 | minimizes jitter caused by the non-deterministic nature of time.sleep.
19 | """
20 | t_start = time_func()
21 | if dt > slack_time:
22 | time.sleep(dt - slack_time)
23 | t_end = t_start + dt
24 | while time_func() < t_end:
25 | pass
26 | return
27 |
28 | def precise_wait(t_end: float, slack_time: float = 0.001, time_func=time.monotonic):
29 | """
30 | Waits until a specified end time using a combination of time.sleep and active spinning for precision.
31 |
32 | Args:
33 | t_end (float): The target end time in seconds since a fixed point in the past (e.g., system start).
34 | slack_time (float, optional): The duration to actively spin (busy-wait) before reaching the target time, to achieve precision. Defaults to 0.001 seconds.
35 | time_func (function, optional): The time function used for measuring time. Defaults to time.monotonic.
36 |
37 | Description:
38 | This function calculates the remaining time to 't_end' and then uses a combination of
39 | time.sleep and active spinning to wait until this target time. It sleeps for the bulk of
40 | the remaining time, then actively spins for the final 'slack_time' duration,
41 | thus ensuring precise timing.
42 | """
43 | t_start = time_func()
44 | t_wait = t_end - t_start
45 | if t_wait > 0:
46 | t_sleep = t_wait - slack_time
47 | if t_sleep > 0:
48 | time.sleep(t_sleep)
49 | while time_func() < t_end:
50 | pass
51 | return
52 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/third_party/franka/franka_panda_teleop.xml:
--------------------------------------------------------------------------------
1 |
4 |
5 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
--------------------------------------------------------------------------------
/state_diff/env/block_pushing/oracles/reach_oracle.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Reach ML Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Reach oracle."""
17 | import numpy as np
18 | from tf_agents.policies import py_policy
19 | from tf_agents.trajectories import policy_step
20 | from tf_agents.trajectories import time_step as ts
21 | from tf_agents.typing import types
22 |
23 | # Only used for debug visualization.
24 | import pybullet # pylint: disable=unused-import
25 |
26 |
27 | class ReachOracle(py_policy.PyPolicy):
28 | """Oracle for moving to a specific spot relative to the block and target."""
29 |
30 | def __init__(self, env, block_pushing_oracles_action_std=0.0):
31 | super(ReachOracle, self).__init__(env.time_step_spec(), env.action_spec())
32 | self._env = env
33 | self._np_random_state = np.random.RandomState(0)
34 | self._block_pushing_oracles_action_std = block_pushing_oracles_action_std
35 |
36 | def _action(self, time_step, policy_state):
37 |
38 | # Specifying this as velocity makes it independent of control frequency.
39 | max_step_velocity = 0.2
40 |
41 | xy_ee = time_step.observation["effector_target_translation"]
42 |
43 | # This should be observable from block and target translation,
44 | # but re-using the computation from the env so that it's only done once, and
45 | # used for reward / completion computation.
46 | xy_pre_block = self._env.reach_target_translation
47 |
48 | xy_delta = xy_pre_block - xy_ee
49 |
50 | if self._block_pushing_oracles_action_std != 0.0:
51 | xy_delta += (
52 | self._np_random_state.randn(2) * self._block_pushing_oracles_action_std
53 | )
54 |
55 | max_step_distance = max_step_velocity * (1 / self._env.get_control_frequency())
56 | length = np.linalg.norm(xy_delta)
57 | if length > max_step_distance:
58 | xy_direction = xy_delta / length
59 | xy_delta = xy_direction * max_step_distance
60 |
61 | return policy_step.PolicyStep(action=np.asarray(xy_delta, dtype=np.float32))
62 |
--------------------------------------------------------------------------------
/state_diff/scripts/real_pusht_successrate.py:
--------------------------------------------------------------------------------
1 | if __name__ == "__main__":
2 | import sys
3 | import os
4 | import pathlib
5 |
6 | ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent)
7 | sys.path.append(ROOT_DIR)
8 |
9 | import os
10 | import click
11 | import collections
12 | import numpy as np
13 | from tqdm import tqdm
14 | import json
15 |
16 | @click.command()
17 | @click.option(
18 | '--reference', '-r', required=True,
19 | help='Reference metrics_raw.json from demonstration dataset.'
20 | )
21 | @click.option(
22 | '--input', '-i', required=True,
23 | help='Data search path'
24 | )
25 | def main(reference, input):
26 | # compute the min last metric for demo metrics
27 | demo_metrics = json.load(open(reference, 'r'))
28 | demo_min_metrics = collections.defaultdict(lambda:float('inf'))
29 | for episode_idx, metrics in demo_metrics.items():
30 | for key, value in metrics.items():
31 | last_value = value[-1]
32 | demo_min_metrics[key] = min(demo_min_metrics[key], last_value)
33 | print(demo_min_metrics)
34 |
35 | # find all metric
36 | name = 'metrics_raw.json'
37 | search_dir = pathlib.Path(input)
38 | success_rate_map = dict()
39 | for json_path in search_dir.glob('**/'+name):
40 | rel_path = json_path.relative_to(search_dir)
41 | rel_name = str(rel_path.parent)
42 | this_metrics = json.load(json_path.open('r'))
43 | metric_success_idxs = collections.defaultdict(list)
44 | metric_failure_idxs = collections.defaultdict(list)
45 | for episode_idx, metrics in this_metrics.items():
46 | for key, value in metrics.items():
47 | last_value = value[-1]
48 | # print(episode_idx, key, last_value)
49 | demo_min = demo_min_metrics[key]
50 | if last_value >= demo_min:
51 | # success
52 | metric_success_idxs[key].append(episode_idx)
53 | else:
54 | metric_failure_idxs[key].append(episode_idx)
55 | # in case of no success
56 | _ = metric_success_idxs[key]
57 | _ = metric_failure_idxs[key]
58 | metric_success_rate = dict()
59 | n_episodes = len(this_metrics)
60 | for key, value in metric_success_idxs.items():
61 | metric_success_rate[key] = len(value) / n_episodes
62 | # metric_success_rate['failured_idxs'] = metric_failure_idxs
63 | success_rate_map[rel_name] = metric_success_rate
64 |
65 | text = json.dumps(success_rate_map, indent=2)
66 | print(text)
67 |
68 | if __name__ == '__main__':
69 | main()
70 |
--------------------------------------------------------------------------------
/tests/test_multi_realsense.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
5 | sys.path.append(ROOT_DIR)
6 | os.chdir(ROOT_DIR)
7 |
8 | import cv2
9 | import json
10 | import time
11 | import numpy as np
12 | from state_diff.real_world.multi_realsense import MultiRealsense
13 | from state_diff.real_world.video_recorder import VideoRecorder
14 |
15 | def test():
16 | config = json.load(open('/home/cchi/dev/state_diff/state_diff/real_world/realsense_config/415_high_accuracy_mode.json', 'r'))
17 |
18 | def transform(data):
19 | color = data['color']
20 | h,w,_ = color.shape
21 | factor = 4
22 | color = cv2.resize(color, (w//factor,h//factor), interpolation=cv2.INTER_AREA)
23 | # color = color[:,140:500]
24 | data['color'] = color
25 | return data
26 |
27 | from state_diff.common.cv2_util import get_image_transform
28 | color_transform = get_image_transform(
29 | input_res=(1280,720),
30 | output_res=(640,480),
31 | bgr_to_rgb=False)
32 | def transform(data):
33 | data['color'] = color_transform(data['color'])
34 | return data
35 |
36 | # one thread per camera
37 | video_recorder = VideoRecorder.create_h264(
38 | fps=30,
39 | codec='h264',
40 | thread_type='FRAME'
41 | )
42 |
43 | with MultiRealsense(
44 | resolution=(1280,720),
45 | capture_fps=30,
46 | record_fps=15,
47 | enable_color=True,
48 | # advanced_mode_config=config,
49 | transform=transform,
50 | # recording_transform=transform,
51 | # video_recorder=video_recorder,
52 | verbose=True
53 | ) as realsense:
54 | realsense.set_exposure(exposure=150, gain=5)
55 | intr = realsense.get_intr_mat()
56 | print(intr)
57 |
58 | video_path = 'data_local/test'
59 | rec_start_time = time.time() + 1
60 | realsense.start_recording(video_path, start_time=rec_start_time)
61 | realsense.restart_put(rec_start_time)
62 |
63 | out = None
64 | vis_img = None
65 | while True:
66 | out = realsense.get(out=out)
67 |
68 | # bgr = out['color']
69 | # print(bgr.shape)
70 | # vis_img = np.concatenate(list(bgr), axis=0, out=vis_img)
71 | # cv2.imshow('default', vis_img)
72 | # key = cv2.pollKey()
73 | # if key == ord('q'):
74 | # break
75 |
76 | time.sleep(1/60)
77 | if time.time() > (rec_start_time + 20.0):
78 | break
79 |
80 |
81 | if __name__ == "__main__":
82 | test()
83 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | bin
2 | logs
3 | wandb
4 | outputs
5 | data
6 | data_local
7 | .vscode
8 | _wandb
9 |
10 | **/.DS_Store
11 |
12 | fuse.cfg
13 |
14 | *.ai
15 |
16 | # Generation results
17 | results/
18 |
19 | ray/auth.json
20 |
21 | # Byte-compiled / optimized / DLL files
22 | __pycache__/
23 | *.py[cod]
24 | *$py.class
25 |
26 | # C extensions
27 | *.so
28 |
29 | # Distribution / packaging
30 | .Python
31 | build/
32 | develop-eggs/
33 | dist/
34 | downloads/
35 | eggs/
36 | .eggs/
37 | lib/
38 | lib64/
39 | parts/
40 | sdist/
41 | var/
42 | wheels/
43 | pip-wheel-metadata/
44 | share/python-wheels/
45 | *.egg-info/
46 | .installed.cfg
47 | *.egg
48 | MANIFEST
49 |
50 | # PyInstaller
51 | # Usually these files are written by a python script from a template
52 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
53 | *.manifest
54 | *.spec
55 |
56 | # Installer logs
57 | pip-log.txt
58 | pip-delete-this-directory.txt
59 |
60 | # Unit test / coverage reports
61 | htmlcov/
62 | .tox/
63 | .nox/
64 | .coverage
65 | .coverage.*
66 | .cache
67 | nosetests.xml
68 | coverage.xml
69 | *.cover
70 | *.py,cover
71 | .hypothesis/
72 | .pytest_cache/
73 |
74 | # Translations
75 | *.mo
76 | *.pot
77 |
78 | # Django stuff:
79 | *.log
80 | *.out
81 | local_settings.py
82 | db.sqlite3
83 | db.sqlite3-journal
84 |
85 | # Flask stuff:
86 | instance/
87 | .webassets-cache
88 |
89 | # Scrapy stuff:
90 | .scrapy
91 |
92 | # Sphinx documentation
93 | docs/_build/
94 |
95 | # PyBuilder
96 | target/
97 |
98 | # Jupyter Notebook
99 | .ipynb_checkpoints
100 |
101 | # IPython
102 | profile_default/
103 | ipython_config.py
104 |
105 | # pyenv
106 | .python-version
107 |
108 | # pipenv
109 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
110 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
111 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
112 | # install all needed dependencies.
113 | #Pipfile.lock
114 |
115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
116 | __pypackages__/
117 |
118 | # Celery stuff
119 | celerybeat-schedule
120 | celerybeat.pid
121 |
122 | # SageMath parsed files
123 | *.sage.py
124 |
125 | # Spyder project settings
126 | .spyderproject
127 | .spyproject
128 |
129 | # Rope project settings
130 | .ropeproject
131 |
132 | # mkdocs documentation
133 | /site
134 |
135 | # mypy
136 | .mypy_cache/
137 | .dmypy.json
138 | dmypy.json
139 |
140 | # Pyre type checker
141 | .pyre/
142 |
143 |
144 | wandb_api_key.txt
145 |
146 | scripts/slurm/slurm_outputs
--------------------------------------------------------------------------------
/state_diff/model/common/losses.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for various likelihood-based losses. These are ported from the original
3 | Ho et al. diffusion models codebase:
4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
5 | """
6 |
7 | import numpy as np
8 |
9 | import torch as th
10 |
11 |
12 | def normal_kl(mean1, logvar1, mean2, logvar2):
13 | """
14 | Compute the KL divergence between two gaussians.
15 |
16 | Shapes are automatically broadcasted, so batches can be compared to
17 | scalars, among other use cases.
18 | """
19 | tensor = None
20 | for obj in (mean1, logvar1, mean2, logvar2):
21 | if isinstance(obj, th.Tensor):
22 | tensor = obj
23 | break
24 | assert tensor is not None, "at least one argument must be a Tensor"
25 |
26 | # Force variances to be Tensors. Broadcasting helps convert scalars to
27 | # Tensors, but it does not work for th.exp().
28 | logvar1, logvar2 = [
29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
30 | for x in (logvar1, logvar2)
31 | ]
32 |
33 | return 0.5 * (
34 | -1.0
35 | + logvar2
36 | - logvar1
37 | + th.exp(logvar1 - logvar2)
38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
39 | )
40 |
41 |
42 | def approx_standard_normal_cdf(x):
43 | """
44 | A fast approximation of the cumulative distribution function of the
45 | standard normal.
46 | """
47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
48 |
49 |
50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
51 | """
52 | Compute the log-likelihood of a Gaussian distribution discretizing to a
53 | given image.
54 |
55 | :param x: the target images. It is assumed that this was uint8 values,
56 | rescaled to the range [-1, 1].
57 | :param means: the Gaussian mean Tensor.
58 | :param log_scales: the Gaussian log stddev Tensor.
59 | :return: a tensor like x of log probabilities (in nats).
60 | """
61 | assert x.shape == means.shape == log_scales.shape
62 | centered_x = x - means
63 | inv_stdv = th.exp(-log_scales)
64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
65 | cdf_plus = approx_standard_normal_cdf(plus_in)
66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
67 | cdf_min = approx_standard_normal_cdf(min_in)
68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
70 | cdf_delta = cdf_plus - cdf_min
71 | log_probs = th.where(
72 | x < -0.999,
73 | log_cdf_plus,
74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
75 | )
76 | assert log_probs.shape == x.shape
77 | return log_probs
78 |
--------------------------------------------------------------------------------
/state_diff/env/block_pushing/oracles/discontinuous_push_oracle.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Reach ML Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Pushes to first target, waits, then pushes to second target."""
17 |
18 | import state_diff.env.block_pushing.oracles.oriented_push_oracle as oriented_push_oracle_module
19 | import numpy as np
20 | from tf_agents.trajectories import policy_step
21 | from tf_agents.trajectories import time_step as ts
22 | from tf_agents.typing import types
23 |
24 | # Only used for debug visualization.
25 | import pybullet # pylint: disable=unused-import
26 |
27 |
28 | class DiscontinuousOrientedPushOracle(oriented_push_oracle_module.OrientedPushOracle):
29 | """Pushes to first target, waits, then pushes to second target."""
30 |
31 | def __init__(self, env, goal_tolerance=0.04, wait=0):
32 | super(DiscontinuousOrientedPushOracle, self).__init__(env)
33 | self._countdown = 0
34 | self._wait = wait
35 | self._goal_dist_tolerance = goal_tolerance
36 |
37 | def reset(self):
38 | self.phase = "move_to_pre_block"
39 | self._countdown = 0
40 |
41 | def _action(self, time_step, policy_state):
42 | if time_step.is_first():
43 | self.reset()
44 | # Move to first target first.
45 | self._current_target = "target"
46 | self._has_switched = False
47 |
48 | def _block_target_dist(block, target):
49 | dist = np.linalg.norm(
50 | time_step.observation["%s_translation" % block]
51 | - time_step.observation["%s_translation" % target]
52 | )
53 | return dist
54 |
55 | d1 = _block_target_dist("block", "target")
56 | if d1 < self._goal_dist_tolerance and not self._has_switched:
57 | self._countdown = self._wait
58 | # If first block has been pushed to first target, switch to second block.
59 | self._has_switched = True
60 | self._current_target = "target2"
61 |
62 | xy_delta = self._get_action_for_block_target(
63 | time_step, block="block", target=self._current_target
64 | )
65 |
66 | if self._countdown > 0:
67 | xy_delta = np.zeros_like(xy_delta)
68 | self._countdown -= 1
69 |
70 | return policy_step.PolicyStep(action=np.asarray(xy_delta, dtype=np.float32))
71 |
--------------------------------------------------------------------------------
/state_diff/common/pytorch_util.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Callable, List
2 | import collections
3 | import torch
4 | import torch.nn as nn
5 |
6 | def dict_apply(
7 | x: Dict[str, torch.Tensor],
8 | func: Callable[[torch.Tensor], torch.Tensor]
9 | ) -> Dict[str, torch.Tensor]:
10 | result = dict()
11 | for key, value in x.items():
12 | if isinstance(value, dict):
13 | result[key] = dict_apply(value, func)
14 | else:
15 | result[key] = func(value)
16 | return result
17 |
18 | def pad_remaining_dims(x, target):
19 | assert x.shape == target.shape[:len(x.shape)]
20 | return x.reshape(x.shape + (1,)*(len(target.shape) - len(x.shape)))
21 |
22 | def dict_apply_split(
23 | x: Dict[str, torch.Tensor],
24 | split_func: Callable[[torch.Tensor], Dict[str, torch.Tensor]]
25 | ) -> Dict[str, torch.Tensor]:
26 | results = collections.defaultdict(dict)
27 | for key, value in x.items():
28 | result = split_func(value)
29 | for k, v in result.items():
30 | results[k][key] = v
31 | return results
32 |
33 | def dict_apply_reduce(
34 | x: List[Dict[str, torch.Tensor]],
35 | reduce_func: Callable[[List[torch.Tensor]], torch.Tensor]
36 | ) -> Dict[str, torch.Tensor]:
37 | result = dict()
38 | for key in x[0].keys():
39 | result[key] = reduce_func([x_[key] for x_ in x])
40 | return result
41 |
42 |
43 | def replace_submodules(
44 | root_module: nn.Module,
45 | predicate: Callable[[nn.Module], bool],
46 | func: Callable[[nn.Module], nn.Module]) -> nn.Module:
47 | """
48 | predicate: Return true if the module is to be replaced.
49 | func: Return new module to use.
50 | """
51 | if predicate(root_module):
52 | return func(root_module)
53 |
54 | bn_list = [k.split('.') for k, m
55 | in root_module.named_modules(remove_duplicate=True)
56 | if predicate(m)]
57 | for *parent, k in bn_list:
58 | parent_module = root_module
59 | if len(parent) > 0:
60 | parent_module = root_module.get_submodule('.'.join(parent))
61 | if isinstance(parent_module, nn.Sequential):
62 | src_module = parent_module[int(k)]
63 | else:
64 | src_module = getattr(parent_module, k)
65 | tgt_module = func(src_module)
66 | if isinstance(parent_module, nn.Sequential):
67 | parent_module[int(k)] = tgt_module
68 | else:
69 | setattr(parent_module, k, tgt_module)
70 | # verify that all BN are replaced
71 | bn_list = [k.split('.') for k, m
72 | in root_module.named_modules(remove_duplicate=True)
73 | if predicate(m)]
74 | assert len(bn_list) == 0
75 | return root_module
76 |
77 | def optimizer_to(optimizer, device):
78 | for state in optimizer.state.values():
79 | for k, v in state.items():
80 | if isinstance(v, torch.Tensor):
81 | state[k] = v.to(device=device)
82 | return optimizer
83 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/assets/oven_asset.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
--------------------------------------------------------------------------------
/state_diff/env/block_pushing/assets/suction/cylinder_real.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
95 |
96 |
97 |
98 |
99 |
--------------------------------------------------------------------------------
/state_diff/env/block_pushing/assets/suction/cylinder.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
95 |
96 |
97 |
98 |
99 |
--------------------------------------------------------------------------------
/tests/test_single_realsense.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
5 | sys.path.append(ROOT_DIR)
6 | os.chdir(ROOT_DIR)
7 |
8 | import cv2
9 | import json
10 | import time
11 | from multiprocessing.managers import SharedMemoryManager
12 | from state_diff.real_world.single_realsense import SingleRealsense
13 |
14 | def test():
15 |
16 | serials = SingleRealsense.get_connected_devices_serial()
17 | # import pdb; pdb.set_trace()
18 | serial = serials[0]
19 | config = json.load(open('/home/cchi/dev/state_diff/state_diff/real_world/realsense_config/415_high_accuracy_mode.json', 'r'))
20 |
21 | def transform(data):
22 | color = data['color']
23 | h,w,_ = color.shape
24 | factor = 2
25 | color = cv2.resize(color, (w//factor,h//factor), interpolation=cv2.INTER_AREA)
26 | # color = color[:,140:500]
27 | data['color'] = color
28 | return data
29 |
30 | # at 960x540 with //3, 60fps and 30fps are indistinguishable
31 |
32 | with SharedMemoryManager() as shm_manager:
33 | with SingleRealsense(
34 | shm_manager=shm_manager,
35 | serial_number=serial,
36 | resolution=(1280,720),
37 | # resolution=(960,540),
38 | # resolution=(640,480),
39 | capture_fps=30,
40 | enable_color=True,
41 | # enable_depth=True,
42 | # enable_infrared=True,
43 | # advanced_mode_config=config,
44 | # transform=transform,
45 | # recording_transform=transform
46 | # verbose=True
47 | ) as realsense:
48 | cv2.setNumThreads(1)
49 | realsense.set_exposure(exposure=150, gain=5)
50 | intr = realsense.get_intr_mat()
51 | print(intr)
52 |
53 |
54 | video_path = 'data_local/test.mp4'
55 | rec_start_time = time.time() + 2
56 | realsense.start_recording(video_path, start_time=rec_start_time)
57 |
58 | data = None
59 | while True:
60 | data = realsense.get(out=data)
61 | t = time.time()
62 | # print('capture_latency', data['receive_timestamp']-data['capture_timestamp'], 'receive_latency', t - data['receive_timestamp'])
63 | # print('receive', t - data['receive_timestamp'])
64 |
65 | # dt = time.time() - data['timestamp']
66 | # print(dt)
67 | # print(data['capture_timestamp'] - rec_start_time)
68 |
69 | bgr = data['color']
70 | # print(bgr.shape)
71 | cv2.imshow('default', bgr)
72 | key = cv2.pollKey()
73 | # if key == ord('q'):
74 | # break
75 | # elif key == ord('r'):
76 | # video_path = 'data_local/test.mp4'
77 | # realsense.start_recording(video_path)
78 | # elif key == ord('s'):
79 | # realsense.stop_recording()
80 |
81 | time.sleep(1/60)
82 | if time.time() > (rec_start_time + 20.0):
83 | break
84 |
85 |
86 | if __name__ == "__main__":
87 | test()
88 |
--------------------------------------------------------------------------------
/state_diff/env/block_pushing/assets/suction/suction-head-long.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
98 |
99 |
100 |
101 |
102 |
--------------------------------------------------------------------------------
/state_diff/dataset/kitchen_lowdim_dataset.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | import torch
3 | import numpy as np
4 | import copy
5 | import pathlib
6 | from state_diff.common.pytorch_util import dict_apply
7 | from state_diff.common.replay_buffer import ReplayBuffer
8 | from state_diff.common.sampler import SequenceSampler, get_val_mask
9 | from state_diff.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer
10 | from state_diff.dataset.base_dataset import BaseLowdimDataset
11 |
12 | class KitchenLowdimDataset(BaseLowdimDataset):
13 | def __init__(self,
14 | dataset_dir,
15 | horizon=1,
16 | pad_before=0,
17 | pad_after=0,
18 | seed=42,
19 | val_ratio=0.0
20 | ):
21 | super().__init__()
22 |
23 | data_directory = pathlib.Path(dataset_dir)
24 | observations = np.load(data_directory / "observations_seq.npy")
25 | actions = np.load(data_directory / "actions_seq.npy")
26 | masks = np.load(data_directory / "existence_mask.npy")
27 |
28 | self.replay_buffer = ReplayBuffer.create_empty_numpy()
29 | for i in range(len(masks)):
30 | eps_len = int(masks[i].sum())
31 | obs = observations[i,:eps_len].astype(np.float32)
32 | action = actions[i,:eps_len].astype(np.float32)
33 | data = {
34 | 'obs': obs,
35 | 'action': action
36 | }
37 | self.replay_buffer.add_episode(data)
38 |
39 | val_mask = get_val_mask(
40 | n_episodes=self.replay_buffer.n_episodes,
41 | val_ratio=val_ratio,
42 | seed=seed)
43 | train_mask = ~val_mask
44 | self.sampler = SequenceSampler(
45 | replay_buffer=self.replay_buffer,
46 | sequence_length=horizon,
47 | pad_before=pad_before,
48 | pad_after=pad_after,
49 | episode_mask=train_mask)
50 |
51 | self.train_mask = train_mask
52 | self.horizon = horizon
53 | self.pad_before = pad_before
54 | self.pad_after = pad_after
55 |
56 | def get_validation_dataset(self):
57 | val_set = copy.copy(self)
58 | val_set.sampler = SequenceSampler(
59 | replay_buffer=self.replay_buffer,
60 | sequence_length=self.horizon,
61 | pad_before=self.pad_before,
62 | pad_after=self.pad_after,
63 | episode_mask=~self.train_mask
64 | )
65 | val_set.train_mask = ~self.train_mask
66 | return val_set
67 |
68 | def get_normalizer(self, mode='limits', **kwargs):
69 | data = {
70 | 'obs': self.replay_buffer['obs'],
71 | 'action': self.replay_buffer['action']
72 | }
73 | if 'range_eps' not in kwargs:
74 | # to prevent blowing up dims that barely change
75 | kwargs['range_eps'] = 5e-2
76 | normalizer = LinearNormalizer()
77 | normalizer.fit(data=data, last_n_dims=1, mode=mode, **kwargs)
78 | return normalizer
79 |
80 | def get_all_actions(self) -> torch.Tensor:
81 | return torch.from_numpy(self.replay_buffer['action'])
82 |
83 | def __len__(self) -> int:
84 | return len(self.sampler)
85 |
86 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
87 | sample = self.sampler.sample_sequence(idx)
88 | data = sample
89 |
90 | torch_data = dict_apply(data, torch.from_numpy)
91 | return torch_data
92 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/third_party/franka/bi-franka_panda.xml:
--------------------------------------------------------------------------------
1 |
4 |
5 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 | /
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
--------------------------------------------------------------------------------
/state_diff/common/checkpoint_util.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Dict
2 | import os
3 |
4 | class TopKCheckpointManager:
5 | """
6 | A manager for saving top K performing model checkpoints based on a specified metric.
7 |
8 | Attributes:
9 | save_dir (str): Directory where checkpoints are saved.
10 | monitor_key (str): Key for the metric to monitor in the data dictionary.
11 | mode (str): Mode for evaluation ('max' for higher is better, 'min' for lower is better).
12 | k (int): Number of top-performing checkpoints to keep.
13 | format_str (str): Format string for generating checkpoint filenames.
14 | path_value_map (dict): A dictionary mapping checkpoint paths to their metric values.
15 | """
16 |
17 | def __init__(self,
18 | save_dir,
19 | monitor_key: str,
20 | mode='min',
21 | k=1,
22 | format_str='epoch={epoch:03d}-train_loss={train_loss:.3f}.ckpt'
23 | ):
24 | """
25 | Initializes the TopKCheckpointManager with the specified configuration.
26 |
27 | Args:
28 | save_dir (str): The directory to save checkpoints.
29 | monitor_key (str): The metric key to monitor for checkpointing.
30 | mode (str, optional): Determines if 'max' or 'min' values of the monitored metric are better. Defaults to 'min'.
31 | k (int, optional): The number of top checkpoints to maintain. Defaults to 1.
32 | format_str (str, optional): Format string for checkpoint filenames. Defaults to 'epoch={epoch:03d}-train_loss={train_loss:.3f}.ckpt'.
33 | """
34 |
35 | assert mode in ['max', 'min']
36 | assert k >= 0
37 |
38 | self.save_dir = save_dir
39 | self.monitor_key = monitor_key
40 | self.mode = mode
41 | self.k = k
42 | self.format_str = format_str
43 | self.path_value_map = dict()
44 |
45 | def get_ckpt_path(self, data: Dict[str, float]) -> Optional[str]:
46 | """
47 | Determines the checkpoint path for the current metrics and decides whether to replace an existing checkpoint.
48 |
49 | Args:
50 | data (Dict[str, float]): A dictionary containing current metric data.
51 |
52 | Returns:
53 | Optional[str]: The path to save the new checkpoint if it is among the top K; otherwise, None.
54 | """
55 |
56 | if self.k == 0:
57 | return None
58 |
59 | value = data[self.monitor_key]
60 | ckpt_path = os.path.join(
61 | self.save_dir, self.format_str.format(**data))
62 |
63 | if len(self.path_value_map) < self.k:
64 | # under-capacity
65 | self.path_value_map[ckpt_path] = value
66 | return ckpt_path
67 |
68 | # at capacity
69 | sorted_map = sorted(self.path_value_map.items(), key=lambda x: x[1])
70 | min_path, min_value = sorted_map[0]
71 | max_path, max_value = sorted_map[-1]
72 |
73 | delete_path = None
74 | if self.mode == 'max':
75 | if value > min_value:
76 | delete_path = min_path
77 | else:
78 | if value < max_value:
79 | delete_path = max_path
80 |
81 | if delete_path is None:
82 | return None
83 | else:
84 | del self.path_value_map[delete_path]
85 | self.path_value_map[ckpt_path] = value
86 |
87 | if not os.path.exists(self.save_dir):
88 | os.mkdir(self.save_dir)
89 |
90 | if os.path.exists(delete_path):
91 | os.remove(delete_path)
92 | return ckpt_path
93 |
--------------------------------------------------------------------------------
/state_diff/gym_util/video_recording_wrapper.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import numpy as np
3 | from state_diff.gym_util.video_recorder import VideoRecorder
4 | import cv2
5 | import matplotlib.pyplot as plt
6 | class VideoRecordingWrapper(gym.Wrapper):
7 | def __init__(self,
8 | env,
9 | video_recoder: VideoRecorder,
10 | mode='rgb_array',
11 | file_path=None,
12 | steps_per_render=1,
13 | **kwargs
14 | ):
15 | """
16 | When file_path is None, don't record.
17 | """
18 | super().__init__(env)
19 |
20 | self.mode = mode
21 | self.render_kwargs = kwargs
22 | self.steps_per_render = steps_per_render
23 | self.file_path = file_path
24 | self.video_recoder = video_recoder
25 |
26 | self.step_count = 0
27 |
28 | def reset(self, **kwargs):
29 | obs = super().reset(**kwargs)
30 | self.frames = list()
31 | self.step_count = 1
32 | self.video_recoder.stop()
33 | return obs
34 |
35 | def step(self, action):
36 | deno_actions = None
37 | if len(action.shape) == 2:
38 | deno_actions = action
39 | action = action[-1]
40 | result = super().step(action)
41 | self.step_count += 1
42 | if self.file_path is not None \
43 | and ((self.step_count % self.steps_per_render) == 0):
44 | if not self.video_recoder.is_ready():
45 | self.video_recoder.start(self.file_path)
46 |
47 | frame = self.env.render(
48 | mode=self.mode, **self.render_kwargs)
49 |
50 | if deno_actions is not None:
51 | frame = self.overlay_points_on_frame(frame, deno_actions)
52 |
53 | assert frame.dtype == np.uint8
54 |
55 | self.video_recoder.write_frame(frame)
56 | return result
57 |
58 | def render(self, mode='rgb_array', **kwargs):
59 | if self.video_recoder.is_ready():
60 | self.video_recoder.stop()
61 | return self.file_path
62 |
63 |
64 |
65 | def overlay_points_on_frame(self, frame, points, color=(255, 0, 0), gradient=False):
66 |
67 | """Helper method to overlay deno_actions as points on the frame."""
68 | # if points.shape[-1] == 4:
69 | # points = np.concatenate([points[:, :2], points[:, 2:]], axis=0)
70 | # elif points.shape[-1] == 6:
71 | # points = np.concatenate([points[:, :2], points[:, 2:4], points[:, 4:]], axis=0)
72 |
73 | # Mapping of possible dimensions to slicing rules
74 | slicing_map = {
75 | 4: [slice(0, 2), slice(2, 4)],
76 | 6: [slice(0, 2), slice(2, 4), slice(4, 6)]
77 | }
78 |
79 | # Check if the last dimension matches a supported format
80 | slices = slicing_map.get(points.shape[-1])
81 | if slices:
82 | points = np.concatenate([points[:, sl] for sl in slices], axis=0)
83 | frame_size = frame.shape[0]
84 | if frame_size == 96:
85 | radians = 1
86 | elif frame_size == 512:
87 | radius = 4
88 |
89 | if gradient:
90 | colors = plt.cm.plasma(np.linspace(0, 1, points.shape[0]))[:, :3] * 255
91 | else:
92 | colors = [color] * points.shape[0]
93 |
94 | coord = (points / 512 * frame_size).astype(np.int32)
95 | for idx, point in enumerate(coord):
96 | x, y = map(int, point)
97 | cv2.circle(frame, (x, y), radius=radius, color=tuple(colors[idx]), thickness=-1)
98 | return frame
99 |
100 |
--------------------------------------------------------------------------------
/state_diff/model/common/rotation_transformer.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | import pytorch3d.transforms as pt
3 | import torch
4 | import numpy as np
5 | import functools
6 |
7 | class RotationTransformer:
8 | valid_reps = [
9 | 'axis_angle',
10 | 'euler_angles',
11 | 'quaternion', # (w, x, y, z) convention
12 | 'rotation_6d',
13 | 'matrix'
14 | ]
15 |
16 | def __init__(self,
17 | from_rep='axis_angle',
18 | to_rep='rotation_6d',
19 | from_convention=None,
20 | to_convention=None):
21 | """
22 | Valid representations
23 |
24 | Always use matrix as intermediate representation.
25 | """
26 | assert from_rep != to_rep
27 | assert from_rep in self.valid_reps
28 | assert to_rep in self.valid_reps
29 | if from_rep == 'euler_angles':
30 | assert from_convention is not None
31 | if to_rep == 'euler_angles':
32 | assert to_convention is not None
33 |
34 | forward_funcs = list()
35 | inverse_funcs = list()
36 |
37 | if from_rep != 'matrix':
38 | funcs = [
39 | getattr(pt, f'{from_rep}_to_matrix'),
40 | getattr(pt, f'matrix_to_{from_rep}')
41 | ]
42 | if from_convention is not None:
43 | funcs = [functools.partial(func, convention=from_convention)
44 | for func in funcs]
45 | forward_funcs.append(funcs[0])
46 | inverse_funcs.append(funcs[1])
47 |
48 | if to_rep != 'matrix':
49 | funcs = [
50 | getattr(pt, f'matrix_to_{to_rep}'),
51 | getattr(pt, f'{to_rep}_to_matrix')
52 | ]
53 | if to_convention is not None:
54 | funcs = [functools.partial(func, convention=to_convention)
55 | for func in funcs]
56 | forward_funcs.append(funcs[0])
57 | inverse_funcs.append(funcs[1])
58 |
59 | inverse_funcs = inverse_funcs[::-1]
60 |
61 | self.forward_funcs = forward_funcs
62 | self.inverse_funcs = inverse_funcs
63 |
64 | @staticmethod
65 | def _apply_funcs(x: Union[np.ndarray, torch.Tensor], funcs: list) -> Union[np.ndarray, torch.Tensor]:
66 | x_ = x
67 | if isinstance(x, np.ndarray):
68 | x_ = torch.from_numpy(x)
69 | x_: torch.Tensor
70 | for func in funcs:
71 | x_ = func(x_)
72 | y = x_
73 | if isinstance(x, np.ndarray):
74 | y = x_.numpy()
75 | return y
76 |
77 | def forward(self, x: Union[np.ndarray, torch.Tensor]
78 | ) -> Union[np.ndarray, torch.Tensor]:
79 | return self._apply_funcs(x, self.forward_funcs)
80 |
81 | def inverse(self, x: Union[np.ndarray, torch.Tensor]
82 | ) -> Union[np.ndarray, torch.Tensor]:
83 | return self._apply_funcs(x, self.inverse_funcs)
84 |
85 |
86 | def test():
87 | tf = RotationTransformer()
88 |
89 | rotvec = np.random.uniform(-2*np.pi,2*np.pi,size=(1000,3))
90 | rot6d = tf.forward(rotvec)
91 | new_rotvec = tf.inverse(rot6d)
92 |
93 | from scipy.spatial.transform import Rotation
94 | diff = Rotation.from_rotvec(rotvec) * Rotation.from_rotvec(new_rotvec).inv()
95 | dist = diff.magnitude()
96 | assert dist.max() < 1e-7
97 |
98 | tf = RotationTransformer('rotation_6d', 'matrix')
99 | rot6d_wrong = rot6d + np.random.normal(scale=0.1, size=rot6d.shape)
100 | mat = tf.forward(rot6d_wrong)
101 | mat_det = np.linalg.det(mat)
102 | assert np.allclose(mat_det, 1)
103 | # rotaiton_6d will be normalized to rotation matrix
104 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Learning Coordinated Bimanual Manipulation Policies using State Diffusion and Inverse Dynamics Models
2 |
3 |
4 |
5 | [Haonan Chen](http://haonan16.github.io/)1,
6 | [Jiaming Xu](#)*1,
7 | [Lily Sheng](#)*1,
8 | [Tianchen Ji](https://tianchenji.github.io/)1
9 | [Shuijing Liu](https://shuijing725.github.io/)3
10 | [Yunzhu Li](https://yunzhuli.github.io/)2,
11 | [Katherine Driggs-Campbell](https://krdc.web.illinois.edu/)1
12 |
13 | 1University of Illinois, Urbana-Champaign,
14 | 2Columbia University,
15 | 3The University of Texas at Austin
16 |
17 |
18 | ### Environment Setup
19 |
20 | We recommend using [Mambaforge](https://gyithub.com/conda-forge/miniforge#mambaforge) over the standard Anaconda distribution for a faster installation process. Create your environment using:
21 |
22 | 1. Install the necessary dependencies:
23 | ```console
24 | sudo apt install -y libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf libglm-dev
25 | ```
26 |
27 | 2. Clone the repository:
28 | ```console
29 | git clone https://github.com/haonan16/state_diff.git
30 | cd state_diff/
31 | ```
32 |
33 | 3. Update `mamba`, and create and activate the environment:
34 |
35 | To update `mamba` and create the environment, use the following commands:
36 |
37 | ```console
38 | mamba install mamba=1.5.1 -n base -c conda-forge
39 | mamba env create -f conda_environment.yaml
40 | mamba activate coord_bimanual
41 | ```
42 |
43 |
44 |
45 |
46 |
47 | ## Data Preparation
48 |
49 | ```console
50 | mkdir -p data
51 | ```
52 |
53 | Place the `pushl` dataset in the `data` folder. The directory structure should be:
54 |
55 | To obtain the dataset, download the corresponding zip file and unzip it from the following link:
56 |
57 | - [PushL Dataset](https://drive.google.com/file/d/1433HHxOH5nomDZ12aUZ4XZel4nVO_uwL/view?usp=sharing)
58 |
59 | You can automate this process by running the following command:
60 |
61 | ```console
62 | gdown https://drive.google.com/uc?id=1433HHxOH5nomDZ12aUZ4XZel4nVO_uwL -O data/pushl_dataset.zip
63 | unzip data/pushl_dataset.zip -d data
64 | rm data/pushl_dataset.zip
65 | ```
66 |
67 |
68 | For datasets from other simulation benchmarks, you can find them here:
69 |
70 | - [Simulation Benchmark Datasets from Diffusion Policy](https://github.com/real-stanford/diffusion_policy?tab=readme-ov-file)
71 |
72 |
73 |
74 | Once downloaded, extract the contents into the `data` folder.
75 |
76 | Directory structure:
77 |
78 | ```
79 | data/
80 | └── pushl_dataset/
81 | ```
82 |
83 |
84 |
85 |
86 |
87 | ## Demo, Training and Eval
88 |
89 |
90 | Activate conda environment and login to [wandb](https://wandb.ai) (if you haven't already).
91 | ```console
92 | conda activate coord_bimanual
93 | wandb login
94 | ```
95 |
96 |
97 |
98 | ### Collecting Human Demonstration Data
99 |
100 | Run the following script to collect human demonstration:
101 |
102 | ```console
103 | python demo_push_data_collection.py
104 | ```
105 |
106 | The agent agent position can be controlled by the mouse. The following keys can be used to control the environment:
107 | - `Space` - Pause and step forward (increments `plan_idx`)
108 | - `R` - Retry current attempt
109 | - `Q` - Exit the script
110 |
111 |
112 |
113 | ### Training
114 | To launch training, run:
115 |
116 |
117 | ```console
118 | python train.py \
119 | --config-name=train_diffusion_trajectory_unet_lowdim_workspace.yaml
120 | ```
121 |
122 |
123 | ## Acknowledgement
124 | * Policy training implementation is adapted from [Diffusion Policy](https://github.com/real-stanford/diffusion_policy/tree/main).
125 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/third_party/franka/assets/assets.xml:
--------------------------------------------------------------------------------
1 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
--------------------------------------------------------------------------------
/state_diff/dataset/pusht_dataset.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | import torch
3 | import numpy as np
4 | import copy
5 | from state_diff.common.pytorch_util import dict_apply
6 | from state_diff.common.replay_buffer import ReplayBuffer
7 | from state_diff.common.sampler import (
8 | SequenceSampler, get_val_mask, downsample_mask)
9 | from state_diff.model.common.normalizer import LinearNormalizer
10 | from state_diff.dataset.base_dataset import BaseLowdimDataset
11 |
12 | class PushTLowdimDataset(BaseLowdimDataset):
13 | def __init__(self,
14 | zarr_path,
15 | horizon=1,
16 | pad_before=0,
17 | pad_after=0,
18 | obs_key='keypoint',
19 | state_key='state',
20 | action_key='action',
21 | seed=42,
22 | val_ratio=0.0,
23 | max_train_episodes=None,
24 | num_episodes=None,
25 | ):
26 | super().__init__()
27 | self.replay_buffer = ReplayBuffer.copy_from_path(
28 | zarr_path, keys=[obs_key, state_key, action_key], num_episodes=num_episodes)
29 | if ',' in action_key and 'action' in action_key:
30 | action_key = 'action'
31 | self.agent_len = 4
32 | else:
33 | self.agent_len = 2
34 | val_mask = get_val_mask(
35 | n_episodes=self.replay_buffer.n_episodes,
36 | val_ratio=val_ratio,
37 | seed=seed)
38 | train_mask = ~val_mask
39 | train_mask = downsample_mask(
40 | mask=train_mask,
41 | max_n=max_train_episodes,
42 | seed=seed)
43 |
44 | self.sampler = SequenceSampler(
45 | replay_buffer=self.replay_buffer,
46 | sequence_length=horizon,
47 | pad_before=pad_before,
48 | pad_after=pad_after,
49 | episode_mask=train_mask
50 | )
51 | self.obs_key = obs_key
52 | self.state_key = state_key
53 | self.action_key = action_key
54 | self.train_mask = train_mask
55 | self.horizon = horizon
56 | self.pad_before = pad_before
57 | self.pad_after = pad_after
58 |
59 | def get_validation_dataset(self):
60 | val_set = copy.copy(self)
61 | val_set.sampler = SequenceSampler(
62 | replay_buffer=self.replay_buffer,
63 | sequence_length=self.horizon,
64 | pad_before=self.pad_before,
65 | pad_after=self.pad_after,
66 | episode_mask=~self.train_mask
67 | )
68 | val_set.train_mask = ~self.train_mask
69 | return val_set
70 |
71 | def get_normalizer(self, mode='limits', **kwargs):
72 | data = self._sample_to_data(self.replay_buffer)
73 | normalizer = LinearNormalizer()
74 | normalizer.fit(data=data, last_n_dims=1, mode=mode, **kwargs)
75 | return normalizer
76 |
77 | def get_all_actions(self) -> torch.Tensor:
78 | return torch.from_numpy(self.replay_buffer[self.action_key])
79 |
80 | def __len__(self) -> int:
81 | return len(self.sampler)
82 |
83 | def _sample_to_data(self, sample):
84 | keypoint = sample[self.obs_key]
85 | state = sample[self.state_key]
86 | agent_pos = state[:,:self.agent_len]
87 | obs = np.concatenate([
88 | keypoint.reshape(keypoint.shape[0], -1),
89 | agent_pos], axis=-1)
90 |
91 | data = {
92 | 'obs': obs, # T, D_o
93 | 'action': sample[self.action_key], # T, D_a
94 | }
95 | return data
96 |
97 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
98 | sample = self.sampler.sample_sequence(idx)
99 | data = self._sample_to_data(sample)
100 |
101 | torch_data = dict_apply(data, torch.from_numpy)
102 | return torch_data
103 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/third_party/franka/assets/chain0_overlay.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_envs/adept_envs/utils/config.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2020 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import numpy as np
18 | try:
19 | import cElementTree as ET
20 | except ImportError:
21 | try:
22 | # Python 2.5 need to import a different module
23 | import xml.etree.cElementTree as ET
24 | except ImportError:
25 | exit_err("Failed to import cElementTree from any known place")
26 |
27 | CONFIG_XML_DATA = """
28 |
29 |
30 |
31 |
32 |
33 | """
34 |
35 |
36 | # Read config from root
37 | def read_config_from_node(root_node, parent_name, child_name, dtype=int):
38 | # find parent
39 | parent_node = root_node.find(parent_name)
40 | if parent_node == None:
41 | quit("Parent %s not found" % parent_name)
42 |
43 | # get child data
44 | child_data = parent_node.get(child_name)
45 | if child_data == None:
46 | quit("Child %s not found" % child_name)
47 |
48 | config_val = np.array(child_data.split(), dtype=dtype)
49 | return config_val
50 |
51 |
52 | # get config frlom file or string
53 | def get_config_root_node(config_file_name=None, config_file_data=None):
54 | try:
55 | # get root
56 | if config_file_data is None:
57 | config_file_content = open(config_file_name, "r")
58 | config = ET.parse(config_file_content)
59 | root_node = config.getroot()
60 | else:
61 | root_node = ET.fromstring(config_file_data)
62 |
63 | # get root data
64 | root_data = root_node.get('name')
65 | root_name = np.array(root_data.split(), dtype=str)
66 | except:
67 | quit("ERROR: Unable to process config file %s" % config_file_name)
68 |
69 | return root_node, root_name
70 |
71 |
72 | # Read config from config_file
73 | def read_config_from_xml(config_file_name, parent_name, child_name, dtype=int):
74 | root_node, root_name = get_config_root_node(
75 | config_file_name=config_file_name)
76 | return read_config_from_node(root_node, parent_name, child_name, dtype)
77 |
78 |
79 | # tests
80 | if __name__ == '__main__':
81 | print("Read config and parse -------------------------")
82 | root, root_name = get_config_root_node(config_file_data=CONFIG_XML_DATA)
83 | print("Root:name \t", root_name)
84 | print("limit:low \t", read_config_from_node(root, "limits", "low", float))
85 | print("limit:high \t", read_config_from_node(root, "limits", "high", float))
86 | print("scale:joint \t", read_config_from_node(root, "scale", "joint",
87 | float))
88 | print("data:type \t", read_config_from_node(root, "data", "type", str))
89 |
90 | # read straight from xml (dumb the XML data as duh.xml for this test)
91 | root, root_name = get_config_root_node(config_file_name="duh.xml")
92 | print("Read from xml --------------------------------")
93 | print("limit:low \t", read_config_from_xml("duh.xml", "limits", "low",
94 | float))
95 | print("limit:high \t",
96 | read_config_from_xml("duh.xml", "limits", "high", float))
97 | print("scale:joint \t",
98 | read_config_from_xml("duh.xml", "scale", "joint", float))
99 | print("data:type \t", read_config_from_xml("duh.xml", "data", "type", str))
100 |
--------------------------------------------------------------------------------
/state_diff/dataset/pusht_traj_dataset.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | import torch
3 | import numpy as np
4 | import copy
5 | from state_diff.common.pytorch_util import dict_apply
6 | from state_diff.common.replay_buffer import ReplayBuffer
7 | from state_diff.common.sampler import (
8 | SequenceSampler, get_val_mask, downsample_mask)
9 | from state_diff.model.common.normalizer import LinearNormalizer
10 | from state_diff.dataset.base_dataset import BaseLowdimDataset
11 |
12 | class PushTTrajLowdimDataset(BaseLowdimDataset):
13 | def __init__(self,
14 | zarr_path,
15 | horizon=1,
16 | pad_before=0,
17 | pad_after=0,
18 | obs_key='keypoint',
19 | state_key='state',
20 | action_key='action',
21 | seed=42,
22 | val_ratio=0.0,
23 | max_train_episodes=None,
24 | num_episodes=None,
25 | **kwargs
26 | ):
27 | super().__init__()
28 | self.replay_buffer = ReplayBuffer.copy_from_path(
29 | zarr_path, keys=[obs_key, state_key, action_key], num_episodes=num_episodes)
30 |
31 | if ',' in action_key and 'action' in action_key:
32 | action_key = 'action'
33 | self.num_agents = kwargs.get('num_agents', 1)
34 | self.agent_len = kwargs.get('action_dim', 2) * self.num_agents
35 | # self.agent_len = kwargs.get('action_dim', 2) * self.num_agents need to debug here; kwargs.get('action_dim', 2) always returns 2
36 |
37 | val_mask = get_val_mask(
38 | n_episodes=self.replay_buffer.n_episodes,
39 | val_ratio=val_ratio,
40 | seed=seed)
41 | train_mask = ~val_mask
42 | train_mask = downsample_mask(
43 | mask=train_mask,
44 | max_n=max_train_episodes,
45 | seed=seed)
46 |
47 | self.sampler = SequenceSampler(
48 | replay_buffer=self.replay_buffer,
49 | sequence_length=horizon,
50 | pad_before=pad_before,
51 | pad_after=pad_after,
52 | episode_mask=train_mask
53 | )
54 | self.obs_key = obs_key
55 | self.state_key = state_key
56 | self.action_key = action_key
57 | self.train_mask = train_mask
58 | self.horizon = horizon
59 | self.pad_before = pad_before
60 | self.pad_after = pad_after
61 |
62 | def get_validation_dataset(self):
63 | val_set = copy.copy(self)
64 | val_set.sampler = SequenceSampler(
65 | replay_buffer=self.replay_buffer,
66 | sequence_length=self.horizon,
67 | pad_before=self.pad_before,
68 | pad_after=self.pad_after,
69 | episode_mask=~self.train_mask
70 | )
71 | val_set.train_mask = ~self.train_mask
72 | return val_set
73 |
74 | def get_normalizer(self, mode='limits', **kwargs):
75 | data = self._sample_to_data(self.replay_buffer)
76 | normalizer = LinearNormalizer()
77 | normalizer.fit(data=data, last_n_dims=1, mode=mode, **kwargs)
78 | return normalizer
79 |
80 | def get_all_actions(self) -> torch.Tensor:
81 | return torch.from_numpy(self.replay_buffer[self.action_key])
82 |
83 |
84 | def __len__(self) -> int:
85 | return len(self.sampler)
86 |
87 | def _sample_to_data(self, sample):
88 | keypoint = sample[self.obs_key]
89 | state = sample[self.state_key]
90 | agent_pos = state[:,:self.agent_len]
91 | obs = np.concatenate([
92 | keypoint.reshape(keypoint.shape[0], -1),
93 | agent_pos], axis=-1)
94 |
95 | data = {
96 | 'obs': obs, # T, D_o
97 | 'action': sample[self.action_key], # T, D_a
98 | }
99 | return data
100 |
101 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
102 | sample = self.sampler.sample_sequence(idx)
103 | data = self._sample_to_data(sample)
104 |
105 | torch_data = dict_apply(data, torch.from_numpy)
106 | return torch_data
107 |
--------------------------------------------------------------------------------
/state_diff/env/kitchen/relay_policy_learning/adept_envs/adept_envs/simulation/module.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2020 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Module for caching Python modules related to simulation."""
18 |
19 | import sys
20 |
21 | _MUJOCO_PY_MODULE = None
22 |
23 | _DM_MUJOCO_MODULE = None
24 | _DM_VIEWER_MODULE = None
25 | _DM_RENDER_MODULE = None
26 |
27 | _GLFW_MODULE = None
28 |
29 |
30 | def get_mujoco_py():
31 | """Returns the mujoco_py module."""
32 | global _MUJOCO_PY_MODULE
33 | if _MUJOCO_PY_MODULE:
34 | return _MUJOCO_PY_MODULE
35 | try:
36 | import mujoco_py
37 | # Override the warning function.
38 | from mujoco_py.builder import cymj
39 | cymj.set_warning_callback(_mj_warning_fn)
40 | except ImportError:
41 | print(
42 | 'Failed to import mujoco_py. Ensure that mujoco_py (using MuJoCo '
43 | 'v1.50) is installed.',
44 | file=sys.stderr)
45 | sys.exit(1)
46 | _MUJOCO_PY_MODULE = mujoco_py
47 | return mujoco_py
48 |
49 |
50 | def get_mujoco_py_mjlib():
51 | """Returns the mujoco_py mjlib module."""
52 |
53 | class MjlibDelegate:
54 | """Wrapper that forwards mjlib calls."""
55 |
56 | def __init__(self, lib):
57 | self._lib = lib
58 |
59 | def __getattr__(self, name: str):
60 | if name.startswith('mj'):
61 | return getattr(self._lib, '_' + name)
62 | raise AttributeError(name)
63 |
64 | return MjlibDelegate(get_mujoco_py().cymj)
65 |
66 |
67 | def get_dm_mujoco():
68 | """Returns the DM Control mujoco module."""
69 | global _DM_MUJOCO_MODULE
70 | if _DM_MUJOCO_MODULE:
71 | return _DM_MUJOCO_MODULE
72 | try:
73 | from dm_control import mujoco
74 | except ImportError:
75 | print(
76 | 'Failed to import dm_control.mujoco. Ensure that dm_control (using '
77 | 'MuJoCo v2.00) is installed.',
78 | file=sys.stderr)
79 | sys.exit(1)
80 | _DM_MUJOCO_MODULE = mujoco
81 | return mujoco
82 |
83 |
84 | def get_dm_viewer():
85 | """Returns the DM Control viewer module."""
86 | global _DM_VIEWER_MODULE
87 | if _DM_VIEWER_MODULE:
88 | return _DM_VIEWER_MODULE
89 | try:
90 | from dm_control import viewer
91 | except ImportError:
92 | print(
93 | 'Failed to import dm_control.viewer. Ensure that dm_control (using '
94 | 'MuJoCo v2.00) is installed.',
95 | file=sys.stderr)
96 | sys.exit(1)
97 | _DM_VIEWER_MODULE = viewer
98 | return viewer
99 |
100 |
101 | def get_dm_render():
102 | """Returns the DM Control render module."""
103 | global _DM_RENDER_MODULE
104 | if _DM_RENDER_MODULE:
105 | return _DM_RENDER_MODULE
106 | try:
107 | try:
108 | from dm_control import _render
109 | render = _render
110 | except ImportError:
111 | print('Warning: DM Control is out of date.')
112 | from dm_control import render
113 | except ImportError:
114 | print(
115 | 'Failed to import dm_control.render. Ensure that dm_control (using '
116 | 'MuJoCo v2.00) is installed.',
117 | file=sys.stderr)
118 | sys.exit(1)
119 | _DM_RENDER_MODULE = render
120 | return render
121 |
122 |
123 | def _mj_warning_fn(warn_data: bytes):
124 | """Warning function override for mujoco_py."""
125 | print('WARNING: Mujoco simulation is unstable (has NaNs): {}'.format(
126 | warn_data.decode()))
127 |
--------------------------------------------------------------------------------