├── prismatic ├── py.typed ├── extern │ ├── __init__.py │ └── hf │ │ ├── __init__.py │ │ └── configuration_prismatic.py ├── models │ ├── backbones │ │ ├── __init__.py │ │ ├── llm │ │ │ ├── __init__.py │ │ │ ├── prompting │ │ │ │ ├── __init__.py │ │ │ │ ├── mistral_instruct_prompter.py │ │ │ │ ├── phi_prompter.py │ │ │ │ ├── base_prompter.py │ │ │ │ ├── qwen_prompter.py │ │ │ │ ├── vicuna_v15_prompter.py │ │ │ │ └── llama2_chat_prompter.py │ │ │ ├── phi.py │ │ │ ├── mistral.py │ │ │ ├── qwen25.py │ │ │ └── llama2.py │ │ └── vision │ │ │ ├── __init__.py │ │ │ ├── dinov2_vit.py │ │ │ ├── in1k_vit.py │ │ │ ├── siglip_vit.py │ │ │ ├── clip_vit.py │ │ │ └── dinoclip_vit.py │ ├── vlas │ │ ├── __init__.py │ │ └── openvla.py │ ├── vlms │ │ ├── __init__.py │ │ └── base_vlm.py │ ├── __init__.py │ └── materialize.py ├── vla │ ├── datasets │ │ ├── rlds │ │ │ ├── utils │ │ │ │ ├── __init__.py │ │ │ │ ├── goal_relabeling.py │ │ │ │ └── task_augmentation.py │ │ │ ├── __init__.py │ │ │ ├── oxe │ │ │ │ ├── __init__.py │ │ │ │ ├── materialize.py │ │ │ │ └── utils │ │ │ │ │ └── droid_utils.py │ │ │ ├── obs_transforms.py │ │ │ └── traj_transforms.py │ │ └── __init__.py │ ├── __init__.py │ ├── action_dataset_materialize.py │ └── materialize.py ├── overwatch │ ├── __init__.py │ └── overwatch.py ├── util │ ├── __init__.py │ ├── nn_utils.py │ ├── torch_utils.py │ └── data_utils.py ├── preprocessing │ ├── datasets │ │ └── __init__.py │ ├── __init__.py │ └── materialize.py ├── __init__.py ├── training │ ├── __init__.py │ ├── strategies │ │ ├── __init__.py │ │ └── ddp.py │ └── materialize.py └── conf │ ├── __init__.py │ └── datasets.py ├── experiments └── robot │ ├── libero │ ├── libero_requirements.txt │ └── libero_utils.py │ ├── simpler │ ├── simpler_benchmark.py │ └── simpler_utils.py │ ├── robot_utils.py │ └── bridge │ ├── bridgev2_utils.py │ ├── widowx_env.py │ └── run_bridgev2_eval.py ├── requirements-min.txt ├── Makefile ├── .pre-commit-config.yaml ├── LICENSE ├── scripts ├── preprocess.py ├── additional-datasets │ ├── lvis_instruct_4v.py │ └── lrv_instruct.py ├── generate.py └── extern │ └── verify_prismatic.py ├── .gitignore ├── pyproject.toml └── vla-scripts ├── extern └── verify_openvla.py ├── deploy.py └── pretrain_vq.py /prismatic/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /prismatic/extern/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /prismatic/extern/hf/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /prismatic/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/rlds/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /prismatic/models/vlas/__init__.py: -------------------------------------------------------------------------------- 1 | from .openvla import OpenVLA 2 | -------------------------------------------------------------------------------- /prismatic/models/vlms/__init__.py: -------------------------------------------------------------------------------- 1 | from .prismatic import PrismaticVLM 2 | -------------------------------------------------------------------------------- /prismatic/overwatch/__init__.py: -------------------------------------------------------------------------------- 1 | from .overwatch import initialize_overwatch 2 | -------------------------------------------------------------------------------- /prismatic/vla/__init__.py: -------------------------------------------------------------------------------- 1 | from .materialize import get_vla_dataset_and_collator 2 | -------------------------------------------------------------------------------- /prismatic/util/__init__.py: -------------------------------------------------------------------------------- 1 | from .torch_utils import check_bloat16_supported, set_global_seed 2 | -------------------------------------------------------------------------------- /prismatic/preprocessing/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import AlignDataset, FinetuneDataset 2 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/rlds/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import make_interleaved_dataset, make_single_dataset 2 | -------------------------------------------------------------------------------- /prismatic/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import available_model_names, available_models, get_model_description, load 2 | -------------------------------------------------------------------------------- /prismatic/training/__init__.py: -------------------------------------------------------------------------------- 1 | from .materialize import get_train_strategy 2 | from .metrics import Metrics, VLAMetrics 3 | -------------------------------------------------------------------------------- /experiments/robot/libero/libero_requirements.txt: -------------------------------------------------------------------------------- 1 | imageio[ffmpeg] 2 | robosuite 3 | bddl 4 | easydict 5 | cloudpickle 6 | gym 7 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import DummyDataset, EpisodicRLDSDataset, RLDSBatchTransform, RLDSDataset 2 | -------------------------------------------------------------------------------- /requirements-min.txt: -------------------------------------------------------------------------------- 1 | timm==0.9.10 2 | tokenizers==0.19.1 3 | torch>=2.2.0 4 | torchvision>=0.16.0 5 | transformers==4.40.1 6 | -------------------------------------------------------------------------------- /prismatic/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | from .download import convert_to_jpg, download_extract 2 | from .materialize import get_dataset_and_collator 3 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/rlds/oxe/__init__.py: -------------------------------------------------------------------------------- 1 | from .materialize import get_oxe_dataset_kwargs_and_weights 2 | from .mixtures import OXE_NAMED_MIXTURES 3 | -------------------------------------------------------------------------------- /prismatic/training/strategies/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_strategy import TrainingStrategy 2 | from .ddp import DDPStrategy 3 | from .fsdp import FSDPStrategy 4 | -------------------------------------------------------------------------------- /prismatic/conf/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import DatasetConfig, DatasetRegistry 2 | from .models import ModelConfig, ModelRegistry 3 | from .vla import VLAConfig, VLARegistry 4 | -------------------------------------------------------------------------------- /prismatic/models/backbones/llm/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_llm import LLMBackbone 2 | from .llama2 import LLaMa2LLMBackbone 3 | from .mistral import MistralLLMBackbone 4 | from .phi import PhiLLMBackbone 5 | -------------------------------------------------------------------------------- /prismatic/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .load import available_model_names, available_models, get_model_description, load, load_vla 2 | from .materialize import get_llm_backbone_and_tokenizer, get_vision_backbone_and_transform, get_vlm 3 | -------------------------------------------------------------------------------- /prismatic/models/backbones/llm/prompting/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_prompter import PromptBuilder, PurePromptBuilder 2 | from .llama2_chat_prompter import LLaMa2ChatPromptBuilder 3 | from .mistral_instruct_prompter import MistralInstructPromptBuilder 4 | from .phi_prompter import PhiPromptBuilder 5 | from .vicuna_v15_prompter import VicunaV15ChatPromptBuilder 6 | -------------------------------------------------------------------------------- /prismatic/models/backbones/vision/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_vision import ImageTransform, VisionBackbone 2 | from .clip_vit import CLIPViTBackbone 3 | from .dinoclip_vit import DinoCLIPViTBackbone 4 | from .dinosiglip_vit import DinoSigLIPViTBackbone 5 | from .dinov2_vit import DinoV2ViTBackbone 6 | from .in1k_vit import IN1KViTBackbone 7 | from .siglip_vit import SigLIPViTBackbone 8 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: help clean check autoformat 2 | .DEFAULT: help 3 | 4 | # Generates a useful overview/help message for various make features - add to this as necessary! 5 | help: 6 | @echo "make clean" 7 | @echo " Remove all temporary pyc/pycache files" 8 | @echo "make check" 9 | @echo " Run code style and linting (black, ruff) *without* changing files!" 10 | @echo "make autoformat" 11 | @echo " Run code styling (black, ruff) and update in place - committing with pre-commit also does this." 12 | 13 | clean: 14 | find . -name "*.pyc" | xargs rm -f && \ 15 | find . -name "__pycache__" | xargs rm -rf 16 | 17 | check: 18 | black --check . 19 | ruff check --show-source . 20 | 21 | autoformat: 22 | black . 23 | ruff check --fix --show-fixes . 24 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | exclude: ".git" 4 | 5 | repos: 6 | - repo: https://github.com/astral-sh/ruff-pre-commit 7 | rev: v0.2.2 8 | hooks: 9 | - id: ruff 10 | args: [ --fix, --exit-non-zero-on-fix ] 11 | 12 | - repo: https://github.com/psf/black 13 | rev: 24.2.0 14 | hooks: 15 | - id: black 16 | 17 | - repo: https://github.com/pre-commit/pre-commit-hooks 18 | rev: v4.5.0 19 | hooks: 20 | - id: check-added-large-files 21 | - id: check-ast 22 | - id: check-case-conflict 23 | - id: check-merge-conflict 24 | - id: check-toml 25 | - id: check-yaml 26 | - id: end-of-file-fixer 27 | - id: trailing-whitespace 28 | -------------------------------------------------------------------------------- /prismatic/models/backbones/vision/dinov2_vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | dinov2_vit.py 3 | """ 4 | 5 | from prismatic.models.backbones.vision.base_vision import TimmViTBackbone 6 | 7 | # Registry =>> Supported DINOv2 Vision Backbones (from TIMM) =>> Note:: Using DINOv2 w/ Registers! 8 | # => Reference: https://arxiv.org/abs/2309.16588 9 | DINOv2_VISION_BACKBONES = {"dinov2-vit-l": "vit_large_patch14_reg4_dinov2.lvd142m"} 10 | 11 | 12 | class DinoV2ViTBackbone(TimmViTBackbone): 13 | def __init__( 14 | self, 15 | vision_backbone_id: str, 16 | image_resize_strategy: str, 17 | default_image_size: int = 224, 18 | image_sequence_len: int = 1, 19 | ) -> None: 20 | super().__init__( 21 | vision_backbone_id, 22 | DINOv2_VISION_BACKBONES[vision_backbone_id], 23 | image_resize_strategy, 24 | default_image_size=default_image_size, 25 | image_sequence_len=image_sequence_len, 26 | ) 27 | -------------------------------------------------------------------------------- /prismatic/models/backbones/vision/in1k_vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | in1k_vit.py 3 | 4 | Vision Transformers trained / finetuned on ImageNet (ImageNet-21K =>> ImageNet-1K) 5 | """ 6 | 7 | from prismatic.models.backbones.vision.base_vision import TimmViTBackbone 8 | 9 | # Registry =>> Supported Vision Backbones (from TIMM) 10 | IN1K_VISION_BACKBONES = { 11 | "in1k-vit-l": "vit_large_patch16_224.augreg_in21k_ft_in1k", 12 | } 13 | 14 | 15 | class IN1KViTBackbone(TimmViTBackbone): 16 | def __init__( 17 | self, 18 | vision_backbone_id: str, 19 | image_resize_strategy: str, 20 | default_image_size: int = 224, 21 | image_sequence_len: int = 1, 22 | ) -> None: 23 | super().__init__( 24 | vision_backbone_id, 25 | IN1K_VISION_BACKBONES[vision_backbone_id], 26 | image_resize_strategy, 27 | default_image_size=default_image_size, 28 | image_sequence_len=image_sequence_len, 29 | ) 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Moo Jin Kim, Karl Pertsch, Siddharth Karamcheti. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /prismatic/models/backbones/vision/siglip_vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | siglip_vit.py 3 | """ 4 | 5 | from prismatic.models.backbones.vision.base_vision import TimmViTBackbone 6 | 7 | # Registry =>> Supported SigLIP Vision Backbones (from TIMM) =>> Note:: Using SigLIP w/ Patch = 14 (but SO400M Arch) 8 | SIGLIP_VISION_BACKBONES = { 9 | "siglip-vit-b16-224px": "vit_base_patch16_siglip_224", 10 | "siglip-vit-b16-256px": "vit_base_patch16_siglip_256", 11 | "siglip-vit-b16-384px": "vit_base_patch16_siglip_384", 12 | "siglip-vit-so400m": "vit_so400m_patch14_siglip_224", 13 | "siglip-vit-so400m-384px": "vit_so400m_patch14_siglip_384", 14 | } 15 | 16 | 17 | class SigLIPViTBackbone(TimmViTBackbone): 18 | def __init__( 19 | self, 20 | vision_backbone_id: str, 21 | image_resize_strategy: str, 22 | default_image_size: int = 224, 23 | image_sequence_len: int = 1, 24 | ) -> None: 25 | super().__init__( 26 | vision_backbone_id, 27 | SIGLIP_VISION_BACKBONES[vision_backbone_id], 28 | image_resize_strategy, 29 | default_image_size=default_image_size, 30 | image_sequence_len=image_sequence_len, 31 | ) 32 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/rlds/utils/goal_relabeling.py: -------------------------------------------------------------------------------- 1 | """ 2 | goal_relabeling.py 3 | 4 | Contains simple goal relabeling logic for BC use-cases where rewards and next_observations are not required. 5 | Each function should add entries to the "task" dict. 6 | """ 7 | 8 | from typing import Dict 9 | 10 | import tensorflow as tf 11 | 12 | from prismatic.vla.datasets.rlds.utils.data_utils import tree_merge 13 | 14 | 15 | def uniform(traj: Dict) -> Dict: 16 | """Relabels with a true uniform distribution over future states.""" 17 | traj_len = tf.shape(tf.nest.flatten(traj["observation"])[0])[0] 18 | 19 | # Select a random future index for each transition i in the range [i + 1, traj_len) 20 | rand = tf.random.uniform([traj_len]) 21 | low = tf.cast(tf.range(traj_len) + 1, tf.float32) 22 | high = tf.cast(traj_len, tf.float32) 23 | goal_idxs = tf.cast(rand * (high - low) + low, tf.int32) 24 | 25 | # Sometimes there are floating-point errors that cause an out-of-bounds 26 | goal_idxs = tf.minimum(goal_idxs, traj_len - 1) 27 | 28 | # Adds keys to "task" mirroring "observation" keys (`tree_merge` to combine "pad_mask_dict" properly) 29 | goal = tf.nest.map_structure(lambda x: tf.gather(x, goal_idxs), traj["observation"]) 30 | traj["task"] = tree_merge(traj["task"], goal) 31 | 32 | return traj 33 | -------------------------------------------------------------------------------- /prismatic/models/backbones/vision/clip_vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | clip_vit.py 3 | """ 4 | 5 | from prismatic.models.backbones.vision.base_vision import TimmViTBackbone 6 | 7 | # Registry =>> Supported CLIP Vision Backbones (from TIMM) 8 | CLIP_VISION_BACKBONES = { 9 | "clip-vit-b": "vit_base_patch16_clip_224.openai", 10 | "clip-vit-l": "vit_large_patch14_clip_224.openai", 11 | "clip-vit-l-336px": "vit_large_patch14_clip_336.openai", 12 | } 13 | 14 | 15 | # [IMPORTANT] By Default, TIMM initialized OpenAI CLIP models with the standard GELU activation from PyTorch. 16 | # HOWEVER =>> Original OpenAI models were trained with the quick_gelu *approximation* -- while it's 17 | # a decent approximation, the resulting features are *worse*; this was a super tricky bug 18 | # to identify, but luckily there's an easy fix (`override_act_layer`) 19 | class CLIPViTBackbone(TimmViTBackbone): 20 | def __init__( 21 | self, 22 | vision_backbone_id: str, 23 | image_resize_strategy: str, 24 | default_image_size: int = 224, 25 | image_sequence_len: int = 1, 26 | ) -> None: 27 | super().__init__( 28 | vision_backbone_id, 29 | CLIP_VISION_BACKBONES[vision_backbone_id], 30 | image_resize_strategy, 31 | default_image_size=default_image_size, 32 | image_sequence_len=image_sequence_len, 33 | override_act_layer="quick_gelu" if CLIP_VISION_BACKBONES[vision_backbone_id].endswith(".openai") else None, 34 | ) 35 | -------------------------------------------------------------------------------- /experiments/robot/simpler/simpler_benchmark.py: -------------------------------------------------------------------------------- 1 | from experiments.robot.simpler.simpler_utils import get_simpler_env 2 | 3 | BENCHMARK_MAPPING = {} 4 | 5 | 6 | def register_benchmark(target_class): 7 | """We design the mapping to be case-INsensitive.""" 8 | BENCHMARK_MAPPING[target_class.__name__.lower()] = target_class 9 | 10 | 11 | def get_benchmark(benchmark_name): 12 | return BENCHMARK_MAPPING[benchmark_name.lower()] 13 | 14 | 15 | ### 16 | 17 | task_map = { 18 | "simpler_widowx": [ 19 | "widowx_spoon_on_towel", 20 | "widowx_carrot_on_plate", 21 | "widowx_stack_cube", 22 | "widowx_put_eggplant_in_basket", 23 | ], 24 | "simpler_widowx_carrot": [ 25 | "widowx_carrot_on_plate", 26 | ], 27 | } 28 | 29 | 30 | class Benchmark: 31 | def _make_benchmark(self): 32 | self.tasks = task_map[self.name] 33 | 34 | def get_task(self, i): 35 | return self.tasks[i] 36 | 37 | def make(self, *args, **kwargs): 38 | return self.env_fn(*args, **kwargs) 39 | 40 | @property 41 | def n_tasks(self): 42 | return len(self.tasks) 43 | 44 | 45 | class SimplerBenchmark(Benchmark): 46 | def __init__(self): 47 | super().__init__() 48 | self.env_fn = get_simpler_env 49 | self.state_dim = 7 50 | 51 | 52 | @register_benchmark 53 | class SIMPLER_WIDOWX(SimplerBenchmark): 54 | def __init__(self): 55 | super().__init__() 56 | self.name = "simpler_widowx" 57 | self._make_benchmark() 58 | 59 | 60 | @register_benchmark 61 | class SIMPLER_WIDOWX_CARROT(SimplerBenchmark): 62 | def __init__(self): 63 | super().__init__() 64 | self.name = "simpler_widowx_carrot" 65 | self._make_benchmark() 66 | -------------------------------------------------------------------------------- /scripts/preprocess.py: -------------------------------------------------------------------------------- 1 | """ 2 | preprocess.py 3 | 4 | Core script for automatically downloading raw VLM pretraining datasets. Supports downloading the following datasets: 5 | - LLaVA v1.5 Datasets (for both training stages) [`llava-laion-cc-sbu-558k`, `llava-v1.5-instruct`] 6 | - Stage 1 :: Projection Matrix Alignment between Vision Encoder & Pretrained LLM on CC-3M-595K (Custom) 7 | - Stage 2 :: Projection & LLM Finetuning on LLaVa v1.5 Instruct (including various vision-language train sets) 8 | 9 | By default, runs download & extraction automatically. 10 | 11 | Run with: `python scripts/preprocess.py --dataset_id ` 12 | """ 13 | 14 | from dataclasses import dataclass 15 | from pathlib import Path 16 | 17 | import draccus 18 | 19 | from prismatic.overwatch import initialize_overwatch 20 | from prismatic.preprocessing import convert_to_jpg, download_extract 21 | 22 | # Initialize Overwatch =>> Wraps `logging.Logger` 23 | overwatch = initialize_overwatch(__name__) 24 | 25 | 26 | @dataclass 27 | class PreprocessConfig: 28 | # fmt: off 29 | dataset_id: str = "llava-v1.5-instruct" # Unique identifier for dataset to process (see above) 30 | root_dir: Path = Path("data") # Path to root directory for storing datasets 31 | 32 | # fmt: on 33 | 34 | 35 | @draccus.wrap() 36 | def preprocess(cfg: PreprocessConfig) -> None: 37 | overwatch.info(f"Downloading & Extracting `{cfg.dataset_id}` to `{cfg.root_dir / 'download'}") 38 | download_extract(cfg.dataset_id, root_dir=cfg.root_dir) 39 | 40 | # Special Handling for OCR VQA Images (for `llava-v1.5-instruct`) --> convert GIFs/PNGs to JPG 41 | if cfg.dataset_id == "llava-v1.5-instruct": 42 | convert_to_jpg(cfg.root_dir / "download" / cfg.dataset_id / "ocr_vqa" / "images") 43 | 44 | 45 | if __name__ == "__main__": 46 | preprocess() 47 | -------------------------------------------------------------------------------- /prismatic/vla/action_dataset_materialize.py: -------------------------------------------------------------------------------- 1 | """Materialization for RAW dataset w/ images / actions / instructions. (no tokens)""" 2 | 3 | from dataclasses import dataclass 4 | from pathlib import Path 5 | from typing import Any, Dict, Tuple 6 | 7 | from prismatic.vla.datasets.datasets import EpisodicRLDSDataset, RLDSDataset 8 | 9 | 10 | @dataclass 11 | class RLDSActionBatchTransform: 12 | include_images: bool = True 13 | 14 | def __call__(self, rlds_batch: Dict[str, Any]) -> Dict[str, Any]: 15 | """Converts a RLDS batch to the format expected by the OpenVLA collator/models.""" 16 | dataset_name, action = rlds_batch["dataset_name"], rlds_batch["action"] 17 | lang = rlds_batch["task"]["language_instruction"].decode().lower() 18 | 19 | batch = dict(instruction=lang, action=action, dataset_name=dataset_name) 20 | if self.include_images: 21 | batch["image"] = rlds_batch["observation"]["image_primary"][0] 22 | 23 | return batch 24 | 25 | 26 | def get_vla_action_dataset( 27 | data_root_dir: Path, 28 | data_mix: str, 29 | default_image_resolution: Tuple[int, int, int], 30 | shuffle_buffer_size: int = 100_000, 31 | train: bool = True, 32 | episodic: bool = False, 33 | image_aug: bool = False, 34 | future_action_window_size: int = 0, 35 | include_images: bool = True, 36 | ): 37 | """Only get the image / action / instruction, don't do any tokenization.""" 38 | 39 | # TODO new batch transform 40 | batch_transform = RLDSActionBatchTransform(include_images=include_images) 41 | 42 | # Build RLDS Iterable Dataset & Return 43 | cls = RLDSDataset if not episodic else EpisodicRLDSDataset 44 | dataset = cls( 45 | data_root_dir, 46 | data_mix, 47 | batch_transform, 48 | resize_resolution=default_image_resolution[1:], 49 | shuffle_buffer_size=shuffle_buffer_size, 50 | train=train, 51 | image_aug=image_aug, 52 | # did not add support for below kwargs with episodic dataset 53 | future_action_window_size=future_action_window_size, 54 | ) 55 | 56 | return dataset 57 | -------------------------------------------------------------------------------- /prismatic/util/nn_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | nn_utils.py 3 | 4 | Utility functions and PyTorch submodule definitions. 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | # === Definitions for Various Projection Modules, with Signature :: [..., in_dim] --> [..., out_dim] === 12 | class LinearProjector(nn.Module): 13 | def __init__(self, vision_dim: int, llm_dim: int) -> None: 14 | super().__init__() 15 | self.projector = nn.Linear(vision_dim, llm_dim, bias=True) 16 | 17 | def forward(self, img_patches: torch.Tensor) -> torch.Tensor: 18 | return self.projector(img_patches) 19 | 20 | 21 | class MLPProjector(nn.Module): 22 | def __init__(self, vision_dim: int, llm_dim: int, mlp_type: str = "gelu-mlp") -> None: 23 | super().__init__() 24 | if mlp_type == "gelu-mlp": 25 | self.projector = nn.Sequential( 26 | nn.Linear(vision_dim, llm_dim, bias=True), 27 | nn.GELU(), 28 | nn.Linear(llm_dim, llm_dim, bias=True), 29 | ) 30 | else: 31 | raise ValueError(f"Projector with `{mlp_type = }` is not supported!") 32 | 33 | def forward(self, img_patches: torch.Tensor) -> torch.Tensor: 34 | return self.projector(img_patches) 35 | 36 | 37 | class FusedMLPProjector(nn.Module): 38 | def __init__(self, fused_vision_dim: int, llm_dim: int, mlp_type: str = "fused-gelu-mlp") -> None: 39 | super().__init__() 40 | self.initial_projection_dim = fused_vision_dim * 4 41 | if mlp_type == "fused-gelu-mlp": 42 | self.projector = nn.Sequential( 43 | nn.Linear(fused_vision_dim, self.initial_projection_dim, bias=True), 44 | nn.GELU(), 45 | nn.Linear(self.initial_projection_dim, llm_dim, bias=True), 46 | nn.GELU(), 47 | nn.Linear(llm_dim, llm_dim, bias=True), 48 | ) 49 | else: 50 | raise ValueError(f"Fused Projector with `{mlp_type = }` is not supported!") 51 | 52 | def forward(self, fused_img_patches: torch.Tensor) -> torch.Tensor: 53 | return self.projector(fused_img_patches) 54 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/rlds/utils/task_augmentation.py: -------------------------------------------------------------------------------- 1 | """ 2 | task_augmentation.py 3 | 4 | Contains basic logic for randomly zeroing out keys in the task specification. 5 | """ 6 | 7 | from typing import Dict 8 | 9 | import tensorflow as tf 10 | 11 | from prismatic.vla.datasets.rlds.utils.data_utils import to_padding 12 | 13 | 14 | def delete_task_conditioning(traj: Dict, keep_image_prob: float) -> Dict: 15 | """ 16 | Randomly drops out either the goal images or the language instruction. Only does something if both of 17 | these are present. 18 | 19 | Args: 20 | traj: A dictionary containing trajectory data. Should have a "task" key. 21 | keep_image_prob: The probability of keeping the goal images. The probability of keeping the language 22 | instruction is 1 - keep_image_prob. 23 | """ 24 | if "language_instruction" not in traj["task"]: 25 | return traj 26 | 27 | image_keys = {key for key in traj["task"].keys() if key.startswith("image_") or key.startswith("depth_")} 28 | if not image_keys: 29 | return traj 30 | 31 | traj_len = tf.shape(traj["action"])[0] 32 | should_keep_images = tf.random.uniform([traj_len]) < keep_image_prob 33 | should_keep_images |= ~traj["task"]["pad_mask_dict"]["language_instruction"] 34 | 35 | for key in image_keys | {"language_instruction"}: 36 | should_keep = should_keep_images if key in image_keys else ~should_keep_images 37 | # pad out the key 38 | traj["task"][key] = tf.where( 39 | should_keep, 40 | traj["task"][key], 41 | to_padding(traj["task"][key]), 42 | ) 43 | # zero out the pad mask dict for the key 44 | traj["task"]["pad_mask_dict"][key] = tf.where( 45 | should_keep, 46 | traj["task"]["pad_mask_dict"][key], 47 | tf.zeros_like(traj["task"]["pad_mask_dict"][key]), 48 | ) 49 | 50 | # when no goal images are present, the goal timestep becomes the final timestep 51 | traj["task"]["timestep"] = tf.where( 52 | should_keep_images, 53 | traj["task"]["timestep"], 54 | traj_len - 1, 55 | ) 56 | 57 | return traj 58 | -------------------------------------------------------------------------------- /prismatic/models/backbones/llm/phi.py: -------------------------------------------------------------------------------- 1 | """ 2 | phi.py 3 | 4 | Class definition for all LLMs derived from PhiForCausalLM. 5 | """ 6 | 7 | from typing import Optional, Type 8 | 9 | import torch 10 | from torch import nn as nn 11 | from transformers import PhiForCausalLM 12 | from transformers.models.phi.modeling_phi import PhiDecoderLayer 13 | 14 | from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone 15 | from prismatic.models.backbones.llm.prompting import PhiPromptBuilder, PromptBuilder 16 | 17 | # Registry ==> Support Phi Models (from HF Transformers) 18 | # fmt: off 19 | PHI_MODELS = { 20 | # === Phi-2 === 21 | "phi-2-3b": { 22 | "llm_family": "phi", "llm_cls": PhiForCausalLM, "hf_hub_path": "microsoft/phi-2" 23 | } 24 | } 25 | # fmt: on 26 | 27 | 28 | class PhiLLMBackbone(HFCausalLLMBackbone): 29 | def __init__( 30 | self, 31 | llm_backbone_id: str, 32 | llm_max_length: int = 2048, 33 | hf_token: Optional[str] = None, 34 | inference_mode: bool = False, 35 | use_flash_attention_2: bool = True, 36 | ) -> None: 37 | super().__init__( 38 | llm_backbone_id, 39 | llm_max_length=llm_max_length, 40 | hf_token=hf_token, 41 | inference_mode=inference_mode, 42 | use_flash_attention_2=use_flash_attention_2, 43 | **PHI_MODELS[llm_backbone_id], 44 | ) 45 | 46 | # [Special Case] Phi PAD Token Handling --> for clarity, we add an extra token (and resize) 47 | self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) 48 | self.llm.config.pad_token_id = self.tokenizer.pad_token_id 49 | self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64) 50 | 51 | @property 52 | def prompt_builder_fn(self) -> Type[PromptBuilder]: 53 | if self.identifier.startswith("phi-2"): 54 | return PhiPromptBuilder 55 | 56 | raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`") 57 | 58 | @property 59 | def transformer_layer_cls(self) -> Type[nn.Module]: 60 | return PhiDecoderLayer 61 | 62 | @property 63 | def half_precision_dtype(self) -> torch.dtype: 64 | return torch.bfloat16 65 | -------------------------------------------------------------------------------- /prismatic/models/backbones/llm/prompting/mistral_instruct_prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | mistral_instruct_prompter.py 3 | 4 | Defines a PromptBuilder for building Mistral Instruct Chat Prompts --> recommended pattern used by HF / online tutorial.s 5 | 6 | Reference: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format 7 | """ 8 | 9 | from typing import Optional 10 | 11 | from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder 12 | 13 | 14 | class MistralInstructPromptBuilder(PromptBuilder): 15 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 16 | super().__init__(model_family, system_prompt) 17 | 18 | # Note =>> Mistral Tokenizer is an instance of `LlamaTokenizer(Fast)` 19 | # =>> Mistral Instruct *does not* use a System Prompt 20 | self.bos, self.eos = "", "" 21 | 22 | # Get role-specific "wrap" functions 23 | self.wrap_human = lambda msg: f"[INST] {msg} [/INST] " 24 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" 25 | 26 | # === `self.prompt` gets built up over multiple turns === 27 | self.prompt, self.turn_count = "", 0 28 | 29 | def add_turn(self, role: str, message: str) -> str: 30 | assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") 31 | message = message.replace("", "").strip() 32 | 33 | if (self.turn_count % 2) == 0: 34 | human_message = self.wrap_human(message) 35 | wrapped_message = human_message 36 | else: 37 | gpt_message = self.wrap_gpt(message) 38 | wrapped_message = gpt_message 39 | 40 | # Update Prompt 41 | self.prompt += wrapped_message 42 | 43 | # Bump Turn Counter 44 | self.turn_count += 1 45 | 46 | # Return "wrapped_message" (effective string added to context) 47 | return wrapped_message 48 | 49 | def get_potential_prompt(self, message: str) -> None: 50 | # Assumes that it's always the user's (human's) turn! 51 | prompt_copy = str(self.prompt) 52 | 53 | human_message = self.wrap_human(message) 54 | prompt_copy += human_message 55 | 56 | return prompt_copy.removeprefix(self.bos).rstrip() 57 | 58 | def get_prompt(self) -> str: 59 | # Remove prefix because it gets auto-inserted by tokenizer! 60 | return self.prompt.removeprefix(self.bos).rstrip() 61 | -------------------------------------------------------------------------------- /prismatic/models/backbones/llm/mistral.py: -------------------------------------------------------------------------------- 1 | """ 2 | mistral.py 3 | 4 | Class definition for all LLMs derived from MistralForCausalLM. 5 | """ 6 | 7 | from typing import Optional, Type 8 | 9 | import torch 10 | from torch import nn as nn 11 | from transformers import MistralForCausalLM 12 | from transformers.models.mistral.modeling_mistral import MistralDecoderLayer 13 | 14 | from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone 15 | from prismatic.models.backbones.llm.prompting import MistralInstructPromptBuilder, PromptBuilder, PurePromptBuilder 16 | 17 | # Registry =>> Support Mistral Models (from HF Transformers) 18 | # fmt: off 19 | MISTRAL_MODELS = { 20 | # === Base Mistral v0.1 === 21 | "mistral-v0.1-7b-pure": { 22 | "llm_family": "mistral", "llm_cls": MistralForCausalLM, "hf_hub_path": "mistralai/Mistral-7B-v0.1" 23 | }, 24 | 25 | # === Mistral Instruct v0.1 === 26 | "mistral-v0.1-7b-instruct": { 27 | "llm_family": "mistral", "llm_cls": MistralForCausalLM, "hf_hub_path": "mistralai/Mistral-7B-Instruct-v0.1" 28 | } 29 | } 30 | # fmt: on 31 | 32 | 33 | class MistralLLMBackbone(HFCausalLLMBackbone): 34 | def __init__( 35 | self, 36 | llm_backbone_id: str, 37 | llm_max_length: int = 2048, 38 | hf_token: Optional[str] = None, 39 | inference_mode: bool = False, 40 | use_flash_attention_2: bool = True, 41 | ) -> None: 42 | super().__init__( 43 | llm_backbone_id, 44 | llm_max_length=llm_max_length, 45 | hf_token=hf_token, 46 | inference_mode=inference_mode, 47 | use_flash_attention_2=use_flash_attention_2, 48 | **MISTRAL_MODELS[llm_backbone_id], 49 | ) 50 | 51 | # [Special Case] Mistral PAD Token Handling --> for clarity, we add an extra token (and resize) 52 | self.tokenizer.add_special_tokens({"pad_token": ""}) 53 | self.llm.config.pad_token_id = self.tokenizer.pad_token_id 54 | self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64) 55 | 56 | @property 57 | def prompt_builder_fn(self) -> Type[PromptBuilder]: 58 | if self.identifier.endswith("-pure"): 59 | return PurePromptBuilder 60 | 61 | elif self.identifier.endswith("-instruct"): 62 | return MistralInstructPromptBuilder 63 | 64 | raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`") 65 | 66 | @property 67 | def transformer_layer_cls(self) -> Type[nn.Module]: 68 | return MistralDecoderLayer 69 | 70 | @property 71 | def half_precision_dtype(self) -> torch.dtype: 72 | return torch.bfloat16 73 | -------------------------------------------------------------------------------- /prismatic/training/materialize.py: -------------------------------------------------------------------------------- 1 | """ 2 | materialize.py 3 | 4 | Factory class defining functions for instantiating various Training Strategies, supporting different VLMs, backbones, 5 | and strategy configurations. 6 | """ 7 | 8 | from typing import Callable, Optional 9 | 10 | import torch 11 | 12 | from prismatic.models.vlms import PrismaticVLM 13 | from prismatic.training.strategies import FSDPStrategy, TrainingStrategy 14 | 15 | # Registry =>> Maps ID --> {cls(), kwargs} :: supports FSDP for now, but DDP handler is also implemented! 16 | TRAIN_STRATEGIES = { 17 | "fsdp-shard-grad-op": {"cls": FSDPStrategy, "kwargs": {"sharding_strategy": "shard-grad-op"}}, 18 | "fsdp-full-shard": {"cls": FSDPStrategy, "kwargs": {"sharding_strategy": "full-shard"}}, 19 | } 20 | 21 | 22 | def get_train_strategy( 23 | train_strategy: str, 24 | vlm: PrismaticVLM, 25 | device_id: int, 26 | stage: str, 27 | epochs: int, 28 | max_steps: Optional[int], 29 | global_batch_size: int, 30 | per_device_batch_size: int, 31 | learning_rate: float, 32 | weight_decay: float, 33 | max_grad_norm: float, 34 | lr_scheduler_type: str, 35 | warmup_ratio: float, 36 | enable_gradient_checkpointing: bool = True, 37 | enable_mixed_precision_training: bool = True, 38 | reduce_in_full_precision: bool = False, 39 | mixed_precision_dtype: torch.dtype = torch.bfloat16, 40 | worker_init_fn: Optional[Callable[[int], None]] = None, 41 | save_every_n_steps: Optional[int] = None, 42 | ) -> TrainingStrategy: 43 | if train_strategy in TRAIN_STRATEGIES: 44 | strategy_cfg = TRAIN_STRATEGIES[train_strategy] 45 | strategy = strategy_cfg["cls"]( 46 | vlm=vlm, 47 | device_id=device_id, 48 | stage=stage, 49 | epochs=epochs, 50 | max_steps=max_steps, 51 | global_batch_size=global_batch_size, 52 | per_device_batch_size=per_device_batch_size, 53 | learning_rate=learning_rate, 54 | weight_decay=weight_decay, 55 | max_grad_norm=max_grad_norm, 56 | lr_scheduler_type=lr_scheduler_type, 57 | warmup_ratio=warmup_ratio, 58 | enable_gradient_checkpointing=enable_gradient_checkpointing, 59 | enable_mixed_precision_training=enable_mixed_precision_training, 60 | reduce_in_full_precision=reduce_in_full_precision, 61 | mixed_precision_dtype=mixed_precision_dtype, 62 | worker_init_fn=worker_init_fn, 63 | save_every_n_steps=save_every_n_steps, 64 | **strategy_cfg["kwargs"], 65 | ) 66 | return strategy 67 | else: 68 | raise ValueError(f"Train Strategy `{train_strategy}` is not supported!") 69 | -------------------------------------------------------------------------------- /prismatic/models/backbones/llm/prompting/phi_prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | phi_prompter.py 3 | 4 | Defines a PromptBuilder for building Phi-2 Input/Output Prompts --> recommended pattern used by HF / Microsoft. 5 | Also handles Phi special case BOS token additions. 6 | 7 | Reference: https://huggingface.co/microsoft/phi-2#qa-format 8 | """ 9 | 10 | from typing import Optional 11 | 12 | from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder 13 | 14 | 15 | class PhiPromptBuilder(PromptBuilder): 16 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 17 | super().__init__(model_family, system_prompt) 18 | 19 | # Note =>> Phi Tokenizer is an instance of `CodeGenTokenizer(Fast)` 20 | # =>> By default, does *not* append / tokens --> we handle that here (IMPORTANT)! 21 | self.bos, self.eos = "<|endoftext|>", "<|endoftext|>" 22 | 23 | # Get role-specific "wrap" functions 24 | # =>> Note that placement of / were based on experiments generating from Phi-2 in Input/Output mode 25 | self.wrap_human = lambda msg: f"Input: {msg}\nOutput: " 26 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}\n{self.eos}" 27 | 28 | # === `self.prompt` gets built up over multiple turns === 29 | self.prompt, self.turn_count = "", 0 30 | 31 | def add_turn(self, role: str, message: str) -> str: 32 | assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") 33 | message = message.replace("", "").strip() 34 | 35 | # Special Handling for "first" input --> prepend a token (expected by Prismatic) 36 | if self.turn_count == 0: 37 | bos_human_message = f"{self.bos}{self.wrap_human(message)}" 38 | wrapped_message = bos_human_message 39 | elif (self.turn_count % 2) == 0: 40 | human_message = self.wrap_human(message) 41 | wrapped_message = human_message 42 | else: 43 | gpt_message = self.wrap_gpt(message) 44 | wrapped_message = gpt_message 45 | 46 | # Update Prompt 47 | self.prompt += wrapped_message 48 | 49 | # Bump Turn Counter 50 | self.turn_count += 1 51 | 52 | # Return "wrapped_message" (effective string added to context) 53 | return wrapped_message 54 | 55 | def get_potential_prompt(self, message: str) -> None: 56 | # Assumes that it's always the user's (human's) turn! 57 | prompt_copy = str(self.prompt) 58 | 59 | human_message = self.wrap_human(message) 60 | prompt_copy += human_message 61 | 62 | return prompt_copy.rstrip() 63 | 64 | def get_prompt(self) -> str: 65 | return self.prompt.rstrip() 66 | -------------------------------------------------------------------------------- /prismatic/preprocessing/materialize.py: -------------------------------------------------------------------------------- 1 | """ 2 | materialize.py 3 | 4 | Factory class for initializing pretraining datasets on a per-VLM basis; provides and exports individual functions for 5 | clear control flow. 6 | """ 7 | 8 | from typing import Tuple, Type 9 | 10 | from torch.utils.data import Dataset 11 | from transformers import PreTrainedTokenizerBase 12 | 13 | from prismatic.conf import DatasetConfig 14 | from prismatic.models.backbones.llm.prompting import PromptBuilder 15 | from prismatic.models.backbones.vision import ImageTransform 16 | from prismatic.preprocessing.datasets import AlignDataset, FinetuneDataset 17 | from prismatic.util.data_utils import PaddedCollatorForLanguageModeling 18 | 19 | # Dataset Initializers =>> Maps Stage --> cls() 20 | DATASET_INITIALIZER = {"align": AlignDataset, "finetune": FinetuneDataset, "full-finetune": FinetuneDataset} 21 | 22 | 23 | def get_dataset_and_collator( 24 | stage: str, 25 | dataset_cfg: DatasetConfig, 26 | image_transform: ImageTransform, 27 | tokenizer: PreTrainedTokenizerBase, 28 | prompt_builder_fn: Type[PromptBuilder], 29 | default_image_resolution: Tuple[int, int, int], 30 | padding_side: str = "right", 31 | ) -> Tuple[Dataset, PaddedCollatorForLanguageModeling]: 32 | dataset_cls = DATASET_INITIALIZER[stage] 33 | dataset_root_dir = dataset_cfg.dataset_root_dir 34 | collator = PaddedCollatorForLanguageModeling( 35 | tokenizer.model_max_length, tokenizer.pad_token_id, default_image_resolution, padding_side=padding_side 36 | ) 37 | 38 | # Switch on `stage` 39 | if stage == "align": 40 | annotation_json, image_dir = dataset_cfg.align_stage_components 41 | dataset = dataset_cls( 42 | dataset_root_dir / annotation_json, dataset_root_dir / image_dir, image_transform, tokenizer 43 | ) 44 | return dataset, collator 45 | 46 | elif stage == "finetune": 47 | annotation_json, image_dir = dataset_cfg.finetune_stage_components 48 | dataset = dataset_cls( 49 | dataset_root_dir / annotation_json, 50 | dataset_root_dir / image_dir, 51 | image_transform, 52 | tokenizer, 53 | prompt_builder_fn=prompt_builder_fn, 54 | ) 55 | return dataset, collator 56 | 57 | elif stage == "full-finetune": 58 | annotation_json, image_dir = dataset_cfg.finetune_stage_components 59 | dataset = dataset_cls( 60 | dataset_root_dir / annotation_json, 61 | dataset_root_dir / image_dir, 62 | image_transform, 63 | tokenizer, 64 | prompt_builder_fn=prompt_builder_fn, 65 | ) 66 | return dataset, collator 67 | 68 | else: 69 | raise ValueError(f"Stage `{stage}` is not supported!") 70 | -------------------------------------------------------------------------------- /prismatic/models/backbones/llm/prompting/base_prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | base_prompter.py 3 | 4 | Abstract class definition of a multi-turn prompt builder for ensuring consistent formatting for chat-based LLMs. 5 | """ 6 | 7 | from abc import ABC, abstractmethod 8 | from typing import Optional 9 | 10 | 11 | class PromptBuilder(ABC): 12 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 13 | self.model_family = model_family 14 | 15 | # Only some models define a system prompt => let subclasses handle this logic! 16 | self.system_prompt = system_prompt 17 | 18 | @abstractmethod 19 | def add_turn(self, role: str, message: str) -> str: ... 20 | 21 | @abstractmethod 22 | def get_potential_prompt(self, user_msg: str) -> None: ... 23 | 24 | @abstractmethod 25 | def get_prompt(self) -> str: ... 26 | 27 | 28 | class PurePromptBuilder(PromptBuilder): 29 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 30 | super().__init__(model_family, system_prompt) 31 | 32 | # TODO (siddk) =>> Can't always assume LlamaTokenizer --> FIX ME! 33 | self.bos, self.eos = "", "" 34 | 35 | # Get role-specific "wrap" functions 36 | self.wrap_human = lambda msg: f"In: {msg}\nOut: " 37 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" 38 | 39 | # === `self.prompt` gets built up over multiple turns === 40 | self.prompt, self.turn_count = "", 0 41 | 42 | def add_turn(self, role: str, message: str) -> str: 43 | assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") 44 | message = message.replace("", "").strip() 45 | 46 | if (self.turn_count % 2) == 0: 47 | human_message = self.wrap_human(message) 48 | wrapped_message = human_message 49 | else: 50 | gpt_message = self.wrap_gpt(message) 51 | wrapped_message = gpt_message 52 | 53 | # Update Prompt 54 | self.prompt += wrapped_message 55 | 56 | # Bump Turn Counter 57 | self.turn_count += 1 58 | 59 | # Return "wrapped_message" (effective string added to context) 60 | return wrapped_message 61 | 62 | def get_potential_prompt(self, message: str) -> None: 63 | # Assumes that it's always the user's (human's) turn! 64 | prompt_copy = str(self.prompt) 65 | 66 | human_message = self.wrap_human(message) 67 | prompt_copy += human_message 68 | 69 | return prompt_copy.removeprefix(self.bos).rstrip() 70 | 71 | def get_prompt(self) -> str: 72 | # Remove prefix (if exists) because it gets auto-inserted by tokenizer! 73 | return self.prompt.removeprefix(self.bos).rstrip() 74 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Ruff 132 | .ruff_cache/ 133 | 134 | # Auth Tokens / Hidden Files 135 | .hf_token 136 | .wandb_api_key 137 | .*_token 138 | .*api_key 139 | 140 | # IDE Caches 141 | .idea/ 142 | .vscode/ 143 | 144 | # Mac OS 145 | .DS_Store 146 | 147 | # Caches and Datasets 148 | cache/ 149 | data/ 150 | 151 | # Rollout videos and wandb logs 152 | rollouts/ 153 | wandb/ 154 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "openvla" 7 | authors = [ 8 | {name = "Moo Jin Kim", email="moojink@stanford.edu"}, 9 | {name = "Karl Pertsch", email="pertsch@berkeley.edu"}, 10 | {name = "Siddharth Karamcheti", email="skaramcheti@cs.stanford.edu"}, 11 | ] 12 | description = "OpenVLA: Vision-Language-Action Models for Robotics" 13 | version = "0.0.3" 14 | readme = "README.md" 15 | requires-python = ">=3.8" 16 | keywords = ["vision-language-actions models", "multimodal pretraining", "robot learning"] 17 | license = {file = "LICENSE"} 18 | classifiers = [ 19 | "Development Status :: 3 - Alpha", 20 | "Intended Audience :: Developers", 21 | "Intended Audience :: Education", 22 | "Intended Audience :: Science/Research", 23 | "License :: OSI Approved :: MIT License", 24 | "Operating System :: OS Independent", 25 | "Programming Language :: Python :: 3", 26 | "Programming Language :: Python :: 3.8", 27 | "Programming Language :: Python :: 3.9", 28 | "Programming Language :: Python :: 3.10", 29 | "Programming Language :: Python :: 3 :: Only", 30 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 31 | ] 32 | dependencies = [ 33 | "accelerate>=0.25.0", 34 | "draccus==0.8.0", 35 | "einops", 36 | # "flash_attn==2.5.5", # Here for documentation -- install *AFTER* editable install (follow README) 37 | "huggingface_hub", 38 | "json-numpy", 39 | "jsonlines", 40 | "matplotlib", 41 | "peft==0.11.1", 42 | "protobuf", 43 | "rich", 44 | "sentencepiece==0.1.99", 45 | "timm==0.9.10", 46 | "tokenizers==0.19.1", 47 | "torch==2.2.0", 48 | "torchvision==0.17.0", 49 | "torchaudio==2.2.0", 50 | "transformers==4.40.1", 51 | "wandb", 52 | "tensorflow==2.15.0", 53 | "tensorflow_datasets==4.9.3", 54 | "tensorflow_graphics==2021.12.3", 55 | "dlimp @ git+https://github.com/moojink/dlimp_openvla" 56 | ] 57 | 58 | [project.optional-dependencies] 59 | dev = [ 60 | "black>=24.2.0", 61 | "gpustat", 62 | "ipython", 63 | "pre-commit", 64 | "ruff>=0.2.2", 65 | ] 66 | sagemaker = [ 67 | "boto3", 68 | "sagemaker" 69 | ] 70 | 71 | [project.urls] 72 | homepage = "https://github.com/openvla/openvla" 73 | repository = "https://github.com/openvla/openvla" 74 | documentation = "https://github.com/openvla/openvla" 75 | 76 | [tool.setuptools.packages.find] 77 | where = ["."] 78 | exclude = ["cache"] 79 | 80 | [tool.setuptools.package-data] 81 | "prismatic" = ["py.typed"] 82 | 83 | [tool.black] 84 | line-length = 121 85 | target-version = ["py38", "py39", "py310"] 86 | preview = true 87 | 88 | [tool.ruff] 89 | line-length = 121 90 | target-version = "py38" 91 | 92 | [tool.ruff.lint] 93 | select = ["A", "B", "E", "F", "I", "RUF", "W"] 94 | ignore = ["F722"] 95 | 96 | [tool.ruff.lint.per-file-ignores] 97 | "__init__.py" = ["E402", "F401"] 98 | -------------------------------------------------------------------------------- /experiments/robot/simpler/simpler_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import simpler_env 3 | import tensorflow as tf 4 | from simpler_env.utils.env.observation_utils import get_image_from_maniskill2_obs_dict 5 | from transforms3d.euler import euler2axangle 6 | 7 | from experiments.robot.robot_utils import normalize_gripper_action 8 | 9 | 10 | def get_simpler_img(env, obs, resize_size): 11 | """ 12 | Takes in environment and observation and returns resized image as numpy array. 13 | 14 | NOTE (Moo Jin): To make input images in distribution with respect to the inputs seen at training time, we follow 15 | the same resizing scheme used in the Octo dataloader, which OpenVLA uses for training. 16 | """ 17 | assert isinstance(resize_size, int) 18 | image = get_image_from_maniskill2_obs_dict(env, obs) 19 | 20 | # Preprocess the image the exact same way that the Berkeley Bridge folks did it 21 | # to minimize distribution shift. 22 | # NOTE (Moo Jin): Yes, we resize down to 256x256 first even though the image may end up being 23 | # resized up to a different resolution by some models. This is just so that we're in-distribution 24 | # w.r.t. the original preprocessing at train time. 25 | IMAGE_BASE_PREPROCESS_SIZE = 128 26 | # Resize to image size expected by model 27 | image = tf.image.encode_jpeg(image) # Encode as JPEG, as done in RLDS dataset builder 28 | image = tf.io.decode_image(image, expand_animations=False, dtype=tf.uint8) # Immediately decode back 29 | image = tf.image.resize( 30 | image, (IMAGE_BASE_PREPROCESS_SIZE, IMAGE_BASE_PREPROCESS_SIZE), method="lanczos3", antialias=True 31 | ) 32 | image = tf.image.resize(image, (resize_size, resize_size), method="lanczos3", antialias=True) 33 | image = tf.cast(tf.clip_by_value(tf.round(image), 0, 255), tf.uint8) 34 | return image.numpy() 35 | 36 | 37 | def get_simpler_env(task, model_family): 38 | """Initializes and returns the Simpler environment along with the task description.""" 39 | env = simpler_env.make(task) 40 | return env 41 | 42 | 43 | def get_simpler_dummy_action(model_family: str): 44 | if model_family == "octo": 45 | # TODO: don't hardcode the action horizon for Octo 46 | return np.tile(np.array([0, 0, 0, 0, 0, 0, -1])[None], (4, 1)) 47 | else: 48 | return np.array([0, 0, 0, 0, 0, 0, -1]) 49 | 50 | 51 | def convert_maniskill(action): 52 | """ 53 | Applies transforms to raw VLA action that Maniskill simpler_env env expects. 54 | Converts rotation to axis_angle. 55 | Changes gripper action (last dimension of action vector) from [0,1] to [-1,+1] and binarizes. 56 | """ 57 | assert action.shape[0] == 7 58 | 59 | # Change rotation to axis-angle 60 | action = action.copy() 61 | roll, pitch, yaw = action[3], action[4], action[5] 62 | action_rotation_ax, action_rotation_angle = euler2axangle(roll, pitch, yaw) 63 | action[3:6] = action_rotation_ax * action_rotation_angle 64 | 65 | # Binarize final gripper dimension & map to [-1...1] 66 | return normalize_gripper_action(action) 67 | -------------------------------------------------------------------------------- /scripts/additional-datasets/lvis_instruct_4v.py: -------------------------------------------------------------------------------- 1 | """ 2 | scripts/additional-datasets/lvis_instruct4v.py 3 | 4 | Standalone script for pre-processing the LVIS-Instruct4V (language/chat) data (`lvis_instruct4v_220k.json`). This 5 | dataset is curated from LVIS images (subset of COCO yet again), but chat data is synthesized from GPT4-Vision. 6 | 7 | This script downloads the raw data, merges with the LLaVa v15 data, and performs any other data normalization, saving 8 | the resulting `.json` file(s) to the `data/download/llava-v1.5-instruct/` directory. 9 | 10 | Make sure to download the COCO Val 2017 (LVIS) data to `data/download/llava-v1.5-instruct/coco`: 11 | => cd data/download/llava-v1.5-instruct/coco 12 | => wget http://images.cocodataset.org/zips/val2017.zip 13 | => unzip val2017.zip; rm val2017.zip 14 | 15 | References: "To See is to Believe: Prompting GPT-4V for Better Visual Instruction Tuning" 16 | => Paper: https://arxiv.org/abs/2311.07574 17 | => Github / Data: https://github.com/X2FD/LVIS-INSTRUCT4V || https://huggingface.co/datasets/X2FD/LVIS-Instruct4V 18 | """ 19 | 20 | import json 21 | import os 22 | import random 23 | from pathlib import Path 24 | 25 | from tqdm import tqdm 26 | 27 | from prismatic.preprocessing.download import download_with_progress 28 | 29 | # === Constants === 30 | DATA_URL = "https://huggingface.co/datasets/X2FD/LVIS-Instruct4V/resolve/main/lvis_instruct4v_220k.json" 31 | DOWNLOAD_DIR = Path("data/download/llava-v1.5-instruct") 32 | RAW_JSON_FILE = DOWNLOAD_DIR / "lvis_instruct4v_220k.json" 33 | 34 | # JSON Files for "merged" variant of the dataset (with `llava_v1_5_mix665k.json`) 35 | BASE_JSON_FILE = DOWNLOAD_DIR / "llava_v1_5_mix665k.json" 36 | MERGED_JSON_FILE = DOWNLOAD_DIR / "llava_v1_5_lvis4v_mix888k.json" 37 | 38 | 39 | def build_lvis_instruct_4v() -> None: 40 | print("[*] Downloading and Formatting `LVIS-Instruct-4V` Dataset!") 41 | 42 | # Set Random Seed 43 | random.seed(7) 44 | 45 | # Download Dataset JSON 46 | os.makedirs(DOWNLOAD_DIR, exist_ok=True) 47 | if not RAW_JSON_FILE.exists(): 48 | download_with_progress(DATA_URL, DOWNLOAD_DIR) 49 | 50 | # Open JSON File --> verify image existence! 51 | print("[*] Loading LVIS Instruct4V Data!") 52 | with open(RAW_JSON_FILE, "r") as f: 53 | data = json.load(f) 54 | 55 | # Iterate & Verify 56 | for example in tqdm(data, desc="[*] Verifying all Images in LVIS Instruct4V"): 57 | image_path = example["image"] 58 | assert (DOWNLOAD_DIR / image_path).exists(), f"Missing Image `{image_path}`" 59 | 60 | # Create Stacked Dataset =>> Shuffle for Good Measure! 61 | print("[*] Loading LLaVa v1.5 Data!") 62 | with open(BASE_JSON_FILE, "r") as f: 63 | llava_v15_data = json.load(f) 64 | 65 | # Combine & Shuffle & Write 66 | full_data = llava_v15_data + data 67 | 68 | random.shuffle(full_data) 69 | random.shuffle(full_data) 70 | random.shuffle(full_data) 71 | 72 | with open(MERGED_JSON_FILE, "w") as f: 73 | json.dump(full_data, f) 74 | 75 | 76 | if __name__ == "__main__": 77 | build_lvis_instruct_4v() 78 | -------------------------------------------------------------------------------- /prismatic/models/backbones/llm/prompting/qwen_prompter.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder 4 | 5 | SYS_PROMPTS = { 6 | "prismatic": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.", 7 | "openvla": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.", 8 | } 9 | 10 | 11 | class QwenPromptBuilder(PromptBuilder): 12 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 13 | super().__init__(model_family, system_prompt) 14 | 15 | self.system_prompt = (SYS_PROMPTS[model_family] if system_prompt is None else self.system_prompt).strip() 16 | 17 | # Note =>> Qwen Tokenizer is an instance of `Qwen2Tokenizer(Fast)` 18 | # =>> By default, there is *no* token. we add manually. 19 | self.bos = self.start = "<|im_start|>" # NOTE this is not used 20 | self.eos = "<|endoftext|>" 21 | 22 | self.end = "<|im_end|>" 23 | 24 | # Get role-specific "wrap" functions 25 | # =>> Note that placement of / were based on experiments generating from Phi-2 in Input/Output mode 26 | self.wrap_system = lambda msg: f"{self.start}system\n{msg}{self.end}\n" 27 | self.wrap_human = lambda msg: f"{self.start}user\n{msg}{self.end}\n{self.start}assistant\n" 28 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.end}\n" 29 | 30 | # === `self.prompt` gets built up over multiple turns === 31 | self.prompt, self.turn_count = "", 0 32 | 33 | def add_turn(self, role: str, message: str) -> str: 34 | assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") 35 | message = message.replace("", "").strip() 36 | 37 | # Special Handling for "first" input --> add an optional system prompt to the beginning. 38 | if self.turn_count == 0 and self.system_prompt is not None: 39 | self.prompt += self.wrap_system(self.system_prompt) 40 | 41 | if (self.turn_count % 2) == 0: 42 | human_message = self.wrap_human(message) 43 | wrapped_message = human_message 44 | else: 45 | gpt_message = self.wrap_gpt(message) 46 | wrapped_message = gpt_message 47 | 48 | # Update Prompt 49 | self.prompt += wrapped_message 50 | 51 | # Bump Turn Counter 52 | self.turn_count += 1 53 | 54 | # Return "wrapped_message" (effective string added to context) 55 | return wrapped_message 56 | 57 | def get_potential_prompt(self, message: str) -> None: 58 | # Assumes that it's always the user's (human's) turn! 59 | prompt_copy = str(self.prompt) 60 | 61 | human_message = self.wrap_human(message) 62 | prompt_copy += human_message 63 | 64 | return prompt_copy 65 | 66 | def get_prompt(self) -> str: 67 | # add EOS if we ended on a "gpt" role (turns is a multiple of 2) 68 | if self.turn_count % 2 == 0: 69 | # remove the newline before EOS 70 | assert self.prompt[-1] == "\n", f"malformed prompt ({self.prompt}) missing newline before EOS append!" 71 | return self.prompt[:-1] + self.eos 72 | 73 | return self.prompt 74 | -------------------------------------------------------------------------------- /prismatic/models/backbones/llm/prompting/vicuna_v15_prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | vicuna_v15_prompter.py 3 | 4 | Defines a PromptBuilder for building Vicuna-v1.5 Chat Prompts. 5 | 6 | Reference: https://huggingface.co/lmsys/vicuna-13b-v1.5 7 | """ 8 | 9 | from typing import Optional 10 | 11 | from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder 12 | 13 | # Default System Prompt for LLaVa Models 14 | SYS_PROMPTS = { 15 | "prismatic": ( 16 | "A chat between a curious user and an artificial intelligence assistant. " 17 | "The assistant gives helpful, detailed, and polite answers to the user's questions." 18 | ), 19 | "openvla": ( 20 | "A chat between a curious user and an artificial intelligence assistant. " 21 | "The assistant gives helpful, detailed, and polite answers to the user's questions." 22 | ), 23 | } 24 | 25 | 26 | class VicunaV15ChatPromptBuilder(PromptBuilder): 27 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 28 | super().__init__(model_family, system_prompt) 29 | self.system_prompt = (SYS_PROMPTS[self.model_family] if system_prompt is None else system_prompt).strip() + " " 30 | 31 | # LLaMa-2 Specific 32 | self.bos, self.eos = "", "" 33 | 34 | # Get role-specific "wrap" functions 35 | self.wrap_human = lambda msg: f"USER: {msg} ASSISTANT: " 36 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" 37 | 38 | # === `self.prompt` gets built up over multiple turns === 39 | self.prompt, self.turn_count = "", 0 40 | 41 | def add_turn(self, role: str, message: str) -> str: 42 | assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") 43 | message = message.replace("", "").strip() 44 | 45 | # Special Handling for "system" prompt (turn_count == 0) 46 | if self.turn_count == 0: 47 | sys_message = self.system_prompt + self.wrap_human(message) 48 | wrapped_message = sys_message 49 | elif (self.turn_count % 2) == 0: 50 | human_message = self.wrap_human(message) 51 | wrapped_message = human_message 52 | else: 53 | gpt_message = self.wrap_gpt(message) 54 | wrapped_message = gpt_message 55 | 56 | # Update Prompt 57 | self.prompt += wrapped_message 58 | 59 | # Bump Turn Counter 60 | self.turn_count += 1 61 | 62 | # Return "wrapped_message" (effective string added to context) 63 | return wrapped_message 64 | 65 | def get_potential_prompt(self, message: str) -> None: 66 | # Assumes that it's always the user's (human's) turn! 67 | prompt_copy = str(self.prompt) 68 | 69 | # Special Handling for "system" prompt (turn_count == 0) 70 | if self.turn_count == 0: 71 | sys_message = self.system_prompt + self.wrap_human(message) 72 | prompt_copy += sys_message 73 | 74 | else: 75 | human_message = self.wrap_human(message) 76 | prompt_copy += human_message 77 | 78 | return prompt_copy.removeprefix(self.bos).rstrip() 79 | 80 | def get_prompt(self) -> str: 81 | # Remove prefix (if exists) because it gets auto-inserted by tokenizer! 82 | return self.prompt.removeprefix(self.bos).rstrip() 83 | -------------------------------------------------------------------------------- /prismatic/models/backbones/llm/qwen25.py: -------------------------------------------------------------------------------- 1 | """ 2 | qwen2_5.py 3 | 4 | Class definition for all LLMs derived from QwenForCausalLM. 5 | """ 6 | 7 | from typing import Optional, Sequence, Type 8 | 9 | import torch 10 | from transformers import AutoModelForCausalLM 11 | from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer 12 | 13 | from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone 14 | from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder 15 | from prismatic.models.backbones.llm.prompting.qwen_prompter import QwenPromptBuilder 16 | 17 | # Registry =>> Support Qwen-2.5 Models (from HF Transformers) 18 | # fmt: off 19 | QWEN25_MODELS = { 20 | # === Pure Qwen2.5 (non-instruct/chat-tuned) Models === 21 | "qwen25-0_5b-extra": { 22 | "llm_family": "qwen2.5", "llm_cls": AutoModelForCausalLM, "hf_hub_path": "Qwen/Qwen2.5-0.5B" 23 | }, 24 | "qwen25-0_5b-pure": { 25 | "llm_family": "qwen2.5", "llm_cls": AutoModelForCausalLM, "hf_hub_path": "Qwen/Qwen2.5-0.5B" 26 | }, 27 | "qwen25-1_5b-pure": { 28 | "llm_family": "qwen2.5", "llm_cls": AutoModelForCausalLM, "hf_hub_path": "Qwen/Qwen2.5-1.5B" 29 | }, 30 | "qwen25-3b-pure": { 31 | "llm_family": "qwen2.5", "llm_cls": AutoModelForCausalLM, "hf_hub_path": "Qwen/Qwen2.5-3B" 32 | }, 33 | "qwen25-7b-pure": { 34 | "llm_family": "qwen2.5", "llm_cls": AutoModelForCausalLM, "hf_hub_path": "Qwen/Qwen2.5-7B" 35 | }, 36 | 37 | } 38 | # fmt: on 39 | 40 | 41 | class Qwen25LLMBackbone(HFCausalLLMBackbone): 42 | def __init__( 43 | self, 44 | llm_backbone_id: str, 45 | llm_max_length: int = 2048, 46 | hf_token: Optional[str] = None, 47 | inference_mode: bool = False, 48 | use_flash_attention_2: bool = True, 49 | num_extra_tokens: int = 0, 50 | ) -> None: 51 | super().__init__( 52 | llm_backbone_id, 53 | llm_max_length=llm_max_length, 54 | hf_token=hf_token, 55 | inference_mode=inference_mode, 56 | use_flash_attention_2=use_flash_attention_2, 57 | **QWEN25_MODELS[llm_backbone_id], 58 | ) 59 | 60 | # add some more special tokens 61 | if num_extra_tokens > 0: 62 | added = self.tokenizer.add_tokens([f"<|extra_{i}|>" for i in range(num_extra_tokens)]) 63 | assert added == num_extra_tokens, f"Added {added} of {num_extra_tokens} extra tokens to tokenizer!" 64 | print(f"Added {num_extra_tokens} extra tokens.") 65 | 66 | # there is already a special token for Qwen 67 | # self.tokenizer.add_special_tokens({"pad_token": ""}) 68 | self.llm.config.pad_token_id = self.tokenizer.pad_token_id 69 | self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64) 70 | 71 | @property 72 | def prompt_builder_fn(self) -> Type[PromptBuilder]: 73 | return QwenPromptBuilder 74 | 75 | @property 76 | def transformer_layer_cls(self) -> Type[torch.nn.Module]: 77 | return Qwen2DecoderLayer 78 | 79 | @property 80 | def half_precision_dtype(self) -> torch.dtype: 81 | return torch.bfloat16 82 | 83 | @property 84 | def last_layer_finetune_modules(self) -> Sequence[torch.nn.Module]: 85 | # TODO not sure that this works 86 | return (self.llm.model.embed_tokens, self.llm.model.layers[-1], self.llm.lm_head) 87 | -------------------------------------------------------------------------------- /vla-scripts/extern/verify_openvla.py: -------------------------------------------------------------------------------- 1 | """ 2 | verify_openvla.py 3 | 4 | Given an HF-exported OpenVLA model, attempt to load via AutoClasses, and verify forward() and predict_action(). 5 | """ 6 | 7 | import time 8 | 9 | import numpy as np 10 | import torch 11 | from PIL import Image 12 | from transformers import AutoModelForVision2Seq, AutoProcessor 13 | 14 | # === Verification Arguments 15 | MODEL_PATH = "openvla/openvla-7b" 16 | SYSTEM_PROMPT = ( 17 | "A chat between a curious user and an artificial intelligence assistant. " 18 | "The assistant gives helpful, detailed, and polite answers to the user's questions." 19 | ) 20 | INSTRUCTION = "put spoon on towel" 21 | 22 | 23 | def get_openvla_prompt(instruction: str) -> str: 24 | if "v01" in MODEL_PATH: 25 | return f"{SYSTEM_PROMPT} USER: What action should the robot take to {instruction.lower()}? ASSISTANT:" 26 | else: 27 | return f"In: What action should the robot take to {instruction.lower()}?\nOut:" 28 | 29 | 30 | @torch.inference_mode() 31 | def verify_openvla() -> None: 32 | print(f"[*] Verifying OpenVLAForActionPrediction using Model `{MODEL_PATH}`") 33 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 34 | 35 | # Load Processor & VLA 36 | print("[*] Instantiating Processor and Pretrained OpenVLA") 37 | processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True) 38 | 39 | # === BFLOAT16 + FLASH-ATTN MODE === 40 | print("[*] Loading in BF16 with Flash-Attention Enabled") 41 | vla = AutoModelForVision2Seq.from_pretrained( 42 | MODEL_PATH, 43 | attn_implementation="flash_attention_2", 44 | torch_dtype=torch.bfloat16, 45 | low_cpu_mem_usage=True, 46 | trust_remote_code=True, 47 | ).to(device) 48 | 49 | # === 8-BIT QUANTIZATION MODE (`pip install bitsandbytes`) :: [~9GB of VRAM Passive || 10GB of VRAM Active] === 50 | # print("[*] Loading in 8-Bit Quantization Mode") 51 | # vla = AutoModelForVision2Seq.from_pretrained( 52 | # MODEL_PATH, 53 | # attn_implementation="flash_attention_2", 54 | # torch_dtype=torch.float16, 55 | # quantization_config=BitsAndBytesConfig(load_in_8bit=True), 56 | # low_cpu_mem_usage=True, 57 | # trust_remote_code=True, 58 | # ) 59 | 60 | # === 4-BIT QUANTIZATION MODE (`pip install bitsandbytes`) :: [~6GB of VRAM Passive || 7GB of VRAM Active] === 61 | # print("[*] Loading in 4-Bit Quantization Mode") 62 | # vla = AutoModelForVision2Seq.from_pretrained( 63 | # MODEL_PATH, 64 | # attn_implementation="flash_attention_2", 65 | # torch_dtype=torch.float16, 66 | # quantization_config=BitsAndBytesConfig(load_in_4bit=True), 67 | # low_cpu_mem_usage=True, 68 | # trust_remote_code=True, 69 | # ) 70 | 71 | print("[*] Iterating with Randomly Generated Images") 72 | for _ in range(100): 73 | prompt = get_openvla_prompt(INSTRUCTION) 74 | image = Image.fromarray(np.asarray(np.random.rand(256, 256, 3) * 255, dtype=np.uint8)) 75 | 76 | # === BFLOAT16 MODE === 77 | inputs = processor(prompt, image).to(device, dtype=torch.bfloat16) 78 | 79 | # === 8-BIT/4-BIT QUANTIZATION MODE === 80 | # inputs = processor(prompt, image).to(device, dtype=torch.float16) 81 | 82 | # Run OpenVLA Inference 83 | start_time = time.time() 84 | action = vla.predict_action(**inputs, unnorm_key="bridge_orig", do_sample=False) 85 | print(f"\t=>> Time: {time.time() - start_time:.4f} || Action: {action}") 86 | 87 | 88 | if __name__ == "__main__": 89 | verify_openvla() 90 | -------------------------------------------------------------------------------- /prismatic/models/backbones/llm/prompting/llama2_chat_prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | llama2_prompter.py 3 | 4 | Defines a PromptBuilder for building LLaMa-2 Chat Prompts --> not sure if this is "optimal", but this is the pattern 5 | that's used by HF and other online tutorials. 6 | 7 | Reference: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 8 | """ 9 | 10 | from typing import Optional 11 | 12 | from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder 13 | 14 | # Default System Prompt for Prismatic Models 15 | SYS_PROMPTS = { 16 | "prismatic": ( 17 | "You are a helpful language and vision assistant. " 18 | "You are able to understand the visual content that the user provides, " 19 | "and assist the user with a variety of tasks using natural language." 20 | ), 21 | "openvla": ( 22 | "You are a helpful language and vision assistant. " 23 | "You are able to understand the visual content that the user provides, " 24 | "and assist the user with a variety of tasks using natural language." 25 | ), 26 | } 27 | 28 | 29 | def format_system_prompt(system_prompt: str) -> str: 30 | return f"<\n{system_prompt.strip()}\n<>\n\n" 31 | 32 | 33 | class LLaMa2ChatPromptBuilder(PromptBuilder): 34 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 35 | super().__init__(model_family, system_prompt) 36 | self.system_prompt = format_system_prompt( 37 | SYS_PROMPTS[self.model_family] if system_prompt is None else system_prompt 38 | ) 39 | 40 | # LLaMa-2 Specific 41 | self.bos, self.eos = "", "" 42 | 43 | # Get role-specific "wrap" functions 44 | self.wrap_human = lambda msg: f"[INST] {msg} [/INST] " 45 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" 46 | 47 | # === `self.prompt` gets built up over multiple turns === 48 | self.prompt, self.turn_count = "", 0 49 | 50 | def add_turn(self, role: str, message: str) -> str: 51 | assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") 52 | message = message.replace("", "").strip() 53 | 54 | # Special Handling for "system" prompt (turn_count == 0) 55 | if self.turn_count == 0: 56 | sys_message = self.wrap_human(self.system_prompt + message) 57 | wrapped_message = sys_message 58 | elif (self.turn_count % 2) == 0: 59 | human_message = self.wrap_human(message) 60 | wrapped_message = human_message 61 | else: 62 | gpt_message = self.wrap_gpt(message) 63 | wrapped_message = gpt_message 64 | 65 | # Update Prompt 66 | self.prompt += wrapped_message 67 | 68 | # Bump Turn Counter 69 | self.turn_count += 1 70 | 71 | # Return "wrapped_message" (effective string added to context) 72 | return wrapped_message 73 | 74 | def get_potential_prompt(self, message: str) -> None: 75 | # Assumes that it's always the user's (human's) turn! 76 | prompt_copy = str(self.prompt) 77 | 78 | # Special Handling for "system" prompt (turn_count == 0) 79 | if self.turn_count == 0: 80 | sys_message = self.wrap_human(self.system_prompt + message) 81 | prompt_copy += sys_message 82 | 83 | else: 84 | human_message = self.wrap_human(message) 85 | prompt_copy += human_message 86 | 87 | return prompt_copy.removeprefix(self.bos).rstrip() 88 | 89 | def get_prompt(self) -> str: 90 | # Remove prefix because it gets auto-inserted by tokenizer! 91 | return self.prompt.removeprefix(self.bos).rstrip() 92 | -------------------------------------------------------------------------------- /prismatic/vla/materialize.py: -------------------------------------------------------------------------------- 1 | """ 2 | materialize.py 3 | 4 | Factory class for initializing Open-X RLDS-backed datasets, given specified data mixture parameters; provides and 5 | exports individual functions for clear control flow. 6 | """ 7 | 8 | from pathlib import Path 9 | from typing import Tuple, Type 10 | 11 | from torch.utils.data import Dataset 12 | from transformers import PreTrainedTokenizerBase 13 | 14 | from prismatic.models.backbones.llm.prompting import PromptBuilder 15 | from prismatic.models.backbones.vision import ImageTransform 16 | from prismatic.models.backbones.vision.base_vision import WrapSequenceImageTransform 17 | from prismatic.util.data_utils import PaddedCollatorForActionPrediction 18 | from prismatic.vla.action_tokenizer import ACTION_TOKENIZERS, ActionTokenizer 19 | from prismatic.vla.datasets import EpisodicRLDSDataset, RLDSBatchTransform, RLDSDataset 20 | 21 | 22 | def get_vla_dataset_and_collator( 23 | data_root_dir: Path, 24 | data_mix: str, 25 | image_transform: ImageTransform, 26 | tokenizer: PreTrainedTokenizerBase, 27 | prompt_builder_fn: Type[PromptBuilder], 28 | default_image_resolution: Tuple[int, int, int], 29 | padding_side: str = "right", 30 | predict_stop_token: bool = True, 31 | shuffle_buffer_size: int = 100_000, 32 | train: bool = True, 33 | episodic: bool = False, 34 | image_aug: bool = False, 35 | action_tokenizer: str = "action_tokenizer", 36 | future_action_window_size: int = 0, 37 | image_window_size: int = 1, 38 | use_wrist_image: bool = False, 39 | ) -> Tuple[Dataset, ActionTokenizer, PaddedCollatorForActionPrediction]: 40 | """Initialize RLDS Dataset (wraps TFDS), ActionTokenizer, and initialize transform/collation functions.""" 41 | 42 | action_tokenizer: ActionTokenizer = ACTION_TOKENIZERS[action_tokenizer](tokenizer) 43 | 44 | # get the future action window needed from the tokenizer 45 | future_action_window_size = max(action_tokenizer.required_future_horizon, future_action_window_size) 46 | 47 | load_camera_views = ("primary", "wrist") if use_wrist_image else ("primary",) 48 | 49 | # get the observation history from the image_transform (only needed if its a WrapSequence transform) 50 | if isinstance(image_transform, WrapSequenceImageTransform): 51 | if use_wrist_image: 52 | # expects groupings of two in image sequence len 53 | assert image_transform.sequence_len % 2 == 0, "With wrist images, image transform must expect 2N images!" 54 | image_window_size = max(image_transform.sequence_len // 2, image_window_size) 55 | else: 56 | image_window_size = max(image_transform.sequence_len, image_window_size) 57 | 58 | batch_transform = RLDSBatchTransform( 59 | action_tokenizer, 60 | tokenizer, 61 | image_transform, 62 | prompt_builder_fn, 63 | predict_stop_token=predict_stop_token, 64 | image_window_size=image_window_size, 65 | use_wrist_image=use_wrist_image, 66 | ) 67 | collator = PaddedCollatorForActionPrediction( 68 | tokenizer.model_max_length, tokenizer.pad_token_id, padding_side=padding_side 69 | ) 70 | 71 | # Build RLDS Iterable Dataset 72 | cls = RLDSDataset if not episodic else EpisodicRLDSDataset 73 | dataset = cls( 74 | data_root_dir, 75 | data_mix, 76 | batch_transform, 77 | resize_resolution=default_image_resolution[1:], 78 | shuffle_buffer_size=shuffle_buffer_size, 79 | train=train, 80 | image_aug=image_aug, 81 | future_action_window_size=future_action_window_size, 82 | image_window_size=image_window_size, 83 | load_camera_views=load_camera_views, 84 | ) 85 | 86 | return dataset, action_tokenizer, collator 87 | -------------------------------------------------------------------------------- /prismatic/models/backbones/llm/llama2.py: -------------------------------------------------------------------------------- 1 | """ 2 | llama2.py 3 | 4 | Class definition for all LLMs derived from LlamaForCausalLM. 5 | """ 6 | 7 | from typing import Optional, Sequence, Type 8 | 9 | import torch 10 | from torch import nn as nn 11 | from transformers import LlamaForCausalLM 12 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer 13 | 14 | from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone 15 | from prismatic.models.backbones.llm.prompting import ( 16 | LLaMa2ChatPromptBuilder, 17 | PromptBuilder, 18 | PurePromptBuilder, 19 | VicunaV15ChatPromptBuilder, 20 | ) 21 | 22 | # Registry =>> Support LLaMa-2 Models (from HF Transformers) 23 | # fmt: off 24 | LLAMA2_MODELS = { 25 | # === Pure Meta LLaMa-2 (non-instruct/chat-tuned) Models === 26 | "llama2-7b-pure": { 27 | "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-7b-hf" 28 | }, 29 | 30 | "llama2-13b-pure": { 31 | "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-13b-hf" 32 | }, 33 | 34 | # === Meta LLaMa-2 Chat Models === 35 | "llama2-7b-chat": { 36 | "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-7b-chat-hf" 37 | }, 38 | 39 | "llama2-13b-chat": { 40 | "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-13b-chat-hf" 41 | }, 42 | 43 | # === Vicuna v1.5 Chat Models === 44 | "vicuna-v15-7b": { 45 | "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "lmsys/vicuna-7b-v1.5" 46 | }, 47 | 48 | "vicuna-v15-13b": { 49 | "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "lmsys/vicuna-13b-v1.5" 50 | }, 51 | } 52 | # fmt: on 53 | 54 | 55 | class LLaMa2LLMBackbone(HFCausalLLMBackbone): 56 | def __init__( 57 | self, 58 | llm_backbone_id: str, 59 | llm_max_length: int = 2048, 60 | hf_token: Optional[str] = None, 61 | inference_mode: bool = False, 62 | use_flash_attention_2: bool = True, 63 | ) -> None: 64 | super().__init__( 65 | llm_backbone_id, 66 | llm_max_length=llm_max_length, 67 | hf_token=hf_token, 68 | inference_mode=inference_mode, 69 | use_flash_attention_2=use_flash_attention_2, 70 | **LLAMA2_MODELS[llm_backbone_id], 71 | ) 72 | 73 | # [Special Case] LLaMa-2 PAD Token Handling --> for clarity, we add an extra token (and resize) 74 | self.tokenizer.add_special_tokens({"pad_token": ""}) 75 | self.llm.config.pad_token_id = self.tokenizer.pad_token_id 76 | self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64) 77 | 78 | @property 79 | def prompt_builder_fn(self) -> Type[PromptBuilder]: 80 | if self.identifier.startswith("llama2-") and self.identifier.endswith("-pure"): 81 | return PurePromptBuilder 82 | 83 | elif self.identifier.startswith("llama2-") and self.identifier.endswith("-chat"): 84 | return LLaMa2ChatPromptBuilder 85 | 86 | elif self.identifier.startswith("vicuna"): 87 | return VicunaV15ChatPromptBuilder 88 | 89 | raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`") 90 | 91 | @property 92 | def transformer_layer_cls(self) -> Type[nn.Module]: 93 | return LlamaDecoderLayer 94 | 95 | @property 96 | def half_precision_dtype(self) -> torch.dtype: 97 | """LLaMa-2 was trained in BF16; see https://huggingface.co/docs/transformers/main/model_doc/llama2.""" 98 | return torch.bfloat16 99 | 100 | @property 101 | def last_layer_finetune_modules(self) -> Sequence[nn.Module]: 102 | return (self.llm.model.embed_tokens, self.llm.model.layers[-1], self.llm.lm_head) 103 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/rlds/obs_transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | obs_transforms.py 3 | 4 | Contains observation-level transforms used in the orca data pipeline. 5 | 6 | These transforms operate on the "observation" dictionary, and are applied at a per-frame level. 7 | """ 8 | 9 | from typing import Dict, Tuple, Union 10 | 11 | import dlimp as dl 12 | import tensorflow as tf 13 | from absl import logging 14 | 15 | 16 | # ruff: noqa: B023 17 | def augment(obs: Dict, seed: tf.Tensor, augment_kwargs: Union[Dict, Dict[str, Dict]]) -> Dict: 18 | """Augments images, skipping padding images.""" 19 | image_names = {key[6:] for key in obs if key.startswith("image_")} 20 | 21 | # "augment_order" is required in augment_kwargs, so if it's there, we can assume that the user has passed 22 | # in a single augmentation dict (otherwise, we assume that the user has passed in a mapping from image 23 | # name to augmentation dict) 24 | if "augment_order" in augment_kwargs: 25 | augment_kwargs = {name: augment_kwargs for name in image_names} 26 | 27 | for i, name in enumerate(image_names): 28 | if name not in augment_kwargs: 29 | continue 30 | kwargs = augment_kwargs[name] 31 | logging.debug(f"Augmenting image_{name} with kwargs {kwargs}") 32 | obs[f"image_{name}"] = tf.cond( 33 | obs["pad_mask_dict"][f"image_{name}"], 34 | lambda: dl.transforms.augment_image( 35 | obs[f"image_{name}"], 36 | **kwargs, 37 | seed=seed + i, # augment each image differently 38 | ), 39 | lambda: obs[f"image_{name}"], # skip padding images 40 | ) 41 | 42 | return obs 43 | 44 | 45 | def decode_and_resize( 46 | obs: Dict, 47 | resize_size: Union[Tuple[int, int], Dict[str, Tuple[int, int]]], 48 | depth_resize_size: Union[Tuple[int, int], Dict[str, Tuple[int, int]]], 49 | ) -> Dict: 50 | """Decodes images and depth images, and then optionally resizes them.""" 51 | image_names = {key[6:] for key in obs if key.startswith("image_")} 52 | depth_names = {key[6:] for key in obs if key.startswith("depth_")} 53 | 54 | if isinstance(resize_size, tuple): 55 | resize_size = {name: resize_size for name in image_names} 56 | if isinstance(depth_resize_size, tuple): 57 | depth_resize_size = {name: depth_resize_size for name in depth_names} 58 | 59 | for name in image_names: 60 | if name not in resize_size: 61 | logging.warning( 62 | f"No resize_size was provided for image_{name}. This will result in 1x1 " 63 | "padding images, which may cause errors if you mix padding and non-padding images." 64 | ) 65 | image = obs[f"image_{name}"] 66 | if image.dtype == tf.string: 67 | if tf.strings.length(image) == 0: 68 | # this is a padding image 69 | image = tf.zeros((*resize_size.get(name, (1, 1)), 3), dtype=tf.uint8) 70 | else: 71 | image = tf.io.decode_image(image, expand_animations=False, dtype=tf.uint8) 72 | elif image.dtype != tf.uint8: 73 | raise ValueError(f"Unsupported image dtype: found image_{name} with dtype {image.dtype}") 74 | if name in resize_size: 75 | image = dl.transforms.resize_image(image, size=resize_size[name]) 76 | obs[f"image_{name}"] = image 77 | 78 | for name in depth_names: 79 | if name not in depth_resize_size: 80 | logging.warning( 81 | f"No depth_resize_size was provided for depth_{name}. This will result in 1x1 " 82 | "padding depth images, which may cause errors if you mix padding and non-padding images." 83 | ) 84 | depth = obs[f"depth_{name}"] 85 | 86 | if depth.dtype == tf.string: 87 | if tf.strings.length(depth) == 0: 88 | depth = tf.zeros((*depth_resize_size.get(name, (1, 1)), 1), dtype=tf.float32) 89 | else: 90 | depth = tf.io.decode_image(depth, expand_animations=False, dtype=tf.float32)[..., 0] 91 | elif depth.dtype != tf.float32: 92 | raise ValueError(f"Unsupported depth dtype: found depth_{name} with dtype {depth.dtype}") 93 | 94 | if name in depth_resize_size: 95 | depth = dl.transforms.resize_depth_image(depth, size=depth_resize_size[name]) 96 | 97 | obs[f"depth_{name}"] = depth 98 | 99 | return obs 100 | -------------------------------------------------------------------------------- /prismatic/models/vlms/base_vlm.py: -------------------------------------------------------------------------------- 1 | """ 2 | base_vlm.py 3 | 4 | Abstract class definition of a Vision-Language Model (VLM), with full annotations of class methods, utility functions, 5 | and initialization logic. This is mostly to future-proof the codebase; while all our experiments instantiate 6 | from PrismaticVLM, theoretically, this base class should be general enough to cover almost all models (e.g., IDEFICS, 7 | PALI, Fuyu) in the future. 8 | 9 | We use Abstract base classes *sparingly* -- mostly as a way to encapsulate any redundant logic or nested inheritance 10 | (e.g., dependence on nn.Module, HF PretrainedModel, etc.). For other abstract objects (e.g., Tokenizers/Transforms), 11 | prefer Protocol definitions instead. 12 | """ 13 | 14 | from __future__ import annotations 15 | 16 | from abc import ABC, abstractmethod 17 | from pathlib import Path 18 | from typing import Callable, List, Optional 19 | 20 | import torch 21 | import torch.nn as nn 22 | from transformers import GenerationMixin, PretrainedConfig 23 | from transformers.modeling_outputs import CausalLMOutputWithPast 24 | 25 | from prismatic.models.backbones.llm import LLMBackbone 26 | from prismatic.models.backbones.llm.prompting import PromptBuilder 27 | from prismatic.models.backbones.vision import VisionBackbone 28 | 29 | 30 | # === Abstract Base Class for arbitrary Vision-Language Models === 31 | class VLM(nn.Module, GenerationMixin, ABC): 32 | def __init__( 33 | self, 34 | model_family: str, 35 | model_id: str, 36 | vision_backbone: VisionBackbone, 37 | llm_backbone: LLMBackbone, 38 | enable_mixed_precision_training: bool = True, 39 | ) -> None: 40 | super().__init__() 41 | self.model_family, self.model_id = model_family, model_id 42 | self.vision_backbone, self.llm_backbone = vision_backbone, llm_backbone 43 | self.enable_mixed_precision_training = enable_mixed_precision_training 44 | 45 | # Instance Attributes for a generic VLM 46 | self.all_module_keys, self.trainable_module_keys = None, None 47 | 48 | # === GenerationMixin Expected Attributes =>> *DO NOT MODIFY* === 49 | self.generation_config = self.llm_backbone.llm.generation_config 50 | self.main_input_name = "input_ids" 51 | 52 | @property 53 | def device(self) -> torch.device: 54 | """Borrowed from `transformers.modeling_utils.py` -- checks parameter device; assumes model on *ONE* device!""" 55 | return next(self.parameters()).device 56 | 57 | @classmethod 58 | @abstractmethod 59 | def from_pretrained( 60 | cls, 61 | pretrained_checkpoint: Path, 62 | model_family: str, 63 | model_id: str, 64 | vision_backbone: VisionBackbone, 65 | llm_backbone: LLMBackbone, 66 | **kwargs: str, 67 | ) -> VLM: ... 68 | 69 | @abstractmethod 70 | def get_prompt_builder(self, system_prompt: Optional[str] = None) -> PromptBuilder: ... 71 | 72 | @abstractmethod 73 | def freeze_backbones(self, stage: str) -> None: ... 74 | 75 | @abstractmethod 76 | def load_from_checkpoint(self, stage: str, run_dir: Path, pretrained_checkpoint: Optional[Path] = None) -> None: ... 77 | 78 | @abstractmethod 79 | def get_fsdp_wrapping_policy(self) -> Callable: ... 80 | 81 | @abstractmethod 82 | def forward( 83 | self, 84 | input_ids: Optional[torch.LongTensor] = None, 85 | attention_mask: Optional[torch.Tensor] = None, 86 | pixel_values: Optional[torch.FloatTensor] = None, 87 | labels: Optional[torch.LongTensor] = None, 88 | inputs_embeds: Optional[torch.FloatTensor] = None, 89 | past_key_values: Optional[List[torch.FloatTensor]] = None, 90 | use_cache: Optional[bool] = None, 91 | output_attentions: Optional[bool] = None, 92 | output_hidden_states: Optional[bool] = None, 93 | return_dict: Optional[bool] = None, 94 | multimodal_indices: Optional[torch.LongTensor] = None, 95 | ) -> CausalLMOutputWithPast: ... 96 | 97 | # === GenerationMixin Expected Properties & Methods (DO NOT MODIFY) === 98 | @staticmethod 99 | def can_generate() -> bool: 100 | return True 101 | 102 | @property 103 | def config(self) -> PretrainedConfig: 104 | return self.llm_backbone.llm.config 105 | 106 | # => Beam Search Utility 107 | def _reorder_cache(self, past_key_values, beam_idx): 108 | return self.llm_backbone.llm._reorder_cache(past_key_values, beam_idx) 109 | -------------------------------------------------------------------------------- /experiments/robot/libero/libero_utils.py: -------------------------------------------------------------------------------- 1 | """Utils for evaluating policies in LIBERO simulation environments.""" 2 | 3 | import math 4 | import os 5 | 6 | import imageio 7 | import numpy as np 8 | import tensorflow as tf 9 | from libero.libero import get_libero_path 10 | from libero.libero.envs import OffScreenRenderEnv 11 | from PIL import Image 12 | 13 | from experiments.robot.robot_utils import ( 14 | DATE, 15 | DATE_TIME, 16 | ) 17 | 18 | 19 | def get_libero_env(task, model_family, resolution=256): 20 | """Initializes and returns the LIBERO environment, along with the task description.""" 21 | task_description = task.language 22 | task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file) 23 | env_args = {"bddl_file_name": task_bddl_file, "camera_heights": resolution, "camera_widths": resolution} 24 | env = OffScreenRenderEnv(**env_args) 25 | env.seed(0) # IMPORTANT: seed seems to affect object positions even when using fixed initial state 26 | return env, task_description 27 | 28 | 29 | def get_libero_dummy_action(model_family: str): 30 | """Get dummy/no-op action, used to roll out the simulation while the robot does nothing.""" 31 | return [0, 0, 0, 0, 0, 0, -1] 32 | 33 | 34 | def resize_image(img, resize_size): 35 | """ 36 | Takes numpy array corresponding to a single image and returns resized image as numpy array. 37 | 38 | NOTE (Moo Jin): To make input images in distribution with respect to the inputs seen at training time, we follow 39 | the same resizing scheme used in the Octo dataloader, which OpenVLA uses for training. 40 | """ 41 | assert isinstance(resize_size, tuple) 42 | # Resize to image size expected by model 43 | img = tf.image.encode_jpeg(img) # Encode as JPEG, as done in RLDS dataset builder 44 | img = tf.io.decode_image(img, expand_animations=False, dtype=tf.uint8) # Immediately decode back 45 | img = tf.image.resize(img, resize_size, method="lanczos3", antialias=True) 46 | img = tf.cast(tf.clip_by_value(tf.round(img), 0, 255), tf.uint8) 47 | img = img.numpy() 48 | return img 49 | 50 | 51 | def get_libero_image(obs, resize_size, key="agentview_image"): 52 | """Extracts image from observations and preprocesses it.""" 53 | assert isinstance(resize_size, int) or isinstance(resize_size, tuple) 54 | if isinstance(resize_size, int): 55 | resize_size = (resize_size, resize_size) 56 | img = obs[key] 57 | img = np.flipud(img) 58 | # img = img[::-1, ::-1] # IMPORTANT: rotate 180 degrees to match train preprocessing 59 | img = Image.fromarray(img) 60 | img = img.resize(resize_size, Image.Resampling.LANCZOS) # resize to size seen at train time 61 | img = img.convert("RGB") 62 | return np.array(img) 63 | 64 | 65 | def save_rollout_video(rollout_images, idx, success, task_description, log_file=None): 66 | """Saves an MP4 replay of an episode.""" 67 | rollout_dir = f"./rollouts/{DATE}" 68 | os.makedirs(rollout_dir, exist_ok=True) 69 | processed_task_description = task_description.lower().replace(" ", "_").replace("\n", "_").replace(".", "_")[:50] 70 | mp4_path = f"{rollout_dir}/{DATE_TIME}--episode={idx}--success={success}--task={processed_task_description}.mp4" 71 | video_writer = imageio.get_writer(mp4_path, fps=30) 72 | for img in rollout_images: 73 | video_writer.append_data(img) 74 | video_writer.close() 75 | print(f"Saved rollout MP4 at path {mp4_path}") 76 | if log_file is not None: 77 | log_file.write(f"Saved rollout MP4 at path {mp4_path}\n") 78 | return mp4_path 79 | 80 | 81 | def quat2axisangle(quat): 82 | """ 83 | Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55 84 | 85 | Converts quaternion to axis-angle format. 86 | Returns a unit vector direction scaled by its angle in radians. 87 | 88 | Args: 89 | quat (np.array): (x,y,z,w) vec4 float angles 90 | 91 | Returns: 92 | np.array: (ax,ay,az) axis-angle exponential coordinates 93 | """ 94 | # clip quaternion 95 | if quat[3] > 1.0: 96 | quat[3] = 1.0 97 | elif quat[3] < -1.0: 98 | quat[3] = -1.0 99 | 100 | den = np.sqrt(1.0 - quat[3] * quat[3]) 101 | if math.isclose(den, 0.0): 102 | # This is (close to) a zero degree rotation, immediately return 103 | return np.zeros(3) 104 | 105 | return (quat[:3] * 2.0 * math.acos(quat[3])) / den 106 | -------------------------------------------------------------------------------- /experiments/robot/robot_utils.py: -------------------------------------------------------------------------------- 1 | """Utils for evaluating robot policies in various environments.""" 2 | 3 | import math 4 | import os 5 | import random 6 | import time 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from experiments.robot.openvla_utils import ( 12 | get_prismatic_vla, 13 | get_prismatic_vla_action, 14 | get_vla, 15 | get_vla_action, 16 | ) 17 | 18 | # Initialize important constants and pretty-printing mode in NumPy. 19 | ACTION_DIM = 7 20 | DATE = time.strftime("%Y_%m_%d") 21 | DATE_TIME = time.strftime("%Y_%m_%d-%H_%M_%S") 22 | DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 23 | np.set_printoptions(formatter={"float": lambda x: "{0:0.3f}".format(x)}) 24 | 25 | # Initialize system prompt for OpenVLA v0.1. 26 | OPENVLA_V01_SYSTEM_PROMPT = ( 27 | "A chat between a curious user and an artificial intelligence assistant. " 28 | "The assistant gives helpful, detailed, and polite answers to the user's questions." 29 | ) 30 | 31 | 32 | def round_to_n(x, n=1): 33 | # round scalar to n significant figure(s) 34 | return round(x, -int(math.floor(math.log10(abs(x))) + (n - 1))) 35 | 36 | 37 | def hr_name(float_arg, fp=None): 38 | if fp is not None: 39 | float_arg = round_to_n(float_arg, n=fp) 40 | return str(float_arg).replace(".", "_") 41 | 42 | 43 | def set_seed_everywhere(seed: int): 44 | """Sets the random seed for Python, NumPy, and PyTorch functions.""" 45 | torch.manual_seed(seed) 46 | torch.cuda.manual_seed_all(seed) 47 | np.random.seed(seed) 48 | random.seed(seed) 49 | torch.backends.cudnn.deterministic = True 50 | torch.backends.cudnn.benchmark = False 51 | os.environ["PYTHONHASHSEED"] = str(seed) 52 | 53 | 54 | def get_model(cfg, wrap_diffusion_policy_for_droid=False): 55 | """Load model for evaluation.""" 56 | if cfg.model_family == "prismatic": 57 | model = get_prismatic_vla(cfg) 58 | elif cfg.model_family == "openvla": 59 | model = get_vla(cfg) 60 | else: 61 | raise ValueError(f"Unexpected `model_family` found in config ({cfg.model_family}).") 62 | print(f"Loaded model: {type(model)}") 63 | return model 64 | 65 | 66 | def get_image_resize_size(cfg): 67 | """ 68 | Gets image resize size for a model class. 69 | If `resize_size` is an int, then the resized image will be a square. 70 | Else, the image will be a rectangle. 71 | """ 72 | if cfg.model_family == "prismatic": 73 | resize_size = 224 74 | elif cfg.model_family == "openvla": 75 | resize_size = 224 76 | else: 77 | raise ValueError("Unexpected `model_family` found in config.") 78 | return resize_size 79 | 80 | 81 | def get_action(cfg, model, obs, task_label, processor=None): 82 | """Queries the model to get an action.""" 83 | if cfg.model_family == "prismatic": 84 | action = get_prismatic_vla_action( 85 | model, processor, cfg.pretrained_checkpoint, obs, task_label, cfg.unnorm_key, center_crop=cfg.center_crop 86 | ) 87 | assert action.shape == (ACTION_DIM,) 88 | elif cfg.model_family == "openvla": 89 | action = get_vla_action( 90 | model, processor, cfg.pretrained_checkpoint, obs, task_label, cfg.unnorm_key, center_crop=cfg.center_crop 91 | ) 92 | assert action.shape == (ACTION_DIM,) 93 | else: 94 | raise ValueError("Unexpected `model_family` found in config.") 95 | return action 96 | 97 | 98 | def normalize_gripper_action(action, binarize=True): 99 | """ 100 | Changes gripper action (last dimension of action vector) from [0,1] to [-1,+1]. 101 | Necessary for some environments (not Bridge) because the dataset wrapper standardizes gripper actions to [0,1]. 102 | Note that unlike the other action dimensions, the gripper action is not normalized to [-1,+1] by default by 103 | the dataset wrapper. 104 | 105 | Normalization formula: y = 2 * (x - orig_low) / (orig_high - orig_low) - 1 106 | """ 107 | # Just normalize the last action to [-1,+1]. 108 | orig_low, orig_high = 0.0, 1.0 109 | action[..., -1] = 2 * (action[..., -1] - orig_low) / (orig_high - orig_low) - 1 110 | 111 | if binarize: 112 | # Binarize to -1 or +1. 113 | action[..., -1] = np.sign(action[..., -1]) 114 | 115 | return action 116 | 117 | 118 | def invert_gripper_action(action): 119 | """ 120 | Flips the sign of the gripper action (last dimension of action vector). 121 | This is necessary for some environments where -1 = open, +1 = close, since 122 | the RLDS dataloader aligns gripper actions such that 0 = close, 1 = open. 123 | """ 124 | action[..., -1] = action[..., -1] * -1.0 125 | return action 126 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/rlds/traj_transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | traj_transforms.py 3 | 4 | Contains trajectory transforms used in the orca data pipeline. Trajectory transforms operate on a dictionary 5 | that represents a single trajectory, meaning each tensor has the same leading dimension (the trajectory length). 6 | """ 7 | 8 | import logging 9 | from typing import Dict 10 | 11 | import tensorflow as tf 12 | 13 | 14 | def chunk_act_obs(traj: Dict, window_size: int, future_action_window_size: int = 0) -> Dict: 15 | """ 16 | Chunks actions and observations into the given window_size. 17 | 18 | "observation" keys are given a new axis (at index 1) of size `window_size` containing `window_size - 1` 19 | observations from the past and the current observation. "action" is given a new axis (at index 1) of size 20 | `window_size + future_action_window_size` containing `window_size - 1` actions from the past, the current 21 | action, and `future_action_window_size` actions from the future. "pad_mask" is added to "observation" and 22 | indicates whether an observation should be considered padding (i.e. if it had come from a timestep 23 | before the start of the trajectory). 24 | """ 25 | traj_len = tf.shape(traj["action"])[0] 26 | action_dim = traj["action"].shape[-1] 27 | chunk_indices = tf.broadcast_to(tf.range(-window_size + 1, 1), [traj_len, window_size]) + tf.broadcast_to( 28 | tf.range(traj_len)[:, None], [traj_len, window_size] 29 | ) 30 | 31 | action_chunk_indices = tf.broadcast_to( 32 | tf.range(-window_size + 1, 1 + future_action_window_size), 33 | [traj_len, window_size + future_action_window_size], 34 | ) + tf.broadcast_to( 35 | tf.range(traj_len)[:, None], 36 | [traj_len, window_size + future_action_window_size], 37 | ) 38 | 39 | floored_chunk_indices = tf.maximum(chunk_indices, 0) 40 | 41 | if "timestep" in traj["task"]: 42 | goal_timestep = traj["task"]["timestep"] 43 | else: 44 | goal_timestep = tf.fill([traj_len], traj_len - 1) 45 | 46 | floored_action_chunk_indices = tf.minimum(tf.maximum(action_chunk_indices, 0), goal_timestep[:, None]) 47 | 48 | traj["observation"] = tf.nest.map_structure(lambda x: tf.gather(x, floored_chunk_indices), traj["observation"]) 49 | traj["action"] = tf.gather(traj["action"], floored_action_chunk_indices) 50 | 51 | # indicates whether an entire observation is padding 52 | traj["observation"]["pad_mask"] = chunk_indices >= 0 53 | 54 | # if no absolute_action_mask was provided, assume all actions are relative 55 | if "absolute_action_mask" not in traj and future_action_window_size > 0: 56 | logging.warning( 57 | "future_action_window_size > 0 but no absolute_action_mask was provided. " 58 | "Assuming all actions are relative for the purpose of making neutral actions." 59 | ) 60 | absolute_action_mask = traj.get("absolute_action_mask", tf.zeros([traj_len, action_dim], dtype=tf.bool)) 61 | neutral_actions = tf.where( 62 | absolute_action_mask[:, None, :], 63 | traj["action"], # absolute actions are repeated (already done during chunking) 64 | tf.zeros_like(traj["action"]), # relative actions are zeroed 65 | ) 66 | 67 | # actions past the goal timestep become neutral 68 | action_past_goal = action_chunk_indices > goal_timestep[:, None] 69 | traj["action"] = tf.where(action_past_goal[:, :, None], neutral_actions, traj["action"]) 70 | 71 | return traj 72 | 73 | 74 | def subsample(traj: Dict, subsample_length: int) -> Dict: 75 | """Subsamples trajectories to the given length.""" 76 | traj_len = tf.shape(traj["action"])[0] 77 | if traj_len > subsample_length: 78 | indices = tf.random.shuffle(tf.range(traj_len))[:subsample_length] 79 | traj = tf.nest.map_structure(lambda x: tf.gather(x, indices), traj) 80 | 81 | return traj 82 | 83 | 84 | def add_pad_mask_dict(traj: Dict) -> Dict: 85 | """ 86 | Adds a dictionary indicating which elements of the observation/task should be treated as padding. 87 | =>> traj["observation"|"task"]["pad_mask_dict"] = {k: traj["observation"|"task"][k] is not padding} 88 | """ 89 | traj_len = tf.shape(traj["action"])[0] 90 | 91 | for key in ["observation", "task"]: 92 | pad_mask_dict = {} 93 | for subkey in traj[key]: 94 | # Handles "language_instruction", "image_*", and "depth_*" 95 | if traj[key][subkey].dtype == tf.string: 96 | pad_mask_dict[subkey] = tf.strings.length(traj[key][subkey]) != 0 97 | 98 | # All other keys should not be treated as padding 99 | else: 100 | pad_mask_dict[subkey] = tf.ones([traj_len], dtype=tf.bool) 101 | 102 | traj[key]["pad_mask_dict"] = pad_mask_dict 103 | 104 | return traj 105 | -------------------------------------------------------------------------------- /prismatic/util/torch_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | torch_utils.py 3 | 4 | General utilities for randomness, mixed precision training, and miscellaneous checks in PyTorch. 5 | 6 | Random `set_global_seed` functionality is taken directly from PyTorch-Lighting: 7 | > Ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/seed.py 8 | 9 | This is pretty important to get right if we're every randomly generating our masks (or prefix dropout) inside our 10 | Dataset __getitem__() with multiple workers... if not handled properly, we will get repeated augmentations anytime 11 | we inject randomness from non-PyTorch sources (e.g., numpy, random)! 12 | > Ref: https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/ 13 | 14 | Terminology 15 | -> World Size :: Total number of processes distributed over (# nodes x # devices) -- assumed homogenous! 16 | -> Rank :: Integer index of current process in the total world size 17 | -> Local Rank :: Local index on given node in [0, Devices per Node] 18 | """ 19 | 20 | import os 21 | import random 22 | from typing import Callable, Optional 23 | 24 | import numpy as np 25 | import torch 26 | 27 | # === Randomness === 28 | 29 | 30 | def set_global_seed(seed: int, get_worker_init_fn: bool = False) -> Optional[Callable[[int], None]]: 31 | """Sets seed for all randomness libraries (mostly random, numpy, torch) and produces a `worker_init_fn`""" 32 | assert np.iinfo(np.uint32).min < seed < np.iinfo(np.uint32).max, "Seed outside the np.uint32 bounds!" 33 | 34 | # Set Seed as an Environment Variable 35 | os.environ["EXPERIMENT_GLOBAL_SEED"] = str(seed) 36 | random.seed(seed) 37 | np.random.seed(seed) 38 | torch.manual_seed(seed) 39 | 40 | return worker_init_function if get_worker_init_fn else None 41 | 42 | 43 | def worker_init_function(worker_id: int) -> None: 44 | """ 45 | Borrowed directly from PyTorch-Lightning; inspired by this issue comment in the PyTorch repo: 46 | > Ref: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562 47 | 48 | Intuition: You can think of the seed sequence spawn function as a "janky" torch.Generator() or jax.PRNGKey that 49 | you can run iterative splitting on to get new (predictable) randomness. 50 | 51 | :param worker_id: Identifier for the given worker [0, num_workers) for the Dataloader in question. 52 | """ 53 | # Get current `rank` (if running distributed) and `process_seed` 54 | global_rank, process_seed = int(os.environ["LOCAL_RANK"]), torch.initial_seed() 55 | 56 | # Back out the "base" (original) seed - the per-worker seed is set in PyTorch: 57 | # > https://pytorch.org/docs/stable/data.html#data-loading-randomness 58 | base_seed = process_seed - worker_id 59 | 60 | # "Magic" code --> basically creates a seed sequence that mixes different "sources" and seeds every library... 61 | seed_seq = np.random.SeedSequence([base_seed, worker_id, global_rank]) 62 | 63 | # Use 128 bits (4 x 32-bit words) to represent seed --> generate_state(k) produces a `k` element array! 64 | np.random.seed(seed_seq.generate_state(4)) 65 | 66 | # Spawn distinct child sequences for PyTorch (reseed) and stdlib random 67 | torch_seed_seq, random_seed_seq = seed_seq.spawn(2) 68 | 69 | # Torch Manual seed takes 64 bits (so just specify a dtype of uint64 70 | torch.manual_seed(torch_seed_seq.generate_state(1, dtype=np.uint64)[0]) 71 | 72 | # Use 128 Bits for `random`, but express as integer instead of as an array 73 | random_seed = (random_seed_seq.generate_state(2, dtype=np.uint64).astype(list) * [1 << 64, 1]).sum() 74 | random.seed(random_seed) 75 | 76 | 77 | # === BFloat16 Support === 78 | 79 | 80 | def check_bloat16_supported() -> bool: 81 | try: 82 | import packaging.version 83 | import torch.cuda.nccl as nccl 84 | import torch.distributed as dist 85 | 86 | return ( 87 | (torch.version.cuda is not None) 88 | and torch.cuda.is_bf16_supported() 89 | and (packaging.version.parse(torch.version.cuda).release >= (11, 0)) 90 | and dist.is_nccl_available() 91 | and (nccl.version() >= (2, 10)) 92 | ) 93 | 94 | except Exception: 95 | return False 96 | 97 | 98 | # === Other helpers === 99 | 100 | 101 | def sequence_combine_call_split(sequence: torch.Tensor, fn: Callable): 102 | # image sequence must be (B, T, ...) 103 | B, T = sequence.shape[:2] 104 | flat_sequence = sequence.reshape([-1, *sequence.shape[2:]]) 105 | # outputs will be (B*T, ...) 106 | flat_outputs = fn(flat_sequence) 107 | return flat_outputs.reshape([B, T, *flat_outputs.shape[1:]]) 108 | 109 | 110 | def merge_two_dims(tensor: torch.Tensor, start_dim: int = 0): 111 | # wrap around 112 | if start_dim < 0: 113 | start_dim = len(tensor.shape) + start_dim 114 | assert start_dim >= 0 115 | # check the next dimension is also within bounds 116 | assert len(tensor.shape) > start_dim + 1, "Start dimension for merge is too big!" 117 | 118 | # merge the dimension 119 | return tensor.reshape([*tensor.shape[:start_dim], -1, *tensor.shape[start_dim + 2 :]]) 120 | -------------------------------------------------------------------------------- /scripts/generate.py: -------------------------------------------------------------------------------- 1 | """ 2 | generate.py 3 | 4 | Simple CLI script to interactively test generating from a pretrained VLM; provides a minimal REPL for specify image 5 | URLs, prompts, and language generation parameters. 6 | 7 | Run with: python scripts/generate.py --model_path 8 | """ 9 | 10 | import os 11 | from dataclasses import dataclass 12 | from pathlib import Path 13 | from typing import Union 14 | 15 | import draccus 16 | import requests 17 | import torch 18 | from PIL import Image 19 | 20 | from prismatic import load 21 | from prismatic.overwatch import initialize_overwatch 22 | 23 | # Initialize Overwatch =>> Wraps `logging.Logger` 24 | overwatch = initialize_overwatch(__name__) 25 | 26 | 27 | # Default Image URL (Beignets) 28 | DEFAULT_IMAGE_URL = ( 29 | "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png" 30 | ) 31 | 32 | 33 | @dataclass 34 | class GenerateConfig: 35 | # fmt: off 36 | model_path: Union[str, Path] = ( # Path to Pretrained VLM (on disk or HF Hub) 37 | "prism-dinosiglip+7b" 38 | ) 39 | 40 | # HF Hub Credentials (required for Gated Models like LLaMa-2) 41 | hf_token: Union[str, Path] = Path(".hf_token") # Environment variable or Path to HF Token 42 | 43 | # Default Generation Parameters =>> subscribes to HuggingFace's GenerateMixIn API 44 | do_sample: bool = False 45 | temperature: float = 1.0 46 | max_new_tokens: int = 512 47 | min_length: int = 1 48 | 49 | # fmt: on 50 | 51 | 52 | @draccus.wrap() 53 | def generate(cfg: GenerateConfig) -> None: 54 | overwatch.info(f"Initializing Generation Playground with Prismatic Model `{cfg.model_path}`") 55 | hf_token = cfg.hf_token.read_text().strip() if isinstance(cfg.hf_token, Path) else os.environ[cfg.hf_token] 56 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 57 | 58 | # Load the pretrained VLM --> uses default `load()` function 59 | vlm = load(cfg.model_path, hf_token=hf_token) 60 | vlm.to(device, dtype=torch.bfloat16) 61 | 62 | # Initial Setup 63 | image = Image.open(requests.get(DEFAULT_IMAGE_URL, stream=True).raw).convert("RGB") 64 | prompt_builder = vlm.get_prompt_builder() 65 | system_prompt = prompt_builder.system_prompt 66 | 67 | # REPL Welcome Message 68 | print( 69 | "[*] Dropping into Prismatic VLM REPL with Default Generation Setup => Initial Conditions:\n" 70 | f" => Prompt Template:\n\n{prompt_builder.get_potential_prompt('')}\n\n" 71 | f" => Default Image URL: `{DEFAULT_IMAGE_URL}`\n===\n" 72 | ) 73 | 74 | # REPL 75 | repl_prompt = ( 76 | "|=>> Enter (i)mage to fetch image from URL, (p)rompt to update prompt template, (q)uit to exit, or any other" 77 | " key to enter input questions: " 78 | ) 79 | while True: 80 | user_input = input(repl_prompt) 81 | 82 | if user_input.lower().startswith("q"): 83 | print("\n|=>> Received (q)uit signal => Exiting...") 84 | return 85 | 86 | elif user_input.lower().startswith("i"): 87 | # Note => a new image starts a _new_ conversation (for now) 88 | url = input("\n|=>> Enter Image URL: ") 89 | image = Image.open(requests.get(url, stream=True).raw).convert("RGB") 90 | prompt_builder = vlm.get_prompt_builder(system_prompt=system_prompt) 91 | 92 | elif user_input.lower().startswith("p"): 93 | if system_prompt is None: 94 | print("\n|=>> Model does not support `system_prompt`!") 95 | continue 96 | 97 | # Note => a new system prompt starts a _new_ conversation 98 | system_prompt = input("\n|=>> Enter New System Prompt: ") 99 | prompt_builder = vlm.get_prompt_builder(system_prompt=system_prompt) 100 | print( 101 | "\n[*] Set New System Prompt:\n" 102 | f" => Prompt Template:\n{prompt_builder.get_potential_prompt('')}\n\n" 103 | ) 104 | 105 | else: 106 | print("\n[*] Entering Chat Session - CTRL-C to start afresh!\n===\n") 107 | try: 108 | while True: 109 | message = input("|=>> Enter Prompt: ") 110 | 111 | # Build Prompt 112 | prompt_builder.add_turn(role="human", message=message) 113 | prompt_text = prompt_builder.get_prompt() 114 | 115 | # Generate from the VLM 116 | generated_text = vlm.generate( 117 | image, 118 | prompt_text, 119 | do_sample=cfg.do_sample, 120 | temperature=cfg.temperature, 121 | max_new_tokens=cfg.max_new_tokens, 122 | min_length=cfg.min_length, 123 | ) 124 | prompt_builder.add_turn(role="gpt", message=generated_text) 125 | print(f"\t|=>> VLM Response >>> {generated_text}\n") 126 | 127 | except KeyboardInterrupt: 128 | print("\n===\n") 129 | continue 130 | 131 | 132 | if __name__ == "__main__": 133 | generate() 134 | -------------------------------------------------------------------------------- /prismatic/conf/datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | datasets.py 3 | 4 | Draccus Dataclass Definition for a DatasetConfig object, with various registered subclasses for each dataset variant 5 | and processing scheme. A given dataset variant (e.g., `llava-lightning`) configures the following attributes: 6 | - Dataset Variant (Identifier) --> e.g., "llava-v15" 7 | - Align Stage Dataset Components (annotations, images) 8 | - Finetune Stage Dataset Components (annotations, images) 9 | - Dataset Root Directory (Path) 10 | """ 11 | 12 | import os 13 | from dataclasses import dataclass 14 | from enum import Enum, unique 15 | from pathlib import Path 16 | from typing import Tuple 17 | 18 | from draccus import ChoiceRegistry 19 | 20 | DEFAULT_DATA_ROOT = "/tmp/datasets" 21 | 22 | 23 | @dataclass 24 | class DatasetConfig(ChoiceRegistry): 25 | # fmt: off 26 | dataset_id: str # Unique ID that fully specifies a dataset variant 27 | 28 | # Dataset Components for each Stage in < align | finetune > 29 | align_stage_components: Tuple[Path, Path] # Path to annotation file and images directory for `align` stage 30 | finetune_stage_components: Tuple[Path, Path] # Path to annotation file and images directory for `finetune` stage 31 | 32 | dataset_root_dir: Path # Path to dataset root directory; others paths are relative to root 33 | # fmt: on 34 | 35 | 36 | # [Reproduction] LLaVa-v15 (exact dataset used in all public LLaVa-v15 models) 37 | @dataclass 38 | class LLaVa_V15_Config(DatasetConfig): 39 | dataset_id: str = "llava-v15" 40 | 41 | align_stage_components: Tuple[Path, Path] = ( 42 | Path("download/llava-laion-cc-sbu-558k/chat.json"), 43 | Path("download/llava-laion-cc-sbu-558k/"), 44 | ) 45 | finetune_stage_components: Tuple[Path, Path] = ( 46 | Path("download/llava-v1.5-instruct/llava_v1_5_mix665k.json"), 47 | Path("download/llava-v1.5-instruct/"), 48 | ) 49 | dataset_root_dir: Path = Path(os.environ.get("PRISMATIC_DATA_ROOT", DEFAULT_DATA_ROOT)) 50 | 51 | 52 | # [Multimodal-Only] LLava-v15 WITHOUT the Language-Only ShareGPT Data (No Co-Training) 53 | @dataclass 54 | class LLaVa_Multimodal_Only_Config(DatasetConfig): 55 | dataset_id: str = "llava-multimodal" 56 | 57 | align_stage_components: Tuple[Path, Path] = ( 58 | Path("download/llava-laion-cc-sbu-558k/chat.json"), 59 | Path("download/llava-laion-cc-sbu-558k/"), 60 | ) 61 | finetune_stage_components: Tuple[Path, Path] = ( 62 | Path("download/llava-v1.5-instruct/llava_v1_5_stripped625k.json"), 63 | Path("download/llava-v1.5-instruct/"), 64 | ) 65 | dataset_root_dir: Path = Path(os.environ.get("PRISMATIC_DATA_ROOT", DEFAULT_DATA_ROOT)) 66 | 67 | 68 | # LLaVa-v15 + LVIS-Instruct-4V 69 | @dataclass 70 | class LLaVa_LVIS4V_Config(DatasetConfig): 71 | dataset_id: str = "llava-lvis4v" 72 | 73 | align_stage_components: Tuple[Path, Path] = ( 74 | Path("download/llava-laion-cc-sbu-558k/chat.json"), 75 | Path("download/llava-laion-cc-sbu-558k/"), 76 | ) 77 | finetune_stage_components: Tuple[Path, Path] = ( 78 | Path("download/llava-v1.5-instruct/llava_v1_5_lvis4v_mix888k.json"), 79 | Path("download/llava-v1.5-instruct/"), 80 | ) 81 | dataset_root_dir: Path = Path(os.environ.get("PRISMATIC_DATA_ROOT", DEFAULT_DATA_ROOT)) 82 | 83 | 84 | # LLaVa-v15 + LRV-Instruct 85 | @dataclass 86 | class LLaVa_LRV_Config(DatasetConfig): 87 | dataset_id: str = "llava-lrv" 88 | 89 | align_stage_components: Tuple[Path, Path] = ( 90 | Path("download/llava-laion-cc-sbu-558k/chat.json"), 91 | Path("download/llava-laion-cc-sbu-558k/"), 92 | ) 93 | finetune_stage_components: Tuple[Path, Path] = ( 94 | Path("download/llava-v1.5-instruct/llava_v1_5_lrv_mix1008k.json"), 95 | Path("download/llava-v1.5-instruct/"), 96 | ) 97 | dataset_root_dir: Path = Path(os.environ.get("PRISMATIC_DATA_ROOT", DEFAULT_DATA_ROOT)) 98 | 99 | 100 | # LLaVa-v15 + LVIS-Instruct-4V + LRV-Instruct 101 | @dataclass 102 | class LLaVa_LVIS4V_LRV_Config(DatasetConfig): 103 | dataset_id: str = "llava-lvis4v-lrv" 104 | 105 | align_stage_components: Tuple[Path, Path] = ( 106 | Path("download/llava-laion-cc-sbu-558k/chat.json"), 107 | Path("download/llava-laion-cc-sbu-558k/"), 108 | ) 109 | finetune_stage_components: Tuple[Path, Path] = ( 110 | Path("download/llava-v1.5-instruct/llava_v1_5_lvis4v_lrv_mix1231k.json"), 111 | Path("download/llava-v1.5-instruct/"), 112 | ) 113 | dataset_root_dir: Path = Path(os.environ["PRISMATIC_DATA_ROOT"]) 114 | 115 | 116 | # === Define a Dataset Registry Enum for Reference & Validation =>> all *new* datasets must be added here! === 117 | @unique 118 | class DatasetRegistry(Enum): 119 | # === LLaVa v1.5 === 120 | LLAVA_V15 = LLaVa_V15_Config 121 | 122 | LLAVA_MULTIMODAL_ONLY = LLaVa_Multimodal_Only_Config 123 | 124 | LLAVA_LVIS4V = LLaVa_LVIS4V_Config 125 | LLAVA_LRV = LLaVa_LRV_Config 126 | 127 | LLAVA_LVIS4V_LRV = LLaVa_LVIS4V_LRV_Config 128 | 129 | @property 130 | def dataset_id(self) -> str: 131 | return self.value.dataset_id 132 | 133 | 134 | # Register Datasets in Choice Registry 135 | for dataset_variant in DatasetRegistry: 136 | DatasetConfig.register_subclass(dataset_variant.dataset_id, dataset_variant.value) 137 | -------------------------------------------------------------------------------- /experiments/robot/bridge/bridgev2_utils.py: -------------------------------------------------------------------------------- 1 | """Utils for evaluating policies in real-world BridgeData V2 environments.""" 2 | 3 | import os 4 | import sys 5 | import time 6 | 7 | import imageio 8 | import numpy as np 9 | import tensorflow as tf 10 | import torch 11 | from widowx_envs.widowx_env_service import WidowXClient, WidowXConfigs 12 | 13 | sys.path.append(".") 14 | from experiments.robot.bridge.widowx_env import WidowXGym 15 | 16 | # Initialize important constants and pretty-printing mode in NumPy. 17 | ACTION_DIM = 7 18 | BRIDGE_PROPRIO_DIM = 7 19 | DATE_TIME = time.strftime("%Y_%m_%d-%H_%M_%S") 20 | DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 21 | np.set_printoptions(formatter={"float": lambda x: "{0:0.2f}".format(x)}) 22 | 23 | 24 | def get_widowx_env_params(cfg): 25 | """Gets (mostly default) environment parameters for the WidowX environment.""" 26 | env_params = WidowXConfigs.DefaultEnvParams.copy() 27 | env_params["override_workspace_boundaries"] = cfg.bounds 28 | env_params["camera_topics"] = cfg.camera_topics 29 | env_params["return_full_image"] = True 30 | return env_params 31 | 32 | 33 | def get_widowx_env(cfg, model=None): 34 | """Get WidowX control environment.""" 35 | # Set up the WidowX environment parameters 36 | env_params = get_widowx_env_params(cfg) 37 | start_state = np.concatenate([cfg.init_ee_pos, cfg.init_ee_quat]) 38 | env_params["start_state"] = list(start_state) 39 | # Set up the WidowX client 40 | widowx_client = WidowXClient(host=cfg.host_ip, port=cfg.port) 41 | widowx_client.init(env_params) 42 | env = WidowXGym( 43 | widowx_client, 44 | cfg=cfg, 45 | blocking=cfg.blocking, 46 | ) 47 | return env 48 | 49 | 50 | def get_next_task_label(task_label): 51 | """Prompt the user to input the next task.""" 52 | if task_label == "": 53 | user_input = "" 54 | while user_input == "": 55 | user_input = input("Enter the task name: ") 56 | task_label = user_input 57 | else: 58 | user_input = input("Enter the task name (or leave blank to repeat the previous task): ") 59 | if user_input == "": 60 | pass # Do nothing -> Let task_label be the same 61 | else: 62 | task_label = user_input 63 | print(f"Task: {task_label}") 64 | return task_label 65 | 66 | 67 | def save_rollout_video(rollout_images, idx): 68 | """Saves an MP4 replay of an episode.""" 69 | os.makedirs("./rollouts", exist_ok=True) 70 | mp4_path = f"./rollouts/rollout-{DATE_TIME}-{idx+1}.mp4" 71 | video_writer = imageio.get_writer(mp4_path, fps=5) 72 | for img in rollout_images: 73 | video_writer.append_data(img) 74 | video_writer.close() 75 | print(f"Saved rollout MP4 at path {mp4_path}") 76 | 77 | 78 | def save_rollout_data(rollout_orig_images, rollout_images, rollout_states, rollout_actions, idx): 79 | """ 80 | Saves rollout data from an episode. 81 | 82 | Args: 83 | rollout_orig_images (list): Original rollout images (before preprocessing). 84 | rollout_images (list): Preprocessed images. 85 | rollout_states (list): Proprioceptive states. 86 | rollout_actions (list): Predicted actions. 87 | idx (int): Episode index. 88 | """ 89 | os.makedirs("./rollouts", exist_ok=True) 90 | path = f"./rollouts/rollout-{DATE_TIME}-{idx+1}.npz" 91 | # Convert lists to numpy arrays 92 | orig_images_array = np.array(rollout_orig_images) 93 | images_array = np.array(rollout_images) 94 | states_array = np.array(rollout_states) 95 | actions_array = np.array(rollout_actions) 96 | # Save to a single .npz file 97 | np.savez(path, orig_images=orig_images_array, images=images_array, states=states_array, actions=actions_array) 98 | print(f"Saved rollout data at path {path}") 99 | 100 | 101 | def resize_image(img, resize_size): 102 | """ 103 | Takes numpy array corresponding to a single image and returns resized image as numpy array. 104 | 105 | NOTE (Moo Jin): To make input images in distribution with respect to the inputs seen at training time, we follow 106 | the same resizing scheme used in the Octo dataloader, which OpenVLA uses for training. 107 | """ 108 | assert isinstance(resize_size, tuple) 109 | # Resize to image size expected by model 110 | img = tf.image.encode_jpeg(img) # Encode as JPEG, as done in RLDS dataset builder 111 | img = tf.io.decode_image(img, expand_animations=False, dtype=tf.uint8) # Immediately decode back 112 | img = tf.image.resize(img, resize_size, method="lanczos3", antialias=True) 113 | img = tf.cast(tf.clip_by_value(tf.round(img), 0, 255), tf.uint8) 114 | img = img.numpy() 115 | return img 116 | 117 | 118 | def get_preprocessed_image(obs, resize_size): 119 | """Extracts image from observations and preprocesses it.""" 120 | assert isinstance(resize_size, int) or isinstance(resize_size, tuple) 121 | if isinstance(resize_size, int): 122 | resize_size = (resize_size, resize_size) 123 | obs["full_image"] = resize_image(obs["full_image"], resize_size) 124 | return obs["full_image"] 125 | 126 | 127 | def refresh_obs(obs, env): 128 | """Fetches new observations from the environment and updates the current observations.""" 129 | new_obs = env.get_observation() 130 | obs["full_image"] = new_obs["full_image"] 131 | obs["image_primary"] = new_obs["image_primary"] 132 | obs["proprio"] = new_obs["proprio"] 133 | return obs 134 | -------------------------------------------------------------------------------- /prismatic/overwatch/overwatch.py: -------------------------------------------------------------------------------- 1 | """ 2 | overwatch.py 3 | 4 | Utility class for creating a centralized/standardized logger (built on Rich) and accelerate handler. 5 | """ 6 | 7 | import logging 8 | import logging.config 9 | import os 10 | from contextlib import nullcontext 11 | from logging import LoggerAdapter 12 | from typing import Any, Callable, ClassVar, Dict, MutableMapping, Tuple, Union 13 | 14 | # Overwatch Default Format String 15 | RICH_FORMATTER, DATEFMT = "| >> %(message)s", "%m/%d [%H:%M:%S]" 16 | 17 | # Set Logging Configuration 18 | LOG_CONFIG = { 19 | "version": 1, 20 | "disable_existing_loggers": True, 21 | "formatters": {"simple-console": {"format": RICH_FORMATTER, "datefmt": DATEFMT}}, 22 | "handlers": { 23 | "console": { 24 | "class": "rich.logging.RichHandler", 25 | "formatter": "simple-console", 26 | "markup": True, 27 | "rich_tracebacks": True, 28 | "show_level": True, 29 | "show_path": True, 30 | "show_time": True, 31 | } 32 | }, 33 | "root": {"level": "INFO", "handlers": ["console"]}, 34 | } 35 | logging.config.dictConfig(LOG_CONFIG) 36 | 37 | 38 | # === Custom Contextual Logging Logic === 39 | class ContextAdapter(LoggerAdapter): 40 | CTX_PREFIXES: ClassVar[Dict[int, str]] = {**{0: "[*] "}, **{idx: "|=> ".rjust(4 + (idx * 4)) for idx in [1, 2, 3]}} 41 | 42 | def process(self, msg: str, kwargs: MutableMapping[str, Any]) -> Tuple[str, MutableMapping[str, Any]]: 43 | ctx_level = kwargs.pop("ctx_level", 0) 44 | return f"{self.CTX_PREFIXES[ctx_level]}{msg}", kwargs 45 | 46 | 47 | class DistributedOverwatch: 48 | def __init__(self, name: str) -> None: 49 | """Initializer for an Overwatch object that wraps logging & `accelerate.PartialState`.""" 50 | from accelerate import PartialState 51 | 52 | # Note that PartialState is always safe to initialize regardless of `accelerate launch` or `torchrun` 53 | # =>> However, might be worth actually figuring out if we need the `accelerate` dependency at all! 54 | self.logger, self.distributed_state = ContextAdapter(logging.getLogger(name), extra={}), PartialState() 55 | 56 | # Logger Delegation (for convenience; would be nice to just compose & dynamic dispatch eventually) 57 | self.debug = self.logger.debug 58 | self.info = self.logger.info 59 | self.warning = self.logger.warning 60 | self.error = self.logger.error 61 | self.critical = self.logger.critical 62 | 63 | # Logging Defaults =>> only Log `INFO` on Main Process, `ERROR` on others! 64 | self.logger.setLevel(logging.INFO if self.distributed_state.is_main_process else logging.ERROR) 65 | 66 | @property 67 | def rank_zero_only(self) -> Callable[..., Any]: 68 | return self.distributed_state.on_main_process 69 | 70 | @property 71 | def local_zero_only(self) -> Callable[..., Any]: 72 | return self.distributed_state.on_local_main_process 73 | 74 | @property 75 | def rank_zero_first(self) -> Callable[..., Any]: 76 | return self.distributed_state.main_process_first 77 | 78 | @property 79 | def local_zero_first(self) -> Callable[..., Any]: 80 | return self.distributed_state.local_main_process_first 81 | 82 | def is_rank_zero(self) -> bool: 83 | return self.distributed_state.is_main_process 84 | 85 | def rank(self) -> int: 86 | return self.distributed_state.process_index 87 | 88 | def local_rank(self) -> int: 89 | return self.distributed_state.local_process_index 90 | 91 | def world_size(self) -> int: 92 | return self.distributed_state.num_processes 93 | 94 | 95 | class PureOverwatch: 96 | def __init__(self, name: str) -> None: 97 | """Initializer for an Overwatch object that just wraps logging.""" 98 | self.logger = ContextAdapter(logging.getLogger(name), extra={}) 99 | 100 | # Logger Delegation (for convenience; would be nice to just compose & dynamic dispatch eventually) 101 | self.debug = self.logger.debug 102 | self.info = self.logger.info 103 | self.warning = self.logger.warning 104 | self.error = self.logger.error 105 | self.critical = self.logger.critical 106 | 107 | # Logging Defaults =>> INFO 108 | self.logger.setLevel(logging.INFO) 109 | 110 | @staticmethod 111 | def get_identity_ctx() -> Callable[..., Any]: 112 | def identity(fn: Callable[..., Any]) -> Callable[..., Any]: 113 | return fn 114 | 115 | return identity 116 | 117 | @property 118 | def rank_zero_only(self) -> Callable[..., Any]: 119 | return self.get_identity_ctx() 120 | 121 | @property 122 | def local_zero_only(self) -> Callable[..., Any]: 123 | return self.get_identity_ctx() 124 | 125 | @property 126 | def rank_zero_first(self) -> Callable[..., Any]: 127 | return nullcontext 128 | 129 | @property 130 | def local_zero_first(self) -> Callable[..., Any]: 131 | return nullcontext 132 | 133 | @staticmethod 134 | def is_rank_zero() -> bool: 135 | return True 136 | 137 | @staticmethod 138 | def rank() -> int: 139 | return 0 140 | 141 | @staticmethod 142 | def world_size() -> int: 143 | return 1 144 | 145 | 146 | def initialize_overwatch(name: str) -> Union[DistributedOverwatch, PureOverwatch]: 147 | return DistributedOverwatch(name) if int(os.environ.get("WORLD_SIZE", -1)) != -1 else PureOverwatch(name) 148 | -------------------------------------------------------------------------------- /scripts/extern/verify_prismatic.py: -------------------------------------------------------------------------------- 1 | """ 2 | verify_prismatic.py 3 | 4 | Given an HF-exported Prismatic model, attempt to load via AutoClasses, and verify forward() and generate(). 5 | """ 6 | 7 | import time 8 | 9 | import requests 10 | import torch 11 | from PIL import Image 12 | from transformers import AutoModelForVision2Seq, AutoProcessor 13 | 14 | # === Verification Arguments === 15 | MODEL_PATH = "TRI-ML/prismatic-siglip-224px-7b" 16 | DEFAULT_IMAGE_URL = ( 17 | "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png" 18 | ) 19 | 20 | if "-prism-" in MODEL_PATH: 21 | SAMPLE_PROMPTS_FOR_GENERATION = [ 22 | "In: What is sitting in the coffee?\nOut:", 23 | "In: What's the name of the food on the plate?\nOut:", 24 | "In: caption.\nOut:", 25 | "In: how many beinets..?\nOut:", 26 | "In: Can you give me a lyrical description of the scene\nOut:", 27 | ] 28 | else: 29 | SYSTEM_PROMPT = ( 30 | "A chat between a curious user and an artificial intelligence assistant. " 31 | "The assistant gives helpful, detailed, and polite answers to the user's questions." 32 | ) 33 | SAMPLE_PROMPTS_FOR_GENERATION = [ 34 | f"{SYSTEM_PROMPT} USER: What is sitting in the coffee? ASSISTANT:", 35 | f"{SYSTEM_PROMPT} USER: What's the name of the food on the plate? ASSISTANT:", 36 | f"{SYSTEM_PROMPT} USER: caption. ASSISTANT:", 37 | f"{SYSTEM_PROMPT} USER: how many beinets..? ASSISTANT:", 38 | f"{SYSTEM_PROMPT} USER: Can you give me a lyrical description of the scene ASSISTANT:", 39 | ] 40 | 41 | 42 | @torch.inference_mode() 43 | def verify_prismatic() -> None: 44 | print(f"[*] Verifying PrismaticForConditionalGeneration using Model `{MODEL_PATH}`") 45 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 46 | 47 | # Load Processor & VLM 48 | print("[*] Instantiating Processor and Pretrained VLM") 49 | processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True) 50 | 51 | # === AUTOCAST MODE === 52 | # print("[*] Loading in BF16 Autocast Mode") 53 | # vlm = AutoModelForVision2Seq.from_pretrained(MODEL_PATH, low_cpu_mem_usage=True, trust_remote_code=True).to( 54 | # device, dtype=torch.bfloat16 55 | # ) 56 | 57 | # === NATIVE BFLOAT16 MODE === 58 | # print("[*] Loading in BF16") 59 | # vlm = AutoModelForVision2Seq.from_pretrained( 60 | # MODEL_PATH, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True 61 | # ).to(device) 62 | 63 | # === BFLOAT16 + FLASH-ATTN MODE :: [~14GB of VRAM Passive || 18GB of VRAM Active] === 64 | print("[*] Loading in BF16 with Flash-Attention Enabled") 65 | vlm = AutoModelForVision2Seq.from_pretrained( 66 | MODEL_PATH, 67 | attn_implementation="flash_attention_2", 68 | torch_dtype=torch.bfloat16, 69 | low_cpu_mem_usage=True, 70 | trust_remote_code=True, 71 | ).to(device) 72 | 73 | # === 8-BIT QUANTIZATION MODE (`pip install bitsandbytes`) :: [~9GB of VRAM Passive || 10GB of VRAM Active] === 74 | # print("[*] Loading in 8-Bit Quantization Mode") 75 | # vlm = AutoModelForVision2Seq.from_pretrained( 76 | # MODEL_PATH, 77 | # attn_implementation="flash_attention_2", 78 | # torch_dtype=torch.float16, 79 | # quantization_config=BitsAndBytesConfig(load_in_8bit=True), 80 | # low_cpu_mem_usage=True, 81 | # trust_remote_code=True, 82 | # ) 83 | 84 | # === 4-BIT QUANTIZATION MODE (`pip install bitsandbytes`) :: [~6GB of VRAM Passive || 7GB of VRAM Active] === 85 | # print("[*] Loading in 4-Bit Quantization Mode") 86 | # vlm = AutoModelForVision2Seq.from_pretrained( 87 | # MODEL_PATH, 88 | # attn_implementation="flash_attention_2", 89 | # torch_dtype=torch.float16, 90 | # quantization_config=BitsAndBytesConfig(load_in_4bit=True), 91 | # low_cpu_mem_usage=True, 92 | # trust_remote_code=True, 93 | # ) 94 | 95 | # Iterate over Sample Prompts =>> Generate 96 | image = Image.open(requests.get(DEFAULT_IMAGE_URL, stream=True).raw).convert("RGB") 97 | num_tokens, total_time = 0, 0.0 98 | 99 | print("[*] Iterating over Sample Prompts\n===\n") 100 | for idx, prompt in enumerate(SAMPLE_PROMPTS_FOR_GENERATION): 101 | # === AUTOCAST MODE (Reproduces Prismatic `scripts/generate.py`) === 102 | # inputs = processor(prompt, image).to(device) 103 | # 104 | # # Using "autocast" to evaluate bit-wise equivalence to `scripts/generate.py` 105 | # # =>> Running in native BF16 is also fine (but leads to slightly different generations) 106 | # with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True): 107 | # gen_ids = vlm.generate(**inputs, do_sample=False, min_length=1, max_length=512) 108 | 109 | # === BFLOAT16 MODE === 110 | inputs = processor(prompt, image).to(device, dtype=torch.bfloat16) 111 | 112 | # === 8-BIT/4-BIT QUANTIZATION MODE === 113 | # inputs = processor(prompt, image).to(device, dtype=torch.float16) 114 | 115 | # Run Inference 116 | gen_ids = None 117 | for _ in range(5): 118 | start_time = time.time() 119 | gen_ids = vlm.generate(**inputs, do_sample=False, min_length=1, max_length=512) 120 | total_time += time.time() - start_time 121 | 122 | gen_ids = gen_ids[0, inputs.input_ids.shape[1] :] 123 | num_tokens += len(gen_ids) 124 | 125 | # === 126 | gen_text = processor.decode(gen_ids, skip_special_tokens=True).strip() 127 | print(f"[{idx + 1}] Input Prompt => {prompt}\n Generated => {gen_text}\n") 128 | 129 | # Compute Tokens / Second 130 | print(f"[*] Generated Tokens per Second = {num_tokens / total_time} w/ {num_tokens = } and {total_time = }") 131 | 132 | 133 | if __name__ == "__main__": 134 | verify_prismatic() 135 | -------------------------------------------------------------------------------- /prismatic/models/materialize.py: -------------------------------------------------------------------------------- 1 | """ 2 | materialize.py 3 | 4 | Factory class for initializing Vision Backbones, LLM Backbones, and VLMs from a set registry; provides and exports 5 | individual functions for clear control flow. 6 | """ 7 | 8 | from typing import Optional, Tuple 9 | 10 | from transformers import PreTrainedTokenizerBase 11 | 12 | from prismatic.models.backbones.llm import LLaMa2LLMBackbone, LLMBackbone, MistralLLMBackbone, PhiLLMBackbone 13 | from prismatic.models.backbones.llm.qwen25 import Qwen25LLMBackbone 14 | from prismatic.models.backbones.vision import ( 15 | CLIPViTBackbone, 16 | DinoCLIPViTBackbone, 17 | DinoSigLIPViTBackbone, 18 | DinoV2ViTBackbone, 19 | ImageTransform, 20 | IN1KViTBackbone, 21 | SigLIPViTBackbone, 22 | VisionBackbone, 23 | ) 24 | from prismatic.models.vlms import PrismaticVLM 25 | 26 | # === Registries =>> Maps ID --> {cls(), kwargs} :: Different Registries for Vision Backbones, LLM Backbones, VLMs === 27 | # fmt: off 28 | 29 | # === Vision Backbone Registry === 30 | VISION_BACKBONES = { 31 | # === 224px Backbones === 32 | "clip-vit-l": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 224}}, 33 | "siglip-vit-so400m": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 224}}, 34 | "dinov2-vit-l": {"cls": DinoV2ViTBackbone, "kwargs": {"default_image_size": 224}}, 35 | "in1k-vit-l": {"cls": IN1KViTBackbone, "kwargs": {"default_image_size": 224}}, 36 | "dinosiglip-vit-so-224px": {"cls": DinoSigLIPViTBackbone, "kwargs": {"default_image_size": 224}}, 37 | 38 | # === Assorted CLIP Backbones === 39 | "clip-vit-b": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 224}}, 40 | "clip-vit-l-336px": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 336}}, 41 | 42 | # === Assorted SigLIP Backbones === 43 | "siglip-vit-b16-224px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 224}}, 44 | "siglip-vit-b16-256px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 256}}, 45 | "siglip-vit-b16-384px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 384}}, 46 | "siglip-vit-so400m-384px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 384}}, 47 | 48 | # === Fused Backbones === 49 | "dinoclip-vit-l-336px": {"cls": DinoCLIPViTBackbone, "kwargs": {"default_image_size": 336}}, 50 | "dinosiglip-vit-so-384px": {"cls": DinoSigLIPViTBackbone, "kwargs": {"default_image_size": 384}}, 51 | } 52 | 53 | 54 | # === Language Model Registry === 55 | LLM_BACKBONES = { 56 | # === LLaMa-2 Pure (Non-Chat) Backbones === 57 | "llama2-7b-pure": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, 58 | "llama2-13b-pure": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, 59 | 60 | # === LLaMa-2 Chat Backbones === 61 | "llama2-7b-chat": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, 62 | "llama2-13b-chat": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, 63 | 64 | # === Vicuna-v1.5 Backbones === 65 | "vicuna-v15-7b": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, 66 | "vicuna-v15-13b": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, 67 | 68 | # === Mistral v0.1 Backbones === 69 | "mistral-v0.1-7b-pure": {"cls": MistralLLMBackbone, "kwargs": {}}, 70 | "mistral-v0.1-7b-instruct": {"cls": MistralLLMBackbone, "kwargs": {}}, 71 | 72 | # === Phi-2 Backbone === 73 | "phi-2-3b": {"cls": PhiLLMBackbone, "kwargs": {}}, 74 | 75 | # === Qwen2.5 Backbone === 76 | "qwen25-0_5b-pure": {"cls": Qwen25LLMBackbone, "kwargs": {}}, 77 | "qwen25-0_5b-extra": {"cls": Qwen25LLMBackbone, "kwargs": {"num_extra_tokens": 256}}, 78 | "qwen25-1_5b-pure": {"cls": Qwen25LLMBackbone, "kwargs": {}}, 79 | "qwen25-3b-pure": {"cls": Qwen25LLMBackbone, "kwargs": {}}, 80 | } 81 | 82 | # fmt: on 83 | 84 | 85 | def get_vision_backbone_and_transform( 86 | vision_backbone_id: str, 87 | image_resize_strategy: str, 88 | image_sequence_len: int, 89 | ) -> Tuple[VisionBackbone, ImageTransform]: 90 | """Instantiate a Vision Backbone, returning both the nn.Module wrapper class and default Image Transform.""" 91 | if vision_backbone_id in VISION_BACKBONES: 92 | vision_cfg = VISION_BACKBONES[vision_backbone_id] 93 | vision_backbone: VisionBackbone = vision_cfg["cls"]( 94 | vision_backbone_id, image_resize_strategy, image_sequence_len=image_sequence_len, **vision_cfg["kwargs"] 95 | ) 96 | image_transform = vision_backbone.get_image_transform() 97 | return vision_backbone, image_transform 98 | 99 | else: 100 | raise ValueError(f"Vision Backbone `{vision_backbone_id}` is not supported!") 101 | 102 | 103 | def get_llm_backbone_and_tokenizer( 104 | llm_backbone_id: str, 105 | llm_max_length: int = 2048, 106 | hf_token: Optional[str] = None, 107 | inference_mode: bool = False, 108 | ) -> Tuple[LLMBackbone, PreTrainedTokenizerBase]: 109 | if llm_backbone_id in LLM_BACKBONES: 110 | llm_cfg = LLM_BACKBONES[llm_backbone_id] 111 | llm_backbone: LLMBackbone = llm_cfg["cls"]( 112 | llm_backbone_id, 113 | llm_max_length=llm_max_length, 114 | hf_token=hf_token, 115 | inference_mode=inference_mode, 116 | **llm_cfg["kwargs"], 117 | ) 118 | tokenizer = llm_backbone.get_tokenizer() 119 | return llm_backbone, tokenizer 120 | 121 | else: 122 | raise ValueError(f"LLM Backbone `{llm_backbone_id}` is not supported!") 123 | 124 | 125 | def get_vlm( 126 | model_id: str, 127 | arch_specifier: str, 128 | vision_backbone: VisionBackbone, 129 | llm_backbone: LLMBackbone, 130 | enable_mixed_precision_training: bool = True, 131 | ) -> PrismaticVLM: 132 | """Lightweight wrapper around initializing a VLM, mostly for future-proofing (if one wants to add a new VLM).""" 133 | return PrismaticVLM( 134 | model_id, 135 | vision_backbone, 136 | llm_backbone, 137 | enable_mixed_precision_training=enable_mixed_precision_training, 138 | arch_specifier=arch_specifier, 139 | ) 140 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/rlds/oxe/materialize.py: -------------------------------------------------------------------------------- 1 | """ 2 | materialize.py 3 | 4 | Factory class for initializing Open-X Embodiment dataset kwargs and other parameters; provides and exports functions for 5 | clear control flow. 6 | """ 7 | 8 | from copy import deepcopy 9 | from pathlib import Path 10 | from typing import Any, Dict, List, Tuple 11 | 12 | from prismatic.overwatch import initialize_overwatch 13 | from prismatic.vla.datasets.rlds.oxe.configs import OXE_DATASET_CONFIGS, ActionEncoding 14 | from prismatic.vla.datasets.rlds.oxe.transforms import OXE_STANDARDIZATION_TRANSFORMS 15 | from prismatic.vla.datasets.rlds.utils.data_utils import NormalizationType 16 | 17 | # Initialize Overwatch =>> Wraps `logging.Logger` 18 | overwatch = initialize_overwatch(__name__) 19 | 20 | 21 | def make_oxe_dataset_kwargs( 22 | dataset_name: str, 23 | data_root_dir: Path, 24 | load_camera_views: Tuple[str] = ("primary",), 25 | load_depth: bool = False, 26 | load_proprio: bool = True, 27 | load_language: bool = True, 28 | action_proprio_normalization_type: NormalizationType = NormalizationType.NORMAL, 29 | ) -> Dict[str, Any]: 30 | """Generates config (kwargs) for given dataset from Open-X Embodiment.""" 31 | dataset_kwargs = deepcopy(OXE_DATASET_CONFIGS[dataset_name]) 32 | if dataset_kwargs["action_encoding"] not in [ActionEncoding.EEF_POS, ActionEncoding.EEF_R6]: 33 | raise ValueError(f"Cannot load `{dataset_name}`; only EEF_POS & EEF_R6 actions supported!") 34 | 35 | # [Contract] For EEF_POS & EEF_R6 actions, only the last action dimension (gripper) is absolute! 36 | # Normalize all action dimensions *except* the gripper 37 | if dataset_kwargs["action_encoding"] is ActionEncoding.EEF_POS: 38 | dataset_kwargs["absolute_action_mask"] = [False] * 6 + [True] 39 | dataset_kwargs["action_normalization_mask"] = [True] * 6 + [False] 40 | elif dataset_kwargs["action_encoding"] is ActionEncoding.EEF_R6: 41 | dataset_kwargs["absolute_action_mask"] = [False] * 9 + [True] 42 | dataset_kwargs["action_normalization_mask"] = [True] * 9 + [False] 43 | dataset_kwargs["action_proprio_normalization_type"] = action_proprio_normalization_type 44 | 45 | # Adjust Loaded Camera Views 46 | if len(missing_keys := (set(load_camera_views) - set(dataset_kwargs["image_obs_keys"]))) > 0: 47 | raise ValueError(f"Cannot load `{dataset_name}`; missing camera views `{missing_keys}`") 48 | 49 | # Filter 50 | dataset_kwargs["image_obs_keys"] = { 51 | k: v for k, v in dataset_kwargs["image_obs_keys"].items() if k in load_camera_views 52 | } 53 | dataset_kwargs["depth_obs_keys"] = { 54 | k: v for k, v in dataset_kwargs["depth_obs_keys"].items() if k in load_camera_views 55 | } 56 | 57 | # Eliminate Unnecessary Keys 58 | dataset_kwargs.pop("state_encoding") 59 | dataset_kwargs.pop("action_encoding") 60 | if not load_depth: 61 | dataset_kwargs.pop("depth_obs_keys") 62 | if not load_proprio: 63 | dataset_kwargs.pop("state_obs_keys") 64 | 65 | # Load Language 66 | if load_language: 67 | dataset_kwargs["language_key"] = "language_instruction" 68 | 69 | # Specify Standardization Transform 70 | dataset_kwargs["standardize_fn"] = OXE_STANDARDIZATION_TRANSFORMS[dataset_name] 71 | 72 | # Add any aux arguments 73 | if "aux_kwargs" in dataset_kwargs: 74 | dataset_kwargs.update(dataset_kwargs.pop("aux_kwargs")) 75 | 76 | return {"name": dataset_name, "data_dir": str(data_root_dir), **dataset_kwargs} 77 | 78 | 79 | def get_oxe_dataset_kwargs_and_weights( 80 | data_root_dir: Path, 81 | mixture_spec: List[Tuple[str, float]], 82 | load_camera_views: Tuple[str] = ("primary",), 83 | load_depth: bool = False, 84 | load_proprio: bool = True, 85 | load_language: bool = True, 86 | action_proprio_normalization_type: NormalizationType = NormalizationType.NORMAL, 87 | ) -> Tuple[Dict[str, Any], List[float]]: 88 | """ 89 | Generates dataset kwargs for a given dataset mix from the Open X-Embodiment dataset. The returned kwargs 90 | (per-dataset configs) and weights can be passed directly to `make_interleaved_dataset`. 91 | 92 | :param data_root_dir: Base directory containing RLDS/TFDS-formatted datasets (from Open-X) 93 | :param mixture_spec: List of (dataset_name, sampling_weight) from `oxe.mixtures.OXE_NAMED_MIXTURES` 94 | :param load_camera_views: Camera views to load; see `oxe.dataset_configs.py` for available views. 95 | :param load_depth: Load depth information in addition to camera RGB. 96 | :param load_proprio: Load proprioceptive state. 97 | :param load_language: Load language instructions. 98 | :param action_proprio_normalization_type: Normalization scheme to use for proprioceptive actions. 99 | 100 | return: Tuple of (per_dataset_kwargs, sampling_weights) 101 | """ 102 | included_datasets, filtered_mixture_spec = set(), [] 103 | for d_name, d_weight in mixture_spec: 104 | if d_name in included_datasets: 105 | overwatch.warning(f"Skipping Duplicate Dataset: `{(d_name, d_weight)}`") 106 | continue 107 | 108 | included_datasets.add(d_name) 109 | filtered_mixture_spec.append((d_name, d_weight)) 110 | 111 | # Assemble Dataset Config (kwargs) and Weights 112 | per_dataset_kwargs, sampling_weights = [], [] 113 | for d_name, d_weight in filtered_mixture_spec: 114 | try: 115 | per_dataset_kwargs.append( 116 | make_oxe_dataset_kwargs( 117 | d_name, 118 | data_root_dir, 119 | load_camera_views, 120 | load_depth, 121 | load_proprio, 122 | load_language, 123 | action_proprio_normalization_type, 124 | ) 125 | ) 126 | sampling_weights.append(d_weight) 127 | 128 | except ValueError as e: 129 | overwatch.warning(f"Skipping `{d_name}` due to Error: {e}") 130 | 131 | return per_dataset_kwargs, sampling_weights 132 | -------------------------------------------------------------------------------- /vla-scripts/deploy.py: -------------------------------------------------------------------------------- 1 | """ 2 | deploy.py 3 | 4 | Provide a lightweight server/client implementation for deploying OpenVLA models (through the HF AutoClass API) over a 5 | REST API. This script implements *just* the server, with specific dependencies and instructions below. 6 | 7 | Note that for the *client*, usage just requires numpy/json-numpy, and requests; example usage below! 8 | 9 | Dependencies: 10 | => Server (runs OpenVLA model on GPU): `pip install uvicorn fastapi json-numpy` 11 | => Client: `pip install requests json-numpy` 12 | 13 | Client (Standalone) Usage (assuming a server running on 0.0.0.0:8000): 14 | 15 | ``` 16 | import requests 17 | import json_numpy 18 | json_numpy.patch() 19 | import numpy as np 20 | 21 | action = requests.post( 22 | "http://0.0.0.0:8000/act", 23 | json={"image": np.zeros((256, 256, 3), dtype=np.uint8), "instruction": "do something"} 24 | ).json() 25 | 26 | Note that if your server is not accessible on the open web, you can use ngrok, or forward ports to your client via ssh: 27 | => `ssh -L 8000:localhost:8000 ssh USER@` 28 | """ 29 | 30 | import os.path 31 | 32 | # ruff: noqa: E402 33 | import json_numpy 34 | 35 | json_numpy.patch() 36 | import json 37 | import logging 38 | import traceback 39 | from dataclasses import dataclass 40 | from pathlib import Path 41 | from typing import Any, Dict, Optional, Union 42 | 43 | import draccus 44 | import torch 45 | import uvicorn 46 | from fastapi import FastAPI 47 | from fastapi.responses import JSONResponse 48 | from PIL import Image 49 | from transformers import AutoModelForVision2Seq, AutoProcessor 50 | 51 | # === Utilities === 52 | SYSTEM_PROMPT = ( 53 | "A chat between a curious user and an artificial intelligence assistant. " 54 | "The assistant gives helpful, detailed, and polite answers to the user's questions." 55 | ) 56 | 57 | 58 | def get_openvla_prompt(instruction: str, openvla_path: Union[str, Path]) -> str: 59 | if "v01" in openvla_path: 60 | return f"{SYSTEM_PROMPT} USER: What action should the robot take to {instruction.lower()}? ASSISTANT:" 61 | else: 62 | return f"In: What action should the robot take to {instruction.lower()}?\nOut:" 63 | 64 | 65 | # === Server Interface === 66 | class OpenVLAServer: 67 | def __init__(self, openvla_path: Union[str, Path], attn_implementation: Optional[str] = "flash_attention_2") -> Path: 68 | """ 69 | A simple server for OpenVLA models; exposes `/act` to predict an action for a given image + instruction. 70 | => Takes in {"image": np.ndarray, "instruction": str, "unnorm_key": Optional[str]} 71 | => Returns {"action": np.ndarray} 72 | """ 73 | self.openvla_path, self.attn_implementation = openvla_path, attn_implementation 74 | self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 75 | 76 | # Load VLA Model using HF AutoClasses 77 | self.processor = AutoProcessor.from_pretrained(self.openvla_path, trust_remote_code=True) 78 | self.vla = AutoModelForVision2Seq.from_pretrained( 79 | self.openvla_path, 80 | attn_implementation=attn_implementation, 81 | torch_dtype=torch.bfloat16, 82 | low_cpu_mem_usage=True, 83 | trust_remote_code=True, 84 | ).to(self.device) 85 | 86 | # [Hacky] Load Dataset Statistics from Disk (if passing a path to a fine-tuned model) 87 | if os.path.isdir(self.openvla_path): 88 | with open(Path(self.openvla_path) / "dataset_statistics.json", "r") as f: 89 | self.vla.norm_stats = json.load(f) 90 | 91 | def predict_action(self, payload: Dict[str, Any]) -> str: 92 | try: 93 | if double_encode := "encoded" in payload: 94 | # Support cases where `json_numpy` is hard to install, and numpy arrays are "double-encoded" as strings 95 | assert len(payload.keys()) == 1, "Only uses encoded payload!" 96 | payload = json.loads(payload["encoded"]) 97 | 98 | # Parse payload components 99 | image, instruction = payload["image"], payload["instruction"] 100 | unnorm_key = payload.get("unnorm_key", None) 101 | 102 | # Run VLA Inference 103 | prompt = get_openvla_prompt(instruction, self.openvla_path) 104 | inputs = self.processor(prompt, Image.fromarray(image).convert("RGB")).to(self.device, dtype=torch.bfloat16) 105 | action = self.vla.predict_action(**inputs, unnorm_key=unnorm_key, do_sample=False) 106 | if double_encode: 107 | return JSONResponse(json_numpy.dumps(action)) 108 | else: 109 | return JSONResponse(action) 110 | except: # noqa: E722 111 | logging.error(traceback.format_exc()) 112 | logging.warning( 113 | "Your request threw an error; make sure your request complies with the expected format:\n" 114 | "{'image': np.ndarray, 'instruction': str}\n" 115 | "You can optionally an `unnorm_key: str` to specific the dataset statistics you want to use for " 116 | "de-normalizing the output actions." 117 | ) 118 | return "error" 119 | 120 | def run(self, host: str = "0.0.0.0", port: int = 8000) -> None: 121 | self.app = FastAPI() 122 | self.app.post("/act")(self.predict_action) 123 | uvicorn.run(self.app, host=host, port=port) 124 | 125 | 126 | @dataclass 127 | class DeployConfig: 128 | # fmt: off 129 | openvla_path: Union[str, Path] = "openvla/openvla-7b" # HF Hub Path (or path to local run directory) 130 | 131 | # Server Configuration 132 | host: str = "0.0.0.0" # Host IP Address 133 | port: int = 8000 # Host Port 134 | 135 | # fmt: on 136 | 137 | 138 | @draccus.wrap() 139 | def deploy(cfg: DeployConfig) -> None: 140 | server = OpenVLAServer(cfg.openvla_path) 141 | server.run(cfg.host, port=cfg.port) 142 | 143 | 144 | if __name__ == "__main__": 145 | deploy() 146 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/rlds/oxe/utils/droid_utils.py: -------------------------------------------------------------------------------- 1 | """Episode transforms for DROID dataset.""" 2 | 3 | from typing import Any, Dict 4 | 5 | import tensorflow as tf 6 | import tensorflow_graphics.geometry.transformation as tfg 7 | 8 | 9 | def rmat_to_euler(rot_mat): 10 | return tfg.euler.from_rotation_matrix(rot_mat) 11 | 12 | 13 | def euler_to_rmat(euler): 14 | return tfg.rotation_matrix_3d.from_euler(euler) 15 | 16 | 17 | def invert_rmat(rot_mat): 18 | return tfg.rotation_matrix_3d.inverse(rot_mat) 19 | 20 | 21 | def rotmat_to_rot6d(mat): 22 | """ 23 | Converts rotation matrix to R6 rotation representation (first two rows in rotation matrix). 24 | Args: 25 | mat: rotation matrix 26 | 27 | Returns: 6d vector (first two rows of rotation matrix) 28 | 29 | """ 30 | r6 = mat[..., :2, :] 31 | r6_0, r6_1 = r6[..., 0, :], r6[..., 1, :] 32 | r6_flat = tf.concat([r6_0, r6_1], axis=-1) 33 | return r6_flat 34 | 35 | 36 | def velocity_act_to_wrist_frame(velocity, wrist_in_robot_frame): 37 | """ 38 | Translates velocity actions (translation + rotation) from base frame of the robot to wrist frame. 39 | Args: 40 | velocity: 6d velocity action (3 x translation, 3 x rotation) 41 | wrist_in_robot_frame: 6d pose of the end-effector in robot base frame 42 | 43 | Returns: 9d velocity action in robot wrist frame (3 x translation, 6 x rotation as R6) 44 | 45 | """ 46 | R_frame = euler_to_rmat(wrist_in_robot_frame[:, 3:6]) 47 | R_frame_inv = invert_rmat(R_frame) 48 | 49 | # world to wrist: dT_pi = R^-1 dT_rbt 50 | vel_t = (R_frame_inv @ velocity[:, :3][..., None])[..., 0] 51 | 52 | # world to wrist: dR_pi = R^-1 dR_rbt R 53 | dR = euler_to_rmat(velocity[:, 3:6]) 54 | dR = R_frame_inv @ (dR @ R_frame) 55 | dR_r6 = rotmat_to_rot6d(dR) 56 | return tf.concat([vel_t, dR_r6], axis=-1) 57 | 58 | 59 | def rand_swap_exterior_images(img1, img2): 60 | """ 61 | Randomly swaps the two exterior images (for training with single exterior input). 62 | """ 63 | return tf.cond(tf.random.uniform(shape=[]) > 0.5, lambda: (img1, img2), lambda: (img2, img1)) 64 | 65 | 66 | def droid_baseact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: 67 | """ 68 | DROID dataset transformation for actions expressed in *base* frame of the robot. 69 | """ 70 | dt = trajectory["action_dict"]["cartesian_velocity"][:, :3] 71 | dR = trajectory["action_dict"]["cartesian_velocity"][:, 3:6] 72 | 73 | trajectory["action"] = tf.concat( 74 | ( 75 | dt, 76 | dR, 77 | 1 - trajectory["action_dict"]["gripper_position"], 78 | ), 79 | axis=-1, 80 | ) 81 | trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = ( 82 | rand_swap_exterior_images( 83 | trajectory["observation"]["exterior_image_1_left"], 84 | trajectory["observation"]["exterior_image_2_left"], 85 | ) 86 | ) 87 | trajectory["observation"]["proprio"] = tf.concat( 88 | ( 89 | trajectory["observation"]["cartesian_position"], 90 | trajectory["observation"]["gripper_position"], 91 | ), 92 | axis=-1, 93 | ) 94 | return trajectory 95 | 96 | 97 | def droid_wristact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: 98 | """ 99 | DROID dataset transformation for actions expressed in *wrist* frame of the robot. 100 | """ 101 | wrist_act = velocity_act_to_wrist_frame( 102 | trajectory["action_dict"]["cartesian_velocity"], trajectory["observation"]["cartesian_position"] 103 | ) 104 | trajectory["action"] = tf.concat( 105 | ( 106 | wrist_act, 107 | trajectory["action_dict"]["gripper_position"], 108 | ), 109 | axis=-1, 110 | ) 111 | trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = ( 112 | rand_swap_exterior_images( 113 | trajectory["observation"]["exterior_image_1_left"], 114 | trajectory["observation"]["exterior_image_2_left"], 115 | ) 116 | ) 117 | trajectory["observation"]["proprio"] = tf.concat( 118 | ( 119 | trajectory["observation"]["cartesian_position"], 120 | trajectory["observation"]["gripper_position"], 121 | ), 122 | axis=-1, 123 | ) 124 | return trajectory 125 | 126 | 127 | def droid_finetuning_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: 128 | """ 129 | DROID dataset transformation for actions expressed in *base* frame of the robot. 130 | """ 131 | dt = trajectory["action_dict"]["cartesian_velocity"][:, :3] 132 | dR = trajectory["action_dict"]["cartesian_velocity"][:, 3:6] 133 | trajectory["action"] = tf.concat( 134 | ( 135 | dt, 136 | dR, 137 | 1 - trajectory["action_dict"]["gripper_position"], 138 | ), 139 | axis=-1, 140 | ) 141 | trajectory["observation"]["proprio"] = tf.concat( 142 | ( 143 | trajectory["observation"]["cartesian_position"], 144 | trajectory["observation"]["gripper_position"], 145 | ), 146 | axis=-1, 147 | ) 148 | return trajectory 149 | 150 | 151 | def zero_action_filter(traj: Dict) -> bool: 152 | """ 153 | Filters transitions whose actions are all-0 (only relative actions, no gripper action). 154 | Note: this filter is applied *after* action normalization, so need to compare to "normalized 0". 155 | """ 156 | DROID_Q01 = tf.convert_to_tensor( 157 | [ 158 | -0.7776297926902771, 159 | -0.5803514122962952, 160 | -0.5795090794563293, 161 | -0.6464047729969025, 162 | -0.7041108310222626, 163 | -0.8895104378461838, 164 | ] 165 | ) 166 | DROID_Q99 = tf.convert_to_tensor( 167 | [ 168 | 0.7597932070493698, 169 | 0.5726242214441299, 170 | 0.7351000607013702, 171 | 0.6705610305070877, 172 | 0.6464948207139969, 173 | 0.8897542208433151, 174 | ] 175 | ) 176 | DROID_NORM_0_ACT = 2 * (tf.zeros_like(traj["action"][:, :6]) - DROID_Q01) / (DROID_Q99 - DROID_Q01 + 1e-8) - 1 177 | 178 | return tf.reduce_any(tf.math.abs(traj["action"][:, :6] - DROID_NORM_0_ACT) > 1e-5) 179 | -------------------------------------------------------------------------------- /prismatic/extern/hf/configuration_prismatic.py: -------------------------------------------------------------------------------- 1 | """ 2 | configuration_prismatic.py 3 | 4 | HuggingFace-style configuration definition for Prismatic VLMs, inheriting from `transformers.PretrainedConfig`. 5 | Default configuration specifies `siglip-224px+7b`. 6 | """ 7 | 8 | from typing import Any, Dict, List, Optional 9 | 10 | from transformers import PretrainedConfig 11 | from transformers.models.auto import CONFIG_MAPPING 12 | 13 | # === Utilities for Mapping Prismatic names to HF names === 14 | # fmt: off 15 | VISION_BACKBONE_TO_RESOLUTION: Dict[str, List[int]] = { 16 | "clip-vit-l": [224], "siglip-vit-so400m": [224], "dinov2-vit-l": [224], "in1k-vit-l": [224], 17 | 18 | "clip-vit-l-336px": [336], 19 | "siglip-vit-so400m-384px": [384], 20 | 21 | "dinoclip-vit-l-336px": [336, 336], 22 | "dinosiglip-vit-so-224px": [224, 224], 23 | "dinosiglip-vit-so-384px": [384, 384], 24 | } 25 | VISION_BACKBONE_TO_TIMM_ID: Dict[str, List[str]] = { 26 | "clip-vit-l": ["vit_large_patch14_clip_224.openai"], 27 | "clip-vit-l-336px": ["vit_large_patch14_clip_336.openai"], 28 | 29 | "dinov2-vit-l": ["vit_large_patch14_reg4_dinov2.lvd142m"], 30 | "in1k-vit-l": ["vit_large_patch16_224.augreg_in21k_ft_in1k"], 31 | 32 | "siglip-vit-so400m": ["vit_so400m_patch14_siglip_224"], 33 | "siglip-vit-so400m-384px": ["vit_so400m_patch14_siglip_384"], 34 | 35 | "dinoclip-vit-l-336px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_large_patch14_clip_336.openai"], 36 | "dinosiglip-vit-so-224px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_224"], 37 | "dinosiglip-vit-so-384px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_384"], 38 | } 39 | TIMM_OVERRIDE_ACT_LAYER: Dict[str, List[Optional[str]]] = { 40 | "clip-vit-l": ["quick_gelu"], "clip-vit-l-336px": ["quick_gelu"], 41 | "dinov2-vit-l": [None], "in1k-vit-l": [None], 42 | "siglip-vit-so400m": [None], "siglip-vit-so400m-384px": [None], 43 | "dinoclip-vit-l-336px": [None, "quick_gelu"], 44 | "dinosiglip-vit-so-224px": [None, None], "dinosiglip-vit-so-384px": [None, None] 45 | } 46 | 47 | LLM_BACKBONE_TO_HF_PATH = { 48 | "llama2-7b-pure": "meta-llama/Llama-2-7b-hf", "llama2-13b-pure": "meta-llama/Llama-2-13b-hf", 49 | "llama2-7b-chat": "meta-llama/Llama-2-7b-chat-hf", "llama2-13b-chat": "meta-llama/Llama-2-13b-chat-hf", 50 | 51 | "vicuna-v15-7b": "lmsys/vicuna-7b-v1.5", "vicuna-v15-13b": "lmsys/vicuna-13b-v1.5", 52 | 53 | "mistral-v0.1-7b-pure": "mistralai/Mistral-7B-v0.1", 54 | "mistral-v0.1-7b-instruct": "mistralai/Mistral-7B-Instruct-v0.1", 55 | 56 | "phi-2-3b": "microsoft/phi-2", 57 | } 58 | LLM_BACKBONE_TO_HF_METACLASS = { 59 | "llama2-7b-pure": "llama", "llama2-13b-pure": "llama", "llama2-7b-chat": "llama", "llama2-13b-chat": "llama", 60 | "vicuna-v15-7b": "llama", "vicuna-v15-13b": "llama", 61 | 62 | "mistral-v0.1-7b-pure": "mistral", "mistral-v0.1-7b-instruct": "mistral", 63 | 64 | "phi-2-3b": "phi", 65 | } 66 | 67 | VALID_VISION_BACKBONES = set(VISION_BACKBONE_TO_RESOLUTION.keys()) 68 | VALID_LLM_BACKBONES = set(LLM_BACKBONE_TO_HF_PATH) 69 | # fmt: on 70 | 71 | 72 | class PrismaticConfig(PretrainedConfig): 73 | model_type: str = "prismatic" 74 | is_composition: bool = False 75 | 76 | def __init__( 77 | self, 78 | vision_backbone_id: str = "siglip-vit-so400m", 79 | llm_backbone_id: str = "vicuna-v15-7b", 80 | arch_specifier: str = "no-align+gelu-mlp", 81 | use_fused_vision_backbone: Optional[bool] = None, 82 | image_resize_strategy: str = "letterbox", 83 | text_config: Optional[Dict[str, Any]] = None, 84 | llm_max_length: int = 2048, 85 | pad_token_id: int = 32000, 86 | pad_to_multiple_of: int = 64, 87 | output_projector_states: bool = False, 88 | **kwargs: str, 89 | ) -> None: 90 | if vision_backbone_id not in VALID_VISION_BACKBONES: 91 | raise ValueError(f"Vision backbone `{vision_backbone_id}` not in {VALID_VISION_BACKBONES = }") 92 | 93 | if llm_backbone_id not in VALID_LLM_BACKBONES: 94 | raise ValueError(f"LLM backbone `{llm_backbone_id}` not in {VALID_LLM_BACKBONES = }") 95 | 96 | # Set Prismatic Configuration Fields 97 | self.vision_backbone_id = vision_backbone_id 98 | self.llm_backbone_id = llm_backbone_id 99 | self.arch_specifier = arch_specifier 100 | self.output_projector_states = output_projector_states 101 | 102 | # [Contract] All vision backbone parameters are lists =>> supports fused backbones with different preprocessing 103 | self.use_fused_vision_backbone = ( 104 | use_fused_vision_backbone 105 | if use_fused_vision_backbone is not None 106 | else any(self.vision_backbone_id.startswith(v) for v in ["dinoclip", "dinosiglip"]) 107 | ) 108 | 109 | self.timm_model_ids = VISION_BACKBONE_TO_TIMM_ID[self.vision_backbone_id] 110 | self.timm_override_act_layers = TIMM_OVERRIDE_ACT_LAYER[self.vision_backbone_id] 111 | self.image_sizes = VISION_BACKBONE_TO_RESOLUTION[self.vision_backbone_id] 112 | self.image_resize_strategy = image_resize_strategy 113 | 114 | self.hf_llm_id = LLM_BACKBONE_TO_HF_PATH[self.llm_backbone_id] 115 | self.llm_max_length = llm_max_length 116 | self.pad_token_id, self.pad_to_multiple_of = pad_token_id, pad_to_multiple_of 117 | 118 | # [IMPORTANT] HF Utilities actually look for a `text_config` field... we need to use that specific naming! 119 | self.text_config = ( 120 | CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]](**text_config) 121 | if text_config is not None 122 | else CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]]() 123 | ) 124 | 125 | # Dispatch **kwargs to super() =>> note that `pad_token_id` collides, so we pass it in here as well... 126 | super().__init__(pad_token_id=pad_token_id, **kwargs) 127 | 128 | 129 | class OpenVLAConfig(PrismaticConfig): 130 | model_type: str = "openvla" 131 | 132 | def __init__( 133 | self, 134 | norm_stats: Optional[Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]]] = None, 135 | n_action_bins: int = 256, 136 | **kwargs: str, 137 | ) -> None: 138 | self.norm_stats, self.n_action_bins = norm_stats, n_action_bins 139 | 140 | super().__init__(**kwargs) 141 | -------------------------------------------------------------------------------- /prismatic/util/data_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | data_utils.py 3 | 4 | General utilities and classes for facilitating data loading and collation. 5 | """ 6 | 7 | from dataclasses import dataclass 8 | from typing import Callable, Dict, Sequence, Tuple 9 | 10 | import torch 11 | from torch.nn.utils.rnn import pad_sequence 12 | 13 | # HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels) 14 | IGNORE_INDEX = -100 15 | 16 | 17 | def tree_map(fn: Callable, tree: dict) -> dict: 18 | """Maps a function over a nested dictionary.""" 19 | return {k: tree_map(fn, v) if isinstance(v, dict) else fn(v) for k, v in tree.items()} 20 | 21 | 22 | def tree_map_with_key(fn: Callable, tree: dict, keys: Sequence = ()) -> dict: 23 | """Maps a function over a nested dictionary.""" 24 | return { 25 | k: tree_map_with_key(fn, v, (*keys, k)) if isinstance(v, dict) else fn((*keys, k), v) for k, v in tree.items() 26 | } 27 | 28 | 29 | @dataclass 30 | class PaddedCollatorForLanguageModeling: 31 | model_max_length: int 32 | pad_token_id: int 33 | default_image_resolution: Tuple[int, int, int] 34 | padding_side: str = "right" 35 | pixel_values_dtype: torch.dtype = torch.float32 36 | 37 | def __post_init__(self) -> None: 38 | self.dummy_pixel_values = torch.zeros(self.default_image_resolution, dtype=self.pixel_values_dtype) 39 | 40 | def __call__(self, instances: Sequence[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: 41 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) 42 | pixel_values = [instance["pixel_values"] for instance in instances] 43 | 44 | # For now, we only support Tokenizers with `padding_side = "right"` during Training (but plan to extend!) 45 | # => Handle padding via RNN Utils => `pad_sequence` 46 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id) 47 | labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) 48 | 49 | # Truncate (if necessary) 50 | input_ids, labels = input_ids[:, : self.model_max_length], labels[:, : self.model_max_length] 51 | 52 | # Get `attention_mask` by checking for `pad_token_id` 53 | attention_mask = input_ids.ne(self.pad_token_id) 54 | 55 | # === Handle "unimodal" (language-only) vs. "multimodal" === 56 | 57 | # Some examples are "language-only" --> build a Tensor of `multimodal_indices` that we can slice into easily 58 | multimodal_indices = torch.tensor( 59 | [idx for idx in range(len(pixel_values)) if pixel_values[idx] is not None], dtype=torch.long 60 | ) 61 | 62 | # Stack all `pixel_values` --> depending on type (torch.Tensor, or Dict[str, torch.Tensor]) & presence of None 63 | if len(multimodal_indices) == 0: 64 | pixel_values = torch.stack([self.dummy_pixel_values for _ in range(len(input_ids))]) 65 | elif isinstance(pv_example := pixel_values[multimodal_indices[0]], torch.Tensor): 66 | pixel_values = torch.stack( 67 | [ 68 | pixel_values[idx] if idx in multimodal_indices else self.dummy_pixel_values 69 | for idx in range(len(input_ids)) 70 | ] 71 | ) 72 | elif isinstance(pv_example, dict): 73 | pixel_values = { 74 | k: torch.stack( 75 | [ 76 | pixel_values[idx][k] if idx in multimodal_indices else self.dummy_pixel_values 77 | for idx in range(len(input_ids)) 78 | ] 79 | ) 80 | for k in pv_example 81 | } 82 | else: 83 | raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}") 84 | 85 | return dict( 86 | pixel_values=pixel_values, 87 | input_ids=input_ids, 88 | attention_mask=attention_mask, 89 | labels=labels, 90 | multimodal_indices=multimodal_indices, 91 | ) 92 | 93 | 94 | @dataclass 95 | class PaddedCollatorForActionPrediction: 96 | model_max_length: int 97 | pad_token_id: int 98 | padding_side: str = "right" 99 | pixel_values_dtype: torch.dtype = torch.float32 100 | 101 | def __call__(self, instances: Sequence[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: 102 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) 103 | pixel_values = [instance["pixel_values"] for instance in instances] 104 | if "dataset_name" in instances[0]: 105 | dataset_names = [instance["dataset_name"] for instance in instances] 106 | else: 107 | dataset_names = None 108 | 109 | # For now, we only support Tokenizers with `padding_side = "right"` during training 110 | # => Handle padding via RNN Utils => `pad_sequence` 111 | assert self.padding_side == "right", f"Invalid Tokenizer `{self.padding_side = }`" 112 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id) 113 | labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) 114 | 115 | # Truncate (if necessary) 116 | input_ids, labels = input_ids[:, : self.model_max_length], labels[:, : self.model_max_length] 117 | 118 | # Get `attention_mask` by checking for `pad_token_id` 119 | attention_mask = input_ids.ne(self.pad_token_id) 120 | 121 | # [Contract] For VLA Training =>> No "Unimodal" Data! 122 | assert all([pv is not None for pv in pixel_values]), "Invalid VLA Example with `pixel_values = None`!" 123 | 124 | # Stack all `pixel_values` --> depending on type is torch.Tensor or Dict[str, torch.Tensor] 125 | if isinstance(pixel_values[0], torch.Tensor): 126 | pixel_values = torch.stack(pixel_values) 127 | elif isinstance(pixel_values[0], dict): 128 | pixel_values = { 129 | k: torch.stack([pixel_values[idx][k] for idx in range(len(input_ids))]) for k in pixel_values[0] 130 | } 131 | else: 132 | raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}") 133 | 134 | output = dict( 135 | pixel_values=pixel_values, 136 | input_ids=input_ids, 137 | attention_mask=attention_mask, 138 | labels=labels, 139 | ) 140 | if dataset_names is not None: 141 | output["dataset_names"] = dataset_names 142 | return output 143 | -------------------------------------------------------------------------------- /vla-scripts/pretrain_vq.py: -------------------------------------------------------------------------------- 1 | """Pretrain the VQ VAE tokenizer for an arbitrary tfds dataset. 2 | 3 | VQVAE should be installed w/ VQ-Bet instructions: https://github.com/jayLEE0301/vq_bet_official. 4 | This script should work with arbitrary data mixtures as defined in rlds/oxe/mixtures.py. 5 | Works with arbitrary length action chunks (see --future_action_horizon below). 6 | 7 | --- 8 | Example with bridge for 256 bins, ac_dim=7, ac_chunk=8, and num residual rounds = 7: 9 | python vla-scripts/pretrain_vq.py --data_dir $WHERE_IS_BRIDGE \\ 10 | --data_mix bridge_dataset --action_dim 7 --future_action_horizon 7 --vqvae_n_embed 256 11 | 12 | """ 13 | 14 | import argparse 15 | import json 16 | import os 17 | import shutil 18 | from pathlib import Path 19 | 20 | import torch 21 | import tqdm 22 | from vqvae.vqvae import VqVae 23 | 24 | import wandb 25 | from prismatic.vla.action_dataset_materialize import get_vla_action_dataset 26 | 27 | 28 | def main(): 29 | p = argparse.ArgumentParser() 30 | 31 | # dataset arguments 32 | p.add_argument("--data_dir", type=str, required=True, help="Where to look for the tfrecords") 33 | p.add_argument("--data_mix", type=str, required=True, help="The name of the data [mix] to use") 34 | p.add_argument("--save_folder", type=str, default="vq/", help="Folder to save the final vq model (under )") 35 | p.add_argument("--shuffle_buffer_size", type=int, default=256_000) 36 | 37 | # train arguments 38 | p.add_argument("--wandb_project", type=str, default="prismatic-vq-vla") 39 | p.add_argument("--wandb_entity", type=str, default=None) 40 | p.add_argument("--batch_size", type=int, default=1028) 41 | p.add_argument("--epochs", type=int, default=200) 42 | p.add_argument("--save_every_n_epochs", type=int, default=2) 43 | p.add_argument("--device", type=str, default="cuda") 44 | 45 | # residual VQ arguments 46 | p.add_argument("--action_dim", type=int, required=True, help="Action dimension (usually 7)") 47 | p.add_argument("--future_action_horizon", type=int, default=9, help="How many FUTURE actions to include in VQ") 48 | p.add_argument("--n_latent_dims", type=int, default=512, help="Underlying VQ latent dimension") 49 | p.add_argument("--default_image_resolution", type=int, nargs=3, default=[3, 224, 224]) 50 | p.add_argument( 51 | "--vqvae_n_embed", type=int, default=128, help="Number of token options per round, corresponds to binning width." 52 | ) 53 | p.add_argument( 54 | "--vqvae_groups", 55 | type=int, 56 | default=None, 57 | help="number of residual rounds (i.e., output ac dim), defaults to ac dim", 58 | ) 59 | p.add_argument("--load_dir", type=str, default=None) 60 | p.add_argument("--encoder_loss_multiplier", type=float, default=1.0) 61 | p.add_argument("--act_scale", type=float, default=1.0) 62 | 63 | args = p.parse_args() 64 | 65 | if args.vqvae_groups is None: 66 | args.vqvae_groups = args.action_dim 67 | 68 | exp_name = ( 69 | f"pretrain_vq+mx-{args.data_mix}+fach-{args.future_action_horizon}" 70 | f"+ng-{args.vqvae_groups}+nemb-{args.vqvae_n_embed}+nlatent-{args.n_latent_dims}" 71 | ) 72 | 73 | vla_dataset = get_vla_action_dataset( 74 | args.data_dir, 75 | args.data_mix, 76 | shuffle_buffer_size=args.shuffle_buffer_size, 77 | image_aug=False, 78 | future_action_window_size=args.future_action_horizon, 79 | default_image_resolution=tuple(args.default_image_resolution), 80 | include_images=False, 81 | ) 82 | 83 | vq_config = { 84 | "input_dim_w": args.action_dim, # action dimension 85 | # argparse fields 86 | "input_dim_h": args.future_action_horizon + 1, 87 | "n_latent_dims": args.n_latent_dims, 88 | "vqvae_n_embed": args.vqvae_n_embed, 89 | "vqvae_groups": args.vqvae_groups, 90 | "eval": False, 91 | "device": args.device, 92 | "load_dir": args.load_dir, 93 | "encoder_loss_multiplier": args.encoder_loss_multiplier, 94 | "act_scale": args.act_scale, 95 | } 96 | 97 | vqvae_model = VqVae(**vq_config) 98 | 99 | wandb.init(name=exp_name, project=args.wandb_project, entity=args.wandb_entity, config=vars(args)) 100 | 101 | # make all required directories. 102 | save_path = Path(args.save_folder) / exp_name 103 | save_path.mkdir(parents=True, exist_ok=False) 104 | (save_path / "checkpoints").mkdir() 105 | 106 | # save to experiment 107 | with open(save_path / "config.json", "w") as f: 108 | json.dump(vq_config, f, indent=4) 109 | 110 | train_loader = torch.utils.data.DataLoader( 111 | vla_dataset, 112 | batch_size=args.batch_size, 113 | num_workers=0, 114 | ) 115 | loader_iter = iter(train_loader) 116 | 117 | step_count = 0 118 | 119 | for epoch in tqdm.trange(args.epochs): 120 | for _ in tqdm.trange(len(train_loader)): 121 | batch = next(loader_iter) 122 | 123 | # N T D 124 | act = batch["action"].to(args.device) 125 | 126 | ( 127 | encoder_loss, 128 | vq_loss_state, 129 | vq_code, 130 | vqvae_recon_loss, 131 | ) = vqvae_model.vqvae_update( 132 | act 133 | ) # N T D 134 | 135 | wandb.log({"pretrain/n_different_codes": len(torch.unique(vq_code))}) 136 | wandb.log({"pretrain/n_different_combinations": len(torch.unique(vq_code, dim=0))}) 137 | wandb.log({"pretrain/encoder_loss": encoder_loss}) 138 | wandb.log({"pretrain/vq_loss_state": vq_loss_state}) 139 | wandb.log({"pretrain/vqvae_recon_loss": vqvae_recon_loss}) 140 | 141 | step_count += 1 142 | 143 | if args.save_every_n_epochs > 0 and (epoch + 1) % args.save_every_n_epochs == 0: 144 | print(f"Saving checkpoint after {epoch + 1} epoch(s) and {step_count} steps.") 145 | state_dict = vqvae_model.state_dict() 146 | torch.save(state_dict, os.path.join(save_path, "checkpoints/model.pt")) 147 | shutil.copy( 148 | os.path.join(save_path, "checkpoints/model.pt"), 149 | os.path.join(save_path, f"checkpoints/step-{step_count}-epoch-{epoch + 1}.pt"), 150 | ) 151 | 152 | # SAVE AT THE END 153 | print("Saving last checkpoint...") 154 | torch.save(state_dict, os.path.join(save_path, "checkpoints/model.pt")) 155 | print("Done.") 156 | 157 | 158 | if __name__ == "__main__": 159 | main() 160 | -------------------------------------------------------------------------------- /prismatic/models/vlas/openvla.py: -------------------------------------------------------------------------------- 1 | """ 2 | openvla.py 3 | 4 | PyTorch Module defining OpenVLA as a lightweight wrapper around a PrismaticVLM; defines custom logic around 5 | discretizing actions with the ActionTokenizer. 6 | """ 7 | 8 | from typing import Dict, List, Optional, Union 9 | 10 | import numpy as np 11 | import torch 12 | from PIL.Image import Image as Img 13 | from transformers import LlamaTokenizerFast 14 | from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast 15 | 16 | from prismatic.models.vlms.prismatic import PrismaticVLM 17 | from prismatic.overwatch import initialize_overwatch 18 | from prismatic.vla.action_tokenizer import ActionTokenizer 19 | 20 | # Initialize Overwatch =>> Wraps `logging.Logger` 21 | overwatch = initialize_overwatch(__name__) 22 | 23 | 24 | class OpenVLA(PrismaticVLM): 25 | def __init__( 26 | self, 27 | *args, 28 | norm_stats: Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]], 29 | action_tokenizer: ActionTokenizer, 30 | **kwargs, 31 | ) -> None: 32 | super().__init__(*args, **kwargs) 33 | self.norm_stats = norm_stats 34 | self.action_tokenizer = action_tokenizer 35 | 36 | @torch.inference_mode() 37 | def predict_action( 38 | self, image: Union[Img, List[Img]], instruction: str, unnorm_key: Optional[str] = None, **kwargs: str 39 | ) -> np.ndarray: 40 | """ 41 | Core function for VLA inference; maps input image and task instruction to continuous action (de-tokenizes). 42 | 43 | @param image: PIL Image as [height, width, 3] 44 | @param instruction: Task instruction string 45 | @param unnorm_key: Optional dataset name for retrieving un-normalizing statistics; if None, checks that model 46 | was trained only on a single dataset, and retrieves those statistics. 47 | 48 | @return Unnormalized (continuous) action vector --> end-effector deltas. 49 | """ 50 | image_transform, tokenizer = self.vision_backbone.get_image_transform(), self.llm_backbone.tokenizer 51 | 52 | # Build VLA Prompt 53 | prompt_builder = self.get_prompt_builder() 54 | prompt_builder.add_turn(role="human", message=f"What action should the robot take to {instruction.lower()}?") 55 | prompt_text = prompt_builder.get_prompt() 56 | 57 | # Prepare Inputs 58 | input_ids = tokenizer(prompt_text, truncation=True, return_tensors="pt").input_ids.to(self.device) 59 | if isinstance(tokenizer, LlamaTokenizerFast): 60 | # If the special empty token ('') does not already appear after the colon (':') token in the prompt 61 | # (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time 62 | if not torch.all(input_ids[:, -1] == 29871): 63 | input_ids = torch.cat( 64 | (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1 65 | ) 66 | elif isinstance(tokenizer, Qwen2TokenizerFast): 67 | # do nothing here. I think... 68 | pass 69 | else: 70 | raise ValueError(f"Unsupported `tokenizer` type = {type(tokenizer)}") 71 | 72 | # Preprocess Image 73 | pixel_values = image_transform(image) 74 | if isinstance(pixel_values, torch.Tensor): 75 | pixel_values = pixel_values[None, ...].to(self.device) 76 | elif isinstance(pixel_values, dict): 77 | pixel_values = {k: v[None, ...].to(self.device) for k, v in pixel_values.items()} 78 | else: 79 | raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}") 80 | 81 | # Invoke super().generate --> taps into `GenerationMixin` which (redirects) to `forward()` 82 | autocast_dtype = self.llm_backbone.half_precision_dtype 83 | with torch.autocast("cuda", dtype=autocast_dtype, enabled=self.enable_mixed_precision_training): 84 | # fmt: off 85 | generated_ids = super(PrismaticVLM, self).generate( 86 | input_ids=input_ids, # Shape: [1, seq] 87 | pixel_values=pixel_values, # Shape: [1, (opt T,) 3, res, res] or Dict[str, ...] 88 | max_new_tokens=self.get_action_dim(unnorm_key), 89 | **kwargs 90 | ) 91 | # fmt: on 92 | 93 | # Extract predicted action tokens and translate into (normalized) continuous actions 94 | predicted_action_token_ids = generated_ids[0, -self.get_action_dim(unnorm_key) :] 95 | normalized_actions = self.action_tokenizer.decode_token_ids_to_actions(predicted_action_token_ids.cpu().numpy()) 96 | 97 | # Un-normalize Actions 98 | action_norm_stats = self.get_action_stats(unnorm_key) 99 | mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool)) 100 | action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"]) 101 | actions = np.where( 102 | mask, 103 | 0.5 * (normalized_actions + 1) * (action_high - action_low) + action_low, 104 | normalized_actions, 105 | ) 106 | 107 | return actions 108 | 109 | @staticmethod 110 | def _check_unnorm_key(norm_stats: Dict, unnorm_key: str) -> str: 111 | if unnorm_key is None: 112 | assert len(norm_stats) == 1, ( 113 | f"Your model was trained on more than one dataset, please pass a `unnorm_key` from the following " 114 | f"options to choose the statistics used for un-normalizing actions: {norm_stats.keys()}" 115 | ) 116 | unnorm_key = next(iter(norm_stats.keys())) 117 | 118 | # Error Handling 119 | assert ( 120 | unnorm_key in norm_stats 121 | ), f"The `unnorm_key` you chose is not in the set of available statistics; choose from: {norm_stats.keys()}" 122 | 123 | return unnorm_key 124 | 125 | def get_action_dim(self, unnorm_key: Optional[str] = None) -> int: 126 | """Dimensionality of the policy's action space.""" 127 | unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) 128 | 129 | return len(self.norm_stats[unnorm_key]["action"]["q01"]) 130 | 131 | def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict: 132 | """Dimensionality of the policy's action space.""" 133 | unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) 134 | 135 | return self.norm_stats[unnorm_key]["action"] 136 | -------------------------------------------------------------------------------- /experiments/robot/bridge/widowx_env.py: -------------------------------------------------------------------------------- 1 | """ 2 | WidowXGym environment definition. 3 | 4 | Copied over from the Octo eval code linked below, with some modifications: 5 | https://github.com/octo-models/octo/blob/main/examples/envs/widowx_env.py 6 | """ 7 | 8 | import time 9 | from typing import Dict 10 | 11 | import gym 12 | import numpy as np 13 | from pyquaternion import Quaternion 14 | from widowx_envs.widowx_env_service import WidowXClient 15 | 16 | 17 | def state_to_eep(xyz_coor, zangle: float): 18 | """ 19 | Implements the state to end-effector pose function, returning a 4x4 matrix. 20 | Refer to `widowx_controller/widowx_controller.py` in the `bridge_data_robot` codebase. 21 | """ 22 | assert len(xyz_coor) == 3 23 | DEFAULT_ROTATION = np.array([[0, 0, 1.0], [0, 1.0, 0], [-1.0, 0, 0]]) 24 | new_pose = np.eye(4) 25 | new_pose[:3, -1] = xyz_coor 26 | new_quat = Quaternion(axis=np.array([0.0, 0.0, 1.0]), angle=zangle) * Quaternion(matrix=DEFAULT_ROTATION) 27 | new_pose[:3, :3] = new_quat.rotation_matrix 28 | return new_pose 29 | 30 | 31 | def wait_for_obs(widowx_client): 32 | """Fetches an observation from the WidowXClient.""" 33 | obs = widowx_client.get_observation() 34 | while obs is None: 35 | print("Waiting for observations...") 36 | obs = widowx_client.get_observation() 37 | time.sleep(1) 38 | return obs 39 | 40 | 41 | def convert_obs(obs, im_size): 42 | """Preprocesses image and proprio observations.""" 43 | # Preprocess image 44 | image_obs = (obs["image"].reshape(3, im_size, im_size).transpose(1, 2, 0) * 255).astype(np.uint8) 45 | # Add padding to proprio to match RLDS training 46 | proprio = np.concatenate([obs["state"][:6], [0], obs["state"][-1:]]) 47 | return { 48 | "image_primary": image_obs, 49 | "full_image": obs["full_image"], 50 | "proprio": proprio, 51 | } 52 | 53 | 54 | def null_obs(img_size): 55 | """Returns a dummy observation with all-zero image and proprio.""" 56 | return { 57 | "image_primary": np.zeros((img_size, img_size, 3), dtype=np.uint8), 58 | "proprio": np.zeros((8,), dtype=np.float64), 59 | } 60 | 61 | 62 | class WidowXGym(gym.Env): 63 | """ 64 | A Gym environment for the WidowX controller provided by: 65 | https://github.com/rail-berkeley/bridge_data_robot 66 | """ 67 | 68 | def __init__( 69 | self, 70 | widowx_client: WidowXClient, 71 | cfg: Dict, 72 | im_size: int = 256, 73 | blocking: bool = True, 74 | ): 75 | self.widowx_client = widowx_client 76 | self.im_size = im_size 77 | self.blocking = blocking 78 | self.observation_space = gym.spaces.Dict( 79 | { 80 | "image_primary": gym.spaces.Box( 81 | low=np.zeros((im_size, im_size, 3)), 82 | high=255 * np.ones((im_size, im_size, 3)), 83 | dtype=np.uint8, 84 | ), 85 | "full_image": gym.spaces.Box( 86 | low=np.zeros((480, 640, 3)), 87 | high=255 * np.ones((480, 640, 3)), 88 | dtype=np.uint8, 89 | ), 90 | "proprio": gym.spaces.Box(low=np.ones((8,)) * -1, high=np.ones((8,)), dtype=np.float64), 91 | } 92 | ) 93 | self.action_space = gym.spaces.Box(low=np.zeros((7,)), high=np.ones((7,)), dtype=np.float64) 94 | self.cfg = cfg 95 | 96 | def step(self, action): 97 | self.widowx_client.step_action(action, blocking=self.blocking) 98 | 99 | raw_obs = self.widowx_client.get_observation() 100 | 101 | truncated = False 102 | if raw_obs is None: 103 | # this indicates a loss of connection with the server 104 | # due to an exception in the last step so end the trajectory 105 | truncated = True 106 | obs = null_obs(self.im_size) # obs with all zeros 107 | else: 108 | obs = convert_obs(raw_obs, self.im_size) 109 | 110 | return obs, 0, False, truncated, {} 111 | 112 | def reset(self, seed=None, options=None): 113 | super().reset(seed=seed) 114 | 115 | self.widowx_client.reset() 116 | self.move_to_start_state() 117 | 118 | raw_obs = wait_for_obs(self.widowx_client) 119 | obs = convert_obs(raw_obs, self.im_size) 120 | 121 | return obs, {} 122 | 123 | def get_observation(self): 124 | raw_obs = wait_for_obs(self.widowx_client) 125 | obs = convert_obs(raw_obs, self.im_size) 126 | return obs 127 | 128 | def move_to_start_state(self): 129 | successful = False 130 | while not successful: 131 | try: 132 | # Get XYZ position from user. 133 | init_x, init_y, init_z = self.cfg.init_ee_pos 134 | x_val = input(f"Enter x value of gripper starting position (leave empty for default == {init_x}): ") 135 | if x_val == "": 136 | x_val = init_x 137 | y_val = input(f"Enter y value of gripper starting position (leave empty for default == {init_y}): ") 138 | if y_val == "": 139 | y_val = init_y 140 | z_val = input(f"Enter z value of gripper starting position (leave empty for default == {init_z}): ") 141 | if z_val == "": 142 | z_val = init_z 143 | # Fix initial orientation and add user's commanded XYZ into start transform. 144 | # Initial orientation: gripper points ~15 degrees away from the standard orientation (quat=[0, 0, 0, 1]). 145 | transform = np.array( 146 | [ 147 | [0.267, 0.000, 0.963, float(x_val)], 148 | [0.000, 1.000, 0.000, float(y_val)], 149 | [-0.963, 0.000, 0.267, float(z_val)], 150 | [0.00, 0.00, 0.00, 1.00], 151 | ] 152 | ) 153 | # IMPORTANT: It is very important to move to reset position with blocking==True. 154 | # Otherwise, the controller's `_reset_previous_qpos()` call will be called immediately after 155 | # the move command is given -- and before the move is complete -- and the initial state will 156 | # be totally incorrect. 157 | self.widowx_client.move(transform, duration=0.8, blocking=True) 158 | successful = True 159 | except Exception as e: 160 | print(e) 161 | -------------------------------------------------------------------------------- /scripts/additional-datasets/lrv_instruct.py: -------------------------------------------------------------------------------- 1 | """ 2 | scripts/additional-datasets/lrv_instruct.py 3 | 4 | Standalone script for pre-processing the LRV-Instruct data (including the chart/diagram reasoning split). This isn't 5 | full conversational chat data, but rather each example has an input prompt and output response; we'll use this structure 6 | to format the data equivalently to the LLaVa-v1.5 dataset. 7 | 8 | In general, LRV Instruct provides *both positive and negative* examples -- where a negative example is a question or 9 | instruction that is *not answerable* or *irrelevant*; the goal of this dataset is to reduce hallucinations in VLMs. 10 | 11 | This script downloads the raw instruct data (three different JSON files), as well as the image files; the non-chart 12 | images come from Visual Genome, but are hosted separately by the LRV Instruct authors and use different image IDs, so 13 | we're downloading this data (again) for simplicity. The chart images come from the LRV Instruct authors, and are sourced 14 | from statista.com. All file URLS are here: https://github.com/FuxiaoLiu/LRV-Instruction/blob/main/download.txt#L20 15 | 16 | Note that we are using the *coordinate-free* data (due to noted inaccuracies in the original coordinates). 17 | 18 | Make sure to download the images first to `data/download/llava-v1.5-instruct/lrv` 19 | => cd data/download/llava-v1.5-instruct/lrv 20 | => [Visual Genome] gdown https://drive.google.com/uc?id=1k9MNV-ImEV9BYEOeLEIb4uGEUZjd3QbM 21 | => `tar -xvf image.tar.gz; mv image lrv-vg; rm image.tar.gz` 22 | => [Chart Data] gdown https://drive.google.com/uc?id=1Dey-undzW2Nl21CYLFSkP_Y4RrfRJkYd 23 | => `unzip chart_image.zip; rm -rf __MACOSX; mv chart_image lrv-chart; rm chart_image.zip` 24 | 25 | Download the raw JSON files to the same directory - `data/download/llava-v1.5-instruct/lrv` 26 | => [LRV Instruct Pt. 1] gdown https://drive.google.com/uc?id=1pWkxE2kqpys1VdwBi99ZXN6-XY5SqhwU 27 | => `filter_cap1.json` 28 | => [LRV Instruct Pt. II] gdown https://drive.google.com/uc?id=1NTxkuRPlvDn7aWaJpK_yb0p5r0cxPLNZ 29 | => `filter_cap_more1.json` 30 | => [Chart Instruct] gdown https://drive.google.com/uc?id=13j2U-ectsYGR92r6J5hPdhT8T5ezItHF 31 | => `chart_release_update.json` 32 | 33 | References: "Mitigating Hallucination in Large Multi-Modal Models via Robust Instruction Tuning" 34 | => Paper: https://arxiv.org/abs/2306.14565 35 | => Github / Data: https://github.com/FuxiaoLiu/LRV-Instruction 36 | """ 37 | 38 | import json 39 | import random 40 | from pathlib import Path 41 | 42 | from tqdm import tqdm 43 | 44 | # === Constants === 45 | BASE_DIR = Path("data/download/llava-v1.5-instruct") 46 | LRV_DIR = BASE_DIR / "lrv" 47 | 48 | VG_JSON_FILES, VG_IMG_DIR = [LRV_DIR / "filter_cap1.json", LRV_DIR / "filter_cap_more1.json"], LRV_DIR / "lrv-vg" 49 | CHART_JSON_FILE, CHART_IMG_DIR = LRV_DIR / "chart_release_update.json", LRV_DIR / "lrv-chart" 50 | 51 | # JSON Files for "merged" variants fo the dataset (with `llava_v1_5_mix665k.json` and `llava_v1_5_lvis4v_mix888k.json` 52 | BASE_JSON_FILE = BASE_DIR / "llava_v1_5_mix665k.json" 53 | BASE_LVIS_JSON_FILE = BASE_DIR / "llava_v1_5_lvis4v_mix888k.json" 54 | 55 | MERGED_BASE_LRV_JSON_FILE = BASE_DIR / "llava_v1_5_lrv_mix1008k.json" 56 | MERGED_BASE_LVIS_LRV_JSON_FILE = BASE_DIR / "llava_v1_5_lvis4v_lrv_mix1231k.json" 57 | 58 | 59 | def build_lrv_instruct() -> None: 60 | print("[*] Downloading and Formatting `LRV-Instruct` Dataset!") 61 | 62 | # Set Random Seed 63 | random.seed(7) 64 | 65 | # Open VG JSON Files 66 | vg_examples = [] 67 | for fn in VG_JSON_FILES: 68 | with open(fn, "r") as f: 69 | vg_examples.extend(json.load(f)) 70 | 71 | # Iterate through VG Examples & Verify Image Existence 72 | for example in tqdm(vg_examples, desc="[*] Verifying all VG Images in LRV Instruct"): 73 | image_id = example["image_id"] 74 | assert (VG_IMG_DIR / f"{image_id}.jpg").exists(), f"Missing Image `{image_id}.jpg`" 75 | 76 | # Open Chart JSON File 77 | with open(CHART_JSON_FILE, "r") as f: 78 | chart_examples = json.load(f) 79 | 80 | # Iterate through Chart Examples & Verify Image Existence 81 | for example in tqdm(chart_examples, desc="[*] Verifying all Chart Images in LRV Instruct"): 82 | image_path = example["image_id"] 83 | assert (CHART_IMG_DIR / image_path).exists(), f"Missing Image `{image_path}`" 84 | 85 | # Reformat VG Examples as LLaVa "Chat" Style => List[Entry] where each Entry is a Dictionary: 86 | # => "id": str 87 | # => "image": str -- Relative path from `BASE_DIR` 88 | # => "conversations: List[Turn] where Turn is a Dictionary: 89 | # => {"from": "human", "value": "\n{VG_EXAMPLE['question']}"} 90 | # => {"from": "gpt", "value": "{VG_EXAMPLE['answer']}"} 91 | vg_chat_json = [] 92 | for vg_example in tqdm(vg_examples, desc="[*] Converting all VG Examples to LLaVa Format"): 93 | vg_chat_json.append( 94 | { 95 | "id": vg_example["image_id"], 96 | "image": f"lrv/lrv-vg/{vg_example['image_id']}.jpg", 97 | "conversations": [ 98 | {"from": "human", "value": f"\n{vg_example['question'].strip()}"}, 99 | {"from": "gpt", "value": vg_example["answer"].strip()}, 100 | ], 101 | } 102 | ) 103 | 104 | # Reformat Chart Examples as LLaVa "Chat" Style 105 | chart_chat_json = [] 106 | for chart_example in tqdm(chart_examples, desc="[*] Converting all Chart Examples to LLaVa Format"): 107 | chart_chat_json.append( 108 | { 109 | "id": Path(chart_example["image_id"]).stem, 110 | "image": f"lrv/lrv-chart/{chart_example['image_id']}", 111 | "conversations": [ 112 | {"from": "human", "value": f"\n{chart_example['question'].strip()}"}, 113 | {"from": "gpt", "value": chart_example["answer"].strip()}, 114 | ], 115 | } 116 | ) 117 | 118 | # Merge and Create Full LRV Chat Data =>> Total of 342,799 Examples 119 | lrv_data = vg_chat_json + chart_chat_json 120 | 121 | # Create Stacked Datasets =>> Shuffle for Good Measure! 122 | print("[*] Loading LLaVa v1.5 Data!") 123 | with open(BASE_JSON_FILE, "r") as f: 124 | llava_v15_data = json.load(f) 125 | 126 | # Combine & Shuffle & Write 127 | llava_lrv_data = llava_v15_data + lrv_data 128 | 129 | random.shuffle(llava_lrv_data) 130 | random.shuffle(llava_lrv_data) 131 | random.shuffle(llava_lrv_data) 132 | 133 | with open(MERGED_BASE_LRV_JSON_FILE, "w") as f: 134 | json.dump(llava_lrv_data, f) 135 | 136 | print("[*] Loading LLaVa v1.5 + LVIS-4V Instruct Data!") 137 | with open(BASE_LVIS_JSON_FILE, "r") as f: 138 | llava_v15_lvis_data = json.load(f) 139 | 140 | # Combine & Shuffle & Write 141 | full_data = llava_v15_lvis_data + lrv_data 142 | 143 | random.shuffle(full_data) 144 | random.shuffle(full_data) 145 | random.shuffle(full_data) 146 | 147 | with open(MERGED_BASE_LVIS_LRV_JSON_FILE, "w") as f: 148 | json.dump(full_data, f) 149 | 150 | 151 | if __name__ == "__main__": 152 | build_lrv_instruct() 153 | -------------------------------------------------------------------------------- /prismatic/training/strategies/ddp.py: -------------------------------------------------------------------------------- 1 | """ 2 | ddp.py 3 | 4 | Core class definition for a strategy implementing Torch native Distributed Data Parallel Training; note that on most 5 | GPU hardware and LLM backbones >= 5-7B parameters, DDP training will OOM, which is why we opt for FSDP. 6 | """ 7 | 8 | import shutil 9 | from pathlib import Path 10 | from typing import Optional 11 | 12 | import torch 13 | from torch.nn.parallel import DistributedDataParallel as DDP 14 | from torch.optim import AdamW 15 | from transformers.optimization import get_constant_schedule, get_cosine_schedule_with_warmup 16 | 17 | from prismatic.overwatch import initialize_overwatch 18 | from prismatic.training.strategies.base_strategy import TrainingStrategy 19 | 20 | # Initialize Overwatch =>> Wraps `logging.Logger` 21 | overwatch = initialize_overwatch(__name__) 22 | 23 | 24 | class DDPStrategy(TrainingStrategy): 25 | @overwatch.rank_zero_only 26 | def save_checkpoint( 27 | self, 28 | run_dir: Path, 29 | global_step: int, 30 | epoch: int, 31 | train_loss: Optional[float] = None, 32 | only_trainable: bool = True, 33 | ) -> None: 34 | """Save a checkpoint to the `run_dir` only containing the state_dicts for trainable parameters by default.""" 35 | assert isinstance(self.vlm, DDP), "save_checkpoint assumes VLM is already wrapped in DDP!" 36 | 37 | # Splinter State Dictionary by Top-Level Submodules (or subset, if `only_trainable`) 38 | model_state_dicts = { 39 | mkey: getattr(self.vlm.module, mkey).state_dict() 40 | for mkey in (self.trainable_module_keys if only_trainable else self.all_module_keys) 41 | } 42 | optimizer_state_dict = self.optimizer.state_dict() 43 | 44 | # Set Checkpoint Path =>> Embed *minimal* training statistics! 45 | checkpoint_dir = run_dir / "checkpoints" 46 | if train_loss is None: 47 | checkpoint_path = checkpoint_dir / f"step-{global_step:06d}-epoch-{epoch:02d}-loss=inf.pt" 48 | else: 49 | checkpoint_path = checkpoint_dir / f"step-{global_step:06d}-epoch-{epoch:02d}-loss={train_loss:.4f}.pt" 50 | 51 | # Save Checkpoint & Copy Latest to `latest-checkpoint.pt` 52 | torch.save({"model": model_state_dicts, "optimizer": optimizer_state_dict}, checkpoint_path) 53 | shutil.copy(checkpoint_path, checkpoint_dir / "latest-checkpoint.pt") 54 | 55 | def run_setup(self, run_dir: Path, n_train_examples: int) -> None: 56 | # Gradient Checkpointing Setup 57 | if self.enable_gradient_checkpointing: 58 | # For Gradient Checkpointing --> we make the assumption that the "bulk" of activation memory is taken up 59 | # by the LLM; because we also make the explicit assumption that each LLM is derived from a HF 60 | # pretrained model, the only thing we *need* to do (technically) is call `gradient_checkpoint_enable` 61 | # on `self.llm_backbone`. 62 | # 63 | # What does it actually do? --> runs the *generic* custom_forward + torch.utils.checkpoint.checkpoint logic 64 | # => github.com/huggingface/transformers/.../models/llama/modeling_llama.py#L692-L706 65 | # 66 | # Additional Reference (to better understand gradient checkpointing in PyTorch writ large) 67 | # => github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb 68 | overwatch.info("Enabling Gradient Checkpointing on LLM Backbone", ctx_level=1) 69 | self.vlm.llm_backbone.gradient_checkpointing_enable() 70 | 71 | # Move to Device =>> Note parameters are in full precision (*mixed precision* will only autocast as appropriate) 72 | overwatch.info("Placing Entire VLM (Vision Backbone, LLM Backbone, Projector Weights) on GPU", ctx_level=1) 73 | self.vlm.to(self.device_id) 74 | 75 | # Wrap with Distributed Data Parallel 76 | # => Note: By default, wrapping naively with DDP(self.vlm) will initialize a *separate* buffer on GPU that 77 | # is the same size/dtype as the model parameters; this will *double* GPU memory! 78 | # - stackoverflow.com/questions/68949954/model-takes-twice-the-memory-footprint-with-distributed-data-parallel 79 | overwatch.info("Wrapping VLM with Distributed Data Parallel", ctx_level=1) 80 | self.vlm = DDP(self.vlm, device_ids=[self.device_id], gradient_as_bucket_view=True) 81 | 82 | # Create Optimizer and LR Scheduler =>> note that most of the LR Schedulers we use require `max_steps/epochs` 83 | # => Optimizer should only operate on parameters that are *unfrozen* / trainable! 84 | trainable_params = [param for param in self.vlm.parameters() if param.requires_grad] 85 | if self.max_steps is None: 86 | num_training_steps = (n_train_examples * self.epochs) // self.global_batch_size 87 | else: 88 | num_training_steps = self.max_steps 89 | 90 | if self.lr_scheduler_type == "linear-warmup+cosine-decay": 91 | # Set warmup steps (floor) based on `warmup_ratio` (should be 0.03 - 0.05) 92 | num_warmup_steps = int(num_training_steps * self.warmup_ratio) 93 | 94 | assert self.weight_decay == 0, "DDP training does not currently support `weight_decay` > 0!" 95 | self.optimizer = AdamW(trainable_params, lr=self.learning_rate, weight_decay=self.weight_decay) 96 | self.lr_scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps, num_training_steps) 97 | for param_group in self.optimizer.param_groups: 98 | param_group["lr"] = 0.0 99 | 100 | elif self.lr_scheduler_type == "constant": 101 | num_warmup_steps = 0 102 | 103 | assert self.weight_decay == 0, "DDP training does not currently support `weight_decay` > 0!" 104 | self.optimizer = AdamW(trainable_params, lr=self.learning_rate, weight_decay=self.weight_decay) 105 | self.lr_scheduler = get_constant_schedule(self.optimizer) 106 | 107 | else: 108 | raise ValueError(f"Learning Rate Schedule with type `{self.lr_scheduler_type}` is not supported!") 109 | 110 | # Finalize Setup =>> Log 111 | overwatch.info( 112 | "DDP Strategy =>> Finalized Training Setup:\n" 113 | f" |-> Global (Effective) Batch Size = {self.global_batch_size}\n" 114 | f" |-> Per-Device Batch Size = {self.per_device_batch_size}\n" 115 | f" |-> Distributed World Size = {overwatch.world_size()}\n" 116 | f" |-> Gradient Accumulation Steps = {self.grad_accumulation_steps}\n\n" 117 | f" |-> LLM Backbone Gradient Checkpointing = {self.enable_gradient_checkpointing}\n" 118 | f" |-> Use Native AMP = {self.enable_mixed_precision_training} ({self.mixed_precision_dtype})\n\n" 119 | f" |-> Default AdamW LR = {self.learning_rate}\n" 120 | f" |-> AdamW Weight Decay = {self.weight_decay}\n" 121 | f" |-> LR Scheduler Type = {self.lr_scheduler_type}\n" 122 | f" |-> LR Scheduler Warmup Steps (Ratio) = {num_warmup_steps} ({self.warmup_ratio})\n" 123 | f" |-> Dataset Size = {n_train_examples} Examples\n" 124 | f" |-> Max Steps = {num_training_steps}\n" 125 | ) 126 | 127 | def clip_grad_norm(self) -> None: 128 | torch.nn.utils.clip_grad_norm_(self.vlm.parameters(), max_norm=self.max_grad_norm) 129 | -------------------------------------------------------------------------------- /experiments/robot/bridge/run_bridgev2_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | run_bridgev2_eval.py 3 | 4 | Runs a model in a real-world Bridge V2 environment. 5 | 6 | Usage: 7 | # OpenVLA: 8 | python experiments/robot/bridge/run_bridgev2_eval.py \\ 9 | --model_family openvla \\ 10 | --pretrained_checkpoint openvla/openvla-7b 11 | """ 12 | 13 | import sys 14 | import time 15 | from dataclasses import dataclass, field 16 | from pathlib import Path 17 | from typing import Dict, List, Union 18 | 19 | import draccus 20 | 21 | # Append current directory so that interpreter can find experiments.robot 22 | sys.path.append(".") 23 | from experiments.robot.bridge.bridgev2_utils import ( 24 | get_next_task_label, 25 | get_preprocessed_image, 26 | get_widowx_env, 27 | refresh_obs, 28 | save_rollout_data, 29 | save_rollout_video, 30 | ) 31 | from experiments.robot.openvla_utils import get_processor 32 | from experiments.robot.robot_utils import ( 33 | get_action, 34 | get_image_resize_size, 35 | get_model, 36 | ) 37 | 38 | 39 | @dataclass 40 | class GenerateConfig: 41 | # fmt: off 42 | 43 | ################################################################################################################# 44 | # Model-specific parameters 45 | ################################################################################################################# 46 | model_family: str = "openvla" # Model family 47 | pretrained_checkpoint: Union[str, Path] = "" # Pretrained checkpoint path 48 | load_in_8bit: bool = False # (For OpenVLA only) Load with 8-bit quantization 49 | load_in_4bit: bool = False # (For OpenVLA only) Load with 4-bit quantization 50 | 51 | center_crop: bool = False # Center crop? (if trained w/ random crop image aug) 52 | 53 | ################################################################################################################# 54 | # WidowX environment-specific parameters 55 | ################################################################################################################# 56 | host_ip: str = "localhost" 57 | port: int = 5556 58 | 59 | # Note: Setting initial orientation with a 30 degree offset, which makes the robot appear more natural 60 | init_ee_pos: List[float] = field(default_factory=lambda: [0.3, -0.09, 0.26]) 61 | init_ee_quat: List[float] = field(default_factory=lambda: [0, -0.259, 0, -0.966]) 62 | bounds: List[List[float]] = field(default_factory=lambda: [ 63 | [0.1, -0.20, -0.01, -1.57, 0], 64 | [0.45, 0.25, 0.30, 1.57, 0], 65 | ] 66 | ) 67 | 68 | camera_topics: List[Dict[str, str]] = field(default_factory=lambda: [{"name": "/blue/image_raw"}]) 69 | 70 | blocking: bool = False # Whether to use blocking control 71 | max_episodes: int = 50 # Max number of episodes to run 72 | max_steps: int = 60 # Max number of timesteps per episode 73 | control_frequency: float = 5 # WidowX control frequency 74 | 75 | ################################################################################################################# 76 | # Utils 77 | ################################################################################################################# 78 | save_data: bool = False # Whether to save rollout data (images, actions, etc.) 79 | 80 | # fmt: on 81 | 82 | 83 | @draccus.wrap() 84 | def eval_model_in_bridge_env(cfg: GenerateConfig) -> None: 85 | assert cfg.pretrained_checkpoint is not None, "cfg.pretrained_checkpoint must not be None!" 86 | assert not cfg.center_crop, "`center_crop` should be disabled for Bridge evaluations!" 87 | 88 | # [OpenVLA] Set action un-normalization key 89 | cfg.unnorm_key = "bridge_orig" 90 | 91 | # Load model 92 | model = get_model(cfg) 93 | 94 | # [OpenVLA] Get Hugging Face processor 95 | processor = None 96 | if cfg.model_family == "openvla": 97 | processor = get_processor(cfg) 98 | 99 | # Initialize the WidowX environment 100 | env = get_widowx_env(cfg, model) 101 | 102 | # Get expected image dimensions 103 | resize_size = get_image_resize_size(cfg) 104 | 105 | # Start evaluation 106 | task_label = "" 107 | episode_idx = 0 108 | while episode_idx < cfg.max_episodes: 109 | # Get task description from user 110 | task_label = get_next_task_label(task_label) 111 | 112 | # Reset environment 113 | obs, _ = env.reset() 114 | 115 | # Setup 116 | t = 0 117 | step_duration = 1.0 / cfg.control_frequency 118 | replay_images = [] 119 | if cfg.save_data: 120 | rollout_images = [] 121 | rollout_states = [] 122 | rollout_actions = [] 123 | 124 | # Start episode 125 | input(f"Press Enter to start episode {episode_idx+1}...") 126 | print("Starting episode... Press Ctrl-C to terminate episode early!") 127 | last_tstamp = time.time() 128 | while t < cfg.max_steps: 129 | try: 130 | curr_tstamp = time.time() 131 | if curr_tstamp > last_tstamp + step_duration: 132 | print(f"t: {t}") 133 | print(f"Previous step elapsed time (sec): {curr_tstamp - last_tstamp:.2f}") 134 | last_tstamp = time.time() 135 | 136 | # Refresh the camera image and proprioceptive state 137 | obs = refresh_obs(obs, env) 138 | 139 | # Save full (not preprocessed) image for replay video 140 | replay_images.append(obs["full_image"]) 141 | 142 | # Get preprocessed image 143 | obs["full_image"] = get_preprocessed_image(obs, resize_size) 144 | 145 | # Query model to get action 146 | action = get_action( 147 | cfg, 148 | model, 149 | obs, 150 | task_label, 151 | processor=processor, 152 | ) 153 | 154 | # [If saving rollout data] Save preprocessed image, robot state, and action 155 | if cfg.save_data: 156 | rollout_images.append(obs["full_image"]) 157 | rollout_states.append(obs["proprio"]) 158 | rollout_actions.append(action) 159 | 160 | # Execute action 161 | print("action:", action) 162 | obs, _, _, _, _ = env.step(action) 163 | t += 1 164 | 165 | except (KeyboardInterrupt, Exception) as e: 166 | if isinstance(e, KeyboardInterrupt): 167 | print("\nCaught KeyboardInterrupt: Terminating episode early.") 168 | else: 169 | print(f"\nCaught exception: {e}") 170 | break 171 | 172 | # Save a replay video of the episode 173 | save_rollout_video(replay_images, episode_idx) 174 | 175 | # [If saving rollout data] Save rollout data 176 | if cfg.save_data: 177 | save_rollout_data(replay_images, rollout_images, rollout_states, rollout_actions, idx=episode_idx) 178 | 179 | # Redo episode or continue 180 | if input("Enter 'r' if you want to redo the episode, or press Enter to continue: ") != "r": 181 | episode_idx += 1 182 | 183 | 184 | if __name__ == "__main__": 185 | eval_model_in_bridge_env() 186 | -------------------------------------------------------------------------------- /prismatic/models/backbones/vision/dinoclip_vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | dinoclip_vit.py 3 | 4 | Vision backbone that returns concatenated features from both DINOv2 and CLIP. 5 | """ 6 | 7 | from dataclasses import dataclass 8 | from functools import partial 9 | from typing import Callable, Dict, Tuple 10 | 11 | import timm 12 | import torch 13 | from PIL import Image 14 | from timm.models.vision_transformer import Block, VisionTransformer 15 | from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy 16 | from torchvision.transforms import Compose, Resize 17 | 18 | from prismatic.models.backbones.vision.base_vision import ( 19 | ImageTransform, 20 | LetterboxPad, 21 | VisionBackbone, 22 | compute_sequence_patches, 23 | unpack_tuple, 24 | ) 25 | 26 | # Registry =>> Supported DinoCLIP Pairs (as TIMM identifiers) 27 | DINOCLIP_VISION_BACKBONES = { 28 | "dinoclip-vit-l-336px": { 29 | "dino": "vit_large_patch14_reg4_dinov2.lvd142m", 30 | "clip": "vit_large_patch14_clip_336.openai", 31 | }, 32 | } 33 | 34 | 35 | @dataclass 36 | class DinoCLIPImageTransform: 37 | dino_image_transform: ImageTransform 38 | clip_image_transform: ImageTransform 39 | is_prismatic: bool = True 40 | 41 | def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]: 42 | return {"dino": self.dino_image_transform(img, **kwargs), "clip": self.clip_image_transform(img, **kwargs)} 43 | 44 | 45 | class DinoCLIPViTBackbone(VisionBackbone): 46 | def __init__( 47 | self, 48 | vision_backbone_id: str, 49 | image_resize_strategy: str, 50 | default_image_size: int = 224, 51 | image_sequence_len: int = 1, 52 | ) -> None: 53 | super().__init__( 54 | vision_backbone_id, 55 | image_resize_strategy, 56 | default_image_size=default_image_size, 57 | image_sequence_len=image_sequence_len, 58 | ) 59 | self.dino_timm_path_or_url = DINOCLIP_VISION_BACKBONES[vision_backbone_id]["dino"] 60 | self.clip_timm_path_or_url = DINOCLIP_VISION_BACKBONES[vision_backbone_id]["clip"] 61 | 62 | # Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary 63 | self.dino_featurizer: VisionTransformer = timm.create_model( 64 | self.dino_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size 65 | ) 66 | self.dino_featurizer.eval() 67 | 68 | self.clip_featurizer: VisionTransformer = timm.create_model( 69 | self.clip_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size 70 | ) 71 | self.clip_featurizer.eval() 72 | 73 | # Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility 74 | # => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches! 75 | # => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385 76 | self.dino_featurizer.forward = unpack_tuple( 77 | partial(self.dino_featurizer.get_intermediate_layers, n={len(self.dino_featurizer.blocks) - 2}) 78 | ) 79 | self.clip_featurizer.forward = unpack_tuple( 80 | partial(self.clip_featurizer.get_intermediate_layers, n={len(self.clip_featurizer.blocks) - 2}) 81 | ) 82 | 83 | # Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models 84 | self.dino_data_cfg = timm.data.resolve_model_data_config(self.dino_featurizer) 85 | self.dino_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size) 86 | 87 | self.clip_data_cfg = timm.data.resolve_model_data_config(self.clip_featurizer) 88 | self.clip_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size) 89 | 90 | # Initialize *both* Transforms 91 | default_dino_transform = timm.data.create_transform(**self.dino_data_cfg, is_training=False) 92 | default_clip_transform = timm.data.create_transform(**self.clip_data_cfg, is_training=False) 93 | if self.image_resize_strategy == "resize-naive": 94 | assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_image_transform`!" 95 | assert isinstance(default_clip_transform, Compose), "Unexpected `default_clip_image_transform`!" 96 | assert isinstance(default_dino_transform.transforms[0], Resize) 97 | assert isinstance(default_clip_transform.transforms[0], Resize) 98 | 99 | target_size = (self.default_image_size, self.default_image_size) 100 | dino_transform = Compose( 101 | [ 102 | Resize(target_size, interpolation=default_dino_transform.transforms[0].interpolation), 103 | *default_dino_transform.transforms[1:], 104 | ] 105 | ) 106 | clip_transform = Compose( 107 | [ 108 | Resize(target_size, interpolation=default_clip_transform.transforms[0].interpolation), 109 | *default_clip_transform.transforms[1:], 110 | ] 111 | ) 112 | 113 | self.image_transform = DinoCLIPImageTransform(dino_transform, clip_transform) 114 | 115 | elif self.image_resize_strategy == "resize-crop": 116 | self.image_transform = DinoCLIPImageTransform(default_dino_transform, default_clip_transform) 117 | 118 | elif self.image_resize_strategy == "letterbox": 119 | assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_transform`!" 120 | assert isinstance(default_clip_transform, Compose), "Unexpected `default_clip_transform`!" 121 | assert "mean" in self.dino_data_cfg and "mean" in self.clip_data_cfg, "DinoCLIP `data_cfg` missing `mean`!" 122 | 123 | # Compute Padding Fill Value(s) (rescaled normalization mean if applicable) 124 | dino_fill = tuple([int(x * 255) for x in self.dino_data_cfg["mean"]]) 125 | clip_fill = tuple([int(x * 255) for x in self.clip_data_cfg["mean"]]) 126 | 127 | # Build New Transform 128 | self.image_transform = DinoCLIPImageTransform( 129 | Compose([LetterboxPad(dino_fill), *default_dino_transform.transforms]), 130 | Compose([LetterboxPad(clip_fill), *default_clip_transform.transforms]), 131 | ) 132 | 133 | else: 134 | raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!") 135 | 136 | def get_fsdp_wrapping_policy(self) -> Callable: 137 | """Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers.""" 138 | vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer}) 139 | transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) 140 | return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy]) 141 | 142 | def forward(self, pixel_values: Dict[str, torch.Tensor]) -> torch.Tensor: 143 | """Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches.""" 144 | if self.image_sequence_len == 1: 145 | dino_patches = self.dino_featurizer(pixel_values["dino"]) 146 | clip_patches = self.clip_featurizer(pixel_values["clip"]) 147 | else: 148 | featurizers = { 149 | "dino": self.dino_featurizer, 150 | "clip": self.clip_featurizer, 151 | } 152 | patches = compute_sequence_patches(pixel_values, featurizers, self.image_sequence_len) 153 | dino_patches, clip_patches = patches["dino"], patches["clip"] 154 | return torch.cat([dino_patches, clip_patches], dim=2) 155 | 156 | @property 157 | def default_image_resolution(self) -> Tuple[int, int, int]: 158 | return self.dino_data_cfg["input_size"] 159 | 160 | @property 161 | def embed_dim(self) -> int: 162 | return self.dino_featurizer.embed_dim + self.clip_featurizer.embed_dim 163 | 164 | @property 165 | def num_patches(self) -> int: 166 | assert self.dino_featurizer.patch_embed.num_patches == self.clip_featurizer.patch_embed.num_patches 167 | return self.dino_featurizer.patch_embed.num_patches * self.image_sequence_len 168 | 169 | @property 170 | def half_precision_dtype(self) -> torch.dtype: 171 | return torch.bfloat16 172 | --------------------------------------------------------------------------------