├── .gitignore ├── .gitmodules ├── README.md ├── config ├── dataset │ ├── bridge_statistics.json │ └── fractal_statistics.json ├── experiment │ ├── simpler │ │ ├── magma_bridge_ev.yaml │ │ ├── magma_bridge_ev_lang1 copy.yaml │ │ ├── magma_bridge_ev_lang1.yaml │ │ ├── magma_bridge_ev_lang2.yaml │ │ ├── magma_bridge_ev_ood.yaml │ │ ├── octo_base_bridge_ev.yaml │ │ ├── octo_small_bridge_ev.yaml │ │ ├── pi0_baseline_bridge_ev.yaml │ │ ├── pi0_baseline_bridge_ev_lang1.yaml │ │ ├── pi0_baseline_bridge_ev_lang2.yaml │ │ ├── pi0_baseline_bridge_ev_ood.yaml │ │ ├── pi0_baseline_freezevlm_ev1.yaml │ │ ├── pi0_baseline_paraphrase_ev1.yaml │ │ ├── pi0_finetune_bridge_ev.yaml │ │ ├── pi0_finetune_bridge_ev_lang1.yaml │ │ ├── pi0_finetune_bridge_ev_lang2.yaml │ │ ├── pi0_finetune_bridge_ev_ood.yaml │ │ ├── pi0_rephrase_ft_bridge_ev1.yaml │ │ ├── spatialvla_finetune_bridge_ev.yaml │ │ ├── spatialvla_finetune_bridge_ev_lang1.yaml │ │ ├── spatialvla_finetune_bridge_ev_lang2.yaml │ │ └── spatialvla_finetune_bridge_ev_ood.yaml │ └── simplerMS3 │ │ ├── pi0_baseline_bridge_ev.yaml │ │ ├── pi0_finetune_bridge_ev.yaml │ │ └── pi0fast_finetune_bridge_ev.yaml ├── models │ ├── hf_pi0.json │ ├── magma.json │ ├── octo.json │ ├── pi0_baseline_bridge.json │ ├── pi0_baseline_fractal.json │ ├── pi0_finetune_bridge.json │ ├── pi0fast_baseline_bridge.json │ ├── pi0fast_finetune_bridge.json │ └── spatialvla_finetune_bridge.json └── train │ ├── pi0_baseline_bridge.yaml │ ├── pi0_baseline_bridge_freezevlm.yaml │ ├── pi0_baseline_bridge_paraphrase.yaml │ ├── pi0_baseline_fractal.yaml │ ├── pi0_finetune_bridge.yaml │ └── pi0_finetune_bridge_paraphrase.yaml ├── doc ├── evaluation.md └── training_finetuning.md ├── packages └── policy-server-client │ ├── README.md │ ├── pyproject.toml │ └── src │ └── policy_server_client │ ├── __init__.py │ ├── base_policy.py │ ├── image_tools.py │ ├── msgpack_numpy.py │ ├── websocket_policy_client.py │ └── websocket_policy_server.py ├── pyproject.toml ├── scripts ├── dataset │ ├── modify_rlds_dataset.py │ ├── rename_rlds.sh │ ├── rlds2lerobot.py │ ├── test_lerobot_dataset.py │ └── test_rlds_dataset.py ├── eval │ ├── experiment_vis.py │ ├── gather_data_delta.py │ ├── gather_data_to_csv.py │ ├── gather_data_to_wandb.py │ ├── get_eval_obj_bbox.py │ └── test_evaluator.sh ├── singularity │ ├── cpu_activate_singularity.sh │ └── gpu_activate_singularity.sh └── training │ └── test_nyu_training.sh ├── slurms ├── dataset_scripts │ ├── convert_proprio.sh │ ├── resize_jpeg.sh │ ├── rlds2lerobot.sh │ └── test_rlds_dataset.sh ├── eval_scripts │ └── simpler │ │ ├── ev_baselinepi0_bridge_lang1_simpler.sh │ │ ├── ev_baselinepi0_bridge_lang2_simple.sh │ │ ├── ev_baselinepi0_bridge_ood_simpler.sh │ │ ├── ev_baselinepi0_bridge_simpler.sh │ │ ├── ev_baselinepi0_freezevlm_bridge_simpler_1.sh │ │ ├── ev_baselinepi0_rephrase_bridge_simpler_1.sh │ │ ├── ev_magma_bridge_lang1_simpler.sh │ │ ├── ev_magma_bridge_lang2_simpler.sh │ │ ├── ev_magma_bridge_ood_simpler.sh │ │ ├── ev_magma_bridge_simpler.sh │ │ ├── ev_octo_base_bridge_simpler.sh │ │ ├── ev_octo_small_bridge_simpler.sh │ │ ├── ev_pi0_bridge_lang1_simpler.sh │ │ ├── ev_pi0_bridge_lang2_simpler.sh │ │ ├── ev_pi0_bridge_ood_simpler.sh │ │ ├── ev_pi0_bridge_simpler.sh │ │ ├── ev_pi0_rephrase_ft_bridge_simpler_part1.sh │ │ ├── ev_spatialvla_bridge_lang1_simpler.sh │ │ ├── ev_spatialvla_bridge_lang2_simpler.sh │ │ ├── ev_spatialvla_bridge_ood_simpler.sh │ │ └── ev_spatialvla_bridge_simpler.sh └── train_scripts │ ├── pi0_baseline_bridge.sh │ ├── pi0_baseline_fractal.sh │ ├── pi0_baseline_freezevlm_bridge.sh │ ├── pi0_baseline_rephrase_bridge.sh │ ├── pi0_finetune_bridge.sh │ └── pi0_finetune_bridge_paraphrase.sh └── src ├── __init__.py ├── agent ├── __init__.py ├── configuration_pipeline.py ├── dataset.py ├── run.py └── trainer.py ├── data ├── dlimp │ ├── __init__.py │ ├── augmentations.py │ ├── dataset.py │ ├── transforms │ │ ├── __init__.py │ │ ├── common.py │ │ ├── frame_transforms.py │ │ ├── goal_relabeling.py │ │ └── traj_transforms.py │ └── utils.py ├── obs_transforms.py ├── oxe │ ├── __init__.py │ ├── oxe_dataset_configs.py │ ├── oxe_dataset_mixes.py │ ├── oxe_standardization_transforms.py │ └── preprocess │ │ ├── mod_functions.py │ │ └── multithreaded_adhoc_tfds_builder.py ├── rlds_dataset.py ├── rlds_dataset_torch.py ├── traj_transforms.py └── utils │ ├── data_utils.py │ ├── goal_relabeling.py │ ├── task_augmentation.py │ └── text_processing.py ├── experiments ├── env_adapters │ ├── base.py │ ├── language_mapper.py │ ├── libero.py │ ├── simpler.py │ └── simplerMS3.py ├── envs │ ├── base_evaluator.py │ ├── libero │ │ ├── libero_evaluator.py │ │ └── pyproject.toml │ ├── simpler │ │ ├── pyproject.toml │ │ └── simpler_evaluator.py │ └── simplerMS3 │ │ ├── pyproject.toml │ │ └── simplerMS3_evaluator.py └── policies │ ├── magma_policy_server │ └── pyproject.toml │ ├── octo_policy_server │ └── pyproject.toml │ └── policy_wrapper.py ├── model ├── magma │ ├── configuration_magma.py │ └── modeling_magma.py ├── octo │ ├── configuration_octo.py │ └── modeling_octo.py └── spatialvla │ ├── configuration_spatialvla.py │ └── modeling_spatialvla.py └── utils ├── decorator.py ├── geometry.py ├── metric.py ├── monitor.py ├── optim.py ├── pipeline.py └── spec.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.ext3 2 | hf_pretrained_ckpts/ 3 | cache/ 4 | hf-models/ 5 | uv.lock 6 | .vscode/ 7 | *.csv 8 | 9 | # Dataset. This doesn't not ignore all the other folder named 'data' 10 | /data/ 11 | 12 | *.mp4 13 | *.png 14 | set_path.sh 15 | 16 | # Hugging Face model checkpoints and safetensor files 17 | *.safetensors 18 | 19 | # Byte-compiled / optimized / DLL files 20 | __pycache__/ 21 | *.py[cod] 22 | *$py.class 23 | 24 | # C extensions 25 | *.so 26 | 27 | # Distribution / packaging 28 | .Python 29 | build/ 30 | develop-eggs/ 31 | dist/ 32 | downloads/ 33 | eggs/ 34 | .eggs/ 35 | lib/ 36 | lib64/ 37 | parts/ 38 | sdist/ 39 | var/ 40 | wheels/ 41 | pip-wheel-metadata/ 42 | share/python-wheels/ 43 | *.egg-info/ 44 | .installed.cfg 45 | *.egg 46 | MANIFEST 47 | 48 | # PyInstaller 49 | # Usually these files are written by a python script from a template 50 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 51 | *.manifest 52 | *.spec 53 | 54 | # Installer logs 55 | pip-log.txt 56 | pip-delete-this-directory.txt 57 | 58 | # Unit test / coverage reports 59 | htmlcov/ 60 | .tox/ 61 | .nox/ 62 | .coverage 63 | .coverage.* 64 | .cache 65 | nosetests.xml 66 | coverage.xml 67 | *.cover 68 | *.py,cover 69 | .hypothesis/ 70 | .pytest_cache/ 71 | 72 | # Translations 73 | *.mo 74 | *.pot 75 | 76 | # Django stuff: 77 | *.log 78 | local_settings.py 79 | db.sqlite3 80 | db.sqlite3-journal 81 | 82 | # Flask stuff: 83 | instance/ 84 | .webassets-cache 85 | 86 | # Scrapy stuff: 87 | .scrapy 88 | 89 | # Sphinx documentation 90 | docs/_build/ 91 | 92 | # PyBuilder 93 | target/ 94 | 95 | # Jupyter Notebook 96 | .ipynb_checkpoints 97 | 98 | # IPython 99 | profile_default/ 100 | ipython_config.py 101 | 102 | # pyenv 103 | .python-version 104 | 105 | # pipenv 106 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 107 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 108 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 109 | # install all needed dependencies. 110 | #Pipfile.lock 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | .venv_cp 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | uv.lock 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # Ruff 152 | .ruff_cache/ 153 | 154 | # Auth Tokens / Hidden Files 155 | .hf_token 156 | .wandb_api_key 157 | .*_token 158 | .*api_key 159 | 160 | # IDE Caches 161 | .idea/ 162 | .vscode/ 163 | 164 | # Mac OS 165 | .DS_Store 166 | 167 | # Caches 168 | cache/ 169 | 170 | # Rollout videos and wandb logs 171 | rollouts/ 172 | wandb/ 173 | log/ 174 | experiments/logs/ -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/LIBERO"] 2 | path = third_party/LIBERO 3 | url = https://github.com/Lifelong-Robot-Learning/LIBERO.git 4 | branch = master 5 | [submodule "third_party/lerobot"] 6 | path = third_party/lerobot 7 | url = https://github.com/IrvingF7/lerobot.git 8 | branch = pi0 9 | [submodule "third_party/SimplerEnv"] 10 | path = third_party/SimplerEnv 11 | url = https://github.com/juexZZ/SimplerEnv.git 12 | branch = my-changes-after-allen 13 | [submodule "third_party/SimplerEnv_MS3"] 14 | path = third_party/SimplerEnv_MS3 15 | url = https://github.com/IrvingF7/SimplerEnv_MS3.git 16 | branch = maniskill3 17 | [submodule "third_party/ManiSkill"] 18 | path = third_party/ManiSkill 19 | url = https://github.com/IrvingF7/ManiSkill.git 20 | branch = main 21 | [submodule "third_party/ManiSkill2_real2sim"] 22 | path = third_party/ManiSkill2_real2sim 23 | url = https://github.com/juexZZ/ManiSkill2_real2sim.git 24 | branch = main 25 | -------------------------------------------------------------------------------- /config/dataset/bridge_statistics.json: -------------------------------------------------------------------------------- 1 | {"action": {"mean": [0.00021758403454441577, 0.00012507825158536434, -0.00017109014152083546, -0.0001617111702216789, -0.0002524859446566552, 0.0002515816013328731, 0.5879487991333008], "std": [0.009632210247218609, 0.013500974513590336, 0.012510341592133045, 0.028145477175712585, 0.03028254210948944, 0.07585873454809189, 0.4877150356769562], "max": [0.41691166162490845, 0.25864794850349426, 0.21218234300613403, 3.122201919555664, 1.8618112802505493, 6.280478477478027, 1.0], "min": [-0.4007510244846344, -0.13874775171279907, -0.22553899884223938, -3.2010786533355713, -1.8618112802505493, -6.279075622558594, 0.0], "p99": [0.028122276067733765, 0.040630316659808145, 0.03994889184832546, 0.08121915772557152, 0.07724379181861864, 0.20214049845933896, 1.0], "p01": [-0.028539552688598632, -0.041432044506073, -0.025977383628487588, -0.08020886614918708, -0.09213060349225997, -0.2054861941933632, 0.0]}, "num_transitions": 2195527, "num_trajectories": 60064, "proprio": {"mean": [0.30904945731163025, 0.03045589290559292, 0.06558273732662201, 0.00706630339846015, -0.07828629016876221, 0.10661222040653229, 0.7149746417999268], "std": [0.06059328466653824, 0.09172434359788895, 0.05185756832361221, 0.1313914805650711, 0.1698099821805954, 0.573583722114563, 0.3517141044139862], "max": [0.5862360596656799, 0.4034728705883026, 0.36494991183280945, 1.514088749885559, 1.570796251296997, 3.1415255069732666, 1.1154625415802002], "min": [-0.04167502000927925, -0.3945816159248352, -0.15537554025650024, -3.141592502593994, -1.4992541074752808, -3.14153790473938, 0.04637829214334488], "p99": [0.4527312242984769, 0.23490807592868757, 0.1973453593254087, 0.37877989292144754, 0.27723048210143925, 1.8378053522109963, 1.0105689764022827], "p01": [0.17017078369855881, -0.16965715914964677, -0.054787094071507454, -0.3655692100524902, -0.5435487496852874, -1.3501438736915587, 0.052190229296684265]}} 2 | -------------------------------------------------------------------------------- /config/dataset/fractal_statistics.json: -------------------------------------------------------------------------------- 1 | {"action": {"mean": [0.006987567059695721, 0.006265868898481131, -0.012625112198293209, 0.04333272576332092, -0.005756245460361242, 0.0009130232501775026, 0.5354204773902893], "std": [0.06921152025461197, 0.05971040576696396, 0.07353048771619797, 0.156105175614357, 0.1316440999507904, 0.14593836665153503, 0.4971115291118622], "max": [2.9984593391418457, 22.09052848815918, 2.7507524490356445, 1.570636510848999, 1.5321086645126343, 1.5691522359848022, 1.0], "min": [-2.0204520225524902, -5.497899532318115, -2.031663417816162, -1.569917917251587, -1.569892168045044, -1.570419430732727, 0.0], "p99": [0.17824687153100965, 0.14938379630446405, 0.21842354819178575, 0.5892666035890578, 0.35272657424211445, 0.44796681255102094, 1.0], "p01": [-0.22453527510166169, -0.14820013284683228, -0.231589707583189, -0.3517994859814644, -0.4193011274933815, -0.43643461108207704, 0.0]}, "num_transitions": 3786400, "num_trajectories": 87212, "proprio": {"mean": [0.5599020719528198, -0.08333852887153625, 0.7770926356315613, -0.24803675711154938, 0.49517107009887695, 0.09266142547130585, 0.20975486934185028, 0.42613455653190613], "std": [0.12432780861854553, 0.11558882147073746, 0.24595776200294495, 0.5126982927322388, 0.5218101143836975, 0.16630391776561737, 0.2754841148853302, 0.45544859766960144], "max": [1.0534898042678833, 0.48018959164619446, 1.6896663904190063, 0.9999993443489075, 0.9999874830245972, 0.9554369449615479, 0.9914546012878418, 1.0], "min": [-0.4436439275741577, -0.9970501065254211, -0.006579156965017319, -0.8643477559089661, -0.7079970240592957, -0.7688722014427185, -0.4999994933605194, 0.0], "p99": [0.8750156319141384, 0.21247054174542404, 1.0727112340927123, 0.9377871316671368, 0.9563051050901409, 0.45990042358636823, 0.7216041100025177, 1.0], "p01": [0.32481380939483645, -0.28334290891885755, 0.14107070609927178, -0.686474204659462, -0.6808923494815826, -0.36045596331357954, -0.454380963742733, 0.0]}} 2 | -------------------------------------------------------------------------------- /config/experiment/simpler/magma_bridge_ev.yaml: -------------------------------------------------------------------------------- 1 | name: magma_bridge_simpler 2 | seed: 42 3 | model_cfg: !include ../../models/magma.json 4 | 5 | eval_cfg: 6 | simulator_name: "simpler" 7 | env_adapter: "BridgeSimplerMagmaAdapter" 8 | task_list: [ 9 | # * original tasks 10 | "widowx_spoon_on_towel", 11 | "widowx_carrot_on_plate", 12 | "widowx_stack_cube", 13 | "widowx_put_eggplant_in_basket", 14 | # * object distraction 15 | "widowx_spoon_on_towel_distract", 16 | "widowx_carrot_on_plate_distract", 17 | "widowx_carrot_on_keyboard_distract", 18 | "widowx_coke_can_on_plate_distract", 19 | "widowx_coke_can_on_keyboard_distract", 20 | # * generatization test 21 | "widowx_cube_on_plate_clean", # seen source and target, unseen combination 22 | "widowx_small_plate_on_green_cube_clean", # seen source and target, unseen combination (a reverse) 23 | "widowx_coke_can_on_plate_clean", # ood source 24 | "widowx_pepsi_on_plate_clean", # another OOD source besides coke can, also a texture difference 25 | "widowx_carrot_on_sponge_clean", # seen source and target, unseen combination 26 | "widowx_eggplant_on_sponge_clean", 27 | "widowx_carrot_on_keyboard_clean", # ood target 28 | "widowx_coke_can_on_keyboard_clean", 29 | # * language variation 30 | "widowx_carrot_on_plate_lang_common", # rabbit, no distract 31 | "widowx_carrot_on_plate_lang_action", 32 | "widowx_carrot_on_plate_lang_neg", 33 | "widowx_carrot_on_plate_lang_neg_action", # on the table not on the plate 34 | "widowx_carrot_on_plate_lang_common_distract", # rabbit 35 | "widowx_spoon_on_towel_lang_action", 36 | "widowx_spoon_on_towel_lang_common", 37 | "widowx_spoon_on_towel_lang_common_distract", 38 | "widowx_stack_cube_lang_action", 39 | "widowx_eggplant_in_basket_lang_action", 40 | "widowx_eggplant_in_basket_lang_color", 41 | "widowx_eggplant_in_basket_lang_common", 42 | "widowx_carrot_on_keyboard_lang_common", 43 | "widowx_coke_can_on_plate_lang_common", 44 | "widowx_coke_can_on_plate_lang_neg", 45 | "widowx_coke_can_on_plate_lang_common_distract", # thirsty 46 | * new 47 | "widowx_orange_juice_on_plate_clean", 48 | "widowx_orange_juice_on_plate_distract", 49 | "widowx_orange_juice_on_plate_lang_neg", 50 | "widowx_orange_juice_on_plate_lang_common", 51 | "widowx_orange_juice_on_plate_lang_common_distract", 52 | "widowx_orange_juice_on_plate_lang_common_distractv2", 53 | "widowx_nut_on_plate_clean", 54 | "widowx_nut_on_plate_lang_common", 55 | "widowx_eggplant_on_keyboard_clean", 56 | "widowx_carrot_on_ramekin_clean", 57 | "widowx_carrot_on_wheel_clean", 58 | "widowx_coke_can_on_ramekin_clean", 59 | "widowx_coke_can_on_wheel_clean", 60 | "widowx_nut_on_wheel_clean", 61 | "widowx_cube_on_plate_lang_shape", 62 | "widowx_spoon_on_towel_lang_neg", 63 | "widowx_spoon_on_towel_lang_color", 64 | "widowx_carrot_on_plate_lang_color", 65 | ] 66 | 67 | n_eval_episode: 24 68 | n_video: 24 69 | recording: True 70 | pretrained_model_path: "microsoft/Magma-8B" 71 | action_step: 1 # OpenVLA-like models usually use action emsemble to fold 4 future steps in to 1 72 | 73 | env: 74 | dataset_statistics_path: ./config/dataset/bridge_statistics.json 75 | 76 | wandb: 77 | project: "vla_benchmark" -------------------------------------------------------------------------------- /config/experiment/simpler/magma_bridge_ev_lang1 copy.yaml: -------------------------------------------------------------------------------- 1 | name: magma_bridge_simpler 2 | seed: 42 3 | model_cfg: !include ../../models/magma.json 4 | 5 | eval_cfg: 6 | simulator_name: "simpler" 7 | env_adapter: "BridgeSimplerMagmaAdapter" 8 | task_list: [ 9 | # # * original tasks 10 | # "widowx_spoon_on_towel", 11 | # "widowx_carrot_on_plate", 12 | # "widowx_stack_cube", 13 | # "widowx_put_eggplant_in_basket", 14 | # # * object distraction 15 | # "widowx_spoon_on_towel_distract", 16 | # "widowx_carrot_on_plate_distract", 17 | # "widowx_carrot_on_keyboard_distract", 18 | # "widowx_coke_can_on_plate_distract", 19 | # "widowx_coke_can_on_keyboard_distract", 20 | # # * generatization test 21 | # "widowx_cube_on_plate_clean", # seen source and target, unseen combination 22 | # "widowx_small_plate_on_green_cube_clean", # seen source and target, unseen combination (a reverse) 23 | # "widowx_coke_can_on_plate_clean", # ood source 24 | # "widowx_pepsi_on_plate_clean", # another OOD source besides coke can, also a texture difference 25 | # "widowx_carrot_on_sponge_clean", # seen source and target, unseen combination 26 | # "widowx_eggplant_on_sponge_clean", 27 | # "widowx_carrot_on_keyboard_clean", # ood target 28 | # "widowx_coke_can_on_keyboard_clean", 29 | # * language variation 30 | "widowx_carrot_on_plate_lang_common", # rabbit, no distract 31 | "widowx_carrot_on_plate_lang_action", 32 | "widowx_carrot_on_plate_lang_neg", 33 | "widowx_carrot_on_plate_lang_neg_action", # on the table not on the plate 34 | "widowx_carrot_on_plate_lang_common_distract", # rabbit 35 | "widowx_spoon_on_towel_lang_action", 36 | "widowx_spoon_on_towel_lang_common", 37 | "widowx_spoon_on_towel_lang_common_distract", 38 | # "widowx_stack_cube_lang_action", 39 | # "widowx_eggplant_in_basket_lang_action", 40 | # "widowx_eggplant_in_basket_lang_color", 41 | # "widowx_eggplant_in_basket_lang_common", 42 | # "widowx_carrot_on_keyboard_lang_common", 43 | # "widowx_coke_can_on_plate_lang_common", 44 | # "widowx_coke_can_on_plate_lang_neg", 45 | # "widowx_coke_can_on_plate_lang_common_distract", # thirsty 46 | ] 47 | 48 | n_eval_episode: 24 49 | n_video: 24 50 | recording: True 51 | pretrained_model_path: "microsoft/Magma-8B" 52 | action_step: 1 # OpenVLA-like models usually use action emsemble to fold 4 future steps in to 1 53 | 54 | env: 55 | dataset_statistics_path: ./config/dataset/bridge_statistics.json 56 | 57 | wandb: 58 | project: "vla_benchmark" -------------------------------------------------------------------------------- /config/experiment/simpler/magma_bridge_ev_lang1.yaml: -------------------------------------------------------------------------------- 1 | name: magma_bridge_simpler 2 | seed: 42 3 | model_cfg: !include ../../models/magma.json 4 | 5 | eval_cfg: 6 | simulator_name: "simpler" 7 | env_adapter: "BridgeSimplerMagmaAdapter" 8 | task_list: [ 9 | # # * original tasks 10 | # "widowx_spoon_on_towel", 11 | # "widowx_carrot_on_plate", 12 | # "widowx_stack_cube", 13 | # "widowx_put_eggplant_in_basket", 14 | # # * object distraction 15 | # "widowx_spoon_on_towel_distract", 16 | # "widowx_carrot_on_plate_distract", 17 | # "widowx_carrot_on_keyboard_distract", 18 | # "widowx_coke_can_on_plate_distract", 19 | # "widowx_coke_can_on_keyboard_distract", 20 | # # * generatization test 21 | # "widowx_cube_on_plate_clean", # seen source and target, unseen combination 22 | # "widowx_small_plate_on_green_cube_clean", # seen source and target, unseen combination (a reverse) 23 | # "widowx_coke_can_on_plate_clean", # ood source 24 | # "widowx_pepsi_on_plate_clean", # another OOD source besides coke can, also a texture difference 25 | # "widowx_carrot_on_sponge_clean", # seen source and target, unseen combination 26 | # "widowx_eggplant_on_sponge_clean", 27 | # "widowx_carrot_on_keyboard_clean", # ood target 28 | # "widowx_coke_can_on_keyboard_clean", 29 | # * language variation 30 | "widowx_carrot_on_plate_lang_common", # rabbit, no distract 31 | "widowx_carrot_on_plate_lang_action", 32 | "widowx_carrot_on_plate_lang_neg", 33 | "widowx_carrot_on_plate_lang_neg_action", # on the table not on the plate 34 | "widowx_carrot_on_plate_lang_common_distract", # rabbit 35 | "widowx_spoon_on_towel_lang_action", 36 | "widowx_spoon_on_towel_lang_common", 37 | "widowx_spoon_on_towel_lang_common_distract", 38 | # "widowx_stack_cube_lang_action", 39 | # "widowx_eggplant_in_basket_lang_action", 40 | # "widowx_eggplant_in_basket_lang_color", 41 | # "widowx_eggplant_in_basket_lang_common", 42 | # "widowx_carrot_on_keyboard_lang_common", 43 | # "widowx_coke_can_on_plate_lang_common", 44 | # "widowx_coke_can_on_plate_lang_neg", 45 | # "widowx_coke_can_on_plate_lang_common_distract", # thirsty 46 | ] 47 | 48 | n_eval_episode: 24 49 | n_video: 24 50 | recording: True 51 | pretrained_model_path: "microsoft/Magma-8B" 52 | action_step: 1 # OpenVLA-like models usually use action emsemble to fold 4 future steps in to 1 53 | 54 | env: 55 | dataset_statistics_path: ./config/dataset/bridge_statistics.json 56 | 57 | wandb: 58 | project: "vla_benchmark" -------------------------------------------------------------------------------- /config/experiment/simpler/magma_bridge_ev_lang2.yaml: -------------------------------------------------------------------------------- 1 | name: magma_bridge_simpler 2 | seed: 42 3 | model_cfg: !include ../../models/magma.json 4 | 5 | eval_cfg: 6 | simulator_name: "simpler" 7 | env_adapter: "BridgeSimplerMagmaAdapter" 8 | task_list: [ 9 | # # * original tasks 10 | # "widowx_spoon_on_towel", 11 | # "widowx_carrot_on_plate", 12 | # "widowx_stack_cube", 13 | # "widowx_put_eggplant_in_basket", 14 | # # * object distraction 15 | # "widowx_spoon_on_towel_distract", 16 | # "widowx_carrot_on_plate_distract", 17 | # "widowx_carrot_on_keyboard_distract", 18 | # "widowx_coke_can_on_plate_distract", 19 | # "widowx_coke_can_on_keyboard_distract", 20 | # # * generatization test 21 | # "widowx_cube_on_plate_clean", # seen source and target, unseen combination 22 | # "widowx_small_plate_on_green_cube_clean", # seen source and target, unseen combination (a reverse) 23 | # "widowx_coke_can_on_plate_clean", # ood source 24 | # "widowx_pepsi_on_plate_clean", # another OOD source besides coke can, also a texture difference 25 | # "widowx_carrot_on_sponge_clean", # seen source and target, unseen combination 26 | # "widowx_eggplant_on_sponge_clean", 27 | # "widowx_carrot_on_keyboard_clean", # ood target 28 | # "widowx_coke_can_on_keyboard_clean", 29 | # # * language variation 30 | # "widowx_carrot_on_plate_lang_common", # rabbit, no distract 31 | # "widowx_carrot_on_plate_lang_action", 32 | # "widowx_carrot_on_plate_lang_neg", 33 | # "widowx_carrot_on_plate_lang_neg_action", # on the table not on the plate 34 | # "widowx_carrot_on_plate_lang_common_distract", # rabbit 35 | # "widowx_spoon_on_towel_lang_action", 36 | # "widowx_spoon_on_towel_lang_common", 37 | # "widowx_spoon_on_towel_lang_common_distract", 38 | "widowx_stack_cube_lang_action", 39 | "widowx_eggplant_in_basket_lang_action", 40 | "widowx_eggplant_in_basket_lang_color", 41 | "widowx_eggplant_in_basket_lang_common", 42 | "widowx_carrot_on_keyboard_lang_common", 43 | "widowx_coke_can_on_plate_lang_common", 44 | "widowx_coke_can_on_plate_lang_neg", 45 | "widowx_coke_can_on_plate_lang_common_distract", # thirsty 46 | ] 47 | 48 | n_eval_episode: 24 49 | n_video: 24 50 | recording: True 51 | pretrained_model_path: "microsoft/Magma-8B" 52 | action_step: 1 # OpenVLA-like models usually use action emsemble to fold 4 future steps in to 1 53 | 54 | env: 55 | dataset_statistics_path: ./config/dataset/bridge_statistics.json 56 | 57 | wandb: 58 | project: "vla_benchmark" -------------------------------------------------------------------------------- /config/experiment/simpler/magma_bridge_ev_ood.yaml: -------------------------------------------------------------------------------- 1 | name: magma_bridge_simpler 2 | seed: 42 3 | model_cfg: !include ../../models/magma.json 4 | 5 | eval_cfg: 6 | simulator_name: "simpler" 7 | env_adapter: "BridgeSimplerMagmaAdapter" 8 | task_list: [ 9 | # # * original tasks 10 | # "widowx_spoon_on_towel", 11 | # "widowx_carrot_on_plate", 12 | # "widowx_stack_cube", 13 | # "widowx_put_eggplant_in_basket", 14 | # # * object distraction 15 | # "widowx_spoon_on_towel_distract", 16 | # "widowx_carrot_on_plate_distract", 17 | # "widowx_carrot_on_keyboard_distract", 18 | # "widowx_coke_can_on_plate_distract", 19 | # "widowx_coke_can_on_keyboard_distract", 20 | # * generatization test 21 | "widowx_cube_on_plate_clean", # seen source and target, unseen combination 22 | "widowx_small_plate_on_green_cube_clean", # seen source and target, unseen combination (a reverse) 23 | "widowx_coke_can_on_plate_clean", # ood source 24 | "widowx_pepsi_on_plate_clean", # another OOD source besides coke can, also a texture difference 25 | "widowx_carrot_on_sponge_clean", # seen source and target, unseen combination 26 | "widowx_eggplant_on_sponge_clean", 27 | "widowx_carrot_on_keyboard_clean", # ood target 28 | "widowx_coke_can_on_keyboard_clean", 29 | # # * language variation 30 | # "widowx_carrot_on_plate_lang_common", # rabbit, no distract 31 | # "widowx_carrot_on_plate_lang_action", 32 | # "widowx_carrot_on_plate_lang_neg", 33 | # "widowx_carrot_on_plate_lang_neg_action", # on the table not on the plate 34 | # "widowx_carrot_on_plate_lang_common_distract", # rabbit 35 | # "widowx_spoon_on_towel_lang_action", 36 | # "widowx_spoon_on_towel_lang_common", 37 | # "widowx_spoon_on_towel_lang_common_distract", 38 | # "widowx_stack_cube_lang_action", 39 | # "widowx_eggplant_in_basket_lang_action", 40 | # "widowx_eggplant_in_basket_lang_color", 41 | # "widowx_eggplant_in_basket_lang_common", 42 | # "widowx_carrot_on_keyboard_lang_common", 43 | # "widowx_coke_can_on_plate_lang_common", 44 | # "widowx_coke_can_on_plate_lang_neg", 45 | # "widowx_coke_can_on_plate_lang_common_distract", # thirsty 46 | ] 47 | 48 | n_eval_episode: 24 49 | n_video: 24 50 | recording: True 51 | pretrained_model_path: "microsoft/Magma-8B" 52 | action_step: 1 # OpenVLA-like models usually use action emsemble to fold 4 future steps in to 1 53 | 54 | env: 55 | dataset_statistics_path: ./config/dataset/bridge_statistics.json 56 | 57 | wandb: 58 | project: "vla_benchmark" -------------------------------------------------------------------------------- /config/experiment/simpler/octo_base_bridge_ev.yaml: -------------------------------------------------------------------------------- 1 | name: octo_base_bridge_simpler 2 | seed: 42 3 | model_cfg: !include ../../models/octo.json 4 | 5 | eval_cfg: 6 | simulator_name: "simpler" 7 | env_adapter: "BridgeSimplerOctoAdapter" 8 | task_list: [ 9 | # "widowx_spoon_on_towel", 10 | # "widowx_carrot_on_plate", 11 | # "widowx_stack_cube", 12 | # "widowx_put_eggplant_in_basket", 13 | # # * generatization test 14 | # "widowx_cube_on_plate_clean", # seen source and target, unseen combination 15 | # "widowx_small_plate_on_green_cube_clean", # seen source and target, unseen combination (a reverse) 16 | # "widowx_coke_can_on_plate_clean", # ood source 17 | # "widowx_pepsi_on_plate_clean", # another OOD source besides coke can, also a texture difference 18 | # "widowx_carrot_on_sponge_clean", # seen source and target, unseen combination 19 | # "widowx_eggplant_on_sponge_clean", 20 | # "widowx_carrot_on_keyboard_clean", # ood target 21 | # "widowx_coke_can_on_keyboard_clean", 22 | # # * object distraction 23 | # "widowx_spoon_on_towel_distract", 24 | # "widowx_carrot_on_plate_distract", 25 | # "widowx_carrot_on_keyboard_distract", 26 | # "widowx_coke_can_on_plate_distract", 27 | # "widowx_coke_can_on_keyboard_distract", 28 | # # * language variation 29 | # "widowx_carrot_on_plate_lang_common", # rabbit, no distract 30 | # "widowx_carrot_on_plate_lang_action", 31 | # "widowx_carrot_on_plate_lang_neg", 32 | # "widowx_carrot_on_plate_lang_neg_action", # on the table not on the plate 33 | # "widowx_carrot_on_plate_lang_common_distract", # rabbit 34 | # "widowx_spoon_on_towel_lang_action", 35 | # "widowx_spoon_on_towel_lang_common", 36 | # "widowx_spoon_on_towel_lang_common_distract", 37 | # "widowx_stack_cube_lang_action", 38 | # "widowx_eggplant_in_basket_lang_action", 39 | # "widowx_eggplant_in_basket_lang_color", 40 | # "widowx_eggplant_in_basket_lang_common", 41 | # "widowx_carrot_on_keyboard_lang_common", 42 | # "widowx_coke_can_on_plate_lang_common", 43 | # "widowx_coke_can_on_plate_lang_neg", 44 | # "widowx_coke_can_on_plate_lang_common_distract", # thirsty 45 | "widowx_orange_juice_on_plate_clean", 46 | "widowx_orange_juice_on_plate_distract", 47 | "widowx_orange_juice_on_plate_lang_neg", 48 | "widowx_orange_juice_on_plate_lang_common", 49 | "widowx_orange_juice_on_plate_lang_common_distract", 50 | "widowx_orange_juice_on_plate_lang_common_distractv2", 51 | "widowx_nut_on_plate_clean", 52 | "widowx_nut_on_plate_lang_common", 53 | "widowx_eggplant_on_keyboard_clean", 54 | "widowx_carrot_on_ramekin_clean", 55 | "widowx_carrot_on_wheel_clean", 56 | "widowx_coke_can_on_ramekin_clean", 57 | "widowx_coke_can_on_wheel_clean", 58 | "widowx_nut_on_wheel_clean", 59 | "widowx_cube_on_plate_lang_shape", 60 | "widowx_spoon_on_towel_lang_neg", 61 | "widowx_spoon_on_towel_lang_color", 62 | "widowx_carrot_on_plate_lang_color", 63 | ] 64 | 65 | n_eval_episode: 24 66 | n_video: 24 67 | recording: True 68 | pretrained_model_path: "rail-berkeley/octo-base" 69 | action_step: 1 # octo use action emsemble to fold 4 future steps in to 1 70 | 71 | env: 72 | dataset_statistics_path: ./config/dataset/bridge_statistics.json 73 | image_size: [256, 256] 74 | action_normalization_type: "gaussian" 75 | wandb: 76 | project: "vla_benchmark" -------------------------------------------------------------------------------- /config/experiment/simpler/octo_small_bridge_ev.yaml: -------------------------------------------------------------------------------- 1 | name: octo_small_bridge_simpler 2 | seed: 42 3 | model_cfg: !include ../../models/octo.json 4 | 5 | eval_cfg: 6 | simulator_name: "simpler" 7 | env_adapter: "BridgeSimplerOctoAdapter" 8 | task_list: [ 9 | # "widowx_spoon_on_towel", 10 | # "widowx_carrot_on_plate", 11 | # "widowx_stack_cube", 12 | # "widowx_put_eggplant_in_basket", 13 | # # * generatization test 14 | # "widowx_cube_on_plate_clean", # seen source and target, unseen combination 15 | # "widowx_small_plate_on_green_cube_clean", # seen source and target, unseen combination (a reverse) 16 | # "widowx_coke_can_on_plate_clean", # ood source 17 | # "widowx_pepsi_on_plate_clean", # another OOD source besides coke can, also a texture difference 18 | # "widowx_carrot_on_sponge_clean", # seen source and target, unseen combination 19 | # "widowx_eggplant_on_sponge_clean", 20 | # "widowx_carrot_on_keyboard_clean", # ood target 21 | # "widowx_coke_can_on_keyboard_clean", 22 | # # * object distraction 23 | # "widowx_spoon_on_towel_distract", 24 | # "widowx_carrot_on_plate_distract", 25 | # "widowx_carrot_on_keyboard_distract", 26 | # "widowx_coke_can_on_plate_distract", 27 | # "widowx_coke_can_on_keyboard_distract", 28 | # # * language variation 29 | # "widowx_carrot_on_plate_lang_common", # rabbit, no distract 30 | # "widowx_carrot_on_plate_lang_action", 31 | # "widowx_carrot_on_plate_lang_neg", 32 | # "widowx_carrot_on_plate_lang_neg_action", # on the table not on the plate 33 | # "widowx_carrot_on_plate_lang_common_distract", # rabbit 34 | # "widowx_spoon_on_towel_lang_action", 35 | # "widowx_spoon_on_towel_lang_common", 36 | # "widowx_spoon_on_towel_lang_common_distract", 37 | # "widowx_stack_cube_lang_action", 38 | # "widowx_eggplant_in_basket_lang_action", 39 | # "widowx_eggplant_in_basket_lang_color", 40 | # "widowx_eggplant_in_basket_lang_common", 41 | # "widowx_carrot_on_keyboard_lang_common", 42 | # "widowx_coke_can_on_plate_lang_common", 43 | # "widowx_coke_can_on_plate_lang_neg", 44 | # "widowx_coke_can_on_plate_lang_common_distract", # thirsty 45 | "widowx_orange_juice_on_plate_clean", 46 | "widowx_orange_juice_on_plate_distract", 47 | "widowx_orange_juice_on_plate_lang_neg", 48 | "widowx_orange_juice_on_plate_lang_common", 49 | "widowx_orange_juice_on_plate_lang_common_distract", 50 | "widowx_orange_juice_on_plate_lang_common_distractv2", 51 | "widowx_nut_on_plate_clean", 52 | "widowx_nut_on_plate_lang_common", 53 | "widowx_eggplant_on_keyboard_clean", 54 | "widowx_carrot_on_ramekin_clean", 55 | "widowx_carrot_on_wheel_clean", 56 | "widowx_coke_can_on_ramekin_clean", 57 | "widowx_coke_can_on_wheel_clean", 58 | "widowx_nut_on_wheel_clean", 59 | "widowx_cube_on_plate_lang_shape", 60 | "widowx_spoon_on_towel_lang_neg", 61 | "widowx_spoon_on_towel_lang_color", 62 | "widowx_carrot_on_plate_lang_color", 63 | ] 64 | 65 | n_eval_episode: 24 66 | n_video: 24 67 | recording: True 68 | pretrained_model_path: "rail-berkeley/octo-small" 69 | action_step: 1 # octo use action emsemble to fold 4 future steps in to 1 70 | 71 | env: 72 | dataset_statistics_path: ./config/dataset/bridge_statistics.json 73 | image_size: [256, 256] 74 | action_normalization_type: "gaussian" 75 | wandb: 76 | project: "vla_benchmark" -------------------------------------------------------------------------------- /config/experiment/simpler/pi0_baseline_bridge_ev.yaml: -------------------------------------------------------------------------------- 1 | name: pi0_baseline 2 | seed: 42 3 | model_cfg: !include ../../models/pi0_baseline_bridge.json 4 | 5 | eval_cfg: 6 | simulator_name: "simpler" 7 | env_adapter: "BridgeSimplerAdapter" 8 | task_list: [ 9 | # * original tasks 10 | "widowx_spoon_on_towel", 11 | "widowx_carrot_on_plate", 12 | "widowx_stack_cube", 13 | "widowx_put_eggplant_in_basket", 14 | # * object distraction 15 | "widowx_spoon_on_towel_distract", 16 | "widowx_carrot_on_plate_distract", 17 | "widowx_carrot_on_keyboard_distract", 18 | "widowx_coke_can_on_plate_distract", 19 | "widowx_coke_can_on_keyboard_distract", 20 | # * generatization test 21 | "widowx_cube_on_plate_clean", # seen source and target, unseen combination 22 | "widowx_small_plate_on_green_cube_clean", # seen source and target, unseen combination (a reverse) 23 | "widowx_coke_can_on_plate_clean", # ood source 24 | "widowx_pepsi_on_plate_clean", # another OOD source besides coke can, also a texture difference 25 | "widowx_carrot_on_sponge_clean", # seen source and target, unseen combination 26 | "widowx_eggplant_on_sponge_clean", 27 | "widowx_carrot_on_keyboard_clean", # ood target 28 | "widowx_coke_can_on_keyboard_clean", 29 | # * language variation 30 | "widowx_carrot_on_plate_lang_common", # rabbit, no distract 31 | "widowx_carrot_on_plate_lang_action", 32 | "widowx_carrot_on_plate_lang_neg", 33 | "widowx_carrot_on_plate_lang_neg_action", # on the table not on the plate 34 | "widowx_carrot_on_plate_lang_common_distract", # rabbit 35 | "widowx_spoon_on_towel_lang_action", 36 | "widowx_spoon_on_towel_lang_common", 37 | "widowx_spoon_on_towel_lang_common_distract", 38 | "widowx_stack_cube_lang_action", 39 | "widowx_eggplant_in_basket_lang_action", 40 | "widowx_eggplant_in_basket_lang_color", 41 | "widowx_eggplant_in_basket_lang_common", 42 | "widowx_carrot_on_keyboard_lang_common", 43 | "widowx_coke_can_on_plate_lang_common", 44 | "widowx_coke_can_on_plate_lang_neg", 45 | "widowx_coke_can_on_plate_lang_common_distract", # thirsty 46 | # * new 47 | "widowx_orange_juice_on_plate_clean", 48 | "widowx_orange_juice_on_plate_distract", 49 | "widowx_orange_juice_on_plate_lang_neg", 50 | "widowx_orange_juice_on_plate_lang_common", 51 | "widowx_orange_juice_on_plate_lang_common_distract", 52 | "widowx_orange_juice_on_plate_lang_common_distractv2", 53 | "widowx_nut_on_plate_clean", 54 | "widowx_nut_on_plate_lang_common", 55 | "widowx_eggplant_on_keyboard_clean", 56 | "widowx_carrot_on_ramekin_clean", 57 | "widowx_carrot_on_wheel_clean", 58 | "widowx_coke_can_on_ramekin_clean", 59 | "widowx_coke_can_on_wheel_clean", 60 | "widowx_nut_on_wheel_clean", 61 | "widowx_cube_on_plate_lang_shape", 62 | "widowx_spoon_on_towel_lang_neg", 63 | "widowx_spoon_on_towel_lang_color", 64 | "widowx_carrot_on_plate_lang_color", 65 | ] 66 | 67 | n_eval_episode: 24 68 | n_video: 24 69 | recording: True 70 | pretrained_model_path: ./log/train/pi0_baseline_bridge/2025-03-04_00-00-38-42/checkpoint 71 | pretrained_model_gradient_step_cnt: [1513, 3026, 4539, 7565, 15130, 22695] 72 | 73 | env: 74 | dataset_statistics_path: ./config/dataset/bridge_statistics.json 75 | 76 | wandb: 77 | project: "vla_benchmark" -------------------------------------------------------------------------------- /config/experiment/simpler/pi0_baseline_bridge_ev_lang1.yaml: -------------------------------------------------------------------------------- 1 | name: pi0_baseline 2 | seed: 42 3 | model_cfg: !include ../../models/pi0_baseline_bridge.json 4 | 5 | eval_cfg: 6 | simulator_name: "simpler" 7 | env_adapter: "BridgeSimplerAdapter" 8 | task_list: [ 9 | # # * original tasks 10 | # "widowx_spoon_on_towel", 11 | # "widowx_carrot_on_plate", 12 | # "widowx_stack_cube", 13 | # "widowx_put_eggplant_in_basket", 14 | # # * object distraction 15 | # "widowx_spoon_on_towel_distract", 16 | # "widowx_carrot_on_plate_distract", 17 | # "widowx_carrot_on_keyboard_distract", 18 | # "widowx_coke_can_on_plate_distract", 19 | # "widowx_coke_can_on_keyboard_distract", 20 | # # * generatization test 21 | # "widowx_cube_on_plate_clean", # seen source and target, unseen combination 22 | # "widowx_small_plate_on_green_cube_clean", # seen source and target, unseen combination (a reverse) 23 | # "widowx_coke_can_on_plate_clean", # ood source 24 | # "widowx_pepsi_on_plate_clean", # another OOD source besides coke can, also a texture difference 25 | # "widowx_carrot_on_sponge_clean", # seen source and target, unseen combination 26 | # "widowx_eggplant_on_sponge_clean", 27 | # "widowx_carrot_on_keyboard_clean", # ood target 28 | # "widowx_coke_can_on_keyboard_clean", 29 | # * language variation 30 | "widowx_carrot_on_plate_lang_common", # rabbit, no distract 31 | "widowx_carrot_on_plate_lang_action", 32 | "widowx_carrot_on_plate_lang_neg", 33 | "widowx_carrot_on_plate_lang_neg_action", # on the table not on the plate 34 | "widowx_carrot_on_plate_lang_common_distract", # rabbit 35 | "widowx_spoon_on_towel_lang_action", 36 | "widowx_spoon_on_towel_lang_common", 37 | "widowx_spoon_on_towel_lang_common_distract", 38 | # "widowx_stack_cube_lang_action", 39 | # "widowx_eggplant_in_basket_lang_action", 40 | # "widowx_eggplant_in_basket_lang_color", 41 | # "widowx_eggplant_in_basket_lang_common", 42 | # "widowx_carrot_on_keyboard_lang_common", 43 | # "widowx_coke_can_on_plate_lang_common", 44 | # "widowx_coke_can_on_plate_lang_neg", 45 | # "widowx_coke_can_on_plate_lang_common_distract", # thirsty 46 | ] 47 | 48 | n_eval_episode: 24 49 | n_video: 24 50 | recording: True 51 | pretrained_model_path: ./log/train/pi0_baseline_bridge/2025-03-04_00-00-38-42/checkpoint 52 | pretrained_model_gradient_step_cnt: [1513, 3026, 4539, 7565, 15130, 22695] 53 | 54 | env: 55 | dataset_statistics_path: ./config/dataset/bridge_statistics.json 56 | 57 | wandb: 58 | project: "vla_benchmark" -------------------------------------------------------------------------------- /config/experiment/simpler/pi0_baseline_bridge_ev_lang2.yaml: -------------------------------------------------------------------------------- 1 | name: pi0_baseline 2 | seed: 42 3 | model_cfg: !include ../../models/pi0_baseline_bridge.json 4 | 5 | eval_cfg: 6 | simulator_name: "simpler" 7 | env_adapter: "BridgeSimplerAdapter" 8 | task_list: [ 9 | # # * original tasks 10 | # "widowx_spoon_on_towel", 11 | # "widowx_carrot_on_plate", 12 | # "widowx_stack_cube", 13 | # "widowx_put_eggplant_in_basket", 14 | # # * object distraction 15 | # "widowx_spoon_on_towel_distract", 16 | # "widowx_carrot_on_plate_distract", 17 | # "widowx_carrot_on_keyboard_distract", 18 | # "widowx_coke_can_on_plate_distract", 19 | # "widowx_coke_can_on_keyboard_distract", 20 | # # * generatization test 21 | # "widowx_cube_on_plate_clean", # seen source and target, unseen combination 22 | # "widowx_small_plate_on_green_cube_clean", # seen source and target, unseen combination (a reverse) 23 | # "widowx_coke_can_on_plate_clean", # ood source 24 | # "widowx_pepsi_on_plate_clean", # another OOD source besides coke can, also a texture difference 25 | # "widowx_carrot_on_sponge_clean", # seen source and target, unseen combination 26 | # "widowx_eggplant_on_sponge_clean", 27 | # "widowx_carrot_on_keyboard_clean", # ood target 28 | # "widowx_coke_can_on_keyboard_clean", 29 | # # * language variation 30 | # "widowx_carrot_on_plate_lang_common", # rabbit, no distract 31 | # "widowx_carrot_on_plate_lang_action", 32 | # "widowx_carrot_on_plate_lang_neg", 33 | # "widowx_carrot_on_plate_lang_neg_action", # on the table not on the plate 34 | # "widowx_carrot_on_plate_lang_common_distract", # rabbit 35 | # "widowx_spoon_on_towel_lang_action", 36 | # "widowx_spoon_on_towel_lang_common", 37 | # "widowx_spoon_on_towel_lang_common_distract", 38 | "widowx_stack_cube_lang_action", 39 | "widowx_eggplant_in_basket_lang_action", 40 | "widowx_eggplant_in_basket_lang_color", 41 | "widowx_eggplant_in_basket_lang_common", 42 | "widowx_carrot_on_keyboard_lang_common", 43 | "widowx_coke_can_on_plate_lang_common", 44 | "widowx_coke_can_on_plate_lang_neg", 45 | "widowx_coke_can_on_plate_lang_common_distract", # thirsty 46 | ] 47 | 48 | n_eval_episode: 24 49 | n_video: 24 50 | recording: True 51 | pretrained_model_path: ./log/train/pi0_baseline_bridge/2025-03-04_00-00-38-42/checkpoint 52 | pretrained_model_gradient_step_cnt: [1513, 3026, 4539, 7565, 15130, 22695] 53 | 54 | env: 55 | dataset_statistics_path: ./config/dataset/bridge_statistics.json 56 | 57 | wandb: 58 | project: "vla_benchmark" -------------------------------------------------------------------------------- /config/experiment/simpler/pi0_baseline_bridge_ev_ood.yaml: -------------------------------------------------------------------------------- 1 | name: pi0_baseline 2 | seed: 42 3 | model_cfg: !include ../../models/pi0_baseline_bridge.json 4 | 5 | eval_cfg: 6 | simulator_name: "simpler" 7 | env_adapter: "BridgeSimplerAdapter" 8 | task_list: [ 9 | # # * original tasks 10 | # "widowx_spoon_on_towel", 11 | # "widowx_carrot_on_plate", 12 | # "widowx_stack_cube", 13 | # "widowx_put_eggplant_in_basket", 14 | # # * object distraction 15 | # "widowx_spoon_on_towel_distract", 16 | # "widowx_carrot_on_plate_distract", 17 | # "widowx_carrot_on_keyboard_distract", 18 | # "widowx_coke_can_on_plate_distract", 19 | # "widowx_coke_can_on_keyboard_distract", 20 | # * generatization test 21 | "widowx_cube_on_plate_clean", # seen source and target, unseen combination 22 | "widowx_small_plate_on_green_cube_clean", # seen source and target, unseen combination (a reverse) 23 | "widowx_coke_can_on_plate_clean", # ood source 24 | "widowx_pepsi_on_plate_clean", # another OOD source besides coke can, also a texture difference 25 | "widowx_carrot_on_sponge_clean", # seen source and target, unseen combination 26 | "widowx_eggplant_on_sponge_clean", 27 | "widowx_carrot_on_keyboard_clean", # ood target 28 | "widowx_coke_can_on_keyboard_clean", 29 | # # * language variation 30 | # "widowx_carrot_on_plate_lang_common", # rabbit, no distract 31 | # "widowx_carrot_on_plate_lang_action", 32 | # "widowx_carrot_on_plate_lang_neg", 33 | # "widowx_carrot_on_plate_lang_neg_action", # on the table not on the plate 34 | # "widowx_carrot_on_plate_lang_common_distract", # rabbit 35 | # "widowx_spoon_on_towel_lang_action", 36 | # "widowx_spoon_on_towel_lang_common", 37 | # "widowx_spoon_on_towel_lang_common_distract", 38 | # "widowx_stack_cube_lang_action", 39 | # "widowx_eggplant_in_basket_lang_action", 40 | # "widowx_eggplant_in_basket_lang_color", 41 | # "widowx_eggplant_in_basket_lang_common", 42 | # "widowx_carrot_on_keyboard_lang_common", 43 | # "widowx_coke_can_on_plate_lang_common", 44 | # "widowx_coke_can_on_plate_lang_neg", 45 | # "widowx_coke_can_on_plate_lang_common_distract", # thirsty 46 | ] 47 | 48 | n_eval_episode: 24 49 | n_video: 24 50 | recording: True 51 | pretrained_model_path: ./log/train/pi0_baseline_bridge/2025-03-04_00-00-38-42/checkpoint 52 | pretrained_model_gradient_step_cnt: [1513, 3026, 4539, 7565, 15130, 22695] 53 | 54 | env: 55 | dataset_statistics_path: ./config/dataset/bridge_statistics.json 56 | 57 | wandb: 58 | project: "vla_benchmark" -------------------------------------------------------------------------------- /config/experiment/simpler/pi0_baseline_freezevlm_ev1.yaml: -------------------------------------------------------------------------------- 1 | name: pi0_baseline_freeze_vlm 2 | seed: 42 3 | model_cfg: !include ../../models/pi0_baseline_bridge.json 4 | 5 | eval_cfg: 6 | simulator_name: "simpler" 7 | env_adapter: "BridgeSimplerAdapter" 8 | task_list: [ 9 | # * original tasks 10 | "widowx_spoon_on_towel", 11 | "widowx_carrot_on_plate", 12 | "widowx_stack_cube", 13 | "widowx_put_eggplant_in_basket", 14 | # * object distraction 15 | "widowx_spoon_on_towel_distract", 16 | "widowx_carrot_on_plate_distract", 17 | "widowx_carrot_on_keyboard_distract", 18 | "widowx_coke_can_on_plate_distract", 19 | "widowx_coke_can_on_keyboard_distract", 20 | # * generatization test 21 | "widowx_cube_on_plate_clean", # seen source and target, unseen combination 22 | "widowx_small_plate_on_green_cube_clean", # seen source and target, unseen combination (a reverse) 23 | "widowx_coke_can_on_plate_clean", # ood source 24 | "widowx_pepsi_on_plate_clean", # another OOD source besides coke can, also a texture difference 25 | "widowx_carrot_on_sponge_clean", # seen source and target, unseen combination 26 | "widowx_eggplant_on_sponge_clean", 27 | "widowx_carrot_on_keyboard_clean", # ood target 28 | "widowx_coke_can_on_keyboard_clean", 29 | # * language variation 30 | "widowx_carrot_on_plate_lang_common", # rabbit, no distract 31 | "widowx_carrot_on_plate_lang_action", 32 | "widowx_carrot_on_plate_lang_neg", 33 | "widowx_carrot_on_plate_lang_common_distract", # rabbit 34 | "widowx_spoon_on_towel_lang_action", 35 | "widowx_spoon_on_towel_lang_common", 36 | "widowx_spoon_on_towel_lang_common_distract", 37 | "widowx_stack_cube_lang_action", 38 | "widowx_eggplant_in_basket_lang_action", 39 | "widowx_eggplant_in_basket_lang_color", 40 | "widowx_eggplant_in_basket_lang_common", 41 | "widowx_carrot_on_keyboard_lang_common", 42 | "widowx_coke_can_on_plate_lang_common", 43 | "widowx_coke_can_on_plate_lang_neg", 44 | "widowx_coke_can_on_plate_lang_common_distract", # thirsty 45 | 46 | "widowx_orange_juice_on_plate_clean", 47 | "widowx_orange_juice_on_plate_distract", 48 | "widowx_orange_juice_on_plate_lang_neg", 49 | "widowx_orange_juice_on_plate_lang_common", 50 | "widowx_orange_juice_on_plate_lang_common_distract", 51 | "widowx_orange_juice_on_plate_lang_common_distractv2", 52 | "widowx_nut_on_plate_clean", 53 | "widowx_nut_on_plate_lang_common", 54 | "widowx_eggplant_on_keyboard_clean", 55 | "widowx_carrot_on_ramekin_clean", 56 | "widowx_carrot_on_wheel_clean", 57 | "widowx_coke_can_on_ramekin_clean", 58 | "widowx_coke_can_on_wheel_clean", 59 | "widowx_nut_on_wheel_clean", 60 | "widowx_cube_on_plate_lang_shape", 61 | "widowx_spoon_on_towel_lang_neg", 62 | "widowx_spoon_on_towel_lang_color", 63 | "widowx_carrot_on_plate_lang_color", 64 | ] 65 | 66 | n_eval_episode: 24 67 | n_video: 24 68 | recording: True 69 | pretrained_model_path: ./log/train/pi0_baseline_bridge_freeze_vlm/2025-05-19_01-27-14_42/checkpoint/ 70 | pretrained_model_gradient_step_cnt: [1513, 3026, 4539, 7565, 15130, 22695] 71 | 72 | env: 73 | dataset_statistics_path: ./config/dataset/bridge_statistics.json 74 | 75 | wandb: 76 | project: "vla_benchmark" -------------------------------------------------------------------------------- /config/experiment/simpler/pi0_baseline_paraphrase_ev1.yaml: -------------------------------------------------------------------------------- 1 | name: pi0_baseline_paraphrase 2 | seed: 42 3 | model_cfg: !include ../../models/pi0_baseline_bridge.json 4 | 5 | eval_cfg: 6 | simulator_name: "simpler" 7 | env_adapter: "BridgeSimplerAdapter" 8 | task_list: [ 9 | # * original tasks 10 | "widowx_spoon_on_towel", 11 | "widowx_carrot_on_plate", 12 | "widowx_stack_cube", 13 | "widowx_put_eggplant_in_basket", 14 | # * object distraction 15 | "widowx_spoon_on_towel_distract", 16 | "widowx_carrot_on_plate_distract", 17 | "widowx_carrot_on_keyboard_distract", 18 | "widowx_coke_can_on_plate_distract", 19 | "widowx_coke_can_on_keyboard_distract", 20 | # * generatization test 21 | "widowx_cube_on_plate_clean", # seen source and target, unseen combination 22 | "widowx_small_plate_on_green_cube_clean", # seen source and target, unseen combination (a reverse) 23 | "widowx_coke_can_on_plate_clean", # ood source 24 | "widowx_pepsi_on_plate_clean", # another OOD source besides coke can, also a texture difference 25 | "widowx_carrot_on_sponge_clean", # seen source and target, unseen combination 26 | "widowx_eggplant_on_sponge_clean", 27 | "widowx_carrot_on_keyboard_clean", # ood target 28 | "widowx_coke_can_on_keyboard_clean", 29 | # * language variation 30 | "widowx_carrot_on_plate_lang_common", # rabbit, no distract 31 | "widowx_carrot_on_plate_lang_action", 32 | "widowx_carrot_on_plate_lang_neg", 33 | "widowx_carrot_on_plate_lang_common_distract", # rabbit 34 | "widowx_spoon_on_towel_lang_action", 35 | "widowx_spoon_on_towel_lang_common", 36 | "widowx_spoon_on_towel_lang_common_distract", 37 | "widowx_stack_cube_lang_action", 38 | "widowx_eggplant_in_basket_lang_action", 39 | "widowx_eggplant_in_basket_lang_color", 40 | "widowx_eggplant_in_basket_lang_common", 41 | "widowx_carrot_on_keyboard_lang_common", 42 | "widowx_coke_can_on_plate_lang_common", 43 | "widowx_coke_can_on_plate_lang_neg", 44 | "widowx_coke_can_on_plate_lang_common_distract", # thirsty 45 | 46 | "widowx_orange_juice_on_plate_clean", 47 | "widowx_orange_juice_on_plate_distract", 48 | "widowx_orange_juice_on_plate_lang_neg", 49 | "widowx_orange_juice_on_plate_lang_common", 50 | "widowx_orange_juice_on_plate_lang_common_distract", 51 | "widowx_orange_juice_on_plate_lang_common_distractv2", 52 | "widowx_nut_on_plate_clean", 53 | "widowx_nut_on_plate_lang_common", 54 | "widowx_eggplant_on_keyboard_clean", 55 | "widowx_carrot_on_ramekin_clean", 56 | "widowx_carrot_on_wheel_clean", 57 | "widowx_coke_can_on_ramekin_clean", 58 | "widowx_coke_can_on_wheel_clean", 59 | "widowx_nut_on_wheel_clean", 60 | "widowx_cube_on_plate_lang_shape", 61 | "widowx_spoon_on_towel_lang_neg", 62 | "widowx_spoon_on_towel_lang_color", 63 | "widowx_carrot_on_plate_lang_color", 64 | ] 65 | 66 | n_eval_episode: 24 67 | n_video: 24 68 | recording: True 69 | pretrained_model_path: ./log/train/pi0_baseline_bridge_paraphrase/2025-05-16_23-37-18_42/checkpoint/ 70 | pretrained_model_gradient_step_cnt: [1513, 3026, 4539, 7565, 15130, 22695] 71 | 72 | env: 73 | dataset_statistics_path: ./config/dataset/bridge_statistics.json 74 | 75 | wandb: 76 | project: "vla_benchmark" -------------------------------------------------------------------------------- /config/experiment/simpler/pi0_finetune_bridge_ev.yaml: -------------------------------------------------------------------------------- 1 | name: pi0_finetune 2 | seed: 42 3 | model_cfg: !include ../../models/pi0_finetune_bridge.json 4 | 5 | eval_cfg: 6 | simulator_name: "simpler" 7 | env_adapter: "BridgeSimplerAdapter" 8 | task_list: [ 9 | # * original tasks 10 | "widowx_spoon_on_towel", 11 | "widowx_carrot_on_plate", 12 | "widowx_stack_cube", 13 | "widowx_put_eggplant_in_basket", 14 | # * object distraction 15 | "widowx_spoon_on_towel_distract", 16 | "widowx_carrot_on_plate_distract", 17 | "widowx_carrot_on_keyboard_distract", 18 | "widowx_coke_can_on_plate_distract", 19 | "widowx_coke_can_on_keyboard_distract", 20 | # * generatization test 21 | "widowx_cube_on_plate_clean", # seen source and target, unseen combination 22 | "widowx_small_plate_on_green_cube_clean", # seen source and target, unseen combination (a reverse) 23 | "widowx_coke_can_on_plate_clean", # ood source 24 | "widowx_pepsi_on_plate_clean", # another OOD source besides coke can, also a texture difference 25 | "widowx_carrot_on_sponge_clean", # seen source and target, unseen combination 26 | "widowx_eggplant_on_sponge_clean", 27 | "widowx_carrot_on_keyboard_clean", # ood target 28 | "widowx_coke_can_on_keyboard_clean", 29 | # * language variation 30 | "widowx_carrot_on_plate_lang_common", # rabbit, no distract 31 | "widowx_carrot_on_plate_lang_action", 32 | "widowx_carrot_on_plate_lang_neg", 33 | "widowx_carrot_on_plate_lang_neg_action", # on the table not on the plate 34 | "widowx_carrot_on_plate_lang_common_distract", # rabbit 35 | "widowx_spoon_on_towel_lang_action", 36 | "widowx_spoon_on_towel_lang_common", 37 | "widowx_spoon_on_towel_lang_common_distract", 38 | "widowx_stack_cube_lang_action", 39 | "widowx_eggplant_in_basket_lang_action", 40 | "widowx_eggplant_in_basket_lang_color", 41 | "widowx_eggplant_in_basket_lang_common", 42 | "widowx_carrot_on_keyboard_lang_common", 43 | "widowx_coke_can_on_plate_lang_common", 44 | "widowx_coke_can_on_plate_lang_neg", 45 | "widowx_coke_can_on_plate_lang_common_distract", # thirsty 46 | # * new 47 | "widowx_orange_juice_on_plate_clean", 48 | "widowx_orange_juice_on_plate_distract", 49 | "widowx_orange_juice_on_plate_lang_neg", 50 | "widowx_orange_juice_on_plate_lang_common", 51 | "widowx_orange_juice_on_plate_lang_common_distract", 52 | "widowx_orange_juice_on_plate_lang_common_distractv2", 53 | "widowx_nut_on_plate_clean", 54 | "widowx_nut_on_plate_lang_common", 55 | "widowx_eggplant_on_keyboard_clean", 56 | "widowx_carrot_on_ramekin_clean", 57 | "widowx_carrot_on_wheel_clean", 58 | "widowx_coke_can_on_ramekin_clean", 59 | "widowx_coke_can_on_wheel_clean", 60 | "widowx_nut_on_wheel_clean", 61 | "widowx_cube_on_plate_lang_shape", 62 | "widowx_spoon_on_towel_lang_neg", 63 | "widowx_spoon_on_towel_lang_color", 64 | "widowx_carrot_on_plate_lang_color", 65 | ] 66 | 67 | n_eval_episode: 24 68 | n_video: 24 69 | recording: True 70 | pretrained_model_path: ./log/train/pi0_finetune_bridge/2025-03-11_22-42-06-42/checkpoint 71 | pretrained_model_gradient_step_cnt: [1513, 3026, 4539, 7565, 15130, 22695] 72 | 73 | env: 74 | dataset_statistics_path: ./config/dataset/bridge_statistics.json 75 | 76 | wandb: 77 | project: "vla_benchmark" -------------------------------------------------------------------------------- /config/experiment/simpler/pi0_finetune_bridge_ev_lang1.yaml: -------------------------------------------------------------------------------- 1 | name: pi0_finetune 2 | seed: 42 3 | model_cfg: !include ../../models/pi0_finetune_bridge.json 4 | 5 | eval_cfg: 6 | simulator_name: "simpler" 7 | env_adapter: "BridgeSimplerAdapter" 8 | task_list: [ 9 | # # * original tasks 10 | # "widowx_spoon_on_towel", 11 | # "widowx_carrot_on_plate", 12 | # "widowx_stack_cube", 13 | # "widowx_put_eggplant_in_basket", 14 | # # * object distraction 15 | # "widowx_spoon_on_towel_distract", 16 | # "widowx_carrot_on_plate_distract", 17 | # "widowx_carrot_on_keyboard_distract", 18 | # "widowx_coke_can_on_plate_distract", 19 | # "widowx_coke_can_on_keyboard_distract", 20 | # # * generatization test 21 | # "widowx_cube_on_plate_clean", # seen source and target, unseen combination 22 | # "widowx_small_plate_on_green_cube_clean", # seen source and target, unseen combination (a reverse) 23 | # "widowx_coke_can_on_plate_clean", # ood source 24 | # "widowx_pepsi_on_plate_clean", # another OOD source besides coke can, also a texture difference 25 | # "widowx_carrot_on_sponge_clean", # seen source and target, unseen combination 26 | # "widowx_eggplant_on_sponge_clean", 27 | # "widowx_carrot_on_keyboard_clean", # ood target 28 | # "widowx_coke_can_on_keyboard_clean", 29 | # # * language variation 30 | "widowx_carrot_on_plate_lang_common", # rabbit, no distract 31 | "widowx_carrot_on_plate_lang_action", 32 | "widowx_carrot_on_plate_lang_neg", 33 | "widowx_carrot_on_plate_lang_neg_action", # on the table not on the plate 34 | "widowx_carrot_on_plate_lang_common_distract", # rabbit 35 | "widowx_spoon_on_towel_lang_action", 36 | "widowx_spoon_on_towel_lang_common", 37 | "widowx_spoon_on_towel_lang_common_distract", 38 | # "widowx_stack_cube_lang_action", 39 | # "widowx_eggplant_in_basket_lang_action", 40 | # "widowx_eggplant_in_basket_lang_color", 41 | # "widowx_eggplant_in_basket_lang_common", 42 | # "widowx_carrot_on_keyboard_lang_common", 43 | # "widowx_coke_can_on_plate_lang_common", 44 | # "widowx_coke_can_on_plate_lang_neg", 45 | # "widowx_coke_can_on_plate_lang_common_distract", # thirsty 46 | ] 47 | 48 | n_eval_episode: 24 49 | n_video: 24 50 | recording: True 51 | pretrained_model_path: ./log/train/pi0_finetune_bridge/2025-03-11_22-42-06-42/checkpoint 52 | pretrained_model_gradient_step_cnt: [1513, 3026, 4539, 7565, 15130, 22695] 53 | 54 | env: 55 | dataset_statistics_path: ./config/dataset/bridge_statistics.json 56 | 57 | wandb: 58 | project: "vla_benchmark" -------------------------------------------------------------------------------- /config/experiment/simpler/pi0_finetune_bridge_ev_lang2.yaml: -------------------------------------------------------------------------------- 1 | name: pi0_finetune 2 | seed: 42 3 | model_cfg: !include ../../models/pi0_finetune_bridge.json 4 | 5 | eval_cfg: 6 | simulator_name: "simpler" 7 | env_adapter: "BridgeSimplerAdapter" 8 | task_list: [ 9 | # # * original tasks 10 | # "widowx_spoon_on_towel", 11 | # "widowx_carrot_on_plate", 12 | # "widowx_stack_cube", 13 | # "widowx_put_eggplant_in_basket", 14 | # # * object distraction 15 | # "widowx_spoon_on_towel_distract", 16 | # "widowx_carrot_on_plate_distract", 17 | # "widowx_carrot_on_keyboard_distract", 18 | # "widowx_coke_can_on_plate_distract", 19 | # "widowx_coke_can_on_keyboard_distract", 20 | # # # * generatization test 21 | # "widowx_cube_on_plate_clean", # seen source and target, unseen combination 22 | # "widowx_small_plate_on_green_cube_clean", # seen source and target, unseen combination (a reverse) 23 | # "widowx_coke_can_on_plate_clean", # ood source 24 | # "widowx_pepsi_on_plate_clean", # another OOD source besides coke can, also a texture difference 25 | # "widowx_carrot_on_sponge_clean", # seen source and target, unseen combination 26 | # "widowx_eggplant_on_sponge_clean", 27 | # "widowx_carrot_on_keyboard_clean", # ood target 28 | # "widowx_coke_can_on_keyboard_clean", 29 | # # * language variation 30 | # "widowx_carrot_on_plate_lang_common", # rabbit, no distract 31 | # "widowx_carrot_on_plate_lang_action", 32 | # "widowx_carrot_on_plate_lang_neg", 33 | # "widowx_carrot_on_plate_lang_neg_action", # on the table not on the plate 34 | # "widowx_carrot_on_plate_lang_common_distract", # rabbit 35 | # "widowx_spoon_on_towel_lang_action", 36 | # "widowx_spoon_on_towel_lang_common", 37 | # "widowx_spoon_on_towel_lang_common_distract", 38 | "widowx_stack_cube_lang_action", 39 | "widowx_eggplant_in_basket_lang_action", 40 | "widowx_eggplant_in_basket_lang_color", 41 | "widowx_eggplant_in_basket_lang_common", 42 | "widowx_carrot_on_keyboard_lang_common", 43 | "widowx_coke_can_on_plate_lang_common", 44 | "widowx_coke_can_on_plate_lang_neg", 45 | "widowx_coke_can_on_plate_lang_common_distract", # thirsty 46 | ] 47 | 48 | n_eval_episode: 24 49 | n_video: 24 50 | recording: True 51 | pretrained_model_path: ./log/train/pi0_finetune_bridge/2025-03-11_22-42-06-42/checkpoint 52 | pretrained_model_gradient_step_cnt: [1513, 3026, 4539, 7565, 15130, 22695] 53 | 54 | env: 55 | dataset_statistics_path: ./config/dataset/bridge_statistics.json 56 | 57 | wandb: 58 | project: "vla_benchmark" -------------------------------------------------------------------------------- /config/experiment/simpler/pi0_finetune_bridge_ev_ood.yaml: -------------------------------------------------------------------------------- 1 | name: pi0_finetune 2 | seed: 42 3 | model_cfg: !include ../../models/pi0_finetune_bridge.json 4 | 5 | eval_cfg: 6 | simulator_name: "simpler" 7 | env_adapter: "BridgeSimplerAdapter" 8 | task_list: [ 9 | # # * original tasks 10 | # "widowx_spoon_on_towel", 11 | # "widowx_carrot_on_plate", 12 | # "widowx_stack_cube", 13 | # "widowx_put_eggplant_in_basket", 14 | # # * object distraction 15 | # "widowx_spoon_on_towel_distract", 16 | # "widowx_carrot_on_plate_distract", 17 | # "widowx_carrot_on_keyboard_distract", 18 | # "widowx_coke_can_on_plate_distract", 19 | # "widowx_coke_can_on_keyboard_distract", 20 | # * generatization test 21 | "widowx_cube_on_plate_clean", # seen source and target, unseen combination 22 | "widowx_small_plate_on_green_cube_clean", # seen source and target, unseen combination (a reverse) 23 | "widowx_coke_can_on_plate_clean", # ood source 24 | "widowx_pepsi_on_plate_clean", # another OOD source besides coke can, also a texture difference 25 | "widowx_carrot_on_sponge_clean", # seen source and target, unseen combination 26 | "widowx_eggplant_on_sponge_clean", 27 | "widowx_carrot_on_keyboard_clean", # ood target 28 | "widowx_coke_can_on_keyboard_clean", 29 | # # * language variation 30 | # "widowx_carrot_on_plate_lang_common", # rabbit, no distract 31 | # "widowx_carrot_on_plate_lang_action", 32 | # "widowx_carrot_on_plate_lang_neg", 33 | # "widowx_carrot_on_plate_lang_neg_action", # on the table not on the plate 34 | # "widowx_carrot_on_plate_lang_common_distract", # rabbit 35 | # "widowx_spoon_on_towel_lang_action", 36 | # "widowx_spoon_on_towel_lang_common", 37 | # "widowx_spoon_on_towel_lang_common_distract", 38 | # "widowx_stack_cube_lang_action", 39 | # "widowx_eggplant_in_basket_lang_action", 40 | # "widowx_eggplant_in_basket_lang_color", 41 | # "widowx_eggplant_in_basket_lang_common", 42 | # "widowx_carrot_on_keyboard_lang_common", 43 | # "widowx_coke_can_on_plate_lang_common", 44 | # "widowx_coke_can_on_plate_lang_neg", 45 | # "widowx_coke_can_on_plate_lang_common_distract", # thirsty 46 | ] 47 | 48 | n_eval_episode: 24 49 | n_video: 24 50 | recording: True 51 | pretrained_model_path: ./log/train/pi0_finetune_bridge/2025-03-11_22-42-06-42/checkpoint 52 | pretrained_model_gradient_step_cnt: [1513, 3026, 4539, 7565, 15130, 22695] 53 | 54 | env: 55 | dataset_statistics_path: ./config/dataset/bridge_statistics.json 56 | 57 | wandb: 58 | project: "vla_benchmark" -------------------------------------------------------------------------------- /config/experiment/simpler/pi0_rephrase_ft_bridge_ev1.yaml: -------------------------------------------------------------------------------- 1 | name: pi0_finetune_paraphrase 2 | seed: 42 3 | model_cfg: !include ../../models/pi0_finetune_bridge.json 4 | 5 | eval_cfg: 6 | simulator_name: "simpler" 7 | env_adapter: "BridgeSimplerAdapter" 8 | task_list: [ 9 | # * original tasks 10 | "widowx_spoon_on_towel", 11 | "widowx_carrot_on_plate", 12 | "widowx_stack_cube", 13 | "widowx_put_eggplant_in_basket", 14 | # * object distraction 15 | "widowx_spoon_on_towel_distract", 16 | "widowx_carrot_on_plate_distract", 17 | "widowx_carrot_on_keyboard_distract", 18 | "widowx_coke_can_on_plate_distract", 19 | "widowx_coke_can_on_keyboard_distract", 20 | # * generatization test 21 | "widowx_cube_on_plate_clean", # seen source and target, unseen combination 22 | "widowx_small_plate_on_green_cube_clean", # seen source and target, unseen combination (a reverse) 23 | "widowx_coke_can_on_plate_clean", # ood source 24 | "widowx_pepsi_on_plate_clean", # another OOD source besides coke can, also a texture difference 25 | "widowx_carrot_on_sponge_clean", # seen source and target, unseen combination 26 | "widowx_eggplant_on_sponge_clean", 27 | "widowx_carrot_on_keyboard_clean", # ood target 28 | "widowx_coke_can_on_keyboard_clean", 29 | # * language variation 30 | "widowx_carrot_on_plate_lang_common", # rabbit, no distract 31 | "widowx_carrot_on_plate_lang_action", 32 | "widowx_carrot_on_plate_lang_neg", 33 | "widowx_carrot_on_plate_lang_common_distract", # rabbit 34 | "widowx_spoon_on_towel_lang_action", 35 | "widowx_spoon_on_towel_lang_common", 36 | "widowx_spoon_on_towel_lang_common_distract", 37 | "widowx_stack_cube_lang_action", 38 | "widowx_eggplant_in_basket_lang_action", 39 | "widowx_eggplant_in_basket_lang_color", 40 | "widowx_eggplant_in_basket_lang_common", 41 | "widowx_carrot_on_keyboard_lang_common", 42 | "widowx_coke_can_on_plate_lang_common", 43 | "widowx_coke_can_on_plate_lang_neg", 44 | "widowx_coke_can_on_plate_lang_common_distract", # thirsty 45 | 46 | "widowx_orange_juice_on_plate_clean", 47 | "widowx_orange_juice_on_plate_distract", 48 | "widowx_orange_juice_on_plate_lang_neg", 49 | "widowx_orange_juice_on_plate_lang_common", 50 | "widowx_orange_juice_on_plate_lang_common_distract", 51 | "widowx_orange_juice_on_plate_lang_common_distractv2", 52 | "widowx_nut_on_plate_clean", 53 | "widowx_nut_on_plate_lang_common", 54 | "widowx_eggplant_on_keyboard_clean", 55 | "widowx_carrot_on_ramekin_clean", 56 | "widowx_carrot_on_wheel_clean", 57 | "widowx_coke_can_on_ramekin_clean", 58 | "widowx_coke_can_on_wheel_clean", 59 | "widowx_nut_on_wheel_clean", 60 | "widowx_cube_on_plate_lang_shape", 61 | "widowx_spoon_on_towel_lang_neg", 62 | "widowx_spoon_on_towel_lang_color", 63 | "widowx_carrot_on_plate_lang_color", 64 | ] 65 | 66 | n_eval_episode: 24 67 | n_video: 24 68 | recording: True 69 | pretrained_model_path: ./log/train/pi0_finetune_paraphrase/2025-05-11_13-25-23_42/checkpoint 70 | pretrained_model_gradient_step_cnt: [1513, 3026, 4539, 7565, 15130, 22695] 71 | 72 | env: 73 | dataset_statistics_path: ./config/dataset/bridge_statistics.json 74 | 75 | wandb: 76 | project: "vla_benchmark" -------------------------------------------------------------------------------- /config/experiment/simpler/spatialvla_finetune_bridge_ev.yaml: -------------------------------------------------------------------------------- 1 | name: sptialvla_finetune_bridge_simpler 2 | seed: 42 3 | model_cfg: !include ../../models/spatialvla_finetune_bridge.json 4 | 5 | eval_cfg: 6 | simulator_name: "simpler" 7 | env_adapter: "BridgeSimplerSpatialVLAAdapter" 8 | task_list: [ 9 | # * original tasks 10 | "widowx_spoon_on_towel", 11 | "widowx_carrot_on_plate", 12 | "widowx_stack_cube", 13 | "widowx_put_eggplant_in_basket", 14 | # * object distraction 15 | "widowx_spoon_on_towel_distract", 16 | "widowx_carrot_on_plate_distract", 17 | "widowx_carrot_on_keyboard_distract", 18 | "widowx_coke_can_on_plate_distract", 19 | "widowx_coke_can_on_keyboard_distract", 20 | # * generatization test 21 | "widowx_cube_on_plate_clean", # seen source and target, unseen combination 22 | "widowx_small_plate_on_green_cube_clean", # seen source and target, unseen combination (a reverse) 23 | "widowx_coke_can_on_plate_clean", # ood source 24 | "widowx_pepsi_on_plate_clean", # another OOD source besides coke can, also a texture difference 25 | "widowx_carrot_on_sponge_clean", # seen source and target, unseen combination 26 | "widowx_eggplant_on_sponge_clean", 27 | "widowx_carrot_on_keyboard_clean", # ood target 28 | "widowx_coke_can_on_keyboard_clean", 29 | # * language variation 30 | "widowx_carrot_on_plate_lang_common", # rabbit, no distract 31 | "widowx_carrot_on_plate_lang_action", 32 | "widowx_carrot_on_plate_lang_neg", 33 | "widowx_carrot_on_plate_lang_neg_action", # on the table not on the plate 34 | "widowx_carrot_on_plate_lang_common_distract", # rabbit 35 | "widowx_spoon_on_towel_lang_action", 36 | "widowx_spoon_on_towel_lang_common", 37 | "widowx_spoon_on_towel_lang_common_distract", 38 | "widowx_stack_cube_lang_action", 39 | "widowx_eggplant_in_basket_lang_action", 40 | "widowx_eggplant_in_basket_lang_color", 41 | "widowx_eggplant_in_basket_lang_common", 42 | "widowx_carrot_on_keyboard_lang_common", 43 | "widowx_coke_can_on_plate_lang_common", 44 | "widowx_coke_can_on_plate_lang_neg", 45 | "widowx_coke_can_on_plate_lang_common_distract", # thirsty 46 | * new 47 | "widowx_orange_juice_on_plate_clean", 48 | "widowx_orange_juice_on_plate_distract", 49 | "widowx_orange_juice_on_plate_lang_neg", 50 | "widowx_orange_juice_on_plate_lang_common", 51 | "widowx_orange_juice_on_plate_lang_common_distract", 52 | "widowx_orange_juice_on_plate_lang_common_distractv2", 53 | "widowx_nut_on_plate_clean", 54 | "widowx_nut_on_plate_lang_common", 55 | "widowx_eggplant_on_keyboard_clean", 56 | "widowx_carrot_on_ramekin_clean", 57 | "widowx_carrot_on_wheel_clean", 58 | "widowx_coke_can_on_ramekin_clean", 59 | "widowx_coke_can_on_wheel_clean", 60 | "widowx_nut_on_wheel_clean", 61 | "widowx_cube_on_plate_lang_shape", 62 | "widowx_spoon_on_towel_lang_neg", 63 | "widowx_spoon_on_towel_lang_color", 64 | "widowx_carrot_on_plate_lang_color", 65 | ] 66 | 67 | n_eval_episode: 24 68 | n_video: 24 69 | recording: True 70 | pretrained_model_path: "IPEC-COMMUNITY/spatialvla-4b-224-sft-bridge" 71 | action_step: 1 # OpenVLA-like models usually use action emsemble to fold 4 future steps in to 1 72 | unnorm_key: "bridge_orig/1.0.0" 73 | 74 | env: 75 | dataset_statistics_path: ./config/dataset/bridge_statistics.json 76 | 77 | wandb: 78 | project: "vla_benchmark" -------------------------------------------------------------------------------- /config/experiment/simpler/spatialvla_finetune_bridge_ev_lang1.yaml: -------------------------------------------------------------------------------- 1 | name: sptialvla_finetune_bridge_simpler 2 | seed: 42 3 | model_cfg: !include ../../models/spatialvla_finetune_bridge.json 4 | 5 | eval_cfg: 6 | simulator_name: "simpler" 7 | env_adapter: "BridgeSimplerSpatialVLAAdapter" 8 | task_list: [ 9 | # # * original tasks 10 | # "widowx_spoon_on_towel", 11 | # "widowx_carrot_on_plate", 12 | # "widowx_stack_cube", 13 | # "widowx_put_eggplant_in_basket", 14 | # # * object distraction 15 | # "widowx_spoon_on_towel_distract", 16 | # "widowx_carrot_on_plate_distract", 17 | # "widowx_carrot_on_keyboard_distract", 18 | # "widowx_coke_can_on_plate_distract", 19 | # "widowx_coke_can_on_keyboard_distract", 20 | # # * generatization test 21 | # "widowx_cube_on_plate_clean", # seen source and target, unseen combination 22 | # "widowx_small_plate_on_green_cube_clean", # seen source and target, unseen combination (a reverse) 23 | # "widowx_coke_can_on_plate_clean", # ood source 24 | # "widowx_pepsi_on_plate_clean", # another OOD source besides coke can, also a texture difference 25 | # "widowx_carrot_on_sponge_clean", # seen source and target, unseen combination 26 | # "widowx_eggplant_on_sponge_clean", 27 | # "widowx_carrot_on_keyboard_clean", # ood target 28 | # "widowx_coke_can_on_keyboard_clean", 29 | # # * language variation 30 | "widowx_carrot_on_plate_lang_common", # rabbit, no distract 31 | "widowx_carrot_on_plate_lang_action", 32 | "widowx_carrot_on_plate_lang_neg", 33 | "widowx_carrot_on_plate_lang_neg_action", # on the table not on the plate 34 | "widowx_carrot_on_plate_lang_common_distract", # rabbit 35 | "widowx_spoon_on_towel_lang_action", 36 | "widowx_spoon_on_towel_lang_common", 37 | "widowx_spoon_on_towel_lang_common_distract", 38 | # "widowx_stack_cube_lang_action", 39 | # "widowx_eggplant_in_basket_lang_action", 40 | # "widowx_eggplant_in_basket_lang_color", 41 | # "widowx_eggplant_in_basket_lang_common", 42 | # "widowx_carrot_on_keyboard_lang_common", 43 | # "widowx_coke_can_on_plate_lang_common", 44 | # "widowx_coke_can_on_plate_lang_neg", 45 | # "widowx_coke_can_on_plate_lang_common_distract", # thirsty 46 | ] 47 | 48 | n_eval_episode: 24 49 | n_video: 24 50 | recording: True 51 | pretrained_model_path: "IPEC-COMMUNITY/spatialvla-4b-224-sft-bridge" 52 | action_step: 1 53 | unnorm_key: "bridge_orig/1.0.0" 54 | 55 | env: 56 | dataset_statistics_path: ./config/dataset/bridge_statistics.json 57 | 58 | wandb: 59 | project: "vla_benchmark" -------------------------------------------------------------------------------- /config/experiment/simpler/spatialvla_finetune_bridge_ev_lang2.yaml: -------------------------------------------------------------------------------- 1 | name: sptialvla_finetune_bridge_simpler 2 | seed: 42 3 | model_cfg: !include ../../models/spatialvla_finetune_bridge.json 4 | 5 | eval_cfg: 6 | simulator_name: "simpler" 7 | env_adapter: "BridgeSimplerSpatialVLAAdapter" 8 | task_list: [ 9 | # # * original tasks 10 | # "widowx_spoon_on_towel", 11 | # "widowx_carrot_on_plate", 12 | # "widowx_stack_cube", 13 | # "widowx_put_eggplant_in_basket", 14 | # # * object distraction 15 | # "widowx_spoon_on_towel_distract", 16 | # "widowx_carrot_on_plate_distract", 17 | # "widowx_carrot_on_keyboard_distract", 18 | # "widowx_coke_can_on_plate_distract", 19 | # "widowx_coke_can_on_keyboard_distract", 20 | # # # * generatization test 21 | # "widowx_cube_on_plate_clean", # seen source and target, unseen combination 22 | # "widowx_small_plate_on_green_cube_clean", # seen source and target, unseen combination (a reverse) 23 | # "widowx_coke_can_on_plate_clean", # ood source 24 | # "widowx_pepsi_on_plate_clean", # another OOD source besides coke can, also a texture difference 25 | # "widowx_carrot_on_sponge_clean", # seen source and target, unseen combination 26 | # "widowx_eggplant_on_sponge_clean", 27 | # "widowx_carrot_on_keyboard_clean", # ood target 28 | # "widowx_coke_can_on_keyboard_clean", 29 | # # * language variation 30 | # "widowx_carrot_on_plate_lang_common", # rabbit, no distract 31 | # "widowx_carrot_on_plate_lang_action", 32 | # "widowx_carrot_on_plate_lang_neg", 33 | # "widowx_carrot_on_plate_lang_neg_action", # on the table not on the plate 34 | # "widowx_carrot_on_plate_lang_common_distract", # rabbit 35 | # "widowx_spoon_on_towel_lang_action", 36 | # "widowx_spoon_on_towel_lang_common", 37 | # "widowx_spoon_on_towel_lang_common_distract", 38 | "widowx_stack_cube_lang_action", 39 | "widowx_eggplant_in_basket_lang_action", 40 | "widowx_eggplant_in_basket_lang_color", 41 | "widowx_eggplant_in_basket_lang_common", 42 | "widowx_carrot_on_keyboard_lang_common", 43 | "widowx_coke_can_on_plate_lang_common", 44 | "widowx_coke_can_on_plate_lang_neg", 45 | "widowx_coke_can_on_plate_lang_common_distract", # thirsty 46 | ] 47 | 48 | n_eval_episode: 24 49 | n_video: 24 50 | recording: True 51 | pretrained_model_path: "IPEC-COMMUNITY/spatialvla-4b-224-sft-bridge" 52 | action_step: 1 53 | unnorm_key: "bridge_orig/1.0.0" 54 | 55 | env: 56 | dataset_statistics_path: ./config/dataset/bridge_statistics.json 57 | 58 | wandb: 59 | project: "vla_benchmark" -------------------------------------------------------------------------------- /config/experiment/simpler/spatialvla_finetune_bridge_ev_ood.yaml: -------------------------------------------------------------------------------- 1 | name: sptialvla_finetune_bridge_simpler 2 | seed: 42 3 | model_cfg: !include ../../models/spatialvla_finetune_bridge.json 4 | 5 | eval_cfg: 6 | simulator_name: "simpler" 7 | env_adapter: "BridgeSimplerSpatialVLAAdapter" 8 | task_list: [ 9 | # # * original tasks 10 | # "widowx_spoon_on_towel", 11 | # "widowx_carrot_on_plate", 12 | # "widowx_stack_cube", 13 | # "widowx_put_eggplant_in_basket", 14 | # # * object distraction 15 | # "widowx_spoon_on_towel_distract", 16 | # "widowx_carrot_on_plate_distract", 17 | # "widowx_carrot_on_keyboard_distract", 18 | # "widowx_coke_can_on_plate_distract", 19 | # "widowx_coke_can_on_keyboard_distract", 20 | # * generatization test 21 | "widowx_cube_on_plate_clean", # seen source and target, unseen combination 22 | "widowx_small_plate_on_green_cube_clean", # seen source and target, unseen combination (a reverse) 23 | "widowx_coke_can_on_plate_clean", # ood source 24 | "widowx_pepsi_on_plate_clean", # another OOD source besides coke can, also a texture difference 25 | "widowx_carrot_on_sponge_clean", # seen source and target, unseen combination 26 | "widowx_eggplant_on_sponge_clean", 27 | "widowx_carrot_on_keyboard_clean", # ood target 28 | "widowx_coke_can_on_keyboard_clean", 29 | # # * language variation 30 | # "widowx_carrot_on_plate_lang_common", # rabbit, no distract 31 | # "widowx_carrot_on_plate_lang_action", 32 | # "widowx_carrot_on_plate_lang_neg", 33 | # "widowx_carrot_on_plate_lang_neg_action", # on the table not on the plate 34 | # "widowx_carrot_on_plate_lang_common_distract", # rabbit 35 | # "widowx_spoon_on_towel_lang_action", 36 | # "widowx_spoon_on_towel_lang_common", 37 | # "widowx_spoon_on_towel_lang_common_distract", 38 | # "widowx_stack_cube_lang_action", 39 | # "widowx_eggplant_in_basket_lang_action", 40 | # "widowx_eggplant_in_basket_lang_color", 41 | # "widowx_eggplant_in_basket_lang_common", 42 | # "widowx_carrot_on_keyboard_lang_common", 43 | # "widowx_coke_can_on_plate_lang_common", 44 | # "widowx_coke_can_on_plate_lang_neg", 45 | # "widowx_coke_can_on_plate_lang_common_distract", # thirsty 46 | ] 47 | 48 | n_eval_episode: 24 49 | n_video: 24 50 | recording: True 51 | pretrained_model_path: "IPEC-COMMUNITY/spatialvla-4b-224-sft-bridge" 52 | action_step: 1 53 | unnorm_key: "bridge_orig/1.0.0" 54 | 55 | env: 56 | dataset_statistics_path: ./config/dataset/bridge_statistics.json 57 | 58 | wandb: 59 | project: "vla_benchmark" -------------------------------------------------------------------------------- /config/experiment/simplerMS3/pi0_baseline_bridge_ev.yaml: -------------------------------------------------------------------------------- 1 | name: pi0_baseline_bridge_simplerMS3_rotation_fix 2 | seed: 42 3 | model_cfg: !include ../../models/pi0_baseline_bridge.json 4 | 5 | eval_cfg: 6 | simulator_name: "simplerMS3" 7 | env_adapter: "BridgeSimplerBatchAdapter" 8 | task_list: [ 9 | # "widowx_carrot_on_plate", 10 | # "widowx_put_eggplant_in_basket", 11 | # "widowx_spoon_on_towel", 12 | # "widowx_stack_cube", 13 | "widowx_coke_can_on_plate", 14 | # "widowx_eggplant_on_carrot" 15 | ] 16 | 17 | n_eval_episode: 240 18 | n_video: 240 19 | n_parallel_eval: 60 20 | recording: True 21 | pretrained_model_path: ./log/train/pi0_baseline_bridge/2025-03-04_00-00-38-42/checkpoint 22 | pretrained_model_gradient_step_cnt: [1513, 3026, 4539, 7565, 15130, 22695] 23 | 24 | env: 25 | dataset_statistics_path: ./config/dataset/bridge_statistics.json 26 | -------------------------------------------------------------------------------- /config/experiment/simplerMS3/pi0_finetune_bridge_ev.yaml: -------------------------------------------------------------------------------- 1 | name: pi0_finetune_bridge_simplerMS3_rotation_fix 2 | seed: 42 3 | model_cfg: !include ../../models/pi0_finetune_bridge.json 4 | 5 | eval_cfg: 6 | simulator_name: "simplerMS3" 7 | env_adapter: "BridgeSimplerBatchAdapter" 8 | task_list: [ 9 | # "widowx_carrot_on_plate", 10 | # "widowx_put_eggplant_in_basket", 11 | # "widowx_spoon_on_towel", 12 | # "widowx_stack_cube", 13 | "widowx_coke_can_on_plate", 14 | # "widowx_eggplant_on_carrot" 15 | ] 16 | 17 | n_eval_episode: 240 18 | n_video: 240 19 | n_parallel_eval: 60 20 | recording: True 21 | pretrained_model_path: ./log/train/pi0_finetune_bridge/2025-03-11_22-42-06-42/checkpoint 22 | pretrained_model_gradient_step_cnt: [1513, 3026, 4539, 7565, 15130, 22695] 23 | 24 | env: 25 | dataset_statistics_path: ./config/dataset/bridge_statistics.json -------------------------------------------------------------------------------- /config/experiment/simplerMS3/pi0fast_finetune_bridge_ev.yaml: -------------------------------------------------------------------------------- 1 | name: pi0fast_finetune_bridge_simplerMS3 2 | seed: 42 3 | model_cfg: !include ../../models/pi0fast_finetune_bridge.json 4 | 5 | eval_cfg: 6 | simulator_name: "simplerMS3" 7 | env_adapter: "BridgeSimplerBatchAdapter" 8 | task_list: ["widowx_carrot_on_plate", 9 | "widowx_put_eggplant_in_basket", 10 | "widowx_spoon_on_towel", 11 | "widowx_stack_cube"] 12 | 13 | n_eval_episode: 240 14 | n_video: 240 15 | n_parallel_eval: 60 16 | recording: True 17 | pretrained_model_path: ./log/train/pi0fast_finetune_bridge/2025-04-27_05-31-14_42/checkpoint 18 | pretrained_model_gradient_step_cnt: [1513, 3026, 4539, 7565, 15130, 22695] 19 | 20 | env: 21 | dataset_statistics_path: ./config/dataset/bridge_statistics.json -------------------------------------------------------------------------------- /config/models/hf_pi0.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "pi0", 3 | "n_obs_steps": 1, 4 | "normalization_mapping": { 5 | "VISUAL": "IDENTITY", 6 | "STATE": "MEAN_STD", 7 | "ACTION": "MEAN_STD" 8 | }, 9 | "input_features": { 10 | "observation.images.top": { 11 | "shape": [ 12 | 3, 13 | 224, 14 | 224 15 | ], 16 | "type": "VISUAL" 17 | }, 18 | "observation.images.left": { 19 | "shape": [ 20 | 3, 21 | 224, 22 | 224 23 | ], 24 | "type": "VISUAL" 25 | }, 26 | "observation.images.right": { 27 | "shape": [ 28 | 3, 29 | 224, 30 | 224 31 | ], 32 | "type": "VISUAL" 33 | }, 34 | "observation.state": { 35 | "shape": [ 36 | 7 37 | ], 38 | "type": "STATE" 39 | } 40 | }, 41 | "output_features": { 42 | "action": { 43 | "type": "ACTION", 44 | "shape": [ 45 | 7 46 | ] 47 | } 48 | }, 49 | "chunk_size": 4, 50 | "n_action_steps": 4, 51 | "max_state_dim": 32, 52 | "max_action_dim": 32, 53 | "resize_imgs_with_padding": [ 54 | 224, 55 | 224 56 | ], 57 | "empty_cameras": 0, 58 | "adapt_to_pi_aloha": false, 59 | "use_delta_joint_actions_aloha": false, 60 | "tokenizer_max_length": 72, 61 | "proj_width": 1024, 62 | "num_steps": 10, 63 | "use_cache": true, 64 | "attention_implementation": "eager", 65 | "freeze_vision_encoder": false, 66 | "train_expert_only": false, 67 | "train_state_proj": true, 68 | "optimizer_lr": 2.5e-05, 69 | "optimizer_betas": [ 70 | 0.9, 71 | 0.95 72 | ], 73 | "optimizer_eps": 1e-08, 74 | "optimizer_weight_decay": 1e-10, 75 | "scheduler_warmup_steps": 1000, 76 | "scheduler_decay_steps": 30000, 77 | "scheduler_decay_lr": 2.5e-06 78 | } -------------------------------------------------------------------------------- /config/models/magma.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "magma", 3 | "n_obs_steps": 1, 4 | "device": "cpu", 5 | "use_amp": true, 6 | "chunk_size": 4 7 | } -------------------------------------------------------------------------------- /config/models/octo.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "octo", 3 | "n_obs_steps": 2, 4 | "device": "cpu", 5 | "use_amp": true, 6 | "chunk_size": 4 7 | } -------------------------------------------------------------------------------- /config/models/pi0_baseline_bridge.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "pi0", 3 | "n_obs_steps": 1, 4 | "device": "cpu", 5 | "use_amp": true, 6 | "normalization_mapping": { 7 | "VISUAL": "IDENTITY", 8 | "STATE": "IDENTITY", 9 | "ACTION": "IDENTITY" 10 | }, 11 | "input_features": { 12 | "observation.images.top": { 13 | "shape": [ 14 | 3, 15 | 224, 16 | 224 17 | ], 18 | "type": "VISUAL" 19 | }, 20 | "observation.state": { 21 | "shape": [ 22 | 7 23 | ], 24 | "type": "STATE" 25 | } 26 | }, 27 | "output_features": { 28 | "action": { 29 | "type": "ACTION", 30 | "shape": [ 31 | 7 32 | ] 33 | } 34 | }, 35 | "chunk_size": 4, 36 | "n_action_steps": 4, 37 | "max_state_dim": 7, 38 | "max_action_dim": 7, 39 | "resize_imgs_with_padding": [ 40 | 224, 41 | 224 42 | ], 43 | "empty_cameras": 0, 44 | "adapt_to_pi_aloha": false, 45 | "use_delta_joint_actions_aloha": false, 46 | "tokenizer_max_length": 72, 47 | "proj_width": 1024, 48 | "num_steps": 10, 49 | "use_cache": true, 50 | "attention_implementation": "eager", 51 | "freeze_vision_encoder": false, 52 | "train_expert_only": false, 53 | "train_state_proj": true, 54 | "optimizer_lr": 5e-5, 55 | "optimizer_betas": [ 56 | 0.9, 57 | 0.999 58 | ], 59 | "optimizer_eps": 1e-8, 60 | "optimizer_weight_decay": 0, 61 | "scheduler_warmup_steps": 200, 62 | "scheduler_decay_steps": 30000, 63 | "scheduler_decay_lr": 2.5e-06 64 | } -------------------------------------------------------------------------------- /config/models/pi0_baseline_fractal.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "pi0", 3 | "n_obs_steps": 1, 4 | "device": "cpu", 5 | "use_amp": true, 6 | "normalization_mapping": { 7 | "VISUAL": "IDENTITY", 8 | "STATE": "IDENTITY", 9 | "ACTION": "IDENTITY" 10 | }, 11 | "input_features": { 12 | "observation.images.top": { 13 | "shape": [ 14 | 3, 15 | 224, 16 | 224 17 | ], 18 | "type": "VISUAL" 19 | }, 20 | "observation.state": { 21 | "shape": [ 22 | 8 23 | ], 24 | "type": "STATE" 25 | } 26 | }, 27 | "output_features": { 28 | "action": { 29 | "type": "ACTION", 30 | "shape": [ 31 | 7 32 | ] 33 | } 34 | }, 35 | "chunk_size": 4, 36 | "n_action_steps": 4, 37 | "max_state_dim": 8, 38 | "max_action_dim": 7, 39 | "resize_imgs_with_padding": [ 40 | 224, 41 | 224 42 | ], 43 | "empty_cameras": 0, 44 | "adapt_to_pi_aloha": false, 45 | "use_delta_joint_actions_aloha": false, 46 | "tokenizer_max_length": 72, 47 | "proj_width": 1024, 48 | "num_steps": 10, 49 | "use_cache": true, 50 | "attention_implementation": "eager", 51 | "freeze_vision_encoder": false, 52 | "train_expert_only": false, 53 | "train_state_proj": true, 54 | "optimizer_lr": 5e-5, 55 | "optimizer_betas": [ 56 | 0.9, 57 | 0.999 58 | ], 59 | "optimizer_eps": 1e-8, 60 | "optimizer_weight_decay": 0, 61 | "scheduler_warmup_steps": 200, 62 | "scheduler_decay_steps": 30000, 63 | "scheduler_decay_lr": 2.5e-06 64 | } -------------------------------------------------------------------------------- /config/models/pi0_finetune_bridge.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "pi0", 3 | "n_obs_steps": 1, 4 | "device": "cpu", 5 | "use_amp": true, 6 | "normalization_mapping": { 7 | "VISUAL": "IDENTITY", 8 | "STATE": "IDENTITY", 9 | "ACTION": "IDENTITY" 10 | }, 11 | "input_features": { 12 | "observation.images.top": { 13 | "shape": [ 14 | 3, 15 | 224, 16 | 224 17 | ], 18 | "type": "VISUAL" 19 | }, 20 | "observation.state": { 21 | "shape": [ 22 | 7 23 | ], 24 | "type": "STATE" 25 | } 26 | }, 27 | "output_features": { 28 | "action": { 29 | "type": "ACTION", 30 | "shape": [ 31 | 7 32 | ] 33 | } 34 | }, 35 | "chunk_size": 4, 36 | "n_action_steps": 4, 37 | "max_state_dim": 32, 38 | "max_action_dim": 32, 39 | "resize_imgs_with_padding": [ 40 | 224, 41 | 224 42 | ], 43 | "empty_cameras": 0, 44 | "adapt_to_pi_aloha": false, 45 | "use_delta_joint_actions_aloha": false, 46 | "tokenizer_max_length": 72, 47 | "proj_width": 1024, 48 | "num_steps": 10, 49 | "use_cache": true, 50 | "attention_implementation": "eager", 51 | "freeze_vision_encoder": false, 52 | "train_expert_only": false, 53 | "train_state_proj": true, 54 | "paligemma_pretrained_path": null, 55 | "optimizer_lr": 5e-5, 56 | "optimizer_betas": [ 57 | 0.9, 58 | 0.999 59 | ], 60 | "optimizer_eps": 1e-8, 61 | "optimizer_weight_decay": 0, 62 | "scheduler_warmup_steps": 200, 63 | "scheduler_decay_steps": 30000, 64 | "scheduler_decay_lr": 2.5e-06 65 | } -------------------------------------------------------------------------------- /config/models/pi0fast_baseline_bridge.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "pi0fast", 3 | "n_obs_steps": 1, 4 | "device": "cpu", 5 | "use_amp": true, 6 | "normalization_mapping": { 7 | "VISUAL": "IDENTITY", 8 | "STATE": "IDENTITY", 9 | "ACTION": "IDENTITY" 10 | }, 11 | "input_features": { 12 | "observation.images.top": { 13 | "shape": [ 14 | 3, 15 | 224, 16 | 224 17 | ], 18 | "type": "VISUAL" 19 | }, 20 | "observation.state": { 21 | "shape": [ 22 | 7 23 | ], 24 | "type": "STATE" 25 | } 26 | }, 27 | "output_features": { 28 | "action": { 29 | "type": "ACTION", 30 | "shape": [ 31 | 7 32 | ] 33 | } 34 | }, 35 | "chunk_size": 4, 36 | "n_action_steps": 4, 37 | "max_state_dim": 7, 38 | "max_action_dim": 7, 39 | "resize_imgs_with_padding": [ 40 | 224, 41 | 224 42 | ], 43 | "interpolate_like_pi": false, 44 | "empty_cameras": 0, 45 | "adapt_to_pi_aloha": false, 46 | "use_delta_joint_actions_aloha": false, 47 | "tokenizer_max_length": 72, 48 | "proj_width": 1024, 49 | "max_decoding_steps": 256, 50 | "fast_skip_tokens": 128, 51 | "max_input_seq_len": 256, 52 | "use_cache": true, 53 | "freeze_vision_encoder": false, 54 | "freeze_lm_head": true, 55 | "optimizer_lr": 5e-5, 56 | "optimizer_betas": [ 57 | 0.9, 58 | 0.999 59 | ], 60 | "optimizer_eps": 1e-8, 61 | "optimizer_weight_decay": 0, 62 | "scheduler_warmup_steps": 200, 63 | "scheduler_decay_steps": 30000, 64 | "scheduler_decay_lr": 2.5e-06 65 | } 66 | -------------------------------------------------------------------------------- /config/models/pi0fast_finetune_bridge.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "pi0fast", 3 | "n_obs_steps": 1, 4 | "device": "cpu", 5 | "use_amp": true, 6 | "normalization_mapping": { 7 | "VISUAL": "IDENTITY", 8 | "STATE": "IDENTITY", 9 | "ACTION": "IDENTITY" 10 | }, 11 | "input_features": { 12 | "observation.images.top": { 13 | "shape": [ 14 | 3, 15 | 224, 16 | 224 17 | ], 18 | "type": "VISUAL" 19 | }, 20 | "observation.state": { 21 | "shape": [ 22 | 7 23 | ], 24 | "type": "STATE" 25 | } 26 | }, 27 | "output_features": { 28 | "action": { 29 | "type": "ACTION", 30 | "shape": [ 31 | 7 32 | ] 33 | } 34 | }, 35 | "chunk_size": 4, 36 | "n_action_steps": 4, 37 | "max_state_dim": 32, 38 | "max_action_dim": 32, 39 | "resize_imgs_with_padding": [ 40 | 224, 41 | 224 42 | ], 43 | "interpolate_like_pi": false, 44 | "empty_cameras": 0, 45 | "adapt_to_pi_aloha": false, 46 | "use_delta_joint_actions_aloha": false, 47 | "tokenizer_max_length": 72, 48 | "proj_width": 1024, 49 | "max_decoding_steps": 256, 50 | "fast_skip_tokens": 128, 51 | "max_input_seq_len": 256, 52 | "use_cache": true, 53 | "freeze_vision_encoder": false, 54 | "freeze_lm_head": true, 55 | "optimizer_lr": 5e-5, 56 | "optimizer_betas": [ 57 | 0.9, 58 | 0.999 59 | ], 60 | "optimizer_eps": 1e-8, 61 | "optimizer_weight_decay": 0, 62 | "scheduler_warmup_steps": 200, 63 | "scheduler_decay_steps": 30000, 64 | "scheduler_decay_lr": 2.5e-06 65 | } 66 | -------------------------------------------------------------------------------- /config/models/spatialvla_finetune_bridge.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "spatial-vla", 3 | "n_obs_steps": 1, 4 | "device": "cpu", 5 | "use_amp": true, 6 | "chunk_size": 4 7 | } -------------------------------------------------------------------------------- /config/train/pi0_baseline_bridge.yaml: -------------------------------------------------------------------------------- 1 | # our pi0_baseline. 2 | # max_dim set to 7. language token max set to 72. all hyperparameters follows Allen 3 | name: "pi0_baseline_bridge" 4 | model_cfg: !include ../models/pi0_baseline_bridge.json 5 | freeze_lm_head: true -------------------------------------------------------------------------------- /config/train/pi0_baseline_bridge_freezevlm.yaml: -------------------------------------------------------------------------------- 1 | # our pi0_baseline. 2 | # max_dim set to 7. language token max set to 72. 3 | name: "pi0_baseline_bridge_freeze_vlm" 4 | model_cfg: !include ../models/pi0_baseline_bridge.json 5 | # freeze_lm_head: true 6 | freeze_vlm: true -------------------------------------------------------------------------------- /config/train/pi0_baseline_bridge_paraphrase.yaml: -------------------------------------------------------------------------------- 1 | # our pi0_baseline. 2 | # max_dim set to 7. language token max set to 72. all hyperparameters follows Allen 3 | name: "pi0_baseline_bridge_paraphrase" 4 | model_cfg: !include ../models/pi0_baseline_bridge.json 5 | freeze_lm_head: true 6 | task_paraphrase: true -------------------------------------------------------------------------------- /config/train/pi0_baseline_fractal.yaml: -------------------------------------------------------------------------------- 1 | # our pi0_baseline on fractal 2 | name: "pi0_baseline_fractal" 3 | model_cfg: !include ../models/pi0_baseline_fractal.json 4 | freeze_lm_head: true 5 | 6 | n_epochs: 10 7 | 8 | data: 9 | train: 10 | dataset_mix: fractal 11 | split: train[:95%] 12 | val: 13 | dataset_mix: fractal 14 | split: train[95%:] 15 | shuffle_buffer_size: 10000 16 | train_episode_count: 3786400 17 | 18 | -------------------------------------------------------------------------------- /config/train/pi0_finetune_bridge.yaml: -------------------------------------------------------------------------------- 1 | # our pi0 finetune on bridge. 2 | # max_dim set to 32. language token max set to 72. all hyperparameters follows Allen 3 | name: "pi0_finetune" 4 | 5 | model_cfg: !include ../models/pi0_finetune_bridge.json 6 | freeze_lm_head: true 7 | load_from_checkpoint: "lerobot/pi0" -------------------------------------------------------------------------------- /config/train/pi0_finetune_bridge_paraphrase.yaml: -------------------------------------------------------------------------------- 1 | # our pi0 finetune on bridge. 2 | # max_dim set to 32. language token max set to 72. all hyperparameters follows Allen 3 | name: "pi0_finetune_paraphrase" 4 | 5 | model_cfg: !include ../models/pi0_finetune_bridge.json 6 | freeze_lm_head: true 7 | load_from_checkpoint: "lerobot/pi0" 8 | task_paraphrase: true -------------------------------------------------------------------------------- /packages/policy-server-client/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai4ce/INT-ACT/66ae3ed4719c68cb2f3f22868015a7a171b04e1f/packages/policy-server-client/README.md -------------------------------------------------------------------------------- /packages/policy-server-client/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "policy-server-client" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.7" 7 | dependencies = [ 8 | "msgpack>=1.0.5", 9 | "numpy>=1.21.6", 10 | "pillow>=9.0.0", 11 | "websockets>=11.0", 12 | ] 13 | 14 | [build-system] 15 | requires = ["hatchling"] 16 | build-backend = "hatchling.build" 17 | -------------------------------------------------------------------------------- /packages/policy-server-client/src/policy_server_client/__init__.py: -------------------------------------------------------------------------------- 1 | __VERSION__ = "0.1.0" 2 | -------------------------------------------------------------------------------- /packages/policy-server-client/src/policy_server_client/base_policy.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class BasePolicy(abc.ABC): 5 | @abc.abstractmethod 6 | def infer(self, obs: dict) -> dict: 7 | """Infer actions from observations.""" 8 | raise NotImplementedError("infer() not implemented") 9 | 10 | @abc.abstractmethod 11 | def reset(self) -> None: 12 | """Reset the policy to its initial state.""" 13 | raise NotImplementedError("reset() not implemented") 14 | -------------------------------------------------------------------------------- /packages/policy-server-client/src/policy_server_client/image_tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copied from https://github.com/Physical-Intelligence/openpi/blob/main/packages/openpi-client/src/openpi_client/image_tools.py 3 | """ 4 | 5 | import numpy as np 6 | from PIL import Image 7 | 8 | 9 | def convert_to_uint8(img: np.ndarray) -> np.ndarray: 10 | """Converts an image to uint8 if it is a float image. 11 | 12 | This is important for reducing the size of the image when sending it over the network. 13 | """ 14 | if np.issubdtype(img.dtype, np.floating): 15 | img = (255 * img).astype(np.uint8) 16 | return img 17 | 18 | 19 | def resize_with_pad(images: np.ndarray, height: int, width: int, method=Image.BILINEAR) -> np.ndarray: 20 | """Replicates tf.image.resize_with_pad for multiple images using PIL. Resizes a batch of images to a target height. 21 | 22 | Args: 23 | images: A batch of images in [..., height, width, channel] format. 24 | height: The target height of the image. 25 | width: The target width of the image. 26 | method: The interpolation method to use. Default is bilinear. 27 | 28 | Returns: 29 | The resized images in [..., height, width, channel]. 30 | """ 31 | # If the images are already the correct size, return them as is. 32 | if images.shape[-3:-1] == (height, width): 33 | return images 34 | 35 | original_shape = images.shape 36 | 37 | images = images.reshape(-1, *original_shape[-3:]) 38 | resized = np.stack([_resize_with_pad_pil(Image.fromarray(im), height, width, method=method) for im in images]) 39 | return resized.reshape(*original_shape[:-3], *resized.shape[-3:]) 40 | 41 | 42 | def _resize_with_pad_pil(image: Image.Image, height: int, width: int, method: int) -> Image.Image: 43 | """Replicates tf.image.resize_with_pad for one image using PIL. Resizes an image to a target height and 44 | width without distortion by padding with zeros. 45 | 46 | Unlike the jax version, note that PIL uses [width, height, channel] ordering instead of [batch, h, w, c]. 47 | """ 48 | cur_width, cur_height = image.size 49 | if cur_width == width and cur_height == height: 50 | return image # No need to resize if the image is already the correct size. 51 | 52 | ratio = max(cur_width / width, cur_height / height) 53 | resized_height = int(cur_height / ratio) 54 | resized_width = int(cur_width / ratio) 55 | resized_image = image.resize((resized_width, resized_height), resample=method) 56 | 57 | zero_image = Image.new(resized_image.mode, (width, height), 0) 58 | pad_height = max(0, int((height - resized_height) / 2)) 59 | pad_width = max(0, int((width - resized_width) / 2)) 60 | zero_image.paste(resized_image, (pad_width, pad_height)) 61 | assert zero_image.size == (width, height) 62 | return zero_image 63 | -------------------------------------------------------------------------------- /packages/policy-server-client/src/policy_server_client/msgpack_numpy.py: -------------------------------------------------------------------------------- 1 | """Adds NumPy array support to msgpack. 2 | 3 | msgpack is good for (de)serializing data over a network for multiple reasons: 4 | - msgpack is secure (as opposed to pickle/dill/etc which allow for arbitrary code execution) 5 | - msgpack is widely used and has good cross-language support 6 | - msgpack does not require a schema (as opposed to protobuf/flatbuffers/etc) which is convenient in dynamically typed 7 | languages like Python and JavaScript 8 | - msgpack is fast and efficient (as opposed to readable formats like JSON/YAML/etc); I found that msgpack was ~4x faster 9 | than pickle for serializing large arrays using the below strategy 10 | 11 | The code below is adapted from https://github.com/lebedov/msgpack-numpy. The reason not to use that library directly is 12 | that it falls back to pickle for object arrays. 13 | """ 14 | 15 | import functools 16 | 17 | import msgpack 18 | import numpy as np 19 | 20 | 21 | def pack_array(obj): 22 | if (isinstance(obj, np.ndarray | np.generic)) and obj.dtype.kind in ("V", "O", "c"): 23 | raise ValueError(f"Unsupported dtype: {obj.dtype}") 24 | 25 | if isinstance(obj, np.ndarray): 26 | return { 27 | b"__ndarray__": True, 28 | b"data": obj.tobytes(), 29 | b"dtype": obj.dtype.str, 30 | b"shape": obj.shape, 31 | } 32 | 33 | if isinstance(obj, np.generic): 34 | return { 35 | b"__npgeneric__": True, 36 | b"data": obj.item(), 37 | b"dtype": obj.dtype.str, 38 | } 39 | 40 | return obj 41 | 42 | 43 | def unpack_array(obj): 44 | if b"__ndarray__" in obj: 45 | return np.ndarray(buffer=obj[b"data"], dtype=np.dtype(obj[b"dtype"]), shape=obj[b"shape"]) 46 | 47 | if b"__npgeneric__" in obj: 48 | return np.dtype(obj[b"dtype"]).type(obj[b"data"]) 49 | 50 | return obj 51 | 52 | 53 | Packer = functools.partial(msgpack.Packer, default=pack_array) 54 | packb = functools.partial(msgpack.packb, default=pack_array) 55 | 56 | Unpacker = functools.partial(msgpack.Unpacker, object_hook=unpack_array) 57 | unpackb = functools.partial(msgpack.unpackb, object_hook=unpack_array) 58 | -------------------------------------------------------------------------------- /packages/policy-server-client/src/policy_server_client/websocket_policy_client.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from https://github.com/Physical-Intelligence/openpi 3 | """ 4 | 5 | import logging 6 | import time 7 | from typing import ( 8 | Dict, 9 | Tuple, 10 | ) 11 | 12 | import websockets.sync.client 13 | from typing_extensions import override 14 | 15 | from policy_server_client import base_policy as _base_policy 16 | from policy_server_client import msgpack_numpy 17 | 18 | 19 | class WebsocketPolicyClient(_base_policy.BasePolicy): 20 | """Websocket client for the policy server.""" 21 | 22 | def __init__(self, host: str, port: int): 23 | """Initialize the WebsocketPolicyClient. 24 | 25 | Args: 26 | host (str): The hostname of the policy server. 27 | port (int): The port number of the policy server. 28 | """ 29 | self.host = host 30 | self.port = port 31 | self.logger = logging.getLogger("websockets.client") 32 | self._uri = f"ws://{self.host}:{self.port}" 33 | self._ws, self._server_metadata = self._wait_for_server() 34 | self._packer = msgpack_numpy.Packer() 35 | 36 | def get_server_metadata(self) -> Dict: 37 | return self._server_metadata 38 | 39 | def _wait_for_server(self) -> Tuple[websockets.sync.client.ClientConnection, Dict]: 40 | self.logger.info(f"Waiting for server at {self._uri}...") 41 | while True: 42 | try: 43 | conn = websockets.sync.client.connect(self._uri, 44 | compression=None, 45 | max_size=None, 46 | ping_timeout=None,) 47 | metadata = msgpack_numpy.unpackb(conn.recv()) 48 | return conn, metadata 49 | except ConnectionRefusedError: 50 | self.logger.info("Still waiting for server...") 51 | time.sleep(5) 52 | 53 | @override 54 | def infer(self, obs: Dict) -> Dict: 55 | """Infer actions from observations.""" 56 | data = self._packer.pack(obs) 57 | self._ws.send(data) 58 | response = self._ws.recv() 59 | if isinstance(response, str): 60 | # we're expecting bytes; if the server sends a string, it's an error. 61 | raise RuntimeError(f"Error in inference server:\n{response}") 62 | return msgpack_numpy.unpackb(response) 63 | 64 | @override 65 | def reset(self) -> None: 66 | """Reset the policy and associated env adapter to its initial state.""" 67 | self._ws.send(self._packer.pack({"reset": True})) 68 | response = self._ws.recv() 69 | if isinstance(response, str): 70 | # we're expecting bytes; if the server sends a string, it's an error. 71 | raise RuntimeError(f"Error in inference server:\n{response}") 72 | return msgpack_numpy.unpackb(response) 73 | 74 | def switch_model(self, new_model_path) -> None: 75 | """Switch the model to a new checkpoint step.""" 76 | self._ws.send(self._packer.pack({"new_model_path": new_model_path})) 77 | response = self._ws.recv() 78 | if isinstance(response, str): 79 | # we're expecting bytes; if the server sends a string, it's an error. 80 | raise RuntimeError(f"Error in inference server:\n{response}") 81 | return msgpack_numpy.unpackb(response) 82 | -------------------------------------------------------------------------------- /packages/policy-server-client/src/policy_server_client/websocket_policy_server.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from https://github.com/Physical-Intelligence/openpi 3 | """ 4 | 5 | import asyncio 6 | import logging 7 | import traceback 8 | 9 | import websockets.asyncio.server 10 | import websockets.frames 11 | 12 | from policy_server_client import msgpack_numpy 13 | from src.utils.monitor import setup_logger 14 | 15 | 16 | class WebsocketPolicyServer: 17 | """Serves a policy using the websocket protocol. See websocket_client_policy.py for a client implementation. 18 | 19 | Currently only implements the `load` and `infer` methods. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | policy, 25 | host: str = "0.0.0.0", 26 | port: int = 8000, 27 | metadata: dict | None = None, 28 | ) -> None: 29 | self._policy = policy 30 | self._host = host 31 | self._port = port 32 | self._metadata = metadata or {} 33 | self.logger = setup_logger( 34 | main_rank=True, 35 | filename=None, 36 | name="policy_server", 37 | ) 38 | self.logger.setLevel(logging.INFO) 39 | 40 | def serve_forever(self) -> None: 41 | """Starts the server and runs it forever. This is a blocking call.""" 42 | self.logger.info(f"Starting server on {self._host}:{self._port}") 43 | asyncio.run(self.run()) 44 | 45 | async def run(self): 46 | async with websockets.asyncio.server.serve( 47 | self._handler, 48 | self._host, 49 | self._port, 50 | compression=None, 51 | max_size=None, 52 | ) as server: 53 | await server.serve_forever() 54 | 55 | async def _handler(self, websocket: websockets.asyncio.server.ServerConnection): 56 | self.logger.info(f"Connection from {websocket.remote_address} opened") 57 | packer = msgpack_numpy.Packer() 58 | 59 | await websocket.send(packer.pack(self._metadata)) 60 | 61 | while True: 62 | try: 63 | obs = msgpack_numpy.unpackb(await websocket.recv()) 64 | 65 | # Check if the client is requesting a new model checkpoint 66 | new_model_path = obs.get("new_model_path", None) 67 | if new_model_path is not None: 68 | self._policy.switch_model(new_model_path) 69 | self.logger.info(f"Loaded new model checkpoint: {new_model_path}") 70 | await websocket.send(packer.pack({"status": "model switched"})) 71 | continue # no actual observation will be sent with this, so we skip the rest of the loop 72 | 73 | # Check if the client is requesting a reset 74 | if obs.get("reset", False): 75 | self._policy.reset() 76 | await websocket.send(packer.pack({"status": "reset"})) 77 | continue 78 | 79 | action = self._policy.select_action(obs) 80 | 81 | await websocket.send(packer.pack(action)) 82 | except websockets.ConnectionClosed: 83 | self.logger.info(f"Connection from {websocket.remote_address} closed") 84 | break 85 | except Exception: 86 | await websocket.send(traceback.format_exc()) 87 | await websocket.close( 88 | code=websockets.frames.CloseCode.INTERNAL_ERROR, 89 | reason="Internal server error. Traceback included in previous frame.", 90 | ) 91 | raise 92 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "INT-ACT" 7 | version = "0.0.1" 8 | readme = "README.md" 9 | description = "Official repo of From Intention to Execution: Probing the Generalization Boundaries of Vision-Language-Action Models" 10 | authors = [ 11 | {name = "Irving Fang", email = "irving.fang@nyu.edu"}, 12 | {name = "Juexiao Zhang", email = "juexiao.zhang@nyu.edu"} 13 | ] 14 | requires-python = "==3.10.12" 15 | 16 | dependencies = [ 17 | "lerobot", 18 | "transformers==4.48.3", 19 | "opencv-python>=4.11.0", 20 | "pytest", 21 | "policy-server-client", 22 | "numpy==1.26.4", # it's important to use pre-2.0 numpy. 2.0 has breaking changes and lots of donwstream packages haven't updated yet 23 | "torch==2.6.0", 24 | "bitsandbytes", 25 | "einops", 26 | "draccus", 27 | "protobuf==3.20.3", 28 | "tensorflow==2.15.0", 29 | "tensorflow_datasets==4.9.2", 30 | "tensorflow_graphics", 31 | "evaluate", 32 | ] 33 | 34 | [tool.ruff] 35 | line-length = 170 36 | target-version = "py310" 37 | 38 | [tool.ruff.lint] 39 | select = ["A", "B", "E", "F", "I", "RUF", "W"] 40 | ignore = ["F722"] 41 | 42 | [tool.ruff.lint.per-file-ignores] 43 | "__init__.py" = ["E402", "F401"] 44 | 45 | [tool.uv.sources] 46 | policy-server-client = { workspace = true } 47 | lerobot = { path = "./third_party/lerobot", editable = true } 48 | 49 | [tool.uv.workspace] 50 | members = ["packages/policy-server-client"] 51 | -------------------------------------------------------------------------------- /scripts/dataset/modify_rlds_dataset.py: -------------------------------------------------------------------------------- 1 | """Modifies TFDS dataset with a map function, updates the feature definition and stores new dataset.""" 2 | 3 | import os 4 | import sys 5 | from functools import partial 6 | 7 | import tensorflow_datasets as tfds 8 | 9 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) 10 | 11 | from src.data.oxe.preprocess.mod_functions import TFDS_MOD_FUNCTIONS 12 | from src.data.oxe.preprocess.multithreaded_adhoc_tfds_builder import ( 13 | MultiThreadedAdhocDatasetBuilder, 14 | ) 15 | 16 | # avoid GCS nonsense errors 17 | tfds.core.utils.gcs_utils._is_gcs_disabled = True 18 | os.environ["NO_GCE_CHECK"] = "true" 19 | 20 | 21 | def mod_features(mods, features): 22 | """Modifies feature dict.""" 23 | for mod in mods: 24 | features = TFDS_MOD_FUNCTIONS[mod].mod_features(features) 25 | return features 26 | 27 | 28 | def mod_dataset_generator(builder, split, mods): 29 | """Modifies dataset features.""" 30 | ds = builder.as_dataset(split=split) 31 | for mod in mods: 32 | ds = TFDS_MOD_FUNCTIONS[mod].mod_dataset(ds) 33 | for episode in tfds.core.dataset_utils.as_numpy(ds): 34 | yield episode 35 | 36 | 37 | def main(args): 38 | builder = tfds.builder(args.dataset, data_dir=args.data_dir) 39 | 40 | features = mod_features(args.mods, builder.info.features) 41 | print("############# Target features: ###############") 42 | print(features) 43 | print("##############################################") 44 | assert args.data_dir != args.target_dir # prevent overwriting original dataset 45 | 46 | mod_dataset_builder = MultiThreadedAdhocDatasetBuilder( 47 | name=args.dataset, 48 | version=builder.version, 49 | features=features, 50 | split_datasets={ 51 | split: builder.info.splits[split] for split in builder.info.splits 52 | }, 53 | config=builder.builder_config, 54 | data_dir=args.target_dir, 55 | description=builder.info.description, 56 | generator_fcn=partial(mod_dataset_generator, builder=builder, mods=args.mods), 57 | n_workers=args.n_workers, 58 | max_episodes_in_memory=args.max_episodes_in_memory, 59 | ) 60 | mod_dataset_builder.download_and_prepare() 61 | 62 | 63 | if __name__ == "__main__": 64 | import argparse 65 | 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument("--dataset", type=str, required=True) 68 | parser.add_argument("--data_dir", type=str, required=True) 69 | parser.add_argument("--target_dir", type=str, required=True) 70 | parser.add_argument( 71 | "--mods", type=str, nargs="+", default=["resize_and_jpeg_encode"] 72 | ) 73 | parser.add_argument("--n_workers", type=int, default=10) 74 | parser.add_argument("--max_episodes_in_memory", type=int, default=100) 75 | args = parser.parse_args() 76 | 77 | main(args) 78 | -------------------------------------------------------------------------------- /scripts/dataset/rename_rlds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Check if correct number of arguments is provided 4 | if [ "$#" -ne 3 ]; then 5 | echo "Usage: $0 " 6 | exit 1 7 | fi 8 | 9 | directory=$1 10 | old_phrase=$2 11 | new_phrase=$3 12 | 13 | # Check if directory exists 14 | if [ ! -d "$directory" ]; then 15 | echo "Error: Directory '$directory' does not exist." 16 | exit 1 17 | fi 18 | 19 | # Iterate through files in the directory 20 | for file in "$directory"/*; do 21 | if [[ -f "$file" ]]; then 22 | base_name=$(basename "$file") 23 | if [[ "$base_name" == *"$old_phrase"* ]]; then 24 | new_base_name="${base_name//$old_phrase/$new_phrase}" 25 | new_name="$directory/$new_base_name" 26 | mv "$file" "$new_name" 27 | echo "Renamed: '$file' -> '$new_name'" 28 | fi 29 | fi 30 | done 31 | 32 | echo "Renaming completed." -------------------------------------------------------------------------------- /scripts/dataset/rlds2lerobot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | import tensorflow_datasets as tfds 8 | import tqdm 9 | from lerobot.common.datasets.lerobot_dataset import LeRobotDataset 10 | 11 | REPO_NAME = "IrvingF7/taco_play_test" # Name of the output dataset, also used for the Hugging Face Hub 12 | RAW_DATASET_NAMES = [ 13 | "taco_play" 14 | ] # For simplicity we will combine multiple Libero datasets into one training dataset 15 | DATA_DIR = Path(os.environ["VLA_DATA_DIR"]) / "resize_224" 16 | 17 | def main(data_dir: str = DATA_DIR, push_to_hub: bool = True): 18 | # Clean up any existing dataset in the output directory 19 | output_path = DATA_DIR / REPO_NAME 20 | if output_path.exists(): 21 | shutil.rmtree(output_path) 22 | 23 | # Create LeRobot dataset, define features to store 24 | # OpenPi assumes that proprio is stored in `state` and actions in `action` 25 | # LeRobot assumes that dtype of image data is `image` 26 | dataset = LeRobotDataset.create( 27 | repo_id=REPO_NAME, 28 | robot_type="panda", 29 | fps=15, 30 | features={ 31 | "image": { 32 | "dtype": "image", 33 | "shape": (224, 224, 3), 34 | "names": ["height", "width", "channel"], 35 | }, 36 | "state": { 37 | "dtype": "float32", 38 | "shape": (9,), 39 | "names": ["state"], 40 | }, 41 | "actions": { 42 | "dtype": "float32", 43 | "shape": (7,), 44 | "names": ["actions"], 45 | }, 46 | }, 47 | image_writer_threads=10, 48 | image_writer_processes=20, 49 | ) 50 | 51 | # Loop over raw Libero datasets and write episodes to the LeRobot dataset 52 | # You can modify this for your own data format 53 | for raw_dataset_name in RAW_DATASET_NAMES: 54 | raw_dataset = tfds.load(raw_dataset_name, data_dir=data_dir, split="train") 55 | print(f"current dataset: {raw_dataset_name}") 56 | for episode in tqdm.tqdm(raw_dataset): 57 | for step in episode["steps"].as_numpy_iterator(): 58 | step["action"] = step["action"]["rel_actions_world"] 59 | 60 | # clip gripper action, +1 = open, 0 = close 61 | step["action"] = np.concatenate( 62 | ( 63 | step["action"][:6], 64 | tf.clip_by_value(step["action"][-1:], 0, 1).numpy(), 65 | ) 66 | ) 67 | dataset.add_frame( 68 | { 69 | "image": step["observation"]["rgb_static"], 70 | "state": np.concatenate( 71 | ( 72 | step["observation"]["robot_obs"][:6], 73 | step["observation"]["robot_obs"][6:8], 74 | step["observation"]["robot_obs"][-1:], 75 | ), 76 | ), 77 | "actions": step["action"], 78 | "task": step["observation"]["natural_language_instruction"].decode(), 79 | } 80 | ) 81 | dataset.save_episode() 82 | 83 | # Optionally push to the Hugging Face Hub 84 | if push_to_hub: 85 | dataset.push_to_hub( 86 | tags=["taco_play", "panda", "rlds"], 87 | private=False, 88 | push_videos=True, 89 | license="apache-2.0", 90 | ) 91 | 92 | 93 | if __name__ == "__main__": 94 | main() 95 | -------------------------------------------------------------------------------- /scripts/dataset/test_lerobot_dataset.py: -------------------------------------------------------------------------------- 1 | """This scripts demonstrates how to train Diffusion Policy on the PushT environment. 2 | 3 | Once you have trained a model with this script, you can try to evaluate it on 4 | examples/2_evaluate_pretrained_policy.py 5 | """ 6 | 7 | from pathlib import Path 8 | 9 | import torch 10 | from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata 11 | from lerobot.common.datasets.utils import dataset_to_policy_features 12 | from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig 13 | from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy 14 | from lerobot.configs.types import FeatureType 15 | 16 | 17 | def main(): 18 | # Create a directory to store the training checkpoint. 19 | output_directory = Path("outputs/train/example_pusht_diffusion") 20 | output_directory.mkdir(parents=True, exist_ok=True) 21 | 22 | # # Select your device 23 | device = torch.device("cuda") 24 | 25 | # Number of offline training steps (we'll only do offline training for this example.) 26 | # Adjust as you prefer. 5000 steps are needed to get something worth evaluating. 27 | training_steps = 5000 28 | log_freq = 1 29 | 30 | # When starting from scratch (i.e. not from a pretrained policy), we need to specify 2 things before 31 | # creating the policy: 32 | # - input/output shapes: to properly size the policy 33 | # - dataset stats: for normalization and denormalization of input/outputs 34 | dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht") 35 | features = dataset_to_policy_features(dataset_metadata.features) 36 | output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} 37 | input_features = {key: ft for key, ft in features.items() if key not in output_features} 38 | 39 | # Policies are initialized with a configuration class, in this case `DiffusionConfig`. For this example, 40 | # we'll just use the defaults and so no arguments other than input/output features need to be passed. 41 | cfg = DiffusionConfig(input_features=input_features, output_features=output_features) 42 | 43 | # We can now instantiate our policy with this config and the dataset stats. 44 | policy = DiffusionPolicy(cfg, dataset_stats=dataset_metadata.stats) 45 | policy.train() 46 | policy.to(device) 47 | 48 | # Another policy-dataset interaction is with the delta_timestamps. Each policy expects a given number frames 49 | # which can differ for inputs, outputs and rewards (if there are some). 50 | delta_timestamps = { 51 | "observation.image": [i / dataset_metadata.fps for i in cfg.observation_delta_indices], 52 | "observation.state": [i / dataset_metadata.fps for i in cfg.observation_delta_indices], 53 | "action": [i / dataset_metadata.fps for i in cfg.action_delta_indices], 54 | } 55 | 56 | # In this case with the standard configuration for Diffusion Policy, it is equivalent to this: 57 | delta_timestamps = { 58 | # Load the previous image and state at -0.1 seconds before current frame, 59 | # then load current image and state corresponding to 0.0 second. 60 | "observation.image": [-0.1, 0.0], 61 | "observation.state": [-0.1, 0.0], 62 | # Load the previous action (-0.1), the next action to be executed (0.0), 63 | # and 14 future actions with a 0.1 seconds spacing. All these actions will be 64 | # used to supervise the policy. 65 | "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4], 66 | } 67 | 68 | # We can then instantiate the dataset with these delta_timestamps configuration. 69 | dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps) 70 | 71 | # Then we create our optimizer and dataloader for offline training. 72 | optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4) 73 | dataloader = torch.utils.data.DataLoader( 74 | dataset, 75 | num_workers=4, 76 | batch_size=64, 77 | shuffle=True, 78 | pin_memory=device.type != "cpu", 79 | drop_last=True, 80 | ) 81 | 82 | # Run training loop. 83 | step = 0 84 | 85 | for batch in dataloader: 86 | # batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} 87 | 88 | print(batch.keys()) 89 | print("image shape", batch["observation.image"].shape) 90 | print("state shape", batch["observation.state"].shape) 91 | print("action shape", batch["action"].shape) 92 | break 93 | output_dict = policy.forward(batch) 94 | loss = output_dict["loss"] 95 | loss.backward() 96 | optimizer.step() 97 | optimizer.zero_grad() 98 | 99 | if step % log_freq == 0: 100 | print(f"step: {step} loss: {loss.item():.3f}") 101 | step += 1 102 | if step >= training_steps: 103 | break 104 | 105 | 106 | 107 | 108 | if __name__ == "__main__": 109 | main() 110 | -------------------------------------------------------------------------------- /scripts/dataset/test_rlds_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import draccus 5 | from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy 6 | from torch.utils.data import DataLoader 7 | 8 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) 9 | 10 | from src.agent.configuration_pipeline import TrainPipelineConfig 11 | from src.agent.dataset import TorchRLDSInterleavedDataset 12 | 13 | 14 | @draccus.wrap() 15 | def main(train_cfg: TrainPipelineConfig): 16 | train_dataloader = DataLoader( 17 | TorchRLDSInterleavedDataset(train_cfg.data.train, train=True).dataset, 18 | batch_size=1, 19 | pin_memory=True, 20 | ) 21 | 22 | transition_count = 0 23 | for _ in train_dataloader: 24 | transition_count += 1 25 | print(f"Total usable transitions in train dataloader: {transition_count}") 26 | 27 | if __name__ == "__main__": 28 | main() 29 | -------------------------------------------------------------------------------- /scripts/eval/get_eval_obj_bbox.py: -------------------------------------------------------------------------------- 1 | import trimesh 2 | 3 | # Load; force='mesh' will merge a single-mesh scene into one Trimesh 4 | mesh = trimesh.load('path/to/keyboard.glb', force='mesh') 5 | 6 | # bounds is a (2*3) array: [ [xmin, ymin, zmin], [xmax, ymax, zmax] ] 7 | min_corner, max_corner = mesh.bounds 8 | 9 | xmin, ymin, zmin = min_corner 10 | xmax, ymax, zmax = max_corner 11 | 12 | print(f"X range: {xmin:.3f} → {xmax:.3f}") 13 | print(f"Y range: {ymin:.3f} → {ymax:.3f}") 14 | print(f"Z range: {zmin:.3f} → {zmax:.3f}") 15 | 16 | # If you also want the size along each axis: 17 | size_x, size_y, size_z = max_corner - min_corner 18 | print(f"Size: ΔX={size_x:.3f}, ΔY={size_y:.3f}, ΔZ={size_z:.3f}") 19 | -------------------------------------------------------------------------------- /scripts/eval/test_evaluator.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Trap Ctrl+C and clean up all child processes 4 | trap "echo 'Ctrl+C received, killing server...'; kill $SERVER_PID; exit 1" SIGINT 5 | 6 | CONFIG_NAMES=("test_simpler.yaml") 7 | 8 | for CONFIG_NAME in "${CONFIG_NAMES[@]}"; do 9 | singularity exec --nv \ 10 | --bind /usr/share/nvidia \ 11 | --bind /usr/share/glvnd/egl_vendor.d/10_nvidia.json \ 12 | --bind /usr/share/vulkan/icd.d/nvidia_icd.x86_64.json \ 13 | --overlay /scratch/zf540/pi0/pi_overlay.ext3:ro \ 14 | --overlay /scratch/work/public/singularity/vulkan-1.4.309-cuda12.1.1-cudnn8.9.0-devel-ubuntu22.04.2.sqf:ro \ 15 | /scratch/work/public/singularity/cuda12.1.1-cudnn8.9.0-devel-ubuntu22.04.2.sif \ 16 | /bin/bash -c "source ~/set_path.sh; \ 17 | export PATH='/ext3/uv:$PATH'; \ 18 | source ./.venv/bin/activate; \ 19 | uv run src/agent/run.py \ 20 | --config_path config/experiment/simpler/${CONFIG_NAME} \ 21 | --use_bf16 False \ 22 | --eval_cfg.role server" & 23 | SERVER_PID=\$! 24 | 25 | singularity exec --nv \ 26 | --bind /usr/share/nvidia \ 27 | --bind /usr/share/glvnd/egl_vendor.d/10_nvidia.json \ 28 | --bind /usr/share/vulkan/icd.d/nvidia_icd.x86_64.json \ 29 | --overlay /scratch/zf540/pi0/pi_overlay.ext3:ro \ 30 | --overlay /scratch/work/public/singularity/vulkan-1.4.309-cuda12.1.1-cudnn8.9.0-devel-ubuntu22.04.2.sqf:ro \ 31 | /scratch/work/public/singularity/cuda12.1.1-cudnn8.9.0-devel-ubuntu22.04.2.sif \ 32 | /bin/bash -c "~/set_path.sh; \ 33 | export PATH='/ext3/uv:$PATH'; \ 34 | source ./src/experiments/envs/simpler_ms3/.venv/bin/activate; \ 35 | python src/agent/run.py \ 36 | --config_path config/experiment/simpler/${CONFIG_NAME} \ 37 | --use_bf16 False \ 38 | --eval_cfg.role client" 39 | kill $SERVER_PID 40 | 41 | done -------------------------------------------------------------------------------- /scripts/singularity/cpu_activate_singularity.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | singularity exec --overlay pi_overlay.ext3:ro /scratch/work/public/singularity/cuda12.1.1-cudnn8.9.0-devel-ubuntu22.04.2.sif /bin/bash 4 | -------------------------------------------------------------------------------- /scripts/singularity/gpu_activate_singularity.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | singularity exec --nv \ 4 | --bind /usr/share/nvidia \ 5 | --bind /usr/share/glvnd/egl_vendor.d/10_nvidia.json \ 6 | --bind /usr/share/vulkan/icd.d/nvidia_icd.x86_64.json \ 7 | --overlay pi_overlay.ext3:ro \ 8 | --overlay /scratch/work/public/singularity/vulkan-1.4.309-cuda12.1.1-cudnn8.9.0-devel-ubuntu22.04.2.sqf:ro \ 9 | /scratch/work/public/singularity/cuda12.1.1-cudnn8.9.0-devel-ubuntu22.04.2.sif /bin/bash 10 | -------------------------------------------------------------------------------- /scripts/training/test_nyu_training.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # CPU check 4 | TOTAL_CORES=$(nproc) 5 | echo "TOTAL_CORES=$TOTAL_CORES" 6 | 7 | echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" 8 | NUM_GPU="$(nvidia-smi --list-gpus | wc -l)" 9 | echo "NUM_GPU=$NUM_GPU" 10 | 11 | # Compute OMP_NUM_THREADS (avoid division by zero) 12 | OMP_THREADS=$((TOTAL_CORES / NUM_GPU)) 13 | 14 | # Ensure OMP_NUM_THREADS is at least 1 15 | OMP_THREADS=$((OMP_THREADS > 0 ? OMP_THREADS : 1)) 16 | echo "OMP_NUM_THREADS=$OMP_THREADS" 17 | 18 | export MASTER_ADDR=$(scontrol show hostname ${SLURM_NODELIST} | head -n 1) 19 | find_free_port() { 20 | python -c "import socket; s = socket.socket(socket.AF_INET, socket.SOCK_STREAM); s.bind(('', 0)); port = s.getsockname()[1]; s.close(); print(port)" 21 | } 22 | export MASTER_PORT=$(find_free_port) 23 | 24 | # export NCCL_DEBUG=INFO 25 | 26 | singularity exec --nv \ 27 | --overlay /scratch/zf540/pi0/pi_overlay.ext3:ro \ 28 | /scratch/work/public/singularity/cuda12.1.1-cudnn8.9.0-devel-ubuntu22.04.2.sif \ 29 | /bin/bash -c "source ~/.bashrc; \ 30 | export PATH='/ext3/uv:$PATH'; \ 31 | source ./.venv/bin/activate; \ 32 | export OMP_NUM_THREADS=$OMP_THREADS; \ 33 | uv run torchrun \ 34 | --nnodes=1 \ 35 | --nproc_per_node=$NUM_GPU \ 36 | --rdzv_id=$RANDOM \ 37 | --rdzv_backend=c10d \ 38 | --max-restarts=0 \ 39 | --standalone \ 40 | --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \ 41 | src/agent/run.py \ 42 | --config_path config/train/pi0_finetune_taco_libero.yaml \ 43 | --use_wandb False" 44 | # --debug True \ 45 | # --use_wandb False" 46 | 47 | -------------------------------------------------------------------------------- /slurms/dataset_scripts/convert_proprio.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=convert_proprio 4 | #SBATCH --nodes=1 5 | #SBATCH --ntasks=1 6 | #SBATCH --cpus-per-task=45 7 | #SBATCH --mem=200G 8 | #SBATCH --time=04:00:00 9 | #SBATCH --output=log/slurm/dataset/%x.out 10 | #SBATCH --error=log/slurm/dataset/%x.err 11 | #SBATCH --mail-type=ALL 12 | #SBATCH --mail-user=zf540@nyu.edu 13 | #SBATCH --account=pr_109_tandon_advanced 14 | 15 | module purge 16 | # increase limit on number of files opened in parallel to 20k --> conversion opens up to 1k temporary files in /tmp to store dataset during conversion 17 | ulimit -n 20000 18 | 19 | source ./set_path.sh 20 | 21 | # dataset: bridge_dataset, or fractal20220817_data 22 | singularity exec \ 23 | --overlay ${OVERLAY_EXT3}:ro \ 24 | /scratch/work/public/singularity/cuda11.8.86-cudnn8.7-devel-ubuntu22.04.2.sif \ 25 | /bin/bash -c "export PATH='/ext3/uv:$PATH'; \ 26 | source ./.venv/bin/activate; \ 27 | uv run python scripts/dataset/modify_rlds_dataset.py \ 28 | --dataset=fractal20220817_data \ 29 | --data_dir=$VLA_DATA_DIR/resize_224 \ 30 | --target_dir=$VLA_DATA_DIR/temp \ 31 | --mods=convert_proprio_to_euler \ 32 | --n_workers=40 \ 33 | --max_episodes_in_memory=200" 34 | -------------------------------------------------------------------------------- /slurms/dataset_scripts/resize_jpeg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=resize_jpeg 4 | #SBATCH --output=log/slurm/dataset/%x.out 5 | #SBATCH --error=log/slurm/dataset/%x.err 6 | #SBATCH --nodes=1 7 | #SBATCH --ntasks=1 8 | #SBATCH --cpus-per-task=40 9 | #SBATCH --mem=200G 10 | #SBATCH --time=03:55:00 11 | #SBATCH --account=pr_109_tandon_advanced 12 | 13 | module purge 14 | # increase limit on number of files opened in parallel to 20k --> conversion opens up to 1k temporary files in /tmp to store dataset during conversion 15 | ulimit -n 20000 16 | 17 | source ./set_path.sh 18 | 19 | singularity exec \ 20 | /scratch/work/public/singularity/cuda11.8.86-cudnn8.7-devel-ubuntu22.04.2.sif \ 21 | /bin/bash -c "source ./set_path.sh; \ 22 | export PATH='${HOME}/.local/bin/uv:$PATH'; \ 23 | uv run scripts/dataset/modify_rlds_dataset.py \ 24 | --dataset=bridge_dataset \ 25 | --data_dir=/vast/work/public/ml-datasets/x-embodiment \ 26 | --target_dir=$VLA_DATA_DIR/resize_224 \ 27 | --mods=resize_and_jpeg_encode \ 28 | --n_workers=40 \ 29 | --max_episodes_in_memory=200" -------------------------------------------------------------------------------- /slurms/dataset_scripts/rlds2lerobot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=rlds2lerobot 4 | #SBATCH --output=log/slurm/dataset/%x.out 5 | #SBATCH --error=log/slurm/dataset/%x.err 6 | #SBATCH --nodes=1 7 | #SBATCH --ntasks=1 8 | #SBATCH --cpus-per-task=45 9 | #SBATCH --mem=100G 10 | #SBATCH --time=04:00:00 11 | #SBATCH --mail-type=ALL 12 | #SBATCH --mail-user=zf540@nyu.edu 13 | #SBATCH --account=pr_109_tandon_advanced 14 | 15 | module purge 16 | # increase limit on number of files opened in parallel to 20k --> conversion opens up to 1k temporary files in /tmp to store dataset during conversion 17 | ulimit -n 20000 18 | source ./set_path.sh 19 | # dataset: bridge_dataset, or fractal20220817_data 20 | singularity exec \ 21 | --overlay ${OVERLAY_EXT3}:ro \ 22 | /scratch/work/public/singularity/cuda11.8.86-cudnn8.7-devel-ubuntu22.04.2.sif \ 23 | /bin/bash -c "export PATH='/ext3/uv:$PATH'; \ 24 | source ./.venv/bin/activate; \ 25 | uv run python scripts/dataset/rlds2lerobot.py" 26 | -------------------------------------------------------------------------------- /slurms/dataset_scripts/test_rlds_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=test_rlds_dataset 4 | #SBATCH --output=log/slurm/dataset/%x.out 5 | #SBATCH --error=log/slurm/dataset/%x.err 6 | #SBATCH --nodes=1 7 | #SBATCH --ntasks=1 8 | #SBATCH --cpus-per-task=45 9 | #SBATCH --mem=200G 10 | #SBATCH --time=04:00:00 11 | #SBATCH --mail-type=ALL 12 | #SBATCH --mail-user=zf540@nyu.edu 13 | #SBATCH --account=pr_109_tandon_advanced 14 | 15 | module purge 16 | # increase limit on number of files opened in parallel to 20k --> conversion opens up to 1k temporary files in /tmp to store dataset during conversion 17 | ulimit -n 20000 18 | 19 | source ./set_path.sh 20 | 21 | # dataset: bridge_dataset, or fractal20220817_data 22 | singularity exec \ 23 | --overlay ${OVERLAY_EXT3}:ro \ 24 | /scratch/work/public/singularity/cuda11.8.86-cudnn8.7-devel-ubuntu22.04.2.sif \ 25 | /bin/bash -c "source ~/.bashrc; 26 | export PATH='/ext3/uv:$PATH'; \ 27 | source ./.venv/bin/activate; \ 28 | uv run python scripts/dataset/test_rlds_dataset.py \ 29 | --config_path config/train/pi0_finetune_taco.yaml" 30 | -------------------------------------------------------------------------------- /slurms/eval_scripts/simpler/ev_magma_bridge_simpler.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=ev_magma_bridge_simpler 4 | #SBATCH --output=log/slurm/eval/simpler/%x_%j.out 5 | #SBATCH --error=log/slurm/eval/simpler/%x_%j.err 6 | #SBATCH --time=48:00:00 7 | #SBATCH --nodes=1 8 | #SBATCH --gres=gpu:1 9 | #SBATCH --ntasks-per-node=1 10 | #SBATCH --constraint="a100|h100" 11 | #SBATCH --cpus-per-task=15 12 | #SBATCH --mem=80G 13 | #SBATCH --account=pr_109_tandon_advanced 14 | 15 | 16 | # Trap Ctrl+C and clean up all child processes 17 | trap "echo 'Ctrl+C received, killing server...'; kill $SERVER_PID; exit 1" SIGINT 18 | 19 | CONFIG_NAMES=("magma_bridge_ev.yaml") 20 | 21 | SEEDS=(42 7 314) 22 | 23 | find_available_port() { 24 | local port 25 | for port in $(shuf -i 10000-65500 -n 200); do 26 | if ! ss -tuln | grep ":$port" > /dev/null; then 27 | echo $port 28 | return 0 29 | fi 30 | done 31 | # Fallback to a default port if no random port is found (unlikely) 32 | echo 5000 33 | return 1 34 | } 35 | 36 | # set all the paths to environment variables 37 | source ./set_path.sh 38 | 39 | for SEED in "${SEEDS[@]}"; do 40 | echo "Running with seed $SEED" 41 | 42 | for CONFIG_NAME in "${CONFIG_NAMES[@]}"; do 43 | echo "Running with config $CONFIG_NAME" 44 | # pull the pretrained_model_gradient_step_cnt list via Python+PyYAML with a custom !include handler 45 | STEP_COUNTS=( $(python3 - < 0 ? OMP_THREADS : 1)) 31 | echo "OMP_NUM_THREADS=$OMP_THREADS" 32 | 33 | export MASTER_ADDR=$(scontrol show hostname ${SLURM_NODELIST} | head -n 1) 34 | find_free_port() { 35 | python -c "import socket; s = socket.socket(socket.AF_INET, socket.SOCK_STREAM); s.bind(('', 0)); port = s.getsockname()[1]; s.close(); print(port)" 36 | } 37 | export MASTER_PORT=$(find_free_port) 38 | 39 | # export NCCL_DEBUG=INFO 40 | echo "Job restart count: $SLURM_RESTART_COUNT" 41 | 42 | singularity exec --nv \ 43 | --overlay ${OVERLAY_EXT3}:ro \ 44 | /scratch/work/public/singularity/cuda12.1.1-cudnn8.9.0-devel-ubuntu22.04.2.sif \ 45 | /bin/bash -c "source ./set_path.sh; \ 46 | export PATH='/ext3/uv:$PATH'; \ 47 | source ./.venv/bin/activate; \ 48 | export OMP_NUM_THREADS=$OMP_THREADS; \ 49 | uv run torchrun \ 50 | --nnodes=1 \ 51 | --nproc_per_node=$NUM_GPU \ 52 | --rdzv_id=$RANDOM \ 53 | --rdzv_backend=c10d \ 54 | --max-restarts=0 \ 55 | --standalone \ 56 | --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \ 57 | src/agent/run.py \ 58 | --config_path config/train/pi0_baseline_bridge.yaml" 59 | 60 | -------------------------------------------------------------------------------- /slurms/train_scripts/pi0_baseline_fractal.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=pi0_baseline_fractal_euler 4 | #SBATCH --output=log/slurm/train/%x.out 5 | #SBATCH --error=log/slurm/train/%x.err 6 | #SBATCH --time=44:00:00 7 | #SBATCH --nodes=1 8 | #SBATCH --gres=gpu:4 9 | #SBATCH --constraint="a100|h100" 10 | #SBATCH --ntasks-per-node=1 11 | #SBATCH --cpus-per-task=54 12 | #SBATCH --mem=440G 13 | #SBATCH --account=pr_109_tandon_advanced 14 | #SBATCH --mail-type=ALL 15 | #SBATCH --mail-user=zf540@nyu.edu 16 | 17 | # CPU check 18 | TOTAL_CORES=$(nproc) 19 | echo "TOTAL_CORES=$TOTAL_CORES" 20 | # GPU check 21 | echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" 22 | NUM_GPU="$(nvidia-smi --list-gpus | wc -l)" 23 | echo "NUM_GPU=$NUM_GPU" 24 | 25 | # Compute OMP_NUM_THREADS (avoid division by zero) 26 | OMP_THREADS=$((TOTAL_CORES / NUM_GPU)) 27 | 28 | # Ensure OMP_NUM_THREADS is at least 1 29 | OMP_THREADS=$((OMP_THREADS > 0 ? OMP_THREADS : 1)) 30 | echo "OMP_NUM_THREADS=$OMP_THREADS" 31 | 32 | export MASTER_ADDR=$(scontrol show hostname ${SLURM_NODELIST} | head -n 1) 33 | find_free_port() { 34 | python -c "import socket; s = socket.socket(socket.AF_INET, socket.SOCK_STREAM); s.bind(('', 0)); port = s.getsockname()[1]; s.close(); print(port)" 35 | } 36 | export MASTER_PORT=$(find_free_port) 37 | 38 | 39 | singularity exec --nv \ 40 | --overlay /scratch/zf540/pi0/pi_overlay.ext3:ro \ 41 | /scratch/work/public/singularity/cuda12.1.1-cudnn8.9.0-devel-ubuntu22.04.2.sif \ 42 | /bin/bash -c "source ~/.bashrc; \ 43 | export PATH='/ext3/uv:$PATH'; \ 44 | source ./.venv/bin/activate; \ 45 | export OMP_NUM_THREADS=$OMP_THREADS; \ 46 | uv run torchrun \ 47 | --nnodes=1 \ 48 | --nproc_per_node=$NUM_GPU \ 49 | --rdzv_id=$RANDOM \ 50 | --rdzv_backend=c10d \ 51 | --max-restarts=0 \ 52 | --standalone \ 53 | --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \ 54 | src/agent/run.py \ 55 | --config_path config/train/pi0_baseline_fractal_euler.yaml" 56 | 57 | -------------------------------------------------------------------------------- /slurms/train_scripts/pi0_baseline_freezevlm_bridge.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=pi0_baseline_freezevlm_bridge 4 | #SBATCH --output=log/slurm/train/%x.out 5 | #SBATCH --error=log/slurm/train/%x.err 6 | #SBATCH --time=44:00:00 7 | #SBATCH --nodes=1 8 | #SBATCH --gres=gpu:4 9 | #SBATCH --constraint="a100|h100" 10 | #SBATCH --ntasks-per-node=1 11 | #SBATCH --cpus-per-task=54 12 | #SBATCH --mem=440G 13 | #SBATCH --account=pr_109_tandon_advanced 14 | 15 | # set all the paths to environment variables 16 | source ./set_path.sh 17 | 18 | # CPU check 19 | TOTAL_CORES=$(nproc) 20 | echo "TOTAL_CORES=$TOTAL_CORES" 21 | 22 | echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" 23 | NUM_GPU="$(nvidia-smi --list-gpus | wc -l)" 24 | echo "NUM_GPU=$NUM_GPU" 25 | 26 | # Compute OMP_NUM_THREADS (avoid division by zero) 27 | OMP_THREADS=$((TOTAL_CORES / NUM_GPU)) 28 | 29 | # Ensure OMP_NUM_THREADS is at least 1 30 | OMP_THREADS=$((OMP_THREADS > 0 ? OMP_THREADS : 1)) 31 | echo "OMP_NUM_THREADS=$OMP_THREADS" 32 | 33 | export MASTER_ADDR=$(scontrol show hostname ${SLURM_NODELIST} | head -n 1) 34 | find_free_port() { 35 | python -c "import socket; s = socket.socket(socket.AF_INET, socket.SOCK_STREAM); s.bind(('', 0)); port = s.getsockname()[1]; s.close(); print(port)" 36 | } 37 | export MASTER_PORT=$(find_free_port) 38 | 39 | # export NCCL_DEBUG=INFO 40 | echo "Job restart count: $SLURM_RESTART_COUNT" 41 | 42 | singularity exec --nv \ 43 | --overlay ${OVERLAY_EXT3}:ro \ 44 | /scratch/work/public/singularity/cuda12.1.1-cudnn8.9.0-devel-ubuntu22.04.2.sif \ 45 | /bin/bash -c "source ./set_path.sh; \ 46 | export PATH='/ext3/uv:$PATH'; \ 47 | source ./.venv/bin/activate; \ 48 | export OMP_NUM_THREADS=$OMP_THREADS; \ 49 | uv run torchrun \ 50 | --nnodes=1 \ 51 | --nproc_per_node=$NUM_GPU \ 52 | --rdzv_id=$RANDOM \ 53 | --rdzv_backend=c10d \ 54 | --max-restarts=0 \ 55 | --standalone \ 56 | --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \ 57 | src/agent/run.py \ 58 | --config_path config/train/pi0_baseline_bridge_freezevlm.yaml" 59 | 60 | -------------------------------------------------------------------------------- /slurms/train_scripts/pi0_baseline_rephrase_bridge.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=pi0_baseline_bridge 4 | #SBATCH --output=log/slurm/train/%x.out 5 | #SBATCH --error=log/slurm/train/%x.err 6 | #SBATCH --time=44:00:00 7 | #SBATCH --nodes=1 8 | #SBATCH --gres=gpu:4 9 | #SBATCH --constraint="h100" 10 | #SBATCH --ntasks-per-node=1 11 | #SBATCH --cpus-per-task=54 12 | #SBATCH --mem=440G 13 | #SBATCH --account=pr_109_tandon_advanced 14 | 15 | # set all the paths to environment variables 16 | source ./set_path.sh 17 | 18 | # CPU check 19 | TOTAL_CORES=$(nproc) 20 | echo "TOTAL_CORES=$TOTAL_CORES" 21 | 22 | echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" 23 | NUM_GPU="$(nvidia-smi --list-gpus | wc -l)" 24 | echo "NUM_GPU=$NUM_GPU" 25 | 26 | # Compute OMP_NUM_THREADS (avoid division by zero) 27 | OMP_THREADS=$((TOTAL_CORES / NUM_GPU)) 28 | 29 | # Ensure OMP_NUM_THREADS is at least 1 30 | OMP_THREADS=$((OMP_THREADS > 0 ? OMP_THREADS : 1)) 31 | echo "OMP_NUM_THREADS=$OMP_THREADS" 32 | 33 | export MASTER_ADDR=$(scontrol show hostname ${SLURM_NODELIST} | head -n 1) 34 | find_free_port() { 35 | python -c "import socket; s = socket.socket(socket.AF_INET, socket.SOCK_STREAM); s.bind(('', 0)); port = s.getsockname()[1]; s.close(); print(port)" 36 | } 37 | export MASTER_PORT=$(find_free_port) 38 | 39 | # export NCCL_DEBUG=INFO 40 | echo "Job restart count: $SLURM_RESTART_COUNT" 41 | 42 | singularity exec --nv \ 43 | --overlay ${OVERLAY_EXT3}:ro \ 44 | /scratch/work/public/singularity/cuda12.1.1-cudnn8.9.0-devel-ubuntu22.04.2.sif \ 45 | /bin/bash -c "source ./set_path.sh; \ 46 | export PATH='/ext3/uv:$PATH'; \ 47 | source ./.venv/bin/activate; \ 48 | export OMP_NUM_THREADS=$OMP_THREADS; \ 49 | uv run torchrun \ 50 | --nnodes=1 \ 51 | --nproc_per_node=$NUM_GPU \ 52 | --rdzv_id=$RANDOM \ 53 | --rdzv_backend=c10d \ 54 | --max-restarts=0 \ 55 | --standalone \ 56 | --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \ 57 | src/agent/run.py \ 58 | --config_path config/train/pi0_baseline_bridge_paraphrase.yaml" 59 | 60 | -------------------------------------------------------------------------------- /slurms/train_scripts/pi0_finetune_bridge.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=pi0_finetune_bridge 4 | #SBATCH --output=log/slurm/train/%x.out 5 | #SBATCH --error=log/slurm/train/%x.err 6 | #SBATCH --time=44:00:00 7 | #SBATCH --nodes=1 8 | #SBATCH --gres=gpu:4 9 | #SBATCH --constraint="a100|h100" 10 | #SBATCH --ntasks-per-node=1 11 | #SBATCH --cpus-per-task=54 12 | #SBATCH --mem=440G 13 | #SBATCH --account=pr_109_tandon_advanced 14 | #SBATCH --requeue 15 | 16 | # set all the paths to environment variables 17 | source ./set_path.sh 18 | 19 | # CPU check 20 | TOTAL_CORES=$(nproc) 21 | echo "TOTAL_CORES=$TOTAL_CORES" 22 | 23 | echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" 24 | NUM_GPU="$(nvidia-smi --list-gpus | wc -l)" 25 | echo "NUM_GPU=$NUM_GPU" 26 | 27 | # Compute OMP_NUM_THREADS (avoid division by zero) 28 | OMP_THREADS=$((TOTAL_CORES / NUM_GPU)) 29 | 30 | # Ensure OMP_NUM_THREADS is at least 1 31 | OMP_THREADS=$((OMP_THREADS > 0 ? OMP_THREADS : 1)) 32 | echo "OMP_NUM_THREADS=$OMP_THREADS" 33 | 34 | export MASTER_ADDR=$(scontrol show hostname ${SLURM_NODELIST} | head -n 1) 35 | find_free_port() { 36 | python -c "import socket; s = socket.socket(socket.AF_INET, socket.SOCK_STREAM); s.bind(('', 0)); port = s.getsockname()[1]; s.close(); print(port)" 37 | } 38 | export MASTER_PORT=$(find_free_port) 39 | 40 | # export NCCL_DEBUG=INFO 41 | echo "Job restart count: $SLURM_RESTART_COUNT" 42 | 43 | singularity exec --nv \ 44 | --overlay ${OVERLAY_EXT3}:ro \ 45 | /scratch/work/public/singularity/cuda12.1.1-cudnn8.9.0-devel-ubuntu22.04.2.sif \ 46 | /bin/bash -c "source ./set_path.sh; \ 47 | export PATH='/ext3/uv:$PATH'; \ 48 | source ./.venv/bin/activate; \ 49 | export OMP_NUM_THREADS=$OMP_THREADS; \ 50 | uv run torchrun \ 51 | --nnodes=1 \ 52 | --nproc_per_node=$NUM_GPU \ 53 | --rdzv_id=$RANDOM \ 54 | --rdzv_backend=c10d \ 55 | --max-restarts=0 \ 56 | --standalone \ 57 | --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \ 58 | src/agent/run.py \ 59 | --config_path config/train/pi0_finetune_bridge.yaml" || scontrol requeue $SLURM_JOB_ID 60 | 61 | -------------------------------------------------------------------------------- /slurms/train_scripts/pi0_finetune_bridge_paraphrase.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=pi0_finetune_bridge_rephrase 4 | #SBATCH --output=log/slurm/train/%x.out 5 | #SBATCH --error=log/slurm/train/%x.err 6 | #SBATCH --time=44:00:00 7 | #SBATCH --nodes=1 8 | #SBATCH --gres=gpu:4 9 | #SBATCH --constraint="a100|h100" 10 | #SBATCH --ntasks-per-node=1 11 | #SBATCH --cpus-per-task=54 12 | #SBATCH --mem=440G 13 | #SBATCH --account=pr_109_tandon_advanced 14 | 15 | # set all the paths to environment variables 16 | source ./set_path.sh 17 | 18 | # CPU check 19 | TOTAL_CORES=$(nproc) 20 | echo "TOTAL_CORES=$TOTAL_CORES" 21 | 22 | echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" 23 | NUM_GPU="$(nvidia-smi --list-gpus | wc -l)" 24 | echo "NUM_GPU=$NUM_GPU" 25 | 26 | # Compute OMP_NUM_THREADS (avoid division by zero) 27 | OMP_THREADS=$((TOTAL_CORES / NUM_GPU)) 28 | 29 | # Ensure OMP_NUM_THREADS is at least 1 30 | OMP_THREADS=$((OMP_THREADS > 0 ? OMP_THREADS : 1)) 31 | echo "OMP_NUM_THREADS=$OMP_THREADS" 32 | 33 | export MASTER_ADDR=$(scontrol show hostname ${SLURM_NODELIST} | head -n 1) 34 | find_free_port() { 35 | python -c "import socket; s = socket.socket(socket.AF_INET, socket.SOCK_STREAM); s.bind(('', 0)); port = s.getsockname()[1]; s.close(); print(port)" 36 | } 37 | export MASTER_PORT=$(find_free_port) 38 | 39 | # export NCCL_DEBUG=INFO 40 | echo "Job restart count: $SLURM_RESTART_COUNT" 41 | 42 | singularity exec --nv \ 43 | --overlay ${OVERLAY_EXT3}:ro \ 44 | /scratch/work/public/singularity/cuda12.1.1-cudnn8.9.0-devel-ubuntu22.04.2.sif \ 45 | /bin/bash -c "source ./set_path.sh; \ 46 | export PATH='/ext3/uv:$PATH'; \ 47 | source ./.venv/bin/activate; \ 48 | export OMP_NUM_THREADS=$OMP_THREADS; \ 49 | uv run torchrun \ 50 | --nnodes=1 \ 51 | --nproc_per_node=$NUM_GPU \ 52 | --rdzv_id=$RANDOM \ 53 | --rdzv_backend=c10d \ 54 | --max-restarts=0 \ 55 | --standalone \ 56 | --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \ 57 | src/agent/run.py \ 58 | --config_path config/train/pi0_finetune_bridge_paraphrase.yaml" 59 | 60 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from src.model.magma.modeling_magma import MagmaPolicy 2 | from src.model.octo.modeling_octo import OctoPolicy 3 | from src.model.spatialvla.modeling_spatialvla import SpatialVLAPolicy 4 | -------------------------------------------------------------------------------- /src/agent/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai4ce/INT-ACT/66ae3ed4719c68cb2f3f22868015a7a171b04e1f/src/agent/__init__.py -------------------------------------------------------------------------------- /src/agent/dataset.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from src.data.oxe import make_oxe_dataset_kwargs_and_weights 4 | from src.data.rlds_dataset import make_interleaved_dataset 5 | from src.data.rlds_dataset_torch import TorchRLDSDataset 6 | from src.utils.monitor import log_execution_time 7 | 8 | tf.config.set_visible_devices([], "GPU") 9 | 10 | 11 | class TorchRLDSInterleavedDataset: 12 | @log_execution_time() 13 | def __init__(self, config, train=True, task_paraphrase=False, shuffle=None): 14 | dataset_kwargs_list, sample_weights = make_oxe_dataset_kwargs_and_weights( 15 | config.dataset_mix, 16 | config.data_path, 17 | load_proprio=config.load_proprio, 18 | load_camera_views=("primary",), 19 | ) 20 | # # ! debug perpose, fix the data loading to make it determinisitic, see: https://github.com/kvablack/dlimp/pull/4 21 | # for dataset_kwarg in dataset_kwargs_list: 22 | # dataset_kwargs["deterministic"] = config.deterministic_dataset 23 | if shuffle is None: 24 | shuffle = train 25 | traj_transform_kwargs = dict( 26 | # goal_relabeling_strategy="uniform", # no neeed for goal relabeling 27 | window_size=config.window_size, 28 | action_horizon=config.action_horizon, 29 | subsample_length=100, 30 | 31 | max_action_future=config.max_action_future, 32 | skip_unlabeled=config.skip_unlabeled, # skip ones without language annotation 33 | ) 34 | if task_paraphrase: 35 | traj_transform_kwargs["task_augment_strategy"] = "rephrase_instruction" 36 | traj_transform_kwargs["task_augment_kwargs"] = dict( 37 | paraphrases_repo="rail-berkeley/OXE_paraphrases", 38 | paraphrases_filename="paraphrases_oxe.pkl", 39 | rephrase_prob=0.5, 40 | ) 41 | dataset = make_interleaved_dataset( 42 | dataset_kwargs_list, 43 | sample_weights, 44 | train=train, 45 | split=config.split, 46 | shuffle=shuffle, 47 | shuffle_buffer_size=config.shuffle_buffer_size, 48 | batch_size=None, # batching will be handles in PyTorch Dataloader object 49 | balance_weights=True, 50 | traj_transform_kwargs=traj_transform_kwargs, 51 | frame_transform_kwargs=dict( 52 | image_augment_kwargs={ 53 | "primary": dict( 54 | random_resized_crop=dict( 55 | scale=[0.8, 1.0], 56 | ratio=[0.9, 1.1], 57 | ), 58 | random_brightness=[0.1], 59 | random_contrast=[0.9, 1.1], 60 | random_saturation=[0.9, 1.1], 61 | random_hue=[0.05], 62 | augment_order=[ 63 | "random_resized_crop", 64 | "random_brightness", 65 | "random_contrast", 66 | "random_saturation", 67 | "random_hue", 68 | ], 69 | ), 70 | "wrist": dict( 71 | random_brightness=[0.1], 72 | random_contrast=[0.9, 1.1], 73 | random_saturation=[0.9, 1.1], 74 | random_hue=[0.05], 75 | augment_order=[ 76 | "random_brightness", 77 | "random_contrast", 78 | "random_saturation", 79 | "random_hue", 80 | ], 81 | ), 82 | }, 83 | resize_size=dict( 84 | primary=(224, 224), 85 | wrist=(224, 224), 86 | ), 87 | num_parallel_calls=config.num_parallel_calls, 88 | ), 89 | traj_transform_threads=config.traj_transform_threads, 90 | traj_read_threads=config.traj_read_threads, 91 | ) 92 | 93 | # convert for torch 94 | self.dataset = TorchRLDSDataset(dataset, train=train) 95 | -------------------------------------------------------------------------------- /src/agent/run.py: -------------------------------------------------------------------------------- 1 | """ 2 | The script to step into the training or evaluation. 3 | Has model factory featue to select the model to train or evaluate. 4 | 5 | """ 6 | import os 7 | import sys 8 | 9 | import draccus 10 | from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy 11 | from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy 12 | 13 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) 14 | from policy_server_client.websocket_policy_server import WebsocketPolicyServer 15 | 16 | from src.agent.configuration_pipeline import TrainPipelineConfig 17 | from src.utils.pipeline import get_class_from_path 18 | 19 | 20 | @draccus.wrap() 21 | def main(pipeline_cfg: TrainPipelineConfig): 22 | model_type = pipeline_cfg.model_cfg.type 23 | 24 | model_map = { 25 | "pi0": PI0Policy, 26 | } 27 | 28 | if pipeline_cfg.eval_cfg is None: 29 | # only training 30 | from src.agent.trainer import PI0Trainer 31 | trainer_map = { 32 | "pi0": PI0Trainer, 33 | } 34 | 35 | model_class = model_map.get(model_type, None) 36 | 37 | trainer_class = trainer_map.get(model_type, None) 38 | if trainer_class is None: 39 | raise ValueError(f"Model type {model_type} not supported for training.") 40 | 41 | trainer = trainer_class(train_cfg=pipeline_cfg, model_class=model_class) 42 | trainer.train() 43 | else: 44 | # evaluation 45 | if pipeline_cfg.eval_cfg.role == "server": 46 | 47 | model_class = model_map.get(model_type, None) 48 | 49 | from src.experiments.policies.policy_wrapper import LeRobotPolicyWrapper, MagmaPolicyWrapper, OctoPolicyWrapper, SpatialVLAPolicyWrapper 50 | policy_wrapper_map = { 51 | "pi0": LeRobotPolicyWrapper, 52 | "spatial-vla": SpatialVLAPolicyWrapper, 53 | "magma": MagmaPolicyWrapper, 54 | "octo": OctoPolicyWrapper, 55 | } 56 | policy_wrapper_class = policy_wrapper_map.get(model_type, None) 57 | policy = policy_wrapper_class(pipeline_cfg=pipeline_cfg, model_class=model_class) 58 | 59 | websocket_server = WebsocketPolicyServer( 60 | policy=policy, 61 | host=policy.host, 62 | port=policy.port, 63 | ) 64 | websocket_server.serve_forever() 65 | 66 | elif pipeline_cfg.eval_cfg.role == "client": 67 | evaluator_class = get_class_from_path(pipeline_cfg.eval_cfg.simulator_path) 68 | evaluator = evaluator_class(pipeline_cfg=pipeline_cfg) 69 | evaluator.evaluate() 70 | if __name__ == "__main__": 71 | main() 72 | -------------------------------------------------------------------------------- /src/data/dlimp/__init__.py: -------------------------------------------------------------------------------- 1 | from . import transforms 2 | from .dataset import DLataset 3 | from .utils import parallel_vmap, vmap 4 | -------------------------------------------------------------------------------- /src/data/dlimp/augmentations.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import tensorflow as tf 4 | 5 | 6 | def random_resized_crop(image, scale, ratio, seed): 7 | assert image.shape.ndims == 3 or image.shape.ndims == 4 8 | if image.shape.ndims == 3: 9 | image = tf.expand_dims(image, axis=0) 10 | batch_size = tf.shape(image)[0] 11 | # taken from https://keras.io/examples/vision/nnclr/#random-resized-crops 12 | log_ratio = (tf.math.log(ratio[0]), tf.math.log(ratio[1])) 13 | height = tf.shape(image)[1] 14 | width = tf.shape(image)[2] 15 | 16 | random_scales = tf.random.stateless_uniform((batch_size,), seed, scale[0], scale[1]) 17 | random_ratios = tf.exp( 18 | tf.random.stateless_uniform((batch_size,), seed, log_ratio[0], log_ratio[1]) 19 | ) 20 | 21 | new_heights = tf.clip_by_value(tf.sqrt(random_scales / random_ratios), 0, 1) 22 | new_widths = tf.clip_by_value(tf.sqrt(random_scales * random_ratios), 0, 1) 23 | height_offsets = tf.random.stateless_uniform( 24 | (batch_size,), seed, 0, 1 - new_heights 25 | ) 26 | width_offsets = tf.random.stateless_uniform((batch_size,), seed, 0, 1 - new_widths) 27 | 28 | bounding_boxes = tf.stack( 29 | [ 30 | height_offsets, 31 | width_offsets, 32 | height_offsets + new_heights, 33 | width_offsets + new_widths, 34 | ], 35 | axis=1, 36 | ) 37 | 38 | image = tf.image.crop_and_resize( 39 | image, bounding_boxes, tf.range(batch_size), (height, width) 40 | ) 41 | 42 | if image.shape[0] == 1: 43 | return image[0] 44 | else: 45 | return image 46 | 47 | 48 | def random_rot90(image, seed): 49 | k = tf.random.stateless_uniform((), seed, 0, 4, dtype=tf.int32) 50 | return tf.image.rot90(image, k=k) 51 | 52 | 53 | AUGMENT_OPS = { 54 | "random_resized_crop": random_resized_crop, 55 | "random_brightness": tf.image.stateless_random_brightness, 56 | "random_contrast": tf.image.stateless_random_contrast, 57 | "random_saturation": tf.image.stateless_random_saturation, 58 | "random_hue": tf.image.stateless_random_hue, 59 | "random_flip_left_right": tf.image.stateless_random_flip_left_right, 60 | "random_flip_up_down": tf.image.stateless_random_flip_up_down, 61 | "random_rot90": random_rot90, 62 | } 63 | 64 | 65 | def augment_image( 66 | image: tf.Tensor, 67 | seed: Optional[tf.Tensor] = None, 68 | **augment_kwargs, 69 | ) -> tf.Tensor: 70 | """Unified image augmentation function for TensorFlow. 71 | 72 | This function is primarily configured through `augment_kwargs`. There must be one kwarg called "augment_order", 73 | which is a list of strings specifying the augmentation operations to apply and the order in which to apply them. See 74 | the `AUGMENT_OPS` dictionary above for a list of available operations. 75 | 76 | For each entry in "augment_order", there may be a corresponding kwarg with the same name. The value of this kwarg 77 | can be a dictionary of kwargs or a sequence of positional args to pass to the corresponding augmentation operation. 78 | This additional kwarg is required for all operations that take additional arguments other than the image and random 79 | seed. For example, the "random_resized_crop" operation requires a "scale" and "ratio" argument that can be specified 80 | either positionally or by name. "random_flip_left_right", on the other hand, does not take any additional arguments 81 | and so does not require an additional kwarg to configure it. 82 | 83 | Here is an example config: 84 | 85 | ``` 86 | augment_kwargs = { 87 | "augment_order": ["random_resized_crop", "random_brightness", "random_contrast", "random_flip_left_right"], 88 | "random_resized_crop": { 89 | "scale": [0.8, 1.0], 90 | "ratio": [3/4, 4/3], 91 | }, 92 | "random_brightness": [0.1], 93 | "random_contrast": [0.9, 1.1], 94 | ``` 95 | 96 | Args: 97 | image: A `Tensor` of shape [height, width, channels] with the image. May be uint8 or float32 with values in [0, 255]. 98 | seed (optional): A `Tensor` of shape [2] with the seed for the random number generator. 99 | **augment_kwargs: Keyword arguments for the augmentation operations. The order of operations is determined by 100 | the "augment_order" keyword argument. Other keyword arguments are passed to the corresponding augmentation 101 | operation. See above for a list of operations. 102 | """ 103 | assert isinstance(augment_kwargs, dict) 104 | 105 | if "augment_order" not in augment_kwargs: 106 | raise ValueError("augment_kwargs must contain an 'augment_order' key.") 107 | 108 | # convert to float at the beginning to avoid each op converting back and 109 | # forth between uint8 and float32 internally 110 | orig_dtype = image.dtype 111 | image = tf.image.convert_image_dtype(image, tf.float32) 112 | 113 | if seed is None: 114 | seed = tf.random.uniform([2], maxval=tf.dtypes.int32.max, dtype=tf.int32) 115 | 116 | for op in augment_kwargs["augment_order"]: 117 | seed = tf.random.stateless_uniform( 118 | [2], seed, maxval=tf.dtypes.int32.max, dtype=tf.int32 119 | ) 120 | if op in augment_kwargs: 121 | if hasattr(augment_kwargs[op], "items"): 122 | image = AUGMENT_OPS[op](image, seed=seed, **augment_kwargs[op]) 123 | else: 124 | image = AUGMENT_OPS[op](image, seed=seed, *augment_kwargs[op]) 125 | else: 126 | image = AUGMENT_OPS[op](image, seed=seed) 127 | # float images are expected to be in [0, 1] 128 | image = tf.clip_by_value(image, 0, 1) 129 | 130 | # convert back to original dtype and scale 131 | image = tf.image.convert_image_dtype(image, orig_dtype, saturate=True) 132 | 133 | return image 134 | -------------------------------------------------------------------------------- /src/data/dlimp/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from . import goal_relabeling 2 | from .common import * 3 | from .frame_transforms import * 4 | from .traj_transforms import * 5 | -------------------------------------------------------------------------------- /src/data/dlimp/transforms/common.py: -------------------------------------------------------------------------------- 1 | import fnmatch 2 | from typing import Any, Callable, Dict, Union 3 | 4 | 5 | def selective_tree_map( 6 | x: Dict[str, Any], 7 | match: Union[str, Callable[[str, Any], bool]], 8 | map_fn: Callable, 9 | *, 10 | _keypath: str = "", 11 | ) -> Dict[str, Any]: 12 | """Maps a function over a nested dictionary, only applying it leaves that match a criterion. 13 | 14 | If `match` is a string, it follows glob-style syntax. For example, "bar" will only match 15 | a top-level key called "bar", "*bar" will match any leaf whose key ends with "bar", 16 | and "*bar*" will match any subtree with a key that contains "bar". 17 | 18 | Key paths are separated by "/". For example, "foo/bar" will match a leaf with key "bar" that 19 | is nested under a key "foo". 20 | 21 | Args: 22 | x (Dict[str, Any]): The (possibly nested) dictionary to map over. 23 | match (str or Callable[[str, Any], bool]): If a string or list of strings, `map_fn` will 24 | only be applied to leaves whose key path matches `match` using glob-style syntax. If a 25 | function, `map_fn` will only be applied to leaves for which `match(key_path, value)` 26 | returns True. 27 | map_fn (Callable): The function to apply. 28 | """ 29 | if not callable(match): 30 | 31 | def match_fn(keypath, value): 32 | return fnmatch.fnmatch(keypath, match) 33 | else: 34 | match_fn = match 35 | 36 | out = {} 37 | for key in x: 38 | if isinstance(x[key], dict): 39 | out[key] = selective_tree_map( 40 | x[key], match_fn, map_fn, _keypath=_keypath + key + "/" 41 | ) 42 | elif match_fn(_keypath + key, x[key]): 43 | out[key] = map_fn(x[key]) 44 | else: 45 | out[key] = x[key] 46 | return out 47 | 48 | 49 | def flatten_dict(d: Dict[str, Any], sep="/") -> Dict[str, Any]: 50 | """Given a nested dictionary, flatten it by concatenating keys with sep.""" 51 | flattened = {} 52 | for k, v in d.items(): 53 | if isinstance(v, dict): 54 | for k2, v2 in flatten_dict(v, sep=sep).items(): 55 | flattened[k + sep + k2] = v2 56 | else: 57 | flattened[k] = v 58 | return flattened 59 | 60 | 61 | def unflatten_dict(d: Dict[str, Any], sep="/") -> Dict[str, Any]: 62 | """Given a flattened dictionary, unflatten it by splitting keys by sep.""" 63 | unflattened = {} 64 | for k, v in d.items(): 65 | keys = k.split(sep) 66 | if len(keys) == 1: 67 | unflattened[k] = v 68 | else: 69 | if keys[0] not in unflattened: 70 | unflattened[keys[0]] = {} 71 | unflattened[keys[0]][sep.join(keys[1:])] = v 72 | return unflattened 73 | -------------------------------------------------------------------------------- /src/data/dlimp/transforms/frame_transforms.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any, Callable, Dict, Sequence, Tuple, Union 3 | 4 | import tensorflow as tf 5 | 6 | from src.data.dlimp.augmentations import augment_image 7 | from src.data.dlimp.utils import resize_depth_image, resize_image 8 | 9 | from .common import selective_tree_map 10 | 11 | 12 | def decode_images( 13 | x: Dict[str, Any], match: Union[str, Sequence[str]] = "image" 14 | ) -> Dict[str, Any]: 15 | """Can operate on nested dicts. Decodes any leaves that have `match` anywhere in their path.""" 16 | if isinstance(match, str): 17 | match = [match] 18 | 19 | return selective_tree_map( 20 | x, 21 | lambda keypath, value: any([s in keypath for s in match]) 22 | and value.dtype == tf.string, 23 | partial(tf.io.decode_image, expand_animations=False), 24 | ) 25 | 26 | 27 | def resize_images( 28 | x: Dict[str, Any], 29 | match: Union[str, Sequence[str]] = "image", 30 | size: Tuple[int, int] = (128, 128), 31 | ) -> Dict[str, Any]: 32 | """Can operate on nested dicts. Resizes any leaves that have `match` anywhere in their path. Takes uint8 images 33 | as input and returns float images (still in [0, 255]). 34 | """ 35 | if isinstance(match, str): 36 | match = [match] 37 | 38 | return selective_tree_map( 39 | x, 40 | lambda keypath, value: any([s in keypath for s in match]) 41 | and value.dtype == tf.uint8, 42 | partial(resize_image, size=size), 43 | ) 44 | 45 | 46 | def resize_depth_images( 47 | x: Dict[str, Any], 48 | match: Union[str, Sequence[str]] = "depth", 49 | size: Tuple[int, int] = (128, 128), 50 | ) -> Dict[str, Any]: 51 | """Can operate on nested dicts. Resizes any leaves that have `match` anywhere in their path. Takes float32 images 52 | as input and returns float images (in arbitrary range). 53 | """ 54 | if isinstance(match, str): 55 | match = [match] 56 | 57 | return selective_tree_map( 58 | x, 59 | lambda keypath, value: any([s in keypath for s in match]) 60 | and value.dtype == tf.float32, 61 | partial(resize_depth_image, size=size), 62 | ) 63 | 64 | 65 | def augment( 66 | x: Dict[str, Any], 67 | match: Union[str, Callable[[str, Any], bool]] = "*image", 68 | traj_identical: bool = True, 69 | keys_identical: bool = True, 70 | augment_kwargs: dict = {}, 71 | ) -> Dict[str, Any]: 72 | """ 73 | Augments the input dictionary `x` by applying image augmentation to all values whose keypath contains `match`. 74 | 75 | Args: 76 | x (Dict[str, Any]): The input dictionary to augment. 77 | match (str or Callable[[str, Any], bool]): See documentation for `selective_tree_map`. 78 | Defaults to "*image", which matches all leaves whose key ends in "image". 79 | traj_identical (bool, optional): Whether to use the same random seed for all images in a trajectory. 80 | keys_identical (bool, optional): Whether to use the same random seed for all keys that are augmented. 81 | augment_kwargs (dict, optional): Additional keyword arguments to pass to the `augment_image` function. 82 | """ 83 | toplevel_seed = tf.random.uniform([2], 0, 2**31 - 1, dtype=tf.int32) 84 | 85 | def map_fn(value): 86 | if keys_identical and traj_identical: 87 | seed = [x["_traj_index"], x["_traj_index"]] 88 | elif keys_identical and not traj_identical: 89 | seed = toplevel_seed 90 | elif not keys_identical and traj_identical: 91 | raise NotImplementedError() 92 | else: 93 | seed = None 94 | 95 | return augment_image(value, seed=seed, **augment_kwargs) 96 | 97 | return selective_tree_map( 98 | x, 99 | match, 100 | map_fn, 101 | ) 102 | -------------------------------------------------------------------------------- /src/data/dlimp/transforms/goal_relabeling.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains goal relabeling and reward logic written in TensorFlow. 3 | 4 | Each relabeling function takes a trajectory with keys `obs` and `next_obs`. It returns a new trajectory with the added 5 | keys `goals` and `rewards`. Keep in mind that `obs` and `next_obs` may themselves be dictionaries, and `goals` must 6 | match their structure. 7 | """ 8 | 9 | from typing import Any, Dict 10 | 11 | import tensorflow as tf 12 | 13 | 14 | def uniform(traj: Dict[str, Any], reached_proportion: float): 15 | """Relabels with a true uniform distribution over future states. With probability reached_proportion, 16 | obs[i] gets a goal equal to next_obs[i]. In this case, the reward is 0. Otherwise, 17 | obs[i] gets a goal sampled uniformly from the set next_obs[i + 1:], and the reward is -1. 18 | """ 19 | traj_len = tf.shape(tf.nest.flatten(traj)[0])[0] 20 | 21 | # select a random future index for each transition i in the range [i + 1, traj_len) 22 | rand = tf.random.uniform([traj_len]) 23 | low = tf.cast(tf.range(traj_len) + 1, tf.float32) 24 | high = tf.cast(traj_len, tf.float32) 25 | goal_idxs = tf.cast(rand * (high - low) + low, tf.int32) 26 | 27 | # TODO(kvablack): don't know how I got an out-of-bounds during training, 28 | # could not reproduce, trying to patch it for now 29 | goal_idxs = tf.minimum(goal_idxs, traj_len - 1) 30 | 31 | # select a random proportion of transitions to relabel with the next obs 32 | goal_reached_mask = tf.random.uniform([traj_len]) < reached_proportion 33 | 34 | # the last transition must be goal-reaching 35 | goal_reached_mask = tf.logical_or( 36 | goal_reached_mask, tf.range(traj_len) == traj_len - 1 37 | ) 38 | 39 | # make goal-reaching transitions have an offset of 0 40 | goal_idxs = tf.where(goal_reached_mask, tf.range(traj_len), goal_idxs) 41 | 42 | # select goals 43 | traj["goals"] = tf.nest.map_structure( 44 | lambda x: tf.gather(x, goal_idxs), 45 | traj["next_obs"], 46 | ) 47 | 48 | # reward is 0 for goal-reaching transitions, -1 otherwise 49 | traj["rewards"] = tf.cast(tf.where(goal_reached_mask, 0, -1), tf.int32) 50 | 51 | return traj 52 | 53 | 54 | def last_state_upweighted(traj: Dict[str, Any], reached_proportion: float): 55 | """ 56 | A weird relabeling scheme where the last state gets upweighted. For each transition i, a uniform random number is 57 | generated in the range [i + 1, i + traj_len). It then gets clipped to be less than traj_len. Therefore, the first 58 | transition (i = 0) gets a goal sampled uniformly from the future, but for i > 0 the last state gets more and more 59 | upweighted. 60 | """ 61 | traj_len = tf.shape(tf.nest.flatten(traj)[0])[0] 62 | 63 | # select a random future index for each transition 64 | offsets = tf.random.uniform( 65 | [traj_len], 66 | minval=1, 67 | maxval=traj_len, 68 | dtype=tf.int32, 69 | ) 70 | 71 | # select random transitions to relabel as goal-reaching 72 | goal_reached_mask = tf.random.uniform([traj_len]) < reached_proportion 73 | # last transition is always goal-reaching 74 | goal_reached_mask = tf.logical_or( 75 | goal_reached_mask, tf.range(traj_len) == traj_len - 1 76 | ) 77 | 78 | # the goal will come from the current transition if the goal was reached 79 | offsets = tf.where(goal_reached_mask, 0, offsets) 80 | 81 | # convert from relative to absolute indices 82 | indices = tf.range(traj_len) + offsets 83 | 84 | # clamp out of bounds indices to the last transition 85 | indices = tf.minimum(indices, traj_len - 1) 86 | 87 | # select goals 88 | traj["goals"] = tf.nest.map_structure( 89 | lambda x: tf.gather(x, indices), 90 | traj["next_obs"], 91 | ) 92 | 93 | # reward is 0 for goal-reaching transitions, -1 otherwise 94 | traj["rewards"] = tf.cast(tf.where(goal_reached_mask, 0, -1), tf.int32) 95 | 96 | return traj 97 | 98 | 99 | def geometric(traj: Dict[str, Any], reached_proportion: float, discount: float): 100 | """ 101 | Relabels with a geometric distribution over future states. With probability reached_proportion, obs[i] gets 102 | a goal equal to next_obs[i]. In this case, the reward is 0. Otherwise, obs[i] gets a goal sampled 103 | geometrically from the set next_obs[i + 1:], and the reward is -1. 104 | """ 105 | traj_len = tf.shape(tf.nest.flatten(traj)[0])[0] 106 | 107 | # geometrically select a future index for each transition i in the range [i + 1, traj_len) 108 | arange = tf.range(traj_len) 109 | is_future_mask = tf.cast(arange[:, None] < arange[None], tf.float32) 110 | d = discount ** tf.cast(arange[None] - arange[:, None], tf.float32) 111 | 112 | probs = is_future_mask * d 113 | # The indexing changes the shape from [seq_len, 1] to [seq_len] 114 | goal_idxs = tf.random.categorical( 115 | logits=tf.math.log(probs), num_samples=1, dtype=tf.int32 116 | )[:, 0] 117 | 118 | # select a random proportion of transitions to relabel with the next obs 119 | goal_reached_mask = tf.random.uniform([traj_len]) < reached_proportion 120 | 121 | # the last transition must be goal-reaching 122 | goal_reached_mask = tf.logical_or( 123 | goal_reached_mask, tf.range(traj_len) == traj_len - 1 124 | ) 125 | 126 | # make goal-reaching transitions have an offset of 0 127 | goal_idxs = tf.where(goal_reached_mask, tf.range(traj_len), goal_idxs) 128 | 129 | # select goals 130 | traj["goals"] = tf.nest.map_structure( 131 | lambda x: tf.gather(x, goal_idxs), 132 | traj["next_obs"], 133 | ) 134 | 135 | # reward is 0 for goal-reaching transitions, -1 otherwise 136 | traj["rewards"] = tf.cast(tf.where(goal_reached_mask, 0, -1), tf.int32) 137 | 138 | return traj 139 | -------------------------------------------------------------------------------- /src/data/dlimp/transforms/traj_transforms.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | import tensorflow as tf 4 | 5 | 6 | def add_next_obs(traj: Dict[str, Any], pad: bool = True) -> Dict[str, Any]: 7 | """ 8 | Given a trajectory with a key "observations", add the key "next_observations". If pad is False, discards the last 9 | value of all other keys. Otherwise, the last transition will have "observations" == "next_observations". 10 | """ 11 | if not pad: 12 | traj_truncated = tf.nest.map_structure(lambda x: x[:-1], traj) 13 | traj_truncated["next_observations"] = tf.nest.map_structure( 14 | lambda x: x[1:], traj["observations"] 15 | ) 16 | return traj_truncated 17 | else: 18 | traj["next_observations"] = tf.nest.map_structure( 19 | lambda x: tf.concat((x[1:], x[-1:]), axis=0), traj["observations"] 20 | ) 21 | return traj 22 | -------------------------------------------------------------------------------- /src/data/dlimp/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Tuple 2 | 3 | import tensorflow as tf 4 | 5 | 6 | def tensor_feature(value): 7 | return tf.train.Feature( 8 | bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(value).numpy()]) 9 | ) 10 | 11 | 12 | def resize_image(image: tf.Tensor, size: Tuple[int, int]) -> tf.Tensor: 13 | """Resizes an image using Lanczos3 interpolation. Expects & returns uint8.""" 14 | assert image.dtype == tf.uint8 15 | image = tf.image.resize(image, size, method="lanczos3", antialias=True) 16 | image = tf.cast(tf.clip_by_value(tf.round(image), 0, 255), tf.uint8) 17 | return image 18 | 19 | 20 | def resize_depth_image(depth_image: tf.Tensor, size: Tuple[int, int]) -> tf.Tensor: 21 | """Resizes a depth image using bilinear interpolation. Expects & returns float32 in arbitrary range.""" 22 | assert depth_image.dtype == tf.float32 23 | if len(depth_image.shape) < 3: 24 | depth_image = tf.image.resize( 25 | depth_image[..., None], size, method="bilinear", antialias=True 26 | )[..., 0] 27 | else: 28 | depth_image = tf.image.resize( 29 | depth_image, size, method="bilinear", antialias=True 30 | ) 31 | return depth_image 32 | 33 | 34 | def read_resize_encode_image(path: str, size: Tuple[int, int]) -> tf.Tensor: 35 | """Reads, decodes, resizes, and then re-encodes an image.""" 36 | data = tf.io.read_file(path) 37 | image = tf.image.decode_jpeg(data) 38 | image = resize_image(image, size) 39 | image = tf.cast(tf.clip_by_value(tf.round(image), 0, 255), tf.uint8) 40 | return tf.io.encode_jpeg(image, quality=95) 41 | 42 | 43 | def vmap(fn: Callable) -> Callable: 44 | """ 45 | Vmap a function over the first dimension of a tensor (or nested structure of tensors). This 46 | version does NOT parallelize the function; however, it fuses the function calls in a way that 47 | appears to be more performant than tf.map_fn or tf.vectorized_map (when falling back to 48 | while_loop) for certain situations. 49 | 50 | Requires the first dimension of the input to be statically known. 51 | """ 52 | 53 | def wrapped(structure): 54 | return tf.nest.map_structure( 55 | lambda *x: tf.stack(x), 56 | *[ 57 | fn(tf.nest.pack_sequence_as(structure, x)) 58 | for x in zip(*map(tf.unstack, tf.nest.flatten(structure))) 59 | ], 60 | ) 61 | 62 | return wrapped 63 | 64 | 65 | def parallel_vmap(fn: Callable, num_parallel_calls=tf.data.AUTOTUNE) -> Callable: 66 | """ 67 | Vmap a function over the first dimension of a tensor (or nested structure of tensors). This 68 | version attempts to parallelize the function using the tf.data API. I found this to be more 69 | performant than tf.map_fn or tf.vectorized_map (when falling back to while_loop), but the batch 70 | call appears to add significant overhead that may make it slower for some situations. 71 | """ 72 | 73 | def wrapped(structure): 74 | return ( 75 | tf.data.Dataset.from_tensor_slices(structure) 76 | .map(fn, deterministic=True, num_parallel_calls=num_parallel_calls) 77 | .batch( 78 | tf.cast(tf.shape(tf.nest.flatten(structure)[0])[0], tf.int64), 79 | ) 80 | .get_single_element() 81 | ) 82 | 83 | return wrapped 84 | -------------------------------------------------------------------------------- /src/data/rlds_dataset_torch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copied from Allen's open-pi-zero: https://github.com/allenzren/open-pi-zero 3 | From: https://github.com/octo-models/octo/blob/main/examples/06_pytorch_oxe_dataloader.py 4 | 5 | This example shows how to use the `src.data` dataloader with PyTorch by wrapping it in a simple PyTorch dataloader. The config below also happens to be our exact pretraining config (except for the batch size and shuffle buffer size, which are reduced for demonstration purposes). 6 | """ 7 | 8 | import tensorflow as tf 9 | import torch 10 | import numpy as np 11 | tf.config.set_visible_devices([], "GPU") 12 | 13 | 14 | class TorchRLDSDataset(torch.utils.data.IterableDataset): 15 | """Thin wrapper around RLDS dataset for use with PyTorch dataloaders.""" 16 | 17 | def __init__( 18 | self, 19 | rlds_dataset, 20 | train=True, 21 | ): 22 | self._rlds_dataset = rlds_dataset 23 | self._is_train = train 24 | # commputed bsed on sampling weights 25 | self.split_transition_length, self.transition_lengths = self.__get_sampling_length() 26 | 27 | def __iter__(self): 28 | for sample in self._rlds_dataset.as_numpy_iterator(): 29 | yield sample 30 | 31 | def __len__(self): 32 | # TODO(allenzren): account for sample weights? 33 | return self._rlds_dataset.true_total_length 34 | # lengths = np.array( 35 | # [ 36 | # stats["num_transitions"] 37 | # for stats in self._rlds_dataset.dataset_statistics.values() 38 | # ], 39 | # dtype=float, 40 | # ) 41 | # if hasattr(self._rlds_dataset, "sample_weights"): 42 | # lengths *= self._rlds_dataset.sample_weights 43 | # total_len = lengths.sum() 44 | # if self._is_train: 45 | # return int(0.95 * total_len) 46 | # else: 47 | # return int(0.05 * total_len) 48 | 49 | def __get_sampling_length(self): 50 | # this function returns the sampling 51 | lengths = np.array( 52 | [ 53 | stats["num_transitions"] 54 | for stats in self._rlds_dataset.dataset_statistics.values() 55 | ], 56 | dtype=float, 57 | ) 58 | 59 | if hasattr(self._rlds_dataset, "sample_weights"): 60 | lengths *= self._rlds_dataset.sample_weights 61 | total_len = lengths.sum() 62 | if self._is_train: 63 | split_len = int(0.95 * total_len) 64 | else: 65 | split_len = int(0.05 * total_len) 66 | return split_len, lengths 67 | -------------------------------------------------------------------------------- /src/data/utils/goal_relabeling.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains simple goal relabeling logic for BC use-cases where rewards and next_observations are not required. 3 | Each function should add entries to the "task" dict. 4 | """ 5 | 6 | from typing import Optional 7 | 8 | import tensorflow as tf 9 | 10 | from src.data.utils.data_utils import tree_merge 11 | 12 | 13 | def uniform(traj: dict, max_goal_distance: Optional[int] = None) -> dict: 14 | """ 15 | Relabels with a true uniform distribution over future states. 16 | Optionally caps goal distance. 17 | """ 18 | traj_len = tf.shape(tf.nest.flatten(traj["observation"])[0])[0] 19 | 20 | # select a random future index for each transition i in the range [i, traj_len) 21 | rand = tf.random.uniform([traj_len]) 22 | low = tf.cast(tf.range(traj_len), tf.float32) 23 | if max_goal_distance is not None: 24 | high = tf.cast( 25 | tf.minimum(tf.range(traj_len) + max_goal_distance, traj_len), tf.float32 26 | ) 27 | else: 28 | high = tf.cast(traj_len, tf.float32) 29 | goal_idxs = tf.cast(rand * (high - low) + low, tf.int32) 30 | 31 | # sometimes there are floating-point errors that cause an out-of-bounds 32 | goal_idxs = tf.minimum(goal_idxs, traj_len - 1) 33 | 34 | # adds keys to "task" mirroring "observation" keys (must do a tree merge to combine "pad_mask_dict" from 35 | # "observation" and "task" properly) 36 | goal = tf.nest.map_structure(lambda x: tf.gather(x, goal_idxs), traj["observation"]) 37 | traj["task"] = tree_merge(traj["task"], goal) 38 | 39 | return traj 40 | 41 | """ 42 | Contains simple goal relabeling logic for BC use-cases where rewards and next_observations are not required. 43 | Each function should add entries to the "task" dict. 44 | """ 45 | 46 | from typing import Optional 47 | 48 | import tensorflow as tf 49 | 50 | from src.data.utils.data_utils import tree_merge 51 | 52 | 53 | def uniform(traj: dict, max_goal_distance: Optional[int] = None) -> dict: 54 | """ 55 | Relabels with a true uniform distribution over future states. 56 | Optionally caps goal distance. 57 | """ 58 | traj_len = tf.shape(tf.nest.flatten(traj["observation"])[0])[0] 59 | 60 | # select a random future index for each transition i in the range [i, traj_len) 61 | rand = tf.random.uniform([traj_len]) 62 | low = tf.cast(tf.range(traj_len), tf.float32) 63 | if max_goal_distance is not None: 64 | high = tf.cast( 65 | tf.minimum(tf.range(traj_len) + max_goal_distance, traj_len), tf.float32 66 | ) 67 | else: 68 | high = tf.cast(traj_len, tf.float32) 69 | goal_idxs = tf.cast(rand * (high - low) + low, tf.int32) 70 | 71 | # sometimes there are floating-point errors that cause an out-of-bounds 72 | goal_idxs = tf.minimum(goal_idxs, traj_len - 1) 73 | 74 | # adds keys to "task" mirroring "observation" keys (must do a tree merge to combine "pad_mask_dict" from 75 | # "observation" and "task" properly) 76 | goal = tf.nest.map_structure(lambda x: tf.gather(x, goal_idxs), traj["observation"]) 77 | traj["task"] = tree_merge(traj["task"], goal) 78 | 79 | return traj 80 | -------------------------------------------------------------------------------- /src/data/utils/text_processing.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Optional, Sequence 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | MULTI_MODULE = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3" 8 | 9 | 10 | class TextProcessor(ABC): 11 | """ 12 | Base class for text tokenization or text embedding. 13 | """ 14 | 15 | @abstractmethod 16 | def encode(self, strings: Sequence[str]): 17 | raise NotImplementedError 18 | 19 | 20 | class HFTokenizer(TextProcessor): 21 | def __init__( 22 | self, 23 | tokenizer_name: str, 24 | tokenizer_kwargs: Optional[dict] = { 25 | "max_length": 64, 26 | "padding": "max_length", 27 | "truncation": True, 28 | "return_tensors": "np", 29 | }, 30 | encode_with_model: bool = False, 31 | ): 32 | from transformers import AutoTokenizer, FlaxAutoModel # lazy import 33 | 34 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 35 | self.tokenizer_kwargs = tokenizer_kwargs 36 | self.encode_with_model = encode_with_model 37 | if self.encode_with_model: 38 | self.model = FlaxAutoModel.from_pretrained(tokenizer_name) 39 | 40 | def encode(self, strings: Sequence[str]): 41 | # this creates another nested layer with "input_ids", "attention_mask", etc. 42 | inputs = self.tokenizer( 43 | strings, 44 | **self.tokenizer_kwargs, 45 | ) 46 | if self.encode_with_model: 47 | return np.array(self.model(**inputs).last_hidden_state) 48 | else: 49 | return dict(inputs) 50 | 51 | 52 | class MuseEmbedding(TextProcessor): 53 | def __init__(self): 54 | import tensorflow_hub as hub # lazy import 55 | import tensorflow_text # noqa: F401 56 | 57 | self.muse_model = hub.load(MULTI_MODULE) 58 | 59 | def encode(self, strings: Sequence[str]): 60 | with tf.device("/cpu:0"): 61 | return self.muse_model(strings).numpy() 62 | 63 | 64 | class CLIPTextProcessor(TextProcessor): 65 | def __init__( 66 | self, 67 | tokenizer_kwargs: Optional[dict] = { 68 | "max_length": 64, 69 | "padding": "max_length", 70 | "truncation": True, 71 | "return_tensors": "np", 72 | }, 73 | ): 74 | from transformers import CLIPProcessor 75 | 76 | self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") 77 | self.kwargs = tokenizer_kwargs 78 | 79 | def encode(self, strings: Sequence[str]): 80 | inputs = self.processor( 81 | text=strings, 82 | **self.kwargs, 83 | ) 84 | inputs["position_ids"] = np.expand_dims( 85 | np.arange(inputs["input_ids"].shape[1]), axis=0 86 | ).repeat(inputs["input_ids"].shape[0], axis=0) 87 | return inputs 88 | 89 | from abc import ABC, abstractmethod 90 | from typing import Optional, Sequence 91 | 92 | import numpy as np 93 | import tensorflow as tf 94 | 95 | MULTI_MODULE = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3" 96 | 97 | 98 | class TextProcessor(ABC): 99 | """ 100 | Base class for text tokenization or text embedding. 101 | """ 102 | 103 | @abstractmethod 104 | def encode(self, strings: Sequence[str]): 105 | raise NotImplementedError 106 | 107 | 108 | class HFTokenizer(TextProcessor): 109 | def __init__( 110 | self, 111 | tokenizer_name: str, 112 | tokenizer_kwargs: Optional[dict] = { 113 | "max_length": 64, 114 | "padding": "max_length", 115 | "truncation": True, 116 | "return_tensors": "np", 117 | }, 118 | encode_with_model: bool = False, 119 | ): 120 | from transformers import AutoTokenizer, FlaxAutoModel # lazy import 121 | 122 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 123 | self.tokenizer_kwargs = tokenizer_kwargs 124 | self.encode_with_model = encode_with_model 125 | if self.encode_with_model: 126 | self.model = FlaxAutoModel.from_pretrained(tokenizer_name) 127 | 128 | def encode(self, strings: Sequence[str]): 129 | # this creates another nested layer with "input_ids", "attention_mask", etc. 130 | inputs = self.tokenizer( 131 | strings, 132 | **self.tokenizer_kwargs, 133 | ) 134 | if self.encode_with_model: 135 | return np.array(self.model(**inputs).last_hidden_state) 136 | else: 137 | return dict(inputs) 138 | 139 | 140 | class MuseEmbedding(TextProcessor): 141 | def __init__(self): 142 | import tensorflow_hub as hub # lazy import 143 | import tensorflow_text # noqa: F401 144 | 145 | self.muse_model = hub.load(MULTI_MODULE) 146 | 147 | def encode(self, strings: Sequence[str]): 148 | with tf.device("/cpu:0"): 149 | return self.muse_model(strings).numpy() 150 | 151 | 152 | class CLIPTextProcessor(TextProcessor): 153 | def __init__( 154 | self, 155 | tokenizer_kwargs: Optional[dict] = { 156 | "max_length": 64, 157 | "padding": "max_length", 158 | "truncation": True, 159 | "return_tensors": "np", 160 | }, 161 | ): 162 | from transformers import CLIPProcessor 163 | 164 | self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") 165 | self.kwargs = tokenizer_kwargs 166 | 167 | def encode(self, strings: Sequence[str]): 168 | inputs = self.processor( 169 | text=strings, 170 | **self.kwargs, 171 | ) 172 | inputs["position_ids"] = np.expand_dims( 173 | np.arange(inputs["input_ids"].shape[1]), axis=0 174 | ).repeat(inputs["input_ids"].shape[0], axis=0) 175 | return inputs 176 | -------------------------------------------------------------------------------- /src/experiments/env_adapters/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class BaseEnvAdapter: 5 | def __init__(self): 6 | pass 7 | 8 | def normalize_bound( 9 | self, 10 | data: np.ndarray, 11 | data_min: np.ndarray, 12 | data_max: np.ndarray, 13 | clip_min: float = -1, 14 | clip_max: float = 1, 15 | eps: float = 1e-8, 16 | ) -> np.ndarray: 17 | ndata = 2 * (data - data_min) / (data_max - data_min + eps) - 1 18 | return np.clip(ndata, clip_min, clip_max) 19 | 20 | def denormalize_bound( 21 | self, 22 | data: np.ndarray, 23 | data_min: np.ndarray, 24 | data_max: np.ndarray, 25 | clip_min: float = -1, 26 | clip_max: float = 1, 27 | eps=1e-8, 28 | ) -> np.ndarray: 29 | clip_range = clip_max - clip_min 30 | rdata = (data - clip_min) / clip_range * (data_max - data_min) + data_min 31 | return rdata 32 | 33 | def normalize_gaussian( 34 | self, 35 | data: np.ndarray, 36 | mean: np.ndarray, 37 | std: np.ndarray, 38 | eps: float = 1e-8, 39 | ) -> np.ndarray: 40 | return (data - mean) / (std + eps) 41 | 42 | def denormalize_gaussian( 43 | self, 44 | data: np.ndarray, 45 | mean: np.ndarray, 46 | std: np.ndarray, 47 | eps: float = 1e-8, 48 | ) -> np.ndarray: 49 | return data * (std + eps) + mean 50 | 51 | -------------------------------------------------------------------------------- /src/experiments/env_adapters/language_mapper.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | 4 | class PersistentLanguageMapper: 5 | def __init__(self, mapping_candidates: dict[str, list[str]], seed: int = 42): 6 | """ 7 | mapping_candidates: Dict mapping keys to list of possible output strings. 8 | Example: {'A': ['apple', 'apricot'], 'B': ['banana', 'blueberry']} 9 | """ 10 | self.mapping_candidates = mapping_candidates 11 | self.mapping = {} 12 | self._random = random.Random(seed) 13 | 14 | def map(self, key: str) -> str: 15 | if key not in self.mapping: 16 | if key not in self.mapping_candidates: 17 | raise KeyError(f"No candidates defined for key '{key}'") 18 | self.mapping[key] = self._random.choice(self.mapping_candidates[key]) 19 | return self.mapping[key] 20 | 21 | def reset(self): 22 | self.mapping.clear() # keep RNG state to continue sequence 23 | -------------------------------------------------------------------------------- /src/experiments/envs/libero/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "libero-experiment" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.10" 7 | dependencies = [ 8 | "lerobot", 9 | "libero", 10 | "draccus", 11 | "policy-server-client", 12 | "torch==2.5.1", # there is a small change in 2.6.0 that breaks libero. It's easy to fix in libero's code but we decided to respect their versioning 13 | "torchvision==0.20.1", 14 | "hydra-core==1.2.0", 15 | "numpy>=1.22.4", 16 | "wandb>=0.13.1", 17 | "easydict==1.9", 18 | "transformers>=4.48.0", 19 | "tensorflow==2.15.0", 20 | "opencv-python", 21 | "robomimic==0.2.0", 22 | "einops", 23 | "thop==0.1.1-2209072238", 24 | "robosuite==1.4.0", 25 | "bddl==1.0.1", 26 | "future==0.18.2", 27 | "matplotlib==3.5.3", 28 | "cloudpickle==2.1.0", 29 | "gym==0.25.2" 30 | ] 31 | 32 | [tool.uv.sources] 33 | libero = { path = "../../../../third_party/LIBERO", editable = true } 34 | policy-server-client = { path = "../../../../packages/policy-server-client", editable = true } 35 | lerobot = { path = "../../../../third_party/lerobot", editable = true } 36 | -------------------------------------------------------------------------------- /src/experiments/envs/simpler/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "simpler-experiment" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.10" 7 | dependencies = [ 8 | "lerobot", 9 | "draccus", 10 | "policy-server-client", 11 | "simpler-env", 12 | "mani-skill2-real2sim", 13 | "pytest>=8.3.5", 14 | "tensorflow>=2.15.0", 15 | # # maniskill dependencies 16 | "numpy<=1.26.4", # it's important to use pre-2.0 numpy. 2.0 has breaking changes and lots of donwstream packages haven't updated yet 17 | # # Simpler dependencies 18 | "transformers>=4.48.3", 19 | ] 20 | 21 | [tool.uv.sources] 22 | dlimp = {git = "https://github.com/kvablack/dlimp", rev = "d08da3852c149548aaa8551186d619d87375df08"} 23 | simpler-env = { path = "../../../../third_party/SimplerEnv", editable = true } 24 | mani-skill2-real2sim = { path = "../../../../third_party/ManiSkill2_real2sim", editable = true } 25 | policy-server-client = { path = "../../../../packages/policy-server-client", editable = true } 26 | lerobot = { path = "../../../../third_party/lerobot", editable = true } 27 | -------------------------------------------------------------------------------- /src/experiments/envs/simplerMS3/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "simplerMS3-experiment" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.10" 7 | dependencies = [ 8 | "lerobot", 9 | "draccus", 10 | "policy-server-client", 11 | "simpler-env", 12 | "mani_skill", 13 | "pytest>=8.3.5", 14 | # "tensorflow>=2.15.0", 15 | "torch==2.6.0", 16 | "torchvision==0.21.0", 17 | # # maniskill dependencies 18 | "numpy<=1.26.4", # it's important to use pre-2.0 numpy. 2.0 has breaking changes and lots of donwstream packages haven't updated yet 19 | # # Simpler dependencies 20 | "transformers>=4.48.3", 21 | ] 22 | 23 | [tool.uv.sources] 24 | simpler-env = { path = "../../../../third_party/SimplerEnv_MS3", editable = true } 25 | mani_skill = { path = "../../../../third_party/ManiSkill/", editable = true } 26 | policy-server-client = { path = "../../../../packages/policy-server-client", editable = true } 27 | lerobot = { path = "../../../../third_party/lerobot", editable = true } 28 | 29 | # [tool.uv.pip] 30 | # # Allow all pre-release versions when resolving or installing 31 | # prerelease = "allow" -------------------------------------------------------------------------------- /src/experiments/policies/magma_policy_server/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "magma_policy_server" 7 | version = "0.0.1" 8 | description = "policy environment for magma" 9 | authors = [ 10 | {name = "Irving Fang", email = "irving.fang@nyu.edu"}, 11 | {name = "Juexiao Zhang", email = "place@holder.com"} 12 | ] 13 | requires-python = ">=3.10, <3.13" 14 | 15 | dependencies = [ 16 | "lerobot", 17 | "transformers==4.49.0", 18 | "opencv-python>=4.11.0", 19 | "pytest", 20 | "policy-server-client", 21 | "numpy==1.26.4", # it's important to use pre-2.0 numpy. 2.0 has breaking changes and lots of donwstream packages haven't updated yet 22 | "scipy<=1.12.0,>=1.6.0", 23 | "torch==2.6.0", 24 | "draccus", 25 | "protobuf==3.20.3", 26 | "open-clip-torch>=2.32.0", 27 | "flash-attn", 28 | "accelerate>=1.6.0", 29 | ] 30 | 31 | 32 | [tool.ruff] 33 | line-length = 170 34 | target-version = "py310" 35 | 36 | [tool.ruff.lint] 37 | select = ["A", "B", "E", "F", "I", "RUF", "W"] 38 | ignore = ["F722"] 39 | 40 | [tool.ruff.lint.per-file-ignores] 41 | "__init__.py" = ["E402", "F401"] 42 | 43 | [tool.uv.sources] 44 | policy-server-client = { path = "../../../../packages/policy-server-client", editable = true } 45 | lerobot = { path = "../../../../third_party/lerobot", editable = true } 46 | 47 | [tool.uv] 48 | no-build-isolation-package = ['flash-attn'] 49 | -------------------------------------------------------------------------------- /src/experiments/policies/octo_policy_server/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "octo_policy_server" 7 | version = "0.0.1" 8 | description = "policy environment for octo" 9 | authors = [ 10 | {name = "Irving Fang", email = "irving.fang@nyu.edu"}, 11 | {name = "Juexiao Zhang", email = "place@holder.com"} 12 | ] 13 | requires-python = ">=3.10, <3.13" 14 | 15 | dependencies = [ 16 | "lerobot", 17 | "opencv-python>=4.11.0", 18 | "transformers==4.48.3", 19 | "pytest", 20 | "policy-server-client", 21 | "numpy==1.26.4", # it's important to use pre-2.0 numpy. 2.0 has breaking changes and lots of donwstream packages haven't updated yet 22 | "scipy<=1.12.0,>=1.6.0", 23 | "torch==2.6.0", 24 | "draccus", 25 | "protobuf==3.20.3", 26 | "accelerate>=1.6.0", 27 | "jax[cuda12-pip]==0.4.20", 28 | "jaxlib", 29 | "flax==0.8.1", 30 | "dlimp", 31 | "octo", 32 | "distrax==0.1.5", 33 | "tensorflow_probability==0.23.0", 34 | "scipy<=1.12.0,>=1.6.0" 35 | ] 36 | 37 | 38 | [tool.ruff] 39 | line-length = 170 40 | target-version = "py310" 41 | 42 | [tool.ruff.lint] 43 | select = ["A", "B", "E", "F", "I", "RUF", "W"] 44 | ignore = ["F722"] 45 | 46 | [tool.ruff.lint.per-file-ignores] 47 | "__init__.py" = ["E402", "F401"] 48 | 49 | [tool.uv.sources] 50 | policy-server-client = { path = "../../../../packages/policy-server-client", editable = true } 51 | lerobot = { path = "../../../../third_party/lerobot", editable = true } 52 | jaxlib = { url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-0.4.20+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl" } 53 | dlimp = { git = "https://github.com/kvablack/dlimp", rev = "5edaa4691567873d495633f2708982b42edf1972" } 54 | octo = { git = "https://github.com/octo-models/octo", branch = "main" } 55 | -------------------------------------------------------------------------------- /src/model/magma/configuration_magma.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from lerobot.configs.policies import PreTrainedConfig 4 | 5 | 6 | @PreTrainedConfig.register_subclass("magma") 7 | @dataclass 8 | class MagmaConfig(PreTrainedConfig): 9 | # Input / output structure. 10 | n_obs_steps: int = 1 11 | chunk_size: int = 4 12 | 13 | def validate_features(self) -> None: 14 | pass 15 | 16 | def get_optimizer_preset(self): 17 | return None 18 | 19 | def get_scheduler_preset(self): 20 | return None 21 | 22 | @property 23 | def observation_delta_indices(self) -> None: 24 | return None 25 | 26 | @property 27 | def action_delta_indices(self) -> list: 28 | return list(range(self.chunk_size)) 29 | 30 | @property 31 | def reward_delta_indices(self) -> None: 32 | return None 33 | -------------------------------------------------------------------------------- /src/model/magma/modeling_magma.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | ################################################################## 4 | # Placeholder for Magma model 5 | ################################################################## 6 | 7 | from lerobot.common.policies.pretrained import PreTrainedPolicy 8 | 9 | from src.model.magma.configuration_magma import MagmaConfig 10 | 11 | 12 | class MagmaPolicy(PreTrainedPolicy): 13 | """Wrapper class around FusionA model to train and run inference within LeRobot.""" 14 | 15 | config_class = MagmaConfig 16 | name = "spatial-vla" 17 | 18 | def __init__( 19 | self, 20 | config: MagmaConfig, 21 | dataset_stats: dict[str, dict] | None = None, 22 | ): 23 | """ 24 | Args: 25 | config: Policy configuration class instance or None, in which case the default instantiation of 26 | the configuration class is used. 27 | dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected 28 | that they will be passed with a call to `load_state_dict` before the policy is used. 29 | """ 30 | 31 | super().__init__(config) 32 | config.validate_features() 33 | self.config = config 34 | 35 | -------------------------------------------------------------------------------- /src/model/octo/configuration_octo.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from lerobot.configs.policies import PreTrainedConfig 4 | 5 | 6 | @PreTrainedConfig.register_subclass("octo") 7 | @dataclass 8 | class OctoConfig(PreTrainedConfig): 9 | # Input / output structure. 10 | n_obs_steps: int = 2 11 | chunk_size: int = 4 12 | action_ensemble_temp: float = 0 13 | 14 | def validate_features(self) -> None: 15 | pass 16 | 17 | def get_optimizer_preset(self): 18 | return None 19 | 20 | def get_scheduler_preset(self): 21 | return None 22 | 23 | @property 24 | def observation_delta_indices(self) -> None: 25 | return None 26 | 27 | @property 28 | def action_delta_indices(self) -> list: 29 | return list(range(self.chunk_size)) 30 | 31 | @property 32 | def reward_delta_indices(self) -> None: 33 | return None 34 | -------------------------------------------------------------------------------- /src/model/octo/modeling_octo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | ################################################################## 4 | # Placeholder for Octo model 5 | ################################################################## 6 | 7 | from lerobot.common.policies.pretrained import PreTrainedPolicy 8 | 9 | from src.model.octo.configuration_octo import OctoConfig 10 | 11 | 12 | class OctoPolicy(PreTrainedPolicy): 13 | """Wrapper class around Octo model to train and run inference within LeRobot.""" 14 | 15 | config_class = OctoConfig 16 | name = "octo" 17 | 18 | def __init__( 19 | self, 20 | config: OctoConfig, 21 | dataset_stats: dict[str, dict] | None = None, 22 | ): 23 | """ 24 | Args: 25 | config: Policy configuration class instance or None, in which case the default instantiation of 26 | the configuration class is used. 27 | dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected 28 | that they will be passed with a call to `load_state_dict` before the policy is used. 29 | """ 30 | 31 | super().__init__(config) 32 | config.validate_features() 33 | self.config = config 34 | 35 | -------------------------------------------------------------------------------- /src/model/spatialvla/configuration_spatialvla.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from lerobot.configs.policies import PreTrainedConfig 4 | 5 | 6 | @PreTrainedConfig.register_subclass("spatial-vla") 7 | @dataclass 8 | class SpatialVLAConfig(PreTrainedConfig): 9 | # Input / output structure. 10 | n_obs_steps: int = 1 11 | chunk_size: int = 4 12 | action_ensemble_temp: float = -0.8 13 | 14 | def validate_features(self) -> None: 15 | pass 16 | 17 | def get_optimizer_preset(self): 18 | return None 19 | 20 | def get_scheduler_preset(self): 21 | return None 22 | 23 | @property 24 | def observation_delta_indices(self) -> None: 25 | return None 26 | 27 | @property 28 | def action_delta_indices(self) -> list: 29 | return list(range(self.chunk_size)) 30 | 31 | @property 32 | def reward_delta_indices(self) -> None: 33 | return None 34 | -------------------------------------------------------------------------------- /src/model/spatialvla/modeling_spatialvla.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | ################################################################## 4 | # Placeholder for SpatialVLA model 5 | ################################################################## 6 | 7 | from lerobot.common.policies.pretrained import PreTrainedPolicy 8 | 9 | from src.model.spatialvla.configuration_spatialvla import SpatialVLAConfig 10 | 11 | 12 | class SpatialVLAPolicy(PreTrainedPolicy): 13 | """Wrapper class around FusionA model to train and run inference within LeRobot.""" 14 | 15 | config_class = SpatialVLAConfig 16 | name = "spatial-vla" 17 | 18 | def __init__( 19 | self, 20 | config: SpatialVLAConfig, 21 | dataset_stats: dict[str, dict] | None = None, 22 | ): 23 | """ 24 | Args: 25 | config: Policy configuration class instance or None, in which case the default instantiation of 26 | the configuration class is used. 27 | dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected 28 | that they will be passed with a call to `load_state_dict` before the policy is used. 29 | """ 30 | 31 | super().__init__(config) 32 | config.validate_features() 33 | self.config = config 34 | 35 | -------------------------------------------------------------------------------- /src/utils/decorator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def conditional_decorator(dec, condition): 5 | def decorator(func): 6 | if not condition: 7 | # Return the function unchanged, not decorated. 8 | return func 9 | return dec(func) 10 | 11 | return decorator 12 | 13 | 14 | class NoSyncBase: 15 | def no_sync(self): 16 | if self.use_ddp: 17 | # If DDP is used, call the actual `no_sync` method 18 | return torch.nn.parallel.DistributedDataParallel.no_sync(self) 19 | else: 20 | # Otherwise, return the dummy context manager 21 | class DummyContext: 22 | def __enter__(self): 23 | pass 24 | 25 | def __exit__(self, exc_type, exc_value, traceback): 26 | pass 27 | 28 | return DummyContext() 29 | -------------------------------------------------------------------------------- /src/utils/metric.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | 5 | 6 | def get_action_accuracy( 7 | gt: torch.FloatTensor, # [Batch_Size, Horizon, Action_Dim] 8 | pred: torch.FloatTensor, 9 | thresholds: List[float] = [0.1, 0.2], 10 | ) -> torch.FloatTensor: 11 | device = gt.device 12 | diff = torch.abs(gt - pred).reshape(-1, gt.shape[-1]) 13 | 14 | # get the percentage of diff lower than threshold for all action dimensions 15 | accuracies = torch.zeros(len(thresholds), device=device) 16 | for idx, threshold in enumerate(thresholds): 17 | accuracy = torch.mean( 18 | (torch.mean((diff < threshold).float(), dim=1) >= 1.0).float() 19 | ) 20 | accuracies[idx] = accuracy 21 | return accuracies 22 | 23 | # TODO: (juexiao) add text accuracy or other metric 24 | -------------------------------------------------------------------------------- /src/utils/monitor.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import inspect 3 | import logging 4 | import os 5 | import sys 6 | import time 7 | from typing import Optional 8 | 9 | import torch 10 | 11 | def save_bad_data(data, step, save_dir="debug-bad-data"): 12 | os.makedirs(save_dir, exist_ok=True) 13 | save_path = os.path.join(save_dir, f"bad_data_step_{step}.pt") 14 | torch.save(data, save_path) 15 | print(f"Saved bad data to {save_path}") 16 | 17 | def log_allocated_gpu_memory(log=None, stage="loading model", device=0): 18 | if torch.cuda.is_available(): 19 | allocated_memory = torch.cuda.memory_allocated(device) 20 | msg = f"Allocated GPU memory after {stage}: {allocated_memory/1024/1024/1024:.2f} GB" 21 | print(msg) if log is None else log.info(msg) 22 | 23 | def log_execution_time(logger=None): 24 | """Decorator to log the execution time of a function""" 25 | 26 | def decorator(func): 27 | @functools.wraps(func) 28 | def wrapper(*args, **kwargs): 29 | start_time = time.time() 30 | result = func(*args, **kwargs) 31 | end_time = time.time() 32 | elapsed_time = end_time - start_time 33 | if logger is None: 34 | print(f"{func.__name__} took {elapsed_time:.2f} seconds to execute.") 35 | else: 36 | logger.info( 37 | f"{func.__name__} took {elapsed_time:.2f} seconds to execute." 38 | ) 39 | return result 40 | 41 | return wrapper 42 | 43 | return decorator 44 | 45 | def blockprint(): 46 | """Block printing to stdout and stderr for the current process.""" 47 | sys.stdout = open(os.devnull, 'w') 48 | sys.stderr = open(os.devnull, 'w') 49 | 50 | def setup_logger(main_rank: bool, 51 | filename: str|None = None, 52 | debug: bool = False, 53 | name: Optional[str] = None) -> logging.Logger: 54 | ''' 55 | Set up a logger for the script. 56 | main_rank: bool, whether this is the main process. We only set up logging for the main process. 57 | filename: str, the name of the file to log to. If None, logs to stdout. 58 | debug: bool, whether to log in debug mode. 59 | name: str, the name of the logger. If None, uses the name of the calling module. 60 | ''' 61 | 62 | if name is None: 63 | # Use the name of the calling module as the logger name 64 | logger = logging.getLogger(name= 65 | os.path.splitext(os.path.basename(inspect.stack()[1].filename))[0] 66 | ) 67 | else: 68 | logger = logging.getLogger(name=name) 69 | 70 | # Remove any existing handlers to avoid duplicate logs 71 | if logger.hasHandlers(): 72 | logger.handlers.clear() 73 | 74 | # Only set up logging for rank 0 75 | if main_rank: 76 | if debug: 77 | logger.setLevel(logging.DEBUG) # Everything at DEBUG level and above will be logged 78 | else: 79 | logger.setLevel(logging.INFO) # Everything at INFO level and above will be logged 80 | 81 | # Create a file handler if filename is provided, otherwise use stdout 82 | if filename: 83 | handler = logging.FileHandler(filename) 84 | else: 85 | handler = logging.StreamHandler(sys.stdout) 86 | 87 | format_str = "[%(asctime)s,%(msecs)03d][%(name)s:%(lineno)d][%(levelname)s] - %(message)s" 88 | date_format = "%Y-%m-%d %H:%M:%S" 89 | formatter = logging.Formatter(fmt=format_str, datefmt=date_format) 90 | handler.setFormatter(formatter) 91 | logger.addHandler(handler) 92 | else: 93 | # Disable logging for non-master processes. 94 | logger.setLevel(logging.ERROR) 95 | # Optionally add a NullHandler to absorb logs. 96 | logger.addHandler(logging.NullHandler()) 97 | 98 | logger.propagate = False # so no duplicate, unformatted log in stderr 99 | return logger 100 | 101 | class Timer: 102 | def __init__(self): 103 | self._start = time.time() 104 | 105 | def __call__(self, reset=True): 106 | now = time.time() 107 | diff = now - self._start 108 | if reset: 109 | self._start = now 110 | return diff 111 | -------------------------------------------------------------------------------- /src/utils/pipeline.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | import torch 7 | 8 | IMAGENET_STANDARD_MEAN = torch.tensor([0.5, 0.5, 0.5]) 9 | IMAGENET_STANDARD_STD = torch.tensor([0.5, 0.5, 0.5]) 10 | 11 | def set_seed_everywhere(seed: int, train: bool = True): 12 | """Sets the random seed for Python, NumPy, and PyTorch functions.""" 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed_all(seed) 15 | np.random.seed(seed) 16 | random.seed(seed) 17 | if train: 18 | import tensorflow as tf 19 | tf.random.set_seed(seed) 20 | torch.backends.cudnn.deterministic = True 21 | torch.backends.cudnn.benchmark = False 22 | os.environ["PYTHONHASHSEED"] = str(seed) 23 | 24 | def get_class_from_path(class_path: str): 25 | ''' 26 | class_path: str 27 | The full path to the class, including the module name and class name. 28 | For example: "my_module.MyClass" 29 | ''' 30 | module_name, class_name = class_path.rsplit(".", 1) 31 | module = importlib.import_module(module_name) 32 | return getattr(module, class_name) 33 | 34 | def rescale( 35 | image: torch.LongTensor, 36 | scale: float, 37 | ) -> torch.FloatTensor: 38 | rescaled_image = image * scale 39 | return rescaled_image 40 | 41 | 42 | def normalize( 43 | image: torch.LongTensor, 44 | mean: torch.FloatTensor, 45 | std: torch.FloatTensor, 46 | ) -> torch.FloatTensor: 47 | assert image.ndim == 4, f"Expected 4D tensor, got {image.ndim}D tensor." 48 | assert ( 49 | image.shape[1] == 3 50 | ), f"Expected 3 channels at axis 1, got {image.shape[1]} channels." 51 | mean = mean[None, :, None, None] # add batch and spatial dimensions 52 | std = std[None, :, None, None] 53 | image = (image - mean) / std 54 | return image 55 | 56 | 57 | def process_images( 58 | images: torch.LongTensor, 59 | rescale_factor: float, 60 | image_mean: torch.FloatTensor = IMAGENET_STANDARD_MEAN, 61 | image_std: torch.FloatTensor = IMAGENET_STANDARD_STD, 62 | ) -> torch.FloatTensor: 63 | # Rescale the pixel values to be in the range [0, 1] 64 | images = rescale(images, scale=rescale_factor) 65 | 66 | # Normalize the images to have mean 0 and standard deviation 1 67 | images = normalize(images, mean=image_mean, std=image_std) 68 | 69 | return images 70 | 71 | def revert_processed_images( 72 | processed_image: torch.FloatTensor, 73 | image_mean: torch.FloatTensor = IMAGENET_STANDARD_MEAN, 74 | image_std: torch.FloatTensor = IMAGENET_STANDARD_STD, 75 | rescale_factor: float = 1/255.0 76 | ) -> torch.LongTensor: 77 | rescale_factor = torch.tensor(rescale_factor).to(processed_image.device) 78 | # Undo normalization 79 | mean = image_mean[None, :, None, None] # Add batch and spatial dimensions 80 | mean = mean.to(processed_image.device) 81 | std = image_std[None, :, None, None] 82 | std = std.to(processed_image.device) 83 | image = processed_image * std + mean # Convert back to [0,1] 84 | 85 | # Undo rescaling 86 | image = image / rescale_factor # Convert back to [0,255] 87 | 88 | # Clip values to valid range and convert to integer 89 | image = torch.clamp(image, 0, 255).to(torch.uint8) 90 | 91 | return image 92 | -------------------------------------------------------------------------------- /src/utils/spec.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from functools import partial 3 | from typing import Any, Dict, Tuple, TypedDict, Union 4 | 5 | 6 | class ModuleSpec(TypedDict): 7 | """A JSON-serializable representation of a function or class with some default args and kwargs to pass to it. Useful for specifying a particular class or function in a config file, while keeping it serializable and overridable from the command line using ml_collections. 8 | 9 | Usage: 10 | 11 | # Preferred way to create a spec: 12 | >>> from src.model.components.transformer import Transformer 13 | >>> spec = ModuleSpec.create(Transformer, num_layers=3) 14 | # Same as above using the fully qualified import string: 15 | >>> spec = ModuleSpec.create("src.model.components.transformer:Transformer", num_layers=3) 16 | 17 | # Usage: 18 | >>> ModuleSpec.instantiate(spec) == partial(Transformer, num_layers=3) 19 | # can pass additional kwargs at instantiation time 20 | >>> transformer = ModuleSpec.instantiate(spec, num_heads=8) 21 | 22 | Note: ModuleSpec is just an alias for a dictionary (that is strongly typed), not a real class. So from 23 | your code's perspective, it is just a dictionary. 24 | 25 | module (str): The module the callable is located in 26 | name (str): The name of the callable in the module 27 | args (tuple): The args to pass to the callable 28 | kwargs (dict): The kwargs to pass to the callable 29 | """ 30 | 31 | module: str 32 | name: str 33 | args: Tuple[Any, ...] 34 | kwargs: Dict[str, Any] 35 | 36 | @staticmethod 37 | def create( 38 | callable_or_full_name: Union[str, callable], *args, **kwargs 39 | ) -> "ModuleSpec": # type: ignore 40 | """Create a module spec from a callable or import string. 41 | 42 | Args: 43 | callable_or_full_name (str or object): Either the object itself or a fully qualified import string 44 | (e.g. "src.model.components.transformer:Transformer") 45 | args (tuple, optional): Passed into callable upon instantiation. 46 | kwargs (dict, optional): Passed into callable upon instantiation. 47 | """ 48 | if isinstance(callable_or_full_name, str): 49 | assert callable_or_full_name.count(":") == 1, ( 50 | "If passing in a string, it must be a fully qualified import string " 51 | "(e.g. 'src.model.components.transformer:Transformer')" 52 | ) 53 | module, name = callable_or_full_name.split(":") 54 | else: 55 | module, name = _infer_full_name(callable_or_full_name) 56 | 57 | return ModuleSpec(module=module, name=name, args=args, kwargs=kwargs) 58 | 59 | @staticmethod 60 | def instantiate(spec: "ModuleSpec"): # type: ignore 61 | if set(spec.keys()) != {"module", "name", "args", "kwargs"}: 62 | raise ValueError( 63 | f"Expected ModuleSpec, but got {spec}. " 64 | "ModuleSpec must have keys 'module', 'name', 'args', and 'kwargs'." 65 | ) 66 | cls = _import_from_string(spec["module"], spec["name"]) 67 | return partial(cls, *spec["args"], **spec["kwargs"]) 68 | 69 | @staticmethod 70 | def to_string(spec: "ModuleSpec"): # type: ignore 71 | return ( 72 | f"{spec['module']}:{spec['name']}" 73 | f"({', '.join(spec['args'])}" 74 | f"{', ' if spec['args'] and spec['kwargs'] else ''}" 75 | f"{', '.join(f'{k}={v}' for k, v in spec['kwargs'].items())})" 76 | ) 77 | 78 | 79 | def _infer_full_name(o: object): 80 | if hasattr(o, "__module__") and hasattr(o, "__name__"): 81 | return o.__module__, o.__name__ 82 | else: 83 | raise ValueError( 84 | f"Could not infer identifier for {o}. " 85 | "Please pass in a fully qualified import string instead " 86 | "e.g. 'src.model.components.transformer:Transformer'" 87 | ) 88 | 89 | 90 | def _import_from_string(module_string: str, name: str): 91 | try: 92 | module = importlib.import_module(module_string) 93 | return getattr(module, name) 94 | except Exception as e: 95 | raise ValueError(f"Could not import {module_string}:{name}") from e 96 | --------------------------------------------------------------------------------