├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── Makefile ├── README.md ├── SECURITY.md ├── action_model ├── __init__.py ├── action_model.py ├── diffusion_utils.py ├── gaussian_diffusion.py ├── models.py ├── respace.py └── timestep_sampler.py ├── conf ├── __init__.py └── vla.py ├── pyproject.toml ├── scripts ├── deploy.py └── train.py ├── sim_cogact ├── __init__.py ├── adaptive_ensemble.py ├── cogact_policy.py └── scripts │ ├── cogact_bridge.sh │ ├── cogact_drawer_variant_agg.sh │ ├── cogact_drawer_visual_matching.sh │ ├── cogact_move_near_variant_agg.sh │ ├── cogact_move_near_visual_matching.sh │ ├── cogact_pick_coke_can_variant_agg.sh │ ├── cogact_pick_coke_can_visual_matching.sh │ ├── cogact_put_in_drawer_variant_agg.sh │ ├── cogact_put_in_drawer_visual_matching.sh │ └── octo_bridge.sh ├── training ├── __init__.py ├── materialize.py ├── metrics.py └── strategies │ ├── __init__.py │ ├── base_strategy_cogact.py │ └── fsdp.py └── vla ├── __init__.py ├── cogactvla.py └── load.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Ruff 132 | .ruff_cache/ 133 | 134 | # Auth Tokens / Hidden Files 135 | .hf_token 136 | .wandb_api_key 137 | .*_token 138 | .*api_key 139 | 140 | # IDE Caches 141 | .idea/ 142 | .vscode/ 143 | 144 | # Mac OS 145 | .DS_Store 146 | 147 | # Caches and Datasets 148 | cache/ 149 | data/ 150 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: help clean check autoformat 2 | .DEFAULT: help 3 | 4 | # Generates a useful overview/help message for various make features - add to this as necessary! 5 | help: 6 | @echo "make clean" 7 | @echo " Remove all temporary pyc/pycache files" 8 | @echo "make check" 9 | @echo " Run code style and linting (black, ruff) *without* changing files!" 10 | @echo "make autoformat" 11 | @echo " Run code styling (black, ruff) and update in place - committing with pre-commit also does this." 12 | 13 | clean: 14 | find . -name "*.pyc" | xargs rm -f && \ 15 | find . -name "__pycache__" | xargs rm -rf 16 | 17 | check: 18 | black --check . 19 | ruff check --show-source . 20 | 21 | autoformat: 22 | black . 23 | ruff check --fix --show-fixes . 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CogACT: A Foundational Vision-Language-Action Model for Synergizing Cognition and Action in Robotic Manipulation 2 | ### 🚩[Project Page](https://cogact.github.io/) | 📑[Paper](https://arxiv.org/abs/2411.19650) | 🤗[Models](https://huggingface.co/CogACT) 3 | 4 | 5 | This is the code for CogACT: A Foundational Vision-Language-Action Model for Synergizing Cognition and Action in Robotic Manipulation. 6 | ## News 7 | 8 | - 🔥 [2024-12-23] Update new setions: [Deployment in The Real World](#deployment-in-the-real-world) and [Inference Speed](#inference-speed). It demonstrates the advantage of our approach in inference speed. 9 | - 🔥 [2024-12-01] Initial release. 10 | 11 | ## Contents 12 | * [**Installation**](#installation) 13 | * [**Getting Started**](#getting-started) 14 | * [**Fully Fine-Tuning**](#fully-fine-tuning) 15 | * [**Training CogACT from Scratch**](#training-cogact-from-scratch) 16 | * [**Evaluation in SIMPLER**](#evaluation-in-simpler) 17 | * [**Deployment in The Real World**](#deployment-in-the-real-world) 18 | * [**Inference Speed**](#inference-speed) 19 | ## Installation 20 | The code is built using Python 3.10, and can be run under any environment with Python 3.8 and above. We require PyTorch >= 2.2.0 and CUDA >= 12.0 (It may run with lower versions, but we have not tested it). 21 | 22 | We recommend using [Miniconda](https://docs.conda.io/en/latest/miniconda.html) and setting up an environment: 23 | ```bash 24 | conda create --name cogact python=3.10 25 | ``` 26 | Next, clone our repo and install the required packages: 27 | ```bash 28 | git clone https://github.com/microsoft/CogACT 29 | cd CogACT 30 | pip install -e . 31 | ``` 32 | If you need to use the traning code, please also install the [Flash Attention](https://github.com/Dao-AILab/flash-attention). You can simply run (The first run might take a bit longer): 33 | ```bash 34 | pip install -e .[train] 35 | ``` 36 | or install it manually: 37 | ```bash 38 | # [Optional] 39 | # Training additionally requires Flash-Attention 2 (https://github.com/Dao-AILab/flash-attention) 40 | pip install packaging ninja 41 | 42 | # Verify Ninja --> should return exit code "0" 43 | ninja --version; echo $? 44 | 45 | # Install Flash Attention 2 46 | # =>> If you run into difficulty, try `pip cache remove flash_attn` first 47 | pip install "flash-attn==2.5.5" --no-build-isolation 48 | ``` 49 | ## Getting Started 50 | We release three CogACT models with different model sizes, including [Small](https://huggingface.co/CogACT/CogACT-Small), [Base](https://huggingface.co/CogACT/CogACT-Base) and [Large](https://huggingface.co/CogACT/CogACT-Large). Checkpoints, configs, and model cards are availabel on [Hugging Face page](https://huggingface.co/CogACT). Refer to the code below for the minimal inference: 51 | ```python 52 | from PIL import Image 53 | from vla import load_vla 54 | import torch 55 | 56 | model = load_vla( 57 | 'CogACT/CogACT-Base', # choose from [CogACT-Small, CogACT-Base, CogACT-Large] or the local path 58 | load_for_training=False, 59 | action_model_type='DiT-B', # choose from ['DiT-S', 'DiT-B', 'DiT-L'] to match the model weight 60 | future_action_window_size=15, 61 | ) 62 | # about 30G Memory in fp32; 63 | 64 | # (Optional) use "model.vlm = model.vlm.to(torch.bfloat16)" to load vlm in bf16 65 | 66 | model.to('cuda:0').eval() 67 | 68 | image: Image.Image = 69 | prompt = "move sponge near apple" # input your prompt 70 | 71 | # Predict Action (7-DoF; un-normalize for RT-1 google robot data, i.e., fractal20220817_data) 72 | actions, _ = model.predict_action( 73 | image, 74 | prompt, 75 | unnorm_key='fractal20220817_data', # input your unnorm_key of the dataset 76 | cfg_scale = 1.5, # cfg from 1.5 to 7 also performs well 77 | use_ddim = True, # use DDIM sampling 78 | num_ddim_steps = 10, # number of steps for DDIM sampling 79 | ) 80 | 81 | # results in 7-DoF actions of 16 steps with shape [16, 7] 82 | ``` 83 | 84 | Alternatively, you can use batch inference function ``predict_action_batch`` from [vla/cogactvla.py](./vla/cogactvla.py) to accelerate inference in the simulator. For our ``Adaptive Action Ensemble`` strategy, please refer to [adaptive_ensemble.py](./evaluation/adaptive_ensemble.py). 85 | 86 | ## Fully Fine-Tuning 87 | To fully fine-tune the pretrained models, we use PyTorch Fully Sharded Data Parallel ([FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html)). The training script used is from [Prismatic VLMs](https://github.com/TRI-ML/prismatic-vlms). 88 | We recommend using fully finetune on your dataset instead of LoRA, because the model with fully finetuning performs better in a shorter training time. Empirically. Fully finetuning the pretrained model for around 30 epochs already yields good results. Pretrained models can be download from our [Hugging Face page](https://huggingface.co/CogACT/CogACT-Base) or by passing the model_id to the training scripts for automatic download. 89 | 90 | Download from our [Hugging Face page](https://huggingface.co/CogACT/CogACT-Base), using CogACT-Base for an example. (Optional) 91 | ```bash 92 | # Change directory to your base model PATH 93 | cd 94 | 95 | # Make sure you have git-lfs installed (https://git-lfs.com) 96 | git lfs install 97 | 98 | # Download checkpoint (30 GB) 99 | git clone https://huggingface.co/CogACT/CogACT-Base 100 | ``` 101 | You can also pass the model_id (e.g., `CogACT/CogACT-Base`) to the training scripts for automatic download. (Seeing below) 102 | 103 | Next, create a [Hugging Face user access token](https://huggingface.co/docs/hub/en/security-tokens) and export the token value. 104 | 105 | ```bash 106 | # export the HuggingFace user access token token 107 | export HF_TOKEN = hf_.. 108 | ``` 109 | 110 | Then launch the training script. We use one node with 8 A100 GPUs as an example. 111 | ```bash 112 | torchrun --standalone --nnodes 1 --nproc-per-node 8 scripts/train.py \ 113 | --pretrained_checkpoint \ 114 | --vla.type prism-dinosiglip-224px+oxe+diffusion \ 115 | --vla.data_mix \ 116 | --vla.expected_world_size 8 \ 117 | --vla.global_batch_size 256 \ 118 | --vla.per_device_batch_size 32 \ 119 | --vla.learning_rate 2e-5 \ 120 | --data_root_dir \ 121 | --run_root_dir \ 122 | --run_id \ 123 | --image_aug \ 124 | --wandb_project \ 125 | --wandb_entity \ 126 | --save_interval \ 127 | --repeated_diffusion_steps 8 \ 128 | --future_action_window_size 15 \ 129 | --action_model_type DiT-B \ 130 | --is_resume False 131 | ``` 132 | More customized training settings and changes can be made in [`conf/vla.py`](conf/vla.py) by modifying and registering a new VLA type. If you want to resume from a checkpoint instead of starting training from scratch, please set `is_resume=True`. Note that you also need to set `--resume_step` and `--resume_epoch` to match the checkpoint, and the optimizer in the checkpoint also needs to be loaded. 133 | 134 | To finetune on datasets belong to [Open X-Embodiment (OXE)](https://robotics-transformer-x.github.io/), you can download them from [OXE](https://robotics-transformer-x.github.io/) and change the ``vla.data_mix`` to the corresponding name. To finetune on your own customized data, please follow the instruction [(rlds_dataset_builder)](https://github.com/kpertsch/rlds_dataset_builder) for converting your data to RLDS format. The actions should be the deltas of end effector ``EEF Delta XYZ (3) + Roll-Pitch-Yaw (3) + Gripper Open/Close (1)``. Once your customized data is ready, place the customized data directly under the ``/custom_finetuning/1.0.0`` directory. Then set ``vla.data_mix="custom_finetuning"``. 135 | 136 | ## Training CogACT from Scratch 137 | You can start the trainging from the weights of [OpenVLA](https://github.com/openvla/openvla) for greater efficiency. Please follow the instruction of [OpenVLA](https://github.com/openvla/openvla) to download their weights: 138 | ```bash 139 | # From OpenVLA repo 140 | # Change directory to your base model checkpoints folder 141 | cd 142 | 143 | # Download checkpoint (30 GB) -- may take a few minutes 144 | git clone git@hf.co:openvla/openvla-7b-prismatic 145 | 146 | # If the command above did not download the full checkpoint, 147 | # manually fetch it via git Large File Storage (LFS) 148 | # Note: You may have to configure an SSH key for this to work 149 | cd openvla-7b-prismatic 150 | git lfs fetch --all 151 | ``` 152 | The data of [Open X-Embodiment (OXE)](https://robotics-transformer-x.github.io/) can be download following [OXE](https://robotics-transformer-x.github.io/) and [OpenVLA](https://github.com/openvla/openvla). Then launch the training script. We use one node with 8 A100 GPUs as an example. 153 | 154 | ```bash 155 | torchrun --standalone --nnodes 1 --nproc-per-node 8 scripts/train.py \ 156 | --pretrained_checkpoint openvla-7b-prismatic/checkpoints/step-295000-epoch-40-loss=0.2200.pt \ 157 | --vla.type prism-dinosiglip-224px+oxe+diffusion \ 158 | --vla.data_mix oxe_magic_soup_plus_minus \ 159 | --vla.expected_world_size 8 \ 160 | --vla.global_batch_size 256 \ 161 | --vla.per_device_batch_size 32 \ 162 | --vla.learning_rate 2e-5 \ 163 | --data_root_dir \ 164 | --run_root_dir \ 165 | --run_id \ 166 | --image_aug \ 167 | --wandb_project \ 168 | --wandb_entity \ 169 | --save_interval \ 170 | --repeated_diffusion_steps 8 \ 171 | --future_action_window_size 15 \ 172 | --action_model_type DiT-B \ 173 | --is_resume False 174 | ``` 175 | You can also start training from PrismaticVLM and simply ignore the ``--pretrained_checkpoint``. However, it will take longer to converge. 176 | 177 | ## Evaluation in SIMPLER 178 | In this section, we provide a minimal evaluation for our models in [SIMPLER](https://simpler-env.github.io/). First, please follow the instruction of [SimplerEnv](https://github.com/simpler-env/SimplerEnv) to install the simulation environment. Next, add our [./sim_cogact](./sim_cogact) to [SimplerEnv/simpler_env/policies](https://github.com/simpler-env/SimplerEnv/tree/main/simpler_env/policies). 179 | ```bash 180 | cp ./sim_cogact /simpler_env/policies -r 181 | ``` 182 | Then add a new policy model in [SimplerEnv/simpler_env/main_inference.py](https://github.com/simpler-env/SimplerEnv/blob/main/simpler_env/main_inference.py) as below: 183 | ```python 184 | elif args.policy_model == "cogact": 185 | from simpler_env.policies.sim_cogact import CogACTInference 186 | assert args.ckpt_path is not None 187 | model = CogACTInference( 188 | saved_model_path=args.ckpt_path, # e.g., CogACT/CogACT-Base 189 | policy_setup=args.policy_setup, 190 | action_scale=args.action_scale, 191 | action_model_type='DiT-B', 192 | cfg_scale=1.5 # cfg from 1.5 to 7 also performs well 193 | ) 194 | ``` 195 | After that, you can modify and launch the scripts in [`sim_cogact/scripts`](sim_cogact/scripts) like: 196 | ```bash 197 | cd 198 | bash simpler_env/policies/sim_cogact/scripts/cogact_put_in_drawer_visual_matching.sh 199 | ``` 200 | 201 | ## Deployment in The Real World 202 | 203 | For your own environment or robot, please first collect the corresponding real-world operation data (e.g., using teleoperation). Then, use the data to fine-tune the pretrained model we provided, following the instructions in the section [Fully Fine-Tuning](#fully-fine-tuning). 204 | 205 | Next, you can set up the server and client as instructed in the [`scripts/deploy.py`](scripts/deploy.py) and deploy it on the real robot according to the hardware you are using. Please run the following line to serve the fine-tuned model: (Using 'fractal20220817_data' as an example, please replace "unnorm_key" with the value from your fine-tuned dataset in actual use.) 206 | ```bash 207 | python scripts/deploy.py --saved_model_path --unnorm_key fractal20220817_data --action_ensemble --use_bf16 --action_ensemble_horizon 2 --adaptive_ensemble_alpha 0.1 --cfg_scale 1.5 --port 5500 208 | ``` 209 | You can also use other inference strategies modifying the parameters in [`scripts/deploy.py`](scripts/deploy.py) such as the action chunking (output multiple acitons without ensembling). 210 | 211 | As for the client, only a Python environment and the ``requests`` library (``pip install requests``) are required; 212 | no other dependencies need to be installed. 213 | 214 | A simple client (standalone) usage (assuming a server running on 127.0.0.1:5500): 215 | 216 | ```python 217 | import requests 218 | import json 219 | 220 | # Define the API endpoint 221 | url = 'http://127.0.0.1:5500/api/inference' 222 | 223 | # Define the parameters you want to send 224 | data = { 225 | 'task_description': "Pick up the red can.", 226 | } 227 | image = "image/google_robot.png" 228 | 229 | json.dump(data, open("data.json", "w")) 230 | 231 | with open ("data.json", "r") as query_file: 232 | with open(image, "rb") as image_file: 233 | file = [ 234 | ('images', (image, image_file, 'image/png')), 235 | ('json', ("data.json", query_file, 'application/json')) 236 | ] 237 | 238 | response = requests.post(url, files=file) 239 | # print(response) 240 | if response.status_code == 200: 241 | pass 242 | else: 243 | print("Failed to get a response from the API") 244 | print(response.text) 245 | ``` 246 | 247 | ## Inference Speed 248 | We serve the ``CogACT-Base`` on a single A6000 GPU in bfloat16 format and invoke it 100 times repeatedly (see [Deployment in The Real World](#deployment-in-the-real-world) for deployment details). It takes about 181ms for each inference in average. Therefore, the action generation frequency is approximately 5.5Hz on a single A6000 GPU using our ``Adaptive Action Ensemble`` strategy. If the action chunking strategy is used and *k* actions (*k* is at most 16) are output each time, the frequency will become *k* times the original. However, the accuracy of the actions will gradually decrease as *k* increases due to the longer open-loop prediction. 249 | 250 | We also deploy OpenVLA in bfloat16 format on the same device for comparison, test the average time for model inference, and list the number of actions the model can generate in a single inference in the following table. 251 | 252 | | | CogACT-Base | OpenVLA | 253 | |-----------------------------|-------------|---------| 254 | | Inference time (ms) | 181 | 307 | 255 | | Number of generated actions | 16 | 1 | 256 | 257 | As shown in the table, our method has a faster inference speed because we use a single cognition token to generate an entire action sequence. In contrast, an OpenVLA-style model needs to generate 7 tokens to represent a 7-dimensional action. Even when considering the time taken for our DiT inference, our model still achieves a significant speedup compared to OpenVLA. Additionally, our approach can utilize action chunking to generate multiple actions in a single inference. 258 | 259 | ## Citing 260 | If you find our work useful, please consider citing [our paper](https://cogact.github.io/CogACT_paper.pdf): 261 | 262 | ```bibtex 263 | @article{li2024cogact, 264 | title={CogACT: A Foundational Vision-Language-Action Model for Synergizing Cognition and Action in Robotic Manipulation}, 265 | author={Li, Qixiu and Liang, Yaobo and Wang, Zeyu and Luo, Lin and Chen, Xi and Liao, Mozheng and Wei, Fangyun and Deng, Yu and Xu, Sicheng and Zhang, Yizhong and others}, 266 | journal={arXiv preprint arXiv:2411.19650}, 267 | year={2024} 268 | } 269 | ``` 270 | 271 | ## Contributing 272 | 273 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 274 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 275 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 276 | 277 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 278 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 279 | provided by the bot. You will only need to do this once across all repos using our CLA. 280 | 281 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 282 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 283 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 284 | 285 | ## Trademarks 286 | 287 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 288 | trademarks or logos is subject to and must follow 289 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 290 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 291 | Any use of third-party trademarks or logos are subject to those third-party's policies. 292 | 293 | ## License 294 | 295 | All the code, model weights, and data are licensed under [MIT license](./LICENSE). 296 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /action_model/__init__.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from . import gaussian_diffusion as gd 7 | from .respace import SpacedDiffusion, space_timesteps 8 | 9 | 10 | def create_diffusion( 11 | timestep_respacing, 12 | noise_schedule="linear", 13 | use_kl=False, 14 | sigma_small=False, 15 | predict_xstart=False, 16 | learn_sigma=True, 17 | rescale_learned_sigmas=False, 18 | diffusion_steps=1000 19 | ): 20 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 21 | if use_kl: 22 | loss_type = gd.LossType.RESCALED_KL 23 | elif rescale_learned_sigmas: 24 | loss_type = gd.LossType.RESCALED_MSE 25 | else: 26 | loss_type = gd.LossType.MSE 27 | if timestep_respacing is None or timestep_respacing == "": 28 | timestep_respacing = [diffusion_steps] 29 | return SpacedDiffusion( 30 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 31 | betas=betas, 32 | model_mean_type=( 33 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 34 | ), 35 | model_var_type=( 36 | ( 37 | gd.ModelVarType.FIXED_LARGE 38 | if not sigma_small 39 | else gd.ModelVarType.FIXED_SMALL 40 | ) 41 | if not learn_sigma 42 | else gd.ModelVarType.LEARNED_RANGE 43 | ), 44 | loss_type=loss_type 45 | # rescale_timesteps=rescale_timesteps, 46 | ) 47 | -------------------------------------------------------------------------------- /action_model/action_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | action_model.py 3 | 4 | """ 5 | from action_model.models import DiT 6 | from action_model import create_diffusion 7 | from . import gaussian_diffusion as gd 8 | import torch 9 | from torch import nn 10 | 11 | # Create model sizes of ActionModels 12 | def DiT_S(**kwargs): 13 | return DiT(depth=6, hidden_size=384, num_heads=4, **kwargs) 14 | def DiT_B(**kwargs): 15 | return DiT(depth=12, hidden_size=768, num_heads=12, **kwargs) 16 | def DiT_L(**kwargs): 17 | return DiT(depth=24, hidden_size=1024, num_heads=16, **kwargs) 18 | 19 | # Model size 20 | DiT_models = {'DiT-S': DiT_S, 'DiT-B': DiT_B, 'DiT-L': DiT_L} 21 | 22 | # Create ActionModel 23 | class ActionModel(nn.Module): 24 | def __init__(self, 25 | token_size, 26 | model_type, 27 | in_channels, 28 | future_action_window_size, 29 | past_action_window_size, 30 | diffusion_steps = 100, 31 | noise_schedule = 'squaredcos_cap_v2' 32 | ): 33 | super().__init__() 34 | self.in_channels = in_channels 35 | self.noise_schedule = noise_schedule 36 | # GaussianDiffusion offers forward and backward functions q_sample and p_sample. 37 | self.diffusion_steps = diffusion_steps 38 | self.diffusion = create_diffusion(timestep_respacing="", noise_schedule = noise_schedule, diffusion_steps=self.diffusion_steps, sigma_small=True, learn_sigma = False) 39 | self.ddim_diffusion = None 40 | if self.diffusion.model_var_type in [gd.ModelVarType.LEARNED, gd.ModelVarType.LEARNED_RANGE]: 41 | learn_sigma = True 42 | else: 43 | learn_sigma = False 44 | self.past_action_window_size = past_action_window_size 45 | self.future_action_window_size = future_action_window_size 46 | self.net = DiT_models[model_type]( 47 | token_size = token_size, 48 | in_channels=in_channels, 49 | class_dropout_prob = 0.1, 50 | learn_sigma = learn_sigma, 51 | future_action_window_size = future_action_window_size, 52 | past_action_window_size = past_action_window_size 53 | ) 54 | 55 | # Given condition z and ground truth token x, compute loss 56 | def loss(self, x, z): 57 | # sample random noise and timestep 58 | noise = torch.randn_like(x) # [B, T, C] 59 | timestep = torch.randint(0, self.diffusion.num_timesteps, (x.size(0),), device= x.device) 60 | 61 | # sample x_t from x 62 | x_t = self.diffusion.q_sample(x, timestep, noise) 63 | 64 | # predict noise from x_t 65 | noise_pred = self.net(x_t, timestep, z) 66 | 67 | assert noise_pred.shape == noise.shape == x.shape 68 | # Compute L2 loss 69 | loss = ((noise_pred - noise) ** 2).mean() 70 | # Optional: loss += loss_vlb 71 | 72 | return loss 73 | 74 | # Create DDIM sampler 75 | def create_ddim(self, ddim_step=10): 76 | self.ddim_diffusion = create_diffusion(timestep_respacing = "ddim"+str(ddim_step), 77 | noise_schedule = self.noise_schedule, 78 | diffusion_steps = self.diffusion_steps, 79 | sigma_small = True, 80 | learn_sigma = False 81 | ) 82 | return self.ddim_diffusion -------------------------------------------------------------------------------- /action_model/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import torch as th 7 | import numpy as np 8 | 9 | 10 | def normal_kl(mean1, logvar1, mean2, logvar2): 11 | """ 12 | Compute the KL divergence between two gaussians. 13 | Shapes are automatically broadcasted, so batches can be compared to 14 | scalars, among other use cases. 15 | """ 16 | tensor = None 17 | for obj in (mean1, logvar1, mean2, logvar2): 18 | if isinstance(obj, th.Tensor): 19 | tensor = obj 20 | break 21 | assert tensor is not None, "at least one argument must be a Tensor" 22 | 23 | # Force variances to be Tensors. Broadcasting helps convert scalars to 24 | # Tensors, but it does not work for th.exp(). 25 | logvar1, logvar2 = [ 26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 27 | for x in (logvar1, logvar2) 28 | ] 29 | 30 | return 0.5 * ( 31 | -1.0 32 | + logvar2 33 | - logvar1 34 | + th.exp(logvar1 - logvar2) 35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 36 | ) 37 | 38 | 39 | def approx_standard_normal_cdf(x): 40 | """ 41 | A fast approximation of the cumulative distribution function of the 42 | standard normal. 43 | """ 44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 45 | 46 | 47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales): 48 | """ 49 | Compute the log-likelihood of a continuous Gaussian distribution. 50 | :param x: the targets 51 | :param means: the Gaussian mean Tensor. 52 | :param log_scales: the Gaussian log stddev Tensor. 53 | :return: a tensor like x of log probabilities (in nats). 54 | """ 55 | centered_x = x - means 56 | inv_stdv = th.exp(-log_scales) 57 | normalized_x = centered_x * inv_stdv 58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) 59 | return log_probs 60 | 61 | 62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 63 | """ 64 | Compute the log-likelihood of a Gaussian distribution discretizing to a 65 | given image. 66 | :param x: the target images. It is assumed that this was uint8 values, 67 | rescaled to the range [-1, 1]. 68 | :param means: the Gaussian mean Tensor. 69 | :param log_scales: the Gaussian log stddev Tensor. 70 | :return: a tensor like x of log probabilities (in nats). 71 | """ 72 | assert x.shape == means.shape == log_scales.shape 73 | centered_x = x - means 74 | inv_stdv = th.exp(-log_scales) 75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 76 | cdf_plus = approx_standard_normal_cdf(plus_in) 77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 78 | cdf_min = approx_standard_normal_cdf(min_in) 79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 81 | cdf_delta = cdf_plus - cdf_min 82 | log_probs = th.where( 83 | x < -0.999, 84 | log_cdf_plus, 85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 86 | ) 87 | assert log_probs.shape == x.shape 88 | return log_probs 89 | -------------------------------------------------------------------------------- /action_model/models.py: -------------------------------------------------------------------------------- 1 | # Modified from facebookresearch's DiT repos 2 | # DiT: https://github.com/facebookresearch/DiT/blob/main/models.py 3 | 4 | # Copyright (c) Meta Platforms, Inc. and affiliates. 5 | # All rights reserved. 6 | 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | # -------------------------------------------------------- 10 | # References: 11 | # GLIDE: https://github.com/openai/glide-text2im 12 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py 13 | # -------------------------------------------------------- 14 | 15 | import torch 16 | import torch.nn as nn 17 | import math 18 | from timm.models.vision_transformer import Attention, Mlp 19 | 20 | def modulate(x, shift, scale): 21 | return x * (1 + scale) + shift 22 | 23 | 24 | ################################################################################# 25 | # Embedding Layers for Timesteps and conditions # 26 | ################################################################################# 27 | 28 | class TimestepEmbedder(nn.Module): 29 | """ 30 | Embeds scalar timesteps into vector representations. 31 | """ 32 | def __init__(self, hidden_size, frequency_embedding_size=256): 33 | super().__init__() 34 | self.mlp = nn.Sequential( 35 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 36 | nn.SiLU(), 37 | nn.Linear(hidden_size, hidden_size, bias=True), 38 | ) 39 | self.frequency_embedding_size = frequency_embedding_size 40 | 41 | @staticmethod 42 | def timestep_embedding(t, dim, max_period=10000): 43 | """ 44 | Create sinusoidal timestep embeddings. 45 | :param t: a 1-D Tensor of N indices, one per batch element. 46 | These may be fractional. 47 | :param dim: the dimension of the output. 48 | :param max_period: controls the minimum frequency of the embeddings. 49 | :return: an (N, D) Tensor of positional embeddings. 50 | """ 51 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 52 | half = dim // 2 53 | freqs = torch.exp( 54 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 55 | ).to(device=t.device) 56 | args = t[:, None].float() * freqs[None] 57 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 58 | if dim % 2: 59 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 60 | return embedding 61 | 62 | def forward(self, t): 63 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(next(self.mlp.parameters()).dtype) 64 | t_emb = self.mlp(t_freq) 65 | return t_emb 66 | 67 | class LabelEmbedder(nn.Module): 68 | """ 69 | Embeds conditions into vector representations. Also handles label dropout for classifier-free guidance. 70 | """ 71 | def __init__(self, in_size, hidden_size, dropout_prob=0.1, conditions_shape=(1, 1, 4096)): 72 | super().__init__() 73 | self.linear = nn.Linear(in_size, hidden_size) 74 | self.dropout_prob = dropout_prob 75 | if dropout_prob > 0: 76 | self.uncondition = nn.Parameter(torch.empty(conditions_shape[1:])) 77 | 78 | def token_drop(self, conditions, force_drop_ids=None): 79 | """ 80 | Drops conditions to enable classifier-free guidance. 81 | """ 82 | if force_drop_ids is None: 83 | drop_ids = torch.rand(conditions.shape[0], device=conditions.device) < self.dropout_prob 84 | else: 85 | drop_ids = force_drop_ids == 1 86 | conditions = torch.where(drop_ids.unsqueeze(1).unsqueeze(1).expand(conditions.shape[0], *self.uncondition.shape), self.uncondition, conditions) 87 | return conditions 88 | 89 | 90 | def forward(self, conditions, train, force_drop_ids=None): 91 | use_dropout = self.dropout_prob > 0 92 | if (train and use_dropout) or (force_drop_ids is not None): 93 | conditions = self.token_drop(conditions, force_drop_ids) 94 | embeddings = self.linear(conditions) 95 | return embeddings 96 | 97 | ################################################################################# 98 | # Embedding Layers for Actions and # 99 | ################################################################################# 100 | class ActionEmbedder(nn.Module): 101 | def __init__(self, action_size, hidden_size): 102 | super().__init__() 103 | self.linear = nn.Linear(action_size, hidden_size) 104 | 105 | def forward(self, x): 106 | x = self.linear(x) 107 | return x 108 | 109 | # Action_History is not used now 110 | class HistoryEmbedder(nn.Module): 111 | def __init__(self, action_size, hidden_size): 112 | super().__init__() 113 | self.linear = nn.Linear(action_size, hidden_size) 114 | 115 | def forward(self, x): 116 | x = self.linear(x) 117 | return x 118 | 119 | ################################################################################# 120 | # Core DiT Model # 121 | ################################################################################# 122 | 123 | class DiTBlock(nn.Module): 124 | """ 125 | A DiT block with self-attention conditioning. 126 | """ 127 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): 128 | super().__init__() 129 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 130 | self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) 131 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 132 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 133 | approx_gelu = lambda: nn.GELU(approximate="tanh") 134 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) 135 | 136 | def forward(self, x): 137 | x = x + self.attn(self.norm1(x)) 138 | x = x + self.mlp(self.norm2(x)) 139 | return x 140 | 141 | 142 | class FinalLayer(nn.Module): 143 | """ 144 | The final layer of DiT. 145 | """ 146 | def __init__(self, hidden_size, out_channels): 147 | super().__init__() 148 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 149 | self.linear = nn.Linear(hidden_size, out_channels, bias=True) 150 | 151 | def forward(self, x): 152 | x = self.norm_final(x) 153 | x = self.linear(x) 154 | return x 155 | 156 | 157 | class DiT(nn.Module): 158 | """ 159 | Diffusion model with a Transformer backbone. 160 | """ 161 | def __init__( 162 | self, 163 | in_channels=7, 164 | hidden_size=1152, 165 | depth=28, 166 | num_heads=16, 167 | mlp_ratio=4.0, 168 | class_dropout_prob=0.1, 169 | token_size=4096, 170 | future_action_window_size=1, 171 | past_action_window_size=0, 172 | learn_sigma=False, 173 | ): 174 | super().__init__() 175 | 176 | assert past_action_window_size == 0, "Error: action_history is not used now" 177 | 178 | self.learn_sigma = learn_sigma 179 | self.in_channels = in_channels 180 | self.out_channels = in_channels * 2 if learn_sigma else in_channels 181 | self.class_dropout_prob = class_dropout_prob 182 | self.num_heads = num_heads 183 | self.past_action_window_size = past_action_window_size 184 | self.future_action_window_size = future_action_window_size 185 | 186 | # Action history is not used now. 187 | self.history_embedder = HistoryEmbedder(action_size=in_channels, hidden_size=hidden_size) 188 | 189 | self.x_embedder = ActionEmbedder(action_size=in_channels, hidden_size=hidden_size) 190 | self.t_embedder = TimestepEmbedder(hidden_size) 191 | self.z_embedder = LabelEmbedder(in_size=token_size, hidden_size=hidden_size, dropout_prob=class_dropout_prob) 192 | scale = hidden_size ** -0.5 193 | 194 | # Learnable positional embeddings 195 | # +2, one for the conditional token, and one for the current action prediction 196 | self.positional_embedding = nn.Parameter( 197 | scale * torch.randn(future_action_window_size + past_action_window_size + 2, hidden_size)) 198 | 199 | self.blocks = nn.ModuleList([ 200 | DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth) 201 | ]) 202 | self.final_layer = FinalLayer(hidden_size, self.out_channels) 203 | self.initialize_weights() 204 | 205 | def initialize_weights(self): 206 | # Initialize transformer layers: 207 | def _basic_init(module): 208 | if isinstance(module, nn.Linear): 209 | torch.nn.init.xavier_uniform_(module.weight) 210 | if module.bias is not None: 211 | nn.init.constant_(module.bias, 0) 212 | self.apply(_basic_init) 213 | 214 | # # Initialize token_embed like nn.Linear 215 | nn.init.normal_(self.x_embedder.linear.weight, std=0.02) 216 | nn.init.constant_(self.x_embedder.linear.bias, 0) 217 | 218 | nn.init.normal_(self.history_embedder.linear.weight, std=0.02) 219 | nn.init.constant_(self.history_embedder.linear.bias, 0) 220 | 221 | # Initialize label embedding table: 222 | if self.class_dropout_prob > 0: 223 | nn.init.normal_(self.z_embedder.uncondition, std=0.02) 224 | nn.init.normal_(self.z_embedder.linear.weight, std=0.02) 225 | nn.init.constant_(self.z_embedder.linear.bias, 0) 226 | 227 | # Initialize timestep embedding MLP: 228 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) 229 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) 230 | 231 | nn.init.constant_(self.final_layer.linear.weight, 0) 232 | nn.init.constant_(self.final_layer.linear.bias, 0) 233 | 234 | def forward(self, x, t, z): 235 | """ 236 | Forward pass of DiT. 237 | history: (N, H, D) tensor of action history # not used now 238 | x: (N, T, D) tensor of predicting action inputs 239 | t: (N,) tensor of diffusion timesteps 240 | z: (N, 1, D) tensor of conditions 241 | """ 242 | x = self.x_embedder(x) # (N, T, D) 243 | t = self.t_embedder(t) # (N, D) 244 | z = self.z_embedder(z, self.training) # (N, 1, D) 245 | c = t.unsqueeze(1) + z # (N, 1, D) 246 | x = torch.cat((c, x), dim=1) # (N, T+1, D) 247 | x = x + self.positional_embedding # (N, T+1, D) 248 | for block in self.blocks: 249 | x = block(x) # (N, T+1, D) 250 | x = self.final_layer(x) # (N, T+1, out_channels) 251 | return x[:, 1:, :] # (N, T, C) 252 | 253 | def forward_with_cfg(self, x, t, z, cfg_scale): 254 | """ 255 | Forward pass of Diffusion, but also batches the unconditional forward pass for classifier-free guidance. 256 | """ 257 | 258 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb 259 | half = x[: len(x) // 2] 260 | combined = torch.cat([half, half], dim=0).to(next(self.x_embedder.parameters()).dtype) 261 | model_out = self.forward(combined, t, z) 262 | # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] 263 | eps, rest = model_out[:, :, :self.in_channels], model_out[:, :, self.in_channels:] 264 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 265 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) 266 | eps = torch.cat([half_eps, half_eps], dim=0) 267 | # return torch.cat([eps, rest], dim=1) 268 | return torch.cat([eps, rest], dim=2) -------------------------------------------------------------------------------- /action_model/respace.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import numpy as np 7 | import torch as th 8 | 9 | from .gaussian_diffusion import GaussianDiffusion 10 | 11 | 12 | def space_timesteps(num_timesteps, section_counts): 13 | """ 14 | Create a list of timesteps to use from an original diffusion process, 15 | given the number of timesteps we want to take from equally-sized portions 16 | of the original process. 17 | For example, if there's 300 timesteps and the section counts are [10,15,20] 18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 19 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 20 | If the stride is a string starting with "ddim", then the fixed striding 21 | from the DDIM paper is used, and only one section is allowed. 22 | :param num_timesteps: the number of diffusion steps in the original 23 | process to divide up. 24 | :param section_counts: either a list of numbers, or a string containing 25 | comma-separated numbers, indicating the step count 26 | per section. As a special case, use "ddimN" where N 27 | is a number of steps to use the striding from the 28 | DDIM paper. 29 | :return: a set of diffusion steps from the original process to use. 30 | """ 31 | if isinstance(section_counts, str): 32 | if section_counts.startswith("ddim"): 33 | desired_count = int(section_counts[len("ddim") :]) 34 | if desired_count == 1: 35 | return set([50]) 36 | for i in range(1, num_timesteps): 37 | if len(range(0, num_timesteps, i)) == desired_count: 38 | return set(range(0, num_timesteps, i)) 39 | raise ValueError( 40 | f"cannot create exactly {num_timesteps} steps with an integer stride" 41 | ) 42 | section_counts = [int(x) for x in section_counts.split(",")] 43 | size_per = num_timesteps // len(section_counts) 44 | extra = num_timesteps % len(section_counts) 45 | start_idx = 0 46 | all_steps = [] 47 | for i, section_count in enumerate(section_counts): 48 | size = size_per + (1 if i < extra else 0) 49 | if size < section_count: 50 | raise ValueError( 51 | f"cannot divide section of {size} steps into {section_count}" 52 | ) 53 | if section_count <= 1: 54 | frac_stride = 1 55 | else: 56 | frac_stride = (size - 1) / (section_count - 1) 57 | cur_idx = 0.0 58 | taken_steps = [] 59 | for _ in range(section_count): 60 | taken_steps.append(start_idx + round(cur_idx)) 61 | cur_idx += frac_stride 62 | all_steps += taken_steps 63 | start_idx += size 64 | return set(all_steps) 65 | 66 | 67 | class SpacedDiffusion(GaussianDiffusion): 68 | """ 69 | A diffusion process which can skip steps in a base diffusion process. 70 | :param use_timesteps: a collection (sequence or set) of timesteps from the 71 | original diffusion process to retain. 72 | :param kwargs: the kwargs to create the base diffusion process. 73 | """ 74 | 75 | def __init__(self, use_timesteps, **kwargs): 76 | self.use_timesteps = set(use_timesteps) 77 | self.timestep_map = [] 78 | self.original_num_steps = len(kwargs["betas"]) 79 | 80 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 81 | last_alpha_cumprod = 1.0 82 | new_betas = [] 83 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 84 | if i in self.use_timesteps: 85 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 86 | last_alpha_cumprod = alpha_cumprod 87 | self.timestep_map.append(i) 88 | kwargs["betas"] = np.array(new_betas) 89 | super().__init__(**kwargs) 90 | 91 | def p_mean_variance( 92 | self, model, *args, **kwargs 93 | ): # pylint: disable=signature-differs 94 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 95 | 96 | def training_losses( 97 | self, model, *args, **kwargs 98 | ): # pylint: disable=signature-differs 99 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 100 | 101 | def condition_mean(self, cond_fn, *args, **kwargs): 102 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 103 | 104 | def condition_score(self, cond_fn, *args, **kwargs): 105 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 106 | 107 | def _wrap_model(self, model): 108 | if isinstance(model, _WrappedModel): 109 | return model 110 | return _WrappedModel( 111 | model, self.timestep_map, self.original_num_steps 112 | ) 113 | 114 | def _scale_timesteps(self, t): 115 | # Scaling is done by the wrapped model. 116 | return t 117 | 118 | 119 | class _WrappedModel: 120 | def __init__(self, model, timestep_map, original_num_steps): 121 | self.model = model 122 | self.timestep_map = timestep_map 123 | # self.rescale_timesteps = rescale_timesteps 124 | self.original_num_steps = original_num_steps 125 | 126 | def __call__(self, x, ts, **kwargs): 127 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 128 | new_ts = map_tensor[ts] 129 | # if self.rescale_timesteps: 130 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 131 | return self.model(x, new_ts, **kwargs) 132 | -------------------------------------------------------------------------------- /action_model/timestep_sampler.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from abc import ABC, abstractmethod 7 | 8 | import numpy as np 9 | import torch as th 10 | import torch.distributed as dist 11 | 12 | 13 | def create_named_schedule_sampler(name, diffusion): 14 | """ 15 | Create a ScheduleSampler from a library of pre-defined samplers. 16 | :param name: the name of the sampler. 17 | :param diffusion: the diffusion object to sample for. 18 | """ 19 | if name == "uniform": 20 | return UniformSampler(diffusion) 21 | elif name == "loss-second-moment": 22 | return LossSecondMomentResampler(diffusion) 23 | else: 24 | raise NotImplementedError(f"unknown schedule sampler: {name}") 25 | 26 | 27 | class ScheduleSampler(ABC): 28 | """ 29 | A distribution over timesteps in the diffusion process, intended to reduce 30 | variance of the objective. 31 | By default, samplers perform unbiased importance sampling, in which the 32 | objective's mean is unchanged. 33 | However, subclasses may override sample() to change how the resampled 34 | terms are reweighted, allowing for actual changes in the objective. 35 | """ 36 | 37 | @abstractmethod 38 | def weights(self): 39 | """ 40 | Get a numpy array of weights, one per diffusion step. 41 | The weights needn't be normalized, but must be positive. 42 | """ 43 | 44 | def sample(self, batch_size, device): 45 | """ 46 | Importance-sample timesteps for a batch. 47 | :param batch_size: the number of timesteps. 48 | :param device: the torch device to save to. 49 | :return: a tuple (timesteps, weights): 50 | - timesteps: a tensor of timestep indices. 51 | - weights: a tensor of weights to scale the resulting losses. 52 | """ 53 | w = self.weights() 54 | p = w / np.sum(w) 55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 56 | indices = th.from_numpy(indices_np).long().to(device) 57 | weights_np = 1 / (len(p) * p[indices_np]) 58 | weights = th.from_numpy(weights_np).float().to(device) 59 | return indices, weights 60 | 61 | 62 | class UniformSampler(ScheduleSampler): 63 | def __init__(self, diffusion): 64 | self.diffusion = diffusion 65 | self._weights = np.ones([diffusion.num_timesteps]) 66 | 67 | def weights(self): 68 | return self._weights 69 | 70 | 71 | class LossAwareSampler(ScheduleSampler): 72 | def update_with_local_losses(self, local_ts, local_losses): 73 | """ 74 | Update the reweighting using losses from a model. 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | :param local_ts: an integer Tensor of timesteps. 80 | :param local_losses: a 1D Tensor of losses. 81 | """ 82 | batch_sizes = [ 83 | th.tensor([0], dtype=th.int32, device=local_ts.device) 84 | for _ in range(dist.get_world_size()) 85 | ] 86 | dist.all_gather( 87 | batch_sizes, 88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 89 | ) 90 | 91 | # Pad all_gather batches to be the maximum batch size. 92 | batch_sizes = [x.item() for x in batch_sizes] 93 | max_bs = max(batch_sizes) 94 | 95 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 96 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 97 | dist.all_gather(timestep_batches, local_ts) 98 | dist.all_gather(loss_batches, local_losses) 99 | timesteps = [ 100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 101 | ] 102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 103 | self.update_with_all_losses(timesteps, losses) 104 | 105 | @abstractmethod 106 | def update_with_all_losses(self, ts, losses): 107 | """ 108 | Update the reweighting using losses from a model. 109 | Sub-classes should override this method to update the reweighting 110 | using losses from the model. 111 | This method directly updates the reweighting without synchronizing 112 | between workers. It is called by update_with_local_losses from all 113 | ranks with identical arguments. Thus, it should have deterministic 114 | behavior to maintain state across workers. 115 | :param ts: a list of int timesteps. 116 | :param losses: a list of float losses, one per timestep. 117 | """ 118 | 119 | 120 | class LossSecondMomentResampler(LossAwareSampler): 121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 122 | self.diffusion = diffusion 123 | self.history_per_term = history_per_term 124 | self.uniform_prob = uniform_prob 125 | self._loss_history = np.zeros( 126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 127 | ) 128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 129 | 130 | def weights(self): 131 | if not self._warmed_up(): 132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 134 | weights /= np.sum(weights) 135 | weights *= 1 - self.uniform_prob 136 | weights += self.uniform_prob / len(weights) 137 | return weights 138 | 139 | def update_with_all_losses(self, ts, losses): 140 | for t, loss in zip(ts, losses): 141 | if self._loss_counts[t] == self.history_per_term: 142 | # Shift out the oldest loss term. 143 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 144 | self._loss_history[t, -1] = loss 145 | else: 146 | self._loss_history[t, self._loss_counts[t]] = loss 147 | self._loss_counts[t] += 1 148 | 149 | def _warmed_up(self): 150 | return (self._loss_counts == self.history_per_term).all() 151 | -------------------------------------------------------------------------------- /conf/__init__.py: -------------------------------------------------------------------------------- 1 | from .vla import VLAConfig, VLARegistry 2 | -------------------------------------------------------------------------------- /conf/vla.py: -------------------------------------------------------------------------------- 1 | """ 2 | vla.py 3 | 4 | Draccus Dataclass Definition for a VLAConfig object, with various registered subclasses for each VLA experiment and 5 | model configuration thereof. A given VLA model (`policy`) configures the following attributes: 6 | - Data Mixture (e.g., Bridge, OXE_MAGIC_SOUP, etc.) 7 | - Base VLM from Prismatic Registry (e.g., `prism-dinosiglip+7b`) 8 | - VLA Model Architecture / Parameters (e.g., freeze vision encoder, last layer finetuning) 9 | - Training / Optimization Hyperparameters 10 | """ 11 | 12 | from dataclasses import dataclass 13 | from enum import Enum, unique 14 | from pathlib import Path 15 | from typing import Optional, Union 16 | 17 | from draccus import ChoiceRegistry 18 | 19 | 20 | @dataclass 21 | class VLAConfig(ChoiceRegistry): 22 | # fmt: off 23 | vla_id: str # Unique VLA Policy ID that fully specifies a configuration variant 24 | base_vlm: Union[str, Path] # Base VLM as ID/Path to Run Directory (e.g., `prism-dinosiglip+7b`) 25 | freeze_vision_backbone: bool # Freeze Vision Backbone Parameters (akin to pretraining) 26 | freeze_llm_backbone: bool # Freeze LLM Backbone parameters 27 | unfreeze_last_llm_layer: bool # Unfreeze final layer of LLM (only takes effect if LLM is frozen) 28 | 29 | # Data Mixture Parameters 30 | data_mix: str # Open-X Embodiment Dataset =>> Unique Mixture ID (e.g., `bridge`) 31 | shuffle_buffer_size: int # Size of Shuffle Buffer (100K for Bridge, 1M for OXE) 32 | 33 | # Optimization Parameters 34 | epochs: int # Epochs to Run (in case `max_steps` is not specified) 35 | max_steps: Optional[int] # [Optional] Max Gradient Steps to Run (overrides `epochs`) 36 | 37 | expected_world_size: int # Expected # of GPUs =>> allows us to gate training on hardware 38 | global_batch_size: int # Global Batch Size (divided across processes / world size) 39 | per_device_batch_size: int # Per-Device Batch Size (per-process / individual GPU) 40 | # =>> # of accumulation steps is auto-computed 41 | 42 | learning_rate: float # Peak Learning Rate (`lr_scheduler_type` sets warmup/decay) 43 | weight_decay: float # Weight Decay for AdamW Optimizer 44 | max_grad_norm: float # Max Grad Norm (for global gradient clipping) 45 | lr_scheduler_type: str # LR Scheduler (usually: "constant" | "linear-warmup+cosine-decay") 46 | warmup_ratio: float # Fraction of Steps to Warmup (for warmup LR schedulers) 47 | 48 | train_strategy: str # Train Strategy (default "fsdp-full-shard") 49 | 50 | # Enable Gradient/Activation Checkpointing (for the LLM Backbone) 51 | enable_gradient_checkpointing: bool = True # Enable Gradient/Activation Checkpointing during Training 52 | 53 | # Mixed Precision Training via Torch Native AMP (`autocast`) 54 | enable_mixed_precision_training: bool = True # Enable Traditional BF16 Mixed Precision 55 | reduce_in_full_precision: bool = True # Accumulate/Reduce All-Gather Gradients in FP32 Full Precision 56 | 57 | # fmt: on 58 | 59 | 60 | # === OpenVLA Training Configurations === 61 | 62 | 63 | # = [8 GPU] Fast Iteration =>> SigLIP 224px + Bridge = 64 | @dataclass 65 | class Exp_SigLIP_224px_Bridge(VLAConfig): 66 | vla_id: str = "siglip-224px+mx-bridge" 67 | base_vlm: Union[str, Path] = "siglip-224px+7b" 68 | 69 | freeze_vision_backbone: bool = False 70 | freeze_llm_backbone: bool = False 71 | unfreeze_last_llm_layer: bool = False 72 | 73 | # Data Mixture Parameters 74 | data_mix: str = "bridge" 75 | shuffle_buffer_size: int = 256_000 76 | 77 | # Optimization Parameters 78 | epochs: int = 1000 79 | max_steps: Optional[int] = None 80 | 81 | expected_world_size: int = 8 82 | global_batch_size: int = 256 83 | per_device_batch_size: int = 32 84 | 85 | learning_rate: float = 2e-5 86 | weight_decay: float = 0.0 87 | max_grad_norm: float = 1.0 88 | lr_scheduler_type: str = "constant" 89 | warmup_ratio: float = 0.0 90 | 91 | train_strategy: str = "fsdp-full-shard" 92 | 93 | 94 | # === CogACT-VLA Pretraining Configs === 95 | 96 | @dataclass 97 | class Exp_CogACT_OXE_Magic_Soup_Plus_Minus(Exp_SigLIP_224px_Bridge): 98 | vla_id: str = "prism-dinosiglip-224px+oxe+diffusion" 99 | base_vlm: Union[str, Path] = "prism-dinosiglip-224px+7b" 100 | 101 | # data_mix: str = "oxe_magic_soup_plus" 102 | data_mix: str = "oxe_magic_soup_plus_minus" 103 | shuffle_buffer_size: int = 250_000 104 | expected_world_size: int = 16 105 | global_batch_size: int = 256 106 | per_device_batch_size: int = 16 107 | max_grad_norm: float = 1.0 108 | learning_rate: float = 2e-5 109 | 110 | epochs: int = 100 111 | 112 | # === Define a VLA Registry Enum for Reference & Validation === 113 | @unique 114 | class VLARegistry(Enum): 115 | # Sanity Check Configurations =>> BridgeV2 116 | SIGLIP_224PX_MX_BRIDGE = Exp_SigLIP_224px_Bridge 117 | 118 | # === CogACT-VLA Pretraining Configs === 119 | EXP_COGACT_OXE_MAGIC_SOUP_PLUS_MINUS = Exp_CogACT_OXE_Magic_Soup_Plus_Minus 120 | 121 | @property 122 | def vla_id(self) -> str: 123 | return self.value.vla_id 124 | 125 | 126 | # Register VLAs in Choice Registry 127 | for vla_variant in VLARegistry: 128 | VLAConfig.register_subclass(vla_variant.vla_id, vla_variant.value) 129 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "cogact" 7 | authors = [ 8 | {name = "Qixiu Li", email="liqx23@mails.tsinghua.edu.cn"}, 9 | {name = "Yaobo Liang", email="yalia@microsoft.com"}, 10 | {name = "Zeyu Wang", email="skyw95357@gmail.com"} 11 | ] 12 | description = "CogACT: A Foundational Vision-Language-Action Model for Synergizing Cognition and Action in Robotic Manipulation" 13 | version = "0.0.1" 14 | readme = "README.md" 15 | requires-python = ">=3.8" 16 | keywords = ["vision-language-actions models", " robotic manipulation", "multimodal pretraining", "robot learning", "generative models"] 17 | license = {file = "LICENSE"} 18 | classifiers = [ 19 | "Development Status :: 3 - Alpha", 20 | "Intended Audience :: Developers", 21 | "Intended Audience :: Education", 22 | "Intended Audience :: Science/Research", 23 | "License :: OSI Approved :: MIT License", 24 | "Operating System :: OS Independent", 25 | "Programming Language :: Python :: 3", 26 | "Programming Language :: Python :: 3.8", 27 | "Programming Language :: Python :: 3.9", 28 | "Programming Language :: Python :: 3.10", 29 | "Programming Language :: Python :: 3 :: Only", 30 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 31 | ] 32 | dependencies = [ 33 | "accelerate>=0.25.0", 34 | "draccus>=0.8.0", 35 | "einops", 36 | # "flash_attn==2.5.5", # Here for documentation -- install *AFTER* editable install (follow README) 37 | "huggingface_hub", 38 | "json-numpy", 39 | "jsonlines", 40 | "matplotlib", 41 | "peft==0.11.1", 42 | "protobuf", 43 | "rich", 44 | "sentencepiece==0.1.99", 45 | "timm==0.9.10", 46 | "tokenizers==0.19.1", 47 | "torch>=2.2.0", 48 | "torchvision>=0.16.0", 49 | "torchaudio", 50 | "transformers==4.40.1", 51 | "wandb", 52 | "tensorflow==2.15.0", 53 | "tensorflow_datasets==4.9.3", 54 | "tensorflow_graphics==2021.12.3", 55 | "dlimp @ git+https://github.com/moojink/dlimp_openvla", 56 | "openvla @ git+https://github.com/arnoldland/openvla", # For data loader and prismatic VLM. Here we modify and use the dataloader of OpenVLA 57 | ] 58 | 59 | [project.optional-dependencies] 60 | dev = [ 61 | "black>=24.2.0", 62 | "gpustat", 63 | "ipython", 64 | "pre-commit", 65 | "ruff>=0.2.2", 66 | ] 67 | sagemaker = [ 68 | "boto3", 69 | "sagemaker" 70 | ] 71 | 72 | train = [ 73 | "flash-attn==2.5.5", 74 | "ninja", 75 | "packaging" 76 | ] 77 | 78 | [tool.setuptools.packages.find] 79 | where = ["."] 80 | exclude = ["cache"] 81 | 82 | [tool.setuptools.package-data] 83 | "cogact" = ["py.typed"] 84 | 85 | [tool.black] 86 | line-length = 121 87 | target-version = ["py38", "py39", "py310"] 88 | preview = true 89 | 90 | [tool.ruff] 91 | line-length = 121 92 | target-version = "py310" 93 | 94 | [tool.ruff.lint] 95 | select = ["A", "B", "E", "F", "I", "RUF", "W"] 96 | ignore = ["F722"] 97 | 98 | [tool.ruff.lint.per-file-ignores] 99 | "__init__.py" = ["E402", "F401"] 100 | -------------------------------------------------------------------------------- /scripts/deploy.py: -------------------------------------------------------------------------------- 1 | """ 2 | deploy.py 3 | 4 | A simple deployment example for serving a fine-tuned model. 5 | Please execute the code below to start the server. Using 'fractal20220817_data' as an example, 6 | please replace "unnorm_key" with the value from your fine-tuned dataset in actual use. 7 | 8 | ``` 9 | python scripts/deploy.py --saved_model_path --unnorm_key fractal20220817_data --action_ensemble --use_bf16 --action_ensemble_horizon 2 --adaptive_ensemble_alpha 0.1 --cfg_scale 1.5 --port 5500 10 | 11 | ``` 12 | 13 | 14 | The client only needs a Python environment and the requests library (pip install requests); 15 | no other dependencies need to be installed. 16 | 17 | Client (Standalone) Usage (assuming a server running on 0.0.0.0:5500): 18 | 19 | ``` 20 | import requests 21 | import json 22 | 23 | # Define the API endpoint 24 | url = 'http://127.0.0.1:5500/api/inference' 25 | 26 | # Define the parameters you want to send 27 | data = { 28 | 'task_description': "Pick up the red can.", 29 | } 30 | image = "image/google_robot.png" 31 | 32 | json.dump(data, open("data.json", "w")) 33 | 34 | with open ("data.json", "r") as query_file: 35 | with open(image, "rb") as image_file: 36 | file = [ 37 | ('images', (image, image_file, 'image/png')), 38 | ('json', ("data.json", query_file, 'application/json')) 39 | ] 40 | 41 | response = requests.post(url, files=file) 42 | if response.status_code == 200: 43 | pass 44 | else: 45 | print("Failed to get a response from the API") 46 | print(response.text) 47 | ``` 48 | 49 | """ 50 | 51 | import numpy as np 52 | from PIL import Image 53 | from typing import Optional, Tuple, Union 54 | import os 55 | import argparse 56 | import json 57 | import math 58 | from flask import Flask, request, jsonify 59 | import tempfile 60 | import torch 61 | from vla import load_vla 62 | from sim_cogact.adaptive_ensemble import AdaptiveEnsembler 63 | 64 | app = Flask(__name__) 65 | 66 | 67 | class CogACTService: 68 | def __init__( 69 | self, 70 | saved_model_path: str = "CogACT/CogACT-Base", 71 | unnorm_key: str = None, 72 | image_size: list[int] = [224, 224], 73 | action_model_type: str = "DiT-B", # choose from ['DiT-Small', 'DiT-Base', 'DiT-Large'] to match the model weight 74 | future_action_window_size: int = 15, 75 | cfg_scale: float = 1.5, 76 | num_ddim_steps: int = 10, 77 | use_ddim: bool = True, 78 | use_bf16: bool = True, 79 | action_dim: int = 7, 80 | action_ensemble: bool = True, 81 | adaptive_ensemble_alpha: float = 0.1, 82 | action_ensemble_horizon: int = 2, 83 | action_chunking: bool = False, 84 | action_chunking_window: Optional[int] = None, 85 | args=None 86 | ) -> None: 87 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 88 | assert not (action_chunking and action_ensemble), "Now 'action_chunking' and 'action_ensemble' cannot both be True." 89 | 90 | self.unnorm_key = unnorm_key 91 | 92 | print(f"*** unnorm_key: {unnorm_key} ***") 93 | self.vla = load_vla( 94 | saved_model_path, 95 | load_for_training=False, 96 | action_model_type=action_model_type, 97 | future_action_window_size=future_action_window_size, 98 | action_dim=action_dim, 99 | ) 100 | if use_bf16: 101 | self.vla.vlm = self.vla.vlm.to(torch.bfloat16) 102 | self.vla = self.vla.to("cuda").eval() 103 | self.cfg_scale = cfg_scale 104 | 105 | self.image_size = image_size 106 | self.use_ddim = use_ddim 107 | self.num_ddim_steps = num_ddim_steps 108 | self.action_ensemble = action_ensemble 109 | self.adaptive_ensemble_alpha = adaptive_ensemble_alpha 110 | self.action_ensemble_horizon = action_ensemble_horizon 111 | self.action_chunking = action_chunking 112 | self.action_chunking_window = action_chunking_window 113 | if self.action_ensemble: 114 | self.action_ensembler = AdaptiveEnsembler(self.action_ensemble_horizon, self.adaptive_ensemble_alpha) 115 | else: 116 | self.action_ensembler = None 117 | 118 | self.args = args 119 | self.reset() 120 | 121 | def reset(self) -> None: 122 | if self.action_ensemble: 123 | self.action_ensembler.reset() 124 | 125 | def step( 126 | self, image: str, 127 | task_description: Optional[str] = None, 128 | *args, **kwargs, 129 | ) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]: 130 | """ 131 | Input: 132 | image: Path to the image file 133 | task_description: Optional[str], task description 134 | Output: 135 | action: list[float], the ensembled 7-DoFs action of End-effector and gripper 136 | 137 | """ 138 | 139 | image: Image.Image = Image.open(image) 140 | 141 | # [IMPORTANT!]: Please process the input images here in exactly the same way as the images 142 | # were processed during finetuning to ensure alignment between inference and training. 143 | # Make sure, as much as possible, that the gripper is visible in the processed images. 144 | resized_image = resize_image(image, size=self.image_size) 145 | unnormed_actions, normalized_actions = self.vla.predict_action( 146 | image=resized_image, 147 | instruction=task_description, 148 | unnorm_key=self.unnorm_key, 149 | do_sample=False, 150 | cfg_scale=self.cfg_scale, 151 | use_ddim=self.use_ddim, 152 | num_ddim_steps=self.num_ddim_steps, 153 | ) 154 | 155 | if self.action_ensemble: 156 | unnormed_actions = self.action_ensembler.ensemble_action(unnormed_actions) 157 | # Translate the value of the gripper's open/close state to 0 or 1. 158 | # Please adjust this line according to the control mode of different grippers. 159 | unnormed_actions[6] = unnormed_actions[6] > 0.5 160 | action = unnormed_actions.tolist() 161 | elif self.action_chunking: 162 | # [IMPORTANT!]: Please modify the code here to output multiple actions at once. 163 | # The code below only outputs the first action in the chunking. 164 | # The chunking window size can be adjusted by modifying the 'action_chunking_window' parameter. 165 | if self.action_chunking_window is not None: 166 | chunked_actions = [] 167 | for i in range(0, self.action_chunking_window): 168 | chunked_actions.append(unnormed_actions[i].tolist()) 169 | action = chunked_actions 170 | else: 171 | raise ValueError("Please specify the 'action_chunking_window' when using action chunking.") 172 | else: 173 | # Output the first action in the chunking. Can be modified to output multiple actions at once. 174 | unnormed_actions = unnormed_actions[0] 175 | action = unnormed_actions.tolist() 176 | 177 | print(f"Instruction: {task_description}") 178 | print(f"Model path: {self.args.saved_model_path} at port {self.args.port}") 179 | return action 180 | 181 | 182 | # [IMPORTANT!]: Please modify the image processing code here to ensure that the input images 183 | # are handled in exactly the same way as during the finetuning phase. 184 | # Make sure, as much as possible, that the gripper is visible in the processed images. 185 | def resize_image(image: Image, size=(224, 224), shift_to_left=0): 186 | w, h = image.size 187 | assert h < w, "Height should be less than width" 188 | left_margin = (w - h) // 2 - shift_to_left 189 | left_margin = min(max(left_margin, 0), w - h) 190 | image = image.crop((left_margin, 0, left_margin + h, h)) 191 | 192 | image = image.resize(size, resample=Image.LANCZOS) 193 | 194 | image = scale_and_resize(image, target_size=(224, 224), scale=0.9, margin_w_ratio=0.5, margin_h_ratio=0.5) 195 | return image 196 | 197 | # Here the image is first center cropped and then resized back to its original size 198 | # because random crop data augmentation was used during finetuning. 199 | def scale_and_resize(image : Image, target_size=(224, 224), scale=0.9, margin_w_ratio=0.5, margin_h_ratio=0.5): 200 | w, h = image.size 201 | new_w = int(w * math.sqrt(scale)) 202 | new_h = int(h * math.sqrt(scale)) 203 | margin_w_max = w - new_w 204 | margin_h_max = h - new_h 205 | margin_w = int(margin_w_max * margin_w_ratio) 206 | margin_h = int(margin_h_max * margin_h_ratio) 207 | image = image.crop((margin_w, margin_h, margin_w + new_w, margin_h + new_h)) 208 | image = image.resize(target_size, resample=Image.LANCZOS) 209 | return image 210 | 211 | 212 | parser = argparse.ArgumentParser() 213 | parser.add_argument("--saved_model_path", type=str, default="CogACT/CogACT-Base") 214 | parser.add_argument("--unnorm_key", type=str, default=None) 215 | parser.add_argument("--image_size", type=list[int], default=[224, 224]) 216 | parser.add_argument("--action_model_type", type=str, default="DiT-B") 217 | parser.add_argument("--future_action_window_size", type=int, default=15) 218 | parser.add_argument("--cfg_scale", type=float, default=1.5) 219 | parser.add_argument("--port", type=int, default=5500) 220 | parser.add_argument("--use_bf16", action="store_true") 221 | parser.add_argument("--action_dim", type=int, default=7) 222 | parser.add_argument("--action_ensemble", action="store_true") 223 | parser.add_argument("--action_ensemble_horizon", type=int, default=2) 224 | parser.add_argument("--adaptive_ensemble_alpha", type=float, default=0.1) 225 | parser.add_argument("--action_chunking", action="store_true") 226 | parser.add_argument("--action_chunking_window", type=int, default=None) 227 | 228 | args = parser.parse_args() 229 | 230 | inferencer = CogACTService( 231 | saved_model_path=args.saved_model_path, 232 | unnorm_key=args.unnorm_key, 233 | image_size=args.image_size, 234 | action_model_type=args.action_model_type, 235 | future_action_window_size=args.future_action_window_size, 236 | cfg_scale=args.cfg_scale, 237 | use_bf16=args.use_bf16, 238 | action_dim=args.action_dim, 239 | action_ensemble=args.action_ensemble, 240 | adaptive_ensemble_alpha=args.adaptive_ensemble_alpha, 241 | action_ensemble_horizon=args.action_ensemble_horizon, 242 | action_chunking=args.action_chunking, 243 | action_chunking_window=args.action_chunking_window, 244 | args=args 245 | ) 246 | 247 | @app.route('/api/inference', methods=['POST']) 248 | def inference(): 249 | image = request.files['images'] 250 | query = request.files['json'] 251 | with tempfile.NamedTemporaryFile(delete=False) as temp_image: 252 | image.save(temp_image.name) 253 | temp_image_path = temp_image.name 254 | with tempfile.NamedTemporaryFile(delete=False) as temp_query: 255 | query.save(temp_query.name) 256 | temp_query_path = temp_query.name 257 | input_query = json.load(open(temp_query_path)) 258 | print(input_query) 259 | answer = inferencer.step(temp_image_path, **input_query) 260 | # clean files 261 | os.remove(temp_image_path) 262 | os.remove(temp_query_path) 263 | return jsonify(answer) 264 | 265 | if __name__ == "__main__": 266 | 267 | app.run(host="0.0.0.0", debug=False, port=args.port) -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | train.py 3 | 4 | Training script for Vision-Language-Action (VLA) Policies, built on top of pretrained VLMs, trained using mixtures of 5 | the Open-X Embodiment dataset. Performs training in native PyTorch, using Fully-Sharded Data Parallel (FSDP) to run 6 | distributed across GPUs (and nodes). By default, assumes that CUDA toolkit is >= 11.0 (to support BF16 mixed precision). 7 | 8 | Notes & Prerequisites: 9 | - If you want to set a custom location for all HF / TIMM artifacts --> `export HF_HOME=""` *before* running! 10 | => For example (add to end of .bashrc): `export HF_HOME="/mnt/fsx/skaramcheti/cache"` 11 | - If you want to suppress random Tensorflow logs --> `export TF_CPP_MIN_LOG_LEVEL=3` 12 | 13 | Run with: 14 | - [Single Node One-GPU (Debug)] : torchrun --standalone --nnodes 1 --nproc-per-node 1 vla-scripts/train.py 15 | - [Single Node Multi-GPU (= $K)]: torchrun --standalone --nnodes 1 --nproc-per-node $K vla-scripts/train.py 16 | """ 17 | 18 | import json 19 | import os 20 | import re 21 | from dataclasses import dataclass, field 22 | from pathlib import Path 23 | from typing import Optional, Tuple, Union 24 | 25 | import draccus 26 | import torch 27 | import torch.distributed as dist 28 | import yaml 29 | import wandb 30 | 31 | from prismatic.overwatch import initialize_overwatch 32 | from prismatic.util import set_global_seed 33 | from prismatic.vla import get_vla_dataset_and_collator 34 | from prismatic.vla.datasets.rlds.utils.data_utils import save_dataset_statistics 35 | 36 | from training import VLAMetrics, get_train_strategy 37 | from conf import VLAConfig, VLARegistry 38 | from vla import load, load_vla 39 | from vla import CogACT 40 | 41 | # Sane Defaults 42 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 43 | 44 | 45 | # Initialize Overwatch =>> Wraps `logging.Logger` 46 | overwatch = initialize_overwatch(__name__) 47 | 48 | 49 | @dataclass 50 | class TrainConfig: 51 | # fmt: off 52 | 53 | # VLAConfig (`conf/vla.py`); override with --vla.type `VLARegistry..vla_id` 54 | vla: VLAConfig = field( 55 | default_factory=VLAConfig.get_choice_class(VLARegistry.EXP_COGACT_OXE_MAGIC_SOUP_PLUS_MINUS.vla_id) 56 | ) 57 | 58 | # Directory Paths 59 | data_root_dir: Path = Path( # Path to Open-X dataset directory 60 | "datasets/open-x-embodiment" 61 | ) 62 | run_root_dir: Path = Path("runs") # Path to directory to store logs & checkpoints 63 | 64 | # Resume Run Parameters 65 | pretrained_checkpoint: Optional[Union[str, Path]] = None # Absolute Path to Checkpoint 66 | is_resume: bool = True # Whether we are continuing a prior training run 67 | # (only applicable given pretrained checkpoint) 68 | resume_step: Optional[int] = None # Global Step to Resume (should match checkpoint) 69 | resume_epoch: Optional[int] = None # Epoch to Resume (should match checkpoint) 70 | 71 | # Run Arguments 72 | run_id: Optional[str] = None # Run ID for logging, Weights & Biases 73 | run_id_note: Optional[str] = None # Extra note for logging, Weights & Biases 74 | save_interval: int = 2500 # Interval for saving checkpoints (in steps) 75 | image_aug: bool = False # Whether to enable image augmentations 76 | seed: int = 42 # Random seed (for reproducibility) 77 | 78 | # HF Hub Credentials (for any gated models) 79 | hf_token: Union[str, Path] = Path(".hf_token") # Environment variable or Path to HF Token 80 | 81 | # Tracking Parameters 82 | trackers: Tuple[str, ...] = ("jsonl", "wandb") # Trackers to initialize (if W&B, add config!) 83 | #trackers: Tuple[str, ...] = ("jsonl",) # Trackers to initialize (if W&B, add config!) 84 | wandb_project: str = "" # Name of W&B project to log to (use default!) 85 | wandb_entity: str = "" # Name of entity to log under 86 | repeated_diffusion_steps: int = 8 # Repeated steps for training action model (a diffusion model) 87 | load_all_data_for_training: bool = True # Load all training data 88 | future_action_window_size: int = 15 # Action chunking, predicting future actions + current action 89 | past_action_window_size: int = 0 # Action history window size, not used now, set to be 0 90 | action_model_type: str = 'DiT-B' # Action model type, chose from ['DiT-S', 'DiT-B', 'DiT-L'] 91 | use_ema: bool = False # EMA version of action model 92 | action_dim: int = 7 # Dimension of action space 93 | 94 | def __post_init__(self) -> None: 95 | """Lift optimization parameters from `self.vla` for ease of use =>> validate on `expected_world_size`""" 96 | self.epochs = self.vla.epochs 97 | self.max_steps = self.vla.max_steps 98 | self.global_batch_size = self.vla.global_batch_size 99 | self.per_device_batch_size = self.vla.per_device_batch_size 100 | 101 | self.learning_rate = self.vla.learning_rate 102 | self.weight_decay = self.vla.weight_decay 103 | self.max_grad_norm = self.vla.max_grad_norm 104 | self.lr_scheduler_type = self.vla.lr_scheduler_type 105 | self.warmup_ratio = self.vla.warmup_ratio 106 | 107 | self.train_strategy = self.vla.train_strategy 108 | 109 | # [Validate] Assert on `expected_world_size` 110 | assert ( 111 | self.vla.expected_world_size == overwatch.world_size() 112 | ), f"Expected World Size = {self.vla.expected_world_size} but Found {overwatch.world_size()} GPUs!" 113 | 114 | # fmt: on 115 | 116 | 117 | @draccus.wrap() 118 | def train(cfg: TrainConfig) -> None: 119 | overwatch.info("CogACT-VLA Training :: Warming Up") 120 | 121 | # Note => Under `torchrun` initializing `overwatch` will automatically set up `torch.distributed` 122 | torch.cuda.set_device(device_id := overwatch.local_rank()) 123 | torch.cuda.empty_cache() 124 | 125 | # Configure Unique Run Name & Save Directory 126 | vla_id = cfg.vla.vla_id 127 | cfg.run_id = ( 128 | f"{vla_id}+n{cfg.vla.expected_world_size // 8}+b{cfg.per_device_batch_size}+x{cfg.seed}" 129 | if cfg.run_id is None 130 | else cfg.run_id 131 | ) 132 | if cfg.run_id_note is not None: 133 | cfg.run_id += f"--{cfg.run_id_note}" 134 | if cfg.image_aug: 135 | cfg.run_id += "--image_aug" 136 | 137 | # Start =>> Build Directories and Set Randomness 138 | overwatch.info('"Do or do not; there is no try."', ctx_level=1) 139 | hf_token = cfg.hf_token.read_text().strip() if isinstance(cfg.hf_token, Path) else os.environ[cfg.hf_token] 140 | worker_init_fn = set_global_seed(cfg.seed, get_worker_init_fn=True) 141 | os.makedirs(run_dir := (cfg.run_root_dir / cfg.run_id), exist_ok=True) 142 | os.makedirs(cfg.run_root_dir / cfg.run_id / "checkpoints", exist_ok=True) 143 | 144 | # Save Configuration =>> additionally save a JSON version for later HF Integration 145 | if overwatch.is_rank_zero(): 146 | draccus.dump(cfg, open(run_dir / "config.yaml", "w")) 147 | with open(run_dir / "config.yaml", "r") as f_yaml, open(run_dir / "config.json", "w") as f_json: 148 | yaml_cfg = yaml.safe_load(f_yaml) 149 | json.dump(yaml_cfg, f_json, indent=2) 150 | 151 | dist.barrier() 152 | # Load VLA checkpoint (if resuming from training) or Base VLM otherwise (from `cfg.vla.base_vlm` ID or Path) 153 | # =>> Note :: Verifies that all parameters are loaded in FP32 on load! 154 | overwatch.info(f"Loading Base VLM `{cfg.vla.base_vlm}` from ID/Path") 155 | if cfg.pretrained_checkpoint is not None: 156 | # [Validate] Pretrained Checkpoint `step` and `epoch` should match `resume_step` and `resume_epoch` 157 | # =>> Note :: We make developers pass in `resume_*` arguments as an extra sanity check! 158 | if cfg.is_resume: 159 | assert int(re.search("step-(.+?)-", cfg.pretrained_checkpoint.name).group(1)) == cfg.resume_step 160 | assert int(re.search("epoch-(.+?)-", cfg.pretrained_checkpoint.name).group(1)) == cfg.resume_epoch 161 | overwatch.info("Loading VLA Checkpoint") 162 | if cfg.use_ema: 163 | overwatch.info("Loading EMA of Diffusion") 164 | vla = load_vla(cfg.pretrained_checkpoint, 165 | hf_token=hf_token, 166 | load_for_training=True, 167 | action_model_type=cfg.action_model_type, 168 | action_dim=cfg.action_dim, 169 | future_action_window_size=cfg.future_action_window_size, 170 | past_action_window_size=cfg.past_action_window_size, 171 | use_ema=cfg.use_ema, 172 | ) 173 | 174 | else: 175 | vlm = load(cfg.vla.base_vlm, hf_token=hf_token, load_for_training=True) 176 | overwatch.info("Creating VLA from Base VLM") 177 | if cfg.use_ema: 178 | overwatch.info("Creating EMA for Diffusion") 179 | vla = CogACT(vlm, 180 | action_model_type=cfg.action_model_type, 181 | action_dim=cfg.action_dim, 182 | future_action_window_size=cfg.future_action_window_size, 183 | past_action_window_size=cfg.past_action_window_size, 184 | use_ema=cfg.use_ema, 185 | ) 186 | # del this variable to avoid bugs. The vlm shouldn't be used anymore 187 | del vlm 188 | 189 | # [Validate] Model should be in Full Precision! 190 | for param in vla.parameters(): 191 | assert param.dtype == torch.float32, f"Loaded VLM parameter not in full precision: {param}" 192 | 193 | # Determine training "stage" based on frozen vs unfrozen parameters --> supports different fine-tuning schemes! 194 | if not cfg.vla.freeze_vision_backbone and not cfg.vla.freeze_llm_backbone: 195 | stage = "full-finetune" # Full fine-tuning 196 | elif cfg.vla.freeze_vision_backbone and not cfg.vla.freeze_llm_backbone: 197 | stage = "finetune" # Frozen vision encoder 198 | elif cfg.vla.freeze_vision_backbone and cfg.vla.freeze_llm_backbone: 199 | stage = "align" # Fine-tuning projector 200 | elif not cfg.vla.freeze_vision_backbone and cfg.vla.freeze_llm_backbone and cfg.vla.unfreeze_last_llm_layer: 201 | stage = "vla-sandwich-train" # Fine-tuning vision encoder, projector, and LLM last layer 202 | elif cfg.vla.freeze_vision_backbone and cfg.vla.freeze_llm_backbone and cfg.vla.unfreeze_last_llm_layer: 203 | stage = "vla-last-layer-train" # Fine-tuning LLM last layer only 204 | else: 205 | raise ValueError( 206 | "Weight freezing configuration not supported. VLA config has the following parameters: " 207 | f"freeze_vision_backbone: {cfg.vla.freeze_vision_backbone}" 208 | f"freeze_llm_backbone: {cfg.vla.freeze_llm_backbone}" 209 | f"unfreeze_last_llm_layer: {cfg.vla.unfreeze_last_llm_layer}" 210 | ) 211 | 212 | # [Explicit] Call to `freeze_backbones` here for clarity =>> will log exactly what is/is not frozen 213 | overwatch.info(f"Invoking `VLM.freeze_backbones()` for `{vla_id}` => Stage: `{stage}`") 214 | vla.freeze_backbones(stage) 215 | 216 | # Print number of total/trainable model parameters 217 | num_params = sum(p.numel() for p in vla.parameters()) 218 | num_trainable_params = sum(p.numel() for p in vla.parameters() if p.requires_grad) 219 | overwatch.info( 220 | f"# Parameters (in millions): {num_params / 10**6:.3f} Total, {num_trainable_params / 10**6:.3f} Trainable" 221 | ) 222 | 223 | overwatch.info(f"Creating VLA Open-X Dataset with Mixture `{cfg.vla.data_mix}`") 224 | vla_dataset, _, collator = get_vla_dataset_and_collator( 225 | cfg.data_root_dir, 226 | cfg.vla.data_mix, 227 | image_transform=vla.vision_backbone.get_image_transform(), 228 | tokenizer=vla.llm_backbone.get_tokenizer(), 229 | prompt_builder_fn=vla.llm_backbone.prompt_builder_fn, 230 | default_image_resolution=vla.vision_backbone.default_image_resolution, 231 | shuffle_buffer_size=cfg.vla.shuffle_buffer_size, 232 | image_aug=cfg.image_aug, 233 | load_all_data_for_training=cfg.load_all_data_for_training, 234 | future_action_window_size=cfg.future_action_window_size, 235 | past_action_window_size=cfg.past_action_window_size, 236 | ) 237 | 238 | # Save dataset statistics for de-normalization at inference time 239 | if overwatch.is_rank_zero(): 240 | save_dataset_statistics(vla_dataset.dataset_statistics, run_dir) 241 | 242 | dist.barrier() 243 | # Create Train Strategy 244 | overwatch.info(f"Initializing Train Strategy `{cfg.train_strategy}`") 245 | train_strategy = get_train_strategy( 246 | train_strategy=cfg.train_strategy, 247 | vlm=vla, 248 | device_id=device_id, 249 | stage=stage, 250 | epochs=cfg.epochs, 251 | max_steps=cfg.max_steps, 252 | global_batch_size=cfg.global_batch_size, 253 | per_device_batch_size=cfg.per_device_batch_size, 254 | learning_rate=cfg.learning_rate, 255 | weight_decay=cfg.weight_decay, 256 | max_grad_norm=cfg.max_grad_norm, 257 | lr_scheduler_type=cfg.lr_scheduler_type, 258 | warmup_ratio=cfg.warmup_ratio, 259 | enable_gradient_checkpointing=cfg.vla.enable_gradient_checkpointing, 260 | enable_mixed_precision_training=cfg.vla.enable_mixed_precision_training, 261 | reduce_in_full_precision=cfg.vla.reduce_in_full_precision, 262 | worker_init_fn=worker_init_fn, 263 | ) 264 | train_strategy.run_setup(run_dir=run_dir, n_train_examples=len(vla_dataset)) 265 | if cfg.pretrained_checkpoint is not None and cfg.is_resume: 266 | train_strategy.load_optimizer_and_scheduler(cfg.pretrained_checkpoint) 267 | 268 | # Create Metrics =>> Handles on the fly tracking, logging to specified trackers (e.g., JSONL, Weights & Biases) 269 | overwatch.info(f"Creating Metrics with Active Trackers => `{cfg.trackers}`") 270 | metrics = VLAMetrics( 271 | cfg.trackers, 272 | cfg.run_id, 273 | run_dir, 274 | draccus.encode(cfg), 275 | wandb_project=cfg.wandb_project, 276 | wandb_entity=cfg.wandb_entity, 277 | resume_step=cfg.resume_step, 278 | resume_epoch=cfg.resume_epoch, 279 | ) 280 | 281 | # Run VLA Training 282 | overwatch.info("Starting VLA Training Loop") 283 | train_strategy.run_vla_training( 284 | vla_dataset, 285 | collator, 286 | metrics, 287 | save_interval=cfg.save_interval, 288 | action_model=True, 289 | ) 290 | 291 | # Finalize 292 | overwatch.info("Done with Training =>> Finalizing Metrics") 293 | metrics.finalize() 294 | 295 | # And... we're done! 296 | overwatch.info("... and that's all, folks!") 297 | dist.barrier() 298 | dist.destroy_process_group() 299 | 300 | 301 | if __name__ == "__main__": 302 | train() 303 | -------------------------------------------------------------------------------- /sim_cogact/__init__.py: -------------------------------------------------------------------------------- 1 | from .cogact_policy import CogACTInference 2 | from .adaptive_ensemble import AdaptiveEnsembler -------------------------------------------------------------------------------- /sim_cogact/adaptive_ensemble.py: -------------------------------------------------------------------------------- 1 | """ 2 | adaptive_ensemble.py 3 | 4 | """ 5 | 6 | from collections import deque 7 | import numpy as np 8 | 9 | 10 | class AdaptiveEnsembler: 11 | def __init__(self, pred_action_horizon, adaptive_ensemble_alpha=0.0): 12 | self.pred_action_horizon = pred_action_horizon 13 | self.action_history = deque(maxlen=self.pred_action_horizon) 14 | self.adaptive_ensemble_alpha = adaptive_ensemble_alpha 15 | 16 | def reset(self): 17 | self.action_history.clear() 18 | 19 | def ensemble_action(self, cur_action): 20 | self.action_history.append(cur_action) 21 | num_actions = len(self.action_history) 22 | if cur_action.ndim == 1: 23 | curr_act_preds = np.stack(self.action_history) 24 | else: 25 | curr_act_preds = np.stack( 26 | [pred_actions[i] for (i, pred_actions) in zip(range(num_actions - 1, -1, -1), self.action_history)] 27 | ) 28 | 29 | # calculate cosine similarity between the current prediction and all previous predictions 30 | ref = curr_act_preds[num_actions-1, :] 31 | previous_pred = curr_act_preds 32 | dot_product = np.sum(previous_pred * ref, axis=1) 33 | norm_previous_pred = np.linalg.norm(previous_pred, axis=1) 34 | norm_ref = np.linalg.norm(ref) 35 | cos_similarity = dot_product / (norm_previous_pred * norm_ref + 1e-7) 36 | 37 | # compute the weights for each prediction 38 | weights = np.exp(self.adaptive_ensemble_alpha * cos_similarity) 39 | weights = weights / weights.sum() 40 | 41 | # compute the weighted average across all predictions for this timestep 42 | cur_action = np.sum(weights[:, None] * curr_act_preds, axis=0) 43 | 44 | return cur_action -------------------------------------------------------------------------------- /sim_cogact/cogact_policy.py: -------------------------------------------------------------------------------- 1 | """ 2 | cogact_policy.py 3 | 4 | """ 5 | from collections import deque 6 | from typing import Optional, Sequence 7 | import os 8 | from PIL import Image 9 | import torch 10 | import cv2 as cv 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | 14 | from transforms3d.euler import euler2axangle 15 | from transformers import AutoModelForVision2Seq, AutoProcessor 16 | 17 | 18 | from vla import CogACT 19 | from vla import load_vla 20 | from sim_cogact.adaptive_ensemble import AdaptiveEnsembler 21 | 22 | class CogACTInference: 23 | def __init__( 24 | self, 25 | saved_model_path: str = 'CogACT/CogACT-Base', 26 | unnorm_key: Optional[str] = None, 27 | policy_setup: str = "widowx_bridge", 28 | horizon: int = 0, 29 | action_ensemble_horizon: Optional[int] = None, 30 | image_size: list[int] = [224, 224], 31 | future_action_window_size: int = 15, 32 | action_dim: int = 7, 33 | action_model_type: str = "DiT-B", 34 | action_scale: float = 1.0, 35 | cfg_scale: float = 1.5, 36 | use_ddim: bool = True, 37 | num_ddim_steps: int = 10, 38 | use_bf16: bool = False, 39 | action_ensemble = True, 40 | adaptive_ensemble_alpha = 0.1, 41 | ) -> None: 42 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 43 | if policy_setup == "widowx_bridge": 44 | unnorm_key = "bridge_orig" if unnorm_key is None else unnorm_key 45 | action_ensemble = action_ensemble 46 | adaptive_ensemble_alpha = adaptive_ensemble_alpha 47 | if action_ensemble_horizon is None: 48 | # Set 7 for widowx_bridge to fix the window size of motion scale between each frame. see appendix in our paper for details 49 | action_ensemble_horizon = 7 50 | self.sticky_gripper_num_repeat = 1 51 | elif policy_setup == "google_robot": 52 | unnorm_key = "fractal20220817_data" if unnorm_key is None else unnorm_key 53 | action_ensemble = action_ensemble 54 | adaptive_ensemble_alpha = adaptive_ensemble_alpha 55 | if action_ensemble_horizon is None: 56 | # Set 2 for google_robot to fix the window size of motion scale between each frame. see appendix in our paper for details 57 | action_ensemble_horizon = 2 58 | self.sticky_gripper_num_repeat = 10 59 | else: 60 | raise NotImplementedError( 61 | f"Policy setup {policy_setup} not supported for octo models. The other datasets can be found in the huggingface config.json file." 62 | ) 63 | self.policy_setup = policy_setup 64 | self.unnorm_key = unnorm_key 65 | 66 | print(f"*** policy_setup: {policy_setup}, unnorm_key: {unnorm_key} ***") 67 | self.use_ddim = use_ddim 68 | self.num_ddim_steps = num_ddim_steps 69 | self.vla = load_vla( 70 | saved_model_path, # choose from ['CogACT/CogACT-Small', 'CogACT/CogACT-Base', 'CogACT/CogACT-Large'] or the local path 71 | load_for_training=False, 72 | action_model_type=action_model_type, # choose from ['DiT-Small', 'DiT-Base', 'DiT-Large'] to match the model weight 73 | future_action_window_size=future_action_window_size, 74 | action_dim=action_dim, 75 | ) 76 | 77 | if use_bf16: 78 | self.vla.vlm = self.vla.vlm.to(torch.bfloat16) 79 | self.vla = self.vla.to("cuda").eval() 80 | self.cfg_scale = cfg_scale 81 | 82 | self.image_size = image_size 83 | self.action_scale = action_scale 84 | self.horizon = horizon 85 | self.action_ensemble = action_ensemble 86 | self.adaptive_ensemble_alpha = adaptive_ensemble_alpha 87 | self.action_ensemble_horizon = action_ensemble_horizon 88 | self.sticky_action_is_on = False 89 | self.gripper_action_repeat = 0 90 | self.sticky_gripper_action = 0.0 91 | self.previous_gripper_action = None 92 | 93 | self.task_description = None 94 | self.image_history = deque(maxlen=self.horizon) 95 | if self.action_ensemble: 96 | self.action_ensembler = AdaptiveEnsembler(self.action_ensemble_horizon, self.adaptive_ensemble_alpha) 97 | else: 98 | self.action_ensembler = None 99 | self.num_image_history = 0 100 | 101 | def _add_image_to_history(self, image: np.ndarray) -> None: 102 | self.image_history.append(image) 103 | self.num_image_history = min(self.num_image_history + 1, self.horizon) 104 | 105 | def reset(self, task_description: str) -> None: 106 | self.task_description = task_description 107 | self.image_history.clear() 108 | if self.action_ensemble: 109 | self.action_ensembler.reset() 110 | self.num_image_history = 0 111 | 112 | self.sticky_action_is_on = False 113 | self.gripper_action_repeat = 0 114 | self.sticky_gripper_action = 0.0 115 | self.previous_gripper_action = None 116 | 117 | def step( 118 | self, image: np.ndarray, task_description: Optional[str] = None, *args, **kwargs 119 | ) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]: 120 | """ 121 | Input: 122 | image: np.ndarray of shape (H, W, 3), uint8 123 | task_description: Optional[str], task description; if different from previous task description, policy state is reset 124 | Output: 125 | raw_action: dict; raw policy action output 126 | action: dict; processed action to be sent to the maniskill2 environment, with the following keys: 127 | - 'world_vector': np.ndarray of shape (3,), xyz translation of robot end-effector 128 | - 'rot_axangle': np.ndarray of shape (3,), axis-angle representation of end-effector rotation 129 | - 'gripper': np.ndarray of shape (1,), gripper action 130 | - 'terminate_episode': np.ndarray of shape (1,), 1 if episode should be terminated, 0 otherwise 131 | """ 132 | if task_description is not None: 133 | if task_description != self.task_description: 134 | self.reset(task_description) 135 | 136 | assert image.dtype == np.uint8 137 | self._add_image_to_history(self._resize_image(image)) 138 | image: Image.Image = Image.fromarray(image) 139 | raw_actions, normalized_actions = self.vla.predict_action(image=image, 140 | instruction=self.task_description, 141 | unnorm_key=self.unnorm_key, 142 | do_sample=False, 143 | cfg_scale=self.cfg_scale, 144 | use_ddim=self.use_ddim, 145 | num_ddim_steps=self.num_ddim_steps, 146 | ) 147 | 148 | if self.action_ensemble: 149 | raw_actions = self.action_ensembler.ensemble_action(raw_actions)[None] 150 | raw_action = { 151 | "world_vector": np.array(raw_actions[0, :3]), 152 | "rotation_delta": np.array(raw_actions[0, 3:6]), 153 | "open_gripper": np.array(raw_actions[0, 6:7]), # range [0, 1]; 1 = open; 0 = close 154 | } 155 | 156 | # process raw_action to obtain the action to be sent to the maniskill2 environment 157 | action = {} 158 | action["world_vector"] = raw_action["world_vector"] * self.action_scale 159 | action_rotation_delta = np.asarray(raw_action["rotation_delta"], dtype=np.float64) 160 | 161 | roll, pitch, yaw = action_rotation_delta 162 | axes, angles = euler2axangle(roll, pitch, yaw) 163 | action_rotation_axangle = axes * angles 164 | action["rot_axangle"] = action_rotation_axangle * self.action_scale 165 | 166 | if self.policy_setup == "google_robot": 167 | action["gripper"] = 0 168 | current_gripper_action = raw_action["open_gripper"] 169 | if self.previous_gripper_action is None: 170 | relative_gripper_action = np.array([0]) 171 | self.previous_gripper_action = current_gripper_action 172 | else: 173 | relative_gripper_action = self.previous_gripper_action - current_gripper_action 174 | # fix a bug in the SIMPLER code here 175 | # self.previous_gripper_action = current_gripper_action 176 | 177 | if np.abs(relative_gripper_action) > 0.5 and (not self.sticky_action_is_on): 178 | self.sticky_action_is_on = True 179 | self.sticky_gripper_action = relative_gripper_action 180 | self.previous_gripper_action = current_gripper_action 181 | 182 | if self.sticky_action_is_on: 183 | self.gripper_action_repeat += 1 184 | relative_gripper_action = self.sticky_gripper_action 185 | 186 | if self.gripper_action_repeat == self.sticky_gripper_num_repeat: 187 | self.sticky_action_is_on = False 188 | self.gripper_action_repeat = 0 189 | self.sticky_gripper_action = 0.0 190 | 191 | action["gripper"] = relative_gripper_action 192 | 193 | elif self.policy_setup == "widowx_bridge": 194 | action["gripper"] = 2.0 * (raw_action["open_gripper"] > 0.5) - 1.0 195 | 196 | action["terminate_episode"] = np.array([0.0]) 197 | return raw_action, action 198 | 199 | def _resize_image(self, image: np.ndarray) -> np.ndarray: 200 | image = cv.resize(image, tuple(self.image_size), interpolation=cv.INTER_AREA) 201 | return image 202 | 203 | def visualize_epoch( 204 | self, predicted_raw_actions: Sequence[np.ndarray], images: Sequence[np.ndarray], save_path: str 205 | ) -> None: 206 | images = [self._resize_image(image) for image in images] 207 | ACTION_DIM_LABELS = ["x", "y", "z", "roll", "pitch", "yaw", "grasp"] 208 | 209 | img_strip = np.concatenate(np.array(images[::3]), axis=1) 210 | 211 | # set up plt figure 212 | figure_layout = [["image"] * len(ACTION_DIM_LABELS), ACTION_DIM_LABELS] 213 | plt.rcParams.update({"font.size": 12}) 214 | fig, axs = plt.subplot_mosaic(figure_layout) 215 | fig.set_size_inches([45, 10]) 216 | 217 | # plot actions 218 | pred_actions = np.array( 219 | [ 220 | np.concatenate([a["world_vector"], a["rotation_delta"], a["open_gripper"]], axis=-1) 221 | for a in predicted_raw_actions 222 | ] 223 | ) 224 | for action_dim, action_label in enumerate(ACTION_DIM_LABELS): 225 | # actions have batch, horizon, dim, in this example we just take the first action for simplicity 226 | axs[action_label].plot(pred_actions[:, action_dim], label="predicted action") 227 | axs[action_label].set_title(action_label) 228 | axs[action_label].set_xlabel("Time in one episode") 229 | 230 | axs["image"].imshow(img_strip) 231 | axs["image"].set_xlabel("Time in one episode (subsampled)") 232 | plt.legend() 233 | plt.savefig(save_path) -------------------------------------------------------------------------------- /sim_cogact/scripts/cogact_bridge.sh: -------------------------------------------------------------------------------- 1 | gpu_id=0 2 | policy_model=cogact 3 | ckpt_path=CogACT/CogACT-Base # CogACT/CogACT-Large CogACT/CogACT-Small 4 | 5 | scene_name=bridge_table_1_v1 6 | robot=widowx 7 | rgb_overlay_path=ManiSkill2_real2sim/data/real_inpainting/bridge_real_eval_1.png 8 | robot_init_x=0.147 9 | robot_init_y=0.028 10 | 11 | 12 | CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model ${policy_model} --ckpt-path ${ckpt_path} \ 13 | --robot ${robot} --policy-setup widowx_bridge \ 14 | --control-freq 5 --sim-freq 500 --max-episode-steps 120 \ 15 | --env-name StackGreenCubeOnYellowCubeBakedTexInScene-v0 --scene-name ${scene_name} \ 16 | --rgb-overlay-path ${rgb_overlay_path} \ 17 | --robot-init-x ${robot_init_x} ${robot_init_x} 1 --robot-init-y ${robot_init_y} ${robot_init_y} 1 --obj-variation-mode episode --obj-episode-range 0 24 \ 18 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 0 0 1; 19 | 20 | CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model ${policy_model} --ckpt-path ${ckpt_path} \ 21 | --robot ${robot} --policy-setup widowx_bridge \ 22 | --control-freq 5 --sim-freq 500 --max-episode-steps 120 \ 23 | --env-name PutCarrotOnPlateInScene-v0 --scene-name ${scene_name} \ 24 | --rgb-overlay-path ${rgb_overlay_path} \ 25 | --robot-init-x ${robot_init_x} ${robot_init_x} 1 --robot-init-y ${robot_init_y} ${robot_init_y} 1 --obj-variation-mode episode --obj-episode-range 0 24 \ 26 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 0 0 1; 27 | 28 | CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model ${policy_model} --ckpt-path ${ckpt_path} \ 29 | --robot ${robot} --policy-setup widowx_bridge \ 30 | --control-freq 5 --sim-freq 500 --max-episode-steps 120 \ 31 | --env-name PutSpoonOnTableClothInScene-v0 --scene-name ${scene_name} \ 32 | --rgb-overlay-path ${rgb_overlay_path} \ 33 | --robot-init-x ${robot_init_x} ${robot_init_x} 1 --robot-init-y ${robot_init_y} ${robot_init_y} 1 --obj-variation-mode episode --obj-episode-range 0 24 \ 34 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 0 0 1; 35 | 36 | 37 | scene_name=bridge_table_1_v2 38 | robot=widowx_sink_camera_setup 39 | rgb_overlay_path=ManiSkill2_real2sim/data/real_inpainting/bridge_sink.png 40 | robot_init_x=0.127 41 | robot_init_y=0.06 42 | 43 | CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model ${policy_model} --ckpt-path ${ckpt_path} \ 44 | --robot ${robot} --policy-setup widowx_bridge \ 45 | --control-freq 5 --sim-freq 500 --max-episode-steps 120 \ 46 | --env-name PutEggplantInBasketScene-v0 --scene-name ${scene_name} \ 47 | --rgb-overlay-path ${rgb_overlay_path} \ 48 | --robot-init-x ${robot_init_x} ${robot_init_x} 1 --robot-init-y ${robot_init_y} ${robot_init_y} 1 --obj-variation-mode episode --obj-episode-range 0 24 \ 49 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 0 0 1; -------------------------------------------------------------------------------- /sim_cogact/scripts/cogact_drawer_variant_agg.sh: -------------------------------------------------------------------------------- 1 | # shader_dir=rt means that we turn on ray-tracing rendering; this is quite crucial for the open / close drawer task as policies often rely on shadows to infer depth 2 | gpu_id=0 3 | 4 | declare -a ckpt_paths=( 5 | "CogACT/CogACT-Base" 6 | ) 7 | # CogACT/CogACT-Large CogACT/CogACT-Small 8 | declare -a env_names=( 9 | OpenTopDrawerCustomInScene-v0 10 | OpenMiddleDrawerCustomInScene-v0 11 | OpenBottomDrawerCustomInScene-v0 12 | CloseTopDrawerCustomInScene-v0 13 | CloseMiddleDrawerCustomInScene-v0 14 | CloseBottomDrawerCustomInScene-v0 15 | ) 16 | 17 | EXTRA_ARGS="--enable-raytracing" 18 | 19 | 20 | # base setup 21 | scene_name=frl_apartment_stage_simple 22 | 23 | EvalSim() { 24 | echo ${ckpt_path} ${env_name} 25 | 26 | CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model cogact --ckpt-path ${ckpt_path} \ 27 | --robot google_robot_static \ 28 | --control-freq 3 --sim-freq 513 --max-episode-steps 113 \ 29 | --env-name ${env_name} --scene-name ${scene_name} \ 30 | --robot-init-x 0.65 0.85 3 --robot-init-y -0.2 0.2 3 \ 31 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 0.0 0.0 1 \ 32 | --obj-init-x-range 0 0 1 --obj-init-y-range 0 0 1 \ 33 | ${EXTRA_ARGS} 34 | } 35 | 36 | 37 | for ckpt_path in "${ckpt_paths[@]}"; do 38 | for env_name in "${env_names[@]}"; do 39 | EvalSim 40 | done 41 | done 42 | 43 | 44 | # backgrounds 45 | 46 | declare -a scene_names=( 47 | "modern_bedroom_no_roof" 48 | "modern_office_no_roof" 49 | ) 50 | 51 | for scene_name in "${scene_names[@]}"; do 52 | for ckpt_path in "${ckpt_paths[@]}"; do 53 | for env_name in "${env_names[@]}"; do 54 | EXTRA_ARGS="--additional-env-build-kwargs shader_dir=rt" 55 | EvalSim 56 | done 57 | done 58 | done 59 | 60 | 61 | # lightings 62 | scene_name=frl_apartment_stage_simple 63 | 64 | for ckpt_path in "${ckpt_paths[@]}"; do 65 | for env_name in "${env_names[@]}"; do 66 | EXTRA_ARGS="--additional-env-build-kwargs shader_dir=rt light_mode=brighter" 67 | EvalSim 68 | EXTRA_ARGS="--additional-env-build-kwargs shader_dir=rt light_mode=darker" 69 | EvalSim 70 | done 71 | done 72 | 73 | 74 | # new cabinets 75 | scene_name=frl_apartment_stage_simple 76 | 77 | for ckpt_path in "${ckpt_paths[@]}"; do 78 | for env_name in "${env_names[@]}"; do 79 | EXTRA_ARGS="--additional-env-build-kwargs shader_dir=rt station_name=mk_station2" 80 | EvalSim 81 | EXTRA_ARGS="--additional-env-build-kwargs shader_dir=rt station_name=mk_station3" 82 | EvalSim 83 | done 84 | done 85 | -------------------------------------------------------------------------------- /sim_cogact/scripts/cogact_drawer_visual_matching.sh: -------------------------------------------------------------------------------- 1 | # shader_dir=rt means that we turn on ray-tracing rendering; this is quite crucial for the open / close drawer task as policies often rely on shadows to infer depth 2 | gpu_id=0 3 | 4 | declare -a ckpt_paths=( 5 | "CogACT/CogACT-Base" 6 | ) 7 | 8 | declare -a env_names=( 9 | OpenTopDrawerCustomInScene-v0 10 | OpenMiddleDrawerCustomInScene-v0 11 | OpenBottomDrawerCustomInScene-v0 12 | CloseTopDrawerCustomInScene-v0 13 | CloseMiddleDrawerCustomInScene-v0 14 | CloseBottomDrawerCustomInScene-v0 15 | ) 16 | 17 | # URDF variations 18 | declare -a urdf_version_arr=("recolor_cabinet_visual_matching_1" "recolor_tabletop_visual_matching_1" "recolor_tabletop_visual_matching_2" None) 19 | 20 | for urdf_version in "${urdf_version_arr[@]}"; do 21 | 22 | EXTRA_ARGS="--enable-raytracing --additional-env-build-kwargs station_name=mk_station_recolor light_mode=simple disable_bad_material=True urdf_version=${urdf_version}" 23 | 24 | EvalOverlay() { 25 | # A0 26 | CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model cogact --ckpt-path ${ckpt_path} \ 27 | --robot google_robot_static \ 28 | --control-freq 3 --sim-freq 513 --max-episode-steps 113 \ 29 | --env-name ${env_name} --scene-name dummy_drawer \ 30 | --robot-init-x 0.644 0.644 1 --robot-init-y -0.179 -0.179 1 \ 31 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 -0.03 -0.03 1 \ 32 | --obj-init-x-range 0 0 1 --obj-init-y-range 0 0 1 \ 33 | --rgb-overlay-path ./ManiSkill2_real2sim/data/real_inpainting/open_drawer_a0.png \ 34 | ${EXTRA_ARGS} 35 | 36 | # A1 37 | CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model cogact --ckpt-path ${ckpt_path} \ 38 | --robot google_robot_static \ 39 | --control-freq 3 --sim-freq 513 --max-episode-steps 113 \ 40 | --env-name ${env_name} --scene-name dummy_drawer \ 41 | --robot-init-x 0.765 0.765 1 --robot-init-y -0.182 -0.182 1 \ 42 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 -0.02 -0.02 1 \ 43 | --obj-init-x-range 0 0 1 --obj-init-y-range 0 0 1 \ 44 | --rgb-overlay-path ./ManiSkill2_real2sim/data/real_inpainting/open_drawer_a1.png \ 45 | ${EXTRA_ARGS} 46 | 47 | # A2 48 | CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model cogact --ckpt-path ${ckpt_path} \ 49 | --robot google_robot_static \ 50 | --control-freq 3 --sim-freq 513 --max-episode-steps 113 \ 51 | --env-name ${env_name} --scene-name dummy_drawer \ 52 | --robot-init-x 0.889 0.889 1 --robot-init-y -0.203 -0.203 1 \ 53 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 -0.06 -0.06 1 \ 54 | --obj-init-x-range 0 0 1 --obj-init-y-range 0 0 1 \ 55 | --rgb-overlay-path ./ManiSkill2_real2sim/data/real_inpainting/open_drawer_a2.png \ 56 | ${EXTRA_ARGS} 57 | 58 | # B0 59 | CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model cogact --ckpt-path ${ckpt_path} \ 60 | --robot google_robot_static \ 61 | --control-freq 3 --sim-freq 513 --max-episode-steps 113 \ 62 | --env-name ${env_name} --scene-name dummy_drawer \ 63 | --robot-init-x 0.652 0.652 1 --robot-init-y 0.009 0.009 1 \ 64 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 0 0 1 \ 65 | --obj-init-x-range 0 0 1 --obj-init-y-range 0 0 1 \ 66 | --rgb-overlay-path ./ManiSkill2_real2sim/data/real_inpainting/open_drawer_b0.png \ 67 | ${EXTRA_ARGS} 68 | 69 | # B1 70 | CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model cogact --ckpt-path ${ckpt_path} \ 71 | --robot google_robot_static \ 72 | --control-freq 3 --sim-freq 513 --max-episode-steps 113 \ 73 | --env-name ${env_name} --scene-name dummy_drawer \ 74 | --robot-init-x 0.752 0.752 1 --robot-init-y 0.009 0.009 1 \ 75 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 0 0 1 \ 76 | --obj-init-x-range 0 0 1 --obj-init-y-range 0 0 1 \ 77 | --rgb-overlay-path ./ManiSkill2_real2sim/data/real_inpainting/open_drawer_b1.png \ 78 | ${EXTRA_ARGS} 79 | 80 | # B2 81 | CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model cogact --ckpt-path ${ckpt_path} \ 82 | --robot google_robot_static \ 83 | --control-freq 3 --sim-freq 513 --max-episode-steps 113 \ 84 | --env-name ${env_name} --scene-name dummy_drawer \ 85 | --robot-init-x 0.851 0.851 1 --robot-init-y 0.035 0.035 1 \ 86 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 0 0 1 \ 87 | --obj-init-x-range 0 0 1 --obj-init-y-range 0 0 1 \ 88 | --rgb-overlay-path ./ManiSkill2_real2sim/data/real_inpainting/open_drawer_b2.png \ 89 | ${EXTRA_ARGS} 90 | 91 | # C0 92 | CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model cogact --ckpt-path ${ckpt_path} \ 93 | --robot google_robot_static \ 94 | --control-freq 3 --sim-freq 513 --max-episode-steps 113 \ 95 | --env-name ${env_name} --scene-name dummy_drawer \ 96 | --robot-init-x 0.665 0.665 1 --robot-init-y 0.224 0.224 1 \ 97 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 0 0 1 \ 98 | --obj-init-x-range 0 0 1 --obj-init-y-range 0 0 1 \ 99 | --rgb-overlay-path ./ManiSkill2_real2sim/data/real_inpainting/open_drawer_c0.png \ 100 | ${EXTRA_ARGS} 101 | 102 | # C1 103 | CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model cogact --ckpt-path ${ckpt_path} \ 104 | --robot google_robot_static \ 105 | --control-freq 3 --sim-freq 513 --max-episode-steps 113 \ 106 | --env-name ${env_name} --scene-name dummy_drawer \ 107 | --robot-init-x 0.765 0.765 1 --robot-init-y 0.222 0.222 1 \ 108 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 -0.025 -0.025 1 \ 109 | --obj-init-x-range 0 0 1 --obj-init-y-range 0 0 1 \ 110 | --rgb-overlay-path ./ManiSkill2_real2sim/data/real_inpainting/open_drawer_c1.png \ 111 | ${EXTRA_ARGS} 112 | 113 | # C2 114 | CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model cogact --ckpt-path ${ckpt_path} \ 115 | --robot google_robot_static \ 116 | --control-freq 3 --sim-freq 513 --max-episode-steps 113 \ 117 | --env-name ${env_name} --scene-name dummy_drawer \ 118 | --robot-init-x 0.865 0.865 1 --robot-init-y 0.222 0.222 1 \ 119 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 -0.025 -0.025 1 \ 120 | --obj-init-x-range 0 0 1 --obj-init-y-range 0 0 1 \ 121 | --rgb-overlay-path ./ManiSkill2_real2sim/data/real_inpainting/open_drawer_c2.png \ 122 | ${EXTRA_ARGS} 123 | } 124 | 125 | 126 | for ckpt_path in "${ckpt_paths[@]}"; do 127 | for env_name in "${env_names[@]}"; do 128 | EvalOverlay 129 | done 130 | done 131 | 132 | 133 | 134 | done 135 | -------------------------------------------------------------------------------- /sim_cogact/scripts/cogact_move_near_variant_agg.sh: -------------------------------------------------------------------------------- 1 | gpu_id=0 2 | 3 | declare -a arr=("CogACT/CogACT-Base") 4 | # CogACT/CogACT-Large CogACT/CogACT-Small 5 | for ckpt_path in "${arr[@]}"; do echo "$ckpt_path"; done 6 | 7 | 8 | # base setup 9 | 10 | env_name=MoveNearGoogleInScene-v0 11 | scene_name=google_pick_coke_can_1_v4 12 | 13 | for ckpt_path in "${arr[@]}"; 14 | 15 | do CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model cogact --ckpt-path ${ckpt_path} \ 16 | --robot google_robot_static \ 17 | --control-freq 3 --sim-freq 513 --max-episode-steps 80 \ 18 | --env-name ${env_name} --scene-name ${scene_name} \ 19 | --robot-init-x 0.35 0.35 1 --robot-init-y 0.21 0.21 1 --obj-variation-mode episode --obj-episode-range 0 60 \ 20 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 -0.09 -0.09 1; 21 | 22 | done 23 | 24 | 25 | 26 | # distractor 27 | 28 | for ckpt_path in "${arr[@]}"; 29 | 30 | do CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model cogact --ckpt-path ${ckpt_path} \ 31 | --robot google_robot_static \ 32 | --control-freq 3 --sim-freq 513 --max-episode-steps 80 \ 33 | --env-name ${env_name} --scene-name ${scene_name} \ 34 | --robot-init-x 0.35 0.35 1 --robot-init-y 0.21 0.21 1 --obj-variation-mode episode --obj-episode-range 0 60 \ 35 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 -0.09 -0.09 1 \ 36 | --additional-env-build-kwargs no_distractor=True; 37 | 38 | done 39 | 40 | 41 | # backgrounds 42 | 43 | env_name=MoveNearGoogleInScene-v0 44 | declare -a scene_arr=("google_pick_coke_can_1_v4_alt_background" \ 45 | "google_pick_coke_can_1_v4_alt_background_2") 46 | 47 | for scene_name in "${scene_arr[@]}"; 48 | 49 | do for ckpt_path in "${arr[@]}"; 50 | 51 | do CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model cogact --ckpt-path ${ckpt_path} \ 52 | --robot google_robot_static \ 53 | --control-freq 3 --sim-freq 513 --max-episode-steps 80 \ 54 | --env-name ${env_name} --scene-name ${scene_name} \ 55 | --robot-init-x 0.35 0.35 1 --robot-init-y 0.21 0.21 1 --obj-variation-mode episode --obj-episode-range 0 60 \ 56 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 -0.09 -0.09 1; 57 | 58 | done 59 | 60 | done 61 | 62 | 63 | 64 | 65 | 66 | # lighting 67 | 68 | env_name=MoveNearGoogleInScene-v0 69 | scene_name=google_pick_coke_can_1_v4 70 | 71 | for ckpt_path in "${arr[@]}"; 72 | 73 | do CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model cogact --ckpt-path ${ckpt_path} \ 74 | --robot google_robot_static \ 75 | --control-freq 3 --sim-freq 513 --max-episode-steps 80 \ 76 | --env-name ${env_name} --scene-name ${scene_name} \ 77 | --robot-init-x 0.35 0.35 1 --robot-init-y 0.21 0.21 1 --obj-variation-mode episode --obj-episode-range 0 60 \ 78 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 -0.09 -0.09 1 \ 79 | --additional-env-build-kwargs slightly_darker_lighting=True; 80 | 81 | CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model cogact --ckpt-path ${ckpt_path} \ 82 | --robot google_robot_static \ 83 | --control-freq 3 --sim-freq 513 --max-episode-steps 80 \ 84 | --env-name ${env_name} --scene-name ${scene_name} \ 85 | --robot-init-x 0.35 0.35 1 --robot-init-y 0.21 0.21 1 --obj-variation-mode episode --obj-episode-range 0 60 \ 86 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 -0.09 -0.09 1 \ 87 | --additional-env-build-kwargs slightly_brighter_lighting=True; 88 | 89 | done 90 | 91 | 92 | 93 | 94 | 95 | # table textures 96 | 97 | env_name=MoveNearGoogleInScene-v0 98 | declare -a scene_arr=("Baked_sc1_staging_objaverse_cabinet1_h870" \ 99 | "Baked_sc1_staging_objaverse_cabinet2_h870") 100 | 101 | for scene_name in "${scene_arr[@]}"; 102 | 103 | do for ckpt_path in "${arr[@]}"; 104 | 105 | do CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model cogact --ckpt-path ${ckpt_path} \ 106 | --robot google_robot_static \ 107 | --control-freq 3 --sim-freq 513 --max-episode-steps 80 \ 108 | --env-name ${env_name} --scene-name ${scene_name} \ 109 | --robot-init-x 0.35 0.35 1 --robot-init-y 0.21 0.21 1 --obj-variation-mode episode --obj-episode-range 0 60 \ 110 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 -0.09 -0.09 1; 111 | 112 | done 113 | 114 | done 115 | 116 | 117 | 118 | 119 | # camera orientations 120 | 121 | declare -a env_arr=("MoveNearAltGoogleCameraInScene-v0" \ 122 | "MoveNearAltGoogleCamera2InScene-v0") 123 | scene_name=google_pick_coke_can_1_v4 124 | 125 | for env_name in "${env_arr[@]}"; 126 | 127 | do for ckpt_path in "${arr[@]}"; 128 | 129 | do CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model cogact --ckpt-path ${ckpt_path} \ 130 | --robot google_robot_static \ 131 | --control-freq 3 --sim-freq 513 --max-episode-steps 80 \ 132 | --env-name ${env_name} --scene-name ${scene_name} \ 133 | --robot-init-x 0.35 0.35 1 --robot-init-y 0.21 0.21 1 --obj-variation-mode episode --obj-episode-range 0 60 \ 134 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 -0.09 -0.09 1; 135 | 136 | done 137 | 138 | done 139 | -------------------------------------------------------------------------------- /sim_cogact/scripts/cogact_move_near_visual_matching.sh: -------------------------------------------------------------------------------- 1 | 2 | gpu_id=0 3 | 4 | declare -a arr=("CogACT/CogACT-Base") 5 | 6 | env_name=MoveNearGoogleBakedTexInScene-v0 7 | # env_name=MoveNearGoogleBakedTexInScene-v1 8 | scene_name=google_pick_coke_can_1_v4 9 | rgb_overlay_path=./ManiSkill2_real2sim/data/real_inpainting/google_move_near_real_eval_1.png 10 | 11 | # URDF variations 12 | declare -a urdf_version_arr=(None "recolor_tabletop_visual_matching_1" "recolor_tabletop_visual_matching_2" "recolor_cabinet_visual_matching_1") 13 | 14 | for ckpt_path in "${arr[@]}"; do echo "$ckpt_path"; done 15 | 16 | 17 | for urdf_version in "${urdf_version_arr[@]}"; 18 | 19 | do for ckpt_path in "${arr[@]}"; 20 | 21 | do CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model cogact --ckpt-path ${ckpt_path} \ 22 | --robot google_robot_static \ 23 | --control-freq 3 --sim-freq 513 --max-episode-steps 80 \ 24 | --env-name ${env_name} --scene-name ${scene_name} \ 25 | --rgb-overlay-path ${rgb_overlay_path} \ 26 | --robot-init-x 0.35 0.35 1 --robot-init-y 0.21 0.21 1 --obj-variation-mode episode --obj-episode-range 0 60 \ 27 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 -0.09 -0.09 1 \ 28 | --additional-env-build-kwargs urdf_version=${urdf_version} \ 29 | --additional-env-save-tags baked_except_bpb_orange; 30 | 31 | done 32 | 33 | done 34 | -------------------------------------------------------------------------------- /sim_cogact/scripts/cogact_pick_coke_can_variant_agg.sh: -------------------------------------------------------------------------------- 1 | gpu_id=0 2 | 3 | declare -a arr=("CogACT/CogACT-Base") 4 | 5 | # lr_switch=laying horizontally but flipped left-right to match real eval; upright=standing; laid_vertically=laying vertically 6 | declare -a coke_can_options_arr=("lr_switch=True" "upright=True" "laid_vertically=True") 7 | 8 | for ckpt_path in "${arr[@]}"; do echo "$ckpt_path"; done 9 | 10 | 11 | # base setup 12 | 13 | env_name=GraspSingleOpenedCokeCanInScene-v0 14 | scene_name=google_pick_coke_can_1_v4 15 | 16 | for coke_can_option in "${coke_can_options_arr[@]}"; 17 | 18 | do for ckpt_path in "${arr[@]}"; 19 | 20 | do CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model cogact --ckpt-path ${ckpt_path} \ 21 | --robot google_robot_static \ 22 | --control-freq 3 --sim-freq 513 --max-episode-steps 80 \ 23 | --env-name ${env_name} --scene-name ${scene_name} \ 24 | --robot-init-x 0.35 0.35 1 --robot-init-y 0.20 0.20 1 --obj-init-x -0.35 -0.12 5 --obj-init-y -0.02 0.42 5 \ 25 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 0 0 1 \ 26 | --additional-env-build-kwargs ${coke_can_option}; 27 | 28 | done 29 | 30 | done 31 | 32 | 33 | 34 | # table textures 35 | 36 | env_name=GraspSingleOpenedCokeCanInScene-v0 37 | 38 | declare -a scene_arr=("Baked_sc1_staging_objaverse_cabinet1_h870" \ 39 | "Baked_sc1_staging_objaverse_cabinet2_h870") 40 | 41 | 42 | for coke_can_option in "${coke_can_options_arr[@]}"; 43 | 44 | do for scene_name in "${scene_arr[@]}"; 45 | 46 | do for ckpt_path in "${arr[@]}"; 47 | 48 | do CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model cogact --ckpt-path ${ckpt_path} \ 49 | --robot google_robot_static \ 50 | --control-freq 3 --sim-freq 513 --max-episode-steps 80 \ 51 | --env-name ${env_name} --scene-name ${scene_name} \ 52 | --robot-init-x 0.35 0.35 1 --robot-init-y 0.20 0.20 1 --obj-init-x -0.35 -0.12 5 --obj-init-y -0.02 0.42 5 \ 53 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 0 0 1 \ 54 | --additional-env-build-kwargs ${coke_can_option}; 55 | 56 | done 57 | 58 | done 59 | 60 | done 61 | 62 | 63 | 64 | 65 | # distractors 66 | 67 | env_name=GraspSingleOpenedCokeCanDistractorInScene-v0 68 | scene_name=google_pick_coke_can_1_v4 69 | 70 | for coke_can_option in "${coke_can_options_arr[@]}"; 71 | 72 | do for ckpt_path in "${arr[@]}"; 73 | 74 | do CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model cogact --ckpt-path ${ckpt_path} \ 75 | --robot google_robot_static \ 76 | --control-freq 3 --sim-freq 513 --max-episode-steps 80 \ 77 | --env-name ${env_name} --scene-name ${scene_name} \ 78 | --robot-init-x 0.35 0.35 1 --robot-init-y 0.20 0.20 1 --obj-init-x -0.35 -0.12 5 --obj-init-y -0.02 0.42 5 \ 79 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 0 0 1 \ 80 | --additional-env-build-kwargs ${coke_can_option}; 81 | 82 | CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model cogact --ckpt-path ${ckpt_path} \ 83 | --robot google_robot_static \ 84 | --control-freq 3 --sim-freq 513 --max-episode-steps 80 \ 85 | --env-name ${env_name} --scene-name ${scene_name} \ 86 | --robot-init-x 0.35 0.35 1 --robot-init-y 0.20 0.20 1 --obj-init-x -0.35 -0.12 5 --obj-init-y -0.02 0.42 5 \ 87 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 0 0 1 \ 88 | --additional-env-build-kwargs ${coke_can_option} distractor_config=more; 89 | 90 | done 91 | 92 | done 93 | 94 | 95 | 96 | 97 | # backgrounds 98 | 99 | env_name=GraspSingleOpenedCokeCanInScene-v0 100 | declare -a scene_arr=("google_pick_coke_can_1_v4_alt_background" \ 101 | "google_pick_coke_can_1_v4_alt_background_2") 102 | 103 | for coke_can_option in "${coke_can_options_arr[@]}"; 104 | 105 | do for scene_name in "${scene_arr[@]}"; 106 | 107 | do for ckpt_path in "${arr[@]}"; 108 | 109 | do CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model cogact --ckpt-path ${ckpt_path} \ 110 | --robot google_robot_static \ 111 | --control-freq 3 --sim-freq 513 --max-episode-steps 80 \ 112 | --env-name ${env_name} --scene-name ${scene_name} \ 113 | --robot-init-x 0.35 0.35 1 --robot-init-y 0.20 0.20 1 --obj-init-x -0.35 -0.12 5 --obj-init-y -0.02 0.42 5 \ 114 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 0 0 1 \ 115 | --additional-env-build-kwargs ${coke_can_option}; 116 | 117 | done 118 | 119 | done 120 | 121 | done 122 | 123 | 124 | 125 | # lightings 126 | 127 | env_name=GraspSingleOpenedCokeCanInScene-v0 128 | scene_name=google_pick_coke_can_1_v4 129 | 130 | for coke_can_option in "${coke_can_options_arr[@]}"; 131 | 132 | do for ckpt_path in "${arr[@]}"; 133 | 134 | do CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model cogact --ckpt-path ${ckpt_path} \ 135 | --robot google_robot_static \ 136 | --control-freq 3 --sim-freq 513 --max-episode-steps 80 \ 137 | --env-name ${env_name} --scene-name ${scene_name} \ 138 | --robot-init-x 0.35 0.35 1 --robot-init-y 0.20 0.20 1 --obj-init-x -0.35 -0.12 5 --obj-init-y -0.02 0.42 5 \ 139 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 0 0 1 \ 140 | --additional-env-build-kwargs ${coke_can_option} slightly_darker_lighting=True; 141 | 142 | CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model cogact --ckpt-path ${ckpt_path} \ 143 | --robot google_robot_static \ 144 | --control-freq 3 --sim-freq 513 --max-episode-steps 80 \ 145 | --env-name ${env_name} --scene-name ${scene_name} \ 146 | --robot-init-x 0.35 0.35 1 --robot-init-y 0.20 0.20 1 --obj-init-x -0.35 -0.12 5 --obj-init-y -0.02 0.42 5 \ 147 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 0 0 1 \ 148 | --additional-env-build-kwargs ${coke_can_option} slightly_brighter_lighting=True; 149 | 150 | done 151 | 152 | done 153 | 154 | 155 | 156 | 157 | # camera orientations 158 | 159 | declare -a env_arr=("GraspSingleOpenedCokeCanAltGoogleCameraInScene-v0" \ 160 | "GraspSingleOpenedCokeCanAltGoogleCamera2InScene-v0") 161 | scene_name=google_pick_coke_can_1_v4 162 | 163 | for coke_can_option in "${coke_can_options_arr[@]}"; 164 | 165 | do for env_name in "${env_arr[@]}"; 166 | 167 | do for ckpt_path in "${arr[@]}"; 168 | 169 | do CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model cogact --ckpt-path ${ckpt_path} \ 170 | --robot google_robot_static \ 171 | --control-freq 3 --sim-freq 513 --max-episode-steps 80 \ 172 | --env-name ${env_name} --scene-name ${scene_name} \ 173 | --robot-init-x 0.35 0.35 1 --robot-init-y 0.20 0.20 1 --obj-init-x -0.35 -0.12 5 --obj-init-y -0.02 0.42 5 \ 174 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 0 0 1 \ 175 | --additional-env-build-kwargs ${coke_can_option}; 176 | 177 | done 178 | 179 | done 180 | 181 | done 182 | -------------------------------------------------------------------------------- /sim_cogact/scripts/cogact_pick_coke_can_visual_matching.sh: -------------------------------------------------------------------------------- 1 | gpu_id=0 2 | 3 | declare -a arr=("CogACT/CogACT-Base") 4 | 5 | # lr_switch=laying horizontally but flipped left-right to match real eval; upright=standing; laid_vertically=laying vertically 6 | declare -a coke_can_options_arr=("lr_switch=True" "upright=True" "laid_vertically=True") 7 | 8 | # URDF variations 9 | declare -a urdf_version_arr=(None "recolor_tabletop_visual_matching_1" "recolor_tabletop_visual_matching_2" "recolor_cabinet_visual_matching_1") 10 | 11 | env_name=GraspSingleOpenedCokeCanInScene-v0 12 | scene_name=google_pick_coke_can_1_v4 13 | rgb_overlay_path=./ManiSkill2_real2sim/data/real_inpainting/google_coke_can_real_eval_1.png 14 | 15 | for ckpt_path in "${arr[@]}"; do echo "$ckpt_path"; done 16 | 17 | 18 | 19 | for urdf_version in "${urdf_version_arr[@]}"; 20 | 21 | do for coke_can_option in "${coke_can_options_arr[@]}"; 22 | 23 | do for ckpt_path in "${arr[@]}"; 24 | 25 | do CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model cogact --ckpt-path ${ckpt_path} \ 26 | --robot google_robot_static \ 27 | --control-freq 3 --sim-freq 513 --max-episode-steps 80 \ 28 | --env-name ${env_name} --scene-name ${scene_name} \ 29 | --rgb-overlay-path ${rgb_overlay_path} \ 30 | --robot-init-x 0.35 0.35 1 --robot-init-y 0.20 0.20 1 --obj-init-x -0.35 -0.12 5 --obj-init-y -0.02 0.42 5 \ 31 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 0 0 1 \ 32 | --additional-env-build-kwargs ${coke_can_option} urdf_version=${urdf_version}; 33 | 34 | done 35 | 36 | done 37 | 38 | done 39 | -------------------------------------------------------------------------------- /sim_cogact/scripts/cogact_put_in_drawer_variant_agg.sh: -------------------------------------------------------------------------------- 1 | # shader_dir=rt means that we turn on ray-tracing rendering; this is quite crucial for the open / close drawer task as policies often rely on shadows to infer depth 2 | gpu_id=0 3 | 4 | 5 | declare -a ckpt_paths=( 6 | "CogACT/CogACT-Base" 7 | ) 8 | 9 | declare -a env_names=( 10 | PlaceIntoClosedTopDrawerCustomInScene-v0 11 | ) 12 | 13 | EXTRA_ARGS="--enable-raytracing --additional-env-build-kwargs model_ids=apple" 14 | 15 | 16 | # base setup 17 | scene_name=frl_apartment_stage_simple 18 | 19 | EvalSim() { 20 | echo ${ckpt_path} ${env_name} 21 | 22 | CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model cogact --ckpt-path ${ckpt_path} \ 23 | --robot google_robot_static \ 24 | --control-freq 3 --sim-freq 513 --max-episode-steps 200 \ 25 | --env-name ${env_name} --scene-name ${scene_name} \ 26 | --robot-init-x 0.65 0.65 1 --robot-init-y -0.2 0.2 3 \ 27 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 0.0 0.0 1 \ 28 | --obj-init-x-range -0.08 -0.02 3 --obj-init-y-range -0.02 0.08 3 \ 29 | ${EXTRA_ARGS} 30 | } 31 | 32 | 33 | for ckpt_path in "${ckpt_paths[@]}"; do 34 | for env_name in "${env_names[@]}"; do 35 | EvalSim 36 | done 37 | done 38 | 39 | 40 | # backgrounds 41 | 42 | declare -a scene_names=( 43 | "modern_bedroom_no_roof" 44 | "modern_office_no_roof" 45 | ) 46 | 47 | for scene_name in "${scene_names[@]}"; do 48 | for ckpt_path in "${ckpt_paths[@]}"; do 49 | for env_name in "${env_names[@]}"; do 50 | EXTRA_ARGS="--additional-env-build-kwargs shader_dir=rt model_ids=apple" 51 | EvalSim 52 | done 53 | done 54 | done 55 | 56 | 57 | # lightings 58 | scene_name=frl_apartment_stage_simple 59 | 60 | for ckpt_path in "${ckpt_paths[@]}"; do 61 | for env_name in "${env_names[@]}"; do 62 | EXTRA_ARGS="--additional-env-build-kwargs shader_dir=rt light_mode=brighter model_ids=apple" 63 | EvalSim 64 | EXTRA_ARGS="--additional-env-build-kwargs shader_dir=rt light_mode=darker model_ids=apple" 65 | EvalSim 66 | done 67 | done 68 | 69 | 70 | # new cabinets 71 | scene_name=frl_apartment_stage_simple 72 | 73 | for ckpt_path in "${ckpt_paths[@]}"; do 74 | for env_name in "${env_names[@]}"; do 75 | EXTRA_ARGS="--additional-env-build-kwargs shader_dir=rt station_name=mk_station2 model_ids=apple" 76 | EvalSim 77 | EXTRA_ARGS="--additional-env-build-kwargs shader_dir=rt station_name=mk_station3 model_ids=apple" 78 | EvalSim 79 | done 80 | done 81 | -------------------------------------------------------------------------------- /sim_cogact/scripts/cogact_put_in_drawer_visual_matching.sh: -------------------------------------------------------------------------------- 1 | # shader_dir=rt means that we turn on ray-tracing rendering; this is quite crucial for the open / close drawer task as policies often rely on shadows to infer depth 2 | gpu_id=0 3 | declare -a ckpt_paths=( 4 | "CogACT/CogACT-Base" 5 | ) 6 | 7 | 8 | declare -a env_names=( 9 | PlaceIntoClosedTopDrawerCustomInScene-v0 10 | # PlaceIntoClosedMiddleDrawerCustomInScene-v0 11 | # PlaceIntoClosedBottomDrawerCustomInScene-v0 12 | ) 13 | 14 | 15 | # URDF variations 16 | declare -a urdf_version_arr=("recolor_cabinet_visual_matching_1" "recolor_tabletop_visual_matching_1" "recolor_tabletop_visual_matching_2" None) 17 | 18 | for urdf_version in "${urdf_version_arr[@]}"; do 19 | 20 | EXTRA_ARGS="--enable-raytracing --additional-env-build-kwargs station_name=mk_station_recolor light_mode=simple disable_bad_material=True urdf_version=${urdf_version} model_ids=baked_apple_v2" 21 | 22 | 23 | EvalOverlay() { 24 | # A0 25 | CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model cogact --ckpt-path ${ckpt_path} \ 26 | --robot google_robot_static \ 27 | --control-freq 3 --sim-freq 513 --max-episode-steps 200 \ 28 | --env-name ${env_name} --scene-name dummy_drawer \ 29 | --robot-init-x 0.644 0.644 1 --robot-init-y -0.179 -0.179 1 \ 30 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 -0.03 -0.03 1 \ 31 | --obj-init-x-range -0.08 -0.02 3 --obj-init-y-range -0.02 0.08 3 \ 32 | --rgb-overlay-path ./ManiSkill2_real2sim/data/real_inpainting/open_drawer_a0.png \ 33 | ${EXTRA_ARGS} 34 | 35 | # B0 36 | CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model cogact --ckpt-path ${ckpt_path} \ 37 | --robot google_robot_static \ 38 | --control-freq 3 --sim-freq 513 --max-episode-steps 200 \ 39 | --env-name ${env_name} --scene-name dummy_drawer \ 40 | --robot-init-x 0.652 0.652 1 --robot-init-y 0.009 0.009 1 \ 41 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 0 0 1 \ 42 | --obj-init-x-range -0.08 -0.02 3 --obj-init-y-range -0.02 0.08 3 \ 43 | --rgb-overlay-path ./ManiSkill2_real2sim/data/real_inpainting/open_drawer_b0.png \ 44 | ${EXTRA_ARGS} 45 | 46 | # C0 47 | CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model cogact --ckpt-path ${ckpt_path} \ 48 | --robot google_robot_static \ 49 | --control-freq 3 --sim-freq 513 --max-episode-steps 200 \ 50 | --env-name ${env_name} --scene-name dummy_drawer \ 51 | --robot-init-x 0.665 0.665 1 --robot-init-y 0.224 0.224 1 \ 52 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 0 0 1 \ 53 | --obj-init-x-range -0.08 -0.02 3 --obj-init-y-range -0.02 0.08 3 \ 54 | --rgb-overlay-path ./ManiSkill2_real2sim/data/real_inpainting/open_drawer_c0.png \ 55 | ${EXTRA_ARGS} 56 | } 57 | 58 | 59 | for ckpt_path in "${ckpt_paths[@]}"; do 60 | for env_name in "${env_names[@]}"; do 61 | EvalOverlay 62 | done 63 | done 64 | 65 | 66 | 67 | done 68 | -------------------------------------------------------------------------------- /sim_cogact/scripts/octo_bridge.sh: -------------------------------------------------------------------------------- 1 | gpu_id=0 2 | declare -a policy_models=( 3 | "octo-small" 4 | "octo-base" 5 | # "octo-server" 6 | ) 7 | 8 | ckpt_path=None 9 | 10 | scene_name=bridge_table_1_v1 11 | robot=widowx 12 | rgb_overlay_path=ManiSkill2_real2sim/data/real_inpainting/bridge_real_eval_1.png 13 | robot_init_x=0.147 14 | robot_init_y=0.028 15 | 16 | for init_rng in 0 2 4 6 8; 17 | 18 | do for policy_model in "${policy_models[@]}"; 19 | 20 | do CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model ${policy_model} --ckpt-path ${ckpt_path} \ 21 | --robot ${robot} --policy-setup widowx_bridge --octo-init-rng ${init_rng} \ 22 | --control-freq 5 --sim-freq 500 --max-episode-steps 120 \ 23 | --env-name StackGreenCubeOnYellowCubeBakedTexInScene-v0 --scene-name ${scene_name} \ 24 | --rgb-overlay-path ${rgb_overlay_path} \ 25 | --robot-init-x ${robot_init_x} ${robot_init_x} 1 --robot-init-y ${robot_init_y} ${robot_init_y} 1 --obj-variation-mode episode --obj-episode-range 0 24 \ 26 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 0 0 1 \ 27 | --additional-env-save-tags octo_init_rng_${init_rng}; 28 | 29 | CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model ${policy_model} --ckpt-path ${ckpt_path} \ 30 | --robot ${robot} --policy-setup widowx_bridge --octo-init-rng ${init_rng} \ 31 | --control-freq 5 --sim-freq 500 --max-episode-steps 120 \ 32 | --env-name PutCarrotOnPlateInScene-v0 --scene-name ${scene_name} \ 33 | --rgb-overlay-path ${rgb_overlay_path} \ 34 | --robot-init-x ${robot_init_x} ${robot_init_x} 1 --robot-init-y ${robot_init_y} ${robot_init_y} 1 --obj-variation-mode episode --obj-episode-range 0 24 \ 35 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 0 0 1 \ 36 | --additional-env-save-tags octo_init_rng_${init_rng}; 37 | 38 | CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model ${policy_model} --ckpt-path ${ckpt_path} \ 39 | --robot ${robot} --policy-setup widowx_bridge --octo-init-rng ${init_rng} \ 40 | --control-freq 5 --sim-freq 500 --max-episode-steps 120 \ 41 | --env-name PutSpoonOnTableClothInScene-v0 --scene-name ${scene_name} \ 42 | --rgb-overlay-path ${rgb_overlay_path} \ 43 | --robot-init-x ${robot_init_x} ${robot_init_x} 1 --robot-init-y ${robot_init_y} ${robot_init_y} 1 --obj-variation-mode episode --obj-episode-range 0 24 \ 44 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 0 0 1 \ 45 | --additional-env-save-tags octo_init_rng_${init_rng}; 46 | 47 | done 48 | 49 | done 50 | 51 | 52 | 53 | 54 | 55 | scene_name=bridge_table_1_v2 56 | robot=widowx_sink_camera_setup 57 | rgb_overlay_path=ManiSkill2_real2sim/data/real_inpainting/bridge_sink.png 58 | robot_init_x=0.127 59 | robot_init_y=0.06 60 | 61 | for init_rng in 0 2 4 6 8; 62 | 63 | do for policy_model in "${policy_models[@]}"; 64 | 65 | do CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference.py --policy-model ${policy_model} --ckpt-path ${ckpt_path} \ 66 | --robot ${robot} --policy-setup widowx_bridge --octo-init-rng ${init_rng} \ 67 | --control-freq 5 --sim-freq 500 --max-episode-steps 120 \ 68 | --env-name PutEggplantInBasketScene-v0 --scene-name ${scene_name} \ 69 | --rgb-overlay-path ${rgb_overlay_path} \ 70 | --robot-init-x ${robot_init_x} ${robot_init_x} 1 --robot-init-y ${robot_init_y} ${robot_init_y} 1 --obj-variation-mode episode --obj-episode-range 0 24 \ 71 | --robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 0 0 1 \ 72 | --additional-env-save-tags octo_init_rng_${init_rng}; 73 | 74 | done 75 | 76 | done 77 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | from .materialize import get_train_strategy 2 | from .metrics import Metrics, VLAMetrics 3 | -------------------------------------------------------------------------------- /training/materialize.py: -------------------------------------------------------------------------------- 1 | """ 2 | materialize.py 3 | 4 | Factory class defining functions for instantiating various Training Strategies, supporting different VLMs, backbones, 5 | and strategy configurations. 6 | """ 7 | 8 | from typing import Callable, Optional, Union 9 | 10 | import torch 11 | 12 | from prismatic.models.vlms import PrismaticVLM 13 | from vla import CogACT 14 | from training.strategies import FSDPStrategy, TrainingStrategy 15 | # Registry =>> Maps ID --> {cls(), kwargs} :: supports FSDP for now, but DDP handler is also implemented! 16 | TRAIN_STRATEGIES = { 17 | "fsdp-shard-grad-op": {"cls": FSDPStrategy, "kwargs": {"sharding_strategy": "shard-grad-op"}}, 18 | "fsdp-full-shard": {"cls": FSDPStrategy, "kwargs": {"sharding_strategy": "full-shard"}}, 19 | } 20 | 21 | 22 | def get_train_strategy( 23 | train_strategy: str, 24 | vlm: Union[PrismaticVLM, CogACT], 25 | device_id: int, 26 | stage: str, 27 | epochs: int, 28 | max_steps: Optional[int], 29 | global_batch_size: int, 30 | per_device_batch_size: int, 31 | learning_rate: float, 32 | weight_decay: float, 33 | max_grad_norm: float, 34 | lr_scheduler_type: str, 35 | warmup_ratio: float, 36 | enable_gradient_checkpointing: bool = True, 37 | enable_mixed_precision_training: bool = True, 38 | reduce_in_full_precision: bool = False, 39 | mixed_precision_dtype: torch.dtype = torch.bfloat16, 40 | worker_init_fn: Optional[Callable[[int], None]] = None, 41 | ) -> TrainingStrategy: 42 | if train_strategy in TRAIN_STRATEGIES: 43 | strategy_cfg = TRAIN_STRATEGIES[train_strategy] 44 | strategy = strategy_cfg["cls"]( 45 | vlm=vlm, 46 | device_id=device_id, 47 | stage=stage, 48 | epochs=epochs, 49 | max_steps=max_steps, 50 | global_batch_size=global_batch_size, 51 | per_device_batch_size=per_device_batch_size, 52 | learning_rate=learning_rate, 53 | weight_decay=weight_decay, 54 | max_grad_norm=max_grad_norm, 55 | lr_scheduler_type=lr_scheduler_type, 56 | warmup_ratio=warmup_ratio, 57 | enable_gradient_checkpointing=enable_gradient_checkpointing, 58 | enable_mixed_precision_training=enable_mixed_precision_training, 59 | reduce_in_full_precision=reduce_in_full_precision, 60 | mixed_precision_dtype=mixed_precision_dtype, 61 | worker_init_fn=worker_init_fn, 62 | **strategy_cfg["kwargs"], 63 | ) 64 | return strategy 65 | else: 66 | raise ValueError(f"Train Strategy `{train_strategy}` is not supported!") 67 | -------------------------------------------------------------------------------- /training/metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | metrics.py 3 | 4 | Utility classes defining a Metrics container and multiple Trackers to enable model/stage-specific logging to various 5 | endpoints (e.g., JSONL local logs, Weights & Biases). 6 | """ 7 | 8 | import time 9 | from collections import defaultdict, deque 10 | from pathlib import Path 11 | from typing import Any, Dict, Optional, Protocol, Tuple, Union 12 | 13 | import jsonlines 14 | import numpy as np 15 | import torch 16 | import wandb 17 | 18 | from prismatic.overwatch import initialize_overwatch 19 | 20 | # Initialize Overwatch =>> Wraps `logging.Logger` 21 | overwatch = initialize_overwatch(__name__) 22 | 23 | 24 | # === Define Tracker Interface === 25 | class Tracker(Protocol): 26 | def write_hyperparameters(self) -> None: ... 27 | 28 | def write(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None: ... 29 | 30 | def finalize(self) -> None: ... 31 | 32 | 33 | # === Individual Tracker Definitions === 34 | class JSONLinesTracker: 35 | def __init__(self, run_id: str, run_dir: Path, hparams: Dict[str, Any]) -> None: 36 | self.run_id, self.run_dir, self.hparams = run_id, run_dir, hparams 37 | 38 | @overwatch.rank_zero_only 39 | def write_hyperparameters(self) -> None: 40 | with jsonlines.open(self.run_dir / "run-metrics.jsonl", mode="w", sort_keys=True) as js_tracker: 41 | js_tracker.write({"run_id": self.run_id, "hparams": self.hparams}) 42 | 43 | @overwatch.rank_zero_only 44 | def write(self, _: int, metrics: Dict[str, Union[int, float]]) -> None: 45 | with jsonlines.open(self.run_dir / f"{self.run_id}.jsonl", mode="a", sort_keys=True) as js_tracker: 46 | js_tracker.write(metrics) 47 | 48 | def finalize(self) -> None: 49 | return 50 | 51 | 52 | class WeightsBiasesTracker: 53 | def __init__( 54 | self, 55 | run_id: str, 56 | run_dir: Path, 57 | hparams: Dict[str, Any], 58 | project: str = "prismatic", 59 | entity: Optional[str] = None, 60 | group: str = "align", 61 | ) -> None: 62 | self.run_id, self.run_dir, self.hparams = run_id, run_dir, hparams 63 | 64 | # Get W&B-Specific Initialization Parameters 65 | self.project, self.entity, self.group, self.wandb_dir = project, entity, group, self.run_dir 66 | 67 | # Call W&B.init() 68 | self.initialize() 69 | 70 | @overwatch.rank_zero_only 71 | def initialize(self) -> None: 72 | wandb.init( 73 | name=self.run_id, 74 | dir=self.wandb_dir, 75 | config=self.hparams, 76 | project=self.project, 77 | entity=self.entity, 78 | group=self.group, 79 | ) 80 | 81 | @overwatch.rank_zero_only 82 | def write_hyperparameters(self) -> None: 83 | wandb.config = self.hparams 84 | 85 | @overwatch.rank_zero_only 86 | def write(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None: 87 | wandb.log(metrics, step=global_step) 88 | 89 | @staticmethod 90 | def finalize() -> None: 91 | if overwatch.is_rank_zero(): 92 | wandb.finish() 93 | 94 | # A job gets 210 seconds to get its affairs in order 95 | time.sleep(210) 96 | 97 | 98 | # === Core Metrics Container :: Initializes Trackers => Compiles/Pushes Metrics === 99 | 100 | 101 | class Metrics: 102 | def __init__( 103 | self, 104 | active_trackers: Tuple[str, ...], 105 | run_id: str, 106 | run_dir: Path, 107 | hparams: Dict[str, Any], 108 | stage: str, 109 | wandb_project: str = "prismatic", 110 | wandb_entity: Optional[str] = None, 111 | grad_accumulation_steps: int = 1, 112 | window_size: int = 128, 113 | ) -> None: 114 | self.run_id, self.run_dir, self.hparams, self.stage = run_id, run_dir, hparams, stage 115 | 116 | # Initialize Trackers 117 | self.trackers = [] 118 | for tracker_type in active_trackers: 119 | if tracker_type == "jsonl": 120 | tracker = JSONLinesTracker(run_id, run_dir, hparams) 121 | elif tracker_type == "wandb": 122 | tracker = WeightsBiasesTracker( 123 | run_id, run_dir, hparams, project=wandb_project, entity=wandb_entity, group=self.stage 124 | ) 125 | else: 126 | raise ValueError(f"Tracker with type `{tracker_type} is not supported!") 127 | 128 | # Add Hyperparameters --> add to `self.trackers` 129 | tracker.write_hyperparameters() 130 | self.trackers.append(tracker) 131 | 132 | # Create Universal Metrics Buffers 133 | self.global_step, self.start_time, self.step_start_time = 0, time.time(), time.time() 134 | self.state = { 135 | "loss_raw": deque(maxlen=grad_accumulation_steps), 136 | "loss": deque(maxlen=window_size), 137 | "step_time": deque(maxlen=window_size), 138 | "lr": [], 139 | } 140 | 141 | def log(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None: 142 | for tracker in self.trackers: 143 | tracker.write(global_step, metrics) 144 | 145 | def get_status(self, loss: Optional[torch.Tensor] = None) -> str: 146 | lr = self.state["lr"][-1] if len(self.state["lr"]) > 0 else 0 147 | if loss is None: 148 | return f"=>> [Global Step] {self.global_step:06d} =>> LR :: {lr:.6f}" 149 | 150 | # Otherwise, embed `loss` in status report! 151 | return f"=>> [Global Step] {self.global_step:06d} =>> LR :: {lr:.6f} -- Loss :: {loss:.4f}" 152 | 153 | def commit( 154 | self, *, global_step: Optional[int] = None, lr: Optional[float] = None, update_step_time: bool = False, **kwargs 155 | ) -> None: 156 | """Update all metrics in `self.state` by iterating through special positional arguments & kwargs.""" 157 | if global_step is not None: 158 | self.global_step = global_step 159 | 160 | # For all other variables --> only track on rank zero! 161 | if not overwatch.is_rank_zero(): 162 | return 163 | 164 | # Special Positional Arguments 165 | if lr is not None: 166 | self.state["lr"].append(lr) 167 | 168 | if update_step_time: 169 | self.state["step_time"].append(time.time() - self.step_start_time) 170 | self.step_start_time = time.time() 171 | 172 | # Generic Keyword Arguments 173 | for key, value in kwargs.items(): 174 | if key == "loss": 175 | loss_val = value.detach() 176 | self.state["loss_raw"].append(loss_val) 177 | self.state["loss"].append(loss_val) 178 | else: 179 | self.state[key].append(value.detach()) 180 | 181 | @overwatch.rank_zero_only 182 | def push(self) -> str: 183 | # Note :: Raw Loss is an Average over Gradient Accumulation Steps --> No Smoothing! 184 | loss_raw = torch.stack(list(self.state["loss_raw"])).mean().item() 185 | loss = torch.stack(list(self.state["loss"])).mean().item() 186 | step_time, lr = np.mean(list(self.state["step_time"])), self.state["lr"][-1] 187 | status = self.get_status(loss) 188 | 189 | # Fire to Trackers 190 | prefix = self.stage.capitalize() 191 | self.log( 192 | self.global_step, 193 | metrics={ 194 | f"{prefix}/Step": self.global_step, 195 | f"{prefix}/Loss": loss, 196 | f"{prefix}/Loss (Raw)": loss_raw, 197 | f"{prefix}/Learning Rate": lr, 198 | f"{prefix}/Step Time": step_time, 199 | }, 200 | ) 201 | return status 202 | 203 | def finalize(self) -> str: 204 | for tracker in self.trackers: 205 | tracker.finalize() 206 | 207 | 208 | class VLAMetrics: 209 | def __init__( 210 | self, 211 | active_trackers: Tuple[str, ...], 212 | run_id: str, 213 | run_dir: Path, 214 | hparams: Dict[str, Any], 215 | wandb_project: str = " ", 216 | wandb_entity: Optional[str] = "", 217 | grad_accumulation_steps: int = 1, 218 | window_size: int = 1, 219 | resume_step: Optional[int] = None, 220 | resume_epoch: Optional[int] = None, 221 | ) -> None: 222 | self.run_id, self.run_dir, self.hparams = run_id, run_dir, hparams 223 | 224 | # Initialize Trackers 225 | self.trackers = [] 226 | for tracker_type in active_trackers: 227 | if tracker_type == "jsonl": 228 | tracker = JSONLinesTracker(run_id, run_dir, hparams) 229 | elif tracker_type == "wandb": 230 | tracker = WeightsBiasesTracker( 231 | run_id, run_dir, hparams, project=wandb_project, entity=wandb_entity, group="vla-train" 232 | ) 233 | else: 234 | raise ValueError(f"Tracker with type `{tracker_type} is not supported!") 235 | 236 | # Add Hyperparameters --> add to `self.trackers` 237 | tracker.write_hyperparameters() 238 | self.trackers.append(tracker) 239 | 240 | # Create Universal Metrics Buffers 241 | self.global_step = 0 if resume_step is None else resume_step 242 | self.epoch = 0 if resume_epoch is None else resume_epoch 243 | self.start_time, self.step_start_time = time.time(), time.time() 244 | self.state = { 245 | "loss_raw": deque(maxlen=grad_accumulation_steps), 246 | "loss": deque(maxlen=window_size), 247 | "step_time": deque(maxlen=window_size), 248 | "lr": [], 249 | } 250 | 251 | # Created metrics buffers for individual tracked datasets 252 | self.dataset_trackers = defaultdict(lambda: VLAMetrics([], "", "", {})) 253 | 254 | def log(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None: 255 | for tracker in self.trackers: 256 | tracker.write(global_step, metrics) 257 | 258 | def get_status(self, loss: Optional[torch.Tensor] = None) -> str: 259 | lr = self.state["lr"][-1] if len(self.state["lr"]) > 0 else 0 260 | if loss is None: 261 | return f"=>> [Epoch {self.epoch:03d}] Global Step {self.global_step:06d} =>> LR :: {lr:.6f}" 262 | 263 | # Otherwise, embed `loss` in status report! 264 | return f"=>> [Epoch {self.epoch:03d}] Global Step {self.global_step:06d} =>> LR :: {lr:.6f} - Loss :: {loss:.4f}" 265 | 266 | def commit( 267 | self, 268 | *, 269 | global_step: Optional[int] = None, 270 | epoch: Optional[int] = None, 271 | lr: Optional[float] = None, 272 | update_step_time: bool = False, 273 | **kwargs, 274 | ) -> None: 275 | """Update all metrics in `self.state` by iterating through special positional arguments & kwargs.""" 276 | if global_step is not None: 277 | self.global_step = global_step 278 | 279 | if epoch is not None: 280 | self.epoch = epoch 281 | 282 | # For all other variables --> only track on rank zero! 283 | if not overwatch.is_rank_zero(): 284 | return 285 | 286 | # Special Positional Arguments 287 | if lr is not None: 288 | self.state["lr"].append(lr) 289 | 290 | if update_step_time: 291 | self.state["step_time"].append(time.time() - self.step_start_time) 292 | self.step_start_time = time.time() 293 | 294 | # Generic Keyword Arguments 295 | for key, value in kwargs.items(): 296 | if key == "loss": 297 | loss_val = value.detach() 298 | self.state["loss_raw"].append(loss_val) 299 | self.state["loss"].append(loss_val) 300 | else: 301 | self.state[key].append(value.detach()) 302 | 303 | def commit_for_dataset(self, dataset_name: str, **kwargs) -> None: 304 | self.dataset_trackers[dataset_name].commit(**kwargs) 305 | 306 | @overwatch.rank_zero_only 307 | def push(self) -> str: 308 | # Note :: Raw Loss is an Average over Gradient Accumulation Steps --> No Smoothing! 309 | loss_raw = torch.stack(list(self.state["loss_raw"])).mean().item() 310 | loss = torch.stack(list(self.state["loss"])).mean().item() 311 | step_time, lr = np.mean(list(self.state["step_time"])), self.state["lr"][-1] 312 | status = self.get_status(loss) 313 | 314 | 315 | # Fire to Trackers 316 | prefix = "VLA Train" 317 | self.log( 318 | self.global_step, 319 | metrics={ 320 | f"{prefix}/Step": self.global_step, 321 | f"{prefix}/Epoch": self.epoch, 322 | f"{prefix}/Loss": loss, 323 | f"{prefix}/Loss (Raw)": loss_raw, 324 | f"{prefix}/Learning Rate": lr, 325 | f"{prefix}/Step Time": step_time, 326 | }, 327 | ) 328 | return status 329 | 330 | def finalize(self) -> str: 331 | for tracker in self.trackers: 332 | tracker.finalize() 333 | -------------------------------------------------------------------------------- /training/strategies/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_strategy_cogact import TrainingStrategy 2 | from .fsdp import FSDPStrategy 3 | -------------------------------------------------------------------------------- /training/strategies/base_strategy_cogact.py: -------------------------------------------------------------------------------- 1 | """ 2 | base_strategy_cogact.py 3 | 4 | Abstract class definition of a (distributed) training strategy, with full annotations of class methods, utility 5 | functions, and initialization logic. 6 | 7 | Training Strategies (DDP, FSDP-Grad, FSDP-Full) tend to have a lot of repeated components; this class does a lot of 8 | heavy lifting. 9 | """ 10 | import torch 11 | import torchvision.transforms.functional as TF 12 | import torch.distributed as dist 13 | import numpy as np 14 | 15 | from abc import ABC, abstractmethod 16 | from pathlib import Path 17 | from typing import Callable, Optional, Union 18 | from tqdm import tqdm 19 | from transformers.modeling_outputs import CausalLMOutputWithPast 20 | from collections import OrderedDict 21 | from PIL import Image 22 | from torch.utils.data import DataLoader, Dataset, DistributedSampler, IterableDataset 23 | 24 | from prismatic.models.vlms import PrismaticVLM 25 | from prismatic.overwatch import initialize_overwatch 26 | from prismatic.training.metrics import Metrics, VLAMetrics 27 | from prismatic.util import check_bloat16_supported 28 | from prismatic.util.batching_utils import SplitModalitySampler 29 | from prismatic.util.data_utils import PaddedCollatorForActionPrediction, PaddedCollatorForLanguageModeling 30 | 31 | from vla import CogACT 32 | 33 | @torch.no_grad() 34 | def update_ema(ema_model, model, decay=0.9999): 35 | """ 36 | Step the EMA model towards the current model. 37 | """ 38 | ema_params = OrderedDict(ema_model.named_parameters()) 39 | model_params = OrderedDict(model.named_parameters()) 40 | for name, param in model_params.items(): 41 | # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed 42 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) 43 | 44 | # Initialize Overwatch =>> Wraps `logging.Logger` 45 | overwatch = initialize_overwatch(__name__) 46 | 47 | 48 | # === Abstract Base Class for an arbitrary Training Strategy === 49 | class TrainingStrategy(ABC): 50 | def __init__( 51 | self, 52 | vlm: Union[PrismaticVLM, CogACT], 53 | device_id: int, 54 | stage: str, 55 | epochs: int, 56 | max_steps: Optional[int], 57 | global_batch_size: int, 58 | per_device_batch_size: int, 59 | learning_rate: float, 60 | weight_decay: float, 61 | max_grad_norm: float, 62 | lr_scheduler_type: str, 63 | warmup_ratio: float, 64 | enable_gradient_checkpointing: bool = True, 65 | enable_mixed_precision_training: bool = True, 66 | reduce_in_full_precision: bool = False, 67 | mixed_precision_dtype: torch.dtype = torch.bfloat16, 68 | worker_init_fn: Optional[Callable[[int], None]] = None, 69 | repeated_diffusion_steps: int = 4, 70 | **_: str, 71 | ) -> None: 72 | self.vlm, self.device_id, self.stage = vlm, device_id, stage 73 | 74 | # Get relevant VLM instance parameters before they get (potentially) wrapped 75 | self.all_module_keys, self.trainable_module_keys = self.vlm.all_module_keys, self.vlm.trainable_module_keys 76 | self.llm_transformer_layer_cls = self.vlm.llm_backbone.transformer_layer_cls 77 | 78 | # Optimization Parameters 79 | self.epochs, self.max_steps = epochs, max_steps 80 | self.global_batch_size, self.per_device_batch_size = global_batch_size, per_device_batch_size 81 | 82 | self.learning_rate, self.weight_decay, self.max_grad_norm = learning_rate, weight_decay, max_grad_norm 83 | self.lr_scheduler_type, self.warmup_ratio = lr_scheduler_type, warmup_ratio 84 | 85 | # Generic Strategy Parameters 86 | self.enable_gradient_checkpointing = enable_gradient_checkpointing 87 | self.enable_mixed_precision_training = enable_mixed_precision_training 88 | self.reduce_in_full_precision = reduce_in_full_precision 89 | self.mixed_precision_dtype = mixed_precision_dtype 90 | self.repeated_diffusion_steps = repeated_diffusion_steps 91 | 92 | # DataLoader Parameters 93 | self.worker_init_fn = worker_init_fn 94 | 95 | # Optimizers & Scheduler (initialized in `run_setup`) 96 | self.optimizer, self.lr_scheduler = None, None 97 | 98 | # Lightweight Validation 99 | assert ( 100 | self.global_batch_size % self.per_device_batch_size == 0 101 | ), "Per-device batch size must evenly divide global batch size!" 102 | self.grad_accumulation_steps = self.global_batch_size // self.per_device_batch_size // overwatch.world_size() 103 | if self.enable_mixed_precision_training: 104 | assert self.mixed_precision_dtype == torch.bfloat16, "Only BF16 mixed precision training is supported!" 105 | assert check_bloat16_supported(), "BFloat16 is not supported on this hardware; unset `mixed_precision`" 106 | 107 | @abstractmethod 108 | def save_checkpoint( 109 | self, 110 | run_dir: Path, 111 | global_step: int, 112 | epoch: int, 113 | train_loss: Optional[float] = None, 114 | only_trainable: bool = True, 115 | ) -> None: ... 116 | 117 | @abstractmethod 118 | def load_optimizer_and_scheduler(self, checkpoint_path: str) -> None: ... 119 | 120 | @abstractmethod 121 | def run_setup(self, run_dir: Path, n_train_examples: int) -> None: ... 122 | 123 | @abstractmethod 124 | def clip_grad_norm(self) -> None: ... 125 | 126 | def run_training( 127 | self, 128 | dataset: Dataset, 129 | collator: PaddedCollatorForLanguageModeling, 130 | metrics: Metrics, 131 | stage: str = "finetune", 132 | batch_construction_strategy: str = "split-modality", 133 | seed: int = 7, 134 | ) -> None: 135 | """Run the training loop for the given `dataset` and `collator`; log losses, results to `metrics`""" 136 | if "finetune" in stage and batch_construction_strategy == "split-modality": 137 | # Instantiate the split-modality sampler; if you want to extend with other batch construction schemes, 138 | # (e.g., grouping by length) =>> can easily add them here! 139 | modality_lengths = dataset.get_modality_lengths() 140 | sampler = SplitModalitySampler( 141 | dataset, 142 | modality_lengths, 143 | global_batch_size=self.global_batch_size, 144 | num_replicas=overwatch.world_size(), 145 | rank=overwatch.rank(), 146 | seed=seed, 147 | drop_last=False, 148 | ) 149 | 150 | else: 151 | sampler = DistributedSampler( 152 | dataset, 153 | num_replicas=overwatch.world_size(), 154 | rank=overwatch.rank(), 155 | shuffle=True, 156 | seed=seed, 157 | drop_last=False, 158 | ) 159 | 160 | # Create a DataLoader with the initialized sampler, per-device-bsz, and collator 161 | dataloader = DataLoader( 162 | dataset, 163 | batch_size=self.per_device_batch_size, 164 | sampler=sampler, 165 | collate_fn=collator, 166 | num_workers=2, 167 | worker_init_fn=self.worker_init_fn, 168 | ) 169 | 170 | # Max Steps vs. Epochs Computation 171 | steps_per_epoch = len(dataloader) // self.grad_accumulation_steps 172 | if self.max_steps is not None and steps_per_epoch < self.max_steps: 173 | # Just set `epochs` to some large number --> we'll short-circuit based on steps anyway 174 | self.epochs = 100 175 | 176 | # === Train === 177 | status = metrics.get_status() 178 | with tqdm( 179 | total=( 180 | (self.epochs * (len(dataloader) // self.grad_accumulation_steps)) 181 | if self.max_steps is None 182 | else self.max_steps 183 | ), 184 | desc=status, 185 | leave=False, 186 | disable=not overwatch.is_rank_zero(), 187 | ) as progress: 188 | for epoch in range(self.epochs): 189 | self.vlm.train() 190 | sampler.set_epoch(epoch) 191 | 192 | # Zero-Gradients (just in case) 193 | self.optimizer.zero_grad() 194 | 195 | # Note that we'll unpack batch (and let AMP/FSDP do its thing) in the VLM.forward() call 196 | # => Basically, if we're using mixed precision (or not), autocast()/FSDP will move to device! 197 | for train_idx, batch in enumerate(dataloader): 198 | # [Contract] self.vlm.forward() must automatically compute `loss` and return! 199 | with torch.autocast( 200 | "cuda", 201 | dtype=self.mixed_precision_dtype, 202 | enabled=self.enable_mixed_precision_training, 203 | ): 204 | loss, output = self.vlm( 205 | input_ids=batch["input_ids"], 206 | attention_mask=batch["attention_mask"], 207 | pixel_values=batch["pixel_values"], 208 | labels=batch["labels"], 209 | multimodal_indices=batch["multimodal_indices"], 210 | repeated_diffusion_steps = self.repeated_diffusion_steps 211 | ) 212 | 213 | # Commit Loss (Prior to Gradient Accumulation Normalization) 214 | metrics.commit(loss=loss) 215 | normalized_loss = loss / self.grad_accumulation_steps 216 | normalized_loss.backward() 217 | 218 | # Step =>> Only if Done w/ Gradient Accumulation 219 | if (train_idx + 1) % self.grad_accumulation_steps == 0: 220 | metrics.commit(update_step_time=True) 221 | 222 | # Clip Gradients --> this is custom, per-strategy because of DDP vs. FSDP locality-assumptions 223 | self.clip_grad_norm() 224 | 225 | # Optimizer & LR Scheduler Step 226 | self.optimizer.step() 227 | self.lr_scheduler.step() 228 | self.optimizer.zero_grad() 229 | 230 | # Push Metrics 231 | metrics.commit(global_step=metrics.global_step + 1, lr=self.lr_scheduler.get_last_lr()[0]) 232 | status = metrics.push() 233 | 234 | # Check for Termination & Save Final Checkpoint (in case `max_steps` is not None) 235 | if self.max_steps is not None and metrics.global_step >= self.max_steps: 236 | self.save_checkpoint(metrics.run_dir, metrics.global_step, epoch, loss.item()) 237 | dist.barrier() 238 | 239 | return 240 | 241 | # Update Progress Bar 242 | progress.update() 243 | progress.set_description(status) 244 | 245 | # Save checkpoint at end each epoch (if `self.max_steps` is None) 246 | if self.max_steps is None: 247 | self.save_checkpoint(metrics.run_dir, metrics.global_step, epoch, loss.item()) 248 | dist.barrier() 249 | 250 | # === VLA Training === 251 | 252 | def run_vla_training( 253 | self, 254 | vla_dataset: IterableDataset, 255 | collator: PaddedCollatorForActionPrediction, 256 | metrics: VLAMetrics, 257 | save_interval: int = 2500, 258 | save_full_model: bool = True, 259 | action_model: bool = True, 260 | ) -> None: 261 | """Run the VLA training loop for the given `dataset` and `collator`; log losses, action metrics to `metrics`.""" 262 | assert isinstance(vla_dataset, IterableDataset), "VLA training expects an IterableDataset!" 263 | #assert self.grad_accumulation_steps == 1, "VLA training does not support gradient accumulation!" 264 | 265 | # Create a DataLoader =>> Set `num_workers` to 0; RLDS loader handles parallelism! 266 | dataloader = DataLoader( 267 | vla_dataset, 268 | batch_size=self.per_device_batch_size, 269 | sampler=None, 270 | collate_fn=collator, 271 | num_workers=0, 272 | worker_init_fn=self.worker_init_fn, 273 | ) 274 | 275 | # === Train === 276 | status = metrics.get_status() 277 | with tqdm( 278 | total=(self.epochs * (len(dataloader) // self.grad_accumulation_steps)) if self.max_steps is None else self.max_steps, 279 | desc=status, 280 | leave=False, 281 | disable=not overwatch.is_rank_zero(), 282 | ) as progress: 283 | self.vlm.train() 284 | 285 | # Zero Gradients (just in case) 286 | if self.vlm.use_ema is not None and self.vlm.use_ema == True: 287 | self.vlm.ema_diffusion.eval() 288 | self.optimizer.zero_grad() 289 | 290 | # [Contract] DataLoader wraps RLDS Loader (`.as_numpy_iterator() =>> implicit `.repeat()`) 291 | # => This means looping over the DataLoader is basically "infinite" (so no outer loop over epochs). 292 | # Slightly breaks default PyTorch semantics, which is why we adaptively compute `epoch` below. 293 | for train_idx, batch in enumerate(dataloader): 294 | # Note that we'll unpack batch (and let AMP/FSDP do its thing) in the VLM.forward() call 295 | # => Basically, if we're using mixed precision (or not), autocast()/FSDP will move to device! 296 | with torch.autocast( 297 | "cuda", dtype=self.mixed_precision_dtype, enabled=self.enable_mixed_precision_training 298 | ): 299 | if action_model: 300 | loss, output = self.vlm( 301 | input_ids=batch["input_ids"], 302 | attention_mask=batch["attention_mask"], 303 | actions=batch["actions"], 304 | pixel_values=batch["pixel_values"], 305 | action_masks=batch["action_masks"], 306 | labels=batch["labels"], 307 | output_hidden_states = True, 308 | ) 309 | else: 310 | # [Contract] self.vlm.forward() must automatically compute `loss` and return! 311 | output: CausalLMOutputWithPast = self.vlm( 312 | input_ids=batch["input_ids"], 313 | attention_mask=batch["attention_mask"], 314 | pixel_values=batch["pixel_values"], 315 | labels=batch["labels"], 316 | ) 317 | loss = output.loss 318 | 319 | # Commit Loss =>> Backward! 320 | metrics.commit(loss=loss) 321 | 322 | normalized_loss = loss / self.grad_accumulation_steps 323 | normalized_loss.backward() 324 | 325 | # === Gradient Step === 326 | # Step =>> Only if Done w/ Gradient Accumulation 327 | if (train_idx + 1) % self.grad_accumulation_steps == 0: 328 | # Clip Gradients --> this is custom, per-strategy because of DDP vs. FSDP locality-assumptions 329 | self.clip_grad_norm() 330 | 331 | # Optimizer & LR Scheduler Step 332 | self.optimizer.step() 333 | self.lr_scheduler.step() 334 | if self.vlm.use_ema is not None and self.vlm.use_ema == True: 335 | update_ema(self.vlm.ema_diffusion, self.vlm.action_model) 336 | self.optimizer.zero_grad() 337 | # Compute epoch value using number of completed gradient steps 338 | epoch = (metrics.global_step + 1) // (len(vla_dataset) // self.global_batch_size) 339 | 340 | # Push Metrics 341 | metrics.commit(update_step_time=True, global_step=metrics.global_step + 1, epoch=epoch, lr=self.lr_scheduler.get_last_lr()[0]) 342 | status = metrics.push() 343 | 344 | # Check for Save Interval or Max Steps & Save Checkpoint 345 | if (terminate := (self.max_steps is not None and metrics.global_step >= self.max_steps)) or ( 346 | (metrics.global_step % save_interval) == 0 347 | ): 348 | self.save_checkpoint( 349 | metrics.run_dir, metrics.global_step, epoch, loss.item(), only_trainable=not save_full_model 350 | ) 351 | dist.barrier() 352 | 353 | if terminate: 354 | return 355 | 356 | # Update Progress Bar 357 | progress.update() 358 | progress.set_description(status) -------------------------------------------------------------------------------- /training/strategies/fsdp.py: -------------------------------------------------------------------------------- 1 | """ 2 | fsdp.py 3 | 4 | Core class definition for a strategy implementing Torch native Fully Sharded Data Parallel Training (with support for 5 | fine-grained control over wrapping policies and mixed precision per component). 6 | """ 7 | 8 | import math 9 | from collections import OrderedDict 10 | from functools import partial 11 | from pathlib import Path 12 | from typing import Callable, Optional, Union 13 | 14 | import torch 15 | import torch.distributed as dist 16 | import torch.nn as nn 17 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 18 | CheckpointImpl, 19 | apply_activation_checkpointing, 20 | checkpoint_wrapper, 21 | ) 22 | from torch.distributed.fsdp import ( 23 | FullStateDictConfig, 24 | FullOptimStateDictConfig, 25 | MixedPrecision, 26 | ShardingStrategy, 27 | StateDictType, 28 | ) 29 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 30 | from torch.optim import AdamW 31 | from transformers.optimization import get_constant_schedule, get_cosine_schedule_with_warmup 32 | 33 | from prismatic.models.vlms import PrismaticVLM 34 | from prismatic.overwatch import initialize_overwatch 35 | 36 | from vla import CogACT 37 | from training.strategies.base_strategy_cogact import TrainingStrategy 38 | 39 | # Initialize Overwatch =>> Wraps `logging.Logger` 40 | overwatch = initialize_overwatch(__name__) 41 | 42 | 43 | class FSDPStrategy(TrainingStrategy): 44 | def __init__( 45 | self, 46 | vlm: Union[PrismaticVLM, CogACT], 47 | device_id: int, 48 | stage: str, 49 | epochs: int, 50 | max_steps: Optional[int], 51 | global_batch_size: int, 52 | per_device_batch_size: int, 53 | learning_rate: float, 54 | weight_decay: float, 55 | max_grad_norm: float, 56 | lr_scheduler_type: str, 57 | warmup_ratio: float, 58 | enable_gradient_checkpointing: bool = True, 59 | enable_mixed_precision_training: bool = True, 60 | reduce_in_full_precision: bool = False, 61 | mixed_precision_dtype: torch.dtype = torch.bfloat16, 62 | worker_init_fn: Optional[Callable[[int], None]] = None, 63 | sharding_strategy: str = "shard-grad-op", 64 | state_dict_type: StateDictType = StateDictType.FULL_STATE_DICT, 65 | ) -> None: 66 | super().__init__( 67 | vlm=vlm, 68 | device_id=device_id, 69 | stage=stage, 70 | epochs=epochs, 71 | max_steps=max_steps, 72 | global_batch_size=global_batch_size, 73 | per_device_batch_size=per_device_batch_size, 74 | learning_rate=learning_rate, 75 | weight_decay=weight_decay, 76 | max_grad_norm=max_grad_norm, 77 | lr_scheduler_type=lr_scheduler_type, 78 | warmup_ratio=warmup_ratio, 79 | enable_gradient_checkpointing=enable_gradient_checkpointing, 80 | enable_mixed_precision_training=enable_mixed_precision_training, 81 | reduce_in_full_precision=reduce_in_full_precision, 82 | mixed_precision_dtype=mixed_precision_dtype, 83 | worker_init_fn=worker_init_fn, 84 | ) 85 | 86 | # FSDP-Specific Parameters 87 | if sharding_strategy == "shard-grad-op": 88 | self.fsdp_sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2 89 | elif sharding_strategy == "full-shard": 90 | self.fsdp_sharding_strategy = ShardingStrategy.HYBRID_SHARD 91 | else: 92 | raise ValueError(f"FSDP Sharding Strategy {sharding_strategy} is not supported!") 93 | 94 | assert state_dict_type == StateDictType.FULL_STATE_DICT, "Sharded state saving is not yet implemented!" 95 | self.fsdp_state_dict_type = state_dict_type 96 | self.fsdp_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) 97 | self.fsdp_save_optimizer_policy = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True) 98 | 99 | def save_checkpoint( 100 | self, 101 | run_dir: Path, 102 | global_step: int, 103 | epoch: int, 104 | train_loss: Optional[float] = None, 105 | only_trainable: bool = True, 106 | ) -> None: 107 | """Save a checkpoint to the `run_dir` only containing the state_dicts for trainable parameters by default.""" 108 | assert isinstance(self.vlm, FSDP), "FSDPStrategy.save_checkpoint assumes VLM is already wrapped in FSDP!" 109 | 110 | # Summon Full State Dictionary =>> Reconstitute from Shards 111 | with FSDP.state_dict_type(self.vlm, self.fsdp_state_dict_type, self.fsdp_save_policy, self.fsdp_save_optimizer_policy): 112 | full_vlm_state_dict = self.vlm.state_dict() 113 | model_state_dicts = { 114 | mkey: OrderedDict() for mkey in (self.trainable_module_keys if only_trainable else self.all_module_keys) 115 | } 116 | 117 | # Iterate through `full_vlm_state_dict` and split `mkey.{full_dotted_path}` -> `mkey: {full_dotted_path}` 118 | for key, param in full_vlm_state_dict.items(): 119 | for mkey in model_state_dicts: 120 | if key.startswith(mprefix := f"{mkey}."): 121 | model_state_dicts[mkey][key.removeprefix(mprefix)] = param 122 | 123 | # Save on rank zero *only* 124 | if overwatch.is_rank_zero(): 125 | checkpoint_dir = run_dir / "checkpoints" 126 | if train_loss is None: 127 | checkpoint_path = checkpoint_dir / f"step-{global_step:06d}-epoch-{epoch:02d}-loss=inf.pt" 128 | else: 129 | checkpoint_path = ( 130 | checkpoint_dir / f"step-{global_step:06d}-epoch-{epoch:02d}-loss={train_loss:.4f}.pt" 131 | ) 132 | 133 | for key in list(model_state_dicts.keys()): 134 | if key.startswith("vlm."): 135 | value = model_state_dicts.pop(key) 136 | model_state_dicts[key[4:]] = value 137 | 138 | # Save Checkpoint & Copy Latest to `latest-checkpoint.pt` 139 | torch.save({"model": model_state_dicts}, checkpoint_path) 140 | dist.barrier() 141 | optim_state_dict = FSDP.optim_state_dict(self.vlm, self.optimizer) 142 | if overwatch.is_rank_zero(): 143 | optimizer_path = self._get_optimizer_path(checkpoint_path) 144 | optimizer_checkpoint = { 145 | "optimizer": optim_state_dict, 146 | "scheduler": { 147 | "epoch": epoch, 148 | "global_step": global_step, 149 | } 150 | } 151 | torch.save(optimizer_checkpoint, optimizer_path) 152 | overwatch.info(f"Saved optimizer state dict to {optimizer_path}") 153 | # TODO (siddk) :: This breaks w/ Sagemaker default permissions (root vs. )... skip? 154 | # shutil.copy(checkpoint_path, checkpoint_dir / "latest-checkpoint.pt") 155 | 156 | def _get_optimizer_path(self, checkpoint_path: Path) -> Path: 157 | """Get the path to the optimizer checkpoint file.""" 158 | return checkpoint_path.with_suffix(".optimizer") 159 | 160 | def load_optimizer_and_scheduler(self, checkpoint_path: str) -> None: 161 | """Load a checkpoint from the specified `checkpoint_path`.""" 162 | assert isinstance(self.vlm, FSDP), "FSDPStrategy.load_optimizer_and_scheduler assumes VLM is already wrapped in FSDP!" 163 | checkpoint_path = Path(checkpoint_path) 164 | optimizer_path = self._get_optimizer_path(checkpoint_path) 165 | if not optimizer_path.exists(): 166 | overwatch.warning(f"Optimizer checkpoint not found at {optimizer_path}!") 167 | return 168 | # Load Checkpoint =>> Note that FSDP will automatically handle device placement! 169 | optim_state_dict = torch.load(optimizer_path, map_location="cpu") 170 | with FSDP.state_dict_type(self.vlm, self.fsdp_state_dict_type, FullStateDictConfig(offload_to_cpu=True, rank0_only=False), FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False)): 171 | optim_state_dict = FSDP.optim_state_dict_to_load(self.vlm, self.optimizer, optim_state_dict["optimizer"]) 172 | self.optimizer.load_state_dict(optim_state_dict) 173 | overwatch.info(f"Loaded optimizer state dict from {optimizer_path}") 174 | 175 | def run_setup(self, run_dir: Path, n_train_examples: int) -> None: 176 | # Iteratively Assemble FSDP Wrapping Policy by fetching the wrapping policies for each backbone/constituent 177 | vlm_fsdp_wrapping_policy = self.vlm.get_fsdp_wrapping_policy() 178 | 179 | # Assemble the Default FSDP Mixed Precision Policy 180 | if self.enable_mixed_precision_training and self.mixed_precision_dtype == torch.bfloat16: 181 | # MixedPrecision `param_dtype` specifies *compute* dtype (for forward/backward only) 182 | # => Reference: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.MixedPrecision 183 | reduce_buffer_dtype = torch.bfloat16 if not self.reduce_in_full_precision else torch.float32 184 | fsdp_precision_policy = MixedPrecision( 185 | param_dtype=torch.bfloat16, reduce_dtype=reduce_buffer_dtype, buffer_dtype=reduce_buffer_dtype 186 | ) 187 | 188 | # When running FSDP with a frozen vision backbone --> move to half precision! 189 | if self.stage not in {"full-finetune", "vla-full-train", "vla-sandwich-train"}: 190 | overwatch.info("Casting Vision Backbone to *Half Precision* via `.to(dtype=...)`") 191 | self.vlm.vision_backbone.to(dtype=self.vlm.vision_backbone.half_precision_dtype) 192 | 193 | else: 194 | # If we're not using mixed precision, everything is in default full precision! 195 | fsdp_precision_policy = MixedPrecision( 196 | param_dtype=torch.float32, reduce_dtype=torch.float32, buffer_dtype=torch.float32 197 | ) 198 | 199 | # => note that FSDP will automatically take care of device placement (similar to `autocast`) 200 | self.vlm = FSDP( 201 | self.vlm, 202 | auto_wrap_policy=vlm_fsdp_wrapping_policy, 203 | mixed_precision=fsdp_precision_policy, 204 | sharding_strategy=self.fsdp_sharding_strategy, 205 | device_id=torch.cuda.current_device(), 206 | limit_all_gathers=True, 207 | use_orig_params=True, 208 | ) 209 | # Gradient Checkpoint Setup 210 | if self.enable_gradient_checkpointing: 211 | # For Gradient Checkpointing under FSDP --> we make the same assumption as in the DDP/other strategies; the 212 | # bulk of activation memory is taken up by the LLM activations. However, unlike other strategies, we 213 | # cannot rely on the HF Transformers default `gradient_checkpointing_enable()` --> FSDP breaks semantics! 214 | # 215 | # Instead, we need to write our own *NO-REENTRANT* wrapper, and apply it to the LLM's Transformer Layer. 216 | non_reentrant_wrapper = partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT) 217 | 218 | def check_fn(submodule: nn.Module) -> bool: 219 | return isinstance(submodule, self.llm_transformer_layer_cls) 220 | 221 | # Note that the terms "activation checkpointing" and "gradient checkpointing" are synonymous! 222 | apply_activation_checkpointing(self.vlm, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn) 223 | 224 | # Barrier =>> Sharding takes a minute? 225 | dist.barrier() 226 | 227 | # Create Optimizer and LR Scheduler =>> note that most of the LR Schedulers we use require `max_steps/epochs` 228 | # => Optimizer should only operate on parameters that are *unfrozen* / trainable! 229 | n_train_examples = math.ceil(n_train_examples / self.global_batch_size) * self.global_batch_size 230 | if self.max_steps is None: 231 | num_training_steps = (n_train_examples * self.epochs) // self.global_batch_size 232 | else: 233 | num_training_steps = self.max_steps 234 | 235 | if self.lr_scheduler_type == "linear-warmup+cosine-decay": 236 | # Set warmup steps (floor) based on `warmup_ratio` (should be 0.03 - 0.05) 237 | num_warmup_steps = int(num_training_steps * self.warmup_ratio) 238 | 239 | # Default AdamW w/ specified LR & Linear Warmup / Cosine Decay & Weight Decay 240 | # => Create Parameter Groups --> bias terms, normalization layer parameters shouldn't be decayed! 241 | decay, no_decay = [], [] 242 | for name, param in self.vlm.named_parameters(): 243 | if not param.requires_grad: 244 | continue 245 | 246 | # Check on any parameters with fewer than 2 dimensions or with "bias" in the name 247 | if param.ndim <= 1 or name.endswith(".bias"): 248 | no_decay.append(param) 249 | else: 250 | decay.append(param) 251 | 252 | # Build Parameter Groups 253 | groups = [{"params": decay, "weight_decay": self.weight_decay}, {"params": no_decay, "weight_decay": 0.0}] 254 | 255 | # Create Optimizer & LR Scheduler 256 | self.optimizer = AdamW(groups, lr=self.learning_rate) 257 | self.lr_scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps, num_training_steps) 258 | for param_group in self.optimizer.param_groups: 259 | param_group["lr"] = 0.0 260 | 261 | elif self.lr_scheduler_type == "constant": 262 | num_warmup_steps = 0 263 | 264 | # Default AdamW w/ specified LR & Linear Warmup / Cosine Decay & Weight Decay 265 | # => Create Parameter Groups --> bias terms, normalization layer parameters shouldn't be decayed! 266 | decay, no_decay = [], [] 267 | for name, param in self.vlm.named_parameters(): 268 | if not param.requires_grad: 269 | continue 270 | 271 | # Check on any parameters with fewer than 2 dimensions or with "bias" in the name 272 | if param.ndim <= 1 or name.endswith(".bias"): 273 | no_decay.append(param) 274 | else: 275 | decay.append(param) 276 | 277 | # Build Parameter Groups 278 | groups = [{"params": decay, "weight_decay": self.weight_decay}, {"params": no_decay, "weight_decay": 0.0}] 279 | 280 | # Create Optimizer & LR Scheduler 281 | self.optimizer = AdamW(groups, lr=self.learning_rate) 282 | self.lr_scheduler = get_constant_schedule(self.optimizer) 283 | 284 | else: 285 | raise ValueError(f"Learning Rate Schedule with type `{self.lr_scheduler_type}` is not supported!") 286 | 287 | # Finalize Setup =>> Log! 288 | overwatch.info( 289 | "FSDP Full-Shard Strategy =>> Finalized Training Setup:\n" 290 | f" |-> Global (Effective) Batch Size = {self.global_batch_size}\n" 291 | f" |-> Per-Device Batch Size = {self.per_device_batch_size}\n" 292 | f" |-> Distributed World Size = {overwatch.world_size()}\n" 293 | f" |-> Gradient Accumulation Steps = {self.grad_accumulation_steps}\n\n" 294 | f" |-> LLM Backbone FSDP Gradient Checkpointing = {self.enable_gradient_checkpointing}\n" 295 | f" |-> Use FSDP Mixed Precision = {self.enable_mixed_precision_training}\n" 296 | f" |-> Parameter Precision = {fsdp_precision_policy.param_dtype}\n" 297 | f" |-> Reduction Precision = {fsdp_precision_policy.reduce_dtype}\n" 298 | f" |-> Buffer Precision = {fsdp_precision_policy.buffer_dtype}\n\n" 299 | f" |-> Default AdamW LR = {self.learning_rate}\n" 300 | f" |-> AdamW Weight Decay = {self.weight_decay}\n" 301 | f" |-> LR Scheduler Type = {self.lr_scheduler_type}\n" 302 | f" |-> LR Scheduler Warmup Steps (Ratio) = {num_warmup_steps} ({self.warmup_ratio})\n" 303 | f" |-> Dataset Size = {n_train_examples} Examples\n" 304 | f" |-> Max Steps = {num_training_steps}\n" 305 | ) 306 | 307 | def clip_grad_norm(self) -> None: 308 | # Note =>> FSDP uses a custom `clip_grad_norm_` function; requires *uniform grad dtype* 309 | self.vlm.clip_grad_norm_(max_norm=self.max_grad_norm) 310 | -------------------------------------------------------------------------------- /vla/__init__.py: -------------------------------------------------------------------------------- 1 | from .cogactvla import CogACT 2 | from .load import available_model_names, available_models, get_model_description, load, load_vla -------------------------------------------------------------------------------- /vla/load.py: -------------------------------------------------------------------------------- 1 | """ 2 | load.py 3 | 4 | Entry point for loading pretrained VLMs for inference; exposes functions for listing available models (with canonical 5 | IDs, mappings to paper experiments, and short descriptions), as well as for loading models (from disk or HF Hub). 6 | """ 7 | 8 | import json 9 | import os 10 | from pathlib import Path 11 | from typing import List, Optional, Union 12 | 13 | from huggingface_hub import HfFileSystem, hf_hub_download 14 | 15 | from prismatic.conf import ModelConfig 16 | from prismatic.models.materialize import get_llm_backbone_and_tokenizer, get_vision_backbone_and_transform 17 | from prismatic.models.registry import GLOBAL_REGISTRY, MODEL_REGISTRY 18 | from prismatic.models.vlms import PrismaticVLM 19 | from prismatic.overwatch import initialize_overwatch 20 | 21 | from vla import CogACT 22 | 23 | # Initialize Overwatch =>> Wraps `logging.Logger` 24 | overwatch = initialize_overwatch(__name__) 25 | 26 | 27 | # === HF Hub Repository === 28 | HF_HUB_REPO = "TRI-ML/prismatic-vlms" 29 | 30 | # === Available Models === 31 | def available_models() -> List[str]: 32 | return list(MODEL_REGISTRY.keys()) 33 | 34 | 35 | def available_model_names() -> List[str]: 36 | return list(GLOBAL_REGISTRY.items()) 37 | 38 | 39 | def get_model_description(model_id_or_name: str) -> str: 40 | if model_id_or_name not in GLOBAL_REGISTRY: 41 | raise ValueError(f"Couldn't find `{model_id_or_name = }; check `prismatic.available_model_names()`") 42 | 43 | # Print Description & Return 44 | print(json.dumps(description := GLOBAL_REGISTRY[model_id_or_name]["description"], indent=2)) 45 | 46 | return description 47 | 48 | 49 | # === Load Pretrained Model === 50 | def load( 51 | model_id_or_path: Union[str, Path], 52 | hf_token: Optional[str] = None, 53 | cache_dir: Optional[Union[str, Path]] = None, 54 | load_for_training: bool = False, 55 | ) -> PrismaticVLM: 56 | """Loads a pretrained PrismaticVLM from either local disk or the HuggingFace Hub.""" 57 | if os.path.isdir(model_id_or_path): 58 | overwatch.info(f"Loading from local path `{(run_dir := Path(model_id_or_path))}`") 59 | 60 | # Get paths for `config.json` and pretrained checkpoint 61 | config_json, checkpoint_pt = run_dir / "config.json", run_dir / "checkpoints" / "latest-checkpoint.pt" 62 | assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`" 63 | assert checkpoint_pt.exists(), f"Missing checkpoint for `{run_dir = }`" 64 | else: 65 | if model_id_or_path not in GLOBAL_REGISTRY: 66 | raise ValueError(f"Couldn't find `{model_id_or_path = }; check `prismatic.available_model_names()`") 67 | 68 | overwatch.info(f"Downloading `{(model_id := GLOBAL_REGISTRY[model_id_or_path]['model_id'])} from HF Hub") 69 | with overwatch.local_zero_first(): 70 | config_json = hf_hub_download(repo_id=HF_HUB_REPO, filename=f"{model_id}/config.json", cache_dir=cache_dir) 71 | checkpoint_pt = hf_hub_download( 72 | repo_id=HF_HUB_REPO, filename=f"{model_id}/checkpoints/latest-checkpoint.pt", cache_dir=cache_dir 73 | ) 74 | 75 | # Load Model Config from `config.json` 76 | with open(config_json, "r") as f: 77 | model_cfg = json.load(f)["model"] 78 | 79 | # = Load Individual Components necessary for Instantiating a VLM = 80 | # =>> Print Minimal Config 81 | overwatch.info( 82 | f"Found Config =>> Loading & Freezing [bold blue]{model_cfg['model_id']}[/] with:\n" 83 | f" Vision Backbone =>> [bold]{model_cfg['vision_backbone_id']}[/]\n" 84 | f" LLM Backbone =>> [bold]{model_cfg['llm_backbone_id']}[/]\n" 85 | f" Arch Specifier =>> [bold]{model_cfg['arch_specifier']}[/]\n" 86 | f" Checkpoint Path =>> [underline]`{checkpoint_pt}`[/]" 87 | ) 88 | 89 | # Load Vision Backbone 90 | overwatch.info(f"Loading Vision Backbone [bold]{model_cfg['vision_backbone_id']}[/]") 91 | vision_backbone, image_transform = get_vision_backbone_and_transform( 92 | model_cfg["vision_backbone_id"], 93 | model_cfg["image_resize_strategy"], 94 | ) 95 | 96 | # Load LLM Backbone --> note `inference_mode = True` by default when calling `load()` 97 | overwatch.info(f"Loading Pretrained LLM [bold]{model_cfg['llm_backbone_id']}[/] via HF Transformers") 98 | llm_backbone, tokenizer = get_llm_backbone_and_tokenizer( 99 | model_cfg["llm_backbone_id"], 100 | llm_max_length=model_cfg.get("llm_max_length", 2048), 101 | hf_token=hf_token, 102 | inference_mode=not load_for_training, 103 | ) 104 | 105 | # Load VLM using `from_pretrained` (clobbers HF syntax... eventually should reconcile) 106 | overwatch.info(f"Loading VLM [bold blue]{model_cfg['model_id']}[/] from Checkpoint") 107 | vlm = PrismaticVLM.from_pretrained( 108 | checkpoint_pt, 109 | model_cfg["model_id"], 110 | vision_backbone, 111 | llm_backbone, 112 | arch_specifier=model_cfg["arch_specifier"], 113 | freeze_weights=not load_for_training, 114 | ) 115 | 116 | return vlm 117 | 118 | # === Load Pretrained VLA Model === 119 | def load_vla( 120 | model_id_or_path: Union[str, Path], 121 | hf_token: Optional[str] = None, 122 | cache_dir: Optional[Union[str, Path]] = None, 123 | load_for_training: bool = False, 124 | model_type: str = "pretrained", 125 | **kwargs, 126 | ) -> CogACT: 127 | """Loads a pretrained CogACT from either local disk or the HuggingFace Hub.""" 128 | 129 | # TODO (siddk, moojink) :: Unify semantics with `load()` above; right now, `load_vla()` assumes path points to 130 | # checkpoint `.pt` file, rather than the top-level run directory! 131 | if os.path.isfile(model_id_or_path): 132 | overwatch.info(f"Loading from local checkpoint path `{(checkpoint_pt := Path(model_id_or_path))}`") 133 | 134 | # [Validate] Checkpoint Path should look like `...//checkpoints/.pt` 135 | assert (checkpoint_pt.suffix == ".pt") and (checkpoint_pt.parent.name == "checkpoints"), "Invalid checkpoint!" 136 | run_dir = checkpoint_pt.parents[1] 137 | 138 | # Get paths for `config.json`, `dataset_statistics.json` and pretrained checkpoint 139 | config_json, dataset_statistics_json = run_dir / "config.json", run_dir / "dataset_statistics.json" 140 | assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`" 141 | assert dataset_statistics_json.exists(), f"Missing `dataset_statistics.json` for `{run_dir = }`" 142 | 143 | # Otherwise =>> try looking for a match on `model_id_or_path` on the HF Hub (`model_id_or_path`) 144 | else: 145 | # Search HF Hub Repo via fsspec API 146 | overwatch.info(f"Checking HF for `{(hf_path := str(Path(model_id_or_path)))}`") 147 | if not (tmpfs := HfFileSystem()).exists(hf_path): 148 | raise ValueError(f"Couldn't find valid HF Hub Path `{hf_path = }`") 149 | 150 | valid_ckpts = tmpfs.glob(f"{hf_path}/checkpoints/*.pt") 151 | if (len(valid_ckpts) == 0) or (len(valid_ckpts) != 1): 152 | raise ValueError(f"Couldn't find a valid checkpoint to load from HF Hub Path `{hf_path}/checkpoints/") 153 | 154 | target_ckpt = Path(valid_ckpts[-1]).name 155 | model_id_or_path = str(model_id_or_path) # Convert to string for HF Hub API 156 | overwatch.info(f"Downloading Model `{model_id_or_path}` Config & Checkpoint `{target_ckpt}`") 157 | with overwatch.local_zero_first(): 158 | # relpath = Path(model_type) / model_id_or_path 159 | config_json = hf_hub_download( 160 | repo_id=model_id_or_path, filename=f"{('config.json')!s}", cache_dir=cache_dir 161 | ) 162 | dataset_statistics_json = hf_hub_download( 163 | repo_id=model_id_or_path, filename=f"{('dataset_statistics.json')!s}", cache_dir=cache_dir 164 | ) 165 | checkpoint_pt = hf_hub_download( 166 | repo_id=model_id_or_path, filename=f"{(Path('checkpoints') / target_ckpt)!s}", cache_dir=cache_dir 167 | ) 168 | 169 | # Load VLA Config (and corresponding base VLM `ModelConfig`) from `config.json` 170 | with open(config_json, "r") as f: 171 | vla_cfg = json.load(f)["vla"] 172 | model_cfg = ModelConfig.get_choice_class(vla_cfg["base_vlm"])() 173 | 174 | # Load Dataset Statistics for Action Denormalization 175 | with open(dataset_statistics_json, "r") as f: 176 | norm_stats = json.load(f) 177 | 178 | # = Load Individual Components necessary for Instantiating a VLA (via base VLM components) = 179 | # =>> Print Minimal Config 180 | overwatch.info( 181 | f"Found Config =>> Loading & Freezing [bold blue]{model_cfg.model_id}[/] with:\n" 182 | f" Vision Backbone =>> [bold]{model_cfg.vision_backbone_id}[/]\n" 183 | f" LLM Backbone =>> [bold]{model_cfg.llm_backbone_id}[/]\n" 184 | f" Arch Specifier =>> [bold]{model_cfg.arch_specifier}[/]\n" 185 | f" Checkpoint Path =>> [underline]`{checkpoint_pt}`[/]" 186 | ) 187 | 188 | # Load Vision Backbone 189 | overwatch.info(f"Loading Vision Backbone [bold]{model_cfg.vision_backbone_id}[/]") 190 | vision_backbone, image_transform = get_vision_backbone_and_transform( 191 | model_cfg.vision_backbone_id, 192 | model_cfg.image_resize_strategy, 193 | ) 194 | 195 | # Load LLM Backbone --> note `inference_mode = True` by default when calling `load()` 196 | overwatch.info(f"Loading Pretrained LLM [bold]{model_cfg.llm_backbone_id}[/] via HF Transformers") 197 | llm_backbone, tokenizer = get_llm_backbone_and_tokenizer( 198 | model_cfg.llm_backbone_id, 199 | llm_max_length=model_cfg.llm_max_length, 200 | hf_token=hf_token, 201 | inference_mode=not load_for_training, 202 | ) 203 | 204 | # Load VLM using `from_pretrained` (clobbers HF syntax... eventually should reconcile) 205 | overwatch.info(f"Loading VLA [bold blue]{model_cfg.model_id}[/] from Checkpoint") 206 | 207 | vla = CogACT.from_pretrained( 208 | checkpoint_pt, 209 | model_cfg.model_id, 210 | vision_backbone, 211 | llm_backbone, 212 | arch_specifier=model_cfg.arch_specifier, 213 | freeze_weights=not load_for_training, 214 | norm_stats=norm_stats, 215 | **kwargs, 216 | ) 217 | 218 | return vla 219 | --------------------------------------------------------------------------------