├── openpi-SF ├── .python-version ├── scripts │ ├── __init__.py │ ├── train_test.py │ └── docker │ │ ├── compose.yml │ │ ├── install_nvidia_container_toolkit.sh │ │ ├── install_docker_ubuntu22.sh │ │ └── serve_policy.Dockerfile ├── src │ ├── openpi │ │ ├── __init__.py │ │ ├── py.typed │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── tokenizer_test.py │ │ │ ├── pi0_test.py │ │ │ ├── model_test.py │ │ │ └── lora_test.py │ │ ├── shared │ │ │ ├── __init__.py │ │ │ ├── image_tools_test.py │ │ │ ├── normalize_test.py │ │ │ ├── download_test.py │ │ │ ├── nnx_utils.py │ │ │ └── array_typing.py │ │ ├── models_pytorch │ │ │ ├── transformers_replace │ │ │ │ └── models │ │ │ │ │ └── siglip │ │ │ │ │ └── check.py │ │ │ └── projectors.py │ │ ├── conftest.py │ │ ├── policies │ │ │ ├── policy_test.py │ │ │ └── droid_policy.py │ │ ├── training │ │ │ ├── utils.py │ │ │ └── data_loader_test.py │ │ └── serving │ │ │ └── websocket_policy_server.py │ └── vggt │ │ ├── __init__.py │ │ ├── dependency │ │ ├── track_modules │ │ │ └── __init__.py │ │ └── __init__.py │ │ ├── heads │ │ └── track_modules │ │ │ └── __init__.py │ │ ├── layers │ │ ├── __init__.py │ │ ├── layer_scale.py │ │ ├── drop_path.py │ │ ├── mlp.py │ │ ├── swiglu_ffn.py │ │ ├── patch_embed.py │ │ └── attention.py │ │ ├── pyproject.toml │ │ └── utils │ │ └── helper.py ├── .dockerignore ├── packages │ └── openpi-client │ │ ├── src │ │ └── openpi_client │ │ │ ├── __init__.py │ │ │ ├── base_policy.py │ │ │ ├── runtime │ │ │ ├── agent.py │ │ │ ├── subscriber.py │ │ │ ├── agents │ │ │ │ └── policy_agent.py │ │ │ ├── environment.py │ │ │ └── runtime.py │ │ │ ├── image_tools_test.py │ │ │ ├── action_chunk_broker.py │ │ │ ├── msgpack_numpy_test.py │ │ │ ├── msgpack_numpy.py │ │ │ ├── websocket_client_policy.py │ │ │ └── image_tools.py │ │ └── pyproject.toml ├── examples │ ├── simple_client │ │ ├── requirements.in │ │ ├── README.md │ │ ├── requirements.txt │ │ ├── compose.yml │ │ └── Dockerfile │ ├── aloha_sim │ │ ├── requirements.in │ │ ├── README.md │ │ ├── compose.yml │ │ ├── saver.py │ │ ├── main.py │ │ ├── Dockerfile │ │ ├── env.py │ │ └── requirements.txt │ ├── libero │ │ ├── requirements.in │ │ ├── compose.yml │ │ ├── Dockerfile │ │ ├── README.md │ │ └── requirements.txt │ └── aloha_real │ │ ├── requirements.in │ │ ├── video_display.py │ │ ├── main.py │ │ ├── compose.yml │ │ ├── env.py │ │ ├── Dockerfile │ │ └── constants.py ├── .gitmodules ├── .pre-commit-config.yaml └── docs │ └── docker.md ├── openvla-SF ├── prismatic │ ├── py.typed │ ├── extern │ │ ├── __init__.py │ │ └── hf │ │ │ └── __init__.py │ ├── models │ │ ├── backbones │ │ │ ├── __init__.py │ │ │ ├── llm │ │ │ │ ├── __init__.py │ │ │ │ ├── prompting │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── mistral_instruct_prompter.py │ │ │ │ │ ├── base_prompter.py │ │ │ │ │ ├── phi_prompter.py │ │ │ │ │ ├── vicuna_v15_prompter.py │ │ │ │ │ └── llama2_chat_prompter.py │ │ │ │ ├── phi.py │ │ │ │ └── mistral.py │ │ │ └── vision │ │ │ │ ├── __init__.py │ │ │ │ ├── dinov2_vit.py │ │ │ │ ├── in1k_vit.py │ │ │ │ ├── siglip_vit.py │ │ │ │ └── clip_vit.py │ │ ├── vlas │ │ │ └── __init__.py │ │ ├── vlms │ │ │ └── __init__.py │ │ └── __init__.py │ ├── vla │ │ ├── datasets │ │ │ ├── rlds │ │ │ │ ├── utils │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── goal_relabeling.py │ │ │ │ │ └── task_augmentation.py │ │ │ │ ├── __init__.py │ │ │ │ └── oxe │ │ │ │ │ └── __init__.py │ │ │ └── __init__.py │ │ ├── __init__.py │ │ ├── materialize.py │ │ └── constants.py │ ├── overwatch │ │ └── __init__.py │ ├── util │ │ ├── __init__.py │ │ ├── nn_utils.py │ │ └── pooling_utils.py │ ├── preprocessing │ │ ├── datasets │ │ │ └── __init__.py │ │ ├── __init__.py │ │ └── materialize.py │ ├── __init__.py │ ├── training │ │ ├── __init__.py │ │ ├── strategies │ │ │ └── __init__.py │ │ ├── train_utils.py │ │ └── materialize.py │ └── conf │ │ └── __init__.py ├── vggt │ ├── dependency │ │ ├── track_modules │ │ │ └── __init__.py │ │ └── __init__.py │ ├── heads │ │ └── track_modules │ │ │ └── __init__.py │ ├── layers │ │ ├── __init__.py │ │ ├── layer_scale.py │ │ ├── drop_path.py │ │ ├── mlp.py │ │ ├── swiglu_ffn.py │ │ ├── patch_embed.py │ │ └── attention.py │ └── utils │ │ └── helper.py ├── experiments │ └── robot │ │ ├── libero │ │ ├── libero_requirements.txt │ │ ├── sample_libero_spatial_observation.pkl │ │ └── libero_utils.py │ │ └── aloha │ │ ├── requirements_aloha.txt │ │ └── aloha_utils.py ├── train.sh ├── .gitignore ├── vla-scripts │ ├── merge_lora_weights_and_save.py │ └── extern │ │ └── verify_openvla.py └── pyproject.toml ├── figs └── teaser.png └── LICENSE /openpi-SF/.python-version: -------------------------------------------------------------------------------- 1 | 3.11 -------------------------------------------------------------------------------- /openpi-SF/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /openpi-SF/src/openpi/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /openpi-SF/src/openpi/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /openpi-SF/src/vggt/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /openpi-SF/src/openpi/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /openpi-SF/src/openpi/shared/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/extern/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/extern/hf/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /openvla-SF/vggt/dependency/track_modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /openpi-SF/.dockerignore: -------------------------------------------------------------------------------- 1 | .venv 2 | checkpoints 3 | data 4 | -------------------------------------------------------------------------------- /openpi-SF/src/vggt/dependency/track_modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/vla/datasets/rlds/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/models/vlas/__init__.py: -------------------------------------------------------------------------------- 1 | from .openvla import OpenVLA 2 | -------------------------------------------------------------------------------- /openpi-SF/packages/openpi-client/src/openpi_client/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.0" 2 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/models/vlms/__init__.py: -------------------------------------------------------------------------------- 1 | from .prismatic import PrismaticVLM 2 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/overwatch/__init__.py: -------------------------------------------------------------------------------- 1 | from .overwatch import initialize_overwatch 2 | -------------------------------------------------------------------------------- /figs/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenHelix-Team/Spatial-Forcing/HEAD/figs/teaser.png -------------------------------------------------------------------------------- /openvla-SF/prismatic/vla/__init__.py: -------------------------------------------------------------------------------- 1 | from .materialize import get_vla_dataset_and_collator 2 | -------------------------------------------------------------------------------- /openpi-SF/examples/simple_client/requirements.in: -------------------------------------------------------------------------------- 1 | numpy>=1.22.4,<2.0.0 2 | rich 3 | tqdm 4 | tyro 5 | polars -------------------------------------------------------------------------------- /openvla-SF/prismatic/util/__init__.py: -------------------------------------------------------------------------------- 1 | from .torch_utils import check_bloat16_supported, set_global_seed 2 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/preprocessing/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import AlignDataset, FinetuneDataset 2 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/vla/datasets/rlds/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import make_interleaved_dataset, make_single_dataset 2 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import available_model_names, available_models, get_model_description, load 2 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/training/__init__.py: -------------------------------------------------------------------------------- 1 | from .materialize import get_train_strategy 2 | from .metrics import Metrics, VLAMetrics 3 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/vla/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import DummyDataset, EpisodicRLDSDataset, RLDSBatchTransform, RLDSDataset 2 | -------------------------------------------------------------------------------- /openvla-SF/experiments/robot/libero/libero_requirements.txt: -------------------------------------------------------------------------------- 1 | imageio[ffmpeg] 2 | robosuite==1.4.1 3 | bddl 4 | easydict 5 | cloudpickle 6 | gym 7 | -------------------------------------------------------------------------------- /openpi-SF/examples/aloha_sim/requirements.in: -------------------------------------------------------------------------------- 1 | gym-aloha 2 | imageio 3 | matplotlib 4 | msgpack 5 | numpy>=1.22.4,<2.0.0 6 | typing-extensions 7 | tyro 8 | websockets -------------------------------------------------------------------------------- /openvla-SF/prismatic/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | from .download import convert_to_jpg, download_extract 2 | from .materialize import get_dataset_and_collator 3 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/vla/datasets/rlds/oxe/__init__.py: -------------------------------------------------------------------------------- 1 | from .materialize import get_oxe_dataset_kwargs_and_weights 2 | from .mixtures import OXE_NAMED_MIXTURES 3 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/training/strategies/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_strategy import TrainingStrategy 2 | from .ddp import DDPStrategy 3 | from .fsdp import FSDPStrategy 4 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/conf/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import DatasetConfig, DatasetRegistry 2 | from .models import ModelConfig, ModelRegistry 3 | from .vla import VLAConfig, VLARegistry 4 | -------------------------------------------------------------------------------- /openvla-SF/experiments/robot/libero/sample_libero_spatial_observation.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenHelix-Team/Spatial-Forcing/HEAD/openvla-SF/experiments/robot/libero/sample_libero_spatial_observation.pkl -------------------------------------------------------------------------------- /openvla-SF/prismatic/models/backbones/llm/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_llm import LLMBackbone 2 | from .llama2 import LLaMa2LLMBackbone 3 | from .mistral import MistralLLMBackbone 4 | from .phi import PhiLLMBackbone 5 | -------------------------------------------------------------------------------- /openpi-SF/src/openpi/models_pytorch/transformers_replace/models/siglip/check.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | 3 | def check_whether_transformers_replace_is_installed_correctly(): 4 | return transformers.__version__ == "4.53.2" -------------------------------------------------------------------------------- /openpi-SF/src/vggt/dependency/__init__.py: -------------------------------------------------------------------------------- 1 | from .track_modules.track_refine import refine_track 2 | from .track_modules.blocks import BasicEncoder, ShallowEncoder 3 | from .track_modules.base_track_predictor import BaseTrackerPredictor 4 | -------------------------------------------------------------------------------- /openvla-SF/vggt/dependency/__init__.py: -------------------------------------------------------------------------------- 1 | from .track_modules.track_refine import refine_track 2 | from .track_modules.blocks import BasicEncoder, ShallowEncoder 3 | from .track_modules.base_track_predictor import BaseTrackerPredictor 4 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .load import available_model_names, available_models, get_model_description, load, load_vla 2 | from .materialize import get_llm_backbone_and_tokenizer, get_vision_backbone_and_transform, get_vlm 3 | -------------------------------------------------------------------------------- /openpi-SF/src/vggt/heads/track_modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /openvla-SF/vggt/heads/track_modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /openpi-SF/examples/libero/requirements.in: -------------------------------------------------------------------------------- 1 | imageio[ffmpeg] 2 | numpy==1.22.4 3 | tqdm 4 | tyro 5 | PyYaml 6 | opencv-python==4.6.0.66 7 | torch==1.11.0+cu113 8 | torchvision==0.12.0+cu113 9 | torchaudio==0.11.0+cu113 10 | robosuite==1.4.1 11 | matplotlib==3.5.3 12 | -------------------------------------------------------------------------------- /openpi-SF/.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/aloha"] 2 | path = third_party/aloha 3 | url = https://github.com/Physical-Intelligence/aloha.git 4 | [submodule "third_party/libero"] 5 | path = third_party/libero 6 | url = https://github.com/Lifelong-Robot-Learning/LIBERO.git 7 | -------------------------------------------------------------------------------- /openpi-SF/examples/aloha_real/requirements.in: -------------------------------------------------------------------------------- 1 | Pillow 2 | dm_control 3 | einops 4 | h5py 5 | matplotlib 6 | modern_robotics 7 | msgpack 8 | numpy>=1.22.4,<2.0.0 9 | opencv-python 10 | packaging 11 | pexpect 12 | pyquaternion 13 | pyrealsense2 14 | pyyaml 15 | requests 16 | rospkg 17 | tyro 18 | websockets 19 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/models/backbones/llm/prompting/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_prompter import PromptBuilder, PurePromptBuilder 2 | from .llama2_chat_prompter import LLaMa2ChatPromptBuilder 3 | from .mistral_instruct_prompter import MistralInstructPromptBuilder 4 | from .phi_prompter import PhiPromptBuilder 5 | from .vicuna_v15_prompter import VicunaV15ChatPromptBuilder 6 | -------------------------------------------------------------------------------- /openpi-SF/packages/openpi-client/src/openpi_client/base_policy.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Dict 3 | 4 | 5 | class BasePolicy(abc.ABC): 6 | @abc.abstractmethod 7 | def infer(self, obs: Dict) -> Dict: 8 | """Infer actions from observations.""" 9 | 10 | def reset(self) -> None: 11 | """Reset the policy to its initial state.""" 12 | pass 13 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/models/backbones/vision/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_vision import ImageTransform, VisionBackbone 2 | from .clip_vit import CLIPViTBackbone 3 | from .dinoclip_vit import DinoCLIPViTBackbone 4 | from .dinosiglip_vit import DinoSigLIPViTBackbone 5 | from .dinov2_vit import DinoV2ViTBackbone 6 | from .in1k_vit import IN1KViTBackbone 7 | from .siglip_vit import SigLIPViTBackbone 8 | -------------------------------------------------------------------------------- /openpi-SF/.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: third_party/ 2 | 3 | repos: 4 | - repo: https://github.com/astral-sh/uv-pre-commit 5 | # uv version. 6 | rev: 0.5.14 7 | hooks: 8 | - id: uv-lock 9 | - repo: https://github.com/astral-sh/ruff-pre-commit 10 | # Ruff version. 11 | rev: v0.8.6 12 | hooks: 13 | # Run the linter. 14 | - id: ruff 15 | args: [--fix] 16 | - id: ruff-format -------------------------------------------------------------------------------- /openvla-SF/experiments/robot/aloha/requirements_aloha.txt: -------------------------------------------------------------------------------- 1 | numpy<2 2 | draccus 3 | torchvision 4 | torch 5 | pyquaternion 6 | pyyaml 7 | rospkg 8 | pexpect 9 | mujoco==2.3.7 10 | dm_control==1.0.14 11 | opencv-python 12 | matplotlib 13 | einops 14 | packaging 15 | h5py 16 | traitlets 17 | ipdb 18 | IPython 19 | modern_robotics 20 | Pillow 21 | termcolor 22 | imageio[ffmpeg] 23 | uvicorn 24 | fastapi 25 | requests 26 | json_numpy 27 | -------------------------------------------------------------------------------- /openpi-SF/src/openpi/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pynvml 4 | import pytest 5 | 6 | 7 | def set_jax_cpu_backend_if_no_gpu() -> None: 8 | try: 9 | pynvml.nvmlInit() 10 | pynvml.nvmlShutdown() 11 | except pynvml.NVMLError: 12 | # No GPU found. 13 | os.environ["JAX_PLATFORMS"] = "cpu" 14 | 15 | 16 | def pytest_configure(config: pytest.Config) -> None: 17 | set_jax_cpu_backend_if_no_gpu() 18 | -------------------------------------------------------------------------------- /openvla-SF/vggt/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .mlp import Mlp 8 | from .patch_embed import PatchEmbed 9 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused 10 | from .block import NestedTensorBlock 11 | from .attention import MemEffAttention 12 | -------------------------------------------------------------------------------- /openpi-SF/src/vggt/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .mlp import Mlp 8 | from .patch_embed import PatchEmbed 9 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused 10 | from .block import NestedTensorBlock 11 | from .attention import MemEffAttention 12 | -------------------------------------------------------------------------------- /openpi-SF/packages/openpi-client/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "openpi-client" 3 | version = "0.1.0" 4 | requires-python = ">=3.7" 5 | dependencies = [ 6 | "dm-tree>=0.1.8", 7 | "msgpack>=1.0.5", 8 | "numpy>=1.22.4,<2.0.0", 9 | "pillow>=9.0.0", 10 | "tree>=0.2.4", 11 | "websockets>=11.0", 12 | ] 13 | 14 | [build-system] 15 | requires = ["hatchling"] 16 | build-backend = "hatchling.build" 17 | 18 | [tool.uv] 19 | dev-dependencies = ["pytest>=8.3.4"] 20 | 21 | [tool.ruff] 22 | line-length = 120 23 | target-version = "py37" 24 | -------------------------------------------------------------------------------- /openpi-SF/packages/openpi-client/src/openpi_client/runtime/agent.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class Agent(abc.ABC): 5 | """An Agent is the thing with agency, i.e. the entity that makes decisions. 6 | 7 | Agents receive observations about the state of the world, and return actions 8 | to take in response. 9 | """ 10 | 11 | @abc.abstractmethod 12 | def get_action(self, observation: dict) -> dict: 13 | """Query the agent for the next action.""" 14 | 15 | @abc.abstractmethod 16 | def reset(self) -> None: 17 | """Reset the agent to its initial state.""" 18 | -------------------------------------------------------------------------------- /openpi-SF/packages/openpi-client/src/openpi_client/runtime/subscriber.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class Subscriber(abc.ABC): 5 | """Subscribes to events in the runtime. 6 | 7 | Subscribers can be used to save data, visualize, etc. 8 | """ 9 | 10 | @abc.abstractmethod 11 | def on_episode_start(self) -> None: 12 | """Called when an episode starts.""" 13 | 14 | @abc.abstractmethod 15 | def on_step(self, observation: dict, action: dict) -> None: 16 | """Append a step to the episode.""" 17 | 18 | @abc.abstractmethod 19 | def on_episode_end(self) -> None: 20 | """Called when an episode ends.""" 21 | -------------------------------------------------------------------------------- /openpi-SF/packages/openpi-client/src/openpi_client/runtime/agents/policy_agent.py: -------------------------------------------------------------------------------- 1 | from typing_extensions import override 2 | 3 | from openpi_client import base_policy as _base_policy 4 | from openpi_client.runtime import agent as _agent 5 | 6 | 7 | class PolicyAgent(_agent.Agent): 8 | """An agent that uses a policy to determine actions.""" 9 | 10 | def __init__(self, policy: _base_policy.BasePolicy) -> None: 11 | self._policy = policy 12 | 13 | @override 14 | def get_action(self, observation: dict) -> dict: 15 | return self._policy.infer(observation) 16 | 17 | def reset(self) -> None: 18 | self._policy.reset() 19 | -------------------------------------------------------------------------------- /openpi-SF/examples/simple_client/README.md: -------------------------------------------------------------------------------- 1 | # Simple Client 2 | 3 | A minimal client that sends observations to the server and prints the inference rate. 4 | 5 | You can specify which runtime environment to use using the `--env` flag. You can see the available options by running: 6 | 7 | ```bash 8 | uv run examples/simple_client/main.py --help 9 | ``` 10 | 11 | ## With Docker 12 | 13 | ```bash 14 | export SERVER_ARGS="--env ALOHA_SIM" 15 | docker compose -f examples/simple_client/compose.yml up --build 16 | ``` 17 | 18 | ## Without Docker 19 | 20 | Terminal window 1: 21 | 22 | ```bash 23 | uv run examples/simple_client/main.py --env DROID 24 | ``` 25 | 26 | Terminal window 2: 27 | 28 | ```bash 29 | uv run scripts/serve_policy.py --env DROID 30 | ``` 31 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/models/backbones/vision/dinov2_vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | dinov2_vit.py 3 | """ 4 | 5 | from prismatic.models.backbones.vision.base_vision import TimmViTBackbone 6 | 7 | # Registry =>> Supported DINOv2 Vision Backbones (from TIMM) =>> Note:: Using DINOv2 w/ Registers! 8 | # => Reference: https://arxiv.org/abs/2309.16588 9 | DINOv2_VISION_BACKBONES = {"dinov2-vit-l": "vit_large_patch14_reg4_dinov2.lvd142m"} 10 | 11 | 12 | class DinoV2ViTBackbone(TimmViTBackbone): 13 | def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: 14 | super().__init__( 15 | vision_backbone_id, 16 | DINOv2_VISION_BACKBONES[vision_backbone_id], 17 | image_resize_strategy, 18 | default_image_size=default_image_size, 19 | ) 20 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/models/backbones/vision/in1k_vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | in1k_vit.py 3 | 4 | Vision Transformers trained / finetuned on ImageNet (ImageNet-21K =>> ImageNet-1K) 5 | """ 6 | 7 | from prismatic.models.backbones.vision.base_vision import TimmViTBackbone 8 | 9 | # Registry =>> Supported Vision Backbones (from TIMM) 10 | IN1K_VISION_BACKBONES = { 11 | "in1k-vit-l": "vit_large_patch16_224.augreg_in21k_ft_in1k", 12 | } 13 | 14 | 15 | class IN1KViTBackbone(TimmViTBackbone): 16 | def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: 17 | super().__init__( 18 | vision_backbone_id, 19 | IN1K_VISION_BACKBONES[vision_backbone_id], 20 | image_resize_strategy, 21 | default_image_size=default_image_size, 22 | ) 23 | -------------------------------------------------------------------------------- /openpi-SF/scripts/train_test.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import os 3 | import pathlib 4 | 5 | import pytest 6 | 7 | os.environ["JAX_PLATFORMS"] = "cpu" 8 | 9 | from openpi.training import config as _config 10 | 11 | from . import train 12 | 13 | 14 | @pytest.mark.parametrize("config_name", ["debug"]) 15 | def test_train(tmp_path: pathlib.Path, config_name: str): 16 | config = dataclasses.replace( 17 | _config._CONFIGS_DICT[config_name], # noqa: SLF001 18 | batch_size=2, 19 | checkpoint_base_dir=str(tmp_path / "checkpoint"), 20 | exp_name="test", 21 | overwrite=False, 22 | resume=False, 23 | num_train_steps=2, 24 | log_interval=1, 25 | ) 26 | train.main(config) 27 | 28 | # test resuming 29 | config = dataclasses.replace(config, resume=True, num_train_steps=4) 30 | train.main(config) 31 | -------------------------------------------------------------------------------- /openpi-SF/src/vggt/layers/layer_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 7 | 8 | from typing import Union 9 | 10 | import torch 11 | from torch import Tensor 12 | from torch import nn 13 | 14 | 15 | class LayerScale(nn.Module): 16 | def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False) -> None: 17 | super().__init__() 18 | self.inplace = inplace 19 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 20 | 21 | def forward(self, x: Tensor) -> Tensor: 22 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 23 | -------------------------------------------------------------------------------- /openvla-SF/vggt/layers/layer_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 7 | 8 | from typing import Union 9 | 10 | import torch 11 | from torch import Tensor 12 | from torch import nn 13 | 14 | 15 | class LayerScale(nn.Module): 16 | def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False) -> None: 17 | super().__init__() 18 | self.inplace = inplace 19 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 20 | 21 | def forward(self, x: Tensor) -> Tensor: 22 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 23 | -------------------------------------------------------------------------------- /openpi-SF/examples/aloha_sim/README.md: -------------------------------------------------------------------------------- 1 | # Run Aloha Sim 2 | 3 | ## With Docker 4 | 5 | ```bash 6 | export SERVER_ARGS="--env ALOHA_SIM" 7 | docker compose -f examples/aloha_sim/compose.yml up --build 8 | ``` 9 | 10 | ## Without Docker 11 | 12 | Terminal window 1: 13 | 14 | ```bash 15 | # Create virtual environment 16 | uv venv --python 3.10 examples/aloha_sim/.venv 17 | source examples/aloha_sim/.venv/bin/activate 18 | uv pip sync examples/aloha_sim/requirements.txt 19 | uv pip install -e packages/openpi-client 20 | 21 | # Run the simulation 22 | MUJOCO_GL=egl python examples/aloha_sim/main.py 23 | ``` 24 | 25 | Note: If you are seeing EGL errors, you may need to install the following dependencies: 26 | 27 | ```bash 28 | sudo apt-get install -y libegl1-mesa-dev libgles2-mesa-dev 29 | ``` 30 | 31 | Terminal window 2: 32 | 33 | ```bash 34 | # Run the server 35 | uv run scripts/serve_policy.py --env ALOHA_SIM 36 | ``` 37 | -------------------------------------------------------------------------------- /openpi-SF/src/openpi/models/tokenizer_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from openpi.models import tokenizer as _tokenizer 4 | 5 | 6 | def test_tokenize(): 7 | tokenizer = _tokenizer.PaligemmaTokenizer(max_len=10) 8 | tokens, masks = tokenizer.tokenize("Hello, world!") 9 | 10 | assert tokens.shape == (10,) 11 | assert masks.shape == (10,) 12 | 13 | 14 | def test_fast_tokenizer(): 15 | prompt = "Hello, world!" 16 | state = np.random.rand(5).astype(np.float32) 17 | action = np.random.rand(3, 2).astype(np.float32) 18 | tokenizer = _tokenizer.FASTTokenizer(max_len=256) 19 | tokens, token_masks, ar_masks, loss_masks = tokenizer.tokenize(prompt, state, action) 20 | 21 | assert tokens.shape == (256,) 22 | assert token_masks.shape == (256,) 23 | assert ar_masks.shape == (256,) 24 | assert loss_masks.shape == (256,) 25 | 26 | act = tokenizer.extract_actions(tokens, 3, 2) 27 | assert act.shape == (3, 2) 28 | -------------------------------------------------------------------------------- /openpi-SF/scripts/docker/compose.yml: -------------------------------------------------------------------------------- 1 | # Run with: 2 | # docker compose -f scripts/docker/compose.yml up --build 3 | services: 4 | openpi_server: 5 | image: openpi_server 6 | build: 7 | context: ../.. 8 | dockerfile: scripts/docker/serve_policy.Dockerfile 9 | init: true 10 | tty: true 11 | network_mode: host 12 | # Populate configured openpi data home to /openpi_assets inside the container. 13 | # Populate aws credential inside the container. 14 | volumes: 15 | - $PWD:/app 16 | - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets 17 | environment: 18 | - SERVER_ARGS 19 | - OPENPI_DATA_HOME=/openpi_assets 20 | - IS_DOCKER=true 21 | 22 | # Comment out this block if not running on a machine with GPUs. 23 | deploy: 24 | resources: 25 | reservations: 26 | devices: 27 | - driver: nvidia 28 | count: 1 29 | capabilities: [gpu] 30 | -------------------------------------------------------------------------------- /openpi-SF/examples/simple_client/requirements.txt: -------------------------------------------------------------------------------- 1 | # This file was autogenerated by uv via the following command: 2 | # uv pip compile examples/simple_client/requirements.in -o examples/simple_client/requirements.txt --python-version 3.11.9 3 | docstring-parser==0.16 4 | # via tyro 5 | markdown-it-py==3.0.0 6 | # via rich 7 | mdurl==0.1.2 8 | # via markdown-it-py 9 | numpy==1.26.4 10 | # via -r examples/simple_client/requirements.in 11 | polars==1.30.0 12 | # via -r examples/simple_client/requirements.in 13 | pygments==2.19.1 14 | # via rich 15 | rich==14.0.0 16 | # via 17 | # -r examples/simple_client/requirements.in 18 | # tyro 19 | shtab==1.7.2 20 | # via tyro 21 | tqdm==4.67.1 22 | # via -r examples/simple_client/requirements.in 23 | typeguard==4.4.2 24 | # via tyro 25 | typing-extensions==4.13.2 26 | # via 27 | # typeguard 28 | # tyro 29 | tyro==0.9.22 30 | # via -r examples/simple_client/requirements.in 31 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/models/backbones/vision/siglip_vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | siglip_vit.py 3 | """ 4 | 5 | from prismatic.models.backbones.vision.base_vision import TimmViTBackbone 6 | 7 | # Registry =>> Supported SigLIP Vision Backbones (from TIMM) =>> Note:: Using SigLIP w/ Patch = 14 (but SO400M Arch) 8 | SIGLIP_VISION_BACKBONES = { 9 | "siglip-vit-b16-224px": "vit_base_patch16_siglip_224", 10 | "siglip-vit-b16-256px": "vit_base_patch16_siglip_256", 11 | "siglip-vit-b16-384px": "vit_base_patch16_siglip_384", 12 | "siglip-vit-so400m": "vit_so400m_patch14_siglip_224", 13 | "siglip-vit-so400m-384px": "vit_so400m_patch14_siglip_384", 14 | } 15 | 16 | 17 | class SigLIPViTBackbone(TimmViTBackbone): 18 | def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: 19 | super().__init__( 20 | vision_backbone_id, 21 | SIGLIP_VISION_BACKBONES[vision_backbone_id], 22 | image_resize_strategy, 23 | default_image_size=default_image_size, 24 | ) 25 | -------------------------------------------------------------------------------- /openvla-SF/train.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nnodes 1 --nproc-per-node 8 vla-scripts/finetune_align.py \ 2 | --vla_path ckpts/openvla-7b \ 3 | --vggt_path ckpts/model.pt \ 4 | --data_root_dir data/libero/ \ 5 | --dataset_name libero_spatial_no_noops \ 6 | --run_root_dir ckpts/training_results/ \ 7 | --pooling_func bilinear \ 8 | --vla_layers_align 24 \ 9 | --vggt_layers_align -1 \ 10 | --align_loss_type cosine \ 11 | --align_loss_coeff 0.5 \ 12 | --use_l1_regression True \ 13 | --use_diffusion False \ 14 | --use_film False \ 15 | --use_vlm_norm True \ 16 | --use_vggt_pe True \ 17 | --num_images_in_input 2 \ 18 | --use_proprio True \ 19 | --batch_size 8 \ 20 | --learning_rate 5e-4 \ 21 | --num_steps_before_decay 100000 \ 22 | --max_steps 150005 \ 23 | --save_freq 10000 \ 24 | --save_latest_checkpoint_only True \ 25 | --merge_lora_during_training False \ 26 | --image_aug True \ 27 | --lora_rank 32 \ 28 | --wandb_entity "YOUR_WANDB_ENTITY" \ 29 | --wandb_project "YOUR_WANDB_PROJECT" \ 30 | --run_id_override "YOUR_RUN_ID" -------------------------------------------------------------------------------- /openpi-SF/scripts/docker/install_nvidia_container_toolkit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Installs the NVIDIA Container Toolkit, which allows Docker containers to access NVIDIA GPUs. 4 | # NVIDIA's official documentation: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html 5 | 6 | curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg && 7 | curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | 8 | sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | 9 | sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list 10 | 11 | # NVIDIA's documenation omits 'sudo' in the following command, but it is required. 12 | sudo sed -i -e '/experimental/ s/^#//g' /etc/apt/sources.list.d/nvidia-container-toolkit.list 13 | sudo apt-get update 14 | sudo apt-get install -y nvidia-container-toolkit 15 | 16 | sudo nvidia-ctk runtime configure --runtime=docker 17 | sudo systemctl restart docker 18 | -------------------------------------------------------------------------------- /openpi-SF/examples/aloha_real/video_display.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from openpi_client.runtime import subscriber as _subscriber 4 | from typing_extensions import override 5 | 6 | 7 | class VideoDisplay(_subscriber.Subscriber): 8 | """Displays video frames.""" 9 | 10 | def __init__(self) -> None: 11 | self._ax: plt.Axes | None = None 12 | self._plt_img: plt.Image | None = None 13 | 14 | @override 15 | def on_episode_start(self) -> None: 16 | plt.ion() 17 | self._ax = plt.subplot() 18 | self._plt_img = None 19 | 20 | @override 21 | def on_step(self, observation: dict, action: dict) -> None: 22 | assert self._ax is not None 23 | 24 | im = observation["image"][0] # [C, H, W] 25 | im = np.transpose(im, (1, 2, 0)) # [H, W, C] 26 | 27 | if self._plt_img is None: 28 | self._plt_img = self._ax.imshow(im) 29 | else: 30 | self._plt_img.set_data(im) 31 | plt.pause(0.001) 32 | 33 | @override 34 | def on_episode_end(self) -> None: 35 | plt.ioff() 36 | plt.close() 37 | -------------------------------------------------------------------------------- /openpi-SF/examples/aloha_sim/compose.yml: -------------------------------------------------------------------------------- 1 | # Run with: 2 | # docker compose -f examples/aloha_sim/compose.yml up --build 3 | services: 4 | runtime: 5 | image: aloha_sim 6 | depends_on: 7 | - openpi_server 8 | build: 9 | context: ../.. 10 | dockerfile: examples/aloha_sim/Dockerfile 11 | init: true 12 | tty: true 13 | network_mode: host 14 | privileged: true 15 | volumes: 16 | - $PWD:/app 17 | - ../../data:/data 18 | 19 | openpi_server: 20 | image: openpi_server 21 | build: 22 | context: ../.. 23 | dockerfile: scripts/docker/serve_policy.Dockerfile 24 | init: true 25 | tty: true 26 | network_mode: host 27 | volumes: 28 | - $PWD:/app 29 | - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets 30 | environment: 31 | - SERVER_ARGS 32 | - OPENPI_DATA_HOME=/openpi_assets 33 | - IS_DOCKER=true 34 | 35 | # Comment out this block if not running on a machine with GPUs. 36 | deploy: 37 | resources: 38 | reservations: 39 | devices: 40 | - driver: nvidia 41 | count: 1 42 | capabilities: [gpu] 43 | -------------------------------------------------------------------------------- /openpi-SF/examples/simple_client/compose.yml: -------------------------------------------------------------------------------- 1 | # Run with: 2 | # docker compose -f examples/simple_client/compose.yml up --build 3 | services: 4 | runtime: 5 | image: simple_client 6 | depends_on: 7 | - openpi_server 8 | build: 9 | context: ../.. 10 | dockerfile: examples/simple_client/Dockerfile 11 | init: true 12 | tty: true 13 | network_mode: host 14 | volumes: 15 | - $PWD:/app 16 | environment: 17 | - SERVER_ARGS 18 | 19 | openpi_server: 20 | image: openpi_server 21 | build: 22 | context: ../.. 23 | dockerfile: scripts/docker/serve_policy.Dockerfile 24 | init: true 25 | tty: true 26 | network_mode: host 27 | volumes: 28 | - $PWD:/app 29 | - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets 30 | environment: 31 | - SERVER_ARGS 32 | - OPENPI_DATA_HOME=/openpi_assets 33 | - IS_DOCKER=true 34 | 35 | # Comment out this block if not running on a machine with GPUs. 36 | deploy: 37 | resources: 38 | reservations: 39 | devices: 40 | - driver: nvidia 41 | count: 1 42 | capabilities: [gpu] 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Fuhao Li, Wenxuan Song, Han Zhao, Jingbo Wang, Pengxiang Ding, Donglin Wang, Long Zeng, Haoang Li. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /openpi-SF/packages/openpi-client/src/openpi_client/runtime/environment.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class Environment(abc.ABC): 5 | """An Environment represents the robot and the environment it inhabits. 6 | 7 | The primary contract of environments is that they can be queried for observations 8 | about their state, and have actions applied to them to change that state. 9 | """ 10 | 11 | @abc.abstractmethod 12 | def reset(self) -> None: 13 | """Reset the environment to its initial state. 14 | 15 | This will be called once before starting each episode. 16 | """ 17 | 18 | @abc.abstractmethod 19 | def is_episode_complete(self) -> bool: 20 | """Allow the environment to signal that the episode is complete. 21 | 22 | This will be called after each step. It should return `True` if the episode is 23 | complete (either successfully or unsuccessfully), and `False` otherwise. 24 | """ 25 | 26 | @abc.abstractmethod 27 | def get_observation(self) -> dict: 28 | """Query the environment for the current state.""" 29 | 30 | @abc.abstractmethod 31 | def apply_action(self, action: dict) -> None: 32 | """Take an action in the environment.""" 33 | -------------------------------------------------------------------------------- /openpi-SF/src/vggt/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | authors = [{name = "Jianyuan Wang", email = "jianyuan@robots.ox.ac.uk"}] 3 | dependencies = [ 4 | "numpy<2", 5 | "Pillow", 6 | "huggingface_hub", 7 | "einops", 8 | "safetensors", 9 | "opencv-python", 10 | ] 11 | name = "vggt" 12 | requires-python = ">= 3.10" 13 | version = "0.0.1" 14 | 15 | [project.optional-dependencies] 16 | demo = [ 17 | "gradio==5.17.1", 18 | "viser==0.2.23", 19 | "tqdm", 20 | "hydra-core", 21 | "omegaconf", 22 | "opencv-python", 23 | "scipy", 24 | "onnxruntime", 25 | "requests", 26 | "trimesh", 27 | "matplotlib", 28 | ] 29 | 30 | # Using setuptools as the build backend 31 | [build-system] 32 | requires = ["setuptools>=61.0", "wheel"] 33 | build-backend = "setuptools.build_meta" 34 | 35 | # setuptools configuration 36 | [tool.setuptools.packages.find] 37 | where = ["."] 38 | include = ["vggt*"] 39 | 40 | # Pixi configuration 41 | [tool.pixi.workspace] 42 | channels = ["conda-forge"] 43 | platforms = ["linux-64"] 44 | 45 | [tool.pixi.pypi-dependencies] 46 | vggt = { path = ".", editable = true } 47 | 48 | [tool.pixi.environments] 49 | default = { solve-group = "default" } 50 | demo = { features = ["demo"], solve-group = "default" } 51 | 52 | [tool.pixi.tasks] 53 | -------------------------------------------------------------------------------- /openpi-SF/src/openpi/policies/policy_test.py: -------------------------------------------------------------------------------- 1 | from openpi_client import action_chunk_broker 2 | import pytest 3 | 4 | from openpi.policies import aloha_policy 5 | from openpi.policies import policy_config as _policy_config 6 | from openpi.training import config as _config 7 | 8 | 9 | @pytest.mark.manual 10 | def test_infer(): 11 | config = _config.get_config("pi0_aloha_sim") 12 | policy = _policy_config.create_trained_policy(config, "gs://openpi-assets/checkpoints/pi0_aloha_sim") 13 | 14 | example = aloha_policy.make_aloha_example() 15 | result = policy.infer(example) 16 | 17 | assert result["actions"].shape == (config.model.action_horizon, 14) 18 | 19 | 20 | @pytest.mark.manual 21 | def test_broker(): 22 | config = _config.get_config("pi0_aloha_sim") 23 | policy = _policy_config.create_trained_policy(config, "gs://openpi-assets/checkpoints/pi0_aloha_sim") 24 | 25 | broker = action_chunk_broker.ActionChunkBroker( 26 | policy, 27 | # Only execute the first half of the chunk. 28 | action_horizon=config.model.action_horizon // 2, 29 | ) 30 | 31 | example = aloha_policy.make_aloha_example() 32 | for _ in range(config.model.action_horizon): 33 | outputs = broker.infer(example) 34 | assert outputs["actions"].shape == (14,) 35 | -------------------------------------------------------------------------------- /openvla-SF/vggt/layers/drop_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py 9 | 10 | 11 | from torch import nn 12 | 13 | 14 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 15 | if drop_prob == 0.0 or not training: 16 | return x 17 | keep_prob = 1 - drop_prob 18 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 19 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 20 | if keep_prob > 0.0: 21 | random_tensor.div_(keep_prob) 22 | output = x * random_tensor 23 | return output 24 | 25 | 26 | class DropPath(nn.Module): 27 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 28 | 29 | def __init__(self, drop_prob=None): 30 | super(DropPath, self).__init__() 31 | self.drop_prob = drop_prob 32 | 33 | def forward(self, x): 34 | return drop_path(x, self.drop_prob, self.training) 35 | -------------------------------------------------------------------------------- /openpi-SF/src/vggt/layers/drop_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py 9 | 10 | 11 | from torch import nn 12 | 13 | 14 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 15 | if drop_prob == 0.0 or not training: 16 | return x 17 | keep_prob = 1 - drop_prob 18 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 19 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 20 | if keep_prob > 0.0: 21 | random_tensor.div_(keep_prob) 22 | output = x * random_tensor 23 | return output 24 | 25 | 26 | class DropPath(nn.Module): 27 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 28 | 29 | def __init__(self, drop_prob=None): 30 | super(DropPath, self).__init__() 31 | self.drop_prob = drop_prob 32 | 33 | def forward(self, x): 34 | return drop_path(x, self.drop_prob, self.training) 35 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/vla/datasets/rlds/utils/goal_relabeling.py: -------------------------------------------------------------------------------- 1 | """ 2 | goal_relabeling.py 3 | 4 | Contains simple goal relabeling logic for BC use-cases where rewards and next_observations are not required. 5 | Each function should add entries to the "task" dict. 6 | """ 7 | 8 | from typing import Dict 9 | 10 | import tensorflow as tf 11 | 12 | from prismatic.vla.datasets.rlds.utils.data_utils import tree_merge 13 | 14 | 15 | def uniform(traj: Dict) -> Dict: 16 | """Relabels with a true uniform distribution over future states.""" 17 | traj_len = tf.shape(tf.nest.flatten(traj["observation"])[0])[0] 18 | 19 | # Select a random future index for each transition i in the range [i + 1, traj_len) 20 | rand = tf.random.uniform([traj_len]) 21 | low = tf.cast(tf.range(traj_len) + 1, tf.float32) 22 | high = tf.cast(traj_len, tf.float32) 23 | goal_idxs = tf.cast(rand * (high - low) + low, tf.int32) 24 | 25 | # Sometimes there are floating-point errors that cause an out-of-bounds 26 | goal_idxs = tf.minimum(goal_idxs, traj_len - 1) 27 | 28 | # Adds keys to "task" mirroring "observation" keys (`tree_merge` to combine "pad_mask_dict" properly) 29 | goal = tf.nest.map_structure(lambda x: tf.gather(x, goal_idxs), traj["observation"]) 30 | traj["task"] = tree_merge(traj["task"], goal) 31 | 32 | return traj 33 | -------------------------------------------------------------------------------- /openpi-SF/src/openpi/training/utils.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | from typing import Any 3 | 4 | from flax import nnx 5 | from flax import struct 6 | import jax 7 | import optax 8 | 9 | from openpi.models import model as _model 10 | from openpi.shared import array_typing as at 11 | 12 | 13 | @at.typecheck 14 | @struct.dataclass 15 | class TrainState: 16 | step: at.Int[at.ArrayLike, ""] 17 | params: nnx.State 18 | model_def: nnx.GraphDef[_model.BaseModel] 19 | opt_state: optax.OptState 20 | tx: optax.GradientTransformation = struct.field(pytree_node=False) 21 | 22 | ema_decay: float | None = struct.field(pytree_node=False) 23 | ema_params: nnx.State | None = None 24 | 25 | 26 | @at.typecheck 27 | def tree_to_info(tree: at.PyTree, interp_func: Callable[[Any], str] = str) -> str: 28 | """Converts a PyTree into a human-readable string for logging. Optionally, `interp_func` can be provided to convert 29 | the leaf values to more meaningful strings. 30 | """ 31 | tree, _ = jax.tree_util.tree_flatten_with_path(tree) 32 | return "\n".join(f"{jax.tree_util.keystr(path)}: {interp_func(value)}" for path, value in tree) 33 | 34 | 35 | @at.typecheck 36 | def array_tree_to_info(tree: at.PyTree) -> str: 37 | """Converts a PyTree of arrays into a human-readable string for logging.""" 38 | return tree_to_info(tree, lambda x: f"{x.shape}@{x.dtype}") 39 | -------------------------------------------------------------------------------- /openpi-SF/examples/simple_client/Dockerfile: -------------------------------------------------------------------------------- 1 | # Dockerfile for the simple client. 2 | 3 | # Build the container: 4 | # docker build . -t simple_client -f examples/simple_client/Dockerfile 5 | 6 | # Run the container: 7 | # docker run --rm -it --network=host -v .:/app simple_client /bin/bash 8 | 9 | FROM python:3.7-slim 10 | COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/ 11 | 12 | WORKDIR /app 13 | 14 | # Copy from the cache instead of linking since it's a mounted volume 15 | ENV UV_LINK_MODE=copy 16 | 17 | # Write the virtual environment outside of the project directory so it doesn't 18 | # leak out of the container when we mount the application code. 19 | ENV UV_PROJECT_ENVIRONMENT=/.venv 20 | 21 | # Copy the requirements files so we can install dependencies. 22 | # The rest of the project is mounted as a volume, so we don't need to rebuild on changes. 23 | # This strategy is best for development-style usage. 24 | COPY ./examples/simple_client/requirements.txt /tmp/requirements.txt 25 | COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml 26 | 27 | # Install python dependencies. 28 | RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT 29 | RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml 30 | ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src 31 | 32 | CMD /bin/bash -c "source /.venv/bin/activate && python examples/simple_client/main.py $SERVER_ARGS" 33 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/models/backbones/vision/clip_vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | clip_vit.py 3 | """ 4 | 5 | from prismatic.models.backbones.vision.base_vision import TimmViTBackbone 6 | 7 | # Registry =>> Supported CLIP Vision Backbones (from TIMM) 8 | CLIP_VISION_BACKBONES = { 9 | "clip-vit-b": "vit_base_patch16_clip_224.openai", 10 | "clip-vit-l": "vit_large_patch14_clip_224.openai", 11 | "clip-vit-l-336px": "vit_large_patch14_clip_336.openai", 12 | } 13 | 14 | 15 | # [IMPORTANT] By Default, TIMM initialized OpenAI CLIP models with the standard GELU activation from PyTorch. 16 | # HOWEVER =>> Original OpenAI models were trained with the quick_gelu *approximation* -- while it's 17 | # a decent approximation, the resulting features are *worse*; this was a super tricky bug 18 | # to identify, but luckily there's an easy fix (`override_act_layer`) 19 | class CLIPViTBackbone(TimmViTBackbone): 20 | def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: 21 | super().__init__( 22 | vision_backbone_id, 23 | CLIP_VISION_BACKBONES[vision_backbone_id], 24 | image_resize_strategy, 25 | default_image_size=default_image_size, 26 | override_act_layer="quick_gelu" if CLIP_VISION_BACKBONES[vision_backbone_id].endswith(".openai") else None, 27 | ) 28 | -------------------------------------------------------------------------------- /openpi-SF/src/vggt/layers/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py 9 | 10 | 11 | from typing import Callable, Optional 12 | 13 | from torch import Tensor, nn 14 | 15 | 16 | class Mlp(nn.Module): 17 | def __init__( 18 | self, 19 | in_features: int, 20 | hidden_features: Optional[int] = None, 21 | out_features: Optional[int] = None, 22 | act_layer: Callable[..., nn.Module] = nn.GELU, 23 | drop: float = 0.0, 24 | bias: bool = True, 25 | ) -> None: 26 | super().__init__() 27 | out_features = out_features or in_features 28 | hidden_features = hidden_features or in_features 29 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 30 | self.act = act_layer() 31 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 32 | self.drop = nn.Dropout(drop) 33 | 34 | def forward(self, x: Tensor) -> Tensor: 35 | x = self.fc1(x) 36 | x = self.act(x) 37 | x = self.drop(x) 38 | x = self.fc2(x) 39 | x = self.drop(x) 40 | return x 41 | -------------------------------------------------------------------------------- /openvla-SF/vggt/layers/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py 9 | 10 | 11 | from typing import Callable, Optional 12 | 13 | from torch import Tensor, nn 14 | 15 | 16 | class Mlp(nn.Module): 17 | def __init__( 18 | self, 19 | in_features: int, 20 | hidden_features: Optional[int] = None, 21 | out_features: Optional[int] = None, 22 | act_layer: Callable[..., nn.Module] = nn.GELU, 23 | drop: float = 0.0, 24 | bias: bool = True, 25 | ) -> None: 26 | super().__init__() 27 | out_features = out_features or in_features 28 | hidden_features = hidden_features or in_features 29 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 30 | self.act = act_layer() 31 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 32 | self.drop = nn.Dropout(drop) 33 | 34 | def forward(self, x: Tensor) -> Tensor: 35 | x = self.fc1(x) 36 | x = self.act(x) 37 | x = self.drop(x) 38 | x = self.fc2(x) 39 | x = self.drop(x) 40 | return x 41 | -------------------------------------------------------------------------------- /openpi-SF/examples/aloha_sim/saver.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pathlib 3 | 4 | import imageio 5 | import numpy as np 6 | from openpi_client.runtime import subscriber as _subscriber 7 | from typing_extensions import override 8 | 9 | 10 | class VideoSaver(_subscriber.Subscriber): 11 | """Saves episode data.""" 12 | 13 | def __init__(self, out_dir: pathlib.Path, subsample: int = 1) -> None: 14 | out_dir.mkdir(parents=True, exist_ok=True) 15 | self._out_dir = out_dir 16 | self._images: list[np.ndarray] = [] 17 | self._subsample = subsample 18 | 19 | @override 20 | def on_episode_start(self) -> None: 21 | self._images = [] 22 | 23 | @override 24 | def on_step(self, observation: dict, action: dict) -> None: 25 | im = observation["images"]["cam_high"] # [C, H, W] 26 | im = np.transpose(im, (1, 2, 0)) # [H, W, C] 27 | self._images.append(im) 28 | 29 | @override 30 | def on_episode_end(self) -> None: 31 | existing = list(self._out_dir.glob("out_[0-9]*.mp4")) 32 | next_idx = max([int(p.stem.split("_")[1]) for p in existing], default=-1) + 1 33 | out_path = self._out_dir / f"out_{next_idx}.mp4" 34 | 35 | logging.info(f"Saving video to {out_path}") 36 | imageio.mimwrite( 37 | out_path, 38 | [np.asarray(x) for x in self._images[:: self._subsample]], 39 | fps=50 // max(1, self._subsample), 40 | ) 41 | -------------------------------------------------------------------------------- /openpi-SF/examples/libero/compose.yml: -------------------------------------------------------------------------------- 1 | # Run with: 2 | # docker compose -f examples/libero/compose.yml up --build 3 | services: 4 | runtime: 5 | image: libero 6 | depends_on: 7 | - openpi_server 8 | build: 9 | context: ../.. 10 | dockerfile: examples/libero/Dockerfile 11 | init: true 12 | tty: true 13 | network_mode: host 14 | privileged: true 15 | volumes: 16 | - $PWD:/app 17 | - ../../data:/data 18 | - /tmp/.X11-unix:/tmp/.X11-unix:ro 19 | environment: 20 | - CLIENT_ARGS 21 | - DISPLAY=$DISPLAY 22 | - MUJOCO_GL=${MUJOCO_GL:-egl} 23 | deploy: 24 | resources: 25 | reservations: 26 | devices: 27 | - driver: nvidia 28 | count: 1 29 | capabilities: [gpu] 30 | 31 | openpi_server: 32 | image: openpi_server 33 | build: 34 | context: ../.. 35 | dockerfile: scripts/docker/serve_policy.Dockerfile 36 | init: true 37 | tty: true 38 | network_mode: host 39 | volumes: 40 | - $PWD:/app 41 | - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets 42 | environment: 43 | - SERVER_ARGS 44 | - OPENPI_DATA_HOME=/openpi_assets 45 | - IS_DOCKER=true 46 | 47 | # Comment out this block if not running on a machine with GPUs. 48 | deploy: 49 | resources: 50 | reservations: 51 | devices: 52 | - driver: nvidia 53 | count: 1 54 | capabilities: [gpu] 55 | -------------------------------------------------------------------------------- /openpi-SF/src/openpi/shared/image_tools_test.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | 3 | from openpi.shared import image_tools 4 | 5 | 6 | def test_resize_with_pad_shapes(): 7 | # Test case 1: Resize image with larger dimensions 8 | images = jnp.zeros((2, 10, 10, 3), dtype=jnp.uint8) # Input images of shape (batch_size, height, width, channels) 9 | height = 20 10 | width = 20 11 | resized_images = image_tools.resize_with_pad(images, height, width) 12 | assert resized_images.shape == (2, height, width, 3) 13 | assert jnp.all(resized_images == 0) 14 | 15 | # Test case 2: Resize image with smaller dimensions 16 | images = jnp.zeros((3, 30, 30, 3), dtype=jnp.uint8) 17 | height = 15 18 | width = 15 19 | resized_images = image_tools.resize_with_pad(images, height, width) 20 | assert resized_images.shape == (3, height, width, 3) 21 | assert jnp.all(resized_images == 0) 22 | 23 | # Test case 3: Resize image with the same dimensions 24 | images = jnp.zeros((1, 50, 50, 3), dtype=jnp.uint8) 25 | height = 50 26 | width = 50 27 | resized_images = image_tools.resize_with_pad(images, height, width) 28 | assert resized_images.shape == (1, height, width, 3) 29 | assert jnp.all(resized_images == 0) 30 | 31 | # Test case 3: Resize image with odd-numbered padding 32 | images = jnp.zeros((1, 256, 320, 3), dtype=jnp.uint8) 33 | height = 60 34 | width = 80 35 | resized_images = image_tools.resize_with_pad(images, height, width) 36 | assert resized_images.shape == (1, height, width, 3) 37 | assert jnp.all(resized_images == 0) 38 | -------------------------------------------------------------------------------- /openpi-SF/packages/openpi-client/src/openpi_client/image_tools_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import openpi_client.image_tools as image_tools 4 | 5 | 6 | def test_resize_with_pad_shapes(): 7 | # Test case 1: Resize image with larger dimensions 8 | images = np.zeros((2, 10, 10, 3), dtype=np.uint8) # Input images of shape (batch_size, height, width, channels) 9 | height = 20 10 | width = 20 11 | resized_images = image_tools.resize_with_pad(images, height, width) 12 | assert resized_images.shape == (2, height, width, 3) 13 | assert np.all(resized_images == 0) 14 | 15 | # Test case 2: Resize image with smaller dimensions 16 | images = np.zeros((3, 30, 30, 3), dtype=np.uint8) 17 | height = 15 18 | width = 15 19 | resized_images = image_tools.resize_with_pad(images, height, width) 20 | assert resized_images.shape == (3, height, width, 3) 21 | assert np.all(resized_images == 0) 22 | 23 | # Test case 3: Resize image with the same dimensions 24 | images = np.zeros((1, 50, 50, 3), dtype=np.uint8) 25 | height = 50 26 | width = 50 27 | resized_images = image_tools.resize_with_pad(images, height, width) 28 | assert resized_images.shape == (1, height, width, 3) 29 | assert np.all(resized_images == 0) 30 | 31 | # Test case 3: Resize image with odd-numbered padding 32 | images = np.zeros((1, 256, 320, 3), dtype=np.uint8) 33 | height = 60 34 | width = 80 35 | resized_images = image_tools.resize_with_pad(images, height, width) 36 | assert resized_images.shape == (1, height, width, 3) 37 | assert np.all(resized_images == 0) 38 | -------------------------------------------------------------------------------- /openpi-SF/examples/aloha_sim/main.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import logging 3 | import pathlib 4 | 5 | import env as _env 6 | from openpi_client import action_chunk_broker 7 | from openpi_client import websocket_client_policy as _websocket_client_policy 8 | from openpi_client.runtime import runtime as _runtime 9 | from openpi_client.runtime.agents import policy_agent as _policy_agent 10 | import saver as _saver 11 | import tyro 12 | 13 | 14 | @dataclasses.dataclass 15 | class Args: 16 | out_dir: pathlib.Path = pathlib.Path("data/aloha_sim/videos") 17 | 18 | task: str = "gym_aloha/AlohaTransferCube-v0" 19 | seed: int = 0 20 | 21 | action_horizon: int = 10 22 | 23 | host: str = "0.0.0.0" 24 | port: int = 8000 25 | 26 | display: bool = False 27 | 28 | 29 | def main(args: Args) -> None: 30 | runtime = _runtime.Runtime( 31 | environment=_env.AlohaSimEnvironment( 32 | task=args.task, 33 | seed=args.seed, 34 | ), 35 | agent=_policy_agent.PolicyAgent( 36 | policy=action_chunk_broker.ActionChunkBroker( 37 | policy=_websocket_client_policy.WebsocketClientPolicy( 38 | host=args.host, 39 | port=args.port, 40 | ), 41 | action_horizon=args.action_horizon, 42 | ) 43 | ), 44 | subscribers=[ 45 | _saver.VideoSaver(args.out_dir), 46 | ], 47 | max_hz=50, 48 | ) 49 | 50 | runtime.run() 51 | 52 | 53 | if __name__ == "__main__": 54 | logging.basicConfig(level=logging.INFO, force=True) 55 | tyro.cli(main) 56 | -------------------------------------------------------------------------------- /openpi-SF/examples/aloha_real/main.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import logging 3 | 4 | from openpi_client import action_chunk_broker 5 | from openpi_client import websocket_client_policy as _websocket_client_policy 6 | from openpi_client.runtime import runtime as _runtime 7 | from openpi_client.runtime.agents import policy_agent as _policy_agent 8 | import tyro 9 | 10 | from examples.aloha_real import env as _env 11 | 12 | 13 | @dataclasses.dataclass 14 | class Args: 15 | host: str = "0.0.0.0" 16 | port: int = 8000 17 | 18 | action_horizon: int = 25 19 | 20 | num_episodes: int = 1 21 | max_episode_steps: int = 1000 22 | 23 | 24 | def main(args: Args) -> None: 25 | ws_client_policy = _websocket_client_policy.WebsocketClientPolicy( 26 | host=args.host, 27 | port=args.port, 28 | ) 29 | logging.info(f"Server metadata: {ws_client_policy.get_server_metadata()}") 30 | 31 | metadata = ws_client_policy.get_server_metadata() 32 | runtime = _runtime.Runtime( 33 | environment=_env.AlohaRealEnvironment(reset_position=metadata.get("reset_pose")), 34 | agent=_policy_agent.PolicyAgent( 35 | policy=action_chunk_broker.ActionChunkBroker( 36 | policy=ws_client_policy, 37 | action_horizon=args.action_horizon, 38 | ) 39 | ), 40 | subscribers=[], 41 | max_hz=50, 42 | num_episodes=args.num_episodes, 43 | max_episode_steps=args.max_episode_steps, 44 | ) 45 | 46 | runtime.run() 47 | 48 | 49 | if __name__ == "__main__": 50 | logging.basicConfig(level=logging.INFO, force=True) 51 | tyro.cli(main) 52 | -------------------------------------------------------------------------------- /openpi-SF/examples/aloha_sim/Dockerfile: -------------------------------------------------------------------------------- 1 | # Dockerfile for the Aloha simulation environment. 2 | 3 | # Build the container: 4 | # docker build . -t aloha_sim -f examples/aloha_sim/Dockerfile 5 | 6 | # Run the container: 7 | # docker run --rm -it --network=host -v .:/app aloha_sim /bin/bash 8 | 9 | FROM python:3.11-slim@sha256:370c586a6ffc8c619e6d652f81c094b34b14b8f2fb9251f092de23f16e299b78 10 | COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/ 11 | 12 | RUN apt-get update && \ 13 | apt-get install -y \ 14 | libosmesa6-dev \ 15 | libgl1-mesa-glx \ 16 | libglew-dev \ 17 | libglfw3-dev \ 18 | libgles2-mesa-dev 19 | ENV MUJOCO_GL=egl 20 | 21 | WORKDIR /app 22 | 23 | # Copy from the cache instead of linking since it's a mounted volume 24 | ENV UV_LINK_MODE=copy 25 | 26 | # Write the virtual environment outside of the project directory so it doesn't 27 | # leak out of the container when we mount the application code. 28 | ENV UV_PROJECT_ENVIRONMENT=/.venv 29 | 30 | # Copy the requirements files so we can install dependencies. 31 | # The rest of the project is mounted as a volume, so we don't need to rebuild on changes. 32 | # This strategy is best for development-style usage. 33 | COPY ./examples/aloha_sim/requirements.txt /tmp/requirements.txt 34 | COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml 35 | 36 | # Install python dependencies. 37 | RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT 38 | RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml 39 | ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src 40 | 41 | CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/aloha_sim/main.py"] -------------------------------------------------------------------------------- /openpi-SF/packages/openpi-client/src/openpi_client/action_chunk_broker.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import numpy as np 4 | import tree 5 | from typing_extensions import override 6 | 7 | from openpi_client import base_policy as _base_policy 8 | 9 | 10 | class ActionChunkBroker(_base_policy.BasePolicy): 11 | """Wraps a policy to return action chunks one-at-a-time. 12 | 13 | Assumes that the first dimension of all action fields is the chunk size. 14 | 15 | A new inference call to the inner policy is only made when the current 16 | list of chunks is exhausted. 17 | """ 18 | 19 | def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int): 20 | self._policy = policy 21 | self._action_horizon = action_horizon 22 | self._cur_step: int = 0 23 | 24 | self._last_results: Dict[str, np.ndarray] | None = None 25 | 26 | @override 27 | def infer(self, obs: Dict) -> Dict: # noqa: UP006 28 | if self._last_results is None: 29 | self._last_results = self._policy.infer(obs) 30 | self._cur_step = 0 31 | 32 | def slicer(x): 33 | if isinstance(x, np.ndarray): 34 | return x[self._cur_step, ...] 35 | else: 36 | return x 37 | 38 | results = tree.map_structure(slicer, self._last_results) 39 | self._cur_step += 1 40 | 41 | if self._cur_step >= self._action_horizon: 42 | self._last_results = None 43 | 44 | return results 45 | 46 | @override 47 | def reset(self) -> None: 48 | self._policy.reset() 49 | self._last_results = None 50 | self._cur_step = 0 51 | -------------------------------------------------------------------------------- /openpi-SF/scripts/docker/install_docker_ubuntu22.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Add Docker's official GPG key: 4 | sudo apt-get update 5 | sudo apt-get install -y ca-certificates curl 6 | sudo install -m 0755 -d /etc/apt/keyrings 7 | sudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc 8 | sudo chmod a+r /etc/apt/keyrings/docker.asc 9 | 10 | # Add the repository to Apt sources: 11 | echo \ 12 | "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \ 13 | $(. /etc/os-release && echo "$VERSION_CODENAME") stable" | 14 | sudo tee /etc/apt/sources.list.d/docker.list >/dev/null 15 | sudo apt-get update 16 | 17 | sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin 18 | 19 | # Add current user to the 'docker' group, which allows them to use docker commands (docker build, docker run, etc). 20 | # See https://docs.docker.com/engine/install/linux-postinstall/ 21 | username=$(whoami) 22 | sudo usermod -aG docker $username 23 | 24 | # Configure docker to start automatically on system boot. 25 | sudo systemctl enable docker.service 26 | sudo systemctl enable containerd.service 27 | 28 | # https://forums.docker.com/t/docker-credential-desktop-exe-executable-file-not-found-in-path-using-wsl2/100225/5 29 | if [ ~/.docker/config.json ]; then 30 | sed -i 's/credsStore/credStore/g' ~/.docker/config.json 31 | fi 32 | 33 | echo "" 34 | echo "********************************************************************" 35 | echo "**** Restart to allow Docker permission changes to take effect. ****" 36 | echo "********************************************************************" 37 | echo "" 38 | -------------------------------------------------------------------------------- /openpi-SF/src/openpi/shared/normalize_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import openpi.shared.normalize as normalize 4 | 5 | 6 | def test_normalize_update(): 7 | arr = np.arange(12).reshape(4, 3) # 4 vectors of length 3 8 | 9 | stats = normalize.RunningStats() 10 | for i in range(len(arr)): 11 | stats.update(arr[i : i + 1]) # Update with one vector at a time 12 | results = stats.get_statistics() 13 | 14 | assert np.allclose(results.mean, np.mean(arr, axis=0)) 15 | assert np.allclose(results.std, np.std(arr, axis=0)) 16 | 17 | 18 | def test_serialize_deserialize(): 19 | stats = normalize.RunningStats() 20 | stats.update(np.arange(12).reshape(4, 3)) # 4 vectors of length 3 21 | 22 | norm_stats = {"test": stats.get_statistics()} 23 | norm_stats2 = normalize.deserialize_json(normalize.serialize_json(norm_stats)) 24 | assert np.allclose(norm_stats["test"].mean, norm_stats2["test"].mean) 25 | assert np.allclose(norm_stats["test"].std, norm_stats2["test"].std) 26 | 27 | 28 | def test_multiple_batch_dimensions(): 29 | # Test with multiple batch dimensions: (2, 3, 4) where 4 is vector dimension 30 | batch_shape = (2, 3, 4) 31 | arr = np.random.rand(*batch_shape) 32 | 33 | stats = normalize.RunningStats() 34 | stats.update(arr) # Should handle (2, 3, 4) -> reshape to (6, 4) 35 | results = stats.get_statistics() 36 | 37 | # Flatten batch dimensions and compute expected stats 38 | flattened = arr.reshape(-1, arr.shape[-1]) # (6, 4) 39 | expected_mean = np.mean(flattened, axis=0) 40 | expected_std = np.std(flattened, axis=0) 41 | 42 | assert np.allclose(results.mean, expected_mean) 43 | assert np.allclose(results.std, expected_std) 44 | -------------------------------------------------------------------------------- /openpi-SF/src/openpi/shared/download_test.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import pytest 4 | 5 | import openpi.shared.download as download 6 | 7 | 8 | @pytest.fixture(scope="session", autouse=True) 9 | def set_openpi_data_home(tmp_path_factory): 10 | temp_dir = tmp_path_factory.mktemp("openpi_data") 11 | with pytest.MonkeyPatch().context() as mp: 12 | mp.setenv("OPENPI_DATA_HOME", str(temp_dir)) 13 | yield 14 | 15 | 16 | def test_download_local(tmp_path: pathlib.Path): 17 | local_path = tmp_path / "local" 18 | local_path.touch() 19 | 20 | result = download.maybe_download(str(local_path)) 21 | assert result == local_path 22 | 23 | with pytest.raises(FileNotFoundError): 24 | download.maybe_download("bogus") 25 | 26 | 27 | def test_download_gs_dir(): 28 | remote_path = "gs://openpi-assets/testdata/random" 29 | 30 | local_path = download.maybe_download(remote_path) 31 | assert local_path.exists() 32 | 33 | new_local_path = download.maybe_download(remote_path) 34 | assert new_local_path == local_path 35 | 36 | 37 | def test_download_gs(): 38 | remote_path = "gs://openpi-assets/testdata/random/random_512kb.bin" 39 | 40 | local_path = download.maybe_download(remote_path) 41 | assert local_path.exists() 42 | 43 | new_local_path = download.maybe_download(remote_path) 44 | assert new_local_path == local_path 45 | 46 | 47 | def test_download_fsspec(): 48 | remote_path = "gs://big_vision/paligemma_tokenizer.model" 49 | 50 | local_path = download.maybe_download(remote_path, gs={"token": "anon"}) 51 | assert local_path.exists() 52 | 53 | new_local_path = download.maybe_download(remote_path, gs={"token": "anon"}) 54 | assert new_local_path == local_path 55 | -------------------------------------------------------------------------------- /openpi-SF/examples/aloha_real/compose.yml: -------------------------------------------------------------------------------- 1 | # Run with: 2 | # docker compose -f examples/aloha_real/compose.yml up --build 3 | services: 4 | runtime: 5 | image: aloha_real 6 | depends_on: 7 | - aloha_ros_nodes 8 | - ros_master 9 | - openpi_server 10 | build: 11 | context: ../.. 12 | dockerfile: examples/aloha_real/Dockerfile 13 | init: true 14 | tty: true 15 | network_mode: host 16 | privileged: true 17 | volumes: 18 | - $PWD:/app 19 | - ../../data:/data 20 | 21 | aloha_ros_nodes: 22 | image: aloha_real 23 | depends_on: 24 | - ros_master 25 | build: 26 | context: ../.. 27 | dockerfile: examples/aloha_real/Dockerfile 28 | init: true 29 | tty: true 30 | network_mode: host 31 | privileged: true 32 | volumes: 33 | - /dev:/dev 34 | command: roslaunch --wait aloha ros_nodes.launch 35 | 36 | ros_master: 37 | image: ros:noetic-robot 38 | network_mode: host 39 | privileged: true 40 | command: 41 | - roscore 42 | 43 | openpi_server: 44 | image: openpi_server 45 | build: 46 | context: ../.. 47 | dockerfile: scripts/docker/serve_policy.Dockerfile 48 | init: true 49 | tty: true 50 | network_mode: host 51 | volumes: 52 | - $PWD:/app 53 | - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets 54 | environment: 55 | - SERVER_ARGS 56 | - OPENPI_DATA_HOME=/openpi_assets 57 | - IS_DOCKER=true 58 | 59 | # Comment out this block if not running on a machine with GPUs. 60 | deploy: 61 | resources: 62 | reservations: 63 | devices: 64 | - driver: nvidia 65 | count: 1 66 | capabilities: [gpu] 67 | -------------------------------------------------------------------------------- /openpi-SF/src/openpi/models/pi0_test.py: -------------------------------------------------------------------------------- 1 | import flax.nnx as nnx 2 | import jax 3 | 4 | import openpi.models.pi0_config as _pi0_config 5 | 6 | 7 | def _get_frozen_state(config: _pi0_config.Pi0Config) -> nnx.State: 8 | abstract_model = nnx.eval_shape(config.create, jax.random.key(0)) 9 | 10 | freeze_filter = config.get_freeze_filter() 11 | return nnx.state(abstract_model, nnx.All(nnx.Param, freeze_filter)).flat_state() 12 | 13 | 14 | def test_pi0_full_finetune(): 15 | config = _pi0_config.Pi0Config() 16 | state = _get_frozen_state(config) 17 | assert len(state) == 0 18 | 19 | 20 | def test_pi0_gemma_lora(): 21 | config = _pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora") 22 | state = _get_frozen_state(config) 23 | assert len(state) == 9 24 | assert all("lora" not in p for p in state) 25 | assert all("llm" in p for p in state) 26 | assert all("_1" not in p for p in state) 27 | 28 | 29 | def test_pi0_action_expert_lora(): 30 | config = _pi0_config.Pi0Config(action_expert_variant="gemma_300m_lora") 31 | state = _get_frozen_state(config) 32 | # excluding embedder, rest of the params should be same as gemma_lora. 33 | assert len(state) == 8 34 | assert all("lora" not in p for p in state) 35 | assert all("llm" in p for p in state) 36 | # all frozen params should have _1 in their path since it's the action expert. 37 | assert all(any("_1" in p for p in path) for path in state) 38 | 39 | 40 | def test_pi0_all_lora(): 41 | config = _pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora") 42 | state = _get_frozen_state(config) 43 | # sum of gemma_lora and action_expert_lora's frozen params. 44 | assert len(state) == 17 45 | assert all("lora" not in p for p in state) 46 | assert all("llm" in p for p in state) 47 | -------------------------------------------------------------------------------- /openpi-SF/packages/openpi-client/src/openpi_client/msgpack_numpy_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import tree 4 | 5 | from openpi_client import msgpack_numpy 6 | 7 | 8 | def _check(expected, actual): 9 | if isinstance(expected, np.ndarray): 10 | assert expected.shape == actual.shape 11 | assert expected.dtype == actual.dtype 12 | assert np.array_equal(expected, actual, equal_nan=expected.dtype.kind == "f") 13 | else: 14 | assert expected == actual 15 | 16 | 17 | @pytest.mark.parametrize( 18 | "data", 19 | [ 20 | 1, # int 21 | 1.0, # float 22 | "hello", # string 23 | np.bool_(True), # boolean scalar 24 | np.array([1, 2, 3])[0], # int scalar 25 | np.str_("asdf"), # string scalar 26 | [1, 2, 3], # list 27 | {"key": "value"}, # dict 28 | {"key": [1, 2, 3]}, # nested dict 29 | np.array(1.0), # 0D array 30 | np.array([1, 2, 3], dtype=np.int32), # 1D integer array 31 | np.array(["asdf", "qwer"]), # string array 32 | np.array([True, False]), # boolean array 33 | np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), # 2D float array 34 | np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.int16), # 3D integer array 35 | np.array([np.nan, np.inf, -np.inf]), # special float values 36 | {"arr": np.array([1, 2, 3]), "nested": {"arr": np.array([4, 5, 6])}}, # nested dict with arrays 37 | [np.array([1, 2]), np.array([3, 4])], # list of arrays 38 | np.zeros((3, 4, 5), dtype=np.float32), # 3D zeros 39 | np.ones((2, 3), dtype=np.float64), # 2D ones with double precision 40 | ], 41 | ) 42 | def test_pack_unpack(data): 43 | packed = msgpack_numpy.packb(data) 44 | unpacked = msgpack_numpy.unpackb(packed) 45 | tree.map_structure(_check, data, unpacked) 46 | -------------------------------------------------------------------------------- /openpi-SF/examples/aloha_real/env.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional # noqa: UP035 2 | 3 | import einops 4 | from openpi_client import image_tools 5 | from openpi_client.runtime import environment as _environment 6 | from typing_extensions import override 7 | 8 | from examples.aloha_real import real_env as _real_env 9 | 10 | 11 | class AlohaRealEnvironment(_environment.Environment): 12 | """An environment for an Aloha robot on real hardware.""" 13 | 14 | def __init__( 15 | self, 16 | reset_position: Optional[List[float]] = None, # noqa: UP006,UP007 17 | render_height: int = 224, 18 | render_width: int = 224, 19 | ) -> None: 20 | self._env = _real_env.make_real_env(init_node=True, reset_position=reset_position) 21 | self._render_height = render_height 22 | self._render_width = render_width 23 | 24 | self._ts = None 25 | 26 | @override 27 | def reset(self) -> None: 28 | self._ts = self._env.reset() 29 | 30 | @override 31 | def is_episode_complete(self) -> bool: 32 | return False 33 | 34 | @override 35 | def get_observation(self) -> dict: 36 | if self._ts is None: 37 | raise RuntimeError("Timestep is not set. Call reset() first.") 38 | 39 | obs = self._ts.observation 40 | for k in list(obs["images"].keys()): 41 | if "_depth" in k: 42 | del obs["images"][k] 43 | 44 | for cam_name in obs["images"]: 45 | img = image_tools.convert_to_uint8( 46 | image_tools.resize_with_pad(obs["images"][cam_name], self._render_height, self._render_width) 47 | ) 48 | obs["images"][cam_name] = einops.rearrange(img, "h w c -> c h w") 49 | 50 | return { 51 | "state": obs["qpos"], 52 | "images": obs["images"], 53 | } 54 | 55 | @override 56 | def apply_action(self, action: dict) -> None: 57 | self._ts = self._env.step(action["actions"]) 58 | -------------------------------------------------------------------------------- /openpi-SF/docs/docker.md: -------------------------------------------------------------------------------- 1 | ### Docker Setup 2 | 3 | All of the examples in this repo provide instructions for being run normally, and also using Docker. Although not required, the Docker option is recommended as this will simplify software installation, produce a more stable environment, and also allow you to avoid installing ROS and cluttering your machine, for examples which depend on ROS. 4 | 5 | - Basic Docker installation instructions are [here](https://docs.docker.com/engine/install/). 6 | - Docker must be installed in [rootless mode](https://docs.docker.com/engine/security/rootless/). 7 | - To use your GPU you must also install the [NVIDIA container toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html). 8 | - The version of docker installed with `snap` is incompatible with the NVIDIA container toolkit, preventing it from accessing `libnvidia-ml.so` ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/154)). The snap version can be uninstalled with `sudo snap remove docker`. 9 | - Docker Desktop is also incompatible with the NVIDIA runtime ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/229)). Docker Desktop can be uninstalled with `sudo apt remove docker-desktop`. 10 | 11 | 12 | If starting from scratch and your host machine is Ubuntu 22.04, you can use accomplish all of the above with the convenience scripts `scripts/docker/install_docker_ubuntu22.sh` and `scripts/docker/install_nvidia_container_toolkit.sh`. 13 | 14 | Build the Docker image and start the container with the following command: 15 | ```bash 16 | docker compose -f scripts/docker/compose.yml up --build 17 | ``` 18 | 19 | To build and run the Docker image for a specific example, use the following command: 20 | ```bash 21 | docker compose -f examples//compose.yml up --build 22 | ``` 23 | where `` is the name of the example you want to run. 24 | 25 | During the first run of any example, Docker will build the images. Go grab a coffee while this happens. Subsequent runs will be faster since the images are cached. -------------------------------------------------------------------------------- /openpi-SF/scripts/docker/serve_policy.Dockerfile: -------------------------------------------------------------------------------- 1 | # Dockerfile for serving a PI policy. 2 | # Based on UV's instructions: https://docs.astral.sh/uv/guides/integration/docker/#developing-in-a-container 3 | 4 | # Build the container: 5 | # docker build . -t openpi_server -f scripts/docker/serve_policy.Dockerfile 6 | 7 | # Run the container: 8 | # docker run --rm -it --network=host -v .:/app --gpus=all openpi_server /bin/bash 9 | 10 | FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0 11 | COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/ 12 | 13 | WORKDIR /app 14 | 15 | # Needed because LeRobot uses git-lfs. 16 | RUN apt-get update && apt-get install -y git git-lfs linux-headers-generic build-essential clang 17 | 18 | # Copy from the cache instead of linking since it's a mounted volume 19 | ENV UV_LINK_MODE=copy 20 | 21 | # Write the virtual environment outside of the project directory so it doesn't 22 | # leak out of the container when we mount the application code. 23 | ENV UV_PROJECT_ENVIRONMENT=/.venv 24 | 25 | # Install the project's dependencies using the lockfile and settings 26 | RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT 27 | RUN --mount=type=cache,target=/root/.cache/uv \ 28 | --mount=type=bind,source=uv.lock,target=uv.lock \ 29 | --mount=type=bind,source=pyproject.toml,target=pyproject.toml \ 30 | --mount=type=bind,source=packages/openpi-client/pyproject.toml,target=packages/openpi-client/pyproject.toml \ 31 | --mount=type=bind,source=packages/openpi-client/src,target=packages/openpi-client/src \ 32 | GIT_LFS_SKIP_SMUDGE=1 uv sync --frozen --no-install-project --no-dev 33 | 34 | # Copy transformers_replace files while preserving directory structure 35 | COPY src/openpi/models_pytorch/transformers_replace/ /tmp/transformers_replace/ 36 | RUN /.venv/bin/python -c "import transformers; print(transformers.__file__)" | xargs dirname | xargs -I{} cp -r /tmp/transformers_replace/* {} && rm -rf /tmp/transformers_replace 37 | 38 | CMD /bin/bash -c "uv run scripts/serve_policy.py $SERVER_ARGS" 39 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/training/train_utils.py: -------------------------------------------------------------------------------- 1 | """Utils for training/fine-tuning scripts.""" 2 | 3 | import torch 4 | 5 | from prismatic.vla.constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX 6 | 7 | 8 | def get_current_action_mask(token_ids): 9 | # Create a tensor marking positions of IGNORE_INDEX 10 | newline_positions = token_ids != IGNORE_INDEX 11 | 12 | # Calculate cumulative sum to identify regions between newlines 13 | cumsum = torch.cumsum(newline_positions, dim=1) 14 | 15 | # Create the mask 16 | mask = (1 <= cumsum) & (cumsum <= ACTION_DIM) 17 | 18 | # Extract the action part only 19 | action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX 20 | mask = action_tokens_only_mask * mask 21 | 22 | return mask 23 | 24 | 25 | def get_next_actions_mask(token_ids): 26 | # Create a tensor marking positions of IGNORE_INDEX 27 | newline_positions = token_ids != IGNORE_INDEX 28 | 29 | # Calculate cumulative sum to identify regions between newlines 30 | cumsum = torch.cumsum(newline_positions, dim=1) 31 | 32 | # Create the mask 33 | mask = cumsum > ACTION_DIM 34 | 35 | # Extract the action part only 36 | action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX 37 | mask = action_tokens_only_mask * mask 38 | 39 | return mask 40 | 41 | 42 | def compute_token_accuracy(predicted_token_ids, ground_truth_token_ids, mask): 43 | correct_preds = (predicted_token_ids == ground_truth_token_ids) & mask 44 | accuracy = correct_preds.sum().float() / mask.sum().float() 45 | return accuracy 46 | 47 | 48 | def compute_actions_l1_loss(action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask): 49 | pred_continuous_actions = torch.tensor( 50 | action_tokenizer.decode_token_ids_to_actions(predicted_token_ids[mask].cpu().numpy()) 51 | ) 52 | true_continuous_actions = torch.tensor( 53 | action_tokenizer.decode_token_ids_to_actions(ground_truth_token_ids[mask].cpu().numpy()) 54 | ) 55 | l1_loss = torch.nn.functional.l1_loss(pred_continuous_actions, true_continuous_actions) 56 | return l1_loss 57 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/util/nn_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | nn_utils.py 3 | 4 | Utility functions and PyTorch submodule definitions. 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | # === Definitions for Various Projection Modules, with Signature :: [..., in_dim] --> [..., out_dim] === 12 | class LinearProjector(nn.Module): 13 | def __init__(self, vision_dim: int, llm_dim: int) -> None: 14 | super().__init__() 15 | self.projector = nn.Linear(vision_dim, llm_dim, bias=True) 16 | 17 | def forward(self, img_patches: torch.Tensor) -> torch.Tensor: 18 | return self.projector(img_patches) 19 | 20 | 21 | class MLPProjector(nn.Module): 22 | def __init__(self, vision_dim: int, llm_dim: int, mlp_type: str = "gelu-mlp") -> None: 23 | super().__init__() 24 | if mlp_type == "gelu-mlp": 25 | self.projector = nn.Sequential( 26 | nn.Linear(vision_dim, llm_dim, bias=True), 27 | nn.GELU(), 28 | nn.Linear(llm_dim, llm_dim, bias=True), 29 | ) 30 | else: 31 | raise ValueError(f"Projector with `{mlp_type = }` is not supported!") 32 | 33 | def forward(self, img_patches: torch.Tensor) -> torch.Tensor: 34 | return self.projector(img_patches) 35 | 36 | 37 | class FusedMLPProjector(nn.Module): 38 | def __init__(self, fused_vision_dim: int, llm_dim: int, mlp_type: str = "fused-gelu-mlp") -> None: 39 | super().__init__() 40 | self.initial_projection_dim = fused_vision_dim * 4 41 | if mlp_type == "fused-gelu-mlp": 42 | self.projector = nn.Sequential( 43 | nn.Linear(fused_vision_dim, self.initial_projection_dim, bias=True), 44 | nn.GELU(), 45 | nn.Linear(self.initial_projection_dim, llm_dim, bias=True), 46 | nn.GELU(), 47 | nn.Linear(llm_dim, llm_dim, bias=True), 48 | ) 49 | else: 50 | raise ValueError(f"Fused Projector with `{mlp_type = }` is not supported!") 51 | 52 | def forward(self, fused_img_patches: torch.Tensor) -> torch.Tensor: 53 | return self.projector(fused_img_patches) 54 | -------------------------------------------------------------------------------- /openpi-SF/examples/aloha_sim/env.py: -------------------------------------------------------------------------------- 1 | import gym_aloha # noqa: F401 2 | import gymnasium 3 | import numpy as np 4 | from openpi_client import image_tools 5 | from openpi_client.runtime import environment as _environment 6 | from typing_extensions import override 7 | 8 | 9 | class AlohaSimEnvironment(_environment.Environment): 10 | """An environment for an Aloha robot in simulation.""" 11 | 12 | def __init__(self, task: str, obs_type: str = "pixels_agent_pos", seed: int = 0) -> None: 13 | np.random.seed(seed) 14 | self._rng = np.random.default_rng(seed) 15 | 16 | self._gym = gymnasium.make(task, obs_type=obs_type) 17 | 18 | self._last_obs = None 19 | self._done = True 20 | self._episode_reward = 0.0 21 | 22 | @override 23 | def reset(self) -> None: 24 | gym_obs, _ = self._gym.reset(seed=int(self._rng.integers(2**32 - 1))) 25 | self._last_obs = self._convert_observation(gym_obs) # type: ignore 26 | self._done = False 27 | self._episode_reward = 0.0 28 | 29 | @override 30 | def is_episode_complete(self) -> bool: 31 | return self._done 32 | 33 | @override 34 | def get_observation(self) -> dict: 35 | if self._last_obs is None: 36 | raise RuntimeError("Observation is not set. Call reset() first.") 37 | 38 | return self._last_obs # type: ignore 39 | 40 | @override 41 | def apply_action(self, action: dict) -> None: 42 | gym_obs, reward, terminated, truncated, info = self._gym.step(action["actions"]) 43 | self._last_obs = self._convert_observation(gym_obs) # type: ignore 44 | self._done = terminated or truncated 45 | self._episode_reward = max(self._episode_reward, reward) 46 | 47 | def _convert_observation(self, gym_obs: dict) -> dict: 48 | img = gym_obs["pixels"]["top"] 49 | img = image_tools.convert_to_uint8(image_tools.resize_with_pad(img, 224, 224)) 50 | # Convert axis order from [H, W, C] --> [C, H, W] 51 | img = np.transpose(img, (2, 0, 1)) 52 | 53 | return { 54 | "state": gym_obs["agent_pos"], 55 | "images": {"cam_high": img}, 56 | } 57 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/vla/datasets/rlds/utils/task_augmentation.py: -------------------------------------------------------------------------------- 1 | """ 2 | task_augmentation.py 3 | 4 | Contains basic logic for randomly zeroing out keys in the task specification. 5 | """ 6 | 7 | from typing import Dict 8 | 9 | import tensorflow as tf 10 | 11 | from prismatic.vla.datasets.rlds.utils.data_utils import to_padding 12 | 13 | 14 | def delete_task_conditioning(traj: Dict, keep_image_prob: float) -> Dict: 15 | """ 16 | Randomly drops out either the goal images or the language instruction. Only does something if both of 17 | these are present. 18 | 19 | Args: 20 | traj: A dictionary containing trajectory data. Should have a "task" key. 21 | keep_image_prob: The probability of keeping the goal images. The probability of keeping the language 22 | instruction is 1 - keep_image_prob. 23 | """ 24 | if "language_instruction" not in traj["task"]: 25 | return traj 26 | 27 | image_keys = {key for key in traj["task"].keys() if key.startswith("image_") or key.startswith("depth_")} 28 | if not image_keys: 29 | return traj 30 | 31 | traj_len = tf.shape(traj["action"])[0] 32 | should_keep_images = tf.random.uniform([traj_len]) < keep_image_prob 33 | should_keep_images |= ~traj["task"]["pad_mask_dict"]["language_instruction"] 34 | 35 | for key in image_keys | {"language_instruction"}: 36 | should_keep = should_keep_images if key in image_keys else ~should_keep_images 37 | # pad out the key 38 | traj["task"][key] = tf.where( 39 | should_keep, 40 | traj["task"][key], 41 | to_padding(traj["task"][key]), 42 | ) 43 | # zero out the pad mask dict for the key 44 | traj["task"]["pad_mask_dict"][key] = tf.where( 45 | should_keep, 46 | traj["task"]["pad_mask_dict"][key], 47 | tf.zeros_like(traj["task"]["pad_mask_dict"][key]), 48 | ) 49 | 50 | # when no goal images are present, the goal timestep becomes the final timestep 51 | traj["task"]["timestep"] = tf.where( 52 | should_keep_images, 53 | traj["task"]["timestep"], 54 | traj_len - 1, 55 | ) 56 | 57 | return traj 58 | -------------------------------------------------------------------------------- /openpi-SF/packages/openpi-client/src/openpi_client/msgpack_numpy.py: -------------------------------------------------------------------------------- 1 | """Adds NumPy array support to msgpack. 2 | 3 | msgpack is good for (de)serializing data over a network for multiple reasons: 4 | - msgpack is secure (as opposed to pickle/dill/etc which allow for arbitrary code execution) 5 | - msgpack is widely used and has good cross-language support 6 | - msgpack does not require a schema (as opposed to protobuf/flatbuffers/etc) which is convenient in dynamically typed 7 | languages like Python and JavaScript 8 | - msgpack is fast and efficient (as opposed to readable formats like JSON/YAML/etc); I found that msgpack was ~4x faster 9 | than pickle for serializing large arrays using the below strategy 10 | 11 | The code below is adapted from https://github.com/lebedov/msgpack-numpy. The reason not to use that library directly is 12 | that it falls back to pickle for object arrays. 13 | """ 14 | 15 | import functools 16 | 17 | import msgpack 18 | import numpy as np 19 | 20 | 21 | def pack_array(obj): 22 | if (isinstance(obj, (np.ndarray, np.generic))) and obj.dtype.kind in ("V", "O", "c"): 23 | raise ValueError(f"Unsupported dtype: {obj.dtype}") 24 | 25 | if isinstance(obj, np.ndarray): 26 | return { 27 | b"__ndarray__": True, 28 | b"data": obj.tobytes(), 29 | b"dtype": obj.dtype.str, 30 | b"shape": obj.shape, 31 | } 32 | 33 | if isinstance(obj, np.generic): 34 | return { 35 | b"__npgeneric__": True, 36 | b"data": obj.item(), 37 | b"dtype": obj.dtype.str, 38 | } 39 | 40 | return obj 41 | 42 | 43 | def unpack_array(obj): 44 | if b"__ndarray__" in obj: 45 | return np.ndarray(buffer=obj[b"data"], dtype=np.dtype(obj[b"dtype"]), shape=obj[b"shape"]) 46 | 47 | if b"__npgeneric__" in obj: 48 | return np.dtype(obj[b"dtype"]).type(obj[b"data"]) 49 | 50 | return obj 51 | 52 | 53 | Packer = functools.partial(msgpack.Packer, default=pack_array) 54 | packb = functools.partial(msgpack.packb, default=pack_array) 55 | 56 | Unpacker = functools.partial(msgpack.Unpacker, object_hook=unpack_array) 57 | unpackb = functools.partial(msgpack.unpackb, object_hook=unpack_array) 58 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/models/backbones/llm/phi.py: -------------------------------------------------------------------------------- 1 | """ 2 | phi.py 3 | 4 | Class definition for all LLMs derived from PhiForCausalLM. 5 | """ 6 | 7 | from typing import Optional, Type 8 | 9 | import torch 10 | from torch import nn as nn 11 | from transformers import PhiForCausalLM 12 | from transformers.models.phi.modeling_phi import PhiDecoderLayer 13 | 14 | from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone 15 | from prismatic.models.backbones.llm.prompting import PhiPromptBuilder, PromptBuilder 16 | 17 | # Registry ==> Support Phi Models (from HF Transformers) 18 | # fmt: off 19 | PHI_MODELS = { 20 | # === Phi-2 === 21 | "phi-2-3b": { 22 | "llm_family": "phi", "llm_cls": PhiForCausalLM, "hf_hub_path": "microsoft/phi-2" 23 | } 24 | } 25 | # fmt: on 26 | 27 | 28 | class PhiLLMBackbone(HFCausalLLMBackbone): 29 | def __init__( 30 | self, 31 | llm_backbone_id: str, 32 | llm_max_length: int = 2048, 33 | hf_token: Optional[str] = None, 34 | inference_mode: bool = False, 35 | use_flash_attention_2: bool = True, 36 | ) -> None: 37 | super().__init__( 38 | llm_backbone_id, 39 | llm_max_length=llm_max_length, 40 | hf_token=hf_token, 41 | inference_mode=inference_mode, 42 | use_flash_attention_2=use_flash_attention_2, 43 | **PHI_MODELS[llm_backbone_id], 44 | ) 45 | 46 | # [Special Case] Phi PAD Token Handling --> for clarity, we add an extra token (and resize) 47 | self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) 48 | self.llm.config.pad_token_id = self.tokenizer.pad_token_id 49 | self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64) 50 | 51 | @property 52 | def prompt_builder_fn(self) -> Type[PromptBuilder]: 53 | if self.identifier.startswith("phi-2"): 54 | return PhiPromptBuilder 55 | 56 | raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`") 57 | 58 | @property 59 | def transformer_layer_cls(self) -> Type[nn.Module]: 60 | return PhiDecoderLayer 61 | 62 | @property 63 | def half_precision_dtype(self) -> torch.dtype: 64 | return torch.bfloat16 65 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/vla/materialize.py: -------------------------------------------------------------------------------- 1 | """ 2 | materialize.py 3 | 4 | Factory class for initializing Open-X RLDS-backed datasets, given specified data mixture parameters; provides and 5 | exports individual functions for clear control flow. 6 | """ 7 | 8 | from pathlib import Path 9 | from typing import Tuple, Type 10 | 11 | from torch.utils.data import Dataset 12 | from transformers import PreTrainedTokenizerBase 13 | 14 | from prismatic.models.backbones.llm.prompting import PromptBuilder 15 | from prismatic.models.backbones.vision import ImageTransform 16 | from prismatic.util.data_utils import PaddedCollatorForActionPrediction 17 | from prismatic.vla.action_tokenizer import ActionTokenizer 18 | from prismatic.vla.datasets import EpisodicRLDSDataset, RLDSBatchTransform, RLDSDataset 19 | 20 | 21 | def get_vla_dataset_and_collator( 22 | data_root_dir: Path, 23 | data_mix: str, 24 | image_transform: ImageTransform, 25 | tokenizer: PreTrainedTokenizerBase, 26 | prompt_builder_fn: Type[PromptBuilder], 27 | default_image_resolution: Tuple[int, int, int], 28 | padding_side: str = "right", 29 | predict_stop_token: bool = True, 30 | shuffle_buffer_size: int = 100_000, 31 | train: bool = True, 32 | episodic: bool = False, 33 | image_aug: bool = False, 34 | ) -> Tuple[Dataset, ActionTokenizer, PaddedCollatorForActionPrediction]: 35 | """Initialize RLDS Dataset (wraps TFDS), ActionTokenizer, and initialize transform/collation functions.""" 36 | action_tokenizer = ActionTokenizer(tokenizer) 37 | batch_transform = RLDSBatchTransform( 38 | action_tokenizer, tokenizer, image_transform, prompt_builder_fn, predict_stop_token=predict_stop_token 39 | ) 40 | collator = PaddedCollatorForActionPrediction( 41 | tokenizer.model_max_length, tokenizer.pad_token_id, padding_side=padding_side 42 | ) 43 | 44 | # Build RLDS Iterable Dataset 45 | cls = RLDSDataset if not episodic else EpisodicRLDSDataset 46 | dataset = cls( 47 | data_root_dir, 48 | data_mix, 49 | batch_transform, 50 | resize_resolution=default_image_resolution[1:], 51 | shuffle_buffer_size=shuffle_buffer_size, 52 | train=train, 53 | image_aug=image_aug, 54 | ) 55 | 56 | return dataset, action_tokenizer, collator 57 | -------------------------------------------------------------------------------- /openpi-SF/packages/openpi-client/src/openpi_client/websocket_client_policy.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | from typing import Dict, Optional, Tuple 4 | 5 | from typing_extensions import override 6 | import websockets.sync.client 7 | 8 | from openpi_client import base_policy as _base_policy 9 | from openpi_client import msgpack_numpy 10 | 11 | 12 | class WebsocketClientPolicy(_base_policy.BasePolicy): 13 | """Implements the Policy interface by communicating with a server over websocket. 14 | 15 | See WebsocketPolicyServer for a corresponding server implementation. 16 | """ 17 | 18 | def __init__(self, host: str = "0.0.0.0", port: Optional[int] = None, api_key: Optional[str] = None) -> None: 19 | self._uri = f"ws://{host}" 20 | if port is not None: 21 | self._uri += f":{port}" 22 | self._packer = msgpack_numpy.Packer() 23 | self._api_key = api_key 24 | self._ws, self._server_metadata = self._wait_for_server() 25 | 26 | def get_server_metadata(self) -> Dict: 27 | return self._server_metadata 28 | 29 | def _wait_for_server(self) -> Tuple[websockets.sync.client.ClientConnection, Dict]: 30 | logging.info(f"Waiting for server at {self._uri}...") 31 | while True: 32 | try: 33 | headers = {"Authorization": f"Api-Key {self._api_key}"} if self._api_key else None 34 | conn = websockets.sync.client.connect( 35 | self._uri, compression=None, max_size=None, additional_headers=headers 36 | ) 37 | metadata = msgpack_numpy.unpackb(conn.recv()) 38 | return conn, metadata 39 | except ConnectionRefusedError: 40 | logging.info("Still waiting for server...") 41 | time.sleep(5) 42 | 43 | @override 44 | def infer(self, obs: Dict) -> Dict: # noqa: UP006 45 | data = self._packer.pack(obs) 46 | self._ws.send(data) 47 | response = self._ws.recv() 48 | if isinstance(response, str): 49 | # we're expecting bytes; if the server sends a string, it's an error. 50 | raise RuntimeError(f"Error in inference server:\n{response}") 51 | return msgpack_numpy.unpackb(response) 52 | 53 | @override 54 | def reset(self) -> None: 55 | pass 56 | -------------------------------------------------------------------------------- /openvla-SF/vggt/layers/swiglu_ffn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | from typing import Callable, Optional 8 | import warnings 9 | 10 | from torch import Tensor, nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class SwiGLUFFN(nn.Module): 15 | def __init__( 16 | self, 17 | in_features: int, 18 | hidden_features: Optional[int] = None, 19 | out_features: Optional[int] = None, 20 | act_layer: Callable[..., nn.Module] = None, 21 | drop: float = 0.0, 22 | bias: bool = True, 23 | ) -> None: 24 | super().__init__() 25 | out_features = out_features or in_features 26 | hidden_features = hidden_features or in_features 27 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 28 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 29 | 30 | def forward(self, x: Tensor) -> Tensor: 31 | x12 = self.w12(x) 32 | x1, x2 = x12.chunk(2, dim=-1) 33 | hidden = F.silu(x1) * x2 34 | return self.w3(hidden) 35 | 36 | 37 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 38 | # try: 39 | # if XFORMERS_ENABLED: 40 | # from xformers.ops import SwiGLU 41 | 42 | # XFORMERS_AVAILABLE = True 43 | # warnings.warn("xFormers is available (SwiGLU)") 44 | # else: 45 | # warnings.warn("xFormers is disabled (SwiGLU)") 46 | # raise ImportError 47 | # except ImportError: 48 | SwiGLU = SwiGLUFFN 49 | XFORMERS_AVAILABLE = False 50 | 51 | # warnings.warn("xFormers is not available (SwiGLU)") 52 | 53 | 54 | class SwiGLUFFNFused(SwiGLU): 55 | def __init__( 56 | self, 57 | in_features: int, 58 | hidden_features: Optional[int] = None, 59 | out_features: Optional[int] = None, 60 | act_layer: Callable[..., nn.Module] = None, 61 | drop: float = 0.0, 62 | bias: bool = True, 63 | ) -> None: 64 | out_features = out_features or in_features 65 | hidden_features = hidden_features or in_features 66 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 67 | super().__init__(in_features=in_features, hidden_features=hidden_features, out_features=out_features, bias=bias) 68 | -------------------------------------------------------------------------------- /openpi-SF/src/vggt/layers/swiglu_ffn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | from typing import Callable, Optional 8 | import warnings 9 | 10 | from torch import Tensor, nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class SwiGLUFFN(nn.Module): 15 | def __init__( 16 | self, 17 | in_features: int, 18 | hidden_features: Optional[int] = None, 19 | out_features: Optional[int] = None, 20 | act_layer: Callable[..., nn.Module] = None, 21 | drop: float = 0.0, 22 | bias: bool = True, 23 | ) -> None: 24 | super().__init__() 25 | out_features = out_features or in_features 26 | hidden_features = hidden_features or in_features 27 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 28 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 29 | 30 | def forward(self, x: Tensor) -> Tensor: 31 | x12 = self.w12(x) 32 | x1, x2 = x12.chunk(2, dim=-1) 33 | hidden = F.silu(x1) * x2 34 | return self.w3(hidden) 35 | 36 | 37 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 38 | # try: 39 | # if XFORMERS_ENABLED: 40 | # from xformers.ops import SwiGLU 41 | 42 | # XFORMERS_AVAILABLE = True 43 | # warnings.warn("xFormers is available (SwiGLU)") 44 | # else: 45 | # warnings.warn("xFormers is disabled (SwiGLU)") 46 | # raise ImportError 47 | # except ImportError: 48 | SwiGLU = SwiGLUFFN 49 | XFORMERS_AVAILABLE = False 50 | 51 | # warnings.warn("xFormers is not available (SwiGLU)") 52 | 53 | 54 | class SwiGLUFFNFused(SwiGLU): 55 | def __init__( 56 | self, 57 | in_features: int, 58 | hidden_features: Optional[int] = None, 59 | out_features: Optional[int] = None, 60 | act_layer: Callable[..., nn.Module] = None, 61 | drop: float = 0.0, 62 | bias: bool = True, 63 | ) -> None: 64 | out_features = out_features or in_features 65 | hidden_features = hidden_features or in_features 66 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 67 | super().__init__(in_features=in_features, hidden_features=hidden_features, out_features=out_features, bias=bias) 68 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/models/backbones/llm/prompting/mistral_instruct_prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | mistral_instruct_prompter.py 3 | 4 | Defines a PromptBuilder for building Mistral Instruct Chat Prompts --> recommended pattern used by HF / online tutorial.s 5 | 6 | Reference: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format 7 | """ 8 | 9 | from typing import Optional 10 | 11 | from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder 12 | 13 | 14 | class MistralInstructPromptBuilder(PromptBuilder): 15 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 16 | super().__init__(model_family, system_prompt) 17 | 18 | # Note =>> Mistral Tokenizer is an instance of `LlamaTokenizer(Fast)` 19 | # =>> Mistral Instruct *does not* use a System Prompt 20 | self.bos, self.eos = "", "" 21 | 22 | # Get role-specific "wrap" functions 23 | self.wrap_human = lambda msg: f"[INST] {msg} [/INST] " 24 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" 25 | 26 | # === `self.prompt` gets built up over multiple turns === 27 | self.prompt, self.turn_count = "", 0 28 | 29 | def add_turn(self, role: str, message: str) -> str: 30 | assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") 31 | message = message.replace("", "").strip() 32 | 33 | if (self.turn_count % 2) == 0: 34 | human_message = self.wrap_human(message) 35 | wrapped_message = human_message 36 | else: 37 | gpt_message = self.wrap_gpt(message) 38 | wrapped_message = gpt_message 39 | 40 | # Update Prompt 41 | self.prompt += wrapped_message 42 | 43 | # Bump Turn Counter 44 | self.turn_count += 1 45 | 46 | # Return "wrapped_message" (effective string added to context) 47 | return wrapped_message 48 | 49 | def get_potential_prompt(self, message: str) -> None: 50 | # Assumes that it's always the user's (human's) turn! 51 | prompt_copy = str(self.prompt) 52 | 53 | human_message = self.wrap_human(message) 54 | prompt_copy += human_message 55 | 56 | return prompt_copy.removeprefix(self.bos).rstrip() 57 | 58 | def get_prompt(self) -> str: 59 | # Remove prefix because it gets auto-inserted by tokenizer! 60 | return self.prompt.removeprefix(self.bos).rstrip() 61 | -------------------------------------------------------------------------------- /openpi-SF/src/vggt/utils/helper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | 9 | 10 | def randomly_limit_trues(mask: np.ndarray, max_trues: int) -> np.ndarray: 11 | """ 12 | If mask has more than max_trues True values, 13 | randomly keep only max_trues of them and set the rest to False. 14 | """ 15 | # 1D positions of all True entries 16 | true_indices = np.flatnonzero(mask) # shape = (N_true,) 17 | 18 | # if already within budget, return as-is 19 | if true_indices.size <= max_trues: 20 | return mask 21 | 22 | # randomly pick which True positions to keep 23 | sampled_indices = np.random.choice(true_indices, size=max_trues, replace=False) # shape = (max_trues,) 24 | 25 | # build new flat mask: True only at sampled positions 26 | limited_flat_mask = np.zeros(mask.size, dtype=bool) 27 | limited_flat_mask[sampled_indices] = True 28 | 29 | # restore original shape 30 | return limited_flat_mask.reshape(mask.shape) 31 | 32 | 33 | def create_pixel_coordinate_grid(num_frames, height, width): 34 | """ 35 | Creates a grid of pixel coordinates and frame indices for all frames. 36 | Returns: 37 | tuple: A tuple containing: 38 | - points_xyf (numpy.ndarray): Array of shape (num_frames, height, width, 3) 39 | with x, y coordinates and frame indices 40 | - y_coords (numpy.ndarray): Array of y coordinates for all frames 41 | - x_coords (numpy.ndarray): Array of x coordinates for all frames 42 | - f_coords (numpy.ndarray): Array of frame indices for all frames 43 | """ 44 | # Create coordinate grids for a single frame 45 | y_grid, x_grid = np.indices((height, width), dtype=np.float32) 46 | x_grid = x_grid[np.newaxis, :, :] 47 | y_grid = y_grid[np.newaxis, :, :] 48 | 49 | # Broadcast to all frames 50 | x_coords = np.broadcast_to(x_grid, (num_frames, height, width)) 51 | y_coords = np.broadcast_to(y_grid, (num_frames, height, width)) 52 | 53 | # Create frame indices and broadcast 54 | f_idx = np.arange(num_frames, dtype=np.float32)[:, np.newaxis, np.newaxis] 55 | f_coords = np.broadcast_to(f_idx, (num_frames, height, width)) 56 | 57 | # Stack coordinates and frame indices 58 | points_xyf = np.stack((x_coords, y_coords, f_coords), axis=-1) 59 | 60 | return points_xyf 61 | -------------------------------------------------------------------------------- /openvla-SF/vggt/utils/helper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | 9 | 10 | def randomly_limit_trues(mask: np.ndarray, max_trues: int) -> np.ndarray: 11 | """ 12 | If mask has more than max_trues True values, 13 | randomly keep only max_trues of them and set the rest to False. 14 | """ 15 | # 1D positions of all True entries 16 | true_indices = np.flatnonzero(mask) # shape = (N_true,) 17 | 18 | # if already within budget, return as-is 19 | if true_indices.size <= max_trues: 20 | return mask 21 | 22 | # randomly pick which True positions to keep 23 | sampled_indices = np.random.choice(true_indices, size=max_trues, replace=False) # shape = (max_trues,) 24 | 25 | # build new flat mask: True only at sampled positions 26 | limited_flat_mask = np.zeros(mask.size, dtype=bool) 27 | limited_flat_mask[sampled_indices] = True 28 | 29 | # restore original shape 30 | return limited_flat_mask.reshape(mask.shape) 31 | 32 | 33 | def create_pixel_coordinate_grid(num_frames, height, width): 34 | """ 35 | Creates a grid of pixel coordinates and frame indices for all frames. 36 | Returns: 37 | tuple: A tuple containing: 38 | - points_xyf (numpy.ndarray): Array of shape (num_frames, height, width, 3) 39 | with x, y coordinates and frame indices 40 | - y_coords (numpy.ndarray): Array of y coordinates for all frames 41 | - x_coords (numpy.ndarray): Array of x coordinates for all frames 42 | - f_coords (numpy.ndarray): Array of frame indices for all frames 43 | """ 44 | # Create coordinate grids for a single frame 45 | y_grid, x_grid = np.indices((height, width), dtype=np.float32) 46 | x_grid = x_grid[np.newaxis, :, :] 47 | y_grid = y_grid[np.newaxis, :, :] 48 | 49 | # Broadcast to all frames 50 | x_coords = np.broadcast_to(x_grid, (num_frames, height, width)) 51 | y_coords = np.broadcast_to(y_grid, (num_frames, height, width)) 52 | 53 | # Create frame indices and broadcast 54 | f_idx = np.arange(num_frames, dtype=np.float32)[:, np.newaxis, np.newaxis] 55 | f_coords = np.broadcast_to(f_idx, (num_frames, height, width)) 56 | 57 | # Stack coordinates and frame indices 58 | points_xyf = np.stack((x_coords, y_coords, f_coords), axis=-1) 59 | 60 | return points_xyf 61 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/training/materialize.py: -------------------------------------------------------------------------------- 1 | """ 2 | materialize.py 3 | 4 | Factory class defining functions for instantiating various Training Strategies, supporting different VLMs, backbones, 5 | and strategy configurations. 6 | """ 7 | 8 | from typing import Callable, Optional 9 | 10 | import torch 11 | 12 | from prismatic.models.vlms import PrismaticVLM 13 | from prismatic.training.strategies import FSDPStrategy, TrainingStrategy 14 | 15 | # Registry =>> Maps ID --> {cls(), kwargs} :: supports FSDP for now, but DDP handler is also implemented! 16 | TRAIN_STRATEGIES = { 17 | "fsdp-shard-grad-op": {"cls": FSDPStrategy, "kwargs": {"sharding_strategy": "shard-grad-op"}}, 18 | "fsdp-full-shard": {"cls": FSDPStrategy, "kwargs": {"sharding_strategy": "full-shard"}}, 19 | } 20 | 21 | 22 | def get_train_strategy( 23 | train_strategy: str, 24 | vlm: PrismaticVLM, 25 | device_id: int, 26 | stage: str, 27 | epochs: int, 28 | max_steps: Optional[int], 29 | global_batch_size: int, 30 | per_device_batch_size: int, 31 | learning_rate: float, 32 | weight_decay: float, 33 | max_grad_norm: float, 34 | lr_scheduler_type: str, 35 | warmup_ratio: float, 36 | enable_gradient_checkpointing: bool = True, 37 | enable_mixed_precision_training: bool = True, 38 | reduce_in_full_precision: bool = False, 39 | mixed_precision_dtype: torch.dtype = torch.bfloat16, 40 | worker_init_fn: Optional[Callable[[int], None]] = None, 41 | ) -> TrainingStrategy: 42 | if train_strategy in TRAIN_STRATEGIES: 43 | strategy_cfg = TRAIN_STRATEGIES[train_strategy] 44 | strategy = strategy_cfg["cls"]( 45 | vlm=vlm, 46 | device_id=device_id, 47 | stage=stage, 48 | epochs=epochs, 49 | max_steps=max_steps, 50 | global_batch_size=global_batch_size, 51 | per_device_batch_size=per_device_batch_size, 52 | learning_rate=learning_rate, 53 | weight_decay=weight_decay, 54 | max_grad_norm=max_grad_norm, 55 | lr_scheduler_type=lr_scheduler_type, 56 | warmup_ratio=warmup_ratio, 57 | enable_gradient_checkpointing=enable_gradient_checkpointing, 58 | enable_mixed_precision_training=enable_mixed_precision_training, 59 | reduce_in_full_precision=reduce_in_full_precision, 60 | mixed_precision_dtype=mixed_precision_dtype, 61 | worker_init_fn=worker_init_fn, 62 | **strategy_cfg["kwargs"], 63 | ) 64 | return strategy 65 | else: 66 | raise ValueError(f"Train Strategy `{train_strategy}` is not supported!") 67 | -------------------------------------------------------------------------------- /openpi-SF/examples/libero/Dockerfile: -------------------------------------------------------------------------------- 1 | # Dockerfile for the LIBERO benchmark. 2 | 3 | # Build the container: 4 | # docker build . -t libero -f examples/libero/Dockerfile 5 | 6 | # Run the container: 7 | # docker run --rm -it --network=host -v .:/app -v /tmp/.X11-unix:/tmp/.X11-unix:ro -e DISPLAY=$DISPLAY --gpus all libero /bin/bash 8 | 9 | FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0 10 | COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/ 11 | 12 | RUN apt-get update && \ 13 | apt-get install -y \ 14 | make \ 15 | g++ \ 16 | clang \ 17 | libosmesa6-dev \ 18 | libgl1-mesa-glx \ 19 | libglew-dev \ 20 | libglfw3-dev \ 21 | libgles2-mesa-dev \ 22 | libglib2.0-0 \ 23 | libsm6 \ 24 | libxrender1 \ 25 | libxext6 26 | 27 | WORKDIR /app 28 | 29 | # Copy from the cache instead of linking since it's a mounted volume 30 | ENV UV_LINK_MODE=copy 31 | 32 | # Write the virtual environment outside of the project directory so it doesn't 33 | # leak out of the container when we mount the application code. 34 | ENV UV_PROJECT_ENVIRONMENT=/.venv 35 | 36 | # Copy the requirements files so we can install dependencies. 37 | # The rest of the project is mounted as a volume, so we don't need to rebuild on changes. 38 | # This strategy is best for development-style usage. 39 | COPY ./examples/libero/requirements.txt /tmp/requirements.txt 40 | COPY ./third_party/libero/requirements.txt /tmp/requirements-libero.txt 41 | COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml 42 | 43 | # Install python dependencies. 44 | RUN uv venv --python 3.8 $UV_PROJECT_ENVIRONMENT 45 | RUN uv pip sync /tmp/requirements.txt /tmp/requirements-libero.txt /tmp/openpi-client/pyproject.toml --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match 46 | ENV PYTHONPATH=/app:/app/packages/openpi-client/src:/app/third_party/libero 47 | 48 | # Create a default config file to avoid an input prompt from LIBERO's init script. 49 | # https://github.com/Lifelong-Robot-Learning/LIBERO/blob/master/libero/libero/__init__.py 50 | ENV LIBERO_CONFIG_PATH=/tmp/libero 51 | RUN mkdir -p /tmp/libero && cat <<'EOF' > /tmp/libero/config.yaml 52 | benchmark_root: /app/third_party/libero/libero/libero 53 | bddl_files: /app/third_party/libero/libero/libero/bddl_files 54 | init_states: /app/third_party/libero/libero/libero/init_files 55 | datasets: /app/third_party/libero/libero/datasets 56 | assets: /app/third_party/libero/libero/libero/assets 57 | EOF 58 | 59 | CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/libero/main.py $CLIENT_ARGS"] 60 | -------------------------------------------------------------------------------- /openpi-SF/packages/openpi-client/src/openpi_client/image_tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | 5 | def convert_to_uint8(img: np.ndarray) -> np.ndarray: 6 | """Converts an image to uint8 if it is a float image. 7 | 8 | This is important for reducing the size of the image when sending it over the network. 9 | """ 10 | if np.issubdtype(img.dtype, np.floating): 11 | img = (255 * img).astype(np.uint8) 12 | return img 13 | 14 | 15 | def resize_with_pad(images: np.ndarray, height: int, width: int, method=Image.BILINEAR) -> np.ndarray: 16 | """Replicates tf.image.resize_with_pad for multiple images using PIL. Resizes a batch of images to a target height. 17 | 18 | Args: 19 | images: A batch of images in [..., height, width, channel] format. 20 | height: The target height of the image. 21 | width: The target width of the image. 22 | method: The interpolation method to use. Default is bilinear. 23 | 24 | Returns: 25 | The resized images in [..., height, width, channel]. 26 | """ 27 | # If the images are already the correct size, return them as is. 28 | if images.shape[-3:-1] == (height, width): 29 | return images 30 | 31 | original_shape = images.shape 32 | 33 | images = images.reshape(-1, *original_shape[-3:]) 34 | resized = np.stack([_resize_with_pad_pil(Image.fromarray(im), height, width, method=method) for im in images]) 35 | return resized.reshape(*original_shape[:-3], *resized.shape[-3:]) 36 | 37 | 38 | def _resize_with_pad_pil(image: Image.Image, height: int, width: int, method: int) -> Image.Image: 39 | """Replicates tf.image.resize_with_pad for one image using PIL. Resizes an image to a target height and 40 | width without distortion by padding with zeros. 41 | 42 | Unlike the jax version, note that PIL uses [width, height, channel] ordering instead of [batch, h, w, c]. 43 | """ 44 | cur_width, cur_height = image.size 45 | if cur_width == width and cur_height == height: 46 | return image # No need to resize if the image is already the correct size. 47 | 48 | ratio = max(cur_width / width, cur_height / height) 49 | resized_height = int(cur_height / ratio) 50 | resized_width = int(cur_width / ratio) 51 | resized_image = image.resize((resized_width, resized_height), resample=method) 52 | 53 | zero_image = Image.new(resized_image.mode, (width, height), 0) 54 | pad_height = max(0, int((height - resized_height) / 2)) 55 | pad_width = max(0, int((width - resized_width) / 2)) 56 | zero_image.paste(resized_image, (pad_width, pad_height)) 57 | assert zero_image.size == (width, height) 58 | return zero_image 59 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/models/backbones/llm/mistral.py: -------------------------------------------------------------------------------- 1 | """ 2 | mistral.py 3 | 4 | Class definition for all LLMs derived from MistralForCausalLM. 5 | """ 6 | 7 | from typing import Optional, Type 8 | 9 | import torch 10 | from torch import nn as nn 11 | from transformers import MistralForCausalLM 12 | from transformers.models.mistral.modeling_mistral import MistralDecoderLayer 13 | 14 | from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone 15 | from prismatic.models.backbones.llm.prompting import MistralInstructPromptBuilder, PromptBuilder, PurePromptBuilder 16 | 17 | # Registry =>> Support Mistral Models (from HF Transformers) 18 | # fmt: off 19 | MISTRAL_MODELS = { 20 | # === Base Mistral v0.1 === 21 | "mistral-v0.1-7b-pure": { 22 | "llm_family": "mistral", "llm_cls": MistralForCausalLM, "hf_hub_path": "mistralai/Mistral-7B-v0.1" 23 | }, 24 | 25 | # === Mistral Instruct v0.1 === 26 | "mistral-v0.1-7b-instruct": { 27 | "llm_family": "mistral", "llm_cls": MistralForCausalLM, "hf_hub_path": "mistralai/Mistral-7B-Instruct-v0.1" 28 | } 29 | } 30 | # fmt: on 31 | 32 | 33 | class MistralLLMBackbone(HFCausalLLMBackbone): 34 | def __init__( 35 | self, 36 | llm_backbone_id: str, 37 | llm_max_length: int = 2048, 38 | hf_token: Optional[str] = None, 39 | inference_mode: bool = False, 40 | use_flash_attention_2: bool = True, 41 | ) -> None: 42 | super().__init__( 43 | llm_backbone_id, 44 | llm_max_length=llm_max_length, 45 | hf_token=hf_token, 46 | inference_mode=inference_mode, 47 | use_flash_attention_2=use_flash_attention_2, 48 | **MISTRAL_MODELS[llm_backbone_id], 49 | ) 50 | 51 | # [Special Case] Mistral PAD Token Handling --> for clarity, we add an extra token (and resize) 52 | self.tokenizer.add_special_tokens({"pad_token": ""}) 53 | self.llm.config.pad_token_id = self.tokenizer.pad_token_id 54 | self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64) 55 | 56 | @property 57 | def prompt_builder_fn(self) -> Type[PromptBuilder]: 58 | if self.identifier.endswith("-pure"): 59 | return PurePromptBuilder 60 | 61 | elif self.identifier.endswith("-instruct"): 62 | return MistralInstructPromptBuilder 63 | 64 | raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`") 65 | 66 | @property 67 | def transformer_layer_cls(self) -> Type[nn.Module]: 68 | return MistralDecoderLayer 69 | 70 | @property 71 | def half_precision_dtype(self) -> torch.dtype: 72 | return torch.bfloat16 73 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/models/backbones/llm/prompting/base_prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | base_prompter.py 3 | 4 | Abstract class definition of a multi-turn prompt builder for ensuring consistent formatting for chat-based LLMs. 5 | """ 6 | 7 | from abc import ABC, abstractmethod 8 | from typing import Optional 9 | 10 | 11 | class PromptBuilder(ABC): 12 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 13 | self.model_family = model_family 14 | 15 | # Only some models define a system prompt => let subclasses handle this logic! 16 | self.system_prompt = system_prompt 17 | 18 | @abstractmethod 19 | def add_turn(self, role: str, message: str) -> str: ... 20 | 21 | @abstractmethod 22 | def get_potential_prompt(self, user_msg: str) -> None: ... 23 | 24 | @abstractmethod 25 | def get_prompt(self) -> str: ... 26 | 27 | 28 | class PurePromptBuilder(PromptBuilder): 29 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 30 | super().__init__(model_family, system_prompt) 31 | 32 | self.bos, self.eos = "", "" 33 | 34 | # Get role-specific "wrap" functions 35 | self.wrap_human = lambda msg: f"In: {msg}\nOut: " 36 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" 37 | 38 | # === `self.prompt` gets built up over multiple turns === 39 | self.prompt, self.turn_count = "", 0 40 | 41 | def add_turn(self, role: str, message: str) -> str: 42 | assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") 43 | message = message.replace("", "").strip() 44 | 45 | if (self.turn_count % 2) == 0: 46 | human_message = self.wrap_human(message) 47 | wrapped_message = human_message 48 | else: 49 | gpt_message = self.wrap_gpt(message) 50 | wrapped_message = gpt_message 51 | 52 | # Update Prompt 53 | self.prompt += wrapped_message 54 | 55 | # Bump Turn Counter 56 | self.turn_count += 1 57 | 58 | # Return "wrapped_message" (effective string added to context) 59 | return wrapped_message 60 | 61 | def get_potential_prompt(self, message: str) -> None: 62 | # Assumes that it's always the user's (human's) turn! 63 | prompt_copy = str(self.prompt) 64 | 65 | human_message = self.wrap_human(message) 66 | prompt_copy += human_message 67 | 68 | return prompt_copy.removeprefix(self.bos).rstrip() 69 | 70 | def get_prompt(self) -> str: 71 | # Remove prefix (if exists) because it gets auto-inserted by tokenizer! 72 | return self.prompt.removeprefix(self.bos).rstrip() 73 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/models/backbones/llm/prompting/phi_prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | phi_prompter.py 3 | 4 | Defines a PromptBuilder for building Phi-2 Input/Output Prompts --> recommended pattern used by HF / Microsoft. 5 | Also handles Phi special case BOS token additions. 6 | 7 | Reference: https://huggingface.co/microsoft/phi-2#qa-format 8 | """ 9 | 10 | from typing import Optional 11 | 12 | from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder 13 | 14 | 15 | class PhiPromptBuilder(PromptBuilder): 16 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 17 | super().__init__(model_family, system_prompt) 18 | 19 | # Note =>> Phi Tokenizer is an instance of `CodeGenTokenizer(Fast)` 20 | # =>> By default, does *not* append / tokens --> we handle that here (IMPORTANT)! 21 | self.bos, self.eos = "<|endoftext|>", "<|endoftext|>" 22 | 23 | # Get role-specific "wrap" functions 24 | # =>> Note that placement of / were based on experiments generating from Phi-2 in Input/Output mode 25 | self.wrap_human = lambda msg: f"Input: {msg}\nOutput: " 26 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}\n{self.eos}" 27 | 28 | # === `self.prompt` gets built up over multiple turns === 29 | self.prompt, self.turn_count = "", 0 30 | 31 | def add_turn(self, role: str, message: str) -> str: 32 | assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") 33 | message = message.replace("", "").strip() 34 | 35 | # Special Handling for "first" input --> prepend a token (expected by Prismatic) 36 | if self.turn_count == 0: 37 | bos_human_message = f"{self.bos}{self.wrap_human(message)}" 38 | wrapped_message = bos_human_message 39 | elif (self.turn_count % 2) == 0: 40 | human_message = self.wrap_human(message) 41 | wrapped_message = human_message 42 | else: 43 | gpt_message = self.wrap_gpt(message) 44 | wrapped_message = gpt_message 45 | 46 | # Update Prompt 47 | self.prompt += wrapped_message 48 | 49 | # Bump Turn Counter 50 | self.turn_count += 1 51 | 52 | # Return "wrapped_message" (effective string added to context) 53 | return wrapped_message 54 | 55 | def get_potential_prompt(self, message: str) -> None: 56 | # Assumes that it's always the user's (human's) turn! 57 | prompt_copy = str(self.prompt) 58 | 59 | human_message = self.wrap_human(message) 60 | prompt_copy += human_message 61 | 62 | return prompt_copy.rstrip() 63 | 64 | def get_prompt(self) -> str: 65 | return self.prompt.rstrip() 66 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/preprocessing/materialize.py: -------------------------------------------------------------------------------- 1 | """ 2 | materialize.py 3 | 4 | Factory class for initializing pretraining datasets on a per-VLM basis; provides and exports individual functions for 5 | clear control flow. 6 | """ 7 | 8 | from typing import Tuple, Type 9 | 10 | from torch.utils.data import Dataset 11 | from transformers import PreTrainedTokenizerBase 12 | 13 | from prismatic.conf import DatasetConfig 14 | from prismatic.models.backbones.llm.prompting import PromptBuilder 15 | from prismatic.models.backbones.vision import ImageTransform 16 | from prismatic.preprocessing.datasets import AlignDataset, FinetuneDataset 17 | from prismatic.util.data_utils import PaddedCollatorForLanguageModeling 18 | 19 | # Dataset Initializers =>> Maps Stage --> cls() 20 | DATASET_INITIALIZER = {"align": AlignDataset, "finetune": FinetuneDataset, "full-finetune": FinetuneDataset} 21 | 22 | 23 | def get_dataset_and_collator( 24 | stage: str, 25 | dataset_cfg: DatasetConfig, 26 | image_transform: ImageTransform, 27 | tokenizer: PreTrainedTokenizerBase, 28 | prompt_builder_fn: Type[PromptBuilder], 29 | default_image_resolution: Tuple[int, int, int], 30 | padding_side: str = "right", 31 | ) -> Tuple[Dataset, PaddedCollatorForLanguageModeling]: 32 | dataset_cls = DATASET_INITIALIZER[stage] 33 | dataset_root_dir = dataset_cfg.dataset_root_dir 34 | collator = PaddedCollatorForLanguageModeling( 35 | tokenizer.model_max_length, tokenizer.pad_token_id, default_image_resolution, padding_side=padding_side 36 | ) 37 | 38 | # Switch on `stage` 39 | if stage == "align": 40 | annotation_json, image_dir = dataset_cfg.align_stage_components 41 | dataset = dataset_cls( 42 | dataset_root_dir / annotation_json, dataset_root_dir / image_dir, image_transform, tokenizer 43 | ) 44 | return dataset, collator 45 | 46 | elif stage == "finetune": 47 | annotation_json, image_dir = dataset_cfg.finetune_stage_components 48 | dataset = dataset_cls( 49 | dataset_root_dir / annotation_json, 50 | dataset_root_dir / image_dir, 51 | image_transform, 52 | tokenizer, 53 | prompt_builder_fn=prompt_builder_fn, 54 | ) 55 | return dataset, collator 56 | 57 | elif stage == "full-finetune": 58 | annotation_json, image_dir = dataset_cfg.finetune_stage_components 59 | dataset = dataset_cls( 60 | dataset_root_dir / annotation_json, 61 | dataset_root_dir / image_dir, 62 | image_transform, 63 | tokenizer, 64 | prompt_builder_fn=prompt_builder_fn, 65 | ) 66 | return dataset, collator 67 | 68 | else: 69 | raise ValueError(f"Stage `{stage}` is not supported!") 70 | -------------------------------------------------------------------------------- /openpi-SF/examples/libero/README.md: -------------------------------------------------------------------------------- 1 | # LIBERO Benchmark 2 | 3 | This example runs the LIBERO benchmark: https://github.com/Lifelong-Robot-Learning/LIBERO 4 | 5 | Note: When updating requirements.txt in this directory, there is an additional flag `--extra-index-url https://download.pytorch.org/whl/cu113` that must be added to the `uv pip compile` command. 6 | 7 | This example requires git submodules to be initialized. Don't forget to run: 8 | 9 | ```bash 10 | git submodule update --init --recursive 11 | ``` 12 | 13 | ## With Docker (recommended) 14 | 15 | ```bash 16 | # Grant access to the X11 server: 17 | sudo xhost +local:docker 18 | 19 | # To run with the default checkpoint and task suite: 20 | SERVER_ARGS="--env LIBERO" docker compose -f examples/libero/compose.yml up --build 21 | 22 | # To run with glx for Mujoco instead (use this if you have egl errors): 23 | MUJOCO_GL=glx SERVER_ARGS="--env LIBERO" docker compose -f examples/libero/compose.yml up --build 24 | ``` 25 | 26 | You can customize the loaded checkpoint by providing additional `SERVER_ARGS` (see `scripts/serve_policy.py`), and the LIBERO task suite by providing additional `CLIENT_ARGS` (see `examples/libero/main.py`). 27 | For example: 28 | 29 | ```bash 30 | # To load a custom checkpoint (located in the top-level openpi/ directory): 31 | export SERVER_ARGS="--env LIBERO policy:checkpoint --policy.config pi05_libero --policy.dir ./my_custom_checkpoint" 32 | 33 | # To run the libero_10 task suite: 34 | export CLIENT_ARGS="--args.task-suite-name libero_10" 35 | ``` 36 | 37 | ## Without Docker (not recommended) 38 | 39 | Terminal window 1: 40 | 41 | ```bash 42 | # Create virtual environment 43 | uv venv --python 3.8 examples/libero/.venv 44 | source examples/libero/.venv/bin/activate 45 | uv pip sync examples/libero/requirements.txt third_party/libero/requirements.txt --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match 46 | uv pip install -e packages/openpi-client 47 | uv pip install -e third_party/libero 48 | export PYTHONPATH=$PYTHONPATH:$PWD/third_party/libero 49 | 50 | # Run the simulation 51 | python examples/libero/main.py 52 | 53 | # To run with glx for Mujoco instead (use this if you have egl errors): 54 | MUJOCO_GL=glx python examples/libero/main.py 55 | ``` 56 | 57 | Terminal window 2: 58 | 59 | ```bash 60 | # Run the server 61 | uv run scripts/serve_policy.py --env LIBERO 62 | ``` 63 | 64 | ## Results 65 | 66 | If you want to reproduce the following numbers, you can evaluate the checkpoint at `gs://openpi-assets/checkpoints/pi05_libero/`. This 67 | checkpoint was trained in openpi with the `pi05_libero` config. 68 | 69 | | Model | Libero Spatial | Libero Object | Libero Goal | Libero 10 | Average | 70 | |-------|---------------|---------------|-------------|-----------|---------| 71 | | π0.5 @ 30k (finetuned) | 98.8 | 98.2 | 98.0 | 92.4 | 96.85 72 | -------------------------------------------------------------------------------- /openpi-SF/src/openpi/training/data_loader_test.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | import jax 4 | 5 | from openpi.models import pi0_config 6 | from openpi.training import config as _config 7 | from openpi.training import data_loader as _data_loader 8 | 9 | 10 | def test_torch_data_loader(): 11 | config = pi0_config.Pi0Config(action_dim=24, action_horizon=50, max_token_len=48) 12 | dataset = _data_loader.FakeDataset(config, 16) 13 | 14 | loader = _data_loader.TorchDataLoader( 15 | dataset, 16 | local_batch_size=4, 17 | num_batches=2, 18 | ) 19 | batches = list(loader) 20 | 21 | assert len(batches) == 2 22 | for batch in batches: 23 | assert all(x.shape[0] == 4 for x in jax.tree.leaves(batch)) 24 | 25 | 26 | def test_torch_data_loader_infinite(): 27 | config = pi0_config.Pi0Config(action_dim=24, action_horizon=50, max_token_len=48) 28 | dataset = _data_loader.FakeDataset(config, 4) 29 | 30 | loader = _data_loader.TorchDataLoader(dataset, local_batch_size=4) 31 | data_iter = iter(loader) 32 | 33 | for _ in range(10): 34 | _ = next(data_iter) 35 | 36 | 37 | def test_torch_data_loader_parallel(): 38 | config = pi0_config.Pi0Config(action_dim=24, action_horizon=50, max_token_len=48) 39 | dataset = _data_loader.FakeDataset(config, 10) 40 | 41 | loader = _data_loader.TorchDataLoader(dataset, local_batch_size=4, num_batches=2, num_workers=2) 42 | batches = list(loader) 43 | 44 | assert len(batches) == 2 45 | 46 | for batch in batches: 47 | assert all(x.shape[0] == 4 for x in jax.tree.leaves(batch)) 48 | 49 | 50 | def test_with_fake_dataset(): 51 | config = _config.get_config("debug") 52 | 53 | loader = _data_loader.create_data_loader(config, skip_norm_stats=True, num_batches=2) 54 | batches = list(loader) 55 | 56 | assert len(batches) == 2 57 | 58 | for batch in batches: 59 | assert all(x.shape[0] == config.batch_size for x in jax.tree.leaves(batch)) 60 | 61 | for _, actions in batches: 62 | assert actions.shape == (config.batch_size, config.model.action_horizon, config.model.action_dim) 63 | 64 | 65 | def test_with_real_dataset(): 66 | config = _config.get_config("pi0_aloha_sim") 67 | config = dataclasses.replace(config, batch_size=4) 68 | 69 | loader = _data_loader.create_data_loader( 70 | config, 71 | # Skip since we may not have the data available. 72 | skip_norm_stats=True, 73 | num_batches=2, 74 | shuffle=True, 75 | ) 76 | # Make sure that we can get the data config. 77 | assert loader.data_config().repo_id == config.data.repo_id 78 | 79 | batches = list(loader) 80 | 81 | assert len(batches) == 2 82 | 83 | for _, actions in batches: 84 | assert actions.shape == (config.batch_size, config.model.action_horizon, config.model.action_dim) 85 | -------------------------------------------------------------------------------- /openvla-SF/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Ruff 132 | .ruff_cache/ 133 | 134 | # Auth Tokens / Hidden Files 135 | .hf_token 136 | .wandb_api_key 137 | .*_token 138 | .*api_key 139 | 140 | # IDE Caches 141 | .idea/ 142 | .vscode/ 143 | 144 | # Mac OS 145 | .DS_Store 146 | 147 | # Caches, Datasets, and Checkpoints 148 | cache/ 149 | data/ 150 | ckpts/ 151 | 152 | # Rollout videos and wandb logs 153 | rollouts/ 154 | wandb/ -------------------------------------------------------------------------------- /openpi-SF/src/openpi/models_pytorch/projectors.py: -------------------------------------------------------------------------------- 1 | """Implementation of additional projectors for additional inputs to the VLA models.""" 2 | import torch 3 | import torch.nn as nn 4 | import openpi.models.gemma as _gemma 5 | 6 | class AlignProjector(nn.Module): 7 | """ 8 | calculate the alignment between LLM and VGGT embeddings. 9 | """ 10 | def __init__( 11 | self, 12 | llm_dim: int, 13 | vggt_dim: int, 14 | use_vlm_norm: bool = False, 15 | ) -> None: 16 | super().__init__() 17 | 18 | self.llm_dim = llm_dim 19 | self.vggt_dim = vggt_dim 20 | 21 | self.fc1 = nn.Linear(self.llm_dim, 2 * self.vggt_dim, bias=True) 22 | self.fc2 = nn.Linear(2 * self.vggt_dim, 2 * self.vggt_dim, bias=True) 23 | self.act_fn1 = nn.GELU() 24 | 25 | self.vlm_norm = nn.LayerNorm(llm_dim) if use_vlm_norm else None 26 | 27 | self.initialize_weights() 28 | 29 | def initialize_weights(self): 30 | # Initialize transformer layers: 31 | def _basic_init(module): 32 | if isinstance(module, nn.Linear): 33 | torch.nn.init.xavier_uniform_(module.weight) 34 | if module.bias is not None: 35 | nn.init.constant_(module.bias, 0) 36 | self.apply(_basic_init) 37 | 38 | def align_dimension(self, LLM_embedding: torch.Tensor = None) -> torch.Tensor: 39 | if self.vlm_norm is not None: 40 | LLM_embedding = self.vlm_norm(LLM_embedding) 41 | projected_features = self.fc1(LLM_embedding) 42 | projected_features = self.act_fn1(projected_features) 43 | projected_features = self.fc2(projected_features) 44 | return projected_features 45 | 46 | def compute_align_loss_cosine(self, vision_hidden, vggt_hidden, align_mask): 47 | # vision_hidden has a shape of (bs, N, D) 48 | def mean_flat(x): 49 | return torch.mean(x, dim=list(range(1, len(x.size())))) 50 | align_loss = 0 51 | bsz = vision_hidden.shape[0] 52 | for _vision, _vggt, _mask in zip(vision_hidden, vggt_hidden, align_mask): 53 | _vision = torch.nn.functional.normalize(_vision, dim=-1) 54 | _vggt = torch.nn.functional.normalize(_vggt, dim=-1) 55 | # align_loss += 1 - torch.mean(vision_hidden * vggt_hidden).sum(dim=-1).mean() 56 | align_loss += 1 - mean_flat((_vision * _vggt)[_mask].sum(dim=-1)) # Cosine similarity loss 57 | align_loss /= bsz # Average over batch size 58 | return align_loss 59 | 60 | def forward(self, LLM_emb, target_emb, align_mask): 61 | # project vla dimension and calculate align loss 62 | LLM_emb = self.align_dimension(LLM_emb) 63 | align_loss = self.compute_align_loss_cosine(LLM_emb, target_emb, align_mask).mean() # mean for sequence length 64 | return align_loss 65 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/util/pooling_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from typing import List, Dict, Tuple, Union 5 | from einops import rearrange 6 | 7 | from vggt.heads.utils import create_uv_grid, position_grid_to_embed 8 | 9 | 10 | def _interpolate( 11 | x: torch.Tensor, 12 | size: Tuple[int, int] = None, 13 | scale_factor: float = None, 14 | mode: str = "bilinear", 15 | align_corners: bool = True, 16 | ) -> torch.Tensor: 17 | """ 18 | Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate. 19 | """ 20 | if size is None: 21 | size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor)) 22 | 23 | INT_MAX = 1610612736 24 | 25 | input_elements = size[0] * size[1] * x.shape[0] * x.shape[1] 26 | 27 | if input_elements > INT_MAX: 28 | chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0) 29 | interpolated_chunks = [ 30 | nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks 31 | ] 32 | x = torch.cat(interpolated_chunks, dim=0) 33 | return x.contiguous() 34 | else: 35 | return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners) 36 | 37 | def _apply_pos_embed(x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor: 38 | """ 39 | Apply positional embedding to tensor x. 40 | """ 41 | patch_w = x.shape[-1] 42 | patch_h = x.shape[-2] 43 | pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device) 44 | pos_embed = position_grid_to_embed(pos_embed, x.shape[1]) 45 | pos_embed = pos_embed * ratio 46 | pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1) 47 | return x + pos_embed 48 | 49 | def interpolate_pooling(hidden, patch_hw, img_hw, reference, pooling_func, use_vggt_pe): 50 | (patch_h, patch_w) = patch_hw 51 | (img_h, img_w) = img_hw 52 | bs, N, S, D = hidden.shape 53 | re_sample_ratio = 1 / np.sqrt(N * S / reference.shape[1]) 54 | 55 | _hidden = hidden.permute(0, 1, 3, 2) 56 | _hidden = _hidden.reshape(bs*N, D, patch_h, patch_w) 57 | if use_vggt_pe: 58 | _hidden = _apply_pos_embed(_hidden, img_w, img_h) 59 | hidden_pooling = _interpolate( 60 | _hidden, scale_factor=re_sample_ratio, mode=pooling_func, align_corners=True 61 | ) 62 | hidden_pooling = hidden_pooling.reshape(bs, N, D, -1).permute(0, 1, 3, 2).reshape(bs, -1, D) 63 | return hidden_pooling 64 | 65 | 66 | def custom_pooling(hidden, patch_hw, img_hw, reference, pooling_func, use_vggt_pe): 67 | if pooling_func in ['bilinear']: 68 | return interpolate_pooling(hidden, patch_hw, img_hw, reference, pooling_func, use_vggt_pe) 69 | else: 70 | raise NotImplementedError(f"Pooling function {pooling_func} is not implemented.") 71 | 72 | -------------------------------------------------------------------------------- /openpi-SF/examples/aloha_real/Dockerfile: -------------------------------------------------------------------------------- 1 | # Dockerfile for the Aloha real environment. 2 | 3 | # Build the container: 4 | # docker build . -t aloha_real -f examples/aloha_real/Dockerfile 5 | 6 | # Run the container: 7 | # docker run --rm -it --network=host -v /dev:/dev -v .:/app --privileged aloha_real /bin/bash 8 | 9 | FROM ros:noetic-robot@sha256:7cf0b9f6546abeba308ea42cb7ad3453f3e520e1af57cdf179fe915c939674bc 10 | SHELL ["/bin/bash", "-c"] 11 | 12 | ENV DEBIAN_FRONTEND=noninteractive 13 | RUN apt-get update && \ 14 | apt-get install -y --no-install-recommends \ 15 | cmake \ 16 | curl \ 17 | libffi-dev \ 18 | python3-rosdep \ 19 | python3-rosinstall \ 20 | python3-rosinstall-generator \ 21 | whiptail \ 22 | git \ 23 | wget \ 24 | openssh-client \ 25 | ros-noetic-cv-bridge \ 26 | ros-noetic-usb-cam \ 27 | ros-noetic-realsense2-camera \ 28 | keyboard-configuration 29 | 30 | WORKDIR /root 31 | RUN curl 'https://raw.githubusercontent.com/Interbotix/interbotix_ros_manipulators/main/interbotix_ros_xsarms/install/amd64/xsarm_amd64_install.sh' > xsarm_amd64_install.sh 32 | RUN chmod +x xsarm_amd64_install.sh 33 | RUN export TZ='America/Los_Angeles' && ./xsarm_amd64_install.sh -d noetic -n 34 | 35 | COPY ./third_party/aloha /root/interbotix_ws/src/aloha 36 | RUN cd /root/interbotix_ws && source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && catkin_make 37 | 38 | # Install python 3.10 because this ROS image comes with 3.8 39 | RUN mkdir /python && \ 40 | cd /python && \ 41 | wget https://www.python.org/ftp/python/3.10.14/Python-3.10.14.tgz && \ 42 | tar -zxvf Python-3.10.14.tgz && \ 43 | cd Python-3.10.14 && \ 44 | ls -lhR && \ 45 | ./configure --enable-optimizations && \ 46 | make install && \ 47 | echo 'alias python3="/usr/local/bin/python3.10"' >> ~/.bashrc && \ 48 | echo 'alias python="/usr/local/bin/python3.10"' >> ~/.bashrc && \ 49 | cd ~ && rm -rf /python && \ 50 | rm -rf /var/lib/apt/lists/* 51 | 52 | COPY --from=ghcr.io/astral-sh/uv:0.5.6 /uv /bin/uv 53 | ENV UV_HTTP_TIMEOUT=120 54 | ENV UV_LINK_MODE=copy 55 | COPY ./examples/aloha_real/requirements.txt /tmp/requirements.txt 56 | COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml 57 | RUN uv pip sync --python 3.10 --system /tmp/requirements.txt /tmp/openpi-client/pyproject.toml 58 | 59 | ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src:/root/interbotix_ws/src/aloha/aloha_scripts:/root/interbotix_ws/src/aloha 60 | WORKDIR /app 61 | 62 | # Create an entrypoint script to run the setup commands, followed by the command passed in. 63 | RUN cat <<'EOF' > /usr/local/bin/entrypoint.sh 64 | #!/bin/bash 65 | source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && "$@" 66 | EOF 67 | RUN chmod +x /usr/local/bin/entrypoint.sh 68 | 69 | ENTRYPOINT ["/usr/local/bin/entrypoint.sh"] 70 | CMD ["python3", "/app/examples/aloha_real/main.py"] 71 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/vla/constants.py: -------------------------------------------------------------------------------- 1 | """ 2 | Important constants for VLA training and evaluation. 3 | 4 | Attempts to automatically identify the correct constants to set based on the Python command used to launch 5 | training or evaluation. If it is unclear, defaults to using the LIBERO simulation benchmark constants. 6 | """ 7 | import sys 8 | from enum import Enum 9 | 10 | # Llama 2 token constants 11 | IGNORE_INDEX = -100 12 | ACTION_TOKEN_BEGIN_IDX = 31743 13 | STOP_INDEX = 2 # '' 14 | 15 | 16 | # Defines supported normalization schemes for action and proprioceptive state. 17 | class NormalizationType(str, Enum): 18 | # fmt: off 19 | NORMAL = "normal" # Normalize to Mean = 0, Stdev = 1 20 | BOUNDS = "bounds" # Normalize to Interval = [-1, 1] 21 | BOUNDS_Q99 = "bounds_q99" # Normalize [quantile_01, ..., quantile_99] --> [-1, ..., 1] 22 | # fmt: on 23 | 24 | 25 | # Define constants for each robot platform 26 | LIBERO_CONSTANTS = { 27 | "NUM_ACTIONS_CHUNK": 8, 28 | "ACTION_DIM": 7, 29 | "PROPRIO_DIM": 8, 30 | "ACTION_PROPRIO_NORMALIZATION_TYPE": NormalizationType.BOUNDS_Q99, 31 | } 32 | 33 | ALOHA_CONSTANTS = { 34 | "NUM_ACTIONS_CHUNK": 30, 35 | "ACTION_DIM": 14, 36 | "PROPRIO_DIM": 14, 37 | "ACTION_PROPRIO_NORMALIZATION_TYPE": NormalizationType.BOUNDS, 38 | } 39 | 40 | BRIDGE_CONSTANTS = { 41 | "NUM_ACTIONS_CHUNK": 5, 42 | "ACTION_DIM": 7, 43 | "PROPRIO_DIM": 7, 44 | "ACTION_PROPRIO_NORMALIZATION_TYPE": NormalizationType.BOUNDS_Q99, 45 | } 46 | 47 | 48 | # Function to detect robot platform from command line arguments 49 | def detect_robot_platform(): 50 | cmd_args = " ".join(sys.argv).lower() 51 | 52 | if "libero" in cmd_args: 53 | return "LIBERO" 54 | elif "aloha" in cmd_args: 55 | return "ALOHA" 56 | elif "bridge" in cmd_args: 57 | return "BRIDGE" 58 | else: 59 | # Default to LIBERO if unclear 60 | return "LIBERO" 61 | 62 | 63 | # Determine which robot platform to use 64 | ROBOT_PLATFORM = detect_robot_platform() 65 | 66 | # Set the appropriate constants based on the detected platform 67 | if ROBOT_PLATFORM == "LIBERO": 68 | constants = LIBERO_CONSTANTS 69 | elif ROBOT_PLATFORM == "ALOHA": 70 | constants = ALOHA_CONSTANTS 71 | elif ROBOT_PLATFORM == "BRIDGE": 72 | constants = BRIDGE_CONSTANTS 73 | 74 | # Assign constants to global variables 75 | NUM_ACTIONS_CHUNK = constants["NUM_ACTIONS_CHUNK"] 76 | ACTION_DIM = constants["ACTION_DIM"] 77 | PROPRIO_DIM = constants["PROPRIO_DIM"] 78 | ACTION_PROPRIO_NORMALIZATION_TYPE = constants["ACTION_PROPRIO_NORMALIZATION_TYPE"] 79 | 80 | # Print which robot platform constants are being used (for debugging) 81 | print(f"Using {ROBOT_PLATFORM} constants:") 82 | print(f" NUM_ACTIONS_CHUNK = {NUM_ACTIONS_CHUNK}") 83 | print(f" ACTION_DIM = {ACTION_DIM}") 84 | print(f" PROPRIO_DIM = {PROPRIO_DIM}") 85 | print(f" ACTION_PROPRIO_NORMALIZATION_TYPE = {ACTION_PROPRIO_NORMALIZATION_TYPE}") 86 | print("If needed, manually set the correct constants in `prismatic/vla/constants.py`!") 87 | -------------------------------------------------------------------------------- /openpi-SF/src/openpi/shared/nnx_utils.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | import dataclasses 3 | import functools 4 | import inspect 5 | import re 6 | from typing import Any, ParamSpec, TypeVar 7 | 8 | import flax.nnx as nnx 9 | import jax 10 | 11 | P = ParamSpec("P") 12 | R = TypeVar("R") 13 | 14 | 15 | def module_jit(meth: Callable[P, R], *jit_args, **jit_kwargs) -> Callable[P, R]: 16 | """A higher-order function to JIT-compile `nnx.Module` methods, freezing the module's state in the process. 17 | 18 | Why not `nnx.jit`? For some reason, naively applying `nnx.jit` to `nnx.Module` methods, bound or unbound, uses much 19 | more memory than necessary. I'm guessing it has something to do with the fact that it must keep track of module 20 | mutations. Also, `nnx.jit` has some inherent overhead compared to a standard `jax.jit`, since every call must 21 | traverse the NNX module graph. See https://github.com/google/flax/discussions/4224 for details. 22 | 23 | `module_jit` is an alternative that avoids these issues by freezing the module's state. The function returned by 24 | `module_jit` acts exactly like the original method, except that the state of the module is frozen to whatever it was 25 | when `module_jit` was called. Mutations to the module within `meth` are still allowed, but they will be discarded 26 | after the method call completes. 27 | """ 28 | if not (inspect.ismethod(meth) and isinstance(meth.__self__, nnx.Module)): 29 | raise ValueError("module_jit must only be used on bound methods of nnx.Modules.") 30 | 31 | graphdef, state = nnx.split(meth.__self__) 32 | 33 | def fun(state: nnx.State, *args: P.args, **kwargs: P.kwargs) -> R: 34 | module = nnx.merge(graphdef, state) 35 | return meth.__func__(module, *args, **kwargs) 36 | 37 | jitted_fn = jax.jit(fun, *jit_args, **jit_kwargs) 38 | 39 | @functools.wraps(meth) 40 | def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: 41 | return jitted_fn(state, *args, **kwargs) 42 | 43 | return wrapper 44 | 45 | 46 | @dataclasses.dataclass(frozen=True) 47 | class PathRegex: 48 | """NNX Filter that matches paths using a regex. 49 | 50 | By default, paths are joined with a `/` separator. This can be overridden by setting the `sep` argument. 51 | """ 52 | 53 | pattern: str | re.Pattern 54 | sep: str = "/" 55 | 56 | def __post_init__(self): 57 | if not isinstance(self.pattern, re.Pattern): 58 | object.__setattr__(self, "pattern", re.compile(self.pattern)) 59 | 60 | def __call__(self, path: nnx.filterlib.PathParts, x: Any) -> bool: 61 | joined_path = self.sep.join(str(x) for x in path) 62 | assert isinstance(self.pattern, re.Pattern) 63 | return self.pattern.fullmatch(joined_path) is not None 64 | 65 | 66 | def state_map(state: nnx.State, filter: nnx.filterlib.Filter, fn: Callable[[Any], Any]) -> nnx.State: 67 | """Apply a function to the leaves of the state that match the filter.""" 68 | filtered_keys = set(state.filter(filter).flat_state()) 69 | return state.map(lambda k, v: fn(v) if k in filtered_keys else v) 70 | -------------------------------------------------------------------------------- /openvla-SF/vggt/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 9 | 10 | from typing import Callable, Optional, Tuple, Union 11 | 12 | from torch import Tensor 13 | import torch.nn as nn 14 | 15 | 16 | def make_2tuple(x): 17 | if isinstance(x, tuple): 18 | assert len(x) == 2 19 | return x 20 | 21 | assert isinstance(x, int) 22 | return (x, x) 23 | 24 | 25 | class PatchEmbed(nn.Module): 26 | """ 27 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D) 28 | 29 | Args: 30 | img_size: Image size. 31 | patch_size: Patch token size. 32 | in_chans: Number of input image channels. 33 | embed_dim: Number of linear projection output channels. 34 | norm_layer: Normalization layer. 35 | """ 36 | 37 | def __init__( 38 | self, 39 | img_size: Union[int, Tuple[int, int]] = 224, 40 | patch_size: Union[int, Tuple[int, int]] = 16, 41 | in_chans: int = 3, 42 | embed_dim: int = 768, 43 | norm_layer: Optional[Callable] = None, 44 | flatten_embedding: bool = True, 45 | ) -> None: 46 | super().__init__() 47 | 48 | image_HW = make_2tuple(img_size) 49 | patch_HW = make_2tuple(patch_size) 50 | patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1]) 51 | 52 | self.img_size = image_HW 53 | self.patch_size = patch_HW 54 | self.patches_resolution = patch_grid_size 55 | self.num_patches = patch_grid_size[0] * patch_grid_size[1] 56 | 57 | self.in_chans = in_chans 58 | self.embed_dim = embed_dim 59 | 60 | self.flatten_embedding = flatten_embedding 61 | 62 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) 63 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 64 | 65 | def forward(self, x: Tensor) -> Tensor: 66 | _, _, H, W = x.shape 67 | patch_H, patch_W = self.patch_size 68 | 69 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" 70 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" 71 | 72 | x = self.proj(x) # B C H W 73 | H, W = x.size(2), x.size(3) 74 | x = x.flatten(2).transpose(1, 2) # B HW C 75 | x = self.norm(x) 76 | if not self.flatten_embedding: 77 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C 78 | return x 79 | 80 | def flops(self) -> float: 81 | Ho, Wo = self.patches_resolution 82 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 83 | if self.norm is not None: 84 | flops += Ho * Wo * self.embed_dim 85 | return flops 86 | -------------------------------------------------------------------------------- /openpi-SF/src/vggt/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 9 | 10 | from typing import Callable, Optional, Tuple, Union 11 | 12 | from torch import Tensor 13 | import torch.nn as nn 14 | 15 | 16 | def make_2tuple(x): 17 | if isinstance(x, tuple): 18 | assert len(x) == 2 19 | return x 20 | 21 | assert isinstance(x, int) 22 | return (x, x) 23 | 24 | 25 | class PatchEmbed(nn.Module): 26 | """ 27 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D) 28 | 29 | Args: 30 | img_size: Image size. 31 | patch_size: Patch token size. 32 | in_chans: Number of input image channels. 33 | embed_dim: Number of linear projection output channels. 34 | norm_layer: Normalization layer. 35 | """ 36 | 37 | def __init__( 38 | self, 39 | img_size: Union[int, Tuple[int, int]] = 224, 40 | patch_size: Union[int, Tuple[int, int]] = 16, 41 | in_chans: int = 3, 42 | embed_dim: int = 768, 43 | norm_layer: Optional[Callable] = None, 44 | flatten_embedding: bool = True, 45 | ) -> None: 46 | super().__init__() 47 | 48 | image_HW = make_2tuple(img_size) 49 | patch_HW = make_2tuple(patch_size) 50 | patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1]) 51 | 52 | self.img_size = image_HW 53 | self.patch_size = patch_HW 54 | self.patches_resolution = patch_grid_size 55 | self.num_patches = patch_grid_size[0] * patch_grid_size[1] 56 | 57 | self.in_chans = in_chans 58 | self.embed_dim = embed_dim 59 | 60 | self.flatten_embedding = flatten_embedding 61 | 62 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) 63 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 64 | 65 | def forward(self, x: Tensor) -> Tensor: 66 | _, _, H, W = x.shape 67 | patch_H, patch_W = self.patch_size 68 | 69 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" 70 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" 71 | 72 | x = self.proj(x) # B C H W 73 | H, W = x.size(2), x.size(3) 74 | x = x.flatten(2).transpose(1, 2) # B HW C 75 | x = self.norm(x) 76 | if not self.flatten_embedding: 77 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C 78 | return x 79 | 80 | def flops(self) -> float: 81 | Ho, Wo = self.patches_resolution 82 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 83 | if self.norm is not None: 84 | flops += Ho * Wo * self.embed_dim 85 | return flops 86 | -------------------------------------------------------------------------------- /openpi-SF/examples/aloha_sim/requirements.txt: -------------------------------------------------------------------------------- 1 | # This file was autogenerated by uv via the following command: 2 | # uv pip compile examples/aloha_sim/requirements.in -o examples/aloha_sim/requirements.txt --python-version 3.10 3 | absl-py==2.1.0 4 | # via 5 | # dm-control 6 | # dm-env 7 | # labmaze 8 | # mujoco 9 | certifi==2024.8.30 10 | # via requests 11 | charset-normalizer==3.4.0 12 | # via requests 13 | cloudpickle==3.1.0 14 | # via gymnasium 15 | contourpy==1.3.1 16 | # via matplotlib 17 | cycler==0.12.1 18 | # via matplotlib 19 | dm-control==1.0.14 20 | # via gym-aloha 21 | dm-env==1.6 22 | # via dm-control 23 | dm-tree==0.1.8 24 | # via 25 | # dm-control 26 | # dm-env 27 | docstring-parser==0.16 28 | # via tyro 29 | farama-notifications==0.0.4 30 | # via gymnasium 31 | fonttools==4.55.2 32 | # via matplotlib 33 | glfw==2.8.0 34 | # via 35 | # dm-control 36 | # mujoco 37 | gym-aloha==0.1.1 38 | # via -r examples/aloha_sim/requirements.in 39 | gymnasium==1.0.0 40 | # via gym-aloha 41 | idna==3.10 42 | # via requests 43 | imageio==2.36.1 44 | # via 45 | # -r examples/aloha_sim/requirements.in 46 | # gym-aloha 47 | imageio-ffmpeg==0.5.1 48 | # via imageio 49 | kiwisolver==1.4.7 50 | # via matplotlib 51 | labmaze==1.0.6 52 | # via dm-control 53 | lxml==5.3.0 54 | # via dm-control 55 | markdown-it-py==3.0.0 56 | # via rich 57 | matplotlib==3.9.3 58 | # via -r examples/aloha_sim/requirements.in 59 | mdurl==0.1.2 60 | # via markdown-it-py 61 | msgpack==1.1.0 62 | # via -r examples/aloha_sim/requirements.in 63 | mujoco==2.3.7 64 | # via 65 | # dm-control 66 | # gym-aloha 67 | numpy==1.26.4 68 | # via 69 | # -r examples/aloha_sim/requirements.in 70 | # contourpy 71 | # dm-control 72 | # dm-env 73 | # gymnasium 74 | # imageio 75 | # labmaze 76 | # matplotlib 77 | # mujoco 78 | # scipy 79 | packaging==24.2 80 | # via matplotlib 81 | pillow==11.0.0 82 | # via 83 | # imageio 84 | # matplotlib 85 | protobuf==5.29.1 86 | # via dm-control 87 | psutil==6.1.0 88 | # via imageio 89 | pygments==2.18.0 90 | # via rich 91 | pyopengl==3.1.7 92 | # via 93 | # dm-control 94 | # mujoco 95 | pyparsing==3.2.0 96 | # via 97 | # dm-control 98 | # matplotlib 99 | python-dateutil==2.9.0.post0 100 | # via matplotlib 101 | requests==2.32.3 102 | # via dm-control 103 | rich==13.9.4 104 | # via tyro 105 | scipy==1.14.1 106 | # via dm-control 107 | setuptools==75.6.0 108 | # via 109 | # dm-control 110 | # imageio-ffmpeg 111 | # labmaze 112 | shtab==1.7.1 113 | # via tyro 114 | six==1.17.0 115 | # via python-dateutil 116 | tqdm==4.67.1 117 | # via dm-control 118 | typeguard==4.4.1 119 | # via tyro 120 | typing-extensions==4.12.2 121 | # via 122 | # -r examples/aloha_sim/requirements.in 123 | # gymnasium 124 | # rich 125 | # typeguard 126 | # tyro 127 | tyro==0.9.2 128 | # via -r examples/aloha_sim/requirements.in 129 | urllib3==2.2.3 130 | # via requests 131 | websockets==14.1 132 | # via -r examples/aloha_sim/requirements.in 133 | -------------------------------------------------------------------------------- /openpi-SF/src/openpi/models/model_test.py: -------------------------------------------------------------------------------- 1 | from flax import nnx 2 | import jax 3 | import pytest 4 | 5 | from openpi.models import model as _model 6 | from openpi.models import pi0_config 7 | from openpi.models import pi0_fast 8 | from openpi.shared import download 9 | from openpi.shared import nnx_utils 10 | 11 | 12 | def test_pi0_model(): 13 | key = jax.random.key(0) 14 | config = pi0_config.Pi0Config() 15 | model = config.create(key) 16 | 17 | batch_size = 2 18 | obs, act = config.fake_obs(batch_size), config.fake_act(batch_size) 19 | 20 | loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act) 21 | assert loss.shape == (batch_size, config.action_horizon) 22 | 23 | actions = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10) 24 | assert actions.shape == (batch_size, model.action_horizon, model.action_dim) 25 | 26 | 27 | def test_pi0_lora_model(): 28 | key = jax.random.key(0) 29 | config = pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora") 30 | model = config.create(key) 31 | 32 | batch_size = 2 33 | obs, act = config.fake_obs(batch_size), config.fake_act(batch_size) 34 | 35 | loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act) 36 | assert loss.shape == (batch_size, config.action_horizon) 37 | 38 | actions = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10) 39 | assert actions.shape == (batch_size, model.action_horizon, model.action_dim) 40 | 41 | 42 | def test_pi0_fast_model(): 43 | key = jax.random.key(0) 44 | config = pi0_fast.Pi0FASTConfig() 45 | model = config.create(key) 46 | 47 | batch_size = 2 48 | obs, act = config.fake_obs(batch_size), config.fake_act(batch_size) 49 | 50 | loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act) 51 | assert loss.shape == (batch_size,) 52 | 53 | actions = nnx_utils.module_jit(model.sample_actions)(key, obs) 54 | assert actions.shape == (batch_size, 256) 55 | 56 | 57 | def test_pi0_fast_lora_model(): 58 | key = jax.random.key(0) 59 | config = pi0_fast.Pi0FASTConfig(paligemma_variant="gemma_2b_lora") 60 | model = config.create(key) 61 | 62 | batch_size = 2 63 | obs, act = config.fake_obs(batch_size), config.fake_act(batch_size) 64 | 65 | loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act) 66 | assert loss.shape == (batch_size,) 67 | 68 | actions = nnx_utils.module_jit(model.sample_actions)(key, obs) 69 | assert actions.shape == (batch_size, 256) 70 | 71 | lora_filter = nnx_utils.PathRegex(".*lora.*") 72 | model_state = nnx.state(model) 73 | 74 | lora_state_elems = list(model_state.filter(lora_filter)) 75 | assert len(lora_state_elems) > 0 76 | 77 | 78 | @pytest.mark.manual 79 | def test_model_restore(): 80 | key = jax.random.key(0) 81 | config = pi0_config.Pi0Config() 82 | 83 | batch_size = 2 84 | obs, act = config.fake_obs(batch_size), config.fake_act(batch_size) 85 | 86 | model = config.load( 87 | _model.restore_params(download.maybe_download("gs://openpi-assets/checkpoints/pi0_base/params")) 88 | ) 89 | 90 | loss = model.compute_loss(key, obs, act) 91 | assert loss.shape == (batch_size, config.action_horizon) 92 | 93 | actions = model.sample_actions(key, obs, num_steps=10) 94 | assert actions.shape == (batch_size, model.action_horizon, model.action_dim) 95 | -------------------------------------------------------------------------------- /openvla-SF/vla-scripts/merge_lora_weights_and_save.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loads a checkpoint that only has a LoRA adapter (no merged model) and merges the adapter 3 | into the base OpenVLA model. Saves the final checkpoint in the same directory. 4 | 5 | Make sure to specify the correct base checkpoint when running this script. For example, 6 | - if you fine-tuned the default OpenVLA-7B model without modifications, then `--base_checkpoint=="openvla/openvla-7b"` 7 | - if you fine-tuned a different model or resumed fine-tuning from a different checkpoint, then specify that base checkpoint 8 | - if you fine-tuned the default OpenVLA-7B model with modifications to `modeling_prismatic.py` (OpenVLA class definition), 9 | then the base checkpoint path should point to the checkpoint containing the modifications 10 | 11 | Usage: 12 | python vla-scripts/merge_lora_weights_and_save.py \ 13 | --base_checkpoint openvla/openvla-7b \ 14 | --lora_finetuned_checkpoint_dir /PATH/TO/CHECKPOINT/DIR/ 15 | """ 16 | 17 | import os 18 | import time 19 | from dataclasses import dataclass 20 | from pathlib import Path 21 | from typing import Union 22 | 23 | import draccus 24 | import torch 25 | from peft import PeftModel 26 | from transformers import AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor 27 | 28 | from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig 29 | from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction 30 | from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor 31 | 32 | 33 | @dataclass 34 | class ConvertConfig: 35 | # fmt: off 36 | 37 | base_checkpoint: Union[str, Path] = "" # Base model checkpoint path/dir (either openvla/openvla-7b or whichever model you fine-tuned / resumed training from) 38 | lora_finetuned_checkpoint_dir: Union[str, Path] = "" # Checkpoint directory containing the LoRA adapter 39 | 40 | # fmt: on 41 | 42 | 43 | @draccus.wrap() 44 | def main(cfg: ConvertConfig) -> None: 45 | # Register OpenVLA model to HF Auto Classes (not needed if the model is on HF Hub) 46 | AutoConfig.register("openvla", OpenVLAConfig) 47 | AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor) 48 | AutoProcessor.register(OpenVLAConfig, PrismaticProcessor) 49 | AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction) 50 | 51 | # Load Model using HF AutoClasses 52 | print(f"Loading base model: {cfg.base_checkpoint}") 53 | vla = AutoModelForVision2Seq.from_pretrained( 54 | cfg.base_checkpoint, 55 | torch_dtype=torch.bfloat16, 56 | low_cpu_mem_usage=True, 57 | trust_remote_code=True, 58 | ) 59 | 60 | # Load LoRA weights and merge into base model, then save final checkpoint 61 | print("Merging LoRA weights into base model...") 62 | start_time = time.time() 63 | merged_vla = PeftModel.from_pretrained(vla, os.path.join(cfg.lora_finetuned_checkpoint_dir, "lora_adapter")).to( 64 | "cuda" 65 | ) 66 | merged_vla = merged_vla.merge_and_unload() 67 | merged_vla.save_pretrained(cfg.lora_finetuned_checkpoint_dir) 68 | print(f"\nMerging complete! Time elapsed (sec): {time.time() - start_time}") 69 | print(f"\nSaved merged model checkpoint at:\n{cfg.lora_finetuned_checkpoint_dir}") 70 | 71 | 72 | if __name__ == "__main__": 73 | main() 74 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/models/backbones/llm/prompting/vicuna_v15_prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | vicuna_v15_prompter.py 3 | 4 | Defines a PromptBuilder for building Vicuna-v1.5 Chat Prompts. 5 | 6 | Reference: https://huggingface.co/lmsys/vicuna-13b-v1.5 7 | """ 8 | 9 | from typing import Optional 10 | 11 | from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder 12 | 13 | # Default System Prompt for LLaVa Models 14 | SYS_PROMPTS = { 15 | "prismatic": ( 16 | "A chat between a curious user and an artificial intelligence assistant. " 17 | "The assistant gives helpful, detailed, and polite answers to the user's questions." 18 | ), 19 | "openvla": ( 20 | "A chat between a curious user and an artificial intelligence assistant. " 21 | "The assistant gives helpful, detailed, and polite answers to the user's questions." 22 | ), 23 | } 24 | 25 | 26 | class VicunaV15ChatPromptBuilder(PromptBuilder): 27 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 28 | super().__init__(model_family, system_prompt) 29 | self.system_prompt = (SYS_PROMPTS[self.model_family] if system_prompt is None else system_prompt).strip() + " " 30 | 31 | # LLaMa-2 Specific 32 | self.bos, self.eos = "", "" 33 | 34 | # Get role-specific "wrap" functions 35 | self.wrap_human = lambda msg: f"USER: {msg} ASSISTANT: " 36 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" 37 | 38 | # === `self.prompt` gets built up over multiple turns === 39 | self.prompt, self.turn_count = "", 0 40 | 41 | def add_turn(self, role: str, message: str) -> str: 42 | assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") 43 | message = message.replace("", "").strip() 44 | 45 | # Special Handling for "system" prompt (turn_count == 0) 46 | if self.turn_count == 0: 47 | sys_message = self.system_prompt + self.wrap_human(message) 48 | wrapped_message = sys_message 49 | elif (self.turn_count % 2) == 0: 50 | human_message = self.wrap_human(message) 51 | wrapped_message = human_message 52 | else: 53 | gpt_message = self.wrap_gpt(message) 54 | wrapped_message = gpt_message 55 | 56 | # Update Prompt 57 | self.prompt += wrapped_message 58 | 59 | # Bump Turn Counter 60 | self.turn_count += 1 61 | 62 | # Return "wrapped_message" (effective string added to context) 63 | return wrapped_message 64 | 65 | def get_potential_prompt(self, message: str) -> None: 66 | # Assumes that it's always the user's (human's) turn! 67 | prompt_copy = str(self.prompt) 68 | 69 | # Special Handling for "system" prompt (turn_count == 0) 70 | if self.turn_count == 0: 71 | sys_message = self.system_prompt + self.wrap_human(message) 72 | prompt_copy += sys_message 73 | 74 | else: 75 | human_message = self.wrap_human(message) 76 | prompt_copy += human_message 77 | 78 | return prompt_copy.removeprefix(self.bos).rstrip() 79 | 80 | def get_prompt(self) -> str: 81 | # Remove prefix (if exists) because it gets auto-inserted by tokenizer! 82 | return self.prompt.removeprefix(self.bos).rstrip() 83 | -------------------------------------------------------------------------------- /openvla-SF/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "spatial-forcing" 7 | authors = [ 8 | {name = "Fuhao Li", email="lfh23@mails.tsinghua.edu.cn"}, 9 | {name = "Wenxuan Song", email="songwenxuan0115@gmail.com"}, 10 | ] 11 | description = "Spatial Forcing: Implicit Spatial Representation Alignment for Vision-language-action Model" 12 | version = "0.0.1" 13 | readme = "README.md" 14 | requires-python = ">=3.8" 15 | keywords = ["vision-language-actions models", "representation supervision", "robot learning"] 16 | license = {file = "LICENSE"} 17 | classifiers = [ 18 | "Development Status :: 3 - Alpha", 19 | "Intended Audience :: Developers", 20 | "Intended Audience :: Education", 21 | "Intended Audience :: Science/Research", 22 | "License :: OSI Approved :: MIT License", 23 | "Operating System :: OS Independent", 24 | "Programming Language :: Python :: 3", 25 | "Programming Language :: Python :: 3.8", 26 | "Programming Language :: Python :: 3.9", 27 | "Programming Language :: Python :: 3.10", 28 | "Programming Language :: Python :: 3 :: Only", 29 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 30 | ] 31 | dependencies = [ 32 | "accelerate>=0.25.0", 33 | "draccus==0.8.0", 34 | "einops", 35 | # "flash_attn==2.5.5", # Here for documentation -- install *AFTER* editable install (follow README) 36 | "huggingface_hub", 37 | "numpy==1.24.4", 38 | "jsonlines", 39 | "matplotlib", 40 | "peft==0.11.1", 41 | "protobuf", 42 | "rich", 43 | "sentencepiece==0.1.99", 44 | "timm==0.9.10", 45 | "tokenizers==0.19.1", 46 | "torch==2.2.0", 47 | "torchvision==0.17.0", 48 | "torchaudio==2.2.0", 49 | "transformers @ git+https://github.com/moojink/transformers-openvla-oft.git", # IMPORTANT: Use this fork for bidirectional attn (for parallel decoding) 50 | "wandb", 51 | "tensorflow==2.15.0", 52 | "tensorflow_datasets==4.9.3", 53 | "tensorflow_graphics==2021.12.3", 54 | "dlimp @ git+https://github.com/moojink/dlimp_openvla", 55 | "diffusers==0.33.0", 56 | "imageio", 57 | "uvicorn", 58 | "fastapi", 59 | "json-numpy", 60 | "pillow", 61 | "safetensors", 62 | ] 63 | 64 | [project.optional-dependencies] 65 | dev = [ 66 | "black>=24.2.0", 67 | "gpustat", 68 | "ipython", 69 | "pre-commit", 70 | "ruff>=0.2.2", 71 | ] 72 | sagemaker = [ 73 | "boto3", 74 | "sagemaker" 75 | ] 76 | 77 | [project.urls] 78 | homepage = "https://github.com/OpenHelix-Team/Spatial-Forcing" 79 | repository = "https://github.com/OpenHelix-Team/Spatial-Forcing" 80 | documentation = "https://github.com/OpenHelix-Team/Spatial-Forcing" 81 | 82 | [tool.setuptools.packages.find] 83 | where = ["."] 84 | exclude = ["cache"] 85 | 86 | [tool.setuptools.package-data] 87 | "prismatic" = ["py.typed"] 88 | "vggt" = ["py.typed"] 89 | 90 | [tool.black] 91 | line-length = 121 92 | target-version = ["py38", "py39", "py310"] 93 | preview = true 94 | 95 | [tool.ruff] 96 | line-length = 121 97 | target-version = "py38" 98 | 99 | [tool.ruff.lint] 100 | select = ["A", "B", "E", "F", "I", "RUF", "W"] 101 | ignore = ["F722"] 102 | 103 | [tool.ruff.lint.per-file-ignores] 104 | "__init__.py" = ["E402", "F401"] 105 | -------------------------------------------------------------------------------- /openvla-SF/experiments/robot/aloha/aloha_utils.py: -------------------------------------------------------------------------------- 1 | """Utils for evaluating policies in real-world ALOHA environments.""" 2 | 3 | import os 4 | 5 | import imageio 6 | import numpy as np 7 | from PIL import Image 8 | 9 | from experiments.robot.aloha.real_env import make_real_env 10 | from experiments.robot.robot_utils import ( 11 | DATE, 12 | DATE_TIME, 13 | ) 14 | 15 | 16 | def get_next_task_label(task_label): 17 | """Prompt the user to input the next task.""" 18 | if task_label == "": 19 | user_input = "" 20 | while user_input == "": 21 | user_input = input("Enter the task name: ") 22 | task_label = user_input 23 | else: 24 | user_input = input("Enter the task name (or leave blank to repeat the previous task): ") 25 | if user_input == "": 26 | pass # Do nothing -> Let task_label be the same 27 | else: 28 | task_label = user_input 29 | print(f"Task: {task_label}") 30 | return task_label 31 | 32 | 33 | def get_aloha_env(): 34 | """Initializes and returns the ALOHA environment.""" 35 | env = make_real_env(init_node=True) 36 | return env 37 | 38 | 39 | def resize_image_for_preprocessing(img): 40 | """ 41 | Takes numpy array corresponding to a single image and resizes to 256x256, exactly as done 42 | in the ALOHA data preprocessing script, which is used before converting the dataset to RLDS. 43 | """ 44 | ALOHA_PREPROCESS_SIZE = 256 45 | img = np.array( 46 | Image.fromarray(img).resize((ALOHA_PREPROCESS_SIZE, ALOHA_PREPROCESS_SIZE), resample=Image.BICUBIC) 47 | ) # BICUBIC is default; specify explicitly to make it clear 48 | return img 49 | 50 | 51 | def get_aloha_image(obs): 52 | """Extracts third-person image from observations and preprocesses it.""" 53 | # obs: dm_env._environment.TimeStep 54 | img = obs.observation["images"]["cam_high"] 55 | img = resize_image_for_preprocessing(img) 56 | return img 57 | 58 | 59 | def get_aloha_wrist_images(obs): 60 | """Extracts both wrist camera images from observations and preprocesses them.""" 61 | # obs: dm_env._environment.TimeStep 62 | left_wrist_img = obs.observation["images"]["cam_left_wrist"] 63 | right_wrist_img = obs.observation["images"]["cam_right_wrist"] 64 | left_wrist_img = resize_image_for_preprocessing(left_wrist_img) 65 | right_wrist_img = resize_image_for_preprocessing(right_wrist_img) 66 | return left_wrist_img, right_wrist_img 67 | 68 | 69 | def save_rollout_video(rollout_images, idx, success, task_description, log_file=None, notes=None): 70 | """Saves an MP4 replay of an episode.""" 71 | rollout_dir = f"./rollouts/{DATE}" 72 | os.makedirs(rollout_dir, exist_ok=True) 73 | processed_task_description = task_description.lower().replace(" ", "_").replace("\n", "_").replace(".", "_")[:50] 74 | filetag = f"{rollout_dir}/{DATE_TIME}--openvla_oft--episode={idx}--success={success}--task={processed_task_description}" 75 | if notes is not None: 76 | filetag += f"--{notes}" 77 | mp4_path = f"{filetag}.mp4" 78 | video_writer = imageio.get_writer(mp4_path, fps=25) 79 | for img in rollout_images: 80 | video_writer.append_data(img) 81 | video_writer.close() 82 | print(f"Saved rollout MP4 at path {mp4_path}") 83 | if log_file is not None: 84 | log_file.write(f"Saved rollout MP4 at path {mp4_path}\n") 85 | return mp4_path 86 | -------------------------------------------------------------------------------- /openpi-SF/src/openpi/serving/websocket_policy_server.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import http 3 | import logging 4 | import time 5 | import traceback 6 | 7 | from openpi_client import base_policy as _base_policy 8 | from openpi_client import msgpack_numpy 9 | import websockets.asyncio.server as _server 10 | import websockets.frames 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class WebsocketPolicyServer: 16 | """Serves a policy using the websocket protocol. See websocket_client_policy.py for a client implementation. 17 | 18 | Currently only implements the `load` and `infer` methods. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | policy: _base_policy.BasePolicy, 24 | host: str = "0.0.0.0", 25 | port: int | None = None, 26 | metadata: dict | None = None, 27 | ) -> None: 28 | self._policy = policy 29 | self._host = host 30 | self._port = port 31 | self._metadata = metadata or {} 32 | logging.getLogger("websockets.server").setLevel(logging.INFO) 33 | 34 | def serve_forever(self) -> None: 35 | asyncio.run(self.run()) 36 | 37 | async def run(self): 38 | async with _server.serve( 39 | self._handler, 40 | self._host, 41 | self._port, 42 | compression=None, 43 | max_size=None, 44 | process_request=_health_check, 45 | ) as server: 46 | await server.serve_forever() 47 | 48 | async def _handler(self, websocket: _server.ServerConnection): 49 | logger.info(f"Connection from {websocket.remote_address} opened") 50 | packer = msgpack_numpy.Packer() 51 | 52 | await websocket.send(packer.pack(self._metadata)) 53 | 54 | prev_total_time = None 55 | while True: 56 | try: 57 | start_time = time.monotonic() 58 | obs = msgpack_numpy.unpackb(await websocket.recv()) 59 | 60 | infer_time = time.monotonic() 61 | action = self._policy.infer(obs) 62 | infer_time = time.monotonic() - infer_time 63 | 64 | action["server_timing"] = { 65 | "infer_ms": infer_time * 1000, 66 | } 67 | if prev_total_time is not None: 68 | # We can only record the last total time since we also want to include the send time. 69 | action["server_timing"]["prev_total_ms"] = prev_total_time * 1000 70 | 71 | await websocket.send(packer.pack(action)) 72 | prev_total_time = time.monotonic() - start_time 73 | 74 | except websockets.ConnectionClosed: 75 | logger.info(f"Connection from {websocket.remote_address} closed") 76 | break 77 | except Exception: 78 | await websocket.send(traceback.format_exc()) 79 | await websocket.close( 80 | code=websockets.frames.CloseCode.INTERNAL_ERROR, 81 | reason="Internal server error. Traceback included in previous frame.", 82 | ) 83 | raise 84 | 85 | 86 | def _health_check(connection: _server.ServerConnection, request: _server.Request) -> _server.Response | None: 87 | if request.path == "/healthz": 88 | return connection.respond(http.HTTPStatus.OK, "OK\n") 89 | # Continue with the normal request handling. 90 | return None 91 | -------------------------------------------------------------------------------- /openvla-SF/vggt/layers/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 9 | 10 | import logging 11 | import os 12 | import warnings 13 | 14 | from torch import Tensor 15 | from torch import nn 16 | import torch.nn.functional as F 17 | 18 | XFORMERS_AVAILABLE = False 19 | 20 | 21 | class Attention(nn.Module): 22 | def __init__( 23 | self, 24 | dim: int, 25 | num_heads: int = 8, 26 | qkv_bias: bool = True, 27 | proj_bias: bool = True, 28 | attn_drop: float = 0.0, 29 | proj_drop: float = 0.0, 30 | norm_layer: nn.Module = nn.LayerNorm, 31 | qk_norm: bool = False, 32 | fused_attn: bool = True, # use F.scaled_dot_product_attention or not 33 | rope=None, 34 | ) -> None: 35 | super().__init__() 36 | assert dim % num_heads == 0, "dim should be divisible by num_heads" 37 | self.num_heads = num_heads 38 | self.head_dim = dim // num_heads 39 | self.scale = self.head_dim**-0.5 40 | self.fused_attn = fused_attn 41 | 42 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 43 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 44 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 45 | self.attn_drop = nn.Dropout(attn_drop) 46 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 47 | self.proj_drop = nn.Dropout(proj_drop) 48 | self.rope = rope 49 | 50 | def forward(self, x: Tensor, pos=None) -> Tensor: 51 | B, N, C = x.shape 52 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 53 | q, k, v = qkv.unbind(0) 54 | q, k = self.q_norm(q), self.k_norm(k) 55 | 56 | if self.rope is not None: 57 | q = self.rope(q, pos) 58 | k = self.rope(k, pos) 59 | 60 | if self.fused_attn: 61 | x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0) 62 | else: 63 | q = q * self.scale 64 | attn = q @ k.transpose(-2, -1) 65 | attn = attn.softmax(dim=-1) 66 | attn = self.attn_drop(attn) 67 | x = attn @ v 68 | 69 | x = x.transpose(1, 2).reshape(B, N, C) 70 | x = self.proj(x) 71 | x = self.proj_drop(x) 72 | return x 73 | 74 | 75 | class MemEffAttention(Attention): 76 | def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor: 77 | assert pos is None 78 | if not XFORMERS_AVAILABLE: 79 | if attn_bias is not None: 80 | raise AssertionError("xFormers is required for using nested tensors") 81 | return super().forward(x) 82 | 83 | B, N, C = x.shape 84 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 85 | 86 | q, k, v = unbind(qkv, 2) 87 | 88 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 89 | x = x.reshape([B, N, C]) 90 | 91 | x = self.proj(x) 92 | x = self.proj_drop(x) 93 | return x 94 | -------------------------------------------------------------------------------- /openpi-SF/src/vggt/layers/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 9 | 10 | import logging 11 | import os 12 | import warnings 13 | 14 | from torch import Tensor 15 | from torch import nn 16 | import torch.nn.functional as F 17 | 18 | XFORMERS_AVAILABLE = False 19 | 20 | 21 | class Attention(nn.Module): 22 | def __init__( 23 | self, 24 | dim: int, 25 | num_heads: int = 8, 26 | qkv_bias: bool = True, 27 | proj_bias: bool = True, 28 | attn_drop: float = 0.0, 29 | proj_drop: float = 0.0, 30 | norm_layer: nn.Module = nn.LayerNorm, 31 | qk_norm: bool = False, 32 | fused_attn: bool = True, # use F.scaled_dot_product_attention or not 33 | rope=None, 34 | ) -> None: 35 | super().__init__() 36 | assert dim % num_heads == 0, "dim should be divisible by num_heads" 37 | self.num_heads = num_heads 38 | self.head_dim = dim // num_heads 39 | self.scale = self.head_dim**-0.5 40 | self.fused_attn = fused_attn 41 | 42 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 43 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 44 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 45 | self.attn_drop = nn.Dropout(attn_drop) 46 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 47 | self.proj_drop = nn.Dropout(proj_drop) 48 | self.rope = rope 49 | 50 | def forward(self, x: Tensor, pos=None) -> Tensor: 51 | B, N, C = x.shape 52 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 53 | q, k, v = qkv.unbind(0) 54 | q, k = self.q_norm(q), self.k_norm(k) 55 | 56 | if self.rope is not None: 57 | q = self.rope(q, pos) 58 | k = self.rope(k, pos) 59 | 60 | if self.fused_attn: 61 | x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0) 62 | else: 63 | q = q * self.scale 64 | attn = q @ k.transpose(-2, -1) 65 | attn = attn.softmax(dim=-1) 66 | attn = self.attn_drop(attn) 67 | x = attn @ v 68 | 69 | x = x.transpose(1, 2).reshape(B, N, C) 70 | x = self.proj(x) 71 | x = self.proj_drop(x) 72 | return x 73 | 74 | 75 | class MemEffAttention(Attention): 76 | def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor: 77 | assert pos is None 78 | if not XFORMERS_AVAILABLE: 79 | if attn_bias is not None: 80 | raise AssertionError("xFormers is required for using nested tensors") 81 | return super().forward(x) 82 | 83 | B, N, C = x.shape 84 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 85 | 86 | q, k, v = unbind(qkv, 2) 87 | 88 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 89 | x = x.reshape([B, N, C]) 90 | 91 | x = self.proj(x) 92 | x = self.proj_drop(x) 93 | return x 94 | -------------------------------------------------------------------------------- /openpi-SF/packages/openpi-client/src/openpi_client/runtime/runtime.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import threading 3 | import time 4 | 5 | from openpi_client.runtime import agent as _agent 6 | from openpi_client.runtime import environment as _environment 7 | from openpi_client.runtime import subscriber as _subscriber 8 | 9 | 10 | class Runtime: 11 | """The core module orchestrating interactions between key components of the system.""" 12 | 13 | def __init__( 14 | self, 15 | environment: _environment.Environment, 16 | agent: _agent.Agent, 17 | subscribers: list[_subscriber.Subscriber], 18 | max_hz: float = 0, 19 | num_episodes: int = 1, 20 | max_episode_steps: int = 0, 21 | ) -> None: 22 | self._environment = environment 23 | self._agent = agent 24 | self._subscribers = subscribers 25 | self._max_hz = max_hz 26 | self._num_episodes = num_episodes 27 | self._max_episode_steps = max_episode_steps 28 | 29 | self._in_episode = False 30 | self._episode_steps = 0 31 | 32 | def run(self) -> None: 33 | """Runs the runtime loop continuously until stop() is called or the environment is done.""" 34 | for _ in range(self._num_episodes): 35 | self._run_episode() 36 | 37 | # Final reset, this is important for real environments to move the robot to its home position. 38 | self._environment.reset() 39 | 40 | def run_in_new_thread(self) -> threading.Thread: 41 | """Runs the runtime loop in a new thread.""" 42 | thread = threading.Thread(target=self.run) 43 | thread.start() 44 | return thread 45 | 46 | def mark_episode_complete(self) -> None: 47 | """Marks the end of an episode.""" 48 | self._in_episode = False 49 | 50 | def _run_episode(self) -> None: 51 | """Runs a single episode.""" 52 | logging.info("Starting episode...") 53 | self._environment.reset() 54 | self._agent.reset() 55 | for subscriber in self._subscribers: 56 | subscriber.on_episode_start() 57 | 58 | self._in_episode = True 59 | self._episode_steps = 0 60 | step_time = 1 / self._max_hz if self._max_hz > 0 else 0 61 | last_step_time = time.time() 62 | 63 | while self._in_episode: 64 | self._step() 65 | self._episode_steps += 1 66 | 67 | # Sleep to maintain the desired frame rate 68 | now = time.time() 69 | dt = now - last_step_time 70 | if dt < step_time: 71 | time.sleep(step_time - dt) 72 | last_step_time = time.time() 73 | else: 74 | last_step_time = now 75 | 76 | logging.info("Episode completed.") 77 | for subscriber in self._subscribers: 78 | subscriber.on_episode_end() 79 | 80 | def _step(self) -> None: 81 | """A single step of the runtime loop.""" 82 | observation = self._environment.get_observation() 83 | action = self._agent.get_action(observation) 84 | self._environment.apply_action(action) 85 | 86 | for subscriber in self._subscribers: 87 | subscriber.on_step(observation, action) 88 | 89 | if self._environment.is_episode_complete() or ( 90 | self._max_episode_steps > 0 and self._episode_steps >= self._max_episode_steps 91 | ): 92 | self.mark_episode_complete() 93 | -------------------------------------------------------------------------------- /openpi-SF/src/openpi/policies/droid_policy.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | import einops 4 | import numpy as np 5 | 6 | from openpi import transforms 7 | from openpi.models import model as _model 8 | 9 | 10 | def make_droid_example() -> dict: 11 | """Creates a random input example for the Droid policy.""" 12 | return { 13 | "observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), 14 | "observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), 15 | "observation/joint_position": np.random.rand(7), 16 | "observation/gripper_position": np.random.rand(1), 17 | "prompt": "do something", 18 | } 19 | 20 | 21 | def _parse_image(image) -> np.ndarray: 22 | image = np.asarray(image) 23 | if np.issubdtype(image.dtype, np.floating): 24 | image = (255 * image).astype(np.uint8) 25 | if image.shape[0] == 3: 26 | image = einops.rearrange(image, "c h w -> h w c") 27 | return image 28 | 29 | 30 | @dataclasses.dataclass(frozen=True) 31 | class DroidInputs(transforms.DataTransformFn): 32 | # Determines which model will be used. 33 | model_type: _model.ModelType 34 | 35 | def __call__(self, data: dict) -> dict: 36 | gripper_pos = np.asarray(data["observation/gripper_position"]) 37 | if gripper_pos.ndim == 0: 38 | # Ensure gripper position is a 1D array, not a scalar, so we can concatenate with joint positions 39 | gripper_pos = gripper_pos[np.newaxis] 40 | state = np.concatenate([data["observation/joint_position"], gripper_pos]) 41 | 42 | # Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically 43 | # stores as float32 (C,H,W), gets skipped for policy inference 44 | base_image = _parse_image(data["observation/exterior_image_1_left"]) 45 | wrist_image = _parse_image(data["observation/wrist_image_left"]) 46 | 47 | match self.model_type: 48 | case _model.ModelType.PI0 | _model.ModelType.PI05: 49 | names = ("base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb") 50 | images = (base_image, wrist_image, np.zeros_like(base_image)) 51 | image_masks = (np.True_, np.True_, np.False_) 52 | case _model.ModelType.PI0_FAST: 53 | names = ("base_0_rgb", "base_1_rgb", "wrist_0_rgb") 54 | # We don't mask out padding images for FAST models. 55 | images = (base_image, np.zeros_like(base_image), wrist_image) 56 | image_masks = (np.True_, np.True_, np.True_) 57 | case _: 58 | raise ValueError(f"Unsupported model type: {self.model_type}") 59 | 60 | inputs = { 61 | "state": state, 62 | "image": dict(zip(names, images, strict=True)), 63 | "image_mask": dict(zip(names, image_masks, strict=True)), 64 | } 65 | 66 | if "actions" in data: 67 | inputs["actions"] = np.asarray(data["actions"]) 68 | 69 | if "prompt" in data: 70 | if isinstance(data["prompt"], bytes): 71 | data["prompt"] = data["prompt"].decode("utf-8") 72 | inputs["prompt"] = data["prompt"] 73 | 74 | return inputs 75 | 76 | 77 | @dataclasses.dataclass(frozen=True) 78 | class DroidOutputs(transforms.DataTransformFn): 79 | def __call__(self, data: dict) -> dict: 80 | # Only return the first 8 dims. 81 | return {"actions": np.asarray(data["actions"][:, :8])} 82 | -------------------------------------------------------------------------------- /openpi-SF/src/openpi/models/lora_test.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | import jax 3 | import jax.numpy as jnp 4 | 5 | import openpi.models.lora as lora 6 | 7 | 8 | def test_lora_einsum_params_shape(): 9 | shape = (3, 8, 32, 4) # (3KDH) 10 | einsum = lora.Einsum(shape) 11 | lora0 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2)) 12 | lora1 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2, axes=(1, 2))) 13 | 14 | key = jax.random.key(0) 15 | x = jax.random.normal(key, (8, 64, 32)) # (BSD) 16 | eqn = "BSD,3KDH->3BSKH" 17 | 18 | # Ensure that lora parameters are not initialized when LoRA is not used. 19 | params = einsum.init(key, eqn, x) 20 | assert "lora_a" not in params["params"] 21 | assert "lora_b" not in params["params"] 22 | 23 | # Check that default axes work. 24 | params_lora0 = lora0.init(key, eqn, x) 25 | assert params_lora0["params"]["lora_a"].shape == (3, 8, 32, 2) 26 | assert params_lora0["params"]["lora_b"].shape == (3, 8, 2, 4) 27 | 28 | # Check that user provided axes work. 29 | params_lora1 = lora1.init(key, eqn, x) 30 | assert params_lora1["params"]["lora_a"].shape == (3, 8, 2, 4) 31 | assert params_lora1["params"]["lora_b"].shape == (3, 2, 32, 4) 32 | 33 | 34 | def test_lora_einsum_same_output(): 35 | shape = (3, 8, 32, 4) # (3KDH) 36 | einsum = lora.Einsum(shape) 37 | einsum_lora = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros)) 38 | 39 | key = jax.random.key(0) 40 | x = jax.random.normal(key, (8, 64, 32)) # (BSD) 41 | eqn = "BSD,3KDH->3BSKH" 42 | 43 | params = einsum.init(key, eqn, x) 44 | output = einsum.apply(params, eqn, x) 45 | 46 | params_lora = einsum_lora.init(key, eqn, x) 47 | output_lora = einsum_lora.apply(params_lora, eqn, x) 48 | 49 | # Results are the same since the LoRA parameters are initialized to zeros. 50 | assert jnp.allclose(output, output_lora) 51 | 52 | 53 | def test_lora_ffn_params_shape(): 54 | ffn = lora.FeedForward(features=8, hidden_dim=32) 55 | ffn_lora = lora.FeedForward( 56 | features=8, 57 | hidden_dim=32, 58 | lora_config=lora.LoRAConfig(rank=2), 59 | ) 60 | 61 | key = jax.random.key(0) 62 | x = jax.random.normal(key, (2, 8)) 63 | 64 | params = ffn.init(key, x) 65 | assert params["params"]["gating_einsum"].shape == (2, 8, 32) 66 | assert params["params"]["linear"].shape == (32, 8) 67 | 68 | params_lora = ffn_lora.init(key, x) 69 | assert params_lora["params"]["gating_einsum"].shape == (2, 8, 32) 70 | assert params_lora["params"]["linear"].shape == (32, 8) 71 | assert params_lora["params"]["gating_einsum_lora_a"].shape == (2, 8, 2) 72 | assert params_lora["params"]["gating_einsum_lora_b"].shape == (2, 2, 32) 73 | assert params_lora["params"]["linear_lora_a"].shape == (32, 2) 74 | assert params_lora["params"]["linear_lora_b"].shape == (2, 8) 75 | 76 | 77 | def test_lora_ffn_same_output(): 78 | ffn = lora.FeedForward(features=8, hidden_dim=32) 79 | ffn_lora = lora.FeedForward( 80 | features=8, 81 | hidden_dim=32, 82 | lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros), 83 | ) 84 | 85 | key = jax.random.key(0) 86 | x = jax.random.normal(key, (2, 8)) 87 | 88 | params = ffn.init(key, x) 89 | output = ffn.apply(params, x) 90 | 91 | params_lora = ffn_lora.init(key, x) 92 | output_lora = ffn_lora.apply(params_lora, x) 93 | 94 | assert jnp.allclose(output, output_lora) 95 | -------------------------------------------------------------------------------- /openvla-SF/experiments/robot/libero/libero_utils.py: -------------------------------------------------------------------------------- 1 | """Utils for evaluating policies in LIBERO simulation environments.""" 2 | 3 | import math 4 | import os 5 | 6 | import imageio 7 | import numpy as np 8 | import tensorflow as tf 9 | from libero.libero import get_libero_path 10 | from libero.libero.envs import OffScreenRenderEnv 11 | 12 | from experiments.robot.robot_utils import ( 13 | DATE, 14 | DATE_TIME, 15 | ) 16 | 17 | 18 | def get_libero_env(task, model_family, resolution=256): 19 | """Initializes and returns the LIBERO environment, along with the task description.""" 20 | task_description = task.language 21 | task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file) 22 | env_args = {"bddl_file_name": task_bddl_file, "camera_heights": resolution, "camera_widths": resolution} 23 | env = OffScreenRenderEnv(**env_args) 24 | env.seed(0) # IMPORTANT: seed seems to affect object positions even when using fixed initial state 25 | return env, task_description 26 | 27 | 28 | def get_libero_dummy_action(model_family: str): 29 | """Get dummy/no-op action, used to roll out the simulation while the robot does nothing.""" 30 | return [0, 0, 0, 0, 0, 0, -1] 31 | 32 | 33 | def get_libero_image(obs): 34 | """Extracts third-person image from observations and preprocesses it.""" 35 | img = obs["agentview_image"] 36 | img = img[::-1, ::-1] # IMPORTANT: rotate 180 degrees to match train preprocessing 37 | return img 38 | 39 | 40 | def get_libero_wrist_image(obs): 41 | """Extracts wrist camera image from observations and preprocesses it.""" 42 | img = obs["robot0_eye_in_hand_image"] 43 | img = img[::-1, ::-1] # IMPORTANT: rotate 180 degrees to match train preprocessing 44 | return img 45 | 46 | 47 | def save_rollout_video(rollout_images, idx, success, task_description, log_file=None): 48 | """Saves an MP4 replay of an episode.""" 49 | rollout_dir = f"./rollouts/{DATE}" 50 | os.makedirs(rollout_dir, exist_ok=True) 51 | processed_task_description = task_description.lower().replace(" ", "_").replace("\n", "_").replace(".", "_")[:50] 52 | mp4_path = f"{rollout_dir}/{DATE_TIME}--openvla_oft--episode={idx}--success={success}--task={processed_task_description}.mp4" 53 | video_writer = imageio.get_writer(mp4_path, fps=30) 54 | for img in rollout_images: 55 | video_writer.append_data(img) 56 | video_writer.close() 57 | print(f"Saved rollout MP4 at path {mp4_path}") 58 | if log_file is not None: 59 | log_file.write(f"Saved rollout MP4 at path {mp4_path}\n") 60 | return mp4_path 61 | 62 | 63 | def quat2axisangle(quat): 64 | """ 65 | Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55 66 | 67 | Converts quaternion to axis-angle format. 68 | Returns a unit vector direction scaled by its angle in radians. 69 | 70 | Args: 71 | quat (np.array): (x,y,z,w) vec4 float angles 72 | 73 | Returns: 74 | np.array: (ax,ay,az) axis-angle exponential coordinates 75 | """ 76 | # clip quaternion 77 | if quat[3] > 1.0: 78 | quat[3] = 1.0 79 | elif quat[3] < -1.0: 80 | quat[3] = -1.0 81 | 82 | den = np.sqrt(1.0 - quat[3] * quat[3]) 83 | if math.isclose(den, 0.0): 84 | # This is (close to) a zero degree rotation, immediately return 85 | return np.zeros(3) 86 | 87 | return (quat[:3] * 2.0 * math.acos(quat[3])) / den 88 | -------------------------------------------------------------------------------- /openpi-SF/examples/aloha_real/constants.py: -------------------------------------------------------------------------------- 1 | # Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act). 2 | # ruff: noqa 3 | 4 | ### Task parameters 5 | 6 | ### ALOHA fixed constants 7 | DT = 0.001 8 | JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"] 9 | START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239] 10 | 11 | # Left finger position limits (qpos[7]), right_finger = -1 * left_finger 12 | MASTER_GRIPPER_POSITION_OPEN = 0.02417 13 | MASTER_GRIPPER_POSITION_CLOSE = 0.01244 14 | PUPPET_GRIPPER_POSITION_OPEN = 0.05800 15 | PUPPET_GRIPPER_POSITION_CLOSE = 0.01844 16 | 17 | # Gripper joint limits (qpos[6]) 18 | MASTER_GRIPPER_JOINT_OPEN = 0.3083 19 | MASTER_GRIPPER_JOINT_CLOSE = -0.6842 20 | PUPPET_GRIPPER_JOINT_OPEN = 1.4910 21 | PUPPET_GRIPPER_JOINT_CLOSE = -0.6213 22 | 23 | ############################ Helper functions ############################ 24 | 25 | MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / ( 26 | MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE 27 | ) 28 | PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / ( 29 | PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE 30 | ) 31 | MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = ( 32 | lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE 33 | ) 34 | PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = ( 35 | lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE 36 | ) 37 | MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x)) 38 | 39 | MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / ( 40 | MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE 41 | ) 42 | PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / ( 43 | PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE 44 | ) 45 | MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = ( 46 | lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE 47 | ) 48 | PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = ( 49 | lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE 50 | ) 51 | MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x)) 52 | 53 | MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) 54 | PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) 55 | 56 | MASTER_POS2JOINT = ( 57 | lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) 58 | + MASTER_GRIPPER_JOINT_CLOSE 59 | ) 60 | MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN( 61 | (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) 62 | ) 63 | PUPPET_POS2JOINT = ( 64 | lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) 65 | + PUPPET_GRIPPER_JOINT_CLOSE 66 | ) 67 | PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN( 68 | (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) 69 | ) 70 | 71 | MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2 72 | -------------------------------------------------------------------------------- /openpi-SF/examples/libero/requirements.txt: -------------------------------------------------------------------------------- 1 | # This file was autogenerated by uv via the following command: 2 | # uv pip compile examples/libero/requirements.in -o examples/libero/requirements.txt --python-version 3.8 --index-strategy=unsafe-best-match 3 | absl-py==2.1.0 4 | # via mujoco 5 | certifi==2024.12.14 6 | # via requests 7 | charset-normalizer==3.4.0 8 | # via requests 9 | cycler==0.12.1 10 | # via matplotlib 11 | docstring-parser==0.16 12 | # via tyro 13 | etils==1.3.0 14 | # via mujoco 15 | eval-type-backport==0.2.0 16 | # via tyro 17 | evdev==1.7.1 18 | # via pynput 19 | fonttools==4.55.3 20 | # via matplotlib 21 | glfw==1.12.0 22 | # via mujoco 23 | idna==3.10 24 | # via requests 25 | imageio==2.35.1 26 | # via -r examples/libero/requirements.in 27 | imageio-ffmpeg==0.5.1 28 | # via imageio 29 | importlib-metadata==8.5.0 30 | # via typeguard 31 | importlib-resources==6.4.5 32 | # via etils 33 | kiwisolver==1.4.7 34 | # via matplotlib 35 | llvmlite==0.36.0 36 | # via numba 37 | markdown-it-py==3.0.0 38 | # via rich 39 | matplotlib==3.5.3 40 | # via -r examples/libero/requirements.in 41 | mdurl==0.1.2 42 | # via markdown-it-py 43 | mujoco==3.2.3 44 | # via robosuite 45 | numba==0.53.1 46 | # via robosuite 47 | numpy==1.22.4 48 | # via 49 | # -r examples/libero/requirements.in 50 | # imageio 51 | # matplotlib 52 | # mujoco 53 | # numba 54 | # opencv-python 55 | # robosuite 56 | # scipy 57 | # torchvision 58 | opencv-python==4.6.0.66 59 | # via 60 | # -r examples/libero/requirements.in 61 | # robosuite 62 | packaging==24.2 63 | # via matplotlib 64 | pillow==10.4.0 65 | # via 66 | # imageio 67 | # matplotlib 68 | # robosuite 69 | # torchvision 70 | psutil==6.1.0 71 | # via imageio 72 | pygments==2.18.0 73 | # via rich 74 | pynput==1.7.7 75 | # via robosuite 76 | pyopengl==3.1.7 77 | # via mujoco 78 | pyparsing==3.1.4 79 | # via matplotlib 80 | python-dateutil==2.9.0.post0 81 | # via matplotlib 82 | python-xlib==0.33 83 | # via pynput 84 | pyyaml==6.0.2 85 | # via -r examples/libero/requirements.in 86 | requests==2.32.3 87 | # via torchvision 88 | rich==13.9.4 89 | # via tyro 90 | robosuite==1.4.1 91 | # via -r examples/libero/requirements.in 92 | scipy==1.10.1 93 | # via robosuite 94 | setuptools==75.3.0 95 | # via 96 | # imageio-ffmpeg 97 | # numba 98 | shtab==1.7.1 99 | # via tyro 100 | six==1.17.0 101 | # via 102 | # pynput 103 | # python-dateutil 104 | # python-xlib 105 | termcolor==2.4.0 106 | # via robosuite 107 | torch==1.11.0+cu113 108 | # via 109 | # -r examples/libero/requirements.in 110 | # torchaudio 111 | # torchvision 112 | torchaudio==0.11.0+cu113 113 | # via -r examples/libero/requirements.in 114 | torchvision==0.12.0+cu113 115 | # via -r examples/libero/requirements.in 116 | tqdm==4.67.1 117 | # via -r examples/libero/requirements.in 118 | typeguard==4.4.0 119 | # via tyro 120 | typing-extensions==4.12.2 121 | # via 122 | # etils 123 | # rich 124 | # torch 125 | # torchvision 126 | # typeguard 127 | # tyro 128 | tyro==0.9.2 129 | # via -r examples/libero/requirements.in 130 | urllib3==2.2.3 131 | # via requests 132 | zipp==3.20.2 133 | # via 134 | # etils 135 | # importlib-metadata 136 | # importlib-resources 137 | -------------------------------------------------------------------------------- /openvla-SF/vla-scripts/extern/verify_openvla.py: -------------------------------------------------------------------------------- 1 | """ 2 | verify_openvla.py 3 | 4 | Given an HF-exported OpenVLA model, attempt to load via AutoClasses, and verify forward() and predict_action(). 5 | """ 6 | 7 | import time 8 | 9 | import numpy as np 10 | import torch 11 | from PIL import Image 12 | from transformers import AutoModelForVision2Seq, AutoProcessor 13 | 14 | # === Verification Arguments 15 | MODEL_PATH = "openvla/openvla-7b" 16 | SYSTEM_PROMPT = ( 17 | "A chat between a curious user and an artificial intelligence assistant. " 18 | "The assistant gives helpful, detailed, and polite answers to the user's questions." 19 | ) 20 | INSTRUCTION = "put spoon on towel" 21 | 22 | 23 | def get_openvla_prompt(instruction: str) -> str: 24 | if "v01" in MODEL_PATH: 25 | return f"{SYSTEM_PROMPT} USER: What action should the robot take to {instruction.lower()}? ASSISTANT:" 26 | else: 27 | return f"In: What action should the robot take to {instruction.lower()}?\nOut:" 28 | 29 | 30 | @torch.inference_mode() 31 | def verify_openvla() -> None: 32 | print(f"[*] Verifying OpenVLAForActionPrediction using Model `{MODEL_PATH}`") 33 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 34 | 35 | # Load Processor & VLA 36 | print("[*] Instantiating Processor and Pretrained OpenVLA") 37 | processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True) 38 | 39 | # === BFLOAT16 + FLASH-ATTN MODE === 40 | print("[*] Loading in BF16 with Flash-Attention Enabled") 41 | vla = AutoModelForVision2Seq.from_pretrained( 42 | MODEL_PATH, 43 | attn_implementation="flash_attention_2", 44 | torch_dtype=torch.bfloat16, 45 | low_cpu_mem_usage=True, 46 | trust_remote_code=True, 47 | ).to(device) 48 | 49 | # === 8-BIT QUANTIZATION MODE (`pip install bitsandbytes`) :: [~9GB of VRAM Passive || 10GB of VRAM Active] === 50 | # print("[*] Loading in 8-Bit Quantization Mode") 51 | # vla = AutoModelForVision2Seq.from_pretrained( 52 | # MODEL_PATH, 53 | # attn_implementation="flash_attention_2", 54 | # torch_dtype=torch.float16, 55 | # quantization_config=BitsAndBytesConfig(load_in_8bit=True), 56 | # low_cpu_mem_usage=True, 57 | # trust_remote_code=True, 58 | # ) 59 | 60 | # === 4-BIT QUANTIZATION MODE (`pip install bitsandbytes`) :: [~6GB of VRAM Passive || 7GB of VRAM Active] === 61 | # print("[*] Loading in 4-Bit Quantization Mode") 62 | # vla = AutoModelForVision2Seq.from_pretrained( 63 | # MODEL_PATH, 64 | # attn_implementation="flash_attention_2", 65 | # torch_dtype=torch.float16, 66 | # quantization_config=BitsAndBytesConfig(load_in_4bit=True), 67 | # low_cpu_mem_usage=True, 68 | # trust_remote_code=True, 69 | # ) 70 | 71 | print("[*] Iterating with Randomly Generated Images") 72 | for _ in range(100): 73 | prompt = get_openvla_prompt(INSTRUCTION) 74 | image = Image.fromarray(np.asarray(np.random.rand(256, 256, 3) * 255, dtype=np.uint8)) 75 | 76 | # === BFLOAT16 MODE === 77 | inputs = processor(prompt, image).to(device, dtype=torch.bfloat16) 78 | 79 | # === 8-BIT/4-BIT QUANTIZATION MODE === 80 | # inputs = processor(prompt, image).to(device, dtype=torch.float16) 81 | 82 | # Run OpenVLA Inference 83 | start_time = time.time() 84 | action = vla.predict_action(**inputs, unnorm_key="bridge_orig", do_sample=False) 85 | print(f"\t=>> Time: {time.time() - start_time:.4f} || Action: {action}") 86 | 87 | 88 | if __name__ == "__main__": 89 | verify_openvla() 90 | -------------------------------------------------------------------------------- /openvla-SF/prismatic/models/backbones/llm/prompting/llama2_chat_prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | llama2_prompter.py 3 | 4 | Defines a PromptBuilder for building LLaMa-2 Chat Prompts --> not sure if this is "optimal", but this is the pattern 5 | that's used by HF and other online tutorials. 6 | 7 | Reference: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 8 | """ 9 | 10 | from typing import Optional 11 | 12 | from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder 13 | 14 | # Default System Prompt for Prismatic Models 15 | SYS_PROMPTS = { 16 | "prismatic": ( 17 | "You are a helpful language and vision assistant. " 18 | "You are able to understand the visual content that the user provides, " 19 | "and assist the user with a variety of tasks using natural language." 20 | ), 21 | "openvla": ( 22 | "You are a helpful language and vision assistant. " 23 | "You are able to understand the visual content that the user provides, " 24 | "and assist the user with a variety of tasks using natural language." 25 | ), 26 | } 27 | 28 | 29 | def format_system_prompt(system_prompt: str) -> str: 30 | return f"<\n{system_prompt.strip()}\n<>\n\n" 31 | 32 | 33 | class LLaMa2ChatPromptBuilder(PromptBuilder): 34 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 35 | super().__init__(model_family, system_prompt) 36 | self.system_prompt = format_system_prompt( 37 | SYS_PROMPTS[self.model_family] if system_prompt is None else system_prompt 38 | ) 39 | 40 | # LLaMa-2 Specific 41 | self.bos, self.eos = "", "" 42 | 43 | # Get role-specific "wrap" functions 44 | self.wrap_human = lambda msg: f"[INST] {msg} [/INST] " 45 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" 46 | 47 | # === `self.prompt` gets built up over multiple turns === 48 | self.prompt, self.turn_count = "", 0 49 | 50 | def add_turn(self, role: str, message: str) -> str: 51 | assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") 52 | message = message.replace("", "").strip() 53 | 54 | # Special Handling for "system" prompt (turn_count == 0) 55 | if self.turn_count == 0: 56 | sys_message = self.wrap_human(self.system_prompt + message) 57 | wrapped_message = sys_message 58 | elif (self.turn_count % 2) == 0: 59 | human_message = self.wrap_human(message) 60 | wrapped_message = human_message 61 | else: 62 | gpt_message = self.wrap_gpt(message) 63 | wrapped_message = gpt_message 64 | 65 | # Update Prompt 66 | self.prompt += wrapped_message 67 | 68 | # Bump Turn Counter 69 | self.turn_count += 1 70 | 71 | # Return "wrapped_message" (effective string added to context) 72 | return wrapped_message 73 | 74 | def get_potential_prompt(self, message: str) -> None: 75 | # Assumes that it's always the user's (human's) turn! 76 | prompt_copy = str(self.prompt) 77 | 78 | # Special Handling for "system" prompt (turn_count == 0) 79 | if self.turn_count == 0: 80 | sys_message = self.wrap_human(self.system_prompt + message) 81 | prompt_copy += sys_message 82 | 83 | else: 84 | human_message = self.wrap_human(message) 85 | prompt_copy += human_message 86 | 87 | return prompt_copy.removeprefix(self.bos).rstrip() 88 | 89 | def get_prompt(self) -> str: 90 | # Remove prefix because it gets auto-inserted by tokenizer! 91 | return self.prompt.removeprefix(self.bos).rstrip() 92 | -------------------------------------------------------------------------------- /openpi-SF/src/openpi/shared/array_typing.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import functools as ft 3 | import inspect 4 | from typing import TypeAlias, TypeVar, cast 5 | 6 | import beartype 7 | import jax 8 | import jax._src.tree_util as private_tree_util 9 | import jax.core 10 | from jaxtyping import ArrayLike 11 | from jaxtyping import Bool # noqa: F401 12 | from jaxtyping import DTypeLike # noqa: F401 13 | from jaxtyping import Float 14 | from jaxtyping import Int # noqa: F401 15 | from jaxtyping import Key # noqa: F401 16 | from jaxtyping import Num # noqa: F401 17 | from jaxtyping import PyTree 18 | from jaxtyping import Real # noqa: F401 19 | from jaxtyping import UInt8 # noqa: F401 20 | from jaxtyping import config 21 | from jaxtyping import jaxtyped 22 | import jaxtyping._decorator 23 | import torch 24 | 25 | # patch jaxtyping to handle https://github.com/patrick-kidger/jaxtyping/issues/277. 26 | # the problem is that custom PyTree nodes are sometimes initialized with arbitrary types (e.g., `jax.ShapeDtypeStruct`, 27 | # `jax.Sharding`, or even ) due to JAX tracing operations. this patch skips typechecking when the stack trace 28 | # contains `jax._src.tree_util`, which should only be the case during tree unflattening. 29 | _original_check_dataclass_annotations = jaxtyping._decorator._check_dataclass_annotations # noqa: SLF001 30 | # Redefine Array to include both JAX arrays and PyTorch tensors 31 | Array = jax.Array | torch.Tensor 32 | 33 | 34 | def _check_dataclass_annotations(self, typechecker): 35 | if not any( 36 | frame.frame.f_globals.get("__name__") in {"jax._src.tree_util", "flax.nnx.transforms.compilation"} 37 | for frame in inspect.stack() 38 | ): 39 | return _original_check_dataclass_annotations(self, typechecker) 40 | return None 41 | 42 | 43 | jaxtyping._decorator._check_dataclass_annotations = _check_dataclass_annotations # noqa: SLF001 44 | 45 | KeyArrayLike: TypeAlias = jax.typing.ArrayLike 46 | Params: TypeAlias = PyTree[Float[ArrayLike, "..."]] 47 | 48 | T = TypeVar("T") 49 | 50 | 51 | # runtime type-checking decorator 52 | def typecheck(t: T) -> T: 53 | return cast(T, ft.partial(jaxtyped, typechecker=beartype.beartype)(t)) 54 | 55 | 56 | @contextlib.contextmanager 57 | def disable_typechecking(): 58 | initial = config.jaxtyping_disable 59 | config.update("jaxtyping_disable", True) # noqa: FBT003 60 | yield 61 | config.update("jaxtyping_disable", initial) 62 | 63 | 64 | def check_pytree_equality(*, expected: PyTree, got: PyTree, check_shapes: bool = False, check_dtypes: bool = False): 65 | """Checks that two PyTrees have the same structure and optionally checks shapes and dtypes. Creates a much nicer 66 | error message than if `jax.tree.map` is naively used on PyTrees with different structures. 67 | """ 68 | 69 | if errors := list(private_tree_util.equality_errors(expected, got)): 70 | raise ValueError( 71 | "PyTrees have different structure:\n" 72 | + ( 73 | "\n".join( 74 | f" - at keypath '{jax.tree_util.keystr(path)}': expected {thing1}, got {thing2}, so {explanation}.\n" 75 | for path, thing1, thing2, explanation in errors 76 | ) 77 | ) 78 | ) 79 | 80 | if check_shapes or check_dtypes: 81 | 82 | def check(kp, x, y): 83 | if check_shapes and x.shape != y.shape: 84 | raise ValueError(f"Shape mismatch at {jax.tree_util.keystr(kp)}: expected {x.shape}, got {y.shape}") 85 | 86 | if check_dtypes and x.dtype != y.dtype: 87 | raise ValueError(f"Dtype mismatch at {jax.tree_util.keystr(kp)}: expected {x.dtype}, got {y.dtype}") 88 | 89 | jax.tree_util.tree_map_with_path(check, expected, got) 90 | --------------------------------------------------------------------------------