├── state_diff ├── env │ ├── kitchen │ │ ├── relay_policy_learning │ │ │ ├── adept_models │ │ │ │ ├── __init__.py │ │ │ │ ├── .gitignore │ │ │ │ ├── kitchen │ │ │ │ │ ├── meshes │ │ │ │ │ │ ├── hood.stl │ │ │ │ │ │ ├── knob.stl │ │ │ │ │ │ ├── oven.stl │ │ │ │ │ │ ├── tile.stl │ │ │ │ │ │ ├── wall.stl │ │ │ │ │ │ ├── faucet.stl │ │ │ │ │ │ ├── handle2.stl │ │ │ │ │ │ ├── kettle.stl │ │ │ │ │ │ ├── micro.stl │ │ │ │ │ │ ├── oventop.stl │ │ │ │ │ │ ├── hingedoor.stl │ │ │ │ │ │ ├── microdoor.stl │ │ │ │ │ │ ├── microfeet.stl │ │ │ │ │ │ ├── slidedoor.stl │ │ │ │ │ │ ├── stoverim.stl │ │ │ │ │ │ ├── burnerplate.stl │ │ │ │ │ │ ├── cabinetbase.stl │ │ │ │ │ │ ├── countertop.stl │ │ │ │ │ │ ├── hingecabinet.stl │ │ │ │ │ │ ├── hingehandle.stl │ │ │ │ │ │ ├── kettlehandle.stl │ │ │ │ │ │ ├── lightswitch.stl │ │ │ │ │ │ ├── microbutton.stl │ │ │ │ │ │ ├── microefeet.stl │ │ │ │ │ │ ├── microhandle.stl │ │ │ │ │ │ ├── microwindow.stl │ │ │ │ │ │ ├── ovenhandle.stl │ │ │ │ │ │ ├── ovenwindow.stl │ │ │ │ │ │ ├── slidecabinet.stl │ │ │ │ │ │ ├── cabinetdrawer.stl │ │ │ │ │ │ ├── cabinethandle.stl │ │ │ │ │ │ ├── burnerplate_mesh.stl │ │ │ │ │ │ └── lightswitchbase.stl │ │ │ │ │ ├── textures │ │ │ │ │ │ ├── tile1.png │ │ │ │ │ │ ├── wood1.png │ │ │ │ │ │ ├── marble1.png │ │ │ │ │ │ └── metal1.png │ │ │ │ │ ├── microwave.xml │ │ │ │ │ ├── oven.xml │ │ │ │ │ ├── kettle.xml │ │ │ │ │ ├── counters.xml │ │ │ │ │ ├── hingecabinet.xml │ │ │ │ │ ├── slidecabinet.xml │ │ │ │ │ ├── assets │ │ │ │ │ │ ├── backwall_chain.xml │ │ │ │ │ │ ├── backwall_asset.xml │ │ │ │ │ │ ├── slidecabinet_asset.xml │ │ │ │ │ │ ├── hingecabinet_asset.xml │ │ │ │ │ │ ├── kettle_chain.xml │ │ │ │ │ │ ├── kettle_asset.xml │ │ │ │ │ │ ├── microwave_asset.xml │ │ │ │ │ │ ├── counters_asset.xml │ │ │ │ │ │ ├── slidecabinet_chain.xml │ │ │ │ │ │ ├── microwave_chain.xml │ │ │ │ │ │ └── oven_asset.xml │ │ │ │ │ └── kitchen.xml │ │ │ │ ├── scenes │ │ │ │ │ ├── textures │ │ │ │ │ │ ├── white_marble_tile.png │ │ │ │ │ │ └── white_marble_tile2.png │ │ │ │ │ └── basic_scene.xml │ │ │ │ ├── README.public.md │ │ │ │ └── CONTRIBUTING.public.md │ │ │ ├── third_party │ │ │ │ └── franka │ │ │ │ │ ├── franka_panda.png │ │ │ │ │ ├── meshes │ │ │ │ │ ├── visual │ │ │ │ │ │ ├── hand.stl │ │ │ │ │ │ ├── finger.stl │ │ │ │ │ │ ├── link0.stl │ │ │ │ │ │ ├── link1.stl │ │ │ │ │ │ ├── link2.stl │ │ │ │ │ │ ├── link3.stl │ │ │ │ │ │ ├── link4.stl │ │ │ │ │ │ ├── link5.stl │ │ │ │ │ │ ├── link6.stl │ │ │ │ │ │ └── link7.stl │ │ │ │ │ └── collision │ │ │ │ │ │ ├── hand.stl │ │ │ │ │ │ ├── finger.stl │ │ │ │ │ │ ├── link0.stl │ │ │ │ │ │ ├── link1.stl │ │ │ │ │ │ ├── link2.stl │ │ │ │ │ │ ├── link3.stl │ │ │ │ │ │ ├── link4.stl │ │ │ │ │ │ ├── link5.stl │ │ │ │ │ │ ├── link6.stl │ │ │ │ │ │ └── link7.stl │ │ │ │ │ ├── README.md │ │ │ │ │ ├── assets │ │ │ │ │ ├── basic_scene.xml │ │ │ │ │ ├── teleop_actuator.xml │ │ │ │ │ ├── actuator1.xml │ │ │ │ │ ├── actuator0.xml │ │ │ │ │ ├── assets.xml │ │ │ │ │ └── chain0_overlay.xml │ │ │ │ │ ├── franka_panda.xml │ │ │ │ │ ├── franka_panda_teleop.xml │ │ │ │ │ └── bi-franka_panda.xml │ │ │ └── adept_envs │ │ │ │ └── adept_envs │ │ │ │ ├── __init__.py │ │ │ │ ├── franka │ │ │ │ └── __init__.py │ │ │ │ ├── utils │ │ │ │ ├── constants.py │ │ │ │ └── config.py │ │ │ │ └── simulation │ │ │ │ └── module.py │ │ ├── v0.py │ │ ├── __init__.py │ │ ├── kitchen_lowdim_wrapper.py │ │ └── kitchen_util.py │ ├── pusht │ │ └── __init__.py │ └── block_pushing │ │ ├── assets │ │ ├── plane.obj │ │ ├── zone.urdf │ │ ├── zone2.urdf │ │ ├── workspace.urdf │ │ ├── workspace_real.urdf │ │ ├── block.urdf │ │ ├── block2.urdf │ │ ├── blocks │ │ │ ├── red_moon.urdf │ │ │ ├── blue_cube.urdf │ │ │ ├── green_star.urdf │ │ │ └── yellow_pentagon.urdf │ │ ├── zone.obj │ │ ├── insert.urdf │ │ └── suction │ │ │ ├── suction-base.urdf │ │ │ ├── suction-head.urdf │ │ │ ├── cylinder_real.urdf │ │ │ ├── cylinder.urdf │ │ │ └── suction-head-long.urdf │ │ ├── oracles │ │ ├── pushing_info.py │ │ ├── reach_oracle.py │ │ └── discontinuous_push_oracle.py │ │ └── utils │ │ └── pose3d.py ├── env_runner │ └── base_lowdim_runner.py ├── model │ ├── common │ │ ├── module_attr_mixin.py │ │ ├── shape_util.py │ │ ├── dict_of_tensor_mixin.py │ │ ├── lr_scheduler.py │ │ ├── losses.py │ │ └── rotation_transformer.py │ └── diffusion │ │ ├── positional_embedding.py │ │ └── conv1d_components.py ├── common │ ├── env_util.py │ ├── nested_dict_util.py │ ├── pymunk_util.py │ ├── robomimic_config_util.py │ ├── precise_sleep.py │ ├── pytorch_util.py │ └── checkpoint_util.py ├── scripts │ ├── episode_lengths.py │ ├── blockpush_abs_conversion.py │ ├── bet_blockpush_conversion.py │ ├── real_dataset_conversion.py │ ├── generate_bet_blockpush.py │ └── real_pusht_successrate.py ├── config │ └── task │ │ ├── kitchen_lowdim.yaml │ │ ├── blockpush_lowdim_seed.yaml │ │ ├── blockpush_lowdim_seed_abs.yaml │ │ ├── blockpush_traj_lowdim_seed.yaml │ │ ├── blockpush_traj_lowdim_seed_abs.yaml │ │ ├── kitchen_lowdim_abs.yaml │ │ ├── pushl_traj_lowdim.yaml │ │ └── pushl_lowdim.yaml ├── shared_memory │ └── shared_memory_util.py ├── gym_util │ ├── video_wrapper.py │ └── video_recording_wrapper.py └── dataset │ ├── base_dataset.py │ ├── kitchen_lowdim_dataset.py │ ├── pusht_dataset.py │ └── pusht_traj_dataset.py ├── media └── teaser_image.png ├── pyrightconfig.json ├── setup.py ├── tests ├── test_cv2_util.py ├── test_robomimic_lowdim_runner.py ├── test_robomimic_image_runner.py ├── test_block_pushing.py ├── test_precise_sleep.py ├── test_shared_queue.py ├── test_replay_buffer.py ├── test_multi_realsense.py └── test_single_realsense.py ├── train.py ├── conda_environment.yaml ├── .gitignore └── README.md /state_diff/env/kitchen/relay_policy_learning/adept_models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /media/teaser_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/media/teaser_image.png -------------------------------------------------------------------------------- /pyrightconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "exclude": [ 3 | "data/**", 4 | "data_local/**", 5 | "outputs/**" 6 | ] 7 | } -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'state_diff', 5 | packages=find_packages(), 6 | ) 7 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/.gitignore: -------------------------------------------------------------------------------- 1 | # General 2 | .DS_Store 3 | *.swp 4 | *.profraw 5 | 6 | # Editors 7 | .vscode 8 | .idea 9 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/hood.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/hood.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/knob.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/knob.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/oven.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/oven.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/tile.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/tile.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/wall.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/wall.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/third_party/franka/franka_panda.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/franka_panda.png -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/faucet.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/faucet.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/handle2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/handle2.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/kettle.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/kettle.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/micro.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/micro.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/oventop.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/oventop.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/textures/tile1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/textures/tile1.png -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/textures/wood1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/textures/wood1.png -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/hingedoor.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/hingedoor.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/microdoor.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/microdoor.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/microfeet.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/microfeet.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/slidedoor.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/slidedoor.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/stoverim.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/stoverim.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/textures/marble1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/textures/marble1.png -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/textures/metal1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/textures/metal1.png -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/hand.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/hand.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/burnerplate.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/burnerplate.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/cabinetbase.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/cabinetbase.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/countertop.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/countertop.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/hingecabinet.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/hingecabinet.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/hingehandle.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/hingehandle.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/kettlehandle.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/kettlehandle.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/lightswitch.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/lightswitch.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/microbutton.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/microbutton.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/microefeet.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/microefeet.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/microhandle.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/microhandle.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/microwindow.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/microwindow.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/ovenhandle.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/ovenhandle.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/ovenwindow.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/ovenwindow.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/slidecabinet.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/slidecabinet.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/hand.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/hand.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/finger.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/finger.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/link0.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/link0.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/link1.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/link1.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/link2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/link2.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/link3.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/link3.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/link4.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/link4.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/link5.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/link5.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/link6.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/link6.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/link7.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/visual/link7.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/cabinetdrawer.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/cabinetdrawer.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/cabinethandle.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/cabinethandle.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/finger.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/finger.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link0.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link0.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link1.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link1.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link2.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link3.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link3.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link4.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link4.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link5.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link5.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link6.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link6.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link7.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/third_party/franka/meshes/collision/link7.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/burnerplate_mesh.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/burnerplate_mesh.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/lightswitchbase.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/meshes/lightswitchbase.stl -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/scenes/textures/white_marble_tile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/scenes/textures/white_marble_tile.png -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/scenes/textures/white_marble_tile2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haonan16/state_diff/HEAD/state_diff/env/kitchen/relay_policy_learning/adept_models/scenes/textures/white_marble_tile2.png -------------------------------------------------------------------------------- /state_diff/env/pusht/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | import state_diff.env.pusht 3 | 4 | register( 5 | id='pusht-keypoints-v0', 6 | entry_point='envs.pusht.pusht_keypoints_env:PushTKeypointsEnv', 7 | max_episode_steps=200, 8 | reward_threshold=1.0 9 | ) -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/third_party/franka/README.md: -------------------------------------------------------------------------------- 1 | # franka 2 | Franka panda mujoco models 3 | 4 | 5 | # Environment 6 | 7 | franka_panda.xml | coming soon 8 | :-------------------------:|:-------------------------: 9 | ![Alt text](franka_panda.png?raw=false "sawyer") | coming soon 10 | -------------------------------------------------------------------------------- /state_diff/env_runner/base_lowdim_runner.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from state_diff.policy.base_lowdim_policy import BaseLowdimPolicy 3 | 4 | class BaseLowdimRunner: 5 | def __init__(self, output_dir): 6 | self.output_dir = output_dir 7 | 8 | def run(self, policy: BaseLowdimPolicy) -> Dict: 9 | raise NotImplementedError() 10 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/README.public.md: -------------------------------------------------------------------------------- 1 | # D'Suite Scenes 2 | 3 | This repository is based on a collection of [MuJoCo](http://www.mujoco.org/) simulation 4 | scenes and common assets for D'Suite environments. Based on code in the ROBEL suite 5 | https://github.com/google-research/robel 6 | 7 | ## Disclaimer 8 | 9 | This is not an official Google product. 10 | 11 | -------------------------------------------------------------------------------- /state_diff/model/common/module_attr_mixin.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class ModuleAttrMixin(nn.Module): 4 | def __init__(self): 5 | super().__init__() 6 | self._dummy_variable = nn.Parameter() 7 | 8 | @property 9 | def device(self): 10 | return next(iter(self.parameters())).device 11 | 12 | @property 13 | def dtype(self): 14 | return next(iter(self.parameters())).dtype 15 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/microwave.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/oven.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/kettle.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/counters.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/hingecabinet.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /state_diff/env/block_pushing/assets/plane.obj: -------------------------------------------------------------------------------- 1 | # Blender v2.66 (sub 1) OBJ File: '' 2 | # www.blender.org 3 | mtllib plane.mtl 4 | o Plane 5 | v 15.000000 -15.000000 0.000000 6 | v 15.000000 15.000000 0.000000 7 | v -15.000000 15.000000 0.000000 8 | v -15.000000 -15.000000 0.000000 9 | 10 | vt 15.000000 0.000000 11 | vt 15.000000 15.000000 12 | vt 0.000000 15.000000 13 | vt 0.000000 0.000000 14 | 15 | usemtl Material 16 | s off 17 | f 1/1 2/2 3/3 18 | f 1/1 3/3 4/4 19 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/slidecabinet.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /state_diff/model/diffusion/positional_embedding.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | class SinusoidalPosEmb(nn.Module): 6 | def __init__(self, dim): 7 | super().__init__() 8 | self.dim = dim 9 | 10 | def forward(self, x): 11 | device = x.device 12 | half_dim = self.dim // 2 13 | emb = math.log(10000) / (half_dim - 1) 14 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 15 | emb = x[:, None] * emb[None, :] 16 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 17 | return emb 18 | -------------------------------------------------------------------------------- /tests/test_cv2_util.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | ROOT_DIR = os.path.dirname(os.path.dirname(__file__)) 5 | sys.path.append(ROOT_DIR) 6 | os.chdir(ROOT_DIR) 7 | 8 | import numpy as np 9 | from state_diff.common.cv2_util import get_image_transform 10 | 11 | 12 | def test(): 13 | tf = get_image_transform((1280,720), (640,480), bgr_to_rgb=False) 14 | in_img = np.zeros((720,1280,3), dtype=np.uint8) 15 | out_img = tf(in_img) 16 | # print(out_img.shape) 17 | assert out_img.shape == (480,640,3) 18 | 19 | # import pdb; pdb.set_trace() 20 | 21 | if __name__ == '__main__': 22 | test() 23 | -------------------------------------------------------------------------------- /state_diff/env/block_pushing/assets/zone.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /state_diff/env/block_pushing/assets/zone2.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/v0.py: -------------------------------------------------------------------------------- 1 | from state_diff.env.kitchen.base import KitchenBase 2 | 3 | 4 | class KitchenMicrowaveKettleBottomBurnerLightV0(KitchenBase): 5 | TASK_ELEMENTS = ["microwave", "kettle", "bottom burner", "light switch"] 6 | COMPLETE_IN_ANY_ORDER = False 7 | 8 | 9 | class KitchenMicrowaveKettleLightSliderV0(KitchenBase): 10 | TASK_ELEMENTS = ["microwave", "kettle", "light switch", "slide cabinet"] 11 | COMPLETE_IN_ANY_ORDER = False 12 | 13 | 14 | class KitchenKettleMicrowaveLightSliderV0(KitchenBase): 15 | TASK_ELEMENTS = ["kettle", "microwave", "light switch", "slide cabinet"] 16 | COMPLETE_IN_ANY_ORDER = False 17 | 18 | 19 | class KitchenAllV0(KitchenBase): 20 | TASK_ELEMENTS = KitchenBase.ALL_TASKS 21 | -------------------------------------------------------------------------------- /state_diff/common/env_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def render_env_video(env, states, actions=None): 6 | observations = states 7 | imgs = list() 8 | for i in range(len(observations)): 9 | state = observations[i] 10 | env.set_state(state) 11 | if i == 0: 12 | env.set_state(state) 13 | img = env.render() 14 | # draw action 15 | if actions is not None: 16 | action = actions[i] 17 | coord = (action / 512 * 96).astype(np.int32) 18 | cv2.drawMarker(img, coord, 19 | color=(255,0,0), markerType=cv2.MARKER_CROSS, 20 | markerSize=8, thickness=1) 21 | imgs.append(img) 22 | imgs = np.array(imgs) 23 | return imgs 24 | -------------------------------------------------------------------------------- /state_diff/model/common/shape_util.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple, Callable 2 | import torch 3 | import torch.nn as nn 4 | 5 | def get_module_device(m: nn.Module): 6 | device = torch.device('cpu') 7 | try: 8 | param = next(iter(m.parameters())) 9 | device = param.device 10 | except StopIteration: 11 | pass 12 | return device 13 | 14 | @torch.no_grad() 15 | def get_output_shape( 16 | input_shape: Tuple[int], 17 | net: Callable[[torch.Tensor], torch.Tensor] 18 | ): 19 | device = get_module_device(net) 20 | test_input = torch.zeros((1,)+tuple(input_shape), device=device) 21 | test_output = net(test_input) 22 | output_shape = tuple(test_output.shape[1:]) 23 | return output_shape 24 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_envs/adept_envs/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2020 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import adept_envs.franka 18 | 19 | from adept_envs.utils.configurable import global_config 20 | -------------------------------------------------------------------------------- /state_diff/scripts/episode_lengths.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | import sys 3 | import os 4 | import pathlib 5 | 6 | ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent) 7 | sys.path.append(ROOT_DIR) 8 | 9 | import click 10 | import numpy as np 11 | import json 12 | from state_diff.common.replay_buffer import ReplayBuffer 13 | 14 | @click.command() 15 | @click.option('--input', '-i', required=True) 16 | @click.option('--dt', default=0.1, type=float) 17 | def main(input, dt): 18 | buffer = ReplayBuffer.create_from_path(input) 19 | lengths = buffer.episode_lengths 20 | durations = lengths * dt 21 | result = { 22 | 'duration/mean': np.mean(durations) 23 | } 24 | 25 | text = json.dumps(result, indent=2) 26 | print(text) 27 | 28 | if __name__ == '__main__': 29 | main() 30 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/third_party/franka/assets/basic_scene.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /state_diff/common/nested_dict_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | def nested_dict_map(f, x): 4 | """ 5 | Map f over all leaf of nested dict x 6 | """ 7 | 8 | if not isinstance(x, dict): 9 | return f(x) 10 | y = dict() 11 | for key, value in x.items(): 12 | y[key] = nested_dict_map(f, value) 13 | return y 14 | 15 | def nested_dict_reduce(f, x): 16 | """ 17 | Map f over all values of nested dict x, and reduce to a single value 18 | """ 19 | if not isinstance(x, dict): 20 | return x 21 | 22 | reduced_values = list() 23 | for value in x.values(): 24 | reduced_values.append(nested_dict_reduce(f, value)) 25 | y = functools.reduce(f, reduced_values) 26 | return y 27 | 28 | 29 | def nested_dict_check(f, x): 30 | bool_dict = nested_dict_map(f, x) 31 | result = nested_dict_reduce(lambda x, y: x and y, bool_dict) 32 | return result 33 | -------------------------------------------------------------------------------- /state_diff/env/block_pushing/assets/workspace.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /state_diff/env/block_pushing/assets/workspace_real.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /state_diff/config/task/kitchen_lowdim.yaml: -------------------------------------------------------------------------------- 1 | name: kitchen_lowdim 2 | 3 | obs_dim: 60 4 | action_dim: 9 5 | keypoint_dim: 3 6 | num_agents: 1 7 | 8 | dataset_dir: &dataset_dir data/kitchen 9 | 10 | env_runner: 11 | _target_: state_diff.env_runner.kitchen_lowdim_runner.KitchenLowdimRunner 12 | dataset_dir: *dataset_dir 13 | n_train: 6 14 | n_train_vis: 2 15 | train_start_seed: 0 16 | n_test: 50 17 | n_test_vis: 4 18 | test_start_seed: 100000 19 | max_steps: 280 20 | n_obs_steps: ${n_obs_steps} 21 | n_action_steps: ${n_action_steps} 22 | render_hw: [240, 360] 23 | fps: 12.5 24 | past_action: ${past_action_visible} 25 | n_envs: null 26 | 27 | dataset: 28 | _target_: state_diff.dataset.kitchen_lowdim_dataset.KitchenLowdimDataset 29 | dataset_dir: *dataset_dir 30 | horizon: ${horizon} 31 | pad_before: ${eval:'${n_obs_steps}-1'} 32 | pad_after: ${eval:'${n_action_steps}-1'} 33 | seed: 42 34 | val_ratio: 0.02 35 | -------------------------------------------------------------------------------- /state_diff/env/block_pushing/assets/block.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /state_diff/env/block_pushing/assets/block2.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_envs/adept_envs/franka/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2020 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from gym.envs.registration import register 18 | 19 | # Relax the robot 20 | register( 21 | id='kitchen_relax-v1', 22 | entry_point='adept_envs.franka.kitchen_multitask_v0:KitchenTaskRelaxV1', 23 | max_episode_steps=280, 24 | ) -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_envs/adept_envs/utils/constants.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2020 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import os 18 | 19 | ENVS_ROOT_PATH = os.path.abspath(os.path.join( 20 | os.path.dirname(os.path.abspath(__file__)), 21 | "../../")) 22 | 23 | MODELS_PATH = os.path.abspath(os.path.join(ENVS_ROOT_PATH, "../adept_models/")) 24 | -------------------------------------------------------------------------------- /state_diff/env/block_pushing/assets/blocks/red_moon.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /state_diff/env/block_pushing/assets/blocks/blue_cube.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /state_diff/env/block_pushing/assets/blocks/green_star.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/__init__.py: -------------------------------------------------------------------------------- 1 | """Environments using kitchen and Franka robot.""" 2 | from gym.envs.registration import register 3 | 4 | register( 5 | id="kitchen-microwave-kettle-light-slider-v0", 6 | entry_point="state_diff.env.kitchen.v0:KitchenMicrowaveKettleLightSliderV0", 7 | max_episode_steps=280, 8 | reward_threshold=1.0, 9 | ) 10 | 11 | register( 12 | id="kitchen-microwave-kettle-burner-light-v0", 13 | entry_point="state_diff.env.kitchen.v0:KitchenMicrowaveKettleBottomBurnerLightV0", 14 | max_episode_steps=280, 15 | reward_threshold=1.0, 16 | ) 17 | 18 | register( 19 | id="kitchen-kettle-microwave-light-slider-v0", 20 | entry_point="state_diff.env.kitchen.v0:KitchenKettleMicrowaveLightSliderV0", 21 | max_episode_steps=280, 22 | reward_threshold=1.0, 23 | ) 24 | 25 | register( 26 | id="kitchen-all-v0", 27 | entry_point="state_diff.env.kitchen.v0:KitchenAllV0", 28 | max_episode_steps=280, 29 | reward_threshold=1.0, 30 | ) 31 | -------------------------------------------------------------------------------- /state_diff/scripts/blockpush_abs_conversion.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | import sys 3 | import os 4 | import pathlib 5 | 6 | ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent) 7 | sys.path.append(ROOT_DIR) 8 | 9 | import os 10 | import click 11 | import pathlib 12 | from state_diff.common.replay_buffer import ReplayBuffer 13 | 14 | 15 | @click.command() 16 | @click.option('-i', '--input', required=True) 17 | @click.option('-o', '--output', required=True) 18 | @click.option('-t', '--target_eef_idx', default=8, type=int) 19 | def main(input, output, target_eef_idx): 20 | buffer = ReplayBuffer.copy_from_path(input) 21 | obs = buffer['obs'] 22 | action = buffer['action'] 23 | prev_eef_target = obs[:,target_eef_idx:target_eef_idx+action.shape[1]] 24 | next_eef_target = prev_eef_target + action 25 | action[:] = next_eef_target 26 | buffer.save_to_path(zarr_path=output, chunk_length=-1) 27 | 28 | if __name__ == '__main__': 29 | main() 30 | -------------------------------------------------------------------------------- /state_diff/env/block_pushing/assets/blocks/yellow_pentagon.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /state_diff/config/task/blockpush_lowdim_seed.yaml: -------------------------------------------------------------------------------- 1 | name: blockpush_lowdim_seed 2 | 3 | obs_dim: 16 4 | action_dim: 2 5 | keypoint_dim: 2 6 | obs_eef_target: True 7 | num_agents: 1 8 | 9 | env_runner: 10 | _target_: state_diff.env_runner.blockpush_lowdim_runner.BlockPushLowdimRunner 11 | n_train: 6 12 | n_train_vis: 2 13 | train_start_seed: 0 14 | n_test: 50 15 | n_test_vis: 4 16 | test_start_seed: 100000 17 | max_steps: 350 18 | n_obs_steps: ${n_obs_steps} 19 | n_action_steps: ${n_action_steps} 20 | fps: 5 21 | past_action: ${past_action_visible} 22 | abs_action: False 23 | obs_eef_target: ${task.obs_eef_target} 24 | n_envs: null 25 | 26 | dataset: 27 | _target_: state_diff.dataset.blockpush_lowdim_dataset.BlockPushLowdimDataset 28 | zarr_path: data/block_pushing/multimodal_push_seed.zarr 29 | horizon: ${horizon} 30 | pad_before: ${eval:'${n_obs_steps}-1'} 31 | pad_after: ${eval:'${n_action_steps}-1'} 32 | obs_eef_target: ${task.obs_eef_target} 33 | use_manual_normalizer: False 34 | seed: 42 35 | val_ratio: 0.02 36 | -------------------------------------------------------------------------------- /tests/test_robomimic_lowdim_runner.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | ROOT_DIR = os.path.dirname(os.path.dirname(__file__)) 5 | sys.path.append(ROOT_DIR) 6 | os.chdir(ROOT_DIR) 7 | 8 | from state_diff.env_runner.robomimic_lowdim_runner import RobomimicLowdimRunner 9 | 10 | def test(): 11 | import os 12 | from omegaconf import OmegaConf 13 | cfg_path = os.path.expanduser('/projects/haonan2/bimanual_ur5e/state_diff/config/task/lift_lowdim.yaml') 14 | cfg = OmegaConf.load(cfg_path) 15 | cfg['n_obs_steps'] = 1 16 | cfg['n_action_steps'] = 1 17 | cfg['past_action_visible'] = False 18 | runner_cfg = cfg['env_runner'] 19 | runner_cfg['n_train'] = 1 20 | runner_cfg['n_test'] = 0 21 | del runner_cfg['_target_'] 22 | runner = RobomimicLowdimRunner( 23 | **runner_cfg, 24 | output_dir='/tmp/test') 25 | 26 | # import pdb; pdb.set_trace() 27 | 28 | self = runner 29 | env = self.env 30 | env.seed(seeds=self.env_seeds) 31 | obs = env.reset() 32 | 33 | if __name__ == '__main__': 34 | test() 35 | -------------------------------------------------------------------------------- /state_diff/config/task/blockpush_lowdim_seed_abs.yaml: -------------------------------------------------------------------------------- 1 | name: blockpush_lowdim_seed_abs 2 | 3 | obs_dim: 16 4 | action_dim: 2 5 | keypoint_dim: 2 6 | obs_eef_target: True 7 | num_agents: 1 8 | 9 | env_runner: 10 | _target_: state_diff.env_runner.blockpush_lowdim_runner.BlockPushLowdimRunner 11 | n_train: 6 12 | n_train_vis: 2 13 | train_start_seed: 0 14 | n_test: 50 15 | n_test_vis: 4 16 | test_start_seed: 100000 17 | max_steps: 350 18 | n_obs_steps: ${n_obs_steps} 19 | n_action_steps: ${n_action_steps} 20 | fps: 5 21 | past_action: ${past_action_visible} 22 | abs_action: True 23 | obs_eef_target: ${task.obs_eef_target} 24 | n_envs: null 25 | 26 | dataset: 27 | _target_: state_diff.dataset.blockpush_lowdim_dataset.BlockPushLowdimDataset 28 | zarr_path: data/block_pushing/multimodal_push_seed_abs.zarr 29 | horizon: ${horizon} 30 | pad_before: ${eval:'${n_obs_steps}-1'} 31 | pad_after: ${eval:'${n_action_steps}-1'} 32 | obs_eef_target: ${task.obs_eef_target} 33 | use_manual_normalizer: False 34 | seed: 42 35 | val_ratio: 0.02 36 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/assets/backwall_chain.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /state_diff/config/task/blockpush_traj_lowdim_seed.yaml: -------------------------------------------------------------------------------- 1 | name: blockpush_traj_lowdim_seed 2 | 3 | obs_dim: 16 4 | action_dim: 2 5 | keypoint_dim: 2 6 | inv_hidden_dim: 256 7 | obs_eef_target: True 8 | num_agents: 1 9 | 10 | env_runner: 11 | _target_: state_diff.env_runner.blockpush_lowdim_runner.BlockPushLowdimRunner 12 | n_train: 6 13 | n_train_vis: 2 14 | train_start_seed: 0 15 | n_test: 50 16 | n_test_vis: 4 17 | test_start_seed: 100000 18 | max_steps: 350 19 | n_obs_steps: ${n_obs_steps} 20 | n_action_steps: ${n_action_steps} 21 | fps: 5 22 | past_action: ${past_action_visible} 23 | abs_action: False 24 | obs_eef_target: ${task.obs_eef_target} 25 | n_envs: null 26 | num_agents: ${task.num_agents} 27 | 28 | dataset: 29 | _target_: state_diff.dataset.blockpush_lowdim_dataset.BlockPushLowdimDataset 30 | zarr_path: data/block_pushing/multimodal_push_seed.zarr 31 | horizon: ${horizon} 32 | pad_before: ${eval:'${n_obs_steps}-1'} 33 | pad_after: ${eval:'${n_action_steps}-1'} 34 | obs_eef_target: ${task.obs_eef_target} 35 | use_manual_normalizer: False 36 | seed: 42 37 | val_ratio: 0.02 38 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/assets/backwall_asset.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /state_diff/config/task/blockpush_traj_lowdim_seed_abs.yaml: -------------------------------------------------------------------------------- 1 | name: blockpush_traj_lowdim_seed_abs 2 | 3 | obs_dim: 16 4 | action_dim: 2 5 | keypoint_dim: 2 6 | inv_hidden_dim: 256 7 | obs_eef_target: True 8 | num_agents: 1 9 | 10 | env_runner: 11 | _target_: state_diff.env_runner.blockpush_lowdim_runner.BlockPushLowdimRunner 12 | n_train: 6 13 | n_train_vis: 2 14 | train_start_seed: 0 15 | n_test: 50 16 | n_test_vis: 4 17 | test_start_seed: 100000 18 | max_steps: 350 19 | n_obs_steps: ${n_obs_steps} 20 | n_action_steps: ${n_action_steps} 21 | fps: 5 22 | past_action: ${past_action_visible} 23 | abs_action: True 24 | obs_eef_target: ${task.obs_eef_target} 25 | n_envs: null 26 | num_agents: ${task.num_agents} 27 | 28 | dataset: 29 | _target_: state_diff.dataset.blockpush_lowdim_dataset.BlockPushLowdimDataset 30 | zarr_path: data/block_pushing/multimodal_push_seed_abs.zarr 31 | horizon: ${horizon} 32 | pad_before: ${eval:'${n_obs_steps}-1'} 33 | pad_after: ${eval:'${n_action_steps}-1'} 34 | obs_eef_target: ${task.obs_eef_target} 35 | use_manual_normalizer: False 36 | seed: 42 37 | val_ratio: 0.02 38 | -------------------------------------------------------------------------------- /state_diff/config/task/kitchen_lowdim_abs.yaml: -------------------------------------------------------------------------------- 1 | name: kitchen_lowdim 2 | 3 | obs_dim: 60 4 | action_dim: 9 5 | keypoint_dim: 3 6 | num_agents: 1 7 | 8 | abs_action: True 9 | robot_noise_ratio: 0.1 10 | 11 | env_runner: 12 | _target_: state_diff.env_runner.kitchen_lowdim_runner.KitchenLowdimRunner 13 | dataset_dir: data/kitchen 14 | n_train: 6 15 | n_train_vis: 2 16 | train_start_seed: 0 17 | n_test: 50 18 | n_test_vis: 4 19 | test_start_seed: 100000 20 | max_steps: 280 21 | n_obs_steps: ${n_obs_steps} 22 | n_action_steps: ${n_action_steps} 23 | render_hw: [240, 360] 24 | fps: 12.5 25 | past_action: ${past_action_visible} 26 | abs_action: ${task.abs_action} 27 | robot_noise_ratio: ${task.robot_noise_ratio} 28 | n_envs: null 29 | 30 | dataset: 31 | _target_: state_diff.dataset.kitchen_mjl_lowdim_dataset.KitchenMjlLowdimDataset 32 | dataset_dir: data/kitchen/kitchen_demos_multitask 33 | horizon: ${horizon} 34 | pad_before: ${eval:'${n_obs_steps}-1'} 35 | pad_after: ${eval:'${n_action_steps}-1'} 36 | abs_action: ${task.abs_action} 37 | robot_noise_ratio: ${task.robot_noise_ratio} 38 | seed: 42 39 | val_ratio: 0.02 40 | -------------------------------------------------------------------------------- /tests/test_robomimic_image_runner.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | ROOT_DIR = os.path.dirname(os.path.dirname(__file__)) 5 | sys.path.append(ROOT_DIR) 6 | os.chdir(ROOT_DIR) 7 | 8 | from state_diff.env_runner.robomimic_image_runner import RobomimicImageRunner 9 | 10 | def test(): 11 | import os 12 | from omegaconf import OmegaConf 13 | cfg_path = os.path.expanduser('~/dev/state_diff/state_diff/config/task/lift_image.yaml') 14 | cfg = OmegaConf.load(cfg_path) 15 | cfg['n_obs_steps'] = 1 16 | cfg['n_action_steps'] = 1 17 | cfg['past_action_visible'] = False 18 | runner_cfg = cfg['env_runner'] 19 | runner_cfg['n_train'] = 1 20 | runner_cfg['n_test'] = 1 21 | del runner_cfg['_target_'] 22 | runner = RobomimicImageRunner( 23 | **runner_cfg, 24 | output_dir='/tmp/test') 25 | 26 | # import pdb; pdb.set_trace() 27 | 28 | self = runner 29 | env = self.env 30 | env.seed(seeds=self.env_seeds) 31 | obs = env.reset() 32 | for i in range(10): 33 | _ = env.step(env.action_space.sample()) 34 | 35 | imgs = env.render() 36 | 37 | if __name__ == '__main__': 38 | test() 39 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/assets/slidecabinet_asset.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | Training: 4 | python train.py --config-name=train_diffusion_trajectory_unet_lowdim_workspace 5 | """ 6 | 7 | import sys 8 | # use line-buffering for both stdout and stderr 9 | sys.stdout = open(sys.stdout.fileno(), mode='w', buffering=1) 10 | sys.stderr = open(sys.stderr.fileno(), mode='w', buffering=1) 11 | 12 | import hydra 13 | from omegaconf import OmegaConf 14 | import pathlib 15 | from state_diff.workspace.base_workspace import BaseWorkspace 16 | 17 | # allows arbitrary python code execution in configs using the ${eval:''} resolver 18 | OmegaConf.register_new_resolver("eval", eval, replace=True) 19 | 20 | @hydra.main( 21 | version_base=None, 22 | config_path=str(pathlib.Path(__file__).parent.joinpath( 23 | 'state_diff','config'), 24 | ), 25 | config_name="train_diffusion_trajectory_unet_lowdim_workspace" 26 | ) 27 | def main(cfg: OmegaConf): 28 | # resolve immediately so all the ${now:} resolvers 29 | # will use the same time. 30 | OmegaConf.resolve(cfg) 31 | 32 | cls = hydra.utils.get_class(cfg._target_) 33 | workspace: BaseWorkspace = cls(cfg) 34 | workspace.run() 35 | 36 | if __name__ == "__main__": 37 | main() -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/CONTRIBUTING.public.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows 28 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). 29 | -------------------------------------------------------------------------------- /state_diff/shared_memory/shared_memory_util.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from dataclasses import dataclass 3 | import numpy as np 4 | from multiprocessing.managers import SharedMemoryManager 5 | from atomics import atomicview, MemoryOrder, UINT 6 | 7 | @dataclass 8 | class ArraySpec: 9 | name: str 10 | shape: Tuple[int] 11 | dtype: np.dtype 12 | 13 | 14 | class SharedAtomicCounter: 15 | def __init__(self, 16 | shm_manager: SharedMemoryManager, 17 | size :int=8 # 64bit int 18 | ): 19 | shm = shm_manager.SharedMemory(size=size) 20 | self.shm = shm 21 | self.size = size 22 | self.store(0) # initialize 23 | 24 | @property 25 | def buf(self): 26 | return self.shm.buf[:self.size] 27 | 28 | def load(self) -> int: 29 | with atomicview(buffer=self.buf, atype=UINT) as a: 30 | value = a.load(order=MemoryOrder.ACQUIRE) 31 | return value 32 | 33 | def store(self, value: int): 34 | with atomicview(buffer=self.buf, atype=UINT) as a: 35 | a.store(value, order=MemoryOrder.RELEASE) 36 | 37 | def add(self, value: int): 38 | with atomicview(buffer=self.buf, atype=UINT) as a: 39 | a.add(value, order=MemoryOrder.ACQ_REL) 40 | -------------------------------------------------------------------------------- /state_diff/env/block_pushing/oracles/pushing_info.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Reach ML Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Dataclass holding info needed for pushing oracles.""" 17 | import dataclasses 18 | from typing import Any 19 | 20 | 21 | @dataclasses.dataclass 22 | class PushingInfo: 23 | """Holds onto info necessary for pushing state machine.""" 24 | 25 | xy_block: Any = None 26 | xy_ee: Any = None 27 | xy_pre_block: Any = None 28 | xy_delta_to_nexttoblock: Any = None 29 | xy_delta_to_touchingblock: Any = None 30 | xy_dir_block_to_ee: Any = None 31 | theta_threshold_to_orient: Any = None 32 | theta_threshold_flat_enough: Any = None 33 | theta_error: Any = None 34 | obstacle_poses: Any = None 35 | distance_to_target: Any = None 36 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/third_party/franka/assets/teleop_actuator.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /state_diff/model/diffusion/conv1d_components.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | # from einops.layers.torch import Rearrange 5 | 6 | 7 | class Downsample1d(nn.Module): 8 | def __init__(self, dim): 9 | super().__init__() 10 | self.conv = nn.Conv1d(dim, dim, 3, 2, 1) 11 | 12 | def forward(self, x): 13 | return self.conv(x) 14 | 15 | class Upsample1d(nn.Module): 16 | def __init__(self, dim): 17 | super().__init__() 18 | self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) 19 | 20 | def forward(self, x): 21 | return self.conv(x) 22 | 23 | class Conv1dBlock(nn.Module): 24 | ''' 25 | Conv1d --> GroupNorm --> Mish 26 | ''' 27 | 28 | def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): 29 | super().__init__() 30 | 31 | self.block = nn.Sequential( 32 | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), 33 | # Rearrange('batch channels horizon -> batch channels 1 horizon'), 34 | nn.GroupNorm(n_groups, out_channels), 35 | # Rearrange('batch channels 1 horizon -> batch channels horizon'), 36 | nn.Mish(), 37 | ) 38 | 39 | def forward(self, x): 40 | return self.block(x) 41 | 42 | 43 | def test(): 44 | cb = Conv1dBlock(256, 128, kernel_size=3) 45 | x = torch.zeros((1,256,16)) 46 | o = cb(x) 47 | -------------------------------------------------------------------------------- /state_diff/gym_util/video_wrapper.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | 4 | class VideoWrapper(gym.Wrapper): 5 | def __init__(self, 6 | env, 7 | mode='rgb_array', 8 | enabled=True, 9 | steps_per_render=1, 10 | **kwargs 11 | ): 12 | super().__init__(env) 13 | 14 | self.mode = mode 15 | self.enabled = enabled 16 | self.render_kwargs = kwargs 17 | self.steps_per_render = steps_per_render 18 | 19 | self.frames = list() 20 | self.step_count = 0 21 | 22 | def reset(self, **kwargs): 23 | obs = super().reset(**kwargs) 24 | self.frames = list() 25 | self.step_count = 1 26 | if self.enabled: 27 | frame = self.env.render( 28 | mode=self.mode, **self.render_kwargs) 29 | assert frame.dtype == np.uint8 30 | self.frames.append(frame) 31 | return obs 32 | 33 | def step(self, action): 34 | result = super().step(action) 35 | self.step_count += 1 36 | if self.enabled and ((self.step_count % self.steps_per_render) == 0): 37 | frame = self.env.render( 38 | mode=self.mode, **self.render_kwargs) 39 | assert frame.dtype == np.uint8 40 | self.frames.append(frame) 41 | return result 42 | 43 | def render(self, mode='rgb_array', **kwargs): 44 | return self.frames 45 | -------------------------------------------------------------------------------- /tests/test_block_pushing.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | ROOT_DIR = os.path.dirname(os.path.dirname(__file__)) 5 | sys.path.append(ROOT_DIR) 6 | os.chdir(ROOT_DIR) 7 | 8 | from state_diff.env.block_pushing.block_pushing_multimodal import BlockPushMultimodal 9 | from gym.wrappers import FlattenObservation 10 | from state_diff.gym_util.multistep_wrapper import MultiStepWrapper 11 | from state_diff.gym_util.video_wrapper import VideoWrapper 12 | 13 | def test(): 14 | env = MultiStepWrapper( 15 | VideoWrapper( 16 | FlattenObservation( 17 | BlockPushMultimodal() 18 | ), 19 | enabled=True, 20 | steps_per_render=2 21 | ), 22 | n_obs_steps=2, 23 | n_action_steps=8, 24 | max_episode_steps=16 25 | ) 26 | env = BlockPushMultimodal() 27 | obs = env.reset() 28 | import pdb; pdb.set_trace() 29 | 30 | env = FlattenObservation(BlockPushMultimodal()) 31 | obs = env.reset() 32 | action = env.action_space.sample() 33 | next_obs, reward, done, info = env.step(action) 34 | print(obs[8:10] + action - next_obs[8:10]) 35 | import pdb; pdb.set_trace() 36 | 37 | for i in range(3): 38 | obs, reward, done, info = env.step(env.action_space.sample()) 39 | img = env.render() 40 | import pdb; pdb.set_trace() 41 | print("Done!", done) 42 | 43 | if __name__ == '__main__': 44 | test() 45 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/assets/hingecabinet_asset.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /state_diff/config/task/pushl_traj_lowdim.yaml: -------------------------------------------------------------------------------- 1 | name: pushl_traj_lowdim 2 | 3 | obs_dim: 40 # 40 or 42 # 9*2 keypoints + 10*2 keypoints + 4 state 4 | action_dim: 2 5 | keypoint_dim: 2 6 | inv_hidden_dim: 256 7 | num_agents: 1 8 | 9 | obs_key: keypoint # keypoint, structured_keypoint 10 | state_key: state 11 | action_key: action # action1,action2 12 | 13 | env_runner: 14 | _target_: state_diff.env_runner.push_keypoints_runner.PushKeypointsRunner 15 | keypoint_visible_rate: ${keypoint_visible_rate} 16 | n_train: 6 17 | n_train_vis: 2 18 | train_start_seed: 0 19 | n_test: 50 20 | n_test_vis: 4 21 | legacy_test: True 22 | test_start_seed: 100000 23 | max_steps: 500 24 | n_obs_steps: ${n_obs_steps} 25 | n_action_steps: ${n_action_steps} 26 | n_latency_steps: ${n_latency_steps} 27 | fps: 10 28 | agent_keypoints: False 29 | past_action: ${past_action_visible} 30 | n_envs: null 31 | obs_key: ${task.obs_key} 32 | state_key: ${task.state_key} 33 | action_key: ${task.action_key} 34 | env_type: ${task.name} 35 | 36 | dataset: 37 | _target_: state_diff.dataset.pusht_traj_dataset.PushTTrajLowdimDataset 38 | zarr_path: data/pushl_dataset 39 | horizon: ${horizon} 40 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'} 41 | pad_after: ${eval:'${n_action_steps}-1'} 42 | seed: 42 43 | val_ratio: 0.02 44 | max_train_episodes: null 45 | obs_key: ${task.obs_key} 46 | state_key: ${task.state_key} 47 | action_key: ${task.action_key} 48 | num_episodes: null 49 | 50 | -------------------------------------------------------------------------------- /state_diff/env/block_pushing/assets/zone.obj: -------------------------------------------------------------------------------- 1 | # Object Export From Tinkercad Server 2015 2 | 3 | mtllib obj.mtl 4 | 5 | o obj_0 6 | v 10 -10 20 7 | v 10 -10 0 8 | v 10 10 0 9 | v 10 10 20 10 | v 9.002 9.003 20 11 | v 9.002 -9.002 20 12 | v -10 10 0 13 | v -10 10 20 14 | v -9.003 9.003 20 15 | v -9.003 9.003 0 16 | v 9.002 9.003 0 17 | v 9.002 -9.002 0 18 | v -9.003 -9.002 0 19 | v -9.003 -9.002 20 20 | v -10 -10 0 21 | v -10 -10 20 22 | # 16 vertices 23 | 24 | g group_0_15277357 25 | 26 | usemtl color_15277357 27 | s 0 28 | 29 | f 1 2 3 30 | f 1 3 4 31 | f 4 5 6 32 | f 4 6 1 33 | f 9 10 11 34 | f 9 11 5 35 | f 6 12 13 36 | f 6 13 14 37 | f 10 9 14 38 | f 10 14 13 39 | f 7 10 13 40 | f 7 13 15 41 | f 4 8 5 42 | f 9 5 8 43 | f 8 7 15 44 | f 8 15 16 45 | f 10 7 11 46 | f 3 11 7 47 | f 11 3 12 48 | f 2 12 3 49 | f 14 16 6 50 | f 1 6 16 51 | f 16 15 2 52 | f 16 2 1 53 | f 9 8 14 54 | f 16 14 8 55 | f 7 8 3 56 | f 4 3 8 57 | f 2 15 12 58 | f 13 12 15 59 | f 12 6 5 60 | f 12 5 11 61 | # 32 faces 62 | 63 | #end of obj_0 64 | 65 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/assets/kettle_chain.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /state_diff/model/common/dict_of_tensor_mixin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class DictOfTensorMixin(nn.Module): 5 | def __init__(self, params_dict=None): 6 | super().__init__() 7 | if params_dict is None: 8 | params_dict = nn.ParameterDict() 9 | self.params_dict = params_dict 10 | 11 | @property 12 | def device(self): 13 | return next(iter(self.parameters())).device 14 | 15 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): 16 | def dfs_add(dest, keys, value: torch.Tensor): 17 | if len(keys) == 1: 18 | dest[keys[0]] = value 19 | return 20 | 21 | if keys[0] not in dest: 22 | dest[keys[0]] = nn.ParameterDict() 23 | dfs_add(dest[keys[0]], keys[1:], value) 24 | 25 | def load_dict(state_dict, prefix): 26 | out_dict = nn.ParameterDict() 27 | for key, value in state_dict.items(): 28 | value: torch.Tensor 29 | if key.startswith(prefix): 30 | param_keys = key[len(prefix):].split('.')[1:] 31 | # if len(param_keys) == 0: 32 | # import pdb; pdb.set_trace() 33 | dfs_add(out_dict, param_keys, value.clone()) 34 | return out_dict 35 | 36 | self.params_dict = load_dict(state_dict, prefix + 'params_dict') 37 | self.params_dict.requires_grad_(False) 38 | return 39 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/assets/kettle_asset.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/third_party/franka/assets/actuator1.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /state_diff/dataset/base_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | import torch.nn 5 | from state_diff.model.common.normalizer import LinearNormalizer 6 | 7 | class BaseLowdimDataset(torch.utils.data.Dataset): 8 | def get_validation_dataset(self) -> 'BaseLowdimDataset': 9 | # return an empty dataset by default 10 | return BaseLowdimDataset() 11 | 12 | def get_normalizer(self, **kwargs) -> LinearNormalizer: 13 | raise NotImplementedError() 14 | 15 | def get_all_actions(self) -> torch.Tensor: 16 | raise NotImplementedError() 17 | 18 | def __len__(self) -> int: 19 | return 0 20 | 21 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 22 | """ 23 | output: 24 | obs: T, Do 25 | action: T, Da 26 | """ 27 | raise NotImplementedError() 28 | 29 | 30 | class BaseImageDataset(torch.utils.data.Dataset): 31 | def get_validation_dataset(self) -> 'BaseLowdimDataset': 32 | # return an empty dataset by default 33 | return BaseImageDataset() 34 | 35 | def get_normalizer(self, **kwargs) -> LinearNormalizer: 36 | raise NotImplementedError() 37 | 38 | def get_all_actions(self) -> torch.Tensor: 39 | raise NotImplementedError() 40 | 41 | def __len__(self) -> int: 42 | return 0 43 | 44 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 45 | """ 46 | output: 47 | obs: 48 | key: T, * 49 | action: T, Da 50 | """ 51 | raise NotImplementedError() 52 | -------------------------------------------------------------------------------- /tests/test_precise_sleep.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | ROOT_DIR = os.path.dirname(os.path.dirname(__file__)) 5 | sys.path.append(ROOT_DIR) 6 | os.chdir(ROOT_DIR) 7 | 8 | import time 9 | import numpy as np 10 | from state_diff.common.precise_sleep import precise_sleep, precise_wait 11 | 12 | 13 | def test_sleep(): 14 | dt = 0.1 15 | tol = 1e-3 16 | time_samples = list() 17 | for i in range(100): 18 | precise_sleep(dt) 19 | # time.sleep(dt) 20 | time_samples.append(time.monotonic()) 21 | time_deltas = np.diff(time_samples) 22 | 23 | from matplotlib import pyplot as plt 24 | plt.plot(time_deltas) 25 | plt.ylim((dt-tol,dt+tol)) 26 | 27 | 28 | def test_wait(): 29 | dt = 0.1 30 | tol = 1e-3 31 | errors = list() 32 | t_start = time.monotonic() 33 | for i in range(1,100): 34 | t_end_desired = t_start + i * dt 35 | time.sleep(t_end_desired - time.monotonic()) 36 | t_end = time.monotonic() 37 | errors.append(t_end - t_end_desired) 38 | 39 | new_errors = list() 40 | t_start = time.monotonic() 41 | for i in range(1,100): 42 | t_end_desired = t_start + i * dt 43 | precise_wait(t_end_desired) 44 | t_end = time.monotonic() 45 | new_errors.append(t_end - t_end_desired) 46 | 47 | from matplotlib import pyplot as plt 48 | plt.plot(errors, label='time.sleep') 49 | plt.plot(new_errors, label='sleep/spin hybrid') 50 | plt.ylim((-tol,+tol)) 51 | plt.title('0.1 sec sleep error') 52 | plt.legend() 53 | 54 | 55 | if __name__ == '__main__': 56 | test_sleep() 57 | -------------------------------------------------------------------------------- /state_diff/common/pymunk_util.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | import pymunk 3 | import pymunk.pygame_util 4 | import numpy as np 5 | 6 | COLLTYPE_DEFAULT = 0 7 | COLLTYPE_MOUSE = 1 8 | COLLTYPE_BALL = 2 9 | 10 | def get_body_type(static=False): 11 | body_type = pymunk.Body.DYNAMIC 12 | if static: 13 | body_type = pymunk.Body.STATIC 14 | return body_type 15 | 16 | 17 | def create_rectangle(space, 18 | pos_x,pos_y,width,height, 19 | density=3,static=False): 20 | body = pymunk.Body(body_type=get_body_type(static)) 21 | body.position = (pos_x,pos_y) 22 | shape = pymunk.Poly.create_box(body,(width,height)) 23 | shape.density = density 24 | space.add(body,shape) 25 | return body, shape 26 | 27 | 28 | def create_rectangle_bb(space, 29 | left, bottom, right, top, 30 | **kwargs): 31 | pos_x = (left + right) / 2 32 | pos_y = (top + bottom) / 2 33 | height = top - bottom 34 | width = right - left 35 | return create_rectangle(space, pos_x, pos_y, width, height, **kwargs) 36 | 37 | def create_circle(space, pos_x, pos_y, radius, density=3, static=False): 38 | body = pymunk.Body(body_type=get_body_type(static)) 39 | body.position = (pos_x, pos_y) 40 | shape = pymunk.Circle(body, radius=radius) 41 | shape.density = density 42 | shape.collision_type = COLLTYPE_BALL 43 | space.add(body, shape) 44 | return body, shape 45 | 46 | def get_body_state(body): 47 | state = np.zeros(6, dtype=np.float32) 48 | state[:2] = body.position 49 | state[2] = body.angle 50 | state[3:5] = body.velocity 51 | state[5] = body.angular_velocity 52 | return state 53 | -------------------------------------------------------------------------------- /state_diff/scripts/bet_blockpush_conversion.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | import sys 3 | import os 4 | import pathlib 5 | 6 | ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent) 7 | sys.path.append(ROOT_DIR) 8 | 9 | 10 | import os 11 | import click 12 | import pathlib 13 | import numpy as np 14 | from state_diff.common.replay_buffer import ReplayBuffer 15 | 16 | @click.command() 17 | @click.option('-i', '--input', required=True, help='input dir contains npy files') 18 | @click.option('-o', '--output', required=True, help='output zarr path') 19 | @click.option('--abs_action', is_flag=True, default=False) 20 | def main(input, output, abs_action): 21 | data_directory = pathlib.Path(input) 22 | observations = np.load( 23 | data_directory / "multimodal_push_observations.npy" 24 | ) 25 | actions = np.load(data_directory / "multimodal_push_actions.npy") 26 | masks = np.load(data_directory / "multimodal_push_masks.npy") 27 | 28 | buffer = ReplayBuffer.create_empty_numpy() 29 | for i in range(len(masks)): 30 | eps_len = int(masks[i].sum()) 31 | obs = observations[i,:eps_len].astype(np.float32) 32 | action = actions[i,:eps_len].astype(np.float32) 33 | if abs_action: 34 | prev_eef_target = obs[:,8:10] 35 | next_eef_target = prev_eef_target + action 36 | action = next_eef_target 37 | data = { 38 | 'obs': obs, 39 | 'action': action 40 | } 41 | buffer.add_episode(data) 42 | 43 | buffer.save_to_path(zarr_path=output, chunk_length=-1) 44 | 45 | if __name__ == '__main__': 46 | main() 47 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/kitchen_lowdim_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Optional, Optional 2 | import numpy as np 3 | import gym 4 | from gym.spaces import Box 5 | from state_diff.env.kitchen.base import KitchenBase 6 | 7 | class KitchenLowdimWrapper(gym.Env): 8 | def __init__(self, 9 | env: KitchenBase, 10 | init_qpos: Optional[np.ndarray]=None, 11 | init_qvel: Optional[np.ndarray]=None, 12 | render_hw = (240,360) 13 | ): 14 | self.env = env 15 | self.init_qpos = init_qpos 16 | self.init_qvel = init_qvel 17 | self.render_hw = render_hw 18 | 19 | @property 20 | def action_space(self): 21 | return self.env.action_space 22 | 23 | @property 24 | def observation_space(self): 25 | return self.env.observation_space 26 | 27 | def seed(self, seed=None): 28 | return self.env.seed(seed) 29 | 30 | def reset(self): 31 | if self.init_qpos is not None: 32 | # reset anyway to be safe, not very expensive 33 | _ = self.env.reset() 34 | # start from known state 35 | self.env.set_state(self.init_qpos, self.init_qvel) 36 | obs = self.env._get_obs() 37 | return obs 38 | # obs, _, _, _ = self.env.step(np.zeros_like( 39 | # self.action_space.sample())) 40 | # return obs 41 | else: 42 | return self.env.reset() 43 | 44 | def render(self, mode='rgb_array'): 45 | h, w = self.render_hw 46 | return self.env.render(mode=mode, width=w, height=h) 47 | 48 | def step(self, a): 49 | return self.env.step(a) 50 | -------------------------------------------------------------------------------- /state_diff/config/task/pushl_lowdim.yaml: -------------------------------------------------------------------------------- 1 | name: pushl_lowdim 2 | 3 | 4 | 5 | # For single agent pushl 6 | obs_dim: 40 # 9*2 keypoints + 10*2 keypoints + 2 state 7 | action_dim: 2 8 | keypoint_dim: 2 9 | num_agents: 1 10 | 11 | 12 | obs_key: keypoint # keypoint, structured_keypoint 13 | state_key: state 14 | action_key: action 15 | 16 | 17 | # # For double agent pushl 18 | # obs_dim: 42 # 9*2 keypoints + 10*2 keypoints + 4 state 19 | # action_dim: 4 20 | # keypoint_dim: 2 21 | # obs_key: keypoint # keypoint, structured_keypoint 22 | # state_key: state 23 | # action_key: action1,action2 24 | 25 | 26 | 27 | env_runner: 28 | _target_: state_diff.env_runner.push_keypoints_runner.PushKeypointsRunner 29 | keypoint_visible_rate: ${keypoint_visible_rate} 30 | n_train: 6 31 | n_train_vis: 2 32 | train_start_seed: 0 33 | n_test: 200 # 50 34 | n_test_vis: 4 35 | legacy_test: True 36 | test_start_seed: 100000 37 | max_steps: 500 38 | n_obs_steps: ${n_obs_steps} 39 | n_action_steps: ${n_action_steps} 40 | n_latency_steps: ${n_latency_steps} 41 | fps: 10 42 | agent_keypoints: False 43 | past_action: ${past_action_visible} 44 | n_envs: null 45 | obs_key: ${task.obs_key} 46 | state_key: ${task.state_key} 47 | action_key: ${task.action_key} 48 | 49 | 50 | dataset: 51 | _target_: state_diff.dataset.pusht_dataset.PushTLowdimDataset 52 | zarr_path: data/pushl_dataset 53 | horizon: ${horizon} 54 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'} 55 | pad_after: ${eval:'${n_action_steps}-1'} 56 | seed: 42 57 | val_ratio: 0.02 58 | max_train_episodes: null 59 | obs_key: ${task.obs_key} 60 | state_key: ${task.state_key} 61 | action_key: ${task.action_key} 62 | num_episodes: 600 63 | 64 | -------------------------------------------------------------------------------- /tests/test_shared_queue.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | ROOT_DIR = os.path.dirname(os.path.dirname(__file__)) 5 | sys.path.append(ROOT_DIR) 6 | os.chdir(ROOT_DIR) 7 | 8 | import numpy as np 9 | from multiprocessing.managers import SharedMemoryManager 10 | from state_diff.shared_memory.shared_memory_queue import SharedMemoryQueue, Full, Empty 11 | 12 | 13 | def test(): 14 | shm_manager = SharedMemoryManager() 15 | shm_manager.start() 16 | example = { 17 | 'cmd': 0, 18 | 'pose': np.zeros((6,)) 19 | } 20 | queue = SharedMemoryQueue.create_from_examples( 21 | shm_manager=shm_manager, 22 | examples=example, 23 | buffer_size=3 24 | ) 25 | raised = False 26 | try: 27 | queue.get() 28 | except Empty: 29 | raised = True 30 | assert raised 31 | 32 | data = { 33 | 'cmd': 1, 34 | 'pose': np.ones((6,)) 35 | } 36 | queue.put(data) 37 | result = queue.get() 38 | assert result['cmd'] == data['cmd'] 39 | assert np.allclose(result['pose'], data['pose']) 40 | 41 | queue.put(data) 42 | queue.put(data) 43 | queue.put(data) 44 | assert queue.qsize() == 3 45 | raised = False 46 | try: 47 | queue.put(data) 48 | except Full: 49 | raised = True 50 | assert raised 51 | 52 | result = queue.get_all() 53 | assert np.allclose(result['cmd'], [1,1,1]) 54 | 55 | queue.put({'cmd': 0}) 56 | queue.put({'cmd': 1}) 57 | queue.put({'cmd': 2}) 58 | queue.get() 59 | queue.put({'cmd': 3}) 60 | 61 | result = queue.get_k(3) 62 | assert np.allclose(result['cmd'], [1,2,3]) 63 | 64 | queue.clear() 65 | 66 | if __name__ == "__main__": 67 | test() 68 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/assets/microwave_asset.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/third_party/franka/assets/actuator0.xml: -------------------------------------------------------------------------------- 1 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /state_diff/common/robomimic_config_util.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | from robomimic.config import config_factory 3 | import robomimic.scripts.generate_paper_configs as gpc 4 | from robomimic.scripts.generate_paper_configs import ( 5 | modify_config_for_default_image_exp, 6 | modify_config_for_default_low_dim_exp, 7 | modify_config_for_dataset, 8 | ) 9 | 10 | def get_robomimic_config( 11 | algo_name='bc_rnn', 12 | hdf5_type='low_dim', 13 | task_name='square', 14 | dataset_type='ph' 15 | ): 16 | base_dataset_dir = '/tmp/null' 17 | filter_key = None 18 | 19 | # decide whether to use low-dim or image training defaults 20 | modifier_for_obs = modify_config_for_default_image_exp 21 | if hdf5_type in ["low_dim", "low_dim_sparse", "low_dim_dense"]: 22 | modifier_for_obs = modify_config_for_default_low_dim_exp 23 | 24 | algo_config_name = "bc" if algo_name == "bc_rnn" else algo_name 25 | config = config_factory(algo_name=algo_config_name) 26 | # turn into default config for observation modalities (e.g.: low-dim or rgb) 27 | config = modifier_for_obs(config) 28 | # add in config based on the dataset 29 | config = modify_config_for_dataset( 30 | config=config, 31 | task_name=task_name, 32 | dataset_type=dataset_type, 33 | hdf5_type=hdf5_type, 34 | base_dataset_dir=base_dataset_dir, 35 | filter_key=filter_key, 36 | ) 37 | # add in algo hypers based on dataset 38 | algo_config_modifier = getattr(gpc, f'modify_{algo_name}_config_for_dataset') 39 | config = algo_config_modifier( 40 | config=config, 41 | task_name=task_name, 42 | dataset_type=dataset_type, 43 | hdf5_type=hdf5_type, 44 | ) 45 | return config 46 | 47 | 48 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/scenes/basic_scene.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /state_diff/env/block_pushing/assets/insert.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/third_party/franka/franka_panda.xml: -------------------------------------------------------------------------------- 1 | 4 | 5 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /tests/test_replay_buffer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | ROOT_DIR = os.path.dirname(os.path.dirname(__file__)) 5 | sys.path.append(ROOT_DIR) 6 | os.chdir(ROOT_DIR) 7 | 8 | import zarr 9 | from state_diff.common.replay_buffer import ReplayBuffer 10 | 11 | def test(): 12 | import numpy as np 13 | buff = ReplayBuffer.create_empty_numpy() 14 | buff.add_episode({ 15 | 'obs': np.zeros((100,10), dtype=np.float16) 16 | }) 17 | buff.add_episode({ 18 | 'obs': np.ones((50,10)), 19 | 'action': np.ones((50,2)) 20 | }) 21 | # buff.rechunk(256) 22 | obs = buff.get_episode(0) 23 | 24 | import numpy as np 25 | buff = ReplayBuffer.create_empty_zarr() 26 | buff.add_episode({ 27 | 'obs': np.zeros((100,10), dtype=np.float16) 28 | }) 29 | buff.add_episode({ 30 | 'obs': np.ones((50,10)), 31 | 'action': np.ones((50,2)) 32 | }) 33 | obs = buff.get_episode(0) 34 | buff.set_chunks({ 35 | 'obs': (100,10), 36 | 'action': (100,2) 37 | }) 38 | 39 | 40 | def test_real(): 41 | import os 42 | dist_group = zarr.open( 43 | os.path.expanduser('~/dev/state_diff/data/pusht/pusht_cchi_v2.zarr'), 'r') 44 | 45 | buff = ReplayBuffer.create_empty_numpy() 46 | key, group = next(iter(dist_group.items())) 47 | for key, group in dist_group.items(): 48 | buff.add_episode(group) 49 | 50 | # out_path = os.path.expanduser('~/dev/state_diff/data/pusht_cchi2_v2_replay.zarr') 51 | out_path = os.path.expanduser('~/dev/state_diff/data/test.zarr') 52 | out_store = zarr.DirectoryStore(out_path) 53 | buff.save_to_store(out_store) 54 | 55 | buff = ReplayBuffer.copy_from_path(out_path, store=zarr.MemoryStore()) 56 | buff.pop_episode() 57 | 58 | 59 | def test_pop(): 60 | buff = ReplayBuffer.create_from_path( 61 | '/home/chengchi/dev/state_diff/data/pusht_cchi_v3_replay.zarr', 62 | mode='rw') 63 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/assets/counters_asset.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/kitchen.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /conda_environment.yaml: -------------------------------------------------------------------------------- 1 | name: coord_bimanual 2 | channels: 3 | - pytorch 4 | - pytorch3d 5 | - nvidia 6 | - conda-forge 7 | dependencies: 8 | - python=3.9 9 | - pip=22.2.2 10 | - cudatoolkit=11.6 11 | - pytorch=1.12.1 12 | - torchvision=0.13.1 13 | - pytorch3d=0.7.0 14 | - numpy=1.23.3 15 | - numba==0.56.4 16 | - scipy==1.9.1 17 | - py-opencv=4.6.0 18 | - cffi=1.15.1 19 | - ipykernel=6.16 20 | - matplotlib=3.6.1 21 | - zarr=2.12.0 22 | - numcodecs=0.10.2 23 | - h5py=3.7.0 24 | - hydra-core=1.2.0 25 | - einops=0.4.1 26 | - tqdm=4.64.1 27 | - dill=0.3.5.1 28 | - scikit-video=1.1.11 29 | - scikit-image=0.19.3 30 | - gym=0.21.0 31 | - pymunk=6.2.1 32 | - wandb=0.13.3 33 | - threadpoolctl=3.1.0 34 | - shapely=1.8.4 35 | - cython=0.29.32 36 | - imageio=2.22.0 37 | - imageio-ffmpeg=0.4.7 38 | - termcolor=2.0.1 39 | - tensorboard=2.10.1 40 | - tensorboardx=2.5.1 41 | - psutil=5.9.2 42 | - click=8.0.4 43 | - boto3=1.24.96 44 | - accelerate=0.13.2 45 | - datasets=2.6.1 46 | - diffusers=0.11.1 47 | - av=10.0.0 48 | - cmake=3.24.3 49 | # trick to avoid cpu affinity issue described in https://github.com/pytorch/pytorch/issues/99625 50 | - llvm-openmp=14 51 | # trick to force reinstall imagecodecs via pip 52 | - imagecodecs==2022.9.26 53 | - pip: 54 | - ray[default,tune]==2.2.0 55 | # requires mujoco py dependencies libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf 56 | - free-mujoco-py==2.1.6 57 | - pygame==2.1.2 58 | - pybullet-svl==3.1.6.4 59 | - robosuite @ https://github.com/cheng-chi/robosuite/archive/277ab9588ad7a4f4b55cf75508b44aa67ec171f0.tar.gz 60 | - robomimic==0.2.0 61 | - pytorchvideo==0.1.5 62 | # pip package required for jpeg-xl 63 | - imagecodecs==2022.9.26 64 | - r3m @ https://github.com/facebookresearch/r3m/archive/b2334e726887fa0206962d7984c69c5fb09cceab.tar.gz 65 | - dm-control==1.0.9 66 | - huggingface-hub==0.25.2 67 | - -e . 68 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/kitchen_util.py: -------------------------------------------------------------------------------- 1 | import struct 2 | import numpy as np 3 | 4 | def parse_mjl_logs(read_filename, skipamount): 5 | with open(read_filename, mode='rb') as file: 6 | fileContent = file.read() 7 | headers = struct.unpack('iiiiiii', fileContent[:28]) 8 | nq = headers[0] 9 | nv = headers[1] 10 | nu = headers[2] 11 | nmocap = headers[3] 12 | nsensordata = headers[4] 13 | nuserdata = headers[5] 14 | name_len = headers[6] 15 | name = struct.unpack(str(name_len) + 's', fileContent[28:28+name_len])[0] 16 | rem_size = len(fileContent[28 + name_len:]) 17 | num_floats = int(rem_size/4) 18 | dat = np.asarray(struct.unpack(str(num_floats) + 'f', fileContent[28+name_len:])) 19 | recsz = 1 + nq + nv + nu + 7*nmocap + nsensordata + nuserdata 20 | if rem_size % recsz != 0: 21 | print("ERROR") 22 | else: 23 | dat = np.reshape(dat, (int(len(dat)/recsz), recsz)) 24 | dat = dat.T 25 | 26 | time = dat[0,:][::skipamount] - 0*dat[0, 0] 27 | qpos = dat[1:nq + 1, :].T[::skipamount, :] 28 | qvel = dat[nq+1:nq+nv+1,:].T[::skipamount, :] 29 | ctrl = dat[nq+nv+1:nq+nv+nu+1,:].T[::skipamount,:] 30 | mocap_pos = dat[nq+nv+nu+1:nq+nv+nu+3*nmocap+1,:].T[::skipamount, :] 31 | mocap_quat = dat[nq+nv+nu+3*nmocap+1:nq+nv+nu+7*nmocap+1,:].T[::skipamount, :] 32 | sensordata = dat[nq+nv+nu+7*nmocap+1:nq+nv+nu+7*nmocap+nsensordata+1,:].T[::skipamount,:] 33 | userdata = dat[nq+nv+nu+7*nmocap+nsensordata+1:,:].T[::skipamount,:] 34 | 35 | data = dict(nq=nq, 36 | nv=nv, 37 | nu=nu, 38 | nmocap=nmocap, 39 | nsensordata=nsensordata, 40 | name=name, 41 | time=time, 42 | qpos=qpos, 43 | qvel=qvel, 44 | ctrl=ctrl, 45 | mocap_pos=mocap_pos, 46 | mocap_quat=mocap_quat, 47 | sensordata=sensordata, 48 | userdata=userdata, 49 | logName = read_filename 50 | ) 51 | return data 52 | -------------------------------------------------------------------------------- /state_diff/model/common/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from diffusers.optimization import ( 2 | Union, SchedulerType, Optional, 3 | Optimizer, TYPE_TO_SCHEDULER_FUNCTION 4 | ) 5 | 6 | def get_scheduler( 7 | name: Union[str, SchedulerType], 8 | optimizer: Optimizer, 9 | num_warmup_steps: Optional[int] = None, 10 | num_training_steps: Optional[int] = None, 11 | **kwargs 12 | ): 13 | """ 14 | Added kwargs vs diffuser's original implementation 15 | 16 | Unified API to get any scheduler from its name. 17 | 18 | Args: 19 | name (`str` or `SchedulerType`): 20 | The name of the scheduler to use. 21 | optimizer (`torch.optim.Optimizer`): 22 | The optimizer that will be used during training. 23 | num_warmup_steps (`int`, *optional*): 24 | The number of warmup steps to do. This is not required by all schedulers (hence the argument being 25 | optional), the function will raise an error if it's unset and the scheduler type requires it. 26 | num_training_steps (`int``, *optional*): 27 | The number of training steps to do. This is not required by all schedulers (hence the argument being 28 | optional), the function will raise an error if it's unset and the scheduler type requires it. 29 | """ 30 | name = SchedulerType(name) 31 | schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] 32 | if name == SchedulerType.CONSTANT: 33 | return schedule_func(optimizer, **kwargs) 34 | 35 | # All other schedulers require `num_warmup_steps` 36 | if num_warmup_steps is None: 37 | raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") 38 | 39 | if name == SchedulerType.CONSTANT_WITH_WARMUP: 40 | return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **kwargs) 41 | 42 | # All other schedulers require `num_training_steps` 43 | if num_training_steps is None: 44 | raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") 45 | 46 | return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **kwargs) 47 | -------------------------------------------------------------------------------- /state_diff/scripts/real_dataset_conversion.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | import sys 3 | import os 4 | import pathlib 5 | 6 | ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent) 7 | sys.path.append(ROOT_DIR) 8 | 9 | import os 10 | import click 11 | import pathlib 12 | import zarr 13 | import cv2 14 | import threadpoolctl 15 | from state_diff.real_world.real_data_conversion import real_data_to_replay_buffer 16 | 17 | @click.command() 18 | @click.option('--input', '-i', required=True) 19 | @click.option('--output', '-o', default=None) 20 | @click.option('--resolution', '-r', default='640x480') 21 | @click.option('--n_decoding_threads', '-nd', default=-1, type=int) 22 | @click.option('--n_encoding_threads', '-ne', default=-1, type=int) 23 | def main(input, output, resolution, n_decoding_threads, n_encoding_threads): 24 | out_resolution = tuple(int(x) for x in resolution.split('x')) 25 | input = pathlib.Path(os.path.expanduser(input)) 26 | in_zarr_path = input.joinpath('replay_buffer.zarr') 27 | in_video_dir = input.joinpath('videos') 28 | assert in_zarr_path.is_dir() 29 | assert in_video_dir.is_dir() 30 | if output is None: 31 | output = input.joinpath(resolution + '.zarr.zip') 32 | else: 33 | output = pathlib.Path(os.path.expanduser(output)) 34 | 35 | if output.exists(): 36 | click.confirm('Output path already exists! Overrite?', abort=True) 37 | 38 | cv2.setNumThreads(1) 39 | with threadpoolctl.threadpool_limits(1): 40 | replay_buffer = real_data_to_replay_buffer( 41 | dataset_path=str(input), 42 | out_resolutions=out_resolution, 43 | n_decoding_threads=n_decoding_threads, 44 | n_encoding_threads=n_encoding_threads 45 | ) 46 | 47 | print('Saving to disk') 48 | if output.suffix == '.zip': 49 | with zarr.ZipStore(output) as zip_store: 50 | replay_buffer.save_to_store( 51 | store=zip_store 52 | ) 53 | else: 54 | with zarr.DirectoryStore(output) as store: 55 | replay_buffer.save_to_store( 56 | store=store 57 | ) 58 | 59 | if __name__ == '__main__': 60 | main() 61 | -------------------------------------------------------------------------------- /state_diff/env/block_pushing/assets/suction/suction-base.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/assets/slidecabinet_chain.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /state_diff/env/block_pushing/assets/suction/suction-head.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/assets/microwave_chain.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /state_diff/scripts/generate_bet_blockpush.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | import sys 3 | import os 4 | import pathlib 5 | 6 | ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent) 7 | sys.path.append(ROOT_DIR) 8 | 9 | 10 | import os 11 | import click 12 | import pathlib 13 | import numpy as np 14 | from tqdm import tqdm 15 | from state_diff.common.replay_buffer import ReplayBuffer 16 | from tf_agents.environments.wrappers import TimeLimit 17 | from tf_agents.environments.gym_wrapper import GymWrapper 18 | from tf_agents.trajectories.time_step import StepType 19 | from state_diff.env.block_pushing.block_pushing_multimodal import BlockPushMultimodal 20 | from state_diff.env.block_pushing.block_pushing import BlockPush 21 | from state_diff.env.block_pushing.oracles.multimodal_push_oracle import MultimodalOrientedPushOracle 22 | 23 | @click.command() 24 | @click.option('-o', '--output', required=True) 25 | @click.option('-n', '--n_episodes', default=1000) 26 | @click.option('-c', '--chunk_length', default=-1) 27 | def main(output, n_episodes, chunk_length): 28 | 29 | buffer = ReplayBuffer.create_empty_numpy() 30 | env = TimeLimit(GymWrapper(BlockPushMultimodal()), duration=350) 31 | for i in tqdm(range(n_episodes)): 32 | print(i) 33 | obs_history = list() 34 | action_history = list() 35 | 36 | env.seed(i) 37 | policy = MultimodalOrientedPushOracle(env) 38 | time_step = env.reset() 39 | policy_state = policy.get_initial_state(1) 40 | while True: 41 | action_step = policy.action(time_step, policy_state) 42 | obs = np.concatenate(list(time_step.observation.values()), axis=-1) 43 | action = action_step.action 44 | obs_history.append(obs) 45 | action_history.append(action) 46 | 47 | if time_step.step_type == 2: 48 | break 49 | 50 | # state = env.wrapped_env().gym.get_pybullet_state() 51 | time_step = env.step(action) 52 | obs_history = np.array(obs_history) 53 | action_history = np.array(action_history) 54 | 55 | episode = { 56 | 'obs': obs_history, 57 | 'action': action_history 58 | } 59 | buffer.add_episode(episode) 60 | 61 | buffer.save_to_path(output, chunk_length=chunk_length) 62 | 63 | if __name__ == '__main__': 64 | main() 65 | -------------------------------------------------------------------------------- /state_diff/env/block_pushing/utils/pose3d.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Reach ML Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A simple 6DOF pose container. 17 | """ 18 | 19 | import dataclasses 20 | import numpy as np 21 | from scipy.spatial import transform 22 | 23 | 24 | class NoCopyAsDict(object): 25 | """Base class for dataclasses. Avoids a copy in the asdict() call.""" 26 | 27 | def asdict(self): 28 | """Replacement for dataclasses.asdict. 29 | 30 | TF Dataset does not handle dataclasses.asdict, which uses copy.deepcopy when 31 | setting values in the output dict. This causes issues with tf.Dataset. 32 | Instead, shallow copy contents. 33 | 34 | Returns: 35 | dict containing contents of dataclass. 36 | """ 37 | return {k.name: getattr(self, k.name) for k in dataclasses.fields(self)} 38 | 39 | 40 | @dataclasses.dataclass 41 | class Pose3d(NoCopyAsDict): 42 | """Simple container for translation and rotation.""" 43 | 44 | rotation: transform.Rotation 45 | translation: np.ndarray 46 | 47 | @property 48 | def vec7(self): 49 | return np.concatenate([self.translation, self.rotation.as_quat()]) 50 | 51 | def serialize(self): 52 | return { 53 | "rotation": self.rotation.as_quat().tolist(), 54 | "translation": self.translation.tolist(), 55 | } 56 | 57 | @staticmethod 58 | def deserialize(data): 59 | return Pose3d( 60 | rotation=transform.Rotation.from_quat(data["rotation"]), 61 | translation=np.array(data["translation"]), 62 | ) 63 | 64 | def __eq__(self, other): 65 | return np.array_equal( 66 | self.rotation.as_quat(), other.rotation.as_quat() 67 | ) and np.array_equal(self.translation, other.translation) 68 | 69 | def __ne__(self, other): 70 | return not self.__eq__(other) 71 | -------------------------------------------------------------------------------- /state_diff/common/precise_sleep.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import time 4 | 5 | def precise_sleep(dt: float, slack_time: float = 0.001, time_func=time.monotonic): 6 | """ 7 | Performs a precise sleep for a specified duration using a combination of time.sleep and active spinning. 8 | 9 | Args: 10 | dt (float): The total duration to sleep, in seconds. 11 | slack_time (float, optional): The duration to actively spin (busy-wait) after sleeping, to achieve precision. Defaults to 0.001 seconds. 12 | time_func (function, optional): The time function used for measuring time. Defaults to time.monotonic. 13 | 14 | Description: 15 | This function first uses time.sleep to sleep for 'dt - slack_time' seconds, 16 | allowing for low CPU usage during most of the sleep duration. Then, it actively spins 17 | (busy-wait) for the remaining 'slack_time' to ensure precise wake-up. This hybrid approach 18 | minimizes jitter caused by the non-deterministic nature of time.sleep. 19 | """ 20 | t_start = time_func() 21 | if dt > slack_time: 22 | time.sleep(dt - slack_time) 23 | t_end = t_start + dt 24 | while time_func() < t_end: 25 | pass 26 | return 27 | 28 | def precise_wait(t_end: float, slack_time: float = 0.001, time_func=time.monotonic): 29 | """ 30 | Waits until a specified end time using a combination of time.sleep and active spinning for precision. 31 | 32 | Args: 33 | t_end (float): The target end time in seconds since a fixed point in the past (e.g., system start). 34 | slack_time (float, optional): The duration to actively spin (busy-wait) before reaching the target time, to achieve precision. Defaults to 0.001 seconds. 35 | time_func (function, optional): The time function used for measuring time. Defaults to time.monotonic. 36 | 37 | Description: 38 | This function calculates the remaining time to 't_end' and then uses a combination of 39 | time.sleep and active spinning to wait until this target time. It sleeps for the bulk of 40 | the remaining time, then actively spins for the final 'slack_time' duration, 41 | thus ensuring precise timing. 42 | """ 43 | t_start = time_func() 44 | t_wait = t_end - t_start 45 | if t_wait > 0: 46 | t_sleep = t_wait - slack_time 47 | if t_sleep > 0: 48 | time.sleep(t_sleep) 49 | while time_func() < t_end: 50 | pass 51 | return 52 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/third_party/franka/franka_panda_teleop.xml: -------------------------------------------------------------------------------- 1 | 4 | 5 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /state_diff/env/block_pushing/oracles/reach_oracle.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Reach ML Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Reach oracle.""" 17 | import numpy as np 18 | from tf_agents.policies import py_policy 19 | from tf_agents.trajectories import policy_step 20 | from tf_agents.trajectories import time_step as ts 21 | from tf_agents.typing import types 22 | 23 | # Only used for debug visualization. 24 | import pybullet # pylint: disable=unused-import 25 | 26 | 27 | class ReachOracle(py_policy.PyPolicy): 28 | """Oracle for moving to a specific spot relative to the block and target.""" 29 | 30 | def __init__(self, env, block_pushing_oracles_action_std=0.0): 31 | super(ReachOracle, self).__init__(env.time_step_spec(), env.action_spec()) 32 | self._env = env 33 | self._np_random_state = np.random.RandomState(0) 34 | self._block_pushing_oracles_action_std = block_pushing_oracles_action_std 35 | 36 | def _action(self, time_step, policy_state): 37 | 38 | # Specifying this as velocity makes it independent of control frequency. 39 | max_step_velocity = 0.2 40 | 41 | xy_ee = time_step.observation["effector_target_translation"] 42 | 43 | # This should be observable from block and target translation, 44 | # but re-using the computation from the env so that it's only done once, and 45 | # used for reward / completion computation. 46 | xy_pre_block = self._env.reach_target_translation 47 | 48 | xy_delta = xy_pre_block - xy_ee 49 | 50 | if self._block_pushing_oracles_action_std != 0.0: 51 | xy_delta += ( 52 | self._np_random_state.randn(2) * self._block_pushing_oracles_action_std 53 | ) 54 | 55 | max_step_distance = max_step_velocity * (1 / self._env.get_control_frequency()) 56 | length = np.linalg.norm(xy_delta) 57 | if length > max_step_distance: 58 | xy_direction = xy_delta / length 59 | xy_delta = xy_direction * max_step_distance 60 | 61 | return policy_step.PolicyStep(action=np.asarray(xy_delta, dtype=np.float32)) 62 | -------------------------------------------------------------------------------- /state_diff/scripts/real_pusht_successrate.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | import sys 3 | import os 4 | import pathlib 5 | 6 | ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent) 7 | sys.path.append(ROOT_DIR) 8 | 9 | import os 10 | import click 11 | import collections 12 | import numpy as np 13 | from tqdm import tqdm 14 | import json 15 | 16 | @click.command() 17 | @click.option( 18 | '--reference', '-r', required=True, 19 | help='Reference metrics_raw.json from demonstration dataset.' 20 | ) 21 | @click.option( 22 | '--input', '-i', required=True, 23 | help='Data search path' 24 | ) 25 | def main(reference, input): 26 | # compute the min last metric for demo metrics 27 | demo_metrics = json.load(open(reference, 'r')) 28 | demo_min_metrics = collections.defaultdict(lambda:float('inf')) 29 | for episode_idx, metrics in demo_metrics.items(): 30 | for key, value in metrics.items(): 31 | last_value = value[-1] 32 | demo_min_metrics[key] = min(demo_min_metrics[key], last_value) 33 | print(demo_min_metrics) 34 | 35 | # find all metric 36 | name = 'metrics_raw.json' 37 | search_dir = pathlib.Path(input) 38 | success_rate_map = dict() 39 | for json_path in search_dir.glob('**/'+name): 40 | rel_path = json_path.relative_to(search_dir) 41 | rel_name = str(rel_path.parent) 42 | this_metrics = json.load(json_path.open('r')) 43 | metric_success_idxs = collections.defaultdict(list) 44 | metric_failure_idxs = collections.defaultdict(list) 45 | for episode_idx, metrics in this_metrics.items(): 46 | for key, value in metrics.items(): 47 | last_value = value[-1] 48 | # print(episode_idx, key, last_value) 49 | demo_min = demo_min_metrics[key] 50 | if last_value >= demo_min: 51 | # success 52 | metric_success_idxs[key].append(episode_idx) 53 | else: 54 | metric_failure_idxs[key].append(episode_idx) 55 | # in case of no success 56 | _ = metric_success_idxs[key] 57 | _ = metric_failure_idxs[key] 58 | metric_success_rate = dict() 59 | n_episodes = len(this_metrics) 60 | for key, value in metric_success_idxs.items(): 61 | metric_success_rate[key] = len(value) / n_episodes 62 | # metric_success_rate['failured_idxs'] = metric_failure_idxs 63 | success_rate_map[rel_name] = metric_success_rate 64 | 65 | text = json.dumps(success_rate_map, indent=2) 66 | print(text) 67 | 68 | if __name__ == '__main__': 69 | main() 70 | -------------------------------------------------------------------------------- /tests/test_multi_realsense.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | ROOT_DIR = os.path.dirname(os.path.dirname(__file__)) 5 | sys.path.append(ROOT_DIR) 6 | os.chdir(ROOT_DIR) 7 | 8 | import cv2 9 | import json 10 | import time 11 | import numpy as np 12 | from state_diff.real_world.multi_realsense import MultiRealsense 13 | from state_diff.real_world.video_recorder import VideoRecorder 14 | 15 | def test(): 16 | config = json.load(open('/home/cchi/dev/state_diff/state_diff/real_world/realsense_config/415_high_accuracy_mode.json', 'r')) 17 | 18 | def transform(data): 19 | color = data['color'] 20 | h,w,_ = color.shape 21 | factor = 4 22 | color = cv2.resize(color, (w//factor,h//factor), interpolation=cv2.INTER_AREA) 23 | # color = color[:,140:500] 24 | data['color'] = color 25 | return data 26 | 27 | from state_diff.common.cv2_util import get_image_transform 28 | color_transform = get_image_transform( 29 | input_res=(1280,720), 30 | output_res=(640,480), 31 | bgr_to_rgb=False) 32 | def transform(data): 33 | data['color'] = color_transform(data['color']) 34 | return data 35 | 36 | # one thread per camera 37 | video_recorder = VideoRecorder.create_h264( 38 | fps=30, 39 | codec='h264', 40 | thread_type='FRAME' 41 | ) 42 | 43 | with MultiRealsense( 44 | resolution=(1280,720), 45 | capture_fps=30, 46 | record_fps=15, 47 | enable_color=True, 48 | # advanced_mode_config=config, 49 | transform=transform, 50 | # recording_transform=transform, 51 | # video_recorder=video_recorder, 52 | verbose=True 53 | ) as realsense: 54 | realsense.set_exposure(exposure=150, gain=5) 55 | intr = realsense.get_intr_mat() 56 | print(intr) 57 | 58 | video_path = 'data_local/test' 59 | rec_start_time = time.time() + 1 60 | realsense.start_recording(video_path, start_time=rec_start_time) 61 | realsense.restart_put(rec_start_time) 62 | 63 | out = None 64 | vis_img = None 65 | while True: 66 | out = realsense.get(out=out) 67 | 68 | # bgr = out['color'] 69 | # print(bgr.shape) 70 | # vis_img = np.concatenate(list(bgr), axis=0, out=vis_img) 71 | # cv2.imshow('default', vis_img) 72 | # key = cv2.pollKey() 73 | # if key == ord('q'): 74 | # break 75 | 76 | time.sleep(1/60) 77 | if time.time() > (rec_start_time + 20.0): 78 | break 79 | 80 | 81 | if __name__ == "__main__": 82 | test() 83 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | bin 2 | logs 3 | wandb 4 | outputs 5 | data 6 | data_local 7 | .vscode 8 | _wandb 9 | 10 | **/.DS_Store 11 | 12 | fuse.cfg 13 | 14 | *.ai 15 | 16 | # Generation results 17 | results/ 18 | 19 | ray/auth.json 20 | 21 | # Byte-compiled / optimized / DLL files 22 | __pycache__/ 23 | *.py[cod] 24 | *$py.class 25 | 26 | # C extensions 27 | *.so 28 | 29 | # Distribution / packaging 30 | .Python 31 | build/ 32 | develop-eggs/ 33 | dist/ 34 | downloads/ 35 | eggs/ 36 | .eggs/ 37 | lib/ 38 | lib64/ 39 | parts/ 40 | sdist/ 41 | var/ 42 | wheels/ 43 | pip-wheel-metadata/ 44 | share/python-wheels/ 45 | *.egg-info/ 46 | .installed.cfg 47 | *.egg 48 | MANIFEST 49 | 50 | # PyInstaller 51 | # Usually these files are written by a python script from a template 52 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 53 | *.manifest 54 | *.spec 55 | 56 | # Installer logs 57 | pip-log.txt 58 | pip-delete-this-directory.txt 59 | 60 | # Unit test / coverage reports 61 | htmlcov/ 62 | .tox/ 63 | .nox/ 64 | .coverage 65 | .coverage.* 66 | .cache 67 | nosetests.xml 68 | coverage.xml 69 | *.cover 70 | *.py,cover 71 | .hypothesis/ 72 | .pytest_cache/ 73 | 74 | # Translations 75 | *.mo 76 | *.pot 77 | 78 | # Django stuff: 79 | *.log 80 | *.out 81 | local_settings.py 82 | db.sqlite3 83 | db.sqlite3-journal 84 | 85 | # Flask stuff: 86 | instance/ 87 | .webassets-cache 88 | 89 | # Scrapy stuff: 90 | .scrapy 91 | 92 | # Sphinx documentation 93 | docs/_build/ 94 | 95 | # PyBuilder 96 | target/ 97 | 98 | # Jupyter Notebook 99 | .ipynb_checkpoints 100 | 101 | # IPython 102 | profile_default/ 103 | ipython_config.py 104 | 105 | # pyenv 106 | .python-version 107 | 108 | # pipenv 109 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 110 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 111 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 112 | # install all needed dependencies. 113 | #Pipfile.lock 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Spyder project settings 126 | .spyderproject 127 | .spyproject 128 | 129 | # Rope project settings 130 | .ropeproject 131 | 132 | # mkdocs documentation 133 | /site 134 | 135 | # mypy 136 | .mypy_cache/ 137 | .dmypy.json 138 | dmypy.json 139 | 140 | # Pyre type checker 141 | .pyre/ 142 | 143 | 144 | wandb_api_key.txt 145 | 146 | scripts/slurm/slurm_outputs -------------------------------------------------------------------------------- /state_diff/model/common/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a Gaussian distribution discretizing to a 53 | given image. 54 | 55 | :param x: the target images. It is assumed that this was uint8 values, 56 | rescaled to the range [-1, 1]. 57 | :param means: the Gaussian mean Tensor. 58 | :param log_scales: the Gaussian log stddev Tensor. 59 | :return: a tensor like x of log probabilities (in nats). 60 | """ 61 | assert x.shape == means.shape == log_scales.shape 62 | centered_x = x - means 63 | inv_stdv = th.exp(-log_scales) 64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 65 | cdf_plus = approx_standard_normal_cdf(plus_in) 66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 67 | cdf_min = approx_standard_normal_cdf(min_in) 68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 70 | cdf_delta = cdf_plus - cdf_min 71 | log_probs = th.where( 72 | x < -0.999, 73 | log_cdf_plus, 74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 75 | ) 76 | assert log_probs.shape == x.shape 77 | return log_probs 78 | -------------------------------------------------------------------------------- /state_diff/env/block_pushing/oracles/discontinuous_push_oracle.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Reach ML Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Pushes to first target, waits, then pushes to second target.""" 17 | 18 | import state_diff.env.block_pushing.oracles.oriented_push_oracle as oriented_push_oracle_module 19 | import numpy as np 20 | from tf_agents.trajectories import policy_step 21 | from tf_agents.trajectories import time_step as ts 22 | from tf_agents.typing import types 23 | 24 | # Only used for debug visualization. 25 | import pybullet # pylint: disable=unused-import 26 | 27 | 28 | class DiscontinuousOrientedPushOracle(oriented_push_oracle_module.OrientedPushOracle): 29 | """Pushes to first target, waits, then pushes to second target.""" 30 | 31 | def __init__(self, env, goal_tolerance=0.04, wait=0): 32 | super(DiscontinuousOrientedPushOracle, self).__init__(env) 33 | self._countdown = 0 34 | self._wait = wait 35 | self._goal_dist_tolerance = goal_tolerance 36 | 37 | def reset(self): 38 | self.phase = "move_to_pre_block" 39 | self._countdown = 0 40 | 41 | def _action(self, time_step, policy_state): 42 | if time_step.is_first(): 43 | self.reset() 44 | # Move to first target first. 45 | self._current_target = "target" 46 | self._has_switched = False 47 | 48 | def _block_target_dist(block, target): 49 | dist = np.linalg.norm( 50 | time_step.observation["%s_translation" % block] 51 | - time_step.observation["%s_translation" % target] 52 | ) 53 | return dist 54 | 55 | d1 = _block_target_dist("block", "target") 56 | if d1 < self._goal_dist_tolerance and not self._has_switched: 57 | self._countdown = self._wait 58 | # If first block has been pushed to first target, switch to second block. 59 | self._has_switched = True 60 | self._current_target = "target2" 61 | 62 | xy_delta = self._get_action_for_block_target( 63 | time_step, block="block", target=self._current_target 64 | ) 65 | 66 | if self._countdown > 0: 67 | xy_delta = np.zeros_like(xy_delta) 68 | self._countdown -= 1 69 | 70 | return policy_step.PolicyStep(action=np.asarray(xy_delta, dtype=np.float32)) 71 | -------------------------------------------------------------------------------- /state_diff/common/pytorch_util.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Callable, List 2 | import collections 3 | import torch 4 | import torch.nn as nn 5 | 6 | def dict_apply( 7 | x: Dict[str, torch.Tensor], 8 | func: Callable[[torch.Tensor], torch.Tensor] 9 | ) -> Dict[str, torch.Tensor]: 10 | result = dict() 11 | for key, value in x.items(): 12 | if isinstance(value, dict): 13 | result[key] = dict_apply(value, func) 14 | else: 15 | result[key] = func(value) 16 | return result 17 | 18 | def pad_remaining_dims(x, target): 19 | assert x.shape == target.shape[:len(x.shape)] 20 | return x.reshape(x.shape + (1,)*(len(target.shape) - len(x.shape))) 21 | 22 | def dict_apply_split( 23 | x: Dict[str, torch.Tensor], 24 | split_func: Callable[[torch.Tensor], Dict[str, torch.Tensor]] 25 | ) -> Dict[str, torch.Tensor]: 26 | results = collections.defaultdict(dict) 27 | for key, value in x.items(): 28 | result = split_func(value) 29 | for k, v in result.items(): 30 | results[k][key] = v 31 | return results 32 | 33 | def dict_apply_reduce( 34 | x: List[Dict[str, torch.Tensor]], 35 | reduce_func: Callable[[List[torch.Tensor]], torch.Tensor] 36 | ) -> Dict[str, torch.Tensor]: 37 | result = dict() 38 | for key in x[0].keys(): 39 | result[key] = reduce_func([x_[key] for x_ in x]) 40 | return result 41 | 42 | 43 | def replace_submodules( 44 | root_module: nn.Module, 45 | predicate: Callable[[nn.Module], bool], 46 | func: Callable[[nn.Module], nn.Module]) -> nn.Module: 47 | """ 48 | predicate: Return true if the module is to be replaced. 49 | func: Return new module to use. 50 | """ 51 | if predicate(root_module): 52 | return func(root_module) 53 | 54 | bn_list = [k.split('.') for k, m 55 | in root_module.named_modules(remove_duplicate=True) 56 | if predicate(m)] 57 | for *parent, k in bn_list: 58 | parent_module = root_module 59 | if len(parent) > 0: 60 | parent_module = root_module.get_submodule('.'.join(parent)) 61 | if isinstance(parent_module, nn.Sequential): 62 | src_module = parent_module[int(k)] 63 | else: 64 | src_module = getattr(parent_module, k) 65 | tgt_module = func(src_module) 66 | if isinstance(parent_module, nn.Sequential): 67 | parent_module[int(k)] = tgt_module 68 | else: 69 | setattr(parent_module, k, tgt_module) 70 | # verify that all BN are replaced 71 | bn_list = [k.split('.') for k, m 72 | in root_module.named_modules(remove_duplicate=True) 73 | if predicate(m)] 74 | assert len(bn_list) == 0 75 | return root_module 76 | 77 | def optimizer_to(optimizer, device): 78 | for state in optimizer.state.values(): 79 | for k, v in state.items(): 80 | if isinstance(v, torch.Tensor): 81 | state[k] = v.to(device=device) 82 | return optimizer 83 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_models/kitchen/assets/oven_asset.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /state_diff/env/block_pushing/assets/suction/cylinder_real.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 95 | 96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /state_diff/env/block_pushing/assets/suction/cylinder.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 95 | 96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /tests/test_single_realsense.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | ROOT_DIR = os.path.dirname(os.path.dirname(__file__)) 5 | sys.path.append(ROOT_DIR) 6 | os.chdir(ROOT_DIR) 7 | 8 | import cv2 9 | import json 10 | import time 11 | from multiprocessing.managers import SharedMemoryManager 12 | from state_diff.real_world.single_realsense import SingleRealsense 13 | 14 | def test(): 15 | 16 | serials = SingleRealsense.get_connected_devices_serial() 17 | # import pdb; pdb.set_trace() 18 | serial = serials[0] 19 | config = json.load(open('/home/cchi/dev/state_diff/state_diff/real_world/realsense_config/415_high_accuracy_mode.json', 'r')) 20 | 21 | def transform(data): 22 | color = data['color'] 23 | h,w,_ = color.shape 24 | factor = 2 25 | color = cv2.resize(color, (w//factor,h//factor), interpolation=cv2.INTER_AREA) 26 | # color = color[:,140:500] 27 | data['color'] = color 28 | return data 29 | 30 | # at 960x540 with //3, 60fps and 30fps are indistinguishable 31 | 32 | with SharedMemoryManager() as shm_manager: 33 | with SingleRealsense( 34 | shm_manager=shm_manager, 35 | serial_number=serial, 36 | resolution=(1280,720), 37 | # resolution=(960,540), 38 | # resolution=(640,480), 39 | capture_fps=30, 40 | enable_color=True, 41 | # enable_depth=True, 42 | # enable_infrared=True, 43 | # advanced_mode_config=config, 44 | # transform=transform, 45 | # recording_transform=transform 46 | # verbose=True 47 | ) as realsense: 48 | cv2.setNumThreads(1) 49 | realsense.set_exposure(exposure=150, gain=5) 50 | intr = realsense.get_intr_mat() 51 | print(intr) 52 | 53 | 54 | video_path = 'data_local/test.mp4' 55 | rec_start_time = time.time() + 2 56 | realsense.start_recording(video_path, start_time=rec_start_time) 57 | 58 | data = None 59 | while True: 60 | data = realsense.get(out=data) 61 | t = time.time() 62 | # print('capture_latency', data['receive_timestamp']-data['capture_timestamp'], 'receive_latency', t - data['receive_timestamp']) 63 | # print('receive', t - data['receive_timestamp']) 64 | 65 | # dt = time.time() - data['timestamp'] 66 | # print(dt) 67 | # print(data['capture_timestamp'] - rec_start_time) 68 | 69 | bgr = data['color'] 70 | # print(bgr.shape) 71 | cv2.imshow('default', bgr) 72 | key = cv2.pollKey() 73 | # if key == ord('q'): 74 | # break 75 | # elif key == ord('r'): 76 | # video_path = 'data_local/test.mp4' 77 | # realsense.start_recording(video_path) 78 | # elif key == ord('s'): 79 | # realsense.stop_recording() 80 | 81 | time.sleep(1/60) 82 | if time.time() > (rec_start_time + 20.0): 83 | break 84 | 85 | 86 | if __name__ == "__main__": 87 | test() 88 | -------------------------------------------------------------------------------- /state_diff/env/block_pushing/assets/suction/suction-head-long.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 98 | 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /state_diff/dataset/kitchen_lowdim_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import torch 3 | import numpy as np 4 | import copy 5 | import pathlib 6 | from state_diff.common.pytorch_util import dict_apply 7 | from state_diff.common.replay_buffer import ReplayBuffer 8 | from state_diff.common.sampler import SequenceSampler, get_val_mask 9 | from state_diff.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer 10 | from state_diff.dataset.base_dataset import BaseLowdimDataset 11 | 12 | class KitchenLowdimDataset(BaseLowdimDataset): 13 | def __init__(self, 14 | dataset_dir, 15 | horizon=1, 16 | pad_before=0, 17 | pad_after=0, 18 | seed=42, 19 | val_ratio=0.0 20 | ): 21 | super().__init__() 22 | 23 | data_directory = pathlib.Path(dataset_dir) 24 | observations = np.load(data_directory / "observations_seq.npy") 25 | actions = np.load(data_directory / "actions_seq.npy") 26 | masks = np.load(data_directory / "existence_mask.npy") 27 | 28 | self.replay_buffer = ReplayBuffer.create_empty_numpy() 29 | for i in range(len(masks)): 30 | eps_len = int(masks[i].sum()) 31 | obs = observations[i,:eps_len].astype(np.float32) 32 | action = actions[i,:eps_len].astype(np.float32) 33 | data = { 34 | 'obs': obs, 35 | 'action': action 36 | } 37 | self.replay_buffer.add_episode(data) 38 | 39 | val_mask = get_val_mask( 40 | n_episodes=self.replay_buffer.n_episodes, 41 | val_ratio=val_ratio, 42 | seed=seed) 43 | train_mask = ~val_mask 44 | self.sampler = SequenceSampler( 45 | replay_buffer=self.replay_buffer, 46 | sequence_length=horizon, 47 | pad_before=pad_before, 48 | pad_after=pad_after, 49 | episode_mask=train_mask) 50 | 51 | self.train_mask = train_mask 52 | self.horizon = horizon 53 | self.pad_before = pad_before 54 | self.pad_after = pad_after 55 | 56 | def get_validation_dataset(self): 57 | val_set = copy.copy(self) 58 | val_set.sampler = SequenceSampler( 59 | replay_buffer=self.replay_buffer, 60 | sequence_length=self.horizon, 61 | pad_before=self.pad_before, 62 | pad_after=self.pad_after, 63 | episode_mask=~self.train_mask 64 | ) 65 | val_set.train_mask = ~self.train_mask 66 | return val_set 67 | 68 | def get_normalizer(self, mode='limits', **kwargs): 69 | data = { 70 | 'obs': self.replay_buffer['obs'], 71 | 'action': self.replay_buffer['action'] 72 | } 73 | if 'range_eps' not in kwargs: 74 | # to prevent blowing up dims that barely change 75 | kwargs['range_eps'] = 5e-2 76 | normalizer = LinearNormalizer() 77 | normalizer.fit(data=data, last_n_dims=1, mode=mode, **kwargs) 78 | return normalizer 79 | 80 | def get_all_actions(self) -> torch.Tensor: 81 | return torch.from_numpy(self.replay_buffer['action']) 82 | 83 | def __len__(self) -> int: 84 | return len(self.sampler) 85 | 86 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 87 | sample = self.sampler.sample_sequence(idx) 88 | data = sample 89 | 90 | torch_data = dict_apply(data, torch.from_numpy) 91 | return torch_data 92 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/third_party/franka/bi-franka_panda.xml: -------------------------------------------------------------------------------- 1 | 4 | 5 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | / 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /state_diff/common/checkpoint_util.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict 2 | import os 3 | 4 | class TopKCheckpointManager: 5 | """ 6 | A manager for saving top K performing model checkpoints based on a specified metric. 7 | 8 | Attributes: 9 | save_dir (str): Directory where checkpoints are saved. 10 | monitor_key (str): Key for the metric to monitor in the data dictionary. 11 | mode (str): Mode for evaluation ('max' for higher is better, 'min' for lower is better). 12 | k (int): Number of top-performing checkpoints to keep. 13 | format_str (str): Format string for generating checkpoint filenames. 14 | path_value_map (dict): A dictionary mapping checkpoint paths to their metric values. 15 | """ 16 | 17 | def __init__(self, 18 | save_dir, 19 | monitor_key: str, 20 | mode='min', 21 | k=1, 22 | format_str='epoch={epoch:03d}-train_loss={train_loss:.3f}.ckpt' 23 | ): 24 | """ 25 | Initializes the TopKCheckpointManager with the specified configuration. 26 | 27 | Args: 28 | save_dir (str): The directory to save checkpoints. 29 | monitor_key (str): The metric key to monitor for checkpointing. 30 | mode (str, optional): Determines if 'max' or 'min' values of the monitored metric are better. Defaults to 'min'. 31 | k (int, optional): The number of top checkpoints to maintain. Defaults to 1. 32 | format_str (str, optional): Format string for checkpoint filenames. Defaults to 'epoch={epoch:03d}-train_loss={train_loss:.3f}.ckpt'. 33 | """ 34 | 35 | assert mode in ['max', 'min'] 36 | assert k >= 0 37 | 38 | self.save_dir = save_dir 39 | self.monitor_key = monitor_key 40 | self.mode = mode 41 | self.k = k 42 | self.format_str = format_str 43 | self.path_value_map = dict() 44 | 45 | def get_ckpt_path(self, data: Dict[str, float]) -> Optional[str]: 46 | """ 47 | Determines the checkpoint path for the current metrics and decides whether to replace an existing checkpoint. 48 | 49 | Args: 50 | data (Dict[str, float]): A dictionary containing current metric data. 51 | 52 | Returns: 53 | Optional[str]: The path to save the new checkpoint if it is among the top K; otherwise, None. 54 | """ 55 | 56 | if self.k == 0: 57 | return None 58 | 59 | value = data[self.monitor_key] 60 | ckpt_path = os.path.join( 61 | self.save_dir, self.format_str.format(**data)) 62 | 63 | if len(self.path_value_map) < self.k: 64 | # under-capacity 65 | self.path_value_map[ckpt_path] = value 66 | return ckpt_path 67 | 68 | # at capacity 69 | sorted_map = sorted(self.path_value_map.items(), key=lambda x: x[1]) 70 | min_path, min_value = sorted_map[0] 71 | max_path, max_value = sorted_map[-1] 72 | 73 | delete_path = None 74 | if self.mode == 'max': 75 | if value > min_value: 76 | delete_path = min_path 77 | else: 78 | if value < max_value: 79 | delete_path = max_path 80 | 81 | if delete_path is None: 82 | return None 83 | else: 84 | del self.path_value_map[delete_path] 85 | self.path_value_map[ckpt_path] = value 86 | 87 | if not os.path.exists(self.save_dir): 88 | os.mkdir(self.save_dir) 89 | 90 | if os.path.exists(delete_path): 91 | os.remove(delete_path) 92 | return ckpt_path 93 | -------------------------------------------------------------------------------- /state_diff/gym_util/video_recording_wrapper.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | from state_diff.gym_util.video_recorder import VideoRecorder 4 | import cv2 5 | import matplotlib.pyplot as plt 6 | class VideoRecordingWrapper(gym.Wrapper): 7 | def __init__(self, 8 | env, 9 | video_recoder: VideoRecorder, 10 | mode='rgb_array', 11 | file_path=None, 12 | steps_per_render=1, 13 | **kwargs 14 | ): 15 | """ 16 | When file_path is None, don't record. 17 | """ 18 | super().__init__(env) 19 | 20 | self.mode = mode 21 | self.render_kwargs = kwargs 22 | self.steps_per_render = steps_per_render 23 | self.file_path = file_path 24 | self.video_recoder = video_recoder 25 | 26 | self.step_count = 0 27 | 28 | def reset(self, **kwargs): 29 | obs = super().reset(**kwargs) 30 | self.frames = list() 31 | self.step_count = 1 32 | self.video_recoder.stop() 33 | return obs 34 | 35 | def step(self, action): 36 | deno_actions = None 37 | if len(action.shape) == 2: 38 | deno_actions = action 39 | action = action[-1] 40 | result = super().step(action) 41 | self.step_count += 1 42 | if self.file_path is not None \ 43 | and ((self.step_count % self.steps_per_render) == 0): 44 | if not self.video_recoder.is_ready(): 45 | self.video_recoder.start(self.file_path) 46 | 47 | frame = self.env.render( 48 | mode=self.mode, **self.render_kwargs) 49 | 50 | if deno_actions is not None: 51 | frame = self.overlay_points_on_frame(frame, deno_actions) 52 | 53 | assert frame.dtype == np.uint8 54 | 55 | self.video_recoder.write_frame(frame) 56 | return result 57 | 58 | def render(self, mode='rgb_array', **kwargs): 59 | if self.video_recoder.is_ready(): 60 | self.video_recoder.stop() 61 | return self.file_path 62 | 63 | 64 | 65 | def overlay_points_on_frame(self, frame, points, color=(255, 0, 0), gradient=False): 66 | 67 | """Helper method to overlay deno_actions as points on the frame.""" 68 | # if points.shape[-1] == 4: 69 | # points = np.concatenate([points[:, :2], points[:, 2:]], axis=0) 70 | # elif points.shape[-1] == 6: 71 | # points = np.concatenate([points[:, :2], points[:, 2:4], points[:, 4:]], axis=0) 72 | 73 | # Mapping of possible dimensions to slicing rules 74 | slicing_map = { 75 | 4: [slice(0, 2), slice(2, 4)], 76 | 6: [slice(0, 2), slice(2, 4), slice(4, 6)] 77 | } 78 | 79 | # Check if the last dimension matches a supported format 80 | slices = slicing_map.get(points.shape[-1]) 81 | if slices: 82 | points = np.concatenate([points[:, sl] for sl in slices], axis=0) 83 | frame_size = frame.shape[0] 84 | if frame_size == 96: 85 | radians = 1 86 | elif frame_size == 512: 87 | radius = 4 88 | 89 | if gradient: 90 | colors = plt.cm.plasma(np.linspace(0, 1, points.shape[0]))[:, :3] * 255 91 | else: 92 | colors = [color] * points.shape[0] 93 | 94 | coord = (points / 512 * frame_size).astype(np.int32) 95 | for idx, point in enumerate(coord): 96 | x, y = map(int, point) 97 | cv2.circle(frame, (x, y), radius=radius, color=tuple(colors[idx]), thickness=-1) 98 | return frame 99 | 100 | -------------------------------------------------------------------------------- /state_diff/model/common/rotation_transformer.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import pytorch3d.transforms as pt 3 | import torch 4 | import numpy as np 5 | import functools 6 | 7 | class RotationTransformer: 8 | valid_reps = [ 9 | 'axis_angle', 10 | 'euler_angles', 11 | 'quaternion', # (w, x, y, z) convention 12 | 'rotation_6d', 13 | 'matrix' 14 | ] 15 | 16 | def __init__(self, 17 | from_rep='axis_angle', 18 | to_rep='rotation_6d', 19 | from_convention=None, 20 | to_convention=None): 21 | """ 22 | Valid representations 23 | 24 | Always use matrix as intermediate representation. 25 | """ 26 | assert from_rep != to_rep 27 | assert from_rep in self.valid_reps 28 | assert to_rep in self.valid_reps 29 | if from_rep == 'euler_angles': 30 | assert from_convention is not None 31 | if to_rep == 'euler_angles': 32 | assert to_convention is not None 33 | 34 | forward_funcs = list() 35 | inverse_funcs = list() 36 | 37 | if from_rep != 'matrix': 38 | funcs = [ 39 | getattr(pt, f'{from_rep}_to_matrix'), 40 | getattr(pt, f'matrix_to_{from_rep}') 41 | ] 42 | if from_convention is not None: 43 | funcs = [functools.partial(func, convention=from_convention) 44 | for func in funcs] 45 | forward_funcs.append(funcs[0]) 46 | inverse_funcs.append(funcs[1]) 47 | 48 | if to_rep != 'matrix': 49 | funcs = [ 50 | getattr(pt, f'matrix_to_{to_rep}'), 51 | getattr(pt, f'{to_rep}_to_matrix') 52 | ] 53 | if to_convention is not None: 54 | funcs = [functools.partial(func, convention=to_convention) 55 | for func in funcs] 56 | forward_funcs.append(funcs[0]) 57 | inverse_funcs.append(funcs[1]) 58 | 59 | inverse_funcs = inverse_funcs[::-1] 60 | 61 | self.forward_funcs = forward_funcs 62 | self.inverse_funcs = inverse_funcs 63 | 64 | @staticmethod 65 | def _apply_funcs(x: Union[np.ndarray, torch.Tensor], funcs: list) -> Union[np.ndarray, torch.Tensor]: 66 | x_ = x 67 | if isinstance(x, np.ndarray): 68 | x_ = torch.from_numpy(x) 69 | x_: torch.Tensor 70 | for func in funcs: 71 | x_ = func(x_) 72 | y = x_ 73 | if isinstance(x, np.ndarray): 74 | y = x_.numpy() 75 | return y 76 | 77 | def forward(self, x: Union[np.ndarray, torch.Tensor] 78 | ) -> Union[np.ndarray, torch.Tensor]: 79 | return self._apply_funcs(x, self.forward_funcs) 80 | 81 | def inverse(self, x: Union[np.ndarray, torch.Tensor] 82 | ) -> Union[np.ndarray, torch.Tensor]: 83 | return self._apply_funcs(x, self.inverse_funcs) 84 | 85 | 86 | def test(): 87 | tf = RotationTransformer() 88 | 89 | rotvec = np.random.uniform(-2*np.pi,2*np.pi,size=(1000,3)) 90 | rot6d = tf.forward(rotvec) 91 | new_rotvec = tf.inverse(rot6d) 92 | 93 | from scipy.spatial.transform import Rotation 94 | diff = Rotation.from_rotvec(rotvec) * Rotation.from_rotvec(new_rotvec).inv() 95 | dist = diff.magnitude() 96 | assert dist.max() < 1e-7 97 | 98 | tf = RotationTransformer('rotation_6d', 'matrix') 99 | rot6d_wrong = rot6d + np.random.normal(scale=0.1, size=rot6d.shape) 100 | mat = tf.forward(rot6d_wrong) 101 | mat_det = np.linalg.det(mat) 102 | assert np.allclose(mat_det, 1) 103 | # rotaiton_6d will be normalized to rotation matrix 104 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Coordinated Bimanual Manipulation Policies using State Diffusion and Inverse Dynamics Models 2 | 3 | 4 | 5 | [Haonan Chen](http://haonan16.github.io/)1, 6 | [Jiaming Xu](#)*1, 7 | [Lily Sheng](#)*1, 8 | [Tianchen Ji](https://tianchenji.github.io/)1 9 | [Shuijing Liu](https://shuijing725.github.io/)3 10 | [Yunzhu Li](https://yunzhuli.github.io/)2, 11 | [Katherine Driggs-Campbell](https://krdc.web.illinois.edu/)1 12 | 13 | 1University of Illinois, Urbana-Champaign, 14 | 2Columbia University, 15 | 3The University of Texas at Austin 16 | 17 | 18 | ### Environment Setup 19 | 20 | We recommend using [Mambaforge](https://gyithub.com/conda-forge/miniforge#mambaforge) over the standard Anaconda distribution for a faster installation process. Create your environment using: 21 | 22 | 1. Install the necessary dependencies: 23 | ```console 24 | sudo apt install -y libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf libglm-dev 25 | ``` 26 | 27 | 2. Clone the repository: 28 | ```console 29 | git clone https://github.com/haonan16/state_diff.git 30 | cd state_diff/ 31 | ``` 32 | 33 | 3. Update `mamba`, and create and activate the environment: 34 | 35 | To update `mamba` and create the environment, use the following commands: 36 | 37 | ```console 38 | mamba install mamba=1.5.1 -n base -c conda-forge 39 | mamba env create -f conda_environment.yaml 40 | mamba activate coord_bimanual 41 | ``` 42 | 43 | 44 | 45 | 46 | 47 | ## Data Preparation 48 | 49 | ```console 50 | mkdir -p data 51 | ``` 52 | 53 | Place the `pushl` dataset in the `data` folder. The directory structure should be: 54 | 55 | To obtain the dataset, download the corresponding zip file and unzip it from the following link: 56 | 57 | - [PushL Dataset](https://drive.google.com/file/d/1433HHxOH5nomDZ12aUZ4XZel4nVO_uwL/view?usp=sharing) 58 | 59 | You can automate this process by running the following command: 60 | 61 | ```console 62 | gdown https://drive.google.com/uc?id=1433HHxOH5nomDZ12aUZ4XZel4nVO_uwL -O data/pushl_dataset.zip 63 | unzip data/pushl_dataset.zip -d data 64 | rm data/pushl_dataset.zip 65 | ``` 66 | 67 | 68 | For datasets from other simulation benchmarks, you can find them here: 69 | 70 | - [Simulation Benchmark Datasets from Diffusion Policy](https://github.com/real-stanford/diffusion_policy?tab=readme-ov-file) 71 | 72 | 73 | 74 | Once downloaded, extract the contents into the `data` folder. 75 | 76 | Directory structure: 77 | 78 | ``` 79 | data/ 80 | └── pushl_dataset/ 81 | ``` 82 | 83 | 84 | 85 | 86 | 87 | ## Demo, Training and Eval 88 | 89 | 90 | Activate conda environment and login to [wandb](https://wandb.ai) (if you haven't already). 91 | ```console 92 | conda activate coord_bimanual 93 | wandb login 94 | ``` 95 | 96 | 97 | 98 | ### Collecting Human Demonstration Data 99 | 100 | Run the following script to collect human demonstration: 101 | 102 | ```console 103 | python demo_push_data_collection.py 104 | ``` 105 | 106 | The agent agent position can be controlled by the mouse. The following keys can be used to control the environment: 107 | - `Space` - Pause and step forward (increments `plan_idx`) 108 | - `R` - Retry current attempt 109 | - `Q` - Exit the script 110 | 111 | 112 | 113 | ### Training 114 | To launch training, run: 115 | 116 | 117 | ```console 118 | python train.py \ 119 | --config-name=train_diffusion_trajectory_unet_lowdim_workspace.yaml 120 | ``` 121 | 122 | 123 | ## Acknowledgement 124 | * Policy training implementation is adapted from [Diffusion Policy](https://github.com/real-stanford/diffusion_policy/tree/main). 125 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/third_party/franka/assets/assets.xml: -------------------------------------------------------------------------------- 1 | 4 | 5 | 6 | 7 | 64 | -------------------------------------------------------------------------------- /state_diff/dataset/pusht_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import torch 3 | import numpy as np 4 | import copy 5 | from state_diff.common.pytorch_util import dict_apply 6 | from state_diff.common.replay_buffer import ReplayBuffer 7 | from state_diff.common.sampler import ( 8 | SequenceSampler, get_val_mask, downsample_mask) 9 | from state_diff.model.common.normalizer import LinearNormalizer 10 | from state_diff.dataset.base_dataset import BaseLowdimDataset 11 | 12 | class PushTLowdimDataset(BaseLowdimDataset): 13 | def __init__(self, 14 | zarr_path, 15 | horizon=1, 16 | pad_before=0, 17 | pad_after=0, 18 | obs_key='keypoint', 19 | state_key='state', 20 | action_key='action', 21 | seed=42, 22 | val_ratio=0.0, 23 | max_train_episodes=None, 24 | num_episodes=None, 25 | ): 26 | super().__init__() 27 | self.replay_buffer = ReplayBuffer.copy_from_path( 28 | zarr_path, keys=[obs_key, state_key, action_key], num_episodes=num_episodes) 29 | if ',' in action_key and 'action' in action_key: 30 | action_key = 'action' 31 | self.agent_len = 4 32 | else: 33 | self.agent_len = 2 34 | val_mask = get_val_mask( 35 | n_episodes=self.replay_buffer.n_episodes, 36 | val_ratio=val_ratio, 37 | seed=seed) 38 | train_mask = ~val_mask 39 | train_mask = downsample_mask( 40 | mask=train_mask, 41 | max_n=max_train_episodes, 42 | seed=seed) 43 | 44 | self.sampler = SequenceSampler( 45 | replay_buffer=self.replay_buffer, 46 | sequence_length=horizon, 47 | pad_before=pad_before, 48 | pad_after=pad_after, 49 | episode_mask=train_mask 50 | ) 51 | self.obs_key = obs_key 52 | self.state_key = state_key 53 | self.action_key = action_key 54 | self.train_mask = train_mask 55 | self.horizon = horizon 56 | self.pad_before = pad_before 57 | self.pad_after = pad_after 58 | 59 | def get_validation_dataset(self): 60 | val_set = copy.copy(self) 61 | val_set.sampler = SequenceSampler( 62 | replay_buffer=self.replay_buffer, 63 | sequence_length=self.horizon, 64 | pad_before=self.pad_before, 65 | pad_after=self.pad_after, 66 | episode_mask=~self.train_mask 67 | ) 68 | val_set.train_mask = ~self.train_mask 69 | return val_set 70 | 71 | def get_normalizer(self, mode='limits', **kwargs): 72 | data = self._sample_to_data(self.replay_buffer) 73 | normalizer = LinearNormalizer() 74 | normalizer.fit(data=data, last_n_dims=1, mode=mode, **kwargs) 75 | return normalizer 76 | 77 | def get_all_actions(self) -> torch.Tensor: 78 | return torch.from_numpy(self.replay_buffer[self.action_key]) 79 | 80 | def __len__(self) -> int: 81 | return len(self.sampler) 82 | 83 | def _sample_to_data(self, sample): 84 | keypoint = sample[self.obs_key] 85 | state = sample[self.state_key] 86 | agent_pos = state[:,:self.agent_len] 87 | obs = np.concatenate([ 88 | keypoint.reshape(keypoint.shape[0], -1), 89 | agent_pos], axis=-1) 90 | 91 | data = { 92 | 'obs': obs, # T, D_o 93 | 'action': sample[self.action_key], # T, D_a 94 | } 95 | return data 96 | 97 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 98 | sample = self.sampler.sample_sequence(idx) 99 | data = self._sample_to_data(sample) 100 | 101 | torch_data = dict_apply(data, torch.from_numpy) 102 | return torch_data 103 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/third_party/franka/assets/chain0_overlay.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_envs/adept_envs/utils/config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2020 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import numpy as np 18 | try: 19 | import cElementTree as ET 20 | except ImportError: 21 | try: 22 | # Python 2.5 need to import a different module 23 | import xml.etree.cElementTree as ET 24 | except ImportError: 25 | exit_err("Failed to import cElementTree from any known place") 26 | 27 | CONFIG_XML_DATA = """ 28 | 29 | 30 | 31 | 32 | 33 | """ 34 | 35 | 36 | # Read config from root 37 | def read_config_from_node(root_node, parent_name, child_name, dtype=int): 38 | # find parent 39 | parent_node = root_node.find(parent_name) 40 | if parent_node == None: 41 | quit("Parent %s not found" % parent_name) 42 | 43 | # get child data 44 | child_data = parent_node.get(child_name) 45 | if child_data == None: 46 | quit("Child %s not found" % child_name) 47 | 48 | config_val = np.array(child_data.split(), dtype=dtype) 49 | return config_val 50 | 51 | 52 | # get config frlom file or string 53 | def get_config_root_node(config_file_name=None, config_file_data=None): 54 | try: 55 | # get root 56 | if config_file_data is None: 57 | config_file_content = open(config_file_name, "r") 58 | config = ET.parse(config_file_content) 59 | root_node = config.getroot() 60 | else: 61 | root_node = ET.fromstring(config_file_data) 62 | 63 | # get root data 64 | root_data = root_node.get('name') 65 | root_name = np.array(root_data.split(), dtype=str) 66 | except: 67 | quit("ERROR: Unable to process config file %s" % config_file_name) 68 | 69 | return root_node, root_name 70 | 71 | 72 | # Read config from config_file 73 | def read_config_from_xml(config_file_name, parent_name, child_name, dtype=int): 74 | root_node, root_name = get_config_root_node( 75 | config_file_name=config_file_name) 76 | return read_config_from_node(root_node, parent_name, child_name, dtype) 77 | 78 | 79 | # tests 80 | if __name__ == '__main__': 81 | print("Read config and parse -------------------------") 82 | root, root_name = get_config_root_node(config_file_data=CONFIG_XML_DATA) 83 | print("Root:name \t", root_name) 84 | print("limit:low \t", read_config_from_node(root, "limits", "low", float)) 85 | print("limit:high \t", read_config_from_node(root, "limits", "high", float)) 86 | print("scale:joint \t", read_config_from_node(root, "scale", "joint", 87 | float)) 88 | print("data:type \t", read_config_from_node(root, "data", "type", str)) 89 | 90 | # read straight from xml (dumb the XML data as duh.xml for this test) 91 | root, root_name = get_config_root_node(config_file_name="duh.xml") 92 | print("Read from xml --------------------------------") 93 | print("limit:low \t", read_config_from_xml("duh.xml", "limits", "low", 94 | float)) 95 | print("limit:high \t", 96 | read_config_from_xml("duh.xml", "limits", "high", float)) 97 | print("scale:joint \t", 98 | read_config_from_xml("duh.xml", "scale", "joint", float)) 99 | print("data:type \t", read_config_from_xml("duh.xml", "data", "type", str)) 100 | -------------------------------------------------------------------------------- /state_diff/dataset/pusht_traj_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import torch 3 | import numpy as np 4 | import copy 5 | from state_diff.common.pytorch_util import dict_apply 6 | from state_diff.common.replay_buffer import ReplayBuffer 7 | from state_diff.common.sampler import ( 8 | SequenceSampler, get_val_mask, downsample_mask) 9 | from state_diff.model.common.normalizer import LinearNormalizer 10 | from state_diff.dataset.base_dataset import BaseLowdimDataset 11 | 12 | class PushTTrajLowdimDataset(BaseLowdimDataset): 13 | def __init__(self, 14 | zarr_path, 15 | horizon=1, 16 | pad_before=0, 17 | pad_after=0, 18 | obs_key='keypoint', 19 | state_key='state', 20 | action_key='action', 21 | seed=42, 22 | val_ratio=0.0, 23 | max_train_episodes=None, 24 | num_episodes=None, 25 | **kwargs 26 | ): 27 | super().__init__() 28 | self.replay_buffer = ReplayBuffer.copy_from_path( 29 | zarr_path, keys=[obs_key, state_key, action_key], num_episodes=num_episodes) 30 | 31 | if ',' in action_key and 'action' in action_key: 32 | action_key = 'action' 33 | self.num_agents = kwargs.get('num_agents', 1) 34 | self.agent_len = kwargs.get('action_dim', 2) * self.num_agents 35 | # self.agent_len = kwargs.get('action_dim', 2) * self.num_agents need to debug here; kwargs.get('action_dim', 2) always returns 2 36 | 37 | val_mask = get_val_mask( 38 | n_episodes=self.replay_buffer.n_episodes, 39 | val_ratio=val_ratio, 40 | seed=seed) 41 | train_mask = ~val_mask 42 | train_mask = downsample_mask( 43 | mask=train_mask, 44 | max_n=max_train_episodes, 45 | seed=seed) 46 | 47 | self.sampler = SequenceSampler( 48 | replay_buffer=self.replay_buffer, 49 | sequence_length=horizon, 50 | pad_before=pad_before, 51 | pad_after=pad_after, 52 | episode_mask=train_mask 53 | ) 54 | self.obs_key = obs_key 55 | self.state_key = state_key 56 | self.action_key = action_key 57 | self.train_mask = train_mask 58 | self.horizon = horizon 59 | self.pad_before = pad_before 60 | self.pad_after = pad_after 61 | 62 | def get_validation_dataset(self): 63 | val_set = copy.copy(self) 64 | val_set.sampler = SequenceSampler( 65 | replay_buffer=self.replay_buffer, 66 | sequence_length=self.horizon, 67 | pad_before=self.pad_before, 68 | pad_after=self.pad_after, 69 | episode_mask=~self.train_mask 70 | ) 71 | val_set.train_mask = ~self.train_mask 72 | return val_set 73 | 74 | def get_normalizer(self, mode='limits', **kwargs): 75 | data = self._sample_to_data(self.replay_buffer) 76 | normalizer = LinearNormalizer() 77 | normalizer.fit(data=data, last_n_dims=1, mode=mode, **kwargs) 78 | return normalizer 79 | 80 | def get_all_actions(self) -> torch.Tensor: 81 | return torch.from_numpy(self.replay_buffer[self.action_key]) 82 | 83 | 84 | def __len__(self) -> int: 85 | return len(self.sampler) 86 | 87 | def _sample_to_data(self, sample): 88 | keypoint = sample[self.obs_key] 89 | state = sample[self.state_key] 90 | agent_pos = state[:,:self.agent_len] 91 | obs = np.concatenate([ 92 | keypoint.reshape(keypoint.shape[0], -1), 93 | agent_pos], axis=-1) 94 | 95 | data = { 96 | 'obs': obs, # T, D_o 97 | 'action': sample[self.action_key], # T, D_a 98 | } 99 | return data 100 | 101 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 102 | sample = self.sampler.sample_sequence(idx) 103 | data = self._sample_to_data(sample) 104 | 105 | torch_data = dict_apply(data, torch.from_numpy) 106 | return torch_data 107 | -------------------------------------------------------------------------------- /state_diff/env/kitchen/relay_policy_learning/adept_envs/adept_envs/simulation/module.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2020 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Module for caching Python modules related to simulation.""" 18 | 19 | import sys 20 | 21 | _MUJOCO_PY_MODULE = None 22 | 23 | _DM_MUJOCO_MODULE = None 24 | _DM_VIEWER_MODULE = None 25 | _DM_RENDER_MODULE = None 26 | 27 | _GLFW_MODULE = None 28 | 29 | 30 | def get_mujoco_py(): 31 | """Returns the mujoco_py module.""" 32 | global _MUJOCO_PY_MODULE 33 | if _MUJOCO_PY_MODULE: 34 | return _MUJOCO_PY_MODULE 35 | try: 36 | import mujoco_py 37 | # Override the warning function. 38 | from mujoco_py.builder import cymj 39 | cymj.set_warning_callback(_mj_warning_fn) 40 | except ImportError: 41 | print( 42 | 'Failed to import mujoco_py. Ensure that mujoco_py (using MuJoCo ' 43 | 'v1.50) is installed.', 44 | file=sys.stderr) 45 | sys.exit(1) 46 | _MUJOCO_PY_MODULE = mujoco_py 47 | return mujoco_py 48 | 49 | 50 | def get_mujoco_py_mjlib(): 51 | """Returns the mujoco_py mjlib module.""" 52 | 53 | class MjlibDelegate: 54 | """Wrapper that forwards mjlib calls.""" 55 | 56 | def __init__(self, lib): 57 | self._lib = lib 58 | 59 | def __getattr__(self, name: str): 60 | if name.startswith('mj'): 61 | return getattr(self._lib, '_' + name) 62 | raise AttributeError(name) 63 | 64 | return MjlibDelegate(get_mujoco_py().cymj) 65 | 66 | 67 | def get_dm_mujoco(): 68 | """Returns the DM Control mujoco module.""" 69 | global _DM_MUJOCO_MODULE 70 | if _DM_MUJOCO_MODULE: 71 | return _DM_MUJOCO_MODULE 72 | try: 73 | from dm_control import mujoco 74 | except ImportError: 75 | print( 76 | 'Failed to import dm_control.mujoco. Ensure that dm_control (using ' 77 | 'MuJoCo v2.00) is installed.', 78 | file=sys.stderr) 79 | sys.exit(1) 80 | _DM_MUJOCO_MODULE = mujoco 81 | return mujoco 82 | 83 | 84 | def get_dm_viewer(): 85 | """Returns the DM Control viewer module.""" 86 | global _DM_VIEWER_MODULE 87 | if _DM_VIEWER_MODULE: 88 | return _DM_VIEWER_MODULE 89 | try: 90 | from dm_control import viewer 91 | except ImportError: 92 | print( 93 | 'Failed to import dm_control.viewer. Ensure that dm_control (using ' 94 | 'MuJoCo v2.00) is installed.', 95 | file=sys.stderr) 96 | sys.exit(1) 97 | _DM_VIEWER_MODULE = viewer 98 | return viewer 99 | 100 | 101 | def get_dm_render(): 102 | """Returns the DM Control render module.""" 103 | global _DM_RENDER_MODULE 104 | if _DM_RENDER_MODULE: 105 | return _DM_RENDER_MODULE 106 | try: 107 | try: 108 | from dm_control import _render 109 | render = _render 110 | except ImportError: 111 | print('Warning: DM Control is out of date.') 112 | from dm_control import render 113 | except ImportError: 114 | print( 115 | 'Failed to import dm_control.render. Ensure that dm_control (using ' 116 | 'MuJoCo v2.00) is installed.', 117 | file=sys.stderr) 118 | sys.exit(1) 119 | _DM_RENDER_MODULE = render 120 | return render 121 | 122 | 123 | def _mj_warning_fn(warn_data: bytes): 124 | """Warning function override for mujoco_py.""" 125 | print('WARNING: Mujoco simulation is unstable (has NaNs): {}'.format( 126 | warn_data.decode())) 127 | --------------------------------------------------------------------------------