├── configs
├── template_dp.yaml
├── dataset
│ ├── libero_book_caddy.yaml
│ ├── robomimic_can_ph.yaml
│ ├── libero_90.yaml
│ ├── droid_mixture.yaml
│ ├── robomimic_lift_ph.yaml
│ ├── robomimic_square_ph.yaml
│ ├── libero_moka_moka.yaml
│ ├── libero_stove_moka.yaml
│ ├── libero_mug_microwave.yaml
│ ├── libero_cheese_butter.yaml
│ ├── libero_soup_sauce.yaml
│ ├── libero_bowl_drawer.yaml
│ ├── libero_soup_cheese.yaml
│ ├── libero_mug_mug.yaml
│ ├── libero_mug_pudding.yaml
│ ├── template_robomimic.yaml
│ ├── robomimic_tool_hang_ph.yaml
│ ├── webvideo_for_robomimic.yaml
│ ├── robomimic_transport_ph.yaml
│ ├── template_libero.yaml
│ ├── droid.yaml
│ ├── libero_90_webvideo.yaml
│ └── droid_webvideo.yaml
├── finetune_gr1.yaml
├── finetune_pad.yaml
├── finetune_uwm.yaml
├── train_gr1.yaml
├── train_pad.yaml
├── train_uwm.yaml
├── train_dp.yaml
├── finetune_dp.yaml
├── train_gr1_robomimic.yaml
├── train_pad_robomimic.yaml
├── finetune_gr1_robomimic.yaml
├── finetune_pad_robomimic.yaml
├── train_dp_robomimic.yaml
├── finetune_dp_robomimic.yaml
├── train_uwm_robomimic.yaml
├── finetune_uwm_robomimic.yaml
├── template_finetune.yaml
├── template_finetune_robomimic.yaml
├── template_train.yaml
├── model
│ ├── pad.yaml
│ ├── gr1.yaml
│ ├── uwm.yaml
│ └── dp.yaml
└── template_train_robomimic.yaml
├── models
├── gr1
│ ├── __init__.py
│ ├── flamingo.py
│ ├── obs_encoder.py
│ └── gr1.py
├── pad
│ ├── __init__.py
│ └── obs_encoder.py
├── uwm
│ ├── __init__.py
│ └── obs_encoder.py
├── dp
│ ├── __init__.py
│ ├── image_policy.py
│ ├── transformer.py
│ ├── base_policy.py
│ └── obs_encoder.py
└── common
│ ├── language.py
│ ├── transforms.py
│ └── adaln_attention.py
├── .gitignore
├── scripts
├── launch_droid_pretrain.sh
├── launch_droid_cotrain.sh
├── setup_aws.sh
└── launch_droid_finetune.sh
├── datasets
├── utils
│ ├── file_utils.py
│ ├── obs_utils.py
│ ├── loader.py
│ ├── sampler.py
│ ├── normalizer.py
│ ├── mixture.py
│ └── buffer.py
├── webvideo
│ ├── __init__.py
│ └── dataset.py
├── droid
│ ├── __init__.py
│ └── convert_dataset_zarr.py
└── robomimic
│ ├── __init__.py
│ └── dataset.py
├── pyproject.toml
├── requirements.txt
├── experiments
├── utils.py
├── uwm
│ ├── eval_robomimic.py
│ ├── train_webvideo.py
│ ├── ablate_inverse_dynamics.py
│ ├── ablate_forward_dynamics.py
│ ├── eval_droid.py
│ └── train_robomimic.py
├── gr1
│ ├── train_robomimic.py
│ └── train.py
├── pad
│ ├── train_robomimic.py
│ └── train.py
└── dp
│ └── train_robomimic.py
├── environments
└── robomimic
│ ├── __init__.py
│ └── wrappers.py
└── README.md
/configs/template_dp.yaml:
--------------------------------------------------------------------------------
1 | num_frames: 17
--------------------------------------------------------------------------------
/configs/dataset/libero_book_caddy.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - template_libero
3 | - _self_
--------------------------------------------------------------------------------
/configs/dataset/robomimic_can_ph.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - template_robomimic
3 | - _self_
--------------------------------------------------------------------------------
/models/gr1/__init__.py:
--------------------------------------------------------------------------------
1 | from .gr1 import GR1
2 | from .obs_encoder import GR1ObservationEncoder
3 |
--------------------------------------------------------------------------------
/models/pad/__init__.py:
--------------------------------------------------------------------------------
1 | from .obs_encoder import PADObservationEncoder
2 | from .pad import PAD
3 |
--------------------------------------------------------------------------------
/models/uwm/__init__.py:
--------------------------------------------------------------------------------
1 | from .obs_encoder import UWMObservationEncoder
2 | from .uwm import UnifiedWorldModel
3 |
--------------------------------------------------------------------------------
/configs/finetune_gr1.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset: droid
3 | - model: gr1
4 | - template_finetune
5 | - _self_
--------------------------------------------------------------------------------
/configs/finetune_pad.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset: droid
3 | - model: pad
4 | - template_finetune
5 | - _self_
--------------------------------------------------------------------------------
/configs/finetune_uwm.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset: droid
3 | - model: uwm
4 | - template_finetune
5 | - _self_
--------------------------------------------------------------------------------
/configs/train_gr1.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset: droid
3 | - model: gr1
4 | - template_train
5 | - _self_
6 |
--------------------------------------------------------------------------------
/configs/train_pad.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset: droid
3 | - model: pad
4 | - template_train
5 | - _self_
6 |
--------------------------------------------------------------------------------
/configs/train_uwm.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset: droid
3 | - model: uwm
4 | - template_train
5 | - _self_
6 |
--------------------------------------------------------------------------------
/configs/train_dp.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset: droid
3 | - model: dp
4 | - template_train
5 | - template_dp
6 | - _self_
--------------------------------------------------------------------------------
/configs/finetune_dp.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset: droid
3 | - model: dp
4 | - template_finetune
5 | - template_dp
6 | - _self_
--------------------------------------------------------------------------------
/configs/train_gr1_robomimic.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset: libero_90
3 | - model: gr1
4 | - template_train_robomimic
5 | - _self_
--------------------------------------------------------------------------------
/configs/train_pad_robomimic.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset: libero_90
3 | - model: pad
4 | - template_train_robomimic
5 | - _self_
--------------------------------------------------------------------------------
/configs/finetune_gr1_robomimic.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset: libero_book_caddy
3 | - model: gr1
4 | - template_finetune_robomimic
5 | - _self_
--------------------------------------------------------------------------------
/configs/finetune_pad_robomimic.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset: libero_book_caddy
3 | - model: pad
4 | - template_finetune_robomimic
5 | - _self_
--------------------------------------------------------------------------------
/configs/train_dp_robomimic.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset: libero_90
3 | - model: dp
4 | - template_train_robomimic
5 | - template_dp
6 | - _self_
--------------------------------------------------------------------------------
/configs/finetune_dp_robomimic.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset: libero_book_caddy
3 | - model: dp
4 | - template_finetune_robomimic
5 | - template_dp
6 | - _self_
--------------------------------------------------------------------------------
/configs/train_uwm_robomimic.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset: libero_90
3 | - model: uwm
4 | - template_train_robomimic
5 | - _self_
6 |
7 | model:
8 | num_registers: 24
--------------------------------------------------------------------------------
/configs/finetune_uwm_robomimic.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset: libero_book_caddy
3 | - model: uwm
4 | - template_finetune_robomimic
5 | - _self_
6 |
7 | model:
8 | num_registers: 24
--------------------------------------------------------------------------------
/configs/dataset/libero_90.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - template_libero
3 | - _self_
4 |
5 | name: libero_90
6 | hdf5_path_globs: /home/ubuntu/LIBERO/libero/datasets/libero_90/*.hdf5
7 | buffer_path: /home/ubuntu/LIBERO/libero/datasets/libero_90/buffer.zarr
--------------------------------------------------------------------------------
/configs/dataset/droid_mixture.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - droid
3 | - _self_
4 |
5 | name: droid_mixture
6 | buffer_path: /tmp/weirdlab/zchuning/data/droid/buffer.zarr
7 | video_buffer_path: /tmp/weirdlab/zchuning/data/droid/video_buffer.zarr
8 | balance_datasets: False
--------------------------------------------------------------------------------
/configs/dataset/robomimic_lift_ph.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - template_robomimic
3 | - _self_
4 |
5 | name: robomimic_lift_ph
6 | hdf5_path_globs: /home/ubuntu/robomimic/datasets/lift/ph/image_v141.hdf5
7 | buffer_path: /home/ubuntu/robomimic/datasets/lift/ph/image_v141.zarr
--------------------------------------------------------------------------------
/configs/dataset/robomimic_square_ph.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - template_robomimic
3 | - _self_
4 |
5 | name: robomimic_square_ph
6 | hdf5_path_globs: /home/ubuntu/robomimic/datasets/square/ph/image_v141.hdf5
7 | buffer_path: /home/ubuntu/robomimic/datasets/square/ph/image_v141.zarr
--------------------------------------------------------------------------------
/configs/template_finetune.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - template_train
3 | - _self_
4 |
5 | pretrain_checkpoint_path: "logdir/${algo}/droid/benchmark/0/models.pt"
6 | num_steps: 10000
7 |
8 | clip_grad_norm: 1.0
9 | scheduler:
10 | name: "cosine"
11 | num_training_steps: ${num_steps}
12 | num_warmup_steps: 1000
--------------------------------------------------------------------------------
/models/dp/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_policy import NoisePredictionNet, DiffusionPolicy, FlowPolicy
2 | from .image_policy import ImageDiffusionPolicy, ImageFlowPolicy
3 | from .obs_encoder import ImageObservationEncoder
4 | from .transformer import TransformerNoisePredictionNet
5 | from .unet import UnetNoisePredictionNet
6 |
--------------------------------------------------------------------------------
/configs/dataset/libero_moka_moka.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - template_libero
3 | - _self_
4 |
5 | name: libero_moka_moka
6 | hdf5_path_globs: /home/ubuntu/LIBERO/libero/datasets/libero_10/KITCHEN_SCENE8_put_both_moka_pots_on_the_stove_demo.hdf5
7 | buffer_path: /home/ubuntu/LIBERO/libero/datasets/libero_10/KITCHEN_SCENE8_put_both_moka_pots_on_the_stove_demo.zarr
--------------------------------------------------------------------------------
/configs/template_finetune_robomimic.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - template_train_robomimic
3 | - _self_
4 |
5 | pretrain_checkpoint_path: "/home/ubuntu/efs/chuning/logdir/${algo}/libero_90/benchmark/0/models.pt"
6 | num_steps: 20000
7 |
8 | clip_grad_norm: 1.0
9 | scheduler:
10 | name: "cosine"
11 | num_training_steps: ${num_steps}
12 | num_warmup_steps: 1000
--------------------------------------------------------------------------------
/configs/dataset/libero_stove_moka.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - template_libero
3 | - _self_
4 |
5 | name: libero_stove_moka
6 | hdf5_path_globs: /home/ubuntu/LIBERO/libero/datasets/libero_10/KITCHEN_SCENE3_turn_on_the_stove_and_put_the_moka_pot_on_it_demo.hdf5
7 | buffer_path: /home/ubuntu/LIBERO/libero/datasets/libero_10/KITCHEN_SCENE3_turn_on_the_stove_and_put_the_moka_pot_on_it_demo.zarr
--------------------------------------------------------------------------------
/configs/dataset/libero_mug_microwave.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - template_libero
3 | - _self_
4 |
5 | name: libero_mug_microwave
6 | hdf5_path_globs: /home/ubuntu/LIBERO/libero/datasets/libero_10/KITCHEN_SCENE6_put_the_yellow_and_white_mug_in_the_microwave_and_close_it_demo.hdf5
7 | buffer_path: /home/ubuntu/LIBERO/libero/datasets/libero_10/KITCHEN_SCENE6_put_the_yellow_and_white_mug_in_the_microwave_and_close_it_demo.zarr
--------------------------------------------------------------------------------
/configs/dataset/libero_cheese_butter.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - template_libero
3 | - _self_
4 |
5 | name: libero_cheese_butter
6 | hdf5_path_globs: /home/ubuntu/LIBERO/libero/datasets/libero_10/LIVING_ROOM_SCENE2_put_both_the_cream_cheese_box_and_the_butter_in_the_basket_demo.hdf5
7 | buffer_path: /home/ubuntu/LIBERO/libero/datasets/libero_10/LIVING_ROOM_SCENE2_put_both_the_cream_cheese_box_and_the_butter_in_the_basket_demo.zarr
--------------------------------------------------------------------------------
/configs/dataset/libero_soup_sauce.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - template_libero
3 | - _self_
4 |
5 | name: libero_soup_sauce
6 | hdf5_path_globs: /home/ubuntu/LIBERO/libero/datasets/libero_10/LIVING_ROOM_SCENE2_put_both_the_alphabet_soup_and_the_tomato_sauce_in_the_basket_demo.hdf5
7 | buffer_path: /home/ubuntu/LIBERO/libero/datasets/libero_10/LIVING_ROOM_SCENE2_put_both_the_alphabet_soup_and_the_tomato_sauce_in_the_basket_demo.zarr
--------------------------------------------------------------------------------
/configs/dataset/libero_bowl_drawer.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - template_libero
3 | - _self_
4 |
5 | name: libero_bowl_drawer
6 | hdf5_path_globs: /home/ubuntu/LIBERO/libero/datasets/libero_10/KITCHEN_SCENE4_put_the_black_bowl_in_the_bottom_drawer_of_the_cabinet_and_close_it_demo.hdf5
7 | buffer_path: /home/ubuntu/LIBERO/libero/datasets/libero_10/KITCHEN_SCENE4_put_the_black_bowl_in_the_bottom_drawer_of_the_cabinet_and_close_it_demo.zarr
--------------------------------------------------------------------------------
/configs/dataset/libero_soup_cheese.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - template_libero
3 | - _self_
4 |
5 | name: libero_soup_cheese
6 | hdf5_path_globs: /home/ubuntu/LIBERO/libero/datasets/libero_10/LIVING_ROOM_SCENE1_put_both_the_alphabet_soup_and_the_cream_cheese_box_in_the_basket_demo.hdf5
7 | buffer_path: /home/ubuntu/LIBERO/libero/datasets/libero_10/LIVING_ROOM_SCENE1_put_both_the_alphabet_soup_and_the_cream_cheese_box_in_the_basket_demo.zarr
--------------------------------------------------------------------------------
/configs/dataset/libero_mug_mug.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - template_libero
3 | - _self_
4 |
5 | name: libero_mug_mug
6 | hdf5_path_globs: /home/ubuntu/LIBERO/libero/datasets/libero_10/LIVING_ROOM_SCENE5_put_the_white_mug_on_the_left_plate_and_put_the_yellow_and_white_mug_on_the_right_plate_demo.hdf5
7 | buffer_path: /home/ubuntu/LIBERO/libero/datasets/libero_10/LIVING_ROOM_SCENE5_put_the_white_mug_on_the_left_plate_and_put_the_yellow_and_white_mug_on_the_right_plate_demo.zarr
--------------------------------------------------------------------------------
/configs/dataset/libero_mug_pudding.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - template_libero
3 | - _self_
4 |
5 | name: libero_mug_pudding
6 | hdf5_path_globs: /home/ubuntu/LIBERO/libero/datasets/libero_10/LIVING_ROOM_SCENE6_put_the_white_mug_on_the_plate_and_put_the_chocolate_pudding_to_the_right_of_the_plate_demo.hdf5
7 | buffer_path: /home/ubuntu/LIBERO/libero/datasets/libero_10/LIVING_ROOM_SCENE6_put_the_white_mug_on_the_plate_and_put_the_chocolate_pudding_to_the_right_of_the_plate_demo.zarr
--------------------------------------------------------------------------------
/configs/dataset/template_robomimic.yaml:
--------------------------------------------------------------------------------
1 | _target_: datasets.robomimic.make_robomimic_dataset
2 | name: robomimic_can_ph
3 | hdf5_path_globs: /home/ubuntu/robomimic/datasets/can/ph/image_v141.hdf5
4 | buffer_path: /home/ubuntu/robomimic/datasets/can/ph/image_v141.zarr
5 | shape_meta:
6 | obs:
7 | agentview_image: &camera_meta
8 | shape: [84, 84, 3]
9 | type: rgb
10 | robot0_eye_in_hand_image: *camera_meta
11 | action:
12 | shape: [7] # pos + quarternion + gripper
13 | seq_len: ${num_frames}
14 | val_ratio: 0.05
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # IDE config
31 | .idea
32 | .vscode
33 |
34 | # experiment data
35 | logdir/
36 | wandb/
37 | logs/
--------------------------------------------------------------------------------
/configs/dataset/robomimic_tool_hang_ph.yaml:
--------------------------------------------------------------------------------
1 | _target_: datasets.robomimic.make_robomimic_dataset
2 | name: robomimic_tool_hang_ph
3 | hdf5_path_globs: /home/ubuntu/robomimic/datasets/tool_hang/ph/image_v141.hdf5
4 | buffer_path: /home/ubuntu/robomimic/datasets/tool_hang/ph/image_v141.zarr
5 | shape_meta:
6 | obs:
7 | sideview_image: &camera_meta
8 | shape: [240, 240, 3]
9 | type: rgb
10 | robot0_eye_in_hand_image: *camera_meta
11 | action:
12 | shape: [7] # pos + quarternion + gripper
13 | seq_len: ${num_frames}
14 | val_ratio: 0.05
--------------------------------------------------------------------------------
/configs/dataset/webvideo_for_robomimic.yaml:
--------------------------------------------------------------------------------
1 | _target_: datasets.webvideo.make_multiview_video_dataset
2 | name: webvideo_for_robomimic
3 | index_paths: [
4 | /gscratch/weirdlab/zchuning/data/k400_filtered.txt,
5 | /gscratch/weirdlab/zchuning/data/ssv2_filtered.txt,
6 | ]
7 | shape_meta:
8 | obs:
9 | agentview_image: &camera_meta
10 | shape: [84, 84, 3]
11 | type: rgb
12 | robot0_eye_in_hand_image: *camera_meta
13 | action:
14 | shape: [7] # pos + quarternion + gripper
15 | seq_len: ${num_frames}
16 | obs_padding: "same"
17 | frame_skip: 2
18 | val_ratio: 0.05
19 |
--------------------------------------------------------------------------------
/configs/template_train.yaml:
--------------------------------------------------------------------------------
1 | hydra:
2 | run:
3 | dir: logdir/${algo}/${dataset.name}/${exp_id}/${seed}
4 |
5 | algo: ${hydra:runtime.choices.model}
6 | exp_id: default
7 | seed: 0
8 | logdir: ${hydra:run.dir}
9 | resume: False
10 | use_amp: False
11 | eval_task_name: ${dataset.name}
12 | pretrain_checkpoint_path: null
13 |
14 | num_steps: 100000
15 | eval_every: 10000
16 | save_every: 10000
17 |
18 | batch_size: 36
19 | num_frames: 19
20 | obs_num_frames: 2
21 | clip_grad_norm: null
22 |
23 | optimizer:
24 | lr: 1e-4
25 | weight_decay: 1e-6
26 | betas: [0.9, 0.999]
27 | eps: 1e-8
28 |
29 | scheduler:
30 | name: "constant"
--------------------------------------------------------------------------------
/configs/dataset/robomimic_transport_ph.yaml:
--------------------------------------------------------------------------------
1 | _target_: datasets.robomimic.make_robomimic_dataset
2 | name: robomimic_transport_ph
3 | hdf5_path_globs: /home/ubuntu/robomimic/datasets/transport/ph/image_v141.hdf5
4 | buffer_path: /home/ubuntu/robomimic/datasets/transport/ph/image_v141.zarr
5 | shape_meta:
6 | obs:
7 | shouldercamera0_image: &camera_meta
8 | shape: [84, 84, 3]
9 | type: rgb
10 | shouldercamera1_image: *camera_meta
11 | robot0_eye_in_hand_image: *camera_meta
12 | robot1_eye_in_hand_image: *camera_meta
13 | action:
14 | shape: [14] # 2 * (pos + quarternion + gripper)
15 | seq_len: ${num_frames}
16 | val_ratio: 0.05
--------------------------------------------------------------------------------
/configs/dataset/template_libero.yaml:
--------------------------------------------------------------------------------
1 | _target_: datasets.robomimic.make_robomimic_dataset
2 | name: libero_book_caddy
3 | hdf5_path_globs: /home/ubuntu/LIBERO/libero/datasets/libero_10/STUDY_SCENE1_pick_up_the_book_and_place_it_in_the_back_compartment_of_the_caddy_demo.hdf5
4 | buffer_path: /home/ubuntu/LIBERO/libero/datasets/libero_10/STUDY_SCENE1_pick_up_the_book_and_place_it_in_the_back_compartment_of_the_caddy_demo.zarr
5 | shape_meta:
6 | obs:
7 | agentview_rgb: &camera_meta
8 | shape: [128, 128, 3]
9 | type: rgb
10 | eye_in_hand_rgb: *camera_meta
11 | action:
12 | shape: [7] # pos + quarternion + gripper
13 | seq_len: ${num_frames}
14 | val_ratio: 0.05
15 | subsample_ratio: 1.0
16 | flip_rgb: True
--------------------------------------------------------------------------------
/configs/model/pad.yaml:
--------------------------------------------------------------------------------
1 | _target_: models.pad.PAD
2 | action_len: 16
3 | action_dim: ${dataset.shape_meta.action.shape[0]}
4 | obs_encoder:
5 | _target_: models.pad.PADObservationEncoder
6 | shape_meta: ${dataset.shape_meta}
7 | num_frames: ${obs_num_frames}
8 | resize_shape: [240, 320]
9 | crop_shape: [224, 224]
10 | random_crop: True
11 | color_jitter:
12 | brightness: 0.2
13 | contrast: 0.2
14 | saturation: 0.2
15 | hue: [-0.2, 0.2]
16 | imagenet_norm: False
17 | embed_dim: 768
18 | timestep_embed_dim: 512
19 | latent_patch_shape: [2, 4, 4]
20 | depth: 12
21 | num_heads: 12
22 | mlp_ratio: 4
23 | qkv_bias: True
24 | num_train_steps: 100
25 | num_inference_steps: 10
26 | beta_schedule: squaredcos_cap_v2
27 | clip_sample: True
--------------------------------------------------------------------------------
/models/common/language.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from transformers import CLIPModel
4 |
5 |
6 | class CLIPTextEncoder(nn.Module):
7 | def __init__(self, embed_dim: int):
8 | super().__init__()
9 | self.language_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
10 | # Freeze pretrained model parameters
11 | for p in self.language_model.parameters():
12 | p.requires_grad = False
13 | self.head = nn.Linear(512, embed_dim)
14 |
15 | def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor):
16 | feats = self.language_model.get_text_features(
17 | input_ids=input_ids, attention_mask=attention_mask
18 | )
19 | feats = self.head(feats)
20 | return feats
21 |
--------------------------------------------------------------------------------
/configs/dataset/droid.yaml:
--------------------------------------------------------------------------------
1 | _target_: datasets.droid.make_droid_dataset
2 | name: droid
3 |
4 | buffer_path: /tmp/weirdlab/zchuning/data/droid/buffer.zarr
5 | shape_meta:
6 | # acceptable types: rgb, low_dim
7 | obs:
8 | # camera
9 | "exterior_image_1_left": &camera_meta
10 | shape: [180, 320, 3]
11 | type: rgb
12 | "exterior_image_2_left": *camera_meta
13 | "wrist_image_left": *camera_meta
14 |
15 | # lowdim
16 | "cartesian_position":
17 | shape: [6]
18 | type: low_dim
19 | "gripper_position":
20 | shape: [1]
21 | type: low_dim
22 |
23 | action:
24 | # 3(pos) + 6(rot) + 1(gripper)
25 | shape: [10]
26 |
27 | seq_len: ${num_frames}
28 | history_len: ${obs_num_frames}
29 | normalize_action: True
30 | normalize_lowdim: True
31 | val_ratio: 0.05
--------------------------------------------------------------------------------
/configs/template_train_robomimic.yaml:
--------------------------------------------------------------------------------
1 | hydra:
2 | run:
3 | dir: /home/ubuntu/efs/chuning/logdir/${algo}/${dataset.name}/${exp_id}/${seed}
4 |
5 | algo: ${hydra:runtime.choices.model}
6 | exp_id: default
7 | seed: 0
8 | logdir: ${hydra:run.dir}
9 | resume: False
10 | use_amp: False
11 | eval_task_name: ${dataset.name}
12 | pretrain_checkpoint_path: null
13 |
14 | num_steps: 100000
15 | eval_every: 10000
16 | save_every: 10000
17 | rollout_every: 10000
18 | num_rollouts: 50
19 | rollout_length: 1000
20 |
21 | batch_size: 36
22 | num_frames: 19
23 | obs_num_frames: 2
24 | clip_grad_norm: null
25 |
26 | optimizer:
27 | lr: 1e-4
28 | weight_decay: 1e-6
29 | betas: [0.9, 0.999]
30 | eps: 1e-8
31 |
32 | scheduler:
33 | name: "constant"
34 |
35 | # Override resize shape
36 | model:
37 | obs_encoder:
38 | resize_shape: [240, 240]
39 |
--------------------------------------------------------------------------------
/scripts/launch_droid_pretrain.sh:
--------------------------------------------------------------------------------
1 | # Cache dataset
2 | DATA_DIR="/gscratch/weirdlab/memmelma/data/"
3 | BUFFER_PATH="/tmp/weirdlab/zchuning/data/droid/buffer_weird.zarr"
4 | if [ ! -d $BUFFER_PATH ]; then
5 | # Cache dataset
6 | echo "Caching dataset..."
7 | python datasets/droid/convert_dataset_zarr.py --data_dir $DATA_DIR --buffer_path $BUFFER_PATH --num_episodes 2000 --num_workers 8 --filter_key WEIRD
8 | fi
9 |
10 | # UWM
11 | python experiments/uwm/train.py dataset=droid exp_id=benchmark dataset.buffer_path=$BUFFER_PATH
12 |
13 | # DP
14 | # python experiments/dp/train.py dataset=droid exp_id=benchmark dataset.buffer_path=$BUFFER_PATH
15 |
16 | # GR1
17 | # python experiments/gr1/train.py dataset=droid exp_id=benchmark dataset.buffer_path=$BUFFER_PATH
18 |
19 | # PAD
20 | # python experiments/pad/train.py dataset=droid exp_id=benchmark dataset.buffer_path=$BUFFER_PATH
--------------------------------------------------------------------------------
/datasets/utils/file_utils.py:
--------------------------------------------------------------------------------
1 | import glob
2 | from os.path import expanduser, expandvars
3 | from typing import Union
4 |
5 | from omegaconf.listconfig import ListConfig
6 |
7 |
8 | def glob_all(globs: Union[str, list[str], ListConfig]) -> list[str]:
9 | """
10 | Expand a list of glob strings into a list of file paths.
11 |
12 | Args:
13 | globs: A glob string or a list of glob strings.
14 |
15 | Returns:
16 | A list of file paths matching the glob strings.
17 | """
18 |
19 | if isinstance(globs, str):
20 | globs = [globs]
21 | elif isinstance(globs, ListConfig):
22 | globs = list(globs)
23 | assert isinstance(globs, list)
24 |
25 | files = []
26 | for glob_str in globs:
27 | glob_str = expandvars(expanduser(glob_str))
28 | files.extend(glob.glob(glob_str))
29 |
30 | return sorted(files)
31 |
--------------------------------------------------------------------------------
/configs/model/gr1.yaml:
--------------------------------------------------------------------------------
1 | _target_: models.gr1.GR1
2 | action_len: 16
3 | action_dim: ${dataset.shape_meta.action.shape[0]}
4 | obs_encoder:
5 | _target_: models.gr1.GR1ObservationEncoder
6 | shape_meta: ${dataset.shape_meta}
7 | num_frames: ${obs_num_frames}
8 | embed_dim: 768
9 | resize_shape: [240, 320]
10 | crop_shape: [224, 224]
11 | random_crop: True
12 | color_jitter:
13 | brightness: 0.2
14 | contrast: 0.2
15 | saturation: 0.2
16 | hue: [-0.2, 0.2]
17 | imagenet_norm: True
18 | resampler_params:
19 | dim: 768
20 | depth: 3
21 | dim_head: 128
22 | heads: 4
23 | num_latents: 9
24 | num_media_embeds: 1
25 | embed_dim: 768
26 | image_size: 224
27 | patch_size: 32
28 | depth: 12
29 | num_heads: 12
30 | mlp_ratio: 4
31 | qkv_bias: True
32 | decoder_depth: 3
33 | decoder_num_heads: 16
34 | decoder_mlp_ratio: 4
35 | decoder_qkv_bias: True
--------------------------------------------------------------------------------
/configs/model/uwm.yaml:
--------------------------------------------------------------------------------
1 | _target_: models.uwm.UnifiedWorldModel
2 | action_len: 16
3 | action_dim: ${dataset.shape_meta.action.shape[0]}
4 | obs_encoder:
5 | _target_: models.uwm.UWMObservationEncoder
6 | shape_meta: ${dataset.shape_meta}
7 | num_frames: ${obs_num_frames}
8 | embed_dim: 768
9 | resize_shape: [240, 320]
10 | crop_shape: [224, 224]
11 | random_crop: True
12 | color_jitter:
13 | brightness: 0.2
14 | contrast: 0.2
15 | saturation: 0.2
16 | hue: [-0.2, 0.2]
17 | imagenet_norm: False
18 | vision_backbone: resnet
19 | use_low_dim: False
20 | use_language: False
21 | embed_dim: 768
22 | timestep_embed_dim: 512
23 | latent_patch_shape: [2, 4, 4]
24 | depth: 12
25 | num_heads: 12
26 | mlp_ratio: 4
27 | qkv_bias: True
28 | num_registers: 8
29 | num_train_steps: 100
30 | num_inference_steps: 10
31 | beta_schedule: squaredcos_cap_v2
32 | clip_sample: True
--------------------------------------------------------------------------------
/datasets/webvideo/__init__.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader, random_split
2 | from torch.utils.data.distributed import DistributedSampler
3 |
4 | from .dataset import MultiviewVideoDataset
5 |
6 |
7 | def make_multiview_video_dataset(
8 | name: str,
9 | index_paths: list[str],
10 | shape_meta: dict,
11 | seq_len: int,
12 | frame_skip: int = 2,
13 | obs_padding: str = "same",
14 | val_ratio: float = 0.0,
15 | ):
16 | dataset = MultiviewVideoDataset(
17 | index_paths=index_paths,
18 | shape_meta=shape_meta,
19 | clip_len=seq_len,
20 | frame_skip=frame_skip,
21 | obs_padding=obs_padding,
22 | )
23 |
24 | train_size = int(len(dataset) * (1 - val_ratio))
25 | val_size = len(dataset) - train_size
26 | train_set, val_set = random_split(dataset, [train_size, val_size])
27 | return train_set, val_set
28 |
--------------------------------------------------------------------------------
/datasets/utils/obs_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | import torch
4 |
5 |
6 | def unflatten_obs(x: dict[str, torch.Tensor]) -> dict[str, Union[dict, torch.Tensor]]:
7 | """
8 | Unflattens entries with keys starting with "obs" as follows:
9 |
10 | dict_data["obs.some_name"] -> dict_data["obs"]["some_name"]
11 |
12 | Args:
13 | x: A dictionary of tensors.
14 |
15 | Returns:
16 | The same dictionary, but keys starting with "obs" are unflattened.
17 | """
18 |
19 | obs = {}
20 | to_delete = []
21 | for key, value in x.items():
22 | if key.startswith("obs."):
23 | new_key = key[4:]
24 | assert new_key not in obs, f"Duplicate key {new_key}"
25 | obs[new_key] = value
26 | to_delete.append(key)
27 | for key in to_delete:
28 | del x[key]
29 | x["obs"] = obs
30 | return x
31 |
--------------------------------------------------------------------------------
/configs/model/dp.yaml:
--------------------------------------------------------------------------------
1 | _target_: models.dp.ImageDiffusionPolicy
2 | action_len: 16
3 | action_dim: ${dataset.shape_meta.action.shape[0]}
4 | obs_encoder:
5 | _target_: models.dp.ImageObservationEncoder
6 | shape_meta: ${dataset.shape_meta}
7 | num_frames: ${obs_num_frames}
8 | embed_dim: 768
9 | resize_shape: [240, 320]
10 | crop_shape: [224, 224]
11 | random_crop: True
12 | color_jitter:
13 | brightness: 0.2
14 | contrast: 0.2
15 | saturation: 0.2
16 | hue: [-0.2, 0.2]
17 | imagenet_norm: True
18 | pretrained_weights: IMAGENET1K_V1
19 | use_low_dim: False
20 | use_language: False
21 | noise_pred_net:
22 | _target_: models.dp.TransformerNoisePredictionNet
23 | _partial_: True
24 | input_len: ${model.action_len}
25 | input_dim: ${model.action_dim}
26 | global_cond_dim: ???
27 | timestep_embed_dim: 256
28 | embed_dim: 768
29 | num_heads: 12
30 | mlp_ratio: 4
31 | qkv_bias: True
32 | num_train_steps: 100
33 | num_inference_steps: 10
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=75.1", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "unified_world_model"
7 | version = "0.1.0"
8 | description = "PyTorch implementation of Unified World Models: Coupling Video and Action Diffusion for Pretraining on Large Robotic Datasets"
9 | readme = "README.md"
10 | requires-python = ">=3.10.14"
11 | dynamic = ["dependencies"]
12 |
13 | authors = [
14 | { name = "Chuning Zhu" },
15 | { name = "Raymond Yu" },
16 | { name = "Siyuan Feng" },
17 | { name = "Benjamin Burchfiel" },
18 | { name = "Paarth Shah" },
19 | { name = "Abhishek Gupta" }
20 | ]
21 |
22 | classifiers = [
23 | "Programming Language :: Python :: 3",
24 | "Programming Language :: Python :: 3.10",
25 | "Operating System :: OS Independent"
26 | ]
27 |
28 | [tool.setuptools]
29 | packages = [
30 | "models",
31 | "configs",
32 | "datasets",
33 | "experiments",
34 | "environments"
35 | ]
36 |
37 | [tool.setuptools.dynamic]
38 | dependencies = { file = ["requirements.txt"] }
--------------------------------------------------------------------------------
/configs/dataset/libero_90_webvideo.yaml:
--------------------------------------------------------------------------------
1 | _target_: datasets.utils.mixture.make_robot_video_mixture_dataset
2 | name: libero_90_webvideo
3 | shape_meta: &shape_meta
4 | obs:
5 | agentview_rgb: &camera_meta
6 | shape: [128, 128, 3]
7 | type: rgb
8 | eye_in_hand_rgb: *camera_meta
9 | action:
10 | shape: [7] # pos + quarternion + gripper
11 | hdf5_path_globs: &hdf5_path_globs /home/ubuntu/LIBERO/libero/datasets/libero_90/*.hdf5
12 | robot_train_val_sets:
13 | _target_: datasets.robomimic.make_robomimic_dataset
14 | name: libero_90
15 | hdf5_path_globs: *hdf5_path_globs
16 | buffer_path: /home/ubuntu/LIBERO/libero/datasets/libero_90/buffer.zarr
17 | shape_meta: *shape_meta
18 | seq_len: ${num_frames}
19 | val_ratio: 0.05
20 | video_train_val_sets:
21 | _target_: datasets.webvideo.make_multiview_video_dataset
22 | name: webvideo
23 | index_paths: [
24 | /home/ubuntu/efs/chuning/k400_filtered.txt,
25 | /home/ubuntu/efs/chuning/ssv2_filtered.txt,
26 | ]
27 | shape_meta: *shape_meta
28 | seq_len: ${num_frames}
29 | obs_padding: "same"
30 | frame_skip: 2
31 | val_ratio: 0.05
32 | balance_datasets: False
--------------------------------------------------------------------------------
/configs/dataset/droid_webvideo.yaml:
--------------------------------------------------------------------------------
1 | _target_: datasets.utils.mixture.make_robot_video_mixture_dataset
2 | name: droid_webvideo
3 | shape_meta: &shape_meta
4 | obs:
5 | "exterior_image_1_left": &camera_meta
6 | shape: [180, 320, 3]
7 | type: rgb
8 | "exterior_image_2_left": *camera_meta
9 | "wrist_image_left": *camera_meta
10 | "cartesian_position":
11 | shape: [6]
12 | type: low_dim
13 | "gripper_position":
14 | shape: [1]
15 | type: low_dim
16 | action:
17 | shape: [10]
18 | robot_train_val_sets:
19 | _target_: datasets.droid.make_droid_dataset
20 | name: droid
21 | buffer_path: /tmp/weirdlab/zchuning/data/droid/buffer.zarr
22 | shape_meta: *shape_meta
23 | seq_len: ${num_frames}
24 | history_len: ${obs_num_frames}
25 | normalize_action: True
26 | normalize_lowdim: True
27 | val_ratio: 0.05
28 | video_train_val_sets:
29 | _target_: datasets.webvideo.make_multiview_video_dataset
30 | name: webvideo
31 | index_paths: [
32 | /gscratch/weirdlab/zchuning/data/k400_filtered.txt,
33 | /gscratch/weirdlab/zchuning/data/ssv2_filtered.txt,
34 | ]
35 | shape_meta: *shape_meta
36 | seq_len: ${num_frames}
37 | obs_padding: "same"
38 | frame_skip: 2
39 | val_ratio: 0.05
40 | balance_datasets: False
--------------------------------------------------------------------------------
/datasets/droid/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from torch.utils.data import DataLoader
4 | from torch.utils.data.distributed import DistributedSampler
5 |
6 | from .dataset import DroidDataset, DroidMixtureDataset
7 |
8 |
9 | def make_droid_dataset(
10 | name: str,
11 | buffer_path: str,
12 | shape_meta: dict,
13 | seq_len: int,
14 | history_len: bool = 1,
15 | normalize_action: bool = False,
16 | normalize_lowdim: bool = False,
17 | val_ratio: float = 0.0,
18 | video_buffer_path: Optional[str] = None,
19 | balance_datasets: bool = False,
20 | ):
21 | # Training dataset
22 | train_set = DroidDataset(
23 | name=name,
24 | buffer_path=buffer_path,
25 | shape_meta=shape_meta,
26 | seq_len=seq_len,
27 | history_len=history_len,
28 | normalize_lowdim=normalize_lowdim,
29 | normalize_action=normalize_action,
30 | val_ratio=val_ratio,
31 | )
32 | if video_buffer_path is not None:
33 | train_set = DroidMixtureDataset(
34 | base_dataset=train_set,
35 | video_buffer_path=video_buffer_path,
36 | balance_datasets=balance_datasets,
37 | )
38 |
39 | # Validation dataset
40 | val_set = train_set.get_validation_dataset()
41 | return train_set, val_set
42 |
--------------------------------------------------------------------------------
/datasets/robomimic/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Union
3 |
4 | import torch.distributed as dist
5 | from .dataset import RobomimicDataset
6 |
7 |
8 | def make_robomimic_dataset(
9 | name: str,
10 | hdf5_path_globs: Union[str, list[str]],
11 | buffer_path: str,
12 | shape_meta: dict,
13 | seq_len: int,
14 | val_ratio: float = 0.0,
15 | subsample_ratio: float = 1.0,
16 | flip_rgb: bool = False,
17 | ):
18 | # Cache compressed dataset in the main process
19 | if not os.path.exists(buffer_path):
20 | if not dist.is_initialized() or dist.get_rank() == 0:
21 | RobomimicDataset(
22 | name=name,
23 | hdf5_path_globs=hdf5_path_globs,
24 | buffer_path=buffer_path,
25 | shape_meta=shape_meta,
26 | seq_len=seq_len,
27 | flip_rgb=flip_rgb,
28 | )
29 | if dist.is_initialized():
30 | dist.barrier()
31 |
32 | # Training dataset
33 | train_set = RobomimicDataset(
34 | name=name,
35 | hdf5_path_globs=hdf5_path_globs,
36 | buffer_path=buffer_path,
37 | shape_meta=shape_meta,
38 | seq_len=seq_len,
39 | val_ratio=val_ratio,
40 | subsample_ratio=subsample_ratio,
41 | flip_rgb=flip_rgb,
42 | )
43 | val_set = train_set.get_validation_dataset()
44 | return train_set, val_set
45 |
--------------------------------------------------------------------------------
/scripts/launch_droid_cotrain.sh:
--------------------------------------------------------------------------------
1 | DATA_DIR="/gscratch/weirdlab/memmelma/data/"
2 |
3 | # Cache robot dataset
4 | ROBOT_BUFFER_PATH="/tmp/weirdlab/zchuning/data/droid/buffer_weird.zarr"
5 | if [ ! -d $ROBOT_BUFFER_PATH ]; then
6 | # Cache dataset
7 | echo "Caching robot dataset..."
8 | python datasets/droid/convert_dataset_zarr.py --data_dir $DATA_DIR --buffer_path $ROBOT_BUFFER_PATH --num_episodes 2000 --num_workers 8 --filter_key WEIRD
9 | fi
10 |
11 | # Cache video dataset
12 | VIDEO_BUFFER_PATH="/tmp/weirdlab/zchuning/data/droid/buffer_video.zarr"
13 | if [ ! -d $VIDEO_BUFFER_PATH ]; then
14 | # Cache dataset
15 | echo "Caching video dataset..."
16 | python datasets/droid/convert_dataset_zarr.py --data_dir $DATA_DIR --buffer_path $VIDEO_BUFFER_PATH --num_episodes 2000 --num_workers 8 --except_key WEIRD
17 | fi
18 |
19 | # UWM
20 | python experiments/uwm/train.py dataset=droid_mixture exp_id=benchmark_cotrain \
21 | dataset.buffer_path=$ROBOT_BUFFER_PATH \
22 | dataset.video_buffer_path=$VIDEO_BUFFER_PATH
23 |
24 | # GR1
25 | # python experiments/gr1/train.py dataset=droid_mixture exp_id=benchmark_cotrain \
26 | # dataset.buffer_path=$ROBOT_BUFFER_PATH \
27 | # dataset.video_buffer_path=$VIDEO_BUFFER_PATH
28 |
29 | # PAD
30 | # python experiments/pad/train.py dataset=droid_mixture exp_id=benchmark_cotrain \
31 | # dataset.buffer_path=$ROBOT_BUFFER_PATH \
32 | # dataset.video_buffer_path=$VIDEO_BUFFER_PATH
33 |
--------------------------------------------------------------------------------
/datasets/utils/loader.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset, DataLoader
2 | from torch.utils.data.distributed import DistributedSampler
3 |
4 |
5 | def make_distributed_data_loader(
6 | train_set: Dataset,
7 | val_set: Dataset,
8 | batch_size: int,
9 | rank: int = 0,
10 | world_size: int = 1,
11 | num_workers: int = 8,
12 | pin_memory: bool = True,
13 | drop_last: bool = True,
14 | persistent_workers: bool = True,
15 | ):
16 | # Training sampler and loader
17 | train_sampler = DistributedSampler(
18 | train_set,
19 | num_replicas=world_size,
20 | rank=rank,
21 | shuffle=True,
22 | )
23 | train_loader = DataLoader(
24 | train_set,
25 | batch_size=batch_size,
26 | sampler=train_sampler,
27 | num_workers=num_workers,
28 | pin_memory=pin_memory,
29 | drop_last=drop_last,
30 | persistent_workers=persistent_workers,
31 | )
32 |
33 | # Validation sampler and loader
34 | val_sampler = DistributedSampler(
35 | val_set,
36 | num_replicas=world_size,
37 | rank=rank,
38 | shuffle=False,
39 | )
40 | val_loader = DataLoader(
41 | val_set,
42 | batch_size=batch_size,
43 | sampler=val_sampler,
44 | num_workers=num_workers,
45 | pin_memory=pin_memory,
46 | drop_last=drop_last,
47 | persistent_workers=persistent_workers,
48 | )
49 |
50 | return train_loader, val_loader
51 |
--------------------------------------------------------------------------------
/scripts/setup_aws.sh:
--------------------------------------------------------------------------------
1 | # Install conda
2 | mkdir -p ~/miniconda3
3 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
4 | bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
5 | rm ~/miniconda3/miniconda.sh
6 |
7 | # Refresh terminal
8 | source ~/miniconda3/bin/activate
9 |
10 | # Modify bashrc
11 | cat << 'EOF' >> ~/.bashrc
12 |
13 | # >>> conda initialize >>>
14 | # !! Contents within this block are managed by 'conda init' !!
15 | __conda_setup="$('/home/ubuntu/miniconda3/bin/conda' 'shell.bash' 'hook' 2> /dev/null)"
16 | if [ $? -eq 0 ]; then
17 | eval "$__conda_setup"
18 | else
19 | if [ -f "/home/ubuntu/miniconda3/etc/profile.d/conda.sh" ]; then
20 | . "/home/ubuntu/miniconda3/etc/profile.d/conda.sh"
21 | else
22 | export PATH="/home/ubuntu/miniconda3/bin:$PATH"
23 | fi
24 | fi
25 | unset __conda_setup
26 | # <<< conda initialize <<<
27 | EOF
28 |
29 | # Create conda environment
30 | conda create --yes -n val python==3.10.14
31 | conda activate val
32 |
33 | # Install requirements
34 | pip install -r requirements.txt
35 |
36 | # Clone robomimic fork
37 | cd ..
38 | git clone git@github.com:zchuning/robomimic.git
39 | cd robomimic
40 | pip install -e .
41 |
42 | # Download robomimic dataset
43 | python robomimic/scripts/setup_macros.py
44 | python robomimic/scripts/download_datasets.py --tasks sim --dataset_types ph --hdf5_types raw
45 | cd robomimic/scripts
46 | source extract_obs_from_raw_datasets.sh
47 |
48 | # Clone LIBERO fork
49 | cd ..
50 | git clone git@github.com:zchuning/LIBERO.git
51 | cd LIBERO
52 | pip install -e .
53 | pip install -r requirements.txt
54 |
55 | # Download LIBERO data
56 | python benchmark_scripts/download_libero_datasets.py --datasets libero_100
--------------------------------------------------------------------------------
/datasets/utils/sampler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from .buffer import CompressedTrajectoryBuffer
4 |
5 |
6 | class TrajectorySampler:
7 | """
8 | A class that samples sequences of observations and actions from a trajectory buffer.
9 | """
10 |
11 | def __init__(
12 | self,
13 | buffer: CompressedTrajectoryBuffer,
14 | seq_len: int,
15 | episode_mask: np.ndarray = None,
16 | ):
17 | """
18 | Initialize the trajectory sampler.
19 |
20 | Args:
21 | buffer: The trajectory buffer containing the data.
22 | seq_len: The length of the sequences to sample.
23 | episode_mask: A binary mask indicating valid episodes. If None, all episodes are valid.
24 | """
25 | self.buffer = buffer
26 | self.seq_len = seq_len
27 | self.keys = list(self.buffer.keys())
28 |
29 | # Compute all possible sample indices
30 | indices = []
31 | episode_start = 0
32 | for i, episode_end in enumerate(self.buffer.episode_ends):
33 | if episode_mask is None or episode_mask[i]:
34 | for j in range(episode_start, episode_end + 1 - seq_len):
35 | indices.append([j, j + seq_len])
36 | episode_start = episode_end
37 | self.indices = np.array(indices, dtype=np.int64)
38 | print(f"Total number of valid sequences: {len(self.indices)}")
39 |
40 | def __len__(self) -> int:
41 | return len(self.indices)
42 |
43 | def sample_sequence(self, index: int) -> dict[str, np.ndarray]:
44 | start, end = self.indices[index]
45 | data = {}
46 | for key in self.keys:
47 | arr = self.buffer[key]
48 | value = arr[start:end]
49 | data[key] = value
50 | return data
51 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.33.0
2 | antlr4-python3-runtime==4.9.3
3 | appdirs==1.4.4
4 | asciitree==0.3.3
5 | certifi==2024.6.2
6 | charset-normalizer==3.3.2
7 | click==8.1.7
8 | cloudpickle==3.0.0
9 | dask==2024.6.0
10 | decord==0.6.0
11 | diffusers==0.30.0
12 | dm-tree==0.1.8
13 | docker-pycreds==0.4.0
14 | einops==0.7.0
15 | fasteners==0.19
16 | filelock==3.15.3
17 | fsspec==2024.6.0
18 | ftfy==6.2.3
19 | gitdb==4.0.11
20 | GitPython==3.1.43
21 | huggingface-hub==0.23.4
22 | hydra-core==1.3.2
23 | idna==3.7
24 | imageio==2.21.2
25 | imageio_ffmpeg==0.5.1
26 | importlib_metadata==7.2.0
27 | Jinja2==3.1.4
28 | locket==1.0.0
29 | markdown-it-py==3.0.0
30 | MarkupSafe==2.1.5
31 | mdurl==0.1.2
32 | moviepy==2.1.1
33 | mpmath==1.3.0
34 | networkx==3.3
35 | numcodecs==0.12.1
36 | numpy==1.26.4
37 | nvidia-cublas-cu12==12.1.3.1
38 | nvidia-cuda-cupti-cu12==12.1.105
39 | nvidia-cuda-nvrtc-cu12==12.1.105
40 | nvidia-cuda-runtime-cu12==12.1.105
41 | nvidia-cudnn-cu12==8.9.2.26
42 | nvidia-cufft-cu12==11.0.2.54
43 | nvidia-curand-cu12==10.3.2.106
44 | nvidia-cusolver-cu12==11.4.5.107
45 | nvidia-cusparse-cu12==12.1.0.106
46 | nvidia-nccl-cu12==2.19.3
47 | nvidia-nvjitlink-cu12==12.5.40
48 | nvidia-nvtx-cu12==12.1.105
49 | omegaconf==2.3.0
50 | packaging==24.1
51 | pandas==2.2.1
52 | partd==1.4.2
53 | pillow==10.3.0
54 | promise==2.3
55 | protobuf==3.20.3
56 | psutil==6.0.0
57 | Pygments==2.18.0
58 | python-dateutil==2.9.0.post0
59 | pytz==2024.1
60 | PyYAML==6.0.1
61 | regex==2024.5.15
62 | requests==2.32.3
63 | rich==13.7.1
64 | robosuite==1.4.1
65 | safetensors==0.4.4
66 | scipy==1.14.0
67 | sentry-sdk==2.6.0
68 | setproctitle==1.3.3
69 | six==1.16.0
70 | smmap==5.0.1
71 | sympy==1.12.1
72 | tensorflow==2.15.0
73 | tensorflow-datasets==4.9.7
74 | tensorflow-metadata==1.16.1
75 | timm==1.0.9
76 | tokenizers==0.19.1
77 | toml==0.10.2
78 | toolz==0.12.1
79 | torch==2.2.2
80 | torchaudio==2.2.2
81 | torchvision==0.17.2
82 | tqdm==4.66.2
83 | transformers==4.44.0
84 | triton==2.2.0
85 | typing_extensions==4.12.2
86 | tzdata==2024.1
87 | urllib3==2.2.2
88 | wandb==0.19.1
89 | wcwidth==0.2.13
90 | zarr==2.18.2
91 | zipp==3.19.2
92 |
--------------------------------------------------------------------------------
/datasets/utils/normalizer.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | import numpy as np
3 |
4 |
5 | class LinearNormalizer:
6 | """
7 | A class to linearly normalizes data to the range [-1, 1].
8 | """
9 |
10 | def __init__(self, scale: np.ndarray, offset: np.ndarray):
11 | """
12 | Initializes a new instance of the LinearNormalizer class with given statistics.
13 |
14 | Args:
15 | scale: The scale factor for normalization.
16 | offset: The offset for normalization.
17 | """
18 | self.scale = scale
19 | self.offset = offset
20 |
21 | def __call__(self, x: np.ndarray) -> np.ndarray:
22 | """
23 | Normalizes the input data using the stored statistics.
24 | """
25 | return (x - self.offset) / self.scale
26 |
27 | def reconstruct(self, x: np.ndarray) -> np.ndarray:
28 | """
29 | Reconstructs the original data from the normalized data.
30 | """
31 | return x * self.scale + self.offset
32 |
33 |
34 | class NestedDictLinearNormalizer(dict):
35 | """
36 | A class that applies linear normalization to values in a nested dictionary structure.
37 | """
38 |
39 | def __init__(self, stats: dict[str, Union[tuple[np.ndarray, np.ndarray], dict]]):
40 | """
41 | Initializes a new instance of the NestedDictLinearNormalizer class with given statistics.
42 |
43 | Args:
44 | stats: A dictionary containing statistics for each key. The values can either
45 | be tuples representing scale and offset values or dictionaries that require
46 | recursive scaling.
47 | """
48 | super().__init__()
49 | for k, v in stats.items():
50 | if isinstance(v, dict):
51 | self[k] = NestedDictLinearNormalizer(v)
52 | else:
53 | self[k] = LinearNormalizer(np.array(v[0]), np.array(v[1]))
54 |
55 | def __call__(self, x: dict) -> dict:
56 | """
57 | Normalizes all values in the input dictionary based on the stored normalizers.
58 | """
59 | return {k: self[k](v) if k in self.keys() else v for k, v in x.items()}
60 |
61 | def reconstruct(self, x: dict) -> dict:
62 | """
63 | Reconstructs the original values from normalized values in the input dictionary.
64 | """
65 | return {
66 | k: self[k].reconstruct(v) if k in self.keys() else v for k, v in x.items()
67 | }
68 |
--------------------------------------------------------------------------------
/models/dp/image_policy.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | from .obs_encoder import ImageObservationEncoder
4 | from .base_policy import NoisePredictionNet, DiffusionPolicy, FlowPolicy
5 |
6 |
7 | class ImageDiffusionPolicy(DiffusionPolicy):
8 | def __init__(
9 | self,
10 | action_len: int,
11 | action_dim: int,
12 | obs_encoder: ImageObservationEncoder,
13 | noise_pred_net: partial[NoisePredictionNet],
14 | num_train_steps: int = 100,
15 | num_inference_steps: int = 10,
16 | num_train_noise_samples: int = 1,
17 | beta_schedule: str = "squaredcos_cap_v2",
18 | clip_sample: bool = True,
19 | ):
20 | """
21 | Assumes rgb input: (B, T, H, W, C) uint8 image
22 | Assumes low_dim input: (B, T, D)
23 | """
24 | super().__init__(
25 | action_len=action_len,
26 | action_dim=action_dim,
27 | noise_pred_net=noise_pred_net(global_cond_dim=obs_encoder.output_len),
28 | num_train_steps=num_train_steps,
29 | num_inference_steps=num_inference_steps,
30 | num_train_noise_samples=num_train_noise_samples,
31 | beta_schedule=beta_schedule,
32 | clip_sample=clip_sample,
33 | )
34 |
35 | # Observation encoder
36 | self.obs_encoder = obs_encoder
37 |
38 | def sample(self, obs_dict):
39 | obs = self.obs_encoder(obs_dict)
40 | action = super().sample(obs)
41 | return action
42 |
43 | def forward(self, obs_dict, action):
44 | obs = self.obs_encoder(obs_dict)
45 | loss = super().forward(obs, action)
46 | return loss
47 |
48 |
49 | class ImageFlowPolicy(FlowPolicy):
50 | def __init__(
51 | self,
52 | action_len: int,
53 | action_dim: int,
54 | obs_encoder: ImageObservationEncoder,
55 | noise_pred_net: partial[NoisePredictionNet],
56 | num_train_steps: int = 100,
57 | num_inference_steps: int = 10,
58 | timeshift: float = 1.0,
59 | ):
60 | """
61 | Assumes rgb input: (B, T, H, W, C) uint8 image
62 | Assumes low_dim input: (B, T, D)
63 | """
64 | super().__init__(
65 | action_len=action_len,
66 | action_dim=action_dim,
67 | noise_pred_net=noise_pred_net(global_cond_dim=obs_encoder.output_len),
68 | num_train_steps=num_train_steps,
69 | num_inference_steps=num_inference_steps,
70 | timeshift=timeshift,
71 | )
72 |
73 | # Observation encoder
74 | self.obs_encoder = obs_encoder
75 |
76 | def sample(self, obs_dict):
77 | obs = self.obs_encoder(obs_dict)
78 | action = super().sample(obs)
79 | return action
80 |
81 | def forward(self, obs_dict, action):
82 | obs = self.obs_encoder(obs_dict)
83 | loss = super().forward(obs, action)
84 | return loss
85 |
--------------------------------------------------------------------------------
/scripts/launch_droid_finetune.sh:
--------------------------------------------------------------------------------
1 | # Cache dataset
2 | DATA_NAME=stack_red_bowl_on_blue_bowl
3 | DATA_DIR="/gscratch/weirdlab/zchuning/data/"
4 | BUFFER_PATH="/tmp/weirdlab/zchuning/data/droid/buffer_$DATA_NAME.zarr"
5 | if [ ! -d $BUFFER_PATH ]; then
6 | # Cache dataset
7 | echo "Caching dataset..."
8 | python datasets/droid/convert_dataset_zarr.py --data_name $DATA_NAME --data_dir $DATA_DIR --buffer_path $BUFFER_PATH --num_episodes 2000 --num_workers 8
9 | fi
10 |
11 | # UWM
12 | PRETRAIN_CHECKPOINT_PATH="/gscratch/weirdlab/zchuning/video-action-learning/logdir/uwm/droid/benchmark/0/models.pt"
13 | python experiments/uwm/train.py --config-name finetune_uwm.yaml dataset=droid exp_id=finetune_benchmark \
14 | dataset.name=droid_$DATA_NAME \
15 | dataset.buffer_path=$BUFFER_PATH \
16 | pretrain_checkpoint_path=$PRETRAIN_CHECKPOINT_PATH
17 |
18 | # UWM cotrained
19 | # PRETRAIN_CHECKPOINT_PATH="/gscratch/weirdlab/zchuning/video-action-learning/logdir/uwm/droid_mixture/benchmark_cotrain/0/models.pt"
20 | # python experiments/uwm/finetune.py dataset=droid exp_id=finetune_benchmark_cotrain \
21 | # dataset.name=droid_$DATA_NAME \
22 | # dataset.buffer_path=$BUFFER_PATH \
23 | # pretrain_checkpoint_path=$PRETRAIN_CHECKPOINT_PATH
24 |
25 | # DP
26 | # PRETRAIN_CHECKPOINT_PATH="/gscratch/weirdlab/zchuning/video-action-learning/logdir/dp/droid/benchmark/0/models.pt"
27 | # python experiments/dp/finetune.py dataset=droid exp_id=finetune_benchmark \
28 | # dataset.name=droid_$DATA_NAME \
29 | # dataset.buffer_path=$BUFFER_PATH \
30 | # pretrain_checkpoint_path=$PRETRAIN_CHECKPOINT_PATH
31 |
32 | # GR1
33 | # PRETRAIN_CHECKPOINT_PATH="/gscratch/weirdlab/zchuning/video-action-learning/logdir/gr1/droid/benchmark/0/models.pt"
34 | # python experiments/gr1/finetune.py dataset=droid exp_id=finetune_benchmark \
35 | # dataset.name=droid_$DATA_NAME \
36 | # dataset.buffer_path=$BUFFER_PATH \
37 | # pretrain_checkpoint_path=$PRETRAIN_CHECKPOINT_PATH
38 |
39 | # GR1 cotrained
40 | # PRETRAIN_CHECKPOINT_PATH="/gscratch/weirdlab/zchuning/video-action-learning/logdir/gr1/droid_mixture/benchmark_cotrain/0/models.pt"
41 | # python experiments/gr1/finetune.py dataset=droid exp_id=finetune_benchmark_cotrain \
42 | # dataset.name=droid_$DATA_NAME \
43 | # dataset.buffer_path=$BUFFER_PATH \
44 | # pretrain_checkpoint_path=$PRETRAIN_CHECKPOINT_PATH
45 |
46 | # PAD
47 | # PRETRAIN_CHECKPOINT_PATH="/gscratch/weirdlab/zchuning/video-action-learning/logdir/pad/droid/benchmark/0/models.pt"
48 | # python experiments/pad/finetune.py dataset=droid exp_id=finetune_benchmark \
49 | # dataset.name=droid_$DATA_NAME \
50 | # dataset.buffer_path=$BUFFER_PATH \
51 | # pretrain_checkpoint_path=$PRETRAIN_CHECKPOINT_PATH
52 |
53 |
54 | # PAD cotrained
55 | # PRETRAIN_CHECKPOINT_PATH="/gscratch/weirdlab/zchuning/video-action-learning/logdir/pad/droid_mixture/benchmark_cotrain/0/models.pt"
56 | # python experiments/pad/finetune.py dataset=droid exp_id=finetune_benchmark_cotrain \
57 | # dataset.name=droid_$DATA_NAME \
58 | # dataset.buffer_path=$BUFFER_PATH \
59 | # pretrain_checkpoint_path=$PRETRAIN_CHECKPOINT_PATH
--------------------------------------------------------------------------------
/experiments/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import random
4 | from datetime import timedelta
5 |
6 | import numpy as np
7 | import torch
8 | import torch.distributed as dist
9 | import wandb
10 | from omegaconf import OmegaConf
11 |
12 | DATA_TYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
13 |
14 |
15 | def set_seed(seed):
16 | """Set random seed for reproducibility."""
17 | random.seed(seed)
18 | np.random.seed(seed)
19 | torch.manual_seed(seed)
20 |
21 |
22 | def init_wandb(config, job_type):
23 | """Initialize WANDB logging. Only use in main process.
24 |
25 | Args:
26 | config: config dictionary
27 | job_type: "train" or "eval"
28 | """
29 | run_id_path = os.path.join(config.logdir, f"run_id_{job_type}.json")
30 | if config.resume and os.path.exists(run_id_path):
31 | # Load WANDB run ID from log directory
32 | with open(run_id_path, "r") as f:
33 | run_id = json.load(f)["run_id"]
34 | else:
35 | # Generate new WANDB run ID
36 | run_id = wandb.util.generate_id()
37 | with open(run_id_path, "w") as f:
38 | json.dump({"run_id": run_id}, f)
39 |
40 | wandb.init(
41 | project="video-action-learning",
42 | job_type=job_type,
43 | group=config.algo,
44 | name="_".join([config.exp_id, str(config.seed)]),
45 | config=OmegaConf.to_container(config, resolve=True),
46 | resume=config.resume,
47 | id=run_id,
48 | )
49 |
50 |
51 | def init_distributed(rank, world_size):
52 | """Initialize distributed training and set visible device.
53 |
54 | Args:
55 | rank: unique identifier of each process
56 | world_size: total number of processes
57 | """
58 | os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost")
59 | os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "25678")
60 | dist.init_process_group(
61 | backend="nccl",
62 | rank=rank,
63 | world_size=world_size,
64 | timeout=timedelta(seconds=3600),
65 | )
66 |
67 |
68 | def get_rank():
69 | if not dist.is_initialized():
70 | return 0
71 | return dist.get_rank()
72 |
73 |
74 | def get_world_size():
75 | if not dist.is_initialized():
76 | return 1
77 | return dist.get_world_size()
78 |
79 |
80 | def is_main_process():
81 | return get_rank() == 0
82 |
83 |
84 | def soft_update(target, source, tau):
85 | """Soft update target model with source model."""
86 | for target_param, param in zip(target.parameters(), source.parameters()):
87 | target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
88 |
89 |
90 | class FreezeParameters:
91 | def __init__(self, params):
92 | self.params = params
93 | self.param_states = [p.requires_grad for p in self.params]
94 |
95 | def __enter__(self):
96 | for param in self.params:
97 | param.requires_grad = False
98 |
99 | def __exit__(self, exc_type, exc_val, exc_tb):
100 | for i, param in enumerate(self.params):
101 | param.requires_grad = self.param_states[i]
102 |
--------------------------------------------------------------------------------
/environments/robomimic/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from .wrappers import RoboMimicEnvWrapper, LIBEROEnvWrapper
4 |
5 |
6 | def make_robomimic_env(
7 | dataset_name,
8 | dataset_path,
9 | shape_meta,
10 | obs_horizon,
11 | max_episode_length,
12 | record=False,
13 | ):
14 | if "robomimic" in dataset_name:
15 | from robomimic.utils.env_utils import get_env_type, get_env_class
16 | from robomimic.utils.file_utils import (
17 | get_env_metadata_from_dataset,
18 | get_shape_metadata_from_dataset,
19 | )
20 | from robomimic.utils.obs_utils import initialize_obs_utils_with_obs_specs
21 |
22 | # Initialize observation modalities
23 | rgb_keys = [k for k, v in shape_meta["obs"].items() if v["type"] == "rgb"]
24 | low_dim_keys = [
25 | k for k, v in shape_meta["obs"].items() if v["type"] == "low_dim"
26 | ]
27 | all_obs_keys = rgb_keys + low_dim_keys
28 | initialize_obs_utils_with_obs_specs(
29 | {"obs": {"rgb": rgb_keys, "low_dim": low_dim_keys}}
30 | )
31 |
32 | # Create environment
33 | env_meta = get_env_metadata_from_dataset(dataset_path=dataset_path)
34 | env_type = get_env_type(env_meta=env_meta)
35 | env_class = get_env_class(env_type=env_type)
36 | shape_meta = get_shape_metadata_from_dataset(
37 | dataset_path=dataset_path,
38 | all_obs_keys=all_obs_keys,
39 | verbose=True,
40 | )
41 | # Set render device if CUDA_VISIBLE_DEVICES is set
42 | if os.environ.get("CUDA_VISIBLE_DEVICES", None):
43 | env_meta["env_kwargs"]["render_gpu_device_id"] = int(
44 | os.environ["CUDA_VISIBLE_DEVICES"].split(",")[0]
45 | )
46 | env = env_class(
47 | env_name=env_meta["env_name"],
48 | render=False,
49 | render_offscreen=True,
50 | use_image_obs=shape_meta["use_images"],
51 | use_depth_obs=shape_meta["use_depths"],
52 | postprocess_visual_obs=False, # use raw images
53 | **env_meta["env_kwargs"],
54 | )
55 | env = RoboMimicEnvWrapper(
56 | env, all_obs_keys, obs_horizon, max_episode_length, record=record
57 | )
58 | elif "libero" in dataset_name:
59 | from libero.libero.envs import OffScreenRenderEnv
60 | from libero.libero import get_libero_path
61 |
62 | # Construct environment kwargs
63 | bddl_file_name = os.path.join(
64 | get_libero_path("bddl_files"),
65 | "libero_10",
66 | dataset_path.split("/")[-1].replace("_demo.hdf5", ".bddl"),
67 | )
68 | env_kwargs = {
69 | "bddl_file_name": bddl_file_name,
70 | "camera_heights": 128,
71 | "camera_widths": 128,
72 | }
73 |
74 | # Set render device if CUDA_VISIBLE_DEVICES is set
75 | if os.environ.get("CUDA_VISIBLE_DEVICES", None):
76 | env_kwargs["render_gpu_device_id"] = int(
77 | os.environ["CUDA_VISIBLE_DEVICES"].split(",")[0]
78 | )
79 |
80 | # Create environment
81 | env = OffScreenRenderEnv(**env_kwargs)
82 | obs_keys = list(shape_meta["obs"].keys())
83 | env = LIBEROEnvWrapper(
84 | env, obs_keys, obs_horizon, max_episode_length, record=record
85 | )
86 | else:
87 | raise NotImplementedError(f"Unsupported environment: {dataset_name}")
88 | return env
89 |
--------------------------------------------------------------------------------
/experiments/uwm/eval_robomimic.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import hydra
4 | import imageio
5 | import numpy as np
6 | import torch
7 | from diffusers.optimization import get_scheduler
8 | from omegaconf import OmegaConf
9 | from hydra.utils import instantiate
10 | from tqdm import trange
11 |
12 | from environments.robomimic import make_robomimic_env
13 | from experiments.uwm.train import maybe_resume_checkpoint
14 | from experiments.utils import set_seed, is_main_process
15 |
16 |
17 | def collect_rollout(config, model, device):
18 | model.eval()
19 | model = getattr(model, "module", model) # unwrap DDP
20 |
21 | # Create eval environment
22 | assert isinstance(config.dataset.hdf5_path_globs, str)
23 | env = make_robomimic_env(
24 | dataset_name=config.dataset.name,
25 | dataset_path=config.dataset.hdf5_path_globs,
26 | shape_meta=config.dataset.shape_meta,
27 | obs_horizon=model.obs_encoder.num_frames,
28 | max_episode_length=config.rollout_length,
29 | record=True,
30 | )
31 |
32 | # Collect rollouts
33 | video_dir = os.path.join(config.logdir, "videos")
34 | if not os.path.exists(video_dir):
35 | os.mkdir(video_dir)
36 | successes = []
37 | for e in trange(
38 | config.num_rollouts, desc="Collecting rollouts", disable=not is_main_process()
39 | ):
40 | env.seed(e)
41 | obs = env.reset()
42 | done = False
43 | while not done:
44 | obs_tensor = {
45 | k: torch.tensor(v, device=device)[None] for k, v in obs.items()
46 | }
47 |
48 | # Sample action from model
49 | action = model.sample(obs_tensor)[0].cpu().numpy()
50 |
51 | # Step environment
52 | obs, reward, done, info = env.step(action)
53 | successes.append(info["success"])
54 | video = env.get_video()
55 | imageio.mimwrite(os.path.join(video_dir, f"{e}.mp4"), video, fps=30)
56 | print(
57 | f"Episode {e} success: {info['success']}, cumulative: {np.mean(successes):.2f}"
58 | )
59 |
60 | # Compute success rate
61 | success_rate = sum(successes) / len(successes)
62 | return success_rate
63 |
64 |
65 | def maybe_collect_rollout(config, step, model, device):
66 | """Collect rollouts on the main process if it's the correct step."""
67 | # Skip rollout rollection for pretraining
68 | if "libero_90" in config.dataset.name:
69 | return
70 |
71 | if is_main_process() and (
72 | step % config.rollout_every == 0 or step == (config.num_steps - 1)
73 | ):
74 | success_rate = collect_rollout(config, model, device)
75 | print(f"Step: {step} success rate: {success_rate}")
76 |
77 |
78 | @hydra.main(
79 | version_base=None,
80 | config_path="../../configs",
81 | config_name="train_uwm_robomimic.yaml",
82 | )
83 | def main(config):
84 | # Resolve hydra config
85 | OmegaConf.resolve(config)
86 | set_seed(0)
87 | device = torch.device(f"cuda:0")
88 |
89 | # Create model
90 | model = instantiate(config.model).to(device)
91 | optimizer = torch.optim.AdamW(model.parameters(), **config.optimizer)
92 | scheduler = get_scheduler(optimizer=optimizer, **config.scheduler)
93 | scaler = torch.cuda.amp.GradScaler(enabled=config.use_amp)
94 |
95 | # Resume from checkpoint
96 | config.resume = True
97 | step = maybe_resume_checkpoint(config, model, optimizer, scheduler, scaler)
98 | maybe_collect_rollout(config, 0, model, device)
99 |
100 |
101 | if __name__ == "__main__":
102 | main()
103 |
--------------------------------------------------------------------------------
/models/gr1/flamingo.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum
3 |
4 | from einops import rearrange, repeat
5 |
6 |
7 | def FeedForward(dim, mult=4):
8 | inner_dim = int(dim * mult)
9 | return nn.Sequential(
10 | nn.LayerNorm(dim),
11 | nn.Linear(dim, inner_dim, bias=False),
12 | nn.GELU(),
13 | nn.Linear(inner_dim, dim, bias=False),
14 | )
15 |
16 |
17 | class PerceiverAttention(nn.Module):
18 | def __init__(self, *, dim, dim_head=64, heads=8):
19 | super().__init__()
20 | self.scale = dim_head**-0.5
21 | self.heads = heads
22 | inner_dim = dim_head * heads
23 |
24 | self.norm_media = nn.LayerNorm(dim)
25 | self.norm_latents = nn.LayerNorm(dim)
26 |
27 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
28 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
29 | self.to_out = nn.Linear(inner_dim, dim, bias=False)
30 |
31 | def forward(self, x, latents):
32 | """
33 | einstein notation
34 | b - batch
35 | t - time
36 | n - sequence
37 | d - dimension
38 | """
39 | x = self.norm_media(x)
40 | latents = self.norm_latents(latents)
41 |
42 | b, m, h = *x.shape[:2], self.heads
43 |
44 | q = self.to_q(latents)
45 |
46 | # the paper differs from Perceiver in which they also concat the key / values derived from the latents to be attended to
47 | kv_input = torch.cat((x, latents), dim=-2)
48 | k, v = self.to_kv(kv_input).chunk(2, dim=-1)
49 |
50 | q = rearrange(q, "b t n (h d) -> b h t n d", h=h)
51 | k = rearrange(k, "b t n (h d) -> b h t n d", h=h)
52 | v = rearrange(v, "b t n (h d) -> b h t n d", h=h)
53 |
54 | # attention
55 | q = q * self.scale
56 | sim = einsum("... i d, ... j d -> ... i j", q, k)
57 | sim = sim - sim.amax(dim=-1, keepdim=True).detach()
58 | attn = sim.softmax(dim=-1)
59 |
60 | out = einsum("... i j, ... j d -> ... i d", attn, v)
61 | out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
62 | return self.to_out(out)
63 |
64 |
65 | class PerceiverResampler(nn.Module):
66 | def __init__(
67 | self,
68 | *,
69 | dim,
70 | depth,
71 | dim_head=64,
72 | heads=8,
73 | num_latents=64,
74 | num_media_embeds=4,
75 | ff_mult=4
76 | ):
77 | super().__init__()
78 | self.num_latents = num_latents
79 | self.latents = nn.Parameter(torch.randn(num_latents, dim))
80 | self.media_pos_emb = nn.Parameter(torch.randn(num_media_embeds, 1, dim))
81 |
82 | self.layers = nn.ModuleList([])
83 | for _ in range(depth):
84 | self.layers.append(
85 | nn.ModuleList(
86 | [
87 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
88 | FeedForward(dim=dim, mult=ff_mult),
89 | ]
90 | )
91 | )
92 |
93 | self.norm = nn.LayerNorm(dim)
94 |
95 | def forward(self, x):
96 | if x.ndim == 3:
97 | x = rearrange(x, "b n d -> b 1 n d")
98 |
99 | times = x.shape[1]
100 | x = x + self.media_pos_emb[:times]
101 |
102 | latents = repeat(self.latents, "n d -> b m n d", b=x.shape[0], m=x.shape[1])
103 | for attn, ff in self.layers:
104 | latents = attn(x, latents) + latents
105 | latents = ff(latents) + latents
106 | return self.norm(latents)
107 |
--------------------------------------------------------------------------------
/models/dp/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from models.common.adaln_attention import AdaLNAttentionBlock, AdaLNFinalLayer
5 | from models.common.utils import SinusoidalPosEmb, init_weights
6 | from .base_policy import NoisePredictionNet
7 |
8 |
9 | class TransformerNoisePredictionNet(NoisePredictionNet):
10 | def __init__(
11 | self,
12 | input_len: int,
13 | input_dim: int,
14 | global_cond_dim: int,
15 | timestep_embed_dim: int = 256,
16 | embed_dim: int = 768,
17 | depth: int = 12,
18 | num_heads: int = 12,
19 | mlp_ratio: float = 4.0,
20 | qkv_bias: bool = True,
21 | ):
22 | super().__init__()
23 | self.input_len = input_len
24 |
25 | # Input encoder and decoder
26 | hidden_dim = int(max(input_dim, embed_dim) * mlp_ratio)
27 | self.input_encoder = nn.Sequential(
28 | nn.Linear(input_dim, hidden_dim),
29 | nn.Mish(),
30 | nn.Linear(hidden_dim, embed_dim),
31 | )
32 | self.output_decoder = nn.Sequential(
33 | nn.Linear(embed_dim, hidden_dim),
34 | nn.Mish(),
35 | nn.Linear(hidden_dim, input_dim),
36 | )
37 |
38 | # Timestep encoder
39 | self.timestep_encoder = nn.Sequential(
40 | SinusoidalPosEmb(timestep_embed_dim),
41 | nn.Linear(timestep_embed_dim, timestep_embed_dim * 4),
42 | nn.Mish(),
43 | nn.Linear(timestep_embed_dim * 4, timestep_embed_dim),
44 | )
45 |
46 | # Model components
47 | self.pos_embed = nn.Parameter(
48 | torch.empty(1, input_len, embed_dim).normal_(std=0.02)
49 | )
50 | cond_dim = global_cond_dim + timestep_embed_dim
51 | self.blocks = nn.ModuleList(
52 | [
53 | AdaLNAttentionBlock(
54 | dim=embed_dim,
55 | cond_dim=cond_dim,
56 | num_heads=num_heads,
57 | mlp_ratio=mlp_ratio,
58 | qkv_bias=qkv_bias,
59 | )
60 | for _ in range(depth)
61 | ]
62 | )
63 | self.head = AdaLNFinalLayer(dim=embed_dim, cond_dim=cond_dim)
64 |
65 | # AdaLN-specific weight initialization
66 | self.initialize_weights()
67 |
68 | def initialize_weights(self):
69 | # Base initialization
70 | self.apply(init_weights)
71 |
72 | # Zero-out adaLN modulation layers in DiT blocks:
73 | for block in self.blocks:
74 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
75 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
76 |
77 | # Zero-out output layers:
78 | nn.init.constant_(self.head.adaLN_modulation[-1].weight, 0)
79 | nn.init.constant_(self.head.adaLN_modulation[-1].bias, 0)
80 | nn.init.constant_(self.head.linear.weight, 0)
81 | nn.init.constant_(self.head.linear.bias, 0)
82 |
83 | def forward(self, sample, timestep, global_cond):
84 | # Encode input
85 | embed = self.input_encoder(sample)
86 |
87 | # Encode timestep
88 | if len(timestep.shape) == 0:
89 | timestep = timestep.expand(sample.shape[0]).to(
90 | dtype=torch.long, device=sample.device
91 | )
92 | temb = self.timestep_encoder(timestep)
93 |
94 | # Concatenate timestep and condition along the sequence dimension
95 | x = embed + self.pos_embed
96 | cond = torch.cat([global_cond, temb], dim=-1)
97 | for block in self.blocks:
98 | x = block(x, cond)
99 | x = self.head(x, cond)
100 |
101 | # Decode output
102 | out = self.output_decoder(x)
103 | return out
104 |
--------------------------------------------------------------------------------
/datasets/utils/mixture.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch.utils.data import Dataset
5 |
6 |
7 | class RobotVideoMixtureDataset(Dataset):
8 | def __init__(
9 | self,
10 | robot_dataset: Dataset,
11 | video_dataset: Dataset,
12 | balance_datasets: bool = False,
13 | ):
14 | self.robot_dataset = robot_dataset
15 | self.video_dataset = video_dataset
16 |
17 | # Copy action and lowdim normalizers from robot dataset
18 | if hasattr(self.robot_dataset, "action_normalizer"):
19 | self.action_normalizer = self.robot_dataset.action_normalizer
20 | if hasattr(self.robot_dataset, "lowdim_normalizer"):
21 | self.lowdim_normalizer = self.robot_dataset.lowdim_normalizer
22 |
23 | # Balance robot and video datasets
24 | if balance_datasets:
25 | # Figure out dataset lengths
26 | len_robot = len(self.robot_dataset)
27 | len_video = len(self.video_dataset)
28 | max_len = max(len_robot, len_video)
29 | print(
30 | f"Balancing data: {len_robot} robot, {len_video} video -> upsample to {max_len} each"
31 | )
32 |
33 | # Upsample robot data
34 | robot_factor = math.ceil(max_len / len_robot)
35 | robot_indices = []
36 | for _ in range(robot_factor):
37 | robot_indices.extend(range(len_robot))
38 | robot_indices = robot_indices[:max_len]
39 |
40 | # Upsample video data
41 | video_factor = math.ceil(max_len / len_video)
42 | video_indices = []
43 | for _ in range(video_factor):
44 | video_indices.extend(range(len_video))
45 | video_indices = video_indices[:max_len]
46 |
47 | # Combine robot and video data
48 | combined_indices = []
49 | for r_i in robot_indices:
50 | combined_indices.append((True, r_i))
51 | for v_i in video_indices:
52 | combined_indices.append((False, v_i))
53 | self.index_map = combined_indices # list of (bool, idx)
54 | else:
55 | # Just combine the two datasets
56 | combined_indices = []
57 | for r_i in range(len(self.robot_dataset)):
58 | combined_indices.append((True, r_i))
59 | for v_i in range(len(self.video_dataset)):
60 | combined_indices.append((False, v_i))
61 | self.index_map = combined_indices
62 |
63 | def __len__(self):
64 | return len(self.index_map)
65 |
66 | def __getitem__(self, idx):
67 | is_robot, dataset_idx = self.index_map[idx]
68 | if is_robot:
69 | data = self.robot_dataset[dataset_idx]
70 | data["action_mask"] = torch.tensor(1, dtype=torch.bool)
71 | else:
72 | data = self.video_dataset[dataset_idx]
73 | data["action_mask"] = torch.tensor(0, dtype=torch.bool)
74 | return data
75 |
76 |
77 | def make_robot_video_mixture_dataset(
78 | robot_train_val_sets: tuple[Dataset, Dataset],
79 | video_train_val_sets: tuple[Dataset, Dataset],
80 | balance_datasets: bool = False,
81 | **kwargs,
82 | ):
83 | """
84 | Combine robot and video datasets into a mixture dataset.
85 |
86 | This function merges the training sets from both robot and video datasets to create a single training set,
87 | while using the robot dataset's validation set for evaluation. Additional keyword arguments (kwargs) are
88 | captured for compatibility with other functions in the codebase.
89 | """
90 | robot_train_set, robot_val_set = robot_train_val_sets
91 | video_train_set, video_val_set = video_train_val_sets
92 | train_set = RobotVideoMixtureDataset(
93 | robot_train_set, video_train_set, balance_datasets
94 | )
95 | return train_set, robot_val_set
96 |
--------------------------------------------------------------------------------
/experiments/uwm/train_webvideo.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import torch
3 | import torch.multiprocessing as mp
4 | import wandb
5 | from diffusers.optimization import get_scheduler
6 | from hydra.utils import instantiate
7 | from omegaconf import OmegaConf
8 | from torch.nn.parallel import DistributedDataParallel
9 | from tqdm import tqdm
10 |
11 | from datasets.utils.loader import make_distributed_data_loader
12 | from experiments.utils import set_seed, init_wandb, init_distributed, is_main_process
13 | from experiments.uwm.train import (
14 | train_one_step,
15 | maybe_resume_checkpoint,
16 | maybe_evaluate,
17 | maybe_save_checkpoint,
18 | )
19 |
20 |
21 | def train(rank, world_size, config):
22 | # Set global seed
23 | set_seed(config.seed * world_size + rank)
24 |
25 | # Initialize distributed training
26 | init_distributed(rank, world_size)
27 | device = torch.device(f"cuda:{rank}")
28 |
29 | # Initialize WANDB
30 | if is_main_process():
31 | init_wandb(config, job_type="train")
32 |
33 | # Create dataset and loader
34 | train_set, val_set = instantiate(config.dataset)
35 | train_loader, val_loader = make_distributed_data_loader(
36 | train_set, val_set, config.batch_size, rank, world_size
37 | )
38 |
39 | # Create model
40 | model = instantiate(config.model).to(device)
41 | optimizer = torch.optim.AdamW(model.parameters(), **config.optimizer)
42 | scheduler = get_scheduler(optimizer=optimizer, **config.scheduler)
43 | scaler = torch.cuda.amp.GradScaler(enabled=config.use_amp)
44 |
45 | # Load pretrained model
46 | if config.pretrain_checkpoint_path:
47 | ckpt = torch.load(config.pretrain_checkpoint_path, map_location="cpu")
48 | model.load_state_dict(ckpt["model"])
49 | print(
50 | f"Loaded pretraining checkpoint {config.pretrain_checkpoint_path}, step: {ckpt['step']}"
51 | )
52 |
53 | # Resume from checkpoint
54 | step = maybe_resume_checkpoint(config, model, optimizer, scheduler, scaler)
55 | epoch = step // len(train_loader)
56 |
57 | # Wrap model with DDP
58 | model = DistributedDataParallel(model, device_ids=[rank], static_graph=True)
59 |
60 | # Training loop
61 | pbar = tqdm(
62 | total=config.num_steps,
63 | initial=step,
64 | desc="Training",
65 | disable=not is_main_process(),
66 | )
67 | while step < config.num_steps:
68 | # Set epoch for distributed sampler to shuffle indices
69 | train_loader.sampler.set_epoch(epoch)
70 |
71 | # Train for one epoch
72 | for batch in train_loader:
73 | # --- Training step ---
74 | loss, info = train_one_step(
75 | config, model, optimizer, scheduler, scaler, batch, device
76 | )
77 |
78 | # --- Logging ---
79 | if is_main_process():
80 | pbar.set_description(f"step: {step}, loss: {loss.item():.4f}")
81 | wandb.log({f"train/{k}": v for k, v in info.items()})
82 |
83 | # --- Evaluate if needed ---
84 | maybe_evaluate(config, step, model, val_loader, device)
85 |
86 | # --- Save checkpoint if needed ---
87 | maybe_save_checkpoint(config, step, model, optimizer, scheduler, scaler)
88 |
89 | step += 1
90 | pbar.update(1)
91 | if step >= config.num_steps:
92 | break
93 |
94 | epoch += 1
95 |
96 |
97 | @hydra.main(
98 | version_base=None, config_path="../../configs", config_name="train_uwm.yaml"
99 | )
100 | def main(config):
101 | # Resolve hydra config
102 | OmegaConf.resolve(config)
103 | # Spawn processes
104 | world_size = torch.cuda.device_count()
105 | mp.spawn(train, args=(world_size, config), nprocs=world_size, join=True)
106 |
107 |
108 | if __name__ == "__main__":
109 | main()
110 |
--------------------------------------------------------------------------------
/experiments/gr1/train_robomimic.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import torch
3 | import torch.multiprocessing as mp
4 | import wandb
5 | from diffusers.optimization import get_scheduler
6 | from hydra.utils import instantiate
7 | from omegaconf import OmegaConf
8 | from torch.nn.parallel import DistributedDataParallel
9 | from tqdm import tqdm
10 |
11 | from datasets.utils.loader import make_distributed_data_loader
12 | from experiments.utils import set_seed, init_wandb, init_distributed, is_main_process
13 | from experiments.gr1.train import maybe_evaluate
14 | from experiments.uwm.train import (
15 | train_one_step,
16 | maybe_resume_checkpoint,
17 | maybe_save_checkpoint,
18 | )
19 | from experiments.uwm.train_robomimic import maybe_collect_rollout
20 |
21 |
22 | def train(rank, world_size, config):
23 | # Set global seed
24 | set_seed(config.seed * world_size + rank)
25 |
26 | # Initialize distributed training
27 | init_distributed(rank, world_size)
28 | device = torch.device(f"cuda:{rank}")
29 |
30 | # Initialize WANDB
31 | if is_main_process():
32 | init_wandb(config, job_type="train")
33 |
34 | # Create dataset and loader
35 | train_set, val_set = instantiate(config.dataset)
36 | train_loader, val_loader = make_distributed_data_loader(
37 | train_set, val_set, config.batch_size, rank, world_size
38 | )
39 |
40 | # Create model
41 | model = instantiate(config.model).to(device)
42 | optimizer = torch.optim.AdamW(model.parameters(), **config.optimizer)
43 | scheduler = get_scheduler(optimizer=optimizer, **config.scheduler)
44 | scaler = torch.cuda.amp.GradScaler(enabled=config.use_amp)
45 |
46 | # Load pretrained model
47 | if config.pretrain_checkpoint_path:
48 | ckpt = torch.load(config.pretrain_checkpoint_path, map_location="cpu")
49 | model.load_state_dict(ckpt["model"])
50 | print(
51 | f"Loaded pretraining checkpoint {config.pretrain_checkpoint_path}, step: {ckpt['step']}"
52 | )
53 |
54 | # Resume from checkpoint
55 | step = maybe_resume_checkpoint(config, model, optimizer, scheduler, scaler)
56 | epoch = step // len(train_loader)
57 |
58 | # Wrap model with DDP
59 | model = DistributedDataParallel(model, device_ids=[rank], static_graph=True)
60 |
61 | # Training loop
62 | pbar = tqdm(
63 | total=config.num_steps,
64 | initial=step,
65 | desc="Training",
66 | disable=not is_main_process(),
67 | )
68 | while step < config.num_steps:
69 | # Set epoch for distributed sampler to shuffle indices
70 | train_loader.sampler.set_epoch(epoch)
71 |
72 | # Train for one epoch
73 | for batch in train_loader:
74 | # --- Training step ---
75 | loss, info = train_one_step(
76 | config, model, optimizer, scheduler, scaler, batch, device
77 | )
78 |
79 | # --- Logging ---
80 | if is_main_process():
81 | pbar.set_description(f"step: {step}, loss: {loss.item():.4f}")
82 | wandb.log({f"train/{k}": v for k, v in info.items()})
83 |
84 | # --- Evaluate if needed ---
85 | maybe_evaluate(config, step, model, val_loader, device)
86 |
87 | # ---Collect environment rollouts if needed ---
88 | maybe_collect_rollout(config, step, model, device)
89 |
90 | # --- Save checkpoint if needed ---
91 | maybe_save_checkpoint(config, step, model, optimizer, scheduler, scaler)
92 |
93 | step += 1
94 | pbar.update(1)
95 | if step >= config.num_steps:
96 | break
97 |
98 | epoch += 1
99 |
100 |
101 | @hydra.main(
102 | version_base=None,
103 | config_path="../../configs",
104 | config_name="train_gr1_robomimic.yaml",
105 | )
106 | def main(config):
107 | # Resolve hydra config
108 | OmegaConf.resolve(config)
109 | # Spawn processes
110 | world_size = torch.cuda.device_count()
111 | mp.spawn(train, args=(world_size, config), nprocs=world_size, join=True)
112 |
113 |
114 | if __name__ == "__main__":
115 | main()
116 |
--------------------------------------------------------------------------------
/experiments/pad/train_robomimic.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import torch
3 | import torch.multiprocessing as mp
4 | import wandb
5 | from diffusers.optimization import get_scheduler
6 | from hydra.utils import instantiate
7 | from omegaconf import OmegaConf
8 | from torch.nn.parallel import DistributedDataParallel
9 | from tqdm import tqdm
10 |
11 | from datasets.utils.loader import make_distributed_data_loader
12 | from experiments.utils import set_seed, init_wandb, init_distributed, is_main_process
13 | from experiments.pad.train import maybe_evaluate
14 | from experiments.uwm.train import (
15 | train_one_step,
16 | maybe_resume_checkpoint,
17 | maybe_save_checkpoint,
18 | )
19 | from experiments.uwm.train_robomimic import maybe_collect_rollout
20 |
21 |
22 | def train(rank, world_size, config):
23 | # Set global seed
24 | set_seed(config.seed * world_size + rank)
25 |
26 | # Initialize distributed training
27 | init_distributed(rank, world_size)
28 | device = torch.device(f"cuda:{rank}")
29 |
30 | # Initialize WANDB
31 | if is_main_process():
32 | init_wandb(config, job_type="train")
33 |
34 | # Create dataset and loader
35 | train_set, val_set = instantiate(config.dataset)
36 | train_loader, val_loader = make_distributed_data_loader(
37 | train_set, val_set, config.batch_size, rank, world_size
38 | )
39 |
40 | # Create model
41 | model = instantiate(config.model).to(device)
42 | optimizer = torch.optim.AdamW(model.parameters(), **config.optimizer)
43 | scheduler = get_scheduler(optimizer=optimizer, **config.scheduler)
44 | scaler = torch.cuda.amp.GradScaler(enabled=config.use_amp)
45 |
46 | # Load pretrained model
47 | if config.pretrain_checkpoint_path:
48 | ckpt = torch.load(config.pretrain_checkpoint_path, map_location="cpu")
49 | model.load_state_dict(ckpt["model"])
50 | print(
51 | f"Loaded pretraining checkpoint {config.pretrain_checkpoint_path}, step: {ckpt['step']}"
52 | )
53 |
54 | # Resume from checkpoint
55 | step = maybe_resume_checkpoint(config, model, optimizer, scheduler, scaler)
56 | epoch = step // len(train_loader)
57 |
58 | # Wrap model with DDP
59 | model = DistributedDataParallel(model, device_ids=[rank], static_graph=True)
60 |
61 | # Training loop
62 | pbar = tqdm(
63 | total=config.num_steps,
64 | initial=step,
65 | desc="Training",
66 | disable=not is_main_process(),
67 | )
68 | while step < config.num_steps:
69 | # Set epoch for distributed sampler to shuffle indices
70 | train_loader.sampler.set_epoch(epoch)
71 |
72 | # Train for one epoch
73 | for batch in train_loader:
74 | # --- Training step ---
75 | loss, info = train_one_step(
76 | config, model, optimizer, scheduler, scaler, batch, device
77 | )
78 |
79 | # --- Logging ---
80 | if is_main_process():
81 | pbar.set_description(f"step: {step}, loss: {loss.item():.4f}")
82 | wandb.log({f"train/{k}": v for k, v in info.items()})
83 |
84 | # --- Evaluate if needed ---
85 | maybe_evaluate(config, step, model, val_loader, device)
86 |
87 | # ---Collect environment rollouts if needed ---
88 | maybe_collect_rollout(config, step, model, device)
89 |
90 | # --- Save checkpoint if needed ---
91 | maybe_save_checkpoint(config, step, model, optimizer, scheduler, scaler)
92 |
93 | step += 1
94 | pbar.update(1)
95 | if step >= config.num_steps:
96 | break
97 |
98 | epoch += 1
99 |
100 |
101 | @hydra.main(
102 | version_base=None,
103 | config_path="../../configs",
104 | config_name="train_pad_robomimic.yaml",
105 | )
106 | def main(config):
107 | # Resolve hydra config
108 | OmegaConf.resolve(config)
109 | # Spawn processes
110 | world_size = torch.cuda.device_count()
111 | mp.spawn(train, args=(world_size, config), nprocs=world_size, join=True)
112 |
113 |
114 | if __name__ == "__main__":
115 | main()
116 |
--------------------------------------------------------------------------------
/experiments/uwm/ablate_inverse_dynamics.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import hydra
4 | import h5py
5 | import imageio
6 | import torch
7 | from hydra.utils import instantiate
8 | from libero.libero.envs import OffScreenRenderEnv
9 | from libero.libero import get_libero_path
10 | from omegaconf import OmegaConf
11 | from tqdm import trange
12 |
13 |
14 | from datasets.utils.loader import make_distributed_data_loader
15 | from environments.robomimic.wrappers import LIBEROEnvWrapper
16 | from experiments.utils import set_seed
17 |
18 |
19 | def eval_inverse_dynamics(
20 | hdf5_path, obs_keys, obs_horizon, action_horizon, max_episode_length, model, device
21 | ):
22 | # Set eval mode
23 | model.eval()
24 |
25 | # Make environment to verify demo
26 | bddl_file_name = os.path.join(
27 | get_libero_path("bddl_files"),
28 | "libero_10",
29 | hdf5_path.split("/")[-1].replace("_demo.hdf5", ".bddl"),
30 | )
31 | env_kwargs = {
32 | "bddl_file_name": bddl_file_name,
33 | "camera_heights": 128,
34 | "camera_widths": 128,
35 | }
36 | if os.environ.get("CUDA_VISIBLE_DEVICES", None):
37 | env_kwargs["render_gpu_device_id"] = int(
38 | os.environ["CUDA_VISIBLE_DEVICES"].split(",")[0]
39 | )
40 | env = OffScreenRenderEnv(**env_kwargs)
41 | env = LIBEROEnvWrapper(env, obs_keys, obs_horizon, max_episode_length, record=True)
42 |
43 | successes = []
44 | with h5py.File(hdf5_path) as f:
45 | demos = f["data"]
46 | for i in trange(len(demos)):
47 | demo = demos[f"demo_{i}"]
48 | obs = env.reset_to_state(demo["states"][0])
49 | next_index = action_horizon
50 | done = False
51 | while next_index + obs_horizon <= len(demo["actions"]) and not done:
52 | next_obs = {
53 | key: demo["obs"][key][next_index : next_index + obs_horizon]
54 | for key in obs_keys
55 | }
56 | obs_tensor = {
57 | k: torch.tensor(v, device=device)[None] for k, v in obs.items()
58 | }
59 | next_obs_tensor = {
60 | k: torch.tensor(v, device=device)[None] for k, v in next_obs.items()
61 | }
62 | action = model.sample_inverse_dynamics(obs_tensor, next_obs_tensor)
63 | # action = model.sample(obs_tensor)
64 | obs, _, done, info = env.step(action[0].cpu().numpy())
65 | next_index += action_horizon
66 | successes.append(info["success"])
67 | print(f"Episode {i}, success: {successes[-1]}")
68 |
69 | # Save video locally
70 | video = env.get_video()
71 | imageio.mimsave(f"episode_{i}.gif", video)
72 |
73 | print(f"Total: {len(demos)}, success: {sum(successes)}")
74 |
75 |
76 | def train(rank, world_size, config):
77 | # Set global seed
78 | set_seed(config.seed * world_size + rank)
79 |
80 | # Initialize distributed training
81 | device = torch.device(f"cuda:{rank}")
82 |
83 | # Create dataset
84 | train_set, val_set = instantiate(config.dataset)
85 | train_loader, val_loader = make_distributed_data_loader(
86 | train_set, val_set, config.batch_size, rank, world_size
87 | )
88 |
89 | # Create model
90 | model = instantiate(config.model).to(device)
91 |
92 | # Resume from checkpoint
93 | ckpt_path = os.path.join(config.logdir, "models.pt")
94 | ckpt = torch.load(ckpt_path, map_location="cpu")
95 | model.load_state_dict(ckpt["model"])
96 | step = ckpt["step"] + 1
97 | print(f"Loading checkpoint from step {step}")
98 |
99 | if ckpt["action_normalizer"] is not None:
100 | train_set.action_normalizer = ckpt["action_normalizer"]
101 | val_set.action_normalizer = ckpt["action_normalizer"]
102 | if ckpt["lowdim_normalizer"] is not None:
103 | train_set.lowdim_normalizer = ckpt["lowdim_normalizer"]
104 | val_set.lowdim_normalizer = ckpt["lowdim_normalizer"]
105 |
106 | obs_keys = list(model.obs_encoder.rgb_keys) + list(model.obs_encoder.low_dim_keys)
107 | eval_inverse_dynamics(
108 | config.dataset.hdf5_path_globs,
109 | obs_keys,
110 | config.model.obs_encoder.num_frames,
111 | config.model.action_len,
112 | config.rollout_length,
113 | model,
114 | device,
115 | )
116 |
117 |
118 | @hydra.main(
119 | version_base=None,
120 | config_path="../../configs",
121 | config_name="train_uwm_robomimic.yaml",
122 | )
123 | def main(config):
124 | # Resolve hydra config
125 | OmegaConf.resolve(config)
126 | train(0, 1, config)
127 |
128 |
129 | if __name__ == "__main__":
130 | main()
131 |
--------------------------------------------------------------------------------
/environments/robomimic/wrappers.py:
--------------------------------------------------------------------------------
1 | from collections import deque
2 |
3 | import numpy as np
4 |
5 |
6 | class RoboMimicEnvWrapper:
7 | def __init__(
8 | self,
9 | env,
10 | obs_keys,
11 | obs_horizon,
12 | max_episode_length,
13 | record=False,
14 | render_size=(224, 224),
15 | ):
16 | self.env = env
17 | self.obs_keys = obs_keys
18 | self.obs_buffer = deque(maxlen=obs_horizon)
19 |
20 | self._max_episode_length = max_episode_length
21 | self._elapsed_steps = None
22 |
23 | self.record = record
24 | self.render_size = render_size
25 | if record:
26 | self.video_buffer = deque()
27 |
28 | def _is_success(self):
29 | return self.env.is_success()["task"]
30 |
31 | def _get_obs(self):
32 | # Return a dictionary of stacked observations
33 | stacked_obs = {}
34 | for key in self.obs_keys:
35 | stacked_obs[key] = np.stack([obs[key] for obs in self.obs_buffer])
36 | return stacked_obs
37 |
38 | def seed(self, seed):
39 | np.random.seed(seed)
40 |
41 | def reset(self):
42 | # Clear buffers
43 | self.obs_buffer.clear()
44 | if self.record:
45 | self.video_buffer.clear()
46 |
47 | # Reset environment
48 | obs = self.env.reset()
49 | self._elapsed_steps = 0
50 |
51 | # Pad observation buffer
52 | for _ in range(self.obs_buffer.maxlen):
53 | self.obs_buffer.append(obs)
54 |
55 | return self._get_obs()
56 |
57 | def step(self, actions):
58 | # Roll out a sequence of actions in the environment
59 | total_reward = 0
60 | for action in actions:
61 | # Step environment
62 | obs, reward, done, info = self.env.step(action)
63 | total_reward += reward
64 | self.obs_buffer.append(obs)
65 | if self.record:
66 | self.video_buffer.append(self.render())
67 |
68 | # Store success info
69 | info["success"] = self._is_success()
70 |
71 | # Terminate on success
72 | done = done or info["success"]
73 |
74 | # Terminate if max episode length is reached
75 | self._elapsed_steps += 1
76 | if self._elapsed_steps >= self._max_episode_length:
77 | info["truncated"] = not done
78 | done = True
79 |
80 | if done:
81 | break
82 |
83 | return self._get_obs(), total_reward, done, info
84 |
85 | def render(self):
86 | return self.env.render(
87 | mode="rgb_array",
88 | width=self.render_size[0],
89 | height=self.render_size[1],
90 | )
91 |
92 | def get_video(self):
93 | if not self.record:
94 | raise ValueError("Video recording is disabled.")
95 | return np.stack(self.video_buffer)
96 |
97 | def close(self):
98 | self.env.close()
99 |
100 |
101 | class LIBEROEnvWrapper(RoboMimicEnvWrapper):
102 | def __init__(
103 | self,
104 | env,
105 | obs_keys,
106 | obs_horizon,
107 | max_episode_length,
108 | record=False,
109 | render_size=(224, 224),
110 | ):
111 | super().__init__(
112 | env,
113 | obs_keys,
114 | obs_horizon,
115 | max_episode_length,
116 | record,
117 | render_size,
118 | )
119 | self.source_key_map = {
120 | "agentview_rgb": "agentview_image",
121 | "eye_in_hand_rgb": "robot0_eye_in_hand_image",
122 | }
123 |
124 | def reset_to_state(self, state):
125 | self.seed(0)
126 | self.reset()
127 | self.env.set_init_state(state)
128 |
129 | # Refresh obs buffer
130 | self.obs_buffer.clear()
131 | obs = self.env.env._get_observations()
132 | for _ in range(self.obs_buffer.maxlen):
133 | self.obs_buffer.append(obs)
134 | return self._get_obs()
135 |
136 | def _is_success(self):
137 | return self.env.check_success()
138 |
139 | def _get_obs(self):
140 | # Return a dictionary of stacked observations
141 | stacked_obs = {}
142 | for key in self.obs_keys:
143 | source_key = self.source_key_map.get(key, key)
144 | stacked_obs[key] = np.stack([obs[source_key] for obs in self.obs_buffer])
145 |
146 | # Flip all image observations
147 | for key in self.obs_keys:
148 | if len(stacked_obs[key].shape) == 4:
149 | stacked_obs[key] = stacked_obs[key][:, ::-1].copy()
150 | return stacked_obs
151 |
152 | def render(self):
153 | img = self.env.env.sim.render(
154 | height=self.render_size[1],
155 | width=self.render_size[0],
156 | camera_name="frontview",
157 | )
158 | return img[::-1]
159 |
--------------------------------------------------------------------------------
/experiments/uwm/ablate_forward_dynamics.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import hydra
4 | import imageio
5 | import torch
6 | import numpy as np
7 | from diffusers.optimization import get_scheduler
8 | from einops import rearrange
9 | from hydra.utils import instantiate
10 | from omegaconf import OmegaConf
11 | from torchvision.utils import make_grid
12 | from tqdm import tqdm
13 |
14 | from datasets.utils.loader import make_distributed_data_loader
15 | from experiments.utils import set_seed, is_main_process
16 |
17 |
18 | def process_batch(batch, obs_horizon, action_horizon, device):
19 | action_start = obs_horizon - 1
20 | action_end = action_start + action_horizon
21 | curr_obs = {k: v[:, : action_start + 1].to(device) for k, v in batch["obs"].items()}
22 | next_obs = {k: v[:, action_end:].to(device) for k, v in batch["obs"].items()}
23 | actions = batch["action"][:, action_start:action_end].to(device)
24 |
25 | # Add language tokens
26 | if "input_ids" in batch and "attention_mask" in batch:
27 | curr_obs["input_ids"] = batch["input_ids"].to(device)
28 | curr_obs["attention_mask"] = batch["attention_mask"].to(device)
29 | return curr_obs, next_obs, actions
30 |
31 |
32 | def eval_one_epoch(config, data_loader, device, model):
33 | model.eval()
34 | model = getattr(model, "module", model) # unwrap DDP
35 |
36 | def decode_and_plot(images, nrows=2):
37 | images = model.obs_encoder.apply_vae(images, inverse=True)
38 | images = rearrange(images, "b v c t h w -> (b v t) c h w").clamp(0, 1)
39 | images_grid = make_grid(images, nrows)
40 | return (
41 | (images_grid.cpu().numpy() * 255)
42 | .round()
43 | .astype(np.uint8)
44 | .transpose(1, 2, 0)
45 | )
46 |
47 | save_path = f"viz_{config.dataset.name}"
48 | if not os.path.exists(save_path):
49 | os.mkdir(save_path)
50 | step = 0
51 | for batch in tqdm(data_loader, desc="Evaluating", disable=not is_main_process()):
52 | # ------------ Preprocess data ------------ #
53 | curr_obs_dict, next_obs_dict, action_norm = process_batch(
54 | batch, config.model.obs_encoder.num_frames, config.model.action_len, device
55 | )
56 |
57 | with torch.no_grad():
58 | # Encode current observations
59 | curr_obs = model.obs_encoder.encode_next_obs(curr_obs_dict)
60 |
61 | # Encode next observations
62 | next_obs = model.obs_encoder.encode_next_obs(next_obs_dict)
63 |
64 | # Sample observations from forward dynamics
65 | next_obs_hat_forward = model.sample_forward_dynamics(
66 | curr_obs_dict, action_norm
67 | )
68 |
69 | # Plot current observations
70 | curr_obs = decode_and_plot(curr_obs[:1])
71 | imageio.imwrite(f"{save_path}/{step}_curr_obs.png", curr_obs)
72 |
73 | # Plot next observations
74 | next_obs = decode_and_plot(next_obs[:1])
75 | imageio.imwrite(f"{save_path}/{step}_next_obs.png", next_obs)
76 |
77 | # Plot predicted next observations
78 | next_obs_hat_forward = decode_and_plot(next_obs_hat_forward[:1])
79 | imageio.imwrite(f"{save_path}/{step}_next_obs_hat.png", next_obs_hat_forward)
80 |
81 | step += 1
82 | if step == 15:
83 | break
84 |
85 |
86 | def train(rank, world_size, config):
87 | # Set global seed
88 | set_seed(config.seed * world_size + rank)
89 |
90 | # Initialize distributed training
91 | device = torch.device(f"cuda:{rank}")
92 |
93 | # Create dataset
94 | train_set, val_set = instantiate(config.dataset)
95 | train_loader, val_loader = make_distributed_data_loader(
96 | train_set, val_set, config.batch_size, rank, world_size
97 | )
98 |
99 | # Create model
100 | model = instantiate(config.model).to(device)
101 | optimizer = torch.optim.AdamW(model.parameters(), **config.optimizer)
102 | scheduler = get_scheduler(optimizer=optimizer, **config.scheduler)
103 | scaler = torch.cuda.amp.GradScaler(enabled=config.use_amp)
104 |
105 | # Resume from checkpoint
106 | ckpt_path = os.path.join(config.logdir, "models.pt")
107 | ckpt = torch.load(ckpt_path, map_location="cpu")
108 | model.load_state_dict(ckpt["model"])
109 | optimizer.load_state_dict(ckpt["optimizer"])
110 | scheduler.load_state_dict(ckpt["scheduler"])
111 | scaler.load_state_dict(ckpt["scaler"])
112 | step = ckpt["step"] + 1
113 | print(f"Resumed training from step {step}")
114 |
115 | train_set.action_normalizer = ckpt["action_normalizer"]
116 | train_set.lowdim_normalizer = ckpt["lowdim_normalizer"]
117 | val_set.action_normalizer = ckpt["action_normalizer"]
118 | val_set.lowdim_normalizer = ckpt["lowdim_normalizer"]
119 |
120 | eval_one_epoch(config, val_loader, device, model)
121 |
122 |
123 | @hydra.main(
124 | version_base=None, config_path="../../configs", config_name="train_uwm.yaml"
125 | )
126 | def main(config):
127 | # Resolve hydra config
128 | OmegaConf.resolve(config)
129 | train(0, 1, config)
130 |
131 |
132 | if __name__ == "__main__":
133 | main()
134 |
--------------------------------------------------------------------------------
/models/dp/base_policy.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from diffusers.schedulers.scheduling_ddim import DDIMScheduler
7 |
8 |
9 | class NoisePredictionNet(nn.Module, ABC):
10 |
11 | @abstractmethod
12 | def forward(self, sample, timestep, global_cond):
13 | raise NotImplementedError
14 |
15 |
16 | class DiffusionPolicy(nn.Module):
17 | def __init__(
18 | self,
19 | action_len,
20 | action_dim,
21 | noise_pred_net,
22 | num_train_steps=100,
23 | num_inference_steps=10,
24 | num_train_noise_samples=1,
25 | beta_schedule="squaredcos_cap_v2",
26 | clip_sample=True,
27 | ):
28 | super().__init__()
29 | self.action_len = action_len
30 | self.action_dim = action_dim
31 | self.num_train_steps = num_train_steps
32 | self.num_inference_steps = num_inference_steps
33 | self.num_train_noise_samples = num_train_noise_samples
34 |
35 | # Noise prediction net
36 | assert isinstance(noise_pred_net, NoisePredictionNet)
37 | self.noise_pred_net = noise_pred_net
38 |
39 | # Noise scheduler
40 | self.noise_scheduler = DDIMScheduler(
41 | num_train_timesteps=num_train_steps,
42 | beta_schedule=beta_schedule,
43 | clip_sample=clip_sample,
44 | )
45 |
46 | @torch.no_grad()
47 | def sample(self, obs):
48 | # Initialize sample
49 | action = torch.randn(
50 | (obs.shape[0], self.action_len, self.action_dim), device=obs.device
51 | )
52 |
53 | # Initialize scheduler
54 | self.noise_scheduler.set_timesteps(self.num_inference_steps)
55 |
56 | # Reverse diffusion process
57 | for t in self.noise_scheduler.timesteps:
58 | # Predict noise
59 | noise_pred = self.noise_pred_net(action, t, global_cond=obs)
60 |
61 | # Diffusion step
62 | action = self.noise_scheduler.step(noise_pred, t, action).prev_sample
63 |
64 | return action
65 |
66 | def forward(self, obs, action):
67 | # Repeat observations and actions for multiple noise samples
68 | if self.num_train_noise_samples > 1:
69 | obs = obs.repeat_interleave(self.num_train_noise_samples, dim=0)
70 | action = action.repeat_interleave(self.num_train_noise_samples, dim=0)
71 |
72 | # Sample random noise
73 | noise = torch.randn_like(action)
74 |
75 | # Sample a random timestep
76 | t = torch.randint(
77 | low=0,
78 | high=self.num_train_steps,
79 | size=(action.shape[0],),
80 | device=action.device,
81 | ).long()
82 |
83 | # Forward diffusion step
84 | noisy_action = self.noise_scheduler.add_noise(action, noise, t)
85 |
86 | # Diffusion loss
87 | noise_pred = self.noise_pred_net(noisy_action, t, global_cond=obs)
88 | loss = F.mse_loss(noise_pred, noise)
89 | return loss
90 |
91 |
92 | class FlowPolicy(nn.Module):
93 | def __init__(
94 | self,
95 | action_len,
96 | action_dim,
97 | noise_pred_net,
98 | num_train_steps=100,
99 | num_inference_steps=10,
100 | timeshift=1.0,
101 | ):
102 | super().__init__()
103 | self.action_len = action_len
104 | self.action_dim = action_dim
105 |
106 | # Noise prediction net
107 | assert isinstance(noise_pred_net, NoisePredictionNet)
108 | self.noise_pred_net = noise_pred_net
109 |
110 | self.num_train_steps = num_train_steps
111 | self.num_inference_steps = num_inference_steps
112 | timesteps = torch.linspace(1, 0, self.num_inference_steps + 1)
113 | self.timesteps = (timeshift * timesteps) / (1 + (timeshift - 1) * timesteps)
114 |
115 | @torch.no_grad()
116 | def sample(self, obs):
117 | # Initialize sample
118 | action = torch.randn(
119 | (obs.shape[0], self.action_len, self.action_dim), device=obs.device
120 | )
121 |
122 | for tcont, tcont_next in zip(self.timesteps[:-1], self.timesteps[1:]):
123 | # Predict noise
124 | t = (tcont * self.num_train_steps).long()
125 | noise_pred = self.noise_pred_net(action, t, global_cond=obs)
126 |
127 | # Flow step
128 | action = action + (tcont_next - tcont) * noise_pred
129 |
130 | return action
131 |
132 | def forward(self, obs, action):
133 | # Sample random noise
134 | noise = torch.randn_like(action)
135 |
136 | # Sample random timestep
137 | tcont = torch.rand((action.shape[0],), device=action.device)
138 |
139 | # Forward flow step
140 | direction = noise - action
141 | noisy_action = (
142 | action + tcont.view(-1, *[1 for _ in range(action.dim() - 1)]) * direction
143 | )
144 |
145 | # Flow matching loss
146 | t = (tcont * self.num_train_steps).long()
147 | noise_pred = self.noise_pred_net(noisy_action, t, global_cond=obs)
148 | loss = F.mse_loss(noise_pred, direction)
149 | return loss
150 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Unified World Models: Coupling Video and Action Diffusion for Pretraining on Large Robotic Datasets
2 |
3 | #### [[Website]](https://weirdlabuw.github.io/uwm/) [[Paper]](https://arxiv.org/abs/2504.02792) [[Talk]](https://www.youtube.com/watch?v=WwPRxBbZ4kw)
4 |
5 | [Chuning Zhu1](https://homes.cs.washington.edu/~zchuning/), [Raymond Yu1](https://raymondyu5.github.io/), [Siyuan Feng2](https://www.cs.cmu.edu/~sfeng/), [Benjamin Burchfiel2](https://scholar.google.com/citations?user=eGoTK1YAAAAJ&hl=en), [Paarth Shah2](https://www.paarthshah.me/about), [Abhishek Gupta1](https://homes.cs.washington.edu/~abhgupta/)
6 |
7 | 1University of Washington 2Toyota Research Institute
8 |
9 | This repository provides a PyTorch implementation of Unified World Model (UWM). UWM combines action diffusion and video diffusion to enable scalable pretraining on large, heterogeneous robotics datasets.
10 |
11 |
12 | ## Code structure
13 | * `configs`: Configuration files for pretraining and finetuning experiments.
14 | * `datasets`: Dataset wrappers for DROID, Robomimic, and LIBERO. We standardize all datasets using compressed [Zarr](https://zarr.readthedocs.io/en/stable/) buffers.
15 | * `environments`: Interface wrappers for Robomimic and LIBERO environments.
16 | * `experiments`: Training and evaluation scripts.
17 | * `models`: Model definitions for UWM and baselines.
18 | * `scripts`: Bash scripts for running DROID experiments.
19 |
20 |
21 | ## Setup
22 | Install the package via
23 | ```
24 | pip install -e .
25 | ```
26 | > Note: if you encounter issues using tensorflow-dataset with DROID, consider installing tensorflow-dataset from [source](https://github.com/tensorflow/datasets).
27 |
28 | ## Robomimic Experiments
29 | To run a Robomimic single-task experiment,
30 | 1. Install the [Robomimic](https://github.com/ARISE-Initiative/robomimic) dataset.
31 | 2. Update `hdf5_path` and `buffer_path` in the config (e.g., `configs/dataset/robomimic_cap_ph.yaml`).
32 | 3. Run:
33 | ```
34 | python experiments/uwm/train_robomimic.py --config_name train_uwm_robomimic.yaml dataset=robomimic_can_ph exp_id=singletask
35 | ```
36 | This command will generate a Zarr compressed buffer at the `buffer_path` specified in the config file.
37 |
38 | ## LIBERO Experiments
39 | The LIBERO experiments share most infrastructure with the Robomimic experiments.
40 |
41 | ### Pretraining
42 | To pretrain a UWM on LIBERO-90,
43 | 1. Install the [LIBERO](https://github.com/Lifelong-Robot-Learning/LIBERO) dataset.
44 | 2. Update `hdf5_path` and `buffer_path` in `configs/dataset/libero_90.yaml`.
45 | 3. Run:
46 | ```
47 | python experiments/uwm/train_robomimic.py --config_name train_uwm_robomimic.yaml dataset=libero_90 exp_id=pretrain
48 | ```
49 |
50 | ### Finetuning
51 | To finetune a pretrained UWM on a downstream LIBERO task (e.g., Book-Caddy),
52 | 1. Update `hdf5_path` and `buffer_path` in `configs/dataset/libero_book_caddy.yaml`.
53 | 2. Run:
54 | ```
55 | python experiments/uwm/train_robomimic.py --config-name finetune_uwm_robomimic.yaml dataset=libero_book_caddy exp_id=finetune pretrain_checkpoint_path="logdir/uwm/libero_90/pretrain/0/models.pt"
56 | ```
57 |
58 | We release the pretrained LIBERO-90 checkpoint [here](https://drive.google.com/drive/folders/1M4AuVLMRpSwOf_YAp56bV9AqyZI9ul6g?usp=sharing). You can download and directly finetune from this checkpoint.
59 |
60 | ## DROID Experiments
61 | We provide shell scripts for DROID pretraining / cotraining / finetuning experiments in the `scripts` directory. Each script runs a dataset conversion pipeline to create a Zarr buffer for the corresponding DROID TFDS dataset and then launches training.
62 |
63 | ### Pretraining
64 | To launch a DROID pretraining experiment,
65 | 1. Install the [DROID](https://droid-dataset.github.io/) dataset
66 | 2. Update `DATA_DIR` and `BUFFER_PATH` in `scripts/launch_droid_pretrain.sh`
67 | 3. Run:
68 | ```
69 | source scripts/launch_droid_pretrain.sh
70 | ```
71 |
72 | ### Cotraining
73 | To launch a video cotraining experiment,
74 | 1. Install the [DROID](https://droid-dataset.github.io/) dataset
75 | 2. Update `DATA_DIR`, `ROBOT_BUFFER_PATH`, and `VIDEO_BUFFER_PATH` in `scripts/launch_droid_cotrain.sh`
76 | 3. Run:
77 | ```
78 | source scripts/launch_droid_cotrain.sh
79 | ```
80 |
81 | ### Finetuning
82 | To fineune a pretrained model to a downstream task,
83 | 1. Collect demonstrations using the DROID interface
84 | 2. Convert them into a TFDS dataset (via this [pipeline](https://github.com/kpertsch/droid_dataset_builder))
85 | 3. Modify and run:
86 | ```
87 | source scripts/launch_droid_finetune.sh
88 | ```
89 |
90 | We release the pretrained and cotrained DROID UWM checkpoints [here](https://drive.google.com/drive/folders/1M4AuVLMRpSwOf_YAp56bV9AqyZI9ul6g?usp=sharing). You can download and directly finetune from these checkpoints.
91 |
92 | ## Bibtex
93 | If you find this code useful, please cite:
94 |
95 | ```
96 | @inproceedings{zhu2025uwm,
97 | author = {Zhu, Chuning and Yu, Raymond and Feng, Siyuan and Burchfiel, Benjamin and Shah, Paarth and Gupta, Abhishek},
98 | title = {Unified World Models: Coupling Video and Action Diffusion for Pretraining on Large Robotic Datasets},
99 | booktitle = {Proceedings of Robotics: Science and Systems (RSS)},
100 | year = {2025},
101 | }
102 | ```
--------------------------------------------------------------------------------
/models/dp/obs_encoder.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional, Tuple
2 |
3 | import torch
4 | import torch.nn as nn
5 | from einops import rearrange
6 | from torchvision.transforms import Normalize
7 |
8 | from models.common.language import CLIPTextEncoder
9 | from models.common.transforms import ImageTransform
10 | from models.common.vision import get_resnet, get_vit, get_clip
11 |
12 |
13 | class ImageObservationEncoder(nn.Module):
14 | def __init__(
15 | self,
16 | shape_meta: dict,
17 | num_frames: int,
18 | embed_dim: int,
19 | resize_shape: Tuple[int, int] = None,
20 | crop_shape: Tuple[int, int] = None,
21 | random_crop: bool = True,
22 | color_jitter: Optional[Dict] = None,
23 | imagenet_norm: bool = True,
24 | pretrained_weights: Optional[str] = None,
25 | use_low_dim: bool = True,
26 | use_language: bool = True,
27 | ):
28 | """
29 | Assumes rgb input: (B, T, H, W, C) uint8 image
30 | Assumes low_dim input: (B, T, D)
31 | """
32 | super().__init__()
33 | rgb_keys = list()
34 | low_dim_keys = list()
35 | key_shape_map = dict()
36 | key_transform_map = nn.ModuleDict()
37 |
38 | obs_shape_meta = shape_meta["obs"]
39 | for key, attr in obs_shape_meta.items():
40 | obs_shape = tuple(attr["shape"])
41 | key_shape_map[key] = obs_shape
42 |
43 | obs_type = attr.get("type", "low_dim")
44 | if obs_type == "rgb":
45 | rgb_keys.append(key)
46 | key_transform_map[key] = ImageTransform(
47 | resize_shape=resize_shape,
48 | crop_shape=crop_shape,
49 | random_crop=random_crop,
50 | color_jitter=color_jitter,
51 | imagenet_norm=imagenet_norm,
52 | )
53 | elif obs_type == "low_dim":
54 | low_dim_keys.append(key)
55 | else:
56 | raise RuntimeError(f"Unsupported obs type: {type}")
57 |
58 | self.shape_meta = shape_meta
59 | self.num_frames = num_frames
60 | self.embed_dim = embed_dim
61 | self.rgb_keys = sorted(rgb_keys)
62 | self.low_dim_keys = sorted(low_dim_keys)
63 | self.key_shape_map = key_shape_map
64 | self.key_transform_map = key_transform_map
65 |
66 | # RGB model
67 | if pretrained_weights == "clip":
68 | assert not imagenet_norm, "imagenet_norm must be False for CLIP encoder"
69 | norm = Normalize(
70 | mean=[0.48145466, 0.4578275, 0.40821073],
71 | std=[0.26862954, 0.26130258, 0.27577711],
72 | inplace=True,
73 | )
74 | model = get_clip(embed_dim)
75 | self.rgb_encoder = nn.Sequential(norm, model)
76 | elif pretrained_weights == "vit":
77 | self.rgb_encoder = get_vit(
78 | "vit_b_32", embed_dim, weights=pretrained_weights
79 | )
80 | else:
81 | self.rgb_encoder = get_resnet(
82 | "resnet18", embed_dim, weights=pretrained_weights
83 | )
84 |
85 | # Low dim model
86 | self.use_low_dim = use_low_dim
87 | self.low_dim_size = sum([key_shape_map[key][-1] for key in low_dim_keys])
88 |
89 | # Language model
90 | self.use_language = use_language
91 | self.text_encoder = (
92 | CLIPTextEncoder(embed_dim=embed_dim) if use_language else None
93 | )
94 |
95 | def __call__(self, obs_dict):
96 | # Process rgb observations
97 | imgs = list()
98 | for key in self.rgb_keys:
99 | img = obs_dict[key].flatten(0, 1)
100 | assert img.shape[1:] == self.key_shape_map[key]
101 | img = self.key_transform_map[key](img) # (B*T, C, H, W)
102 | imgs.append(img)
103 |
104 | # Concatenate along batch dimension
105 | imgs = torch.cat(imgs, dim=0) # (N*B*T, C, H, W)
106 | feats = self.rgb_encoder(imgs) # (N*B*T, D)
107 | feats = rearrange(
108 | feats, "(n b t) d -> b (t n d)", n=len(self.rgb_keys), t=self.num_frames
109 | )
110 |
111 | if self.use_low_dim:
112 | # Process low dim observations
113 | low_dims = list()
114 | for key in self.low_dim_keys:
115 | low_dim = obs_dict[key].flatten(0, 1)
116 | assert low_dim.shape[1:] == self.key_shape_map[key]
117 | low_dims.append(low_dim)
118 | low_dims = torch.cat(low_dims, dim=-1) # (B*T, D_low_dim)
119 | low_dims = rearrange(low_dims, "(b t) d -> b (t d)", t=self.num_frames)
120 |
121 | # Concatenate image and lowdim features
122 | feats = torch.cat([feats, low_dims], dim=-1)
123 |
124 | # Encode language
125 | if self.use_language:
126 | lang_feats = self.text_encoder(
127 | input_ids=obs_dict["input_ids"],
128 | attention_mask=obs_dict["attention_mask"],
129 | )
130 | feats = torch.cat([feats, lang_feats], dim=-1)
131 |
132 | return feats
133 |
134 | @property
135 | def output_len(self):
136 | return (
137 | len(self.rgb_keys) * self.embed_dim
138 | + int(self.use_low_dim) * self.low_dim_size
139 | ) * self.num_frames + int(self.use_language) * self.embed_dim
140 |
--------------------------------------------------------------------------------
/models/pad/obs_encoder.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional, Tuple, Union
2 |
3 | import torch
4 | import torch.nn as nn
5 | from einops import rearrange
6 |
7 | from models.common.transforms import VideoTransform, VAEDownsample
8 |
9 |
10 | class PADObservationEncoder(nn.Module):
11 | def __init__(
12 | self,
13 | shape_meta: dict,
14 | num_frames: int,
15 | resize_shape: Tuple[int, int] = None,
16 | crop_shape: Tuple[int, int] = None,
17 | random_crop: bool = True,
18 | color_jitter: Optional[Dict] = None,
19 | imagenet_norm: bool = False,
20 | ):
21 | super().__init__()
22 | self.shape_meta = shape_meta
23 | self.num_frames = num_frames
24 | self.rgb_keys = sorted(
25 | [k for k, v in shape_meta["obs"].items() if v["type"] == "rgb"]
26 | )
27 | self.low_dim_keys = sorted(
28 | [k for k, v in shape_meta["obs"].items() if v["type"] == "low_dim"]
29 | )
30 | self.num_views = len(self.rgb_keys)
31 |
32 | # Image augmentation
33 | self.obs_transform = VideoTransform(
34 | resize_shape=resize_shape,
35 | crop_shape=crop_shape,
36 | random_crop=random_crop,
37 | color_jitter=color_jitter,
38 | imagenet_norm=imagenet_norm,
39 | )
40 |
41 | # Downsample future observations using vae
42 | self.vae = VAEDownsample()
43 |
44 | def apply_transform(self, obs_dicts: Union[dict, list[dict]]):
45 | """
46 | Accept a list of observation dictionaries and apply the same transform to each.
47 | """
48 | if isinstance(obs_dicts, dict):
49 | obs_dicts = [obs_dicts]
50 | is_singleton = True
51 | else:
52 | is_singleton = False
53 | assert isinstance(obs_dicts, list)
54 |
55 | # Apply the same transform to each observation
56 | num_obs = len(obs_dicts)
57 | transformed_imgs = [[] for _ in range(num_obs)]
58 | for key in self.rgb_keys:
59 | combined_imgs = torch.cat([obs_dict[key] for obs_dict in obs_dicts], dim=0)
60 | combined_imgs = self.obs_transform(combined_imgs)
61 | chunked_imgs = combined_imgs.chunk(num_obs, dim=0)
62 | for i, img in enumerate(chunked_imgs):
63 | transformed_imgs[i].append(img)
64 |
65 | # Stack transformed images
66 | # Each image has shape (B, V, C, T, H, W)
67 | transformed_imgs = [torch.stack(imgs, dim=1) for imgs in transformed_imgs]
68 | if is_singleton:
69 | transformed_imgs = transformed_imgs[0]
70 | return transformed_imgs
71 |
72 | def apply_vae(
73 | self,
74 | imgs_list: Union[torch.Tensor, list[torch.Tensor]],
75 | inverse: bool = False,
76 | microbatch_size: int = 72, # Tuned for 40GB VRAM
77 | ):
78 | """
79 | Accept a list of images and apply VAE to downsample or upsample images.
80 | If inverse is False, downsample images. Otherwise, upsample images.
81 | Process images in microbatches to reduce memory usage.
82 | """
83 | if isinstance(imgs_list, torch.Tensor):
84 | imgs_list = [imgs_list]
85 | is_singleton = True
86 | else:
87 | is_singleton = False
88 | assert isinstance(imgs_list, list)
89 | imgs = torch.cat(imgs_list, dim=0)
90 |
91 | # Flatten multiview videos to images
92 | B, V = imgs.shape[:2]
93 | imgs = rearrange(imgs, "b v c t h w -> (b v t) c h w")
94 |
95 | # Process images in microbatches
96 | transformed_imgs = []
97 | for i in range(0, imgs.shape[0], microbatch_size):
98 | batch_imgs = imgs[i : i + microbatch_size]
99 | if inverse:
100 | batch_transformed_imgs = self.vae.inverse(batch_imgs)
101 | else:
102 | batch_transformed_imgs = self.vae(batch_imgs)
103 | transformed_imgs.append(batch_transformed_imgs)
104 | transformed_imgs = torch.cat(transformed_imgs, dim=0)
105 |
106 | # Unflatten images to multiview videos
107 | transformed_imgs = rearrange(
108 | transformed_imgs, "(b v t) c h w -> b v c t h w", b=B, v=V
109 | )
110 | if not is_singleton:
111 | chunk_sizes = [img.shape[0] for img in imgs_list]
112 | transformed_imgs = list(transformed_imgs.split(chunk_sizes, dim=0))
113 | return transformed_imgs
114 |
115 | def encode_obs(self, obs_dict: dict):
116 | imgs = self.apply_transform(obs_dict)
117 | latents = self.apply_vae(imgs)
118 | return latents
119 |
120 | def encode_curr_and_next_obs(self, curr_obs_dict: dict, next_obs_dict: dict):
121 | # Apply the same transform to obs and next obs
122 | curr_imgs, next_imgs = self.apply_transform([curr_obs_dict, next_obs_dict])
123 | curr_latents, next_latents = self.apply_vae([curr_imgs, next_imgs])
124 | return curr_latents, next_latents
125 |
126 | def latent_img_shape(self):
127 | # Construct dummy image and forward pass to get latent shape
128 | dummy_obs_dict = {}
129 | for k in self.rgb_keys:
130 | img_shape = self.shape_meta["obs"][k]["shape"]
131 | dummy_obs_dict[k] = torch.zeros(
132 | 1, self.num_frames, *img_shape, dtype=torch.uint8
133 | )
134 | with torch.no_grad():
135 | dummy_imgs = self.apply_transform(dummy_obs_dict)
136 | dummy_latents = self.apply_vae(dummy_imgs)
137 | return tuple(dummy_latents.shape[1:])
138 |
--------------------------------------------------------------------------------
/models/gr1/obs_encoder.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional, Tuple, Union
2 |
3 | import torch
4 | import torch.nn as nn
5 | from einops import rearrange
6 |
7 | from models.common.transforms import VideoTransform
8 | from models.common.vision import get_vit
9 | from models.gr1.flamingo import PerceiverResampler
10 |
11 |
12 | class MultiViewViTImageEncoder(nn.Module):
13 | def __init__(
14 | self, num_views: int, num_frames: int, embed_dim: int, resampler_params: dict
15 | ):
16 | super().__init__()
17 | self.num_views = num_views
18 | self.model = get_vit("vit_b_32", embed_dim, weights="IMAGENET1K_V1")
19 |
20 | # Perceiver resampler
21 | self.perceiver_resampler = PerceiverResampler(**resampler_params)
22 |
23 | # Learnable embeddings
24 | self.pos_shift = nn.Parameter(
25 | torch.zeros(1, num_views * num_frames, 1, embed_dim),
26 | requires_grad=True,
27 | )
28 | self.pos_scale = nn.Parameter(
29 | torch.zeros(1, num_views * num_frames, 1, embed_dim),
30 | requires_grad=True,
31 | )
32 |
33 | self.feat_head = nn.Linear(embed_dim, embed_dim)
34 | self.patch_head = nn.Linear(embed_dim, embed_dim)
35 |
36 | def forward(self, imgs: torch.Tensor):
37 | B, V = imgs.shape[:2]
38 | imgs = rearrange(imgs, "b v c t h w -> (b v t) c h w")
39 |
40 | # Reshape and permute the input tensor
41 | x = self.model._process_input(imgs)
42 | n = x.shape[0]
43 |
44 | # Expand the class token to the full batch
45 | batch_class_token = self.model.class_token.expand(n, -1, -1)
46 | x = torch.cat([batch_class_token, x], dim=1)
47 |
48 | # Get raw tokens
49 | x = self.model.encoder(x)
50 | feats, patch_embeds = x[:, :1], x[:, 1:]
51 |
52 | # Pass through perceiver resampler
53 | feats = self.feat_head(feats)
54 | patch_embeds = self.patch_head(
55 | self.perceiver_resampler(x.unsqueeze(1)).squeeze(1)
56 | )
57 |
58 | # Add learned positional embeddings
59 | x = torch.cat([feats, patch_embeds], dim=1)
60 | x = rearrange(x, "(b v t) n c -> b (v t) n c", b=B, v=V)
61 | x = x * (1 + self.pos_scale) + self.pos_shift
62 | return x.flatten(1, 2) # (b, v*t*n, c)
63 |
64 |
65 | class GR1ObservationEncoder(nn.Module):
66 | def __init__(
67 | self,
68 | shape_meta: dict,
69 | num_frames: int,
70 | embed_dim: int,
71 | resize_shape: Tuple[int, int] = None,
72 | crop_shape: Tuple[int, int] = None,
73 | random_crop: bool = True,
74 | color_jitter: Optional[Dict] = None,
75 | imagenet_norm: bool = False,
76 | resampler_params: dict = None,
77 | ):
78 | super().__init__()
79 | self.shape_meta = shape_meta
80 | self.num_frames = num_frames
81 | self.rgb_keys = sorted(
82 | [k for k, v in shape_meta["obs"].items() if v["type"] == "rgb"]
83 | )
84 | self.low_dim_keys = sorted(
85 | [k for k, v in shape_meta["obs"].items() if v["type"] == "low_dim"]
86 | )
87 | self.num_views = len(self.rgb_keys)
88 |
89 | # Image augmentation
90 | self.obs_transform = VideoTransform(
91 | resize_shape=resize_shape,
92 | crop_shape=crop_shape,
93 | random_crop=random_crop,
94 | color_jitter=color_jitter,
95 | imagenet_norm=imagenet_norm,
96 | )
97 |
98 | # Image encoder
99 | self.img_encoder = MultiViewViTImageEncoder(
100 | num_views=self.num_views,
101 | num_frames=self.num_frames,
102 | embed_dim=embed_dim,
103 | resampler_params=resampler_params,
104 | )
105 |
106 | def apply_transform(self, obs_dicts: Union[dict, list[dict]]):
107 | """
108 | Accept a list of observation dictionaries and apply the same transform to each.
109 | """
110 | if isinstance(obs_dicts, dict):
111 | obs_dicts = [obs_dicts]
112 | is_singleton = True
113 | else:
114 | is_singleton = False
115 | assert isinstance(obs_dicts, list)
116 |
117 | # Apply the same transform to each observation
118 | num_obs = len(obs_dicts)
119 | transformed_imgs = [[] for _ in range(num_obs)]
120 | for key in self.rgb_keys:
121 | combined_imgs = torch.cat([obs_dict[key] for obs_dict in obs_dicts], dim=0)
122 | combined_imgs = self.obs_transform(combined_imgs)
123 | chunked_imgs = combined_imgs.chunk(num_obs, dim=0)
124 | for i, img in enumerate(chunked_imgs):
125 | transformed_imgs[i].append(img)
126 |
127 | # Stack transformed images
128 | # Each image has shape (B, V, C, T, H, W)
129 | transformed_imgs = [torch.stack(imgs, dim=1) for imgs in transformed_imgs]
130 | if is_singleton:
131 | transformed_imgs = transformed_imgs[0]
132 | return transformed_imgs
133 |
134 | def encode_obs(self, obs_dict: dict):
135 | imgs = self.apply_transform(obs_dict)
136 | feats = self.img_encoder(imgs)
137 | return feats
138 |
139 | def encode_curr_and_next_obs(self, curr_obs_dict: dict, next_obs_dict: dict):
140 | # Apply the same transform to obs and next obs
141 | curr_imgs, next_imgs = self.apply_transform([curr_obs_dict, next_obs_dict])
142 | curr_feats = self.img_encoder(curr_imgs)
143 | return curr_feats, next_imgs
144 |
145 | @property
146 | def num_latents(self):
147 | return (
148 | (self.img_encoder.perceiver_resampler.num_latents + 1)
149 | * self.num_views
150 | * self.num_frames
151 | )
152 |
--------------------------------------------------------------------------------
/models/common/transforms.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torchvision.transforms.functional as ttf
6 | from diffusers import AutoencoderKL
7 | from einops import rearrange
8 | from torchvision.transforms import (
9 | CenterCrop,
10 | ColorJitter,
11 | RandomCrop,
12 | Resize,
13 | Normalize,
14 | )
15 |
16 |
17 | class ToTensor(nn.Module):
18 | """
19 | Convert a batch of images from (B, H, W, C) to (B, C, H, W)
20 | and normalize the pixel values to the range [0, 1].
21 | """
22 |
23 | def forward(self, inputs: torch.Tensor):
24 | return inputs.permute((0, 3, 1, 2)).contiguous().float().div_(255.0)
25 |
26 |
27 | class AutoRandomCrop(nn.Module):
28 | """
29 | Perform random cropping during training and center cropping during eval.
30 | """
31 |
32 | def __init__(self, size: tuple[int, int]):
33 | super().__init__()
34 | self.size = size
35 | self.random_crop = RandomCrop(size=size)
36 |
37 | def forward(self, inputs: torch.Tensor):
38 | if self.training:
39 | return self.random_crop(inputs)
40 | else:
41 | # Take center crop during eval
42 | return ttf.center_crop(img=inputs, output_size=self.size)
43 |
44 |
45 | class AutoColorJitter(nn.Module):
46 | """
47 | Perform color jittering during training and no-op during eval.
48 | """
49 |
50 | def __init__(
51 | self,
52 | brightness: float,
53 | contrast: float,
54 | saturation: float,
55 | hue: tuple[float],
56 | ):
57 | super().__init__()
58 | self.color_jitter = ColorJitter(
59 | brightness=brightness,
60 | contrast=contrast,
61 | saturation=saturation,
62 | hue=tuple(hue),
63 | )
64 |
65 | def forward(self, inputs: torch.Tensor):
66 | if self.training:
67 | return self.color_jitter(inputs)
68 | else:
69 | return inputs # no-op during eval
70 |
71 |
72 | class VAEDownsample(nn.Module):
73 | """
74 | Downsample images using a pre-trained VAE.
75 | """
76 |
77 | def __init__(self):
78 | super().__init__()
79 | # Input normalization
80 | self.norm = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
81 | self.inv_norm = Normalize(mean=[-1, -1, -1], std=[2, 2, 2], inplace=True)
82 |
83 | # Normalization stats (computed from multitask 12)
84 | shift = torch.tensor([0.0, 0.0, 0.0, 0.0]).view(1, 4, 1, 1)
85 | scale = torch.tensor([3.0, 3.0, 3.0, 3.0]).view(1, 4, 1, 1)
86 | self.register_buffer("shift", shift)
87 | self.register_buffer("scale", scale)
88 |
89 | # Load pre-trained VAE
90 | self.vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
91 | for p in self.vae.parameters():
92 | p.requires_grad = False
93 | self.scaling_factor = self.vae.config.scaling_factor
94 |
95 | def forward(self, images: torch.Tensor):
96 | images = self.norm(images)
97 | feats = self.vae.encode(images).latent_dist.sample()
98 | feats = feats.mul_(self.scaling_factor)
99 | feats = feats.sub_(self.shift).div_(self.scale)
100 | return feats
101 |
102 | def inverse(self, feats: torch.Tensor):
103 | feats = feats.mul_(self.scale).add_(self.shift)
104 | feats = feats.div_(self.scaling_factor)
105 | images = self.vae.decode(feats).sample
106 | images = self.inv_norm(images)
107 | return images
108 |
109 |
110 | class ImageTransform(nn.Module):
111 | """
112 | Apply a sequence of transforms to images.
113 | """
114 |
115 | def __init__(
116 | self,
117 | resize_shape: Optional[tuple[int, int]] = None,
118 | crop_shape: Optional[tuple[int, int]] = None,
119 | random_crop: bool = True,
120 | color_jitter: Optional[dict] = None,
121 | downsample: bool = False,
122 | imagenet_norm: bool = True,
123 | ):
124 | super().__init__()
125 | transform = list()
126 |
127 | # Convert image to tensor format
128 | transform.append(ToTensor())
129 |
130 | # Resize images
131 | if resize_shape is not None:
132 | transform.append(Resize(resize_shape))
133 |
134 | # Apply random crop during training and center crop during eval
135 | if crop_shape is not None:
136 | if random_crop:
137 | transform.append(AutoRandomCrop(crop_shape))
138 | else:
139 | transform.append(CenterCrop(crop_shape))
140 |
141 | # Apply color jitter during training
142 | if color_jitter is not None:
143 | transform.append(
144 | AutoColorJitter(
145 | brightness=color_jitter["brightness"],
146 | contrast=color_jitter["contrast"],
147 | saturation=color_jitter["saturation"],
148 | hue=tuple(color_jitter["hue"]),
149 | )
150 | )
151 |
152 | # Normalize using imagenet statistics
153 | if downsample:
154 | if imagenet_norm:
155 | print("Disabling imagenet normalization since downsample is enabled.")
156 | transform.append(VAEDownsample())
157 | elif imagenet_norm:
158 | transform.append(
159 | Normalize(
160 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=True
161 | )
162 | )
163 |
164 | self.transform = nn.Sequential(*transform)
165 |
166 | @property
167 | def vae(self):
168 | assert isinstance(self.transform[-1], VAEDownsample)
169 | return self.transform[-1]
170 |
171 | def forward(self, images):
172 | return self.transform(images)
173 |
174 |
175 | class VideoTransform(ImageTransform):
176 | """
177 | Flatten videos to images, apply transforms, and reshape back to videos.
178 | """
179 |
180 | def forward(self, images):
181 | num_frames = images.shape[1]
182 | images = rearrange(images, "b t h w c-> (b t) h w c")
183 | images = self.transform(images)
184 | images = rearrange(images, "(b t) c h w-> b c t h w", t=num_frames)
185 | return images
186 |
--------------------------------------------------------------------------------
/datasets/robomimic/dataset.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import h5py
3 | import numpy as np
4 | import torch
5 | from torch.utils.data import Dataset
6 | from tqdm import tqdm
7 |
8 | from datasets.utils.buffer import CompressedTrajectoryBuffer
9 | from datasets.utils.file_utils import glob_all
10 | from datasets.utils.sampler import TrajectorySampler
11 | from datasets.utils.obs_utils import unflatten_obs
12 |
13 |
14 | class RobomimicDataset(Dataset):
15 | def __init__(
16 | self,
17 | name: str,
18 | hdf5_path_globs: str,
19 | buffer_path: str,
20 | shape_meta: dict,
21 | seq_len: int,
22 | val_ratio: float = 0.0,
23 | subsample_ratio: float = 1.0,
24 | flip_rgb: bool = False,
25 | ):
26 | self.name = name
27 | self.seq_len = seq_len
28 | self.flip_rgb = flip_rgb
29 |
30 | # Parse observation and action shapes
31 | obs_shape_meta = shape_meta["obs"]
32 | self._image_shapes = {}
33 | self._lowdim_shapes = {}
34 | for key, attr in obs_shape_meta.items():
35 | obs_type = attr["type"]
36 | obs_shape = tuple(attr["shape"])
37 | if obs_type == "rgb":
38 | self._image_shapes[key] = obs_shape
39 | elif obs_type == "low_dim":
40 | self._lowdim_shapes[key] = obs_shape
41 | else:
42 | raise RuntimeError(f"Unsupported obs type: {obs_type}")
43 | self._action_shape = tuple(shape_meta["action"]["shape"])
44 |
45 | # Compressed buffer to store episode data
46 | self.buffer = self._init_buffer(hdf5_path_globs, buffer_path)
47 |
48 | # Create training-validation split
49 | num_episodes = self.buffer.num_episodes
50 | val_mask = np.zeros(num_episodes, dtype=bool)
51 | if val_ratio > 0:
52 | num_val_episodes = round(val_ratio * num_episodes)
53 | num_val_episodes = min(max(num_val_episodes, 1), num_episodes - 1)
54 | rng = np.random.default_rng(seed=0)
55 | val_inds = rng.choice(num_episodes, num_val_episodes, replace=False)
56 | val_mask[val_inds] = True
57 | self.val_mask = val_mask
58 | self.train_mask = ~val_mask
59 |
60 | # Apply subsample_ratio to training episodes
61 | if subsample_ratio < 1.0:
62 | train_indices = np.where(self.train_mask)[0]
63 | num_train_episodes = len(train_indices)
64 | num_subsampled = round(num_train_episodes * subsample_ratio)
65 | num_subsampled = max(1, num_subsampled) # Ensure at least one episode
66 |
67 | # Create a new mask with only the subsampled training episodes
68 | subsampled_train_mask = np.zeros(num_episodes, dtype=bool)
69 | rng = np.random.default_rng(seed=1)
70 | sampled_indices = rng.choice(train_indices, num_subsampled, replace=False)
71 | subsampled_train_mask[sampled_indices] = True
72 | self.train_mask = subsampled_train_mask
73 |
74 | # Sampler to draw sequences from buffer
75 | self.sampler = TrajectorySampler(self.buffer, self.seq_len, self.train_mask)
76 |
77 | def _init_buffer(self, hdf5_path_globs, buffer_path):
78 | hdf5_paths = glob_all(hdf5_path_globs)
79 |
80 | # Create metadata
81 | metadata = {}
82 | for key, shape in self._image_shapes.items():
83 | metadata[f"obs.{key}"] = {"shape": shape, "dtype": np.uint8}
84 | for key, shape in self._lowdim_shapes.items():
85 | metadata[f"obs.{key}"] = {"shape": shape, "dtype": np.float32}
86 | metadata["action"] = {"shape": self._action_shape, "dtype": np.float32}
87 |
88 | # Compute buffer capacity
89 | capacity = 0
90 | num_episodes = 0
91 | for hdf5_path in hdf5_paths:
92 | with h5py.File(hdf5_path) as f:
93 | demos = f["data"]
94 | for i in range(len(demos)):
95 | demo = demos[f"demo_{i}"]
96 | capacity += demo["actions"].shape[0]
97 | num_episodes += len(demos)
98 |
99 | # Initialize buffer
100 | buffer = CompressedTrajectoryBuffer(
101 | storage_path=buffer_path,
102 | metadata=metadata,
103 | capacity=capacity,
104 | )
105 |
106 | # If buffer is restored from disk, return it
107 | if buffer.restored:
108 | return buffer
109 |
110 | # Otherwise, load episodes to buffer
111 | pbar = tqdm(total=num_episodes, desc="Loading episodes to buffer")
112 | for hdf5_path in hdf5_paths:
113 | with h5py.File(hdf5_path) as f:
114 | demos = f["data"]
115 | for i in range(len(demos)):
116 | demo = demos[f"demo_{i}"]
117 | episode = {}
118 | for key in self._image_shapes.keys():
119 | if self.flip_rgb:
120 | episode[f"obs.{key}"] = demo["obs"][key][:][:, ::-1]
121 | else:
122 | episode[f"obs.{key}"] = demo["obs"][key][:]
123 | for key in self._lowdim_shapes.keys():
124 | episode[f"obs.{key}"] = demo["obs"][key][:]
125 | episode["action"] = demo["actions"][:]
126 | buffer.add_episode(episode)
127 | pbar.update(1)
128 | pbar.close()
129 | return buffer
130 |
131 | def __len__(self) -> int:
132 | return len(self.sampler)
133 |
134 | def __repr__(self) -> str:
135 | return f"\nname: {self.name}\nnum_samples: {len(self)}\n{self.buffer}"
136 |
137 | def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
138 | # Sample a sequence of observations and actions from the dataset.
139 | data = self.sampler.sample_sequence(idx)
140 |
141 | # Convert data to torch tensors
142 | data = {k: torch.from_numpy(v) for k, v in data.items()}
143 |
144 | # Unflatten observations
145 | data = unflatten_obs(data)
146 | return data
147 |
148 | def get_validation_dataset(self):
149 | val_set = copy.copy(self)
150 | val_set.train_mask = self.val_mask
151 | val_set.sampler = TrajectorySampler(self.buffer, self.seq_len, self.val_mask)
152 | return val_set
153 |
--------------------------------------------------------------------------------
/models/common/adaln_attention.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from .attention import Attention, CrossAttention, MLP
4 |
5 |
6 | def modulate(x, shift, scale):
7 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
8 |
9 |
10 | class AdaLNAttentionBlock(nn.Module):
11 | """Multiheaded self-attention block with adaptive layer normalization modulation."""
12 |
13 | def __init__(
14 | self,
15 | dim,
16 | cond_dim,
17 | num_heads=8,
18 | mlp_ratio=4.0,
19 | qkv_bias=False,
20 | drop=0.0,
21 | attn_drop=0.0,
22 | act=nn.GELU,
23 | norm=nn.LayerNorm,
24 | is_causal=False,
25 | causal_block=1,
26 | ):
27 | super().__init__()
28 | self.norm1 = norm(dim, elementwise_affine=False, eps=1e-6)
29 | self.attn = Attention(
30 | dim=dim,
31 | num_heads=num_heads,
32 | qkv_bias=qkv_bias,
33 | attn_drop=attn_drop,
34 | proj_drop=drop,
35 | is_causal=is_causal,
36 | causal_block=causal_block,
37 | )
38 | self.norm2 = norm(dim, elementwise_affine=False, eps=1e-6)
39 | self.mlp = MLP(
40 | in_dim=dim,
41 | hidden_dim=int(dim * mlp_ratio),
42 | out_dim=dim,
43 | act=act,
44 | drop=drop,
45 | )
46 | self.adaLN_modulation = nn.Sequential(
47 | nn.SiLU(),
48 | nn.Linear(cond_dim, 6 * dim),
49 | )
50 |
51 | def forward(self, x, cond, pos_embed=None, attn_mask=None):
52 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
53 | self.adaLN_modulation(cond).chunk(6, dim=1)
54 | )
55 | x = x + gate_msa.unsqueeze(1) * self.attn(
56 | modulate(self.norm1(x), shift_msa, scale_msa), pos_embed, attn_mask
57 | )
58 | x = x + gate_mlp.unsqueeze(1) * self.mlp(
59 | modulate(self.norm2(x), shift_mlp, scale_mlp)
60 | )
61 | return x
62 |
63 |
64 | class AdaLNCrossAttentionBlock(nn.Module):
65 | """Multiheaded cross-attention block with adaptive layer normalization modulation."""
66 |
67 | def __init__(
68 | self,
69 | dim,
70 | cond_dim,
71 | num_heads=8,
72 | mlp_ratio=4.0,
73 | qkv_bias=False,
74 | drop=0.0,
75 | attn_drop=0.0,
76 | act=nn.GELU,
77 | norm=nn.LayerNorm,
78 | ):
79 | super().__init__()
80 | self.norm1 = norm(dim, elementwise_affine=False, eps=1e-6)
81 | self.xattn = CrossAttention(
82 | dim=dim,
83 | num_heads=num_heads,
84 | qkv_bias=qkv_bias,
85 | attn_drop=attn_drop,
86 | proj_drop=drop,
87 | )
88 | self.norm2 = norm(dim, elementwise_affine=False, eps=1e-6)
89 | self.mlp = MLP(
90 | in_dim=dim,
91 | hidden_dim=int(dim * mlp_ratio),
92 | out_dim=dim,
93 | act=act,
94 | drop=drop,
95 | )
96 | self.adaLN_modulation = nn.Sequential(
97 | nn.SiLU(),
98 | nn.Linear(cond_dim, 6 * dim),
99 | )
100 |
101 | def forward(self, x, c, cond, x_pos_embed=None, c_pos_embed=None):
102 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
103 | self.adaLN_modulation(cond).chunk(6, dim=1)
104 | )
105 | x = x + gate_msa.unsqueeze(1) * self.xattn(
106 | modulate(x, shift_msa, scale_msa), self.norm1(c), x_pos_embed, c_pos_embed
107 | )
108 | x = x + gate_mlp.unsqueeze(1) * self.mlp(
109 | modulate(self.norm2(x), shift_mlp, scale_mlp)
110 | )
111 | return x
112 |
113 |
114 | class AdaLNHybridAttentionBlock(nn.Module):
115 | """Multiheaded hybrid attention block with adaptive layer normalization modulation."""
116 |
117 | def __init__(
118 | self,
119 | dim,
120 | cond_dim,
121 | num_heads=8,
122 | mlp_ratio=4.0,
123 | qkv_bias=False,
124 | drop=0.0,
125 | attn_drop=0.0,
126 | act=nn.GELU,
127 | norm=nn.LayerNorm,
128 | ):
129 | super().__init__()
130 | self.norm1 = norm(dim, elementwise_affine=False, eps=1e-6)
131 | self.attn = Attention(
132 | dim=dim,
133 | num_heads=num_heads,
134 | qkv_bias=qkv_bias,
135 | attn_drop=attn_drop,
136 | proj_drop=drop,
137 | )
138 | self.norm2 = norm(dim, elementwise_affine=False, eps=1e-6)
139 | self.xattn = CrossAttention(
140 | dim=dim,
141 | num_heads=num_heads,
142 | qkv_bias=qkv_bias,
143 | attn_drop=attn_drop,
144 | proj_drop=drop,
145 | )
146 | self.norm3 = norm(dim, elementwise_affine=False, eps=1e-6)
147 | self.mlp = MLP(
148 | in_dim=dim,
149 | hidden_dim=int(dim * mlp_ratio),
150 | out_dim=dim,
151 | act=act,
152 | drop=drop,
153 | )
154 | self.adaLN_modulation = nn.Sequential(
155 | nn.SiLU(),
156 | nn.Linear(cond_dim, 6 * dim),
157 | )
158 |
159 | def forward(self, x, c, cond, x_pos_embed=None, c_pos_embed=None):
160 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
161 | self.adaLN_modulation(cond).chunk(6, dim=1)
162 | )
163 | x = x + gate_msa.unsqueeze(1) * self.attn(
164 | modulate(self.norm1(x), shift_msa, scale_msa), x_pos_embed
165 | )
166 | x = x + self.xattn(self.norm2(x), c, x_pos_embed, c_pos_embed)
167 | x = x + gate_mlp.unsqueeze(1) * self.mlp(
168 | modulate(self.norm3(x), shift_mlp, scale_mlp)
169 | )
170 | return x
171 |
172 |
173 | class AdaLNFinalLayer(nn.Module):
174 | def __init__(self, dim, cond_dim):
175 | super().__init__()
176 | self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
177 | self.linear = nn.Linear(dim, dim)
178 | self.adaLN_modulation = nn.Sequential(
179 | nn.SiLU(),
180 | nn.Linear(cond_dim, 2 * dim),
181 | )
182 |
183 | def forward(self, x, cond):
184 | shift, scale = self.adaLN_modulation(cond).chunk(2, dim=1)
185 | x = self.linear(modulate(self.norm(x), shift, scale))
186 | return x
187 |
--------------------------------------------------------------------------------
/experiments/dp/train_robomimic.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import torch
3 | import torch.distributed as dist
4 | import torch.multiprocessing as mp
5 | import torch.nn.functional as F
6 | import wandb
7 | from diffusers.optimization import get_scheduler
8 | from omegaconf import OmegaConf
9 | from hydra.utils import instantiate
10 | from torch.nn.parallel import DistributedDataParallel
11 | from tqdm import trange, tqdm
12 |
13 | from datasets.utils.loader import make_distributed_data_loader
14 | from environments.robomimic import make_robomimic_env
15 | from experiments.dp.train import (
16 | train_one_step,
17 | maybe_resume_checkpoint,
18 | maybe_evaluate,
19 | maybe_save_checkpoint,
20 | )
21 | from experiments.utils import set_seed, init_wandb, init_distributed, is_main_process
22 |
23 |
24 | def collect_rollout(config, model, device):
25 | model.eval()
26 | model = getattr(model, "module", model) # unwrap DDP
27 |
28 | # Create eval environment
29 | assert isinstance(config.dataset.hdf5_path_globs, str)
30 | env = make_robomimic_env(
31 | dataset_name=config.dataset.name,
32 | dataset_path=config.dataset.hdf5_path_globs,
33 | shape_meta=config.dataset.shape_meta,
34 | obs_horizon=model.obs_encoder.num_frames,
35 | max_episode_length=config.rollout_length,
36 | record=True,
37 | )
38 |
39 | # Collect rollouts
40 | successes = []
41 | for e in trange(
42 | config.num_rollouts, desc="Collecting rollouts", disable=not is_main_process()
43 | ):
44 | env.seed(e)
45 | obs = env.reset()
46 | done = False
47 | while not done:
48 | obs_tensor = {
49 | k: torch.tensor(v, device=device)[None] for k, v in obs.items()
50 | }
51 |
52 | # Sample action from model
53 | action = model.sample(obs_tensor)[0].cpu().numpy()
54 |
55 | # Step environment
56 | obs, reward, done, info = env.step(action)
57 | successes.append(info["success"])
58 |
59 | # Compute success rate
60 | success_rate = sum(successes) / len(successes)
61 |
62 | # Record video of the last episode
63 | video = env.get_video()
64 | return success_rate, video
65 |
66 |
67 | def maybe_collect_rollout(config, step, model, device):
68 | """Collect rollouts on the main process if it's the correct step."""
69 | # Skip rollout rollection for pretraining
70 | if "libero_90" in config.dataset.name:
71 | return
72 |
73 | if is_main_process() and (
74 | step % config.rollout_every == 0 or step == (config.num_steps - 1)
75 | ):
76 | success_rate, video = collect_rollout(config, model, device)
77 | print(f"Step: {step} success rate: {success_rate}")
78 | # Video shape: (T, H, W, C) -> (N, T, C, H, W)
79 | video = video.transpose(0, 3, 1, 2)[None]
80 | wandb.log(
81 | {
82 | "rollout/success_rate": success_rate,
83 | "rollout/video": wandb.Video(video, fps=10),
84 | }
85 | )
86 | dist.barrier()
87 |
88 |
89 | def train(rank, world_size, config):
90 | # Set global seed
91 | set_seed(config.seed * world_size + rank)
92 |
93 | # Initialize distributed training
94 | init_distributed(rank, world_size)
95 | device = torch.device(f"cuda:{rank}")
96 |
97 | # Initialize WANDB
98 | if is_main_process():
99 | init_wandb(config, job_type="train")
100 |
101 | # Create dataset
102 | train_set, val_set = instantiate(config.dataset)
103 | train_loader, val_loader = make_distributed_data_loader(
104 | train_set, val_set, config.batch_size, rank, world_size
105 | )
106 |
107 | # Create model
108 | model = instantiate(config.model).to(device)
109 | optimizer = torch.optim.AdamW(model.parameters(), **config.optimizer)
110 | scheduler = get_scheduler(optimizer=optimizer, **config.scheduler)
111 | scaler = torch.cuda.amp.GradScaler(enabled=config.use_amp)
112 |
113 | # Load pretrained model
114 | if config.pretrain_checkpoint_path:
115 | ckpt = torch.load(config.pretrain_checkpoint_path, map_location="cpu")
116 | model.load_state_dict(ckpt["model"])
117 | print(
118 | f"Loaded pretraining checkpoint {config.pretrain_checkpoint_path}, step: {ckpt['step']}"
119 | )
120 |
121 | # Resume from checkpoint
122 | step = maybe_resume_checkpoint(config, model, optimizer, scheduler, scaler)
123 | epoch = step // len(train_loader)
124 |
125 | # Wrap model with DDP
126 | model = DistributedDataParallel(model, device_ids=[rank], static_graph=True)
127 |
128 | # Training loop
129 | pbar = tqdm(
130 | total=config.num_steps,
131 | initial=step,
132 | desc="Training",
133 | disable=not is_main_process(),
134 | )
135 | while step < config.num_steps:
136 | # Set epoch for distributed sampler to shuffle indices
137 | train_loader.sampler.set_epoch(epoch)
138 |
139 | # Train for one epoch
140 | for batch in train_loader:
141 | # --- Training step ---
142 | loss, info = train_one_step(
143 | config, model, optimizer, scheduler, scaler, batch, device
144 | )
145 |
146 | # --- Logging ---
147 | if is_main_process():
148 | pbar.set_description(f"step: {step}, loss: {loss.item():.4f}")
149 | wandb.log({f"train/{k}": v for k, v in info.items()})
150 |
151 | # --- Evaluate if needed ---
152 | maybe_evaluate(config, step, model, val_loader, device)
153 |
154 | # ---Collect environment rollouts if needed ---
155 | maybe_collect_rollout(config, step, model, device)
156 |
157 | # --- Save checkpoint if needed ---
158 | maybe_save_checkpoint(config, step, model, optimizer, scheduler, scaler)
159 |
160 | step += 1
161 | pbar.update(1)
162 | if step >= config.num_steps:
163 | break
164 |
165 | epoch += 1
166 |
167 |
168 | @hydra.main(
169 | version_base=None,
170 | config_path="../../configs",
171 | config_name="train_dp_robomimic.yaml",
172 | )
173 | def main(config):
174 | # Resolve hydra config
175 | OmegaConf.resolve(config)
176 | # Spawn processes
177 | world_size = torch.cuda.device_count()
178 | mp.spawn(train, args=(world_size, config), nprocs=world_size, join=True)
179 |
180 |
181 | if __name__ == "__main__":
182 | main()
183 |
--------------------------------------------------------------------------------
/datasets/webvideo/dataset.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import warnings
4 |
5 | import numpy as np
6 | import torch
7 |
8 | from decord import VideoReader, cpu
9 | from torch.utils.data import Dataset
10 | from torchvision.transforms.functional import resize, center_crop
11 |
12 |
13 | def compute_resize_shape(original_size: tuple[int, int], target_size: tuple[int, int]):
14 | """
15 | This method calculates the dimensions (height and width) needed to resize an image so that it
16 | minimally bounds the target dimensions defined in `target_size` while preserving the original
17 | aspect ratio.
18 |
19 | Args:
20 | original_size (tuple): The original height and width of the image.
21 | target_size (tuple): The target height and width of the image.
22 |
23 | Returns:
24 | tuple: A tuple representing the computed height and width for resizing.
25 |
26 | """
27 | h, w = original_size
28 | target_h, target_w = target_size
29 | scale = max(target_h / h, target_w / w)
30 | new_h = math.ceil(h * scale)
31 | new_w = math.ceil(w * scale)
32 | return (new_h, new_w)
33 |
34 |
35 | class MultiviewVideoDataset(Dataset):
36 | def __init__(
37 | self,
38 | index_paths: list[str],
39 | shape_meta: dict,
40 | clip_len: int,
41 | frame_skip: int = 2,
42 | obs_padding: str = "same",
43 | ):
44 | """A video dataset that augments single-view videos to multi-view videos
45 | by padding missing observations.
46 |
47 | Args:
48 | index_paths: list of paths to index files each containing paths to video files.
49 | shape_meta: dictionary containing metadata for observations and actions.
50 | clip_len: length of each clip in frames.
51 | frame_skip: number of frames to skip between frames.
52 | obs_padding: padding method for observations, chosen from ["none", "same", "random"]
53 | """
54 | self.index_paths = index_paths
55 | self.clip_len = clip_len
56 | self.frame_skip = frame_skip
57 | self.obs_padding = obs_padding
58 |
59 | self.image_shapes = {}
60 | self.lowdim_shapes = {}
61 | for key, attr in shape_meta["obs"].items():
62 | obs_type, obs_shape = attr["type"], tuple(attr["shape"])
63 | if obs_type == "rgb":
64 | self.image_shapes[key] = obs_shape
65 | elif obs_type == "low_dim":
66 | self.lowdim_shapes[key] = obs_shape
67 | else:
68 | raise RuntimeError(f"Unsupported obs type: {obs_type}")
69 | self.action_shape = shape_meta["action"]["shape"]
70 |
71 | # Check that all image observations have the same shape
72 | assert (
73 | len(set(self.image_shapes.values())) == 1
74 | ), "Image observations must have the same shape"
75 | self.image_size = next(iter(self.image_shapes.values()))[:2]
76 | self.image_keys = list(self.image_shapes.keys())
77 |
78 | self.samples = []
79 | for index_path in index_paths:
80 | # Each line in a data file is a path to a video clip
81 | with open(index_path, "r") as f:
82 | video_paths = f.read().splitlines()
83 | self.samples.extend(video_paths)
84 |
85 | def __len__(self):
86 | return len(self.samples)
87 |
88 | def __getitem__(self, index):
89 | # Keep trying to load video until successful
90 | clip = None
91 | while clip is None:
92 | clip = self.load_video(self.samples[index])
93 | if clip is None:
94 | index = np.random.randint(self.__len__())
95 |
96 | # Construct image observations based on padding method
97 | obs_dict = {}
98 | if self.obs_padding == "none":
99 | main_key = np.random.choice(self.image_keys)
100 | obs_dict[main_key] = clip
101 | for key in [k for k in self.image_keys if k != main_key]:
102 | obs_dict[key] = torch.zeros_like(clip)
103 | elif self.obs_padding == "same":
104 | for key in self.image_keys:
105 | obs_dict[key] = clip
106 | elif self.obs_padding == "random":
107 | main_key = np.random.choice(self.image_keys)
108 | obs_dict[main_key] = clip
109 | for key in [k for k in self.image_keys if k != main_key]:
110 | rand_clip = None
111 | while rand_clip is None:
112 | rand_index = np.random.randint(self.__len__())
113 | rand_clip = self.load_video(self.samples[rand_index])
114 | obs_dict[key] = rand_clip
115 | else:
116 | raise ValueError(f"Invalid padding method {self.obs_padding}")
117 |
118 | # Zero pad lowdim observations
119 | for key in self.lowdim_shapes.keys():
120 | obs_dict[key] = torch.zeros(self.clip_len, *self.lowdim_shapes[key])
121 |
122 | # Construct sample
123 | sample = {
124 | "obs": obs_dict,
125 | "action": torch.zeros(self.clip_len, *self.action_shape),
126 | "action_mask": torch.tensor(0, dtype=torch.bool),
127 | }
128 | return sample
129 |
130 | def load_video(self, fname: str):
131 | if not os.path.exists(fname):
132 | warnings.warn(f"video path not found {fname}")
133 | return None
134 |
135 | # Skip short or long videos
136 | fsize = os.path.getsize(fname)
137 | if fsize < 1 * 1024 or fsize > int(10**9):
138 | warnings.warn(f"video size {fsize} out of bounds {fname}")
139 | return None
140 |
141 | # Try loading video
142 | try:
143 | vr = VideoReader(fname, num_threads=-1, ctx=cpu(0))
144 | except Exception:
145 | return None
146 |
147 | # Compute full clip length
148 | full_clip_len = int(self.clip_len * self.frame_skip)
149 |
150 | # Filter videos shorter than a single clip
151 | if len(vr) < full_clip_len:
152 | warnings.warn(f"video length {len(vr)} shorter than a single clip {fname}")
153 | return None
154 |
155 | # Sample random clip from video
156 | start_indx = np.random.randint(0, len(vr) - full_clip_len + 1)
157 | end_indx = start_indx + full_clip_len
158 | indices = np.linspace(start_indx, end_indx, self.clip_len)
159 | indices = np.clip(indices, start_indx, end_indx - 1).astype(np.int64)
160 |
161 | # Load clip
162 | vr.seek(0)
163 | clip = vr.get_batch(indices).asnumpy()
164 |
165 | # Postprocess video
166 | clip = torch.from_numpy(clip)
167 | clip = clip.permute(0, 3, 1, 2) # (T, C, H, W)
168 | clip = resize(clip, compute_resize_shape(clip.shape[2:], self.image_size))
169 | clip = center_crop(clip, self.image_size)
170 | return clip.permute(0, 2, 3, 1) # (T, H, W, C)
171 |
--------------------------------------------------------------------------------
/experiments/uwm/eval_droid.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import deque
3 |
4 | import hydra
5 | import torch
6 | import numpy as np
7 |
8 | from droid.controllers.oculus_controller import VRPolicy
9 | from droid.data_processing.timestep_processing import TimestepProcesser
10 | from droid.robot_env import RobotEnv
11 | from droid.user_interface.data_collector import DataCollecter
12 | from droid.user_interface.gui import RobotGUI
13 |
14 | from datasets.droid.utils import rot_6d_to_euler_angles
15 |
16 |
17 | class DROIDAgent:
18 | def __init__(
19 | self,
20 | model,
21 | device,
22 | img_keys,
23 | action_normalizer,
24 | lowdim_normalizer,
25 | obs_horizon=2,
26 | action_horizon=8,
27 | ):
28 | self.model = model
29 | self.device = device
30 | self.img_keys = img_keys
31 | self.action_normalizer = action_normalizer
32 | self.lowdim_normalizer = lowdim_normalizer
33 | self.obs_horizon = obs_horizon
34 | self.action_horizon = action_horizon
35 |
36 | assert obs_horizon == model.obs_encoder.num_frames
37 | assert action_horizon <= model.action_len
38 |
39 | self.action_scale = action_normalizer.scale[None]
40 | self.action_offset = action_normalizer.offset[None]
41 |
42 | # Observation buffer
43 | self.obs_buffer = deque(maxlen=obs_horizon)
44 | self.act_buffer = deque(maxlen=action_horizon)
45 |
46 | # Timestep processor
47 | self.timestep_processor = TimestepProcesser(
48 | ignore_action=True,
49 | action_space="cartesian_position",
50 | gripper_action_space="position",
51 | robot_state_keys=[
52 | "cartesian_position",
53 | "gripper_position",
54 | "joint_positions",
55 | ],
56 | image_transform_kwargs=dict(
57 | remove_alpha=True,
58 | bgr_to_rgb=True,
59 | to_tensor=False,
60 | augment=False,
61 | ),
62 | )
63 |
64 | def _convert_obs(self, observation):
65 | # Process camera images
66 | timestep = {"observation": observation}
67 | processed_timestep = self.timestep_processor.forward(timestep)
68 | camera_images = processed_timestep["observation"]["camera"]["image"]
69 |
70 | # Extract observations
71 | obs = {
72 | "exterior_image_1_left": camera_images["varied_camera"][0],
73 | "exterior_image_2_left": camera_images["varied_camera"][2],
74 | "wrist_image_left": camera_images["hand_camera"][0],
75 | "cartesian_position": observation["robot_state"]["cartesian_position"],
76 | "gripper_position": np.array(
77 | [observation["robot_state"]["gripper_position"]]
78 | ),
79 | }
80 |
81 | # Convert image observations to torch tensors
82 | for key in self.img_keys:
83 | obs[key] = torch.from_numpy(obs[key][None]).to(self.device)
84 |
85 | # Normalize low-dimensional observations and convert to torch tensors
86 | for key in self.lowdim_normalizer.keys():
87 | lowdim_obs = self.lowdim_normalizer[key](obs[key])
88 | obs[key] = torch.from_numpy(lowdim_obs[None]).float().to(self.device)
89 |
90 | return obs
91 |
92 | def _convert_action(self, action):
93 | xyz, rot6d, grippers = action[:3], action[3:9], action[9:]
94 | euler = rot_6d_to_euler_angles(torch.tensor(rot6d)).numpy()
95 | return np.concatenate([xyz, euler, grippers])
96 |
97 | @torch.no_grad()
98 | def forward(self, observation):
99 | # Encode observations
100 | obs = self._convert_obs(observation)
101 |
102 | # Update observation buffer
103 | if len(self.obs_buffer) == 0:
104 | # Pad observation buffer if empty (only after reset)
105 | for _ in range(self.obs_horizon):
106 | self.obs_buffer.append(obs)
107 | else:
108 | self.obs_buffer.append(obs)
109 |
110 | # Update action buffer
111 | if len(self.act_buffer) == 0:
112 | # Stack observations by key
113 | obs_seq = {}
114 | for key in obs.keys():
115 | obs_seq[key] = torch.stack([obs[key] for obs in self.obs_buffer], dim=1)
116 |
117 | # Sample actions
118 | act_seq = self.model.sample(obs_seq)
119 | act_seq = act_seq[0].cpu().numpy()
120 | act_seq = act_seq * self.action_scale + self.action_offset
121 |
122 | # Store new actions in buffer
123 | for t in range(self.action_horizon):
124 | self.act_buffer.append(self._convert_action(act_seq[t]))
125 |
126 | # Return next action
127 | action = self.act_buffer.popleft()
128 | # Clip action
129 | action = np.clip(action, -1, 1)
130 | return action
131 |
132 | def reset(self):
133 | self.obs_buffer.clear()
134 | self.act_buffer.clear()
135 |
136 |
137 | @hydra.main(
138 | version_base=None, config_path="../../configs", config_name="train_uwm.yaml"
139 | )
140 | def main(config):
141 | device = torch.device(f"cuda:0")
142 |
143 | # Create model
144 | model = hydra.utils.instantiate(config.model).to(device)
145 | model.eval()
146 |
147 | # Load models
148 | ckpt = torch.load(os.path.join(config.logdir, "models.pt"), map_location="cpu")
149 | model.load_state_dict(ckpt["model"])
150 | print(f"Loaded models from pretraining checkpoint, step: {ckpt['step']}")
151 |
152 | # Create agent
153 | img_keys = [
154 | k for k, v in config.dataset.shape_meta["obs"].items() if v["type"] == "rgb"
155 | ]
156 | agent = DROIDAgent(
157 | model=model,
158 | device=device,
159 | img_keys=img_keys,
160 | action_normalizer=ckpt["action_normalizer"],
161 | lowdim_normalizer=ckpt["lowdim_normalizer"],
162 | obs_horizon=config.model.obs_encoder.num_frames,
163 | action_horizon=config.model.action_len // 2,
164 | )
165 |
166 | # Create evaluation environment
167 | h, w = tuple(config.dataset.shape_meta["obs"][img_keys[0]]["shape"][:2])
168 | img_size = (w, h) # flip width and height
169 | env = RobotEnv(
170 | action_space="cartesian_velocity",
171 | gripper_action_space="position",
172 | camera_kwargs=dict(
173 | hand_camera=dict(
174 | image=True,
175 | concatenate_images=False,
176 | resolution=img_size,
177 | resize_func="cv2",
178 | ),
179 | varied_camera=dict(
180 | image=True,
181 | concatenate_images=False,
182 | resolution=img_size,
183 | resize_func="cv2",
184 | ),
185 | ),
186 | )
187 | controller = VRPolicy()
188 |
189 | # Launch GUI
190 | data_collector = DataCollecter(
191 | env=env,
192 | controller=controller,
193 | policy=agent,
194 | save_traj_dir=os.path.join(config.logdir, "videos"),
195 | save_data=True,
196 | )
197 | RobotGUI(robot=data_collector)
198 |
199 |
200 | if __name__ == "__main__":
201 | main()
202 |
--------------------------------------------------------------------------------
/experiments/uwm/train_robomimic.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import torch
3 | import torch.distributed as dist
4 | import torch.multiprocessing as mp
5 | import wandb
6 | from diffusers.optimization import get_scheduler
7 | from hydra.utils import instantiate
8 | from omegaconf import OmegaConf
9 | from torch.nn.parallel import DistributedDataParallel
10 | from tqdm import trange, tqdm
11 |
12 | from datasets.utils.loader import make_distributed_data_loader
13 | from environments.robomimic import make_robomimic_env
14 | from experiments.utils import set_seed, init_wandb, init_distributed, is_main_process
15 | from experiments.uwm.train import (
16 | train_one_step,
17 | maybe_resume_checkpoint,
18 | maybe_evaluate,
19 | maybe_save_checkpoint,
20 | )
21 |
22 |
23 | def collect_rollout(config, model, device):
24 | model.eval()
25 | model = getattr(model, "module", model) # unwrap DDP
26 |
27 | # Create eval environment
28 | assert isinstance(config.dataset.hdf5_path_globs, str)
29 | env = make_robomimic_env(
30 | dataset_name=config.dataset.name,
31 | dataset_path=config.dataset.hdf5_path_globs,
32 | shape_meta=config.dataset.shape_meta,
33 | obs_horizon=model.obs_encoder.num_frames,
34 | max_episode_length=config.rollout_length,
35 | record=True,
36 | )
37 |
38 | # Collect rollouts
39 | successes = []
40 | for e in trange(
41 | config.num_rollouts, desc="Collecting rollouts", disable=not is_main_process()
42 | ):
43 | env.seed(e)
44 | obs = env.reset()
45 | done = False
46 | while not done:
47 | obs_tensor = {
48 | k: torch.tensor(v, device=device)[None] for k, v in obs.items()
49 | }
50 |
51 | # Sample action from model
52 | action = model.sample(obs_tensor)[0].cpu().numpy()
53 |
54 | # Step environment
55 | obs, reward, done, info = env.step(action)
56 | successes.append(info["success"])
57 |
58 | # Compute success rate
59 | success_rate = sum(successes) / len(successes)
60 |
61 | # Record video of the last episode
62 | video = env.get_video()
63 | return success_rate, video
64 |
65 |
66 | def maybe_collect_rollout(config, step, model, device):
67 | """Collect rollouts on the main process if it's the correct step."""
68 | # Skip rollout rollection for pretraining
69 | if "libero_90" in config.dataset.name:
70 | return
71 |
72 | if is_main_process() and (
73 | step % config.rollout_every == 0 or step == (config.num_steps - 1)
74 | ):
75 | success_rate, video = collect_rollout(config, model, device)
76 | print(f"Step: {step} success rate: {success_rate}")
77 | # Video shape: (T, H, W, C) -> (N, T, C, H, W)
78 | video = video.transpose(0, 3, 1, 2)[None]
79 | wandb.log(
80 | {
81 | "rollout/success_rate": success_rate,
82 | "rollout/video": wandb.Video(video, fps=10),
83 | }
84 | )
85 | dist.barrier()
86 |
87 |
88 | def train(rank, world_size, config):
89 | # Set global seed
90 | set_seed(config.seed * world_size + rank)
91 |
92 | # Initialize distributed training
93 | init_distributed(rank, world_size)
94 | device = torch.device(f"cuda:{rank}")
95 |
96 | # Initialize WANDB
97 | if is_main_process():
98 | init_wandb(config, job_type="train")
99 |
100 | # Create dataset and loader
101 | train_set, val_set = instantiate(config.dataset)
102 | train_loader, val_loader = make_distributed_data_loader(
103 | train_set, val_set, config.batch_size, rank, world_size
104 | )
105 |
106 | # Create model
107 | model = instantiate(config.model).to(device)
108 | optimizer = torch.optim.AdamW(model.parameters(), **config.optimizer)
109 | scheduler = get_scheduler(optimizer=optimizer, **config.scheduler)
110 | scaler = torch.cuda.amp.GradScaler(enabled=config.use_amp)
111 |
112 | # Load pretrained model
113 | if config.pretrain_checkpoint_path:
114 | ckpt = torch.load(config.pretrain_checkpoint_path, map_location="cpu")
115 | model.load_state_dict(ckpt["model"])
116 | print(
117 | f"Loaded pretraining checkpoint {config.pretrain_checkpoint_path}, step: {ckpt['step']}"
118 | )
119 |
120 | # Use a smaller learning rate for the image encoder
121 | encoder_params, other_params = [], []
122 | for name, param in model.named_parameters():
123 | if not param.requires_grad:
124 | continue
125 | if name.startswith("obs_encoder.img_encoder"):
126 | encoder_params.append(param)
127 | else:
128 | other_params.append(param)
129 |
130 | # Define learning rates
131 | base_lr = config.optimizer.lr
132 | encoder_lr = base_lr * 0.05
133 | wd = config.optimizer.weight_decay
134 |
135 | # Construct optimizer with custom parameter groups
136 | optimizer = torch.optim.AdamW(
137 | [
138 | {"params": encoder_params, "lr": encoder_lr, "weight_decay": wd},
139 | {"params": other_params, "lr": base_lr, "weight_decay": wd},
140 | ],
141 | betas=config.optimizer.betas,
142 | eps=config.optimizer.eps,
143 | )
144 | scheduler = get_scheduler(optimizer=optimizer, **config.scheduler)
145 |
146 | # Resume from checkpoint
147 | step = maybe_resume_checkpoint(config, model, optimizer, scheduler, scaler)
148 | epoch = step // len(train_loader)
149 |
150 | # Wrap model with DDP
151 | model = DistributedDataParallel(model, device_ids=[rank], static_graph=True)
152 |
153 | # Training loop
154 | pbar = tqdm(
155 | total=config.num_steps,
156 | initial=step,
157 | desc="Training",
158 | disable=not is_main_process(),
159 | )
160 | while step < config.num_steps:
161 | # Set epoch for distributed sampler to shuffle indices
162 | train_loader.sampler.set_epoch(epoch)
163 |
164 | # Train for one epoch
165 | for batch in train_loader:
166 | # --- Training step ---
167 | loss, info = train_one_step(
168 | config, model, optimizer, scheduler, scaler, batch, device
169 | )
170 |
171 | # --- Logging ---
172 | if is_main_process():
173 | pbar.set_description(f"step: {step}, loss: {loss.item():.4f}")
174 | wandb.log({f"train/{k}": v for k, v in info.items()})
175 |
176 | # --- Evaluate if needed ---
177 | maybe_evaluate(config, step, model, val_loader, device)
178 |
179 | # ---Collect environment rollouts if needed ---
180 | maybe_collect_rollout(config, step, model, device)
181 |
182 | # --- Save checkpoint if needed ---
183 | maybe_save_checkpoint(config, step, model, optimizer, scheduler, scaler)
184 |
185 | step += 1
186 | pbar.update(1)
187 | if step >= config.num_steps:
188 | break
189 |
190 | epoch += 1
191 |
192 |
193 | @hydra.main(
194 | version_base=None,
195 | config_path="../../configs",
196 | config_name="train_uwm_robomimic.yaml",
197 | )
198 | def main(config):
199 | # Resolve hydra config
200 | OmegaConf.resolve(config)
201 | # Spawn processes
202 | world_size = torch.cuda.device_count()
203 | mp.spawn(train, args=(world_size, config), nprocs=world_size, join=True)
204 |
205 |
206 | if __name__ == "__main__":
207 | main()
208 |
--------------------------------------------------------------------------------
/experiments/gr1/train.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import torch
3 | import torch.distributed as dist
4 | import torch.multiprocessing as mp
5 | import torch.nn.functional as F
6 | import wandb
7 | from diffusers.optimization import get_scheduler
8 | from einops import rearrange
9 | from hydra.utils import instantiate
10 | from omegaconf import OmegaConf
11 | from torch.nn.parallel import DistributedDataParallel
12 | from torchvision.utils import make_grid
13 | from tqdm import tqdm
14 |
15 | from datasets.utils.loader import make_distributed_data_loader
16 | from experiments.utils import set_seed, init_wandb, init_distributed, is_main_process
17 | from experiments.uwm.train import (
18 | process_batch,
19 | train_one_step,
20 | maybe_resume_checkpoint,
21 | maybe_save_checkpoint,
22 | )
23 |
24 |
25 | def eval_one_epoch(config, data_loader, device, model, action_normalizer=None):
26 | model.eval()
27 | model = getattr(model, "module", model) # unwrap DDP
28 |
29 | # Unnormalize actions
30 | if action_normalizer is not None:
31 | action_scale = torch.tensor(action_normalizer.scale[None], device=device)
32 | action_offset = torch.tensor(action_normalizer.offset[None], device=device)
33 | unnormalize = lambda a: a * action_scale + action_offset
34 | else:
35 | unnormalize = lambda a: a
36 |
37 | def plot(images, nrows=4):
38 | images = rearrange(images, "b v c t h w -> (b v t) c h w")
39 | images_grid = make_grid(images, nrows)
40 | return images_grid
41 |
42 | stats = {
43 | "loss": 0,
44 | "action_loss": 0,
45 | "dynamics_loss": 0,
46 | "action_mse_joint": 0,
47 | "image_mse_joint": 0,
48 | }
49 | for batch in tqdm(data_loader, desc="Evaluating", disable=not is_main_process()):
50 | # ------------ Preprocess data ------------ #
51 | curr_obs_dict, next_obs_dict, action_norm = process_batch(
52 | batch, config.model.obs_encoder.num_frames, config.model.action_len, device
53 | )
54 |
55 | with torch.no_grad():
56 | # ------------ Validation loss ------------ #
57 | _, info = model(curr_obs_dict, next_obs_dict, action_norm)
58 | for k, v in info.items():
59 | stats[k] += v
60 |
61 | # ------------ UWM Inference ------------ #
62 | action = unnormalize(action_norm)
63 |
64 | # Encode next observations
65 | next_obs = model.obs_encoder.apply_transform(next_obs_dict)
66 |
67 | # Sample next obs and actions from joint distribution
68 | next_obs_hat_joint, action_hat_joint = model.sample_joint(curr_obs_dict)
69 | joint_image_mse = F.mse_loss(next_obs_hat_joint, next_obs)
70 | dist.all_reduce(joint_image_mse, op=dist.ReduceOp.AVG)
71 | stats["image_mse_joint"] += joint_image_mse
72 | joint_action_mse = F.mse_loss(unnormalize(action_hat_joint), action)
73 | dist.all_reduce(joint_action_mse, op=dist.ReduceOp.AVG)
74 | stats["action_mse_joint"] += joint_action_mse
75 |
76 | # Average over all batches
77 | stats = {k: v / len(data_loader) for k, v in stats.items()}
78 |
79 | # Plot reconstruction
80 | stats["images"] = wandb.Image(plot(next_obs[0:1]))
81 | stats["images_joint"] = wandb.Image(plot(next_obs_hat_joint[0:1]))
82 | return stats
83 |
84 |
85 | def maybe_evaluate(config, step, model, loader, device, action_normalizer=None):
86 | """Evaluate if it's the correct step."""
87 | if step % config.eval_every == 0:
88 | stats = eval_one_epoch(config, loader, device, model, action_normalizer)
89 | if is_main_process():
90 | wandb.log({f"eval/{k}": v for k, v in stats.items()})
91 | print(f"Step {step} action mse: {stats['action_mse_joint']:.4f}")
92 |
93 |
94 | def train(rank, world_size, config):
95 | # Set global seed
96 | set_seed(config.seed * world_size + rank)
97 |
98 | # Initialize distributed training
99 | init_distributed(rank, world_size)
100 | device = torch.device(f"cuda:{rank}")
101 |
102 | # Initialize WANDB
103 | if is_main_process():
104 | init_wandb(config, job_type="train")
105 |
106 | # Create dataset
107 | train_set, val_set = instantiate(config.dataset)
108 | train_loader, val_loader = make_distributed_data_loader(
109 | train_set, val_set, config.batch_size, rank, world_size
110 | )
111 |
112 | # Create model
113 | model = instantiate(config.model).to(device)
114 | optimizer = torch.optim.AdamW(model.parameters(), **config.optimizer)
115 | scheduler = get_scheduler(optimizer=optimizer, **config.scheduler)
116 | scaler = torch.cuda.amp.GradScaler(enabled=config.use_amp)
117 |
118 | # Load pretrained model
119 | if config.pretrain_checkpoint_path:
120 | ckpt = torch.load(config.pretrain_checkpoint_path, map_location="cpu")
121 | model.load_state_dict(ckpt["model"])
122 | print(
123 | f"Loaded pretraining checkpoint {config.pretrain_checkpoint_path}, step: {ckpt['step']}"
124 | )
125 |
126 | # Replace dataset normalizers to make sure data is normalized correctly
127 | if ckpt["action_normalizer"] is not None:
128 | train_set.action_normalizer = ckpt["action_normalizer"]
129 | val_set.action_normalizer = ckpt["action_normalizer"]
130 | if ckpt["lowdim_normalizer"] is not None:
131 | train_set.lowdim_normalizer = ckpt["lowdim_normalizer"]
132 | val_set.lowdim_normalizer = ckpt["lowdim_normalizer"]
133 |
134 | # Resume from checkpoint
135 | step = maybe_resume_checkpoint(config, model, optimizer, scheduler, scaler)
136 | epoch = step // len(train_loader)
137 |
138 | # Wrap model with DDP
139 | model = DistributedDataParallel(model, device_ids=[rank], static_graph=True)
140 |
141 | # Training loop
142 | pbar = tqdm(
143 | total=config.num_steps,
144 | initial=step,
145 | desc="Training",
146 | disable=not is_main_process(),
147 | )
148 | while step < config.num_steps:
149 | # Set epoch for distributed sampler to shuffle indices
150 | train_loader.sampler.set_epoch(epoch)
151 |
152 | for batch in train_loader:
153 | # --- Training step ---
154 | loss, info = train_one_step(
155 | config, model, optimizer, scheduler, scaler, batch, device
156 | )
157 |
158 | # --- Logging ---
159 | if is_main_process():
160 | pbar.set_description(f"step: {step}, loss: {loss.item():.4f}")
161 | wandb.log({f"train/{k}": v for k, v in info.items()})
162 |
163 | # --- Evaluate if needed ---
164 | maybe_evaluate(
165 | config, step, model, val_loader, device, train_set.action_normalizer
166 | )
167 |
168 | # --- Save checkpoint if needed ---
169 | maybe_save_checkpoint(
170 | config,
171 | step,
172 | model,
173 | optimizer,
174 | scheduler,
175 | scaler,
176 | train_set.action_normalizer,
177 | train_set.lowdim_normalizer,
178 | )
179 |
180 | step += 1
181 | pbar.update(1)
182 | if step >= config.num_steps:
183 | break
184 |
185 | epoch += 1
186 |
187 |
188 | @hydra.main(
189 | version_base=None, config_path="../../configs", config_name="train_gr1.yaml"
190 | )
191 | def main(config):
192 | # Resolve hydra config
193 | OmegaConf.resolve(config)
194 | # Spawn processes
195 | world_size = torch.cuda.device_count()
196 | mp.spawn(train, args=(world_size, config), nprocs=world_size, join=True)
197 |
198 |
199 | if __name__ == "__main__":
200 | main()
201 |
--------------------------------------------------------------------------------
/datasets/droid/convert_dataset_zarr.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import multiprocessing as mp
3 | import os
4 |
5 | import numpy as np
6 | import torch
7 | from tqdm import tqdm
8 |
9 | import tensorflow_datasets as tfds
10 |
11 | from datasets.utils.buffer import CompressedTrajectoryBuffer
12 | from datasets.droid.utils import euler_angles_to_rot_6d
13 |
14 | # Disable GPU otherwise jax allocates lots of memory
15 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
16 |
17 |
18 | # Shape metadata for zarr dataset
19 | shape_meta = {
20 | "obs": {
21 | "exterior_image_1_left": {
22 | "shape": (180, 320, 3),
23 | "type": "rgb",
24 | },
25 | "exterior_image_2_left": {
26 | "shape": (180, 320, 3),
27 | "type": "rgb",
28 | },
29 | "wrist_image_left": {
30 | "shape": (180, 320, 3),
31 | "type": "rgb",
32 | },
33 | "cartesian_position": {
34 | "shape": (6,),
35 | "type": "low_dim",
36 | },
37 | "gripper_position": {
38 | "shape": (1,),
39 | "type": "low_dim",
40 | },
41 | },
42 | "action": {
43 | "shape": (10,),
44 | },
45 | }
46 |
47 | rgb_keys = [k for k, v in shape_meta["obs"].items() if v["type"] == "rgb"]
48 | lowdim_keys = [k for k, v in shape_meta["obs"].items() if v["type"] == "low_dim"]
49 |
50 |
51 | class TruncatedDataset:
52 | def __init__(self, dataset, num_episodes, filter_key=None, except_key=None):
53 | # Generate up to num_episodes episodes that contain filter_key
54 | # and doesn't contain except_key in their folderpath
55 | self.dataset = dataset
56 | self.num_episodes = num_episodes
57 | if filter_key and except_key:
58 | assert (
59 | filter_key != except_key
60 | ), "filter_key and except_key cannot be the same"
61 | self.filter_key = filter_key
62 | self.except_key = except_key
63 |
64 | def __len__(self):
65 | return self.num_episodes
66 |
67 | def __iter__(self):
68 | episode_count = 0
69 | for episode in self.dataset:
70 | folderpath = (
71 | episode["episode_metadata"]["recording_folderpath"]
72 | .numpy()
73 | .decode("UTF-8")
74 | )
75 | if self.filter_key and self.filter_key not in folderpath:
76 | continue
77 | if self.except_key and self.except_key in folderpath:
78 | continue
79 | yield episode
80 | episode_count += 1
81 | if episode_count >= self.num_episodes:
82 | break
83 |
84 |
85 | def _load_episode(episode):
86 | obs_dict = {k: [] for k in shape_meta["obs"].keys()}
87 | actions = []
88 | for step in episode:
89 | obs = step["observation"]
90 | # Load image observations
91 | for k in rgb_keys:
92 | obs_dict[k].append(obs[k].numpy())
93 | # Load low dimensional obervations
94 | obs_dict["cartesian_position"].append(obs["cartesian_position"].numpy())
95 | obs_dict["gripper_position"].append(obs["gripper_position"].numpy())
96 |
97 | # Convert and load actions
98 | action_dict = step["action_dict"]
99 | cartesian_velocity = action_dict["cartesian_velocity"].numpy()
100 | xyz, euler = cartesian_velocity[:3], cartesian_velocity[3:6]
101 | rot6d = euler_angles_to_rot_6d(torch.tensor(euler)).numpy()
102 | grippers = action_dict["gripper_position"].numpy()
103 | action = np.concatenate([xyz, rot6d, grippers])
104 | actions.append(action)
105 |
106 | episode_dict = {"obs." + k: np.stack(v) for k, v in obs_dict.items()}
107 | episode_dict["action"] = np.stack(actions)
108 | return episode_dict
109 |
110 |
111 | def preprocess_episode(episode):
112 | """Preprocess episode by removing variant tensors and standardizing data types."""
113 | processed_episode = []
114 | for step in episode["steps"]:
115 | processed_step = {
116 | "observation": step["observation"],
117 | "action_dict": step["action_dict"],
118 | }
119 | processed_episode.append(processed_step)
120 | return processed_episode
121 |
122 |
123 | def episode_loader(queue, buffer_args):
124 | """Process episodes from input queue and put results in output queue."""
125 | pid = mp.current_process().name
126 | print(f"Starting episode loader process {pid}")
127 |
128 | # Initialize buffer
129 | buffer = CompressedTrajectoryBuffer(**buffer_args)
130 |
131 | while True:
132 | episode = queue.get()
133 | if episode is None:
134 | print(f"Episode loader {pid} received a termination signal")
135 | break
136 | else:
137 | print(f"Episode loader {pid} received an episode")
138 | buffer.add_episode(_load_episode(episode))
139 |
140 |
141 | def main(args):
142 | # Load dataset
143 | raw_dataset = tfds.load(args.data_name, data_dir=args.data_dir, split=f"train")
144 | dataset = TruncatedDataset(
145 | raw_dataset, args.num_episodes, args.filter_key, args.except_key
146 | )
147 |
148 | # Create metadata
149 | metadata = {}
150 | for key, meta in shape_meta["obs"].items():
151 | metadata[f"obs.{key}"] = {
152 | "shape": meta["shape"],
153 | "dtype": np.uint8 if meta["type"] == "rgb" else np.float32,
154 | }
155 | metadata["action"] = {"shape": shape_meta["action"]["shape"], "dtype": np.float32}
156 |
157 | # Compute buffer capacity
158 | capacity = sum([len(episode["steps"]) for episode in dataset])
159 | print(f"Buffer capacity: {capacity}")
160 |
161 | # Create temporary buffer to check if buffer is restored
162 | buffer = CompressedTrajectoryBuffer(
163 | storage_path=args.buffer_path,
164 | metadata=metadata,
165 | capacity=capacity,
166 | )
167 | if buffer.restored:
168 | print("Buffer restored from disk")
169 | return
170 |
171 | # Multiprocessing setup
172 | context = mp.get_context("spawn")
173 | queue = context.Queue(maxsize=args.num_workers * 2)
174 | lock = context.Lock() # share lock across processes
175 | buffer_args = {
176 | "storage_path": args.buffer_path,
177 | "metadata": metadata,
178 | "capacity": capacity,
179 | "lock": lock,
180 | }
181 |
182 | # Start episode loader processes
183 | episode_loader_processes = []
184 | for i in range(args.num_workers):
185 | p = context.Process(
186 | target=episode_loader,
187 | args=(queue, buffer_args),
188 | name=f"EpisodeLoaderProcess-{i}",
189 | )
190 | p.start()
191 | episode_loader_processes.append(p)
192 |
193 | # Preprocess episodes on main process
194 | for episode in tqdm(dataset, desc="Preprocessing"):
195 | processed_episode = preprocess_episode(episode)
196 | queue.put(processed_episode)
197 |
198 | # Send termination signals to loaders
199 | for _ in range(args.num_workers):
200 | queue.put(None)
201 |
202 | # Wait for all processes to complete
203 | for p in episode_loader_processes:
204 | p.join()
205 |
206 |
207 | if __name__ == "__main__":
208 | parser = argparse.ArgumentParser()
209 | parser.add_argument("--data_name", type=str, default="droid")
210 | parser.add_argument(
211 | "--data_dir", type=str, default="/gscratch/weirdlab/memmelma/data/"
212 | )
213 | parser.add_argument(
214 | "--buffer_path",
215 | type=str,
216 | default="/gscratch/weirdlab/zchuning/data/droid/buffer.zarr",
217 | )
218 | parser.add_argument("--num_episodes", type=int, default=500)
219 | parser.add_argument("--num_workers", type=int, default=8)
220 | parser.add_argument("--filter_key", type=str, default=None)
221 | parser.add_argument("--except_key", type=str, default=None)
222 | args = parser.parse_args()
223 | main(args)
224 |
--------------------------------------------------------------------------------
/experiments/pad/train.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import torch
3 | import torch.distributed as dist
4 | import torch.multiprocessing as mp
5 | import torch.nn.functional as F
6 | import wandb
7 | from diffusers.optimization import get_scheduler
8 | from einops import rearrange
9 | from hydra.utils import instantiate
10 | from omegaconf import OmegaConf
11 | from torch.nn.parallel import DistributedDataParallel
12 | from torchvision.utils import make_grid
13 | from tqdm import tqdm
14 |
15 | from datasets.utils.loader import make_distributed_data_loader
16 | from experiments.utils import set_seed, init_wandb, init_distributed, is_main_process
17 | from experiments.uwm.train import (
18 | process_batch,
19 | train_one_step,
20 | maybe_resume_checkpoint,
21 | maybe_save_checkpoint,
22 | )
23 |
24 |
25 | def eval_one_epoch(config, data_loader, device, model, action_normalizer=None):
26 | model.eval()
27 | model = getattr(model, "module", model) # unwrap DDP
28 |
29 | # Unnormalize actions
30 | if action_normalizer is not None:
31 | action_scale = torch.tensor(action_normalizer.scale[None], device=device)
32 | action_offset = torch.tensor(action_normalizer.offset[None], device=device)
33 | unnormalize = lambda a: a * action_scale + action_offset
34 | else:
35 | unnormalize = lambda a: a
36 |
37 | def decode_and_plot(latents, nrows=4):
38 | with torch.no_grad():
39 | images = model.obs_encoder.apply_vae(latents, inverse=True)
40 | images = rearrange(images, "b v c t h w -> (b v t) c h w")
41 | images_grid = make_grid(images, nrows)
42 | return images_grid
43 |
44 | stats = {
45 | "loss": 0,
46 | "action_loss": 0,
47 | "dynamics_loss": 0,
48 | "action_mse_joint": 0,
49 | "image_mse_joint": 0,
50 | }
51 | for batch in tqdm(data_loader, desc="Evaluating", disable=not is_main_process()):
52 | # ------------ Preprocess data ------------ #
53 | curr_obs_dict, next_obs_dict, action_norm = process_batch(
54 | batch, config.model.obs_encoder.num_frames, config.model.action_len, device
55 | )
56 |
57 | with torch.no_grad():
58 | # ------------ Validation loss ------------ #
59 | _, info = model(curr_obs_dict, next_obs_dict, action_norm)
60 | for k, v in info.items():
61 | stats[k] += v
62 |
63 | # ------------ UWM Inference ------------ #
64 | action = unnormalize(action_norm)
65 |
66 | # Encode next observations
67 | next_img = model.obs_encoder.apply_transform(next_obs_dict)
68 | next_latent = model.obs_encoder.apply_vae(next_img)
69 |
70 | # Sample next obs and actions from joint distribution
71 | next_latent_hat_joint, action_hat_joint = model.sample_joint(curr_obs_dict)
72 | joint_image_mse = F.mse_loss(next_latent_hat_joint, next_latent)
73 | dist.all_reduce(joint_image_mse, op=dist.ReduceOp.AVG)
74 | stats["image_mse_joint"] += joint_image_mse
75 | joint_action_mse = F.mse_loss(unnormalize(action_hat_joint), action)
76 | dist.all_reduce(joint_action_mse, op=dist.ReduceOp.AVG)
77 | stats["action_mse_joint"] += joint_action_mse
78 |
79 | # Average over all batches
80 | stats = {k: v / len(data_loader) for k, v in stats.items()}
81 |
82 | # Plot reconstruction
83 | stats["images"] = wandb.Image(decode_and_plot(next_latent[0:1]))
84 | stats["images_joint"] = wandb.Image(decode_and_plot(next_latent_hat_joint[0:1]))
85 | return stats
86 |
87 |
88 | def maybe_evaluate(config, step, model, loader, device, action_normalizer=None):
89 | """Evaluate if it's the correct step."""
90 | if step % config.eval_every == 0:
91 | stats = eval_one_epoch(config, loader, device, model, action_normalizer)
92 | if is_main_process():
93 | wandb.log({f"eval/{k}": v for k, v in stats.items()})
94 | print(f"Step {step} action mse: {stats['action_mse_joint']:.4f}")
95 |
96 |
97 | def train(rank, world_size, config):
98 | # Set global seed
99 | set_seed(config.seed * world_size + rank)
100 |
101 | # Initialize distributed training
102 | init_distributed(rank, world_size)
103 | device = torch.device(f"cuda:{rank}")
104 |
105 | # Initialize WANDB
106 | if is_main_process():
107 | init_wandb(config, job_type="train")
108 |
109 | # Create dataset
110 | train_set, val_set = instantiate(config.dataset)
111 | train_loader, val_loader = make_distributed_data_loader(
112 | train_set, val_set, config.batch_size, rank, world_size
113 | )
114 |
115 | # Create model
116 | model = instantiate(config.model).to(device)
117 | optimizer = torch.optim.AdamW(model.parameters(), **config.optimizer)
118 | scheduler = get_scheduler(optimizer=optimizer, **config.scheduler)
119 | scaler = torch.cuda.amp.GradScaler(enabled=config.use_amp)
120 |
121 | # Load pretrained model
122 | if config.pretrain_checkpoint_path:
123 | ckpt = torch.load(config.pretrain_checkpoint_path, map_location="cpu")
124 | model.load_state_dict(ckpt["model"])
125 | print(
126 | f"Loaded pretraining checkpoint {config.pretrain_checkpoint_path}, step: {ckpt['step']}"
127 | )
128 |
129 | # Replace dataset normalizers to make sure data is normalized correctly
130 | if ckpt["action_normalizer"] is not None:
131 | train_set.action_normalizer = ckpt["action_normalizer"]
132 | val_set.action_normalizer = ckpt["action_normalizer"]
133 | if ckpt["lowdim_normalizer"] is not None:
134 | train_set.lowdim_normalizer = ckpt["lowdim_normalizer"]
135 | val_set.lowdim_normalizer = ckpt["lowdim_normalizer"]
136 |
137 | # Resume from checkpoint
138 | step = maybe_resume_checkpoint(config, model, optimizer, scheduler, scaler)
139 | epoch = step // len(train_loader)
140 |
141 | # Wrap model with DDP
142 | model = DistributedDataParallel(model, device_ids=[rank], static_graph=True)
143 |
144 | # Training loop
145 | pbar = tqdm(
146 | total=config.num_steps,
147 | initial=step,
148 | desc="Training",
149 | disable=not is_main_process(),
150 | )
151 | while step < config.num_steps:
152 | # Set epoch for distributed sampler to shuffle indices
153 | train_loader.sampler.set_epoch(epoch)
154 |
155 | for batch in train_loader:
156 | # --- Training step ---
157 | loss, info = train_one_step(
158 | config, model, optimizer, scheduler, scaler, batch, device
159 | )
160 |
161 | # --- Logging ---
162 | if is_main_process():
163 | pbar.set_description(f"step: {step}, loss: {loss.item():.4f}")
164 | wandb.log({f"train/{k}": v for k, v in info.items()})
165 |
166 | # --- Evaluate if needed ---
167 | maybe_evaluate(
168 | config, step, model, val_loader, device, train_set.action_normalizer
169 | )
170 |
171 | # --- Save checkpoint if needed ---
172 | maybe_save_checkpoint(
173 | config,
174 | step,
175 | model,
176 | optimizer,
177 | scheduler,
178 | scaler,
179 | train_set.action_normalizer,
180 | train_set.lowdim_normalizer,
181 | )
182 |
183 | step += 1
184 | pbar.update(1)
185 | if step >= config.num_steps:
186 | break
187 |
188 | epoch += 1
189 |
190 |
191 | @hydra.main(
192 | version_base=None, config_path="../../configs", config_name="train_pad.yaml"
193 | )
194 | def main(config):
195 | # Resolve hydra config
196 | OmegaConf.resolve(config)
197 | # Spawn processes
198 | world_size = torch.cuda.device_count()
199 | mp.spawn(train, args=(world_size, config), nprocs=world_size, join=True)
200 |
201 |
202 | if __name__ == "__main__":
203 | main()
204 |
--------------------------------------------------------------------------------
/datasets/utils/buffer.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | from multiprocessing import Lock
4 | from os.path import expanduser, expandvars
5 | from typing import Union, Optional
6 |
7 | import numcodecs
8 | import numpy as np
9 | import zarr
10 |
11 |
12 | def get_optimal_chunks(
13 | shape: tuple[int, ...], dtype: Union[str, np.dtype], target_chunk_bytes: float = 2e6
14 | ) -> tuple[int, ...]:
15 | """
16 | Calculates the optimal chunk sizes for an array given its shape, data type, and a target chunk size in bytes.
17 |
18 | Args:
19 | shape: The shape of the array.
20 | dtype: The data type of the array.
21 | target_chunk_bytes: The target size for each chunk in bytes. Defaults to 2e6 (2 MB).
22 |
23 | Returns:
24 | The optimal chunk dimensions for the given array shape and data type, aiming to not exceed
25 | the target chunk size in bytes.
26 | """
27 | itemsize = np.dtype(dtype).itemsize
28 | rshape = list(shape[::-1])
29 |
30 | # Find the index to split the shape, starting from the rightmost dimension
31 | split_idx = len(shape) - 1
32 | for i in range(len(shape) - 1):
33 | this_chunk_bytes = itemsize * np.prod(rshape[:i])
34 | next_chunk_bytes = itemsize * np.prod(rshape[: i + 1])
35 | if this_chunk_bytes <= target_chunk_bytes < next_chunk_bytes:
36 | split_idx = i
37 | break
38 |
39 | # Handle jagged chunk dimension
40 | rchunks = rshape[:split_idx]
41 | item_chunk_bytes = itemsize * np.prod(rchunks)
42 | next_chunk_length = min(
43 | rshape[split_idx], math.ceil(target_chunk_bytes / item_chunk_bytes)
44 | )
45 | rchunks.append(next_chunk_length)
46 |
47 | # Handle remaining dimensions
48 | rchunks.extend([1] * (len(shape) - len(rchunks)))
49 | chunks = tuple(rchunks[::-1])
50 | return chunks
51 |
52 |
53 | class CompressedTrajectoryBuffer:
54 | """
55 | A class that stores trajectory data in a compressed zarr array.
56 | """
57 |
58 | def __init__(
59 | self,
60 | storage_path: str,
61 | metadata: dict[str, dict[str, any]],
62 | capacity: Optional[int] = None,
63 | lock: Optional[Lock] = None,
64 | ):
65 | """
66 | Initialize the trajectory buffer. If there is an existing buffer at the given path, it will be restored.
67 |
68 | Args:
69 | storage_path: Path to the buffer storage.
70 | metadata: Dictionary containing metadata for each data key. Each key should
71 | map to a dictionary containing the following keys:
72 | - shape: shape of the data
73 | - dtype: dtype of the data
74 | capacity: Maximum number of transition steps that can be stored in the buffer.
75 | Only used when creating a new buffer.
76 | lock: Multiprocessing lock to synchronize access to the buffer. If None, a new lock will be created.
77 | """
78 | # Create zarr storage and root group
79 | storage_path = expandvars(expanduser(storage_path))
80 | self.restored = os.path.exists(storage_path)
81 | self.storage = zarr.DirectoryStore(storage_path)
82 |
83 | # Mutex for zarr storage
84 | self.lock = Lock() if lock is None else lock
85 |
86 | # Create data and metadata groups
87 | self.root = zarr.group(store=self.storage)
88 | self.data = self.root.require_group("data")
89 | self.meta = self.root.require_group("meta")
90 |
91 | if self.restored:
92 | print(f"Restoring buffer from {storage_path}")
93 | assert "episode_ends" in self.meta
94 | assert all(key in self.data for key in metadata)
95 | assert all(
96 | self.data[key].shape[1:] == value["shape"]
97 | for key, value in metadata.items()
98 | )
99 |
100 | # Check that all data have the same length and restore capacity
101 | lengths = {self.data[key].shape[0] for key in self.data}
102 | assert len(lengths) == 1, "Inconsistent data lengths in the buffer"
103 | self.capacity = lengths.pop()
104 | else:
105 | with self.lock:
106 | print(f"Creating new buffer at {storage_path}")
107 | assert capacity is not None, "Capacity must be specified for new buffer"
108 | self.capacity = capacity
109 |
110 | # Create empty episode_ends
111 | self.meta.zeros(
112 | name="episode_ends",
113 | shape=(0,),
114 | dtype=np.int64,
115 | compressor=None,
116 | )
117 |
118 | # Allocate space for data
119 | for key, value in metadata.items():
120 | shape = (capacity,) + tuple(value["shape"])
121 | dtype = value["dtype"]
122 | if dtype == np.uint8:
123 | # Chunk and compress images individually
124 | chunks = (1,) + shape[1:]
125 | compressor = numcodecs.Blosc(
126 | cname="lz4", clevel=5, shuffle=numcodecs.Blosc.NOSHUFFLE
127 | )
128 | else:
129 | # Chunk and compress other data by computing optimal chunks
130 | chunks = get_optimal_chunks(shape, dtype)
131 | compressor = numcodecs.Blosc(
132 | cname="lz4", clevel=0, shuffle=numcodecs.Blosc.NOSHUFFLE
133 | )
134 | # Create new array
135 | self.data.zeros(
136 | name=key,
137 | shape=shape,
138 | chunks=chunks,
139 | dtype=dtype,
140 | compressor=compressor,
141 | object_codec=numcodecs.Pickle(),
142 | )
143 |
144 | @property
145 | def episode_ends(self) -> np.ndarray:
146 | return self.meta["episode_ends"]
147 |
148 | @property
149 | def num_episodes(self) -> int:
150 | return len(self.episode_ends)
151 |
152 | @property
153 | def num_steps(self) -> int:
154 | return self.episode_ends[-1] if len(self.episode_ends) > 0 else 0
155 |
156 | def add_episode(self, data: dict[str, np.ndarray]):
157 | with self.lock:
158 | # Get episode length
159 | episode_lens = np.array([v.shape[0] for v in data.values()])
160 | assert np.all(episode_lens == episode_lens[0])
161 | episode_len = episode_lens[0]
162 |
163 | # Compute corresponding buffer indices
164 | start_ind = self.num_steps
165 | end_ind = start_ind + episode_len
166 | if end_ind > self.capacity:
167 | raise RuntimeError("Buffer capacity exceeded")
168 |
169 | # Copy data to buffer
170 | for key, value in data.items():
171 | arr = self.data[key]
172 | arr[start_ind:end_ind] = value
173 |
174 | # Update episode_ends
175 | self.episode_ends.resize(len(self.episode_ends) + 1)
176 | self.episode_ends[-1] = end_ind
177 |
178 | # Rechunk and recompress episode_ends if necessary
179 | if self.episode_ends.chunks[0] < self.episode_ends.shape[0]:
180 | new_chunk_len = self.episode_ends.shape[0] * 1.5
181 | new_chunks = (new_chunk_len,) + self.episode_ends.chunks[1:]
182 | self.meta.move("episode_ends", "_temp")
183 | zarr.copy(
184 | source=self.meta["_temp"],
185 | dest=self.meta,
186 | name="episode_ends",
187 | chunks=new_chunks,
188 | compressor=None,
189 | )
190 | del self.meta["_temp"]
191 |
192 | def __repr__(self) -> str:
193 | return str(self.root.tree())
194 |
195 | def keys(self):
196 | return self.data.keys()
197 |
198 | def values(self):
199 | return self.data.values()
200 |
201 | def items(self):
202 | return self.data.items()
203 |
204 | def __getitem__(self, key):
205 | return self.data[key]
206 |
207 | def __contains__(self, key):
208 | return key in self.data
209 |
--------------------------------------------------------------------------------
/models/gr1/gr1.py:
--------------------------------------------------------------------------------
1 | # Copyright (2024) Bytedance Ltd. and/or its affiliates
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """GR-1 model."""
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 | from einops import rearrange
20 |
21 | from models.common.attention import AttentionBlock
22 | from models.common.utils import get_nd_sinusoidal_embed
23 | from models.gr1.obs_encoder import GR1ObservationEncoder
24 | from models.gr1.vision_transformer import Block
25 |
26 |
27 | class GR1(nn.Module):
28 | def __init__(
29 | self,
30 | action_len: int,
31 | action_dim: int,
32 | obs_encoder: GR1ObservationEncoder,
33 | embed_dim: int,
34 | image_size: tuple[int, ...],
35 | patch_size: tuple[int, ...],
36 | num_chans: int = 3,
37 | depth: int = 12,
38 | num_heads: int = 12,
39 | mlp_ratio: int = 4,
40 | qkv_bias: bool = True,
41 | decoder_depth: int = 3,
42 | decoder_num_heads: int = 16,
43 | decoder_mlp_ratio: int = 4,
44 | decoder_qkv_bias: bool = True,
45 | ):
46 | super().__init__()
47 | self.action_dim = action_dim
48 | self.action_len = action_len
49 | self.image_size = image_size
50 | self.patch_size = patch_size
51 | self.num_chans = num_chans
52 |
53 | # Observation encoder (with perceiver resampler)
54 | self.obs_encoder = obs_encoder
55 | self.obs_len = obs_encoder.num_latents
56 |
57 | # Action query token
58 | self.action_queries = nn.Parameter(
59 | torch.empty(1, self.action_len, embed_dim).normal_(mean=0, std=0.02)
60 | )
61 |
62 | # Observation query token
63 | self.obs_queries = nn.Parameter(
64 | torch.empty(1, self.obs_len, embed_dim).normal_(mean=0, std=0.02)
65 | )
66 |
67 | # Main transformer
68 | self.transformer = nn.ModuleList(
69 | [
70 | AttentionBlock(
71 | dim=embed_dim,
72 | num_heads=num_heads,
73 | mlp_ratio=mlp_ratio,
74 | qkv_bias=qkv_bias,
75 | )
76 | for _ in range(depth)
77 | ]
78 | )
79 |
80 | # Action head
81 | self.action_head = nn.Sequential(
82 | nn.Linear(embed_dim, embed_dim // 2),
83 | nn.SiLU(),
84 | nn.Linear(embed_dim // 2, embed_dim // 2),
85 | nn.SiLU(),
86 | nn.Linear(embed_dim // 2, self.action_dim),
87 | )
88 |
89 | # Image decoder
90 | self.decoder_query_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
91 | decoder_pos_embed = get_nd_sinusoidal_embed(
92 | embed_dim,
93 | (self.image_size // self.patch_size, self.image_size // self.patch_size),
94 | )
95 | self.decoder_pos_embed = nn.Parameter(
96 | torch.from_numpy(decoder_pos_embed).float()[None], requires_grad=False
97 | )
98 | self.decoder_proj = nn.Linear(embed_dim, embed_dim, bias=True)
99 | self.decoder_blocks = nn.ModuleList(
100 | [
101 | Block(
102 | dim=embed_dim,
103 | num_heads=decoder_num_heads,
104 | mlp_ratio=decoder_mlp_ratio,
105 | qkv_bias=decoder_qkv_bias,
106 | )
107 | for _ in range(decoder_depth)
108 | ]
109 | )
110 | self.decoder_head = nn.Sequential(
111 | nn.LayerNorm(embed_dim),
112 | nn.Linear(embed_dim, self.patch_size**2 * 3, bias=True),
113 | )
114 |
115 | def predict_next_obs(self, next_obs_embed):
116 | next_obs_embed = rearrange(
117 | next_obs_embed,
118 | "b (v t n) d -> (b v t) n d",
119 | v=self.obs_encoder.num_views,
120 | t=self.obs_encoder.num_frames,
121 | )
122 | next_obs_embed = self.decoder_proj(next_obs_embed)
123 | next_obs_queries = self.decoder_query_token + self.decoder_pos_embed
124 | next_obs_pred = torch.cat(
125 | [next_obs_embed, next_obs_queries.repeat(next_obs_embed.shape[0], 1, 1)],
126 | dim=1,
127 | )
128 | for block in self.decoder_blocks:
129 | next_obs_pred = block(next_obs_pred)
130 | next_obs_pred = self.decoder_head(
131 | next_obs_pred[:, -self.decoder_pos_embed.shape[1] :]
132 | )
133 | next_obs_pred = rearrange(
134 | next_obs_pred,
135 | "(b v t) (h w) (c ph pw) -> b v c t (h ph) (w pw)",
136 | v=self.obs_encoder.num_views,
137 | t=self.obs_encoder.num_frames,
138 | h=self.image_size // self.patch_size,
139 | w=self.image_size // self.patch_size,
140 | ph=self.patch_size,
141 | pw=self.patch_size,
142 | )
143 | return next_obs_pred
144 |
145 | def forward(self, obs_dict, next_obs_dict, action, action_mask=None):
146 | # Get obs patches and prediction targets
147 | curr_embeds, next_obs = self.obs_encoder.encode_curr_and_next_obs(
148 | obs_dict, next_obs_dict
149 | )
150 |
151 | # Transformer inputs
152 | action_queries = self.action_queries.expand(action.shape[0], -1, -1)
153 | obs_queries = self.obs_queries.expand(next_obs.shape[0], -1, -1)
154 | x = torch.cat([curr_embeds, action_queries, obs_queries], dim=1)
155 |
156 | # Attention mask
157 | attn_mask = None
158 | if action_mask is not None:
159 | # Action mask has shape (B,)
160 | B = action_mask.shape[0]
161 | N = self.action_len + self.obs_len * 2
162 | attn_mask = torch.ones((B, N, N), device=action.device, dtype=torch.bool)
163 | attn_mask[
164 | ~action_mask, :, self.obs_len : self.obs_len + self.action_len
165 | ] = 0
166 |
167 | # Transformer forward pass
168 | for block in self.transformer:
169 | x = block(x, attn_mask=attn_mask)
170 |
171 | # Action prediction
172 | action_embed = x[:, self.obs_len : self.obs_len + self.action_len]
173 | action_pred = self.action_head(action_embed)
174 |
175 | # Image prediction
176 | next_obs_embed = x[:, -self.obs_len :]
177 | next_obs_pred = self.predict_next_obs(next_obs_embed)
178 |
179 | # Compute losses
180 | if action_mask is None:
181 | action_loss = F.mse_loss(action_pred, action)
182 | else:
183 | action_loss = F.mse_loss(action_pred[action_mask], action[action_mask])
184 | dynamics_loss = F.mse_loss(next_obs_pred, next_obs)
185 | loss = action_loss + dynamics_loss
186 | info = {
187 | "loss": loss.item(),
188 | "action_loss": action_loss.item(),
189 | "dynamics_loss": dynamics_loss.item(),
190 | }
191 | return loss, info
192 |
193 | @torch.no_grad()
194 | def sample(self, obs_dict):
195 | _, action_sample = self.sample_joint(obs_dict)
196 | return action_sample
197 |
198 | @torch.no_grad()
199 | def sample_joint(self, obs_dict):
200 | # Get obs patches and prediction targets
201 | curr_embeds = self.obs_encoder.encode_obs(obs_dict)
202 |
203 | # Transformer inputs
204 | action_queries = self.action_queries.expand(curr_embeds.shape[0], -1, -1)
205 | obs_queries = self.obs_queries.expand(curr_embeds.shape[0], -1, -1)
206 | x = torch.cat([curr_embeds, action_queries, obs_queries], dim=1)
207 |
208 | # Transformer forward pass
209 | for block in self.transformer:
210 | x = block(x)
211 |
212 | # Action prediction
213 | action_embed = x[:, self.obs_len : self.obs_len + self.action_len]
214 | action_pred = self.action_head(action_embed)
215 |
216 | # Image prediction
217 | next_obs_embed = x[:, -self.obs_len :]
218 | next_obs_pred = self.predict_next_obs(next_obs_embed)
219 | return next_obs_pred, action_pred
220 |
--------------------------------------------------------------------------------
/models/uwm/obs_encoder.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional, Tuple, Union
2 |
3 | import torch
4 | import torch.nn as nn
5 | from einops import rearrange
6 |
7 | from models.common.language import CLIPTextEncoder
8 | from models.common.transforms import VideoTransform, VAEDownsample
9 | from models.common.vision import ResNetImageEncoder, ViTImageEncoder
10 |
11 |
12 | class UWMObservationEncoder(nn.Module):
13 | def __init__(
14 | self,
15 | shape_meta: dict,
16 | num_frames: int,
17 | embed_dim: int,
18 | resize_shape: Tuple[int, int] = None,
19 | crop_shape: Tuple[int, int] = None,
20 | random_crop: bool = True,
21 | color_jitter: Optional[Dict] = None,
22 | imagenet_norm: bool = False,
23 | vision_backbone: str = "vit",
24 | use_low_dim: bool = True,
25 | use_language: bool = True,
26 | ):
27 | super().__init__()
28 | self.shape_meta = shape_meta
29 | self.num_frames = num_frames
30 | self.rgb_keys = sorted(
31 | [k for k, v in shape_meta["obs"].items() if v["type"] == "rgb"]
32 | )
33 | self.low_dim_keys = sorted(
34 | [k for k, v in shape_meta["obs"].items() if v["type"] == "low_dim"]
35 | )
36 | self.num_views = len(self.rgb_keys)
37 | self.embed_dim = embed_dim
38 |
39 | # Image augmentation
40 | self.obs_transform = VideoTransform(
41 | resize_shape=resize_shape,
42 | crop_shape=crop_shape,
43 | random_crop=random_crop,
44 | color_jitter=color_jitter,
45 | imagenet_norm=imagenet_norm,
46 | )
47 |
48 | # Image encoder
49 | if vision_backbone == "vit":
50 | self.img_encoder = ViTImageEncoder(
51 | num_views=self.num_views,
52 | embed_dim=embed_dim,
53 | )
54 | elif vision_backbone == "resnet":
55 | self.img_encoder = ResNetImageEncoder(
56 | num_views=self.num_views,
57 | embed_dim=embed_dim,
58 | )
59 | else:
60 | raise NotImplementedError(f"Unsupported vision backbone: {vision_backbone}")
61 |
62 | # Low-dim observations
63 | self.use_low_dim = use_low_dim
64 |
65 | # Language encoder
66 | self.use_language = use_language
67 | self.text_encoder = (
68 | CLIPTextEncoder(embed_dim=embed_dim) if use_language else None
69 | )
70 |
71 | # VAE downsampling
72 | self.vae = VAEDownsample()
73 |
74 | def apply_transform(self, obs_dicts: Union[dict, list[dict]]):
75 | """
76 | Accept a list of observation dictionaries and apply the same transform to each.
77 | """
78 | if isinstance(obs_dicts, dict):
79 | obs_dicts = [obs_dicts]
80 | is_singleton = True
81 | else:
82 | is_singleton = False
83 | assert isinstance(obs_dicts, list)
84 |
85 | # Apply the same transform to each observation
86 | num_obs = len(obs_dicts)
87 | transformed_imgs = [[] for _ in range(num_obs)]
88 | for key in self.rgb_keys:
89 | combined_imgs = torch.cat([obs_dict[key] for obs_dict in obs_dicts], dim=0)
90 | combined_imgs = self.obs_transform(combined_imgs)
91 | chunked_imgs = combined_imgs.chunk(num_obs, dim=0)
92 | for i, img in enumerate(chunked_imgs):
93 | transformed_imgs[i].append(img)
94 |
95 | # Stack transformed images
96 | # Each image has shape (B, V, C, T, H, W)
97 | transformed_imgs = [torch.stack(imgs, dim=1) for imgs in transformed_imgs]
98 | if is_singleton:
99 | transformed_imgs = transformed_imgs[0]
100 | return transformed_imgs
101 |
102 | def apply_vae(
103 | self,
104 | imgs_list: Union[torch.Tensor, list[torch.Tensor]],
105 | inverse: bool = False,
106 | microbatch_size: int = 72, # Tuned for 40GB VRAM
107 | ):
108 | """
109 | Accept a list of images and apply VAE to downsample or upsample images.
110 | If inverse is False, downsample images. Otherwise, upsample images.
111 | Process images in microbatches to reduce memory usage.
112 | """
113 | if isinstance(imgs_list, torch.Tensor):
114 | imgs_list = [imgs_list]
115 | is_singleton = True
116 | else:
117 | is_singleton = False
118 | assert isinstance(imgs_list, list)
119 | imgs = torch.cat(imgs_list, dim=0)
120 |
121 | # Flatten multiview videos to images
122 | B, V = imgs.shape[:2]
123 | imgs = rearrange(imgs, "b v c t h w -> (b v t) c h w")
124 |
125 | # Process images in microbatches
126 | transformed_imgs = []
127 | for i in range(0, imgs.shape[0], microbatch_size):
128 | batch_imgs = imgs[i : i + microbatch_size]
129 | if inverse:
130 | batch_transformed_imgs = self.vae.inverse(batch_imgs)
131 | else:
132 | batch_transformed_imgs = self.vae(batch_imgs)
133 | transformed_imgs.append(batch_transformed_imgs)
134 | transformed_imgs = torch.cat(transformed_imgs, dim=0)
135 |
136 | # Unflatten images to multiview videos
137 | transformed_imgs = rearrange(
138 | transformed_imgs, "(b v t) c h w -> b v c t h w", b=B, v=V
139 | )
140 | if not is_singleton:
141 | chunk_sizes = [img.shape[0] for img in imgs_list]
142 | transformed_imgs = list(transformed_imgs.split(chunk_sizes, dim=0))
143 | return transformed_imgs
144 |
145 | def encode_curr_obs(self, curr_obs_dict: dict):
146 | # Encoder current observations to features
147 | curr_imgs = self.apply_transform(curr_obs_dict)
148 | curr_feats = self.img_encoder(curr_imgs) # (B, V*T*D)
149 |
150 | if self.use_low_dim:
151 | low_dims = [curr_obs_dict[key] for key in self.low_dim_keys]
152 | low_dims = torch.cat(low_dims, dim=-1).flatten(1)
153 | curr_feats = torch.cat([curr_feats, low_dims], dim=-1)
154 |
155 | if self.use_language:
156 | lang_feats = self.text_encoder(
157 | input_ids=curr_obs_dict["input_ids"],
158 | attention_mask=curr_obs_dict["attention_mask"],
159 | )
160 | curr_feats = torch.cat([curr_feats, lang_feats], dim=-1)
161 | return curr_feats
162 |
163 | def encode_next_obs(self, next_obs_dict: dict):
164 | # Encoder next observations to latents
165 | next_imgs = self.apply_transform(next_obs_dict)
166 | next_latents = self.apply_vae(next_imgs)
167 | return next_latents
168 |
169 | def encode_curr_and_next_obs(self, curr_obs_dict: dict, next_obs_dict: dict):
170 | # Apply the same transform to obs and next obs
171 | curr_imgs, next_imgs = self.apply_transform([curr_obs_dict, next_obs_dict])
172 |
173 | # Encode current obs to features
174 | curr_feats = self.img_encoder(curr_imgs) # (B, V*T*D)
175 |
176 | if self.use_low_dim:
177 | low_dims = [curr_obs_dict[key] for key in self.low_dim_keys]
178 | low_dims = torch.cat(low_dims, dim=-1).flatten(1)
179 | curr_feats = torch.cat([curr_feats, low_dims], dim=-1)
180 |
181 | if self.use_language:
182 | lang_feats = self.text_encoder(
183 | input_ids=curr_obs_dict["input_ids"],
184 | attention_mask=curr_obs_dict["attention_mask"],
185 | )
186 | curr_feats = torch.cat([curr_feats, lang_feats], dim=-1)
187 |
188 | # Encode next obs to latents
189 | next_latents = self.apply_vae(next_imgs)
190 | return curr_feats, next_latents
191 |
192 | def feat_dim(self):
193 | # Return the dimension of encoded features
194 | low_dim_size = sum(
195 | self.shape_meta["obs"][key]["shape"][-1] for key in self.low_dim_keys
196 | )
197 | return (
198 | self.num_views * self.num_frames * self.embed_dim
199 | + int(self.use_low_dim) * self.num_frames * low_dim_size
200 | + int(self.use_language) * self.embed_dim
201 | )
202 |
203 | def latent_img_shape(self):
204 | # Construct dummy image and forward pass to get latent image shape
205 | dummy_obs = {}
206 | for k in self.rgb_keys:
207 | img_shape = self.shape_meta["obs"][k]["shape"]
208 | dummy_obs[k] = torch.zeros(
209 | 1, self.num_frames, *img_shape, dtype=torch.uint8
210 | )
211 | with torch.no_grad():
212 | latent = self.encode_next_obs(dummy_obs)
213 | return tuple(latent.shape[1:])
214 |
--------------------------------------------------------------------------------