├── .dockerignore ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── .python-version ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs ├── docker.md └── remote_inference.md ├── examples ├── aloha_real │ ├── Dockerfile │ ├── README.md │ ├── compose.yml │ ├── constants.py │ ├── convert_aloha_data_to_lerobot.py │ ├── env.py │ ├── main.py │ ├── real_env.py │ ├── requirements.in │ ├── requirements.txt │ ├── robot_utils.py │ └── video_display.py ├── aloha_sim │ ├── Dockerfile │ ├── README.md │ ├── compose.yml │ ├── env.py │ ├── main.py │ ├── requirements.in │ ├── requirements.txt │ └── saver.py ├── droid │ ├── README.md │ └── main.py ├── inference.ipynb ├── libero │ ├── Dockerfile │ ├── README.md │ ├── compose.yml │ ├── convert_libero_data_to_lerobot.py │ ├── main.py │ ├── requirements.in │ └── requirements.txt ├── policy_records.ipynb ├── simple_client │ ├── Dockerfile │ ├── README.md │ ├── compose.yml │ ├── main.py │ ├── requirements.in │ └── requirements.txt └── umi │ ├── README.md │ ├── convert_umi_data_to_lerobot.py │ ├── imagecodecs_numcodecs.py │ └── umi_replay_buffer.py ├── figures └── fisheye-aug.png ├── packages └── openpi-client │ ├── pyproject.toml │ └── src │ └── openpi_client │ ├── __init__.py │ ├── action_chunk_broker.py │ ├── base_policy.py │ ├── image_tools.py │ ├── image_tools_test.py │ ├── msgpack_numpy.py │ ├── msgpack_numpy_test.py │ ├── runtime │ ├── agent.py │ ├── agents │ │ └── policy_agent.py │ ├── environment.py │ ├── runtime.py │ └── subscriber.py │ └── websocket_client_policy.py ├── pyproject.toml ├── scripts ├── __init__.py ├── augment_vl_data │ ├── augment.py │ ├── finger_mask.jpg │ ├── gripper_lens_mask.jpg │ ├── inpainted.jpg │ └── lens.jpg ├── compute_norm_stats.py ├── docker │ ├── compose.yml │ ├── install_docker_ubuntu22.sh │ ├── install_nvidia_container_toolkit.sh │ └── serve_policy.Dockerfile ├── serve_policy.py ├── train.py └── train_test.py ├── src └── openpi │ ├── __init__.py │ ├── conftest.py │ ├── models │ ├── __init__.py │ ├── gemma.py │ ├── gemma_fast.py │ ├── lora.py │ ├── lora_test.py │ ├── model.py │ ├── model_test.py │ ├── pi0.py │ ├── pi0_fast.py │ ├── pi0_fuse.py │ ├── pi0_test.py │ ├── siglip.py │ ├── tokenizer.py │ ├── tokenizer_test.py │ └── vit.py │ ├── policies │ ├── aloha_policy.py │ ├── droid_policy.py │ ├── libero_policy.py │ ├── policy.py │ ├── policy_config.py │ ├── policy_test.py │ ├── pose_repr_util.py │ ├── pose_util.py │ ├── umi_dataset.py │ └── umi_policy.py │ ├── py.typed │ ├── serving │ └── websocket_policy_server.py │ ├── shared │ ├── __init__.py │ ├── array_typing.py │ ├── download.py │ ├── download_test.py │ ├── image_tools.py │ ├── image_tools_test.py │ ├── nnx_utils.py │ ├── normalize.py │ └── normalize_test.py │ ├── training │ ├── checkpoints.py │ ├── config.py │ ├── data_loader.py │ ├── data_loader_test.py │ ├── optimizer.py │ ├── sharding.py │ ├── utils.py │ └── weight_loaders.py │ ├── transforms.py │ └── transforms_test.py ├── train_scripts ├── train_onetwovla_cocktail.sh ├── train_onetwovla_visual_grounding.sh ├── train_pi0_cocktail.sh └── train_pi0_visual_grounding.sh └── uv.lock /.dockerignore: -------------------------------------------------------------------------------- 1 | .venv 2 | checkpoints 3 | data 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Data directories. 2 | assets/ 3 | checkpoints/ 4 | data/ 5 | wandb/ 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | cover/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | .pybuilder/ 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | # For a library or package, you might want to ignore these files since the code is 93 | # intended to run in multiple environments; otherwise, check them in: 94 | # .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | pretrain_model 170 | 171 | test_scripts/ 172 | google-cloud-sdk -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/aloha"] 2 | path = third_party/aloha 3 | url = git@github.com:Physical-Intelligence/aloha.git 4 | [submodule "third_party/libero"] 5 | path = third_party/libero 6 | url = git@github.com:Lifelong-Robot-Learning/LIBERO.git 7 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: third_party/ 2 | 3 | repos: 4 | - repo: https://github.com/astral-sh/uv-pre-commit 5 | # uv version. 6 | rev: 0.5.14 7 | hooks: 8 | - id: uv-lock 9 | - repo: https://github.com/astral-sh/ruff-pre-commit 10 | # Ruff version. 11 | rev: v0.8.6 12 | hooks: 13 | # Run the linter. 14 | - id: ruff 15 | args: [--fix] 16 | - id: ruff-format -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.11 -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to openpi 2 | 3 | We welcome contributions, improvements, and modifications. Everyone is welcome to use openpi in accordance to the [license](LICENSE). Contributors are also welcome to submit bug reports, feature requests, and pull requests. We can't promise to approve every pull request, and we are a small team with limited bandwidth to review all requests, but we'll give it our best effort. Specifics are described below. 4 | 5 | ## Issues and feature requests 6 | 7 | You are welcome to use the Github [discussion](https://github.com/Physical-Intelligence/openpi/discussions) feature if you would like to discuss something that is not directly reporting an issue or making a feature request. This is suitable for questions about how to use some aspect of openpi, or other topics. 8 | 9 | If you found a bug or other issue, please first check that the issue was not already reported (use the search bar on Github under Issues). If the issue has not yet been reported, please include this information when filing a Github issue: 10 | 11 | - Your OS type and version and the version of Python you are using 12 | - Code that allows us to reproduce your bug, including all dependencies 13 | - Traceback of any exception 14 | - Any other information that would help us, such as a screenshot 15 | 16 | In order for us to address any issue, we must be able to reproduce it, so if you encountered the issue after making modifications to openpi, please reproduce the issue without any other modifications and provide a code snippet that allows us to quickly reproduce the problem on `main`. 17 | 18 | If you would like to submit a feature request, please check that the feature request does not already exist, and please provide the following information: 19 | 20 | - The motivation for the feature 21 | - A description of the problem you are trying to solve or your use case 22 | - Enough information for us to understand the nature of the request 23 | - Some information for how you intend to use it (this might help us in understanding the motivation!) 24 | 25 | We can't promise to support every feature request, but it is helpful to us to know the use cases that you are interested in! 26 | 27 | ## Submitting a pull request 28 | 29 | If you implemented support for a new robot or environment, or some other new feature, we welcome pull requests (PRs) to openpi. We encourage you to create a [feature request](https://github.com/Physical-Intelligence/openpi/issues) or make a post on the [discussion](https://github.com/Physical-Intelligence/openpi/discussions) board before starting to work on your PR, if you would like to get a sense for whether we are likely to approve your PR if it is submitted. Since we are a small team with limited ability to provide maintenance and support, we may not accept all PRs (e.g., if we believe it would make the code harder to maintain, or if reviewing the PR is out of scope for us), so contacting us in advance is a good way to get a sense for whether your PR is likely to get approved for merging into openpi directly. But even if it isn't, you are of course more than welcome to maintain your own fork with whatever modifications you would like. When creating PRs, we recommend every contribution to consider the following: 30 | 31 | - Make sure that your PR has a clear title and description 32 | - Run `pre-commit` (install using `pre-commit install` first), and run `ruff check .` and `ruff format .` 33 | - Make sure your PR passes all tests 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Ruiqian Nai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OneTwoVLA: A Unified Vision-Language-Action Model with Adaptive Reasoning 2 | 3 | [[Project Page]](https://one-two-vla.github.io/) [[Paper]](https://arxiv.org/abs/2505.11917) [[Processed Datasets]](https://huggingface.co/datasets/Richard-Nai/onetwovla-dataset) 4 | 5 | [Fanqi Lin](https://fanqi-lin.github.io/)1,2,3,5\*, 6 | [Ruiqian Nai](https://richard-coder-nai.github.io/)1,2,3,5\*, 7 | [Yingdong Hu](https://yingdong-hu.github.io/)1,2,3\*, 8 | [Jiacheng You](https://scholar.google.com/citations?user=FiP-TVUAAAAJ)1,2,3, 9 | Junming Zhao1,4, 10 | [Yang Gao](https://yang-gao.weebly.com/)1,2,3,5 11 | 12 | 1Tsinghua University, 13 | 2Shanghai Qi Zhi Institute, 14 | 3Shanghai AI Lab, 15 | 4Fudan University, 16 | 5Spirit AI 17 | 18 | \* indicates equal contributions 19 | 20 | 21 | ## 🛠️ Installation 22 | 23 | We manage Python dependencies with [uv](https://docs.astral.sh/uv/). If you haven't installed `uv`, please follow [uv installation instructions](https://docs.astral.sh/uv/getting-started/installation/) to set it up. 24 | 25 | Run the following to set up the environment: 26 | 27 | ```bash 28 | GIT_LFS_SKIP_SMUDGE=1 uv sync 29 | GIT_LFS_SKIP_SMUDGE=1 uv pip install -e . 30 | ``` 31 | 32 | > NOTE: `GIT_LFS_SKIP_SMUDGE=1` is needed to pull LeRobot as a dependency. 33 | 34 | For more details, refer to the original [openpi repository](https://github.com/Physical-Intelligence/openpi.git). 35 | 36 | ## 🚀 Training OneTwoVLA 37 | Download the dataset and place them under `$LEROBOT_HOME/umi/`. 38 | 39 | To train a OneTwoVLA model, run: 40 | ```bash 41 | bash train_scripts/train_.sh 42 | ``` 43 | Available tasks are: 44 | ```bash 45 | train_scripts 46 | |-- train_onetwovla_cocktail.sh 47 | |-- train_onetwovla_visual_grounding.sh 48 | |-- train_pi0_cocktail.sh 49 | |-- train_pi0_visual_grounding.sh 50 | ``` 51 | 52 | ## 🦾 Real-World Deployment 53 | We run inference using a policy server and a hardware client. The instructions for running policy server can be found at [examples/umi/README.md](examples/umi/README.md), and we provide the UMI hardware client code in this [repository](https://github.com/Fanqi-Lin/OneTwoVLA-UMI-Client). 54 | 55 | ## 📷 Data 56 | We provide access to the following datasets: 57 | 58 | - `Robot Datasets`: Datasets for the `cocktail` and `open-world visual grounding` tasks. 59 | - `Vision-Language Datasets`: Datasets contains synthetic images and annotated reasoning for the `open-world visual grounding` task. 60 | 61 | All datasets are hosted on Hugging Face. You can find them [here](https://huggingface.co/datasets/Richard-Nai/onetwovla-dataset). 62 | 63 | We provide code for converting UMI data format to LeRobot data format [here](examples/umi/convert_umi_data_to_lerobot.py). 64 | 65 | ### Synthetic Image Augmentation 66 | 67 | To make the synthetic images more closely resemble real robot observations, we randomly apply several augmentations, including random fisheye distortion and compositing a robot gripper with adaptive brightness adjustments. The implementation is available in [scripts/augment_vl_data/augment.py](scripts/augment_vl_data/augment.py). 68 | 69 | Here we show an example. From left to right, the images are: the original image, the image with fisheye distortion, the image compositing a robot gripper with adaptive brightness adjustments, and the image with both applied. 70 | 71 | 72 | 73 | ## 🙏 Acknowledgements 74 | We express our sincere gratitude to the developers of the [openpi](https://github.com/Physical-Intelligence/openpi.git) for open-sourcing their code. 75 | -------------------------------------------------------------------------------- /docs/docker.md: -------------------------------------------------------------------------------- 1 | ### Docker Setup 2 | 3 | All of the examples in this repo provide instructions for being run normally, and also using Docker. Although not required, the Docker option is recommended as this will simplify software installation, produce a more stable environment, and also allow you to avoid installing ROS and cluttering your machine, for examples which depend on ROS. 4 | 5 | - Basic Docker installation instructions are [here](https://docs.docker.com/engine/install/). 6 | - Docker must be installed in [rootless mode](https://docs.docker.com/engine/security/rootless/). 7 | - To use your GPU you must also install the [NVIDIA container toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html). 8 | - The version of docker installed with `snap` is incompatible with the NVIDIA container toolkit, preventing it from accessing `libnvidia-ml.so` ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/154)). The snap version can be uninstalled with `sudo snap remove docker`. 9 | - Docker Desktop is also incompatible with the NVIDIA runtime ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/229)). Docker Desktop can be uninstalled with `sudo apt remove docker-desktop`. 10 | 11 | 12 | If starting from scratch and your host machine is Ubuntu 22.04, you can use accomplish all of the above with the convenience scripts `scripts/docker/install_docker_ubuntu22.sh` and `scripts/docker/install_nvidia_container_toolkit.sh`. 13 | 14 | During the first run of any example, Docker will build the images. Go grab a coffee while this happens. Subsequent runs will be faster since the images are cached. -------------------------------------------------------------------------------- /docs/remote_inference.md: -------------------------------------------------------------------------------- 1 | 2 | # Running openpi models remotely 3 | 4 | We provide utilities for running openpi models remotely. This is useful for running inference on more powerful GPUs off-robot, and also helps keep the robot and policy environments separate (and e.g. avoid dependency hell with robot software). 5 | 6 | ## Starting a remote policy server 7 | 8 | To start a remote policy server, you can simply run the following command: 9 | 10 | ```bash 11 | uv run scripts/serve_policy.py --env=[DROID | ALOHA | LIBERO] 12 | ``` 13 | 14 | The `env` argument specifies which $\pi_0$ checkpoint should be loaded. Under the hood, this script will execute a command like the following, which you can use to start a policy server, e.g. for checkpoints you trained yourself (here an example for the DROID environment): 15 | 16 | ```bash 17 | uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=s3://openpi-assets/checkpoints/pi0_fast_droid 18 | ``` 19 | 20 | This will start a policy server that will serve the policy specified by the `config` and `dir` arguments. The policy will be served on the specified port (default: 8000). 21 | 22 | ## Querying the remote policy server from your robot code 23 | 24 | We provide a client utility with minimal dependencies that you can easily embed into any robot codebase. 25 | 26 | First, install the `openpi-client` package in your robot environment: 27 | 28 | ```bash 29 | cd $OPENPI_ROOT/packages/openpi-client 30 | pip install -e . 31 | ``` 32 | 33 | Then, you can use the client to query the remote policy server from your robot code. Here's an example of how to do this: 34 | 35 | ```python 36 | from openpi_client import websocket_client_policy 37 | 38 | policy_client = websocket_client_policy.WebsocketClientPolicy(host="10.32.255.0", port=8000) 39 | action_chunk = policy_client.infer(example)["actions"] 40 | ``` 41 | 42 | Here, the `host` and `port` arguments specify the IP address and port of the remote policy server. You can also specify these as command-line arguments to your robot code, or hard-code them in your robot codebase. The `example` is a dictionary of observations and the prompt, following the specification of the policy inputs for the policy you are serving. We have concrete examples of how to construct this dictionary for different environments in the [simple client example](examples/simple_client/main.py). 43 | -------------------------------------------------------------------------------- /examples/aloha_real/Dockerfile: -------------------------------------------------------------------------------- 1 | # Dockerfile for the Aloha real environment. 2 | 3 | # Build the container: 4 | # docker build . -t aloha_real -f examples/aloha_real/Dockerfile 5 | 6 | # Run the container: 7 | # docker run --rm -it --network=host -v /dev:/dev -v .:/app --privileged aloha_real /bin/bash 8 | 9 | FROM ros:noetic-robot@sha256:0e12e4db836e78c74c4b04c6d16f185d9a18d2b13cf5580747efa075eb6dc6e0 10 | SHELL ["/bin/bash", "-c"] 11 | 12 | ENV DEBIAN_FRONTEND=noninteractive 13 | RUN apt-get update && \ 14 | apt-get install -y --no-install-recommends \ 15 | cmake \ 16 | curl \ 17 | libffi-dev \ 18 | python3-rosdep \ 19 | python3-rosinstall \ 20 | python3-rosinstall-generator \ 21 | whiptail \ 22 | git \ 23 | wget \ 24 | openssh-client \ 25 | ros-noetic-cv-bridge \ 26 | ros-noetic-usb-cam \ 27 | ros-noetic-realsense2-camera \ 28 | keyboard-configuration 29 | 30 | WORKDIR /root 31 | RUN curl 'https://raw.githubusercontent.com/Interbotix/interbotix_ros_manipulators/main/interbotix_ros_xsarms/install/amd64/xsarm_amd64_install.sh' > xsarm_amd64_install.sh 32 | RUN chmod +x xsarm_amd64_install.sh 33 | RUN export TZ='America/Los_Angeles' && ./xsarm_amd64_install.sh -d noetic -n 34 | 35 | COPY ./third_party/aloha /root/interbotix_ws/src/aloha 36 | RUN cd /root/interbotix_ws && source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && catkin_make 37 | 38 | # Install python 3.10 because this ROS image comes with 3.8 39 | RUN mkdir /python && \ 40 | cd /python && \ 41 | wget https://www.python.org/ftp/python/3.10.14/Python-3.10.14.tgz && \ 42 | tar -zxvf Python-3.10.14.tgz && \ 43 | cd Python-3.10.14 && \ 44 | ls -lhR && \ 45 | ./configure --enable-optimizations && \ 46 | make install && \ 47 | echo 'alias python3="/usr/local/bin/python3.10"' >> ~/.bashrc && \ 48 | echo 'alias python="/usr/local/bin/python3.10"' >> ~/.bashrc && \ 49 | cd ~ && rm -rf /python && \ 50 | rm -rf /var/lib/apt/lists/* 51 | 52 | COPY --from=ghcr.io/astral-sh/uv:0.5.6 /uv /bin/uv 53 | ENV UV_HTTP_TIMEOUT=120 54 | ENV UV_LINK_MODE=copy 55 | COPY ./examples/aloha_real/requirements.txt /tmp/requirements.txt 56 | COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml 57 | RUN uv pip sync --python 3.10 --system /tmp/requirements.txt /tmp/openpi-client/pyproject.toml 58 | 59 | ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src:/root/interbotix_ws/src/aloha/aloha_scripts:/root/interbotix_ws/src/aloha 60 | WORKDIR /app 61 | 62 | # Create an entrypoint script to run the setup commands, followed by the command passed in. 63 | RUN cat <<'EOF' > /usr/local/bin/entrypoint.sh 64 | #!/bin/bash 65 | source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && "$@" 66 | EOF 67 | RUN chmod +x /usr/local/bin/entrypoint.sh 68 | 69 | ENTRYPOINT ["/usr/local/bin/entrypoint.sh"] 70 | CMD ["python3", "/app/examples/aloha_real/main.py"] 71 | -------------------------------------------------------------------------------- /examples/aloha_real/compose.yml: -------------------------------------------------------------------------------- 1 | # Run with: 2 | # docker compose -f examples/aloha_real/compose.yml up --build 3 | services: 4 | runtime: 5 | image: aloha_real 6 | depends_on: 7 | - aloha_ros_nodes 8 | - ros_master 9 | - openpi_server 10 | build: 11 | context: ../.. 12 | dockerfile: examples/aloha_real/Dockerfile 13 | init: true 14 | tty: true 15 | network_mode: host 16 | privileged: true 17 | volumes: 18 | - $PWD:/app 19 | - ../../data:/data 20 | 21 | aloha_ros_nodes: 22 | image: aloha_real 23 | depends_on: 24 | - ros_master 25 | build: 26 | context: ../.. 27 | dockerfile: examples/aloha_real/Dockerfile 28 | init: true 29 | tty: true 30 | network_mode: host 31 | privileged: true 32 | volumes: 33 | - /dev:/dev 34 | command: roslaunch --wait aloha ros_nodes.launch 35 | 36 | ros_master: 37 | image: ros:noetic-robot 38 | network_mode: host 39 | privileged: true 40 | command: 41 | - roscore 42 | 43 | openpi_server: 44 | image: openpi_server 45 | build: 46 | context: ../.. 47 | dockerfile: scripts/docker/serve_policy.Dockerfile 48 | init: true 49 | tty: true 50 | network_mode: host 51 | volumes: 52 | - $PWD:/app 53 | - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets 54 | environment: 55 | - SERVER_ARGS 56 | - OPENPI_DATA_HOME=/openpi_assets 57 | - IS_DOCKER=true 58 | 59 | # Comment out this block if not running on a machine with GPUs. 60 | deploy: 61 | resources: 62 | reservations: 63 | devices: 64 | - driver: nvidia 65 | count: 1 66 | capabilities: [gpu] 67 | -------------------------------------------------------------------------------- /examples/aloha_real/constants.py: -------------------------------------------------------------------------------- 1 | # Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act). 2 | # ruff: noqa 3 | 4 | ### Task parameters 5 | 6 | ### ALOHA fixed constants 7 | DT = 0.001 8 | JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"] 9 | START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239] 10 | 11 | # Left finger position limits (qpos[7]), right_finger = -1 * left_finger 12 | MASTER_GRIPPER_POSITION_OPEN = 0.02417 13 | MASTER_GRIPPER_POSITION_CLOSE = 0.01244 14 | PUPPET_GRIPPER_POSITION_OPEN = 0.05800 15 | PUPPET_GRIPPER_POSITION_CLOSE = 0.01844 16 | 17 | # Gripper joint limits (qpos[6]) 18 | MASTER_GRIPPER_JOINT_OPEN = 0.3083 19 | MASTER_GRIPPER_JOINT_CLOSE = -0.6842 20 | PUPPET_GRIPPER_JOINT_OPEN = 1.4910 21 | PUPPET_GRIPPER_JOINT_CLOSE = -0.6213 22 | 23 | ############################ Helper functions ############################ 24 | 25 | MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / ( 26 | MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE 27 | ) 28 | PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / ( 29 | PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE 30 | ) 31 | MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = ( 32 | lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE 33 | ) 34 | PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = ( 35 | lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE 36 | ) 37 | MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x)) 38 | 39 | MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / ( 40 | MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE 41 | ) 42 | PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / ( 43 | PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE 44 | ) 45 | MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = ( 46 | lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE 47 | ) 48 | PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = ( 49 | lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE 50 | ) 51 | MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x)) 52 | 53 | MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) 54 | PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) 55 | 56 | MASTER_POS2JOINT = ( 57 | lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) 58 | + MASTER_GRIPPER_JOINT_CLOSE 59 | ) 60 | MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN( 61 | (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) 62 | ) 63 | PUPPET_POS2JOINT = ( 64 | lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) 65 | + PUPPET_GRIPPER_JOINT_CLOSE 66 | ) 67 | PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN( 68 | (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) 69 | ) 70 | 71 | MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2 72 | -------------------------------------------------------------------------------- /examples/aloha_real/env.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional # noqa: UP035 2 | 3 | import einops 4 | from openpi_client import image_tools 5 | from openpi_client.runtime import environment as _environment 6 | from typing_extensions import override 7 | 8 | from examples.aloha_real import real_env as _real_env 9 | 10 | 11 | class AlohaRealEnvironment(_environment.Environment): 12 | """An environment for an Aloha robot on real hardware.""" 13 | 14 | def __init__( 15 | self, 16 | reset_position: Optional[List[float]] = None, # noqa: UP006,UP007 17 | render_height: int = 224, 18 | render_width: int = 224, 19 | ) -> None: 20 | self._env = _real_env.make_real_env(init_node=True, reset_position=reset_position) 21 | self._render_height = render_height 22 | self._render_width = render_width 23 | 24 | self._ts = None 25 | 26 | @override 27 | def reset(self) -> None: 28 | self._ts = self._env.reset() 29 | 30 | @override 31 | def is_episode_complete(self) -> bool: 32 | return False 33 | 34 | @override 35 | def get_observation(self) -> dict: 36 | if self._ts is None: 37 | raise RuntimeError("Timestep is not set. Call reset() first.") 38 | 39 | obs = self._ts.observation 40 | for k in list(obs["images"].keys()): 41 | if "_depth" in k: 42 | del obs["images"][k] 43 | 44 | for cam_name in obs["images"]: 45 | img = image_tools.convert_to_uint8( 46 | image_tools.resize_with_pad(obs["images"][cam_name], self._render_height, self._render_width) 47 | ) 48 | obs["images"][cam_name] = einops.rearrange(img, "h w c -> c h w") 49 | 50 | return { 51 | "state": obs["qpos"], 52 | "images": obs["images"], 53 | } 54 | 55 | @override 56 | def apply_action(self, action: dict) -> None: 57 | self._ts = self._env.step(action["actions"]) 58 | -------------------------------------------------------------------------------- /examples/aloha_real/main.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import logging 3 | 4 | from openpi_client import action_chunk_broker 5 | from openpi_client import websocket_client_policy as _websocket_client_policy 6 | from openpi_client.runtime import runtime as _runtime 7 | from openpi_client.runtime.agents import policy_agent as _policy_agent 8 | import tyro 9 | 10 | from examples.aloha_real import env as _env 11 | 12 | 13 | @dataclasses.dataclass 14 | class Args: 15 | host: str = "0.0.0.0" 16 | port: int = 8000 17 | 18 | action_horizon: int = 25 19 | 20 | num_episodes: int = 1 21 | max_episode_steps: int = 1000 22 | 23 | 24 | def main(args: Args) -> None: 25 | ws_client_policy = _websocket_client_policy.WebsocketClientPolicy( 26 | host=args.host, 27 | port=args.port, 28 | ) 29 | logging.info(f"Server metadata: {ws_client_policy.get_server_metadata()}") 30 | 31 | metadata = ws_client_policy.get_server_metadata() 32 | runtime = _runtime.Runtime( 33 | environment=_env.AlohaRealEnvironment(reset_position=metadata.get("reset_pose")), 34 | agent=_policy_agent.PolicyAgent( 35 | policy=action_chunk_broker.ActionChunkBroker( 36 | policy=ws_client_policy, 37 | action_horizon=args.action_horizon, 38 | ) 39 | ), 40 | subscribers=[], 41 | max_hz=50, 42 | num_episodes=args.num_episodes, 43 | max_episode_steps=args.max_episode_steps, 44 | ) 45 | 46 | runtime.run() 47 | 48 | 49 | if __name__ == "__main__": 50 | logging.basicConfig(level=logging.INFO, force=True) 51 | tyro.cli(main) 52 | -------------------------------------------------------------------------------- /examples/aloha_real/requirements.in: -------------------------------------------------------------------------------- 1 | Pillow 2 | dm_control 3 | einops 4 | h5py 5 | matplotlib 6 | modern_robotics 7 | msgpack 8 | numpy 9 | opencv-python 10 | packaging 11 | pexpect 12 | pyquaternion 13 | pyrealsense2 14 | pyyaml 15 | requests 16 | rospkg 17 | tyro 18 | websockets 19 | -------------------------------------------------------------------------------- /examples/aloha_real/requirements.txt: -------------------------------------------------------------------------------- 1 | # This file was autogenerated by uv via the following command: 2 | # uv pip compile examples/aloha_real/requirements.in -o examples/aloha_real/requirements.txt --python-version 3.10 3 | absl-py==2.1.0 4 | # via 5 | # dm-control 6 | # dm-env 7 | # labmaze 8 | # mujoco 9 | catkin-pkg==1.0.0 10 | # via rospkg 11 | certifi==2024.8.30 12 | # via requests 13 | charset-normalizer==3.4.0 14 | # via requests 15 | contourpy==1.1.1 16 | # via matplotlib 17 | cycler==0.12.1 18 | # via matplotlib 19 | distro==1.9.0 20 | # via rospkg 21 | dm-control==1.0.23 22 | # via -r examples/aloha_real/requirements.in 23 | dm-env==1.6 24 | # via dm-control 25 | dm-tree==0.1.8 26 | # via 27 | # dm-control 28 | # dm-env 29 | docstring-parser==0.16 30 | # via tyro 31 | docutils==0.20.1 32 | # via catkin-pkg 33 | einops==0.8.0 34 | # via -r examples/aloha_real/requirements.in 35 | etils==1.3.0 36 | # via mujoco 37 | fonttools==4.55.2 38 | # via matplotlib 39 | glfw==2.8.0 40 | # via 41 | # dm-control 42 | # mujoco 43 | h5py==3.11.0 44 | # via -r examples/aloha_real/requirements.in 45 | idna==3.10 46 | # via requests 47 | importlib-resources==6.4.5 48 | # via etils 49 | kiwisolver==1.4.7 50 | # via matplotlib 51 | labmaze==1.0.6 52 | # via dm-control 53 | lxml==5.3.0 54 | # via dm-control 55 | markdown-it-py==3.0.0 56 | # via rich 57 | matplotlib==3.7.5 58 | # via -r examples/aloha_real/requirements.in 59 | mdurl==0.1.2 60 | # via markdown-it-py 61 | modern-robotics==1.1.1 62 | # via -r examples/aloha_real/requirements.in 63 | msgpack==1.1.0 64 | # via -r examples/aloha_real/requirements.in 65 | mujoco==3.2.3 66 | # via dm-control 67 | numpy==1.24.4 68 | # via 69 | # -r examples/aloha_real/requirements.in 70 | # contourpy 71 | # dm-control 72 | # dm-env 73 | # h5py 74 | # labmaze 75 | # matplotlib 76 | # modern-robotics 77 | # mujoco 78 | # opencv-python 79 | # pyquaternion 80 | # scipy 81 | opencv-python==4.10.0.84 82 | # via -r examples/aloha_real/requirements.in 83 | packaging==24.2 84 | # via 85 | # -r examples/aloha_real/requirements.in 86 | # matplotlib 87 | pexpect==4.9.0 88 | # via -r examples/aloha_real/requirements.in 89 | pillow==10.4.0 90 | # via 91 | # -r examples/aloha_real/requirements.in 92 | # matplotlib 93 | protobuf==5.29.1 94 | # via dm-control 95 | ptyprocess==0.7.0 96 | # via pexpect 97 | pygments==2.18.0 98 | # via rich 99 | pyopengl==3.1.7 100 | # via 101 | # dm-control 102 | # mujoco 103 | pyparsing==3.1.4 104 | # via 105 | # catkin-pkg 106 | # dm-control 107 | # matplotlib 108 | pyquaternion==0.9.9 109 | # via -r examples/aloha_real/requirements.in 110 | pyrealsense2==2.55.1.6486 111 | # via -r examples/aloha_real/requirements.in 112 | python-dateutil==2.9.0.post0 113 | # via 114 | # catkin-pkg 115 | # matplotlib 116 | pyyaml==6.0.2 117 | # via 118 | # -r examples/aloha_real/requirements.in 119 | # rospkg 120 | requests==2.32.3 121 | # via 122 | # -r examples/aloha_real/requirements.in 123 | # dm-control 124 | rich==13.9.4 125 | # via tyro 126 | rospkg==1.5.1 127 | # via -r examples/aloha_real/requirements.in 128 | scipy==1.10.1 129 | # via dm-control 130 | setuptools==75.3.0 131 | # via 132 | # catkin-pkg 133 | # dm-control 134 | # labmaze 135 | shtab==1.7.1 136 | # via tyro 137 | six==1.17.0 138 | # via python-dateutil 139 | tqdm==4.67.1 140 | # via dm-control 141 | typeguard==4.4.0 142 | # via tyro 143 | typing-extensions==4.12.2 144 | # via 145 | # etils 146 | # rich 147 | # typeguard 148 | # tyro 149 | tyro==0.9.2 150 | # via -r examples/aloha_real/requirements.in 151 | urllib3==2.2.3 152 | # via requests 153 | websockets==14.1 154 | # via -r examples/aloha_real/requirements.in 155 | zipp==3.20.2 156 | # via etils 157 | -------------------------------------------------------------------------------- /examples/aloha_real/video_display.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from openpi_client.runtime import subscriber as _subscriber 4 | from typing_extensions import override 5 | 6 | 7 | class VideoDisplay(_subscriber.Subscriber): 8 | """Displays video frames.""" 9 | 10 | def __init__(self) -> None: 11 | self._ax: plt.Axes | None = None 12 | self._plt_img: plt.Image | None = None 13 | 14 | @override 15 | def on_episode_start(self) -> None: 16 | plt.ion() 17 | self._ax = plt.subplot() 18 | self._plt_img = None 19 | 20 | @override 21 | def on_step(self, observation: dict, action: dict) -> None: 22 | assert self._ax is not None 23 | 24 | im = observation["image"][0] # [C, H, W] 25 | im = np.transpose(im, (1, 2, 0)) # [H, W, C] 26 | 27 | if self._plt_img is None: 28 | self._plt_img = self._ax.imshow(im) 29 | else: 30 | self._plt_img.set_data(im) 31 | plt.pause(0.001) 32 | 33 | @override 34 | def on_episode_end(self) -> None: 35 | plt.ioff() 36 | plt.close() 37 | -------------------------------------------------------------------------------- /examples/aloha_sim/Dockerfile: -------------------------------------------------------------------------------- 1 | # Dockerfile for the Aloha simulation environment. 2 | 3 | # Build the container: 4 | # docker build . -t aloha_sim -f examples/aloha_sim/Dockerfile 5 | 6 | # Run the container: 7 | # docker run --rm -it --network=host -v .:/app aloha_sim /bin/bash 8 | 9 | FROM python:3.11-slim@sha256:370c586a6ffc8c619e6d652f81c094b34b14b8f2fb9251f092de23f16e299b78 10 | COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/ 11 | 12 | RUN apt-get update && \ 13 | apt-get install -y \ 14 | libosmesa6-dev \ 15 | libgl1-mesa-glx \ 16 | libglew-dev \ 17 | libglfw3-dev \ 18 | libgles2-mesa-dev 19 | ENV MUJOCO_GL=egl 20 | 21 | WORKDIR /app 22 | 23 | # Copy from the cache instead of linking since it's a mounted volume 24 | ENV UV_LINK_MODE=copy 25 | 26 | # Write the virtual environment outside of the project directory so it doesn't 27 | # leak out of the container when we mount the application code. 28 | ENV UV_PROJECT_ENVIRONMENT=/.venv 29 | 30 | # Copy the requirements files so we can install dependencies. 31 | # The rest of the project is mounted as a volume, so we don't need to rebuild on changes. 32 | # This strategy is best for development-style usage. 33 | COPY ./examples/aloha_sim/requirements.txt /tmp/requirements.txt 34 | COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml 35 | 36 | # Install python dependencies. 37 | RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT 38 | RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml 39 | ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src 40 | 41 | CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/aloha_sim/main.py"] -------------------------------------------------------------------------------- /examples/aloha_sim/README.md: -------------------------------------------------------------------------------- 1 | # Run Aloha Sim 2 | 3 | ## With Docker 4 | 5 | ```bash 6 | export SERVER_ARGS="--env ALOHA_SIM" 7 | docker compose -f examples/aloha_sim/compose.yml up --build 8 | ``` 9 | 10 | ## Without Docker 11 | 12 | Terminal window 1: 13 | 14 | ```bash 15 | # Create virtual environment 16 | uv venv --python 3.10 examples/aloha_sim/.venv 17 | source examples/aloha_sim/.venv/bin/activate 18 | uv pip sync examples/aloha_sim/requirements.txt 19 | uv pip install -e packages/openpi-client 20 | 21 | # Run the simulation 22 | MUJOCO_GL=egl python examples/aloha_sim/main.py 23 | ``` 24 | 25 | Note: If you are seeing EGL errors, you may need to install the following dependencies: 26 | 27 | ```bash 28 | sudo apt-get install -y libegl1-mesa-dev libgles2-mesa-dev 29 | ``` 30 | 31 | Terminal window 2: 32 | 33 | ```bash 34 | # Run the server 35 | uv run scripts/serve_policy.py --env ALOHA_SIM 36 | ``` 37 | -------------------------------------------------------------------------------- /examples/aloha_sim/compose.yml: -------------------------------------------------------------------------------- 1 | # Run with: 2 | # docker compose -f examples/aloha_sim/compose.yml up --build 3 | services: 4 | runtime: 5 | image: aloha_sim 6 | depends_on: 7 | - openpi_server 8 | build: 9 | context: ../.. 10 | dockerfile: examples/aloha_sim/Dockerfile 11 | init: true 12 | tty: true 13 | network_mode: host 14 | privileged: true 15 | volumes: 16 | - $PWD:/app 17 | - ../../data:/data 18 | 19 | openpi_server: 20 | image: openpi_server 21 | build: 22 | context: ../.. 23 | dockerfile: scripts/docker/serve_policy.Dockerfile 24 | init: true 25 | tty: true 26 | network_mode: host 27 | volumes: 28 | - $PWD:/app 29 | - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets 30 | environment: 31 | - SERVER_ARGS 32 | - OPENPI_DATA_HOME=/openpi_assets 33 | - IS_DOCKER=true 34 | 35 | # Comment out this block if not running on a machine with GPUs. 36 | deploy: 37 | resources: 38 | reservations: 39 | devices: 40 | - driver: nvidia 41 | count: 1 42 | capabilities: [gpu] 43 | -------------------------------------------------------------------------------- /examples/aloha_sim/env.py: -------------------------------------------------------------------------------- 1 | import gym_aloha # noqa: F401 2 | import gymnasium 3 | import numpy as np 4 | from openpi_client import image_tools 5 | from openpi_client.runtime import environment as _environment 6 | from typing_extensions import override 7 | 8 | 9 | class AlohaSimEnvironment(_environment.Environment): 10 | """An environment for an Aloha robot in simulation.""" 11 | 12 | def __init__(self, task: str, obs_type: str = "pixels_agent_pos", seed: int = 0) -> None: 13 | np.random.seed(seed) 14 | self._rng = np.random.default_rng(seed) 15 | 16 | self._gym = gymnasium.make(task, obs_type=obs_type) 17 | 18 | self._last_obs = None 19 | self._done = True 20 | self._episode_reward = 0.0 21 | 22 | @override 23 | def reset(self) -> None: 24 | gym_obs, _ = self._gym.reset(seed=int(self._rng.integers(2**32 - 1))) 25 | self._last_obs = self._convert_observation(gym_obs) # type: ignore 26 | self._done = False 27 | self._episode_reward = 0.0 28 | 29 | @override 30 | def is_episode_complete(self) -> bool: 31 | return self._done 32 | 33 | @override 34 | def get_observation(self) -> dict: 35 | if self._last_obs is None: 36 | raise RuntimeError("Observation is not set. Call reset() first.") 37 | 38 | return self._last_obs # type: ignore 39 | 40 | @override 41 | def apply_action(self, action: dict) -> None: 42 | gym_obs, reward, terminated, truncated, info = self._gym.step(action["actions"]) 43 | self._last_obs = self._convert_observation(gym_obs) # type: ignore 44 | self._done = terminated or truncated 45 | self._episode_reward = max(self._episode_reward, reward) 46 | 47 | def _convert_observation(self, gym_obs: dict) -> dict: 48 | img = gym_obs["pixels"]["top"] 49 | img = image_tools.convert_to_uint8(image_tools.resize_with_pad(img, 224, 224)) 50 | # Convert axis order from [H, W, C] --> [C, H, W] 51 | img = np.transpose(img, (2, 0, 1)) 52 | 53 | return { 54 | "state": gym_obs["agent_pos"], 55 | "images": {"cam_high": img}, 56 | } 57 | -------------------------------------------------------------------------------- /examples/aloha_sim/main.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import logging 3 | import pathlib 4 | 5 | import env as _env 6 | from openpi_client import action_chunk_broker 7 | from openpi_client import websocket_client_policy as _websocket_client_policy 8 | from openpi_client.runtime import runtime as _runtime 9 | from openpi_client.runtime.agents import policy_agent as _policy_agent 10 | import saver as _saver 11 | import tyro 12 | 13 | 14 | @dataclasses.dataclass 15 | class Args: 16 | out_dir: pathlib.Path = pathlib.Path("data/aloha_sim/videos") 17 | 18 | task: str = "gym_aloha/AlohaTransferCube-v0" 19 | seed: int = 0 20 | 21 | action_horizon: int = 10 22 | 23 | host: str = "0.0.0.0" 24 | port: int = 8000 25 | 26 | display: bool = False 27 | 28 | 29 | def main(args: Args) -> None: 30 | runtime = _runtime.Runtime( 31 | environment=_env.AlohaSimEnvironment( 32 | task=args.task, 33 | seed=args.seed, 34 | ), 35 | agent=_policy_agent.PolicyAgent( 36 | policy=action_chunk_broker.ActionChunkBroker( 37 | policy=_websocket_client_policy.WebsocketClientPolicy( 38 | host=args.host, 39 | port=args.port, 40 | ), 41 | action_horizon=args.action_horizon, 42 | ) 43 | ), 44 | subscribers=[ 45 | _saver.VideoSaver(args.out_dir), 46 | ], 47 | max_hz=50, 48 | ) 49 | 50 | runtime.run() 51 | 52 | 53 | if __name__ == "__main__": 54 | logging.basicConfig(level=logging.INFO, force=True) 55 | tyro.cli(main) 56 | -------------------------------------------------------------------------------- /examples/aloha_sim/requirements.in: -------------------------------------------------------------------------------- 1 | gym-aloha 2 | imageio 3 | matplotlib 4 | msgpack 5 | numpy 6 | typing-extensions 7 | tyro 8 | websockets -------------------------------------------------------------------------------- /examples/aloha_sim/requirements.txt: -------------------------------------------------------------------------------- 1 | # This file was autogenerated by uv via the following command: 2 | # uv pip compile examples/aloha_sim/requirements.in -o examples/aloha_sim/requirements.txt --python-version 3.10 3 | absl-py==2.1.0 4 | # via 5 | # dm-control 6 | # dm-env 7 | # labmaze 8 | # mujoco 9 | certifi==2024.8.30 10 | # via requests 11 | charset-normalizer==3.4.0 12 | # via requests 13 | cloudpickle==3.1.0 14 | # via gymnasium 15 | contourpy==1.3.1 16 | # via matplotlib 17 | cycler==0.12.1 18 | # via matplotlib 19 | dm-control==1.0.14 20 | # via gym-aloha 21 | dm-env==1.6 22 | # via dm-control 23 | dm-tree==0.1.8 24 | # via 25 | # dm-control 26 | # dm-env 27 | docstring-parser==0.16 28 | # via tyro 29 | farama-notifications==0.0.4 30 | # via gymnasium 31 | fonttools==4.55.2 32 | # via matplotlib 33 | glfw==2.8.0 34 | # via 35 | # dm-control 36 | # mujoco 37 | gym-aloha==0.1.1 38 | # via -r examples/aloha_sim/requirements.in 39 | gymnasium==1.0.0 40 | # via gym-aloha 41 | idna==3.10 42 | # via requests 43 | imageio==2.36.1 44 | # via 45 | # -r examples/aloha_sim/requirements.in 46 | # gym-aloha 47 | imageio-ffmpeg==0.5.1 48 | # via imageio 49 | kiwisolver==1.4.7 50 | # via matplotlib 51 | labmaze==1.0.6 52 | # via dm-control 53 | lxml==5.3.0 54 | # via dm-control 55 | markdown-it-py==3.0.0 56 | # via rich 57 | matplotlib==3.9.3 58 | # via -r examples/aloha_sim/requirements.in 59 | mdurl==0.1.2 60 | # via markdown-it-py 61 | msgpack==1.1.0 62 | # via -r examples/aloha_sim/requirements.in 63 | mujoco==2.3.7 64 | # via 65 | # dm-control 66 | # gym-aloha 67 | numpy==1.26.4 68 | # via 69 | # -r examples/aloha_sim/requirements.in 70 | # contourpy 71 | # dm-control 72 | # dm-env 73 | # gymnasium 74 | # imageio 75 | # labmaze 76 | # matplotlib 77 | # mujoco 78 | # scipy 79 | packaging==24.2 80 | # via matplotlib 81 | pillow==11.0.0 82 | # via 83 | # imageio 84 | # matplotlib 85 | protobuf==5.29.1 86 | # via dm-control 87 | psutil==6.1.0 88 | # via imageio 89 | pygments==2.18.0 90 | # via rich 91 | pyopengl==3.1.7 92 | # via 93 | # dm-control 94 | # mujoco 95 | pyparsing==3.2.0 96 | # via 97 | # dm-control 98 | # matplotlib 99 | python-dateutil==2.9.0.post0 100 | # via matplotlib 101 | requests==2.32.3 102 | # via dm-control 103 | rich==13.9.4 104 | # via tyro 105 | scipy==1.14.1 106 | # via dm-control 107 | setuptools==75.6.0 108 | # via 109 | # dm-control 110 | # imageio-ffmpeg 111 | # labmaze 112 | shtab==1.7.1 113 | # via tyro 114 | six==1.17.0 115 | # via python-dateutil 116 | tqdm==4.67.1 117 | # via dm-control 118 | typeguard==4.4.1 119 | # via tyro 120 | typing-extensions==4.12.2 121 | # via 122 | # -r examples/aloha_sim/requirements.in 123 | # gymnasium 124 | # rich 125 | # typeguard 126 | # tyro 127 | tyro==0.9.2 128 | # via -r examples/aloha_sim/requirements.in 129 | urllib3==2.2.3 130 | # via requests 131 | websockets==14.1 132 | # via -r examples/aloha_sim/requirements.in 133 | -------------------------------------------------------------------------------- /examples/aloha_sim/saver.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pathlib 3 | 4 | import imageio 5 | import numpy as np 6 | from openpi_client.runtime import subscriber as _subscriber 7 | from typing_extensions import override 8 | 9 | 10 | class VideoSaver(_subscriber.Subscriber): 11 | """Saves episode data.""" 12 | 13 | def __init__(self, out_dir: pathlib.Path, subsample: int = 1) -> None: 14 | out_dir.mkdir(parents=True, exist_ok=True) 15 | self._out_dir = out_dir 16 | self._images: list[np.ndarray] = [] 17 | self._subsample = subsample 18 | 19 | @override 20 | def on_episode_start(self) -> None: 21 | self._images = [] 22 | 23 | @override 24 | def on_step(self, observation: dict, action: dict) -> None: 25 | im = observation["images"]["cam_high"] # [C, H, W] 26 | im = np.transpose(im, (1, 2, 0)) # [H, W, C] 27 | self._images.append(im) 28 | 29 | @override 30 | def on_episode_end(self) -> None: 31 | existing = list(self._out_dir.glob("out_[0-9]*.mp4")) 32 | next_idx = max([int(p.stem.split("_")[1]) for p in existing], default=-1) + 1 33 | out_path = self._out_dir / f"out_{next_idx}.mp4" 34 | 35 | logging.info(f"Saving video to {out_path}") 36 | imageio.mimwrite( 37 | out_path, 38 | [np.asarray(x) for x in self._images[:: self._subsample]], 39 | fps=50 // max(1, self._subsample), 40 | ) 41 | -------------------------------------------------------------------------------- /examples/droid/README.md: -------------------------------------------------------------------------------- 1 | # Run DROID 2 | 3 | This example shows how to run the fine-tuned $\pi_0$-FAST-DROID model on the [DROID robot platform](https://github.com/droid-dataset/droid). We also offer a $\pi_0$-DROID model that is fine-tuned from $\pi_0$ and uses flow action decoding. You can use it by replacing `pi0_fast_droid` with `pi0_droid` in the commands below. In practice, we find that out-of-the-box, the $\pi_0$-FAST-DROID model is better at following language commands, so we recommend it as the default checkpoint for DROID evaluation. If you want to fine-tune on a DROID task that requires a fast-to-inference policy, you may still want to consider using the $\pi_0$-DROID model, since it decodes faster. For more details, please see the [FAST paper](https://pi.website/research/fast). 4 | 5 | 6 | ## Step 1: Start a policy server 7 | 8 | Since the DROID control laptop does not have a powerful GPU, we will start a remote policy server on a different machine with a more powerful GPU and then query it from the DROID control laptop during inference. 9 | 10 | 1. On a machine with a powerful GPU (~NVIDIA 4090), clone and install the `openpi` repository following the instructions in the [README](https://github.com/Physical-Intelligence/openpi). 11 | 2. Start the OpenPI server via the following command: 12 | 13 | ```bash 14 | uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=s3://openpi-assets/checkpoints/pi0_fast_droid 15 | ``` 16 | 17 | You can also run the equivalent command below: 18 | 19 | ```bash 20 | uv run scripts/serve_policy.py --env=DROID 21 | ``` 22 | 23 | ## Step 2: Run the DROID robot 24 | 25 | 1. Make sure you have the most recent version of the DROID package installed on both the DROID control laptop and the NUC. 26 | 2. On the control laptop, activate your DROID conda environment. 27 | 3. Clone the openpi repo and install the openpi client, which we will use to connect to the policy server (this has very few dependencies and should be very fast to install): with the DROID conda environment activated, run `cd $OPENPI_ROOT/packages/openpi-client && pip install -e .`. 28 | 4. Install `tyro`, which we will use for command line parsing: `pip install tyro`. 29 | 5. Copy the `main.py` file from this directory to the `$DROID_ROOT/scripts` directory. 30 | 6. Replace the camera IDs in the `main.py` file with the IDs of your cameras (you can find the camera IDs by running `ZED_Explore` in the command line, which will open a tool that shows you all connected cameras and their IDs -- you can also use it to make sure that the cameras are well-positioned to see the scene you want the robot to interact with). 31 | 7. Run the `main.py` file. Make sure to point the IP and host address to the policy server. (To make sure the server machine is reachable from the DROID laptop, you can run `ping ` from the DROID laptop.) Also make sure to specify the external camera to use for the policy (we only input one external camera), choose from ["left", "right"]. 32 | 33 | ```bash 34 | python3 scripts/main.py --remote_host= --remote_port= --external_camera="left" 35 | ``` 36 | 37 | The script will ask you to enter a free-form language instruction for the robot to follow. Make sure to point the cameras at the scene you want the robot to interact with. You _do not_ need to carefully control camera angle, object positions, etc. The policy is fairly robust in our experience. Happy prompting! 38 | 39 | # Troubleshooting 40 | 41 | | Issue | Solution | 42 | |-------|----------| 43 | | Cannot reach policy server | Make sure the server is running and the IP and port are correct. You can check that the server machine is reachable by running `ping ` from the DROID laptop. | 44 | | Cannot find cameras | Make sure the camera IDs are correct and that the cameras are connected to the DROID laptop. Sometimes replugging the cameras can help. You can check all connected cameras by running `ZED_Explore` in the command line. | 45 | | Policy inference is slow / inconsistent | Try using a wired internet connection for the DROID laptop to reduce latency (0.5 - 1 sec latency per chunk is normal). | 46 | | Policy does not perform the task well | In our experiments, the policy could perform simple table top manipulation tasks (pick-and-place) across a wide range of environments, camera positions, and lighting conditions. If the policy does not perform the task well, you can try modifying the scene or object placement to make the task easier. Also make sure that the camera view you are passing to the policy can see all relevant objects in the scene (the policy is only conditioned on a single external camera + wrist camera, make sure you are feeding the desired camera to the policy). Use `ZED_Explore` to check that the camera view you are passing to the policy can see all relevant objects in the scene. Finally, the policy is far from perfect and will fail on more complex manipulation tasks, but it usually makes a decent effort. :) | 47 | -------------------------------------------------------------------------------- /examples/inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import dataclasses\n", 10 | "\n", 11 | "import jax\n", 12 | "\n", 13 | "from openpi.models import model as _model\n", 14 | "from openpi.policies import droid_policy\n", 15 | "from openpi.policies import policy_config as _policy_config\n", 16 | "from openpi.shared import download\n", 17 | "from openpi.training import config as _config\n", 18 | "from openpi.training import data_loader as _data_loader" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "# Policy inference\n", 26 | "\n", 27 | "The following example shows how to create a policy from a checkpoint and run inference on a dummy example." 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "config = _config.get_config(\"pi0_fast_droid\")\n", 37 | "checkpoint_dir = download.maybe_download(\"s3://openpi-assets/checkpoints/pi0_fast_droid\")\n", 38 | "\n", 39 | "# Create a trained policy.\n", 40 | "policy = _policy_config.create_trained_policy(config, checkpoint_dir)\n", 41 | "\n", 42 | "# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.\n", 43 | "example = droid_policy.make_droid_example()\n", 44 | "result = policy.infer(example)\n", 45 | "\n", 46 | "# Delete the policy to free up memory.\n", 47 | "del policy\n", 48 | "\n", 49 | "print(\"Actions shape:\", result[\"actions\"].shape)" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "# Working with a live model\n", 57 | "\n", 58 | "\n", 59 | "The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.\n" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "config = _config.get_config(\"pi0_aloha_sim\")\n", 69 | "\n", 70 | "checkpoint_dir = download.maybe_download(\"s3://openpi-assets/checkpoints/pi0_aloha_sim\")\n", 71 | "key = jax.random.key(0)\n", 72 | "\n", 73 | "# Create a model from the checkpoint.\n", 74 | "model = config.model.load(_model.restore_params(checkpoint_dir / \"params\"))\n", 75 | "\n", 76 | "# We can create fake observations and actions to test the model.\n", 77 | "obs, act = config.model.fake_obs(), config.model.fake_act()\n", 78 | "\n", 79 | "# Sample actions from the model.\n", 80 | "loss = model.compute_loss(key, obs, act)\n", 81 | "print(\"Loss shape:\", loss.shape)" 82 | ] 83 | }, 84 | { 85 | "cell_type": "markdown", 86 | "metadata": {}, 87 | "source": [ 88 | "Now, we are going to create a data loader and use a real batch of training data to compute the loss." 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "# Reduce the batch size to reduce memory usage.\n", 98 | "config = dataclasses.replace(config, batch_size=2)\n", 99 | "\n", 100 | "# Load a single batch of data. This is the same data that will be used during training.\n", 101 | "# NOTE: In order to make this example self-contained, we are skipping the normalization step\n", 102 | "# since it requires the normalization statistics to be generated using `compute_norm_stats`.\n", 103 | "loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)\n", 104 | "obs, act = next(iter(loader))\n", 105 | "\n", 106 | "# Sample actions from the model.\n", 107 | "loss = model.compute_loss(key, obs, act)\n", 108 | "\n", 109 | "# Delete the model to free up memory.\n", 110 | "del model\n", 111 | "\n", 112 | "print(\"Loss shape:\", loss.shape)" 113 | ] 114 | } 115 | ], 116 | "metadata": { 117 | "kernelspec": { 118 | "display_name": ".venv", 119 | "language": "python", 120 | "name": "python3" 121 | }, 122 | "language_info": { 123 | "codemirror_mode": { 124 | "name": "ipython", 125 | "version": 3 126 | }, 127 | "file_extension": ".py", 128 | "mimetype": "text/x-python", 129 | "name": "python", 130 | "nbconvert_exporter": "python", 131 | "pygments_lexer": "ipython3", 132 | "version": "3.11.9" 133 | } 134 | }, 135 | "nbformat": 4, 136 | "nbformat_minor": 2 137 | } 138 | -------------------------------------------------------------------------------- /examples/libero/Dockerfile: -------------------------------------------------------------------------------- 1 | # Dockerfile for the LIBERO benchmark. 2 | 3 | # Build the container: 4 | # docker build . -t libero -f examples/libero/Dockerfile 5 | 6 | # Run the container: 7 | # docker run --rm -it --network=host -v .:/app -v /tmp/.X11-unix:/tmp/.X11-unix:ro -e DISPLAY=$DISPLAY --gpus all libero /bin/bash 8 | 9 | FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0 10 | COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/ 11 | 12 | RUN apt-get update && \ 13 | apt-get install -y \ 14 | make \ 15 | g++ \ 16 | clang \ 17 | libosmesa6-dev \ 18 | libgl1-mesa-glx \ 19 | libglew-dev \ 20 | libglfw3-dev \ 21 | libgles2-mesa-dev \ 22 | libglib2.0-0 \ 23 | libsm6 \ 24 | libxrender1 \ 25 | libxext6 26 | 27 | WORKDIR /app 28 | 29 | # Copy from the cache instead of linking since it's a mounted volume 30 | ENV UV_LINK_MODE=copy 31 | 32 | # Write the virtual environment outside of the project directory so it doesn't 33 | # leak out of the container when we mount the application code. 34 | ENV UV_PROJECT_ENVIRONMENT=/.venv 35 | 36 | # Copy the requirements files so we can install dependencies. 37 | # The rest of the project is mounted as a volume, so we don't need to rebuild on changes. 38 | # This strategy is best for development-style usage. 39 | COPY ./examples/libero/requirements.txt /tmp/requirements.txt 40 | COPY ./third_party/libero/requirements.txt /tmp/requirements-libero.txt 41 | COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml 42 | 43 | # Install python dependencies. 44 | RUN uv venv --python 3.8 $UV_PROJECT_ENVIRONMENT 45 | RUN uv pip sync /tmp/requirements.txt /tmp/requirements-libero.txt /tmp/openpi-client/pyproject.toml --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match 46 | ENV PYTHONPATH=/app:/app/packages/openpi-client/src:/app/third_party/libero 47 | 48 | # Create a default config file to avoid an input prompt from LIBERO's init script. 49 | # https://github.com/Lifelong-Robot-Learning/LIBERO/blob/master/libero/libero/__init__.py 50 | ENV LIBERO_CONFIG_PATH=/tmp/libero 51 | RUN mkdir -p /tmp/libero && cat <<'EOF' > /tmp/libero/config.yaml 52 | benchmark_root: /app/third_party/libero/libero/libero 53 | bddl_files: /app/third_party/libero/libero/libero/bddl_files 54 | init_states: /app/third_party/libero/libero/libero/init_files 55 | datasets: /app/third_party/libero/libero/datasets 56 | assets: /app/third_party/libero/libero/libero/assets 57 | EOF 58 | 59 | CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/libero/main.py"] 60 | -------------------------------------------------------------------------------- /examples/libero/README.md: -------------------------------------------------------------------------------- 1 | # LIBERO Benchmark 2 | 3 | This example runs the LIBERO benchmark: https://github.com/Lifelong-Robot-Learning/LIBERO 4 | 5 | Note: When updating requirements.txt in this directory, there is an additional flag `--extra-index-url https://download.pytorch.org/whl/cu113` that must be added to the `uv pip compile` command. 6 | 7 | This example requires git submodules to be initialized. Don't forget to run: 8 | 9 | ```bash 10 | git submodule update --init --recursive 11 | ``` 12 | 13 | ## With Docker 14 | 15 | ```bash 16 | # Grant access to the X11 server: 17 | sudo xhost +local:docker 18 | 19 | export SERVER_ARGS="--env LIBERO" 20 | docker compose -f examples/libero/compose.yml up --build 21 | ``` 22 | 23 | ## Without Docker 24 | 25 | Terminal window 1: 26 | 27 | ```bash 28 | # Create virtual environment 29 | uv venv --python 3.8 examples/libero/.venv 30 | source examples/libero/.venv/bin/activate 31 | uv pip sync examples/libero/requirements.txt third_party/libero/requirements.txt --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match 32 | uv pip install -e packages/openpi-client 33 | uv pip install -e third_party/libero 34 | export PYTHONPATH=$PYTHONPATH:$PWD/third_party/libero 35 | 36 | # Run the simulation 37 | python examples/libero/main.py 38 | ``` 39 | 40 | Terminal window 2: 41 | 42 | ```bash 43 | # Run the server 44 | uv run scripts/serve_policy.py --env LIBERO 45 | ``` 46 | 47 | ## Results 48 | 49 | If you follow the training instructions and hyperparameters in the `pi0_libero` and `pi0_fast_libero` configs, you should get results similar to the following: 50 | 51 | | Model | Libero Spatial | Libero Object | Libero Goal | Libero 10 | Average | 52 | |-------|---------------|---------------|-------------|-----------|---------| 53 | | π0-FAST @ 30k (finetuned) | 96.4 | 96.8 | 88.6 | 60.2 | 85.5 | 54 | | π0 @ 30k (finetuned) | 96.8 | 98.8 | 95.8 | 85.2 | 94.15 | 55 | 56 | Note that the hyperparameters for these runs are not tuned and $\pi_0$-FAST does not use a FAST tokenizer optimized for Libero. Likely, the results could be improved with more tuning, we mainly use these results as an example of how to use openpi to fine-tune $\pi_0$ models on a new dataset. 57 | -------------------------------------------------------------------------------- /examples/libero/compose.yml: -------------------------------------------------------------------------------- 1 | # Run with: 2 | # docker compose -f examples/libero/compose.yml up --build 3 | services: 4 | runtime: 5 | image: libero 6 | depends_on: 7 | - openpi_server 8 | build: 9 | context: ../.. 10 | dockerfile: examples/libero/Dockerfile 11 | init: true 12 | tty: true 13 | network_mode: host 14 | privileged: true 15 | volumes: 16 | - $PWD:/app 17 | - ../../data:/data 18 | - /tmp/.X11-unix:/tmp/.X11-unix:ro 19 | environment: 20 | - DISPLAY=$DISPLAY 21 | deploy: 22 | resources: 23 | reservations: 24 | devices: 25 | - driver: nvidia 26 | count: 1 27 | capabilities: [gpu] 28 | 29 | openpi_server: 30 | image: openpi_server 31 | build: 32 | context: ../.. 33 | dockerfile: scripts/docker/serve_policy.Dockerfile 34 | init: true 35 | tty: true 36 | network_mode: host 37 | volumes: 38 | - $PWD:/app 39 | - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets 40 | environment: 41 | - SERVER_ARGS 42 | - OPENPI_DATA_HOME=/openpi_assets 43 | - IS_DOCKER=true 44 | 45 | # Comment out this block if not running on a machine with GPUs. 46 | deploy: 47 | resources: 48 | reservations: 49 | devices: 50 | - driver: nvidia 51 | count: 1 52 | capabilities: [gpu] 53 | -------------------------------------------------------------------------------- /examples/libero/convert_libero_data_to_lerobot.py: -------------------------------------------------------------------------------- 1 | """ 2 | Minimal example script for converting a dataset to LeRobot format. 3 | 4 | We use the Libero dataset (stored in RLDS) for this example, but it can be easily 5 | modified for any other data you have saved in a custom format. 6 | 7 | Usage: 8 | uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data 9 | 10 | If you want to push your dataset to the Hugging Face Hub, you can use the following command: 11 | uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub 12 | 13 | Note: to run the script, you need to install tensorflow_datasets: 14 | `uv pip install tensorflow tensorflow_datasets` 15 | 16 | You can download the raw Libero datasets from https://huggingface.co/datasets/openvla/modified_libero_rlds 17 | The resulting dataset will get saved to the $LEROBOT_HOME directory. 18 | Running this conversion script will take approximately 30 minutes. 19 | """ 20 | 21 | import shutil 22 | 23 | from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME 24 | from lerobot.common.datasets.lerobot_dataset import LeRobotDataset 25 | import tensorflow_datasets as tfds 26 | import tyro 27 | 28 | REPO_NAME = "xxx/libero" # Name of the output dataset, also used for the Hugging Face Hub 29 | RAW_DATASET_NAMES = [ 30 | "libero_10_no_noops", 31 | "libero_goal_no_noops", 32 | "libero_object_no_noops", 33 | "libero_spatial_no_noops", 34 | ] # For simplicity we will combine multiple Libero datasets into one training dataset 35 | 36 | 37 | def main(data_dir: str, *, push_to_hub: bool = False): 38 | # Clean up any existing dataset in the output directory 39 | output_path = LEROBOT_HOME / REPO_NAME 40 | if output_path.exists(): 41 | shutil.rmtree(output_path) 42 | 43 | # Create LeRobot dataset, define features to store 44 | # OpenPi assumes that proprio is stored in `state` and actions in `action` 45 | # LeRobot assumes that dtype of image data is `image` 46 | dataset = LeRobotDataset.create( 47 | repo_id=REPO_NAME, 48 | robot_type="panda", 49 | fps=10, 50 | features={ 51 | "image": { 52 | "dtype": "image", 53 | "shape": (256, 256, 3), 54 | "names": ["height", "width", "channel"], 55 | }, 56 | "wrist_image": { 57 | "dtype": "image", 58 | "shape": (256, 256, 3), 59 | "names": ["height", "width", "channel"], 60 | }, 61 | "state": { 62 | "dtype": "float32", 63 | "shape": (8,), 64 | "names": ["state"], 65 | }, 66 | "actions": { 67 | "dtype": "float32", 68 | "shape": (7,), 69 | "names": ["actions"], 70 | }, 71 | }, 72 | image_writer_threads=10, 73 | image_writer_processes=5, 74 | ) 75 | 76 | # Loop over raw Libero datasets and write episodes to the LeRobot dataset 77 | # You can modify this for your own data format 78 | for raw_dataset_name in RAW_DATASET_NAMES: 79 | raw_dataset = tfds.load(raw_dataset_name, data_dir=data_dir, split="train") 80 | for episode in raw_dataset: 81 | for step in episode["steps"].as_numpy_iterator(): 82 | dataset.add_frame( 83 | { 84 | "image": step["observation"]["image"], 85 | "wrist_image": step["observation"]["wrist_image"], 86 | "state": step["observation"]["state"], 87 | "actions": step["action"], 88 | } 89 | ) 90 | dataset.save_episode(task=step["language_instruction"].decode()) 91 | 92 | # Consolidate the dataset, skip computing stats since we will do that later 93 | dataset.consolidate(run_compute_stats=False) 94 | 95 | # Optionally push to the Hugging Face Hub 96 | if push_to_hub: 97 | dataset.push_to_hub( 98 | tags=["libero", "panda", "rlds"], 99 | private=False, 100 | push_videos=True, 101 | license="apache-2.0", 102 | ) 103 | 104 | 105 | if __name__ == "__main__": 106 | tyro.cli(main) 107 | -------------------------------------------------------------------------------- /examples/libero/requirements.in: -------------------------------------------------------------------------------- 1 | imageio[ffmpeg] 2 | numpy==1.22.4 3 | tqdm 4 | tyro 5 | PyYaml 6 | opencv-python==4.6.0.66 7 | torch==1.11.0+cu113 8 | torchvision==0.12.0+cu113 9 | torchaudio==0.11.0+cu113 10 | robosuite==1.4.1 11 | matplotlib==3.5.3 12 | -------------------------------------------------------------------------------- /examples/libero/requirements.txt: -------------------------------------------------------------------------------- 1 | # This file was autogenerated by uv via the following command: 2 | # uv pip compile examples/libero/requirements.in -o examples/libero/requirements.txt --python-version 3.8 --index-strategy=unsafe-best-match 3 | absl-py==2.1.0 4 | # via mujoco 5 | certifi==2024.12.14 6 | # via requests 7 | charset-normalizer==3.4.0 8 | # via requests 9 | cycler==0.12.1 10 | # via matplotlib 11 | docstring-parser==0.16 12 | # via tyro 13 | etils==1.3.0 14 | # via mujoco 15 | eval-type-backport==0.2.0 16 | # via tyro 17 | evdev==1.7.1 18 | # via pynput 19 | fonttools==4.55.3 20 | # via matplotlib 21 | glfw==1.12.0 22 | # via mujoco 23 | idna==3.10 24 | # via requests 25 | imageio==2.35.1 26 | # via -r examples/libero/requirements.in 27 | imageio-ffmpeg==0.5.1 28 | # via imageio 29 | importlib-metadata==8.5.0 30 | # via typeguard 31 | importlib-resources==6.4.5 32 | # via etils 33 | kiwisolver==1.4.7 34 | # via matplotlib 35 | llvmlite==0.36.0 36 | # via numba 37 | markdown-it-py==3.0.0 38 | # via rich 39 | matplotlib==3.5.3 40 | # via -r examples/libero/requirements.in 41 | mdurl==0.1.2 42 | # via markdown-it-py 43 | mujoco==3.2.3 44 | # via robosuite 45 | numba==0.53.1 46 | # via robosuite 47 | numpy==1.22.4 48 | # via 49 | # -r examples/libero/requirements.in 50 | # imageio 51 | # matplotlib 52 | # mujoco 53 | # numba 54 | # opencv-python 55 | # robosuite 56 | # scipy 57 | # torchvision 58 | opencv-python==4.6.0.66 59 | # via 60 | # -r examples/libero/requirements.in 61 | # robosuite 62 | packaging==24.2 63 | # via matplotlib 64 | pillow==10.4.0 65 | # via 66 | # imageio 67 | # matplotlib 68 | # robosuite 69 | # torchvision 70 | psutil==6.1.0 71 | # via imageio 72 | pygments==2.18.0 73 | # via rich 74 | pynput==1.7.7 75 | # via robosuite 76 | pyopengl==3.1.7 77 | # via mujoco 78 | pyparsing==3.1.4 79 | # via matplotlib 80 | python-dateutil==2.9.0.post0 81 | # via matplotlib 82 | python-xlib==0.33 83 | # via pynput 84 | pyyaml==6.0.2 85 | # via -r examples/libero/requirements.in 86 | requests==2.32.3 87 | # via torchvision 88 | rich==13.9.4 89 | # via tyro 90 | robosuite==1.4.1 91 | # via -r examples/libero/requirements.in 92 | scipy==1.10.1 93 | # via robosuite 94 | setuptools==75.3.0 95 | # via 96 | # imageio-ffmpeg 97 | # numba 98 | shtab==1.7.1 99 | # via tyro 100 | six==1.17.0 101 | # via 102 | # pynput 103 | # python-dateutil 104 | # python-xlib 105 | termcolor==2.4.0 106 | # via robosuite 107 | torch==1.11.0+cu113 108 | # via 109 | # -r examples/libero/requirements.in 110 | # torchaudio 111 | # torchvision 112 | torchaudio==0.11.0+cu113 113 | # via -r examples/libero/requirements.in 114 | torchvision==0.12.0+cu113 115 | # via -r examples/libero/requirements.in 116 | tqdm==4.67.1 117 | # via -r examples/libero/requirements.in 118 | typeguard==4.4.0 119 | # via tyro 120 | typing-extensions==4.12.2 121 | # via 122 | # etils 123 | # rich 124 | # torch 125 | # torchvision 126 | # typeguard 127 | # tyro 128 | tyro==0.9.2 129 | # via -r examples/libero/requirements.in 130 | urllib3==2.2.3 131 | # via requests 132 | zipp==3.20.2 133 | # via 134 | # etils 135 | # importlib-metadata 136 | # importlib-resources 137 | -------------------------------------------------------------------------------- /examples/policy_records.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pathlib\n", 10 | "\n", 11 | "import numpy as np\n", 12 | "\n", 13 | "record_path = pathlib.Path(\"../policy_records\")\n", 14 | "num_steps = len(list(record_path.glob(\"step_*.npy\")))\n", 15 | "\n", 16 | "records = []\n", 17 | "for i in range(num_steps):\n", 18 | " record = np.load(record_path / f\"step_{i}.npy\", allow_pickle=True).item()\n", 19 | " records.append(record)" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "print(\"length of records\", len(records))\n", 29 | "print(\"keys in records\", records[0].keys())\n", 30 | "\n", 31 | "for k in records[0]:\n", 32 | " print(f\"{k} shape: {records[0][k].shape}\")" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "from PIL import Image\n", 42 | "\n", 43 | "\n", 44 | "def get_image(step: int, idx: int = 0):\n", 45 | " img = (255 * records[step][\"inputs/image\"]).astype(np.uint8)\n", 46 | " return img[idx].transpose(1, 2, 0)\n", 47 | "\n", 48 | "\n", 49 | "def show_image(step: int, idx_lst: list[int]):\n", 50 | " imgs = [get_image(step, idx) for idx in idx_lst]\n", 51 | " return Image.fromarray(np.hstack(imgs))\n", 52 | "\n", 53 | "\n", 54 | "for i in range(2):\n", 55 | " display(show_image(i, [0]))" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 14, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "import pandas as pd\n", 65 | "\n", 66 | "\n", 67 | "def get_axis(name, axis):\n", 68 | " return np.array([record[name][axis] for record in records])\n", 69 | "\n", 70 | "\n", 71 | "# qpos is [..., 14] of type float:\n", 72 | "# 0-5: left arm joint angles\n", 73 | "# 6: left arm gripper\n", 74 | "# 7-12: right arm joint angles\n", 75 | "# 13: right arm gripper\n", 76 | "names = [(\"left_joint\", 6), (\"left_gripper\", 1), (\"right_joint\", 6), (\"right_gripper\", 1)]\n", 77 | "\n", 78 | "\n", 79 | "def make_data():\n", 80 | " cur_dim = 0\n", 81 | " in_data = {}\n", 82 | " out_data = {}\n", 83 | " for name, dim_size in names:\n", 84 | " for i in range(dim_size):\n", 85 | " in_data[f\"{name}_{i}\"] = get_axis(\"inputs/qpos\", cur_dim)\n", 86 | " out_data[f\"{name}_{i}\"] = get_axis(\"outputs/qpos\", cur_dim)\n", 87 | " cur_dim += 1\n", 88 | " return pd.DataFrame(in_data), pd.DataFrame(out_data)\n", 89 | "\n", 90 | "\n", 91 | "in_data, out_data = make_data()" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "for name in in_data.columns:\n", 101 | " data = pd.DataFrame({f\"in_{name}\": in_data[name], f\"out_{name}\": out_data[name]})\n", 102 | " data.plot()" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [] 111 | } 112 | ], 113 | "metadata": { 114 | "kernelspec": { 115 | "display_name": ".venv", 116 | "language": "python", 117 | "name": "python3" 118 | }, 119 | "language_info": { 120 | "codemirror_mode": { 121 | "name": "ipython", 122 | "version": 3 123 | }, 124 | "file_extension": ".py", 125 | "mimetype": "text/x-python", 126 | "name": "python", 127 | "nbconvert_exporter": "python", 128 | "pygments_lexer": "ipython3", 129 | "version": "3.11.9" 130 | } 131 | }, 132 | "nbformat": 4, 133 | "nbformat_minor": 2 134 | } 135 | -------------------------------------------------------------------------------- /examples/simple_client/Dockerfile: -------------------------------------------------------------------------------- 1 | # Dockerfile for the simple client. 2 | 3 | # Build the container: 4 | # docker build . -t simple_client -f examples/simple_client/Dockerfile 5 | 6 | # Run the container: 7 | # docker run --rm -it --network=host -v .:/app simple_client /bin/bash 8 | 9 | FROM python:3.7-slim 10 | COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/ 11 | 12 | WORKDIR /app 13 | 14 | # Copy from the cache instead of linking since it's a mounted volume 15 | ENV UV_LINK_MODE=copy 16 | 17 | # Write the virtual environment outside of the project directory so it doesn't 18 | # leak out of the container when we mount the application code. 19 | ENV UV_PROJECT_ENVIRONMENT=/.venv 20 | 21 | # Copy the requirements files so we can install dependencies. 22 | # The rest of the project is mounted as a volume, so we don't need to rebuild on changes. 23 | # This strategy is best for development-style usage. 24 | COPY ./examples/simple_client/requirements.txt /tmp/requirements.txt 25 | COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml 26 | 27 | # Install python dependencies. 28 | RUN uv venv --python 3.7 $UV_PROJECT_ENVIRONMENT 29 | RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml 30 | ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src 31 | 32 | CMD /bin/bash -c "source /.venv/bin/activate && python examples/simple_client/main.py $SERVER_ARGS" 33 | -------------------------------------------------------------------------------- /examples/simple_client/README.md: -------------------------------------------------------------------------------- 1 | # Simple Client 2 | 3 | A minimal client that sends observations to the server and prints the inference rate. 4 | 5 | You can specify which runtime environment to use using the `--env` flag. You can see the available options by running: 6 | 7 | ```bash 8 | uv run examples/simple_client/main.py --help 9 | ``` 10 | 11 | ## With Docker 12 | 13 | ```bash 14 | export SERVER_ARGS="--env ALOHA_SIM" 15 | docker compose -f examples/simple_client/compose.yml up --build 16 | ``` 17 | 18 | ## Without Docker 19 | 20 | Terminal window 1: 21 | 22 | ```bash 23 | uv run examples/simple_client/main.py --env DROID 24 | ``` 25 | 26 | Terminal window 2: 27 | 28 | ```bash 29 | uv run scripts/serve_policy.py --env DROID 30 | ``` 31 | -------------------------------------------------------------------------------- /examples/simple_client/compose.yml: -------------------------------------------------------------------------------- 1 | # Run with: 2 | # docker compose -f examples/simple_client/compose.yml up --build 3 | services: 4 | runtime: 5 | image: simple_client 6 | depends_on: 7 | - openpi_server 8 | build: 9 | context: ../.. 10 | dockerfile: examples/simple_client/Dockerfile 11 | init: true 12 | tty: true 13 | network_mode: host 14 | volumes: 15 | - $PWD:/app 16 | environment: 17 | - SERVER_ARGS 18 | 19 | openpi_server: 20 | image: openpi_server 21 | build: 22 | context: ../.. 23 | dockerfile: scripts/docker/serve_policy.Dockerfile 24 | init: true 25 | tty: true 26 | network_mode: host 27 | volumes: 28 | - $PWD:/app 29 | - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets 30 | environment: 31 | - SERVER_ARGS 32 | - OPENPI_DATA_HOME=/openpi_assets 33 | - IS_DOCKER=true 34 | 35 | # Comment out this block if not running on a machine with GPUs. 36 | deploy: 37 | resources: 38 | reservations: 39 | devices: 40 | - driver: nvidia 41 | count: 1 42 | capabilities: [gpu] 43 | -------------------------------------------------------------------------------- /examples/simple_client/main.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import enum 3 | import logging 4 | import time 5 | 6 | import numpy as np 7 | from openpi_client import websocket_client_policy as _websocket_client_policy 8 | import tyro 9 | 10 | 11 | class EnvMode(enum.Enum): 12 | """Supported environments.""" 13 | 14 | ALOHA = "aloha" 15 | ALOHA_SIM = "aloha_sim" 16 | DROID = "droid" 17 | LIBERO = "libero" 18 | 19 | 20 | @dataclasses.dataclass 21 | class Args: 22 | host: str = "0.0.0.0" 23 | port: int = 8000 24 | 25 | env: EnvMode = EnvMode.ALOHA_SIM 26 | num_steps: int = 10 27 | 28 | 29 | def main(args: Args) -> None: 30 | obs_fn = { 31 | EnvMode.ALOHA: _random_observation_aloha, 32 | EnvMode.ALOHA_SIM: _random_observation_aloha, 33 | EnvMode.DROID: _random_observation_droid, 34 | EnvMode.LIBERO: _random_observation_libero, 35 | }[args.env] 36 | 37 | policy = _websocket_client_policy.WebsocketClientPolicy( 38 | host=args.host, 39 | port=args.port, 40 | ) 41 | logging.info(f"Server metadata: {policy.get_server_metadata()}") 42 | 43 | # Send 1 observation to make sure the model is loaded. 44 | policy.infer(obs_fn()) 45 | 46 | start = time.time() 47 | for _ in range(args.num_steps): 48 | policy.infer(obs_fn()) 49 | end = time.time() 50 | 51 | print(f"Total time taken: {end - start:.2f} s") 52 | print(f"Average inference time: {1000 * (end - start) / args.num_steps:.2f} ms") 53 | 54 | 55 | def _random_observation_aloha() -> dict: 56 | return { 57 | "state": np.ones((14,)), 58 | "images": { 59 | "cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), 60 | "cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), 61 | "cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), 62 | "cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), 63 | }, 64 | "prompt": "do something", 65 | } 66 | 67 | 68 | def _random_observation_droid() -> dict: 69 | return { 70 | "observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), 71 | "observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), 72 | "observation/joint_position": np.random.rand(7), 73 | "observation/gripper_position": np.random.rand(1), 74 | "prompt": "do something", 75 | } 76 | 77 | 78 | def _random_observation_libero() -> dict: 79 | return { 80 | "observation/state": np.random.rand(8), 81 | "observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), 82 | "observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), 83 | "prompt": "do something", 84 | } 85 | 86 | 87 | if __name__ == "__main__": 88 | logging.basicConfig(level=logging.INFO) 89 | main(tyro.cli(Args)) 90 | -------------------------------------------------------------------------------- /examples/simple_client/requirements.in: -------------------------------------------------------------------------------- 1 | numpy 2 | tyro -------------------------------------------------------------------------------- /examples/simple_client/requirements.txt: -------------------------------------------------------------------------------- 1 | # This file was autogenerated by uv via the following command: 2 | # uv pip compile examples/simple_client/requirements.in -o examples/simple_client/requirements.txt --python-version 3.7 3 | backports-cached-property==1.0.2 4 | # via tyro 5 | docstring-parser==0.16 6 | # via tyro 7 | eval-type-backport==0.1.3 8 | # via tyro 9 | markdown-it-py==2.2.0 10 | # via rich 11 | mdurl==0.1.2 12 | # via markdown-it-py 13 | numpy==1.21.6 14 | # via -r examples/simple_client/requirements.in 15 | pygments==2.17.2 16 | # via rich 17 | rich==13.8.1 18 | # via tyro 19 | shtab==1.7.1 20 | # via tyro 21 | typing-extensions==4.7.1 22 | # via 23 | # markdown-it-py 24 | # rich 25 | # tyro 26 | tyro==0.9.1 27 | # via -r examples/simple_client/requirements.in 28 | -------------------------------------------------------------------------------- /examples/umi/README.md: -------------------------------------------------------------------------------- 1 | # OneTwoVLA Policy Server 2 | 3 | Here are the instructions for running the policy server for OneTwoVLA. We provide the code to launch the UMI client in this [repository](https://github.com/Fanqi-Lin/OneTwoVLA-UMI-Client). 4 | 5 | --- 6 | 7 | ## Set Up the Policy Server 8 | 9 | First, install the required dependencies: 10 | 11 | ```bash 12 | uv pip install pynput 13 | ``` 14 | > *Note: You may need `sudo` permissions.* 15 | 16 | Next, start the policy server on your desired port (e.g., `8000`): 17 | 18 | ```bash 19 | uv run scripts/serve_policy.py --port 8000 \ 20 | policy:checkpoint \ 21 | --policy.config=onetwovla_visual_grounding \ 22 | --policy.dir=/path/to/your/checkpoint 23 | ``` 24 | 25 | **Supported policy configurations:** 26 | - `onetwovla_visual_cocktail` 27 | - `onetwovla_visual_grounding` 28 | - `pi0_visual_cocktail` 29 | - `pi0_visual_grounding` 30 | -------------------------------------------------------------------------------- /figures/fisheye-aug.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fanqi-Lin/OneTwoVLA/2489662cf1d4f75dbeea3cf9f337c95391e2b836/figures/fisheye-aug.png -------------------------------------------------------------------------------- /packages/openpi-client/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "openpi-client" 3 | version = "0.1.0" 4 | requires-python = ">=3.7" 5 | dependencies = [ 6 | "dm-tree>=0.1.8", 7 | "msgpack>=1.0.5", 8 | "numpy>=1.21.6", 9 | "pillow>=9.0.0", 10 | "tree>=0.2.4", 11 | "websockets>=11.0", 12 | ] 13 | 14 | [build-system] 15 | requires = ["hatchling"] 16 | build-backend = "hatchling.build" 17 | 18 | [tool.uv] 19 | dev-dependencies = [ 20 | "pytest>=8.3.4", 21 | ] 22 | 23 | [tool.ruff] 24 | line-length = 120 25 | target-version = "py37" -------------------------------------------------------------------------------- /packages/openpi-client/src/openpi_client/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.0" 2 | -------------------------------------------------------------------------------- /packages/openpi-client/src/openpi_client/action_chunk_broker.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import numpy as np 4 | import tree 5 | from typing_extensions import override 6 | 7 | from openpi_client import base_policy as _base_policy 8 | 9 | 10 | class ActionChunkBroker(_base_policy.BasePolicy): 11 | """Wraps a policy to return action chunks one-at-a-time. 12 | 13 | Assumes that the first dimension of all action fields is the chunk size. 14 | 15 | A new inference call to the inner policy is only made when the current 16 | list of chunks is exhausted. 17 | """ 18 | 19 | def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int): 20 | self._policy = policy 21 | 22 | self._action_horizon = action_horizon 23 | self._cur_step: int = 0 24 | 25 | self._last_results: Dict[str, np.ndarray] | None = None 26 | 27 | @override 28 | def infer(self, obs: Dict) -> Dict: # noqa: UP006 29 | if self._last_results is None: 30 | self._last_results = self._policy.infer(obs) 31 | self._cur_step = 0 32 | 33 | results = tree.map_structure(lambda x: x[self._cur_step, ...], self._last_results) 34 | self._cur_step += 1 35 | 36 | if self._cur_step >= self._action_horizon: 37 | self._last_results = None 38 | 39 | return results 40 | 41 | @override 42 | def reset(self) -> None: 43 | self._policy.reset() 44 | self._last_results = None 45 | self._cur_step = 0 46 | -------------------------------------------------------------------------------- /packages/openpi-client/src/openpi_client/base_policy.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Dict 3 | 4 | 5 | class BasePolicy(abc.ABC): 6 | @abc.abstractmethod 7 | def infer(self, obs: Dict) -> Dict: 8 | """Infer actions from observations.""" 9 | 10 | def reset(self) -> None: 11 | """Reset the policy to its initial state.""" 12 | pass 13 | -------------------------------------------------------------------------------- /packages/openpi-client/src/openpi_client/image_tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | 5 | def convert_to_uint8(img: np.ndarray) -> np.ndarray: 6 | """Converts an image to uint8 if it is a float image. 7 | 8 | This is important for reducing the size of the image when sending it over the network. 9 | """ 10 | if np.issubdtype(img.dtype, np.floating): 11 | img = (255 * img).astype(np.uint8) 12 | return img 13 | 14 | 15 | def resize_with_pad(images: np.ndarray, height: int, width: int, method=Image.BILINEAR) -> np.ndarray: 16 | """Replicates tf.image.resize_with_pad for multiple images using PIL. Resizes a batch of images to a target height. 17 | 18 | Args: 19 | images: A batch of images in [..., height, width, channel] format. 20 | height: The target height of the image. 21 | width: The target width of the image. 22 | method: The interpolation method to use. Default is bilinear. 23 | 24 | Returns: 25 | The resized images in [..., height, width, channel]. 26 | """ 27 | # If the images are already the correct size, return them as is. 28 | if images.shape[-3:-1] == (height, width): 29 | return images 30 | 31 | original_shape = images.shape 32 | 33 | images = images.reshape(-1, *original_shape[-3:]) 34 | resized = np.stack([_resize_with_pad_pil(Image.fromarray(im), height, width, method=method) for im in images]) 35 | return resized.reshape(*original_shape[:-3], *resized.shape[-3:]) 36 | 37 | 38 | def _resize_with_pad_pil(image: Image.Image, height: int, width: int, method: int) -> Image.Image: 39 | """Replicates tf.image.resize_with_pad for one image using PIL. Resizes an image to a target height and 40 | width without distortion by padding with zeros. 41 | 42 | Unlike the jax version, note that PIL uses [width, height, channel] ordering instead of [batch, h, w, c]. 43 | """ 44 | cur_width, cur_height = image.size 45 | if cur_width == width and cur_height == height: 46 | return image # No need to resize if the image is already the correct size. 47 | 48 | ratio = max(cur_width / width, cur_height / height) 49 | resized_height = int(cur_height / ratio) 50 | resized_width = int(cur_width / ratio) 51 | resized_image = image.resize((resized_width, resized_height), resample=method) 52 | 53 | zero_image = Image.new(resized_image.mode, (width, height), 0) 54 | pad_height = max(0, int((height - resized_height) / 2)) 55 | pad_width = max(0, int((width - resized_width) / 2)) 56 | zero_image.paste(resized_image, (pad_width, pad_height)) 57 | assert zero_image.size == (width, height) 58 | return zero_image 59 | -------------------------------------------------------------------------------- /packages/openpi-client/src/openpi_client/image_tools_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import openpi_client.image_tools as image_tools 4 | 5 | 6 | def test_resize_with_pad_shapes(): 7 | # Test case 1: Resize image with larger dimensions 8 | images = np.zeros((2, 10, 10, 3), dtype=np.uint8) # Input images of shape (batch_size, height, width, channels) 9 | height = 20 10 | width = 20 11 | resized_images = image_tools.resize_with_pad(images, height, width) 12 | assert resized_images.shape == (2, height, width, 3) 13 | assert np.all(resized_images == 0) 14 | 15 | # Test case 2: Resize image with smaller dimensions 16 | images = np.zeros((3, 30, 30, 3), dtype=np.uint8) 17 | height = 15 18 | width = 15 19 | resized_images = image_tools.resize_with_pad(images, height, width) 20 | assert resized_images.shape == (3, height, width, 3) 21 | assert np.all(resized_images == 0) 22 | 23 | # Test case 3: Resize image with the same dimensions 24 | images = np.zeros((1, 50, 50, 3), dtype=np.uint8) 25 | height = 50 26 | width = 50 27 | resized_images = image_tools.resize_with_pad(images, height, width) 28 | assert resized_images.shape == (1, height, width, 3) 29 | assert np.all(resized_images == 0) 30 | 31 | # Test case 3: Resize image with odd-numbered padding 32 | images = np.zeros((1, 256, 320, 3), dtype=np.uint8) 33 | height = 60 34 | width = 80 35 | resized_images = image_tools.resize_with_pad(images, height, width) 36 | assert resized_images.shape == (1, height, width, 3) 37 | assert np.all(resized_images == 0) 38 | -------------------------------------------------------------------------------- /packages/openpi-client/src/openpi_client/msgpack_numpy.py: -------------------------------------------------------------------------------- 1 | """Adds NumPy array support to msgpack. 2 | 3 | msgpack is good for (de)serializing data over a network for multiple reasons: 4 | - msgpack is secure (as opposed to pickle/dill/etc which allow for arbitrary code execution) 5 | - msgpack is widely used and has good cross-language support 6 | - msgpack does not require a schema (as opposed to protobuf/flatbuffers/etc) which is convenient in dynamically typed 7 | languages like Python and JavaScript 8 | - msgpack is fast and efficient (as opposed to readable formats like JSON/YAML/etc); I found that msgpack was ~4x faster 9 | than pickle for serializing large arrays using the below strategy 10 | 11 | The code below is adapted from https://github.com/lebedov/msgpack-numpy. The reason not to use that library directly is 12 | that it falls back to pickle for object arrays. 13 | """ 14 | 15 | import functools 16 | 17 | import msgpack 18 | import numpy as np 19 | 20 | 21 | def pack_array(obj): 22 | if (isinstance(obj, (np.ndarray, np.generic))) and obj.dtype.kind in ("V", "O", "c"): 23 | raise ValueError(f"Unsupported dtype: {obj.dtype}") 24 | 25 | if isinstance(obj, np.ndarray): 26 | return { 27 | b"__ndarray__": True, 28 | b"data": obj.tobytes(), 29 | b"dtype": obj.dtype.str, 30 | b"shape": obj.shape, 31 | } 32 | 33 | if isinstance(obj, np.generic): 34 | return { 35 | b"__npgeneric__": True, 36 | b"data": obj.item(), 37 | b"dtype": obj.dtype.str, 38 | } 39 | 40 | return obj 41 | 42 | 43 | def unpack_array(obj): 44 | if b"__ndarray__" in obj: 45 | return np.ndarray(buffer=obj[b"data"], dtype=np.dtype(obj[b"dtype"]), shape=obj[b"shape"]) 46 | 47 | if b"__npgeneric__" in obj: 48 | return np.dtype(obj[b"dtype"]).type(obj[b"data"]) 49 | 50 | return obj 51 | 52 | 53 | Packer = functools.partial(msgpack.Packer, default=pack_array) 54 | packb = functools.partial(msgpack.packb, default=pack_array) 55 | 56 | Unpacker = functools.partial(msgpack.Unpacker, object_hook=unpack_array) 57 | unpackb = functools.partial(msgpack.unpackb, object_hook=unpack_array) 58 | -------------------------------------------------------------------------------- /packages/openpi-client/src/openpi_client/msgpack_numpy_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import tree 4 | 5 | from openpi_client import msgpack_numpy 6 | 7 | 8 | def _check(expected, actual): 9 | if isinstance(expected, np.ndarray): 10 | assert expected.shape == actual.shape 11 | assert expected.dtype == actual.dtype 12 | assert np.array_equal(expected, actual, equal_nan=expected.dtype.kind == "f") 13 | else: 14 | assert expected == actual 15 | 16 | 17 | @pytest.mark.parametrize( 18 | "data", 19 | [ 20 | 1, # int 21 | 1.0, # float 22 | "hello", # string 23 | np.bool_(True), # boolean scalar 24 | np.array([1, 2, 3])[0], # int scalar 25 | np.str_("asdf"), # string scalar 26 | [1, 2, 3], # list 27 | {"key": "value"}, # dict 28 | {"key": [1, 2, 3]}, # nested dict 29 | np.array(1.0), # 0D array 30 | np.array([1, 2, 3], dtype=np.int32), # 1D integer array 31 | np.array(["asdf", "qwer"]), # string array 32 | np.array([True, False]), # boolean array 33 | np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), # 2D float array 34 | np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.int16), # 3D integer array 35 | np.array([np.nan, np.inf, -np.inf]), # special float values 36 | {"arr": np.array([1, 2, 3]), "nested": {"arr": np.array([4, 5, 6])}}, # nested dict with arrays 37 | [np.array([1, 2]), np.array([3, 4])], # list of arrays 38 | np.zeros((3, 4, 5), dtype=np.float32), # 3D zeros 39 | np.ones((2, 3), dtype=np.float64), # 2D ones with double precision 40 | ], 41 | ) 42 | def test_pack_unpack(data): 43 | packed = msgpack_numpy.packb(data) 44 | unpacked = msgpack_numpy.unpackb(packed) 45 | tree.map_structure(_check, data, unpacked) 46 | -------------------------------------------------------------------------------- /packages/openpi-client/src/openpi_client/runtime/agent.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class Agent(abc.ABC): 5 | """An Agent is the thing with agency, i.e. the entity that makes decisions. 6 | 7 | Agents receive observations about the state of the world, and return actions 8 | to take in response. 9 | """ 10 | 11 | @abc.abstractmethod 12 | def get_action(self, observation: dict) -> dict: 13 | """Query the agent for the next action.""" 14 | 15 | @abc.abstractmethod 16 | def reset(self) -> None: 17 | """Reset the agent to its initial state.""" 18 | -------------------------------------------------------------------------------- /packages/openpi-client/src/openpi_client/runtime/agents/policy_agent.py: -------------------------------------------------------------------------------- 1 | from typing_extensions import override 2 | 3 | from openpi_client import base_policy as _base_policy 4 | from openpi_client.runtime import agent as _agent 5 | 6 | 7 | class PolicyAgent(_agent.Agent): 8 | """An agent that uses a policy to determine actions.""" 9 | 10 | def __init__(self, policy: _base_policy.BasePolicy) -> None: 11 | self._policy = policy 12 | 13 | @override 14 | def get_action(self, observation: dict) -> dict: 15 | return self._policy.infer(observation) 16 | 17 | def reset(self) -> None: 18 | self._policy.reset() 19 | -------------------------------------------------------------------------------- /packages/openpi-client/src/openpi_client/runtime/environment.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class Environment(abc.ABC): 5 | """An Environment represents the robot and the environment it inhabits. 6 | 7 | The primary contract of environments is that they can be queried for observations 8 | about their state, and have actions applied to them to change that state. 9 | """ 10 | 11 | @abc.abstractmethod 12 | def reset(self) -> None: 13 | """Reset the environment to its initial state. 14 | 15 | This will be called once before starting each episode. 16 | """ 17 | 18 | @abc.abstractmethod 19 | def is_episode_complete(self) -> bool: 20 | """Allow the environment to signal that the episode is complete. 21 | 22 | This will be called after each step. It should return `True` if the episode is 23 | complete (either successfully or unsuccessfully), and `False` otherwise. 24 | """ 25 | 26 | @abc.abstractmethod 27 | def get_observation(self) -> dict: 28 | """Query the environment for the current state.""" 29 | 30 | @abc.abstractmethod 31 | def apply_action(self, action: dict) -> None: 32 | """Take an action in the environment.""" 33 | -------------------------------------------------------------------------------- /packages/openpi-client/src/openpi_client/runtime/runtime.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import threading 3 | import time 4 | 5 | from openpi_client.runtime import agent as _agent 6 | from openpi_client.runtime import environment as _environment 7 | from openpi_client.runtime import subscriber as _subscriber 8 | 9 | 10 | class Runtime: 11 | """The core module orchestrating interactions between key components of the system.""" 12 | 13 | def __init__( 14 | self, 15 | environment: _environment.Environment, 16 | agent: _agent.Agent, 17 | subscribers: list[_subscriber.Subscriber], 18 | max_hz: float = 0, 19 | num_episodes: int = 1, 20 | max_episode_steps: int = 0, 21 | ) -> None: 22 | self._environment = environment 23 | self._agent = agent 24 | self._subscribers = subscribers 25 | self._max_hz = max_hz 26 | self._num_episodes = num_episodes 27 | self._max_episode_steps = max_episode_steps 28 | 29 | self._in_episode = False 30 | self._episode_steps = 0 31 | 32 | def run(self) -> None: 33 | """Runs the runtime loop continuously until stop() is called or the environment is done.""" 34 | for _ in range(self._num_episodes): 35 | self._run_episode() 36 | 37 | # Final reset, this is important for real environments to move the robot to its home position. 38 | self._environment.reset() 39 | 40 | def run_in_new_thread(self) -> threading.Thread: 41 | """Runs the runtime loop in a new thread.""" 42 | thread = threading.Thread(target=self.run) 43 | thread.start() 44 | return thread 45 | 46 | def mark_episode_complete(self) -> None: 47 | """Marks the end of an episode.""" 48 | self._in_episode = False 49 | 50 | def _run_episode(self) -> None: 51 | """Runs a single episode.""" 52 | logging.info("Starting episode...") 53 | self._environment.reset() 54 | self._agent.reset() 55 | for subscriber in self._subscribers: 56 | subscriber.on_episode_start() 57 | 58 | self._in_episode = True 59 | self._episode_steps = 0 60 | step_time = 1 / self._max_hz if self._max_hz > 0 else 0 61 | last_step_time = time.time() 62 | 63 | while self._in_episode: 64 | self._step() 65 | self._episode_steps += 1 66 | 67 | # Sleep to maintain the desired frame rate 68 | now = time.time() 69 | dt = now - last_step_time 70 | if dt < step_time: 71 | time.sleep(step_time - dt) 72 | last_step_time = time.time() 73 | else: 74 | last_step_time = now 75 | 76 | logging.info("Episode completed.") 77 | for subscriber in self._subscribers: 78 | subscriber.on_episode_end() 79 | 80 | def _step(self) -> None: 81 | """A single step of the runtime loop.""" 82 | observation = self._environment.get_observation() 83 | action = self._agent.get_action(observation) 84 | self._environment.apply_action(action) 85 | 86 | for subscriber in self._subscribers: 87 | subscriber.on_step(observation, action) 88 | 89 | if self._environment.is_episode_complete() or ( 90 | self._max_episode_steps > 0 and self._episode_steps >= self._max_episode_steps 91 | ): 92 | self.mark_episode_complete() 93 | -------------------------------------------------------------------------------- /packages/openpi-client/src/openpi_client/runtime/subscriber.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class Subscriber(abc.ABC): 5 | """Subscribes to events in the runtime. 6 | 7 | Subscribers can be used to save data, visualize, etc. 8 | """ 9 | 10 | @abc.abstractmethod 11 | def on_episode_start(self) -> None: 12 | """Called when an episode starts.""" 13 | 14 | @abc.abstractmethod 15 | def on_step(self, observation: dict, action: dict) -> None: 16 | """Append a step to the episode.""" 17 | 18 | @abc.abstractmethod 19 | def on_episode_end(self) -> None: 20 | """Called when an episode ends.""" 21 | -------------------------------------------------------------------------------- /packages/openpi-client/src/openpi_client/websocket_client_policy.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | from typing import Dict, Tuple 4 | 5 | import websockets.sync.client 6 | from typing_extensions import override 7 | 8 | from openpi_client import base_policy as _base_policy 9 | from openpi_client import msgpack_numpy 10 | 11 | 12 | class WebsocketClientPolicy(_base_policy.BasePolicy): 13 | """Implements the Policy interface by communicating with a server over websocket. 14 | 15 | See WebsocketPolicyServer for a corresponding server implementation. 16 | """ 17 | 18 | def __init__(self, host: str = "0.0.0.0", port: int = 8000) -> None: 19 | self._uri = f"ws://{host}:{port}" 20 | self._packer = msgpack_numpy.Packer() 21 | self._ws, self._server_metadata = self._wait_for_server() 22 | 23 | def get_server_metadata(self) -> Dict: 24 | return self._server_metadata 25 | 26 | def _wait_for_server(self) -> Tuple[websockets.sync.client.ClientConnection, Dict]: 27 | logging.info(f"Waiting for server at {self._uri}...") 28 | while True: 29 | try: 30 | conn = websockets.sync.client.connect(self._uri, compression=None, max_size=None) 31 | metadata = msgpack_numpy.unpackb(conn.recv()) 32 | return conn, metadata 33 | except ConnectionRefusedError: 34 | logging.info("Still waiting for server...") 35 | time.sleep(5) 36 | 37 | @override 38 | def infer(self, obs: Dict) -> Dict: # noqa: UP006 39 | data = self._packer.pack(obs) 40 | self._ws.send(data) 41 | response = self._ws.recv() 42 | if isinstance(response, str): 43 | # we're expecting bytes; if the server sends a string, it's an error. 44 | raise RuntimeError(f"Error in inference server:\n{response}") 45 | return msgpack_numpy.unpackb(response) 46 | 47 | @override 48 | def reset(self) -> None: 49 | pass 50 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "openpi" 3 | version = "0.1.0" 4 | description = "Physical Intelligence open source repo" 5 | readme = "README.md" 6 | requires-python = ">=3.11" 7 | license = { file = "LICENSE" } 8 | dependencies = [ 9 | "augmax>=0.3.4", 10 | "dm-tree>=0.1.8", 11 | "einops>=0.8.0", 12 | "equinox>=0.11.8", 13 | "flatbuffers>=24.3.25", 14 | "flax==0.10.2", 15 | "fsspec[gcs]>=2024.6.0", 16 | "gym-aloha>=0.1.1", 17 | "imageio>=2.36.1", 18 | "jax[cuda12]==0.5.0", 19 | "jaxtyping==0.2.36", 20 | "lerobot", 21 | "ml_collections==1.0.0", 22 | "numpy>=1.26.4", 23 | "numpydantic>=1.6.6", 24 | "opencv-python>=4.10.0.84", 25 | "openpi-client", 26 | "orbax-checkpoint==0.11.1", 27 | "pillow>=11.0.0", 28 | "s3fs>=2024.9.0", 29 | "sentencepiece>=0.2.0", 30 | "torch>=2.5.1", 31 | "tqdm-loggable>=0.2", 32 | "typing-extensions>=4.12.2", 33 | "tyro>=0.9.5", 34 | "wandb==0.18.0", 35 | "boto3>=1.35.7", 36 | "types-boto3[boto3,s3]>=1.35.7", 37 | "filelock>=3.16.1", 38 | "beartype>=0.19.0", 39 | "treescope>=0.1.7", 40 | "transformers==4.48.1", 41 | "imagecodecs>=2024.12.30", 42 | ] 43 | 44 | 45 | [project.urls] 46 | Repository = "https://github.com/Richard-coder-Nai/onetwovla.git" 47 | 48 | [dependency-groups] 49 | dev = [ 50 | "pytest>=8.3.4", 51 | "ruff>=0.8.6", 52 | "pre-commit>=4.0.1", 53 | "ipykernel>=6.29.5", 54 | "ipywidgets>=8.1.5", 55 | "matplotlib>=3.10.0", 56 | "pynvml>=12.0.0", 57 | ] 58 | 59 | 60 | [tool.uv.sources] 61 | openpi-client = { workspace = true } 62 | lerobot = { git = "https://github.com/huggingface/lerobot", rev = "6674e368249472c91382eb54bb8501c94c7f0c56" } 63 | 64 | [tool.uv.workspace] 65 | members = ["packages/*"] 66 | 67 | [[tool.uv.index]] 68 | url = "https://pypi.tuna.tsinghua.edu.cn/simple" 69 | 70 | [tool.ruff] 71 | line-length = 120 72 | target-version = "py311" 73 | extend-exclude = ["docker", "third_party"] 74 | 75 | [tool.ruff.lint] 76 | # https://docs.astral.sh/ruff/rules/ 77 | select = [ 78 | "B", 79 | "C4", 80 | "DTZ", 81 | "E4", 82 | "E7", 83 | "E9", 84 | "F", 85 | "FBT", 86 | "FURB", 87 | "I", 88 | "ICN", 89 | "ISC", 90 | "LOG", 91 | "N", 92 | "PD", 93 | "PERF", 94 | "PIE", 95 | "PLC", 96 | "PLE", 97 | "PLR1", 98 | "PLR5", 99 | "PLW", 100 | "PT", 101 | "PTH", 102 | "Q", 103 | "RET", 104 | "RUF", 105 | "SIM", 106 | "SLF", 107 | "T10", 108 | "T20", 109 | "UP", 110 | "W", 111 | ] 112 | ignore = [ 113 | "F722", # Conflicts with array typing. 114 | "T201", # We use print statements. 115 | "PD008", # Lots of false positives. 116 | "ISC001", # Disabling to support ruff format. 117 | ] 118 | unfixable = [ 119 | "B905", # Fix defaults to strict=False, which is not what we want. 120 | ] 121 | 122 | [tool.ruff.lint.isort] 123 | force-single-line = true 124 | force-sort-within-sections = true 125 | single-line-exclusions = ["collections.abc", "typing", "typing_extensions"] 126 | known-third-party = ["wandb"] 127 | 128 | [build-system] 129 | requires = ["hatchling"] 130 | build-backend = "hatchling.build" 131 | 132 | [tool.pytest.ini_options] 133 | markers = ["manual: should be run manually."] 134 | testpaths = ["src", "scripts", "packages"] 135 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fanqi-Lin/OneTwoVLA/2489662cf1d4f75dbeea3cf9f337c95391e2b836/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/augment_vl_data/finger_mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fanqi-Lin/OneTwoVLA/2489662cf1d4f75dbeea3cf9f337c95391e2b836/scripts/augment_vl_data/finger_mask.jpg -------------------------------------------------------------------------------- /scripts/augment_vl_data/gripper_lens_mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fanqi-Lin/OneTwoVLA/2489662cf1d4f75dbeea3cf9f337c95391e2b836/scripts/augment_vl_data/gripper_lens_mask.jpg -------------------------------------------------------------------------------- /scripts/augment_vl_data/inpainted.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fanqi-Lin/OneTwoVLA/2489662cf1d4f75dbeea3cf9f337c95391e2b836/scripts/augment_vl_data/inpainted.jpg -------------------------------------------------------------------------------- /scripts/augment_vl_data/lens.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fanqi-Lin/OneTwoVLA/2489662cf1d4f75dbeea3cf9f337c95391e2b836/scripts/augment_vl_data/lens.jpg -------------------------------------------------------------------------------- /scripts/compute_norm_stats.py: -------------------------------------------------------------------------------- 1 | """Compute normalization statistics for a config. 2 | 3 | This script is used to compute the normalization statistics for a given config. It 4 | will compute the mean and standard deviation of the data in the dataset and save it 5 | to the config assets directory. 6 | """ 7 | 8 | import numpy as np 9 | import tqdm 10 | import tyro 11 | 12 | import openpi.shared.normalize as normalize 13 | import openpi.training.config as _config 14 | import openpi.training.data_loader as _data_loader 15 | import openpi.transforms as transforms 16 | 17 | 18 | class RemoveStrings(transforms.DataTransformFn): 19 | def __call__(self, x: dict) -> dict: 20 | return {k: v for k, v in x.items() if not np.issubdtype(np.asarray(v).dtype, np.str_)} 21 | 22 | 23 | def create_dataset(config: _config.TrainConfig) -> tuple[_config.DataConfig, _data_loader.Dataset]: 24 | data_config = config.data.create(config.assets_dirs, config.model) 25 | if data_config.repo_id is None: 26 | raise ValueError("Data config must have a repo_id") 27 | dataset, _ = _data_loader.create_dataset(data_config, config.model) 28 | dataset = _data_loader.TransformedDataset( 29 | dataset, 30 | [ 31 | *data_config.repack_transforms.inputs, 32 | *data_config.data_transforms.inputs, 33 | # Remove strings since they are not supported by JAX and are not needed to compute norm stats. 34 | RemoveStrings(), 35 | ], 36 | ) 37 | return data_config, dataset 38 | 39 | 40 | def main(config: _config.TrainConfig, max_frames: int | None = None): 41 | data_config, dataset = create_dataset(config) 42 | 43 | num_frames = len(dataset) 44 | shuffle = False 45 | 46 | if max_frames is not None and max_frames < num_frames: 47 | num_frames = max_frames 48 | shuffle = True 49 | 50 | data_loader = _data_loader.TorchDataLoader( 51 | dataset, 52 | local_batch_size=1, 53 | num_workers=8, 54 | shuffle=shuffle, 55 | num_batches=num_frames, 56 | ) 57 | 58 | keys = ["state", "actions"] 59 | stats = {key: normalize.RunningStats() for key in keys} 60 | 61 | for batch in tqdm.tqdm(data_loader, total=num_frames, desc="Computing stats"): 62 | for key in keys: 63 | values = np.asarray(batch[key][0]) 64 | stats[key].update(values.reshape(-1, values.shape[-1])) 65 | 66 | norm_stats = {key: stats.get_statistics() for key, stats in stats.items()} 67 | 68 | output_path = config.assets_dirs / data_config.repo_id 69 | if hasattr(data_config, "getitem_type"): 70 | output_path = config.assets_dirs / data_config.repo_id / data_config.getitem_type 71 | 72 | print(f"Writing stats to: {output_path}") 73 | normalize.save(output_path, norm_stats) 74 | 75 | 76 | if __name__ == "__main__": 77 | main(_config.cli()) 78 | -------------------------------------------------------------------------------- /scripts/docker/compose.yml: -------------------------------------------------------------------------------- 1 | # Run with: 2 | # docker compose -f scripts/compose.yml up --build 3 | services: 4 | openpi_server: 5 | image: openpi_server 6 | build: 7 | context: .. 8 | dockerfile: scripts/docker/serve_policy.Dockerfile 9 | init: true 10 | tty: true 11 | network_mode: host 12 | # Populate configured openpi data home to /openpi_assets inside the container. 13 | # Populate aws credential inside the container. 14 | volumes: 15 | - $PWD:/app 16 | - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets 17 | environment: 18 | - SERVER_ARGS 19 | - OPENPI_DATA_HOME=/openpi_assets 20 | - IS_DOCKER=true 21 | 22 | # Comment out this block if not running on a machine with GPUs. 23 | deploy: 24 | resources: 25 | reservations: 26 | devices: 27 | - driver: nvidia 28 | count: 1 29 | capabilities: [gpu] 30 | -------------------------------------------------------------------------------- /scripts/docker/install_docker_ubuntu22.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Add Docker's official GPG key: 4 | sudo apt-get update 5 | sudo apt-get install -y ca-certificates curl 6 | sudo install -m 0755 -d /etc/apt/keyrings 7 | sudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc 8 | sudo chmod a+r /etc/apt/keyrings/docker.asc 9 | 10 | # Add the repository to Apt sources: 11 | echo \ 12 | "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \ 13 | $(. /etc/os-release && echo "$VERSION_CODENAME") stable" | 14 | sudo tee /etc/apt/sources.list.d/docker.list >/dev/null 15 | sudo apt-get update 16 | 17 | sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin 18 | 19 | # Add current user to the 'docker' group, which allows them to use docker commands (docker build, docker run, etc). 20 | # See https://docs.docker.com/engine/install/linux-postinstall/ 21 | username=$(whoami) 22 | sudo usermod -aG docker $username 23 | 24 | # Configure docker to start automatically on system boot. 25 | sudo systemctl enable docker.service 26 | sudo systemctl enable containerd.service 27 | 28 | # https://forums.docker.com/t/docker-credential-desktop-exe-executable-file-not-found-in-path-using-wsl2/100225/5 29 | if [ ~/.docker/config.json ]; then 30 | sed -i 's/credsStore/credStore/g' ~/.docker/config.json 31 | fi 32 | 33 | echo "" 34 | echo "********************************************************************" 35 | echo "**** Restart to allow Docker permission changes to take effect. ****" 36 | echo "********************************************************************" 37 | echo "" 38 | -------------------------------------------------------------------------------- /scripts/docker/install_nvidia_container_toolkit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Installs the NVIDIA Container Toolkit, which allows Docker containers to access NVIDIA GPUs. 4 | # NVIDIA's official documentation: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html 5 | 6 | curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg && 7 | curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | 8 | sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | 9 | sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list 10 | 11 | # NVIDIA's documenation omits 'sudo' in the following command, but it is required. 12 | sudo sed -i -e '/experimental/ s/^#//g' /etc/apt/sources.list.d/nvidia-container-toolkit.list 13 | sudo apt-get update 14 | sudo apt-get install -y nvidia-container-toolkit 15 | 16 | sudo nvidia-ctk runtime configure --runtime=docker 17 | sudo systemctl restart docker 18 | -------------------------------------------------------------------------------- /scripts/docker/serve_policy.Dockerfile: -------------------------------------------------------------------------------- 1 | # Dockerfile for serving a PI policy. 2 | # Based on UV's instructions: https://docs.astral.sh/uv/guides/integration/docker/#developing-in-a-container 3 | 4 | # Build the container: 5 | # docker build . -t openpi_server -f scripts/docker/serve_policy.Dockerfile 6 | 7 | # Run the container: 8 | # docker run --rm -it --network=host -v .:/app --gpus=all openpi_server /bin/bash 9 | 10 | FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0 11 | COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/ 12 | 13 | WORKDIR /app 14 | 15 | # Needed because LeRobot uses git-lfs. 16 | RUN apt-get update && apt-get install -y git git-lfs 17 | 18 | # Copy from the cache instead of linking since it's a mounted volume 19 | ENV UV_LINK_MODE=copy 20 | 21 | # Write the virtual environment outside of the project directory so it doesn't 22 | # leak out of the container when we mount the application code. 23 | ENV UV_PROJECT_ENVIRONMENT=/.venv 24 | 25 | # Install the project's dependencies using the lockfile and settings 26 | RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT 27 | RUN --mount=type=cache,target=/root/.cache/uv \ 28 | --mount=type=bind,source=uv.lock,target=uv.lock \ 29 | --mount=type=bind,source=pyproject.toml,target=pyproject.toml \ 30 | --mount=type=bind,source=packages/openpi-client/pyproject.toml,target=packages/openpi-client/pyproject.toml \ 31 | --mount=type=bind,source=packages/openpi-client/src,target=packages/openpi-client/src \ 32 | GIT_LFS_SKIP_SMUDGE=1 uv sync --frozen --no-install-project --no-dev 33 | 34 | CMD /bin/bash -c "uv run scripts/serve_policy.py $SERVER_ARGS" 35 | -------------------------------------------------------------------------------- /scripts/serve_policy.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import enum 3 | import logging 4 | import socket 5 | 6 | import tyro 7 | 8 | from openpi.policies import policy as _policy 9 | from openpi.policies import policy_config as _policy_config 10 | from openpi.serving import websocket_policy_server 11 | from openpi.training import config as _config 12 | 13 | 14 | class EnvMode(enum.Enum): 15 | """Supported environments.""" 16 | 17 | ALOHA = "aloha" 18 | ALOHA_SIM = "aloha_sim" 19 | DROID = "droid" 20 | LIBERO = "libero" 21 | FAST_BASE = "fast_base" 22 | BASE = "base" 23 | 24 | 25 | @dataclasses.dataclass 26 | class Checkpoint: 27 | """Load a policy from a trained checkpoint.""" 28 | 29 | # Training config name (e.g., "pi0_aloha_sim"). 30 | config: str 31 | # Checkpoint directory (e.g., "checkpoints/pi0_aloha_sim/exp/10000"). 32 | dir: str 33 | 34 | 35 | @dataclasses.dataclass 36 | class Default: 37 | """Use the default policy for the given environment.""" 38 | 39 | 40 | @dataclasses.dataclass 41 | class Args: 42 | """Arguments for the serve_policy script.""" 43 | 44 | # Environment to serve the policy for. This is only used when serving default policies. 45 | env: EnvMode = EnvMode.ALOHA_SIM 46 | 47 | # If provided, will be used in case the "prompt" key is not present in the data, or if the model doesn't have a default 48 | # prompt. 49 | default_prompt: str | None = None 50 | 51 | # Port to serve the policy on. 52 | port: int = 8000 53 | # Record the policy's behavior for debugging. 54 | record: bool = False 55 | 56 | # Specifies how to load the policy. If not provided, the default policy for the environment will be used. 57 | policy: Checkpoint | Default = dataclasses.field(default_factory=Default) 58 | 59 | 60 | # Default checkpoints that should be used for each environment. 61 | DEFAULT_CHECKPOINT: dict[EnvMode, Checkpoint] = { 62 | EnvMode.ALOHA: Checkpoint( 63 | config="pi0_aloha", 64 | dir="s3://openpi-assets/checkpoints/pi0_base", 65 | ), 66 | EnvMode.ALOHA_SIM: Checkpoint( 67 | config="pi0_aloha_sim", 68 | dir="s3://openpi-assets/checkpoints/pi0_aloha_sim", 69 | ), 70 | EnvMode.DROID: Checkpoint( 71 | config="pi0_fast_droid", 72 | dir="s3://openpi-assets/checkpoints/pi0_fast_droid", 73 | ), 74 | EnvMode.LIBERO: Checkpoint( 75 | config="pi0_fast_libero", 76 | dir="s3://openpi-assets/checkpoints/pi0_fast_libero", 77 | ), 78 | EnvMode.FAST_BASE: Checkpoint( 79 | config="pi0_fast_umi", 80 | dir="s3://openpi-assets/checkpoints/pi0_fast_base", 81 | ), 82 | EnvMode.BASE: Checkpoint( 83 | config="pi0_umi", 84 | dir="s3://openpi-assets/checkpoints/pi0_base", 85 | ), 86 | } 87 | 88 | 89 | def create_default_policy(env: EnvMode, *, default_prompt: str | None = None) -> _policy.Policy: 90 | """Create a default policy for the given environment.""" 91 | if checkpoint := DEFAULT_CHECKPOINT.get(env): 92 | return _policy_config.create_trained_policy( 93 | _config.get_config(checkpoint.config), checkpoint.dir, default_prompt=default_prompt 94 | ) 95 | raise ValueError(f"Unsupported environment mode: {env}") 96 | 97 | 98 | def create_policy(args: Args) -> _policy.Policy: 99 | """Create a policy from the given arguments.""" 100 | match args.policy: 101 | case Checkpoint(): 102 | return _policy_config.create_trained_policy( 103 | _config.get_config(args.policy.config), args.policy.dir, default_prompt=args.default_prompt 104 | ) 105 | case Default(): 106 | return create_default_policy(args.env, default_prompt=args.default_prompt) 107 | 108 | 109 | def main(args: Args) -> None: 110 | policy = create_policy(args) 111 | policy_metadata = policy.metadata 112 | 113 | # Record the policy's behavior. 114 | if args.record: 115 | policy = _policy.PolicyRecorder(policy, "policy_records") 116 | 117 | hostname = socket.gethostname() 118 | local_ip = socket.gethostbyname(hostname) 119 | logging.info("Creating server (host: %s, ip: %s)", hostname, local_ip) 120 | 121 | server = websocket_policy_server.WebsocketPolicyServer( 122 | policy=policy, 123 | host="0.0.0.0", 124 | port=args.port, 125 | metadata=policy_metadata, 126 | ) 127 | server.serve_forever() 128 | 129 | 130 | if __name__ == "__main__": 131 | logging.basicConfig(level=logging.INFO, force=True) 132 | main(tyro.cli(Args)) 133 | -------------------------------------------------------------------------------- /scripts/train_test.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import os 3 | import pathlib 4 | 5 | import pytest 6 | 7 | os.environ["JAX_PLATFORMS"] = "cpu" 8 | 9 | from openpi.training import config as _config 10 | 11 | from . import train 12 | 13 | 14 | @pytest.mark.parametrize("config_name", ["debug"]) 15 | def test_train(tmp_path: pathlib.Path, config_name: str): 16 | config = dataclasses.replace( 17 | _config._CONFIGS_DICT[config_name], # noqa: SLF001 18 | batch_size=2, 19 | checkpoint_base_dir=tmp_path / "checkpoint", 20 | exp_name="test", 21 | overwrite=False, 22 | resume=False, 23 | num_train_steps=2, 24 | log_interval=1, 25 | ) 26 | train.main(config) 27 | 28 | # test resuming 29 | config = dataclasses.replace(config, resume=True, num_train_steps=4) 30 | train.main(config) 31 | -------------------------------------------------------------------------------- /src/openpi/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fanqi-Lin/OneTwoVLA/2489662cf1d4f75dbeea3cf9f337c95391e2b836/src/openpi/__init__.py -------------------------------------------------------------------------------- /src/openpi/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pynvml 4 | import pytest 5 | 6 | 7 | def set_jax_cpu_backend_if_no_gpu() -> None: 8 | try: 9 | pynvml.nvmlInit() 10 | pynvml.nvmlShutdown() 11 | except pynvml.NVMLError: 12 | # No GPU found. 13 | os.environ["JAX_PLATFORMS"] = "cpu" 14 | 15 | 16 | def pytest_configure(config: pytest.Config) -> None: 17 | set_jax_cpu_backend_if_no_gpu() 18 | -------------------------------------------------------------------------------- /src/openpi/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fanqi-Lin/OneTwoVLA/2489662cf1d4f75dbeea3cf9f337c95391e2b836/src/openpi/models/__init__.py -------------------------------------------------------------------------------- /src/openpi/models/lora.py: -------------------------------------------------------------------------------- 1 | import math 2 | import re 3 | 4 | import flax.linen as nn 5 | import flax.struct as struct 6 | import jax.numpy as jnp 7 | 8 | import openpi.shared.array_typing as at 9 | 10 | 11 | @struct.dataclass 12 | class LoRAConfig: 13 | """Configuration for LoRA.""" 14 | 15 | # LoRA rank. 16 | rank: int 17 | # LoRA scaling factor. 18 | alpha: float = 1.0 19 | # Initialization function for LoRA parameters. 20 | init_fn: nn.initializers.Initializer = nn.initializers.normal(stddev=0.01) 21 | # Enable rank-stabilized LoRA: https://arxiv.org/pdf/2312.03732 22 | rslora: bool = False 23 | # Axes in the weight to apply LoRA to. Should typically be the last two axes. 24 | axes: tuple[int, int] = (-2, -1) 25 | # Axis label which is used by LoRA in einsum equations. Must not be present in the original equation. 26 | label: str = "L" 27 | 28 | @property 29 | def scaling_value(self) -> float: 30 | return self.alpha / math.sqrt(self.rank) if self.rslora else self.alpha / self.rank 31 | 32 | 33 | class Einsum(nn.Module): 34 | """Einsum with LoRA support. Can be used as a drop-in replacement for the Gemma Einsum.""" 35 | 36 | # Shape of the weight. 37 | shape: tuple[int, ...] 38 | # Initialization function for the weight. 39 | init_fn: nn.initializers.Initializer = nn.initializers.zeros 40 | # If not None, apply LoRA to the weight. 41 | lora_config: LoRAConfig | None = None 42 | 43 | def setup(self): 44 | self.w = self.param("w", self.init_fn, self.shape) 45 | 46 | if config := self.lora_config: 47 | # Setup LoRA parameters. 48 | shape_a, shape_b = list(self.shape), list(self.shape) 49 | shape_a[config.axes[1]] = config.rank 50 | shape_b[config.axes[0]] = config.rank 51 | self.w_a = self.param("lora_a", config.init_fn, shape_a) 52 | self.w_b = self.param("lora_b", config.init_fn, shape_b) 53 | 54 | @nn.compact 55 | def __call__(self, eqn: str, x): 56 | dtype = x.dtype # original dtype, could be half-precision 57 | result = jnp.einsum(eqn, x, self.w.astype(dtype)) 58 | 59 | if config := self.lora_config: 60 | eqn_a, eqn_b = self._make_lora_eqns(eqn) 61 | lora = jnp.einsum(eqn_a, x, self.w_a.astype(dtype)) 62 | lora = jnp.einsum(eqn_b, lora, self.w_b.astype(dtype)) 63 | result = result + lora * config.scaling_value 64 | 65 | return result 66 | 67 | def _make_lora_eqns(self, eqn: str) -> tuple[str, str]: 68 | if "L" in eqn: 69 | raise ValueError(f"L already in eqn: {eqn}") 70 | if not (m := re.match("(.*),(.*)->(.*)", eqn)): 71 | raise ValueError(f"Unsupported einsum eqn: {eqn}") 72 | lhs, rhs, out = m.groups() 73 | 74 | assert self.lora_config is not None 75 | a_label, b_label = (rhs[x] for x in self.lora_config.axes) 76 | label = self.lora_config.label 77 | 78 | a_rhs = rhs.replace(b_label, label) 79 | a_out = out.replace(b_label, label) 80 | eqn_a = f"{lhs},{a_rhs}->{a_out}" 81 | 82 | b_rhs = rhs.replace(a_label, label) 83 | eqn_b = f"{a_out},{b_rhs}->{out}" 84 | 85 | return eqn_a, eqn_b 86 | 87 | 88 | class FeedForward(nn.Module): 89 | """Feed forward module.""" 90 | 91 | features: int 92 | hidden_dim: int 93 | # If not None, apply LoRA to the weight. 94 | lora_config: LoRAConfig | None = None 95 | 96 | def setup(self): 97 | self.w_gating = self.param( 98 | "gating_einsum", 99 | nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)), 100 | (2, self.features, self.hidden_dim), 101 | ) 102 | self.w_linear = self.param( 103 | "linear", 104 | nn.initializers.lecun_normal(in_axis=-2, out_axis=-1), 105 | (self.hidden_dim, self.features), 106 | ) 107 | self.w_gating_lora = None 108 | self.w_linear_lora = None 109 | if self.lora_config: 110 | # Setup LoRA parameters. 111 | # TODO: follow up with a simplified init_fn api. 112 | self.w_gating_lora = ( 113 | self.param("gating_einsum_lora_a", self.lora_config.init_fn, (2, self.features, self.lora_config.rank)), 114 | self.param( 115 | "gating_einsum_lora_b", self.lora_config.init_fn, (2, self.lora_config.rank, self.hidden_dim) 116 | ), 117 | ) 118 | self.w_linear_lora = ( 119 | self.param("linear_lora_a", self.lora_config.init_fn, (self.hidden_dim, self.lora_config.rank)), 120 | self.param("linear_lora_b", self.lora_config.init_fn, (self.lora_config.rank, self.features)), 121 | ) 122 | 123 | @nn.compact 124 | def __call__(self, x): 125 | dtype = x.dtype # original dtype, could be half-precision 126 | ff_gate = self._dot( 127 | x, 128 | self.w_gating[0], 129 | None if self.w_gating_lora is None else (self.w_gating_lora[0][0], self.w_gating_lora[1][0]), 130 | ) 131 | gate_value = nn.gelu(ff_gate) 132 | 133 | ff1 = self._dot( 134 | x, 135 | self.w_gating[1], 136 | None if self.w_gating_lora is None else (self.w_gating_lora[0][1], self.w_gating_lora[1][1]), 137 | ) 138 | activations = gate_value * ff1 139 | 140 | outputs = self._dot(activations, self.w_linear, self.w_linear_lora) 141 | assert outputs.dtype == dtype 142 | return outputs 143 | 144 | def _dot(self, x: at.Array, w: at.Array, lora_weights: tuple[at.Array, at.Array] | None) -> at.Array: 145 | base = jnp.dot(x, w.astype(x.dtype)) 146 | if lora_weights is None: 147 | return base 148 | return base + jnp.dot(jnp.dot(x, lora_weights[0].astype(x.dtype)), lora_weights[1].astype(x.dtype)) 149 | -------------------------------------------------------------------------------- /src/openpi/models/lora_test.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | import jax 3 | import jax.numpy as jnp 4 | 5 | import openpi.models.lora as lora 6 | 7 | 8 | def test_lora_einsum_params_shape(): 9 | shape = (3, 8, 32, 4) # (3KDH) 10 | einsum = lora.Einsum(shape) 11 | lora0 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2)) 12 | lora1 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2, axes=(1, 2))) 13 | 14 | key = jax.random.key(0) 15 | x = jax.random.normal(key, (8, 64, 32)) # (BSD) 16 | eqn = "BSD,3KDH->3BSKH" 17 | 18 | # Ensure that lora parameters are not initialized when LoRA is not used. 19 | params = einsum.init(key, eqn, x) 20 | assert "lora_a" not in params["params"] 21 | assert "lora_b" not in params["params"] 22 | 23 | # Check that default axes work. 24 | params_lora0 = lora0.init(key, eqn, x) 25 | assert params_lora0["params"]["lora_a"].shape == (3, 8, 32, 2) 26 | assert params_lora0["params"]["lora_b"].shape == (3, 8, 2, 4) 27 | 28 | # Check that user provided axes work. 29 | params_lora1 = lora1.init(key, eqn, x) 30 | assert params_lora1["params"]["lora_a"].shape == (3, 8, 2, 4) 31 | assert params_lora1["params"]["lora_b"].shape == (3, 2, 32, 4) 32 | 33 | 34 | def test_lora_einsum_same_output(): 35 | shape = (3, 8, 32, 4) # (3KDH) 36 | einsum = lora.Einsum(shape) 37 | einsum_lora = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros)) 38 | 39 | key = jax.random.key(0) 40 | x = jax.random.normal(key, (8, 64, 32)) # (BSD) 41 | eqn = "BSD,3KDH->3BSKH" 42 | 43 | params = einsum.init(key, eqn, x) 44 | output = einsum.apply(params, eqn, x) 45 | 46 | params_lora = einsum_lora.init(key, eqn, x) 47 | output_lora = einsum_lora.apply(params_lora, eqn, x) 48 | 49 | # Results are the same since the LoRA parameters are initialized to zeros. 50 | assert jnp.allclose(output, output_lora) 51 | 52 | 53 | def test_lora_ffn_params_shape(): 54 | ffn = lora.FeedForward(features=8, hidden_dim=32) 55 | ffn_lora = lora.FeedForward( 56 | features=8, 57 | hidden_dim=32, 58 | lora_config=lora.LoRAConfig(rank=2), 59 | ) 60 | 61 | key = jax.random.key(0) 62 | x = jax.random.normal(key, (2, 8)) 63 | 64 | params = ffn.init(key, x) 65 | assert params["params"]["gating_einsum"].shape == (2, 8, 32) 66 | assert params["params"]["linear"].shape == (32, 8) 67 | 68 | params_lora = ffn_lora.init(key, x) 69 | assert params_lora["params"]["gating_einsum"].shape == (2, 8, 32) 70 | assert params_lora["params"]["linear"].shape == (32, 8) 71 | assert params_lora["params"]["gating_einsum_lora_a"].shape == (2, 8, 2) 72 | assert params_lora["params"]["gating_einsum_lora_b"].shape == (2, 2, 32) 73 | assert params_lora["params"]["linear_lora_a"].shape == (32, 2) 74 | assert params_lora["params"]["linear_lora_b"].shape == (2, 8) 75 | 76 | 77 | def test_lora_ffn_same_output(): 78 | ffn = lora.FeedForward(features=8, hidden_dim=32) 79 | ffn_lora = lora.FeedForward( 80 | features=8, 81 | hidden_dim=32, 82 | lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros), 83 | ) 84 | 85 | key = jax.random.key(0) 86 | x = jax.random.normal(key, (2, 8)) 87 | 88 | params = ffn.init(key, x) 89 | output = ffn.apply(params, x) 90 | 91 | params_lora = ffn_lora.init(key, x) 92 | output_lora = ffn_lora.apply(params_lora, x) 93 | 94 | assert jnp.allclose(output, output_lora) 95 | -------------------------------------------------------------------------------- /src/openpi/models/model_test.py: -------------------------------------------------------------------------------- 1 | from flax import nnx 2 | import jax 3 | import pytest 4 | 5 | from openpi.models import model as _model 6 | from openpi.models import pi0 7 | from openpi.models import pi0_fast 8 | from openpi.models import pi0_fuse 9 | from openpi.shared import download 10 | from openpi.shared import nnx_utils 11 | 12 | 13 | def test_pi0_model(): 14 | key = jax.random.key(0) 15 | config = pi0.Pi0Config() 16 | model = config.create(key) 17 | 18 | batch_size = 2 19 | obs, act = config.fake_obs(batch_size), config.fake_act(batch_size) 20 | 21 | loss, loss_info = nnx_utils.module_jit(model.compute_loss)(key, obs, act) 22 | assert loss.shape == (batch_size, config.action_horizon) 23 | 24 | actions, action_info = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10) 25 | assert actions.shape == (batch_size, model.action_horizon, model.action_dim) 26 | 27 | 28 | def test_pi0_lora_model(): 29 | key = jax.random.key(0) 30 | config = pi0.Pi0Config(paligemma_variant="gemma_2b_lora") 31 | model = config.create(key) 32 | 33 | batch_size = 2 34 | obs, act = config.fake_obs(batch_size), config.fake_act(batch_size) 35 | 36 | loss, loss_info = nnx_utils.module_jit(model.compute_loss)(key, obs, act) 37 | assert loss.shape == (batch_size, config.action_horizon) 38 | 39 | actions, action_info = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10) 40 | assert actions.shape == (batch_size, model.action_horizon, model.action_dim) 41 | 42 | 43 | def test_pi0_fast_model(): 44 | key = jax.random.key(0) 45 | config = pi0_fast.Pi0FASTConfig() 46 | model = config.create(key) 47 | 48 | batch_size = 2 49 | obs, act = config.fake_obs(batch_size), config.fake_act(batch_size) 50 | 51 | loss, loss_info = nnx_utils.module_jit(model.compute_loss)(key, obs, act) 52 | assert loss.shape == (batch_size,) 53 | 54 | actions, action_info = nnx_utils.module_jit(model.sample_actions)(key, obs) 55 | assert actions.shape == (batch_size, 256) 56 | 57 | 58 | def test_pi0_fast_lora_model(): 59 | key = jax.random.key(0) 60 | config = pi0_fast.Pi0FASTConfig(paligemma_variant="gemma_2b_lora") 61 | model = config.create(key) 62 | 63 | batch_size = 2 64 | obs, act = config.fake_obs(batch_size), config.fake_act(batch_size) 65 | 66 | loss, loss_info = nnx_utils.module_jit(model.compute_loss)(key, obs, act) 67 | assert loss.shape == (batch_size,) 68 | 69 | actions, action_info = nnx_utils.module_jit(model.sample_actions)(key, obs) 70 | assert actions.shape == (batch_size, 256) 71 | 72 | lora_filter = nnx_utils.PathRegex(".*lora.*") 73 | model_state = nnx.state(model) 74 | 75 | lora_state_elems = list(model_state.filter(lora_filter)) 76 | assert len(lora_state_elems) > 0 77 | 78 | 79 | def test_pi0_fuse_mosel(): 80 | key = jax.random.key(0) 81 | config = pi0_fuse.Pi0FuseConfig() 82 | model = config.create(key) 83 | 84 | batch_size = 2 85 | obs, act = config.fake_obs(batch_size), config.fake_act(batch_size) 86 | 87 | loss, loss_info = nnx_utils.module_jit(model.compute_loss)(key, obs, act) 88 | assert loss.shape == (batch_size,) 89 | 90 | actions, action_info = nnx_utils.module_jit(model.sample_actions)(key, obs) 91 | assert actions.shape == (batch_size, model.action_horizon, model.action_dim) 92 | 93 | 94 | @pytest.mark.manual 95 | def test_model_restore(): 96 | key = jax.random.key(0) 97 | config = pi0.Pi0Config() 98 | 99 | batch_size = 2 100 | obs, act = config.fake_obs(batch_size), config.fake_act(batch_size) 101 | 102 | model = config.load( 103 | _model.restore_params(download.maybe_download("s3://openpi-assets/checkpoints/pi0_base/params")) 104 | ) 105 | 106 | loss, loss_info = model.compute_loss(key, obs, act) 107 | assert loss.shape == (batch_size, config.action_horizon) 108 | 109 | actions, action_info = model.sample_actions(key, obs, num_steps=10) 110 | assert actions.shape == (batch_size, model.action_horizon, model.action_dim) 111 | -------------------------------------------------------------------------------- /src/openpi/models/pi0_test.py: -------------------------------------------------------------------------------- 1 | import flax.nnx as nnx 2 | import jax 3 | 4 | import openpi.models.pi0 as _pi0 5 | 6 | 7 | def _get_frozen_state(config: _pi0.Pi0Config) -> nnx.State: 8 | abstract_model = nnx.eval_shape(config.create, jax.random.key(0)) 9 | 10 | freeze_filter = config.get_freeze_filter() 11 | return nnx.state(abstract_model, nnx.All(nnx.Param, freeze_filter)).flat_state() 12 | 13 | 14 | def test_pi0_full_finetune(): 15 | config = _pi0.Pi0Config() 16 | state = _get_frozen_state(config) 17 | assert len(state) == 0 18 | 19 | 20 | def test_pi0_gemma_lora(): 21 | config = _pi0.Pi0Config(paligemma_variant="gemma_2b_lora") 22 | state = _get_frozen_state(config) 23 | assert len(state) == 9 24 | assert all("lora" not in p for p in state) 25 | assert all("llm" in p for p in state) 26 | assert all("_1" not in p for p in state) 27 | 28 | 29 | def test_pi0_action_expert_lora(): 30 | config = _pi0.Pi0Config(action_expert_variant="gemma_300m_lora") 31 | state = _get_frozen_state(config) 32 | # excluding embedder, rest of the params should be same as gemma_lora. 33 | assert len(state) == 8 34 | assert all("lora" not in p for p in state) 35 | assert all("llm" in p for p in state) 36 | # all frozen params should have _1 in their path since it's the action expert. 37 | assert all(any("_1" in p for p in path) for path in state) 38 | 39 | 40 | def test_pi0_all_lora(): 41 | config = _pi0.Pi0Config(paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora") 42 | state = _get_frozen_state(config) 43 | # sum of gemma_lora and action_expert_lora's frozen params. 44 | assert len(state) == 17 45 | assert all("lora" not in p for p in state) 46 | assert all("llm" in p for p in state) 47 | -------------------------------------------------------------------------------- /src/openpi/models/tokenizer_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from openpi.models import tokenizer as _tokenizer 4 | 5 | 6 | def test_tokenize(): 7 | tokenizer = _tokenizer.PaligemmaTokenizer(max_len=10) 8 | tokens, masks = tokenizer.tokenize("Hello, world!") 9 | 10 | assert tokens.shape == (10,) 11 | assert masks.shape == (10,) 12 | 13 | 14 | def test_fast_tokenizer(): 15 | prompt = "Hello, world!" 16 | state = np.random.rand(5).astype(np.float32) 17 | action = np.random.rand(3, 2).astype(np.float32) 18 | tokenizer = _tokenizer.FASTTokenizer(max_len=256) 19 | tokens, token_masks, ar_masks, loss_masks = tokenizer.tokenize(prompt, state, action) 20 | 21 | assert tokens.shape == (256,) 22 | assert token_masks.shape == (256,) 23 | assert ar_masks.shape == (256,) 24 | assert loss_masks.shape == (256,) 25 | 26 | act = tokenizer.extract_actions(tokens, 3, 2) 27 | assert act.shape == (3, 2) 28 | -------------------------------------------------------------------------------- /src/openpi/policies/droid_policy.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | import einops 4 | import numpy as np 5 | 6 | from openpi import transforms 7 | from openpi.models import model as _model 8 | 9 | 10 | def make_droid_example() -> dict: 11 | """Creates a random input example for the Droid policy.""" 12 | return { 13 | "observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), 14 | "observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), 15 | "observation/joint_position": np.random.rand(7), 16 | "observation/gripper_position": np.random.rand(1), 17 | "prompt": "do something", 18 | } 19 | 20 | 21 | def _parse_image(image) -> np.ndarray: 22 | image = np.asarray(image) 23 | if np.issubdtype(image.dtype, np.floating): 24 | image = (255 * image).astype(np.uint8) 25 | if image.shape[0] == 3: 26 | image = einops.rearrange(image, "c h w -> h w c") 27 | return image 28 | 29 | 30 | @dataclasses.dataclass(frozen=True) 31 | class DroidInputs(transforms.DataTransformFn): 32 | # The action dimension of the model. Will be used to pad state and actions. 33 | action_dim: int 34 | 35 | # Determines which model will be used. 36 | model_type: _model.ModelType = _model.ModelType.PI0 37 | 38 | def __call__(self, data: dict) -> dict: 39 | state = np.concatenate([data["observation/joint_position"], data["observation/gripper_position"]]) 40 | state = transforms.pad_to_dim(state, self.action_dim) 41 | 42 | # Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically 43 | # stores as float32 (C,H,W), gets skipped for policy inference 44 | base_image = _parse_image(data["observation/exterior_image_1_left"]) 45 | wrist_image = _parse_image(data["observation/wrist_image_left"]) 46 | 47 | match self.model_type: 48 | case _model.ModelType.PI0: 49 | names = ("base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb") 50 | images = (base_image, wrist_image, np.zeros_like(base_image)) 51 | image_masks = (np.True_, np.True_, np.False_) 52 | case _model.ModelType.PI0_FAST: 53 | names = ("base_0_rgb", "base_1_rgb", "wrist_0_rgb") 54 | # We don't mask out padding images for FAST models. 55 | images = (base_image, np.zeros_like(base_image), wrist_image) 56 | image_masks = (np.True_, np.True_, np.True_) 57 | case _: 58 | raise ValueError(f"Unsupported model type: {self.model_type}") 59 | 60 | inputs = { 61 | "state": state, 62 | "image": dict(zip(names, images, strict=True)), 63 | "image_mask": dict(zip(names, image_masks, strict=True)), 64 | } 65 | 66 | if "actions" in data: 67 | inputs["actions"] = data["actions"] 68 | 69 | if "prompt" in data: 70 | inputs["prompt"] = data["prompt"] 71 | 72 | return inputs 73 | 74 | 75 | @dataclasses.dataclass(frozen=True) 76 | class DroidOutputs(transforms.DataTransformFn): 77 | def __call__(self, data: dict) -> dict: 78 | # Only return the first 8 dims. 79 | return {"actions": np.asarray(data["actions"][:, :8])} 80 | -------------------------------------------------------------------------------- /src/openpi/policies/libero_policy.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | import einops 4 | import numpy as np 5 | 6 | from openpi import transforms 7 | from openpi.models import model as _model 8 | 9 | 10 | def make_libero_example() -> dict: 11 | """Creates a random input example for the Libero policy.""" 12 | return { 13 | "observation/state": np.random.rand(8), 14 | "observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), 15 | "observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), 16 | "prompt": "do something", 17 | } 18 | 19 | 20 | def _parse_image(image) -> np.ndarray: 21 | image = np.asarray(image) 22 | if np.issubdtype(image.dtype, np.floating): 23 | image = (255 * image).astype(np.uint8) 24 | if image.shape[0] == 3: 25 | image = einops.rearrange(image, "c h w -> h w c") 26 | return image 27 | 28 | 29 | @dataclasses.dataclass(frozen=True) 30 | class LiberoInputs(transforms.DataTransformFn): 31 | # The action dimension of the model. Will be used to pad state and actions for pi0 model (not pi0-FAST). 32 | action_dim: int 33 | 34 | # Determines which model will be used. 35 | model_type: _model.ModelType = _model.ModelType.PI0 36 | 37 | def __call__(self, data: dict) -> dict: 38 | mask_padding = self.model_type == _model.ModelType.PI0 # We don't mask for pi0-FAST. 39 | 40 | # Get the state. We are padding from 8 to the model action dim. 41 | # For pi0-FAST, we don't pad the state (action_dim = 7, which is < 8, so pad is skipped). 42 | state = transforms.pad_to_dim(data["observation/state"], self.action_dim) 43 | 44 | # Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically 45 | # stores as float32 (C,H,W), gets skipped for policy inference 46 | base_image = _parse_image(data["observation/image"]) 47 | wrist_image = _parse_image(data["observation/wrist_image"]) 48 | 49 | inputs = { 50 | "state": state, 51 | "image": { 52 | "base_0_rgb": base_image, 53 | "left_wrist_0_rgb": wrist_image, 54 | "right_wrist_0_rgb": np.zeros_like(base_image), 55 | }, 56 | "image_mask": { 57 | "base_0_rgb": np.True_, 58 | "left_wrist_0_rgb": np.True_, 59 | "right_wrist_0_rgb": np.False_ if mask_padding else np.True_, 60 | }, 61 | } 62 | 63 | # Actions are only available during training. 64 | if "actions" in data: 65 | # We are padding from 7 to the model action dim. 66 | # For pi0-FAST, this is a no-op (since action_dim = 7). 67 | actions = transforms.pad_to_dim(data["actions"], self.action_dim) 68 | inputs["actions"] = actions 69 | 70 | if "prompt" in data: 71 | inputs["prompt"] = data["prompt"] 72 | 73 | return inputs 74 | 75 | 76 | @dataclasses.dataclass(frozen=True) 77 | class LiberoOutputs(transforms.DataTransformFn): 78 | def __call__(self, data: dict) -> dict: 79 | # Only return the first 7 dims. 80 | return {"actions": np.asarray(data["actions"][:, :7])} 81 | -------------------------------------------------------------------------------- /src/openpi/policies/policy_config.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | import dataclasses 3 | import logging 4 | import pathlib 5 | from typing import Any 6 | 7 | import jax.numpy as jnp 8 | 9 | import openpi.models.model as _model 10 | import openpi.policies.policy as _policy 11 | import openpi.shared.download as download 12 | from openpi.training import checkpoints as _checkpoints 13 | from openpi.training import config as _config 14 | import openpi.transforms as transforms 15 | 16 | 17 | @dataclasses.dataclass 18 | class PolicyConfig: 19 | model: _model.BaseModel 20 | norm_stats: dict[str, transforms.NormStats] 21 | 22 | input_layers: Sequence[transforms.DataTransformFn] 23 | output_layers: Sequence[transforms.DataTransformFn] 24 | 25 | model_type: _model.ModelType = _model.ModelType.PI0 26 | default_prompt: str | None = None 27 | sample_kwargs: dict[str, Any] | None = None 28 | 29 | 30 | def create_trained_policy( 31 | train_config: _config.TrainConfig | _config.UMITrainConfig, 32 | checkpoint_dir: pathlib.Path | str, 33 | *, 34 | repack_transforms: transforms.Group | None = None, 35 | sample_kwargs: dict[str, Any] | None = None, 36 | default_prompt: str | None = None, 37 | norm_stats: dict[str, transforms.NormStats] | None = None, 38 | ) -> _policy.Policy | _policy.ReasoningPolicy: 39 | """Create a policy from a trained checkpoint. 40 | 41 | Args: 42 | train_config: The training config to use to create the model. 43 | checkpoint_dir: The directory to load the model from. 44 | repack_transforms: Optional transforms that will be applied before any other transforms. 45 | sample_kwargs: The kwargs to pass to the `sample_actions` method. If not provided, the default 46 | kwargs will be used. 47 | default_prompt: The default prompt to use for the policy. Will inject the prompt into the input 48 | data if it doesn't already exist. 49 | norm_stats: The norm stats to use for the policy. If not provided, the norm stats will be loaded 50 | from the checkpoint directory. 51 | """ 52 | repack_transforms = repack_transforms or transforms.Group() 53 | checkpoint_dir = download.maybe_download(str(checkpoint_dir)) 54 | 55 | logging.info("Loading model...") 56 | model = train_config.model.load(_model.restore_params(checkpoint_dir / "params", dtype=jnp.bfloat16)) 57 | 58 | data_config = train_config.data.create(train_config.assets_dirs, train_config.model) 59 | if norm_stats is None: 60 | # We are loading the norm stats from the checkpoint instead of the config assets dir to make sure 61 | # that the policy is using the same normalization stats as the original training process. 62 | if data_config.asset_id is None: 63 | raise ValueError("Asset id is required to load norm stats.") 64 | norm_stats = _checkpoints.load_norm_stats(checkpoint_dir / "assets", data_config.asset_id) 65 | 66 | extra_policy_kwargs = {} 67 | 68 | if train_config.model.model_type == _model.ModelType.PI0_FUSE: 69 | assert isinstance(train_config, _config.UMITrainConfig) 70 | ploicy_cls = _policy.ReasoningPolicy 71 | extra_policy_kwargs['use_ref_img'] = train_config.use_reference_image 72 | 73 | if 'cocktail' in train_config.repo_id: 74 | extra_policy_kwargs['initial_scene_plan'] = ( 75 | 'Scene description: TBD.\n' 76 | 'Plan: TBD.\n' 77 | 'What I have done: TBD.\n' 78 | 'Now I need to: TBD.\n' 79 | ) 80 | elif 'wild_move_to' in train_config.repo_id: 81 | extra_policy_kwargs['initial_scene_plan'] = '' 82 | else: 83 | extra_policy_kwargs['initial_scene_plan'] = '' 84 | else: 85 | ploicy_cls = _policy.Policy 86 | 87 | return ploicy_cls( 88 | model, 89 | transforms=[ 90 | *repack_transforms.inputs, 91 | transforms.InjectDefaultPrompt(default_prompt), 92 | *data_config.data_transforms.inputs, 93 | transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm), 94 | *data_config.model_transforms.inputs, 95 | ], 96 | output_transforms=[ 97 | *data_config.model_transforms.outputs, 98 | transforms.Unnormalize(norm_stats, use_quantiles=data_config.use_quantile_norm), 99 | *data_config.data_transforms.outputs, 100 | *repack_transforms.outputs, 101 | ], 102 | sample_kwargs=sample_kwargs, 103 | metadata=train_config.policy_metadata, 104 | **extra_policy_kwargs, 105 | ) 106 | -------------------------------------------------------------------------------- /src/openpi/policies/policy_test.py: -------------------------------------------------------------------------------- 1 | from openpi_client import action_chunk_broker 2 | import pytest 3 | 4 | from openpi.policies import aloha_policy 5 | from openpi.policies import policy_config as _policy_config 6 | from openpi.training import config as _config 7 | 8 | 9 | @pytest.mark.manual 10 | def test_infer(): 11 | config = _config.get_config("pi0_aloha_sim") 12 | policy = _policy_config.create_trained_policy(config, "s3://openpi-assets/checkpoints/pi0_aloha_sim") 13 | 14 | example = aloha_policy.make_aloha_example() 15 | result = policy.infer(example) 16 | 17 | assert result["actions"].shape == (config.model.action_horizon, 14) 18 | 19 | 20 | @pytest.mark.manual 21 | def test_broker(): 22 | config = _config.get_config("pi0_aloha_sim") 23 | policy = _policy_config.create_trained_policy(config, "s3://openpi-assets/checkpoints/pi0_aloha_sim") 24 | 25 | broker = action_chunk_broker.ActionChunkBroker( 26 | policy, 27 | # Only execute the first half of the chunk. 28 | action_horizon=config.model.action_horizon // 2, 29 | ) 30 | 31 | example = aloha_policy.make_aloha_example() 32 | for _ in range(config.model.action_horizon): 33 | outputs = broker.infer(example) 34 | assert outputs["actions"].shape == (14,) 35 | -------------------------------------------------------------------------------- /src/openpi/policies/pose_repr_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def compute_relative_pose(pos, rot, base_pos, base_rot_mat, 5 | rot_transformer_to_mat, 6 | rot_transformer_to_target, 7 | backward=False, 8 | delta=False): 9 | if not backward: 10 | # forward pass 11 | if not delta: 12 | output_pos = pos if base_pos is None else pos - base_pos 13 | output_rot = rot_transformer_to_target.forward( 14 | rot_transformer_to_mat.forward(rot) @ np.linalg.inv(base_rot_mat)) 15 | return output_pos, output_rot 16 | else: 17 | all_pos = np.concatenate([base_pos[None,...], pos], axis=0) 18 | output_pos = np.diff(all_pos, axis=0) 19 | 20 | rot_mat = rot_transformer_to_mat.forward(rot) 21 | all_rot_mat = np.concatenate([base_rot_mat[None,...], rot_mat], axis=0) 22 | prev_rot = np.linalg.inv(all_rot_mat[:-1]) 23 | curr_rot = all_rot_mat[1:] 24 | rot = np.matmul(curr_rot, prev_rot) 25 | output_rot = rot_transformer_to_target.forward(rot) 26 | return output_pos, output_rot 27 | 28 | else: 29 | # backward pass 30 | if not delta: 31 | output_pos = pos if base_pos is None else pos + base_pos 32 | output_rot = rot_transformer_to_mat.inverse( 33 | rot_transformer_to_target.inverse(rot) @ base_rot_mat) 34 | return output_pos, output_rot 35 | else: 36 | output_pos = np.cumsum(pos, axis=0) + base_pos 37 | 38 | rot_mat = rot_transformer_to_target.inverse(rot) 39 | output_rot_mat = np.zeros_like(rot_mat) 40 | curr_rot = base_rot_mat 41 | for i in range(len(rot_mat)): 42 | curr_rot = rot_mat[i] @ curr_rot 43 | output_rot_mat[i] = curr_rot 44 | output_rot = rot_transformer_to_mat.inverse(rot) 45 | return output_pos, output_rot 46 | 47 | 48 | def convert_pose_mat_rep(pose_mat, base_pose_mat, pose_rep='abs', backward=False): 49 | if not backward: 50 | # training transform 51 | if pose_rep == 'abs': 52 | return pose_mat 53 | elif pose_rep == 'rel': 54 | # legacy buggy implementation 55 | # for compatibility 56 | pos = pose_mat[...,:3,3] - base_pose_mat[:3,3] 57 | rot = pose_mat[...,:3,:3] @ np.linalg.inv(base_pose_mat[:3,:3]) 58 | out = np.copy(pose_mat) 59 | out[...,:3,:3] = rot 60 | out[...,:3,3] = pos 61 | return out 62 | elif pose_rep == 'relative': 63 | out = np.linalg.inv(base_pose_mat) @ pose_mat 64 | return out 65 | elif pose_rep == 'delta': 66 | all_pos = np.concatenate([base_pose_mat[None,:3,3], pose_mat[...,:3,3]], axis=0) 67 | out_pos = np.diff(all_pos, axis=0) 68 | 69 | all_rot_mat = np.concatenate([base_pose_mat[None,:3,:3], pose_mat[...,:3,:3]], axis=0) 70 | prev_rot = np.linalg.inv(all_rot_mat[:-1]) 71 | curr_rot = all_rot_mat[1:] 72 | out_rot = np.matmul(curr_rot, prev_rot) 73 | 74 | out = np.copy(pose_mat) 75 | out[...,:3,:3] = out_rot 76 | out[...,:3,3] = out_pos 77 | return out 78 | else: 79 | raise RuntimeError(f"Unsupported pose_rep: {pose_rep}") 80 | 81 | else: 82 | # eval transform 83 | if pose_rep == 'abs': 84 | return pose_mat 85 | elif pose_rep == 'rel': 86 | # legacy buggy implementation 87 | # for compatibility 88 | pos = pose_mat[...,:3,3] + base_pose_mat[:3,3] 89 | rot = pose_mat[...,:3,:3] @ base_pose_mat[:3,:3] 90 | out = np.copy(pose_mat) 91 | out[...,:3,:3] = rot 92 | out[...,:3,3] = pos 93 | return out 94 | elif pose_rep == 'relative': 95 | out = base_pose_mat @ pose_mat 96 | return out 97 | elif pose_rep == 'delta': 98 | output_pos = np.cumsum(pose_mat[...,:3,3], axis=0) + base_pose_mat[:3,3] 99 | 100 | output_rot_mat = np.zeros_like(pose_mat[...,:3,:3]) 101 | curr_rot = base_pose_mat[:3,:3] 102 | for i in range(len(pose_mat)): 103 | curr_rot = pose_mat[i,:3,:3] @ curr_rot 104 | output_rot_mat[i] = curr_rot 105 | 106 | out = np.copy(pose_mat) 107 | out[...,:3,:3] = output_rot_mat 108 | out[...,:3,3] = output_pos 109 | return out 110 | else: 111 | raise RuntimeError(f"Unsupported pose_rep: {pose_rep}") 112 | -------------------------------------------------------------------------------- /src/openpi/policies/pose_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.spatial.transform as st 3 | 4 | def pos_rot_to_mat(pos, rot): 5 | shape = pos.shape[:-1] 6 | mat = np.zeros(shape + (4,4), dtype=pos.dtype) 7 | mat[...,:3,3] = pos 8 | mat[...,:3,:3] = rot.as_matrix() 9 | mat[...,3,3] = 1 10 | return mat 11 | 12 | def mat_to_pos_rot(mat): 13 | pos = (mat[...,:3,3].T / mat[...,3,3].T).T 14 | rot = st.Rotation.from_matrix(mat[...,:3,:3]) 15 | return pos, rot 16 | 17 | def pos_rot_to_pose(pos, rot): 18 | shape = pos.shape[:-1] 19 | pose = np.zeros(shape+(6,), dtype=pos.dtype) 20 | pose[...,:3] = pos 21 | pose[...,3:] = rot.as_rotvec() 22 | return pose 23 | 24 | def pose_to_pos_rot(pose): 25 | pos = pose[...,:3] 26 | rot = st.Rotation.from_rotvec(pose[...,3:]) 27 | return pos, rot 28 | 29 | def pose_to_mat(pose): 30 | return pos_rot_to_mat(*pose_to_pos_rot(pose)) 31 | 32 | def mat_to_pose(mat): 33 | return pos_rot_to_pose(*mat_to_pos_rot(mat)) 34 | 35 | def transform_pose(tx, pose): 36 | """ 37 | tx: tx_new_old 38 | pose: tx_old_obj 39 | result: tx_new_obj 40 | """ 41 | pose_mat = pose_to_mat(pose) 42 | tf_pose_mat = tx @ pose_mat 43 | tf_pose = mat_to_pose(tf_pose_mat) 44 | return tf_pose 45 | 46 | def transform_point(tx, point): 47 | return point @ tx[:3,:3].T + tx[:3,3] 48 | 49 | def project_point(k, point): 50 | x = point @ k.T 51 | uv = x[...,:2] / x[...,[2]] 52 | return uv 53 | 54 | def apply_delta_pose(pose, delta_pose): 55 | new_pose = np.zeros_like(pose) 56 | 57 | # simple add for position 58 | new_pose[:3] = pose[:3] + delta_pose[:3] 59 | 60 | # matrix multiplication for rotation 61 | rot = st.Rotation.from_rotvec(pose[3:]) 62 | drot = st.Rotation.from_rotvec(delta_pose[3:]) 63 | new_pose[3:] = (drot * rot).as_rotvec() 64 | 65 | return new_pose 66 | 67 | def normalize(vec, tol=1e-7): 68 | return vec / np.maximum(np.linalg.norm(vec), tol) 69 | 70 | def rot_from_directions(from_vec, to_vec): 71 | from_vec = normalize(from_vec) 72 | to_vec = normalize(to_vec) 73 | axis = np.cross(from_vec, to_vec) 74 | axis = normalize(axis) 75 | angle = np.arccos(np.dot(from_vec, to_vec)) 76 | rotvec = axis * angle 77 | rot = st.Rotation.from_rotvec(rotvec) 78 | return rot 79 | 80 | def normalize(vec, eps=1e-12): 81 | norm = np.linalg.norm(vec, axis=-1) 82 | norm = np.maximum(norm, eps) 83 | out = (vec.T / norm).T 84 | return out 85 | 86 | def rot6d_to_mat(d6): 87 | a1, a2 = d6[..., :3], d6[..., 3:] 88 | b1 = normalize(a1) 89 | b2 = a2 - np.sum(b1 * a2, axis=-1, keepdims=True) * b1 90 | b2 = normalize(b2) 91 | b3 = np.cross(b1, b2, axis=-1) 92 | out = np.stack((b1, b2, b3), axis=-2) 93 | return out 94 | 95 | def mat_to_rot6d(mat): 96 | batch_dim = mat.shape[:-2] 97 | out = mat[..., :2, :].copy().reshape(batch_dim + (6,)) 98 | return out 99 | 100 | def mat_to_pose10d(mat): 101 | pos = mat[...,:3,3] 102 | rotmat = mat[...,:3,:3] 103 | d6 = mat_to_rot6d(rotmat) 104 | d10 = np.concatenate([pos, d6], axis=-1) 105 | return d10 106 | 107 | def pose10d_to_mat(d10): 108 | pos = d10[...,:3] 109 | d6 = d10[...,3:] 110 | rotmat = rot6d_to_mat(d6) 111 | out = np.zeros(d10.shape[:-1]+(4,4), dtype=d10.dtype) 112 | out[...,:3,:3] = rotmat 113 | out[...,:3,3] = pos 114 | out[...,3,3] = 1 115 | return out 116 | -------------------------------------------------------------------------------- /src/openpi/policies/umi_policy.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | import einops 4 | import numpy as np 5 | 6 | from openpi import transforms 7 | from openpi.models import model as _model 8 | 9 | 10 | def make_umi_example() -> dict: 11 | """Creates a random input example for the umi policy.""" 12 | return { 13 | "state": np.random.rand(48), 14 | "image_1": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), 15 | "image_2": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), 16 | "image_3": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), 17 | "prompt": "do something", 18 | } 19 | 20 | 21 | def _parse_image(image) -> np.ndarray: 22 | image = np.asarray(image) 23 | if np.issubdtype(image.dtype, np.floating): 24 | image = (255 * image).astype(np.uint8) 25 | if image.shape[0] == 3: 26 | image = einops.rearrange(image, "c h w -> h w c") 27 | return image 28 | 29 | 30 | @dataclasses.dataclass(frozen=True) 31 | class UMIInputs(transforms.DataTransformFn): 32 | # The action dimension of the model. Will be used to pad state and actions for pi0 model (not pi0-FAST). 33 | action_dim: int 34 | 35 | # Determines which model will be used. 36 | model_type: _model.ModelType = _model.ModelType.PI0 37 | 38 | def __call__(self, data: dict) -> dict: 39 | mask_padding = self.model_type == _model.ModelType.PI0 # We don't mask for pi0-FAST. 40 | 41 | # Get the state. We are padding from 8 to the model action dim. 42 | # For pi0-FAST, we don't pad the state (action_dim = 7, which is < 8, so pad is skipped). 43 | state = transforms.pad_to_dim(data["state"], self.action_dim) 44 | 45 | history_length = 1 46 | while True: 47 | if f"image_{history_length + 1}" not in data: 48 | break 49 | history_length += 1 50 | 51 | # Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically 52 | # stores as float32 (C,H,W), gets skipped for policy inference 53 | image_dict, image_mask_dict = {}, {} 54 | for i in range(history_length): 55 | image = _parse_image(data[f"image_{i + 1}"]) 56 | image_dict[f"{i}_rgb"] = image 57 | image_mask_dict[f"{i}_rgb"] = np.True_ 58 | 59 | if 'reference_image' in data.keys(): 60 | image = _parse_image(data['reference_image']) 61 | image_dict['reference_rgb'] = image 62 | image_mask_dict['reference_rgb'] = np.True_ 63 | 64 | add_prompt_info = None 65 | if 'condition' in data.keys(): 66 | if data['condition'] is None: 67 | image_dict['start_rgb'] = np.zeros_like(image_dict['0_rgb']) 68 | image_mask_dict['start_rgb'] = np.False_ 69 | else: 70 | image_dict['start_rgb'] = _parse_image(data['condition']['episode_start_image']) 71 | image_mask_dict['start_rgb'] = np.True_ 72 | add_prompt_info = '. Objects are located at ' + str(data['condition']['detect']) + '.' 73 | 74 | inputs = { 75 | "state": state, 76 | "image": image_dict, 77 | "image_mask": image_mask_dict 78 | } 79 | 80 | # Actions are only available during training. 81 | if "actions" in data: 82 | # We are padding from 7 to the model action dim. 83 | # For pi0-FAST, this is a no-op (since action_dim = 7). 84 | actions = transforms.pad_to_dim(data["actions"], self.action_dim) 85 | inputs["actions"] = actions 86 | 87 | if "prompt" in data: 88 | inputs["prompt"] = data["prompt"] 89 | 90 | if add_prompt_info is not None: 91 | inputs["prompt"] += add_prompt_info 92 | 93 | if 'thought' in data.keys(): 94 | inputs['thought'] = data['thought'] 95 | inputs['act_with_outdated_thought'] = data['act_with_outdated_thought'] 96 | inputs['think_with_outdated_thought'] = data['think_with_outdated_thought'] 97 | 98 | return inputs 99 | 100 | 101 | @dataclasses.dataclass(frozen=True) 102 | class UMIOutputs(transforms.DataTransformFn): 103 | def __call__(self, data: dict) -> dict: 104 | # Only return the first 10 dims. 105 | data.update({"actions": np.asarray(data["actions"][:, :10])}) 106 | return data 107 | -------------------------------------------------------------------------------- /src/openpi/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fanqi-Lin/OneTwoVLA/2489662cf1d4f75dbeea3cf9f337c95391e2b836/src/openpi/py.typed -------------------------------------------------------------------------------- /src/openpi/serving/websocket_policy_server.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import traceback 4 | import numpy as np 5 | 6 | from openpi_client import base_policy as _base_policy 7 | from openpi_client import msgpack_numpy 8 | import websockets.asyncio.server 9 | import websockets.frames 10 | 11 | 12 | class WebsocketPolicyServer: 13 | """Serves a policy using the websocket protocol. See websocket_client_policy.py for a client implementation. 14 | 15 | Currently only implements the `load` and `infer` methods. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | policy: _base_policy.BasePolicy, 21 | host: str = "0.0.0.0", 22 | port: int = 8000, 23 | metadata: dict | None = None, 24 | ) -> None: 25 | self._policy = policy 26 | self._host = host 27 | self._port = port 28 | self._metadata = metadata or {} 29 | logging.getLogger("websockets.server").setLevel(logging.INFO) 30 | 31 | def serve_forever(self) -> None: 32 | asyncio.run(self.run()) 33 | 34 | async def run(self): 35 | async with websockets.asyncio.server.serve( 36 | self._handler, 37 | self._host, 38 | self._port, 39 | compression=None, 40 | max_size=None, 41 | ) as server: 42 | await server.serve_forever() 43 | 44 | async def _handler(self, websocket: websockets.asyncio.server.ServerConnection): 45 | logging.info(f"Connection from {websocket.remote_address} opened") 46 | packer = msgpack_numpy.Packer() 47 | if hasattr(self._policy, 'start'): 48 | self._policy.start() 49 | 50 | await websocket.send(packer.pack(self._metadata)) 51 | 52 | while True: 53 | try: 54 | obs = msgpack_numpy.unpackb(await websocket.recv()) 55 | infer_task = asyncio.create_task(asyncio.to_thread(self._policy.infer, obs)) 56 | while True: 57 | if infer_task.done(): 58 | action = await infer_task 59 | await websocket.send(packer.pack(action)) 60 | break 61 | if getattr(self._policy, 'is_thinking', False): 62 | await websocket.send(packer.pack({"isthinking": np.True_})) 63 | _ = msgpack_numpy.unpackb(await websocket.recv()) 64 | await asyncio.sleep(0.02) 65 | except websockets.ConnectionClosed: 66 | logging.info(f"Connection from {websocket.remote_address} closed") 67 | break 68 | except Exception: 69 | await websocket.send(traceback.format_exc()) 70 | await websocket.close( 71 | code=websockets.frames.CloseCode.INTERNAL_ERROR, 72 | reason="Internal server error. Traceback included in previous frame.", 73 | ) 74 | raise 75 | -------------------------------------------------------------------------------- /src/openpi/shared/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fanqi-Lin/OneTwoVLA/2489662cf1d4f75dbeea3cf9f337c95391e2b836/src/openpi/shared/__init__.py -------------------------------------------------------------------------------- /src/openpi/shared/array_typing.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import functools as ft 3 | import inspect 4 | from typing import TypeAlias, TypeVar, cast 5 | 6 | import beartype 7 | import jax 8 | import jax._src.tree_util as private_tree_util 9 | import jax.core 10 | from jaxtyping import Array # noqa: F401 11 | from jaxtyping import ArrayLike 12 | from jaxtyping import Bool # noqa: F401 13 | from jaxtyping import DTypeLike # noqa: F401 14 | from jaxtyping import Float 15 | from jaxtyping import Int # noqa: F401 16 | from jaxtyping import Key # noqa: F401 17 | from jaxtyping import Num # noqa: F401 18 | from jaxtyping import PyTree 19 | from jaxtyping import Real # noqa: F401 20 | from jaxtyping import UInt8 # noqa: F401 21 | from jaxtyping import config 22 | from jaxtyping import jaxtyped 23 | import jaxtyping._decorator 24 | 25 | # patch jaxtyping to handle https://github.com/patrick-kidger/jaxtyping/issues/277. 26 | # the problem is that custom PyTree nodes are sometimes initialized with arbitrary types (e.g., `jax.ShapeDtypeStruct`, 27 | # `jax.Sharding`, or even ) due to JAX tracing operations. this patch skips typechecking when the stack trace 28 | # contains `jax._src.tree_util`, which should only be the case during tree unflattening. 29 | _original_check_dataclass_annotations = jaxtyping._decorator._check_dataclass_annotations # noqa: SLF001 30 | 31 | 32 | def _check_dataclass_annotations(self, typechecker): 33 | if not any( 34 | frame.frame.f_globals["__name__"] in {"jax._src.tree_util", "flax.nnx.transforms.compilation"} 35 | for frame in inspect.stack() 36 | ): 37 | return _original_check_dataclass_annotations(self, typechecker) 38 | return None 39 | 40 | 41 | jaxtyping._decorator._check_dataclass_annotations = _check_dataclass_annotations # noqa: SLF001 42 | 43 | KeyArrayLike: TypeAlias = jax.typing.ArrayLike 44 | Params: TypeAlias = PyTree[Float[ArrayLike, "..."]] 45 | 46 | T = TypeVar("T") 47 | 48 | 49 | # runtime type-checking decorator 50 | def typecheck(t: T) -> T: 51 | return cast(T, ft.partial(jaxtyped, typechecker=beartype.beartype)(t)) 52 | 53 | 54 | @contextlib.contextmanager 55 | def disable_typechecking(): 56 | initial = config.jaxtyping_disable 57 | config.update("jaxtyping_disable", True) # noqa: FBT003 58 | yield 59 | config.update("jaxtyping_disable", initial) 60 | 61 | 62 | def check_pytree_equality(*, expected: PyTree, got: PyTree, check_shapes: bool = False, check_dtypes: bool = False): 63 | """Checks that two PyTrees have the same structure and optionally checks shapes and dtypes. Creates a much nicer 64 | error message than if `jax.tree.map` is naively used on PyTrees with different structures. 65 | """ 66 | 67 | if errors := list(private_tree_util.equality_errors(expected, got)): 68 | raise ValueError( 69 | "PyTrees have different structure:\n" 70 | + ( 71 | "\n".join( 72 | f" - at keypath '{jax.tree_util.keystr(path)}': expected {thing1}, got {thing2}, so {explanation}.\n" 73 | for path, thing1, thing2, explanation in errors 74 | ) 75 | ) 76 | ) 77 | 78 | if check_shapes or check_dtypes: 79 | 80 | def check(kp, x, y): 81 | if check_shapes and x.shape != y.shape: 82 | raise ValueError(f"Shape mismatch at {jax.tree_util.keystr(kp)}: expected {x.shape}, got {y.shape}") 83 | 84 | if check_dtypes and x.dtype != y.dtype: 85 | raise ValueError(f"Dtype mismatch at {jax.tree_util.keystr(kp)}: expected {x.dtype}, got {y.dtype}") 86 | 87 | jax.tree_util.tree_map_with_path(check, expected, got) 88 | -------------------------------------------------------------------------------- /src/openpi/shared/download_test.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import pytest 4 | 5 | import openpi.shared.download as download 6 | 7 | 8 | @pytest.fixture(scope="session", autouse=True) 9 | def set_openpi_data_home(tmp_path_factory): 10 | temp_dir = tmp_path_factory.mktemp("openpi_data") 11 | with pytest.MonkeyPatch().context() as mp: 12 | mp.setenv("OPENPI_DATA_HOME", str(temp_dir)) 13 | yield 14 | 15 | 16 | def test_download_local(tmp_path: pathlib.Path): 17 | local_path = tmp_path / "local" 18 | local_path.touch() 19 | 20 | result = download.maybe_download(str(local_path)) 21 | assert result == local_path 22 | 23 | with pytest.raises(FileNotFoundError): 24 | download.maybe_download("bogus") 25 | 26 | 27 | def test_download_s3_dir(): 28 | remote_path = "s3://openpi-assets/testdata/random" 29 | 30 | local_path = download.maybe_download(remote_path) 31 | assert local_path.exists() 32 | 33 | new_local_path = download.maybe_download(remote_path) 34 | assert new_local_path == local_path 35 | 36 | 37 | def test_download_s3(): 38 | remote_path = "s3://openpi-assets/testdata/random/random_512kb.bin" 39 | 40 | local_path = download.maybe_download(remote_path) 41 | assert local_path.exists() 42 | 43 | new_local_path = download.maybe_download(remote_path) 44 | assert new_local_path == local_path 45 | 46 | 47 | def test_download_fsspec(): 48 | remote_path = "gs://big_vision/paligemma_tokenizer.model" 49 | 50 | local_path = download.maybe_download(remote_path, gs={"token": "anon"}) 51 | assert local_path.exists() 52 | 53 | new_local_path = download.maybe_download(remote_path, gs={"token": "anon"}) 54 | assert new_local_path == local_path 55 | -------------------------------------------------------------------------------- /src/openpi/shared/image_tools.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | import openpi.shared.array_typing as at 7 | 8 | 9 | @functools.partial(jax.jit, static_argnums=(1, 2, 3)) 10 | @at.typecheck 11 | def resize_with_pad( 12 | images: at.UInt8[at.Array, "*b h w c"] | at.Float[at.Array, "*b h w c"], 13 | height: int, 14 | width: int, 15 | method: jax.image.ResizeMethod = jax.image.ResizeMethod.LINEAR, 16 | ) -> at.UInt8[at.Array, "*b {height} {width} c"] | at.Float[at.Array, "*b {height} {width} c"]: 17 | """Replicates tf.image.resize_with_pad. Resizes an image to a target height and width without distortion 18 | by padding with black. If the image is float32, it must be in the range [-1, 1]. 19 | """ 20 | has_batch_dim = images.ndim == 4 21 | if not has_batch_dim: 22 | images = images[None] # type: ignore 23 | cur_height, cur_width = images.shape[1:3] 24 | ratio = max(cur_width / width, cur_height / height) 25 | resized_height = int(cur_height / ratio) 26 | resized_width = int(cur_width / ratio) 27 | resized_images = jax.image.resize( 28 | images, (images.shape[0], resized_height, resized_width, images.shape[3]), method=method 29 | ) 30 | if images.dtype == jnp.uint8: 31 | # round from float back to uint8 32 | resized_images = jnp.round(resized_images).clip(0, 255).astype(jnp.uint8) 33 | elif images.dtype == jnp.float32: 34 | resized_images = resized_images.clip(-1.0, 1.0) 35 | else: 36 | raise ValueError(f"Unsupported image dtype: {images.dtype}") 37 | 38 | pad_h0, remainder_h = divmod(height - resized_height, 2) 39 | pad_h1 = pad_h0 + remainder_h 40 | pad_w0, remainder_w = divmod(width - resized_width, 2) 41 | pad_w1 = pad_w0 + remainder_w 42 | padded_images = jnp.pad( 43 | resized_images, 44 | ((0, 0), (pad_h0, pad_h1), (pad_w0, pad_w1), (0, 0)), 45 | constant_values=0 if images.dtype == jnp.uint8 else -1.0, 46 | ) 47 | 48 | if not has_batch_dim: 49 | padded_images = padded_images[0] 50 | return padded_images 51 | -------------------------------------------------------------------------------- /src/openpi/shared/image_tools_test.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | 3 | from openpi.shared import image_tools 4 | 5 | 6 | def test_resize_with_pad_shapes(): 7 | # Test case 1: Resize image with larger dimensions 8 | images = jnp.zeros((2, 10, 10, 3), dtype=jnp.uint8) # Input images of shape (batch_size, height, width, channels) 9 | height = 20 10 | width = 20 11 | resized_images = image_tools.resize_with_pad(images, height, width) 12 | assert resized_images.shape == (2, height, width, 3) 13 | assert jnp.all(resized_images == 0) 14 | 15 | # Test case 2: Resize image with smaller dimensions 16 | images = jnp.zeros((3, 30, 30, 3), dtype=jnp.uint8) 17 | height = 15 18 | width = 15 19 | resized_images = image_tools.resize_with_pad(images, height, width) 20 | assert resized_images.shape == (3, height, width, 3) 21 | assert jnp.all(resized_images == 0) 22 | 23 | # Test case 3: Resize image with the same dimensions 24 | images = jnp.zeros((1, 50, 50, 3), dtype=jnp.uint8) 25 | height = 50 26 | width = 50 27 | resized_images = image_tools.resize_with_pad(images, height, width) 28 | assert resized_images.shape == (1, height, width, 3) 29 | assert jnp.all(resized_images == 0) 30 | 31 | # Test case 3: Resize image with odd-numbered padding 32 | images = jnp.zeros((1, 256, 320, 3), dtype=jnp.uint8) 33 | height = 60 34 | width = 80 35 | resized_images = image_tools.resize_with_pad(images, height, width) 36 | assert resized_images.shape == (1, height, width, 3) 37 | assert jnp.all(resized_images == 0) 38 | -------------------------------------------------------------------------------- /src/openpi/shared/nnx_utils.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | import dataclasses 3 | import functools 4 | import inspect 5 | import re 6 | from typing import Any, ParamSpec, TypeVar 7 | 8 | import flax.nnx as nnx 9 | import jax 10 | 11 | P = ParamSpec("P") 12 | R = TypeVar("R") 13 | 14 | 15 | def module_jit(meth: Callable[P, R], *jit_args, **jit_kwargs) -> Callable[P, R]: 16 | """A higher-order function to JIT-compile `nnx.Module` methods, freezing the module's state in the process. 17 | 18 | Why not `nnx.jit`? For some reason, naively applying `nnx.jit` to `nnx.Module` methods, bound or unbound, uses much 19 | more memory than necessary. I'm guessing it has something to do with the fact that it must keep track of module 20 | mutations. Also, `nnx.jit` has some inherent overhead compared to a standard `jax.jit`, since every call must 21 | traverse the NNX module graph. See https://github.com/google/flax/discussions/4224 for details. 22 | 23 | `module_jit` is an alternative that avoids these issues by freezing the module's state. The function returned by 24 | `module_jit` acts exactly like the original method, except that the state of the module is frozen to whatever it was 25 | when `module_jit` was called. Mutations to the module within `meth` are still allowed, but they will be discarded 26 | after the method call completes. 27 | """ 28 | if not (inspect.ismethod(meth) and isinstance(meth.__self__, nnx.Module)): 29 | raise ValueError("module_jit must only be used on bound methods of nnx.Modules.") 30 | 31 | graphdef, state = nnx.split(meth.__self__) 32 | 33 | def fun(state: nnx.State, *args: P.args, **kwargs: P.kwargs) -> R: 34 | module = nnx.merge(graphdef, state) 35 | return meth.__func__(module, *args, **kwargs) 36 | 37 | jitted_fn = jax.jit(fun, *jit_args, **jit_kwargs) 38 | 39 | @functools.wraps(meth) 40 | def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: 41 | return jitted_fn(state, *args, **kwargs) 42 | 43 | return wrapper 44 | 45 | 46 | @dataclasses.dataclass(frozen=True) 47 | class PathRegex: 48 | """NNX Filter that matches paths using a regex. 49 | 50 | By default, paths are joined with a `/` separator. This can be overridden by setting the `sep` argument. 51 | """ 52 | 53 | pattern: str | re.Pattern 54 | sep: str = "/" 55 | 56 | def __post_init__(self): 57 | if not isinstance(self.pattern, re.Pattern): 58 | object.__setattr__(self, "pattern", re.compile(self.pattern)) 59 | 60 | def __call__(self, path: nnx.filterlib.PathParts, x: Any) -> bool: 61 | joined_path = self.sep.join(str(x) for x in path) 62 | assert isinstance(self.pattern, re.Pattern) 63 | return self.pattern.fullmatch(joined_path) is not None 64 | 65 | 66 | def state_map(state: nnx.State, filter: nnx.filterlib.Filter, fn: Callable[[Any], Any]) -> nnx.State: 67 | """Apply a function to the leaves of the state that match the filter.""" 68 | filtered_keys = set(state.filter(filter).flat_state()) 69 | return state.map(lambda k, v: fn(v) if k in filtered_keys else v) 70 | -------------------------------------------------------------------------------- /src/openpi/shared/normalize.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pathlib 3 | 4 | import numpy as np 5 | import numpydantic 6 | import pydantic 7 | 8 | 9 | @pydantic.dataclasses.dataclass 10 | class NormStats: 11 | mean: numpydantic.NDArray 12 | std: numpydantic.NDArray 13 | q01: numpydantic.NDArray | None = None # 1st quantile 14 | q99: numpydantic.NDArray | None = None # 99th quantile 15 | 16 | 17 | class RunningStats: 18 | """Compute running statistics of a batch of vectors.""" 19 | 20 | def __init__(self): 21 | self._count = 0 22 | self._mean = None 23 | self._mean_of_squares = None 24 | self._min = None 25 | self._max = None 26 | self._histograms = None 27 | self._bin_edges = None 28 | self._num_quantile_bins = 5000 # for computing quantiles on the fly 29 | 30 | def update(self, batch: np.ndarray) -> None: 31 | """ 32 | Update the running statistics with a batch of vectors. 33 | 34 | Args: 35 | vectors (np.ndarray): A 2D array where each row is a new vector. 36 | """ 37 | if batch.ndim == 1: 38 | batch = batch.reshape(-1, 1) 39 | num_elements, vector_length = batch.shape 40 | if self._count == 0: 41 | self._mean = np.mean(batch, axis=0) 42 | self._mean_of_squares = np.mean(batch**2, axis=0) 43 | self._min = np.min(batch, axis=0) 44 | self._max = np.max(batch, axis=0) 45 | self._histograms = [np.zeros(self._num_quantile_bins) for _ in range(vector_length)] 46 | self._bin_edges = [ 47 | np.linspace(self._min[i] - 1e-10, self._max[i] + 1e-10, self._num_quantile_bins + 1) 48 | for i in range(vector_length) 49 | ] 50 | else: 51 | if vector_length != self._mean.size: 52 | raise ValueError("The length of new vectors does not match the initialized vector length.") 53 | new_max = np.max(batch, axis=0) 54 | new_min = np.min(batch, axis=0) 55 | max_changed = np.any(new_max > self._max) 56 | min_changed = np.any(new_min < self._min) 57 | self._max = np.maximum(self._max, new_max) 58 | self._min = np.minimum(self._min, new_min) 59 | 60 | if max_changed or min_changed: 61 | self._adjust_histograms() 62 | 63 | self._count += num_elements 64 | 65 | batch_mean = np.mean(batch, axis=0) 66 | batch_mean_of_squares = np.mean(batch**2, axis=0) 67 | 68 | # Update running mean and mean of squares. 69 | self._mean += (batch_mean - self._mean) * (num_elements / self._count) 70 | self._mean_of_squares += (batch_mean_of_squares - self._mean_of_squares) * (num_elements / self._count) 71 | 72 | self._update_histograms(batch) 73 | 74 | def get_statistics(self) -> NormStats: 75 | """ 76 | Compute and return the statistics of the vectors processed so far. 77 | 78 | Returns: 79 | dict: A dictionary containing the computed statistics. 80 | """ 81 | if self._count < 2: 82 | raise ValueError("Cannot compute statistics for less than 2 vectors.") 83 | 84 | variance = self._mean_of_squares - self._mean**2 85 | stddev = np.sqrt(np.maximum(0, variance)) 86 | q01, q99 = self._compute_quantiles([0.01, 0.99]) 87 | return NormStats(mean=self._mean, std=stddev, q01=q01, q99=q99) 88 | 89 | def _adjust_histograms(self): 90 | """Adjust histograms when min or max changes.""" 91 | for i in range(len(self._histograms)): 92 | old_edges = self._bin_edges[i] 93 | new_edges = np.linspace(self._min[i], self._max[i], self._num_quantile_bins + 1) 94 | 95 | # Redistribute the existing histogram counts to the new bins 96 | new_hist, _ = np.histogram(old_edges[:-1], bins=new_edges, weights=self._histograms[i]) 97 | 98 | self._histograms[i] = new_hist 99 | self._bin_edges[i] = new_edges 100 | 101 | def _update_histograms(self, batch: np.ndarray) -> None: 102 | """Update histograms with new vectors.""" 103 | for i in range(batch.shape[1]): 104 | hist, _ = np.histogram(batch[:, i], bins=self._bin_edges[i]) 105 | self._histograms[i] += hist 106 | 107 | def _compute_quantiles(self, quantiles): 108 | """Compute quantiles based on histograms.""" 109 | results = [] 110 | for q in quantiles: 111 | target_count = q * self._count 112 | q_values = [] 113 | for hist, edges in zip(self._histograms, self._bin_edges, strict=True): 114 | cumsum = np.cumsum(hist) 115 | idx = np.searchsorted(cumsum, target_count) 116 | q_values.append(edges[idx]) 117 | results.append(np.array(q_values)) 118 | return results 119 | 120 | 121 | class _NormStatsDict(pydantic.BaseModel): 122 | norm_stats: dict[str, NormStats] 123 | 124 | 125 | def serialize_json(norm_stats: dict[str, NormStats]) -> str: 126 | """Serialize the running statistics to a JSON string.""" 127 | return _NormStatsDict(norm_stats=norm_stats).model_dump_json(indent=2) 128 | 129 | 130 | def deserialize_json(data: str) -> dict[str, NormStats]: 131 | """Deserialize the running statistics from a JSON string.""" 132 | return _NormStatsDict(**json.loads(data)).norm_stats 133 | 134 | 135 | def save(directory: pathlib.Path | str, norm_stats: dict[str, NormStats]) -> None: 136 | """Save the normalization stats to a directory.""" 137 | path = pathlib.Path(directory) / "norm_stats.json" 138 | path.parent.mkdir(parents=True, exist_ok=True) 139 | path.write_text(serialize_json(norm_stats)) 140 | 141 | 142 | def load(directory: pathlib.Path | str) -> dict[str, NormStats]: 143 | """Load the normalization stats from a directory.""" 144 | path = pathlib.Path(directory) / "norm_stats.json" 145 | if not path.exists(): 146 | raise FileNotFoundError(f"Norm stats file not found at: {path}") 147 | return deserialize_json(path.read_text()) 148 | -------------------------------------------------------------------------------- /src/openpi/shared/normalize_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import openpi.shared.normalize as normalize 4 | 5 | 6 | def test_normalize_update(): 7 | arr = np.arange(12) 8 | 9 | stats = normalize.RunningStats() 10 | for i in range(0, len(arr), 3): 11 | stats.update(arr[i : i + 3]) 12 | results = stats.get_statistics() 13 | 14 | assert np.allclose(results.mean, np.mean(arr)) 15 | assert np.allclose(results.std, np.std(arr)) 16 | 17 | 18 | def test_serialize_deserialize(): 19 | stats = normalize.RunningStats() 20 | stats.update(np.arange(12)) 21 | 22 | norm_stats = {"test": stats.get_statistics()} 23 | norm_stats2 = normalize.deserialize_json(normalize.serialize_json(norm_stats)) 24 | assert np.allclose(norm_stats["test"].mean, norm_stats2["test"].mean) 25 | assert np.allclose(norm_stats["test"].std, norm_stats2["test"].std) 26 | -------------------------------------------------------------------------------- /src/openpi/training/checkpoints.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures as futures 2 | import dataclasses 3 | import logging 4 | from typing import Protocol 5 | 6 | from etils import epath 7 | import jax 8 | import orbax.checkpoint as ocp 9 | 10 | from openpi.shared import array_typing as at 11 | import openpi.shared.normalize as _normalize 12 | import openpi.training.data_loader as _data_loader 13 | import openpi.training.utils as training_utils 14 | 15 | 16 | def initialize_checkpoint_dir( 17 | checkpoint_dir: epath.Path | str, *, keep_period: int | None, overwrite: bool, resume: bool 18 | ) -> tuple[ocp.CheckpointManager, bool]: 19 | checkpoint_dir = epath.Path(checkpoint_dir).resolve() 20 | resuming = False 21 | if checkpoint_dir.exists(): 22 | if overwrite: 23 | checkpoint_dir.rmtree() 24 | checkpoint_dir.mkdir(parents=True, exist_ok=True) 25 | logging.info(f"Wiped checkpoint directory {checkpoint_dir}") 26 | elif resume: 27 | resuming = True 28 | else: 29 | raise FileExistsError( 30 | f"Checkpoint directory {checkpoint_dir} already exists. Use --overwrite or --resume " 31 | "to indicate how to handle it." 32 | ) 33 | 34 | checkpoint_dir.mkdir(parents=True, exist_ok=True) 35 | 36 | mngr = ocp.CheckpointManager( 37 | checkpoint_dir, 38 | item_handlers={ 39 | "assets": CallbackHandler(), 40 | "train_state": ocp.PyTreeCheckpointHandler(), 41 | "params": ocp.PyTreeCheckpointHandler(), 42 | }, 43 | options=ocp.CheckpointManagerOptions( 44 | max_to_keep=1, 45 | keep_period=keep_period, 46 | create=False, 47 | async_options=ocp.AsyncOptions(timeout_secs=7200), 48 | ), 49 | ) 50 | 51 | # special case: the checkpoint directory exists and the user requests to resume training, but the training run did 52 | # not get to the first checkpoint saved. in this case, we don't actually want the train script to try and restore a 53 | # checkpoint, since it will fail. 54 | if resuming and tuple(mngr.all_steps()) in [(), (0,)]: 55 | logging.info("Checkpoint directory exists, but does not contain any checkpoints. Aborting resume.") 56 | resuming = False 57 | 58 | return mngr, resuming 59 | 60 | 61 | def save_state( 62 | checkpoint_manager: ocp.CheckpointManager, 63 | state: training_utils.TrainState, 64 | data_loader: _data_loader.DataLoader, 65 | step: int, 66 | ): 67 | def save_assets(directory: epath.Path): 68 | # Save the normalization stats. 69 | data_config = data_loader.data_config() 70 | norm_stats = data_config.norm_stats 71 | if norm_stats is not None and data_config.asset_id is not None: 72 | _normalize.save(directory / data_config.asset_id, norm_stats) 73 | 74 | # Split params that can be used for inference into a separate item. 75 | with at.disable_typechecking(): 76 | train_state, params = _split_params(state) 77 | items = { 78 | "assets": save_assets, 79 | "train_state": train_state, 80 | "params": {"params": params}, 81 | } 82 | checkpoint_manager.save(step, items) 83 | 84 | 85 | def restore_state( 86 | checkpoint_manager: ocp.CheckpointManager, 87 | state: training_utils.TrainState, 88 | data_loader: _data_loader.DataLoader, 89 | step: int | None = None, 90 | ) -> training_utils.TrainState: 91 | del data_loader 92 | 93 | with at.disable_typechecking(): 94 | # Split params that can be used for inference into a separate item. 95 | train_state, params = _split_params(state) 96 | restored = checkpoint_manager.restore( 97 | step, 98 | items={ 99 | "train_state": train_state, 100 | "params": {"params": params}, 101 | }, 102 | ) 103 | return _merge_params(restored["train_state"], restored["params"]) 104 | 105 | 106 | def load_norm_stats(assets_dir: epath.Path | str, asset_id: str) -> dict[str, _normalize.NormStats] | None: 107 | norm_stats_dir = epath.Path(assets_dir) / asset_id 108 | norm_stats = _normalize.load(norm_stats_dir) 109 | logging.info(f"Loaded norm stats from {norm_stats_dir}") 110 | return norm_stats 111 | 112 | 113 | class Callback(Protocol): 114 | def __call__(self, directory: epath.Path) -> None: ... 115 | 116 | 117 | class CallbackHandler(ocp.AsyncCheckpointHandler): 118 | """A CheckpointHandler for calling an arbitrary function asynchronously. Only for saving, not for restoring.""" 119 | 120 | def __init__(self): 121 | self._executor = futures.ThreadPoolExecutor(max_workers=1) 122 | 123 | def close(self): 124 | self._executor.shutdown() 125 | 126 | def save(self, directory: epath.Path, args: "CallbackSave"): 127 | if jax.process_index() == 0: 128 | args.callback(directory) 129 | 130 | async def async_save(self, directory: epath.Path, args: "CallbackSave") -> list[futures.Future]: 131 | return [self._executor.submit(self.save, directory, args)] 132 | 133 | def restore(self, *args, **kwargs): 134 | raise NotImplementedError("CallbackHandler does not support restore") 135 | 136 | 137 | @ocp.args.register_with_handler(CallbackHandler, for_save=True) 138 | @dataclasses.dataclass 139 | class CallbackSave(ocp.args.CheckpointArgs): 140 | callback: Callback 141 | 142 | 143 | @ocp.args.register_with_handler(CallbackHandler, for_restore=True) 144 | class CallbackRestore(ocp.args.CheckpointArgs): ... 145 | 146 | 147 | def _split_params(state: training_utils.TrainState) -> tuple[training_utils.TrainState, at.Params]: 148 | if state.ema_params is not None: 149 | params = state.ema_params 150 | train_state = dataclasses.replace(state, ema_params=None) 151 | else: 152 | params = state.params 153 | train_state = dataclasses.replace(state, params={}) 154 | return train_state, params 155 | 156 | 157 | def _merge_params(train_state: training_utils.TrainState, params: dict[str, at.Params]) -> training_utils.TrainState: 158 | # Revert the logic inside `_split_params`. Assumes that existence of `params` means that EMA params were used during the split. 159 | if train_state.params: 160 | return dataclasses.replace(train_state, ema_params=params["params"]) 161 | return dataclasses.replace(train_state, params=params["params"]) 162 | -------------------------------------------------------------------------------- /src/openpi/training/data_loader_test.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | import jax 4 | 5 | from openpi.models import pi0 6 | from openpi.training import config as _config 7 | from openpi.training import data_loader as _data_loader 8 | 9 | 10 | def test_torch_data_loader(): 11 | config = pi0.Pi0Config(action_dim=24, action_horizon=50, max_token_len=48) 12 | dataset = _data_loader.FakeDataset(config, 16) 13 | 14 | loader = _data_loader.TorchDataLoader( 15 | dataset, 16 | local_batch_size=4, 17 | num_batches=2, 18 | ) 19 | batches = list(loader) 20 | 21 | assert len(batches) == 2 22 | for batch in batches: 23 | assert all(x.shape[0] == 4 for x in jax.tree.leaves(batch)) 24 | 25 | 26 | def test_torch_data_loader_infinite(): 27 | config = pi0.Pi0Config(action_dim=24, action_horizon=50, max_token_len=48) 28 | dataset = _data_loader.FakeDataset(config, 4) 29 | 30 | loader = _data_loader.TorchDataLoader(dataset, local_batch_size=4) 31 | data_iter = iter(loader) 32 | 33 | for _ in range(10): 34 | _ = next(data_iter) 35 | 36 | 37 | def test_torch_data_loader_parallel(): 38 | config = pi0.Pi0Config(action_dim=24, action_horizon=50, max_token_len=48) 39 | dataset = _data_loader.FakeDataset(config, 10) 40 | 41 | loader = _data_loader.TorchDataLoader(dataset, local_batch_size=4, num_batches=2, num_workers=2) 42 | batches = list(loader) 43 | 44 | assert len(batches) == 2 45 | 46 | for batch in batches: 47 | assert all(x.shape[0] == 4 for x in jax.tree.leaves(batch)) 48 | 49 | 50 | def test_with_fake_dataset(): 51 | config = _config.get_config("debug") 52 | 53 | loader = _data_loader.create_data_loader(config, skip_norm_stats=True, num_batches=2) 54 | batches = list(loader) 55 | 56 | assert len(batches) == 2 57 | 58 | for batch in batches: 59 | assert all(x.shape[0] == config.batch_size for x in jax.tree.leaves(batch)) 60 | 61 | for _, actions in batches: 62 | assert actions.shape == (config.batch_size, config.model.action_horizon, config.model.action_dim) 63 | 64 | 65 | def test_with_real_dataset(): 66 | config = _config.get_config("pi0_aloha_sim") 67 | config = dataclasses.replace(config, batch_size=4) 68 | 69 | loader = _data_loader.create_data_loader( 70 | config, 71 | # Skip since we may not have the data available. 72 | skip_norm_stats=True, 73 | num_batches=2, 74 | shuffle=True, 75 | ) 76 | # Make sure that we can get the data config. 77 | assert loader.data_config().repo_id == config.data.repo_id 78 | 79 | batches = list(loader) 80 | 81 | assert len(batches) == 2 82 | 83 | for _, actions in batches: 84 | assert actions.shape == (config.batch_size, config.model.action_horizon, config.model.action_dim) 85 | -------------------------------------------------------------------------------- /src/openpi/training/optimizer.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Protocol, runtime_checkable 3 | 4 | import jax.numpy as jnp 5 | import optax 6 | 7 | import openpi.shared.array_typing as at 8 | 9 | 10 | @runtime_checkable 11 | class LRScheduleConfig(Protocol): 12 | def create(self) -> optax.Schedule: ... 13 | 14 | 15 | @dataclasses.dataclass(frozen=True) 16 | class CosineDecaySchedule(LRScheduleConfig): 17 | """Cosine decay schedule with warmup.""" 18 | 19 | warmup_steps: int = 1_000 20 | peak_lr: float = 2.5e-5 21 | decay_steps: int = 30_000 22 | decay_lr: float = 2.5e-6 23 | 24 | def create(self) -> optax.Schedule: 25 | return optax.warmup_cosine_decay_schedule( 26 | init_value=self.peak_lr / (self.warmup_steps + 1), 27 | peak_value=self.peak_lr, 28 | warmup_steps=self.warmup_steps, 29 | decay_steps=self.decay_steps, 30 | end_value=self.decay_lr, 31 | ) 32 | 33 | 34 | @dataclasses.dataclass(frozen=True) 35 | class RsqrtDecaySchedule(LRScheduleConfig): 36 | """Inverse square root decay schedule with warmup.""" 37 | 38 | warmup_steps: int = 1_000 39 | peak_lr: float = 5e-5 40 | timescale: float = 10_000 41 | 42 | def create(self) -> optax.Schedule: 43 | return optax.join_schedules( 44 | [ 45 | optax.linear_schedule( 46 | init_value=self.peak_lr / (self.warmup_steps + 1), 47 | end_value=self.peak_lr, 48 | transition_steps=self.warmup_steps, 49 | ), 50 | lambda step: self.peak_lr / jnp.sqrt((self.timescale + step) / self.timescale), 51 | ], 52 | [self.warmup_steps], 53 | ) 54 | 55 | 56 | @runtime_checkable 57 | class OptimizerConfig(Protocol): 58 | def create( 59 | self, 60 | lr: optax.ScalarOrSchedule, 61 | weight_decay_mask: at.PyTree | None = None, 62 | ) -> optax.GradientTransformation: ... 63 | 64 | 65 | @dataclasses.dataclass(frozen=True) 66 | class AdamW(OptimizerConfig): 67 | """AdamW optimizer.""" 68 | 69 | b1: float = 0.9 70 | b2: float = 0.95 71 | eps: float = 1e-8 72 | weight_decay: float = 1e-10 73 | clip_gradient_norm: float = 1.0 74 | 75 | def create( 76 | self, 77 | lr: optax.ScalarOrSchedule, 78 | weight_decay_mask: at.PyTree | None = None, 79 | ) -> optax.GradientTransformation: 80 | tx = optax.adamw( 81 | lr, b1=self.b1, b2=self.b2, eps=self.eps, weight_decay=self.weight_decay, mask=weight_decay_mask 82 | ) 83 | 84 | return optax.chain(optax.clip_by_global_norm(self.clip_gradient_norm), tx) 85 | 86 | 87 | @dataclasses.dataclass(frozen=True) 88 | class SGD(OptimizerConfig): 89 | """SGD optimizer.""" 90 | 91 | lr: float = 5e-5 92 | momentum: float = 0.9 93 | nesterov: bool = False 94 | 95 | def create( 96 | self, 97 | lr: optax.ScalarOrSchedule, 98 | weight_decay_mask: at.PyTree | None = None, 99 | ) -> optax.GradientTransformation: 100 | assert weight_decay_mask is None, "Weight decay is not supported for SGD" 101 | return optax.sgd(lr, momentum=self.momentum, nesterov=self.nesterov) 102 | 103 | 104 | def create_optimizer( 105 | optimizer: OptimizerConfig, lr_schedule: LRScheduleConfig, weight_decay_mask: at.PyTree | None = None 106 | ) -> optax.GradientTransformation: 107 | lr = lr_schedule.create() 108 | return optimizer.create(lr, weight_decay_mask=weight_decay_mask) 109 | -------------------------------------------------------------------------------- /src/openpi/training/sharding.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import logging 3 | 4 | import jax 5 | import numpy as np 6 | 7 | BATCH_AXIS = "batch" 8 | FSDP_AXIS = "fsdp" 9 | # In FSDP, we shard the data across both the batch and FSDP axes. 10 | DATA_AXIS = (BATCH_AXIS, FSDP_AXIS) 11 | 12 | 13 | class _MeshState: 14 | active_mesh: jax.sharding.Mesh | None = None 15 | 16 | 17 | def make_mesh(num_fsdp_devices: int) -> jax.sharding.Mesh: 18 | if jax.device_count() % num_fsdp_devices != 0: 19 | raise ValueError( 20 | f"Number of devices {jax.device_count()} must be divisible by the number of FSDP devices {num_fsdp_devices}." 21 | ) 22 | mesh_shape = (jax.device_count() // num_fsdp_devices, num_fsdp_devices) 23 | return jax.make_mesh(mesh_shape, (BATCH_AXIS, FSDP_AXIS)) 24 | 25 | 26 | @contextlib.contextmanager 27 | def set_mesh(mesh: jax.sharding.Mesh): 28 | """Plumbing the mesh deep into the module tree is extremeley cumbersome; until the JAX team lands a better API, a 29 | custom context manager like this one is the recommended way to maintain a reference to a global mesh. This is only used 30 | in `activation_sharding_constraint` below.""" 31 | if _MeshState.active_mesh is not None: 32 | raise ValueError("Cannot nest set_mesh context managers.") 33 | _MeshState.active_mesh = mesh 34 | try: 35 | yield 36 | finally: 37 | _MeshState.active_mesh = None 38 | 39 | 40 | def activation_sharding_constraint(pytree): 41 | if _MeshState.active_mesh is None: 42 | return pytree 43 | return jax.lax.with_sharding_constraint( 44 | pytree, jax.sharding.NamedSharding(_MeshState.active_mesh, jax.sharding.PartitionSpec(DATA_AXIS)) 45 | ) 46 | 47 | 48 | def fsdp_sharding( 49 | pytree, 50 | mesh: jax.sharding.Mesh, 51 | *, 52 | min_size_mbytes: int = 4, # 4 MiB 53 | log: bool = False, 54 | ): 55 | """Apply FSDP sharding to a pytree of arrays based on the mesh shape. 56 | 57 | Args: 58 | pytree: A pytree to be apply sharding specified by the mesh, note that only array types (eg. contains .shape attr) 59 | will be considered for sharding. 60 | mesh: The mesh being used for applying sharding on to pytree. 61 | min_size_mbytes: The minimum size of the array in MiB to be considered for sharding, any array smaller than this 62 | will be replicated. 63 | log: If true, will log the sharding decisions for arrays that are being considered for sharding. 64 | 65 | Returns: 66 | The sharded pytree. 67 | """ 68 | min_size_bytes = min_size_mbytes * 2**20 69 | 70 | def _shard_arr(kp, array: jax.ShapeDtypeStruct): 71 | # if fsdp is not actually going to be used, replicate everything to avoid extraneous logging 72 | if mesh.shape[FSDP_AXIS] == 1: 73 | return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) 74 | # replicate scalar and vector arrays 75 | if not hasattr(array, "shape"): 76 | return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) 77 | if len(array.shape) < 2: 78 | return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) 79 | # replicate small arrays 80 | if (arr_size := np.prod(array.shape) * np.dtype(array.dtype).itemsize) < min_size_bytes: 81 | return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) 82 | 83 | # shard matrices and larger tensors along the largest axis that is divisible by the fsdp dimension 84 | axes = np.argsort(array.shape)[::-1] 85 | spec = [None] * len(axes) 86 | for i in axes: 87 | if array.shape[i] % mesh.shape[FSDP_AXIS] == 0: 88 | if log: 89 | logging.info( 90 | f"Sharding {jax.tree_util.keystr(kp)} of shape {array.shape} ({arr_size / 2**20:.2f} MiB) along axis {i}" 91 | ) 92 | spec[i] = FSDP_AXIS 93 | return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(*spec)) 94 | 95 | # replicate if no valid sharding was found 96 | if log: 97 | logging.warning( 98 | f"Could not find a valid sharding for {jax.tree_util.keystr(kp)} of shape {array.shape} with mesh of shape {mesh.shape}" 99 | ) 100 | return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) 101 | 102 | return jax.tree_util.tree_map_with_path(_shard_arr, pytree) 103 | -------------------------------------------------------------------------------- /src/openpi/training/utils.py: -------------------------------------------------------------------------------- 1 | import pprint 2 | from collections.abc import Callable 3 | from typing import Any 4 | 5 | import jax 6 | import optax 7 | import sentencepiece 8 | from flax import nnx, struct 9 | 10 | import openpi.shared.download as download 11 | from openpi.models import model as _model 12 | from openpi.shared import array_typing as at 13 | 14 | 15 | @at.typecheck 16 | @struct.dataclass 17 | class TrainState: 18 | step: at.Int[at.ArrayLike, ""] 19 | params: nnx.State 20 | model_def: nnx.GraphDef[_model.BaseModel] 21 | opt_state: optax.OptState 22 | tx: optax.GradientTransformation = struct.field(pytree_node=False) 23 | 24 | ema_decay: float | None = struct.field(pytree_node=False) 25 | ema_params: nnx.State | None = None 26 | 27 | 28 | @at.typecheck 29 | def tree_to_info(tree: at.PyTree, interp_func: Callable[[Any], str] = str) -> str: 30 | """Converts a PyTree into a human-readable string for logging. Optionally, `interp_func` can be provided to convert 31 | the leaf values to more meaningful strings. 32 | """ 33 | tree, _ = jax.tree_util.tree_flatten_with_path(tree) 34 | return "\n".join(f"{jax.tree_util.keystr(path)}: {interp_func(value)}" for path, value in tree) 35 | 36 | 37 | @at.typecheck 38 | def array_tree_to_info(tree: at.PyTree) -> str: 39 | """Converts a PyTree of arrays into a human-readable string for logging.""" 40 | return tree_to_info(tree, lambda x: f"{x.shape}@{x.dtype}") 41 | 42 | 43 | def inspect_prompts( 44 | batch: tuple[_model.FuseObservation | _model.Observation, at.Array], 45 | ) -> str: 46 | """Converts a PyTree of prompts into a human-readable string for logging.""" 47 | tokenized_prompt = batch[0].tokenized_prompt 48 | path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"}) 49 | with path.open("rb") as f: 50 | tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read()) 51 | 52 | prompts = [] 53 | for tokenized in tokenized_prompt: 54 | tokens = tokenizer.decode(tokenized.tolist(),) 55 | prompts.append(tokens) 56 | print("prompts:") 57 | pprint.pprint(prompts) 58 | -------------------------------------------------------------------------------- /src/openpi/training/weight_loaders.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import logging 3 | import re 4 | from typing import Protocol, runtime_checkable 5 | 6 | import flax.traverse_util 7 | import numpy as np 8 | 9 | import openpi.models.model as _model 10 | import openpi.shared.array_typing as at 11 | import openpi.shared.download as download 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | @runtime_checkable 17 | class WeightLoader(Protocol): 18 | def load(self, params: at.Params) -> at.Params: 19 | """Loads the model weights. 20 | 21 | Args: 22 | params: Parameters of the model. This is a nested structure of array-like objects that 23 | represent the model's parameters. 24 | 25 | Returns: 26 | Loaded parameters. The structure must be identical to `params`. If returning a subset of 27 | the parameters the loader must merge the loaded parameters with `params`. 28 | """ 29 | 30 | 31 | @dataclasses.dataclass(frozen=True) 32 | class NoOpWeightLoader(WeightLoader): 33 | def load(self, params: at.Params) -> at.Params: 34 | return params 35 | 36 | 37 | @dataclasses.dataclass(frozen=True) 38 | class CheckpointWeightLoader(WeightLoader): 39 | """Loads an entire set of weights from a checkpoint. 40 | 41 | Compatible with: 42 | trained checkpoints: 43 | example: "./checkpoints////params" 44 | released checkpoints: 45 | example: "s3://openpi-assets/checkpoints//params" 46 | """ 47 | 48 | params_path: str 49 | 50 | def load(self, params: at.Params) -> at.Params: 51 | # We are loading np.ndarray and relying on the training code to properly convert and shard the params. 52 | loaded_params = _model.restore_params(download.maybe_download(self.params_path), restore_type=np.ndarray) 53 | # Add all missing LoRA weights. 54 | return _merge_params(loaded_params, params, missing_regex=".*lora.*") 55 | 56 | 57 | @dataclasses.dataclass(frozen=True) 58 | class PaliGemmaWeightLoader(WeightLoader): 59 | """Loads weights from the official PaliGemma checkpoint. 60 | 61 | This will overwrite existing weights with similar names while keeping all extra weights intact. 62 | This allows us to support the action expert which is used by the Pi0 model. 63 | """ 64 | 65 | def load(self, params: at.Params) -> at.Params: 66 | path = download.maybe_download( 67 | "gs://vertex-model-garden-paligemma-us/paligemma/pt_224.npz", gs={"token": "anon"} 68 | ) 69 | with path.open("rb") as f: 70 | flat_params = dict(np.load(f, allow_pickle=False)) 71 | loaded_params = {"PaliGemma": flax.traverse_util.unflatten_dict(flat_params, sep="/")["params"]} 72 | # Add all missing weights. 73 | return _merge_params(loaded_params, params, missing_regex=".*") 74 | 75 | 76 | def _merge_params(loaded_params: at.Params, params: at.Params, *, missing_regex: str) -> at.Params: 77 | """Merges the loaded parameters with the reference parameters. 78 | 79 | Args: 80 | loaded_params: The parameters to merge. 81 | params: The reference parameters. 82 | missing_regex: A regex pattern for all missing keys that should be merged from the reference parameters. 83 | 84 | Returns: 85 | A new dictionary with the merged parameters. 86 | """ 87 | flat_ref = flax.traverse_util.flatten_dict(params, sep="/") 88 | flat_loaded = flax.traverse_util.flatten_dict(loaded_params, sep="/") 89 | 90 | # First, take all weights that are a subset of the reference weights. 91 | result = {} 92 | for k, v in flat_loaded.items(): 93 | if k in flat_ref: 94 | result[k] = v.astype(flat_ref[k].dtype) 95 | 96 | # Then, merge any missing weights as defined by the missing regex. 97 | pattern = re.compile(missing_regex) 98 | for k in {k for k in flat_ref if pattern.fullmatch(k)}: 99 | if k not in result: 100 | result[k] = flat_ref[k] 101 | 102 | return flax.traverse_util.unflatten_dict(result, sep="/") 103 | -------------------------------------------------------------------------------- /src/openpi/transforms_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | import openpi.models.tokenizer as _tokenizer 5 | import openpi.transforms as _transforms 6 | 7 | 8 | def test_repack_transform(): 9 | transform = _transforms.RepackTransform( 10 | structure={ 11 | "a": {"b": "b/c"}, 12 | "d": "e/f", 13 | } 14 | ) 15 | item = {"b": {"c": 1}, "e": {"f": 2}} 16 | assert transform(item) == {"a": {"b": 1}, "d": 2} 17 | 18 | 19 | def test_delta_actions(): 20 | item = {"state": np.array([1, 2, 3]), "actions": np.array([[3, 4, 5], [5, 6, 7]])} 21 | 22 | transform = _transforms.DeltaActions(mask=[False, True]) 23 | transformed = transform(item) 24 | 25 | assert np.all(transformed["state"] == np.array([1, 2, 3])) 26 | assert np.all(transformed["actions"] == np.array([[3, 2, 5], [5, 4, 7]])) 27 | 28 | 29 | def test_delta_actions_noop(): 30 | item = {"state": np.array([1, 2, 3]), "actions": np.array([[3, 4, 5], [5, 6, 7]])} 31 | 32 | # No-op when the mask is disabled. 33 | transform = _transforms.DeltaActions(mask=None) 34 | assert transform(item) is item 35 | 36 | # No-op when there are no actions in the input. 37 | del item["actions"] 38 | transform = _transforms.DeltaActions(mask=[True, False]) 39 | assert transform(item) is item 40 | 41 | 42 | def test_absolute_actions(): 43 | item = {"state": np.array([1, 2, 3]), "actions": np.array([[3, 4, 5], [5, 6, 7]])} 44 | 45 | transform = _transforms.AbsoluteActions(mask=[False, True]) 46 | transformed = transform(item) 47 | 48 | assert np.all(transformed["state"] == np.array([1, 2, 3])) 49 | assert np.all(transformed["actions"] == np.array([[3, 6, 5], [5, 8, 7]])) 50 | 51 | 52 | def test_absolute_actions_noop(): 53 | item = {"state": np.array([1, 2, 3]), "actions": np.array([[3, 4, 5], [5, 6, 7]])} 54 | 55 | # No-op when the mask is disabled. 56 | transform = _transforms.AbsoluteActions(mask=None) 57 | assert transform(item) is item 58 | 59 | # No-op when there are no actions in the input. 60 | del item["actions"] 61 | transform = _transforms.AbsoluteActions(mask=[True, False]) 62 | assert transform(item) is item 63 | 64 | 65 | def test_make_bool_mask(): 66 | assert _transforms.make_bool_mask(2, -2, 2) == (True, True, False, False, True, True) 67 | assert _transforms.make_bool_mask(2, 0, 2) == (True, True, True, True) 68 | 69 | 70 | def test_tokenize_prompt(): 71 | tokenizer = _tokenizer.PaligemmaTokenizer(max_len=12) 72 | transform = _transforms.TokenizePrompt(tokenizer) 73 | 74 | data = transform({"prompt": "Hello, world!"}) 75 | 76 | tok_prompt, tok_mask = tokenizer.tokenize("Hello, world!") 77 | assert np.allclose(tok_prompt, data["tokenized_prompt"]) 78 | assert np.allclose(tok_mask, data["tokenized_prompt_mask"]) 79 | 80 | 81 | def test_tokenize_no_prompt(): 82 | transform = _transforms.TokenizePrompt(_tokenizer.PaligemmaTokenizer()) 83 | 84 | with pytest.raises(ValueError, match="Prompt is required"): 85 | transform({}) 86 | 87 | 88 | def test_transform_dict(): 89 | # Rename and remove keys. 90 | input = {"a": {"b": 1, "c": 2}} 91 | output = _transforms.transform_dict({"a/b": "a/c", "a/c": None}, input) 92 | assert output == {"a": {"c": 1}} 93 | 94 | # Raises and error since the renamed key conflicts with an existing key. 95 | with pytest.raises(ValueError, match="Key 'a/c' already exists in output"): 96 | _transforms.transform_dict({"a/b": "a/c"}, input) 97 | 98 | # Full match is required and so nothing will be removed. 99 | input = {"a": {"b": 1, "c": 2}} 100 | output = _transforms.transform_dict({"a": None}, input) 101 | assert output == input 102 | 103 | # The regex matches the entire key and so the entire input will be removed. 104 | input = {"a": {"b": 1, "c": 2}} 105 | output = _transforms.transform_dict({"a.+": None}, input) 106 | assert output == {} 107 | 108 | # Replace keys using backreferences. All leaves named 'c' are replaced with 'd'. 109 | input = {"a": {"b": 1, "c": 1}, "b": {"c": 2}} 110 | output = _transforms.transform_dict({"(.+)/c": r"\1/d"}, input) 111 | assert output == {"a": {"b": 1, "d": 1}, "b": {"d": 2}} 112 | 113 | 114 | def test_extract_prompt_from_task(): 115 | transform = _transforms.PromptFromLeRobotTask({1: "Hello, world!"}) 116 | 117 | data = transform({"task_index": 1}) 118 | assert data["prompt"] == "Hello, world!" 119 | 120 | with pytest.raises(ValueError, match="task_index=2 not found in task mapping"): 121 | transform({"task_index": 2}) 122 | -------------------------------------------------------------------------------- /train_scripts/train_onetwovla_cocktail.sh: -------------------------------------------------------------------------------- 1 | logging_time=$(date "+%d-%H.%M.%S") 2 | now_seconds="${logging_time: -8}" 3 | now_date=$(date "+%Y.%m.%d") 4 | 5 | num_devices=$(nvidia-smi --list-gpus | wc -l) 6 | single_batch_size=20 7 | batch_size=$((num_devices * single_batch_size)) 8 | echo batch_size $batch_size 9 | 10 | single_val_batch_size=12 11 | val_batch_size=$((num_devices * single_val_batch_size)) 12 | echo val_batch_size $val_batch_size 13 | 14 | # the json file can be downloaded from https://huggingface.co/datasets/Richard-Nai/onetwovla-dataset/tree/main/cocktail 15 | # ensure the dataset's path is $LEROBOT_HOME/umi/cocktail 16 | reasoning_json_path=/cephfs/nairuiqian/codespace/pi-motion-generalization/assets/univla_cocktail/umi/cocktail/necessary/cot.json 17 | 18 | # normalization stats 19 | # this can only run on a single GPU. 20 | # this code only needs to run once. 21 | # CUDA_VISIBLE_DEVICES=0 uv run scripts/compute_norm_stats.py onetwovla_cocktail --exp-name=computing-norm \ 22 | # --create_train_val_split --val_ratio=0.05 \ 23 | # --reasoning_json_path $reasoning_json_path \ 24 | # --is_computing_norm_stats 25 | 26 | XLA_PYTHON_CLIENT_MEM_FRACTION=0.95 uv run scripts/train.py onetwovla_cocktail --exp-name=${now_date}/${now_seconds}/onetwovla-cocktail --batch-size=$batch_size --val-batch-size=$val_batch_size \ 27 | --reasoning_json_path $reasoning_json_path 28 | -------------------------------------------------------------------------------- /train_scripts/train_onetwovla_visual_grounding.sh: -------------------------------------------------------------------------------- 1 | logging_time=$(date "+%d-%H.%M.%S") 2 | now_seconds="${logging_time: -8}" 3 | now_date=$(date "+%Y.%m.%d") 4 | 5 | num_devices=$(nvidia-smi --list-gpus | wc -l) 6 | single_batch_size=20 7 | batch_size=$((num_devices * single_batch_size)) 8 | echo batch_size $batch_size 9 | 10 | single_val_batch_size=12 11 | val_batch_size=$((num_devices * single_val_batch_size)) 12 | echo val_batch_size $val_batch_size 13 | 14 | # the json file can be downloaded from https://huggingface.co/datasets/Richard-Nai/onetwovla-dataset/tree/main/wild_move_to 15 | # ensure the dataset's path is $LEROBOT_HOME/umi/wild_move_to 16 | reasoning_json_path=/path/to/your/cot.json 17 | 18 | # normalization stats 19 | # this can only run on a single GPU. 20 | # this code only needs to run once. 21 | CUDA_VISIBLE_DEVICES=0 uv run scripts/compute_norm_stats.py onetwovla_visual_grounding --exp-name=computing-norm \ 22 | --create_train_val_split --val_ratio=0.05 \ 23 | --reasoning_json_path $reasoning_json_path \ 24 | --is_computing_norm_stats 25 | 26 | XLA_PYTHON_CLIENT_MEM_FRACTION=0.95 uv run scripts/train.py onetwovla_visual_grounding --exp-name=${now_date}/${now_seconds}/one-two-vla-visual-grounding --batch-size=$batch_size --val-batch-size=$val_batch_size \ 27 | --reasoning_json_path $reasoning_json_path 28 | -------------------------------------------------------------------------------- /train_scripts/train_pi0_cocktail.sh: -------------------------------------------------------------------------------- 1 | logging_time=$(date "+%d-%H.%M.%S") 2 | now_seconds="${logging_time: -8}" 3 | now_date=$(date "+%Y.%m.%d") 4 | 5 | num_devices=$(nvidia-smi --list-gpus | wc -l) 6 | single_batch_size=20 7 | batch_size=$((num_devices * single_batch_size)) 8 | echo batch_size $batch_size 9 | 10 | single_val_batch_size=12 11 | val_batch_size=$((num_devices * single_val_batch_size)) 12 | echo val_batch_size $val_batch_size 13 | 14 | # ensure the dataset's path is $LEROBOT_HOME/umi/cocktail 15 | 16 | # normalization stats 17 | # this can only run on a single GPU. 18 | # this code only needs to run once. 19 | CUDA_VISIBLE_DEVICES=0 uv run scripts/compute_norm_stats.py pi0_cocktail --exp-name=computing-norm \ 20 | --create_train_val_split --val_ratio=0.05 \ 21 | --is_computing_norm_stats 22 | 23 | XLA_PYTHON_CLIENT_MEM_FRACTION=0.95 uv run scripts/train.py pi0_cocktail --exp-name=${now_date}/${now_seconds}/pi0-cocktail --batch-size=$batch_size --val-batch-size=$val_batch_size 24 | -------------------------------------------------------------------------------- /train_scripts/train_pi0_visual_grounding.sh: -------------------------------------------------------------------------------- 1 | logging_time=$(date "+%d-%H.%M.%S") 2 | now_seconds="${logging_time: -8}" 3 | now_date=$(date "+%Y.%m.%d") 4 | 5 | num_devices=$(nvidia-smi --list-gpus | wc -l) 6 | single_batch_size=20 7 | batch_size=$((num_devices * single_batch_size)) 8 | echo batch_size $batch_size 9 | 10 | single_val_batch_size=12 11 | val_batch_size=$((num_devices * single_val_batch_size)) 12 | echo val_batch_size $val_batch_size 13 | 14 | # the json file can be downloaded from https://huggingface.co/datasets/Richard-Nai/onetwovla-dataset/tree/main/wild_move_to_no_vl 15 | # ensure the dataset's path is $LEROBOT_HOME/umi/wild_move_to_no_vl 16 | reasoning_json_path=/path/to/your/cot.json 17 | 18 | # normalization stats 19 | # this can only run on a single GPU. 20 | # this code only needs to run once. 21 | CUDA_VISIBLE_DEVICES=0 uv run scripts/compute_norm_stats.py pi0_visual_grounding --exp-name=computing-norm \ 22 | --create_train_val_split --val_ratio=0.05 \ 23 | --reasoning_json_path $reasoning_json_path \ 24 | --is_computing_norm_stats 25 | 26 | XLA_PYTHON_CLIENT_MEM_FRACTION=0.95 uv run scripts/train.py pi0_visual_grounding --exp-name=${now_date}/${now_seconds}/pi0-visual-grounding --batch-size=$batch_size --val-batch-size=$val_batch_size \ 27 | --reasoning_json_path $reasoning_json_path 28 | --------------------------------------------------------------------------------