├── spirl ├── __init__.py ├── data │ ├── __init__.py │ ├── block_stacking │ │ ├── __init__.py │ │ ├── src │ │ │ ├── __init__.py │ │ │ ├── robosuite │ │ │ │ ├── scripts │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── README.md │ │ │ │ │ ├── browse_mjcf_model.py │ │ │ │ │ ├── compile_mjcf_model.py │ │ │ │ │ ├── demo_gripper_selection.py │ │ │ │ │ ├── demo_baxter_ik_control.py │ │ │ │ │ ├── demo_video_recording.py │ │ │ │ │ ├── demo_pygame_renderer.py │ │ │ │ │ ├── demo_learning_curriculum.py │ │ │ │ │ └── demo_collect_and_playback_data.py │ │ │ │ ├── environments │ │ │ │ │ └── __init__.py │ │ │ │ ├── devices │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── device.py │ │ │ │ │ └── README.md │ │ │ │ ├── models │ │ │ │ │ ├── robots │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── sawyer_robot.py │ │ │ │ │ │ ├── baxter_robot.py │ │ │ │ │ │ └── robot.py │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── assets │ │ │ │ │ │ ├── textures │ │ │ │ │ │ │ ├── can.png │ │ │ │ │ │ │ ├── clay.png │ │ │ │ │ │ │ ├── bread.png │ │ │ │ │ │ │ ├── cereal.png │ │ │ │ │ │ │ ├── glass.png │ │ │ │ │ │ │ ├── lemon.png │ │ │ │ │ │ │ ├── metal.png │ │ │ │ │ │ │ ├── ceramic.png │ │ │ │ │ │ │ ├── dark-wood.png │ │ │ │ │ │ │ └── light-wood.png │ │ │ │ │ │ ├── objects │ │ │ │ │ │ │ ├── meshes │ │ │ │ │ │ │ │ ├── can.stl │ │ │ │ │ │ │ │ ├── bread.stl │ │ │ │ │ │ │ │ ├── lemon.stl │ │ │ │ │ │ │ │ ├── milk.stl │ │ │ │ │ │ │ │ ├── bottle.stl │ │ │ │ │ │ │ │ ├── cereal.stl │ │ │ │ │ │ │ │ └── handles.stl │ │ │ │ │ │ │ ├── can-visual.xml │ │ │ │ │ │ │ ├── milk-visual.xml │ │ │ │ │ │ │ ├── cereal-visual.xml │ │ │ │ │ │ │ ├── bread-visual.xml │ │ │ │ │ │ │ ├── bottle.xml │ │ │ │ │ │ │ ├── can.xml │ │ │ │ │ │ │ ├── plate-with-hole.xml │ │ │ │ │ │ │ ├── bread.xml │ │ │ │ │ │ │ ├── lemon.xml │ │ │ │ │ │ │ ├── milk.xml │ │ │ │ │ │ │ ├── cereal.xml │ │ │ │ │ │ │ ├── square-nut.xml │ │ │ │ │ │ │ └── round-nut.xml │ │ │ │ │ │ ├── grippers │ │ │ │ │ │ │ └── meshes │ │ │ │ │ │ │ │ ├── deprecated │ │ │ │ │ │ │ │ ├── limiter.STL │ │ │ │ │ │ │ │ ├── paddle_tip.STL │ │ │ │ │ │ │ │ ├── extended_wide.STL │ │ │ │ │ │ │ │ ├── standard_wide.STL │ │ │ │ │ │ │ │ ├── basic_hard_tip.STL │ │ │ │ │ │ │ │ ├── basic_soft_tip.STL │ │ │ │ │ │ │ │ ├── extended_narrow.STL │ │ │ │ │ │ │ │ └── electric_gripper_w_fingers.STL │ │ │ │ │ │ │ │ ├── pr2_gripper │ │ │ │ │ │ │ │ ├── l_finger.stl │ │ │ │ │ │ │ │ ├── gripper_palm.stl │ │ │ │ │ │ │ │ └── l_finger_tip.stl │ │ │ │ │ │ │ │ ├── robotiq_s_gripper │ │ │ │ │ │ │ │ ├── palm.STL │ │ │ │ │ │ │ │ ├── link_0.STL │ │ │ │ │ │ │ │ ├── link_1.STL │ │ │ │ │ │ │ │ ├── link_2.STL │ │ │ │ │ │ │ │ └── link_3.STL │ │ │ │ │ │ │ │ ├── robotiq_gripper │ │ │ │ │ │ │ │ ├── adapter_plate.stl │ │ │ │ │ │ │ │ ├── adapter_plate.dae.stl │ │ │ │ │ │ │ │ ├── robotiq_85_gripper_base.stl │ │ │ │ │ │ │ │ ├── robotiq_85_gripper_joint_0_L.stl │ │ │ │ │ │ │ │ ├── robotiq_85_gripper_joint_0_R.stl │ │ │ │ │ │ │ │ ├── robotiq_85_gripper_joint_1_L.stl │ │ │ │ │ │ │ │ ├── robotiq_85_gripper_joint_1_R.stl │ │ │ │ │ │ │ │ ├── robotiq_85_gripper_joint_2_L.stl │ │ │ │ │ │ │ │ ├── robotiq_85_gripper_joint_2_R.stl │ │ │ │ │ │ │ │ ├── robotiq_85_gripper_joint_3_L.stl │ │ │ │ │ │ │ │ ├── robotiq_85_gripper_joint_3_R.stl │ │ │ │ │ │ │ │ └── robotiq_85_gripper_adapter_plate.stl │ │ │ │ │ │ │ │ └── two_finger_gripper │ │ │ │ │ │ │ │ ├── half_round_tip.STL │ │ │ │ │ │ │ │ ├── standard_narrow.STL │ │ │ │ │ │ │ │ └── electric_gripper_base.STL │ │ │ │ │ │ ├── base.xml │ │ │ │ │ │ └── arenas │ │ │ │ │ │ │ ├── empty_arena.xml │ │ │ │ │ │ │ ├── table_arena.xml │ │ │ │ │ │ │ └── pegs_arena.xml │ │ │ │ │ ├── arenas │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── empty_arena.py │ │ │ │ │ │ ├── arena.py │ │ │ │ │ │ ├── bins_arena.py │ │ │ │ │ │ ├── pegs_arena.py │ │ │ │ │ │ └── table_arena.py │ │ │ │ │ ├── tasks │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── task.py │ │ │ │ │ │ ├── table_top_task.py │ │ │ │ │ │ └── nut_assembly_task.py │ │ │ │ │ ├── world.py │ │ │ │ │ ├── grippers │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── pushing_gripper.py │ │ │ │ │ │ ├── gripper_factory.py │ │ │ │ │ │ ├── pr2_gripper.py │ │ │ │ │ │ ├── robotiq_three_finger_gripper.py │ │ │ │ │ │ ├── robotiq_gripper.py │ │ │ │ │ │ └── gripper.py │ │ │ │ │ ├── objects │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── xml_objects.py │ │ │ │ │ └── README.md │ │ │ │ ├── utils │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── errors.py │ │ │ │ │ └── mujoco_py_renderer.py │ │ │ │ ├── controllers │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── controller.py │ │ │ │ ├── wrappers │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── gym_wrapper.py │ │ │ │ │ ├── wrapper.py │ │ │ │ │ └── README.md │ │ │ │ ├── __init__.py │ │ │ │ └── demo.py │ │ │ ├── utils │ │ │ │ ├── gripper_only_robot.py │ │ │ │ ├── wide_range_gripper.py │ │ │ │ └── numbered_box_object.py │ │ │ ├── demo_gen │ │ │ │ └── block_stacking_demo_agent.py │ │ │ ├── block_stacking_data_loader.py │ │ │ ├── block_stacking_logger.py │ │ │ └── block_task_generator.py │ │ └── assets │ │ │ ├── textures │ │ │ ├── obj0.png │ │ │ ├── obj1.png │ │ │ ├── obj2.png │ │ │ ├── obj3.png │ │ │ ├── obj4.png │ │ │ ├── obj5.png │ │ │ ├── obj6.png │ │ │ ├── obj7.png │ │ │ ├── obj8.png │ │ │ └── obj9.png │ │ │ └── gripper_only_robot.xml │ ├── peg_in_hole │ │ └── src │ │ │ └── generate_peg_in_hole_hdf5.py │ └── README.md ├── utils │ ├── __init__.py │ ├── transformations.py │ ├── data_utils.py │ ├── video_utils.py │ └── dist_utils.py ├── components │ ├── __init__.py │ ├── params.py │ └── trainer_base.py ├── models │ └── __init__.py ├── modules │ ├── __init__.py │ └── losses.py ├── configs │ ├── rl │ │ ├── maze │ │ │ ├── SAC │ │ │ │ └── conf.py │ │ │ ├── prior_initialized │ │ │ │ ├── bc_finetune │ │ │ │ │ └── conf.py │ │ │ │ ├── flat_prior │ │ │ │ │ └── conf.py │ │ │ │ └── base_conf.py │ │ │ └── base_conf.py │ │ ├── kitchen │ │ │ ├── SAC │ │ │ │ └── conf.py │ │ │ ├── prior_initialized │ │ │ │ ├── bc_finetune │ │ │ │ │ └── conf.py │ │ │ │ ├── flat_prior │ │ │ │ │ └── conf.py │ │ │ │ └── base_conf.py │ │ │ └── base_conf.py │ │ ├── peg_in_hole │ │ │ ├── SAC │ │ │ │ └── conf.py │ │ │ └── base_conf.py │ │ └── block_stacking │ │ │ ├── SAC │ │ │ └── conf.py │ │ │ ├── prior_initialized │ │ │ ├── bc_finetune │ │ │ │ └── conf.py │ │ │ ├── flat_prior │ │ │ │ └── conf.py │ │ │ └── base_conf.py │ │ │ └── base_conf.py │ ├── hrl │ │ ├── maze │ │ │ ├── no_prior │ │ │ │ └── conf.py │ │ │ └── spirl │ │ │ │ └── conf.py │ │ ├── kitchen │ │ │ ├── no_prior │ │ │ │ └── conf.py │ │ │ ├── spirl │ │ │ │ └── conf.py │ │ │ └── spirl_cl │ │ │ │ └── conf.py │ │ ├── block_stacking │ │ │ ├── no_prior │ │ │ │ └── conf.py │ │ │ └── spirl │ │ │ │ └── conf.py │ │ └── peg_in_hole │ │ │ └── spirl │ │ │ └── conf.py │ ├── default_data_configs │ │ ├── maze.py │ │ ├── block_stacking.py │ │ ├── kitchen.py │ │ └── peg_in_hole.py │ ├── skill_prior_learning │ │ ├── kitchen │ │ │ ├── flat │ │ │ │ └── conf.py │ │ │ ├── hierarchical_cl │ │ │ │ ├── conf.py │ │ │ │ └── README.md │ │ │ └── hierarchical │ │ │ │ └── conf.py │ │ ├── maze │ │ │ ├── flat │ │ │ │ └── conf.py │ │ │ ├── hierarchical │ │ │ │ └── conf.py │ │ │ └── hierarchical_cl │ │ │ │ └── conf.py │ │ ├── block_stacking │ │ │ ├── flat │ │ │ │ └── conf.py │ │ │ ├── hierarchical_cl │ │ │ │ ├── conf.py │ │ │ │ └── README.md │ │ │ └── hierarchical │ │ │ │ └── conf.py │ │ └── peg_in_hole │ │ │ └── hierarchical_cl │ │ │ └── conf.py │ └── data_collect │ │ └── block_stacking │ │ └── conf.py ├── rl │ ├── utils │ │ ├── reward_fcns.py │ │ ├── robosuite_utils.py │ │ ├── rollout_utils.py │ │ └── wandb.py │ ├── policies │ │ └── basic_policies.py │ ├── envs │ │ ├── maze.py │ │ └── kitchen.py │ ├── components │ │ └── params.py │ └── agents │ │ └── skill_space_agent.py └── plotter │ └── plot_trajectories.py ├── .gitignore ├── README.md ├── setup.py └── requirements.txt /spirl/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /spirl/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /spirl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /spirl/components/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /spirl/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /spirl/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .vscode 3 | __pycache__ 4 | *.egg-info 5 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /spirl/configs/rl/maze/SAC/conf.py: -------------------------------------------------------------------------------- 1 | from spirl.configs.rl.maze.base_conf import * -------------------------------------------------------------------------------- /spirl/configs/hrl/maze/no_prior/conf.py: -------------------------------------------------------------------------------- 1 | from spirl.configs.hrl.maze.base_conf import * -------------------------------------------------------------------------------- /spirl/configs/rl/kitchen/SAC/conf.py: -------------------------------------------------------------------------------- 1 | from spirl.configs.rl.kitchen.base_conf import * -------------------------------------------------------------------------------- /spirl/configs/rl/peg_in_hole/SAC/conf.py: -------------------------------------------------------------------------------- 1 | from spirl.configs.rl.peg_in_hole.base_conf import * -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # learning_impedance_actions 2 | 3 | Peg-in-hole dataset will be available soon. 4 | -------------------------------------------------------------------------------- /spirl/configs/hrl/kitchen/no_prior/conf.py: -------------------------------------------------------------------------------- 1 | from spirl.configs.hrl.kitchen.base_conf import * 2 | -------------------------------------------------------------------------------- /spirl/configs/rl/block_stacking/SAC/conf.py: -------------------------------------------------------------------------------- 1 | from spirl.configs.rl.block_stacking.base_conf import * -------------------------------------------------------------------------------- /spirl/configs/hrl/block_stacking/no_prior/conf.py: -------------------------------------------------------------------------------- 1 | from spirl.configs.hrl.block_stacking.base_conf import * -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | 3 | setup(name='spirl', version='0.0.1', packages=['spirl']) 4 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/environments/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import REGISTERED_ENVS, MujocoEnv 2 | 3 | ALL_ENVS = REGISTERED_ENVS.keys() 4 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/devices/__init__.py: -------------------------------------------------------------------------------- 1 | from .device import Device 2 | from .keyboard import Keyboard 3 | from .spacemouse import SpaceMouse 4 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/robots/__init__.py: -------------------------------------------------------------------------------- 1 | from .robot import Robot 2 | from .sawyer_robot import Sawyer 3 | from .baxter_robot import Baxter 4 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/assets/textures/obj0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/assets/textures/obj0.png -------------------------------------------------------------------------------- /spirl/data/block_stacking/assets/textures/obj1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/assets/textures/obj1.png -------------------------------------------------------------------------------- /spirl/data/block_stacking/assets/textures/obj2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/assets/textures/obj2.png -------------------------------------------------------------------------------- /spirl/data/block_stacking/assets/textures/obj3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/assets/textures/obj3.png -------------------------------------------------------------------------------- /spirl/data/block_stacking/assets/textures/obj4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/assets/textures/obj4.png -------------------------------------------------------------------------------- /spirl/data/block_stacking/assets/textures/obj5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/assets/textures/obj5.png -------------------------------------------------------------------------------- /spirl/data/block_stacking/assets/textures/obj6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/assets/textures/obj6.png -------------------------------------------------------------------------------- /spirl/data/block_stacking/assets/textures/obj7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/assets/textures/obj7.png -------------------------------------------------------------------------------- /spirl/data/block_stacking/assets/textures/obj8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/assets/textures/obj8.png -------------------------------------------------------------------------------- /spirl/data/block_stacking/assets/textures/obj9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/assets/textures/obj9.png -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .world import MujocoWorldBase 3 | 4 | assets_root = os.path.join(os.path.dirname(__file__), "assets") 5 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .errors import robosuiteError, XMLError, SimulationError, RandomizationError 2 | from .mujoco_py_renderer import MujocoPyRenderer 3 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/textures/can.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/textures/can.png -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/textures/clay.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/textures/clay.png -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/controllers/__init__.py: -------------------------------------------------------------------------------- 1 | from .controller import Controller 2 | from .baxter_ik_controller import BaxterIKController 3 | from .sawyer_ik_controller import SawyerIKController 4 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/textures/bread.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/textures/bread.png -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/textures/cereal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/textures/cereal.png -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/textures/glass.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/textures/glass.png -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/textures/lemon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/textures/lemon.png -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/textures/metal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/textures/metal.png -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/objects/meshes/can.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/objects/meshes/can.stl -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/textures/ceramic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/textures/ceramic.png -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/textures/dark-wood.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/textures/dark-wood.png -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/objects/meshes/bread.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/objects/meshes/bread.stl -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/objects/meshes/lemon.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/objects/meshes/lemon.stl -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/objects/meshes/milk.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/objects/meshes/milk.stl -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/textures/light-wood.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/textures/light-wood.png -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/objects/meshes/bottle.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/objects/meshes/bottle.stl -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/objects/meshes/cereal.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/objects/meshes/cereal.stl -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/objects/meshes/handles.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/objects/meshes/handles.stl -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/arenas/__init__.py: -------------------------------------------------------------------------------- 1 | from .arena import Arena 2 | from .bins_arena import BinsArena 3 | from .empty_arena import EmptyArena 4 | from .pegs_arena import PegsArena 5 | from .table_arena import TableArena 6 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/scripts/README.md: -------------------------------------------------------------------------------- 1 | Scripts 2 | ======= 3 | 4 | This folder contains a collection of scripts to demonstrate the functionalities of robosuite. Check the documentation in the script files for detailed instructions. -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/deprecated/limiter.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/deprecated/limiter.STL -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/deprecated/paddle_tip.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/deprecated/paddle_tip.STL -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/pr2_gripper/l_finger.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/pr2_gripper/l_finger.stl -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/palm.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/palm.STL -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/deprecated/extended_wide.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/deprecated/extended_wide.STL -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/deprecated/standard_wide.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/deprecated/standard_wide.STL -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/pr2_gripper/gripper_palm.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/pr2_gripper/gripper_palm.stl -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/pr2_gripper/l_finger_tip.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/pr2_gripper/l_finger_tip.stl -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_0.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_0.STL -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_1.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_1.STL -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_2.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_2.STL -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_3.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_3.STL -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/deprecated/basic_hard_tip.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/deprecated/basic_hard_tip.STL -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/deprecated/basic_soft_tip.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/deprecated/basic_soft_tip.STL -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/deprecated/extended_narrow.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/deprecated/extended_narrow.STL -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/adapter_plate.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/adapter_plate.stl -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/adapter_plate.dae.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/adapter_plate.dae.stl -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/two_finger_gripper/half_round_tip.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/two_finger_gripper/half_round_tip.STL -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/two_finger_gripper/standard_narrow.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/two_finger_gripper/standard_narrow.STL -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/deprecated/electric_gripper_w_fingers.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/deprecated/electric_gripper_w_fingers.STL -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_base.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_base.stl -------------------------------------------------------------------------------- /spirl/configs/rl/kitchen/prior_initialized/bc_finetune/conf.py: -------------------------------------------------------------------------------- 1 | from spirl.configs.rl.kitchen.prior_initialized.base_conf import * 2 | from spirl.rl.policies.prior_policies import PriorInitializedPolicy 3 | 4 | agent_config.policy = PriorInitializedPolicy 5 | configuration.agent = SACAgent 6 | 7 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/two_finger_gripper/electric_gripper_base.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/two_finger_gripper/electric_gripper_base.STL -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_joint_0_L.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_joint_0_L.stl -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_joint_0_R.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_joint_0_R.stl -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_joint_1_L.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_joint_1_L.stl -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_joint_1_R.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_joint_1_R.stl -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_joint_2_L.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_joint_2_L.stl -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_joint_2_R.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_joint_2_R.stl -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_joint_3_L.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_joint_3_L.stl -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_joint_3_R.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_joint_3_R.stl -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_adapter_plate.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yquantao/learning_impedance_actions/HEAD/spirl/data/block_stacking/src/robosuite/models/assets/grippers/meshes/robotiq_gripper/robotiq_85_gripper_adapter_plate.stl -------------------------------------------------------------------------------- /spirl/configs/rl/block_stacking/prior_initialized/bc_finetune/conf.py: -------------------------------------------------------------------------------- 1 | from spirl.configs.rl.block_stacking.prior_initialized.base_conf import * 2 | from spirl.rl.policies.prior_policies import ACPriorInitializedPolicy 3 | 4 | # update agent 5 | agent_config.policy = ACPriorInitializedPolicy 6 | configuration.agent = SACAgent 7 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | from spirl.data.block_stacking.src.robosuite.wrappers.ik_wrapper import IKWrapper 2 | 3 | try: 4 | from spirl.data.block_stacking.src.robosuite.wrappers import GymWrapper 5 | except: 6 | print("Warning: make sure gym is installed if you want to use the GymWrapper.") 7 | -------------------------------------------------------------------------------- /spirl/configs/rl/maze/prior_initialized/bc_finetune/conf.py: -------------------------------------------------------------------------------- 1 | from spirl.configs.rl.maze.prior_initialized.base_conf import * 2 | from spirl.rl.policies.prior_policies import ACPriorInitializedPolicy 3 | from spirl.data.maze.src.maze_agents import MazeSACAgent 4 | 5 | agent_config.policy = ACPriorInitializedPolicy 6 | configuration.agent = MazeSACAgent 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # core 2 | numpy 3 | matplotlib 4 | pillow 5 | h5py==2.10.0 6 | scikit-image 7 | funcsigs 8 | opencv-python 9 | moviepy 10 | torch==1.3.1 11 | torchvision==0.4.2 12 | tensorboard==2.1.1 13 | tensorboardX==2.0 14 | gym==0.15.4 15 | pandas 16 | 17 | # RL 18 | wandb 19 | mpi4py 20 | 21 | # Block Stacking 22 | # mujoco_py==2.0.2.9 23 | -------------------------------------------------------------------------------- /spirl/rl/utils/reward_fcns.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def euclid_distance(state1, state2): 5 | return np.float64(np.linalg.norm(state1 - state2)) # rewards need to be np arrays for replay buffer 6 | 7 | 8 | def sparse_threshold(state1, state2, thresh): 9 | return np.float64(euclid_distance(state1, state2) < thresh) 10 | 11 | 12 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .task import Task 2 | 3 | from .placement_sampler import ( 4 | ObjectPositionSampler, 5 | UniformRandomSampler, 6 | UniformRandomPegsSampler, 7 | ) 8 | 9 | from .pick_place_task import PickPlaceTask 10 | from .nut_assembly_task import NutAssemblyTask 11 | from .table_top_task import TableTopTask 12 | -------------------------------------------------------------------------------- /spirl/configs/default_data_configs/maze.py: -------------------------------------------------------------------------------- 1 | from spirl.utils.general_utils import AttrDict 2 | from spirl.components.data_loader import GlobalSplitVideoDataset 3 | 4 | 5 | data_spec = AttrDict( 6 | dataset_class=GlobalSplitVideoDataset, 7 | n_actions=2, 8 | state_dim=4, 9 | split=AttrDict(train=0.9, val=0.1, test=0.0), 10 | res=32, 11 | crop_rand_subseq=True, 12 | ) 13 | data_spec.max_seq_len = 300 14 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/world.py: -------------------------------------------------------------------------------- 1 | from spirl.data.block_stacking.src.robosuite.models.base import MujocoXML 2 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import xml_path_completion 3 | 4 | 5 | class MujocoWorldBase(MujocoXML): 6 | """Base class to inherit all mujoco worlds from.""" 7 | 8 | def __init__(self): 9 | super().__init__(xml_path_completion("base.xml")) 10 | -------------------------------------------------------------------------------- /spirl/configs/default_data_configs/block_stacking.py: -------------------------------------------------------------------------------- 1 | from spirl.utils.general_utils import AttrDict 2 | from spirl.components.data_loader import GlobalSplitVideoDataset 3 | 4 | data_spec = AttrDict( 5 | dataset_class=GlobalSplitVideoDataset, 6 | n_actions=3, 7 | state_dim=41, 8 | split=AttrDict(train=0.95, val=0.05, test=0.0), 9 | res=32, 10 | crop_rand_subseq=True, 11 | ) 12 | data_spec.max_seq_len = 150 13 | -------------------------------------------------------------------------------- /spirl/configs/default_data_configs/kitchen.py: -------------------------------------------------------------------------------- 1 | from spirl.utils.general_utils import AttrDict 2 | from spirl.data.kitchen.src.kitchen_data_loader import D4RLSequenceSplitDataset 3 | 4 | 5 | data_spec = AttrDict( 6 | dataset_class=D4RLSequenceSplitDataset, 7 | n_actions=9, 8 | state_dim=60, 9 | env_name="kitchen-mixed-v0", 10 | res=128, 11 | crop_rand_subseq=True, 12 | ) 13 | data_spec.max_seq_len = 280 14 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/grippers/__init__.py: -------------------------------------------------------------------------------- 1 | from .gripper import Gripper 2 | from .gripper_factory import gripper_factory 3 | from .two_finger_gripper import TwoFingerGripper, LeftTwoFingerGripper 4 | from .gripper_tester import GripperTester 5 | from .pr2_gripper import PR2Gripper 6 | from .pushing_gripper import PushingGripper 7 | from .robotiq_gripper import RobotiqGripper 8 | from .robotiq_three_finger_gripper import RobotiqThreeFingerGripper 9 | -------------------------------------------------------------------------------- /spirl/configs/rl/kitchen/prior_initialized/flat_prior/conf.py: -------------------------------------------------------------------------------- 1 | from spirl.configs.rl.kitchen.prior_initialized.base_conf import * 2 | from spirl.rl.policies.prior_policies import LearnedPriorAugmentedPIPolicy 3 | from spirl.rl.agents.prior_sac_agent import ActionPriorSACAgent 4 | 5 | agent_config.update(AttrDict( 6 | td_schedule_params=AttrDict(p=1.), 7 | )) 8 | 9 | agent_config.policy = LearnedPriorAugmentedPIPolicy 10 | configuration.agent = ActionPriorSACAgent 11 | -------------------------------------------------------------------------------- /spirl/configs/rl/maze/prior_initialized/flat_prior/conf.py: -------------------------------------------------------------------------------- 1 | from spirl.configs.rl.maze.prior_initialized.base_conf import * 2 | from spirl.rl.policies.prior_policies import ACLearnedPriorAugmentedPIPolicy 3 | from spirl.data.maze.src.maze_agents import MazeActionPriorSACAgent 4 | 5 | agent_config.update(AttrDict( 6 | td_schedule_params=AttrDict(p=1.), 7 | )) 8 | 9 | agent_config.policy = ACLearnedPriorAugmentedPIPolicy 10 | configuration.agent = MazeActionPriorSACAgent 11 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/arenas/empty_arena.py: -------------------------------------------------------------------------------- 1 | from spirl.data.block_stacking.src.robosuite.models.arenas import Arena 2 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import xml_path_completion 3 | 4 | 5 | class EmptyArena(Arena): 6 | """Empty workspace.""" 7 | 8 | def __init__(self): 9 | super().__init__(xml_path_completion("arenas/empty_arena.xml")) 10 | self.floor = self.worldbody.find("./geom[@name='floor']") 11 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/base.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /spirl/configs/rl/kitchen/prior_initialized/base_conf.py: -------------------------------------------------------------------------------- 1 | from spirl.configs.rl.kitchen.base_conf import * 2 | from spirl.models.bc_mdl import BCMdl 3 | 4 | policy_params.update(AttrDict( 5 | prior_model=BCMdl, 6 | prior_model_params=AttrDict(state_dim=data_spec.state_dim, 7 | action_dim=data_spec.n_actions, 8 | ), 9 | prior_model_checkpoint=os.path.join(os.environ["EXP_DIR"], "skill_prior_learning/kitchen/flat"), 10 | )) 11 | 12 | -------------------------------------------------------------------------------- /spirl/configs/rl/block_stacking/prior_initialized/flat_prior/conf.py: -------------------------------------------------------------------------------- 1 | from spirl.configs.rl.block_stacking.prior_initialized.base_conf import * 2 | from spirl.rl.policies.prior_policies import ACLearnedPriorAugmentedPIPolicy 3 | from spirl.rl.agents.prior_sac_agent import ActionPriorSACAgent 4 | 5 | agent_config.update(AttrDict( 6 | td_schedule_params=AttrDict(p=1.), 7 | )) 8 | 9 | # update agent 10 | agent_config.policy = ACLearnedPriorAugmentedPIPolicy 11 | configuration.agent = ActionPriorSACAgent -------------------------------------------------------------------------------- /spirl/configs/default_data_configs/peg_in_hole.py: -------------------------------------------------------------------------------- 1 | from spirl.utils.general_utils import AttrDict 2 | #from spirl.data.kitchen.src.kitchen_data_loader import D4RLSequenceSplitDataset 3 | from spirl.data.peg_in_hole.src.peg_in_hole_data_loader import PegInHoleSequenceSplitDataset 4 | 5 | data_spec = AttrDict( 6 | dataset_class=PegInHoleSequenceSplitDataset, 7 | n_actions=8, 8 | state_dim=19, 9 | env_name="peg-in-hole-v0", 10 | res=128, 11 | crop_rand_subseq=True, 12 | ) 13 | data_spec.max_seq_len = 280 14 | -------------------------------------------------------------------------------- /spirl/utils/transformations.py: -------------------------------------------------------------------------------- 1 | from scipy.spatial.transform import Rotation 2 | 3 | # Create a rotation object from Euler angles specifying axes of rotation 4 | rot = Rotation.from_euler('xyz', [0.35, 0, 0.23], degrees=False) 5 | 6 | # Convert to quaternions and print 7 | rot_quat = rot.as_quat() 8 | print(rot_quat) 9 | 10 | #print(rot.as_euler('xyz', degrees=True)) 11 | 12 | rot = Rotation.from_quat(rot_quat) 13 | 14 | # Convert the rotation to Euler angles given the axes of rotation 15 | print(rot.as_euler('xyz', degrees=False)) -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/grippers/pushing_gripper.py: -------------------------------------------------------------------------------- 1 | """ 2 | A version of TwoFingerGripper but always closed. 3 | """ 4 | import numpy as np 5 | from spirl.data.block_stacking.src.robosuite.models.grippers.two_finger_gripper import TwoFingerGripper 6 | 7 | 8 | class PushingGripper(TwoFingerGripper): 9 | """ 10 | Same as TwoFingerGripper, but always closed 11 | """ 12 | 13 | def format_action(self, action): 14 | return np.array([1, -1]) 15 | 16 | @property 17 | def dof(self): 18 | return 1 19 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/utils/errors.py: -------------------------------------------------------------------------------- 1 | class robosuiteError(Exception): 2 | """Base class for exceptions in robosuite.""" 3 | 4 | pass 5 | 6 | 7 | class XMLError(robosuiteError): 8 | """Exception raised for errors related to xml.""" 9 | 10 | pass 11 | 12 | 13 | class SimulationError(robosuiteError): 14 | """Exception raised for errors during runtime.""" 15 | 16 | pass 17 | 18 | 19 | class RandomizationError(robosuiteError): 20 | """Exception raised for really really bad RNG.""" 21 | 22 | pass 23 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/objects/__init__.py: -------------------------------------------------------------------------------- 1 | from .objects import MujocoObject, MujocoXMLObject, MujocoGeneratedObject 2 | 3 | from .xml_objects import ( 4 | BottleObject, 5 | CanObject, 6 | LemonObject, 7 | MilkObject, 8 | BreadObject, 9 | CerealObject, 10 | SquareNutObject, 11 | RoundNutObject, 12 | MilkVisualObject, 13 | BreadVisualObject, 14 | CerealVisualObject, 15 | CanVisualObject, 16 | PlateWithHoleObject, 17 | ) 18 | 19 | from .generated_objects import ( 20 | PotWithHandlesObject, 21 | BoxObject, 22 | CylinderObject, 23 | BallObject, 24 | CapsuleObject, 25 | ) 26 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/devices/device.py: -------------------------------------------------------------------------------- 1 | import abc # for abstract base class definitions 2 | 3 | 4 | class Device(metaclass=abc.ABCMeta): 5 | """ 6 | Base class for all robot controllers. 7 | Defines basic interface for all controllers to adhere to. 8 | """ 9 | 10 | @abc.abstractmethod 11 | def start_control(self): 12 | """ 13 | Method that should be called externally before controller can 14 | start receiving commands. 15 | """ 16 | raise NotImplementedError 17 | 18 | @abc.abstractmethod 19 | def get_controller_state(self): 20 | """Returns the current state of the device, a dictionary of pos, orn, grasp, and reset.""" 21 | raise NotImplementedError 22 | -------------------------------------------------------------------------------- /spirl/configs/hrl/block_stacking/spirl/conf.py: -------------------------------------------------------------------------------- 1 | from spirl.configs.hrl.block_stacking.base_conf import * 2 | from spirl.rl.policies.prior_policies import ACLearnedPriorAugmentedPIPolicy 3 | from spirl.rl.agents.prior_sac_agent import ActionPriorSACAgent 4 | 5 | 6 | # add prior model to policy 7 | hl_policy_params.update(AttrDict( 8 | prior_model=ll_agent_config.model, 9 | prior_model_params=ll_agent_config.model_params, 10 | prior_model_checkpoint=ll_agent_config.model_checkpoint, 11 | )) 12 | hl_agent_config.policy = ACLearnedPriorAugmentedPIPolicy 13 | 14 | # update agent + set target divergence 15 | agent_config.hl_agent = ActionPriorSACAgent 16 | agent_config.hl_agent_params.update(AttrDict( 17 | td_schedule_params=AttrDict(p=5.), 18 | )) 19 | 20 | -------------------------------------------------------------------------------- /spirl/configs/hrl/kitchen/spirl/conf.py: -------------------------------------------------------------------------------- 1 | from spirl.configs.hrl.kitchen.base_conf import * 2 | from spirl.rl.policies.prior_policies import LearnedPriorAugmentedPIPolicy 3 | from spirl.rl.agents.prior_sac_agent import ActionPriorSACAgent 4 | 5 | 6 | # update policy to use prior model for computing divergence 7 | hl_policy_params.update(AttrDict( 8 | prior_model=ll_agent_config.model, 9 | prior_model_params=ll_agent_config.model_params, 10 | prior_model_checkpoint=ll_agent_config.model_checkpoint, 11 | )) 12 | hl_agent_config.policy = LearnedPriorAugmentedPIPolicy 13 | 14 | # update agent, set target divergence 15 | agent_config.hl_agent = ActionPriorSACAgent 16 | agent_config.hl_agent_params.update(AttrDict( 17 | td_schedule_params=AttrDict(p=5.), 18 | )) 19 | -------------------------------------------------------------------------------- /spirl/configs/hrl/peg_in_hole/spirl/conf.py: -------------------------------------------------------------------------------- 1 | from spirl.configs.hrl.peg_in_hole.base_conf import * 2 | from spirl.rl.policies.prior_policies import LearnedPriorAugmentedPIPolicy 3 | from spirl.rl.agents.prior_sac_agent import ActionPriorSACAgent 4 | 5 | 6 | # update policy to use prior model for computing divergence 7 | hl_policy_params.update(AttrDict( 8 | prior_model=ll_agent_config.model, 9 | prior_model_params=ll_agent_config.model_params, 10 | prior_model_checkpoint=ll_agent_config.model_checkpoint, 11 | )) 12 | hl_agent_config.policy = LearnedPriorAugmentedPIPolicy 13 | 14 | # update agent, set target divergence 15 | agent_config.hl_agent = ActionPriorSACAgent 16 | agent_config.hl_agent_params.update(AttrDict( 17 | td_schedule_params=AttrDict(p=5.), 18 | )) 19 | -------------------------------------------------------------------------------- /spirl/configs/hrl/maze/spirl/conf.py: -------------------------------------------------------------------------------- 1 | from spirl.configs.hrl.maze.base_conf import * 2 | from spirl.rl.policies.prior_policies import ACLearnedPriorAugmentedPIPolicy 3 | from spirl.rl.agents.prior_sac_agent import ActionPriorSACAgent 4 | 5 | 6 | # update policy to use learned prior 7 | hl_policy_params.update(AttrDict( 8 | prior_model=ll_agent_config.model, 9 | prior_model_params=ll_agent_config.model_params, 10 | prior_model_checkpoint=ll_agent_config.model_checkpoint, 11 | )) 12 | hl_agent_config.policy = ACLearnedPriorAugmentedPIPolicy 13 | 14 | # update agent to regularize with prior, set target divergence 15 | agent_config.hl_agent = ActionPriorSACAgent 16 | agent_config.hl_agent_params.update(AttrDict( 17 | td_schedule_params=AttrDict(p=1.), 18 | )) 19 | 20 | 21 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/objects/can-visual.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /spirl/configs/skill_prior_learning/kitchen/flat/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from spirl.models.bc_mdl import BCMdl 4 | from spirl.components.logger import Logger 5 | from spirl.utils.general_utils import AttrDict 6 | from spirl.configs.default_data_configs.kitchen import data_spec 7 | from spirl.components.evaluator import DummyEvaluator 8 | 9 | 10 | current_dir = os.path.dirname(os.path.realpath(__file__)) 11 | 12 | 13 | configuration = { 14 | 'model': BCMdl, 15 | 'logger': Logger, 16 | 'data_dir': '.', 17 | 'epoch_cycles_train': 10, 18 | 'evaluator': DummyEvaluator, 19 | } 20 | configuration = AttrDict(configuration) 21 | 22 | model_config = AttrDict( 23 | state_dim=data_spec.state_dim, 24 | action_dim=data_spec.n_actions, 25 | ) 26 | 27 | # Dataset 28 | data_config = AttrDict() 29 | data_config.dataset_spec = data_spec 30 | data_config.dataset_spec.subseq_len = 1 + 1 # flat last action from seq gets cropped 31 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/scripts/browse_mjcf_model.py: -------------------------------------------------------------------------------- 1 | """Visualize MJCF models. 2 | 3 | Loads MJCF XML models from file and renders it on screen. 4 | 5 | Example: 6 | $ python browse_arena_model.py --filepath ../models/assets/arenas/table_arena.xml 7 | """ 8 | 9 | import argparse 10 | import os 11 | 12 | from mujoco_py import load_model_from_path 13 | from mujoco_py import MjSim, MjViewer 14 | 15 | from spirl.data.block_stacking.src import robosuite 16 | 17 | if __name__ == "__main__": 18 | 19 | arena_file = os.path.join(robosuite.models.assets_root, "arenas/pegs_arena.xml") 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--filepath", type=str, default=arena_file) 23 | args = parser.parse_args() 24 | 25 | model = load_model_from_path(args.filepath) 26 | sim = MjSim(model) 27 | viewer = MjViewer(sim) 28 | 29 | print("Press ESC to exit...") 30 | while True: 31 | sim.step() 32 | viewer.render() 33 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/__init__.py: -------------------------------------------------------------------------------- 1 | #import os 2 | 3 | #from spirl.data.block_stacking.src.robosuite.environments.base import make 4 | #from spirl.data.block_stacking.src.robosuite.environments.sawyer_lift import SawyerLift 5 | #from spirl.data.block_stacking.src.robosuite.environments.sawyer_stack import SawyerStack 6 | #from spirl.data.block_stacking.src.robosuite.environments.sawyer_pick_place import SawyerPickPlace 7 | #from spirl.data.block_stacking.src.robosuite.environments.sawyer_nut_assembly import SawyerNutAssembly 8 | 9 | #from spirl.data.block_stacking.src.robosuite.environments.baxter_lift import BaxterLift 10 | #from spirl.data.block_stacking.src.robosuite.environments.baxter_peg_in_hole import BaxterPegInHole 11 | #from spirl.data.block_stacking.src.robosuite.environments.baxter_modified import BaxterChange 12 | 13 | __version__ = "0.3.0" 14 | __logo__ = """ 15 | ; / ,--. 16 | ["] ["] ,< |__**| 17 | /[_]\ [~]\/ |// | 18 | ] [ OOO /o|__| 19 | """ 20 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/objects/milk-visual.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/objects/cereal-visual.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /spirl/configs/skill_prior_learning/maze/flat/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from spirl.models.bc_mdl import ImageBCMdl 4 | from spirl.utils.general_utils import AttrDict 5 | from spirl.configs.default_data_configs.maze import data_spec 6 | from spirl.components.evaluator import DummyEvaluator 7 | from spirl.components.logger import Logger 8 | 9 | 10 | current_dir = os.path.dirname(os.path.realpath(__file__)) 11 | 12 | 13 | configuration = { 14 | 'model': ImageBCMdl, 15 | 'logger': Logger, 16 | 'data_dir': os.path.join(os.environ['DATA_DIR'], 'maze'), 17 | 'epoch_cycles_train': 4, 18 | 'evaluator': DummyEvaluator, 19 | } 20 | configuration = AttrDict(configuration) 21 | 22 | model_config = AttrDict( 23 | state_dim=data_spec.state_dim, 24 | action_dim=data_spec.n_actions, 25 | input_res=data_spec.res, 26 | n_input_frames=2, 27 | ) 28 | 29 | # Dataset 30 | data_config = AttrDict() 31 | data_config.dataset_spec = data_spec 32 | data_config.dataset_spec.subseq_len = 1 + 1 + (model_config.n_input_frames - 1) 33 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/objects/bread-visual.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /spirl/configs/skill_prior_learning/block_stacking/flat/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from spirl.models.bc_mdl import ImageBCMdl 4 | from spirl.utils.general_utils import AttrDict 5 | from spirl.configs.default_data_configs.block_stacking import data_spec 6 | from spirl.components.evaluator import DummyEvaluator 7 | from spirl.components.logger import Logger 8 | 9 | 10 | current_dir = os.path.dirname(os.path.realpath(__file__)) 11 | 12 | 13 | configuration = { 14 | 'model': ImageBCMdl, 15 | 'logger': Logger, 16 | 'data_dir': os.path.join(os.environ['DATA_DIR'], 'block_stacking'), 17 | 'epoch_cycles_train': 4, 18 | 'evaluator': DummyEvaluator, 19 | } 20 | configuration = AttrDict(configuration) 21 | 22 | model_config = AttrDict( 23 | state_dim=data_spec.state_dim, 24 | action_dim=data_spec.n_actions, 25 | input_res=data_spec.res, 26 | n_input_frames=2, 27 | ) 28 | 29 | # Dataset 30 | data_config = AttrDict() 31 | data_config.dataset_spec = data_spec 32 | data_config.dataset_spec.subseq_len = 1 + 1 + (model_config.n_input_frames - 1) 33 | -------------------------------------------------------------------------------- /spirl/data/peg_in_hole/src/generate_peg_in_hole_hdf5.py: -------------------------------------------------------------------------------- 1 | """Script for generating the datasets for peg-in-hole tasks.""" 2 | import h5py 3 | import numpy as np 4 | import pandas as pd 5 | import os 6 | import pickle 7 | 8 | def main(): 9 | dataset_df = pd.read_csv('peg_in_hole_dataset.csv', header=None) 10 | #print(action_df) 11 | #action_df.to_hdf('action.hdf5', key='action', mode='w') 12 | 13 | with h5py.File('peg_in_hole_dataset.hdf5', 'w') as hdf: 14 | # create state group 15 | #group1 = hdf.create_group('observations') 16 | hdf.create_dataset('observations', data=dataset_df.iloc[:,0:19]) 17 | 18 | # create action group 19 | #group2 = hdf.create_group('actions') 20 | hdf.create_dataset('actions', data=dataset_df.iloc[:,19:27]) 21 | 22 | # create reward group 23 | hdf.create_dataset('rewards', data=dataset_df.iloc[:,27]) 24 | 25 | #group3 = hdf.create_group('terminals') 26 | hdf.create_dataset('terminals', data=dataset_df.iloc[:,28]) 27 | 28 | 29 | if __name__ == '__main__': 30 | main() -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/objects/bottle.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/robots/sawyer_robot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from spirl.data.block_stacking.src.robosuite.models.robots.robot import Robot 3 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import xml_path_completion, array_to_string 4 | 5 | 6 | class Sawyer(Robot): 7 | """Sawyer is a witty single-arm robot designed by Rethink Robotics.""" 8 | 9 | def __init__(self): 10 | super().__init__(xml_path_completion("robots/sawyer/robot.xml")) 11 | 12 | self.bottom_offset = np.array([0, 0, -0.913]) 13 | 14 | def set_base_xpos(self, pos): 15 | """Places the robot on position @pos.""" 16 | node = self.worldbody.find("./body[@name='base']") 17 | node.set("pos", array_to_string(pos - self.bottom_offset)) 18 | 19 | @property 20 | def dof(self): 21 | return 7 22 | 23 | @property 24 | def joints(self): 25 | return ["right_j{}".format(x) for x in range(7)] 26 | 27 | @property 28 | def init_qpos(self): 29 | return np.array([0, -1.18, 0.00, 2.18, 0.00, 0.57, 3.3161]) 30 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/utils/gripper_only_robot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from spirl.data.block_stacking.src.robosuite.models.robots.robot import Robot 4 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import array_to_string 5 | 6 | 7 | class GripperOnlyRobot(Robot): 8 | """Sawyer is a witty single-arm robot designed by Rethink Robotics.""" 9 | 10 | def __init__(self): 11 | super().__init__(os.path.join(os.getcwd(), "spirl/data/block_stacking/assets/gripper_only_robot.xml")) 12 | 13 | self.bottom_offset = np.array([0, 0, 0]) 14 | 15 | def set_base_xpos(self, pos): 16 | """Places the robot on position @pos.""" 17 | node = self.worldbody.find("./body[@name='base']") 18 | node.set("pos", array_to_string(pos - self.bottom_offset)) 19 | 20 | @property 21 | def dof(self): 22 | return 4 23 | 24 | @property 25 | def joints(self): 26 | return ["slide_x", "slide_y", "slide_z", "rotate_z"] 27 | 28 | @property 29 | def init_qpos(self): 30 | return np.array([0, 0, 1.2, 0]) 31 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/scripts/compile_mjcf_model.py: -------------------------------------------------------------------------------- 1 | """Loads a raw mjcf file and saves a compiled mjcf file. 2 | 3 | This avoids mujoco-py from complaining about .urdf extension. 4 | Also allows assets to be compiled properly. 5 | 6 | Example: 7 | $ python compile_mjcf_model.py source_mjcf.xml target_mjcf.xml 8 | """ 9 | 10 | import os 11 | import sys 12 | from shutil import copyfile 13 | from mujoco_py import load_model_from_path 14 | 15 | 16 | def print_usage(): 17 | print("""python compile.py input_file output_file""") 18 | 19 | 20 | if __name__ == "__main__": 21 | 22 | if len(sys.argv) != 3: 23 | print_usage() 24 | exit(0) 25 | 26 | input_file = sys.argv[1] 27 | output_file = sys.argv[2] 28 | input_folder = os.path.dirname(input_file) 29 | 30 | tempfile = os.path.join(input_folder, ".surreal_temp_model.xml") 31 | copyfile(input_file, tempfile) 32 | 33 | model = load_model_from_path(tempfile) 34 | xml_string = model.get_xml() 35 | with open(output_file, "w") as f: 36 | f.write(xml_string) 37 | 38 | os.remove(tempfile) 39 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/assets/gripper_only_robot.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/controllers/controller.py: -------------------------------------------------------------------------------- 1 | import abc # for abstract base class definitions 2 | 3 | 4 | class Controller(metaclass=abc.ABCMeta): 5 | """ 6 | Base class for all robot controllers. 7 | Defines basic interface for all controllers to adhere to. 8 | """ 9 | 10 | def __init__(self, bullet_data_path, robot_jpos_getter): 11 | """ 12 | Args: 13 | bullet_data_path (str): base path to bullet data. 14 | 15 | robot_jpos_getter (function): function that returns the position of the joints 16 | as a numpy array of the right dimension. 17 | """ 18 | raise NotImplementedError 19 | 20 | @abc.abstractmethod 21 | def get_control(self, *args, **kwargs): 22 | """ 23 | Retrieve a control input from the controller. 24 | """ 25 | raise NotImplementedError 26 | 27 | @abc.abstractmethod 28 | def sync_state(self): 29 | """ 30 | This function does internal bookkeeping to maintain 31 | consistency between the robot being controlled and 32 | the controller state. 33 | """ 34 | raise NotImplementedError 35 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/objects/can.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/arenas/empty_arena.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /spirl/plotter/plot_trajectories.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import pandas as pd 4 | import seaborn as sns 5 | 6 | def plot_trajectory(prior_seqs, policy_seqs): 7 | ax = plt.axes(projection='3d') 8 | 9 | # plot actions 10 | x = prior_seqs[:,0] 11 | y = prior_seqs[:,1] 12 | z = prior_seqs[:,2] 13 | ax.plot3D(x, y, z, 'gray') 14 | ax.scatter3D(x, y, z) 15 | 16 | # plot actions 17 | x = policy_seqs[:,0] 18 | y = policy_seqs[:,1] 19 | z = policy_seqs[:,2] 20 | ax.plot3D(x, y, z, 'blue') 21 | ax.scatter3D(x, y, z) 22 | 23 | ax.set_xlabel('x') 24 | ax.set_ylabel('y') 25 | ax.set_zlabel('z') 26 | ax.set_xlim([0.1, 0.2]) 27 | ax.set_ylim([0.2, 0.3]) 28 | ax.set_zlim([0.0, 0.15]) 29 | 30 | plt.show() 31 | plt.savefig('fig/'+'traj.png') 32 | 33 | if __name__ == '__main__': 34 | traj_df = pd.read_csv('~/Workspaces/rl_ws/spirl/trajectories2.csv', header=None) 35 | policy_traj_array = traj_df.iloc[:,0:3].values.astype(np.float32)#.reshape(-1,120) 36 | prior_traj_array = traj_df.iloc[:,3:6].values.astype(np.float32)#.reshape(-1,120) 37 | plot_trajectory(prior_traj_array, policy_traj_array) -------------------------------------------------------------------------------- /spirl/rl/policies/basic_policies.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from spirl.modules.variational_inference import MultivariateGaussian 4 | from spirl.rl.components.policy import Policy 5 | from spirl.utils.general_utils import ParamDict 6 | 7 | 8 | class UniformGaussianPolicy(Policy): 9 | """Samples actions from a uniform Gaussian.""" 10 | def __init__(self, config): 11 | self._hp = self._default_hparams().overwrite(config) 12 | super().__init__() 13 | 14 | def _default_hparams(self): 15 | default_dict = ParamDict({ 16 | 'scale': 1, # scale of Uniform Gaussian 17 | }) 18 | return super()._default_hparams().overwrite(default_dict) 19 | 20 | def _build_network(self): 21 | return torch.nn.Module() # dummy module 22 | 23 | def _compute_action_dist(self, obs): 24 | batch_size = obs.shape[0] 25 | return MultivariateGaussian(mu=torch.zeros((batch_size, self._hp.action_dim), device=obs.device), 26 | log_sigma=torch.log(self._hp.scale * 27 | torch.ones((batch_size, self._hp.action_dim), device=obs.device))) 28 | -------------------------------------------------------------------------------- /spirl/configs/skill_prior_learning/maze/hierarchical/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from spirl.models.skill_prior_mdl import ImageSkillPriorMdl, SkillSpaceLogger 4 | from spirl.utils.general_utils import AttrDict 5 | from spirl.configs.default_data_configs.maze import data_spec 6 | from spirl.components.evaluator import TopOfNSequenceEvaluator 7 | 8 | 9 | current_dir = os.path.dirname(os.path.realpath(__file__)) 10 | 11 | 12 | configuration = { 13 | 'model': ImageSkillPriorMdl, 14 | 'logger': SkillSpaceLogger, 15 | 'data_dir': os.path.join(os.environ['DATA_DIR'], 'maze'), 16 | 'epoch_cycles_train': 10, 17 | 'evaluator': TopOfNSequenceEvaluator, 18 | 'top_of_n_eval': 100, 19 | 'top_comp_metric': 'mse', 20 | } 21 | configuration = AttrDict(configuration) 22 | 23 | model_config = AttrDict( 24 | state_dim=data_spec.state_dim, 25 | action_dim=data_spec.n_actions, 26 | n_rollout_steps=10, 27 | kl_div_weight=1e-2, 28 | prior_input_res=data_spec.res, 29 | n_input_frames=2, 30 | ) 31 | 32 | # Dataset 33 | data_config = AttrDict() 34 | data_config.dataset_spec = data_spec 35 | data_config.dataset_spec.subseq_len = model_config.n_rollout_steps + model_config.n_input_frames 36 | -------------------------------------------------------------------------------- /spirl/configs/skill_prior_learning/kitchen/hierarchical_cl/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from spirl.models.closed_loop_spirl_mdl import ClSPiRLMdl 4 | from spirl.components.logger import Logger 5 | from spirl.utils.general_utils import AttrDict 6 | from spirl.configs.default_data_configs.kitchen import data_spec 7 | from spirl.components.evaluator import TopOfNSequenceEvaluator 8 | 9 | current_dir = os.path.dirname(os.path.realpath(__file__)) 10 | 11 | 12 | configuration = { 13 | 'model': ClSPiRLMdl, 14 | 'logger': Logger, 15 | 'data_dir': '.', 16 | 'epoch_cycles_train': 10, 17 | 'evaluator': TopOfNSequenceEvaluator, 18 | 'top_of_n_eval': 100, 19 | 'top_comp_metric': 'mse', 20 | } 21 | configuration = AttrDict(configuration) 22 | 23 | model_config = AttrDict( 24 | state_dim=data_spec.state_dim, 25 | action_dim=data_spec.n_actions, 26 | n_rollout_steps=10, 27 | kl_div_weight=5e-4, 28 | nz_enc=128, 29 | nz_mid=128, 30 | n_processing_layers=5, 31 | cond_decode=True, 32 | ) 33 | 34 | # Dataset 35 | data_config = AttrDict() 36 | data_config.dataset_spec = data_spec 37 | data_config.dataset_spec.subseq_len = model_config.n_rollout_steps + 1 # flat last action from seq gets cropped 38 | -------------------------------------------------------------------------------- /spirl/configs/skill_prior_learning/peg_in_hole/hierarchical_cl/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from spirl.models.closed_loop_spirl_mdl import ClSPiRLMdl 4 | from spirl.components.logger import Logger 5 | from spirl.utils.general_utils import AttrDict 6 | from spirl.configs.default_data_configs.peg_in_hole import data_spec 7 | from spirl.components.evaluator import TopOfNSequenceEvaluator 8 | 9 | current_dir = os.path.dirname(os.path.realpath(__file__)) 10 | 11 | 12 | configuration = { 13 | 'model': ClSPiRLMdl, 14 | 'logger': Logger, 15 | 'data_dir': '.', 16 | 'epoch_cycles_train': 10, 17 | 'evaluator': TopOfNSequenceEvaluator, 18 | 'top_of_n_eval': 100, 19 | 'top_comp_metric': 'mse', 20 | } 21 | configuration = AttrDict(configuration) 22 | 23 | model_config = AttrDict( 24 | state_dim=data_spec.state_dim, 25 | action_dim=data_spec.n_actions, 26 | n_rollout_steps=10, 27 | kl_div_weight=5e-4, 28 | nz_enc=128, 29 | nz_mid=128, 30 | n_processing_layers=5, 31 | cond_decode=True, 32 | ) 33 | 34 | # Dataset 35 | data_config = AttrDict() 36 | data_config.dataset_spec = data_spec 37 | data_config.dataset_spec.subseq_len = model_config.n_rollout_steps + 1 # flat last action from seq gets cropped 38 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/objects/plate-with-hole.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/objects/bread.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/objects/lemon.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/objects/milk.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /spirl/configs/skill_prior_learning/kitchen/hierarchical/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from spirl.models.skill_prior_mdl import SkillPriorMdl 4 | from spirl.components.logger import Logger 5 | from spirl.models.skill_prior_mdl import SkillSpaceLogger 6 | from spirl.utils.general_utils import AttrDict 7 | from spirl.configs.default_data_configs.kitchen import data_spec 8 | from spirl.components.evaluator import TopOfNSequenceEvaluator 9 | 10 | current_dir = os.path.dirname(os.path.realpath(__file__)) 11 | 12 | 13 | configuration = { 14 | 'model': SkillPriorMdl, 15 | 'logger': SkillSpaceLogger, 16 | 'data_dir': '.', 17 | 'epoch_cycles_train': 10, 18 | 'evaluator': TopOfNSequenceEvaluator, 19 | 'top_of_n_eval': 100, 20 | 'top_comp_metric': 'mse', 21 | } 22 | configuration = AttrDict(configuration) 23 | 24 | model_config = AttrDict( 25 | state_dim=data_spec.state_dim, 26 | action_dim=data_spec.n_actions, 27 | n_rollout_steps=10, 28 | kl_div_weight=5e-4, 29 | nz_enc=128, 30 | nz_mid=128, 31 | n_processing_layers=5, 32 | ) 33 | 34 | # Dataset 35 | data_config = AttrDict() 36 | data_config.dataset_spec = data_spec 37 | data_config.dataset_spec.subseq_len = model_config.n_rollout_steps + 1 # flat last action from seq gets cropped 38 | -------------------------------------------------------------------------------- /spirl/configs/skill_prior_learning/maze/hierarchical_cl/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from spirl.models.skill_prior_mdl import SkillSpaceLogger 4 | from spirl.models.closed_loop_spirl_mdl import ImageClSPiRLMdl 5 | from spirl.utils.general_utils import AttrDict 6 | from spirl.configs.default_data_configs.maze import data_spec 7 | from spirl.components.evaluator import TopOfNSequenceEvaluator 8 | 9 | 10 | current_dir = os.path.dirname(os.path.realpath(__file__)) 11 | 12 | 13 | configuration = { 14 | 'model': ImageClSPiRLMdl, 15 | 'logger': SkillSpaceLogger, 16 | 'data_dir': os.path.join(os.environ['DATA_DIR'], 'maze'), 17 | 'epoch_cycles_train': 10, 18 | 'evaluator': TopOfNSequenceEvaluator, 19 | 'top_of_n_eval': 100, 20 | 'top_comp_metric': 'mse', 21 | } 22 | configuration = AttrDict(configuration) 23 | 24 | model_config = AttrDict( 25 | state_dim=data_spec.state_dim, 26 | action_dim=data_spec.n_actions, 27 | n_rollout_steps=10, 28 | kl_div_weight=1e-2, 29 | prior_input_res=data_spec.res, 30 | n_input_frames=2, 31 | cond_decode=True, 32 | ) 33 | 34 | # Dataset 35 | data_config = AttrDict() 36 | data_config.dataset_spec = data_spec 37 | data_config.dataset_spec.subseq_len = model_config.n_rollout_steps + model_config.n_input_frames 38 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/objects/cereal.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/demo_gen/block_stacking_demo_agent.py: -------------------------------------------------------------------------------- 1 | from spirl.utils.general_utils import AttrDict, ParamDict 2 | from spirl.rl.components.agent import BaseAgent 3 | 4 | from spirl.data.block_stacking.src.demo_gen.block_demo_policy import ClosedLoopBlockStackDemoPolicy 5 | 6 | 7 | class BlockStackingDemoAgent(BaseAgent): 8 | """Wraps demo policy for block stacking.""" 9 | def __init__(self, config): 10 | super().__init__(config) 11 | self._policy = self._hp.policy(self._hp.env_params) 12 | 13 | def _default_hparams(self): 14 | default_dict = ParamDict({ 15 | 'policy': ClosedLoopBlockStackDemoPolicy, # policy class 16 | 'env_params': None, # parameters containing info about env -> set automatically 17 | }) 18 | return super()._default_hparams().overwrite(default_dict) 19 | 20 | @property 21 | def rollout_valid(self): 22 | return self._hp.env_params.task_complete_check() 23 | 24 | def reset(self): 25 | self._policy.reset() 26 | 27 | def _act(self, obs): 28 | return AttrDict(action=self._policy.act(obs)) 29 | 30 | def _act_rand(self, obs): 31 | raise NotImplementedError("This should not be called in the demo agent.") 32 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/block_stacking_data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from spirl.utils.general_utils import AttrDict 4 | from spirl.components.data_loader import GlobalSplitVideoDataset, GlobalSplitStateSequenceDataset 5 | 6 | 7 | class BlockStackSequenceDataset(GlobalSplitVideoDataset): 8 | """Adds info about env idx from file path.""" 9 | def _get_aux_info(self, data, path): 10 | # extract env name from file path 11 | # TODO: design an env id system for block stacking envs 12 | return AttrDict(env_id=0) 13 | 14 | def __getitem__(self, index): 15 | data = super().__getitem__(index) 16 | for key in data.keys(): 17 | if key.endswith('states') and data[key].shape[-1] == 40: 18 | # remove quatenion dimensions 19 | data[key] = data[key][:, :20] 20 | elif key.endswith('states') and data[key].shape[-1] == 43: 21 | data[key] = data[key][:, :23] 22 | if key.endswith('actions') and data[key].shape[-1] == 4: 23 | # remove rotation dimension 24 | data[key] = data[key][:, [0, 1, 3]] 25 | return data 26 | 27 | 28 | class BlockStackStateSequenceDataset(BlockStackSequenceDataset, GlobalSplitStateSequenceDataset): 29 | pass 30 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/grippers/gripper_factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines a string based method of initializing grippers 3 | """ 4 | from .two_finger_gripper import TwoFingerGripper, LeftTwoFingerGripper 5 | from .pr2_gripper import PR2Gripper 6 | from .robotiq_gripper import RobotiqGripper 7 | from .pushing_gripper import PushingGripper 8 | from .robotiq_three_finger_gripper import RobotiqThreeFingerGripper 9 | 10 | 11 | def gripper_factory(name): 12 | """ 13 | Genreator for grippers 14 | 15 | Creates a Gripper instance with the provided name. 16 | 17 | Args: 18 | name: the name of the gripper class 19 | 20 | Returns: 21 | gripper: Gripper instance 22 | 23 | Raises: 24 | XMLError: [description] 25 | """ 26 | if name == "TwoFingerGripper": 27 | return TwoFingerGripper() 28 | if name == "LeftTwoFingerGripper": 29 | return LeftTwoFingerGripper() 30 | if name == "PR2Gripper": 31 | return PR2Gripper() 32 | if name == "RobotiqGripper": 33 | return RobotiqGripper() 34 | if name == "PushingGripper": 35 | return PushingGripper() 36 | if name == "RobotiqThreeFingerGripper": 37 | return RobotiqThreeFingerGripper() 38 | raise ValueError("Unkown gripper name {}".format(name)) 39 | -------------------------------------------------------------------------------- /spirl/configs/skill_prior_learning/block_stacking/hierarchical_cl/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from spirl.models.skill_prior_mdl import SkillSpaceLogger 4 | from spirl.models.closed_loop_spirl_mdl import ImageClSPiRLMdl 5 | from spirl.utils.general_utils import AttrDict 6 | from spirl.configs.default_data_configs.block_stacking import data_spec 7 | from spirl.components.evaluator import TopOfNSequenceEvaluator 8 | 9 | 10 | current_dir = os.path.dirname(os.path.realpath(__file__)) 11 | 12 | 13 | configuration = { 14 | 'model': ImageClSPiRLMdl, 15 | 'logger': SkillSpaceLogger, 16 | 'data_dir': os.path.join(os.environ['DATA_DIR'], 'block_stacking'), 17 | 'epoch_cycles_train': 10, 18 | 'evaluator': TopOfNSequenceEvaluator, 19 | 'top_of_n_eval': 100, 20 | 'top_comp_metric': 'mse', 21 | } 22 | configuration = AttrDict(configuration) 23 | 24 | model_config = AttrDict( 25 | state_dim=data_spec.state_dim, 26 | action_dim=data_spec.n_actions, 27 | n_rollout_steps=10, 28 | kl_div_weight=1e-4, 29 | prior_input_res=data_spec.res, 30 | n_input_frames=2, 31 | cond_decode=True, 32 | ) 33 | 34 | # Dataset 35 | data_config = AttrDict() 36 | data_config.dataset_spec = data_spec 37 | data_config.dataset_spec.subseq_len = model_config.n_rollout_steps + model_config.n_input_frames 38 | -------------------------------------------------------------------------------- /spirl/configs/skill_prior_learning/block_stacking/hierarchical/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from spirl.models.skill_prior_mdl import ImageSkillPriorMdl 4 | from spirl.data.block_stacking.src.block_stacking_logger import SkillSpaceBlockStackLogger 5 | from spirl.utils.general_utils import AttrDict 6 | from spirl.configs.default_data_configs.block_stacking import data_spec 7 | from spirl.components.evaluator import TopOfNSequenceEvaluator 8 | 9 | 10 | current_dir = os.path.dirname(os.path.realpath(__file__)) 11 | 12 | 13 | configuration = { 14 | 'model': ImageSkillPriorMdl, 15 | 'logger': SkillSpaceBlockStackLogger, 16 | 'data_dir': os.path.join(os.environ['DATA_DIR'], 'block_stacking'), 17 | 'epoch_cycles_train': 10, 18 | 'evaluator': TopOfNSequenceEvaluator, 19 | 'top_of_n_eval': 100, 20 | 'top_comp_metric': 'mse', 21 | } 22 | configuration = AttrDict(configuration) 23 | 24 | model_config = AttrDict( 25 | state_dim=data_spec.state_dim, 26 | action_dim=data_spec.n_actions, 27 | n_rollout_steps=10, 28 | kl_div_weight=1e-2, 29 | prior_input_res=data_spec.res, 30 | n_input_frames=2, 31 | ) 32 | 33 | # Dataset 34 | data_config = AttrDict() 35 | data_config.dataset_spec = data_spec 36 | data_config.dataset_spec.subseq_len = model_config.n_rollout_steps + model_config.n_input_frames 37 | -------------------------------------------------------------------------------- /spirl/configs/rl/block_stacking/prior_initialized/base_conf.py: -------------------------------------------------------------------------------- 1 | from spirl.configs.rl.block_stacking.base_conf import * 2 | from spirl.rl.components.sampler import ACMultiImageAugmentedSampler 3 | from spirl.rl.policies.mlp_policies import ConvPolicy 4 | from spirl.rl.components.critic import SplitObsMLPCritic 5 | from spirl.models.bc_mdl import ImageBCMdl 6 | 7 | 8 | # update sampler 9 | configuration['sampler'] = ACMultiImageAugmentedSampler 10 | sampler_config = AttrDict( 11 | n_frames=2, 12 | ) 13 | 14 | # update policy to conv policy 15 | agent_config.policy = ConvPolicy 16 | policy_params.update(AttrDict( 17 | input_nc=3 * sampler_config.n_frames, 18 | prior_model=ImageBCMdl, 19 | prior_model_params=AttrDict(state_dim=data_spec.state_dim, 20 | action_dim=data_spec.n_actions, 21 | input_res=data_spec.res, 22 | n_input_frames=2, 23 | ), 24 | prior_model_checkpoint=os.path.join(os.environ["EXP_DIR"], 25 | "skill_prior_learning/block_stacking/flat"), 26 | )) 27 | 28 | # update critic+policy to handle multi-frame combined observation 29 | agent_config.critic = SplitObsMLPCritic 30 | agent_config.critic_params.unused_obs_size = 32**2*3 * sampler_config.n_frames 31 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/arenas/arena.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from spirl.data.block_stacking.src.robosuite.models.base import MujocoXML 4 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import array_to_string, string_to_array 5 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import new_geom, new_body, new_joint 6 | 7 | 8 | class Arena(MujocoXML): 9 | """Base arena class.""" 10 | 11 | def set_origin(self, offset): 12 | """Applies a constant offset to all objects.""" 13 | offset = np.array(offset) 14 | for node in self.worldbody.findall("./*[@pos]"): 15 | cur_pos = string_to_array(node.get("pos")) 16 | new_pos = cur_pos + offset 17 | node.set("pos", array_to_string(new_pos)) 18 | 19 | def add_pos_indicator(self): 20 | """Adds a new position indicator.""" 21 | body = new_body(name="pos_indicator") 22 | body.append( 23 | new_geom( 24 | "sphere", 25 | [0.03], 26 | rgba=[1, 0, 0, 0.5], 27 | group=1, 28 | contype="0", 29 | conaffinity="0", 30 | ) 31 | ) 32 | body.append(new_joint(type="free", name="pos_indicator")) 33 | self.worldbody.append(body) 34 | -------------------------------------------------------------------------------- /spirl/configs/data_collect/block_stacking/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from spirl.utils.general_utils import AttrDict 4 | 5 | from spirl.data.block_stacking.src.demo_gen.block_stacking_demo_agent import BlockStackingDemoAgent 6 | from spirl.data.block_stacking.src.block_stacking_env import BlockStackEnv 7 | from spirl.data.block_stacking.src.block_task_generator import FixedSizeSingleTowerBlockTaskGenerator 8 | 9 | 10 | current_dir = os.path.dirname(os.path.realpath(__file__)) 11 | 12 | notes = 'used for generating block stacking dataset' 13 | SEED = 31 14 | 15 | configuration = { 16 | 'seed': SEED, 17 | 'agent': BlockStackingDemoAgent, 18 | 'environment': BlockStackEnv, 19 | 'max_rollout_len': 250, 20 | } 21 | configuration = AttrDict(configuration) 22 | 23 | # Task 24 | task_params = AttrDict( 25 | max_tower_height=4, 26 | seed=SEED, 27 | ) 28 | 29 | # Agent 30 | agent_config = AttrDict( 31 | 32 | ) 33 | 34 | # Dataset - Random data 35 | data_config = AttrDict( 36 | 37 | ) 38 | 39 | # Environment 40 | env_config = AttrDict( 41 | task_generator=FixedSizeSingleTowerBlockTaskGenerator, 42 | task_params=task_params, 43 | dimension=2, 44 | n_steps=2, 45 | screen_width=32, 46 | screen_height=32, 47 | rand_task=True, 48 | rand_init_pos=True, 49 | camera_name='agentview', 50 | ) 51 | 52 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/utils/wide_range_gripper.py: -------------------------------------------------------------------------------- 1 | from spirl.data.block_stacking.src.robosuite.models.grippers.two_finger_gripper import TwoFingerGripper 2 | 3 | 4 | class WideRangeGripper(TwoFingerGripper): 5 | def __init__(self): 6 | super().__init__() 7 | l_gripper = self.worldbody.find(".//joint[@name='r_gripper_l_finger_joint']") 8 | l_gripper.set("range", "-0.020833 0.08") 9 | l_actuator = self.actuator.find(".//position[@joint='r_gripper_l_finger_joint']") 10 | l_actuator.set("ctrlrange", "-0.020833 0.08") 11 | l_actuator.set("gear", "2") 12 | r_gripper = self.worldbody.find(".//joint[@name='r_gripper_r_finger_joint']") 13 | r_gripper.set("range", "-0.08 0.020833") 14 | r_actuator = self.actuator.find(".//position[@joint='r_gripper_r_finger_joint']") 15 | r_actuator.set("ctrlrange", "-0.08 0.020833") 16 | r_actuator.set("gear", "2") 17 | 18 | for geom_name in ['l_finger_g1', 'r_finger_g1']: 19 | geom = self.worldbody.find(".//geom[@name='{}']".format(geom_name)) 20 | geom.set("conaffinity", "0") 21 | 22 | for geom_name in ['l_fingertip_g0', 'r_fingertip_g0']: 23 | geom = self.worldbody.find(".//geom[@name='{}']".format(geom_name)) 24 | geom.set("friction", "30 0.005 0.0001") 25 | -------------------------------------------------------------------------------- /spirl/configs/skill_prior_learning/kitchen/hierarchical_cl/README.md: -------------------------------------------------------------------------------- 1 | # SPiRL w/ Closed-Loop Skill Decoder 2 | 3 | This version of the SPiRL model uses a [closed-loop action decoder](../../../../models/closed_loop_spirl_mdl.py): 4 | in contrast to the original SPiRL model it takes the current environment state as input in every skill decoding step. 5 | 6 | We find that this model improves performance over the original 7 | SPiRL model, particularly on tasks that require precise control, like in the kitchen environment. 8 | 9 |

