├── configs ├── template_dp.yaml ├── dataset │ ├── libero_book_caddy.yaml │ ├── robomimic_can_ph.yaml │ ├── libero_90.yaml │ ├── droid_mixture.yaml │ ├── robomimic_lift_ph.yaml │ ├── robomimic_square_ph.yaml │ ├── libero_moka_moka.yaml │ ├── libero_stove_moka.yaml │ ├── libero_mug_microwave.yaml │ ├── libero_cheese_butter.yaml │ ├── libero_soup_sauce.yaml │ ├── libero_bowl_drawer.yaml │ ├── libero_soup_cheese.yaml │ ├── libero_mug_mug.yaml │ ├── libero_mug_pudding.yaml │ ├── template_robomimic.yaml │ ├── robomimic_tool_hang_ph.yaml │ ├── webvideo_for_robomimic.yaml │ ├── robomimic_transport_ph.yaml │ ├── template_libero.yaml │ ├── droid.yaml │ ├── libero_90_webvideo.yaml │ └── droid_webvideo.yaml ├── finetune_gr1.yaml ├── finetune_pad.yaml ├── finetune_uwm.yaml ├── train_gr1.yaml ├── train_pad.yaml ├── train_uwm.yaml ├── train_dp.yaml ├── finetune_dp.yaml ├── train_gr1_robomimic.yaml ├── train_pad_robomimic.yaml ├── finetune_gr1_robomimic.yaml ├── finetune_pad_robomimic.yaml ├── train_dp_robomimic.yaml ├── finetune_dp_robomimic.yaml ├── train_uwm_robomimic.yaml ├── finetune_uwm_robomimic.yaml ├── template_finetune.yaml ├── template_finetune_robomimic.yaml ├── template_train.yaml ├── model │ ├── pad.yaml │ ├── gr1.yaml │ ├── uwm.yaml │ └── dp.yaml └── template_train_robomimic.yaml ├── models ├── gr1 │ ├── __init__.py │ ├── flamingo.py │ ├── obs_encoder.py │ └── gr1.py ├── pad │ ├── __init__.py │ └── obs_encoder.py ├── uwm │ ├── __init__.py │ └── obs_encoder.py ├── dp │ ├── __init__.py │ ├── image_policy.py │ ├── transformer.py │ ├── base_policy.py │ └── obs_encoder.py └── common │ ├── language.py │ ├── transforms.py │ └── adaln_attention.py ├── .gitignore ├── scripts ├── launch_droid_pretrain.sh ├── launch_droid_cotrain.sh ├── setup_aws.sh └── launch_droid_finetune.sh ├── datasets ├── utils │ ├── file_utils.py │ ├── obs_utils.py │ ├── loader.py │ ├── sampler.py │ ├── normalizer.py │ ├── mixture.py │ └── buffer.py ├── webvideo │ ├── __init__.py │ └── dataset.py ├── droid │ ├── __init__.py │ └── convert_dataset_zarr.py └── robomimic │ ├── __init__.py │ └── dataset.py ├── pyproject.toml ├── requirements.txt ├── experiments ├── utils.py ├── uwm │ ├── eval_robomimic.py │ ├── train_webvideo.py │ ├── ablate_inverse_dynamics.py │ ├── ablate_forward_dynamics.py │ ├── eval_droid.py │ └── train_robomimic.py ├── gr1 │ ├── train_robomimic.py │ └── train.py ├── pad │ ├── train_robomimic.py │ └── train.py └── dp │ └── train_robomimic.py ├── environments └── robomimic │ ├── __init__.py │ └── wrappers.py └── README.md /configs/template_dp.yaml: -------------------------------------------------------------------------------- 1 | num_frames: 17 -------------------------------------------------------------------------------- /configs/dataset/libero_book_caddy.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - template_libero 3 | - _self_ -------------------------------------------------------------------------------- /configs/dataset/robomimic_can_ph.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - template_robomimic 3 | - _self_ -------------------------------------------------------------------------------- /models/gr1/__init__.py: -------------------------------------------------------------------------------- 1 | from .gr1 import GR1 2 | from .obs_encoder import GR1ObservationEncoder 3 | -------------------------------------------------------------------------------- /models/pad/__init__.py: -------------------------------------------------------------------------------- 1 | from .obs_encoder import PADObservationEncoder 2 | from .pad import PAD 3 | -------------------------------------------------------------------------------- /models/uwm/__init__.py: -------------------------------------------------------------------------------- 1 | from .obs_encoder import UWMObservationEncoder 2 | from .uwm import UnifiedWorldModel 3 | -------------------------------------------------------------------------------- /configs/finetune_gr1.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: droid 3 | - model: gr1 4 | - template_finetune 5 | - _self_ -------------------------------------------------------------------------------- /configs/finetune_pad.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: droid 3 | - model: pad 4 | - template_finetune 5 | - _self_ -------------------------------------------------------------------------------- /configs/finetune_uwm.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: droid 3 | - model: uwm 4 | - template_finetune 5 | - _self_ -------------------------------------------------------------------------------- /configs/train_gr1.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: droid 3 | - model: gr1 4 | - template_train 5 | - _self_ 6 | -------------------------------------------------------------------------------- /configs/train_pad.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: droid 3 | - model: pad 4 | - template_train 5 | - _self_ 6 | -------------------------------------------------------------------------------- /configs/train_uwm.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: droid 3 | - model: uwm 4 | - template_train 5 | - _self_ 6 | -------------------------------------------------------------------------------- /configs/train_dp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: droid 3 | - model: dp 4 | - template_train 5 | - template_dp 6 | - _self_ -------------------------------------------------------------------------------- /configs/finetune_dp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: droid 3 | - model: dp 4 | - template_finetune 5 | - template_dp 6 | - _self_ -------------------------------------------------------------------------------- /configs/train_gr1_robomimic.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: libero_90 3 | - model: gr1 4 | - template_train_robomimic 5 | - _self_ -------------------------------------------------------------------------------- /configs/train_pad_robomimic.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: libero_90 3 | - model: pad 4 | - template_train_robomimic 5 | - _self_ -------------------------------------------------------------------------------- /configs/finetune_gr1_robomimic.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: libero_book_caddy 3 | - model: gr1 4 | - template_finetune_robomimic 5 | - _self_ -------------------------------------------------------------------------------- /configs/finetune_pad_robomimic.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: libero_book_caddy 3 | - model: pad 4 | - template_finetune_robomimic 5 | - _self_ -------------------------------------------------------------------------------- /configs/train_dp_robomimic.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: libero_90 3 | - model: dp 4 | - template_train_robomimic 5 | - template_dp 6 | - _self_ -------------------------------------------------------------------------------- /configs/finetune_dp_robomimic.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: libero_book_caddy 3 | - model: dp 4 | - template_finetune_robomimic 5 | - template_dp 6 | - _self_ -------------------------------------------------------------------------------- /configs/train_uwm_robomimic.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: libero_90 3 | - model: uwm 4 | - template_train_robomimic 5 | - _self_ 6 | 7 | model: 8 | num_registers: 24 -------------------------------------------------------------------------------- /configs/finetune_uwm_robomimic.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: libero_book_caddy 3 | - model: uwm 4 | - template_finetune_robomimic 5 | - _self_ 6 | 7 | model: 8 | num_registers: 24 -------------------------------------------------------------------------------- /configs/dataset/libero_90.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - template_libero 3 | - _self_ 4 | 5 | name: libero_90 6 | hdf5_path_globs: /home/ubuntu/LIBERO/libero/datasets/libero_90/*.hdf5 7 | buffer_path: /home/ubuntu/LIBERO/libero/datasets/libero_90/buffer.zarr -------------------------------------------------------------------------------- /configs/dataset/droid_mixture.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - droid 3 | - _self_ 4 | 5 | name: droid_mixture 6 | buffer_path: /tmp/weirdlab/zchuning/data/droid/buffer.zarr 7 | video_buffer_path: /tmp/weirdlab/zchuning/data/droid/video_buffer.zarr 8 | balance_datasets: False -------------------------------------------------------------------------------- /configs/dataset/robomimic_lift_ph.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - template_robomimic 3 | - _self_ 4 | 5 | name: robomimic_lift_ph 6 | hdf5_path_globs: /home/ubuntu/robomimic/datasets/lift/ph/image_v141.hdf5 7 | buffer_path: /home/ubuntu/robomimic/datasets/lift/ph/image_v141.zarr -------------------------------------------------------------------------------- /configs/dataset/robomimic_square_ph.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - template_robomimic 3 | - _self_ 4 | 5 | name: robomimic_square_ph 6 | hdf5_path_globs: /home/ubuntu/robomimic/datasets/square/ph/image_v141.hdf5 7 | buffer_path: /home/ubuntu/robomimic/datasets/square/ph/image_v141.zarr -------------------------------------------------------------------------------- /configs/template_finetune.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - template_train 3 | - _self_ 4 | 5 | pretrain_checkpoint_path: "logdir/${algo}/droid/benchmark/0/models.pt" 6 | num_steps: 10000 7 | 8 | clip_grad_norm: 1.0 9 | scheduler: 10 | name: "cosine" 11 | num_training_steps: ${num_steps} 12 | num_warmup_steps: 1000 -------------------------------------------------------------------------------- /models/dp/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_policy import NoisePredictionNet, DiffusionPolicy, FlowPolicy 2 | from .image_policy import ImageDiffusionPolicy, ImageFlowPolicy 3 | from .obs_encoder import ImageObservationEncoder 4 | from .transformer import TransformerNoisePredictionNet 5 | from .unet import UnetNoisePredictionNet 6 | -------------------------------------------------------------------------------- /configs/dataset/libero_moka_moka.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - template_libero 3 | - _self_ 4 | 5 | name: libero_moka_moka 6 | hdf5_path_globs: /home/ubuntu/LIBERO/libero/datasets/libero_10/KITCHEN_SCENE8_put_both_moka_pots_on_the_stove_demo.hdf5 7 | buffer_path: /home/ubuntu/LIBERO/libero/datasets/libero_10/KITCHEN_SCENE8_put_both_moka_pots_on_the_stove_demo.zarr -------------------------------------------------------------------------------- /configs/template_finetune_robomimic.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - template_train_robomimic 3 | - _self_ 4 | 5 | pretrain_checkpoint_path: "/home/ubuntu/efs/chuning/logdir/${algo}/libero_90/benchmark/0/models.pt" 6 | num_steps: 20000 7 | 8 | clip_grad_norm: 1.0 9 | scheduler: 10 | name: "cosine" 11 | num_training_steps: ${num_steps} 12 | num_warmup_steps: 1000 -------------------------------------------------------------------------------- /configs/dataset/libero_stove_moka.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - template_libero 3 | - _self_ 4 | 5 | name: libero_stove_moka 6 | hdf5_path_globs: /home/ubuntu/LIBERO/libero/datasets/libero_10/KITCHEN_SCENE3_turn_on_the_stove_and_put_the_moka_pot_on_it_demo.hdf5 7 | buffer_path: /home/ubuntu/LIBERO/libero/datasets/libero_10/KITCHEN_SCENE3_turn_on_the_stove_and_put_the_moka_pot_on_it_demo.zarr -------------------------------------------------------------------------------- /configs/dataset/libero_mug_microwave.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - template_libero 3 | - _self_ 4 | 5 | name: libero_mug_microwave 6 | hdf5_path_globs: /home/ubuntu/LIBERO/libero/datasets/libero_10/KITCHEN_SCENE6_put_the_yellow_and_white_mug_in_the_microwave_and_close_it_demo.hdf5 7 | buffer_path: /home/ubuntu/LIBERO/libero/datasets/libero_10/KITCHEN_SCENE6_put_the_yellow_and_white_mug_in_the_microwave_and_close_it_demo.zarr -------------------------------------------------------------------------------- /configs/dataset/libero_cheese_butter.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - template_libero 3 | - _self_ 4 | 5 | name: libero_cheese_butter 6 | hdf5_path_globs: /home/ubuntu/LIBERO/libero/datasets/libero_10/LIVING_ROOM_SCENE2_put_both_the_cream_cheese_box_and_the_butter_in_the_basket_demo.hdf5 7 | buffer_path: /home/ubuntu/LIBERO/libero/datasets/libero_10/LIVING_ROOM_SCENE2_put_both_the_cream_cheese_box_and_the_butter_in_the_basket_demo.zarr -------------------------------------------------------------------------------- /configs/dataset/libero_soup_sauce.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - template_libero 3 | - _self_ 4 | 5 | name: libero_soup_sauce 6 | hdf5_path_globs: /home/ubuntu/LIBERO/libero/datasets/libero_10/LIVING_ROOM_SCENE2_put_both_the_alphabet_soup_and_the_tomato_sauce_in_the_basket_demo.hdf5 7 | buffer_path: /home/ubuntu/LIBERO/libero/datasets/libero_10/LIVING_ROOM_SCENE2_put_both_the_alphabet_soup_and_the_tomato_sauce_in_the_basket_demo.zarr -------------------------------------------------------------------------------- /configs/dataset/libero_bowl_drawer.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - template_libero 3 | - _self_ 4 | 5 | name: libero_bowl_drawer 6 | hdf5_path_globs: /home/ubuntu/LIBERO/libero/datasets/libero_10/KITCHEN_SCENE4_put_the_black_bowl_in_the_bottom_drawer_of_the_cabinet_and_close_it_demo.hdf5 7 | buffer_path: /home/ubuntu/LIBERO/libero/datasets/libero_10/KITCHEN_SCENE4_put_the_black_bowl_in_the_bottom_drawer_of_the_cabinet_and_close_it_demo.zarr -------------------------------------------------------------------------------- /configs/dataset/libero_soup_cheese.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - template_libero 3 | - _self_ 4 | 5 | name: libero_soup_cheese 6 | hdf5_path_globs: /home/ubuntu/LIBERO/libero/datasets/libero_10/LIVING_ROOM_SCENE1_put_both_the_alphabet_soup_and_the_cream_cheese_box_in_the_basket_demo.hdf5 7 | buffer_path: /home/ubuntu/LIBERO/libero/datasets/libero_10/LIVING_ROOM_SCENE1_put_both_the_alphabet_soup_and_the_cream_cheese_box_in_the_basket_demo.zarr -------------------------------------------------------------------------------- /configs/dataset/libero_mug_mug.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - template_libero 3 | - _self_ 4 | 5 | name: libero_mug_mug 6 | hdf5_path_globs: /home/ubuntu/LIBERO/libero/datasets/libero_10/LIVING_ROOM_SCENE5_put_the_white_mug_on_the_left_plate_and_put_the_yellow_and_white_mug_on_the_right_plate_demo.hdf5 7 | buffer_path: /home/ubuntu/LIBERO/libero/datasets/libero_10/LIVING_ROOM_SCENE5_put_the_white_mug_on_the_left_plate_and_put_the_yellow_and_white_mug_on_the_right_plate_demo.zarr -------------------------------------------------------------------------------- /configs/dataset/libero_mug_pudding.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - template_libero 3 | - _self_ 4 | 5 | name: libero_mug_pudding 6 | hdf5_path_globs: /home/ubuntu/LIBERO/libero/datasets/libero_10/LIVING_ROOM_SCENE6_put_the_white_mug_on_the_plate_and_put_the_chocolate_pudding_to_the_right_of_the_plate_demo.hdf5 7 | buffer_path: /home/ubuntu/LIBERO/libero/datasets/libero_10/LIVING_ROOM_SCENE6_put_the_white_mug_on_the_plate_and_put_the_chocolate_pudding_to_the_right_of_the_plate_demo.zarr -------------------------------------------------------------------------------- /configs/dataset/template_robomimic.yaml: -------------------------------------------------------------------------------- 1 | _target_: datasets.robomimic.make_robomimic_dataset 2 | name: robomimic_can_ph 3 | hdf5_path_globs: /home/ubuntu/robomimic/datasets/can/ph/image_v141.hdf5 4 | buffer_path: /home/ubuntu/robomimic/datasets/can/ph/image_v141.zarr 5 | shape_meta: 6 | obs: 7 | agentview_image: &camera_meta 8 | shape: [84, 84, 3] 9 | type: rgb 10 | robot0_eye_in_hand_image: *camera_meta 11 | action: 12 | shape: [7] # pos + quarternion + gripper 13 | seq_len: ${num_frames} 14 | val_ratio: 0.05 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # IDE config 31 | .idea 32 | .vscode 33 | 34 | # experiment data 35 | logdir/ 36 | wandb/ 37 | logs/ -------------------------------------------------------------------------------- /configs/dataset/robomimic_tool_hang_ph.yaml: -------------------------------------------------------------------------------- 1 | _target_: datasets.robomimic.make_robomimic_dataset 2 | name: robomimic_tool_hang_ph 3 | hdf5_path_globs: /home/ubuntu/robomimic/datasets/tool_hang/ph/image_v141.hdf5 4 | buffer_path: /home/ubuntu/robomimic/datasets/tool_hang/ph/image_v141.zarr 5 | shape_meta: 6 | obs: 7 | sideview_image: &camera_meta 8 | shape: [240, 240, 3] 9 | type: rgb 10 | robot0_eye_in_hand_image: *camera_meta 11 | action: 12 | shape: [7] # pos + quarternion + gripper 13 | seq_len: ${num_frames} 14 | val_ratio: 0.05 -------------------------------------------------------------------------------- /configs/dataset/webvideo_for_robomimic.yaml: -------------------------------------------------------------------------------- 1 | _target_: datasets.webvideo.make_multiview_video_dataset 2 | name: webvideo_for_robomimic 3 | index_paths: [ 4 | /gscratch/weirdlab/zchuning/data/k400_filtered.txt, 5 | /gscratch/weirdlab/zchuning/data/ssv2_filtered.txt, 6 | ] 7 | shape_meta: 8 | obs: 9 | agentview_image: &camera_meta 10 | shape: [84, 84, 3] 11 | type: rgb 12 | robot0_eye_in_hand_image: *camera_meta 13 | action: 14 | shape: [7] # pos + quarternion + gripper 15 | seq_len: ${num_frames} 16 | obs_padding: "same" 17 | frame_skip: 2 18 | val_ratio: 0.05 19 | -------------------------------------------------------------------------------- /configs/template_train.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: logdir/${algo}/${dataset.name}/${exp_id}/${seed} 4 | 5 | algo: ${hydra:runtime.choices.model} 6 | exp_id: default 7 | seed: 0 8 | logdir: ${hydra:run.dir} 9 | resume: False 10 | use_amp: False 11 | eval_task_name: ${dataset.name} 12 | pretrain_checkpoint_path: null 13 | 14 | num_steps: 100000 15 | eval_every: 10000 16 | save_every: 10000 17 | 18 | batch_size: 36 19 | num_frames: 19 20 | obs_num_frames: 2 21 | clip_grad_norm: null 22 | 23 | optimizer: 24 | lr: 1e-4 25 | weight_decay: 1e-6 26 | betas: [0.9, 0.999] 27 | eps: 1e-8 28 | 29 | scheduler: 30 | name: "constant" -------------------------------------------------------------------------------- /configs/dataset/robomimic_transport_ph.yaml: -------------------------------------------------------------------------------- 1 | _target_: datasets.robomimic.make_robomimic_dataset 2 | name: robomimic_transport_ph 3 | hdf5_path_globs: /home/ubuntu/robomimic/datasets/transport/ph/image_v141.hdf5 4 | buffer_path: /home/ubuntu/robomimic/datasets/transport/ph/image_v141.zarr 5 | shape_meta: 6 | obs: 7 | shouldercamera0_image: &camera_meta 8 | shape: [84, 84, 3] 9 | type: rgb 10 | shouldercamera1_image: *camera_meta 11 | robot0_eye_in_hand_image: *camera_meta 12 | robot1_eye_in_hand_image: *camera_meta 13 | action: 14 | shape: [14] # 2 * (pos + quarternion + gripper) 15 | seq_len: ${num_frames} 16 | val_ratio: 0.05 -------------------------------------------------------------------------------- /configs/dataset/template_libero.yaml: -------------------------------------------------------------------------------- 1 | _target_: datasets.robomimic.make_robomimic_dataset 2 | name: libero_book_caddy 3 | hdf5_path_globs: /home/ubuntu/LIBERO/libero/datasets/libero_10/STUDY_SCENE1_pick_up_the_book_and_place_it_in_the_back_compartment_of_the_caddy_demo.hdf5 4 | buffer_path: /home/ubuntu/LIBERO/libero/datasets/libero_10/STUDY_SCENE1_pick_up_the_book_and_place_it_in_the_back_compartment_of_the_caddy_demo.zarr 5 | shape_meta: 6 | obs: 7 | agentview_rgb: &camera_meta 8 | shape: [128, 128, 3] 9 | type: rgb 10 | eye_in_hand_rgb: *camera_meta 11 | action: 12 | shape: [7] # pos + quarternion + gripper 13 | seq_len: ${num_frames} 14 | val_ratio: 0.05 15 | subsample_ratio: 1.0 16 | flip_rgb: True -------------------------------------------------------------------------------- /configs/model/pad.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.pad.PAD 2 | action_len: 16 3 | action_dim: ${dataset.shape_meta.action.shape[0]} 4 | obs_encoder: 5 | _target_: models.pad.PADObservationEncoder 6 | shape_meta: ${dataset.shape_meta} 7 | num_frames: ${obs_num_frames} 8 | resize_shape: [240, 320] 9 | crop_shape: [224, 224] 10 | random_crop: True 11 | color_jitter: 12 | brightness: 0.2 13 | contrast: 0.2 14 | saturation: 0.2 15 | hue: [-0.2, 0.2] 16 | imagenet_norm: False 17 | embed_dim: 768 18 | timestep_embed_dim: 512 19 | latent_patch_shape: [2, 4, 4] 20 | depth: 12 21 | num_heads: 12 22 | mlp_ratio: 4 23 | qkv_bias: True 24 | num_train_steps: 100 25 | num_inference_steps: 10 26 | beta_schedule: squaredcos_cap_v2 27 | clip_sample: True -------------------------------------------------------------------------------- /models/common/language.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import CLIPModel 4 | 5 | 6 | class CLIPTextEncoder(nn.Module): 7 | def __init__(self, embed_dim: int): 8 | super().__init__() 9 | self.language_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") 10 | # Freeze pretrained model parameters 11 | for p in self.language_model.parameters(): 12 | p.requires_grad = False 13 | self.head = nn.Linear(512, embed_dim) 14 | 15 | def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor): 16 | feats = self.language_model.get_text_features( 17 | input_ids=input_ids, attention_mask=attention_mask 18 | ) 19 | feats = self.head(feats) 20 | return feats 21 | -------------------------------------------------------------------------------- /configs/dataset/droid.yaml: -------------------------------------------------------------------------------- 1 | _target_: datasets.droid.make_droid_dataset 2 | name: droid 3 | 4 | buffer_path: /tmp/weirdlab/zchuning/data/droid/buffer.zarr 5 | shape_meta: 6 | # acceptable types: rgb, low_dim 7 | obs: 8 | # camera 9 | "exterior_image_1_left": &camera_meta 10 | shape: [180, 320, 3] 11 | type: rgb 12 | "exterior_image_2_left": *camera_meta 13 | "wrist_image_left": *camera_meta 14 | 15 | # lowdim 16 | "cartesian_position": 17 | shape: [6] 18 | type: low_dim 19 | "gripper_position": 20 | shape: [1] 21 | type: low_dim 22 | 23 | action: 24 | # 3(pos) + 6(rot) + 1(gripper) 25 | shape: [10] 26 | 27 | seq_len: ${num_frames} 28 | history_len: ${obs_num_frames} 29 | normalize_action: True 30 | normalize_lowdim: True 31 | val_ratio: 0.05 -------------------------------------------------------------------------------- /configs/template_train_robomimic.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: /home/ubuntu/efs/chuning/logdir/${algo}/${dataset.name}/${exp_id}/${seed} 4 | 5 | algo: ${hydra:runtime.choices.model} 6 | exp_id: default 7 | seed: 0 8 | logdir: ${hydra:run.dir} 9 | resume: False 10 | use_amp: False 11 | eval_task_name: ${dataset.name} 12 | pretrain_checkpoint_path: null 13 | 14 | num_steps: 100000 15 | eval_every: 10000 16 | save_every: 10000 17 | rollout_every: 10000 18 | num_rollouts: 50 19 | rollout_length: 1000 20 | 21 | batch_size: 36 22 | num_frames: 19 23 | obs_num_frames: 2 24 | clip_grad_norm: null 25 | 26 | optimizer: 27 | lr: 1e-4 28 | weight_decay: 1e-6 29 | betas: [0.9, 0.999] 30 | eps: 1e-8 31 | 32 | scheduler: 33 | name: "constant" 34 | 35 | # Override resize shape 36 | model: 37 | obs_encoder: 38 | resize_shape: [240, 240] 39 | -------------------------------------------------------------------------------- /scripts/launch_droid_pretrain.sh: -------------------------------------------------------------------------------- 1 | # Cache dataset 2 | DATA_DIR="/gscratch/weirdlab/memmelma/data/" 3 | BUFFER_PATH="/tmp/weirdlab/zchuning/data/droid/buffer_weird.zarr" 4 | if [ ! -d $BUFFER_PATH ]; then 5 | # Cache dataset 6 | echo "Caching dataset..." 7 | python datasets/droid/convert_dataset_zarr.py --data_dir $DATA_DIR --buffer_path $BUFFER_PATH --num_episodes 2000 --num_workers 8 --filter_key WEIRD 8 | fi 9 | 10 | # UWM 11 | python experiments/uwm/train.py dataset=droid exp_id=benchmark dataset.buffer_path=$BUFFER_PATH 12 | 13 | # DP 14 | # python experiments/dp/train.py dataset=droid exp_id=benchmark dataset.buffer_path=$BUFFER_PATH 15 | 16 | # GR1 17 | # python experiments/gr1/train.py dataset=droid exp_id=benchmark dataset.buffer_path=$BUFFER_PATH 18 | 19 | # PAD 20 | # python experiments/pad/train.py dataset=droid exp_id=benchmark dataset.buffer_path=$BUFFER_PATH -------------------------------------------------------------------------------- /datasets/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from os.path import expanduser, expandvars 3 | from typing import Union 4 | 5 | from omegaconf.listconfig import ListConfig 6 | 7 | 8 | def glob_all(globs: Union[str, list[str], ListConfig]) -> list[str]: 9 | """ 10 | Expand a list of glob strings into a list of file paths. 11 | 12 | Args: 13 | globs: A glob string or a list of glob strings. 14 | 15 | Returns: 16 | A list of file paths matching the glob strings. 17 | """ 18 | 19 | if isinstance(globs, str): 20 | globs = [globs] 21 | elif isinstance(globs, ListConfig): 22 | globs = list(globs) 23 | assert isinstance(globs, list) 24 | 25 | files = [] 26 | for glob_str in globs: 27 | glob_str = expandvars(expanduser(glob_str)) 28 | files.extend(glob.glob(glob_str)) 29 | 30 | return sorted(files) 31 | -------------------------------------------------------------------------------- /configs/model/gr1.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.gr1.GR1 2 | action_len: 16 3 | action_dim: ${dataset.shape_meta.action.shape[0]} 4 | obs_encoder: 5 | _target_: models.gr1.GR1ObservationEncoder 6 | shape_meta: ${dataset.shape_meta} 7 | num_frames: ${obs_num_frames} 8 | embed_dim: 768 9 | resize_shape: [240, 320] 10 | crop_shape: [224, 224] 11 | random_crop: True 12 | color_jitter: 13 | brightness: 0.2 14 | contrast: 0.2 15 | saturation: 0.2 16 | hue: [-0.2, 0.2] 17 | imagenet_norm: True 18 | resampler_params: 19 | dim: 768 20 | depth: 3 21 | dim_head: 128 22 | heads: 4 23 | num_latents: 9 24 | num_media_embeds: 1 25 | embed_dim: 768 26 | image_size: 224 27 | patch_size: 32 28 | depth: 12 29 | num_heads: 12 30 | mlp_ratio: 4 31 | qkv_bias: True 32 | decoder_depth: 3 33 | decoder_num_heads: 16 34 | decoder_mlp_ratio: 4 35 | decoder_qkv_bias: True -------------------------------------------------------------------------------- /configs/model/uwm.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.uwm.UnifiedWorldModel 2 | action_len: 16 3 | action_dim: ${dataset.shape_meta.action.shape[0]} 4 | obs_encoder: 5 | _target_: models.uwm.UWMObservationEncoder 6 | shape_meta: ${dataset.shape_meta} 7 | num_frames: ${obs_num_frames} 8 | embed_dim: 768 9 | resize_shape: [240, 320] 10 | crop_shape: [224, 224] 11 | random_crop: True 12 | color_jitter: 13 | brightness: 0.2 14 | contrast: 0.2 15 | saturation: 0.2 16 | hue: [-0.2, 0.2] 17 | imagenet_norm: False 18 | vision_backbone: resnet 19 | use_low_dim: False 20 | use_language: False 21 | embed_dim: 768 22 | timestep_embed_dim: 512 23 | latent_patch_shape: [2, 4, 4] 24 | depth: 12 25 | num_heads: 12 26 | mlp_ratio: 4 27 | qkv_bias: True 28 | num_registers: 8 29 | num_train_steps: 100 30 | num_inference_steps: 10 31 | beta_schedule: squaredcos_cap_v2 32 | clip_sample: True -------------------------------------------------------------------------------- /datasets/webvideo/__init__.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader, random_split 2 | from torch.utils.data.distributed import DistributedSampler 3 | 4 | from .dataset import MultiviewVideoDataset 5 | 6 | 7 | def make_multiview_video_dataset( 8 | name: str, 9 | index_paths: list[str], 10 | shape_meta: dict, 11 | seq_len: int, 12 | frame_skip: int = 2, 13 | obs_padding: str = "same", 14 | val_ratio: float = 0.0, 15 | ): 16 | dataset = MultiviewVideoDataset( 17 | index_paths=index_paths, 18 | shape_meta=shape_meta, 19 | clip_len=seq_len, 20 | frame_skip=frame_skip, 21 | obs_padding=obs_padding, 22 | ) 23 | 24 | train_size = int(len(dataset) * (1 - val_ratio)) 25 | val_size = len(dataset) - train_size 26 | train_set, val_set = random_split(dataset, [train_size, val_size]) 27 | return train_set, val_set 28 | -------------------------------------------------------------------------------- /datasets/utils/obs_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | 5 | 6 | def unflatten_obs(x: dict[str, torch.Tensor]) -> dict[str, Union[dict, torch.Tensor]]: 7 | """ 8 | Unflattens entries with keys starting with "obs" as follows: 9 | 10 | dict_data["obs.some_name"] -> dict_data["obs"]["some_name"] 11 | 12 | Args: 13 | x: A dictionary of tensors. 14 | 15 | Returns: 16 | The same dictionary, but keys starting with "obs" are unflattened. 17 | """ 18 | 19 | obs = {} 20 | to_delete = [] 21 | for key, value in x.items(): 22 | if key.startswith("obs."): 23 | new_key = key[4:] 24 | assert new_key not in obs, f"Duplicate key {new_key}" 25 | obs[new_key] = value 26 | to_delete.append(key) 27 | for key in to_delete: 28 | del x[key] 29 | x["obs"] = obs 30 | return x 31 | -------------------------------------------------------------------------------- /configs/model/dp.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.dp.ImageDiffusionPolicy 2 | action_len: 16 3 | action_dim: ${dataset.shape_meta.action.shape[0]} 4 | obs_encoder: 5 | _target_: models.dp.ImageObservationEncoder 6 | shape_meta: ${dataset.shape_meta} 7 | num_frames: ${obs_num_frames} 8 | embed_dim: 768 9 | resize_shape: [240, 320] 10 | crop_shape: [224, 224] 11 | random_crop: True 12 | color_jitter: 13 | brightness: 0.2 14 | contrast: 0.2 15 | saturation: 0.2 16 | hue: [-0.2, 0.2] 17 | imagenet_norm: True 18 | pretrained_weights: IMAGENET1K_V1 19 | use_low_dim: False 20 | use_language: False 21 | noise_pred_net: 22 | _target_: models.dp.TransformerNoisePredictionNet 23 | _partial_: True 24 | input_len: ${model.action_len} 25 | input_dim: ${model.action_dim} 26 | global_cond_dim: ??? 27 | timestep_embed_dim: 256 28 | embed_dim: 768 29 | num_heads: 12 30 | mlp_ratio: 4 31 | qkv_bias: True 32 | num_train_steps: 100 33 | num_inference_steps: 10 -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=75.1", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "unified_world_model" 7 | version = "0.1.0" 8 | description = "PyTorch implementation of Unified World Models: Coupling Video and Action Diffusion for Pretraining on Large Robotic Datasets" 9 | readme = "README.md" 10 | requires-python = ">=3.10.14" 11 | dynamic = ["dependencies"] 12 | 13 | authors = [ 14 | { name = "Chuning Zhu" }, 15 | { name = "Raymond Yu" }, 16 | { name = "Siyuan Feng" }, 17 | { name = "Benjamin Burchfiel" }, 18 | { name = "Paarth Shah" }, 19 | { name = "Abhishek Gupta" } 20 | ] 21 | 22 | classifiers = [ 23 | "Programming Language :: Python :: 3", 24 | "Programming Language :: Python :: 3.10", 25 | "Operating System :: OS Independent" 26 | ] 27 | 28 | [tool.setuptools] 29 | packages = [ 30 | "models", 31 | "configs", 32 | "datasets", 33 | "experiments", 34 | "environments" 35 | ] 36 | 37 | [tool.setuptools.dynamic] 38 | dependencies = { file = ["requirements.txt"] } -------------------------------------------------------------------------------- /configs/dataset/libero_90_webvideo.yaml: -------------------------------------------------------------------------------- 1 | _target_: datasets.utils.mixture.make_robot_video_mixture_dataset 2 | name: libero_90_webvideo 3 | shape_meta: &shape_meta 4 | obs: 5 | agentview_rgb: &camera_meta 6 | shape: [128, 128, 3] 7 | type: rgb 8 | eye_in_hand_rgb: *camera_meta 9 | action: 10 | shape: [7] # pos + quarternion + gripper 11 | hdf5_path_globs: &hdf5_path_globs /home/ubuntu/LIBERO/libero/datasets/libero_90/*.hdf5 12 | robot_train_val_sets: 13 | _target_: datasets.robomimic.make_robomimic_dataset 14 | name: libero_90 15 | hdf5_path_globs: *hdf5_path_globs 16 | buffer_path: /home/ubuntu/LIBERO/libero/datasets/libero_90/buffer.zarr 17 | shape_meta: *shape_meta 18 | seq_len: ${num_frames} 19 | val_ratio: 0.05 20 | video_train_val_sets: 21 | _target_: datasets.webvideo.make_multiview_video_dataset 22 | name: webvideo 23 | index_paths: [ 24 | /home/ubuntu/efs/chuning/k400_filtered.txt, 25 | /home/ubuntu/efs/chuning/ssv2_filtered.txt, 26 | ] 27 | shape_meta: *shape_meta 28 | seq_len: ${num_frames} 29 | obs_padding: "same" 30 | frame_skip: 2 31 | val_ratio: 0.05 32 | balance_datasets: False -------------------------------------------------------------------------------- /configs/dataset/droid_webvideo.yaml: -------------------------------------------------------------------------------- 1 | _target_: datasets.utils.mixture.make_robot_video_mixture_dataset 2 | name: droid_webvideo 3 | shape_meta: &shape_meta 4 | obs: 5 | "exterior_image_1_left": &camera_meta 6 | shape: [180, 320, 3] 7 | type: rgb 8 | "exterior_image_2_left": *camera_meta 9 | "wrist_image_left": *camera_meta 10 | "cartesian_position": 11 | shape: [6] 12 | type: low_dim 13 | "gripper_position": 14 | shape: [1] 15 | type: low_dim 16 | action: 17 | shape: [10] 18 | robot_train_val_sets: 19 | _target_: datasets.droid.make_droid_dataset 20 | name: droid 21 | buffer_path: /tmp/weirdlab/zchuning/data/droid/buffer.zarr 22 | shape_meta: *shape_meta 23 | seq_len: ${num_frames} 24 | history_len: ${obs_num_frames} 25 | normalize_action: True 26 | normalize_lowdim: True 27 | val_ratio: 0.05 28 | video_train_val_sets: 29 | _target_: datasets.webvideo.make_multiview_video_dataset 30 | name: webvideo 31 | index_paths: [ 32 | /gscratch/weirdlab/zchuning/data/k400_filtered.txt, 33 | /gscratch/weirdlab/zchuning/data/ssv2_filtered.txt, 34 | ] 35 | shape_meta: *shape_meta 36 | seq_len: ${num_frames} 37 | obs_padding: "same" 38 | frame_skip: 2 39 | val_ratio: 0.05 40 | balance_datasets: False -------------------------------------------------------------------------------- /datasets/droid/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from torch.utils.data import DataLoader 4 | from torch.utils.data.distributed import DistributedSampler 5 | 6 | from .dataset import DroidDataset, DroidMixtureDataset 7 | 8 | 9 | def make_droid_dataset( 10 | name: str, 11 | buffer_path: str, 12 | shape_meta: dict, 13 | seq_len: int, 14 | history_len: bool = 1, 15 | normalize_action: bool = False, 16 | normalize_lowdim: bool = False, 17 | val_ratio: float = 0.0, 18 | video_buffer_path: Optional[str] = None, 19 | balance_datasets: bool = False, 20 | ): 21 | # Training dataset 22 | train_set = DroidDataset( 23 | name=name, 24 | buffer_path=buffer_path, 25 | shape_meta=shape_meta, 26 | seq_len=seq_len, 27 | history_len=history_len, 28 | normalize_lowdim=normalize_lowdim, 29 | normalize_action=normalize_action, 30 | val_ratio=val_ratio, 31 | ) 32 | if video_buffer_path is not None: 33 | train_set = DroidMixtureDataset( 34 | base_dataset=train_set, 35 | video_buffer_path=video_buffer_path, 36 | balance_datasets=balance_datasets, 37 | ) 38 | 39 | # Validation dataset 40 | val_set = train_set.get_validation_dataset() 41 | return train_set, val_set 42 | -------------------------------------------------------------------------------- /datasets/robomimic/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Union 3 | 4 | import torch.distributed as dist 5 | from .dataset import RobomimicDataset 6 | 7 | 8 | def make_robomimic_dataset( 9 | name: str, 10 | hdf5_path_globs: Union[str, list[str]], 11 | buffer_path: str, 12 | shape_meta: dict, 13 | seq_len: int, 14 | val_ratio: float = 0.0, 15 | subsample_ratio: float = 1.0, 16 | flip_rgb: bool = False, 17 | ): 18 | # Cache compressed dataset in the main process 19 | if not os.path.exists(buffer_path): 20 | if not dist.is_initialized() or dist.get_rank() == 0: 21 | RobomimicDataset( 22 | name=name, 23 | hdf5_path_globs=hdf5_path_globs, 24 | buffer_path=buffer_path, 25 | shape_meta=shape_meta, 26 | seq_len=seq_len, 27 | flip_rgb=flip_rgb, 28 | ) 29 | if dist.is_initialized(): 30 | dist.barrier() 31 | 32 | # Training dataset 33 | train_set = RobomimicDataset( 34 | name=name, 35 | hdf5_path_globs=hdf5_path_globs, 36 | buffer_path=buffer_path, 37 | shape_meta=shape_meta, 38 | seq_len=seq_len, 39 | val_ratio=val_ratio, 40 | subsample_ratio=subsample_ratio, 41 | flip_rgb=flip_rgb, 42 | ) 43 | val_set = train_set.get_validation_dataset() 44 | return train_set, val_set 45 | -------------------------------------------------------------------------------- /scripts/launch_droid_cotrain.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR="/gscratch/weirdlab/memmelma/data/" 2 | 3 | # Cache robot dataset 4 | ROBOT_BUFFER_PATH="/tmp/weirdlab/zchuning/data/droid/buffer_weird.zarr" 5 | if [ ! -d $ROBOT_BUFFER_PATH ]; then 6 | # Cache dataset 7 | echo "Caching robot dataset..." 8 | python datasets/droid/convert_dataset_zarr.py --data_dir $DATA_DIR --buffer_path $ROBOT_BUFFER_PATH --num_episodes 2000 --num_workers 8 --filter_key WEIRD 9 | fi 10 | 11 | # Cache video dataset 12 | VIDEO_BUFFER_PATH="/tmp/weirdlab/zchuning/data/droid/buffer_video.zarr" 13 | if [ ! -d $VIDEO_BUFFER_PATH ]; then 14 | # Cache dataset 15 | echo "Caching video dataset..." 16 | python datasets/droid/convert_dataset_zarr.py --data_dir $DATA_DIR --buffer_path $VIDEO_BUFFER_PATH --num_episodes 2000 --num_workers 8 --except_key WEIRD 17 | fi 18 | 19 | # UWM 20 | python experiments/uwm/train.py dataset=droid_mixture exp_id=benchmark_cotrain \ 21 | dataset.buffer_path=$ROBOT_BUFFER_PATH \ 22 | dataset.video_buffer_path=$VIDEO_BUFFER_PATH 23 | 24 | # GR1 25 | # python experiments/gr1/train.py dataset=droid_mixture exp_id=benchmark_cotrain \ 26 | # dataset.buffer_path=$ROBOT_BUFFER_PATH \ 27 | # dataset.video_buffer_path=$VIDEO_BUFFER_PATH 28 | 29 | # PAD 30 | # python experiments/pad/train.py dataset=droid_mixture exp_id=benchmark_cotrain \ 31 | # dataset.buffer_path=$ROBOT_BUFFER_PATH \ 32 | # dataset.video_buffer_path=$VIDEO_BUFFER_PATH 33 | -------------------------------------------------------------------------------- /datasets/utils/loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | from torch.utils.data.distributed import DistributedSampler 3 | 4 | 5 | def make_distributed_data_loader( 6 | train_set: Dataset, 7 | val_set: Dataset, 8 | batch_size: int, 9 | rank: int = 0, 10 | world_size: int = 1, 11 | num_workers: int = 8, 12 | pin_memory: bool = True, 13 | drop_last: bool = True, 14 | persistent_workers: bool = True, 15 | ): 16 | # Training sampler and loader 17 | train_sampler = DistributedSampler( 18 | train_set, 19 | num_replicas=world_size, 20 | rank=rank, 21 | shuffle=True, 22 | ) 23 | train_loader = DataLoader( 24 | train_set, 25 | batch_size=batch_size, 26 | sampler=train_sampler, 27 | num_workers=num_workers, 28 | pin_memory=pin_memory, 29 | drop_last=drop_last, 30 | persistent_workers=persistent_workers, 31 | ) 32 | 33 | # Validation sampler and loader 34 | val_sampler = DistributedSampler( 35 | val_set, 36 | num_replicas=world_size, 37 | rank=rank, 38 | shuffle=False, 39 | ) 40 | val_loader = DataLoader( 41 | val_set, 42 | batch_size=batch_size, 43 | sampler=val_sampler, 44 | num_workers=num_workers, 45 | pin_memory=pin_memory, 46 | drop_last=drop_last, 47 | persistent_workers=persistent_workers, 48 | ) 49 | 50 | return train_loader, val_loader 51 | -------------------------------------------------------------------------------- /scripts/setup_aws.sh: -------------------------------------------------------------------------------- 1 | # Install conda 2 | mkdir -p ~/miniconda3 3 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh 4 | bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3 5 | rm ~/miniconda3/miniconda.sh 6 | 7 | # Refresh terminal 8 | source ~/miniconda3/bin/activate 9 | 10 | # Modify bashrc 11 | cat << 'EOF' >> ~/.bashrc 12 | 13 | # >>> conda initialize >>> 14 | # !! Contents within this block are managed by 'conda init' !! 15 | __conda_setup="$('/home/ubuntu/miniconda3/bin/conda' 'shell.bash' 'hook' 2> /dev/null)" 16 | if [ $? -eq 0 ]; then 17 | eval "$__conda_setup" 18 | else 19 | if [ -f "/home/ubuntu/miniconda3/etc/profile.d/conda.sh" ]; then 20 | . "/home/ubuntu/miniconda3/etc/profile.d/conda.sh" 21 | else 22 | export PATH="/home/ubuntu/miniconda3/bin:$PATH" 23 | fi 24 | fi 25 | unset __conda_setup 26 | # <<< conda initialize <<< 27 | EOF 28 | 29 | # Create conda environment 30 | conda create --yes -n val python==3.10.14 31 | conda activate val 32 | 33 | # Install requirements 34 | pip install -r requirements.txt 35 | 36 | # Clone robomimic fork 37 | cd .. 38 | git clone git@github.com:zchuning/robomimic.git 39 | cd robomimic 40 | pip install -e . 41 | 42 | # Download robomimic dataset 43 | python robomimic/scripts/setup_macros.py 44 | python robomimic/scripts/download_datasets.py --tasks sim --dataset_types ph --hdf5_types raw 45 | cd robomimic/scripts 46 | source extract_obs_from_raw_datasets.sh 47 | 48 | # Clone LIBERO fork 49 | cd .. 50 | git clone git@github.com:zchuning/LIBERO.git 51 | cd LIBERO 52 | pip install -e . 53 | pip install -r requirements.txt 54 | 55 | # Download LIBERO data 56 | python benchmark_scripts/download_libero_datasets.py --datasets libero_100 -------------------------------------------------------------------------------- /datasets/utils/sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .buffer import CompressedTrajectoryBuffer 4 | 5 | 6 | class TrajectorySampler: 7 | """ 8 | A class that samples sequences of observations and actions from a trajectory buffer. 9 | """ 10 | 11 | def __init__( 12 | self, 13 | buffer: CompressedTrajectoryBuffer, 14 | seq_len: int, 15 | episode_mask: np.ndarray = None, 16 | ): 17 | """ 18 | Initialize the trajectory sampler. 19 | 20 | Args: 21 | buffer: The trajectory buffer containing the data. 22 | seq_len: The length of the sequences to sample. 23 | episode_mask: A binary mask indicating valid episodes. If None, all episodes are valid. 24 | """ 25 | self.buffer = buffer 26 | self.seq_len = seq_len 27 | self.keys = list(self.buffer.keys()) 28 | 29 | # Compute all possible sample indices 30 | indices = [] 31 | episode_start = 0 32 | for i, episode_end in enumerate(self.buffer.episode_ends): 33 | if episode_mask is None or episode_mask[i]: 34 | for j in range(episode_start, episode_end + 1 - seq_len): 35 | indices.append([j, j + seq_len]) 36 | episode_start = episode_end 37 | self.indices = np.array(indices, dtype=np.int64) 38 | print(f"Total number of valid sequences: {len(self.indices)}") 39 | 40 | def __len__(self) -> int: 41 | return len(self.indices) 42 | 43 | def sample_sequence(self, index: int) -> dict[str, np.ndarray]: 44 | start, end = self.indices[index] 45 | data = {} 46 | for key in self.keys: 47 | arr = self.buffer[key] 48 | value = arr[start:end] 49 | data[key] = value 50 | return data 51 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.33.0 2 | antlr4-python3-runtime==4.9.3 3 | appdirs==1.4.4 4 | asciitree==0.3.3 5 | certifi==2024.6.2 6 | charset-normalizer==3.3.2 7 | click==8.1.7 8 | cloudpickle==3.0.0 9 | dask==2024.6.0 10 | decord==0.6.0 11 | diffusers==0.30.0 12 | dm-tree==0.1.8 13 | docker-pycreds==0.4.0 14 | einops==0.7.0 15 | fasteners==0.19 16 | filelock==3.15.3 17 | fsspec==2024.6.0 18 | ftfy==6.2.3 19 | gitdb==4.0.11 20 | GitPython==3.1.43 21 | huggingface-hub==0.23.4 22 | hydra-core==1.3.2 23 | idna==3.7 24 | imageio==2.21.2 25 | imageio_ffmpeg==0.5.1 26 | importlib_metadata==7.2.0 27 | Jinja2==3.1.4 28 | locket==1.0.0 29 | markdown-it-py==3.0.0 30 | MarkupSafe==2.1.5 31 | mdurl==0.1.2 32 | moviepy==2.1.1 33 | mpmath==1.3.0 34 | networkx==3.3 35 | numcodecs==0.12.1 36 | numpy==1.26.4 37 | nvidia-cublas-cu12==12.1.3.1 38 | nvidia-cuda-cupti-cu12==12.1.105 39 | nvidia-cuda-nvrtc-cu12==12.1.105 40 | nvidia-cuda-runtime-cu12==12.1.105 41 | nvidia-cudnn-cu12==8.9.2.26 42 | nvidia-cufft-cu12==11.0.2.54 43 | nvidia-curand-cu12==10.3.2.106 44 | nvidia-cusolver-cu12==11.4.5.107 45 | nvidia-cusparse-cu12==12.1.0.106 46 | nvidia-nccl-cu12==2.19.3 47 | nvidia-nvjitlink-cu12==12.5.40 48 | nvidia-nvtx-cu12==12.1.105 49 | omegaconf==2.3.0 50 | packaging==24.1 51 | pandas==2.2.1 52 | partd==1.4.2 53 | pillow==10.3.0 54 | promise==2.3 55 | protobuf==3.20.3 56 | psutil==6.0.0 57 | Pygments==2.18.0 58 | python-dateutil==2.9.0.post0 59 | pytz==2024.1 60 | PyYAML==6.0.1 61 | regex==2024.5.15 62 | requests==2.32.3 63 | rich==13.7.1 64 | robosuite==1.4.1 65 | safetensors==0.4.4 66 | scipy==1.14.0 67 | sentry-sdk==2.6.0 68 | setproctitle==1.3.3 69 | six==1.16.0 70 | smmap==5.0.1 71 | sympy==1.12.1 72 | tensorflow==2.15.0 73 | tensorflow-datasets==4.9.7 74 | tensorflow-metadata==1.16.1 75 | timm==1.0.9 76 | tokenizers==0.19.1 77 | toml==0.10.2 78 | toolz==0.12.1 79 | torch==2.2.2 80 | torchaudio==2.2.2 81 | torchvision==0.17.2 82 | tqdm==4.66.2 83 | transformers==4.44.0 84 | triton==2.2.0 85 | typing_extensions==4.12.2 86 | tzdata==2024.1 87 | urllib3==2.2.2 88 | wandb==0.19.1 89 | wcwidth==0.2.13 90 | zarr==2.18.2 91 | zipp==3.19.2 92 | -------------------------------------------------------------------------------- /datasets/utils/normalizer.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import numpy as np 3 | 4 | 5 | class LinearNormalizer: 6 | """ 7 | A class to linearly normalizes data to the range [-1, 1]. 8 | """ 9 | 10 | def __init__(self, scale: np.ndarray, offset: np.ndarray): 11 | """ 12 | Initializes a new instance of the LinearNormalizer class with given statistics. 13 | 14 | Args: 15 | scale: The scale factor for normalization. 16 | offset: The offset for normalization. 17 | """ 18 | self.scale = scale 19 | self.offset = offset 20 | 21 | def __call__(self, x: np.ndarray) -> np.ndarray: 22 | """ 23 | Normalizes the input data using the stored statistics. 24 | """ 25 | return (x - self.offset) / self.scale 26 | 27 | def reconstruct(self, x: np.ndarray) -> np.ndarray: 28 | """ 29 | Reconstructs the original data from the normalized data. 30 | """ 31 | return x * self.scale + self.offset 32 | 33 | 34 | class NestedDictLinearNormalizer(dict): 35 | """ 36 | A class that applies linear normalization to values in a nested dictionary structure. 37 | """ 38 | 39 | def __init__(self, stats: dict[str, Union[tuple[np.ndarray, np.ndarray], dict]]): 40 | """ 41 | Initializes a new instance of the NestedDictLinearNormalizer class with given statistics. 42 | 43 | Args: 44 | stats: A dictionary containing statistics for each key. The values can either 45 | be tuples representing scale and offset values or dictionaries that require 46 | recursive scaling. 47 | """ 48 | super().__init__() 49 | for k, v in stats.items(): 50 | if isinstance(v, dict): 51 | self[k] = NestedDictLinearNormalizer(v) 52 | else: 53 | self[k] = LinearNormalizer(np.array(v[0]), np.array(v[1])) 54 | 55 | def __call__(self, x: dict) -> dict: 56 | """ 57 | Normalizes all values in the input dictionary based on the stored normalizers. 58 | """ 59 | return {k: self[k](v) if k in self.keys() else v for k, v in x.items()} 60 | 61 | def reconstruct(self, x: dict) -> dict: 62 | """ 63 | Reconstructs the original values from normalized values in the input dictionary. 64 | """ 65 | return { 66 | k: self[k].reconstruct(v) if k in self.keys() else v for k, v in x.items() 67 | } 68 | -------------------------------------------------------------------------------- /models/dp/image_policy.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from .obs_encoder import ImageObservationEncoder 4 | from .base_policy import NoisePredictionNet, DiffusionPolicy, FlowPolicy 5 | 6 | 7 | class ImageDiffusionPolicy(DiffusionPolicy): 8 | def __init__( 9 | self, 10 | action_len: int, 11 | action_dim: int, 12 | obs_encoder: ImageObservationEncoder, 13 | noise_pred_net: partial[NoisePredictionNet], 14 | num_train_steps: int = 100, 15 | num_inference_steps: int = 10, 16 | num_train_noise_samples: int = 1, 17 | beta_schedule: str = "squaredcos_cap_v2", 18 | clip_sample: bool = True, 19 | ): 20 | """ 21 | Assumes rgb input: (B, T, H, W, C) uint8 image 22 | Assumes low_dim input: (B, T, D) 23 | """ 24 | super().__init__( 25 | action_len=action_len, 26 | action_dim=action_dim, 27 | noise_pred_net=noise_pred_net(global_cond_dim=obs_encoder.output_len), 28 | num_train_steps=num_train_steps, 29 | num_inference_steps=num_inference_steps, 30 | num_train_noise_samples=num_train_noise_samples, 31 | beta_schedule=beta_schedule, 32 | clip_sample=clip_sample, 33 | ) 34 | 35 | # Observation encoder 36 | self.obs_encoder = obs_encoder 37 | 38 | def sample(self, obs_dict): 39 | obs = self.obs_encoder(obs_dict) 40 | action = super().sample(obs) 41 | return action 42 | 43 | def forward(self, obs_dict, action): 44 | obs = self.obs_encoder(obs_dict) 45 | loss = super().forward(obs, action) 46 | return loss 47 | 48 | 49 | class ImageFlowPolicy(FlowPolicy): 50 | def __init__( 51 | self, 52 | action_len: int, 53 | action_dim: int, 54 | obs_encoder: ImageObservationEncoder, 55 | noise_pred_net: partial[NoisePredictionNet], 56 | num_train_steps: int = 100, 57 | num_inference_steps: int = 10, 58 | timeshift: float = 1.0, 59 | ): 60 | """ 61 | Assumes rgb input: (B, T, H, W, C) uint8 image 62 | Assumes low_dim input: (B, T, D) 63 | """ 64 | super().__init__( 65 | action_len=action_len, 66 | action_dim=action_dim, 67 | noise_pred_net=noise_pred_net(global_cond_dim=obs_encoder.output_len), 68 | num_train_steps=num_train_steps, 69 | num_inference_steps=num_inference_steps, 70 | timeshift=timeshift, 71 | ) 72 | 73 | # Observation encoder 74 | self.obs_encoder = obs_encoder 75 | 76 | def sample(self, obs_dict): 77 | obs = self.obs_encoder(obs_dict) 78 | action = super().sample(obs) 79 | return action 80 | 81 | def forward(self, obs_dict, action): 82 | obs = self.obs_encoder(obs_dict) 83 | loss = super().forward(obs, action) 84 | return loss 85 | -------------------------------------------------------------------------------- /scripts/launch_droid_finetune.sh: -------------------------------------------------------------------------------- 1 | # Cache dataset 2 | DATA_NAME=stack_red_bowl_on_blue_bowl 3 | DATA_DIR="/gscratch/weirdlab/zchuning/data/" 4 | BUFFER_PATH="/tmp/weirdlab/zchuning/data/droid/buffer_$DATA_NAME.zarr" 5 | if [ ! -d $BUFFER_PATH ]; then 6 | # Cache dataset 7 | echo "Caching dataset..." 8 | python datasets/droid/convert_dataset_zarr.py --data_name $DATA_NAME --data_dir $DATA_DIR --buffer_path $BUFFER_PATH --num_episodes 2000 --num_workers 8 9 | fi 10 | 11 | # UWM 12 | PRETRAIN_CHECKPOINT_PATH="/gscratch/weirdlab/zchuning/video-action-learning/logdir/uwm/droid/benchmark/0/models.pt" 13 | python experiments/uwm/train.py --config-name finetune_uwm.yaml dataset=droid exp_id=finetune_benchmark \ 14 | dataset.name=droid_$DATA_NAME \ 15 | dataset.buffer_path=$BUFFER_PATH \ 16 | pretrain_checkpoint_path=$PRETRAIN_CHECKPOINT_PATH 17 | 18 | # UWM cotrained 19 | # PRETRAIN_CHECKPOINT_PATH="/gscratch/weirdlab/zchuning/video-action-learning/logdir/uwm/droid_mixture/benchmark_cotrain/0/models.pt" 20 | # python experiments/uwm/finetune.py dataset=droid exp_id=finetune_benchmark_cotrain \ 21 | # dataset.name=droid_$DATA_NAME \ 22 | # dataset.buffer_path=$BUFFER_PATH \ 23 | # pretrain_checkpoint_path=$PRETRAIN_CHECKPOINT_PATH 24 | 25 | # DP 26 | # PRETRAIN_CHECKPOINT_PATH="/gscratch/weirdlab/zchuning/video-action-learning/logdir/dp/droid/benchmark/0/models.pt" 27 | # python experiments/dp/finetune.py dataset=droid exp_id=finetune_benchmark \ 28 | # dataset.name=droid_$DATA_NAME \ 29 | # dataset.buffer_path=$BUFFER_PATH \ 30 | # pretrain_checkpoint_path=$PRETRAIN_CHECKPOINT_PATH 31 | 32 | # GR1 33 | # PRETRAIN_CHECKPOINT_PATH="/gscratch/weirdlab/zchuning/video-action-learning/logdir/gr1/droid/benchmark/0/models.pt" 34 | # python experiments/gr1/finetune.py dataset=droid exp_id=finetune_benchmark \ 35 | # dataset.name=droid_$DATA_NAME \ 36 | # dataset.buffer_path=$BUFFER_PATH \ 37 | # pretrain_checkpoint_path=$PRETRAIN_CHECKPOINT_PATH 38 | 39 | # GR1 cotrained 40 | # PRETRAIN_CHECKPOINT_PATH="/gscratch/weirdlab/zchuning/video-action-learning/logdir/gr1/droid_mixture/benchmark_cotrain/0/models.pt" 41 | # python experiments/gr1/finetune.py dataset=droid exp_id=finetune_benchmark_cotrain \ 42 | # dataset.name=droid_$DATA_NAME \ 43 | # dataset.buffer_path=$BUFFER_PATH \ 44 | # pretrain_checkpoint_path=$PRETRAIN_CHECKPOINT_PATH 45 | 46 | # PAD 47 | # PRETRAIN_CHECKPOINT_PATH="/gscratch/weirdlab/zchuning/video-action-learning/logdir/pad/droid/benchmark/0/models.pt" 48 | # python experiments/pad/finetune.py dataset=droid exp_id=finetune_benchmark \ 49 | # dataset.name=droid_$DATA_NAME \ 50 | # dataset.buffer_path=$BUFFER_PATH \ 51 | # pretrain_checkpoint_path=$PRETRAIN_CHECKPOINT_PATH 52 | 53 | 54 | # PAD cotrained 55 | # PRETRAIN_CHECKPOINT_PATH="/gscratch/weirdlab/zchuning/video-action-learning/logdir/pad/droid_mixture/benchmark_cotrain/0/models.pt" 56 | # python experiments/pad/finetune.py dataset=droid exp_id=finetune_benchmark_cotrain \ 57 | # dataset.name=droid_$DATA_NAME \ 58 | # dataset.buffer_path=$BUFFER_PATH \ 59 | # pretrain_checkpoint_path=$PRETRAIN_CHECKPOINT_PATH -------------------------------------------------------------------------------- /experiments/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | from datetime import timedelta 5 | 6 | import numpy as np 7 | import torch 8 | import torch.distributed as dist 9 | import wandb 10 | from omegaconf import OmegaConf 11 | 12 | DATA_TYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 13 | 14 | 15 | def set_seed(seed): 16 | """Set random seed for reproducibility.""" 17 | random.seed(seed) 18 | np.random.seed(seed) 19 | torch.manual_seed(seed) 20 | 21 | 22 | def init_wandb(config, job_type): 23 | """Initialize WANDB logging. Only use in main process. 24 | 25 | Args: 26 | config: config dictionary 27 | job_type: "train" or "eval" 28 | """ 29 | run_id_path = os.path.join(config.logdir, f"run_id_{job_type}.json") 30 | if config.resume and os.path.exists(run_id_path): 31 | # Load WANDB run ID from log directory 32 | with open(run_id_path, "r") as f: 33 | run_id = json.load(f)["run_id"] 34 | else: 35 | # Generate new WANDB run ID 36 | run_id = wandb.util.generate_id() 37 | with open(run_id_path, "w") as f: 38 | json.dump({"run_id": run_id}, f) 39 | 40 | wandb.init( 41 | project="video-action-learning", 42 | job_type=job_type, 43 | group=config.algo, 44 | name="_".join([config.exp_id, str(config.seed)]), 45 | config=OmegaConf.to_container(config, resolve=True), 46 | resume=config.resume, 47 | id=run_id, 48 | ) 49 | 50 | 51 | def init_distributed(rank, world_size): 52 | """Initialize distributed training and set visible device. 53 | 54 | Args: 55 | rank: unique identifier of each process 56 | world_size: total number of processes 57 | """ 58 | os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost") 59 | os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "25678") 60 | dist.init_process_group( 61 | backend="nccl", 62 | rank=rank, 63 | world_size=world_size, 64 | timeout=timedelta(seconds=3600), 65 | ) 66 | 67 | 68 | def get_rank(): 69 | if not dist.is_initialized(): 70 | return 0 71 | return dist.get_rank() 72 | 73 | 74 | def get_world_size(): 75 | if not dist.is_initialized(): 76 | return 1 77 | return dist.get_world_size() 78 | 79 | 80 | def is_main_process(): 81 | return get_rank() == 0 82 | 83 | 84 | def soft_update(target, source, tau): 85 | """Soft update target model with source model.""" 86 | for target_param, param in zip(target.parameters(), source.parameters()): 87 | target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau) 88 | 89 | 90 | class FreezeParameters: 91 | def __init__(self, params): 92 | self.params = params 93 | self.param_states = [p.requires_grad for p in self.params] 94 | 95 | def __enter__(self): 96 | for param in self.params: 97 | param.requires_grad = False 98 | 99 | def __exit__(self, exc_type, exc_val, exc_tb): 100 | for i, param in enumerate(self.params): 101 | param.requires_grad = self.param_states[i] 102 | -------------------------------------------------------------------------------- /environments/robomimic/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .wrappers import RoboMimicEnvWrapper, LIBEROEnvWrapper 4 | 5 | 6 | def make_robomimic_env( 7 | dataset_name, 8 | dataset_path, 9 | shape_meta, 10 | obs_horizon, 11 | max_episode_length, 12 | record=False, 13 | ): 14 | if "robomimic" in dataset_name: 15 | from robomimic.utils.env_utils import get_env_type, get_env_class 16 | from robomimic.utils.file_utils import ( 17 | get_env_metadata_from_dataset, 18 | get_shape_metadata_from_dataset, 19 | ) 20 | from robomimic.utils.obs_utils import initialize_obs_utils_with_obs_specs 21 | 22 | # Initialize observation modalities 23 | rgb_keys = [k for k, v in shape_meta["obs"].items() if v["type"] == "rgb"] 24 | low_dim_keys = [ 25 | k for k, v in shape_meta["obs"].items() if v["type"] == "low_dim" 26 | ] 27 | all_obs_keys = rgb_keys + low_dim_keys 28 | initialize_obs_utils_with_obs_specs( 29 | {"obs": {"rgb": rgb_keys, "low_dim": low_dim_keys}} 30 | ) 31 | 32 | # Create environment 33 | env_meta = get_env_metadata_from_dataset(dataset_path=dataset_path) 34 | env_type = get_env_type(env_meta=env_meta) 35 | env_class = get_env_class(env_type=env_type) 36 | shape_meta = get_shape_metadata_from_dataset( 37 | dataset_path=dataset_path, 38 | all_obs_keys=all_obs_keys, 39 | verbose=True, 40 | ) 41 | # Set render device if CUDA_VISIBLE_DEVICES is set 42 | if os.environ.get("CUDA_VISIBLE_DEVICES", None): 43 | env_meta["env_kwargs"]["render_gpu_device_id"] = int( 44 | os.environ["CUDA_VISIBLE_DEVICES"].split(",")[0] 45 | ) 46 | env = env_class( 47 | env_name=env_meta["env_name"], 48 | render=False, 49 | render_offscreen=True, 50 | use_image_obs=shape_meta["use_images"], 51 | use_depth_obs=shape_meta["use_depths"], 52 | postprocess_visual_obs=False, # use raw images 53 | **env_meta["env_kwargs"], 54 | ) 55 | env = RoboMimicEnvWrapper( 56 | env, all_obs_keys, obs_horizon, max_episode_length, record=record 57 | ) 58 | elif "libero" in dataset_name: 59 | from libero.libero.envs import OffScreenRenderEnv 60 | from libero.libero import get_libero_path 61 | 62 | # Construct environment kwargs 63 | bddl_file_name = os.path.join( 64 | get_libero_path("bddl_files"), 65 | "libero_10", 66 | dataset_path.split("/")[-1].replace("_demo.hdf5", ".bddl"), 67 | ) 68 | env_kwargs = { 69 | "bddl_file_name": bddl_file_name, 70 | "camera_heights": 128, 71 | "camera_widths": 128, 72 | } 73 | 74 | # Set render device if CUDA_VISIBLE_DEVICES is set 75 | if os.environ.get("CUDA_VISIBLE_DEVICES", None): 76 | env_kwargs["render_gpu_device_id"] = int( 77 | os.environ["CUDA_VISIBLE_DEVICES"].split(",")[0] 78 | ) 79 | 80 | # Create environment 81 | env = OffScreenRenderEnv(**env_kwargs) 82 | obs_keys = list(shape_meta["obs"].keys()) 83 | env = LIBEROEnvWrapper( 84 | env, obs_keys, obs_horizon, max_episode_length, record=record 85 | ) 86 | else: 87 | raise NotImplementedError(f"Unsupported environment: {dataset_name}") 88 | return env 89 | -------------------------------------------------------------------------------- /experiments/uwm/eval_robomimic.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import hydra 4 | import imageio 5 | import numpy as np 6 | import torch 7 | from diffusers.optimization import get_scheduler 8 | from omegaconf import OmegaConf 9 | from hydra.utils import instantiate 10 | from tqdm import trange 11 | 12 | from environments.robomimic import make_robomimic_env 13 | from experiments.uwm.train import maybe_resume_checkpoint 14 | from experiments.utils import set_seed, is_main_process 15 | 16 | 17 | def collect_rollout(config, model, device): 18 | model.eval() 19 | model = getattr(model, "module", model) # unwrap DDP 20 | 21 | # Create eval environment 22 | assert isinstance(config.dataset.hdf5_path_globs, str) 23 | env = make_robomimic_env( 24 | dataset_name=config.dataset.name, 25 | dataset_path=config.dataset.hdf5_path_globs, 26 | shape_meta=config.dataset.shape_meta, 27 | obs_horizon=model.obs_encoder.num_frames, 28 | max_episode_length=config.rollout_length, 29 | record=True, 30 | ) 31 | 32 | # Collect rollouts 33 | video_dir = os.path.join(config.logdir, "videos") 34 | if not os.path.exists(video_dir): 35 | os.mkdir(video_dir) 36 | successes = [] 37 | for e in trange( 38 | config.num_rollouts, desc="Collecting rollouts", disable=not is_main_process() 39 | ): 40 | env.seed(e) 41 | obs = env.reset() 42 | done = False 43 | while not done: 44 | obs_tensor = { 45 | k: torch.tensor(v, device=device)[None] for k, v in obs.items() 46 | } 47 | 48 | # Sample action from model 49 | action = model.sample(obs_tensor)[0].cpu().numpy() 50 | 51 | # Step environment 52 | obs, reward, done, info = env.step(action) 53 | successes.append(info["success"]) 54 | video = env.get_video() 55 | imageio.mimwrite(os.path.join(video_dir, f"{e}.mp4"), video, fps=30) 56 | print( 57 | f"Episode {e} success: {info['success']}, cumulative: {np.mean(successes):.2f}" 58 | ) 59 | 60 | # Compute success rate 61 | success_rate = sum(successes) / len(successes) 62 | return success_rate 63 | 64 | 65 | def maybe_collect_rollout(config, step, model, device): 66 | """Collect rollouts on the main process if it's the correct step.""" 67 | # Skip rollout rollection for pretraining 68 | if "libero_90" in config.dataset.name: 69 | return 70 | 71 | if is_main_process() and ( 72 | step % config.rollout_every == 0 or step == (config.num_steps - 1) 73 | ): 74 | success_rate = collect_rollout(config, model, device) 75 | print(f"Step: {step} success rate: {success_rate}") 76 | 77 | 78 | @hydra.main( 79 | version_base=None, 80 | config_path="../../configs", 81 | config_name="train_uwm_robomimic.yaml", 82 | ) 83 | def main(config): 84 | # Resolve hydra config 85 | OmegaConf.resolve(config) 86 | set_seed(0) 87 | device = torch.device(f"cuda:0") 88 | 89 | # Create model 90 | model = instantiate(config.model).to(device) 91 | optimizer = torch.optim.AdamW(model.parameters(), **config.optimizer) 92 | scheduler = get_scheduler(optimizer=optimizer, **config.scheduler) 93 | scaler = torch.cuda.amp.GradScaler(enabled=config.use_amp) 94 | 95 | # Resume from checkpoint 96 | config.resume = True 97 | step = maybe_resume_checkpoint(config, model, optimizer, scheduler, scaler) 98 | maybe_collect_rollout(config, 0, model, device) 99 | 100 | 101 | if __name__ == "__main__": 102 | main() 103 | -------------------------------------------------------------------------------- /models/gr1/flamingo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | 4 | from einops import rearrange, repeat 5 | 6 | 7 | def FeedForward(dim, mult=4): 8 | inner_dim = int(dim * mult) 9 | return nn.Sequential( 10 | nn.LayerNorm(dim), 11 | nn.Linear(dim, inner_dim, bias=False), 12 | nn.GELU(), 13 | nn.Linear(inner_dim, dim, bias=False), 14 | ) 15 | 16 | 17 | class PerceiverAttention(nn.Module): 18 | def __init__(self, *, dim, dim_head=64, heads=8): 19 | super().__init__() 20 | self.scale = dim_head**-0.5 21 | self.heads = heads 22 | inner_dim = dim_head * heads 23 | 24 | self.norm_media = nn.LayerNorm(dim) 25 | self.norm_latents = nn.LayerNorm(dim) 26 | 27 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 28 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 29 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 30 | 31 | def forward(self, x, latents): 32 | """ 33 | einstein notation 34 | b - batch 35 | t - time 36 | n - sequence 37 | d - dimension 38 | """ 39 | x = self.norm_media(x) 40 | latents = self.norm_latents(latents) 41 | 42 | b, m, h = *x.shape[:2], self.heads 43 | 44 | q = self.to_q(latents) 45 | 46 | # the paper differs from Perceiver in which they also concat the key / values derived from the latents to be attended to 47 | kv_input = torch.cat((x, latents), dim=-2) 48 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 49 | 50 | q = rearrange(q, "b t n (h d) -> b h t n d", h=h) 51 | k = rearrange(k, "b t n (h d) -> b h t n d", h=h) 52 | v = rearrange(v, "b t n (h d) -> b h t n d", h=h) 53 | 54 | # attention 55 | q = q * self.scale 56 | sim = einsum("... i d, ... j d -> ... i j", q, k) 57 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 58 | attn = sim.softmax(dim=-1) 59 | 60 | out = einsum("... i j, ... j d -> ... i d", attn, v) 61 | out = rearrange(out, "b h t n d -> b t n (h d)", h=h) 62 | return self.to_out(out) 63 | 64 | 65 | class PerceiverResampler(nn.Module): 66 | def __init__( 67 | self, 68 | *, 69 | dim, 70 | depth, 71 | dim_head=64, 72 | heads=8, 73 | num_latents=64, 74 | num_media_embeds=4, 75 | ff_mult=4 76 | ): 77 | super().__init__() 78 | self.num_latents = num_latents 79 | self.latents = nn.Parameter(torch.randn(num_latents, dim)) 80 | self.media_pos_emb = nn.Parameter(torch.randn(num_media_embeds, 1, dim)) 81 | 82 | self.layers = nn.ModuleList([]) 83 | for _ in range(depth): 84 | self.layers.append( 85 | nn.ModuleList( 86 | [ 87 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 88 | FeedForward(dim=dim, mult=ff_mult), 89 | ] 90 | ) 91 | ) 92 | 93 | self.norm = nn.LayerNorm(dim) 94 | 95 | def forward(self, x): 96 | if x.ndim == 3: 97 | x = rearrange(x, "b n d -> b 1 n d") 98 | 99 | times = x.shape[1] 100 | x = x + self.media_pos_emb[:times] 101 | 102 | latents = repeat(self.latents, "n d -> b m n d", b=x.shape[0], m=x.shape[1]) 103 | for attn, ff in self.layers: 104 | latents = attn(x, latents) + latents 105 | latents = ff(latents) + latents 106 | return self.norm(latents) 107 | -------------------------------------------------------------------------------- /models/dp/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models.common.adaln_attention import AdaLNAttentionBlock, AdaLNFinalLayer 5 | from models.common.utils import SinusoidalPosEmb, init_weights 6 | from .base_policy import NoisePredictionNet 7 | 8 | 9 | class TransformerNoisePredictionNet(NoisePredictionNet): 10 | def __init__( 11 | self, 12 | input_len: int, 13 | input_dim: int, 14 | global_cond_dim: int, 15 | timestep_embed_dim: int = 256, 16 | embed_dim: int = 768, 17 | depth: int = 12, 18 | num_heads: int = 12, 19 | mlp_ratio: float = 4.0, 20 | qkv_bias: bool = True, 21 | ): 22 | super().__init__() 23 | self.input_len = input_len 24 | 25 | # Input encoder and decoder 26 | hidden_dim = int(max(input_dim, embed_dim) * mlp_ratio) 27 | self.input_encoder = nn.Sequential( 28 | nn.Linear(input_dim, hidden_dim), 29 | nn.Mish(), 30 | nn.Linear(hidden_dim, embed_dim), 31 | ) 32 | self.output_decoder = nn.Sequential( 33 | nn.Linear(embed_dim, hidden_dim), 34 | nn.Mish(), 35 | nn.Linear(hidden_dim, input_dim), 36 | ) 37 | 38 | # Timestep encoder 39 | self.timestep_encoder = nn.Sequential( 40 | SinusoidalPosEmb(timestep_embed_dim), 41 | nn.Linear(timestep_embed_dim, timestep_embed_dim * 4), 42 | nn.Mish(), 43 | nn.Linear(timestep_embed_dim * 4, timestep_embed_dim), 44 | ) 45 | 46 | # Model components 47 | self.pos_embed = nn.Parameter( 48 | torch.empty(1, input_len, embed_dim).normal_(std=0.02) 49 | ) 50 | cond_dim = global_cond_dim + timestep_embed_dim 51 | self.blocks = nn.ModuleList( 52 | [ 53 | AdaLNAttentionBlock( 54 | dim=embed_dim, 55 | cond_dim=cond_dim, 56 | num_heads=num_heads, 57 | mlp_ratio=mlp_ratio, 58 | qkv_bias=qkv_bias, 59 | ) 60 | for _ in range(depth) 61 | ] 62 | ) 63 | self.head = AdaLNFinalLayer(dim=embed_dim, cond_dim=cond_dim) 64 | 65 | # AdaLN-specific weight initialization 66 | self.initialize_weights() 67 | 68 | def initialize_weights(self): 69 | # Base initialization 70 | self.apply(init_weights) 71 | 72 | # Zero-out adaLN modulation layers in DiT blocks: 73 | for block in self.blocks: 74 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 75 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 76 | 77 | # Zero-out output layers: 78 | nn.init.constant_(self.head.adaLN_modulation[-1].weight, 0) 79 | nn.init.constant_(self.head.adaLN_modulation[-1].bias, 0) 80 | nn.init.constant_(self.head.linear.weight, 0) 81 | nn.init.constant_(self.head.linear.bias, 0) 82 | 83 | def forward(self, sample, timestep, global_cond): 84 | # Encode input 85 | embed = self.input_encoder(sample) 86 | 87 | # Encode timestep 88 | if len(timestep.shape) == 0: 89 | timestep = timestep.expand(sample.shape[0]).to( 90 | dtype=torch.long, device=sample.device 91 | ) 92 | temb = self.timestep_encoder(timestep) 93 | 94 | # Concatenate timestep and condition along the sequence dimension 95 | x = embed + self.pos_embed 96 | cond = torch.cat([global_cond, temb], dim=-1) 97 | for block in self.blocks: 98 | x = block(x, cond) 99 | x = self.head(x, cond) 100 | 101 | # Decode output 102 | out = self.output_decoder(x) 103 | return out 104 | -------------------------------------------------------------------------------- /datasets/utils/mixture.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class RobotVideoMixtureDataset(Dataset): 8 | def __init__( 9 | self, 10 | robot_dataset: Dataset, 11 | video_dataset: Dataset, 12 | balance_datasets: bool = False, 13 | ): 14 | self.robot_dataset = robot_dataset 15 | self.video_dataset = video_dataset 16 | 17 | # Copy action and lowdim normalizers from robot dataset 18 | if hasattr(self.robot_dataset, "action_normalizer"): 19 | self.action_normalizer = self.robot_dataset.action_normalizer 20 | if hasattr(self.robot_dataset, "lowdim_normalizer"): 21 | self.lowdim_normalizer = self.robot_dataset.lowdim_normalizer 22 | 23 | # Balance robot and video datasets 24 | if balance_datasets: 25 | # Figure out dataset lengths 26 | len_robot = len(self.robot_dataset) 27 | len_video = len(self.video_dataset) 28 | max_len = max(len_robot, len_video) 29 | print( 30 | f"Balancing data: {len_robot} robot, {len_video} video -> upsample to {max_len} each" 31 | ) 32 | 33 | # Upsample robot data 34 | robot_factor = math.ceil(max_len / len_robot) 35 | robot_indices = [] 36 | for _ in range(robot_factor): 37 | robot_indices.extend(range(len_robot)) 38 | robot_indices = robot_indices[:max_len] 39 | 40 | # Upsample video data 41 | video_factor = math.ceil(max_len / len_video) 42 | video_indices = [] 43 | for _ in range(video_factor): 44 | video_indices.extend(range(len_video)) 45 | video_indices = video_indices[:max_len] 46 | 47 | # Combine robot and video data 48 | combined_indices = [] 49 | for r_i in robot_indices: 50 | combined_indices.append((True, r_i)) 51 | for v_i in video_indices: 52 | combined_indices.append((False, v_i)) 53 | self.index_map = combined_indices # list of (bool, idx) 54 | else: 55 | # Just combine the two datasets 56 | combined_indices = [] 57 | for r_i in range(len(self.robot_dataset)): 58 | combined_indices.append((True, r_i)) 59 | for v_i in range(len(self.video_dataset)): 60 | combined_indices.append((False, v_i)) 61 | self.index_map = combined_indices 62 | 63 | def __len__(self): 64 | return len(self.index_map) 65 | 66 | def __getitem__(self, idx): 67 | is_robot, dataset_idx = self.index_map[idx] 68 | if is_robot: 69 | data = self.robot_dataset[dataset_idx] 70 | data["action_mask"] = torch.tensor(1, dtype=torch.bool) 71 | else: 72 | data = self.video_dataset[dataset_idx] 73 | data["action_mask"] = torch.tensor(0, dtype=torch.bool) 74 | return data 75 | 76 | 77 | def make_robot_video_mixture_dataset( 78 | robot_train_val_sets: tuple[Dataset, Dataset], 79 | video_train_val_sets: tuple[Dataset, Dataset], 80 | balance_datasets: bool = False, 81 | **kwargs, 82 | ): 83 | """ 84 | Combine robot and video datasets into a mixture dataset. 85 | 86 | This function merges the training sets from both robot and video datasets to create a single training set, 87 | while using the robot dataset's validation set for evaluation. Additional keyword arguments (kwargs) are 88 | captured for compatibility with other functions in the codebase. 89 | """ 90 | robot_train_set, robot_val_set = robot_train_val_sets 91 | video_train_set, video_val_set = video_train_val_sets 92 | train_set = RobotVideoMixtureDataset( 93 | robot_train_set, video_train_set, balance_datasets 94 | ) 95 | return train_set, robot_val_set 96 | -------------------------------------------------------------------------------- /experiments/uwm/train_webvideo.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import torch 3 | import torch.multiprocessing as mp 4 | import wandb 5 | from diffusers.optimization import get_scheduler 6 | from hydra.utils import instantiate 7 | from omegaconf import OmegaConf 8 | from torch.nn.parallel import DistributedDataParallel 9 | from tqdm import tqdm 10 | 11 | from datasets.utils.loader import make_distributed_data_loader 12 | from experiments.utils import set_seed, init_wandb, init_distributed, is_main_process 13 | from experiments.uwm.train import ( 14 | train_one_step, 15 | maybe_resume_checkpoint, 16 | maybe_evaluate, 17 | maybe_save_checkpoint, 18 | ) 19 | 20 | 21 | def train(rank, world_size, config): 22 | # Set global seed 23 | set_seed(config.seed * world_size + rank) 24 | 25 | # Initialize distributed training 26 | init_distributed(rank, world_size) 27 | device = torch.device(f"cuda:{rank}") 28 | 29 | # Initialize WANDB 30 | if is_main_process(): 31 | init_wandb(config, job_type="train") 32 | 33 | # Create dataset and loader 34 | train_set, val_set = instantiate(config.dataset) 35 | train_loader, val_loader = make_distributed_data_loader( 36 | train_set, val_set, config.batch_size, rank, world_size 37 | ) 38 | 39 | # Create model 40 | model = instantiate(config.model).to(device) 41 | optimizer = torch.optim.AdamW(model.parameters(), **config.optimizer) 42 | scheduler = get_scheduler(optimizer=optimizer, **config.scheduler) 43 | scaler = torch.cuda.amp.GradScaler(enabled=config.use_amp) 44 | 45 | # Load pretrained model 46 | if config.pretrain_checkpoint_path: 47 | ckpt = torch.load(config.pretrain_checkpoint_path, map_location="cpu") 48 | model.load_state_dict(ckpt["model"]) 49 | print( 50 | f"Loaded pretraining checkpoint {config.pretrain_checkpoint_path}, step: {ckpt['step']}" 51 | ) 52 | 53 | # Resume from checkpoint 54 | step = maybe_resume_checkpoint(config, model, optimizer, scheduler, scaler) 55 | epoch = step // len(train_loader) 56 | 57 | # Wrap model with DDP 58 | model = DistributedDataParallel(model, device_ids=[rank], static_graph=True) 59 | 60 | # Training loop 61 | pbar = tqdm( 62 | total=config.num_steps, 63 | initial=step, 64 | desc="Training", 65 | disable=not is_main_process(), 66 | ) 67 | while step < config.num_steps: 68 | # Set epoch for distributed sampler to shuffle indices 69 | train_loader.sampler.set_epoch(epoch) 70 | 71 | # Train for one epoch 72 | for batch in train_loader: 73 | # --- Training step --- 74 | loss, info = train_one_step( 75 | config, model, optimizer, scheduler, scaler, batch, device 76 | ) 77 | 78 | # --- Logging --- 79 | if is_main_process(): 80 | pbar.set_description(f"step: {step}, loss: {loss.item():.4f}") 81 | wandb.log({f"train/{k}": v for k, v in info.items()}) 82 | 83 | # --- Evaluate if needed --- 84 | maybe_evaluate(config, step, model, val_loader, device) 85 | 86 | # --- Save checkpoint if needed --- 87 | maybe_save_checkpoint(config, step, model, optimizer, scheduler, scaler) 88 | 89 | step += 1 90 | pbar.update(1) 91 | if step >= config.num_steps: 92 | break 93 | 94 | epoch += 1 95 | 96 | 97 | @hydra.main( 98 | version_base=None, config_path="../../configs", config_name="train_uwm.yaml" 99 | ) 100 | def main(config): 101 | # Resolve hydra config 102 | OmegaConf.resolve(config) 103 | # Spawn processes 104 | world_size = torch.cuda.device_count() 105 | mp.spawn(train, args=(world_size, config), nprocs=world_size, join=True) 106 | 107 | 108 | if __name__ == "__main__": 109 | main() 110 | -------------------------------------------------------------------------------- /experiments/gr1/train_robomimic.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import torch 3 | import torch.multiprocessing as mp 4 | import wandb 5 | from diffusers.optimization import get_scheduler 6 | from hydra.utils import instantiate 7 | from omegaconf import OmegaConf 8 | from torch.nn.parallel import DistributedDataParallel 9 | from tqdm import tqdm 10 | 11 | from datasets.utils.loader import make_distributed_data_loader 12 | from experiments.utils import set_seed, init_wandb, init_distributed, is_main_process 13 | from experiments.gr1.train import maybe_evaluate 14 | from experiments.uwm.train import ( 15 | train_one_step, 16 | maybe_resume_checkpoint, 17 | maybe_save_checkpoint, 18 | ) 19 | from experiments.uwm.train_robomimic import maybe_collect_rollout 20 | 21 | 22 | def train(rank, world_size, config): 23 | # Set global seed 24 | set_seed(config.seed * world_size + rank) 25 | 26 | # Initialize distributed training 27 | init_distributed(rank, world_size) 28 | device = torch.device(f"cuda:{rank}") 29 | 30 | # Initialize WANDB 31 | if is_main_process(): 32 | init_wandb(config, job_type="train") 33 | 34 | # Create dataset and loader 35 | train_set, val_set = instantiate(config.dataset) 36 | train_loader, val_loader = make_distributed_data_loader( 37 | train_set, val_set, config.batch_size, rank, world_size 38 | ) 39 | 40 | # Create model 41 | model = instantiate(config.model).to(device) 42 | optimizer = torch.optim.AdamW(model.parameters(), **config.optimizer) 43 | scheduler = get_scheduler(optimizer=optimizer, **config.scheduler) 44 | scaler = torch.cuda.amp.GradScaler(enabled=config.use_amp) 45 | 46 | # Load pretrained model 47 | if config.pretrain_checkpoint_path: 48 | ckpt = torch.load(config.pretrain_checkpoint_path, map_location="cpu") 49 | model.load_state_dict(ckpt["model"]) 50 | print( 51 | f"Loaded pretraining checkpoint {config.pretrain_checkpoint_path}, step: {ckpt['step']}" 52 | ) 53 | 54 | # Resume from checkpoint 55 | step = maybe_resume_checkpoint(config, model, optimizer, scheduler, scaler) 56 | epoch = step // len(train_loader) 57 | 58 | # Wrap model with DDP 59 | model = DistributedDataParallel(model, device_ids=[rank], static_graph=True) 60 | 61 | # Training loop 62 | pbar = tqdm( 63 | total=config.num_steps, 64 | initial=step, 65 | desc="Training", 66 | disable=not is_main_process(), 67 | ) 68 | while step < config.num_steps: 69 | # Set epoch for distributed sampler to shuffle indices 70 | train_loader.sampler.set_epoch(epoch) 71 | 72 | # Train for one epoch 73 | for batch in train_loader: 74 | # --- Training step --- 75 | loss, info = train_one_step( 76 | config, model, optimizer, scheduler, scaler, batch, device 77 | ) 78 | 79 | # --- Logging --- 80 | if is_main_process(): 81 | pbar.set_description(f"step: {step}, loss: {loss.item():.4f}") 82 | wandb.log({f"train/{k}": v for k, v in info.items()}) 83 | 84 | # --- Evaluate if needed --- 85 | maybe_evaluate(config, step, model, val_loader, device) 86 | 87 | # ---Collect environment rollouts if needed --- 88 | maybe_collect_rollout(config, step, model, device) 89 | 90 | # --- Save checkpoint if needed --- 91 | maybe_save_checkpoint(config, step, model, optimizer, scheduler, scaler) 92 | 93 | step += 1 94 | pbar.update(1) 95 | if step >= config.num_steps: 96 | break 97 | 98 | epoch += 1 99 | 100 | 101 | @hydra.main( 102 | version_base=None, 103 | config_path="../../configs", 104 | config_name="train_gr1_robomimic.yaml", 105 | ) 106 | def main(config): 107 | # Resolve hydra config 108 | OmegaConf.resolve(config) 109 | # Spawn processes 110 | world_size = torch.cuda.device_count() 111 | mp.spawn(train, args=(world_size, config), nprocs=world_size, join=True) 112 | 113 | 114 | if __name__ == "__main__": 115 | main() 116 | -------------------------------------------------------------------------------- /experiments/pad/train_robomimic.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import torch 3 | import torch.multiprocessing as mp 4 | import wandb 5 | from diffusers.optimization import get_scheduler 6 | from hydra.utils import instantiate 7 | from omegaconf import OmegaConf 8 | from torch.nn.parallel import DistributedDataParallel 9 | from tqdm import tqdm 10 | 11 | from datasets.utils.loader import make_distributed_data_loader 12 | from experiments.utils import set_seed, init_wandb, init_distributed, is_main_process 13 | from experiments.pad.train import maybe_evaluate 14 | from experiments.uwm.train import ( 15 | train_one_step, 16 | maybe_resume_checkpoint, 17 | maybe_save_checkpoint, 18 | ) 19 | from experiments.uwm.train_robomimic import maybe_collect_rollout 20 | 21 | 22 | def train(rank, world_size, config): 23 | # Set global seed 24 | set_seed(config.seed * world_size + rank) 25 | 26 | # Initialize distributed training 27 | init_distributed(rank, world_size) 28 | device = torch.device(f"cuda:{rank}") 29 | 30 | # Initialize WANDB 31 | if is_main_process(): 32 | init_wandb(config, job_type="train") 33 | 34 | # Create dataset and loader 35 | train_set, val_set = instantiate(config.dataset) 36 | train_loader, val_loader = make_distributed_data_loader( 37 | train_set, val_set, config.batch_size, rank, world_size 38 | ) 39 | 40 | # Create model 41 | model = instantiate(config.model).to(device) 42 | optimizer = torch.optim.AdamW(model.parameters(), **config.optimizer) 43 | scheduler = get_scheduler(optimizer=optimizer, **config.scheduler) 44 | scaler = torch.cuda.amp.GradScaler(enabled=config.use_amp) 45 | 46 | # Load pretrained model 47 | if config.pretrain_checkpoint_path: 48 | ckpt = torch.load(config.pretrain_checkpoint_path, map_location="cpu") 49 | model.load_state_dict(ckpt["model"]) 50 | print( 51 | f"Loaded pretraining checkpoint {config.pretrain_checkpoint_path}, step: {ckpt['step']}" 52 | ) 53 | 54 | # Resume from checkpoint 55 | step = maybe_resume_checkpoint(config, model, optimizer, scheduler, scaler) 56 | epoch = step // len(train_loader) 57 | 58 | # Wrap model with DDP 59 | model = DistributedDataParallel(model, device_ids=[rank], static_graph=True) 60 | 61 | # Training loop 62 | pbar = tqdm( 63 | total=config.num_steps, 64 | initial=step, 65 | desc="Training", 66 | disable=not is_main_process(), 67 | ) 68 | while step < config.num_steps: 69 | # Set epoch for distributed sampler to shuffle indices 70 | train_loader.sampler.set_epoch(epoch) 71 | 72 | # Train for one epoch 73 | for batch in train_loader: 74 | # --- Training step --- 75 | loss, info = train_one_step( 76 | config, model, optimizer, scheduler, scaler, batch, device 77 | ) 78 | 79 | # --- Logging --- 80 | if is_main_process(): 81 | pbar.set_description(f"step: {step}, loss: {loss.item():.4f}") 82 | wandb.log({f"train/{k}": v for k, v in info.items()}) 83 | 84 | # --- Evaluate if needed --- 85 | maybe_evaluate(config, step, model, val_loader, device) 86 | 87 | # ---Collect environment rollouts if needed --- 88 | maybe_collect_rollout(config, step, model, device) 89 | 90 | # --- Save checkpoint if needed --- 91 | maybe_save_checkpoint(config, step, model, optimizer, scheduler, scaler) 92 | 93 | step += 1 94 | pbar.update(1) 95 | if step >= config.num_steps: 96 | break 97 | 98 | epoch += 1 99 | 100 | 101 | @hydra.main( 102 | version_base=None, 103 | config_path="../../configs", 104 | config_name="train_pad_robomimic.yaml", 105 | ) 106 | def main(config): 107 | # Resolve hydra config 108 | OmegaConf.resolve(config) 109 | # Spawn processes 110 | world_size = torch.cuda.device_count() 111 | mp.spawn(train, args=(world_size, config), nprocs=world_size, join=True) 112 | 113 | 114 | if __name__ == "__main__": 115 | main() 116 | -------------------------------------------------------------------------------- /experiments/uwm/ablate_inverse_dynamics.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import hydra 4 | import h5py 5 | import imageio 6 | import torch 7 | from hydra.utils import instantiate 8 | from libero.libero.envs import OffScreenRenderEnv 9 | from libero.libero import get_libero_path 10 | from omegaconf import OmegaConf 11 | from tqdm import trange 12 | 13 | 14 | from datasets.utils.loader import make_distributed_data_loader 15 | from environments.robomimic.wrappers import LIBEROEnvWrapper 16 | from experiments.utils import set_seed 17 | 18 | 19 | def eval_inverse_dynamics( 20 | hdf5_path, obs_keys, obs_horizon, action_horizon, max_episode_length, model, device 21 | ): 22 | # Set eval mode 23 | model.eval() 24 | 25 | # Make environment to verify demo 26 | bddl_file_name = os.path.join( 27 | get_libero_path("bddl_files"), 28 | "libero_10", 29 | hdf5_path.split("/")[-1].replace("_demo.hdf5", ".bddl"), 30 | ) 31 | env_kwargs = { 32 | "bddl_file_name": bddl_file_name, 33 | "camera_heights": 128, 34 | "camera_widths": 128, 35 | } 36 | if os.environ.get("CUDA_VISIBLE_DEVICES", None): 37 | env_kwargs["render_gpu_device_id"] = int( 38 | os.environ["CUDA_VISIBLE_DEVICES"].split(",")[0] 39 | ) 40 | env = OffScreenRenderEnv(**env_kwargs) 41 | env = LIBEROEnvWrapper(env, obs_keys, obs_horizon, max_episode_length, record=True) 42 | 43 | successes = [] 44 | with h5py.File(hdf5_path) as f: 45 | demos = f["data"] 46 | for i in trange(len(demos)): 47 | demo = demos[f"demo_{i}"] 48 | obs = env.reset_to_state(demo["states"][0]) 49 | next_index = action_horizon 50 | done = False 51 | while next_index + obs_horizon <= len(demo["actions"]) and not done: 52 | next_obs = { 53 | key: demo["obs"][key][next_index : next_index + obs_horizon] 54 | for key in obs_keys 55 | } 56 | obs_tensor = { 57 | k: torch.tensor(v, device=device)[None] for k, v in obs.items() 58 | } 59 | next_obs_tensor = { 60 | k: torch.tensor(v, device=device)[None] for k, v in next_obs.items() 61 | } 62 | action = model.sample_inverse_dynamics(obs_tensor, next_obs_tensor) 63 | # action = model.sample(obs_tensor) 64 | obs, _, done, info = env.step(action[0].cpu().numpy()) 65 | next_index += action_horizon 66 | successes.append(info["success"]) 67 | print(f"Episode {i}, success: {successes[-1]}") 68 | 69 | # Save video locally 70 | video = env.get_video() 71 | imageio.mimsave(f"episode_{i}.gif", video) 72 | 73 | print(f"Total: {len(demos)}, success: {sum(successes)}") 74 | 75 | 76 | def train(rank, world_size, config): 77 | # Set global seed 78 | set_seed(config.seed * world_size + rank) 79 | 80 | # Initialize distributed training 81 | device = torch.device(f"cuda:{rank}") 82 | 83 | # Create dataset 84 | train_set, val_set = instantiate(config.dataset) 85 | train_loader, val_loader = make_distributed_data_loader( 86 | train_set, val_set, config.batch_size, rank, world_size 87 | ) 88 | 89 | # Create model 90 | model = instantiate(config.model).to(device) 91 | 92 | # Resume from checkpoint 93 | ckpt_path = os.path.join(config.logdir, "models.pt") 94 | ckpt = torch.load(ckpt_path, map_location="cpu") 95 | model.load_state_dict(ckpt["model"]) 96 | step = ckpt["step"] + 1 97 | print(f"Loading checkpoint from step {step}") 98 | 99 | if ckpt["action_normalizer"] is not None: 100 | train_set.action_normalizer = ckpt["action_normalizer"] 101 | val_set.action_normalizer = ckpt["action_normalizer"] 102 | if ckpt["lowdim_normalizer"] is not None: 103 | train_set.lowdim_normalizer = ckpt["lowdim_normalizer"] 104 | val_set.lowdim_normalizer = ckpt["lowdim_normalizer"] 105 | 106 | obs_keys = list(model.obs_encoder.rgb_keys) + list(model.obs_encoder.low_dim_keys) 107 | eval_inverse_dynamics( 108 | config.dataset.hdf5_path_globs, 109 | obs_keys, 110 | config.model.obs_encoder.num_frames, 111 | config.model.action_len, 112 | config.rollout_length, 113 | model, 114 | device, 115 | ) 116 | 117 | 118 | @hydra.main( 119 | version_base=None, 120 | config_path="../../configs", 121 | config_name="train_uwm_robomimic.yaml", 122 | ) 123 | def main(config): 124 | # Resolve hydra config 125 | OmegaConf.resolve(config) 126 | train(0, 1, config) 127 | 128 | 129 | if __name__ == "__main__": 130 | main() 131 | -------------------------------------------------------------------------------- /environments/robomimic/wrappers.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | 3 | import numpy as np 4 | 5 | 6 | class RoboMimicEnvWrapper: 7 | def __init__( 8 | self, 9 | env, 10 | obs_keys, 11 | obs_horizon, 12 | max_episode_length, 13 | record=False, 14 | render_size=(224, 224), 15 | ): 16 | self.env = env 17 | self.obs_keys = obs_keys 18 | self.obs_buffer = deque(maxlen=obs_horizon) 19 | 20 | self._max_episode_length = max_episode_length 21 | self._elapsed_steps = None 22 | 23 | self.record = record 24 | self.render_size = render_size 25 | if record: 26 | self.video_buffer = deque() 27 | 28 | def _is_success(self): 29 | return self.env.is_success()["task"] 30 | 31 | def _get_obs(self): 32 | # Return a dictionary of stacked observations 33 | stacked_obs = {} 34 | for key in self.obs_keys: 35 | stacked_obs[key] = np.stack([obs[key] for obs in self.obs_buffer]) 36 | return stacked_obs 37 | 38 | def seed(self, seed): 39 | np.random.seed(seed) 40 | 41 | def reset(self): 42 | # Clear buffers 43 | self.obs_buffer.clear() 44 | if self.record: 45 | self.video_buffer.clear() 46 | 47 | # Reset environment 48 | obs = self.env.reset() 49 | self._elapsed_steps = 0 50 | 51 | # Pad observation buffer 52 | for _ in range(self.obs_buffer.maxlen): 53 | self.obs_buffer.append(obs) 54 | 55 | return self._get_obs() 56 | 57 | def step(self, actions): 58 | # Roll out a sequence of actions in the environment 59 | total_reward = 0 60 | for action in actions: 61 | # Step environment 62 | obs, reward, done, info = self.env.step(action) 63 | total_reward += reward 64 | self.obs_buffer.append(obs) 65 | if self.record: 66 | self.video_buffer.append(self.render()) 67 | 68 | # Store success info 69 | info["success"] = self._is_success() 70 | 71 | # Terminate on success 72 | done = done or info["success"] 73 | 74 | # Terminate if max episode length is reached 75 | self._elapsed_steps += 1 76 | if self._elapsed_steps >= self._max_episode_length: 77 | info["truncated"] = not done 78 | done = True 79 | 80 | if done: 81 | break 82 | 83 | return self._get_obs(), total_reward, done, info 84 | 85 | def render(self): 86 | return self.env.render( 87 | mode="rgb_array", 88 | width=self.render_size[0], 89 | height=self.render_size[1], 90 | ) 91 | 92 | def get_video(self): 93 | if not self.record: 94 | raise ValueError("Video recording is disabled.") 95 | return np.stack(self.video_buffer) 96 | 97 | def close(self): 98 | self.env.close() 99 | 100 | 101 | class LIBEROEnvWrapper(RoboMimicEnvWrapper): 102 | def __init__( 103 | self, 104 | env, 105 | obs_keys, 106 | obs_horizon, 107 | max_episode_length, 108 | record=False, 109 | render_size=(224, 224), 110 | ): 111 | super().__init__( 112 | env, 113 | obs_keys, 114 | obs_horizon, 115 | max_episode_length, 116 | record, 117 | render_size, 118 | ) 119 | self.source_key_map = { 120 | "agentview_rgb": "agentview_image", 121 | "eye_in_hand_rgb": "robot0_eye_in_hand_image", 122 | } 123 | 124 | def reset_to_state(self, state): 125 | self.seed(0) 126 | self.reset() 127 | self.env.set_init_state(state) 128 | 129 | # Refresh obs buffer 130 | self.obs_buffer.clear() 131 | obs = self.env.env._get_observations() 132 | for _ in range(self.obs_buffer.maxlen): 133 | self.obs_buffer.append(obs) 134 | return self._get_obs() 135 | 136 | def _is_success(self): 137 | return self.env.check_success() 138 | 139 | def _get_obs(self): 140 | # Return a dictionary of stacked observations 141 | stacked_obs = {} 142 | for key in self.obs_keys: 143 | source_key = self.source_key_map.get(key, key) 144 | stacked_obs[key] = np.stack([obs[source_key] for obs in self.obs_buffer]) 145 | 146 | # Flip all image observations 147 | for key in self.obs_keys: 148 | if len(stacked_obs[key].shape) == 4: 149 | stacked_obs[key] = stacked_obs[key][:, ::-1].copy() 150 | return stacked_obs 151 | 152 | def render(self): 153 | img = self.env.env.sim.render( 154 | height=self.render_size[1], 155 | width=self.render_size[0], 156 | camera_name="frontview", 157 | ) 158 | return img[::-1] 159 | -------------------------------------------------------------------------------- /experiments/uwm/ablate_forward_dynamics.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import hydra 4 | import imageio 5 | import torch 6 | import numpy as np 7 | from diffusers.optimization import get_scheduler 8 | from einops import rearrange 9 | from hydra.utils import instantiate 10 | from omegaconf import OmegaConf 11 | from torchvision.utils import make_grid 12 | from tqdm import tqdm 13 | 14 | from datasets.utils.loader import make_distributed_data_loader 15 | from experiments.utils import set_seed, is_main_process 16 | 17 | 18 | def process_batch(batch, obs_horizon, action_horizon, device): 19 | action_start = obs_horizon - 1 20 | action_end = action_start + action_horizon 21 | curr_obs = {k: v[:, : action_start + 1].to(device) for k, v in batch["obs"].items()} 22 | next_obs = {k: v[:, action_end:].to(device) for k, v in batch["obs"].items()} 23 | actions = batch["action"][:, action_start:action_end].to(device) 24 | 25 | # Add language tokens 26 | if "input_ids" in batch and "attention_mask" in batch: 27 | curr_obs["input_ids"] = batch["input_ids"].to(device) 28 | curr_obs["attention_mask"] = batch["attention_mask"].to(device) 29 | return curr_obs, next_obs, actions 30 | 31 | 32 | def eval_one_epoch(config, data_loader, device, model): 33 | model.eval() 34 | model = getattr(model, "module", model) # unwrap DDP 35 | 36 | def decode_and_plot(images, nrows=2): 37 | images = model.obs_encoder.apply_vae(images, inverse=True) 38 | images = rearrange(images, "b v c t h w -> (b v t) c h w").clamp(0, 1) 39 | images_grid = make_grid(images, nrows) 40 | return ( 41 | (images_grid.cpu().numpy() * 255) 42 | .round() 43 | .astype(np.uint8) 44 | .transpose(1, 2, 0) 45 | ) 46 | 47 | save_path = f"viz_{config.dataset.name}" 48 | if not os.path.exists(save_path): 49 | os.mkdir(save_path) 50 | step = 0 51 | for batch in tqdm(data_loader, desc="Evaluating", disable=not is_main_process()): 52 | # ------------ Preprocess data ------------ # 53 | curr_obs_dict, next_obs_dict, action_norm = process_batch( 54 | batch, config.model.obs_encoder.num_frames, config.model.action_len, device 55 | ) 56 | 57 | with torch.no_grad(): 58 | # Encode current observations 59 | curr_obs = model.obs_encoder.encode_next_obs(curr_obs_dict) 60 | 61 | # Encode next observations 62 | next_obs = model.obs_encoder.encode_next_obs(next_obs_dict) 63 | 64 | # Sample observations from forward dynamics 65 | next_obs_hat_forward = model.sample_forward_dynamics( 66 | curr_obs_dict, action_norm 67 | ) 68 | 69 | # Plot current observations 70 | curr_obs = decode_and_plot(curr_obs[:1]) 71 | imageio.imwrite(f"{save_path}/{step}_curr_obs.png", curr_obs) 72 | 73 | # Plot next observations 74 | next_obs = decode_and_plot(next_obs[:1]) 75 | imageio.imwrite(f"{save_path}/{step}_next_obs.png", next_obs) 76 | 77 | # Plot predicted next observations 78 | next_obs_hat_forward = decode_and_plot(next_obs_hat_forward[:1]) 79 | imageio.imwrite(f"{save_path}/{step}_next_obs_hat.png", next_obs_hat_forward) 80 | 81 | step += 1 82 | if step == 15: 83 | break 84 | 85 | 86 | def train(rank, world_size, config): 87 | # Set global seed 88 | set_seed(config.seed * world_size + rank) 89 | 90 | # Initialize distributed training 91 | device = torch.device(f"cuda:{rank}") 92 | 93 | # Create dataset 94 | train_set, val_set = instantiate(config.dataset) 95 | train_loader, val_loader = make_distributed_data_loader( 96 | train_set, val_set, config.batch_size, rank, world_size 97 | ) 98 | 99 | # Create model 100 | model = instantiate(config.model).to(device) 101 | optimizer = torch.optim.AdamW(model.parameters(), **config.optimizer) 102 | scheduler = get_scheduler(optimizer=optimizer, **config.scheduler) 103 | scaler = torch.cuda.amp.GradScaler(enabled=config.use_amp) 104 | 105 | # Resume from checkpoint 106 | ckpt_path = os.path.join(config.logdir, "models.pt") 107 | ckpt = torch.load(ckpt_path, map_location="cpu") 108 | model.load_state_dict(ckpt["model"]) 109 | optimizer.load_state_dict(ckpt["optimizer"]) 110 | scheduler.load_state_dict(ckpt["scheduler"]) 111 | scaler.load_state_dict(ckpt["scaler"]) 112 | step = ckpt["step"] + 1 113 | print(f"Resumed training from step {step}") 114 | 115 | train_set.action_normalizer = ckpt["action_normalizer"] 116 | train_set.lowdim_normalizer = ckpt["lowdim_normalizer"] 117 | val_set.action_normalizer = ckpt["action_normalizer"] 118 | val_set.lowdim_normalizer = ckpt["lowdim_normalizer"] 119 | 120 | eval_one_epoch(config, val_loader, device, model) 121 | 122 | 123 | @hydra.main( 124 | version_base=None, config_path="../../configs", config_name="train_uwm.yaml" 125 | ) 126 | def main(config): 127 | # Resolve hydra config 128 | OmegaConf.resolve(config) 129 | train(0, 1, config) 130 | 131 | 132 | if __name__ == "__main__": 133 | main() 134 | -------------------------------------------------------------------------------- /models/dp/base_policy.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from diffusers.schedulers.scheduling_ddim import DDIMScheduler 7 | 8 | 9 | class NoisePredictionNet(nn.Module, ABC): 10 | 11 | @abstractmethod 12 | def forward(self, sample, timestep, global_cond): 13 | raise NotImplementedError 14 | 15 | 16 | class DiffusionPolicy(nn.Module): 17 | def __init__( 18 | self, 19 | action_len, 20 | action_dim, 21 | noise_pred_net, 22 | num_train_steps=100, 23 | num_inference_steps=10, 24 | num_train_noise_samples=1, 25 | beta_schedule="squaredcos_cap_v2", 26 | clip_sample=True, 27 | ): 28 | super().__init__() 29 | self.action_len = action_len 30 | self.action_dim = action_dim 31 | self.num_train_steps = num_train_steps 32 | self.num_inference_steps = num_inference_steps 33 | self.num_train_noise_samples = num_train_noise_samples 34 | 35 | # Noise prediction net 36 | assert isinstance(noise_pred_net, NoisePredictionNet) 37 | self.noise_pred_net = noise_pred_net 38 | 39 | # Noise scheduler 40 | self.noise_scheduler = DDIMScheduler( 41 | num_train_timesteps=num_train_steps, 42 | beta_schedule=beta_schedule, 43 | clip_sample=clip_sample, 44 | ) 45 | 46 | @torch.no_grad() 47 | def sample(self, obs): 48 | # Initialize sample 49 | action = torch.randn( 50 | (obs.shape[0], self.action_len, self.action_dim), device=obs.device 51 | ) 52 | 53 | # Initialize scheduler 54 | self.noise_scheduler.set_timesteps(self.num_inference_steps) 55 | 56 | # Reverse diffusion process 57 | for t in self.noise_scheduler.timesteps: 58 | # Predict noise 59 | noise_pred = self.noise_pred_net(action, t, global_cond=obs) 60 | 61 | # Diffusion step 62 | action = self.noise_scheduler.step(noise_pred, t, action).prev_sample 63 | 64 | return action 65 | 66 | def forward(self, obs, action): 67 | # Repeat observations and actions for multiple noise samples 68 | if self.num_train_noise_samples > 1: 69 | obs = obs.repeat_interleave(self.num_train_noise_samples, dim=0) 70 | action = action.repeat_interleave(self.num_train_noise_samples, dim=0) 71 | 72 | # Sample random noise 73 | noise = torch.randn_like(action) 74 | 75 | # Sample a random timestep 76 | t = torch.randint( 77 | low=0, 78 | high=self.num_train_steps, 79 | size=(action.shape[0],), 80 | device=action.device, 81 | ).long() 82 | 83 | # Forward diffusion step 84 | noisy_action = self.noise_scheduler.add_noise(action, noise, t) 85 | 86 | # Diffusion loss 87 | noise_pred = self.noise_pred_net(noisy_action, t, global_cond=obs) 88 | loss = F.mse_loss(noise_pred, noise) 89 | return loss 90 | 91 | 92 | class FlowPolicy(nn.Module): 93 | def __init__( 94 | self, 95 | action_len, 96 | action_dim, 97 | noise_pred_net, 98 | num_train_steps=100, 99 | num_inference_steps=10, 100 | timeshift=1.0, 101 | ): 102 | super().__init__() 103 | self.action_len = action_len 104 | self.action_dim = action_dim 105 | 106 | # Noise prediction net 107 | assert isinstance(noise_pred_net, NoisePredictionNet) 108 | self.noise_pred_net = noise_pred_net 109 | 110 | self.num_train_steps = num_train_steps 111 | self.num_inference_steps = num_inference_steps 112 | timesteps = torch.linspace(1, 0, self.num_inference_steps + 1) 113 | self.timesteps = (timeshift * timesteps) / (1 + (timeshift - 1) * timesteps) 114 | 115 | @torch.no_grad() 116 | def sample(self, obs): 117 | # Initialize sample 118 | action = torch.randn( 119 | (obs.shape[0], self.action_len, self.action_dim), device=obs.device 120 | ) 121 | 122 | for tcont, tcont_next in zip(self.timesteps[:-1], self.timesteps[1:]): 123 | # Predict noise 124 | t = (tcont * self.num_train_steps).long() 125 | noise_pred = self.noise_pred_net(action, t, global_cond=obs) 126 | 127 | # Flow step 128 | action = action + (tcont_next - tcont) * noise_pred 129 | 130 | return action 131 | 132 | def forward(self, obs, action): 133 | # Sample random noise 134 | noise = torch.randn_like(action) 135 | 136 | # Sample random timestep 137 | tcont = torch.rand((action.shape[0],), device=action.device) 138 | 139 | # Forward flow step 140 | direction = noise - action 141 | noisy_action = ( 142 | action + tcont.view(-1, *[1 for _ in range(action.dim() - 1)]) * direction 143 | ) 144 | 145 | # Flow matching loss 146 | t = (tcont * self.num_train_steps).long() 147 | noise_pred = self.noise_pred_net(noisy_action, t, global_cond=obs) 148 | loss = F.mse_loss(noise_pred, direction) 149 | return loss 150 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unified World Models: Coupling Video and Action Diffusion for Pretraining on Large Robotic Datasets 2 | 3 | #### [[Website]](https://weirdlabuw.github.io/uwm/) [[Paper]](https://arxiv.org/abs/2504.02792) [[Talk]](https://www.youtube.com/watch?v=WwPRxBbZ4kw) 4 | 5 | [Chuning Zhu1](https://homes.cs.washington.edu/~zchuning/), [Raymond Yu1](https://raymondyu5.github.io/), [Siyuan Feng2](https://www.cs.cmu.edu/~sfeng/), [Benjamin Burchfiel2](https://scholar.google.com/citations?user=eGoTK1YAAAAJ&hl=en), [Paarth Shah2](https://www.paarthshah.me/about), [Abhishek Gupta1](https://homes.cs.washington.edu/~abhgupta/)
6 | 7 | 1University of Washington 2Toyota Research Institute 8 | 9 | This repository provides a PyTorch implementation of Unified World Model (UWM). UWM combines action diffusion and video diffusion to enable scalable pretraining on large, heterogeneous robotics datasets. 10 | 11 | 12 | ## Code structure 13 | * `configs`: Configuration files for pretraining and finetuning experiments. 14 | * `datasets`: Dataset wrappers for DROID, Robomimic, and LIBERO. We standardize all datasets using compressed [Zarr](https://zarr.readthedocs.io/en/stable/) buffers. 15 | * `environments`: Interface wrappers for Robomimic and LIBERO environments. 16 | * `experiments`: Training and evaluation scripts. 17 | * `models`: Model definitions for UWM and baselines. 18 | * `scripts`: Bash scripts for running DROID experiments. 19 | 20 | 21 | ## Setup 22 | Install the package via 23 | ``` 24 | pip install -e . 25 | ``` 26 | > Note: if you encounter issues using tensorflow-dataset with DROID, consider installing tensorflow-dataset from [source](https://github.com/tensorflow/datasets). 27 | 28 | ## Robomimic Experiments 29 | To run a Robomimic single-task experiment, 30 | 1. Install the [Robomimic](https://github.com/ARISE-Initiative/robomimic) dataset. 31 | 2. Update `hdf5_path` and `buffer_path` in the config (e.g., `configs/dataset/robomimic_cap_ph.yaml`). 32 | 3. Run: 33 | ``` 34 | python experiments/uwm/train_robomimic.py --config_name train_uwm_robomimic.yaml dataset=robomimic_can_ph exp_id=singletask 35 | ``` 36 | This command will generate a Zarr compressed buffer at the `buffer_path` specified in the config file. 37 | 38 | ## LIBERO Experiments 39 | The LIBERO experiments share most infrastructure with the Robomimic experiments. 40 | 41 | ### Pretraining 42 | To pretrain a UWM on LIBERO-90, 43 | 1. Install the [LIBERO](https://github.com/Lifelong-Robot-Learning/LIBERO) dataset. 44 | 2. Update `hdf5_path` and `buffer_path` in `configs/dataset/libero_90.yaml`. 45 | 3. Run: 46 | ``` 47 | python experiments/uwm/train_robomimic.py --config_name train_uwm_robomimic.yaml dataset=libero_90 exp_id=pretrain 48 | ``` 49 | 50 | ### Finetuning 51 | To finetune a pretrained UWM on a downstream LIBERO task (e.g., Book-Caddy), 52 | 1. Update `hdf5_path` and `buffer_path` in `configs/dataset/libero_book_caddy.yaml`. 53 | 2. Run: 54 | ``` 55 | python experiments/uwm/train_robomimic.py --config-name finetune_uwm_robomimic.yaml dataset=libero_book_caddy exp_id=finetune pretrain_checkpoint_path="logdir/uwm/libero_90/pretrain/0/models.pt" 56 | ``` 57 | 58 | We release the pretrained LIBERO-90 checkpoint [here](https://drive.google.com/drive/folders/1M4AuVLMRpSwOf_YAp56bV9AqyZI9ul6g?usp=sharing). You can download and directly finetune from this checkpoint. 59 | 60 | ## DROID Experiments 61 | We provide shell scripts for DROID pretraining / cotraining / finetuning experiments in the `scripts` directory. Each script runs a dataset conversion pipeline to create a Zarr buffer for the corresponding DROID TFDS dataset and then launches training. 62 | 63 | ### Pretraining 64 | To launch a DROID pretraining experiment, 65 | 1. Install the [DROID](https://droid-dataset.github.io/) dataset 66 | 2. Update `DATA_DIR` and `BUFFER_PATH` in `scripts/launch_droid_pretrain.sh` 67 | 3. Run: 68 | ``` 69 | source scripts/launch_droid_pretrain.sh 70 | ``` 71 | 72 | ### Cotraining 73 | To launch a video cotraining experiment, 74 | 1. Install the [DROID](https://droid-dataset.github.io/) dataset 75 | 2. Update `DATA_DIR`, `ROBOT_BUFFER_PATH`, and `VIDEO_BUFFER_PATH` in `scripts/launch_droid_cotrain.sh` 76 | 3. Run: 77 | ``` 78 | source scripts/launch_droid_cotrain.sh 79 | ``` 80 | 81 | ### Finetuning 82 | To fineune a pretrained model to a downstream task, 83 | 1. Collect demonstrations using the DROID interface 84 | 2. Convert them into a TFDS dataset (via this [pipeline](https://github.com/kpertsch/droid_dataset_builder)) 85 | 3. Modify and run: 86 | ``` 87 | source scripts/launch_droid_finetune.sh 88 | ``` 89 | 90 | We release the pretrained and cotrained DROID UWM checkpoints [here](https://drive.google.com/drive/folders/1M4AuVLMRpSwOf_YAp56bV9AqyZI9ul6g?usp=sharing). You can download and directly finetune from these checkpoints. 91 | 92 | ## Bibtex 93 | If you find this code useful, please cite: 94 | 95 | ``` 96 | @inproceedings{zhu2025uwm, 97 | author = {Zhu, Chuning and Yu, Raymond and Feng, Siyuan and Burchfiel, Benjamin and Shah, Paarth and Gupta, Abhishek}, 98 | title = {Unified World Models: Coupling Video and Action Diffusion for Pretraining on Large Robotic Datasets}, 99 | booktitle = {Proceedings of Robotics: Science and Systems (RSS)}, 100 | year = {2025}, 101 | } 102 | ``` -------------------------------------------------------------------------------- /models/dp/obs_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | from einops import rearrange 6 | from torchvision.transforms import Normalize 7 | 8 | from models.common.language import CLIPTextEncoder 9 | from models.common.transforms import ImageTransform 10 | from models.common.vision import get_resnet, get_vit, get_clip 11 | 12 | 13 | class ImageObservationEncoder(nn.Module): 14 | def __init__( 15 | self, 16 | shape_meta: dict, 17 | num_frames: int, 18 | embed_dim: int, 19 | resize_shape: Tuple[int, int] = None, 20 | crop_shape: Tuple[int, int] = None, 21 | random_crop: bool = True, 22 | color_jitter: Optional[Dict] = None, 23 | imagenet_norm: bool = True, 24 | pretrained_weights: Optional[str] = None, 25 | use_low_dim: bool = True, 26 | use_language: bool = True, 27 | ): 28 | """ 29 | Assumes rgb input: (B, T, H, W, C) uint8 image 30 | Assumes low_dim input: (B, T, D) 31 | """ 32 | super().__init__() 33 | rgb_keys = list() 34 | low_dim_keys = list() 35 | key_shape_map = dict() 36 | key_transform_map = nn.ModuleDict() 37 | 38 | obs_shape_meta = shape_meta["obs"] 39 | for key, attr in obs_shape_meta.items(): 40 | obs_shape = tuple(attr["shape"]) 41 | key_shape_map[key] = obs_shape 42 | 43 | obs_type = attr.get("type", "low_dim") 44 | if obs_type == "rgb": 45 | rgb_keys.append(key) 46 | key_transform_map[key] = ImageTransform( 47 | resize_shape=resize_shape, 48 | crop_shape=crop_shape, 49 | random_crop=random_crop, 50 | color_jitter=color_jitter, 51 | imagenet_norm=imagenet_norm, 52 | ) 53 | elif obs_type == "low_dim": 54 | low_dim_keys.append(key) 55 | else: 56 | raise RuntimeError(f"Unsupported obs type: {type}") 57 | 58 | self.shape_meta = shape_meta 59 | self.num_frames = num_frames 60 | self.embed_dim = embed_dim 61 | self.rgb_keys = sorted(rgb_keys) 62 | self.low_dim_keys = sorted(low_dim_keys) 63 | self.key_shape_map = key_shape_map 64 | self.key_transform_map = key_transform_map 65 | 66 | # RGB model 67 | if pretrained_weights == "clip": 68 | assert not imagenet_norm, "imagenet_norm must be False for CLIP encoder" 69 | norm = Normalize( 70 | mean=[0.48145466, 0.4578275, 0.40821073], 71 | std=[0.26862954, 0.26130258, 0.27577711], 72 | inplace=True, 73 | ) 74 | model = get_clip(embed_dim) 75 | self.rgb_encoder = nn.Sequential(norm, model) 76 | elif pretrained_weights == "vit": 77 | self.rgb_encoder = get_vit( 78 | "vit_b_32", embed_dim, weights=pretrained_weights 79 | ) 80 | else: 81 | self.rgb_encoder = get_resnet( 82 | "resnet18", embed_dim, weights=pretrained_weights 83 | ) 84 | 85 | # Low dim model 86 | self.use_low_dim = use_low_dim 87 | self.low_dim_size = sum([key_shape_map[key][-1] for key in low_dim_keys]) 88 | 89 | # Language model 90 | self.use_language = use_language 91 | self.text_encoder = ( 92 | CLIPTextEncoder(embed_dim=embed_dim) if use_language else None 93 | ) 94 | 95 | def __call__(self, obs_dict): 96 | # Process rgb observations 97 | imgs = list() 98 | for key in self.rgb_keys: 99 | img = obs_dict[key].flatten(0, 1) 100 | assert img.shape[1:] == self.key_shape_map[key] 101 | img = self.key_transform_map[key](img) # (B*T, C, H, W) 102 | imgs.append(img) 103 | 104 | # Concatenate along batch dimension 105 | imgs = torch.cat(imgs, dim=0) # (N*B*T, C, H, W) 106 | feats = self.rgb_encoder(imgs) # (N*B*T, D) 107 | feats = rearrange( 108 | feats, "(n b t) d -> b (t n d)", n=len(self.rgb_keys), t=self.num_frames 109 | ) 110 | 111 | if self.use_low_dim: 112 | # Process low dim observations 113 | low_dims = list() 114 | for key in self.low_dim_keys: 115 | low_dim = obs_dict[key].flatten(0, 1) 116 | assert low_dim.shape[1:] == self.key_shape_map[key] 117 | low_dims.append(low_dim) 118 | low_dims = torch.cat(low_dims, dim=-1) # (B*T, D_low_dim) 119 | low_dims = rearrange(low_dims, "(b t) d -> b (t d)", t=self.num_frames) 120 | 121 | # Concatenate image and lowdim features 122 | feats = torch.cat([feats, low_dims], dim=-1) 123 | 124 | # Encode language 125 | if self.use_language: 126 | lang_feats = self.text_encoder( 127 | input_ids=obs_dict["input_ids"], 128 | attention_mask=obs_dict["attention_mask"], 129 | ) 130 | feats = torch.cat([feats, lang_feats], dim=-1) 131 | 132 | return feats 133 | 134 | @property 135 | def output_len(self): 136 | return ( 137 | len(self.rgb_keys) * self.embed_dim 138 | + int(self.use_low_dim) * self.low_dim_size 139 | ) * self.num_frames + int(self.use_language) * self.embed_dim 140 | -------------------------------------------------------------------------------- /models/pad/obs_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | from einops import rearrange 6 | 7 | from models.common.transforms import VideoTransform, VAEDownsample 8 | 9 | 10 | class PADObservationEncoder(nn.Module): 11 | def __init__( 12 | self, 13 | shape_meta: dict, 14 | num_frames: int, 15 | resize_shape: Tuple[int, int] = None, 16 | crop_shape: Tuple[int, int] = None, 17 | random_crop: bool = True, 18 | color_jitter: Optional[Dict] = None, 19 | imagenet_norm: bool = False, 20 | ): 21 | super().__init__() 22 | self.shape_meta = shape_meta 23 | self.num_frames = num_frames 24 | self.rgb_keys = sorted( 25 | [k for k, v in shape_meta["obs"].items() if v["type"] == "rgb"] 26 | ) 27 | self.low_dim_keys = sorted( 28 | [k for k, v in shape_meta["obs"].items() if v["type"] == "low_dim"] 29 | ) 30 | self.num_views = len(self.rgb_keys) 31 | 32 | # Image augmentation 33 | self.obs_transform = VideoTransform( 34 | resize_shape=resize_shape, 35 | crop_shape=crop_shape, 36 | random_crop=random_crop, 37 | color_jitter=color_jitter, 38 | imagenet_norm=imagenet_norm, 39 | ) 40 | 41 | # Downsample future observations using vae 42 | self.vae = VAEDownsample() 43 | 44 | def apply_transform(self, obs_dicts: Union[dict, list[dict]]): 45 | """ 46 | Accept a list of observation dictionaries and apply the same transform to each. 47 | """ 48 | if isinstance(obs_dicts, dict): 49 | obs_dicts = [obs_dicts] 50 | is_singleton = True 51 | else: 52 | is_singleton = False 53 | assert isinstance(obs_dicts, list) 54 | 55 | # Apply the same transform to each observation 56 | num_obs = len(obs_dicts) 57 | transformed_imgs = [[] for _ in range(num_obs)] 58 | for key in self.rgb_keys: 59 | combined_imgs = torch.cat([obs_dict[key] for obs_dict in obs_dicts], dim=0) 60 | combined_imgs = self.obs_transform(combined_imgs) 61 | chunked_imgs = combined_imgs.chunk(num_obs, dim=0) 62 | for i, img in enumerate(chunked_imgs): 63 | transformed_imgs[i].append(img) 64 | 65 | # Stack transformed images 66 | # Each image has shape (B, V, C, T, H, W) 67 | transformed_imgs = [torch.stack(imgs, dim=1) for imgs in transformed_imgs] 68 | if is_singleton: 69 | transformed_imgs = transformed_imgs[0] 70 | return transformed_imgs 71 | 72 | def apply_vae( 73 | self, 74 | imgs_list: Union[torch.Tensor, list[torch.Tensor]], 75 | inverse: bool = False, 76 | microbatch_size: int = 72, # Tuned for 40GB VRAM 77 | ): 78 | """ 79 | Accept a list of images and apply VAE to downsample or upsample images. 80 | If inverse is False, downsample images. Otherwise, upsample images. 81 | Process images in microbatches to reduce memory usage. 82 | """ 83 | if isinstance(imgs_list, torch.Tensor): 84 | imgs_list = [imgs_list] 85 | is_singleton = True 86 | else: 87 | is_singleton = False 88 | assert isinstance(imgs_list, list) 89 | imgs = torch.cat(imgs_list, dim=0) 90 | 91 | # Flatten multiview videos to images 92 | B, V = imgs.shape[:2] 93 | imgs = rearrange(imgs, "b v c t h w -> (b v t) c h w") 94 | 95 | # Process images in microbatches 96 | transformed_imgs = [] 97 | for i in range(0, imgs.shape[0], microbatch_size): 98 | batch_imgs = imgs[i : i + microbatch_size] 99 | if inverse: 100 | batch_transformed_imgs = self.vae.inverse(batch_imgs) 101 | else: 102 | batch_transformed_imgs = self.vae(batch_imgs) 103 | transformed_imgs.append(batch_transformed_imgs) 104 | transformed_imgs = torch.cat(transformed_imgs, dim=0) 105 | 106 | # Unflatten images to multiview videos 107 | transformed_imgs = rearrange( 108 | transformed_imgs, "(b v t) c h w -> b v c t h w", b=B, v=V 109 | ) 110 | if not is_singleton: 111 | chunk_sizes = [img.shape[0] for img in imgs_list] 112 | transformed_imgs = list(transformed_imgs.split(chunk_sizes, dim=0)) 113 | return transformed_imgs 114 | 115 | def encode_obs(self, obs_dict: dict): 116 | imgs = self.apply_transform(obs_dict) 117 | latents = self.apply_vae(imgs) 118 | return latents 119 | 120 | def encode_curr_and_next_obs(self, curr_obs_dict: dict, next_obs_dict: dict): 121 | # Apply the same transform to obs and next obs 122 | curr_imgs, next_imgs = self.apply_transform([curr_obs_dict, next_obs_dict]) 123 | curr_latents, next_latents = self.apply_vae([curr_imgs, next_imgs]) 124 | return curr_latents, next_latents 125 | 126 | def latent_img_shape(self): 127 | # Construct dummy image and forward pass to get latent shape 128 | dummy_obs_dict = {} 129 | for k in self.rgb_keys: 130 | img_shape = self.shape_meta["obs"][k]["shape"] 131 | dummy_obs_dict[k] = torch.zeros( 132 | 1, self.num_frames, *img_shape, dtype=torch.uint8 133 | ) 134 | with torch.no_grad(): 135 | dummy_imgs = self.apply_transform(dummy_obs_dict) 136 | dummy_latents = self.apply_vae(dummy_imgs) 137 | return tuple(dummy_latents.shape[1:]) 138 | -------------------------------------------------------------------------------- /models/gr1/obs_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | from einops import rearrange 6 | 7 | from models.common.transforms import VideoTransform 8 | from models.common.vision import get_vit 9 | from models.gr1.flamingo import PerceiverResampler 10 | 11 | 12 | class MultiViewViTImageEncoder(nn.Module): 13 | def __init__( 14 | self, num_views: int, num_frames: int, embed_dim: int, resampler_params: dict 15 | ): 16 | super().__init__() 17 | self.num_views = num_views 18 | self.model = get_vit("vit_b_32", embed_dim, weights="IMAGENET1K_V1") 19 | 20 | # Perceiver resampler 21 | self.perceiver_resampler = PerceiverResampler(**resampler_params) 22 | 23 | # Learnable embeddings 24 | self.pos_shift = nn.Parameter( 25 | torch.zeros(1, num_views * num_frames, 1, embed_dim), 26 | requires_grad=True, 27 | ) 28 | self.pos_scale = nn.Parameter( 29 | torch.zeros(1, num_views * num_frames, 1, embed_dim), 30 | requires_grad=True, 31 | ) 32 | 33 | self.feat_head = nn.Linear(embed_dim, embed_dim) 34 | self.patch_head = nn.Linear(embed_dim, embed_dim) 35 | 36 | def forward(self, imgs: torch.Tensor): 37 | B, V = imgs.shape[:2] 38 | imgs = rearrange(imgs, "b v c t h w -> (b v t) c h w") 39 | 40 | # Reshape and permute the input tensor 41 | x = self.model._process_input(imgs) 42 | n = x.shape[0] 43 | 44 | # Expand the class token to the full batch 45 | batch_class_token = self.model.class_token.expand(n, -1, -1) 46 | x = torch.cat([batch_class_token, x], dim=1) 47 | 48 | # Get raw tokens 49 | x = self.model.encoder(x) 50 | feats, patch_embeds = x[:, :1], x[:, 1:] 51 | 52 | # Pass through perceiver resampler 53 | feats = self.feat_head(feats) 54 | patch_embeds = self.patch_head( 55 | self.perceiver_resampler(x.unsqueeze(1)).squeeze(1) 56 | ) 57 | 58 | # Add learned positional embeddings 59 | x = torch.cat([feats, patch_embeds], dim=1) 60 | x = rearrange(x, "(b v t) n c -> b (v t) n c", b=B, v=V) 61 | x = x * (1 + self.pos_scale) + self.pos_shift 62 | return x.flatten(1, 2) # (b, v*t*n, c) 63 | 64 | 65 | class GR1ObservationEncoder(nn.Module): 66 | def __init__( 67 | self, 68 | shape_meta: dict, 69 | num_frames: int, 70 | embed_dim: int, 71 | resize_shape: Tuple[int, int] = None, 72 | crop_shape: Tuple[int, int] = None, 73 | random_crop: bool = True, 74 | color_jitter: Optional[Dict] = None, 75 | imagenet_norm: bool = False, 76 | resampler_params: dict = None, 77 | ): 78 | super().__init__() 79 | self.shape_meta = shape_meta 80 | self.num_frames = num_frames 81 | self.rgb_keys = sorted( 82 | [k for k, v in shape_meta["obs"].items() if v["type"] == "rgb"] 83 | ) 84 | self.low_dim_keys = sorted( 85 | [k for k, v in shape_meta["obs"].items() if v["type"] == "low_dim"] 86 | ) 87 | self.num_views = len(self.rgb_keys) 88 | 89 | # Image augmentation 90 | self.obs_transform = VideoTransform( 91 | resize_shape=resize_shape, 92 | crop_shape=crop_shape, 93 | random_crop=random_crop, 94 | color_jitter=color_jitter, 95 | imagenet_norm=imagenet_norm, 96 | ) 97 | 98 | # Image encoder 99 | self.img_encoder = MultiViewViTImageEncoder( 100 | num_views=self.num_views, 101 | num_frames=self.num_frames, 102 | embed_dim=embed_dim, 103 | resampler_params=resampler_params, 104 | ) 105 | 106 | def apply_transform(self, obs_dicts: Union[dict, list[dict]]): 107 | """ 108 | Accept a list of observation dictionaries and apply the same transform to each. 109 | """ 110 | if isinstance(obs_dicts, dict): 111 | obs_dicts = [obs_dicts] 112 | is_singleton = True 113 | else: 114 | is_singleton = False 115 | assert isinstance(obs_dicts, list) 116 | 117 | # Apply the same transform to each observation 118 | num_obs = len(obs_dicts) 119 | transformed_imgs = [[] for _ in range(num_obs)] 120 | for key in self.rgb_keys: 121 | combined_imgs = torch.cat([obs_dict[key] for obs_dict in obs_dicts], dim=0) 122 | combined_imgs = self.obs_transform(combined_imgs) 123 | chunked_imgs = combined_imgs.chunk(num_obs, dim=0) 124 | for i, img in enumerate(chunked_imgs): 125 | transformed_imgs[i].append(img) 126 | 127 | # Stack transformed images 128 | # Each image has shape (B, V, C, T, H, W) 129 | transformed_imgs = [torch.stack(imgs, dim=1) for imgs in transformed_imgs] 130 | if is_singleton: 131 | transformed_imgs = transformed_imgs[0] 132 | return transformed_imgs 133 | 134 | def encode_obs(self, obs_dict: dict): 135 | imgs = self.apply_transform(obs_dict) 136 | feats = self.img_encoder(imgs) 137 | return feats 138 | 139 | def encode_curr_and_next_obs(self, curr_obs_dict: dict, next_obs_dict: dict): 140 | # Apply the same transform to obs and next obs 141 | curr_imgs, next_imgs = self.apply_transform([curr_obs_dict, next_obs_dict]) 142 | curr_feats = self.img_encoder(curr_imgs) 143 | return curr_feats, next_imgs 144 | 145 | @property 146 | def num_latents(self): 147 | return ( 148 | (self.img_encoder.perceiver_resampler.num_latents + 1) 149 | * self.num_views 150 | * self.num_frames 151 | ) 152 | -------------------------------------------------------------------------------- /models/common/transforms.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.transforms.functional as ttf 6 | from diffusers import AutoencoderKL 7 | from einops import rearrange 8 | from torchvision.transforms import ( 9 | CenterCrop, 10 | ColorJitter, 11 | RandomCrop, 12 | Resize, 13 | Normalize, 14 | ) 15 | 16 | 17 | class ToTensor(nn.Module): 18 | """ 19 | Convert a batch of images from (B, H, W, C) to (B, C, H, W) 20 | and normalize the pixel values to the range [0, 1]. 21 | """ 22 | 23 | def forward(self, inputs: torch.Tensor): 24 | return inputs.permute((0, 3, 1, 2)).contiguous().float().div_(255.0) 25 | 26 | 27 | class AutoRandomCrop(nn.Module): 28 | """ 29 | Perform random cropping during training and center cropping during eval. 30 | """ 31 | 32 | def __init__(self, size: tuple[int, int]): 33 | super().__init__() 34 | self.size = size 35 | self.random_crop = RandomCrop(size=size) 36 | 37 | def forward(self, inputs: torch.Tensor): 38 | if self.training: 39 | return self.random_crop(inputs) 40 | else: 41 | # Take center crop during eval 42 | return ttf.center_crop(img=inputs, output_size=self.size) 43 | 44 | 45 | class AutoColorJitter(nn.Module): 46 | """ 47 | Perform color jittering during training and no-op during eval. 48 | """ 49 | 50 | def __init__( 51 | self, 52 | brightness: float, 53 | contrast: float, 54 | saturation: float, 55 | hue: tuple[float], 56 | ): 57 | super().__init__() 58 | self.color_jitter = ColorJitter( 59 | brightness=brightness, 60 | contrast=contrast, 61 | saturation=saturation, 62 | hue=tuple(hue), 63 | ) 64 | 65 | def forward(self, inputs: torch.Tensor): 66 | if self.training: 67 | return self.color_jitter(inputs) 68 | else: 69 | return inputs # no-op during eval 70 | 71 | 72 | class VAEDownsample(nn.Module): 73 | """ 74 | Downsample images using a pre-trained VAE. 75 | """ 76 | 77 | def __init__(self): 78 | super().__init__() 79 | # Input normalization 80 | self.norm = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) 81 | self.inv_norm = Normalize(mean=[-1, -1, -1], std=[2, 2, 2], inplace=True) 82 | 83 | # Normalization stats (computed from multitask 12) 84 | shift = torch.tensor([0.0, 0.0, 0.0, 0.0]).view(1, 4, 1, 1) 85 | scale = torch.tensor([3.0, 3.0, 3.0, 3.0]).view(1, 4, 1, 1) 86 | self.register_buffer("shift", shift) 87 | self.register_buffer("scale", scale) 88 | 89 | # Load pre-trained VAE 90 | self.vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae") 91 | for p in self.vae.parameters(): 92 | p.requires_grad = False 93 | self.scaling_factor = self.vae.config.scaling_factor 94 | 95 | def forward(self, images: torch.Tensor): 96 | images = self.norm(images) 97 | feats = self.vae.encode(images).latent_dist.sample() 98 | feats = feats.mul_(self.scaling_factor) 99 | feats = feats.sub_(self.shift).div_(self.scale) 100 | return feats 101 | 102 | def inverse(self, feats: torch.Tensor): 103 | feats = feats.mul_(self.scale).add_(self.shift) 104 | feats = feats.div_(self.scaling_factor) 105 | images = self.vae.decode(feats).sample 106 | images = self.inv_norm(images) 107 | return images 108 | 109 | 110 | class ImageTransform(nn.Module): 111 | """ 112 | Apply a sequence of transforms to images. 113 | """ 114 | 115 | def __init__( 116 | self, 117 | resize_shape: Optional[tuple[int, int]] = None, 118 | crop_shape: Optional[tuple[int, int]] = None, 119 | random_crop: bool = True, 120 | color_jitter: Optional[dict] = None, 121 | downsample: bool = False, 122 | imagenet_norm: bool = True, 123 | ): 124 | super().__init__() 125 | transform = list() 126 | 127 | # Convert image to tensor format 128 | transform.append(ToTensor()) 129 | 130 | # Resize images 131 | if resize_shape is not None: 132 | transform.append(Resize(resize_shape)) 133 | 134 | # Apply random crop during training and center crop during eval 135 | if crop_shape is not None: 136 | if random_crop: 137 | transform.append(AutoRandomCrop(crop_shape)) 138 | else: 139 | transform.append(CenterCrop(crop_shape)) 140 | 141 | # Apply color jitter during training 142 | if color_jitter is not None: 143 | transform.append( 144 | AutoColorJitter( 145 | brightness=color_jitter["brightness"], 146 | contrast=color_jitter["contrast"], 147 | saturation=color_jitter["saturation"], 148 | hue=tuple(color_jitter["hue"]), 149 | ) 150 | ) 151 | 152 | # Normalize using imagenet statistics 153 | if downsample: 154 | if imagenet_norm: 155 | print("Disabling imagenet normalization since downsample is enabled.") 156 | transform.append(VAEDownsample()) 157 | elif imagenet_norm: 158 | transform.append( 159 | Normalize( 160 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=True 161 | ) 162 | ) 163 | 164 | self.transform = nn.Sequential(*transform) 165 | 166 | @property 167 | def vae(self): 168 | assert isinstance(self.transform[-1], VAEDownsample) 169 | return self.transform[-1] 170 | 171 | def forward(self, images): 172 | return self.transform(images) 173 | 174 | 175 | class VideoTransform(ImageTransform): 176 | """ 177 | Flatten videos to images, apply transforms, and reshape back to videos. 178 | """ 179 | 180 | def forward(self, images): 181 | num_frames = images.shape[1] 182 | images = rearrange(images, "b t h w c-> (b t) h w c") 183 | images = self.transform(images) 184 | images = rearrange(images, "(b t) c h w-> b c t h w", t=num_frames) 185 | return images 186 | -------------------------------------------------------------------------------- /datasets/robomimic/dataset.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import h5py 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset 6 | from tqdm import tqdm 7 | 8 | from datasets.utils.buffer import CompressedTrajectoryBuffer 9 | from datasets.utils.file_utils import glob_all 10 | from datasets.utils.sampler import TrajectorySampler 11 | from datasets.utils.obs_utils import unflatten_obs 12 | 13 | 14 | class RobomimicDataset(Dataset): 15 | def __init__( 16 | self, 17 | name: str, 18 | hdf5_path_globs: str, 19 | buffer_path: str, 20 | shape_meta: dict, 21 | seq_len: int, 22 | val_ratio: float = 0.0, 23 | subsample_ratio: float = 1.0, 24 | flip_rgb: bool = False, 25 | ): 26 | self.name = name 27 | self.seq_len = seq_len 28 | self.flip_rgb = flip_rgb 29 | 30 | # Parse observation and action shapes 31 | obs_shape_meta = shape_meta["obs"] 32 | self._image_shapes = {} 33 | self._lowdim_shapes = {} 34 | for key, attr in obs_shape_meta.items(): 35 | obs_type = attr["type"] 36 | obs_shape = tuple(attr["shape"]) 37 | if obs_type == "rgb": 38 | self._image_shapes[key] = obs_shape 39 | elif obs_type == "low_dim": 40 | self._lowdim_shapes[key] = obs_shape 41 | else: 42 | raise RuntimeError(f"Unsupported obs type: {obs_type}") 43 | self._action_shape = tuple(shape_meta["action"]["shape"]) 44 | 45 | # Compressed buffer to store episode data 46 | self.buffer = self._init_buffer(hdf5_path_globs, buffer_path) 47 | 48 | # Create training-validation split 49 | num_episodes = self.buffer.num_episodes 50 | val_mask = np.zeros(num_episodes, dtype=bool) 51 | if val_ratio > 0: 52 | num_val_episodes = round(val_ratio * num_episodes) 53 | num_val_episodes = min(max(num_val_episodes, 1), num_episodes - 1) 54 | rng = np.random.default_rng(seed=0) 55 | val_inds = rng.choice(num_episodes, num_val_episodes, replace=False) 56 | val_mask[val_inds] = True 57 | self.val_mask = val_mask 58 | self.train_mask = ~val_mask 59 | 60 | # Apply subsample_ratio to training episodes 61 | if subsample_ratio < 1.0: 62 | train_indices = np.where(self.train_mask)[0] 63 | num_train_episodes = len(train_indices) 64 | num_subsampled = round(num_train_episodes * subsample_ratio) 65 | num_subsampled = max(1, num_subsampled) # Ensure at least one episode 66 | 67 | # Create a new mask with only the subsampled training episodes 68 | subsampled_train_mask = np.zeros(num_episodes, dtype=bool) 69 | rng = np.random.default_rng(seed=1) 70 | sampled_indices = rng.choice(train_indices, num_subsampled, replace=False) 71 | subsampled_train_mask[sampled_indices] = True 72 | self.train_mask = subsampled_train_mask 73 | 74 | # Sampler to draw sequences from buffer 75 | self.sampler = TrajectorySampler(self.buffer, self.seq_len, self.train_mask) 76 | 77 | def _init_buffer(self, hdf5_path_globs, buffer_path): 78 | hdf5_paths = glob_all(hdf5_path_globs) 79 | 80 | # Create metadata 81 | metadata = {} 82 | for key, shape in self._image_shapes.items(): 83 | metadata[f"obs.{key}"] = {"shape": shape, "dtype": np.uint8} 84 | for key, shape in self._lowdim_shapes.items(): 85 | metadata[f"obs.{key}"] = {"shape": shape, "dtype": np.float32} 86 | metadata["action"] = {"shape": self._action_shape, "dtype": np.float32} 87 | 88 | # Compute buffer capacity 89 | capacity = 0 90 | num_episodes = 0 91 | for hdf5_path in hdf5_paths: 92 | with h5py.File(hdf5_path) as f: 93 | demos = f["data"] 94 | for i in range(len(demos)): 95 | demo = demos[f"demo_{i}"] 96 | capacity += demo["actions"].shape[0] 97 | num_episodes += len(demos) 98 | 99 | # Initialize buffer 100 | buffer = CompressedTrajectoryBuffer( 101 | storage_path=buffer_path, 102 | metadata=metadata, 103 | capacity=capacity, 104 | ) 105 | 106 | # If buffer is restored from disk, return it 107 | if buffer.restored: 108 | return buffer 109 | 110 | # Otherwise, load episodes to buffer 111 | pbar = tqdm(total=num_episodes, desc="Loading episodes to buffer") 112 | for hdf5_path in hdf5_paths: 113 | with h5py.File(hdf5_path) as f: 114 | demos = f["data"] 115 | for i in range(len(demos)): 116 | demo = demos[f"demo_{i}"] 117 | episode = {} 118 | for key in self._image_shapes.keys(): 119 | if self.flip_rgb: 120 | episode[f"obs.{key}"] = demo["obs"][key][:][:, ::-1] 121 | else: 122 | episode[f"obs.{key}"] = demo["obs"][key][:] 123 | for key in self._lowdim_shapes.keys(): 124 | episode[f"obs.{key}"] = demo["obs"][key][:] 125 | episode["action"] = demo["actions"][:] 126 | buffer.add_episode(episode) 127 | pbar.update(1) 128 | pbar.close() 129 | return buffer 130 | 131 | def __len__(self) -> int: 132 | return len(self.sampler) 133 | 134 | def __repr__(self) -> str: 135 | return f"\nname: {self.name}\nnum_samples: {len(self)}\n{self.buffer}" 136 | 137 | def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: 138 | # Sample a sequence of observations and actions from the dataset. 139 | data = self.sampler.sample_sequence(idx) 140 | 141 | # Convert data to torch tensors 142 | data = {k: torch.from_numpy(v) for k, v in data.items()} 143 | 144 | # Unflatten observations 145 | data = unflatten_obs(data) 146 | return data 147 | 148 | def get_validation_dataset(self): 149 | val_set = copy.copy(self) 150 | val_set.train_mask = self.val_mask 151 | val_set.sampler = TrajectorySampler(self.buffer, self.seq_len, self.val_mask) 152 | return val_set 153 | -------------------------------------------------------------------------------- /models/common/adaln_attention.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .attention import Attention, CrossAttention, MLP 4 | 5 | 6 | def modulate(x, shift, scale): 7 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 8 | 9 | 10 | class AdaLNAttentionBlock(nn.Module): 11 | """Multiheaded self-attention block with adaptive layer normalization modulation.""" 12 | 13 | def __init__( 14 | self, 15 | dim, 16 | cond_dim, 17 | num_heads=8, 18 | mlp_ratio=4.0, 19 | qkv_bias=False, 20 | drop=0.0, 21 | attn_drop=0.0, 22 | act=nn.GELU, 23 | norm=nn.LayerNorm, 24 | is_causal=False, 25 | causal_block=1, 26 | ): 27 | super().__init__() 28 | self.norm1 = norm(dim, elementwise_affine=False, eps=1e-6) 29 | self.attn = Attention( 30 | dim=dim, 31 | num_heads=num_heads, 32 | qkv_bias=qkv_bias, 33 | attn_drop=attn_drop, 34 | proj_drop=drop, 35 | is_causal=is_causal, 36 | causal_block=causal_block, 37 | ) 38 | self.norm2 = norm(dim, elementwise_affine=False, eps=1e-6) 39 | self.mlp = MLP( 40 | in_dim=dim, 41 | hidden_dim=int(dim * mlp_ratio), 42 | out_dim=dim, 43 | act=act, 44 | drop=drop, 45 | ) 46 | self.adaLN_modulation = nn.Sequential( 47 | nn.SiLU(), 48 | nn.Linear(cond_dim, 6 * dim), 49 | ) 50 | 51 | def forward(self, x, cond, pos_embed=None, attn_mask=None): 52 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( 53 | self.adaLN_modulation(cond).chunk(6, dim=1) 54 | ) 55 | x = x + gate_msa.unsqueeze(1) * self.attn( 56 | modulate(self.norm1(x), shift_msa, scale_msa), pos_embed, attn_mask 57 | ) 58 | x = x + gate_mlp.unsqueeze(1) * self.mlp( 59 | modulate(self.norm2(x), shift_mlp, scale_mlp) 60 | ) 61 | return x 62 | 63 | 64 | class AdaLNCrossAttentionBlock(nn.Module): 65 | """Multiheaded cross-attention block with adaptive layer normalization modulation.""" 66 | 67 | def __init__( 68 | self, 69 | dim, 70 | cond_dim, 71 | num_heads=8, 72 | mlp_ratio=4.0, 73 | qkv_bias=False, 74 | drop=0.0, 75 | attn_drop=0.0, 76 | act=nn.GELU, 77 | norm=nn.LayerNorm, 78 | ): 79 | super().__init__() 80 | self.norm1 = norm(dim, elementwise_affine=False, eps=1e-6) 81 | self.xattn = CrossAttention( 82 | dim=dim, 83 | num_heads=num_heads, 84 | qkv_bias=qkv_bias, 85 | attn_drop=attn_drop, 86 | proj_drop=drop, 87 | ) 88 | self.norm2 = norm(dim, elementwise_affine=False, eps=1e-6) 89 | self.mlp = MLP( 90 | in_dim=dim, 91 | hidden_dim=int(dim * mlp_ratio), 92 | out_dim=dim, 93 | act=act, 94 | drop=drop, 95 | ) 96 | self.adaLN_modulation = nn.Sequential( 97 | nn.SiLU(), 98 | nn.Linear(cond_dim, 6 * dim), 99 | ) 100 | 101 | def forward(self, x, c, cond, x_pos_embed=None, c_pos_embed=None): 102 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( 103 | self.adaLN_modulation(cond).chunk(6, dim=1) 104 | ) 105 | x = x + gate_msa.unsqueeze(1) * self.xattn( 106 | modulate(x, shift_msa, scale_msa), self.norm1(c), x_pos_embed, c_pos_embed 107 | ) 108 | x = x + gate_mlp.unsqueeze(1) * self.mlp( 109 | modulate(self.norm2(x), shift_mlp, scale_mlp) 110 | ) 111 | return x 112 | 113 | 114 | class AdaLNHybridAttentionBlock(nn.Module): 115 | """Multiheaded hybrid attention block with adaptive layer normalization modulation.""" 116 | 117 | def __init__( 118 | self, 119 | dim, 120 | cond_dim, 121 | num_heads=8, 122 | mlp_ratio=4.0, 123 | qkv_bias=False, 124 | drop=0.0, 125 | attn_drop=0.0, 126 | act=nn.GELU, 127 | norm=nn.LayerNorm, 128 | ): 129 | super().__init__() 130 | self.norm1 = norm(dim, elementwise_affine=False, eps=1e-6) 131 | self.attn = Attention( 132 | dim=dim, 133 | num_heads=num_heads, 134 | qkv_bias=qkv_bias, 135 | attn_drop=attn_drop, 136 | proj_drop=drop, 137 | ) 138 | self.norm2 = norm(dim, elementwise_affine=False, eps=1e-6) 139 | self.xattn = CrossAttention( 140 | dim=dim, 141 | num_heads=num_heads, 142 | qkv_bias=qkv_bias, 143 | attn_drop=attn_drop, 144 | proj_drop=drop, 145 | ) 146 | self.norm3 = norm(dim, elementwise_affine=False, eps=1e-6) 147 | self.mlp = MLP( 148 | in_dim=dim, 149 | hidden_dim=int(dim * mlp_ratio), 150 | out_dim=dim, 151 | act=act, 152 | drop=drop, 153 | ) 154 | self.adaLN_modulation = nn.Sequential( 155 | nn.SiLU(), 156 | nn.Linear(cond_dim, 6 * dim), 157 | ) 158 | 159 | def forward(self, x, c, cond, x_pos_embed=None, c_pos_embed=None): 160 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( 161 | self.adaLN_modulation(cond).chunk(6, dim=1) 162 | ) 163 | x = x + gate_msa.unsqueeze(1) * self.attn( 164 | modulate(self.norm1(x), shift_msa, scale_msa), x_pos_embed 165 | ) 166 | x = x + self.xattn(self.norm2(x), c, x_pos_embed, c_pos_embed) 167 | x = x + gate_mlp.unsqueeze(1) * self.mlp( 168 | modulate(self.norm3(x), shift_mlp, scale_mlp) 169 | ) 170 | return x 171 | 172 | 173 | class AdaLNFinalLayer(nn.Module): 174 | def __init__(self, dim, cond_dim): 175 | super().__init__() 176 | self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) 177 | self.linear = nn.Linear(dim, dim) 178 | self.adaLN_modulation = nn.Sequential( 179 | nn.SiLU(), 180 | nn.Linear(cond_dim, 2 * dim), 181 | ) 182 | 183 | def forward(self, x, cond): 184 | shift, scale = self.adaLN_modulation(cond).chunk(2, dim=1) 185 | x = self.linear(modulate(self.norm(x), shift, scale)) 186 | return x 187 | -------------------------------------------------------------------------------- /experiments/dp/train_robomimic.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import torch 3 | import torch.distributed as dist 4 | import torch.multiprocessing as mp 5 | import torch.nn.functional as F 6 | import wandb 7 | from diffusers.optimization import get_scheduler 8 | from omegaconf import OmegaConf 9 | from hydra.utils import instantiate 10 | from torch.nn.parallel import DistributedDataParallel 11 | from tqdm import trange, tqdm 12 | 13 | from datasets.utils.loader import make_distributed_data_loader 14 | from environments.robomimic import make_robomimic_env 15 | from experiments.dp.train import ( 16 | train_one_step, 17 | maybe_resume_checkpoint, 18 | maybe_evaluate, 19 | maybe_save_checkpoint, 20 | ) 21 | from experiments.utils import set_seed, init_wandb, init_distributed, is_main_process 22 | 23 | 24 | def collect_rollout(config, model, device): 25 | model.eval() 26 | model = getattr(model, "module", model) # unwrap DDP 27 | 28 | # Create eval environment 29 | assert isinstance(config.dataset.hdf5_path_globs, str) 30 | env = make_robomimic_env( 31 | dataset_name=config.dataset.name, 32 | dataset_path=config.dataset.hdf5_path_globs, 33 | shape_meta=config.dataset.shape_meta, 34 | obs_horizon=model.obs_encoder.num_frames, 35 | max_episode_length=config.rollout_length, 36 | record=True, 37 | ) 38 | 39 | # Collect rollouts 40 | successes = [] 41 | for e in trange( 42 | config.num_rollouts, desc="Collecting rollouts", disable=not is_main_process() 43 | ): 44 | env.seed(e) 45 | obs = env.reset() 46 | done = False 47 | while not done: 48 | obs_tensor = { 49 | k: torch.tensor(v, device=device)[None] for k, v in obs.items() 50 | } 51 | 52 | # Sample action from model 53 | action = model.sample(obs_tensor)[0].cpu().numpy() 54 | 55 | # Step environment 56 | obs, reward, done, info = env.step(action) 57 | successes.append(info["success"]) 58 | 59 | # Compute success rate 60 | success_rate = sum(successes) / len(successes) 61 | 62 | # Record video of the last episode 63 | video = env.get_video() 64 | return success_rate, video 65 | 66 | 67 | def maybe_collect_rollout(config, step, model, device): 68 | """Collect rollouts on the main process if it's the correct step.""" 69 | # Skip rollout rollection for pretraining 70 | if "libero_90" in config.dataset.name: 71 | return 72 | 73 | if is_main_process() and ( 74 | step % config.rollout_every == 0 or step == (config.num_steps - 1) 75 | ): 76 | success_rate, video = collect_rollout(config, model, device) 77 | print(f"Step: {step} success rate: {success_rate}") 78 | # Video shape: (T, H, W, C) -> (N, T, C, H, W) 79 | video = video.transpose(0, 3, 1, 2)[None] 80 | wandb.log( 81 | { 82 | "rollout/success_rate": success_rate, 83 | "rollout/video": wandb.Video(video, fps=10), 84 | } 85 | ) 86 | dist.barrier() 87 | 88 | 89 | def train(rank, world_size, config): 90 | # Set global seed 91 | set_seed(config.seed * world_size + rank) 92 | 93 | # Initialize distributed training 94 | init_distributed(rank, world_size) 95 | device = torch.device(f"cuda:{rank}") 96 | 97 | # Initialize WANDB 98 | if is_main_process(): 99 | init_wandb(config, job_type="train") 100 | 101 | # Create dataset 102 | train_set, val_set = instantiate(config.dataset) 103 | train_loader, val_loader = make_distributed_data_loader( 104 | train_set, val_set, config.batch_size, rank, world_size 105 | ) 106 | 107 | # Create model 108 | model = instantiate(config.model).to(device) 109 | optimizer = torch.optim.AdamW(model.parameters(), **config.optimizer) 110 | scheduler = get_scheduler(optimizer=optimizer, **config.scheduler) 111 | scaler = torch.cuda.amp.GradScaler(enabled=config.use_amp) 112 | 113 | # Load pretrained model 114 | if config.pretrain_checkpoint_path: 115 | ckpt = torch.load(config.pretrain_checkpoint_path, map_location="cpu") 116 | model.load_state_dict(ckpt["model"]) 117 | print( 118 | f"Loaded pretraining checkpoint {config.pretrain_checkpoint_path}, step: {ckpt['step']}" 119 | ) 120 | 121 | # Resume from checkpoint 122 | step = maybe_resume_checkpoint(config, model, optimizer, scheduler, scaler) 123 | epoch = step // len(train_loader) 124 | 125 | # Wrap model with DDP 126 | model = DistributedDataParallel(model, device_ids=[rank], static_graph=True) 127 | 128 | # Training loop 129 | pbar = tqdm( 130 | total=config.num_steps, 131 | initial=step, 132 | desc="Training", 133 | disable=not is_main_process(), 134 | ) 135 | while step < config.num_steps: 136 | # Set epoch for distributed sampler to shuffle indices 137 | train_loader.sampler.set_epoch(epoch) 138 | 139 | # Train for one epoch 140 | for batch in train_loader: 141 | # --- Training step --- 142 | loss, info = train_one_step( 143 | config, model, optimizer, scheduler, scaler, batch, device 144 | ) 145 | 146 | # --- Logging --- 147 | if is_main_process(): 148 | pbar.set_description(f"step: {step}, loss: {loss.item():.4f}") 149 | wandb.log({f"train/{k}": v for k, v in info.items()}) 150 | 151 | # --- Evaluate if needed --- 152 | maybe_evaluate(config, step, model, val_loader, device) 153 | 154 | # ---Collect environment rollouts if needed --- 155 | maybe_collect_rollout(config, step, model, device) 156 | 157 | # --- Save checkpoint if needed --- 158 | maybe_save_checkpoint(config, step, model, optimizer, scheduler, scaler) 159 | 160 | step += 1 161 | pbar.update(1) 162 | if step >= config.num_steps: 163 | break 164 | 165 | epoch += 1 166 | 167 | 168 | @hydra.main( 169 | version_base=None, 170 | config_path="../../configs", 171 | config_name="train_dp_robomimic.yaml", 172 | ) 173 | def main(config): 174 | # Resolve hydra config 175 | OmegaConf.resolve(config) 176 | # Spawn processes 177 | world_size = torch.cuda.device_count() 178 | mp.spawn(train, args=(world_size, config), nprocs=world_size, join=True) 179 | 180 | 181 | if __name__ == "__main__": 182 | main() 183 | -------------------------------------------------------------------------------- /datasets/webvideo/dataset.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import warnings 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from decord import VideoReader, cpu 9 | from torch.utils.data import Dataset 10 | from torchvision.transforms.functional import resize, center_crop 11 | 12 | 13 | def compute_resize_shape(original_size: tuple[int, int], target_size: tuple[int, int]): 14 | """ 15 | This method calculates the dimensions (height and width) needed to resize an image so that it 16 | minimally bounds the target dimensions defined in `target_size` while preserving the original 17 | aspect ratio. 18 | 19 | Args: 20 | original_size (tuple): The original height and width of the image. 21 | target_size (tuple): The target height and width of the image. 22 | 23 | Returns: 24 | tuple: A tuple representing the computed height and width for resizing. 25 | 26 | """ 27 | h, w = original_size 28 | target_h, target_w = target_size 29 | scale = max(target_h / h, target_w / w) 30 | new_h = math.ceil(h * scale) 31 | new_w = math.ceil(w * scale) 32 | return (new_h, new_w) 33 | 34 | 35 | class MultiviewVideoDataset(Dataset): 36 | def __init__( 37 | self, 38 | index_paths: list[str], 39 | shape_meta: dict, 40 | clip_len: int, 41 | frame_skip: int = 2, 42 | obs_padding: str = "same", 43 | ): 44 | """A video dataset that augments single-view videos to multi-view videos 45 | by padding missing observations. 46 | 47 | Args: 48 | index_paths: list of paths to index files each containing paths to video files. 49 | shape_meta: dictionary containing metadata for observations and actions. 50 | clip_len: length of each clip in frames. 51 | frame_skip: number of frames to skip between frames. 52 | obs_padding: padding method for observations, chosen from ["none", "same", "random"] 53 | """ 54 | self.index_paths = index_paths 55 | self.clip_len = clip_len 56 | self.frame_skip = frame_skip 57 | self.obs_padding = obs_padding 58 | 59 | self.image_shapes = {} 60 | self.lowdim_shapes = {} 61 | for key, attr in shape_meta["obs"].items(): 62 | obs_type, obs_shape = attr["type"], tuple(attr["shape"]) 63 | if obs_type == "rgb": 64 | self.image_shapes[key] = obs_shape 65 | elif obs_type == "low_dim": 66 | self.lowdim_shapes[key] = obs_shape 67 | else: 68 | raise RuntimeError(f"Unsupported obs type: {obs_type}") 69 | self.action_shape = shape_meta["action"]["shape"] 70 | 71 | # Check that all image observations have the same shape 72 | assert ( 73 | len(set(self.image_shapes.values())) == 1 74 | ), "Image observations must have the same shape" 75 | self.image_size = next(iter(self.image_shapes.values()))[:2] 76 | self.image_keys = list(self.image_shapes.keys()) 77 | 78 | self.samples = [] 79 | for index_path in index_paths: 80 | # Each line in a data file is a path to a video clip 81 | with open(index_path, "r") as f: 82 | video_paths = f.read().splitlines() 83 | self.samples.extend(video_paths) 84 | 85 | def __len__(self): 86 | return len(self.samples) 87 | 88 | def __getitem__(self, index): 89 | # Keep trying to load video until successful 90 | clip = None 91 | while clip is None: 92 | clip = self.load_video(self.samples[index]) 93 | if clip is None: 94 | index = np.random.randint(self.__len__()) 95 | 96 | # Construct image observations based on padding method 97 | obs_dict = {} 98 | if self.obs_padding == "none": 99 | main_key = np.random.choice(self.image_keys) 100 | obs_dict[main_key] = clip 101 | for key in [k for k in self.image_keys if k != main_key]: 102 | obs_dict[key] = torch.zeros_like(clip) 103 | elif self.obs_padding == "same": 104 | for key in self.image_keys: 105 | obs_dict[key] = clip 106 | elif self.obs_padding == "random": 107 | main_key = np.random.choice(self.image_keys) 108 | obs_dict[main_key] = clip 109 | for key in [k for k in self.image_keys if k != main_key]: 110 | rand_clip = None 111 | while rand_clip is None: 112 | rand_index = np.random.randint(self.__len__()) 113 | rand_clip = self.load_video(self.samples[rand_index]) 114 | obs_dict[key] = rand_clip 115 | else: 116 | raise ValueError(f"Invalid padding method {self.obs_padding}") 117 | 118 | # Zero pad lowdim observations 119 | for key in self.lowdim_shapes.keys(): 120 | obs_dict[key] = torch.zeros(self.clip_len, *self.lowdim_shapes[key]) 121 | 122 | # Construct sample 123 | sample = { 124 | "obs": obs_dict, 125 | "action": torch.zeros(self.clip_len, *self.action_shape), 126 | "action_mask": torch.tensor(0, dtype=torch.bool), 127 | } 128 | return sample 129 | 130 | def load_video(self, fname: str): 131 | if not os.path.exists(fname): 132 | warnings.warn(f"video path not found {fname}") 133 | return None 134 | 135 | # Skip short or long videos 136 | fsize = os.path.getsize(fname) 137 | if fsize < 1 * 1024 or fsize > int(10**9): 138 | warnings.warn(f"video size {fsize} out of bounds {fname}") 139 | return None 140 | 141 | # Try loading video 142 | try: 143 | vr = VideoReader(fname, num_threads=-1, ctx=cpu(0)) 144 | except Exception: 145 | return None 146 | 147 | # Compute full clip length 148 | full_clip_len = int(self.clip_len * self.frame_skip) 149 | 150 | # Filter videos shorter than a single clip 151 | if len(vr) < full_clip_len: 152 | warnings.warn(f"video length {len(vr)} shorter than a single clip {fname}") 153 | return None 154 | 155 | # Sample random clip from video 156 | start_indx = np.random.randint(0, len(vr) - full_clip_len + 1) 157 | end_indx = start_indx + full_clip_len 158 | indices = np.linspace(start_indx, end_indx, self.clip_len) 159 | indices = np.clip(indices, start_indx, end_indx - 1).astype(np.int64) 160 | 161 | # Load clip 162 | vr.seek(0) 163 | clip = vr.get_batch(indices).asnumpy() 164 | 165 | # Postprocess video 166 | clip = torch.from_numpy(clip) 167 | clip = clip.permute(0, 3, 1, 2) # (T, C, H, W) 168 | clip = resize(clip, compute_resize_shape(clip.shape[2:], self.image_size)) 169 | clip = center_crop(clip, self.image_size) 170 | return clip.permute(0, 2, 3, 1) # (T, H, W, C) 171 | -------------------------------------------------------------------------------- /experiments/uwm/eval_droid.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import deque 3 | 4 | import hydra 5 | import torch 6 | import numpy as np 7 | 8 | from droid.controllers.oculus_controller import VRPolicy 9 | from droid.data_processing.timestep_processing import TimestepProcesser 10 | from droid.robot_env import RobotEnv 11 | from droid.user_interface.data_collector import DataCollecter 12 | from droid.user_interface.gui import RobotGUI 13 | 14 | from datasets.droid.utils import rot_6d_to_euler_angles 15 | 16 | 17 | class DROIDAgent: 18 | def __init__( 19 | self, 20 | model, 21 | device, 22 | img_keys, 23 | action_normalizer, 24 | lowdim_normalizer, 25 | obs_horizon=2, 26 | action_horizon=8, 27 | ): 28 | self.model = model 29 | self.device = device 30 | self.img_keys = img_keys 31 | self.action_normalizer = action_normalizer 32 | self.lowdim_normalizer = lowdim_normalizer 33 | self.obs_horizon = obs_horizon 34 | self.action_horizon = action_horizon 35 | 36 | assert obs_horizon == model.obs_encoder.num_frames 37 | assert action_horizon <= model.action_len 38 | 39 | self.action_scale = action_normalizer.scale[None] 40 | self.action_offset = action_normalizer.offset[None] 41 | 42 | # Observation buffer 43 | self.obs_buffer = deque(maxlen=obs_horizon) 44 | self.act_buffer = deque(maxlen=action_horizon) 45 | 46 | # Timestep processor 47 | self.timestep_processor = TimestepProcesser( 48 | ignore_action=True, 49 | action_space="cartesian_position", 50 | gripper_action_space="position", 51 | robot_state_keys=[ 52 | "cartesian_position", 53 | "gripper_position", 54 | "joint_positions", 55 | ], 56 | image_transform_kwargs=dict( 57 | remove_alpha=True, 58 | bgr_to_rgb=True, 59 | to_tensor=False, 60 | augment=False, 61 | ), 62 | ) 63 | 64 | def _convert_obs(self, observation): 65 | # Process camera images 66 | timestep = {"observation": observation} 67 | processed_timestep = self.timestep_processor.forward(timestep) 68 | camera_images = processed_timestep["observation"]["camera"]["image"] 69 | 70 | # Extract observations 71 | obs = { 72 | "exterior_image_1_left": camera_images["varied_camera"][0], 73 | "exterior_image_2_left": camera_images["varied_camera"][2], 74 | "wrist_image_left": camera_images["hand_camera"][0], 75 | "cartesian_position": observation["robot_state"]["cartesian_position"], 76 | "gripper_position": np.array( 77 | [observation["robot_state"]["gripper_position"]] 78 | ), 79 | } 80 | 81 | # Convert image observations to torch tensors 82 | for key in self.img_keys: 83 | obs[key] = torch.from_numpy(obs[key][None]).to(self.device) 84 | 85 | # Normalize low-dimensional observations and convert to torch tensors 86 | for key in self.lowdim_normalizer.keys(): 87 | lowdim_obs = self.lowdim_normalizer[key](obs[key]) 88 | obs[key] = torch.from_numpy(lowdim_obs[None]).float().to(self.device) 89 | 90 | return obs 91 | 92 | def _convert_action(self, action): 93 | xyz, rot6d, grippers = action[:3], action[3:9], action[9:] 94 | euler = rot_6d_to_euler_angles(torch.tensor(rot6d)).numpy() 95 | return np.concatenate([xyz, euler, grippers]) 96 | 97 | @torch.no_grad() 98 | def forward(self, observation): 99 | # Encode observations 100 | obs = self._convert_obs(observation) 101 | 102 | # Update observation buffer 103 | if len(self.obs_buffer) == 0: 104 | # Pad observation buffer if empty (only after reset) 105 | for _ in range(self.obs_horizon): 106 | self.obs_buffer.append(obs) 107 | else: 108 | self.obs_buffer.append(obs) 109 | 110 | # Update action buffer 111 | if len(self.act_buffer) == 0: 112 | # Stack observations by key 113 | obs_seq = {} 114 | for key in obs.keys(): 115 | obs_seq[key] = torch.stack([obs[key] for obs in self.obs_buffer], dim=1) 116 | 117 | # Sample actions 118 | act_seq = self.model.sample(obs_seq) 119 | act_seq = act_seq[0].cpu().numpy() 120 | act_seq = act_seq * self.action_scale + self.action_offset 121 | 122 | # Store new actions in buffer 123 | for t in range(self.action_horizon): 124 | self.act_buffer.append(self._convert_action(act_seq[t])) 125 | 126 | # Return next action 127 | action = self.act_buffer.popleft() 128 | # Clip action 129 | action = np.clip(action, -1, 1) 130 | return action 131 | 132 | def reset(self): 133 | self.obs_buffer.clear() 134 | self.act_buffer.clear() 135 | 136 | 137 | @hydra.main( 138 | version_base=None, config_path="../../configs", config_name="train_uwm.yaml" 139 | ) 140 | def main(config): 141 | device = torch.device(f"cuda:0") 142 | 143 | # Create model 144 | model = hydra.utils.instantiate(config.model).to(device) 145 | model.eval() 146 | 147 | # Load models 148 | ckpt = torch.load(os.path.join(config.logdir, "models.pt"), map_location="cpu") 149 | model.load_state_dict(ckpt["model"]) 150 | print(f"Loaded models from pretraining checkpoint, step: {ckpt['step']}") 151 | 152 | # Create agent 153 | img_keys = [ 154 | k for k, v in config.dataset.shape_meta["obs"].items() if v["type"] == "rgb" 155 | ] 156 | agent = DROIDAgent( 157 | model=model, 158 | device=device, 159 | img_keys=img_keys, 160 | action_normalizer=ckpt["action_normalizer"], 161 | lowdim_normalizer=ckpt["lowdim_normalizer"], 162 | obs_horizon=config.model.obs_encoder.num_frames, 163 | action_horizon=config.model.action_len // 2, 164 | ) 165 | 166 | # Create evaluation environment 167 | h, w = tuple(config.dataset.shape_meta["obs"][img_keys[0]]["shape"][:2]) 168 | img_size = (w, h) # flip width and height 169 | env = RobotEnv( 170 | action_space="cartesian_velocity", 171 | gripper_action_space="position", 172 | camera_kwargs=dict( 173 | hand_camera=dict( 174 | image=True, 175 | concatenate_images=False, 176 | resolution=img_size, 177 | resize_func="cv2", 178 | ), 179 | varied_camera=dict( 180 | image=True, 181 | concatenate_images=False, 182 | resolution=img_size, 183 | resize_func="cv2", 184 | ), 185 | ), 186 | ) 187 | controller = VRPolicy() 188 | 189 | # Launch GUI 190 | data_collector = DataCollecter( 191 | env=env, 192 | controller=controller, 193 | policy=agent, 194 | save_traj_dir=os.path.join(config.logdir, "videos"), 195 | save_data=True, 196 | ) 197 | RobotGUI(robot=data_collector) 198 | 199 | 200 | if __name__ == "__main__": 201 | main() 202 | -------------------------------------------------------------------------------- /experiments/uwm/train_robomimic.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import torch 3 | import torch.distributed as dist 4 | import torch.multiprocessing as mp 5 | import wandb 6 | from diffusers.optimization import get_scheduler 7 | from hydra.utils import instantiate 8 | from omegaconf import OmegaConf 9 | from torch.nn.parallel import DistributedDataParallel 10 | from tqdm import trange, tqdm 11 | 12 | from datasets.utils.loader import make_distributed_data_loader 13 | from environments.robomimic import make_robomimic_env 14 | from experiments.utils import set_seed, init_wandb, init_distributed, is_main_process 15 | from experiments.uwm.train import ( 16 | train_one_step, 17 | maybe_resume_checkpoint, 18 | maybe_evaluate, 19 | maybe_save_checkpoint, 20 | ) 21 | 22 | 23 | def collect_rollout(config, model, device): 24 | model.eval() 25 | model = getattr(model, "module", model) # unwrap DDP 26 | 27 | # Create eval environment 28 | assert isinstance(config.dataset.hdf5_path_globs, str) 29 | env = make_robomimic_env( 30 | dataset_name=config.dataset.name, 31 | dataset_path=config.dataset.hdf5_path_globs, 32 | shape_meta=config.dataset.shape_meta, 33 | obs_horizon=model.obs_encoder.num_frames, 34 | max_episode_length=config.rollout_length, 35 | record=True, 36 | ) 37 | 38 | # Collect rollouts 39 | successes = [] 40 | for e in trange( 41 | config.num_rollouts, desc="Collecting rollouts", disable=not is_main_process() 42 | ): 43 | env.seed(e) 44 | obs = env.reset() 45 | done = False 46 | while not done: 47 | obs_tensor = { 48 | k: torch.tensor(v, device=device)[None] for k, v in obs.items() 49 | } 50 | 51 | # Sample action from model 52 | action = model.sample(obs_tensor)[0].cpu().numpy() 53 | 54 | # Step environment 55 | obs, reward, done, info = env.step(action) 56 | successes.append(info["success"]) 57 | 58 | # Compute success rate 59 | success_rate = sum(successes) / len(successes) 60 | 61 | # Record video of the last episode 62 | video = env.get_video() 63 | return success_rate, video 64 | 65 | 66 | def maybe_collect_rollout(config, step, model, device): 67 | """Collect rollouts on the main process if it's the correct step.""" 68 | # Skip rollout rollection for pretraining 69 | if "libero_90" in config.dataset.name: 70 | return 71 | 72 | if is_main_process() and ( 73 | step % config.rollout_every == 0 or step == (config.num_steps - 1) 74 | ): 75 | success_rate, video = collect_rollout(config, model, device) 76 | print(f"Step: {step} success rate: {success_rate}") 77 | # Video shape: (T, H, W, C) -> (N, T, C, H, W) 78 | video = video.transpose(0, 3, 1, 2)[None] 79 | wandb.log( 80 | { 81 | "rollout/success_rate": success_rate, 82 | "rollout/video": wandb.Video(video, fps=10), 83 | } 84 | ) 85 | dist.barrier() 86 | 87 | 88 | def train(rank, world_size, config): 89 | # Set global seed 90 | set_seed(config.seed * world_size + rank) 91 | 92 | # Initialize distributed training 93 | init_distributed(rank, world_size) 94 | device = torch.device(f"cuda:{rank}") 95 | 96 | # Initialize WANDB 97 | if is_main_process(): 98 | init_wandb(config, job_type="train") 99 | 100 | # Create dataset and loader 101 | train_set, val_set = instantiate(config.dataset) 102 | train_loader, val_loader = make_distributed_data_loader( 103 | train_set, val_set, config.batch_size, rank, world_size 104 | ) 105 | 106 | # Create model 107 | model = instantiate(config.model).to(device) 108 | optimizer = torch.optim.AdamW(model.parameters(), **config.optimizer) 109 | scheduler = get_scheduler(optimizer=optimizer, **config.scheduler) 110 | scaler = torch.cuda.amp.GradScaler(enabled=config.use_amp) 111 | 112 | # Load pretrained model 113 | if config.pretrain_checkpoint_path: 114 | ckpt = torch.load(config.pretrain_checkpoint_path, map_location="cpu") 115 | model.load_state_dict(ckpt["model"]) 116 | print( 117 | f"Loaded pretraining checkpoint {config.pretrain_checkpoint_path}, step: {ckpt['step']}" 118 | ) 119 | 120 | # Use a smaller learning rate for the image encoder 121 | encoder_params, other_params = [], [] 122 | for name, param in model.named_parameters(): 123 | if not param.requires_grad: 124 | continue 125 | if name.startswith("obs_encoder.img_encoder"): 126 | encoder_params.append(param) 127 | else: 128 | other_params.append(param) 129 | 130 | # Define learning rates 131 | base_lr = config.optimizer.lr 132 | encoder_lr = base_lr * 0.05 133 | wd = config.optimizer.weight_decay 134 | 135 | # Construct optimizer with custom parameter groups 136 | optimizer = torch.optim.AdamW( 137 | [ 138 | {"params": encoder_params, "lr": encoder_lr, "weight_decay": wd}, 139 | {"params": other_params, "lr": base_lr, "weight_decay": wd}, 140 | ], 141 | betas=config.optimizer.betas, 142 | eps=config.optimizer.eps, 143 | ) 144 | scheduler = get_scheduler(optimizer=optimizer, **config.scheduler) 145 | 146 | # Resume from checkpoint 147 | step = maybe_resume_checkpoint(config, model, optimizer, scheduler, scaler) 148 | epoch = step // len(train_loader) 149 | 150 | # Wrap model with DDP 151 | model = DistributedDataParallel(model, device_ids=[rank], static_graph=True) 152 | 153 | # Training loop 154 | pbar = tqdm( 155 | total=config.num_steps, 156 | initial=step, 157 | desc="Training", 158 | disable=not is_main_process(), 159 | ) 160 | while step < config.num_steps: 161 | # Set epoch for distributed sampler to shuffle indices 162 | train_loader.sampler.set_epoch(epoch) 163 | 164 | # Train for one epoch 165 | for batch in train_loader: 166 | # --- Training step --- 167 | loss, info = train_one_step( 168 | config, model, optimizer, scheduler, scaler, batch, device 169 | ) 170 | 171 | # --- Logging --- 172 | if is_main_process(): 173 | pbar.set_description(f"step: {step}, loss: {loss.item():.4f}") 174 | wandb.log({f"train/{k}": v for k, v in info.items()}) 175 | 176 | # --- Evaluate if needed --- 177 | maybe_evaluate(config, step, model, val_loader, device) 178 | 179 | # ---Collect environment rollouts if needed --- 180 | maybe_collect_rollout(config, step, model, device) 181 | 182 | # --- Save checkpoint if needed --- 183 | maybe_save_checkpoint(config, step, model, optimizer, scheduler, scaler) 184 | 185 | step += 1 186 | pbar.update(1) 187 | if step >= config.num_steps: 188 | break 189 | 190 | epoch += 1 191 | 192 | 193 | @hydra.main( 194 | version_base=None, 195 | config_path="../../configs", 196 | config_name="train_uwm_robomimic.yaml", 197 | ) 198 | def main(config): 199 | # Resolve hydra config 200 | OmegaConf.resolve(config) 201 | # Spawn processes 202 | world_size = torch.cuda.device_count() 203 | mp.spawn(train, args=(world_size, config), nprocs=world_size, join=True) 204 | 205 | 206 | if __name__ == "__main__": 207 | main() 208 | -------------------------------------------------------------------------------- /experiments/gr1/train.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import torch 3 | import torch.distributed as dist 4 | import torch.multiprocessing as mp 5 | import torch.nn.functional as F 6 | import wandb 7 | from diffusers.optimization import get_scheduler 8 | from einops import rearrange 9 | from hydra.utils import instantiate 10 | from omegaconf import OmegaConf 11 | from torch.nn.parallel import DistributedDataParallel 12 | from torchvision.utils import make_grid 13 | from tqdm import tqdm 14 | 15 | from datasets.utils.loader import make_distributed_data_loader 16 | from experiments.utils import set_seed, init_wandb, init_distributed, is_main_process 17 | from experiments.uwm.train import ( 18 | process_batch, 19 | train_one_step, 20 | maybe_resume_checkpoint, 21 | maybe_save_checkpoint, 22 | ) 23 | 24 | 25 | def eval_one_epoch(config, data_loader, device, model, action_normalizer=None): 26 | model.eval() 27 | model = getattr(model, "module", model) # unwrap DDP 28 | 29 | # Unnormalize actions 30 | if action_normalizer is not None: 31 | action_scale = torch.tensor(action_normalizer.scale[None], device=device) 32 | action_offset = torch.tensor(action_normalizer.offset[None], device=device) 33 | unnormalize = lambda a: a * action_scale + action_offset 34 | else: 35 | unnormalize = lambda a: a 36 | 37 | def plot(images, nrows=4): 38 | images = rearrange(images, "b v c t h w -> (b v t) c h w") 39 | images_grid = make_grid(images, nrows) 40 | return images_grid 41 | 42 | stats = { 43 | "loss": 0, 44 | "action_loss": 0, 45 | "dynamics_loss": 0, 46 | "action_mse_joint": 0, 47 | "image_mse_joint": 0, 48 | } 49 | for batch in tqdm(data_loader, desc="Evaluating", disable=not is_main_process()): 50 | # ------------ Preprocess data ------------ # 51 | curr_obs_dict, next_obs_dict, action_norm = process_batch( 52 | batch, config.model.obs_encoder.num_frames, config.model.action_len, device 53 | ) 54 | 55 | with torch.no_grad(): 56 | # ------------ Validation loss ------------ # 57 | _, info = model(curr_obs_dict, next_obs_dict, action_norm) 58 | for k, v in info.items(): 59 | stats[k] += v 60 | 61 | # ------------ UWM Inference ------------ # 62 | action = unnormalize(action_norm) 63 | 64 | # Encode next observations 65 | next_obs = model.obs_encoder.apply_transform(next_obs_dict) 66 | 67 | # Sample next obs and actions from joint distribution 68 | next_obs_hat_joint, action_hat_joint = model.sample_joint(curr_obs_dict) 69 | joint_image_mse = F.mse_loss(next_obs_hat_joint, next_obs) 70 | dist.all_reduce(joint_image_mse, op=dist.ReduceOp.AVG) 71 | stats["image_mse_joint"] += joint_image_mse 72 | joint_action_mse = F.mse_loss(unnormalize(action_hat_joint), action) 73 | dist.all_reduce(joint_action_mse, op=dist.ReduceOp.AVG) 74 | stats["action_mse_joint"] += joint_action_mse 75 | 76 | # Average over all batches 77 | stats = {k: v / len(data_loader) for k, v in stats.items()} 78 | 79 | # Plot reconstruction 80 | stats["images"] = wandb.Image(plot(next_obs[0:1])) 81 | stats["images_joint"] = wandb.Image(plot(next_obs_hat_joint[0:1])) 82 | return stats 83 | 84 | 85 | def maybe_evaluate(config, step, model, loader, device, action_normalizer=None): 86 | """Evaluate if it's the correct step.""" 87 | if step % config.eval_every == 0: 88 | stats = eval_one_epoch(config, loader, device, model, action_normalizer) 89 | if is_main_process(): 90 | wandb.log({f"eval/{k}": v for k, v in stats.items()}) 91 | print(f"Step {step} action mse: {stats['action_mse_joint']:.4f}") 92 | 93 | 94 | def train(rank, world_size, config): 95 | # Set global seed 96 | set_seed(config.seed * world_size + rank) 97 | 98 | # Initialize distributed training 99 | init_distributed(rank, world_size) 100 | device = torch.device(f"cuda:{rank}") 101 | 102 | # Initialize WANDB 103 | if is_main_process(): 104 | init_wandb(config, job_type="train") 105 | 106 | # Create dataset 107 | train_set, val_set = instantiate(config.dataset) 108 | train_loader, val_loader = make_distributed_data_loader( 109 | train_set, val_set, config.batch_size, rank, world_size 110 | ) 111 | 112 | # Create model 113 | model = instantiate(config.model).to(device) 114 | optimizer = torch.optim.AdamW(model.parameters(), **config.optimizer) 115 | scheduler = get_scheduler(optimizer=optimizer, **config.scheduler) 116 | scaler = torch.cuda.amp.GradScaler(enabled=config.use_amp) 117 | 118 | # Load pretrained model 119 | if config.pretrain_checkpoint_path: 120 | ckpt = torch.load(config.pretrain_checkpoint_path, map_location="cpu") 121 | model.load_state_dict(ckpt["model"]) 122 | print( 123 | f"Loaded pretraining checkpoint {config.pretrain_checkpoint_path}, step: {ckpt['step']}" 124 | ) 125 | 126 | # Replace dataset normalizers to make sure data is normalized correctly 127 | if ckpt["action_normalizer"] is not None: 128 | train_set.action_normalizer = ckpt["action_normalizer"] 129 | val_set.action_normalizer = ckpt["action_normalizer"] 130 | if ckpt["lowdim_normalizer"] is not None: 131 | train_set.lowdim_normalizer = ckpt["lowdim_normalizer"] 132 | val_set.lowdim_normalizer = ckpt["lowdim_normalizer"] 133 | 134 | # Resume from checkpoint 135 | step = maybe_resume_checkpoint(config, model, optimizer, scheduler, scaler) 136 | epoch = step // len(train_loader) 137 | 138 | # Wrap model with DDP 139 | model = DistributedDataParallel(model, device_ids=[rank], static_graph=True) 140 | 141 | # Training loop 142 | pbar = tqdm( 143 | total=config.num_steps, 144 | initial=step, 145 | desc="Training", 146 | disable=not is_main_process(), 147 | ) 148 | while step < config.num_steps: 149 | # Set epoch for distributed sampler to shuffle indices 150 | train_loader.sampler.set_epoch(epoch) 151 | 152 | for batch in train_loader: 153 | # --- Training step --- 154 | loss, info = train_one_step( 155 | config, model, optimizer, scheduler, scaler, batch, device 156 | ) 157 | 158 | # --- Logging --- 159 | if is_main_process(): 160 | pbar.set_description(f"step: {step}, loss: {loss.item():.4f}") 161 | wandb.log({f"train/{k}": v for k, v in info.items()}) 162 | 163 | # --- Evaluate if needed --- 164 | maybe_evaluate( 165 | config, step, model, val_loader, device, train_set.action_normalizer 166 | ) 167 | 168 | # --- Save checkpoint if needed --- 169 | maybe_save_checkpoint( 170 | config, 171 | step, 172 | model, 173 | optimizer, 174 | scheduler, 175 | scaler, 176 | train_set.action_normalizer, 177 | train_set.lowdim_normalizer, 178 | ) 179 | 180 | step += 1 181 | pbar.update(1) 182 | if step >= config.num_steps: 183 | break 184 | 185 | epoch += 1 186 | 187 | 188 | @hydra.main( 189 | version_base=None, config_path="../../configs", config_name="train_gr1.yaml" 190 | ) 191 | def main(config): 192 | # Resolve hydra config 193 | OmegaConf.resolve(config) 194 | # Spawn processes 195 | world_size = torch.cuda.device_count() 196 | mp.spawn(train, args=(world_size, config), nprocs=world_size, join=True) 197 | 198 | 199 | if __name__ == "__main__": 200 | main() 201 | -------------------------------------------------------------------------------- /datasets/droid/convert_dataset_zarr.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import multiprocessing as mp 3 | import os 4 | 5 | import numpy as np 6 | import torch 7 | from tqdm import tqdm 8 | 9 | import tensorflow_datasets as tfds 10 | 11 | from datasets.utils.buffer import CompressedTrajectoryBuffer 12 | from datasets.droid.utils import euler_angles_to_rot_6d 13 | 14 | # Disable GPU otherwise jax allocates lots of memory 15 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 16 | 17 | 18 | # Shape metadata for zarr dataset 19 | shape_meta = { 20 | "obs": { 21 | "exterior_image_1_left": { 22 | "shape": (180, 320, 3), 23 | "type": "rgb", 24 | }, 25 | "exterior_image_2_left": { 26 | "shape": (180, 320, 3), 27 | "type": "rgb", 28 | }, 29 | "wrist_image_left": { 30 | "shape": (180, 320, 3), 31 | "type": "rgb", 32 | }, 33 | "cartesian_position": { 34 | "shape": (6,), 35 | "type": "low_dim", 36 | }, 37 | "gripper_position": { 38 | "shape": (1,), 39 | "type": "low_dim", 40 | }, 41 | }, 42 | "action": { 43 | "shape": (10,), 44 | }, 45 | } 46 | 47 | rgb_keys = [k for k, v in shape_meta["obs"].items() if v["type"] == "rgb"] 48 | lowdim_keys = [k for k, v in shape_meta["obs"].items() if v["type"] == "low_dim"] 49 | 50 | 51 | class TruncatedDataset: 52 | def __init__(self, dataset, num_episodes, filter_key=None, except_key=None): 53 | # Generate up to num_episodes episodes that contain filter_key 54 | # and doesn't contain except_key in their folderpath 55 | self.dataset = dataset 56 | self.num_episodes = num_episodes 57 | if filter_key and except_key: 58 | assert ( 59 | filter_key != except_key 60 | ), "filter_key and except_key cannot be the same" 61 | self.filter_key = filter_key 62 | self.except_key = except_key 63 | 64 | def __len__(self): 65 | return self.num_episodes 66 | 67 | def __iter__(self): 68 | episode_count = 0 69 | for episode in self.dataset: 70 | folderpath = ( 71 | episode["episode_metadata"]["recording_folderpath"] 72 | .numpy() 73 | .decode("UTF-8") 74 | ) 75 | if self.filter_key and self.filter_key not in folderpath: 76 | continue 77 | if self.except_key and self.except_key in folderpath: 78 | continue 79 | yield episode 80 | episode_count += 1 81 | if episode_count >= self.num_episodes: 82 | break 83 | 84 | 85 | def _load_episode(episode): 86 | obs_dict = {k: [] for k in shape_meta["obs"].keys()} 87 | actions = [] 88 | for step in episode: 89 | obs = step["observation"] 90 | # Load image observations 91 | for k in rgb_keys: 92 | obs_dict[k].append(obs[k].numpy()) 93 | # Load low dimensional obervations 94 | obs_dict["cartesian_position"].append(obs["cartesian_position"].numpy()) 95 | obs_dict["gripper_position"].append(obs["gripper_position"].numpy()) 96 | 97 | # Convert and load actions 98 | action_dict = step["action_dict"] 99 | cartesian_velocity = action_dict["cartesian_velocity"].numpy() 100 | xyz, euler = cartesian_velocity[:3], cartesian_velocity[3:6] 101 | rot6d = euler_angles_to_rot_6d(torch.tensor(euler)).numpy() 102 | grippers = action_dict["gripper_position"].numpy() 103 | action = np.concatenate([xyz, rot6d, grippers]) 104 | actions.append(action) 105 | 106 | episode_dict = {"obs." + k: np.stack(v) for k, v in obs_dict.items()} 107 | episode_dict["action"] = np.stack(actions) 108 | return episode_dict 109 | 110 | 111 | def preprocess_episode(episode): 112 | """Preprocess episode by removing variant tensors and standardizing data types.""" 113 | processed_episode = [] 114 | for step in episode["steps"]: 115 | processed_step = { 116 | "observation": step["observation"], 117 | "action_dict": step["action_dict"], 118 | } 119 | processed_episode.append(processed_step) 120 | return processed_episode 121 | 122 | 123 | def episode_loader(queue, buffer_args): 124 | """Process episodes from input queue and put results in output queue.""" 125 | pid = mp.current_process().name 126 | print(f"Starting episode loader process {pid}") 127 | 128 | # Initialize buffer 129 | buffer = CompressedTrajectoryBuffer(**buffer_args) 130 | 131 | while True: 132 | episode = queue.get() 133 | if episode is None: 134 | print(f"Episode loader {pid} received a termination signal") 135 | break 136 | else: 137 | print(f"Episode loader {pid} received an episode") 138 | buffer.add_episode(_load_episode(episode)) 139 | 140 | 141 | def main(args): 142 | # Load dataset 143 | raw_dataset = tfds.load(args.data_name, data_dir=args.data_dir, split=f"train") 144 | dataset = TruncatedDataset( 145 | raw_dataset, args.num_episodes, args.filter_key, args.except_key 146 | ) 147 | 148 | # Create metadata 149 | metadata = {} 150 | for key, meta in shape_meta["obs"].items(): 151 | metadata[f"obs.{key}"] = { 152 | "shape": meta["shape"], 153 | "dtype": np.uint8 if meta["type"] == "rgb" else np.float32, 154 | } 155 | metadata["action"] = {"shape": shape_meta["action"]["shape"], "dtype": np.float32} 156 | 157 | # Compute buffer capacity 158 | capacity = sum([len(episode["steps"]) for episode in dataset]) 159 | print(f"Buffer capacity: {capacity}") 160 | 161 | # Create temporary buffer to check if buffer is restored 162 | buffer = CompressedTrajectoryBuffer( 163 | storage_path=args.buffer_path, 164 | metadata=metadata, 165 | capacity=capacity, 166 | ) 167 | if buffer.restored: 168 | print("Buffer restored from disk") 169 | return 170 | 171 | # Multiprocessing setup 172 | context = mp.get_context("spawn") 173 | queue = context.Queue(maxsize=args.num_workers * 2) 174 | lock = context.Lock() # share lock across processes 175 | buffer_args = { 176 | "storage_path": args.buffer_path, 177 | "metadata": metadata, 178 | "capacity": capacity, 179 | "lock": lock, 180 | } 181 | 182 | # Start episode loader processes 183 | episode_loader_processes = [] 184 | for i in range(args.num_workers): 185 | p = context.Process( 186 | target=episode_loader, 187 | args=(queue, buffer_args), 188 | name=f"EpisodeLoaderProcess-{i}", 189 | ) 190 | p.start() 191 | episode_loader_processes.append(p) 192 | 193 | # Preprocess episodes on main process 194 | for episode in tqdm(dataset, desc="Preprocessing"): 195 | processed_episode = preprocess_episode(episode) 196 | queue.put(processed_episode) 197 | 198 | # Send termination signals to loaders 199 | for _ in range(args.num_workers): 200 | queue.put(None) 201 | 202 | # Wait for all processes to complete 203 | for p in episode_loader_processes: 204 | p.join() 205 | 206 | 207 | if __name__ == "__main__": 208 | parser = argparse.ArgumentParser() 209 | parser.add_argument("--data_name", type=str, default="droid") 210 | parser.add_argument( 211 | "--data_dir", type=str, default="/gscratch/weirdlab/memmelma/data/" 212 | ) 213 | parser.add_argument( 214 | "--buffer_path", 215 | type=str, 216 | default="/gscratch/weirdlab/zchuning/data/droid/buffer.zarr", 217 | ) 218 | parser.add_argument("--num_episodes", type=int, default=500) 219 | parser.add_argument("--num_workers", type=int, default=8) 220 | parser.add_argument("--filter_key", type=str, default=None) 221 | parser.add_argument("--except_key", type=str, default=None) 222 | args = parser.parse_args() 223 | main(args) 224 | -------------------------------------------------------------------------------- /experiments/pad/train.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import torch 3 | import torch.distributed as dist 4 | import torch.multiprocessing as mp 5 | import torch.nn.functional as F 6 | import wandb 7 | from diffusers.optimization import get_scheduler 8 | from einops import rearrange 9 | from hydra.utils import instantiate 10 | from omegaconf import OmegaConf 11 | from torch.nn.parallel import DistributedDataParallel 12 | from torchvision.utils import make_grid 13 | from tqdm import tqdm 14 | 15 | from datasets.utils.loader import make_distributed_data_loader 16 | from experiments.utils import set_seed, init_wandb, init_distributed, is_main_process 17 | from experiments.uwm.train import ( 18 | process_batch, 19 | train_one_step, 20 | maybe_resume_checkpoint, 21 | maybe_save_checkpoint, 22 | ) 23 | 24 | 25 | def eval_one_epoch(config, data_loader, device, model, action_normalizer=None): 26 | model.eval() 27 | model = getattr(model, "module", model) # unwrap DDP 28 | 29 | # Unnormalize actions 30 | if action_normalizer is not None: 31 | action_scale = torch.tensor(action_normalizer.scale[None], device=device) 32 | action_offset = torch.tensor(action_normalizer.offset[None], device=device) 33 | unnormalize = lambda a: a * action_scale + action_offset 34 | else: 35 | unnormalize = lambda a: a 36 | 37 | def decode_and_plot(latents, nrows=4): 38 | with torch.no_grad(): 39 | images = model.obs_encoder.apply_vae(latents, inverse=True) 40 | images = rearrange(images, "b v c t h w -> (b v t) c h w") 41 | images_grid = make_grid(images, nrows) 42 | return images_grid 43 | 44 | stats = { 45 | "loss": 0, 46 | "action_loss": 0, 47 | "dynamics_loss": 0, 48 | "action_mse_joint": 0, 49 | "image_mse_joint": 0, 50 | } 51 | for batch in tqdm(data_loader, desc="Evaluating", disable=not is_main_process()): 52 | # ------------ Preprocess data ------------ # 53 | curr_obs_dict, next_obs_dict, action_norm = process_batch( 54 | batch, config.model.obs_encoder.num_frames, config.model.action_len, device 55 | ) 56 | 57 | with torch.no_grad(): 58 | # ------------ Validation loss ------------ # 59 | _, info = model(curr_obs_dict, next_obs_dict, action_norm) 60 | for k, v in info.items(): 61 | stats[k] += v 62 | 63 | # ------------ UWM Inference ------------ # 64 | action = unnormalize(action_norm) 65 | 66 | # Encode next observations 67 | next_img = model.obs_encoder.apply_transform(next_obs_dict) 68 | next_latent = model.obs_encoder.apply_vae(next_img) 69 | 70 | # Sample next obs and actions from joint distribution 71 | next_latent_hat_joint, action_hat_joint = model.sample_joint(curr_obs_dict) 72 | joint_image_mse = F.mse_loss(next_latent_hat_joint, next_latent) 73 | dist.all_reduce(joint_image_mse, op=dist.ReduceOp.AVG) 74 | stats["image_mse_joint"] += joint_image_mse 75 | joint_action_mse = F.mse_loss(unnormalize(action_hat_joint), action) 76 | dist.all_reduce(joint_action_mse, op=dist.ReduceOp.AVG) 77 | stats["action_mse_joint"] += joint_action_mse 78 | 79 | # Average over all batches 80 | stats = {k: v / len(data_loader) for k, v in stats.items()} 81 | 82 | # Plot reconstruction 83 | stats["images"] = wandb.Image(decode_and_plot(next_latent[0:1])) 84 | stats["images_joint"] = wandb.Image(decode_and_plot(next_latent_hat_joint[0:1])) 85 | return stats 86 | 87 | 88 | def maybe_evaluate(config, step, model, loader, device, action_normalizer=None): 89 | """Evaluate if it's the correct step.""" 90 | if step % config.eval_every == 0: 91 | stats = eval_one_epoch(config, loader, device, model, action_normalizer) 92 | if is_main_process(): 93 | wandb.log({f"eval/{k}": v for k, v in stats.items()}) 94 | print(f"Step {step} action mse: {stats['action_mse_joint']:.4f}") 95 | 96 | 97 | def train(rank, world_size, config): 98 | # Set global seed 99 | set_seed(config.seed * world_size + rank) 100 | 101 | # Initialize distributed training 102 | init_distributed(rank, world_size) 103 | device = torch.device(f"cuda:{rank}") 104 | 105 | # Initialize WANDB 106 | if is_main_process(): 107 | init_wandb(config, job_type="train") 108 | 109 | # Create dataset 110 | train_set, val_set = instantiate(config.dataset) 111 | train_loader, val_loader = make_distributed_data_loader( 112 | train_set, val_set, config.batch_size, rank, world_size 113 | ) 114 | 115 | # Create model 116 | model = instantiate(config.model).to(device) 117 | optimizer = torch.optim.AdamW(model.parameters(), **config.optimizer) 118 | scheduler = get_scheduler(optimizer=optimizer, **config.scheduler) 119 | scaler = torch.cuda.amp.GradScaler(enabled=config.use_amp) 120 | 121 | # Load pretrained model 122 | if config.pretrain_checkpoint_path: 123 | ckpt = torch.load(config.pretrain_checkpoint_path, map_location="cpu") 124 | model.load_state_dict(ckpt["model"]) 125 | print( 126 | f"Loaded pretraining checkpoint {config.pretrain_checkpoint_path}, step: {ckpt['step']}" 127 | ) 128 | 129 | # Replace dataset normalizers to make sure data is normalized correctly 130 | if ckpt["action_normalizer"] is not None: 131 | train_set.action_normalizer = ckpt["action_normalizer"] 132 | val_set.action_normalizer = ckpt["action_normalizer"] 133 | if ckpt["lowdim_normalizer"] is not None: 134 | train_set.lowdim_normalizer = ckpt["lowdim_normalizer"] 135 | val_set.lowdim_normalizer = ckpt["lowdim_normalizer"] 136 | 137 | # Resume from checkpoint 138 | step = maybe_resume_checkpoint(config, model, optimizer, scheduler, scaler) 139 | epoch = step // len(train_loader) 140 | 141 | # Wrap model with DDP 142 | model = DistributedDataParallel(model, device_ids=[rank], static_graph=True) 143 | 144 | # Training loop 145 | pbar = tqdm( 146 | total=config.num_steps, 147 | initial=step, 148 | desc="Training", 149 | disable=not is_main_process(), 150 | ) 151 | while step < config.num_steps: 152 | # Set epoch for distributed sampler to shuffle indices 153 | train_loader.sampler.set_epoch(epoch) 154 | 155 | for batch in train_loader: 156 | # --- Training step --- 157 | loss, info = train_one_step( 158 | config, model, optimizer, scheduler, scaler, batch, device 159 | ) 160 | 161 | # --- Logging --- 162 | if is_main_process(): 163 | pbar.set_description(f"step: {step}, loss: {loss.item():.4f}") 164 | wandb.log({f"train/{k}": v for k, v in info.items()}) 165 | 166 | # --- Evaluate if needed --- 167 | maybe_evaluate( 168 | config, step, model, val_loader, device, train_set.action_normalizer 169 | ) 170 | 171 | # --- Save checkpoint if needed --- 172 | maybe_save_checkpoint( 173 | config, 174 | step, 175 | model, 176 | optimizer, 177 | scheduler, 178 | scaler, 179 | train_set.action_normalizer, 180 | train_set.lowdim_normalizer, 181 | ) 182 | 183 | step += 1 184 | pbar.update(1) 185 | if step >= config.num_steps: 186 | break 187 | 188 | epoch += 1 189 | 190 | 191 | @hydra.main( 192 | version_base=None, config_path="../../configs", config_name="train_pad.yaml" 193 | ) 194 | def main(config): 195 | # Resolve hydra config 196 | OmegaConf.resolve(config) 197 | # Spawn processes 198 | world_size = torch.cuda.device_count() 199 | mp.spawn(train, args=(world_size, config), nprocs=world_size, join=True) 200 | 201 | 202 | if __name__ == "__main__": 203 | main() 204 | -------------------------------------------------------------------------------- /datasets/utils/buffer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from multiprocessing import Lock 4 | from os.path import expanduser, expandvars 5 | from typing import Union, Optional 6 | 7 | import numcodecs 8 | import numpy as np 9 | import zarr 10 | 11 | 12 | def get_optimal_chunks( 13 | shape: tuple[int, ...], dtype: Union[str, np.dtype], target_chunk_bytes: float = 2e6 14 | ) -> tuple[int, ...]: 15 | """ 16 | Calculates the optimal chunk sizes for an array given its shape, data type, and a target chunk size in bytes. 17 | 18 | Args: 19 | shape: The shape of the array. 20 | dtype: The data type of the array. 21 | target_chunk_bytes: The target size for each chunk in bytes. Defaults to 2e6 (2 MB). 22 | 23 | Returns: 24 | The optimal chunk dimensions for the given array shape and data type, aiming to not exceed 25 | the target chunk size in bytes. 26 | """ 27 | itemsize = np.dtype(dtype).itemsize 28 | rshape = list(shape[::-1]) 29 | 30 | # Find the index to split the shape, starting from the rightmost dimension 31 | split_idx = len(shape) - 1 32 | for i in range(len(shape) - 1): 33 | this_chunk_bytes = itemsize * np.prod(rshape[:i]) 34 | next_chunk_bytes = itemsize * np.prod(rshape[: i + 1]) 35 | if this_chunk_bytes <= target_chunk_bytes < next_chunk_bytes: 36 | split_idx = i 37 | break 38 | 39 | # Handle jagged chunk dimension 40 | rchunks = rshape[:split_idx] 41 | item_chunk_bytes = itemsize * np.prod(rchunks) 42 | next_chunk_length = min( 43 | rshape[split_idx], math.ceil(target_chunk_bytes / item_chunk_bytes) 44 | ) 45 | rchunks.append(next_chunk_length) 46 | 47 | # Handle remaining dimensions 48 | rchunks.extend([1] * (len(shape) - len(rchunks))) 49 | chunks = tuple(rchunks[::-1]) 50 | return chunks 51 | 52 | 53 | class CompressedTrajectoryBuffer: 54 | """ 55 | A class that stores trajectory data in a compressed zarr array. 56 | """ 57 | 58 | def __init__( 59 | self, 60 | storage_path: str, 61 | metadata: dict[str, dict[str, any]], 62 | capacity: Optional[int] = None, 63 | lock: Optional[Lock] = None, 64 | ): 65 | """ 66 | Initialize the trajectory buffer. If there is an existing buffer at the given path, it will be restored. 67 | 68 | Args: 69 | storage_path: Path to the buffer storage. 70 | metadata: Dictionary containing metadata for each data key. Each key should 71 | map to a dictionary containing the following keys: 72 | - shape: shape of the data 73 | - dtype: dtype of the data 74 | capacity: Maximum number of transition steps that can be stored in the buffer. 75 | Only used when creating a new buffer. 76 | lock: Multiprocessing lock to synchronize access to the buffer. If None, a new lock will be created. 77 | """ 78 | # Create zarr storage and root group 79 | storage_path = expandvars(expanduser(storage_path)) 80 | self.restored = os.path.exists(storage_path) 81 | self.storage = zarr.DirectoryStore(storage_path) 82 | 83 | # Mutex for zarr storage 84 | self.lock = Lock() if lock is None else lock 85 | 86 | # Create data and metadata groups 87 | self.root = zarr.group(store=self.storage) 88 | self.data = self.root.require_group("data") 89 | self.meta = self.root.require_group("meta") 90 | 91 | if self.restored: 92 | print(f"Restoring buffer from {storage_path}") 93 | assert "episode_ends" in self.meta 94 | assert all(key in self.data for key in metadata) 95 | assert all( 96 | self.data[key].shape[1:] == value["shape"] 97 | for key, value in metadata.items() 98 | ) 99 | 100 | # Check that all data have the same length and restore capacity 101 | lengths = {self.data[key].shape[0] for key in self.data} 102 | assert len(lengths) == 1, "Inconsistent data lengths in the buffer" 103 | self.capacity = lengths.pop() 104 | else: 105 | with self.lock: 106 | print(f"Creating new buffer at {storage_path}") 107 | assert capacity is not None, "Capacity must be specified for new buffer" 108 | self.capacity = capacity 109 | 110 | # Create empty episode_ends 111 | self.meta.zeros( 112 | name="episode_ends", 113 | shape=(0,), 114 | dtype=np.int64, 115 | compressor=None, 116 | ) 117 | 118 | # Allocate space for data 119 | for key, value in metadata.items(): 120 | shape = (capacity,) + tuple(value["shape"]) 121 | dtype = value["dtype"] 122 | if dtype == np.uint8: 123 | # Chunk and compress images individually 124 | chunks = (1,) + shape[1:] 125 | compressor = numcodecs.Blosc( 126 | cname="lz4", clevel=5, shuffle=numcodecs.Blosc.NOSHUFFLE 127 | ) 128 | else: 129 | # Chunk and compress other data by computing optimal chunks 130 | chunks = get_optimal_chunks(shape, dtype) 131 | compressor = numcodecs.Blosc( 132 | cname="lz4", clevel=0, shuffle=numcodecs.Blosc.NOSHUFFLE 133 | ) 134 | # Create new array 135 | self.data.zeros( 136 | name=key, 137 | shape=shape, 138 | chunks=chunks, 139 | dtype=dtype, 140 | compressor=compressor, 141 | object_codec=numcodecs.Pickle(), 142 | ) 143 | 144 | @property 145 | def episode_ends(self) -> np.ndarray: 146 | return self.meta["episode_ends"] 147 | 148 | @property 149 | def num_episodes(self) -> int: 150 | return len(self.episode_ends) 151 | 152 | @property 153 | def num_steps(self) -> int: 154 | return self.episode_ends[-1] if len(self.episode_ends) > 0 else 0 155 | 156 | def add_episode(self, data: dict[str, np.ndarray]): 157 | with self.lock: 158 | # Get episode length 159 | episode_lens = np.array([v.shape[0] for v in data.values()]) 160 | assert np.all(episode_lens == episode_lens[0]) 161 | episode_len = episode_lens[0] 162 | 163 | # Compute corresponding buffer indices 164 | start_ind = self.num_steps 165 | end_ind = start_ind + episode_len 166 | if end_ind > self.capacity: 167 | raise RuntimeError("Buffer capacity exceeded") 168 | 169 | # Copy data to buffer 170 | for key, value in data.items(): 171 | arr = self.data[key] 172 | arr[start_ind:end_ind] = value 173 | 174 | # Update episode_ends 175 | self.episode_ends.resize(len(self.episode_ends) + 1) 176 | self.episode_ends[-1] = end_ind 177 | 178 | # Rechunk and recompress episode_ends if necessary 179 | if self.episode_ends.chunks[0] < self.episode_ends.shape[0]: 180 | new_chunk_len = self.episode_ends.shape[0] * 1.5 181 | new_chunks = (new_chunk_len,) + self.episode_ends.chunks[1:] 182 | self.meta.move("episode_ends", "_temp") 183 | zarr.copy( 184 | source=self.meta["_temp"], 185 | dest=self.meta, 186 | name="episode_ends", 187 | chunks=new_chunks, 188 | compressor=None, 189 | ) 190 | del self.meta["_temp"] 191 | 192 | def __repr__(self) -> str: 193 | return str(self.root.tree()) 194 | 195 | def keys(self): 196 | return self.data.keys() 197 | 198 | def values(self): 199 | return self.data.values() 200 | 201 | def items(self): 202 | return self.data.items() 203 | 204 | def __getitem__(self, key): 205 | return self.data[key] 206 | 207 | def __contains__(self, key): 208 | return key in self.data 209 | -------------------------------------------------------------------------------- /models/gr1/gr1.py: -------------------------------------------------------------------------------- 1 | # Copyright (2024) Bytedance Ltd. and/or its affiliates 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """GR-1 model.""" 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | from einops import rearrange 20 | 21 | from models.common.attention import AttentionBlock 22 | from models.common.utils import get_nd_sinusoidal_embed 23 | from models.gr1.obs_encoder import GR1ObservationEncoder 24 | from models.gr1.vision_transformer import Block 25 | 26 | 27 | class GR1(nn.Module): 28 | def __init__( 29 | self, 30 | action_len: int, 31 | action_dim: int, 32 | obs_encoder: GR1ObservationEncoder, 33 | embed_dim: int, 34 | image_size: tuple[int, ...], 35 | patch_size: tuple[int, ...], 36 | num_chans: int = 3, 37 | depth: int = 12, 38 | num_heads: int = 12, 39 | mlp_ratio: int = 4, 40 | qkv_bias: bool = True, 41 | decoder_depth: int = 3, 42 | decoder_num_heads: int = 16, 43 | decoder_mlp_ratio: int = 4, 44 | decoder_qkv_bias: bool = True, 45 | ): 46 | super().__init__() 47 | self.action_dim = action_dim 48 | self.action_len = action_len 49 | self.image_size = image_size 50 | self.patch_size = patch_size 51 | self.num_chans = num_chans 52 | 53 | # Observation encoder (with perceiver resampler) 54 | self.obs_encoder = obs_encoder 55 | self.obs_len = obs_encoder.num_latents 56 | 57 | # Action query token 58 | self.action_queries = nn.Parameter( 59 | torch.empty(1, self.action_len, embed_dim).normal_(mean=0, std=0.02) 60 | ) 61 | 62 | # Observation query token 63 | self.obs_queries = nn.Parameter( 64 | torch.empty(1, self.obs_len, embed_dim).normal_(mean=0, std=0.02) 65 | ) 66 | 67 | # Main transformer 68 | self.transformer = nn.ModuleList( 69 | [ 70 | AttentionBlock( 71 | dim=embed_dim, 72 | num_heads=num_heads, 73 | mlp_ratio=mlp_ratio, 74 | qkv_bias=qkv_bias, 75 | ) 76 | for _ in range(depth) 77 | ] 78 | ) 79 | 80 | # Action head 81 | self.action_head = nn.Sequential( 82 | nn.Linear(embed_dim, embed_dim // 2), 83 | nn.SiLU(), 84 | nn.Linear(embed_dim // 2, embed_dim // 2), 85 | nn.SiLU(), 86 | nn.Linear(embed_dim // 2, self.action_dim), 87 | ) 88 | 89 | # Image decoder 90 | self.decoder_query_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 91 | decoder_pos_embed = get_nd_sinusoidal_embed( 92 | embed_dim, 93 | (self.image_size // self.patch_size, self.image_size // self.patch_size), 94 | ) 95 | self.decoder_pos_embed = nn.Parameter( 96 | torch.from_numpy(decoder_pos_embed).float()[None], requires_grad=False 97 | ) 98 | self.decoder_proj = nn.Linear(embed_dim, embed_dim, bias=True) 99 | self.decoder_blocks = nn.ModuleList( 100 | [ 101 | Block( 102 | dim=embed_dim, 103 | num_heads=decoder_num_heads, 104 | mlp_ratio=decoder_mlp_ratio, 105 | qkv_bias=decoder_qkv_bias, 106 | ) 107 | for _ in range(decoder_depth) 108 | ] 109 | ) 110 | self.decoder_head = nn.Sequential( 111 | nn.LayerNorm(embed_dim), 112 | nn.Linear(embed_dim, self.patch_size**2 * 3, bias=True), 113 | ) 114 | 115 | def predict_next_obs(self, next_obs_embed): 116 | next_obs_embed = rearrange( 117 | next_obs_embed, 118 | "b (v t n) d -> (b v t) n d", 119 | v=self.obs_encoder.num_views, 120 | t=self.obs_encoder.num_frames, 121 | ) 122 | next_obs_embed = self.decoder_proj(next_obs_embed) 123 | next_obs_queries = self.decoder_query_token + self.decoder_pos_embed 124 | next_obs_pred = torch.cat( 125 | [next_obs_embed, next_obs_queries.repeat(next_obs_embed.shape[0], 1, 1)], 126 | dim=1, 127 | ) 128 | for block in self.decoder_blocks: 129 | next_obs_pred = block(next_obs_pred) 130 | next_obs_pred = self.decoder_head( 131 | next_obs_pred[:, -self.decoder_pos_embed.shape[1] :] 132 | ) 133 | next_obs_pred = rearrange( 134 | next_obs_pred, 135 | "(b v t) (h w) (c ph pw) -> b v c t (h ph) (w pw)", 136 | v=self.obs_encoder.num_views, 137 | t=self.obs_encoder.num_frames, 138 | h=self.image_size // self.patch_size, 139 | w=self.image_size // self.patch_size, 140 | ph=self.patch_size, 141 | pw=self.patch_size, 142 | ) 143 | return next_obs_pred 144 | 145 | def forward(self, obs_dict, next_obs_dict, action, action_mask=None): 146 | # Get obs patches and prediction targets 147 | curr_embeds, next_obs = self.obs_encoder.encode_curr_and_next_obs( 148 | obs_dict, next_obs_dict 149 | ) 150 | 151 | # Transformer inputs 152 | action_queries = self.action_queries.expand(action.shape[0], -1, -1) 153 | obs_queries = self.obs_queries.expand(next_obs.shape[0], -1, -1) 154 | x = torch.cat([curr_embeds, action_queries, obs_queries], dim=1) 155 | 156 | # Attention mask 157 | attn_mask = None 158 | if action_mask is not None: 159 | # Action mask has shape (B,) 160 | B = action_mask.shape[0] 161 | N = self.action_len + self.obs_len * 2 162 | attn_mask = torch.ones((B, N, N), device=action.device, dtype=torch.bool) 163 | attn_mask[ 164 | ~action_mask, :, self.obs_len : self.obs_len + self.action_len 165 | ] = 0 166 | 167 | # Transformer forward pass 168 | for block in self.transformer: 169 | x = block(x, attn_mask=attn_mask) 170 | 171 | # Action prediction 172 | action_embed = x[:, self.obs_len : self.obs_len + self.action_len] 173 | action_pred = self.action_head(action_embed) 174 | 175 | # Image prediction 176 | next_obs_embed = x[:, -self.obs_len :] 177 | next_obs_pred = self.predict_next_obs(next_obs_embed) 178 | 179 | # Compute losses 180 | if action_mask is None: 181 | action_loss = F.mse_loss(action_pred, action) 182 | else: 183 | action_loss = F.mse_loss(action_pred[action_mask], action[action_mask]) 184 | dynamics_loss = F.mse_loss(next_obs_pred, next_obs) 185 | loss = action_loss + dynamics_loss 186 | info = { 187 | "loss": loss.item(), 188 | "action_loss": action_loss.item(), 189 | "dynamics_loss": dynamics_loss.item(), 190 | } 191 | return loss, info 192 | 193 | @torch.no_grad() 194 | def sample(self, obs_dict): 195 | _, action_sample = self.sample_joint(obs_dict) 196 | return action_sample 197 | 198 | @torch.no_grad() 199 | def sample_joint(self, obs_dict): 200 | # Get obs patches and prediction targets 201 | curr_embeds = self.obs_encoder.encode_obs(obs_dict) 202 | 203 | # Transformer inputs 204 | action_queries = self.action_queries.expand(curr_embeds.shape[0], -1, -1) 205 | obs_queries = self.obs_queries.expand(curr_embeds.shape[0], -1, -1) 206 | x = torch.cat([curr_embeds, action_queries, obs_queries], dim=1) 207 | 208 | # Transformer forward pass 209 | for block in self.transformer: 210 | x = block(x) 211 | 212 | # Action prediction 213 | action_embed = x[:, self.obs_len : self.obs_len + self.action_len] 214 | action_pred = self.action_head(action_embed) 215 | 216 | # Image prediction 217 | next_obs_embed = x[:, -self.obs_len :] 218 | next_obs_pred = self.predict_next_obs(next_obs_embed) 219 | return next_obs_pred, action_pred 220 | -------------------------------------------------------------------------------- /models/uwm/obs_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | from einops import rearrange 6 | 7 | from models.common.language import CLIPTextEncoder 8 | from models.common.transforms import VideoTransform, VAEDownsample 9 | from models.common.vision import ResNetImageEncoder, ViTImageEncoder 10 | 11 | 12 | class UWMObservationEncoder(nn.Module): 13 | def __init__( 14 | self, 15 | shape_meta: dict, 16 | num_frames: int, 17 | embed_dim: int, 18 | resize_shape: Tuple[int, int] = None, 19 | crop_shape: Tuple[int, int] = None, 20 | random_crop: bool = True, 21 | color_jitter: Optional[Dict] = None, 22 | imagenet_norm: bool = False, 23 | vision_backbone: str = "vit", 24 | use_low_dim: bool = True, 25 | use_language: bool = True, 26 | ): 27 | super().__init__() 28 | self.shape_meta = shape_meta 29 | self.num_frames = num_frames 30 | self.rgb_keys = sorted( 31 | [k for k, v in shape_meta["obs"].items() if v["type"] == "rgb"] 32 | ) 33 | self.low_dim_keys = sorted( 34 | [k for k, v in shape_meta["obs"].items() if v["type"] == "low_dim"] 35 | ) 36 | self.num_views = len(self.rgb_keys) 37 | self.embed_dim = embed_dim 38 | 39 | # Image augmentation 40 | self.obs_transform = VideoTransform( 41 | resize_shape=resize_shape, 42 | crop_shape=crop_shape, 43 | random_crop=random_crop, 44 | color_jitter=color_jitter, 45 | imagenet_norm=imagenet_norm, 46 | ) 47 | 48 | # Image encoder 49 | if vision_backbone == "vit": 50 | self.img_encoder = ViTImageEncoder( 51 | num_views=self.num_views, 52 | embed_dim=embed_dim, 53 | ) 54 | elif vision_backbone == "resnet": 55 | self.img_encoder = ResNetImageEncoder( 56 | num_views=self.num_views, 57 | embed_dim=embed_dim, 58 | ) 59 | else: 60 | raise NotImplementedError(f"Unsupported vision backbone: {vision_backbone}") 61 | 62 | # Low-dim observations 63 | self.use_low_dim = use_low_dim 64 | 65 | # Language encoder 66 | self.use_language = use_language 67 | self.text_encoder = ( 68 | CLIPTextEncoder(embed_dim=embed_dim) if use_language else None 69 | ) 70 | 71 | # VAE downsampling 72 | self.vae = VAEDownsample() 73 | 74 | def apply_transform(self, obs_dicts: Union[dict, list[dict]]): 75 | """ 76 | Accept a list of observation dictionaries and apply the same transform to each. 77 | """ 78 | if isinstance(obs_dicts, dict): 79 | obs_dicts = [obs_dicts] 80 | is_singleton = True 81 | else: 82 | is_singleton = False 83 | assert isinstance(obs_dicts, list) 84 | 85 | # Apply the same transform to each observation 86 | num_obs = len(obs_dicts) 87 | transformed_imgs = [[] for _ in range(num_obs)] 88 | for key in self.rgb_keys: 89 | combined_imgs = torch.cat([obs_dict[key] for obs_dict in obs_dicts], dim=0) 90 | combined_imgs = self.obs_transform(combined_imgs) 91 | chunked_imgs = combined_imgs.chunk(num_obs, dim=0) 92 | for i, img in enumerate(chunked_imgs): 93 | transformed_imgs[i].append(img) 94 | 95 | # Stack transformed images 96 | # Each image has shape (B, V, C, T, H, W) 97 | transformed_imgs = [torch.stack(imgs, dim=1) for imgs in transformed_imgs] 98 | if is_singleton: 99 | transformed_imgs = transformed_imgs[0] 100 | return transformed_imgs 101 | 102 | def apply_vae( 103 | self, 104 | imgs_list: Union[torch.Tensor, list[torch.Tensor]], 105 | inverse: bool = False, 106 | microbatch_size: int = 72, # Tuned for 40GB VRAM 107 | ): 108 | """ 109 | Accept a list of images and apply VAE to downsample or upsample images. 110 | If inverse is False, downsample images. Otherwise, upsample images. 111 | Process images in microbatches to reduce memory usage. 112 | """ 113 | if isinstance(imgs_list, torch.Tensor): 114 | imgs_list = [imgs_list] 115 | is_singleton = True 116 | else: 117 | is_singleton = False 118 | assert isinstance(imgs_list, list) 119 | imgs = torch.cat(imgs_list, dim=0) 120 | 121 | # Flatten multiview videos to images 122 | B, V = imgs.shape[:2] 123 | imgs = rearrange(imgs, "b v c t h w -> (b v t) c h w") 124 | 125 | # Process images in microbatches 126 | transformed_imgs = [] 127 | for i in range(0, imgs.shape[0], microbatch_size): 128 | batch_imgs = imgs[i : i + microbatch_size] 129 | if inverse: 130 | batch_transformed_imgs = self.vae.inverse(batch_imgs) 131 | else: 132 | batch_transformed_imgs = self.vae(batch_imgs) 133 | transformed_imgs.append(batch_transformed_imgs) 134 | transformed_imgs = torch.cat(transformed_imgs, dim=0) 135 | 136 | # Unflatten images to multiview videos 137 | transformed_imgs = rearrange( 138 | transformed_imgs, "(b v t) c h w -> b v c t h w", b=B, v=V 139 | ) 140 | if not is_singleton: 141 | chunk_sizes = [img.shape[0] for img in imgs_list] 142 | transformed_imgs = list(transformed_imgs.split(chunk_sizes, dim=0)) 143 | return transformed_imgs 144 | 145 | def encode_curr_obs(self, curr_obs_dict: dict): 146 | # Encoder current observations to features 147 | curr_imgs = self.apply_transform(curr_obs_dict) 148 | curr_feats = self.img_encoder(curr_imgs) # (B, V*T*D) 149 | 150 | if self.use_low_dim: 151 | low_dims = [curr_obs_dict[key] for key in self.low_dim_keys] 152 | low_dims = torch.cat(low_dims, dim=-1).flatten(1) 153 | curr_feats = torch.cat([curr_feats, low_dims], dim=-1) 154 | 155 | if self.use_language: 156 | lang_feats = self.text_encoder( 157 | input_ids=curr_obs_dict["input_ids"], 158 | attention_mask=curr_obs_dict["attention_mask"], 159 | ) 160 | curr_feats = torch.cat([curr_feats, lang_feats], dim=-1) 161 | return curr_feats 162 | 163 | def encode_next_obs(self, next_obs_dict: dict): 164 | # Encoder next observations to latents 165 | next_imgs = self.apply_transform(next_obs_dict) 166 | next_latents = self.apply_vae(next_imgs) 167 | return next_latents 168 | 169 | def encode_curr_and_next_obs(self, curr_obs_dict: dict, next_obs_dict: dict): 170 | # Apply the same transform to obs and next obs 171 | curr_imgs, next_imgs = self.apply_transform([curr_obs_dict, next_obs_dict]) 172 | 173 | # Encode current obs to features 174 | curr_feats = self.img_encoder(curr_imgs) # (B, V*T*D) 175 | 176 | if self.use_low_dim: 177 | low_dims = [curr_obs_dict[key] for key in self.low_dim_keys] 178 | low_dims = torch.cat(low_dims, dim=-1).flatten(1) 179 | curr_feats = torch.cat([curr_feats, low_dims], dim=-1) 180 | 181 | if self.use_language: 182 | lang_feats = self.text_encoder( 183 | input_ids=curr_obs_dict["input_ids"], 184 | attention_mask=curr_obs_dict["attention_mask"], 185 | ) 186 | curr_feats = torch.cat([curr_feats, lang_feats], dim=-1) 187 | 188 | # Encode next obs to latents 189 | next_latents = self.apply_vae(next_imgs) 190 | return curr_feats, next_latents 191 | 192 | def feat_dim(self): 193 | # Return the dimension of encoded features 194 | low_dim_size = sum( 195 | self.shape_meta["obs"][key]["shape"][-1] for key in self.low_dim_keys 196 | ) 197 | return ( 198 | self.num_views * self.num_frames * self.embed_dim 199 | + int(self.use_low_dim) * self.num_frames * low_dim_size 200 | + int(self.use_language) * self.embed_dim 201 | ) 202 | 203 | def latent_img_shape(self): 204 | # Construct dummy image and forward pass to get latent image shape 205 | dummy_obs = {} 206 | for k in self.rgb_keys: 207 | img_shape = self.shape_meta["obs"][k]["shape"] 208 | dummy_obs[k] = torch.zeros( 209 | 1, self.num_frames, *img_shape, dtype=torch.uint8 210 | ) 211 | with torch.no_grad(): 212 | latent = self.encode_next_obs(dummy_obs) 213 | return tuple(latent.shape[1:]) 214 | --------------------------------------------------------------------------------