├── real_preprocess ├── octo_oxe_data_utils │ ├── __init__.py │ ├── utils │ │ ├── __init__.py │ │ ├── goal_relabeling.py │ │ ├── task_augmentation.py │ │ └── text_processing.py │ ├── setup.py │ ├── requirements.txt │ ├── obs_transforms.py │ ├── traj_transforms.py │ └── oxe │ │ ├── __init__.py │ │ └── oxe_dataset_mixes.py ├── mujoco_menagerie │ └── robotiq_2f85 │ │ ├── 2f85.png │ │ ├── assets │ │ ├── base.stl │ │ ├── pad.stl │ │ ├── coupler.stl │ │ ├── driver.stl │ │ ├── follower.stl │ │ ├── base_mount.stl │ │ ├── silicone_pad.stl │ │ └── spring_link.stl │ │ ├── LICENSE │ │ ├── README.md │ │ ├── scene.xml │ │ └── 2f85.xml └── requirements.txt ├── data_info ├── ep_lens.npy ├── ep_start_end_ids.npy ├── except_lang_idx │ └── except_lang_idx.npy ├── .hydra │ ├── overrides.yaml │ ├── config.yaml │ ├── hydra.yaml │ └── merged_config.yaml ├── droid_success_languaged_0803_stat_0.1_0.5.json ├── droid_success_full_0803_stat.json ├── droid_success_languaged_0803_stat_0.02_0.05.json ├── droid_failure.json └── austin_buds_dataset_converted_externally_to_rlds.json ├── assets └── seer_method.jpg ├── requirements.txt ├── docs ├── LIBERO_LONG_INSTALL.md ├── REAL-WORLD_INSTALL.md ├── REAL-WORLD_FT_SC.md ├── CALVIN_ABC-D_INSTALL.md ├── REAL-WORLD_PRETRAIN.md ├── REAL-WORLD_INFERENCE.md ├── LIBERO_LONG_RUN.md ├── CALVIN_ABC-D_RUN.md ├── REAL-WORLD_PREPROCESS.md └── REAL-WORLD_POSTPROCESS.md ├── scripts ├── CALVIN_ABC_D │ ├── Seer │ │ ├── scratch.sh │ │ ├── pretrain.sh │ │ ├── finetune.sh │ │ └── eval.sh │ └── Seer-Large │ │ ├── scratch.sh │ │ ├── pretrain.sh │ │ ├── eval.sh │ │ └── finetune.sh ├── REAL │ ├── slurm_s_oxe_cluster.sh │ ├── deploy.sh │ ├── slurm_s_full_cluster.sh │ ├── slurm_s_language_cluster.sh │ ├── single_node_full_cluster.sh │ ├── single_node_language_cluster.sh │ ├── single_node_scratch.sh │ └── single_node_ft.sh └── LIBERO_LONG │ └── Seer │ ├── scratch.sh │ ├── pretrain.sh │ ├── eval.sh │ └── finetune.sh ├── deploy.py ├── .gitignore ├── eval_libero.py ├── models ├── perceiver_resampler.py └── vit_mae.py ├── README.md ├── eval_calvin.py ├── utils ├── distributed_utils.py ├── convert_libero_per_step.py └── arguments_utils.py ├── slurm_train_intern.py ├── LICENSE └── train.py /real_preprocess/octo_oxe_data_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /real_preprocess/octo_oxe_data_utils/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data_info/ep_lens.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InternRobotics/Seer/HEAD/data_info/ep_lens.npy -------------------------------------------------------------------------------- /assets/seer_method.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InternRobotics/Seer/HEAD/assets/seer_method.jpg -------------------------------------------------------------------------------- /data_info/ep_start_end_ids.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InternRobotics/Seer/HEAD/data_info/ep_start_end_ids.npy -------------------------------------------------------------------------------- /data_info/except_lang_idx/except_lang_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InternRobotics/Seer/HEAD/data_info/except_lang_idx/except_lang_idx.npy -------------------------------------------------------------------------------- /real_preprocess/mujoco_menagerie/robotiq_2f85/2f85.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InternRobotics/Seer/HEAD/real_preprocess/mujoco_menagerie/robotiq_2f85/2f85.png -------------------------------------------------------------------------------- /real_preprocess/mujoco_menagerie/robotiq_2f85/assets/base.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InternRobotics/Seer/HEAD/real_preprocess/mujoco_menagerie/robotiq_2f85/assets/base.stl -------------------------------------------------------------------------------- /real_preprocess/mujoco_menagerie/robotiq_2f85/assets/pad.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InternRobotics/Seer/HEAD/real_preprocess/mujoco_menagerie/robotiq_2f85/assets/pad.stl -------------------------------------------------------------------------------- /real_preprocess/mujoco_menagerie/robotiq_2f85/assets/coupler.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InternRobotics/Seer/HEAD/real_preprocess/mujoco_menagerie/robotiq_2f85/assets/coupler.stl -------------------------------------------------------------------------------- /real_preprocess/mujoco_menagerie/robotiq_2f85/assets/driver.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InternRobotics/Seer/HEAD/real_preprocess/mujoco_menagerie/robotiq_2f85/assets/driver.stl -------------------------------------------------------------------------------- /real_preprocess/mujoco_menagerie/robotiq_2f85/assets/follower.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InternRobotics/Seer/HEAD/real_preprocess/mujoco_menagerie/robotiq_2f85/assets/follower.stl -------------------------------------------------------------------------------- /real_preprocess/mujoco_menagerie/robotiq_2f85/assets/base_mount.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InternRobotics/Seer/HEAD/real_preprocess/mujoco_menagerie/robotiq_2f85/assets/base_mount.stl -------------------------------------------------------------------------------- /real_preprocess/mujoco_menagerie/robotiq_2f85/assets/silicone_pad.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InternRobotics/Seer/HEAD/real_preprocess/mujoco_menagerie/robotiq_2f85/assets/silicone_pad.stl -------------------------------------------------------------------------------- /real_preprocess/mujoco_menagerie/robotiq_2f85/assets/spring_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InternRobotics/Seer/HEAD/real_preprocess/mujoco_menagerie/robotiq_2f85/assets/spring_link.stl -------------------------------------------------------------------------------- /real_preprocess/octo_oxe_data_utils/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | setup( 5 | name='octo_oxe_data_utils', 6 | version='1.0.0', 7 | packages=find_packages(), 8 | ) -------------------------------------------------------------------------------- /data_info/.hydra/overrides.yaml: -------------------------------------------------------------------------------- 1 | - load_dir=/work/dlclarge2/meeso-lfp/calvin_recordings/play_env_A/2021-10-06/16-23-57/ 2 | - set_static_cam=false 3 | - processes=16 4 | - +scene=calvin_scene_A 5 | - show_gui=false 6 | -------------------------------------------------------------------------------- /real_preprocess/octo_oxe_data_utils/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow == 2.15.0 2 | tensorflow_datasets == 4.9.2 3 | tensorflow_graphics == 2021.12.3 4 | dlimp @ git+https://github.com/kvablack/dlimp@d08da3852c149548aaa8551186d619d87375df08 5 | imageio 6 | absl-py 7 | mujoco 8 | roboticstoolbox-python[collision] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | flamingo_pytorch 2 | tensorboard 3 | ftfy==6.2.0 4 | regex 5 | tqdm 6 | matplotlib 7 | git+https://github.com/openai/CLIP.git 8 | timm==0.9.16 9 | h5py==3.11.0 10 | transformers==4.40.2 11 | packaging==24.0 12 | setuptools==57.5.0 13 | omegaconf==2.1.2 14 | moviepy==1.0.3 15 | einops_exts 16 | numpy==1.23.1 -------------------------------------------------------------------------------- /docs/LIBERO_LONG_INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | **(1) Conda Env** 4 | ``` 5 | conda create -n seer python=3.10 6 | conda activate seer 7 | ``` 8 | 9 | **(2) LIBERO Env** 10 | ``` 11 | git clone https://github.com/Lifelong-Robot-Learning/LIBERO.git 12 | cd LIBERO 13 | pip install -r requirements.txt 14 | pip install transformers==4.40.2 15 | pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121 16 | pip install -e . 17 | ``` -------------------------------------------------------------------------------- /real_preprocess/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.24.3 2 | tensorflow_probability==0.23.0 3 | tensorflow==2.15.0 4 | ml_collections>=0.1.0 5 | tqdm>=4.60.0 6 | absl-py>=0.12.0 7 | scipy>=1.6.0 8 | einops>=0.6.1 9 | tensorflow_hub>=0.14.0 10 | tensorflow_text>=2.13.0 11 | tensorflow_datasets==4.9.2 12 | tensorflow_graphics==2021.12.3 13 | mujoco 14 | Pillow 15 | pybullet 16 | dlimp@git+https://github.com/kvablack/dlimp@5edaa4691567873d495633f2708982b42edf1972 17 | roboticstoolbox-python -------------------------------------------------------------------------------- /data_info/droid_success_languaged_0803_stat_0.1_0.5.json: -------------------------------------------------------------------------------- 1 | { 2 | "droid_success_languaged_0803": { 3 | "mean": [ 4 | -0.013946113699228197, 5 | -0.0008144068900243393, 6 | -0.010159969904343499, 7 | 0.029293912297435065, 8 | -0.025690169644747598, 9 | -0.015989328273960466 10 | ], 11 | "std": [ 12 | 0.20220592438479257, 13 | 0.21943531464788582, 14 | 0.210883565255364, 15 | 0.17382451705057012, 16 | 0.14346100321749855, 17 | 0.14715284488965125 18 | ] 19 | } 20 | } -------------------------------------------------------------------------------- /data_info/droid_success_full_0803_stat.json: -------------------------------------------------------------------------------- 1 | { 2 | "droid_success_full_0803": { 3 | "mean": [ 4 | 0.034469623780328515, 5 | -0.001633081419506495, 6 | 0.0319105912448654, 7 | -0.0027474652911320127, 8 | 0.007313525898785993, 9 | 0.006798099658541783, 10 | 0.12350288160670808 11 | ], 12 | "std": [ 13 | 0.2520571699338127, 14 | 0.2486948654319061, 15 | 0.2937194360082029, 16 | 0.25230853439244694, 17 | 0.2796370525416708, 18 | 0.34477038757871253, 19 | 0.9923442135864566 20 | ] 21 | } 22 | } -------------------------------------------------------------------------------- /data_info/droid_success_languaged_0803_stat_0.02_0.05.json: -------------------------------------------------------------------------------- 1 | { 2 | "droid_success_languaged_0803": { 3 | "mean": [ 4 | 0.034524966522835786, 5 | -0.002034788243149046, 6 | 0.03286368149243478, 7 | -0.002270051002688882, 8 | 0.007490289567743871, 9 | 0.008484724839843296, 10 | 0.1251103956939712 11 | ], 12 | "std": [ 13 | 0.25377941149232486, 14 | 0.2485252318164213, 15 | 0.29530677038638276, 16 | 0.2521710628202299, 17 | 0.2795374759449668, 18 | 0.3412635660188041, 19 | 0.9921428268718969 20 | ] 21 | } 22 | } -------------------------------------------------------------------------------- /docs/REAL-WORLD_INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | To set up Seer, we will create an isolated environment called seer. This environment is designed to support pre-training, fine-tuning, and inference workflows. 3 | 4 | ## seer env 5 | **(1) Env** 6 | ```python 7 | conda create -n seer python=3.10 8 | conda activate seer 9 | ``` 10 | **(2) Third Party Packages** 11 | ```python 12 | cd ${YOUR_PATH_TO_SEER} 13 | pip install -r requirements.txt 14 | pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121 15 | ``` 16 | -------------------------------------------------------------------------------- /docs/REAL-WORLD_FT_SC.md: -------------------------------------------------------------------------------- 1 | # Quick Training 2 | Preparation 3 | ```python 4 | cd ${YOUR_PATH_TO_SEER} 5 | conda activate seer 6 | ``` 7 | Download related checkpoints from the [checkpoint repository](https://drive.google.com/drive/folders/1rT8JKLhJGIo97jfYUm2JiFUrogOq-dgJ?usp=drive_link). 8 | ## :sparkles: Fine-tuning 9 | * For single-node fine-tuning: 10 | ```bash 11 | bash scripts/REAL/single_node_ft.sh 12 | ``` 13 | ## :sparkles: Training from Scratch 14 | * For single-node training from scratch: 15 | ```bash 16 | bash scripts/REAL/single_node_scratch.sh 17 | ``` 18 | -------------------------------------------------------------------------------- /data_info/droid_failure.json: -------------------------------------------------------------------------------- 1 | [ 2 | [ 3 | "004897", 4 | 141 5 | ], 6 | [ 7 | "009466", 8 | 80 9 | ], 10 | [ 11 | "021346", 12 | 154 13 | ], 14 | [ 15 | "023455", 16 | 412 17 | ], 18 | [ 19 | "026167", 20 | 174 21 | ], 22 | [ 23 | "026189", 24 | 207 25 | ], 26 | [ 27 | "026614", 28 | 214 29 | ], 30 | [ 31 | "027372", 32 | 145 33 | ], 34 | [ 35 | "031500", 36 | 125 37 | ], 38 | [ 39 | "032599", 40 | 180 41 | ], 42 | [ 43 | "033248", 44 | 107 45 | ], 46 | [ 47 | "035609", 48 | 126 49 | ], 50 | [ 51 | "036914", 52 | 148 53 | ], 54 | [ 55 | "037104", 56 | 188 57 | ], 58 | [ 59 | "037173", 60 | 108 61 | ], 62 | [ 63 | "042263", 64 | 226 65 | ], 66 | [ 67 | "042479", 68 | 241 69 | ], 70 | [ 71 | "045810", 72 | 202 73 | ], 74 | [ 75 | "054451", 76 | 381 77 | ], 78 | [ 79 | "060790", 80 | 204 81 | ], 82 | [ 83 | "061356", 84 | 59 85 | ], 86 | [ 87 | "062455", 88 | 225 89 | ], 90 | [ 91 | "062856", 92 | 235 93 | ], 94 | [ 95 | "081293", 96 | 723 97 | ] 98 | ] -------------------------------------------------------------------------------- /docs/CALVIN_ABC-D_INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | **(1) Conda Env** 4 | ``` 5 | conda create -n seer python=3.10 6 | conda activate seer 7 | ``` 8 | 9 | **(2) CALVIN** 10 | > Exact instructions on [CALVIN](https://github.com/mees/calvin). 11 | ``` 12 | git clone --recurse-submodules https://github.com/mees/calvin.git 13 | export CALVIN_ROOT=$(pwd)/calvin 14 | cd $CALVIN_ROOT 15 | sh install.sh 16 | ``` 17 | 18 | **(3) Dataset Download** 19 | > We only download CALVIN ABC-D. 20 | ``` 21 | cd $CALVIN_ROOT/dataset 22 | sh download_data.sh ABC 23 | ``` 24 | 25 | **(4) Third Party Packages** 26 | ``` 27 | cd ${YOUR_PATH_TO_SEER} 28 | pip install -r requirements.txt 29 | pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121 30 | ``` 31 | 32 | **(5) Create a soft link to CALVIN** 33 | ``` 34 | cd ${YOUR_PATH_TO_SEER} 35 | ln -s $CALVIN_ROOT calvin 36 | ``` 37 | 38 | **(6) Copy the index file `except_lang_idx.npy` to the CALVIN ABC-D training data directory.** 39 | ```python 40 | cp -r data_info/except_lang_idx/except_lang_idx.npy calvin/dataset/task_ABC_D/training 41 | ``` -------------------------------------------------------------------------------- /docs/REAL-WORLD_PRETRAIN.md: -------------------------------------------------------------------------------- 1 | # Pre-train 2 | ## Notice 3 | We provide code for pre-training on both the DROID and OXE datasets. Users should update the save_checkpoint_path to the directory where you want to save the training checkpoints, and modify the root_dir to the location where the preprocessed real data is stored. Additionally, users should configure the SLURM information in the provided scripts. 4 | 5 | Preparation 6 | ```python 7 | cd ${YOUR_PATH_TO_SEER} 8 | conda activate seer 9 | ``` 10 | 11 | ## Pre-train (DROID FULL) 12 | * For single-node pre-training: 13 | ```bash 14 | bash scripts/REAL/single_node_full_cluster.sh 15 | ``` 16 | * For multi-node pre-training: 17 | ```bash 18 | bash scripts/REAL/slurm_s_full_cluster.sh 19 | ``` 20 | ## Pre-train (DROID with Language) 21 | * For single-node pre-training: 22 | ```bash 23 | bash scripts/REAL/single_node_language_cluster.sh 24 | ``` 25 | * For multi-node pre-training: 26 | ```bash 27 | bash scripts/REAL/slurm_s_language_cluster.sh 28 | ``` 29 | ## Pre-train (OXE) 30 | * For multi-node pre-training: 31 | We should first generate data info 32 | ```bash 33 | use the oxe_dataset_info in Seer/utils/real_ft_data.py 34 | ``` 35 | Then we train the model 36 | ```bash 37 | bash scripts/REAL/slurm_s_language_cluster.sh 38 | ``` 39 | -------------------------------------------------------------------------------- /real_preprocess/octo_oxe_data_utils/utils/goal_relabeling.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains simple goal relabeling logic for BC use-cases where rewards and next_observations are not required. 3 | Each function should add entries to the "task" dict. 4 | """ 5 | 6 | import tensorflow as tf 7 | 8 | from octo_oxe_data_utils.utils.data_utils import tree_merge 9 | 10 | 11 | def uniform(traj: dict) -> dict: 12 | """Relabels with a true uniform distribution over future states.""" 13 | traj_len = tf.shape(tf.nest.flatten(traj["observation"])[0])[0] 14 | 15 | # select a random future index for each transition i in the range [i + 1, traj_len) 16 | rand = tf.random.uniform([traj_len]) 17 | low = tf.cast(tf.range(traj_len) + 1, tf.float32) 18 | high = tf.cast(traj_len, tf.float32) 19 | goal_idxs = tf.cast(rand * (high - low) + low, tf.int32) 20 | 21 | # sometimes there are floating-point errors that cause an out-of-bounds 22 | goal_idxs = tf.minimum(goal_idxs, traj_len - 1) 23 | 24 | # adds keys to "task" mirroring "observation" keys (must do a tree merge to combine "pad_mask_dict" from 25 | # "observation" and "task" properly) 26 | goal = tf.nest.map_structure(lambda x: tf.gather(x, goal_idxs), traj["observation"]) 27 | traj["task"] = tree_merge(traj["task"], goal) 28 | 29 | return traj 30 | -------------------------------------------------------------------------------- /real_preprocess/mujoco_menagerie/robotiq_2f85/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2013, ROS-Industrial 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, 5 | are permitted provided that the following conditions are met: 6 | 7 | Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | Redistributions in binary form must reproduce the above copyright notice, this 11 | list of conditions and the following disclaimer in the documentation and/or 12 | other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 15 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 16 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 18 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 19 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 20 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 21 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 22 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 23 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | -------------------------------------------------------------------------------- /real_preprocess/mujoco_menagerie/robotiq_2f85/README.md: -------------------------------------------------------------------------------- 1 | # Robotiq 2F-85 Description (MJCF) 2 | 3 | Requires MuJoCo 2.2.2 or later. 4 | 5 | ## Overview 6 | 7 | This package contains a simplified robot description (MJCF) of the [Robotiq 85mm 8 | 2-Finger Adaptive 9 | Gripper](https://robotiq.com/products/2f85-140-adaptive-robot-gripper) developed 10 | by [Robotiq](https://robotiq.com/). It is derived from the [publicly available 11 | URDF 12 | description](https://github.com/ros-industrial/robotiq/tree/kinetic-devel/robotiq_2f_85_gripper_visualization). 13 | 14 |

15 | 16 |