10 | 11 |

12 | 13 | 14 | For an implementation of the closed-loop SPiRL model that supports image observations, 15 | see [here](../../block_stacking/hierarchical_cl/README.md). 16 | 17 | ## Example Commands 18 | 19 | To train the SPiRL model with closed-loop action decoder on the kitchen environment, run the following command: 20 | ``` 21 | python3 spirl/train.py --path=spirl/configs/skill_prior_learning/kitchen/hierarchical_cl --val_data_size=160 22 | ``` 23 | 24 | To train a downstream task policy with RL using the closed-loop SPiRL model we just trained, run the following command: 25 | ``` 26 | python3 spirl/rl/train.py --path=spirl/configs/hrl/kitchen/spirl_cl --seed=0 --prefix=SPIRLv2_kitchen_seed0 27 | ``` 28 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/tasks/task.py: -------------------------------------------------------------------------------- 1 | from spirl.data.block_stacking.src.robosuite.models.world import MujocoWorldBase 2 | 3 | 4 | class Task(MujocoWorldBase): 5 | """ 6 | Base class for creating MJCF model of a task. 7 | 8 | A task typically involves a robot interacting with objects in an arena 9 | (workshpace). The purpose of a task class is to generate a MJCF model 10 | of the task by combining the MJCF models of each component together and 11 | place them to the right positions. Object placement can be done by 12 | ad-hoc methods or placement samplers. 13 | """ 14 | 15 | def merge_robot(self, mujoco_robot): 16 | """Adds robot model to the MJCF model.""" 17 | pass 18 | 19 | def merge_arena(self, mujoco_arena): 20 | """Adds arena model to the MJCF model.""" 21 | pass 22 | 23 | def merge_objects(self, mujoco_objects): 24 | """Adds physical objects to the MJCF model.""" 25 | pass 26 | 27 | def merge_visual(self, mujoco_objects): 28 | """Adds visual objects to the MJCF model.""" 29 | 30 | def place_objects(self): 31 | """Places objects randomly until no collisions or max iterations hit.""" 32 | pass 33 | 34 | def place_visual(self): 35 | """Places visual objects randomly until no collisions or max iterations hit.""" 36 | pass 37 | -------------------------------------------------------------------------------- /spirl/configs/rl/maze/prior_initialized/base_conf.py: -------------------------------------------------------------------------------- 1 | from spirl.configs.rl.maze.base_conf import * 2 | from spirl.rl.components.sampler import ACMultiImageAugmentedSampler 3 | from spirl.rl.policies.mlp_policies import ConvPolicy 4 | from spirl.rl.components.critic import SplitObsMLPCritic 5 | from spirl.models.bc_mdl import ImageBCMdl 6 | 7 | 8 | # update sampler 9 | configuration['sampler'] = ACMultiImageAugmentedSampler 10 | sampler_config = AttrDict( 11 | n_frames=2, 12 | ) 13 | env_config.screen_width = data_spec.res 14 | env_config.screen_height = data_spec.res 15 | 16 | # update policy to conv policy 17 | agent_config.policy = ConvPolicy 18 | policy_params.update(AttrDict( 19 | input_nc=3 * sampler_config.n_frames, 20 | prior_model=ImageBCMdl, 21 | prior_model_params=AttrDict(state_dim=data_spec.state_dim, 22 | action_dim=data_spec.n_actions, 23 | input_res=data_spec.res, 24 | n_input_frames=sampler_config.n_frames, 25 | ), 26 | prior_model_checkpoint=os.path.join(os.environ["EXP_DIR"], 27 | "skill_prior_learning/maze/flat"), 28 | )) 29 | 30 | # update critic+policy to handle multi-frame combined observation 31 | agent_config.critic = SplitObsMLPCritic 32 | agent_config.critic_params.unused_obs_size = 32**2*3 * sampler_config.n_frames 33 | 34 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/demo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from spirl.data.block_stacking.src import robosuite as suite 3 | 4 | if __name__ == "__main__": 5 | 6 | # get the list of all environments 7 | envs = sorted(core.data.block_stacking.src.robosuite.environments.ALL_ENVS) 8 | 9 | # print info and select an environment 10 | print("Welcome to Surreal Robotics Suite v{}!".format(suite.__version__)) 11 | print(suite.__logo__) 12 | print("Here is a list of environments in the suite:\n") 13 | 14 | for k, env in enumerate(envs): 15 | print("[{}] {}".format(k, env)) 16 | print() 17 | try: 18 | s = input( 19 | "Choose an environment to run " 20 | + "(enter a number from 0 to {}): ".format(len(envs) - 1) 21 | ) 22 | # parse input into a number within range 23 | k = min(max(int(s), 0), len(envs)) 24 | except: 25 | print("Input is not valid. Use 0 by default.") 26 | k = 0 27 | 28 | # initialize the task 29 | env = suite.make( 30 | envs[k], 31 | has_renderer=True, 32 | ignore_done=True, 33 | use_camera_obs=False, 34 | control_freq=100, 35 | ) 36 | env.reset() 37 | env.viewer.set_camera(camera_id=0) 38 | 39 | # do visualization 40 | for i in range(1000): 41 | action = np.random.randn(env.dof) 42 | obs, reward, done, _ = env.step(action) 43 | env.render() 44 | -------------------------------------------------------------------------------- /spirl/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class DataSubsampler: 5 | def __init__(self, aggregator): 6 | self._aggregator = aggregator 7 | 8 | def __call__(self, *args, **kwargs): 9 | raise NotImplementedError("This function needs to be implemented by sub-classes!") 10 | 11 | 12 | class FixedFreqSubsampler(DataSubsampler): 13 | """Subsamples input array's first dimension by skipping given number of frames.""" 14 | def __init__(self, n_skip, aggregator=None): 15 | super().__init__(aggregator) 16 | self._n_skip = n_skip 17 | 18 | def __call__(self, val, idxs=None, aggregate=False): 19 | """Subsamples with idxs if given, aggregates with aggregator if aggregate=True.""" 20 | if self._n_skip == 0: 21 | return val, None 22 | 23 | if idxs is None: 24 | seq_len = val.shape[0] 25 | idxs = np.arange(0, seq_len - 1, self._n_skip + 1) 26 | 27 | if aggregate: 28 | assert self._aggregator is not None # no aggregator given! 29 | return self._aggregator(val, idxs), idxs 30 | else: 31 | return val[idxs], idxs 32 | 33 | 34 | class Aggregator: 35 | def __call__(self, *args, **kwargs): 36 | raise NotImplementedError("This function needs to be implemented by sub-classes!") 37 | 38 | 39 | class SumAggregator(Aggregator): 40 | def __call__(self, val, idxs): 41 | return np.add.reduceat(val, idxs, axis=0) 42 | 43 | -------------------------------------------------------------------------------- /spirl/utils/video_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | from torchvision.transforms import Resize 5 | 6 | 7 | def ch_first2last(video): 8 | return video.transpose((0,2,3,1)) 9 | 10 | 11 | def ch_last2first(video): 12 | return video.transpose((0,3,1,2)) 13 | 14 | 15 | def resize_video(video, size): 16 | if video.shape[1] == 3: 17 | video = np.transpose(video, (0,2,3,1)) 18 | transformed_video = np.stack([np.asarray(Resize(size)(Image.fromarray(im))) for im in video], axis=0) 19 | return transformed_video 20 | 21 | 22 | def _make_dir(filename): 23 | folder = os.path.dirname(filename) 24 | if not os.path.exists(folder): 25 | os.makedirs(folder) 26 | 27 | 28 | def save_video(video_frames, filename, fps=60, video_format='mp4'): 29 | assert fps == int(fps), fps 30 | import skvideo.io 31 | _make_dir(filename) 32 | 33 | skvideo.io.vwrite( 34 | filename, 35 | video_frames, 36 | inputdict={ 37 | '-r': str(int(fps)), 38 | }, 39 | outputdict={ 40 | '-f': video_format, 41 | '-pix_fmt': 'yuv420p', # '-pix_fmt=yuv420p' needed for osx https://github.com/scikit-video/scikit-video/issues/74 42 | } 43 | ) 44 | 45 | 46 | def create_video_grid(col_and_row_frames): 47 | video_grid_frames = np.concatenate([ 48 | np.concatenate(row_frames, axis=-2) 49 | for row_frames in col_and_row_frames 50 | ], axis=-3) 51 | 52 | return video_grid_frames 53 | 54 | 55 | -------------------------------------------------------------------------------- /spirl/configs/hrl/kitchen/spirl_cl/conf.py: -------------------------------------------------------------------------------- 1 | from spirl.configs.hrl.kitchen.spirl.conf import * 2 | from spirl.models.closed_loop_spirl_mdl import ClSPiRLMdl 3 | from spirl.rl.policies.cl_model_policies import ClModelPolicy 4 | 5 | # update model params to conditioned decoder on state 6 | ll_model_params.cond_decode = True 7 | 8 | # create LL closed-loop policy 9 | ll_policy_params = AttrDict( 10 | policy_model=ClSPiRLMdl, 11 | policy_model_params=ll_model_params, 12 | policy_model_checkpoint=os.path.join(os.environ["EXP_DIR"], 13 | "skill_prior_learning/kitchen/hierarchical_cl"), 14 | ) 15 | ll_policy_params.update(ll_model_params) 16 | 17 | # create LL SAC agent (by default we will only use it for rolling out decoded skills, not finetuning skill decoder) 18 | ll_agent_config = AttrDict( 19 | policy=ClModelPolicy, 20 | policy_params=ll_policy_params, 21 | critic=MLPCritic, # LL critic is not used since we are not finetuning LL 22 | critic_params=hl_critic_params 23 | ) 24 | 25 | # update HL policy model params 26 | hl_policy_params.update(AttrDict( 27 | prior_model=ll_policy_params.policy_model, 28 | prior_model_params=ll_policy_params.policy_model_params, 29 | prior_model_checkpoint=ll_policy_params.policy_model_checkpoint, 30 | )) 31 | 32 | # register new LL agent in agent_config and turn off LL agent updates 33 | agent_config.update(AttrDict( 34 | ll_agent=SACAgent, 35 | ll_agent_params=ll_agent_config, 36 | update_ll=False, 37 | )) 38 | 39 | 40 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/scripts/demo_gripper_selection.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script shows you how to select gripper for an environment. 3 | This is controlled by gripper_type keyword argument 4 | """ 5 | from spirl.data.block_stacking.src import robosuite as suite 6 | from spirl.data.block_stacking.src.robosuite.wrappers import GymWrapper 7 | 8 | 9 | if __name__ == "__main__": 10 | 11 | grippers = ["TwoFingerGripper", "PR2Gripper", "RobotiqGripper"] 12 | 13 | for gripper in grippers: 14 | 15 | # create environment with selected grippers 16 | env = GymWrapper( 17 | suite.make( 18 | "SawyerPickPlace", 19 | gripper_type=gripper, 20 | use_camera_obs=False, # do not use pixel observations 21 | has_offscreen_renderer=False, # not needed since not using pixel obs 22 | has_renderer=True, # make sure we can render to the screen 23 | reward_shaping=True, # use dense rewards 24 | control_freq=100, # control should happen fast enough so that simulation looks smooth 25 | ) 26 | ) 27 | 28 | # run a random policy 29 | observation = env.reset() 30 | for t in range(500): 31 | env.render() 32 | action = env.action_space.sample() 33 | observation, reward, done, info = env.step(action) 34 | if done: 35 | print("Episode finished after {} timesteps".format(t + 1)) 36 | break 37 | 38 | # close window 39 | env.close() 40 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/arenas/bins_arena.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from spirl.data.block_stacking.src.robosuite.models.arenas import Arena 3 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import xml_path_completion 4 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import array_to_string, string_to_array 5 | 6 | 7 | class BinsArena(Arena): 8 | """Workspace that contains two bins placed side by side.""" 9 | 10 | def __init__( 11 | self, table_full_size=(0.39, 0.49, 0.82), table_friction=(1, 0.005, 0.0001) 12 | ): 13 | """ 14 | Args: 15 | table_full_size: full dimensions of the table 16 | friction: friction parameters of the table 17 | """ 18 | super().__init__(xml_path_completion("arenas/bins_arena.xml")) 19 | 20 | self.table_full_size = np.array(table_full_size) 21 | self.table_half_size = self.table_full_size / 2 22 | self.table_friction = table_friction 23 | 24 | self.floor = self.worldbody.find("./geom[@name='floor']") 25 | self.bin1_body = self.worldbody.find("./body[@name='bin1']") 26 | self.bin2_body = self.worldbody.find("./body[@name='bin2']") 27 | 28 | self.configure_location() 29 | 30 | def configure_location(self): 31 | self.bottom_pos = np.array([0, 0, 0]) 32 | self.floor.set("pos", array_to_string(self.bottom_pos)) 33 | 34 | @property 35 | def bin_abs(self): 36 | """Returns the absolute position of table top""" 37 | return string_to_array(self.bin1_body.get("pos")) 38 | -------------------------------------------------------------------------------- /spirl/rl/envs/maze.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import d4rl 3 | 4 | from spirl.rl.components.environment import GymEnv 5 | from spirl.utils.general_utils import ParamDict, AttrDict 6 | 7 | 8 | class MazeEnv(GymEnv): 9 | """Shallow wrapper around gym env for maze envs.""" 10 | def __init__(self, *args, **kwargs): 11 | super().__init__(*args, **kwargs) 12 | 13 | def _default_hparams(self): 14 | default_dict = ParamDict({ 15 | }) 16 | return super()._default_hparams().overwrite(default_dict) 17 | 18 | def reset(self): 19 | super().reset() 20 | if self.TARGET_POS is not None and self.START_POS is not None: 21 | self._env.set_target(self.TARGET_POS) 22 | self._env.reset_to_location(self.START_POS) 23 | self._env.render(mode='rgb_array') # these are necessary to make sure new state is rendered on first frame 24 | obs, _, _, _ = self._env.step(np.zeros_like(self._env.action_space.sample())) 25 | return self._wrap_observation(obs) 26 | 27 | def step(self, *args, **kwargs): 28 | obs, rew, done, info = super().step(*args, **kwargs) 29 | return obs, np.float64(rew), done, info # casting reward to float64 is important for getting shape later 30 | 31 | 32 | class ACRandMaze0S40Env(MazeEnv): 33 | START_POS = np.array([10., 24.]) 34 | TARGET_POS = np.array([18., 8.]) 35 | 36 | def _default_hparams(self): 37 | default_dict = ParamDict({ 38 | 'name': "maze2d-randMaze0S40-ac-v0", 39 | }) 40 | return super()._default_hparams().overwrite(default_dict) 41 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/objects/square-nut.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/robots/baxter_robot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from spirl.data.block_stacking.src.robosuite.models.robots.robot import Robot 3 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import xml_path_completion, array_to_string 4 | 5 | 6 | class Baxter(Robot): 7 | """Baxter is a hunky bimanual robot designed by Rethink Robotics.""" 8 | 9 | def __init__(self): 10 | super().__init__(xml_path_completion("robots/baxter/robot.xml")) 11 | 12 | self.bottom_offset = np.array([0, 0, -0.913]) 13 | self.left_hand = self.worldbody.find(".//body[@name='left_hand']") 14 | 15 | def set_base_xpos(self, pos): 16 | """Places the robot on position @pos.""" 17 | node = self.worldbody.find("./body[@name='base']") 18 | node.set("pos", array_to_string(pos - self.bottom_offset)) 19 | 20 | @property 21 | def dof(self): 22 | return 14 23 | 24 | @property 25 | def joints(self): 26 | out = [] 27 | for s in ["right_", "left_"]: 28 | out.extend(s + a for a in ["s0", "s1", "e0", "e1", "w0", "w1", "w2"]) 29 | return out 30 | 31 | @property 32 | def init_qpos(self): 33 | # Arms ready to work on the table 34 | return np.array([ 35 | 0.535, -0.093, 0.038, 0.166, 0.643, 1.960, -1.297, 36 | -0.518, -0.026, -0.076, 0.175, -0.748, 1.641, -0.158]) 37 | 38 | # Arms half extended 39 | return np.array([ 40 | 0.752, -0.038, -0.021, 0.161, 0.348, 2.095, -0.531, 41 | -0.585, -0.117, -0.037, 0.164, -0.536, 1.543, 0.204]) 42 | 43 | # Arms fully extended 44 | return np.zeros(14) 45 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/arenas/table_arena.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/scripts/demo_baxter_ik_control.py: -------------------------------------------------------------------------------- 1 | """End-effector control for bimanual Baxter robot. 2 | 3 | This script shows how to use inverse kinematics solver from Bullet 4 | to command the end-effectors of two arms of the Baxter robot. 5 | """ 6 | 7 | import os 8 | import numpy as np 9 | 10 | from spirl.data.block_stacking.src import robosuite 11 | from spirl.data.block_stacking.src.robosuite.wrappers import IKWrapper 12 | 13 | 14 | if __name__ == "__main__": 15 | 16 | # initialize a Baxter environment 17 | env = robosuite.make( 18 | "BaxterLift", 19 | ignore_done=True, 20 | has_renderer=True, 21 | gripper_visualization=True, 22 | use_camera_obs=False, 23 | ) 24 | env = IKWrapper(env) 25 | 26 | obs = env.reset() 27 | 28 | # rotate the gripper so we can see it easily 29 | env.set_robot_joint_positions([ 30 | 0.00, -0.55, 0.00, 1.28, 0.00, 0.26, 0.00, 31 | 0.00, -0.55, 0.00, 1.28, 0.00, 0.26, 0.00, 32 | ]) 33 | 34 | bullet_data_path = os.path.join(robosuite.models.assets_root, "bullet_data") 35 | 36 | def robot_jpos_getter(): 37 | return np.array(env._joint_positions) 38 | 39 | for t in range(100000): 40 | omega = 2 * np.pi / 1000. 41 | A = 5e-4 42 | dpos_right = np.array([A * np.cos(omega * t), 0, A * np.sin(omega * t)]) 43 | dpos_left = np.array([A * np.sin(omega * t), A * np.cos(omega * t), 0]) 44 | dquat = np.array([0, 0, 0, 1]) 45 | grasp = 0. 46 | action = np.concatenate([dpos_right, dquat, dpos_left, dquat, [grasp, grasp]]) 47 | 48 | obs, reward, done, info = env.step(action) 49 | env.render() 50 | 51 | if done: 52 | break 53 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/arenas/pegs_arena.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from spirl.data.block_stacking.src.robosuite.models.arenas import Arena 3 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import xml_path_completion 4 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import array_to_string, string_to_array 5 | 6 | 7 | class PegsArena(Arena): 8 | """Workspace that contains a tabletop with two fixed pegs.""" 9 | 10 | def __init__( 11 | self, table_full_size=(0.45, 0.69, 0.82), table_friction=(1, 0.005, 0.0001) 12 | ): 13 | """ 14 | Args: 15 | table_full_size: full dimensions of the table 16 | table_friction: friction parameters of the table 17 | """ 18 | super().__init__(xml_path_completion("arenas/pegs_arena.xml")) 19 | 20 | self.table_full_size = np.array(table_full_size) 21 | self.table_half_size = self.table_full_size / 2 22 | self.table_friction = table_friction 23 | 24 | self.floor = self.worldbody.find("./geom[@name='floor']") 25 | self.table_body = self.worldbody.find("./body[@name='table']") 26 | self.peg1_body = self.worldbody.find("./body[@name='peg1']") 27 | self.peg2_body = self.worldbody.find("./body[@name='peg2']") 28 | self.table_collision = self.table_body.find("./geom[@name='table_collision']") 29 | 30 | self.configure_location() 31 | 32 | def configure_location(self): 33 | self.bottom_pos = np.array([0, 0, 0]) 34 | self.floor.set("pos", array_to_string(self.bottom_pos)) 35 | 36 | @property 37 | def table_top_abs(self): 38 | """ 39 | Returns the absolute position of table top. 40 | """ 41 | return string_to_array(self.table_body.get("pos")) 42 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/scripts/demo_video_recording.py: -------------------------------------------------------------------------------- 1 | """ 2 | Record video of agent episodes with the imageio library. 3 | This script uses offscreen rendering. 4 | 5 | Example: 6 | $ python demo_video_recording.py --environment SawyerLift 7 | """ 8 | 9 | import argparse 10 | import imageio 11 | import numpy as np 12 | 13 | from spirl.data.block_stacking.src.robosuite import make 14 | 15 | 16 | if __name__ == "__main__": 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--environment", type=str, default="SawyerStack") 20 | parser.add_argument("--video_path", type=str, default="video.mp4") 21 | parser.add_argument("--timesteps", type=int, default=500) 22 | parser.add_argument("--height", type=int, default=512) 23 | parser.add_argument("--width", type=int, default=512) 24 | parser.add_argument("--skip_frame", type=int, default=1) 25 | args = parser.parse_args() 26 | 27 | # initialize an environment with offscreen renderer 28 | env = make( 29 | args.environment, 30 | has_renderer=False, 31 | ignore_done=True, 32 | use_camera_obs=True, 33 | use_object_obs=False, 34 | camera_height=args.height, 35 | camera_width=args.width, 36 | ) 37 | 38 | obs = env.reset() 39 | dof = env.dof 40 | 41 | # create a video writer with imageio 42 | writer = imageio.get_writer(args.video_path, fps=20) 43 | 44 | frames = [] 45 | for i in range(args.timesteps): 46 | 47 | # run a uniformly random agent 48 | action = 0.5 * np.random.randn(dof) 49 | obs, reward, done, info = env.step(action) 50 | 51 | # dump a frame from every K frames 52 | if i % args.skip_frame == 0: 53 | frame = obs["image"][::-1] 54 | writer.append_data(frame) 55 | print("Saving frame #{}".format(i)) 56 | 57 | if done: 58 | break 59 | 60 | writer.close() 61 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/utils/numbered_box_object.py: -------------------------------------------------------------------------------- 1 | import os 2 | import xml.etree.ElementTree as ET 3 | from spirl.data.block_stacking.src.robosuite.models.objects import BoxObject 4 | 5 | 6 | class NumberedBoxObject(BoxObject): 7 | def __init__(self, **kwargs): 8 | self.number = kwargs.pop("number") 9 | super().__init__(**kwargs) 10 | 11 | if self.number is not None: 12 | asset_dir = os.path.join(os.getcwd(), "spirl/data/block_stacking/assets") 13 | texture_path = os.path.join(asset_dir, "textures/obj{}.png".format(self.number)) 14 | 15 | texture = ET.SubElement(self.asset, "texture") 16 | texture.set("file", texture_path) 17 | texture.set("name", "obj-{}-texture".format(self.number)) 18 | texture.set("gridsize", "1 2") 19 | texture.set("gridlayout", ".U") 20 | texture.set("rgb1", "1.0 1.0 1.0") 21 | texture.set("vflip", "true") 22 | texture.set("hflip", "true") 23 | 24 | material = ET.SubElement(self.asset, "material") 25 | material.set("name", "obj-{}-material".format(self.number)) 26 | material.set("reflectance", "0.5") 27 | material.set("specular", "0.5") 28 | material.set("shininess", "0.1") 29 | material.set("texture", "obj-{}-texture".format(self.number)) 30 | material.set("texuniform", "false") 31 | 32 | def get_collision_attrib_template(self): 33 | template = super().get_collision_attrib_template() 34 | if self.number is not None: 35 | template.update({"material": "obj-{}-material".format(self.number)}) 36 | return template 37 | 38 | def get_visual_attrib_template(self): 39 | template = super().get_visual_attrib_template() 40 | if self.number is not None: 41 | template.update({"material": "obj-{}-material".format(self.number)}) 42 | return template 43 | -------------------------------------------------------------------------------- /spirl/rl/envs/kitchen.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import defaultdict 3 | import d4rl 4 | 5 | from spirl.utils.general_utils import AttrDict 6 | from spirl.utils.general_utils import ParamDict 7 | from spirl.rl.components.environment import GymEnv 8 | 9 | 10 | class KitchenEnv(GymEnv): 11 | """Tiny wrapper around GymEnv for Kitchen tasks.""" 12 | SUBTASKS = ['microwave', 'kettle', 'slide cabinet', 'hinge cabinet', 'bottom burner', 'light switch', 'top burner'] 13 | def _default_hparams(self): 14 | return super()._default_hparams().overwrite(ParamDict({ 15 | 'name': "kitchen-mixed-v0", 16 | })) 17 | 18 | def step(self, *args, **kwargs): 19 | obs, rew, done, info = super().step(*args, **kwargs) 20 | return obs, np.float64(rew), done, self._postprocess_info(info) # casting reward to float64 is important for getting shape later 21 | 22 | def reset(self): 23 | self.solved_subtasks = defaultdict(lambda: 0) 24 | return super().reset() 25 | 26 | def get_episode_info(self): 27 | info = super().get_episode_info() 28 | info.update(AttrDict(self.solved_subtasks)) 29 | return info 30 | 31 | def _postprocess_info(self, info): 32 | """Sorts solved subtasks into separately logged elements.""" 33 | completed_subtasks = info.pop("completed_tasks") 34 | for task in self.SUBTASKS: 35 | self.solved_subtasks[task] = 1 if task in completed_subtasks or self.solved_subtasks[task] else 0 36 | return info 37 | 38 | 39 | class NoGoalKitchenEnv(KitchenEnv): 40 | """Splits off goal from obs.""" 41 | def step(self, *args, **kwargs): 42 | obs, rew, done, info = super().step(*args, **kwargs) 43 | obs = obs[:int(obs.shape[0]/2)] 44 | return obs, rew, done, info 45 | 46 | def reset(self, *args, **kwargs): 47 | obs = super().reset(*args, **kwargs) 48 | return obs[:int(obs.shape[0]/2)] 49 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/scripts/demo_pygame_renderer.py: -------------------------------------------------------------------------------- 1 | """pygame rendering demo. 2 | 3 | This script provides an example of using the pygame library for rendering 4 | camera observations as an alternative to the default mujoco_py renderer. 5 | 6 | Example: 7 | $ python run_pygame_renderer.py --environment BaxterPegInHole --width 1000 --height 1000 8 | """ 9 | 10 | import sys 11 | import argparse 12 | import pygame 13 | import numpy as np 14 | 15 | from spirl.data.block_stacking.src import robosuite 16 | 17 | if __name__ == "__main__": 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--environment", type=str, default="BaxterLift") 21 | parser.add_argument("--timesteps", type=int, default=10000) 22 | parser.add_argument("--width", type=int, default=512) 23 | parser.add_argument("--height", type=int, default=384) 24 | args = parser.parse_args() 25 | 26 | width = args.width 27 | height = args.height 28 | screen = pygame.display.set_mode((width, height)) 29 | 30 | env = robosuite.make( 31 | args.environment, 32 | has_renderer=False, 33 | ignore_done=True, 34 | camera_height=height, 35 | camera_width=width, 36 | show_gripper_visualization=True, 37 | use_camera_obs=True, 38 | use_object_obs=False, 39 | ) 40 | 41 | for i in range(args.timesteps): 42 | 43 | # issue random actions 44 | action = 0.5 * np.random.randn(env.dof) 45 | obs, reward, done, info = env.step(action) 46 | 47 | for event in pygame.event.get(): 48 | if event.type == pygame.QUIT: 49 | sys.exit() 50 | 51 | # read camera observation 52 | im = np.flip(obs["image"].transpose((1, 0, 2)), 1) 53 | pygame.pixelcopy.array_to_surface(screen, im) 54 | pygame.display.update() 55 | 56 | if i % 100 == 0: 57 | print("step #{}".format(i)) 58 | 59 | if done: 60 | break 61 | -------------------------------------------------------------------------------- /spirl/rl/utils/robosuite_utils.py: -------------------------------------------------------------------------------- 1 | from spirl.data.block_stacking.src.robosuite.environments.base import MujocoEnv 2 | from mujoco_py import MjSim, MjRenderContextOffscreen 3 | from spirl.data.block_stacking.src.robosuite.utils import MujocoPyRenderer 4 | 5 | 6 | class FastResetMujocoEnv(MujocoEnv): 7 | """Only loads the mujoco XML file once to allow for quicker resets.""" 8 | def _reset_internal(self): 9 | """Resets simulation internal configurations.""" 10 | # instantiate simulation from MJCF model 11 | self._load_model() 12 | if not hasattr(self, "mjpy_model"): 13 | self.mjpy_model = self.model.get_model(mode="mujoco_py") 14 | self.sim = MjSim(self.mjpy_model) 15 | self.initialize_time(self.control_freq) 16 | 17 | # create visualization screen or renderer 18 | if self.has_renderer and self.viewer is None: 19 | self.viewer = MujocoPyRenderer(self.sim) 20 | self.viewer.viewer.vopt.geomgroup[0] = ( 21 | 1 if self.render_collision_mesh else 0 22 | ) 23 | self.viewer.viewer.vopt.geomgroup[1] = 1 if self.render_visual_mesh else 0 24 | 25 | # hiding the overlay speeds up rendering significantly 26 | self.viewer.viewer._hide_overlay = True 27 | 28 | elif self.has_offscreen_renderer: 29 | if self.sim._render_context_offscreen is None: 30 | render_context = MjRenderContextOffscreen(self.sim) 31 | self.sim.add_render_context(render_context) 32 | self.sim._render_context_offscreen.vopt.geomgroup[0] = ( 33 | 1 if self.render_collision_mesh else 0 34 | ) 35 | self.sim._render_context_offscreen.vopt.geomgroup[1] = ( 36 | 1 if self.render_visual_mesh else 0 37 | ) 38 | 39 | # additional housekeeping 40 | self.sim_state_initial = self.sim.get_state() 41 | self._get_reference() 42 | self.cur_time = 0 43 | self.timestep = 0 44 | self.done = False 45 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/grippers/pr2_gripper.py: -------------------------------------------------------------------------------- 1 | """ 2 | 4 dof gripper with two fingers and its open/close variant 3 | """ 4 | import numpy as np 5 | 6 | from spirl.data.block_stacking.src.robosuite.models.grippers import Gripper 7 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import xml_path_completion 8 | 9 | 10 | class PR2GripperBase(Gripper): 11 | """ 12 | A 4 dof gripper with two fingers. 13 | """ 14 | 15 | def __init__(self): 16 | super().__init__(xml_path_completion("grippers/pr2_gripper.xml")) 17 | 18 | def format_action(self, action): 19 | return action 20 | 21 | @property 22 | def init_qpos(self): 23 | return np.zeros(4) 24 | 25 | @property 26 | def joints(self): 27 | return [ 28 | "r_gripper_r_finger_joint", 29 | "r_gripper_l_finger_joint", 30 | "r_gripper_r_finger_tip_joint", 31 | "r_gripper_l_finger_tip_joint", 32 | ] 33 | 34 | @property 35 | def dof(self): 36 | return 4 37 | 38 | def contact_geoms(self): 39 | return [ 40 | "r_gripper_l_finger", 41 | "r_gripper_l_finger_tip", 42 | "r_gripper_r_finger", 43 | "r_gripper_r_finger_tip", 44 | ] 45 | 46 | @property 47 | def visualization_sites(self): 48 | return ["grip_site", "grip_site_cylinder"] 49 | 50 | @property 51 | def left_finger_geoms(self): 52 | return ["r_gripper_l_finger", "r_gripper_l_finger_tip"] 53 | 54 | @property 55 | def right_finger_geoms(self): 56 | return ["r_gripper_r_finger", "r_gripper_r_finger_tip"] 57 | 58 | 59 | class PR2Gripper(PR2GripperBase): 60 | """ 61 | Open/close variant of PR2 gripper. 62 | """ 63 | 64 | def format_action(self, action): 65 | """ 66 | Args: 67 | action: 1 => open, -1 => closed 68 | """ 69 | assert len(action) == 1 70 | return np.ones(4) * action 71 | 72 | @property 73 | def dof(self): 74 | return 1 75 | -------------------------------------------------------------------------------- /spirl/rl/utils/rollout_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import h5py 4 | import numpy as np 5 | 6 | 7 | class RolloutSaver(object): 8 | """Saves rollout episodes to a target directory.""" 9 | def __init__(self, save_dir): 10 | if not os.path.exists(save_dir): 11 | os.makedirs(save_dir) 12 | self.save_dir = save_dir 13 | self.counter = 0 14 | 15 | def save_rollout_to_file(self, episode): 16 | """Saves an episode to the next file index of the target folder.""" 17 | # get save path 18 | save_path = os.path.join(self.save_dir, "rollout_{}.h5".format(self.counter)) 19 | 20 | # save rollout to file 21 | f = h5py.File(save_path, "w") 22 | f.create_dataset("traj_per_file", data=1) 23 | 24 | # store trajectory info in traj0 group 25 | traj_data = f.create_group("traj0") 26 | traj_data.create_dataset("states", data=np.array(episode.observation)) 27 | traj_data.create_dataset("images", data=np.array(episode.image, dtype=np.uint8)) 28 | traj_data.create_dataset("actions", data=np.array(episode.action)) 29 | 30 | terminals = np.array(episode.done) 31 | if np.sum(terminals) == 0: 32 | terminals[-1] = True 33 | 34 | # build pad-mask that indicates how long sequence is 35 | is_terminal_idxs = np.nonzero(terminals)[0] 36 | pad_mask = np.zeros((len(terminals),)) 37 | pad_mask[:is_terminal_idxs[0]] = 1. 38 | traj_data.create_dataset("pad_mask", data=pad_mask) 39 | 40 | f.close() 41 | 42 | self.counter += 1 43 | 44 | def _resize_video(self, images, dim=64): 45 | """Resize a video in numpy array form to target dimension.""" 46 | ret = np.zeros((images.shape[0], dim, dim, 3)) 47 | 48 | for i in range(images.shape[0]): 49 | ret[i] = cv2.resize(images[i], dsize=(dim, dim), 50 | interpolation=cv2.INTER_CUBIC) 51 | 52 | return ret.astype(np.uint8) 53 | 54 | def reset(self): 55 | """Resets counter.""" 56 | self.counter = 0 57 | -------------------------------------------------------------------------------- /spirl/configs/skill_prior_learning/block_stacking/hierarchical_cl/README.md: -------------------------------------------------------------------------------- 1 | # Image-based SPiRL w/ Closed-Loop Skill Decoder 2 | 3 | This version of the SPiRL model uses a [closed-loop action decoder](../../../../models/closed_loop_spirl_mdl.py#L55): 4 | in contrast to the original SPiRL model it takes the current environment observation as input in every skill decoding step. 5 | 6 | This image-based model is a direct extension of the 7 | [state-based SPiRL model with closed-loop skill decoder](../../kitchen/hierarchical_cl/README.md). 8 | Similar to the state-based model we find that the image-based closed-loop model improves performance over the original 9 | image-based SPiRL model, particularly in tasks that require precise control. 10 | We evaluate it on a more challenging, sparse reward version of the block stacking environment 11 | where the agent is rewarded for the height of the tower it built, but does not receive any rewards for picking or lifting 12 | blocks. We find that on this challenging environment, the closed-loop skill decoder ("SPiRLv2") outperforms the original 13 | SPiRL model with open-loop skill decoder ("SPiRLv1"). 14 | 15 |

