├── models
├── __init__.py
├── encoder
│ ├── __init__.py
│ ├── r3m
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ ├── models_language.py
│ │ │ └── models_r3m.py
│ │ ├── utils
│ │ │ ├── __init__.py
│ │ │ └── data_loaders.py
│ │ ├── cfgs
│ │ │ ├── hydra
│ │ │ │ ├── output
│ │ │ │ │ └── local.yaml
│ │ │ │ └── launcher
│ │ │ │ │ └── local.yaml
│ │ │ └── config_rep.yaml
│ │ ├── r3m_base.yaml
│ │ └── __init__.py
│ └── resnet.py
├── dino.py
├── dummy.py
├── proprio.py
├── decoder
│ └── transposed_conv.py
└── vit.py
├── datasets
├── __init__.py
├── img_transforms.py
├── point_maze_dset.py
└── traj_dset.py
├── planning
├── __init__.py
├── base_planner.py
├── objectives.py
├── gd.py
├── cem.py
└── mpc.py
├── env
├── wall
│ ├── envs
│ │ └── __init__.py
│ ├── data
│ │ └── __init__.py
│ └── wall_env_wrapper.py
├── pointmaze
│ ├── gridcraft
│ │ ├── __init__.py
│ │ ├── utils.py
│ │ ├── wrappers.py
│ │ └── grid_spec.py
│ ├── waypoint_controller.py
│ ├── point_maze_wrapper.py
│ ├── q_iteration.py
│ └── dynamic_mjc.py
├── pusht
│ ├── __init__.py
│ └── pusht_wrapper.py
├── deformable_env
│ └── src
│ │ └── sim
│ │ ├── assets
│ │ └── xarm
│ │ │ ├── xarm_gripper
│ │ │ └── meshes
│ │ │ │ ├── base_link.STL
│ │ │ │ ├── left_finger.STL
│ │ │ │ ├── right_finger.STL
│ │ │ │ ├── left_inner_knuckle.STL
│ │ │ │ ├── left_outer_knuckle.STL
│ │ │ │ ├── right_inner_knuckle.STL
│ │ │ │ ├── right_outer_knuckle.STL
│ │ │ │ ├── stick_4cm.obj
│ │ │ │ ├── stick_3cm.obj
│ │ │ │ ├── stick_6cm.obj
│ │ │ │ ├── stick_1cm.obj
│ │ │ │ └── stick_2cm.obj
│ │ │ ├── xarm_description
│ │ │ └── meshes
│ │ │ │ └── xarm6
│ │ │ │ ├── visual
│ │ │ │ ├── base.stl
│ │ │ │ ├── link1.stl
│ │ │ │ ├── link2.stl
│ │ │ │ ├── link3.stl
│ │ │ │ ├── link4.stl
│ │ │ │ ├── link5.stl
│ │ │ │ └── link6.stl
│ │ │ │ └── collision
│ │ │ │ ├── base.mtl
│ │ │ │ ├── link1.mtl
│ │ │ │ ├── link2.mtl
│ │ │ │ ├── link3.mtl
│ │ │ │ ├── link4.mtl
│ │ │ │ ├── link5.mtl
│ │ │ │ ├── link6.mtl
│ │ │ │ ├── link2_vhacd2.mtl
│ │ │ │ └── link6_vhacd.obj
│ │ │ ├── LICENSE
│ │ │ ├── link6.urdf
│ │ │ ├── base.urdf
│ │ │ ├── link1.urdf
│ │ │ ├── link6_com.urdf
│ │ │ ├── link3.urdf
│ │ │ ├── base_com.urdf
│ │ │ ├── link5.urdf
│ │ │ ├── link4.urdf
│ │ │ ├── link1_com.urdf
│ │ │ ├── link2.urdf
│ │ │ ├── link3_com.urdf
│ │ │ ├── link5_com.urdf
│ │ │ ├── link4_com.urdf
│ │ │ └── link2_com.urdf
│ │ ├── sim_env
│ │ ├── flex_scene.py
│ │ ├── cameras.py
│ │ └── robot_env.py
│ │ └── data_gen
│ │ └── data.py
├── __init__.py
└── serial_vector_env.py
├── .gitignore
├── conf
├── encoder
│ ├── dummy.yaml
│ ├── r3m.yaml
│ ├── resnet.yaml
│ ├── dino.yaml
│ └── dino_cls.yaml
├── action_encoder
│ ├── dummy.yaml
│ └── proprio.yaml
├── proprio_encoder
│ ├── dummy.yaml
│ └── proprio.yaml
├── decoder
│ ├── vqvae.yaml
│ └── transposed_conv.yaml
├── predictor
│ └── vit.yaml
├── planner
│ ├── cem.yaml
│ ├── gd.yaml
│ ├── mpc_cem.yaml
│ └── mpc_gd.yaml
├── env
│ ├── point_maze.yaml
│ ├── wall.yaml
│ ├── pusht.yaml
│ ├── deformable_env.yaml
│ ├── rope.yaml
│ └── granular.yaml
├── plan.yaml
├── plan_wall.yaml
├── plan_point_maze.yaml
├── plan_pusht.yaml
└── train.yaml
├── assets
└── intro.png
├── distributed_fn
├── __init__.py
├── launch.py
└── distributed.py
├── custom_resolvers.py
├── metrics
├── lpipsPyTorch
│ ├── __init__.py
│ └── modules
│ │ ├── utils.py
│ │ ├── lpips.py
│ │ └── networks.py
└── image_metrics.py
├── LICENSE
├── preprocessor.py
└── utils.py
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/planning/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/env/wall/envs/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/encoder/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/env/pointmaze/gridcraft/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/encoder/r3m/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/encoder/r3m/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/env/pusht/__init__.py:
--------------------------------------------------------------------------------
1 | from .pusht_env import PushTEnv
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | *.pyc
3 | outputs
4 | plan_outputs
5 |
--------------------------------------------------------------------------------
/conf/encoder/dummy.yaml:
--------------------------------------------------------------------------------
1 | _target_: models.dummy.DummyModel
2 | emb_dim: 7
--------------------------------------------------------------------------------
/conf/action_encoder/dummy.yaml:
--------------------------------------------------------------------------------
1 | _target_: models.dummy.DummyRepeatActionEncoder
2 |
--------------------------------------------------------------------------------
/conf/encoder/r3m.yaml:
--------------------------------------------------------------------------------
1 | _target_: models.encoder.r3m.load_r3m
2 | modelid: resnet18
--------------------------------------------------------------------------------
/conf/proprio_encoder/dummy.yaml:
--------------------------------------------------------------------------------
1 | _target_: models.dummy.DummyRepeatActionEncoder
2 |
--------------------------------------------------------------------------------
/assets/intro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gaoyuezhou/dino_wm/HEAD/assets/intro.png
--------------------------------------------------------------------------------
/conf/encoder/resnet.yaml:
--------------------------------------------------------------------------------
1 | _target_: models.encoder.resnet.resnet18
2 | pretrained: True
3 | unit_norm: False
--------------------------------------------------------------------------------
/conf/encoder/dino.yaml:
--------------------------------------------------------------------------------
1 | _target_: models.dino.DinoV2Encoder
2 | name: "dinov2_vits14"
3 | feature_key: "x_norm_patchtokens"
4 |
--------------------------------------------------------------------------------
/conf/encoder/dino_cls.yaml:
--------------------------------------------------------------------------------
1 | _target_: models.dino.DinoV2Encoder
2 | name: "dinov2_vits14"
3 | feature_key: "x_norm_clstoken"
4 |
--------------------------------------------------------------------------------
/conf/action_encoder/proprio.yaml:
--------------------------------------------------------------------------------
1 | _target_: models.proprio.ProprioceptiveEmbedding
2 | num_frames: 1
3 | tubelet_size: 1
4 | use_3d_pos: False
--------------------------------------------------------------------------------
/conf/proprio_encoder/proprio.yaml:
--------------------------------------------------------------------------------
1 | _target_: models.proprio.ProprioceptiveEmbedding
2 | num_frames: 1
3 | tubelet_size: 1
4 | use_3d_pos: False
--------------------------------------------------------------------------------
/conf/decoder/vqvae.yaml:
--------------------------------------------------------------------------------
1 | _target_: models.vqvae.VQVAE
2 | channel: 384
3 | n_embed: 2048
4 | n_res_block: 4
5 | n_res_channel: 128
6 | quantize: False
--------------------------------------------------------------------------------
/conf/predictor/vit.yaml:
--------------------------------------------------------------------------------
1 | _target_: models.vit.ViTPredictor
2 | depth: 6
3 | heads: 16
4 | mlp_dim: 2048
5 | dropout: 0.1
6 | emb_dropout: 0
7 | pool: 'mean'
--------------------------------------------------------------------------------
/conf/planner/cem.yaml:
--------------------------------------------------------------------------------
1 | _target_: planning.cem.CEMPlanner
2 | horizon: 5
3 | topk: 30
4 | num_samples: 300
5 | var_scale: 1
6 | opt_steps: 30
7 | eval_every: 1
8 |
9 | name: cem
--------------------------------------------------------------------------------
/conf/decoder/transposed_conv.yaml:
--------------------------------------------------------------------------------
1 | _target_: models.decoder.transposed_conv.TransposedConvDecoder
2 | observation_shape: [3, 224, 224]
3 | depth: 64
4 | kernel_size: 5
5 | stride: 3
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/base_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gaoyuezhou/dino_wm/HEAD/env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/base_link.STL
--------------------------------------------------------------------------------
/conf/planner/gd.yaml:
--------------------------------------------------------------------------------
1 | _target_: planning.gd.GDPlanner
2 | horizon: 5
3 | action_noise: 0.003
4 | sample_type: 'randn' # 'zero' or 'randn'
5 | lr: 1
6 | opt_steps: 1000
7 | eval_every: 10
8 |
9 | name: gd
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/left_finger.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gaoyuezhou/dino_wm/HEAD/env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/left_finger.STL
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/right_finger.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gaoyuezhou/dino_wm/HEAD/env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/right_finger.STL
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/left_inner_knuckle.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gaoyuezhou/dino_wm/HEAD/env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/left_inner_knuckle.STL
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/left_outer_knuckle.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gaoyuezhou/dino_wm/HEAD/env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/left_outer_knuckle.STL
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/right_inner_knuckle.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gaoyuezhou/dino_wm/HEAD/env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/right_inner_knuckle.STL
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/right_outer_knuckle.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gaoyuezhou/dino_wm/HEAD/env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/right_outer_knuckle.STL
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/visual/base.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gaoyuezhou/dino_wm/HEAD/env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/visual/base.stl
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/visual/link1.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gaoyuezhou/dino_wm/HEAD/env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/visual/link1.stl
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/visual/link2.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gaoyuezhou/dino_wm/HEAD/env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/visual/link2.stl
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/visual/link3.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gaoyuezhou/dino_wm/HEAD/env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/visual/link3.stl
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/visual/link4.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gaoyuezhou/dino_wm/HEAD/env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/visual/link4.stl
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/visual/link5.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gaoyuezhou/dino_wm/HEAD/env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/visual/link5.stl
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/visual/link6.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gaoyuezhou/dino_wm/HEAD/env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/visual/link6.stl
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/collision/base.mtl:
--------------------------------------------------------------------------------
1 | # Blender MTL File: 'None'
2 | # Material Count: 1
3 |
4 | newmtl None
5 | Ns 500
6 | Ka 0.8 0.8 0.8
7 | Kd 0.8 0.8 0.8
8 | Ks 0.8 0.8 0.8
9 | d 1
10 | illum 2
11 |
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/collision/link1.mtl:
--------------------------------------------------------------------------------
1 | # Blender MTL File: 'None'
2 | # Material Count: 1
3 |
4 | newmtl None
5 | Ns 500
6 | Ka 0.8 0.8 0.8
7 | Kd 0.8 0.8 0.8
8 | Ks 0.8 0.8 0.8
9 | d 1
10 | illum 2
11 |
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/collision/link2.mtl:
--------------------------------------------------------------------------------
1 | # Blender MTL File: 'None'
2 | # Material Count: 1
3 |
4 | newmtl None
5 | Ns 500
6 | Ka 0.8 0.8 0.8
7 | Kd 0.8 0.8 0.8
8 | Ks 0.8 0.8 0.8
9 | d 1
10 | illum 2
11 |
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/collision/link3.mtl:
--------------------------------------------------------------------------------
1 | # Blender MTL File: 'None'
2 | # Material Count: 1
3 |
4 | newmtl None
5 | Ns 500
6 | Ka 0.8 0.8 0.8
7 | Kd 0.8 0.8 0.8
8 | Ks 0.8 0.8 0.8
9 | d 1
10 | illum 2
11 |
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/collision/link4.mtl:
--------------------------------------------------------------------------------
1 | # Blender MTL File: 'None'
2 | # Material Count: 1
3 |
4 | newmtl None
5 | Ns 500
6 | Ka 0.8 0.8 0.8
7 | Kd 0.8 0.8 0.8
8 | Ks 0.8 0.8 0.8
9 | d 1
10 | illum 2
11 |
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/collision/link5.mtl:
--------------------------------------------------------------------------------
1 | # Blender MTL File: 'None'
2 | # Material Count: 1
3 |
4 | newmtl None
5 | Ns 500
6 | Ka 0.8 0.8 0.8
7 | Kd 0.8 0.8 0.8
8 | Ks 0.8 0.8 0.8
9 | d 1
10 | illum 2
11 |
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/collision/link6.mtl:
--------------------------------------------------------------------------------
1 | # Blender MTL File: 'None'
2 | # Material Count: 1
3 |
4 | newmtl None
5 | Ns 500
6 | Ka 0.8 0.8 0.8
7 | Kd 0.8 0.8 0.8
8 | Ks 0.8 0.8 0.8
9 | d 1
10 | illum 2
11 |
--------------------------------------------------------------------------------
/conf/planner/mpc_cem.yaml:
--------------------------------------------------------------------------------
1 | _target_: planning.mpc.MPCPlanner
2 | max_iter: null # unlimited if null
3 | n_taken_actions: 5
4 | sub_planner:
5 | target: planning.cem.CEMPlanner
6 | horizon: 5
7 | topk: 30
8 | num_samples: 300
9 | var_scale: 1
10 | opt_steps: 30
11 | eval_every: 1
12 |
13 | name: mpc_cem
--------------------------------------------------------------------------------
/distributed_fn/__init__.py:
--------------------------------------------------------------------------------
1 | from .distributed import (
2 | get_rank,
3 | get_local_rank,
4 | is_primary,
5 | synchronize,
6 | get_world_size,
7 | all_reduce,
8 | all_gather,
9 | reduce_dict,
10 | data_sampler,
11 | LOCAL_PROCESS_GROUP,
12 | )
13 | from .launch import launch
14 |
--------------------------------------------------------------------------------
/conf/planner/mpc_gd.yaml:
--------------------------------------------------------------------------------
1 | _target_: planning.mpc.MPCPlanner
2 | max_iter: null # unlimited if null
3 | n_taken_actions: 1
4 | sub_planner:
5 | target: planning.gd.GDPlanner
6 | horizon: 5
7 | action_noise: 0.003
8 | sample_type: 'randn' # 'zero' or 'randn'
9 | lr: 1
10 | opt_steps: 1000
11 | eval_every: 10
12 |
13 | name: mpc_gd
--------------------------------------------------------------------------------
/datasets/img_transforms.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 |
3 | def default_transform(img_size=224):
4 | return transforms.Compose(
5 | [
6 | transforms.Resize(img_size),
7 | transforms.CenterCrop(img_size),
8 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
9 | ]
10 | )
--------------------------------------------------------------------------------
/models/encoder/r3m/cfgs/hydra/output/local.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | hydra:
3 | run:
4 | dir: ./r3moutput/${hydra.job.name}/${now:%Y-%m-%d}_${now:%H-%M-%S}
5 | subdir: ${hydra.job.num}_${hydra.job.override_dirname}
6 | sweep:
7 | dir: ./r3moutput/${hydra.job.name}/${now:%Y-%m-%d}_${now:%H-%M-%S}
8 | subdir: ${hydra.job.num}_${hydra.job.override_dirname}
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/collision/link2_vhacd2.mtl:
--------------------------------------------------------------------------------
1 | # Blender MTL File: 'None'
2 | # Material Count: 1
3 |
4 | newmtl Default_OBJ
5 | Ns 225.000000
6 | Ka 1.000000 1.000000 1.000000
7 | Kd 0.800000 0.800000 0.800000
8 | Ks 0.500000 0.500000 0.500000
9 | Ke 0.000000 0.000000 0.000000
10 | Ni 1.450000
11 | d 1.000000
12 | illum 2
13 |
--------------------------------------------------------------------------------
/env/wall/data/__init__.py:
--------------------------------------------------------------------------------
1 | from enum import Enum, auto
2 | from .single import DotDataset, DotDatasetConfig, Sample
3 | from .wall import WallDataset, WallDatasetConfig
4 | from .configs import ConfigBase
5 |
6 | class DatasetType(Enum):
7 | Single = auto()
8 | Multiple = auto()
9 | Wall = auto()
10 | WallExpert = auto()
11 | WallEigenfunc = auto()
12 |
--------------------------------------------------------------------------------
/models/encoder/r3m/cfgs/hydra/launcher/local.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | hydra:
3 | launcher:
4 | cpus_per_task: 20
5 | gpus_per_node: 0
6 | tasks_per_node: 1
7 | timeout_min: 600
8 | mem_gb: 64
9 | name: ${hydra.job.name}
10 | _target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.LocalLauncher
11 | submitit_folder: ${hydra.sweep.dir}/.submitit/%j
12 |
--------------------------------------------------------------------------------
/conf/env/point_maze.yaml:
--------------------------------------------------------------------------------
1 | name: point_maze
2 | args: []
3 | kwargs: {}
4 |
5 | dataset:
6 | _target_: "datasets.point_maze_dset.load_point_maze_slice_train_val"
7 | n_rollout: null
8 | normalize_action: ${normalize_action}
9 | data_path: ${oc.env:DATASET_DIR}/point_maze
10 | split_ratio: 0.9
11 | transform:
12 | _target_: "datasets.img_transforms.default_transform"
13 | img_size: ${img_size}
14 |
15 | decoder_path: null
16 | num_workers: 16
--------------------------------------------------------------------------------
/conf/env/wall.yaml:
--------------------------------------------------------------------------------
1 | name: wall
2 | args: []
3 | kwargs: {}
4 |
5 | dataset:
6 | _target_: "datasets.wall_dset.load_wall_slice_train_val"
7 | n_rollout: null
8 | normalize_action: ${normalize_action}
9 | data_path: ${oc.env:DATASET_DIR}/wall_single
10 | split_ratio: 0.9
11 | split_mode: "random"
12 | transform:
13 | _target_: "datasets.img_transforms.default_transform"
14 | img_size: ${img_size}
15 |
16 | decoder_path: null
17 | num_workers: 16
--------------------------------------------------------------------------------
/custom_resolvers.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | from omegaconf import OmegaConf
3 |
4 | @hydra.main(config_path=None)
5 | def register_resolvers(cfg):
6 | pass
7 |
8 | # Define the resolver function
9 | def replace_slash(value: str) -> str:
10 | return value.replace('/', '_')
11 |
12 | # Register the resolver with Hydra
13 | OmegaConf.register_new_resolver("replace_slash", replace_slash)
14 |
15 | if __name__ == "__main__":
16 | register_resolvers()
17 |
18 |
--------------------------------------------------------------------------------
/conf/env/pusht.yaml:
--------------------------------------------------------------------------------
1 | name: pusht
2 | args: []
3 | kwargs:
4 | with_velocity: true
5 | with_target: true
6 |
7 | dataset:
8 | _target_: "datasets.pusht_dset.load_pusht_slice_train_val"
9 | with_velocity: true
10 | n_rollout: null
11 | normalize_action: ${normalize_action}
12 | data_path: ${oc.env:DATASET_DIR}/pusht_noise
13 | split_ratio: 0.9
14 | transform:
15 | _target_: "datasets.img_transforms.default_transform"
16 | img_size: ${img_size}
17 |
18 | decoder_path: null
19 | num_workers: 16
--------------------------------------------------------------------------------
/conf/env/deformable_env.yaml:
--------------------------------------------------------------------------------
1 | name: deformable_env
2 | args: []
3 | kwargs:
4 | object_name: "granular"
5 |
6 | load_dir: ""
7 |
8 | dataset:
9 | _target_: "datasets.deformable_env_dset.load_deformable_dset_slice_train_val"
10 | n_rollout: null
11 | normalize_action: ${normalize_action}
12 | data_path: ${oc.env:DATASET_DIR}/deformable
13 | object_name: "granular"
14 | split_ratio: 0.9
15 | transform:
16 | _target_: "datasets.img_transforms.default_transform"
17 | img_size: ${img_size}
18 |
19 | decoder_path: null
20 | num_workers: 16
--------------------------------------------------------------------------------
/metrics/lpipsPyTorch/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .modules.lpips import LPIPS
4 |
5 |
6 | def lpips(x: torch.Tensor,
7 | y: torch.Tensor,
8 | net_type: str = 'alex',
9 | version: str = '0.1'):
10 | r"""Function that measures
11 | Learned Perceptual Image Patch Similarity (LPIPS).
12 |
13 | Arguments:
14 | x, y (torch.Tensor): the input tensors to compare.
15 | net_type (str): the network type to compare the features:
16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
17 | version (str): the version of LPIPS. Default: 0.1.
18 | """
19 | device = x.device
20 | criterion = LPIPS(net_type, version).to(device)
21 | return criterion(x, y)
22 |
--------------------------------------------------------------------------------
/models/encoder/r3m/cfgs/config_rep.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - override hydra/launcher: local
4 | - override hydra/output: local
5 |
6 |
7 | # snapshot
8 | save_snapshot: false
9 | load_snap: ""
10 | # replay buffer
11 | num_workers: 10
12 | batch_size: 32 #256
13 | train_steps: 2000000
14 | eval_freq: 20000
15 | # misc
16 | seed: 1
17 | device: cuda
18 | # experiment
19 | experiment: train_r3m
20 | # agent
21 | lr: 1e-4
22 | # data
23 | alpha: 0.2
24 | dataset: "ego4d"
25 | wandbproject:
26 | wandbuser:
27 | doaug: "none"
28 | datapath:
29 |
30 | agent:
31 | _target_: r3m.R3M
32 | device: ${device}
33 | lr: ${lr}
34 | hidden_dim: 1024
35 | size: 34
36 | l2weight: 0.00001
37 | l1weight: 0.00001
38 | tcnweight: 1.0
39 | langweight: 0.0
40 | l2dist: true
41 | bs: ${batch_size}
42 |
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/stick_4cm.obj:
--------------------------------------------------------------------------------
1 | # File units = meters
2 | mtllib Stick.mtl
3 | g Part 1
4 | v 0.0217921 0.0189459 0.1
5 | v 0.0217921 0.0189459 0
6 | v -0.0182079 0.0189459 0
7 | v -0.0182079 0.0189459 0.1
8 | v -0.0182079 -0.0210542 0
9 | v -0.0182079 -0.0210542 0.1
10 | v 0.0217921 -0.0210542 0
11 | v 0.0217921 -0.0210542 0.1
12 | vn 0 1 0
13 | vn -1 0 0
14 | vn 0 -1 0
15 | vn 1 0 0
16 | vn 0 0 1
17 | vn 0 0 -1
18 | usemtl 0.615686_0.811765_0.929412_0.000000_0.000000
19 | o mesh0
20 | f 1//1 2//1 3//1
21 | f 3//1 4//1 1//1
22 | o mesh1
23 | f 4//2 3//2 5//2
24 | f 5//2 6//2 4//2
25 | o mesh2
26 | f 6//3 5//3 7//3
27 | f 7//3 8//3 6//3
28 | o mesh3
29 | f 8//4 7//4 2//4
30 | f 2//4 1//4 8//4
31 | o mesh4
32 | f 8//5 1//5 4//5
33 | f 4//5 6//5 8//5
34 | o mesh5
35 | f 5//6 3//6 2//6
36 | f 2//6 7//6 5//6
37 |
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/stick_3cm.obj:
--------------------------------------------------------------------------------
1 | # File units = meters
2 | mtllib stick_3cm.mtl
3 | g Part 1
4 | v 0.0167921 0.0139459 0.1
5 | v 0.0167921 0.0139459 0
6 | v -0.0132079 0.0139459 0
7 | v -0.0132079 0.0139459 0.1
8 | v -0.0132079 -0.0160542 0
9 | v -0.0132079 -0.0160542 0.1
10 | v 0.0167921 -0.0160542 0
11 | v 0.0167921 -0.0160542 0.1
12 | vn 0 1 0
13 | vn -1 0 0
14 | vn 0 -1 0
15 | vn 1 0 0
16 | vn 0 0 1
17 | vn 0 0 -1
18 | usemtl 0.615686_0.811765_0.929412_0.000000_0.000000
19 | o mesh0
20 | f 1//1 2//1 3//1
21 | f 3//1 4//1 1//1
22 | o mesh1
23 | f 4//2 3//2 5//2
24 | f 5//2 6//2 4//2
25 | o mesh2
26 | f 6//3 5//3 7//3
27 | f 7//3 8//3 6//3
28 | o mesh3
29 | f 8//4 7//4 2//4
30 | f 2//4 1//4 8//4
31 | o mesh4
32 | f 8//5 1//5 4//5
33 | f 4//5 6//5 8//5
34 | o mesh5
35 | f 5//6 3//6 2//6
36 | f 2//6 7//6 5//6
37 |
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/stick_6cm.obj:
--------------------------------------------------------------------------------
1 | # File units = meters
2 | mtllib stick_6cm.mtl
3 | g Part 1
4 | v 0.0317921 0.0289459 0.1
5 | v 0.0317921 0.0289459 0
6 | v -0.0282079 0.0289459 0
7 | v -0.0282079 0.0289459 0.1
8 | v -0.0282079 -0.0310542 0
9 | v -0.0282079 -0.0310542 0.1
10 | v 0.0317921 -0.0310542 0
11 | v 0.0317921 -0.0310542 0.1
12 | vn 0 1 0
13 | vn -1 0 0
14 | vn 0 -1 0
15 | vn 1 0 0
16 | vn 0 0 1
17 | vn 0 0 -1
18 | usemtl 0.615686_0.811765_0.929412_0.000000_0.000000
19 | o mesh0
20 | f 1//1 2//1 3//1
21 | f 3//1 4//1 1//1
22 | o mesh1
23 | f 4//2 3//2 5//2
24 | f 5//2 6//2 4//2
25 | o mesh2
26 | f 6//3 5//3 7//3
27 | f 7//3 8//3 6//3
28 | o mesh3
29 | f 8//4 7//4 2//4
30 | f 2//4 1//4 8//4
31 | o mesh4
32 | f 8//5 1//5 4//5
33 | f 4//5 6//5 8//5
34 | o mesh5
35 | f 5//6 3//6 2//6
36 | f 2//6 7//6 5//6
37 |
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/stick_1cm.obj:
--------------------------------------------------------------------------------
1 | # File units = meters
2 | mtllib stick.mtl
3 | g Part 1
4 | v 0.0100144 0.00716822 0.1
5 | v 0.0100144 0.00716822 0
6 | v -0.00643032 0.00716822 0
7 | v -0.00643032 0.00716822 0.1
8 | v -0.00643032 -0.00927652 0
9 | v -0.00643032 -0.00927652 0.1
10 | v 0.0100144 -0.00927652 0
11 | v 0.0100144 -0.00927652 0.1
12 | vn 0 1 0
13 | vn -1 0 0
14 | vn 0 -1 0
15 | vn 1 0 0
16 | vn 0 0 1
17 | vn 0 0 -1
18 | usemtl 0.615686_0.811765_0.929412_0.000000_0.000000
19 | o mesh0
20 | f 1//1 2//1 3//1
21 | f 3//1 4//1 1//1
22 | o mesh1
23 | f 4//2 3//2 5//2
24 | f 5//2 6//2 4//2
25 | o mesh2
26 | f 6//3 5//3 7//3
27 | f 7//3 8//3 6//3
28 | o mesh3
29 | f 8//4 7//4 2//4
30 | f 2//4 1//4 8//4
31 | o mesh4
32 | f 8//5 1//5 4//5
33 | f 4//5 6//5 8//5
34 | o mesh5
35 | f 5//6 3//6 2//6
36 | f 2//6 7//6 5//6
37 |
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/xarm_gripper/meshes/stick_2cm.obj:
--------------------------------------------------------------------------------
1 | # File units = meters
2 | mtllib stick_2cm.mtl
3 | g Part 1
4 | v 0.0117921 0.00894585 0.1
5 | v 0.0117921 0.00894585 0
6 | v -0.00820794 0.00894585 0
7 | v -0.00820794 0.00894585 0.1
8 | v -0.00820794 -0.0110542 0
9 | v -0.00820794 -0.0110542 0.1
10 | v 0.0117921 -0.0110542 0
11 | v 0.0117921 -0.0110542 0.1
12 | vn 0 1 0
13 | vn -1 0 0
14 | vn 0 -1 0
15 | vn 1 0 0
16 | vn 0 0 1
17 | vn 0 0 -1
18 | usemtl 0.615686_0.811765_0.929412_0.000000_0.000000
19 | o mesh0
20 | f 1//1 2//1 3//1
21 | f 3//1 4//1 1//1
22 | o mesh1
23 | f 4//2 3//2 5//2
24 | f 5//2 6//2 4//2
25 | o mesh2
26 | f 6//3 5//3 7//3
27 | f 7//3 8//3 6//3
28 | o mesh3
29 | f 8//4 7//4 2//4
30 | f 2//4 1//4 8//4
31 | o mesh4
32 | f 8//5 1//5 4//5
33 | f 4//5 6//5 8//5
34 | o mesh5
35 | f 5//6 3//6 2//6
36 | f 2//6 7//6 5//6
37 |
--------------------------------------------------------------------------------
/conf/env/rope.yaml:
--------------------------------------------------------------------------------
1 | dataset:
2 | obj: "rope"
3 | obj_params:
4 | rope_length: 2.5
5 | stiffness: 0.8
6 | z_rotation: 10.0
7 |
8 |
9 | # data collection
10 | base: 0
11 | n_episode: 1000
12 | n_timestep: 20
13 | n_worker: 40
14 |
15 | # sim env
16 | headless: True # False: OpenGL visualization
17 | camera_view: 1 # 0 (top view), 1, 2, 3, 4
18 | screenWidth: 224
19 | screenHeight: 224
20 |
21 | robot_type: 'xarm6'
22 | robot_end_idx: 6
23 | robot_num_dofs: 6
24 | robot_speed_inv: 300
25 |
26 | action_dim: 4 # [x_start, z_start, x_end, z_end]
27 | action_space: 4 # random action space scope
28 |
29 | # Tool
30 | gripper: False
31 | pusher_len: 1.0
32 |
33 | # Save particles
34 | fps: False
35 | fps_number: 2000
36 |
37 | rob_obj_dist_thresh: 0.2
38 | contact_interval: 40
39 | non_contact_interval: 80
40 |
41 | # others
42 | color_threshold: 0.01
43 |
--------------------------------------------------------------------------------
/models/dino.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
5 |
6 | class DinoV2Encoder(nn.Module):
7 | def __init__(self, name, feature_key):
8 | super().__init__()
9 | self.name = name
10 | self.base_model = torch.hub.load("facebookresearch/dinov2", name)
11 | self.feature_key = feature_key
12 | self.emb_dim = self.base_model.num_features
13 | if feature_key == "x_norm_patchtokens":
14 | self.latent_ndim = 2
15 | elif feature_key == "x_norm_clstoken":
16 | self.latent_ndim = 1
17 | else:
18 | raise ValueError(f"Invalid feature key: {feature_key}")
19 |
20 | self.patch_size = self.base_model.patch_size
21 |
22 | def forward(self, x):
23 | emb = self.base_model.forward_features(x)[self.feature_key]
24 | if self.latent_ndim == 1:
25 | emb = emb.unsqueeze(1) # dummy patch dim
26 | return emb
--------------------------------------------------------------------------------
/conf/env/granular.yaml:
--------------------------------------------------------------------------------
1 | dataset:
2 | obj: "granular"
3 | obj_params:
4 | granular_scale: 0.2
5 | area: 5.0 # change particle numbers
6 | # area: 1.5
7 | xz_ratio: 1.0
8 | granular_dis: 0.03
9 |
10 | # data collection
11 | base: 0
12 | n_episode: 1000
13 | n_timestep: 20
14 | n_worker: 40
15 |
16 | # sim env
17 | headless: True # False: OpenGL visualization
18 | camera_view: 1 # 0 (top view), 1, 2, 3, 4
19 | screenWidth: 224
20 | screenHeight: 224
21 |
22 | robot_type: 'xarm6'
23 | robot_end_idx: 6
24 | robot_num_dofs: 6
25 | robot_speed_inv: 300
26 |
27 | action_dim: 4 # [x_start, z_start, x_end, z_end]
28 | action_space: 4 # random action space scope
29 |
30 | # Tool
31 | gripper: False
32 | pusher_len: 1.3
33 |
34 | # Save particles
35 | fps: False
36 | fps_number: 2000
37 |
38 | rob_obj_dist_thresh: 0.2
39 | contact_interval: 40
40 | non_contact_interval: 80
41 |
42 | # others
43 | color_threshold: 0.01
44 |
--------------------------------------------------------------------------------
/metrics/lpipsPyTorch/modules/utils.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 |
5 |
6 | def normalize_activation(x, eps=1e-10):
7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
8 | return x / (norm_factor + eps)
9 |
10 |
11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
12 | # build url
13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
14 | + f'master/lpips/weights/v{version}/{net_type}.pth'
15 |
16 | # download
17 | old_state_dict = torch.hub.load_state_dict_from_url(
18 | url, progress=True,
19 | map_location=None if torch.cuda.is_available() else torch.device('cpu')
20 | )
21 |
22 | # rename keys
23 | new_state_dict = OrderedDict()
24 | for key, val in old_state_dict.items():
25 | new_key = key
26 | new_key = new_key.replace('lin', '')
27 | new_key = new_key.replace('model.', '')
28 | new_state_dict[new_key] = val
29 |
30 | return new_state_dict
31 |
--------------------------------------------------------------------------------
/env/__init__.py:
--------------------------------------------------------------------------------
1 | from gym.envs.registration import register
2 | from .pointmaze import U_MAZE
3 | register(
4 | id="pusht",
5 | entry_point="env.pusht.pusht_wrapper:PushTWrapper",
6 | max_episode_steps=300,
7 | reward_threshold=1.0,
8 | )
9 | register(
10 | id='point_maze',
11 | entry_point='env.pointmaze:PointMazeWrapper',
12 | max_episode_steps=300,
13 | kwargs={
14 | 'maze_spec':U_MAZE,
15 | 'reward_type':'sparse',
16 | 'reset_target': False,
17 | 'ref_min_score': 23.85,
18 | 'ref_max_score': 161.86,
19 | 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-sparse-v1.hdf5'
20 | }
21 | )
22 | register(
23 | id="wall",
24 | entry_point="env.wall.wall_env_wrapper:WallEnvWrapper",
25 | max_episode_steps=300,
26 | reward_threshold=1.0,
27 | )
28 |
29 | register(
30 | id="deformable_env",
31 | entry_point="env.deformable_env.FlexEnvWrapper:FlexEnvWrapper",
32 | max_episode_steps=300,
33 | reward_threshold=1.0,
34 | )
--------------------------------------------------------------------------------
/env/pointmaze/gridcraft/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def flat_to_one_hot(val, ndim):
4 | """
5 |
6 | >>> flat_to_one_hot(2, ndim=4)
7 | array([ 0., 0., 1., 0.])
8 | >>> flat_to_one_hot(4, ndim=5)
9 | array([ 0., 0., 0., 0., 1.])
10 | >>> flat_to_one_hot(np.array([2, 4, 3]), ndim=5)
11 | array([[ 0., 0., 1., 0., 0.],
12 | [ 0., 0., 0., 0., 1.],
13 | [ 0., 0., 0., 1., 0.]])
14 | """
15 | shape =np.array(val).shape
16 | v = np.zeros(shape + (ndim,))
17 | if len(shape) == 1:
18 | v[np.arange(shape[0]), val] = 1.0
19 | else:
20 | v[val] = 1.0
21 | return v
22 |
23 | def one_hot_to_flat(val):
24 | """
25 | >>> one_hot_to_flat(np.array([0,0,0,0,1]))
26 | 4
27 | >>> one_hot_to_flat(np.array([0,0,1,0]))
28 | 2
29 | >>> one_hot_to_flat(np.array([[0,0,1,0], [1,0,0,0], [0,1,0,0]]))
30 | array([2, 0, 1])
31 | """
32 | idxs = np.array(np.where(val == 1.0))[-1]
33 | if len(val.shape) == 1:
34 | return int(idxs)
35 | return idxs
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 gaoyuezhou
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/planning/base_planner.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 | from abc import ABC, abstractmethod
4 |
5 |
6 | class BasePlanner(ABC):
7 | def __init__(
8 | self,
9 | wm,
10 | action_dim,
11 | objective_fn,
12 | preprocessor,
13 | evaluator,
14 | wandb_run,
15 | log_filename,
16 | **kwargs,
17 | ):
18 | self.wm = wm
19 | self.action_dim = action_dim
20 | self.objective_fn = objective_fn
21 | self.preprocessor = preprocessor
22 | self.device = next(wm.parameters()).device
23 |
24 | self.evaluator = evaluator
25 | self.wandb_run = wandb_run
26 | self.log_filename = log_filename # do not log if None
27 |
28 | def dump_logs(self, logs):
29 | logs_entry = {
30 | key: (
31 | value.item()
32 | if isinstance(value, (np.float32, np.int32, np.int64))
33 | else value
34 | )
35 | for key, value in logs.items()
36 | }
37 | if self.log_filename is not None:
38 | with open(self.log_filename, "a") as file:
39 | file.write(json.dumps(logs_entry) + "\n")
40 |
41 | @abstractmethod
42 | def plan(self):
43 | pass
44 |
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/sim_env/flex_scene.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pyflex
3 |
4 | from .scenes import *
5 |
6 |
7 | class FlexScene:
8 | def __init__(self):
9 | self.obj = None
10 | self.env_idx = None
11 | self.scene_params = None
12 |
13 | self.property_params = None
14 | self.clusters = None
15 |
16 | def set_scene(self, obj, obj_params):
17 | self.obj = obj
18 |
19 | if self.obj == "rope":
20 | self.env_idx = 26
21 | self.scene_params, self.property_params = rope_scene(obj_params)
22 | elif self.obj == "granular":
23 | self.env_idx = 35
24 | self.scene_params, self.property_params = granular_scene(obj_params)
25 | elif self.obj == "cloth":
26 | self.env_idx = 29
27 | self.scene_params, self.property_params = cloth_scene(obj_params)
28 | else:
29 | raise ValueError("Unknown Scene.")
30 |
31 | assert self.env_idx is not None
32 | assert self.scene_params is not None
33 | zeros = np.array([0])
34 | pyflex.set_scene(self.env_idx, self.scene_params, zeros, zeros, zeros, zeros, 0)
35 |
36 | def get_property_params(self):
37 | assert self.property_params is not None
38 | return self.property_params
39 |
--------------------------------------------------------------------------------
/metrics/lpipsPyTorch/modules/lpips.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .networks import get_network, LinLayers
5 | from .utils import get_state_dict
6 |
7 |
8 | class LPIPS(nn.Module):
9 | r"""Creates a criterion that measures
10 | Learned Perceptual Image Patch Similarity (LPIPS).
11 |
12 | Arguments:
13 | net_type (str): the network type to compare the features:
14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
15 | version (str): the version of LPIPS. Default: 0.1.
16 | """
17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'):
18 |
19 | assert version in ['0.1'], 'v0.1 is only supported now'
20 |
21 | super(LPIPS, self).__init__()
22 |
23 | # pretrained network
24 | self.net = get_network(net_type)
25 |
26 | # linear layers
27 | self.lin = LinLayers(self.net.n_channels_list)
28 | self.lin.load_state_dict(get_state_dict(net_type, version))
29 |
30 | def forward(self, x: torch.Tensor, y: torch.Tensor):
31 | feat_x, feat_y = self.net(x), self.net(y)
32 |
33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
35 | # return torch.sum(torch.cat(res, 0), 0, True)
36 | return torch.mean(torch.sum(torch.cat(res, 1), dim=1)) # return average across batch instead
37 |
--------------------------------------------------------------------------------
/models/dummy.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 |
4 | class DummyModel(nn.Module):
5 | def __init__(self, emb_dim, **kwargs):
6 | super().__init__()
7 | self.name = "dummy"
8 | self.latent_ndim = 1
9 | self.emb_dim = emb_dim
10 | self.fc = nn.Linear(emb_dim, 1) # not used
11 |
12 | def forward(self, x):
13 | b, dim = x.shape
14 | num_repeat = self.emb_dim // dim
15 | processed_x = torch.zeros([b, self.emb_dim]).to(x.device)
16 | x_repeated = x.repeat(1, num_repeat)
17 | processed_x[:, :num_repeat * dim] = x_repeated
18 | return x.unsqueeze(1) # return shape: (b, 1(or # patches), num_features)
19 |
20 | class DummyRepeatActionEncoder(nn.Module):
21 | def __init__(self, in_chans, emb_dim, **kwargs):
22 | super().__init__()
23 | self.name = "dummy_repeat"
24 | self.latent_ndim = 1
25 | self.in_chans = in_chans
26 | self.emb_dim = emb_dim
27 | self.fc = nn.Linear(in_chans, 1) # not used
28 |
29 | def forward(self, act):
30 | '''
31 | (b, t, act_dim) --> (b, t, action_emb_dim)
32 | '''
33 | b, t, act_dim = act.shape
34 | num_repeat = self.emb_dim // act_dim
35 | processed_act = torch.zeros([b, t, self.emb_dim]).to(act.device)
36 | act_repeated = act.repeat(1, 1, num_repeat)
37 | processed_act[:, :, :num_repeat * act_dim] = act_repeated
38 | return processed_act
--------------------------------------------------------------------------------
/conf/plan.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - planner: gd
4 | - override hydra/launcher: submitit_slurm
5 |
6 | hydra:
7 | run:
8 | dir: plan_outputs/${now:%Y%m%d%H%M%S}_${replace_slash:${model_name}}_gH${goal_H}
9 | sweep:
10 | dir: plan_outputs/${now:%Y%m%d%H%M%S}_${replace_slash:${model_name}}_gH${goal_H}
11 | subdir: ${hydra.job.num}
12 | launcher:
13 | submitit_folder: ${hydra.sweep.dir}/.submitit/%j
14 | nodes: 1
15 | tasks_per_node: 1
16 | cpus_per_task: 16
17 | mem_gb: 256
18 | gres: "gpu:h100:1"
19 | qos: "explore"
20 | timeout_min: 720
21 | setup: ["export DEBUGVAR=$(scontrol show hostnames $SLURM_JOB_NODELIST)",
22 | export MASTER_ADDR="$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)",
23 | "export MASTER_PORT=$(for port in $(shuf -i 30000-65500 -n 20); do if [[ $(netstat -tupln 2>&1 | grep $port | wc -l) -eq 0 ]] ; then echo $port; break; fi; done;)",]
24 |
25 | # model to load for planning
26 | ckpt_base_path: ./ # put absolute path here. Checkpoints will be loaded from ${ckpt_base_path}/outputs
27 | model_name: null
28 | model_epoch: latest
29 |
30 | seed: 99
31 | n_evals: 10
32 | goal_source: 'dset' # 'random_state' or 'dset' or 'random_action'
33 | goal_H: 5 # specifies how far away the goal is if goal_source is 'dset'
34 | n_plot_samples: 10
35 |
36 | debug_dset_init: False
37 |
38 | objective:
39 | _target_: planning.objectives.create_objective_fn
40 | alpha: 1
41 | base: 2 # coeff base for weighting all frames. Only applies when mode == 'all'
42 | mode: last
43 |
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2018, UFACTORY Inc.
2 |
3 | All rights reserved.
4 |
5 | Redistribution and use in source and binary forms, with or without modification,
6 | are permitted provided that the following conditions are met:
7 |
8 | * Redistributions of source code must retain the above copyright notice,
9 | this list of conditions and the following disclaimer.
10 | * Redistributions in binary form must reproduce the above copyright notice,
11 | this list of conditions and the following disclaimer in the documentation
12 | and/or other materials provided with the distribution.
13 | * Neither the name of the copyright holder nor the names of its contributors
14 | may be used to endorse or promote products derived from this software
15 | without specific prior written permission.
16 |
17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
18 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
19 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
20 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
21 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
22 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
23 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
24 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
25 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
26 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
27 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 |
--------------------------------------------------------------------------------
/conf/plan_wall.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - override hydra/launcher: submitit_slurm
4 |
5 | hydra:
6 | run:
7 | dir: plan_outputs/${now:%Y%m%d%H%M%S}_${replace_slash:${model_name}}_gH${goal_H}
8 | sweep:
9 | dir: plan_outputs/${now:%Y%m%d%H%M%S}_${replace_slash:${model_name}}_gH${goal_H}
10 | subdir: ${hydra.job.num}
11 | launcher:
12 | submitit_folder: ${hydra.sweep.dir}/.submitit/%j
13 | nodes: 1
14 | tasks_per_node: 1
15 | cpus_per_task: 16
16 | mem_gb: 256
17 | gres: "gpu:h100:1"
18 | qos: "explore"
19 | timeout_min: 720
20 | setup: ["export DEBUGVAR=$(scontrol show hostnames $SLURM_JOB_NODELIST)",
21 | export MASTER_ADDR="$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)",
22 | "export MASTER_PORT=$(for port in $(shuf -i 30000-65500 -n 20); do if [[ $(netstat -tupln 2>&1 | grep $port | wc -l) -eq 0 ]] ; then echo $port; break; fi; done;)",]
23 |
24 | # model to load for planning
25 | ckpt_base_path: ./checkpoints # put absolute path here. Checkpoints will be loaded from ${ckpt_base_path}/outputs
26 | model_name: null
27 | model_epoch: latest
28 |
29 | seed: 99
30 | n_evals: 50
31 | goal_source: 'random_state'
32 | goal_H: 5
33 | n_plot_samples: 10
34 |
35 | debug_dset_init: False
36 |
37 | objective:
38 | _target_: planning.objectives.create_objective_fn
39 | alpha: 1
40 | base: 2
41 | mode: last
42 |
43 | planner:
44 | _target_: planning.mpc.MPCPlanner
45 | max_iter: null
46 | n_taken_actions: 5
47 | sub_planner:
48 | target: planning.cem.CEMPlanner
49 | horizon: 5
50 | topk: 30
51 | num_samples: 300
52 | var_scale: 1
53 | opt_steps: 10
54 | eval_every: 1
55 | name: mpc_cem
--------------------------------------------------------------------------------
/conf/plan_point_maze.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - override hydra/launcher: submitit_slurm
4 |
5 | hydra:
6 | run:
7 | dir: plan_outputs/${now:%Y%m%d%H%M%S}_${replace_slash:${model_name}}_gH${goal_H}
8 | sweep:
9 | dir: plan_outputs/${now:%Y%m%d%H%M%S}_${replace_slash:${model_name}}_gH${goal_H}
10 | subdir: ${hydra.job.num}
11 | launcher:
12 | submitit_folder: ${hydra.sweep.dir}/.submitit/%j
13 | nodes: 1
14 | tasks_per_node: 1
15 | cpus_per_task: 16
16 | mem_gb: 256
17 | gres: "gpu:h100:1"
18 | qos: "explore"
19 | timeout_min: 720
20 | setup: ["export DEBUGVAR=$(scontrol show hostnames $SLURM_JOB_NODELIST)",
21 | export MASTER_ADDR="$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)",
22 | "export MASTER_PORT=$(for port in $(shuf -i 30000-65500 -n 20); do if [[ $(netstat -tupln 2>&1 | grep $port | wc -l) -eq 0 ]] ; then echo $port; break; fi; done;)",]
23 |
24 | # model to load for planning
25 | ckpt_base_path: ./checkpoints # put absolute path here. Checkpoints will be loaded from ${ckpt_base_path}/outputs
26 | model_name: null
27 | model_epoch: latest
28 |
29 | seed: 99
30 | n_evals: 50
31 | goal_source: 'random_state'
32 | goal_H: 5
33 | n_plot_samples: 10
34 |
35 | debug_dset_init: False
36 |
37 | objective:
38 | _target_: planning.objectives.create_objective_fn
39 | alpha: 0
40 | base: 2
41 | mode: last
42 |
43 | planner:
44 | _target_: planning.mpc.MPCPlanner
45 | max_iter: null
46 | n_taken_actions: 5
47 | sub_planner:
48 | target: planning.cem.CEMPlanner
49 | horizon: 5
50 | topk: 30
51 | num_samples: 300
52 | var_scale: 1
53 | opt_steps: 10
54 | eval_every: 1
55 | name: mpc_cem
--------------------------------------------------------------------------------
/conf/plan_pusht.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - override hydra/launcher: submitit_slurm
4 |
5 | hydra:
6 | run:
7 | dir: plan_outputs/${now:%Y%m%d%H%M%S}_${replace_slash:${model_name}}_gH${goal_H}
8 | sweep:
9 | dir: plan_outputs/${now:%Y%m%d%H%M%S}_${replace_slash:${model_name}}_gH${goal_H}
10 | subdir: ${hydra.job.num}
11 | launcher:
12 | submitit_folder: ${hydra.sweep.dir}/.submitit/%j
13 | nodes: 1
14 | tasks_per_node: 1
15 | cpus_per_task: 16
16 | mem_gb: 256
17 | gres: "gpu:h100:1"
18 | qos: "explore"
19 | timeout_min: 720
20 | setup: ["export DEBUGVAR=$(scontrol show hostnames $SLURM_JOB_NODELIST)",
21 | export MASTER_ADDR="$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)",
22 | "export MASTER_PORT=$(for port in $(shuf -i 30000-65500 -n 20); do if [[ $(netstat -tupln 2>&1 | grep $port | wc -l) -eq 0 ]] ; then echo $port; break; fi; done;)",]
23 |
24 | # model to load for planning
25 | ckpt_base_path: ./checkpoints # put absolute path here. Checkpoints will be loaded from ${ckpt_base_path}/outputs
26 | model_name: null
27 | model_epoch: latest
28 |
29 | seed: 99
30 | n_evals: 50
31 | goal_source: 'dset'
32 | goal_H: 5
33 | n_plot_samples: 10
34 |
35 | debug_dset_init: False
36 |
37 | objective:
38 | _target_: planning.objectives.create_objective_fn
39 | alpha: 1
40 | base: 2
41 | mode: last
42 |
43 | planner:
44 | _target_: planning.mpc.MPCPlanner
45 | max_iter: null
46 | n_taken_actions: 5
47 | sub_planner:
48 | target: planning.cem.CEMPlanner
49 | horizon: 5
50 | topk: 30
51 | num_samples: 300
52 | var_scale: 1
53 | opt_steps: 30
54 | eval_every: 1
55 | name: mpc_cem
56 |
57 |
--------------------------------------------------------------------------------
/models/encoder/r3m/r3m_base.yaml:
--------------------------------------------------------------------------------
1 | name: r3m_base
2 | channels:
3 | - conda-forge
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1=conda_forge
7 | - _openmp_mutex=4.5=1_gnu
8 | - bzip2=1.0.8=h7f98852_4
9 | - ca-certificates=2021.10.8=ha878542_0
10 | - ld_impl_linux-64=2.36.1=hea4e1c9_2
11 | - libffi=3.4.2=h7f98852_5
12 | - libgcc-ng=11.2.0=h1d223b6_13
13 | - libgomp=11.2.0=h1d223b6_13
14 | - libnsl=2.0.0=h7f98852_0
15 | - libuuid=2.32.1=h7f98852_1000
16 | - libzlib=1.2.11=h36c2ea0_1013
17 | - ncurses=6.3=h9c3ff4c_0
18 | - openssl=3.0.0=h7f98852_2
19 | - pip=22.0.4=pyhd8ed1ab_0
20 | - python=3.9.10=hc74c709_2_cpython
21 | - python_abi=3.9=2_cp39
22 | - readline=8.1=h46c0cb4_0
23 | - setuptools=60.9.3=py39hf3d152e_0
24 | - sqlite=3.37.0=h9cd32fc_0
25 | - tk=8.6.12=h27826a3_0
26 | - tzdata=2021e=he74cb21_0
27 | - wheel=0.37.1=pyhd8ed1ab_0
28 | - xz=5.2.5=h516909a_1
29 | - zlib=1.2.11=h36c2ea0_1013
30 | - pip:
31 | - antlr4-python3-runtime==4.8
32 | - beautifulsoup4==4.10.0
33 | - certifi==2021.10.8
34 | - charset-normalizer==2.0.12
35 | - click==8.0.4
36 | - cycler==0.11.0
37 | - filelock==3.6.0
38 | - fonttools==4.30.0
39 | - gdown==4.4.0
40 | - huggingface-hub==0.4.0
41 | - hydra-core==1.1.1
42 | - idna==3.3
43 | - joblib==1.1.0
44 | - kiwisolver==1.3.2
45 | - matplotlib==3.5.1
46 | - numpy==1.22.3
47 | - omegaconf==2.1.1
48 | - packaging==21.3
49 | - pillow==9.0.1
50 | - pyparsing==3.0.7
51 | - pysocks==1.7.1
52 | - python-dateutil==2.8.2
53 | - pyyaml==6.0
54 | - regex==2022.3.2
55 | - requests==2.27.1
56 | - sacremoses==0.0.47
57 | - six==1.16.0
58 | - soupsieve==2.3.1
59 | - tokenizers==0.11.6
60 | - torch==1.7.1
61 | - torchvision==0.8.2
62 | - tqdm==4.63.0
63 | - transformers==4.17.0
64 | - typing-extensions==4.1.1
65 | - urllib3==1.26.8
66 | prefix: /private/home/surajn/.conda/envs/r3m_base
67 |
--------------------------------------------------------------------------------
/preprocessor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from einops import rearrange
3 |
4 | class Preprocessor:
5 | def __init__(self,
6 | action_mean,
7 | action_std,
8 | state_mean,
9 | state_std,
10 | proprio_mean,
11 | proprio_std,
12 | transform,
13 | ):
14 | self.action_mean = action_mean
15 | self.action_std = action_std
16 | self.state_mean = state_mean
17 | self.state_std = state_std
18 | self.proprio_mean = proprio_mean
19 | self.proprio_std = proprio_std
20 | self.transform = transform
21 |
22 | def normalize_actions(self, actions):
23 | '''
24 | actions: (b, t, action_dim)
25 | '''
26 | return (actions - self.action_mean) / self.action_std
27 |
28 | def denormalize_actions(self, actions):
29 | '''
30 | actions: (b, t, action_dim)
31 | '''
32 | return actions * self.action_std + self.action_mean
33 |
34 | def normalize_proprios(self, proprio):
35 | '''
36 | input shape (..., proprio_dim)
37 | '''
38 | return (proprio - self.proprio_mean) / self.proprio_std
39 |
40 | def normalize_states(self, state):
41 | '''
42 | input shape (..., state_dim)
43 | '''
44 | return (state - self.state_mean) / self.state_std
45 |
46 | def preprocess_obs_visual(self, obs_visual):
47 | return rearrange(obs_visual, "b t h w c -> b t c h w") / 255.0
48 |
49 | def transform_obs_visual(self, obs_visual):
50 | transformed_obs_visual = torch.tensor(obs_visual)
51 | transformed_obs_visual = self.preprocess_obs_visual(transformed_obs_visual)
52 | transformed_obs_visual = self.transform(transformed_obs_visual)
53 | return transformed_obs_visual
54 |
55 | def transform_obs(self, obs):
56 | '''
57 | np arrays to tensors
58 | '''
59 | transformed_obs = {}
60 | transformed_obs['visual'] = self.transform_obs_visual(obs['visual'])
61 | transformed_obs['proprio'] = self.normalize_proprios(torch.tensor(obs['proprio']))
62 | return transformed_obs
63 |
--------------------------------------------------------------------------------
/models/proprio.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/facebookresearch/ijepa/blob/main/src/models/vision_transformer.py
2 | import numpy as np
3 | import torch.nn as nn
4 |
5 |
6 | def get_1d_sincos_pos_embed(emb_dim, grid_size, cls_token=False):
7 | """
8 | emb_dim: output dimension for each position
9 | grid_size: int of the grid length
10 | returns:
11 | pos_embed: [grid_size, emb_dim] (w/o cls_token)
12 | or [1+grid_size, emb_dim] (w/ cls_token)
13 | """
14 | grid = np.arange(grid_size, dtype=float)
15 | pos_embed = get_1d_sincos_pos_embed_from_grid(emb_dim, grid)
16 | if cls_token:
17 | pos_embed = np.concatenate([np.zeros([1, emb_dim]), pos_embed], axis=0)
18 | return pos_embed
19 |
20 | def get_1d_sincos_pos_embed_from_grid(emb_dim, pos):
21 | """
22 | emb_dim: output dimension for each position
23 | pos: a list of positions to be encoded: size (M,)
24 | returns: (M, D)
25 | """
26 | assert emb_dim % 2 == 0
27 | omega = np.arange(emb_dim // 2, dtype=float)
28 | omega /= emb_dim / 2.
29 | omega = 1. / 10000**omega # (D/2,)
30 |
31 | pos = pos.reshape(-1) # (M,)
32 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
33 |
34 | emb_sin = np.sin(out) # (M, D/2)
35 | emb_cos = np.cos(out) # (M, D/2)
36 |
37 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
38 | return emb
39 |
40 | class ProprioceptiveEmbedding(nn.Module):
41 | def __init__(
42 | self,
43 | num_frames=16, # horizon
44 | tubelet_size=1,
45 | in_chans=8, # action_dim
46 | emb_dim=384, # output_dim
47 | use_3d_pos=False # always False for now
48 | ):
49 | super().__init__()
50 | print(f'using 3d prop position {use_3d_pos=}')
51 |
52 | # Map input to predictor dimension
53 | self.num_frames = num_frames
54 | self.tubelet_size = tubelet_size
55 | self.in_chans = in_chans
56 | self.emb_dim = emb_dim
57 |
58 | self.patch_embed = nn.Conv1d(
59 | in_chans,
60 | emb_dim,
61 | kernel_size=tubelet_size,
62 | stride=tubelet_size)
63 |
64 | def forward(self, x):
65 | # x: proprioceptive vectors of shape [B T D]
66 | x = x.permute(0, 2, 1)
67 | x = self.patch_embed(x)
68 | x = x.permute(0, 2, 1)
69 | return x
--------------------------------------------------------------------------------
/models/encoder/r3m/models/models_language.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | import numpy as np
6 | from numpy.core.numeric import full
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn.modules.activation import Sigmoid
10 |
11 | epsilon = 1e-8
12 |
13 |
14 | class LangEncoder(nn.Module):
15 | def __init__(self, device, finetune=False, scratch=False):
16 | super().__init__()
17 | from transformers import AutoTokenizer, AutoModel, AutoConfig
18 |
19 | self.device = device
20 | self.modelname = "distilbert-base-uncased"
21 | self.tokenizer = AutoTokenizer.from_pretrained(self.modelname)
22 | self.model = AutoModel.from_pretrained(self.modelname).to(self.device)
23 | self.lang_size = 768
24 |
25 | def forward(self, langs):
26 | try:
27 | langs = langs.tolist()
28 | except:
29 | pass
30 |
31 | with torch.no_grad():
32 | encoded_input = self.tokenizer(langs, return_tensors="pt", padding=True)
33 | input_ids = encoded_input["input_ids"].to(self.device)
34 | attention_mask = encoded_input["attention_mask"].to(self.device)
35 | lang_embedding = self.model(
36 | input_ids, attention_mask=attention_mask
37 | ).last_hidden_state
38 | lang_embedding = lang_embedding.mean(1)
39 | return lang_embedding
40 |
41 |
42 | class LanguageReward(nn.Module):
43 | def __init__(self, ltype, im_dim, hidden_dim, lang_dim, simfunc=None):
44 | super().__init__()
45 | self.ltype = ltype
46 | self.sim = simfunc
47 | self.sigm = Sigmoid()
48 | self.pred = nn.Sequential(
49 | nn.Linear(im_dim * 2 + lang_dim, hidden_dim),
50 | nn.ReLU(inplace=True),
51 | nn.Linear(hidden_dim, hidden_dim),
52 | nn.ReLU(inplace=True),
53 | nn.Linear(hidden_dim, hidden_dim),
54 | nn.ReLU(inplace=True),
55 | nn.Linear(hidden_dim, hidden_dim),
56 | nn.ReLU(inplace=True),
57 | nn.Linear(hidden_dim, 1),
58 | )
59 |
60 | def forward(self, e0, eg, le):
61 | info = {}
62 | return self.pred(torch.cat([e0, eg, le], -1)).squeeze(), info
63 |
--------------------------------------------------------------------------------
/planning/objectives.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | def create_objective_fn(alpha, base, mode="last"):
7 | """
8 | Loss calculated on the last pred frame.
9 | Args:
10 | alpha: int
11 | base: int. only used for objective_fn_all
12 | Returns:
13 | loss: tensor (B, )
14 | """
15 | metric = nn.MSELoss(reduction="none")
16 |
17 | def objective_fn_last(z_obs_pred, z_obs_tgt):
18 | """
19 | Args:
20 | z_obs_pred: dict, {'visual': (B, T, *D_visual), 'proprio': (B, T, *D_proprio)}
21 | z_obs_tgt: dict, {'visual': (B, T, *D_visual), 'proprio': (B, T, *D_proprio)}
22 | Returns:
23 | loss: tensor (B, )
24 | """
25 | loss_visual = metric(z_obs_pred["visual"][:, -1:], z_obs_tgt["visual"]).mean(
26 | dim=tuple(range(1, z_obs_pred["visual"].ndim))
27 | )
28 | loss_proprio = metric(z_obs_pred["proprio"][:, -1:], z_obs_tgt["proprio"]).mean(
29 | dim=tuple(range(1, z_obs_pred["proprio"].ndim))
30 | )
31 | loss = loss_visual + alpha * loss_proprio
32 | return loss
33 |
34 | def objective_fn_all(z_obs_pred, z_obs_tgt):
35 | """
36 | Loss calculated on all pred frames.
37 | Args:
38 | z_obs_pred: dict, {'visual': (B, T, *D_visual), 'proprio': (B, T, *D_proprio)}
39 | z_obs_tgt: dict, {'visual': (B, T, *D_visual), 'proprio': (B, T, *D_proprio)}
40 | Returns:
41 | loss: tensor (B, )
42 | """
43 | coeffs = np.array(
44 | [base**i for i in range(z_obs_pred["visual"].shape[1])], dtype=np.float32
45 | )
46 | coeffs = torch.tensor(coeffs / np.sum(coeffs)).to(z_obs_pred["visual"].device)
47 | loss_visual = metric(z_obs_pred["visual"], z_obs_tgt["visual"]).mean(
48 | dim=tuple(range(2, z_obs_pred["visual"].ndim))
49 | )
50 | loss_proprio = metric(z_obs_pred["proprio"], z_obs_tgt["proprio"]).mean(
51 | dim=tuple(range(2, z_obs_pred["proprio"].ndim))
52 | )
53 | loss_visual = (loss_visual * coeffs).mean(dim=1)
54 | loss_proprio = (loss_proprio * coeffs).mean(dim=1)
55 | loss = loss_visual + alpha * loss_proprio
56 | return loss
57 |
58 | if mode == "last":
59 | return objective_fn_last
60 | elif mode == "all":
61 | return objective_fn_all
62 | else:
63 | raise NotImplementedError
64 |
--------------------------------------------------------------------------------
/conf/train.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - env: point_maze
4 | - encoder: dino
5 | - action_encoder: proprio
6 | - proprio_encoder: proprio
7 | - decoder: vqvae
8 | - predictor: vit
9 | - override hydra/launcher: submitit_slurm
10 |
11 | # base path to save model outputs. Checkpoints will be saved to ${ckpt_base_path}/outputs.
12 | ckpt_base_path: ./ # put absolute path here
13 |
14 | hydra:
15 | run:
16 | dir: ${ckpt_base_path}/outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
17 | sweep:
18 | dir: ${ckpt_base_path}/outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
19 | subdir: ${hydra.job.num}
20 | launcher:
21 | submitit_folder: ${hydra.sweep.dir}/.submitit/%j
22 | nodes: 1
23 | tasks_per_node: 1
24 | cpus_per_task: 8
25 | mem_gb: 512
26 | gres: "gpu:h100:1"
27 | timeout_min: 2880
28 | setup: ["export DEBUGVAR=$(scontrol show hostnames $SLURM_JOB_NODELIST)",
29 | export MASTER_ADDR="$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)",
30 | "export MASTER_PORT=$(for port in $(shuf -i 30000-65500 -n 20); do if [[ $(netstat -tupln 2>&1 | grep $port | wc -l) -eq 0 ]] ; then echo $port; break; fi; done;)",]
31 |
32 | training:
33 | seed: 0
34 | epochs: 100
35 | batch_size: 32 # should >= nodes * tasks_per_node
36 | save_every_x_epoch: 1
37 | reconstruct_every_x_batch: 500
38 | num_reconstruct_samples: 6
39 | encoder_lr: 1e-6
40 | decoder_lr: 3e-4
41 | predictor_lr: 5e-4
42 | action_encoder_lr: 5e-4
43 |
44 | img_size: 224 # should be a multiple of 224
45 | frameskip: 5
46 | concat_dim: 1
47 |
48 | normalize_action: True
49 |
50 | # action encoder
51 | action_emb_dim: 10
52 | num_action_repeat: 1
53 |
54 | # proprio encoder
55 | proprio_emb_dim: 10
56 | num_proprio_repeat: 1
57 |
58 | num_hist: 3
59 | num_pred: 1 # only supports 1
60 | has_predictor: True # set this to False for only training a decoder
61 | has_decoder: True # set this to False for only training a predictor
62 |
63 | model:
64 | _target_: models.visual_world_model.VWorldModel
65 | image_size: ${img_size}
66 | num_hist: ${num_hist}
67 | num_pred: ${num_pred}
68 | train_encoder: False
69 | train_predictor: True
70 | train_decoder: True
71 |
72 | debug: False
73 |
74 | # Planning params for planning eval jobs launched during training
75 | plan_settings:
76 | # plan_cfg_path: conf/plan.yaml # set to null for no planning evals
77 | plan_cfg_path: null
78 | planner: ['gd', 'cem']
79 | goal_source: ['dset', 'random_state']
80 | goal_H: [5]
81 | alpha: [0.1, 1]
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/link6.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
11 |
12 |
13 | /xarm
14 |
15 | gazebo_ros_control/DefaultRobotHWSim
16 | true
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/base.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
11 |
12 |
13 | /xarm
14 |
15 | gazebo_ros_control/DefaultRobotHWSim
16 | true
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/link1.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
11 |
12 |
13 | /xarm
14 |
15 | gazebo_ros_control/DefaultRobotHWSim
16 | true
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/link6_com.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
11 |
12 |
13 | /xarm
14 |
15 | gazebo_ros_control/DefaultRobotHWSim
16 | true
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/link3.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
11 |
12 |
13 | /xarm
14 |
15 | gazebo_ros_control/DefaultRobotHWSim
16 | true
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/base_com.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
11 |
12 |
13 | /xarm
14 |
15 | gazebo_ros_control/DefaultRobotHWSim
16 | true
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/link5.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
11 |
12 |
13 | /xarm
14 |
15 | gazebo_ros_control/DefaultRobotHWSim
16 | true
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/link4.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
11 |
12 |
13 | /xarm
14 |
15 | gazebo_ros_control/DefaultRobotHWSim
16 | true
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/link1_com.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
11 |
12 |
13 | /xarm
14 |
15 | gazebo_ros_control/DefaultRobotHWSim
16 | true
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/link2.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
11 |
12 |
13 | /xarm
14 |
15 | gazebo_ros_control/DefaultRobotHWSim
16 | true
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/link3_com.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
11 |
12 |
13 | /xarm
14 |
15 | gazebo_ros_control/DefaultRobotHWSim
16 | true
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/link5_com.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
11 |
12 |
13 | /xarm
14 |
15 | gazebo_ros_control/DefaultRobotHWSim
16 | true
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/link4_com.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
11 |
12 |
13 | /xarm
14 |
15 | gazebo_ros_control/DefaultRobotHWSim
16 | true
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/link2_com.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
11 |
12 |
13 | /xarm
14 |
15 | gazebo_ros_control/DefaultRobotHWSim
16 | true
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
--------------------------------------------------------------------------------
/distributed_fn/launch.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch import distributed as dist
5 | from torch import multiprocessing as mp
6 |
7 | import distributed_fn as dist_fn
8 |
9 |
10 | def find_free_port():
11 | import socket
12 |
13 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
14 |
15 | sock.bind(("", 0))
16 | port = sock.getsockname()[1]
17 | sock.close()
18 |
19 | return port
20 |
21 |
22 | def launch(fn, n_gpu_per_machine, n_machine=1, machine_rank=0, dist_url=None, args=()):
23 | world_size = n_machine * n_gpu_per_machine
24 |
25 | if world_size > 1:
26 | if "OMP_NUM_THREADS" not in os.environ:
27 | os.environ["OMP_NUM_THREADS"] = "1"
28 |
29 | if dist_url == "auto":
30 | if n_machine != 1:
31 | raise ValueError('dist_url="auto" not supported in multi-machine jobs')
32 |
33 | port = find_free_port()
34 | dist_url = f"tcp://127.0.0.1:{port}"
35 |
36 | if n_machine > 1 and dist_url.startswith("file://"):
37 | raise ValueError(
38 | "file:// is not a reliable init method in multi-machine jobs. Prefer tcp://"
39 | )
40 |
41 | mp.spawn(
42 | distributed_worker,
43 | nprocs=n_gpu_per_machine,
44 | args=(fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args),
45 | daemon=False,
46 | )
47 |
48 | else:
49 | fn(*args)
50 |
51 |
52 | def distributed_worker(
53 | local_rank, fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args
54 | ):
55 | if not torch.cuda.is_available():
56 | raise OSError("CUDA is not available. Please check your environments")
57 |
58 | global_rank = machine_rank * n_gpu_per_machine + local_rank
59 |
60 | try:
61 | dist.init_process_group(
62 | backend="NCCL",
63 | init_method=dist_url,
64 | world_size=world_size,
65 | rank=global_rank,
66 | )
67 |
68 | except Exception:
69 | raise OSError("failed to initialize NCCL groups")
70 |
71 | dist_fn.synchronize()
72 |
73 | if n_gpu_per_machine > torch.cuda.device_count():
74 | raise ValueError(
75 | f"specified n_gpu_per_machine larger than available device ({torch.cuda.device_count()})"
76 | )
77 |
78 | torch.cuda.set_device(local_rank)
79 |
80 | if dist_fn.LOCAL_PROCESS_GROUP is not None:
81 | raise ValueError("torch.distributed.LOCAL_PROCESS_GROUP is not None")
82 |
83 | n_machine = world_size // n_gpu_per_machine
84 |
85 | for i in range(n_machine):
86 | ranks_on_i = list(range(i * n_gpu_per_machine, (i + 1) * n_gpu_per_machine))
87 | pg = dist.new_group(ranks_on_i)
88 |
89 | if i == machine_rank:
90 | dist_fn.distributed.LOCAL_PROCESS_GROUP = pg
91 |
92 | fn(*args)
93 |
--------------------------------------------------------------------------------
/metrics/lpipsPyTorch/modules/networks.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence
2 |
3 | from itertools import chain
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torchvision import models
8 |
9 | from .utils import normalize_activation
10 |
11 |
12 | def get_network(net_type: str):
13 | if net_type == 'alex':
14 | return AlexNet()
15 | elif net_type == 'squeeze':
16 | return SqueezeNet()
17 | elif net_type == 'vgg':
18 | return VGG16()
19 | else:
20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
21 |
22 |
23 | class LinLayers(nn.ModuleList):
24 | def __init__(self, n_channels_list: Sequence[int]):
25 | super(LinLayers, self).__init__([
26 | nn.Sequential(
27 | nn.Identity(),
28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
29 | ) for nc in n_channels_list
30 | ])
31 |
32 | for param in self.parameters():
33 | param.requires_grad = False
34 |
35 |
36 | class BaseNet(nn.Module):
37 | def __init__(self):
38 | super(BaseNet, self).__init__()
39 |
40 | # register buffer
41 | self.register_buffer(
42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
43 | self.register_buffer(
44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
45 |
46 | def set_requires_grad(self, state: bool):
47 | for param in chain(self.parameters(), self.buffers()):
48 | param.requires_grad = state
49 |
50 | def z_score(self, x: torch.Tensor):
51 | return (x - self.mean) / self.std
52 |
53 | def forward(self, x: torch.Tensor):
54 | x = self.z_score(x)
55 |
56 | output = []
57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
58 | x = layer(x)
59 | if i in self.target_layers:
60 | output.append(normalize_activation(x))
61 | if len(output) == len(self.target_layers):
62 | break
63 | return output
64 |
65 |
66 | class SqueezeNet(BaseNet):
67 | def __init__(self):
68 | super(SqueezeNet, self).__init__()
69 |
70 | self.layers = models.squeezenet1_1(True).features
71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13]
72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
73 |
74 | self.set_requires_grad(False)
75 |
76 |
77 | class AlexNet(BaseNet):
78 | def __init__(self):
79 | super(AlexNet, self).__init__()
80 |
81 | self.layers = models.alexnet(True).features
82 | self.target_layers = [2, 5, 8, 10, 12]
83 | self.n_channels_list = [64, 192, 384, 256, 256]
84 |
85 | self.set_requires_grad(False)
86 |
87 |
88 | class VGG16(BaseNet):
89 | def __init__(self):
90 | super(VGG16, self).__init__()
91 |
92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
93 | self.target_layers = [4, 9, 16, 23, 30]
94 | self.n_channels_list = [64, 128, 256, 512, 512]
95 |
96 | self.set_requires_grad(False)
97 |
--------------------------------------------------------------------------------
/metrics/image_metrics.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 | import torch
12 | import torch.nn.functional as F
13 | from torch.autograd import Variable
14 | from math import exp
15 | from metrics.lpipsPyTorch import lpips
16 |
17 | def l1_loss(network_output, gt):
18 | return torch.abs((network_output - gt)).mean()
19 |
20 | def l2_loss(network_output, gt):
21 | return ((network_output - gt) ** 2).mean()
22 |
23 | def gaussian(window_size, sigma):
24 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
25 | return gauss / gauss.sum()
26 |
27 | def create_window(window_size, channel):
28 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
29 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
30 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
31 | return window
32 |
33 | def ssim(img1, img2, window_size=11, size_average=True):
34 | channel = img1.size(-3)
35 | window = create_window(window_size, channel)
36 |
37 | if img1.is_cuda:
38 | window = window.cuda(img1.get_device())
39 | window = window.type_as(img1)
40 |
41 | return _ssim(img1, img2, window, window_size, channel, size_average)
42 |
43 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
44 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
45 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
46 |
47 | mu1_sq = mu1.pow(2)
48 | mu2_sq = mu2.pow(2)
49 | mu1_mu2 = mu1 * mu2
50 |
51 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
52 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
53 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
54 |
55 | C1 = 0.01 ** 2
56 | C2 = 0.03 ** 2
57 |
58 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
59 |
60 | if size_average:
61 | return ssim_map.mean()
62 | else:
63 | return ssim_map.mean(1).mean(1).mean(1)
64 |
65 | def mse(img1, img2): # batched input only
66 | return (((img1 - img2)) ** 2).reshape(img1.shape[0], -1).mean()
67 |
68 | def psnr(img1, img2): # batched input only
69 | mse = (((img1 - img2)) ** 2).reshape(img1.shape[0], -1).mean()
70 | return 20 * torch.log10(1.0 / torch.sqrt(mse))
71 |
72 | def eval_images(img1, img2, item_only=True):
73 | metrics = {}
74 | metrics['l1'] = l1_loss(img1, img2)
75 | metrics['l2'] = l2_loss(img1, img2)
76 | metrics['ssim'] = ssim(img1, img2)
77 | metrics['mse'] = mse(img1, img2)
78 | metrics['psnr'] = psnr(img1, img2)
79 | metrics['lpips'] = lpips(img1, img2, net_type='vgg')
80 | return metrics
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import random
4 | import argparse
5 | import numpy as np
6 | from PIL import Image
7 | from omegaconf import OmegaConf
8 | from typing import Callable, Dict
9 | import psutil
10 |
11 | def get_ram_usage():
12 | process = psutil.Process(os.getpid())
13 | return process.memory_info().rss / (1024 * 1024 * 1024) # Memory usage in MB
14 |
15 | def get_available_ram():
16 | mem = psutil.virtual_memory()
17 | return mem.available / (1024 * 1024 * 1024) # Available memory in MB
18 |
19 | def dict_to_namespace(cfg_dict):
20 | args = argparse.Namespace()
21 | for key in cfg_dict:
22 | setattr(args, key, cfg_dict[key])
23 | return args
24 |
25 | def move_to_device(dct, device):
26 | for key, value in dct.items():
27 | if isinstance(value, torch.Tensor):
28 | dct[key] = value.to(device)
29 | return dct
30 |
31 | def slice_trajdict_with_t(data_dict, start_idx=0, end_idx=None, step=1):
32 | if end_idx is None:
33 | end_idx = max(arr.shape[1] for arr in data_dict.values())
34 | return {key: arr[:, start_idx:end_idx:step, ...] for key, arr in data_dict.items()}
35 |
36 | def concat_trajdict(dcts):
37 | full_dct = {}
38 | for k in dcts[0].keys():
39 | if isinstance(dcts[0][k], np.ndarray):
40 | full_dct[k] = np.concatenate([dct[k] for dct in dcts], axis=1)
41 | elif isinstance(dcts[0][k], torch.Tensor):
42 | full_dct[k] = torch.cat([dct[k] for dct in dcts], dim=1)
43 | else:
44 | raise TypeError(f"Unsupported data type: {type(dcts[0][k])}")
45 | return full_dct
46 |
47 | def aggregate_dct(dcts):
48 | full_dct = {}
49 | for dct in dcts:
50 | for key, value in dct.items():
51 | if key not in full_dct:
52 | full_dct[key] = []
53 | full_dct[key].append(value)
54 | for key, value in full_dct.items():
55 | if isinstance(value[0], torch.Tensor):
56 | full_dct[key] = torch.stack(value)
57 | else:
58 | full_dct[key] = np.stack(value)
59 | return full_dct
60 |
61 | def sample_tensors(tensors, n, indices=None):
62 | if indices is None:
63 | b = tensors[0].shape[0]
64 | indices = torch.randperm(b)[:n]
65 | indices = torch.tensor(indices)
66 | for i, tensor in enumerate(tensors):
67 | if tensor is not None:
68 | tensors[i] = tensor[indices]
69 | return tensors
70 |
71 |
72 | def cfg_to_dict(cfg):
73 | cfg_dict = OmegaConf.to_container(cfg)
74 | for key in cfg_dict:
75 | if isinstance(cfg_dict[key], list):
76 | cfg_dict[key] = ",".join(cfg_dict[key])
77 | return cfg_dict
78 |
79 | def reduce_dict(f: Callable, d: Dict):
80 | return {k: reduce_dict(f, v) if isinstance(v, dict) else f(v) for k, v in d.items()}
81 |
82 | def seed(seed):
83 | random.seed(seed)
84 | torch.manual_seed(seed)
85 | np.random.seed(seed)
86 | if torch.cuda.is_available():
87 | torch.cuda.manual_seed_all(seed)
88 |
89 |
90 | def pil_loader(path):
91 | with open(path, "rb") as f:
92 | with Image.open(f) as img:
93 | return img.convert("RGB")
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/data_gen/data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import h5py
3 |
4 |
5 | def store_data(filename, data, action):
6 | """
7 | action: (action_dim,)
8 | imgs_list: (T, num_cameras, H, W, 5)
9 | particle_pos_list: (T, N, 3)
10 | eef_states_list: (T, 14)
11 | """
12 | # load data
13 | imgs_list, particle_pos_list, eef_states_list = data
14 | imgs_list_np, particle_pos_list_np, eef_states_list_np = (
15 | np.array(imgs_list),
16 | np.array(particle_pos_list),
17 | np.array(eef_states_list),
18 | )
19 |
20 | # stat
21 | T, n_cam = imgs_list_np.shape[:2]
22 | n_particles = particle_pos_list_np.shape[1]
23 |
24 | # process images
25 | color_imgs, depth_imgs = process_imgs(imgs_list_np)
26 |
27 | # init episode data
28 | episode = {
29 | "info": {"n_cams": n_cam, "timestamp": T, "n_particles": n_particles},
30 | "action": action,
31 | "positions": particle_pos_list_np,
32 | "eef_states": eef_states_list_np,
33 | "observations": {"color": color_imgs, "depth": depth_imgs},
34 | }
35 |
36 | # save to h5py
37 | save_data(filename, episode)
38 |
39 |
40 | def process_imgs(imgs_list):
41 | T, n_cam, H, W, _ = imgs_list.shape
42 | color_imgs = {}
43 | depth_imgs = {}
44 |
45 | for cam_idx in range(n_cam):
46 | img = imgs_list[:, cam_idx] # (T, H, W, 5)
47 | color_imgs[f"cam_{cam_idx}"] = img[:, :, :, :3][..., ::-1] # (T, H, W, 3)
48 | depth_imgs[f"cam_{cam_idx}"] = (img[:, :, :, -1] * 1000).astype(
49 | np.uint16
50 | ) # (T, H, W)
51 |
52 | assert color_imgs["cam_0"].shape == (T, H, W, 3)
53 | assert depth_imgs["cam_0"].shape == (T, H, W)
54 |
55 | return color_imgs, depth_imgs
56 |
57 |
58 | def save_data(filename, save_data):
59 | with h5py.File(filename, "w") as f:
60 | for key, value in save_data.items():
61 | if key in ["observations"]:
62 | for sub_key, sub_value in value.items():
63 | for subsub_key, subsub_value in sub_value.items():
64 | f.create_dataset(
65 | f"{key}/{sub_key}/{subsub_key}", data=subsub_value
66 | )
67 | elif key in ["info"]:
68 | for sub_key, sub_value in value.items():
69 | f.create_dataset(f"{key}/{sub_key}", data=sub_value)
70 | else:
71 | f.create_dataset(key, data=value)
72 |
73 |
74 | def load_data(filename):
75 | data = {}
76 | with h5py.File(filename, "r") as f:
77 | for key in f.keys():
78 | if key in ["observations"]:
79 | data[key] = {}
80 | for sub_key in f[key].keys():
81 | data[key][sub_key] = {}
82 | for subsub_key in f[key][sub_key].keys():
83 | data[key][sub_key][subsub_key] = f[key][sub_key][subsub_key][()]
84 | elif key in ["info"]:
85 | data[key] = {}
86 | for sub_key in f[key].keys():
87 | data[key][sub_key] = f[key][sub_key][()]
88 | else:
89 | data[key] = f[key][()]
90 | return data
91 |
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/sim_env/cameras.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pyflex
3 |
4 | class Camera():
5 | def __init__(self, screenWidth, screenHeight):
6 |
7 | self.screenWidth = screenWidth
8 | self.screenHeight = screenHeight
9 |
10 | self.num_cameras = 4
11 | self.camera_view = None
12 |
13 | self.cam_dis = 6.
14 | self.cam_height = 10.
15 | self.cam_deg = np.array([0., 90., 180., 270.]) + 45.
16 |
17 | def set_init_camera(self, camera_view):
18 | self.camera_view = camera_view
19 |
20 | if self.camera_view == 0: # top view
21 | self.camPos = np.array([0., self.cam_height+10., 0.])
22 | self.camAngle = np.array([0., -np.deg2rad(90.), 0.])
23 | elif self.camera_view == 1:
24 | self.camPos = np.array([self.cam_dis, self.cam_height, self.cam_dis])
25 | self.camAngle = np.array([np.deg2rad(self.cam_deg[0]), -np.deg2rad(45.), 0.])
26 | elif self.camera_view == 2:
27 | self.camPos = np.array([self.cam_dis, self.cam_height, -self.cam_dis])
28 | self.camAngle = np.array([np.deg2rad(self.cam_deg[1]), -np.deg2rad(45.), 0.])
29 | elif self.camera_view == 3:
30 | self.camPos = np.array([-self.cam_dis, self.cam_height, -self.cam_dis])
31 | self.camAngle = np.array([np.deg2rad(self.cam_deg[2]), -np.deg2rad(45.), 0.])
32 | elif self.camera_view == 4:
33 | self.camPos = np.array([-self.cam_dis, self.cam_height, self.cam_dis])
34 | self.camAngle = np.array([np.deg2rad(self.cam_deg[3]), -np.deg2rad(45.), 0.])
35 | else:
36 | raise ValueError('camera_view not defined')
37 |
38 | # set camera
39 | pyflex.set_camPos(self.camPos)
40 | pyflex.set_camAngle(self.camAngle)
41 |
42 | def init_multiview_cameras(self):
43 | self.camPos_list, self.camAngle_list = [], []
44 | self.cam_x_list = np.array([self.cam_dis, self.cam_dis, -self.cam_dis, -self.cam_dis])
45 | self.cam_z_list = np.array([self.cam_dis, -self.cam_dis, -self.cam_dis, self.cam_dis])
46 |
47 | self.rad_list = np.deg2rad(self.cam_deg)
48 | for i in range(self.num_cameras):
49 | self.camPos_list.append(np.array([self.cam_x_list[i], self.cam_height, self.cam_z_list[i]]))
50 | self.camAngle_list.append(np.array([self.rad_list[i], -np.deg2rad(45.), 0.]))
51 |
52 | self.cam_intrinsic_params = np.zeros([len(self.camPos_list), 4]) # [fx, fy, cx, cy]
53 | self.cam_extrinsic_matrix = np.zeros([len(self.camPos_list), 4, 4]) # [R, t]
54 |
55 | return self.camPos_list, self.camAngle_list, self.cam_intrinsic_params, self.cam_extrinsic_matrix
56 |
57 | def get_cam_params(self):
58 | # camera intrinsic parameters
59 | projMat = pyflex.get_projMatrix().reshape(4, 4).T
60 | cx = self.screenWidth / 2.0
61 | cy = self.screenHeight / 2.0
62 | fx = projMat[0, 0] * cx
63 | fy = projMat[1, 1] * cy
64 | camera_intrinsic_params = np.array([fx, fy, cx, cy])
65 |
66 | # camera extrinsic parameters
67 | cam_extrinsic_matrix = pyflex.get_viewMatrix().reshape(4, 4).T
68 |
69 | return camera_intrinsic_params, cam_extrinsic_matrix
70 |
--------------------------------------------------------------------------------
/env/serial_vector_env.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from utils import aggregate_dct
3 |
4 | # for debugging environments, and envs that are not compatible with SubprocVectorEnv
5 | class SerialVectorEnv:
6 | """
7 | obs, reward, done, info
8 | obs: dict, each key has shape (num_env, ...)
9 | reward: (num_env, )
10 | done: (num_env, )
11 | info: tuple of length num_env, each element is a dict
12 | """
13 |
14 | def __init__(self, envs):
15 | self.envs = envs
16 | self.num_envs = len(envs)
17 |
18 | def sample_random_init_goal_states(self, seed):
19 | init_state, goal_state = zip(*(self.envs[i].sample_random_init_goal_states(seed[i]) for i in range(self.num_envs)))
20 | return np.stack(init_state), np.stack(goal_state)
21 |
22 | def update_env(self, env_info):
23 | [self.envs[i].update_env(env_info[i]) for i in range(self.num_envs)]
24 |
25 | def eval_state(self, goal_state, cur_state):
26 | eval_result = []
27 | for i in range(self.num_envs):
28 | env = self.envs[i]
29 | eval_result.append(env.eval_state(goal_state[i], cur_state[i]))
30 | eval_result = aggregate_dct(eval_result)
31 | return eval_result
32 |
33 | def prepare(self, seed, init_state):
34 | """
35 | Reset with controlled init_state
36 | obs: (num_envs, H, W, C)
37 | state: tuple of num_envs dicts
38 | """
39 | obs = []
40 | state = []
41 | for i in range(self.num_envs):
42 | env = self.envs[i]
43 | cur_seed = seed[i]
44 | cur_init_state = init_state[i]
45 | o, s = env.prepare(cur_seed, cur_init_state)
46 | obs.append(o)
47 | state.append(s)
48 | obs = aggregate_dct(obs)
49 | state = np.stack(state)
50 | return obs, state
51 |
52 | def step_multiple(self, actions):
53 | """
54 | actions: (num_envs, T, action_dim)
55 | obses: (num_envs, T, H, W, C)
56 | infos: tuple of length num_envs, each element is a dict
57 | """
58 | obses = []
59 | rewards = []
60 | dones = []
61 | infos = []
62 | for i in range(self.num_envs):
63 | env = self.envs[i]
64 | cur_actions = actions[i]
65 | obs, reward, done, info = env.step_multiple(cur_actions)
66 | obses.append(obs)
67 | rewards.append(reward)
68 | dones.append(done)
69 | infos.append(info)
70 | obses = np.stack(obses)
71 | rewards = np.stack(rewards)
72 | dones = np.stack(dones)
73 | infos = tuple(infos)
74 | return obses, rewards, dones, infos
75 |
76 | def rollout(self, seed, init_state, actions):
77 | """
78 | only returns np arrays of observations and states
79 | obses: (num_envs, T, H, W, C)
80 | states: (num_envs, T, D)
81 | proprios: (num_envs, T, D_p)
82 | """
83 | obses = []
84 | states = []
85 | for i in range(self.num_envs):
86 | env = self.envs[i]
87 | cur_seed = seed[i]
88 | cur_init_state = init_state[i]
89 | cur_actions = actions[i]
90 | obs, state = env.rollout(cur_seed, cur_init_state, cur_actions)
91 | obses.append(obs)
92 | states.append(state)
93 | obses = aggregate_dct(obses)
94 | states = np.stack(states)
95 | return obses, states
96 |
--------------------------------------------------------------------------------
/env/pointmaze/waypoint_controller.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from . import q_iteration
3 | from .gridcraft import grid_env
4 | from .gridcraft import grid_spec
5 |
6 |
7 | ZEROS = np.zeros((2,), dtype=np.float32)
8 | ONES = np.zeros((2,), dtype=np.float32)
9 |
10 |
11 | class WaypointController(object):
12 | def __init__(self, maze_str, solve_thresh=0.1, p_gain=10.0, d_gain=-1.0):
13 | self.maze_str = maze_str
14 | self._target = -1000 * ONES
15 |
16 | self.p_gain = p_gain
17 | self.d_gain = d_gain
18 | self.solve_thresh = solve_thresh
19 | self.vel_thresh = 0.1
20 |
21 | self._waypoint_idx = 0
22 | self._waypoints = []
23 | self._waypoint_prev_loc = ZEROS
24 |
25 | self.env = grid_env.GridEnv(grid_spec.spec_from_string(maze_str))
26 |
27 | def current_waypoint(self):
28 | return self._waypoints[self._waypoint_idx]
29 |
30 | def get_action(self, location, velocity, target):
31 | if np.linalg.norm(self._target - np.array(self.gridify_state(target))) > 1e-3:
32 | #print('New target!', target, 'old:', self._target)
33 | self._new_target(location, target)
34 |
35 | dist = np.linalg.norm(location - self._target)
36 | vel = self._waypoint_prev_loc - location
37 | vel_norm = np.linalg.norm(vel)
38 | task_not_solved = (dist >= self.solve_thresh) or (vel_norm >= self.vel_thresh)
39 |
40 | if task_not_solved:
41 | next_wpnt = self._waypoints[self._waypoint_idx]
42 | else:
43 | next_wpnt = self._target
44 |
45 | # Compute control
46 | prop = next_wpnt - location
47 | action = self.p_gain * prop + self.d_gain * velocity
48 |
49 | dist_next_wpnt = np.linalg.norm(location - next_wpnt)
50 | if task_not_solved and (dist_next_wpnt < self.solve_thresh) and (vel_norm= 0 ] = 1.0
28 | pol_probs[pol_probs < 0 ] = 0.0
29 | else:
30 | pol_probs = np.exp((1.0/ent_wt)*adv_rew)
31 | pol_probs /= np.sum(pol_probs, axis=1, keepdims=True)
32 | assert np.all(np.isclose(np.sum(pol_probs, axis=1), 1.0)), str(pol_probs)
33 | return pol_probs
34 |
35 |
36 | def softq_iteration(env, transition_matrix=None, reward_matrix=None, num_itrs=50, discount=0.99, ent_wt=0.1, warmstart_q=None, policy=None):
37 | """
38 | Perform tabular soft Q-iteration
39 | """
40 | dim_obs = env.num_states
41 | dim_act = env.num_actions
42 | if reward_matrix is None:
43 | reward_matrix = env.reward_matrix()
44 | reward_matrix = reward_matrix[:,:,0]
45 |
46 | if warmstart_q is None:
47 | q_fn = np.zeros((dim_obs, dim_act))
48 | else:
49 | q_fn = warmstart_q
50 |
51 | if transition_matrix is None:
52 | t_matrix = env.transition_matrix()
53 | else:
54 | t_matrix = transition_matrix
55 |
56 | for k in range(num_itrs):
57 | if policy is None:
58 | v_fn = logsumexp(q_fn, alpha=ent_wt)
59 | else:
60 | v_fn = np.sum((q_fn - ent_wt*np.log(policy))*policy, axis=1)
61 | new_q = reward_matrix + discount*t_matrix.dot(v_fn)
62 | q_fn = new_q
63 | return q_fn
64 |
65 |
66 | def q_iteration(env, **kwargs):
67 | return softq_iteration(env, ent_wt=0.0, **kwargs)
68 |
69 |
70 | def compute_visitation(env, q_fn, ent_wt=1.0, env_time_limit=50, discount=1.0):
71 | pol_probs = get_policy(q_fn, ent_wt=ent_wt)
72 |
73 | dim_obs = env.num_states
74 | dim_act = env.num_actions
75 | state_visitation = np.zeros((dim_obs, 1))
76 | for (state, prob) in env.initial_state_distribution.items():
77 | state_visitation[state] = prob
78 | t_matrix = env.transition_matrix() # S x A x S
79 | sa_visit_t = np.zeros((dim_obs, dim_act, env_time_limit))
80 |
81 | for i in range(env_time_limit):
82 | sa_visit = state_visitation * pol_probs
83 | # sa_visit_t[:, :, i] = (discount ** i) * sa_visit
84 | sa_visit_t[:, :, i] = sa_visit
85 | # sum-out (SA)S
86 | new_state_visitation = np.einsum('ij,ijk->k', sa_visit, t_matrix)
87 | state_visitation = np.expand_dims(new_state_visitation, axis=1)
88 | return np.sum(sa_visit_t, axis=2) / float(env_time_limit)
89 |
90 |
91 | def compute_occupancy(env, q_fn, ent_wt=1.0, env_time_limit=50, discount=1.0):
92 | pol_probs = get_policy(q_fn, ent_wt=ent_wt)
93 |
94 | dim_obs = env.num_states
95 | dim_act = env.num_actions
96 | state_visitation = np.zeros((dim_obs, 1))
97 | for (state, prob) in env.initial_state_distribution.items():
98 | state_visitation[state] = prob
99 | t_matrix = env.transition_matrix() # S x A x S
100 | sa_visit_t = np.zeros((dim_obs, dim_act, env_time_limit))
101 |
102 | for i in range(env_time_limit):
103 | sa_visit = state_visitation * pol_probs
104 | sa_visit_t[:, :, i] = (discount ** i) * sa_visit
105 | # sa_visit_t[:, :, i] = sa_visit
106 | # sum-out (SA)S
107 | new_state_visitation = np.einsum('ij,ijk->k', sa_visit, t_matrix)
108 | state_visitation = np.expand_dims(new_state_visitation, axis=1)
109 | return np.sum(sa_visit_t, axis=2) #/ float(env_time_limit)
110 |
--------------------------------------------------------------------------------
/models/decoder/transposed_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import einops
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | def initialize_weights(m):
7 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
8 | nn.init.kaiming_uniform_(m.weight.data, nonlinearity="relu")
9 | nn.init.constant_(m.bias.data, 0)
10 | elif isinstance(m, nn.Linear):
11 | nn.init.kaiming_uniform_(m.weight.data)
12 | nn.init.constant_(m.bias.data, 0)
13 |
14 | def horizontal_forward(network, x, input_shape=(-1,), output_shape=(-1,)):
15 | batch_with_horizon_shape = x.shape[: -len(input_shape)]
16 | if not batch_with_horizon_shape:
17 | batch_with_horizon_shape = (1,)
18 | x = x.reshape(-1, *input_shape)
19 | x = network(x)
20 | x = x.reshape(*batch_with_horizon_shape, *output_shape)
21 | return x
22 |
23 | def create_normal_dist(
24 | x,
25 | std=None,
26 | mean_scale=1,
27 | init_std=0,
28 | min_std=0.1,
29 | activation=None,
30 | event_shape=None,
31 | ):
32 | if std == None:
33 | mean, std = torch.chunk(x, 2, -1)
34 | mean = mean / mean_scale
35 | if activation:
36 | mean = activation(mean)
37 | mean = mean_scale * mean
38 | std = F.softplus(std + init_std) + min_std
39 | else:
40 | mean = x
41 | dist = torch.distributions.Normal(mean, std)
42 | if event_shape:
43 | dist = torch.distributions.Independent(dist, event_shape)
44 | return dist
45 |
46 |
47 | class TransposedConvDecoder(nn.Module):
48 | def __init__(self, observation_shape=(3, 224, 224), emb_dim=512, activation=nn.ReLU, depth=32, kernel_size=5, stride=3):
49 | super().__init__()
50 |
51 | activation = activation()
52 | self.observation_shape = observation_shape
53 | self.depth = depth
54 | self.kernel_size = kernel_size
55 | self.stride = stride
56 | self.emb_dim = emb_dim
57 |
58 | self.network = nn.Sequential(
59 | nn.Linear(
60 | emb_dim, self.depth * 32
61 | ),
62 | nn.Unflatten(1, (self.depth * 32, 1)),
63 | nn.Unflatten(2, (1,1)),
64 | nn.ConvTranspose2d(
65 | self.depth * 32,
66 | self.depth * 8,
67 | self.kernel_size,
68 | self.stride,
69 | padding=1
70 | ),
71 | activation,
72 | nn.ConvTranspose2d(
73 | self.depth * 8,
74 | self.depth * 4,
75 | self.kernel_size,
76 | self.stride,
77 | padding=1
78 | ),
79 | activation,
80 | nn.ConvTranspose2d(
81 | self.depth * 4,
82 | self.depth * 2,
83 | self.kernel_size,
84 | self.stride,
85 | padding=1
86 | ),
87 | activation,
88 | nn.ConvTranspose2d(
89 | self.depth * 2,
90 | self.depth * 1,
91 | self.kernel_size,
92 | self.stride,
93 | padding=1
94 | ),
95 | activation,
96 | nn.ConvTranspose2d(
97 | self.depth * 1,
98 | self.observation_shape[0],
99 | self.kernel_size,
100 | self.stride,
101 | padding=1
102 | ),
103 | nn.Upsample(size=(observation_shape[1], observation_shape[2]), mode='bilinear', align_corners=False)
104 | )
105 | self.network.apply(initialize_weights)
106 |
107 | def forward(self, posterior):
108 | x = horizontal_forward(
109 | self.network, posterior, input_shape=[self.emb_dim],output_shape=self.observation_shape
110 | )
111 | dist = create_normal_dist(x, std=1, event_shape=len(self.observation_shape))
112 | img = dist.mean.squeeze(2)
113 | img = einops.rearrange(img, "b t c h w -> (b t) c h w")
114 | return img, torch.zeros(1).to(posterior.device) # dummy placeholder
--------------------------------------------------------------------------------
/env/pusht/pusht_wrapper.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import gym
4 | from env.pusht.pusht_env import PushTEnv
5 | from utils import aggregate_dct
6 |
7 | class PushTWrapper(PushTEnv):
8 | def __init__(
9 | self,
10 | with_velocity=True,
11 | with_target=True,
12 | ):
13 | super().__init__(
14 | with_velocity=with_velocity,
15 | with_target=with_target,
16 | )
17 | self.action_dim = self.action_space.shape[0]
18 |
19 | def sample_random_init_goal_states(self, seed):
20 | """
21 | Return two random states: one as the initial state and one as the goal state.
22 | """
23 | rs = np.random.RandomState(seed)
24 |
25 | def generate_state():
26 | if self.with_velocity:
27 | return np.array(
28 | [
29 | rs.randint(50, 450),
30 | rs.randint(50, 450),
31 | rs.randint(100, 400),
32 | rs.randint(100, 400),
33 | rs.randn() * 2 * np.pi - np.pi,
34 | 0,
35 | 0, # agent velocities default 0
36 | ]
37 | )
38 | else:
39 | return np.array(
40 | [
41 | rs.randint(50, 450),
42 | rs.randint(50, 450),
43 | rs.randint(100, 400),
44 | rs.randint(100, 400),
45 | rs.randn() * 2 * np.pi - np.pi,
46 | ]
47 | )
48 |
49 | init_state = generate_state()
50 | goal_state = generate_state()
51 |
52 | return init_state, goal_state
53 |
54 | def update_env(self, env_info):
55 | self.shape = env_info['shape']
56 |
57 | def eval_state(self, goal_state, cur_state):
58 | """
59 | Return True if the goal is reached
60 | [agent_x, agent_y, T_x, T_y, angle, agent_vx, agent_vy]
61 | """
62 | # if position difference is < 20, and angle difference < np.pi/9, then success
63 | pos_diff = np.linalg.norm(goal_state[:4] - cur_state[:4])
64 | angle_diff = np.abs(goal_state[4] - cur_state[4])
65 | angle_diff = np.minimum(angle_diff, 2 * np.pi - angle_diff)
66 | success = pos_diff < 20 and angle_diff < np.pi / 9
67 | state_dist = np.linalg.norm(goal_state - cur_state)
68 | return {
69 | 'success': success,
70 | 'state_dist': state_dist,
71 | }
72 |
73 | def prepare(self, seed, init_state):
74 | """
75 | Reset with controlled init_state
76 | obs: (H W C)
77 | state: (state_dim)
78 | """
79 | self.seed(seed)
80 | self.reset_to_state = init_state
81 | obs, state = self.reset()
82 | return obs, state
83 |
84 | def step_multiple(self, actions):
85 | """
86 | infos: dict, each key has shape (T, ...)
87 | """
88 | obses = []
89 | rewards = []
90 | dones = []
91 | infos = []
92 | for action in actions:
93 | o, r, d, info = self.step(action)
94 | obses.append(o)
95 | rewards.append(r)
96 | dones.append(d)
97 | infos.append(info)
98 | obses = aggregate_dct(obses)
99 | rewards = np.stack(rewards)
100 | dones = np.stack(dones)
101 | infos = aggregate_dct(infos)
102 | return obses, rewards, dones, infos
103 |
104 | def rollout(self, seed, init_state, actions):
105 | """
106 | only returns np arrays of observations and states
107 | seed: int
108 | init_state: (state_dim, )
109 | actions: (T, action_dim)
110 | obses: dict (T, H, W, C)
111 | states: (T, D)
112 | """
113 | obs, state = self.prepare(seed, init_state)
114 | obses, rewards, dones, infos = self.step_multiple(actions)
115 | for k in obses.keys():
116 | obses[k] = np.vstack([np.expand_dims(obs[k], 0), obses[k]])
117 | states = np.vstack([np.expand_dims(state, 0), infos["state"]])
118 | states = np.stack(states)
119 | return obses, states
120 |
--------------------------------------------------------------------------------
/env/pointmaze/gridcraft/wrappers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .grid_env import REWARD, GridEnv
3 | from .wrappers import ObsWrapper
4 | from gym.spaces import Box
5 |
6 |
7 | class GridObsWrapper(ObsWrapper):
8 | def __init__(self, env):
9 | super(GridObsWrapper, self).__init__(env)
10 |
11 | def render(self):
12 | self.env.render()
13 |
14 |
15 |
16 | class EyesWrapper(ObsWrapper):
17 | def __init__(self, env, range=4, types=(REWARD,), angle_thresh=0.8):
18 | super(EyesWrapper, self).__init__(env)
19 | self.types = types
20 | self.range = range
21 | self.angle_thresh = angle_thresh
22 |
23 | eyes_low = np.ones(5*len(types))
24 | eyes_high = np.ones(5*len(types))
25 | low = np.r_[env.observation_space.low, eyes_low]
26 | high = np.r_[env.observation_space.high, eyes_high]
27 | self.__observation_space = Box(low, high)
28 |
29 | def wrap_obs(self, obs, info=None):
30 | gs = self.env.gs # grid spec
31 | xy = gs.idx_to_xy(self.env.obs_to_state(obs))
32 | #xy = np.array([x, y])
33 |
34 | extra_obs = []
35 | for tile_type in self.types:
36 | idxs = gs.find(tile_type).astype(np.float32) # N x 2
37 | # gather all idxs that are close
38 | diffs = idxs-np.expand_dims(xy, axis=0)
39 | dists = np.linalg.norm(diffs, axis=1)
40 | valid_idxs = np.where(dists <= self.range)[0]
41 | if len(valid_idxs) == 0:
42 | eye_data = np.array([0,0,0,0,0], dtype=np.float32)
43 | else:
44 | diffs = diffs[valid_idxs, :]
45 | dists = dists[valid_idxs]+1e-6
46 | cosines = diffs[:,0]/dists
47 | cosines = np.r_[cosines, 0]
48 | sines = diffs[:,1]/dists
49 | sines = np.r_[sines, 0]
50 | on_target = 0.0
51 | if np.any(dists<=1.0):
52 | on_target = 1.0
53 | eye_data = np.abs(np.array([on_target, np.max(cosines), np.min(cosines), np.max(sines), np.min(sines)]))
54 | eye_data[np.where(eye_data<=self.angle_thresh)] = 0
55 | extra_obs.append(eye_data)
56 | extra_obs = np.concatenate(extra_obs)
57 | obs = np.r_[obs, extra_obs]
58 | return obs
59 |
60 | def unwrap_obs(self, obs, info=None):
61 | if len(obs.shape) == 1:
62 | return obs[:-5*len(self.types)]
63 | else:
64 | return obs[:,:-5*len(self.types)]
65 |
66 | @property
67 | def observation_space(self):
68 | return self.__observation_space
69 |
70 |
71 | """
72 | class CoordinateWiseWrapper(GridObsWrapper):
73 | def __init__(self, env):
74 | assert isinstance(env, GridEnv)
75 | super(CoordinateWiseWrapper, self).__init__(env)
76 | self.gs = env.gs
77 | self.dO = self.gs.width+self.gs.height
78 |
79 | self.__observation_space = Box(0, 1, self.dO)
80 |
81 | def wrap_obs(self, obs, info=None):
82 | state = one_hot_to_flat(obs)
83 | xy = self.gs.idx_to_xy(state)
84 | x = flat_to_one_hot(xy[0], self.gs.width)
85 | y = flat_to_one_hot(xy[1], self.gs.height)
86 | obs = np.r_[x, y]
87 | return obs
88 |
89 | def unwrap_obs(self, obs, info=None):
90 |
91 | if len(obs.shape) == 1:
92 | x = obs[:self.gs.width]
93 | y = obs[self.gs.width:]
94 | x = one_hot_to_flat(x)
95 | y = one_hot_to_flat(y)
96 | state = self.gs.xy_to_idx(np.c_[x,y])
97 | return flat_to_one_hot(state, self.dO)
98 | else:
99 | raise NotImplementedError()
100 | """
101 |
102 |
103 | class RandomObsWrapper(GridObsWrapper):
104 | def __init__(self, env, dO):
105 | assert isinstance(env, GridEnv)
106 | super(RandomObsWrapper, self).__init__(env)
107 | self.gs = env.gs
108 | self.dO = dO
109 | self.obs_matrix = np.random.randn(self.dO, len(self.gs))
110 | self.__observation_space = Box(np.min(self.obs_matrix), np.max(self.obs_matrix),
111 | shape=(self.dO,), dtype=np.float32)
112 |
113 | def wrap_obs(self, obs, info=None):
114 | return np.inner(self.obs_matrix, obs)
115 |
116 | def unwrap_obs(self, obs, info=None):
117 | raise NotImplementedError()
118 |
119 |
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/assets/xarm/xarm_description/meshes/xarm6/collision/link6_vhacd.obj:
--------------------------------------------------------------------------------
1 | o convex_0
2 | v 0.033187 0.018503 -0.027247
3 | v -0.037691 -0.004112 -0.027247
4 | v -0.036680 -0.006625 -0.028254
5 | v 0.013576 -0.034758 0.000399
6 | v -0.012050 0.035092 0.000399
7 | v 0.001516 -0.051863 -0.020710
8 | v -0.035174 -0.014657 -0.000608
9 | v 0.037202 0.000916 0.000399
10 | v -0.009541 0.036598 -0.027749
11 | v 0.032685 -0.018178 -0.028254
12 | v 0.016093 0.034593 -0.000608
13 | v -0.032658 0.019510 -0.000608
14 | v -0.005016 -0.051863 -0.008652
15 | v -0.024620 -0.028233 -0.028254
16 | v 0.028153 -0.025719 -0.000608
17 | v -0.028642 0.025037 -0.027247
18 | v 0.013074 0.035092 -0.028254
19 | v 0.006542 -0.051863 -0.011667
20 | v -0.029646 -0.022706 0.000399
21 | v 0.028662 0.025046 -0.000608
22 | v 0.013576 -0.034758 -0.028254
23 | v 0.037711 -0.000091 -0.027749
24 | v -0.037691 0.004927 -0.000608
25 | v -0.034673 0.013984 -0.028254
26 | v 0.005029 0.037614 -0.000608
27 | v 0.036700 -0.009629 -0.000608
28 | v -0.031654 -0.021191 -0.027247
29 | v -0.007024 -0.051855 -0.014175
30 | v -0.022103 0.031072 -0.000608
31 | v -0.013054 -0.034767 0.000399
32 | v 0.022123 0.031072 -0.027247
33 | v 0.036700 0.009963 -0.000608
34 | v 0.028153 -0.025719 -0.027247
35 | v 0.001014 -0.051863 -0.007143
36 | v -0.014058 -0.034767 -0.028254
37 | v 0.005029 0.037614 -0.027247
38 | v -0.034171 0.014982 0.000399
39 | v 0.021119 0.030572 0.000399
40 | v -0.016073 0.034593 -0.027247
41 | v -0.028133 -0.025719 -0.000608
42 | v -0.008028 0.037098 -0.000608
43 | v 0.034692 -0.015664 -0.027247
44 | v 0.035194 0.012477 -0.028254
45 | v -0.005016 -0.051863 -0.019201
46 | v -0.036680 0.009955 -0.027247
47 | v 0.006542 -0.051863 -0.017193
48 | v -0.037691 -0.004112 -0.000608
49 | v 0.031674 -0.021191 -0.000608
50 | v -0.035174 -0.014657 -0.027247
51 | v -0.023616 0.029066 -0.028254
52 | v 0.033187 0.018503 -0.000608
53 | v 0.023636 -0.028732 0.000399
54 | v 0.028662 0.025046 -0.027247
55 | v 0.036700 0.009963 -0.027247
56 | v -0.005016 0.037614 -0.027247
57 | v 0.037711 0.004936 -0.000608
58 | v 0.037711 -0.004112 -0.027247
59 | v -0.032658 0.019510 -0.027247
60 | v -0.035676 0.012976 -0.000608
61 | v 0.022123 0.031072 -0.000608
62 | v -0.016073 0.034593 -0.000608
63 | v 0.020617 -0.031246 -0.028254
64 | v -0.028642 0.025037 -0.000608
65 | v 0.016093 0.034593 -0.027247
66 | f 36 11 64
67 | f 5 4 8
68 | f 3 10 14
69 | f 10 3 17
70 | f 13 6 18
71 | f 4 5 19
72 | f 14 10 21
73 | f 17 3 24
74 | f 3 14 27
75 | f 27 14 28
76 | f 4 19 30
77 | f 15 18 33
78 | f 13 18 34
79 | f 4 30 34
80 | f 30 13 34
81 | f 21 6 35
82 | f 14 21 35
83 | f 17 9 36
84 | f 25 11 36
85 | f 19 5 37
86 | f 5 29 37
87 | f 5 8 38
88 | f 25 5 38
89 | f 11 25 38
90 | f 19 7 40
91 | f 7 27 40
92 | f 27 28 40
93 | f 28 13 40
94 | f 13 30 40
95 | f 30 19 40
96 | f 5 25 41
97 | f 9 39 41
98 | f 33 10 42
99 | f 10 17 43
100 | f 22 10 43
101 | f 6 13 44
102 | f 13 28 44
103 | f 28 14 44
104 | f 35 6 44
105 | f 14 35 44
106 | f 3 2 45
107 | f 2 23 45
108 | f 24 3 45
109 | f 18 6 46
110 | f 33 18 46
111 | f 2 7 47
112 | f 7 19 47
113 | f 23 2 47
114 | f 19 37 47
115 | f 37 23 47
116 | f 26 8 48
117 | f 15 33 48
118 | f 42 26 48
119 | f 33 42 48
120 | f 2 3 49
121 | f 7 2 49
122 | f 3 27 49
123 | f 27 7 49
124 | f 9 17 50
125 | f 24 16 50
126 | f 17 24 50
127 | f 16 29 50
128 | f 39 9 50
129 | f 29 39 50
130 | f 1 20 51
131 | f 32 1 51
132 | f 8 32 51
133 | f 38 8 51
134 | f 20 38 51
135 | f 8 4 52
136 | f 18 15 52
137 | f 4 34 52
138 | f 34 18 52
139 | f 48 8 52
140 | f 15 48 52
141 | f 20 1 53
142 | f 17 31 53
143 | f 31 20 53
144 | f 43 17 53
145 | f 1 43 53
146 | f 1 32 54
147 | f 22 43 54
148 | f 43 1 54
149 | f 36 9 55
150 | f 25 36 55
151 | f 41 25 55
152 | f 9 41 55
153 | f 8 26 56
154 | f 32 8 56
155 | f 54 32 56
156 | f 22 54 56
157 | f 10 22 57
158 | f 42 10 57
159 | f 26 42 57
160 | f 56 26 57
161 | f 22 56 57
162 | f 12 16 58
163 | f 16 24 58
164 | f 24 45 58
165 | f 58 45 59
166 | f 37 12 59
167 | f 23 37 59
168 | f 45 23 59
169 | f 12 58 59
170 | f 31 11 60
171 | f 20 31 60
172 | f 38 20 60
173 | f 11 38 60
174 | f 29 5 61
175 | f 39 29 61
176 | f 5 41 61
177 | f 41 39 61
178 | f 6 21 62
179 | f 21 10 62
180 | f 10 33 62
181 | f 46 6 62
182 | f 33 46 62
183 | f 16 12 63
184 | f 29 16 63
185 | f 12 37 63
186 | f 37 29 63
187 | f 31 17 64
188 | f 11 31 64
189 | f 17 36 64
190 |
--------------------------------------------------------------------------------
/models/vit.py:
--------------------------------------------------------------------------------
1 | # adapted from https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py
2 | import torch
3 | from torch import nn
4 | from einops import rearrange, repeat
5 |
6 | # helpers
7 | NUM_FRAMES = 1
8 | NUM_PATCHES = 1
9 |
10 | def pair(t):
11 | return t if isinstance(t, tuple) else (t, t)
12 |
13 | def generate_mask_matrix(npatch, nwindow):
14 | zeros = torch.zeros(npatch, npatch)
15 | ones = torch.ones(npatch, npatch)
16 | rows = []
17 | for i in range(nwindow):
18 | row = torch.cat([ones] * (i+1) + [zeros] * (nwindow - i-1), dim=1)
19 | rows.append(row)
20 | mask = torch.cat(rows, dim=0).unsqueeze(0).unsqueeze(0)
21 | return mask
22 |
23 | class FeedForward(nn.Module):
24 | def __init__(self, dim, hidden_dim, dropout = 0.):
25 | super().__init__()
26 | self.net = nn.Sequential(
27 | nn.LayerNorm(dim),
28 | nn.Linear(dim, hidden_dim),
29 | nn.GELU(),
30 | nn.Dropout(dropout),
31 | nn.Linear(hidden_dim, dim),
32 | nn.Dropout(dropout)
33 | )
34 |
35 | def forward(self, x):
36 | return self.net(x)
37 |
38 | class Attention(nn.Module):
39 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
40 | super().__init__()
41 | inner_dim = dim_head * heads
42 | project_out = not (heads == 1 and dim_head == dim)
43 |
44 | self.heads = heads
45 | self.scale = dim_head ** -0.5
46 |
47 | self.norm = nn.LayerNorm(dim)
48 |
49 | self.attend = nn.Softmax(dim = -1)
50 | self.dropout = nn.Dropout(dropout)
51 |
52 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
53 |
54 | self.to_out = nn.Sequential(
55 | nn.Linear(inner_dim, dim),
56 | nn.Dropout(dropout)
57 | ) if project_out else nn.Identity()
58 | self.bias = generate_mask_matrix(NUM_PATCHES, NUM_FRAMES).to('cuda')
59 |
60 | def forward(self, x):
61 | (
62 | B,
63 | T,
64 | C,
65 | ) = x.size()
66 | x = self.norm(x)
67 |
68 | qkv = self.to_qkv(x).chunk(3, dim = -1)
69 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
70 |
71 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
72 | # apply causal mask
73 | dots = dots.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
74 |
75 | attn = self.attend(dots)
76 | attn = self.dropout(attn)
77 |
78 | out = torch.matmul(attn, v)
79 | out = rearrange(out, 'b h n d -> b n (h d)')
80 | return self.to_out(out)
81 |
82 | class Transformer(nn.Module):
83 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
84 | super().__init__()
85 | self.norm = nn.LayerNorm(dim)
86 | self.layers = nn.ModuleList([])
87 | for _ in range(depth):
88 | self.layers.append(nn.ModuleList([
89 | Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
90 | FeedForward(dim, mlp_dim, dropout = dropout)
91 | ]))
92 |
93 | def forward(self, x):
94 | for attn, ff in self.layers:
95 | x = attn(x) + x
96 | x = ff(x) + x
97 |
98 | return self.norm(x)
99 |
100 | class ViTPredictor(nn.Module):
101 | def __init__(self, *, num_patches, num_frames, dim, depth, heads, mlp_dim, pool='cls', dim_head=64, dropout=0., emb_dropout=0.):
102 | super().__init__()
103 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
104 |
105 | # update params for adding causal attention masks
106 | global NUM_FRAMES, NUM_PATCHES
107 | NUM_FRAMES = num_frames
108 | NUM_PATCHES = num_patches
109 |
110 | self.pos_embedding = nn.Parameter(torch.randn(1, num_frames * (num_patches), dim)) # dim for the pos encodings
111 | self.dropout = nn.Dropout(emb_dropout)
112 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
113 | self.pool = pool
114 |
115 | def forward(self, x): # x: (b, window_size * H/patch_size * W/patch_size, 384)
116 | b, n, _ = x.shape
117 | x = x + self.pos_embedding[:, :n]
118 | x = self.dropout(x)
119 | x = self.transformer(x)
120 | return x
--------------------------------------------------------------------------------
/env/wall/wall_env_wrapper.py:
--------------------------------------------------------------------------------
1 | from .envs.wall import DotWall
2 | from .envs.wall import WallDatasetConfig
3 | import numpy as np
4 | import torch
5 | import random
6 | from torchvision import transforms
7 |
8 |
9 | from utils import aggregate_dct
10 | from .data.wall_utils import generate_wall_layouts
11 |
12 | ENV_ACTION_DIM = 2
13 | STATE_RANGES = np.array([
14 | [16.6840, 46.9885],
15 | [4.0083,25.2532]
16 | ])
17 |
18 | DEFAULT_CFG = WallDatasetConfig(
19 | action_angle_noise=0.2,
20 | action_step_mean=1.0,
21 | action_step_std=0.4,
22 | action_lower_bd=0.2,
23 | action_upper_bd=1.8,
24 | batch_size=64,
25 | device='cuda',
26 | dot_std=1.7,
27 | border_wall_loc=5,
28 | fix_wall_batch_k=None,
29 | fix_wall=True,
30 | fix_door_location=30,
31 | fix_wall_location=32,
32 | exclude_wall_train='',
33 | exclude_door_train='',
34 | only_wall_val='',
35 | only_door_val='',
36 | wall_padding=20,
37 | door_padding=10,
38 | wall_width=6,
39 | door_space=4,
40 | num_train_layouts=-1,
41 | img_size=65,
42 | max_step=1,
43 | n_steps=17,
44 | n_steps_reduce_factor=1,
45 | size=20000,
46 | val_size=10000,
47 | train=True,
48 | repeat_actions=1
49 | )
50 |
51 | resize_transform = transforms.Resize((224, 224))
52 | TRANSFORM = resize_transform
53 |
54 | class WallEnvWrapper(DotWall):
55 | def __init__(self, rng=42, wall_config=DEFAULT_CFG, fix_wall=True, cross_wall=False, fix_wall_location=32, fix_door_location=10, device='cpu', **kwargs):
56 | super().__init__(rng, wall_config, fix_wall, cross_wall, fix_wall_location=fix_wall_location, fix_door_location=fix_door_location, device=device,**kwargs)
57 | self.action_dim = ENV_ACTION_DIM
58 | self.transform = TRANSFORM
59 |
60 | def eval_state(self, goal_state, cur_state):
61 | success = np.linalg.norm(goal_state[:2] - cur_state[:2]) < 4.5
62 | state_dist = np.linalg.norm(goal_state - cur_state)
63 | return {
64 | 'success': success,
65 | 'state_dist': state_dist,
66 | }
67 |
68 | def sample_random_init_goal_states(self, seed):
69 | """
70 | Return a random state
71 | """
72 | return self.generate_random_state(seed)
73 |
74 | def update_env(self, env_info): # change door and wall locations
75 | self.wall_config.fix_door_location = env_info["fix_door_location"].item()
76 | self.wall_config.fix_wall_location = env_info["fix_wall_location"].item()
77 | layouts, other_layouts = generate_wall_layouts(self.wall_config)
78 | self.layouts = layouts
79 | self.wall_x, self.hole_y = self._generate_wall()
80 |
81 | def prepare(self, seed, init_state):
82 | """
83 | Reset with controlled init_state
84 | """
85 | self.seed(seed)
86 | self.set_init_state(init_state)
87 | obs, state = self.reset()
88 | obs['visual'] = self.transform(obs['visual'])
89 | obs['visual'] = obs['visual'].permute(1, 2, 0)
90 | return obs, state
91 |
92 | def step_multiple(self, actions):
93 | obses = []
94 | rewards = []
95 | dones = []
96 | infos = []
97 | for action in actions:
98 | o, r, d, info = self.step(action)
99 | o['visual'] = self.transform(o['visual'])
100 | o['visual'] = o['visual'].permute(1, 2, 0)
101 | obses.append(o)
102 | rewards.append(r)
103 | dones.append(d)
104 | infos.append(info)
105 | obses = aggregate_dct(obses)
106 | rewards = np.stack(rewards)
107 | dones = np.stack(dones)
108 | infos = aggregate_dct(infos)
109 | return obses, rewards, dones, infos
110 |
111 | def rollout(self, seed, init_state, actions):
112 | """
113 | only returns np arrays of observations and states
114 | seed: int
115 | init_state: (state_dim, )
116 | actions: (T, action_dim)
117 | obses: dict (T, H, W, C)
118 | states: (T, D)
119 | """
120 | obs, state = self.prepare(seed, init_state)
121 | obses, rewards, dones, infos = self.step_multiple(actions)
122 | for k in obses.keys():
123 | obses[k] = np.vstack([np.expand_dims(obs[k], 0), obses[k]])
124 | states = np.vstack([np.expand_dims(state, 0), infos["state"]])
125 | states = np.stack(states)
126 | return obses, states
127 |
128 |
--------------------------------------------------------------------------------
/env/pointmaze/gridcraft/grid_spec.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | EMPTY = 110
5 | WALL = 111
6 | START = 112
7 | REWARD = 113
8 | OUT_OF_BOUNDS = 114
9 | REWARD2 = 115
10 | REWARD3 = 116
11 | REWARD4 = 117
12 | LAVA = 118
13 | GOAL = 119
14 |
15 | TILES = {EMPTY, WALL, START, REWARD, REWARD2, REWARD3, REWARD4, LAVA, GOAL}
16 |
17 | STR_MAP = {
18 | 'O': EMPTY,
19 | '#': WALL,
20 | 'S': START,
21 | 'R': REWARD,
22 | '2': REWARD2,
23 | '3': REWARD3,
24 | '4': REWARD4,
25 | 'G': GOAL,
26 | 'L': LAVA
27 | }
28 |
29 | RENDER_DICT = {v:k for k, v in STR_MAP.items()}
30 | RENDER_DICT[EMPTY] = ' '
31 | RENDER_DICT[START] = ' '
32 |
33 |
34 |
35 | def spec_from_string(s, valmap=STR_MAP):
36 | if s.endswith('\\'):
37 | s = s[:-1]
38 | rows = s.split('\\')
39 | rowlens = np.array([len(row) for row in rows])
40 | assert np.all(rowlens == rowlens[0])
41 | w, h = len(rows), len(rows[0])#len(rows[0]), len(rows)
42 |
43 | gs = GridSpec(w, h)
44 | for i in range(w):
45 | for j in range(h):
46 | gs[i,j] = valmap[rows[i][j]]
47 | return gs
48 |
49 |
50 | def spec_from_sparse_locations(w, h, tile_to_locs):
51 | """
52 |
53 | Example usage:
54 | >> spec_from_sparse_locations(10, 10, {START: [(0,0)], REWARD: [(7,8), (8,8)]})
55 |
56 | """
57 | gs = GridSpec(w, h)
58 | for tile_type in tile_to_locs:
59 | locs = np.array(tile_to_locs[tile_type])
60 | for i in range(locs.shape[0]):
61 | gs[tuple(locs[i])] = tile_type
62 | return gs
63 |
64 |
65 | def local_spec(map, xpnt):
66 | """
67 | >>> local_spec("yOy\\\\Oxy", xpnt=(5,5))
68 | array([[4, 4],
69 | [6, 4],
70 | [6, 5]])
71 | """
72 | Y = 0; X=1; O=2
73 | valmap={
74 | 'y': Y,
75 | 'x': X,
76 | 'O': O
77 | }
78 | gs = spec_from_string(map, valmap=valmap)
79 | ys = gs.find(Y)
80 | x = gs.find(X)
81 | result = ys-x + np.array(xpnt)
82 | return result
83 |
84 |
85 |
86 | class GridSpec(object):
87 | def __init__(self, w, h):
88 | self.__data = np.zeros((w, h), dtype=np.int32)
89 | self.__w = w
90 | self.__h = h
91 |
92 | def __setitem__(self, key, val):
93 | self.__data[key] = val
94 |
95 | def __getitem__(self, key):
96 | if self.out_of_bounds(key):
97 | raise NotImplementedError("Out of bounds:"+str(key))
98 | return self.__data[tuple(key)]
99 |
100 | def out_of_bounds(self, wh):
101 | """ Return true if x, y is out of bounds """
102 | w, h = wh
103 | if w<0 or w>=self.__w:
104 | return True
105 | if h < 0 or h >= self.__h:
106 | return True
107 | return False
108 |
109 | def get_neighbors(self, k, xy=False):
110 | """ Return values of up, down, left, and right tiles """
111 | if not xy:
112 | k = self.idx_to_xy(k)
113 | offsets = [np.array([0,-1]), np.array([0,1]),
114 | np.array([-1,0]), np.array([1,0])]
115 | neighbors = \
116 | [self[k+offset] if (not self.out_of_bounds(k+offset)) else OUT_OF_BOUNDS for offset in offsets ]
117 | return neighbors
118 |
119 | def get_value(self, k, xy=False):
120 | """ Return values of up, down, left, and right tiles """
121 | if not xy:
122 | k = self.idx_to_xy(k)
123 | return self[k]
124 |
125 | def find(self, value):
126 | return np.array(np.where(self.spec == value)).T
127 |
128 | @property
129 | def spec(self):
130 | return self.__data
131 |
132 | @property
133 | def width(self):
134 | return self.__w
135 |
136 | def __len__(self):
137 | return self.__w*self.__h
138 |
139 | @property
140 | def height(self):
141 | return self.__h
142 |
143 | def idx_to_xy(self, idx):
144 | if hasattr(idx, '__len__'): # array
145 | x = idx % self.__w
146 | y = np.floor(idx/self.__w).astype(np.int32)
147 | xy = np.c_[x,y]
148 | return xy
149 | else:
150 | return np.array([ idx % self.__w, int(np.floor(idx/self.__w))])
151 |
152 | def xy_to_idx(self, key):
153 | shape = np.array(key).shape
154 | if len(shape) == 1:
155 | return key[0] + key[1]*self.__w
156 | elif len(shape) == 2:
157 | return key[:,0] + key[:,1]*self.__w
158 | else:
159 | raise NotImplementedError()
160 |
161 | def __hash__(self):
162 | data = (self.__w, self.__h) + tuple(self.__data.reshape([-1]).tolist())
163 | return hash(data)
164 |
--------------------------------------------------------------------------------
/env/pointmaze/dynamic_mjc.py:
--------------------------------------------------------------------------------
1 | """
2 | dynamic_mjc.py
3 | A small library for programmatically building MuJoCo XML files
4 | """
5 | from contextlib import contextmanager
6 | import tempfile
7 | import numpy as np
8 |
9 |
10 | def default_model(name):
11 | """
12 | Get a model with basic settings such as gravity and RK4 integration enabled
13 | """
14 | model = MJCModel(name)
15 | root = model.root
16 |
17 | # Setup
18 | root.compiler(angle="radian", inertiafromgeom="true")
19 | default = root.default()
20 | default.joint(armature=1, damping=1, limited="true")
21 | default.geom(contype=0, friction='1 0.1 0.1', rgba='0.7 0.7 0 1')
22 | root.option(gravity="0 0 -9.81", integrator="RK4", timestep=0.01)
23 | return model
24 |
25 | def pointmass_model(name):
26 | """
27 | Get a model with basic settings such as gravity and Euler integration enabled
28 | """
29 | model = MJCModel(name)
30 | root = model.root
31 |
32 | # Setup
33 | root.compiler(angle="radian", inertiafromgeom="true", coordinate="local")
34 | default = root.default()
35 | default.joint(limited="false", damping=1)
36 | default.geom(contype=2, conaffinity="1", condim="1", friction=".5 .1 .1", density="1000", margin="0.002")
37 | root.option(timestep=0.01, gravity="0 0 0", iterations="20", integrator="Euler")
38 | return model
39 |
40 |
41 | class MJCModel(object):
42 | def __init__(self, name):
43 | self.name = name
44 | self.root = MJCTreeNode("mujoco").add_attr('model', name)
45 |
46 | @contextmanager
47 | def asfile(self):
48 | """
49 | Usage:
50 | model = MJCModel('reacher')
51 | with model.asfile() as f:
52 | print f.read() # prints a dump of the model
53 | """
54 | with tempfile.NamedTemporaryFile(mode='w+', suffix='.xml', delete=True) as f:
55 | self.root.write(f)
56 | f.seek(0)
57 | yield f
58 |
59 | def open(self):
60 | self.file = tempfile.NamedTemporaryFile(mode='w+', suffix='.xml', delete=True)
61 | self.root.write(self.file)
62 | self.file.seek(0)
63 | return self.file
64 |
65 | def close(self):
66 | self.file.close()
67 |
68 | def find_attr(self, attr, value):
69 | return self.root.find_attr(attr, value)
70 |
71 | def __getstate__(self):
72 | return {}
73 |
74 | def __setstate__(self, state):
75 | pass
76 |
77 |
78 | class MJCTreeNode(object):
79 | def __init__(self, name):
80 | self.name = name
81 | self.attrs = {}
82 | self.children = []
83 |
84 | def add_attr(self, key, value):
85 | if isinstance(value, str):
86 | pass
87 | elif isinstance(value, list) or isinstance(value, np.ndarray):
88 | value = ' '.join([str(val).lower() for val in value])
89 | else:
90 | value = str(value).lower()
91 |
92 | self.attrs[key] = value
93 | return self
94 |
95 | def __getattr__(self, name):
96 | def wrapper(**kwargs):
97 | newnode = MJCTreeNode(name)
98 | for (k, v) in kwargs.items():
99 | newnode.add_attr(k, v)
100 | self.children.append(newnode)
101 | return newnode
102 | return wrapper
103 |
104 | def dfs(self):
105 | yield self
106 | if self.children:
107 | for child in self.children:
108 | for node in child.dfs():
109 | yield node
110 |
111 | def find_attr(self, attr, value):
112 | """ Run DFS to find a matching attr """
113 | if attr in self.attrs and self.attrs[attr] == value:
114 | return self
115 | for child in self.children:
116 | res = child.find_attr(attr, value)
117 | if res is not None:
118 | return res
119 | return None
120 |
121 |
122 | def write(self, ostream, tabs=0):
123 | contents = ' '.join(['%s="%s"'%(k,v) for (k,v) in self.attrs.items()])
124 | if self.children:
125 | ostream.write('\t'*tabs)
126 | ostream.write('<%s %s>\n' % (self.name, contents))
127 | for child in self.children:
128 | child.write(ostream, tabs=tabs+1)
129 | ostream.write('\t'*tabs)
130 | ostream.write('%s>\n' % self.name)
131 | else:
132 | ostream.write('\t'*tabs)
133 | ostream.write('<%s %s/>\n' % (self.name, contents))
134 |
135 | def __str__(self):
136 | s = "<"+self.name
137 | s += ' '.join(['%s="%s"'%(k,v) for (k,v) in self.attrs.items()])
138 | return s+">"
139 |
--------------------------------------------------------------------------------
/env/deformable_env/src/sim/sim_env/robot_env.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pyflex
4 |
5 | import pybullet as p
6 | from bs4 import BeautifulSoup
7 |
8 | from .transformations import quaternion_from_matrix, quaternion_matrix
9 |
10 | class FlexRobotHelper:
11 | def __init__(self):
12 | self.transform_bullet_to_flex = np.array([
13 | [1, 0, 0, 0],
14 | [0, 0, 1, 0],
15 | [0, -1, 0, 0],
16 | [0, 0, 0, 1]])
17 | self.robotId = None
18 |
19 | def loadURDF(self, fileName, basePosition, baseOrientation, useFixedBase = True, globalScaling = 1.0):
20 | if self.robotId is None:
21 | # print("Loading robot from file: ", fileName)
22 | # fileName = os.path.join('/home/gary/AdaptiGraph/src/', fileName)
23 | # print("Loading robot from file: ", fileName)
24 | self.robotId = p.loadURDF(fileName, basePosition, baseOrientation, useFixedBase = useFixedBase, globalScaling = globalScaling)
25 | p.resetBasePositionAndOrientation(self.robotId, basePosition, baseOrientation)
26 |
27 | robot_path = fileName # changed the urdf file
28 | robot_path_par = os.path.abspath(os.path.join(robot_path, os.pardir))
29 | with open(robot_path, 'r') as f:
30 | robot = f.read()
31 | robot_data = BeautifulSoup(robot, 'xml')
32 | links = robot_data.find_all('link')
33 |
34 | # add the mesh to pyflex
35 | self.num_meshes = 0
36 | self.has_mesh = np.ones(len(links), dtype=bool)
37 |
38 | """
39 | XARM6 with gripper:
40 | 0: base_link;
41 | 1 - 6: link1 - link6; (without gripper - 7: stick/finger)
42 |
43 | 7: base_link;
44 | 8: left outer knuckle;
45 | 9: left finger;
46 | 10: left inner knuckle;
47 | 11: right outer knuckle;
48 | 12: right finger;
49 | 13: right inner knuckle;
50 | """
51 | for i in range(len(links)):
52 | link = links[i]
53 | if link.find_all('geometry'):
54 | mesh_name = link.find_all('geometry')[0].find_all('mesh')[0].get('filename')
55 | pyflex.add_mesh(os.path.join(robot_path_par, mesh_name), globalScaling, 0, np.ones(3), np.zeros(3), np.zeros(4), False)
56 | self.num_meshes += 1
57 | else:
58 | self.has_mesh[i] = False
59 |
60 | self.num_link = len(links)
61 | self.state_pre = None
62 |
63 | return self.robotId
64 |
65 | def resetJointState(self, i, pose):
66 | p.resetJointState(self.robotId, i, pose)
67 | return self.getRobotShapeStates()
68 |
69 | def getRobotShapeStates(self):
70 | # convert pybullet link state to pyflex link state
71 | state_cur = []
72 | base_com_pos, base_com_orn = p.getBasePositionAndOrientation(self.robotId)
73 | di = p.getDynamicsInfo(self.robotId, -1)
74 | local_inertial_pos, local_inertial_orn = di[3], di[4]
75 |
76 | pos_inv, orn_inv = p.invertTransform(local_inertial_pos, local_inertial_orn)
77 | pos, orn = p.multiplyTransforms(base_com_pos, base_com_orn, pos_inv, orn_inv)
78 |
79 | state_cur.append(list(pos) + [1] + list(orn))
80 |
81 | for l in range(self.num_link-1):
82 | ls = p.getLinkState(self.robotId, l)
83 | pos = ls[4]
84 | orn = ls[5]
85 | state_cur.append(list(pos) + [1] + list(orn))
86 |
87 | state_cur = np.array(state_cur)
88 |
89 | shape_states = np.zeros((self.num_meshes, 14))
90 | if self.state_pre is None:
91 | self.state_pre = state_cur.copy()
92 |
93 | mesh_idx = 0
94 | for i in range(self.num_link):
95 | if self.has_mesh[i]:
96 | # pos + [1]
97 | shape_states[mesh_idx, 0:3] = np.matmul(
98 | self.transform_bullet_to_flex, state_cur[i, :4])[:3]
99 | shape_states[mesh_idx, 3:6] = np.matmul(
100 | self.transform_bullet_to_flex, self.state_pre[i, :4])[:3]
101 | # orientation
102 | shape_states[mesh_idx, 6:10] = quaternion_from_matrix(
103 | np.matmul(self.transform_bullet_to_flex,
104 | quaternion_matrix(state_cur[i, 4:])))
105 | shape_states[mesh_idx, 10:14] = quaternion_from_matrix(
106 | np.matmul(self.transform_bullet_to_flex,
107 | quaternion_matrix(self.state_pre[i, 4:])))
108 | mesh_idx += 1
109 |
110 | self.state_pre = state_cur
111 | return shape_states
--------------------------------------------------------------------------------
/planning/gd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from einops import rearrange
4 | from .base_planner import BasePlanner
5 | from utils import move_to_device
6 |
7 |
8 | class GDPlanner(BasePlanner):
9 | def __init__(
10 | self,
11 | horizon,
12 | action_noise,
13 | sample_type,
14 | lr,
15 | opt_steps,
16 | eval_every,
17 | wm,
18 | action_dim,
19 | objective_fn,
20 | preprocessor,
21 | evaluator,
22 | wandb_run,
23 | logging_prefix="plan_0",
24 | log_filename="logs.json",
25 | **kwargs,
26 | ):
27 | super().__init__(
28 | wm,
29 | action_dim,
30 | objective_fn,
31 | preprocessor,
32 | evaluator,
33 | wandb_run,
34 | log_filename,
35 | )
36 | self.horizon = horizon
37 | self.action_noise = action_noise
38 | self.sample_type = sample_type
39 | self.lr = lr
40 | self.opt_steps = opt_steps
41 | self.eval_every = eval_every
42 | self.logging_prefix = logging_prefix
43 |
44 | def init_actions(self, obs_0, actions=None):
45 | """
46 | Initializes or appends actions for planning, ensuring the output shape is (b, self.horizon, action_dim).
47 | """
48 | n_evals = obs_0["visual"].shape[0]
49 | if actions is None:
50 | actions = torch.zeros(n_evals, 0, self.action_dim)
51 | device = actions.device
52 | t = actions.shape[1]
53 | remaining_t = self.horizon - t
54 |
55 | if remaining_t > 0:
56 | if self.sample_type == "randn":
57 | new_actions = torch.randn(n_evals, remaining_t, self.action_dim)
58 | elif self.sample_type == "zero": # zero action of env
59 | new_actions = torch.zeros(n_evals, remaining_t, self.action_dim)
60 | new_actions = rearrange(
61 | new_actions, "... (f d) -> ... f d", f=self.evaluator.frameskip
62 | )
63 | new_actions = self.preprocessor.normalize_actions(new_actions)
64 | new_actions = rearrange(new_actions, "... f d -> ... (f d)")
65 | actions = torch.cat([actions, new_actions.to(device)], dim=1)
66 | return actions
67 |
68 | def get_action_optimizer(self, actions):
69 | return torch.optim.SGD([actions], lr=self.lr)
70 |
71 | def plan(self, obs_0, obs_g, actions=None):
72 | """
73 | Args:
74 | actions: normalized
75 | Returns:
76 | actions: (B, T, action_dim) torch.Tensor
77 | """
78 | trans_obs_0 = move_to_device(
79 | self.preprocessor.transform_obs(obs_0), self.device
80 | )
81 | trans_obs_g = move_to_device(
82 | self.preprocessor.transform_obs(obs_g), self.device
83 | )
84 | z_obs_g = self.wm.encode_obs(trans_obs_g)
85 | z_obs_g_detached = {key: value.detach() for key, value in z_obs_g.items()}
86 |
87 | actions = self.init_actions(obs_0, actions).to(self.device)
88 | actions.requires_grad = True
89 | optimizer = self.get_action_optimizer(actions)
90 | n_evals = actions.shape[0]
91 |
92 | for i in range(self.opt_steps):
93 | optimizer.zero_grad()
94 | i_z_obses, i_zs = self.wm.rollout(
95 | obs_0=trans_obs_0,
96 | act=actions,
97 | )
98 | loss = self.objective_fn(i_z_obses, z_obs_g_detached) # (n_evals, )
99 | total_loss = loss.mean() * n_evals # loss for each eval is independent
100 | total_loss.backward()
101 | with torch.no_grad():
102 | actions_new = actions - optimizer.param_groups[0]["lr"] * actions.grad
103 | actions_new += (
104 | torch.randn_like(actions_new) * self.action_noise
105 | ) # Add Gaussian noise
106 | actions.copy_(actions_new)
107 |
108 | self.wandb_run.log(
109 | {f"{self.logging_prefix}/loss": total_loss.item(), "step": i + 1}
110 | )
111 | if self.evaluator is not None and i % self.eval_every == 0:
112 | logs, successes, _, _ = self.evaluator.eval_actions(
113 | actions.detach(), filename=f"{self.logging_prefix}_output_{i+1}"
114 | )
115 | logs = {f"{self.logging_prefix}/{k}": v for k, v in logs.items()}
116 | logs.update({"step": i + 1})
117 | self.wandb_run.log(logs)
118 | self.dump_logs(logs)
119 | if np.all(successes):
120 | break # terminate planning if all success
121 | return actions, np.full(n_evals, np.inf) # all actions are valid
122 |
--------------------------------------------------------------------------------
/planning/cem.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from einops import rearrange, repeat
4 | from .base_planner import BasePlanner
5 | from utils import move_to_device
6 |
7 |
8 | class CEMPlanner(BasePlanner):
9 | def __init__(
10 | self,
11 | horizon,
12 | topk,
13 | num_samples,
14 | var_scale,
15 | opt_steps,
16 | eval_every,
17 | wm,
18 | action_dim,
19 | objective_fn,
20 | preprocessor,
21 | evaluator,
22 | wandb_run,
23 | logging_prefix="plan_0",
24 | log_filename="logs.json",
25 | **kwargs,
26 | ):
27 | super().__init__(
28 | wm,
29 | action_dim,
30 | objective_fn,
31 | preprocessor,
32 | evaluator,
33 | wandb_run,
34 | log_filename,
35 | )
36 | self.horizon = horizon
37 | self.topk = topk
38 | self.num_samples = num_samples
39 | self.var_scale = var_scale
40 | self.opt_steps = opt_steps
41 | self.eval_every = eval_every
42 | self.logging_prefix = logging_prefix
43 |
44 | def init_mu_sigma(self, obs_0, actions=None):
45 | """
46 | actions: (B, T, action_dim) torch.Tensor, T <= self.horizon
47 | mu, sigma could depend on current obs, but obs_0 is only used for providing n_evals for now
48 | """
49 | n_evals = obs_0["visual"].shape[0]
50 | sigma = self.var_scale * torch.ones([n_evals, self.horizon, self.action_dim])
51 | if actions is None:
52 | mu = torch.zeros(n_evals, 0, self.action_dim)
53 | else:
54 | mu = actions
55 | device = mu.device
56 | t = mu.shape[1]
57 | remaining_t = self.horizon - t
58 |
59 | if remaining_t > 0:
60 | new_mu = torch.zeros(n_evals, remaining_t, self.action_dim)
61 | mu = torch.cat([mu, new_mu.to(device)], dim=1)
62 | return mu, sigma
63 |
64 | def plan(self, obs_0, obs_g, actions=None):
65 | """
66 | Args:
67 | actions: normalized
68 | Returns:
69 | actions: (B, T, action_dim) torch.Tensor, T <= self.horizon
70 | """
71 | trans_obs_0 = move_to_device(
72 | self.preprocessor.transform_obs(obs_0), self.device
73 | )
74 | trans_obs_g = move_to_device(
75 | self.preprocessor.transform_obs(obs_g), self.device
76 | )
77 | z_obs_g = self.wm.encode_obs(trans_obs_g)
78 |
79 | mu, sigma = self.init_mu_sigma(obs_0, actions)
80 | mu, sigma = mu.to(self.device), sigma.to(self.device)
81 | n_evals = mu.shape[0]
82 |
83 | for i in range(self.opt_steps):
84 | # optimize individual instances
85 | losses = []
86 | for traj in range(n_evals):
87 | cur_trans_obs_0 = {
88 | key: repeat(
89 | arr[traj].unsqueeze(0), "1 ... -> n ...", n=self.num_samples
90 | )
91 | for key, arr in trans_obs_0.items()
92 | }
93 | cur_z_obs_g = {
94 | key: repeat(
95 | arr[traj].unsqueeze(0), "1 ... -> n ...", n=self.num_samples
96 | )
97 | for key, arr in z_obs_g.items()
98 | }
99 | action = (
100 | torch.randn(self.num_samples, self.horizon, self.action_dim).to(
101 | self.device
102 | )
103 | * sigma[traj]
104 | + mu[traj]
105 | )
106 | action[0] = mu[traj] # optional: make the first one mu itself
107 | with torch.no_grad():
108 | i_z_obses, i_zs = self.wm.rollout(
109 | obs_0=cur_trans_obs_0,
110 | act=action,
111 | )
112 |
113 | loss = self.objective_fn(i_z_obses, cur_z_obs_g)
114 | topk_idx = torch.argsort(loss)[: self.topk]
115 | topk_action = action[topk_idx]
116 | losses.append(loss[topk_idx[0]].item())
117 | mu[traj] = topk_action.mean(dim=0)
118 | sigma[traj] = topk_action.std(dim=0)
119 |
120 | self.wandb_run.log(
121 | {f"{self.logging_prefix}/loss": np.mean(losses), "step": i + 1}
122 | )
123 | if self.evaluator is not None and i % self.eval_every == 0:
124 | logs, successes, _, _ = self.evaluator.eval_actions(
125 | mu, filename=f"{self.logging_prefix}_output_{i+1}"
126 | )
127 | logs = {f"{self.logging_prefix}/{k}": v for k, v in logs.items()}
128 | logs.update({"step": i + 1})
129 | self.wandb_run.log(logs)
130 | self.dump_logs(logs)
131 | if np.all(successes):
132 | break # terminate planning if all success
133 |
134 | return mu, np.full(n_evals, np.inf) # all actions are valid
135 |
--------------------------------------------------------------------------------
/models/encoder/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import torch.nn as nn
4 |
5 |
6 | class resnet18(nn.Module):
7 | def __init__(
8 | self,
9 | pretrained: bool = True,
10 | unit_norm: bool = False,
11 | ):
12 | super().__init__()
13 | resnet = torchvision.models.resnet18(pretrained=pretrained)
14 | self.resnet = nn.Sequential(*list(resnet.children())[:-1])
15 | self.flatten = nn.Flatten()
16 | self.pretrained = pretrained
17 | self.normalize = torchvision.transforms.Normalize(
18 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
19 | )
20 | self.unit_norm = unit_norm
21 |
22 | self.latent_ndim = 1
23 | self.emb_dim = 512
24 | self.name = "resnet"
25 |
26 | def forward(self, x):
27 | dims = len(x.shape)
28 | orig_shape = x.shape
29 | if dims == 3:
30 | x = x.unsqueeze(0)
31 | elif dims > 4:
32 | # flatten all dimensions to batch, then reshape back at the end
33 | x = x.reshape(-1, *orig_shape[-3:])
34 | x = self.normalize(x)
35 | out = self.resnet(x)
36 | out = self.flatten(out)
37 | if self.unit_norm:
38 | out = torch.nn.functional.normalize(out, p=2, dim=-1)
39 | if dims == 3:
40 | out = out.squeeze(0)
41 | elif dims > 4:
42 | out = out.reshape(*orig_shape[:-3], -1)
43 | out = out.unsqueeze(1)
44 | return out
45 |
46 |
47 | class resblock(nn.Module):
48 | # this implementation assumes square images
49 | def __init__(self, input_dim, output_dim, kernel_size, resample=None, hw=32):
50 | super(resblock, self).__init__()
51 | self.input_dim = input_dim
52 | self.output_dim = output_dim
53 | self.kernel_size = kernel_size
54 | self.resample = resample
55 |
56 | padding = int((kernel_size - 1) / 2)
57 |
58 | if resample == "down":
59 | self.skip = nn.Sequential(
60 | nn.AvgPool2d(2, stride=2),
61 | nn.Conv2d(input_dim, output_dim, kernel_size, padding=padding),
62 | )
63 | self.conv1 = nn.Conv2d(
64 | input_dim, input_dim, kernel_size, padding=padding, bias=False
65 | )
66 | self.conv2 = nn.Sequential(
67 | nn.Conv2d(input_dim, output_dim, kernel_size, padding=padding),
68 | nn.MaxPool2d(2, stride=2),
69 | )
70 | self.bn1 = nn.BatchNorm2d(input_dim)
71 | self.bn2 = nn.BatchNorm2d(output_dim)
72 | elif resample is None:
73 | self.skip = nn.Conv2d(input_dim, output_dim, 1)
74 | self.conv1 = nn.Conv2d(
75 | input_dim, output_dim, kernel_size, padding=padding, bias=False
76 | )
77 | self.conv2 = nn.Conv2d(output_dim, output_dim, kernel_size, padding=padding)
78 | self.bn1 = nn.BatchNorm2d(output_dim)
79 | self.bn2 = nn.BatchNorm2d(output_dim)
80 |
81 | self.leakyrelu1 = nn.LeakyReLU()
82 | self.leakyrelu2 = nn.LeakyReLU()
83 |
84 | def forward(self, x):
85 | if (self.input_dim == self.output_dim) and self.resample is None:
86 | idnty = x
87 | else:
88 | idnty = self.skip(x)
89 |
90 | residual = x
91 | residual = self.conv1(residual)
92 | residual = self.bn1(residual)
93 | residual = self.leakyrelu1(residual)
94 |
95 | residual = self.conv2(residual)
96 | residual = self.bn2(residual)
97 | residual = self.leakyrelu2(residual)
98 |
99 | return idnty + residual
100 |
101 |
102 | class SmallResNet(nn.Module):
103 | def __init__(self, output_dim=512):
104 | super(SmallResNet, self).__init__()
105 |
106 | self.hw = 224
107 |
108 | # 3x224x224
109 | self.rb1 = resblock(3, 16, 3, resample="down", hw=self.hw)
110 | # 16x112x112
111 | self.rb2 = resblock(16, 32, 3, resample="down", hw=self.hw // 2)
112 | # 32x56x56
113 | self.rb3 = resblock(32, 64, 3, resample="down", hw=self.hw // 4)
114 | # 64x28x28
115 | self.rb4 = resblock(64, 128, 3, resample="down", hw=self.hw // 8)
116 | # 128x14x14
117 | self.rb5 = resblock(128, 512, 3, resample="down", hw=self.hw // 16)
118 | # 512x7x7
119 | self.maxpool = nn.MaxPool2d(7)
120 | # 512x1x1
121 | self.flat = nn.Flatten()
122 |
123 | def forward(self, x):
124 | dims = len(x.shape)
125 | orig_shape = x.shape
126 | if dims == 3:
127 | x = x.unsqueeze(0)
128 | elif dims > 4:
129 | # flatten all dimensions to batch, then reshape back at the end
130 | x = x.reshape(-1, *orig_shape[-3:])
131 | x = self.rb1(x)
132 | x = self.rb2(x)
133 | x = self.rb3(x)
134 | x = self.rb4(x)
135 | x = self.rb5(x)
136 | x = self.maxpool(x)
137 | out = x.flatten(start_dim=-3)
138 | if dims == 3:
139 | out = out.squeeze(0)
140 | elif dims > 4:
141 | out = out.reshape(*orig_shape[:-3], -1)
142 | return out
143 |
--------------------------------------------------------------------------------
/models/encoder/r3m/models/models_r3m.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | import numpy as np
6 | from numpy.core.numeric import full
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn.modules.activation import Sigmoid
10 | from torch.nn.modules.linear import Identity
11 | import torchvision
12 | from torchvision import transforms
13 | from .. import utils
14 | from pathlib import Path
15 | from torchvision.utils import save_image
16 | import torchvision.transforms as T
17 |
18 | epsilon = 1e-8
19 |
20 |
21 | def do_nothing(x):
22 | return x
23 |
24 |
25 | class R3M(nn.Module):
26 | def __init__(
27 | self,
28 | device,
29 | lr,
30 | hidden_dim,
31 | size=34,
32 | l2weight=1.0,
33 | l1weight=1.0,
34 | langweight=1.0,
35 | tcnweight=0.0,
36 | l2dist=True,
37 | bs=16,
38 | ):
39 | super().__init__()
40 |
41 | self.device = device
42 | self.use_tb = False
43 | self.l2weight = l2weight
44 | self.l1weight = l1weight
45 | self.tcnweight = tcnweight ## Weight on TCN loss (states closer in same clip closer in embedding)
46 | self.l2dist = l2dist ## Use -l2 or cosine sim
47 | self.langweight = langweight ## Weight on language reward
48 | self.size = size ## Size ResNet or ViT
49 | self.num_negatives = 3
50 |
51 | self.latent_ndim = 1
52 | self.emb_dim = 512
53 | self.name = "r3m"
54 |
55 | ## Distances and Metrics
56 | self.cs = torch.nn.CosineSimilarity(1)
57 | self.bce = nn.BCELoss(reduce=False)
58 | self.sigm = Sigmoid()
59 |
60 | params = []
61 | ######################################################################## Sub Modules
62 | ## Visual Encoder
63 | if size == 18:
64 | self.outdim = 512
65 | self.convnet = torchvision.models.resnet18(pretrained=False)
66 | elif size == 34:
67 | self.outdim = 512
68 | self.convnet = torchvision.models.resnet34(pretrained=False)
69 | elif size == 50:
70 | self.outdim = 2048
71 | self.convnet = torchvision.models.resnet50(pretrained=False)
72 | elif size == 0:
73 | from transformers import AutoConfig
74 |
75 | self.outdim = 768
76 | self.convnet = AutoModel.from_config(
77 | config=AutoConfig.from_pretrained("google/vit-base-patch32-224-in21k")
78 | ).to(self.device)
79 |
80 | if self.size == 0:
81 | self.normlayer = transforms.Normalize(
82 | mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]
83 | )
84 | else:
85 | self.normlayer = transforms.Normalize(
86 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
87 | )
88 | self.convnet.fc = Identity()
89 | self.convnet.train()
90 | params += list(self.convnet.parameters())
91 |
92 | ## Language Reward
93 | if self.langweight > 0.0:
94 | ## Pretrained DistilBERT Sentence Encoder
95 | from r3m.models.models_language import LangEncoder, LanguageReward
96 |
97 | self.lang_enc = LangEncoder(self.device, 0, 0)
98 | self.lang_rew = LanguageReward(
99 | None, self.outdim, hidden_dim, self.lang_enc.lang_size, simfunc=self.sim
100 | )
101 | params += list(self.lang_rew.parameters())
102 | ########################################################################
103 |
104 | ## Optimizer
105 | self.encoder_opt = torch.optim.Adam(params, lr=lr)
106 |
107 | def get_reward(self, e0, es, sentences):
108 | ## Only callable is langweight was set to be 1
109 | le = self.lang_enc(sentences)
110 | return self.lang_rew(e0, es, le)
111 |
112 | ## Forward Call (im --> representation)
113 | def forward(self, obs, num_ims=1, obs_shape=[3, 224, 224]):
114 | if obs_shape != [3, 224, 224]:
115 | preprocess = nn.Sequential(
116 | transforms.Resize(256),
117 | transforms.CenterCrop(224),
118 | self.normlayer,
119 | )
120 | else:
121 | preprocess = nn.Sequential(
122 | self.normlayer,
123 | )
124 |
125 | ## Input must be [0, 1], [3,244,244]
126 | dims = len(obs.shape)
127 | orig_shape = obs.shape
128 | if dims == 3:
129 | obs = obs.unsqueeze(0)
130 | elif dims > 4:
131 | obs = obs.reshape(-1, *orig_shape[-3:])
132 | obs_p = preprocess(obs)
133 | h = self.convnet(obs_p)
134 | if dims == 3:
135 | h = h.squeeze(0)
136 | elif dims > 4:
137 | h = h.reshape(*orig_shape[:-3], -1)
138 | h = h.unsqueeze(1)
139 | return h
140 |
141 | def sim(self, tensor1, tensor2):
142 | if self.l2dist:
143 | d = -torch.linalg.norm(tensor1 - tensor2, dim=-1)
144 | else:
145 | d = self.cs(tensor1, tensor2)
146 | return d
147 |
--------------------------------------------------------------------------------
/models/encoder/r3m/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | from .models.models_r3m import R3M
6 |
7 | import os
8 | from os.path import expanduser
9 | import omegaconf
10 | import hydra
11 | import gdown
12 | import torch
13 | import copy
14 |
15 | VALID_ARGS = [
16 | "_target_",
17 | "device",
18 | "lr",
19 | "hidden_dim",
20 | "size",
21 | "l2weight",
22 | "l1weight",
23 | "langweight",
24 | "tcnweight",
25 | "l2dist",
26 | "bs",
27 | ]
28 | if torch.cuda.is_available():
29 | device = "cuda"
30 | else:
31 | device = "cpu"
32 |
33 |
34 | def cleanup_config(cfg):
35 | config = copy.deepcopy(cfg)
36 | keys = config.agent.keys()
37 | for key in list(keys):
38 | if key not in VALID_ARGS:
39 | del config.agent[key]
40 | config.agent["_target_"] = "models.encoder.r3m.R3M"
41 | config["device"] = device
42 |
43 | ## Hardcodes to remove the language head
44 | ## Assumes downstream use is as visual representation
45 | config.agent["langweight"] = 0
46 | return config.agent
47 |
48 |
49 | def remove_language_head(state_dict):
50 | keys = state_dict.keys()
51 | ## Hardcodes to remove the language head
52 | ## Assumes downstream use is as visual representation
53 | for key in list(keys):
54 | if ("lang_enc" in key) or ("lang_rew" in key):
55 | del state_dict[key]
56 | return state_dict
57 |
58 |
59 | def load_r3m(modelid):
60 | home = os.path.join(expanduser("~"), ".model_checkpoints", "r3m")
61 | if modelid == "resnet50":
62 | foldername = "r3m_50"
63 | modelurl = "https://drive.google.com/uc?id=1Xu0ssuG0N1zjZS54wmWzJ7-nb0-7XzbA"
64 | configurl = "https://drive.google.com/uc?id=10jY2VxrrhfOdNPmsFdES568hjjIoBJx8"
65 | elif modelid == "resnet34":
66 | foldername = "r3m_34"
67 | modelurl = "https://drive.google.com/uc?id=15bXD3QRhspIRacOKyWPw5y2HpoWUCEnE"
68 | configurl = "https://drive.google.com/uc?id=1RY0NS-Tl4G7M1Ik_lOym0b5VIBxX9dqW"
69 | elif modelid == "resnet18":
70 | foldername = "r3m_18"
71 | modelurl = "https://drive.google.com/uc?id=1A1ic-p4KtYlKXdXHcV2QV0cUzI4kn0u-"
72 | configurl = "https://drive.google.com/uc?id=1nitbHQ-GRorxc7vMUiEHjHWP5N11Jvc6"
73 | else:
74 | raise NameError("Invalid Model ID")
75 |
76 | if not os.path.exists(os.path.join(home, foldername)):
77 | os.makedirs(os.path.join(home, foldername))
78 | modelpath = os.path.join(home, foldername, "model.pt")
79 | configpath = os.path.join(home, foldername, "config.yaml")
80 | if not os.path.exists(modelpath):
81 | gdown.download(modelurl, modelpath, quiet=False)
82 | gdown.download(configurl, configpath, quiet=False)
83 |
84 | modelcfg = omegaconf.OmegaConf.load(configpath)
85 | cleancfg = cleanup_config(modelcfg)
86 | rep = hydra.utils.instantiate(cleancfg)
87 | rep = torch.nn.DataParallel(rep)
88 | r3m_state_dict = remove_language_head(
89 | torch.load(modelpath, map_location=torch.device(device))["r3m"]
90 | )
91 | rep.load_state_dict(r3m_state_dict)
92 | rep = rep.module
93 | return rep
94 |
95 | def load_r3m_reproduce(modelid):
96 | home = os.path.join(expanduser("~"), ".r3m")
97 | if modelid == "r3m":
98 | foldername = "original_r3m"
99 | modelurl = "https://drive.google.com/uc?id=1jLb1yldIMfAcGVwYojSQmMpmRM7vqjp9"
100 | configurl = "https://drive.google.com/uc?id=1cu-Pb33qcfAieRIUptNlG1AQIMZlAI-q"
101 | elif modelid == "r3m_noaug":
102 | foldername = "original_r3m_noaug"
103 | modelurl = "https://drive.google.com/uc?id=1k_ZlVtvlktoYLtBcfD0aVFnrZcyCNS9D"
104 | configurl = "https://drive.google.com/uc?id=1hPmJwDiWPkd6GGez6ywSC7UOTIX7NgeS"
105 | elif modelid == "r3m_nol1":
106 | foldername = "original_r3m_nol1"
107 | modelurl = "https://drive.google.com/uc?id=1LpW3aBMdjoXsjYlkaDnvwx7q22myM_nB"
108 | configurl = "https://drive.google.com/uc?id=1rZUBrYJZvlF1ReFwRidZsH7-xe7csvab"
109 | elif modelid == "r3m_nolang":
110 | foldername = "original_r3m_nolang"
111 | modelurl = "https://drive.google.com/uc?id=1FXcniRei2JDaGMJJ_KlVxHaLy0Fs_caV"
112 | configurl = "https://drive.google.com/uc?id=192G4UkcNJO4EKN46ECujMcH0AQVhnyQe"
113 | else:
114 | raise NameError("Invalid Model ID")
115 |
116 | if not os.path.exists(os.path.join(home, foldername)):
117 | os.makedirs(os.path.join(home, foldername))
118 | modelpath = os.path.join(home, foldername, "model.pt")
119 | configpath = os.path.join(home, foldername, "config.yaml")
120 | if not os.path.exists(modelpath):
121 | gdown.download(modelurl, modelpath, quiet=False)
122 | gdown.download(configurl, configpath, quiet=False)
123 |
124 | modelcfg = omegaconf.OmegaConf.load(configpath)
125 | cleancfg = cleanup_config(modelcfg)
126 | rep = hydra.utils.instantiate(cleancfg)
127 | rep = torch.nn.DataParallel(rep)
128 | r3m_state_dict = remove_language_head(
129 | torch.load(modelpath, map_location=torch.device(device))["r3m"]
130 | )
131 |
132 | rep.load_state_dict(r3m_state_dict)
133 | rep = rep.module
134 | return rep
135 |
--------------------------------------------------------------------------------
/datasets/point_maze_dset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import decord
3 | import numpy as np
4 | from pathlib import Path
5 | from einops import rearrange
6 | from decord import VideoReader
7 | from typing import Callable, Optional
8 | from .traj_dset import TrajDataset, get_train_val_sliced
9 | from typing import Optional, Callable, Any
10 | decord.bridge.set_bridge("torch")
11 |
12 | class PointMazeDataset(TrajDataset):
13 | def __init__(
14 | self,
15 | data_path: str = "data/point_maze",
16 | n_rollout: Optional[int] = None,
17 | transform: Optional[Callable] = None,
18 | normalize_action: bool = False,
19 | action_scale=1.0,
20 | ):
21 | self.data_path = Path(data_path)
22 | self.transform = transform
23 | self.normalize_action = normalize_action
24 | states = torch.load(self.data_path / "states.pth").float()
25 | self.states = states
26 | self.actions = torch.load(self.data_path / "actions.pth").float()
27 | self.actions = self.actions / action_scale # scaled back up in env
28 | self.seq_lengths = torch.load(self.data_path /'seq_lengths.pth')
29 |
30 | self.n_rollout = n_rollout
31 | if self.n_rollout:
32 | n = self.n_rollout
33 | else:
34 | n = len(self.states)
35 |
36 | self.states = self.states[:n]
37 | self.actions = self.actions[:n]
38 | self.seq_lengths = self.seq_lengths[:n]
39 | self.proprios = self.states.clone()
40 | print(f"Loaded {n} rollouts")
41 |
42 | self.action_dim = self.actions.shape[-1]
43 | self.state_dim = self.states.shape[-1]
44 | self.proprio_dim = self.proprios.shape[-1]
45 |
46 | if normalize_action:
47 | self.action_mean, self.action_std = self.get_data_mean_std(self.actions, self.seq_lengths)
48 | self.state_mean, self.state_std = self.get_data_mean_std(self.states, self.seq_lengths)
49 | self.proprio_mean, self.proprio_std = self.get_data_mean_std(self.proprios, self.seq_lengths)
50 | else:
51 | self.action_mean = torch.zeros(self.action_dim)
52 | self.action_std = torch.ones(self.action_dim)
53 | self.state_mean = torch.zeros(self.state_dim)
54 | self.state_std = torch.ones(self.state_dim)
55 | self.proprio_mean = torch.zeros(self.proprio_dim)
56 | self.proprio_std = torch.ones(self.proprio_dim)
57 |
58 | self.actions = (self.actions - self.action_mean) / self.action_std
59 | self.proprios = (self.proprios - self.proprio_mean) / self.proprio_std
60 |
61 | def get_data_mean_std(self, data, traj_lengths):
62 | all_data = []
63 | for traj in range(len(traj_lengths)):
64 | traj_len = traj_lengths[traj]
65 | traj_data = data[traj, :traj_len]
66 | all_data.append(traj_data)
67 | all_data = torch.vstack(all_data)
68 | data_mean = torch.mean(all_data, dim=0)
69 | data_std = torch.std(all_data, dim=0)
70 | return data_mean, data_std
71 |
72 | def get_seq_length(self, idx):
73 | return self.seq_lengths[idx]
74 |
75 | def get_all_actions(self):
76 | result = []
77 | for i in range(len(self.seq_lengths)):
78 | T = self.seq_lengths[i]
79 | result.append(self.actions[i, :T, :])
80 | return torch.cat(result, dim=0)
81 |
82 | def get_frames(self, idx, frames):
83 | obs_dir = self.data_path / "obses"
84 | image = torch.load(obs_dir / f"episode_{idx:03d}.pth")
85 | proprio = self.proprios[idx, frames]
86 | act = self.actions[idx, frames]
87 | state = self.states[idx, frames]
88 |
89 | image = image[frames] # THWC
90 | image = image / 255.0
91 | image = rearrange(image, "T H W C -> T C H W")
92 | if self.transform:
93 | image = self.transform(image)
94 | obs = {
95 | "visual": image,
96 | "proprio": proprio
97 | }
98 | return obs, act, state, {} # env_info
99 |
100 | def __getitem__(self, idx):
101 | return self.get_frames(idx, range(self.get_seq_length(idx)))
102 |
103 | def __len__(self):
104 | return len(self.seq_lengths)
105 |
106 | def preprocess_imgs(self, imgs):
107 | if isinstance(imgs, np.ndarray):
108 | raise NotImplementedError
109 | elif isinstance(imgs, torch.Tensor):
110 | return rearrange(imgs, "b h w c -> b c h w") / 255.0
111 |
112 | def load_point_maze_slice_train_val(
113 | transform,
114 | n_rollout=50,
115 | data_path='data/pusht_dataset',
116 | normalize_action=False,
117 | split_ratio=0.8,
118 | num_hist=0,
119 | num_pred=0,
120 | frameskip=0,
121 | ):
122 | dset = PointMazeDataset(
123 | n_rollout=n_rollout,
124 | transform=transform,
125 | data_path=data_path,
126 | normalize_action=normalize_action,
127 | )
128 | dset_train, dset_val, train_slices, val_slices = get_train_val_sliced(
129 | traj_dataset=dset,
130 | train_fraction=split_ratio,
131 | num_frames=num_hist + num_pred,
132 | frameskip=frameskip
133 | )
134 |
135 | datasets = {}
136 | datasets['train'] = train_slices
137 | datasets['valid'] = val_slices
138 | traj_dset = {}
139 | traj_dset['train'] = dset_train
140 | traj_dset['valid'] = dset_val
141 | return datasets, traj_dset
142 |
--------------------------------------------------------------------------------
/planning/mpc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import hydra
3 | import copy
4 | import numpy as np
5 | from einops import rearrange, repeat
6 | from utils import slice_trajdict_with_t
7 | from .base_planner import BasePlanner
8 |
9 |
10 | class MPCPlanner(BasePlanner):
11 | """
12 | an online planner so feedback from env is allowed
13 | """
14 |
15 | def __init__(
16 | self,
17 | max_iter,
18 | n_taken_actions,
19 | sub_planner,
20 | wm,
21 | env, # for online exec
22 | action_dim,
23 | objective_fn,
24 | preprocessor,
25 | evaluator,
26 | wandb_run,
27 | logging_prefix="mpc",
28 | log_filename="logs.json",
29 | **kwargs,
30 | ):
31 | super().__init__(
32 | wm,
33 | action_dim,
34 | objective_fn,
35 | preprocessor,
36 | evaluator,
37 | wandb_run,
38 | log_filename,
39 | )
40 | self.env = env
41 | self.max_iter = np.inf if max_iter is None else max_iter
42 | self.n_taken_actions = n_taken_actions
43 | self.logging_prefix = logging_prefix
44 | sub_planner["_target_"] = sub_planner["target"]
45 | self.sub_planner = hydra.utils.instantiate(
46 | sub_planner,
47 | wm=self.wm,
48 | action_dim=self.action_dim,
49 | objective_fn=self.objective_fn,
50 | preprocessor=self.preprocessor,
51 | evaluator=self.evaluator, # evaluator is shared for mpc and sub_planner
52 | wandb_run=self.wandb_run,
53 | log_filename=None,
54 | )
55 | self.is_success = None
56 | self.action_len = None # keep track of the step each traj reaches success
57 | self.iter = 0
58 | self.planned_actions = []
59 |
60 | def _apply_success_mask(self, actions):
61 | device = actions.device
62 | mask = torch.tensor(self.is_success).bool()
63 | actions[mask] = 0
64 | masked_actions = rearrange(
65 | actions[mask], "... (f d) -> ... f d", f=self.evaluator.frameskip
66 | )
67 | masked_actions = self.preprocessor.normalize_actions(masked_actions.cpu())
68 | masked_actions = rearrange(masked_actions, "... f d -> ... (f d)")
69 | actions[mask] = masked_actions.to(device)
70 | return actions
71 |
72 | def plan(self, obs_0, obs_g, actions=None):
73 | """
74 | actions is NOT used
75 | Returns:
76 | actions: (B, T, action_dim) torch.Tensor
77 | """
78 | n_evals = obs_0["visual"].shape[0]
79 | self.is_success = np.zeros(n_evals, dtype=bool)
80 | self.action_len = np.full(n_evals, np.inf)
81 | init_obs_0, init_state_0 = self.evaluator.get_init_cond()
82 |
83 | cur_obs_0 = obs_0
84 | memo_actions = None
85 | while not np.all(self.is_success) and self.iter < self.max_iter:
86 | self.sub_planner.logging_prefix = f"plan_{self.iter}"
87 | actions, _ = self.sub_planner.plan(
88 | obs_0=cur_obs_0,
89 | obs_g=obs_g,
90 | actions=memo_actions,
91 | ) # (b, t, act_dim)
92 | taken_actions = actions.detach()[:, : self.n_taken_actions]
93 | self._apply_success_mask(taken_actions)
94 | memo_actions = actions.detach()[:, self.n_taken_actions :]
95 | self.planned_actions.append(taken_actions)
96 |
97 | print(f"MPC iter {self.iter} Eval ------- ")
98 | action_so_far = torch.cat(self.planned_actions, dim=1)
99 | self.evaluator.assign_init_cond(
100 | obs_0=init_obs_0,
101 | state_0=init_state_0,
102 | )
103 | logs, successes, e_obses, e_states = self.evaluator.eval_actions(
104 | action_so_far,
105 | self.action_len,
106 | filename=f"plan{self.iter}",
107 | save_video=True,
108 | )
109 | new_successes = successes & ~self.is_success # Identify new successes
110 | self.is_success = (
111 | self.is_success | successes
112 | ) # Update overall success status
113 | self.action_len[new_successes] = (
114 | (self.iter + 1) * self.n_taken_actions
115 | ) # Update only for the newly successful trajectories
116 |
117 | print("self.is_success: ", self.is_success)
118 | logs = {f"{self.logging_prefix}/{k}": v for k, v in logs.items()}
119 | logs.update({"step": self.iter + 1})
120 | self.wandb_run.log(logs)
121 | self.dump_logs(logs)
122 |
123 | # update evaluator's init conditions with new env feedback
124 | e_final_obs = slice_trajdict_with_t(e_obses, start_idx=-1)
125 | cur_obs_0 = e_final_obs
126 | e_final_state = e_states[:, -1]
127 | self.evaluator.assign_init_cond(
128 | obs_0=e_final_obs,
129 | state_0=e_final_state,
130 | )
131 | self.iter += 1
132 | self.sub_planner.logging_prefix = f"plan_{self.iter}"
133 |
134 | planned_actions = torch.cat(self.planned_actions, dim=1)
135 | self.evaluator.assign_init_cond(
136 | obs_0=init_obs_0,
137 | state_0=init_state_0,
138 | )
139 |
140 | return planned_actions, self.action_len
141 |
--------------------------------------------------------------------------------
/datasets/traj_dset.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import torch
3 | import numpy as np
4 | from torch.utils.data import Dataset
5 | from typing import Optional, Sequence, List
6 | from torch.utils.data import Dataset, Subset
7 | from torch import default_generator, randperm
8 | from einops import rearrange
9 |
10 | # https://github.com/JaidedAI/EasyOCR/issues/1243
11 | def _accumulate(iterable, fn=lambda x, y: x + y):
12 | "Return running totals"
13 | # _accumulate([1,2,3,4,5]) --> 1 3 6 10 15
14 | # _accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
15 | it = iter(iterable)
16 | try:
17 | total = next(it)
18 | except StopIteration:
19 | return
20 | yield total
21 | for element in it:
22 | total = fn(total, element)
23 | yield total
24 |
25 | class TrajDataset(Dataset, abc.ABC):
26 | @abc.abstractmethod
27 | def get_seq_length(self, idx):
28 | """
29 | Returns the length of the idx-th trajectory.
30 | """
31 | raise NotImplementedError
32 |
33 | class TrajSubset(TrajDataset, Subset):
34 | """
35 | Subset of a trajectory dataset at specified indices.
36 |
37 | Args:
38 | dataset (TrajectoryDataset): The whole Dataset
39 | indices (sequence): Indices in the whole set selected for subset
40 | """
41 | def __init__(self, dataset: TrajDataset, indices: Sequence[int]):
42 | Subset.__init__(self, dataset, indices)
43 |
44 | def get_seq_length(self, idx):
45 | return self.dataset.get_seq_length(self.indices[idx])
46 |
47 | def __getattr__(self, name):
48 | if hasattr(self.dataset, name):
49 | return getattr(self.dataset, name)
50 | raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
51 |
52 |
53 | class TrajSlicerDataset(TrajDataset):
54 | def __init__(
55 | self,
56 | dataset: TrajDataset,
57 | num_frames: int,
58 | frameskip: int = 1,
59 | process_actions: str = "concat",
60 | ):
61 | self.dataset = dataset
62 | self.num_frames = num_frames
63 | self.frameskip = frameskip
64 | self.slices = []
65 | for i in range(len(self.dataset)):
66 | T = self.dataset.get_seq_length(i)
67 | if T - num_frames < 0:
68 | print(f"Ignored short sequence #{i}: len={T}, num_frames={num_frames}")
69 | else:
70 | self.slices += [
71 | (i, start, start + num_frames * self.frameskip)
72 | for start in range(T - num_frames * frameskip + 1)
73 | ] # slice indices follow convention [start, end)
74 | # randomly permute the slices
75 | self.slices = np.random.permutation(self.slices)
76 |
77 | self.proprio_dim = self.dataset.proprio_dim
78 | if process_actions == "concat":
79 | self.action_dim = self.dataset.action_dim * self.frameskip
80 | else:
81 | self.action_dim = self.dataset.action_dim
82 |
83 | self.state_dim = self.dataset.state_dim
84 |
85 |
86 | def get_seq_length(self, idx: int) -> int:
87 | return self.num_frames
88 |
89 | def __len__(self):
90 | return len(self.slices)
91 |
92 | def __getitem__(self, idx):
93 | i, start, end = self.slices[idx]
94 | obs, act, state, _ = self.dataset[i]
95 | for k, v in obs.items():
96 | obs[k] = v[start:end:self.frameskip]
97 | state = state[start:end:self.frameskip]
98 | act = act[start:end]
99 | act = rearrange(act, "(n f) d -> n (f d)", n=self.num_frames) # concat actions
100 | return tuple([obs, act, state])
101 |
102 |
103 | def random_split_traj(
104 | dataset: TrajDataset,
105 | lengths: Sequence[int],
106 | generator: Optional[torch.Generator] = default_generator,
107 | ) -> List[TrajSubset]:
108 | if sum(lengths) != len(dataset): # type: ignore[arg-type]
109 | raise ValueError(
110 | "Sum of input lengths does not equal the length of the input dataset!"
111 | )
112 |
113 | indices = randperm(sum(lengths), generator=generator).tolist()
114 | print(
115 | [
116 | indices[offset - length : offset]
117 | for offset, length in zip(_accumulate(lengths), lengths)
118 | ]
119 | )
120 | return [
121 | TrajSubset(dataset, indices[offset - length : offset])
122 | for offset, length in zip(_accumulate(lengths), lengths)
123 | ]
124 |
125 |
126 | def split_traj_datasets(dataset, train_fraction=0.95, random_seed=42):
127 | dataset_length = len(dataset)
128 | lengths = [
129 | int(train_fraction * dataset_length),
130 | dataset_length - int(train_fraction * dataset_length),
131 | ]
132 | train_set, val_set = random_split_traj(
133 | dataset, lengths, generator=torch.Generator().manual_seed(random_seed)
134 | )
135 | return train_set, val_set
136 |
137 |
138 | def get_train_val_sliced(
139 | traj_dataset: TrajDataset,
140 | train_fraction: float = 0.9,
141 | random_seed: int = 42,
142 | num_frames: int = 10,
143 | frameskip: int = 1,
144 | ):
145 | train, val = split_traj_datasets(
146 | traj_dataset,
147 | train_fraction=train_fraction,
148 | random_seed=random_seed,
149 | )
150 | train_slices = TrajSlicerDataset(train, num_frames, frameskip)
151 | val_slices = TrajSlicerDataset(val, num_frames, frameskip)
152 | return train, val, train_slices, val_slices
--------------------------------------------------------------------------------