├── models ├── __init__.py ├── encoder │ ├── __init__.py │ ├── r3m │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── models_language.py │ │ │ └── models_r3m.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ └── data_loaders.py │ │ ├── cfgs │ │ │ ├── hydra │ │ │ │ ├── output │ │ │ │ │ └── local.yaml │ │ │ │ └── launcher │ │ │ │ │ └── local.yaml │ │ │ └── config_rep.yaml │ │ ├── r3m_base.yaml │ │ └── __init__.py │ └── resnet.py ├── dino.py ├── dummy.py ├── proprio.py ├── decoder │ └── transposed_conv.py └── vit.py ├── datasets ├── __init__.py ├── img_transforms.py ├── point_maze_dset.py └── traj_dset.py ├── planning ├── __init__.py ├── base_planner.py ├── objectives.py ├── gd.py ├── cem.py └── mpc.py ├── env ├── wall │ ├── envs │ │ └── __init__.py │ ├── data │ │ └── __init__.py │ └── wall_env_wrapper.py ├── pointmaze │ ├── gridcraft │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── wrappers.py │ │ └── grid_spec.py │ ├── waypoint_controller.py │ ├── point_maze_wrapper.py │ ├── q_iteration.py │ └── dynamic_mjc.py ├── pusht │ ├── __init__.py │ └── pusht_wrapper.py ├── deformable_env │ └── src │ │ └── sim │ │ ├── assets │ │ └── xarm │ │ │ ├── xarm_gripper │ │ │ └── meshes │ │ │ │ ├── base_link.STL │ │ │ │ ├── left_finger.STL │ │ │ │ ├── right_finger.STL │ │ │ │ ├── left_inner_knuckle.STL │ │ │ │ ├── left_outer_knuckle.STL │ │ │ │ ├── right_inner_knuckle.STL │ │ │ │ ├── right_outer_knuckle.STL │ │ │ │ ├── stick_4cm.obj │ │ │ │ ├── stick_3cm.obj │ │ │ │ ├── stick_6cm.obj │ │ │ │ ├── stick_1cm.obj │ │ │ │ └── stick_2cm.obj │ │ │ ├── xarm_description │ │ │ └── meshes │ │ │ │ └── xarm6 │ │ │ │ ├── visual │ │ │ │ ├── base.stl │ │ │ │ ├── link1.stl │ │ │ │ ├── link2.stl │ │ │ │ ├── link3.stl │ │ │ │ ├── link4.stl │ │ │ │ ├── link5.stl │ │ │ │ └── link6.stl │ │ │ │ └── collision │ │ │ │ ├── base.mtl │ │ │ │ ├── link1.mtl │ │ │ │ ├── link2.mtl │ │ │ │ ├── link3.mtl │ │ │ │ ├── link4.mtl │ │ │ │ ├── link5.mtl │ │ │ │ ├── link6.mtl │ │ │ │ ├── link2_vhacd2.mtl │ │ │ │ └── link6_vhacd.obj │ │ │ ├── LICENSE │ │ │ ├── link6.urdf │ │ │ ├── base.urdf │ │ │ ├── link1.urdf │ │ │ ├── link6_com.urdf │ │ │ ├── link3.urdf │ │ │ ├── base_com.urdf │ │ │ ├── link5.urdf │ │ │ ├── link4.urdf │ │ │ ├── link1_com.urdf │ │ │ ├── link2.urdf │ │ │ ├── link3_com.urdf │ │ │ ├── link5_com.urdf │ │ │ ├── link4_com.urdf │ │ │ └── link2_com.urdf │ │ ├── sim_env │ │ ├── flex_scene.py │ │ ├── cameras.py │ │ └── robot_env.py │ │ └── data_gen │ │ └── data.py ├── __init__.py └── serial_vector_env.py ├── .gitignore ├── conf ├── encoder │ ├── dummy.yaml │ ├── r3m.yaml │ ├── resnet.yaml │ ├── dino.yaml │ └── dino_cls.yaml ├── action_encoder │ ├── dummy.yaml │ └── proprio.yaml ├── proprio_encoder │ ├── dummy.yaml │ └── proprio.yaml ├── decoder │ ├── vqvae.yaml │ └── transposed_conv.yaml ├── predictor │ └── vit.yaml ├── planner │ ├── cem.yaml │ ├── gd.yaml │ ├── mpc_cem.yaml │ └── mpc_gd.yaml ├── env │ ├── point_maze.yaml │ ├── wall.yaml │ ├── pusht.yaml │ ├── deformable_env.yaml │ ├── rope.yaml │ └── granular.yaml ├── plan.yaml ├── plan_wall.yaml ├── plan_point_maze.yaml ├── plan_pusht.yaml └── train.yaml ├── assets └── intro.png ├── distributed_fn ├── __init__.py ├── launch.py └── distributed.py ├── custom_resolvers.py ├── metrics ├── lpipsPyTorch │ ├── __init__.py │ └── modules │ │ ├── utils.py │ │ ├── lpips.py │ │ └── networks.py └── image_metrics.py ├── LICENSE ├── preprocessor.py └── utils.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /planning/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /env/wall/envs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /env/pointmaze/gridcraft/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/encoder/r3m/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/encoder/r3m/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /env/pusht/__init__.py: -------------------------------------------------------------------------------- 1 | from .pusht_env import PushTEnv 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pyc 3 | outputs 4 | plan_outputs 5 | -------------------------------------------------------------------------------- /conf/encoder/dummy.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.dummy.DummyModel 2 | emb_dim: 7 -------------------------------------------------------------------------------- /conf/action_encoder/dummy.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.dummy.DummyRepeatActionEncoder 2 | -------------------------------------------------------------------------------- /conf/encoder/r3m.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.encoder.r3m.load_r3m 2 | modelid: resnet18 -------------------------------------------------------------------------------- /conf/proprio_encoder/dummy.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.dummy.DummyRepeatActionEncoder 2 | -------------------------------------------------------------------------------- /assets/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaoyuezhou/dino_wm/HEAD/assets/intro.png -------------------------------------------------------------------------------- /conf/encoder/resnet.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.encoder.resnet.resnet18 2 | pretrained: True 3 | unit_norm: False -------------------------------------------------------------------------------- /conf/encoder/dino.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.dino.DinoV2Encoder 2 | name: "dinov2_vits14" 3 | feature_key: "x_norm_patchtokens" 4 | -------------------------------------------------------------------------------- /conf/encoder/dino_cls.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.dino.DinoV2Encoder 2 | name: "dinov2_vits14" 3 | feature_key: "x_norm_clstoken" 4 | -------------------------------------------------------------------------------- /conf/action_encoder/proprio.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.proprio.ProprioceptiveEmbedding 2 | num_frames: 1 3 | tubelet_size: 1 4 | use_3d_pos: False -------------------------------------------------------------------------------- /conf/proprio_encoder/proprio.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.proprio.ProprioceptiveEmbedding 2 | num_frames: 1 3 | tubelet_size: 1 4 | use_3d_pos: False -------------------------------------------------------------------------------- /conf/decoder/vqvae.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.vqvae.VQVAE 2 | channel: 384 3 | n_embed: 2048 4 | n_res_block: 4 5 | n_res_channel: 128 6 | quantize: False -------------------------------------------------------------------------------- /conf/predictor/vit.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.vit.ViTPredictor 2 | depth: 6 3 | heads: 16 4 | mlp_dim: 2048 5 | dropout: 0.1 6 | emb_dropout: 0 7 | pool: 'mean' -------------------------------------------------------------------------------- /conf/planner/cem.yaml: -------------------------------------------------------------------------------- 1 | _target_: planning.cem.CEMPlanner 2 | horizon: 5 3 | topk: 30 4 | num_samples: 300 5 | var_scale: 1 6 | opt_steps: 30 7 | eval_every: 1 8 | 9 | name: cem -------------------------------------------------------------------------------- /conf/decoder/transposed_conv.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.decoder.transposed_conv.TransposedConvDecoder 2 | observation_shape: [3, 224, 224] 3 | depth: 64 4 | kernel_size: 5 5 | stride: 3 -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/base_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaoyuezhou/dino_wm/HEAD/env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/base_link.STL -------------------------------------------------------------------------------- /conf/planner/gd.yaml: -------------------------------------------------------------------------------- 1 | _target_: planning.gd.GDPlanner 2 | horizon: 5 3 | action_noise: 0.003 4 | sample_type: 'randn' # 'zero' or 'randn' 5 | lr: 1 6 | opt_steps: 1000 7 | eval_every: 10 8 | 9 | name: gd -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/left_finger.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaoyuezhou/dino_wm/HEAD/env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/left_finger.STL -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/right_finger.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaoyuezhou/dino_wm/HEAD/env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/right_finger.STL -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/left_inner_knuckle.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaoyuezhou/dino_wm/HEAD/env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/left_inner_knuckle.STL -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/left_outer_knuckle.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaoyuezhou/dino_wm/HEAD/env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/left_outer_knuckle.STL -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/right_inner_knuckle.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaoyuezhou/dino_wm/HEAD/env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/right_inner_knuckle.STL -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/right_outer_knuckle.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaoyuezhou/dino_wm/HEAD/env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/right_outer_knuckle.STL -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/visual/base.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaoyuezhou/dino_wm/HEAD/env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/visual/base.stl -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/visual/link1.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaoyuezhou/dino_wm/HEAD/env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/visual/link1.stl -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/visual/link2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaoyuezhou/dino_wm/HEAD/env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/visual/link2.stl -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/visual/link3.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaoyuezhou/dino_wm/HEAD/env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/visual/link3.stl -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/visual/link4.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaoyuezhou/dino_wm/HEAD/env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/visual/link4.stl -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/visual/link5.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaoyuezhou/dino_wm/HEAD/env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/visual/link5.stl -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/visual/link6.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaoyuezhou/dino_wm/HEAD/env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/visual/link6.stl -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/collision/base.mtl: -------------------------------------------------------------------------------- 1 | # Blender MTL File: 'None' 2 | # Material Count: 1 3 | 4 | newmtl None 5 | Ns 500 6 | Ka 0.8 0.8 0.8 7 | Kd 0.8 0.8 0.8 8 | Ks 0.8 0.8 0.8 9 | d 1 10 | illum 2 11 | -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/collision/link1.mtl: -------------------------------------------------------------------------------- 1 | # Blender MTL File: 'None' 2 | # Material Count: 1 3 | 4 | newmtl None 5 | Ns 500 6 | Ka 0.8 0.8 0.8 7 | Kd 0.8 0.8 0.8 8 | Ks 0.8 0.8 0.8 9 | d 1 10 | illum 2 11 | -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/collision/link2.mtl: -------------------------------------------------------------------------------- 1 | # Blender MTL File: 'None' 2 | # Material Count: 1 3 | 4 | newmtl None 5 | Ns 500 6 | Ka 0.8 0.8 0.8 7 | Kd 0.8 0.8 0.8 8 | Ks 0.8 0.8 0.8 9 | d 1 10 | illum 2 11 | -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/collision/link3.mtl: -------------------------------------------------------------------------------- 1 | # Blender MTL File: 'None' 2 | # Material Count: 1 3 | 4 | newmtl None 5 | Ns 500 6 | Ka 0.8 0.8 0.8 7 | Kd 0.8 0.8 0.8 8 | Ks 0.8 0.8 0.8 9 | d 1 10 | illum 2 11 | -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/collision/link4.mtl: -------------------------------------------------------------------------------- 1 | # Blender MTL File: 'None' 2 | # Material Count: 1 3 | 4 | newmtl None 5 | Ns 500 6 | Ka 0.8 0.8 0.8 7 | Kd 0.8 0.8 0.8 8 | Ks 0.8 0.8 0.8 9 | d 1 10 | illum 2 11 | -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/collision/link5.mtl: -------------------------------------------------------------------------------- 1 | # Blender MTL File: 'None' 2 | # Material Count: 1 3 | 4 | newmtl None 5 | Ns 500 6 | Ka 0.8 0.8 0.8 7 | Kd 0.8 0.8 0.8 8 | Ks 0.8 0.8 0.8 9 | d 1 10 | illum 2 11 | -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/collision/link6.mtl: -------------------------------------------------------------------------------- 1 | # Blender MTL File: 'None' 2 | # Material Count: 1 3 | 4 | newmtl None 5 | Ns 500 6 | Ka 0.8 0.8 0.8 7 | Kd 0.8 0.8 0.8 8 | Ks 0.8 0.8 0.8 9 | d 1 10 | illum 2 11 | -------------------------------------------------------------------------------- /conf/planner/mpc_cem.yaml: -------------------------------------------------------------------------------- 1 | _target_: planning.mpc.MPCPlanner 2 | max_iter: null # unlimited if null 3 | n_taken_actions: 5 4 | sub_planner: 5 | target: planning.cem.CEMPlanner 6 | horizon: 5 7 | topk: 30 8 | num_samples: 300 9 | var_scale: 1 10 | opt_steps: 30 11 | eval_every: 1 12 | 13 | name: mpc_cem -------------------------------------------------------------------------------- /distributed_fn/__init__.py: -------------------------------------------------------------------------------- 1 | from .distributed import ( 2 | get_rank, 3 | get_local_rank, 4 | is_primary, 5 | synchronize, 6 | get_world_size, 7 | all_reduce, 8 | all_gather, 9 | reduce_dict, 10 | data_sampler, 11 | LOCAL_PROCESS_GROUP, 12 | ) 13 | from .launch import launch 14 | -------------------------------------------------------------------------------- /conf/planner/mpc_gd.yaml: -------------------------------------------------------------------------------- 1 | _target_: planning.mpc.MPCPlanner 2 | max_iter: null # unlimited if null 3 | n_taken_actions: 1 4 | sub_planner: 5 | target: planning.gd.GDPlanner 6 | horizon: 5 7 | action_noise: 0.003 8 | sample_type: 'randn' # 'zero' or 'randn' 9 | lr: 1 10 | opt_steps: 1000 11 | eval_every: 10 12 | 13 | name: mpc_gd -------------------------------------------------------------------------------- /datasets/img_transforms.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | 3 | def default_transform(img_size=224): 4 | return transforms.Compose( 5 | [ 6 | transforms.Resize(img_size), 7 | transforms.CenterCrop(img_size), 8 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 9 | ] 10 | ) -------------------------------------------------------------------------------- /models/encoder/r3m/cfgs/hydra/output/local.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | hydra: 3 | run: 4 | dir: ./r3moutput/${hydra.job.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 5 | subdir: ${hydra.job.num}_${hydra.job.override_dirname} 6 | sweep: 7 | dir: ./r3moutput/${hydra.job.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 8 | subdir: ${hydra.job.num}_${hydra.job.override_dirname} -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/collision/link2_vhacd2.mtl: -------------------------------------------------------------------------------- 1 | # Blender MTL File: 'None' 2 | # Material Count: 1 3 | 4 | newmtl Default_OBJ 5 | Ns 225.000000 6 | Ka 1.000000 1.000000 1.000000 7 | Kd 0.800000 0.800000 0.800000 8 | Ks 0.500000 0.500000 0.500000 9 | Ke 0.000000 0.000000 0.000000 10 | Ni 1.450000 11 | d 1.000000 12 | illum 2 13 | -------------------------------------------------------------------------------- /env/wall/data/__init__.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, auto 2 | from .single import DotDataset, DotDatasetConfig, Sample 3 | from .wall import WallDataset, WallDatasetConfig 4 | from .configs import ConfigBase 5 | 6 | class DatasetType(Enum): 7 | Single = auto() 8 | Multiple = auto() 9 | Wall = auto() 10 | WallExpert = auto() 11 | WallEigenfunc = auto() 12 | -------------------------------------------------------------------------------- /models/encoder/r3m/cfgs/hydra/launcher/local.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | hydra: 3 | launcher: 4 | cpus_per_task: 20 5 | gpus_per_node: 0 6 | tasks_per_node: 1 7 | timeout_min: 600 8 | mem_gb: 64 9 | name: ${hydra.job.name} 10 | _target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.LocalLauncher 11 | submitit_folder: ${hydra.sweep.dir}/.submitit/%j 12 | -------------------------------------------------------------------------------- /conf/env/point_maze.yaml: -------------------------------------------------------------------------------- 1 | name: point_maze 2 | args: [] 3 | kwargs: {} 4 | 5 | dataset: 6 | _target_: "datasets.point_maze_dset.load_point_maze_slice_train_val" 7 | n_rollout: null 8 | normalize_action: ${normalize_action} 9 | data_path: ${oc.env:DATASET_DIR}/point_maze 10 | split_ratio: 0.9 11 | transform: 12 | _target_: "datasets.img_transforms.default_transform" 13 | img_size: ${img_size} 14 | 15 | decoder_path: null 16 | num_workers: 16 -------------------------------------------------------------------------------- /conf/env/wall.yaml: -------------------------------------------------------------------------------- 1 | name: wall 2 | args: [] 3 | kwargs: {} 4 | 5 | dataset: 6 | _target_: "datasets.wall_dset.load_wall_slice_train_val" 7 | n_rollout: null 8 | normalize_action: ${normalize_action} 9 | data_path: ${oc.env:DATASET_DIR}/wall_single 10 | split_ratio: 0.9 11 | split_mode: "random" 12 | transform: 13 | _target_: "datasets.img_transforms.default_transform" 14 | img_size: ${img_size} 15 | 16 | decoder_path: null 17 | num_workers: 16 -------------------------------------------------------------------------------- /custom_resolvers.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from omegaconf import OmegaConf 3 | 4 | @hydra.main(config_path=None) 5 | def register_resolvers(cfg): 6 | pass 7 | 8 | # Define the resolver function 9 | def replace_slash(value: str) -> str: 10 | return value.replace('/', '_') 11 | 12 | # Register the resolver with Hydra 13 | OmegaConf.register_new_resolver("replace_slash", replace_slash) 14 | 15 | if __name__ == "__main__": 16 | register_resolvers() 17 | 18 | -------------------------------------------------------------------------------- /conf/env/pusht.yaml: -------------------------------------------------------------------------------- 1 | name: pusht 2 | args: [] 3 | kwargs: 4 | with_velocity: true 5 | with_target: true 6 | 7 | dataset: 8 | _target_: "datasets.pusht_dset.load_pusht_slice_train_val" 9 | with_velocity: true 10 | n_rollout: null 11 | normalize_action: ${normalize_action} 12 | data_path: ${oc.env:DATASET_DIR}/pusht_noise 13 | split_ratio: 0.9 14 | transform: 15 | _target_: "datasets.img_transforms.default_transform" 16 | img_size: ${img_size} 17 | 18 | decoder_path: null 19 | num_workers: 16 -------------------------------------------------------------------------------- /conf/env/deformable_env.yaml: -------------------------------------------------------------------------------- 1 | name: deformable_env 2 | args: [] 3 | kwargs: 4 | object_name: "granular" 5 | 6 | load_dir: "" 7 | 8 | dataset: 9 | _target_: "datasets.deformable_env_dset.load_deformable_dset_slice_train_val" 10 | n_rollout: null 11 | normalize_action: ${normalize_action} 12 | data_path: ${oc.env:DATASET_DIR}/deformable 13 | object_name: "granular" 14 | split_ratio: 0.9 15 | transform: 16 | _target_: "datasets.img_transforms.default_transform" 17 | img_size: ${img_size} 18 | 19 | decoder_path: null 20 | num_workers: 16 -------------------------------------------------------------------------------- /metrics/lpipsPyTorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .modules.lpips import LPIPS 4 | 5 | 6 | def lpips(x: torch.Tensor, 7 | y: torch.Tensor, 8 | net_type: str = 'alex', 9 | version: str = '0.1'): 10 | r"""Function that measures 11 | Learned Perceptual Image Patch Similarity (LPIPS). 12 | 13 | Arguments: 14 | x, y (torch.Tensor): the input tensors to compare. 15 | net_type (str): the network type to compare the features: 16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 17 | version (str): the version of LPIPS. Default: 0.1. 18 | """ 19 | device = x.device 20 | criterion = LPIPS(net_type, version).to(device) 21 | return criterion(x, y) 22 | -------------------------------------------------------------------------------- /models/encoder/r3m/cfgs/config_rep.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - override hydra/launcher: local 4 | - override hydra/output: local 5 | 6 | 7 | # snapshot 8 | save_snapshot: false 9 | load_snap: "" 10 | # replay buffer 11 | num_workers: 10 12 | batch_size: 32 #256 13 | train_steps: 2000000 14 | eval_freq: 20000 15 | # misc 16 | seed: 1 17 | device: cuda 18 | # experiment 19 | experiment: train_r3m 20 | # agent 21 | lr: 1e-4 22 | # data 23 | alpha: 0.2 24 | dataset: "ego4d" 25 | wandbproject: 26 | wandbuser: 27 | doaug: "none" 28 | datapath: 29 | 30 | agent: 31 | _target_: r3m.R3M 32 | device: ${device} 33 | lr: ${lr} 34 | hidden_dim: 1024 35 | size: 34 36 | l2weight: 0.00001 37 | l1weight: 0.00001 38 | tcnweight: 1.0 39 | langweight: 0.0 40 | l2dist: true 41 | bs: ${batch_size} 42 | -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/stick_4cm.obj: -------------------------------------------------------------------------------- 1 | # File units = meters 2 | mtllib Stick.mtl 3 | g Part 1 4 | v 0.0217921 0.0189459 0.1 5 | v 0.0217921 0.0189459 0 6 | v -0.0182079 0.0189459 0 7 | v -0.0182079 0.0189459 0.1 8 | v -0.0182079 -0.0210542 0 9 | v -0.0182079 -0.0210542 0.1 10 | v 0.0217921 -0.0210542 0 11 | v 0.0217921 -0.0210542 0.1 12 | vn 0 1 0 13 | vn -1 0 0 14 | vn 0 -1 0 15 | vn 1 0 0 16 | vn 0 0 1 17 | vn 0 0 -1 18 | usemtl 0.615686_0.811765_0.929412_0.000000_0.000000 19 | o mesh0 20 | f 1//1 2//1 3//1 21 | f 3//1 4//1 1//1 22 | o mesh1 23 | f 4//2 3//2 5//2 24 | f 5//2 6//2 4//2 25 | o mesh2 26 | f 6//3 5//3 7//3 27 | f 7//3 8//3 6//3 28 | o mesh3 29 | f 8//4 7//4 2//4 30 | f 2//4 1//4 8//4 31 | o mesh4 32 | f 8//5 1//5 4//5 33 | f 4//5 6//5 8//5 34 | o mesh5 35 | f 5//6 3//6 2//6 36 | f 2//6 7//6 5//6 37 | -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/stick_3cm.obj: -------------------------------------------------------------------------------- 1 | # File units = meters 2 | mtllib stick_3cm.mtl 3 | g Part 1 4 | v 0.0167921 0.0139459 0.1 5 | v 0.0167921 0.0139459 0 6 | v -0.0132079 0.0139459 0 7 | v -0.0132079 0.0139459 0.1 8 | v -0.0132079 -0.0160542 0 9 | v -0.0132079 -0.0160542 0.1 10 | v 0.0167921 -0.0160542 0 11 | v 0.0167921 -0.0160542 0.1 12 | vn 0 1 0 13 | vn -1 0 0 14 | vn 0 -1 0 15 | vn 1 0 0 16 | vn 0 0 1 17 | vn 0 0 -1 18 | usemtl 0.615686_0.811765_0.929412_0.000000_0.000000 19 | o mesh0 20 | f 1//1 2//1 3//1 21 | f 3//1 4//1 1//1 22 | o mesh1 23 | f 4//2 3//2 5//2 24 | f 5//2 6//2 4//2 25 | o mesh2 26 | f 6//3 5//3 7//3 27 | f 7//3 8//3 6//3 28 | o mesh3 29 | f 8//4 7//4 2//4 30 | f 2//4 1//4 8//4 31 | o mesh4 32 | f 8//5 1//5 4//5 33 | f 4//5 6//5 8//5 34 | o mesh5 35 | f 5//6 3//6 2//6 36 | f 2//6 7//6 5//6 37 | -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/stick_6cm.obj: -------------------------------------------------------------------------------- 1 | # File units = meters 2 | mtllib stick_6cm.mtl 3 | g Part 1 4 | v 0.0317921 0.0289459 0.1 5 | v 0.0317921 0.0289459 0 6 | v -0.0282079 0.0289459 0 7 | v -0.0282079 0.0289459 0.1 8 | v -0.0282079 -0.0310542 0 9 | v -0.0282079 -0.0310542 0.1 10 | v 0.0317921 -0.0310542 0 11 | v 0.0317921 -0.0310542 0.1 12 | vn 0 1 0 13 | vn -1 0 0 14 | vn 0 -1 0 15 | vn 1 0 0 16 | vn 0 0 1 17 | vn 0 0 -1 18 | usemtl 0.615686_0.811765_0.929412_0.000000_0.000000 19 | o mesh0 20 | f 1//1 2//1 3//1 21 | f 3//1 4//1 1//1 22 | o mesh1 23 | f 4//2 3//2 5//2 24 | f 5//2 6//2 4//2 25 | o mesh2 26 | f 6//3 5//3 7//3 27 | f 7//3 8//3 6//3 28 | o mesh3 29 | f 8//4 7//4 2//4 30 | f 2//4 1//4 8//4 31 | o mesh4 32 | f 8//5 1//5 4//5 33 | f 4//5 6//5 8//5 34 | o mesh5 35 | f 5//6 3//6 2//6 36 | f 2//6 7//6 5//6 37 | -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/stick_1cm.obj: -------------------------------------------------------------------------------- 1 | # File units = meters 2 | mtllib stick.mtl 3 | g Part 1 4 | v 0.0100144 0.00716822 0.1 5 | v 0.0100144 0.00716822 0 6 | v -0.00643032 0.00716822 0 7 | v -0.00643032 0.00716822 0.1 8 | v -0.00643032 -0.00927652 0 9 | v -0.00643032 -0.00927652 0.1 10 | v 0.0100144 -0.00927652 0 11 | v 0.0100144 -0.00927652 0.1 12 | vn 0 1 0 13 | vn -1 0 0 14 | vn 0 -1 0 15 | vn 1 0 0 16 | vn 0 0 1 17 | vn 0 0 -1 18 | usemtl 0.615686_0.811765_0.929412_0.000000_0.000000 19 | o mesh0 20 | f 1//1 2//1 3//1 21 | f 3//1 4//1 1//1 22 | o mesh1 23 | f 4//2 3//2 5//2 24 | f 5//2 6//2 4//2 25 | o mesh2 26 | f 6//3 5//3 7//3 27 | f 7//3 8//3 6//3 28 | o mesh3 29 | f 8//4 7//4 2//4 30 | f 2//4 1//4 8//4 31 | o mesh4 32 | f 8//5 1//5 4//5 33 | f 4//5 6//5 8//5 34 | o mesh5 35 | f 5//6 3//6 2//6 36 | f 2//6 7//6 5//6 37 | -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/stick_2cm.obj: -------------------------------------------------------------------------------- 1 | # File units = meters 2 | mtllib stick_2cm.mtl 3 | g Part 1 4 | v 0.0117921 0.00894585 0.1 5 | v 0.0117921 0.00894585 0 6 | v -0.00820794 0.00894585 0 7 | v -0.00820794 0.00894585 0.1 8 | v -0.00820794 -0.0110542 0 9 | v -0.00820794 -0.0110542 0.1 10 | v 0.0117921 -0.0110542 0 11 | v 0.0117921 -0.0110542 0.1 12 | vn 0 1 0 13 | vn -1 0 0 14 | vn 0 -1 0 15 | vn 1 0 0 16 | vn 0 0 1 17 | vn 0 0 -1 18 | usemtl 0.615686_0.811765_0.929412_0.000000_0.000000 19 | o mesh0 20 | f 1//1 2//1 3//1 21 | f 3//1 4//1 1//1 22 | o mesh1 23 | f 4//2 3//2 5//2 24 | f 5//2 6//2 4//2 25 | o mesh2 26 | f 6//3 5//3 7//3 27 | f 7//3 8//3 6//3 28 | o mesh3 29 | f 8//4 7//4 2//4 30 | f 2//4 1//4 8//4 31 | o mesh4 32 | f 8//5 1//5 4//5 33 | f 4//5 6//5 8//5 34 | o mesh5 35 | f 5//6 3//6 2//6 36 | f 2//6 7//6 5//6 37 | -------------------------------------------------------------------------------- /conf/env/rope.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | obj: "rope" 3 | obj_params: 4 | rope_length: 2.5 5 | stiffness: 0.8 6 | z_rotation: 10.0 7 | 8 | 9 | # data collection 10 | base: 0 11 | n_episode: 1000 12 | n_timestep: 20 13 | n_worker: 40 14 | 15 | # sim env 16 | headless: True # False: OpenGL visualization 17 | camera_view: 1 # 0 (top view), 1, 2, 3, 4 18 | screenWidth: 224 19 | screenHeight: 224 20 | 21 | robot_type: 'xarm6' 22 | robot_end_idx: 6 23 | robot_num_dofs: 6 24 | robot_speed_inv: 300 25 | 26 | action_dim: 4 # [x_start, z_start, x_end, z_end] 27 | action_space: 4 # random action space scope 28 | 29 | # Tool 30 | gripper: False 31 | pusher_len: 1.0 32 | 33 | # Save particles 34 | fps: False 35 | fps_number: 2000 36 | 37 | rob_obj_dist_thresh: 0.2 38 | contact_interval: 40 39 | non_contact_interval: 80 40 | 41 | # others 42 | color_threshold: 0.01 43 | -------------------------------------------------------------------------------- /models/dino.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | torch.hub._validate_not_a_forked_repo=lambda a,b,c: True 5 | 6 | class DinoV2Encoder(nn.Module): 7 | def __init__(self, name, feature_key): 8 | super().__init__() 9 | self.name = name 10 | self.base_model = torch.hub.load("facebookresearch/dinov2", name) 11 | self.feature_key = feature_key 12 | self.emb_dim = self.base_model.num_features 13 | if feature_key == "x_norm_patchtokens": 14 | self.latent_ndim = 2 15 | elif feature_key == "x_norm_clstoken": 16 | self.latent_ndim = 1 17 | else: 18 | raise ValueError(f"Invalid feature key: {feature_key}") 19 | 20 | self.patch_size = self.base_model.patch_size 21 | 22 | def forward(self, x): 23 | emb = self.base_model.forward_features(x)[self.feature_key] 24 | if self.latent_ndim == 1: 25 | emb = emb.unsqueeze(1) # dummy patch dim 26 | return emb -------------------------------------------------------------------------------- /conf/env/granular.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | obj: "granular" 3 | obj_params: 4 | granular_scale: 0.2 5 | area: 5.0 # change particle numbers 6 | # area: 1.5 7 | xz_ratio: 1.0 8 | granular_dis: 0.03 9 | 10 | # data collection 11 | base: 0 12 | n_episode: 1000 13 | n_timestep: 20 14 | n_worker: 40 15 | 16 | # sim env 17 | headless: True # False: OpenGL visualization 18 | camera_view: 1 # 0 (top view), 1, 2, 3, 4 19 | screenWidth: 224 20 | screenHeight: 224 21 | 22 | robot_type: 'xarm6' 23 | robot_end_idx: 6 24 | robot_num_dofs: 6 25 | robot_speed_inv: 300 26 | 27 | action_dim: 4 # [x_start, z_start, x_end, z_end] 28 | action_space: 4 # random action space scope 29 | 30 | # Tool 31 | gripper: False 32 | pusher_len: 1.3 33 | 34 | # Save particles 35 | fps: False 36 | fps_number: 2000 37 | 38 | rob_obj_dist_thresh: 0.2 39 | contact_interval: 40 40 | non_contact_interval: 80 41 | 42 | # others 43 | color_threshold: 0.01 44 | -------------------------------------------------------------------------------- /metrics/lpipsPyTorch/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /env/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | from .pointmaze import U_MAZE 3 | register( 4 | id="pusht", 5 | entry_point="env.pusht.pusht_wrapper:PushTWrapper", 6 | max_episode_steps=300, 7 | reward_threshold=1.0, 8 | ) 9 | register( 10 | id='point_maze', 11 | entry_point='env.pointmaze:PointMazeWrapper', 12 | max_episode_steps=300, 13 | kwargs={ 14 | 'maze_spec':U_MAZE, 15 | 'reward_type':'sparse', 16 | 'reset_target': False, 17 | 'ref_min_score': 23.85, 18 | 'ref_max_score': 161.86, 19 | 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-sparse-v1.hdf5' 20 | } 21 | ) 22 | register( 23 | id="wall", 24 | entry_point="env.wall.wall_env_wrapper:WallEnvWrapper", 25 | max_episode_steps=300, 26 | reward_threshold=1.0, 27 | ) 28 | 29 | register( 30 | id="deformable_env", 31 | entry_point="env.deformable_env.FlexEnvWrapper:FlexEnvWrapper", 32 | max_episode_steps=300, 33 | reward_threshold=1.0, 34 | ) -------------------------------------------------------------------------------- /env/pointmaze/gridcraft/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def flat_to_one_hot(val, ndim): 4 | """ 5 | 6 | >>> flat_to_one_hot(2, ndim=4) 7 | array([ 0., 0., 1., 0.]) 8 | >>> flat_to_one_hot(4, ndim=5) 9 | array([ 0., 0., 0., 0., 1.]) 10 | >>> flat_to_one_hot(np.array([2, 4, 3]), ndim=5) 11 | array([[ 0., 0., 1., 0., 0.], 12 | [ 0., 0., 0., 0., 1.], 13 | [ 0., 0., 0., 1., 0.]]) 14 | """ 15 | shape =np.array(val).shape 16 | v = np.zeros(shape + (ndim,)) 17 | if len(shape) == 1: 18 | v[np.arange(shape[0]), val] = 1.0 19 | else: 20 | v[val] = 1.0 21 | return v 22 | 23 | def one_hot_to_flat(val): 24 | """ 25 | >>> one_hot_to_flat(np.array([0,0,0,0,1])) 26 | 4 27 | >>> one_hot_to_flat(np.array([0,0,1,0])) 28 | 2 29 | >>> one_hot_to_flat(np.array([[0,0,1,0], [1,0,0,0], [0,1,0,0]])) 30 | array([2, 0, 1]) 31 | """ 32 | idxs = np.array(np.where(val == 1.0))[-1] 33 | if len(val.shape) == 1: 34 | return int(idxs) 35 | return idxs -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 gaoyuezhou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /planning/base_planner.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | from abc import ABC, abstractmethod 4 | 5 | 6 | class BasePlanner(ABC): 7 | def __init__( 8 | self, 9 | wm, 10 | action_dim, 11 | objective_fn, 12 | preprocessor, 13 | evaluator, 14 | wandb_run, 15 | log_filename, 16 | **kwargs, 17 | ): 18 | self.wm = wm 19 | self.action_dim = action_dim 20 | self.objective_fn = objective_fn 21 | self.preprocessor = preprocessor 22 | self.device = next(wm.parameters()).device 23 | 24 | self.evaluator = evaluator 25 | self.wandb_run = wandb_run 26 | self.log_filename = log_filename # do not log if None 27 | 28 | def dump_logs(self, logs): 29 | logs_entry = { 30 | key: ( 31 | value.item() 32 | if isinstance(value, (np.float32, np.int32, np.int64)) 33 | else value 34 | ) 35 | for key, value in logs.items() 36 | } 37 | if self.log_filename is not None: 38 | with open(self.log_filename, "a") as file: 39 | file.write(json.dumps(logs_entry) + "\n") 40 | 41 | @abstractmethod 42 | def plan(self): 43 | pass 44 | -------------------------------------------------------------------------------- /env/deformable_env/src/sim/sim_env/flex_scene.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pyflex 3 | 4 | from .scenes import * 5 | 6 | 7 | class FlexScene: 8 | def __init__(self): 9 | self.obj = None 10 | self.env_idx = None 11 | self.scene_params = None 12 | 13 | self.property_params = None 14 | self.clusters = None 15 | 16 | def set_scene(self, obj, obj_params): 17 | self.obj = obj 18 | 19 | if self.obj == "rope": 20 | self.env_idx = 26 21 | self.scene_params, self.property_params = rope_scene(obj_params) 22 | elif self.obj == "granular": 23 | self.env_idx = 35 24 | self.scene_params, self.property_params = granular_scene(obj_params) 25 | elif self.obj == "cloth": 26 | self.env_idx = 29 27 | self.scene_params, self.property_params = cloth_scene(obj_params) 28 | else: 29 | raise ValueError("Unknown Scene.") 30 | 31 | assert self.env_idx is not None 32 | assert self.scene_params is not None 33 | zeros = np.array([0]) 34 | pyflex.set_scene(self.env_idx, self.scene_params, zeros, zeros, zeros, zeros, 0) 35 | 36 | def get_property_params(self): 37 | assert self.property_params is not None 38 | return self.property_params 39 | -------------------------------------------------------------------------------- /metrics/lpipsPyTorch/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | net_type (str): the network type to compare the features: 14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 15 | version (str): the version of LPIPS. Default: 0.1. 16 | """ 17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 18 | 19 | assert version in ['0.1'], 'v0.1 is only supported now' 20 | 21 | super(LPIPS, self).__init__() 22 | 23 | # pretrained network 24 | self.net = get_network(net_type) 25 | 26 | # linear layers 27 | self.lin = LinLayers(self.net.n_channels_list) 28 | self.lin.load_state_dict(get_state_dict(net_type, version)) 29 | 30 | def forward(self, x: torch.Tensor, y: torch.Tensor): 31 | feat_x, feat_y = self.net(x), self.net(y) 32 | 33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 35 | # return torch.sum(torch.cat(res, 0), 0, True) 36 | return torch.mean(torch.sum(torch.cat(res, 1), dim=1)) # return average across batch instead 37 | -------------------------------------------------------------------------------- /models/dummy.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | class DummyModel(nn.Module): 5 | def __init__(self, emb_dim, **kwargs): 6 | super().__init__() 7 | self.name = "dummy" 8 | self.latent_ndim = 1 9 | self.emb_dim = emb_dim 10 | self.fc = nn.Linear(emb_dim, 1) # not used 11 | 12 | def forward(self, x): 13 | b, dim = x.shape 14 | num_repeat = self.emb_dim // dim 15 | processed_x = torch.zeros([b, self.emb_dim]).to(x.device) 16 | x_repeated = x.repeat(1, num_repeat) 17 | processed_x[:, :num_repeat * dim] = x_repeated 18 | return x.unsqueeze(1) # return shape: (b, 1(or # patches), num_features) 19 | 20 | class DummyRepeatActionEncoder(nn.Module): 21 | def __init__(self, in_chans, emb_dim, **kwargs): 22 | super().__init__() 23 | self.name = "dummy_repeat" 24 | self.latent_ndim = 1 25 | self.in_chans = in_chans 26 | self.emb_dim = emb_dim 27 | self.fc = nn.Linear(in_chans, 1) # not used 28 | 29 | def forward(self, act): 30 | ''' 31 | (b, t, act_dim) --> (b, t, action_emb_dim) 32 | ''' 33 | b, t, act_dim = act.shape 34 | num_repeat = self.emb_dim // act_dim 35 | processed_act = torch.zeros([b, t, self.emb_dim]).to(act.device) 36 | act_repeated = act.repeat(1, 1, num_repeat) 37 | processed_act[:, :, :num_repeat * act_dim] = act_repeated 38 | return processed_act -------------------------------------------------------------------------------- /conf/plan.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - planner: gd 4 | - override hydra/launcher: submitit_slurm 5 | 6 | hydra: 7 | run: 8 | dir: plan_outputs/${now:%Y%m%d%H%M%S}_${replace_slash:${model_name}}_gH${goal_H} 9 | sweep: 10 | dir: plan_outputs/${now:%Y%m%d%H%M%S}_${replace_slash:${model_name}}_gH${goal_H} 11 | subdir: ${hydra.job.num} 12 | launcher: 13 | submitit_folder: ${hydra.sweep.dir}/.submitit/%j 14 | nodes: 1 15 | tasks_per_node: 1 16 | cpus_per_task: 16 17 | mem_gb: 256 18 | gres: "gpu:h100:1" 19 | qos: "explore" 20 | timeout_min: 720 21 | setup: ["export DEBUGVAR=$(scontrol show hostnames $SLURM_JOB_NODELIST)", 22 | export MASTER_ADDR="$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)", 23 | "export MASTER_PORT=$(for port in $(shuf -i 30000-65500 -n 20); do if [[ $(netstat -tupln 2>&1 | grep $port | wc -l) -eq 0 ]] ; then echo $port; break; fi; done;)",] 24 | 25 | # model to load for planning 26 | ckpt_base_path: ./ # put absolute path here. Checkpoints will be loaded from ${ckpt_base_path}/outputs 27 | model_name: null 28 | model_epoch: latest 29 | 30 | seed: 99 31 | n_evals: 10 32 | goal_source: 'dset' # 'random_state' or 'dset' or 'random_action' 33 | goal_H: 5 # specifies how far away the goal is if goal_source is 'dset' 34 | n_plot_samples: 10 35 | 36 | debug_dset_init: False 37 | 38 | objective: 39 | _target_: planning.objectives.create_objective_fn 40 | alpha: 1 41 | base: 2 # coeff base for weighting all frames. Only applies when mode == 'all' 42 | mode: last 43 | -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018, UFACTORY Inc. 2 | 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without modification, 6 | are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright notice, 9 | this list of conditions and the following disclaimer. 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | * Neither the name of the copyright holder nor the names of its contributors 14 | may be used to endorse or promote products derived from this software 15 | without specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 18 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 19 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 20 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 21 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 22 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 23 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 24 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 25 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 26 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 27 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /conf/plan_wall.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - override hydra/launcher: submitit_slurm 4 | 5 | hydra: 6 | run: 7 | dir: plan_outputs/${now:%Y%m%d%H%M%S}_${replace_slash:${model_name}}_gH${goal_H} 8 | sweep: 9 | dir: plan_outputs/${now:%Y%m%d%H%M%S}_${replace_slash:${model_name}}_gH${goal_H} 10 | subdir: ${hydra.job.num} 11 | launcher: 12 | submitit_folder: ${hydra.sweep.dir}/.submitit/%j 13 | nodes: 1 14 | tasks_per_node: 1 15 | cpus_per_task: 16 16 | mem_gb: 256 17 | gres: "gpu:h100:1" 18 | qos: "explore" 19 | timeout_min: 720 20 | setup: ["export DEBUGVAR=$(scontrol show hostnames $SLURM_JOB_NODELIST)", 21 | export MASTER_ADDR="$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)", 22 | "export MASTER_PORT=$(for port in $(shuf -i 30000-65500 -n 20); do if [[ $(netstat -tupln 2>&1 | grep $port | wc -l) -eq 0 ]] ; then echo $port; break; fi; done;)",] 23 | 24 | # model to load for planning 25 | ckpt_base_path: ./checkpoints # put absolute path here. Checkpoints will be loaded from ${ckpt_base_path}/outputs 26 | model_name: null 27 | model_epoch: latest 28 | 29 | seed: 99 30 | n_evals: 50 31 | goal_source: 'random_state' 32 | goal_H: 5 33 | n_plot_samples: 10 34 | 35 | debug_dset_init: False 36 | 37 | objective: 38 | _target_: planning.objectives.create_objective_fn 39 | alpha: 1 40 | base: 2 41 | mode: last 42 | 43 | planner: 44 | _target_: planning.mpc.MPCPlanner 45 | max_iter: null 46 | n_taken_actions: 5 47 | sub_planner: 48 | target: planning.cem.CEMPlanner 49 | horizon: 5 50 | topk: 30 51 | num_samples: 300 52 | var_scale: 1 53 | opt_steps: 10 54 | eval_every: 1 55 | name: mpc_cem -------------------------------------------------------------------------------- /conf/plan_point_maze.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - override hydra/launcher: submitit_slurm 4 | 5 | hydra: 6 | run: 7 | dir: plan_outputs/${now:%Y%m%d%H%M%S}_${replace_slash:${model_name}}_gH${goal_H} 8 | sweep: 9 | dir: plan_outputs/${now:%Y%m%d%H%M%S}_${replace_slash:${model_name}}_gH${goal_H} 10 | subdir: ${hydra.job.num} 11 | launcher: 12 | submitit_folder: ${hydra.sweep.dir}/.submitit/%j 13 | nodes: 1 14 | tasks_per_node: 1 15 | cpus_per_task: 16 16 | mem_gb: 256 17 | gres: "gpu:h100:1" 18 | qos: "explore" 19 | timeout_min: 720 20 | setup: ["export DEBUGVAR=$(scontrol show hostnames $SLURM_JOB_NODELIST)", 21 | export MASTER_ADDR="$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)", 22 | "export MASTER_PORT=$(for port in $(shuf -i 30000-65500 -n 20); do if [[ $(netstat -tupln 2>&1 | grep $port | wc -l) -eq 0 ]] ; then echo $port; break; fi; done;)",] 23 | 24 | # model to load for planning 25 | ckpt_base_path: ./checkpoints # put absolute path here. Checkpoints will be loaded from ${ckpt_base_path}/outputs 26 | model_name: null 27 | model_epoch: latest 28 | 29 | seed: 99 30 | n_evals: 50 31 | goal_source: 'random_state' 32 | goal_H: 5 33 | n_plot_samples: 10 34 | 35 | debug_dset_init: False 36 | 37 | objective: 38 | _target_: planning.objectives.create_objective_fn 39 | alpha: 0 40 | base: 2 41 | mode: last 42 | 43 | planner: 44 | _target_: planning.mpc.MPCPlanner 45 | max_iter: null 46 | n_taken_actions: 5 47 | sub_planner: 48 | target: planning.cem.CEMPlanner 49 | horizon: 5 50 | topk: 30 51 | num_samples: 300 52 | var_scale: 1 53 | opt_steps: 10 54 | eval_every: 1 55 | name: mpc_cem -------------------------------------------------------------------------------- /conf/plan_pusht.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - override hydra/launcher: submitit_slurm 4 | 5 | hydra: 6 | run: 7 | dir: plan_outputs/${now:%Y%m%d%H%M%S}_${replace_slash:${model_name}}_gH${goal_H} 8 | sweep: 9 | dir: plan_outputs/${now:%Y%m%d%H%M%S}_${replace_slash:${model_name}}_gH${goal_H} 10 | subdir: ${hydra.job.num} 11 | launcher: 12 | submitit_folder: ${hydra.sweep.dir}/.submitit/%j 13 | nodes: 1 14 | tasks_per_node: 1 15 | cpus_per_task: 16 16 | mem_gb: 256 17 | gres: "gpu:h100:1" 18 | qos: "explore" 19 | timeout_min: 720 20 | setup: ["export DEBUGVAR=$(scontrol show hostnames $SLURM_JOB_NODELIST)", 21 | export MASTER_ADDR="$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)", 22 | "export MASTER_PORT=$(for port in $(shuf -i 30000-65500 -n 20); do if [[ $(netstat -tupln 2>&1 | grep $port | wc -l) -eq 0 ]] ; then echo $port; break; fi; done;)",] 23 | 24 | # model to load for planning 25 | ckpt_base_path: ./checkpoints # put absolute path here. Checkpoints will be loaded from ${ckpt_base_path}/outputs 26 | model_name: null 27 | model_epoch: latest 28 | 29 | seed: 99 30 | n_evals: 50 31 | goal_source: 'dset' 32 | goal_H: 5 33 | n_plot_samples: 10 34 | 35 | debug_dset_init: False 36 | 37 | objective: 38 | _target_: planning.objectives.create_objective_fn 39 | alpha: 1 40 | base: 2 41 | mode: last 42 | 43 | planner: 44 | _target_: planning.mpc.MPCPlanner 45 | max_iter: null 46 | n_taken_actions: 5 47 | sub_planner: 48 | target: planning.cem.CEMPlanner 49 | horizon: 5 50 | topk: 30 51 | num_samples: 300 52 | var_scale: 1 53 | opt_steps: 30 54 | eval_every: 1 55 | name: mpc_cem 56 | 57 | -------------------------------------------------------------------------------- /models/encoder/r3m/r3m_base.yaml: -------------------------------------------------------------------------------- 1 | name: r3m_base 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=conda_forge 7 | - _openmp_mutex=4.5=1_gnu 8 | - bzip2=1.0.8=h7f98852_4 9 | - ca-certificates=2021.10.8=ha878542_0 10 | - ld_impl_linux-64=2.36.1=hea4e1c9_2 11 | - libffi=3.4.2=h7f98852_5 12 | - libgcc-ng=11.2.0=h1d223b6_13 13 | - libgomp=11.2.0=h1d223b6_13 14 | - libnsl=2.0.0=h7f98852_0 15 | - libuuid=2.32.1=h7f98852_1000 16 | - libzlib=1.2.11=h36c2ea0_1013 17 | - ncurses=6.3=h9c3ff4c_0 18 | - openssl=3.0.0=h7f98852_2 19 | - pip=22.0.4=pyhd8ed1ab_0 20 | - python=3.9.10=hc74c709_2_cpython 21 | - python_abi=3.9=2_cp39 22 | - readline=8.1=h46c0cb4_0 23 | - setuptools=60.9.3=py39hf3d152e_0 24 | - sqlite=3.37.0=h9cd32fc_0 25 | - tk=8.6.12=h27826a3_0 26 | - tzdata=2021e=he74cb21_0 27 | - wheel=0.37.1=pyhd8ed1ab_0 28 | - xz=5.2.5=h516909a_1 29 | - zlib=1.2.11=h36c2ea0_1013 30 | - pip: 31 | - antlr4-python3-runtime==4.8 32 | - beautifulsoup4==4.10.0 33 | - certifi==2021.10.8 34 | - charset-normalizer==2.0.12 35 | - click==8.0.4 36 | - cycler==0.11.0 37 | - filelock==3.6.0 38 | - fonttools==4.30.0 39 | - gdown==4.4.0 40 | - huggingface-hub==0.4.0 41 | - hydra-core==1.1.1 42 | - idna==3.3 43 | - joblib==1.1.0 44 | - kiwisolver==1.3.2 45 | - matplotlib==3.5.1 46 | - numpy==1.22.3 47 | - omegaconf==2.1.1 48 | - packaging==21.3 49 | - pillow==9.0.1 50 | - pyparsing==3.0.7 51 | - pysocks==1.7.1 52 | - python-dateutil==2.8.2 53 | - pyyaml==6.0 54 | - regex==2022.3.2 55 | - requests==2.27.1 56 | - sacremoses==0.0.47 57 | - six==1.16.0 58 | - soupsieve==2.3.1 59 | - tokenizers==0.11.6 60 | - torch==1.7.1 61 | - torchvision==0.8.2 62 | - tqdm==4.63.0 63 | - transformers==4.17.0 64 | - typing-extensions==4.1.1 65 | - urllib3==1.26.8 66 | prefix: /private/home/surajn/.conda/envs/r3m_base 67 | -------------------------------------------------------------------------------- /preprocessor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | 4 | class Preprocessor: 5 | def __init__(self, 6 | action_mean, 7 | action_std, 8 | state_mean, 9 | state_std, 10 | proprio_mean, 11 | proprio_std, 12 | transform, 13 | ): 14 | self.action_mean = action_mean 15 | self.action_std = action_std 16 | self.state_mean = state_mean 17 | self.state_std = state_std 18 | self.proprio_mean = proprio_mean 19 | self.proprio_std = proprio_std 20 | self.transform = transform 21 | 22 | def normalize_actions(self, actions): 23 | ''' 24 | actions: (b, t, action_dim) 25 | ''' 26 | return (actions - self.action_mean) / self.action_std 27 | 28 | def denormalize_actions(self, actions): 29 | ''' 30 | actions: (b, t, action_dim) 31 | ''' 32 | return actions * self.action_std + self.action_mean 33 | 34 | def normalize_proprios(self, proprio): 35 | ''' 36 | input shape (..., proprio_dim) 37 | ''' 38 | return (proprio - self.proprio_mean) / self.proprio_std 39 | 40 | def normalize_states(self, state): 41 | ''' 42 | input shape (..., state_dim) 43 | ''' 44 | return (state - self.state_mean) / self.state_std 45 | 46 | def preprocess_obs_visual(self, obs_visual): 47 | return rearrange(obs_visual, "b t h w c -> b t c h w") / 255.0 48 | 49 | def transform_obs_visual(self, obs_visual): 50 | transformed_obs_visual = torch.tensor(obs_visual) 51 | transformed_obs_visual = self.preprocess_obs_visual(transformed_obs_visual) 52 | transformed_obs_visual = self.transform(transformed_obs_visual) 53 | return transformed_obs_visual 54 | 55 | def transform_obs(self, obs): 56 | ''' 57 | np arrays to tensors 58 | ''' 59 | transformed_obs = {} 60 | transformed_obs['visual'] = self.transform_obs_visual(obs['visual']) 61 | transformed_obs['proprio'] = self.normalize_proprios(torch.tensor(obs['proprio'])) 62 | return transformed_obs 63 | -------------------------------------------------------------------------------- /models/proprio.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/facebookresearch/ijepa/blob/main/src/models/vision_transformer.py 2 | import numpy as np 3 | import torch.nn as nn 4 | 5 | 6 | def get_1d_sincos_pos_embed(emb_dim, grid_size, cls_token=False): 7 | """ 8 | emb_dim: output dimension for each position 9 | grid_size: int of the grid length 10 | returns: 11 | pos_embed: [grid_size, emb_dim] (w/o cls_token) 12 | or [1+grid_size, emb_dim] (w/ cls_token) 13 | """ 14 | grid = np.arange(grid_size, dtype=float) 15 | pos_embed = get_1d_sincos_pos_embed_from_grid(emb_dim, grid) 16 | if cls_token: 17 | pos_embed = np.concatenate([np.zeros([1, emb_dim]), pos_embed], axis=0) 18 | return pos_embed 19 | 20 | def get_1d_sincos_pos_embed_from_grid(emb_dim, pos): 21 | """ 22 | emb_dim: output dimension for each position 23 | pos: a list of positions to be encoded: size (M,) 24 | returns: (M, D) 25 | """ 26 | assert emb_dim % 2 == 0 27 | omega = np.arange(emb_dim // 2, dtype=float) 28 | omega /= emb_dim / 2. 29 | omega = 1. / 10000**omega # (D/2,) 30 | 31 | pos = pos.reshape(-1) # (M,) 32 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 33 | 34 | emb_sin = np.sin(out) # (M, D/2) 35 | emb_cos = np.cos(out) # (M, D/2) 36 | 37 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 38 | return emb 39 | 40 | class ProprioceptiveEmbedding(nn.Module): 41 | def __init__( 42 | self, 43 | num_frames=16, # horizon 44 | tubelet_size=1, 45 | in_chans=8, # action_dim 46 | emb_dim=384, # output_dim 47 | use_3d_pos=False # always False for now 48 | ): 49 | super().__init__() 50 | print(f'using 3d prop position {use_3d_pos=}') 51 | 52 | # Map input to predictor dimension 53 | self.num_frames = num_frames 54 | self.tubelet_size = tubelet_size 55 | self.in_chans = in_chans 56 | self.emb_dim = emb_dim 57 | 58 | self.patch_embed = nn.Conv1d( 59 | in_chans, 60 | emb_dim, 61 | kernel_size=tubelet_size, 62 | stride=tubelet_size) 63 | 64 | def forward(self, x): 65 | # x: proprioceptive vectors of shape [B T D] 66 | x = x.permute(0, 2, 1) 67 | x = self.patch_embed(x) 68 | x = x.permute(0, 2, 1) 69 | return x -------------------------------------------------------------------------------- /models/encoder/r3m/models/models_language.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import numpy as np 6 | from numpy.core.numeric import full 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.modules.activation import Sigmoid 10 | 11 | epsilon = 1e-8 12 | 13 | 14 | class LangEncoder(nn.Module): 15 | def __init__(self, device, finetune=False, scratch=False): 16 | super().__init__() 17 | from transformers import AutoTokenizer, AutoModel, AutoConfig 18 | 19 | self.device = device 20 | self.modelname = "distilbert-base-uncased" 21 | self.tokenizer = AutoTokenizer.from_pretrained(self.modelname) 22 | self.model = AutoModel.from_pretrained(self.modelname).to(self.device) 23 | self.lang_size = 768 24 | 25 | def forward(self, langs): 26 | try: 27 | langs = langs.tolist() 28 | except: 29 | pass 30 | 31 | with torch.no_grad(): 32 | encoded_input = self.tokenizer(langs, return_tensors="pt", padding=True) 33 | input_ids = encoded_input["input_ids"].to(self.device) 34 | attention_mask = encoded_input["attention_mask"].to(self.device) 35 | lang_embedding = self.model( 36 | input_ids, attention_mask=attention_mask 37 | ).last_hidden_state 38 | lang_embedding = lang_embedding.mean(1) 39 | return lang_embedding 40 | 41 | 42 | class LanguageReward(nn.Module): 43 | def __init__(self, ltype, im_dim, hidden_dim, lang_dim, simfunc=None): 44 | super().__init__() 45 | self.ltype = ltype 46 | self.sim = simfunc 47 | self.sigm = Sigmoid() 48 | self.pred = nn.Sequential( 49 | nn.Linear(im_dim * 2 + lang_dim, hidden_dim), 50 | nn.ReLU(inplace=True), 51 | nn.Linear(hidden_dim, hidden_dim), 52 | nn.ReLU(inplace=True), 53 | nn.Linear(hidden_dim, hidden_dim), 54 | nn.ReLU(inplace=True), 55 | nn.Linear(hidden_dim, hidden_dim), 56 | nn.ReLU(inplace=True), 57 | nn.Linear(hidden_dim, 1), 58 | ) 59 | 60 | def forward(self, e0, eg, le): 61 | info = {} 62 | return self.pred(torch.cat([e0, eg, le], -1)).squeeze(), info 63 | -------------------------------------------------------------------------------- /planning/objectives.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def create_objective_fn(alpha, base, mode="last"): 7 | """ 8 | Loss calculated on the last pred frame. 9 | Args: 10 | alpha: int 11 | base: int. only used for objective_fn_all 12 | Returns: 13 | loss: tensor (B, ) 14 | """ 15 | metric = nn.MSELoss(reduction="none") 16 | 17 | def objective_fn_last(z_obs_pred, z_obs_tgt): 18 | """ 19 | Args: 20 | z_obs_pred: dict, {'visual': (B, T, *D_visual), 'proprio': (B, T, *D_proprio)} 21 | z_obs_tgt: dict, {'visual': (B, T, *D_visual), 'proprio': (B, T, *D_proprio)} 22 | Returns: 23 | loss: tensor (B, ) 24 | """ 25 | loss_visual = metric(z_obs_pred["visual"][:, -1:], z_obs_tgt["visual"]).mean( 26 | dim=tuple(range(1, z_obs_pred["visual"].ndim)) 27 | ) 28 | loss_proprio = metric(z_obs_pred["proprio"][:, -1:], z_obs_tgt["proprio"]).mean( 29 | dim=tuple(range(1, z_obs_pred["proprio"].ndim)) 30 | ) 31 | loss = loss_visual + alpha * loss_proprio 32 | return loss 33 | 34 | def objective_fn_all(z_obs_pred, z_obs_tgt): 35 | """ 36 | Loss calculated on all pred frames. 37 | Args: 38 | z_obs_pred: dict, {'visual': (B, T, *D_visual), 'proprio': (B, T, *D_proprio)} 39 | z_obs_tgt: dict, {'visual': (B, T, *D_visual), 'proprio': (B, T, *D_proprio)} 40 | Returns: 41 | loss: tensor (B, ) 42 | """ 43 | coeffs = np.array( 44 | [base**i for i in range(z_obs_pred["visual"].shape[1])], dtype=np.float32 45 | ) 46 | coeffs = torch.tensor(coeffs / np.sum(coeffs)).to(z_obs_pred["visual"].device) 47 | loss_visual = metric(z_obs_pred["visual"], z_obs_tgt["visual"]).mean( 48 | dim=tuple(range(2, z_obs_pred["visual"].ndim)) 49 | ) 50 | loss_proprio = metric(z_obs_pred["proprio"], z_obs_tgt["proprio"]).mean( 51 | dim=tuple(range(2, z_obs_pred["proprio"].ndim)) 52 | ) 53 | loss_visual = (loss_visual * coeffs).mean(dim=1) 54 | loss_proprio = (loss_proprio * coeffs).mean(dim=1) 55 | loss = loss_visual + alpha * loss_proprio 56 | return loss 57 | 58 | if mode == "last": 59 | return objective_fn_last 60 | elif mode == "all": 61 | return objective_fn_all 62 | else: 63 | raise NotImplementedError 64 | -------------------------------------------------------------------------------- /conf/train.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - env: point_maze 4 | - encoder: dino 5 | - action_encoder: proprio 6 | - proprio_encoder: proprio 7 | - decoder: vqvae 8 | - predictor: vit 9 | - override hydra/launcher: submitit_slurm 10 | 11 | # base path to save model outputs. Checkpoints will be saved to ${ckpt_base_path}/outputs. 12 | ckpt_base_path: ./ # put absolute path here 13 | 14 | hydra: 15 | run: 16 | dir: ${ckpt_base_path}/outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} 17 | sweep: 18 | dir: ${ckpt_base_path}/outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} 19 | subdir: ${hydra.job.num} 20 | launcher: 21 | submitit_folder: ${hydra.sweep.dir}/.submitit/%j 22 | nodes: 1 23 | tasks_per_node: 1 24 | cpus_per_task: 8 25 | mem_gb: 512 26 | gres: "gpu:h100:1" 27 | timeout_min: 2880 28 | setup: ["export DEBUGVAR=$(scontrol show hostnames $SLURM_JOB_NODELIST)", 29 | export MASTER_ADDR="$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)", 30 | "export MASTER_PORT=$(for port in $(shuf -i 30000-65500 -n 20); do if [[ $(netstat -tupln 2>&1 | grep $port | wc -l) -eq 0 ]] ; then echo $port; break; fi; done;)",] 31 | 32 | training: 33 | seed: 0 34 | epochs: 100 35 | batch_size: 32 # should >= nodes * tasks_per_node 36 | save_every_x_epoch: 1 37 | reconstruct_every_x_batch: 500 38 | num_reconstruct_samples: 6 39 | encoder_lr: 1e-6 40 | decoder_lr: 3e-4 41 | predictor_lr: 5e-4 42 | action_encoder_lr: 5e-4 43 | 44 | img_size: 224 # should be a multiple of 224 45 | frameskip: 5 46 | concat_dim: 1 47 | 48 | normalize_action: True 49 | 50 | # action encoder 51 | action_emb_dim: 10 52 | num_action_repeat: 1 53 | 54 | # proprio encoder 55 | proprio_emb_dim: 10 56 | num_proprio_repeat: 1 57 | 58 | num_hist: 3 59 | num_pred: 1 # only supports 1 60 | has_predictor: True # set this to False for only training a decoder 61 | has_decoder: True # set this to False for only training a predictor 62 | 63 | model: 64 | _target_: models.visual_world_model.VWorldModel 65 | image_size: ${img_size} 66 | num_hist: ${num_hist} 67 | num_pred: ${num_pred} 68 | train_encoder: False 69 | train_predictor: True 70 | train_decoder: True 71 | 72 | debug: False 73 | 74 | # Planning params for planning eval jobs launched during training 75 | plan_settings: 76 | # plan_cfg_path: conf/plan.yaml # set to null for no planning evals 77 | plan_cfg_path: null 78 | planner: ['gd', 'cem'] 79 | goal_source: ['dset', 'random_state'] 80 | goal_H: [5] 81 | alpha: [0.1, 1] -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/link6.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 11 | 12 | 13 | /xarm 14 | 15 | gazebo_ros_control/DefaultRobotHWSim 16 | true 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/base.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 11 | 12 | 13 | /xarm 14 | 15 | gazebo_ros_control/DefaultRobotHWSim 16 | true 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/link1.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 11 | 12 | 13 | /xarm 14 | 15 | gazebo_ros_control/DefaultRobotHWSim 16 | true 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/link6_com.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 11 | 12 | 13 | /xarm 14 | 15 | gazebo_ros_control/DefaultRobotHWSim 16 | true 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/link3.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 11 | 12 | 13 | /xarm 14 | 15 | gazebo_ros_control/DefaultRobotHWSim 16 | true 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/base_com.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 11 | 12 | 13 | /xarm 14 | 15 | gazebo_ros_control/DefaultRobotHWSim 16 | true 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/link5.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 11 | 12 | 13 | /xarm 14 | 15 | gazebo_ros_control/DefaultRobotHWSim 16 | true 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/link4.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 11 | 12 | 13 | /xarm 14 | 15 | gazebo_ros_control/DefaultRobotHWSim 16 | true 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/link1_com.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 11 | 12 | 13 | /xarm 14 | 15 | gazebo_ros_control/DefaultRobotHWSim 16 | true 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/link2.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 11 | 12 | 13 | /xarm 14 | 15 | gazebo_ros_control/DefaultRobotHWSim 16 | true 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/link3_com.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 11 | 12 | 13 | /xarm 14 | 15 | gazebo_ros_control/DefaultRobotHWSim 16 | true 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/link5_com.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 11 | 12 | 13 | /xarm 14 | 15 | gazebo_ros_control/DefaultRobotHWSim 16 | true 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/link4_com.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 11 | 12 | 13 | /xarm 14 | 15 | gazebo_ros_control/DefaultRobotHWSim 16 | true 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/link2_com.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 11 | 12 | 13 | /xarm 14 | 15 | gazebo_ros_control/DefaultRobotHWSim 16 | true 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /distributed_fn/launch.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import distributed as dist 5 | from torch import multiprocessing as mp 6 | 7 | import distributed_fn as dist_fn 8 | 9 | 10 | def find_free_port(): 11 | import socket 12 | 13 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 14 | 15 | sock.bind(("", 0)) 16 | port = sock.getsockname()[1] 17 | sock.close() 18 | 19 | return port 20 | 21 | 22 | def launch(fn, n_gpu_per_machine, n_machine=1, machine_rank=0, dist_url=None, args=()): 23 | world_size = n_machine * n_gpu_per_machine 24 | 25 | if world_size > 1: 26 | if "OMP_NUM_THREADS" not in os.environ: 27 | os.environ["OMP_NUM_THREADS"] = "1" 28 | 29 | if dist_url == "auto": 30 | if n_machine != 1: 31 | raise ValueError('dist_url="auto" not supported in multi-machine jobs') 32 | 33 | port = find_free_port() 34 | dist_url = f"tcp://127.0.0.1:{port}" 35 | 36 | if n_machine > 1 and dist_url.startswith("file://"): 37 | raise ValueError( 38 | "file:// is not a reliable init method in multi-machine jobs. Prefer tcp://" 39 | ) 40 | 41 | mp.spawn( 42 | distributed_worker, 43 | nprocs=n_gpu_per_machine, 44 | args=(fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args), 45 | daemon=False, 46 | ) 47 | 48 | else: 49 | fn(*args) 50 | 51 | 52 | def distributed_worker( 53 | local_rank, fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args 54 | ): 55 | if not torch.cuda.is_available(): 56 | raise OSError("CUDA is not available. Please check your environments") 57 | 58 | global_rank = machine_rank * n_gpu_per_machine + local_rank 59 | 60 | try: 61 | dist.init_process_group( 62 | backend="NCCL", 63 | init_method=dist_url, 64 | world_size=world_size, 65 | rank=global_rank, 66 | ) 67 | 68 | except Exception: 69 | raise OSError("failed to initialize NCCL groups") 70 | 71 | dist_fn.synchronize() 72 | 73 | if n_gpu_per_machine > torch.cuda.device_count(): 74 | raise ValueError( 75 | f"specified n_gpu_per_machine larger than available device ({torch.cuda.device_count()})" 76 | ) 77 | 78 | torch.cuda.set_device(local_rank) 79 | 80 | if dist_fn.LOCAL_PROCESS_GROUP is not None: 81 | raise ValueError("torch.distributed.LOCAL_PROCESS_GROUP is not None") 82 | 83 | n_machine = world_size // n_gpu_per_machine 84 | 85 | for i in range(n_machine): 86 | ranks_on_i = list(range(i * n_gpu_per_machine, (i + 1) * n_gpu_per_machine)) 87 | pg = dist.new_group(ranks_on_i) 88 | 89 | if i == machine_rank: 90 | dist_fn.distributed.LOCAL_PROCESS_GROUP = pg 91 | 92 | fn(*args) 93 | -------------------------------------------------------------------------------- /metrics/lpipsPyTorch/modules/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) 97 | -------------------------------------------------------------------------------- /metrics/image_metrics.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | import torch 12 | import torch.nn.functional as F 13 | from torch.autograd import Variable 14 | from math import exp 15 | from metrics.lpipsPyTorch import lpips 16 | 17 | def l1_loss(network_output, gt): 18 | return torch.abs((network_output - gt)).mean() 19 | 20 | def l2_loss(network_output, gt): 21 | return ((network_output - gt) ** 2).mean() 22 | 23 | def gaussian(window_size, sigma): 24 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 25 | return gauss / gauss.sum() 26 | 27 | def create_window(window_size, channel): 28 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 29 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 30 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 31 | return window 32 | 33 | def ssim(img1, img2, window_size=11, size_average=True): 34 | channel = img1.size(-3) 35 | window = create_window(window_size, channel) 36 | 37 | if img1.is_cuda: 38 | window = window.cuda(img1.get_device()) 39 | window = window.type_as(img1) 40 | 41 | return _ssim(img1, img2, window, window_size, channel, size_average) 42 | 43 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 44 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 45 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 46 | 47 | mu1_sq = mu1.pow(2) 48 | mu2_sq = mu2.pow(2) 49 | mu1_mu2 = mu1 * mu2 50 | 51 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 52 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 53 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 54 | 55 | C1 = 0.01 ** 2 56 | C2 = 0.03 ** 2 57 | 58 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 59 | 60 | if size_average: 61 | return ssim_map.mean() 62 | else: 63 | return ssim_map.mean(1).mean(1).mean(1) 64 | 65 | def mse(img1, img2): # batched input only 66 | return (((img1 - img2)) ** 2).reshape(img1.shape[0], -1).mean() 67 | 68 | def psnr(img1, img2): # batched input only 69 | mse = (((img1 - img2)) ** 2).reshape(img1.shape[0], -1).mean() 70 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 71 | 72 | def eval_images(img1, img2, item_only=True): 73 | metrics = {} 74 | metrics['l1'] = l1_loss(img1, img2) 75 | metrics['l2'] = l2_loss(img1, img2) 76 | metrics['ssim'] = ssim(img1, img2) 77 | metrics['mse'] = mse(img1, img2) 78 | metrics['psnr'] = psnr(img1, img2) 79 | metrics['lpips'] = lpips(img1, img2, net_type='vgg') 80 | return metrics -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import argparse 5 | import numpy as np 6 | from PIL import Image 7 | from omegaconf import OmegaConf 8 | from typing import Callable, Dict 9 | import psutil 10 | 11 | def get_ram_usage(): 12 | process = psutil.Process(os.getpid()) 13 | return process.memory_info().rss / (1024 * 1024 * 1024) # Memory usage in MB 14 | 15 | def get_available_ram(): 16 | mem = psutil.virtual_memory() 17 | return mem.available / (1024 * 1024 * 1024) # Available memory in MB 18 | 19 | def dict_to_namespace(cfg_dict): 20 | args = argparse.Namespace() 21 | for key in cfg_dict: 22 | setattr(args, key, cfg_dict[key]) 23 | return args 24 | 25 | def move_to_device(dct, device): 26 | for key, value in dct.items(): 27 | if isinstance(value, torch.Tensor): 28 | dct[key] = value.to(device) 29 | return dct 30 | 31 | def slice_trajdict_with_t(data_dict, start_idx=0, end_idx=None, step=1): 32 | if end_idx is None: 33 | end_idx = max(arr.shape[1] for arr in data_dict.values()) 34 | return {key: arr[:, start_idx:end_idx:step, ...] for key, arr in data_dict.items()} 35 | 36 | def concat_trajdict(dcts): 37 | full_dct = {} 38 | for k in dcts[0].keys(): 39 | if isinstance(dcts[0][k], np.ndarray): 40 | full_dct[k] = np.concatenate([dct[k] for dct in dcts], axis=1) 41 | elif isinstance(dcts[0][k], torch.Tensor): 42 | full_dct[k] = torch.cat([dct[k] for dct in dcts], dim=1) 43 | else: 44 | raise TypeError(f"Unsupported data type: {type(dcts[0][k])}") 45 | return full_dct 46 | 47 | def aggregate_dct(dcts): 48 | full_dct = {} 49 | for dct in dcts: 50 | for key, value in dct.items(): 51 | if key not in full_dct: 52 | full_dct[key] = [] 53 | full_dct[key].append(value) 54 | for key, value in full_dct.items(): 55 | if isinstance(value[0], torch.Tensor): 56 | full_dct[key] = torch.stack(value) 57 | else: 58 | full_dct[key] = np.stack(value) 59 | return full_dct 60 | 61 | def sample_tensors(tensors, n, indices=None): 62 | if indices is None: 63 | b = tensors[0].shape[0] 64 | indices = torch.randperm(b)[:n] 65 | indices = torch.tensor(indices) 66 | for i, tensor in enumerate(tensors): 67 | if tensor is not None: 68 | tensors[i] = tensor[indices] 69 | return tensors 70 | 71 | 72 | def cfg_to_dict(cfg): 73 | cfg_dict = OmegaConf.to_container(cfg) 74 | for key in cfg_dict: 75 | if isinstance(cfg_dict[key], list): 76 | cfg_dict[key] = ",".join(cfg_dict[key]) 77 | return cfg_dict 78 | 79 | def reduce_dict(f: Callable, d: Dict): 80 | return {k: reduce_dict(f, v) if isinstance(v, dict) else f(v) for k, v in d.items()} 81 | 82 | def seed(seed): 83 | random.seed(seed) 84 | torch.manual_seed(seed) 85 | np.random.seed(seed) 86 | if torch.cuda.is_available(): 87 | torch.cuda.manual_seed_all(seed) 88 | 89 | 90 | def pil_loader(path): 91 | with open(path, "rb") as f: 92 | with Image.open(f) as img: 93 | return img.convert("RGB") -------------------------------------------------------------------------------- /env/deformable_env/src/sim/data_gen/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | 4 | 5 | def store_data(filename, data, action): 6 | """ 7 | action: (action_dim,) 8 | imgs_list: (T, num_cameras, H, W, 5) 9 | particle_pos_list: (T, N, 3) 10 | eef_states_list: (T, 14) 11 | """ 12 | # load data 13 | imgs_list, particle_pos_list, eef_states_list = data 14 | imgs_list_np, particle_pos_list_np, eef_states_list_np = ( 15 | np.array(imgs_list), 16 | np.array(particle_pos_list), 17 | np.array(eef_states_list), 18 | ) 19 | 20 | # stat 21 | T, n_cam = imgs_list_np.shape[:2] 22 | n_particles = particle_pos_list_np.shape[1] 23 | 24 | # process images 25 | color_imgs, depth_imgs = process_imgs(imgs_list_np) 26 | 27 | # init episode data 28 | episode = { 29 | "info": {"n_cams": n_cam, "timestamp": T, "n_particles": n_particles}, 30 | "action": action, 31 | "positions": particle_pos_list_np, 32 | "eef_states": eef_states_list_np, 33 | "observations": {"color": color_imgs, "depth": depth_imgs}, 34 | } 35 | 36 | # save to h5py 37 | save_data(filename, episode) 38 | 39 | 40 | def process_imgs(imgs_list): 41 | T, n_cam, H, W, _ = imgs_list.shape 42 | color_imgs = {} 43 | depth_imgs = {} 44 | 45 | for cam_idx in range(n_cam): 46 | img = imgs_list[:, cam_idx] # (T, H, W, 5) 47 | color_imgs[f"cam_{cam_idx}"] = img[:, :, :, :3][..., ::-1] # (T, H, W, 3) 48 | depth_imgs[f"cam_{cam_idx}"] = (img[:, :, :, -1] * 1000).astype( 49 | np.uint16 50 | ) # (T, H, W) 51 | 52 | assert color_imgs["cam_0"].shape == (T, H, W, 3) 53 | assert depth_imgs["cam_0"].shape == (T, H, W) 54 | 55 | return color_imgs, depth_imgs 56 | 57 | 58 | def save_data(filename, save_data): 59 | with h5py.File(filename, "w") as f: 60 | for key, value in save_data.items(): 61 | if key in ["observations"]: 62 | for sub_key, sub_value in value.items(): 63 | for subsub_key, subsub_value in sub_value.items(): 64 | f.create_dataset( 65 | f"{key}/{sub_key}/{subsub_key}", data=subsub_value 66 | ) 67 | elif key in ["info"]: 68 | for sub_key, sub_value in value.items(): 69 | f.create_dataset(f"{key}/{sub_key}", data=sub_value) 70 | else: 71 | f.create_dataset(key, data=value) 72 | 73 | 74 | def load_data(filename): 75 | data = {} 76 | with h5py.File(filename, "r") as f: 77 | for key in f.keys(): 78 | if key in ["observations"]: 79 | data[key] = {} 80 | for sub_key in f[key].keys(): 81 | data[key][sub_key] = {} 82 | for subsub_key in f[key][sub_key].keys(): 83 | data[key][sub_key][subsub_key] = f[key][sub_key][subsub_key][()] 84 | elif key in ["info"]: 85 | data[key] = {} 86 | for sub_key in f[key].keys(): 87 | data[key][sub_key] = f[key][sub_key][()] 88 | else: 89 | data[key] = f[key][()] 90 | return data 91 | -------------------------------------------------------------------------------- /env/deformable_env/src/sim/sim_env/cameras.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pyflex 3 | 4 | class Camera(): 5 | def __init__(self, screenWidth, screenHeight): 6 | 7 | self.screenWidth = screenWidth 8 | self.screenHeight = screenHeight 9 | 10 | self.num_cameras = 4 11 | self.camera_view = None 12 | 13 | self.cam_dis = 6. 14 | self.cam_height = 10. 15 | self.cam_deg = np.array([0., 90., 180., 270.]) + 45. 16 | 17 | def set_init_camera(self, camera_view): 18 | self.camera_view = camera_view 19 | 20 | if self.camera_view == 0: # top view 21 | self.camPos = np.array([0., self.cam_height+10., 0.]) 22 | self.camAngle = np.array([0., -np.deg2rad(90.), 0.]) 23 | elif self.camera_view == 1: 24 | self.camPos = np.array([self.cam_dis, self.cam_height, self.cam_dis]) 25 | self.camAngle = np.array([np.deg2rad(self.cam_deg[0]), -np.deg2rad(45.), 0.]) 26 | elif self.camera_view == 2: 27 | self.camPos = np.array([self.cam_dis, self.cam_height, -self.cam_dis]) 28 | self.camAngle = np.array([np.deg2rad(self.cam_deg[1]), -np.deg2rad(45.), 0.]) 29 | elif self.camera_view == 3: 30 | self.camPos = np.array([-self.cam_dis, self.cam_height, -self.cam_dis]) 31 | self.camAngle = np.array([np.deg2rad(self.cam_deg[2]), -np.deg2rad(45.), 0.]) 32 | elif self.camera_view == 4: 33 | self.camPos = np.array([-self.cam_dis, self.cam_height, self.cam_dis]) 34 | self.camAngle = np.array([np.deg2rad(self.cam_deg[3]), -np.deg2rad(45.), 0.]) 35 | else: 36 | raise ValueError('camera_view not defined') 37 | 38 | # set camera 39 | pyflex.set_camPos(self.camPos) 40 | pyflex.set_camAngle(self.camAngle) 41 | 42 | def init_multiview_cameras(self): 43 | self.camPos_list, self.camAngle_list = [], [] 44 | self.cam_x_list = np.array([self.cam_dis, self.cam_dis, -self.cam_dis, -self.cam_dis]) 45 | self.cam_z_list = np.array([self.cam_dis, -self.cam_dis, -self.cam_dis, self.cam_dis]) 46 | 47 | self.rad_list = np.deg2rad(self.cam_deg) 48 | for i in range(self.num_cameras): 49 | self.camPos_list.append(np.array([self.cam_x_list[i], self.cam_height, self.cam_z_list[i]])) 50 | self.camAngle_list.append(np.array([self.rad_list[i], -np.deg2rad(45.), 0.])) 51 | 52 | self.cam_intrinsic_params = np.zeros([len(self.camPos_list), 4]) # [fx, fy, cx, cy] 53 | self.cam_extrinsic_matrix = np.zeros([len(self.camPos_list), 4, 4]) # [R, t] 54 | 55 | return self.camPos_list, self.camAngle_list, self.cam_intrinsic_params, self.cam_extrinsic_matrix 56 | 57 | def get_cam_params(self): 58 | # camera intrinsic parameters 59 | projMat = pyflex.get_projMatrix().reshape(4, 4).T 60 | cx = self.screenWidth / 2.0 61 | cy = self.screenHeight / 2.0 62 | fx = projMat[0, 0] * cx 63 | fy = projMat[1, 1] * cy 64 | camera_intrinsic_params = np.array([fx, fy, cx, cy]) 65 | 66 | # camera extrinsic parameters 67 | cam_extrinsic_matrix = pyflex.get_viewMatrix().reshape(4, 4).T 68 | 69 | return camera_intrinsic_params, cam_extrinsic_matrix 70 | -------------------------------------------------------------------------------- /env/serial_vector_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from utils import aggregate_dct 3 | 4 | # for debugging environments, and envs that are not compatible with SubprocVectorEnv 5 | class SerialVectorEnv: 6 | """ 7 | obs, reward, done, info 8 | obs: dict, each key has shape (num_env, ...) 9 | reward: (num_env, ) 10 | done: (num_env, ) 11 | info: tuple of length num_env, each element is a dict 12 | """ 13 | 14 | def __init__(self, envs): 15 | self.envs = envs 16 | self.num_envs = len(envs) 17 | 18 | def sample_random_init_goal_states(self, seed): 19 | init_state, goal_state = zip(*(self.envs[i].sample_random_init_goal_states(seed[i]) for i in range(self.num_envs))) 20 | return np.stack(init_state), np.stack(goal_state) 21 | 22 | def update_env(self, env_info): 23 | [self.envs[i].update_env(env_info[i]) for i in range(self.num_envs)] 24 | 25 | def eval_state(self, goal_state, cur_state): 26 | eval_result = [] 27 | for i in range(self.num_envs): 28 | env = self.envs[i] 29 | eval_result.append(env.eval_state(goal_state[i], cur_state[i])) 30 | eval_result = aggregate_dct(eval_result) 31 | return eval_result 32 | 33 | def prepare(self, seed, init_state): 34 | """ 35 | Reset with controlled init_state 36 | obs: (num_envs, H, W, C) 37 | state: tuple of num_envs dicts 38 | """ 39 | obs = [] 40 | state = [] 41 | for i in range(self.num_envs): 42 | env = self.envs[i] 43 | cur_seed = seed[i] 44 | cur_init_state = init_state[i] 45 | o, s = env.prepare(cur_seed, cur_init_state) 46 | obs.append(o) 47 | state.append(s) 48 | obs = aggregate_dct(obs) 49 | state = np.stack(state) 50 | return obs, state 51 | 52 | def step_multiple(self, actions): 53 | """ 54 | actions: (num_envs, T, action_dim) 55 | obses: (num_envs, T, H, W, C) 56 | infos: tuple of length num_envs, each element is a dict 57 | """ 58 | obses = [] 59 | rewards = [] 60 | dones = [] 61 | infos = [] 62 | for i in range(self.num_envs): 63 | env = self.envs[i] 64 | cur_actions = actions[i] 65 | obs, reward, done, info = env.step_multiple(cur_actions) 66 | obses.append(obs) 67 | rewards.append(reward) 68 | dones.append(done) 69 | infos.append(info) 70 | obses = np.stack(obses) 71 | rewards = np.stack(rewards) 72 | dones = np.stack(dones) 73 | infos = tuple(infos) 74 | return obses, rewards, dones, infos 75 | 76 | def rollout(self, seed, init_state, actions): 77 | """ 78 | only returns np arrays of observations and states 79 | obses: (num_envs, T, H, W, C) 80 | states: (num_envs, T, D) 81 | proprios: (num_envs, T, D_p) 82 | """ 83 | obses = [] 84 | states = [] 85 | for i in range(self.num_envs): 86 | env = self.envs[i] 87 | cur_seed = seed[i] 88 | cur_init_state = init_state[i] 89 | cur_actions = actions[i] 90 | obs, state = env.rollout(cur_seed, cur_init_state, cur_actions) 91 | obses.append(obs) 92 | states.append(state) 93 | obses = aggregate_dct(obses) 94 | states = np.stack(states) 95 | return obses, states 96 | -------------------------------------------------------------------------------- /env/pointmaze/waypoint_controller.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from . import q_iteration 3 | from .gridcraft import grid_env 4 | from .gridcraft import grid_spec 5 | 6 | 7 | ZEROS = np.zeros((2,), dtype=np.float32) 8 | ONES = np.zeros((2,), dtype=np.float32) 9 | 10 | 11 | class WaypointController(object): 12 | def __init__(self, maze_str, solve_thresh=0.1, p_gain=10.0, d_gain=-1.0): 13 | self.maze_str = maze_str 14 | self._target = -1000 * ONES 15 | 16 | self.p_gain = p_gain 17 | self.d_gain = d_gain 18 | self.solve_thresh = solve_thresh 19 | self.vel_thresh = 0.1 20 | 21 | self._waypoint_idx = 0 22 | self._waypoints = [] 23 | self._waypoint_prev_loc = ZEROS 24 | 25 | self.env = grid_env.GridEnv(grid_spec.spec_from_string(maze_str)) 26 | 27 | def current_waypoint(self): 28 | return self._waypoints[self._waypoint_idx] 29 | 30 | def get_action(self, location, velocity, target): 31 | if np.linalg.norm(self._target - np.array(self.gridify_state(target))) > 1e-3: 32 | #print('New target!', target, 'old:', self._target) 33 | self._new_target(location, target) 34 | 35 | dist = np.linalg.norm(location - self._target) 36 | vel = self._waypoint_prev_loc - location 37 | vel_norm = np.linalg.norm(vel) 38 | task_not_solved = (dist >= self.solve_thresh) or (vel_norm >= self.vel_thresh) 39 | 40 | if task_not_solved: 41 | next_wpnt = self._waypoints[self._waypoint_idx] 42 | else: 43 | next_wpnt = self._target 44 | 45 | # Compute control 46 | prop = next_wpnt - location 47 | action = self.p_gain * prop + self.d_gain * velocity 48 | 49 | dist_next_wpnt = np.linalg.norm(location - next_wpnt) 50 | if task_not_solved and (dist_next_wpnt < self.solve_thresh) and (vel_norm= 0 ] = 1.0 28 | pol_probs[pol_probs < 0 ] = 0.0 29 | else: 30 | pol_probs = np.exp((1.0/ent_wt)*adv_rew) 31 | pol_probs /= np.sum(pol_probs, axis=1, keepdims=True) 32 | assert np.all(np.isclose(np.sum(pol_probs, axis=1), 1.0)), str(pol_probs) 33 | return pol_probs 34 | 35 | 36 | def softq_iteration(env, transition_matrix=None, reward_matrix=None, num_itrs=50, discount=0.99, ent_wt=0.1, warmstart_q=None, policy=None): 37 | """ 38 | Perform tabular soft Q-iteration 39 | """ 40 | dim_obs = env.num_states 41 | dim_act = env.num_actions 42 | if reward_matrix is None: 43 | reward_matrix = env.reward_matrix() 44 | reward_matrix = reward_matrix[:,:,0] 45 | 46 | if warmstart_q is None: 47 | q_fn = np.zeros((dim_obs, dim_act)) 48 | else: 49 | q_fn = warmstart_q 50 | 51 | if transition_matrix is None: 52 | t_matrix = env.transition_matrix() 53 | else: 54 | t_matrix = transition_matrix 55 | 56 | for k in range(num_itrs): 57 | if policy is None: 58 | v_fn = logsumexp(q_fn, alpha=ent_wt) 59 | else: 60 | v_fn = np.sum((q_fn - ent_wt*np.log(policy))*policy, axis=1) 61 | new_q = reward_matrix + discount*t_matrix.dot(v_fn) 62 | q_fn = new_q 63 | return q_fn 64 | 65 | 66 | def q_iteration(env, **kwargs): 67 | return softq_iteration(env, ent_wt=0.0, **kwargs) 68 | 69 | 70 | def compute_visitation(env, q_fn, ent_wt=1.0, env_time_limit=50, discount=1.0): 71 | pol_probs = get_policy(q_fn, ent_wt=ent_wt) 72 | 73 | dim_obs = env.num_states 74 | dim_act = env.num_actions 75 | state_visitation = np.zeros((dim_obs, 1)) 76 | for (state, prob) in env.initial_state_distribution.items(): 77 | state_visitation[state] = prob 78 | t_matrix = env.transition_matrix() # S x A x S 79 | sa_visit_t = np.zeros((dim_obs, dim_act, env_time_limit)) 80 | 81 | for i in range(env_time_limit): 82 | sa_visit = state_visitation * pol_probs 83 | # sa_visit_t[:, :, i] = (discount ** i) * sa_visit 84 | sa_visit_t[:, :, i] = sa_visit 85 | # sum-out (SA)S 86 | new_state_visitation = np.einsum('ij,ijk->k', sa_visit, t_matrix) 87 | state_visitation = np.expand_dims(new_state_visitation, axis=1) 88 | return np.sum(sa_visit_t, axis=2) / float(env_time_limit) 89 | 90 | 91 | def compute_occupancy(env, q_fn, ent_wt=1.0, env_time_limit=50, discount=1.0): 92 | pol_probs = get_policy(q_fn, ent_wt=ent_wt) 93 | 94 | dim_obs = env.num_states 95 | dim_act = env.num_actions 96 | state_visitation = np.zeros((dim_obs, 1)) 97 | for (state, prob) in env.initial_state_distribution.items(): 98 | state_visitation[state] = prob 99 | t_matrix = env.transition_matrix() # S x A x S 100 | sa_visit_t = np.zeros((dim_obs, dim_act, env_time_limit)) 101 | 102 | for i in range(env_time_limit): 103 | sa_visit = state_visitation * pol_probs 104 | sa_visit_t[:, :, i] = (discount ** i) * sa_visit 105 | # sa_visit_t[:, :, i] = sa_visit 106 | # sum-out (SA)S 107 | new_state_visitation = np.einsum('ij,ijk->k', sa_visit, t_matrix) 108 | state_visitation = np.expand_dims(new_state_visitation, axis=1) 109 | return np.sum(sa_visit_t, axis=2) #/ float(env_time_limit) 110 | -------------------------------------------------------------------------------- /models/decoder/transposed_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import einops 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | def initialize_weights(m): 7 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): 8 | nn.init.kaiming_uniform_(m.weight.data, nonlinearity="relu") 9 | nn.init.constant_(m.bias.data, 0) 10 | elif isinstance(m, nn.Linear): 11 | nn.init.kaiming_uniform_(m.weight.data) 12 | nn.init.constant_(m.bias.data, 0) 13 | 14 | def horizontal_forward(network, x, input_shape=(-1,), output_shape=(-1,)): 15 | batch_with_horizon_shape = x.shape[: -len(input_shape)] 16 | if not batch_with_horizon_shape: 17 | batch_with_horizon_shape = (1,) 18 | x = x.reshape(-1, *input_shape) 19 | x = network(x) 20 | x = x.reshape(*batch_with_horizon_shape, *output_shape) 21 | return x 22 | 23 | def create_normal_dist( 24 | x, 25 | std=None, 26 | mean_scale=1, 27 | init_std=0, 28 | min_std=0.1, 29 | activation=None, 30 | event_shape=None, 31 | ): 32 | if std == None: 33 | mean, std = torch.chunk(x, 2, -1) 34 | mean = mean / mean_scale 35 | if activation: 36 | mean = activation(mean) 37 | mean = mean_scale * mean 38 | std = F.softplus(std + init_std) + min_std 39 | else: 40 | mean = x 41 | dist = torch.distributions.Normal(mean, std) 42 | if event_shape: 43 | dist = torch.distributions.Independent(dist, event_shape) 44 | return dist 45 | 46 | 47 | class TransposedConvDecoder(nn.Module): 48 | def __init__(self, observation_shape=(3, 224, 224), emb_dim=512, activation=nn.ReLU, depth=32, kernel_size=5, stride=3): 49 | super().__init__() 50 | 51 | activation = activation() 52 | self.observation_shape = observation_shape 53 | self.depth = depth 54 | self.kernel_size = kernel_size 55 | self.stride = stride 56 | self.emb_dim = emb_dim 57 | 58 | self.network = nn.Sequential( 59 | nn.Linear( 60 | emb_dim, self.depth * 32 61 | ), 62 | nn.Unflatten(1, (self.depth * 32, 1)), 63 | nn.Unflatten(2, (1,1)), 64 | nn.ConvTranspose2d( 65 | self.depth * 32, 66 | self.depth * 8, 67 | self.kernel_size, 68 | self.stride, 69 | padding=1 70 | ), 71 | activation, 72 | nn.ConvTranspose2d( 73 | self.depth * 8, 74 | self.depth * 4, 75 | self.kernel_size, 76 | self.stride, 77 | padding=1 78 | ), 79 | activation, 80 | nn.ConvTranspose2d( 81 | self.depth * 4, 82 | self.depth * 2, 83 | self.kernel_size, 84 | self.stride, 85 | padding=1 86 | ), 87 | activation, 88 | nn.ConvTranspose2d( 89 | self.depth * 2, 90 | self.depth * 1, 91 | self.kernel_size, 92 | self.stride, 93 | padding=1 94 | ), 95 | activation, 96 | nn.ConvTranspose2d( 97 | self.depth * 1, 98 | self.observation_shape[0], 99 | self.kernel_size, 100 | self.stride, 101 | padding=1 102 | ), 103 | nn.Upsample(size=(observation_shape[1], observation_shape[2]), mode='bilinear', align_corners=False) 104 | ) 105 | self.network.apply(initialize_weights) 106 | 107 | def forward(self, posterior): 108 | x = horizontal_forward( 109 | self.network, posterior, input_shape=[self.emb_dim],output_shape=self.observation_shape 110 | ) 111 | dist = create_normal_dist(x, std=1, event_shape=len(self.observation_shape)) 112 | img = dist.mean.squeeze(2) 113 | img = einops.rearrange(img, "b t c h w -> (b t) c h w") 114 | return img, torch.zeros(1).to(posterior.device) # dummy placeholder -------------------------------------------------------------------------------- /env/pusht/pusht_wrapper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import gym 4 | from env.pusht.pusht_env import PushTEnv 5 | from utils import aggregate_dct 6 | 7 | class PushTWrapper(PushTEnv): 8 | def __init__( 9 | self, 10 | with_velocity=True, 11 | with_target=True, 12 | ): 13 | super().__init__( 14 | with_velocity=with_velocity, 15 | with_target=with_target, 16 | ) 17 | self.action_dim = self.action_space.shape[0] 18 | 19 | def sample_random_init_goal_states(self, seed): 20 | """ 21 | Return two random states: one as the initial state and one as the goal state. 22 | """ 23 | rs = np.random.RandomState(seed) 24 | 25 | def generate_state(): 26 | if self.with_velocity: 27 | return np.array( 28 | [ 29 | rs.randint(50, 450), 30 | rs.randint(50, 450), 31 | rs.randint(100, 400), 32 | rs.randint(100, 400), 33 | rs.randn() * 2 * np.pi - np.pi, 34 | 0, 35 | 0, # agent velocities default 0 36 | ] 37 | ) 38 | else: 39 | return np.array( 40 | [ 41 | rs.randint(50, 450), 42 | rs.randint(50, 450), 43 | rs.randint(100, 400), 44 | rs.randint(100, 400), 45 | rs.randn() * 2 * np.pi - np.pi, 46 | ] 47 | ) 48 | 49 | init_state = generate_state() 50 | goal_state = generate_state() 51 | 52 | return init_state, goal_state 53 | 54 | def update_env(self, env_info): 55 | self.shape = env_info['shape'] 56 | 57 | def eval_state(self, goal_state, cur_state): 58 | """ 59 | Return True if the goal is reached 60 | [agent_x, agent_y, T_x, T_y, angle, agent_vx, agent_vy] 61 | """ 62 | # if position difference is < 20, and angle difference < np.pi/9, then success 63 | pos_diff = np.linalg.norm(goal_state[:4] - cur_state[:4]) 64 | angle_diff = np.abs(goal_state[4] - cur_state[4]) 65 | angle_diff = np.minimum(angle_diff, 2 * np.pi - angle_diff) 66 | success = pos_diff < 20 and angle_diff < np.pi / 9 67 | state_dist = np.linalg.norm(goal_state - cur_state) 68 | return { 69 | 'success': success, 70 | 'state_dist': state_dist, 71 | } 72 | 73 | def prepare(self, seed, init_state): 74 | """ 75 | Reset with controlled init_state 76 | obs: (H W C) 77 | state: (state_dim) 78 | """ 79 | self.seed(seed) 80 | self.reset_to_state = init_state 81 | obs, state = self.reset() 82 | return obs, state 83 | 84 | def step_multiple(self, actions): 85 | """ 86 | infos: dict, each key has shape (T, ...) 87 | """ 88 | obses = [] 89 | rewards = [] 90 | dones = [] 91 | infos = [] 92 | for action in actions: 93 | o, r, d, info = self.step(action) 94 | obses.append(o) 95 | rewards.append(r) 96 | dones.append(d) 97 | infos.append(info) 98 | obses = aggregate_dct(obses) 99 | rewards = np.stack(rewards) 100 | dones = np.stack(dones) 101 | infos = aggregate_dct(infos) 102 | return obses, rewards, dones, infos 103 | 104 | def rollout(self, seed, init_state, actions): 105 | """ 106 | only returns np arrays of observations and states 107 | seed: int 108 | init_state: (state_dim, ) 109 | actions: (T, action_dim) 110 | obses: dict (T, H, W, C) 111 | states: (T, D) 112 | """ 113 | obs, state = self.prepare(seed, init_state) 114 | obses, rewards, dones, infos = self.step_multiple(actions) 115 | for k in obses.keys(): 116 | obses[k] = np.vstack([np.expand_dims(obs[k], 0), obses[k]]) 117 | states = np.vstack([np.expand_dims(state, 0), infos["state"]]) 118 | states = np.stack(states) 119 | return obses, states 120 | -------------------------------------------------------------------------------- /env/pointmaze/gridcraft/wrappers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .grid_env import REWARD, GridEnv 3 | from .wrappers import ObsWrapper 4 | from gym.spaces import Box 5 | 6 | 7 | class GridObsWrapper(ObsWrapper): 8 | def __init__(self, env): 9 | super(GridObsWrapper, self).__init__(env) 10 | 11 | def render(self): 12 | self.env.render() 13 | 14 | 15 | 16 | class EyesWrapper(ObsWrapper): 17 | def __init__(self, env, range=4, types=(REWARD,), angle_thresh=0.8): 18 | super(EyesWrapper, self).__init__(env) 19 | self.types = types 20 | self.range = range 21 | self.angle_thresh = angle_thresh 22 | 23 | eyes_low = np.ones(5*len(types)) 24 | eyes_high = np.ones(5*len(types)) 25 | low = np.r_[env.observation_space.low, eyes_low] 26 | high = np.r_[env.observation_space.high, eyes_high] 27 | self.__observation_space = Box(low, high) 28 | 29 | def wrap_obs(self, obs, info=None): 30 | gs = self.env.gs # grid spec 31 | xy = gs.idx_to_xy(self.env.obs_to_state(obs)) 32 | #xy = np.array([x, y]) 33 | 34 | extra_obs = [] 35 | for tile_type in self.types: 36 | idxs = gs.find(tile_type).astype(np.float32) # N x 2 37 | # gather all idxs that are close 38 | diffs = idxs-np.expand_dims(xy, axis=0) 39 | dists = np.linalg.norm(diffs, axis=1) 40 | valid_idxs = np.where(dists <= self.range)[0] 41 | if len(valid_idxs) == 0: 42 | eye_data = np.array([0,0,0,0,0], dtype=np.float32) 43 | else: 44 | diffs = diffs[valid_idxs, :] 45 | dists = dists[valid_idxs]+1e-6 46 | cosines = diffs[:,0]/dists 47 | cosines = np.r_[cosines, 0] 48 | sines = diffs[:,1]/dists 49 | sines = np.r_[sines, 0] 50 | on_target = 0.0 51 | if np.any(dists<=1.0): 52 | on_target = 1.0 53 | eye_data = np.abs(np.array([on_target, np.max(cosines), np.min(cosines), np.max(sines), np.min(sines)])) 54 | eye_data[np.where(eye_data<=self.angle_thresh)] = 0 55 | extra_obs.append(eye_data) 56 | extra_obs = np.concatenate(extra_obs) 57 | obs = np.r_[obs, extra_obs] 58 | return obs 59 | 60 | def unwrap_obs(self, obs, info=None): 61 | if len(obs.shape) == 1: 62 | return obs[:-5*len(self.types)] 63 | else: 64 | return obs[:,:-5*len(self.types)] 65 | 66 | @property 67 | def observation_space(self): 68 | return self.__observation_space 69 | 70 | 71 | """ 72 | class CoordinateWiseWrapper(GridObsWrapper): 73 | def __init__(self, env): 74 | assert isinstance(env, GridEnv) 75 | super(CoordinateWiseWrapper, self).__init__(env) 76 | self.gs = env.gs 77 | self.dO = self.gs.width+self.gs.height 78 | 79 | self.__observation_space = Box(0, 1, self.dO) 80 | 81 | def wrap_obs(self, obs, info=None): 82 | state = one_hot_to_flat(obs) 83 | xy = self.gs.idx_to_xy(state) 84 | x = flat_to_one_hot(xy[0], self.gs.width) 85 | y = flat_to_one_hot(xy[1], self.gs.height) 86 | obs = np.r_[x, y] 87 | return obs 88 | 89 | def unwrap_obs(self, obs, info=None): 90 | 91 | if len(obs.shape) == 1: 92 | x = obs[:self.gs.width] 93 | y = obs[self.gs.width:] 94 | x = one_hot_to_flat(x) 95 | y = one_hot_to_flat(y) 96 | state = self.gs.xy_to_idx(np.c_[x,y]) 97 | return flat_to_one_hot(state, self.dO) 98 | else: 99 | raise NotImplementedError() 100 | """ 101 | 102 | 103 | class RandomObsWrapper(GridObsWrapper): 104 | def __init__(self, env, dO): 105 | assert isinstance(env, GridEnv) 106 | super(RandomObsWrapper, self).__init__(env) 107 | self.gs = env.gs 108 | self.dO = dO 109 | self.obs_matrix = np.random.randn(self.dO, len(self.gs)) 110 | self.__observation_space = Box(np.min(self.obs_matrix), np.max(self.obs_matrix), 111 | shape=(self.dO,), dtype=np.float32) 112 | 113 | def wrap_obs(self, obs, info=None): 114 | return np.inner(self.obs_matrix, obs) 115 | 116 | def unwrap_obs(self, obs, info=None): 117 | raise NotImplementedError() 118 | 119 | -------------------------------------------------------------------------------- /env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/collision/link6_vhacd.obj: -------------------------------------------------------------------------------- 1 | o convex_0 2 | v 0.033187 0.018503 -0.027247 3 | v -0.037691 -0.004112 -0.027247 4 | v -0.036680 -0.006625 -0.028254 5 | v 0.013576 -0.034758 0.000399 6 | v -0.012050 0.035092 0.000399 7 | v 0.001516 -0.051863 -0.020710 8 | v -0.035174 -0.014657 -0.000608 9 | v 0.037202 0.000916 0.000399 10 | v -0.009541 0.036598 -0.027749 11 | v 0.032685 -0.018178 -0.028254 12 | v 0.016093 0.034593 -0.000608 13 | v -0.032658 0.019510 -0.000608 14 | v -0.005016 -0.051863 -0.008652 15 | v -0.024620 -0.028233 -0.028254 16 | v 0.028153 -0.025719 -0.000608 17 | v -0.028642 0.025037 -0.027247 18 | v 0.013074 0.035092 -0.028254 19 | v 0.006542 -0.051863 -0.011667 20 | v -0.029646 -0.022706 0.000399 21 | v 0.028662 0.025046 -0.000608 22 | v 0.013576 -0.034758 -0.028254 23 | v 0.037711 -0.000091 -0.027749 24 | v -0.037691 0.004927 -0.000608 25 | v -0.034673 0.013984 -0.028254 26 | v 0.005029 0.037614 -0.000608 27 | v 0.036700 -0.009629 -0.000608 28 | v -0.031654 -0.021191 -0.027247 29 | v -0.007024 -0.051855 -0.014175 30 | v -0.022103 0.031072 -0.000608 31 | v -0.013054 -0.034767 0.000399 32 | v 0.022123 0.031072 -0.027247 33 | v 0.036700 0.009963 -0.000608 34 | v 0.028153 -0.025719 -0.027247 35 | v 0.001014 -0.051863 -0.007143 36 | v -0.014058 -0.034767 -0.028254 37 | v 0.005029 0.037614 -0.027247 38 | v -0.034171 0.014982 0.000399 39 | v 0.021119 0.030572 0.000399 40 | v -0.016073 0.034593 -0.027247 41 | v -0.028133 -0.025719 -0.000608 42 | v -0.008028 0.037098 -0.000608 43 | v 0.034692 -0.015664 -0.027247 44 | v 0.035194 0.012477 -0.028254 45 | v -0.005016 -0.051863 -0.019201 46 | v -0.036680 0.009955 -0.027247 47 | v 0.006542 -0.051863 -0.017193 48 | v -0.037691 -0.004112 -0.000608 49 | v 0.031674 -0.021191 -0.000608 50 | v -0.035174 -0.014657 -0.027247 51 | v -0.023616 0.029066 -0.028254 52 | v 0.033187 0.018503 -0.000608 53 | v 0.023636 -0.028732 0.000399 54 | v 0.028662 0.025046 -0.027247 55 | v 0.036700 0.009963 -0.027247 56 | v -0.005016 0.037614 -0.027247 57 | v 0.037711 0.004936 -0.000608 58 | v 0.037711 -0.004112 -0.027247 59 | v -0.032658 0.019510 -0.027247 60 | v -0.035676 0.012976 -0.000608 61 | v 0.022123 0.031072 -0.000608 62 | v -0.016073 0.034593 -0.000608 63 | v 0.020617 -0.031246 -0.028254 64 | v -0.028642 0.025037 -0.000608 65 | v 0.016093 0.034593 -0.027247 66 | f 36 11 64 67 | f 5 4 8 68 | f 3 10 14 69 | f 10 3 17 70 | f 13 6 18 71 | f 4 5 19 72 | f 14 10 21 73 | f 17 3 24 74 | f 3 14 27 75 | f 27 14 28 76 | f 4 19 30 77 | f 15 18 33 78 | f 13 18 34 79 | f 4 30 34 80 | f 30 13 34 81 | f 21 6 35 82 | f 14 21 35 83 | f 17 9 36 84 | f 25 11 36 85 | f 19 5 37 86 | f 5 29 37 87 | f 5 8 38 88 | f 25 5 38 89 | f 11 25 38 90 | f 19 7 40 91 | f 7 27 40 92 | f 27 28 40 93 | f 28 13 40 94 | f 13 30 40 95 | f 30 19 40 96 | f 5 25 41 97 | f 9 39 41 98 | f 33 10 42 99 | f 10 17 43 100 | f 22 10 43 101 | f 6 13 44 102 | f 13 28 44 103 | f 28 14 44 104 | f 35 6 44 105 | f 14 35 44 106 | f 3 2 45 107 | f 2 23 45 108 | f 24 3 45 109 | f 18 6 46 110 | f 33 18 46 111 | f 2 7 47 112 | f 7 19 47 113 | f 23 2 47 114 | f 19 37 47 115 | f 37 23 47 116 | f 26 8 48 117 | f 15 33 48 118 | f 42 26 48 119 | f 33 42 48 120 | f 2 3 49 121 | f 7 2 49 122 | f 3 27 49 123 | f 27 7 49 124 | f 9 17 50 125 | f 24 16 50 126 | f 17 24 50 127 | f 16 29 50 128 | f 39 9 50 129 | f 29 39 50 130 | f 1 20 51 131 | f 32 1 51 132 | f 8 32 51 133 | f 38 8 51 134 | f 20 38 51 135 | f 8 4 52 136 | f 18 15 52 137 | f 4 34 52 138 | f 34 18 52 139 | f 48 8 52 140 | f 15 48 52 141 | f 20 1 53 142 | f 17 31 53 143 | f 31 20 53 144 | f 43 17 53 145 | f 1 43 53 146 | f 1 32 54 147 | f 22 43 54 148 | f 43 1 54 149 | f 36 9 55 150 | f 25 36 55 151 | f 41 25 55 152 | f 9 41 55 153 | f 8 26 56 154 | f 32 8 56 155 | f 54 32 56 156 | f 22 54 56 157 | f 10 22 57 158 | f 42 10 57 159 | f 26 42 57 160 | f 56 26 57 161 | f 22 56 57 162 | f 12 16 58 163 | f 16 24 58 164 | f 24 45 58 165 | f 58 45 59 166 | f 37 12 59 167 | f 23 37 59 168 | f 45 23 59 169 | f 12 58 59 170 | f 31 11 60 171 | f 20 31 60 172 | f 38 20 60 173 | f 11 38 60 174 | f 29 5 61 175 | f 39 29 61 176 | f 5 41 61 177 | f 41 39 61 178 | f 6 21 62 179 | f 21 10 62 180 | f 10 33 62 181 | f 46 6 62 182 | f 33 46 62 183 | f 16 12 63 184 | f 29 16 63 185 | f 12 37 63 186 | f 37 29 63 187 | f 31 17 64 188 | f 11 31 64 189 | f 17 36 64 190 | -------------------------------------------------------------------------------- /models/vit.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py 2 | import torch 3 | from torch import nn 4 | from einops import rearrange, repeat 5 | 6 | # helpers 7 | NUM_FRAMES = 1 8 | NUM_PATCHES = 1 9 | 10 | def pair(t): 11 | return t if isinstance(t, tuple) else (t, t) 12 | 13 | def generate_mask_matrix(npatch, nwindow): 14 | zeros = torch.zeros(npatch, npatch) 15 | ones = torch.ones(npatch, npatch) 16 | rows = [] 17 | for i in range(nwindow): 18 | row = torch.cat([ones] * (i+1) + [zeros] * (nwindow - i-1), dim=1) 19 | rows.append(row) 20 | mask = torch.cat(rows, dim=0).unsqueeze(0).unsqueeze(0) 21 | return mask 22 | 23 | class FeedForward(nn.Module): 24 | def __init__(self, dim, hidden_dim, dropout = 0.): 25 | super().__init__() 26 | self.net = nn.Sequential( 27 | nn.LayerNorm(dim), 28 | nn.Linear(dim, hidden_dim), 29 | nn.GELU(), 30 | nn.Dropout(dropout), 31 | nn.Linear(hidden_dim, dim), 32 | nn.Dropout(dropout) 33 | ) 34 | 35 | def forward(self, x): 36 | return self.net(x) 37 | 38 | class Attention(nn.Module): 39 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 40 | super().__init__() 41 | inner_dim = dim_head * heads 42 | project_out = not (heads == 1 and dim_head == dim) 43 | 44 | self.heads = heads 45 | self.scale = dim_head ** -0.5 46 | 47 | self.norm = nn.LayerNorm(dim) 48 | 49 | self.attend = nn.Softmax(dim = -1) 50 | self.dropout = nn.Dropout(dropout) 51 | 52 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 53 | 54 | self.to_out = nn.Sequential( 55 | nn.Linear(inner_dim, dim), 56 | nn.Dropout(dropout) 57 | ) if project_out else nn.Identity() 58 | self.bias = generate_mask_matrix(NUM_PATCHES, NUM_FRAMES).to('cuda') 59 | 60 | def forward(self, x): 61 | ( 62 | B, 63 | T, 64 | C, 65 | ) = x.size() 66 | x = self.norm(x) 67 | 68 | qkv = self.to_qkv(x).chunk(3, dim = -1) 69 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 70 | 71 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 72 | # apply causal mask 73 | dots = dots.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) 74 | 75 | attn = self.attend(dots) 76 | attn = self.dropout(attn) 77 | 78 | out = torch.matmul(attn, v) 79 | out = rearrange(out, 'b h n d -> b n (h d)') 80 | return self.to_out(out) 81 | 82 | class Transformer(nn.Module): 83 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 84 | super().__init__() 85 | self.norm = nn.LayerNorm(dim) 86 | self.layers = nn.ModuleList([]) 87 | for _ in range(depth): 88 | self.layers.append(nn.ModuleList([ 89 | Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), 90 | FeedForward(dim, mlp_dim, dropout = dropout) 91 | ])) 92 | 93 | def forward(self, x): 94 | for attn, ff in self.layers: 95 | x = attn(x) + x 96 | x = ff(x) + x 97 | 98 | return self.norm(x) 99 | 100 | class ViTPredictor(nn.Module): 101 | def __init__(self, *, num_patches, num_frames, dim, depth, heads, mlp_dim, pool='cls', dim_head=64, dropout=0., emb_dropout=0.): 102 | super().__init__() 103 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 104 | 105 | # update params for adding causal attention masks 106 | global NUM_FRAMES, NUM_PATCHES 107 | NUM_FRAMES = num_frames 108 | NUM_PATCHES = num_patches 109 | 110 | self.pos_embedding = nn.Parameter(torch.randn(1, num_frames * (num_patches), dim)) # dim for the pos encodings 111 | self.dropout = nn.Dropout(emb_dropout) 112 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 113 | self.pool = pool 114 | 115 | def forward(self, x): # x: (b, window_size * H/patch_size * W/patch_size, 384) 116 | b, n, _ = x.shape 117 | x = x + self.pos_embedding[:, :n] 118 | x = self.dropout(x) 119 | x = self.transformer(x) 120 | return x -------------------------------------------------------------------------------- /env/wall/wall_env_wrapper.py: -------------------------------------------------------------------------------- 1 | from .envs.wall import DotWall 2 | from .envs.wall import WallDatasetConfig 3 | import numpy as np 4 | import torch 5 | import random 6 | from torchvision import transforms 7 | 8 | 9 | from utils import aggregate_dct 10 | from .data.wall_utils import generate_wall_layouts 11 | 12 | ENV_ACTION_DIM = 2 13 | STATE_RANGES = np.array([ 14 | [16.6840, 46.9885], 15 | [4.0083,25.2532] 16 | ]) 17 | 18 | DEFAULT_CFG = WallDatasetConfig( 19 | action_angle_noise=0.2, 20 | action_step_mean=1.0, 21 | action_step_std=0.4, 22 | action_lower_bd=0.2, 23 | action_upper_bd=1.8, 24 | batch_size=64, 25 | device='cuda', 26 | dot_std=1.7, 27 | border_wall_loc=5, 28 | fix_wall_batch_k=None, 29 | fix_wall=True, 30 | fix_door_location=30, 31 | fix_wall_location=32, 32 | exclude_wall_train='', 33 | exclude_door_train='', 34 | only_wall_val='', 35 | only_door_val='', 36 | wall_padding=20, 37 | door_padding=10, 38 | wall_width=6, 39 | door_space=4, 40 | num_train_layouts=-1, 41 | img_size=65, 42 | max_step=1, 43 | n_steps=17, 44 | n_steps_reduce_factor=1, 45 | size=20000, 46 | val_size=10000, 47 | train=True, 48 | repeat_actions=1 49 | ) 50 | 51 | resize_transform = transforms.Resize((224, 224)) 52 | TRANSFORM = resize_transform 53 | 54 | class WallEnvWrapper(DotWall): 55 | def __init__(self, rng=42, wall_config=DEFAULT_CFG, fix_wall=True, cross_wall=False, fix_wall_location=32, fix_door_location=10, device='cpu', **kwargs): 56 | super().__init__(rng, wall_config, fix_wall, cross_wall, fix_wall_location=fix_wall_location, fix_door_location=fix_door_location, device=device,**kwargs) 57 | self.action_dim = ENV_ACTION_DIM 58 | self.transform = TRANSFORM 59 | 60 | def eval_state(self, goal_state, cur_state): 61 | success = np.linalg.norm(goal_state[:2] - cur_state[:2]) < 4.5 62 | state_dist = np.linalg.norm(goal_state - cur_state) 63 | return { 64 | 'success': success, 65 | 'state_dist': state_dist, 66 | } 67 | 68 | def sample_random_init_goal_states(self, seed): 69 | """ 70 | Return a random state 71 | """ 72 | return self.generate_random_state(seed) 73 | 74 | def update_env(self, env_info): # change door and wall locations 75 | self.wall_config.fix_door_location = env_info["fix_door_location"].item() 76 | self.wall_config.fix_wall_location = env_info["fix_wall_location"].item() 77 | layouts, other_layouts = generate_wall_layouts(self.wall_config) 78 | self.layouts = layouts 79 | self.wall_x, self.hole_y = self._generate_wall() 80 | 81 | def prepare(self, seed, init_state): 82 | """ 83 | Reset with controlled init_state 84 | """ 85 | self.seed(seed) 86 | self.set_init_state(init_state) 87 | obs, state = self.reset() 88 | obs['visual'] = self.transform(obs['visual']) 89 | obs['visual'] = obs['visual'].permute(1, 2, 0) 90 | return obs, state 91 | 92 | def step_multiple(self, actions): 93 | obses = [] 94 | rewards = [] 95 | dones = [] 96 | infos = [] 97 | for action in actions: 98 | o, r, d, info = self.step(action) 99 | o['visual'] = self.transform(o['visual']) 100 | o['visual'] = o['visual'].permute(1, 2, 0) 101 | obses.append(o) 102 | rewards.append(r) 103 | dones.append(d) 104 | infos.append(info) 105 | obses = aggregate_dct(obses) 106 | rewards = np.stack(rewards) 107 | dones = np.stack(dones) 108 | infos = aggregate_dct(infos) 109 | return obses, rewards, dones, infos 110 | 111 | def rollout(self, seed, init_state, actions): 112 | """ 113 | only returns np arrays of observations and states 114 | seed: int 115 | init_state: (state_dim, ) 116 | actions: (T, action_dim) 117 | obses: dict (T, H, W, C) 118 | states: (T, D) 119 | """ 120 | obs, state = self.prepare(seed, init_state) 121 | obses, rewards, dones, infos = self.step_multiple(actions) 122 | for k in obses.keys(): 123 | obses[k] = np.vstack([np.expand_dims(obs[k], 0), obses[k]]) 124 | states = np.vstack([np.expand_dims(state, 0), infos["state"]]) 125 | states = np.stack(states) 126 | return obses, states 127 | 128 | -------------------------------------------------------------------------------- /env/pointmaze/gridcraft/grid_spec.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | EMPTY = 110 5 | WALL = 111 6 | START = 112 7 | REWARD = 113 8 | OUT_OF_BOUNDS = 114 9 | REWARD2 = 115 10 | REWARD3 = 116 11 | REWARD4 = 117 12 | LAVA = 118 13 | GOAL = 119 14 | 15 | TILES = {EMPTY, WALL, START, REWARD, REWARD2, REWARD3, REWARD4, LAVA, GOAL} 16 | 17 | STR_MAP = { 18 | 'O': EMPTY, 19 | '#': WALL, 20 | 'S': START, 21 | 'R': REWARD, 22 | '2': REWARD2, 23 | '3': REWARD3, 24 | '4': REWARD4, 25 | 'G': GOAL, 26 | 'L': LAVA 27 | } 28 | 29 | RENDER_DICT = {v:k for k, v in STR_MAP.items()} 30 | RENDER_DICT[EMPTY] = ' ' 31 | RENDER_DICT[START] = ' ' 32 | 33 | 34 | 35 | def spec_from_string(s, valmap=STR_MAP): 36 | if s.endswith('\\'): 37 | s = s[:-1] 38 | rows = s.split('\\') 39 | rowlens = np.array([len(row) for row in rows]) 40 | assert np.all(rowlens == rowlens[0]) 41 | w, h = len(rows), len(rows[0])#len(rows[0]), len(rows) 42 | 43 | gs = GridSpec(w, h) 44 | for i in range(w): 45 | for j in range(h): 46 | gs[i,j] = valmap[rows[i][j]] 47 | return gs 48 | 49 | 50 | def spec_from_sparse_locations(w, h, tile_to_locs): 51 | """ 52 | 53 | Example usage: 54 | >> spec_from_sparse_locations(10, 10, {START: [(0,0)], REWARD: [(7,8), (8,8)]}) 55 | 56 | """ 57 | gs = GridSpec(w, h) 58 | for tile_type in tile_to_locs: 59 | locs = np.array(tile_to_locs[tile_type]) 60 | for i in range(locs.shape[0]): 61 | gs[tuple(locs[i])] = tile_type 62 | return gs 63 | 64 | 65 | def local_spec(map, xpnt): 66 | """ 67 | >>> local_spec("yOy\\\\Oxy", xpnt=(5,5)) 68 | array([[4, 4], 69 | [6, 4], 70 | [6, 5]]) 71 | """ 72 | Y = 0; X=1; O=2 73 | valmap={ 74 | 'y': Y, 75 | 'x': X, 76 | 'O': O 77 | } 78 | gs = spec_from_string(map, valmap=valmap) 79 | ys = gs.find(Y) 80 | x = gs.find(X) 81 | result = ys-x + np.array(xpnt) 82 | return result 83 | 84 | 85 | 86 | class GridSpec(object): 87 | def __init__(self, w, h): 88 | self.__data = np.zeros((w, h), dtype=np.int32) 89 | self.__w = w 90 | self.__h = h 91 | 92 | def __setitem__(self, key, val): 93 | self.__data[key] = val 94 | 95 | def __getitem__(self, key): 96 | if self.out_of_bounds(key): 97 | raise NotImplementedError("Out of bounds:"+str(key)) 98 | return self.__data[tuple(key)] 99 | 100 | def out_of_bounds(self, wh): 101 | """ Return true if x, y is out of bounds """ 102 | w, h = wh 103 | if w<0 or w>=self.__w: 104 | return True 105 | if h < 0 or h >= self.__h: 106 | return True 107 | return False 108 | 109 | def get_neighbors(self, k, xy=False): 110 | """ Return values of up, down, left, and right tiles """ 111 | if not xy: 112 | k = self.idx_to_xy(k) 113 | offsets = [np.array([0,-1]), np.array([0,1]), 114 | np.array([-1,0]), np.array([1,0])] 115 | neighbors = \ 116 | [self[k+offset] if (not self.out_of_bounds(k+offset)) else OUT_OF_BOUNDS for offset in offsets ] 117 | return neighbors 118 | 119 | def get_value(self, k, xy=False): 120 | """ Return values of up, down, left, and right tiles """ 121 | if not xy: 122 | k = self.idx_to_xy(k) 123 | return self[k] 124 | 125 | def find(self, value): 126 | return np.array(np.where(self.spec == value)).T 127 | 128 | @property 129 | def spec(self): 130 | return self.__data 131 | 132 | @property 133 | def width(self): 134 | return self.__w 135 | 136 | def __len__(self): 137 | return self.__w*self.__h 138 | 139 | @property 140 | def height(self): 141 | return self.__h 142 | 143 | def idx_to_xy(self, idx): 144 | if hasattr(idx, '__len__'): # array 145 | x = idx % self.__w 146 | y = np.floor(idx/self.__w).astype(np.int32) 147 | xy = np.c_[x,y] 148 | return xy 149 | else: 150 | return np.array([ idx % self.__w, int(np.floor(idx/self.__w))]) 151 | 152 | def xy_to_idx(self, key): 153 | shape = np.array(key).shape 154 | if len(shape) == 1: 155 | return key[0] + key[1]*self.__w 156 | elif len(shape) == 2: 157 | return key[:,0] + key[:,1]*self.__w 158 | else: 159 | raise NotImplementedError() 160 | 161 | def __hash__(self): 162 | data = (self.__w, self.__h) + tuple(self.__data.reshape([-1]).tolist()) 163 | return hash(data) 164 | -------------------------------------------------------------------------------- /env/pointmaze/dynamic_mjc.py: -------------------------------------------------------------------------------- 1 | """ 2 | dynamic_mjc.py 3 | A small library for programmatically building MuJoCo XML files 4 | """ 5 | from contextlib import contextmanager 6 | import tempfile 7 | import numpy as np 8 | 9 | 10 | def default_model(name): 11 | """ 12 | Get a model with basic settings such as gravity and RK4 integration enabled 13 | """ 14 | model = MJCModel(name) 15 | root = model.root 16 | 17 | # Setup 18 | root.compiler(angle="radian", inertiafromgeom="true") 19 | default = root.default() 20 | default.joint(armature=1, damping=1, limited="true") 21 | default.geom(contype=0, friction='1 0.1 0.1', rgba='0.7 0.7 0 1') 22 | root.option(gravity="0 0 -9.81", integrator="RK4", timestep=0.01) 23 | return model 24 | 25 | def pointmass_model(name): 26 | """ 27 | Get a model with basic settings such as gravity and Euler integration enabled 28 | """ 29 | model = MJCModel(name) 30 | root = model.root 31 | 32 | # Setup 33 | root.compiler(angle="radian", inertiafromgeom="true", coordinate="local") 34 | default = root.default() 35 | default.joint(limited="false", damping=1) 36 | default.geom(contype=2, conaffinity="1", condim="1", friction=".5 .1 .1", density="1000", margin="0.002") 37 | root.option(timestep=0.01, gravity="0 0 0", iterations="20", integrator="Euler") 38 | return model 39 | 40 | 41 | class MJCModel(object): 42 | def __init__(self, name): 43 | self.name = name 44 | self.root = MJCTreeNode("mujoco").add_attr('model', name) 45 | 46 | @contextmanager 47 | def asfile(self): 48 | """ 49 | Usage: 50 | model = MJCModel('reacher') 51 | with model.asfile() as f: 52 | print f.read() # prints a dump of the model 53 | """ 54 | with tempfile.NamedTemporaryFile(mode='w+', suffix='.xml', delete=True) as f: 55 | self.root.write(f) 56 | f.seek(0) 57 | yield f 58 | 59 | def open(self): 60 | self.file = tempfile.NamedTemporaryFile(mode='w+', suffix='.xml', delete=True) 61 | self.root.write(self.file) 62 | self.file.seek(0) 63 | return self.file 64 | 65 | def close(self): 66 | self.file.close() 67 | 68 | def find_attr(self, attr, value): 69 | return self.root.find_attr(attr, value) 70 | 71 | def __getstate__(self): 72 | return {} 73 | 74 | def __setstate__(self, state): 75 | pass 76 | 77 | 78 | class MJCTreeNode(object): 79 | def __init__(self, name): 80 | self.name = name 81 | self.attrs = {} 82 | self.children = [] 83 | 84 | def add_attr(self, key, value): 85 | if isinstance(value, str): 86 | pass 87 | elif isinstance(value, list) or isinstance(value, np.ndarray): 88 | value = ' '.join([str(val).lower() for val in value]) 89 | else: 90 | value = str(value).lower() 91 | 92 | self.attrs[key] = value 93 | return self 94 | 95 | def __getattr__(self, name): 96 | def wrapper(**kwargs): 97 | newnode = MJCTreeNode(name) 98 | for (k, v) in kwargs.items(): 99 | newnode.add_attr(k, v) 100 | self.children.append(newnode) 101 | return newnode 102 | return wrapper 103 | 104 | def dfs(self): 105 | yield self 106 | if self.children: 107 | for child in self.children: 108 | for node in child.dfs(): 109 | yield node 110 | 111 | def find_attr(self, attr, value): 112 | """ Run DFS to find a matching attr """ 113 | if attr in self.attrs and self.attrs[attr] == value: 114 | return self 115 | for child in self.children: 116 | res = child.find_attr(attr, value) 117 | if res is not None: 118 | return res 119 | return None 120 | 121 | 122 | def write(self, ostream, tabs=0): 123 | contents = ' '.join(['%s="%s"'%(k,v) for (k,v) in self.attrs.items()]) 124 | if self.children: 125 | ostream.write('\t'*tabs) 126 | ostream.write('<%s %s>\n' % (self.name, contents)) 127 | for child in self.children: 128 | child.write(ostream, tabs=tabs+1) 129 | ostream.write('\t'*tabs) 130 | ostream.write('\n' % self.name) 131 | else: 132 | ostream.write('\t'*tabs) 133 | ostream.write('<%s %s/>\n' % (self.name, contents)) 134 | 135 | def __str__(self): 136 | s = "<"+self.name 137 | s += ' '.join(['%s="%s"'%(k,v) for (k,v) in self.attrs.items()]) 138 | return s+">" 139 | -------------------------------------------------------------------------------- /env/deformable_env/src/sim/sim_env/robot_env.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pyflex 4 | 5 | import pybullet as p 6 | from bs4 import BeautifulSoup 7 | 8 | from .transformations import quaternion_from_matrix, quaternion_matrix 9 | 10 | class FlexRobotHelper: 11 | def __init__(self): 12 | self.transform_bullet_to_flex = np.array([ 13 | [1, 0, 0, 0], 14 | [0, 0, 1, 0], 15 | [0, -1, 0, 0], 16 | [0, 0, 0, 1]]) 17 | self.robotId = None 18 | 19 | def loadURDF(self, fileName, basePosition, baseOrientation, useFixedBase = True, globalScaling = 1.0): 20 | if self.robotId is None: 21 | # print("Loading robot from file: ", fileName) 22 | # fileName = os.path.join('/home/gary/AdaptiGraph/src/', fileName) 23 | # print("Loading robot from file: ", fileName) 24 | self.robotId = p.loadURDF(fileName, basePosition, baseOrientation, useFixedBase = useFixedBase, globalScaling = globalScaling) 25 | p.resetBasePositionAndOrientation(self.robotId, basePosition, baseOrientation) 26 | 27 | robot_path = fileName # changed the urdf file 28 | robot_path_par = os.path.abspath(os.path.join(robot_path, os.pardir)) 29 | with open(robot_path, 'r') as f: 30 | robot = f.read() 31 | robot_data = BeautifulSoup(robot, 'xml') 32 | links = robot_data.find_all('link') 33 | 34 | # add the mesh to pyflex 35 | self.num_meshes = 0 36 | self.has_mesh = np.ones(len(links), dtype=bool) 37 | 38 | """ 39 | XARM6 with gripper: 40 | 0: base_link; 41 | 1 - 6: link1 - link6; (without gripper - 7: stick/finger) 42 | 43 | 7: base_link; 44 | 8: left outer knuckle; 45 | 9: left finger; 46 | 10: left inner knuckle; 47 | 11: right outer knuckle; 48 | 12: right finger; 49 | 13: right inner knuckle; 50 | """ 51 | for i in range(len(links)): 52 | link = links[i] 53 | if link.find_all('geometry'): 54 | mesh_name = link.find_all('geometry')[0].find_all('mesh')[0].get('filename') 55 | pyflex.add_mesh(os.path.join(robot_path_par, mesh_name), globalScaling, 0, np.ones(3), np.zeros(3), np.zeros(4), False) 56 | self.num_meshes += 1 57 | else: 58 | self.has_mesh[i] = False 59 | 60 | self.num_link = len(links) 61 | self.state_pre = None 62 | 63 | return self.robotId 64 | 65 | def resetJointState(self, i, pose): 66 | p.resetJointState(self.robotId, i, pose) 67 | return self.getRobotShapeStates() 68 | 69 | def getRobotShapeStates(self): 70 | # convert pybullet link state to pyflex link state 71 | state_cur = [] 72 | base_com_pos, base_com_orn = p.getBasePositionAndOrientation(self.robotId) 73 | di = p.getDynamicsInfo(self.robotId, -1) 74 | local_inertial_pos, local_inertial_orn = di[3], di[4] 75 | 76 | pos_inv, orn_inv = p.invertTransform(local_inertial_pos, local_inertial_orn) 77 | pos, orn = p.multiplyTransforms(base_com_pos, base_com_orn, pos_inv, orn_inv) 78 | 79 | state_cur.append(list(pos) + [1] + list(orn)) 80 | 81 | for l in range(self.num_link-1): 82 | ls = p.getLinkState(self.robotId, l) 83 | pos = ls[4] 84 | orn = ls[5] 85 | state_cur.append(list(pos) + [1] + list(orn)) 86 | 87 | state_cur = np.array(state_cur) 88 | 89 | shape_states = np.zeros((self.num_meshes, 14)) 90 | if self.state_pre is None: 91 | self.state_pre = state_cur.copy() 92 | 93 | mesh_idx = 0 94 | for i in range(self.num_link): 95 | if self.has_mesh[i]: 96 | # pos + [1] 97 | shape_states[mesh_idx, 0:3] = np.matmul( 98 | self.transform_bullet_to_flex, state_cur[i, :4])[:3] 99 | shape_states[mesh_idx, 3:6] = np.matmul( 100 | self.transform_bullet_to_flex, self.state_pre[i, :4])[:3] 101 | # orientation 102 | shape_states[mesh_idx, 6:10] = quaternion_from_matrix( 103 | np.matmul(self.transform_bullet_to_flex, 104 | quaternion_matrix(state_cur[i, 4:]))) 105 | shape_states[mesh_idx, 10:14] = quaternion_from_matrix( 106 | np.matmul(self.transform_bullet_to_flex, 107 | quaternion_matrix(self.state_pre[i, 4:]))) 108 | mesh_idx += 1 109 | 110 | self.state_pre = state_cur 111 | return shape_states -------------------------------------------------------------------------------- /planning/gd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from einops import rearrange 4 | from .base_planner import BasePlanner 5 | from utils import move_to_device 6 | 7 | 8 | class GDPlanner(BasePlanner): 9 | def __init__( 10 | self, 11 | horizon, 12 | action_noise, 13 | sample_type, 14 | lr, 15 | opt_steps, 16 | eval_every, 17 | wm, 18 | action_dim, 19 | objective_fn, 20 | preprocessor, 21 | evaluator, 22 | wandb_run, 23 | logging_prefix="plan_0", 24 | log_filename="logs.json", 25 | **kwargs, 26 | ): 27 | super().__init__( 28 | wm, 29 | action_dim, 30 | objective_fn, 31 | preprocessor, 32 | evaluator, 33 | wandb_run, 34 | log_filename, 35 | ) 36 | self.horizon = horizon 37 | self.action_noise = action_noise 38 | self.sample_type = sample_type 39 | self.lr = lr 40 | self.opt_steps = opt_steps 41 | self.eval_every = eval_every 42 | self.logging_prefix = logging_prefix 43 | 44 | def init_actions(self, obs_0, actions=None): 45 | """ 46 | Initializes or appends actions for planning, ensuring the output shape is (b, self.horizon, action_dim). 47 | """ 48 | n_evals = obs_0["visual"].shape[0] 49 | if actions is None: 50 | actions = torch.zeros(n_evals, 0, self.action_dim) 51 | device = actions.device 52 | t = actions.shape[1] 53 | remaining_t = self.horizon - t 54 | 55 | if remaining_t > 0: 56 | if self.sample_type == "randn": 57 | new_actions = torch.randn(n_evals, remaining_t, self.action_dim) 58 | elif self.sample_type == "zero": # zero action of env 59 | new_actions = torch.zeros(n_evals, remaining_t, self.action_dim) 60 | new_actions = rearrange( 61 | new_actions, "... (f d) -> ... f d", f=self.evaluator.frameskip 62 | ) 63 | new_actions = self.preprocessor.normalize_actions(new_actions) 64 | new_actions = rearrange(new_actions, "... f d -> ... (f d)") 65 | actions = torch.cat([actions, new_actions.to(device)], dim=1) 66 | return actions 67 | 68 | def get_action_optimizer(self, actions): 69 | return torch.optim.SGD([actions], lr=self.lr) 70 | 71 | def plan(self, obs_0, obs_g, actions=None): 72 | """ 73 | Args: 74 | actions: normalized 75 | Returns: 76 | actions: (B, T, action_dim) torch.Tensor 77 | """ 78 | trans_obs_0 = move_to_device( 79 | self.preprocessor.transform_obs(obs_0), self.device 80 | ) 81 | trans_obs_g = move_to_device( 82 | self.preprocessor.transform_obs(obs_g), self.device 83 | ) 84 | z_obs_g = self.wm.encode_obs(trans_obs_g) 85 | z_obs_g_detached = {key: value.detach() for key, value in z_obs_g.items()} 86 | 87 | actions = self.init_actions(obs_0, actions).to(self.device) 88 | actions.requires_grad = True 89 | optimizer = self.get_action_optimizer(actions) 90 | n_evals = actions.shape[0] 91 | 92 | for i in range(self.opt_steps): 93 | optimizer.zero_grad() 94 | i_z_obses, i_zs = self.wm.rollout( 95 | obs_0=trans_obs_0, 96 | act=actions, 97 | ) 98 | loss = self.objective_fn(i_z_obses, z_obs_g_detached) # (n_evals, ) 99 | total_loss = loss.mean() * n_evals # loss for each eval is independent 100 | total_loss.backward() 101 | with torch.no_grad(): 102 | actions_new = actions - optimizer.param_groups[0]["lr"] * actions.grad 103 | actions_new += ( 104 | torch.randn_like(actions_new) * self.action_noise 105 | ) # Add Gaussian noise 106 | actions.copy_(actions_new) 107 | 108 | self.wandb_run.log( 109 | {f"{self.logging_prefix}/loss": total_loss.item(), "step": i + 1} 110 | ) 111 | if self.evaluator is not None and i % self.eval_every == 0: 112 | logs, successes, _, _ = self.evaluator.eval_actions( 113 | actions.detach(), filename=f"{self.logging_prefix}_output_{i+1}" 114 | ) 115 | logs = {f"{self.logging_prefix}/{k}": v for k, v in logs.items()} 116 | logs.update({"step": i + 1}) 117 | self.wandb_run.log(logs) 118 | self.dump_logs(logs) 119 | if np.all(successes): 120 | break # terminate planning if all success 121 | return actions, np.full(n_evals, np.inf) # all actions are valid 122 | -------------------------------------------------------------------------------- /planning/cem.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from einops import rearrange, repeat 4 | from .base_planner import BasePlanner 5 | from utils import move_to_device 6 | 7 | 8 | class CEMPlanner(BasePlanner): 9 | def __init__( 10 | self, 11 | horizon, 12 | topk, 13 | num_samples, 14 | var_scale, 15 | opt_steps, 16 | eval_every, 17 | wm, 18 | action_dim, 19 | objective_fn, 20 | preprocessor, 21 | evaluator, 22 | wandb_run, 23 | logging_prefix="plan_0", 24 | log_filename="logs.json", 25 | **kwargs, 26 | ): 27 | super().__init__( 28 | wm, 29 | action_dim, 30 | objective_fn, 31 | preprocessor, 32 | evaluator, 33 | wandb_run, 34 | log_filename, 35 | ) 36 | self.horizon = horizon 37 | self.topk = topk 38 | self.num_samples = num_samples 39 | self.var_scale = var_scale 40 | self.opt_steps = opt_steps 41 | self.eval_every = eval_every 42 | self.logging_prefix = logging_prefix 43 | 44 | def init_mu_sigma(self, obs_0, actions=None): 45 | """ 46 | actions: (B, T, action_dim) torch.Tensor, T <= self.horizon 47 | mu, sigma could depend on current obs, but obs_0 is only used for providing n_evals for now 48 | """ 49 | n_evals = obs_0["visual"].shape[0] 50 | sigma = self.var_scale * torch.ones([n_evals, self.horizon, self.action_dim]) 51 | if actions is None: 52 | mu = torch.zeros(n_evals, 0, self.action_dim) 53 | else: 54 | mu = actions 55 | device = mu.device 56 | t = mu.shape[1] 57 | remaining_t = self.horizon - t 58 | 59 | if remaining_t > 0: 60 | new_mu = torch.zeros(n_evals, remaining_t, self.action_dim) 61 | mu = torch.cat([mu, new_mu.to(device)], dim=1) 62 | return mu, sigma 63 | 64 | def plan(self, obs_0, obs_g, actions=None): 65 | """ 66 | Args: 67 | actions: normalized 68 | Returns: 69 | actions: (B, T, action_dim) torch.Tensor, T <= self.horizon 70 | """ 71 | trans_obs_0 = move_to_device( 72 | self.preprocessor.transform_obs(obs_0), self.device 73 | ) 74 | trans_obs_g = move_to_device( 75 | self.preprocessor.transform_obs(obs_g), self.device 76 | ) 77 | z_obs_g = self.wm.encode_obs(trans_obs_g) 78 | 79 | mu, sigma = self.init_mu_sigma(obs_0, actions) 80 | mu, sigma = mu.to(self.device), sigma.to(self.device) 81 | n_evals = mu.shape[0] 82 | 83 | for i in range(self.opt_steps): 84 | # optimize individual instances 85 | losses = [] 86 | for traj in range(n_evals): 87 | cur_trans_obs_0 = { 88 | key: repeat( 89 | arr[traj].unsqueeze(0), "1 ... -> n ...", n=self.num_samples 90 | ) 91 | for key, arr in trans_obs_0.items() 92 | } 93 | cur_z_obs_g = { 94 | key: repeat( 95 | arr[traj].unsqueeze(0), "1 ... -> n ...", n=self.num_samples 96 | ) 97 | for key, arr in z_obs_g.items() 98 | } 99 | action = ( 100 | torch.randn(self.num_samples, self.horizon, self.action_dim).to( 101 | self.device 102 | ) 103 | * sigma[traj] 104 | + mu[traj] 105 | ) 106 | action[0] = mu[traj] # optional: make the first one mu itself 107 | with torch.no_grad(): 108 | i_z_obses, i_zs = self.wm.rollout( 109 | obs_0=cur_trans_obs_0, 110 | act=action, 111 | ) 112 | 113 | loss = self.objective_fn(i_z_obses, cur_z_obs_g) 114 | topk_idx = torch.argsort(loss)[: self.topk] 115 | topk_action = action[topk_idx] 116 | losses.append(loss[topk_idx[0]].item()) 117 | mu[traj] = topk_action.mean(dim=0) 118 | sigma[traj] = topk_action.std(dim=0) 119 | 120 | self.wandb_run.log( 121 | {f"{self.logging_prefix}/loss": np.mean(losses), "step": i + 1} 122 | ) 123 | if self.evaluator is not None and i % self.eval_every == 0: 124 | logs, successes, _, _ = self.evaluator.eval_actions( 125 | mu, filename=f"{self.logging_prefix}_output_{i+1}" 126 | ) 127 | logs = {f"{self.logging_prefix}/{k}": v for k, v in logs.items()} 128 | logs.update({"step": i + 1}) 129 | self.wandb_run.log(logs) 130 | self.dump_logs(logs) 131 | if np.all(successes): 132 | break # terminate planning if all success 133 | 134 | return mu, np.full(n_evals, np.inf) # all actions are valid 135 | -------------------------------------------------------------------------------- /models/encoder/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.nn as nn 4 | 5 | 6 | class resnet18(nn.Module): 7 | def __init__( 8 | self, 9 | pretrained: bool = True, 10 | unit_norm: bool = False, 11 | ): 12 | super().__init__() 13 | resnet = torchvision.models.resnet18(pretrained=pretrained) 14 | self.resnet = nn.Sequential(*list(resnet.children())[:-1]) 15 | self.flatten = nn.Flatten() 16 | self.pretrained = pretrained 17 | self.normalize = torchvision.transforms.Normalize( 18 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 19 | ) 20 | self.unit_norm = unit_norm 21 | 22 | self.latent_ndim = 1 23 | self.emb_dim = 512 24 | self.name = "resnet" 25 | 26 | def forward(self, x): 27 | dims = len(x.shape) 28 | orig_shape = x.shape 29 | if dims == 3: 30 | x = x.unsqueeze(0) 31 | elif dims > 4: 32 | # flatten all dimensions to batch, then reshape back at the end 33 | x = x.reshape(-1, *orig_shape[-3:]) 34 | x = self.normalize(x) 35 | out = self.resnet(x) 36 | out = self.flatten(out) 37 | if self.unit_norm: 38 | out = torch.nn.functional.normalize(out, p=2, dim=-1) 39 | if dims == 3: 40 | out = out.squeeze(0) 41 | elif dims > 4: 42 | out = out.reshape(*orig_shape[:-3], -1) 43 | out = out.unsqueeze(1) 44 | return out 45 | 46 | 47 | class resblock(nn.Module): 48 | # this implementation assumes square images 49 | def __init__(self, input_dim, output_dim, kernel_size, resample=None, hw=32): 50 | super(resblock, self).__init__() 51 | self.input_dim = input_dim 52 | self.output_dim = output_dim 53 | self.kernel_size = kernel_size 54 | self.resample = resample 55 | 56 | padding = int((kernel_size - 1) / 2) 57 | 58 | if resample == "down": 59 | self.skip = nn.Sequential( 60 | nn.AvgPool2d(2, stride=2), 61 | nn.Conv2d(input_dim, output_dim, kernel_size, padding=padding), 62 | ) 63 | self.conv1 = nn.Conv2d( 64 | input_dim, input_dim, kernel_size, padding=padding, bias=False 65 | ) 66 | self.conv2 = nn.Sequential( 67 | nn.Conv2d(input_dim, output_dim, kernel_size, padding=padding), 68 | nn.MaxPool2d(2, stride=2), 69 | ) 70 | self.bn1 = nn.BatchNorm2d(input_dim) 71 | self.bn2 = nn.BatchNorm2d(output_dim) 72 | elif resample is None: 73 | self.skip = nn.Conv2d(input_dim, output_dim, 1) 74 | self.conv1 = nn.Conv2d( 75 | input_dim, output_dim, kernel_size, padding=padding, bias=False 76 | ) 77 | self.conv2 = nn.Conv2d(output_dim, output_dim, kernel_size, padding=padding) 78 | self.bn1 = nn.BatchNorm2d(output_dim) 79 | self.bn2 = nn.BatchNorm2d(output_dim) 80 | 81 | self.leakyrelu1 = nn.LeakyReLU() 82 | self.leakyrelu2 = nn.LeakyReLU() 83 | 84 | def forward(self, x): 85 | if (self.input_dim == self.output_dim) and self.resample is None: 86 | idnty = x 87 | else: 88 | idnty = self.skip(x) 89 | 90 | residual = x 91 | residual = self.conv1(residual) 92 | residual = self.bn1(residual) 93 | residual = self.leakyrelu1(residual) 94 | 95 | residual = self.conv2(residual) 96 | residual = self.bn2(residual) 97 | residual = self.leakyrelu2(residual) 98 | 99 | return idnty + residual 100 | 101 | 102 | class SmallResNet(nn.Module): 103 | def __init__(self, output_dim=512): 104 | super(SmallResNet, self).__init__() 105 | 106 | self.hw = 224 107 | 108 | # 3x224x224 109 | self.rb1 = resblock(3, 16, 3, resample="down", hw=self.hw) 110 | # 16x112x112 111 | self.rb2 = resblock(16, 32, 3, resample="down", hw=self.hw // 2) 112 | # 32x56x56 113 | self.rb3 = resblock(32, 64, 3, resample="down", hw=self.hw // 4) 114 | # 64x28x28 115 | self.rb4 = resblock(64, 128, 3, resample="down", hw=self.hw // 8) 116 | # 128x14x14 117 | self.rb5 = resblock(128, 512, 3, resample="down", hw=self.hw // 16) 118 | # 512x7x7 119 | self.maxpool = nn.MaxPool2d(7) 120 | # 512x1x1 121 | self.flat = nn.Flatten() 122 | 123 | def forward(self, x): 124 | dims = len(x.shape) 125 | orig_shape = x.shape 126 | if dims == 3: 127 | x = x.unsqueeze(0) 128 | elif dims > 4: 129 | # flatten all dimensions to batch, then reshape back at the end 130 | x = x.reshape(-1, *orig_shape[-3:]) 131 | x = self.rb1(x) 132 | x = self.rb2(x) 133 | x = self.rb3(x) 134 | x = self.rb4(x) 135 | x = self.rb5(x) 136 | x = self.maxpool(x) 137 | out = x.flatten(start_dim=-3) 138 | if dims == 3: 139 | out = out.squeeze(0) 140 | elif dims > 4: 141 | out = out.reshape(*orig_shape[:-3], -1) 142 | return out 143 | -------------------------------------------------------------------------------- /models/encoder/r3m/models/models_r3m.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import numpy as np 6 | from numpy.core.numeric import full 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.modules.activation import Sigmoid 10 | from torch.nn.modules.linear import Identity 11 | import torchvision 12 | from torchvision import transforms 13 | from .. import utils 14 | from pathlib import Path 15 | from torchvision.utils import save_image 16 | import torchvision.transforms as T 17 | 18 | epsilon = 1e-8 19 | 20 | 21 | def do_nothing(x): 22 | return x 23 | 24 | 25 | class R3M(nn.Module): 26 | def __init__( 27 | self, 28 | device, 29 | lr, 30 | hidden_dim, 31 | size=34, 32 | l2weight=1.0, 33 | l1weight=1.0, 34 | langweight=1.0, 35 | tcnweight=0.0, 36 | l2dist=True, 37 | bs=16, 38 | ): 39 | super().__init__() 40 | 41 | self.device = device 42 | self.use_tb = False 43 | self.l2weight = l2weight 44 | self.l1weight = l1weight 45 | self.tcnweight = tcnweight ## Weight on TCN loss (states closer in same clip closer in embedding) 46 | self.l2dist = l2dist ## Use -l2 or cosine sim 47 | self.langweight = langweight ## Weight on language reward 48 | self.size = size ## Size ResNet or ViT 49 | self.num_negatives = 3 50 | 51 | self.latent_ndim = 1 52 | self.emb_dim = 512 53 | self.name = "r3m" 54 | 55 | ## Distances and Metrics 56 | self.cs = torch.nn.CosineSimilarity(1) 57 | self.bce = nn.BCELoss(reduce=False) 58 | self.sigm = Sigmoid() 59 | 60 | params = [] 61 | ######################################################################## Sub Modules 62 | ## Visual Encoder 63 | if size == 18: 64 | self.outdim = 512 65 | self.convnet = torchvision.models.resnet18(pretrained=False) 66 | elif size == 34: 67 | self.outdim = 512 68 | self.convnet = torchvision.models.resnet34(pretrained=False) 69 | elif size == 50: 70 | self.outdim = 2048 71 | self.convnet = torchvision.models.resnet50(pretrained=False) 72 | elif size == 0: 73 | from transformers import AutoConfig 74 | 75 | self.outdim = 768 76 | self.convnet = AutoModel.from_config( 77 | config=AutoConfig.from_pretrained("google/vit-base-patch32-224-in21k") 78 | ).to(self.device) 79 | 80 | if self.size == 0: 81 | self.normlayer = transforms.Normalize( 82 | mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5] 83 | ) 84 | else: 85 | self.normlayer = transforms.Normalize( 86 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 87 | ) 88 | self.convnet.fc = Identity() 89 | self.convnet.train() 90 | params += list(self.convnet.parameters()) 91 | 92 | ## Language Reward 93 | if self.langweight > 0.0: 94 | ## Pretrained DistilBERT Sentence Encoder 95 | from r3m.models.models_language import LangEncoder, LanguageReward 96 | 97 | self.lang_enc = LangEncoder(self.device, 0, 0) 98 | self.lang_rew = LanguageReward( 99 | None, self.outdim, hidden_dim, self.lang_enc.lang_size, simfunc=self.sim 100 | ) 101 | params += list(self.lang_rew.parameters()) 102 | ######################################################################## 103 | 104 | ## Optimizer 105 | self.encoder_opt = torch.optim.Adam(params, lr=lr) 106 | 107 | def get_reward(self, e0, es, sentences): 108 | ## Only callable is langweight was set to be 1 109 | le = self.lang_enc(sentences) 110 | return self.lang_rew(e0, es, le) 111 | 112 | ## Forward Call (im --> representation) 113 | def forward(self, obs, num_ims=1, obs_shape=[3, 224, 224]): 114 | if obs_shape != [3, 224, 224]: 115 | preprocess = nn.Sequential( 116 | transforms.Resize(256), 117 | transforms.CenterCrop(224), 118 | self.normlayer, 119 | ) 120 | else: 121 | preprocess = nn.Sequential( 122 | self.normlayer, 123 | ) 124 | 125 | ## Input must be [0, 1], [3,244,244] 126 | dims = len(obs.shape) 127 | orig_shape = obs.shape 128 | if dims == 3: 129 | obs = obs.unsqueeze(0) 130 | elif dims > 4: 131 | obs = obs.reshape(-1, *orig_shape[-3:]) 132 | obs_p = preprocess(obs) 133 | h = self.convnet(obs_p) 134 | if dims == 3: 135 | h = h.squeeze(0) 136 | elif dims > 4: 137 | h = h.reshape(*orig_shape[:-3], -1) 138 | h = h.unsqueeze(1) 139 | return h 140 | 141 | def sim(self, tensor1, tensor2): 142 | if self.l2dist: 143 | d = -torch.linalg.norm(tensor1 - tensor2, dim=-1) 144 | else: 145 | d = self.cs(tensor1, tensor2) 146 | return d 147 | -------------------------------------------------------------------------------- /models/encoder/r3m/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from .models.models_r3m import R3M 6 | 7 | import os 8 | from os.path import expanduser 9 | import omegaconf 10 | import hydra 11 | import gdown 12 | import torch 13 | import copy 14 | 15 | VALID_ARGS = [ 16 | "_target_", 17 | "device", 18 | "lr", 19 | "hidden_dim", 20 | "size", 21 | "l2weight", 22 | "l1weight", 23 | "langweight", 24 | "tcnweight", 25 | "l2dist", 26 | "bs", 27 | ] 28 | if torch.cuda.is_available(): 29 | device = "cuda" 30 | else: 31 | device = "cpu" 32 | 33 | 34 | def cleanup_config(cfg): 35 | config = copy.deepcopy(cfg) 36 | keys = config.agent.keys() 37 | for key in list(keys): 38 | if key not in VALID_ARGS: 39 | del config.agent[key] 40 | config.agent["_target_"] = "models.encoder.r3m.R3M" 41 | config["device"] = device 42 | 43 | ## Hardcodes to remove the language head 44 | ## Assumes downstream use is as visual representation 45 | config.agent["langweight"] = 0 46 | return config.agent 47 | 48 | 49 | def remove_language_head(state_dict): 50 | keys = state_dict.keys() 51 | ## Hardcodes to remove the language head 52 | ## Assumes downstream use is as visual representation 53 | for key in list(keys): 54 | if ("lang_enc" in key) or ("lang_rew" in key): 55 | del state_dict[key] 56 | return state_dict 57 | 58 | 59 | def load_r3m(modelid): 60 | home = os.path.join(expanduser("~"), ".model_checkpoints", "r3m") 61 | if modelid == "resnet50": 62 | foldername = "r3m_50" 63 | modelurl = "https://drive.google.com/uc?id=1Xu0ssuG0N1zjZS54wmWzJ7-nb0-7XzbA" 64 | configurl = "https://drive.google.com/uc?id=10jY2VxrrhfOdNPmsFdES568hjjIoBJx8" 65 | elif modelid == "resnet34": 66 | foldername = "r3m_34" 67 | modelurl = "https://drive.google.com/uc?id=15bXD3QRhspIRacOKyWPw5y2HpoWUCEnE" 68 | configurl = "https://drive.google.com/uc?id=1RY0NS-Tl4G7M1Ik_lOym0b5VIBxX9dqW" 69 | elif modelid == "resnet18": 70 | foldername = "r3m_18" 71 | modelurl = "https://drive.google.com/uc?id=1A1ic-p4KtYlKXdXHcV2QV0cUzI4kn0u-" 72 | configurl = "https://drive.google.com/uc?id=1nitbHQ-GRorxc7vMUiEHjHWP5N11Jvc6" 73 | else: 74 | raise NameError("Invalid Model ID") 75 | 76 | if not os.path.exists(os.path.join(home, foldername)): 77 | os.makedirs(os.path.join(home, foldername)) 78 | modelpath = os.path.join(home, foldername, "model.pt") 79 | configpath = os.path.join(home, foldername, "config.yaml") 80 | if not os.path.exists(modelpath): 81 | gdown.download(modelurl, modelpath, quiet=False) 82 | gdown.download(configurl, configpath, quiet=False) 83 | 84 | modelcfg = omegaconf.OmegaConf.load(configpath) 85 | cleancfg = cleanup_config(modelcfg) 86 | rep = hydra.utils.instantiate(cleancfg) 87 | rep = torch.nn.DataParallel(rep) 88 | r3m_state_dict = remove_language_head( 89 | torch.load(modelpath, map_location=torch.device(device))["r3m"] 90 | ) 91 | rep.load_state_dict(r3m_state_dict) 92 | rep = rep.module 93 | return rep 94 | 95 | def load_r3m_reproduce(modelid): 96 | home = os.path.join(expanduser("~"), ".r3m") 97 | if modelid == "r3m": 98 | foldername = "original_r3m" 99 | modelurl = "https://drive.google.com/uc?id=1jLb1yldIMfAcGVwYojSQmMpmRM7vqjp9" 100 | configurl = "https://drive.google.com/uc?id=1cu-Pb33qcfAieRIUptNlG1AQIMZlAI-q" 101 | elif modelid == "r3m_noaug": 102 | foldername = "original_r3m_noaug" 103 | modelurl = "https://drive.google.com/uc?id=1k_ZlVtvlktoYLtBcfD0aVFnrZcyCNS9D" 104 | configurl = "https://drive.google.com/uc?id=1hPmJwDiWPkd6GGez6ywSC7UOTIX7NgeS" 105 | elif modelid == "r3m_nol1": 106 | foldername = "original_r3m_nol1" 107 | modelurl = "https://drive.google.com/uc?id=1LpW3aBMdjoXsjYlkaDnvwx7q22myM_nB" 108 | configurl = "https://drive.google.com/uc?id=1rZUBrYJZvlF1ReFwRidZsH7-xe7csvab" 109 | elif modelid == "r3m_nolang": 110 | foldername = "original_r3m_nolang" 111 | modelurl = "https://drive.google.com/uc?id=1FXcniRei2JDaGMJJ_KlVxHaLy0Fs_caV" 112 | configurl = "https://drive.google.com/uc?id=192G4UkcNJO4EKN46ECujMcH0AQVhnyQe" 113 | else: 114 | raise NameError("Invalid Model ID") 115 | 116 | if not os.path.exists(os.path.join(home, foldername)): 117 | os.makedirs(os.path.join(home, foldername)) 118 | modelpath = os.path.join(home, foldername, "model.pt") 119 | configpath = os.path.join(home, foldername, "config.yaml") 120 | if not os.path.exists(modelpath): 121 | gdown.download(modelurl, modelpath, quiet=False) 122 | gdown.download(configurl, configpath, quiet=False) 123 | 124 | modelcfg = omegaconf.OmegaConf.load(configpath) 125 | cleancfg = cleanup_config(modelcfg) 126 | rep = hydra.utils.instantiate(cleancfg) 127 | rep = torch.nn.DataParallel(rep) 128 | r3m_state_dict = remove_language_head( 129 | torch.load(modelpath, map_location=torch.device(device))["r3m"] 130 | ) 131 | 132 | rep.load_state_dict(r3m_state_dict) 133 | rep = rep.module 134 | return rep 135 | -------------------------------------------------------------------------------- /datasets/point_maze_dset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import decord 3 | import numpy as np 4 | from pathlib import Path 5 | from einops import rearrange 6 | from decord import VideoReader 7 | from typing import Callable, Optional 8 | from .traj_dset import TrajDataset, get_train_val_sliced 9 | from typing import Optional, Callable, Any 10 | decord.bridge.set_bridge("torch") 11 | 12 | class PointMazeDataset(TrajDataset): 13 | def __init__( 14 | self, 15 | data_path: str = "data/point_maze", 16 | n_rollout: Optional[int] = None, 17 | transform: Optional[Callable] = None, 18 | normalize_action: bool = False, 19 | action_scale=1.0, 20 | ): 21 | self.data_path = Path(data_path) 22 | self.transform = transform 23 | self.normalize_action = normalize_action 24 | states = torch.load(self.data_path / "states.pth").float() 25 | self.states = states 26 | self.actions = torch.load(self.data_path / "actions.pth").float() 27 | self.actions = self.actions / action_scale # scaled back up in env 28 | self.seq_lengths = torch.load(self.data_path /'seq_lengths.pth') 29 | 30 | self.n_rollout = n_rollout 31 | if self.n_rollout: 32 | n = self.n_rollout 33 | else: 34 | n = len(self.states) 35 | 36 | self.states = self.states[:n] 37 | self.actions = self.actions[:n] 38 | self.seq_lengths = self.seq_lengths[:n] 39 | self.proprios = self.states.clone() 40 | print(f"Loaded {n} rollouts") 41 | 42 | self.action_dim = self.actions.shape[-1] 43 | self.state_dim = self.states.shape[-1] 44 | self.proprio_dim = self.proprios.shape[-1] 45 | 46 | if normalize_action: 47 | self.action_mean, self.action_std = self.get_data_mean_std(self.actions, self.seq_lengths) 48 | self.state_mean, self.state_std = self.get_data_mean_std(self.states, self.seq_lengths) 49 | self.proprio_mean, self.proprio_std = self.get_data_mean_std(self.proprios, self.seq_lengths) 50 | else: 51 | self.action_mean = torch.zeros(self.action_dim) 52 | self.action_std = torch.ones(self.action_dim) 53 | self.state_mean = torch.zeros(self.state_dim) 54 | self.state_std = torch.ones(self.state_dim) 55 | self.proprio_mean = torch.zeros(self.proprio_dim) 56 | self.proprio_std = torch.ones(self.proprio_dim) 57 | 58 | self.actions = (self.actions - self.action_mean) / self.action_std 59 | self.proprios = (self.proprios - self.proprio_mean) / self.proprio_std 60 | 61 | def get_data_mean_std(self, data, traj_lengths): 62 | all_data = [] 63 | for traj in range(len(traj_lengths)): 64 | traj_len = traj_lengths[traj] 65 | traj_data = data[traj, :traj_len] 66 | all_data.append(traj_data) 67 | all_data = torch.vstack(all_data) 68 | data_mean = torch.mean(all_data, dim=0) 69 | data_std = torch.std(all_data, dim=0) 70 | return data_mean, data_std 71 | 72 | def get_seq_length(self, idx): 73 | return self.seq_lengths[idx] 74 | 75 | def get_all_actions(self): 76 | result = [] 77 | for i in range(len(self.seq_lengths)): 78 | T = self.seq_lengths[i] 79 | result.append(self.actions[i, :T, :]) 80 | return torch.cat(result, dim=0) 81 | 82 | def get_frames(self, idx, frames): 83 | obs_dir = self.data_path / "obses" 84 | image = torch.load(obs_dir / f"episode_{idx:03d}.pth") 85 | proprio = self.proprios[idx, frames] 86 | act = self.actions[idx, frames] 87 | state = self.states[idx, frames] 88 | 89 | image = image[frames] # THWC 90 | image = image / 255.0 91 | image = rearrange(image, "T H W C -> T C H W") 92 | if self.transform: 93 | image = self.transform(image) 94 | obs = { 95 | "visual": image, 96 | "proprio": proprio 97 | } 98 | return obs, act, state, {} # env_info 99 | 100 | def __getitem__(self, idx): 101 | return self.get_frames(idx, range(self.get_seq_length(idx))) 102 | 103 | def __len__(self): 104 | return len(self.seq_lengths) 105 | 106 | def preprocess_imgs(self, imgs): 107 | if isinstance(imgs, np.ndarray): 108 | raise NotImplementedError 109 | elif isinstance(imgs, torch.Tensor): 110 | return rearrange(imgs, "b h w c -> b c h w") / 255.0 111 | 112 | def load_point_maze_slice_train_val( 113 | transform, 114 | n_rollout=50, 115 | data_path='data/pusht_dataset', 116 | normalize_action=False, 117 | split_ratio=0.8, 118 | num_hist=0, 119 | num_pred=0, 120 | frameskip=0, 121 | ): 122 | dset = PointMazeDataset( 123 | n_rollout=n_rollout, 124 | transform=transform, 125 | data_path=data_path, 126 | normalize_action=normalize_action, 127 | ) 128 | dset_train, dset_val, train_slices, val_slices = get_train_val_sliced( 129 | traj_dataset=dset, 130 | train_fraction=split_ratio, 131 | num_frames=num_hist + num_pred, 132 | frameskip=frameskip 133 | ) 134 | 135 | datasets = {} 136 | datasets['train'] = train_slices 137 | datasets['valid'] = val_slices 138 | traj_dset = {} 139 | traj_dset['train'] = dset_train 140 | traj_dset['valid'] = dset_val 141 | return datasets, traj_dset 142 | -------------------------------------------------------------------------------- /planning/mpc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import hydra 3 | import copy 4 | import numpy as np 5 | from einops import rearrange, repeat 6 | from utils import slice_trajdict_with_t 7 | from .base_planner import BasePlanner 8 | 9 | 10 | class MPCPlanner(BasePlanner): 11 | """ 12 | an online planner so feedback from env is allowed 13 | """ 14 | 15 | def __init__( 16 | self, 17 | max_iter, 18 | n_taken_actions, 19 | sub_planner, 20 | wm, 21 | env, # for online exec 22 | action_dim, 23 | objective_fn, 24 | preprocessor, 25 | evaluator, 26 | wandb_run, 27 | logging_prefix="mpc", 28 | log_filename="logs.json", 29 | **kwargs, 30 | ): 31 | super().__init__( 32 | wm, 33 | action_dim, 34 | objective_fn, 35 | preprocessor, 36 | evaluator, 37 | wandb_run, 38 | log_filename, 39 | ) 40 | self.env = env 41 | self.max_iter = np.inf if max_iter is None else max_iter 42 | self.n_taken_actions = n_taken_actions 43 | self.logging_prefix = logging_prefix 44 | sub_planner["_target_"] = sub_planner["target"] 45 | self.sub_planner = hydra.utils.instantiate( 46 | sub_planner, 47 | wm=self.wm, 48 | action_dim=self.action_dim, 49 | objective_fn=self.objective_fn, 50 | preprocessor=self.preprocessor, 51 | evaluator=self.evaluator, # evaluator is shared for mpc and sub_planner 52 | wandb_run=self.wandb_run, 53 | log_filename=None, 54 | ) 55 | self.is_success = None 56 | self.action_len = None # keep track of the step each traj reaches success 57 | self.iter = 0 58 | self.planned_actions = [] 59 | 60 | def _apply_success_mask(self, actions): 61 | device = actions.device 62 | mask = torch.tensor(self.is_success).bool() 63 | actions[mask] = 0 64 | masked_actions = rearrange( 65 | actions[mask], "... (f d) -> ... f d", f=self.evaluator.frameskip 66 | ) 67 | masked_actions = self.preprocessor.normalize_actions(masked_actions.cpu()) 68 | masked_actions = rearrange(masked_actions, "... f d -> ... (f d)") 69 | actions[mask] = masked_actions.to(device) 70 | return actions 71 | 72 | def plan(self, obs_0, obs_g, actions=None): 73 | """ 74 | actions is NOT used 75 | Returns: 76 | actions: (B, T, action_dim) torch.Tensor 77 | """ 78 | n_evals = obs_0["visual"].shape[0] 79 | self.is_success = np.zeros(n_evals, dtype=bool) 80 | self.action_len = np.full(n_evals, np.inf) 81 | init_obs_0, init_state_0 = self.evaluator.get_init_cond() 82 | 83 | cur_obs_0 = obs_0 84 | memo_actions = None 85 | while not np.all(self.is_success) and self.iter < self.max_iter: 86 | self.sub_planner.logging_prefix = f"plan_{self.iter}" 87 | actions, _ = self.sub_planner.plan( 88 | obs_0=cur_obs_0, 89 | obs_g=obs_g, 90 | actions=memo_actions, 91 | ) # (b, t, act_dim) 92 | taken_actions = actions.detach()[:, : self.n_taken_actions] 93 | self._apply_success_mask(taken_actions) 94 | memo_actions = actions.detach()[:, self.n_taken_actions :] 95 | self.planned_actions.append(taken_actions) 96 | 97 | print(f"MPC iter {self.iter} Eval ------- ") 98 | action_so_far = torch.cat(self.planned_actions, dim=1) 99 | self.evaluator.assign_init_cond( 100 | obs_0=init_obs_0, 101 | state_0=init_state_0, 102 | ) 103 | logs, successes, e_obses, e_states = self.evaluator.eval_actions( 104 | action_so_far, 105 | self.action_len, 106 | filename=f"plan{self.iter}", 107 | save_video=True, 108 | ) 109 | new_successes = successes & ~self.is_success # Identify new successes 110 | self.is_success = ( 111 | self.is_success | successes 112 | ) # Update overall success status 113 | self.action_len[new_successes] = ( 114 | (self.iter + 1) * self.n_taken_actions 115 | ) # Update only for the newly successful trajectories 116 | 117 | print("self.is_success: ", self.is_success) 118 | logs = {f"{self.logging_prefix}/{k}": v for k, v in logs.items()} 119 | logs.update({"step": self.iter + 1}) 120 | self.wandb_run.log(logs) 121 | self.dump_logs(logs) 122 | 123 | # update evaluator's init conditions with new env feedback 124 | e_final_obs = slice_trajdict_with_t(e_obses, start_idx=-1) 125 | cur_obs_0 = e_final_obs 126 | e_final_state = e_states[:, -1] 127 | self.evaluator.assign_init_cond( 128 | obs_0=e_final_obs, 129 | state_0=e_final_state, 130 | ) 131 | self.iter += 1 132 | self.sub_planner.logging_prefix = f"plan_{self.iter}" 133 | 134 | planned_actions = torch.cat(self.planned_actions, dim=1) 135 | self.evaluator.assign_init_cond( 136 | obs_0=init_obs_0, 137 | state_0=init_state_0, 138 | ) 139 | 140 | return planned_actions, self.action_len 141 | -------------------------------------------------------------------------------- /datasets/traj_dset.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import torch 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | from typing import Optional, Sequence, List 6 | from torch.utils.data import Dataset, Subset 7 | from torch import default_generator, randperm 8 | from einops import rearrange 9 | 10 | # https://github.com/JaidedAI/EasyOCR/issues/1243 11 | def _accumulate(iterable, fn=lambda x, y: x + y): 12 | "Return running totals" 13 | # _accumulate([1,2,3,4,5]) --> 1 3 6 10 15 14 | # _accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120 15 | it = iter(iterable) 16 | try: 17 | total = next(it) 18 | except StopIteration: 19 | return 20 | yield total 21 | for element in it: 22 | total = fn(total, element) 23 | yield total 24 | 25 | class TrajDataset(Dataset, abc.ABC): 26 | @abc.abstractmethod 27 | def get_seq_length(self, idx): 28 | """ 29 | Returns the length of the idx-th trajectory. 30 | """ 31 | raise NotImplementedError 32 | 33 | class TrajSubset(TrajDataset, Subset): 34 | """ 35 | Subset of a trajectory dataset at specified indices. 36 | 37 | Args: 38 | dataset (TrajectoryDataset): The whole Dataset 39 | indices (sequence): Indices in the whole set selected for subset 40 | """ 41 | def __init__(self, dataset: TrajDataset, indices: Sequence[int]): 42 | Subset.__init__(self, dataset, indices) 43 | 44 | def get_seq_length(self, idx): 45 | return self.dataset.get_seq_length(self.indices[idx]) 46 | 47 | def __getattr__(self, name): 48 | if hasattr(self.dataset, name): 49 | return getattr(self.dataset, name) 50 | raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") 51 | 52 | 53 | class TrajSlicerDataset(TrajDataset): 54 | def __init__( 55 | self, 56 | dataset: TrajDataset, 57 | num_frames: int, 58 | frameskip: int = 1, 59 | process_actions: str = "concat", 60 | ): 61 | self.dataset = dataset 62 | self.num_frames = num_frames 63 | self.frameskip = frameskip 64 | self.slices = [] 65 | for i in range(len(self.dataset)): 66 | T = self.dataset.get_seq_length(i) 67 | if T - num_frames < 0: 68 | print(f"Ignored short sequence #{i}: len={T}, num_frames={num_frames}") 69 | else: 70 | self.slices += [ 71 | (i, start, start + num_frames * self.frameskip) 72 | for start in range(T - num_frames * frameskip + 1) 73 | ] # slice indices follow convention [start, end) 74 | # randomly permute the slices 75 | self.slices = np.random.permutation(self.slices) 76 | 77 | self.proprio_dim = self.dataset.proprio_dim 78 | if process_actions == "concat": 79 | self.action_dim = self.dataset.action_dim * self.frameskip 80 | else: 81 | self.action_dim = self.dataset.action_dim 82 | 83 | self.state_dim = self.dataset.state_dim 84 | 85 | 86 | def get_seq_length(self, idx: int) -> int: 87 | return self.num_frames 88 | 89 | def __len__(self): 90 | return len(self.slices) 91 | 92 | def __getitem__(self, idx): 93 | i, start, end = self.slices[idx] 94 | obs, act, state, _ = self.dataset[i] 95 | for k, v in obs.items(): 96 | obs[k] = v[start:end:self.frameskip] 97 | state = state[start:end:self.frameskip] 98 | act = act[start:end] 99 | act = rearrange(act, "(n f) d -> n (f d)", n=self.num_frames) # concat actions 100 | return tuple([obs, act, state]) 101 | 102 | 103 | def random_split_traj( 104 | dataset: TrajDataset, 105 | lengths: Sequence[int], 106 | generator: Optional[torch.Generator] = default_generator, 107 | ) -> List[TrajSubset]: 108 | if sum(lengths) != len(dataset): # type: ignore[arg-type] 109 | raise ValueError( 110 | "Sum of input lengths does not equal the length of the input dataset!" 111 | ) 112 | 113 | indices = randperm(sum(lengths), generator=generator).tolist() 114 | print( 115 | [ 116 | indices[offset - length : offset] 117 | for offset, length in zip(_accumulate(lengths), lengths) 118 | ] 119 | ) 120 | return [ 121 | TrajSubset(dataset, indices[offset - length : offset]) 122 | for offset, length in zip(_accumulate(lengths), lengths) 123 | ] 124 | 125 | 126 | def split_traj_datasets(dataset, train_fraction=0.95, random_seed=42): 127 | dataset_length = len(dataset) 128 | lengths = [ 129 | int(train_fraction * dataset_length), 130 | dataset_length - int(train_fraction * dataset_length), 131 | ] 132 | train_set, val_set = random_split_traj( 133 | dataset, lengths, generator=torch.Generator().manual_seed(random_seed) 134 | ) 135 | return train_set, val_set 136 | 137 | 138 | def get_train_val_sliced( 139 | traj_dataset: TrajDataset, 140 | train_fraction: float = 0.9, 141 | random_seed: int = 42, 142 | num_frames: int = 10, 143 | frameskip: int = 1, 144 | ): 145 | train, val = split_traj_datasets( 146 | traj_dataset, 147 | train_fraction=train_fraction, 148 | random_seed=random_seed, 149 | ) 150 | train_slices = TrajSlicerDataset(train, num_frames, frameskip) 151 | val_slices = TrajSlicerDataset(val, num_frames, frameskip) 152 | return train, val, train_slices, val_slices --------------------------------------------------------------------------------