16 | 17 |

18 | 19 | 20 | We also tried the closed-loop model on the image-based maze navigation task, but did not find it to improve performance, 21 | which we attribute to the easier control task that does not require closed-loop control. 22 | 23 | ## Example Commands 24 | 25 | To train the image-based SPiRL model with closed-loop action decoder on the block stacking environment, run the following command: 26 | ``` 27 | python3 spirl/train.py --path=spirl/configs/skill_prior_learning/block_stacking/hierarchical_cl --val_data_size=160 28 | ``` 29 | 30 | To train a downstream task policy with RL using the closed-loop image-based SPiRL model 31 | on the sparse reward block stacking environment, run the following command: 32 | ``` 33 | python3 spirl/rl/train.py --path=spirl/configs/hrl/block_stacking/spirl_cl --seed=0 --prefix=SPIRLv2_block_stacking_seed0 34 | ``` 35 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/devices/README.md: -------------------------------------------------------------------------------- 1 | # Devices 2 | 3 | Devices are used to read user input and collect human demonstrations. Demonstrations can be collected by either using a keyboard or using a [SpaceNavigator 3D Mouse](https://www.3dconnexion.com/spacemouse_compact/en/) with the [collect_human_demonstrations](robosuite/scripts/collect_human_demonstrations.py) script. More generally, we support any interface that implements the [Device](device.py) abstract base class. In order to support your own custom device, simply subclass this base class and implement the required methods. 4 | 5 | ## Keyboard 6 | 7 | We support keyboard input through the GLFW window created by the mujoco-py renderer. 8 | 9 | **Keyboard controls** 10 | 11 | Note that the rendering window must be active for these commands to work. 12 | 13 | | Keys | Command | 14 | | :------: | :--------------------------------: | 15 | | q | reset simulation | 16 | | spacebar | toggle gripper (open/close) | 17 | | w-a-s-d | move arm horizontally in x-y plane | 18 | | r-f | move arm vertically | 19 | | z-x | rotate arm about x-axis | 20 | | t-g | rotate arm about y-axis | 21 | | c-v | rotate arm about z-axis | 22 | | ESC | quit | 23 | 24 | ## SpaceNavigator 3D Mouse 25 | 26 | We support the use of a [SpaceNavigator 3D Mouse](https://www.3dconnexion.com/spacemouse_compact/en/) as well. 27 | 28 | **SpaceNavigator 3D Mouse controls** 29 | 30 | | Control | Command | 31 | | :-----------------------: | :-----------------------------------: | 32 | | Right button | reset simulation | 33 | | Left button (hold) | close gripper | 34 | | Move mouse laterally | move arm horizontally in x-y plane | 35 | | Move mouse vertically | move arm vertically | 36 | | Twist mouse about an axis | rotate arm about a corresponding axis | 37 | | ESC (keyboard) | quit | 38 | 39 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/robots/robot.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | from spirl.data.block_stacking.src.robosuite.models.base import MujocoXML 4 | from spirl.data.block_stacking.src.robosuite.utils import XMLError 5 | 6 | 7 | class Robot(MujocoXML): 8 | """Base class for all robot models.""" 9 | 10 | def __init__(self, fname): 11 | """Initializes from file @fname.""" 12 | super().__init__(fname) 13 | # key: gripper name and value: gripper model 14 | self.grippers = OrderedDict() 15 | 16 | def add_gripper(self, arm_name, gripper): 17 | """ 18 | Mounts gripper to arm. 19 | 20 | Throws error if robot already has a gripper or gripper type is incorrect. 21 | 22 | Args: 23 | arm_name (str): name of arm mount 24 | gripper (MujocoGripper instance): gripper MJCF model 25 | """ 26 | if arm_name in self.grippers: 27 | raise ValueError("Attempts to add multiple grippers to one body") 28 | 29 | arm_subtree = self.worldbody.find(".//body[@name='{}']".format(arm_name)) 30 | 31 | for actuator in gripper.actuator: 32 | 33 | if actuator.get("name") is None: 34 | raise XMLError("Actuator has no name") 35 | 36 | if not actuator.get("name").startswith("gripper"): 37 | raise XMLError( 38 | "Actuator name {} does not have prefix 'gripper'".format( 39 | actuator.get("name") 40 | ) 41 | ) 42 | 43 | for body in gripper.worldbody: 44 | arm_subtree.append(body) 45 | 46 | self.merge(gripper, merge_body=False) 47 | self.grippers[arm_name] = gripper 48 | 49 | @property 50 | def dof(self): 51 | """Returns the number of DOF of the robot, not including gripper.""" 52 | raise NotImplementedError 53 | 54 | @property 55 | def joints(self): 56 | """Returns a list of joint names of the robot.""" 57 | raise NotImplementedError 58 | 59 | @property 60 | def init_qpos(self): 61 | """Returns default qpos.""" 62 | raise NotImplementedError 63 | -------------------------------------------------------------------------------- /spirl/data/README.md: -------------------------------------------------------------------------------- 1 | ## Downloading Datasets via the Command Line 2 | 3 | To download the dataset files from Google Drive via the command line, you can use the 4 | [gdown](https://github.com/wkentaro/gdown) package. Install it with: 5 | ``` 6 | pip install gdown 7 | ``` 8 | 9 | Then navigate to the folder you want to download the data to and run the following commands: 10 | ``` 11 | # Download Maze Dataset 12 | gdown https://drive.google.com/uc?id=1pXM-EDCwFrfgUjxITBsR48FqW9gMoXYZ 13 | 14 | # Download Block Stacking Dataset 15 | gdown https://drive.google.com/uc?id=1VobNYJQw_Uwax0kbFG7KOXTgv6ja2s1M 16 | ``` 17 | 18 | ## Re-Generating Datasets 19 | 20 | ### Maze Dataset 21 | To regenerate the maze dataset, our fork of the [D4RL repo](https://github.com/kpertsch/d4rl) needs to be cloned and installed. 22 | It includes the script used to generate the maze dataset. Specifically, new data can be created by running: 23 | ``` 24 | cd d4rl 25 | python3 scripts/generate_randMaze2d_datasets.py --render --agent_centric --save_images --data_dir=path_to_your_dir 26 | ``` 27 | Optionally, an argument `--batch_idx` allows to automatically generate a subfolder in `data_dir` with the batch index, 28 | so that multiple data generation scripts with different batch indices can be run in parallel 29 | for accelerated data generation. 30 | 31 | The number of trajectories that are getting generated can be controlled through the argument `--num_samples`; the size 32 | of the randomly generated training mazes can be changed with `--rand_maze_size`. For a full list of all arguments, see 33 | [```scripts/generate_randMaze2d_datasets.py```](https://github.com/kpertsch/d4rl/scripts/generate_randMaze2d_datasets.py#L72). 34 | 35 | 36 | ### Block Stacking Dataset 37 | To regenerate the block stacking dataset we can use the config provided in [```spirl/configs/data_collect/block_stacking```](spirl/configs/data_collect/block_stacking/conf.py). 38 | To start generation, run: 39 | ``` 40 | python3 spirl/rl/train.py --path=spirl/configs/data_collect/block_stacking --mode=rollout --n_val_samples=2e5 --seed=42 --data_dir=path_to_your_dir 41 | ``` 42 | If you want to run multiple data generation jobs in parallel, make sure to change the seed and set a different target 43 | data directory. -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/grippers/robotiq_three_finger_gripper.py: -------------------------------------------------------------------------------- 1 | """ 2 | Gripper with 11-DoF controlling three fingers and its open/close variant. 3 | """ 4 | import numpy as np 5 | 6 | from spirl.data.block_stacking.src.robosuite.models.grippers import Gripper 7 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import xml_path_completion 8 | 9 | 10 | class RobotiqThreeFingerGripperBase(Gripper): 11 | """ 12 | Gripper with 11 dof controlling three fingers. 13 | """ 14 | 15 | def __init__(self): 16 | super().__init__(xml_path_completion("grippers/robotiq_gripper_s.xml")) 17 | 18 | def format_action(self, action): 19 | return action 20 | 21 | @property 22 | def init_qpos(self): 23 | return np.zeros(11) 24 | 25 | @property 26 | def joints(self): 27 | return [ 28 | "palm_finger_1_joint", 29 | "finger_1_joint_1", 30 | "finger_1_joint_2", 31 | "finger_1_joint_3", 32 | "palm_finger_2_joint", 33 | "finger_2_joint_1", 34 | "finger_2_joint_2", 35 | "finger_2_joint_3", 36 | "finger_middle_joint_1", 37 | "finger_middle_joint_2", 38 | "finger_middle_joint_3", 39 | ] 40 | 41 | @property 42 | def dof(self): 43 | return 11 44 | 45 | def contact_geoms(self): 46 | return [ 47 | "f1_l0", 48 | "f1_l1", 49 | "f1_l2", 50 | "f1_l3", 51 | "f2_l0", 52 | "f2_l1", 53 | "f2_l2", 54 | "f2_l3", 55 | "f3_l0", 56 | "f3_l1", 57 | "f3_l2", 58 | "f3_l3", 59 | ] 60 | 61 | @property 62 | def visualization_sites(self): 63 | return ["grip_site", "grip_site_cylinder"] 64 | 65 | 66 | class RobotiqThreeFingerGripper(RobotiqThreeFingerGripperBase): 67 | """ 68 | 1-DoF variant of RobotiqThreeFingerGripperBase. 69 | """ 70 | 71 | def format_action(self, action): 72 | """ 73 | Args: 74 | action: 1 => open, -1 => closed 75 | """ 76 | movement = np.array([0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1]) 77 | return -1 * movement * action 78 | -------------------------------------------------------------------------------- /spirl/configs/rl/kitchen/base_conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from spirl.utils.general_utils import AttrDict 4 | from spirl.rl.agents.ac_agent import SACAgent 5 | from spirl.rl.policies.mlp_policies import MLPPolicy 6 | from spirl.rl.components.critic import MLPCritic 7 | from spirl.rl.components.replay_buffer import UniformReplayBuffer 8 | from spirl.rl.envs.kitchen import KitchenEnv 9 | from spirl.rl.components.normalization import Normalizer 10 | from spirl.configs.default_data_configs.kitchen import data_spec 11 | 12 | 13 | current_dir = os.path.dirname(os.path.realpath(__file__)) 14 | 15 | notes = 'non-hierarchical RL experiments in kitchen env' 16 | 17 | configuration = { 18 | 'seed': 42, 19 | 'agent': SACAgent, 20 | 'environment': KitchenEnv, 21 | 'data_dir': '.', 22 | 'num_epochs': 100, 23 | 'max_rollout_len': 280, 24 | 'n_steps_per_epoch': 50000, 25 | 'n_warmup_steps': 500, 26 | } 27 | configuration = AttrDict(configuration) 28 | 29 | # Policy 30 | policy_params = AttrDict( 31 | action_dim=data_spec.n_actions, 32 | input_dim=data_spec.state_dim, 33 | n_layers=5, # number of policy network layers 34 | nz_mid=256, 35 | max_action_range=1., 36 | ) 37 | 38 | # Critic 39 | critic_params = AttrDict( 40 | action_dim=policy_params.action_dim, 41 | input_dim=policy_params.input_dim, 42 | output_dim=1, 43 | n_layers=2, # number of policy network layers 44 | nz_mid=256, 45 | action_input=True, 46 | ) 47 | 48 | # Replay Buffer 49 | replay_params = AttrDict( 50 | capacity=1e5, 51 | dump_replay=False, 52 | ) 53 | 54 | # Observation Normalization 55 | obs_norm_params = AttrDict( 56 | ) 57 | 58 | # Agent 59 | agent_config = AttrDict( 60 | policy=MLPPolicy, 61 | policy_params=policy_params, 62 | critic=MLPCritic, 63 | critic_params=critic_params, 64 | replay=UniformReplayBuffer, 65 | replay_params=replay_params, 66 | # obs_normalizer=Normalizer, 67 | # obs_normalizer_params=obs_norm_params, 68 | clip_q_target=False, 69 | batch_size=256, 70 | log_video_caption=True, 71 | ) 72 | 73 | # Dataset - Random data 74 | data_config = AttrDict() 75 | data_config.dataset_spec = data_spec 76 | 77 | # Environment 78 | env_config = AttrDict( 79 | reward_norm=1., 80 | ) 81 | 82 | -------------------------------------------------------------------------------- /spirl/configs/rl/maze/base_conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from spirl.utils.general_utils import AttrDict 4 | from spirl.rl.policies.mlp_policies import MLPPolicy 5 | from spirl.rl.components.critic import MLPCritic 6 | from spirl.rl.components.replay_buffer import UniformReplayBuffer 7 | from spirl.rl.envs.maze import ACRandMaze0S40Env 8 | from spirl.rl.components.normalization import Normalizer 9 | from spirl.configs.default_data_configs.maze import data_spec 10 | from spirl.data.maze.src.maze_agents import MazeSACAgent 11 | 12 | 13 | current_dir = os.path.dirname(os.path.realpath(__file__)) 14 | 15 | notes = 'non-hierarchical RL experiments in maze env' 16 | 17 | configuration = { 18 | 'seed': 42, 19 | 'agent': MazeSACAgent, 20 | 'environment': ACRandMaze0S40Env, 21 | 'data_dir': '.', 22 | 'num_epochs': 100, 23 | 'max_rollout_len': 2000, 24 | 'n_steps_per_epoch': 100000, 25 | 'n_warmup_steps': 5e3, 26 | } 27 | configuration = AttrDict(configuration) 28 | 29 | # Policy 30 | policy_params = AttrDict( 31 | action_dim=data_spec.n_actions, 32 | input_dim=data_spec.state_dim, 33 | n_layers=2, # number of policy network layers 34 | nz_mid=256, 35 | max_action_range=1., 36 | ) 37 | 38 | # Critic 39 | critic_params = AttrDict( 40 | action_dim=policy_params.action_dim, 41 | input_dim=policy_params.input_dim, 42 | output_dim=1, 43 | n_layers=2, # number of policy network layers 44 | nz_mid=256, 45 | action_input=True, 46 | ) 47 | 48 | # Replay Buffer 49 | replay_params = AttrDict( 50 | capacity=1e6, 51 | dump_replay=False, 52 | ) 53 | 54 | # Observation Normalization 55 | obs_norm_params = AttrDict( 56 | ) 57 | 58 | # Agent 59 | agent_config = AttrDict( 60 | policy=MLPPolicy, 61 | policy_params=policy_params, 62 | critic=MLPCritic, 63 | critic_params=critic_params, 64 | replay=UniformReplayBuffer, 65 | replay_params=replay_params, 66 | # obs_normalizer=Normalizer, 67 | # obs_normalizer_params=obs_norm_params, 68 | clip_q_target=False, 69 | batch_size=256, 70 | log_videos=False, 71 | ) 72 | 73 | # Dataset - Random data 74 | data_config = AttrDict() 75 | data_config.dataset_spec = data_spec 76 | 77 | # Environment 78 | env_config = AttrDict( 79 | reward_norm=1., 80 | ) 81 | 82 | -------------------------------------------------------------------------------- /spirl/configs/rl/peg_in_hole/base_conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from spirl.utils.general_utils import AttrDict 4 | from spirl.rl.agents.ac_agent import SACAgent 5 | from spirl.rl.policies.mlp_policies import MLPPolicy 6 | from spirl.rl.components.critic import MLPCritic 7 | from spirl.rl.components.replay_buffer import UniformReplayBuffer 8 | from spirl.rl.envs.peg_in_hole import PegInHoleEnv 9 | from spirl.rl.components.normalization import Normalizer 10 | from spirl.configs.default_data_configs.peg_in_hole import data_spec 11 | 12 | 13 | current_dir = os.path.dirname(os.path.realpath(__file__)) 14 | 15 | notes = 'non-hierarchical RL experiments in peg-in-hole env' 16 | 17 | configuration = { 18 | 'seed': 42, 19 | 'agent': SACAgent, 20 | 'environment': PegInHoleEnv, 21 | 'data_dir': '.', 22 | 'num_epochs': 100, 23 | 'max_rollout_len': 280, 24 | 'n_steps_per_epoch': 50000, 25 | 'n_warmup_steps': 500, 26 | } 27 | configuration = AttrDict(configuration) 28 | 29 | # Policy 30 | policy_params = AttrDict( 31 | action_dim=data_spec.n_actions, 32 | input_dim=data_spec.state_dim, 33 | n_layers=5, # number of policy network layers 34 | nz_mid=256, 35 | max_action_range=1., 36 | ) 37 | 38 | # Critic 39 | critic_params = AttrDict( 40 | action_dim=policy_params.action_dim, 41 | input_dim=policy_params.input_dim, 42 | output_dim=1, 43 | n_layers=2, # number of policy network layers 44 | nz_mid=256, 45 | action_input=True, 46 | ) 47 | 48 | # Replay Buffer 49 | replay_params = AttrDict( 50 | capacity=1e5, 51 | dump_replay=False, 52 | ) 53 | 54 | # Observation Normalization 55 | obs_norm_params = AttrDict( 56 | ) 57 | 58 | # Agent 59 | agent_config = AttrDict( 60 | policy=MLPPolicy, 61 | policy_params=policy_params, 62 | critic=MLPCritic, 63 | critic_params=critic_params, 64 | replay=UniformReplayBuffer, 65 | replay_params=replay_params, 66 | # obs_normalizer=Normalizer, 67 | # obs_normalizer_params=obs_norm_params, 68 | clip_q_target=False, 69 | batch_size=256, 70 | log_video_caption=True, 71 | ) 72 | 73 | # Dataset - Random data 74 | data_config = AttrDict() 75 | data_config.dataset_spec = data_spec 76 | 77 | # Environment 78 | env_config = AttrDict( 79 | reward_norm=1., 80 | ) 81 | 82 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/block_stacking_logger.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from spirl.components.logger import Logger 4 | from spirl.models.skill_prior_mdl import SkillSpaceLogger 5 | from spirl.data.block_stacking.src.block_stacking_env import BlockStackEnv 6 | from spirl.utils.general_utils import AttrDict 7 | from spirl.utils.vis_utils import add_caption_to_img 8 | from spirl.data.block_stacking.src.block_task_generator import FixedSizeSingleTowerBlockTaskGenerator 9 | 10 | 11 | class BlockStackLogger(Logger): 12 | # logger to save visualizations of input and output trajectories in block stacking environment 13 | 14 | @staticmethod 15 | def _init_env_from_id(id): 16 | # TODO: return different environment variants depending on id 17 | task_params = AttrDict( 18 | max_tower_height=4 19 | ) 20 | 21 | env_config = AttrDict( 22 | task_generator=FixedSizeSingleTowerBlockTaskGenerator, 23 | task_params=task_params, 24 | dimension=2, 25 | screen_width=128, 26 | screen_height=128 27 | ) 28 | 29 | return BlockStackEnv(env_config) 30 | 31 | @staticmethod 32 | def _render_state(env, model_xml, obs, name=""): 33 | env.reset() 34 | 35 | unwrapped_obs = env._unflatten_block_obs(obs) 36 | 37 | sim_state = env.sim.get_state() 38 | 39 | sim_state.qpos[:len(sim_state.qpos)] = env.obs2qpos(obs) 40 | env.sim.set_state(sim_state) 41 | env.sim.forward() 42 | img = env.render() 43 | 44 | # round function 45 | rd = lambda x: np.round(x, 2) 46 | 47 | # add caption to the image 48 | info = { 49 | "Robot Pos": rd(unwrapped_obs["gripper_pos"]), 50 | "Robot Ang": rd(unwrapped_obs["gripper_angle"]), 51 | "Gripper Finger Pos": rd(unwrapped_obs["gripper_finger_pos"]), 52 | } 53 | 54 | for i in range(unwrapped_obs["block_pos"].shape[0]): 55 | info.update({ 56 | "Block {}:".format(i): rd(unwrapped_obs["block_pos"][i]) 57 | }) 58 | 59 | img = add_caption_to_img(img, info, name, flip_rgb=True) 60 | 61 | return img 62 | 63 | 64 | class SkillSpaceBlockStackLogger(BlockStackLogger, SkillSpaceLogger): 65 | pass 66 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/wrappers/gym_wrapper.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file implements a wrapper for facilitating compatibility with OpenAI gym. 3 | This is useful when using these environments with code that assumes a gym-like 4 | interface. 5 | """ 6 | 7 | import numpy as np 8 | from gym import spaces 9 | from spirl.data.block_stacking.src.robosuite.wrappers import Wrapper 10 | 11 | 12 | class GymWrapper(Wrapper): 13 | env = None 14 | 15 | def __init__(self, env, keys=None): 16 | """ 17 | Initializes the Gym wrapper. 18 | 19 | Args: 20 | env (MujocoEnv instance): The environment to wrap. 21 | keys (list of strings): If provided, each observation will 22 | consist of concatenated keys from the wrapped environment's 23 | observation dictionary. Defaults to robot-state and object-state. 24 | """ 25 | self.env = env 26 | 27 | if keys is None: 28 | assert self.env.use_object_obs, "Object observations need to be enabled." 29 | keys = ["robot-state", "object-state"] 30 | self.keys = keys 31 | 32 | # set up observation and action spaces 33 | flat_ob = self._flatten_obs(self.env.reset(), verbose=True) 34 | self.obs_dim = flat_ob.size 35 | high = np.inf * np.ones(self.obs_dim) 36 | low = -high 37 | self.observation_space = spaces.Box(low=low, high=high) 38 | low, high = self.env.action_spec 39 | self.action_space = spaces.Box(low=low, high=high) 40 | 41 | def _flatten_obs(self, obs_dict, verbose=False): 42 | """ 43 | Filters keys of interest out and concatenate the information. 44 | 45 | Args: 46 | obs_dict: ordered dictionary of observations 47 | """ 48 | ob_lst = [] 49 | for key in obs_dict: 50 | if key in self.keys: 51 | if verbose: 52 | print("adding key: {}".format(key)) 53 | ob_lst.append(obs_dict[key]) 54 | return np.concatenate(ob_lst) 55 | 56 | def reset(self): 57 | ob_dict = self.env.reset() 58 | return self._flatten_obs(ob_dict) 59 | 60 | def step(self, action): 61 | ob_dict, reward, done, info = self.env.step(action) 62 | return self._flatten_obs(ob_dict), reward, done, info 63 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/wrappers/wrapper.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains the base wrapper class for Mujoco environments. 3 | Wrappers are useful for data collection and logging. Highly recommended. 4 | """ 5 | 6 | 7 | class Wrapper: 8 | env = None 9 | 10 | def __init__(self, env): 11 | self.env = env 12 | 13 | @classmethod 14 | def class_name(cls): 15 | return cls.__name__ 16 | 17 | def _warn_double_wrap(self): 18 | env = self.env 19 | while True: 20 | if isinstance(env, Wrapper): 21 | if env.class_name() == self.class_name(): 22 | raise Exception( 23 | "Attempted to double wrap with Wrapper: {}".format( 24 | self.__class__.__name__ 25 | ) 26 | ) 27 | env = env.env 28 | else: 29 | break 30 | 31 | def step(self, action): 32 | return self.env.step(action) 33 | 34 | def reset(self): 35 | return self.env.reset() 36 | 37 | def render(self, **kwargs): 38 | return self.env.render(**kwargs) 39 | 40 | def observation_spec(self): 41 | return self.env.observation_spec() 42 | 43 | def action_spec(self): 44 | return self.env.action_spec() 45 | 46 | @property 47 | def dof(self): 48 | return self.env.dof 49 | 50 | @property 51 | def unwrapped(self): 52 | if hasattr(self.env, "unwrapped"): 53 | return self.env.unwrapped 54 | else: 55 | return self.env 56 | 57 | # this method is a fallback option on any methods the original env might support 58 | def __getattr__(self, attr): 59 | # using getattr ensures that both __getattribute__ and __getattr__ (fallback) get called 60 | # (see https://stackoverflow.com/questions/3278077/difference-between-getattr-vs-getattribute) 61 | orig_attr = getattr(self.env, attr) 62 | if callable(orig_attr): 63 | 64 | def hooked(*args, **kwargs): 65 | result = orig_attr(*args, **kwargs) 66 | # prevent wrapped_class from becoming unwrapped 67 | if result == self.env: 68 | return self 69 | return result 70 | 71 | return hooked 72 | else: 73 | return orig_attr 74 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/arenas/table_arena.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from spirl.data.block_stacking.src.robosuite.models.arenas import Arena 3 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import xml_path_completion 4 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import array_to_string, string_to_array 5 | 6 | 7 | class TableArena(Arena): 8 | """Workspace that contains an empty table.""" 9 | 10 | def __init__( 11 | self, table_full_size=(0.8, 0.8, 0.8), table_friction=(1, 0.005, 0.0001) 12 | ): 13 | """ 14 | Args: 15 | table_full_size: full dimensions of the table 16 | friction: friction parameters of the table 17 | """ 18 | super().__init__(xml_path_completion("arenas/table_arena.xml")) 19 | 20 | self.table_full_size = np.array(table_full_size) 21 | self.table_half_size = self.table_full_size / 2 22 | self.table_friction = table_friction 23 | 24 | self.floor = self.worldbody.find("./geom[@name='floor']") 25 | self.table_body = self.worldbody.find("./body[@name='table']") 26 | self.table_collision = self.table_body.find("./geom[@name='table_collision']") 27 | self.table_visual = self.table_body.find("./geom[@name='table_visual']") 28 | self.table_top = self.table_body.find("./site[@name='table_top']") 29 | 30 | self.configure_location() 31 | 32 | def configure_location(self): 33 | self.bottom_pos = np.array([0, 0, 0]) 34 | self.floor.set("pos", array_to_string(self.bottom_pos)) 35 | 36 | self.center_pos = self.bottom_pos + np.array([0, 0, self.table_half_size[2]]) 37 | self.table_body.set("pos", array_to_string(self.center_pos)) 38 | self.table_collision.set("size", array_to_string(self.table_half_size)) 39 | self.table_collision.set("friction", array_to_string(self.table_friction)) 40 | self.table_visual.set("size", array_to_string(self.table_half_size)) 41 | 42 | self.table_top.set( 43 | "pos", array_to_string(np.array([0, 0, self.table_half_size[2]])) 44 | ) 45 | 46 | @property 47 | def table_top_abs(self): 48 | """Returns the absolute position of table top""" 49 | table_height = np.array([0, 0, self.table_full_size[2]]) 50 | return string_to_array(self.floor.get("pos")) + table_height 51 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/arenas/pegs_arena.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/scripts/demo_learning_curriculum.py: -------------------------------------------------------------------------------- 1 | """Demo of learning curriculum utilities. 2 | 3 | Several prior works have demonstrated the effectiveness of altering the 4 | start state distribution of training episodes for learning RL policies. 5 | We provide a generic utility for setting various types of learning 6 | curriculums which dictate how to sample from demonstration episodes 7 | when doing an environment reset. For more information see the 8 | `DemoSamplerWrapper` class. 9 | 10 | Related work: 11 | 12 | [1] Reinforcement and Imitation Learning for Diverse Visuomotor Skills 13 | Yuke Zhu, Ziyu Wang, Josh Merel, Andrei Rusu, Tom Erez, Serkan Cabi,Saran Tunyasuvunakool, 14 | János Kramár, Raia Hadsell, Nando de Freitas, Nicolas Heess 15 | RSS 2018 16 | 17 | [2] Backplay: "Man muss immer umkehren" 18 | Cinjon Resnick, Roberta Raileanu, Sanyam Kapoor, Alex Peysakhovich, Kyunghyun Cho, Joan Bruna 19 | arXiv:1807.06919 20 | 21 | [3] DeepMimic: Example-Guided Deep Reinforcement Learning of Physics-Based Character Skills 22 | Xue Bin Peng, Pieter Abbeel, Sergey Levine, Michiel van de Panne 23 | Transactions on Graphics 2018 24 | 25 | [4] Approximately optimal approximate reinforcement learning 26 | Sham Kakade and John Langford 27 | ICML 2002 28 | """ 29 | 30 | import os 31 | 32 | from spirl.data.block_stacking.src import robosuite 33 | from spirl.data.block_stacking.src.robosuite import make 34 | from spirl.data.block_stacking.src.robosuite.wrappers import DemoSamplerWrapper 35 | 36 | 37 | if __name__ == "__main__": 38 | 39 | env = make( 40 | "SawyerPickPlace", 41 | has_renderer=True, 42 | has_offscreen_renderer=False, 43 | ignore_done=True, 44 | use_camera_obs=False, 45 | reward_shaping=True, 46 | gripper_visualization=True, 47 | ) 48 | 49 | env = DemoSamplerWrapper( 50 | env, 51 | demo_path=os.path.join( 52 | robosuite.models.assets_root, "demonstrations/SawyerPickPlace" 53 | ), 54 | need_xml=True, 55 | num_traj=-1, 56 | sampling_schemes=["uniform", "random"], 57 | scheme_ratios=[0.9, 0.1], 58 | ) 59 | 60 | for _ in range(100): 61 | env.reset() 62 | env.viewer.set_camera(0) 63 | env.render() 64 | for i in range(100): 65 | if i == 0: 66 | reward = env.reward() 67 | print("reward", reward) 68 | env.render() 69 | -------------------------------------------------------------------------------- /spirl/rl/components/params.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_args(): 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--path", help="path to the config file directory") 7 | 8 | # Folder settings 9 | parser.add_argument("--prefix", help="experiment prefix, if given creates subfolder in experiment directory") 10 | parser.add_argument('--new_dir', default=False, type=int, help='If True, concat datetime string to exp_dir.') 11 | parser.add_argument('--dont_save', default=False, type=int, 12 | help="if True, nothing is saved to disk. Note: this doesn't work") # TODO this doesn't work 13 | 14 | # Running protocol 15 | parser.add_argument('--resume', default='latest', type=str, metavar='PATH', 16 | help='path to latest checkpoint (default: none)') 17 | parser.add_argument('--mode', default='train', type=str, 18 | choices=['train', 'val', 'rollout'], 19 | help='mode of the program (training, validation, or generate rollout)') 20 | 21 | # Misc 22 | parser.add_argument('--seed', default=-1, type=int, 23 | help='overrides config/default seed for more convenient seed setting.') 24 | parser.add_argument('--gpu', default=0, type=int, 25 | help='will set CUDA_VISIBLE_DEVICES to selected value') 26 | parser.add_argument('--strict_weight_loading', default=True, type=int, 27 | help='if True, uses strict weight loading function') 28 | parser.add_argument('--deterministic', default=False, type=int, 29 | help='if True, sets fixed seeds for torch and numpy') 30 | parser.add_argument('--n_val_samples', default=10, type=int, 31 | help='number of validation episodes') 32 | parser.add_argument('--save_dir', type=str, 33 | help='directory for saving the generated rollouts in rollout mode') 34 | parser.add_argument('--config_override', default='', type=str, 35 | help='override to config file in format "key1.key2=val1,key3=val2"') 36 | 37 | # Debug 38 | parser.add_argument('--debug', default=False, type=int, 39 | help='if True, runs in debug mode') 40 | 41 | # Note 42 | parser.add_argument('--notes', default='', type=str, 43 | help='Notes for the run') 44 | 45 | return parser.parse_args() 46 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/assets/objects/round-nut.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /spirl/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def safe_entropy(dist, dim=None, eps=1e-12): 6 | """Computes entropy even if some entries are 0.""" 7 | return -torch.sum(dist * safe_log_prob(dist, eps), dim=dim) 8 | 9 | 10 | def safe_log_prob(tensor, eps=1e-12): 11 | """Safe log of probability values (must be between 0 and 1)""" 12 | return torch.log(torch.clamp(tensor, eps, 1 - eps)) 13 | 14 | 15 | def normalize(tensor, dim=1, eps=1e-7): 16 | norm = torch.clamp(tensor.sum(dim, keepdim=True), eps) 17 | return tensor / norm 18 | 19 | 20 | def gumbel_sample(shape, eps=1e-8): 21 | """Sample Gumbel noise.""" 22 | uniform = torch.rand(shape).float() 23 | return -torch.log(eps - torch.log(uniform + eps)) 24 | 25 | 26 | def gumbel_softmax_sample(logits, temp=1.): 27 | """Sample from the Gumbel softmax / concrete distribution.""" 28 | gumbel_noise = gumbel_sample(logits.size()).to(logits.device) 29 | return F.softmax((logits + gumbel_noise) / temp, dim=1) 30 | 31 | 32 | def log_cumsum(probs, dim=1, eps=1e-8): 33 | """Calculate log of inclusive cumsum.""" 34 | return torch.log(torch.cumsum(probs, dim=dim) + eps) 35 | 36 | 37 | def poisson_categorical_log_prior(length, rate, device): 38 | """Categorical prior populated with log probabilities of Poisson dist. 39 | From: https://github.com/tkipf/compile/blob/b88b17411c37e1ed95459a0a779d71d5acef9e3f/utils.py#L58""" 40 | rate = torch.tensor(rate, dtype=torch.float32, device=device) 41 | values = torch.arange(1, length + 1, dtype=torch.float32, device=device).unsqueeze(0) 42 | log_prob_unnormalized = torch.log(rate) * values - rate - (values + 1).lgamma() 43 | # TODO(tkipf): Length-sensitive normalization. 44 | return F.log_softmax(log_prob_unnormalized, dim=1) # Normalize. 45 | 46 | 47 | def kl_categorical(preds, log_prior, eps=1e-8): 48 | """KL divergence between two categorical distributions.""" 49 | kl_div = preds * (torch.log(preds + eps) - log_prior) 50 | return kl_div.sum(1) 51 | 52 | 53 | class Dirac: 54 | """Dummy Dirac distribution.""" 55 | def __init__(self, val): 56 | self._val = val 57 | 58 | def sample(self): 59 | return self._val 60 | 61 | def rsample(self): 62 | return self._val 63 | 64 | def log_prob(self, val): 65 | return torch.tensor(int(val == self._val), dtype=torch.float32, device=self._val.device) 66 | 67 | @property 68 | def logits(self): 69 | """This is more of a dummy return.""" 70 | return self._val 71 | -------------------------------------------------------------------------------- /spirl/configs/rl/block_stacking/base_conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from spirl.utils.general_utils import AttrDict 4 | from spirl.rl.agents.ac_agent import SACAgent 5 | from spirl.rl.policies.mlp_policies import MLPPolicy 6 | from spirl.rl.components.critic import MLPCritic 7 | from spirl.rl.components.replay_buffer import UniformReplayBuffer 8 | from spirl.rl.envs.block_stacking import HighStack11StackEnvV0 9 | from spirl.rl.components.normalization import Normalizer 10 | from spirl.configs.default_data_configs.block_stacking import data_spec 11 | 12 | 13 | current_dir = os.path.dirname(os.path.realpath(__file__)) 14 | 15 | notes = 'non-hierarchical RL experiments in block stacking env' 16 | 17 | configuration = { 18 | 'seed': 42, 19 | 'agent': SACAgent, 20 | 'environment': HighStack11StackEnvV0, 21 | 'data_dir': '.', 22 | 'num_epochs': 100, 23 | 'max_rollout_len': 1000, 24 | 'n_steps_per_epoch': 100000, 25 | 'n_warmup_steps': 5e3, 26 | } 27 | configuration = AttrDict(configuration) 28 | 29 | # Policy 30 | policy_params = AttrDict( 31 | action_dim=data_spec.n_actions, 32 | input_dim=data_spec.state_dim, 33 | n_layers=5, # number of policy network layers 34 | nz_mid=256, 35 | max_action_range=1., 36 | ) 37 | 38 | # Critic 39 | critic_params = AttrDict( 40 | action_dim=policy_params.action_dim, 41 | input_dim=policy_params.input_dim, 42 | output_dim=1, 43 | n_layers=2, # number of policy network layers 44 | nz_mid=256, 45 | action_input=True, 46 | ) 47 | 48 | # Replay Buffer 49 | replay_params = AttrDict( 50 | capacity=1e5, 51 | dump_replay=False, 52 | ) 53 | 54 | # Observation Normalization 55 | obs_norm_params = AttrDict( 56 | ) 57 | 58 | # Agent 59 | agent_config = AttrDict( 60 | policy=MLPPolicy, 61 | policy_params=policy_params, 62 | critic=MLPCritic, 63 | critic_params=critic_params, 64 | replay=UniformReplayBuffer, 65 | replay_params=replay_params, 66 | # obs_normalizer=Normalizer, 67 | # obs_normalizer_params=obs_norm_params, 68 | clip_q_target=False, 69 | batch_size=256, 70 | log_video_caption=True, 71 | ) 72 | 73 | # Dataset - Random data 74 | data_config = AttrDict() 75 | data_config.dataset_spec = data_spec 76 | 77 | # Environment 78 | env_config = AttrDict( 79 | name="block_stacking", 80 | reward_norm=1., 81 | screen_width=data_spec.res, 82 | screen_height=data_spec.res, 83 | env_config=AttrDict(camera_name='agentview', 84 | screen_width=data_spec.res, 85 | screen_height=data_spec.res,) 86 | ) 87 | 88 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/grippers/robotiq_gripper.py: -------------------------------------------------------------------------------- 1 | """ 2 | 6-DoF gripper with its open/close variant 3 | """ 4 | import numpy as np 5 | 6 | from spirl.data.block_stacking.src.robosuite.models.grippers import Gripper 7 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import xml_path_completion 8 | 9 | 10 | class RobotiqGripperBase(Gripper): 11 | """ 12 | 6-DoF Robotiq gripper. 13 | """ 14 | 15 | def __init__(self): 16 | super().__init__(xml_path_completion("grippers/robotiq_gripper.xml")) 17 | 18 | @property 19 | def init_qpos(self): 20 | return [3.3161, 0., 0., 0., 0., 0.] 21 | 22 | @property 23 | def joints(self): 24 | return [ 25 | "robotiq_85_left_knuckle_joint", 26 | "robotiq_85_left_inner_knuckle_joint", 27 | "robotiq_85_left_finger_tip_joint", 28 | "robotiq_85_right_knuckle_joint", 29 | "robotiq_85_right_inner_knuckle_joint", 30 | "robotiq_85_right_finger_tip_joint", 31 | ] 32 | 33 | @property 34 | def dof(self): 35 | return 6 36 | 37 | def contact_geoms(self): 38 | return [ 39 | "robotiq_85_gripper_joint_0_L", 40 | "robotiq_85_gripper_joint_1_L", 41 | "robotiq_85_gripper_joint_0_R", 42 | "robotiq_85_gripper_joint_1_R", 43 | "robotiq_85_gripper_joint_2_L", 44 | "robotiq_85_gripper_joint_3_L", 45 | "robotiq_85_gripper_joint_2_R", 46 | "robotiq_85_gripper_joint_3_R", 47 | ] 48 | 49 | @property 50 | def visualization_sites(self): 51 | return ["grip_site", "grip_site_cylinder"] 52 | 53 | @property 54 | def left_finger_geoms(self): 55 | return [ 56 | "robotiq_85_gripper_joint_0_L", 57 | "robotiq_85_gripper_joint_1_L", 58 | "robotiq_85_gripper_joint_2_L", 59 | "robotiq_85_gripper_joint_3_L", 60 | ] 61 | 62 | @property 63 | def right_finger_geoms(self): 64 | return [ 65 | "robotiq_85_gripper_joint_0_R", 66 | "robotiq_85_gripper_joint_1_R", 67 | "robotiq_85_gripper_joint_2_R", 68 | "robotiq_85_gripper_joint_3_R", 69 | ] 70 | 71 | 72 | class RobotiqGripper(RobotiqGripperBase): 73 | """ 74 | 1-DoF variant of RobotiqGripperBase. 75 | """ 76 | 77 | def format_action(self, action): 78 | """ 79 | Args: 80 | action: 1 => open, -1 => closed 81 | """ 82 | assert len(action) == 1 83 | return -1 * np.ones(6) * action 84 | 85 | @property 86 | def dof(self): 87 | return 1 88 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/wrappers/README.md: -------------------------------------------------------------------------------- 1 | Wrappers 2 | ======== 3 | 4 | Wrappers offer additional features to an environment. Custom wrappers can be implemented for the purpose of collecting data, recording videos, and modifying environments. We provide some example wrapper implementations in this folder. All wrappers should inherit a base [Wrapper](wrapper.py) class. 5 | 6 | To use a wrapper for the environment, import it and wrap it around the environment. 7 | 8 | ```python 9 | import robosuite 10 | from robosuite.wrappers import CustomWrapper 11 | 12 | env = robosuite.make("SawyerLift") 13 | env = CustomWrapper(env) 14 | ``` 15 | 16 | DataCollectionWrapper 17 | --------------------- 18 | [DataCollectionWrapper](data_collection_wrapper.py) saves trajectory information to disk. This [demo script](../scripts/demo_collect_and_playback_data.py) illustrates how to collect the rollout trajectories and replay them in the simulation. 19 | 20 | DemoSamplerWrapper 21 | ------------------ 22 | [DemoSamplerWrapper](demo_sampler_wrapper.py) loads demonstrations as a dataset of trajectories and randomly resets the start state of episodes along the demonstration trajectories based on a certain schedule. This functionality is useful for training RL agents and has been adopted in several prior work (see [references](../scripts/demo_learning_curriculum.py)). We provide a [demo script](../scripts/demo_learning_curriculum.py) to show how to configure the demo sampler to load demonstrations from files and use them to change the initial state distribution of episodes. 23 | 24 | GymWrapper 25 | ---------- 26 | [GymWrapper](gym_wrapper.py) implements the standard methods in [OpenAI Gym](https://github.com/openai/gym), which allows popular RL libraries to run with our environments using the same APIs as Gym. This [demo script](../scripts/demo_gym_functionality.py) shows how to convert robosuite environments into Gym interfaces using this wrapper. 27 | 28 | ```bash 29 | pip install gym 30 | ``` 31 | 32 | ## IKWrapper 33 | 34 | [IKWrapper](ik_wrapper.py) allows for using an end effector action space to control the robot in an environment instead of the default joint velocity action space. It uses our inverse kinematics robot controllers, located in the [controllers](../controllers) directory, which depend on PyBullet. In order to use this wrapper, you must run the following command to install PyBullet. 35 | 36 | ```bash 37 | pip install pybullet==1.9.5 38 | ``` 39 | 40 | The main difference between the joint velocity action space and the end effector action space supported by this wrapper is that instead of supplying joint velocities per arm, a **delta position** vector and **delta quaternion** (xyzw) should be supplied per arm, where these correspond to the relative changes in position and rotation of the end effector from its current pose. 41 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/grippers/gripper.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines the base class of all grippers 3 | """ 4 | from spirl.data.block_stacking.src.robosuite.models.base import MujocoXML 5 | 6 | 7 | class Gripper(MujocoXML): 8 | """Base class for grippers""" 9 | 10 | def __init__(self, fname): 11 | super().__init__(fname) 12 | 13 | def format_action(self, action): 14 | """ 15 | Given (-1,1) abstract control as np-array 16 | returns the (-1,1) control signals 17 | for underlying actuators as 1-d np array 18 | """ 19 | raise NotImplementedError 20 | 21 | @property 22 | def init_qpos(self): 23 | """ 24 | Returns rest(open) qpos of the gripper 25 | """ 26 | raise NotImplementedError 27 | 28 | @property 29 | def dof(self): 30 | """ 31 | Returns the number of DOF of the gripper 32 | """ 33 | raise NotImplementedError 34 | 35 | @property 36 | def joints(self): 37 | """ 38 | Returns a list of joint names of the gripper 39 | """ 40 | raise NotImplementedError 41 | 42 | def contact_geoms(self): 43 | """ 44 | Returns a list of names corresponding to the geoms 45 | used to determine contact with the gripper. 46 | """ 47 | return [] 48 | 49 | @property 50 | def visualization_sites(self): 51 | """ 52 | Returns a list of sites corresponding to the geoms 53 | used to aid visualization by human. 54 | (and should be hidden from robots) 55 | """ 56 | return [] 57 | 58 | @property 59 | def visualization_geoms(self): 60 | """ 61 | Returns a list of sites corresponding to the geoms 62 | used to aid visualization by human. 63 | (and should be hidden from robots) 64 | """ 65 | return [] 66 | 67 | @property 68 | def left_finger_geoms(self): 69 | """ 70 | Geoms corresponding to left finger of a gripper 71 | """ 72 | raise NotImplementedError 73 | 74 | @property 75 | def right_finger_geoms(self): 76 | """ 77 | Geoms corresponding to raise finger of a gripper 78 | """ 79 | raise NotImplementedError 80 | 81 | def hide_visualization(self): 82 | """ 83 | Hides all visualization geoms and sites. 84 | This should be called before rendering to agents 85 | """ 86 | for site_name in self.visualization_sites: 87 | site = self.worldbody.find(".//site[@name='{}']".format(site_name)) 88 | site.set("rgba", "0 0 0 0") 89 | for geom_name in self.visualization_geoms: 90 | geom = self.worldbody.find(".//geom[@name='{}']".format(geom_name)) 91 | geom.set("rgba", "0 0 0 0") 92 | -------------------------------------------------------------------------------- /spirl/modules/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from spirl.utils.general_utils import AttrDict, get_dim_inds 3 | from spirl.modules.variational_inference import Gaussian 4 | 5 | 6 | class Loss(): 7 | def __init__(self, weight=1.0, breakdown=None): 8 | """ 9 | 10 | :param weight: the balance term on the loss 11 | :param breakdown: if specified, a breakdown of the loss by this dimension will be recorded 12 | """ 13 | self.weight = weight 14 | self.breakdown = breakdown 15 | 16 | def __call__(self, *args, weights=1, reduction='mean', store_raw=False, **kwargs): 17 | """ 18 | 19 | :param estimates: 20 | :param targets: 21 | :return: 22 | """ 23 | error = self.compute(*args, **kwargs) * weights 24 | if reduction != 'mean': 25 | raise NotImplementedError 26 | loss = AttrDict(value=error.mean(), weight=self.weight) 27 | if self.breakdown is not None: 28 | reduce_dim = get_dim_inds(error)[:self.breakdown] + get_dim_inds(error)[self.breakdown+1:] 29 | loss.breakdown = error.detach().mean(reduce_dim) if reduce_dim else error.detach() 30 | if store_raw: 31 | loss.error_mat = error.detach() 32 | return loss 33 | 34 | def compute(self, estimates, targets): 35 | raise NotImplementedError 36 | 37 | 38 | class L2Loss(Loss): 39 | def compute(self, estimates, targets, activation_function=None): 40 | # assert estimates.shape == targets.shape, "Input {} and targets {} for L2 loss need to have identical shape!"\ 41 | # .format(estimates.shape, targets.shape) 42 | if activation_function is not None: 43 | estimates = activation_function(estimates) 44 | if not isinstance(targets, torch.Tensor): 45 | targets = torch.tensor(targets, device=estimates.device, dtype=estimates.dtype) 46 | l2_loss = torch.nn.MSELoss(reduction='none')(estimates, targets) 47 | return l2_loss 48 | 49 | 50 | class KLDivLoss(Loss): 51 | def compute(self, estimates, targets): 52 | if not isinstance(estimates, Gaussian): estimates = Gaussian(estimates) 53 | if not isinstance(targets, Gaussian): targets = Gaussian(targets) 54 | kl_divergence = estimates.kl_divergence(targets) 55 | return kl_divergence 56 | 57 | 58 | class CELoss(Loss): 59 | compute = staticmethod(torch.nn.functional.cross_entropy) 60 | 61 | 62 | class PenaltyLoss(Loss): 63 | def compute(self, val): 64 | """Computes weighted mean of val as penalty loss.""" 65 | return val 66 | 67 | 68 | class NLL(Loss): 69 | # Note that cross entropy is an instance of NLL, as is L2 loss. 70 | def compute(self, estimates, targets): 71 | nll = estimates.nll(targets) 72 | return nll 73 | 74 | 75 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/tasks/table_top_task.py: -------------------------------------------------------------------------------- 1 | from spirl.data.block_stacking.src.robosuite.models.tasks import Task, UniformRandomSampler 2 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import new_joint, array_to_string 3 | 4 | 5 | class TableTopTask(Task): 6 | """ 7 | Creates MJCF model of a tabletop task. 8 | 9 | A tabletop task consists of one robot interacting with a variable number of 10 | objects placed on the tabletop. This class combines the robot, the table 11 | arena, and the objetcts into a single MJCF model. 12 | """ 13 | 14 | def __init__(self, mujoco_arena, mujoco_robot, mujoco_objects, initializer=None): 15 | """ 16 | Args: 17 | mujoco_arena: MJCF model of robot workspace 18 | mujoco_robot: MJCF model of robot model 19 | mujoco_objects: a list of MJCF models of physical objects 20 | initializer: placement sampler to initialize object positions. 21 | """ 22 | super().__init__() 23 | 24 | self.merge_arena(mujoco_arena) 25 | self.merge_robot(mujoco_robot) 26 | self.merge_objects(mujoco_objects) 27 | if initializer is None: 28 | initializer = UniformRandomSampler() 29 | mjcfs = [x for _, x in self.mujoco_objects.items()] 30 | 31 | self.initializer = initializer 32 | self.initializer.setup(mjcfs, self.table_top_offset, self.table_size) 33 | 34 | def merge_robot(self, mujoco_robot): 35 | """Adds robot model to the MJCF model.""" 36 | self.robot = mujoco_robot 37 | self.merge(mujoco_robot) 38 | 39 | def merge_arena(self, mujoco_arena): 40 | """Adds arena model to the MJCF model.""" 41 | self.arena = mujoco_arena 42 | self.table_top_offset = mujoco_arena.table_top_abs 43 | self.table_size = mujoco_arena.table_full_size 44 | self.merge(mujoco_arena) 45 | 46 | def merge_objects(self, mujoco_objects): 47 | """Adds physical objects to the MJCF model.""" 48 | self.mujoco_objects = mujoco_objects 49 | self.objects = [] # xml manifestation 50 | self.targets = [] # xml manifestation 51 | self.max_horizontal_radius = 0 52 | 53 | for obj_name, obj_mjcf in mujoco_objects.items(): 54 | self.merge_asset(obj_mjcf) 55 | # Load object 56 | obj = obj_mjcf.get_collision(name=obj_name, site=True) 57 | obj.append(new_joint(name=obj_name, type="free")) 58 | self.objects.append(obj) 59 | self.worldbody.append(obj) 60 | 61 | self.max_horizontal_radius = max( 62 | self.max_horizontal_radius, obj_mjcf.get_horizontal_radius() 63 | ) 64 | 65 | def place_objects(self): 66 | """Places objects randomly until no collisions or max iterations hit.""" 67 | pos_arr, quat_arr = self.initializer.sample() 68 | for i in range(len(self.objects)): 69 | self.objects[i].set("pos", array_to_string(pos_arr[i])) 70 | self.objects[i].set("quat", array_to_string(quat_arr[i])) 71 | -------------------------------------------------------------------------------- /spirl/rl/utils/wandb.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import inspect 3 | import numpy as np 4 | 5 | from spirl.utils.general_utils import flatten_dict, prefix_dict 6 | import csv 7 | 8 | class WandBLogger: 9 | """Logs to WandB.""" 10 | N_LOGGED_SAMPLES = 3 # how many examples should be logged in each logging step 11 | 12 | def __init__(self, exp_name, project_name, entity, path, conf, exclude=None): 13 | """ 14 | :param exp_name: full name of experiment in WandB 15 | :param project_name: name of overall project 16 | :param entity: name of head entity in WandB that hosts the project 17 | :param path: path to which WandB log-files will be written 18 | :param conf: hyperparam config that will get logged to WandB 19 | :param exclude: (optional) list of (flattened) hyperparam names that should not get logged 20 | """ 21 | if exclude is None: exclude = [] 22 | flat_config = flatten_dict(conf) 23 | filtered_config = {k: v for k, v in flat_config.items() if (k not in exclude and not inspect.isclass(v))} 24 | print("INIT WANDB") 25 | wandb.init( 26 | resume=exp_name, 27 | project=project_name, 28 | config=filtered_config, 29 | dir=path, 30 | entity=entity, 31 | notes=conf.notes if 'notes' in conf else '' 32 | ) 33 | 34 | def log_scalar_dict(self, d, prefix='', step=None): 35 | """Logs all entries from a dict of scalars. Optionally can prefix all keys in dict before logging.""" 36 | if prefix: d = prefix_dict(d, prefix + '_') 37 | #wandb.log(d) if step is None else wandb.log(d, step=step) 38 | wandb.log(d) 39 | 40 | if 'train_episode_reward' in d: 41 | with open('train_episode_reward.csv','a') as f: 42 | if d['train_episode_length'] < 200: 43 | success = 1 44 | else: 45 | success = 0 46 | f.write(str(d['train_episode_reward'])+','+str(success)+'\n') 47 | 48 | def log_videos(self, vids, name, step=None): 49 | """Logs videos to WandB in mp4 format. 50 | Assumes list of numpy arrays as input with [time, channels, height, width].""" 51 | assert len(vids[0].shape) == 4 and vids[0].shape[1] == 3 52 | assert isinstance(vids[0], np.ndarray) 53 | if vids[0].max() <= 1.0: vids = [np.asarray(vid * 255.0, dtype=np.uint8) for vid in vids] 54 | # TODO(karl) expose the FPS as a parameter 55 | log_dict = {name: [wandb.Video(vid, fps=20, format="mp4") for vid in vids]} 56 | wandb.log(log_dict) if step is None else wandb.log(log_dict, step=step) 57 | 58 | def log_plot(self, fig, name, step=None): 59 | """Logs matplotlib graph to WandB. 60 | fig is a matplotlib figure handle.""" 61 | img = wandb.Image(fig) 62 | wandb.log({name: img}) if step is None else wandb.log({name: img}, step=step) 63 | 64 | @property 65 | def n_logged_samples(self): 66 | # TODO(karl) put this functionality in a base logger class + give it default parameters and config 67 | return self.N_LOGGED_SAMPLES 68 | 69 | 70 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/README.md: -------------------------------------------------------------------------------- 1 | Models 2 | ====== 3 | 4 | Models provide a toolkit of composing modularized elements into a scene. The central goal of models is to provide a set of modularized APIs that procedurally generate combinations of robots, arenas, and parameterized 3D objects, such that it enables us to learn control policies with better robustness and generalization. 5 | 6 | In MuJoCo, these scene elements are specified in [MJCF](http://mujoco.org/book/modeling.html#Summary) format and compiled into mjModel to instantiate the physical simulation. Each MJCF model can consist of multiple XML files as well as meshes and textures referenced from the XML file. These MJCF files are stored in the [assets](assets) folder. 7 | 8 | Below we describe the types of scene elements that we support: 9 | 10 | Robots 11 | ------ 12 | The [robots](robots) folder contains robot classes which load robot specifications from MJCF/URDF files and instantiate a robot mjModel. All robot classes should inherit a base [Robot](robots/robot.py) class which defines a set of common robot APIs. 13 | 14 | Grippers 15 | -------- 16 | The [grippers](grippers) folder consists of a variety of end-effector models that can be mounted to the arms of a robot model by the [`add_gripper`](robots/robot.py#L20) method in the robot class. 17 | 18 | Objects 19 | ------- 20 | [Objects](objects) are small interactable scene elements that robots interact with using their actuators. Objects can be either defined as 3D [meshes](http://mujoco.org/book/modeling.html#mesh) (e.g., in STL format) or procedurally generated from primitive shapes of MuJoCo [geoms](http://mujoco.org/book/modeling.html#geom). 21 | 22 | [MujocoObject](objects/mujoco_object.py) is the base object class. [MujocoXMLObject](objects/mujoco_object.py) is the base class for all objects that are loaded from MJCF XML files. [MujocoGeneratedObject](objects/mujoco_object.py) is the base class for procedurally generated objects with support for size and other physical property randomization. 23 | 24 | Arenas 25 | ------ 26 | [Arenas](arenas) define the workspace, such as a tabletop or a set of bins, where the robot performs the tasks. All arena classes should inherit a base [Arena](arenas/arena.py) class. By default each arena contains 3 cameras (see [example](assets/arenas/empty_arena.xml)). The `frontview` camera provides an overview of the scene, which is often used to generate visualizations and video recordings. The `agentview` camera is the canonical camera typically used for visual observations when training visuomotor policies. The `birdview` camera is a top-down camera which is useful for debugging the placements of objects in the scene. 27 | 28 | Tasks 29 | ----- 30 | [Tasks](tasks) put together all the necessary scene elements, which typically consist of a robot, an arena, and a set of objects, into a model of the whole scene. It handles merging the MJCF models of individual elements and setting the initial placements of these elements in the scene. The resulting scene model is compiled and loaded into the MuJoCo backend to perform simulation. All task classes should inherit a base [Task](tasks/task.py) class which specifies a set of common APIs for model merging and placement initialization. 31 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/block_task_generator.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | import numpy as np 3 | 4 | from spirl.utils.general_utils import ParamDict 5 | 6 | 7 | class BlockTaskGenerator: 8 | def __init__(self, hp, n_blocks): 9 | self._hp = self._default_hparams().overwrite(hp) 10 | self._n_blocks = n_blocks 11 | self._rng = np.random.default_rng(seed=self._hp.seed) 12 | 13 | def _default_hparams(self): 14 | default_dict = ParamDict({ 15 | 'seed': None, 16 | }) 17 | return default_dict 18 | 19 | def sample(self): 20 | """Generates task definition in the form of list of transition tuples. 21 | Each tuple contains (bottom_block, top_block).""" 22 | raise NotImplementedError 23 | 24 | def _sample_tower(self, size, blocks): 25 | """Samples single tower of specified size, pops blocks from queue of idxs.""" 26 | block = blocks.popleft() 27 | tasks = [] 28 | for _ in range(size): 29 | next_block = blocks.popleft() 30 | tasks.append((block, next_block)) 31 | block = next_block 32 | return tasks 33 | 34 | def _init_block_queue(self): 35 | return deque(self._rng.permutation(self._n_blocks)) 36 | 37 | 38 | class SingleTowerBlockTaskGenerator(BlockTaskGenerator): 39 | """Samples single tower with a limit on the number of stacked blocks.""" 40 | def _default_hparams(self): 41 | default_dict = ParamDict({ 42 | 'max_tower_height': 4, # number of blocks in env 43 | }) 44 | return super()._default_hparams().overwrite(default_dict) 45 | 46 | def sample(self): 47 | block_queue = self._init_block_queue() 48 | size = self._sample_size() 49 | return self._sample_tower(size, block_queue) 50 | 51 | def _sample_size(self): 52 | return self._rng.integers(1, self._hp.max_tower_height + 1) # assure at least one stack is performed 53 | 54 | 55 | class FixedSizeSingleTowerBlockTaskGenerator(SingleTowerBlockTaskGenerator): 56 | """Samples single tower, always samples maximum height.""" 57 | def _sample_size(self): 58 | return self._hp.max_tower_height 59 | 60 | 61 | class MultiTowerBlockTaskGenerator(BlockTaskGenerator): 62 | """Samples multiple tower with a limit on the number of stacked blocks.""" 63 | def _default_hparams(self): 64 | default_dict = ParamDict({ 65 | 'max_tower_height': 4, # maximum height of target tower(s) 66 | }) 67 | return super()._default_hparams().overwrite(default_dict) 68 | 69 | def sample(self): 70 | block_queue = self._init_block_queue() 71 | task = [] 72 | while len(block_queue) > 1: 73 | size = self._sample_size(max_height=min(len(block_queue)-1, self._hp.max_tower_height)) 74 | task += self._sample_tower(size, block_queue) 75 | return task 76 | 77 | def _sample_size(self, max_height): 78 | return self._rng.integers(1, max_height + 1) # assure at least one stack is performed 79 | 80 | 81 | if __name__ == '__main__': 82 | task_gen = MultiTowerBlockTaskGenerator(hp={}, n_blocks=10) 83 | for _ in range(10): 84 | print(task_gen.sample()) 85 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/scripts/demo_collect_and_playback_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Record trajectory data with the DataCollectionWrapper wrapper and play them back. 3 | 4 | Example: 5 | $ python demo_collect_and_playback_data.py --environment BaxterLift 6 | """ 7 | 8 | import os 9 | import argparse 10 | from glob import glob 11 | import numpy as np 12 | 13 | from spirl.data.block_stacking.src import robosuite 14 | from spirl.data.block_stacking.src.robosuite import DataCollectionWrapper 15 | 16 | 17 | def collect_random_trajectory(env, timesteps=1000): 18 | """Run a random policy to collect trajectories. 19 | 20 | The rollout trajectory is saved to files in npz format. 21 | Modify the DataCollectionWrapper wrapper to add new fields or change data formats. 22 | """ 23 | 24 | obs = env.reset() 25 | dof = env.dof 26 | 27 | for t in range(timesteps): 28 | action = 0.5 * np.random.randn(dof) 29 | obs, reward, done, info = env.step(action) 30 | env.render() 31 | if t % 100 == 0: 32 | print(t) 33 | 34 | 35 | def playback_trajectory(env, ep_dir): 36 | """Playback data from an episode. 37 | 38 | Args: 39 | ep_dir: The path to the directory containing data for an episode. 40 | """ 41 | 42 | # first reload the model from the xml 43 | xml_path = os.path.join(ep_dir, "model.xml") 44 | with open(xml_path, "r") as f: 45 | env.reset_from_xml_string(f.read()) 46 | 47 | state_paths = os.path.join(ep_dir, "state_*.npz") 48 | 49 | # read states back, load them one by one, and render 50 | t = 0 51 | for state_file in sorted(glob(state_paths)): 52 | print(state_file) 53 | dic = np.load(state_file) 54 | states = dic["states"] 55 | for state in states: 56 | env.sim.set_state_from_flattened(state) 57 | env.sim.forward() 58 | env.render() 59 | t += 1 60 | if t % 100 == 0: 61 | print(t) 62 | 63 | 64 | if __name__ == "__main__": 65 | 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument("--environment", type=str, default="SawyerStack") 68 | parser.add_argument("--directory", type=str, default="/tmp/") 69 | parser.add_argument("--timesteps", type=int, default=2000) 70 | args = parser.parse_args() 71 | 72 | # create original environment 73 | env = robosuite.make( 74 | args.environment, 75 | ignore_done=True, 76 | use_camera_obs=False, 77 | has_renderer=True, 78 | control_freq=100, 79 | ) 80 | data_directory = args.directory 81 | 82 | # wrap the environment with data collection wrapper 83 | env = DataCollectionWrapper(env, data_directory) 84 | 85 | # testing to make sure multiple env.reset calls don't create multiple directories 86 | env.reset() 87 | env.reset() 88 | env.reset() 89 | 90 | # collect some data 91 | print("Collecting some random data...") 92 | collect_random_trajectory(env, timesteps=args.timesteps) 93 | 94 | # playback some data 95 | _ = input("Press any key to begin the playback...") 96 | print("Playing back the data...") 97 | data_directory = env.ep_directory 98 | playback_trajectory(env, data_directory) 99 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/tasks/nut_assembly_task.py: -------------------------------------------------------------------------------- 1 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import new_joint, array_to_string 2 | from spirl.data.block_stacking.src.robosuite.models.tasks import Task, UniformRandomPegsSampler 3 | 4 | 5 | class NutAssemblyTask(Task): 6 | """ 7 | Creates MJCF model of a nut assembly task. 8 | 9 | A nut assembly task consists of one robot picking up nuts from a table and 10 | and assembling them into pegs positioned on the tabletop. This class combines 11 | the robot, the arena with pegs, and the nut objetcts into a single MJCF model. 12 | """ 13 | 14 | def __init__(self, mujoco_arena, mujoco_robot, mujoco_objects, initializer=None): 15 | """ 16 | Args: 17 | mujoco_arena: MJCF model of robot workspace 18 | mujoco_robot: MJCF model of robot model 19 | mujoco_objects: a list of MJCF models of physical objects 20 | initializer: placement sampler to initialize object positions. 21 | """ 22 | super().__init__() 23 | 24 | self.object_metadata = [] 25 | self.merge_arena(mujoco_arena) 26 | self.merge_robot(mujoco_robot) 27 | self.merge_objects(mujoco_objects) 28 | 29 | if initializer is None: 30 | initializer = UniformRandomPegsSampler() 31 | self.initializer = initializer 32 | self.initializer.setup(self.mujoco_objects, self.table_offset, self.table_size) 33 | 34 | def merge_robot(self, mujoco_robot): 35 | """Adds robot model to the MJCF model.""" 36 | self.robot = mujoco_robot 37 | self.merge(mujoco_robot) 38 | 39 | def merge_arena(self, mujoco_arena): 40 | """Adds arena model to the MJCF model.""" 41 | self.arena = mujoco_arena 42 | self.table_offset = mujoco_arena.table_top_abs 43 | self.table_size = mujoco_arena.table_full_size 44 | self.table_body = mujoco_arena.table_body 45 | self.peg1_body = mujoco_arena.peg1_body 46 | self.peg2_body = mujoco_arena.peg2_body 47 | self.merge(mujoco_arena) 48 | 49 | def merge_objects(self, mujoco_objects): 50 | """Adds physical objects to the MJCF model.""" 51 | self.mujoco_objects = mujoco_objects 52 | self.objects = {} # xml manifestation 53 | self.max_horizontal_radius = 0 54 | for obj_name, obj_mjcf in mujoco_objects.items(): 55 | self.merge_asset(obj_mjcf) 56 | # Load object 57 | obj = obj_mjcf.get_collision(name=obj_name, site=True) 58 | obj.append(new_joint(name=obj_name, type="free", damping="0.0005")) 59 | self.objects[obj_name] = obj 60 | self.worldbody.append(obj) 61 | 62 | self.max_horizontal_radius = max( 63 | self.max_horizontal_radius, obj_mjcf.get_horizontal_radius() 64 | ) 65 | 66 | def place_objects(self): 67 | """Places objects randomly until no collisions or max iterations hit.""" 68 | pos_arr, quat_arr = self.initializer.sample() 69 | for k, obj_name in enumerate(self.objects): 70 | self.objects[obj_name].set("pos", array_to_string(pos_arr[k])) 71 | self.objects[obj_name].set("quat", array_to_string(quat_arr[k])) 72 | -------------------------------------------------------------------------------- /spirl/components/params.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_args(): 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--path", help="path to the config file directory") 7 | 8 | # Folder settings 9 | parser.add_argument("--prefix", help="experiment prefix, if given creates subfolder in experiment directory") 10 | parser.add_argument('--new_dir', default=False, type=int, help='If True, concat datetime string to exp_dir.') 11 | parser.add_argument('--dont_save', default=False, type=int, 12 | help="if True, nothing is saved to disk. Note: this doesn't work") # TODO this doesn't work 13 | 14 | # Running protocol 15 | parser.add_argument('--resume', default='latest', type=str, metavar='PATH', 16 | help='path to latest checkpoint (default: none)') 17 | parser.add_argument('--train', default=True, type=int, 18 | help='if False, will run one validation epoch') 19 | parser.add_argument('--test_prediction', default=True, type=int, 20 | help="if False, prediction isn't run at validation time") 21 | parser.add_argument('--skip_first_val', default=True, type=int, 22 | help='if True, will skip the first validation epoch') 23 | parser.add_argument('--val_sweep', default=False, type=int, 24 | help='if True, runs validation on all existing model checkpoints') 25 | 26 | # Misc 27 | parser.add_argument('--gpu', default=0, type=int, 28 | help='will set CUDA_VISIBLE_DEVICES to selected value') 29 | parser.add_argument('--strict_weight_loading', default=True, type=int, 30 | help='if True, uses strict weight loading function') 31 | parser.add_argument('--deterministic', default=False, type=int, 32 | help='if True, sets fixed seeds for torch and numpy') 33 | parser.add_argument('--log_interval', default=500, type=int, 34 | help='number of updates per training log') 35 | parser.add_argument('--per_epoch_img_logs', default=1, type=int, 36 | help='number of image loggings per epoch') 37 | parser.add_argument('--val_data_size', default=-1, type=int, 38 | help='number of sequences in the validation set. If -1, the full dataset is used') 39 | parser.add_argument('--val_interval', default=5, type=int, 40 | help='number of epochs per validation') 41 | 42 | # Debug 43 | parser.add_argument('--detect_anomaly', default=False, type=int, 44 | help='if True, uses autograd.detect_anomaly()') 45 | parser.add_argument('--feed_random_data', default=False, type=int, 46 | help='if True, we feed random data to the model to test its performance') 47 | parser.add_argument('--train_loop_pdb', default=False, type=int, 48 | help='if True, opens a pdb into training loop') 49 | parser.add_argument('--debug', default=False, type=int, 50 | help='if True, runs in debug mode') 51 | 52 | # add kl_div_weight 53 | parser.add_argument('--save2mp4', default=False, type=bool, 54 | help='if set, videos will be saved locally') 55 | 56 | return parser.parse_args() 57 | -------------------------------------------------------------------------------- /spirl/components/trainer_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | 5 | 6 | def save_checkpoint(state, filename='checkpoint.pth'): 7 | torch.save(state, filename) 8 | 9 | 10 | class BaseTrainer(): 11 | def override_defaults(self, policyparams): 12 | for name, value in policyparams.items(): 13 | print('overriding param {} to value {}'.format(name, value)) 14 | if value == getattr(self._hp, name): 15 | raise ValueError("attribute is {} is identical to default value!!".format(name)) 16 | self._hp.set_hparam(name, value) 17 | 18 | def call_hooks(self, inputs, output, losses, epoch): 19 | for hook in self.hooks: 20 | hook.run(inputs, output, losses, epoch) 21 | 22 | def check_nan_grads(self, grads): 23 | grads = grads #or self.model.named_parameters() 24 | return torch.stack([torch.isnan(p.grad).any() for n, p in grads]).any() 25 | 26 | def nan_grads_hook(self, inputs, output, losses, epoch): 27 | non_none = list(filter(lambda x: x[1].grad is not None, self.model.named_parameters())) 28 | if self.check_nan_grads(non_none): 29 | self.dump_debug_data(inputs, output, losses, epoch) 30 | nan_parameters = [n for n,v in filter(lambda x: torch.isnan(x[-1].grad).any(), non_none)] 31 | non_nan_parameters = [n for n, v in filter(lambda x: not torch.isnan(x[-1].grad).any(), non_none)] 32 | # grad_dict = self.print_nan_grads() 33 | print('nan gradients here: ', nan_parameters) 34 | print('there parameters are OK: ', non_nan_parameters) 35 | 36 | def get_nan_module(layer): 37 | return np.unique(map(lambda n: '.'.join(n.split('.')[:layer]), nan_parameters)) 38 | 39 | import pdb; pdb.set_trace() 40 | raise ValueError('there are NaN Gradients') 41 | 42 | def dump_debug_data(self, inputs, outputs, losses, epoch): 43 | print("Detected NaN gradients, dumping data for diagnosis, ...") 44 | debug_dict = dict() 45 | 46 | def clean_dict(d): 47 | clean_d = dict() 48 | for key, val in d.items(): 49 | if isinstance(val, torch.Tensor): 50 | clean_d[key] = val.data.cpu().numpy() 51 | elif isinstance(val, dict): 52 | clean_d[key] = clean_dict(val) 53 | return clean_d 54 | 55 | debug_dict['inputs'] = clean_dict(inputs) 56 | debug_dict['outputs'] = clean_dict(outputs) 57 | debug_dict['losses'] = clean_dict(losses) 58 | 59 | import pickle as pkl 60 | f = open(os.path.join(self._hp.exp_path, "nan_debug_info.pkl"), "wb") 61 | pkl.dump(debug_dict, f) 62 | f.close() 63 | 64 | save_checkpoint({ 65 | 'epoch': epoch, 66 | 'state_dict': self.model.state_dict(), 67 | 'optimizer': self.optimizer.state_dict(), 68 | }, os.path.join(self._hp.exp_path, "nan_debug_ckpt.pth")) 69 | 70 | def print_nan_grads(self): 71 | grad_dict = {} 72 | for name, param in self.model.named_parameters(recurse=True): 73 | print("{}:\t\t{}".format(name, bool(torch.isnan(param.grad).any().data.cpu().numpy()))) 74 | grad_dict[name] = param.grad 75 | return grad_dict 76 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/utils/mujoco_py_renderer.py: -------------------------------------------------------------------------------- 1 | from mujoco_py import MjViewer 2 | from mujoco_py.generated import const 3 | import glfw 4 | from collections import defaultdict 5 | 6 | 7 | class CustomMjViewer(MjViewer): 8 | 9 | keypress = defaultdict(list) 10 | keyup = defaultdict(list) 11 | keyrepeat = defaultdict(list) 12 | 13 | def key_callback(self, window, key, scancode, action, mods): 14 | if action == glfw.PRESS: 15 | tgt = self.keypress 16 | elif action == glfw.RELEASE: 17 | tgt = self.keyup 18 | elif action == glfw.REPEAT: 19 | tgt = self.keyrepeat 20 | else: 21 | return 22 | if tgt.get(key): 23 | for fn in tgt[key]: 24 | fn(window, key, scancode, action, mods) 25 | if tgt.get("any"): 26 | for fn in tgt["any"]: 27 | fn(window, key, scancode, action, mods) 28 | # retain functionality for closing the viewer 29 | if key == glfw.KEY_ESCAPE: 30 | super().key_callback(window, key, scancode, action, mods) 31 | else: 32 | # only use default mujoco callbacks if "any" callbacks are unset 33 | super().key_callback(window, key, scancode, action, mods) 34 | 35 | 36 | class MujocoPyRenderer: 37 | def __init__(self, sim): 38 | """ 39 | Args: 40 | sim: MjSim object 41 | """ 42 | self.viewer = CustomMjViewer(sim) 43 | self.callbacks = {} 44 | 45 | def set_camera(self, camera_id): 46 | """ 47 | Set the camera view to the specified camera ID. 48 | """ 49 | self.viewer.cam.fixedcamid = camera_id 50 | self.viewer.cam.type = const.CAMERA_FIXED 51 | 52 | def render(self): 53 | # safe for multiple calls 54 | self.viewer.render() 55 | 56 | def close(self): 57 | """ 58 | Destroys the open window and renders (pun intended) the viewer useless. 59 | """ 60 | glfw.destroy_window(self.viewer.window) 61 | self.viewer = None 62 | 63 | def add_keypress_callback(self, key, fn): 64 | """ 65 | Allows for custom callback functions for the viewer. Called on key down. 66 | Parameter 'any' will ensure that the callback is called on any key down, 67 | and block default mujoco viewer callbacks from executing, except for 68 | the ESC callback to close the viewer. 69 | """ 70 | self.viewer.keypress[key].append(fn) 71 | 72 | def add_keyup_callback(self, key, fn): 73 | """ 74 | Allows for custom callback functions for the viewer. Called on key up. 75 | Parameter 'any' will ensure that the callback is called on any key up, 76 | and block default mujoco viewer callbacks from executing, except for 77 | the ESC callback to close the viewer. 78 | """ 79 | self.viewer.keyup[key].append(fn) 80 | 81 | def add_keyrepeat_callback(self, key, fn): 82 | """ 83 | Allows for custom callback functions for the viewer. Called on key repeat. 84 | Parameter 'any' will ensure that the callback is called on any key repeat, 85 | and block default mujoco viewer callbacks from executing, except for 86 | the ESC callback to close the viewer. 87 | """ 88 | self.viewer.keyrepeat[key].append(fn) 89 | -------------------------------------------------------------------------------- /spirl/rl/agents/skill_space_agent.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | import contextlib 3 | 4 | import numpy as np 5 | 6 | from spirl.rl.components.agent import BaseAgent 7 | from spirl.utils.general_utils import ParamDict, split_along_axis, AttrDict 8 | from spirl.utils.pytorch_utils import map2torch, map2np, no_batchnorm_update 9 | 10 | 11 | class SkillSpaceAgent(BaseAgent): 12 | """Agent that acts based on pre-trained VAE skill decoder.""" 13 | def __init__(self, config): 14 | super().__init__(config) 15 | self._update_model_params() # transfer some parameters to model 16 | 17 | self._policy = self._hp.model(self._hp.model_params, logger=None) 18 | self.load_model_weights(self._policy, self._hp.model_checkpoint, self._hp.model_epoch) 19 | 20 | self.action_plan = deque() 21 | 22 | def _default_hparams(self): 23 | default_dict = ParamDict({ 24 | 'model': None, # policy class 25 | 'model_params': None, # parameters for the policy class 26 | 'model_checkpoint': None, # checkpoint path of the model 27 | 'model_epoch': 'latest', # epoch that checkpoint should be loaded for (defaults to latest) 28 | }) 29 | return super()._default_hparams().overwrite(default_dict) 30 | 31 | def _act(self, obs): 32 | assert len(obs.shape) == 2 and obs.shape[0] == 1 # assume single-observation batches with leading 1-dim 33 | if not self.action_plan: 34 | # generate action plan if the current one is empty 35 | split_obs = self._split_obs(obs) 36 | with no_batchnorm_update(self._policy) if obs.shape[0] == 1 else contextlib.suppress(): 37 | actions = self._policy.decode(map2torch(split_obs.z, self._hp.device), 38 | map2torch(split_obs.cond_input, self._hp.device), 39 | self._policy.n_rollout_steps) 40 | self.action_plan = deque(split_along_axis(map2np(actions), axis=1)) 41 | return AttrDict(action=self.action_plan.popleft()) 42 | 43 | def reset(self): 44 | self.action_plan = deque() # reset action plan 45 | 46 | def update(self, experience_batch): 47 | return {} # TODO(karl) implement finetuning for policy 48 | 49 | def _split_obs(self, obs): 50 | assert obs.shape[1] == self._policy.state_dim + self._policy.latent_dim 51 | return AttrDict( 52 | cond_input=obs[:, :-self._policy.latent_dim], # condition decoding on state 53 | z=obs[:, -self._policy.latent_dim:], 54 | ) 55 | 56 | def sync_networks(self): 57 | pass # TODO(karl) only need to implement if we implement finetuning 58 | 59 | def _update_model_params(self): 60 | self._hp.model_params.device = self._hp.device # transfer device to low-level model 61 | self._hp.model_params.batch_size = 1 # run only single-element batches 62 | 63 | def _act_rand(self, obs): 64 | return self._act(obs) 65 | 66 | 67 | class ACSkillSpaceAgent(SkillSpaceAgent): 68 | """Unflattens prior input part of observation.""" 69 | def _split_obs(self, obs): 70 | unflattened_obs = map2np(self._policy.unflatten_obs( 71 | map2torch(obs[:, :-self._policy.latent_dim], device=self.device))) 72 | return AttrDict( 73 | cond_input=unflattened_obs.prior_obs, 74 | z=obs[:, -self._policy.latent_dim:], 75 | ) 76 | -------------------------------------------------------------------------------- /spirl/data/block_stacking/src/robosuite/models/objects/xml_objects.py: -------------------------------------------------------------------------------- 1 | from spirl.data.block_stacking.src.robosuite.models.objects import MujocoXMLObject 2 | from spirl.data.block_stacking.src.robosuite.utils.mjcf_utils import xml_path_completion 3 | 4 | 5 | class BottleObject(MujocoXMLObject): 6 | """ 7 | Bottle object 8 | """ 9 | 10 | def __init__(self): 11 | super().__init__(xml_path_completion("objects/bottle.xml")) 12 | 13 | 14 | class CanObject(MujocoXMLObject): 15 | """ 16 | Coke can object (used in SawyerPickPlace) 17 | """ 18 | 19 | def __init__(self): 20 | super().__init__(xml_path_completion("objects/can.xml")) 21 | 22 | 23 | class LemonObject(MujocoXMLObject): 24 | """ 25 | Lemon object 26 | """ 27 | 28 | def __init__(self): 29 | super().__init__(xml_path_completion("objects/lemon.xml")) 30 | 31 | 32 | class MilkObject(MujocoXMLObject): 33 | """ 34 | Milk carton object (used in SawyerPickPlace) 35 | """ 36 | 37 | def __init__(self): 38 | super().__init__(xml_path_completion("objects/milk.xml")) 39 | 40 | 41 | class BreadObject(MujocoXMLObject): 42 | """ 43 | Bread loaf object (used in SawyerPickPlace) 44 | """ 45 | 46 | def __init__(self): 47 | super().__init__(xml_path_completion("objects/bread.xml")) 48 | 49 | 50 | class CerealObject(MujocoXMLObject): 51 | """ 52 | Cereal box object (used in SawyerPickPlace) 53 | """ 54 | 55 | def __init__(self): 56 | super().__init__(xml_path_completion("objects/cereal.xml")) 57 | 58 | 59 | class SquareNutObject(MujocoXMLObject): 60 | """ 61 | Square nut object (used in SawyerNutAssembly) 62 | """ 63 | 64 | def __init__(self): 65 | super().__init__(xml_path_completion("objects/square-nut.xml")) 66 | 67 | 68 | class RoundNutObject(MujocoXMLObject): 69 | """ 70 | Round nut (used in SawyerNutAssembly) 71 | """ 72 | 73 | def __init__(self): 74 | super().__init__(xml_path_completion("objects/round-nut.xml")) 75 | 76 | 77 | class MilkVisualObject(MujocoXMLObject): 78 | """ 79 | Visual fiducial of milk carton (used in SawyerPickPlace). 80 | 81 | Fiducial objects are not involved in collision physics. 82 | They provide a point of reference to indicate a position. 83 | """ 84 | 85 | def __init__(self): 86 | super().__init__(xml_path_completion("objects/milk-visual.xml")) 87 | 88 | 89 | class BreadVisualObject(MujocoXMLObject): 90 | """ 91 | Visual fiducial of bread loaf (used in SawyerPickPlace) 92 | """ 93 | 94 | def __init__(self): 95 | super().__init__(xml_path_completion("objects/bread-visual.xml")) 96 | 97 | 98 | class CerealVisualObject(MujocoXMLObject): 99 | """ 100 | Visual fiducial of cereal box (used in SawyerPickPlace) 101 | """ 102 | 103 | def __init__(self): 104 | super().__init__(xml_path_completion("objects/cereal-visual.xml")) 105 | 106 | 107 | class CanVisualObject(MujocoXMLObject): 108 | """ 109 | Visual fiducial of coke can (used in SawyerPickPlace) 110 | """ 111 | 112 | def __init__(self): 113 | super().__init__(xml_path_completion("objects/can-visual.xml")) 114 | 115 | 116 | class PlateWithHoleObject(MujocoXMLObject): 117 | """ 118 | Square plate with a hole in the center (used in BaxterPegInHole) 119 | """ 120 | 121 | def __init__(self): 122 | super().__init__(xml_path_completion("objects/plate-with-hole.xml")) 123 | --------------------------------------------------------------------------------