├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── config ├── bridge_statistics.json ├── eval │ ├── bridge.yaml │ ├── fractal_apple.yaml │ ├── fractal_coke.yaml │ ├── fractal_drawer.yaml │ └── fractal_move.yaml ├── fractal_statistics.json └── train │ ├── bridge.yaml │ └── fractal.yaml ├── doc ├── convention.md ├── error.md └── notes.md ├── media ├── maniskill_pp.png └── open-pi-zero-overview.png ├── pyproject.toml ├── scripts ├── data │ ├── check_bridge.py │ ├── check_fractal.py │ └── modify_rlds_dataset.py ├── run.py ├── set_path.sh ├── tests │ ├── oxe.py │ ├── run_paligemma.py │ ├── sampling.py │ └── simpler.py └── try_checkpoint_in_simpler.py ├── slurm ├── eval_simpler_bridge.sh ├── eval_simpler_fractal.sh ├── modify_rlds.sh ├── test_training_single_gpu_no_slurm.sh ├── train_multi_gpu.sh └── train_multi_node.sh └── src ├── agent ├── dataset.py ├── env_adapter │ ├── base.py │ └── simpler.py ├── eval.py ├── model_averaging.py └── train.py ├── data ├── dataset.py ├── dataset_torch.py ├── dlimp │ ├── __init__.py │ ├── augmentations.py │ ├── dataset.py │ ├── transforms │ │ ├── __init__.py │ │ ├── common.py │ │ ├── frame_transforms.py │ │ ├── goal_relabeling.py │ │ └── traj_transforms.py │ └── utils.py ├── obs_transforms.py ├── oxe │ ├── __init__.py │ ├── oxe_dataset_configs.py │ ├── oxe_dataset_mixes.py │ ├── oxe_standardization_transforms.py │ └── preprocess │ │ ├── mod_functions.py │ │ └── multithreaded_adhoc_tfds_builder.py ├── traj_transforms.py └── utils │ ├── data_utils.py │ ├── goal_relabeling.py │ ├── task_augmentation.py │ └── text_processing.py ├── model ├── kv_cache.py ├── lora.py ├── paligemma │ ├── config.py │ ├── gemma.py │ ├── load.py │ ├── modules.py │ ├── processing.py │ └── siglip.py ├── utils.py └── vla │ ├── joint_model.py │ ├── mixture.py │ ├── modules.py │ ├── pizero.py │ └── processing.py └── utils ├── decorator.py ├── geometry.py ├── metric.py ├── monitor.py ├── optim.py └── spec.py /.gitignore: -------------------------------------------------------------------------------- 1 | launch.sh 2 | uv.lock 3 | temp/ 4 | .ruff_cache/ 5 | launch/ 6 | logs/ 7 | results/ 8 | wandb/ 9 | *.mp4 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | cover/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | .pybuilder/ 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | # For a library or package, you might want to ignore these files since the code is 97 | # intended to run in multiple environments; otherwise, check them in: 98 | # .python-version 99 | 100 | # pipenv 101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 104 | # install all needed dependencies. 105 | #Pipfile.lock 106 | 107 | # poetry 108 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 109 | # This is especially recommended for binary packages to ensure reproducibility, and is more 110 | # commonly ignored for libraries. 111 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 112 | #poetry.lock 113 | 114 | # pdm 115 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 116 | #pdm.lock 117 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 118 | # in version control. 119 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 120 | .pdm.toml 121 | .pdm-python 122 | .pdm-build/ 123 | 124 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 125 | __pypackages__/ 126 | 127 | # Celery stuff 128 | celerybeat-schedule 129 | celerybeat.pid 130 | 131 | # SageMath parsed files 132 | *.sage.py 133 | 134 | # Environments 135 | .env 136 | .venv 137 | env/ 138 | venv/ 139 | ENV/ 140 | env.bak/ 141 | venv.bak/ 142 | 143 | # Spyder project settings 144 | .spyderproject 145 | .spyproject 146 | 147 | # Rope project settings 148 | .ropeproject 149 | 150 | # mkdocs documentation 151 | /site 152 | 153 | # mypy 154 | .mypy_cache/ 155 | .dmypy.json 156 | dmypy.json 157 | 158 | # Pyre type checker 159 | .pyre/ 160 | 161 | # pytype static type analyzer 162 | .pytype/ 163 | 164 | # Cython debug symbols 165 | cython_debug/ 166 | 167 | # PyCharm 168 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 169 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 170 | # and can be added to the global gitignore or merged into this file. For a more nuclear 171 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 172 | #.idea/ 173 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: ".git" 2 | 3 | repos: 4 | - repo: https://github.com/astral-sh/ruff-pre-commit 5 | rev: v0.9.3 6 | hooks: 7 | - id: ruff-format # formatter 8 | types_or: [python, pyi, jupyter, toml] 9 | - id: ruff # linter 10 | types_or: [python, pyi, jupyter, toml] 11 | args: [--fix, --show-fixes] 12 | 13 | # - repo: https://github.com/pycqa/isort 14 | # rev: 5.12.0 15 | # hooks: 16 | # - id: isort 17 | # exclude: __init__.py 18 | # args: ["--profile", "black"] 19 | 20 | # - repo: https://github.com/RobertCraigie/pyright-python 21 | # rev: v1.1.379 22 | # hooks: 23 | # - id: pyright 24 | # language_version: python3.10 25 | # additional_dependencies: 26 | # [ 27 | # einops, 28 | # pillow, 29 | # tensorflow, 30 | # torch, 31 | # ] 32 | 33 | - repo: https://github.com/pre-commit/pre-commit-hooks 34 | rev: v4.5.0 35 | hooks: 36 | - id: check-added-large-files 37 | - id: check-ast 38 | - id: check-case-conflict 39 | - id: check-merge-conflict 40 | - id: check-toml 41 | - id: check-yaml 42 | - id: end-of-file-fixer 43 | - id: trailing-whitespace 44 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Allen Z. Ren 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /config/bridge_statistics.json: -------------------------------------------------------------------------------- 1 | {"action": {"mean": [0.00021758403454441577, 0.00012507825158536434, -0.00017109014152083546, -0.0001617111702216789, -0.0002524859446566552, 0.0002515816013328731, 0.5879487991333008], "std": [0.009632210247218609, 0.013500974513590336, 0.012510341592133045, 0.028145477175712585, 0.03028254210948944, 0.07585873454809189, 0.4877150356769562], "max": [0.41691166162490845, 0.25864794850349426, 0.21218234300613403, 3.122201919555664, 1.8618112802505493, 6.280478477478027, 1.0], "min": [-0.4007510244846344, -0.13874775171279907, -0.22553899884223938, -3.2010786533355713, -1.8618112802505493, -6.279075622558594, 0.0], "p99": [0.028122276067733765, 0.040630316659808145, 0.03994889184832546, 0.08121915772557152, 0.07724379181861864, 0.20214049845933896, 1.0], "p01": [-0.028539552688598632, -0.041432044506073, -0.025977383628487588, -0.08020886614918708, -0.09213060349225997, -0.2054861941933632, 0.0]}, "num_transitions": 2195527, "num_trajectories": 60064, "proprio": {"mean": [0.30904945731163025, 0.03045589290559292, 0.06558273732662201, 0.00706630339846015, -0.07828629016876221, 0.10661222040653229, 0.7149746417999268], "std": [0.06059328466653824, 0.09172434359788895, 0.05185756832361221, 0.1313914805650711, 0.1698099821805954, 0.573583722114563, 0.3517141044139862], "max": [0.5862360596656799, 0.4034728705883026, 0.36494991183280945, 1.514088749885559, 1.570796251296997, 3.1415255069732666, 1.1154625415802002], "min": [-0.04167502000927925, -0.3945816159248352, -0.15537554025650024, -3.141592502593994, -1.4992541074752808, -3.14153790473938, 0.04637829214334488], "p99": [0.4527312242984769, 0.23490807592868757, 0.1973453593254087, 0.37877989292144754, 0.27723048210143925, 1.8378053522109963, 1.0105689764022827], "p01": [0.17017078369855881, -0.16965715914964677, -0.054787094071507454, -0.3655692100524902, -0.5435487496852874, -1.3501438736915587, 0.052190229296684265]}} 2 | -------------------------------------------------------------------------------- /config/eval/bridge.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | hydra: 4 | run: 5 | dir: ${log_dir} 6 | _target_: src.agent.eval.EvalAgent 7 | 8 | log_dir: ${oc.env:VLA_LOG_DIR}/eval_bridge/${name}_ta${act_steps}_${seed}/${env.task}_${now:%H-%M-%S} 9 | name: 10 | device: cuda 11 | seed: 42 12 | checkpoint_path: 13 | n_eval_episode: 240 # octo simpler runs 3 seeds with 24 configs each, we will run 10 trials for each config 14 | n_video: ${n_eval_episode} 15 | # sweeps: 16 | # urdf_version: 17 | # - null 18 | 19 | env: 20 | task: 21 | adapter: 22 | _target_: src.agent.env_adapter.simpler.BridgeSimplerAdapter 23 | dataset_statistics_path: config/bridge_statistics.json 24 | pretrained_model_path: ${oc.env:TRANSFORMERS_CACHE}/paligemma-3b-pt-224 25 | tokenizer_padding: max_length 26 | max_seq_len: 276 # fixed 256 for image + max 20 for text 27 | num_image_tokens: 256 28 | image_size: [224, 224] 29 | 30 | flow_sampling: beta 31 | num_inference_steps: 10 32 | final_action_clip_value: 1.0 # data normalized in [-1,1] 33 | use_torch_compile: True 34 | use_bf16: False 35 | 36 | cond_steps: 1 37 | horizon_steps: 4 38 | act_steps: 4 39 | action_dim: 7 # EEF_POS 40 | proprio_dim: 7 # POS_EULER 41 | 42 | mixture: 43 | vlm: # gemma 44 | hidden_size: 2048 45 | intermediate_size: 16384 46 | use_final_norm: False 47 | cache: True 48 | use_quantize: ${quantize} 49 | use_lora: ${lora} 50 | adaptive_mode: # not applicable for gemma 51 | rope_theta: 10000.0 # 10000 in gemma 52 | proprio: 53 | hidden_size: 1024 54 | intermediate_size: 4096 55 | use_final_norm: True # technically no, but sharing mixture with action 56 | cache: True 57 | use_quantize: False 58 | use_lora: False 59 | adaptive_mode: ${action_expert_adaptive_mode} 60 | rope_theta: ${action_expert_rope_theta} 61 | action: 62 | hidden_size: 1024 63 | intermediate_size: 4096 64 | use_final_norm: True 65 | cache: False 66 | use_quantize: False 67 | use_lora: False 68 | adaptive_mode: ${action_expert_adaptive_mode} 69 | rope_theta: ${action_expert_rope_theta} 70 | action_expert_adaptive_mode: # adaLN, adaLN-Zero, or None 71 | time_hidden_size: 256 # only applicable if using adaptive 72 | time_max_period: 10000.0 # provided ckpts used 10000.0 for both time_max_period and action_expert_rope_theta 73 | action_expert_rope_theta: 10000.0 74 | quantize: False 75 | lora: False 76 | lora_r: 32 77 | lora_dropout: 0.0 78 | max_image_text_tokens: ${env.adapter.max_seq_len} 79 | 80 | # Fixed 81 | image_token_index: 257152 82 | vocab_size: 257216 83 | pad_token_id: 0 84 | 85 | vision: 86 | _target_: src.model.paligemma.siglip.SiglipVisionModel 87 | config: 88 | hidden_size: 1152 # siglip 89 | intermediate_size: 4304 90 | num_hidden_layers: 27 91 | num_attention_heads: 16 92 | num_channels: 3 93 | image_size: 224 94 | patch_size: 14 95 | layer_norm_eps: 1e-6 96 | attention_dropout: 0.0 97 | num_image_tokens: 256 98 | 99 | vision_projector: 100 | _target_: src.model.paligemma.siglip.PaliGemmaMultiModalProjector 101 | config: 102 | vision_config: 103 | hidden_size: 1152 104 | projection_dim: 2048 105 | 106 | joint: 107 | _target_: src.model.vla.joint_model.JointModel 108 | config: 109 | action_expert_adaptive_mode: ${action_expert_adaptive_mode} 110 | time_hidden_size: ${time_hidden_size} 111 | mixture: ${mixture} 112 | lora: 113 | r: ${lora_r} 114 | dropout: ${lora_dropout} 115 | # 116 | num_hidden_layers: 18 117 | num_attention_heads: 8 118 | num_key_value_heads: 1 119 | head_dim: 256 120 | rms_norm_eps: 1e-6 121 | attention_bias: False 122 | attention_dropout: 0.0 123 | pad_token_id: ${pad_token_id} 124 | -------------------------------------------------------------------------------- /config/eval/fractal_apple.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | hydra: 4 | run: 5 | dir: ${log_dir} 6 | _target_: src.agent.eval.EvalAgent 7 | 8 | log_dir: ${oc.env:VLA_LOG_DIR}/eval_fractal/${name}_ta${act_steps}_${seed}/${env.task}_${now:%H-%M-%S} 9 | name: 10 | device: cuda 11 | seed: 42 12 | checkpoint_path: 13 | n_eval_episode: ${eval:'9 * 4 * 3 * 10'} # 9 apple locations, 4 urdfs, 3 robot locations/rgb_overlay_paths, 10 trials each 14 | n_video: ${n_eval_episode} 15 | # From Simpler: We place the robot at 3 different positions on the floor and the apple at 9 different positions within a grid on the cabinet top, yielding a total of 3 x 9 = 27 trials 16 | # The 3 locations are fixed w.r.t. rgb_overlay_path! https://github.com/simpler-env/ManiSkill2_real2sim/blob/87dc84508520310e61c972ece399a0f034095e42/mani_skill2_real2sim/envs/custom_scenes/place_in_closed_drawer_in_scene.py#L184 17 | # sweeps: 18 | # urdf_version: 19 | # - null 20 | # - "recolor_tabletop_visual_matching_1" 21 | # - "recolor_tabletop_visual_matching_2" 22 | # - "recolor_cabinet_visual_matching_1" 23 | # rgb_overlay_path: 24 | # - ./SimplerEnv/ManiSkill2_real2sim/data/real_inpainting/open_drawer_a0.png 25 | # - ./SimplerEnv/ManiSkill2_real2sim/data/real_inpainting/open_drawer_b0.png 26 | # - ./SimplerEnv/ManiSkill2_real2sim/data/real_inpainting/open_drawer_c0.png 27 | 28 | env: 29 | task: 30 | adapter: 31 | _target_: src.agent.env_adapter.simpler.EDRSimplerAdapter 32 | dataset_statistics_path: config/fractal_statistics.json 33 | pretrained_model_path: ${oc.env:TRANSFORMERS_CACHE}/paligemma-3b-pt-224 34 | tokenizer_padding: max_length 35 | max_seq_len: 276 # fixed 256 for image + max 20 for text 36 | num_image_tokens: 256 37 | image_size: [224, 224] 38 | 39 | flow_sampling: beta 40 | num_inference_steps: 10 41 | final_action_clip_value: 1.0 # data normalized in [-1,1] 42 | use_torch_compile: True 43 | use_bf16: False 44 | 45 | cond_steps: 1 46 | horizon_steps: 4 47 | act_steps: 2 48 | action_dim: 7 # EEF_POS 49 | proprio_dim: 8 # POS_QUAT 50 | 51 | mixture: 52 | vlm: # gemma 53 | hidden_size: 2048 54 | intermediate_size: 16384 55 | use_final_norm: False 56 | cache: True 57 | use_quantize: ${quantize} 58 | use_lora: ${lora} 59 | adaptive_mode: # not applicable for gemma 60 | rope_theta: 10000.0 # 10000 in gemma 61 | proprio: 62 | hidden_size: 1024 63 | intermediate_size: 4096 64 | use_final_norm: True # technically no, but sharing mixture with action 65 | cache: True 66 | use_quantize: False 67 | use_lora: False 68 | adaptive_mode: ${action_expert_adaptive_mode} 69 | rope_theta: ${action_expert_rope_theta} 70 | action: 71 | hidden_size: 1024 72 | intermediate_size: 4096 73 | use_final_norm: True 74 | cache: False 75 | use_quantize: False 76 | use_lora: False 77 | adaptive_mode: ${action_expert_adaptive_mode} 78 | rope_theta: ${action_expert_rope_theta} 79 | action_expert_adaptive_mode: # adaLN, adaLN-Zero, or None 80 | time_hidden_size: 256 # only applicable if using adaptive 81 | time_max_period: 10000.0 # provided ckpts used 10000.0 for both time_max_period and action_expert_rope_theta 82 | action_expert_rope_theta: 10000.0 83 | quantize: False 84 | lora: False 85 | lora_r: 32 86 | lora_dropout: 0.0 87 | max_image_text_tokens: ${env.adapter.max_seq_len} 88 | 89 | # Fixed 90 | image_token_index: 257152 91 | vocab_size: 257216 92 | pad_token_id: 0 93 | 94 | vision: 95 | _target_: src.model.paligemma.siglip.SiglipVisionModel 96 | config: 97 | hidden_size: 1152 # siglip 98 | intermediate_size: 4304 99 | num_hidden_layers: 27 100 | num_attention_heads: 16 101 | num_channels: 3 102 | image_size: 224 103 | patch_size: 14 104 | layer_norm_eps: 1e-6 105 | attention_dropout: 0.0 106 | num_image_tokens: 256 107 | 108 | vision_projector: 109 | _target_: src.model.paligemma.siglip.PaliGemmaMultiModalProjector 110 | config: 111 | vision_config: 112 | hidden_size: 1152 113 | projection_dim: 2048 114 | 115 | joint: 116 | _target_: src.model.vla.joint_model.JointModel 117 | config: 118 | action_expert_adaptive_mode: ${action_expert_adaptive_mode} 119 | time_hidden_size: ${time_hidden_size} 120 | mixture: ${mixture} 121 | lora: 122 | r: ${lora_r} 123 | dropout: ${lora_dropout} 124 | # 125 | num_hidden_layers: 18 126 | num_attention_heads: 8 127 | num_key_value_heads: 1 128 | head_dim: 256 129 | rms_norm_eps: 1e-6 130 | attention_bias: False 131 | attention_dropout: 0.0 132 | pad_token_id: ${pad_token_id} 133 | -------------------------------------------------------------------------------- /config/eval/fractal_coke.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | hydra: 4 | run: 5 | dir: ${log_dir} 6 | _target_: src.agent.eval.EvalAgent 7 | 8 | log_dir: ${oc.env:VLA_LOG_DIR}/eval_fractal/${name}_ta${act_steps}_${seed}/${env.task}_${now:%H-%M-%S} 9 | name: 10 | device: cuda 11 | seed: 42 12 | checkpoint_path: 13 | n_eval_episode: ${eval:'25 * 4 * 10'} # 25 locations, 4 urdfs, 10 trials each 14 | n_video: ${n_eval_episode} 15 | # From Simpler: for each orientation, we place the coke can at 25 grid positions within a rectangle on the tabletop 16 | # sweeps: 17 | # urdf_version: 18 | # - null 19 | # - "recolor_tabletop_visual_matching_1" 20 | # - "recolor_tabletop_visual_matching_2" 21 | # - "recolor_cabinet_visual_matching_1" 22 | 23 | env: 24 | task: 25 | adapter: 26 | _target_: src.agent.env_adapter.simpler.EDRSimplerAdapter 27 | dataset_statistics_path: config/fractal_statistics.json 28 | pretrained_model_path: ${oc.env:TRANSFORMERS_CACHE}/paligemma-3b-pt-224 29 | tokenizer_padding: max_length 30 | max_seq_len: 276 # fixed 256 for image + max 20 for text 31 | num_image_tokens: 256 32 | image_size: [224, 224] 33 | 34 | flow_sampling: beta 35 | num_inference_steps: 10 36 | final_action_clip_value: 1.0 # data normalized in [-1,1] 37 | use_torch_compile: True 38 | use_bf16: False 39 | 40 | cond_steps: 1 41 | horizon_steps: 4 42 | act_steps: 2 43 | action_dim: 7 # EEF_POS 44 | proprio_dim: 8 # POS_QUAT 45 | 46 | mixture: 47 | vlm: # gemma 48 | hidden_size: 2048 49 | intermediate_size: 16384 50 | use_final_norm: False 51 | cache: True 52 | use_quantize: ${quantize} 53 | use_lora: ${lora} 54 | adaptive_mode: # not applicable for gemma 55 | rope_theta: 10000.0 # 10000 in gemma 56 | proprio: 57 | hidden_size: 1024 58 | intermediate_size: 4096 59 | use_final_norm: True # technically no, but sharing mixture with action 60 | cache: True 61 | use_quantize: False 62 | use_lora: False 63 | adaptive_mode: ${action_expert_adaptive_mode} 64 | rope_theta: ${action_expert_rope_theta} 65 | action: 66 | hidden_size: 1024 67 | intermediate_size: 4096 68 | use_final_norm: True 69 | cache: False 70 | use_quantize: False 71 | use_lora: False 72 | adaptive_mode: ${action_expert_adaptive_mode} 73 | rope_theta: ${action_expert_rope_theta} 74 | action_expert_adaptive_mode: # adaLN, adaLN-Zero, or None 75 | time_hidden_size: 256 # only applicable if using adaptive 76 | time_max_period: 10000.0 # provided ckpts used 10000.0 for both time_max_period and action_expert_rope_theta 77 | action_expert_rope_theta: 10000.0 78 | quantize: False 79 | lora: False 80 | lora_r: 32 81 | lora_dropout: 0.0 82 | max_image_text_tokens: ${env.adapter.max_seq_len} 83 | 84 | # Fixed 85 | image_token_index: 257152 86 | vocab_size: 257216 87 | pad_token_id: 0 88 | 89 | vision: 90 | _target_: src.model.paligemma.siglip.SiglipVisionModel 91 | config: 92 | hidden_size: 1152 # siglip 93 | intermediate_size: 4304 94 | num_hidden_layers: 27 95 | num_attention_heads: 16 96 | num_channels: 3 97 | image_size: 224 98 | patch_size: 14 99 | layer_norm_eps: 1e-6 100 | attention_dropout: 0.0 101 | num_image_tokens: 256 102 | 103 | vision_projector: 104 | _target_: src.model.paligemma.siglip.PaliGemmaMultiModalProjector 105 | config: 106 | vision_config: 107 | hidden_size: 1152 108 | projection_dim: 2048 109 | 110 | joint: 111 | _target_: src.model.vla.joint_model.JointModel 112 | config: 113 | action_expert_adaptive_mode: ${action_expert_adaptive_mode} 114 | time_hidden_size: ${time_hidden_size} 115 | mixture: ${mixture} 116 | lora: 117 | r: ${lora_r} 118 | dropout: ${lora_dropout} 119 | # 120 | num_hidden_layers: 18 121 | num_attention_heads: 8 122 | num_key_value_heads: 1 123 | head_dim: 256 124 | rms_norm_eps: 1e-6 125 | attention_bias: False 126 | attention_dropout: 0.0 127 | pad_token_id: ${pad_token_id} 128 | -------------------------------------------------------------------------------- /config/eval/fractal_drawer.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | hydra: 4 | run: 5 | dir: ${log_dir} 6 | _target_: src.agent.eval.EvalAgent 7 | 8 | log_dir: ${oc.env:VLA_LOG_DIR}/eval_fractal/${name}_ta${act_steps}_${seed}/${env.task}_${now:%H-%M-%S} 9 | name: 10 | device: cuda 11 | seed: 42 12 | checkpoint_path: 13 | n_eval_episode: ${eval:'3 * 4 * 9 * 10'} # 3 drawers, 4 urdfs, 9 locations/rgb_overlay_paths, 10 trials each 14 | n_video: ${n_eval_episode} 15 | # From Simpler: The robot is positioned in front of a cabinet that contains 3 drawers and instructed to open / close a specific drawer, testing its ability to manipulate articulated objects. We place the robot at 9 grid positions within a rectangle on the floor, yielding a total of 9 x 3 x 2 = 54 trials. 16 | # The 9 locations are fixed w.r.t. rgb_overlay_path! https://github.com/simpler-env/ManiSkill2_real2sim/blob/87dc84508520310e61c972ece399a0f034095e42/mani_skill2_real2sim/envs/custom_scenes/open_drawer_in_scene.py#L210 17 | # sweeps: 18 | # urdf_version: 19 | # - null 20 | # - "recolor_tabletop_visual_matching_1" 21 | # - "recolor_tabletop_visual_matching_2" 22 | # - "recolor_cabinet_visual_matching_1" 23 | # rgb_overlay_path: 24 | # - ./ManiSkill2_real2sim/data/real_inpainting/open_drawer_a0.png 25 | # - ./ManiSkill2_real2sim/data/real_inpainting/open_drawer_a1.png 26 | # - ./ManiSkill2_real2sim/data/real_inpainting/open_drawer_a2.png 27 | # - ./ManiSkill2_real2sim/data/real_inpainting/open_drawer_b0.png 28 | # - ./ManiSkill2_real2sim/data/real_inpainting/open_drawer_b1.png 29 | # - ./ManiSkill2_real2sim/data/real_inpainting/open_drawer_b2.png 30 | # - ./ManiSkill2_real2sim/data/real_inpainting/open_drawer_c0.png 31 | # - ./ManiSkill2_real2sim/data/real_inpainting/open_drawer_c1.png 32 | # - ./ManiSkill2_real2sim/data/real_inpainting/open_drawer_c2.png 33 | 34 | env: 35 | task: 36 | adapter: 37 | _target_: src.agent.env_adapter.simpler.EDRSimplerAdapter 38 | dataset_statistics_path: config/fractal_statistics.json 39 | pretrained_model_path: ${oc.env:TRANSFORMERS_CACHE}/paligemma-3b-pt-224 40 | tokenizer_padding: max_length 41 | max_seq_len: 276 # fixed 256 for image + max 20 for text 42 | num_image_tokens: 256 43 | image_size: [224, 224] 44 | 45 | flow_sampling: beta 46 | num_inference_steps: 10 47 | final_action_clip_value: 1.0 # data normalized in [-1,1] 48 | use_torch_compile: True 49 | use_bf16: False 50 | 51 | cond_steps: 1 52 | horizon_steps: 4 53 | act_steps: 2 54 | action_dim: 7 # EEF_POS 55 | proprio_dim: 8 # POS_QUAT 56 | 57 | mixture: 58 | vlm: # gemma 59 | hidden_size: 2048 60 | intermediate_size: 16384 61 | use_final_norm: False 62 | cache: True 63 | use_quantize: ${quantize} 64 | use_lora: ${lora} 65 | adaptive_mode: # not applicable for gemma 66 | rope_theta: 10000.0 # 10000 in gemma 67 | proprio: 68 | hidden_size: 1024 69 | intermediate_size: 4096 70 | use_final_norm: True # technically no, but sharing mixture with action 71 | cache: True 72 | use_quantize: False 73 | use_lora: False 74 | adaptive_mode: ${action_expert_adaptive_mode} 75 | rope_theta: ${action_expert_rope_theta} 76 | action: 77 | hidden_size: 1024 78 | intermediate_size: 4096 79 | use_final_norm: True 80 | cache: False 81 | use_quantize: False 82 | use_lora: False 83 | adaptive_mode: ${action_expert_adaptive_mode} 84 | rope_theta: ${action_expert_rope_theta} 85 | action_expert_adaptive_mode: # adaLN, adaLN-Zero, or None 86 | time_hidden_size: 256 # only applicable if using adaptive 87 | time_max_period: 10000.0 # provided ckpts used 10000.0 for both time_max_period and action_expert_rope_theta 88 | action_expert_rope_theta: 10000.0 89 | quantize: False 90 | lora: False 91 | lora_r: 32 92 | lora_dropout: 0.0 93 | max_image_text_tokens: ${env.adapter.max_seq_len} 94 | 95 | # Fixed 96 | image_token_index: 257152 97 | vocab_size: 257216 98 | pad_token_id: 0 99 | 100 | vision: 101 | _target_: src.model.paligemma.siglip.SiglipVisionModel 102 | config: 103 | hidden_size: 1152 # siglip 104 | intermediate_size: 4304 105 | num_hidden_layers: 27 106 | num_attention_heads: 16 107 | num_channels: 3 108 | image_size: 224 109 | patch_size: 14 110 | layer_norm_eps: 1e-6 111 | attention_dropout: 0.0 112 | num_image_tokens: 256 113 | 114 | vision_projector: 115 | _target_: src.model.paligemma.siglip.PaliGemmaMultiModalProjector 116 | config: 117 | vision_config: 118 | hidden_size: 1152 119 | projection_dim: 2048 120 | 121 | joint: 122 | _target_: src.model.vla.joint_model.JointModel 123 | config: 124 | action_expert_adaptive_mode: ${action_expert_adaptive_mode} 125 | time_hidden_size: ${time_hidden_size} 126 | mixture: ${mixture} 127 | lora: 128 | r: ${lora_r} 129 | dropout: ${lora_dropout} 130 | # 131 | num_hidden_layers: 18 132 | num_attention_heads: 8 133 | num_key_value_heads: 1 134 | head_dim: 256 135 | rms_norm_eps: 1e-6 136 | attention_bias: False 137 | attention_dropout: 0.0 138 | pad_token_id: ${pad_token_id} 139 | -------------------------------------------------------------------------------- /config/eval/fractal_move.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | hydra: 4 | run: 5 | dir: ${log_dir} 6 | _target_: src.agent.eval.EvalAgent 7 | 8 | log_dir: ${oc.env:VLA_LOG_DIR}/eval_fractal/${name}_ta${act_steps}_${seed}/${env.task}_${now:%H-%M-%S} 9 | name: 10 | device: cuda 11 | seed: 42 12 | checkpoint_path: 13 | n_eval_episode: ${eval:'60 * 4 * 10'} # 60 locations, 4 urdfs, 10 trials each 14 | n_video: ${n_eval_episode} 15 | # From Simpler: In each trial, one object serves as the source object, one serves as the target, and the other serves as the distractor (this creates 6 trials for each triplet and each triangle pattern). We randomly choose 5 triplets of objects among a total of 8 objects (blue plastic bottle, pepsi can, orange, 7up can, apple, sponge, coke can, redbull can), and adopt 2 triangle patterns (upright and inverted). This creates a total of 5 x 2 x 6 = 60 trials 16 | # sweeps: 17 | # urdf_version: 18 | # - null 19 | # - "recolor_tabletop_visual_matching_1" 20 | # - "recolor_tabletop_visual_matching_2" 21 | # - "recolor_cabinet_visual_matching_1" 22 | 23 | env: 24 | task: 25 | adapter: 26 | _target_: src.agent.env_adapter.simpler.EDRSimplerAdapter 27 | dataset_statistics_path: config/fractal_statistics.json 28 | pretrained_model_path: ${oc.env:TRANSFORMERS_CACHE}/paligemma-3b-pt-224 29 | tokenizer_padding: max_length 30 | max_seq_len: 276 # fixed 256 for image + max 20 for text 31 | num_image_tokens: 256 32 | image_size: [224, 224] 33 | 34 | flow_sampling: beta 35 | num_inference_steps: 10 36 | final_action_clip_value: 1.0 # data normalized in [-1,1] 37 | use_torch_compile: True 38 | use_bf16: False 39 | 40 | cond_steps: 1 41 | horizon_steps: 4 42 | act_steps: 2 43 | action_dim: 7 # EEF_POS 44 | proprio_dim: 8 # POS_QUAT 45 | 46 | mixture: 47 | vlm: # gemma 48 | hidden_size: 2048 49 | intermediate_size: 16384 50 | use_final_norm: False 51 | cache: True 52 | use_quantize: ${quantize} 53 | use_lora: ${lora} 54 | adaptive_mode: # not applicable for gemma 55 | rope_theta: 10000.0 # 10000 in gemma 56 | proprio: 57 | hidden_size: 1024 58 | intermediate_size: 4096 59 | use_final_norm: True # technically no, but sharing mixture with action 60 | cache: True 61 | use_quantize: False 62 | use_lora: False 63 | adaptive_mode: ${action_expert_adaptive_mode} 64 | rope_theta: ${action_expert_rope_theta} 65 | action: 66 | hidden_size: 1024 67 | intermediate_size: 4096 68 | use_final_norm: True 69 | cache: False 70 | use_quantize: False 71 | use_lora: False 72 | adaptive_mode: ${action_expert_adaptive_mode} 73 | rope_theta: ${action_expert_rope_theta} 74 | action_expert_adaptive_mode: # adaLN, adaLN-Zero, or None 75 | time_hidden_size: 256 # only applicable if using adaptive 76 | time_max_period: 10000.0 # provided ckpts used 10000.0 for both time_max_period and action_expert_rope_theta 77 | action_expert_rope_theta: 10000.0 78 | quantize: False 79 | lora: False 80 | lora_r: 32 81 | lora_dropout: 0.0 82 | max_image_text_tokens: ${env.adapter.max_seq_len} 83 | 84 | # Fixed 85 | image_token_index: 257152 86 | vocab_size: 257216 87 | pad_token_id: 0 88 | 89 | vision: 90 | _target_: src.model.paligemma.siglip.SiglipVisionModel 91 | config: 92 | hidden_size: 1152 # siglip 93 | intermediate_size: 4304 94 | num_hidden_layers: 27 95 | num_attention_heads: 16 96 | num_channels: 3 97 | image_size: 224 98 | patch_size: 14 99 | layer_norm_eps: 1e-6 100 | attention_dropout: 0.0 101 | num_image_tokens: 256 102 | 103 | vision_projector: 104 | _target_: src.model.paligemma.siglip.PaliGemmaMultiModalProjector 105 | config: 106 | vision_config: 107 | hidden_size: 1152 108 | projection_dim: 2048 109 | 110 | joint: 111 | _target_: src.model.vla.joint_model.JointModel 112 | config: 113 | action_expert_adaptive_mode: ${action_expert_adaptive_mode} 114 | time_hidden_size: ${time_hidden_size} 115 | mixture: ${mixture} 116 | lora: 117 | r: ${lora_r} 118 | dropout: ${lora_dropout} 119 | # 120 | num_hidden_layers: 18 121 | num_attention_heads: 8 122 | num_key_value_heads: 1 123 | head_dim: 256 124 | rms_norm_eps: 1e-6 125 | attention_bias: False 126 | attention_dropout: 0.0 127 | pad_token_id: ${pad_token_id} 128 | -------------------------------------------------------------------------------- /config/fractal_statistics.json: -------------------------------------------------------------------------------- 1 | {"action": {"mean": [0.006987567059695721, 0.006265868898481131, -0.012625112198293209, 0.04333272576332092, -0.005756245460361242, 0.0009130232501775026, 0.5354204773902893], "std": [0.06921152025461197, 0.05971040576696396, 0.07353048771619797, 0.156105175614357, 0.1316440999507904, 0.14593836665153503, 0.4971115291118622], "max": [2.9984593391418457, 22.09052848815918, 2.7507524490356445, 1.570636510848999, 1.5321086645126343, 1.5691522359848022, 1.0], "min": [-2.0204520225524902, -5.497899532318115, -2.031663417816162, -1.569917917251587, -1.569892168045044, -1.570419430732727, 0.0], "p99": [0.17824687153100965, 0.14938379630446405, 0.21842354819178575, 0.5892666035890578, 0.35272657424211445, 0.44796681255102094, 1.0], "p01": [-0.22453527510166169, -0.14820013284683228, -0.231589707583189, -0.3517994859814644, -0.4193011274933815, -0.43643461108207704, 0.0]}, "num_transitions": 3786400, "num_trajectories": 87212, "proprio": {"mean": [0.5599020719528198, -0.08333852887153625, 0.7770926356315613, -0.24803675711154938, 0.49517107009887695, 0.09266142547130585, 0.20975486934185028, 0.42613455653190613], "std": [0.12432780861854553, 0.11558882147073746, 0.24595776200294495, 0.5126982927322388, 0.5218101143836975, 0.16630391776561737, 0.2754841148853302, 0.45544859766960144], "max": [1.0534898042678833, 0.48018959164619446, 1.6896663904190063, 0.9999993443489075, 0.9999874830245972, 0.9554369449615479, 0.9914546012878418, 1.0], "min": [-0.4436439275741577, -0.9970501065254211, -0.006579156965017319, -0.8643477559089661, -0.7079970240592957, -0.7688722014427185, -0.4999994933605194, 0.0], "p99": [0.8750156319141384, 0.21247054174542404, 1.0727112340927123, 0.9377871316671368, 0.9563051050901409, 0.45990042358636823, 0.7216041100025177, 1.0], "p01": [0.32481380939483645, -0.28334290891885755, 0.14107070609927178, -0.686474204659462, -0.6808923494815826, -0.36045596331357954, -0.454380963742733, 0.0]}} 2 | -------------------------------------------------------------------------------- /config/train/bridge.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | hydra: 4 | run: 5 | dir: ${log_dir} 6 | _target_: src.agent.train.TrainAgent 7 | 8 | log_dir: ${oc.env:VLA_LOG_DIR}/train/${name}/${now:%Y-%m-%d}_${now:%H-%M}_${seed} 9 | name: ${data.train.dataset_mix}_${data.train.split}_tp${horizon_steps}_${flow_sampling} 10 | device: cuda 11 | n_nodes: 1 12 | seed: 42 13 | pretrained_model_path: ${oc.env:TRANSFORMERS_CACHE}/paligemma-3b-pt-224 14 | load_pretrained_weights: True 15 | resume_checkpoint_path: 16 | train_vlm: True 17 | action_expert_adaptive_mode: # adaLN, adaLN-Zero, or None 18 | use_torch_compile: True 19 | use_bf16: True 20 | use_amp: True 21 | quantize: False 22 | lora: False 23 | lora_r: 32 24 | lora_dropout: 0.0 25 | debug: False 26 | 27 | use_ema: False 28 | ema_decay: 0.99 29 | ema_start: ${save_model_start} 30 | ema_freq: 1 31 | ema_device: cuda 32 | use_swa: False 33 | swa_start: ${eval:'1550000 // ${global_batch_size} * 5'} 34 | swa_freq: ${eval:'1550000 // ${global_batch_size} // 4'} 35 | swa_device: cpu 36 | 37 | data: 38 | val: # stil aplying image randomization 39 | split: # full val split 40 | shuffle_buffer_size: 10000 41 | train: 42 | dataset_mix: bridge 43 | split: train 44 | data_path: ${oc.env:VLA_DATA_DIR}/resize_224 45 | window_size: ${cond_steps} 46 | action_horizon: ${horizon_steps} 47 | skip_unlabeled: True 48 | load_proprio: True 49 | shuffle_buffer_size: 200000 50 | num_parallel_calls: 100 51 | traj_transform_threads: 10 52 | traj_read_threads: 10 53 | 54 | wandb: 55 | entity: ${oc.env:VLA_WANDB_ENTITY} 56 | project: open-pi-zero 57 | run: ${now:%H-%M-%S}_${name} 58 | 59 | log_freq: 16 60 | n_epochs: 15 # provided ckpts were about 12 epochs 61 | n_updates: ${eval:'1550000 // ${global_batch_size} * ${n_epochs}'} # bridge dataset has 2195527 transitions in total, but we are skipping (text-)unlabeled episodes, which is roughly 30% of the data 62 | save_model_freq: ${eval:'1550000 // ${global_batch_size} * 1'} 63 | save_model_start: ${eval:'1550000 // ${global_batch_size} * 5'} 64 | eval_freq: 2000 65 | eval_size: 1024 66 | eval_thresholds: [0.05, 0.1, 0.2, 0.3, 0.5] 67 | 68 | global_batch_size: 1024 69 | per_device_batch_size: 16 70 | action_lr: 5e-5 71 | vlm_lr: 5e-5 72 | action_lr_scheduler: 73 | first_cycle_steps: 10000000 # basically no decaying 74 | min_lr: 1e-8 75 | warmup_steps: 200 # a bit of warmup 76 | vlm_lr_scheduler: 77 | first_cycle_steps: 10000000 78 | min_lr: 1e-8 79 | warmup_steps: 200 80 | action_weight_decay: 0 81 | vlm_weight_decay: 0 82 | max_grad_norm: 1.0 83 | 84 | flow_sampling: beta 85 | num_inference_steps: 10 86 | final_action_clip_value: 1.0 # data normalized in [-1,1] 87 | 88 | cond_steps: 1 89 | horizon_steps: 4 90 | action_dim: 7 # EEF_POS 91 | proprio_dim: 7 # POS_EULER 92 | max_seq_len: 276 # fixed 256 for image + max 20 for text 93 | tokenizer_padding: max_length # instead of truncating to longest 94 | max_image_text_tokens: ${max_seq_len} 95 | 96 | mixture: 97 | vlm: # gemma 98 | hidden_size: 2048 99 | intermediate_size: 16384 100 | use_final_norm: False 101 | cache: True 102 | use_quantize: ${quantize} 103 | use_lora: ${lora} 104 | adaptive_mode: # not applicable for gemma 105 | rope_theta: 10000.0 # 10000 in gemma 106 | proprio: 107 | hidden_size: 1024 108 | intermediate_size: 4096 109 | use_final_norm: True # technically no, but sharing weights with action anyway 110 | cache: True 111 | use_quantize: False 112 | use_lora: False 113 | adaptive_mode: ${action_expert_adaptive_mode} 114 | rope_theta: ${action_expert_rope_theta} 115 | action: 116 | hidden_size: 1024 117 | intermediate_size: 4096 118 | use_final_norm: True 119 | cache: False 120 | use_quantize: False 121 | use_lora: False 122 | adaptive_mode: ${action_expert_adaptive_mode} 123 | rope_theta: ${action_expert_rope_theta} 124 | time_hidden_size: 256 # only applicable if using adaptive 125 | time_max_period: 100.0 126 | action_expert_rope_theta: 100.0 # since action/proprio seq_len is pretty small 127 | 128 | # Fixed 129 | image_token_index: 257152 130 | vocab_size: 257216 131 | pad_token_id: 0 132 | 133 | vision: 134 | _target_: src.model.paligemma.siglip.SiglipVisionModel 135 | config: 136 | hidden_size: 1152 # siglip 137 | intermediate_size: 4304 138 | num_hidden_layers: 27 139 | num_attention_heads: 16 140 | num_channels: 3 141 | image_size: 224 142 | patch_size: 14 143 | layer_norm_eps: 1e-6 144 | attention_dropout: 0.0 145 | num_image_tokens: 256 146 | lora: 147 | r: ${lora_r} 148 | dropout: ${lora_dropout} 149 | use_quantize: ${quantize} 150 | use_lora: ${lora} 151 | 152 | vision_projector: 153 | _target_: src.model.paligemma.siglip.PaliGemmaMultiModalProjector 154 | config: 155 | vision_config: 156 | hidden_size: 1152 157 | projection_dim: 2048 158 | lora: 159 | r: ${lora_r} 160 | dropout: ${lora_dropout} 161 | use_quantize: ${quantize} 162 | use_lora: ${lora} 163 | 164 | joint: 165 | _target_: src.model.vla.joint_model.JointModel 166 | config: 167 | action_expert_adaptive_mode: ${action_expert_adaptive_mode} 168 | time_hidden_size: ${time_hidden_size} 169 | mixture: ${mixture} 170 | lora: 171 | r: ${lora_r} 172 | dropout: ${lora_dropout} 173 | # 174 | num_hidden_layers: 18 175 | num_attention_heads: 8 176 | num_key_value_heads: 1 177 | head_dim: 256 178 | rms_norm_eps: 1e-6 179 | attention_bias: False 180 | attention_dropout: 0.0 181 | pad_token_id: ${pad_token_id} 182 | -------------------------------------------------------------------------------- /config/train/fractal.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | hydra: 4 | run: 5 | dir: ${log_dir} 6 | _target_: src.agent.train.TrainAgent 7 | 8 | log_dir: ${oc.env:VLA_LOG_DIR}/train/${name}/${now:%Y-%m-%d}_${now:%H-%M}_${seed} 9 | name: ${data.train.dataset_mix}_${data.train.split}_tp${horizon_steps}_${flow_sampling} 10 | device: cuda 11 | n_nodes: 1 12 | seed: 42 13 | pretrained_model_path: ${oc.env:TRANSFORMERS_CACHE}/paligemma-3b-pt-224 14 | load_pretrained_weights: True 15 | resume_checkpoint_path: 16 | train_vlm: True 17 | action_expert_adaptive_mode: # adaLN, adaLN-Zero, or None 18 | use_torch_compile: True 19 | use_bf16: True 20 | use_amp: True 21 | quantize: False 22 | lora: False 23 | lora_r: 32 24 | lora_dropout: 0.0 25 | debug: False 26 | 27 | use_ema: False 28 | ema_decay: 0.99 29 | ema_start: ${save_model_start} 30 | ema_freq: 1 31 | ema_device: cuda 32 | use_swa: False 33 | swa_start: ${eval:'3786400 // ${global_batch_size} * 5'} 34 | swa_freq: ${eval:'3786400 // ${global_batch_size} // 4'} 35 | swa_device: cpu 36 | 37 | data: 38 | val: # stil aplying image randomization 39 | split: train[95%:] 40 | shuffle_buffer_size: 10000 41 | train: 42 | dataset_mix: fractal 43 | split: train[:95%] 44 | data_path: ${oc.env:VLA_DATA_DIR}/resize_224 45 | window_size: ${cond_steps} 46 | action_horizon: ${horizon_steps} 47 | skip_unlabeled: True 48 | load_proprio: True 49 | shuffle_buffer_size: 200000 50 | num_parallel_calls: 100 51 | traj_transform_threads: 10 52 | traj_read_threads: 10 53 | 54 | wandb: 55 | entity: ${oc.env:VLA_WANDB_ENTITY} 56 | project: open-pi-zero 57 | run: ${now:%H-%M-%S}_${name} 58 | 59 | log_freq: 16 60 | n_epochs: 10 # provided ckpts were about 8 epochs 61 | n_updates: ${eval:'3786400 // ${global_batch_size} * ${n_epochs}'} # fractal dataset has 3786400 transitions in total, all used 62 | save_model_freq: ${eval:'3786400 // ${global_batch_size} * 1'} 63 | save_model_start: ${eval:'3786400 // ${global_batch_size} * 3'} 64 | eval_freq: 2000 65 | eval_size: 1024 66 | eval_thresholds: [0.05, 0.1, 0.2, 0.3, 0.5] 67 | 68 | global_batch_size: 1024 69 | per_device_batch_size: 16 70 | action_lr: 5e-5 71 | vlm_lr: 5e-5 72 | action_lr_scheduler: 73 | first_cycle_steps: 10000000 # basically no decaying 74 | min_lr: 1e-8 75 | warmup_steps: 200 # a bit of warmup 76 | vlm_lr_scheduler: 77 | first_cycle_steps: 10000000 78 | min_lr: 1e-8 79 | warmup_steps: 200 80 | action_weight_decay: 0 81 | vlm_weight_decay: 0 82 | max_grad_norm: 1.0 83 | 84 | flow_sampling: beta 85 | num_inference_steps: 10 86 | final_action_clip_value: 1.0 # data normalized in [-1,1] 87 | 88 | cond_steps: 1 89 | horizon_steps: 4 90 | action_dim: 7 # EEF_POS 91 | proprio_dim: 8 # POS_QUAT 92 | max_seq_len: 276 # fixed 256 for image + max 20 for text 93 | tokenizer_padding: max_length # instead of truncating to longest 94 | max_image_text_tokens: ${max_seq_len} 95 | 96 | mixture: 97 | vlm: # gemma 98 | hidden_size: 2048 99 | intermediate_size: 16384 100 | use_final_norm: False 101 | cache: True 102 | use_quantize: ${quantize} 103 | use_lora: ${lora} 104 | adaptive_mode: # not applicable for gemma 105 | rope_theta: 10000.0 # 10000 in gemma 106 | proprio: 107 | hidden_size: 1024 108 | intermediate_size: 4096 109 | use_final_norm: True # technically no, but sharing weights with action anyway 110 | cache: True 111 | use_quantize: False 112 | use_lora: False 113 | adaptive_mode: ${action_expert_adaptive_mode} 114 | rope_theta: ${action_expert_rope_theta} 115 | action: 116 | hidden_size: 1024 117 | intermediate_size: 4096 118 | use_final_norm: True 119 | cache: False 120 | use_quantize: False 121 | use_lora: False 122 | adaptive_mode: ${action_expert_adaptive_mode} 123 | rope_theta: ${action_expert_rope_theta} 124 | time_hidden_size: 256 # only applicable if using adaptive 125 | time_max_period: 100.0 126 | action_expert_rope_theta: 100.0 # since action/proprio seq_len is pretty small 127 | 128 | # Fixed 129 | image_token_index: 257152 130 | vocab_size: 257216 131 | pad_token_id: 0 132 | 133 | vision: 134 | _target_: src.model.paligemma.siglip.SiglipVisionModel 135 | config: 136 | hidden_size: 1152 # siglip 137 | intermediate_size: 4304 138 | num_hidden_layers: 27 139 | num_attention_heads: 16 140 | num_channels: 3 141 | image_size: 224 142 | patch_size: 14 143 | layer_norm_eps: 1e-6 144 | attention_dropout: 0.0 145 | num_image_tokens: 256 146 | lora: 147 | r: ${lora_r} 148 | dropout: ${lora_dropout} 149 | use_quantize: ${quantize} 150 | use_lora: ${lora} 151 | 152 | vision_projector: 153 | _target_: src.model.paligemma.siglip.PaliGemmaMultiModalProjector 154 | config: 155 | vision_config: 156 | hidden_size: 1152 157 | projection_dim: 2048 158 | lora: 159 | r: ${lora_r} 160 | dropout: ${lora_dropout} 161 | use_quantize: ${quantize} 162 | use_lora: ${lora} 163 | 164 | joint: 165 | _target_: src.model.vla.joint_model.JointModel 166 | config: 167 | action_expert_adaptive_mode: ${action_expert_adaptive_mode} 168 | time_hidden_size: ${time_hidden_size} 169 | mixture: ${mixture} 170 | lora: 171 | r: ${lora_r} 172 | dropout: ${lora_dropout} 173 | # 174 | num_hidden_layers: 18 175 | num_attention_heads: 8 176 | num_key_value_heads: 1 177 | head_dim: 256 178 | rms_norm_eps: 1e-6 179 | attention_bias: False 180 | attention_dropout: 0.0 181 | pad_token_id: ${pad_token_id} 182 | -------------------------------------------------------------------------------- /doc/convention.md: -------------------------------------------------------------------------------- 1 | ## Conventions in datasets / Simpler 2 | 3 | ### EE proprio 4 | 5 | Fractal data has xyzw quaternion in proprio (upon inspection), and I have been using wxyz in Simpler since it follows the transforms3d library. Bridge uses sxyz euler. EE pose saved in bridge data is relative to a top-down pose (instead of base pose). Both datasets use +x for forward, +y for left, and +z for upward. 6 | 7 | ### Gripper proprio and action 8 | 9 | In Octo, bridge data has 1 for gripper state open and -1 for closed after normalization (continuous), and 1 for gripper action open and 0 for closing (without normalization, binarized). Fractal data has -1 for gripper state open and 1 for open closed after normalization (continuous), and also 1 for gripper action open and 0 for closing (without normalization, binarized). 10 | 11 | I added gripper width (1 for open and 0 for closed) to the environment observation in Simpler. Then for the action in Simpler, widowx robot (bridge) has 1 for opening gripper and -1 for closing. Google robot has 1 for closing gripper and -1 for opening. 12 | -------------------------------------------------------------------------------- /doc/error.md: -------------------------------------------------------------------------------- 1 | ## Misc errors 2 | 3 | ### Data processing error 4 | 5 | Comment out Line 299-306 in `.venv/lib/python3.10/site-packages/tensorflow_datasets/core/dataset_builder.py` to avoid the `AttributeError: 'MultiplexedPath' object has no attribute 'parts'` error (seems an issue with running python3.10; using `tensorflow_datasets==4.9.2` fixes this issue but disabling gcs does not work any more somehow) 6 | 7 | ### Quantization error in training 8 | 9 | If using quantization in training, might need to modify Line 474 in `.venv/lib64/python3.10/site-packages/bitsandbytes/autograd/_functions.py` to `return output.clone` from `return output` ([related issue](https://github.com/bitsandbytes-foundation/bitsandbytes/issues/736)). 10 | 11 | ### torch.compile + quantization 12 | 13 | Quantization does not work well with torch.compile currently when running eval. 14 | -------------------------------------------------------------------------------- /doc/notes.md: -------------------------------------------------------------------------------- 1 | ## Observations from training 2 | 3 | Tried Gaussian Fourier features for proprio/action input but did not help. 4 | 5 | adaLN(-Zero) for time conditioning seemed to speed up training a bit initially, but did not make a significant difference after a while. 6 | 7 | Overall, Beta sampling in flow matching timesteps achieved better validation loss, but I usually saw Uniform sampling matching or even outperforming in validation accuracy with low threshold (e.g., predicted actions within 0.05 from the normalized ground-truth ones in all dimensions). With high threshold like 0.3 or 0.5, Beta seemed to consistently outperform Uniform. 8 | 9 | I was able to train with learning rate as high as 3e-4 with batch size 1024, thanks to the stability of the flow matching / diffusion objective? 10 | 11 | I tried training with batch size from 256 to 2048, and the training curves (wall-clock time vs. training loss) were similar. 12 | 13 | I switched to using [-1, 1] normalization from unit Gaussian used in Octo because I find the bridge dataset has some weird, very large action values (e.g., 80). Without clipping after being normalized with unit std, it causes a lot of spikes in training loss. 14 | 15 | Not using pre-trained PaliGemma weights trained much worse. Training the action expert only (freezing PaliGemma) did not work at all. 16 | -------------------------------------------------------------------------------- /media/maniskill_pp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenzren/open-pi-zero/c3df7fb062175c16f69d7ca4ce042958ea238fb7/media/maniskill_pp.png -------------------------------------------------------------------------------- /media/open-pi-zero-overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenzren/open-pi-zero/c3df7fb062175c16f69d7ca4ce042958ea238fb7/media/open-pi-zero-overview.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "open-pi-zero" 3 | version = "0.1.1" 4 | description = "Re-implementation of Pi0 vision-language-action (VLA) model from Physical Intelligence" 5 | authors = [ 6 | {name = "Allen Z. Ren", email = "allenzren1@gmail.com"}, 7 | ] 8 | readme = "README.md" 9 | requires-python = "==3.10.*" 10 | classifiers = [ 11 | "Programming Language :: Python :: 3", 12 | ] 13 | dependencies = [ 14 | "bitsandbytes==0.45.0", 15 | "einops", 16 | "gsutil>=5.32", 17 | "hydra-core", 18 | "imageio", 19 | "matplotlib", 20 | "numpy==1.26.4", 21 | "omegaconf", 22 | "pillow", 23 | "pre-commit>=4.0.1", 24 | "pretty_errors", 25 | "protobuf==3.20.3", 26 | "tensorflow==2.15.0", 27 | "tensorflow_datasets==4.9.2", 28 | "torch==2.5.0", 29 | "transformers==4.47.1", 30 | "tqdm", 31 | "wandb", 32 | ] 33 | 34 | [build-system] 35 | requires = ["setuptools>=61.0"] 36 | build-backend = "setuptools.build_meta" 37 | 38 | [tool.setuptools.packages.find] 39 | exclude = [] 40 | 41 | [tool.ruff] 42 | line-length = 88 43 | target-version = "py310" 44 | extend-exclude = ["src/data/obs_transforms.py", "src/data/utils/data_utils.py"] 45 | 46 | [tool.ruff.lint] 47 | select = ["A", "B", "E", "F", "I", "RUF", "W"] 48 | ignore = ["E203", "E501", "B006", "B026", "B905"] 49 | 50 | [tool.ruff.lint.per-file-ignores] 51 | "__init__.py" = ["E402", "F401", "F403"] 52 | -------------------------------------------------------------------------------- /scripts/data/check_bridge.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import tensorflow as tf 4 | import tqdm 5 | from PIL import Image 6 | from torch.utils.data import DataLoader 7 | 8 | from src.data.dataset import make_interleaved_dataset 9 | from src.data.dataset_torch import TorchRLDSDataset 10 | from src.data.oxe import make_oxe_dataset_kwargs_and_weights 11 | 12 | tf.config.set_visible_devices([], "GPU") 13 | 14 | 15 | if __name__ == "__main__": 16 | import argparse 17 | import os 18 | 19 | import einops 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument( 23 | "--data_path", type=str, default=f"{os.environ['VLA_DATA_DIR']}/resize_224" 24 | ) 25 | parser.add_argument("--mix", type=str, default="bridge") 26 | parser.add_argument("--camera_views", nargs="*", default=("primary",)) 27 | parser.add_argument( 28 | "--skip_norm", action="store_true", help="Use raw actions and proprio" 29 | ) 30 | args = parser.parse_args() 31 | 32 | # config 33 | start_time = time.time() 34 | dataset_kwargs_list, sample_weights = make_oxe_dataset_kwargs_and_weights( 35 | args.mix, 36 | args.data_path, 37 | load_depth=False, 38 | load_language=True, 39 | load_proprio=True, 40 | load_camera_views=args.camera_views, 41 | skip_norm=args.skip_norm, 42 | ) 43 | 44 | # dataset --- bridge has 2195527 transitions and 60064 trajectories 45 | dataset = make_interleaved_dataset( 46 | dataset_kwargs_list, 47 | sample_weights, 48 | train=True, 49 | split="train", # bridge has a separate validation split 50 | shuffle_buffer_size=10000, # change to 500k for training, large shuffle buffers are important, but adjust to your RAM 51 | batch_size=None, # batching will be handles in PyTorch Dataloader object 52 | balance_weights=True, 53 | traj_transform_kwargs=dict( # no neeed for goal relabeling 54 | window_size=2, 55 | action_horizon=4, 56 | subsample_length=100, 57 | skip_unlabeled=False, # skip ones without language annotation 58 | # max_action_from_stats=True, 59 | # max_proprio_from_stats=True, 60 | ), 61 | frame_transform_kwargs=dict( 62 | image_augment_kwargs={ 63 | "primary": dict( 64 | random_resized_crop=dict( 65 | scale=[0.8, 1.0], 66 | ratio=[0.9, 1.1], 67 | ), 68 | random_brightness=[0.1], 69 | random_contrast=[0.9, 1.1], 70 | random_saturation=[0.9, 1.1], 71 | random_hue=[0.05], 72 | augment_order=[ 73 | "random_resized_crop", 74 | "random_brightness", 75 | "random_contrast", 76 | "random_saturation", 77 | "random_hue", 78 | ], 79 | ), 80 | "wrist": dict( 81 | random_brightness=[0.1], 82 | random_contrast=[0.9, 1.1], 83 | random_saturation=[0.9, 1.1], 84 | random_hue=[0.05], 85 | augment_order=[ 86 | "random_brightness", 87 | "random_contrast", 88 | "random_saturation", 89 | "random_hue", 90 | ], 91 | ), 92 | }, 93 | resize_size=dict( 94 | primary=(224, 224), 95 | wrist=(224, 224), 96 | ), 97 | num_parallel_calls=200, 98 | ), 99 | traj_transform_threads=48, 100 | traj_read_threads=48, 101 | ) 102 | 103 | # convert for torch 104 | pytorch_dataset = TorchRLDSDataset(dataset) 105 | print("Dataset length (traj):", len(pytorch_dataset)) 106 | dataloader = DataLoader( 107 | pytorch_dataset, 108 | batch_size=64, 109 | num_workers=0, # important to keep this to 0 so PyTorch does not mess with the parallelism 110 | ) 111 | prep_time = time.time() 112 | print(f"Preparation time: {prep_time - start_time:.2f}s") 113 | 114 | print("Starting dataloader") 115 | cnt_batch = 0 116 | for _, _sample in tqdm.tqdm(enumerate(dataloader)): 117 | # _sample: dict with keys 'observation', 'task', 'action', 'dataset_name', 'action_pad_mask' 118 | # observation: 'image_primary' (torch.Size([16, 2, 256, 256, 3]), 'image_wrist', 'timestep' (torch.Size([16, 2])), 'pad_mask_dict', 'timestep_pad_mask', 'task_completed' (torch.Size([16, 2, 4]), 'proprio' (fractal: torch.Size([16, 2, 8])) 119 | # task: 'language_instruction', 'pad_mask_dict', 'image_primary', 'image_wrist', 'timestep' (torch.Size([16])) 120 | # action (torch.Size([16, 2, 4, 7]) 121 | # dataset_name 122 | # action_pad_mask (torch.Size([16, 2, 4, 7])) 123 | 124 | # timestep_pad_mask: which observations at the beginning of the trajectory are padding --- repeat the first observation at the beginning of the trajectory rather than going out of bounds 125 | # action_pad_mask: mark actions past the goal timestep as padding --- repeat the last action at the end of the trajectory rather than going out of bounds 126 | # task_completed should correspond to action_pad_mask 127 | # timestep should correspond to timestep_pad_mask (e.g., timestep [0, 0] for a datapoint indicates padding the first observation) 128 | images = _sample["observation"]["image_primary"] 129 | images = einops.rearrange( 130 | images, "B T H W C -> B (T C) H W" 131 | ) # remove cond_steps dimension 132 | texts = [ 133 | text.decode("utf-8") for text in _sample["task"]["language_instruction"] 134 | ] 135 | actions = _sample["action"] 136 | proprios = _sample["observation"]["proprio"] 137 | num_unlabled_texts = len([t for t in texts if t == ""]) 138 | print(num_unlabled_texts) 139 | 140 | # save an image 141 | img = Image.fromarray(images[0, :3].numpy().transpose(1, 2, 0)) 142 | img.save("temp/bridge_sample_img_first.png") 143 | img = Image.fromarray(images[0, -3:].numpy().transpose(1, 2, 0)) 144 | img.save("temp/bridge_sample_img_last.png") 145 | print(texts[0]) 146 | print(actions[0, -1]) 147 | breakpoint() 148 | 149 | # check padding 150 | if not _sample["observation"]["timestep_pad_mask"].all(): 151 | print("Padding for history obs past trajectory start") 152 | if not _sample["action_pad_mask"].all(): 153 | print("Padding for action chunks past trajectory end") 154 | 155 | # verify the normalization 156 | if not args.skip_norm and (actions.abs().max() > 1 or proprios.abs().max() > 1): 157 | breakpoint() 158 | cnt_batch += 1 159 | load_time = time.time() 160 | print(f"Iterative over {cnt_batch} batches: {load_time - prep_time:.2f}s") 161 | -------------------------------------------------------------------------------- /scripts/data/check_fractal.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import tensorflow as tf 4 | import torch 5 | import tqdm 6 | from PIL import Image 7 | from torch.utils.data import DataLoader 8 | 9 | from src.data.dataset import make_interleaved_dataset 10 | from src.data.dataset_torch import TorchRLDSDataset 11 | from src.data.oxe import make_oxe_dataset_kwargs_and_weights 12 | from src.utils.geometry import quat2euler 13 | 14 | tf.config.set_visible_devices([], "GPU") 15 | 16 | 17 | if __name__ == "__main__": 18 | import argparse 19 | import os 20 | 21 | import einops 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument( 25 | "--data_path", type=str, default=f"{os.environ['VLA_DATA_DIR']}/resize_224" 26 | ) 27 | parser.add_argument("--mix", type=str, default="fractal") 28 | parser.add_argument("--camera_views", nargs="*", default=("primary",)) 29 | parser.add_argument( 30 | "--skip_norm", action="store_true", help="Use raw actions and proprio" 31 | ) 32 | args = parser.parse_args() 33 | 34 | # config 35 | start_time = time.time() 36 | dataset_kwargs_list, sample_weights = make_oxe_dataset_kwargs_and_weights( 37 | args.mix, 38 | args.data_path, 39 | load_depth=False, 40 | load_language=True, 41 | load_proprio=True, 42 | load_camera_views=args.camera_views, 43 | skip_norm=args.skip_norm, 44 | ) 45 | 46 | # dataset - fractal has 82851 trajectories and 3786400 transitions 47 | dataset = make_interleaved_dataset( 48 | dataset_kwargs_list, 49 | sample_weights, 50 | train=True, 51 | split="train[:95%]", # fractal does not have validation split 52 | shuffle_buffer_size=10000, # change to 500k for training, large shuffle buffers are important, but adjust to your RAM 53 | batch_size=None, # batching will be handles in PyTorch Dataloader object 54 | balance_weights=True, 55 | traj_transform_kwargs=dict( # no neeed for goal relabeling 56 | window_size=2, 57 | action_horizon=4, 58 | subsample_length=100, 59 | skip_unlabeled=False, # skip ones without language annotation 60 | # max_action_from_stats=True, 61 | # max_proprio_from_stats=True, 62 | ), 63 | frame_transform_kwargs=dict( 64 | image_augment_kwargs={ 65 | "primary": dict( 66 | random_resized_crop=dict( 67 | scale=[0.8, 1.0], 68 | ratio=[0.9, 1.1], 69 | ), 70 | random_brightness=[0.1], 71 | random_contrast=[0.9, 1.1], 72 | random_saturation=[0.9, 1.1], 73 | random_hue=[0.05], 74 | augment_order=[ 75 | "random_resized_crop", 76 | "random_brightness", 77 | "random_contrast", 78 | "random_saturation", 79 | "random_hue", 80 | ], 81 | ), 82 | "wrist": dict( 83 | random_brightness=[0.1], 84 | random_contrast=[0.9, 1.1], 85 | random_saturation=[0.9, 1.1], 86 | random_hue=[0.05], 87 | augment_order=[ 88 | "random_brightness", 89 | "random_contrast", 90 | "random_saturation", 91 | "random_hue", 92 | ], 93 | ), 94 | }, 95 | resize_size=dict( 96 | primary=(224, 224), 97 | wrist=(224, 224), 98 | ), 99 | num_parallel_calls=200, 100 | ), 101 | traj_transform_threads=48, 102 | traj_read_threads=48, 103 | ) 104 | 105 | # convert for torch 106 | pytorch_dataset = TorchRLDSDataset(dataset) 107 | print("Dataset length (traj):", len(pytorch_dataset)) 108 | dataloader = DataLoader( 109 | pytorch_dataset, 110 | batch_size=16, 111 | num_workers=0, # important to keep this to 0 so PyTorch does not mess with the parallelism 112 | ) 113 | prep_time = time.time() 114 | print(f"Preparation time: {prep_time - start_time:.2f}s") 115 | 116 | print("Starting dataloader") 117 | cnt_batch = 0 118 | for _, _sample in tqdm.tqdm(enumerate(dataloader)): 119 | # _sample: dict with keys 'observation', 'task', 'action', 'dataset_name', 'action_pad_mask' 120 | # observation: 'image_primary' (torch.Size([16, 2, 256, 256, 3]), 'image_wrist', 'timestep' (torch.Size([16, 2])), 'pad_mask_dict', 'timestep_pad_mask', 'task_completed' (torch.Size([16, 2, 4]), 'proprio' (fractal: torch.Size([16, 2, 8])) 121 | # task: 'language_instruction', 'pad_mask_dict', 'image_primary', 'image_wrist', 'timestep' (torch.Size([16])) 122 | # action (torch.Size([16, 2, 4, 7]) 123 | # dataset_name 124 | # action_pad_mask (torch.Size([16, 2, 4, 7])) 125 | 126 | # timestep_pad_mask: which observations at the beginning of the trajectory are padding --- repeat the first observation at the beginning of the trajectory rather than going out of bounds 127 | # action_pad_mask: mark actions past the goal timestep as padding --- repeat the last action at the end of the trajectory rather than going out of bounds 128 | # task_completed should correspond to action_pad_mask 129 | # timestep should correspond to timestep_pad_mask (e.g., timestep [0, 0] for a datapoint indicates padding the first observation) 130 | images = _sample["observation"]["image_primary"] 131 | images = einops.rearrange( 132 | images, "B T H W C -> B (T C) H W" 133 | ) # remove cond_steps dimension 134 | texts = [ 135 | text.decode("utf-8") for text in _sample["task"]["language_instruction"] 136 | ] 137 | actions = _sample["action"] 138 | proprios = _sample["observation"]["proprio"] # pos, quat, gripper; 139 | sample_quat = torch.cat( 140 | (proprios[0, -1, -2:-1], proprios[0, -1, -5:-2]) 141 | ) # quat [x, y, z, w] to [w, x, y, z] 142 | sample_rpy = quat2euler(sample_quat) 143 | num_unlabled_texts = len([t for t in texts if t == ""]) 144 | print(num_unlabled_texts) 145 | 146 | # quat is in [x, y, z, w], and relative to robot base (unlike bridge that is relative to a top-down rotation). z is pointing forward/downward from the fingers, green is pointing left to the finger (sideway), and red is pointing away from the palm (pointing behind) 147 | 148 | # save an image 149 | img = Image.fromarray(images[0, :3].numpy().transpose(1, 2, 0)) 150 | img.save("temp/fractal_sample_img_first.png") 151 | img = Image.fromarray(images[0, -3:].numpy().transpose(1, 2, 0)) 152 | img.save("temp/fractal_sample_img_last.png") 153 | print(texts[0]) 154 | print(actions[0, -1].numpy()) 155 | print("w x y z", sample_quat.numpy()) 156 | breakpoint() 157 | 158 | # check padding 159 | if not _sample["observation"]["timestep_pad_mask"].all(): 160 | print("Padding for history obs past trajectory start") 161 | if not _sample["action_pad_mask"].all(): 162 | print("Padding for action chunks past trajectory end") 163 | 164 | # verify the normalization 165 | if not args.skip_norm and (actions.abs().max() > 1 or proprios.abs().max() > 1): 166 | breakpoint() 167 | cnt_batch += 1 168 | load_time = time.time() 169 | print(f"Iterative over {cnt_batch} batches: {load_time - prep_time:.2f}s") 170 | -------------------------------------------------------------------------------- /scripts/data/modify_rlds_dataset.py: -------------------------------------------------------------------------------- 1 | """Modifies TFDS dataset with a map function, updates the feature definition and stores new dataset.""" 2 | 3 | import os 4 | from functools import partial 5 | 6 | import tensorflow_datasets as tfds 7 | 8 | from src.data.oxe.preprocess.mod_functions import TFDS_MOD_FUNCTIONS 9 | from src.data.oxe.preprocess.multithreaded_adhoc_tfds_builder import ( 10 | MultiThreadedAdhocDatasetBuilder, 11 | ) 12 | 13 | # avoid GCS nonsense errors 14 | tfds.core.utils.gcs_utils._is_gcs_disabled = True 15 | os.environ["NO_GCE_CHECK"] = "true" 16 | 17 | 18 | def mod_features(mods, features): 19 | """Modifies feature dict.""" 20 | for mod in mods: 21 | features = TFDS_MOD_FUNCTIONS[mod].mod_features(features) 22 | return features 23 | 24 | 25 | def mod_dataset_generator(builder, split, mods): 26 | """Modifies dataset features.""" 27 | ds = builder.as_dataset(split=split) 28 | for mod in mods: 29 | ds = TFDS_MOD_FUNCTIONS[mod].mod_dataset(ds) 30 | for episode in tfds.core.dataset_utils.as_numpy(ds): 31 | yield episode 32 | 33 | 34 | def main(args): 35 | builder = tfds.builder(args.dataset, data_dir=args.data_dir) 36 | 37 | features = mod_features(args.mods, builder.info.features) 38 | print("############# Target features: ###############") 39 | print(features) 40 | print("##############################################") 41 | assert args.data_dir != args.target_dir # prevent overwriting original dataset 42 | 43 | mod_dataset_builder = MultiThreadedAdhocDatasetBuilder( 44 | name=args.dataset, 45 | version=builder.version, 46 | features=features, 47 | split_datasets={ 48 | split: builder.info.splits[split] for split in builder.info.splits 49 | }, 50 | config=builder.builder_config, 51 | data_dir=args.target_dir, 52 | description=builder.info.description, 53 | generator_fcn=partial(mod_dataset_generator, builder=builder, mods=args.mods), 54 | n_workers=args.n_workers, 55 | max_episodes_in_memory=args.max_episodes_in_memory, 56 | ) 57 | mod_dataset_builder.download_and_prepare() 58 | 59 | 60 | if __name__ == "__main__": 61 | import argparse 62 | 63 | parser = argparse.ArgumentParser() 64 | parser.add_argument("--dataset", type=str, required=True) 65 | parser.add_argument("--data_dir", type=str, required=True) 66 | parser.add_argument("--target_dir", type=str, required=True) 67 | parser.add_argument( 68 | "--mods", type=str, nargs="+", default=["resize_and_jpeg_encode"] 69 | ) 70 | parser.add_argument("--n_workers", type=int, default=10) 71 | parser.add_argument("--max_episodes_in_memory", type=int, default=100) 72 | args = parser.parse_args() 73 | 74 | main(args) 75 | -------------------------------------------------------------------------------- /scripts/run.py: -------------------------------------------------------------------------------- 1 | """ 2 | Launcher for all experiments. 3 | 4 | """ 5 | 6 | import logging 7 | import math 8 | import os 9 | import random 10 | import sys 11 | 12 | import hydra 13 | import numpy as np 14 | import pretty_errors 15 | import torch 16 | from omegaconf import OmegaConf, open_dict 17 | 18 | # dummy 19 | print(pretty_errors.__version__) 20 | 21 | # allows arbitrary python code execution in configs using the ${eval:''} resolver 22 | OmegaConf.register_new_resolver("eval", eval, replace=True) 23 | OmegaConf.register_new_resolver("round_up", math.ceil) 24 | OmegaConf.register_new_resolver("round_down", math.floor) 25 | 26 | # add logger 27 | log = logging.getLogger(__name__) 28 | 29 | # use line-buffering for both stdout and stderr 30 | sys.stdout = open(sys.stdout.fileno(), mode="w", buffering=1) 31 | sys.stderr = open(sys.stderr.fileno(), mode="w", buffering=1) 32 | 33 | 34 | def _main(cfg: OmegaConf): 35 | # resolve immediately so all the ${now:} resolvers will use the same time. 36 | OmegaConf.resolve(cfg) 37 | 38 | # figure out the current gpu 39 | multi_gpu = torch.cuda.device_count() > 1 or cfg.get("n_nodes", 1) > 1 40 | if multi_gpu: 41 | from torch.distributed import destroy_process_group, init_process_group 42 | 43 | def ddp_setup(): 44 | torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) 45 | init_process_group(backend="nccl") 46 | 47 | ddp_setup() 48 | gpu_id = int(os.environ["LOCAL_RANK"]) 49 | else: 50 | gpu_id = 0 51 | with open_dict(cfg): 52 | cfg.gpu_id = gpu_id 53 | cfg.multi_gpu = multi_gpu 54 | 55 | # seeding 56 | seed = cfg.get("seed", 42) 57 | random.seed(seed) 58 | np.random.seed(seed) 59 | torch.manual_seed(seed) 60 | 61 | # run agent 62 | cls = hydra.utils.get_class(cfg._target_) 63 | agent = cls(cfg) 64 | agent.run() 65 | 66 | if multi_gpu: 67 | destroy_process_group() 68 | 69 | 70 | @hydra.main( 71 | version_base=None, 72 | config_path=os.path.join(os.getcwd(), "config/train"), 73 | config_name="bridge.yaml", 74 | ) # defaults 75 | def main(cfg: OmegaConf): 76 | _main(cfg) 77 | 78 | 79 | if __name__ == "__main__": 80 | main() 81 | -------------------------------------------------------------------------------- /scripts/set_path.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ##################### Paths ##################### 4 | 5 | # Set default paths 6 | DEFAULT_DATA_DIR="${PWD}/data" 7 | DEFAULT_LOG_DIR="${PWD}/log" 8 | 9 | # Prompt the user for input, allowing overrides 10 | read -p "Enter the desired data directory [default: ${DEFAULT_DATA_DIR}], leave empty to use default: " DATA_DIR 11 | VLA_DATA_DIR=${DATA_DIR:-$DEFAULT_DATA_DIR} # Use user input or default if input is empty 12 | 13 | read -p "Enter the desired logging directory [default: ${DEFAULT_LOG_DIR}], leave empty to use default: " LOG_DIR 14 | VLA_LOG_DIR=${LOG_DIR:-$DEFAULT_LOG_DIR} # Use user input or default if input is empty 15 | 16 | # Export to current session 17 | export VLA_DATA_DIR="$VLA_DATA_DIR" 18 | export VLA_LOG_DIR="$VLA_LOG_DIR" 19 | 20 | # Confirm the paths with the user 21 | echo "Data directory set to: $VLA_DATA_DIR" 22 | echo "Log directory set to: $VLA_LOG_DIR" 23 | 24 | # Append environment variables to .bashrc 25 | echo "export VLA_DATA_DIR=\"$VLA_DATA_DIR\"" >> ~/.bashrc 26 | echo "export VLA_LOG_DIR=\"$VLA_LOG_DIR\"" >> ~/.bashrc 27 | 28 | echo "Environment variables VLA_DATA_DIR and VLA_LOG_DIR added to .bashrc and applied to the current session." 29 | 30 | ##################### WandB ##################### 31 | 32 | # Prompt the user for input, allowing overrides 33 | read -p "Enter your WandB entity (username or team name), leave empty to skip: " ENTITY 34 | 35 | # Check if ENTITY is not empty 36 | if [ -n "$ENTITY" ]; then 37 | # If ENTITY is not empty, set the environment variable 38 | export VLA_WANDB_ENTITY="$ENTITY" 39 | 40 | # Confirm the entity with the user 41 | echo "WandB entity set to: $VLA_WANDB_ENTITY" 42 | 43 | # Append environment variable to .bashrc 44 | echo "export VLA_WANDB_ENTITY=\"$ENTITY\"" >> ~/.bashrc 45 | 46 | echo "Environment variable VLA_WANDB_ENTITY added to .bashrc and applied to the current session." 47 | else 48 | # If ENTITY is empty, skip setting the environment variable 49 | echo "No WandB entity provided. Please set wandb=null when running scripts to disable wandb logging and avoid error." 50 | fi 51 | 52 | ##################### HF ##################### 53 | 54 | echo "Please also set TRANSFORMERS_CACHE (Huggingface cache) and download PaliGemma weights there." 55 | -------------------------------------------------------------------------------- /scripts/tests/oxe.py: -------------------------------------------------------------------------------- 1 | """ " 2 | Visualize and download OXE dataset 3 | 4 | Reference: https://colab.research.google.com/github/google-deepmind/open_x_embodiment/blob/main/colabs/Open_X_Embodiment_Datasets.ipynb 5 | 6 | """ 7 | 8 | import os 9 | 10 | import tensorflow_datasets as tfds 11 | from PIL import Image 12 | 13 | DATASETS = [ 14 | "fractal20220817_data", 15 | # "kuka", 16 | "bridge", 17 | # "taco_play", 18 | # "jaco_play", 19 | # "berkeley_cable_routing", 20 | # "roboturk", 21 | # "nyu_door_opening_surprising_effectiveness", 22 | # "viola", 23 | # "berkeley_autolab_ur5", 24 | # "toto", 25 | # "language_table", 26 | # "columbia_cairlab_pusht_real", 27 | # "stanford_kuka_multimodal_dataset_converted_externally_to_rlds", 28 | # "nyu_rot_dataset_converted_externally_to_rlds", 29 | # "stanford_hydra_dataset_converted_externally_to_rlds", 30 | # "austin_buds_dataset_converted_externally_to_rlds", 31 | # "nyu_franka_play_dataset_converted_externally_to_rlds", 32 | # "maniskill_dataset_converted_externally_to_rlds", 33 | # "cmu_franka_exploration_dataset_converted_externally_to_rlds", 34 | # "ucsd_kitchen_dataset_converted_externally_to_rlds", 35 | # "ucsd_pick_and_place_dataset_converted_externally_to_rlds", 36 | # "austin_sailor_dataset_converted_externally_to_rlds", 37 | # "austin_sirius_dataset_converted_externally_to_rlds", 38 | "bc_z", 39 | # "usc_cloth_sim_converted_externally_to_rlds", 40 | # "utokyo_pr2_opening_fridge_converted_externally_to_rlds", 41 | # "utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds", 42 | # "utokyo_saytap_converted_externally_to_rlds", 43 | # "utokyo_xarm_pick_and_place_converted_externally_to_rlds", 44 | # "utokyo_xarm_bimanual_converted_externally_to_rlds", 45 | # "robo_net", 46 | # "berkeley_mvp_converted_externally_to_rlds", 47 | # "berkeley_rpt_converted_externally_to_rlds", 48 | # "kaist_nonprehensile_converted_externally_to_rlds", 49 | # "stanford_mask_vit_converted_externally_to_rlds", 50 | # "tokyo_u_lsmo_converted_externally_to_rlds", 51 | # "dlr_sara_pour_converted_externally_to_rlds", 52 | # "dlr_sara_grid_clamp_converted_externally_to_rlds", 53 | # "dlr_edan_shared_control_converted_externally_to_rlds", 54 | # "asu_table_top_converted_externally_to_rlds", 55 | # "stanford_robocook_converted_externally_to_rlds", 56 | # "eth_agent_affordances", 57 | # "imperialcollege_sawyer_wrist_cam", 58 | # "iamlab_cmu_pickup_insert_converted_externally_to_rlds", 59 | # "uiuc_d3field", 60 | # "utaustin_mutex", 61 | # "berkeley_fanuc_manipulation", 62 | # "cmu_play_fusion", 63 | # "cmu_stretch", 64 | # "berkeley_gnm_recon", 65 | # "berkeley_gnm_cory_hall", 66 | # "berkeley_gnm_sac_son", 67 | ] 68 | 69 | 70 | def dataset2path(dataset_name): 71 | if dataset_name == "robo_net": 72 | version = "1.0.0" 73 | elif dataset_name == "language_table": 74 | version = "0.0.1" 75 | else: 76 | version = "0.1.0" 77 | return f"{os.environ['VLA_DATA_DIR']}/{dataset_name}/{version}" 78 | 79 | 80 | def as_gif(images, path="temp.gif"): 81 | # Render the images as the gif: 82 | images[0].save(path, save_all=True, append_images=images[1:], duration=1000, loop=0) 83 | gif_bytes = open(path, "rb").read() 84 | return gif_bytes 85 | 86 | 87 | def visualize_image( 88 | dataset="bridge_dataset", 89 | display_key="image_0", 90 | data_dir=f"{os.environ['VLA_DATA_DIR']}/resize_224", 91 | ): 92 | ds, ds_info = tfds.load( 93 | name=dataset, 94 | data_dir=data_dir, 95 | download=False, 96 | split="train[1000:2000]", 97 | shuffle_files=True, 98 | with_info=True, 99 | ) 100 | if display_key not in ds_info.features["steps"]["observation"]: 101 | raise ValueError( 102 | f"The key {display_key} was not found in this dataset.\n" 103 | + "Please choose a different image key to display for this dataset.\n" 104 | + "Here is the observation spec:\n" 105 | + str(ds_info.features["steps"]["observation"]) 106 | ) 107 | 108 | # save ds_info in text 109 | with open(f"temp/{dataset}_info.txt", "w") as f: 110 | f.write(str(ds_info)) 111 | 112 | # inspect data 113 | iterator = iter(ds) 114 | while 1: 115 | episode = next(iterator) 116 | 117 | instructions = [ 118 | step["language_instruction"].numpy().decode("utf-8") 119 | for step in episode["steps"] 120 | ] 121 | images = [step["observation"][display_key] for step in episode["steps"]] 122 | images = [Image.fromarray(image.numpy()) for image in images] 123 | proprios = [step["observation"]["state"] for step in episode["steps"]] 124 | if "carrot" in instructions[0].lower(): 125 | print(instructions) 126 | print(proprios) 127 | 128 | # print image info and save one 129 | print(f"Image shape: {images[0].size}") 130 | images[0].save(f"temp/{dataset}_sample_img.png") 131 | breakpoint() 132 | 133 | # check instructions 134 | print(instructions[0]) 135 | 136 | 137 | if __name__ == "__main__": 138 | import argparse 139 | import os 140 | 141 | parser = argparse.ArgumentParser() 142 | parser.add_argument("--dataset", type=str, default="bridge_dataset") 143 | parser.add_argument("--display_key", type=str, default="image_0") 144 | parser.add_argument( 145 | "--data_dir", type=str, default=f"{os.environ['VLA_DATA_DIR']}/resize_224" 146 | ) 147 | args = parser.parse_args() 148 | 149 | visualize_image(**vars(args)) 150 | -------------------------------------------------------------------------------- /scripts/tests/run_paligemma.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import torch 5 | from PIL import Image 6 | 7 | from src.model.kv_cache import KVCache 8 | from src.model.paligemma.load import load_hf_model 9 | from src.model.paligemma.processing import PaliGemmaProcessor 10 | from src.utils.monitor import log_allocated_gpu_memory 11 | 12 | 13 | def move_inputs_to_device(model_inputs: dict, device: str): 14 | model_inputs = {k: v.to(device) for k, v in model_inputs.items()} 15 | return model_inputs 16 | 17 | 18 | def get_model_inputs( 19 | processor, 20 | prompt: str, 21 | image_file_path: str, 22 | device: str, 23 | ): 24 | image = Image.open(image_file_path).convert("RGB") 25 | images = [image] 26 | prompts = [prompt] 27 | model_inputs = processor(text=prompts, images=images) 28 | model_inputs = move_inputs_to_device(model_inputs, device) 29 | return model_inputs 30 | 31 | 32 | def test_inference( 33 | model, 34 | processor, 35 | device: str, 36 | prompt: str, 37 | image_file_path: str, 38 | max_tokens_to_generate: int, 39 | temperature: float, 40 | top_p: float, 41 | do_sample: bool, 42 | ): 43 | model_inputs = get_model_inputs(processor, prompt, image_file_path, device) 44 | input_ids = model_inputs["input_ids"] 45 | attention_mask = model_inputs["attention_mask"] 46 | pixel_values = model_inputs["pixel_values"] 47 | 48 | kv_cache = KVCache() 49 | 50 | # Generate tokens until you see the stop token 51 | stop_token = processor.tokenizer.eos_token_id 52 | generated_tokens = [] 53 | 54 | for _ in range(max_tokens_to_generate): 55 | # Get the model outputs 56 | outputs = model( 57 | input_ids=input_ids, 58 | pixel_values=pixel_values, 59 | attention_mask=attention_mask, 60 | kv_cache=kv_cache, 61 | ) 62 | kv_cache = outputs["kv_cache"] 63 | next_token_logits = outputs["logits"][:, -1, :] 64 | # Sample the next token 65 | if do_sample: 66 | # Apply temperature 67 | next_token_logits = torch.softmax(next_token_logits / temperature, dim=-1) 68 | next_token = _sample_top_p(next_token_logits, top_p) 69 | else: 70 | next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) 71 | assert next_token.size() == (1, 1) 72 | next_token = next_token.squeeze(0) # Remove batch dimension 73 | generated_tokens.append(next_token) 74 | # Stop if the stop token has been generated 75 | if next_token.item() == stop_token: 76 | break 77 | # Append the next token to the input --- use cache so only the new token 78 | input_ids = next_token.unsqueeze(-1) 79 | attention_mask = torch.cat( 80 | [attention_mask, torch.ones((1, 1), device=input_ids.device)], dim=-1 81 | ) 82 | 83 | generated_tokens = torch.cat(generated_tokens, dim=-1) 84 | # Decode the generated tokens 85 | decoded = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True) 86 | 87 | print(prompt + decoded) 88 | 89 | 90 | def _sample_top_p(probs: torch.Tensor, p: float): 91 | # (B, vocab_size) 92 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) 93 | # (B, vocab_size) 94 | probs_sum = torch.cumsum(probs_sort, dim=-1) 95 | # (B, vocab_size) 96 | # (Substracting "probs_sort" shifts the cumulative sum by 1 position to the right before masking) 97 | mask = probs_sum - probs_sort > p 98 | # Zero out all the probabilities of tokens that are not selected by the Top P 99 | probs_sort[mask] = 0.0 100 | # Redistribute the probabilities so that they sum up to 1. 101 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) 102 | # Sample a token (its index) from the top p distribution 103 | next_token = torch.multinomial(probs_sort, num_samples=1) 104 | # Get the token position in the vocabulary corresponding to the sampled index 105 | next_token = torch.gather(probs_idx, -1, next_token) 106 | return next_token 107 | 108 | 109 | def main( 110 | prompt, 111 | image_file_path, 112 | max_tokens_to_generate: int = 100, 113 | temperature: float = 0.8, 114 | top_p: float = 0.9, 115 | do_sample: bool = False, 116 | only_cpu: bool = False, 117 | quantize: bool = False, 118 | ): 119 | device = "cpu" 120 | model_path = f"{os.environ['TRANSFORMERS_CACHE']}/paligemma-3b-pt-224" 121 | 122 | if not only_cpu: 123 | if torch.cuda.is_available(): 124 | device = "cuda" 125 | elif torch.backends.mps.is_available(): 126 | device = "mps" 127 | 128 | print("Device in use: ", device) 129 | 130 | print("Loading model") 131 | time_start_load = time.time() 132 | model, tokenizer = load_hf_model(model_path, device, quantize=quantize) 133 | model = model.to(device).eval() 134 | # cast 135 | model = model.to(torch.bfloat16) 136 | time_end_load = time.time() 137 | print(f"Model loaded in {time_end_load - time_start_load:.2f} seconds") 138 | log_allocated_gpu_memory(stage="loading model") 139 | print(f"lm head dtype: {model.language_model.lm_head.weight.dtype}") 140 | 141 | num_image_tokens = model.config.vision_config.num_image_tokens 142 | image_size = model.config.vision_config.image_size 143 | processor = PaliGemmaProcessor(tokenizer, num_image_tokens, image_size) 144 | 145 | print("Running inference") 146 | time_start_inference = time.time() 147 | with torch.inference_mode(): 148 | test_inference( 149 | model, 150 | processor, 151 | device, 152 | prompt, 153 | image_file_path, 154 | max_tokens_to_generate, 155 | temperature, 156 | top_p, 157 | do_sample, 158 | ) 159 | print("Inference time: ", time.time() - time_start_inference) 160 | 161 | 162 | if __name__ == "__main__": 163 | import argparse 164 | 165 | parser = argparse.ArgumentParser() 166 | parser.add_argument("--prompt", type=str) 167 | parser.add_argument("--image_file_path", type=str) 168 | parser.add_argument("--max_tokens_to_generate", type=int, default=100) 169 | parser.add_argument("--temperature", type=float, default=0.8) 170 | parser.add_argument("--top_p", type=float, default=0.9) 171 | parser.add_argument("--do_sample", action="store_true") 172 | parser.add_argument("--only_cpu", action="store_true") 173 | parser.add_argument("--quantize", action="store_true") 174 | args = parser.parse_args() 175 | 176 | main(**vars(args)) 177 | -------------------------------------------------------------------------------- /scripts/tests/sampling.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate Gamma-distributed samples without using scipy, proposed by pi0 paper 3 | 4 | """ 5 | 6 | import matplotlib.pyplot as plt 7 | import torch 8 | 9 | # Parameters 10 | s = 0.999 11 | alpha = 1.5 12 | beta = 1.0 13 | n_sample = int(1e5) 14 | 15 | # use gamma: https://math.stackexchange.com/questions/190670/how-exactly-are-the-beta-and-gamma-distributions-related 16 | gamma_alpha_dist = torch.distributions.Gamma(alpha, 1) 17 | gamma_beta_dist = torch.distributions.Gamma(beta, 1) 18 | 19 | x = gamma_alpha_dist.sample((n_sample,)) 20 | y = gamma_beta_dist.sample((n_sample,)) 21 | z = x / (x + y) 22 | scaled_samples = s * (1 - z) 23 | 24 | print("Min:", scaled_samples.min()) 25 | print("Max:", scaled_samples.max()) 26 | 27 | plt.figure() 28 | plt.hist(scaled_samples, bins=30, density=True, alpha=0.7, color="blue") 29 | plt.xlabel("τ") 30 | plt.ylabel("Density") 31 | plt.title(f"Samples from Beta((s-τ)/s; {alpha}, {beta}) with s={s}") 32 | plt.grid(True) 33 | plt.savefig("beta_samples_from_gamma.png") 34 | 35 | # use beta directly 36 | beta_dist = torch.distributions.Beta(alpha, beta) 37 | samples = beta_dist.sample((n_sample,)) 38 | scaled_samples = s * (1 - samples) 39 | print("Min:", scaled_samples.min()) 40 | print("Max:", scaled_samples.max()) 41 | 42 | plt.figure() 43 | plt.hist(scaled_samples, bins=30, density=True, alpha=0.7, color="blue") 44 | plt.xlabel("τ") 45 | plt.ylabel("Density") 46 | plt.title(f"Samples from Beta((s-τ)/s; {alpha}, {beta}) with s={s}") 47 | plt.grid(True) 48 | plt.savefig("beta_samples.png") 49 | -------------------------------------------------------------------------------- /scripts/tests/simpler.py: -------------------------------------------------------------------------------- 1 | import simpler_env 2 | from simpler_env.utils.env.observation_utils import get_image_from_maniskill2_obs_dict 3 | 4 | 5 | def main( 6 | task: str = "google_robot_pick_coke_can", 7 | ): 8 | env = simpler_env.make("google_robot_pick_coke_can") 9 | obs, reset_info = env.reset() 10 | instruction = env.get_language_instruction() 11 | print("Reset info", reset_info) 12 | print("Instruction", instruction) 13 | image = get_image_from_maniskill2_obs_dict(env, obs) 14 | print("Image shape and dtype", image.shape, image.dtype) 15 | 16 | done, truncated = False, False 17 | step = 0 18 | while not (done or truncated): 19 | # action[:3]: delta xyz; action[3:6]: delta rotation in axis-angle representation; 20 | # action[6:7]: gripper (the meaning of open / close depends on robot URDF) 21 | image = get_image_from_maniskill2_obs_dict(env, obs) 22 | action = env.action_space.sample() # replace this with your policy inference 23 | if step == 0: 24 | print("Action shape and dtype", action.shape, action.dtype) 25 | print("Action", action) 26 | obs, reward, done, truncated, info = env.step( 27 | action 28 | ) # for long horizon tasks, you can call env.advance_to_next_subtask() to advance to the next subtask; the environment might also autoadvance if env._elapsed_steps is larger than a threshold 29 | new_instruction = env.get_language_instruction() 30 | if new_instruction != instruction: 31 | # for long horizon tasks, we get a new instruction when robot proceeds to the next subtask 32 | instruction = new_instruction 33 | print("New Instruction", instruction) 34 | 35 | print("Step", step, "Reward", reward, "Done", done, "Truncated", truncated) 36 | step += 1 37 | 38 | episode_stats = info.get("episode_stats", {}) 39 | print("Episode stats", episode_stats) 40 | 41 | 42 | if __name__ == "__main__": 43 | import argparse 44 | 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--task", type=str, default="google_robot_pick_coke_can") 47 | args = parser.parse_args() 48 | 49 | print("=====================================") 50 | print("Available environments:") 51 | print(simpler_env.ENVIRONMENTS) 52 | print("=====================================") 53 | print("Running task", args.task) 54 | # [ 55 | # "google_robot_pick_coke_can", 56 | # "google_robot_pick_horizontal_coke_can", 57 | # "google_robot_pick_vertical_coke_can", 58 | # "google_robot_pick_standing_coke_can", 59 | # "google_robot_pick_object", 60 | # "google_robot_move_near_v0", 61 | # "google_robot_move_near_v1", 62 | # "google_robot_move_near", 63 | # "google_robot_open_drawer", 64 | # "google_robot_open_top_drawer", 65 | # "google_robot_open_middle_drawer", 66 | # "google_robot_open_bottom_drawer", 67 | # "google_robot_close_drawer", 68 | # "google_robot_close_top_drawer", 69 | # "google_robot_close_middle_drawer", 70 | # "google_robot_close_bottom_drawer", 71 | # "google_robot_place_in_closed_drawer", 72 | # "google_robot_place_in_closed_top_drawer", 73 | # "google_robot_place_in_closed_middle_drawer", 74 | # "google_robot_place_in_closed_bottom_drawer", 75 | # "google_robot_place_apple_in_closed_top_drawer", 76 | # "widowx_spoon_on_towel", 77 | # "widowx_carrot_on_plate", 78 | # "widowx_stack_cube", 79 | # "widowx_put_eggplant_in_basket", 80 | # ] 81 | main(args.task) 82 | -------------------------------------------------------------------------------- /scripts/try_checkpoint_in_simpler.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import time 4 | 5 | import hydra 6 | import imageio 7 | import numpy as np 8 | import simpler_env 9 | import torch 10 | from omegaconf import OmegaConf 11 | 12 | from src.model.vla.pizero import PiZeroInference 13 | from src.utils.monitor import log_allocated_gpu_memory, log_execution_time 14 | 15 | 16 | @log_execution_time() 17 | def load_checkpoint(model, path): 18 | """load to cpu first, then move to gpu""" 19 | data = torch.load(path, weights_only=True, map_location="cpu") 20 | # remove "_orig_mod." prefix if saved model was compiled 21 | data["model"] = {k.replace("_orig_mod.", ""): v for k, v in data["model"].items()} 22 | model.load_state_dict(data["model"], strict=True) 23 | print(f"Loaded model from {path}") 24 | 25 | 26 | def main(args): 27 | # seeding 28 | random.seed(args.seed) 29 | np.random.seed(args.seed) 30 | torch.manual_seed(args.seed) 31 | 32 | # devices 33 | device = torch.device(f"cuda:{args.gpu_id}") 34 | 35 | # load default configs 36 | if "fractal" in args.checkpoint_path: 37 | cfg = OmegaConf.load( 38 | "config/eval/fractal_apple.yaml" 39 | ) # doesn't matter which task 40 | if "bridge" in args.checkpoint_path: 41 | cfg = OmegaConf.load("config/eval/bridge.yaml") 42 | 43 | # model 44 | dtype = torch.bfloat16 if args.use_bf16 else torch.float32 45 | model = PiZeroInference(cfg, use_ddp=False) 46 | load_checkpoint(model, args.checkpoint_path) 47 | model.freeze_all_weights() 48 | model.to(dtype) 49 | model.to(device) 50 | if ( 51 | args.use_torch_compile 52 | ): # model being compiled in the first batch which takes some time 53 | model = torch.compile( 54 | model, 55 | mode="default", # "reduce-overhead; max-autotune(-no-cudagraphs) 56 | # backend="inductor", # default: inductor; cudagraphs 57 | ) 58 | # modes: https://pytorch.org/docs/main/generated/torch.compile.html 59 | # backends: https://pytorch.org/docs/stable/torch.compiler.html 60 | model.eval() 61 | print(f"Using cuda device: {device} dtype: {dtype}") 62 | log_allocated_gpu_memory(None, "loading model", args.gpu_id) 63 | 64 | # simpler env 65 | env = simpler_env.make(args.task) 66 | 67 | # env specifics 68 | env_adapter = hydra.utils.instantiate(cfg.env.adapter) 69 | env_adapter.reset() 70 | 71 | # run an episode 72 | episode_id = random.randint(0, 20) 73 | env_reset_options = {} 74 | env_reset_options["obj_init_options"] = { 75 | "episode_id": episode_id, # this determines the obj inits in bridge 76 | } 77 | obs, reset_info = env.reset(options=env_reset_options) 78 | instruction = env.get_language_instruction() 79 | if args.recording: 80 | os.environ["TOKENIZERS_PARALLELISM"] = ( 81 | "false" # avoid tokenizer forking warning about deadlock 82 | ) 83 | video_writer = imageio.get_writer(f"try_{args.task}_{episode_id}.mp4") 84 | print( 85 | f"Reset info: {reset_info} Instruction: {instruction} Max episode length: {env.spec.max_episode_steps}" 86 | ) 87 | cnt_step = 0 88 | inference_times = [] 89 | while 1: 90 | # infer action chunk 91 | inputs = env_adapter.preprocess(env, obs, instruction) 92 | causal_mask, vlm_position_ids, proprio_position_ids, action_position_ids = ( 93 | model.build_causal_mask_and_position_ids( 94 | inputs["attention_mask"], dtype=dtype 95 | ) 96 | ) 97 | image_text_proprio_mask, action_mask = model.split_full_mask_into_submasks( 98 | causal_mask 99 | ) 100 | inputs = { 101 | "input_ids": inputs["input_ids"], 102 | "pixel_values": inputs["pixel_values"].to(dtype), 103 | "image_text_proprio_mask": image_text_proprio_mask, 104 | "action_mask": action_mask, 105 | "vlm_position_ids": vlm_position_ids, 106 | "proprio_position_ids": proprio_position_ids, 107 | "action_position_ids": action_position_ids, 108 | "proprios": inputs["proprios"].to(dtype), 109 | } 110 | inputs = {k: v.to(device) for k, v in inputs.items()} 111 | start_inference_time = time.time() 112 | with torch.inference_mode(): # speeds up 113 | actions = model(**inputs) 114 | if cnt_step > 0: 115 | inference_times.append(time.time() - start_inference_time) 116 | env_actions = env_adapter.postprocess(actions[0].float().cpu().numpy()) 117 | 118 | # environment step 119 | for env_action in env_actions[: cfg.act_steps]: 120 | obs, reward, success, truncated, info = env.step(env_action) 121 | cnt_step += 1 122 | if truncated: 123 | break 124 | 125 | # save frame 126 | if args.recording: 127 | video_writer.append_data(env_adapter.get_video_frame(env, obs)) 128 | 129 | # update instruction in long horizon tasks, e.g., pick apple ---> put in top drawer 130 | new_instruction = env.get_language_instruction() 131 | if new_instruction != instruction: 132 | instruction = new_instruction 133 | 134 | # original octo eval only done when timeout, i.e., not upon success 135 | if truncated: 136 | if args.recording: 137 | video_writer.close() 138 | break 139 | 140 | # summary 141 | print("\n\n============ Summary ============") 142 | print(f"Checkpoint: {args.checkpoint_path}") 143 | print(f"Action chunk steps (predicted): {cfg.horizon_steps}") 144 | print(f"Action chunk steps (executed): {cfg.act_steps}") 145 | print(f"Avg inference time (excluding first step): {np.mean(inference_times):.3f}s") 146 | print( 147 | f"Peak VRAM usage: {torch.cuda.max_memory_reserved(args.gpu_id) / 1024 ** 3:.2f} GB" 148 | ) 149 | print(f"Task: {args.task}") 150 | print(f"Total environment steps: {cnt_step}") 151 | print(f"Success: {success}") 152 | if args.recording: 153 | print(f"Video saved as try_{args.task}_{episode_id}.mp4") 154 | print("======================================\n\n") 155 | 156 | 157 | if __name__ == "__main__": 158 | import argparse 159 | 160 | parser = argparse.ArgumentParser() 161 | parser.add_argument( 162 | "--task", 163 | type=str, 164 | default="google_robot_pick_horizontal_coke_can", 165 | choices=[ 166 | "widowx_carrot_on_plate", 167 | "widowx_put_eggplant_in_basket", 168 | "widowx_spoon_on_towel", 169 | "widowx_stack_cube", 170 | "google_robot_pick_horizontal_coke_can", 171 | "google_robot_pick_vertical_coke_can", 172 | "google_robot_pick_standing_coke_can", 173 | "google_robot_move_near_v0", 174 | "google_robot_open_drawer", 175 | "google_robot_close_drawer", 176 | "google_robot_place_apple_in_closed_top_drawer", 177 | ], 178 | ) 179 | parser.add_argument("--checkpoint_path", type=str) 180 | parser.add_argument("--gpu_id", type=int, default=0) 181 | parser.add_argument("--seed", type=int, default=42) 182 | parser.add_argument("--use_bf16", action="store_true") 183 | parser.add_argument("--use_torch_compile", action="store_true") 184 | parser.add_argument("--recording", action="store_true") 185 | args = parser.parse_args() 186 | 187 | # check task 188 | if "google_robot" in args.task: 189 | assert "fractal" in args.checkpoint_path 190 | if "widowx" in args.task: 191 | assert "bridge" in args.checkpoint_path 192 | 193 | main(args) 194 | -------------------------------------------------------------------------------- /slurm/eval_simpler_bridge.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=eval-bridge 4 | #SBATCH --output=logs/eval/%A.out 5 | #SBATCH --error=logs/eval/%A.err 6 | #SBATCH --time=5:59:59 7 | #SBATCH --nodes=1 8 | #SBATCH --gres=gpu:1 9 | #SBATCH --ntasks-per-node=1 10 | #SBATCH --cpus-per-task=8 11 | #SBATCH --mem=40G 12 | 13 | # better to run jobs for each task 14 | TASKS=( 15 | "widowx_carrot_on_plate" 16 | "widowx_put_eggplant_in_basket" 17 | "widowx_spoon_on_towel" 18 | "widowx_stack_cube" 19 | ) 20 | 21 | N_EVAL_EPISODE=240 # octo simpler runs 3 seeds with 24 configs each, here we run 10 seeds 22 | 23 | for TASK in ${TASKS[@]}; do 24 | 25 | CUDA_VISIBLE_DEVICES=0 HYDRA_FULL_ERROR=1 uv run \ 26 | scripts/run.py \ 27 | --config-name=bridge \ 28 | --config-path=../config/eval \ 29 | device=cuda:0 \ 30 | seed=42 \ 31 | n_eval_episode=$N_EVAL_EPISODE \ 32 | n_video=$N_EVAL_EPISODE \ 33 | env.task=$TASK \ 34 | horizon_steps=4 \ 35 | act_steps=4 \ 36 | use_bf16=False \ 37 | use_torch_compile=True \ 38 | name=bridge_beta \ 39 | 'checkpoint_path="...bridge_beta.pt"' 40 | done 41 | -------------------------------------------------------------------------------- /slurm/eval_simpler_fractal.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=eval-fractal 4 | #SBATCH --output=logs/eval/%A.out 5 | #SBATCH --error=logs/eval/%A.err 6 | #SBATCH --time=15:59:59 7 | #SBATCH --nodes=1 8 | #SBATCH --gres=gpu:1 9 | #SBATCH --ntasks-per-node=1 10 | #SBATCH --cpus-per-task=8 11 | #SBATCH --mem=40G 12 | 13 | # better to run jobs for each task 14 | TASK_CONFIGS=( 15 | "google_robot_pick_horizontal_coke_can:fractal_coke" 16 | "google_robot_pick_vertical_coke_can:fractal_coke" 17 | "google_robot_pick_standing_coke_can:fractal_coke" 18 | "google_robot_move_near_v0:fractal_move" 19 | "google_robot_open_drawer:fractal_drawer" 20 | "google_robot_close_drawer:fractal_drawer" 21 | "google_robot_place_apple_in_closed_top_drawer:fractal_apple" 22 | ) 23 | # see the config file for the number of episodes in each task 24 | 25 | for TASK_CONFIG in "${TASK_CONFIGS[@]}" ; do 26 | 27 | TASK="${TASK_CONFIG%%:*}" 28 | CONFIG_NAME="${TASK_CONFIG##*:}" 29 | 30 | CUDA_VISIBLE_DEVICES=0 HYDRA_FULL_ERROR=1 uv run \ 31 | scripts/run.py \ 32 | --config-name=$CONFIG_NAME \ 33 | --config-path=../config/eval \ 34 | device=cuda:0 \ 35 | seed=42 \ 36 | env.task=$TASK \ 37 | horizon_steps=4 \ 38 | act_steps=2 \ 39 | use_bf16=False \ 40 | use_torch_compile=True \ 41 | name=fractal_beta \ 42 | 'checkpoint_path="...fractal_beta.pt"' 43 | done 44 | -------------------------------------------------------------------------------- /slurm/modify_rlds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=modify-rlds 4 | #SBATCH --nodes=1 5 | #SBATCH --ntasks=1 6 | #SBATCH --cpus-per-task=50 7 | #SBATCH --mem=200G 8 | #SBATCH --time=23:59:59 9 | #SBATCH --job-name=rlds 10 | #SBATCH --output=logs/rlds-%J.log 11 | #SBATCH --error=logs/rlds-%J.err 12 | 13 | # increase limit on number of files opened in parallel to 20k --> conversion opens up to 1k temporary files in /tmp to store dataset during conversion 14 | ulimit -n 20000 15 | 16 | # dataset: bridge_dataset, or fractal20220817_data 17 | uv run python scripts/data/modify_rlds_dataset.py \ 18 | --dataset=bridge_dataset \ 19 | --data_dir=$VLA_DATA_DIR \ 20 | --target_dir=$VLA_DATA_DIR/resize_224 \ 21 | --mods=resize_and_jpeg_encode \ 22 | --n_workers=40 \ 23 | --max_episodes_in_memory=200 24 | -------------------------------------------------------------------------------- /slurm/test_training_single_gpu_no_slurm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # no wandb logging 4 | # logging GPU memory usage 5 | # smaller global batch size 6 | # small resource for dataloading 7 | # try saving model 8 | 9 | # first batch will take a while with torch.compile as model being compiled 10 | CUDA_VISIBLE_DEVICES=0 HYDRA_FULL_ERROR=1 uv run \ 11 | scripts/run.py \ 12 | --config-name=bridge \ 13 | device=cuda:0 \ 14 | debug=True \ 15 | wandb=null \ 16 | log_dir=results/test/ \ 17 | global_batch_size=16 \ 18 | per_device_batch_size=8 \ 19 | flow_sampling=beta \ 20 | data.train.shuffle_buffer_size=10000 \ 21 | data.train.num_parallel_calls=10 \ 22 | eval_freq=32 \ 23 | eval_size=64 \ 24 | save_model_freq=16 \ 25 | save_model_start=0 \ 26 | lora=False \ 27 | quantize=False \ 28 | use_torch_compile=True \ 29 | use_bf16=True \ 30 | use_amp=True \ 31 | use_ema=True \ 32 | ema_decay=0.99 \ 33 | ema_device=cuda \ 34 | use_swa=False \ 35 | swa_start=0 \ 36 | swa_freq=2 \ 37 | swa_device=cpu \ 38 | action_lr_scheduler.warmup_steps=0 \ 39 | vlm_lr_scheduler.warmup_steps=0 40 | # 'resume_checkpoint_path="...fractal_train[:95%]_tp4_beta...ckpt....pt"' 41 | -------------------------------------------------------------------------------- /slurm/train_multi_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=pg-vla 4 | #SBATCH --output=logs/%A.out 5 | #SBATCH --error=logs/%A.err 6 | #SBATCH --time=71:59:59 7 | #SBATCH --nodes=1 8 | #SBATCH --gres=gpu:8 9 | #SBATCH --ntasks-per-node=1 10 | #SBATCH --cpus-per-task=104 11 | #SBATCH --mem=500G 12 | 13 | export WANDB__SERVICE_WAIT=300 14 | 15 | # GPU check 16 | echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" 17 | NUM_GPU="$(nvidia-smi --list-gpus | wc -l)" 18 | echo "NUM_GPU=$NUM_GPU" 19 | 20 | export MASTER_ADDR=$(scontrol show hostname ${SLURM_NODELIST} | head -n 1) 21 | find_free_port() { 22 | python -c "import socket; s = socket.socket(socket.AF_INET, socket.SOCK_STREAM); s.bind(('', 0)); port = s.getsockname()[1]; s.close(); print(port)" 23 | } 24 | export MASTER_PORT=$(find_free_port) 25 | 26 | # run script with selected configuration using torchrun 27 | HYDRA_FULL_ERROR=1 uv run torchrun \ 28 | --nnodes=1 \ 29 | --nproc_per_node=$NUM_GPU \ 30 | --rdzv_id=$RANDOM \ 31 | --rdzv_backend=c10d \ 32 | --max-restarts=3 \ 33 | --standalone \ 34 | --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \ 35 | scripts/run.py \ 36 | --config-name=bridge \ 37 | action_lr=0.00005 \ 38 | vlm_lr=0.00005 \ 39 | flow_sampling=beta \ 40 | use_torch_compile=True \ 41 | use_bf16=True \ 42 | use_amp=True 43 | -------------------------------------------------------------------------------- /slurm/train_multi_node.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=pg-vla 4 | #SBATCH --output=logs/%A-%N.out 5 | #SBATCH --error=logs/%A-%N.err 6 | #SBATCH --time=47:59:59 7 | #SBATCH --nodes=3 8 | #SBATCH --ntasks-per-node=1 9 | #SBATCH --gres=gpu:8 10 | #SBATCH --cpus-per-task=104 11 | #SBATCH --mem=500G # per node 12 | 13 | export WANDB__SERVICE_WAIT=300 14 | 15 | # export NCCL_P2P_DISABLE=1 16 | # export NCCL_IB_DISABLE=1 17 | # export NCCL_SHM_DISABLE=1 18 | export NCCL_NSOCKS_PERTHREAD=4 19 | export NCCL_SOCKET_NTHREADS=2 20 | 21 | NUM_GPU="$(nvidia-smi --list-gpus | wc -l)" 22 | echo "NUM_GPU=$NUM_GPU" 23 | 24 | nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) 25 | nodes_array=($nodes) 26 | head_node=${nodes_array[0]} 27 | head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) 28 | find_free_port() { 29 | python -c "import socket; s = socket.socket(socket.AF_INET, socket.SOCK_STREAM); s.bind(('', 0)); port = s.getsockname()[1]; s.close(); print(port)" 30 | } 31 | export MASTER_PORT=$(find_free_port) 32 | echo Node IP: $head_node_ip 33 | echo Master Port: $MASTER_PORT 34 | 35 | # export LOGLEVEL=INFO 36 | # export NCCL_DEBUG=INFO 37 | # export PYTHONFAULTHANDLER=1 38 | # export TORCH_DISTRIBUTED_DEBUG=DETAIL 39 | 40 | # Run script with selected configuration using torchrun 41 | # TORCH_CPP_LOG_LEVEL=INFO TORCH_DISTRIBUTED_DEBUG=INFO TORCH_SHOW_CPP_STACKTRACES=1 NCCL_DEBUG=INFO 42 | NCCL_SOCKET_IFNAME=ens27f0 srun uv run torchrun \ 43 | --nnodes=$SLURM_JOB_NUM_NODES \ 44 | --nproc_per_node=8 \ 45 | --rdzv_id $RANDOM \ 46 | --rdzv_backend c10d \ 47 | --max-restarts=3 \ 48 | --rdzv_endpoint $head_node_ip:$MASTER_PORT \ 49 | scripts/run.py \ 50 | --config-name=bridge \ 51 | n_nodes=$SLURM_JOB_NUM_NODES \ 52 | action_lr=0.00005 \ 53 | vlm_lr=0.00005 \ 54 | flow_sampling=beta \ 55 | use_torch_compile=True \ 56 | use_bf16=True 57 | -------------------------------------------------------------------------------- /src/agent/dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import tensorflow as tf 4 | 5 | from src.data.dataset import make_interleaved_dataset 6 | from src.data.dataset_torch import TorchRLDSDataset 7 | from src.data.oxe import make_oxe_dataset_kwargs_and_weights 8 | from src.utils.monitor import log_execution_time 9 | 10 | tf.config.set_visible_devices([], "GPU") 11 | log = logging.getLogger(__name__) 12 | 13 | 14 | class TorchRLDSInterleavedDataset: 15 | @log_execution_time(log) 16 | def __init__(self, config, train=True): 17 | dataset_kwargs_list, sample_weights = make_oxe_dataset_kwargs_and_weights( 18 | config.dataset_mix, 19 | config.data_path, 20 | load_proprio=config.load_proprio, 21 | load_camera_views=("primary",), 22 | ) 23 | dataset = make_interleaved_dataset( 24 | dataset_kwargs_list, 25 | sample_weights, 26 | train=train, 27 | split=config.get("split", None), 28 | shuffle_buffer_size=config.shuffle_buffer_size, 29 | batch_size=None, # batching will be handles in PyTorch Dataloader object 30 | balance_weights=True, 31 | traj_transform_kwargs=dict( 32 | # goal_relabeling_strategy="uniform", # no neeed for goal relabeling 33 | window_size=config.window_size, 34 | action_horizon=config.action_horizon, 35 | subsample_length=100, 36 | skip_unlabeled=config.skip_unlabeled, # skip ones without language annotation 37 | ), 38 | frame_transform_kwargs=dict( 39 | image_augment_kwargs={ 40 | "primary": dict( 41 | random_resized_crop=dict( 42 | scale=[0.8, 1.0], 43 | ratio=[0.9, 1.1], 44 | ), 45 | random_brightness=[0.1], 46 | random_contrast=[0.9, 1.1], 47 | random_saturation=[0.9, 1.1], 48 | random_hue=[0.05], 49 | augment_order=[ 50 | "random_resized_crop", 51 | "random_brightness", 52 | "random_contrast", 53 | "random_saturation", 54 | "random_hue", 55 | ], 56 | ), 57 | "wrist": dict( 58 | random_brightness=[0.1], 59 | random_contrast=[0.9, 1.1], 60 | random_saturation=[0.9, 1.1], 61 | random_hue=[0.05], 62 | augment_order=[ 63 | "random_brightness", 64 | "random_contrast", 65 | "random_saturation", 66 | "random_hue", 67 | ], 68 | ), 69 | }, 70 | resize_size=dict( 71 | primary=(224, 224), 72 | wrist=(224, 224), 73 | ), 74 | num_parallel_calls=config.num_parallel_calls, 75 | ), 76 | traj_transform_threads=config.traj_transform_threads, 77 | traj_read_threads=config.traj_read_threads, 78 | ) 79 | 80 | # convert for torch 81 | self.dataset = TorchRLDSDataset(dataset, train=train) 82 | -------------------------------------------------------------------------------- /src/agent/env_adapter/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class BaseEnvAdapter: 5 | def __init__(self): 6 | pass 7 | 8 | def normalize_bound( 9 | self, 10 | data: np.ndarray, 11 | data_min: np.ndarray, 12 | data_max: np.ndarray, 13 | clip_min: float = -1, 14 | clip_max: float = 1, 15 | eps: float = 1e-8, 16 | ) -> np.ndarray: 17 | ndata = 2 * (data - data_min) / (data_max - data_min + eps) - 1 18 | return np.clip(ndata, clip_min, clip_max) 19 | 20 | def denormalize_bound( 21 | self, 22 | data: np.ndarray, 23 | data_min: np.ndarray, 24 | data_max: np.ndarray, 25 | clip_min: float = -1, 26 | clip_max: float = 1, 27 | eps=1e-8, 28 | ) -> np.ndarray: 29 | clip_range = clip_max - clip_min 30 | rdata = (data - clip_min) / clip_range * (data_max - data_min) + data_min 31 | return rdata 32 | 33 | def normalize_gaussian( 34 | self, 35 | data: np.ndarray, 36 | mean: np.ndarray, 37 | std: np.ndarray, 38 | eps: float = 1e-8, 39 | ) -> np.ndarray: 40 | return (data - mean) / (std + eps) 41 | 42 | def denormalize_gaussian( 43 | self, 44 | data: np.ndarray, 45 | mean: np.ndarray, 46 | std: np.ndarray, 47 | eps: float = 1e-8, 48 | ) -> np.ndarray: 49 | return data * (std + eps) + mean 50 | -------------------------------------------------------------------------------- /src/agent/model_averaging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | 5 | log = logging.getLogger(__name__) 6 | 7 | 8 | class ModelAveraging: 9 | """Not supporting resume from checkpoint currently""" 10 | 11 | def __init__(self, model, cfg, device): 12 | self.use_ema = cfg.get("use_ema", False) 13 | self.use_swa = cfg.get("use_swa", False) 14 | assert not (self.use_ema and self.use_swa), ( 15 | "Cannot use both EMA and SWA at once" 16 | ) 17 | 18 | self.model_avg = None 19 | self.model = model 20 | self.device = device 21 | 22 | # EMA configuration 23 | if self.use_ema: 24 | self.ema_start = cfg.ema_start 25 | self.ema_decay = cfg.get("ema_decay", 0.99) 26 | self.ema_freq = cfg.get("ema_freq", 1) 27 | self.ema_device = cfg.get("ema_device", self.device) 28 | 29 | # SWA configuration 30 | if self.use_swa: 31 | self.swa_start = cfg.swa_start 32 | self.swa_freq = cfg.swa_freq 33 | self.swa_device = cfg.get("swa_device", "cpu") 34 | 35 | def maybe_initialize(self, cnt_update): 36 | if self.use_swa and cnt_update == self.swa_start: 37 | self.model_avg = torch.optim.swa_utils.AveragedModel( 38 | self.model, device=self.swa_device 39 | ) 40 | logging.info("Starting SWA...") 41 | 42 | if self.use_ema and cnt_update == self.ema_start: 43 | self.model_avg = torch.optim.swa_utils.AveragedModel( 44 | self.model, 45 | multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(self.ema_decay), 46 | device=self.ema_device, 47 | ) 48 | logging.info(f"Starting EMA with decay {self.ema_decay}...") 49 | 50 | def maybe_update(self, cnt_update): 51 | if self.model_avg is None: 52 | return 53 | if self.use_ema and cnt_update % self.ema_freq == 0: 54 | self.model_avg.update_parameters(self.model) 55 | logging.info("EMA updated") 56 | if self.use_swa and cnt_update % self.swa_freq == 0: 57 | self.model_avg.update_parameters(self.model) 58 | logging.info("SWA updated") 59 | 60 | def get_model_module(self) -> dict: 61 | if self.model_avg: 62 | return self.model_avg.module 63 | return self.model 64 | 65 | def state_dict(self) -> dict: 66 | if self.model_avg: 67 | return { 68 | "state_dict": self.model_avg.module.state_dict(), 69 | "n_averaged": self.model_avg.state_dict().get("n_averaged", 1), 70 | "model_type": "ema" if self.use_ema else "swa", 71 | } 72 | return {} 73 | -------------------------------------------------------------------------------- /src/data/dataset_torch.py: -------------------------------------------------------------------------------- 1 | """ 2 | From: https://github.com/octo-models/octo/blob/main/examples/06_pytorch_oxe_dataloader.py 3 | 4 | This example shows how to use the `src.data` dataloader with PyTorch by wrapping it in a simple PyTorch dataloader. The config below also happens to be our exact pretraining config (except for the batch size and shuffle buffer size, which are reduced for demonstration purposes). 5 | """ 6 | 7 | import tensorflow as tf 8 | import torch 9 | 10 | tf.config.set_visible_devices([], "GPU") 11 | 12 | 13 | class TorchRLDSDataset(torch.utils.data.IterableDataset): 14 | """Thin wrapper around RLDS dataset for use with PyTorch dataloaders.""" 15 | 16 | def __init__( 17 | self, 18 | rlds_dataset, 19 | train=True, 20 | ): 21 | self._rlds_dataset = rlds_dataset 22 | self._is_train = train 23 | 24 | def __iter__(self): 25 | for sample in self._rlds_dataset.as_numpy_iterator(): 26 | yield sample 27 | 28 | def __len__(self): 29 | # TODO(allenzren): account for sample weights? 30 | return self._rlds_dataset.true_total_length 31 | # lengths = np.array( 32 | # [ 33 | # stats["num_transitions"] 34 | # for stats in self._rlds_dataset.dataset_statistics.values() 35 | # ], 36 | # dtype=float, 37 | # ) 38 | # if hasattr(self._rlds_dataset, "sample_weights"): 39 | # lengths *= self._rlds_dataset.sample_weights 40 | # total_len = lengths.sum() 41 | # if self._is_train: 42 | # return int(0.95 * total_len) 43 | # else: 44 | # return int(0.05 * total_len) 45 | -------------------------------------------------------------------------------- /src/data/dlimp/__init__.py: -------------------------------------------------------------------------------- 1 | from . import transforms 2 | from .dataset import DLataset 3 | from .utils import parallel_vmap, vmap 4 | -------------------------------------------------------------------------------- /src/data/dlimp/augmentations.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import tensorflow as tf 4 | 5 | 6 | def random_resized_crop(image, scale, ratio, seed): 7 | assert image.shape.ndims == 3 or image.shape.ndims == 4 8 | if image.shape.ndims == 3: 9 | image = tf.expand_dims(image, axis=0) 10 | batch_size = tf.shape(image)[0] 11 | # taken from https://keras.io/examples/vision/nnclr/#random-resized-crops 12 | log_ratio = (tf.math.log(ratio[0]), tf.math.log(ratio[1])) 13 | height = tf.shape(image)[1] 14 | width = tf.shape(image)[2] 15 | 16 | random_scales = tf.random.stateless_uniform((batch_size,), seed, scale[0], scale[1]) 17 | random_ratios = tf.exp( 18 | tf.random.stateless_uniform((batch_size,), seed, log_ratio[0], log_ratio[1]) 19 | ) 20 | 21 | new_heights = tf.clip_by_value(tf.sqrt(random_scales / random_ratios), 0, 1) 22 | new_widths = tf.clip_by_value(tf.sqrt(random_scales * random_ratios), 0, 1) 23 | height_offsets = tf.random.stateless_uniform( 24 | (batch_size,), seed, 0, 1 - new_heights 25 | ) 26 | width_offsets = tf.random.stateless_uniform((batch_size,), seed, 0, 1 - new_widths) 27 | 28 | bounding_boxes = tf.stack( 29 | [ 30 | height_offsets, 31 | width_offsets, 32 | height_offsets + new_heights, 33 | width_offsets + new_widths, 34 | ], 35 | axis=1, 36 | ) 37 | 38 | image = tf.image.crop_and_resize( 39 | image, bounding_boxes, tf.range(batch_size), (height, width) 40 | ) 41 | 42 | if image.shape[0] == 1: 43 | return image[0] 44 | else: 45 | return image 46 | 47 | 48 | def random_rot90(image, seed): 49 | k = tf.random.stateless_uniform((), seed, 0, 4, dtype=tf.int32) 50 | return tf.image.rot90(image, k=k) 51 | 52 | 53 | AUGMENT_OPS = { 54 | "random_resized_crop": random_resized_crop, 55 | "random_brightness": tf.image.stateless_random_brightness, 56 | "random_contrast": tf.image.stateless_random_contrast, 57 | "random_saturation": tf.image.stateless_random_saturation, 58 | "random_hue": tf.image.stateless_random_hue, 59 | "random_flip_left_right": tf.image.stateless_random_flip_left_right, 60 | "random_flip_up_down": tf.image.stateless_random_flip_up_down, 61 | "random_rot90": random_rot90, 62 | } 63 | 64 | 65 | def augment_image( 66 | image: tf.Tensor, 67 | seed: Optional[tf.Tensor] = None, 68 | **augment_kwargs, 69 | ) -> tf.Tensor: 70 | """Unified image augmentation function for TensorFlow. 71 | 72 | This function is primarily configured through `augment_kwargs`. There must be one kwarg called "augment_order", 73 | which is a list of strings specifying the augmentation operations to apply and the order in which to apply them. See 74 | the `AUGMENT_OPS` dictionary above for a list of available operations. 75 | 76 | For each entry in "augment_order", there may be a corresponding kwarg with the same name. The value of this kwarg 77 | can be a dictionary of kwargs or a sequence of positional args to pass to the corresponding augmentation operation. 78 | This additional kwarg is required for all operations that take additional arguments other than the image and random 79 | seed. For example, the "random_resized_crop" operation requires a "scale" and "ratio" argument that can be specified 80 | either positionally or by name. "random_flip_left_right", on the other hand, does not take any additional arguments 81 | and so does not require an additional kwarg to configure it. 82 | 83 | Here is an example config: 84 | 85 | ``` 86 | augment_kwargs = { 87 | "augment_order": ["random_resized_crop", "random_brightness", "random_contrast", "random_flip_left_right"], 88 | "random_resized_crop": { 89 | "scale": [0.8, 1.0], 90 | "ratio": [3/4, 4/3], 91 | }, 92 | "random_brightness": [0.1], 93 | "random_contrast": [0.9, 1.1], 94 | ``` 95 | 96 | Args: 97 | image: A `Tensor` of shape [height, width, channels] with the image. May be uint8 or float32 with values in [0, 255]. 98 | seed (optional): A `Tensor` of shape [2] with the seed for the random number generator. 99 | **augment_kwargs: Keyword arguments for the augmentation operations. The order of operations is determined by 100 | the "augment_order" keyword argument. Other keyword arguments are passed to the corresponding augmentation 101 | operation. See above for a list of operations. 102 | """ 103 | assert isinstance(augment_kwargs, dict) 104 | 105 | if "augment_order" not in augment_kwargs: 106 | raise ValueError("augment_kwargs must contain an 'augment_order' key.") 107 | 108 | # convert to float at the beginning to avoid each op converting back and 109 | # forth between uint8 and float32 internally 110 | orig_dtype = image.dtype 111 | image = tf.image.convert_image_dtype(image, tf.float32) 112 | 113 | if seed is None: 114 | seed = tf.random.uniform([2], maxval=tf.dtypes.int32.max, dtype=tf.int32) 115 | 116 | for op in augment_kwargs["augment_order"]: 117 | seed = tf.random.stateless_uniform( 118 | [2], seed, maxval=tf.dtypes.int32.max, dtype=tf.int32 119 | ) 120 | if op in augment_kwargs: 121 | if hasattr(augment_kwargs[op], "items"): 122 | image = AUGMENT_OPS[op](image, seed=seed, **augment_kwargs[op]) 123 | else: 124 | image = AUGMENT_OPS[op](image, seed=seed, *augment_kwargs[op]) 125 | else: 126 | image = AUGMENT_OPS[op](image, seed=seed) 127 | # float images are expected to be in [0, 1] 128 | image = tf.clip_by_value(image, 0, 1) 129 | 130 | # convert back to original dtype and scale 131 | image = tf.image.convert_image_dtype(image, orig_dtype, saturate=True) 132 | 133 | return image 134 | -------------------------------------------------------------------------------- /src/data/dlimp/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from . import goal_relabeling 2 | from .common import * 3 | from .frame_transforms import * 4 | from .traj_transforms import * 5 | -------------------------------------------------------------------------------- /src/data/dlimp/transforms/common.py: -------------------------------------------------------------------------------- 1 | import fnmatch 2 | from typing import Any, Callable, Dict, Union 3 | 4 | 5 | def selective_tree_map( 6 | x: Dict[str, Any], 7 | match: Union[str, Callable[[str, Any], bool]], 8 | map_fn: Callable, 9 | *, 10 | _keypath: str = "", 11 | ) -> Dict[str, Any]: 12 | """Maps a function over a nested dictionary, only applying it leaves that match a criterion. 13 | 14 | If `match` is a string, it follows glob-style syntax. For example, "bar" will only match 15 | a top-level key called "bar", "*bar" will match any leaf whose key ends with "bar", 16 | and "*bar*" will match any subtree with a key that contains "bar". 17 | 18 | Key paths are separated by "/". For example, "foo/bar" will match a leaf with key "bar" that 19 | is nested under a key "foo". 20 | 21 | Args: 22 | x (Dict[str, Any]): The (possibly nested) dictionary to map over. 23 | match (str or Callable[[str, Any], bool]): If a string or list of strings, `map_fn` will 24 | only be applied to leaves whose key path matches `match` using glob-style syntax. If a 25 | function, `map_fn` will only be applied to leaves for which `match(key_path, value)` 26 | returns True. 27 | map_fn (Callable): The function to apply. 28 | """ 29 | if not callable(match): 30 | 31 | def match_fn(keypath, value): 32 | return fnmatch.fnmatch(keypath, match) 33 | else: 34 | match_fn = match 35 | 36 | out = {} 37 | for key in x: 38 | if isinstance(x[key], dict): 39 | out[key] = selective_tree_map( 40 | x[key], match_fn, map_fn, _keypath=_keypath + key + "/" 41 | ) 42 | elif match_fn(_keypath + key, x[key]): 43 | out[key] = map_fn(x[key]) 44 | else: 45 | out[key] = x[key] 46 | return out 47 | 48 | 49 | def flatten_dict(d: Dict[str, Any], sep="/") -> Dict[str, Any]: 50 | """Given a nested dictionary, flatten it by concatenating keys with sep.""" 51 | flattened = {} 52 | for k, v in d.items(): 53 | if isinstance(v, dict): 54 | for k2, v2 in flatten_dict(v, sep=sep).items(): 55 | flattened[k + sep + k2] = v2 56 | else: 57 | flattened[k] = v 58 | return flattened 59 | 60 | 61 | def unflatten_dict(d: Dict[str, Any], sep="/") -> Dict[str, Any]: 62 | """Given a flattened dictionary, unflatten it by splitting keys by sep.""" 63 | unflattened = {} 64 | for k, v in d.items(): 65 | keys = k.split(sep) 66 | if len(keys) == 1: 67 | unflattened[k] = v 68 | else: 69 | if keys[0] not in unflattened: 70 | unflattened[keys[0]] = {} 71 | unflattened[keys[0]][sep.join(keys[1:])] = v 72 | return unflattened 73 | -------------------------------------------------------------------------------- /src/data/dlimp/transforms/frame_transforms.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any, Callable, Dict, Sequence, Tuple, Union 3 | 4 | import tensorflow as tf 5 | 6 | from src.data.dlimp.augmentations import augment_image 7 | from src.data.dlimp.utils import resize_depth_image, resize_image 8 | 9 | from .common import selective_tree_map 10 | 11 | 12 | def decode_images( 13 | x: Dict[str, Any], match: Union[str, Sequence[str]] = "image" 14 | ) -> Dict[str, Any]: 15 | """Can operate on nested dicts. Decodes any leaves that have `match` anywhere in their path.""" 16 | if isinstance(match, str): 17 | match = [match] 18 | 19 | return selective_tree_map( 20 | x, 21 | lambda keypath, value: any([s in keypath for s in match]) 22 | and value.dtype == tf.string, 23 | partial(tf.io.decode_image, expand_animations=False), 24 | ) 25 | 26 | 27 | def resize_images( 28 | x: Dict[str, Any], 29 | match: Union[str, Sequence[str]] = "image", 30 | size: Tuple[int, int] = (128, 128), 31 | ) -> Dict[str, Any]: 32 | """Can operate on nested dicts. Resizes any leaves that have `match` anywhere in their path. Takes uint8 images 33 | as input and returns float images (still in [0, 255]). 34 | """ 35 | if isinstance(match, str): 36 | match = [match] 37 | 38 | return selective_tree_map( 39 | x, 40 | lambda keypath, value: any([s in keypath for s in match]) 41 | and value.dtype == tf.uint8, 42 | partial(resize_image, size=size), 43 | ) 44 | 45 | 46 | def resize_depth_images( 47 | x: Dict[str, Any], 48 | match: Union[str, Sequence[str]] = "depth", 49 | size: Tuple[int, int] = (128, 128), 50 | ) -> Dict[str, Any]: 51 | """Can operate on nested dicts. Resizes any leaves that have `match` anywhere in their path. Takes float32 images 52 | as input and returns float images (in arbitrary range). 53 | """ 54 | if isinstance(match, str): 55 | match = [match] 56 | 57 | return selective_tree_map( 58 | x, 59 | lambda keypath, value: any([s in keypath for s in match]) 60 | and value.dtype == tf.float32, 61 | partial(resize_depth_image, size=size), 62 | ) 63 | 64 | 65 | def augment( 66 | x: Dict[str, Any], 67 | match: Union[str, Callable[[str, Any], bool]] = "*image", 68 | traj_identical: bool = True, 69 | keys_identical: bool = True, 70 | augment_kwargs: dict = {}, 71 | ) -> Dict[str, Any]: 72 | """ 73 | Augments the input dictionary `x` by applying image augmentation to all values whose keypath contains `match`. 74 | 75 | Args: 76 | x (Dict[str, Any]): The input dictionary to augment. 77 | match (str or Callable[[str, Any], bool]): See documentation for `selective_tree_map`. 78 | Defaults to "*image", which matches all leaves whose key ends in "image". 79 | traj_identical (bool, optional): Whether to use the same random seed for all images in a trajectory. 80 | keys_identical (bool, optional): Whether to use the same random seed for all keys that are augmented. 81 | augment_kwargs (dict, optional): Additional keyword arguments to pass to the `augment_image` function. 82 | """ 83 | toplevel_seed = tf.random.uniform([2], 0, 2**31 - 1, dtype=tf.int32) 84 | 85 | def map_fn(value): 86 | if keys_identical and traj_identical: 87 | seed = [x["_traj_index"], x["_traj_index"]] 88 | elif keys_identical and not traj_identical: 89 | seed = toplevel_seed 90 | elif not keys_identical and traj_identical: 91 | raise NotImplementedError() 92 | else: 93 | seed = None 94 | 95 | return augment_image(value, seed=seed, **augment_kwargs) 96 | 97 | return selective_tree_map( 98 | x, 99 | match, 100 | map_fn, 101 | ) 102 | -------------------------------------------------------------------------------- /src/data/dlimp/transforms/goal_relabeling.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains goal relabeling and reward logic written in TensorFlow. 3 | 4 | Each relabeling function takes a trajectory with keys `obs` and `next_obs`. It returns a new trajectory with the added 5 | keys `goals` and `rewards`. Keep in mind that `obs` and `next_obs` may themselves be dictionaries, and `goals` must 6 | match their structure. 7 | """ 8 | 9 | from typing import Any, Dict 10 | 11 | import tensorflow as tf 12 | 13 | 14 | def uniform(traj: Dict[str, Any], reached_proportion: float): 15 | """Relabels with a true uniform distribution over future states. With probability reached_proportion, 16 | obs[i] gets a goal equal to next_obs[i]. In this case, the reward is 0. Otherwise, 17 | obs[i] gets a goal sampled uniformly from the set next_obs[i + 1:], and the reward is -1. 18 | """ 19 | traj_len = tf.shape(tf.nest.flatten(traj)[0])[0] 20 | 21 | # select a random future index for each transition i in the range [i + 1, traj_len) 22 | rand = tf.random.uniform([traj_len]) 23 | low = tf.cast(tf.range(traj_len) + 1, tf.float32) 24 | high = tf.cast(traj_len, tf.float32) 25 | goal_idxs = tf.cast(rand * (high - low) + low, tf.int32) 26 | 27 | # TODO(kvablack): don't know how I got an out-of-bounds during training, 28 | # could not reproduce, trying to patch it for now 29 | goal_idxs = tf.minimum(goal_idxs, traj_len - 1) 30 | 31 | # select a random proportion of transitions to relabel with the next obs 32 | goal_reached_mask = tf.random.uniform([traj_len]) < reached_proportion 33 | 34 | # the last transition must be goal-reaching 35 | goal_reached_mask = tf.logical_or( 36 | goal_reached_mask, tf.range(traj_len) == traj_len - 1 37 | ) 38 | 39 | # make goal-reaching transitions have an offset of 0 40 | goal_idxs = tf.where(goal_reached_mask, tf.range(traj_len), goal_idxs) 41 | 42 | # select goals 43 | traj["goals"] = tf.nest.map_structure( 44 | lambda x: tf.gather(x, goal_idxs), 45 | traj["next_obs"], 46 | ) 47 | 48 | # reward is 0 for goal-reaching transitions, -1 otherwise 49 | traj["rewards"] = tf.cast(tf.where(goal_reached_mask, 0, -1), tf.int32) 50 | 51 | return traj 52 | 53 | 54 | def last_state_upweighted(traj: Dict[str, Any], reached_proportion: float): 55 | """ 56 | A weird relabeling scheme where the last state gets upweighted. For each transition i, a uniform random number is 57 | generated in the range [i + 1, i + traj_len). It then gets clipped to be less than traj_len. Therefore, the first 58 | transition (i = 0) gets a goal sampled uniformly from the future, but for i > 0 the last state gets more and more 59 | upweighted. 60 | """ 61 | traj_len = tf.shape(tf.nest.flatten(traj)[0])[0] 62 | 63 | # select a random future index for each transition 64 | offsets = tf.random.uniform( 65 | [traj_len], 66 | minval=1, 67 | maxval=traj_len, 68 | dtype=tf.int32, 69 | ) 70 | 71 | # select random transitions to relabel as goal-reaching 72 | goal_reached_mask = tf.random.uniform([traj_len]) < reached_proportion 73 | # last transition is always goal-reaching 74 | goal_reached_mask = tf.logical_or( 75 | goal_reached_mask, tf.range(traj_len) == traj_len - 1 76 | ) 77 | 78 | # the goal will come from the current transition if the goal was reached 79 | offsets = tf.where(goal_reached_mask, 0, offsets) 80 | 81 | # convert from relative to absolute indices 82 | indices = tf.range(traj_len) + offsets 83 | 84 | # clamp out of bounds indices to the last transition 85 | indices = tf.minimum(indices, traj_len - 1) 86 | 87 | # select goals 88 | traj["goals"] = tf.nest.map_structure( 89 | lambda x: tf.gather(x, indices), 90 | traj["next_obs"], 91 | ) 92 | 93 | # reward is 0 for goal-reaching transitions, -1 otherwise 94 | traj["rewards"] = tf.cast(tf.where(goal_reached_mask, 0, -1), tf.int32) 95 | 96 | return traj 97 | 98 | 99 | def geometric(traj: Dict[str, Any], reached_proportion: float, discount: float): 100 | """ 101 | Relabels with a geometric distribution over future states. With probability reached_proportion, obs[i] gets 102 | a goal equal to next_obs[i]. In this case, the reward is 0. Otherwise, obs[i] gets a goal sampled 103 | geometrically from the set next_obs[i + 1:], and the reward is -1. 104 | """ 105 | traj_len = tf.shape(tf.nest.flatten(traj)[0])[0] 106 | 107 | # geometrically select a future index for each transition i in the range [i + 1, traj_len) 108 | arange = tf.range(traj_len) 109 | is_future_mask = tf.cast(arange[:, None] < arange[None], tf.float32) 110 | d = discount ** tf.cast(arange[None] - arange[:, None], tf.float32) 111 | 112 | probs = is_future_mask * d 113 | # The indexing changes the shape from [seq_len, 1] to [seq_len] 114 | goal_idxs = tf.random.categorical( 115 | logits=tf.math.log(probs), num_samples=1, dtype=tf.int32 116 | )[:, 0] 117 | 118 | # select a random proportion of transitions to relabel with the next obs 119 | goal_reached_mask = tf.random.uniform([traj_len]) < reached_proportion 120 | 121 | # the last transition must be goal-reaching 122 | goal_reached_mask = tf.logical_or( 123 | goal_reached_mask, tf.range(traj_len) == traj_len - 1 124 | ) 125 | 126 | # make goal-reaching transitions have an offset of 0 127 | goal_idxs = tf.where(goal_reached_mask, tf.range(traj_len), goal_idxs) 128 | 129 | # select goals 130 | traj["goals"] = tf.nest.map_structure( 131 | lambda x: tf.gather(x, goal_idxs), 132 | traj["next_obs"], 133 | ) 134 | 135 | # reward is 0 for goal-reaching transitions, -1 otherwise 136 | traj["rewards"] = tf.cast(tf.where(goal_reached_mask, 0, -1), tf.int32) 137 | 138 | return traj 139 | -------------------------------------------------------------------------------- /src/data/dlimp/transforms/traj_transforms.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | import tensorflow as tf 4 | 5 | 6 | def add_next_obs(traj: Dict[str, Any], pad: bool = True) -> Dict[str, Any]: 7 | """ 8 | Given a trajectory with a key "observations", add the key "next_observations". If pad is False, discards the last 9 | value of all other keys. Otherwise, the last transition will have "observations" == "next_observations". 10 | """ 11 | if not pad: 12 | traj_truncated = tf.nest.map_structure(lambda x: x[:-1], traj) 13 | traj_truncated["next_observations"] = tf.nest.map_structure( 14 | lambda x: x[1:], traj["observations"] 15 | ) 16 | return traj_truncated 17 | else: 18 | traj["next_observations"] = tf.nest.map_structure( 19 | lambda x: tf.concat((x[1:], x[-1:]), axis=0), traj["observations"] 20 | ) 21 | return traj 22 | -------------------------------------------------------------------------------- /src/data/dlimp/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Tuple 2 | 3 | import tensorflow as tf 4 | 5 | 6 | def tensor_feature(value): 7 | return tf.train.Feature( 8 | bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(value).numpy()]) 9 | ) 10 | 11 | 12 | def resize_image(image: tf.Tensor, size: Tuple[int, int]) -> tf.Tensor: 13 | """Resizes an image using Lanczos3 interpolation. Expects & returns uint8.""" 14 | assert image.dtype == tf.uint8 15 | image = tf.image.resize(image, size, method="lanczos3", antialias=True) 16 | image = tf.cast(tf.clip_by_value(tf.round(image), 0, 255), tf.uint8) 17 | return image 18 | 19 | 20 | def resize_depth_image(depth_image: tf.Tensor, size: Tuple[int, int]) -> tf.Tensor: 21 | """Resizes a depth image using bilinear interpolation. Expects & returns float32 in arbitrary range.""" 22 | assert depth_image.dtype == tf.float32 23 | if len(depth_image.shape) < 3: 24 | depth_image = tf.image.resize( 25 | depth_image[..., None], size, method="bilinear", antialias=True 26 | )[..., 0] 27 | else: 28 | depth_image = tf.image.resize( 29 | depth_image, size, method="bilinear", antialias=True 30 | ) 31 | return depth_image 32 | 33 | 34 | def read_resize_encode_image(path: str, size: Tuple[int, int]) -> tf.Tensor: 35 | """Reads, decodes, resizes, and then re-encodes an image.""" 36 | data = tf.io.read_file(path) 37 | image = tf.image.decode_jpeg(data) 38 | image = resize_image(image, size) 39 | image = tf.cast(tf.clip_by_value(tf.round(image), 0, 255), tf.uint8) 40 | return tf.io.encode_jpeg(image, quality=95) 41 | 42 | 43 | def vmap(fn: Callable) -> Callable: 44 | """ 45 | Vmap a function over the first dimension of a tensor (or nested structure of tensors). This 46 | version does NOT parallelize the function; however, it fuses the function calls in a way that 47 | appears to be more performant than tf.map_fn or tf.vectorized_map (when falling back to 48 | while_loop) for certain situations. 49 | 50 | Requires the first dimension of the input to be statically known. 51 | """ 52 | 53 | def wrapped(structure): 54 | return tf.nest.map_structure( 55 | lambda *x: tf.stack(x), 56 | *[ 57 | fn(tf.nest.pack_sequence_as(structure, x)) 58 | for x in zip(*map(tf.unstack, tf.nest.flatten(structure))) 59 | ], 60 | ) 61 | 62 | return wrapped 63 | 64 | 65 | def parallel_vmap(fn: Callable, num_parallel_calls=tf.data.AUTOTUNE) -> Callable: 66 | """ 67 | Vmap a function over the first dimension of a tensor (or nested structure of tensors). This 68 | version attempts to parallelize the function using the tf.data API. I found this to be more 69 | performant than tf.map_fn or tf.vectorized_map (when falling back to while_loop), but the batch 70 | call appears to add significant overhead that may make it slower for some situations. 71 | """ 72 | 73 | def wrapped(structure): 74 | return ( 75 | tf.data.Dataset.from_tensor_slices(structure) 76 | .map(fn, deterministic=True, num_parallel_calls=num_parallel_calls) 77 | .batch( 78 | tf.cast(tf.shape(tf.nest.flatten(structure)[0])[0], tf.int64), 79 | ) 80 | .get_single_element() 81 | ) 82 | 83 | return wrapped 84 | -------------------------------------------------------------------------------- /src/data/obs_transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains observation-level transforms used in the octo data pipeline. These transforms operate on the 3 | "observation" dictionary, and are applied at a per-frame level. 4 | """ 5 | 6 | from typing import Mapping, Optional, Tuple, Union 7 | 8 | import tensorflow as tf 9 | import logging 10 | 11 | import src.data.dlimp as dl 12 | 13 | log = logging.getLogger(__name__) 14 | 15 | def augment( 16 | obs: dict, 17 | seed: tf.Tensor, 18 | augment_kwargs: Union[dict, Mapping[str, dict]], 19 | ) -> dict: 20 | """Augments images, skipping padding images.""" 21 | if not hasattr(augment_kwargs, "items"): 22 | raise ValueError( 23 | "augment_kwargs must be a dict with keys corresponding to image names, or a single dict " 24 | "with an 'augment_order' key." 25 | ) 26 | image_names = {key[6:] for key in obs if key.startswith("image_")} 27 | 28 | # "augment_order" is required in augment_kwargs, so if it's there, we can assume that the user has passed 29 | # in a single augmentation dict (otherwise, we assume that the user has passed in a mapping from image 30 | # name to augmentation dict) 31 | if "augment_order" in augment_kwargs: 32 | augment_kwargs = {name: augment_kwargs for name in image_names} 33 | 34 | for i, name in enumerate(image_names): 35 | if name not in augment_kwargs: 36 | continue 37 | kwargs = augment_kwargs[name] 38 | log.debug(f"Augmenting image_{name} with kwargs {kwargs}") 39 | obs[f"image_{name}"] = tf.cond( 40 | obs["pad_mask_dict"][f"image_{name}"], 41 | lambda: dl.transforms.augment_image( 42 | obs[f"image_{name}"], 43 | **kwargs, 44 | seed=seed + i, # augment each image differently 45 | ), 46 | lambda: obs[f"image_{name}"], # skip padding images 47 | ) 48 | 49 | return obs 50 | 51 | 52 | def image_dropout( 53 | obs: dict, 54 | seed: tf.Tensor, 55 | dropout_prob: float, 56 | always_keep_key: Optional[str] = None, 57 | ) -> dict: 58 | """Independently drops out image keys, each with probability `dropout_prob`, but always keeps at least one 59 | image present. 60 | """ 61 | image_keys = [key for key in obs if key.startswith("image_")] 62 | if not image_keys: 63 | return obs 64 | pad_mask = tf.stack([obs["pad_mask_dict"][key] for key in image_keys]) 65 | # if any non-padding images exist, pick one of them to keep no matter what 66 | shuffle_seed, seed = tf.unstack(tf.random.split(seed)) 67 | 68 | if always_keep_key: 69 | assert ( 70 | always_keep_key in image_keys 71 | ), f"Specified always_keep_key {always_keep_key} not present in image_keys: {image_keys} during dropout." 72 | always_keep_index = tf.constant( 73 | image_keys.index(always_keep_key), dtype=tf.int64 74 | ) 75 | else: 76 | always_keep_index = tf.cond( 77 | tf.reduce_any(pad_mask), 78 | # pick a random index from the non-padding images 79 | lambda: tf.random.experimental.stateless_shuffle( 80 | tf.where(pad_mask)[:, 0], seed=shuffle_seed 81 | )[0], 82 | # all images are padding, so it doesn't matter 83 | lambda: tf.constant(0, dtype=tf.int64), 84 | ) 85 | 86 | # drop images independently, except for the one at always_keep_index 87 | rands = tf.random.stateless_uniform([len(image_keys)], seed=seed) 88 | pad_mask = tf.logical_and( 89 | pad_mask, 90 | tf.logical_or( 91 | tf.range(len(image_keys), dtype=tf.int64) == always_keep_index, 92 | rands > dropout_prob, 93 | ), 94 | ) 95 | 96 | # perform the dropout and update pad_mask_dict 97 | for i, key in enumerate(image_keys): 98 | obs["pad_mask_dict"][key] = pad_mask[i] 99 | obs[key] = tf.cond( 100 | pad_mask[i], 101 | lambda: obs[key], 102 | lambda: tf.zeros_like(obs[key]), 103 | ) 104 | return obs 105 | 106 | 107 | def decode_and_resize( 108 | obs: dict, 109 | resize_size: Union[Tuple[int, int], Mapping[str, Tuple[int, int]]], 110 | depth_resize_size: Union[Tuple[int, int], Mapping[str, Tuple[int, int]]], 111 | ) -> dict: 112 | """Decodes images and depth images, and then optionally resizes them.""" 113 | # just gets the part after "image_" or "depth_" 114 | image_names = {key[6:] for key in obs if key.startswith("image_")} 115 | depth_names = {key[6:] for key in obs if key.startswith("depth_")} 116 | 117 | if isinstance(resize_size, tuple): 118 | resize_size = {name: resize_size for name in image_names} 119 | if isinstance(depth_resize_size, tuple): 120 | depth_resize_size = {name: depth_resize_size for name in depth_names} 121 | 122 | for name in image_names: 123 | if name not in resize_size: 124 | log.warning( 125 | f"No resize_size was provided for image_{name}. This will result in 1x1 " 126 | "padding images, which may cause errors if you mix padding and non-padding images." 127 | ) 128 | image = obs[f"image_{name}"] 129 | if image.dtype == tf.string: 130 | if tf.strings.length(image) == 0: 131 | # this is a padding image 132 | image = tf.zeros((*resize_size.get(name, (1, 1)), 3), dtype=tf.uint8) 133 | else: 134 | image = tf.io.decode_image( 135 | image, expand_animations=False, dtype=tf.uint8 136 | ) 137 | elif image.dtype != tf.uint8: 138 | raise ValueError( 139 | f"Unsupported image dtype: found image_{name} with dtype {image.dtype}" 140 | ) 141 | if name in resize_size: 142 | image = dl.transforms.resize_image(image, size=resize_size[name]) 143 | obs[f"image_{name}"] = image 144 | 145 | for name in depth_names: 146 | if name not in depth_resize_size: 147 | log.warning( 148 | f"No depth_resize_size was provided for depth_{name}. This will result in 1x1 " 149 | "padding depth images, which may cause errors if you mix padding and non-padding images." 150 | ) 151 | depth = obs[f"depth_{name}"] 152 | if depth.dtype == tf.string: 153 | if tf.strings.length(depth) == 0: 154 | # this is a padding image 155 | depth = tf.zeros( 156 | (*depth_resize_size.get(name, (1, 1)), 1), dtype=tf.float32 157 | ) 158 | else: 159 | depth = tf.io.decode_image( 160 | depth, expand_animations=False, dtype=tf.float32 161 | )[..., 0] 162 | elif depth.dtype != tf.float32: 163 | raise ValueError( 164 | f"Unsupported depth dtype: found depth_{name} with dtype {depth.dtype}" 165 | ) 166 | if name in depth_resize_size: 167 | depth = dl.transforms.resize_depth_image( 168 | depth, size=depth_resize_size[name] 169 | ) 170 | obs[f"depth_{name}"] = depth 171 | 172 | return obs 173 | -------------------------------------------------------------------------------- /src/data/oxe/__init__.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | from typing import Any, Dict, List, Sequence, Tuple, Union 4 | 5 | from src.data.oxe.oxe_dataset_configs import OXE_DATASET_CONFIGS, ActionEncoding 6 | from src.data.oxe.oxe_dataset_mixes import OXE_NAMED_MIXES 7 | from src.data.oxe.oxe_standardization_transforms import OXE_STANDARDIZATION_TRANSFORMS 8 | from src.data.utils.data_utils import NormalizationType 9 | from src.utils.spec import ModuleSpec 10 | 11 | log = logging.getLogger(__name__) 12 | 13 | 14 | def make_oxe_dataset_kwargs( 15 | name: str, 16 | data_dir: str, 17 | load_camera_views: Sequence[str] = ("primary",), 18 | load_depth: bool = False, 19 | load_proprio: bool = False, 20 | load_language: bool = True, 21 | force_recompute_dataset_statistics: bool = False, 22 | action_proprio_normalization_type: NormalizationType = NormalizationType.BOUNDS, 23 | skip_norm: bool = False, 24 | ) -> Dict[str, Any]: 25 | """Generates dataset kwargs for a given dataset from Open X-Embodiment. The returned kwargs can be passed 26 | directly into `octo.data.dataset.make_dataset_from_rlds`. 27 | 28 | Args: 29 | name: Name of the dataset to load. See `oxe_dataset_configs.py` for available datasets. 30 | data_dir: Base data directory that contains the dataset. 31 | load_camera_views: Which views to load. See `oxe_dataset_configs.py` for available views. 32 | load_depth: If True, loads corresponding depth channels for each RGB channel. 33 | load_proprio: If True, loads proprioceptive information. 34 | load_language: If True, loads language instructions. 35 | force_recompute_dataset_statistics: If True, recompute dataset statistics. 36 | action_proprio_normalization_type: Normalization type to use for proprioceptive actions. 37 | """ 38 | dataset_kwargs = copy.deepcopy(OXE_DATASET_CONFIGS[name]) 39 | 40 | if dataset_kwargs["action_encoding"] is ActionEncoding.EEF_POS: 41 | # with EEF_POS actions, the last action dimension is gripper 42 | dataset_kwargs["action_normalization_mask"] = [True] * 6 + [False] 43 | elif dataset_kwargs["action_encoding"] is ActionEncoding.JOINT_POS: 44 | # with JOINT_POS actions, last dimension is gripper 45 | dataset_kwargs["action_normalization_mask"] = [True] * 7 + [False] 46 | elif dataset_kwargs["action_encoding"] is ActionEncoding.JOINT_POS_BIMANUAL: 47 | # with JOINT_POS_BIMANUAL actions, 7th and 14th dimension are gripper 48 | dataset_kwargs["action_normalization_mask"] = ( 49 | [True] * 6 + [False] + [True] * 6 + [False] 50 | ) 51 | elif dataset_kwargs["action_encoding"] is ActionEncoding.NAV_2D: 52 | # with NAV_2D actions, all dimensions are deltas 53 | dataset_kwargs["action_normalization_mask"] = [True] * 2 54 | elif dataset_kwargs["action_encoding"] is ActionEncoding.JOINT_POS_BIMANUAL_NAV: 55 | # with JOINT_POS_BIMANUAL_NAV actions, 7th and 14th dimension are gripper 56 | dataset_kwargs["action_normalization_mask"] = ( 57 | [True] * 6 + [False] + [True] * 6 + [False] + [True] * 2 58 | ) 59 | else: 60 | raise ValueError( 61 | f"Cannot load {name} with unsupported action encoding {dataset_kwargs['action_encoding']}." 62 | ) 63 | 64 | # adjust loaded camera views 65 | if missing_keys := (set(load_camera_views) - set(dataset_kwargs["image_obs_keys"])): 66 | raise ValueError( 67 | f"Cannot load {name} with views {missing_keys} since they are not available." 68 | ) 69 | dataset_kwargs["image_obs_keys"] = { 70 | k: v 71 | for k, v in dataset_kwargs["image_obs_keys"].items() 72 | if k in load_camera_views 73 | } 74 | dataset_kwargs["depth_obs_keys"] = { 75 | k: v 76 | for k, v in dataset_kwargs["depth_obs_keys"].items() 77 | if k in load_camera_views 78 | } 79 | 80 | if not load_depth: 81 | dataset_kwargs.pop("depth_obs_keys") 82 | if load_proprio: 83 | dataset_kwargs["proprio_obs_key"] = "proprio" 84 | if load_language: 85 | dataset_kwargs["language_key"] = "language_instruction" 86 | 87 | dataset_kwargs["action_proprio_normalization_type"] = ( 88 | action_proprio_normalization_type 89 | ) 90 | dataset_kwargs["skip_norm"] = skip_norm 91 | 92 | del dataset_kwargs["proprio_encoding"] 93 | del dataset_kwargs["action_encoding"] 94 | 95 | dataset_kwargs["standardize_fn"] = ModuleSpec.create( 96 | OXE_STANDARDIZATION_TRANSFORMS[name] 97 | ) 98 | 99 | if force_recompute_dataset_statistics: 100 | dataset_kwargs["force_recompute_dataset_statistics"] = True 101 | 102 | return {"name": name, "data_dir": data_dir, **dataset_kwargs} 103 | 104 | 105 | def make_oxe_dataset_kwargs_and_weights( 106 | data_mix: Union[str, Sequence[Tuple[str, float]]], 107 | data_dir: str, 108 | load_camera_views: Sequence[str] = ("primary",), 109 | load_depth: bool = False, 110 | load_proprio: bool = False, 111 | load_language: bool = True, 112 | force_recompute_dataset_statistics: bool = False, 113 | action_proprio_normalization_type: NormalizationType = NormalizationType.BOUNDS, 114 | skip_norm: bool = False, 115 | ) -> Tuple[Dict[str, Any], List[float]]: 116 | """ 117 | Generates dataset kwargs for a given dataset mix from the Open X-Embodiment dataset. The returned kwargs 118 | and weights can be passed directly into `octo.data.dataset.make_interleaved_dataset`. 119 | 120 | Args: 121 | data_mix: List of (dataset name, sampling weight) tuples, or a string specifying a pre-defined mix to 122 | load from `OXE_NAMED_MIXES`. 123 | data_dir: Base data directory that contains the datasets. 124 | load_camera_views: Which views to load. See `oxe_dataset_configs.py` for available views. 125 | load_depth: If True, loads corresponding depth channels for each RGB channel. 126 | load_proprio: If True, loads proprioceptive information. 127 | load_language: If True, loads language instructions. 128 | force_recompute_dataset_statistics: If True, recompute dataset statistics. 129 | action_proprio_normalization_type: Normalization type to use for proprioceptive actions. 130 | Returns: 131 | Tuple of (dataset_kwargs_list, sampling weights). 132 | """ 133 | if isinstance(data_mix, str): 134 | data_mix = OXE_NAMED_MIXES[data_mix] 135 | 136 | filtered_datasets, included_dataset_names = [], [] 137 | for name, weight in data_mix: 138 | if name not in included_dataset_names: 139 | filtered_datasets.append((name, weight)) 140 | included_dataset_names.append(name) 141 | else: 142 | log.warning(f"Skipping duplicate: {(name, weight)}.") 143 | data_mix = filtered_datasets 144 | 145 | data_kwargs_list, weights = [], [] 146 | for name, weight in data_mix: 147 | try: 148 | data_kwargs_list.append( 149 | make_oxe_dataset_kwargs( 150 | name, 151 | data_dir, 152 | load_camera_views, 153 | load_depth, 154 | load_proprio, 155 | load_language, 156 | force_recompute_dataset_statistics, 157 | action_proprio_normalization_type, 158 | skip_norm, 159 | ) 160 | ) 161 | weights.append(weight) 162 | except ValueError as e: 163 | log.warning(f"Skipping {name} due to error: {e}") 164 | 165 | return data_kwargs_list, weights 166 | -------------------------------------------------------------------------------- /src/data/oxe/oxe_dataset_mixes.py: -------------------------------------------------------------------------------- 1 | """Defines dataset mixtures and weights for the Open X-Embodiment Datasets.""" 2 | 3 | OXE_SIMPLE = [ 4 | ("fractal20220817_data", 1.0), 5 | ("bridge_dataset", 1.0), 6 | # ("bc_z", 0.2), 7 | ] 8 | 9 | BRIDGE_MIX = [ 10 | ("bridge_dataset", 1.0), 11 | ] 12 | 13 | FRACTAL_MIX = [ 14 | ("fractal20220817_data", 1.0), 15 | ] 16 | 17 | RT_X_MIX = [ 18 | ("fractal20220817_data", 0.54087122203), 19 | ("kuka", 0.8341046294), 20 | ("bridge_dataset", 1.0), 21 | ("taco_play", 2.0), 22 | ("jaco_play", 2.0), 23 | ("berkeley_cable_routing", 3.0), 24 | ("roboturk", 1.0), 25 | ("nyu_door_opening_surprising_effectiveness", 5.0), 26 | ("viola", 2.0), 27 | ("berkeley_autolab_ur5", 1.0), 28 | ("toto", 1.0), 29 | ] 30 | 31 | 32 | OXE_FRANKA_MIX = [ 33 | ("taco_play", 1.0), 34 | ("berkeley_cable_routing", 1.0), 35 | ("viola", 1.0), 36 | ("toto", 1.0), 37 | ("stanford_hydra_dataset_converted_externally_to_rlds", 1.0), 38 | ("austin_buds_dataset_converted_externally_to_rlds", 3.0), 39 | ("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0), 40 | ("maniskill_dataset_converted_externally_to_rlds", 0.1), 41 | ("furniture_bench_dataset_converted_externally_to_rlds", 0.1), 42 | ("cmu_franka_exploration_dataset_converted_externally_to_rlds", 5.0), 43 | ("austin_sailor_dataset_converted_externally_to_rlds", 1.0), 44 | ("austin_sirius_dataset_converted_externally_to_rlds", 1.0), 45 | ("berkeley_rpt_converted_externally_to_rlds", 1.0), 46 | ("kaist_nonprehensile_converted_externally_to_rlds", 3.0), 47 | ("stanford_robocook_converted_externally_to_rlds", 1.0), 48 | ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0), 49 | ("utaustin_mutex", 1.0), 50 | # ("cmu_playing_with_food", 1.0), 51 | ("cmu_play_fusion", 1.0), 52 | ] 53 | 54 | 55 | OXE_MAGIC_SOUP = [ 56 | ("fractal20220817_data", 0.54087122203), 57 | ("kuka", 0.8341046294), 58 | ("bridge_dataset", 1.0), 59 | ("taco_play", 2.0), 60 | ("jaco_play", 1.0), 61 | ("berkeley_cable_routing", 1.0), 62 | ("roboturk", 2.0), 63 | ("nyu_door_opening_surprising_effectiveness", 1.0), 64 | ("viola", 2.0), 65 | ("berkeley_autolab_ur5", 2.0), 66 | ("toto", 1.0), 67 | ("language_table", 0.1), 68 | ("stanford_hydra_dataset_converted_externally_to_rlds", 2.0), 69 | ("austin_buds_dataset_converted_externally_to_rlds", 1.0), 70 | ("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0), 71 | ("furniture_bench_dataset_converted_externally_to_rlds", 0.1), 72 | ("ucsd_kitchen_dataset_converted_externally_to_rlds", 2.0), 73 | ("austin_sailor_dataset_converted_externally_to_rlds", 1.0), 74 | ("austin_sirius_dataset_converted_externally_to_rlds", 1.0), 75 | ("bc_z", 0.2), 76 | ("dlr_edan_shared_control_converted_externally_to_rlds", 1.0), 77 | ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0), 78 | # ("uiuc_d3field", 1.0), --> somehow raw data is broken 79 | ("utaustin_mutex", 1.0), 80 | ("berkeley_fanuc_manipulation", 2.0), 81 | ("cmu_stretch", 1.0), 82 | ] 83 | 84 | 85 | OXE_FLEX_ACT_SOUP = [ 86 | ("fractal20220817_data", 0.54087122203), 87 | ("kuka", 0.8341046294), 88 | ("bridge_dataset", 1.0), 89 | ("taco_play", 2.0), 90 | ("jaco_play", 1.0), 91 | ("berkeley_cable_routing", 1.0), 92 | ("roboturk", 2.0), 93 | ("nyu_door_opening_surprising_effectiveness", 1.0), 94 | ("viola", 2.0), 95 | ("berkeley_autolab_ur5", 2.0), 96 | ("toto", 1.0), 97 | ("language_table", 0.1), 98 | ("stanford_hydra_dataset_converted_externally_to_rlds", 2.0), 99 | ("austin_buds_dataset_converted_externally_to_rlds", 1.0), 100 | ("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0), 101 | ("furniture_bench_dataset_converted_externally_to_rlds", 0.1), 102 | ("ucsd_kitchen_dataset_converted_externally_to_rlds", 2.0), 103 | ("austin_sailor_dataset_converted_externally_to_rlds", 1.0), 104 | ("austin_sirius_dataset_converted_externally_to_rlds", 1.0), 105 | ("bc_z", 0.2), 106 | ("berkeley_mvp_converted_externally_to_rlds", 1.0), 107 | # ("berkeley_rpt_converted_externally_to_rlds", 1.0), 108 | ("dlr_edan_shared_control_converted_externally_to_rlds", 1.0), 109 | ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0), 110 | # ("uiuc_d3field", 1.0), --> somehow raw data is broken 111 | ("utaustin_mutex", 1.0), 112 | ("berkeley_fanuc_manipulation", 2.0), 113 | ("cmu_stretch", 1.0), 114 | ("gnm_dataset", 1.0), 115 | ("aloha_static_dataset", 3.0), 116 | # ("aloha_dagger_dataset", 1.0), 117 | ("aloha_mobile_dataset", 2.0), 118 | # ("fmb_dataset", 1.0), 119 | ("dobbe", 1.0), 120 | ("roboset", 0.5), 121 | ("rh20t", 0.5), 122 | ] 123 | 124 | 125 | OXE_FULL_MIX = [ 126 | ("fractal20220817_data", 1.0), 127 | ("kuka", 1.0), 128 | ("bridge_dataset", 1), 129 | ("taco_play", 1.0), 130 | ("jaco_play", 1.0), 131 | ("berkeley_cable_routing", 1.0), 132 | ("roboturk", 1.0), 133 | ("nyu_door_opening_surprising_effectiveness", 1.0), 134 | ("viola", 1.0), 135 | ("berkeley_autolab_ur5", 1.0), 136 | ("toto", 1.0), 137 | ("language_table", 1.0), 138 | ("columbia_cairlab_pusht_real", 1.0), 139 | ("stanford_kuka_multimodal_dataset_converted_externally_to_rlds", 1.0), 140 | ("nyu_rot_dataset_converted_externally_to_rlds", 1.0), 141 | ("stanford_hydra_dataset_converted_externally_to_rlds", 1.0), 142 | ("austin_buds_dataset_converted_externally_to_rlds", 1.0), 143 | ("nyu_franka_play_dataset_converted_externally_to_rlds", 1.0), 144 | ("maniskill_dataset_converted_externally_to_rlds", 1.0), 145 | ("furniture_bench_dataset_converted_externally_to_rlds", 1.0), 146 | ("cmu_franka_exploration_dataset_converted_externally_to_rlds", 1.0), 147 | ("ucsd_kitchen_dataset_converted_externally_to_rlds", 1.0), 148 | ("ucsd_pick_and_place_dataset_converted_externally_to_rlds", 1.0), 149 | ("austin_sailor_dataset_converted_externally_to_rlds", 1.0), 150 | ("austin_sirius_dataset_converted_externally_to_rlds", 1.0), 151 | ("bc_z", 1.0), 152 | ("utokyo_pr2_opening_fridge_converted_externally_to_rlds", 1.0), 153 | ("utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds", 1.0), 154 | ("utokyo_xarm_pick_and_place_converted_externally_to_rlds", 1.0), 155 | ("utokyo_xarm_bimanual_converted_externally_to_rlds", 1.0), 156 | ("robo_net", 1.0), 157 | ("berkeley_mvp_converted_externally_to_rlds", 1.0), 158 | ("berkeley_rpt_converted_externally_to_rlds", 1.0), 159 | ("kaist_nonprehensile_converted_externally_to_rlds", 1.0), 160 | ("stanford_mask_vit_converted_externally_to_rlds", 1.0), 161 | ("tokyo_u_lsmo_converted_externally_to_rlds", 1.0), 162 | ("dlr_sara_pour_converted_externally_to_rlds", 1.0), 163 | ("dlr_sara_grid_clamp_converted_externally_to_rlds", 1.0), 164 | ("dlr_edan_shared_control_converted_externally_to_rlds", 1.0), 165 | ("asu_table_top_converted_externally_to_rlds", 1.0), 166 | ("stanford_robocook_converted_externally_to_rlds", 1.0), 167 | ("imperialcollege_sawyer_wrist_cam", 1.0), 168 | ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0), 169 | ("uiuc_d3field", 1.0), 170 | ("utaustin_mutex", 1.0), 171 | ("berkeley_fanuc_manipulation", 1.0), 172 | ("cmu_playing_with_food", 1.0), 173 | ("cmu_play_fusion", 1.0), 174 | ("cmu_stretch", 1.0), 175 | ("gnm_dataset", 1.0), 176 | ] 177 | 178 | OXE_NAMED_MIXES = { 179 | "bridge": BRIDGE_MIX, 180 | "fractal": FRACTAL_MIX, 181 | "rtx": RT_X_MIX, 182 | "rtx_franka": RT_X_MIX + OXE_FRANKA_MIX, 183 | "oxe_magic_soup": OXE_MAGIC_SOUP, 184 | "oxe_flex_act_soup": OXE_FLEX_ACT_SOUP, 185 | "oxe_simple": OXE_SIMPLE, 186 | } 187 | -------------------------------------------------------------------------------- /src/data/oxe/preprocess/mod_functions.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import ClassVar 3 | 4 | import tensorflow as tf 5 | import tensorflow_datasets as tfds 6 | 7 | import src.data.dlimp as dl 8 | 9 | 10 | class TfdsModFunction(ABC): 11 | @classmethod 12 | @abstractmethod 13 | def mod_features( 14 | cls, 15 | features: tfds.features.FeaturesDict, 16 | ) -> tfds.features.FeaturesDict: 17 | """ 18 | Modifies the data builder feature dict to reflect feature changes of ModFunction. 19 | """ 20 | ... 21 | 22 | @classmethod 23 | @abstractmethod 24 | def mod_dataset(cls, ds: tf.data.Dataset) -> tf.data.Dataset: 25 | """ 26 | Perform arbitrary modifications on the dataset that comply with the modified feature definition. 27 | """ 28 | ... 29 | 30 | 31 | def mod_obs_features(features, obs_feature_mod_function): 32 | """Utility function to only modify keys in observation dict.""" 33 | return tfds.features.FeaturesDict( 34 | { 35 | "steps": tfds.features.Dataset( 36 | { 37 | "observation": tfds.features.FeaturesDict( 38 | { 39 | key: obs_feature_mod_function( 40 | key, features["steps"]["observation"][key] 41 | ) 42 | for key in features["steps"]["observation"].keys() 43 | } 44 | ), 45 | **{ 46 | key: features["steps"][key] 47 | for key in features["steps"].keys() 48 | if key not in ("observation",) 49 | }, 50 | } 51 | ), 52 | **{key: features[key] for key in features.keys() if key not in ("steps",)}, 53 | } 54 | ) 55 | 56 | 57 | class ResizeAndJpegEncode(TfdsModFunction): 58 | MAX_RES: int = 224 59 | 60 | @classmethod 61 | def mod_features( 62 | cls, 63 | features: tfds.features.FeaturesDict, 64 | ) -> tfds.features.FeaturesDict: 65 | def downsize_and_jpeg(key, feat): 66 | """Downsizes image features, encodes as jpeg.""" 67 | if ( 68 | len(feat.shape) >= 2 and feat.shape[0] >= 64 and feat.shape[1] >= 64 69 | ): # is image / depth feature 70 | should_jpeg_encode = ( 71 | isinstance(feat, tfds.features.Image) and "depth" not in key 72 | ) 73 | if len(feat.shape) > 2: 74 | new_shape = ( 75 | ResizeAndJpegEncode.MAX_RES, 76 | ResizeAndJpegEncode.MAX_RES, 77 | feat.shape[2], 78 | ) 79 | else: 80 | new_shape = ( 81 | ResizeAndJpegEncode.MAX_RES, 82 | ResizeAndJpegEncode.MAX_RES, 83 | ) 84 | 85 | if isinstance(feat, tfds.features.Image): 86 | return tfds.features.Image( 87 | shape=new_shape, 88 | dtype=feat.dtype, 89 | encoding_format="jpeg" if should_jpeg_encode else "png", 90 | doc=feat.doc, 91 | ) 92 | else: 93 | return tfds.features.Tensor( 94 | shape=new_shape, 95 | dtype=feat.dtype, 96 | doc=feat.doc, 97 | ) 98 | 99 | return feat 100 | 101 | return mod_obs_features(features, downsize_and_jpeg) 102 | 103 | @classmethod 104 | def mod_dataset(cls, ds: tf.data.Dataset) -> tf.data.Dataset: 105 | def resize_image_fn(step): 106 | # resize images 107 | for key in step["observation"]: 108 | if len(step["observation"][key].shape) >= 2 and ( 109 | step["observation"][key].shape[0] >= 64 110 | or step["observation"][key].shape[1] >= 64 111 | ): 112 | size = (ResizeAndJpegEncode.MAX_RES, ResizeAndJpegEncode.MAX_RES) 113 | if "depth" in key: 114 | step["observation"][key] = tf.cast( 115 | dl.utils.resize_depth_image( 116 | tf.cast(step["observation"][key], tf.float32), size 117 | ), 118 | step["observation"][key].dtype, 119 | ) 120 | else: 121 | step["observation"][key] = tf.cast( 122 | dl.utils.resize_image(step["observation"][key], size), 123 | tf.uint8, 124 | ) 125 | return step 126 | 127 | def episode_map_fn(episode): 128 | episode["steps"] = episode["steps"].map(resize_image_fn) 129 | return episode 130 | 131 | return ds.map(episode_map_fn) 132 | 133 | 134 | class FilterSuccess(TfdsModFunction): 135 | @classmethod 136 | def mod_features( 137 | cls, 138 | features: tfds.features.FeaturesDict, 139 | ) -> tfds.features.FeaturesDict: 140 | return features # no feature changes 141 | 142 | @classmethod 143 | def mod_dataset(cls, ds: tf.data.Dataset) -> tf.data.Dataset: 144 | return ds.filter(lambda e: e["success"]) 145 | 146 | 147 | class FlipImgChannels(TfdsModFunction): 148 | FLIP_KEYS: ClassVar[list[str]] = ["image"] 149 | 150 | @classmethod 151 | def mod_features( 152 | cls, 153 | features: tfds.features.FeaturesDict, 154 | ) -> tfds.features.FeaturesDict: 155 | return features # no feature changes 156 | 157 | @classmethod 158 | def mod_dataset(cls, ds: tf.data.Dataset) -> tf.data.Dataset: 159 | def flip(step): 160 | for key in cls.FLIP_KEYS: 161 | if key in step["observation"]: 162 | step["observation"][key] = step["observation"][key][..., ::-1] 163 | return step 164 | 165 | def episode_map_fn(episode): 166 | episode["steps"] = episode["steps"].map(flip) 167 | return episode 168 | 169 | return ds.map(episode_map_fn) 170 | 171 | 172 | class FlipWristImgChannels(FlipImgChannels): 173 | FLIP_KEYS: ClassVar[list[str]] = ["wrist_image", "hand_image"] 174 | 175 | 176 | TFDS_MOD_FUNCTIONS = { 177 | "resize_and_jpeg_encode": ResizeAndJpegEncode, 178 | "filter_success": FilterSuccess, 179 | "flip_image_channels": FlipImgChannels, 180 | "flip_wrist_image_channels": FlipWristImgChannels, 181 | } 182 | -------------------------------------------------------------------------------- /src/data/utils/goal_relabeling.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains simple goal relabeling logic for BC use-cases where rewards and next_observations are not required. 3 | Each function should add entries to the "task" dict. 4 | """ 5 | 6 | from typing import Optional 7 | 8 | import tensorflow as tf 9 | 10 | from src.data.utils.data_utils import tree_merge 11 | 12 | 13 | def uniform(traj: dict, max_goal_distance: Optional[int] = None) -> dict: 14 | """ 15 | Relabels with a true uniform distribution over future states. 16 | Optionally caps goal distance. 17 | """ 18 | traj_len = tf.shape(tf.nest.flatten(traj["observation"])[0])[0] 19 | 20 | # select a random future index for each transition i in the range [i, traj_len) 21 | rand = tf.random.uniform([traj_len]) 22 | low = tf.cast(tf.range(traj_len), tf.float32) 23 | if max_goal_distance is not None: 24 | high = tf.cast( 25 | tf.minimum(tf.range(traj_len) + max_goal_distance, traj_len), tf.float32 26 | ) 27 | else: 28 | high = tf.cast(traj_len, tf.float32) 29 | goal_idxs = tf.cast(rand * (high - low) + low, tf.int32) 30 | 31 | # sometimes there are floating-point errors that cause an out-of-bounds 32 | goal_idxs = tf.minimum(goal_idxs, traj_len - 1) 33 | 34 | # adds keys to "task" mirroring "observation" keys (must do a tree merge to combine "pad_mask_dict" from 35 | # "observation" and "task" properly) 36 | goal = tf.nest.map_structure(lambda x: tf.gather(x, goal_idxs), traj["observation"]) 37 | traj["task"] = tree_merge(traj["task"], goal) 38 | 39 | return traj 40 | -------------------------------------------------------------------------------- /src/data/utils/task_augmentation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains basic logic for randomly zero-ing out keys in the task specification. 3 | """ 4 | 5 | import pickle 6 | 7 | import tensorflow as tf 8 | from huggingface_hub import hf_hub_download 9 | 10 | from src.data.utils.data_utils import to_padding 11 | 12 | 13 | def delete_and_rephrase( 14 | traj, 15 | paraphrases_repo: str, 16 | paraphrases_filename: str, 17 | rephrase_prob: float, 18 | keep_image_prob: float, 19 | ): 20 | traj = rephrase_instruction( 21 | traj, paraphrases_repo, paraphrases_filename, rephrase_prob 22 | ) 23 | traj = delete_task_conditioning(traj, keep_image_prob) 24 | return traj 25 | 26 | 27 | class Rephraser: 28 | def create_static_hash_table(self, dictionary): 29 | """Takes a python dictionary with string keys and values and creates a tf static hash table""" 30 | keys = list(dictionary.keys()) 31 | values = list(dictionary.values()) 32 | initializer = tf.lookup.KeyValueTensorInitializer( 33 | keys, values, key_dtype=tf.string, value_dtype=tf.string 34 | ) 35 | hash_table = tf.lookup.StaticHashTable(initializer, default_value="") 36 | return hash_table 37 | 38 | def __init__(self, paraphrases_repo: str, paraphrases_filename: str): 39 | if isinstance(paraphrases_repo, str) and isinstance(paraphrases_filename, str): 40 | with open( 41 | hf_hub_download( 42 | repo_id=paraphrases_repo, 43 | filename=paraphrases_filename, 44 | repo_type="dataset", 45 | ), 46 | "rb", 47 | ) as file: 48 | lang_paraphrases = pickle.load(file) 49 | # Create StaticHashTable 50 | self.rephrase_lookup = self.create_static_hash_table(lang_paraphrases) 51 | 52 | 53 | def rephrase_instruction( 54 | traj: dict, paraphrases_repo: str, paraphrases_filename: str, rephrase_prob: float 55 | ) -> dict: 56 | """Randomly rephrases language instructions with precomputed paraphrases 57 | Args: 58 | traj: A dictionary containing trajectory data. Should have a "task" key. 59 | paraphrases_repo: The name of the HF repo containing the paraphrases file. 60 | paraphrases_filename: The name of the file containing the paraphrases. 61 | rephrase_prob: The probability of augmenting the language instruction. The probability of keeping the language 62 | instruction is 1 - rephrase_prob. 63 | """ 64 | rephraser = Rephraser(paraphrases_repo, paraphrases_filename) 65 | 66 | if "language_instruction" not in traj["task"]: 67 | return traj 68 | original_language = traj["task"]["language_instruction"] 69 | # check the language key is not empty 70 | string_is_not_empty = tf.reduce_all(tf.strings.length(original_language) > 0) 71 | # check dict is not empty 72 | dict_is_not_empty = bool(rephraser.rephrase_lookup) 73 | if dict_is_not_empty and string_is_not_empty: 74 | rephrased_instruction = rephraser.rephrase_lookup.lookup(original_language[0]) 75 | rephrased_instruction = tf.where( 76 | tf.strings.length(rephrased_instruction) > 0, 77 | original_language[0] + "." + rephrased_instruction, 78 | original_language[0], 79 | ) 80 | split_tensor = tf.strings.split(rephrased_instruction, sep=".") 81 | num_strings = tf.cast(tf.shape(split_tensor)[0], tf.int32) 82 | random_index = tf.random.uniform( 83 | (tf.shape(original_language)[0],), 84 | minval=0, 85 | maxval=num_strings, 86 | dtype=tf.int32, 87 | ) 88 | sampled_language = tf.gather(split_tensor, random_index) 89 | rand = tf.random.uniform(shape=(), minval=0, maxval=1, dtype=tf.float32) 90 | sampled_language = tf.where( 91 | rand < rephrase_prob, 92 | sampled_language, 93 | original_language, 94 | ) 95 | traj["task"]["language_instruction"] = sampled_language 96 | return traj 97 | 98 | 99 | def delete_task_conditioning( 100 | traj: dict, 101 | keep_image_prob: float, 102 | ): 103 | """ 104 | Randomly drops out either the goal images or the language instruction. Only does something if both of 105 | these are present. 106 | 107 | Args: 108 | traj: A dictionary containing trajectory data. Should have a "task" key. 109 | keep_image_prob: The probability of keeping the goal images. The probability of keeping the language 110 | instruction is 1 - keep_image_prob. 111 | """ 112 | if "language_instruction" not in traj["task"]: 113 | return traj 114 | 115 | image_keys = { 116 | key 117 | for key in traj["task"].keys() 118 | if key.startswith("image_") or key.startswith("depth_") 119 | } 120 | if not image_keys: 121 | return traj 122 | 123 | traj_len = tf.shape(traj["action"])[0] 124 | should_keep_images = tf.random.uniform([traj_len]) < keep_image_prob 125 | should_keep_images |= ~traj["task"]["pad_mask_dict"]["language_instruction"] 126 | 127 | for key in image_keys | {"language_instruction"}: 128 | should_keep = should_keep_images if key in image_keys else ~should_keep_images 129 | # pad out the key 130 | traj["task"][key] = tf.where( 131 | should_keep, 132 | traj["task"][key], 133 | to_padding(traj["task"][key]), 134 | ) 135 | # zero out the pad mask dict for the key 136 | traj["task"]["pad_mask_dict"][key] = tf.where( 137 | should_keep, 138 | traj["task"]["pad_mask_dict"][key], 139 | tf.zeros_like(traj["task"]["pad_mask_dict"][key]), 140 | ) 141 | 142 | # when no goal images are present, the goal timestep becomes the final timestep 143 | traj["task"]["timestep"] = tf.where( 144 | should_keep_images, 145 | traj["task"]["timestep"], 146 | traj_len - 1, 147 | ) 148 | 149 | return traj 150 | -------------------------------------------------------------------------------- /src/data/utils/text_processing.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Optional, Sequence 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | MULTI_MODULE = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3" 8 | 9 | 10 | class TextProcessor(ABC): 11 | """ 12 | Base class for text tokenization or text embedding. 13 | """ 14 | 15 | @abstractmethod 16 | def encode(self, strings: Sequence[str]): 17 | raise NotImplementedError 18 | 19 | 20 | class HFTokenizer(TextProcessor): 21 | def __init__( 22 | self, 23 | tokenizer_name: str, 24 | tokenizer_kwargs: Optional[dict] = { 25 | "max_length": 64, 26 | "padding": "max_length", 27 | "truncation": True, 28 | "return_tensors": "np", 29 | }, 30 | encode_with_model: bool = False, 31 | ): 32 | from transformers import AutoTokenizer, FlaxAutoModel # lazy import 33 | 34 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 35 | self.tokenizer_kwargs = tokenizer_kwargs 36 | self.encode_with_model = encode_with_model 37 | if self.encode_with_model: 38 | self.model = FlaxAutoModel.from_pretrained(tokenizer_name) 39 | 40 | def encode(self, strings: Sequence[str]): 41 | # this creates another nested layer with "input_ids", "attention_mask", etc. 42 | inputs = self.tokenizer( 43 | strings, 44 | **self.tokenizer_kwargs, 45 | ) 46 | if self.encode_with_model: 47 | return np.array(self.model(**inputs).last_hidden_state) 48 | else: 49 | return dict(inputs) 50 | 51 | 52 | class MuseEmbedding(TextProcessor): 53 | def __init__(self): 54 | import tensorflow_hub as hub # lazy import 55 | import tensorflow_text # noqa: F401 56 | 57 | self.muse_model = hub.load(MULTI_MODULE) 58 | 59 | def encode(self, strings: Sequence[str]): 60 | with tf.device("/cpu:0"): 61 | return self.muse_model(strings).numpy() 62 | 63 | 64 | class CLIPTextProcessor(TextProcessor): 65 | def __init__( 66 | self, 67 | tokenizer_kwargs: Optional[dict] = { 68 | "max_length": 64, 69 | "padding": "max_length", 70 | "truncation": True, 71 | "return_tensors": "np", 72 | }, 73 | ): 74 | from transformers import CLIPProcessor 75 | 76 | self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") 77 | self.kwargs = tokenizer_kwargs 78 | 79 | def encode(self, strings: Sequence[str]): 80 | inputs = self.processor( 81 | text=strings, 82 | **self.kwargs, 83 | ) 84 | inputs["position_ids"] = np.expand_dims( 85 | np.arange(inputs["input_ids"].shape[1]), axis=0 86 | ).repeat(inputs["input_ids"].shape[0], axis=0) 87 | return inputs 88 | -------------------------------------------------------------------------------- /src/model/kv_cache.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import torch 4 | 5 | 6 | class KVCache: 7 | def __init__(self) -> None: 8 | """list for layers""" 9 | self.key_cache: List[torch.Tensor] = [] 10 | self.value_cache: List[torch.Tensor] = [] 11 | 12 | def has_item(self, layer_idx) -> bool: 13 | return len(self.key_cache) > layer_idx 14 | 15 | def num_items(self) -> int: 16 | if len(self.key_cache) == 0: 17 | return 0 18 | else: 19 | # The shape of the key_cache is [Batch_Size, Num_Heads_KV, Seq_Len, Head_Dim] 20 | return self.key_cache[0].shape[-2] 21 | 22 | def get(self, layer_idx) -> Tuple[torch.Tensor, torch.Tensor]: 23 | return self.key_cache[layer_idx], self.value_cache[layer_idx] 24 | 25 | def update( 26 | self, 27 | key_states: torch.Tensor, 28 | value_states: torch.Tensor, 29 | layer_idx: int, 30 | ) -> Tuple[torch.Tensor, torch.Tensor]: 31 | if len(self.key_cache) <= layer_idx: 32 | # If we never added anything to the KV-Cache of this layer, let's create it. 33 | self.key_cache.append(key_states) 34 | self.value_cache.append(value_states) 35 | else: 36 | # ... otherwise we concatenate the new keys with the existing ones. 37 | # each tensor has shape: [Batch_Size, Num_Heads_KV, Seq_Len, Head_Dim] 38 | self.key_cache[layer_idx] = torch.cat( 39 | [self.key_cache[layer_idx], key_states], dim=-2 40 | ) 41 | self.value_cache[layer_idx] = torch.cat( 42 | [self.value_cache[layer_idx], value_states], dim=-2 43 | ) 44 | 45 | # ... and then we return all the existing keys + the new ones. 46 | return self.key_cache[layer_idx], self.value_cache[layer_idx] 47 | -------------------------------------------------------------------------------- /src/model/paligemma/config.py: -------------------------------------------------------------------------------- 1 | class SiglipVisionConfig: 2 | def __init__( 3 | self, 4 | hidden_size=768, 5 | intermediate_size=3072, 6 | num_hidden_layers=12, 7 | num_attention_heads=12, 8 | num_channels=3, 9 | image_size=224, 10 | patch_size=16, 11 | layer_norm_eps=1e-6, 12 | attention_dropout=0.0, 13 | num_image_tokens=None, 14 | **kwargs, 15 | ): 16 | super().__init__() 17 | self.hidden_size = hidden_size 18 | self.intermediate_size = intermediate_size 19 | self.num_hidden_layers = num_hidden_layers 20 | self.num_attention_heads = num_attention_heads 21 | self.num_channels = num_channels 22 | self.patch_size = patch_size 23 | self.image_size = image_size 24 | self.attention_dropout = attention_dropout 25 | self.layer_norm_eps = layer_norm_eps 26 | self.num_image_tokens = num_image_tokens 27 | 28 | 29 | class GemmaConfig: 30 | def __init__( 31 | self, 32 | vocab_size, 33 | hidden_size, 34 | intermediate_size, 35 | num_hidden_layers, 36 | num_attention_heads, 37 | num_key_value_heads, 38 | head_dim=256, 39 | max_position_embeddings=8192, 40 | rms_norm_eps=1e-6, 41 | rope_theta=10000.0, 42 | attention_bias=False, 43 | attention_dropout=0.0, 44 | pad_token_id=None, 45 | **kwargs, 46 | ): 47 | super().__init__() 48 | self.vocab_size = vocab_size 49 | self.max_position_embeddings = max_position_embeddings 50 | self.hidden_size = hidden_size 51 | self.intermediate_size = intermediate_size 52 | self.num_hidden_layers = num_hidden_layers 53 | self.num_attention_heads = num_attention_heads 54 | self.head_dim = head_dim 55 | self.num_key_value_heads = num_key_value_heads 56 | self.rms_norm_eps = rms_norm_eps 57 | self.rope_theta = rope_theta 58 | self.attention_bias = attention_bias 59 | self.attention_dropout = attention_dropout 60 | self.pad_token_id = pad_token_id 61 | 62 | 63 | class PaliGemmaConfig: 64 | def __init__( 65 | self, 66 | vision_config=None, 67 | text_config=None, 68 | ignore_index=-100, 69 | image_token_index=256000, 70 | vocab_size=257152, 71 | projection_dim=2048, 72 | hidden_size=2048, 73 | pad_token_id=None, 74 | **kwargs, 75 | ): 76 | super().__init__() 77 | self.ignore_index = ignore_index 78 | self.image_token_index = image_token_index 79 | self.vocab_size = vocab_size 80 | self.projection_dim = projection_dim 81 | self.hidden_size = hidden_size 82 | self.vision_config = vision_config 83 | self.is_encoder_decoder = False 84 | self.pad_token_id = pad_token_id 85 | 86 | self.vision_config = SiglipVisionConfig(**vision_config) 87 | self.text_config = text_config 88 | 89 | self.text_config = GemmaConfig(**text_config, pad_token_id=pad_token_id) 90 | self.vocab_size = self.text_config.vocab_size 91 | 92 | self.text_config.num_image_tokens = ( 93 | self.vision_config.image_size // self.vision_config.patch_size 94 | ) ** 2 95 | self.vision_config.projection_dim = projection_dim 96 | -------------------------------------------------------------------------------- /src/model/paligemma/load.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | 5 | from safetensors import safe_open 6 | from transformers import AutoTokenizer 7 | 8 | from src.model.paligemma.config import PaliGemmaConfig 9 | from src.model.paligemma.gemma import PaliGemmaForConditionalGeneration 10 | 11 | 12 | def load_hf_model( 13 | model_path: str, 14 | device: str, 15 | quantize: bool = False, 16 | ): 17 | if quantize: 18 | print("Running qunatized model") 19 | 20 | # Load the tokenizer 21 | tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="right") 22 | assert tokenizer.padding_side == "right" 23 | 24 | # Find all the *.safetensors files 25 | safetensors_files = glob.glob(os.path.join(model_path, "*.safetensors")) 26 | 27 | # ... and load them one by one in the tensors dictionary 28 | tensors = {} 29 | for safetensors_file in safetensors_files: 30 | with safe_open(safetensors_file, framework="pt", device="cpu") as f: 31 | for key in f.keys(): 32 | tensors[key] = f.get_tensor(key) 33 | 34 | # Load the model's config 35 | with open(os.path.join(model_path, "config.json"), "r") as f: 36 | model_config_file = json.load(f) 37 | config = PaliGemmaConfig(**model_config_file) 38 | 39 | # Create the model using the configuration 40 | model = PaliGemmaForConditionalGeneration(config, use_quantize=quantize) 41 | 42 | # Load the state dict of the model 43 | model.load_state_dict(tensors, strict=False) 44 | 45 | # Move the model to the device --- quantization happens if the model is quantized 46 | model = model.to(device) 47 | 48 | # Tie weights 49 | model.tie_weights() 50 | 51 | return (model, tokenizer) 52 | -------------------------------------------------------------------------------- /src/model/paligemma/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from src.model.lora import get_layer 5 | 6 | 7 | class GemmaRMSNorm(nn.Module): 8 | def __init__(self, dim: int, eps: float = 1e-6): 9 | super().__init__() 10 | self.eps = eps 11 | self.weight = nn.Parameter(torch.zeros(dim)) 12 | 13 | def _norm(self, x): 14 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 15 | 16 | def forward(self, x): 17 | output = self._norm(x.float()) 18 | # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) 19 | # See https://github.com/huggingface/transformers/pull/29402 20 | output = output * (1.0 + self.weight.float()) 21 | return output.type_as(x) 22 | 23 | 24 | class GemmaRotaryEmbedding(nn.Module): 25 | """ 26 | forces RoPE to use float32 for full accuracy 27 | 28 | https://github.com/huggingface/transformers/pull/29402 29 | https://github.com/huggingface/transformers/pull/29285 30 | """ 31 | 32 | def __init__(self, dim, base=10000): 33 | super().__init__() 34 | 35 | self.dim = dim # it is set to the head_dim 36 | self.base = ( 37 | base # should be tuned based on the max_seq_len, e.g., in action expert 38 | ) 39 | 40 | # Calculate the theta according to the formula theta_i = base^(2i/dim) where i = 0, 1, 2, ..., dim // 2 41 | inv_freq = 1.0 / ( 42 | self.base 43 | ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim) 44 | ) 45 | self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) 46 | 47 | @torch.no_grad() 48 | def forward(self, x, position_ids): 49 | # x: [bs, num_attention_heads, seq_len, head_size] 50 | # Copy the inv_freq tensor for batch in the sequence 51 | # inv_freq_expanded: [Batch_Size, Head_Dim // 2, 1] 52 | inv_freq_expanded = ( 53 | self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) 54 | ) 55 | # position_ids_expanded: [Batch_Size, 1, Seq_Len] 56 | position_ids_expanded = position_ids[:, None, :].float() 57 | # Multiply each theta by the position (which is the argument of the sin and cos functions) 58 | # freqs: [Batch_Size, Head_Dim // 2, 1] @ [Batch_Size, 1, Seq_Len] --> [Batch_Size, Seq_Len, Head_Dim // 2] 59 | freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose( 60 | 1, 2 61 | ) 62 | # emb: [Batch_Size, Seq_Len, Head_Dim] 63 | emb = torch.cat((freqs, freqs), dim=-1) 64 | # cos, sin: [Batch_Size, Seq_Len, Head_Dim] 65 | cos = emb.cos() 66 | sin = emb.sin() 67 | return cos.to(x.dtype), sin.to(x.dtype) 68 | 69 | 70 | class GemmaMLP(nn.Module): 71 | def __init__(self, config, use_quantize=False, use_lora=False): 72 | super().__init__() 73 | self.config = config 74 | self.hidden_size = config.hidden_size 75 | self.intermediate_size = config.intermediate_size 76 | 77 | layer = get_layer( 78 | use_quantize, 79 | use_lora, 80 | **config.lora if use_lora else {}, 81 | ) 82 | self.gate_proj = layer(self.hidden_size, self.intermediate_size, bias=False) 83 | self.up_proj = layer(self.hidden_size, self.intermediate_size, bias=False) 84 | self.down_proj = layer(self.intermediate_size, self.hidden_size, bias=False) 85 | 86 | def forward(self, x): 87 | # Equivalent to: 88 | # y = self.gate_proj(x) # [Batch_Size, Seq_Len, Hidden_Size] -> [Batch_Size, Seq_Len, Intermediate_Size] 89 | # y = torch.gelu(y, approximate="tanh") # [Batch_Size, Seq_Len, Intermediate_Size] 90 | # j = self.up_proj(x) # [Batch_Size, Seq_Len, Hidden_Size] -> [Batch_Size, Seq_Len, Intermediate_Size] 91 | # z = y * j # [Batch_Size, Seq_Len, Intermediate_Size] 92 | # z = self.down_proj(z) # [Batch_Size, Seq_Len, Intermediate_Size] -> [Batch_Size, Seq_Len, Hidden_Size] 93 | return self.down_proj( 94 | nn.functional.gelu(self.gate_proj(x), approximate="tanh") * self.up_proj(x) 95 | ) 96 | -------------------------------------------------------------------------------- /src/model/paligemma/processing.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Iterable, List, Optional, Tuple, Union 2 | 3 | import numpy as np 4 | import torch 5 | from PIL import Image 6 | 7 | IMAGENET_STANDARD_MEAN = [0.5, 0.5, 0.5] 8 | IMAGENET_STANDARD_STD = [0.5, 0.5, 0.5] 9 | 10 | 11 | def add_image_tokens_to_prompt( 12 | prefix_prompt, 13 | bos_token, 14 | image_seq_len, 15 | image_token, 16 | ): 17 | # Quoting from the blog (https://huggingface.co/blog/paligemma#detailed-inference-process): 18 | # The input text is tokenized normally. 19 | # A token is added at the beginning, and an additional newline token (\n) is appended. 20 | # This newline token is an essential part of the input prompt the model was trained with, so adding it explicitly ensures it's always there. 21 | # The tokenized text is also prefixed with a fixed number of tokens. 22 | # NOTE: from the paper it looks like the `\n` should be tokenized separately, but in the HF implementation this is not done. 23 | # ref to HF implementation: https://github.com/huggingface/transformers/blob/7f79a97399bb52aad8460e1da2f36577d5dccfed/src/transformers/models/paligemma/processing_paligemma.py#L55-L73 24 | return f"{image_token * image_seq_len}{bos_token}{prefix_prompt}\n" 25 | 26 | 27 | def rescale( 28 | image: np.ndarray, 29 | scale: float, 30 | dtype: np.dtype = np.float32, 31 | ) -> np.ndarray: 32 | rescaled_image = image * scale 33 | rescaled_image = rescaled_image.astype(dtype) 34 | return rescaled_image 35 | 36 | 37 | def resize( 38 | image: Image, 39 | size: Tuple[int, int], 40 | resample: Image.Resampling = None, 41 | reducing_gap: Optional[int] = None, 42 | ) -> np.ndarray: 43 | height, width = size 44 | resized_image = image.resize( 45 | (width, height), resample=resample, reducing_gap=reducing_gap 46 | ) 47 | return resized_image 48 | 49 | 50 | def normalize( 51 | image: np.ndarray, 52 | mean: Union[float, Iterable[float]], 53 | std: Union[float, Iterable[float]], 54 | ) -> np.ndarray: 55 | mean = np.array(mean, dtype=image.dtype) 56 | std = np.array(std, dtype=image.dtype) 57 | image = (image - mean) / std 58 | return image 59 | 60 | 61 | def process_images( 62 | images: List[Image.Image], 63 | size: Dict[str, int], 64 | resample: Image.Resampling, 65 | rescale_factor: float, 66 | image_mean: Optional[Union[float, List[float]]] = None, 67 | image_std: Optional[Union[float, List[float]]] = None, 68 | ) -> List[np.ndarray]: 69 | height, width = size[0], size[1] 70 | images = [ 71 | resize(image=image, size=(height, width), resample=resample) for image in images 72 | ] 73 | # Convert each image to a numpy array 74 | images = [np.array(image) for image in images] 75 | # Rescale the pixel values to be in the range [0, 1] 76 | images = [rescale(image, scale=rescale_factor) for image in images] 77 | # Normalize the images to have mean 0 and standard deviation 1 78 | images = [normalize(image, mean=image_mean, std=image_std) for image in images] 79 | # Move the channel dimension to the first dimension. The model expects images in the format [Channel, Height, Width] 80 | images = [image.transpose(2, 0, 1) for image in images] 81 | return images 82 | 83 | 84 | class PaliGemmaProcessor: 85 | IMAGE_TOKEN = "" 86 | 87 | def __init__( 88 | self, 89 | tokenizer, 90 | num_image_tokens: int, 91 | image_size: int, 92 | ): 93 | super().__init__() 94 | 95 | self.image_seq_length = num_image_tokens 96 | self.image_size = image_size 97 | 98 | # Tokenizer described here: https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md#tokenizer 99 | tokens_to_add = {"additional_special_tokens": [self.IMAGE_TOKEN]} 100 | tokenizer.add_special_tokens(tokens_to_add) 101 | EXTRA_TOKENS = [ 102 | f"" for i in range(1024) 103 | ] # These tokens are used for object detection (bounding boxes) 104 | EXTRA_TOKENS += [ 105 | f"" for i in range(128) 106 | ] # These tokens are used for object segmentation 107 | tokenizer.add_tokens(EXTRA_TOKENS) 108 | self.image_token_id = tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN) 109 | # We will add the BOS and EOS tokens ourselves 110 | tokenizer.add_bos_token = False 111 | tokenizer.add_eos_token = False 112 | 113 | self.tokenizer = tokenizer 114 | 115 | def __call__( 116 | self, 117 | text: List[str], 118 | images: List[Image.Image], 119 | padding: str = "longest", 120 | truncation: bool = True, 121 | ) -> dict: 122 | assert ( 123 | len(images) == 1 and len(text) == 1 124 | ), f"Received {len(images)} images for {len(text)} prompts." 125 | 126 | pixel_values = process_images( 127 | images, 128 | size=(self.image_size, self.image_size), 129 | resample=Image.Resampling.BICUBIC, 130 | rescale_factor=1 / 255.0, 131 | image_mean=IMAGENET_STANDARD_MEAN, 132 | image_std=IMAGENET_STANDARD_STD, 133 | ) 134 | # Convert the list of numpy arrays to a single numpy array with shape [Batch_Size, Channel, Height, Width] 135 | pixel_values = np.stack(pixel_values, axis=0) 136 | # Convert the numpy array to a PyTorch tensor 137 | pixel_values = torch.tensor(pixel_values) 138 | 139 | # Prepend a `self.image_seq_length` number of image tokens to the prompt 140 | input_strings = [ 141 | add_image_tokens_to_prompt( 142 | prefix_prompt=prompt, 143 | bos_token=self.tokenizer.bos_token, 144 | image_seq_len=self.image_seq_length, 145 | image_token=self.IMAGE_TOKEN, 146 | ) 147 | for prompt in text 148 | ] 149 | 150 | # Returns the input_ids and attention_mask as PyTorch tensors 151 | inputs = self.tokenizer( 152 | input_strings, 153 | return_tensors="pt", 154 | padding=padding, 155 | truncation=truncation, 156 | ) 157 | output = {"pixel_values": pixel_values, **inputs} 158 | return output 159 | -------------------------------------------------------------------------------- /src/model/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def rotate_half(x): 5 | # Build the [-x2, x1, -x4, x3, ...] tensor for the sin part of the positional encoding. 6 | x1 = x[..., : x.shape[-1] // 2] # Takes the first half of the last dimension 7 | x2 = x[..., x.shape[-1] // 2 :] # Takes the second half of the last dimension 8 | return torch.cat((-x2, x1), dim=-1) 9 | 10 | 11 | def apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=1): 12 | cos = cos.unsqueeze(unsqueeze_dim) # Add the head dimension 13 | sin = sin.unsqueeze(unsqueeze_dim) # Add the head dimension 14 | # Apply the formula (34) of the Rotary Positional Encoding paper. 15 | x = (x * cos) + (rotate_half(x) * sin) 16 | return x 17 | 18 | 19 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 20 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape 21 | if n_rep == 1: 22 | return hidden_states 23 | hidden_states = hidden_states[:, :, None, :, :].expand( 24 | batch, num_key_value_heads, n_rep, slen, head_dim 25 | ) 26 | return hidden_states.reshape( 27 | batch, 28 | num_key_value_heads * n_rep, 29 | slen, 30 | head_dim, 31 | ) 32 | -------------------------------------------------------------------------------- /src/model/vla/modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional 3 | 4 | import torch 5 | import torch.nn as nn 6 | from einops import rearrange 7 | 8 | 9 | class SinusoidalPosEmb(nn.Module): 10 | def __init__(self, dim: int, max_period: float = 10000.0): 11 | super().__init__() 12 | self.half_dim = dim // 2 13 | self.max_period = max_period 14 | 15 | def forward(self, t: torch.FloatTensor) -> torch.FloatTensor: 16 | emb = math.log(self.max_period) / (self.half_dim - 1) 17 | emb = torch.exp( 18 | torch.arange(self.half_dim, device=t.device, dtype=t.dtype) * -emb 19 | ) 20 | emb = t[:, None] * emb[None, :] 21 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 22 | return emb 23 | 24 | 25 | class ActionEncoder(nn.Module): 26 | """Matching pi0 appendix""" 27 | 28 | def __init__(self, action_dim: int, width: int, time_cond: bool = False): 29 | super().__init__() 30 | self.linear_1 = nn.Linear(action_dim, width) 31 | if time_cond: 32 | self.linear_2 = nn.Linear(2 * width, width) 33 | else: 34 | self.linear_2 = nn.Linear(width, width) 35 | self.nonlinearity = nn.SiLU() # swish 36 | self.linear_3 = nn.Linear(width, width) 37 | self.time_cond = time_cond 38 | 39 | def forward( 40 | self, 41 | action: torch.FloatTensor, 42 | time_emb: Optional[torch.FloatTensor] = None, 43 | ) -> torch.FloatTensor: 44 | # [Batch_Size, Seq_Len, Width] 45 | emb = self.linear_1(action) 46 | if self.time_cond: 47 | # repeat time embedding for seq_len 48 | # [Batch_Size, Seq_Len, Width] 49 | time_emb_full = time_emb.unsqueeze(1).expand(-1, action.size(1), -1) 50 | emb = torch.cat([time_emb_full, emb], dim=-1) 51 | emb = self.nonlinearity(self.linear_2(emb)) 52 | emb = self.linear_3(emb) 53 | return emb 54 | 55 | 56 | class GaussianFourierFeatureTransform(torch.nn.Module): 57 | """ 58 | "Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains": 59 | https://arxiv.org/abs/2006.10739 60 | https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html 61 | """ 62 | 63 | def __init__( 64 | self, 65 | input_dim, 66 | embed_dim=256, 67 | scale=10, 68 | ): 69 | super(GaussianFourierFeatureTransform, self).__init__() 70 | self.b = torch.randn(input_dim, embed_dim) * scale 71 | self.pi = 3.14159265359 72 | 73 | def forward(self, v: torch.FloatTensor) -> torch.FloatTensor: 74 | x_proj = torch.matmul(2 * self.pi * v, self.b.to(v.device).to(v.dtype)) 75 | return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], -1) 76 | 77 | 78 | class AdaptiveRMSNorm(nn.Module): 79 | def __init__(self, dim: int, dim_cond: int, eps: float = 1e-6): 80 | super().__init__() 81 | self.eps = eps 82 | self.to_gamma = nn.Sequential( 83 | nn.Linear(dim_cond, dim), 84 | nn.Sigmoid(), 85 | ) 86 | self.to_beta = nn.Linear(dim_cond, dim, bias=False) 87 | 88 | def _norm(self, x: torch.FloatTensor) -> torch.FloatTensor: 89 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 90 | 91 | def forward( 92 | self, x: torch.FloatTensor, cond: torch.FloatTensor 93 | ) -> torch.FloatTensor: 94 | output = self._norm(x) 95 | if cond.ndim == 2: 96 | cond = rearrange(cond, "b d -> b 1 d") 97 | gamma = self.to_gamma(cond) 98 | beta = self.to_beta(cond) 99 | return output * gamma + beta 100 | 101 | 102 | class AdaptiveLayerscale(nn.Module): 103 | def __init__( 104 | self, dim: int, dim_cond: int, adaln_zero_bias_init_value: float = -2.0 105 | ): 106 | super().__init__() 107 | adaln_zero_gamma_linear = nn.Linear(dim_cond, dim) 108 | nn.init.zeros_(adaln_zero_gamma_linear.weight) 109 | nn.init.constant_(adaln_zero_gamma_linear.bias, adaln_zero_bias_init_value) 110 | 111 | self.to_adaln_zero_gamma = adaln_zero_gamma_linear 112 | 113 | def forward( 114 | self, x: torch.FloatTensor, cond: torch.FloatTensor 115 | ) -> torch.FloatTensor: 116 | if cond.ndim == 2: 117 | cond = rearrange(cond, "b d -> b 1 d") 118 | gamma = self.to_adaln_zero_gamma(cond) 119 | return x * gamma.sigmoid() 120 | -------------------------------------------------------------------------------- /src/model/vla/processing.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | 5 | IMAGENET_STANDARD_MEAN = torch.tensor([0.5, 0.5, 0.5]) 6 | IMAGENET_STANDARD_STD = torch.tensor([0.5, 0.5, 0.5]) 7 | 8 | 9 | def add_image_tokens_to_prompt( 10 | prefix_prompt, 11 | bos_token, 12 | image_seq_len, 13 | image_token, 14 | ): 15 | # Quoting from the blog (https://huggingface.co/blog/paligemma#detailed-inference-process): 16 | # The input text is tokenized normally. 17 | # A token is added at the beginning, and an additional newline token (\n) is appended. 18 | # This newline token is an essential part of the input prompt the model was trained with, so adding it explicitly ensures it's always there. 19 | # The tokenized text is also prefixed with a fixed number of tokens. 20 | # NOTE: from the paper it looks like the `\n` should be tokenized separately, but in the HF implementation this is not done. 21 | # ref to HF implementation: https://github.com/huggingface/transformers/blob/7f79a97399bb52aad8460e1da2f36577d5dccfed/src/transformers/models/paligemma/processing_paligemma.py#L55-L73 22 | return f"{image_token * image_seq_len}{bos_token}{prefix_prompt}\n" 23 | 24 | 25 | def rescale( 26 | image: torch.LongTensor, 27 | scale: float, 28 | ) -> torch.FloatTensor: 29 | rescaled_image = image * scale 30 | return rescaled_image 31 | 32 | 33 | def normalize( 34 | image: torch.LongTensor, 35 | mean: torch.FloatTensor, 36 | std: torch.FloatTensor, 37 | ) -> torch.FloatTensor: 38 | assert image.ndim == 4, f"Expected 4D tensor, got {image.ndim}D tensor." 39 | assert ( 40 | image.shape[1] == 3 41 | ), f"Expected 3 channels at axis 1, got {image.shape[1]} channels." 42 | mean = mean[None, :, None, None] # add batch and spatial dimensions 43 | std = std[None, :, None, None] 44 | image = (image - mean) / std 45 | return image 46 | 47 | 48 | def process_images( 49 | images: torch.LongTensor, 50 | rescale_factor: float, 51 | image_mean: torch.FloatTensor, 52 | image_std: torch.FloatTensor, 53 | ) -> torch.FloatTensor: 54 | # Rescale the pixel values to be in the range [0, 1] 55 | images = rescale(images, scale=rescale_factor) 56 | 57 | # Normalize the images to have mean 0 and standard deviation 1 58 | images = normalize(images, mean=image_mean, std=image_std) 59 | 60 | return images 61 | 62 | 63 | class VLAProcessor: 64 | IMAGE_TOKEN = "" 65 | 66 | def __init__( 67 | self, 68 | tokenizer, 69 | num_image_tokens: int, 70 | max_seq_len: int, 71 | tokenizer_padding: str = "max_length", # # instead of truncating to longest 72 | ): 73 | super().__init__() 74 | 75 | self.image_seq_length = num_image_tokens 76 | self.max_seq_len = max_seq_len 77 | self.tokenizer_padding = tokenizer_padding 78 | 79 | # Tokenizer described here: https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md#tokenizer 80 | tokens_to_add = {"additional_special_tokens": [self.IMAGE_TOKEN]} 81 | tokenizer.add_special_tokens(tokens_to_add) 82 | EXTRA_TOKENS = [ 83 | f"" for i in range(1024) 84 | ] # These tokens are used for object detection (bounding boxes) 85 | EXTRA_TOKENS += [ 86 | f"" for i in range(128) 87 | ] # These tokens are used for object segmentation 88 | tokenizer.add_tokens(EXTRA_TOKENS) 89 | self.image_token_id = tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN) 90 | # We will add the BOS and EOS tokens ourselves 91 | tokenizer.add_bos_token = False 92 | tokenizer.add_eos_token = False 93 | 94 | self.tokenizer = tokenizer 95 | 96 | def __call__( 97 | self, 98 | text: List[str], 99 | images: torch.LongTensor, 100 | truncation: bool = True, 101 | ) -> dict: 102 | assert len(images) == len( 103 | text 104 | ), f"Received {len(images)} images for {len(text)} prompts." 105 | assert ( 106 | images.dtype == torch.uint8 107 | ), f"Expected uint8 tensor for images, got {images.dtype}." 108 | 109 | pixel_values = process_images( 110 | images, 111 | rescale_factor=1 / 255.0, 112 | image_mean=IMAGENET_STANDARD_MEAN, 113 | image_std=IMAGENET_STANDARD_STD, 114 | ) 115 | 116 | # Prepend a `self.image_seq_length` number of image tokens to the prompt 117 | input_strings = [ 118 | add_image_tokens_to_prompt( 119 | prefix_prompt=prompt, 120 | bos_token=self.tokenizer.bos_token, 121 | image_seq_len=self.image_seq_length, 122 | image_token=self.IMAGE_TOKEN, 123 | ) 124 | for prompt in text 125 | ] 126 | 127 | # Returns the input_ids and attention_mask as PyTorch tensors 128 | inputs = self.tokenizer( 129 | input_strings, 130 | return_tensors="pt", 131 | max_length=self.max_seq_len, 132 | padding=self.tokenizer_padding, 133 | truncation=truncation, 134 | ) 135 | output = {"pixel_values": pixel_values, **inputs} 136 | return output 137 | -------------------------------------------------------------------------------- /src/utils/decorator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def conditional_decorator(dec, condition): 5 | def decorator(func): 6 | if not condition: 7 | # Return the function unchanged, not decorated. 8 | return func 9 | return dec(func) 10 | 11 | return decorator 12 | 13 | 14 | class NoSyncBase: 15 | def no_sync(self): 16 | if self.use_ddp: 17 | # If DDP is used, call the actual `no_sync` method 18 | return torch.nn.parallel.DistributedDataParallel.no_sync(self) 19 | else: 20 | # Otherwise, return the dummy context manager 21 | class DummyContext: 22 | def __enter__(self): 23 | pass 24 | 25 | def __exit__(self, exc_type, exc_value, traceback): 26 | pass 27 | 28 | return DummyContext() 29 | 30 | 31 | def main_rank_only(func): 32 | def wrapper(*args, **kwargs): 33 | if not kwargs.get("main_rank", False): 34 | return None 35 | return func(*args, **kwargs) 36 | 37 | return wrapper 38 | -------------------------------------------------------------------------------- /src/utils/metric.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | 5 | 6 | def get_action_accuracy( 7 | gt: torch.FloatTensor, # [Batch_Size, Horizon, Action_Dim] 8 | pred: torch.FloatTensor, 9 | thresholds: List[float] = [0.1, 0.2], 10 | ) -> torch.FloatTensor: 11 | device = gt.device 12 | diff = torch.abs(gt - pred).reshape(-1, gt.shape[-1]) 13 | 14 | # get the percentage of diff lower than threshold for all action dimensions 15 | accuracies = torch.zeros(len(thresholds), device=device) 16 | for idx, threshold in enumerate(thresholds): 17 | accuracy = torch.mean( 18 | (torch.mean((diff < threshold).float(), dim=1) >= 1.0).float() 19 | ) 20 | accuracies[idx] = accuracy 21 | return accuracies 22 | -------------------------------------------------------------------------------- /src/utils/monitor.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import logging 3 | import time 4 | 5 | import torch 6 | 7 | 8 | def log_allocated_gpu_memory(log=None, stage="loading model", device=0): 9 | if torch.cuda.is_available(): 10 | allocated_memory = torch.cuda.memory_allocated(device) 11 | msg = f"Allocated GPU memory after {stage}: {allocated_memory/1024/1024/1024:.2f} GB" 12 | print(msg) if log is None else log.info(msg) 13 | 14 | 15 | def log_execution_time(logger=None): 16 | """Decorator to log the execution time of a function""" 17 | 18 | def decorator(func): 19 | @functools.wraps(func) 20 | def wrapper(*args, **kwargs): 21 | start_time = time.time() 22 | result = func(*args, **kwargs) 23 | end_time = time.time() 24 | elapsed_time = end_time - start_time 25 | if logger is None: 26 | print(f"{func.__name__} took {elapsed_time:.2f} seconds to execute.") 27 | else: 28 | logger.info( 29 | f"{func.__name__} took {elapsed_time:.2f} seconds to execute." 30 | ) 31 | return result 32 | 33 | return wrapper 34 | 35 | return decorator 36 | 37 | 38 | class Timer: 39 | def __init__(self): 40 | self._start = time.time() 41 | 42 | def __call__(self, reset=True): 43 | now = time.time() 44 | diff = now - self._start 45 | if reset: 46 | self._start = now 47 | return diff 48 | 49 | 50 | # Filter to log only on the main rank 51 | class MainRankFilter(logging.Filter): 52 | def __init__(self, main_rank): 53 | super().__init__() 54 | self.main_rank = main_rank 55 | 56 | def filter(self, record): 57 | # Only log if this is the main rank 58 | return self.main_rank 59 | -------------------------------------------------------------------------------- /src/utils/optim.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2022 Naoki Katsura 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | # Modified from https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup. Not inherited from _LRScheduler so compatible with offloaded optimizer 26 | import math 27 | 28 | import torch 29 | 30 | 31 | class CosineAnnealingWarmupRestarts: 32 | """ 33 | optimizer (Optimizer): Wrapped optimizer. 34 | first_cycle_steps (int): First cycle step size. 35 | cycle_mult(float): Cycle steps magnification. Default: 1. 36 | max_lr(float): First cycle's max learning rate. Default: 0.1. 37 | min_lr(float): Min learning rate. Default: 0.001. 38 | warmup_steps(int): Linear warmup step size. Default: 0. 39 | gamma(float): Decrease rate of max learning rate by cycle. Default: 1. 40 | last_epoch (int): The index of last epoch. Default: -1. 41 | """ 42 | 43 | def __init__( 44 | self, 45 | optimizer: torch.optim.Optimizer, 46 | first_cycle_steps: int, 47 | cycle_mult: float = 1.0, 48 | max_lr: float = 0.1, 49 | min_lr: float = 0.001, 50 | warmup_steps: int = 0, 51 | gamma: float = 1.0, 52 | last_epoch: int = -1, 53 | ): 54 | assert warmup_steps < first_cycle_steps 55 | 56 | self.first_cycle_steps = first_cycle_steps # first cycle step size 57 | self.cycle_mult = cycle_mult # cycle steps magnification 58 | self.base_max_lr = max_lr # first max learning rate 59 | self.max_lr = max_lr # max learning rate in the current cycle 60 | self.min_lr = min_lr # min learning rate 61 | self.warmup_steps = warmup_steps # warmup step size 62 | self.gamma = gamma # decrease rate of max learning rate by cycle 63 | 64 | self.cur_cycle_steps = first_cycle_steps # first cycle step size 65 | self.cycle = 0 # cycle count 66 | self.step_in_cycle = last_epoch # step size of the current cycle 67 | self.last_epoch = last_epoch 68 | 69 | self.optimizer = optimizer 70 | self.last_epoch = last_epoch 71 | 72 | # set learning rate min_lr 73 | self.init_lr() 74 | 75 | def state_dict(self): 76 | return { 77 | key: value for key, value in self.__dict__.items() if key != "optimizer" 78 | } 79 | 80 | def load_state_dict(self, state_dict): 81 | self.__dict__.update(state_dict) 82 | 83 | def init_lr(self): 84 | self.base_lrs = [] 85 | for param_group in self.optimizer.param_groups: 86 | param_group["lr"] = self.min_lr 87 | self.base_lrs.append(self.min_lr) 88 | 89 | def get_lr(self): 90 | if self.step_in_cycle == -1: 91 | return self.base_lrs 92 | elif self.step_in_cycle < self.warmup_steps: 93 | return [ 94 | (self.max_lr - base_lr) * self.step_in_cycle / self.warmup_steps 95 | + base_lr 96 | for base_lr in self.base_lrs 97 | ] 98 | else: 99 | return [ 100 | base_lr 101 | + (self.max_lr - base_lr) 102 | * ( 103 | 1 104 | + math.cos( 105 | math.pi 106 | * (self.step_in_cycle - self.warmup_steps) 107 | / (self.cur_cycle_steps - self.warmup_steps) 108 | ) 109 | ) 110 | / 2 111 | for base_lr in self.base_lrs 112 | ] 113 | 114 | def step(self, epoch=None): 115 | if epoch is None: 116 | epoch = self.last_epoch + 1 117 | self.step_in_cycle = self.step_in_cycle + 1 118 | if self.step_in_cycle >= self.cur_cycle_steps: 119 | self.cycle += 1 120 | self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps 121 | self.cur_cycle_steps = ( 122 | int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) 123 | + self.warmup_steps 124 | ) 125 | else: 126 | if epoch >= self.first_cycle_steps: 127 | if self.cycle_mult == 1.0: 128 | self.step_in_cycle = epoch % self.first_cycle_steps 129 | self.cycle = epoch // self.first_cycle_steps 130 | else: 131 | n = int( 132 | math.log( 133 | ( 134 | epoch / self.first_cycle_steps * (self.cycle_mult - 1) 135 | + 1 136 | ), 137 | self.cycle_mult, 138 | ) 139 | ) 140 | self.cycle = n 141 | self.step_in_cycle = epoch - int( 142 | self.first_cycle_steps 143 | * (self.cycle_mult**n - 1) 144 | / (self.cycle_mult - 1) 145 | ) 146 | self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** ( 147 | n 148 | ) 149 | else: 150 | self.cur_cycle_steps = self.first_cycle_steps 151 | self.step_in_cycle = epoch 152 | 153 | self.max_lr = self.base_max_lr * (self.gamma**self.cycle) 154 | self.last_epoch = math.floor(epoch) 155 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 156 | if isinstance(param_group["lr"], torch.Tensor): 157 | param_group["lr"].fill_(lr) 158 | else: 159 | param_group["lr"] = lr 160 | 161 | 162 | def get_num_params_in_billions(optimizer): 163 | return ( 164 | sum(p.numel() for group in optimizer.param_groups for p in group["params"]) 165 | / 1e9 166 | ) 167 | 168 | 169 | def optimizer_to(optim, device): 170 | for param in optim.state.values(): 171 | # Not sure there are any global tensors in the state dict 172 | if isinstance(param, torch.Tensor): 173 | param.data = param.data.to(device) 174 | if param._grad is not None: 175 | param._grad.data = param._grad.data.to(device) 176 | elif isinstance(param, dict): 177 | for subparam in param.values(): 178 | if isinstance(subparam, torch.Tensor): 179 | subparam.data = subparam.data.to(device) 180 | if subparam._grad is not None: 181 | subparam._grad.data = subparam._grad.data.to(device) 182 | -------------------------------------------------------------------------------- /src/utils/spec.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from functools import partial 3 | from typing import Any, Dict, Tuple, TypedDict, Union 4 | 5 | 6 | class ModuleSpec(TypedDict): 7 | """A JSON-serializable representation of a function or class with some default args and kwargs to pass to it. Useful for specifying a particular class or function in a config file, while keeping it serializable and overridable from the command line using ml_collections. 8 | 9 | Usage: 10 | 11 | # Preferred way to create a spec: 12 | >>> from src.model.components.transformer import Transformer 13 | >>> spec = ModuleSpec.create(Transformer, num_layers=3) 14 | # Same as above using the fully qualified import string: 15 | >>> spec = ModuleSpec.create("src.model.components.transformer:Transformer", num_layers=3) 16 | 17 | # Usage: 18 | >>> ModuleSpec.instantiate(spec) == partial(Transformer, num_layers=3) 19 | # can pass additional kwargs at instantiation time 20 | >>> transformer = ModuleSpec.instantiate(spec, num_heads=8) 21 | 22 | Note: ModuleSpec is just an alias for a dictionary (that is strongly typed), not a real class. So from 23 | your code's perspective, it is just a dictionary. 24 | 25 | module (str): The module the callable is located in 26 | name (str): The name of the callable in the module 27 | args (tuple): The args to pass to the callable 28 | kwargs (dict): The kwargs to pass to the callable 29 | """ 30 | 31 | module: str 32 | name: str 33 | args: Tuple[Any, ...] 34 | kwargs: Dict[str, Any] 35 | 36 | @staticmethod 37 | def create( 38 | callable_or_full_name: Union[str, callable], *args, **kwargs 39 | ) -> "ModuleSpec": # type: ignore 40 | """Create a module spec from a callable or import string. 41 | 42 | Args: 43 | callable_or_full_name (str or object): Either the object itself or a fully qualified import string 44 | (e.g. "src.model.components.transformer:Transformer") 45 | args (tuple, optional): Passed into callable upon instantiation. 46 | kwargs (dict, optional): Passed into callable upon instantiation. 47 | """ 48 | if isinstance(callable_or_full_name, str): 49 | assert callable_or_full_name.count(":") == 1, ( 50 | "If passing in a string, it must be a fully qualified import string " 51 | "(e.g. 'src.model.components.transformer:Transformer')" 52 | ) 53 | module, name = callable_or_full_name.split(":") 54 | else: 55 | module, name = _infer_full_name(callable_or_full_name) 56 | 57 | return ModuleSpec(module=module, name=name, args=args, kwargs=kwargs) 58 | 59 | @staticmethod 60 | def instantiate(spec: "ModuleSpec"): # type: ignore 61 | if set(spec.keys()) != {"module", "name", "args", "kwargs"}: 62 | raise ValueError( 63 | f"Expected ModuleSpec, but got {spec}. " 64 | "ModuleSpec must have keys 'module', 'name', 'args', and 'kwargs'." 65 | ) 66 | cls = _import_from_string(spec["module"], spec["name"]) 67 | return partial(cls, *spec["args"], **spec["kwargs"]) 68 | 69 | @staticmethod 70 | def to_string(spec: "ModuleSpec"): # type: ignore 71 | return ( 72 | f"{spec['module']}:{spec['name']}" 73 | f"({', '.join(spec['args'])}" 74 | f"{', ' if spec['args'] and spec['kwargs'] else ''}" 75 | f"{', '.join(f'{k}={v}' for k, v in spec['kwargs'].items())})" 76 | ) 77 | 78 | 79 | def _infer_full_name(o: object): 80 | if hasattr(o, "__module__") and hasattr(o, "__name__"): 81 | return o.__module__, o.__name__ 82 | else: 83 | raise ValueError( 84 | f"Could not infer identifier for {o}. " 85 | "Please pass in a fully qualified import string instead " 86 | "e.g. 'src.model.components.transformer:Transformer'" 87 | ) 88 | 89 | 90 | def _import_from_string(module_string: str, name: str): 91 | try: 92 | module = importlib.import_module(module_string) 93 | return getattr(module, name) 94 | except Exception as e: 95 | raise ValueError(f"Could not import {module_string}:{name}") from e 96 | --------------------------------------------------------------------------------