17 | 18 | ## URDF → MJCF derivation steps 19 | 20 | 1. Added ` ` to the URDF's 21 | `` clause in order to preserve visual geometries. 22 | 2. Loaded the URDF into MuJoCo and saved a corresponding MJCF. 23 | 3. Manually edited the MJCF to extract common properties into the `` section. 24 | 4. Added `` clauses to prevent collisions between the linkage bodies. 25 | 5. Broke up collision pads into two pads for more contacts. 26 | 6. Increased pad friction and priority. 27 | 7. Added `impratio=10` for better noslip. 28 | 8. Added `scene.xml` which includes the robot, with a textured groundplane, skybox, and haze. 29 | 9. Added hanging box to `scene.xml`. 30 | 31 | ## License 32 | 33 | This model is released under a [BSD-2-Clause License](LICENSE). 34 | -------------------------------------------------------------------------------- /scripts/CALVIN_ABC_D/Seer/scratch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ### NEED TO CHANGE ### 3 | calvin_dataset_path="calvin/dataset/task_ABC_D" 4 | save_checkpoint_path="checkpoints/" 5 | vit_checkpoint_path="checkpoints/vit_mae/mae_pretrain_vit_base.pth" # downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing 6 | 7 | node=8 8 | node_num=8 9 | torchrun --nnodes=${node} --nproc_per_node=${node_num} --master_port=10211 train.py \ 10 | --traj_cons \ 11 | --rgb_pad 10 \ 12 | --gripper_pad 4 \ 13 | --gradient_accumulation_steps 1 \ 14 | --bf16_module "vision_encoder" \ 15 | --vit_checkpoint_path ${vit_checkpoint_path} \ 16 | --calvin_dataset ${calvin_dataset_path} \ 17 | --workers 8 \ 18 | --lr_scheduler cosine \ 19 | --save_every_iter 100000 \ 20 | --num_epochs 20 \ 21 | --seed 42 \ 22 | --batch_size 8 \ 23 | --precision fp32 \ 24 | --learning_rate 1e-3 \ 25 | --finetune_type "calvin" \ 26 | --wandb_project seer \ 27 | --weight_decay 1e-4 \ 28 | --num_resampler_query 6 \ 29 | --run_name scratch_seer_calvin_abc_d \ 30 | --save_checkpoint \ 31 | --save_checkpoint_path ${save_checkpoint_path} \ 32 | --transformer_layers 24 \ 33 | --phase "finetune" \ 34 | --action_pred_steps 3 \ 35 | --sequence_length 10 \ 36 | --future_steps 3 \ 37 | --window_size 13 \ 38 | --obs_pred \ 39 | --loss_image \ 40 | --loss_action \ 41 | --report_to_wandb \ 42 | -------------------------------------------------------------------------------- /scripts/REAL/slurm_s_oxe_cluster.sh: -------------------------------------------------------------------------------- 1 | ### NEED TO CHANGE ### 2 | save_checkpoint_path="xxx/checkpoints" 3 | root_dir="xxx/preprocess" 4 | vit_checkpoint_path="xxx/mae_pretrain_vit_base.pth" # downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing 5 | ### NEED TO CHANGE ### 6 | 7 | python slurm_train_intern.py \ 8 | --traj_cons \ 9 | --rgb_pad 10 \ 10 | --gripper_pad 4 \ 11 | --gradient_accumulation_steps 2 \ 12 | --bf16_module "vision_encoder" \ 13 | --vit_checkpoint_path ${vit_checkpoint_path} \ 14 | --calvin_dataset "" \ 15 | --workers 8 \ 16 | --lr_scheduler cosine \ 17 | --save_every_iter 20000 \ 18 | --num_epochs 30 \ 19 | --seed 42 \ 20 | --batch_size 32 \ 21 | --precision fp32 \ 22 | --learning_rate 1e-4 \ 23 | --save_checkpoint \ 24 | --finetune_type "oxe" \ 25 | --wandb_project seer \ 26 | --weight_decay 1e-4 \ 27 | --num_resampler_query 6 \ 28 | --run_name mn_oxe \ 29 | --save_checkpoint_path ${save_checkpoint_path} \ 30 | --except_lang \ 31 | --transformer_layers 24 \ 32 | --phase "pretrain" \ 33 | --obs_pred \ 34 | --action_pred_steps 3 \ 35 | --sequence_length 11 \ 36 | --window_size 11 \ 37 | --future_steps 3 \ 38 | --loss_action \ 39 | --loss_image \ 40 | --atten_goal 4 \ 41 | --atten_goal_state \ 42 | --atten_only_obs \ 43 | --real_dataset_names "" \ 44 | --root_dir ${root_dir} \ 45 | --report_to_wandb \ 46 | --warmup_epochs 3 \ -------------------------------------------------------------------------------- /scripts/REAL/deploy.sh: -------------------------------------------------------------------------------- 1 | ### NEED TO CHANGE ### 2 | resume_from_checkpoint="xxx/xxx.pth" 3 | vit_checkpoint_path="xxx/mae_pretrain_vit_base.pth" # downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing 4 | ### NEED TO CHANGE ### 5 | 6 | IFS='/' read -ra path_parts <<< "$resume_from_checkpoint" 7 | run_name="${path_parts[-2]}" 8 | log_name="${path_parts[-1]}" 9 | log_folder="eval_logs/$run_name" 10 | mkdir -p "$log_folder" 11 | log_file="eval_logs/$run_name/evaluate_$log_name.log" 12 | 13 | node=1 14 | node_num=1 15 | # vision_encoder_causal_transformer_image_decoder 16 | torchrun --nnodes=${node} --nproc_per_node=${node_num} --master_port=10113 deploy.py \ 17 | --traj_cons \ 18 | --rgb_pad 10 \ 19 | --gripper_pad 4 \ 20 | --gradient_accumulation_steps 1 \ 21 | --bf16_module "vision_encoder" \ 22 | --vit_checkpoint_path ${vit_checkpoint_path} \ 23 | --workers 16 \ 24 | --calvin_dataset "" \ 25 | --lr_scheduler cosine \ 26 | --save_every_iter 50000 \ 27 | --num_epochs 20 \ 28 | --seed 42 \ 29 | --batch_size 64 \ 30 | --precision fp32 \ 31 | --weight_decay 1e-4 \ 32 | --num_resampler_query 6 \ 33 | --run_name test \ 34 | --transformer_layers 24 \ 35 | --phase "evaluate" \ 36 | --finetune_type "real" \ 37 | --action_pred_steps 3 \ 38 | --future_steps 3 \ 39 | --sequence_length 7 \ 40 | --obs_pred \ 41 | --resume_from_checkpoint ${resume_from_checkpoint} \ 42 | --real_eval_max_steps 600 \ 43 | --eval_libero_ensembling \ -------------------------------------------------------------------------------- /docs/REAL-WORLD_INFERENCE.md: -------------------------------------------------------------------------------- 1 | # Inference 2 | Seer has been encapsulated into a real-world controller, allowing for straightforward adaptation to downstream tasks. To ensure smooth implementation and avoid common errors, we offer the following recommendations: 3 | * :fire: **Proprio First v.s. Image First:** During data collection and inference, always acquire robot proprioception data first, followed by image observations. This order minimizes the timestep interval since capturing images is significantly more time-intensive than reading proprioception data. 4 | * :fire: **Delta-to-Absolute Action Labels:** To simplify matrix computation and transformation, we provide a delta-action-to-absolute-action conversion in the [deployment script](../deploy.py). This aligns with the absolute-action-to-delta-action transformation found in the [post-process script](../utils/real_ft_data.py). 5 | * :fire: **Consistent Control Frequency:** Ensure that the control frequencies used during data collection match those during inference. Discrepancies in frequency can lead to inconsistent results. 6 | 7 | ## :star: Real-World Controller 8 | A [wrapped seer controller](../real_controller/controller.py) is provided for real-world deployment. This controller is modular and can be easily adapted to specific tasks or environments. 9 | 10 | ## :star2: Real-World Deployment Pseudocode 11 | To deploy the wrapped Seer controller for real-world tasks, modify the [deployment script](../deploy.py) to fit your specific environment. Then, execute the deployment with the following command: 12 | ```python 13 | bash scripts/REAL/deploy.sh 14 | ``` -------------------------------------------------------------------------------- /scripts/CALVIN_ABC_D/Seer/pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ### NEED TO CHANGE ### 3 | calvin_dataset_path="calvin/dataset/task_ABC_D" 4 | save_checkpoint_path="checkpoints/" 5 | vit_checkpoint_path="checkpoints/vit_mae/mae_pretrain_vit_base.pth" # downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing 6 | ### NEED TO CHANGE ### 7 | 8 | node=8 9 | node_num=8 10 | torchrun --nnodes=${node} --nproc_per_node=${node_num} --master_port=10211 train.py \ 11 | --traj_cons \ 12 | --rgb_pad 10 \ 13 | --gripper_pad 4 \ 14 | --gradient_accumulation_steps 1 \ 15 | --bf16_module "vision_encoder" \ 16 | --vit_checkpoint_path ${vit_checkpoint_path} \ 17 | --calvin_dataset ${calvin_dataset_path} \ 18 | --workers 8 \ 19 | --lr_scheduler cosine \ 20 | --save_every_iter 100000 \ 21 | --num_epochs 20 \ 22 | --seed 42 \ 23 | --batch_size 10 \ 24 | --precision fp32 \ 25 | --learning_rate 1e-4 \ 26 | --finetune_type "calvin" \ 27 | --wandb_project seer \ 28 | --weight_decay 1e-4 \ 29 | --num_resampler_query 6 \ 30 | --run_name pretrain_seer_calvin_abc_d \ 31 | --save_checkpoint_path ${save_checkpoint_path} \ 32 | --transformer_layers 24 \ 33 | --phase "pretrain" \ 34 | --action_pred_steps 3 \ 35 | --sequence_length 14 \ 36 | --future_steps 3 \ 37 | --window_size 17 \ 38 | --obs_pred \ 39 | --loss_image \ 40 | --loss_action \ 41 | --atten_goal 4 \ 42 | --atten_goal_state \ 43 | --atten_only_obs \ 44 | --except_lang \ 45 | --save_checkpoint \ 46 | --report_to_wandb \ 47 | -------------------------------------------------------------------------------- /real_preprocess/mujoco_menagerie/robotiq_2f85/scene.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 41 | -------------------------------------------------------------------------------- /scripts/LIBERO_LONG/Seer/scratch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ### NEED TO CHANGE ### 4 | save_checkpoint_path="checkpoints/" 5 | root_dir="PATH_TO_PARENT_DIR_OF_LIBERO_CONVERTED" 6 | vit_checkpoint_path="checkpoints/vit_mae/mae_pretrain_vit_base.pth" 7 | libero_path="PATH_TO_LIBERO" 8 | ### NEED TO CHANGE ### 9 | calvin_dataset_path="calvin/dataset/task_ABC_D" 10 | 11 | node=1 12 | node_num=8 13 | torchrun --nnodes=${node} --nproc_per_node=${node_num} --master_port=10211 train.py \ 14 | --traj_cons \ 15 | --rgb_pad 10 \ 16 | --gripper_pad 4 \ 17 | --gradient_accumulation_steps 4 \ 18 | --bf16_module "vision_encoder" \ 19 | --vit_checkpoint_path ${vit_checkpoint_path} \ 20 | --calvin_dataset ${calvin_dataset_path} \ 21 | --workers 8 \ 22 | --lr_scheduler cosine \ 23 | --save_every_iter 100000 \ 24 | --num_epochs 40 \ 25 | --seed 42 \ 26 | --batch_size 16 \ 27 | --precision fp32 \ 28 | --learning_rate 1e-3 \ 29 | --save_checkpoint \ 30 | --finetune_type libero_finetune \ 31 | --root_dir ${root_dir} \ 32 | --wandb_project seer \ 33 | --weight_decay 1e-4 \ 34 | --num_resampler_query 6 \ 35 | --run_name libero_scratch \ 36 | --save_checkpoint_path ${save_checkpoint_path} \ 37 | --transformer_layers 24 \ 38 | --phase "finetune" \ 39 | --obs_pred \ 40 | --action_pred_steps 3 \ 41 | --sequence_length 7 \ 42 | --future_steps 3 \ 43 | --window_size 10 \ 44 | --loss_image \ 45 | --loss_action \ 46 | --save_checkpoint_seq 1 \ 47 | --start_save_checkpoint 25 \ 48 | --gripper_width \ 49 | --warmup_epochs 5 \ 50 | --libero_path ${libero_path} \ 51 | --report_to_wandb \ -------------------------------------------------------------------------------- /scripts/CALVIN_ABC_D/Seer/finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ### NEED TO CHANGE ### 3 | calvin_dataset_path="calvin/dataset/task_ABC_D" 4 | save_checkpoint_path="checkpoints/" 5 | finetune_from_pretrained_ckpt="checkpoints/pretrain_calvin_abc_d/4.pth" 6 | vit_checkpoint_path="checkpoints/vit_mae/mae_pretrain_vit_base.pth" # downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing 7 | ### NEED TO CHANGE ### 8 | 9 | node=4 10 | node_num=8 11 | torchrun --nnodes=${node} --nproc_per_node=${node_num} --master_port=10211 train.py \ 12 | --traj_cons \ 13 | --rgb_pad 10 \ 14 | --gripper_pad 4 \ 15 | --gradient_accumulation_steps 1 \ 16 | --bf16_module "vision_encoder" \ 17 | --vit_checkpoint_path ${vit_checkpoint_path} \ 18 | --calvin_dataset ${calvin_dataset_path} \ 19 | --workers 8 \ 20 | --lr_scheduler cosine \ 21 | --save_every_iter 100000 \ 22 | --num_epochs 20 \ 23 | --seed 42 \ 24 | --batch_size 16 \ 25 | --precision fp32 \ 26 | --learning_rate 1e-3 \ 27 | --save_checkpoint \ 28 | --finetune_type "calvin" \ 29 | --wandb_project seer \ 30 | --weight_decay 1e-4 \ 31 | --num_resampler_query 6 \ 32 | --run_name finetune_calvin_abc_d_ep5 \ 33 | --save_checkpoint_path ${save_checkpoint_path} \ 34 | --transformer_layers 24 \ 35 | --phase "finetune" \ 36 | --action_pred_steps 3 \ 37 | --sequence_length 10 \ 38 | --future_steps 3 \ 39 | --window_size 13 \ 40 | --obs_pred \ 41 | --loss_image \ 42 | --loss_action \ 43 | --report_to_wandb \ 44 | --reset_action_token \ 45 | --reset_obs_token \ 46 | --finetune_from_pretrained_ckpt ${finetune_from_pretrained_ckpt} \ -------------------------------------------------------------------------------- /scripts/LIBERO_LONG/Seer/pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ### NEED TO CHANGE ### 4 | save_checkpoint_path="checkpoints/" 5 | root_dir="PATH_TO_PARENT_DIR_OF_LIBERO_CONVERTED" 6 | vit_checkpoint_path="checkpoints/vit_mae/mae_pretrain_vit_base.pth" 7 | libero_path="PATH_TO_LIBERO" 8 | ### NEED TO CHANGE ### 9 | calvin_dataset_path="calvin/dataset/task_ABC_D" 10 | 11 | node=1 12 | node_num=8 13 | torchrun --nnodes=${node} --nproc_per_node=${node_num} --master_port=10211 train.py \ 14 | --traj_cons \ 15 | --rgb_pad 10 \ 16 | --gripper_pad 4 \ 17 | --gradient_accumulation_steps 8 \ 18 | --bf16_module "vision_encoder" \ 19 | --vit_checkpoint_path ${vit_checkpoint_path} \ 20 | --calvin_dataset ${calvin_dataset_path} \ 21 | --workers 16 \ 22 | --lr_scheduler cosine \ 23 | --save_every_iter 100000 \ 24 | --num_epochs 30 \ 25 | --seed 42 \ 26 | --batch_size 10 \ 27 | --precision fp32 \ 28 | --learning_rate 1e-4 \ 29 | --save_checkpoint \ 30 | --finetune_type libero_pretrain \ 31 | --root_dir ${root_dir} \ 32 | --wandb_project seer \ 33 | --weight_decay 1e-4 \ 34 | --num_resampler_query 6 \ 35 | --run_name libero_pretrain \ 36 | --save_checkpoint_path ${save_checkpoint_path} \ 37 | --transformer_layers 24 \ 38 | --phase "pretrain" \ 39 | --obs_pred \ 40 | --sequence_length 11 \ 41 | --action_pred_steps 3 \ 42 | --future_steps 3 \ 43 | --atten_goal 4 \ 44 | --window_size 11 \ 45 | --loss_image \ 46 | --loss_action \ 47 | --gripper_width \ 48 | --atten_only_obs \ 49 | --atten_goal_state \ 50 | --mask_l_obs_ratio 0.5 \ 51 | --warmup_epochs 1 \ 52 | --libero_path ${libero_path} \ 53 | --report_to_wandb \ -------------------------------------------------------------------------------- /docs/LIBERO_LONG_RUN.md: -------------------------------------------------------------------------------- 1 | # Running 2 | ## Notice 3 | 4 | For convenience, some checkpoints, such as the MAE-pretrained ViT-B model, are provided for manual download. Users must update the following paths accordingly. Relevant checkpoints can be acquired from the [website](https://drive.google.com/drive/folders/1zwqGvKKtjyuWdDaNSLVGJprJMPoSqAPk?usp=drive_link). 5 | * :exclamation: **pretrain.sh, finetune.sh, scratch, eval.sh:** 6 | Please update the following: 7 | * **save_checkpoint_path** to the parent directory where your experiment checkpoints are saved. Recommend to create a ```checkpoints``` folder in the project root directory. 8 | * **finetune_from_pretrained_ckpt** to the location of your pre-trained checkpoint. 9 | * **resume_from_checkpoint** to the location of your fine-tuned checkpoint. 10 | * **vit_checkpoint_path** to the location of your ViT checkpoint (downloaded from the [website](https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing)). Recommend to be stored in ```checkpoints/vit_mae/mae_pretrain_vit_base.pth```. 11 | * **libero_path** to the location of LIBERO dir. 12 | 13 | ## Seer 14 | ### Convert Data 15 | ```bash 16 | python utils/convert_libero_per_step.py 17 | ``` 18 | 19 | ### Pre-train 20 | ```bash 21 | # Pre-train Seer on LIBERO-90 dataset 22 | bash scripts/LIBERO_LONG/Seer/pretrain.sh 23 | ``` 24 | 25 | ### Fine-tune 26 | ```bash 27 | # Fine-tune Seer on LIBERO-10 dataset 28 | bash scripts/LIBERO_LONG/Seer/finetune.sh 29 | ``` 30 | 31 | ### Train from Scratch 32 | ```bash 33 | # Train Seer on LIBERO-10 dataset from scratch 34 | bash scripts/LIBERO_LONG/Seer/scratch.sh 35 | ``` 36 | 37 | ### Eval 38 | ```bash 39 | # Evaluate Seer on LIBERO-10 benchmark 40 | bash scripts/LIBERO_LONG/Seer/eval.sh 41 | ``` -------------------------------------------------------------------------------- /scripts/CALVIN_ABC_D/Seer-Large/scratch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ### need to change to your path ### 3 | calvin_dataset_path="calvin/dataset/task_ABC_D" 4 | save_checkpoint_path="checkpoints/" 5 | vit_checkpoint_path="checkpoints/vit_mae/mae_pretrain_vit_base.pth" # downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing 6 | # node=8 7 | node_num=8 8 | torchrun \ 9 | --nnodes=${MLP_WORKER_NUM} \ 10 | --node_rank=${MLP_ROLE_INDEX} \ 11 | --nproc_per_node=${node_num} \ 12 | --master_addr=${MLP_WORKER_0_HOST} \ 13 | --master_port=${MLP_WORKER_0_PORT} \ 14 | train.py \ 15 | --traj_cons \ 16 | --rgb_pad 10 \ 17 | --gripper_pad 4 \ 18 | --gradient_accumulation_steps 1 \ 19 | --bf16_module "vision_encoder" \ 20 | --vit_checkpoint_path ${vit_checkpoint_path} \ 21 | --calvin_dataset ${calvin_dataset_path} \ 22 | --workers 8 \ 23 | --lr_scheduler cosine \ 24 | --save_every_iter 100000 \ 25 | --num_epochs 20 \ 26 | --seed 42 \ 27 | --batch_size 8 \ 28 | --precision fp32 \ 29 | --learning_rate 1e-3 \ 30 | --warmup_epochs 1 \ 31 | --finetune_type "calvin" \ 32 | --wandb_project seer \ 33 | --weight_decay 1e-4 \ 34 | --num_resampler_query 16 \ 35 | --num_obs_token_per_image 16 \ 36 | --run_name scratch-Seer-Large \ 37 | --save_checkpoint \ 38 | --save_checkpoint_path ${save_checkpoint_path} \ 39 | --transformer_layers 24 \ 40 | --hidden_dim 1024 \ 41 | --transformer_heads 16 \ 42 | --phase "finetune" \ 43 | --action_pred_steps 3 \ 44 | --sequence_length 10 \ 45 | --future_steps 3 \ 46 | --window_size 13 \ 47 | --obs_pred \ 48 | --loss_image \ 49 | --loss_action \ 50 | --report_to_wandb \ 51 | -------------------------------------------------------------------------------- /scripts/LIBERO_LONG/Seer/eval.sh: -------------------------------------------------------------------------------- 1 | pthlist=("30" "31" "32" "33" "34" "35" "36" "37" "38" "39") 2 | for ckpt_id in "${pthlist[@]}"; do 3 | resume_from_checkpoint="/home/tianyang/Checkpoints/libero_scratch" 4 | vit_checkpoint_path="checkpoints/vit_mae/mae_pretrain_vit_base.pth" 5 | this_resume_from_checkpoint="${resume_from_checkpoint}/${ckpt_id}.pth" 6 | save_checkpoint_path="checkpoints/" 7 | dirname=$(basename "$resume_from_checkpoint") 8 | LOG_DIR="/home/tianyang/Code/Eval/${dirname}" 9 | mkdir -p ${LOG_DIR} 10 | test_id="${ckpt_id}" 11 | logfile="${LOG_DIR}/${test_id}.log" 12 | 13 | node=1 14 | node_num=8 15 | 16 | python -m torch.distributed.run --nnodes=${node} --nproc_per_node=${node_num} --master_port=10133 eval_libero.py \ 17 | --traj_cons \ 18 | --rgb_pad 10 \ 19 | --gripper_pad 4 \ 20 | --gradient_accumulation_steps 1 \ 21 | --bf16_module "vision_encoder" \ 22 | --vit_checkpoint_path ${vit_checkpoint_path} \ 23 | --calvin_dataset "" \ 24 | --workers 16 \ 25 | --lr_scheduler cosine \ 26 | --save_every_iter 50000 \ 27 | --num_epochs 20 \ 28 | --seed 42 \ 29 | --batch_size 64 \ 30 | --precision fp32 \ 31 | --weight_decay 1e-4 \ 32 | --num_resampler_query 6 \ 33 | --run_name test \ 34 | --transformer_layers 24 \ 35 | --phase "evaluate" \ 36 | --finetune_type "libero_10" \ 37 | --save_checkpoint_path ${save_checkpoint_path} \ 38 | --action_pred_steps 3 \ 39 | --future_steps 3 \ 40 | --sequence_length 7 \ 41 | --obs_pred \ 42 | --gripper_width \ 43 | --eval_libero_ensembling \ 44 | --resume_from_checkpoint ${this_resume_from_checkpoint} | tee ${logfile} 45 | done -------------------------------------------------------------------------------- /scripts/REAL/slurm_s_full_cluster.sh: -------------------------------------------------------------------------------- 1 | ### NEED TO CHANGE ### 2 | save_checkpoint_path="xxx/checkpoints" 3 | root_dir="xxx/preprocess" 4 | vit_checkpoint_path="xxx/mae_pretrain_vit_base.pth" # downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing 5 | ### NEED TO CHANGE ### 6 | 7 | ### EXAMPLE ### 8 | # - root_dir 9 | # - droid_success 10 | # - epsiodes 11 | # - 000000 12 | # - ...... 13 | # - xxxxxx 14 | # - meta_info.h5 15 | # - shape_info.h5 16 | ### EXAMPLE ### 17 | 18 | python slurm_train_intern.py \ 19 | --traj_cons \ 20 | --rgb_pad 10 \ 21 | --gripper_pad 4 \ 22 | --gradient_accumulation_steps 2 \ 23 | --bf16_module "vision_encoder" \ 24 | --vit_checkpoint_path ${vit_checkpoint_path} \ 25 | --calvin_dataset "" \ 26 | --workers 8 \ 27 | --lr_scheduler cosine \ 28 | --save_every_iter 20000 \ 29 | --num_epochs 30 \ 30 | --seed 42 \ 31 | --batch_size 32 \ 32 | --precision fp32 \ 33 | --learning_rate 1e-4 \ 34 | --save_checkpoint \ 35 | --finetune_type "droid" \ 36 | --wandb_project seer \ 37 | --weight_decay 1e-4 \ 38 | --num_resampler_query 6 \ 39 | --run_name mn_full_droid \ 40 | --save_checkpoint_path ${save_checkpoint_path} \ 41 | --except_lang \ 42 | --transformer_layers 24 \ 43 | --phase "pretrain" \ 44 | --obs_pred \ 45 | --action_pred_steps 3 \ 46 | --sequence_length 11 \ 47 | --window_size 11 \ 48 | --future_steps 3 \ 49 | --loss_action \ 50 | --loss_image \ 51 | --atten_goal 4 \ 52 | --atten_goal_state \ 53 | --atten_only_obs \ 54 | --real_dataset_names "" \ 55 | --root_dir ${root_dir} \ 56 | --dataset_info droid_success_full_0803 \ 57 | --report_to_wandb \ 58 | --warmup_epochs 3 \ -------------------------------------------------------------------------------- /scripts/REAL/slurm_s_language_cluster.sh: -------------------------------------------------------------------------------- 1 | ### NEED TO CHANGE ### 2 | save_checkpoint_path="xxx/checkpoints" 3 | root_dir="xxx/preprocess" 4 | vit_checkpoint_path="xxx/mae_pretrain_vit_base.pth" # downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing 5 | ### NEED TO CHANGE ### 6 | 7 | ### EXAMPLE ### 8 | # - root_dir 9 | # - droid_success 10 | # - epsiodes 11 | # - 000000 12 | # - ...... 13 | # - xxxxxx 14 | # - meta_info.h5 15 | # - shape_info.h5 16 | ### EXAMPLE ### 17 | 18 | python slurm_train_intern.py \ 19 | --traj_cons \ 20 | --rgb_pad 10 \ 21 | --gripper_pad 4 \ 22 | --gradient_accumulation_steps 2 \ 23 | --bf16_module "vision_encoder" \ 24 | --vit_checkpoint_path ${vit_checkpoint_path} \ 25 | --calvin_dataset "" \ 26 | --workers 8 \ 27 | --lr_scheduler cosine \ 28 | --save_every_iter 20000 \ 29 | --num_epochs 30 \ 30 | --seed 42 \ 31 | --batch_size 32 \ 32 | --precision fp32 \ 33 | --learning_rate 1e-4 \ 34 | --save_checkpoint \ 35 | --finetune_type "droid" \ 36 | --wandb_project seer \ 37 | --weight_decay 1e-4 \ 38 | --num_resampler_query 6 \ 39 | --run_name mn_lang_droid \ 40 | --save_checkpoint_path ${save_checkpoint_path} \ 41 | --except_lang \ 42 | --transformer_layers 24 \ 43 | --phase "pretrain" \ 44 | --obs_pred \ 45 | --action_pred_steps 3 \ 46 | --sequence_length 11 \ 47 | --window_size 11 \ 48 | --future_steps 3 \ 49 | --loss_action \ 50 | --loss_image \ 51 | --atten_goal 4 \ 52 | --atten_goal_state \ 53 | --atten_only_obs \ 54 | --real_dataset_names "" \ 55 | --root_dir ${root_dir} \ 56 | --dataset_info droid_success_languaged_0803 \ 57 | --report_to_wandb \ 58 | --warmup_epochs 3 \ -------------------------------------------------------------------------------- /scripts/LIBERO_LONG/Seer/finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ### NEED TO CHANGE ### 4 | save_checkpoint_path="checkpoints/" 5 | root_dir="PATH_TO_PARENT_DIR_OF_LIBERO_CONVERTED" 6 | vit_checkpoint_path="checkpoints/vit_mae/mae_pretrain_vit_base.pth" 7 | finetune_from_pretrained_ckpt="libero_pretrain/14.pth" 8 | libero_path="PATH_TO_LIBERO" 9 | ### NEED TO CHANGE ### 10 | calvin_dataset_path="calvin/dataset/task_ABC_D" 11 | 12 | node=1 13 | node_num=8 14 | torchrun --nnodes=${node} --nproc_per_node=${node_num} --master_port=10211 train.py \ 15 | --traj_cons \ 16 | --rgb_pad 10 \ 17 | --gripper_pad 4 \ 18 | --gradient_accumulation_steps 4 \ 19 | --bf16_module "vision_encoder" \ 20 | --vit_checkpoint_path ${vit_checkpoint_path} \ 21 | --calvin_dataset ${calvin_dataset_path} \ 22 | --workers 8 \ 23 | --lr_scheduler cosine \ 24 | --save_every_iter 100000 \ 25 | --num_epochs 40 \ 26 | --seed 42 \ 27 | --batch_size 16 \ 28 | --precision fp32 \ 29 | --learning_rate 1e-3 \ 30 | --save_checkpoint \ 31 | --finetune_type libero_finetune \ 32 | --root_dir ${root_dir} \ 33 | --wandb_project seer \ 34 | --weight_decay 1e-4 \ 35 | --num_resampler_query 6 \ 36 | --run_name libero_finetune \ 37 | --save_checkpoint_path ${save_checkpoint_path} \ 38 | --transformer_layers 24 \ 39 | --phase "finetune" \ 40 | --obs_pred \ 41 | --action_pred_steps 3 \ 42 | --sequence_length 7 \ 43 | --future_steps 3 \ 44 | --window_size 10 \ 45 | --loss_image \ 46 | --loss_action \ 47 | --reset_action_token \ 48 | --reset_obs_token \ 49 | --save_checkpoint_seq 1 \ 50 | --start_save_checkpoint 25 \ 51 | --gripper_width \ 52 | --warmup_epochs 5 \ 53 | --libero_path ${libero_path} \ 54 | --finetune_from_pretrained_ckpt ${finetune_from_pretrained_ckpt} \ 55 | --report_to_wandb \ -------------------------------------------------------------------------------- /scripts/CALVIN_ABC_D/Seer-Large/pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ### need to change to your path ### 3 | calvin_dataset_path="calvin/dataset/task_ABC_D" 4 | save_checkpoint_path="checkpoints/" 5 | vit_checkpoint_path="checkpoints/vit_mae/mae_pretrain_vit_base.pth" # downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing 6 | # node=8 7 | node_num=8 8 | torchrun \ 9 | --nnodes=${MLP_WORKER_NUM} \ 10 | --node_rank=${MLP_ROLE_INDEX} \ 11 | --nproc_per_node=${node_num} \ 12 | --master_addr=${MLP_WORKER_0_HOST} \ 13 | --master_port=${MLP_WORKER_0_PORT} \ 14 | train.py \ 15 | --traj_cons \ 16 | --rgb_pad 10 \ 17 | --gripper_pad 4 \ 18 | --gradient_accumulation_steps 1 \ 19 | --bf16_module "vision_encoder" \ 20 | --vit_checkpoint_path ${vit_checkpoint_path} \ 21 | --calvin_dataset ${calvin_dataset_path} \ 22 | --workers 8 \ 23 | --lr_scheduler cosine \ 24 | --save_every_iter 100000 \ 25 | --num_epochs 20 \ 26 | --seed 42 \ 27 | --batch_size 8 \ 28 | --precision fp32 \ 29 | --learning_rate 1e-4 \ 30 | --finetune_type "calvin" \ 31 | --wandb_project seer \ 32 | --weight_decay 1e-4 \ 33 | --num_resampler_query 16 \ 34 | --num_obs_token_per_image 16 \ 35 | --run_name pretrain_Seer-Large \ 36 | --save_checkpoint_path ${save_checkpoint_path} \ 37 | --transformer_layers 24 \ 38 | --hidden_dim 1024 \ 39 | --transformer_heads 16 \ 40 | --phase "pretrain" \ 41 | --action_pred_steps 3 \ 42 | --sequence_length 14 \ 43 | --future_steps 3 \ 44 | --window_size 17 \ 45 | --obs_pred \ 46 | --loss_image \ 47 | --loss_action \ 48 | --atten_goal 4 \ 49 | --atten_goal_state \ 50 | --atten_only_obs \ 51 | --attn_robot_proprio_state \ 52 | --except_lang \ 53 | --save_checkpoint \ 54 | --report_to_wandb \ 55 | -------------------------------------------------------------------------------- /scripts/REAL/single_node_full_cluster.sh: -------------------------------------------------------------------------------- 1 | ### NEED TO CHANGE ### 2 | save_checkpoint_path="xxx/checkpoints" 3 | root_dir="xxx/preprocess" 4 | vit_checkpoint_path="xxx/mae_pretrain_vit_base.pth" # downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing 5 | ### NEED TO CHANGE ### 6 | 7 | ### EXAMPLE ### 8 | # - root_dir 9 | # - droid_success 10 | # - epsiodes 11 | # - 000000 12 | # - ...... 13 | # - xxxxxx 14 | # - meta_info.h5 15 | # - shape_info.h5 16 | ### EXAMPLE ### 17 | 18 | node=4 19 | node_num=8 20 | torchrun --nnodes=${node} --nproc_per_node=${node_num} --master_port=62345 train.py \ 21 | --traj_cons \ 22 | --rgb_pad 10 \ 23 | --gripper_pad 4 \ 24 | --gradient_accumulation_steps 2 \ 25 | --bf16_module "vision_encoder" \ 26 | --vit_checkpoint_path ${vit_checkpoint_path} \ 27 | --calvin_dataset "" \ 28 | --workers 8 \ 29 | --lr_scheduler cosine \ 30 | --save_every_iter 20000 \ 31 | --num_epochs 20 \ 32 | --seed 42 \ 33 | --batch_size 32 \ 34 | --precision fp32 \ 35 | --learning_rate 1e-4 \ 36 | --save_checkpoint \ 37 | --finetune_type "droid" \ 38 | --wandb_project seer \ 39 | --weight_decay 1e-4 \ 40 | --num_resampler_query 6 \ 41 | --run_name sn_full_droid \ 42 | --save_checkpoint_path ${save_checkpoint_path} \ 43 | --except_lang \ 44 | --transformer_layers 24 \ 45 | --phase "pretrain" \ 46 | --obs_pred \ 47 | --action_pred_steps 3 \ 48 | --sequence_length 11 \ 49 | --window_size 11 \ 50 | --future_steps 3 \ 51 | --loss_action \ 52 | --loss_image \ 53 | --atten_goal 4 \ 54 | --atten_goal_state \ 55 | --atten_only_obs \ 56 | --real_dataset_names "" \ 57 | --root_dir ${root_dir} \ 58 | --dataset_info droid_success_full_0803 \ 59 | --warmup_epochs 3 \ 60 | --report_to_wandb \ 61 | -------------------------------------------------------------------------------- /scripts/CALVIN_ABC_D/Seer/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export GIT_PYTHON_REFRESH=quiet 3 | calvin_dataset_path="calvin/dataset/task_ABC_D" 4 | calvin_conf_path="calvin/calvin_models/conf" 5 | vit_checkpoint_path="checkpoints/vit_mae/mae_pretrain_vit_base.pth" # downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing 6 | save_checkpoint_path="checkpoints/" 7 | ### NEED TO CHANGE the checkpoint path ### 8 | resume_from_checkpoint="checkpoints/xxx/xx.pth" 9 | IFS='/' read -ra path_parts <<< "$resume_from_checkpoint" 10 | run_name="${path_parts[-2]}" 11 | log_name="${path_parts[-1]}" 12 | log_folder="eval_logs/$run_name" 13 | mkdir -p "$log_folder" 14 | log_file="eval_logs/$run_name/evaluate_$log_name.log" 15 | node=1 16 | node_num=8 17 | 18 | torchrun --nnodes=${node} --nproc_per_node=${node_num} --master_port=10012 eval_calvin.py \ 19 | --traj_cons \ 20 | --rgb_pad 10 \ 21 | --gripper_pad 4 \ 22 | --gradient_accumulation_steps 1 \ 23 | --bf16_module "vision_encoder" \ 24 | --vit_checkpoint_path ${vit_checkpoint_path} \ 25 | --calvin_dataset ${calvin_dataset_path} \ 26 | --calvin_conf_path ${calvin_conf_path} \ 27 | --workers 16 \ 28 | --lr_scheduler cosine \ 29 | --save_every_iter 50000 \ 30 | --num_epochs 20 \ 31 | --seed 42 \ 32 | --batch_size 64 \ 33 | --precision fp32 \ 34 | --weight_decay 1e-4 \ 35 | --num_resampler_query 6 \ 36 | --num_obs_token_per_image 9 \ 37 | --run_name ${run_name} \ 38 | --save_checkpoint_path ${save_checkpoint_path} \ 39 | --transformer_layers 24 \ 40 | --hidden_dim 384 \ 41 | --transformer_heads 12 \ 42 | --phase "evaluate" \ 43 | --finetune_type "calvin" \ 44 | --action_pred_steps 3 \ 45 | --sequence_length 10 \ 46 | --future_steps 3 \ 47 | --window_size 13 \ 48 | --obs_pred \ 49 | --resume_from_checkpoint ${resume_from_checkpoint} | tee ${log_file} \ 50 | -------------------------------------------------------------------------------- /scripts/REAL/single_node_language_cluster.sh: -------------------------------------------------------------------------------- 1 | ### NEED TO CHANGE ### 2 | save_checkpoint_path="xxx/checkpoints" 3 | root_dir="xxx/preprocess" 4 | vit_checkpoint_path="xxx/mae_pretrain_vit_base.pth" # downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing 5 | ### NEED TO CHANGE ### 6 | 7 | ### EXAMPLE ### 8 | # - root_dir 9 | # - droid_success 10 | # - epsiodes 11 | # - 000000 12 | # - ...... 13 | # - xxxxxx 14 | # - meta_info.h5 15 | # - shape_info.h5 16 | ### EXAMPLE ### 17 | 18 | node=4 19 | node_num=8 20 | torchrun --nnodes=${node} --nproc_per_node=${node_num} --master_port=62345 train.py \ 21 | --traj_cons \ 22 | --rgb_pad 10 \ 23 | --gripper_pad 4 \ 24 | --gradient_accumulation_steps 2 \ 25 | --bf16_module "vision_encoder" \ 26 | --vit_checkpoint_path ${vit_checkpoint_path} \ 27 | --calvin_dataset "" \ 28 | --workers 8 \ 29 | --lr_scheduler cosine \ 30 | --save_every_iter 20000 \ 31 | --num_epochs 20 \ 32 | --seed 42 \ 33 | --batch_size 32 \ 34 | --precision fp32 \ 35 | --learning_rate 1e-4 \ 36 | --save_checkpoint \ 37 | --finetune_type "droid" \ 38 | --wandb_project seer \ 39 | --weight_decay 1e-4 \ 40 | --num_resampler_query 6 \ 41 | --run_name sn_lang_droid \ 42 | --save_checkpoint_path ${save_checkpoint_path} \ 43 | --except_lang \ 44 | --transformer_layers 24 \ 45 | --phase "pretrain" \ 46 | --obs_pred \ 47 | --action_pred_steps 3 \ 48 | --sequence_length 11 \ 49 | --window_size 11 \ 50 | --future_steps 3 \ 51 | --loss_action \ 52 | --loss_image \ 53 | --atten_goal 4 \ 54 | --atten_goal_state \ 55 | --atten_only_obs \ 56 | --real_dataset_names "" \ 57 | --root_dir ${root_dir} \ 58 | --dataset_info droid_success_languaged_0803 \ 59 | --warmup_epochs 3 \ 60 | --report_to_wandb \ 61 | -------------------------------------------------------------------------------- /scripts/CALVIN_ABC_D/Seer-Large/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export GIT_PYTHON_REFRESH=quiet 3 | calvin_dataset_path="calvin/dataset/task_ABC_D" 4 | calvin_conf_path="calvin/calvin_models/conf" 5 | vit_checkpoint_path="checkpoints/vit_mae/mae_pretrain_vit_base.pth" # downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing 6 | save_checkpoint_path="checkpoints/" 7 | ### NEED TO CHANGE the checkpoint path ### 8 | resume_from_checkpoint="checkpoints/xxx/xxx.pth" 9 | 10 | 11 | IFS='/' read -ra path_parts <<< "$resume_from_checkpoint" 12 | run_name="${path_parts[-2]}" 13 | log_name="${path_parts[-1]}" 14 | log_folder="eval_logs/$run_name" 15 | mkdir -p "$log_folder" 16 | log_file="eval_logs/$run_name/evaluate_$log_name.log" 17 | node=1 18 | node_num=8 19 | 20 | torchrun --nnodes=${node} --nproc_per_node=${node_num} --master_port=10212 eval_calvin.py\ 21 | --traj_cons \ 22 | --rgb_pad 10 \ 23 | --gripper_pad 4 \ 24 | --gradient_accumulation_steps 1 \ 25 | --bf16_module "vision_encoder" \ 26 | --vit_checkpoint_path ${vit_checkpoint_path} \ 27 | --calvin_dataset ${calvin_dataset_path} \ 28 | --calvin_conf_path ${calvin_conf_path} \ 29 | --workers 16 \ 30 | --lr_scheduler cosine \ 31 | --save_every_iter 50000 \ 32 | --num_epochs 20 \ 33 | --seed 42 \ 34 | --batch_size 64 \ 35 | --precision fp32 \ 36 | --weight_decay 1e-4 \ 37 | --num_resampler_query 16 \ 38 | --num_obs_token_per_image 16 \ 39 | --run_name ${run_name} \ 40 | --save_checkpoint_path ${save_checkpoint_path} \ 41 | --transformer_layers 24 \ 42 | --hidden_dim 1024 \ 43 | --transformer_heads 16 \ 44 | --phase "evaluate" \ 45 | --finetune_type "calvin" \ 46 | --action_pred_steps 3 \ 47 | --sequence_length 10 \ 48 | --future_steps 3 \ 49 | --window_size 13 \ 50 | --obs_pred \ 51 | --resume_from_checkpoint ${resume_from_checkpoint} | tee ${log_file} \ 52 | -------------------------------------------------------------------------------- /scripts/REAL/single_node_scratch.sh: -------------------------------------------------------------------------------- 1 | ### NEED TO CHANGE ### 2 | save_checkpoint_path="xxx/checkpoints" 3 | root_dir="your_path_to_the_parent_folder_of_real_data" 4 | real_dataset_names="your_real_dataset_name" 5 | vit_checkpoint_path="xxx/mae_pretrain_vit_base.pth" # downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing 6 | ### NEED TO CHANGE ### 7 | 8 | ### EXAMPLE ### 9 | # - root_dir 10 | # - real_dataset_names 11 | # - 0000 12 | # - 000000 13 | # - ...... 14 | # - xxxxxx 15 | # - .... 16 | # - 00xx 17 | ### EXAMPLE ### 18 | 19 | node=1 20 | node_num=8 21 | torchrun --nnodes=${node} --nproc_per_node=${node_num} --master_port=10211 train.py \ 22 | --traj_cons \ 23 | --rgb_pad 10 \ 24 | --gripper_pad 4 \ 25 | --gradient_accumulation_steps 4 \ 26 | --bf16_module "vision_encoder" \ 27 | --vit_checkpoint_path ${vit_checkpoint_path} \ 28 | --calvin_dataset "" \ 29 | --workers 8 \ 30 | --lr_scheduler cosine \ 31 | --save_every_iter 100000 \ 32 | --num_epochs 40 \ 33 | --seed 42 \ 34 | --batch_size 16 \ 35 | --precision fp32 \ 36 | --learning_rate 1e-3 \ 37 | --save_checkpoint \ 38 | --finetune_type real \ 39 | --root_dir ${root_dir} \ 40 | --wandb_project seer \ 41 | --weight_decay 1e-4 \ 42 | --num_resampler_query 6 \ 43 | --run_name sn_scratch \ 44 | --save_checkpoint_path ${save_checkpoint_path} \ 45 | --except_lang \ 46 | --transformer_layers 24 \ 47 | --phase "finetune" \ 48 | --action_pred_steps 3 \ 49 | --sequence_length 7 \ 50 | --future_steps 3 \ 51 | --window_size 10 \ 52 | --obs_pred \ 53 | --loss_action \ 54 | --loss_image \ 55 | --save_checkpoint_seq 1 \ 56 | --start_save_checkpoint 15 \ 57 | --warmup_epochs 5 \ 58 | --real_dataset_names ${real_dataset_names} \ 59 | --use_aug_data \ 60 | --report_to_wandb \ 61 | -------------------------------------------------------------------------------- /scripts/CALVIN_ABC_D/Seer-Large/finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ### need to change to your path ### 3 | calvin_dataset_path="calvin/dataset/task_ABC_D" 4 | save_checkpoint_path="checkpoints/" 5 | finetune_from_pretrained_ckpt="checkpoints/pretrain_Seer_ptbs512_24layers_16heads_hd1024-Large/5.pth" 6 | vit_checkpoint_path="checkpoints/vit_mae/mae_pretrain_vit_base.pth" # downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing 7 | # node=8 8 | node_num=8 9 | torchrun \ 10 | --nnodes=${MLP_WORKER_NUM} \ 11 | --node_rank=${MLP_ROLE_INDEX} \ 12 | --nproc_per_node=${node_num} \ 13 | --master_addr=${MLP_WORKER_0_HOST} \ 14 | --master_port=${MLP_WORKER_0_PORT} \ 15 | train.py \ 16 | --traj_cons \ 17 | --rgb_pad 10 \ 18 | --gripper_pad 4 \ 19 | --gradient_accumulation_steps 1 \ 20 | --bf16_module "vision_encoder" \ 21 | --vit_checkpoint_path ${vit_checkpoint_path} \ 22 | --calvin_dataset ${calvin_dataset_path} \ 23 | --workers 8 \ 24 | --lr_scheduler cosine \ 25 | --save_every_iter 100000 \ 26 | --num_epochs 20 \ 27 | --seed 42 \ 28 | --batch_size 8 \ 29 | --precision fp32 \ 30 | --learning_rate 1e-3 \ 31 | --warmup_epochs 1 \ 32 | --finetune_type "calvin" \ 33 | --wandb_project seer \ 34 | --weight_decay 1e-4 \ 35 | --num_resampler_query 16 \ 36 | --num_obs_token_per_image 16 \ 37 | --run_name finetune_Seer-Large \ 38 | --save_checkpoint_path ${save_checkpoint_path} \ 39 | --transformer_layers 24 \ 40 | --hidden_dim 1024 \ 41 | --transformer_heads 16 \ 42 | --phase "finetune" \ 43 | --action_pred_steps 3 \ 44 | --sequence_length 10 \ 45 | --future_steps 3 \ 46 | --window_size 13 \ 47 | --obs_pred \ 48 | --loss_image \ 49 | --loss_action \ 50 | --save_checkpoint \ 51 | --report_to_wandb \ 52 | --finetune_from_pretrained_ckpt ${finetune_from_pretrained_ckpt} \ 53 | 54 | 55 | -------------------------------------------------------------------------------- /scripts/REAL/single_node_ft.sh: -------------------------------------------------------------------------------- 1 | ### NEED TO CHANGE ### 2 | save_checkpoint_path="xxx/checkpoints" 3 | root_dir="your_path_to_the_parent_folder_of_real_data" 4 | real_dataset_names="your_real_dataset_name" 5 | finetune_from_pretrained_ckpt="xxx/xxx.pth" 6 | vit_checkpoint_path="xxx/mae_pretrain_vit_base.pth" # downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing 7 | ### NEED TO CHANGE ### 8 | 9 | ### EXAMPLE ### 10 | # - root_dir 11 | # - real_dataset_names 12 | # - 0000 13 | # - 000000 14 | # - ...... 15 | # - xxxxxx 16 | # - .... 17 | # - 00xx 18 | ### EXAMPLE ### 19 | 20 | node=1 21 | node_num=8 22 | torchrun --nnodes=${node} --nproc_per_node=${node_num} --master_port=10211 train.py \ 23 | --traj_cons \ 24 | --rgb_pad 10 \ 25 | --gripper_pad 4 \ 26 | --gradient_accumulation_steps 4 \ 27 | --bf16_module "vision_encoder" \ 28 | --vit_checkpoint_path ${vit_checkpoint_path} \ 29 | --calvin_dataset "" \ 30 | --workers 8 \ 31 | --lr_scheduler cosine \ 32 | --save_every_iter 100000 \ 33 | --num_epochs 40 \ 34 | --seed 42 \ 35 | --batch_size 16 \ 36 | --precision fp32 \ 37 | --learning_rate 1e-3 \ 38 | --save_checkpoint \ 39 | --finetune_type real \ 40 | --root_dir ${root_dir} \ 41 | --wandb_project seer \ 42 | --weight_decay 1e-4 \ 43 | --num_resampler_query 6 \ 44 | --run_name sn_ft \ 45 | --save_checkpoint_path ${save_checkpoint_path} \ 46 | --except_lang \ 47 | --transformer_layers 24 \ 48 | --phase "finetune" \ 49 | --action_pred_steps 3 \ 50 | --sequence_length 7 \ 51 | --future_steps 3 \ 52 | --window_size 10 \ 53 | --obs_pred \ 54 | --loss_action \ 55 | --loss_image \ 56 | --save_checkpoint_seq 1 \ 57 | --start_save_checkpoint 15 \ 58 | --warmup_epochs 5 \ 59 | --real_dataset_names ${real_dataset_names} \ 60 | --reset_action_token \ 61 | --reset_obs_token \ 62 | --use_aug_data \ 63 | --report_to_wandb \ 64 | --finetune_from_pretrained_ckpt ${finetune_from_pretrained_ckpt} \ 65 | -------------------------------------------------------------------------------- /real_preprocess/octo_oxe_data_utils/utils/task_augmentation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains basic logic for randomly zero-ing out keys in the task specification. 3 | """ 4 | 5 | import tensorflow as tf 6 | 7 | from octo_oxe_data_utils.utils.data_utils import to_padding 8 | 9 | 10 | def delete_task_conditioning( 11 | traj: dict, 12 | keep_image_prob: float, 13 | ): 14 | """ 15 | Randomly drops out either the goal images or the language instruction. Only does something if both of 16 | these are present. 17 | 18 | Args: 19 | traj: A dictionary containing trajectory data. Should have a "task" key. 20 | keep_image_prob: The probability of keeping the goal images. The probability of keeping the language 21 | instruction is 1 - keep_image_prob. 22 | """ 23 | if "language_instruction" not in traj["task"]: 24 | return traj 25 | 26 | image_keys = { 27 | key 28 | for key in traj["task"].keys() 29 | if key.startswith("image_") or key.startswith("depth_") 30 | } 31 | if not image_keys: 32 | return traj 33 | 34 | traj_len = tf.shape(traj["action"])[0] 35 | should_keep_images = tf.random.uniform([traj_len]) < keep_image_prob 36 | should_keep_images |= ~traj["task"]["pad_mask_dict"]["language_instruction"] 37 | 38 | for key in image_keys | {"language_instruction"}: 39 | should_keep = should_keep_images if key in image_keys else ~should_keep_images 40 | # pad out the key 41 | traj["task"][key] = tf.where( 42 | should_keep, 43 | traj["task"][key], 44 | to_padding(traj["task"][key]), 45 | ) 46 | # zero out the pad mask dict for the key 47 | traj["task"]["pad_mask_dict"][key] = tf.where( 48 | should_keep, 49 | traj["task"]["pad_mask_dict"][key], 50 | tf.zeros_like(traj["task"]["pad_mask_dict"][key]), 51 | ) 52 | 53 | # when no goal images are present, the goal timestep becomes the final timestep 54 | traj["task"]["timestep"] = tf.where( 55 | should_keep_images, 56 | traj["task"]["timestep"], 57 | traj_len - 1, 58 | ) 59 | 60 | return traj 61 | -------------------------------------------------------------------------------- /deploy.py: -------------------------------------------------------------------------------- 1 | from real_controller.controller import SeerController 2 | import time 3 | import numpy as np 4 | from scipy.spatial.transform import Rotation as R 5 | 6 | 7 | def _6d_to_pose( 8 | pose6d, 9 | degrees=False 10 | ): 11 | pose = np.eye(4) 12 | pose[:3, 3] = pose6d[:3] 13 | pose[:3, :3] = R.from_euler("xyz", pose6d[3:6], degrees=degrees).as_matrix() 14 | 15 | return pose 16 | 17 | def pose_to_6d( 18 | pose, 19 | degrees=False 20 | ): 21 | pose6d = np.zeros(6) 22 | pose6d[:3] = pose[:3, 3] 23 | pose6d[3:6] = R.from_matrix(pose[:3, :3]).as_euler("xyz", degrees=degrees) 24 | 25 | return pose6d 26 | 27 | ### pseudocode example to deploy seer ### 28 | 29 | # Recommend control frequency: 15Hz, same as Droid 30 | control_freq = 15 # Hz 31 | max_rel_pos = 0.02 # magic number, same as training 32 | max_rel_orn = 0.05 # magic number, same as training 33 | 34 | # set a controller 35 | controller = SeerController() 36 | last2robot_pose = env.get_robot_state()["pose"] # absolute 4x4 pose matrix in robot space 37 | # warm up 38 | for i in range(3): 39 | obs = {} 40 | obs["robot_state"] = env.get_robot_state() 41 | obs["color_image"] = env.get_color_images() 42 | target_pos, target_euler, target_gripper, _ = controller.forward(obs, include_info=True) 43 | 44 | while True: 45 | # at each time step t 46 | torch.cuda.synchronize() 47 | t1 = time.time() 48 | 49 | obs["robot_state"] = env.get_robot_state() 50 | obs["color_image"] = env.get_color_images() 51 | target_pos, target_euler, target_gripper, _ = controller.forward(obs, include_info=True) 52 | 53 | # delta-action-2-absolute-action 54 | target_pos *= self.max_rel_pos 55 | target_euler *= self.max_rel_orn 56 | cur2last_pose = _6d_to_pose(np.concatenate([target_pos, target_euler])) 57 | last2robot_pose = last2robot_pose @ cur2last_pose 58 | target_pose = pose_to_6d(last2robot_pose) 59 | 60 | torch.cuda.synchronize() 61 | t2 = time.time() 62 | sleep_left = 1. / control_freq - (t2 - t1) 63 | 64 | if sleep_left > 0: 65 | time.sleep(sleep_left) 66 | 67 | env.step(target_pose, target_gripper) 68 | 69 | ### pseudocode example to deploy seer ### 70 | 71 | 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /docs/CALVIN_ABC-D_RUN.md: -------------------------------------------------------------------------------- 1 | # Running 2 | ## Notice 3 | 4 | For convenience, some checkpoints, such as the MAE-pretrained ViT-B model, are provided for manual download. Users must update the following paths accordingly. Relevant checkpoints can be acquired from the [website](https://drive.google.com/drive/folders/1F3IE95z2THAQ_lt3DKUFdRGc86Thsnc7?usp=sharing). 5 | * :exclamation: **pretrain.sh, finetune.sh, scratch, eval.sh:** 6 | Please update the following: 7 | * **calvin_dataset_path** to the directory where you have stored the CALVIN ABC-D data. 8 | * **save_checkpoint_path** to the parent directory where your experiment checkpoints are saved. Recommend to create a ```checkpoints``` folder in the project root directory. 9 | * **finetune_from_pretrained_ckpt** to the location of your pre-trained checkpoint. 10 | * **resume_from_checkpoint** to the location of your fine-tuned checkpoint. 11 | * **vit_checkpoint_path** to the location of your ViT checkpoint (downloaded from the [website](https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing)). Recommend to be stored in ```checkpoints/vit_mae/mae_pretrain_vit_base.pth```. 12 | 13 | * :exclamation: **networkx:** 14 | Due to compatibility issues between the networkx library in CALVIN and Python 3.10, we provide a compatible version of networkx.zip on the [website](https://drive.google.com/file/d/1z-d1SaI0rXfBtBicw1zPSsP-wE-26oLq/view?usp=sharing). Download and unzip it, then replace the existing networkx library in the following path: 15 | 16 | ## Seer 17 | ### Pre-train 18 | ```bash 19 | # Pre-train Seer on Calvin ABC-D dataset 20 | bash scripts/CALVIN_ABC_D/Seer/pretrain.sh 21 | # Pre-train Seer-Large on Calvin ABC-D dataset 22 | bash scripts/CALVIN_ABC_D/Seer-Large/pretrain.sh 23 | ``` 24 | 25 | ### Fine-tune 26 | ```bash 27 | # Fine-tune Seer on Calvin ABC-D dataset 28 | bash scripts/CALVIN_ABC_D/Seer/finetune.sh 29 | # Fine-tune Seer-Large on Calvin ABC-D dataset 30 | bash scripts/CALVIN_ABC_D/Seer-Large/finetune.sh 31 | ``` 32 | 33 | ### Train from Scratch 34 | ```bash 35 | # Train Seer on Calvin ABC-D dataset from scratch 36 | bash scripts/CALVIN_ABC_D/Seer/scratch.sh 37 | # Train Seer-Large on Calvin ABC-D dataset from scratch 38 | bash scripts/CALVIN_ABC_D/Seer-Large/scratch.sh 39 | ``` 40 | 41 | ### Eval 42 | ```bash 43 | # Evaluate Seer on Calvin ABC-D benchmark 44 | bash scripts/CALVIN_ABC_D/Seer/eval.sh 45 | # Evaluate Seer-Large on Calvin ABC-D benchmark 46 | bash scripts/CALVIN_ABC_D/Seer-Large/eval.sh 47 | ``` 48 | 49 | -------------------------------------------------------------------------------- /real_preprocess/octo_oxe_data_utils/utils/text_processing.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Optional, Sequence 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | MULTI_MODULE = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3" 8 | 9 | 10 | class TextProcessor(ABC): 11 | """ 12 | Base class for text tokenization or text embedding. 13 | """ 14 | 15 | @abstractmethod 16 | def encode(self, strings: Sequence[str]): 17 | raise NotImplementedError 18 | 19 | 20 | class HFTokenizer(TextProcessor): 21 | def __init__( 22 | self, 23 | tokenizer_name: str, 24 | tokenizer_kwargs: Optional[dict] = { 25 | "max_length": 64, 26 | "padding": "max_length", 27 | "truncation": True, 28 | "return_tensors": "np", 29 | }, 30 | encode_with_model: bool = False, 31 | ): 32 | from transformers import AutoTokenizer, FlaxAutoModel # lazy import 33 | 34 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 35 | self.tokenizer_kwargs = tokenizer_kwargs 36 | self.encode_with_model = encode_with_model 37 | if self.encode_with_model: 38 | self.model = FlaxAutoModel.from_pretrained(tokenizer_name) 39 | 40 | def encode(self, strings: Sequence[str]): 41 | # this creates another nested layer with "input_ids", "attention_mask", etc. 42 | inputs = self.tokenizer( 43 | strings, 44 | **self.tokenizer_kwargs, 45 | ) 46 | if self.encode_with_model: 47 | return np.array(self.model(**inputs).last_hidden_state) 48 | else: 49 | return dict(inputs) 50 | 51 | 52 | class MuseEmbedding(TextProcessor): 53 | def __init__(self): 54 | import tensorflow_hub as hub # lazy import 55 | import tensorflow_text # noqa: F401 56 | 57 | self.muse_model = hub.load(MULTI_MODULE) 58 | 59 | def encode(self, strings: Sequence[str]): 60 | with tf.device("/cpu:0"): 61 | return self.muse_model(strings).numpy() 62 | 63 | 64 | class CLIPTextProcessor(TextProcessor): 65 | def __init__( 66 | self, 67 | tokenizer_kwargs: Optional[dict] = { 68 | "max_length": 64, 69 | "padding": "max_length", 70 | "truncation": True, 71 | "return_tensors": "np", 72 | }, 73 | ): 74 | from transformers import CLIPProcessor 75 | 76 | self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") 77 | self.kwargs = tokenizer_kwargs 78 | 79 | def encode(self, strings: Sequence[str]): 80 | inputs = self.processor( 81 | text=strings, 82 | **self.kwargs, 83 | ) 84 | inputs["position_ids"] = np.expand_dims( 85 | np.arange(inputs["input_ids"].shape[1]), axis=0 86 | ).repeat(inputs["input_ids"].shape[0], axis=0) 87 | return inputs 88 | -------------------------------------------------------------------------------- /data_info/austin_buds_dataset_converted_externally_to_rlds.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "dataset_name": "austin_buds_dataset_converted_externally_to_rlds", 4 | "wrist_image": "Flip vertically & horizontally", 5 | "s_ratio": 1.0, 6 | "accumulated_num_steps": 34112 7 | }, 8 | [ 9 | "episodes/000000", 10 | 640 11 | ], 12 | [ 13 | "episodes/000001", 14 | 827 15 | ], 16 | [ 17 | "episodes/000002", 18 | 639 19 | ], 20 | [ 21 | "episodes/000003", 22 | 694 23 | ], 24 | [ 25 | "episodes/000004", 26 | 786 27 | ], 28 | [ 29 | "episodes/000005", 30 | 549 31 | ], 32 | [ 33 | "episodes/000006", 34 | 668 35 | ], 36 | [ 37 | "episodes/000007", 38 | 623 39 | ], 40 | [ 41 | "episodes/000008", 42 | 741 43 | ], 44 | [ 45 | "episodes/000009", 46 | 671 47 | ], 48 | [ 49 | "episodes/000010", 50 | 624 51 | ], 52 | [ 53 | "episodes/000011", 54 | 722 55 | ], 56 | [ 57 | "episodes/000012", 58 | 602 59 | ], 60 | [ 61 | "episodes/000013", 62 | 645 63 | ], 64 | [ 65 | "episodes/000014", 66 | 640 67 | ], 68 | [ 69 | "episodes/000015", 70 | 630 71 | ], 72 | [ 73 | "episodes/000016", 74 | 603 75 | ], 76 | [ 77 | "episodes/000017", 78 | 788 79 | ], 80 | [ 81 | "episodes/000018", 82 | 727 83 | ], 84 | [ 85 | "episodes/000019", 86 | 758 87 | ], 88 | [ 89 | "episodes/000020", 90 | 708 91 | ], 92 | [ 93 | "episodes/000021", 94 | 670 95 | ], 96 | [ 97 | "episodes/000022", 98 | 550 99 | ], 100 | [ 101 | "episodes/000023", 102 | 652 103 | ], 104 | [ 105 | "episodes/000024", 106 | 688 107 | ], 108 | [ 109 | "episodes/000025", 110 | 634 111 | ], 112 | [ 113 | "episodes/000026", 114 | 746 115 | ], 116 | [ 117 | "episodes/000027", 118 | 643 119 | ], 120 | [ 121 | "episodes/000028", 122 | 675 123 | ], 124 | [ 125 | "episodes/000029", 126 | 734 127 | ], 128 | [ 129 | "episodes/000030", 130 | 668 131 | ], 132 | [ 133 | "episodes/000031", 134 | 698 135 | ], 136 | [ 137 | "episodes/000032", 138 | 669 139 | ], 140 | [ 141 | "episodes/000033", 142 | 623 143 | ], 144 | [ 145 | "episodes/000034", 146 | 694 147 | ], 148 | [ 149 | "episodes/000035", 150 | 731 151 | ], 152 | [ 153 | "episodes/000036", 154 | 575 155 | ], 156 | [ 157 | "episodes/000037", 158 | 636 159 | ], 160 | [ 161 | "episodes/000038", 162 | 637 163 | ], 164 | [ 165 | "episodes/000039", 166 | 713 167 | ], 168 | [ 169 | "episodes/000040", 170 | 750 171 | ], 172 | [ 173 | "episodes/000041", 174 | 751 175 | ], 176 | [ 177 | "episodes/000042", 178 | 683 179 | ], 180 | [ 181 | "episodes/000043", 182 | 649 183 | ], 184 | [ 185 | "episodes/000044", 186 | 766 187 | ], 188 | [ 189 | "episodes/000045", 190 | 676 191 | ], 192 | [ 193 | "episodes/000046", 194 | 671 195 | ], 196 | [ 197 | "episodes/000047", 198 | 794 199 | ], 200 | [ 201 | "episodes/000048", 202 | 737 203 | ], 204 | [ 205 | "episodes/000049", 206 | 714 207 | ] 208 | ] -------------------------------------------------------------------------------- /docs/REAL-WORLD_PREPROCESS.md: -------------------------------------------------------------------------------- 1 | # Pre-process 2 | Seer exclusively utilizes the [DROID](https://droid-dataset.github.io/) dataset for pre-training. In this section, we describe the data pre-processing and transformation steps for both the [DROID](https://droid-dataset.github.io/) and [OXE](https://robotics-transformer-x.github.io/) datasets. These transformations convert the RLDS format into a standard dataset format, including .png, .npz, and .h5 files. The transformed dataset is organized as follows: /subset_name/episodes/000000/steps/0000/xxx.jpg (h5). 3 | The pre-processing step also unifies action labels across different subsets. For example, it standardizes all control methods to use the delta end-effector pose control, ensuring consistency in the robot's base and end-effector origin and axes. This carefully designed alignment process minimizes confusion caused by different robots and control methods. 4 | To facilitate this process, we create a new environment, seer_pre, which is specifically used for pre-processing the DROID and OXE datasets into our desired format. 5 | 6 | 7 | ## seer_pre env 8 | **(1) Env** 9 | ```python 10 | conda create -n seer_pre python=3.10 11 | conda activate seer_pre 12 | ``` 13 | **(2) Move to real_preprocess** 14 | ```python 15 | cd ${YOUR_PATH_TO_SEER}/real_preprocess 16 | ``` 17 | **(3) Third Party Packages** 18 | ```python 19 | pip install -r requirements.txt 20 | ``` 21 | **(4) octo_oxe_data_utils (Optional for DROID, Required for OXE)** 22 | ```python 23 | cd octo_oxe_data_utils 24 | python setup.py install 25 | cd .. 26 | ``` 27 | **(5) Mujoco** 28 | ```python 29 | pip install mujoco 30 | ``` 31 | **(6) Torch** 32 | ```python 33 | pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu118 34 | ``` 35 | **(7) Dlimp (Important):** 36 | We try to use multiprocess to process data. However, the dataset.py in Dlimp introduce randomness. 37 | replace the dataset.py in /your_anaconda/envs/seer_pre/lib/python3.10/site-packages/dlimp/dataset.py with the one in [dlimp/dataset.py](../real_preprocess/dlimp/dataset.py) 38 | 39 | ## Run Instructions 40 | You can download the full DROID dataset (1.7TB) in RLDS format using the following command: 41 | ```python 42 | gsutil -m cp -r gs://gresearch/robotics/droid 43 | ``` 44 | If needed, follow the download instructions provided on the [OXE Github page](https://github.com/google-deepmind/open_x_embodiment). 45 | 46 | Preparation 47 | ```python 48 | cd ${YOUR_PATH_TO_SEER}/real_preprocess 49 | conda activate seer_pre 50 | ``` 51 | To process the DROID dataset, set the src_dir and tgt_dir paths. You can adjust the num_worker argument to specify the number of processes to use: 52 | ```python 53 | python convert_public_droid_to_h5_per_step.py 54 | ``` 55 | For processing the Franka subsets in the OXE dataset, update the src_root_dir and tgt_dataset_dir paths. Similarly, adjust the num_worker argument for parallel processing: 56 | ```python 57 | python convert_tfds_to_h5_per_step_oxe_franka.py 58 | ``` 59 | To process other subsets of the OXE dataset (excluding Franka), update the src_root_dir and tgt_dataset_dir paths and set the number of worker processes: 60 | ```python 61 | python convert_tfds_to_h5_per_step_oxe_others.py 62 | ``` 63 | -------------------------------------------------------------------------------- /data_info/.hydra/config.yaml: -------------------------------------------------------------------------------- 1 | cameras: 2 | static: 3 | _target_: vr_env.camera.static_camera.StaticCamera 4 | name: static 5 | fov: 10 6 | aspect: 1 7 | nearval: 0.01 8 | farval: 10 9 | width: 200 10 | height: 200 11 | look_at: 12 | - -0.026242351159453392 13 | - -0.0302329882979393 14 | - 0.3920000493526459 15 | look_from: 16 | - 2.871459009488717 17 | - -2.166602199425597 18 | - 2.555159848480571 19 | up_vector: 20 | - 0.4041403970338857 21 | - 0.22629790978217404 22 | - 0.8862616969685161 23 | gripper: 24 | _target_: vr_env.camera.gripper_camera.GripperCamera 25 | name: gripper 26 | fov: 75 27 | aspect: 1 28 | nearval: 0.01 29 | farval: 2 30 | width: 84 31 | height: 84 32 | tactile: 33 | _target_: vr_env.camera.tactile_sensor.TactileSensor 34 | name: tactile 35 | width: 120 36 | height: 160 37 | digit_link_ids: 38 | - 10 39 | - 12 40 | visualize_gui: false 41 | config_path: conf/digit_sensor/config_digit.yml 42 | load_dir: /work/dlclarge2/meeso-lfp/calvin_recordings/play_env_A/2021-10-06/16-23-57/ 43 | data_path: data 44 | save_dir: ??? 45 | show_gui: false 46 | processes: 16 47 | max_episode_frames: 1 48 | save_body_infos: true 49 | set_static_cam: false 50 | env: 51 | cameras: ${cameras} 52 | show_gui: ${show_gui} 53 | use_vr: false 54 | scene: 55 | _target_: vr_env.scene.play_table_scene.PlayTableScene 56 | _recursive_: false 57 | name: calvin_scene_A 58 | data_path: ${data_path} 59 | global_scaling: 0.8 60 | euler_obs: ${robot.euler_obs} 61 | robot_base_position: 62 | - -0.34 63 | - -0.46 64 | - 0.24 65 | robot_base_orientation: 66 | - 0 67 | - 0 68 | - 0 69 | robot_initial_joint_positions: 70 | - -1.21779206 71 | - 1.03987646 72 | - 2.11978261 73 | - -2.34205014 74 | - -0.87015947 75 | - 1.64119353 76 | - 0.55344866 77 | surfaces: 78 | table: 79 | - - 0.0 80 | - -0.15 81 | - 0.46 82 | - - 0.35 83 | - -0.03 84 | - 0.46 85 | slider_left: 86 | - - -0.32 87 | - 0.05 88 | - 0.46 89 | - - -0.16 90 | - 0.12 91 | - 0.46 92 | slider_right: 93 | - - -0.05 94 | - 0.05 95 | - 0.46 96 | - - 0.13 97 | - 0.12 98 | - 0.46 99 | objects: 100 | fixed_objects: 101 | table: 102 | file: calvin_table_A/urdf/calvin_table_A.urdf 103 | initial_pos: 104 | - 0 105 | - 0 106 | - 0 107 | initial_orn: 108 | - 0 109 | - 0 110 | - 0 111 | joints: 112 | base__slide: 113 | initial_state: 0 114 | base__drawer: 115 | initial_state: 0 116 | buttons: 117 | base__button: 118 | initial_state: 0 119 | effect: led 120 | switches: 121 | base__switch: 122 | initial_state: 0 123 | effect: lightbulb 124 | lights: 125 | lightbulb: 126 | link: light_link 127 | color: 128 | - 1 129 | - 1 130 | - 0 131 | - 1 132 | led: 133 | link: led_link 134 | color: 135 | - 0 136 | - 1 137 | - 0 138 | - 1 139 | movable_objects: 140 | block_red: 141 | file: blocks/block_red_middle.urdf 142 | initial_pos: any 143 | initial_orn: any 144 | block_blue: 145 | file: blocks/block_blue_small.urdf 146 | initial_pos: any 147 | initial_orn: any 148 | block_pink: 149 | file: blocks/block_pink_big.urdf 150 | initial_pos: any 151 | initial_orn: any 152 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # workspace 2 | calvin 3 | checkpoints 4 | eval_logs 5 | evaluate 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | cover/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | .pybuilder/ 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | # For a library or package, you might want to ignore these files since the code is 93 | # intended to run in multiple environments; otherwise, check them in: 94 | # .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/#use-with-ide 116 | .pdm.toml 117 | 118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 119 | __pypackages__/ 120 | 121 | # Celery stuff 122 | celerybeat-schedule 123 | celerybeat.pid 124 | 125 | # SageMath parsed files 126 | *.sage.py 127 | 128 | # Environments 129 | .env 130 | .venv 131 | env/ 132 | venv/ 133 | ENV/ 134 | env.bak/ 135 | venv.bak/ 136 | 137 | # Spyder project settings 138 | .spyderproject 139 | .spyproject 140 | 141 | # Rope project settings 142 | .ropeproject 143 | 144 | # mkdocs documentation 145 | /site 146 | 147 | # mypy 148 | .mypy_cache/ 149 | .dmypy.json 150 | dmypy.json 151 | 152 | # Pyre type checker 153 | .pyre/ 154 | 155 | # pytype static type analyzer 156 | .pytype/ 157 | 158 | # Cython debug symbols 159 | cython_debug/ 160 | 161 | # PyCharm 162 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 163 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 164 | # and can be added to the global gitignore or merged into this file. For a more nuclear 165 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 166 | #.idea/ -------------------------------------------------------------------------------- /eval_libero.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import random 4 | import numpy as np 5 | import torch 6 | import wandb 7 | import clip 8 | from torch.nn.parallel import DistributedDataParallel as DDP 9 | from utils.distributed_utils import init_distributed_device, world_info_from_env 10 | 11 | from utils.eval_utils_libero import eval_one_epoch_libero_ddp 12 | 13 | # try: 14 | # from utils.eval_utils_libero import eval_one_epoch_libero_ddp as eval_one_epoch_calvin_ddp 15 | # except: 16 | # pass 17 | # from utils.eval_utils_libero import eval_one_epoch_libero_ddp as eval_one_epoch_calvin_ddp 18 | from torch.distributed.elastic.multiprocessing.errors import record 19 | # from utils.arguments_utils import get_args_and_cfg 20 | from utils.arguments_utils import get_parser 21 | from pdb import set_trace 22 | from models.seer_model import SeerAgent 23 | 24 | def random_seed(seed=42, rank=0): 25 | torch.manual_seed(seed + rank) 26 | np.random.seed(seed + rank) 27 | random.seed(seed + rank) 28 | 29 | @record 30 | def main(): 31 | parser = get_parser(is_eval=True) 32 | args = parser.parse_args() 33 | args.local_rank, args.rank, args.world_size = world_info_from_env() 34 | device_id = init_distributed_device(args) 35 | print("device_id: ", device_id) 36 | random_seed(args.seed) 37 | 38 | # model 39 | model = SeerAgent( 40 | finetune_type=args.finetune_type, 41 | clip_device=device_id, 42 | vit_checkpoint_path=args.vit_checkpoint_path, 43 | sequence_length=args.sequence_length, 44 | num_resampler_query=args.num_resampler_query, 45 | num_obs_token_per_image=args.num_obs_token_per_image, 46 | calvin_input_image_size=args.calvin_input_image_size, 47 | patch_size=args.patch_size, 48 | action_pred_steps=args.action_pred_steps, 49 | obs_pred=args.obs_pred, 50 | atten_only_obs=args.atten_only_obs, 51 | attn_robot_proprio_state=args.attn_robot_proprio_state, 52 | atten_goal=args.atten_goal, 53 | atten_goal_state=args.atten_goal_state, 54 | mask_l_obs_ratio=args.mask_l_obs_ratio, 55 | transformer_layers=args.transformer_layers, 56 | hidden_dim=args.hidden_dim, 57 | transformer_heads=args.transformer_heads, 58 | phase=args.phase, 59 | gripper_width=args.gripper_width, 60 | ) 61 | 62 | random_seed(args.seed, args.rank) 63 | print(f"Start running training on rank {args.rank}.") 64 | 65 | device_id = args.rank % torch.cuda.device_count() 66 | if args.precision == "bf16" or args.precision == "amp_bfloat16" or args.precision == "amp_bf16": 67 | model = model.bfloat16() 68 | elif args.precision == "fp16": 69 | model = model.half() 70 | elif args.precision == "fp32": 71 | model = model.float() 72 | if 'vision_encoder' in args.bf16_module: 73 | model.vision_encoder.bfloat16() 74 | if "causal_transformer" in args.bf16_module: 75 | model.transformer_backbone.bfloat16() 76 | if "image_decoder" in args.bf16_module: 77 | model.image_decoder.bfloat16() 78 | model.image_decoder_obs_pred_projector.bfloat16() 79 | 80 | model.clip_model.requires_grad_(False) 81 | model.vision_encoder.requires_grad_(False) 82 | model = model.to(device_id) 83 | model._init_model_type() 84 | 85 | ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters=True) 86 | 87 | if args.resume_from_checkpoint is not None: 88 | if args.rank == 0: 89 | print(f"Loading checkpoint from {args.resume_from_checkpoint}") 90 | checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu") 91 | ddp_model.load_state_dict(checkpoint["model_state_dict"], False) 92 | 93 | ddp_model.eval() 94 | 95 | eval_log_dir = 'evaluate' 96 | 97 | if args.finetune_type == "libero_10": 98 | eval_one_epoch_libero_ddp( 99 | args=args, 100 | model=ddp_model, 101 | image_processor=model.image_processor, 102 | tokenizer=clip, 103 | ) 104 | else: 105 | raise NotImplementedError 106 | 107 | if __name__ == "__main__": 108 | os.environ["NCCL_BLOCKING_WAIT"] = "0" 109 | main() 110 | -------------------------------------------------------------------------------- /models/perceiver_resampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange, repeat 3 | from einops_exts import rearrange_many 4 | from torch import einsum, nn 5 | 6 | 7 | def exists(val): 8 | return val is not None 9 | 10 | 11 | def FeedForward(dim, mult=4): 12 | inner_dim = int(dim * mult) 13 | return nn.Sequential( 14 | nn.LayerNorm(dim), 15 | nn.Linear(dim, inner_dim, bias=False), 16 | nn.GELU(), 17 | nn.Linear(inner_dim, dim, bias=False), 18 | ) 19 | 20 | 21 | class PerceiverAttention(nn.Module): 22 | def __init__(self, *, dim, dim_head=64, heads=8): 23 | super().__init__() 24 | self.scale = dim_head**-0.5 25 | self.heads = heads 26 | inner_dim = dim_head * heads 27 | 28 | self.norm_media = nn.LayerNorm(dim) 29 | self.norm_latents = nn.LayerNorm(dim) 30 | 31 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 32 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 33 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 34 | 35 | def forward(self, x, latents): 36 | """ 37 | Args: 38 | x (torch.Tensor): image features 39 | shape (b, T, n1, D) 40 | latent (torch.Tensor): latent features 41 | shape (b, T, n2, D) 42 | """ 43 | x = self.norm_media(x) 44 | latents = self.norm_latents(latents) 45 | 46 | h = self.heads 47 | 48 | q = self.to_q(latents) 49 | kv_input = torch.cat((x, latents), dim=-2) 50 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 51 | q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h) 52 | q = q * self.scale 53 | 54 | # attention 55 | sim = einsum("... i d, ... j d -> ... i j", q, k) 56 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 57 | attn = sim.softmax(dim=-1) 58 | 59 | out = einsum("... i j, ... j d -> ... i d", attn, v) 60 | out = rearrange(out, "b h t n d -> b t n (h d)", h=h) 61 | return self.to_out(out) 62 | 63 | 64 | class PerceiverResampler(nn.Module): 65 | def __init__( 66 | self, 67 | *, 68 | dim, 69 | depth=6, 70 | dim_head=64, 71 | heads=8, 72 | num_latents=64, 73 | max_num_media=None, 74 | max_num_frames=None, 75 | ff_mult=4, 76 | ): 77 | super().__init__() 78 | self.latents = nn.Parameter(torch.randn(num_latents, dim)) 79 | self.frame_embs = ( 80 | nn.Parameter(torch.randn(max_num_frames, dim)) 81 | if exists(max_num_frames) 82 | else None 83 | ) 84 | self.media_time_embs = ( 85 | nn.Parameter(torch.randn(max_num_media, 1, dim)) 86 | if exists(max_num_media) 87 | else None 88 | ) 89 | 90 | self.layers = nn.ModuleList([]) 91 | for _ in range(depth): 92 | self.layers.append( 93 | nn.ModuleList( 94 | [ 95 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 96 | FeedForward(dim=dim, mult=ff_mult), 97 | ] 98 | ) 99 | ) 100 | 101 | self.norm = nn.LayerNorm(dim) 102 | 103 | def forward(self, x): 104 | """ 105 | Args: 106 | x (torch.Tensor): image features 107 | shape (b, T, F, v, D) 108 | Returns: 109 | shape (b, T, n, D) where n is self.num_latents 110 | """ 111 | b, T, F, v = x.shape[:4] 112 | 113 | # frame and media time embeddings 114 | if exists(self.frame_embs): 115 | frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v) 116 | x = x + frame_embs 117 | x = rearrange( 118 | x, "b T F v d -> b T (F v) d" 119 | ) # flatten the frame and spatial dimensions 120 | if exists(self.media_time_embs): 121 | x = x + self.media_time_embs[:T] 122 | 123 | # blocks 124 | latents = repeat(self.latents, "n d -> b T n d", b=b, T=T) 125 | for attn, ff in self.layers: 126 | latents = attn(x, latents) + latents 127 | latents = ff(latents) + latents 128 | return self.norm(latents) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Predictive Inverse Dynamics Models are Scalable Learners for Robotic Manipulation 4 |
5 | 6 |

