├── .dockerignore
├── .gitignore
├── .gitmodules
├── .pre-commit-config.yaml
├── .python-version
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── docs
├── docker.md
└── remote_inference.md
├── examples
├── aloha_real
│ ├── Dockerfile
│ ├── README.md
│ ├── compose.yml
│ ├── constants.py
│ ├── convert_aloha_data_to_lerobot.py
│ ├── env.py
│ ├── main.py
│ ├── real_env.py
│ ├── requirements.in
│ ├── requirements.txt
│ ├── robot_utils.py
│ └── video_display.py
├── aloha_sim
│ ├── Dockerfile
│ ├── README.md
│ ├── compose.yml
│ ├── env.py
│ ├── main.py
│ ├── requirements.in
│ ├── requirements.txt
│ └── saver.py
├── droid
│ ├── README.md
│ └── main.py
├── inference.ipynb
├── libero
│ ├── Dockerfile
│ ├── README.md
│ ├── compose.yml
│ ├── convert_libero_data_to_lerobot.py
│ ├── main.py
│ ├── requirements.in
│ └── requirements.txt
├── policy_records.ipynb
├── simple_client
│ ├── Dockerfile
│ ├── README.md
│ ├── compose.yml
│ ├── main.py
│ ├── requirements.in
│ └── requirements.txt
└── umi
│ ├── README.md
│ ├── convert_umi_data_to_lerobot.py
│ ├── imagecodecs_numcodecs.py
│ └── umi_replay_buffer.py
├── figures
└── fisheye-aug.png
├── packages
└── openpi-client
│ ├── pyproject.toml
│ └── src
│ └── openpi_client
│ ├── __init__.py
│ ├── action_chunk_broker.py
│ ├── base_policy.py
│ ├── image_tools.py
│ ├── image_tools_test.py
│ ├── msgpack_numpy.py
│ ├── msgpack_numpy_test.py
│ ├── runtime
│ ├── agent.py
│ ├── agents
│ │ └── policy_agent.py
│ ├── environment.py
│ ├── runtime.py
│ └── subscriber.py
│ └── websocket_client_policy.py
├── pyproject.toml
├── scripts
├── __init__.py
├── augment_vl_data
│ ├── augment.py
│ ├── finger_mask.jpg
│ ├── gripper_lens_mask.jpg
│ ├── inpainted.jpg
│ └── lens.jpg
├── compute_norm_stats.py
├── docker
│ ├── compose.yml
│ ├── install_docker_ubuntu22.sh
│ ├── install_nvidia_container_toolkit.sh
│ └── serve_policy.Dockerfile
├── serve_policy.py
├── train.py
└── train_test.py
├── src
└── openpi
│ ├── __init__.py
│ ├── conftest.py
│ ├── models
│ ├── __init__.py
│ ├── gemma.py
│ ├── gemma_fast.py
│ ├── lora.py
│ ├── lora_test.py
│ ├── model.py
│ ├── model_test.py
│ ├── pi0.py
│ ├── pi0_fast.py
│ ├── pi0_fuse.py
│ ├── pi0_test.py
│ ├── siglip.py
│ ├── tokenizer.py
│ ├── tokenizer_test.py
│ └── vit.py
│ ├── policies
│ ├── aloha_policy.py
│ ├── droid_policy.py
│ ├── libero_policy.py
│ ├── policy.py
│ ├── policy_config.py
│ ├── policy_test.py
│ ├── pose_repr_util.py
│ ├── pose_util.py
│ ├── umi_dataset.py
│ └── umi_policy.py
│ ├── py.typed
│ ├── serving
│ └── websocket_policy_server.py
│ ├── shared
│ ├── __init__.py
│ ├── array_typing.py
│ ├── download.py
│ ├── download_test.py
│ ├── image_tools.py
│ ├── image_tools_test.py
│ ├── nnx_utils.py
│ ├── normalize.py
│ └── normalize_test.py
│ ├── training
│ ├── checkpoints.py
│ ├── config.py
│ ├── data_loader.py
│ ├── data_loader_test.py
│ ├── optimizer.py
│ ├── sharding.py
│ ├── utils.py
│ └── weight_loaders.py
│ ├── transforms.py
│ └── transforms_test.py
├── train_scripts
├── train_onetwovla_cocktail.sh
├── train_onetwovla_visual_grounding.sh
├── train_pi0_cocktail.sh
└── train_pi0_visual_grounding.sh
└── uv.lock
/.dockerignore:
--------------------------------------------------------------------------------
1 | .venv
2 | checkpoints
3 | data
4 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Data directories.
2 | assets/
3 | checkpoints/
4 | data/
5 | wandb/
6 |
7 | # Byte-compiled / optimized / DLL files
8 | __pycache__/
9 | *.py[cod]
10 | *$py.class
11 |
12 | # C extensions
13 | *.so
14 |
15 | # Distribution / packaging
16 | .Python
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | eggs/
22 | .eggs/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | wheels/
29 | share/python-wheels/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 | MANIFEST
34 |
35 | # PyInstaller
36 | # Usually these files are written by a python script from a template
37 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
38 | *.manifest
39 | *.spec
40 |
41 | # Installer logs
42 | pip-log.txt
43 | pip-delete-this-directory.txt
44 |
45 | # Unit test / coverage reports
46 | htmlcov/
47 | .tox/
48 | .nox/
49 | .coverage
50 | .coverage.*
51 | .cache
52 | nosetests.xml
53 | coverage.xml
54 | *.cover
55 | *.py,cover
56 | .hypothesis/
57 | .pytest_cache/
58 | cover/
59 |
60 | # Translations
61 | *.mo
62 | *.pot
63 |
64 | # Django stuff:
65 | *.log
66 | local_settings.py
67 | db.sqlite3
68 | db.sqlite3-journal
69 |
70 | # Flask stuff:
71 | instance/
72 | .webassets-cache
73 |
74 | # Scrapy stuff:
75 | .scrapy
76 |
77 | # Sphinx documentation
78 | docs/_build/
79 |
80 | # PyBuilder
81 | .pybuilder/
82 | target/
83 |
84 | # Jupyter Notebook
85 | .ipynb_checkpoints
86 |
87 | # IPython
88 | profile_default/
89 | ipython_config.py
90 |
91 | # pyenv
92 | # For a library or package, you might want to ignore these files since the code is
93 | # intended to run in multiple environments; otherwise, check them in:
94 | # .python-version
95 |
96 | # pipenv
97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
100 | # install all needed dependencies.
101 | #Pipfile.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 | pretrain_model
170 |
171 | test_scripts/
172 | google-cloud-sdk
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "third_party/aloha"]
2 | path = third_party/aloha
3 | url = git@github.com:Physical-Intelligence/aloha.git
4 | [submodule "third_party/libero"]
5 | path = third_party/libero
6 | url = git@github.com:Lifelong-Robot-Learning/LIBERO.git
7 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | exclude: third_party/
2 |
3 | repos:
4 | - repo: https://github.com/astral-sh/uv-pre-commit
5 | # uv version.
6 | rev: 0.5.14
7 | hooks:
8 | - id: uv-lock
9 | - repo: https://github.com/astral-sh/ruff-pre-commit
10 | # Ruff version.
11 | rev: v0.8.6
12 | hooks:
13 | # Run the linter.
14 | - id: ruff
15 | args: [--fix]
16 | - id: ruff-format
--------------------------------------------------------------------------------
/.python-version:
--------------------------------------------------------------------------------
1 | 3.11
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to openpi
2 |
3 | We welcome contributions, improvements, and modifications. Everyone is welcome to use openpi in accordance to the [license](LICENSE). Contributors are also welcome to submit bug reports, feature requests, and pull requests. We can't promise to approve every pull request, and we are a small team with limited bandwidth to review all requests, but we'll give it our best effort. Specifics are described below.
4 |
5 | ## Issues and feature requests
6 |
7 | You are welcome to use the Github [discussion](https://github.com/Physical-Intelligence/openpi/discussions) feature if you would like to discuss something that is not directly reporting an issue or making a feature request. This is suitable for questions about how to use some aspect of openpi, or other topics.
8 |
9 | If you found a bug or other issue, please first check that the issue was not already reported (use the search bar on Github under Issues). If the issue has not yet been reported, please include this information when filing a Github issue:
10 |
11 | - Your OS type and version and the version of Python you are using
12 | - Code that allows us to reproduce your bug, including all dependencies
13 | - Traceback of any exception
14 | - Any other information that would help us, such as a screenshot
15 |
16 | In order for us to address any issue, we must be able to reproduce it, so if you encountered the issue after making modifications to openpi, please reproduce the issue without any other modifications and provide a code snippet that allows us to quickly reproduce the problem on `main`.
17 |
18 | If you would like to submit a feature request, please check that the feature request does not already exist, and please provide the following information:
19 |
20 | - The motivation for the feature
21 | - A description of the problem you are trying to solve or your use case
22 | - Enough information for us to understand the nature of the request
23 | - Some information for how you intend to use it (this might help us in understanding the motivation!)
24 |
25 | We can't promise to support every feature request, but it is helpful to us to know the use cases that you are interested in!
26 |
27 | ## Submitting a pull request
28 |
29 | If you implemented support for a new robot or environment, or some other new feature, we welcome pull requests (PRs) to openpi. We encourage you to create a [feature request](https://github.com/Physical-Intelligence/openpi/issues) or make a post on the [discussion](https://github.com/Physical-Intelligence/openpi/discussions) board before starting to work on your PR, if you would like to get a sense for whether we are likely to approve your PR if it is submitted. Since we are a small team with limited ability to provide maintenance and support, we may not accept all PRs (e.g., if we believe it would make the code harder to maintain, or if reviewing the PR is out of scope for us), so contacting us in advance is a good way to get a sense for whether your PR is likely to get approved for merging into openpi directly. But even if it isn't, you are of course more than welcome to maintain your own fork with whatever modifications you would like. When creating PRs, we recommend every contribution to consider the following:
30 |
31 | - Make sure that your PR has a clear title and description
32 | - Run `pre-commit` (install using `pre-commit install` first), and run `ruff check .` and `ruff format .`
33 | - Make sure your PR passes all tests
34 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Ruiqian Nai
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # OneTwoVLA: A Unified Vision-Language-Action Model with Adaptive Reasoning
2 |
3 | [[Project Page]](https://one-two-vla.github.io/) [[Paper]](https://arxiv.org/abs/2505.11917) [[Processed Datasets]](https://huggingface.co/datasets/Richard-Nai/onetwovla-dataset)
4 |
5 | [Fanqi Lin](https://fanqi-lin.github.io/)1,2,3,5\*,
6 | [Ruiqian Nai](https://richard-coder-nai.github.io/)1,2,3,5\*,
7 | [Yingdong Hu](https://yingdong-hu.github.io/)1,2,3\*,
8 | [Jiacheng You](https://scholar.google.com/citations?user=FiP-TVUAAAAJ)1,2,3,
9 | Junming Zhao1,4,
10 | [Yang Gao](https://yang-gao.weebly.com/)1,2,3,5
11 |
12 | 1Tsinghua University,
13 | 2Shanghai Qi Zhi Institute,
14 | 3Shanghai AI Lab,
15 | 4Fudan University,
16 | 5Spirit AI
17 |
18 | \* indicates equal contributions
19 |
20 |
21 | ## 🛠️ Installation
22 |
23 | We manage Python dependencies with [uv](https://docs.astral.sh/uv/). If you haven't installed `uv`, please follow [uv installation instructions](https://docs.astral.sh/uv/getting-started/installation/) to set it up.
24 |
25 | Run the following to set up the environment:
26 |
27 | ```bash
28 | GIT_LFS_SKIP_SMUDGE=1 uv sync
29 | GIT_LFS_SKIP_SMUDGE=1 uv pip install -e .
30 | ```
31 |
32 | > NOTE: `GIT_LFS_SKIP_SMUDGE=1` is needed to pull LeRobot as a dependency.
33 |
34 | For more details, refer to the original [openpi repository](https://github.com/Physical-Intelligence/openpi.git).
35 |
36 | ## 🚀 Training OneTwoVLA
37 | Download the dataset and place them under `$LEROBOT_HOME/umi/`.
38 |
39 | To train a OneTwoVLA model, run:
40 | ```bash
41 | bash train_scripts/train_.sh
42 | ```
43 | Available tasks are:
44 | ```bash
45 | train_scripts
46 | |-- train_onetwovla_cocktail.sh
47 | |-- train_onetwovla_visual_grounding.sh
48 | |-- train_pi0_cocktail.sh
49 | |-- train_pi0_visual_grounding.sh
50 | ```
51 |
52 | ## 🦾 Real-World Deployment
53 | We run inference using a policy server and a hardware client. The instructions for running policy server can be found at [examples/umi/README.md](examples/umi/README.md), and we provide the UMI hardware client code in this [repository](https://github.com/Fanqi-Lin/OneTwoVLA-UMI-Client).
54 |
55 | ## 📷 Data
56 | We provide access to the following datasets:
57 |
58 | - `Robot Datasets`: Datasets for the `cocktail` and `open-world visual grounding` tasks.
59 | - `Vision-Language Datasets`: Datasets contains synthetic images and annotated reasoning for the `open-world visual grounding` task.
60 |
61 | All datasets are hosted on Hugging Face. You can find them [here](https://huggingface.co/datasets/Richard-Nai/onetwovla-dataset).
62 |
63 | We provide code for converting UMI data format to LeRobot data format [here](examples/umi/convert_umi_data_to_lerobot.py).
64 |
65 | ### Synthetic Image Augmentation
66 |
67 | To make the synthetic images more closely resemble real robot observations, we randomly apply several augmentations, including random fisheye distortion and compositing a robot gripper with adaptive brightness adjustments. The implementation is available in [scripts/augment_vl_data/augment.py](scripts/augment_vl_data/augment.py).
68 |
69 | Here we show an example. From left to right, the images are: the original image, the image with fisheye distortion, the image compositing a robot gripper with adaptive brightness adjustments, and the image with both applied.
70 |
71 |
72 |
73 | ## 🙏 Acknowledgements
74 | We express our sincere gratitude to the developers of the [openpi](https://github.com/Physical-Intelligence/openpi.git) for open-sourcing their code.
75 |
--------------------------------------------------------------------------------
/docs/docker.md:
--------------------------------------------------------------------------------
1 | ### Docker Setup
2 |
3 | All of the examples in this repo provide instructions for being run normally, and also using Docker. Although not required, the Docker option is recommended as this will simplify software installation, produce a more stable environment, and also allow you to avoid installing ROS and cluttering your machine, for examples which depend on ROS.
4 |
5 | - Basic Docker installation instructions are [here](https://docs.docker.com/engine/install/).
6 | - Docker must be installed in [rootless mode](https://docs.docker.com/engine/security/rootless/).
7 | - To use your GPU you must also install the [NVIDIA container toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html).
8 | - The version of docker installed with `snap` is incompatible with the NVIDIA container toolkit, preventing it from accessing `libnvidia-ml.so` ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/154)). The snap version can be uninstalled with `sudo snap remove docker`.
9 | - Docker Desktop is also incompatible with the NVIDIA runtime ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/229)). Docker Desktop can be uninstalled with `sudo apt remove docker-desktop`.
10 |
11 |
12 | If starting from scratch and your host machine is Ubuntu 22.04, you can use accomplish all of the above with the convenience scripts `scripts/docker/install_docker_ubuntu22.sh` and `scripts/docker/install_nvidia_container_toolkit.sh`.
13 |
14 | During the first run of any example, Docker will build the images. Go grab a coffee while this happens. Subsequent runs will be faster since the images are cached.
--------------------------------------------------------------------------------
/docs/remote_inference.md:
--------------------------------------------------------------------------------
1 |
2 | # Running openpi models remotely
3 |
4 | We provide utilities for running openpi models remotely. This is useful for running inference on more powerful GPUs off-robot, and also helps keep the robot and policy environments separate (and e.g. avoid dependency hell with robot software).
5 |
6 | ## Starting a remote policy server
7 |
8 | To start a remote policy server, you can simply run the following command:
9 |
10 | ```bash
11 | uv run scripts/serve_policy.py --env=[DROID | ALOHA | LIBERO]
12 | ```
13 |
14 | The `env` argument specifies which $\pi_0$ checkpoint should be loaded. Under the hood, this script will execute a command like the following, which you can use to start a policy server, e.g. for checkpoints you trained yourself (here an example for the DROID environment):
15 |
16 | ```bash
17 | uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=s3://openpi-assets/checkpoints/pi0_fast_droid
18 | ```
19 |
20 | This will start a policy server that will serve the policy specified by the `config` and `dir` arguments. The policy will be served on the specified port (default: 8000).
21 |
22 | ## Querying the remote policy server from your robot code
23 |
24 | We provide a client utility with minimal dependencies that you can easily embed into any robot codebase.
25 |
26 | First, install the `openpi-client` package in your robot environment:
27 |
28 | ```bash
29 | cd $OPENPI_ROOT/packages/openpi-client
30 | pip install -e .
31 | ```
32 |
33 | Then, you can use the client to query the remote policy server from your robot code. Here's an example of how to do this:
34 |
35 | ```python
36 | from openpi_client import websocket_client_policy
37 |
38 | policy_client = websocket_client_policy.WebsocketClientPolicy(host="10.32.255.0", port=8000)
39 | action_chunk = policy_client.infer(example)["actions"]
40 | ```
41 |
42 | Here, the `host` and `port` arguments specify the IP address and port of the remote policy server. You can also specify these as command-line arguments to your robot code, or hard-code them in your robot codebase. The `example` is a dictionary of observations and the prompt, following the specification of the policy inputs for the policy you are serving. We have concrete examples of how to construct this dictionary for different environments in the [simple client example](examples/simple_client/main.py).
43 |
--------------------------------------------------------------------------------
/examples/aloha_real/Dockerfile:
--------------------------------------------------------------------------------
1 | # Dockerfile for the Aloha real environment.
2 |
3 | # Build the container:
4 | # docker build . -t aloha_real -f examples/aloha_real/Dockerfile
5 |
6 | # Run the container:
7 | # docker run --rm -it --network=host -v /dev:/dev -v .:/app --privileged aloha_real /bin/bash
8 |
9 | FROM ros:noetic-robot@sha256:0e12e4db836e78c74c4b04c6d16f185d9a18d2b13cf5580747efa075eb6dc6e0
10 | SHELL ["/bin/bash", "-c"]
11 |
12 | ENV DEBIAN_FRONTEND=noninteractive
13 | RUN apt-get update && \
14 | apt-get install -y --no-install-recommends \
15 | cmake \
16 | curl \
17 | libffi-dev \
18 | python3-rosdep \
19 | python3-rosinstall \
20 | python3-rosinstall-generator \
21 | whiptail \
22 | git \
23 | wget \
24 | openssh-client \
25 | ros-noetic-cv-bridge \
26 | ros-noetic-usb-cam \
27 | ros-noetic-realsense2-camera \
28 | keyboard-configuration
29 |
30 | WORKDIR /root
31 | RUN curl 'https://raw.githubusercontent.com/Interbotix/interbotix_ros_manipulators/main/interbotix_ros_xsarms/install/amd64/xsarm_amd64_install.sh' > xsarm_amd64_install.sh
32 | RUN chmod +x xsarm_amd64_install.sh
33 | RUN export TZ='America/Los_Angeles' && ./xsarm_amd64_install.sh -d noetic -n
34 |
35 | COPY ./third_party/aloha /root/interbotix_ws/src/aloha
36 | RUN cd /root/interbotix_ws && source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && catkin_make
37 |
38 | # Install python 3.10 because this ROS image comes with 3.8
39 | RUN mkdir /python && \
40 | cd /python && \
41 | wget https://www.python.org/ftp/python/3.10.14/Python-3.10.14.tgz && \
42 | tar -zxvf Python-3.10.14.tgz && \
43 | cd Python-3.10.14 && \
44 | ls -lhR && \
45 | ./configure --enable-optimizations && \
46 | make install && \
47 | echo 'alias python3="/usr/local/bin/python3.10"' >> ~/.bashrc && \
48 | echo 'alias python="/usr/local/bin/python3.10"' >> ~/.bashrc && \
49 | cd ~ && rm -rf /python && \
50 | rm -rf /var/lib/apt/lists/*
51 |
52 | COPY --from=ghcr.io/astral-sh/uv:0.5.6 /uv /bin/uv
53 | ENV UV_HTTP_TIMEOUT=120
54 | ENV UV_LINK_MODE=copy
55 | COPY ./examples/aloha_real/requirements.txt /tmp/requirements.txt
56 | COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
57 | RUN uv pip sync --python 3.10 --system /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
58 |
59 | ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src:/root/interbotix_ws/src/aloha/aloha_scripts:/root/interbotix_ws/src/aloha
60 | WORKDIR /app
61 |
62 | # Create an entrypoint script to run the setup commands, followed by the command passed in.
63 | RUN cat <<'EOF' > /usr/local/bin/entrypoint.sh
64 | #!/bin/bash
65 | source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && "$@"
66 | EOF
67 | RUN chmod +x /usr/local/bin/entrypoint.sh
68 |
69 | ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
70 | CMD ["python3", "/app/examples/aloha_real/main.py"]
71 |
--------------------------------------------------------------------------------
/examples/aloha_real/compose.yml:
--------------------------------------------------------------------------------
1 | # Run with:
2 | # docker compose -f examples/aloha_real/compose.yml up --build
3 | services:
4 | runtime:
5 | image: aloha_real
6 | depends_on:
7 | - aloha_ros_nodes
8 | - ros_master
9 | - openpi_server
10 | build:
11 | context: ../..
12 | dockerfile: examples/aloha_real/Dockerfile
13 | init: true
14 | tty: true
15 | network_mode: host
16 | privileged: true
17 | volumes:
18 | - $PWD:/app
19 | - ../../data:/data
20 |
21 | aloha_ros_nodes:
22 | image: aloha_real
23 | depends_on:
24 | - ros_master
25 | build:
26 | context: ../..
27 | dockerfile: examples/aloha_real/Dockerfile
28 | init: true
29 | tty: true
30 | network_mode: host
31 | privileged: true
32 | volumes:
33 | - /dev:/dev
34 | command: roslaunch --wait aloha ros_nodes.launch
35 |
36 | ros_master:
37 | image: ros:noetic-robot
38 | network_mode: host
39 | privileged: true
40 | command:
41 | - roscore
42 |
43 | openpi_server:
44 | image: openpi_server
45 | build:
46 | context: ../..
47 | dockerfile: scripts/docker/serve_policy.Dockerfile
48 | init: true
49 | tty: true
50 | network_mode: host
51 | volumes:
52 | - $PWD:/app
53 | - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
54 | environment:
55 | - SERVER_ARGS
56 | - OPENPI_DATA_HOME=/openpi_assets
57 | - IS_DOCKER=true
58 |
59 | # Comment out this block if not running on a machine with GPUs.
60 | deploy:
61 | resources:
62 | reservations:
63 | devices:
64 | - driver: nvidia
65 | count: 1
66 | capabilities: [gpu]
67 |
--------------------------------------------------------------------------------
/examples/aloha_real/constants.py:
--------------------------------------------------------------------------------
1 | # Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
2 | # ruff: noqa
3 |
4 | ### Task parameters
5 |
6 | ### ALOHA fixed constants
7 | DT = 0.001
8 | JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
9 | START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239]
10 |
11 | # Left finger position limits (qpos[7]), right_finger = -1 * left_finger
12 | MASTER_GRIPPER_POSITION_OPEN = 0.02417
13 | MASTER_GRIPPER_POSITION_CLOSE = 0.01244
14 | PUPPET_GRIPPER_POSITION_OPEN = 0.05800
15 | PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
16 |
17 | # Gripper joint limits (qpos[6])
18 | MASTER_GRIPPER_JOINT_OPEN = 0.3083
19 | MASTER_GRIPPER_JOINT_CLOSE = -0.6842
20 | PUPPET_GRIPPER_JOINT_OPEN = 1.4910
21 | PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
22 |
23 | ############################ Helper functions ############################
24 |
25 | MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (
26 | MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
27 | )
28 | PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (
29 | PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
30 | )
31 | MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = (
32 | lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
33 | )
34 | PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = (
35 | lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
36 | )
37 | MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x))
38 |
39 | MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / (
40 | MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE
41 | )
42 | PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / (
43 | PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE
44 | )
45 | MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = (
46 | lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
47 | )
48 | PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = (
49 | lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
50 | )
51 | MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x))
52 |
53 | MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
54 | PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
55 |
56 | MASTER_POS2JOINT = (
57 | lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
58 | + MASTER_GRIPPER_JOINT_CLOSE
59 | )
60 | MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN(
61 | (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
62 | )
63 | PUPPET_POS2JOINT = (
64 | lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
65 | + PUPPET_GRIPPER_JOINT_CLOSE
66 | )
67 | PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(
68 | (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
69 | )
70 |
71 | MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2
72 |
--------------------------------------------------------------------------------
/examples/aloha_real/env.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional # noqa: UP035
2 |
3 | import einops
4 | from openpi_client import image_tools
5 | from openpi_client.runtime import environment as _environment
6 | from typing_extensions import override
7 |
8 | from examples.aloha_real import real_env as _real_env
9 |
10 |
11 | class AlohaRealEnvironment(_environment.Environment):
12 | """An environment for an Aloha robot on real hardware."""
13 |
14 | def __init__(
15 | self,
16 | reset_position: Optional[List[float]] = None, # noqa: UP006,UP007
17 | render_height: int = 224,
18 | render_width: int = 224,
19 | ) -> None:
20 | self._env = _real_env.make_real_env(init_node=True, reset_position=reset_position)
21 | self._render_height = render_height
22 | self._render_width = render_width
23 |
24 | self._ts = None
25 |
26 | @override
27 | def reset(self) -> None:
28 | self._ts = self._env.reset()
29 |
30 | @override
31 | def is_episode_complete(self) -> bool:
32 | return False
33 |
34 | @override
35 | def get_observation(self) -> dict:
36 | if self._ts is None:
37 | raise RuntimeError("Timestep is not set. Call reset() first.")
38 |
39 | obs = self._ts.observation
40 | for k in list(obs["images"].keys()):
41 | if "_depth" in k:
42 | del obs["images"][k]
43 |
44 | for cam_name in obs["images"]:
45 | img = image_tools.convert_to_uint8(
46 | image_tools.resize_with_pad(obs["images"][cam_name], self._render_height, self._render_width)
47 | )
48 | obs["images"][cam_name] = einops.rearrange(img, "h w c -> c h w")
49 |
50 | return {
51 | "state": obs["qpos"],
52 | "images": obs["images"],
53 | }
54 |
55 | @override
56 | def apply_action(self, action: dict) -> None:
57 | self._ts = self._env.step(action["actions"])
58 |
--------------------------------------------------------------------------------
/examples/aloha_real/main.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import logging
3 |
4 | from openpi_client import action_chunk_broker
5 | from openpi_client import websocket_client_policy as _websocket_client_policy
6 | from openpi_client.runtime import runtime as _runtime
7 | from openpi_client.runtime.agents import policy_agent as _policy_agent
8 | import tyro
9 |
10 | from examples.aloha_real import env as _env
11 |
12 |
13 | @dataclasses.dataclass
14 | class Args:
15 | host: str = "0.0.0.0"
16 | port: int = 8000
17 |
18 | action_horizon: int = 25
19 |
20 | num_episodes: int = 1
21 | max_episode_steps: int = 1000
22 |
23 |
24 | def main(args: Args) -> None:
25 | ws_client_policy = _websocket_client_policy.WebsocketClientPolicy(
26 | host=args.host,
27 | port=args.port,
28 | )
29 | logging.info(f"Server metadata: {ws_client_policy.get_server_metadata()}")
30 |
31 | metadata = ws_client_policy.get_server_metadata()
32 | runtime = _runtime.Runtime(
33 | environment=_env.AlohaRealEnvironment(reset_position=metadata.get("reset_pose")),
34 | agent=_policy_agent.PolicyAgent(
35 | policy=action_chunk_broker.ActionChunkBroker(
36 | policy=ws_client_policy,
37 | action_horizon=args.action_horizon,
38 | )
39 | ),
40 | subscribers=[],
41 | max_hz=50,
42 | num_episodes=args.num_episodes,
43 | max_episode_steps=args.max_episode_steps,
44 | )
45 |
46 | runtime.run()
47 |
48 |
49 | if __name__ == "__main__":
50 | logging.basicConfig(level=logging.INFO, force=True)
51 | tyro.cli(main)
52 |
--------------------------------------------------------------------------------
/examples/aloha_real/requirements.in:
--------------------------------------------------------------------------------
1 | Pillow
2 | dm_control
3 | einops
4 | h5py
5 | matplotlib
6 | modern_robotics
7 | msgpack
8 | numpy
9 | opencv-python
10 | packaging
11 | pexpect
12 | pyquaternion
13 | pyrealsense2
14 | pyyaml
15 | requests
16 | rospkg
17 | tyro
18 | websockets
19 |
--------------------------------------------------------------------------------
/examples/aloha_real/requirements.txt:
--------------------------------------------------------------------------------
1 | # This file was autogenerated by uv via the following command:
2 | # uv pip compile examples/aloha_real/requirements.in -o examples/aloha_real/requirements.txt --python-version 3.10
3 | absl-py==2.1.0
4 | # via
5 | # dm-control
6 | # dm-env
7 | # labmaze
8 | # mujoco
9 | catkin-pkg==1.0.0
10 | # via rospkg
11 | certifi==2024.8.30
12 | # via requests
13 | charset-normalizer==3.4.0
14 | # via requests
15 | contourpy==1.1.1
16 | # via matplotlib
17 | cycler==0.12.1
18 | # via matplotlib
19 | distro==1.9.0
20 | # via rospkg
21 | dm-control==1.0.23
22 | # via -r examples/aloha_real/requirements.in
23 | dm-env==1.6
24 | # via dm-control
25 | dm-tree==0.1.8
26 | # via
27 | # dm-control
28 | # dm-env
29 | docstring-parser==0.16
30 | # via tyro
31 | docutils==0.20.1
32 | # via catkin-pkg
33 | einops==0.8.0
34 | # via -r examples/aloha_real/requirements.in
35 | etils==1.3.0
36 | # via mujoco
37 | fonttools==4.55.2
38 | # via matplotlib
39 | glfw==2.8.0
40 | # via
41 | # dm-control
42 | # mujoco
43 | h5py==3.11.0
44 | # via -r examples/aloha_real/requirements.in
45 | idna==3.10
46 | # via requests
47 | importlib-resources==6.4.5
48 | # via etils
49 | kiwisolver==1.4.7
50 | # via matplotlib
51 | labmaze==1.0.6
52 | # via dm-control
53 | lxml==5.3.0
54 | # via dm-control
55 | markdown-it-py==3.0.0
56 | # via rich
57 | matplotlib==3.7.5
58 | # via -r examples/aloha_real/requirements.in
59 | mdurl==0.1.2
60 | # via markdown-it-py
61 | modern-robotics==1.1.1
62 | # via -r examples/aloha_real/requirements.in
63 | msgpack==1.1.0
64 | # via -r examples/aloha_real/requirements.in
65 | mujoco==3.2.3
66 | # via dm-control
67 | numpy==1.24.4
68 | # via
69 | # -r examples/aloha_real/requirements.in
70 | # contourpy
71 | # dm-control
72 | # dm-env
73 | # h5py
74 | # labmaze
75 | # matplotlib
76 | # modern-robotics
77 | # mujoco
78 | # opencv-python
79 | # pyquaternion
80 | # scipy
81 | opencv-python==4.10.0.84
82 | # via -r examples/aloha_real/requirements.in
83 | packaging==24.2
84 | # via
85 | # -r examples/aloha_real/requirements.in
86 | # matplotlib
87 | pexpect==4.9.0
88 | # via -r examples/aloha_real/requirements.in
89 | pillow==10.4.0
90 | # via
91 | # -r examples/aloha_real/requirements.in
92 | # matplotlib
93 | protobuf==5.29.1
94 | # via dm-control
95 | ptyprocess==0.7.0
96 | # via pexpect
97 | pygments==2.18.0
98 | # via rich
99 | pyopengl==3.1.7
100 | # via
101 | # dm-control
102 | # mujoco
103 | pyparsing==3.1.4
104 | # via
105 | # catkin-pkg
106 | # dm-control
107 | # matplotlib
108 | pyquaternion==0.9.9
109 | # via -r examples/aloha_real/requirements.in
110 | pyrealsense2==2.55.1.6486
111 | # via -r examples/aloha_real/requirements.in
112 | python-dateutil==2.9.0.post0
113 | # via
114 | # catkin-pkg
115 | # matplotlib
116 | pyyaml==6.0.2
117 | # via
118 | # -r examples/aloha_real/requirements.in
119 | # rospkg
120 | requests==2.32.3
121 | # via
122 | # -r examples/aloha_real/requirements.in
123 | # dm-control
124 | rich==13.9.4
125 | # via tyro
126 | rospkg==1.5.1
127 | # via -r examples/aloha_real/requirements.in
128 | scipy==1.10.1
129 | # via dm-control
130 | setuptools==75.3.0
131 | # via
132 | # catkin-pkg
133 | # dm-control
134 | # labmaze
135 | shtab==1.7.1
136 | # via tyro
137 | six==1.17.0
138 | # via python-dateutil
139 | tqdm==4.67.1
140 | # via dm-control
141 | typeguard==4.4.0
142 | # via tyro
143 | typing-extensions==4.12.2
144 | # via
145 | # etils
146 | # rich
147 | # typeguard
148 | # tyro
149 | tyro==0.9.2
150 | # via -r examples/aloha_real/requirements.in
151 | urllib3==2.2.3
152 | # via requests
153 | websockets==14.1
154 | # via -r examples/aloha_real/requirements.in
155 | zipp==3.20.2
156 | # via etils
157 |
--------------------------------------------------------------------------------
/examples/aloha_real/video_display.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | from openpi_client.runtime import subscriber as _subscriber
4 | from typing_extensions import override
5 |
6 |
7 | class VideoDisplay(_subscriber.Subscriber):
8 | """Displays video frames."""
9 |
10 | def __init__(self) -> None:
11 | self._ax: plt.Axes | None = None
12 | self._plt_img: plt.Image | None = None
13 |
14 | @override
15 | def on_episode_start(self) -> None:
16 | plt.ion()
17 | self._ax = plt.subplot()
18 | self._plt_img = None
19 |
20 | @override
21 | def on_step(self, observation: dict, action: dict) -> None:
22 | assert self._ax is not None
23 |
24 | im = observation["image"][0] # [C, H, W]
25 | im = np.transpose(im, (1, 2, 0)) # [H, W, C]
26 |
27 | if self._plt_img is None:
28 | self._plt_img = self._ax.imshow(im)
29 | else:
30 | self._plt_img.set_data(im)
31 | plt.pause(0.001)
32 |
33 | @override
34 | def on_episode_end(self) -> None:
35 | plt.ioff()
36 | plt.close()
37 |
--------------------------------------------------------------------------------
/examples/aloha_sim/Dockerfile:
--------------------------------------------------------------------------------
1 | # Dockerfile for the Aloha simulation environment.
2 |
3 | # Build the container:
4 | # docker build . -t aloha_sim -f examples/aloha_sim/Dockerfile
5 |
6 | # Run the container:
7 | # docker run --rm -it --network=host -v .:/app aloha_sim /bin/bash
8 |
9 | FROM python:3.11-slim@sha256:370c586a6ffc8c619e6d652f81c094b34b14b8f2fb9251f092de23f16e299b78
10 | COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
11 |
12 | RUN apt-get update && \
13 | apt-get install -y \
14 | libosmesa6-dev \
15 | libgl1-mesa-glx \
16 | libglew-dev \
17 | libglfw3-dev \
18 | libgles2-mesa-dev
19 | ENV MUJOCO_GL=egl
20 |
21 | WORKDIR /app
22 |
23 | # Copy from the cache instead of linking since it's a mounted volume
24 | ENV UV_LINK_MODE=copy
25 |
26 | # Write the virtual environment outside of the project directory so it doesn't
27 | # leak out of the container when we mount the application code.
28 | ENV UV_PROJECT_ENVIRONMENT=/.venv
29 |
30 | # Copy the requirements files so we can install dependencies.
31 | # The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
32 | # This strategy is best for development-style usage.
33 | COPY ./examples/aloha_sim/requirements.txt /tmp/requirements.txt
34 | COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
35 |
36 | # Install python dependencies.
37 | RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
38 | RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
39 | ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
40 |
41 | CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/aloha_sim/main.py"]
--------------------------------------------------------------------------------
/examples/aloha_sim/README.md:
--------------------------------------------------------------------------------
1 | # Run Aloha Sim
2 |
3 | ## With Docker
4 |
5 | ```bash
6 | export SERVER_ARGS="--env ALOHA_SIM"
7 | docker compose -f examples/aloha_sim/compose.yml up --build
8 | ```
9 |
10 | ## Without Docker
11 |
12 | Terminal window 1:
13 |
14 | ```bash
15 | # Create virtual environment
16 | uv venv --python 3.10 examples/aloha_sim/.venv
17 | source examples/aloha_sim/.venv/bin/activate
18 | uv pip sync examples/aloha_sim/requirements.txt
19 | uv pip install -e packages/openpi-client
20 |
21 | # Run the simulation
22 | MUJOCO_GL=egl python examples/aloha_sim/main.py
23 | ```
24 |
25 | Note: If you are seeing EGL errors, you may need to install the following dependencies:
26 |
27 | ```bash
28 | sudo apt-get install -y libegl1-mesa-dev libgles2-mesa-dev
29 | ```
30 |
31 | Terminal window 2:
32 |
33 | ```bash
34 | # Run the server
35 | uv run scripts/serve_policy.py --env ALOHA_SIM
36 | ```
37 |
--------------------------------------------------------------------------------
/examples/aloha_sim/compose.yml:
--------------------------------------------------------------------------------
1 | # Run with:
2 | # docker compose -f examples/aloha_sim/compose.yml up --build
3 | services:
4 | runtime:
5 | image: aloha_sim
6 | depends_on:
7 | - openpi_server
8 | build:
9 | context: ../..
10 | dockerfile: examples/aloha_sim/Dockerfile
11 | init: true
12 | tty: true
13 | network_mode: host
14 | privileged: true
15 | volumes:
16 | - $PWD:/app
17 | - ../../data:/data
18 |
19 | openpi_server:
20 | image: openpi_server
21 | build:
22 | context: ../..
23 | dockerfile: scripts/docker/serve_policy.Dockerfile
24 | init: true
25 | tty: true
26 | network_mode: host
27 | volumes:
28 | - $PWD:/app
29 | - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
30 | environment:
31 | - SERVER_ARGS
32 | - OPENPI_DATA_HOME=/openpi_assets
33 | - IS_DOCKER=true
34 |
35 | # Comment out this block if not running on a machine with GPUs.
36 | deploy:
37 | resources:
38 | reservations:
39 | devices:
40 | - driver: nvidia
41 | count: 1
42 | capabilities: [gpu]
43 |
--------------------------------------------------------------------------------
/examples/aloha_sim/env.py:
--------------------------------------------------------------------------------
1 | import gym_aloha # noqa: F401
2 | import gymnasium
3 | import numpy as np
4 | from openpi_client import image_tools
5 | from openpi_client.runtime import environment as _environment
6 | from typing_extensions import override
7 |
8 |
9 | class AlohaSimEnvironment(_environment.Environment):
10 | """An environment for an Aloha robot in simulation."""
11 |
12 | def __init__(self, task: str, obs_type: str = "pixels_agent_pos", seed: int = 0) -> None:
13 | np.random.seed(seed)
14 | self._rng = np.random.default_rng(seed)
15 |
16 | self._gym = gymnasium.make(task, obs_type=obs_type)
17 |
18 | self._last_obs = None
19 | self._done = True
20 | self._episode_reward = 0.0
21 |
22 | @override
23 | def reset(self) -> None:
24 | gym_obs, _ = self._gym.reset(seed=int(self._rng.integers(2**32 - 1)))
25 | self._last_obs = self._convert_observation(gym_obs) # type: ignore
26 | self._done = False
27 | self._episode_reward = 0.0
28 |
29 | @override
30 | def is_episode_complete(self) -> bool:
31 | return self._done
32 |
33 | @override
34 | def get_observation(self) -> dict:
35 | if self._last_obs is None:
36 | raise RuntimeError("Observation is not set. Call reset() first.")
37 |
38 | return self._last_obs # type: ignore
39 |
40 | @override
41 | def apply_action(self, action: dict) -> None:
42 | gym_obs, reward, terminated, truncated, info = self._gym.step(action["actions"])
43 | self._last_obs = self._convert_observation(gym_obs) # type: ignore
44 | self._done = terminated or truncated
45 | self._episode_reward = max(self._episode_reward, reward)
46 |
47 | def _convert_observation(self, gym_obs: dict) -> dict:
48 | img = gym_obs["pixels"]["top"]
49 | img = image_tools.convert_to_uint8(image_tools.resize_with_pad(img, 224, 224))
50 | # Convert axis order from [H, W, C] --> [C, H, W]
51 | img = np.transpose(img, (2, 0, 1))
52 |
53 | return {
54 | "state": gym_obs["agent_pos"],
55 | "images": {"cam_high": img},
56 | }
57 |
--------------------------------------------------------------------------------
/examples/aloha_sim/main.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import logging
3 | import pathlib
4 |
5 | import env as _env
6 | from openpi_client import action_chunk_broker
7 | from openpi_client import websocket_client_policy as _websocket_client_policy
8 | from openpi_client.runtime import runtime as _runtime
9 | from openpi_client.runtime.agents import policy_agent as _policy_agent
10 | import saver as _saver
11 | import tyro
12 |
13 |
14 | @dataclasses.dataclass
15 | class Args:
16 | out_dir: pathlib.Path = pathlib.Path("data/aloha_sim/videos")
17 |
18 | task: str = "gym_aloha/AlohaTransferCube-v0"
19 | seed: int = 0
20 |
21 | action_horizon: int = 10
22 |
23 | host: str = "0.0.0.0"
24 | port: int = 8000
25 |
26 | display: bool = False
27 |
28 |
29 | def main(args: Args) -> None:
30 | runtime = _runtime.Runtime(
31 | environment=_env.AlohaSimEnvironment(
32 | task=args.task,
33 | seed=args.seed,
34 | ),
35 | agent=_policy_agent.PolicyAgent(
36 | policy=action_chunk_broker.ActionChunkBroker(
37 | policy=_websocket_client_policy.WebsocketClientPolicy(
38 | host=args.host,
39 | port=args.port,
40 | ),
41 | action_horizon=args.action_horizon,
42 | )
43 | ),
44 | subscribers=[
45 | _saver.VideoSaver(args.out_dir),
46 | ],
47 | max_hz=50,
48 | )
49 |
50 | runtime.run()
51 |
52 |
53 | if __name__ == "__main__":
54 | logging.basicConfig(level=logging.INFO, force=True)
55 | tyro.cli(main)
56 |
--------------------------------------------------------------------------------
/examples/aloha_sim/requirements.in:
--------------------------------------------------------------------------------
1 | gym-aloha
2 | imageio
3 | matplotlib
4 | msgpack
5 | numpy
6 | typing-extensions
7 | tyro
8 | websockets
--------------------------------------------------------------------------------
/examples/aloha_sim/requirements.txt:
--------------------------------------------------------------------------------
1 | # This file was autogenerated by uv via the following command:
2 | # uv pip compile examples/aloha_sim/requirements.in -o examples/aloha_sim/requirements.txt --python-version 3.10
3 | absl-py==2.1.0
4 | # via
5 | # dm-control
6 | # dm-env
7 | # labmaze
8 | # mujoco
9 | certifi==2024.8.30
10 | # via requests
11 | charset-normalizer==3.4.0
12 | # via requests
13 | cloudpickle==3.1.0
14 | # via gymnasium
15 | contourpy==1.3.1
16 | # via matplotlib
17 | cycler==0.12.1
18 | # via matplotlib
19 | dm-control==1.0.14
20 | # via gym-aloha
21 | dm-env==1.6
22 | # via dm-control
23 | dm-tree==0.1.8
24 | # via
25 | # dm-control
26 | # dm-env
27 | docstring-parser==0.16
28 | # via tyro
29 | farama-notifications==0.0.4
30 | # via gymnasium
31 | fonttools==4.55.2
32 | # via matplotlib
33 | glfw==2.8.0
34 | # via
35 | # dm-control
36 | # mujoco
37 | gym-aloha==0.1.1
38 | # via -r examples/aloha_sim/requirements.in
39 | gymnasium==1.0.0
40 | # via gym-aloha
41 | idna==3.10
42 | # via requests
43 | imageio==2.36.1
44 | # via
45 | # -r examples/aloha_sim/requirements.in
46 | # gym-aloha
47 | imageio-ffmpeg==0.5.1
48 | # via imageio
49 | kiwisolver==1.4.7
50 | # via matplotlib
51 | labmaze==1.0.6
52 | # via dm-control
53 | lxml==5.3.0
54 | # via dm-control
55 | markdown-it-py==3.0.0
56 | # via rich
57 | matplotlib==3.9.3
58 | # via -r examples/aloha_sim/requirements.in
59 | mdurl==0.1.2
60 | # via markdown-it-py
61 | msgpack==1.1.0
62 | # via -r examples/aloha_sim/requirements.in
63 | mujoco==2.3.7
64 | # via
65 | # dm-control
66 | # gym-aloha
67 | numpy==1.26.4
68 | # via
69 | # -r examples/aloha_sim/requirements.in
70 | # contourpy
71 | # dm-control
72 | # dm-env
73 | # gymnasium
74 | # imageio
75 | # labmaze
76 | # matplotlib
77 | # mujoco
78 | # scipy
79 | packaging==24.2
80 | # via matplotlib
81 | pillow==11.0.0
82 | # via
83 | # imageio
84 | # matplotlib
85 | protobuf==5.29.1
86 | # via dm-control
87 | psutil==6.1.0
88 | # via imageio
89 | pygments==2.18.0
90 | # via rich
91 | pyopengl==3.1.7
92 | # via
93 | # dm-control
94 | # mujoco
95 | pyparsing==3.2.0
96 | # via
97 | # dm-control
98 | # matplotlib
99 | python-dateutil==2.9.0.post0
100 | # via matplotlib
101 | requests==2.32.3
102 | # via dm-control
103 | rich==13.9.4
104 | # via tyro
105 | scipy==1.14.1
106 | # via dm-control
107 | setuptools==75.6.0
108 | # via
109 | # dm-control
110 | # imageio-ffmpeg
111 | # labmaze
112 | shtab==1.7.1
113 | # via tyro
114 | six==1.17.0
115 | # via python-dateutil
116 | tqdm==4.67.1
117 | # via dm-control
118 | typeguard==4.4.1
119 | # via tyro
120 | typing-extensions==4.12.2
121 | # via
122 | # -r examples/aloha_sim/requirements.in
123 | # gymnasium
124 | # rich
125 | # typeguard
126 | # tyro
127 | tyro==0.9.2
128 | # via -r examples/aloha_sim/requirements.in
129 | urllib3==2.2.3
130 | # via requests
131 | websockets==14.1
132 | # via -r examples/aloha_sim/requirements.in
133 |
--------------------------------------------------------------------------------
/examples/aloha_sim/saver.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import pathlib
3 |
4 | import imageio
5 | import numpy as np
6 | from openpi_client.runtime import subscriber as _subscriber
7 | from typing_extensions import override
8 |
9 |
10 | class VideoSaver(_subscriber.Subscriber):
11 | """Saves episode data."""
12 |
13 | def __init__(self, out_dir: pathlib.Path, subsample: int = 1) -> None:
14 | out_dir.mkdir(parents=True, exist_ok=True)
15 | self._out_dir = out_dir
16 | self._images: list[np.ndarray] = []
17 | self._subsample = subsample
18 |
19 | @override
20 | def on_episode_start(self) -> None:
21 | self._images = []
22 |
23 | @override
24 | def on_step(self, observation: dict, action: dict) -> None:
25 | im = observation["images"]["cam_high"] # [C, H, W]
26 | im = np.transpose(im, (1, 2, 0)) # [H, W, C]
27 | self._images.append(im)
28 |
29 | @override
30 | def on_episode_end(self) -> None:
31 | existing = list(self._out_dir.glob("out_[0-9]*.mp4"))
32 | next_idx = max([int(p.stem.split("_")[1]) for p in existing], default=-1) + 1
33 | out_path = self._out_dir / f"out_{next_idx}.mp4"
34 |
35 | logging.info(f"Saving video to {out_path}")
36 | imageio.mimwrite(
37 | out_path,
38 | [np.asarray(x) for x in self._images[:: self._subsample]],
39 | fps=50 // max(1, self._subsample),
40 | )
41 |
--------------------------------------------------------------------------------
/examples/droid/README.md:
--------------------------------------------------------------------------------
1 | # Run DROID
2 |
3 | This example shows how to run the fine-tuned $\pi_0$-FAST-DROID model on the [DROID robot platform](https://github.com/droid-dataset/droid). We also offer a $\pi_0$-DROID model that is fine-tuned from $\pi_0$ and uses flow action decoding. You can use it by replacing `pi0_fast_droid` with `pi0_droid` in the commands below. In practice, we find that out-of-the-box, the $\pi_0$-FAST-DROID model is better at following language commands, so we recommend it as the default checkpoint for DROID evaluation. If you want to fine-tune on a DROID task that requires a fast-to-inference policy, you may still want to consider using the $\pi_0$-DROID model, since it decodes faster. For more details, please see the [FAST paper](https://pi.website/research/fast).
4 |
5 |
6 | ## Step 1: Start a policy server
7 |
8 | Since the DROID control laptop does not have a powerful GPU, we will start a remote policy server on a different machine with a more powerful GPU and then query it from the DROID control laptop during inference.
9 |
10 | 1. On a machine with a powerful GPU (~NVIDIA 4090), clone and install the `openpi` repository following the instructions in the [README](https://github.com/Physical-Intelligence/openpi).
11 | 2. Start the OpenPI server via the following command:
12 |
13 | ```bash
14 | uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=s3://openpi-assets/checkpoints/pi0_fast_droid
15 | ```
16 |
17 | You can also run the equivalent command below:
18 |
19 | ```bash
20 | uv run scripts/serve_policy.py --env=DROID
21 | ```
22 |
23 | ## Step 2: Run the DROID robot
24 |
25 | 1. Make sure you have the most recent version of the DROID package installed on both the DROID control laptop and the NUC.
26 | 2. On the control laptop, activate your DROID conda environment.
27 | 3. Clone the openpi repo and install the openpi client, which we will use to connect to the policy server (this has very few dependencies and should be very fast to install): with the DROID conda environment activated, run `cd $OPENPI_ROOT/packages/openpi-client && pip install -e .`.
28 | 4. Install `tyro`, which we will use for command line parsing: `pip install tyro`.
29 | 5. Copy the `main.py` file from this directory to the `$DROID_ROOT/scripts` directory.
30 | 6. Replace the camera IDs in the `main.py` file with the IDs of your cameras (you can find the camera IDs by running `ZED_Explore` in the command line, which will open a tool that shows you all connected cameras and their IDs -- you can also use it to make sure that the cameras are well-positioned to see the scene you want the robot to interact with).
31 | 7. Run the `main.py` file. Make sure to point the IP and host address to the policy server. (To make sure the server machine is reachable from the DROID laptop, you can run `ping ` from the DROID laptop.) Also make sure to specify the external camera to use for the policy (we only input one external camera), choose from ["left", "right"].
32 |
33 | ```bash
34 | python3 scripts/main.py --remote_host= --remote_port= --external_camera="left"
35 | ```
36 |
37 | The script will ask you to enter a free-form language instruction for the robot to follow. Make sure to point the cameras at the scene you want the robot to interact with. You _do not_ need to carefully control camera angle, object positions, etc. The policy is fairly robust in our experience. Happy prompting!
38 |
39 | # Troubleshooting
40 |
41 | | Issue | Solution |
42 | |-------|----------|
43 | | Cannot reach policy server | Make sure the server is running and the IP and port are correct. You can check that the server machine is reachable by running `ping ` from the DROID laptop. |
44 | | Cannot find cameras | Make sure the camera IDs are correct and that the cameras are connected to the DROID laptop. Sometimes replugging the cameras can help. You can check all connected cameras by running `ZED_Explore` in the command line. |
45 | | Policy inference is slow / inconsistent | Try using a wired internet connection for the DROID laptop to reduce latency (0.5 - 1 sec latency per chunk is normal). |
46 | | Policy does not perform the task well | In our experiments, the policy could perform simple table top manipulation tasks (pick-and-place) across a wide range of environments, camera positions, and lighting conditions. If the policy does not perform the task well, you can try modifying the scene or object placement to make the task easier. Also make sure that the camera view you are passing to the policy can see all relevant objects in the scene (the policy is only conditioned on a single external camera + wrist camera, make sure you are feeding the desired camera to the policy). Use `ZED_Explore` to check that the camera view you are passing to the policy can see all relevant objects in the scene. Finally, the policy is far from perfect and will fail on more complex manipulation tasks, but it usually makes a decent effort. :) |
47 |
--------------------------------------------------------------------------------
/examples/inference.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import dataclasses\n",
10 | "\n",
11 | "import jax\n",
12 | "\n",
13 | "from openpi.models import model as _model\n",
14 | "from openpi.policies import droid_policy\n",
15 | "from openpi.policies import policy_config as _policy_config\n",
16 | "from openpi.shared import download\n",
17 | "from openpi.training import config as _config\n",
18 | "from openpi.training import data_loader as _data_loader"
19 | ]
20 | },
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {},
24 | "source": [
25 | "# Policy inference\n",
26 | "\n",
27 | "The following example shows how to create a policy from a checkpoint and run inference on a dummy example."
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": null,
33 | "metadata": {},
34 | "outputs": [],
35 | "source": [
36 | "config = _config.get_config(\"pi0_fast_droid\")\n",
37 | "checkpoint_dir = download.maybe_download(\"s3://openpi-assets/checkpoints/pi0_fast_droid\")\n",
38 | "\n",
39 | "# Create a trained policy.\n",
40 | "policy = _policy_config.create_trained_policy(config, checkpoint_dir)\n",
41 | "\n",
42 | "# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.\n",
43 | "example = droid_policy.make_droid_example()\n",
44 | "result = policy.infer(example)\n",
45 | "\n",
46 | "# Delete the policy to free up memory.\n",
47 | "del policy\n",
48 | "\n",
49 | "print(\"Actions shape:\", result[\"actions\"].shape)"
50 | ]
51 | },
52 | {
53 | "cell_type": "markdown",
54 | "metadata": {},
55 | "source": [
56 | "# Working with a live model\n",
57 | "\n",
58 | "\n",
59 | "The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.\n"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": null,
65 | "metadata": {},
66 | "outputs": [],
67 | "source": [
68 | "config = _config.get_config(\"pi0_aloha_sim\")\n",
69 | "\n",
70 | "checkpoint_dir = download.maybe_download(\"s3://openpi-assets/checkpoints/pi0_aloha_sim\")\n",
71 | "key = jax.random.key(0)\n",
72 | "\n",
73 | "# Create a model from the checkpoint.\n",
74 | "model = config.model.load(_model.restore_params(checkpoint_dir / \"params\"))\n",
75 | "\n",
76 | "# We can create fake observations and actions to test the model.\n",
77 | "obs, act = config.model.fake_obs(), config.model.fake_act()\n",
78 | "\n",
79 | "# Sample actions from the model.\n",
80 | "loss = model.compute_loss(key, obs, act)\n",
81 | "print(\"Loss shape:\", loss.shape)"
82 | ]
83 | },
84 | {
85 | "cell_type": "markdown",
86 | "metadata": {},
87 | "source": [
88 | "Now, we are going to create a data loader and use a real batch of training data to compute the loss."
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": null,
94 | "metadata": {},
95 | "outputs": [],
96 | "source": [
97 | "# Reduce the batch size to reduce memory usage.\n",
98 | "config = dataclasses.replace(config, batch_size=2)\n",
99 | "\n",
100 | "# Load a single batch of data. This is the same data that will be used during training.\n",
101 | "# NOTE: In order to make this example self-contained, we are skipping the normalization step\n",
102 | "# since it requires the normalization statistics to be generated using `compute_norm_stats`.\n",
103 | "loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)\n",
104 | "obs, act = next(iter(loader))\n",
105 | "\n",
106 | "# Sample actions from the model.\n",
107 | "loss = model.compute_loss(key, obs, act)\n",
108 | "\n",
109 | "# Delete the model to free up memory.\n",
110 | "del model\n",
111 | "\n",
112 | "print(\"Loss shape:\", loss.shape)"
113 | ]
114 | }
115 | ],
116 | "metadata": {
117 | "kernelspec": {
118 | "display_name": ".venv",
119 | "language": "python",
120 | "name": "python3"
121 | },
122 | "language_info": {
123 | "codemirror_mode": {
124 | "name": "ipython",
125 | "version": 3
126 | },
127 | "file_extension": ".py",
128 | "mimetype": "text/x-python",
129 | "name": "python",
130 | "nbconvert_exporter": "python",
131 | "pygments_lexer": "ipython3",
132 | "version": "3.11.9"
133 | }
134 | },
135 | "nbformat": 4,
136 | "nbformat_minor": 2
137 | }
138 |
--------------------------------------------------------------------------------
/examples/libero/Dockerfile:
--------------------------------------------------------------------------------
1 | # Dockerfile for the LIBERO benchmark.
2 |
3 | # Build the container:
4 | # docker build . -t libero -f examples/libero/Dockerfile
5 |
6 | # Run the container:
7 | # docker run --rm -it --network=host -v .:/app -v /tmp/.X11-unix:/tmp/.X11-unix:ro -e DISPLAY=$DISPLAY --gpus all libero /bin/bash
8 |
9 | FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0
10 | COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
11 |
12 | RUN apt-get update && \
13 | apt-get install -y \
14 | make \
15 | g++ \
16 | clang \
17 | libosmesa6-dev \
18 | libgl1-mesa-glx \
19 | libglew-dev \
20 | libglfw3-dev \
21 | libgles2-mesa-dev \
22 | libglib2.0-0 \
23 | libsm6 \
24 | libxrender1 \
25 | libxext6
26 |
27 | WORKDIR /app
28 |
29 | # Copy from the cache instead of linking since it's a mounted volume
30 | ENV UV_LINK_MODE=copy
31 |
32 | # Write the virtual environment outside of the project directory so it doesn't
33 | # leak out of the container when we mount the application code.
34 | ENV UV_PROJECT_ENVIRONMENT=/.venv
35 |
36 | # Copy the requirements files so we can install dependencies.
37 | # The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
38 | # This strategy is best for development-style usage.
39 | COPY ./examples/libero/requirements.txt /tmp/requirements.txt
40 | COPY ./third_party/libero/requirements.txt /tmp/requirements-libero.txt
41 | COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
42 |
43 | # Install python dependencies.
44 | RUN uv venv --python 3.8 $UV_PROJECT_ENVIRONMENT
45 | RUN uv pip sync /tmp/requirements.txt /tmp/requirements-libero.txt /tmp/openpi-client/pyproject.toml --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match
46 | ENV PYTHONPATH=/app:/app/packages/openpi-client/src:/app/third_party/libero
47 |
48 | # Create a default config file to avoid an input prompt from LIBERO's init script.
49 | # https://github.com/Lifelong-Robot-Learning/LIBERO/blob/master/libero/libero/__init__.py
50 | ENV LIBERO_CONFIG_PATH=/tmp/libero
51 | RUN mkdir -p /tmp/libero && cat <<'EOF' > /tmp/libero/config.yaml
52 | benchmark_root: /app/third_party/libero/libero/libero
53 | bddl_files: /app/third_party/libero/libero/libero/bddl_files
54 | init_states: /app/third_party/libero/libero/libero/init_files
55 | datasets: /app/third_party/libero/libero/datasets
56 | assets: /app/third_party/libero/libero/libero/assets
57 | EOF
58 |
59 | CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/libero/main.py"]
60 |
--------------------------------------------------------------------------------
/examples/libero/README.md:
--------------------------------------------------------------------------------
1 | # LIBERO Benchmark
2 |
3 | This example runs the LIBERO benchmark: https://github.com/Lifelong-Robot-Learning/LIBERO
4 |
5 | Note: When updating requirements.txt in this directory, there is an additional flag `--extra-index-url https://download.pytorch.org/whl/cu113` that must be added to the `uv pip compile` command.
6 |
7 | This example requires git submodules to be initialized. Don't forget to run:
8 |
9 | ```bash
10 | git submodule update --init --recursive
11 | ```
12 |
13 | ## With Docker
14 |
15 | ```bash
16 | # Grant access to the X11 server:
17 | sudo xhost +local:docker
18 |
19 | export SERVER_ARGS="--env LIBERO"
20 | docker compose -f examples/libero/compose.yml up --build
21 | ```
22 |
23 | ## Without Docker
24 |
25 | Terminal window 1:
26 |
27 | ```bash
28 | # Create virtual environment
29 | uv venv --python 3.8 examples/libero/.venv
30 | source examples/libero/.venv/bin/activate
31 | uv pip sync examples/libero/requirements.txt third_party/libero/requirements.txt --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match
32 | uv pip install -e packages/openpi-client
33 | uv pip install -e third_party/libero
34 | export PYTHONPATH=$PYTHONPATH:$PWD/third_party/libero
35 |
36 | # Run the simulation
37 | python examples/libero/main.py
38 | ```
39 |
40 | Terminal window 2:
41 |
42 | ```bash
43 | # Run the server
44 | uv run scripts/serve_policy.py --env LIBERO
45 | ```
46 |
47 | ## Results
48 |
49 | If you follow the training instructions and hyperparameters in the `pi0_libero` and `pi0_fast_libero` configs, you should get results similar to the following:
50 |
51 | | Model | Libero Spatial | Libero Object | Libero Goal | Libero 10 | Average |
52 | |-------|---------------|---------------|-------------|-----------|---------|
53 | | π0-FAST @ 30k (finetuned) | 96.4 | 96.8 | 88.6 | 60.2 | 85.5 |
54 | | π0 @ 30k (finetuned) | 96.8 | 98.8 | 95.8 | 85.2 | 94.15 |
55 |
56 | Note that the hyperparameters for these runs are not tuned and $\pi_0$-FAST does not use a FAST tokenizer optimized for Libero. Likely, the results could be improved with more tuning, we mainly use these results as an example of how to use openpi to fine-tune $\pi_0$ models on a new dataset.
57 |
--------------------------------------------------------------------------------
/examples/libero/compose.yml:
--------------------------------------------------------------------------------
1 | # Run with:
2 | # docker compose -f examples/libero/compose.yml up --build
3 | services:
4 | runtime:
5 | image: libero
6 | depends_on:
7 | - openpi_server
8 | build:
9 | context: ../..
10 | dockerfile: examples/libero/Dockerfile
11 | init: true
12 | tty: true
13 | network_mode: host
14 | privileged: true
15 | volumes:
16 | - $PWD:/app
17 | - ../../data:/data
18 | - /tmp/.X11-unix:/tmp/.X11-unix:ro
19 | environment:
20 | - DISPLAY=$DISPLAY
21 | deploy:
22 | resources:
23 | reservations:
24 | devices:
25 | - driver: nvidia
26 | count: 1
27 | capabilities: [gpu]
28 |
29 | openpi_server:
30 | image: openpi_server
31 | build:
32 | context: ../..
33 | dockerfile: scripts/docker/serve_policy.Dockerfile
34 | init: true
35 | tty: true
36 | network_mode: host
37 | volumes:
38 | - $PWD:/app
39 | - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
40 | environment:
41 | - SERVER_ARGS
42 | - OPENPI_DATA_HOME=/openpi_assets
43 | - IS_DOCKER=true
44 |
45 | # Comment out this block if not running on a machine with GPUs.
46 | deploy:
47 | resources:
48 | reservations:
49 | devices:
50 | - driver: nvidia
51 | count: 1
52 | capabilities: [gpu]
53 |
--------------------------------------------------------------------------------
/examples/libero/convert_libero_data_to_lerobot.py:
--------------------------------------------------------------------------------
1 | """
2 | Minimal example script for converting a dataset to LeRobot format.
3 |
4 | We use the Libero dataset (stored in RLDS) for this example, but it can be easily
5 | modified for any other data you have saved in a custom format.
6 |
7 | Usage:
8 | uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data
9 |
10 | If you want to push your dataset to the Hugging Face Hub, you can use the following command:
11 | uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub
12 |
13 | Note: to run the script, you need to install tensorflow_datasets:
14 | `uv pip install tensorflow tensorflow_datasets`
15 |
16 | You can download the raw Libero datasets from https://huggingface.co/datasets/openvla/modified_libero_rlds
17 | The resulting dataset will get saved to the $LEROBOT_HOME directory.
18 | Running this conversion script will take approximately 30 minutes.
19 | """
20 |
21 | import shutil
22 |
23 | from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
24 | from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
25 | import tensorflow_datasets as tfds
26 | import tyro
27 |
28 | REPO_NAME = "xxx/libero" # Name of the output dataset, also used for the Hugging Face Hub
29 | RAW_DATASET_NAMES = [
30 | "libero_10_no_noops",
31 | "libero_goal_no_noops",
32 | "libero_object_no_noops",
33 | "libero_spatial_no_noops",
34 | ] # For simplicity we will combine multiple Libero datasets into one training dataset
35 |
36 |
37 | def main(data_dir: str, *, push_to_hub: bool = False):
38 | # Clean up any existing dataset in the output directory
39 | output_path = LEROBOT_HOME / REPO_NAME
40 | if output_path.exists():
41 | shutil.rmtree(output_path)
42 |
43 | # Create LeRobot dataset, define features to store
44 | # OpenPi assumes that proprio is stored in `state` and actions in `action`
45 | # LeRobot assumes that dtype of image data is `image`
46 | dataset = LeRobotDataset.create(
47 | repo_id=REPO_NAME,
48 | robot_type="panda",
49 | fps=10,
50 | features={
51 | "image": {
52 | "dtype": "image",
53 | "shape": (256, 256, 3),
54 | "names": ["height", "width", "channel"],
55 | },
56 | "wrist_image": {
57 | "dtype": "image",
58 | "shape": (256, 256, 3),
59 | "names": ["height", "width", "channel"],
60 | },
61 | "state": {
62 | "dtype": "float32",
63 | "shape": (8,),
64 | "names": ["state"],
65 | },
66 | "actions": {
67 | "dtype": "float32",
68 | "shape": (7,),
69 | "names": ["actions"],
70 | },
71 | },
72 | image_writer_threads=10,
73 | image_writer_processes=5,
74 | )
75 |
76 | # Loop over raw Libero datasets and write episodes to the LeRobot dataset
77 | # You can modify this for your own data format
78 | for raw_dataset_name in RAW_DATASET_NAMES:
79 | raw_dataset = tfds.load(raw_dataset_name, data_dir=data_dir, split="train")
80 | for episode in raw_dataset:
81 | for step in episode["steps"].as_numpy_iterator():
82 | dataset.add_frame(
83 | {
84 | "image": step["observation"]["image"],
85 | "wrist_image": step["observation"]["wrist_image"],
86 | "state": step["observation"]["state"],
87 | "actions": step["action"],
88 | }
89 | )
90 | dataset.save_episode(task=step["language_instruction"].decode())
91 |
92 | # Consolidate the dataset, skip computing stats since we will do that later
93 | dataset.consolidate(run_compute_stats=False)
94 |
95 | # Optionally push to the Hugging Face Hub
96 | if push_to_hub:
97 | dataset.push_to_hub(
98 | tags=["libero", "panda", "rlds"],
99 | private=False,
100 | push_videos=True,
101 | license="apache-2.0",
102 | )
103 |
104 |
105 | if __name__ == "__main__":
106 | tyro.cli(main)
107 |
--------------------------------------------------------------------------------
/examples/libero/requirements.in:
--------------------------------------------------------------------------------
1 | imageio[ffmpeg]
2 | numpy==1.22.4
3 | tqdm
4 | tyro
5 | PyYaml
6 | opencv-python==4.6.0.66
7 | torch==1.11.0+cu113
8 | torchvision==0.12.0+cu113
9 | torchaudio==0.11.0+cu113
10 | robosuite==1.4.1
11 | matplotlib==3.5.3
12 |
--------------------------------------------------------------------------------
/examples/libero/requirements.txt:
--------------------------------------------------------------------------------
1 | # This file was autogenerated by uv via the following command:
2 | # uv pip compile examples/libero/requirements.in -o examples/libero/requirements.txt --python-version 3.8 --index-strategy=unsafe-best-match
3 | absl-py==2.1.0
4 | # via mujoco
5 | certifi==2024.12.14
6 | # via requests
7 | charset-normalizer==3.4.0
8 | # via requests
9 | cycler==0.12.1
10 | # via matplotlib
11 | docstring-parser==0.16
12 | # via tyro
13 | etils==1.3.0
14 | # via mujoco
15 | eval-type-backport==0.2.0
16 | # via tyro
17 | evdev==1.7.1
18 | # via pynput
19 | fonttools==4.55.3
20 | # via matplotlib
21 | glfw==1.12.0
22 | # via mujoco
23 | idna==3.10
24 | # via requests
25 | imageio==2.35.1
26 | # via -r examples/libero/requirements.in
27 | imageio-ffmpeg==0.5.1
28 | # via imageio
29 | importlib-metadata==8.5.0
30 | # via typeguard
31 | importlib-resources==6.4.5
32 | # via etils
33 | kiwisolver==1.4.7
34 | # via matplotlib
35 | llvmlite==0.36.0
36 | # via numba
37 | markdown-it-py==3.0.0
38 | # via rich
39 | matplotlib==3.5.3
40 | # via -r examples/libero/requirements.in
41 | mdurl==0.1.2
42 | # via markdown-it-py
43 | mujoco==3.2.3
44 | # via robosuite
45 | numba==0.53.1
46 | # via robosuite
47 | numpy==1.22.4
48 | # via
49 | # -r examples/libero/requirements.in
50 | # imageio
51 | # matplotlib
52 | # mujoco
53 | # numba
54 | # opencv-python
55 | # robosuite
56 | # scipy
57 | # torchvision
58 | opencv-python==4.6.0.66
59 | # via
60 | # -r examples/libero/requirements.in
61 | # robosuite
62 | packaging==24.2
63 | # via matplotlib
64 | pillow==10.4.0
65 | # via
66 | # imageio
67 | # matplotlib
68 | # robosuite
69 | # torchvision
70 | psutil==6.1.0
71 | # via imageio
72 | pygments==2.18.0
73 | # via rich
74 | pynput==1.7.7
75 | # via robosuite
76 | pyopengl==3.1.7
77 | # via mujoco
78 | pyparsing==3.1.4
79 | # via matplotlib
80 | python-dateutil==2.9.0.post0
81 | # via matplotlib
82 | python-xlib==0.33
83 | # via pynput
84 | pyyaml==6.0.2
85 | # via -r examples/libero/requirements.in
86 | requests==2.32.3
87 | # via torchvision
88 | rich==13.9.4
89 | # via tyro
90 | robosuite==1.4.1
91 | # via -r examples/libero/requirements.in
92 | scipy==1.10.1
93 | # via robosuite
94 | setuptools==75.3.0
95 | # via
96 | # imageio-ffmpeg
97 | # numba
98 | shtab==1.7.1
99 | # via tyro
100 | six==1.17.0
101 | # via
102 | # pynput
103 | # python-dateutil
104 | # python-xlib
105 | termcolor==2.4.0
106 | # via robosuite
107 | torch==1.11.0+cu113
108 | # via
109 | # -r examples/libero/requirements.in
110 | # torchaudio
111 | # torchvision
112 | torchaudio==0.11.0+cu113
113 | # via -r examples/libero/requirements.in
114 | torchvision==0.12.0+cu113
115 | # via -r examples/libero/requirements.in
116 | tqdm==4.67.1
117 | # via -r examples/libero/requirements.in
118 | typeguard==4.4.0
119 | # via tyro
120 | typing-extensions==4.12.2
121 | # via
122 | # etils
123 | # rich
124 | # torch
125 | # torchvision
126 | # typeguard
127 | # tyro
128 | tyro==0.9.2
129 | # via -r examples/libero/requirements.in
130 | urllib3==2.2.3
131 | # via requests
132 | zipp==3.20.2
133 | # via
134 | # etils
135 | # importlib-metadata
136 | # importlib-resources
137 |
--------------------------------------------------------------------------------
/examples/policy_records.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import pathlib\n",
10 | "\n",
11 | "import numpy as np\n",
12 | "\n",
13 | "record_path = pathlib.Path(\"../policy_records\")\n",
14 | "num_steps = len(list(record_path.glob(\"step_*.npy\")))\n",
15 | "\n",
16 | "records = []\n",
17 | "for i in range(num_steps):\n",
18 | " record = np.load(record_path / f\"step_{i}.npy\", allow_pickle=True).item()\n",
19 | " records.append(record)"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": null,
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "print(\"length of records\", len(records))\n",
29 | "print(\"keys in records\", records[0].keys())\n",
30 | "\n",
31 | "for k in records[0]:\n",
32 | " print(f\"{k} shape: {records[0][k].shape}\")"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": null,
38 | "metadata": {},
39 | "outputs": [],
40 | "source": [
41 | "from PIL import Image\n",
42 | "\n",
43 | "\n",
44 | "def get_image(step: int, idx: int = 0):\n",
45 | " img = (255 * records[step][\"inputs/image\"]).astype(np.uint8)\n",
46 | " return img[idx].transpose(1, 2, 0)\n",
47 | "\n",
48 | "\n",
49 | "def show_image(step: int, idx_lst: list[int]):\n",
50 | " imgs = [get_image(step, idx) for idx in idx_lst]\n",
51 | " return Image.fromarray(np.hstack(imgs))\n",
52 | "\n",
53 | "\n",
54 | "for i in range(2):\n",
55 | " display(show_image(i, [0]))"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": 14,
61 | "metadata": {},
62 | "outputs": [],
63 | "source": [
64 | "import pandas as pd\n",
65 | "\n",
66 | "\n",
67 | "def get_axis(name, axis):\n",
68 | " return np.array([record[name][axis] for record in records])\n",
69 | "\n",
70 | "\n",
71 | "# qpos is [..., 14] of type float:\n",
72 | "# 0-5: left arm joint angles\n",
73 | "# 6: left arm gripper\n",
74 | "# 7-12: right arm joint angles\n",
75 | "# 13: right arm gripper\n",
76 | "names = [(\"left_joint\", 6), (\"left_gripper\", 1), (\"right_joint\", 6), (\"right_gripper\", 1)]\n",
77 | "\n",
78 | "\n",
79 | "def make_data():\n",
80 | " cur_dim = 0\n",
81 | " in_data = {}\n",
82 | " out_data = {}\n",
83 | " for name, dim_size in names:\n",
84 | " for i in range(dim_size):\n",
85 | " in_data[f\"{name}_{i}\"] = get_axis(\"inputs/qpos\", cur_dim)\n",
86 | " out_data[f\"{name}_{i}\"] = get_axis(\"outputs/qpos\", cur_dim)\n",
87 | " cur_dim += 1\n",
88 | " return pd.DataFrame(in_data), pd.DataFrame(out_data)\n",
89 | "\n",
90 | "\n",
91 | "in_data, out_data = make_data()"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": null,
97 | "metadata": {},
98 | "outputs": [],
99 | "source": [
100 | "for name in in_data.columns:\n",
101 | " data = pd.DataFrame({f\"in_{name}\": in_data[name], f\"out_{name}\": out_data[name]})\n",
102 | " data.plot()"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": null,
108 | "metadata": {},
109 | "outputs": [],
110 | "source": []
111 | }
112 | ],
113 | "metadata": {
114 | "kernelspec": {
115 | "display_name": ".venv",
116 | "language": "python",
117 | "name": "python3"
118 | },
119 | "language_info": {
120 | "codemirror_mode": {
121 | "name": "ipython",
122 | "version": 3
123 | },
124 | "file_extension": ".py",
125 | "mimetype": "text/x-python",
126 | "name": "python",
127 | "nbconvert_exporter": "python",
128 | "pygments_lexer": "ipython3",
129 | "version": "3.11.9"
130 | }
131 | },
132 | "nbformat": 4,
133 | "nbformat_minor": 2
134 | }
135 |
--------------------------------------------------------------------------------
/examples/simple_client/Dockerfile:
--------------------------------------------------------------------------------
1 | # Dockerfile for the simple client.
2 |
3 | # Build the container:
4 | # docker build . -t simple_client -f examples/simple_client/Dockerfile
5 |
6 | # Run the container:
7 | # docker run --rm -it --network=host -v .:/app simple_client /bin/bash
8 |
9 | FROM python:3.7-slim
10 | COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
11 |
12 | WORKDIR /app
13 |
14 | # Copy from the cache instead of linking since it's a mounted volume
15 | ENV UV_LINK_MODE=copy
16 |
17 | # Write the virtual environment outside of the project directory so it doesn't
18 | # leak out of the container when we mount the application code.
19 | ENV UV_PROJECT_ENVIRONMENT=/.venv
20 |
21 | # Copy the requirements files so we can install dependencies.
22 | # The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
23 | # This strategy is best for development-style usage.
24 | COPY ./examples/simple_client/requirements.txt /tmp/requirements.txt
25 | COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
26 |
27 | # Install python dependencies.
28 | RUN uv venv --python 3.7 $UV_PROJECT_ENVIRONMENT
29 | RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
30 | ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
31 |
32 | CMD /bin/bash -c "source /.venv/bin/activate && python examples/simple_client/main.py $SERVER_ARGS"
33 |
--------------------------------------------------------------------------------
/examples/simple_client/README.md:
--------------------------------------------------------------------------------
1 | # Simple Client
2 |
3 | A minimal client that sends observations to the server and prints the inference rate.
4 |
5 | You can specify which runtime environment to use using the `--env` flag. You can see the available options by running:
6 |
7 | ```bash
8 | uv run examples/simple_client/main.py --help
9 | ```
10 |
11 | ## With Docker
12 |
13 | ```bash
14 | export SERVER_ARGS="--env ALOHA_SIM"
15 | docker compose -f examples/simple_client/compose.yml up --build
16 | ```
17 |
18 | ## Without Docker
19 |
20 | Terminal window 1:
21 |
22 | ```bash
23 | uv run examples/simple_client/main.py --env DROID
24 | ```
25 |
26 | Terminal window 2:
27 |
28 | ```bash
29 | uv run scripts/serve_policy.py --env DROID
30 | ```
31 |
--------------------------------------------------------------------------------
/examples/simple_client/compose.yml:
--------------------------------------------------------------------------------
1 | # Run with:
2 | # docker compose -f examples/simple_client/compose.yml up --build
3 | services:
4 | runtime:
5 | image: simple_client
6 | depends_on:
7 | - openpi_server
8 | build:
9 | context: ../..
10 | dockerfile: examples/simple_client/Dockerfile
11 | init: true
12 | tty: true
13 | network_mode: host
14 | volumes:
15 | - $PWD:/app
16 | environment:
17 | - SERVER_ARGS
18 |
19 | openpi_server:
20 | image: openpi_server
21 | build:
22 | context: ../..
23 | dockerfile: scripts/docker/serve_policy.Dockerfile
24 | init: true
25 | tty: true
26 | network_mode: host
27 | volumes:
28 | - $PWD:/app
29 | - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
30 | environment:
31 | - SERVER_ARGS
32 | - OPENPI_DATA_HOME=/openpi_assets
33 | - IS_DOCKER=true
34 |
35 | # Comment out this block if not running on a machine with GPUs.
36 | deploy:
37 | resources:
38 | reservations:
39 | devices:
40 | - driver: nvidia
41 | count: 1
42 | capabilities: [gpu]
43 |
--------------------------------------------------------------------------------
/examples/simple_client/main.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import enum
3 | import logging
4 | import time
5 |
6 | import numpy as np
7 | from openpi_client import websocket_client_policy as _websocket_client_policy
8 | import tyro
9 |
10 |
11 | class EnvMode(enum.Enum):
12 | """Supported environments."""
13 |
14 | ALOHA = "aloha"
15 | ALOHA_SIM = "aloha_sim"
16 | DROID = "droid"
17 | LIBERO = "libero"
18 |
19 |
20 | @dataclasses.dataclass
21 | class Args:
22 | host: str = "0.0.0.0"
23 | port: int = 8000
24 |
25 | env: EnvMode = EnvMode.ALOHA_SIM
26 | num_steps: int = 10
27 |
28 |
29 | def main(args: Args) -> None:
30 | obs_fn = {
31 | EnvMode.ALOHA: _random_observation_aloha,
32 | EnvMode.ALOHA_SIM: _random_observation_aloha,
33 | EnvMode.DROID: _random_observation_droid,
34 | EnvMode.LIBERO: _random_observation_libero,
35 | }[args.env]
36 |
37 | policy = _websocket_client_policy.WebsocketClientPolicy(
38 | host=args.host,
39 | port=args.port,
40 | )
41 | logging.info(f"Server metadata: {policy.get_server_metadata()}")
42 |
43 | # Send 1 observation to make sure the model is loaded.
44 | policy.infer(obs_fn())
45 |
46 | start = time.time()
47 | for _ in range(args.num_steps):
48 | policy.infer(obs_fn())
49 | end = time.time()
50 |
51 | print(f"Total time taken: {end - start:.2f} s")
52 | print(f"Average inference time: {1000 * (end - start) / args.num_steps:.2f} ms")
53 |
54 |
55 | def _random_observation_aloha() -> dict:
56 | return {
57 | "state": np.ones((14,)),
58 | "images": {
59 | "cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
60 | "cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
61 | "cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
62 | "cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
63 | },
64 | "prompt": "do something",
65 | }
66 |
67 |
68 | def _random_observation_droid() -> dict:
69 | return {
70 | "observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
71 | "observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
72 | "observation/joint_position": np.random.rand(7),
73 | "observation/gripper_position": np.random.rand(1),
74 | "prompt": "do something",
75 | }
76 |
77 |
78 | def _random_observation_libero() -> dict:
79 | return {
80 | "observation/state": np.random.rand(8),
81 | "observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
82 | "observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
83 | "prompt": "do something",
84 | }
85 |
86 |
87 | if __name__ == "__main__":
88 | logging.basicConfig(level=logging.INFO)
89 | main(tyro.cli(Args))
90 |
--------------------------------------------------------------------------------
/examples/simple_client/requirements.in:
--------------------------------------------------------------------------------
1 | numpy
2 | tyro
--------------------------------------------------------------------------------
/examples/simple_client/requirements.txt:
--------------------------------------------------------------------------------
1 | # This file was autogenerated by uv via the following command:
2 | # uv pip compile examples/simple_client/requirements.in -o examples/simple_client/requirements.txt --python-version 3.7
3 | backports-cached-property==1.0.2
4 | # via tyro
5 | docstring-parser==0.16
6 | # via tyro
7 | eval-type-backport==0.1.3
8 | # via tyro
9 | markdown-it-py==2.2.0
10 | # via rich
11 | mdurl==0.1.2
12 | # via markdown-it-py
13 | numpy==1.21.6
14 | # via -r examples/simple_client/requirements.in
15 | pygments==2.17.2
16 | # via rich
17 | rich==13.8.1
18 | # via tyro
19 | shtab==1.7.1
20 | # via tyro
21 | typing-extensions==4.7.1
22 | # via
23 | # markdown-it-py
24 | # rich
25 | # tyro
26 | tyro==0.9.1
27 | # via -r examples/simple_client/requirements.in
28 |
--------------------------------------------------------------------------------
/examples/umi/README.md:
--------------------------------------------------------------------------------
1 | # OneTwoVLA Policy Server
2 |
3 | Here are the instructions for running the policy server for OneTwoVLA. We provide the code to launch the UMI client in this [repository](https://github.com/Fanqi-Lin/OneTwoVLA-UMI-Client).
4 |
5 | ---
6 |
7 | ## Set Up the Policy Server
8 |
9 | First, install the required dependencies:
10 |
11 | ```bash
12 | uv pip install pynput
13 | ```
14 | > *Note: You may need `sudo` permissions.*
15 |
16 | Next, start the policy server on your desired port (e.g., `8000`):
17 |
18 | ```bash
19 | uv run scripts/serve_policy.py --port 8000 \
20 | policy:checkpoint \
21 | --policy.config=onetwovla_visual_grounding \
22 | --policy.dir=/path/to/your/checkpoint
23 | ```
24 |
25 | **Supported policy configurations:**
26 | - `onetwovla_visual_cocktail`
27 | - `onetwovla_visual_grounding`
28 | - `pi0_visual_cocktail`
29 | - `pi0_visual_grounding`
30 |
--------------------------------------------------------------------------------
/figures/fisheye-aug.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fanqi-Lin/OneTwoVLA/2489662cf1d4f75dbeea3cf9f337c95391e2b836/figures/fisheye-aug.png
--------------------------------------------------------------------------------
/packages/openpi-client/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "openpi-client"
3 | version = "0.1.0"
4 | requires-python = ">=3.7"
5 | dependencies = [
6 | "dm-tree>=0.1.8",
7 | "msgpack>=1.0.5",
8 | "numpy>=1.21.6",
9 | "pillow>=9.0.0",
10 | "tree>=0.2.4",
11 | "websockets>=11.0",
12 | ]
13 |
14 | [build-system]
15 | requires = ["hatchling"]
16 | build-backend = "hatchling.build"
17 |
18 | [tool.uv]
19 | dev-dependencies = [
20 | "pytest>=8.3.4",
21 | ]
22 |
23 | [tool.ruff]
24 | line-length = 120
25 | target-version = "py37"
--------------------------------------------------------------------------------
/packages/openpi-client/src/openpi_client/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.1.0"
2 |
--------------------------------------------------------------------------------
/packages/openpi-client/src/openpi_client/action_chunk_broker.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | import numpy as np
4 | import tree
5 | from typing_extensions import override
6 |
7 | from openpi_client import base_policy as _base_policy
8 |
9 |
10 | class ActionChunkBroker(_base_policy.BasePolicy):
11 | """Wraps a policy to return action chunks one-at-a-time.
12 |
13 | Assumes that the first dimension of all action fields is the chunk size.
14 |
15 | A new inference call to the inner policy is only made when the current
16 | list of chunks is exhausted.
17 | """
18 |
19 | def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int):
20 | self._policy = policy
21 |
22 | self._action_horizon = action_horizon
23 | self._cur_step: int = 0
24 |
25 | self._last_results: Dict[str, np.ndarray] | None = None
26 |
27 | @override
28 | def infer(self, obs: Dict) -> Dict: # noqa: UP006
29 | if self._last_results is None:
30 | self._last_results = self._policy.infer(obs)
31 | self._cur_step = 0
32 |
33 | results = tree.map_structure(lambda x: x[self._cur_step, ...], self._last_results)
34 | self._cur_step += 1
35 |
36 | if self._cur_step >= self._action_horizon:
37 | self._last_results = None
38 |
39 | return results
40 |
41 | @override
42 | def reset(self) -> None:
43 | self._policy.reset()
44 | self._last_results = None
45 | self._cur_step = 0
46 |
--------------------------------------------------------------------------------
/packages/openpi-client/src/openpi_client/base_policy.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from typing import Dict
3 |
4 |
5 | class BasePolicy(abc.ABC):
6 | @abc.abstractmethod
7 | def infer(self, obs: Dict) -> Dict:
8 | """Infer actions from observations."""
9 |
10 | def reset(self) -> None:
11 | """Reset the policy to its initial state."""
12 | pass
13 |
--------------------------------------------------------------------------------
/packages/openpi-client/src/openpi_client/image_tools.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 |
4 |
5 | def convert_to_uint8(img: np.ndarray) -> np.ndarray:
6 | """Converts an image to uint8 if it is a float image.
7 |
8 | This is important for reducing the size of the image when sending it over the network.
9 | """
10 | if np.issubdtype(img.dtype, np.floating):
11 | img = (255 * img).astype(np.uint8)
12 | return img
13 |
14 |
15 | def resize_with_pad(images: np.ndarray, height: int, width: int, method=Image.BILINEAR) -> np.ndarray:
16 | """Replicates tf.image.resize_with_pad for multiple images using PIL. Resizes a batch of images to a target height.
17 |
18 | Args:
19 | images: A batch of images in [..., height, width, channel] format.
20 | height: The target height of the image.
21 | width: The target width of the image.
22 | method: The interpolation method to use. Default is bilinear.
23 |
24 | Returns:
25 | The resized images in [..., height, width, channel].
26 | """
27 | # If the images are already the correct size, return them as is.
28 | if images.shape[-3:-1] == (height, width):
29 | return images
30 |
31 | original_shape = images.shape
32 |
33 | images = images.reshape(-1, *original_shape[-3:])
34 | resized = np.stack([_resize_with_pad_pil(Image.fromarray(im), height, width, method=method) for im in images])
35 | return resized.reshape(*original_shape[:-3], *resized.shape[-3:])
36 |
37 |
38 | def _resize_with_pad_pil(image: Image.Image, height: int, width: int, method: int) -> Image.Image:
39 | """Replicates tf.image.resize_with_pad for one image using PIL. Resizes an image to a target height and
40 | width without distortion by padding with zeros.
41 |
42 | Unlike the jax version, note that PIL uses [width, height, channel] ordering instead of [batch, h, w, c].
43 | """
44 | cur_width, cur_height = image.size
45 | if cur_width == width and cur_height == height:
46 | return image # No need to resize if the image is already the correct size.
47 |
48 | ratio = max(cur_width / width, cur_height / height)
49 | resized_height = int(cur_height / ratio)
50 | resized_width = int(cur_width / ratio)
51 | resized_image = image.resize((resized_width, resized_height), resample=method)
52 |
53 | zero_image = Image.new(resized_image.mode, (width, height), 0)
54 | pad_height = max(0, int((height - resized_height) / 2))
55 | pad_width = max(0, int((width - resized_width) / 2))
56 | zero_image.paste(resized_image, (pad_width, pad_height))
57 | assert zero_image.size == (width, height)
58 | return zero_image
59 |
--------------------------------------------------------------------------------
/packages/openpi-client/src/openpi_client/image_tools_test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import openpi_client.image_tools as image_tools
4 |
5 |
6 | def test_resize_with_pad_shapes():
7 | # Test case 1: Resize image with larger dimensions
8 | images = np.zeros((2, 10, 10, 3), dtype=np.uint8) # Input images of shape (batch_size, height, width, channels)
9 | height = 20
10 | width = 20
11 | resized_images = image_tools.resize_with_pad(images, height, width)
12 | assert resized_images.shape == (2, height, width, 3)
13 | assert np.all(resized_images == 0)
14 |
15 | # Test case 2: Resize image with smaller dimensions
16 | images = np.zeros((3, 30, 30, 3), dtype=np.uint8)
17 | height = 15
18 | width = 15
19 | resized_images = image_tools.resize_with_pad(images, height, width)
20 | assert resized_images.shape == (3, height, width, 3)
21 | assert np.all(resized_images == 0)
22 |
23 | # Test case 3: Resize image with the same dimensions
24 | images = np.zeros((1, 50, 50, 3), dtype=np.uint8)
25 | height = 50
26 | width = 50
27 | resized_images = image_tools.resize_with_pad(images, height, width)
28 | assert resized_images.shape == (1, height, width, 3)
29 | assert np.all(resized_images == 0)
30 |
31 | # Test case 3: Resize image with odd-numbered padding
32 | images = np.zeros((1, 256, 320, 3), dtype=np.uint8)
33 | height = 60
34 | width = 80
35 | resized_images = image_tools.resize_with_pad(images, height, width)
36 | assert resized_images.shape == (1, height, width, 3)
37 | assert np.all(resized_images == 0)
38 |
--------------------------------------------------------------------------------
/packages/openpi-client/src/openpi_client/msgpack_numpy.py:
--------------------------------------------------------------------------------
1 | """Adds NumPy array support to msgpack.
2 |
3 | msgpack is good for (de)serializing data over a network for multiple reasons:
4 | - msgpack is secure (as opposed to pickle/dill/etc which allow for arbitrary code execution)
5 | - msgpack is widely used and has good cross-language support
6 | - msgpack does not require a schema (as opposed to protobuf/flatbuffers/etc) which is convenient in dynamically typed
7 | languages like Python and JavaScript
8 | - msgpack is fast and efficient (as opposed to readable formats like JSON/YAML/etc); I found that msgpack was ~4x faster
9 | than pickle for serializing large arrays using the below strategy
10 |
11 | The code below is adapted from https://github.com/lebedov/msgpack-numpy. The reason not to use that library directly is
12 | that it falls back to pickle for object arrays.
13 | """
14 |
15 | import functools
16 |
17 | import msgpack
18 | import numpy as np
19 |
20 |
21 | def pack_array(obj):
22 | if (isinstance(obj, (np.ndarray, np.generic))) and obj.dtype.kind in ("V", "O", "c"):
23 | raise ValueError(f"Unsupported dtype: {obj.dtype}")
24 |
25 | if isinstance(obj, np.ndarray):
26 | return {
27 | b"__ndarray__": True,
28 | b"data": obj.tobytes(),
29 | b"dtype": obj.dtype.str,
30 | b"shape": obj.shape,
31 | }
32 |
33 | if isinstance(obj, np.generic):
34 | return {
35 | b"__npgeneric__": True,
36 | b"data": obj.item(),
37 | b"dtype": obj.dtype.str,
38 | }
39 |
40 | return obj
41 |
42 |
43 | def unpack_array(obj):
44 | if b"__ndarray__" in obj:
45 | return np.ndarray(buffer=obj[b"data"], dtype=np.dtype(obj[b"dtype"]), shape=obj[b"shape"])
46 |
47 | if b"__npgeneric__" in obj:
48 | return np.dtype(obj[b"dtype"]).type(obj[b"data"])
49 |
50 | return obj
51 |
52 |
53 | Packer = functools.partial(msgpack.Packer, default=pack_array)
54 | packb = functools.partial(msgpack.packb, default=pack_array)
55 |
56 | Unpacker = functools.partial(msgpack.Unpacker, object_hook=unpack_array)
57 | unpackb = functools.partial(msgpack.unpackb, object_hook=unpack_array)
58 |
--------------------------------------------------------------------------------
/packages/openpi-client/src/openpi_client/msgpack_numpy_test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import tree
4 |
5 | from openpi_client import msgpack_numpy
6 |
7 |
8 | def _check(expected, actual):
9 | if isinstance(expected, np.ndarray):
10 | assert expected.shape == actual.shape
11 | assert expected.dtype == actual.dtype
12 | assert np.array_equal(expected, actual, equal_nan=expected.dtype.kind == "f")
13 | else:
14 | assert expected == actual
15 |
16 |
17 | @pytest.mark.parametrize(
18 | "data",
19 | [
20 | 1, # int
21 | 1.0, # float
22 | "hello", # string
23 | np.bool_(True), # boolean scalar
24 | np.array([1, 2, 3])[0], # int scalar
25 | np.str_("asdf"), # string scalar
26 | [1, 2, 3], # list
27 | {"key": "value"}, # dict
28 | {"key": [1, 2, 3]}, # nested dict
29 | np.array(1.0), # 0D array
30 | np.array([1, 2, 3], dtype=np.int32), # 1D integer array
31 | np.array(["asdf", "qwer"]), # string array
32 | np.array([True, False]), # boolean array
33 | np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), # 2D float array
34 | np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.int16), # 3D integer array
35 | np.array([np.nan, np.inf, -np.inf]), # special float values
36 | {"arr": np.array([1, 2, 3]), "nested": {"arr": np.array([4, 5, 6])}}, # nested dict with arrays
37 | [np.array([1, 2]), np.array([3, 4])], # list of arrays
38 | np.zeros((3, 4, 5), dtype=np.float32), # 3D zeros
39 | np.ones((2, 3), dtype=np.float64), # 2D ones with double precision
40 | ],
41 | )
42 | def test_pack_unpack(data):
43 | packed = msgpack_numpy.packb(data)
44 | unpacked = msgpack_numpy.unpackb(packed)
45 | tree.map_structure(_check, data, unpacked)
46 |
--------------------------------------------------------------------------------
/packages/openpi-client/src/openpi_client/runtime/agent.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 |
4 | class Agent(abc.ABC):
5 | """An Agent is the thing with agency, i.e. the entity that makes decisions.
6 |
7 | Agents receive observations about the state of the world, and return actions
8 | to take in response.
9 | """
10 |
11 | @abc.abstractmethod
12 | def get_action(self, observation: dict) -> dict:
13 | """Query the agent for the next action."""
14 |
15 | @abc.abstractmethod
16 | def reset(self) -> None:
17 | """Reset the agent to its initial state."""
18 |
--------------------------------------------------------------------------------
/packages/openpi-client/src/openpi_client/runtime/agents/policy_agent.py:
--------------------------------------------------------------------------------
1 | from typing_extensions import override
2 |
3 | from openpi_client import base_policy as _base_policy
4 | from openpi_client.runtime import agent as _agent
5 |
6 |
7 | class PolicyAgent(_agent.Agent):
8 | """An agent that uses a policy to determine actions."""
9 |
10 | def __init__(self, policy: _base_policy.BasePolicy) -> None:
11 | self._policy = policy
12 |
13 | @override
14 | def get_action(self, observation: dict) -> dict:
15 | return self._policy.infer(observation)
16 |
17 | def reset(self) -> None:
18 | self._policy.reset()
19 |
--------------------------------------------------------------------------------
/packages/openpi-client/src/openpi_client/runtime/environment.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 |
4 | class Environment(abc.ABC):
5 | """An Environment represents the robot and the environment it inhabits.
6 |
7 | The primary contract of environments is that they can be queried for observations
8 | about their state, and have actions applied to them to change that state.
9 | """
10 |
11 | @abc.abstractmethod
12 | def reset(self) -> None:
13 | """Reset the environment to its initial state.
14 |
15 | This will be called once before starting each episode.
16 | """
17 |
18 | @abc.abstractmethod
19 | def is_episode_complete(self) -> bool:
20 | """Allow the environment to signal that the episode is complete.
21 |
22 | This will be called after each step. It should return `True` if the episode is
23 | complete (either successfully or unsuccessfully), and `False` otherwise.
24 | """
25 |
26 | @abc.abstractmethod
27 | def get_observation(self) -> dict:
28 | """Query the environment for the current state."""
29 |
30 | @abc.abstractmethod
31 | def apply_action(self, action: dict) -> None:
32 | """Take an action in the environment."""
33 |
--------------------------------------------------------------------------------
/packages/openpi-client/src/openpi_client/runtime/runtime.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import threading
3 | import time
4 |
5 | from openpi_client.runtime import agent as _agent
6 | from openpi_client.runtime import environment as _environment
7 | from openpi_client.runtime import subscriber as _subscriber
8 |
9 |
10 | class Runtime:
11 | """The core module orchestrating interactions between key components of the system."""
12 |
13 | def __init__(
14 | self,
15 | environment: _environment.Environment,
16 | agent: _agent.Agent,
17 | subscribers: list[_subscriber.Subscriber],
18 | max_hz: float = 0,
19 | num_episodes: int = 1,
20 | max_episode_steps: int = 0,
21 | ) -> None:
22 | self._environment = environment
23 | self._agent = agent
24 | self._subscribers = subscribers
25 | self._max_hz = max_hz
26 | self._num_episodes = num_episodes
27 | self._max_episode_steps = max_episode_steps
28 |
29 | self._in_episode = False
30 | self._episode_steps = 0
31 |
32 | def run(self) -> None:
33 | """Runs the runtime loop continuously until stop() is called or the environment is done."""
34 | for _ in range(self._num_episodes):
35 | self._run_episode()
36 |
37 | # Final reset, this is important for real environments to move the robot to its home position.
38 | self._environment.reset()
39 |
40 | def run_in_new_thread(self) -> threading.Thread:
41 | """Runs the runtime loop in a new thread."""
42 | thread = threading.Thread(target=self.run)
43 | thread.start()
44 | return thread
45 |
46 | def mark_episode_complete(self) -> None:
47 | """Marks the end of an episode."""
48 | self._in_episode = False
49 |
50 | def _run_episode(self) -> None:
51 | """Runs a single episode."""
52 | logging.info("Starting episode...")
53 | self._environment.reset()
54 | self._agent.reset()
55 | for subscriber in self._subscribers:
56 | subscriber.on_episode_start()
57 |
58 | self._in_episode = True
59 | self._episode_steps = 0
60 | step_time = 1 / self._max_hz if self._max_hz > 0 else 0
61 | last_step_time = time.time()
62 |
63 | while self._in_episode:
64 | self._step()
65 | self._episode_steps += 1
66 |
67 | # Sleep to maintain the desired frame rate
68 | now = time.time()
69 | dt = now - last_step_time
70 | if dt < step_time:
71 | time.sleep(step_time - dt)
72 | last_step_time = time.time()
73 | else:
74 | last_step_time = now
75 |
76 | logging.info("Episode completed.")
77 | for subscriber in self._subscribers:
78 | subscriber.on_episode_end()
79 |
80 | def _step(self) -> None:
81 | """A single step of the runtime loop."""
82 | observation = self._environment.get_observation()
83 | action = self._agent.get_action(observation)
84 | self._environment.apply_action(action)
85 |
86 | for subscriber in self._subscribers:
87 | subscriber.on_step(observation, action)
88 |
89 | if self._environment.is_episode_complete() or (
90 | self._max_episode_steps > 0 and self._episode_steps >= self._max_episode_steps
91 | ):
92 | self.mark_episode_complete()
93 |
--------------------------------------------------------------------------------
/packages/openpi-client/src/openpi_client/runtime/subscriber.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 |
4 | class Subscriber(abc.ABC):
5 | """Subscribes to events in the runtime.
6 |
7 | Subscribers can be used to save data, visualize, etc.
8 | """
9 |
10 | @abc.abstractmethod
11 | def on_episode_start(self) -> None:
12 | """Called when an episode starts."""
13 |
14 | @abc.abstractmethod
15 | def on_step(self, observation: dict, action: dict) -> None:
16 | """Append a step to the episode."""
17 |
18 | @abc.abstractmethod
19 | def on_episode_end(self) -> None:
20 | """Called when an episode ends."""
21 |
--------------------------------------------------------------------------------
/packages/openpi-client/src/openpi_client/websocket_client_policy.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import time
3 | from typing import Dict, Tuple
4 |
5 | import websockets.sync.client
6 | from typing_extensions import override
7 |
8 | from openpi_client import base_policy as _base_policy
9 | from openpi_client import msgpack_numpy
10 |
11 |
12 | class WebsocketClientPolicy(_base_policy.BasePolicy):
13 | """Implements the Policy interface by communicating with a server over websocket.
14 |
15 | See WebsocketPolicyServer for a corresponding server implementation.
16 | """
17 |
18 | def __init__(self, host: str = "0.0.0.0", port: int = 8000) -> None:
19 | self._uri = f"ws://{host}:{port}"
20 | self._packer = msgpack_numpy.Packer()
21 | self._ws, self._server_metadata = self._wait_for_server()
22 |
23 | def get_server_metadata(self) -> Dict:
24 | return self._server_metadata
25 |
26 | def _wait_for_server(self) -> Tuple[websockets.sync.client.ClientConnection, Dict]:
27 | logging.info(f"Waiting for server at {self._uri}...")
28 | while True:
29 | try:
30 | conn = websockets.sync.client.connect(self._uri, compression=None, max_size=None)
31 | metadata = msgpack_numpy.unpackb(conn.recv())
32 | return conn, metadata
33 | except ConnectionRefusedError:
34 | logging.info("Still waiting for server...")
35 | time.sleep(5)
36 |
37 | @override
38 | def infer(self, obs: Dict) -> Dict: # noqa: UP006
39 | data = self._packer.pack(obs)
40 | self._ws.send(data)
41 | response = self._ws.recv()
42 | if isinstance(response, str):
43 | # we're expecting bytes; if the server sends a string, it's an error.
44 | raise RuntimeError(f"Error in inference server:\n{response}")
45 | return msgpack_numpy.unpackb(response)
46 |
47 | @override
48 | def reset(self) -> None:
49 | pass
50 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "openpi"
3 | version = "0.1.0"
4 | description = "Physical Intelligence open source repo"
5 | readme = "README.md"
6 | requires-python = ">=3.11"
7 | license = { file = "LICENSE" }
8 | dependencies = [
9 | "augmax>=0.3.4",
10 | "dm-tree>=0.1.8",
11 | "einops>=0.8.0",
12 | "equinox>=0.11.8",
13 | "flatbuffers>=24.3.25",
14 | "flax==0.10.2",
15 | "fsspec[gcs]>=2024.6.0",
16 | "gym-aloha>=0.1.1",
17 | "imageio>=2.36.1",
18 | "jax[cuda12]==0.5.0",
19 | "jaxtyping==0.2.36",
20 | "lerobot",
21 | "ml_collections==1.0.0",
22 | "numpy>=1.26.4",
23 | "numpydantic>=1.6.6",
24 | "opencv-python>=4.10.0.84",
25 | "openpi-client",
26 | "orbax-checkpoint==0.11.1",
27 | "pillow>=11.0.0",
28 | "s3fs>=2024.9.0",
29 | "sentencepiece>=0.2.0",
30 | "torch>=2.5.1",
31 | "tqdm-loggable>=0.2",
32 | "typing-extensions>=4.12.2",
33 | "tyro>=0.9.5",
34 | "wandb==0.18.0",
35 | "boto3>=1.35.7",
36 | "types-boto3[boto3,s3]>=1.35.7",
37 | "filelock>=3.16.1",
38 | "beartype>=0.19.0",
39 | "treescope>=0.1.7",
40 | "transformers==4.48.1",
41 | "imagecodecs>=2024.12.30",
42 | ]
43 |
44 |
45 | [project.urls]
46 | Repository = "https://github.com/Richard-coder-Nai/onetwovla.git"
47 |
48 | [dependency-groups]
49 | dev = [
50 | "pytest>=8.3.4",
51 | "ruff>=0.8.6",
52 | "pre-commit>=4.0.1",
53 | "ipykernel>=6.29.5",
54 | "ipywidgets>=8.1.5",
55 | "matplotlib>=3.10.0",
56 | "pynvml>=12.0.0",
57 | ]
58 |
59 |
60 | [tool.uv.sources]
61 | openpi-client = { workspace = true }
62 | lerobot = { git = "https://github.com/huggingface/lerobot", rev = "6674e368249472c91382eb54bb8501c94c7f0c56" }
63 |
64 | [tool.uv.workspace]
65 | members = ["packages/*"]
66 |
67 | [[tool.uv.index]]
68 | url = "https://pypi.tuna.tsinghua.edu.cn/simple"
69 |
70 | [tool.ruff]
71 | line-length = 120
72 | target-version = "py311"
73 | extend-exclude = ["docker", "third_party"]
74 |
75 | [tool.ruff.lint]
76 | # https://docs.astral.sh/ruff/rules/
77 | select = [
78 | "B",
79 | "C4",
80 | "DTZ",
81 | "E4",
82 | "E7",
83 | "E9",
84 | "F",
85 | "FBT",
86 | "FURB",
87 | "I",
88 | "ICN",
89 | "ISC",
90 | "LOG",
91 | "N",
92 | "PD",
93 | "PERF",
94 | "PIE",
95 | "PLC",
96 | "PLE",
97 | "PLR1",
98 | "PLR5",
99 | "PLW",
100 | "PT",
101 | "PTH",
102 | "Q",
103 | "RET",
104 | "RUF",
105 | "SIM",
106 | "SLF",
107 | "T10",
108 | "T20",
109 | "UP",
110 | "W",
111 | ]
112 | ignore = [
113 | "F722", # Conflicts with array typing.
114 | "T201", # We use print statements.
115 | "PD008", # Lots of false positives.
116 | "ISC001", # Disabling to support ruff format.
117 | ]
118 | unfixable = [
119 | "B905", # Fix defaults to strict=False, which is not what we want.
120 | ]
121 |
122 | [tool.ruff.lint.isort]
123 | force-single-line = true
124 | force-sort-within-sections = true
125 | single-line-exclusions = ["collections.abc", "typing", "typing_extensions"]
126 | known-third-party = ["wandb"]
127 |
128 | [build-system]
129 | requires = ["hatchling"]
130 | build-backend = "hatchling.build"
131 |
132 | [tool.pytest.ini_options]
133 | markers = ["manual: should be run manually."]
134 | testpaths = ["src", "scripts", "packages"]
135 |
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fanqi-Lin/OneTwoVLA/2489662cf1d4f75dbeea3cf9f337c95391e2b836/scripts/__init__.py
--------------------------------------------------------------------------------
/scripts/augment_vl_data/finger_mask.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fanqi-Lin/OneTwoVLA/2489662cf1d4f75dbeea3cf9f337c95391e2b836/scripts/augment_vl_data/finger_mask.jpg
--------------------------------------------------------------------------------
/scripts/augment_vl_data/gripper_lens_mask.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fanqi-Lin/OneTwoVLA/2489662cf1d4f75dbeea3cf9f337c95391e2b836/scripts/augment_vl_data/gripper_lens_mask.jpg
--------------------------------------------------------------------------------
/scripts/augment_vl_data/inpainted.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fanqi-Lin/OneTwoVLA/2489662cf1d4f75dbeea3cf9f337c95391e2b836/scripts/augment_vl_data/inpainted.jpg
--------------------------------------------------------------------------------
/scripts/augment_vl_data/lens.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fanqi-Lin/OneTwoVLA/2489662cf1d4f75dbeea3cf9f337c95391e2b836/scripts/augment_vl_data/lens.jpg
--------------------------------------------------------------------------------
/scripts/compute_norm_stats.py:
--------------------------------------------------------------------------------
1 | """Compute normalization statistics for a config.
2 |
3 | This script is used to compute the normalization statistics for a given config. It
4 | will compute the mean and standard deviation of the data in the dataset and save it
5 | to the config assets directory.
6 | """
7 |
8 | import numpy as np
9 | import tqdm
10 | import tyro
11 |
12 | import openpi.shared.normalize as normalize
13 | import openpi.training.config as _config
14 | import openpi.training.data_loader as _data_loader
15 | import openpi.transforms as transforms
16 |
17 |
18 | class RemoveStrings(transforms.DataTransformFn):
19 | def __call__(self, x: dict) -> dict:
20 | return {k: v for k, v in x.items() if not np.issubdtype(np.asarray(v).dtype, np.str_)}
21 |
22 |
23 | def create_dataset(config: _config.TrainConfig) -> tuple[_config.DataConfig, _data_loader.Dataset]:
24 | data_config = config.data.create(config.assets_dirs, config.model)
25 | if data_config.repo_id is None:
26 | raise ValueError("Data config must have a repo_id")
27 | dataset, _ = _data_loader.create_dataset(data_config, config.model)
28 | dataset = _data_loader.TransformedDataset(
29 | dataset,
30 | [
31 | *data_config.repack_transforms.inputs,
32 | *data_config.data_transforms.inputs,
33 | # Remove strings since they are not supported by JAX and are not needed to compute norm stats.
34 | RemoveStrings(),
35 | ],
36 | )
37 | return data_config, dataset
38 |
39 |
40 | def main(config: _config.TrainConfig, max_frames: int | None = None):
41 | data_config, dataset = create_dataset(config)
42 |
43 | num_frames = len(dataset)
44 | shuffle = False
45 |
46 | if max_frames is not None and max_frames < num_frames:
47 | num_frames = max_frames
48 | shuffle = True
49 |
50 | data_loader = _data_loader.TorchDataLoader(
51 | dataset,
52 | local_batch_size=1,
53 | num_workers=8,
54 | shuffle=shuffle,
55 | num_batches=num_frames,
56 | )
57 |
58 | keys = ["state", "actions"]
59 | stats = {key: normalize.RunningStats() for key in keys}
60 |
61 | for batch in tqdm.tqdm(data_loader, total=num_frames, desc="Computing stats"):
62 | for key in keys:
63 | values = np.asarray(batch[key][0])
64 | stats[key].update(values.reshape(-1, values.shape[-1]))
65 |
66 | norm_stats = {key: stats.get_statistics() for key, stats in stats.items()}
67 |
68 | output_path = config.assets_dirs / data_config.repo_id
69 | if hasattr(data_config, "getitem_type"):
70 | output_path = config.assets_dirs / data_config.repo_id / data_config.getitem_type
71 |
72 | print(f"Writing stats to: {output_path}")
73 | normalize.save(output_path, norm_stats)
74 |
75 |
76 | if __name__ == "__main__":
77 | main(_config.cli())
78 |
--------------------------------------------------------------------------------
/scripts/docker/compose.yml:
--------------------------------------------------------------------------------
1 | # Run with:
2 | # docker compose -f scripts/compose.yml up --build
3 | services:
4 | openpi_server:
5 | image: openpi_server
6 | build:
7 | context: ..
8 | dockerfile: scripts/docker/serve_policy.Dockerfile
9 | init: true
10 | tty: true
11 | network_mode: host
12 | # Populate configured openpi data home to /openpi_assets inside the container.
13 | # Populate aws credential inside the container.
14 | volumes:
15 | - $PWD:/app
16 | - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
17 | environment:
18 | - SERVER_ARGS
19 | - OPENPI_DATA_HOME=/openpi_assets
20 | - IS_DOCKER=true
21 |
22 | # Comment out this block if not running on a machine with GPUs.
23 | deploy:
24 | resources:
25 | reservations:
26 | devices:
27 | - driver: nvidia
28 | count: 1
29 | capabilities: [gpu]
30 |
--------------------------------------------------------------------------------
/scripts/docker/install_docker_ubuntu22.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Add Docker's official GPG key:
4 | sudo apt-get update
5 | sudo apt-get install -y ca-certificates curl
6 | sudo install -m 0755 -d /etc/apt/keyrings
7 | sudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc
8 | sudo chmod a+r /etc/apt/keyrings/docker.asc
9 |
10 | # Add the repository to Apt sources:
11 | echo \
12 | "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \
13 | $(. /etc/os-release && echo "$VERSION_CODENAME") stable" |
14 | sudo tee /etc/apt/sources.list.d/docker.list >/dev/null
15 | sudo apt-get update
16 |
17 | sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin
18 |
19 | # Add current user to the 'docker' group, which allows them to use docker commands (docker build, docker run, etc).
20 | # See https://docs.docker.com/engine/install/linux-postinstall/
21 | username=$(whoami)
22 | sudo usermod -aG docker $username
23 |
24 | # Configure docker to start automatically on system boot.
25 | sudo systemctl enable docker.service
26 | sudo systemctl enable containerd.service
27 |
28 | # https://forums.docker.com/t/docker-credential-desktop-exe-executable-file-not-found-in-path-using-wsl2/100225/5
29 | if [ ~/.docker/config.json ]; then
30 | sed -i 's/credsStore/credStore/g' ~/.docker/config.json
31 | fi
32 |
33 | echo ""
34 | echo "********************************************************************"
35 | echo "**** Restart to allow Docker permission changes to take effect. ****"
36 | echo "********************************************************************"
37 | echo ""
38 |
--------------------------------------------------------------------------------
/scripts/docker/install_nvidia_container_toolkit.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Installs the NVIDIA Container Toolkit, which allows Docker containers to access NVIDIA GPUs.
4 | # NVIDIA's official documentation: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
5 |
6 | curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg &&
7 | curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list |
8 | sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' |
9 | sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
10 |
11 | # NVIDIA's documenation omits 'sudo' in the following command, but it is required.
12 | sudo sed -i -e '/experimental/ s/^#//g' /etc/apt/sources.list.d/nvidia-container-toolkit.list
13 | sudo apt-get update
14 | sudo apt-get install -y nvidia-container-toolkit
15 |
16 | sudo nvidia-ctk runtime configure --runtime=docker
17 | sudo systemctl restart docker
18 |
--------------------------------------------------------------------------------
/scripts/docker/serve_policy.Dockerfile:
--------------------------------------------------------------------------------
1 | # Dockerfile for serving a PI policy.
2 | # Based on UV's instructions: https://docs.astral.sh/uv/guides/integration/docker/#developing-in-a-container
3 |
4 | # Build the container:
5 | # docker build . -t openpi_server -f scripts/docker/serve_policy.Dockerfile
6 |
7 | # Run the container:
8 | # docker run --rm -it --network=host -v .:/app --gpus=all openpi_server /bin/bash
9 |
10 | FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0
11 | COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
12 |
13 | WORKDIR /app
14 |
15 | # Needed because LeRobot uses git-lfs.
16 | RUN apt-get update && apt-get install -y git git-lfs
17 |
18 | # Copy from the cache instead of linking since it's a mounted volume
19 | ENV UV_LINK_MODE=copy
20 |
21 | # Write the virtual environment outside of the project directory so it doesn't
22 | # leak out of the container when we mount the application code.
23 | ENV UV_PROJECT_ENVIRONMENT=/.venv
24 |
25 | # Install the project's dependencies using the lockfile and settings
26 | RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
27 | RUN --mount=type=cache,target=/root/.cache/uv \
28 | --mount=type=bind,source=uv.lock,target=uv.lock \
29 | --mount=type=bind,source=pyproject.toml,target=pyproject.toml \
30 | --mount=type=bind,source=packages/openpi-client/pyproject.toml,target=packages/openpi-client/pyproject.toml \
31 | --mount=type=bind,source=packages/openpi-client/src,target=packages/openpi-client/src \
32 | GIT_LFS_SKIP_SMUDGE=1 uv sync --frozen --no-install-project --no-dev
33 |
34 | CMD /bin/bash -c "uv run scripts/serve_policy.py $SERVER_ARGS"
35 |
--------------------------------------------------------------------------------
/scripts/serve_policy.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import enum
3 | import logging
4 | import socket
5 |
6 | import tyro
7 |
8 | from openpi.policies import policy as _policy
9 | from openpi.policies import policy_config as _policy_config
10 | from openpi.serving import websocket_policy_server
11 | from openpi.training import config as _config
12 |
13 |
14 | class EnvMode(enum.Enum):
15 | """Supported environments."""
16 |
17 | ALOHA = "aloha"
18 | ALOHA_SIM = "aloha_sim"
19 | DROID = "droid"
20 | LIBERO = "libero"
21 | FAST_BASE = "fast_base"
22 | BASE = "base"
23 |
24 |
25 | @dataclasses.dataclass
26 | class Checkpoint:
27 | """Load a policy from a trained checkpoint."""
28 |
29 | # Training config name (e.g., "pi0_aloha_sim").
30 | config: str
31 | # Checkpoint directory (e.g., "checkpoints/pi0_aloha_sim/exp/10000").
32 | dir: str
33 |
34 |
35 | @dataclasses.dataclass
36 | class Default:
37 | """Use the default policy for the given environment."""
38 |
39 |
40 | @dataclasses.dataclass
41 | class Args:
42 | """Arguments for the serve_policy script."""
43 |
44 | # Environment to serve the policy for. This is only used when serving default policies.
45 | env: EnvMode = EnvMode.ALOHA_SIM
46 |
47 | # If provided, will be used in case the "prompt" key is not present in the data, or if the model doesn't have a default
48 | # prompt.
49 | default_prompt: str | None = None
50 |
51 | # Port to serve the policy on.
52 | port: int = 8000
53 | # Record the policy's behavior for debugging.
54 | record: bool = False
55 |
56 | # Specifies how to load the policy. If not provided, the default policy for the environment will be used.
57 | policy: Checkpoint | Default = dataclasses.field(default_factory=Default)
58 |
59 |
60 | # Default checkpoints that should be used for each environment.
61 | DEFAULT_CHECKPOINT: dict[EnvMode, Checkpoint] = {
62 | EnvMode.ALOHA: Checkpoint(
63 | config="pi0_aloha",
64 | dir="s3://openpi-assets/checkpoints/pi0_base",
65 | ),
66 | EnvMode.ALOHA_SIM: Checkpoint(
67 | config="pi0_aloha_sim",
68 | dir="s3://openpi-assets/checkpoints/pi0_aloha_sim",
69 | ),
70 | EnvMode.DROID: Checkpoint(
71 | config="pi0_fast_droid",
72 | dir="s3://openpi-assets/checkpoints/pi0_fast_droid",
73 | ),
74 | EnvMode.LIBERO: Checkpoint(
75 | config="pi0_fast_libero",
76 | dir="s3://openpi-assets/checkpoints/pi0_fast_libero",
77 | ),
78 | EnvMode.FAST_BASE: Checkpoint(
79 | config="pi0_fast_umi",
80 | dir="s3://openpi-assets/checkpoints/pi0_fast_base",
81 | ),
82 | EnvMode.BASE: Checkpoint(
83 | config="pi0_umi",
84 | dir="s3://openpi-assets/checkpoints/pi0_base",
85 | ),
86 | }
87 |
88 |
89 | def create_default_policy(env: EnvMode, *, default_prompt: str | None = None) -> _policy.Policy:
90 | """Create a default policy for the given environment."""
91 | if checkpoint := DEFAULT_CHECKPOINT.get(env):
92 | return _policy_config.create_trained_policy(
93 | _config.get_config(checkpoint.config), checkpoint.dir, default_prompt=default_prompt
94 | )
95 | raise ValueError(f"Unsupported environment mode: {env}")
96 |
97 |
98 | def create_policy(args: Args) -> _policy.Policy:
99 | """Create a policy from the given arguments."""
100 | match args.policy:
101 | case Checkpoint():
102 | return _policy_config.create_trained_policy(
103 | _config.get_config(args.policy.config), args.policy.dir, default_prompt=args.default_prompt
104 | )
105 | case Default():
106 | return create_default_policy(args.env, default_prompt=args.default_prompt)
107 |
108 |
109 | def main(args: Args) -> None:
110 | policy = create_policy(args)
111 | policy_metadata = policy.metadata
112 |
113 | # Record the policy's behavior.
114 | if args.record:
115 | policy = _policy.PolicyRecorder(policy, "policy_records")
116 |
117 | hostname = socket.gethostname()
118 | local_ip = socket.gethostbyname(hostname)
119 | logging.info("Creating server (host: %s, ip: %s)", hostname, local_ip)
120 |
121 | server = websocket_policy_server.WebsocketPolicyServer(
122 | policy=policy,
123 | host="0.0.0.0",
124 | port=args.port,
125 | metadata=policy_metadata,
126 | )
127 | server.serve_forever()
128 |
129 |
130 | if __name__ == "__main__":
131 | logging.basicConfig(level=logging.INFO, force=True)
132 | main(tyro.cli(Args))
133 |
--------------------------------------------------------------------------------
/scripts/train_test.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import os
3 | import pathlib
4 |
5 | import pytest
6 |
7 | os.environ["JAX_PLATFORMS"] = "cpu"
8 |
9 | from openpi.training import config as _config
10 |
11 | from . import train
12 |
13 |
14 | @pytest.mark.parametrize("config_name", ["debug"])
15 | def test_train(tmp_path: pathlib.Path, config_name: str):
16 | config = dataclasses.replace(
17 | _config._CONFIGS_DICT[config_name], # noqa: SLF001
18 | batch_size=2,
19 | checkpoint_base_dir=tmp_path / "checkpoint",
20 | exp_name="test",
21 | overwrite=False,
22 | resume=False,
23 | num_train_steps=2,
24 | log_interval=1,
25 | )
26 | train.main(config)
27 |
28 | # test resuming
29 | config = dataclasses.replace(config, resume=True, num_train_steps=4)
30 | train.main(config)
31 |
--------------------------------------------------------------------------------
/src/openpi/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fanqi-Lin/OneTwoVLA/2489662cf1d4f75dbeea3cf9f337c95391e2b836/src/openpi/__init__.py
--------------------------------------------------------------------------------
/src/openpi/conftest.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pynvml
4 | import pytest
5 |
6 |
7 | def set_jax_cpu_backend_if_no_gpu() -> None:
8 | try:
9 | pynvml.nvmlInit()
10 | pynvml.nvmlShutdown()
11 | except pynvml.NVMLError:
12 | # No GPU found.
13 | os.environ["JAX_PLATFORMS"] = "cpu"
14 |
15 |
16 | def pytest_configure(config: pytest.Config) -> None:
17 | set_jax_cpu_backend_if_no_gpu()
18 |
--------------------------------------------------------------------------------
/src/openpi/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fanqi-Lin/OneTwoVLA/2489662cf1d4f75dbeea3cf9f337c95391e2b836/src/openpi/models/__init__.py
--------------------------------------------------------------------------------
/src/openpi/models/lora.py:
--------------------------------------------------------------------------------
1 | import math
2 | import re
3 |
4 | import flax.linen as nn
5 | import flax.struct as struct
6 | import jax.numpy as jnp
7 |
8 | import openpi.shared.array_typing as at
9 |
10 |
11 | @struct.dataclass
12 | class LoRAConfig:
13 | """Configuration for LoRA."""
14 |
15 | # LoRA rank.
16 | rank: int
17 | # LoRA scaling factor.
18 | alpha: float = 1.0
19 | # Initialization function for LoRA parameters.
20 | init_fn: nn.initializers.Initializer = nn.initializers.normal(stddev=0.01)
21 | # Enable rank-stabilized LoRA: https://arxiv.org/pdf/2312.03732
22 | rslora: bool = False
23 | # Axes in the weight to apply LoRA to. Should typically be the last two axes.
24 | axes: tuple[int, int] = (-2, -1)
25 | # Axis label which is used by LoRA in einsum equations. Must not be present in the original equation.
26 | label: str = "L"
27 |
28 | @property
29 | def scaling_value(self) -> float:
30 | return self.alpha / math.sqrt(self.rank) if self.rslora else self.alpha / self.rank
31 |
32 |
33 | class Einsum(nn.Module):
34 | """Einsum with LoRA support. Can be used as a drop-in replacement for the Gemma Einsum."""
35 |
36 | # Shape of the weight.
37 | shape: tuple[int, ...]
38 | # Initialization function for the weight.
39 | init_fn: nn.initializers.Initializer = nn.initializers.zeros
40 | # If not None, apply LoRA to the weight.
41 | lora_config: LoRAConfig | None = None
42 |
43 | def setup(self):
44 | self.w = self.param("w", self.init_fn, self.shape)
45 |
46 | if config := self.lora_config:
47 | # Setup LoRA parameters.
48 | shape_a, shape_b = list(self.shape), list(self.shape)
49 | shape_a[config.axes[1]] = config.rank
50 | shape_b[config.axes[0]] = config.rank
51 | self.w_a = self.param("lora_a", config.init_fn, shape_a)
52 | self.w_b = self.param("lora_b", config.init_fn, shape_b)
53 |
54 | @nn.compact
55 | def __call__(self, eqn: str, x):
56 | dtype = x.dtype # original dtype, could be half-precision
57 | result = jnp.einsum(eqn, x, self.w.astype(dtype))
58 |
59 | if config := self.lora_config:
60 | eqn_a, eqn_b = self._make_lora_eqns(eqn)
61 | lora = jnp.einsum(eqn_a, x, self.w_a.astype(dtype))
62 | lora = jnp.einsum(eqn_b, lora, self.w_b.astype(dtype))
63 | result = result + lora * config.scaling_value
64 |
65 | return result
66 |
67 | def _make_lora_eqns(self, eqn: str) -> tuple[str, str]:
68 | if "L" in eqn:
69 | raise ValueError(f"L already in eqn: {eqn}")
70 | if not (m := re.match("(.*),(.*)->(.*)", eqn)):
71 | raise ValueError(f"Unsupported einsum eqn: {eqn}")
72 | lhs, rhs, out = m.groups()
73 |
74 | assert self.lora_config is not None
75 | a_label, b_label = (rhs[x] for x in self.lora_config.axes)
76 | label = self.lora_config.label
77 |
78 | a_rhs = rhs.replace(b_label, label)
79 | a_out = out.replace(b_label, label)
80 | eqn_a = f"{lhs},{a_rhs}->{a_out}"
81 |
82 | b_rhs = rhs.replace(a_label, label)
83 | eqn_b = f"{a_out},{b_rhs}->{out}"
84 |
85 | return eqn_a, eqn_b
86 |
87 |
88 | class FeedForward(nn.Module):
89 | """Feed forward module."""
90 |
91 | features: int
92 | hidden_dim: int
93 | # If not None, apply LoRA to the weight.
94 | lora_config: LoRAConfig | None = None
95 |
96 | def setup(self):
97 | self.w_gating = self.param(
98 | "gating_einsum",
99 | nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
100 | (2, self.features, self.hidden_dim),
101 | )
102 | self.w_linear = self.param(
103 | "linear",
104 | nn.initializers.lecun_normal(in_axis=-2, out_axis=-1),
105 | (self.hidden_dim, self.features),
106 | )
107 | self.w_gating_lora = None
108 | self.w_linear_lora = None
109 | if self.lora_config:
110 | # Setup LoRA parameters.
111 | # TODO: follow up with a simplified init_fn api.
112 | self.w_gating_lora = (
113 | self.param("gating_einsum_lora_a", self.lora_config.init_fn, (2, self.features, self.lora_config.rank)),
114 | self.param(
115 | "gating_einsum_lora_b", self.lora_config.init_fn, (2, self.lora_config.rank, self.hidden_dim)
116 | ),
117 | )
118 | self.w_linear_lora = (
119 | self.param("linear_lora_a", self.lora_config.init_fn, (self.hidden_dim, self.lora_config.rank)),
120 | self.param("linear_lora_b", self.lora_config.init_fn, (self.lora_config.rank, self.features)),
121 | )
122 |
123 | @nn.compact
124 | def __call__(self, x):
125 | dtype = x.dtype # original dtype, could be half-precision
126 | ff_gate = self._dot(
127 | x,
128 | self.w_gating[0],
129 | None if self.w_gating_lora is None else (self.w_gating_lora[0][0], self.w_gating_lora[1][0]),
130 | )
131 | gate_value = nn.gelu(ff_gate)
132 |
133 | ff1 = self._dot(
134 | x,
135 | self.w_gating[1],
136 | None if self.w_gating_lora is None else (self.w_gating_lora[0][1], self.w_gating_lora[1][1]),
137 | )
138 | activations = gate_value * ff1
139 |
140 | outputs = self._dot(activations, self.w_linear, self.w_linear_lora)
141 | assert outputs.dtype == dtype
142 | return outputs
143 |
144 | def _dot(self, x: at.Array, w: at.Array, lora_weights: tuple[at.Array, at.Array] | None) -> at.Array:
145 | base = jnp.dot(x, w.astype(x.dtype))
146 | if lora_weights is None:
147 | return base
148 | return base + jnp.dot(jnp.dot(x, lora_weights[0].astype(x.dtype)), lora_weights[1].astype(x.dtype))
149 |
--------------------------------------------------------------------------------
/src/openpi/models/lora_test.py:
--------------------------------------------------------------------------------
1 | import flax.linen as nn
2 | import jax
3 | import jax.numpy as jnp
4 |
5 | import openpi.models.lora as lora
6 |
7 |
8 | def test_lora_einsum_params_shape():
9 | shape = (3, 8, 32, 4) # (3KDH)
10 | einsum = lora.Einsum(shape)
11 | lora0 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2))
12 | lora1 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2, axes=(1, 2)))
13 |
14 | key = jax.random.key(0)
15 | x = jax.random.normal(key, (8, 64, 32)) # (BSD)
16 | eqn = "BSD,3KDH->3BSKH"
17 |
18 | # Ensure that lora parameters are not initialized when LoRA is not used.
19 | params = einsum.init(key, eqn, x)
20 | assert "lora_a" not in params["params"]
21 | assert "lora_b" not in params["params"]
22 |
23 | # Check that default axes work.
24 | params_lora0 = lora0.init(key, eqn, x)
25 | assert params_lora0["params"]["lora_a"].shape == (3, 8, 32, 2)
26 | assert params_lora0["params"]["lora_b"].shape == (3, 8, 2, 4)
27 |
28 | # Check that user provided axes work.
29 | params_lora1 = lora1.init(key, eqn, x)
30 | assert params_lora1["params"]["lora_a"].shape == (3, 8, 2, 4)
31 | assert params_lora1["params"]["lora_b"].shape == (3, 2, 32, 4)
32 |
33 |
34 | def test_lora_einsum_same_output():
35 | shape = (3, 8, 32, 4) # (3KDH)
36 | einsum = lora.Einsum(shape)
37 | einsum_lora = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros))
38 |
39 | key = jax.random.key(0)
40 | x = jax.random.normal(key, (8, 64, 32)) # (BSD)
41 | eqn = "BSD,3KDH->3BSKH"
42 |
43 | params = einsum.init(key, eqn, x)
44 | output = einsum.apply(params, eqn, x)
45 |
46 | params_lora = einsum_lora.init(key, eqn, x)
47 | output_lora = einsum_lora.apply(params_lora, eqn, x)
48 |
49 | # Results are the same since the LoRA parameters are initialized to zeros.
50 | assert jnp.allclose(output, output_lora)
51 |
52 |
53 | def test_lora_ffn_params_shape():
54 | ffn = lora.FeedForward(features=8, hidden_dim=32)
55 | ffn_lora = lora.FeedForward(
56 | features=8,
57 | hidden_dim=32,
58 | lora_config=lora.LoRAConfig(rank=2),
59 | )
60 |
61 | key = jax.random.key(0)
62 | x = jax.random.normal(key, (2, 8))
63 |
64 | params = ffn.init(key, x)
65 | assert params["params"]["gating_einsum"].shape == (2, 8, 32)
66 | assert params["params"]["linear"].shape == (32, 8)
67 |
68 | params_lora = ffn_lora.init(key, x)
69 | assert params_lora["params"]["gating_einsum"].shape == (2, 8, 32)
70 | assert params_lora["params"]["linear"].shape == (32, 8)
71 | assert params_lora["params"]["gating_einsum_lora_a"].shape == (2, 8, 2)
72 | assert params_lora["params"]["gating_einsum_lora_b"].shape == (2, 2, 32)
73 | assert params_lora["params"]["linear_lora_a"].shape == (32, 2)
74 | assert params_lora["params"]["linear_lora_b"].shape == (2, 8)
75 |
76 |
77 | def test_lora_ffn_same_output():
78 | ffn = lora.FeedForward(features=8, hidden_dim=32)
79 | ffn_lora = lora.FeedForward(
80 | features=8,
81 | hidden_dim=32,
82 | lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros),
83 | )
84 |
85 | key = jax.random.key(0)
86 | x = jax.random.normal(key, (2, 8))
87 |
88 | params = ffn.init(key, x)
89 | output = ffn.apply(params, x)
90 |
91 | params_lora = ffn_lora.init(key, x)
92 | output_lora = ffn_lora.apply(params_lora, x)
93 |
94 | assert jnp.allclose(output, output_lora)
95 |
--------------------------------------------------------------------------------
/src/openpi/models/model_test.py:
--------------------------------------------------------------------------------
1 | from flax import nnx
2 | import jax
3 | import pytest
4 |
5 | from openpi.models import model as _model
6 | from openpi.models import pi0
7 | from openpi.models import pi0_fast
8 | from openpi.models import pi0_fuse
9 | from openpi.shared import download
10 | from openpi.shared import nnx_utils
11 |
12 |
13 | def test_pi0_model():
14 | key = jax.random.key(0)
15 | config = pi0.Pi0Config()
16 | model = config.create(key)
17 |
18 | batch_size = 2
19 | obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
20 |
21 | loss, loss_info = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
22 | assert loss.shape == (batch_size, config.action_horizon)
23 |
24 | actions, action_info = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10)
25 | assert actions.shape == (batch_size, model.action_horizon, model.action_dim)
26 |
27 |
28 | def test_pi0_lora_model():
29 | key = jax.random.key(0)
30 | config = pi0.Pi0Config(paligemma_variant="gemma_2b_lora")
31 | model = config.create(key)
32 |
33 | batch_size = 2
34 | obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
35 |
36 | loss, loss_info = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
37 | assert loss.shape == (batch_size, config.action_horizon)
38 |
39 | actions, action_info = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10)
40 | assert actions.shape == (batch_size, model.action_horizon, model.action_dim)
41 |
42 |
43 | def test_pi0_fast_model():
44 | key = jax.random.key(0)
45 | config = pi0_fast.Pi0FASTConfig()
46 | model = config.create(key)
47 |
48 | batch_size = 2
49 | obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
50 |
51 | loss, loss_info = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
52 | assert loss.shape == (batch_size,)
53 |
54 | actions, action_info = nnx_utils.module_jit(model.sample_actions)(key, obs)
55 | assert actions.shape == (batch_size, 256)
56 |
57 |
58 | def test_pi0_fast_lora_model():
59 | key = jax.random.key(0)
60 | config = pi0_fast.Pi0FASTConfig(paligemma_variant="gemma_2b_lora")
61 | model = config.create(key)
62 |
63 | batch_size = 2
64 | obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
65 |
66 | loss, loss_info = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
67 | assert loss.shape == (batch_size,)
68 |
69 | actions, action_info = nnx_utils.module_jit(model.sample_actions)(key, obs)
70 | assert actions.shape == (batch_size, 256)
71 |
72 | lora_filter = nnx_utils.PathRegex(".*lora.*")
73 | model_state = nnx.state(model)
74 |
75 | lora_state_elems = list(model_state.filter(lora_filter))
76 | assert len(lora_state_elems) > 0
77 |
78 |
79 | def test_pi0_fuse_mosel():
80 | key = jax.random.key(0)
81 | config = pi0_fuse.Pi0FuseConfig()
82 | model = config.create(key)
83 |
84 | batch_size = 2
85 | obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
86 |
87 | loss, loss_info = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
88 | assert loss.shape == (batch_size,)
89 |
90 | actions, action_info = nnx_utils.module_jit(model.sample_actions)(key, obs)
91 | assert actions.shape == (batch_size, model.action_horizon, model.action_dim)
92 |
93 |
94 | @pytest.mark.manual
95 | def test_model_restore():
96 | key = jax.random.key(0)
97 | config = pi0.Pi0Config()
98 |
99 | batch_size = 2
100 | obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
101 |
102 | model = config.load(
103 | _model.restore_params(download.maybe_download("s3://openpi-assets/checkpoints/pi0_base/params"))
104 | )
105 |
106 | loss, loss_info = model.compute_loss(key, obs, act)
107 | assert loss.shape == (batch_size, config.action_horizon)
108 |
109 | actions, action_info = model.sample_actions(key, obs, num_steps=10)
110 | assert actions.shape == (batch_size, model.action_horizon, model.action_dim)
111 |
--------------------------------------------------------------------------------
/src/openpi/models/pi0_test.py:
--------------------------------------------------------------------------------
1 | import flax.nnx as nnx
2 | import jax
3 |
4 | import openpi.models.pi0 as _pi0
5 |
6 |
7 | def _get_frozen_state(config: _pi0.Pi0Config) -> nnx.State:
8 | abstract_model = nnx.eval_shape(config.create, jax.random.key(0))
9 |
10 | freeze_filter = config.get_freeze_filter()
11 | return nnx.state(abstract_model, nnx.All(nnx.Param, freeze_filter)).flat_state()
12 |
13 |
14 | def test_pi0_full_finetune():
15 | config = _pi0.Pi0Config()
16 | state = _get_frozen_state(config)
17 | assert len(state) == 0
18 |
19 |
20 | def test_pi0_gemma_lora():
21 | config = _pi0.Pi0Config(paligemma_variant="gemma_2b_lora")
22 | state = _get_frozen_state(config)
23 | assert len(state) == 9
24 | assert all("lora" not in p for p in state)
25 | assert all("llm" in p for p in state)
26 | assert all("_1" not in p for p in state)
27 |
28 |
29 | def test_pi0_action_expert_lora():
30 | config = _pi0.Pi0Config(action_expert_variant="gemma_300m_lora")
31 | state = _get_frozen_state(config)
32 | # excluding embedder, rest of the params should be same as gemma_lora.
33 | assert len(state) == 8
34 | assert all("lora" not in p for p in state)
35 | assert all("llm" in p for p in state)
36 | # all frozen params should have _1 in their path since it's the action expert.
37 | assert all(any("_1" in p for p in path) for path in state)
38 |
39 |
40 | def test_pi0_all_lora():
41 | config = _pi0.Pi0Config(paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora")
42 | state = _get_frozen_state(config)
43 | # sum of gemma_lora and action_expert_lora's frozen params.
44 | assert len(state) == 17
45 | assert all("lora" not in p for p in state)
46 | assert all("llm" in p for p in state)
47 |
--------------------------------------------------------------------------------
/src/openpi/models/tokenizer_test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from openpi.models import tokenizer as _tokenizer
4 |
5 |
6 | def test_tokenize():
7 | tokenizer = _tokenizer.PaligemmaTokenizer(max_len=10)
8 | tokens, masks = tokenizer.tokenize("Hello, world!")
9 |
10 | assert tokens.shape == (10,)
11 | assert masks.shape == (10,)
12 |
13 |
14 | def test_fast_tokenizer():
15 | prompt = "Hello, world!"
16 | state = np.random.rand(5).astype(np.float32)
17 | action = np.random.rand(3, 2).astype(np.float32)
18 | tokenizer = _tokenizer.FASTTokenizer(max_len=256)
19 | tokens, token_masks, ar_masks, loss_masks = tokenizer.tokenize(prompt, state, action)
20 |
21 | assert tokens.shape == (256,)
22 | assert token_masks.shape == (256,)
23 | assert ar_masks.shape == (256,)
24 | assert loss_masks.shape == (256,)
25 |
26 | act = tokenizer.extract_actions(tokens, 3, 2)
27 | assert act.shape == (3, 2)
28 |
--------------------------------------------------------------------------------
/src/openpi/policies/droid_policy.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 |
3 | import einops
4 | import numpy as np
5 |
6 | from openpi import transforms
7 | from openpi.models import model as _model
8 |
9 |
10 | def make_droid_example() -> dict:
11 | """Creates a random input example for the Droid policy."""
12 | return {
13 | "observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
14 | "observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
15 | "observation/joint_position": np.random.rand(7),
16 | "observation/gripper_position": np.random.rand(1),
17 | "prompt": "do something",
18 | }
19 |
20 |
21 | def _parse_image(image) -> np.ndarray:
22 | image = np.asarray(image)
23 | if np.issubdtype(image.dtype, np.floating):
24 | image = (255 * image).astype(np.uint8)
25 | if image.shape[0] == 3:
26 | image = einops.rearrange(image, "c h w -> h w c")
27 | return image
28 |
29 |
30 | @dataclasses.dataclass(frozen=True)
31 | class DroidInputs(transforms.DataTransformFn):
32 | # The action dimension of the model. Will be used to pad state and actions.
33 | action_dim: int
34 |
35 | # Determines which model will be used.
36 | model_type: _model.ModelType = _model.ModelType.PI0
37 |
38 | def __call__(self, data: dict) -> dict:
39 | state = np.concatenate([data["observation/joint_position"], data["observation/gripper_position"]])
40 | state = transforms.pad_to_dim(state, self.action_dim)
41 |
42 | # Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically
43 | # stores as float32 (C,H,W), gets skipped for policy inference
44 | base_image = _parse_image(data["observation/exterior_image_1_left"])
45 | wrist_image = _parse_image(data["observation/wrist_image_left"])
46 |
47 | match self.model_type:
48 | case _model.ModelType.PI0:
49 | names = ("base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb")
50 | images = (base_image, wrist_image, np.zeros_like(base_image))
51 | image_masks = (np.True_, np.True_, np.False_)
52 | case _model.ModelType.PI0_FAST:
53 | names = ("base_0_rgb", "base_1_rgb", "wrist_0_rgb")
54 | # We don't mask out padding images for FAST models.
55 | images = (base_image, np.zeros_like(base_image), wrist_image)
56 | image_masks = (np.True_, np.True_, np.True_)
57 | case _:
58 | raise ValueError(f"Unsupported model type: {self.model_type}")
59 |
60 | inputs = {
61 | "state": state,
62 | "image": dict(zip(names, images, strict=True)),
63 | "image_mask": dict(zip(names, image_masks, strict=True)),
64 | }
65 |
66 | if "actions" in data:
67 | inputs["actions"] = data["actions"]
68 |
69 | if "prompt" in data:
70 | inputs["prompt"] = data["prompt"]
71 |
72 | return inputs
73 |
74 |
75 | @dataclasses.dataclass(frozen=True)
76 | class DroidOutputs(transforms.DataTransformFn):
77 | def __call__(self, data: dict) -> dict:
78 | # Only return the first 8 dims.
79 | return {"actions": np.asarray(data["actions"][:, :8])}
80 |
--------------------------------------------------------------------------------
/src/openpi/policies/libero_policy.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 |
3 | import einops
4 | import numpy as np
5 |
6 | from openpi import transforms
7 | from openpi.models import model as _model
8 |
9 |
10 | def make_libero_example() -> dict:
11 | """Creates a random input example for the Libero policy."""
12 | return {
13 | "observation/state": np.random.rand(8),
14 | "observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
15 | "observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
16 | "prompt": "do something",
17 | }
18 |
19 |
20 | def _parse_image(image) -> np.ndarray:
21 | image = np.asarray(image)
22 | if np.issubdtype(image.dtype, np.floating):
23 | image = (255 * image).astype(np.uint8)
24 | if image.shape[0] == 3:
25 | image = einops.rearrange(image, "c h w -> h w c")
26 | return image
27 |
28 |
29 | @dataclasses.dataclass(frozen=True)
30 | class LiberoInputs(transforms.DataTransformFn):
31 | # The action dimension of the model. Will be used to pad state and actions for pi0 model (not pi0-FAST).
32 | action_dim: int
33 |
34 | # Determines which model will be used.
35 | model_type: _model.ModelType = _model.ModelType.PI0
36 |
37 | def __call__(self, data: dict) -> dict:
38 | mask_padding = self.model_type == _model.ModelType.PI0 # We don't mask for pi0-FAST.
39 |
40 | # Get the state. We are padding from 8 to the model action dim.
41 | # For pi0-FAST, we don't pad the state (action_dim = 7, which is < 8, so pad is skipped).
42 | state = transforms.pad_to_dim(data["observation/state"], self.action_dim)
43 |
44 | # Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically
45 | # stores as float32 (C,H,W), gets skipped for policy inference
46 | base_image = _parse_image(data["observation/image"])
47 | wrist_image = _parse_image(data["observation/wrist_image"])
48 |
49 | inputs = {
50 | "state": state,
51 | "image": {
52 | "base_0_rgb": base_image,
53 | "left_wrist_0_rgb": wrist_image,
54 | "right_wrist_0_rgb": np.zeros_like(base_image),
55 | },
56 | "image_mask": {
57 | "base_0_rgb": np.True_,
58 | "left_wrist_0_rgb": np.True_,
59 | "right_wrist_0_rgb": np.False_ if mask_padding else np.True_,
60 | },
61 | }
62 |
63 | # Actions are only available during training.
64 | if "actions" in data:
65 | # We are padding from 7 to the model action dim.
66 | # For pi0-FAST, this is a no-op (since action_dim = 7).
67 | actions = transforms.pad_to_dim(data["actions"], self.action_dim)
68 | inputs["actions"] = actions
69 |
70 | if "prompt" in data:
71 | inputs["prompt"] = data["prompt"]
72 |
73 | return inputs
74 |
75 |
76 | @dataclasses.dataclass(frozen=True)
77 | class LiberoOutputs(transforms.DataTransformFn):
78 | def __call__(self, data: dict) -> dict:
79 | # Only return the first 7 dims.
80 | return {"actions": np.asarray(data["actions"][:, :7])}
81 |
--------------------------------------------------------------------------------
/src/openpi/policies/policy_config.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Sequence
2 | import dataclasses
3 | import logging
4 | import pathlib
5 | from typing import Any
6 |
7 | import jax.numpy as jnp
8 |
9 | import openpi.models.model as _model
10 | import openpi.policies.policy as _policy
11 | import openpi.shared.download as download
12 | from openpi.training import checkpoints as _checkpoints
13 | from openpi.training import config as _config
14 | import openpi.transforms as transforms
15 |
16 |
17 | @dataclasses.dataclass
18 | class PolicyConfig:
19 | model: _model.BaseModel
20 | norm_stats: dict[str, transforms.NormStats]
21 |
22 | input_layers: Sequence[transforms.DataTransformFn]
23 | output_layers: Sequence[transforms.DataTransformFn]
24 |
25 | model_type: _model.ModelType = _model.ModelType.PI0
26 | default_prompt: str | None = None
27 | sample_kwargs: dict[str, Any] | None = None
28 |
29 |
30 | def create_trained_policy(
31 | train_config: _config.TrainConfig | _config.UMITrainConfig,
32 | checkpoint_dir: pathlib.Path | str,
33 | *,
34 | repack_transforms: transforms.Group | None = None,
35 | sample_kwargs: dict[str, Any] | None = None,
36 | default_prompt: str | None = None,
37 | norm_stats: dict[str, transforms.NormStats] | None = None,
38 | ) -> _policy.Policy | _policy.ReasoningPolicy:
39 | """Create a policy from a trained checkpoint.
40 |
41 | Args:
42 | train_config: The training config to use to create the model.
43 | checkpoint_dir: The directory to load the model from.
44 | repack_transforms: Optional transforms that will be applied before any other transforms.
45 | sample_kwargs: The kwargs to pass to the `sample_actions` method. If not provided, the default
46 | kwargs will be used.
47 | default_prompt: The default prompt to use for the policy. Will inject the prompt into the input
48 | data if it doesn't already exist.
49 | norm_stats: The norm stats to use for the policy. If not provided, the norm stats will be loaded
50 | from the checkpoint directory.
51 | """
52 | repack_transforms = repack_transforms or transforms.Group()
53 | checkpoint_dir = download.maybe_download(str(checkpoint_dir))
54 |
55 | logging.info("Loading model...")
56 | model = train_config.model.load(_model.restore_params(checkpoint_dir / "params", dtype=jnp.bfloat16))
57 |
58 | data_config = train_config.data.create(train_config.assets_dirs, train_config.model)
59 | if norm_stats is None:
60 | # We are loading the norm stats from the checkpoint instead of the config assets dir to make sure
61 | # that the policy is using the same normalization stats as the original training process.
62 | if data_config.asset_id is None:
63 | raise ValueError("Asset id is required to load norm stats.")
64 | norm_stats = _checkpoints.load_norm_stats(checkpoint_dir / "assets", data_config.asset_id)
65 |
66 | extra_policy_kwargs = {}
67 |
68 | if train_config.model.model_type == _model.ModelType.PI0_FUSE:
69 | assert isinstance(train_config, _config.UMITrainConfig)
70 | ploicy_cls = _policy.ReasoningPolicy
71 | extra_policy_kwargs['use_ref_img'] = train_config.use_reference_image
72 |
73 | if 'cocktail' in train_config.repo_id:
74 | extra_policy_kwargs['initial_scene_plan'] = (
75 | 'Scene description: TBD.\n'
76 | 'Plan: TBD.\n'
77 | 'What I have done: TBD.\n'
78 | 'Now I need to: TBD.\n'
79 | )
80 | elif 'wild_move_to' in train_config.repo_id:
81 | extra_policy_kwargs['initial_scene_plan'] = ''
82 | else:
83 | extra_policy_kwargs['initial_scene_plan'] = ''
84 | else:
85 | ploicy_cls = _policy.Policy
86 |
87 | return ploicy_cls(
88 | model,
89 | transforms=[
90 | *repack_transforms.inputs,
91 | transforms.InjectDefaultPrompt(default_prompt),
92 | *data_config.data_transforms.inputs,
93 | transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
94 | *data_config.model_transforms.inputs,
95 | ],
96 | output_transforms=[
97 | *data_config.model_transforms.outputs,
98 | transforms.Unnormalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
99 | *data_config.data_transforms.outputs,
100 | *repack_transforms.outputs,
101 | ],
102 | sample_kwargs=sample_kwargs,
103 | metadata=train_config.policy_metadata,
104 | **extra_policy_kwargs,
105 | )
106 |
--------------------------------------------------------------------------------
/src/openpi/policies/policy_test.py:
--------------------------------------------------------------------------------
1 | from openpi_client import action_chunk_broker
2 | import pytest
3 |
4 | from openpi.policies import aloha_policy
5 | from openpi.policies import policy_config as _policy_config
6 | from openpi.training import config as _config
7 |
8 |
9 | @pytest.mark.manual
10 | def test_infer():
11 | config = _config.get_config("pi0_aloha_sim")
12 | policy = _policy_config.create_trained_policy(config, "s3://openpi-assets/checkpoints/pi0_aloha_sim")
13 |
14 | example = aloha_policy.make_aloha_example()
15 | result = policy.infer(example)
16 |
17 | assert result["actions"].shape == (config.model.action_horizon, 14)
18 |
19 |
20 | @pytest.mark.manual
21 | def test_broker():
22 | config = _config.get_config("pi0_aloha_sim")
23 | policy = _policy_config.create_trained_policy(config, "s3://openpi-assets/checkpoints/pi0_aloha_sim")
24 |
25 | broker = action_chunk_broker.ActionChunkBroker(
26 | policy,
27 | # Only execute the first half of the chunk.
28 | action_horizon=config.model.action_horizon // 2,
29 | )
30 |
31 | example = aloha_policy.make_aloha_example()
32 | for _ in range(config.model.action_horizon):
33 | outputs = broker.infer(example)
34 | assert outputs["actions"].shape == (14,)
35 |
--------------------------------------------------------------------------------
/src/openpi/policies/pose_repr_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def compute_relative_pose(pos, rot, base_pos, base_rot_mat,
5 | rot_transformer_to_mat,
6 | rot_transformer_to_target,
7 | backward=False,
8 | delta=False):
9 | if not backward:
10 | # forward pass
11 | if not delta:
12 | output_pos = pos if base_pos is None else pos - base_pos
13 | output_rot = rot_transformer_to_target.forward(
14 | rot_transformer_to_mat.forward(rot) @ np.linalg.inv(base_rot_mat))
15 | return output_pos, output_rot
16 | else:
17 | all_pos = np.concatenate([base_pos[None,...], pos], axis=0)
18 | output_pos = np.diff(all_pos, axis=0)
19 |
20 | rot_mat = rot_transformer_to_mat.forward(rot)
21 | all_rot_mat = np.concatenate([base_rot_mat[None,...], rot_mat], axis=0)
22 | prev_rot = np.linalg.inv(all_rot_mat[:-1])
23 | curr_rot = all_rot_mat[1:]
24 | rot = np.matmul(curr_rot, prev_rot)
25 | output_rot = rot_transformer_to_target.forward(rot)
26 | return output_pos, output_rot
27 |
28 | else:
29 | # backward pass
30 | if not delta:
31 | output_pos = pos if base_pos is None else pos + base_pos
32 | output_rot = rot_transformer_to_mat.inverse(
33 | rot_transformer_to_target.inverse(rot) @ base_rot_mat)
34 | return output_pos, output_rot
35 | else:
36 | output_pos = np.cumsum(pos, axis=0) + base_pos
37 |
38 | rot_mat = rot_transformer_to_target.inverse(rot)
39 | output_rot_mat = np.zeros_like(rot_mat)
40 | curr_rot = base_rot_mat
41 | for i in range(len(rot_mat)):
42 | curr_rot = rot_mat[i] @ curr_rot
43 | output_rot_mat[i] = curr_rot
44 | output_rot = rot_transformer_to_mat.inverse(rot)
45 | return output_pos, output_rot
46 |
47 |
48 | def convert_pose_mat_rep(pose_mat, base_pose_mat, pose_rep='abs', backward=False):
49 | if not backward:
50 | # training transform
51 | if pose_rep == 'abs':
52 | return pose_mat
53 | elif pose_rep == 'rel':
54 | # legacy buggy implementation
55 | # for compatibility
56 | pos = pose_mat[...,:3,3] - base_pose_mat[:3,3]
57 | rot = pose_mat[...,:3,:3] @ np.linalg.inv(base_pose_mat[:3,:3])
58 | out = np.copy(pose_mat)
59 | out[...,:3,:3] = rot
60 | out[...,:3,3] = pos
61 | return out
62 | elif pose_rep == 'relative':
63 | out = np.linalg.inv(base_pose_mat) @ pose_mat
64 | return out
65 | elif pose_rep == 'delta':
66 | all_pos = np.concatenate([base_pose_mat[None,:3,3], pose_mat[...,:3,3]], axis=0)
67 | out_pos = np.diff(all_pos, axis=0)
68 |
69 | all_rot_mat = np.concatenate([base_pose_mat[None,:3,:3], pose_mat[...,:3,:3]], axis=0)
70 | prev_rot = np.linalg.inv(all_rot_mat[:-1])
71 | curr_rot = all_rot_mat[1:]
72 | out_rot = np.matmul(curr_rot, prev_rot)
73 |
74 | out = np.copy(pose_mat)
75 | out[...,:3,:3] = out_rot
76 | out[...,:3,3] = out_pos
77 | return out
78 | else:
79 | raise RuntimeError(f"Unsupported pose_rep: {pose_rep}")
80 |
81 | else:
82 | # eval transform
83 | if pose_rep == 'abs':
84 | return pose_mat
85 | elif pose_rep == 'rel':
86 | # legacy buggy implementation
87 | # for compatibility
88 | pos = pose_mat[...,:3,3] + base_pose_mat[:3,3]
89 | rot = pose_mat[...,:3,:3] @ base_pose_mat[:3,:3]
90 | out = np.copy(pose_mat)
91 | out[...,:3,:3] = rot
92 | out[...,:3,3] = pos
93 | return out
94 | elif pose_rep == 'relative':
95 | out = base_pose_mat @ pose_mat
96 | return out
97 | elif pose_rep == 'delta':
98 | output_pos = np.cumsum(pose_mat[...,:3,3], axis=0) + base_pose_mat[:3,3]
99 |
100 | output_rot_mat = np.zeros_like(pose_mat[...,:3,:3])
101 | curr_rot = base_pose_mat[:3,:3]
102 | for i in range(len(pose_mat)):
103 | curr_rot = pose_mat[i,:3,:3] @ curr_rot
104 | output_rot_mat[i] = curr_rot
105 |
106 | out = np.copy(pose_mat)
107 | out[...,:3,:3] = output_rot_mat
108 | out[...,:3,3] = output_pos
109 | return out
110 | else:
111 | raise RuntimeError(f"Unsupported pose_rep: {pose_rep}")
112 |
--------------------------------------------------------------------------------
/src/openpi/policies/pose_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.spatial.transform as st
3 |
4 | def pos_rot_to_mat(pos, rot):
5 | shape = pos.shape[:-1]
6 | mat = np.zeros(shape + (4,4), dtype=pos.dtype)
7 | mat[...,:3,3] = pos
8 | mat[...,:3,:3] = rot.as_matrix()
9 | mat[...,3,3] = 1
10 | return mat
11 |
12 | def mat_to_pos_rot(mat):
13 | pos = (mat[...,:3,3].T / mat[...,3,3].T).T
14 | rot = st.Rotation.from_matrix(mat[...,:3,:3])
15 | return pos, rot
16 |
17 | def pos_rot_to_pose(pos, rot):
18 | shape = pos.shape[:-1]
19 | pose = np.zeros(shape+(6,), dtype=pos.dtype)
20 | pose[...,:3] = pos
21 | pose[...,3:] = rot.as_rotvec()
22 | return pose
23 |
24 | def pose_to_pos_rot(pose):
25 | pos = pose[...,:3]
26 | rot = st.Rotation.from_rotvec(pose[...,3:])
27 | return pos, rot
28 |
29 | def pose_to_mat(pose):
30 | return pos_rot_to_mat(*pose_to_pos_rot(pose))
31 |
32 | def mat_to_pose(mat):
33 | return pos_rot_to_pose(*mat_to_pos_rot(mat))
34 |
35 | def transform_pose(tx, pose):
36 | """
37 | tx: tx_new_old
38 | pose: tx_old_obj
39 | result: tx_new_obj
40 | """
41 | pose_mat = pose_to_mat(pose)
42 | tf_pose_mat = tx @ pose_mat
43 | tf_pose = mat_to_pose(tf_pose_mat)
44 | return tf_pose
45 |
46 | def transform_point(tx, point):
47 | return point @ tx[:3,:3].T + tx[:3,3]
48 |
49 | def project_point(k, point):
50 | x = point @ k.T
51 | uv = x[...,:2] / x[...,[2]]
52 | return uv
53 |
54 | def apply_delta_pose(pose, delta_pose):
55 | new_pose = np.zeros_like(pose)
56 |
57 | # simple add for position
58 | new_pose[:3] = pose[:3] + delta_pose[:3]
59 |
60 | # matrix multiplication for rotation
61 | rot = st.Rotation.from_rotvec(pose[3:])
62 | drot = st.Rotation.from_rotvec(delta_pose[3:])
63 | new_pose[3:] = (drot * rot).as_rotvec()
64 |
65 | return new_pose
66 |
67 | def normalize(vec, tol=1e-7):
68 | return vec / np.maximum(np.linalg.norm(vec), tol)
69 |
70 | def rot_from_directions(from_vec, to_vec):
71 | from_vec = normalize(from_vec)
72 | to_vec = normalize(to_vec)
73 | axis = np.cross(from_vec, to_vec)
74 | axis = normalize(axis)
75 | angle = np.arccos(np.dot(from_vec, to_vec))
76 | rotvec = axis * angle
77 | rot = st.Rotation.from_rotvec(rotvec)
78 | return rot
79 |
80 | def normalize(vec, eps=1e-12):
81 | norm = np.linalg.norm(vec, axis=-1)
82 | norm = np.maximum(norm, eps)
83 | out = (vec.T / norm).T
84 | return out
85 |
86 | def rot6d_to_mat(d6):
87 | a1, a2 = d6[..., :3], d6[..., 3:]
88 | b1 = normalize(a1)
89 | b2 = a2 - np.sum(b1 * a2, axis=-1, keepdims=True) * b1
90 | b2 = normalize(b2)
91 | b3 = np.cross(b1, b2, axis=-1)
92 | out = np.stack((b1, b2, b3), axis=-2)
93 | return out
94 |
95 | def mat_to_rot6d(mat):
96 | batch_dim = mat.shape[:-2]
97 | out = mat[..., :2, :].copy().reshape(batch_dim + (6,))
98 | return out
99 |
100 | def mat_to_pose10d(mat):
101 | pos = mat[...,:3,3]
102 | rotmat = mat[...,:3,:3]
103 | d6 = mat_to_rot6d(rotmat)
104 | d10 = np.concatenate([pos, d6], axis=-1)
105 | return d10
106 |
107 | def pose10d_to_mat(d10):
108 | pos = d10[...,:3]
109 | d6 = d10[...,3:]
110 | rotmat = rot6d_to_mat(d6)
111 | out = np.zeros(d10.shape[:-1]+(4,4), dtype=d10.dtype)
112 | out[...,:3,:3] = rotmat
113 | out[...,:3,3] = pos
114 | out[...,3,3] = 1
115 | return out
116 |
--------------------------------------------------------------------------------
/src/openpi/policies/umi_policy.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 |
3 | import einops
4 | import numpy as np
5 |
6 | from openpi import transforms
7 | from openpi.models import model as _model
8 |
9 |
10 | def make_umi_example() -> dict:
11 | """Creates a random input example for the umi policy."""
12 | return {
13 | "state": np.random.rand(48),
14 | "image_1": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
15 | "image_2": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
16 | "image_3": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
17 | "prompt": "do something",
18 | }
19 |
20 |
21 | def _parse_image(image) -> np.ndarray:
22 | image = np.asarray(image)
23 | if np.issubdtype(image.dtype, np.floating):
24 | image = (255 * image).astype(np.uint8)
25 | if image.shape[0] == 3:
26 | image = einops.rearrange(image, "c h w -> h w c")
27 | return image
28 |
29 |
30 | @dataclasses.dataclass(frozen=True)
31 | class UMIInputs(transforms.DataTransformFn):
32 | # The action dimension of the model. Will be used to pad state and actions for pi0 model (not pi0-FAST).
33 | action_dim: int
34 |
35 | # Determines which model will be used.
36 | model_type: _model.ModelType = _model.ModelType.PI0
37 |
38 | def __call__(self, data: dict) -> dict:
39 | mask_padding = self.model_type == _model.ModelType.PI0 # We don't mask for pi0-FAST.
40 |
41 | # Get the state. We are padding from 8 to the model action dim.
42 | # For pi0-FAST, we don't pad the state (action_dim = 7, which is < 8, so pad is skipped).
43 | state = transforms.pad_to_dim(data["state"], self.action_dim)
44 |
45 | history_length = 1
46 | while True:
47 | if f"image_{history_length + 1}" not in data:
48 | break
49 | history_length += 1
50 |
51 | # Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically
52 | # stores as float32 (C,H,W), gets skipped for policy inference
53 | image_dict, image_mask_dict = {}, {}
54 | for i in range(history_length):
55 | image = _parse_image(data[f"image_{i + 1}"])
56 | image_dict[f"{i}_rgb"] = image
57 | image_mask_dict[f"{i}_rgb"] = np.True_
58 |
59 | if 'reference_image' in data.keys():
60 | image = _parse_image(data['reference_image'])
61 | image_dict['reference_rgb'] = image
62 | image_mask_dict['reference_rgb'] = np.True_
63 |
64 | add_prompt_info = None
65 | if 'condition' in data.keys():
66 | if data['condition'] is None:
67 | image_dict['start_rgb'] = np.zeros_like(image_dict['0_rgb'])
68 | image_mask_dict['start_rgb'] = np.False_
69 | else:
70 | image_dict['start_rgb'] = _parse_image(data['condition']['episode_start_image'])
71 | image_mask_dict['start_rgb'] = np.True_
72 | add_prompt_info = '. Objects are located at ' + str(data['condition']['detect']) + '.'
73 |
74 | inputs = {
75 | "state": state,
76 | "image": image_dict,
77 | "image_mask": image_mask_dict
78 | }
79 |
80 | # Actions are only available during training.
81 | if "actions" in data:
82 | # We are padding from 7 to the model action dim.
83 | # For pi0-FAST, this is a no-op (since action_dim = 7).
84 | actions = transforms.pad_to_dim(data["actions"], self.action_dim)
85 | inputs["actions"] = actions
86 |
87 | if "prompt" in data:
88 | inputs["prompt"] = data["prompt"]
89 |
90 | if add_prompt_info is not None:
91 | inputs["prompt"] += add_prompt_info
92 |
93 | if 'thought' in data.keys():
94 | inputs['thought'] = data['thought']
95 | inputs['act_with_outdated_thought'] = data['act_with_outdated_thought']
96 | inputs['think_with_outdated_thought'] = data['think_with_outdated_thought']
97 |
98 | return inputs
99 |
100 |
101 | @dataclasses.dataclass(frozen=True)
102 | class UMIOutputs(transforms.DataTransformFn):
103 | def __call__(self, data: dict) -> dict:
104 | # Only return the first 10 dims.
105 | data.update({"actions": np.asarray(data["actions"][:, :10])})
106 | return data
107 |
--------------------------------------------------------------------------------
/src/openpi/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fanqi-Lin/OneTwoVLA/2489662cf1d4f75dbeea3cf9f337c95391e2b836/src/openpi/py.typed
--------------------------------------------------------------------------------
/src/openpi/serving/websocket_policy_server.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import logging
3 | import traceback
4 | import numpy as np
5 |
6 | from openpi_client import base_policy as _base_policy
7 | from openpi_client import msgpack_numpy
8 | import websockets.asyncio.server
9 | import websockets.frames
10 |
11 |
12 | class WebsocketPolicyServer:
13 | """Serves a policy using the websocket protocol. See websocket_client_policy.py for a client implementation.
14 |
15 | Currently only implements the `load` and `infer` methods.
16 | """
17 |
18 | def __init__(
19 | self,
20 | policy: _base_policy.BasePolicy,
21 | host: str = "0.0.0.0",
22 | port: int = 8000,
23 | metadata: dict | None = None,
24 | ) -> None:
25 | self._policy = policy
26 | self._host = host
27 | self._port = port
28 | self._metadata = metadata or {}
29 | logging.getLogger("websockets.server").setLevel(logging.INFO)
30 |
31 | def serve_forever(self) -> None:
32 | asyncio.run(self.run())
33 |
34 | async def run(self):
35 | async with websockets.asyncio.server.serve(
36 | self._handler,
37 | self._host,
38 | self._port,
39 | compression=None,
40 | max_size=None,
41 | ) as server:
42 | await server.serve_forever()
43 |
44 | async def _handler(self, websocket: websockets.asyncio.server.ServerConnection):
45 | logging.info(f"Connection from {websocket.remote_address} opened")
46 | packer = msgpack_numpy.Packer()
47 | if hasattr(self._policy, 'start'):
48 | self._policy.start()
49 |
50 | await websocket.send(packer.pack(self._metadata))
51 |
52 | while True:
53 | try:
54 | obs = msgpack_numpy.unpackb(await websocket.recv())
55 | infer_task = asyncio.create_task(asyncio.to_thread(self._policy.infer, obs))
56 | while True:
57 | if infer_task.done():
58 | action = await infer_task
59 | await websocket.send(packer.pack(action))
60 | break
61 | if getattr(self._policy, 'is_thinking', False):
62 | await websocket.send(packer.pack({"isthinking": np.True_}))
63 | _ = msgpack_numpy.unpackb(await websocket.recv())
64 | await asyncio.sleep(0.02)
65 | except websockets.ConnectionClosed:
66 | logging.info(f"Connection from {websocket.remote_address} closed")
67 | break
68 | except Exception:
69 | await websocket.send(traceback.format_exc())
70 | await websocket.close(
71 | code=websockets.frames.CloseCode.INTERNAL_ERROR,
72 | reason="Internal server error. Traceback included in previous frame.",
73 | )
74 | raise
75 |
--------------------------------------------------------------------------------
/src/openpi/shared/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fanqi-Lin/OneTwoVLA/2489662cf1d4f75dbeea3cf9f337c95391e2b836/src/openpi/shared/__init__.py
--------------------------------------------------------------------------------
/src/openpi/shared/array_typing.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import functools as ft
3 | import inspect
4 | from typing import TypeAlias, TypeVar, cast
5 |
6 | import beartype
7 | import jax
8 | import jax._src.tree_util as private_tree_util
9 | import jax.core
10 | from jaxtyping import Array # noqa: F401
11 | from jaxtyping import ArrayLike
12 | from jaxtyping import Bool # noqa: F401
13 | from jaxtyping import DTypeLike # noqa: F401
14 | from jaxtyping import Float
15 | from jaxtyping import Int # noqa: F401
16 | from jaxtyping import Key # noqa: F401
17 | from jaxtyping import Num # noqa: F401
18 | from jaxtyping import PyTree
19 | from jaxtyping import Real # noqa: F401
20 | from jaxtyping import UInt8 # noqa: F401
21 | from jaxtyping import config
22 | from jaxtyping import jaxtyped
23 | import jaxtyping._decorator
24 |
25 | # patch jaxtyping to handle https://github.com/patrick-kidger/jaxtyping/issues/277.
26 | # the problem is that custom PyTree nodes are sometimes initialized with arbitrary types (e.g., `jax.ShapeDtypeStruct`,
27 | # `jax.Sharding`, or even