├── real_preprocess
├── octo_oxe_data_utils
│ ├── __init__.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── goal_relabeling.py
│ │ ├── task_augmentation.py
│ │ └── text_processing.py
│ ├── setup.py
│ ├── requirements.txt
│ ├── obs_transforms.py
│ ├── traj_transforms.py
│ └── oxe
│ │ ├── __init__.py
│ │ └── oxe_dataset_mixes.py
├── mujoco_menagerie
│ └── robotiq_2f85
│ │ ├── 2f85.png
│ │ ├── assets
│ │ ├── base.stl
│ │ ├── pad.stl
│ │ ├── coupler.stl
│ │ ├── driver.stl
│ │ ├── follower.stl
│ │ ├── base_mount.stl
│ │ ├── silicone_pad.stl
│ │ └── spring_link.stl
│ │ ├── LICENSE
│ │ ├── README.md
│ │ ├── scene.xml
│ │ └── 2f85.xml
└── requirements.txt
├── data_info
├── ep_lens.npy
├── ep_start_end_ids.npy
├── except_lang_idx
│ └── except_lang_idx.npy
├── .hydra
│ ├── overrides.yaml
│ ├── config.yaml
│ ├── hydra.yaml
│ └── merged_config.yaml
├── droid_success_languaged_0803_stat_0.1_0.5.json
├── droid_success_full_0803_stat.json
├── droid_success_languaged_0803_stat_0.02_0.05.json
├── droid_failure.json
└── austin_buds_dataset_converted_externally_to_rlds.json
├── assets
└── seer_method.jpg
├── requirements.txt
├── docs
├── LIBERO_LONG_INSTALL.md
├── REAL-WORLD_INSTALL.md
├── REAL-WORLD_FT_SC.md
├── CALVIN_ABC-D_INSTALL.md
├── REAL-WORLD_PRETRAIN.md
├── REAL-WORLD_INFERENCE.md
├── LIBERO_LONG_RUN.md
├── CALVIN_ABC-D_RUN.md
├── REAL-WORLD_PREPROCESS.md
└── REAL-WORLD_POSTPROCESS.md
├── scripts
├── CALVIN_ABC_D
│ ├── Seer
│ │ ├── scratch.sh
│ │ ├── pretrain.sh
│ │ ├── finetune.sh
│ │ └── eval.sh
│ └── Seer-Large
│ │ ├── scratch.sh
│ │ ├── pretrain.sh
│ │ ├── eval.sh
│ │ └── finetune.sh
├── REAL
│ ├── slurm_s_oxe_cluster.sh
│ ├── deploy.sh
│ ├── slurm_s_full_cluster.sh
│ ├── slurm_s_language_cluster.sh
│ ├── single_node_full_cluster.sh
│ ├── single_node_language_cluster.sh
│ ├── single_node_scratch.sh
│ └── single_node_ft.sh
└── LIBERO_LONG
│ └── Seer
│ ├── scratch.sh
│ ├── pretrain.sh
│ ├── eval.sh
│ └── finetune.sh
├── deploy.py
├── .gitignore
├── eval_libero.py
├── models
├── perceiver_resampler.py
└── vit_mae.py
├── README.md
├── eval_calvin.py
├── utils
├── distributed_utils.py
├── convert_libero_per_step.py
└── arguments_utils.py
├── slurm_train_intern.py
├── LICENSE
└── train.py
/real_preprocess/octo_oxe_data_utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/real_preprocess/octo_oxe_data_utils/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data_info/ep_lens.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/InternRobotics/Seer/HEAD/data_info/ep_lens.npy
--------------------------------------------------------------------------------
/assets/seer_method.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/InternRobotics/Seer/HEAD/assets/seer_method.jpg
--------------------------------------------------------------------------------
/data_info/ep_start_end_ids.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/InternRobotics/Seer/HEAD/data_info/ep_start_end_ids.npy
--------------------------------------------------------------------------------
/data_info/except_lang_idx/except_lang_idx.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/InternRobotics/Seer/HEAD/data_info/except_lang_idx/except_lang_idx.npy
--------------------------------------------------------------------------------
/real_preprocess/mujoco_menagerie/robotiq_2f85/2f85.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/InternRobotics/Seer/HEAD/real_preprocess/mujoco_menagerie/robotiq_2f85/2f85.png
--------------------------------------------------------------------------------
/real_preprocess/mujoco_menagerie/robotiq_2f85/assets/base.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/InternRobotics/Seer/HEAD/real_preprocess/mujoco_menagerie/robotiq_2f85/assets/base.stl
--------------------------------------------------------------------------------
/real_preprocess/mujoco_menagerie/robotiq_2f85/assets/pad.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/InternRobotics/Seer/HEAD/real_preprocess/mujoco_menagerie/robotiq_2f85/assets/pad.stl
--------------------------------------------------------------------------------
/real_preprocess/mujoco_menagerie/robotiq_2f85/assets/coupler.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/InternRobotics/Seer/HEAD/real_preprocess/mujoco_menagerie/robotiq_2f85/assets/coupler.stl
--------------------------------------------------------------------------------
/real_preprocess/mujoco_menagerie/robotiq_2f85/assets/driver.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/InternRobotics/Seer/HEAD/real_preprocess/mujoco_menagerie/robotiq_2f85/assets/driver.stl
--------------------------------------------------------------------------------
/real_preprocess/mujoco_menagerie/robotiq_2f85/assets/follower.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/InternRobotics/Seer/HEAD/real_preprocess/mujoco_menagerie/robotiq_2f85/assets/follower.stl
--------------------------------------------------------------------------------
/real_preprocess/mujoco_menagerie/robotiq_2f85/assets/base_mount.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/InternRobotics/Seer/HEAD/real_preprocess/mujoco_menagerie/robotiq_2f85/assets/base_mount.stl
--------------------------------------------------------------------------------
/real_preprocess/mujoco_menagerie/robotiq_2f85/assets/silicone_pad.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/InternRobotics/Seer/HEAD/real_preprocess/mujoco_menagerie/robotiq_2f85/assets/silicone_pad.stl
--------------------------------------------------------------------------------
/real_preprocess/mujoco_menagerie/robotiq_2f85/assets/spring_link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/InternRobotics/Seer/HEAD/real_preprocess/mujoco_menagerie/robotiq_2f85/assets/spring_link.stl
--------------------------------------------------------------------------------
/real_preprocess/octo_oxe_data_utils/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 |
4 | setup(
5 | name='octo_oxe_data_utils',
6 | version='1.0.0',
7 | packages=find_packages(),
8 | )
--------------------------------------------------------------------------------
/data_info/.hydra/overrides.yaml:
--------------------------------------------------------------------------------
1 | - load_dir=/work/dlclarge2/meeso-lfp/calvin_recordings/play_env_A/2021-10-06/16-23-57/
2 | - set_static_cam=false
3 | - processes=16
4 | - +scene=calvin_scene_A
5 | - show_gui=false
6 |
--------------------------------------------------------------------------------
/real_preprocess/octo_oxe_data_utils/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow == 2.15.0
2 | tensorflow_datasets == 4.9.2
3 | tensorflow_graphics == 2021.12.3
4 | dlimp @ git+https://github.com/kvablack/dlimp@d08da3852c149548aaa8551186d619d87375df08
5 | imageio
6 | absl-py
7 | mujoco
8 | roboticstoolbox-python[collision]
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | flamingo_pytorch
2 | tensorboard
3 | ftfy==6.2.0
4 | regex
5 | tqdm
6 | matplotlib
7 | git+https://github.com/openai/CLIP.git
8 | timm==0.9.16
9 | h5py==3.11.0
10 | transformers==4.40.2
11 | packaging==24.0
12 | setuptools==57.5.0
13 | omegaconf==2.1.2
14 | moviepy==1.0.3
15 | einops_exts
16 | numpy==1.23.1
--------------------------------------------------------------------------------
/docs/LIBERO_LONG_INSTALL.md:
--------------------------------------------------------------------------------
1 | # Installation
2 |
3 | **(1) Conda Env**
4 | ```
5 | conda create -n seer python=3.10
6 | conda activate seer
7 | ```
8 |
9 | **(2) LIBERO Env**
10 | ```
11 | git clone https://github.com/Lifelong-Robot-Learning/LIBERO.git
12 | cd LIBERO
13 | pip install -r requirements.txt
14 | pip install transformers==4.40.2
15 | pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121
16 | pip install -e .
17 | ```
--------------------------------------------------------------------------------
/real_preprocess/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.24.3
2 | tensorflow_probability==0.23.0
3 | tensorflow==2.15.0
4 | ml_collections>=0.1.0
5 | tqdm>=4.60.0
6 | absl-py>=0.12.0
7 | scipy>=1.6.0
8 | einops>=0.6.1
9 | tensorflow_hub>=0.14.0
10 | tensorflow_text>=2.13.0
11 | tensorflow_datasets==4.9.2
12 | tensorflow_graphics==2021.12.3
13 | mujoco
14 | Pillow
15 | pybullet
16 | dlimp@git+https://github.com/kvablack/dlimp@5edaa4691567873d495633f2708982b42edf1972
17 | roboticstoolbox-python
--------------------------------------------------------------------------------
/data_info/droid_success_languaged_0803_stat_0.1_0.5.json:
--------------------------------------------------------------------------------
1 | {
2 | "droid_success_languaged_0803": {
3 | "mean": [
4 | -0.013946113699228197,
5 | -0.0008144068900243393,
6 | -0.010159969904343499,
7 | 0.029293912297435065,
8 | -0.025690169644747598,
9 | -0.015989328273960466
10 | ],
11 | "std": [
12 | 0.20220592438479257,
13 | 0.21943531464788582,
14 | 0.210883565255364,
15 | 0.17382451705057012,
16 | 0.14346100321749855,
17 | 0.14715284488965125
18 | ]
19 | }
20 | }
--------------------------------------------------------------------------------
/data_info/droid_success_full_0803_stat.json:
--------------------------------------------------------------------------------
1 | {
2 | "droid_success_full_0803": {
3 | "mean": [
4 | 0.034469623780328515,
5 | -0.001633081419506495,
6 | 0.0319105912448654,
7 | -0.0027474652911320127,
8 | 0.007313525898785993,
9 | 0.006798099658541783,
10 | 0.12350288160670808
11 | ],
12 | "std": [
13 | 0.2520571699338127,
14 | 0.2486948654319061,
15 | 0.2937194360082029,
16 | 0.25230853439244694,
17 | 0.2796370525416708,
18 | 0.34477038757871253,
19 | 0.9923442135864566
20 | ]
21 | }
22 | }
--------------------------------------------------------------------------------
/data_info/droid_success_languaged_0803_stat_0.02_0.05.json:
--------------------------------------------------------------------------------
1 | {
2 | "droid_success_languaged_0803": {
3 | "mean": [
4 | 0.034524966522835786,
5 | -0.002034788243149046,
6 | 0.03286368149243478,
7 | -0.002270051002688882,
8 | 0.007490289567743871,
9 | 0.008484724839843296,
10 | 0.1251103956939712
11 | ],
12 | "std": [
13 | 0.25377941149232486,
14 | 0.2485252318164213,
15 | 0.29530677038638276,
16 | 0.2521710628202299,
17 | 0.2795374759449668,
18 | 0.3412635660188041,
19 | 0.9921428268718969
20 | ]
21 | }
22 | }
--------------------------------------------------------------------------------
/docs/REAL-WORLD_INSTALL.md:
--------------------------------------------------------------------------------
1 | # Installation
2 | To set up Seer, we will create an isolated environment called seer. This environment is designed to support pre-training, fine-tuning, and inference workflows.
3 |
4 | ## seer env
5 | **(1) Env**
6 | ```python
7 | conda create -n seer python=3.10
8 | conda activate seer
9 | ```
10 | **(2) Third Party Packages**
11 | ```python
12 | cd ${YOUR_PATH_TO_SEER}
13 | pip install -r requirements.txt
14 | pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121
15 | ```
16 |
--------------------------------------------------------------------------------
/docs/REAL-WORLD_FT_SC.md:
--------------------------------------------------------------------------------
1 | # Quick Training
2 | Preparation
3 | ```python
4 | cd ${YOUR_PATH_TO_SEER}
5 | conda activate seer
6 | ```
7 | Download related checkpoints from the [checkpoint repository](https://drive.google.com/drive/folders/1rT8JKLhJGIo97jfYUm2JiFUrogOq-dgJ?usp=drive_link).
8 | ## :sparkles: Fine-tuning
9 | * For single-node fine-tuning:
10 | ```bash
11 | bash scripts/REAL/single_node_ft.sh
12 | ```
13 | ## :sparkles: Training from Scratch
14 | * For single-node training from scratch:
15 | ```bash
16 | bash scripts/REAL/single_node_scratch.sh
17 | ```
18 |
--------------------------------------------------------------------------------
/data_info/droid_failure.json:
--------------------------------------------------------------------------------
1 | [
2 | [
3 | "004897",
4 | 141
5 | ],
6 | [
7 | "009466",
8 | 80
9 | ],
10 | [
11 | "021346",
12 | 154
13 | ],
14 | [
15 | "023455",
16 | 412
17 | ],
18 | [
19 | "026167",
20 | 174
21 | ],
22 | [
23 | "026189",
24 | 207
25 | ],
26 | [
27 | "026614",
28 | 214
29 | ],
30 | [
31 | "027372",
32 | 145
33 | ],
34 | [
35 | "031500",
36 | 125
37 | ],
38 | [
39 | "032599",
40 | 180
41 | ],
42 | [
43 | "033248",
44 | 107
45 | ],
46 | [
47 | "035609",
48 | 126
49 | ],
50 | [
51 | "036914",
52 | 148
53 | ],
54 | [
55 | "037104",
56 | 188
57 | ],
58 | [
59 | "037173",
60 | 108
61 | ],
62 | [
63 | "042263",
64 | 226
65 | ],
66 | [
67 | "042479",
68 | 241
69 | ],
70 | [
71 | "045810",
72 | 202
73 | ],
74 | [
75 | "054451",
76 | 381
77 | ],
78 | [
79 | "060790",
80 | 204
81 | ],
82 | [
83 | "061356",
84 | 59
85 | ],
86 | [
87 | "062455",
88 | 225
89 | ],
90 | [
91 | "062856",
92 | 235
93 | ],
94 | [
95 | "081293",
96 | 723
97 | ]
98 | ]
--------------------------------------------------------------------------------
/docs/CALVIN_ABC-D_INSTALL.md:
--------------------------------------------------------------------------------
1 | # Installation
2 |
3 | **(1) Conda Env**
4 | ```
5 | conda create -n seer python=3.10
6 | conda activate seer
7 | ```
8 |
9 | **(2) CALVIN**
10 | > Exact instructions on [CALVIN](https://github.com/mees/calvin).
11 | ```
12 | git clone --recurse-submodules https://github.com/mees/calvin.git
13 | export CALVIN_ROOT=$(pwd)/calvin
14 | cd $CALVIN_ROOT
15 | sh install.sh
16 | ```
17 |
18 | **(3) Dataset Download**
19 | > We only download CALVIN ABC-D.
20 | ```
21 | cd $CALVIN_ROOT/dataset
22 | sh download_data.sh ABC
23 | ```
24 |
25 | **(4) Third Party Packages**
26 | ```
27 | cd ${YOUR_PATH_TO_SEER}
28 | pip install -r requirements.txt
29 | pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121
30 | ```
31 |
32 | **(5) Create a soft link to CALVIN**
33 | ```
34 | cd ${YOUR_PATH_TO_SEER}
35 | ln -s $CALVIN_ROOT calvin
36 | ```
37 |
38 | **(6) Copy the index file `except_lang_idx.npy` to the CALVIN ABC-D training data directory.**
39 | ```python
40 | cp -r data_info/except_lang_idx/except_lang_idx.npy calvin/dataset/task_ABC_D/training
41 | ```
--------------------------------------------------------------------------------
/docs/REAL-WORLD_PRETRAIN.md:
--------------------------------------------------------------------------------
1 | # Pre-train
2 | ## Notice
3 | We provide code for pre-training on both the DROID and OXE datasets. Users should update the save_checkpoint_path to the directory where you want to save the training checkpoints, and modify the root_dir to the location where the preprocessed real data is stored. Additionally, users should configure the SLURM information in the provided scripts.
4 |
5 | Preparation
6 | ```python
7 | cd ${YOUR_PATH_TO_SEER}
8 | conda activate seer
9 | ```
10 |
11 | ## Pre-train (DROID FULL)
12 | * For single-node pre-training:
13 | ```bash
14 | bash scripts/REAL/single_node_full_cluster.sh
15 | ```
16 | * For multi-node pre-training:
17 | ```bash
18 | bash scripts/REAL/slurm_s_full_cluster.sh
19 | ```
20 | ## Pre-train (DROID with Language)
21 | * For single-node pre-training:
22 | ```bash
23 | bash scripts/REAL/single_node_language_cluster.sh
24 | ```
25 | * For multi-node pre-training:
26 | ```bash
27 | bash scripts/REAL/slurm_s_language_cluster.sh
28 | ```
29 | ## Pre-train (OXE)
30 | * For multi-node pre-training:
31 | We should first generate data info
32 | ```bash
33 | use the oxe_dataset_info in Seer/utils/real_ft_data.py
34 | ```
35 | Then we train the model
36 | ```bash
37 | bash scripts/REAL/slurm_s_language_cluster.sh
38 | ```
39 |
--------------------------------------------------------------------------------
/real_preprocess/octo_oxe_data_utils/utils/goal_relabeling.py:
--------------------------------------------------------------------------------
1 | """
2 | Contains simple goal relabeling logic for BC use-cases where rewards and next_observations are not required.
3 | Each function should add entries to the "task" dict.
4 | """
5 |
6 | import tensorflow as tf
7 |
8 | from octo_oxe_data_utils.utils.data_utils import tree_merge
9 |
10 |
11 | def uniform(traj: dict) -> dict:
12 | """Relabels with a true uniform distribution over future states."""
13 | traj_len = tf.shape(tf.nest.flatten(traj["observation"])[0])[0]
14 |
15 | # select a random future index for each transition i in the range [i + 1, traj_len)
16 | rand = tf.random.uniform([traj_len])
17 | low = tf.cast(tf.range(traj_len) + 1, tf.float32)
18 | high = tf.cast(traj_len, tf.float32)
19 | goal_idxs = tf.cast(rand * (high - low) + low, tf.int32)
20 |
21 | # sometimes there are floating-point errors that cause an out-of-bounds
22 | goal_idxs = tf.minimum(goal_idxs, traj_len - 1)
23 |
24 | # adds keys to "task" mirroring "observation" keys (must do a tree merge to combine "pad_mask_dict" from
25 | # "observation" and "task" properly)
26 | goal = tf.nest.map_structure(lambda x: tf.gather(x, goal_idxs), traj["observation"])
27 | traj["task"] = tree_merge(traj["task"], goal)
28 |
29 | return traj
30 |
--------------------------------------------------------------------------------
/real_preprocess/mujoco_menagerie/robotiq_2f85/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2013, ROS-Industrial
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without modification,
5 | are permitted provided that the following conditions are met:
6 |
7 | Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | Redistributions in binary form must reproduce the above copyright notice, this
11 | list of conditions and the following disclaimer in the documentation and/or
12 | other materials provided with the distribution.
13 |
14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
15 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
16 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
18 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
19 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
20 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
21 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
23 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 |
--------------------------------------------------------------------------------
/real_preprocess/mujoco_menagerie/robotiq_2f85/README.md:
--------------------------------------------------------------------------------
1 | # Robotiq 2F-85 Description (MJCF)
2 |
3 | Requires MuJoCo 2.2.2 or later.
4 |
5 | ## Overview
6 |
7 | This package contains a simplified robot description (MJCF) of the [Robotiq 85mm
8 | 2-Finger Adaptive
9 | Gripper](https://robotiq.com/products/2f85-140-adaptive-robot-gripper) developed
10 | by [Robotiq](https://robotiq.com/). It is derived from the [publicly available
11 | URDF
12 | description](https://github.com/ros-industrial/robotiq/tree/kinetic-devel/robotiq_2f_85_gripper_visualization).
13 |
14 |
15 |
16 |
17 |
18 | ## URDF → MJCF derivation steps
19 |
20 | 1. Added ` ` to the URDF's
21 | `` clause in order to preserve visual geometries.
22 | 2. Loaded the URDF into MuJoCo and saved a corresponding MJCF.
23 | 3. Manually edited the MJCF to extract common properties into the `` section.
24 | 4. Added `` clauses to prevent collisions between the linkage bodies.
25 | 5. Broke up collision pads into two pads for more contacts.
26 | 6. Increased pad friction and priority.
27 | 7. Added `impratio=10` for better noslip.
28 | 8. Added `scene.xml` which includes the robot, with a textured groundplane, skybox, and haze.
29 | 9. Added hanging box to `scene.xml`.
30 |
31 | ## License
32 |
33 | This model is released under a [BSD-2-Clause License](LICENSE).
34 |
--------------------------------------------------------------------------------
/scripts/CALVIN_ABC_D/Seer/scratch.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | ### NEED TO CHANGE ###
3 | calvin_dataset_path="calvin/dataset/task_ABC_D"
4 | save_checkpoint_path="checkpoints/"
5 | vit_checkpoint_path="checkpoints/vit_mae/mae_pretrain_vit_base.pth" # downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing
6 |
7 | node=8
8 | node_num=8
9 | torchrun --nnodes=${node} --nproc_per_node=${node_num} --master_port=10211 train.py \
10 | --traj_cons \
11 | --rgb_pad 10 \
12 | --gripper_pad 4 \
13 | --gradient_accumulation_steps 1 \
14 | --bf16_module "vision_encoder" \
15 | --vit_checkpoint_path ${vit_checkpoint_path} \
16 | --calvin_dataset ${calvin_dataset_path} \
17 | --workers 8 \
18 | --lr_scheduler cosine \
19 | --save_every_iter 100000 \
20 | --num_epochs 20 \
21 | --seed 42 \
22 | --batch_size 8 \
23 | --precision fp32 \
24 | --learning_rate 1e-3 \
25 | --finetune_type "calvin" \
26 | --wandb_project seer \
27 | --weight_decay 1e-4 \
28 | --num_resampler_query 6 \
29 | --run_name scratch_seer_calvin_abc_d \
30 | --save_checkpoint \
31 | --save_checkpoint_path ${save_checkpoint_path} \
32 | --transformer_layers 24 \
33 | --phase "finetune" \
34 | --action_pred_steps 3 \
35 | --sequence_length 10 \
36 | --future_steps 3 \
37 | --window_size 13 \
38 | --obs_pred \
39 | --loss_image \
40 | --loss_action \
41 | --report_to_wandb \
42 |
--------------------------------------------------------------------------------
/scripts/REAL/slurm_s_oxe_cluster.sh:
--------------------------------------------------------------------------------
1 | ### NEED TO CHANGE ###
2 | save_checkpoint_path="xxx/checkpoints"
3 | root_dir="xxx/preprocess"
4 | vit_checkpoint_path="xxx/mae_pretrain_vit_base.pth" # downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing
5 | ### NEED TO CHANGE ###
6 |
7 | python slurm_train_intern.py \
8 | --traj_cons \
9 | --rgb_pad 10 \
10 | --gripper_pad 4 \
11 | --gradient_accumulation_steps 2 \
12 | --bf16_module "vision_encoder" \
13 | --vit_checkpoint_path ${vit_checkpoint_path} \
14 | --calvin_dataset "" \
15 | --workers 8 \
16 | --lr_scheduler cosine \
17 | --save_every_iter 20000 \
18 | --num_epochs 30 \
19 | --seed 42 \
20 | --batch_size 32 \
21 | --precision fp32 \
22 | --learning_rate 1e-4 \
23 | --save_checkpoint \
24 | --finetune_type "oxe" \
25 | --wandb_project seer \
26 | --weight_decay 1e-4 \
27 | --num_resampler_query 6 \
28 | --run_name mn_oxe \
29 | --save_checkpoint_path ${save_checkpoint_path} \
30 | --except_lang \
31 | --transformer_layers 24 \
32 | --phase "pretrain" \
33 | --obs_pred \
34 | --action_pred_steps 3 \
35 | --sequence_length 11 \
36 | --window_size 11 \
37 | --future_steps 3 \
38 | --loss_action \
39 | --loss_image \
40 | --atten_goal 4 \
41 | --atten_goal_state \
42 | --atten_only_obs \
43 | --real_dataset_names "" \
44 | --root_dir ${root_dir} \
45 | --report_to_wandb \
46 | --warmup_epochs 3 \
--------------------------------------------------------------------------------
/scripts/REAL/deploy.sh:
--------------------------------------------------------------------------------
1 | ### NEED TO CHANGE ###
2 | resume_from_checkpoint="xxx/xxx.pth"
3 | vit_checkpoint_path="xxx/mae_pretrain_vit_base.pth" # downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing
4 | ### NEED TO CHANGE ###
5 |
6 | IFS='/' read -ra path_parts <<< "$resume_from_checkpoint"
7 | run_name="${path_parts[-2]}"
8 | log_name="${path_parts[-1]}"
9 | log_folder="eval_logs/$run_name"
10 | mkdir -p "$log_folder"
11 | log_file="eval_logs/$run_name/evaluate_$log_name.log"
12 |
13 | node=1
14 | node_num=1
15 | # vision_encoder_causal_transformer_image_decoder
16 | torchrun --nnodes=${node} --nproc_per_node=${node_num} --master_port=10113 deploy.py \
17 | --traj_cons \
18 | --rgb_pad 10 \
19 | --gripper_pad 4 \
20 | --gradient_accumulation_steps 1 \
21 | --bf16_module "vision_encoder" \
22 | --vit_checkpoint_path ${vit_checkpoint_path} \
23 | --workers 16 \
24 | --calvin_dataset "" \
25 | --lr_scheduler cosine \
26 | --save_every_iter 50000 \
27 | --num_epochs 20 \
28 | --seed 42 \
29 | --batch_size 64 \
30 | --precision fp32 \
31 | --weight_decay 1e-4 \
32 | --num_resampler_query 6 \
33 | --run_name test \
34 | --transformer_layers 24 \
35 | --phase "evaluate" \
36 | --finetune_type "real" \
37 | --action_pred_steps 3 \
38 | --future_steps 3 \
39 | --sequence_length 7 \
40 | --obs_pred \
41 | --resume_from_checkpoint ${resume_from_checkpoint} \
42 | --real_eval_max_steps 600 \
43 | --eval_libero_ensembling \
--------------------------------------------------------------------------------
/docs/REAL-WORLD_INFERENCE.md:
--------------------------------------------------------------------------------
1 | # Inference
2 | Seer has been encapsulated into a real-world controller, allowing for straightforward adaptation to downstream tasks. To ensure smooth implementation and avoid common errors, we offer the following recommendations:
3 | * :fire: **Proprio First v.s. Image First:** During data collection and inference, always acquire robot proprioception data first, followed by image observations. This order minimizes the timestep interval since capturing images is significantly more time-intensive than reading proprioception data.
4 | * :fire: **Delta-to-Absolute Action Labels:** To simplify matrix computation and transformation, we provide a delta-action-to-absolute-action conversion in the [deployment script](../deploy.py). This aligns with the absolute-action-to-delta-action transformation found in the [post-process script](../utils/real_ft_data.py).
5 | * :fire: **Consistent Control Frequency:** Ensure that the control frequencies used during data collection match those during inference. Discrepancies in frequency can lead to inconsistent results.
6 |
7 | ## :star: Real-World Controller
8 | A [wrapped seer controller](../real_controller/controller.py) is provided for real-world deployment. This controller is modular and can be easily adapted to specific tasks or environments.
9 |
10 | ## :star2: Real-World Deployment Pseudocode
11 | To deploy the wrapped Seer controller for real-world tasks, modify the [deployment script](../deploy.py) to fit your specific environment. Then, execute the deployment with the following command:
12 | ```python
13 | bash scripts/REAL/deploy.sh
14 | ```
--------------------------------------------------------------------------------
/scripts/CALVIN_ABC_D/Seer/pretrain.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | ### NEED TO CHANGE ###
3 | calvin_dataset_path="calvin/dataset/task_ABC_D"
4 | save_checkpoint_path="checkpoints/"
5 | vit_checkpoint_path="checkpoints/vit_mae/mae_pretrain_vit_base.pth" # downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing
6 | ### NEED TO CHANGE ###
7 |
8 | node=8
9 | node_num=8
10 | torchrun --nnodes=${node} --nproc_per_node=${node_num} --master_port=10211 train.py \
11 | --traj_cons \
12 | --rgb_pad 10 \
13 | --gripper_pad 4 \
14 | --gradient_accumulation_steps 1 \
15 | --bf16_module "vision_encoder" \
16 | --vit_checkpoint_path ${vit_checkpoint_path} \
17 | --calvin_dataset ${calvin_dataset_path} \
18 | --workers 8 \
19 | --lr_scheduler cosine \
20 | --save_every_iter 100000 \
21 | --num_epochs 20 \
22 | --seed 42 \
23 | --batch_size 10 \
24 | --precision fp32 \
25 | --learning_rate 1e-4 \
26 | --finetune_type "calvin" \
27 | --wandb_project seer \
28 | --weight_decay 1e-4 \
29 | --num_resampler_query 6 \
30 | --run_name pretrain_seer_calvin_abc_d \
31 | --save_checkpoint_path ${save_checkpoint_path} \
32 | --transformer_layers 24 \
33 | --phase "pretrain" \
34 | --action_pred_steps 3 \
35 | --sequence_length 14 \
36 | --future_steps 3 \
37 | --window_size 17 \
38 | --obs_pred \
39 | --loss_image \
40 | --loss_action \
41 | --atten_goal 4 \
42 | --atten_goal_state \
43 | --atten_only_obs \
44 | --except_lang \
45 | --save_checkpoint \
46 | --report_to_wandb \
47 |
--------------------------------------------------------------------------------
/real_preprocess/mujoco_menagerie/robotiq_2f85/scene.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
--------------------------------------------------------------------------------
/scripts/LIBERO_LONG/Seer/scratch.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ### NEED TO CHANGE ###
4 | save_checkpoint_path="checkpoints/"
5 | root_dir="PATH_TO_PARENT_DIR_OF_LIBERO_CONVERTED"
6 | vit_checkpoint_path="checkpoints/vit_mae/mae_pretrain_vit_base.pth"
7 | libero_path="PATH_TO_LIBERO"
8 | ### NEED TO CHANGE ###
9 | calvin_dataset_path="calvin/dataset/task_ABC_D"
10 |
11 | node=1
12 | node_num=8
13 | torchrun --nnodes=${node} --nproc_per_node=${node_num} --master_port=10211 train.py \
14 | --traj_cons \
15 | --rgb_pad 10 \
16 | --gripper_pad 4 \
17 | --gradient_accumulation_steps 4 \
18 | --bf16_module "vision_encoder" \
19 | --vit_checkpoint_path ${vit_checkpoint_path} \
20 | --calvin_dataset ${calvin_dataset_path} \
21 | --workers 8 \
22 | --lr_scheduler cosine \
23 | --save_every_iter 100000 \
24 | --num_epochs 40 \
25 | --seed 42 \
26 | --batch_size 16 \
27 | --precision fp32 \
28 | --learning_rate 1e-3 \
29 | --save_checkpoint \
30 | --finetune_type libero_finetune \
31 | --root_dir ${root_dir} \
32 | --wandb_project seer \
33 | --weight_decay 1e-4 \
34 | --num_resampler_query 6 \
35 | --run_name libero_scratch \
36 | --save_checkpoint_path ${save_checkpoint_path} \
37 | --transformer_layers 24 \
38 | --phase "finetune" \
39 | --obs_pred \
40 | --action_pred_steps 3 \
41 | --sequence_length 7 \
42 | --future_steps 3 \
43 | --window_size 10 \
44 | --loss_image \
45 | --loss_action \
46 | --save_checkpoint_seq 1 \
47 | --start_save_checkpoint 25 \
48 | --gripper_width \
49 | --warmup_epochs 5 \
50 | --libero_path ${libero_path} \
51 | --report_to_wandb \
--------------------------------------------------------------------------------
/scripts/CALVIN_ABC_D/Seer/finetune.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | ### NEED TO CHANGE ###
3 | calvin_dataset_path="calvin/dataset/task_ABC_D"
4 | save_checkpoint_path="checkpoints/"
5 | finetune_from_pretrained_ckpt="checkpoints/pretrain_calvin_abc_d/4.pth"
6 | vit_checkpoint_path="checkpoints/vit_mae/mae_pretrain_vit_base.pth" # downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing
7 | ### NEED TO CHANGE ###
8 |
9 | node=4
10 | node_num=8
11 | torchrun --nnodes=${node} --nproc_per_node=${node_num} --master_port=10211 train.py \
12 | --traj_cons \
13 | --rgb_pad 10 \
14 | --gripper_pad 4 \
15 | --gradient_accumulation_steps 1 \
16 | --bf16_module "vision_encoder" \
17 | --vit_checkpoint_path ${vit_checkpoint_path} \
18 | --calvin_dataset ${calvin_dataset_path} \
19 | --workers 8 \
20 | --lr_scheduler cosine \
21 | --save_every_iter 100000 \
22 | --num_epochs 20 \
23 | --seed 42 \
24 | --batch_size 16 \
25 | --precision fp32 \
26 | --learning_rate 1e-3 \
27 | --save_checkpoint \
28 | --finetune_type "calvin" \
29 | --wandb_project seer \
30 | --weight_decay 1e-4 \
31 | --num_resampler_query 6 \
32 | --run_name finetune_calvin_abc_d_ep5 \
33 | --save_checkpoint_path ${save_checkpoint_path} \
34 | --transformer_layers 24 \
35 | --phase "finetune" \
36 | --action_pred_steps 3 \
37 | --sequence_length 10 \
38 | --future_steps 3 \
39 | --window_size 13 \
40 | --obs_pred \
41 | --loss_image \
42 | --loss_action \
43 | --report_to_wandb \
44 | --reset_action_token \
45 | --reset_obs_token \
46 | --finetune_from_pretrained_ckpt ${finetune_from_pretrained_ckpt} \
--------------------------------------------------------------------------------
/scripts/LIBERO_LONG/Seer/pretrain.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ### NEED TO CHANGE ###
4 | save_checkpoint_path="checkpoints/"
5 | root_dir="PATH_TO_PARENT_DIR_OF_LIBERO_CONVERTED"
6 | vit_checkpoint_path="checkpoints/vit_mae/mae_pretrain_vit_base.pth"
7 | libero_path="PATH_TO_LIBERO"
8 | ### NEED TO CHANGE ###
9 | calvin_dataset_path="calvin/dataset/task_ABC_D"
10 |
11 | node=1
12 | node_num=8
13 | torchrun --nnodes=${node} --nproc_per_node=${node_num} --master_port=10211 train.py \
14 | --traj_cons \
15 | --rgb_pad 10 \
16 | --gripper_pad 4 \
17 | --gradient_accumulation_steps 8 \
18 | --bf16_module "vision_encoder" \
19 | --vit_checkpoint_path ${vit_checkpoint_path} \
20 | --calvin_dataset ${calvin_dataset_path} \
21 | --workers 16 \
22 | --lr_scheduler cosine \
23 | --save_every_iter 100000 \
24 | --num_epochs 30 \
25 | --seed 42 \
26 | --batch_size 10 \
27 | --precision fp32 \
28 | --learning_rate 1e-4 \
29 | --save_checkpoint \
30 | --finetune_type libero_pretrain \
31 | --root_dir ${root_dir} \
32 | --wandb_project seer \
33 | --weight_decay 1e-4 \
34 | --num_resampler_query 6 \
35 | --run_name libero_pretrain \
36 | --save_checkpoint_path ${save_checkpoint_path} \
37 | --transformer_layers 24 \
38 | --phase "pretrain" \
39 | --obs_pred \
40 | --sequence_length 11 \
41 | --action_pred_steps 3 \
42 | --future_steps 3 \
43 | --atten_goal 4 \
44 | --window_size 11 \
45 | --loss_image \
46 | --loss_action \
47 | --gripper_width \
48 | --atten_only_obs \
49 | --atten_goal_state \
50 | --mask_l_obs_ratio 0.5 \
51 | --warmup_epochs 1 \
52 | --libero_path ${libero_path} \
53 | --report_to_wandb \
--------------------------------------------------------------------------------
/docs/LIBERO_LONG_RUN.md:
--------------------------------------------------------------------------------
1 | # Running
2 | ## Notice
3 |
4 | For convenience, some checkpoints, such as the MAE-pretrained ViT-B model, are provided for manual download. Users must update the following paths accordingly. Relevant checkpoints can be acquired from the [website](https://drive.google.com/drive/folders/1zwqGvKKtjyuWdDaNSLVGJprJMPoSqAPk?usp=drive_link).
5 | * :exclamation: **pretrain.sh, finetune.sh, scratch, eval.sh:**
6 | Please update the following:
7 | * **save_checkpoint_path** to the parent directory where your experiment checkpoints are saved. Recommend to create a ```checkpoints``` folder in the project root directory.
8 | * **finetune_from_pretrained_ckpt** to the location of your pre-trained checkpoint.
9 | * **resume_from_checkpoint** to the location of your fine-tuned checkpoint.
10 | * **vit_checkpoint_path** to the location of your ViT checkpoint (downloaded from the [website](https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing)). Recommend to be stored in ```checkpoints/vit_mae/mae_pretrain_vit_base.pth```.
11 | * **libero_path** to the location of LIBERO dir.
12 |
13 | ## Seer
14 | ### Convert Data
15 | ```bash
16 | python utils/convert_libero_per_step.py
17 | ```
18 |
19 | ### Pre-train
20 | ```bash
21 | # Pre-train Seer on LIBERO-90 dataset
22 | bash scripts/LIBERO_LONG/Seer/pretrain.sh
23 | ```
24 |
25 | ### Fine-tune
26 | ```bash
27 | # Fine-tune Seer on LIBERO-10 dataset
28 | bash scripts/LIBERO_LONG/Seer/finetune.sh
29 | ```
30 |
31 | ### Train from Scratch
32 | ```bash
33 | # Train Seer on LIBERO-10 dataset from scratch
34 | bash scripts/LIBERO_LONG/Seer/scratch.sh
35 | ```
36 |
37 | ### Eval
38 | ```bash
39 | # Evaluate Seer on LIBERO-10 benchmark
40 | bash scripts/LIBERO_LONG/Seer/eval.sh
41 | ```
--------------------------------------------------------------------------------
/scripts/CALVIN_ABC_D/Seer-Large/scratch.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | ### need to change to your path ###
3 | calvin_dataset_path="calvin/dataset/task_ABC_D"
4 | save_checkpoint_path="checkpoints/"
5 | vit_checkpoint_path="checkpoints/vit_mae/mae_pretrain_vit_base.pth" # downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing
6 | # node=8
7 | node_num=8
8 | torchrun \
9 | --nnodes=${MLP_WORKER_NUM} \
10 | --node_rank=${MLP_ROLE_INDEX} \
11 | --nproc_per_node=${node_num} \
12 | --master_addr=${MLP_WORKER_0_HOST} \
13 | --master_port=${MLP_WORKER_0_PORT} \
14 | train.py \
15 | --traj_cons \
16 | --rgb_pad 10 \
17 | --gripper_pad 4 \
18 | --gradient_accumulation_steps 1 \
19 | --bf16_module "vision_encoder" \
20 | --vit_checkpoint_path ${vit_checkpoint_path} \
21 | --calvin_dataset ${calvin_dataset_path} \
22 | --workers 8 \
23 | --lr_scheduler cosine \
24 | --save_every_iter 100000 \
25 | --num_epochs 20 \
26 | --seed 42 \
27 | --batch_size 8 \
28 | --precision fp32 \
29 | --learning_rate 1e-3 \
30 | --warmup_epochs 1 \
31 | --finetune_type "calvin" \
32 | --wandb_project seer \
33 | --weight_decay 1e-4 \
34 | --num_resampler_query 16 \
35 | --num_obs_token_per_image 16 \
36 | --run_name scratch-Seer-Large \
37 | --save_checkpoint \
38 | --save_checkpoint_path ${save_checkpoint_path} \
39 | --transformer_layers 24 \
40 | --hidden_dim 1024 \
41 | --transformer_heads 16 \
42 | --phase "finetune" \
43 | --action_pred_steps 3 \
44 | --sequence_length 10 \
45 | --future_steps 3 \
46 | --window_size 13 \
47 | --obs_pred \
48 | --loss_image \
49 | --loss_action \
50 | --report_to_wandb \
51 |
--------------------------------------------------------------------------------
/scripts/LIBERO_LONG/Seer/eval.sh:
--------------------------------------------------------------------------------
1 | pthlist=("30" "31" "32" "33" "34" "35" "36" "37" "38" "39")
2 | for ckpt_id in "${pthlist[@]}"; do
3 | resume_from_checkpoint="/home/tianyang/Checkpoints/libero_scratch"
4 | vit_checkpoint_path="checkpoints/vit_mae/mae_pretrain_vit_base.pth"
5 | this_resume_from_checkpoint="${resume_from_checkpoint}/${ckpt_id}.pth"
6 | save_checkpoint_path="checkpoints/"
7 | dirname=$(basename "$resume_from_checkpoint")
8 | LOG_DIR="/home/tianyang/Code/Eval/${dirname}"
9 | mkdir -p ${LOG_DIR}
10 | test_id="${ckpt_id}"
11 | logfile="${LOG_DIR}/${test_id}.log"
12 |
13 | node=1
14 | node_num=8
15 |
16 | python -m torch.distributed.run --nnodes=${node} --nproc_per_node=${node_num} --master_port=10133 eval_libero.py \
17 | --traj_cons \
18 | --rgb_pad 10 \
19 | --gripper_pad 4 \
20 | --gradient_accumulation_steps 1 \
21 | --bf16_module "vision_encoder" \
22 | --vit_checkpoint_path ${vit_checkpoint_path} \
23 | --calvin_dataset "" \
24 | --workers 16 \
25 | --lr_scheduler cosine \
26 | --save_every_iter 50000 \
27 | --num_epochs 20 \
28 | --seed 42 \
29 | --batch_size 64 \
30 | --precision fp32 \
31 | --weight_decay 1e-4 \
32 | --num_resampler_query 6 \
33 | --run_name test \
34 | --transformer_layers 24 \
35 | --phase "evaluate" \
36 | --finetune_type "libero_10" \
37 | --save_checkpoint_path ${save_checkpoint_path} \
38 | --action_pred_steps 3 \
39 | --future_steps 3 \
40 | --sequence_length 7 \
41 | --obs_pred \
42 | --gripper_width \
43 | --eval_libero_ensembling \
44 | --resume_from_checkpoint ${this_resume_from_checkpoint} | tee ${logfile}
45 | done
--------------------------------------------------------------------------------
/scripts/REAL/slurm_s_full_cluster.sh:
--------------------------------------------------------------------------------
1 | ### NEED TO CHANGE ###
2 | save_checkpoint_path="xxx/checkpoints"
3 | root_dir="xxx/preprocess"
4 | vit_checkpoint_path="xxx/mae_pretrain_vit_base.pth" # downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing
5 | ### NEED TO CHANGE ###
6 |
7 | ### EXAMPLE ###
8 | # - root_dir
9 | # - droid_success
10 | # - epsiodes
11 | # - 000000
12 | # - ......
13 | # - xxxxxx
14 | # - meta_info.h5
15 | # - shape_info.h5
16 | ### EXAMPLE ###
17 |
18 | python slurm_train_intern.py \
19 | --traj_cons \
20 | --rgb_pad 10 \
21 | --gripper_pad 4 \
22 | --gradient_accumulation_steps 2 \
23 | --bf16_module "vision_encoder" \
24 | --vit_checkpoint_path ${vit_checkpoint_path} \
25 | --calvin_dataset "" \
26 | --workers 8 \
27 | --lr_scheduler cosine \
28 | --save_every_iter 20000 \
29 | --num_epochs 30 \
30 | --seed 42 \
31 | --batch_size 32 \
32 | --precision fp32 \
33 | --learning_rate 1e-4 \
34 | --save_checkpoint \
35 | --finetune_type "droid" \
36 | --wandb_project seer \
37 | --weight_decay 1e-4 \
38 | --num_resampler_query 6 \
39 | --run_name mn_full_droid \
40 | --save_checkpoint_path ${save_checkpoint_path} \
41 | --except_lang \
42 | --transformer_layers 24 \
43 | --phase "pretrain" \
44 | --obs_pred \
45 | --action_pred_steps 3 \
46 | --sequence_length 11 \
47 | --window_size 11 \
48 | --future_steps 3 \
49 | --loss_action \
50 | --loss_image \
51 | --atten_goal 4 \
52 | --atten_goal_state \
53 | --atten_only_obs \
54 | --real_dataset_names "" \
55 | --root_dir ${root_dir} \
56 | --dataset_info droid_success_full_0803 \
57 | --report_to_wandb \
58 | --warmup_epochs 3 \
--------------------------------------------------------------------------------
/scripts/REAL/slurm_s_language_cluster.sh:
--------------------------------------------------------------------------------
1 | ### NEED TO CHANGE ###
2 | save_checkpoint_path="xxx/checkpoints"
3 | root_dir="xxx/preprocess"
4 | vit_checkpoint_path="xxx/mae_pretrain_vit_base.pth" # downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing
5 | ### NEED TO CHANGE ###
6 |
7 | ### EXAMPLE ###
8 | # - root_dir
9 | # - droid_success
10 | # - epsiodes
11 | # - 000000
12 | # - ......
13 | # - xxxxxx
14 | # - meta_info.h5
15 | # - shape_info.h5
16 | ### EXAMPLE ###
17 |
18 | python slurm_train_intern.py \
19 | --traj_cons \
20 | --rgb_pad 10 \
21 | --gripper_pad 4 \
22 | --gradient_accumulation_steps 2 \
23 | --bf16_module "vision_encoder" \
24 | --vit_checkpoint_path ${vit_checkpoint_path} \
25 | --calvin_dataset "" \
26 | --workers 8 \
27 | --lr_scheduler cosine \
28 | --save_every_iter 20000 \
29 | --num_epochs 30 \
30 | --seed 42 \
31 | --batch_size 32 \
32 | --precision fp32 \
33 | --learning_rate 1e-4 \
34 | --save_checkpoint \
35 | --finetune_type "droid" \
36 | --wandb_project seer \
37 | --weight_decay 1e-4 \
38 | --num_resampler_query 6 \
39 | --run_name mn_lang_droid \
40 | --save_checkpoint_path ${save_checkpoint_path} \
41 | --except_lang \
42 | --transformer_layers 24 \
43 | --phase "pretrain" \
44 | --obs_pred \
45 | --action_pred_steps 3 \
46 | --sequence_length 11 \
47 | --window_size 11 \
48 | --future_steps 3 \
49 | --loss_action \
50 | --loss_image \
51 | --atten_goal 4 \
52 | --atten_goal_state \
53 | --atten_only_obs \
54 | --real_dataset_names "" \
55 | --root_dir ${root_dir} \
56 | --dataset_info droid_success_languaged_0803 \
57 | --report_to_wandb \
58 | --warmup_epochs 3 \
--------------------------------------------------------------------------------
/scripts/LIBERO_LONG/Seer/finetune.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ### NEED TO CHANGE ###
4 | save_checkpoint_path="checkpoints/"
5 | root_dir="PATH_TO_PARENT_DIR_OF_LIBERO_CONVERTED"
6 | vit_checkpoint_path="checkpoints/vit_mae/mae_pretrain_vit_base.pth"
7 | finetune_from_pretrained_ckpt="libero_pretrain/14.pth"
8 | libero_path="PATH_TO_LIBERO"
9 | ### NEED TO CHANGE ###
10 | calvin_dataset_path="calvin/dataset/task_ABC_D"
11 |
12 | node=1
13 | node_num=8
14 | torchrun --nnodes=${node} --nproc_per_node=${node_num} --master_port=10211 train.py \
15 | --traj_cons \
16 | --rgb_pad 10 \
17 | --gripper_pad 4 \
18 | --gradient_accumulation_steps 4 \
19 | --bf16_module "vision_encoder" \
20 | --vit_checkpoint_path ${vit_checkpoint_path} \
21 | --calvin_dataset ${calvin_dataset_path} \
22 | --workers 8 \
23 | --lr_scheduler cosine \
24 | --save_every_iter 100000 \
25 | --num_epochs 40 \
26 | --seed 42 \
27 | --batch_size 16 \
28 | --precision fp32 \
29 | --learning_rate 1e-3 \
30 | --save_checkpoint \
31 | --finetune_type libero_finetune \
32 | --root_dir ${root_dir} \
33 | --wandb_project seer \
34 | --weight_decay 1e-4 \
35 | --num_resampler_query 6 \
36 | --run_name libero_finetune \
37 | --save_checkpoint_path ${save_checkpoint_path} \
38 | --transformer_layers 24 \
39 | --phase "finetune" \
40 | --obs_pred \
41 | --action_pred_steps 3 \
42 | --sequence_length 7 \
43 | --future_steps 3 \
44 | --window_size 10 \
45 | --loss_image \
46 | --loss_action \
47 | --reset_action_token \
48 | --reset_obs_token \
49 | --save_checkpoint_seq 1 \
50 | --start_save_checkpoint 25 \
51 | --gripper_width \
52 | --warmup_epochs 5 \
53 | --libero_path ${libero_path} \
54 | --finetune_from_pretrained_ckpt ${finetune_from_pretrained_ckpt} \
55 | --report_to_wandb \
--------------------------------------------------------------------------------
/scripts/CALVIN_ABC_D/Seer-Large/pretrain.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | ### need to change to your path ###
3 | calvin_dataset_path="calvin/dataset/task_ABC_D"
4 | save_checkpoint_path="checkpoints/"
5 | vit_checkpoint_path="checkpoints/vit_mae/mae_pretrain_vit_base.pth" # downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing
6 | # node=8
7 | node_num=8
8 | torchrun \
9 | --nnodes=${MLP_WORKER_NUM} \
10 | --node_rank=${MLP_ROLE_INDEX} \
11 | --nproc_per_node=${node_num} \
12 | --master_addr=${MLP_WORKER_0_HOST} \
13 | --master_port=${MLP_WORKER_0_PORT} \
14 | train.py \
15 | --traj_cons \
16 | --rgb_pad 10 \
17 | --gripper_pad 4 \
18 | --gradient_accumulation_steps 1 \
19 | --bf16_module "vision_encoder" \
20 | --vit_checkpoint_path ${vit_checkpoint_path} \
21 | --calvin_dataset ${calvin_dataset_path} \
22 | --workers 8 \
23 | --lr_scheduler cosine \
24 | --save_every_iter 100000 \
25 | --num_epochs 20 \
26 | --seed 42 \
27 | --batch_size 8 \
28 | --precision fp32 \
29 | --learning_rate 1e-4 \
30 | --finetune_type "calvin" \
31 | --wandb_project seer \
32 | --weight_decay 1e-4 \
33 | --num_resampler_query 16 \
34 | --num_obs_token_per_image 16 \
35 | --run_name pretrain_Seer-Large \
36 | --save_checkpoint_path ${save_checkpoint_path} \
37 | --transformer_layers 24 \
38 | --hidden_dim 1024 \
39 | --transformer_heads 16 \
40 | --phase "pretrain" \
41 | --action_pred_steps 3 \
42 | --sequence_length 14 \
43 | --future_steps 3 \
44 | --window_size 17 \
45 | --obs_pred \
46 | --loss_image \
47 | --loss_action \
48 | --atten_goal 4 \
49 | --atten_goal_state \
50 | --atten_only_obs \
51 | --attn_robot_proprio_state \
52 | --except_lang \
53 | --save_checkpoint \
54 | --report_to_wandb \
55 |
--------------------------------------------------------------------------------
/scripts/REAL/single_node_full_cluster.sh:
--------------------------------------------------------------------------------
1 | ### NEED TO CHANGE ###
2 | save_checkpoint_path="xxx/checkpoints"
3 | root_dir="xxx/preprocess"
4 | vit_checkpoint_path="xxx/mae_pretrain_vit_base.pth" # downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing
5 | ### NEED TO CHANGE ###
6 |
7 | ### EXAMPLE ###
8 | # - root_dir
9 | # - droid_success
10 | # - epsiodes
11 | # - 000000
12 | # - ......
13 | # - xxxxxx
14 | # - meta_info.h5
15 | # - shape_info.h5
16 | ### EXAMPLE ###
17 |
18 | node=4
19 | node_num=8
20 | torchrun --nnodes=${node} --nproc_per_node=${node_num} --master_port=62345 train.py \
21 | --traj_cons \
22 | --rgb_pad 10 \
23 | --gripper_pad 4 \
24 | --gradient_accumulation_steps 2 \
25 | --bf16_module "vision_encoder" \
26 | --vit_checkpoint_path ${vit_checkpoint_path} \
27 | --calvin_dataset "" \
28 | --workers 8 \
29 | --lr_scheduler cosine \
30 | --save_every_iter 20000 \
31 | --num_epochs 20 \
32 | --seed 42 \
33 | --batch_size 32 \
34 | --precision fp32 \
35 | --learning_rate 1e-4 \
36 | --save_checkpoint \
37 | --finetune_type "droid" \
38 | --wandb_project seer \
39 | --weight_decay 1e-4 \
40 | --num_resampler_query 6 \
41 | --run_name sn_full_droid \
42 | --save_checkpoint_path ${save_checkpoint_path} \
43 | --except_lang \
44 | --transformer_layers 24 \
45 | --phase "pretrain" \
46 | --obs_pred \
47 | --action_pred_steps 3 \
48 | --sequence_length 11 \
49 | --window_size 11 \
50 | --future_steps 3 \
51 | --loss_action \
52 | --loss_image \
53 | --atten_goal 4 \
54 | --atten_goal_state \
55 | --atten_only_obs \
56 | --real_dataset_names "" \
57 | --root_dir ${root_dir} \
58 | --dataset_info droid_success_full_0803 \
59 | --warmup_epochs 3 \
60 | --report_to_wandb \
61 |
--------------------------------------------------------------------------------
/scripts/CALVIN_ABC_D/Seer/eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export GIT_PYTHON_REFRESH=quiet
3 | calvin_dataset_path="calvin/dataset/task_ABC_D"
4 | calvin_conf_path="calvin/calvin_models/conf"
5 | vit_checkpoint_path="checkpoints/vit_mae/mae_pretrain_vit_base.pth" # downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing
6 | save_checkpoint_path="checkpoints/"
7 | ### NEED TO CHANGE the checkpoint path ###
8 | resume_from_checkpoint="checkpoints/xxx/xx.pth"
9 | IFS='/' read -ra path_parts <<< "$resume_from_checkpoint"
10 | run_name="${path_parts[-2]}"
11 | log_name="${path_parts[-1]}"
12 | log_folder="eval_logs/$run_name"
13 | mkdir -p "$log_folder"
14 | log_file="eval_logs/$run_name/evaluate_$log_name.log"
15 | node=1
16 | node_num=8
17 |
18 | torchrun --nnodes=${node} --nproc_per_node=${node_num} --master_port=10012 eval_calvin.py \
19 | --traj_cons \
20 | --rgb_pad 10 \
21 | --gripper_pad 4 \
22 | --gradient_accumulation_steps 1 \
23 | --bf16_module "vision_encoder" \
24 | --vit_checkpoint_path ${vit_checkpoint_path} \
25 | --calvin_dataset ${calvin_dataset_path} \
26 | --calvin_conf_path ${calvin_conf_path} \
27 | --workers 16 \
28 | --lr_scheduler cosine \
29 | --save_every_iter 50000 \
30 | --num_epochs 20 \
31 | --seed 42 \
32 | --batch_size 64 \
33 | --precision fp32 \
34 | --weight_decay 1e-4 \
35 | --num_resampler_query 6 \
36 | --num_obs_token_per_image 9 \
37 | --run_name ${run_name} \
38 | --save_checkpoint_path ${save_checkpoint_path} \
39 | --transformer_layers 24 \
40 | --hidden_dim 384 \
41 | --transformer_heads 12 \
42 | --phase "evaluate" \
43 | --finetune_type "calvin" \
44 | --action_pred_steps 3 \
45 | --sequence_length 10 \
46 | --future_steps 3 \
47 | --window_size 13 \
48 | --obs_pred \
49 | --resume_from_checkpoint ${resume_from_checkpoint} | tee ${log_file} \
50 |
--------------------------------------------------------------------------------
/scripts/REAL/single_node_language_cluster.sh:
--------------------------------------------------------------------------------
1 | ### NEED TO CHANGE ###
2 | save_checkpoint_path="xxx/checkpoints"
3 | root_dir="xxx/preprocess"
4 | vit_checkpoint_path="xxx/mae_pretrain_vit_base.pth" # downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing
5 | ### NEED TO CHANGE ###
6 |
7 | ### EXAMPLE ###
8 | # - root_dir
9 | # - droid_success
10 | # - epsiodes
11 | # - 000000
12 | # - ......
13 | # - xxxxxx
14 | # - meta_info.h5
15 | # - shape_info.h5
16 | ### EXAMPLE ###
17 |
18 | node=4
19 | node_num=8
20 | torchrun --nnodes=${node} --nproc_per_node=${node_num} --master_port=62345 train.py \
21 | --traj_cons \
22 | --rgb_pad 10 \
23 | --gripper_pad 4 \
24 | --gradient_accumulation_steps 2 \
25 | --bf16_module "vision_encoder" \
26 | --vit_checkpoint_path ${vit_checkpoint_path} \
27 | --calvin_dataset "" \
28 | --workers 8 \
29 | --lr_scheduler cosine \
30 | --save_every_iter 20000 \
31 | --num_epochs 20 \
32 | --seed 42 \
33 | --batch_size 32 \
34 | --precision fp32 \
35 | --learning_rate 1e-4 \
36 | --save_checkpoint \
37 | --finetune_type "droid" \
38 | --wandb_project seer \
39 | --weight_decay 1e-4 \
40 | --num_resampler_query 6 \
41 | --run_name sn_lang_droid \
42 | --save_checkpoint_path ${save_checkpoint_path} \
43 | --except_lang \
44 | --transformer_layers 24 \
45 | --phase "pretrain" \
46 | --obs_pred \
47 | --action_pred_steps 3 \
48 | --sequence_length 11 \
49 | --window_size 11 \
50 | --future_steps 3 \
51 | --loss_action \
52 | --loss_image \
53 | --atten_goal 4 \
54 | --atten_goal_state \
55 | --atten_only_obs \
56 | --real_dataset_names "" \
57 | --root_dir ${root_dir} \
58 | --dataset_info droid_success_languaged_0803 \
59 | --warmup_epochs 3 \
60 | --report_to_wandb \
61 |
--------------------------------------------------------------------------------
/scripts/CALVIN_ABC_D/Seer-Large/eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export GIT_PYTHON_REFRESH=quiet
3 | calvin_dataset_path="calvin/dataset/task_ABC_D"
4 | calvin_conf_path="calvin/calvin_models/conf"
5 | vit_checkpoint_path="checkpoints/vit_mae/mae_pretrain_vit_base.pth" # downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing
6 | save_checkpoint_path="checkpoints/"
7 | ### NEED TO CHANGE the checkpoint path ###
8 | resume_from_checkpoint="checkpoints/xxx/xxx.pth"
9 |
10 |
11 | IFS='/' read -ra path_parts <<< "$resume_from_checkpoint"
12 | run_name="${path_parts[-2]}"
13 | log_name="${path_parts[-1]}"
14 | log_folder="eval_logs/$run_name"
15 | mkdir -p "$log_folder"
16 | log_file="eval_logs/$run_name/evaluate_$log_name.log"
17 | node=1
18 | node_num=8
19 |
20 | torchrun --nnodes=${node} --nproc_per_node=${node_num} --master_port=10212 eval_calvin.py\
21 | --traj_cons \
22 | --rgb_pad 10 \
23 | --gripper_pad 4 \
24 | --gradient_accumulation_steps 1 \
25 | --bf16_module "vision_encoder" \
26 | --vit_checkpoint_path ${vit_checkpoint_path} \
27 | --calvin_dataset ${calvin_dataset_path} \
28 | --calvin_conf_path ${calvin_conf_path} \
29 | --workers 16 \
30 | --lr_scheduler cosine \
31 | --save_every_iter 50000 \
32 | --num_epochs 20 \
33 | --seed 42 \
34 | --batch_size 64 \
35 | --precision fp32 \
36 | --weight_decay 1e-4 \
37 | --num_resampler_query 16 \
38 | --num_obs_token_per_image 16 \
39 | --run_name ${run_name} \
40 | --save_checkpoint_path ${save_checkpoint_path} \
41 | --transformer_layers 24 \
42 | --hidden_dim 1024 \
43 | --transformer_heads 16 \
44 | --phase "evaluate" \
45 | --finetune_type "calvin" \
46 | --action_pred_steps 3 \
47 | --sequence_length 10 \
48 | --future_steps 3 \
49 | --window_size 13 \
50 | --obs_pred \
51 | --resume_from_checkpoint ${resume_from_checkpoint} | tee ${log_file} \
52 |
--------------------------------------------------------------------------------
/scripts/REAL/single_node_scratch.sh:
--------------------------------------------------------------------------------
1 | ### NEED TO CHANGE ###
2 | save_checkpoint_path="xxx/checkpoints"
3 | root_dir="your_path_to_the_parent_folder_of_real_data"
4 | real_dataset_names="your_real_dataset_name"
5 | vit_checkpoint_path="xxx/mae_pretrain_vit_base.pth" # downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing
6 | ### NEED TO CHANGE ###
7 |
8 | ### EXAMPLE ###
9 | # - root_dir
10 | # - real_dataset_names
11 | # - 0000
12 | # - 000000
13 | # - ......
14 | # - xxxxxx
15 | # - ....
16 | # - 00xx
17 | ### EXAMPLE ###
18 |
19 | node=1
20 | node_num=8
21 | torchrun --nnodes=${node} --nproc_per_node=${node_num} --master_port=10211 train.py \
22 | --traj_cons \
23 | --rgb_pad 10 \
24 | --gripper_pad 4 \
25 | --gradient_accumulation_steps 4 \
26 | --bf16_module "vision_encoder" \
27 | --vit_checkpoint_path ${vit_checkpoint_path} \
28 | --calvin_dataset "" \
29 | --workers 8 \
30 | --lr_scheduler cosine \
31 | --save_every_iter 100000 \
32 | --num_epochs 40 \
33 | --seed 42 \
34 | --batch_size 16 \
35 | --precision fp32 \
36 | --learning_rate 1e-3 \
37 | --save_checkpoint \
38 | --finetune_type real \
39 | --root_dir ${root_dir} \
40 | --wandb_project seer \
41 | --weight_decay 1e-4 \
42 | --num_resampler_query 6 \
43 | --run_name sn_scratch \
44 | --save_checkpoint_path ${save_checkpoint_path} \
45 | --except_lang \
46 | --transformer_layers 24 \
47 | --phase "finetune" \
48 | --action_pred_steps 3 \
49 | --sequence_length 7 \
50 | --future_steps 3 \
51 | --window_size 10 \
52 | --obs_pred \
53 | --loss_action \
54 | --loss_image \
55 | --save_checkpoint_seq 1 \
56 | --start_save_checkpoint 15 \
57 | --warmup_epochs 5 \
58 | --real_dataset_names ${real_dataset_names} \
59 | --use_aug_data \
60 | --report_to_wandb \
61 |
--------------------------------------------------------------------------------
/scripts/CALVIN_ABC_D/Seer-Large/finetune.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | ### need to change to your path ###
3 | calvin_dataset_path="calvin/dataset/task_ABC_D"
4 | save_checkpoint_path="checkpoints/"
5 | finetune_from_pretrained_ckpt="checkpoints/pretrain_Seer_ptbs512_24layers_16heads_hd1024-Large/5.pth"
6 | vit_checkpoint_path="checkpoints/vit_mae/mae_pretrain_vit_base.pth" # downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing
7 | # node=8
8 | node_num=8
9 | torchrun \
10 | --nnodes=${MLP_WORKER_NUM} \
11 | --node_rank=${MLP_ROLE_INDEX} \
12 | --nproc_per_node=${node_num} \
13 | --master_addr=${MLP_WORKER_0_HOST} \
14 | --master_port=${MLP_WORKER_0_PORT} \
15 | train.py \
16 | --traj_cons \
17 | --rgb_pad 10 \
18 | --gripper_pad 4 \
19 | --gradient_accumulation_steps 1 \
20 | --bf16_module "vision_encoder" \
21 | --vit_checkpoint_path ${vit_checkpoint_path} \
22 | --calvin_dataset ${calvin_dataset_path} \
23 | --workers 8 \
24 | --lr_scheduler cosine \
25 | --save_every_iter 100000 \
26 | --num_epochs 20 \
27 | --seed 42 \
28 | --batch_size 8 \
29 | --precision fp32 \
30 | --learning_rate 1e-3 \
31 | --warmup_epochs 1 \
32 | --finetune_type "calvin" \
33 | --wandb_project seer \
34 | --weight_decay 1e-4 \
35 | --num_resampler_query 16 \
36 | --num_obs_token_per_image 16 \
37 | --run_name finetune_Seer-Large \
38 | --save_checkpoint_path ${save_checkpoint_path} \
39 | --transformer_layers 24 \
40 | --hidden_dim 1024 \
41 | --transformer_heads 16 \
42 | --phase "finetune" \
43 | --action_pred_steps 3 \
44 | --sequence_length 10 \
45 | --future_steps 3 \
46 | --window_size 13 \
47 | --obs_pred \
48 | --loss_image \
49 | --loss_action \
50 | --save_checkpoint \
51 | --report_to_wandb \
52 | --finetune_from_pretrained_ckpt ${finetune_from_pretrained_ckpt} \
53 |
54 |
55 |
--------------------------------------------------------------------------------
/scripts/REAL/single_node_ft.sh:
--------------------------------------------------------------------------------
1 | ### NEED TO CHANGE ###
2 | save_checkpoint_path="xxx/checkpoints"
3 | root_dir="your_path_to_the_parent_folder_of_real_data"
4 | real_dataset_names="your_real_dataset_name"
5 | finetune_from_pretrained_ckpt="xxx/xxx.pth"
6 | vit_checkpoint_path="xxx/mae_pretrain_vit_base.pth" # downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing
7 | ### NEED TO CHANGE ###
8 |
9 | ### EXAMPLE ###
10 | # - root_dir
11 | # - real_dataset_names
12 | # - 0000
13 | # - 000000
14 | # - ......
15 | # - xxxxxx
16 | # - ....
17 | # - 00xx
18 | ### EXAMPLE ###
19 |
20 | node=1
21 | node_num=8
22 | torchrun --nnodes=${node} --nproc_per_node=${node_num} --master_port=10211 train.py \
23 | --traj_cons \
24 | --rgb_pad 10 \
25 | --gripper_pad 4 \
26 | --gradient_accumulation_steps 4 \
27 | --bf16_module "vision_encoder" \
28 | --vit_checkpoint_path ${vit_checkpoint_path} \
29 | --calvin_dataset "" \
30 | --workers 8 \
31 | --lr_scheduler cosine \
32 | --save_every_iter 100000 \
33 | --num_epochs 40 \
34 | --seed 42 \
35 | --batch_size 16 \
36 | --precision fp32 \
37 | --learning_rate 1e-3 \
38 | --save_checkpoint \
39 | --finetune_type real \
40 | --root_dir ${root_dir} \
41 | --wandb_project seer \
42 | --weight_decay 1e-4 \
43 | --num_resampler_query 6 \
44 | --run_name sn_ft \
45 | --save_checkpoint_path ${save_checkpoint_path} \
46 | --except_lang \
47 | --transformer_layers 24 \
48 | --phase "finetune" \
49 | --action_pred_steps 3 \
50 | --sequence_length 7 \
51 | --future_steps 3 \
52 | --window_size 10 \
53 | --obs_pred \
54 | --loss_action \
55 | --loss_image \
56 | --save_checkpoint_seq 1 \
57 | --start_save_checkpoint 15 \
58 | --warmup_epochs 5 \
59 | --real_dataset_names ${real_dataset_names} \
60 | --reset_action_token \
61 | --reset_obs_token \
62 | --use_aug_data \
63 | --report_to_wandb \
64 | --finetune_from_pretrained_ckpt ${finetune_from_pretrained_ckpt} \
65 |
--------------------------------------------------------------------------------
/real_preprocess/octo_oxe_data_utils/utils/task_augmentation.py:
--------------------------------------------------------------------------------
1 | """
2 | Contains basic logic for randomly zero-ing out keys in the task specification.
3 | """
4 |
5 | import tensorflow as tf
6 |
7 | from octo_oxe_data_utils.utils.data_utils import to_padding
8 |
9 |
10 | def delete_task_conditioning(
11 | traj: dict,
12 | keep_image_prob: float,
13 | ):
14 | """
15 | Randomly drops out either the goal images or the language instruction. Only does something if both of
16 | these are present.
17 |
18 | Args:
19 | traj: A dictionary containing trajectory data. Should have a "task" key.
20 | keep_image_prob: The probability of keeping the goal images. The probability of keeping the language
21 | instruction is 1 - keep_image_prob.
22 | """
23 | if "language_instruction" not in traj["task"]:
24 | return traj
25 |
26 | image_keys = {
27 | key
28 | for key in traj["task"].keys()
29 | if key.startswith("image_") or key.startswith("depth_")
30 | }
31 | if not image_keys:
32 | return traj
33 |
34 | traj_len = tf.shape(traj["action"])[0]
35 | should_keep_images = tf.random.uniform([traj_len]) < keep_image_prob
36 | should_keep_images |= ~traj["task"]["pad_mask_dict"]["language_instruction"]
37 |
38 | for key in image_keys | {"language_instruction"}:
39 | should_keep = should_keep_images if key in image_keys else ~should_keep_images
40 | # pad out the key
41 | traj["task"][key] = tf.where(
42 | should_keep,
43 | traj["task"][key],
44 | to_padding(traj["task"][key]),
45 | )
46 | # zero out the pad mask dict for the key
47 | traj["task"]["pad_mask_dict"][key] = tf.where(
48 | should_keep,
49 | traj["task"]["pad_mask_dict"][key],
50 | tf.zeros_like(traj["task"]["pad_mask_dict"][key]),
51 | )
52 |
53 | # when no goal images are present, the goal timestep becomes the final timestep
54 | traj["task"]["timestep"] = tf.where(
55 | should_keep_images,
56 | traj["task"]["timestep"],
57 | traj_len - 1,
58 | )
59 |
60 | return traj
61 |
--------------------------------------------------------------------------------
/deploy.py:
--------------------------------------------------------------------------------
1 | from real_controller.controller import SeerController
2 | import time
3 | import numpy as np
4 | from scipy.spatial.transform import Rotation as R
5 |
6 |
7 | def _6d_to_pose(
8 | pose6d,
9 | degrees=False
10 | ):
11 | pose = np.eye(4)
12 | pose[:3, 3] = pose6d[:3]
13 | pose[:3, :3] = R.from_euler("xyz", pose6d[3:6], degrees=degrees).as_matrix()
14 |
15 | return pose
16 |
17 | def pose_to_6d(
18 | pose,
19 | degrees=False
20 | ):
21 | pose6d = np.zeros(6)
22 | pose6d[:3] = pose[:3, 3]
23 | pose6d[3:6] = R.from_matrix(pose[:3, :3]).as_euler("xyz", degrees=degrees)
24 |
25 | return pose6d
26 |
27 | ### pseudocode example to deploy seer ###
28 |
29 | # Recommend control frequency: 15Hz, same as Droid
30 | control_freq = 15 # Hz
31 | max_rel_pos = 0.02 # magic number, same as training
32 | max_rel_orn = 0.05 # magic number, same as training
33 |
34 | # set a controller
35 | controller = SeerController()
36 | last2robot_pose = env.get_robot_state()["pose"] # absolute 4x4 pose matrix in robot space
37 | # warm up
38 | for i in range(3):
39 | obs = {}
40 | obs["robot_state"] = env.get_robot_state()
41 | obs["color_image"] = env.get_color_images()
42 | target_pos, target_euler, target_gripper, _ = controller.forward(obs, include_info=True)
43 |
44 | while True:
45 | # at each time step t
46 | torch.cuda.synchronize()
47 | t1 = time.time()
48 |
49 | obs["robot_state"] = env.get_robot_state()
50 | obs["color_image"] = env.get_color_images()
51 | target_pos, target_euler, target_gripper, _ = controller.forward(obs, include_info=True)
52 |
53 | # delta-action-2-absolute-action
54 | target_pos *= self.max_rel_pos
55 | target_euler *= self.max_rel_orn
56 | cur2last_pose = _6d_to_pose(np.concatenate([target_pos, target_euler]))
57 | last2robot_pose = last2robot_pose @ cur2last_pose
58 | target_pose = pose_to_6d(last2robot_pose)
59 |
60 | torch.cuda.synchronize()
61 | t2 = time.time()
62 | sleep_left = 1. / control_freq - (t2 - t1)
63 |
64 | if sleep_left > 0:
65 | time.sleep(sleep_left)
66 |
67 | env.step(target_pose, target_gripper)
68 |
69 | ### pseudocode example to deploy seer ###
70 |
71 |
72 |
73 |
74 |
75 |
--------------------------------------------------------------------------------
/docs/CALVIN_ABC-D_RUN.md:
--------------------------------------------------------------------------------
1 | # Running
2 | ## Notice
3 |
4 | For convenience, some checkpoints, such as the MAE-pretrained ViT-B model, are provided for manual download. Users must update the following paths accordingly. Relevant checkpoints can be acquired from the [website](https://drive.google.com/drive/folders/1F3IE95z2THAQ_lt3DKUFdRGc86Thsnc7?usp=sharing).
5 | * :exclamation: **pretrain.sh, finetune.sh, scratch, eval.sh:**
6 | Please update the following:
7 | * **calvin_dataset_path** to the directory where you have stored the CALVIN ABC-D data.
8 | * **save_checkpoint_path** to the parent directory where your experiment checkpoints are saved. Recommend to create a ```checkpoints``` folder in the project root directory.
9 | * **finetune_from_pretrained_ckpt** to the location of your pre-trained checkpoint.
10 | * **resume_from_checkpoint** to the location of your fine-tuned checkpoint.
11 | * **vit_checkpoint_path** to the location of your ViT checkpoint (downloaded from the [website](https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing)). Recommend to be stored in ```checkpoints/vit_mae/mae_pretrain_vit_base.pth```.
12 |
13 | * :exclamation: **networkx:**
14 | Due to compatibility issues between the networkx library in CALVIN and Python 3.10, we provide a compatible version of networkx.zip on the [website](https://drive.google.com/file/d/1z-d1SaI0rXfBtBicw1zPSsP-wE-26oLq/view?usp=sharing). Download and unzip it, then replace the existing networkx library in the following path:
15 |
16 | ## Seer
17 | ### Pre-train
18 | ```bash
19 | # Pre-train Seer on Calvin ABC-D dataset
20 | bash scripts/CALVIN_ABC_D/Seer/pretrain.sh
21 | # Pre-train Seer-Large on Calvin ABC-D dataset
22 | bash scripts/CALVIN_ABC_D/Seer-Large/pretrain.sh
23 | ```
24 |
25 | ### Fine-tune
26 | ```bash
27 | # Fine-tune Seer on Calvin ABC-D dataset
28 | bash scripts/CALVIN_ABC_D/Seer/finetune.sh
29 | # Fine-tune Seer-Large on Calvin ABC-D dataset
30 | bash scripts/CALVIN_ABC_D/Seer-Large/finetune.sh
31 | ```
32 |
33 | ### Train from Scratch
34 | ```bash
35 | # Train Seer on Calvin ABC-D dataset from scratch
36 | bash scripts/CALVIN_ABC_D/Seer/scratch.sh
37 | # Train Seer-Large on Calvin ABC-D dataset from scratch
38 | bash scripts/CALVIN_ABC_D/Seer-Large/scratch.sh
39 | ```
40 |
41 | ### Eval
42 | ```bash
43 | # Evaluate Seer on Calvin ABC-D benchmark
44 | bash scripts/CALVIN_ABC_D/Seer/eval.sh
45 | # Evaluate Seer-Large on Calvin ABC-D benchmark
46 | bash scripts/CALVIN_ABC_D/Seer-Large/eval.sh
47 | ```
48 |
49 |
--------------------------------------------------------------------------------
/real_preprocess/octo_oxe_data_utils/utils/text_processing.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Optional, Sequence
3 |
4 | import numpy as np
5 | import tensorflow as tf
6 |
7 | MULTI_MODULE = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3"
8 |
9 |
10 | class TextProcessor(ABC):
11 | """
12 | Base class for text tokenization or text embedding.
13 | """
14 |
15 | @abstractmethod
16 | def encode(self, strings: Sequence[str]):
17 | raise NotImplementedError
18 |
19 |
20 | class HFTokenizer(TextProcessor):
21 | def __init__(
22 | self,
23 | tokenizer_name: str,
24 | tokenizer_kwargs: Optional[dict] = {
25 | "max_length": 64,
26 | "padding": "max_length",
27 | "truncation": True,
28 | "return_tensors": "np",
29 | },
30 | encode_with_model: bool = False,
31 | ):
32 | from transformers import AutoTokenizer, FlaxAutoModel # lazy import
33 |
34 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
35 | self.tokenizer_kwargs = tokenizer_kwargs
36 | self.encode_with_model = encode_with_model
37 | if self.encode_with_model:
38 | self.model = FlaxAutoModel.from_pretrained(tokenizer_name)
39 |
40 | def encode(self, strings: Sequence[str]):
41 | # this creates another nested layer with "input_ids", "attention_mask", etc.
42 | inputs = self.tokenizer(
43 | strings,
44 | **self.tokenizer_kwargs,
45 | )
46 | if self.encode_with_model:
47 | return np.array(self.model(**inputs).last_hidden_state)
48 | else:
49 | return dict(inputs)
50 |
51 |
52 | class MuseEmbedding(TextProcessor):
53 | def __init__(self):
54 | import tensorflow_hub as hub # lazy import
55 | import tensorflow_text # noqa: F401
56 |
57 | self.muse_model = hub.load(MULTI_MODULE)
58 |
59 | def encode(self, strings: Sequence[str]):
60 | with tf.device("/cpu:0"):
61 | return self.muse_model(strings).numpy()
62 |
63 |
64 | class CLIPTextProcessor(TextProcessor):
65 | def __init__(
66 | self,
67 | tokenizer_kwargs: Optional[dict] = {
68 | "max_length": 64,
69 | "padding": "max_length",
70 | "truncation": True,
71 | "return_tensors": "np",
72 | },
73 | ):
74 | from transformers import CLIPProcessor
75 |
76 | self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
77 | self.kwargs = tokenizer_kwargs
78 |
79 | def encode(self, strings: Sequence[str]):
80 | inputs = self.processor(
81 | text=strings,
82 | **self.kwargs,
83 | )
84 | inputs["position_ids"] = np.expand_dims(
85 | np.arange(inputs["input_ids"].shape[1]), axis=0
86 | ).repeat(inputs["input_ids"].shape[0], axis=0)
87 | return inputs
88 |
--------------------------------------------------------------------------------
/data_info/austin_buds_dataset_converted_externally_to_rlds.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "dataset_name": "austin_buds_dataset_converted_externally_to_rlds",
4 | "wrist_image": "Flip vertically & horizontally",
5 | "s_ratio": 1.0,
6 | "accumulated_num_steps": 34112
7 | },
8 | [
9 | "episodes/000000",
10 | 640
11 | ],
12 | [
13 | "episodes/000001",
14 | 827
15 | ],
16 | [
17 | "episodes/000002",
18 | 639
19 | ],
20 | [
21 | "episodes/000003",
22 | 694
23 | ],
24 | [
25 | "episodes/000004",
26 | 786
27 | ],
28 | [
29 | "episodes/000005",
30 | 549
31 | ],
32 | [
33 | "episodes/000006",
34 | 668
35 | ],
36 | [
37 | "episodes/000007",
38 | 623
39 | ],
40 | [
41 | "episodes/000008",
42 | 741
43 | ],
44 | [
45 | "episodes/000009",
46 | 671
47 | ],
48 | [
49 | "episodes/000010",
50 | 624
51 | ],
52 | [
53 | "episodes/000011",
54 | 722
55 | ],
56 | [
57 | "episodes/000012",
58 | 602
59 | ],
60 | [
61 | "episodes/000013",
62 | 645
63 | ],
64 | [
65 | "episodes/000014",
66 | 640
67 | ],
68 | [
69 | "episodes/000015",
70 | 630
71 | ],
72 | [
73 | "episodes/000016",
74 | 603
75 | ],
76 | [
77 | "episodes/000017",
78 | 788
79 | ],
80 | [
81 | "episodes/000018",
82 | 727
83 | ],
84 | [
85 | "episodes/000019",
86 | 758
87 | ],
88 | [
89 | "episodes/000020",
90 | 708
91 | ],
92 | [
93 | "episodes/000021",
94 | 670
95 | ],
96 | [
97 | "episodes/000022",
98 | 550
99 | ],
100 | [
101 | "episodes/000023",
102 | 652
103 | ],
104 | [
105 | "episodes/000024",
106 | 688
107 | ],
108 | [
109 | "episodes/000025",
110 | 634
111 | ],
112 | [
113 | "episodes/000026",
114 | 746
115 | ],
116 | [
117 | "episodes/000027",
118 | 643
119 | ],
120 | [
121 | "episodes/000028",
122 | 675
123 | ],
124 | [
125 | "episodes/000029",
126 | 734
127 | ],
128 | [
129 | "episodes/000030",
130 | 668
131 | ],
132 | [
133 | "episodes/000031",
134 | 698
135 | ],
136 | [
137 | "episodes/000032",
138 | 669
139 | ],
140 | [
141 | "episodes/000033",
142 | 623
143 | ],
144 | [
145 | "episodes/000034",
146 | 694
147 | ],
148 | [
149 | "episodes/000035",
150 | 731
151 | ],
152 | [
153 | "episodes/000036",
154 | 575
155 | ],
156 | [
157 | "episodes/000037",
158 | 636
159 | ],
160 | [
161 | "episodes/000038",
162 | 637
163 | ],
164 | [
165 | "episodes/000039",
166 | 713
167 | ],
168 | [
169 | "episodes/000040",
170 | 750
171 | ],
172 | [
173 | "episodes/000041",
174 | 751
175 | ],
176 | [
177 | "episodes/000042",
178 | 683
179 | ],
180 | [
181 | "episodes/000043",
182 | 649
183 | ],
184 | [
185 | "episodes/000044",
186 | 766
187 | ],
188 | [
189 | "episodes/000045",
190 | 676
191 | ],
192 | [
193 | "episodes/000046",
194 | 671
195 | ],
196 | [
197 | "episodes/000047",
198 | 794
199 | ],
200 | [
201 | "episodes/000048",
202 | 737
203 | ],
204 | [
205 | "episodes/000049",
206 | 714
207 | ]
208 | ]
--------------------------------------------------------------------------------
/docs/REAL-WORLD_PREPROCESS.md:
--------------------------------------------------------------------------------
1 | # Pre-process
2 | Seer exclusively utilizes the [DROID](https://droid-dataset.github.io/) dataset for pre-training. In this section, we describe the data pre-processing and transformation steps for both the [DROID](https://droid-dataset.github.io/) and [OXE](https://robotics-transformer-x.github.io/) datasets. These transformations convert the RLDS format into a standard dataset format, including .png, .npz, and .h5 files. The transformed dataset is organized as follows: /subset_name/episodes/000000/steps/0000/xxx.jpg (h5).
3 | The pre-processing step also unifies action labels across different subsets. For example, it standardizes all control methods to use the delta end-effector pose control, ensuring consistency in the robot's base and end-effector origin and axes. This carefully designed alignment process minimizes confusion caused by different robots and control methods.
4 | To facilitate this process, we create a new environment, seer_pre, which is specifically used for pre-processing the DROID and OXE datasets into our desired format.
5 |
6 |
7 | ## seer_pre env
8 | **(1) Env**
9 | ```python
10 | conda create -n seer_pre python=3.10
11 | conda activate seer_pre
12 | ```
13 | **(2) Move to real_preprocess**
14 | ```python
15 | cd ${YOUR_PATH_TO_SEER}/real_preprocess
16 | ```
17 | **(3) Third Party Packages**
18 | ```python
19 | pip install -r requirements.txt
20 | ```
21 | **(4) octo_oxe_data_utils (Optional for DROID, Required for OXE)**
22 | ```python
23 | cd octo_oxe_data_utils
24 | python setup.py install
25 | cd ..
26 | ```
27 | **(5) Mujoco**
28 | ```python
29 | pip install mujoco
30 | ```
31 | **(6) Torch**
32 | ```python
33 | pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu118
34 | ```
35 | **(7) Dlimp (Important):**
36 | We try to use multiprocess to process data. However, the dataset.py in Dlimp introduce randomness.
37 | replace the dataset.py in /your_anaconda/envs/seer_pre/lib/python3.10/site-packages/dlimp/dataset.py with the one in [dlimp/dataset.py](../real_preprocess/dlimp/dataset.py)
38 |
39 | ## Run Instructions
40 | You can download the full DROID dataset (1.7TB) in RLDS format using the following command:
41 | ```python
42 | gsutil -m cp -r gs://gresearch/robotics/droid
43 | ```
44 | If needed, follow the download instructions provided on the [OXE Github page](https://github.com/google-deepmind/open_x_embodiment).
45 |
46 | Preparation
47 | ```python
48 | cd ${YOUR_PATH_TO_SEER}/real_preprocess
49 | conda activate seer_pre
50 | ```
51 | To process the DROID dataset, set the src_dir and tgt_dir paths. You can adjust the num_worker argument to specify the number of processes to use:
52 | ```python
53 | python convert_public_droid_to_h5_per_step.py
54 | ```
55 | For processing the Franka subsets in the OXE dataset, update the src_root_dir and tgt_dataset_dir paths. Similarly, adjust the num_worker argument for parallel processing:
56 | ```python
57 | python convert_tfds_to_h5_per_step_oxe_franka.py
58 | ```
59 | To process other subsets of the OXE dataset (excluding Franka), update the src_root_dir and tgt_dataset_dir paths and set the number of worker processes:
60 | ```python
61 | python convert_tfds_to_h5_per_step_oxe_others.py
62 | ```
63 |
--------------------------------------------------------------------------------
/data_info/.hydra/config.yaml:
--------------------------------------------------------------------------------
1 | cameras:
2 | static:
3 | _target_: vr_env.camera.static_camera.StaticCamera
4 | name: static
5 | fov: 10
6 | aspect: 1
7 | nearval: 0.01
8 | farval: 10
9 | width: 200
10 | height: 200
11 | look_at:
12 | - -0.026242351159453392
13 | - -0.0302329882979393
14 | - 0.3920000493526459
15 | look_from:
16 | - 2.871459009488717
17 | - -2.166602199425597
18 | - 2.555159848480571
19 | up_vector:
20 | - 0.4041403970338857
21 | - 0.22629790978217404
22 | - 0.8862616969685161
23 | gripper:
24 | _target_: vr_env.camera.gripper_camera.GripperCamera
25 | name: gripper
26 | fov: 75
27 | aspect: 1
28 | nearval: 0.01
29 | farval: 2
30 | width: 84
31 | height: 84
32 | tactile:
33 | _target_: vr_env.camera.tactile_sensor.TactileSensor
34 | name: tactile
35 | width: 120
36 | height: 160
37 | digit_link_ids:
38 | - 10
39 | - 12
40 | visualize_gui: false
41 | config_path: conf/digit_sensor/config_digit.yml
42 | load_dir: /work/dlclarge2/meeso-lfp/calvin_recordings/play_env_A/2021-10-06/16-23-57/
43 | data_path: data
44 | save_dir: ???
45 | show_gui: false
46 | processes: 16
47 | max_episode_frames: 1
48 | save_body_infos: true
49 | set_static_cam: false
50 | env:
51 | cameras: ${cameras}
52 | show_gui: ${show_gui}
53 | use_vr: false
54 | scene:
55 | _target_: vr_env.scene.play_table_scene.PlayTableScene
56 | _recursive_: false
57 | name: calvin_scene_A
58 | data_path: ${data_path}
59 | global_scaling: 0.8
60 | euler_obs: ${robot.euler_obs}
61 | robot_base_position:
62 | - -0.34
63 | - -0.46
64 | - 0.24
65 | robot_base_orientation:
66 | - 0
67 | - 0
68 | - 0
69 | robot_initial_joint_positions:
70 | - -1.21779206
71 | - 1.03987646
72 | - 2.11978261
73 | - -2.34205014
74 | - -0.87015947
75 | - 1.64119353
76 | - 0.55344866
77 | surfaces:
78 | table:
79 | - - 0.0
80 | - -0.15
81 | - 0.46
82 | - - 0.35
83 | - -0.03
84 | - 0.46
85 | slider_left:
86 | - - -0.32
87 | - 0.05
88 | - 0.46
89 | - - -0.16
90 | - 0.12
91 | - 0.46
92 | slider_right:
93 | - - -0.05
94 | - 0.05
95 | - 0.46
96 | - - 0.13
97 | - 0.12
98 | - 0.46
99 | objects:
100 | fixed_objects:
101 | table:
102 | file: calvin_table_A/urdf/calvin_table_A.urdf
103 | initial_pos:
104 | - 0
105 | - 0
106 | - 0
107 | initial_orn:
108 | - 0
109 | - 0
110 | - 0
111 | joints:
112 | base__slide:
113 | initial_state: 0
114 | base__drawer:
115 | initial_state: 0
116 | buttons:
117 | base__button:
118 | initial_state: 0
119 | effect: led
120 | switches:
121 | base__switch:
122 | initial_state: 0
123 | effect: lightbulb
124 | lights:
125 | lightbulb:
126 | link: light_link
127 | color:
128 | - 1
129 | - 1
130 | - 0
131 | - 1
132 | led:
133 | link: led_link
134 | color:
135 | - 0
136 | - 1
137 | - 0
138 | - 1
139 | movable_objects:
140 | block_red:
141 | file: blocks/block_red_middle.urdf
142 | initial_pos: any
143 | initial_orn: any
144 | block_blue:
145 | file: blocks/block_blue_small.urdf
146 | initial_pos: any
147 | initial_orn: any
148 | block_pink:
149 | file: blocks/block_pink_big.urdf
150 | initial_pos: any
151 | initial_orn: any
152 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # workspace
2 | calvin
3 | checkpoints
4 | eval_logs
5 | evaluate
6 |
7 | # Byte-compiled / optimized / DLL files
8 | __pycache__/
9 | *.py[cod]
10 | *$py.class
11 |
12 | # C extensions
13 | *.so
14 |
15 | # Distribution / packaging
16 | .Python
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | eggs/
22 | .eggs/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | wheels/
29 | share/python-wheels/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 | MANIFEST
34 |
35 | # PyInstaller
36 | # Usually these files are written by a python script from a template
37 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
38 | *.manifest
39 | *.spec
40 |
41 | # Installer logs
42 | pip-log.txt
43 | pip-delete-this-directory.txt
44 |
45 | # Unit test / coverage reports
46 | htmlcov/
47 | .tox/
48 | .nox/
49 | .coverage
50 | .coverage.*
51 | .cache
52 | nosetests.xml
53 | coverage.xml
54 | *.cover
55 | *.py,cover
56 | .hypothesis/
57 | .pytest_cache/
58 | cover/
59 |
60 | # Translations
61 | *.mo
62 | *.pot
63 |
64 | # Django stuff:
65 | *.log
66 | local_settings.py
67 | db.sqlite3
68 | db.sqlite3-journal
69 |
70 | # Flask stuff:
71 | instance/
72 | .webassets-cache
73 |
74 | # Scrapy stuff:
75 | .scrapy
76 |
77 | # Sphinx documentation
78 | docs/_build/
79 |
80 | # PyBuilder
81 | .pybuilder/
82 | target/
83 |
84 | # Jupyter Notebook
85 | .ipynb_checkpoints
86 |
87 | # IPython
88 | profile_default/
89 | ipython_config.py
90 |
91 | # pyenv
92 | # For a library or package, you might want to ignore these files since the code is
93 | # intended to run in multiple environments; otherwise, check them in:
94 | # .python-version
95 |
96 | # pipenv
97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
100 | # install all needed dependencies.
101 | #Pipfile.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/#use-with-ide
116 | .pdm.toml
117 |
118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
119 | __pypackages__/
120 |
121 | # Celery stuff
122 | celerybeat-schedule
123 | celerybeat.pid
124 |
125 | # SageMath parsed files
126 | *.sage.py
127 |
128 | # Environments
129 | .env
130 | .venv
131 | env/
132 | venv/
133 | ENV/
134 | env.bak/
135 | venv.bak/
136 |
137 | # Spyder project settings
138 | .spyderproject
139 | .spyproject
140 |
141 | # Rope project settings
142 | .ropeproject
143 |
144 | # mkdocs documentation
145 | /site
146 |
147 | # mypy
148 | .mypy_cache/
149 | .dmypy.json
150 | dmypy.json
151 |
152 | # Pyre type checker
153 | .pyre/
154 |
155 | # pytype static type analyzer
156 | .pytype/
157 |
158 | # Cython debug symbols
159 | cython_debug/
160 |
161 | # PyCharm
162 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
163 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
164 | # and can be added to the global gitignore or merged into this file. For a more nuclear
165 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
166 | #.idea/
--------------------------------------------------------------------------------
/eval_libero.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import random
4 | import numpy as np
5 | import torch
6 | import wandb
7 | import clip
8 | from torch.nn.parallel import DistributedDataParallel as DDP
9 | from utils.distributed_utils import init_distributed_device, world_info_from_env
10 |
11 | from utils.eval_utils_libero import eval_one_epoch_libero_ddp
12 |
13 | # try:
14 | # from utils.eval_utils_libero import eval_one_epoch_libero_ddp as eval_one_epoch_calvin_ddp
15 | # except:
16 | # pass
17 | # from utils.eval_utils_libero import eval_one_epoch_libero_ddp as eval_one_epoch_calvin_ddp
18 | from torch.distributed.elastic.multiprocessing.errors import record
19 | # from utils.arguments_utils import get_args_and_cfg
20 | from utils.arguments_utils import get_parser
21 | from pdb import set_trace
22 | from models.seer_model import SeerAgent
23 |
24 | def random_seed(seed=42, rank=0):
25 | torch.manual_seed(seed + rank)
26 | np.random.seed(seed + rank)
27 | random.seed(seed + rank)
28 |
29 | @record
30 | def main():
31 | parser = get_parser(is_eval=True)
32 | args = parser.parse_args()
33 | args.local_rank, args.rank, args.world_size = world_info_from_env()
34 | device_id = init_distributed_device(args)
35 | print("device_id: ", device_id)
36 | random_seed(args.seed)
37 |
38 | # model
39 | model = SeerAgent(
40 | finetune_type=args.finetune_type,
41 | clip_device=device_id,
42 | vit_checkpoint_path=args.vit_checkpoint_path,
43 | sequence_length=args.sequence_length,
44 | num_resampler_query=args.num_resampler_query,
45 | num_obs_token_per_image=args.num_obs_token_per_image,
46 | calvin_input_image_size=args.calvin_input_image_size,
47 | patch_size=args.patch_size,
48 | action_pred_steps=args.action_pred_steps,
49 | obs_pred=args.obs_pred,
50 | atten_only_obs=args.atten_only_obs,
51 | attn_robot_proprio_state=args.attn_robot_proprio_state,
52 | atten_goal=args.atten_goal,
53 | atten_goal_state=args.atten_goal_state,
54 | mask_l_obs_ratio=args.mask_l_obs_ratio,
55 | transformer_layers=args.transformer_layers,
56 | hidden_dim=args.hidden_dim,
57 | transformer_heads=args.transformer_heads,
58 | phase=args.phase,
59 | gripper_width=args.gripper_width,
60 | )
61 |
62 | random_seed(args.seed, args.rank)
63 | print(f"Start running training on rank {args.rank}.")
64 |
65 | device_id = args.rank % torch.cuda.device_count()
66 | if args.precision == "bf16" or args.precision == "amp_bfloat16" or args.precision == "amp_bf16":
67 | model = model.bfloat16()
68 | elif args.precision == "fp16":
69 | model = model.half()
70 | elif args.precision == "fp32":
71 | model = model.float()
72 | if 'vision_encoder' in args.bf16_module:
73 | model.vision_encoder.bfloat16()
74 | if "causal_transformer" in args.bf16_module:
75 | model.transformer_backbone.bfloat16()
76 | if "image_decoder" in args.bf16_module:
77 | model.image_decoder.bfloat16()
78 | model.image_decoder_obs_pred_projector.bfloat16()
79 |
80 | model.clip_model.requires_grad_(False)
81 | model.vision_encoder.requires_grad_(False)
82 | model = model.to(device_id)
83 | model._init_model_type()
84 |
85 | ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters=True)
86 |
87 | if args.resume_from_checkpoint is not None:
88 | if args.rank == 0:
89 | print(f"Loading checkpoint from {args.resume_from_checkpoint}")
90 | checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu")
91 | ddp_model.load_state_dict(checkpoint["model_state_dict"], False)
92 |
93 | ddp_model.eval()
94 |
95 | eval_log_dir = 'evaluate'
96 |
97 | if args.finetune_type == "libero_10":
98 | eval_one_epoch_libero_ddp(
99 | args=args,
100 | model=ddp_model,
101 | image_processor=model.image_processor,
102 | tokenizer=clip,
103 | )
104 | else:
105 | raise NotImplementedError
106 |
107 | if __name__ == "__main__":
108 | os.environ["NCCL_BLOCKING_WAIT"] = "0"
109 | main()
110 |
--------------------------------------------------------------------------------
/models/perceiver_resampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from einops import rearrange, repeat
3 | from einops_exts import rearrange_many
4 | from torch import einsum, nn
5 |
6 |
7 | def exists(val):
8 | return val is not None
9 |
10 |
11 | def FeedForward(dim, mult=4):
12 | inner_dim = int(dim * mult)
13 | return nn.Sequential(
14 | nn.LayerNorm(dim),
15 | nn.Linear(dim, inner_dim, bias=False),
16 | nn.GELU(),
17 | nn.Linear(inner_dim, dim, bias=False),
18 | )
19 |
20 |
21 | class PerceiverAttention(nn.Module):
22 | def __init__(self, *, dim, dim_head=64, heads=8):
23 | super().__init__()
24 | self.scale = dim_head**-0.5
25 | self.heads = heads
26 | inner_dim = dim_head * heads
27 |
28 | self.norm_media = nn.LayerNorm(dim)
29 | self.norm_latents = nn.LayerNorm(dim)
30 |
31 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
32 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
33 | self.to_out = nn.Linear(inner_dim, dim, bias=False)
34 |
35 | def forward(self, x, latents):
36 | """
37 | Args:
38 | x (torch.Tensor): image features
39 | shape (b, T, n1, D)
40 | latent (torch.Tensor): latent features
41 | shape (b, T, n2, D)
42 | """
43 | x = self.norm_media(x)
44 | latents = self.norm_latents(latents)
45 |
46 | h = self.heads
47 |
48 | q = self.to_q(latents)
49 | kv_input = torch.cat((x, latents), dim=-2)
50 | k, v = self.to_kv(kv_input).chunk(2, dim=-1)
51 | q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
52 | q = q * self.scale
53 |
54 | # attention
55 | sim = einsum("... i d, ... j d -> ... i j", q, k)
56 | sim = sim - sim.amax(dim=-1, keepdim=True).detach()
57 | attn = sim.softmax(dim=-1)
58 |
59 | out = einsum("... i j, ... j d -> ... i d", attn, v)
60 | out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
61 | return self.to_out(out)
62 |
63 |
64 | class PerceiverResampler(nn.Module):
65 | def __init__(
66 | self,
67 | *,
68 | dim,
69 | depth=6,
70 | dim_head=64,
71 | heads=8,
72 | num_latents=64,
73 | max_num_media=None,
74 | max_num_frames=None,
75 | ff_mult=4,
76 | ):
77 | super().__init__()
78 | self.latents = nn.Parameter(torch.randn(num_latents, dim))
79 | self.frame_embs = (
80 | nn.Parameter(torch.randn(max_num_frames, dim))
81 | if exists(max_num_frames)
82 | else None
83 | )
84 | self.media_time_embs = (
85 | nn.Parameter(torch.randn(max_num_media, 1, dim))
86 | if exists(max_num_media)
87 | else None
88 | )
89 |
90 | self.layers = nn.ModuleList([])
91 | for _ in range(depth):
92 | self.layers.append(
93 | nn.ModuleList(
94 | [
95 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
96 | FeedForward(dim=dim, mult=ff_mult),
97 | ]
98 | )
99 | )
100 |
101 | self.norm = nn.LayerNorm(dim)
102 |
103 | def forward(self, x):
104 | """
105 | Args:
106 | x (torch.Tensor): image features
107 | shape (b, T, F, v, D)
108 | Returns:
109 | shape (b, T, n, D) where n is self.num_latents
110 | """
111 | b, T, F, v = x.shape[:4]
112 |
113 | # frame and media time embeddings
114 | if exists(self.frame_embs):
115 | frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
116 | x = x + frame_embs
117 | x = rearrange(
118 | x, "b T F v d -> b T (F v) d"
119 | ) # flatten the frame and spatial dimensions
120 | if exists(self.media_time_embs):
121 | x = x + self.media_time_embs[:T]
122 |
123 | # blocks
124 | latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
125 | for attn, ff in self.layers:
126 | latents = attn(x, latents) + latents
127 | latents = ff(latents) + latents
128 | return self.norm(latents)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Predictive Inverse Dynamics Models are Scalable Learners for Robotic Manipulation
4 |
5 |
6 |
10 |
11 |
12 | https://github.com/user-attachments/assets/49036e84-c397-4589-9024-efb05b14efa0
13 |
14 |
15 |
16 |
17 | ## :books: Table of Contents:
18 | 1. [Highlights](#high)
19 | 2. [Getting Started](#start)
20 | - [Simulation](#simulation)
21 | - [Real-World](#real-world)
22 | 3. [Checkpoints](#checkpoints)
23 | 4. [TODO List](#todos)
24 | 5. [License](#license)
25 | 6. [Citation](#citation).
26 | 7. [Acknowledgment](#acknowledgment)
27 |
28 | ## :fire: Highlights
29 |
30 |
31 | - :trophy: **SOTA simulation performance** Seer achieves state-of-the-art performance on simulation benchmarks CALVIN ABC-D and LIBERO-LONG.
32 | - :muscle: **Impressive Real-World performance** Seer demonstrates strong effectiveness and generalization across diverse real-world downstream tasks.
33 |
34 | ## :door: Getting Started
35 | We provide step-by-step guidance for running Seer in simulations and real-world experiments.
36 | Follow the specific instructions for a seamless setup.
37 |
38 | ### Simulation
39 | #### CALVIN ABC-D
40 | - [Installation](docs/CALVIN_ABC-D_INSTALL.md)
41 | - [Running Code](docs/CALVIN_ABC-D_RUN.md)
42 | #### LIBERO LONG
43 | - [Installation](docs/LIBERO_LONG_INSTALL.md)
44 | - [Running Code](docs/LIBERO_LONG_RUN.md)
45 | ### Real-World
46 | #### Real-World (Quick Training w & w/o pre-training)
47 | For users aiming to train Seer from scratch or fine-tune it, we provide comprehensive instructions for environment setup, downstream task data preparation, training, and deployment.
48 | - [Installation](docs/REAL-WORLD_INSTALL.md)
49 | - [Post-processing](docs/REAL-WORLD_POSTPROCESS.md)
50 | - [Fine-tuning & Scratch](docs/REAL-WORLD_FT_SC.md)
51 | - [Inference](docs/REAL-WORLD_INFERENCE.md)
52 |
53 | #### Real-World (Pre-training)
54 | This section details the pre-training process of Seer in real-world experiments, including environment setup, dataset preparation, and training procedures. Downstream task processing and fine-tuning are covered in [Real-World (Quick Training w & w/o pre-training)](#real-world-qs).
55 | - [Installation](docs/REAL-WORLD_INSTALL.md)
56 | - [Pre-processing](docs/REAL-WORLD_PREPROCESS.md)
57 | - [Pre-training](docs/REAL-WORLD_PRETRAIN.md)
58 |
59 |
60 | ## :pencil2: Checkpoints
61 | Relevant checkpoints are available on the [website](https://drive.google.com/drive/folders/1F3IE95z2THAQ_lt3DKUFdRGc86Thsnc7?usp=sharing).
62 | |Model|Checkpoint|
63 | |:------:|:------:|
64 | |CALVIN ABC-D|[Seer](https://drive.google.com/drive/folders/17Gv9snGCkViuhHmzN3eTWlI0tMfGSGT3?usp=sharing) (Avg.Len. : 3.98) / [Seer Large](https://drive.google.com/drive/folders/1AFabqfDEi69oMo0FTGhEiH2QSRLYBR9r?usp=drive_link) (Avg.Len. : 4.30)|
65 | |Real-World|[Seer (Droid Pre-trained)](https://drive.google.com/drive/folders/1rT8JKLhJGIo97jfYUm2JiFUrogOq-dgJ?usp=drive_link)|
66 |
67 | ## 📆 TODO
68 | - [x] Release real-world expriment code.
69 | - [x] Release CALVIN ABC-D experiment code (Seer).
70 | - [x] Release the evaluation code of Seer-Large on CALVIN ABC-D experiment.
71 | - [x] Release the training code of Seer-Large on CALVIN ABC-D experiment.
72 | - [x] Release LIBERO-LONG experiment code.
73 | - [ ] Release simpleseer, a quick scratch training & deploying code.
74 |
75 | ## License
76 |
77 | All assets and code are under the [Apache 2.0 license](./LICENSE) unless specified otherwise.
78 |
79 | ## Citation
80 | If you find the project helpful for your research, please consider citing our paper:
81 | ```bibtex
82 | @article{tian2024predictive,
83 | title={Predictive Inverse Dynamics Models are Scalable Learners for Robotic Manipulation},
84 | author={Tian, Yang and Yang, Sizhe and Zeng, Jia and Wang, Ping and Lin, Dahua and Dong, Hao and Pang, Jiangmiao},
85 | journal={arXiv preprint arXiv:2412.15109},
86 | year={2024}
87 | }
88 | ```
89 |
90 | ## Acknowledgment
91 | This project builds upon [GR-1](https://github.com/bytedance/GR-1) and [Roboflamingo](https://github.com/RoboFlamingo/RoboFlamingo). We thank these teams for their open-source contributions.
92 |
--------------------------------------------------------------------------------
/real_preprocess/octo_oxe_data_utils/obs_transforms.py:
--------------------------------------------------------------------------------
1 | """
2 | Contains observation-level transforms used in the octo data pipeline. These transforms operate on the
3 | "observation" dictionary, and are applied at a per-frame level.
4 | """
5 | from typing import Mapping, Tuple, Union
6 |
7 | from absl import logging
8 | import dlimp as dl
9 | import tensorflow as tf
10 |
11 |
12 | def augment(
13 | obs: dict, seed: tf.Tensor, augment_kwargs: Union[dict, Mapping[str, dict]]
14 | ) -> dict:
15 | """Augments images, skipping padding images."""
16 | image_names = {key[6:] for key in obs if key.startswith("image_")}
17 |
18 | # "augment_order" is required in augment_kwargs, so if it's there, we can assume that the user has passed
19 | # in a single augmentation dict (otherwise, we assume that the user has passed in a mapping from image
20 | # name to augmentation dict)
21 | if "augment_order" in augment_kwargs:
22 | augment_kwargs = {name: augment_kwargs for name in image_names}
23 |
24 | for i, name in enumerate(image_names):
25 | if name not in augment_kwargs:
26 | continue
27 | kwargs = augment_kwargs[name]
28 | logging.debug(f"Augmenting image_{name} with kwargs {kwargs}")
29 | obs[f"image_{name}"] = tf.cond(
30 | obs["pad_mask_dict"][f"image_{name}"],
31 | lambda: dl.transforms.augment_image(
32 | obs[f"image_{name}"],
33 | **kwargs,
34 | seed=seed + i, # augment each image differently
35 | ),
36 | lambda: obs[f"image_{name}"], # skip padding images
37 | )
38 |
39 | return obs
40 |
41 |
42 | def decode_and_resize(
43 | obs: dict,
44 | resize_size: Union[Tuple[int, int], Mapping[str, Tuple[int, int]]],
45 | depth_resize_size: Union[Tuple[int, int], Mapping[str, Tuple[int, int]]],
46 | ) -> dict:
47 | """Decodes images and depth images, and then optionally resizes them."""
48 | # just gets the part after "image_" or "depth_"
49 | image_names = {key[6:] for key in obs if key.startswith("image_")}
50 | depth_names = {key[6:] for key in obs if key.startswith("depth_")}
51 |
52 | if isinstance(resize_size, tuple):
53 | resize_size = {name: resize_size for name in image_names}
54 | if isinstance(depth_resize_size, tuple):
55 | depth_resize_size = {name: depth_resize_size for name in depth_names}
56 |
57 | for name in image_names:
58 | if name not in resize_size:
59 | logging.warning(
60 | f"No resize_size was provided for image_{name}. This will result in 1x1 "
61 | "padding images, which may cause errors if you mix padding and non-padding images."
62 | )
63 | image = obs[f"image_{name}"]
64 | if image.dtype == tf.string:
65 | if tf.strings.length(image) == 0:
66 | # this is a padding image
67 | image = tf.zeros((*resize_size.get(name, (1, 1)), 3), dtype=tf.uint8)
68 | else:
69 | image = tf.io.decode_image(
70 | image, expand_animations=False, dtype=tf.uint8
71 | )
72 | elif image.dtype != tf.uint8:
73 | raise ValueError(
74 | f"Unsupported image dtype: found image_{name} with dtype {image.dtype}"
75 | )
76 | if name in resize_size:
77 | image = dl.transforms.resize_image(image, size=resize_size[name])
78 | obs[f"image_{name}"] = image
79 |
80 | for name in depth_names:
81 | if name not in depth_resize_size:
82 | logging.warning(
83 | f"No depth_resize_size was provided for depth_{name}. This will result in 1x1 "
84 | "padding depth images, which may cause errors if you mix padding and non-padding images."
85 | )
86 | depth = obs[f"depth_{name}"]
87 | if depth.dtype == tf.string:
88 | if tf.strings.length(depth) == 0:
89 | # this is a padding image
90 | depth = tf.zeros(
91 | (*depth_resize_size.get(name, (1, 1)), 1), dtype=tf.float32
92 | )
93 | else:
94 | depth = tf.io.decode_image(
95 | depth, expand_animations=False, dtype=tf.float32
96 | )[..., 0]
97 | elif depth.dtype != tf.float32:
98 | raise ValueError(
99 | f"Unsupported depth dtype: found depth_{name} with dtype {depth.dtype}"
100 | )
101 | if name in depth_resize_size:
102 | depth = dl.transforms.resize_depth_image(
103 | depth, size=depth_resize_size[name]
104 | )
105 | obs[f"depth_{name}"] = depth
106 |
107 | return obs
108 |
--------------------------------------------------------------------------------
/eval_calvin.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import random
4 | import numpy as np
5 | import torch
6 | import wandb
7 | import clip
8 | from torch.nn.parallel import DistributedDataParallel as DDP
9 | from torch.distributed.elastic.multiprocessing.errors import record
10 | from models.seer_model import SeerAgent
11 | from utils.data_utils import get_calvin_dataset, get_calvin_val_dataset
12 | from utils.distributed_utils import init_distributed_device, world_info_from_env
13 | from utils.eval_utils_calvin import eval_one_epoch_calvin_ddp
14 | from utils.arguments_utils import get_parser
15 |
16 |
17 | def random_seed(seed=42, rank=0):
18 | torch.manual_seed(seed + rank)
19 | np.random.seed(seed + rank)
20 | random.seed(seed + rank)
21 |
22 | @record
23 | def main():
24 | parser = get_parser(is_eval=True)
25 | args = parser.parse_args()
26 | if args.save_checkpoints_to_wandb and args.save_checkpoint and not args.report_to_wandb:
27 | raise ValueError("save_checkpoints_to_wandb requires report_to_wandb")
28 | if args.offline:
29 | os.environ["WANDB_MODE"] = "offline"
30 | os.environ["TRANSFORMERS_OFFLINE"] = "1"
31 | args.local_rank, args.rank, args.world_size = world_info_from_env()
32 | device_id = init_distributed_device(args)
33 | print("device_id: ", device_id)
34 | random_seed(args.seed)
35 | model = SeerAgent(
36 | finetune_type=args.finetune_type,
37 | clip_device=device_id,
38 | vit_checkpoint_path=args.vit_checkpoint_path,
39 | sequence_length=args.sequence_length,
40 | num_resampler_query=args.num_resampler_query,
41 | num_obs_token_per_image=args.num_obs_token_per_image,
42 | calvin_input_image_size=args.calvin_input_image_size,
43 | patch_size=args.patch_size,
44 | action_pred_steps=args.action_pred_steps,
45 | obs_pred=args.obs_pred,
46 | atten_only_obs=args.atten_only_obs,
47 | attn_robot_proprio_state=args.attn_robot_proprio_state,
48 | atten_goal=args.atten_goal,
49 | atten_goal_state=args.atten_goal_state,
50 | mask_l_obs_ratio=args.mask_l_obs_ratio,
51 | transformer_layers=args.transformer_layers,
52 | hidden_dim=args.hidden_dim,
53 | transformer_heads=args.transformer_heads,
54 | phase=args.phase,
55 | gripper_width=args.gripper_width,
56 | )
57 | calvin_dataset = get_calvin_dataset(args, model.image_processor, clip, epoch=0)
58 | random_seed(args.seed, args.rank)
59 | print(f"Start running training on rank {args.rank}.")
60 | if args.rank == 0 and args.report_to_wandb:
61 | wandb.init(
62 | project=args.wandb_project,
63 | entity=args.wandb_entity,
64 | name=args.run_name,
65 | config=vars(args),
66 | )
67 | device_id = args.rank % torch.cuda.device_count()
68 | if args.precision == "bf16" or args.precision == "amp_bfloat16" or args.precision == "amp_bf16":
69 | model = model.bfloat16()
70 | elif args.precision == "fp16":
71 | model = model.half()
72 | elif args.precision == "fp32":
73 | model = model.float()
74 | if 'vision_encoder' in args.bf16_module:
75 | model.vision_encoder.bfloat16()
76 | if "causal_transformer" in args.bf16_module:
77 | model.transformer_backbone.bfloat16()
78 | if "image_decoder" in args.bf16_module:
79 | model.image_decoder.bfloat16()
80 | model.image_decoder_obs_pred_projector.bfloat16()
81 | model.clip_model.requires_grad_(False)
82 | model.vision_encoder.requires_grad_(False)
83 | model = model.to(device_id)
84 | model._init_model_type()
85 | ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters=True)
86 | if args.resume_from_checkpoint is not None:
87 | if args.rank == 0:
88 | print(f"Loading checkpoint from {args.resume_from_checkpoint}")
89 | checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu")
90 | ddp_model.load_state_dict(checkpoint["model_state_dict"], False)
91 | ddp_model.eval()
92 | eval_log_dir = 'evaluate'
93 | if args.finetune_type == "calvin":
94 | eval_one_epoch_calvin_ddp(
95 | args=args,
96 | model=ddp_model,
97 | image_processor=model.image_processor,
98 | tokenizer=clip,
99 | dataset_path=args.calvin_dataset,
100 | future_act_len=args.future_act_len,
101 | eval_log_dir=eval_log_dir,
102 | debug=args.visualize,
103 | reset=args.reset,
104 | diverse_inst=args.diverse_inst
105 | )
106 | else:
107 | raise NotImplementedError
108 |
109 | if __name__ == "__main__":
110 | os.environ["NCCL_BLOCKING_WAIT"] = "0"
111 | main()
112 |
--------------------------------------------------------------------------------
/real_preprocess/octo_oxe_data_utils/traj_transforms.py:
--------------------------------------------------------------------------------
1 | """
2 | Contains trajectory transforms used in the octo data pipeline. Trajectory transforms operate on a dictionary
3 | that represents a single trajectory, meaning each tensor has the same leading dimension (the trajectory
4 | length).
5 | """
6 | import logging
7 |
8 | import tensorflow as tf
9 |
10 |
11 | def chunk_act_obs(
12 | traj: dict,
13 | window_size: int,
14 | future_action_window_size: int = 0,
15 | ) -> dict:
16 | """Chunks actions and observations into the given window_size.
17 |
18 | "observation" keys are given a new axis (at index 1) of size `window_size` containing `window_size - 1`
19 | observations from the past and the current observation. "action" is given a new axis (at index 1) of size
20 | `window_size + future_action_window_size` containing `window_size - 1` actions from the past, the current
21 | action, and `future_action_window_size` actions from the future. "pad_mask" is added to "observation" and
22 | indicates whether an observation should be considered padding (i.e. if it would have come from a timestep
23 | before the start of the trajectory).
24 | """
25 | traj_len = tf.shape(traj["action"])[0]
26 | action_dim = traj["action"].shape[-1]
27 | chunk_indices = tf.broadcast_to(
28 | tf.range(-window_size + 1, 1), [traj_len, window_size]
29 | ) + tf.broadcast_to(tf.range(traj_len)[:, None], [traj_len, window_size])
30 |
31 | action_chunk_indices = tf.broadcast_to(
32 | tf.range(-window_size + 1, 1 + future_action_window_size),
33 | [traj_len, window_size + future_action_window_size],
34 | ) + tf.broadcast_to(
35 | tf.range(traj_len)[:, None],
36 | [traj_len, window_size + future_action_window_size],
37 | )
38 |
39 | floored_chunk_indices = tf.maximum(chunk_indices, 0)
40 |
41 | if "timestep" in traj["task"]:
42 | goal_timestep = traj["task"]["timestep"]
43 | else:
44 | goal_timestep = tf.fill([traj_len], traj_len - 1)
45 |
46 | floored_action_chunk_indices = tf.minimum(
47 | tf.maximum(action_chunk_indices, 0), goal_timestep[:, None]
48 | )
49 |
50 | traj["observation"] = tf.nest.map_structure(
51 | lambda x: tf.gather(x, floored_chunk_indices), traj["observation"]
52 | )
53 | traj["action"] = tf.gather(traj["action"], floored_action_chunk_indices)
54 |
55 | # indicates whether an entire observation is padding
56 | traj["observation"]["pad_mask"] = chunk_indices >= 0
57 |
58 | # if no absolute_action_mask was provided, assume all actions are relative
59 | if "absolute_action_mask" not in traj and future_action_window_size > 0:
60 | logging.warning(
61 | "future_action_window_size > 0 but no absolute_action_mask was provided. "
62 | "Assuming all actions are relative for the purpose of making neutral actions."
63 | )
64 | absolute_action_mask = traj.get(
65 | "absolute_action_mask", tf.zeros([traj_len, action_dim], dtype=tf.bool)
66 | )
67 | neutral_actions = tf.where(
68 | absolute_action_mask[:, None, :],
69 | traj["action"], # absolute actions are repeated (already done during chunking)
70 | tf.zeros_like(traj["action"]), # relative actions are zeroed
71 | )
72 |
73 | # actions past the goal timestep become neutral
74 | action_past_goal = action_chunk_indices > goal_timestep[:, None]
75 | traj["action"] = tf.where(
76 | action_past_goal[:, :, None], neutral_actions, traj["action"]
77 | )
78 | return traj
79 |
80 |
81 | def subsample(traj: dict, subsample_length: int) -> dict:
82 | """Subsamples trajectories to the given length."""
83 | traj_len = tf.shape(traj["action"])[0]
84 | if traj_len > subsample_length:
85 | indices = tf.random.shuffle(tf.range(traj_len))[:subsample_length]
86 | traj = tf.nest.map_structure(lambda x: tf.gather(x, indices), traj)
87 | return traj
88 |
89 |
90 | def add_pad_mask_dict(traj: dict) -> dict:
91 | """Adds a dictionary indicating which elements of the observation/task should be treated as padding.
92 |
93 | traj["observation"|"task"]["pad_mask_dict"] = {k: traj["observation"|"task"][k] is not padding}
94 | """
95 | traj_len = tf.shape(traj["action"])[0]
96 | for key in ["observation", "task"]:
97 | pad_mask_dict = {}
98 | for subkey in traj[key]:
99 | if traj[key][subkey].dtype == tf.string:
100 | # handles "language_instruction", "image_*", and "depth_*"
101 | pad_mask_dict[subkey] = tf.strings.length(traj[key][subkey]) != 0
102 | else:
103 | # all other keys should not be treated as padding
104 | pad_mask_dict[subkey] = tf.ones([traj_len], dtype=tf.bool)
105 | traj[key]["pad_mask_dict"] = pad_mask_dict
106 | return traj
107 |
--------------------------------------------------------------------------------
/docs/REAL-WORLD_POSTPROCESS.md:
--------------------------------------------------------------------------------
1 | # Post-process
2 | Self-collected data and pre-training datasets often exhibit certain challenges that can negatively affect model performance. Below, we outline these issues and provide standardized solutions for data formatting and post-processing.
3 | * :warning: **Varied Action Labels:**
4 | Different embodiments, and sometimes even identical ones, may use diverse action labels such as:
5 | * **Absolute target joint positions** (qpos)
6 | * **Absolute target joint velocites** (qvel)
7 | * **Absolute end-effector poses** (ee-pose)
8 | * **Delta target end-effector poses** (delta ee-pose)
9 | Furthermore, rotation representations may vary, including quaternions, Euler angles, rotation vectors, and rotation matrices.
10 | * :warning: **Jittering and Long Pauses:** Fresh data collectors often introduce hesitation, leading to long pauses or jittering during data collection. Without proper filtering, such data significantly degrades model performance.
11 | * :warning: **Quick Gripper Open/Close Actions:** A frequency mismatch between camera capture and gripper control often results in abrupt changes in gripper states, especially during grasping or releasing motions.
12 |
13 | To address these issues, we recommend a uniform, clear, and effective format for saving self-collected data and provide tools for post-processing.
14 |
15 | ## :exclamation: Data Format
16 | For each task, we collect 100 demos. The recommended directory structure is:
17 | ```
18 | 0000 (exp_id)
19 | |—— 000000 (episode_id)
20 | |—— steps
21 | |—— 0000 (timestep_id, start)
22 | |—— image_primary.jpg (Eye-on-Base camera rgb image)
23 | |—— image_wrist.jpg (Eye-on-Hand camera rgb image)
24 | └── other.npz (robot state, language, action)
25 | |—— ......
26 | └── xxxx (timestep_id, end)
27 | |—— 000001 (episode_id)
28 | |—— steps
29 | |—— ......
30 | |—— ......
31 | └── 000099 (episode_id)
32 | |—— steps
33 | |—— ......
34 | ```
35 | ### File Details:
36 | * **image_primary.jpg** and **image_wrist.jpg**: Images saved with a resolution of 640 x 480 pixels.
37 | * **other.npz**: Contains key robot metadata. An example of the saved format is:
38 | ```python
39 | # at each timestep i
40 | npz_path = f"other.npz"
41 |
42 | # absolute current gripper pose in robot space, position + euler angles, the unit is m and rad.
43 | gripper_pose = np.array([x, y, z, euler_x, euler_y, euler_z])
44 |
45 | # absolute current gripper open state
46 | gripper_open_state = np.array([1.0]) if gripper is opened else np.array([-1.0])
47 |
48 | # absolute current joints position (qpos)
49 | joints = np.array([q0, q1, q2, q3, q4, q5, q6])
50 |
51 | # language instruction
52 | language_instruction = f"Pick the apple."
53 |
54 | # absolute target pose action label (target_gripper_open_or_close is 1.0 if targetting open, else -1.0)
55 | action_gripper_pose = np.array([target_x, target_y, target_z, target_euler_x, target_euler_y, target_euler_z, target_gripper_open_or_close])
56 |
57 | # delta pose action label
58 | delta_cur_2_last_action = np.array([target_delta_x, target_delta_y, target_delta_z, target_delta_euler_x, target_delta_euler_y, target_delta_euler_z, target_gripper_open_or_close])
59 |
60 | # save npz
61 | np.savez_compressed(
62 | npz_path,
63 | joints=joints,
64 | gripper_pose=gripper_pose,
65 | gripper_open_state=gripper_open_state,
66 | action_gripper_pose=action_gripper_pose,
67 | delta_cur_2_last_action=delta_cur_2_last_action,
68 | language_instruction=language_instruction,
69 | )
70 | ```
71 | For most robotic systems, all metadata except delta_cur_2_last_action can be directly extracted. We provide a helper function to compute the delta pose action label in the [script](../utils/real_ft_data.py):
72 | ```python
73 | compute_delta_action(
74 | data_list,
75 | )
76 | ```
77 |
78 | ## :star: Post-processing Self-collected Data
79 | * **Filtering Jitter and Pauses:** To filter out jittering and long pauses, use the following function in the [script](../utils/real_ft_data.py):
80 | ```python
81 | filter_real_data(
82 | exp_id,
83 | root_path, # path to your raw data
84 | save_data_path, # a desired path to save filterd data
85 | save_gif_path # a desired path to save the filtered gif (only for visualization and debugging)
86 | )
87 | ```
88 | * **Data Augmentation for Gripper Actions:** To augment data by increasing sampling ratios during gripper open/close events, use the following function in the same [script](../utils/real_ft_data.py):
89 | ```python
90 | make_aug_short_real_dataset_info(
91 | root_path, # path to your filterd data
92 | root_info_path, # path to your data info, it should be like xxx/Seer/data_info
93 | dataset_name, # your dataset name, e.g. "ft"
94 | select_ratio=1.0,
95 | sequence_length=7,
96 | action_pred_steps=3,
97 | replicate_steps=10
98 | )
99 | ```
100 |
101 |
--------------------------------------------------------------------------------
/data_info/.hydra/hydra.yaml:
--------------------------------------------------------------------------------
1 | hydra:
2 | run:
3 | dir: /work/dlclarge2/meeso-lfp/calvin_data/task_A_A/
4 | sweep:
5 | dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
6 | subdir: ${hydra.job.num}
7 | launcher:
8 | _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
9 | sweeper:
10 | _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
11 | max_batch_size: null
12 | help:
13 | app_name: ${hydra.job.name}
14 | header: '${hydra.help.app_name} is powered by Hydra.
15 |
16 | '
17 | footer: 'Powered by Hydra (https://hydra.cc)
18 |
19 | Use --hydra-help to view Hydra specific help
20 |
21 | '
22 | template: '${hydra.help.header}
23 |
24 | == Configuration groups ==
25 |
26 | Compose your configuration from those groups (group=option)
27 |
28 |
29 | $APP_CONFIG_GROUPS
30 |
31 |
32 | == Config ==
33 |
34 | Override anything in the config (foo.bar=value)
35 |
36 |
37 | $CONFIG
38 |
39 |
40 | ${hydra.help.footer}
41 |
42 | '
43 | hydra_help:
44 | template: 'Hydra (${hydra.runtime.version})
45 |
46 | See https://hydra.cc for more info.
47 |
48 |
49 | == Flags ==
50 |
51 | $FLAGS_HELP
52 |
53 |
54 | == Configuration groups ==
55 |
56 | Compose your configuration from those groups (For example, append hydra/job_logging=disabled
57 | to command line)
58 |
59 |
60 | $HYDRA_CONFIG_GROUPS
61 |
62 |
63 | Use ''--cfg hydra'' to Show the Hydra config.
64 |
65 | '
66 | hydra_help: ???
67 | hydra_logging:
68 | version: 1
69 | formatters:
70 | colorlog:
71 | (): colorlog.ColoredFormatter
72 | format: '[%(cyan)s%(asctime)s%(reset)s][%(purple)sHYDRA%(reset)s] %(message)s'
73 | handlers:
74 | console:
75 | class: logging.StreamHandler
76 | formatter: colorlog
77 | stream: ext://sys.stdout
78 | root:
79 | level: INFO
80 | handlers:
81 | - console
82 | disable_existing_loggers: false
83 | job_logging:
84 | version: 1
85 | formatters:
86 | simple:
87 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
88 | colorlog:
89 | (): colorlog.ColoredFormatter
90 | format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)s%(name)s%(reset)s][%(log_color)s%(levelname)s%(reset)s]
91 | - %(message)s'
92 | log_colors:
93 | DEBUG: purple
94 | INFO: green
95 | WARNING: yellow
96 | ERROR: red
97 | CRITICAL: red
98 | handlers:
99 | console:
100 | class: logging.StreamHandler
101 | formatter: colorlog
102 | stream: ext://sys.stdout
103 | file:
104 | class: logging.FileHandler
105 | formatter: simple
106 | filename: ${hydra.job.name}.log
107 | root:
108 | level: INFO
109 | handlers:
110 | - console
111 | - file
112 | disable_existing_loggers: false
113 | env: {}
114 | searchpath: []
115 | callbacks: {}
116 | output_subdir: .hydra
117 | overrides:
118 | hydra:
119 | - hydra.run.dir=/work/dlclarge2/meeso-lfp/calvin_data/task_A_A/
120 | task:
121 | - load_dir=/work/dlclarge2/meeso-lfp/calvin_recordings/play_env_A/2021-10-06/16-23-57/
122 | - set_static_cam=false
123 | - processes=16
124 | - +scene=calvin_scene_A
125 | - show_gui=false
126 | job:
127 | name: datarenderer
128 | override_dirname: +scene=calvin_scene_A,load_dir=/work/dlclarge2/meeso-lfp/calvin_recordings/play_env_A/2021-10-06/16-23-57/,processes=16,set_static_cam=false,show_gui=false
129 | id: ???
130 | num: ???
131 | config_name: config_rendering
132 | env_set: {}
133 | env_copy: []
134 | config:
135 | override_dirname:
136 | kv_sep: '='
137 | item_sep: ','
138 | exclude_keys: []
139 | runtime:
140 | version: 1.1.1
141 | cwd: /home/meeso/VRENVCalvinRender/vr_env
142 | config_sources:
143 | - path: hydra.conf
144 | schema: pkg
145 | provider: hydra
146 | - path: /home/meeso/VRENVCalvinRender/conf
147 | schema: file
148 | provider: main
149 | - path: hydra_plugins.hydra_colorlog.conf
150 | schema: pkg
151 | provider: hydra-colorlog
152 | - path: ''
153 | schema: structured
154 | provider: schema
155 | choices:
156 | scene: calvin_scene_A
157 | cameras: static_gripper_tactile
158 | cameras/cameras@cameras.tactile: tactile
159 | cameras/cameras@cameras.gripper: gripper
160 | cameras/cameras@cameras.static: static
161 | hydra/env: default
162 | hydra/callbacks: null
163 | hydra/job_logging: colorlog
164 | hydra/hydra_logging: colorlog
165 | hydra/hydra_help: default
166 | hydra/help: default
167 | hydra/sweeper: basic
168 | hydra/launcher: basic
169 | hydra/output: default
170 | verbose: false
171 |
--------------------------------------------------------------------------------
/real_preprocess/octo_oxe_data_utils/oxe/__init__.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 | from typing import Any, Dict, List, Sequence, Tuple, Union
4 |
5 | from octo_oxe_data_utils.oxe.oxe_dataset_configs import ActionEncoding, OXE_DATASET_CONFIGS
6 | from octo_oxe_data_utils.oxe.oxe_dataset_mixes import OXE_NAMED_MIXES
7 | from octo_oxe_data_utils.oxe.oxe_standardization_transforms import OXE_STANDARDIZATION_TRANSFORMS
8 | from octo_oxe_data_utils.utils.data_utils import NormalizationType
9 |
10 |
11 | def make_oxe_dataset_kwargs(
12 | name: str,
13 | data_dir: str,
14 | load_camera_views: Sequence[str] = ("primary",),
15 | load_depth: bool = False,
16 | load_proprio: bool = True,
17 | load_language: bool = True,
18 | action_proprio_normalization_type: NormalizationType = NormalizationType.NORMAL,
19 | ) -> Dict[str, Any]:
20 | """Generates dataset kwargs for a given dataset from Open X-Embodiment. The returned kwargs can be passed
21 | directly into `octo.data.dataset.make_dataset_from_rlds`.
22 |
23 | Args:
24 | name: Name of the dataset to load. See `oxe_dataset_configs.py` for available datasets.
25 | data_dir: Base data directory that contains the dataset.
26 | load_camera_views: Which views to load. See `oxe_dataset_configs.py` for available views.
27 | load_depth: If True, loads corresponding depth channels for each RGB channel.
28 | load_proprio: If True, loads proprioceptive information.
29 | load_language: If True, loads language instructions.
30 | action_proprio_normalization_type: Normalization type to use for proprioceptive actions.
31 | """
32 | dataset_kwargs = copy.deepcopy(OXE_DATASET_CONFIGS[name])
33 | if dataset_kwargs["action_encoding"] is not ActionEncoding.EEF_POS:
34 | raise ValueError(
35 | f"Cannot load {name} since only EEF pose delta action encoding is supported."
36 | )
37 |
38 | # with EEF_POS actions, only the last action dimension (the gripper) is absolute
39 | dataset_kwargs["absolute_action_mask"] = [False] * 6 + [True]
40 |
41 | # we also want to skip normalizing the gripper action
42 | dataset_kwargs["action_normalization_mask"] = [True] * 6 + [False]
43 |
44 | # adjust loaded camera views
45 | if missing_keys := (set(load_camera_views) - set(dataset_kwargs["image_obs_keys"])):
46 | raise ValueError(
47 | f"Cannot load {name} with views {missing_keys} since they are not available."
48 | )
49 | dataset_kwargs["image_obs_keys"] = {
50 | k: v
51 | for k, v in dataset_kwargs["image_obs_keys"].items()
52 | if k in load_camera_views
53 | }
54 | dataset_kwargs["depth_obs_keys"] = {
55 | k: v
56 | for k, v in dataset_kwargs["depth_obs_keys"].items()
57 | if k in load_camera_views
58 | }
59 |
60 | if not load_depth:
61 | dataset_kwargs.pop("depth_obs_keys")
62 | if not load_proprio:
63 | dataset_kwargs.pop("state_obs_keys")
64 |
65 | if load_language:
66 | dataset_kwargs["language_key"] = "language_instruction"
67 |
68 | dataset_kwargs[
69 | "action_proprio_normalization_type"
70 | ] = action_proprio_normalization_type
71 |
72 | del dataset_kwargs["state_encoding"]
73 | del dataset_kwargs["action_encoding"]
74 |
75 | dataset_kwargs["standardize_fn"] = OXE_STANDARDIZATION_TRANSFORMS[name]
76 |
77 | return {"name": name, "data_dir": data_dir, **dataset_kwargs}
78 |
79 |
80 | def make_oxe_dataset_kwargs_and_weights(
81 | data_mix: Union[str, Sequence[Tuple[str, float]]],
82 | data_dir: str,
83 | load_camera_views: Sequence[str] = ("primary",),
84 | load_depth: bool = False,
85 | load_proprio: bool = True,
86 | load_language: bool = True,
87 | action_proprio_normalization_type: NormalizationType = NormalizationType.NORMAL,
88 | ) -> Tuple[Dict[str, Any], List[float]]:
89 | """
90 | Generates dataset kwargs for a given dataset mix from the Open X-Embodiment dataset. The returned kwargs
91 | and weights can be passed directly into `octo.data.dataset.make_interleaved_dataset`.
92 |
93 | Args:
94 | data_mix: List of (dataset name, sampling weight) tuples, or a string specifying a pre-defined mix to
95 | load from `OXE_NAMED_MIXES`.
96 | data_dir: Base data directory that contains the datasets.
97 | load_camera_views: Which views to load. See `oxe_dataset_configs.py` for available views.
98 | load_depth: If True, loads corresponding depth channels for each RGB channel.
99 | load_proprio: If True, loads proprioceptive information.
100 | load_language: If True, loads language instructions.
101 | action_proprio_normalization_type: Normalization type to use for proprioceptive actions.
102 | Returns:
103 | Tuple of (dataset_kwargs_list, sampling weights).
104 | """
105 | if isinstance(data_mix, str):
106 | data_mix = OXE_NAMED_MIXES[data_mix]
107 |
108 | filtered_datasets, included_dataset_names = [], []
109 | for name, weight in data_mix:
110 | if name not in included_dataset_names:
111 | filtered_datasets.append((name, weight))
112 | included_dataset_names.append(name)
113 | else:
114 | logging.warning(f"Skipping duplicate: {(name, weight)}.")
115 | data_mix = filtered_datasets
116 |
117 | data_kwargs_list, weights = [], []
118 | for name, weight in data_mix:
119 | try:
120 | data_kwargs_list.append(
121 | make_oxe_dataset_kwargs(
122 | name,
123 | data_dir,
124 | load_camera_views,
125 | load_depth,
126 | load_proprio,
127 | load_language,
128 | action_proprio_normalization_type,
129 | )
130 | )
131 | weights.append(weight)
132 | except ValueError as e:
133 | logging.warning(f"Skipping {name} due to error: {e}")
134 |
135 | return data_kwargs_list, weights
136 |
--------------------------------------------------------------------------------
/utils/distributed_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Util functions for setting up distributed training.
3 | Credit: https://github.com/mlfoundations/open_clip/blob/main/src/training/distributed.py
4 | """
5 | import os
6 | import torch
7 | from datetime import timedelta
8 |
9 | def is_global_master(args):
10 | return args.rank == 0
11 |
12 | def is_local_master(args):
13 | return args.local_rank == 0
14 |
15 | def is_master(args, local=False):
16 | return is_local_master(args) if local else is_global_master(args)
17 |
18 | def is_using_distributed():
19 | if "WORLD_SIZE" in os.environ:
20 | return int(os.environ["WORLD_SIZE"]) > 1
21 | if "SLURM_NTASKS" in os.environ:
22 | return int(os.environ["SLURM_NTASKS"]) > 1
23 | return False
24 |
25 | def world_info_from_env():
26 | local_rank = 0
27 | for v in (
28 | "LOCAL_RANK",
29 | "MPI_LOCALRANKID",
30 | "SLURM_LOCALID",
31 | "OMPI_COMM_WORLD_LOCAL_RANK",
32 | ):
33 | if v in os.environ:
34 | local_rank = int(os.environ[v])
35 | break
36 | global_rank = 0
37 | for v in ("RANK", "PMI_RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK"):
38 | if v in os.environ:
39 | global_rank = int(os.environ[v])
40 | break
41 | world_size = 1
42 | for v in ("WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE"):
43 | if v in os.environ:
44 | world_size = int(os.environ[v])
45 | break
46 |
47 | return local_rank, global_rank, world_size
48 |
49 | # def init_distributed_device(args):
50 | # # Distributed training = training on more than one GPU.
51 | # # Works in both single and multi-node scenarios.
52 | # args.distributed = False
53 | # args.world_size = 1
54 | # args.rank = 0 # global rank
55 | # args.local_rank = 0
56 |
57 | # if is_using_distributed():
58 | # if "SLURM_PROCID" in os.environ:
59 | # # DDP via SLURM
60 | # args.local_rank, args.rank, args.world_size = world_info_from_env()
61 | # # SLURM var -> torch.distributed vars in case needed
62 | # os.environ["LOCAL_RANK"] = str(args.local_rank)
63 | # os.environ["RANK"] = str(args.rank)
64 | # os.environ["WORLD_SIZE"] = str(args.world_size)
65 | # torch.distributed.init_process_group(
66 | # backend=args.dist_backend,
67 | # init_method=args.dist_url,
68 | # world_size=args.world_size,
69 | # rank=args.rank,
70 | # timeout=timedelta(seconds=36000000)
71 | # )
72 | # else:
73 | # # DDP via torchrun, torch.distributed.launch
74 | # args.local_rank, _, _ = world_info_from_env()
75 | # torch.distributed.init_process_group(
76 | # backend=args.dist_backend, init_method=args.dist_url, timeout=timedelta(seconds=36000000)
77 | # )
78 | # args.world_size = torch.distributed.get_world_size()
79 | # args.rank = torch.distributed.get_rank()
80 | # args.distributed = True
81 | # else:
82 | # # needed to run on single gpu
83 | # torch.distributed.init_process_group(
84 | # backend=args.dist_backend,
85 | # init_method=args.dist_url,
86 | # world_size=1,
87 | # rank=0,
88 | # timeout=timedelta(seconds=36000000)
89 | # )
90 |
91 | # if torch.cuda.is_available():
92 | # if args.distributed and not args.no_set_device_rank:
93 | # device = "cuda:%d" % args.local_rank
94 | # else:
95 | # device = "cuda:0"
96 | # torch.cuda.set_device(device)
97 | # else:
98 | # device = "cpu"
99 | # args.device = device
100 | # device = torch.device(device)
101 | # return device
102 |
103 | def init_distributed_device(args):
104 | # Distributed training = training on more than one GPU.
105 | # Works in both single and multi-node scenarios.
106 | args.distributed = False
107 | args.world_size = 1
108 | args.rank = 0 # global rank
109 | args.local_rank = 0
110 | from pdb import set_trace
111 | import datetime
112 | if is_using_distributed():
113 | if "SLURM_PROCID" in os.environ:
114 | # DDP via SLURM
115 | args.local_rank, args.rank, args.world_size = world_info_from_env()
116 | # SLURM var -> torch.distributed vars in case needed
117 | os.environ["LOCAL_RANK"] = str(args.local_rank)
118 | os.environ["RANK"] = str(args.rank)
119 | os.environ["WORLD_SIZE"] = str(args.world_size)
120 | os.environ['NCCL_BLOCKING_WAIT'] = '0' # not to enforce timeout
121 | torch.distributed.init_process_group(
122 | # timeout=timedelta(seconds=7200000), # was 1800000
123 | timeout=datetime.timedelta(seconds=7200),
124 | backend=args.dist_backend,
125 | init_method=args.dist_url,
126 | world_size=args.world_size,
127 | rank=args.rank,
128 | )
129 | else:
130 | # DDP via torchrun, torch.distributed.launch
131 | os.environ['NCCL_BLOCKING_WAIT'] = '0' # not to enforce timeout
132 | args.local_rank, _, _ = world_info_from_env()
133 | torch.distributed.init_process_group(
134 | # timeout=timedelta(seconds=7200000), # was 18000
135 | timeout=datetime.timedelta(seconds=17200),
136 | backend=args.dist_backend,
137 | init_method=args.dist_url
138 | )
139 | args.world_size = torch.distributed.get_world_size()
140 | args.rank = torch.distributed.get_rank()
141 | args.distributed = True
142 | else:
143 | # needed to run on single gpu
144 | torch.distributed.init_process_group(
145 | backend=args.dist_backend,
146 | init_method=args.dist_url,
147 | world_size=1,
148 | rank=0,
149 | )
150 |
151 | if torch.cuda.is_available():
152 | if args.distributed and not args.no_set_device_rank:
153 | device = "cuda:%d" % args.local_rank
154 | else:
155 | device = "cuda:0"
156 | torch.cuda.set_device(device)
157 | else:
158 | device = "cpu"
159 | args.device = device
160 | device = torch.device(device)
161 | return device
162 |
--------------------------------------------------------------------------------
/slurm_train_intern.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | A script to run multinode training with submitit.
16 | Almost copy-paste from https://github.com/facebookresearch/deit/blob/main/run_with_submitit.py
17 | """
18 | import argparse
19 | import os
20 | import re
21 | import random
22 | import uuid
23 | from pathlib import Path
24 | import train
25 | from utils.arguments_utils import get_parser
26 | import submitit
27 |
28 |
29 | def parse_args():
30 | parser = argparse.ArgumentParser("Submitit for DINO", parents=[get_parser()])
31 | print("!!!")
32 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node")
33 | parser.add_argument("--nodes", default=4, type=int, help="Number of nodes to request")
34 | parser.add_argument("--timeout", default=72000, type=int, help="Duration of the job")
35 | parser.add_argument("--partition", default="mozi-S1", type=str, help="Partition where to submit")
36 | parser.add_argument("--use_volta32", action='store_true', help="Big models? Use this")
37 | parser.add_argument('--comment', default="", type=str,
38 | help='Comment to pass to scheduler, e.g. priority message')
39 | parser.add_argument("--exclude", default="", type=str, help="Nodes to exclude")
40 | parser.add_argument("--output_dir", default="/mnt/petrelfs/tianyang/Code/ICLR_Manipulation/out", type=str)
41 | return parser.parse_args()
42 |
43 | def get_shared_folder() -> Path:
44 | user = os.getenv("USER")
45 | if Path(f"/ailab/user/{user}/").is_dir():
46 | p = Path(f"/ailab/user/{user}/experiments")
47 | p.mkdir(exist_ok=True)
48 | return p
49 | raise RuntimeError("No shared folder available")
50 |
51 | def get_init_file():
52 | # Init file must not exist, but it's parent dir must exist.
53 | os.makedirs(str(get_shared_folder()), exist_ok=True)
54 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init"
55 | if init_file.exists():
56 | os.remove(str(init_file))
57 | return init_file
58 |
59 | def _get_master_port(seed):
60 | MIN_MASTER_PORT, MAX_MASTER_PORT = (20_000, 60_000)
61 | master_port_str = os.environ.get("MASTER_PORT")
62 | if master_port_str is None:
63 | rng = random.Random(seed)
64 | return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT)
65 | return int(master_port_str)
66 |
67 | def _parse_slurm_node_list(s):
68 | nodes = []
69 | # Extract "hostname", "hostname[1-2,3,4-5]," substrings
70 | p = re.compile(r"(([^\[]+)(?:\[([^\]]+)\])?),?")
71 | for m in p.finditer(s):
72 | prefix, suffixes = s[m.start(2) : m.end(2)], s[m.start(3) : m.end(3)]
73 | prefix_list = prefix.split(',')
74 | if len(prefix_list) > 1:
75 | nodes += prefix_list[:-1]
76 | prefix = prefix_list[-1]
77 | for suffix in suffixes.split(","):
78 | span = suffix.split("-")
79 | if len(span) == 1:
80 | nodes.append(prefix + suffix)
81 | else:
82 | width = len(span[0])
83 | start, end = int(span[0]), int(span[1]) + 1
84 | for i in range(start, end):
85 | nodes.append(prefix + f"{i:0{width}}")
86 |
87 | return nodes
88 |
89 | class Trainer(object):
90 | def __init__(self, args):
91 | self.args = args
92 | def __call__(self):
93 | # import run_beit_pretraining
94 | self._setup_gpu_args()
95 | train.main(self.args)
96 |
97 | def checkpoint(self):
98 | import os
99 | import submitit
100 | # self.args.dist_url = get_init_file().as_uri()
101 | print("Requeuing ", self.args)
102 | empty_trainer = type(self)(self.args)
103 | return submitit.helpers.DelayedSubmission(empty_trainer)
104 |
105 | def _setup_gpu_args(self):
106 | import submitit
107 | from pathlib import Path
108 | job_id = int(os.environ["SLURM_JOB_ID"])
109 | node_count = int(os.environ["SLURM_JOB_NUM_NODES"])
110 | print("node_list :", os.environ["SLURM_JOB_NODELIST"])
111 | nodes = _parse_slurm_node_list(os.environ["SLURM_JOB_NODELIST"])
112 | print("node_count :", node_count)
113 | print("nodes :", nodes)
114 | assert len(nodes) == node_count
115 | master_addr = nodes[0]
116 | master_port = _get_master_port(seed=job_id)
117 | os.environ['MASTER_ADDR'] = master_addr
118 | os.environ['MASTER_PORT'] = str(master_port)
119 | job_env = submitit.JobEnvironment()
120 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id)))
121 | self.args.gpu = job_env.local_rank
122 | self.args.rank = job_env.global_rank
123 | self.args.world_size = job_env.num_tasks
124 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
125 |
126 | def main():
127 | args = parse_args()
128 | if args.output_dir == "":
129 | args.output_dir = get_shared_folder() / "%j"
130 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
131 | executor = submitit.SlurmExecutor(folder=args.output_dir, max_num_timeout=30)
132 | num_gpus_per_node = args.ngpus
133 | nodes = args.nodes
134 | timeout_min = args.timeout
135 | partition = args.partition
136 | kwargs = {}
137 | if args.use_volta32:
138 | kwargs['slurm_constraint'] = 'volta32gb'
139 | if args.comment:
140 | kwargs['slurm_comment'] = args.comment
141 | if args.exclude:
142 | kwargs["exclude"] = args.exclude
143 | executor.update_parameters(
144 | gres=f"gpu:{num_gpus_per_node}",
145 | ntasks_per_node=num_gpus_per_node, # one task per GPU
146 | cpus_per_task=6,
147 | nodes=nodes,
148 | time=timeout_min,
149 | # Below are cluster dependent parameters
150 | signal_delay_s=120,
151 | partition=partition,
152 | **kwargs
153 | )
154 | executor.update_parameters(job_name="seer")
155 | # args.dist_url = get_init_file().as_uri()
156 | trainer = Trainer(args)
157 | job = executor.submit(trainer)
158 | print(f"Submitted job_id: {job.job_id}")
159 | print(f"Logs and checkpoints will be saved at: {args.output_dir}")
160 |
161 | if __name__ == "__main__":
162 | main()
--------------------------------------------------------------------------------
/real_preprocess/octo_oxe_data_utils/oxe/oxe_dataset_mixes.py:
--------------------------------------------------------------------------------
1 | """Defines dataset mixtures and weights for the Open X-Embodiment Datasets."""
2 |
3 |
4 | BRIDGE_MIX = [
5 | ("bridge_dataset", 1.0),
6 | ]
7 |
8 | RT_X_MIX = [
9 | ("fractal20220817_data", 0.54087122203),
10 | ("kuka", 0.8341046294),
11 | ("bridge_dataset", 1.0),
12 | ("taco_play", 2.0),
13 | ("jaco_play", 2.0),
14 | ("berkeley_cable_routing", 3.0),
15 | ("roboturk", 1.0),
16 | ("nyu_door_opening_surprising_effectiveness", 5.0),
17 | ("viola", 2.0),
18 | ("berkeley_autolab_ur5", 1.0),
19 | ("toto", 1.0),
20 | ]
21 |
22 |
23 | OXE_FRANKA_MIX = [
24 | ("taco_play", 1.0),
25 | ("berkeley_cable_routing", 1.0),
26 | ("viola", 1.0),
27 | ("toto", 1.0),
28 | ("stanford_hydra_dataset_converted_externally_to_rlds", 1.0),
29 | ("austin_buds_dataset_converted_externally_to_rlds", 3.0),
30 | ("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0),
31 | ("maniskill_dataset_converted_externally_to_rlds", 0.1),
32 | ("furniture_bench_dataset_converted_externally_to_rlds", 0.1),
33 | ("cmu_franka_exploration_dataset_converted_externally_to_rlds", 5.0),
34 | ("austin_sailor_dataset_converted_externally_to_rlds", 1.0),
35 | ("austin_sirius_dataset_converted_externally_to_rlds", 1.0),
36 | ("berkeley_rpt_converted_externally_to_rlds", 1.0),
37 | ("kaist_nonprehensile_converted_externally_to_rlds", 3.0),
38 | ("stanford_robocook_converted_externally_to_rlds", 1.0),
39 | ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0),
40 | ("utaustin_mutex", 1.0),
41 | # ("cmu_playing_with_food", 1.0),
42 | ("cmu_play_fusion", 1.0),
43 | ]
44 |
45 |
46 | OXE_MAGIC_SOUP = [
47 | ("fractal20220817_data", 0.54087122203),
48 | ("kuka", 0.8341046294),
49 | ("bridge_dataset", 1.0),
50 | ("taco_play", 2.0),
51 | ("jaco_play", 1.0),
52 | ("berkeley_cable_routing", 1.0),
53 | ("roboturk", 2.0),
54 | ("nyu_door_opening_surprising_effectiveness", 1.0),
55 | ("viola", 2.0),
56 | ("berkeley_autolab_ur5", 2.0),
57 | ("toto", 1.0),
58 | ("language_table", 0.1),
59 | ("stanford_hydra_dataset_converted_externally_to_rlds", 2.0),
60 | ("austin_buds_dataset_converted_externally_to_rlds", 1.0),
61 | ("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0),
62 | ("furniture_bench_dataset_converted_externally_to_rlds", 0.1),
63 | ("ucsd_kitchen_dataset_converted_externally_to_rlds", 2.0),
64 | ("austin_sailor_dataset_converted_externally_to_rlds", 1.0),
65 | ("austin_sirius_dataset_converted_externally_to_rlds", 1.0),
66 | ("bc_z", 0.2),
67 | ("dlr_edan_shared_control_converted_externally_to_rlds", 1.0),
68 | ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0),
69 | ("utaustin_mutex", 1.0),
70 | ("berkeley_fanuc_manipulation", 2.0),
71 | ("cmu_stretch", 1.0),
72 | ]
73 |
74 | OXE_MAGIC_SOUP_EEFP_ONLY = [
75 | ("fractal20220817_data", 0.54087122203),
76 | ("kuka", 0.8341046294),
77 | ("bridge_dataset", 1.0),
78 | ("taco_play", 2.0),
79 | ("jaco_play", 1.0),
80 | # ("berkeley_cable_routing", 1.0),
81 | ("roboturk", 2.0),
82 | ("nyu_door_opening_surprising_effectiveness", 1.0),
83 | ("viola", 2.0),
84 | ("berkeley_autolab_ur5", 2.0),
85 | # ("toto", 1.0),
86 | ("language_table", 0.1),
87 | ("stanford_hydra_dataset_converted_externally_to_rlds", 2.0),
88 | ("austin_buds_dataset_converted_externally_to_rlds", 1.0),
89 | # ("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0),
90 | # ("furniture_bench_dataset_converted_externally_to_rlds", 0.1),
91 | ("ucsd_kitchen_dataset_converted_externally_to_rlds", 2.0),
92 | # ("austin_sailor_dataset_converted_externally_to_rlds", 1.0),
93 | # ("austin_sirius_dataset_converted_externally_to_rlds", 1.0),
94 | ("bc_z", 0.2),
95 | ("dlr_edan_shared_control_converted_externally_to_rlds", 1.0),
96 | ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0),
97 | ("utaustin_mutex", 1.0),
98 | ("berkeley_fanuc_manipulation", 2.0),
99 | ("cmu_stretch", 1.0),
100 | ]
101 |
102 | OXE_FULL_MIX = [
103 | ("fractal20220817_data", 1.0),
104 | ("kuka", 1.0),
105 | ("bridge_dataset", 1),
106 | ("taco_play", 1.0),
107 | ("jaco_play", 1.0),
108 | ("berkeley_cable_routing", 1.0),
109 | ("roboturk", 1.0),
110 | ("nyu_door_opening_surprising_effectiveness", 1.0),
111 | ("viola", 1.0),
112 | ("berkeley_autolab_ur5", 1.0),
113 | ("toto", 1.0),
114 | ("language_table", 1.0),
115 | ("columbia_cairlab_pusht_real", 1.0),
116 | ("stanford_kuka_multimodal_dataset_converted_externally_to_rlds", 1.0),
117 | ("nyu_rot_dataset_converted_externally_to_rlds", 1.0),
118 | ("stanford_hydra_dataset_converted_externally_to_rlds", 1.0),
119 | ("austin_buds_dataset_converted_externally_to_rlds", 1.0),
120 | ("nyu_franka_play_dataset_converted_externally_to_rlds", 1.0),
121 | ("maniskill_dataset_converted_externally_to_rlds", 1.0),
122 | ("furniture_bench_dataset_converted_externally_to_rlds", 1.0),
123 | ("cmu_franka_exploration_dataset_converted_externally_to_rlds", 1.0),
124 | ("ucsd_kitchen_dataset_converted_externally_to_rlds", 1.0),
125 | ("ucsd_pick_and_place_dataset_converted_externally_to_rlds", 1.0),
126 | ("austin_sailor_dataset_converted_externally_to_rlds", 1.0),
127 | ("austin_sirius_dataset_converted_externally_to_rlds", 1.0),
128 | ("bc_z", 1.0),
129 | ("utokyo_pr2_opening_fridge_converted_externally_to_rlds", 1.0),
130 | ("utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds", 1.0),
131 | ("utokyo_xarm_pick_and_place_converted_externally_to_rlds", 1.0),
132 | ("utokyo_xarm_bimanual_converted_externally_to_rlds", 1.0),
133 | ("robo_net", 1.0),
134 | ("berkeley_mvp_converted_externally_to_rlds", 1.0),
135 | ("berkeley_rpt_converted_externally_to_rlds", 1.0),
136 | ("kaist_nonprehensile_converted_externally_to_rlds", 1.0),
137 | ("stanford_mask_vit_converted_externally_to_rlds", 1.0),
138 | ("tokyo_u_lsmo_converted_externally_to_rlds", 1.0),
139 | ("dlr_sara_pour_converted_externally_to_rlds", 1.0),
140 | ("dlr_sara_grid_clamp_converted_externally_to_rlds", 1.0),
141 | ("dlr_edan_shared_control_converted_externally_to_rlds", 1.0),
142 | ("asu_table_top_converted_externally_to_rlds", 1.0),
143 | ("stanford_robocook_converted_externally_to_rlds", 1.0),
144 | ("imperialcollege_sawyer_wrist_cam", 1.0),
145 | ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0),
146 | ("uiuc_d3field", 1.0),
147 | ("utaustin_mutex", 1.0),
148 | ("berkeley_fanuc_manipulation", 1.0),
149 | ("cmu_playing_with_food", 1.0),
150 | ("cmu_play_fusion", 1.0),
151 | ("cmu_stretch", 1.0),
152 | ("berkeley_gnm_recon", 1.0),
153 | ("berkeley_gnm_cory_hall", 1.0),
154 | ("berkeley_gnm_sac_son", 1.0),
155 | ]
156 |
157 | MYTEST_MIX = [
158 | ("cmu_stretch", 1.0),
159 | ]
160 | MYTEST_MIX1 = [
161 | ("berkeley_autolab_ur5", 1.0),
162 | ]
163 | MYTEST_MIX2 = [
164 | ("viola", 1.0),
165 | ]
166 | MYTEST_MIX3 = [
167 | ("nyu_door_opening_surprising_effectiveness", 1.0),
168 | ]
169 |
170 | OXE_NAMED_MIXES = {
171 | "bridge": BRIDGE_MIX,
172 | "rtx": RT_X_MIX,
173 | "rtx_franka": RT_X_MIX + OXE_FRANKA_MIX,
174 | "oxe_magic_soup": OXE_MAGIC_SOUP,
175 | "oxe_magic_soup_eefp_only": OXE_MAGIC_SOUP_EEFP_ONLY,
176 | "mytest": MYTEST_MIX,
177 | }
178 |
--------------------------------------------------------------------------------
/data_info/.hydra/merged_config.yaml:
--------------------------------------------------------------------------------
1 | cameras:
2 | static:
3 | _target_: calvin_env.camera.static_camera.StaticCamera
4 | name: static
5 | fov: 10
6 | aspect: 1
7 | nearval: 0.01
8 | farval: 10
9 | width: 200
10 | height: 200
11 | look_at:
12 | - -0.026242351159453392
13 | - -0.0302329882979393
14 | - 0.3920000493526459
15 | look_from:
16 | - 2.871459009488717
17 | - -2.166602199425597
18 | - 2.555159848480571
19 | up_vector:
20 | - 0.4041403970338857
21 | - 0.22629790978217404
22 | - 0.8862616969685161
23 | gripper:
24 | _target_: calvin_env.camera.gripper_camera.GripperCamera
25 | name: gripper
26 | fov: 75
27 | aspect: 1
28 | nearval: 0.01
29 | farval: 2
30 | width: 84
31 | height: 84
32 | tactile:
33 | _target_: calvin_env.camera.tactile_sensor.TactileSensor
34 | name: tactile
35 | width: 120
36 | height: 160
37 | digit_link_ids:
38 | - 10
39 | - 12
40 | visualize_gui: false
41 | config_path: conf/digit_sensor/config_digit.yml
42 | vr_input:
43 | vr_controller:
44 | POSITION: 1
45 | ORIENTATION: 2
46 | ANALOG: 3
47 | BUTTONS: 6
48 | BUTTON_A: 2
49 | BUTTON_B: 1
50 | vr_controller_id: 3
51 | gripper_orientation_offset:
52 | - 0
53 | - 3
54 | - 3.14
55 | gripper_position_offset:
56 | - -0.2
57 | - 0.3
58 | - 0
59 | _target_: calvin_env.io_utils.vr_input.VrInput
60 | limit_angle:
61 | - 90
62 | - 0
63 | - 0
64 | - -1
65 | visualize_vr_pos: true
66 | reset_button_queue_len: 60
67 | env:
68 | _target_: calvin_env.envs.play_table_env.PlayTableSimEnv
69 | _recursive_: false
70 | cameras: ${cameras}
71 | seed: 0
72 | bullet_time_step: 240.0
73 | use_vr: false
74 | show_gui: ${show_gui}
75 | robot_cfg: ${robot}
76 | scene_cfg: ${scene}
77 | use_scene_info: false
78 | use_egl: true
79 | control_freq: 30
80 | scene:
81 | _target_: calvin_env.scene.play_table_scene.PlayTableScene
82 | _recursive_: false
83 | data_path: ${data_path}
84 | global_scaling: 0.8
85 | euler_obs: ${robot.euler_obs}
86 | robot_base_position:
87 | - -0.34
88 | - -0.46
89 | - 0.24
90 | robot_base_orientation:
91 | - 0
92 | - 0
93 | - 0
94 | robot_initial_joint_positions:
95 | - -1.21779206
96 | - 1.03987646
97 | - 2.11978261
98 | - -2.34205014
99 | - -0.87015947
100 | - 1.64119353
101 | - 0.55344866
102 | surfaces:
103 | table:
104 | - - 0.0
105 | - -0.15
106 | - 0.46
107 | - - 0.35
108 | - -0.03
109 | - 0.46
110 | slider_left:
111 | - - -0.32
112 | - 0.05
113 | - 0.46
114 | - - -0.16
115 | - 0.12
116 | - 0.46
117 | slider_right:
118 | - - -0.05
119 | - 0.05
120 | - 0.46
121 | - - 0.13
122 | - 0.12
123 | - 0.46
124 | objects:
125 | fixed_objects:
126 | table:
127 | file: calvin_table_D/urdf/calvin_table_D.urdf
128 | initial_pos:
129 | - 0
130 | - 0
131 | - 0
132 | initial_orn:
133 | - 0
134 | - 0
135 | - 0
136 | joints:
137 | base__slide:
138 | initial_state: 0
139 | base__drawer:
140 | initial_state: 0
141 | buttons:
142 | base__button:
143 | initial_state: 0
144 | effect: led
145 | switches:
146 | base__switch:
147 | initial_state: 0
148 | effect: lightbulb
149 | lights:
150 | lightbulb:
151 | link: light_link
152 | color:
153 | - 1
154 | - 1
155 | - 0
156 | - 1
157 | led:
158 | link: led_link
159 | color:
160 | - 0
161 | - 1
162 | - 0
163 | - 1
164 | movable_objects:
165 | block_red:
166 | file: blocks/block_red_middle.urdf
167 | initial_pos: any
168 | initial_orn: any
169 | block_blue:
170 | file: blocks/block_blue_small.urdf
171 | initial_pos: any
172 | initial_orn: any
173 | block_pink:
174 | file: blocks/block_pink_big.urdf
175 | initial_pos: any
176 | initial_orn: any
177 | name: calvin_scene_D
178 | robot:
179 | _target_: calvin_env.robot.robot.Robot
180 | filename: franka_panda/panda_longer_finger.urdf
181 | base_position: ${scene.robot_base_position}
182 | base_orientation: ${scene.robot_base_orientation}
183 | initial_joint_positions: ${scene.robot_initial_joint_positions}
184 | max_joint_force: 200.0
185 | gripper_force: 200
186 | arm_joint_ids:
187 | - 0
188 | - 1
189 | - 2
190 | - 3
191 | - 4
192 | - 5
193 | - 6
194 | gripper_joint_ids:
195 | - 9
196 | - 11
197 | gripper_joint_limits:
198 | - 0
199 | - 0.04
200 | tcp_link_id: 15
201 | end_effector_link_id: 7
202 | gripper_cam_link: 12
203 | use_nullspace: true
204 | max_velocity: 2
205 | use_ik_fast: false
206 | magic_scaling_factor_pos: 1
207 | magic_scaling_factor_orn: 1
208 | use_target_pose: true
209 | euler_obs: true
210 | tasks:
211 | _target_: calvin_env.envs.tasks.Tasks
212 | tasks:
213 | rotate_red_block_right:
214 | - rotate_object
215 | - block_red
216 | - -60
217 | rotate_red_block_left:
218 | - rotate_object
219 | - block_red
220 | - 60
221 | rotate_blue_block_right:
222 | - rotate_object
223 | - block_blue
224 | - -60
225 | rotate_blue_block_left:
226 | - rotate_object
227 | - block_blue
228 | - 60
229 | rotate_pink_block_right:
230 | - rotate_object
231 | - block_pink
232 | - -60
233 | rotate_pink_block_left:
234 | - rotate_object
235 | - block_pink
236 | - 60
237 | push_red_block_right:
238 | - push_object
239 | - block_red
240 | - 0.1
241 | - 0
242 | push_red_block_left:
243 | - push_object
244 | - block_red
245 | - -0.1
246 | - 0
247 | push_blue_block_right:
248 | - push_object
249 | - block_blue
250 | - 0.1
251 | - 0
252 | push_blue_block_left:
253 | - push_object
254 | - block_blue
255 | - -0.1
256 | - 0
257 | push_pink_block_right:
258 | - push_object
259 | - block_pink
260 | - 0.1
261 | - 0
262 | push_pink_block_left:
263 | - push_object
264 | - block_pink
265 | - -0.1
266 | - 0
267 | move_slider_left:
268 | - move_door_rel
269 | - base__slide
270 | - 0.15
271 | move_slider_right:
272 | - move_door_rel
273 | - base__slide
274 | - -0.15
275 | open_drawer:
276 | - move_door_rel
277 | - base__drawer
278 | - 0.12
279 | close_drawer:
280 | - move_door_rel
281 | - base__drawer
282 | - -0.12
283 | lift_red_block_table:
284 | - lift_object
285 | - block_red
286 | - 0.05
287 | - table
288 | - base_link
289 | lift_red_block_slider:
290 | - lift_object
291 | - block_red
292 | - 0.03
293 | - table
294 | - plank_link
295 | lift_red_block_drawer:
296 | - lift_object
297 | - block_red
298 | - 0.05
299 | - table
300 | - drawer_link
301 | lift_blue_block_table:
302 | - lift_object
303 | - block_blue
304 | - 0.05
305 | - table
306 | - base_link
307 | lift_blue_block_slider:
308 | - lift_object
309 | - block_blue
310 | - 0.03
311 | - table
312 | - plank_link
313 | lift_blue_block_drawer:
314 | - lift_object
315 | - block_blue
316 | - 0.05
317 | - table
318 | - drawer_link
319 | lift_pink_block_table:
320 | - lift_object
321 | - block_pink
322 | - 0.05
323 | - table
324 | - base_link
325 | lift_pink_block_slider:
326 | - lift_object
327 | - block_pink
328 | - 0.03
329 | - table
330 | - plank_link
331 | lift_pink_block_drawer:
332 | - lift_object
333 | - block_pink
334 | - 0.05
335 | - table
336 | - drawer_link
337 | place_in_slider:
338 | - place_object
339 | - table
340 | - plank_link
341 | place_in_drawer:
342 | - place_object
343 | - table
344 | - drawer_link
345 | stack_block:
346 | - stack_objects
347 | unstack_block:
348 | - unstack_objects
349 | turn_on_lightbulb:
350 | - toggle_light
351 | - lightbulb
352 | - 0
353 | - 1
354 | turn_off_lightbulb:
355 | - toggle_light
356 | - lightbulb
357 | - 1
358 | - 0
359 | turn_on_led:
360 | - toggle_light
361 | - led
362 | - 0
363 | - 1
364 | turn_off_led:
365 | - toggle_light
366 | - led
367 | - 1
368 | - 0
369 | push_into_drawer:
370 | - push_object_into
371 | - - block_red
372 | - block_blue
373 | - block_pink
374 | - table
375 | - base_link
376 | - table
377 | - drawer_link
378 | recorder:
379 | record: ${record}
380 | record_fps: 30.0
381 | show_fps: false
382 | enable_tts: true
383 | seed: 0
384 | use_vr: true
385 | data_path: data
386 | save_dir: /home/meeso/recordings
387 | record: true
388 | load_dir: /work/dlclarge2/meeso-lfp/calvin_recordings/play_env_A/2021-09-12/11-36-05/
389 | show_gui: false
390 | processes: 16
391 | max_episode_frames: 1
392 | save_body_infos: true
393 | set_static_cam: false
394 |
--------------------------------------------------------------------------------
/real_preprocess/mujoco_menagerie/robotiq_2f85/2f85.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
49 |
50 |
51 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
65 |
66 |
67 |
68 |
69 |
70 |
72 |
73 |
74 |
75 |
76 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
86 |
87 |
88 |
89 |
90 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
111 |
112 |
113 |
114 |
115 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
125 |
126 |
127 |
128 |
129 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
177 |
178 |
179 |
187 |
188 |
190 |
191 |
192 |
--------------------------------------------------------------------------------
/utils/convert_libero_per_step.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import random
4 | import torch.multiprocessing as mp
5 | import torch.distributed as dist
6 | import numpy as np
7 | import h5py
8 | import tqdm
9 | import argparse
10 | from pathlib import Path
11 | import yaml
12 | from PIL import Image
13 | import time
14 | import json
15 | import math
16 | # from sentence_transformers import SentenceTransformer
17 |
18 | ### Rotation ###
19 | from scipy.spatial.transform import Rotation
20 |
21 |
22 | def setup(rank, world_size, port):
23 | os.environ["MASTER_ADDR"] = "localhost"
24 | os.environ["MASTER_PORT"] = str(port)
25 | dist.init_process_group("nccl", rank=rank, world_size=world_size)
26 |
27 | def extract_task_information(file_name, path):
28 | """
29 | Extracts task information from the given file name.
30 | """
31 | # Regular expression pattern to extract the task name
32 | pattern = r'{}/((.+)_SCENE[0-9]+_(.+))_demo\.hdf5'.format(path)
33 |
34 | # Extracting the task name
35 | match = re.search(pattern, file_name)
36 |
37 | print(match.group(3).lower().replace("_", " "))
38 | return match.group(1).lower() if match else None, match.group(3).lower().replace("_", " ")
39 |
40 |
41 | class DatasetConverter:
42 | def __init__(
43 | self,
44 | src_dir: str,
45 | tgt_dir: str,
46 | rank: int,
47 | num_worker: int,
48 | start_episode_idx,
49 | end_episode_idx,
50 | ):
51 | self.src_dir = src_dir
52 | self.tgt_dir = tgt_dir
53 | self.rank = rank
54 | self.num_worker = num_worker
55 | self.start_episode_idx = start_episode_idx
56 | self.end_episode_idx = end_episode_idx
57 | # self.lang_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
58 |
59 | def process_episode(self, episode_dir, language_instructions, demo_data, episode_index, episode_index_in_task):
60 | i = episode_index_in_task
61 | # get episode dir
62 | episode_dir.mkdir(exist_ok=True)
63 |
64 | ### Get agent's view camera
65 | obs = np.array(demo_data['demo_{}'.format(i)]['obs']['agentview_rgb'])
66 | # obs = obs.transpose(0,3,1,2)
67 |
68 | ### Get wrist's view camera
69 | obs_wrist = np.array(demo_data['demo_{}'.format(i)]['obs']['eye_in_hand_rgb'])
70 | # obs_wrist = obs_wrist.transpose(0,3,1,2)
71 |
72 | ### Get actions
73 | action = np.array(demo_data['demo_{}'.format(i)]['actions']) # -1 open, 1 close
74 |
75 | joint_state = np.array(demo_data['demo_{}'.format(i)]['obs']['joint_states'])
76 | ee_pos = np.array(demo_data['demo_{}'.format(i)]['obs']['ee_pos'])
77 | ee_ori = np.array(demo_data['demo_{}'.format(i)]['obs']['ee_ori'])
78 | ee_state = np.array(demo_data['demo_{}'.format(i)]['obs']['ee_states'])
79 |
80 | gripper_state = np.zeros_like(action[:, -1])
81 | gripper_state[1:] = action[:-1, -1]
82 | gripper_state[0] = action[0, -1]
83 |
84 | gripper_position = np.array(demo_data['demo_{}'.format(i)]['obs']['gripper_states'])
85 | gripper_command = action[:, -1]
86 |
87 | # task emb
88 | # task_emb = self.lang_model.encode(language_instructions)
89 |
90 | # get episode length
91 | num_steps = obs.shape[0]
92 |
93 | episode_dir = episode_dir/str(episode_index).zfill(6)
94 | episode_dir.mkdir(exist_ok=True)
95 |
96 | # save episode length and language instruction
97 | with h5py.File(f'{episode_dir}/meta_info.h5', 'w') as h5_file:
98 | h5_file.create_dataset(name='length', data=num_steps)
99 |
100 | steps_dir = episode_dir/'steps'
101 | steps_dir.mkdir(exist_ok=True)
102 | for step_index in range(num_steps):
103 | step_dir = episode_dir/'steps'/str(step_index).zfill(4)
104 | step_dir.mkdir(exist_ok=True)
105 |
106 | with h5py.File(f'{step_dir}/other.h5', 'w') as h5_file:
107 | # language instruction
108 | h5_file.create_dataset('language_instruction', data=np.array(language_instructions, dtype=h5py.string_dtype(encoding='utf-8')))
109 | # task emb
110 | # h5_file.create_dataset(name='task_emb', data=task_emb)
111 |
112 | # episode length
113 | h5_file.create_dataset(name='episode_length', data=num_steps)
114 |
115 | # action
116 | h5_file.create_dataset(name='action', data=action[step_index])
117 |
118 | # observation (timestep, proprio, image_XXX)
119 | observation_group = h5_file.create_group(name='observation')
120 |
121 | ## image
122 | # ### image_primary
123 | Image.fromarray(obs[step_index]).save(f'{step_dir}/image_primary.jpg')
124 | ### image_wrist
125 | Image.fromarray(obs_wrist[step_index]).save(f'{step_dir}/image_wrist.jpg')
126 |
127 | ## proprio
128 | observation_group.create_dataset(name='proprio', data=joint_state[step_index])
129 |
130 | ## tcp_pose
131 | observation_group.create_dataset(name='tcp_pose', data=ee_state[step_index])
132 |
133 | ## gripper state (-1 or 1)
134 | observation_group.create_dataset(name='gripper_state', data=gripper_state[step_index])
135 |
136 | ## gripper position (n, 2)
137 | observation_group.create_dataset(name='gripper_position', data=gripper_position[step_index])
138 |
139 | def convert_origin_dataset_to_target(self, dataset_by_task):
140 | # /dataset_0
141 | # |_meta_info.h5
142 | # |_/episodes
143 | # | |_/0
144 | # | | |_/steps
145 | # | | |_/0
146 | # | | |_other.h5
147 | # | | |_XXX.jpg
148 | # | |...
149 | # | |_/1
150 | # | |_...
151 | # /dataset_1
152 | # |
153 | episodes_dir = self.tgt_dir/'episodes'
154 | episodes_dir.mkdir(exist_ok=True)
155 |
156 | num_episodes = 0
157 | for dataset in dataset_by_task:
158 | num_episodes += dataset['num_episode']
159 |
160 | if self.rank == 0:
161 | with h5py.File(f'{str(self.tgt_dir)}/meta_info.h5', 'w') as h5_file:
162 | h5_file.create_dataset(name='num_episodes', data=num_episodes)
163 |
164 | processed_task_num_episode = 0
165 | task_index = 0
166 |
167 | for episode_index in range(num_episodes):
168 | episode_index_in_task = episode_index - processed_task_num_episode
169 |
170 | if episode_index < self.start_episode_idx:
171 | continue
172 | if self.end_episode_idx is not None:
173 | if episode_index >= self.end_episode_idx:
174 | break
175 | if episode_index % self.num_worker != self.rank:
176 | if episode_index_in_task+1 == dataset_by_task[task_index]['num_episode']:
177 | processed_task_num_episode += dataset_by_task[task_index]['num_episode']
178 | task_index += 1
179 | continue
180 | print(self.rank, episode_index, '/' , num_episodes)
181 | self.process_episode(episode_dir=episodes_dir, language_instructions=dataset_by_task[task_index]['language'], demo_data=dataset_by_task[task_index]['data'], episode_index=episode_index, episode_index_in_task=episode_index_in_task)
182 |
183 | if episode_index_in_task+1 == dataset_by_task[task_index]['num_episode']:
184 | processed_task_num_episode += dataset_by_task[task_index]['num_episode']
185 | task_index += 1
186 |
187 | def run(self):
188 | print(f'target dir: {self.tgt_dir}')
189 |
190 | dataset_by_task = []
191 | for path in list(Path(self.src_dir).iterdir()):
192 | path_name = str(path)
193 | task_name, task_language = extract_task_information(path_name, self.src_dir)
194 | demo_data = h5py.File(path_name, 'r')['data']
195 | num_episode = len(demo_data)
196 | dataset = {
197 | 'language': task_language,
198 | 'num_episode': num_episode,
199 | 'data': demo_data
200 | }
201 | dataset_by_task.append(dataset)
202 |
203 | self.convert_origin_dataset_to_target(dataset_by_task)
204 |
205 | print(f'data saved at {self.tgt_dir}')
206 |
207 | # get data_info.json
208 | data_info = []
209 | episode_idx = 0
210 | total_step = 0
211 | for path in list(Path(self.src_dir).iterdir()):
212 | path_name = str(path)
213 | demo_data = h5py.File(path_name, 'r')['data']
214 | num_episode = len(demo_data)
215 | for i in range(num_episode):
216 | num_steps = np.array(demo_data['demo_{}'.format(i)]['obs']['agentview_rgb']).shape[0]
217 | data_info.append([str(episode_idx).zfill(6), num_steps])
218 | episode_idx += 1
219 | total_step += num_steps
220 | # print(total_step)
221 | with open(f'./data_info/{dataset_name}_converted.json', 'w') as f:
222 | json.dump(data_info, f)
223 |
224 | def main(rank, port, num_worker, start_episode_idx=0, end_episode_idx=None):
225 | if num_worker > 1:
226 | setup(rank, world_size=num_worker, port=port)
227 |
228 | global dataset_name
229 | dataset_name = "libero_90" # "libero_10"
230 | src_dir = f"/fs-computility/efm/shared/datasets/Banana/tianyang/Data/{dataset_name}"
231 | tgt_dir = Path(f"/fs-computility/efm/shared/datasets/Banana/tianyang/Data/{dataset_name}_converted")
232 | tgt_dir.mkdir(exist_ok=True)
233 |
234 | dataset_converter = DatasetConverter(
235 | src_dir=src_dir,
236 | tgt_dir=tgt_dir,
237 | rank=rank,
238 | num_worker=num_worker,
239 | start_episode_idx=start_episode_idx, # the dataset[start_episode_idx] will be processed
240 | end_episode_idx=end_episode_idx, # None means the last episode. if not none, the dataset[end_episode_idx - 1] will be processed and the dataset[end_episode_idx] will not be processed
241 | )
242 | dataset_converter.run()
243 |
244 | if __name__ == '__main__':
245 | start_episode_idx = 0
246 | end_episode_idx = None
247 | num_worker = 8
248 | port = (random.randint(0, 3000) % 3000) + 27000
249 |
250 | assert num_worker > 1
251 | mp.spawn(main, args=(port, num_worker, start_episode_idx, end_episode_idx), nprocs=num_worker, join=True)
252 |
253 | # main(0, port, 1, start_episode_idx=start_episode_idx, end_episode_idx=end_episode_idx)
254 |
--------------------------------------------------------------------------------
/models/vit_mae.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 |
6 | from timm.models.vision_transformer import PatchEmbed, Block
7 |
8 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
9 | """
10 | grid_size: int of the grid height and width
11 | return:
12 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
13 | """
14 | grid_h = np.arange(grid_size, dtype=np.float32)
15 | grid_w = np.arange(grid_size, dtype=np.float32)
16 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
17 | grid = np.stack(grid, axis=0)
18 |
19 | grid = grid.reshape([2, 1, grid_size, grid_size])
20 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
21 | if cls_token:
22 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
23 | return pos_embed
24 |
25 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
26 | assert embed_dim % 2 == 0
27 |
28 | # use half of dimensions to encode grid_h
29 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
30 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
31 |
32 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
33 | return emb
34 |
35 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
36 | """
37 | embed_dim: output dimension for each position
38 | pos: a list of positions to be encoded: size (M,)
39 | out: (M, D)
40 | """
41 | assert embed_dim % 2 == 0
42 | omega = np.arange(embed_dim // 2, dtype=np.float32)
43 | omega /= embed_dim / 2.
44 | omega = 1. / 10000**omega # (D/2,)
45 |
46 | pos = pos.reshape(-1) # (M,)
47 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
48 |
49 | emb_sin = np.sin(out) # (M, D/2)
50 | emb_cos = np.cos(out) # (M, D/2)
51 |
52 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
53 | return emb
54 |
55 | class MaskedAutoencoderViT(nn.Module):
56 | """ Masked Autoencoder with VisionTransformer backbone
57 | """
58 | def __init__(self, img_size=224, patch_size=16, in_chans=3,
59 | embed_dim=1024, depth=24, num_heads=16,
60 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
61 | mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
62 | super().__init__()
63 |
64 | # --------------------------------------------------------------------------
65 | # MAE encoder specifics
66 | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
67 | # print("path_embed.device: ", self.patch_embed.device)
68 | num_patches = self.patch_embed.num_patches
69 |
70 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
71 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
72 |
73 | self.blocks = nn.ModuleList([
74 | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) # qk_scale=None,
75 | for i in range(depth)])
76 | self.norm = norm_layer(embed_dim)
77 | # --------------------------------------------------------------------------
78 |
79 | # --------------------------------------------------------------------------
80 | # MAE decoder specifics
81 | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
82 |
83 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
84 |
85 | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
86 |
87 | self.decoder_blocks = nn.ModuleList([
88 | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) # qk_scale=None,
89 | for i in range(decoder_depth)])
90 |
91 | self.decoder_norm = norm_layer(decoder_embed_dim)
92 | self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
93 | # --------------------------------------------------------------------------
94 |
95 | self.norm_pix_loss = norm_pix_loss
96 |
97 | self.initialize_weights()
98 |
99 | def initialize_weights(self):
100 | # initialization
101 | # initialize (and freeze) pos_embed by sin-cos embedding
102 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
103 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
104 |
105 | decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
106 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
107 |
108 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
109 | w = self.patch_embed.proj.weight.data
110 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
111 |
112 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
113 | torch.nn.init.normal_(self.cls_token, std=.02)
114 | torch.nn.init.normal_(self.mask_token, std=.02)
115 |
116 | # initialize nn.Linear and nn.LayerNorm
117 | self.apply(self._init_weights)
118 |
119 | def _init_weights(self, m):
120 | if isinstance(m, nn.Linear):
121 | # we use xavier_uniform following official JAX ViT:
122 | torch.nn.init.xavier_uniform_(m.weight)
123 | if isinstance(m, nn.Linear) and m.bias is not None:
124 | nn.init.constant_(m.bias, 0)
125 | elif isinstance(m, nn.LayerNorm):
126 | nn.init.constant_(m.bias, 0)
127 | nn.init.constant_(m.weight, 1.0)
128 |
129 | def patchify(self, imgs):
130 | """
131 | imgs: (N, 3, H, W)
132 | x: (N, L, patch_size**2 *3)
133 | """
134 | p = self.patch_embed.patch_size[0]
135 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
136 |
137 | h = w = imgs.shape[2] // p
138 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
139 | x = torch.einsum('nchpwq->nhwpqc', x)
140 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
141 | return x
142 |
143 | def unpatchify(self, x):
144 | """
145 | x: (N, L, patch_size**2 *3)
146 | imgs: (N, 3, H, W)
147 | """
148 | p = self.patch_embed.patch_size[0]
149 | h = w = int(x.shape[1]**.5)
150 | assert h * w == x.shape[1]
151 |
152 | x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
153 | x = torch.einsum('nhwpqc->nchpwq', x)
154 | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
155 | return imgs
156 |
157 | def random_masking(self, x, mask_ratio):
158 | """
159 | Perform per-sample random masking by per-sample shuffling.
160 | Per-sample shuffling is done by argsort random noise.
161 | x: [N, L, D], sequence
162 | """
163 | N, L, D = x.shape # batch, length, dim
164 | len_keep = int(L * (1 - mask_ratio))
165 |
166 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
167 |
168 | # sort noise for each sample
169 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
170 | ids_restore = torch.argsort(ids_shuffle, dim=1)
171 |
172 | # keep the first subset
173 | ids_keep = ids_shuffle[:, :len_keep]
174 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
175 |
176 | # generate the binary mask: 0 is keep, 1 is remove
177 | mask = torch.ones([N, L], device=x.device)
178 | mask[:, :len_keep] = 0
179 | # unshuffle to get the binary mask
180 | mask = torch.gather(mask, dim=1, index=ids_restore)
181 |
182 | return x_masked, mask, ids_restore
183 |
184 | def forward_encoder(self, x, mask_ratio):
185 | # embed patches
186 | # set_trace()
187 | # print("patch_embed cuda: ", next(self.patch_embed.parameters()).is_cuda)
188 | x = self.patch_embed(x)
189 |
190 | # add pos embed w/o cls token
191 | x = x + self.pos_embed[:, 1:, :]
192 |
193 | # masking: length -> length * mask_ratio
194 | x, mask, ids_restore = self.random_masking(x, mask_ratio)
195 |
196 | # append cls token
197 | cls_token = self.cls_token + self.pos_embed[:, :1, :]
198 | cls_tokens = cls_token.expand(x.shape[0], -1, -1)
199 | x = torch.cat((cls_tokens, x), dim=1)
200 |
201 | # apply Transformer blocks
202 | for blk in self.blocks:
203 | x = blk(x)
204 | x = self.norm(x)
205 |
206 | return x, mask, ids_restore
207 |
208 | def forward_decoder(self, x, ids_restore):
209 | # embed tokens
210 | x = self.decoder_embed(x)
211 |
212 | # append mask tokens to sequence
213 | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
214 | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
215 | x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
216 | x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
217 |
218 | # add pos embed
219 | x = x + self.decoder_pos_embed
220 |
221 | # apply Transformer blocks
222 | for blk in self.decoder_blocks:
223 | x = blk(x)
224 | x = self.decoder_norm(x)
225 |
226 | # predictor projection
227 | x = self.decoder_pred(x)
228 |
229 | # remove cls token
230 | x = x[:, 1:, :]
231 |
232 | return x
233 |
234 | def forward_loss(self, imgs, pred, mask):
235 | """
236 | imgs: [N, 3, H, W]
237 | pred: [N, L, p*p*3]
238 | mask: [N, L], 0 is keep, 1 is remove,
239 | """
240 | target = self.patchify(imgs)
241 | if self.norm_pix_loss:
242 | mean = target.mean(dim=-1, keepdim=True)
243 | var = target.var(dim=-1, keepdim=True)
244 | target = (target - mean) / (var + 1.e-6)**.5
245 |
246 | loss = (pred - target) ** 2
247 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch
248 |
249 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
250 | return loss
251 |
252 | def forward(self, imgs, mask_ratio=0.75):
253 | latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
254 | pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
255 | loss = self.forward_loss(imgs, pred, mask)
256 | return loss, pred, mask
--------------------------------------------------------------------------------
/utils/arguments_utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import copy
3 | import glob
4 | import os
5 | import random
6 | from collections import OrderedDict
7 | import numpy as np
8 | import yaml
9 | import torch
10 | from torch.nn.parallel import DistributedDataParallel as DDP
11 | from torch.distributed.elastic.multiprocessing.errors import record
12 |
13 |
14 | def random_seed(seed=42, rank=0):
15 | torch.manual_seed(seed + rank)
16 | np.random.seed(seed + rank)
17 | random.seed(seed + rank)
18 |
19 | def world_info_from_env():
20 | local_rank = 0
21 | for v in (
22 | "LOCAL_RANK",
23 | "MPI_LOCALRANKID",
24 | "SLURM_LOCALID",
25 | "OMPI_COMM_WORLD_LOCAL_RANK",
26 | ):
27 | if v in os.environ:
28 | local_rank = int(os.environ[v])
29 | break
30 | global_rank = 0
31 | for v in ("RANK", "PMI_RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK"):
32 | if v in os.environ:
33 | global_rank = int(os.environ[v])
34 | break
35 | world_size = 1
36 | for v in ("WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE"):
37 | if v in os.environ:
38 | world_size = int(os.environ[v])
39 | break
40 |
41 | return local_rank, global_rank, world_size
42 |
43 | def get_parser(is_eval=False):
44 | parser = argparse.ArgumentParser()
45 | parser.add_argument(
46 | "--run_name",
47 | type=str,
48 | default="RobotFlamingo",
49 | help="used to name saving directory and wandb run",
50 | )
51 | parser.add_argument("--offline", action="store_true")
52 | parser.add_argument("--num_epochs", type=int, default=1)
53 | # Sum of gradient optimization batch size
54 | parser.add_argument("--batch_size", type=int, default=1)
55 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
56 | parser.add_argument(
57 | "--resume_from_checkpoint",
58 | type=str,
59 | help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states",
60 | default=None,
61 | )
62 | parser.add_argument(
63 | "--delete_previous_checkpoint",
64 | action="store_true",
65 | help="delete previous checkpoint when saving new checkpoint",
66 | )
67 | parser.add_argument("--seed", type=int, default=42)
68 | parser.add_argument("--learning_rate", default=1e-4, type=float) # 1e-4
69 | parser.add_argument(
70 | "--lr_scheduler",
71 | default="constant",
72 | type=str,
73 | help="constant, linear, or cosine",
74 | )
75 | parser.add_argument(
76 | "--calvin_dataset",
77 | type=str,
78 | default='/mnt/petrelfs/share_data/robomani/calvin_data/task_ABCD_D',
79 | help="path to calvin_dataset",
80 | )
81 | parser.add_argument("--warmup_epochs", default=1, type=int)
82 | parser.add_argument("--local-rank", default=0, type=int)
83 | parser.add_argument("--weight_decay", default=0.1, type=float)
84 | # hot fix for torch.distributed.launch
85 | parser.add_argument(
86 | "--precision",
87 | choices=["amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32", "bf16_and_fp32"],
88 | default="fp32",
89 | help="Floating point precision.",
90 | )
91 | # data args
92 | parser.add_argument("--workers", type=int, default=16)
93 | # distributed training args
94 | parser.add_argument(
95 | "--dist-url",
96 | default="env://",
97 | type=str,
98 | help="url used to set up distributed training",
99 | )
100 | parser.add_argument(
101 | "--dist-backend", default="nccl", type=str, help="distributed backend"
102 | )
103 | parser.add_argument(
104 | "--no-set-device-rank",
105 | default=False,
106 | action="store_true",
107 | help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
108 | )
109 | # wandb args
110 | parser.add_argument("--report_to_wandb", default=False, action="store_true")
111 | parser.add_argument(
112 | "--wandb_project",
113 | type=str,
114 | )
115 | parser.add_argument(
116 | "--wandb_entity",
117 | type=str,
118 | )
119 | parser.add_argument(
120 | "--save_checkpoints_to_wandb",
121 | default=False,
122 | action="store_true",
123 | help="save checkpoints to wandb",
124 | )
125 | parser.add_argument('--rgb_pad', type=int, default=-1)
126 | parser.add_argument('--gripper_pad', type=int, default=-1)
127 | parser.add_argument(
128 | "--traj_cons",
129 | default=False,
130 | action="store_true"
131 | )
132 | parser.add_argument(
133 | "--text_aug",
134 | default=False,
135 | action="store_true"
136 | )
137 | parser.add_argument(
138 | "--residual",
139 | default=False,
140 | action="store_true"
141 | )
142 | parser.add_argument(
143 | "--dif_ws",
144 | default=False,
145 | action="store_true"
146 | )
147 | parser.add_argument(
148 | "--partial_data",
149 | default=False,
150 | action="store_true"
151 | )
152 | # data
153 | parser.add_argument("--save_every_iter", type=int, default=-1)
154 | parser.add_argument("--min_window_size", type=int, default=12)
155 | parser.add_argument("--max_window_size", type=int, default=24)
156 | parser.add_argument("--multi_step_action", type=int, default=1, help="multiple step action prediction")
157 | # ceph
158 | parser.add_argument("--data_in_ceph",default=False, action="store_true")
159 | # oxe
160 | parser.add_argument("--root_dir", type=str, default="s3://real_data")
161 | parser.add_argument("--image_primary_size", type=int, default=200)
162 | parser.add_argument("--image_wrist_size", type=int, default=84)
163 | parser.add_argument("--finetune_type", type=str, default="",)
164 | # save checkpoint
165 | parser.add_argument("--start_save_checkpoint", default=-1, type=int)
166 | parser.add_argument("--save_checkpoint", default=False, action="store_true")
167 | parser.add_argument("--save_checkpoint_path", required=True, type=str)
168 | parser.add_argument("--save_checkpoint_seq", type=int, default=1)
169 | # if validate
170 | parser.add_argument("--validation", default=False, action="store_true")
171 | # bf16 module
172 | parser.add_argument("--bf16_module", type=str, default="")
173 | # model structure
174 | parser.add_argument("--sequence_length", type=int, default=10)
175 | # for image prediction
176 | parser.add_argument("--future_steps", type=int, default=3)
177 | parser.add_argument("--num_resampler_query", type=int, default=9)
178 | parser.add_argument("--num_obs_token_per_image", type=int, default=9)
179 | parser.add_argument("--calvin_input_image_size", type=int, default=224)
180 | parser.add_argument("--patch_size", type=int, default=16)
181 | # droid
182 | parser.add_argument("--primary_mode", type=str, default="image_primary")
183 | parser.add_argument("--small_size", type=int, default=0)
184 | parser.add_argument("--dataset_info", type=str, default="droid_success")
185 | # pretrain
186 | parser.add_argument("--finetune_from_pretrained_ckpt", type=str, default=None)
187 | # loss
188 | parser.add_argument("--loss_arm_action_ratio", type=float, default=1.0)
189 | parser.add_argument("--loss_gripper_action_ratio", type=float, default=0.01)
190 | # action_pred_steps
191 | parser.add_argument("--action_pred_steps", type=int, default=1)
192 | # obs_pred
193 | parser.add_argument("--obs_pred", default=False, action="store_true")
194 | parser.add_argument("--atten_only_obs", default=False, action="store_true")
195 | parser.add_argument("--attn_robot_proprio_state", default=False, action="store_true")
196 | parser.add_argument("--atten_goal", default=0, type=int)
197 | parser.add_argument("--atten_goal_state", default=False, action="store_true")
198 | # action mask ratio
199 | parser.add_argument("--mask_l_obs_ratio", default=0.00, type=float)
200 | # reset during finetuning
201 | parser.add_argument("--reset_action_token", default=False, action="store_true")
202 | parser.add_argument("--reset_obs_token", default=False, action="store_true")
203 | parser.add_argument("--reset_mask_token", default=False, action="store_true")
204 | parser.add_argument("--reset_image_decoder", default=False, action="store_true")
205 | parser.add_argument("--reset_action_decoder", default=False, action="store_true")
206 | # loss
207 | parser.add_argument("--loss_action", default=False, action="store_true")
208 | parser.add_argument("--loss_image", default=False, action="store_true")
209 |
210 | # calvin
211 | parser.add_argument("--except_lang", default=False, action="store_true")
212 | # gpt2
213 | parser.add_argument("--transformer_layers", default=12, type=int)
214 | parser.add_argument("--hidden_dim", default=384, type=int)
215 | parser.add_argument("--transformer_heads", default=12, type=int)
216 | # pretrain, finetune, evaluate
217 | parser.add_argument('--phase', required=True, help='pretrain, finetune, evaluate')
218 | # libero
219 | parser.add_argument("--libero_path", default="/ailab/user/tianyang/Code/LIBERO")
220 | parser.add_argument("--libero_img_size", default=128, type=int)
221 | parser.add_argument("--libero_eval_max_steps", default=600, type=int)
222 | parser.add_argument("--gripper_width", default=False, action="store_true")
223 | parser.add_argument("--load_libero_file", type=str, default="h5")
224 | parser.add_argument("--eval_libero_ensembling", default=False, action="store_true")
225 | parser.add_argument("--ensembling_temp", default=0.01, type=float)
226 | # real
227 | parser.add_argument("--real_dataset_names", type=str)
228 | parser.add_argument("--use_aug_data", default=False, action="store_true")
229 | parser.add_argument("--real_eval_max_steps", default=600, type=int)
230 | # preprocess
231 | parser.add_argument("--max_rel_pos", type=float, default=0.02)
232 | parser.add_argument("--max_rel_orn", type=float, default=0.05)
233 | parser.add_argument("--magic_scaling_factor_pos", type=float, default=1.0)
234 | parser.add_argument("--magic_scaling_factor_orn", type=float, default=1.0)
235 | # for eval
236 | if is_eval:
237 | parser.add_argument("--calvin_conf_path", type=str, help="path to calvin configuration file")
238 | parser.add_argument("--future_act_len", default=-1, type=int)
239 | parser.add_argument(
240 | "--visualize",
241 | default=False,
242 | action="store_true"
243 | )
244 | parser.add_argument(
245 | "--reset",
246 | default=False,
247 | action="store_true"
248 | )
249 | parser.add_argument(
250 | "--diverse_inst",
251 | default=False,
252 | action="store_true"
253 | )
254 | parser.add_argument("--pad_length", type=int, default=-1)
255 | parser.add_argument("--window_size", type=int, default=13)
256 | parser.add_argument("--vit_checkpoint_path", type=str)
257 | args = parser.parse_args()
258 |
259 | return parser
260 |
261 | # if args.dataloading_type == "seer":
262 | # if args.phase == "pretrain":
263 | # if args.finetune_type == "calvin":
264 | # args.window_size = args.sequence_length + args.future_steps
265 | # else:
266 | # args.window_size = args.sequence_length
267 | # elif args.phase == "finetune":
268 | # args.window_size = args.sequence_length + args.future_steps
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import random
4 | from collections import OrderedDict
5 | import numpy as np
6 | import torch
7 | import wandb
8 | import clip
9 | from torch.nn.parallel import DistributedDataParallel as DDP
10 | from torch.distributed.elastic.multiprocessing.errors import record
11 | from transformers import (
12 | get_constant_schedule_with_warmup,
13 | get_cosine_schedule_with_warmup,
14 | get_linear_schedule_with_warmup,
15 | )
16 | from models.seer_model import SeerAgent
17 | from utils.train_utils import get_checkpoint, train_one_epoch_calvin, get_ckpt_name
18 | from utils.arguments_utils import get_parser
19 | from utils.data_utils import get_calvin_dataset, get_calvin_val_dataset, get_droid_dataset, get_libero_pretrain_dataset, get_libero_finetune_dataset, get_real_finetune_dataset, get_oxe_dataset
20 | from utils.distributed_utils import init_distributed_device, world_info_from_env
21 |
22 |
23 | def random_seed(seed=42, rank=0):
24 | torch.manual_seed(seed + rank)
25 | np.random.seed(seed + rank)
26 | random.seed(seed + rank)
27 |
28 | def count_parameters(model):
29 | total_params = 0
30 | trainable_params = 0
31 | for param in model.parameters():
32 | total_params += param.numel()
33 | if param.requires_grad:
34 | trainable_params += param.numel()
35 | return total_params, trainable_params
36 |
37 | @record
38 | def main(args):
39 | os.environ["WANDB_DIR"] = f"{os.path.abspath(args.save_checkpoint_path)}"
40 | if args.save_checkpoints_to_wandb and args.save_checkpoint and not args.report_to_wandb:
41 | raise ValueError("save_checkpoints_to_wandb requires report_to_wandb")
42 | if args.offline:
43 | os.environ["WANDB_MODE"] = "offline"
44 | os.environ["TRANSFORMERS_OFFLINE"] = "1"
45 | args.local_rank, args.rank, args.world_size = world_info_from_env()
46 | device_id = init_distributed_device(args)
47 | print("device_id: ", device_id)
48 | random_seed(args.seed)
49 | ptbs = args.world_size * args.batch_size * args.gradient_accumulation_steps
50 | print("training batch size:", ptbs)
51 | args.run_name = args.run_name.replace("Seer", f"Seer_ptbs{ptbs}_{args.transformer_layers}layers_{args.transformer_heads}heads_hd{args.hidden_dim}")
52 | print("run_name:", args.run_name)
53 | model = SeerAgent(
54 | finetune_type=args.finetune_type,
55 | clip_device=device_id,
56 | vit_checkpoint_path=args.vit_checkpoint_path,
57 | sequence_length=args.sequence_length,
58 | num_resampler_query=args.num_resampler_query,
59 | num_obs_token_per_image=args.num_obs_token_per_image,
60 | calvin_input_image_size=args.calvin_input_image_size,
61 | patch_size=args.patch_size,
62 | action_pred_steps=args.action_pred_steps,
63 | obs_pred=args.obs_pred,
64 | atten_only_obs=args.atten_only_obs,
65 | attn_robot_proprio_state=args.attn_robot_proprio_state,
66 | atten_goal=args.atten_goal,
67 | atten_goal_state=args.atten_goal_state,
68 | mask_l_obs_ratio=args.mask_l_obs_ratio,
69 | transformer_layers=args.transformer_layers,
70 | hidden_dim=args.hidden_dim,
71 | transformer_heads=args.transformer_heads,
72 | phase=args.phase,
73 | gripper_width=args.gripper_width,
74 | )
75 | if args.finetune_type == "calvin":
76 | calvin_dataset = get_calvin_dataset(args, model.image_processor, clip, epoch=0, except_lang=args.except_lang)
77 | elif args.finetune_type == "droid":
78 | calvin_dataset = get_droid_dataset(args, model.image_processor, clip, epoch=0)
79 | elif args.finetune_type == "libero_pretrain":
80 | calvin_dataset = get_libero_pretrain_dataset(args, model.image_processor, clip, epoch=0)
81 | elif args.finetune_type == "libero_finetune":
82 | calvin_dataset = get_libero_finetune_dataset(args, model.image_processor, clip, epoch=0)
83 | elif args.finetune_type == "real":
84 | calvin_dataset = get_real_finetune_dataset(args, model.image_processor, clip, epoch=0)
85 | elif args.finetune_type == "oxe":
86 | calvin_dataset = get_oxe_dataset(args, model.image_processor, clip, epoch=0)
87 | random_seed(args.seed, args.rank)
88 | print(f"Start running training on rank {args.rank}.")
89 | if args.rank == 0 and args.report_to_wandb:
90 | print("wandb_project :", args.wandb_project)
91 | print("wandb_entity :", args.wandb_entity)
92 | wandb.init(
93 | project=args.wandb_project,
94 | entity=args.wandb_entity,
95 | name=args.run_name,
96 | config=vars(args),
97 | )
98 | device_id = args.rank % torch.cuda.device_count()
99 | if args.precision == "bf16" or args.precision == "amp_bfloat16" or args.precision == "amp_bf16":
100 | model = model.bfloat16()
101 | elif args.precision == "fp16":
102 | model = model.half()
103 | elif args.precision == "fp32":
104 | model = model.float()
105 | if 'vision_encoder' in args.bf16_module:
106 | model.vision_encoder.bfloat16()
107 | if "causal_transformer" in args.bf16_module:
108 | model.transformer_backbone.bfloat16()
109 | if "image_decoder" in args.bf16_module:
110 | model.image_decoder.bfloat16()
111 | model.image_decoder_obs_pred_projector.bfloat16()
112 | model.clip_model.requires_grad_(False)
113 | model.vision_encoder.requires_grad_(False)
114 | total_params, trainable_params = count_parameters(model)
115 | print("total_params: {} M".format(total_params/1024/1024))
116 | print("trainable_params: {} M".format(trainable_params/1024/1024))
117 | model = model.to(device_id)
118 | model._init_model_type()
119 | ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters=True)
120 | optimizer = torch.optim.AdamW([p for p in ddp_model.parameters() if p.requires_grad], lr=args.learning_rate, weight_decay=args.weight_decay) # TODO make sure the parameters which need to be optimized are passing
121 | total_training_steps = calvin_dataset.dataloader.num_batches * args.num_epochs
122 | args.warmup_steps = calvin_dataset.dataloader.num_batches * args.warmup_epochs
123 | if args.rank == 0:
124 | print(f"Total training steps: {total_training_steps}")
125 | if args.lr_scheduler == "linear":
126 | if args.gradient_accumulation_steps > 1:
127 | lr_scheduler = get_linear_schedule_with_warmup(
128 | optimizer,
129 | num_warmup_steps=args.warmup_steps // args.gradient_accumulation_steps + 1,
130 | num_training_steps=total_training_steps // args.gradient_accumulation_steps + 1,
131 | )
132 | else:
133 | lr_scheduler = get_linear_schedule_with_warmup(
134 | optimizer,
135 | num_warmup_steps=args.warmup_steps,
136 | num_training_steps=total_training_steps,
137 | )
138 | elif args.lr_scheduler == "cosine":
139 | if args.gradient_accumulation_steps > 1:
140 | lr_scheduler = get_cosine_schedule_with_warmup(
141 | optimizer,
142 | num_warmup_steps=args.warmup_steps // args.gradient_accumulation_steps + 1,
143 | num_training_steps=total_training_steps // args.gradient_accumulation_steps + 1,
144 | )
145 | else:
146 | lr_scheduler = get_cosine_schedule_with_warmup(
147 | optimizer,
148 | num_warmup_steps=args.warmup_steps,
149 | num_training_steps=total_training_steps,
150 | )
151 | elif args.lr_scheduler == 'cosine_restart':
152 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-7)
153 | else:
154 | lr_scheduler = get_constant_schedule_with_warmup(
155 | optimizer, num_warmup_steps=args.warmup_steps
156 | )
157 | resume_from_epoch = 0
158 | if args.finetune_from_pretrained_ckpt is not None:
159 | if args.rank == 0:
160 | print(f"Starting finetuning from pretrained checkpoint {args.finetune_from_pretrained_ckpt}")
161 | checkpoint = torch.load(args.finetune_from_pretrained_ckpt, map_location="cpu")
162 | image_decoder_keys = [k for k in checkpoint["model_state_dict"].keys() if "image_decoder" in k]
163 | projector_keys = [k for k in checkpoint["model_state_dict"].keys() if "projector" in k]
164 | action_decoder_keys = [k for k in checkpoint["model_state_dict"].keys() if "action_decoder" in k]
165 | if args.reset_action_token:
166 | del checkpoint["model_state_dict"]["module.action_pred_token"]
167 | if args.reset_obs_token:
168 | del checkpoint["model_state_dict"]["module.obs_tokens"]
169 | if args.reset_mask_token:
170 | del checkpoint["model_state_dict"]["module.mask_token"]
171 | if args.reset_image_decoder:
172 | for k in image_decoder_keys:
173 | if k in checkpoint["model_state_dict"]:
174 | del checkpoint["model_state_dict"][k]
175 | if args.reset_action_decoder:
176 | for k in action_decoder_keys:
177 | if k in checkpoint["model_state_dict"]:
178 | del checkpoint["model_state_dict"][k]
179 | if checkpoint["model_state_dict"]["module.transformer_backbone_position_embedding"].shape != ddp_model.module.transformer_backbone_position_embedding.shape:
180 | checkpoint["model_state_dict"]["module.transformer_backbone_position_embedding"] = checkpoint["model_state_dict"]["module.transformer_backbone_position_embedding"][:, :args.sequence_length, :, :]
181 | print("loading pretrained weights :", checkpoint["model_state_dict"].keys())
182 | ddp_model.load_state_dict(checkpoint["model_state_dict"], False)
183 | if args.resume_from_checkpoint is not None:
184 | if args.rank == 0:
185 | print(f"Loading checkpoint from {args.resume_from_checkpoint}")
186 | checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu")
187 | ddp_model.load_state_dict(checkpoint["model_state_dict"], False)
188 | optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
189 | lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
190 | resume_from_epoch = checkpoint["epoch"] + 1
191 |
192 | ckpt_dir = os.path.join(f"{args.save_checkpoint_path}", args.run_name)
193 | if args.rank == 0 and not os.path.exists(ckpt_dir):
194 | os.makedirs(ckpt_dir)
195 |
196 | ddp_model.train()
197 | for epoch in range(resume_from_epoch, args.num_epochs):
198 | calvin_dataset.set_epoch(epoch)
199 | calvin_loader = calvin_dataset.dataloader
200 | train_one_epoch_calvin(
201 | args=args,
202 | model=ddp_model,
203 | epoch=epoch,
204 | optimizer=optimizer,
205 | lr_scheduler=lr_scheduler,
206 | calvin_loader=calvin_loader,
207 | device_id=device_id,
208 | wandb=wandb,
209 | )
210 | if args.rank == 0 and args.save_checkpoint and epoch % args.save_checkpoint_seq == 0 and epoch > args.start_save_checkpoint:
211 | checkpoint_dict = {
212 | "epoch": epoch,
213 | "model_state_dict": get_checkpoint(ddp_model),
214 | "optimizer_state_dict": optimizer.state_dict(),
215 | "lr_scheduler_state_dict": lr_scheduler.state_dict(),
216 | }
217 | ckpt_name = get_ckpt_name(args, epoch)
218 | ckpt_path = os.path.join(ckpt_dir, ckpt_name)
219 | print(f"Saving checkpoint to {ckpt_path}")
220 | torch.save(checkpoint_dict, ckpt_path)
221 | if args.delete_previous_checkpoint:
222 | if epoch > 0:
223 | os.remove(ckpt_path)
224 |
225 | if __name__ == "__main__":
226 | parser = get_parser()
227 | args = parser.parse_args()
228 | main(args)
229 |
--------------------------------------------------------------------------------