├── spirl
├── __init__.py
├── data
│ ├── __init__.py
│ ├── block_stacking
│ │ ├── __init__.py
│ │ ├── src
│ │ │ ├── __init__.py
│ │ │ ├── robosuite
│ │ │ │ ├── scripts
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── README.md
│ │ │ │ │ ├── browse_mjcf_model.py
│ │ │ │ │ ├── compile_mjcf_model.py
│ │ │ │ │ ├── demo_gripper_selection.py
│ │ │ │ │ ├── demo_baxter_ik_control.py
│ │ │ │ │ ├── demo_video_recording.py
│ │ │ │ │ ├── demo_pygame_renderer.py
│ │ │ │ │ ├── demo_learning_curriculum.py
│ │ │ │ │ └── demo_collect_and_playback_data.py
│ │ │ │ ├── environments
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── devices
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── device.py
│ │ │ │ │ └── README.md
│ │ │ │ ├── models
│ │ │ │ │ ├── robots
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── sawyer_robot.py
│ │ │ │ │ │ ├── baxter_robot.py
│ │ │ │ │ │ └── robot.py
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── assets
│ │ │ │ │ │ ├── textures
│ │ │ │ │ │ │ ├── can.png
│ │ │ │ │ │ │ ├── clay.png
│ │ │ │ │ │ │ ├── bread.png
│ │ │ │ │ │ │ ├── cereal.png
│ │ │ │ │ │ │ ├── glass.png
│ │ │ │ │ │ │ ├── lemon.png
│ │ │ │ │ │ │ ├── metal.png
│ │ │ │ │ │ │ ├── ceramic.png
│ │ │ │ │ │ │ ├── dark-wood.png
│ │ │ │ │ │ │ └── light-wood.png
│ │ │ │ │ │ ├── objects
│ │ │ │ │ │ │ ├── meshes
│ │ │ │ │ │ │ │ ├── can.stl
│ │ │ │ │ │ │ │ ├── bread.stl
│ │ │ │ │ │ │ │ ├── lemon.stl
│ │ │ │ │ │ │ │ ├── milk.stl
│ │ │ │ │ │ │ │ ├── bottle.stl
│ │ │ │ │ │ │ │ ├── cereal.stl
│ │ │ │ │ │ │ │ └── handles.stl
│ │ │ │ │ │ │ ├── can-visual.xml
│ │ │ │ │ │ │ ├── milk-visual.xml
│ │ │ │ │ │ │ ├── cereal-visual.xml
│ │ │ │ │ │ │ ├── bread-visual.xml
│ │ │ │ │ │ │ ├── bottle.xml
│ │ │ │ │ │ │ ├── can.xml
│ │ │ │ │ │ │ ├── plate-with-hole.xml
│ │ │ │ │ │ │ ├── bread.xml
│ │ │ │ │ │ │ ├── lemon.xml
│ │ │ │ │ │ │ ├── milk.xml
│ │ │ │ │ │ │ ├── cereal.xml
│ │ │ │ │ │ │ ├── square-nut.xml
│ │ │ │ │ │ │ └── round-nut.xml
│ │ │ │ │ │ ├── grippers
│ │ │ │ │ │ │ └── meshes
│ │ │ │ │ │ │ │ ├── deprecated
│ │ │ │ │ │ │ │ ├── limiter.STL
│ │ │ │ │ │ │ │ ├── paddle_tip.STL
│ │ │ │ │ │ │ │ ├── extended_wide.STL
│ │ │ │ │ │ │ │ ├── standard_wide.STL
│ │ │ │ │ │ │ │ ├── basic_hard_tip.STL
│ │ │ │ │ │ │ │ ├── basic_soft_tip.STL
│ │ │ │ │ │ │ │ ├── extended_narrow.STL
│ │ │ │ │ │ │ │ └── electric_gripper_w_fingers.STL
│ │ │ │ │ │ │ │ ├── pr2_gripper
│ │ │ │ │ │ │ │ ├── l_finger.stl
│ │ │ │ │ │ │ │ ├── gripper_palm.stl
│ │ │ │ │ │ │ │ └── l_finger_tip.stl
│ │ │ │ │ │ │ │ ├── robotiq_s_gripper
│ │ │ │ │ │ │ │ ├── palm.STL
│ │ │ │ │ │ │ │ ├── link_0.STL
│ │ │ │ │ │ │ │ ├── link_1.STL
│ │ │ │ │ │ │ │ ├── link_2.STL
│ │ │ │ │ │ │ │ └── link_3.STL
│ │ │ │ │ │ │ │ ├── robotiq_gripper
│ │ │ │ │ │ │ │ ├── adapter_plate.stl
│ │ │ │ │ │ │ │ ├── adapter_plate.dae.stl
│ │ │ │ │ │ │ │ ├── robotiq_85_gripper_base.stl
│ │ │ │ │ │ │ │ ├── robotiq_85_gripper_joint_0_L.stl
│ │ │ │ │ │ │ │ ├── robotiq_85_gripper_joint_0_R.stl
│ │ │ │ │ │ │ │ ├── robotiq_85_gripper_joint_1_L.stl
│ │ │ │ │ │ │ │ ├── robotiq_85_gripper_joint_1_R.stl
│ │ │ │ │ │ │ │ ├── robotiq_85_gripper_joint_2_L.stl
│ │ │ │ │ │ │ │ ├── robotiq_85_gripper_joint_2_R.stl
│ │ │ │ │ │ │ │ ├── robotiq_85_gripper_joint_3_L.stl
│ │ │ │ │ │ │ │ ├── robotiq_85_gripper_joint_3_R.stl
│ │ │ │ │ │ │ │ └── robotiq_85_gripper_adapter_plate.stl
│ │ │ │ │ │ │ │ └── two_finger_gripper
│ │ │ │ │ │ │ │ ├── half_round_tip.STL
│ │ │ │ │ │ │ │ ├── standard_narrow.STL
│ │ │ │ │ │ │ │ └── electric_gripper_base.STL
│ │ │ │ │ │ ├── base.xml
│ │ │ │ │ │ └── arenas
│ │ │ │ │ │ │ ├── empty_arena.xml
│ │ │ │ │ │ │ ├── table_arena.xml
│ │ │ │ │ │ │ └── pegs_arena.xml
│ │ │ │ │ ├── arenas
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── empty_arena.py
│ │ │ │ │ │ ├── arena.py
│ │ │ │ │ │ ├── bins_arena.py
│ │ │ │ │ │ ├── pegs_arena.py
│ │ │ │ │ │ └── table_arena.py
│ │ │ │ │ ├── tasks
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── task.py
│ │ │ │ │ │ ├── table_top_task.py
│ │ │ │ │ │ └── nut_assembly_task.py
│ │ │ │ │ ├── world.py
│ │ │ │ │ ├── grippers
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── pushing_gripper.py
│ │ │ │ │ │ ├── gripper_factory.py
│ │ │ │ │ │ ├── pr2_gripper.py
│ │ │ │ │ │ ├── robotiq_three_finger_gripper.py
│ │ │ │ │ │ ├── robotiq_gripper.py
│ │ │ │ │ │ └── gripper.py
│ │ │ │ │ ├── objects
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ └── xml_objects.py
│ │ │ │ │ └── README.md
│ │ │ │ ├── utils
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── errors.py
│ │ │ │ │ └── mujoco_py_renderer.py
│ │ │ │ ├── controllers
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── controller.py
│ │ │ │ ├── wrappers
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── gym_wrapper.py
│ │ │ │ │ ├── wrapper.py
│ │ │ │ │ └── README.md
│ │ │ │ ├── __init__.py
│ │ │ │ └── demo.py
│ │ │ ├── utils
│ │ │ │ ├── gripper_only_robot.py
│ │ │ │ ├── wide_range_gripper.py
│ │ │ │ └── numbered_box_object.py
│ │ │ ├── demo_gen
│ │ │ │ └── block_stacking_demo_agent.py
│ │ │ ├── block_stacking_data_loader.py
│ │ │ ├── block_stacking_logger.py
│ │ │ └── block_task_generator.py
│ │ └── assets
│ │ │ ├── textures
│ │ │ ├── obj0.png
│ │ │ ├── obj1.png
│ │ │ ├── obj2.png
│ │ │ ├── obj3.png
│ │ │ ├── obj4.png
│ │ │ ├── obj5.png
│ │ │ ├── obj6.png
│ │ │ ├── obj7.png
│ │ │ ├── obj8.png
│ │ │ └── obj9.png
│ │ │ └── gripper_only_robot.xml
│ ├── peg_in_hole
│ │ └── src
│ │ │ └── generate_peg_in_hole_hdf5.py
│ └── README.md
├── utils
│ ├── __init__.py
│ ├── transformations.py
│ ├── data_utils.py
│ ├── video_utils.py
│ └── dist_utils.py
├── components
│ ├── __init__.py
│ ├── params.py
│ └── trainer_base.py
├── models
│ └── __init__.py
├── modules
│ ├── __init__.py
│ └── losses.py
├── configs
│ ├── rl
│ │ ├── maze
│ │ │ ├── SAC
│ │ │ │ └── conf.py
│ │ │ ├── prior_initialized
│ │ │ │ ├── bc_finetune
│ │ │ │ │ └── conf.py
│ │ │ │ ├── flat_prior
│ │ │ │ │ └── conf.py
│ │ │ │ └── base_conf.py
│ │ │ └── base_conf.py
│ │ ├── kitchen
│ │ │ ├── SAC
│ │ │ │ └── conf.py
│ │ │ ├── prior_initialized
│ │ │ │ ├── bc_finetune
│ │ │ │ │ └── conf.py
│ │ │ │ ├── flat_prior
│ │ │ │ │ └── conf.py
│ │ │ │ └── base_conf.py
│ │ │ └── base_conf.py
│ │ ├── peg_in_hole
│ │ │ ├── SAC
│ │ │ │ └── conf.py
│ │ │ └── base_conf.py
│ │ └── block_stacking
│ │ │ ├── SAC
│ │ │ └── conf.py
│ │ │ ├── prior_initialized
│ │ │ ├── bc_finetune
│ │ │ │ └── conf.py
│ │ │ ├── flat_prior
│ │ │ │ └── conf.py
│ │ │ └── base_conf.py
│ │ │ └── base_conf.py
│ ├── hrl
│ │ ├── maze
│ │ │ ├── no_prior
│ │ │ │ └── conf.py
│ │ │ └── spirl
│ │ │ │ └── conf.py
│ │ ├── kitchen
│ │ │ ├── no_prior
│ │ │ │ └── conf.py
│ │ │ ├── spirl
│ │ │ │ └── conf.py
│ │ │ └── spirl_cl
│ │ │ │ └── conf.py
│ │ ├── block_stacking
│ │ │ ├── no_prior
│ │ │ │ └── conf.py
│ │ │ └── spirl
│ │ │ │ └── conf.py
│ │ └── peg_in_hole
│ │ │ └── spirl
│ │ │ └── conf.py
│ ├── default_data_configs
│ │ ├── maze.py
│ │ ├── block_stacking.py
│ │ ├── kitchen.py
│ │ └── peg_in_hole.py
│ ├── skill_prior_learning
│ │ ├── kitchen
│ │ │ ├── flat
│ │ │ │ └── conf.py
│ │ │ ├── hierarchical_cl
│ │ │ │ ├── conf.py
│ │ │ │ └── README.md
│ │ │ └── hierarchical
│ │ │ │ └── conf.py
│ │ ├── maze
│ │ │ ├── flat
│ │ │ │ └── conf.py
│ │ │ ├── hierarchical
│ │ │ │ └── conf.py
│ │ │ └── hierarchical_cl
│ │ │ │ └── conf.py
│ │ ├── block_stacking
│ │ │ ├── flat
│ │ │ │ └── conf.py
│ │ │ ├── hierarchical_cl
│ │ │ │ ├── conf.py
│ │ │ │ └── README.md
│ │ │ └── hierarchical
│ │ │ │ └── conf.py
│ │ └── peg_in_hole
│ │ │ └── hierarchical_cl
│ │ │ └── conf.py
│ └── data_collect
│ │ └── block_stacking
│ │ └── conf.py
├── rl
│ ├── utils
│ │ ├── reward_fcns.py
│ │ ├── robosuite_utils.py
│ │ ├── rollout_utils.py
│ │ └── wandb.py
│ ├── policies
│ │ └── basic_policies.py
│ ├── envs
│ │ ├── maze.py
│ │ └── kitchen.py
│ ├── components
│ │ └── params.py
│ └── agents
│ │ └── skill_space_agent.py
└── plotter
│ └── plot_trajectories.py
├── .gitignore
├── README.md
├── setup.py
└── requirements.txt
/spirl/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/spirl/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/spirl/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/spirl/components/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/spirl/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/spirl/modules/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | .vscode
3 | __pycache__
4 | *.egg-info
5 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/scripts/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/spirl/configs/rl/maze/SAC/conf.py:
--------------------------------------------------------------------------------
1 | from spirl.configs.rl.maze.base_conf import *
--------------------------------------------------------------------------------
/spirl/configs/hrl/maze/no_prior/conf.py:
--------------------------------------------------------------------------------
1 | from spirl.configs.hrl.maze.base_conf import *
--------------------------------------------------------------------------------
/spirl/configs/rl/kitchen/SAC/conf.py:
--------------------------------------------------------------------------------
1 | from spirl.configs.rl.kitchen.base_conf import *
--------------------------------------------------------------------------------
/spirl/configs/rl/peg_in_hole/SAC/conf.py:
--------------------------------------------------------------------------------
1 | from spirl.configs.rl.peg_in_hole.base_conf import *
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # learning_impedance_actions
2 |
3 | Peg-in-hole dataset will be available soon.
4 |
--------------------------------------------------------------------------------
/spirl/configs/hrl/kitchen/no_prior/conf.py:
--------------------------------------------------------------------------------
1 | from spirl.configs.hrl.kitchen.base_conf import *
2 |
--------------------------------------------------------------------------------
/spirl/configs/rl/block_stacking/SAC/conf.py:
--------------------------------------------------------------------------------
1 | from spirl.configs.rl.block_stacking.base_conf import *
--------------------------------------------------------------------------------
/spirl/configs/hrl/block_stacking/no_prior/conf.py:
--------------------------------------------------------------------------------
1 | from spirl.configs.hrl.block_stacking.base_conf import *
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 |
3 | setup(name='spirl', version='0.0.1', packages=['spirl'])
4 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/environments/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import REGISTERED_ENVS, MujocoEnv
2 |
3 | ALL_ENVS = REGISTERED_ENVS.keys()
4 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/devices/__init__.py:
--------------------------------------------------------------------------------
1 | from .device import Device
2 | from .keyboard import Keyboard
3 | from .spacemouse import SpaceMouse
4 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/robots/__init__.py:
--------------------------------------------------------------------------------
1 | from .robot import Robot
2 | from .sawyer_robot import Sawyer
3 | from .baxter_robot import Baxter
4 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/assets/textures/obj0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/assets/textures/obj0.png
--------------------------------------------------------------------------------
/spirl/data/block_stacking/assets/textures/obj1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/assets/textures/obj1.png
--------------------------------------------------------------------------------
/spirl/data/block_stacking/assets/textures/obj2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/assets/textures/obj2.png
--------------------------------------------------------------------------------
/spirl/data/block_stacking/assets/textures/obj3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/assets/textures/obj3.png
--------------------------------------------------------------------------------
/spirl/data/block_stacking/assets/textures/obj4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/assets/textures/obj4.png
--------------------------------------------------------------------------------
/spirl/data/block_stacking/assets/textures/obj5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/assets/textures/obj5.png
--------------------------------------------------------------------------------
/spirl/data/block_stacking/assets/textures/obj6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/assets/textures/obj6.png
--------------------------------------------------------------------------------
/spirl/data/block_stacking/assets/textures/obj7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/assets/textures/obj7.png
--------------------------------------------------------------------------------
/spirl/data/block_stacking/assets/textures/obj8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/assets/textures/obj8.png
--------------------------------------------------------------------------------
/spirl/data/block_stacking/assets/textures/obj9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/assets/textures/obj9.png
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .world import MujocoWorldBase
3 |
4 | assets_root = os.path.join(os.path.dirname(__file__), "assets")
5 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .errors import robosuiteError, XMLError, SimulationError, RandomizationError
2 | from .mujoco_py_renderer import MujocoPyRenderer
3 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/textures/can.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/textures/can.png
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/textures/clay.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/textures/clay.png
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/controllers/__init__.py:
--------------------------------------------------------------------------------
1 | from .controller import Controller
2 | from .baxter_ik_controller import BaxterIKController
3 | from .sawyer_ik_controller import SawyerIKController
4 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/textures/bread.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/textures/bread.png
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/textures/cereal.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/textures/cereal.png
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/textures/glass.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/textures/glass.png
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/textures/lemon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/textures/lemon.png
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/textures/metal.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/textures/metal.png
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/objects/meshes/can.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/objects/meshes/can.stl
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/textures/ceramic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/textures/ceramic.png
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/textures/dark-wood.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/textures/dark-wood.png
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/objects/meshes/bread.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/objects/meshes/bread.stl
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/objects/meshes/lemon.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/objects/meshes/lemon.stl
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/objects/meshes/milk.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/objects/meshes/milk.stl
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/textures/light-wood.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/textures/light-wood.png
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/objects/meshes/bottle.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/objects/meshes/bottle.stl
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/objects/meshes/cereal.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/objects/meshes/cereal.stl
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/objects/meshes/handles.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/objects/meshes/handles.stl
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/arenas/__init__.py:
--------------------------------------------------------------------------------
1 | from .arena import Arena
2 | from .bins_arena import BinsArena
3 | from .empty_arena import EmptyArena
4 | from .pegs_arena import PegsArena
5 | from .table_arena import TableArena
6 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/scripts/README.md:
--------------------------------------------------------------------------------
1 | Scripts
2 | =======
3 |
4 | This folder contains a collection of scripts to demonstrate the functionalities of robosuite. Check the documentation in the script files for detailed instructions.
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/deprecated/limiter.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/deprecated/limiter.STL
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/deprecated/paddle_tip.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/deprecated/paddle_tip.STL
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/pr2_gripper/l_finger.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/pr2_gripper/l_finger.stl
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/palm.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/palm.STL
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/deprecated/extended_wide.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/deprecated/extended_wide.STL
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/deprecated/standard_wide.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/deprecated/standard_wide.STL
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/pr2_gripper/gripper_palm.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/pr2_gripper/gripper_palm.stl
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/pr2_gripper/l_finger_tip.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/pr2_gripper/l_finger_tip.stl
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_0.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_0.STL
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_1.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_1.STL
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_2.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_2.STL
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_3.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_3.STL
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/deprecated/basic_hard_tip.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/deprecated/basic_hard_tip.STL
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/deprecated/basic_soft_tip.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/deprecated/basic_soft_tip.STL
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/deprecated/extended_narrow.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/deprecated/extended_narrow.STL
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/adapter_plate.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/adapter_plate.stl
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/adapter_plate.dae.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/adapter_plate.dae.stl
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/two_finger_gripper/half_round_tip.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/two_finger_gripper/half_round_tip.STL
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/two_finger_gripper/standard_narrow.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/two_finger_gripper/standard_narrow.STL
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/deprecated/electric_gripper_w_fingers.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/deprecated/electric_gripper_w_fingers.STL
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_base.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_base.stl
--------------------------------------------------------------------------------
/spirl/configs/rl/kitchen/prior_initialized/bc_finetune/conf.py:
--------------------------------------------------------------------------------
1 | from spirl.configs.rl.kitchen.prior_initialized.base_conf import *
2 | from spirl.rl.policies.prior_policies import PriorInitializedPolicy
3 |
4 | agent_config.policy = PriorInitializedPolicy
5 | configuration.agent = SACAgent
6 |
7 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/two_finger_gripper/electric_gripper_base.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/two_finger_gripper/electric_gripper_base.STL
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_joint_0_L.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_joint_0_L.stl
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_joint_0_R.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_joint_0_R.stl
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_joint_1_L.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_joint_1_L.stl
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_joint_1_R.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_joint_1_R.stl
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_joint_2_L.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_joint_2_L.stl
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_joint_2_R.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_joint_2_R.stl
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_joint_3_L.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_joint_3_L.stl
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_joint_3_R.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_joint_3_R.stl
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_adapter_plate.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_adapter_plate.stl
--------------------------------------------------------------------------------
/spirl/configs/rl/block_stacking/prior_initialized/bc_finetune/conf.py:
--------------------------------------------------------------------------------
1 | from spirl.configs.rl.block_stacking.prior_initialized.base_conf import *
2 | from spirl.rl.policies.prior_policies import ACPriorInitializedPolicy
3 |
4 | # update agent
5 | agent_config.policy = ACPriorInitializedPolicy
6 | configuration.agent = SACAgent
7 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/wrappers/__init__.py:
--------------------------------------------------------------------------------
1 | from spirl.data.block_stacking.src.robosuite.wrappers.ik_wrapper import IKWrapper
2 |
3 | try:
4 | from spirl.data.block_stacking.src.robosuite.wrappers import GymWrapper
5 | except:
6 | print("Warning: make sure gym is installed if you want to use the GymWrapper.")
7 |
--------------------------------------------------------------------------------
/spirl/configs/rl/maze/prior_initialized/bc_finetune/conf.py:
--------------------------------------------------------------------------------
1 | from spirl.configs.rl.maze.prior_initialized.base_conf import *
2 | from spirl.rl.policies.prior_policies import ACPriorInitializedPolicy
3 | from spirl.data.maze.src.maze_agents import MazeSACAgent
4 |
5 | agent_config.policy = ACPriorInitializedPolicy
6 | configuration.agent = MazeSACAgent
7 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # core
2 | numpy
3 | matplotlib
4 | pillow
5 | h5py==2.10.0
6 | scikit-image
7 | funcsigs
8 | opencv-python
9 | moviepy
10 | torch==1.3.1
11 | torchvision==0.4.2
12 | tensorboard==2.1.1
13 | tensorboardX==2.0
14 | gym==0.15.4
15 | pandas
16 |
17 | # RL
18 | wandb
19 | mpi4py
20 |
21 | # Block Stacking
22 | # mujoco_py==2.0.2.9
23 |
--------------------------------------------------------------------------------
/spirl/rl/utils/reward_fcns.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def euclid_distance(state1, state2):
5 | return np.float64(np.linalg.norm(state1 - state2)) # rewards need to be np arrays for replay buffer
6 |
7 |
8 | def sparse_threshold(state1, state2, thresh):
9 | return np.float64(euclid_distance(state1, state2) < thresh)
10 |
11 |
12 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | from .task import Task
2 |
3 | from .placement_sampler import (
4 | ObjectPositionSampler,
5 | UniformRandomSampler,
6 | UniformRandomPegsSampler,
7 | )
8 |
9 | from .pick_place_task import PickPlaceTask
10 | from .nut_assembly_task import NutAssemblyTask
11 | from .table_top_task import TableTopTask
12 |
--------------------------------------------------------------------------------
/spirl/configs/default_data_configs/maze.py:
--------------------------------------------------------------------------------
1 | from spirl.utils.general_utils import AttrDict
2 | from spirl.components.data_loader import GlobalSplitVideoDataset
3 |
4 |
5 | data_spec = AttrDict(
6 | dataset_class=GlobalSplitVideoDataset,
7 | n_actions=2,
8 | state_dim=4,
9 | split=AttrDict(train=0.9, val=0.1, test=0.0),
10 | res=32,
11 | crop_rand_subseq=True,
12 | )
13 | data_spec.max_seq_len = 300
14 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/world.py:
--------------------------------------------------------------------------------
1 | from spirl.data.block_stacking.src.robosuite.models.base import MujocoXML
2 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import xml_path_completion
3 |
4 |
5 | class MujocoWorldBase(MujocoXML):
6 | """Base class to inherit all mujoco worlds from."""
7 |
8 | def __init__(self):
9 | super().__init__(xml_path_completion("base.xml"))
10 |
--------------------------------------------------------------------------------
/spirl/configs/default_data_configs/block_stacking.py:
--------------------------------------------------------------------------------
1 | from spirl.utils.general_utils import AttrDict
2 | from spirl.components.data_loader import GlobalSplitVideoDataset
3 |
4 | data_spec = AttrDict(
5 | dataset_class=GlobalSplitVideoDataset,
6 | n_actions=3,
7 | state_dim=41,
8 | split=AttrDict(train=0.95, val=0.05, test=0.0),
9 | res=32,
10 | crop_rand_subseq=True,
11 | )
12 | data_spec.max_seq_len = 150
13 |
--------------------------------------------------------------------------------
/spirl/configs/default_data_configs/kitchen.py:
--------------------------------------------------------------------------------
1 | from spirl.utils.general_utils import AttrDict
2 | from spirl.data.kitchen.src.kitchen_data_loader import D4RLSequenceSplitDataset
3 |
4 |
5 | data_spec = AttrDict(
6 | dataset_class=D4RLSequenceSplitDataset,
7 | n_actions=9,
8 | state_dim=60,
9 | env_name="kitchen-mixed-v0",
10 | res=128,
11 | crop_rand_subseq=True,
12 | )
13 | data_spec.max_seq_len = 280
14 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/grippers/__init__.py:
--------------------------------------------------------------------------------
1 | from .gripper import Gripper
2 | from .gripper_factory import gripper_factory
3 | from .two_finger_gripper import TwoFingerGripper, LeftTwoFingerGripper
4 | from .gripper_tester import GripperTester
5 | from .pr2_gripper import PR2Gripper
6 | from .pushing_gripper import PushingGripper
7 | from .robotiq_gripper import RobotiqGripper
8 | from .robotiq_three_finger_gripper import RobotiqThreeFingerGripper
9 |
--------------------------------------------------------------------------------
/spirl/configs/rl/kitchen/prior_initialized/flat_prior/conf.py:
--------------------------------------------------------------------------------
1 | from spirl.configs.rl.kitchen.prior_initialized.base_conf import *
2 | from spirl.rl.policies.prior_policies import LearnedPriorAugmentedPIPolicy
3 | from spirl.rl.agents.prior_sac_agent import ActionPriorSACAgent
4 |
5 | agent_config.update(AttrDict(
6 | td_schedule_params=AttrDict(p=1.),
7 | ))
8 |
9 | agent_config.policy = LearnedPriorAugmentedPIPolicy
10 | configuration.agent = ActionPriorSACAgent
11 |
--------------------------------------------------------------------------------
/spirl/configs/rl/maze/prior_initialized/flat_prior/conf.py:
--------------------------------------------------------------------------------
1 | from spirl.configs.rl.maze.prior_initialized.base_conf import *
2 | from spirl.rl.policies.prior_policies import ACLearnedPriorAugmentedPIPolicy
3 | from spirl.data.maze.src.maze_agents import MazeActionPriorSACAgent
4 |
5 | agent_config.update(AttrDict(
6 | td_schedule_params=AttrDict(p=1.),
7 | ))
8 |
9 | agent_config.policy = ACLearnedPriorAugmentedPIPolicy
10 | configuration.agent = MazeActionPriorSACAgent
11 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/arenas/empty_arena.py:
--------------------------------------------------------------------------------
1 | from spirl.data.block_stacking.src.robosuite.models.arenas import Arena
2 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import xml_path_completion
3 |
4 |
5 | class EmptyArena(Arena):
6 | """Empty workspace."""
7 |
8 | def __init__(self):
9 | super().__init__(xml_path_completion("arenas/empty_arena.xml"))
10 | self.floor = self.worldbody.find("./geom[@name='floor']")
11 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/base.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/spirl/configs/rl/kitchen/prior_initialized/base_conf.py:
--------------------------------------------------------------------------------
1 | from spirl.configs.rl.kitchen.base_conf import *
2 | from spirl.models.bc_mdl import BCMdl
3 |
4 | policy_params.update(AttrDict(
5 | prior_model=BCMdl,
6 | prior_model_params=AttrDict(state_dim=data_spec.state_dim,
7 | action_dim=data_spec.n_actions,
8 | ),
9 | prior_model_checkpoint=os.path.join(os.environ["EXP_DIR"], "skill_prior_learning/kitchen/flat"),
10 | ))
11 |
12 |
--------------------------------------------------------------------------------
/spirl/configs/rl/block_stacking/prior_initialized/flat_prior/conf.py:
--------------------------------------------------------------------------------
1 | from spirl.configs.rl.block_stacking.prior_initialized.base_conf import *
2 | from spirl.rl.policies.prior_policies import ACLearnedPriorAugmentedPIPolicy
3 | from spirl.rl.agents.prior_sac_agent import ActionPriorSACAgent
4 |
5 | agent_config.update(AttrDict(
6 | td_schedule_params=AttrDict(p=1.),
7 | ))
8 |
9 | # update agent
10 | agent_config.policy = ACLearnedPriorAugmentedPIPolicy
11 | configuration.agent = ActionPriorSACAgent
--------------------------------------------------------------------------------
/spirl/configs/default_data_configs/peg_in_hole.py:
--------------------------------------------------------------------------------
1 | from spirl.utils.general_utils import AttrDict
2 | #from spirl.data.kitchen.src.kitchen_data_loader import D4RLSequenceSplitDataset
3 | from spirl.data.peg_in_hole.src.peg_in_hole_data_loader import PegInHoleSequenceSplitDataset
4 |
5 | data_spec = AttrDict(
6 | dataset_class=PegInHoleSequenceSplitDataset,
7 | n_actions=8,
8 | state_dim=19,
9 | env_name="peg-in-hole-v0",
10 | res=128,
11 | crop_rand_subseq=True,
12 | )
13 | data_spec.max_seq_len = 280
14 |
--------------------------------------------------------------------------------
/spirl/utils/transformations.py:
--------------------------------------------------------------------------------
1 | from scipy.spatial.transform import Rotation
2 |
3 | # Create a rotation object from Euler angles specifying axes of rotation
4 | rot = Rotation.from_euler('xyz', [0.35, 0, 0.23], degrees=False)
5 |
6 | # Convert to quaternions and print
7 | rot_quat = rot.as_quat()
8 | print(rot_quat)
9 |
10 | #print(rot.as_euler('xyz', degrees=True))
11 |
12 | rot = Rotation.from_quat(rot_quat)
13 |
14 | # Convert the rotation to Euler angles given the axes of rotation
15 | print(rot.as_euler('xyz', degrees=False))
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/grippers/pushing_gripper.py:
--------------------------------------------------------------------------------
1 | """
2 | A version of TwoFingerGripper but always closed.
3 | """
4 | import numpy as np
5 | from spirl.data.block_stacking.src.robosuite.models.grippers.two_finger_gripper import TwoFingerGripper
6 |
7 |
8 | class PushingGripper(TwoFingerGripper):
9 | """
10 | Same as TwoFingerGripper, but always closed
11 | """
12 |
13 | def format_action(self, action):
14 | return np.array([1, -1])
15 |
16 | @property
17 | def dof(self):
18 | return 1
19 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/utils/errors.py:
--------------------------------------------------------------------------------
1 | class robosuiteError(Exception):
2 | """Base class for exceptions in robosuite."""
3 |
4 | pass
5 |
6 |
7 | class XMLError(robosuiteError):
8 | """Exception raised for errors related to xml."""
9 |
10 | pass
11 |
12 |
13 | class SimulationError(robosuiteError):
14 | """Exception raised for errors during runtime."""
15 |
16 | pass
17 |
18 |
19 | class RandomizationError(robosuiteError):
20 | """Exception raised for really really bad RNG."""
21 |
22 | pass
23 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/objects/__init__.py:
--------------------------------------------------------------------------------
1 | from .objects import MujocoObject, MujocoXMLObject, MujocoGeneratedObject
2 |
3 | from .xml_objects import (
4 | BottleObject,
5 | CanObject,
6 | LemonObject,
7 | MilkObject,
8 | BreadObject,
9 | CerealObject,
10 | SquareNutObject,
11 | RoundNutObject,
12 | MilkVisualObject,
13 | BreadVisualObject,
14 | CerealVisualObject,
15 | CanVisualObject,
16 | PlateWithHoleObject,
17 | )
18 |
19 | from .generated_objects import (
20 | PotWithHandlesObject,
21 | BoxObject,
22 | CylinderObject,
23 | BallObject,
24 | CapsuleObject,
25 | )
26 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/devices/device.py:
--------------------------------------------------------------------------------
1 | import abc # for abstract base class definitions
2 |
3 |
4 | class Device(metaclass=abc.ABCMeta):
5 | """
6 | Base class for all robot controllers.
7 | Defines basic interface for all controllers to adhere to.
8 | """
9 |
10 | @abc.abstractmethod
11 | def start_control(self):
12 | """
13 | Method that should be called externally before controller can
14 | start receiving commands.
15 | """
16 | raise NotImplementedError
17 |
18 | @abc.abstractmethod
19 | def get_controller_state(self):
20 | """Returns the current state of the device, a dictionary of pos, orn, grasp, and reset."""
21 | raise NotImplementedError
22 |
--------------------------------------------------------------------------------
/spirl/configs/hrl/block_stacking/spirl/conf.py:
--------------------------------------------------------------------------------
1 | from spirl.configs.hrl.block_stacking.base_conf import *
2 | from spirl.rl.policies.prior_policies import ACLearnedPriorAugmentedPIPolicy
3 | from spirl.rl.agents.prior_sac_agent import ActionPriorSACAgent
4 |
5 |
6 | # add prior model to policy
7 | hl_policy_params.update(AttrDict(
8 | prior_model=ll_agent_config.model,
9 | prior_model_params=ll_agent_config.model_params,
10 | prior_model_checkpoint=ll_agent_config.model_checkpoint,
11 | ))
12 | hl_agent_config.policy = ACLearnedPriorAugmentedPIPolicy
13 |
14 | # update agent + set target divergence
15 | agent_config.hl_agent = ActionPriorSACAgent
16 | agent_config.hl_agent_params.update(AttrDict(
17 | td_schedule_params=AttrDict(p=5.),
18 | ))
19 |
20 |
--------------------------------------------------------------------------------
/spirl/configs/hrl/kitchen/spirl/conf.py:
--------------------------------------------------------------------------------
1 | from spirl.configs.hrl.kitchen.base_conf import *
2 | from spirl.rl.policies.prior_policies import LearnedPriorAugmentedPIPolicy
3 | from spirl.rl.agents.prior_sac_agent import ActionPriorSACAgent
4 |
5 |
6 | # update policy to use prior model for computing divergence
7 | hl_policy_params.update(AttrDict(
8 | prior_model=ll_agent_config.model,
9 | prior_model_params=ll_agent_config.model_params,
10 | prior_model_checkpoint=ll_agent_config.model_checkpoint,
11 | ))
12 | hl_agent_config.policy = LearnedPriorAugmentedPIPolicy
13 |
14 | # update agent, set target divergence
15 | agent_config.hl_agent = ActionPriorSACAgent
16 | agent_config.hl_agent_params.update(AttrDict(
17 | td_schedule_params=AttrDict(p=5.),
18 | ))
19 |
--------------------------------------------------------------------------------
/spirl/configs/hrl/peg_in_hole/spirl/conf.py:
--------------------------------------------------------------------------------
1 | from spirl.configs.hrl.peg_in_hole.base_conf import *
2 | from spirl.rl.policies.prior_policies import LearnedPriorAugmentedPIPolicy
3 | from spirl.rl.agents.prior_sac_agent import ActionPriorSACAgent
4 |
5 |
6 | # update policy to use prior model for computing divergence
7 | hl_policy_params.update(AttrDict(
8 | prior_model=ll_agent_config.model,
9 | prior_model_params=ll_agent_config.model_params,
10 | prior_model_checkpoint=ll_agent_config.model_checkpoint,
11 | ))
12 | hl_agent_config.policy = LearnedPriorAugmentedPIPolicy
13 |
14 | # update agent, set target divergence
15 | agent_config.hl_agent = ActionPriorSACAgent
16 | agent_config.hl_agent_params.update(AttrDict(
17 | td_schedule_params=AttrDict(p=5.),
18 | ))
19 |
--------------------------------------------------------------------------------
/spirl/configs/hrl/maze/spirl/conf.py:
--------------------------------------------------------------------------------
1 | from spirl.configs.hrl.maze.base_conf import *
2 | from spirl.rl.policies.prior_policies import ACLearnedPriorAugmentedPIPolicy
3 | from spirl.rl.agents.prior_sac_agent import ActionPriorSACAgent
4 |
5 |
6 | # update policy to use learned prior
7 | hl_policy_params.update(AttrDict(
8 | prior_model=ll_agent_config.model,
9 | prior_model_params=ll_agent_config.model_params,
10 | prior_model_checkpoint=ll_agent_config.model_checkpoint,
11 | ))
12 | hl_agent_config.policy = ACLearnedPriorAugmentedPIPolicy
13 |
14 | # update agent to regularize with prior, set target divergence
15 | agent_config.hl_agent = ActionPriorSACAgent
16 | agent_config.hl_agent_params.update(AttrDict(
17 | td_schedule_params=AttrDict(p=1.),
18 | ))
19 |
20 |
21 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/objects/can-visual.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/spirl/configs/skill_prior_learning/kitchen/flat/conf.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from spirl.models.bc_mdl import BCMdl
4 | from spirl.components.logger import Logger
5 | from spirl.utils.general_utils import AttrDict
6 | from spirl.configs.default_data_configs.kitchen import data_spec
7 | from spirl.components.evaluator import DummyEvaluator
8 |
9 |
10 | current_dir = os.path.dirname(os.path.realpath(__file__))
11 |
12 |
13 | configuration = {
14 | 'model': BCMdl,
15 | 'logger': Logger,
16 | 'data_dir': '.',
17 | 'epoch_cycles_train': 10,
18 | 'evaluator': DummyEvaluator,
19 | }
20 | configuration = AttrDict(configuration)
21 |
22 | model_config = AttrDict(
23 | state_dim=data_spec.state_dim,
24 | action_dim=data_spec.n_actions,
25 | )
26 |
27 | # Dataset
28 | data_config = AttrDict()
29 | data_config.dataset_spec = data_spec
30 | data_config.dataset_spec.subseq_len = 1 + 1 # flat last action from seq gets cropped
31 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/scripts/browse_mjcf_model.py:
--------------------------------------------------------------------------------
1 | """Visualize MJCF models.
2 |
3 | Loads MJCF XML models from file and renders it on screen.
4 |
5 | Example:
6 | $ python browse_arena_model.py --filepath ../models/assets/arenas/table_arena.xml
7 | """
8 |
9 | import argparse
10 | import os
11 |
12 | from mujoco_py import load_model_from_path
13 | from mujoco_py import MjSim, MjViewer
14 |
15 | from spirl.data.block_stacking.src import robosuite
16 |
17 | if __name__ == "__main__":
18 |
19 | arena_file = os.path.join(robosuite.models.assets_root, "arenas/pegs_arena.xml")
20 |
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument("--filepath", type=str, default=arena_file)
23 | args = parser.parse_args()
24 |
25 | model = load_model_from_path(args.filepath)
26 | sim = MjSim(model)
27 | viewer = MjViewer(sim)
28 |
29 | print("Press ESC to exit...")
30 | while True:
31 | sim.step()
32 | viewer.render()
33 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/__init__.py:
--------------------------------------------------------------------------------
1 | #import os
2 |
3 | #from spirl.data.block_stacking.src.robosuite.environments.base import make
4 | #from spirl.data.block_stacking.src.robosuite.environments.sawyer_lift import SawyerLift
5 | #from spirl.data.block_stacking.src.robosuite.environments.sawyer_stack import SawyerStack
6 | #from spirl.data.block_stacking.src.robosuite.environments.sawyer_pick_place import SawyerPickPlace
7 | #from spirl.data.block_stacking.src.robosuite.environments.sawyer_nut_assembly import SawyerNutAssembly
8 |
9 | #from spirl.data.block_stacking.src.robosuite.environments.baxter_lift import BaxterLift
10 | #from spirl.data.block_stacking.src.robosuite.environments.baxter_peg_in_hole import BaxterPegInHole
11 | #from spirl.data.block_stacking.src.robosuite.environments.baxter_modified import BaxterChange
12 |
13 | __version__ = "0.3.0"
14 | __logo__ = """
15 | ; / ,--.
16 | ["] ["] ,< |__**|
17 | /[_]\ [~]\/ |// |
18 | ] [ OOO /o|__|
19 | """
20 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/objects/milk-visual.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/objects/cereal-visual.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/spirl/configs/skill_prior_learning/maze/flat/conf.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from spirl.models.bc_mdl import ImageBCMdl
4 | from spirl.utils.general_utils import AttrDict
5 | from spirl.configs.default_data_configs.maze import data_spec
6 | from spirl.components.evaluator import DummyEvaluator
7 | from spirl.components.logger import Logger
8 |
9 |
10 | current_dir = os.path.dirname(os.path.realpath(__file__))
11 |
12 |
13 | configuration = {
14 | 'model': ImageBCMdl,
15 | 'logger': Logger,
16 | 'data_dir': os.path.join(os.environ['DATA_DIR'], 'maze'),
17 | 'epoch_cycles_train': 4,
18 | 'evaluator': DummyEvaluator,
19 | }
20 | configuration = AttrDict(configuration)
21 |
22 | model_config = AttrDict(
23 | state_dim=data_spec.state_dim,
24 | action_dim=data_spec.n_actions,
25 | input_res=data_spec.res,
26 | n_input_frames=2,
27 | )
28 |
29 | # Dataset
30 | data_config = AttrDict()
31 | data_config.dataset_spec = data_spec
32 | data_config.dataset_spec.subseq_len = 1 + 1 + (model_config.n_input_frames - 1)
33 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/objects/bread-visual.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/spirl/configs/skill_prior_learning/block_stacking/flat/conf.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from spirl.models.bc_mdl import ImageBCMdl
4 | from spirl.utils.general_utils import AttrDict
5 | from spirl.configs.default_data_configs.block_stacking import data_spec
6 | from spirl.components.evaluator import DummyEvaluator
7 | from spirl.components.logger import Logger
8 |
9 |
10 | current_dir = os.path.dirname(os.path.realpath(__file__))
11 |
12 |
13 | configuration = {
14 | 'model': ImageBCMdl,
15 | 'logger': Logger,
16 | 'data_dir': os.path.join(os.environ['DATA_DIR'], 'block_stacking'),
17 | 'epoch_cycles_train': 4,
18 | 'evaluator': DummyEvaluator,
19 | }
20 | configuration = AttrDict(configuration)
21 |
22 | model_config = AttrDict(
23 | state_dim=data_spec.state_dim,
24 | action_dim=data_spec.n_actions,
25 | input_res=data_spec.res,
26 | n_input_frames=2,
27 | )
28 |
29 | # Dataset
30 | data_config = AttrDict()
31 | data_config.dataset_spec = data_spec
32 | data_config.dataset_spec.subseq_len = 1 + 1 + (model_config.n_input_frames - 1)
33 |
--------------------------------------------------------------------------------
/spirl/data/peg_in_hole/src/generate_peg_in_hole_hdf5.py:
--------------------------------------------------------------------------------
1 | """Script for generating the datasets for peg-in-hole tasks."""
2 | import h5py
3 | import numpy as np
4 | import pandas as pd
5 | import os
6 | import pickle
7 |
8 | def main():
9 | dataset_df = pd.read_csv('peg_in_hole_dataset.csv', header=None)
10 | #print(action_df)
11 | #action_df.to_hdf('action.hdf5', key='action', mode='w')
12 |
13 | with h5py.File('peg_in_hole_dataset.hdf5', 'w') as hdf:
14 | # create state group
15 | #group1 = hdf.create_group('observations')
16 | hdf.create_dataset('observations', data=dataset_df.iloc[:,0:19])
17 |
18 | # create action group
19 | #group2 = hdf.create_group('actions')
20 | hdf.create_dataset('actions', data=dataset_df.iloc[:,19:27])
21 |
22 | # create reward group
23 | hdf.create_dataset('rewards', data=dataset_df.iloc[:,27])
24 |
25 | #group3 = hdf.create_group('terminals')
26 | hdf.create_dataset('terminals', data=dataset_df.iloc[:,28])
27 |
28 |
29 | if __name__ == '__main__':
30 | main()
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/objects/bottle.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/robots/sawyer_robot.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from spirl.data.block_stacking.src.robosuite.models.robots.robot import Robot
3 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import xml_path_completion, array_to_string
4 |
5 |
6 | class Sawyer(Robot):
7 | """Sawyer is a witty single-arm robot designed by Rethink Robotics."""
8 |
9 | def __init__(self):
10 | super().__init__(xml_path_completion("robots/sawyer/robot.xml"))
11 |
12 | self.bottom_offset = np.array([0, 0, -0.913])
13 |
14 | def set_base_xpos(self, pos):
15 | """Places the robot on position @pos."""
16 | node = self.worldbody.find("./body[@name='base']")
17 | node.set("pos", array_to_string(pos - self.bottom_offset))
18 |
19 | @property
20 | def dof(self):
21 | return 7
22 |
23 | @property
24 | def joints(self):
25 | return ["right_j{}".format(x) for x in range(7)]
26 |
27 | @property
28 | def init_qpos(self):
29 | return np.array([0, -1.18, 0.00, 2.18, 0.00, 0.57, 3.3161])
30 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/utils/gripper_only_robot.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from spirl.data.block_stacking.src.robosuite.models.robots.robot import Robot
4 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import array_to_string
5 |
6 |
7 | class GripperOnlyRobot(Robot):
8 | """Sawyer is a witty single-arm robot designed by Rethink Robotics."""
9 |
10 | def __init__(self):
11 | super().__init__(os.path.join(os.getcwd(), "spirl/data/block_stacking/assets/gripper_only_robot.xml"))
12 |
13 | self.bottom_offset = np.array([0, 0, 0])
14 |
15 | def set_base_xpos(self, pos):
16 | """Places the robot on position @pos."""
17 | node = self.worldbody.find("./body[@name='base']")
18 | node.set("pos", array_to_string(pos - self.bottom_offset))
19 |
20 | @property
21 | def dof(self):
22 | return 4
23 |
24 | @property
25 | def joints(self):
26 | return ["slide_x", "slide_y", "slide_z", "rotate_z"]
27 |
28 | @property
29 | def init_qpos(self):
30 | return np.array([0, 0, 1.2, 0])
31 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/scripts/compile_mjcf_model.py:
--------------------------------------------------------------------------------
1 | """Loads a raw mjcf file and saves a compiled mjcf file.
2 |
3 | This avoids mujoco-py from complaining about .urdf extension.
4 | Also allows assets to be compiled properly.
5 |
6 | Example:
7 | $ python compile_mjcf_model.py source_mjcf.xml target_mjcf.xml
8 | """
9 |
10 | import os
11 | import sys
12 | from shutil import copyfile
13 | from mujoco_py import load_model_from_path
14 |
15 |
16 | def print_usage():
17 | print("""python compile.py input_file output_file""")
18 |
19 |
20 | if __name__ == "__main__":
21 |
22 | if len(sys.argv) != 3:
23 | print_usage()
24 | exit(0)
25 |
26 | input_file = sys.argv[1]
27 | output_file = sys.argv[2]
28 | input_folder = os.path.dirname(input_file)
29 |
30 | tempfile = os.path.join(input_folder, ".surreal_temp_model.xml")
31 | copyfile(input_file, tempfile)
32 |
33 | model = load_model_from_path(tempfile)
34 | xml_string = model.get_xml()
35 | with open(output_file, "w") as f:
36 | f.write(xml_string)
37 |
38 | os.remove(tempfile)
39 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/assets/gripper_only_robot.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/controllers/controller.py:
--------------------------------------------------------------------------------
1 | import abc # for abstract base class definitions
2 |
3 |
4 | class Controller(metaclass=abc.ABCMeta):
5 | """
6 | Base class for all robot controllers.
7 | Defines basic interface for all controllers to adhere to.
8 | """
9 |
10 | def __init__(self, bullet_data_path, robot_jpos_getter):
11 | """
12 | Args:
13 | bullet_data_path (str): base path to bullet data.
14 |
15 | robot_jpos_getter (function): function that returns the position of the joints
16 | as a numpy array of the right dimension.
17 | """
18 | raise NotImplementedError
19 |
20 | @abc.abstractmethod
21 | def get_control(self, *args, **kwargs):
22 | """
23 | Retrieve a control input from the controller.
24 | """
25 | raise NotImplementedError
26 |
27 | @abc.abstractmethod
28 | def sync_state(self):
29 | """
30 | This function does internal bookkeeping to maintain
31 | consistency between the robot being controlled and
32 | the controller state.
33 | """
34 | raise NotImplementedError
35 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/objects/can.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/arenas/empty_arena.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/spirl/plotter/plot_trajectories.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import pandas as pd
4 | import seaborn as sns
5 |
6 | def plot_trajectory(prior_seqs, policy_seqs):
7 | ax = plt.axes(projection='3d')
8 |
9 | # plot actions
10 | x = prior_seqs[:,0]
11 | y = prior_seqs[:,1]
12 | z = prior_seqs[:,2]
13 | ax.plot3D(x, y, z, 'gray')
14 | ax.scatter3D(x, y, z)
15 |
16 | # plot actions
17 | x = policy_seqs[:,0]
18 | y = policy_seqs[:,1]
19 | z = policy_seqs[:,2]
20 | ax.plot3D(x, y, z, 'blue')
21 | ax.scatter3D(x, y, z)
22 |
23 | ax.set_xlabel('x')
24 | ax.set_ylabel('y')
25 | ax.set_zlabel('z')
26 | ax.set_xlim([0.1, 0.2])
27 | ax.set_ylim([0.2, 0.3])
28 | ax.set_zlim([0.0, 0.15])
29 |
30 | plt.show()
31 | plt.savefig('fig/'+'traj.png')
32 |
33 | if __name__ == '__main__':
34 | traj_df = pd.read_csv('~/Workspaces/rl_ws/spirl/trajectories2.csv', header=None)
35 | policy_traj_array = traj_df.iloc[:,0:3].values.astype(np.float32)#.reshape(-1,120)
36 | prior_traj_array = traj_df.iloc[:,3:6].values.astype(np.float32)#.reshape(-1,120)
37 | plot_trajectory(prior_traj_array, policy_traj_array)
--------------------------------------------------------------------------------
/spirl/rl/policies/basic_policies.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from spirl.modules.variational_inference import MultivariateGaussian
4 | from spirl.rl.components.policy import Policy
5 | from spirl.utils.general_utils import ParamDict
6 |
7 |
8 | class UniformGaussianPolicy(Policy):
9 | """Samples actions from a uniform Gaussian."""
10 | def __init__(self, config):
11 | self._hp = self._default_hparams().overwrite(config)
12 | super().__init__()
13 |
14 | def _default_hparams(self):
15 | default_dict = ParamDict({
16 | 'scale': 1, # scale of Uniform Gaussian
17 | })
18 | return super()._default_hparams().overwrite(default_dict)
19 |
20 | def _build_network(self):
21 | return torch.nn.Module() # dummy module
22 |
23 | def _compute_action_dist(self, obs):
24 | batch_size = obs.shape[0]
25 | return MultivariateGaussian(mu=torch.zeros((batch_size, self._hp.action_dim), device=obs.device),
26 | log_sigma=torch.log(self._hp.scale *
27 | torch.ones((batch_size, self._hp.action_dim), device=obs.device)))
28 |
--------------------------------------------------------------------------------
/spirl/configs/skill_prior_learning/maze/hierarchical/conf.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from spirl.models.skill_prior_mdl import ImageSkillPriorMdl, SkillSpaceLogger
4 | from spirl.utils.general_utils import AttrDict
5 | from spirl.configs.default_data_configs.maze import data_spec
6 | from spirl.components.evaluator import TopOfNSequenceEvaluator
7 |
8 |
9 | current_dir = os.path.dirname(os.path.realpath(__file__))
10 |
11 |
12 | configuration = {
13 | 'model': ImageSkillPriorMdl,
14 | 'logger': SkillSpaceLogger,
15 | 'data_dir': os.path.join(os.environ['DATA_DIR'], 'maze'),
16 | 'epoch_cycles_train': 10,
17 | 'evaluator': TopOfNSequenceEvaluator,
18 | 'top_of_n_eval': 100,
19 | 'top_comp_metric': 'mse',
20 | }
21 | configuration = AttrDict(configuration)
22 |
23 | model_config = AttrDict(
24 | state_dim=data_spec.state_dim,
25 | action_dim=data_spec.n_actions,
26 | n_rollout_steps=10,
27 | kl_div_weight=1e-2,
28 | prior_input_res=data_spec.res,
29 | n_input_frames=2,
30 | )
31 |
32 | # Dataset
33 | data_config = AttrDict()
34 | data_config.dataset_spec = data_spec
35 | data_config.dataset_spec.subseq_len = model_config.n_rollout_steps + model_config.n_input_frames
36 |
--------------------------------------------------------------------------------
/spirl/configs/skill_prior_learning/kitchen/hierarchical_cl/conf.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from spirl.models.closed_loop_spirl_mdl import ClSPiRLMdl
4 | from spirl.components.logger import Logger
5 | from spirl.utils.general_utils import AttrDict
6 | from spirl.configs.default_data_configs.kitchen import data_spec
7 | from spirl.components.evaluator import TopOfNSequenceEvaluator
8 |
9 | current_dir = os.path.dirname(os.path.realpath(__file__))
10 |
11 |
12 | configuration = {
13 | 'model': ClSPiRLMdl,
14 | 'logger': Logger,
15 | 'data_dir': '.',
16 | 'epoch_cycles_train': 10,
17 | 'evaluator': TopOfNSequenceEvaluator,
18 | 'top_of_n_eval': 100,
19 | 'top_comp_metric': 'mse',
20 | }
21 | configuration = AttrDict(configuration)
22 |
23 | model_config = AttrDict(
24 | state_dim=data_spec.state_dim,
25 | action_dim=data_spec.n_actions,
26 | n_rollout_steps=10,
27 | kl_div_weight=5e-4,
28 | nz_enc=128,
29 | nz_mid=128,
30 | n_processing_layers=5,
31 | cond_decode=True,
32 | )
33 |
34 | # Dataset
35 | data_config = AttrDict()
36 | data_config.dataset_spec = data_spec
37 | data_config.dataset_spec.subseq_len = model_config.n_rollout_steps + 1 # flat last action from seq gets cropped
38 |
--------------------------------------------------------------------------------
/spirl/configs/skill_prior_learning/peg_in_hole/hierarchical_cl/conf.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from spirl.models.closed_loop_spirl_mdl import ClSPiRLMdl
4 | from spirl.components.logger import Logger
5 | from spirl.utils.general_utils import AttrDict
6 | from spirl.configs.default_data_configs.peg_in_hole import data_spec
7 | from spirl.components.evaluator import TopOfNSequenceEvaluator
8 |
9 | current_dir = os.path.dirname(os.path.realpath(__file__))
10 |
11 |
12 | configuration = {
13 | 'model': ClSPiRLMdl,
14 | 'logger': Logger,
15 | 'data_dir': '.',
16 | 'epoch_cycles_train': 10,
17 | 'evaluator': TopOfNSequenceEvaluator,
18 | 'top_of_n_eval': 100,
19 | 'top_comp_metric': 'mse',
20 | }
21 | configuration = AttrDict(configuration)
22 |
23 | model_config = AttrDict(
24 | state_dim=data_spec.state_dim,
25 | action_dim=data_spec.n_actions,
26 | n_rollout_steps=10,
27 | kl_div_weight=5e-4,
28 | nz_enc=128,
29 | nz_mid=128,
30 | n_processing_layers=5,
31 | cond_decode=True,
32 | )
33 |
34 | # Dataset
35 | data_config = AttrDict()
36 | data_config.dataset_spec = data_spec
37 | data_config.dataset_spec.subseq_len = model_config.n_rollout_steps + 1 # flat last action from seq gets cropped
38 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/objects/plate-with-hole.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/objects/bread.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/objects/lemon.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/objects/milk.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/spirl/configs/skill_prior_learning/kitchen/hierarchical/conf.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from spirl.models.skill_prior_mdl import SkillPriorMdl
4 | from spirl.components.logger import Logger
5 | from spirl.models.skill_prior_mdl import SkillSpaceLogger
6 | from spirl.utils.general_utils import AttrDict
7 | from spirl.configs.default_data_configs.kitchen import data_spec
8 | from spirl.components.evaluator import TopOfNSequenceEvaluator
9 |
10 | current_dir = os.path.dirname(os.path.realpath(__file__))
11 |
12 |
13 | configuration = {
14 | 'model': SkillPriorMdl,
15 | 'logger': SkillSpaceLogger,
16 | 'data_dir': '.',
17 | 'epoch_cycles_train': 10,
18 | 'evaluator': TopOfNSequenceEvaluator,
19 | 'top_of_n_eval': 100,
20 | 'top_comp_metric': 'mse',
21 | }
22 | configuration = AttrDict(configuration)
23 |
24 | model_config = AttrDict(
25 | state_dim=data_spec.state_dim,
26 | action_dim=data_spec.n_actions,
27 | n_rollout_steps=10,
28 | kl_div_weight=5e-4,
29 | nz_enc=128,
30 | nz_mid=128,
31 | n_processing_layers=5,
32 | )
33 |
34 | # Dataset
35 | data_config = AttrDict()
36 | data_config.dataset_spec = data_spec
37 | data_config.dataset_spec.subseq_len = model_config.n_rollout_steps + 1 # flat last action from seq gets cropped
38 |
--------------------------------------------------------------------------------
/spirl/configs/skill_prior_learning/maze/hierarchical_cl/conf.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from spirl.models.skill_prior_mdl import SkillSpaceLogger
4 | from spirl.models.closed_loop_spirl_mdl import ImageClSPiRLMdl
5 | from spirl.utils.general_utils import AttrDict
6 | from spirl.configs.default_data_configs.maze import data_spec
7 | from spirl.components.evaluator import TopOfNSequenceEvaluator
8 |
9 |
10 | current_dir = os.path.dirname(os.path.realpath(__file__))
11 |
12 |
13 | configuration = {
14 | 'model': ImageClSPiRLMdl,
15 | 'logger': SkillSpaceLogger,
16 | 'data_dir': os.path.join(os.environ['DATA_DIR'], 'maze'),
17 | 'epoch_cycles_train': 10,
18 | 'evaluator': TopOfNSequenceEvaluator,
19 | 'top_of_n_eval': 100,
20 | 'top_comp_metric': 'mse',
21 | }
22 | configuration = AttrDict(configuration)
23 |
24 | model_config = AttrDict(
25 | state_dim=data_spec.state_dim,
26 | action_dim=data_spec.n_actions,
27 | n_rollout_steps=10,
28 | kl_div_weight=1e-2,
29 | prior_input_res=data_spec.res,
30 | n_input_frames=2,
31 | cond_decode=True,
32 | )
33 |
34 | # Dataset
35 | data_config = AttrDict()
36 | data_config.dataset_spec = data_spec
37 | data_config.dataset_spec.subseq_len = model_config.n_rollout_steps + model_config.n_input_frames
38 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/objects/cereal.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/demo_gen/block_stacking_demo_agent.py:
--------------------------------------------------------------------------------
1 | from spirl.utils.general_utils import AttrDict, ParamDict
2 | from spirl.rl.components.agent import BaseAgent
3 |
4 | from spirl.data.block_stacking.src.demo_gen.block_demo_policy import ClosedLoopBlockStackDemoPolicy
5 |
6 |
7 | class BlockStackingDemoAgent(BaseAgent):
8 | """Wraps demo policy for block stacking."""
9 | def __init__(self, config):
10 | super().__init__(config)
11 | self._policy = self._hp.policy(self._hp.env_params)
12 |
13 | def _default_hparams(self):
14 | default_dict = ParamDict({
15 | 'policy': ClosedLoopBlockStackDemoPolicy, # policy class
16 | 'env_params': None, # parameters containing info about env -> set automatically
17 | })
18 | return super()._default_hparams().overwrite(default_dict)
19 |
20 | @property
21 | def rollout_valid(self):
22 | return self._hp.env_params.task_complete_check()
23 |
24 | def reset(self):
25 | self._policy.reset()
26 |
27 | def _act(self, obs):
28 | return AttrDict(action=self._policy.act(obs))
29 |
30 | def _act_rand(self, obs):
31 | raise NotImplementedError("This should not be called in the demo agent.")
32 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/block_stacking_data_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from spirl.utils.general_utils import AttrDict
4 | from spirl.components.data_loader import GlobalSplitVideoDataset, GlobalSplitStateSequenceDataset
5 |
6 |
7 | class BlockStackSequenceDataset(GlobalSplitVideoDataset):
8 | """Adds info about env idx from file path."""
9 | def _get_aux_info(self, data, path):
10 | # extract env name from file path
11 | # TODO: design an env id system for block stacking envs
12 | return AttrDict(env_id=0)
13 |
14 | def __getitem__(self, index):
15 | data = super().__getitem__(index)
16 | for key in data.keys():
17 | if key.endswith('states') and data[key].shape[-1] == 40:
18 | # remove quatenion dimensions
19 | data[key] = data[key][:, :20]
20 | elif key.endswith('states') and data[key].shape[-1] == 43:
21 | data[key] = data[key][:, :23]
22 | if key.endswith('actions') and data[key].shape[-1] == 4:
23 | # remove rotation dimension
24 | data[key] = data[key][:, [0, 1, 3]]
25 | return data
26 |
27 |
28 | class BlockStackStateSequenceDataset(BlockStackSequenceDataset, GlobalSplitStateSequenceDataset):
29 | pass
30 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/grippers/gripper_factory.py:
--------------------------------------------------------------------------------
1 | """
2 | Defines a string based method of initializing grippers
3 | """
4 | from .two_finger_gripper import TwoFingerGripper, LeftTwoFingerGripper
5 | from .pr2_gripper import PR2Gripper
6 | from .robotiq_gripper import RobotiqGripper
7 | from .pushing_gripper import PushingGripper
8 | from .robotiq_three_finger_gripper import RobotiqThreeFingerGripper
9 |
10 |
11 | def gripper_factory(name):
12 | """
13 | Genreator for grippers
14 |
15 | Creates a Gripper instance with the provided name.
16 |
17 | Args:
18 | name: the name of the gripper class
19 |
20 | Returns:
21 | gripper: Gripper instance
22 |
23 | Raises:
24 | XMLError: [description]
25 | """
26 | if name == "TwoFingerGripper":
27 | return TwoFingerGripper()
28 | if name == "LeftTwoFingerGripper":
29 | return LeftTwoFingerGripper()
30 | if name == "PR2Gripper":
31 | return PR2Gripper()
32 | if name == "RobotiqGripper":
33 | return RobotiqGripper()
34 | if name == "PushingGripper":
35 | return PushingGripper()
36 | if name == "RobotiqThreeFingerGripper":
37 | return RobotiqThreeFingerGripper()
38 | raise ValueError("Unkown gripper name {}".format(name))
39 |
--------------------------------------------------------------------------------
/spirl/configs/skill_prior_learning/block_stacking/hierarchical_cl/conf.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from spirl.models.skill_prior_mdl import SkillSpaceLogger
4 | from spirl.models.closed_loop_spirl_mdl import ImageClSPiRLMdl
5 | from spirl.utils.general_utils import AttrDict
6 | from spirl.configs.default_data_configs.block_stacking import data_spec
7 | from spirl.components.evaluator import TopOfNSequenceEvaluator
8 |
9 |
10 | current_dir = os.path.dirname(os.path.realpath(__file__))
11 |
12 |
13 | configuration = {
14 | 'model': ImageClSPiRLMdl,
15 | 'logger': SkillSpaceLogger,
16 | 'data_dir': os.path.join(os.environ['DATA_DIR'], 'block_stacking'),
17 | 'epoch_cycles_train': 10,
18 | 'evaluator': TopOfNSequenceEvaluator,
19 | 'top_of_n_eval': 100,
20 | 'top_comp_metric': 'mse',
21 | }
22 | configuration = AttrDict(configuration)
23 |
24 | model_config = AttrDict(
25 | state_dim=data_spec.state_dim,
26 | action_dim=data_spec.n_actions,
27 | n_rollout_steps=10,
28 | kl_div_weight=1e-4,
29 | prior_input_res=data_spec.res,
30 | n_input_frames=2,
31 | cond_decode=True,
32 | )
33 |
34 | # Dataset
35 | data_config = AttrDict()
36 | data_config.dataset_spec = data_spec
37 | data_config.dataset_spec.subseq_len = model_config.n_rollout_steps + model_config.n_input_frames
38 |
--------------------------------------------------------------------------------
/spirl/configs/skill_prior_learning/block_stacking/hierarchical/conf.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from spirl.models.skill_prior_mdl import ImageSkillPriorMdl
4 | from spirl.data.block_stacking.src.block_stacking_logger import SkillSpaceBlockStackLogger
5 | from spirl.utils.general_utils import AttrDict
6 | from spirl.configs.default_data_configs.block_stacking import data_spec
7 | from spirl.components.evaluator import TopOfNSequenceEvaluator
8 |
9 |
10 | current_dir = os.path.dirname(os.path.realpath(__file__))
11 |
12 |
13 | configuration = {
14 | 'model': ImageSkillPriorMdl,
15 | 'logger': SkillSpaceBlockStackLogger,
16 | 'data_dir': os.path.join(os.environ['DATA_DIR'], 'block_stacking'),
17 | 'epoch_cycles_train': 10,
18 | 'evaluator': TopOfNSequenceEvaluator,
19 | 'top_of_n_eval': 100,
20 | 'top_comp_metric': 'mse',
21 | }
22 | configuration = AttrDict(configuration)
23 |
24 | model_config = AttrDict(
25 | state_dim=data_spec.state_dim,
26 | action_dim=data_spec.n_actions,
27 | n_rollout_steps=10,
28 | kl_div_weight=1e-2,
29 | prior_input_res=data_spec.res,
30 | n_input_frames=2,
31 | )
32 |
33 | # Dataset
34 | data_config = AttrDict()
35 | data_config.dataset_spec = data_spec
36 | data_config.dataset_spec.subseq_len = model_config.n_rollout_steps + model_config.n_input_frames
37 |
--------------------------------------------------------------------------------
/spirl/configs/rl/block_stacking/prior_initialized/base_conf.py:
--------------------------------------------------------------------------------
1 | from spirl.configs.rl.block_stacking.base_conf import *
2 | from spirl.rl.components.sampler import ACMultiImageAugmentedSampler
3 | from spirl.rl.policies.mlp_policies import ConvPolicy
4 | from spirl.rl.components.critic import SplitObsMLPCritic
5 | from spirl.models.bc_mdl import ImageBCMdl
6 |
7 |
8 | # update sampler
9 | configuration['sampler'] = ACMultiImageAugmentedSampler
10 | sampler_config = AttrDict(
11 | n_frames=2,
12 | )
13 |
14 | # update policy to conv policy
15 | agent_config.policy = ConvPolicy
16 | policy_params.update(AttrDict(
17 | input_nc=3 * sampler_config.n_frames,
18 | prior_model=ImageBCMdl,
19 | prior_model_params=AttrDict(state_dim=data_spec.state_dim,
20 | action_dim=data_spec.n_actions,
21 | input_res=data_spec.res,
22 | n_input_frames=2,
23 | ),
24 | prior_model_checkpoint=os.path.join(os.environ["EXP_DIR"],
25 | "skill_prior_learning/block_stacking/flat"),
26 | ))
27 |
28 | # update critic+policy to handle multi-frame combined observation
29 | agent_config.critic = SplitObsMLPCritic
30 | agent_config.critic_params.unused_obs_size = 32**2*3 * sampler_config.n_frames
31 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/arenas/arena.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from spirl.data.block_stacking.src.robosuite.models.base import MujocoXML
4 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import array_to_string, string_to_array
5 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import new_geom, new_body, new_joint
6 |
7 |
8 | class Arena(MujocoXML):
9 | """Base arena class."""
10 |
11 | def set_origin(self, offset):
12 | """Applies a constant offset to all objects."""
13 | offset = np.array(offset)
14 | for node in self.worldbody.findall("./*[@pos]"):
15 | cur_pos = string_to_array(node.get("pos"))
16 | new_pos = cur_pos + offset
17 | node.set("pos", array_to_string(new_pos))
18 |
19 | def add_pos_indicator(self):
20 | """Adds a new position indicator."""
21 | body = new_body(name="pos_indicator")
22 | body.append(
23 | new_geom(
24 | "sphere",
25 | [0.03],
26 | rgba=[1, 0, 0, 0.5],
27 | group=1,
28 | contype="0",
29 | conaffinity="0",
30 | )
31 | )
32 | body.append(new_joint(type="free", name="pos_indicator"))
33 | self.worldbody.append(body)
34 |
--------------------------------------------------------------------------------
/spirl/configs/data_collect/block_stacking/conf.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from spirl.utils.general_utils import AttrDict
4 |
5 | from spirl.data.block_stacking.src.demo_gen.block_stacking_demo_agent import BlockStackingDemoAgent
6 | from spirl.data.block_stacking.src.block_stacking_env import BlockStackEnv
7 | from spirl.data.block_stacking.src.block_task_generator import FixedSizeSingleTowerBlockTaskGenerator
8 |
9 |
10 | current_dir = os.path.dirname(os.path.realpath(__file__))
11 |
12 | notes = 'used for generating block stacking dataset'
13 | SEED = 31
14 |
15 | configuration = {
16 | 'seed': SEED,
17 | 'agent': BlockStackingDemoAgent,
18 | 'environment': BlockStackEnv,
19 | 'max_rollout_len': 250,
20 | }
21 | configuration = AttrDict(configuration)
22 |
23 | # Task
24 | task_params = AttrDict(
25 | max_tower_height=4,
26 | seed=SEED,
27 | )
28 |
29 | # Agent
30 | agent_config = AttrDict(
31 |
32 | )
33 |
34 | # Dataset - Random data
35 | data_config = AttrDict(
36 |
37 | )
38 |
39 | # Environment
40 | env_config = AttrDict(
41 | task_generator=FixedSizeSingleTowerBlockTaskGenerator,
42 | task_params=task_params,
43 | dimension=2,
44 | n_steps=2,
45 | screen_width=32,
46 | screen_height=32,
47 | rand_task=True,
48 | rand_init_pos=True,
49 | camera_name='agentview',
50 | )
51 |
52 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/utils/wide_range_gripper.py:
--------------------------------------------------------------------------------
1 | from spirl.data.block_stacking.src.robosuite.models.grippers.two_finger_gripper import TwoFingerGripper
2 |
3 |
4 | class WideRangeGripper(TwoFingerGripper):
5 | def __init__(self):
6 | super().__init__()
7 | l_gripper = self.worldbody.find(".//joint[@name='r_gripper_l_finger_joint']")
8 | l_gripper.set("range", "-0.020833 0.08")
9 | l_actuator = self.actuator.find(".//position[@joint='r_gripper_l_finger_joint']")
10 | l_actuator.set("ctrlrange", "-0.020833 0.08")
11 | l_actuator.set("gear", "2")
12 | r_gripper = self.worldbody.find(".//joint[@name='r_gripper_r_finger_joint']")
13 | r_gripper.set("range", "-0.08 0.020833")
14 | r_actuator = self.actuator.find(".//position[@joint='r_gripper_r_finger_joint']")
15 | r_actuator.set("ctrlrange", "-0.08 0.020833")
16 | r_actuator.set("gear", "2")
17 |
18 | for geom_name in ['l_finger_g1', 'r_finger_g1']:
19 | geom = self.worldbody.find(".//geom[@name='{}']".format(geom_name))
20 | geom.set("conaffinity", "0")
21 |
22 | for geom_name in ['l_fingertip_g0', 'r_fingertip_g0']:
23 | geom = self.worldbody.find(".//geom[@name='{}']".format(geom_name))
24 | geom.set("friction", "30 0.005 0.0001")
25 |
--------------------------------------------------------------------------------
/spirl/configs/skill_prior_learning/kitchen/hierarchical_cl/README.md:
--------------------------------------------------------------------------------
1 | # SPiRL w/ Closed-Loop Skill Decoder
2 |
3 | This version of the SPiRL model uses a [closed-loop action decoder](../../../../models/closed_loop_spirl_mdl.py):
4 | in contrast to the original SPiRL model it takes the current environment state as input in every skill decoding step.
5 |
6 | We find that this model improves performance over the original
7 | SPiRL model, particularly on tasks that require precise control, like in the kitchen environment.
8 |
9 |
10 |
11 |
12 |
13 |
14 | For an implementation of the closed-loop SPiRL model that supports image observations,
15 | see [here](../../block_stacking/hierarchical_cl/README.md).
16 |
17 | ## Example Commands
18 |
19 | To train the SPiRL model with closed-loop action decoder on the kitchen environment, run the following command:
20 | ```
21 | python3 spirl/train.py --path=spirl/configs/skill_prior_learning/kitchen/hierarchical_cl --val_data_size=160
22 | ```
23 |
24 | To train a downstream task policy with RL using the closed-loop SPiRL model we just trained, run the following command:
25 | ```
26 | python3 spirl/rl/train.py --path=spirl/configs/hrl/kitchen/spirl_cl --seed=0 --prefix=SPIRLv2_kitchen_seed0
27 | ```
28 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/tasks/task.py:
--------------------------------------------------------------------------------
1 | from spirl.data.block_stacking.src.robosuite.models.world import MujocoWorldBase
2 |
3 |
4 | class Task(MujocoWorldBase):
5 | """
6 | Base class for creating MJCF model of a task.
7 |
8 | A task typically involves a robot interacting with objects in an arena
9 | (workshpace). The purpose of a task class is to generate a MJCF model
10 | of the task by combining the MJCF models of each component together and
11 | place them to the right positions. Object placement can be done by
12 | ad-hoc methods or placement samplers.
13 | """
14 |
15 | def merge_robot(self, mujoco_robot):
16 | """Adds robot model to the MJCF model."""
17 | pass
18 |
19 | def merge_arena(self, mujoco_arena):
20 | """Adds arena model to the MJCF model."""
21 | pass
22 |
23 | def merge_objects(self, mujoco_objects):
24 | """Adds physical objects to the MJCF model."""
25 | pass
26 |
27 | def merge_visual(self, mujoco_objects):
28 | """Adds visual objects to the MJCF model."""
29 |
30 | def place_objects(self):
31 | """Places objects randomly until no collisions or max iterations hit."""
32 | pass
33 |
34 | def place_visual(self):
35 | """Places visual objects randomly until no collisions or max iterations hit."""
36 | pass
37 |
--------------------------------------------------------------------------------
/spirl/configs/rl/maze/prior_initialized/base_conf.py:
--------------------------------------------------------------------------------
1 | from spirl.configs.rl.maze.base_conf import *
2 | from spirl.rl.components.sampler import ACMultiImageAugmentedSampler
3 | from spirl.rl.policies.mlp_policies import ConvPolicy
4 | from spirl.rl.components.critic import SplitObsMLPCritic
5 | from spirl.models.bc_mdl import ImageBCMdl
6 |
7 |
8 | # update sampler
9 | configuration['sampler'] = ACMultiImageAugmentedSampler
10 | sampler_config = AttrDict(
11 | n_frames=2,
12 | )
13 | env_config.screen_width = data_spec.res
14 | env_config.screen_height = data_spec.res
15 |
16 | # update policy to conv policy
17 | agent_config.policy = ConvPolicy
18 | policy_params.update(AttrDict(
19 | input_nc=3 * sampler_config.n_frames,
20 | prior_model=ImageBCMdl,
21 | prior_model_params=AttrDict(state_dim=data_spec.state_dim,
22 | action_dim=data_spec.n_actions,
23 | input_res=data_spec.res,
24 | n_input_frames=sampler_config.n_frames,
25 | ),
26 | prior_model_checkpoint=os.path.join(os.environ["EXP_DIR"],
27 | "skill_prior_learning/maze/flat"),
28 | ))
29 |
30 | # update critic+policy to handle multi-frame combined observation
31 | agent_config.critic = SplitObsMLPCritic
32 | agent_config.critic_params.unused_obs_size = 32**2*3 * sampler_config.n_frames
33 |
34 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/demo.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from spirl.data.block_stacking.src import robosuite as suite
3 |
4 | if __name__ == "__main__":
5 |
6 | # get the list of all environments
7 | envs = sorted(core.data.block_stacking.src.robosuite.environments.ALL_ENVS)
8 |
9 | # print info and select an environment
10 | print("Welcome to Surreal Robotics Suite v{}!".format(suite.__version__))
11 | print(suite.__logo__)
12 | print("Here is a list of environments in the suite:\n")
13 |
14 | for k, env in enumerate(envs):
15 | print("[{}] {}".format(k, env))
16 | print()
17 | try:
18 | s = input(
19 | "Choose an environment to run "
20 | + "(enter a number from 0 to {}): ".format(len(envs) - 1)
21 | )
22 | # parse input into a number within range
23 | k = min(max(int(s), 0), len(envs))
24 | except:
25 | print("Input is not valid. Use 0 by default.")
26 | k = 0
27 |
28 | # initialize the task
29 | env = suite.make(
30 | envs[k],
31 | has_renderer=True,
32 | ignore_done=True,
33 | use_camera_obs=False,
34 | control_freq=100,
35 | )
36 | env.reset()
37 | env.viewer.set_camera(camera_id=0)
38 |
39 | # do visualization
40 | for i in range(1000):
41 | action = np.random.randn(env.dof)
42 | obs, reward, done, _ = env.step(action)
43 | env.render()
44 |
--------------------------------------------------------------------------------
/spirl/utils/data_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class DataSubsampler:
5 | def __init__(self, aggregator):
6 | self._aggregator = aggregator
7 |
8 | def __call__(self, *args, **kwargs):
9 | raise NotImplementedError("This function needs to be implemented by sub-classes!")
10 |
11 |
12 | class FixedFreqSubsampler(DataSubsampler):
13 | """Subsamples input array's first dimension by skipping given number of frames."""
14 | def __init__(self, n_skip, aggregator=None):
15 | super().__init__(aggregator)
16 | self._n_skip = n_skip
17 |
18 | def __call__(self, val, idxs=None, aggregate=False):
19 | """Subsamples with idxs if given, aggregates with aggregator if aggregate=True."""
20 | if self._n_skip == 0:
21 | return val, None
22 |
23 | if idxs is None:
24 | seq_len = val.shape[0]
25 | idxs = np.arange(0, seq_len - 1, self._n_skip + 1)
26 |
27 | if aggregate:
28 | assert self._aggregator is not None # no aggregator given!
29 | return self._aggregator(val, idxs), idxs
30 | else:
31 | return val[idxs], idxs
32 |
33 |
34 | class Aggregator:
35 | def __call__(self, *args, **kwargs):
36 | raise NotImplementedError("This function needs to be implemented by sub-classes!")
37 |
38 |
39 | class SumAggregator(Aggregator):
40 | def __call__(self, val, idxs):
41 | return np.add.reduceat(val, idxs, axis=0)
42 |
43 |
--------------------------------------------------------------------------------
/spirl/utils/video_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from PIL import Image
4 | from torchvision.transforms import Resize
5 |
6 |
7 | def ch_first2last(video):
8 | return video.transpose((0,2,3,1))
9 |
10 |
11 | def ch_last2first(video):
12 | return video.transpose((0,3,1,2))
13 |
14 |
15 | def resize_video(video, size):
16 | if video.shape[1] == 3:
17 | video = np.transpose(video, (0,2,3,1))
18 | transformed_video = np.stack([np.asarray(Resize(size)(Image.fromarray(im))) for im in video], axis=0)
19 | return transformed_video
20 |
21 |
22 | def _make_dir(filename):
23 | folder = os.path.dirname(filename)
24 | if not os.path.exists(folder):
25 | os.makedirs(folder)
26 |
27 |
28 | def save_video(video_frames, filename, fps=60, video_format='mp4'):
29 | assert fps == int(fps), fps
30 | import skvideo.io
31 | _make_dir(filename)
32 |
33 | skvideo.io.vwrite(
34 | filename,
35 | video_frames,
36 | inputdict={
37 | '-r': str(int(fps)),
38 | },
39 | outputdict={
40 | '-f': video_format,
41 | '-pix_fmt': 'yuv420p', # '-pix_fmt=yuv420p' needed for osx https://github.com/scikit-video/scikit-video/issues/74
42 | }
43 | )
44 |
45 |
46 | def create_video_grid(col_and_row_frames):
47 | video_grid_frames = np.concatenate([
48 | np.concatenate(row_frames, axis=-2)
49 | for row_frames in col_and_row_frames
50 | ], axis=-3)
51 |
52 | return video_grid_frames
53 |
54 |
55 |
--------------------------------------------------------------------------------
/spirl/configs/hrl/kitchen/spirl_cl/conf.py:
--------------------------------------------------------------------------------
1 | from spirl.configs.hrl.kitchen.spirl.conf import *
2 | from spirl.models.closed_loop_spirl_mdl import ClSPiRLMdl
3 | from spirl.rl.policies.cl_model_policies import ClModelPolicy
4 |
5 | # update model params to conditioned decoder on state
6 | ll_model_params.cond_decode = True
7 |
8 | # create LL closed-loop policy
9 | ll_policy_params = AttrDict(
10 | policy_model=ClSPiRLMdl,
11 | policy_model_params=ll_model_params,
12 | policy_model_checkpoint=os.path.join(os.environ["EXP_DIR"],
13 | "skill_prior_learning/kitchen/hierarchical_cl"),
14 | )
15 | ll_policy_params.update(ll_model_params)
16 |
17 | # create LL SAC agent (by default we will only use it for rolling out decoded skills, not finetuning skill decoder)
18 | ll_agent_config = AttrDict(
19 | policy=ClModelPolicy,
20 | policy_params=ll_policy_params,
21 | critic=MLPCritic, # LL critic is not used since we are not finetuning LL
22 | critic_params=hl_critic_params
23 | )
24 |
25 | # update HL policy model params
26 | hl_policy_params.update(AttrDict(
27 | prior_model=ll_policy_params.policy_model,
28 | prior_model_params=ll_policy_params.policy_model_params,
29 | prior_model_checkpoint=ll_policy_params.policy_model_checkpoint,
30 | ))
31 |
32 | # register new LL agent in agent_config and turn off LL agent updates
33 | agent_config.update(AttrDict(
34 | ll_agent=SACAgent,
35 | ll_agent_params=ll_agent_config,
36 | update_ll=False,
37 | ))
38 |
39 |
40 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/scripts/demo_gripper_selection.py:
--------------------------------------------------------------------------------
1 | """
2 | This script shows you how to select gripper for an environment.
3 | This is controlled by gripper_type keyword argument
4 | """
5 | from spirl.data.block_stacking.src import robosuite as suite
6 | from spirl.data.block_stacking.src.robosuite.wrappers import GymWrapper
7 |
8 |
9 | if __name__ == "__main__":
10 |
11 | grippers = ["TwoFingerGripper", "PR2Gripper", "RobotiqGripper"]
12 |
13 | for gripper in grippers:
14 |
15 | # create environment with selected grippers
16 | env = GymWrapper(
17 | suite.make(
18 | "SawyerPickPlace",
19 | gripper_type=gripper,
20 | use_camera_obs=False, # do not use pixel observations
21 | has_offscreen_renderer=False, # not needed since not using pixel obs
22 | has_renderer=True, # make sure we can render to the screen
23 | reward_shaping=True, # use dense rewards
24 | control_freq=100, # control should happen fast enough so that simulation looks smooth
25 | )
26 | )
27 |
28 | # run a random policy
29 | observation = env.reset()
30 | for t in range(500):
31 | env.render()
32 | action = env.action_space.sample()
33 | observation, reward, done, info = env.step(action)
34 | if done:
35 | print("Episode finished after {} timesteps".format(t + 1))
36 | break
37 |
38 | # close window
39 | env.close()
40 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/arenas/bins_arena.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from spirl.data.block_stacking.src.robosuite.models.arenas import Arena
3 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import xml_path_completion
4 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import array_to_string, string_to_array
5 |
6 |
7 | class BinsArena(Arena):
8 | """Workspace that contains two bins placed side by side."""
9 |
10 | def __init__(
11 | self, table_full_size=(0.39, 0.49, 0.82), table_friction=(1, 0.005, 0.0001)
12 | ):
13 | """
14 | Args:
15 | table_full_size: full dimensions of the table
16 | friction: friction parameters of the table
17 | """
18 | super().__init__(xml_path_completion("arenas/bins_arena.xml"))
19 |
20 | self.table_full_size = np.array(table_full_size)
21 | self.table_half_size = self.table_full_size / 2
22 | self.table_friction = table_friction
23 |
24 | self.floor = self.worldbody.find("./geom[@name='floor']")
25 | self.bin1_body = self.worldbody.find("./body[@name='bin1']")
26 | self.bin2_body = self.worldbody.find("./body[@name='bin2']")
27 |
28 | self.configure_location()
29 |
30 | def configure_location(self):
31 | self.bottom_pos = np.array([0, 0, 0])
32 | self.floor.set("pos", array_to_string(self.bottom_pos))
33 |
34 | @property
35 | def bin_abs(self):
36 | """Returns the absolute position of table top"""
37 | return string_to_array(self.bin1_body.get("pos"))
38 |
--------------------------------------------------------------------------------
/spirl/rl/envs/maze.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import d4rl
3 |
4 | from spirl.rl.components.environment import GymEnv
5 | from spirl.utils.general_utils import ParamDict, AttrDict
6 |
7 |
8 | class MazeEnv(GymEnv):
9 | """Shallow wrapper around gym env for maze envs."""
10 | def __init__(self, *args, **kwargs):
11 | super().__init__(*args, **kwargs)
12 |
13 | def _default_hparams(self):
14 | default_dict = ParamDict({
15 | })
16 | return super()._default_hparams().overwrite(default_dict)
17 |
18 | def reset(self):
19 | super().reset()
20 | if self.TARGET_POS is not None and self.START_POS is not None:
21 | self._env.set_target(self.TARGET_POS)
22 | self._env.reset_to_location(self.START_POS)
23 | self._env.render(mode='rgb_array') # these are necessary to make sure new state is rendered on first frame
24 | obs, _, _, _ = self._env.step(np.zeros_like(self._env.action_space.sample()))
25 | return self._wrap_observation(obs)
26 |
27 | def step(self, *args, **kwargs):
28 | obs, rew, done, info = super().step(*args, **kwargs)
29 | return obs, np.float64(rew), done, info # casting reward to float64 is important for getting shape later
30 |
31 |
32 | class ACRandMaze0S40Env(MazeEnv):
33 | START_POS = np.array([10., 24.])
34 | TARGET_POS = np.array([18., 8.])
35 |
36 | def _default_hparams(self):
37 | default_dict = ParamDict({
38 | 'name': "maze2d-randMaze0S40-ac-v0",
39 | })
40 | return super()._default_hparams().overwrite(default_dict)
41 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/objects/square-nut.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/robots/baxter_robot.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from spirl.data.block_stacking.src.robosuite.models.robots.robot import Robot
3 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import xml_path_completion, array_to_string
4 |
5 |
6 | class Baxter(Robot):
7 | """Baxter is a hunky bimanual robot designed by Rethink Robotics."""
8 |
9 | def __init__(self):
10 | super().__init__(xml_path_completion("robots/baxter/robot.xml"))
11 |
12 | self.bottom_offset = np.array([0, 0, -0.913])
13 | self.left_hand = self.worldbody.find(".//body[@name='left_hand']")
14 |
15 | def set_base_xpos(self, pos):
16 | """Places the robot on position @pos."""
17 | node = self.worldbody.find("./body[@name='base']")
18 | node.set("pos", array_to_string(pos - self.bottom_offset))
19 |
20 | @property
21 | def dof(self):
22 | return 14
23 |
24 | @property
25 | def joints(self):
26 | out = []
27 | for s in ["right_", "left_"]:
28 | out.extend(s + a for a in ["s0", "s1", "e0", "e1", "w0", "w1", "w2"])
29 | return out
30 |
31 | @property
32 | def init_qpos(self):
33 | # Arms ready to work on the table
34 | return np.array([
35 | 0.535, -0.093, 0.038, 0.166, 0.643, 1.960, -1.297,
36 | -0.518, -0.026, -0.076, 0.175, -0.748, 1.641, -0.158])
37 |
38 | # Arms half extended
39 | return np.array([
40 | 0.752, -0.038, -0.021, 0.161, 0.348, 2.095, -0.531,
41 | -0.585, -0.117, -0.037, 0.164, -0.536, 1.543, 0.204])
42 |
43 | # Arms fully extended
44 | return np.zeros(14)
45 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/arenas/table_arena.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/scripts/demo_baxter_ik_control.py:
--------------------------------------------------------------------------------
1 | """End-effector control for bimanual Baxter robot.
2 |
3 | This script shows how to use inverse kinematics solver from Bullet
4 | to command the end-effectors of two arms of the Baxter robot.
5 | """
6 |
7 | import os
8 | import numpy as np
9 |
10 | from spirl.data.block_stacking.src import robosuite
11 | from spirl.data.block_stacking.src.robosuite.wrappers import IKWrapper
12 |
13 |
14 | if __name__ == "__main__":
15 |
16 | # initialize a Baxter environment
17 | env = robosuite.make(
18 | "BaxterLift",
19 | ignore_done=True,
20 | has_renderer=True,
21 | gripper_visualization=True,
22 | use_camera_obs=False,
23 | )
24 | env = IKWrapper(env)
25 |
26 | obs = env.reset()
27 |
28 | # rotate the gripper so we can see it easily
29 | env.set_robot_joint_positions([
30 | 0.00, -0.55, 0.00, 1.28, 0.00, 0.26, 0.00,
31 | 0.00, -0.55, 0.00, 1.28, 0.00, 0.26, 0.00,
32 | ])
33 |
34 | bullet_data_path = os.path.join(robosuite.models.assets_root, "bullet_data")
35 |
36 | def robot_jpos_getter():
37 | return np.array(env._joint_positions)
38 |
39 | for t in range(100000):
40 | omega = 2 * np.pi / 1000.
41 | A = 5e-4
42 | dpos_right = np.array([A * np.cos(omega * t), 0, A * np.sin(omega * t)])
43 | dpos_left = np.array([A * np.sin(omega * t), A * np.cos(omega * t), 0])
44 | dquat = np.array([0, 0, 0, 1])
45 | grasp = 0.
46 | action = np.concatenate([dpos_right, dquat, dpos_left, dquat, [grasp, grasp]])
47 |
48 | obs, reward, done, info = env.step(action)
49 | env.render()
50 |
51 | if done:
52 | break
53 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/arenas/pegs_arena.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from spirl.data.block_stacking.src.robosuite.models.arenas import Arena
3 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import xml_path_completion
4 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import array_to_string, string_to_array
5 |
6 |
7 | class PegsArena(Arena):
8 | """Workspace that contains a tabletop with two fixed pegs."""
9 |
10 | def __init__(
11 | self, table_full_size=(0.45, 0.69, 0.82), table_friction=(1, 0.005, 0.0001)
12 | ):
13 | """
14 | Args:
15 | table_full_size: full dimensions of the table
16 | table_friction: friction parameters of the table
17 | """
18 | super().__init__(xml_path_completion("arenas/pegs_arena.xml"))
19 |
20 | self.table_full_size = np.array(table_full_size)
21 | self.table_half_size = self.table_full_size / 2
22 | self.table_friction = table_friction
23 |
24 | self.floor = self.worldbody.find("./geom[@name='floor']")
25 | self.table_body = self.worldbody.find("./body[@name='table']")
26 | self.peg1_body = self.worldbody.find("./body[@name='peg1']")
27 | self.peg2_body = self.worldbody.find("./body[@name='peg2']")
28 | self.table_collision = self.table_body.find("./geom[@name='table_collision']")
29 |
30 | self.configure_location()
31 |
32 | def configure_location(self):
33 | self.bottom_pos = np.array([0, 0, 0])
34 | self.floor.set("pos", array_to_string(self.bottom_pos))
35 |
36 | @property
37 | def table_top_abs(self):
38 | """
39 | Returns the absolute position of table top.
40 | """
41 | return string_to_array(self.table_body.get("pos"))
42 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/scripts/demo_video_recording.py:
--------------------------------------------------------------------------------
1 | """
2 | Record video of agent episodes with the imageio library.
3 | This script uses offscreen rendering.
4 |
5 | Example:
6 | $ python demo_video_recording.py --environment SawyerLift
7 | """
8 |
9 | import argparse
10 | import imageio
11 | import numpy as np
12 |
13 | from spirl.data.block_stacking.src.robosuite import make
14 |
15 |
16 | if __name__ == "__main__":
17 |
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument("--environment", type=str, default="SawyerStack")
20 | parser.add_argument("--video_path", type=str, default="video.mp4")
21 | parser.add_argument("--timesteps", type=int, default=500)
22 | parser.add_argument("--height", type=int, default=512)
23 | parser.add_argument("--width", type=int, default=512)
24 | parser.add_argument("--skip_frame", type=int, default=1)
25 | args = parser.parse_args()
26 |
27 | # initialize an environment with offscreen renderer
28 | env = make(
29 | args.environment,
30 | has_renderer=False,
31 | ignore_done=True,
32 | use_camera_obs=True,
33 | use_object_obs=False,
34 | camera_height=args.height,
35 | camera_width=args.width,
36 | )
37 |
38 | obs = env.reset()
39 | dof = env.dof
40 |
41 | # create a video writer with imageio
42 | writer = imageio.get_writer(args.video_path, fps=20)
43 |
44 | frames = []
45 | for i in range(args.timesteps):
46 |
47 | # run a uniformly random agent
48 | action = 0.5 * np.random.randn(dof)
49 | obs, reward, done, info = env.step(action)
50 |
51 | # dump a frame from every K frames
52 | if i % args.skip_frame == 0:
53 | frame = obs["image"][::-1]
54 | writer.append_data(frame)
55 | print("Saving frame #{}".format(i))
56 |
57 | if done:
58 | break
59 |
60 | writer.close()
61 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/utils/numbered_box_object.py:
--------------------------------------------------------------------------------
1 | import os
2 | import xml.etree.ElementTree as ET
3 | from spirl.data.block_stacking.src.robosuite.models.objects import BoxObject
4 |
5 |
6 | class NumberedBoxObject(BoxObject):
7 | def __init__(self, **kwargs):
8 | self.number = kwargs.pop("number")
9 | super().__init__(**kwargs)
10 |
11 | if self.number is not None:
12 | asset_dir = os.path.join(os.getcwd(), "spirl/data/block_stacking/assets")
13 | texture_path = os.path.join(asset_dir, "textures/obj{}.png".format(self.number))
14 |
15 | texture = ET.SubElement(self.asset, "texture")
16 | texture.set("file", texture_path)
17 | texture.set("name", "obj-{}-texture".format(self.number))
18 | texture.set("gridsize", "1 2")
19 | texture.set("gridlayout", ".U")
20 | texture.set("rgb1", "1.0 1.0 1.0")
21 | texture.set("vflip", "true")
22 | texture.set("hflip", "true")
23 |
24 | material = ET.SubElement(self.asset, "material")
25 | material.set("name", "obj-{}-material".format(self.number))
26 | material.set("reflectance", "0.5")
27 | material.set("specular", "0.5")
28 | material.set("shininess", "0.1")
29 | material.set("texture", "obj-{}-texture".format(self.number))
30 | material.set("texuniform", "false")
31 |
32 | def get_collision_attrib_template(self):
33 | template = super().get_collision_attrib_template()
34 | if self.number is not None:
35 | template.update({"material": "obj-{}-material".format(self.number)})
36 | return template
37 |
38 | def get_visual_attrib_template(self):
39 | template = super().get_visual_attrib_template()
40 | if self.number is not None:
41 | template.update({"material": "obj-{}-material".format(self.number)})
42 | return template
43 |
--------------------------------------------------------------------------------
/spirl/rl/envs/kitchen.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from collections import defaultdict
3 | import d4rl
4 |
5 | from spirl.utils.general_utils import AttrDict
6 | from spirl.utils.general_utils import ParamDict
7 | from spirl.rl.components.environment import GymEnv
8 |
9 |
10 | class KitchenEnv(GymEnv):
11 | """Tiny wrapper around GymEnv for Kitchen tasks."""
12 | SUBTASKS = ['microwave', 'kettle', 'slide cabinet', 'hinge cabinet', 'bottom burner', 'light switch', 'top burner']
13 | def _default_hparams(self):
14 | return super()._default_hparams().overwrite(ParamDict({
15 | 'name': "kitchen-mixed-v0",
16 | }))
17 |
18 | def step(self, *args, **kwargs):
19 | obs, rew, done, info = super().step(*args, **kwargs)
20 | return obs, np.float64(rew), done, self._postprocess_info(info) # casting reward to float64 is important for getting shape later
21 |
22 | def reset(self):
23 | self.solved_subtasks = defaultdict(lambda: 0)
24 | return super().reset()
25 |
26 | def get_episode_info(self):
27 | info = super().get_episode_info()
28 | info.update(AttrDict(self.solved_subtasks))
29 | return info
30 |
31 | def _postprocess_info(self, info):
32 | """Sorts solved subtasks into separately logged elements."""
33 | completed_subtasks = info.pop("completed_tasks")
34 | for task in self.SUBTASKS:
35 | self.solved_subtasks[task] = 1 if task in completed_subtasks or self.solved_subtasks[task] else 0
36 | return info
37 |
38 |
39 | class NoGoalKitchenEnv(KitchenEnv):
40 | """Splits off goal from obs."""
41 | def step(self, *args, **kwargs):
42 | obs, rew, done, info = super().step(*args, **kwargs)
43 | obs = obs[:int(obs.shape[0]/2)]
44 | return obs, rew, done, info
45 |
46 | def reset(self, *args, **kwargs):
47 | obs = super().reset(*args, **kwargs)
48 | return obs[:int(obs.shape[0]/2)]
49 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/scripts/demo_pygame_renderer.py:
--------------------------------------------------------------------------------
1 | """pygame rendering demo.
2 |
3 | This script provides an example of using the pygame library for rendering
4 | camera observations as an alternative to the default mujoco_py renderer.
5 |
6 | Example:
7 | $ python run_pygame_renderer.py --environment BaxterPegInHole --width 1000 --height 1000
8 | """
9 |
10 | import sys
11 | import argparse
12 | import pygame
13 | import numpy as np
14 |
15 | from spirl.data.block_stacking.src import robosuite
16 |
17 | if __name__ == "__main__":
18 |
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument("--environment", type=str, default="BaxterLift")
21 | parser.add_argument("--timesteps", type=int, default=10000)
22 | parser.add_argument("--width", type=int, default=512)
23 | parser.add_argument("--height", type=int, default=384)
24 | args = parser.parse_args()
25 |
26 | width = args.width
27 | height = args.height
28 | screen = pygame.display.set_mode((width, height))
29 |
30 | env = robosuite.make(
31 | args.environment,
32 | has_renderer=False,
33 | ignore_done=True,
34 | camera_height=height,
35 | camera_width=width,
36 | show_gripper_visualization=True,
37 | use_camera_obs=True,
38 | use_object_obs=False,
39 | )
40 |
41 | for i in range(args.timesteps):
42 |
43 | # issue random actions
44 | action = 0.5 * np.random.randn(env.dof)
45 | obs, reward, done, info = env.step(action)
46 |
47 | for event in pygame.event.get():
48 | if event.type == pygame.QUIT:
49 | sys.exit()
50 |
51 | # read camera observation
52 | im = np.flip(obs["image"].transpose((1, 0, 2)), 1)
53 | pygame.pixelcopy.array_to_surface(screen, im)
54 | pygame.display.update()
55 |
56 | if i % 100 == 0:
57 | print("step #{}".format(i))
58 |
59 | if done:
60 | break
61 |
--------------------------------------------------------------------------------
/spirl/rl/utils/robosuite_utils.py:
--------------------------------------------------------------------------------
1 | from spirl.data.block_stacking.src.robosuite.environments.base import MujocoEnv
2 | from mujoco_py import MjSim, MjRenderContextOffscreen
3 | from spirl.data.block_stacking.src.robosuite.utils import MujocoPyRenderer
4 |
5 |
6 | class FastResetMujocoEnv(MujocoEnv):
7 | """Only loads the mujoco XML file once to allow for quicker resets."""
8 | def _reset_internal(self):
9 | """Resets simulation internal configurations."""
10 | # instantiate simulation from MJCF model
11 | self._load_model()
12 | if not hasattr(self, "mjpy_model"):
13 | self.mjpy_model = self.model.get_model(mode="mujoco_py")
14 | self.sim = MjSim(self.mjpy_model)
15 | self.initialize_time(self.control_freq)
16 |
17 | # create visualization screen or renderer
18 | if self.has_renderer and self.viewer is None:
19 | self.viewer = MujocoPyRenderer(self.sim)
20 | self.viewer.viewer.vopt.geomgroup[0] = (
21 | 1 if self.render_collision_mesh else 0
22 | )
23 | self.viewer.viewer.vopt.geomgroup[1] = 1 if self.render_visual_mesh else 0
24 |
25 | # hiding the overlay speeds up rendering significantly
26 | self.viewer.viewer._hide_overlay = True
27 |
28 | elif self.has_offscreen_renderer:
29 | if self.sim._render_context_offscreen is None:
30 | render_context = MjRenderContextOffscreen(self.sim)
31 | self.sim.add_render_context(render_context)
32 | self.sim._render_context_offscreen.vopt.geomgroup[0] = (
33 | 1 if self.render_collision_mesh else 0
34 | )
35 | self.sim._render_context_offscreen.vopt.geomgroup[1] = (
36 | 1 if self.render_visual_mesh else 0
37 | )
38 |
39 | # additional housekeeping
40 | self.sim_state_initial = self.sim.get_state()
41 | self._get_reference()
42 | self.cur_time = 0
43 | self.timestep = 0
44 | self.done = False
45 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/grippers/pr2_gripper.py:
--------------------------------------------------------------------------------
1 | """
2 | 4 dof gripper with two fingers and its open/close variant
3 | """
4 | import numpy as np
5 |
6 | from spirl.data.block_stacking.src.robosuite.models.grippers import Gripper
7 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import xml_path_completion
8 |
9 |
10 | class PR2GripperBase(Gripper):
11 | """
12 | A 4 dof gripper with two fingers.
13 | """
14 |
15 | def __init__(self):
16 | super().__init__(xml_path_completion("grippers/pr2_gripper.xml"))
17 |
18 | def format_action(self, action):
19 | return action
20 |
21 | @property
22 | def init_qpos(self):
23 | return np.zeros(4)
24 |
25 | @property
26 | def joints(self):
27 | return [
28 | "r_gripper_r_finger_joint",
29 | "r_gripper_l_finger_joint",
30 | "r_gripper_r_finger_tip_joint",
31 | "r_gripper_l_finger_tip_joint",
32 | ]
33 |
34 | @property
35 | def dof(self):
36 | return 4
37 |
38 | def contact_geoms(self):
39 | return [
40 | "r_gripper_l_finger",
41 | "r_gripper_l_finger_tip",
42 | "r_gripper_r_finger",
43 | "r_gripper_r_finger_tip",
44 | ]
45 |
46 | @property
47 | def visualization_sites(self):
48 | return ["grip_site", "grip_site_cylinder"]
49 |
50 | @property
51 | def left_finger_geoms(self):
52 | return ["r_gripper_l_finger", "r_gripper_l_finger_tip"]
53 |
54 | @property
55 | def right_finger_geoms(self):
56 | return ["r_gripper_r_finger", "r_gripper_r_finger_tip"]
57 |
58 |
59 | class PR2Gripper(PR2GripperBase):
60 | """
61 | Open/close variant of PR2 gripper.
62 | """
63 |
64 | def format_action(self, action):
65 | """
66 | Args:
67 | action: 1 => open, -1 => closed
68 | """
69 | assert len(action) == 1
70 | return np.ones(4) * action
71 |
72 | @property
73 | def dof(self):
74 | return 1
75 |
--------------------------------------------------------------------------------
/spirl/rl/utils/rollout_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import h5py
4 | import numpy as np
5 |
6 |
7 | class RolloutSaver(object):
8 | """Saves rollout episodes to a target directory."""
9 | def __init__(self, save_dir):
10 | if not os.path.exists(save_dir):
11 | os.makedirs(save_dir)
12 | self.save_dir = save_dir
13 | self.counter = 0
14 |
15 | def save_rollout_to_file(self, episode):
16 | """Saves an episode to the next file index of the target folder."""
17 | # get save path
18 | save_path = os.path.join(self.save_dir, "rollout_{}.h5".format(self.counter))
19 |
20 | # save rollout to file
21 | f = h5py.File(save_path, "w")
22 | f.create_dataset("traj_per_file", data=1)
23 |
24 | # store trajectory info in traj0 group
25 | traj_data = f.create_group("traj0")
26 | traj_data.create_dataset("states", data=np.array(episode.observation))
27 | traj_data.create_dataset("images", data=np.array(episode.image, dtype=np.uint8))
28 | traj_data.create_dataset("actions", data=np.array(episode.action))
29 |
30 | terminals = np.array(episode.done)
31 | if np.sum(terminals) == 0:
32 | terminals[-1] = True
33 |
34 | # build pad-mask that indicates how long sequence is
35 | is_terminal_idxs = np.nonzero(terminals)[0]
36 | pad_mask = np.zeros((len(terminals),))
37 | pad_mask[:is_terminal_idxs[0]] = 1.
38 | traj_data.create_dataset("pad_mask", data=pad_mask)
39 |
40 | f.close()
41 |
42 | self.counter += 1
43 |
44 | def _resize_video(self, images, dim=64):
45 | """Resize a video in numpy array form to target dimension."""
46 | ret = np.zeros((images.shape[0], dim, dim, 3))
47 |
48 | for i in range(images.shape[0]):
49 | ret[i] = cv2.resize(images[i], dsize=(dim, dim),
50 | interpolation=cv2.INTER_CUBIC)
51 |
52 | return ret.astype(np.uint8)
53 |
54 | def reset(self):
55 | """Resets counter."""
56 | self.counter = 0
57 |
--------------------------------------------------------------------------------
/spirl/configs/skill_prior_learning/block_stacking/hierarchical_cl/README.md:
--------------------------------------------------------------------------------
1 | # Image-based SPiRL w/ Closed-Loop Skill Decoder
2 |
3 | This version of the SPiRL model uses a [closed-loop action decoder](../../../../models/closed_loop_spirl_mdl.py#L55):
4 | in contrast to the original SPiRL model it takes the current environment observation as input in every skill decoding step.
5 |
6 | This image-based model is a direct extension of the
7 | [state-based SPiRL model with closed-loop skill decoder](../../kitchen/hierarchical_cl/README.md).
8 | Similar to the state-based model we find that the image-based closed-loop model improves performance over the original
9 | image-based SPiRL model, particularly in tasks that require precise control.
10 | We evaluate it on a more challenging, sparse reward version of the block stacking environment
11 | where the agent is rewarded for the height of the tower it built, but does not receive any rewards for picking or lifting
12 | blocks. We find that on this challenging environment, the closed-loop skill decoder ("SPiRLv2") outperforms the original
13 | SPiRL model with open-loop skill decoder ("SPiRLv1").
14 |
15 |
16 |
17 |
18 |
19 |
20 | We also tried the closed-loop model on the image-based maze navigation task, but did not find it to improve performance,
21 | which we attribute to the easier control task that does not require closed-loop control.
22 |
23 | ## Example Commands
24 |
25 | To train the image-based SPiRL model with closed-loop action decoder on the block stacking environment, run the following command:
26 | ```
27 | python3 spirl/train.py --path=spirl/configs/skill_prior_learning/block_stacking/hierarchical_cl --val_data_size=160
28 | ```
29 |
30 | To train a downstream task policy with RL using the closed-loop image-based SPiRL model
31 | on the sparse reward block stacking environment, run the following command:
32 | ```
33 | python3 spirl/rl/train.py --path=spirl/configs/hrl/block_stacking/spirl_cl --seed=0 --prefix=SPIRLv2_block_stacking_seed0
34 | ```
35 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/devices/README.md:
--------------------------------------------------------------------------------
1 | # Devices
2 |
3 | Devices are used to read user input and collect human demonstrations. Demonstrations can be collected by either using a keyboard or using a [SpaceNavigator 3D Mouse](https://www.3dconnexion.com/spacemouse_compact/en/) with the [collect_human_demonstrations](robosuite/scripts/collect_human_demonstrations.py) script. More generally, we support any interface that implements the [Device](device.py) abstract base class. In order to support your own custom device, simply subclass this base class and implement the required methods.
4 |
5 | ## Keyboard
6 |
7 | We support keyboard input through the GLFW window created by the mujoco-py renderer.
8 |
9 | **Keyboard controls**
10 |
11 | Note that the rendering window must be active for these commands to work.
12 |
13 | | Keys | Command |
14 | | :------: | :--------------------------------: |
15 | | q | reset simulation |
16 | | spacebar | toggle gripper (open/close) |
17 | | w-a-s-d | move arm horizontally in x-y plane |
18 | | r-f | move arm vertically |
19 | | z-x | rotate arm about x-axis |
20 | | t-g | rotate arm about y-axis |
21 | | c-v | rotate arm about z-axis |
22 | | ESC | quit |
23 |
24 | ## SpaceNavigator 3D Mouse
25 |
26 | We support the use of a [SpaceNavigator 3D Mouse](https://www.3dconnexion.com/spacemouse_compact/en/) as well.
27 |
28 | **SpaceNavigator 3D Mouse controls**
29 |
30 | | Control | Command |
31 | | :-----------------------: | :-----------------------------------: |
32 | | Right button | reset simulation |
33 | | Left button (hold) | close gripper |
34 | | Move mouse laterally | move arm horizontally in x-y plane |
35 | | Move mouse vertically | move arm vertically |
36 | | Twist mouse about an axis | rotate arm about a corresponding axis |
37 | | ESC (keyboard) | quit |
38 |
39 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/robots/robot.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | from spirl.data.block_stacking.src.robosuite.models.base import MujocoXML
4 | from spirl.data.block_stacking.src.robosuite.utils import XMLError
5 |
6 |
7 | class Robot(MujocoXML):
8 | """Base class for all robot models."""
9 |
10 | def __init__(self, fname):
11 | """Initializes from file @fname."""
12 | super().__init__(fname)
13 | # key: gripper name and value: gripper model
14 | self.grippers = OrderedDict()
15 |
16 | def add_gripper(self, arm_name, gripper):
17 | """
18 | Mounts gripper to arm.
19 |
20 | Throws error if robot already has a gripper or gripper type is incorrect.
21 |
22 | Args:
23 | arm_name (str): name of arm mount
24 | gripper (MujocoGripper instance): gripper MJCF model
25 | """
26 | if arm_name in self.grippers:
27 | raise ValueError("Attempts to add multiple grippers to one body")
28 |
29 | arm_subtree = self.worldbody.find(".//body[@name='{}']".format(arm_name))
30 |
31 | for actuator in gripper.actuator:
32 |
33 | if actuator.get("name") is None:
34 | raise XMLError("Actuator has no name")
35 |
36 | if not actuator.get("name").startswith("gripper"):
37 | raise XMLError(
38 | "Actuator name {} does not have prefix 'gripper'".format(
39 | actuator.get("name")
40 | )
41 | )
42 |
43 | for body in gripper.worldbody:
44 | arm_subtree.append(body)
45 |
46 | self.merge(gripper, merge_body=False)
47 | self.grippers[arm_name] = gripper
48 |
49 | @property
50 | def dof(self):
51 | """Returns the number of DOF of the robot, not including gripper."""
52 | raise NotImplementedError
53 |
54 | @property
55 | def joints(self):
56 | """Returns a list of joint names of the robot."""
57 | raise NotImplementedError
58 |
59 | @property
60 | def init_qpos(self):
61 | """Returns default qpos."""
62 | raise NotImplementedError
63 |
--------------------------------------------------------------------------------
/spirl/data/README.md:
--------------------------------------------------------------------------------
1 | ## Downloading Datasets via the Command Line
2 |
3 | To download the dataset files from Google Drive via the command line, you can use the
4 | [gdown](https://github.com/wkentaro/gdown) package. Install it with:
5 | ```
6 | pip install gdown
7 | ```
8 |
9 | Then navigate to the folder you want to download the data to and run the following commands:
10 | ```
11 | # Download Maze Dataset
12 | gdown https://drive.google.com/uc?id=1pXM-EDCwFrfgUjxITBsR48FqW9gMoXYZ
13 |
14 | # Download Block Stacking Dataset
15 | gdown https://drive.google.com/uc?id=1VobNYJQw_Uwax0kbFG7KOXTgv6ja2s1M
16 | ```
17 |
18 | ## Re-Generating Datasets
19 |
20 | ### Maze Dataset
21 | To regenerate the maze dataset, our fork of the [D4RL repo](https://github.com/kpertsch/d4rl) needs to be cloned and installed.
22 | It includes the script used to generate the maze dataset. Specifically, new data can be created by running:
23 | ```
24 | cd d4rl
25 | python3 scripts/generate_randMaze2d_datasets.py --render --agent_centric --save_images --data_dir=path_to_your_dir
26 | ```
27 | Optionally, an argument `--batch_idx` allows to automatically generate a subfolder in `data_dir` with the batch index,
28 | so that multiple data generation scripts with different batch indices can be run in parallel
29 | for accelerated data generation.
30 |
31 | The number of trajectories that are getting generated can be controlled through the argument `--num_samples`; the size
32 | of the randomly generated training mazes can be changed with `--rand_maze_size`. For a full list of all arguments, see
33 | [```scripts/generate_randMaze2d_datasets.py```](https://github.com/kpertsch/d4rl/scripts/generate_randMaze2d_datasets.py#L72).
34 |
35 |
36 | ### Block Stacking Dataset
37 | To regenerate the block stacking dataset we can use the config provided in [```spirl/configs/data_collect/block_stacking```](spirl/configs/data_collect/block_stacking/conf.py).
38 | To start generation, run:
39 | ```
40 | python3 spirl/rl/train.py --path=spirl/configs/data_collect/block_stacking --mode=rollout --n_val_samples=2e5 --seed=42 --data_dir=path_to_your_dir
41 | ```
42 | If you want to run multiple data generation jobs in parallel, make sure to change the seed and set a different target
43 | data directory.
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/grippers/robotiq_three_finger_gripper.py:
--------------------------------------------------------------------------------
1 | """
2 | Gripper with 11-DoF controlling three fingers and its open/close variant.
3 | """
4 | import numpy as np
5 |
6 | from spirl.data.block_stacking.src.robosuite.models.grippers import Gripper
7 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import xml_path_completion
8 |
9 |
10 | class RobotiqThreeFingerGripperBase(Gripper):
11 | """
12 | Gripper with 11 dof controlling three fingers.
13 | """
14 |
15 | def __init__(self):
16 | super().__init__(xml_path_completion("grippers/robotiq_gripper_s.xml"))
17 |
18 | def format_action(self, action):
19 | return action
20 |
21 | @property
22 | def init_qpos(self):
23 | return np.zeros(11)
24 |
25 | @property
26 | def joints(self):
27 | return [
28 | "palm_finger_1_joint",
29 | "finger_1_joint_1",
30 | "finger_1_joint_2",
31 | "finger_1_joint_3",
32 | "palm_finger_2_joint",
33 | "finger_2_joint_1",
34 | "finger_2_joint_2",
35 | "finger_2_joint_3",
36 | "finger_middle_joint_1",
37 | "finger_middle_joint_2",
38 | "finger_middle_joint_3",
39 | ]
40 |
41 | @property
42 | def dof(self):
43 | return 11
44 |
45 | def contact_geoms(self):
46 | return [
47 | "f1_l0",
48 | "f1_l1",
49 | "f1_l2",
50 | "f1_l3",
51 | "f2_l0",
52 | "f2_l1",
53 | "f2_l2",
54 | "f2_l3",
55 | "f3_l0",
56 | "f3_l1",
57 | "f3_l2",
58 | "f3_l3",
59 | ]
60 |
61 | @property
62 | def visualization_sites(self):
63 | return ["grip_site", "grip_site_cylinder"]
64 |
65 |
66 | class RobotiqThreeFingerGripper(RobotiqThreeFingerGripperBase):
67 | """
68 | 1-DoF variant of RobotiqThreeFingerGripperBase.
69 | """
70 |
71 | def format_action(self, action):
72 | """
73 | Args:
74 | action: 1 => open, -1 => closed
75 | """
76 | movement = np.array([0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1])
77 | return -1 * movement * action
78 |
--------------------------------------------------------------------------------
/spirl/configs/rl/kitchen/base_conf.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from spirl.utils.general_utils import AttrDict
4 | from spirl.rl.agents.ac_agent import SACAgent
5 | from spirl.rl.policies.mlp_policies import MLPPolicy
6 | from spirl.rl.components.critic import MLPCritic
7 | from spirl.rl.components.replay_buffer import UniformReplayBuffer
8 | from spirl.rl.envs.kitchen import KitchenEnv
9 | from spirl.rl.components.normalization import Normalizer
10 | from spirl.configs.default_data_configs.kitchen import data_spec
11 |
12 |
13 | current_dir = os.path.dirname(os.path.realpath(__file__))
14 |
15 | notes = 'non-hierarchical RL experiments in kitchen env'
16 |
17 | configuration = {
18 | 'seed': 42,
19 | 'agent': SACAgent,
20 | 'environment': KitchenEnv,
21 | 'data_dir': '.',
22 | 'num_epochs': 100,
23 | 'max_rollout_len': 280,
24 | 'n_steps_per_epoch': 50000,
25 | 'n_warmup_steps': 500,
26 | }
27 | configuration = AttrDict(configuration)
28 |
29 | # Policy
30 | policy_params = AttrDict(
31 | action_dim=data_spec.n_actions,
32 | input_dim=data_spec.state_dim,
33 | n_layers=5, # number of policy network layers
34 | nz_mid=256,
35 | max_action_range=1.,
36 | )
37 |
38 | # Critic
39 | critic_params = AttrDict(
40 | action_dim=policy_params.action_dim,
41 | input_dim=policy_params.input_dim,
42 | output_dim=1,
43 | n_layers=2, # number of policy network layers
44 | nz_mid=256,
45 | action_input=True,
46 | )
47 |
48 | # Replay Buffer
49 | replay_params = AttrDict(
50 | capacity=1e5,
51 | dump_replay=False,
52 | )
53 |
54 | # Observation Normalization
55 | obs_norm_params = AttrDict(
56 | )
57 |
58 | # Agent
59 | agent_config = AttrDict(
60 | policy=MLPPolicy,
61 | policy_params=policy_params,
62 | critic=MLPCritic,
63 | critic_params=critic_params,
64 | replay=UniformReplayBuffer,
65 | replay_params=replay_params,
66 | # obs_normalizer=Normalizer,
67 | # obs_normalizer_params=obs_norm_params,
68 | clip_q_target=False,
69 | batch_size=256,
70 | log_video_caption=True,
71 | )
72 |
73 | # Dataset - Random data
74 | data_config = AttrDict()
75 | data_config.dataset_spec = data_spec
76 |
77 | # Environment
78 | env_config = AttrDict(
79 | reward_norm=1.,
80 | )
81 |
82 |
--------------------------------------------------------------------------------
/spirl/configs/rl/maze/base_conf.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from spirl.utils.general_utils import AttrDict
4 | from spirl.rl.policies.mlp_policies import MLPPolicy
5 | from spirl.rl.components.critic import MLPCritic
6 | from spirl.rl.components.replay_buffer import UniformReplayBuffer
7 | from spirl.rl.envs.maze import ACRandMaze0S40Env
8 | from spirl.rl.components.normalization import Normalizer
9 | from spirl.configs.default_data_configs.maze import data_spec
10 | from spirl.data.maze.src.maze_agents import MazeSACAgent
11 |
12 |
13 | current_dir = os.path.dirname(os.path.realpath(__file__))
14 |
15 | notes = 'non-hierarchical RL experiments in maze env'
16 |
17 | configuration = {
18 | 'seed': 42,
19 | 'agent': MazeSACAgent,
20 | 'environment': ACRandMaze0S40Env,
21 | 'data_dir': '.',
22 | 'num_epochs': 100,
23 | 'max_rollout_len': 2000,
24 | 'n_steps_per_epoch': 100000,
25 | 'n_warmup_steps': 5e3,
26 | }
27 | configuration = AttrDict(configuration)
28 |
29 | # Policy
30 | policy_params = AttrDict(
31 | action_dim=data_spec.n_actions,
32 | input_dim=data_spec.state_dim,
33 | n_layers=2, # number of policy network layers
34 | nz_mid=256,
35 | max_action_range=1.,
36 | )
37 |
38 | # Critic
39 | critic_params = AttrDict(
40 | action_dim=policy_params.action_dim,
41 | input_dim=policy_params.input_dim,
42 | output_dim=1,
43 | n_layers=2, # number of policy network layers
44 | nz_mid=256,
45 | action_input=True,
46 | )
47 |
48 | # Replay Buffer
49 | replay_params = AttrDict(
50 | capacity=1e6,
51 | dump_replay=False,
52 | )
53 |
54 | # Observation Normalization
55 | obs_norm_params = AttrDict(
56 | )
57 |
58 | # Agent
59 | agent_config = AttrDict(
60 | policy=MLPPolicy,
61 | policy_params=policy_params,
62 | critic=MLPCritic,
63 | critic_params=critic_params,
64 | replay=UniformReplayBuffer,
65 | replay_params=replay_params,
66 | # obs_normalizer=Normalizer,
67 | # obs_normalizer_params=obs_norm_params,
68 | clip_q_target=False,
69 | batch_size=256,
70 | log_videos=False,
71 | )
72 |
73 | # Dataset - Random data
74 | data_config = AttrDict()
75 | data_config.dataset_spec = data_spec
76 |
77 | # Environment
78 | env_config = AttrDict(
79 | reward_norm=1.,
80 | )
81 |
82 |
--------------------------------------------------------------------------------
/spirl/configs/rl/peg_in_hole/base_conf.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from spirl.utils.general_utils import AttrDict
4 | from spirl.rl.agents.ac_agent import SACAgent
5 | from spirl.rl.policies.mlp_policies import MLPPolicy
6 | from spirl.rl.components.critic import MLPCritic
7 | from spirl.rl.components.replay_buffer import UniformReplayBuffer
8 | from spirl.rl.envs.peg_in_hole import PegInHoleEnv
9 | from spirl.rl.components.normalization import Normalizer
10 | from spirl.configs.default_data_configs.peg_in_hole import data_spec
11 |
12 |
13 | current_dir = os.path.dirname(os.path.realpath(__file__))
14 |
15 | notes = 'non-hierarchical RL experiments in peg-in-hole env'
16 |
17 | configuration = {
18 | 'seed': 42,
19 | 'agent': SACAgent,
20 | 'environment': PegInHoleEnv,
21 | 'data_dir': '.',
22 | 'num_epochs': 100,
23 | 'max_rollout_len': 280,
24 | 'n_steps_per_epoch': 50000,
25 | 'n_warmup_steps': 500,
26 | }
27 | configuration = AttrDict(configuration)
28 |
29 | # Policy
30 | policy_params = AttrDict(
31 | action_dim=data_spec.n_actions,
32 | input_dim=data_spec.state_dim,
33 | n_layers=5, # number of policy network layers
34 | nz_mid=256,
35 | max_action_range=1.,
36 | )
37 |
38 | # Critic
39 | critic_params = AttrDict(
40 | action_dim=policy_params.action_dim,
41 | input_dim=policy_params.input_dim,
42 | output_dim=1,
43 | n_layers=2, # number of policy network layers
44 | nz_mid=256,
45 | action_input=True,
46 | )
47 |
48 | # Replay Buffer
49 | replay_params = AttrDict(
50 | capacity=1e5,
51 | dump_replay=False,
52 | )
53 |
54 | # Observation Normalization
55 | obs_norm_params = AttrDict(
56 | )
57 |
58 | # Agent
59 | agent_config = AttrDict(
60 | policy=MLPPolicy,
61 | policy_params=policy_params,
62 | critic=MLPCritic,
63 | critic_params=critic_params,
64 | replay=UniformReplayBuffer,
65 | replay_params=replay_params,
66 | # obs_normalizer=Normalizer,
67 | # obs_normalizer_params=obs_norm_params,
68 | clip_q_target=False,
69 | batch_size=256,
70 | log_video_caption=True,
71 | )
72 |
73 | # Dataset - Random data
74 | data_config = AttrDict()
75 | data_config.dataset_spec = data_spec
76 |
77 | # Environment
78 | env_config = AttrDict(
79 | reward_norm=1.,
80 | )
81 |
82 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/block_stacking_logger.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from spirl.components.logger import Logger
4 | from spirl.models.skill_prior_mdl import SkillSpaceLogger
5 | from spirl.data.block_stacking.src.block_stacking_env import BlockStackEnv
6 | from spirl.utils.general_utils import AttrDict
7 | from spirl.utils.vis_utils import add_caption_to_img
8 | from spirl.data.block_stacking.src.block_task_generator import FixedSizeSingleTowerBlockTaskGenerator
9 |
10 |
11 | class BlockStackLogger(Logger):
12 | # logger to save visualizations of input and output trajectories in block stacking environment
13 |
14 | @staticmethod
15 | def _init_env_from_id(id):
16 | # TODO: return different environment variants depending on id
17 | task_params = AttrDict(
18 | max_tower_height=4
19 | )
20 |
21 | env_config = AttrDict(
22 | task_generator=FixedSizeSingleTowerBlockTaskGenerator,
23 | task_params=task_params,
24 | dimension=2,
25 | screen_width=128,
26 | screen_height=128
27 | )
28 |
29 | return BlockStackEnv(env_config)
30 |
31 | @staticmethod
32 | def _render_state(env, model_xml, obs, name=""):
33 | env.reset()
34 |
35 | unwrapped_obs = env._unflatten_block_obs(obs)
36 |
37 | sim_state = env.sim.get_state()
38 |
39 | sim_state.qpos[:len(sim_state.qpos)] = env.obs2qpos(obs)
40 | env.sim.set_state(sim_state)
41 | env.sim.forward()
42 | img = env.render()
43 |
44 | # round function
45 | rd = lambda x: np.round(x, 2)
46 |
47 | # add caption to the image
48 | info = {
49 | "Robot Pos": rd(unwrapped_obs["gripper_pos"]),
50 | "Robot Ang": rd(unwrapped_obs["gripper_angle"]),
51 | "Gripper Finger Pos": rd(unwrapped_obs["gripper_finger_pos"]),
52 | }
53 |
54 | for i in range(unwrapped_obs["block_pos"].shape[0]):
55 | info.update({
56 | "Block {}:".format(i): rd(unwrapped_obs["block_pos"][i])
57 | })
58 |
59 | img = add_caption_to_img(img, info, name, flip_rgb=True)
60 |
61 | return img
62 |
63 |
64 | class SkillSpaceBlockStackLogger(BlockStackLogger, SkillSpaceLogger):
65 | pass
66 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/wrappers/gym_wrapper.py:
--------------------------------------------------------------------------------
1 | """
2 | This file implements a wrapper for facilitating compatibility with OpenAI gym.
3 | This is useful when using these environments with code that assumes a gym-like
4 | interface.
5 | """
6 |
7 | import numpy as np
8 | from gym import spaces
9 | from spirl.data.block_stacking.src.robosuite.wrappers import Wrapper
10 |
11 |
12 | class GymWrapper(Wrapper):
13 | env = None
14 |
15 | def __init__(self, env, keys=None):
16 | """
17 | Initializes the Gym wrapper.
18 |
19 | Args:
20 | env (MujocoEnv instance): The environment to wrap.
21 | keys (list of strings): If provided, each observation will
22 | consist of concatenated keys from the wrapped environment's
23 | observation dictionary. Defaults to robot-state and object-state.
24 | """
25 | self.env = env
26 |
27 | if keys is None:
28 | assert self.env.use_object_obs, "Object observations need to be enabled."
29 | keys = ["robot-state", "object-state"]
30 | self.keys = keys
31 |
32 | # set up observation and action spaces
33 | flat_ob = self._flatten_obs(self.env.reset(), verbose=True)
34 | self.obs_dim = flat_ob.size
35 | high = np.inf * np.ones(self.obs_dim)
36 | low = -high
37 | self.observation_space = spaces.Box(low=low, high=high)
38 | low, high = self.env.action_spec
39 | self.action_space = spaces.Box(low=low, high=high)
40 |
41 | def _flatten_obs(self, obs_dict, verbose=False):
42 | """
43 | Filters keys of interest out and concatenate the information.
44 |
45 | Args:
46 | obs_dict: ordered dictionary of observations
47 | """
48 | ob_lst = []
49 | for key in obs_dict:
50 | if key in self.keys:
51 | if verbose:
52 | print("adding key: {}".format(key))
53 | ob_lst.append(obs_dict[key])
54 | return np.concatenate(ob_lst)
55 |
56 | def reset(self):
57 | ob_dict = self.env.reset()
58 | return self._flatten_obs(ob_dict)
59 |
60 | def step(self, action):
61 | ob_dict, reward, done, info = self.env.step(action)
62 | return self._flatten_obs(ob_dict), reward, done, info
63 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/wrappers/wrapper.py:
--------------------------------------------------------------------------------
1 | """
2 | This file contains the base wrapper class for Mujoco environments.
3 | Wrappers are useful for data collection and logging. Highly recommended.
4 | """
5 |
6 |
7 | class Wrapper:
8 | env = None
9 |
10 | def __init__(self, env):
11 | self.env = env
12 |
13 | @classmethod
14 | def class_name(cls):
15 | return cls.__name__
16 |
17 | def _warn_double_wrap(self):
18 | env = self.env
19 | while True:
20 | if isinstance(env, Wrapper):
21 | if env.class_name() == self.class_name():
22 | raise Exception(
23 | "Attempted to double wrap with Wrapper: {}".format(
24 | self.__class__.__name__
25 | )
26 | )
27 | env = env.env
28 | else:
29 | break
30 |
31 | def step(self, action):
32 | return self.env.step(action)
33 |
34 | def reset(self):
35 | return self.env.reset()
36 |
37 | def render(self, **kwargs):
38 | return self.env.render(**kwargs)
39 |
40 | def observation_spec(self):
41 | return self.env.observation_spec()
42 |
43 | def action_spec(self):
44 | return self.env.action_spec()
45 |
46 | @property
47 | def dof(self):
48 | return self.env.dof
49 |
50 | @property
51 | def unwrapped(self):
52 | if hasattr(self.env, "unwrapped"):
53 | return self.env.unwrapped
54 | else:
55 | return self.env
56 |
57 | # this method is a fallback option on any methods the original env might support
58 | def __getattr__(self, attr):
59 | # using getattr ensures that both __getattribute__ and __getattr__ (fallback) get called
60 | # (see https://stackoverflow.com/questions/3278077/difference-between-getattr-vs-getattribute)
61 | orig_attr = getattr(self.env, attr)
62 | if callable(orig_attr):
63 |
64 | def hooked(*args, **kwargs):
65 | result = orig_attr(*args, **kwargs)
66 | # prevent wrapped_class from becoming unwrapped
67 | if result == self.env:
68 | return self
69 | return result
70 |
71 | return hooked
72 | else:
73 | return orig_attr
74 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/arenas/table_arena.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from spirl.data.block_stacking.src.robosuite.models.arenas import Arena
3 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import xml_path_completion
4 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import array_to_string, string_to_array
5 |
6 |
7 | class TableArena(Arena):
8 | """Workspace that contains an empty table."""
9 |
10 | def __init__(
11 | self, table_full_size=(0.8, 0.8, 0.8), table_friction=(1, 0.005, 0.0001)
12 | ):
13 | """
14 | Args:
15 | table_full_size: full dimensions of the table
16 | friction: friction parameters of the table
17 | """
18 | super().__init__(xml_path_completion("arenas/table_arena.xml"))
19 |
20 | self.table_full_size = np.array(table_full_size)
21 | self.table_half_size = self.table_full_size / 2
22 | self.table_friction = table_friction
23 |
24 | self.floor = self.worldbody.find("./geom[@name='floor']")
25 | self.table_body = self.worldbody.find("./body[@name='table']")
26 | self.table_collision = self.table_body.find("./geom[@name='table_collision']")
27 | self.table_visual = self.table_body.find("./geom[@name='table_visual']")
28 | self.table_top = self.table_body.find("./site[@name='table_top']")
29 |
30 | self.configure_location()
31 |
32 | def configure_location(self):
33 | self.bottom_pos = np.array([0, 0, 0])
34 | self.floor.set("pos", array_to_string(self.bottom_pos))
35 |
36 | self.center_pos = self.bottom_pos + np.array([0, 0, self.table_half_size[2]])
37 | self.table_body.set("pos", array_to_string(self.center_pos))
38 | self.table_collision.set("size", array_to_string(self.table_half_size))
39 | self.table_collision.set("friction", array_to_string(self.table_friction))
40 | self.table_visual.set("size", array_to_string(self.table_half_size))
41 |
42 | self.table_top.set(
43 | "pos", array_to_string(np.array([0, 0, self.table_half_size[2]]))
44 | )
45 |
46 | @property
47 | def table_top_abs(self):
48 | """Returns the absolute position of table top"""
49 | table_height = np.array([0, 0, self.table_full_size[2]])
50 | return string_to_array(self.floor.get("pos")) + table_height
51 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/arenas/pegs_arena.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/scripts/demo_learning_curriculum.py:
--------------------------------------------------------------------------------
1 | """Demo of learning curriculum utilities.
2 |
3 | Several prior works have demonstrated the effectiveness of altering the
4 | start state distribution of training episodes for learning RL policies.
5 | We provide a generic utility for setting various types of learning
6 | curriculums which dictate how to sample from demonstration episodes
7 | when doing an environment reset. For more information see the
8 | `DemoSamplerWrapper` class.
9 |
10 | Related work:
11 |
12 | [1] Reinforcement and Imitation Learning for Diverse Visuomotor Skills
13 | Yuke Zhu, Ziyu Wang, Josh Merel, Andrei Rusu, Tom Erez, Serkan Cabi,Saran Tunyasuvunakool,
14 | János Kramár, Raia Hadsell, Nando de Freitas, Nicolas Heess
15 | RSS 2018
16 |
17 | [2] Backplay: "Man muss immer umkehren"
18 | Cinjon Resnick, Roberta Raileanu, Sanyam Kapoor, Alex Peysakhovich, Kyunghyun Cho, Joan Bruna
19 | arXiv:1807.06919
20 |
21 | [3] DeepMimic: Example-Guided Deep Reinforcement Learning of Physics-Based Character Skills
22 | Xue Bin Peng, Pieter Abbeel, Sergey Levine, Michiel van de Panne
23 | Transactions on Graphics 2018
24 |
25 | [4] Approximately optimal approximate reinforcement learning
26 | Sham Kakade and John Langford
27 | ICML 2002
28 | """
29 |
30 | import os
31 |
32 | from spirl.data.block_stacking.src import robosuite
33 | from spirl.data.block_stacking.src.robosuite import make
34 | from spirl.data.block_stacking.src.robosuite.wrappers import DemoSamplerWrapper
35 |
36 |
37 | if __name__ == "__main__":
38 |
39 | env = make(
40 | "SawyerPickPlace",
41 | has_renderer=True,
42 | has_offscreen_renderer=False,
43 | ignore_done=True,
44 | use_camera_obs=False,
45 | reward_shaping=True,
46 | gripper_visualization=True,
47 | )
48 |
49 | env = DemoSamplerWrapper(
50 | env,
51 | demo_path=os.path.join(
52 | robosuite.models.assets_root, "demonstrations/SawyerPickPlace"
53 | ),
54 | need_xml=True,
55 | num_traj=-1,
56 | sampling_schemes=["uniform", "random"],
57 | scheme_ratios=[0.9, 0.1],
58 | )
59 |
60 | for _ in range(100):
61 | env.reset()
62 | env.viewer.set_camera(0)
63 | env.render()
64 | for i in range(100):
65 | if i == 0:
66 | reward = env.reward()
67 | print("reward", reward)
68 | env.render()
69 |
--------------------------------------------------------------------------------
/spirl/rl/components/params.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def get_args():
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument("--path", help="path to the config file directory")
7 |
8 | # Folder settings
9 | parser.add_argument("--prefix", help="experiment prefix, if given creates subfolder in experiment directory")
10 | parser.add_argument('--new_dir', default=False, type=int, help='If True, concat datetime string to exp_dir.')
11 | parser.add_argument('--dont_save', default=False, type=int,
12 | help="if True, nothing is saved to disk. Note: this doesn't work") # TODO this doesn't work
13 |
14 | # Running protocol
15 | parser.add_argument('--resume', default='latest', type=str, metavar='PATH',
16 | help='path to latest checkpoint (default: none)')
17 | parser.add_argument('--mode', default='train', type=str,
18 | choices=['train', 'val', 'rollout'],
19 | help='mode of the program (training, validation, or generate rollout)')
20 |
21 | # Misc
22 | parser.add_argument('--seed', default=-1, type=int,
23 | help='overrides config/default seed for more convenient seed setting.')
24 | parser.add_argument('--gpu', default=0, type=int,
25 | help='will set CUDA_VISIBLE_DEVICES to selected value')
26 | parser.add_argument('--strict_weight_loading', default=True, type=int,
27 | help='if True, uses strict weight loading function')
28 | parser.add_argument('--deterministic', default=False, type=int,
29 | help='if True, sets fixed seeds for torch and numpy')
30 | parser.add_argument('--n_val_samples', default=10, type=int,
31 | help='number of validation episodes')
32 | parser.add_argument('--save_dir', type=str,
33 | help='directory for saving the generated rollouts in rollout mode')
34 | parser.add_argument('--config_override', default='', type=str,
35 | help='override to config file in format "key1.key2=val1,key3=val2"')
36 |
37 | # Debug
38 | parser.add_argument('--debug', default=False, type=int,
39 | help='if True, runs in debug mode')
40 |
41 | # Note
42 | parser.add_argument('--notes', default='', type=str,
43 | help='Notes for the run')
44 |
45 | return parser.parse_args()
46 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/assets/objects/round-nut.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/spirl/utils/dist_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def safe_entropy(dist, dim=None, eps=1e-12):
6 | """Computes entropy even if some entries are 0."""
7 | return -torch.sum(dist * safe_log_prob(dist, eps), dim=dim)
8 |
9 |
10 | def safe_log_prob(tensor, eps=1e-12):
11 | """Safe log of probability values (must be between 0 and 1)"""
12 | return torch.log(torch.clamp(tensor, eps, 1 - eps))
13 |
14 |
15 | def normalize(tensor, dim=1, eps=1e-7):
16 | norm = torch.clamp(tensor.sum(dim, keepdim=True), eps)
17 | return tensor / norm
18 |
19 |
20 | def gumbel_sample(shape, eps=1e-8):
21 | """Sample Gumbel noise."""
22 | uniform = torch.rand(shape).float()
23 | return -torch.log(eps - torch.log(uniform + eps))
24 |
25 |
26 | def gumbel_softmax_sample(logits, temp=1.):
27 | """Sample from the Gumbel softmax / concrete distribution."""
28 | gumbel_noise = gumbel_sample(logits.size()).to(logits.device)
29 | return F.softmax((logits + gumbel_noise) / temp, dim=1)
30 |
31 |
32 | def log_cumsum(probs, dim=1, eps=1e-8):
33 | """Calculate log of inclusive cumsum."""
34 | return torch.log(torch.cumsum(probs, dim=dim) + eps)
35 |
36 |
37 | def poisson_categorical_log_prior(length, rate, device):
38 | """Categorical prior populated with log probabilities of Poisson dist.
39 | From: https://github.com/tkipf/compile/blob/b88b17411c37e1ed95459a0a779d71d5acef9e3f/utils.py#L58"""
40 | rate = torch.tensor(rate, dtype=torch.float32, device=device)
41 | values = torch.arange(1, length + 1, dtype=torch.float32, device=device).unsqueeze(0)
42 | log_prob_unnormalized = torch.log(rate) * values - rate - (values + 1).lgamma()
43 | # TODO(tkipf): Length-sensitive normalization.
44 | return F.log_softmax(log_prob_unnormalized, dim=1) # Normalize.
45 |
46 |
47 | def kl_categorical(preds, log_prior, eps=1e-8):
48 | """KL divergence between two categorical distributions."""
49 | kl_div = preds * (torch.log(preds + eps) - log_prior)
50 | return kl_div.sum(1)
51 |
52 |
53 | class Dirac:
54 | """Dummy Dirac distribution."""
55 | def __init__(self, val):
56 | self._val = val
57 |
58 | def sample(self):
59 | return self._val
60 |
61 | def rsample(self):
62 | return self._val
63 |
64 | def log_prob(self, val):
65 | return torch.tensor(int(val == self._val), dtype=torch.float32, device=self._val.device)
66 |
67 | @property
68 | def logits(self):
69 | """This is more of a dummy return."""
70 | return self._val
71 |
--------------------------------------------------------------------------------
/spirl/configs/rl/block_stacking/base_conf.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from spirl.utils.general_utils import AttrDict
4 | from spirl.rl.agents.ac_agent import SACAgent
5 | from spirl.rl.policies.mlp_policies import MLPPolicy
6 | from spirl.rl.components.critic import MLPCritic
7 | from spirl.rl.components.replay_buffer import UniformReplayBuffer
8 | from spirl.rl.envs.block_stacking import HighStack11StackEnvV0
9 | from spirl.rl.components.normalization import Normalizer
10 | from spirl.configs.default_data_configs.block_stacking import data_spec
11 |
12 |
13 | current_dir = os.path.dirname(os.path.realpath(__file__))
14 |
15 | notes = 'non-hierarchical RL experiments in block stacking env'
16 |
17 | configuration = {
18 | 'seed': 42,
19 | 'agent': SACAgent,
20 | 'environment': HighStack11StackEnvV0,
21 | 'data_dir': '.',
22 | 'num_epochs': 100,
23 | 'max_rollout_len': 1000,
24 | 'n_steps_per_epoch': 100000,
25 | 'n_warmup_steps': 5e3,
26 | }
27 | configuration = AttrDict(configuration)
28 |
29 | # Policy
30 | policy_params = AttrDict(
31 | action_dim=data_spec.n_actions,
32 | input_dim=data_spec.state_dim,
33 | n_layers=5, # number of policy network layers
34 | nz_mid=256,
35 | max_action_range=1.,
36 | )
37 |
38 | # Critic
39 | critic_params = AttrDict(
40 | action_dim=policy_params.action_dim,
41 | input_dim=policy_params.input_dim,
42 | output_dim=1,
43 | n_layers=2, # number of policy network layers
44 | nz_mid=256,
45 | action_input=True,
46 | )
47 |
48 | # Replay Buffer
49 | replay_params = AttrDict(
50 | capacity=1e5,
51 | dump_replay=False,
52 | )
53 |
54 | # Observation Normalization
55 | obs_norm_params = AttrDict(
56 | )
57 |
58 | # Agent
59 | agent_config = AttrDict(
60 | policy=MLPPolicy,
61 | policy_params=policy_params,
62 | critic=MLPCritic,
63 | critic_params=critic_params,
64 | replay=UniformReplayBuffer,
65 | replay_params=replay_params,
66 | # obs_normalizer=Normalizer,
67 | # obs_normalizer_params=obs_norm_params,
68 | clip_q_target=False,
69 | batch_size=256,
70 | log_video_caption=True,
71 | )
72 |
73 | # Dataset - Random data
74 | data_config = AttrDict()
75 | data_config.dataset_spec = data_spec
76 |
77 | # Environment
78 | env_config = AttrDict(
79 | name="block_stacking",
80 | reward_norm=1.,
81 | screen_width=data_spec.res,
82 | screen_height=data_spec.res,
83 | env_config=AttrDict(camera_name='agentview',
84 | screen_width=data_spec.res,
85 | screen_height=data_spec.res,)
86 | )
87 |
88 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/grippers/robotiq_gripper.py:
--------------------------------------------------------------------------------
1 | """
2 | 6-DoF gripper with its open/close variant
3 | """
4 | import numpy as np
5 |
6 | from spirl.data.block_stacking.src.robosuite.models.grippers import Gripper
7 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import xml_path_completion
8 |
9 |
10 | class RobotiqGripperBase(Gripper):
11 | """
12 | 6-DoF Robotiq gripper.
13 | """
14 |
15 | def __init__(self):
16 | super().__init__(xml_path_completion("grippers/robotiq_gripper.xml"))
17 |
18 | @property
19 | def init_qpos(self):
20 | return [3.3161, 0., 0., 0., 0., 0.]
21 |
22 | @property
23 | def joints(self):
24 | return [
25 | "robotiq_85_left_knuckle_joint",
26 | "robotiq_85_left_inner_knuckle_joint",
27 | "robotiq_85_left_finger_tip_joint",
28 | "robotiq_85_right_knuckle_joint",
29 | "robotiq_85_right_inner_knuckle_joint",
30 | "robotiq_85_right_finger_tip_joint",
31 | ]
32 |
33 | @property
34 | def dof(self):
35 | return 6
36 |
37 | def contact_geoms(self):
38 | return [
39 | "robotiq_85_gripper_joint_0_L",
40 | "robotiq_85_gripper_joint_1_L",
41 | "robotiq_85_gripper_joint_0_R",
42 | "robotiq_85_gripper_joint_1_R",
43 | "robotiq_85_gripper_joint_2_L",
44 | "robotiq_85_gripper_joint_3_L",
45 | "robotiq_85_gripper_joint_2_R",
46 | "robotiq_85_gripper_joint_3_R",
47 | ]
48 |
49 | @property
50 | def visualization_sites(self):
51 | return ["grip_site", "grip_site_cylinder"]
52 |
53 | @property
54 | def left_finger_geoms(self):
55 | return [
56 | "robotiq_85_gripper_joint_0_L",
57 | "robotiq_85_gripper_joint_1_L",
58 | "robotiq_85_gripper_joint_2_L",
59 | "robotiq_85_gripper_joint_3_L",
60 | ]
61 |
62 | @property
63 | def right_finger_geoms(self):
64 | return [
65 | "robotiq_85_gripper_joint_0_R",
66 | "robotiq_85_gripper_joint_1_R",
67 | "robotiq_85_gripper_joint_2_R",
68 | "robotiq_85_gripper_joint_3_R",
69 | ]
70 |
71 |
72 | class RobotiqGripper(RobotiqGripperBase):
73 | """
74 | 1-DoF variant of RobotiqGripperBase.
75 | """
76 |
77 | def format_action(self, action):
78 | """
79 | Args:
80 | action: 1 => open, -1 => closed
81 | """
82 | assert len(action) == 1
83 | return -1 * np.ones(6) * action
84 |
85 | @property
86 | def dof(self):
87 | return 1
88 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/wrappers/README.md:
--------------------------------------------------------------------------------
1 | Wrappers
2 | ========
3 |
4 | Wrappers offer additional features to an environment. Custom wrappers can be implemented for the purpose of collecting data, recording videos, and modifying environments. We provide some example wrapper implementations in this folder. All wrappers should inherit a base [Wrapper](wrapper.py) class.
5 |
6 | To use a wrapper for the environment, import it and wrap it around the environment.
7 |
8 | ```python
9 | import robosuite
10 | from robosuite.wrappers import CustomWrapper
11 |
12 | env = robosuite.make("SawyerLift")
13 | env = CustomWrapper(env)
14 | ```
15 |
16 | DataCollectionWrapper
17 | ---------------------
18 | [DataCollectionWrapper](data_collection_wrapper.py) saves trajectory information to disk. This [demo script](../scripts/demo_collect_and_playback_data.py) illustrates how to collect the rollout trajectories and replay them in the simulation.
19 |
20 | DemoSamplerWrapper
21 | ------------------
22 | [DemoSamplerWrapper](demo_sampler_wrapper.py) loads demonstrations as a dataset of trajectories and randomly resets the start state of episodes along the demonstration trajectories based on a certain schedule. This functionality is useful for training RL agents and has been adopted in several prior work (see [references](../scripts/demo_learning_curriculum.py)). We provide a [demo script](../scripts/demo_learning_curriculum.py) to show how to configure the demo sampler to load demonstrations from files and use them to change the initial state distribution of episodes.
23 |
24 | GymWrapper
25 | ----------
26 | [GymWrapper](gym_wrapper.py) implements the standard methods in [OpenAI Gym](https://github.com/openai/gym), which allows popular RL libraries to run with our environments using the same APIs as Gym. This [demo script](../scripts/demo_gym_functionality.py) shows how to convert robosuite environments into Gym interfaces using this wrapper.
27 |
28 | ```bash
29 | pip install gym
30 | ```
31 |
32 | ## IKWrapper
33 |
34 | [IKWrapper](ik_wrapper.py) allows for using an end effector action space to control the robot in an environment instead of the default joint velocity action space. It uses our inverse kinematics robot controllers, located in the [controllers](../controllers) directory, which depend on PyBullet. In order to use this wrapper, you must run the following command to install PyBullet.
35 |
36 | ```bash
37 | pip install pybullet==1.9.5
38 | ```
39 |
40 | The main difference between the joint velocity action space and the end effector action space supported by this wrapper is that instead of supplying joint velocities per arm, a **delta position** vector and **delta quaternion** (xyzw) should be supplied per arm, where these correspond to the relative changes in position and rotation of the end effector from its current pose.
41 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/grippers/gripper.py:
--------------------------------------------------------------------------------
1 | """
2 | Defines the base class of all grippers
3 | """
4 | from spirl.data.block_stacking.src.robosuite.models.base import MujocoXML
5 |
6 |
7 | class Gripper(MujocoXML):
8 | """Base class for grippers"""
9 |
10 | def __init__(self, fname):
11 | super().__init__(fname)
12 |
13 | def format_action(self, action):
14 | """
15 | Given (-1,1) abstract control as np-array
16 | returns the (-1,1) control signals
17 | for underlying actuators as 1-d np array
18 | """
19 | raise NotImplementedError
20 |
21 | @property
22 | def init_qpos(self):
23 | """
24 | Returns rest(open) qpos of the gripper
25 | """
26 | raise NotImplementedError
27 |
28 | @property
29 | def dof(self):
30 | """
31 | Returns the number of DOF of the gripper
32 | """
33 | raise NotImplementedError
34 |
35 | @property
36 | def joints(self):
37 | """
38 | Returns a list of joint names of the gripper
39 | """
40 | raise NotImplementedError
41 |
42 | def contact_geoms(self):
43 | """
44 | Returns a list of names corresponding to the geoms
45 | used to determine contact with the gripper.
46 | """
47 | return []
48 |
49 | @property
50 | def visualization_sites(self):
51 | """
52 | Returns a list of sites corresponding to the geoms
53 | used to aid visualization by human.
54 | (and should be hidden from robots)
55 | """
56 | return []
57 |
58 | @property
59 | def visualization_geoms(self):
60 | """
61 | Returns a list of sites corresponding to the geoms
62 | used to aid visualization by human.
63 | (and should be hidden from robots)
64 | """
65 | return []
66 |
67 | @property
68 | def left_finger_geoms(self):
69 | """
70 | Geoms corresponding to left finger of a gripper
71 | """
72 | raise NotImplementedError
73 |
74 | @property
75 | def right_finger_geoms(self):
76 | """
77 | Geoms corresponding to raise finger of a gripper
78 | """
79 | raise NotImplementedError
80 |
81 | def hide_visualization(self):
82 | """
83 | Hides all visualization geoms and sites.
84 | This should be called before rendering to agents
85 | """
86 | for site_name in self.visualization_sites:
87 | site = self.worldbody.find(".//site[@name='{}']".format(site_name))
88 | site.set("rgba", "0 0 0 0")
89 | for geom_name in self.visualization_geoms:
90 | geom = self.worldbody.find(".//geom[@name='{}']".format(geom_name))
91 | geom.set("rgba", "0 0 0 0")
92 |
--------------------------------------------------------------------------------
/spirl/modules/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from spirl.utils.general_utils import AttrDict, get_dim_inds
3 | from spirl.modules.variational_inference import Gaussian
4 |
5 |
6 | class Loss():
7 | def __init__(self, weight=1.0, breakdown=None):
8 | """
9 |
10 | :param weight: the balance term on the loss
11 | :param breakdown: if specified, a breakdown of the loss by this dimension will be recorded
12 | """
13 | self.weight = weight
14 | self.breakdown = breakdown
15 |
16 | def __call__(self, *args, weights=1, reduction='mean', store_raw=False, **kwargs):
17 | """
18 |
19 | :param estimates:
20 | :param targets:
21 | :return:
22 | """
23 | error = self.compute(*args, **kwargs) * weights
24 | if reduction != 'mean':
25 | raise NotImplementedError
26 | loss = AttrDict(value=error.mean(), weight=self.weight)
27 | if self.breakdown is not None:
28 | reduce_dim = get_dim_inds(error)[:self.breakdown] + get_dim_inds(error)[self.breakdown+1:]
29 | loss.breakdown = error.detach().mean(reduce_dim) if reduce_dim else error.detach()
30 | if store_raw:
31 | loss.error_mat = error.detach()
32 | return loss
33 |
34 | def compute(self, estimates, targets):
35 | raise NotImplementedError
36 |
37 |
38 | class L2Loss(Loss):
39 | def compute(self, estimates, targets, activation_function=None):
40 | # assert estimates.shape == targets.shape, "Input {} and targets {} for L2 loss need to have identical shape!"\
41 | # .format(estimates.shape, targets.shape)
42 | if activation_function is not None:
43 | estimates = activation_function(estimates)
44 | if not isinstance(targets, torch.Tensor):
45 | targets = torch.tensor(targets, device=estimates.device, dtype=estimates.dtype)
46 | l2_loss = torch.nn.MSELoss(reduction='none')(estimates, targets)
47 | return l2_loss
48 |
49 |
50 | class KLDivLoss(Loss):
51 | def compute(self, estimates, targets):
52 | if not isinstance(estimates, Gaussian): estimates = Gaussian(estimates)
53 | if not isinstance(targets, Gaussian): targets = Gaussian(targets)
54 | kl_divergence = estimates.kl_divergence(targets)
55 | return kl_divergence
56 |
57 |
58 | class CELoss(Loss):
59 | compute = staticmethod(torch.nn.functional.cross_entropy)
60 |
61 |
62 | class PenaltyLoss(Loss):
63 | def compute(self, val):
64 | """Computes weighted mean of val as penalty loss."""
65 | return val
66 |
67 |
68 | class NLL(Loss):
69 | # Note that cross entropy is an instance of NLL, as is L2 loss.
70 | def compute(self, estimates, targets):
71 | nll = estimates.nll(targets)
72 | return nll
73 |
74 |
75 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/tasks/table_top_task.py:
--------------------------------------------------------------------------------
1 | from spirl.data.block_stacking.src.robosuite.models.tasks import Task, UniformRandomSampler
2 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import new_joint, array_to_string
3 |
4 |
5 | class TableTopTask(Task):
6 | """
7 | Creates MJCF model of a tabletop task.
8 |
9 | A tabletop task consists of one robot interacting with a variable number of
10 | objects placed on the tabletop. This class combines the robot, the table
11 | arena, and the objetcts into a single MJCF model.
12 | """
13 |
14 | def __init__(self, mujoco_arena, mujoco_robot, mujoco_objects, initializer=None):
15 | """
16 | Args:
17 | mujoco_arena: MJCF model of robot workspace
18 | mujoco_robot: MJCF model of robot model
19 | mujoco_objects: a list of MJCF models of physical objects
20 | initializer: placement sampler to initialize object positions.
21 | """
22 | super().__init__()
23 |
24 | self.merge_arena(mujoco_arena)
25 | self.merge_robot(mujoco_robot)
26 | self.merge_objects(mujoco_objects)
27 | if initializer is None:
28 | initializer = UniformRandomSampler()
29 | mjcfs = [x for _, x in self.mujoco_objects.items()]
30 |
31 | self.initializer = initializer
32 | self.initializer.setup(mjcfs, self.table_top_offset, self.table_size)
33 |
34 | def merge_robot(self, mujoco_robot):
35 | """Adds robot model to the MJCF model."""
36 | self.robot = mujoco_robot
37 | self.merge(mujoco_robot)
38 |
39 | def merge_arena(self, mujoco_arena):
40 | """Adds arena model to the MJCF model."""
41 | self.arena = mujoco_arena
42 | self.table_top_offset = mujoco_arena.table_top_abs
43 | self.table_size = mujoco_arena.table_full_size
44 | self.merge(mujoco_arena)
45 |
46 | def merge_objects(self, mujoco_objects):
47 | """Adds physical objects to the MJCF model."""
48 | self.mujoco_objects = mujoco_objects
49 | self.objects = [] # xml manifestation
50 | self.targets = [] # xml manifestation
51 | self.max_horizontal_radius = 0
52 |
53 | for obj_name, obj_mjcf in mujoco_objects.items():
54 | self.merge_asset(obj_mjcf)
55 | # Load object
56 | obj = obj_mjcf.get_collision(name=obj_name, site=True)
57 | obj.append(new_joint(name=obj_name, type="free"))
58 | self.objects.append(obj)
59 | self.worldbody.append(obj)
60 |
61 | self.max_horizontal_radius = max(
62 | self.max_horizontal_radius, obj_mjcf.get_horizontal_radius()
63 | )
64 |
65 | def place_objects(self):
66 | """Places objects randomly until no collisions or max iterations hit."""
67 | pos_arr, quat_arr = self.initializer.sample()
68 | for i in range(len(self.objects)):
69 | self.objects[i].set("pos", array_to_string(pos_arr[i]))
70 | self.objects[i].set("quat", array_to_string(quat_arr[i]))
71 |
--------------------------------------------------------------------------------
/spirl/rl/utils/wandb.py:
--------------------------------------------------------------------------------
1 | import wandb
2 | import inspect
3 | import numpy as np
4 |
5 | from spirl.utils.general_utils import flatten_dict, prefix_dict
6 | import csv
7 |
8 | class WandBLogger:
9 | """Logs to WandB."""
10 | N_LOGGED_SAMPLES = 3 # how many examples should be logged in each logging step
11 |
12 | def __init__(self, exp_name, project_name, entity, path, conf, exclude=None):
13 | """
14 | :param exp_name: full name of experiment in WandB
15 | :param project_name: name of overall project
16 | :param entity: name of head entity in WandB that hosts the project
17 | :param path: path to which WandB log-files will be written
18 | :param conf: hyperparam config that will get logged to WandB
19 | :param exclude: (optional) list of (flattened) hyperparam names that should not get logged
20 | """
21 | if exclude is None: exclude = []
22 | flat_config = flatten_dict(conf)
23 | filtered_config = {k: v for k, v in flat_config.items() if (k not in exclude and not inspect.isclass(v))}
24 | print("INIT WANDB")
25 | wandb.init(
26 | resume=exp_name,
27 | project=project_name,
28 | config=filtered_config,
29 | dir=path,
30 | entity=entity,
31 | notes=conf.notes if 'notes' in conf else ''
32 | )
33 |
34 | def log_scalar_dict(self, d, prefix='', step=None):
35 | """Logs all entries from a dict of scalars. Optionally can prefix all keys in dict before logging."""
36 | if prefix: d = prefix_dict(d, prefix + '_')
37 | #wandb.log(d) if step is None else wandb.log(d, step=step)
38 | wandb.log(d)
39 |
40 | if 'train_episode_reward' in d:
41 | with open('train_episode_reward.csv','a') as f:
42 | if d['train_episode_length'] < 200:
43 | success = 1
44 | else:
45 | success = 0
46 | f.write(str(d['train_episode_reward'])+','+str(success)+'\n')
47 |
48 | def log_videos(self, vids, name, step=None):
49 | """Logs videos to WandB in mp4 format.
50 | Assumes list of numpy arrays as input with [time, channels, height, width]."""
51 | assert len(vids[0].shape) == 4 and vids[0].shape[1] == 3
52 | assert isinstance(vids[0], np.ndarray)
53 | if vids[0].max() <= 1.0: vids = [np.asarray(vid * 255.0, dtype=np.uint8) for vid in vids]
54 | # TODO(karl) expose the FPS as a parameter
55 | log_dict = {name: [wandb.Video(vid, fps=20, format="mp4") for vid in vids]}
56 | wandb.log(log_dict) if step is None else wandb.log(log_dict, step=step)
57 |
58 | def log_plot(self, fig, name, step=None):
59 | """Logs matplotlib graph to WandB.
60 | fig is a matplotlib figure handle."""
61 | img = wandb.Image(fig)
62 | wandb.log({name: img}) if step is None else wandb.log({name: img}, step=step)
63 |
64 | @property
65 | def n_logged_samples(self):
66 | # TODO(karl) put this functionality in a base logger class + give it default parameters and config
67 | return self.N_LOGGED_SAMPLES
68 |
69 |
70 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/README.md:
--------------------------------------------------------------------------------
1 | Models
2 | ======
3 |
4 | Models provide a toolkit of composing modularized elements into a scene. The central goal of models is to provide a set of modularized APIs that procedurally generate combinations of robots, arenas, and parameterized 3D objects, such that it enables us to learn control policies with better robustness and generalization.
5 |
6 | In MuJoCo, these scene elements are specified in [MJCF](http://mujoco.org/book/modeling.html#Summary) format and compiled into mjModel to instantiate the physical simulation. Each MJCF model can consist of multiple XML files as well as meshes and textures referenced from the XML file. These MJCF files are stored in the [assets](assets) folder.
7 |
8 | Below we describe the types of scene elements that we support:
9 |
10 | Robots
11 | ------
12 | The [robots](robots) folder contains robot classes which load robot specifications from MJCF/URDF files and instantiate a robot mjModel. All robot classes should inherit a base [Robot](robots/robot.py) class which defines a set of common robot APIs.
13 |
14 | Grippers
15 | --------
16 | The [grippers](grippers) folder consists of a variety of end-effector models that can be mounted to the arms of a robot model by the [`add_gripper`](robots/robot.py#L20) method in the robot class.
17 |
18 | Objects
19 | -------
20 | [Objects](objects) are small interactable scene elements that robots interact with using their actuators. Objects can be either defined as 3D [meshes](http://mujoco.org/book/modeling.html#mesh) (e.g., in STL format) or procedurally generated from primitive shapes of MuJoCo [geoms](http://mujoco.org/book/modeling.html#geom).
21 |
22 | [MujocoObject](objects/mujoco_object.py) is the base object class. [MujocoXMLObject](objects/mujoco_object.py) is the base class for all objects that are loaded from MJCF XML files. [MujocoGeneratedObject](objects/mujoco_object.py) is the base class for procedurally generated objects with support for size and other physical property randomization.
23 |
24 | Arenas
25 | ------
26 | [Arenas](arenas) define the workspace, such as a tabletop or a set of bins, where the robot performs the tasks. All arena classes should inherit a base [Arena](arenas/arena.py) class. By default each arena contains 3 cameras (see [example](assets/arenas/empty_arena.xml)). The `frontview` camera provides an overview of the scene, which is often used to generate visualizations and video recordings. The `agentview` camera is the canonical camera typically used for visual observations when training visuomotor policies. The `birdview` camera is a top-down camera which is useful for debugging the placements of objects in the scene.
27 |
28 | Tasks
29 | -----
30 | [Tasks](tasks) put together all the necessary scene elements, which typically consist of a robot, an arena, and a set of objects, into a model of the whole scene. It handles merging the MJCF models of individual elements and setting the initial placements of these elements in the scene. The resulting scene model is compiled and loaded into the MuJoCo backend to perform simulation. All task classes should inherit a base [Task](tasks/task.py) class which specifies a set of common APIs for model merging and placement initialization.
31 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/block_task_generator.py:
--------------------------------------------------------------------------------
1 | from collections import deque
2 | import numpy as np
3 |
4 | from spirl.utils.general_utils import ParamDict
5 |
6 |
7 | class BlockTaskGenerator:
8 | def __init__(self, hp, n_blocks):
9 | self._hp = self._default_hparams().overwrite(hp)
10 | self._n_blocks = n_blocks
11 | self._rng = np.random.default_rng(seed=self._hp.seed)
12 |
13 | def _default_hparams(self):
14 | default_dict = ParamDict({
15 | 'seed': None,
16 | })
17 | return default_dict
18 |
19 | def sample(self):
20 | """Generates task definition in the form of list of transition tuples.
21 | Each tuple contains (bottom_block, top_block)."""
22 | raise NotImplementedError
23 |
24 | def _sample_tower(self, size, blocks):
25 | """Samples single tower of specified size, pops blocks from queue of idxs."""
26 | block = blocks.popleft()
27 | tasks = []
28 | for _ in range(size):
29 | next_block = blocks.popleft()
30 | tasks.append((block, next_block))
31 | block = next_block
32 | return tasks
33 |
34 | def _init_block_queue(self):
35 | return deque(self._rng.permutation(self._n_blocks))
36 |
37 |
38 | class SingleTowerBlockTaskGenerator(BlockTaskGenerator):
39 | """Samples single tower with a limit on the number of stacked blocks."""
40 | def _default_hparams(self):
41 | default_dict = ParamDict({
42 | 'max_tower_height': 4, # number of blocks in env
43 | })
44 | return super()._default_hparams().overwrite(default_dict)
45 |
46 | def sample(self):
47 | block_queue = self._init_block_queue()
48 | size = self._sample_size()
49 | return self._sample_tower(size, block_queue)
50 |
51 | def _sample_size(self):
52 | return self._rng.integers(1, self._hp.max_tower_height + 1) # assure at least one stack is performed
53 |
54 |
55 | class FixedSizeSingleTowerBlockTaskGenerator(SingleTowerBlockTaskGenerator):
56 | """Samples single tower, always samples maximum height."""
57 | def _sample_size(self):
58 | return self._hp.max_tower_height
59 |
60 |
61 | class MultiTowerBlockTaskGenerator(BlockTaskGenerator):
62 | """Samples multiple tower with a limit on the number of stacked blocks."""
63 | def _default_hparams(self):
64 | default_dict = ParamDict({
65 | 'max_tower_height': 4, # maximum height of target tower(s)
66 | })
67 | return super()._default_hparams().overwrite(default_dict)
68 |
69 | def sample(self):
70 | block_queue = self._init_block_queue()
71 | task = []
72 | while len(block_queue) > 1:
73 | size = self._sample_size(max_height=min(len(block_queue)-1, self._hp.max_tower_height))
74 | task += self._sample_tower(size, block_queue)
75 | return task
76 |
77 | def _sample_size(self, max_height):
78 | return self._rng.integers(1, max_height + 1) # assure at least one stack is performed
79 |
80 |
81 | if __name__ == '__main__':
82 | task_gen = MultiTowerBlockTaskGenerator(hp={}, n_blocks=10)
83 | for _ in range(10):
84 | print(task_gen.sample())
85 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/scripts/demo_collect_and_playback_data.py:
--------------------------------------------------------------------------------
1 | """
2 | Record trajectory data with the DataCollectionWrapper wrapper and play them back.
3 |
4 | Example:
5 | $ python demo_collect_and_playback_data.py --environment BaxterLift
6 | """
7 |
8 | import os
9 | import argparse
10 | from glob import glob
11 | import numpy as np
12 |
13 | from spirl.data.block_stacking.src import robosuite
14 | from spirl.data.block_stacking.src.robosuite import DataCollectionWrapper
15 |
16 |
17 | def collect_random_trajectory(env, timesteps=1000):
18 | """Run a random policy to collect trajectories.
19 |
20 | The rollout trajectory is saved to files in npz format.
21 | Modify the DataCollectionWrapper wrapper to add new fields or change data formats.
22 | """
23 |
24 | obs = env.reset()
25 | dof = env.dof
26 |
27 | for t in range(timesteps):
28 | action = 0.5 * np.random.randn(dof)
29 | obs, reward, done, info = env.step(action)
30 | env.render()
31 | if t % 100 == 0:
32 | print(t)
33 |
34 |
35 | def playback_trajectory(env, ep_dir):
36 | """Playback data from an episode.
37 |
38 | Args:
39 | ep_dir: The path to the directory containing data for an episode.
40 | """
41 |
42 | # first reload the model from the xml
43 | xml_path = os.path.join(ep_dir, "model.xml")
44 | with open(xml_path, "r") as f:
45 | env.reset_from_xml_string(f.read())
46 |
47 | state_paths = os.path.join(ep_dir, "state_*.npz")
48 |
49 | # read states back, load them one by one, and render
50 | t = 0
51 | for state_file in sorted(glob(state_paths)):
52 | print(state_file)
53 | dic = np.load(state_file)
54 | states = dic["states"]
55 | for state in states:
56 | env.sim.set_state_from_flattened(state)
57 | env.sim.forward()
58 | env.render()
59 | t += 1
60 | if t % 100 == 0:
61 | print(t)
62 |
63 |
64 | if __name__ == "__main__":
65 |
66 | parser = argparse.ArgumentParser()
67 | parser.add_argument("--environment", type=str, default="SawyerStack")
68 | parser.add_argument("--directory", type=str, default="/tmp/")
69 | parser.add_argument("--timesteps", type=int, default=2000)
70 | args = parser.parse_args()
71 |
72 | # create original environment
73 | env = robosuite.make(
74 | args.environment,
75 | ignore_done=True,
76 | use_camera_obs=False,
77 | has_renderer=True,
78 | control_freq=100,
79 | )
80 | data_directory = args.directory
81 |
82 | # wrap the environment with data collection wrapper
83 | env = DataCollectionWrapper(env, data_directory)
84 |
85 | # testing to make sure multiple env.reset calls don't create multiple directories
86 | env.reset()
87 | env.reset()
88 | env.reset()
89 |
90 | # collect some data
91 | print("Collecting some random data...")
92 | collect_random_trajectory(env, timesteps=args.timesteps)
93 |
94 | # playback some data
95 | _ = input("Press any key to begin the playback...")
96 | print("Playing back the data...")
97 | data_directory = env.ep_directory
98 | playback_trajectory(env, data_directory)
99 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/tasks/nut_assembly_task.py:
--------------------------------------------------------------------------------
1 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import new_joint, array_to_string
2 | from spirl.data.block_stacking.src.robosuite.models.tasks import Task, UniformRandomPegsSampler
3 |
4 |
5 | class NutAssemblyTask(Task):
6 | """
7 | Creates MJCF model of a nut assembly task.
8 |
9 | A nut assembly task consists of one robot picking up nuts from a table and
10 | and assembling them into pegs positioned on the tabletop. This class combines
11 | the robot, the arena with pegs, and the nut objetcts into a single MJCF model.
12 | """
13 |
14 | def __init__(self, mujoco_arena, mujoco_robot, mujoco_objects, initializer=None):
15 | """
16 | Args:
17 | mujoco_arena: MJCF model of robot workspace
18 | mujoco_robot: MJCF model of robot model
19 | mujoco_objects: a list of MJCF models of physical objects
20 | initializer: placement sampler to initialize object positions.
21 | """
22 | super().__init__()
23 |
24 | self.object_metadata = []
25 | self.merge_arena(mujoco_arena)
26 | self.merge_robot(mujoco_robot)
27 | self.merge_objects(mujoco_objects)
28 |
29 | if initializer is None:
30 | initializer = UniformRandomPegsSampler()
31 | self.initializer = initializer
32 | self.initializer.setup(self.mujoco_objects, self.table_offset, self.table_size)
33 |
34 | def merge_robot(self, mujoco_robot):
35 | """Adds robot model to the MJCF model."""
36 | self.robot = mujoco_robot
37 | self.merge(mujoco_robot)
38 |
39 | def merge_arena(self, mujoco_arena):
40 | """Adds arena model to the MJCF model."""
41 | self.arena = mujoco_arena
42 | self.table_offset = mujoco_arena.table_top_abs
43 | self.table_size = mujoco_arena.table_full_size
44 | self.table_body = mujoco_arena.table_body
45 | self.peg1_body = mujoco_arena.peg1_body
46 | self.peg2_body = mujoco_arena.peg2_body
47 | self.merge(mujoco_arena)
48 |
49 | def merge_objects(self, mujoco_objects):
50 | """Adds physical objects to the MJCF model."""
51 | self.mujoco_objects = mujoco_objects
52 | self.objects = {} # xml manifestation
53 | self.max_horizontal_radius = 0
54 | for obj_name, obj_mjcf in mujoco_objects.items():
55 | self.merge_asset(obj_mjcf)
56 | # Load object
57 | obj = obj_mjcf.get_collision(name=obj_name, site=True)
58 | obj.append(new_joint(name=obj_name, type="free", damping="0.0005"))
59 | self.objects[obj_name] = obj
60 | self.worldbody.append(obj)
61 |
62 | self.max_horizontal_radius = max(
63 | self.max_horizontal_radius, obj_mjcf.get_horizontal_radius()
64 | )
65 |
66 | def place_objects(self):
67 | """Places objects randomly until no collisions or max iterations hit."""
68 | pos_arr, quat_arr = self.initializer.sample()
69 | for k, obj_name in enumerate(self.objects):
70 | self.objects[obj_name].set("pos", array_to_string(pos_arr[k]))
71 | self.objects[obj_name].set("quat", array_to_string(quat_arr[k]))
72 |
--------------------------------------------------------------------------------
/spirl/components/params.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def get_args():
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument("--path", help="path to the config file directory")
7 |
8 | # Folder settings
9 | parser.add_argument("--prefix", help="experiment prefix, if given creates subfolder in experiment directory")
10 | parser.add_argument('--new_dir', default=False, type=int, help='If True, concat datetime string to exp_dir.')
11 | parser.add_argument('--dont_save', default=False, type=int,
12 | help="if True, nothing is saved to disk. Note: this doesn't work") # TODO this doesn't work
13 |
14 | # Running protocol
15 | parser.add_argument('--resume', default='latest', type=str, metavar='PATH',
16 | help='path to latest checkpoint (default: none)')
17 | parser.add_argument('--train', default=True, type=int,
18 | help='if False, will run one validation epoch')
19 | parser.add_argument('--test_prediction', default=True, type=int,
20 | help="if False, prediction isn't run at validation time")
21 | parser.add_argument('--skip_first_val', default=True, type=int,
22 | help='if True, will skip the first validation epoch')
23 | parser.add_argument('--val_sweep', default=False, type=int,
24 | help='if True, runs validation on all existing model checkpoints')
25 |
26 | # Misc
27 | parser.add_argument('--gpu', default=0, type=int,
28 | help='will set CUDA_VISIBLE_DEVICES to selected value')
29 | parser.add_argument('--strict_weight_loading', default=True, type=int,
30 | help='if True, uses strict weight loading function')
31 | parser.add_argument('--deterministic', default=False, type=int,
32 | help='if True, sets fixed seeds for torch and numpy')
33 | parser.add_argument('--log_interval', default=500, type=int,
34 | help='number of updates per training log')
35 | parser.add_argument('--per_epoch_img_logs', default=1, type=int,
36 | help='number of image loggings per epoch')
37 | parser.add_argument('--val_data_size', default=-1, type=int,
38 | help='number of sequences in the validation set. If -1, the full dataset is used')
39 | parser.add_argument('--val_interval', default=5, type=int,
40 | help='number of epochs per validation')
41 |
42 | # Debug
43 | parser.add_argument('--detect_anomaly', default=False, type=int,
44 | help='if True, uses autograd.detect_anomaly()')
45 | parser.add_argument('--feed_random_data', default=False, type=int,
46 | help='if True, we feed random data to the model to test its performance')
47 | parser.add_argument('--train_loop_pdb', default=False, type=int,
48 | help='if True, opens a pdb into training loop')
49 | parser.add_argument('--debug', default=False, type=int,
50 | help='if True, runs in debug mode')
51 |
52 | # add kl_div_weight
53 | parser.add_argument('--save2mp4', default=False, type=bool,
54 | help='if set, videos will be saved locally')
55 |
56 | return parser.parse_args()
57 |
--------------------------------------------------------------------------------
/spirl/components/trainer_base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 |
5 |
6 | def save_checkpoint(state, filename='checkpoint.pth'):
7 | torch.save(state, filename)
8 |
9 |
10 | class BaseTrainer():
11 | def override_defaults(self, policyparams):
12 | for name, value in policyparams.items():
13 | print('overriding param {} to value {}'.format(name, value))
14 | if value == getattr(self._hp, name):
15 | raise ValueError("attribute is {} is identical to default value!!".format(name))
16 | self._hp.set_hparam(name, value)
17 |
18 | def call_hooks(self, inputs, output, losses, epoch):
19 | for hook in self.hooks:
20 | hook.run(inputs, output, losses, epoch)
21 |
22 | def check_nan_grads(self, grads):
23 | grads = grads #or self.model.named_parameters()
24 | return torch.stack([torch.isnan(p.grad).any() for n, p in grads]).any()
25 |
26 | def nan_grads_hook(self, inputs, output, losses, epoch):
27 | non_none = list(filter(lambda x: x[1].grad is not None, self.model.named_parameters()))
28 | if self.check_nan_grads(non_none):
29 | self.dump_debug_data(inputs, output, losses, epoch)
30 | nan_parameters = [n for n,v in filter(lambda x: torch.isnan(x[-1].grad).any(), non_none)]
31 | non_nan_parameters = [n for n, v in filter(lambda x: not torch.isnan(x[-1].grad).any(), non_none)]
32 | # grad_dict = self.print_nan_grads()
33 | print('nan gradients here: ', nan_parameters)
34 | print('there parameters are OK: ', non_nan_parameters)
35 |
36 | def get_nan_module(layer):
37 | return np.unique(map(lambda n: '.'.join(n.split('.')[:layer]), nan_parameters))
38 |
39 | import pdb; pdb.set_trace()
40 | raise ValueError('there are NaN Gradients')
41 |
42 | def dump_debug_data(self, inputs, outputs, losses, epoch):
43 | print("Detected NaN gradients, dumping data for diagnosis, ...")
44 | debug_dict = dict()
45 |
46 | def clean_dict(d):
47 | clean_d = dict()
48 | for key, val in d.items():
49 | if isinstance(val, torch.Tensor):
50 | clean_d[key] = val.data.cpu().numpy()
51 | elif isinstance(val, dict):
52 | clean_d[key] = clean_dict(val)
53 | return clean_d
54 |
55 | debug_dict['inputs'] = clean_dict(inputs)
56 | debug_dict['outputs'] = clean_dict(outputs)
57 | debug_dict['losses'] = clean_dict(losses)
58 |
59 | import pickle as pkl
60 | f = open(os.path.join(self._hp.exp_path, "nan_debug_info.pkl"), "wb")
61 | pkl.dump(debug_dict, f)
62 | f.close()
63 |
64 | save_checkpoint({
65 | 'epoch': epoch,
66 | 'state_dict': self.model.state_dict(),
67 | 'optimizer': self.optimizer.state_dict(),
68 | }, os.path.join(self._hp.exp_path, "nan_debug_ckpt.pth"))
69 |
70 | def print_nan_grads(self):
71 | grad_dict = {}
72 | for name, param in self.model.named_parameters(recurse=True):
73 | print("{}:\t\t{}".format(name, bool(torch.isnan(param.grad).any().data.cpu().numpy())))
74 | grad_dict[name] = param.grad
75 | return grad_dict
76 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/utils/mujoco_py_renderer.py:
--------------------------------------------------------------------------------
1 | from mujoco_py import MjViewer
2 | from mujoco_py.generated import const
3 | import glfw
4 | from collections import defaultdict
5 |
6 |
7 | class CustomMjViewer(MjViewer):
8 |
9 | keypress = defaultdict(list)
10 | keyup = defaultdict(list)
11 | keyrepeat = defaultdict(list)
12 |
13 | def key_callback(self, window, key, scancode, action, mods):
14 | if action == glfw.PRESS:
15 | tgt = self.keypress
16 | elif action == glfw.RELEASE:
17 | tgt = self.keyup
18 | elif action == glfw.REPEAT:
19 | tgt = self.keyrepeat
20 | else:
21 | return
22 | if tgt.get(key):
23 | for fn in tgt[key]:
24 | fn(window, key, scancode, action, mods)
25 | if tgt.get("any"):
26 | for fn in tgt["any"]:
27 | fn(window, key, scancode, action, mods)
28 | # retain functionality for closing the viewer
29 | if key == glfw.KEY_ESCAPE:
30 | super().key_callback(window, key, scancode, action, mods)
31 | else:
32 | # only use default mujoco callbacks if "any" callbacks are unset
33 | super().key_callback(window, key, scancode, action, mods)
34 |
35 |
36 | class MujocoPyRenderer:
37 | def __init__(self, sim):
38 | """
39 | Args:
40 | sim: MjSim object
41 | """
42 | self.viewer = CustomMjViewer(sim)
43 | self.callbacks = {}
44 |
45 | def set_camera(self, camera_id):
46 | """
47 | Set the camera view to the specified camera ID.
48 | """
49 | self.viewer.cam.fixedcamid = camera_id
50 | self.viewer.cam.type = const.CAMERA_FIXED
51 |
52 | def render(self):
53 | # safe for multiple calls
54 | self.viewer.render()
55 |
56 | def close(self):
57 | """
58 | Destroys the open window and renders (pun intended) the viewer useless.
59 | """
60 | glfw.destroy_window(self.viewer.window)
61 | self.viewer = None
62 |
63 | def add_keypress_callback(self, key, fn):
64 | """
65 | Allows for custom callback functions for the viewer. Called on key down.
66 | Parameter 'any' will ensure that the callback is called on any key down,
67 | and block default mujoco viewer callbacks from executing, except for
68 | the ESC callback to close the viewer.
69 | """
70 | self.viewer.keypress[key].append(fn)
71 |
72 | def add_keyup_callback(self, key, fn):
73 | """
74 | Allows for custom callback functions for the viewer. Called on key up.
75 | Parameter 'any' will ensure that the callback is called on any key up,
76 | and block default mujoco viewer callbacks from executing, except for
77 | the ESC callback to close the viewer.
78 | """
79 | self.viewer.keyup[key].append(fn)
80 |
81 | def add_keyrepeat_callback(self, key, fn):
82 | """
83 | Allows for custom callback functions for the viewer. Called on key repeat.
84 | Parameter 'any' will ensure that the callback is called on any key repeat,
85 | and block default mujoco viewer callbacks from executing, except for
86 | the ESC callback to close the viewer.
87 | """
88 | self.viewer.keyrepeat[key].append(fn)
89 |
--------------------------------------------------------------------------------
/spirl/rl/agents/skill_space_agent.py:
--------------------------------------------------------------------------------
1 | from collections import deque
2 | import contextlib
3 |
4 | import numpy as np
5 |
6 | from spirl.rl.components.agent import BaseAgent
7 | from spirl.utils.general_utils import ParamDict, split_along_axis, AttrDict
8 | from spirl.utils.pytorch_utils import map2torch, map2np, no_batchnorm_update
9 |
10 |
11 | class SkillSpaceAgent(BaseAgent):
12 | """Agent that acts based on pre-trained VAE skill decoder."""
13 | def __init__(self, config):
14 | super().__init__(config)
15 | self._update_model_params() # transfer some parameters to model
16 |
17 | self._policy = self._hp.model(self._hp.model_params, logger=None)
18 | self.load_model_weights(self._policy, self._hp.model_checkpoint, self._hp.model_epoch)
19 |
20 | self.action_plan = deque()
21 |
22 | def _default_hparams(self):
23 | default_dict = ParamDict({
24 | 'model': None, # policy class
25 | 'model_params': None, # parameters for the policy class
26 | 'model_checkpoint': None, # checkpoint path of the model
27 | 'model_epoch': 'latest', # epoch that checkpoint should be loaded for (defaults to latest)
28 | })
29 | return super()._default_hparams().overwrite(default_dict)
30 |
31 | def _act(self, obs):
32 | assert len(obs.shape) == 2 and obs.shape[0] == 1 # assume single-observation batches with leading 1-dim
33 | if not self.action_plan:
34 | # generate action plan if the current one is empty
35 | split_obs = self._split_obs(obs)
36 | with no_batchnorm_update(self._policy) if obs.shape[0] == 1 else contextlib.suppress():
37 | actions = self._policy.decode(map2torch(split_obs.z, self._hp.device),
38 | map2torch(split_obs.cond_input, self._hp.device),
39 | self._policy.n_rollout_steps)
40 | self.action_plan = deque(split_along_axis(map2np(actions), axis=1))
41 | return AttrDict(action=self.action_plan.popleft())
42 |
43 | def reset(self):
44 | self.action_plan = deque() # reset action plan
45 |
46 | def update(self, experience_batch):
47 | return {} # TODO(karl) implement finetuning for policy
48 |
49 | def _split_obs(self, obs):
50 | assert obs.shape[1] == self._policy.state_dim + self._policy.latent_dim
51 | return AttrDict(
52 | cond_input=obs[:, :-self._policy.latent_dim], # condition decoding on state
53 | z=obs[:, -self._policy.latent_dim:],
54 | )
55 |
56 | def sync_networks(self):
57 | pass # TODO(karl) only need to implement if we implement finetuning
58 |
59 | def _update_model_params(self):
60 | self._hp.model_params.device = self._hp.device # transfer device to low-level model
61 | self._hp.model_params.batch_size = 1 # run only single-element batches
62 |
63 | def _act_rand(self, obs):
64 | return self._act(obs)
65 |
66 |
67 | class ACSkillSpaceAgent(SkillSpaceAgent):
68 | """Unflattens prior input part of observation."""
69 | def _split_obs(self, obs):
70 | unflattened_obs = map2np(self._policy.unflatten_obs(
71 | map2torch(obs[:, :-self._policy.latent_dim], device=self.device)))
72 | return AttrDict(
73 | cond_input=unflattened_obs.prior_obs,
74 | z=obs[:, -self._policy.latent_dim:],
75 | )
76 |
--------------------------------------------------------------------------------
/spirl/data/block_stacking/src/robosuite/models/objects/xml_objects.py:
--------------------------------------------------------------------------------
1 | from spirl.data.block_stacking.src.robosuite.models.objects import MujocoXMLObject
2 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import xml_path_completion
3 |
4 |
5 | class BottleObject(MujocoXMLObject):
6 | """
7 | Bottle object
8 | """
9 |
10 | def __init__(self):
11 | super().__init__(xml_path_completion("objects/bottle.xml"))
12 |
13 |
14 | class CanObject(MujocoXMLObject):
15 | """
16 | Coke can object (used in SawyerPickPlace)
17 | """
18 |
19 | def __init__(self):
20 | super().__init__(xml_path_completion("objects/can.xml"))
21 |
22 |
23 | class LemonObject(MujocoXMLObject):
24 | """
25 | Lemon object
26 | """
27 |
28 | def __init__(self):
29 | super().__init__(xml_path_completion("objects/lemon.xml"))
30 |
31 |
32 | class MilkObject(MujocoXMLObject):
33 | """
34 | Milk carton object (used in SawyerPickPlace)
35 | """
36 |
37 | def __init__(self):
38 | super().__init__(xml_path_completion("objects/milk.xml"))
39 |
40 |
41 | class BreadObject(MujocoXMLObject):
42 | """
43 | Bread loaf object (used in SawyerPickPlace)
44 | """
45 |
46 | def __init__(self):
47 | super().__init__(xml_path_completion("objects/bread.xml"))
48 |
49 |
50 | class CerealObject(MujocoXMLObject):
51 | """
52 | Cereal box object (used in SawyerPickPlace)
53 | """
54 |
55 | def __init__(self):
56 | super().__init__(xml_path_completion("objects/cereal.xml"))
57 |
58 |
59 | class SquareNutObject(MujocoXMLObject):
60 | """
61 | Square nut object (used in SawyerNutAssembly)
62 | """
63 |
64 | def __init__(self):
65 | super().__init__(xml_path_completion("objects/square-nut.xml"))
66 |
67 |
68 | class RoundNutObject(MujocoXMLObject):
69 | """
70 | Round nut (used in SawyerNutAssembly)
71 | """
72 |
73 | def __init__(self):
74 | super().__init__(xml_path_completion("objects/round-nut.xml"))
75 |
76 |
77 | class MilkVisualObject(MujocoXMLObject):
78 | """
79 | Visual fiducial of milk carton (used in SawyerPickPlace).
80 |
81 | Fiducial objects are not involved in collision physics.
82 | They provide a point of reference to indicate a position.
83 | """
84 |
85 | def __init__(self):
86 | super().__init__(xml_path_completion("objects/milk-visual.xml"))
87 |
88 |
89 | class BreadVisualObject(MujocoXMLObject):
90 | """
91 | Visual fiducial of bread loaf (used in SawyerPickPlace)
92 | """
93 |
94 | def __init__(self):
95 | super().__init__(xml_path_completion("objects/bread-visual.xml"))
96 |
97 |
98 | class CerealVisualObject(MujocoXMLObject):
99 | """
100 | Visual fiducial of cereal box (used in SawyerPickPlace)
101 | """
102 |
103 | def __init__(self):
104 | super().__init__(xml_path_completion("objects/cereal-visual.xml"))
105 |
106 |
107 | class CanVisualObject(MujocoXMLObject):
108 | """
109 | Visual fiducial of coke can (used in SawyerPickPlace)
110 | """
111 |
112 | def __init__(self):
113 | super().__init__(xml_path_completion("objects/can-visual.xml"))
114 |
115 |
116 | class PlateWithHoleObject(MujocoXMLObject):
117 | """
118 | Square plate with a hole in the center (used in BaxterPegInHole)
119 | """
120 |
121 | def __init__(self):
122 | super().__init__(xml_path_completion("objects/plate-with-hole.xml"))
123 |
--------------------------------------------------------------------------------