7 | Arxiv | 8 | Webpage 9 |

10 | 11 | 12 | https://github.com/user-attachments/assets/49036e84-c397-4589-9024-efb05b14efa0 13 | 14 | 15 |

16 | 17 | ## :books: Table of Contents: 18 | 1. [Highlights](#high) 19 | 2. [Getting Started](#start) 20 | - [Simulation](#simulation) 21 | - [Real-World](#real-world) 22 | 3. [Checkpoints](#checkpoints) 23 | 4. [TODO List](#todos) 24 | 5. [License](#license) 25 | 6. [Citation](#citation). 26 | 7. [Acknowledgment](#acknowledgment) 27 | 28 | ## :fire: Highlights 29 | seer 30 | 31 | - :trophy: **SOTA simulation performance** Seer achieves state-of-the-art performance on simulation benchmarks CALVIN ABC-D and LIBERO-LONG. 32 | - :muscle: **Impressive Real-World performance** Seer demonstrates strong effectiveness and generalization across diverse real-world downstream tasks. 33 | 34 | ## :door: Getting Started 35 | We provide step-by-step guidance for running Seer in simulations and real-world experiments. 36 | Follow the specific instructions for a seamless setup. 37 | 38 | ### Simulation 39 | #### CALVIN ABC-D 40 | - [Installation](docs/CALVIN_ABC-D_INSTALL.md) 41 | - [Running Code](docs/CALVIN_ABC-D_RUN.md) 42 | #### LIBERO LONG 43 | - [Installation](docs/LIBERO_LONG_INSTALL.md) 44 | - [Running Code](docs/LIBERO_LONG_RUN.md) 45 | ### Real-World 46 | #### Real-World (Quick Training w & w/o pre-training) 47 | For users aiming to train Seer from scratch or fine-tune it, we provide comprehensive instructions for environment setup, downstream task data preparation, training, and deployment. 48 | - [Installation](docs/REAL-WORLD_INSTALL.md) 49 | - [Post-processing](docs/REAL-WORLD_POSTPROCESS.md) 50 | - [Fine-tuning & Scratch](docs/REAL-WORLD_FT_SC.md) 51 | - [Inference](docs/REAL-WORLD_INFERENCE.md) 52 | 53 | #### Real-World (Pre-training) 54 | This section details the pre-training process of Seer in real-world experiments, including environment setup, dataset preparation, and training procedures. Downstream task processing and fine-tuning are covered in [Real-World (Quick Training w & w/o pre-training)](#real-world-qs). 55 | - [Installation](docs/REAL-WORLD_INSTALL.md) 56 | - [Pre-processing](docs/REAL-WORLD_PREPROCESS.md) 57 | - [Pre-training](docs/REAL-WORLD_PRETRAIN.md) 58 | 59 | 60 | ## :pencil2: Checkpoints 61 | Relevant checkpoints are available on the [website](https://drive.google.com/drive/folders/1F3IE95z2THAQ_lt3DKUFdRGc86Thsnc7?usp=sharing). 62 | |Model|Checkpoint| 63 | |:------:|:------:| 64 | |CALVIN ABC-D|[Seer](https://drive.google.com/drive/folders/17Gv9snGCkViuhHmzN3eTWlI0tMfGSGT3?usp=sharing) (Avg.Len. : 3.98) / [Seer Large](https://drive.google.com/drive/folders/1AFabqfDEi69oMo0FTGhEiH2QSRLYBR9r?usp=drive_link) (Avg.Len. : 4.30)| 65 | |Real-World|[Seer (Droid Pre-trained)](https://drive.google.com/drive/folders/1rT8JKLhJGIo97jfYUm2JiFUrogOq-dgJ?usp=drive_link)| 66 | 67 | ## 📆 TODO 68 | - [x] Release real-world expriment code. 69 | - [x] Release CALVIN ABC-D experiment code (Seer). 70 | - [x] Release the evaluation code of Seer-Large on CALVIN ABC-D experiment. 71 | - [x] Release the training code of Seer-Large on CALVIN ABC-D experiment. 72 | - [x] Release LIBERO-LONG experiment code. 73 | - [ ] Release simpleseer, a quick scratch training & deploying code. 74 | 75 | ## License 76 | 77 | All assets and code are under the [Apache 2.0 license](./LICENSE) unless specified otherwise. 78 | 79 | ## Citation 80 | If you find the project helpful for your research, please consider citing our paper: 81 | ```bibtex 82 | @article{tian2024predictive, 83 | title={Predictive Inverse Dynamics Models are Scalable Learners for Robotic Manipulation}, 84 | author={Tian, Yang and Yang, Sizhe and Zeng, Jia and Wang, Ping and Lin, Dahua and Dong, Hao and Pang, Jiangmiao}, 85 | journal={arXiv preprint arXiv:2412.15109}, 86 | year={2024} 87 | } 88 | ``` 89 | 90 | ## Acknowledgment 91 | This project builds upon [GR-1](https://github.com/bytedance/GR-1) and [Roboflamingo](https://github.com/RoboFlamingo/RoboFlamingo). We thank these teams for their open-source contributions. 92 | -------------------------------------------------------------------------------- /real_preprocess/octo_oxe_data_utils/obs_transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains observation-level transforms used in the octo data pipeline. These transforms operate on the 3 | "observation" dictionary, and are applied at a per-frame level. 4 | """ 5 | from typing import Mapping, Tuple, Union 6 | 7 | from absl import logging 8 | import dlimp as dl 9 | import tensorflow as tf 10 | 11 | 12 | def augment( 13 | obs: dict, seed: tf.Tensor, augment_kwargs: Union[dict, Mapping[str, dict]] 14 | ) -> dict: 15 | """Augments images, skipping padding images.""" 16 | image_names = {key[6:] for key in obs if key.startswith("image_")} 17 | 18 | # "augment_order" is required in augment_kwargs, so if it's there, we can assume that the user has passed 19 | # in a single augmentation dict (otherwise, we assume that the user has passed in a mapping from image 20 | # name to augmentation dict) 21 | if "augment_order" in augment_kwargs: 22 | augment_kwargs = {name: augment_kwargs for name in image_names} 23 | 24 | for i, name in enumerate(image_names): 25 | if name not in augment_kwargs: 26 | continue 27 | kwargs = augment_kwargs[name] 28 | logging.debug(f"Augmenting image_{name} with kwargs {kwargs}") 29 | obs[f"image_{name}"] = tf.cond( 30 | obs["pad_mask_dict"][f"image_{name}"], 31 | lambda: dl.transforms.augment_image( 32 | obs[f"image_{name}"], 33 | **kwargs, 34 | seed=seed + i, # augment each image differently 35 | ), 36 | lambda: obs[f"image_{name}"], # skip padding images 37 | ) 38 | 39 | return obs 40 | 41 | 42 | def decode_and_resize( 43 | obs: dict, 44 | resize_size: Union[Tuple[int, int], Mapping[str, Tuple[int, int]]], 45 | depth_resize_size: Union[Tuple[int, int], Mapping[str, Tuple[int, int]]], 46 | ) -> dict: 47 | """Decodes images and depth images, and then optionally resizes them.""" 48 | # just gets the part after "image_" or "depth_" 49 | image_names = {key[6:] for key in obs if key.startswith("image_")} 50 | depth_names = {key[6:] for key in obs if key.startswith("depth_")} 51 | 52 | if isinstance(resize_size, tuple): 53 | resize_size = {name: resize_size for name in image_names} 54 | if isinstance(depth_resize_size, tuple): 55 | depth_resize_size = {name: depth_resize_size for name in depth_names} 56 | 57 | for name in image_names: 58 | if name not in resize_size: 59 | logging.warning( 60 | f"No resize_size was provided for image_{name}. This will result in 1x1 " 61 | "padding images, which may cause errors if you mix padding and non-padding images." 62 | ) 63 | image = obs[f"image_{name}"] 64 | if image.dtype == tf.string: 65 | if tf.strings.length(image) == 0: 66 | # this is a padding image 67 | image = tf.zeros((*resize_size.get(name, (1, 1)), 3), dtype=tf.uint8) 68 | else: 69 | image = tf.io.decode_image( 70 | image, expand_animations=False, dtype=tf.uint8 71 | ) 72 | elif image.dtype != tf.uint8: 73 | raise ValueError( 74 | f"Unsupported image dtype: found image_{name} with dtype {image.dtype}" 75 | ) 76 | if name in resize_size: 77 | image = dl.transforms.resize_image(image, size=resize_size[name]) 78 | obs[f"image_{name}"] = image 79 | 80 | for name in depth_names: 81 | if name not in depth_resize_size: 82 | logging.warning( 83 | f"No depth_resize_size was provided for depth_{name}. This will result in 1x1 " 84 | "padding depth images, which may cause errors if you mix padding and non-padding images." 85 | ) 86 | depth = obs[f"depth_{name}"] 87 | if depth.dtype == tf.string: 88 | if tf.strings.length(depth) == 0: 89 | # this is a padding image 90 | depth = tf.zeros( 91 | (*depth_resize_size.get(name, (1, 1)), 1), dtype=tf.float32 92 | ) 93 | else: 94 | depth = tf.io.decode_image( 95 | depth, expand_animations=False, dtype=tf.float32 96 | )[..., 0] 97 | elif depth.dtype != tf.float32: 98 | raise ValueError( 99 | f"Unsupported depth dtype: found depth_{name} with dtype {depth.dtype}" 100 | ) 101 | if name in depth_resize_size: 102 | depth = dl.transforms.resize_depth_image( 103 | depth, size=depth_resize_size[name] 104 | ) 105 | obs[f"depth_{name}"] = depth 106 | 107 | return obs 108 | -------------------------------------------------------------------------------- /eval_calvin.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import random 4 | import numpy as np 5 | import torch 6 | import wandb 7 | import clip 8 | from torch.nn.parallel import DistributedDataParallel as DDP 9 | from torch.distributed.elastic.multiprocessing.errors import record 10 | from models.seer_model import SeerAgent 11 | from utils.data_utils import get_calvin_dataset, get_calvin_val_dataset 12 | from utils.distributed_utils import init_distributed_device, world_info_from_env 13 | from utils.eval_utils_calvin import eval_one_epoch_calvin_ddp 14 | from utils.arguments_utils import get_parser 15 | 16 | 17 | def random_seed(seed=42, rank=0): 18 | torch.manual_seed(seed + rank) 19 | np.random.seed(seed + rank) 20 | random.seed(seed + rank) 21 | 22 | @record 23 | def main(): 24 | parser = get_parser(is_eval=True) 25 | args = parser.parse_args() 26 | if args.save_checkpoints_to_wandb and args.save_checkpoint and not args.report_to_wandb: 27 | raise ValueError("save_checkpoints_to_wandb requires report_to_wandb") 28 | if args.offline: 29 | os.environ["WANDB_MODE"] = "offline" 30 | os.environ["TRANSFORMERS_OFFLINE"] = "1" 31 | args.local_rank, args.rank, args.world_size = world_info_from_env() 32 | device_id = init_distributed_device(args) 33 | print("device_id: ", device_id) 34 | random_seed(args.seed) 35 | model = SeerAgent( 36 | finetune_type=args.finetune_type, 37 | clip_device=device_id, 38 | vit_checkpoint_path=args.vit_checkpoint_path, 39 | sequence_length=args.sequence_length, 40 | num_resampler_query=args.num_resampler_query, 41 | num_obs_token_per_image=args.num_obs_token_per_image, 42 | calvin_input_image_size=args.calvin_input_image_size, 43 | patch_size=args.patch_size, 44 | action_pred_steps=args.action_pred_steps, 45 | obs_pred=args.obs_pred, 46 | atten_only_obs=args.atten_only_obs, 47 | attn_robot_proprio_state=args.attn_robot_proprio_state, 48 | atten_goal=args.atten_goal, 49 | atten_goal_state=args.atten_goal_state, 50 | mask_l_obs_ratio=args.mask_l_obs_ratio, 51 | transformer_layers=args.transformer_layers, 52 | hidden_dim=args.hidden_dim, 53 | transformer_heads=args.transformer_heads, 54 | phase=args.phase, 55 | gripper_width=args.gripper_width, 56 | ) 57 | calvin_dataset = get_calvin_dataset(args, model.image_processor, clip, epoch=0) 58 | random_seed(args.seed, args.rank) 59 | print(f"Start running training on rank {args.rank}.") 60 | if args.rank == 0 and args.report_to_wandb: 61 | wandb.init( 62 | project=args.wandb_project, 63 | entity=args.wandb_entity, 64 | name=args.run_name, 65 | config=vars(args), 66 | ) 67 | device_id = args.rank % torch.cuda.device_count() 68 | if args.precision == "bf16" or args.precision == "amp_bfloat16" or args.precision == "amp_bf16": 69 | model = model.bfloat16() 70 | elif args.precision == "fp16": 71 | model = model.half() 72 | elif args.precision == "fp32": 73 | model = model.float() 74 | if 'vision_encoder' in args.bf16_module: 75 | model.vision_encoder.bfloat16() 76 | if "causal_transformer" in args.bf16_module: 77 | model.transformer_backbone.bfloat16() 78 | if "image_decoder" in args.bf16_module: 79 | model.image_decoder.bfloat16() 80 | model.image_decoder_obs_pred_projector.bfloat16() 81 | model.clip_model.requires_grad_(False) 82 | model.vision_encoder.requires_grad_(False) 83 | model = model.to(device_id) 84 | model._init_model_type() 85 | ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters=True) 86 | if args.resume_from_checkpoint is not None: 87 | if args.rank == 0: 88 | print(f"Loading checkpoint from {args.resume_from_checkpoint}") 89 | checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu") 90 | ddp_model.load_state_dict(checkpoint["model_state_dict"], False) 91 | ddp_model.eval() 92 | eval_log_dir = 'evaluate' 93 | if args.finetune_type == "calvin": 94 | eval_one_epoch_calvin_ddp( 95 | args=args, 96 | model=ddp_model, 97 | image_processor=model.image_processor, 98 | tokenizer=clip, 99 | dataset_path=args.calvin_dataset, 100 | future_act_len=args.future_act_len, 101 | eval_log_dir=eval_log_dir, 102 | debug=args.visualize, 103 | reset=args.reset, 104 | diverse_inst=args.diverse_inst 105 | ) 106 | else: 107 | raise NotImplementedError 108 | 109 | if __name__ == "__main__": 110 | os.environ["NCCL_BLOCKING_WAIT"] = "0" 111 | main() 112 | -------------------------------------------------------------------------------- /real_preprocess/octo_oxe_data_utils/traj_transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains trajectory transforms used in the octo data pipeline. Trajectory transforms operate on a dictionary 3 | that represents a single trajectory, meaning each tensor has the same leading dimension (the trajectory 4 | length). 5 | """ 6 | import logging 7 | 8 | import tensorflow as tf 9 | 10 | 11 | def chunk_act_obs( 12 | traj: dict, 13 | window_size: int, 14 | future_action_window_size: int = 0, 15 | ) -> dict: 16 | """Chunks actions and observations into the given window_size. 17 | 18 | "observation" keys are given a new axis (at index 1) of size `window_size` containing `window_size - 1` 19 | observations from the past and the current observation. "action" is given a new axis (at index 1) of size 20 | `window_size + future_action_window_size` containing `window_size - 1` actions from the past, the current 21 | action, and `future_action_window_size` actions from the future. "pad_mask" is added to "observation" and 22 | indicates whether an observation should be considered padding (i.e. if it would have come from a timestep 23 | before the start of the trajectory). 24 | """ 25 | traj_len = tf.shape(traj["action"])[0] 26 | action_dim = traj["action"].shape[-1] 27 | chunk_indices = tf.broadcast_to( 28 | tf.range(-window_size + 1, 1), [traj_len, window_size] 29 | ) + tf.broadcast_to(tf.range(traj_len)[:, None], [traj_len, window_size]) 30 | 31 | action_chunk_indices = tf.broadcast_to( 32 | tf.range(-window_size + 1, 1 + future_action_window_size), 33 | [traj_len, window_size + future_action_window_size], 34 | ) + tf.broadcast_to( 35 | tf.range(traj_len)[:, None], 36 | [traj_len, window_size + future_action_window_size], 37 | ) 38 | 39 | floored_chunk_indices = tf.maximum(chunk_indices, 0) 40 | 41 | if "timestep" in traj["task"]: 42 | goal_timestep = traj["task"]["timestep"] 43 | else: 44 | goal_timestep = tf.fill([traj_len], traj_len - 1) 45 | 46 | floored_action_chunk_indices = tf.minimum( 47 | tf.maximum(action_chunk_indices, 0), goal_timestep[:, None] 48 | ) 49 | 50 | traj["observation"] = tf.nest.map_structure( 51 | lambda x: tf.gather(x, floored_chunk_indices), traj["observation"] 52 | ) 53 | traj["action"] = tf.gather(traj["action"], floored_action_chunk_indices) 54 | 55 | # indicates whether an entire observation is padding 56 | traj["observation"]["pad_mask"] = chunk_indices >= 0 57 | 58 | # if no absolute_action_mask was provided, assume all actions are relative 59 | if "absolute_action_mask" not in traj and future_action_window_size > 0: 60 | logging.warning( 61 | "future_action_window_size > 0 but no absolute_action_mask was provided. " 62 | "Assuming all actions are relative for the purpose of making neutral actions." 63 | ) 64 | absolute_action_mask = traj.get( 65 | "absolute_action_mask", tf.zeros([traj_len, action_dim], dtype=tf.bool) 66 | ) 67 | neutral_actions = tf.where( 68 | absolute_action_mask[:, None, :], 69 | traj["action"], # absolute actions are repeated (already done during chunking) 70 | tf.zeros_like(traj["action"]), # relative actions are zeroed 71 | ) 72 | 73 | # actions past the goal timestep become neutral 74 | action_past_goal = action_chunk_indices > goal_timestep[:, None] 75 | traj["action"] = tf.where( 76 | action_past_goal[:, :, None], neutral_actions, traj["action"] 77 | ) 78 | return traj 79 | 80 | 81 | def subsample(traj: dict, subsample_length: int) -> dict: 82 | """Subsamples trajectories to the given length.""" 83 | traj_len = tf.shape(traj["action"])[0] 84 | if traj_len > subsample_length: 85 | indices = tf.random.shuffle(tf.range(traj_len))[:subsample_length] 86 | traj = tf.nest.map_structure(lambda x: tf.gather(x, indices), traj) 87 | return traj 88 | 89 | 90 | def add_pad_mask_dict(traj: dict) -> dict: 91 | """Adds a dictionary indicating which elements of the observation/task should be treated as padding. 92 | 93 | traj["observation"|"task"]["pad_mask_dict"] = {k: traj["observation"|"task"][k] is not padding} 94 | """ 95 | traj_len = tf.shape(traj["action"])[0] 96 | for key in ["observation", "task"]: 97 | pad_mask_dict = {} 98 | for subkey in traj[key]: 99 | if traj[key][subkey].dtype == tf.string: 100 | # handles "language_instruction", "image_*", and "depth_*" 101 | pad_mask_dict[subkey] = tf.strings.length(traj[key][subkey]) != 0 102 | else: 103 | # all other keys should not be treated as padding 104 | pad_mask_dict[subkey] = tf.ones([traj_len], dtype=tf.bool) 105 | traj[key]["pad_mask_dict"] = pad_mask_dict 106 | return traj 107 | -------------------------------------------------------------------------------- /docs/REAL-WORLD_POSTPROCESS.md: -------------------------------------------------------------------------------- 1 | # Post-process 2 | Self-collected data and pre-training datasets often exhibit certain challenges that can negatively affect model performance. Below, we outline these issues and provide standardized solutions for data formatting and post-processing. 3 | * :warning: **Varied Action Labels:** 4 | Different embodiments, and sometimes even identical ones, may use diverse action labels such as: 5 | * **Absolute target joint positions** (qpos) 6 | * **Absolute target joint velocites** (qvel) 7 | * **Absolute end-effector poses** (ee-pose) 8 | * **Delta target end-effector poses** (delta ee-pose) 9 | Furthermore, rotation representations may vary, including quaternions, Euler angles, rotation vectors, and rotation matrices. 10 | * :warning: **Jittering and Long Pauses:** Fresh data collectors often introduce hesitation, leading to long pauses or jittering during data collection. Without proper filtering, such data significantly degrades model performance. 11 | * :warning: **Quick Gripper Open/Close Actions:** A frequency mismatch between camera capture and gripper control often results in abrupt changes in gripper states, especially during grasping or releasing motions. 12 | 13 | To address these issues, we recommend a uniform, clear, and effective format for saving self-collected data and provide tools for post-processing. 14 | 15 | ## :exclamation: Data Format 16 | For each task, we collect 100 demos. The recommended directory structure is: 17 | ``` 18 | 0000 (exp_id) 19 | |—— 000000 (episode_id) 20 | |—— steps 21 | |—— 0000 (timestep_id, start) 22 | |—— image_primary.jpg (Eye-on-Base camera rgb image) 23 | |—— image_wrist.jpg (Eye-on-Hand camera rgb image) 24 | └── other.npz (robot state, language, action) 25 | |—— ...... 26 | └── xxxx (timestep_id, end) 27 | |—— 000001 (episode_id) 28 | |—— steps 29 | |—— ...... 30 | |—— ...... 31 | └── 000099 (episode_id) 32 | |—— steps 33 | |—— ...... 34 | ``` 35 | ### File Details: 36 | * **image_primary.jpg** and **image_wrist.jpg**: Images saved with a resolution of 640 x 480 pixels. 37 | * **other.npz**: Contains key robot metadata. An example of the saved format is: 38 | ```python 39 | # at each timestep i 40 | npz_path = f"other.npz" 41 | 42 | # absolute current gripper pose in robot space, position + euler angles, the unit is m and rad. 43 | gripper_pose = np.array([x, y, z, euler_x, euler_y, euler_z]) 44 | 45 | # absolute current gripper open state 46 | gripper_open_state = np.array([1.0]) if gripper is opened else np.array([-1.0]) 47 | 48 | # absolute current joints position (qpos) 49 | joints = np.array([q0, q1, q2, q3, q4, q5, q6]) 50 | 51 | # language instruction 52 | language_instruction = f"Pick the apple." 53 | 54 | # absolute target pose action label (target_gripper_open_or_close is 1.0 if targetting open, else -1.0) 55 | action_gripper_pose = np.array([target_x, target_y, target_z, target_euler_x, target_euler_y, target_euler_z, target_gripper_open_or_close]) 56 | 57 | # delta pose action label 58 | delta_cur_2_last_action = np.array([target_delta_x, target_delta_y, target_delta_z, target_delta_euler_x, target_delta_euler_y, target_delta_euler_z, target_gripper_open_or_close]) 59 | 60 | # save npz 61 | np.savez_compressed( 62 | npz_path, 63 | joints=joints, 64 | gripper_pose=gripper_pose, 65 | gripper_open_state=gripper_open_state, 66 | action_gripper_pose=action_gripper_pose, 67 | delta_cur_2_last_action=delta_cur_2_last_action, 68 | language_instruction=language_instruction, 69 | ) 70 | ``` 71 | For most robotic systems, all metadata except delta_cur_2_last_action can be directly extracted. We provide a helper function to compute the delta pose action label in the [script](../utils/real_ft_data.py): 72 | ```python 73 | compute_delta_action( 74 | data_list, 75 | ) 76 | ``` 77 | 78 | ## :star: Post-processing Self-collected Data 79 | * **Filtering Jitter and Pauses:** To filter out jittering and long pauses, use the following function in the [script](../utils/real_ft_data.py): 80 | ```python 81 | filter_real_data( 82 | exp_id, 83 | root_path, # path to your raw data 84 | save_data_path, # a desired path to save filterd data 85 | save_gif_path # a desired path to save the filtered gif (only for visualization and debugging) 86 | ) 87 | ``` 88 | * **Data Augmentation for Gripper Actions:** To augment data by increasing sampling ratios during gripper open/close events, use the following function in the same [script](../utils/real_ft_data.py): 89 | ```python 90 | make_aug_short_real_dataset_info( 91 | root_path, # path to your filterd data 92 | root_info_path, # path to your data info, it should be like xxx/Seer/data_info 93 | dataset_name, # your dataset name, e.g. "ft" 94 | select_ratio=1.0, 95 | sequence_length=7, 96 | action_pred_steps=3, 97 | replicate_steps=10 98 | ) 99 | ``` 100 | 101 | -------------------------------------------------------------------------------- /data_info/.hydra/hydra.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: /work/dlclarge2/meeso-lfp/calvin_data/task_A_A/ 4 | sweep: 5 | dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} 6 | subdir: ${hydra.job.num} 7 | launcher: 8 | _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher 9 | sweeper: 10 | _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper 11 | max_batch_size: null 12 | help: 13 | app_name: ${hydra.job.name} 14 | header: '${hydra.help.app_name} is powered by Hydra. 15 | 16 | ' 17 | footer: 'Powered by Hydra (https://hydra.cc) 18 | 19 | Use --hydra-help to view Hydra specific help 20 | 21 | ' 22 | template: '${hydra.help.header} 23 | 24 | == Configuration groups == 25 | 26 | Compose your configuration from those groups (group=option) 27 | 28 | 29 | $APP_CONFIG_GROUPS 30 | 31 | 32 | == Config == 33 | 34 | Override anything in the config (foo.bar=value) 35 | 36 | 37 | $CONFIG 38 | 39 | 40 | ${hydra.help.footer} 41 | 42 | ' 43 | hydra_help: 44 | template: 'Hydra (${hydra.runtime.version}) 45 | 46 | See https://hydra.cc for more info. 47 | 48 | 49 | == Flags == 50 | 51 | $FLAGS_HELP 52 | 53 | 54 | == Configuration groups == 55 | 56 | Compose your configuration from those groups (For example, append hydra/job_logging=disabled 57 | to command line) 58 | 59 | 60 | $HYDRA_CONFIG_GROUPS 61 | 62 | 63 | Use ''--cfg hydra'' to Show the Hydra config. 64 | 65 | ' 66 | hydra_help: ??? 67 | hydra_logging: 68 | version: 1 69 | formatters: 70 | colorlog: 71 | (): colorlog.ColoredFormatter 72 | format: '[%(cyan)s%(asctime)s%(reset)s][%(purple)sHYDRA%(reset)s] %(message)s' 73 | handlers: 74 | console: 75 | class: logging.StreamHandler 76 | formatter: colorlog 77 | stream: ext://sys.stdout 78 | root: 79 | level: INFO 80 | handlers: 81 | - console 82 | disable_existing_loggers: false 83 | job_logging: 84 | version: 1 85 | formatters: 86 | simple: 87 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' 88 | colorlog: 89 | (): colorlog.ColoredFormatter 90 | format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)s%(name)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] 91 | - %(message)s' 92 | log_colors: 93 | DEBUG: purple 94 | INFO: green 95 | WARNING: yellow 96 | ERROR: red 97 | CRITICAL: red 98 | handlers: 99 | console: 100 | class: logging.StreamHandler 101 | formatter: colorlog 102 | stream: ext://sys.stdout 103 | file: 104 | class: logging.FileHandler 105 | formatter: simple 106 | filename: ${hydra.job.name}.log 107 | root: 108 | level: INFO 109 | handlers: 110 | - console 111 | - file 112 | disable_existing_loggers: false 113 | env: {} 114 | searchpath: [] 115 | callbacks: {} 116 | output_subdir: .hydra 117 | overrides: 118 | hydra: 119 | - hydra.run.dir=/work/dlclarge2/meeso-lfp/calvin_data/task_A_A/ 120 | task: 121 | - load_dir=/work/dlclarge2/meeso-lfp/calvin_recordings/play_env_A/2021-10-06/16-23-57/ 122 | - set_static_cam=false 123 | - processes=16 124 | - +scene=calvin_scene_A 125 | - show_gui=false 126 | job: 127 | name: datarenderer 128 | override_dirname: +scene=calvin_scene_A,load_dir=/work/dlclarge2/meeso-lfp/calvin_recordings/play_env_A/2021-10-06/16-23-57/,processes=16,set_static_cam=false,show_gui=false 129 | id: ??? 130 | num: ??? 131 | config_name: config_rendering 132 | env_set: {} 133 | env_copy: [] 134 | config: 135 | override_dirname: 136 | kv_sep: '=' 137 | item_sep: ',' 138 | exclude_keys: [] 139 | runtime: 140 | version: 1.1.1 141 | cwd: /home/meeso/VRENVCalvinRender/vr_env 142 | config_sources: 143 | - path: hydra.conf 144 | schema: pkg 145 | provider: hydra 146 | - path: /home/meeso/VRENVCalvinRender/conf 147 | schema: file 148 | provider: main 149 | - path: hydra_plugins.hydra_colorlog.conf 150 | schema: pkg 151 | provider: hydra-colorlog 152 | - path: '' 153 | schema: structured 154 | provider: schema 155 | choices: 156 | scene: calvin_scene_A 157 | cameras: static_gripper_tactile 158 | cameras/cameras@cameras.tactile: tactile 159 | cameras/cameras@cameras.gripper: gripper 160 | cameras/cameras@cameras.static: static 161 | hydra/env: default 162 | hydra/callbacks: null 163 | hydra/job_logging: colorlog 164 | hydra/hydra_logging: colorlog 165 | hydra/hydra_help: default 166 | hydra/help: default 167 | hydra/sweeper: basic 168 | hydra/launcher: basic 169 | hydra/output: default 170 | verbose: false 171 | -------------------------------------------------------------------------------- /real_preprocess/octo_oxe_data_utils/oxe/__init__.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | from typing import Any, Dict, List, Sequence, Tuple, Union 4 | 5 | from octo_oxe_data_utils.oxe.oxe_dataset_configs import ActionEncoding, OXE_DATASET_CONFIGS 6 | from octo_oxe_data_utils.oxe.oxe_dataset_mixes import OXE_NAMED_MIXES 7 | from octo_oxe_data_utils.oxe.oxe_standardization_transforms import OXE_STANDARDIZATION_TRANSFORMS 8 | from octo_oxe_data_utils.utils.data_utils import NormalizationType 9 | 10 | 11 | def make_oxe_dataset_kwargs( 12 | name: str, 13 | data_dir: str, 14 | load_camera_views: Sequence[str] = ("primary",), 15 | load_depth: bool = False, 16 | load_proprio: bool = True, 17 | load_language: bool = True, 18 | action_proprio_normalization_type: NormalizationType = NormalizationType.NORMAL, 19 | ) -> Dict[str, Any]: 20 | """Generates dataset kwargs for a given dataset from Open X-Embodiment. The returned kwargs can be passed 21 | directly into `octo.data.dataset.make_dataset_from_rlds`. 22 | 23 | Args: 24 | name: Name of the dataset to load. See `oxe_dataset_configs.py` for available datasets. 25 | data_dir: Base data directory that contains the dataset. 26 | load_camera_views: Which views to load. See `oxe_dataset_configs.py` for available views. 27 | load_depth: If True, loads corresponding depth channels for each RGB channel. 28 | load_proprio: If True, loads proprioceptive information. 29 | load_language: If True, loads language instructions. 30 | action_proprio_normalization_type: Normalization type to use for proprioceptive actions. 31 | """ 32 | dataset_kwargs = copy.deepcopy(OXE_DATASET_CONFIGS[name]) 33 | if dataset_kwargs["action_encoding"] is not ActionEncoding.EEF_POS: 34 | raise ValueError( 35 | f"Cannot load {name} since only EEF pose delta action encoding is supported." 36 | ) 37 | 38 | # with EEF_POS actions, only the last action dimension (the gripper) is absolute 39 | dataset_kwargs["absolute_action_mask"] = [False] * 6 + [True] 40 | 41 | # we also want to skip normalizing the gripper action 42 | dataset_kwargs["action_normalization_mask"] = [True] * 6 + [False] 43 | 44 | # adjust loaded camera views 45 | if missing_keys := (set(load_camera_views) - set(dataset_kwargs["image_obs_keys"])): 46 | raise ValueError( 47 | f"Cannot load {name} with views {missing_keys} since they are not available." 48 | ) 49 | dataset_kwargs["image_obs_keys"] = { 50 | k: v 51 | for k, v in dataset_kwargs["image_obs_keys"].items() 52 | if k in load_camera_views 53 | } 54 | dataset_kwargs["depth_obs_keys"] = { 55 | k: v 56 | for k, v in dataset_kwargs["depth_obs_keys"].items() 57 | if k in load_camera_views 58 | } 59 | 60 | if not load_depth: 61 | dataset_kwargs.pop("depth_obs_keys") 62 | if not load_proprio: 63 | dataset_kwargs.pop("state_obs_keys") 64 | 65 | if load_language: 66 | dataset_kwargs["language_key"] = "language_instruction" 67 | 68 | dataset_kwargs[ 69 | "action_proprio_normalization_type" 70 | ] = action_proprio_normalization_type 71 | 72 | del dataset_kwargs["state_encoding"] 73 | del dataset_kwargs["action_encoding"] 74 | 75 | dataset_kwargs["standardize_fn"] = OXE_STANDARDIZATION_TRANSFORMS[name] 76 | 77 | return {"name": name, "data_dir": data_dir, **dataset_kwargs} 78 | 79 | 80 | def make_oxe_dataset_kwargs_and_weights( 81 | data_mix: Union[str, Sequence[Tuple[str, float]]], 82 | data_dir: str, 83 | load_camera_views: Sequence[str] = ("primary",), 84 | load_depth: bool = False, 85 | load_proprio: bool = True, 86 | load_language: bool = True, 87 | action_proprio_normalization_type: NormalizationType = NormalizationType.NORMAL, 88 | ) -> Tuple[Dict[str, Any], List[float]]: 89 | """ 90 | Generates dataset kwargs for a given dataset mix from the Open X-Embodiment dataset. The returned kwargs 91 | and weights can be passed directly into `octo.data.dataset.make_interleaved_dataset`. 92 | 93 | Args: 94 | data_mix: List of (dataset name, sampling weight) tuples, or a string specifying a pre-defined mix to 95 | load from `OXE_NAMED_MIXES`. 96 | data_dir: Base data directory that contains the datasets. 97 | load_camera_views: Which views to load. See `oxe_dataset_configs.py` for available views. 98 | load_depth: If True, loads corresponding depth channels for each RGB channel. 99 | load_proprio: If True, loads proprioceptive information. 100 | load_language: If True, loads language instructions. 101 | action_proprio_normalization_type: Normalization type to use for proprioceptive actions. 102 | Returns: 103 | Tuple of (dataset_kwargs_list, sampling weights). 104 | """ 105 | if isinstance(data_mix, str): 106 | data_mix = OXE_NAMED_MIXES[data_mix] 107 | 108 | filtered_datasets, included_dataset_names = [], [] 109 | for name, weight in data_mix: 110 | if name not in included_dataset_names: 111 | filtered_datasets.append((name, weight)) 112 | included_dataset_names.append(name) 113 | else: 114 | logging.warning(f"Skipping duplicate: {(name, weight)}.") 115 | data_mix = filtered_datasets 116 | 117 | data_kwargs_list, weights = [], [] 118 | for name, weight in data_mix: 119 | try: 120 | data_kwargs_list.append( 121 | make_oxe_dataset_kwargs( 122 | name, 123 | data_dir, 124 | load_camera_views, 125 | load_depth, 126 | load_proprio, 127 | load_language, 128 | action_proprio_normalization_type, 129 | ) 130 | ) 131 | weights.append(weight) 132 | except ValueError as e: 133 | logging.warning(f"Skipping {name} due to error: {e}") 134 | 135 | return data_kwargs_list, weights 136 | -------------------------------------------------------------------------------- /utils/distributed_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Util functions for setting up distributed training. 3 | Credit: https://github.com/mlfoundations/open_clip/blob/main/src/training/distributed.py 4 | """ 5 | import os 6 | import torch 7 | from datetime import timedelta 8 | 9 | def is_global_master(args): 10 | return args.rank == 0 11 | 12 | def is_local_master(args): 13 | return args.local_rank == 0 14 | 15 | def is_master(args, local=False): 16 | return is_local_master(args) if local else is_global_master(args) 17 | 18 | def is_using_distributed(): 19 | if "WORLD_SIZE" in os.environ: 20 | return int(os.environ["WORLD_SIZE"]) > 1 21 | if "SLURM_NTASKS" in os.environ: 22 | return int(os.environ["SLURM_NTASKS"]) > 1 23 | return False 24 | 25 | def world_info_from_env(): 26 | local_rank = 0 27 | for v in ( 28 | "LOCAL_RANK", 29 | "MPI_LOCALRANKID", 30 | "SLURM_LOCALID", 31 | "OMPI_COMM_WORLD_LOCAL_RANK", 32 | ): 33 | if v in os.environ: 34 | local_rank = int(os.environ[v]) 35 | break 36 | global_rank = 0 37 | for v in ("RANK", "PMI_RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK"): 38 | if v in os.environ: 39 | global_rank = int(os.environ[v]) 40 | break 41 | world_size = 1 42 | for v in ("WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE"): 43 | if v in os.environ: 44 | world_size = int(os.environ[v]) 45 | break 46 | 47 | return local_rank, global_rank, world_size 48 | 49 | # def init_distributed_device(args): 50 | # # Distributed training = training on more than one GPU. 51 | # # Works in both single and multi-node scenarios. 52 | # args.distributed = False 53 | # args.world_size = 1 54 | # args.rank = 0 # global rank 55 | # args.local_rank = 0 56 | 57 | # if is_using_distributed(): 58 | # if "SLURM_PROCID" in os.environ: 59 | # # DDP via SLURM 60 | # args.local_rank, args.rank, args.world_size = world_info_from_env() 61 | # # SLURM var -> torch.distributed vars in case needed 62 | # os.environ["LOCAL_RANK"] = str(args.local_rank) 63 | # os.environ["RANK"] = str(args.rank) 64 | # os.environ["WORLD_SIZE"] = str(args.world_size) 65 | # torch.distributed.init_process_group( 66 | # backend=args.dist_backend, 67 | # init_method=args.dist_url, 68 | # world_size=args.world_size, 69 | # rank=args.rank, 70 | # timeout=timedelta(seconds=36000000) 71 | # ) 72 | # else: 73 | # # DDP via torchrun, torch.distributed.launch 74 | # args.local_rank, _, _ = world_info_from_env() 75 | # torch.distributed.init_process_group( 76 | # backend=args.dist_backend, init_method=args.dist_url, timeout=timedelta(seconds=36000000) 77 | # ) 78 | # args.world_size = torch.distributed.get_world_size() 79 | # args.rank = torch.distributed.get_rank() 80 | # args.distributed = True 81 | # else: 82 | # # needed to run on single gpu 83 | # torch.distributed.init_process_group( 84 | # backend=args.dist_backend, 85 | # init_method=args.dist_url, 86 | # world_size=1, 87 | # rank=0, 88 | # timeout=timedelta(seconds=36000000) 89 | # ) 90 | 91 | # if torch.cuda.is_available(): 92 | # if args.distributed and not args.no_set_device_rank: 93 | # device = "cuda:%d" % args.local_rank 94 | # else: 95 | # device = "cuda:0" 96 | # torch.cuda.set_device(device) 97 | # else: 98 | # device = "cpu" 99 | # args.device = device 100 | # device = torch.device(device) 101 | # return device 102 | 103 | def init_distributed_device(args): 104 | # Distributed training = training on more than one GPU. 105 | # Works in both single and multi-node scenarios. 106 | args.distributed = False 107 | args.world_size = 1 108 | args.rank = 0 # global rank 109 | args.local_rank = 0 110 | from pdb import set_trace 111 | import datetime 112 | if is_using_distributed(): 113 | if "SLURM_PROCID" in os.environ: 114 | # DDP via SLURM 115 | args.local_rank, args.rank, args.world_size = world_info_from_env() 116 | # SLURM var -> torch.distributed vars in case needed 117 | os.environ["LOCAL_RANK"] = str(args.local_rank) 118 | os.environ["RANK"] = str(args.rank) 119 | os.environ["WORLD_SIZE"] = str(args.world_size) 120 | os.environ['NCCL_BLOCKING_WAIT'] = '0' # not to enforce timeout 121 | torch.distributed.init_process_group( 122 | # timeout=timedelta(seconds=7200000), # was 1800000 123 | timeout=datetime.timedelta(seconds=7200), 124 | backend=args.dist_backend, 125 | init_method=args.dist_url, 126 | world_size=args.world_size, 127 | rank=args.rank, 128 | ) 129 | else: 130 | # DDP via torchrun, torch.distributed.launch 131 | os.environ['NCCL_BLOCKING_WAIT'] = '0' # not to enforce timeout 132 | args.local_rank, _, _ = world_info_from_env() 133 | torch.distributed.init_process_group( 134 | # timeout=timedelta(seconds=7200000), # was 18000 135 | timeout=datetime.timedelta(seconds=17200), 136 | backend=args.dist_backend, 137 | init_method=args.dist_url 138 | ) 139 | args.world_size = torch.distributed.get_world_size() 140 | args.rank = torch.distributed.get_rank() 141 | args.distributed = True 142 | else: 143 | # needed to run on single gpu 144 | torch.distributed.init_process_group( 145 | backend=args.dist_backend, 146 | init_method=args.dist_url, 147 | world_size=1, 148 | rank=0, 149 | ) 150 | 151 | if torch.cuda.is_available(): 152 | if args.distributed and not args.no_set_device_rank: 153 | device = "cuda:%d" % args.local_rank 154 | else: 155 | device = "cuda:0" 156 | torch.cuda.set_device(device) 157 | else: 158 | device = "cpu" 159 | args.device = device 160 | device = torch.device(device) 161 | return device 162 | -------------------------------------------------------------------------------- /slurm_train_intern.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | A script to run multinode training with submitit. 16 | Almost copy-paste from https://github.com/facebookresearch/deit/blob/main/run_with_submitit.py 17 | """ 18 | import argparse 19 | import os 20 | import re 21 | import random 22 | import uuid 23 | from pathlib import Path 24 | import train 25 | from utils.arguments_utils import get_parser 26 | import submitit 27 | 28 | 29 | def parse_args(): 30 | parser = argparse.ArgumentParser("Submitit for DINO", parents=[get_parser()]) 31 | print("!!!") 32 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 33 | parser.add_argument("--nodes", default=4, type=int, help="Number of nodes to request") 34 | parser.add_argument("--timeout", default=72000, type=int, help="Duration of the job") 35 | parser.add_argument("--partition", default="mozi-S1", type=str, help="Partition where to submit") 36 | parser.add_argument("--use_volta32", action='store_true', help="Big models? Use this") 37 | parser.add_argument('--comment', default="", type=str, 38 | help='Comment to pass to scheduler, e.g. priority message') 39 | parser.add_argument("--exclude", default="", type=str, help="Nodes to exclude") 40 | parser.add_argument("--output_dir", default="/mnt/petrelfs/tianyang/Code/ICLR_Manipulation/out", type=str) 41 | return parser.parse_args() 42 | 43 | def get_shared_folder() -> Path: 44 | user = os.getenv("USER") 45 | if Path(f"/ailab/user/{user}/").is_dir(): 46 | p = Path(f"/ailab/user/{user}/experiments") 47 | p.mkdir(exist_ok=True) 48 | return p 49 | raise RuntimeError("No shared folder available") 50 | 51 | def get_init_file(): 52 | # Init file must not exist, but it's parent dir must exist. 53 | os.makedirs(str(get_shared_folder()), exist_ok=True) 54 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 55 | if init_file.exists(): 56 | os.remove(str(init_file)) 57 | return init_file 58 | 59 | def _get_master_port(seed): 60 | MIN_MASTER_PORT, MAX_MASTER_PORT = (20_000, 60_000) 61 | master_port_str = os.environ.get("MASTER_PORT") 62 | if master_port_str is None: 63 | rng = random.Random(seed) 64 | return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT) 65 | return int(master_port_str) 66 | 67 | def _parse_slurm_node_list(s): 68 | nodes = [] 69 | # Extract "hostname", "hostname[1-2,3,4-5]," substrings 70 | p = re.compile(r"(([^\[]+)(?:\[([^\]]+)\])?),?") 71 | for m in p.finditer(s): 72 | prefix, suffixes = s[m.start(2) : m.end(2)], s[m.start(3) : m.end(3)] 73 | prefix_list = prefix.split(',') 74 | if len(prefix_list) > 1: 75 | nodes += prefix_list[:-1] 76 | prefix = prefix_list[-1] 77 | for suffix in suffixes.split(","): 78 | span = suffix.split("-") 79 | if len(span) == 1: 80 | nodes.append(prefix + suffix) 81 | else: 82 | width = len(span[0]) 83 | start, end = int(span[0]), int(span[1]) + 1 84 | for i in range(start, end): 85 | nodes.append(prefix + f"{i:0{width}}") 86 | 87 | return nodes 88 | 89 | class Trainer(object): 90 | def __init__(self, args): 91 | self.args = args 92 | def __call__(self): 93 | # import run_beit_pretraining 94 | self._setup_gpu_args() 95 | train.main(self.args) 96 | 97 | def checkpoint(self): 98 | import os 99 | import submitit 100 | # self.args.dist_url = get_init_file().as_uri() 101 | print("Requeuing ", self.args) 102 | empty_trainer = type(self)(self.args) 103 | return submitit.helpers.DelayedSubmission(empty_trainer) 104 | 105 | def _setup_gpu_args(self): 106 | import submitit 107 | from pathlib import Path 108 | job_id = int(os.environ["SLURM_JOB_ID"]) 109 | node_count = int(os.environ["SLURM_JOB_NUM_NODES"]) 110 | print("node_list :", os.environ["SLURM_JOB_NODELIST"]) 111 | nodes = _parse_slurm_node_list(os.environ["SLURM_JOB_NODELIST"]) 112 | print("node_count :", node_count) 113 | print("nodes :", nodes) 114 | assert len(nodes) == node_count 115 | master_addr = nodes[0] 116 | master_port = _get_master_port(seed=job_id) 117 | os.environ['MASTER_ADDR'] = master_addr 118 | os.environ['MASTER_PORT'] = str(master_port) 119 | job_env = submitit.JobEnvironment() 120 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 121 | self.args.gpu = job_env.local_rank 122 | self.args.rank = job_env.global_rank 123 | self.args.world_size = job_env.num_tasks 124 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 125 | 126 | def main(): 127 | args = parse_args() 128 | if args.output_dir == "": 129 | args.output_dir = get_shared_folder() / "%j" 130 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 131 | executor = submitit.SlurmExecutor(folder=args.output_dir, max_num_timeout=30) 132 | num_gpus_per_node = args.ngpus 133 | nodes = args.nodes 134 | timeout_min = args.timeout 135 | partition = args.partition 136 | kwargs = {} 137 | if args.use_volta32: 138 | kwargs['slurm_constraint'] = 'volta32gb' 139 | if args.comment: 140 | kwargs['slurm_comment'] = args.comment 141 | if args.exclude: 142 | kwargs["exclude"] = args.exclude 143 | executor.update_parameters( 144 | gres=f"gpu:{num_gpus_per_node}", 145 | ntasks_per_node=num_gpus_per_node, # one task per GPU 146 | cpus_per_task=6, 147 | nodes=nodes, 148 | time=timeout_min, 149 | # Below are cluster dependent parameters 150 | signal_delay_s=120, 151 | partition=partition, 152 | **kwargs 153 | ) 154 | executor.update_parameters(job_name="seer") 155 | # args.dist_url = get_init_file().as_uri() 156 | trainer = Trainer(args) 157 | job = executor.submit(trainer) 158 | print(f"Submitted job_id: {job.job_id}") 159 | print(f"Logs and checkpoints will be saved at: {args.output_dir}") 160 | 161 | if __name__ == "__main__": 162 | main() -------------------------------------------------------------------------------- /real_preprocess/octo_oxe_data_utils/oxe/oxe_dataset_mixes.py: -------------------------------------------------------------------------------- 1 | """Defines dataset mixtures and weights for the Open X-Embodiment Datasets.""" 2 | 3 | 4 | BRIDGE_MIX = [ 5 | ("bridge_dataset", 1.0), 6 | ] 7 | 8 | RT_X_MIX = [ 9 | ("fractal20220817_data", 0.54087122203), 10 | ("kuka", 0.8341046294), 11 | ("bridge_dataset", 1.0), 12 | ("taco_play", 2.0), 13 | ("jaco_play", 2.0), 14 | ("berkeley_cable_routing", 3.0), 15 | ("roboturk", 1.0), 16 | ("nyu_door_opening_surprising_effectiveness", 5.0), 17 | ("viola", 2.0), 18 | ("berkeley_autolab_ur5", 1.0), 19 | ("toto", 1.0), 20 | ] 21 | 22 | 23 | OXE_FRANKA_MIX = [ 24 | ("taco_play", 1.0), 25 | ("berkeley_cable_routing", 1.0), 26 | ("viola", 1.0), 27 | ("toto", 1.0), 28 | ("stanford_hydra_dataset_converted_externally_to_rlds", 1.0), 29 | ("austin_buds_dataset_converted_externally_to_rlds", 3.0), 30 | ("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0), 31 | ("maniskill_dataset_converted_externally_to_rlds", 0.1), 32 | ("furniture_bench_dataset_converted_externally_to_rlds", 0.1), 33 | ("cmu_franka_exploration_dataset_converted_externally_to_rlds", 5.0), 34 | ("austin_sailor_dataset_converted_externally_to_rlds", 1.0), 35 | ("austin_sirius_dataset_converted_externally_to_rlds", 1.0), 36 | ("berkeley_rpt_converted_externally_to_rlds", 1.0), 37 | ("kaist_nonprehensile_converted_externally_to_rlds", 3.0), 38 | ("stanford_robocook_converted_externally_to_rlds", 1.0), 39 | ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0), 40 | ("utaustin_mutex", 1.0), 41 | # ("cmu_playing_with_food", 1.0), 42 | ("cmu_play_fusion", 1.0), 43 | ] 44 | 45 | 46 | OXE_MAGIC_SOUP = [ 47 | ("fractal20220817_data", 0.54087122203), 48 | ("kuka", 0.8341046294), 49 | ("bridge_dataset", 1.0), 50 | ("taco_play", 2.0), 51 | ("jaco_play", 1.0), 52 | ("berkeley_cable_routing", 1.0), 53 | ("roboturk", 2.0), 54 | ("nyu_door_opening_surprising_effectiveness", 1.0), 55 | ("viola", 2.0), 56 | ("berkeley_autolab_ur5", 2.0), 57 | ("toto", 1.0), 58 | ("language_table", 0.1), 59 | ("stanford_hydra_dataset_converted_externally_to_rlds", 2.0), 60 | ("austin_buds_dataset_converted_externally_to_rlds", 1.0), 61 | ("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0), 62 | ("furniture_bench_dataset_converted_externally_to_rlds", 0.1), 63 | ("ucsd_kitchen_dataset_converted_externally_to_rlds", 2.0), 64 | ("austin_sailor_dataset_converted_externally_to_rlds", 1.0), 65 | ("austin_sirius_dataset_converted_externally_to_rlds", 1.0), 66 | ("bc_z", 0.2), 67 | ("dlr_edan_shared_control_converted_externally_to_rlds", 1.0), 68 | ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0), 69 | ("utaustin_mutex", 1.0), 70 | ("berkeley_fanuc_manipulation", 2.0), 71 | ("cmu_stretch", 1.0), 72 | ] 73 | 74 | OXE_MAGIC_SOUP_EEFP_ONLY = [ 75 | ("fractal20220817_data", 0.54087122203), 76 | ("kuka", 0.8341046294), 77 | ("bridge_dataset", 1.0), 78 | ("taco_play", 2.0), 79 | ("jaco_play", 1.0), 80 | # ("berkeley_cable_routing", 1.0), 81 | ("roboturk", 2.0), 82 | ("nyu_door_opening_surprising_effectiveness", 1.0), 83 | ("viola", 2.0), 84 | ("berkeley_autolab_ur5", 2.0), 85 | # ("toto", 1.0), 86 | ("language_table", 0.1), 87 | ("stanford_hydra_dataset_converted_externally_to_rlds", 2.0), 88 | ("austin_buds_dataset_converted_externally_to_rlds", 1.0), 89 | # ("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0), 90 | # ("furniture_bench_dataset_converted_externally_to_rlds", 0.1), 91 | ("ucsd_kitchen_dataset_converted_externally_to_rlds", 2.0), 92 | # ("austin_sailor_dataset_converted_externally_to_rlds", 1.0), 93 | # ("austin_sirius_dataset_converted_externally_to_rlds", 1.0), 94 | ("bc_z", 0.2), 95 | ("dlr_edan_shared_control_converted_externally_to_rlds", 1.0), 96 | ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0), 97 | ("utaustin_mutex", 1.0), 98 | ("berkeley_fanuc_manipulation", 2.0), 99 | ("cmu_stretch", 1.0), 100 | ] 101 | 102 | OXE_FULL_MIX = [ 103 | ("fractal20220817_data", 1.0), 104 | ("kuka", 1.0), 105 | ("bridge_dataset", 1), 106 | ("taco_play", 1.0), 107 | ("jaco_play", 1.0), 108 | ("berkeley_cable_routing", 1.0), 109 | ("roboturk", 1.0), 110 | ("nyu_door_opening_surprising_effectiveness", 1.0), 111 | ("viola", 1.0), 112 | ("berkeley_autolab_ur5", 1.0), 113 | ("toto", 1.0), 114 | ("language_table", 1.0), 115 | ("columbia_cairlab_pusht_real", 1.0), 116 | ("stanford_kuka_multimodal_dataset_converted_externally_to_rlds", 1.0), 117 | ("nyu_rot_dataset_converted_externally_to_rlds", 1.0), 118 | ("stanford_hydra_dataset_converted_externally_to_rlds", 1.0), 119 | ("austin_buds_dataset_converted_externally_to_rlds", 1.0), 120 | ("nyu_franka_play_dataset_converted_externally_to_rlds", 1.0), 121 | ("maniskill_dataset_converted_externally_to_rlds", 1.0), 122 | ("furniture_bench_dataset_converted_externally_to_rlds", 1.0), 123 | ("cmu_franka_exploration_dataset_converted_externally_to_rlds", 1.0), 124 | ("ucsd_kitchen_dataset_converted_externally_to_rlds", 1.0), 125 | ("ucsd_pick_and_place_dataset_converted_externally_to_rlds", 1.0), 126 | ("austin_sailor_dataset_converted_externally_to_rlds", 1.0), 127 | ("austin_sirius_dataset_converted_externally_to_rlds", 1.0), 128 | ("bc_z", 1.0), 129 | ("utokyo_pr2_opening_fridge_converted_externally_to_rlds", 1.0), 130 | ("utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds", 1.0), 131 | ("utokyo_xarm_pick_and_place_converted_externally_to_rlds", 1.0), 132 | ("utokyo_xarm_bimanual_converted_externally_to_rlds", 1.0), 133 | ("robo_net", 1.0), 134 | ("berkeley_mvp_converted_externally_to_rlds", 1.0), 135 | ("berkeley_rpt_converted_externally_to_rlds", 1.0), 136 | ("kaist_nonprehensile_converted_externally_to_rlds", 1.0), 137 | ("stanford_mask_vit_converted_externally_to_rlds", 1.0), 138 | ("tokyo_u_lsmo_converted_externally_to_rlds", 1.0), 139 | ("dlr_sara_pour_converted_externally_to_rlds", 1.0), 140 | ("dlr_sara_grid_clamp_converted_externally_to_rlds", 1.0), 141 | ("dlr_edan_shared_control_converted_externally_to_rlds", 1.0), 142 | ("asu_table_top_converted_externally_to_rlds", 1.0), 143 | ("stanford_robocook_converted_externally_to_rlds", 1.0), 144 | ("imperialcollege_sawyer_wrist_cam", 1.0), 145 | ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0), 146 | ("uiuc_d3field", 1.0), 147 | ("utaustin_mutex", 1.0), 148 | ("berkeley_fanuc_manipulation", 1.0), 149 | ("cmu_playing_with_food", 1.0), 150 | ("cmu_play_fusion", 1.0), 151 | ("cmu_stretch", 1.0), 152 | ("berkeley_gnm_recon", 1.0), 153 | ("berkeley_gnm_cory_hall", 1.0), 154 | ("berkeley_gnm_sac_son", 1.0), 155 | ] 156 | 157 | MYTEST_MIX = [ 158 | ("cmu_stretch", 1.0), 159 | ] 160 | MYTEST_MIX1 = [ 161 | ("berkeley_autolab_ur5", 1.0), 162 | ] 163 | MYTEST_MIX2 = [ 164 | ("viola", 1.0), 165 | ] 166 | MYTEST_MIX3 = [ 167 | ("nyu_door_opening_surprising_effectiveness", 1.0), 168 | ] 169 | 170 | OXE_NAMED_MIXES = { 171 | "bridge": BRIDGE_MIX, 172 | "rtx": RT_X_MIX, 173 | "rtx_franka": RT_X_MIX + OXE_FRANKA_MIX, 174 | "oxe_magic_soup": OXE_MAGIC_SOUP, 175 | "oxe_magic_soup_eefp_only": OXE_MAGIC_SOUP_EEFP_ONLY, 176 | "mytest": MYTEST_MIX, 177 | } 178 | -------------------------------------------------------------------------------- /data_info/.hydra/merged_config.yaml: -------------------------------------------------------------------------------- 1 | cameras: 2 | static: 3 | _target_: calvin_env.camera.static_camera.StaticCamera 4 | name: static 5 | fov: 10 6 | aspect: 1 7 | nearval: 0.01 8 | farval: 10 9 | width: 200 10 | height: 200 11 | look_at: 12 | - -0.026242351159453392 13 | - -0.0302329882979393 14 | - 0.3920000493526459 15 | look_from: 16 | - 2.871459009488717 17 | - -2.166602199425597 18 | - 2.555159848480571 19 | up_vector: 20 | - 0.4041403970338857 21 | - 0.22629790978217404 22 | - 0.8862616969685161 23 | gripper: 24 | _target_: calvin_env.camera.gripper_camera.GripperCamera 25 | name: gripper 26 | fov: 75 27 | aspect: 1 28 | nearval: 0.01 29 | farval: 2 30 | width: 84 31 | height: 84 32 | tactile: 33 | _target_: calvin_env.camera.tactile_sensor.TactileSensor 34 | name: tactile 35 | width: 120 36 | height: 160 37 | digit_link_ids: 38 | - 10 39 | - 12 40 | visualize_gui: false 41 | config_path: conf/digit_sensor/config_digit.yml 42 | vr_input: 43 | vr_controller: 44 | POSITION: 1 45 | ORIENTATION: 2 46 | ANALOG: 3 47 | BUTTONS: 6 48 | BUTTON_A: 2 49 | BUTTON_B: 1 50 | vr_controller_id: 3 51 | gripper_orientation_offset: 52 | - 0 53 | - 3 54 | - 3.14 55 | gripper_position_offset: 56 | - -0.2 57 | - 0.3 58 | - 0 59 | _target_: calvin_env.io_utils.vr_input.VrInput 60 | limit_angle: 61 | - 90 62 | - 0 63 | - 0 64 | - -1 65 | visualize_vr_pos: true 66 | reset_button_queue_len: 60 67 | env: 68 | _target_: calvin_env.envs.play_table_env.PlayTableSimEnv 69 | _recursive_: false 70 | cameras: ${cameras} 71 | seed: 0 72 | bullet_time_step: 240.0 73 | use_vr: false 74 | show_gui: ${show_gui} 75 | robot_cfg: ${robot} 76 | scene_cfg: ${scene} 77 | use_scene_info: false 78 | use_egl: true 79 | control_freq: 30 80 | scene: 81 | _target_: calvin_env.scene.play_table_scene.PlayTableScene 82 | _recursive_: false 83 | data_path: ${data_path} 84 | global_scaling: 0.8 85 | euler_obs: ${robot.euler_obs} 86 | robot_base_position: 87 | - -0.34 88 | - -0.46 89 | - 0.24 90 | robot_base_orientation: 91 | - 0 92 | - 0 93 | - 0 94 | robot_initial_joint_positions: 95 | - -1.21779206 96 | - 1.03987646 97 | - 2.11978261 98 | - -2.34205014 99 | - -0.87015947 100 | - 1.64119353 101 | - 0.55344866 102 | surfaces: 103 | table: 104 | - - 0.0 105 | - -0.15 106 | - 0.46 107 | - - 0.35 108 | - -0.03 109 | - 0.46 110 | slider_left: 111 | - - -0.32 112 | - 0.05 113 | - 0.46 114 | - - -0.16 115 | - 0.12 116 | - 0.46 117 | slider_right: 118 | - - -0.05 119 | - 0.05 120 | - 0.46 121 | - - 0.13 122 | - 0.12 123 | - 0.46 124 | objects: 125 | fixed_objects: 126 | table: 127 | file: calvin_table_D/urdf/calvin_table_D.urdf 128 | initial_pos: 129 | - 0 130 | - 0 131 | - 0 132 | initial_orn: 133 | - 0 134 | - 0 135 | - 0 136 | joints: 137 | base__slide: 138 | initial_state: 0 139 | base__drawer: 140 | initial_state: 0 141 | buttons: 142 | base__button: 143 | initial_state: 0 144 | effect: led 145 | switches: 146 | base__switch: 147 | initial_state: 0 148 | effect: lightbulb 149 | lights: 150 | lightbulb: 151 | link: light_link 152 | color: 153 | - 1 154 | - 1 155 | - 0 156 | - 1 157 | led: 158 | link: led_link 159 | color: 160 | - 0 161 | - 1 162 | - 0 163 | - 1 164 | movable_objects: 165 | block_red: 166 | file: blocks/block_red_middle.urdf 167 | initial_pos: any 168 | initial_orn: any 169 | block_blue: 170 | file: blocks/block_blue_small.urdf 171 | initial_pos: any 172 | initial_orn: any 173 | block_pink: 174 | file: blocks/block_pink_big.urdf 175 | initial_pos: any 176 | initial_orn: any 177 | name: calvin_scene_D 178 | robot: 179 | _target_: calvin_env.robot.robot.Robot 180 | filename: franka_panda/panda_longer_finger.urdf 181 | base_position: ${scene.robot_base_position} 182 | base_orientation: ${scene.robot_base_orientation} 183 | initial_joint_positions: ${scene.robot_initial_joint_positions} 184 | max_joint_force: 200.0 185 | gripper_force: 200 186 | arm_joint_ids: 187 | - 0 188 | - 1 189 | - 2 190 | - 3 191 | - 4 192 | - 5 193 | - 6 194 | gripper_joint_ids: 195 | - 9 196 | - 11 197 | gripper_joint_limits: 198 | - 0 199 | - 0.04 200 | tcp_link_id: 15 201 | end_effector_link_id: 7 202 | gripper_cam_link: 12 203 | use_nullspace: true 204 | max_velocity: 2 205 | use_ik_fast: false 206 | magic_scaling_factor_pos: 1 207 | magic_scaling_factor_orn: 1 208 | use_target_pose: true 209 | euler_obs: true 210 | tasks: 211 | _target_: calvin_env.envs.tasks.Tasks 212 | tasks: 213 | rotate_red_block_right: 214 | - rotate_object 215 | - block_red 216 | - -60 217 | rotate_red_block_left: 218 | - rotate_object 219 | - block_red 220 | - 60 221 | rotate_blue_block_right: 222 | - rotate_object 223 | - block_blue 224 | - -60 225 | rotate_blue_block_left: 226 | - rotate_object 227 | - block_blue 228 | - 60 229 | rotate_pink_block_right: 230 | - rotate_object 231 | - block_pink 232 | - -60 233 | rotate_pink_block_left: 234 | - rotate_object 235 | - block_pink 236 | - 60 237 | push_red_block_right: 238 | - push_object 239 | - block_red 240 | - 0.1 241 | - 0 242 | push_red_block_left: 243 | - push_object 244 | - block_red 245 | - -0.1 246 | - 0 247 | push_blue_block_right: 248 | - push_object 249 | - block_blue 250 | - 0.1 251 | - 0 252 | push_blue_block_left: 253 | - push_object 254 | - block_blue 255 | - -0.1 256 | - 0 257 | push_pink_block_right: 258 | - push_object 259 | - block_pink 260 | - 0.1 261 | - 0 262 | push_pink_block_left: 263 | - push_object 264 | - block_pink 265 | - -0.1 266 | - 0 267 | move_slider_left: 268 | - move_door_rel 269 | - base__slide 270 | - 0.15 271 | move_slider_right: 272 | - move_door_rel 273 | - base__slide 274 | - -0.15 275 | open_drawer: 276 | - move_door_rel 277 | - base__drawer 278 | - 0.12 279 | close_drawer: 280 | - move_door_rel 281 | - base__drawer 282 | - -0.12 283 | lift_red_block_table: 284 | - lift_object 285 | - block_red 286 | - 0.05 287 | - table 288 | - base_link 289 | lift_red_block_slider: 290 | - lift_object 291 | - block_red 292 | - 0.03 293 | - table 294 | - plank_link 295 | lift_red_block_drawer: 296 | - lift_object 297 | - block_red 298 | - 0.05 299 | - table 300 | - drawer_link 301 | lift_blue_block_table: 302 | - lift_object 303 | - block_blue 304 | - 0.05 305 | - table 306 | - base_link 307 | lift_blue_block_slider: 308 | - lift_object 309 | - block_blue 310 | - 0.03 311 | - table 312 | - plank_link 313 | lift_blue_block_drawer: 314 | - lift_object 315 | - block_blue 316 | - 0.05 317 | - table 318 | - drawer_link 319 | lift_pink_block_table: 320 | - lift_object 321 | - block_pink 322 | - 0.05 323 | - table 324 | - base_link 325 | lift_pink_block_slider: 326 | - lift_object 327 | - block_pink 328 | - 0.03 329 | - table 330 | - plank_link 331 | lift_pink_block_drawer: 332 | - lift_object 333 | - block_pink 334 | - 0.05 335 | - table 336 | - drawer_link 337 | place_in_slider: 338 | - place_object 339 | - table 340 | - plank_link 341 | place_in_drawer: 342 | - place_object 343 | - table 344 | - drawer_link 345 | stack_block: 346 | - stack_objects 347 | unstack_block: 348 | - unstack_objects 349 | turn_on_lightbulb: 350 | - toggle_light 351 | - lightbulb 352 | - 0 353 | - 1 354 | turn_off_lightbulb: 355 | - toggle_light 356 | - lightbulb 357 | - 1 358 | - 0 359 | turn_on_led: 360 | - toggle_light 361 | - led 362 | - 0 363 | - 1 364 | turn_off_led: 365 | - toggle_light 366 | - led 367 | - 1 368 | - 0 369 | push_into_drawer: 370 | - push_object_into 371 | - - block_red 372 | - block_blue 373 | - block_pink 374 | - table 375 | - base_link 376 | - table 377 | - drawer_link 378 | recorder: 379 | record: ${record} 380 | record_fps: 30.0 381 | show_fps: false 382 | enable_tts: true 383 | seed: 0 384 | use_vr: true 385 | data_path: data 386 | save_dir: /home/meeso/recordings 387 | record: true 388 | load_dir: /work/dlclarge2/meeso-lfp/calvin_recordings/play_env_A/2021-09-12/11-36-05/ 389 | show_gui: false 390 | processes: 16 391 | max_episode_frames: 1 392 | save_body_infos: true 393 | set_static_cam: false 394 | -------------------------------------------------------------------------------- /real_preprocess/mujoco_menagerie/robotiq_2f85/2f85.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 192 | -------------------------------------------------------------------------------- /utils/convert_libero_per_step.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import random 4 | import torch.multiprocessing as mp 5 | import torch.distributed as dist 6 | import numpy as np 7 | import h5py 8 | import tqdm 9 | import argparse 10 | from pathlib import Path 11 | import yaml 12 | from PIL import Image 13 | import time 14 | import json 15 | import math 16 | # from sentence_transformers import SentenceTransformer 17 | 18 | ### Rotation ### 19 | from scipy.spatial.transform import Rotation 20 | 21 | 22 | def setup(rank, world_size, port): 23 | os.environ["MASTER_ADDR"] = "localhost" 24 | os.environ["MASTER_PORT"] = str(port) 25 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 26 | 27 | def extract_task_information(file_name, path): 28 | """ 29 | Extracts task information from the given file name. 30 | """ 31 | # Regular expression pattern to extract the task name 32 | pattern = r'{}/((.+)_SCENE[0-9]+_(.+))_demo\.hdf5'.format(path) 33 | 34 | # Extracting the task name 35 | match = re.search(pattern, file_name) 36 | 37 | print(match.group(3).lower().replace("_", " ")) 38 | return match.group(1).lower() if match else None, match.group(3).lower().replace("_", " ") 39 | 40 | 41 | class DatasetConverter: 42 | def __init__( 43 | self, 44 | src_dir: str, 45 | tgt_dir: str, 46 | rank: int, 47 | num_worker: int, 48 | start_episode_idx, 49 | end_episode_idx, 50 | ): 51 | self.src_dir = src_dir 52 | self.tgt_dir = tgt_dir 53 | self.rank = rank 54 | self.num_worker = num_worker 55 | self.start_episode_idx = start_episode_idx 56 | self.end_episode_idx = end_episode_idx 57 | # self.lang_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") 58 | 59 | def process_episode(self, episode_dir, language_instructions, demo_data, episode_index, episode_index_in_task): 60 | i = episode_index_in_task 61 | # get episode dir 62 | episode_dir.mkdir(exist_ok=True) 63 | 64 | ### Get agent's view camera 65 | obs = np.array(demo_data['demo_{}'.format(i)]['obs']['agentview_rgb']) 66 | # obs = obs.transpose(0,3,1,2) 67 | 68 | ### Get wrist's view camera 69 | obs_wrist = np.array(demo_data['demo_{}'.format(i)]['obs']['eye_in_hand_rgb']) 70 | # obs_wrist = obs_wrist.transpose(0,3,1,2) 71 | 72 | ### Get actions 73 | action = np.array(demo_data['demo_{}'.format(i)]['actions']) # -1 open, 1 close 74 | 75 | joint_state = np.array(demo_data['demo_{}'.format(i)]['obs']['joint_states']) 76 | ee_pos = np.array(demo_data['demo_{}'.format(i)]['obs']['ee_pos']) 77 | ee_ori = np.array(demo_data['demo_{}'.format(i)]['obs']['ee_ori']) 78 | ee_state = np.array(demo_data['demo_{}'.format(i)]['obs']['ee_states']) 79 | 80 | gripper_state = np.zeros_like(action[:, -1]) 81 | gripper_state[1:] = action[:-1, -1] 82 | gripper_state[0] = action[0, -1] 83 | 84 | gripper_position = np.array(demo_data['demo_{}'.format(i)]['obs']['gripper_states']) 85 | gripper_command = action[:, -1] 86 | 87 | # task emb 88 | # task_emb = self.lang_model.encode(language_instructions) 89 | 90 | # get episode length 91 | num_steps = obs.shape[0] 92 | 93 | episode_dir = episode_dir/str(episode_index).zfill(6) 94 | episode_dir.mkdir(exist_ok=True) 95 | 96 | # save episode length and language instruction 97 | with h5py.File(f'{episode_dir}/meta_info.h5', 'w') as h5_file: 98 | h5_file.create_dataset(name='length', data=num_steps) 99 | 100 | steps_dir = episode_dir/'steps' 101 | steps_dir.mkdir(exist_ok=True) 102 | for step_index in range(num_steps): 103 | step_dir = episode_dir/'steps'/str(step_index).zfill(4) 104 | step_dir.mkdir(exist_ok=True) 105 | 106 | with h5py.File(f'{step_dir}/other.h5', 'w') as h5_file: 107 | # language instruction 108 | h5_file.create_dataset('language_instruction', data=np.array(language_instructions, dtype=h5py.string_dtype(encoding='utf-8'))) 109 | # task emb 110 | # h5_file.create_dataset(name='task_emb', data=task_emb) 111 | 112 | # episode length 113 | h5_file.create_dataset(name='episode_length', data=num_steps) 114 | 115 | # action 116 | h5_file.create_dataset(name='action', data=action[step_index]) 117 | 118 | # observation (timestep, proprio, image_XXX) 119 | observation_group = h5_file.create_group(name='observation') 120 | 121 | ## image 122 | # ### image_primary 123 | Image.fromarray(obs[step_index]).save(f'{step_dir}/image_primary.jpg') 124 | ### image_wrist 125 | Image.fromarray(obs_wrist[step_index]).save(f'{step_dir}/image_wrist.jpg') 126 | 127 | ## proprio 128 | observation_group.create_dataset(name='proprio', data=joint_state[step_index]) 129 | 130 | ## tcp_pose 131 | observation_group.create_dataset(name='tcp_pose', data=ee_state[step_index]) 132 | 133 | ## gripper state (-1 or 1) 134 | observation_group.create_dataset(name='gripper_state', data=gripper_state[step_index]) 135 | 136 | ## gripper position (n, 2) 137 | observation_group.create_dataset(name='gripper_position', data=gripper_position[step_index]) 138 | 139 | def convert_origin_dataset_to_target(self, dataset_by_task): 140 | # /dataset_0 141 | # |_meta_info.h5 142 | # |_/episodes 143 | # | |_/0 144 | # | | |_/steps 145 | # | | |_/0 146 | # | | |_other.h5 147 | # | | |_XXX.jpg 148 | # | |... 149 | # | |_/1 150 | # | |_... 151 | # /dataset_1 152 | # | 153 | episodes_dir = self.tgt_dir/'episodes' 154 | episodes_dir.mkdir(exist_ok=True) 155 | 156 | num_episodes = 0 157 | for dataset in dataset_by_task: 158 | num_episodes += dataset['num_episode'] 159 | 160 | if self.rank == 0: 161 | with h5py.File(f'{str(self.tgt_dir)}/meta_info.h5', 'w') as h5_file: 162 | h5_file.create_dataset(name='num_episodes', data=num_episodes) 163 | 164 | processed_task_num_episode = 0 165 | task_index = 0 166 | 167 | for episode_index in range(num_episodes): 168 | episode_index_in_task = episode_index - processed_task_num_episode 169 | 170 | if episode_index < self.start_episode_idx: 171 | continue 172 | if self.end_episode_idx is not None: 173 | if episode_index >= self.end_episode_idx: 174 | break 175 | if episode_index % self.num_worker != self.rank: 176 | if episode_index_in_task+1 == dataset_by_task[task_index]['num_episode']: 177 | processed_task_num_episode += dataset_by_task[task_index]['num_episode'] 178 | task_index += 1 179 | continue 180 | print(self.rank, episode_index, '/' , num_episodes) 181 | self.process_episode(episode_dir=episodes_dir, language_instructions=dataset_by_task[task_index]['language'], demo_data=dataset_by_task[task_index]['data'], episode_index=episode_index, episode_index_in_task=episode_index_in_task) 182 | 183 | if episode_index_in_task+1 == dataset_by_task[task_index]['num_episode']: 184 | processed_task_num_episode += dataset_by_task[task_index]['num_episode'] 185 | task_index += 1 186 | 187 | def run(self): 188 | print(f'target dir: {self.tgt_dir}') 189 | 190 | dataset_by_task = [] 191 | for path in list(Path(self.src_dir).iterdir()): 192 | path_name = str(path) 193 | task_name, task_language = extract_task_information(path_name, self.src_dir) 194 | demo_data = h5py.File(path_name, 'r')['data'] 195 | num_episode = len(demo_data) 196 | dataset = { 197 | 'language': task_language, 198 | 'num_episode': num_episode, 199 | 'data': demo_data 200 | } 201 | dataset_by_task.append(dataset) 202 | 203 | self.convert_origin_dataset_to_target(dataset_by_task) 204 | 205 | print(f'data saved at {self.tgt_dir}') 206 | 207 | # get data_info.json 208 | data_info = [] 209 | episode_idx = 0 210 | total_step = 0 211 | for path in list(Path(self.src_dir).iterdir()): 212 | path_name = str(path) 213 | demo_data = h5py.File(path_name, 'r')['data'] 214 | num_episode = len(demo_data) 215 | for i in range(num_episode): 216 | num_steps = np.array(demo_data['demo_{}'.format(i)]['obs']['agentview_rgb']).shape[0] 217 | data_info.append([str(episode_idx).zfill(6), num_steps]) 218 | episode_idx += 1 219 | total_step += num_steps 220 | # print(total_step) 221 | with open(f'./data_info/{dataset_name}_converted.json', 'w') as f: 222 | json.dump(data_info, f) 223 | 224 | def main(rank, port, num_worker, start_episode_idx=0, end_episode_idx=None): 225 | if num_worker > 1: 226 | setup(rank, world_size=num_worker, port=port) 227 | 228 | global dataset_name 229 | dataset_name = "libero_90" # "libero_10" 230 | src_dir = f"/fs-computility/efm/shared/datasets/Banana/tianyang/Data/{dataset_name}" 231 | tgt_dir = Path(f"/fs-computility/efm/shared/datasets/Banana/tianyang/Data/{dataset_name}_converted") 232 | tgt_dir.mkdir(exist_ok=True) 233 | 234 | dataset_converter = DatasetConverter( 235 | src_dir=src_dir, 236 | tgt_dir=tgt_dir, 237 | rank=rank, 238 | num_worker=num_worker, 239 | start_episode_idx=start_episode_idx, # the dataset[start_episode_idx] will be processed 240 | end_episode_idx=end_episode_idx, # None means the last episode. if not none, the dataset[end_episode_idx - 1] will be processed and the dataset[end_episode_idx] will not be processed 241 | ) 242 | dataset_converter.run() 243 | 244 | if __name__ == '__main__': 245 | start_episode_idx = 0 246 | end_episode_idx = None 247 | num_worker = 8 248 | port = (random.randint(0, 3000) % 3000) + 27000 249 | 250 | assert num_worker > 1 251 | mp.spawn(main, args=(port, num_worker, start_episode_idx, end_episode_idx), nprocs=num_worker, join=True) 252 | 253 | # main(0, port, 1, start_episode_idx=start_episode_idx, end_episode_idx=end_episode_idx) 254 | -------------------------------------------------------------------------------- /models/vit_mae.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | from timm.models.vision_transformer import PatchEmbed, Block 7 | 8 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 9 | """ 10 | grid_size: int of the grid height and width 11 | return: 12 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 13 | """ 14 | grid_h = np.arange(grid_size, dtype=np.float32) 15 | grid_w = np.arange(grid_size, dtype=np.float32) 16 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 17 | grid = np.stack(grid, axis=0) 18 | 19 | grid = grid.reshape([2, 1, grid_size, grid_size]) 20 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 21 | if cls_token: 22 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 23 | return pos_embed 24 | 25 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 26 | assert embed_dim % 2 == 0 27 | 28 | # use half of dimensions to encode grid_h 29 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 30 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 31 | 32 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 33 | return emb 34 | 35 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 36 | """ 37 | embed_dim: output dimension for each position 38 | pos: a list of positions to be encoded: size (M,) 39 | out: (M, D) 40 | """ 41 | assert embed_dim % 2 == 0 42 | omega = np.arange(embed_dim // 2, dtype=np.float32) 43 | omega /= embed_dim / 2. 44 | omega = 1. / 10000**omega # (D/2,) 45 | 46 | pos = pos.reshape(-1) # (M,) 47 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 48 | 49 | emb_sin = np.sin(out) # (M, D/2) 50 | emb_cos = np.cos(out) # (M, D/2) 51 | 52 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 53 | return emb 54 | 55 | class MaskedAutoencoderViT(nn.Module): 56 | """ Masked Autoencoder with VisionTransformer backbone 57 | """ 58 | def __init__(self, img_size=224, patch_size=16, in_chans=3, 59 | embed_dim=1024, depth=24, num_heads=16, 60 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 61 | mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False): 62 | super().__init__() 63 | 64 | # -------------------------------------------------------------------------- 65 | # MAE encoder specifics 66 | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) 67 | # print("path_embed.device: ", self.patch_embed.device) 68 | num_patches = self.patch_embed.num_patches 69 | 70 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 71 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding 72 | 73 | self.blocks = nn.ModuleList([ 74 | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) # qk_scale=None, 75 | for i in range(depth)]) 76 | self.norm = norm_layer(embed_dim) 77 | # -------------------------------------------------------------------------- 78 | 79 | # -------------------------------------------------------------------------- 80 | # MAE decoder specifics 81 | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) 82 | 83 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 84 | 85 | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding 86 | 87 | self.decoder_blocks = nn.ModuleList([ 88 | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) # qk_scale=None, 89 | for i in range(decoder_depth)]) 90 | 91 | self.decoder_norm = norm_layer(decoder_embed_dim) 92 | self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch 93 | # -------------------------------------------------------------------------- 94 | 95 | self.norm_pix_loss = norm_pix_loss 96 | 97 | self.initialize_weights() 98 | 99 | def initialize_weights(self): 100 | # initialization 101 | # initialize (and freeze) pos_embed by sin-cos embedding 102 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 103 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 104 | 105 | decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 106 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) 107 | 108 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 109 | w = self.patch_embed.proj.weight.data 110 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 111 | 112 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 113 | torch.nn.init.normal_(self.cls_token, std=.02) 114 | torch.nn.init.normal_(self.mask_token, std=.02) 115 | 116 | # initialize nn.Linear and nn.LayerNorm 117 | self.apply(self._init_weights) 118 | 119 | def _init_weights(self, m): 120 | if isinstance(m, nn.Linear): 121 | # we use xavier_uniform following official JAX ViT: 122 | torch.nn.init.xavier_uniform_(m.weight) 123 | if isinstance(m, nn.Linear) and m.bias is not None: 124 | nn.init.constant_(m.bias, 0) 125 | elif isinstance(m, nn.LayerNorm): 126 | nn.init.constant_(m.bias, 0) 127 | nn.init.constant_(m.weight, 1.0) 128 | 129 | def patchify(self, imgs): 130 | """ 131 | imgs: (N, 3, H, W) 132 | x: (N, L, patch_size**2 *3) 133 | """ 134 | p = self.patch_embed.patch_size[0] 135 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 136 | 137 | h = w = imgs.shape[2] // p 138 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) 139 | x = torch.einsum('nchpwq->nhwpqc', x) 140 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) 141 | return x 142 | 143 | def unpatchify(self, x): 144 | """ 145 | x: (N, L, patch_size**2 *3) 146 | imgs: (N, 3, H, W) 147 | """ 148 | p = self.patch_embed.patch_size[0] 149 | h = w = int(x.shape[1]**.5) 150 | assert h * w == x.shape[1] 151 | 152 | x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) 153 | x = torch.einsum('nhwpqc->nchpwq', x) 154 | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) 155 | return imgs 156 | 157 | def random_masking(self, x, mask_ratio): 158 | """ 159 | Perform per-sample random masking by per-sample shuffling. 160 | Per-sample shuffling is done by argsort random noise. 161 | x: [N, L, D], sequence 162 | """ 163 | N, L, D = x.shape # batch, length, dim 164 | len_keep = int(L * (1 - mask_ratio)) 165 | 166 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 167 | 168 | # sort noise for each sample 169 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 170 | ids_restore = torch.argsort(ids_shuffle, dim=1) 171 | 172 | # keep the first subset 173 | ids_keep = ids_shuffle[:, :len_keep] 174 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 175 | 176 | # generate the binary mask: 0 is keep, 1 is remove 177 | mask = torch.ones([N, L], device=x.device) 178 | mask[:, :len_keep] = 0 179 | # unshuffle to get the binary mask 180 | mask = torch.gather(mask, dim=1, index=ids_restore) 181 | 182 | return x_masked, mask, ids_restore 183 | 184 | def forward_encoder(self, x, mask_ratio): 185 | # embed patches 186 | # set_trace() 187 | # print("patch_embed cuda: ", next(self.patch_embed.parameters()).is_cuda) 188 | x = self.patch_embed(x) 189 | 190 | # add pos embed w/o cls token 191 | x = x + self.pos_embed[:, 1:, :] 192 | 193 | # masking: length -> length * mask_ratio 194 | x, mask, ids_restore = self.random_masking(x, mask_ratio) 195 | 196 | # append cls token 197 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 198 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 199 | x = torch.cat((cls_tokens, x), dim=1) 200 | 201 | # apply Transformer blocks 202 | for blk in self.blocks: 203 | x = blk(x) 204 | x = self.norm(x) 205 | 206 | return x, mask, ids_restore 207 | 208 | def forward_decoder(self, x, ids_restore): 209 | # embed tokens 210 | x = self.decoder_embed(x) 211 | 212 | # append mask tokens to sequence 213 | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) 214 | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token 215 | x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle 216 | x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token 217 | 218 | # add pos embed 219 | x = x + self.decoder_pos_embed 220 | 221 | # apply Transformer blocks 222 | for blk in self.decoder_blocks: 223 | x = blk(x) 224 | x = self.decoder_norm(x) 225 | 226 | # predictor projection 227 | x = self.decoder_pred(x) 228 | 229 | # remove cls token 230 | x = x[:, 1:, :] 231 | 232 | return x 233 | 234 | def forward_loss(self, imgs, pred, mask): 235 | """ 236 | imgs: [N, 3, H, W] 237 | pred: [N, L, p*p*3] 238 | mask: [N, L], 0 is keep, 1 is remove, 239 | """ 240 | target = self.patchify(imgs) 241 | if self.norm_pix_loss: 242 | mean = target.mean(dim=-1, keepdim=True) 243 | var = target.var(dim=-1, keepdim=True) 244 | target = (target - mean) / (var + 1.e-6)**.5 245 | 246 | loss = (pred - target) ** 2 247 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch 248 | 249 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 250 | return loss 251 | 252 | def forward(self, imgs, mask_ratio=0.75): 253 | latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) 254 | pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] 255 | loss = self.forward_loss(imgs, pred, mask) 256 | return loss, pred, mask -------------------------------------------------------------------------------- /utils/arguments_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import glob 4 | import os 5 | import random 6 | from collections import OrderedDict 7 | import numpy as np 8 | import yaml 9 | import torch 10 | from torch.nn.parallel import DistributedDataParallel as DDP 11 | from torch.distributed.elastic.multiprocessing.errors import record 12 | 13 | 14 | def random_seed(seed=42, rank=0): 15 | torch.manual_seed(seed + rank) 16 | np.random.seed(seed + rank) 17 | random.seed(seed + rank) 18 | 19 | def world_info_from_env(): 20 | local_rank = 0 21 | for v in ( 22 | "LOCAL_RANK", 23 | "MPI_LOCALRANKID", 24 | "SLURM_LOCALID", 25 | "OMPI_COMM_WORLD_LOCAL_RANK", 26 | ): 27 | if v in os.environ: 28 | local_rank = int(os.environ[v]) 29 | break 30 | global_rank = 0 31 | for v in ("RANK", "PMI_RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK"): 32 | if v in os.environ: 33 | global_rank = int(os.environ[v]) 34 | break 35 | world_size = 1 36 | for v in ("WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE"): 37 | if v in os.environ: 38 | world_size = int(os.environ[v]) 39 | break 40 | 41 | return local_rank, global_rank, world_size 42 | 43 | def get_parser(is_eval=False): 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument( 46 | "--run_name", 47 | type=str, 48 | default="RobotFlamingo", 49 | help="used to name saving directory and wandb run", 50 | ) 51 | parser.add_argument("--offline", action="store_true") 52 | parser.add_argument("--num_epochs", type=int, default=1) 53 | # Sum of gradient optimization batch size 54 | parser.add_argument("--batch_size", type=int, default=1) 55 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1) 56 | parser.add_argument( 57 | "--resume_from_checkpoint", 58 | type=str, 59 | help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states", 60 | default=None, 61 | ) 62 | parser.add_argument( 63 | "--delete_previous_checkpoint", 64 | action="store_true", 65 | help="delete previous checkpoint when saving new checkpoint", 66 | ) 67 | parser.add_argument("--seed", type=int, default=42) 68 | parser.add_argument("--learning_rate", default=1e-4, type=float) # 1e-4 69 | parser.add_argument( 70 | "--lr_scheduler", 71 | default="constant", 72 | type=str, 73 | help="constant, linear, or cosine", 74 | ) 75 | parser.add_argument( 76 | "--calvin_dataset", 77 | type=str, 78 | default='/mnt/petrelfs/share_data/robomani/calvin_data/task_ABCD_D', 79 | help="path to calvin_dataset", 80 | ) 81 | parser.add_argument("--warmup_epochs", default=1, type=int) 82 | parser.add_argument("--local-rank", default=0, type=int) 83 | parser.add_argument("--weight_decay", default=0.1, type=float) 84 | # hot fix for torch.distributed.launch 85 | parser.add_argument( 86 | "--precision", 87 | choices=["amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32", "bf16_and_fp32"], 88 | default="fp32", 89 | help="Floating point precision.", 90 | ) 91 | # data args 92 | parser.add_argument("--workers", type=int, default=16) 93 | # distributed training args 94 | parser.add_argument( 95 | "--dist-url", 96 | default="env://", 97 | type=str, 98 | help="url used to set up distributed training", 99 | ) 100 | parser.add_argument( 101 | "--dist-backend", default="nccl", type=str, help="distributed backend" 102 | ) 103 | parser.add_argument( 104 | "--no-set-device-rank", 105 | default=False, 106 | action="store_true", 107 | help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).", 108 | ) 109 | # wandb args 110 | parser.add_argument("--report_to_wandb", default=False, action="store_true") 111 | parser.add_argument( 112 | "--wandb_project", 113 | type=str, 114 | ) 115 | parser.add_argument( 116 | "--wandb_entity", 117 | type=str, 118 | ) 119 | parser.add_argument( 120 | "--save_checkpoints_to_wandb", 121 | default=False, 122 | action="store_true", 123 | help="save checkpoints to wandb", 124 | ) 125 | parser.add_argument('--rgb_pad', type=int, default=-1) 126 | parser.add_argument('--gripper_pad', type=int, default=-1) 127 | parser.add_argument( 128 | "--traj_cons", 129 | default=False, 130 | action="store_true" 131 | ) 132 | parser.add_argument( 133 | "--text_aug", 134 | default=False, 135 | action="store_true" 136 | ) 137 | parser.add_argument( 138 | "--residual", 139 | default=False, 140 | action="store_true" 141 | ) 142 | parser.add_argument( 143 | "--dif_ws", 144 | default=False, 145 | action="store_true" 146 | ) 147 | parser.add_argument( 148 | "--partial_data", 149 | default=False, 150 | action="store_true" 151 | ) 152 | # data 153 | parser.add_argument("--save_every_iter", type=int, default=-1) 154 | parser.add_argument("--min_window_size", type=int, default=12) 155 | parser.add_argument("--max_window_size", type=int, default=24) 156 | parser.add_argument("--multi_step_action", type=int, default=1, help="multiple step action prediction") 157 | # ceph 158 | parser.add_argument("--data_in_ceph",default=False, action="store_true") 159 | # oxe 160 | parser.add_argument("--root_dir", type=str, default="s3://real_data") 161 | parser.add_argument("--image_primary_size", type=int, default=200) 162 | parser.add_argument("--image_wrist_size", type=int, default=84) 163 | parser.add_argument("--finetune_type", type=str, default="",) 164 | # save checkpoint 165 | parser.add_argument("--start_save_checkpoint", default=-1, type=int) 166 | parser.add_argument("--save_checkpoint", default=False, action="store_true") 167 | parser.add_argument("--save_checkpoint_path", required=True, type=str) 168 | parser.add_argument("--save_checkpoint_seq", type=int, default=1) 169 | # if validate 170 | parser.add_argument("--validation", default=False, action="store_true") 171 | # bf16 module 172 | parser.add_argument("--bf16_module", type=str, default="") 173 | # model structure 174 | parser.add_argument("--sequence_length", type=int, default=10) 175 | # for image prediction 176 | parser.add_argument("--future_steps", type=int, default=3) 177 | parser.add_argument("--num_resampler_query", type=int, default=9) 178 | parser.add_argument("--num_obs_token_per_image", type=int, default=9) 179 | parser.add_argument("--calvin_input_image_size", type=int, default=224) 180 | parser.add_argument("--patch_size", type=int, default=16) 181 | # droid 182 | parser.add_argument("--primary_mode", type=str, default="image_primary") 183 | parser.add_argument("--small_size", type=int, default=0) 184 | parser.add_argument("--dataset_info", type=str, default="droid_success") 185 | # pretrain 186 | parser.add_argument("--finetune_from_pretrained_ckpt", type=str, default=None) 187 | # loss 188 | parser.add_argument("--loss_arm_action_ratio", type=float, default=1.0) 189 | parser.add_argument("--loss_gripper_action_ratio", type=float, default=0.01) 190 | # action_pred_steps 191 | parser.add_argument("--action_pred_steps", type=int, default=1) 192 | # obs_pred 193 | parser.add_argument("--obs_pred", default=False, action="store_true") 194 | parser.add_argument("--atten_only_obs", default=False, action="store_true") 195 | parser.add_argument("--attn_robot_proprio_state", default=False, action="store_true") 196 | parser.add_argument("--atten_goal", default=0, type=int) 197 | parser.add_argument("--atten_goal_state", default=False, action="store_true") 198 | # action mask ratio 199 | parser.add_argument("--mask_l_obs_ratio", default=0.00, type=float) 200 | # reset during finetuning 201 | parser.add_argument("--reset_action_token", default=False, action="store_true") 202 | parser.add_argument("--reset_obs_token", default=False, action="store_true") 203 | parser.add_argument("--reset_mask_token", default=False, action="store_true") 204 | parser.add_argument("--reset_image_decoder", default=False, action="store_true") 205 | parser.add_argument("--reset_action_decoder", default=False, action="store_true") 206 | # loss 207 | parser.add_argument("--loss_action", default=False, action="store_true") 208 | parser.add_argument("--loss_image", default=False, action="store_true") 209 | 210 | # calvin 211 | parser.add_argument("--except_lang", default=False, action="store_true") 212 | # gpt2 213 | parser.add_argument("--transformer_layers", default=12, type=int) 214 | parser.add_argument("--hidden_dim", default=384, type=int) 215 | parser.add_argument("--transformer_heads", default=12, type=int) 216 | # pretrain, finetune, evaluate 217 | parser.add_argument('--phase', required=True, help='pretrain, finetune, evaluate') 218 | # libero 219 | parser.add_argument("--libero_path", default="/ailab/user/tianyang/Code/LIBERO") 220 | parser.add_argument("--libero_img_size", default=128, type=int) 221 | parser.add_argument("--libero_eval_max_steps", default=600, type=int) 222 | parser.add_argument("--gripper_width", default=False, action="store_true") 223 | parser.add_argument("--load_libero_file", type=str, default="h5") 224 | parser.add_argument("--eval_libero_ensembling", default=False, action="store_true") 225 | parser.add_argument("--ensembling_temp", default=0.01, type=float) 226 | # real 227 | parser.add_argument("--real_dataset_names", type=str) 228 | parser.add_argument("--use_aug_data", default=False, action="store_true") 229 | parser.add_argument("--real_eval_max_steps", default=600, type=int) 230 | # preprocess 231 | parser.add_argument("--max_rel_pos", type=float, default=0.02) 232 | parser.add_argument("--max_rel_orn", type=float, default=0.05) 233 | parser.add_argument("--magic_scaling_factor_pos", type=float, default=1.0) 234 | parser.add_argument("--magic_scaling_factor_orn", type=float, default=1.0) 235 | # for eval 236 | if is_eval: 237 | parser.add_argument("--calvin_conf_path", type=str, help="path to calvin configuration file") 238 | parser.add_argument("--future_act_len", default=-1, type=int) 239 | parser.add_argument( 240 | "--visualize", 241 | default=False, 242 | action="store_true" 243 | ) 244 | parser.add_argument( 245 | "--reset", 246 | default=False, 247 | action="store_true" 248 | ) 249 | parser.add_argument( 250 | "--diverse_inst", 251 | default=False, 252 | action="store_true" 253 | ) 254 | parser.add_argument("--pad_length", type=int, default=-1) 255 | parser.add_argument("--window_size", type=int, default=13) 256 | parser.add_argument("--vit_checkpoint_path", type=str) 257 | args = parser.parse_args() 258 | 259 | return parser 260 | 261 | # if args.dataloading_type == "seer": 262 | # if args.phase == "pretrain": 263 | # if args.finetune_type == "calvin": 264 | # args.window_size = args.sequence_length + args.future_steps 265 | # else: 266 | # args.window_size = args.sequence_length 267 | # elif args.phase == "finetune": 268 | # args.window_size = args.sequence_length + args.future_steps -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import random 4 | from collections import OrderedDict 5 | import numpy as np 6 | import torch 7 | import wandb 8 | import clip 9 | from torch.nn.parallel import DistributedDataParallel as DDP 10 | from torch.distributed.elastic.multiprocessing.errors import record 11 | from transformers import ( 12 | get_constant_schedule_with_warmup, 13 | get_cosine_schedule_with_warmup, 14 | get_linear_schedule_with_warmup, 15 | ) 16 | from models.seer_model import SeerAgent 17 | from utils.train_utils import get_checkpoint, train_one_epoch_calvin, get_ckpt_name 18 | from utils.arguments_utils import get_parser 19 | from utils.data_utils import get_calvin_dataset, get_calvin_val_dataset, get_droid_dataset, get_libero_pretrain_dataset, get_libero_finetune_dataset, get_real_finetune_dataset, get_oxe_dataset 20 | from utils.distributed_utils import init_distributed_device, world_info_from_env 21 | 22 | 23 | def random_seed(seed=42, rank=0): 24 | torch.manual_seed(seed + rank) 25 | np.random.seed(seed + rank) 26 | random.seed(seed + rank) 27 | 28 | def count_parameters(model): 29 | total_params = 0 30 | trainable_params = 0 31 | for param in model.parameters(): 32 | total_params += param.numel() 33 | if param.requires_grad: 34 | trainable_params += param.numel() 35 | return total_params, trainable_params 36 | 37 | @record 38 | def main(args): 39 | os.environ["WANDB_DIR"] = f"{os.path.abspath(args.save_checkpoint_path)}" 40 | if args.save_checkpoints_to_wandb and args.save_checkpoint and not args.report_to_wandb: 41 | raise ValueError("save_checkpoints_to_wandb requires report_to_wandb") 42 | if args.offline: 43 | os.environ["WANDB_MODE"] = "offline" 44 | os.environ["TRANSFORMERS_OFFLINE"] = "1" 45 | args.local_rank, args.rank, args.world_size = world_info_from_env() 46 | device_id = init_distributed_device(args) 47 | print("device_id: ", device_id) 48 | random_seed(args.seed) 49 | ptbs = args.world_size * args.batch_size * args.gradient_accumulation_steps 50 | print("training batch size:", ptbs) 51 | args.run_name = args.run_name.replace("Seer", f"Seer_ptbs{ptbs}_{args.transformer_layers}layers_{args.transformer_heads}heads_hd{args.hidden_dim}") 52 | print("run_name:", args.run_name) 53 | model = SeerAgent( 54 | finetune_type=args.finetune_type, 55 | clip_device=device_id, 56 | vit_checkpoint_path=args.vit_checkpoint_path, 57 | sequence_length=args.sequence_length, 58 | num_resampler_query=args.num_resampler_query, 59 | num_obs_token_per_image=args.num_obs_token_per_image, 60 | calvin_input_image_size=args.calvin_input_image_size, 61 | patch_size=args.patch_size, 62 | action_pred_steps=args.action_pred_steps, 63 | obs_pred=args.obs_pred, 64 | atten_only_obs=args.atten_only_obs, 65 | attn_robot_proprio_state=args.attn_robot_proprio_state, 66 | atten_goal=args.atten_goal, 67 | atten_goal_state=args.atten_goal_state, 68 | mask_l_obs_ratio=args.mask_l_obs_ratio, 69 | transformer_layers=args.transformer_layers, 70 | hidden_dim=args.hidden_dim, 71 | transformer_heads=args.transformer_heads, 72 | phase=args.phase, 73 | gripper_width=args.gripper_width, 74 | ) 75 | if args.finetune_type == "calvin": 76 | calvin_dataset = get_calvin_dataset(args, model.image_processor, clip, epoch=0, except_lang=args.except_lang) 77 | elif args.finetune_type == "droid": 78 | calvin_dataset = get_droid_dataset(args, model.image_processor, clip, epoch=0) 79 | elif args.finetune_type == "libero_pretrain": 80 | calvin_dataset = get_libero_pretrain_dataset(args, model.image_processor, clip, epoch=0) 81 | elif args.finetune_type == "libero_finetune": 82 | calvin_dataset = get_libero_finetune_dataset(args, model.image_processor, clip, epoch=0) 83 | elif args.finetune_type == "real": 84 | calvin_dataset = get_real_finetune_dataset(args, model.image_processor, clip, epoch=0) 85 | elif args.finetune_type == "oxe": 86 | calvin_dataset = get_oxe_dataset(args, model.image_processor, clip, epoch=0) 87 | random_seed(args.seed, args.rank) 88 | print(f"Start running training on rank {args.rank}.") 89 | if args.rank == 0 and args.report_to_wandb: 90 | print("wandb_project :", args.wandb_project) 91 | print("wandb_entity :", args.wandb_entity) 92 | wandb.init( 93 | project=args.wandb_project, 94 | entity=args.wandb_entity, 95 | name=args.run_name, 96 | config=vars(args), 97 | ) 98 | device_id = args.rank % torch.cuda.device_count() 99 | if args.precision == "bf16" or args.precision == "amp_bfloat16" or args.precision == "amp_bf16": 100 | model = model.bfloat16() 101 | elif args.precision == "fp16": 102 | model = model.half() 103 | elif args.precision == "fp32": 104 | model = model.float() 105 | if 'vision_encoder' in args.bf16_module: 106 | model.vision_encoder.bfloat16() 107 | if "causal_transformer" in args.bf16_module: 108 | model.transformer_backbone.bfloat16() 109 | if "image_decoder" in args.bf16_module: 110 | model.image_decoder.bfloat16() 111 | model.image_decoder_obs_pred_projector.bfloat16() 112 | model.clip_model.requires_grad_(False) 113 | model.vision_encoder.requires_grad_(False) 114 | total_params, trainable_params = count_parameters(model) 115 | print("total_params: {} M".format(total_params/1024/1024)) 116 | print("trainable_params: {} M".format(trainable_params/1024/1024)) 117 | model = model.to(device_id) 118 | model._init_model_type() 119 | ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters=True) 120 | optimizer = torch.optim.AdamW([p for p in ddp_model.parameters() if p.requires_grad], lr=args.learning_rate, weight_decay=args.weight_decay) # TODO make sure the parameters which need to be optimized are passing 121 | total_training_steps = calvin_dataset.dataloader.num_batches * args.num_epochs 122 | args.warmup_steps = calvin_dataset.dataloader.num_batches * args.warmup_epochs 123 | if args.rank == 0: 124 | print(f"Total training steps: {total_training_steps}") 125 | if args.lr_scheduler == "linear": 126 | if args.gradient_accumulation_steps > 1: 127 | lr_scheduler = get_linear_schedule_with_warmup( 128 | optimizer, 129 | num_warmup_steps=args.warmup_steps // args.gradient_accumulation_steps + 1, 130 | num_training_steps=total_training_steps // args.gradient_accumulation_steps + 1, 131 | ) 132 | else: 133 | lr_scheduler = get_linear_schedule_with_warmup( 134 | optimizer, 135 | num_warmup_steps=args.warmup_steps, 136 | num_training_steps=total_training_steps, 137 | ) 138 | elif args.lr_scheduler == "cosine": 139 | if args.gradient_accumulation_steps > 1: 140 | lr_scheduler = get_cosine_schedule_with_warmup( 141 | optimizer, 142 | num_warmup_steps=args.warmup_steps // args.gradient_accumulation_steps + 1, 143 | num_training_steps=total_training_steps // args.gradient_accumulation_steps + 1, 144 | ) 145 | else: 146 | lr_scheduler = get_cosine_schedule_with_warmup( 147 | optimizer, 148 | num_warmup_steps=args.warmup_steps, 149 | num_training_steps=total_training_steps, 150 | ) 151 | elif args.lr_scheduler == 'cosine_restart': 152 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-7) 153 | else: 154 | lr_scheduler = get_constant_schedule_with_warmup( 155 | optimizer, num_warmup_steps=args.warmup_steps 156 | ) 157 | resume_from_epoch = 0 158 | if args.finetune_from_pretrained_ckpt is not None: 159 | if args.rank == 0: 160 | print(f"Starting finetuning from pretrained checkpoint {args.finetune_from_pretrained_ckpt}") 161 | checkpoint = torch.load(args.finetune_from_pretrained_ckpt, map_location="cpu") 162 | image_decoder_keys = [k for k in checkpoint["model_state_dict"].keys() if "image_decoder" in k] 163 | projector_keys = [k for k in checkpoint["model_state_dict"].keys() if "projector" in k] 164 | action_decoder_keys = [k for k in checkpoint["model_state_dict"].keys() if "action_decoder" in k] 165 | if args.reset_action_token: 166 | del checkpoint["model_state_dict"]["module.action_pred_token"] 167 | if args.reset_obs_token: 168 | del checkpoint["model_state_dict"]["module.obs_tokens"] 169 | if args.reset_mask_token: 170 | del checkpoint["model_state_dict"]["module.mask_token"] 171 | if args.reset_image_decoder: 172 | for k in image_decoder_keys: 173 | if k in checkpoint["model_state_dict"]: 174 | del checkpoint["model_state_dict"][k] 175 | if args.reset_action_decoder: 176 | for k in action_decoder_keys: 177 | if k in checkpoint["model_state_dict"]: 178 | del checkpoint["model_state_dict"][k] 179 | if checkpoint["model_state_dict"]["module.transformer_backbone_position_embedding"].shape != ddp_model.module.transformer_backbone_position_embedding.shape: 180 | checkpoint["model_state_dict"]["module.transformer_backbone_position_embedding"] = checkpoint["model_state_dict"]["module.transformer_backbone_position_embedding"][:, :args.sequence_length, :, :] 181 | print("loading pretrained weights :", checkpoint["model_state_dict"].keys()) 182 | ddp_model.load_state_dict(checkpoint["model_state_dict"], False) 183 | if args.resume_from_checkpoint is not None: 184 | if args.rank == 0: 185 | print(f"Loading checkpoint from {args.resume_from_checkpoint}") 186 | checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu") 187 | ddp_model.load_state_dict(checkpoint["model_state_dict"], False) 188 | optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 189 | lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"]) 190 | resume_from_epoch = checkpoint["epoch"] + 1 191 | 192 | ckpt_dir = os.path.join(f"{args.save_checkpoint_path}", args.run_name) 193 | if args.rank == 0 and not os.path.exists(ckpt_dir): 194 | os.makedirs(ckpt_dir) 195 | 196 | ddp_model.train() 197 | for epoch in range(resume_from_epoch, args.num_epochs): 198 | calvin_dataset.set_epoch(epoch) 199 | calvin_loader = calvin_dataset.dataloader 200 | train_one_epoch_calvin( 201 | args=args, 202 | model=ddp_model, 203 | epoch=epoch, 204 | optimizer=optimizer, 205 | lr_scheduler=lr_scheduler, 206 | calvin_loader=calvin_loader, 207 | device_id=device_id, 208 | wandb=wandb, 209 | ) 210 | if args.rank == 0 and args.save_checkpoint and epoch % args.save_checkpoint_seq == 0 and epoch > args.start_save_checkpoint: 211 | checkpoint_dict = { 212 | "epoch": epoch, 213 | "model_state_dict": get_checkpoint(ddp_model), 214 | "optimizer_state_dict": optimizer.state_dict(), 215 | "lr_scheduler_state_dict": lr_scheduler.state_dict(), 216 | } 217 | ckpt_name = get_ckpt_name(args, epoch) 218 | ckpt_path = os.path.join(ckpt_dir, ckpt_name) 219 | print(f"Saving checkpoint to {ckpt_path}") 220 | torch.save(checkpoint_dict, ckpt_path) 221 | if args.delete_previous_checkpoint: 222 | if epoch > 0: 223 | os.remove(ckpt_path) 224 | 225 | if __name__ == "__main__": 226 | parser = get_parser() 227 | args = parser.parse_args() 228 | main(args) 229 | --------------------------------------------------------------------------------