├── prismatic ├── py.typed ├── extern │ ├── __init__.py │ └── hf │ │ ├── __init__.py │ │ └── configuration_prismatic.py ├── models │ ├── backbones │ │ ├── __init__.py │ │ ├── llm │ │ │ ├── __init__.py │ │ │ ├── prompting │ │ │ │ ├── __init__.py │ │ │ │ ├── mistral_instruct_prompter.py │ │ │ │ ├── phi_prompter.py │ │ │ │ ├── base_prompter.py │ │ │ │ ├── vicuna_v15_prompter.py │ │ │ │ └── llama2_chat_prompter.py │ │ │ ├── phi.py │ │ │ ├── mistral.py │ │ │ └── llama2.py │ │ └── vision │ │ │ ├── __init__.py │ │ │ ├── dinov2_vit.py │ │ │ ├── in1k_vit.py │ │ │ ├── siglip_vit.py │ │ │ ├── clip_vit.py │ │ │ ├── dinoclip_vit.py │ │ │ ├── dinosiglip_vit.py │ │ │ └── base_vision.py │ ├── vlas │ │ ├── __init__.py │ │ └── openvla.py │ ├── vlms │ │ ├── __init__.py │ │ └── base_vlm.py │ ├── __init__.py │ ├── projectors.py │ ├── materialize.py │ └── action_heads.py ├── vla │ ├── datasets │ │ ├── rlds │ │ │ ├── utils │ │ │ │ ├── __init__.py │ │ │ │ ├── goal_relabeling.py │ │ │ │ └── task_augmentation.py │ │ │ ├── __init__.py │ │ │ ├── oxe │ │ │ │ ├── __init__.py │ │ │ │ ├── materialize.py │ │ │ │ └── utils │ │ │ │ │ └── droid_utils.py │ │ │ ├── traj_transforms.py │ │ │ └── obs_transforms.py │ │ ├── __init__.py │ │ ├── data_config.yaml │ │ ├── cast_dataset.py │ │ └── dummy_dataset.py │ ├── __init__.py │ ├── constants.py │ ├── materialize.py │ └── action_tokenizer.py ├── overwatch │ ├── __init__.py │ └── overwatch.py ├── util │ ├── __init__.py │ ├── nn_utils.py │ └── torch_utils.py ├── preprocessing │ ├── datasets │ │ ├── __init__.py │ │ └── datasets.py │ ├── __init__.py │ ├── materialize.py │ └── download.py ├── __init__.py ├── training │ ├── __init__.py │ ├── strategies │ │ ├── __init__.py │ │ └── ddp.py │ ├── train_utils.py │ └── materialize.py └── conf │ ├── __init__.py │ ├── datasets.py │ └── vla.py ├── inference ├── 1_ex.jpg ├── goal_img.jpg └── current_img.jpg ├── config_nav ├── mbra_config.yaml └── mbra_and_dataset_config.yaml ├── .gitignore ├── SETUP.md ├── LICENSE ├── pyproject.toml └── README.md /prismatic/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /prismatic/extern/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /prismatic/extern/hf/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /prismatic/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/rlds/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /prismatic/models/vlas/__init__.py: -------------------------------------------------------------------------------- 1 | from .openvla import OpenVLA 2 | -------------------------------------------------------------------------------- /prismatic/models/vlms/__init__.py: -------------------------------------------------------------------------------- 1 | from .prismatic import PrismaticVLM 2 | -------------------------------------------------------------------------------- /prismatic/overwatch/__init__.py: -------------------------------------------------------------------------------- 1 | from .overwatch import initialize_overwatch 2 | -------------------------------------------------------------------------------- /prismatic/vla/__init__.py: -------------------------------------------------------------------------------- 1 | from .materialize import get_vla_dataset_and_collator 2 | -------------------------------------------------------------------------------- /inference/1_ex.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NHirose/OmniVLA/HEAD/inference/1_ex.jpg -------------------------------------------------------------------------------- /inference/goal_img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NHirose/OmniVLA/HEAD/inference/goal_img.jpg -------------------------------------------------------------------------------- /prismatic/util/__init__.py: -------------------------------------------------------------------------------- 1 | from .torch_utils import check_bloat16_supported, set_global_seed 2 | -------------------------------------------------------------------------------- /inference/current_img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NHirose/OmniVLA/HEAD/inference/current_img.jpg -------------------------------------------------------------------------------- /prismatic/preprocessing/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import AlignDataset, FinetuneDataset 2 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/rlds/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import make_interleaved_dataset, make_single_dataset 2 | -------------------------------------------------------------------------------- /prismatic/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import available_model_names, available_models, get_model_description, load 2 | -------------------------------------------------------------------------------- /prismatic/training/__init__.py: -------------------------------------------------------------------------------- 1 | from .materialize import get_train_strategy 2 | from .metrics import Metrics, VLAMetrics 3 | -------------------------------------------------------------------------------- /prismatic/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | from .download import convert_to_jpg, download_extract 2 | from .materialize import get_dataset_and_collator 3 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/rlds/oxe/__init__.py: -------------------------------------------------------------------------------- 1 | from .materialize import get_oxe_dataset_kwargs_and_weights 2 | from .mixtures import OXE_NAMED_MIXTURES 3 | -------------------------------------------------------------------------------- /prismatic/training/strategies/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_strategy import TrainingStrategy 2 | from .ddp import DDPStrategy 3 | from .fsdp import FSDPStrategy 4 | -------------------------------------------------------------------------------- /prismatic/conf/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import DatasetConfig, DatasetRegistry 2 | from .models import ModelConfig, ModelRegistry 3 | from .vla import VLAConfig, VLARegistry 4 | -------------------------------------------------------------------------------- /prismatic/models/backbones/llm/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_llm import LLMBackbone 2 | from .llama2 import LLaMa2LLMBackbone 3 | from .mistral import MistralLLMBackbone 4 | from .phi import PhiLLMBackbone 5 | -------------------------------------------------------------------------------- /prismatic/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .load import available_model_names, available_models, get_model_description, load, load_vla 2 | from .materialize import get_llm_backbone_and_tokenizer, get_vision_backbone_and_transform, get_vlm 3 | -------------------------------------------------------------------------------- /config_nav/mbra_config.yaml: -------------------------------------------------------------------------------- 1 | # parameters for MBRA 2 | context_size: 5 #history of image 3 | len_traj_pred: 8 #action length 4 | learn_angle: True 5 | obs_encoder: efficientnet-b0 6 | obs_encoding_size: 1024 7 | late_fusion: False 8 | mha_num_attention_heads: 4 9 | mha_num_attention_layers: 4 10 | mha_ff_dim_factor: 4 11 | -------------------------------------------------------------------------------- /prismatic/models/backbones/llm/prompting/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_prompter import PromptBuilder, PurePromptBuilder 2 | from .llama2_chat_prompter import LLaMa2ChatPromptBuilder 3 | from .mistral_instruct_prompter import MistralInstructPromptBuilder 4 | from .phi_prompter import PhiPromptBuilder 5 | from .vicuna_v15_prompter import VicunaV15ChatPromptBuilder 6 | -------------------------------------------------------------------------------- /prismatic/models/backbones/vision/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_vision import ImageTransform, VisionBackbone 2 | from .clip_vit import CLIPViTBackbone 3 | from .dinoclip_vit import DinoCLIPViTBackbone 4 | from .dinosiglip_vit import DinoSigLIPViTBackbone 5 | from .dinov2_vit import DinoV2ViTBackbone 6 | from .in1k_vit import IN1KViTBackbone 7 | from .siglip_vit import SigLIPViTBackbone 8 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import DummyDataset, EpisodicRLDSDataset, RLDSBatchTransform, RLDSDataset 2 | #from .lelan_dataset import LeLaN_Dataset_multi, LeLaN_Dataset_openvla, LeLaN_Dataset_openvla_act, LeLaN_Dataset_openvla_act_MMN 3 | #from .vint_hf_dataset import ViNTLeRobotDataset_IL2_gps_map2_crop_shadow_MMN, EpisodeSampler_IL_MMN 4 | #from .vint_dataset import ViNT_Dataset_gps_MMN 5 | #from .bdd_dataset import BDD_Dataset_multi_MMN 6 | #from .cast_dataset import CAST_Dataset_MMN 7 | -------------------------------------------------------------------------------- /prismatic/models/backbones/vision/dinov2_vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | dinov2_vit.py 3 | """ 4 | 5 | from prismatic.models.backbones.vision.base_vision import TimmViTBackbone 6 | 7 | # Registry =>> Supported DINOv2 Vision Backbones (from TIMM) =>> Note:: Using DINOv2 w/ Registers! 8 | # => Reference: https://arxiv.org/abs/2309.16588 9 | DINOv2_VISION_BACKBONES = {"dinov2-vit-l": "vit_large_patch14_reg4_dinov2.lvd142m"} 10 | 11 | 12 | class DinoV2ViTBackbone(TimmViTBackbone): 13 | def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: 14 | super().__init__( 15 | vision_backbone_id, 16 | DINOv2_VISION_BACKBONES[vision_backbone_id], 17 | image_resize_strategy, 18 | default_image_size=default_image_size, 19 | ) 20 | -------------------------------------------------------------------------------- /prismatic/models/backbones/vision/in1k_vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | in1k_vit.py 3 | 4 | Vision Transformers trained / finetuned on ImageNet (ImageNet-21K =>> ImageNet-1K) 5 | """ 6 | 7 | from prismatic.models.backbones.vision.base_vision import TimmViTBackbone 8 | 9 | # Registry =>> Supported Vision Backbones (from TIMM) 10 | IN1K_VISION_BACKBONES = { 11 | "in1k-vit-l": "vit_large_patch16_224.augreg_in21k_ft_in1k", 12 | } 13 | 14 | 15 | class IN1KViTBackbone(TimmViTBackbone): 16 | def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: 17 | super().__init__( 18 | vision_backbone_id, 19 | IN1K_VISION_BACKBONES[vision_backbone_id], 20 | image_resize_strategy, 21 | default_image_size=default_image_size, 22 | ) 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | omnivla-finetuned-cast/ 2 | omnivla-original/ 3 | omnivla-original-balance/ 4 | MBRA/ 5 | MBRA_BDD/ 6 | OmniVLA.egg-info/ 7 | prismatic/vla/datasets/sampler/ 8 | prismatic/vla/datasets/bdd_dataset_.py 9 | prismatic/vla/datasets/cast_dataset_.py 10 | prismatic/vla/datasets/lelan_dataset_.py 11 | prismatic/vla/datasets/gnm_dataset_.py 12 | prismatic/vla/datasets/frodobots_dataset_.py 13 | vla-scripts/train_omnivla_dataset_.py 14 | vla-scripts/train_omnivla_dataset__.py 15 | vla-scripts/finetune_nav_MMN_dataset_cast2.py 16 | config_nav/mbra_and_dataset_config_.yaml 17 | config_nav/mbra_and_dataset_config__.yaml 18 | config_nav/mbra_and_dataset_config_bellman.yaml 19 | config_nav/base_server_bellman.yaml 20 | visualization/ 21 | runs/ 22 | # Ignore all __pycache__ directories everywhere 23 | **/__pycache__/ 24 | # Ignore all wandb directories everywhere 25 | **/wandb/ 26 | -------------------------------------------------------------------------------- /SETUP.md: -------------------------------------------------------------------------------- 1 | # Setup Instructions 2 | 3 | ## Set Up Conda Environment 4 | 5 | ```bash 6 | # Create and activate conda environment 7 | conda create -n omnivla python=3.10 -y 8 | conda activate omnivla 9 | 10 | # Install PyTorch 11 | # Use a command specific to your machine: https://pytorch.org/get-started/locally/ 12 | pip3 install numpy==1.26.4 torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 13 | 14 | # Clone openvla-oft repo and pip install to download dependencies 15 | git clone https://github.com/NHirose/OmniVLA.git 16 | cd OmniVLA 17 | pip install -e . 18 | 19 | # Install Flash Attention 2 for training (https://github.com/Dao-AILab/flash-attention) 20 | # =>> If you run into difficulty, try `pip cache remove flash_attn` first 21 | pip install packaging ninja 22 | ninja --version; echo $? # Verify Ninja --> should return exit code "0" 23 | pip install "flash-attn==2.5.5" --no-build-isolation 24 | ``` 25 | -------------------------------------------------------------------------------- /prismatic/models/backbones/vision/siglip_vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | siglip_vit.py 3 | """ 4 | 5 | from prismatic.models.backbones.vision.base_vision import TimmViTBackbone 6 | 7 | # Registry =>> Supported SigLIP Vision Backbones (from TIMM) =>> Note:: Using SigLIP w/ Patch = 14 (but SO400M Arch) 8 | SIGLIP_VISION_BACKBONES = { 9 | "siglip-vit-b16-224px": "vit_base_patch16_siglip_224", 10 | "siglip-vit-b16-256px": "vit_base_patch16_siglip_256", 11 | "siglip-vit-b16-384px": "vit_base_patch16_siglip_384", 12 | "siglip-vit-so400m": "vit_so400m_patch14_siglip_224", 13 | "siglip-vit-so400m-384px": "vit_so400m_patch14_siglip_384", 14 | } 15 | 16 | 17 | class SigLIPViTBackbone(TimmViTBackbone): 18 | def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: 19 | super().__init__( 20 | vision_backbone_id, 21 | SIGLIP_VISION_BACKBONES[vision_backbone_id], 22 | image_resize_strategy, 23 | default_image_size=default_image_size, 24 | ) 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Original Copyright (c) 2025 Moo Jin Kim, Chelsea Finn, Percy Liang. 4 | Modifications Copyright (c) 2025 Noriaki Hirose, Catherine Glossop, Dhruv Shah, Sergey Levine 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/rlds/utils/goal_relabeling.py: -------------------------------------------------------------------------------- 1 | """ 2 | goal_relabeling.py 3 | 4 | Contains simple goal relabeling logic for BC use-cases where rewards and next_observations are not required. 5 | Each function should add entries to the "task" dict. 6 | """ 7 | 8 | from typing import Dict 9 | 10 | import tensorflow as tf 11 | 12 | from prismatic.vla.datasets.rlds.utils.data_utils import tree_merge 13 | 14 | 15 | def uniform(traj: Dict) -> Dict: 16 | """Relabels with a true uniform distribution over future states.""" 17 | traj_len = tf.shape(tf.nest.flatten(traj["observation"])[0])[0] 18 | 19 | # Select a random future index for each transition i in the range [i + 1, traj_len) 20 | rand = tf.random.uniform([traj_len]) 21 | low = tf.cast(tf.range(traj_len) + 1, tf.float32) 22 | high = tf.cast(traj_len, tf.float32) 23 | goal_idxs = tf.cast(rand * (high - low) + low, tf.int32) 24 | 25 | # Sometimes there are floating-point errors that cause an out-of-bounds 26 | goal_idxs = tf.minimum(goal_idxs, traj_len - 1) 27 | 28 | # Adds keys to "task" mirroring "observation" keys (`tree_merge` to combine "pad_mask_dict" properly) 29 | goal = tf.nest.map_structure(lambda x: tf.gather(x, goal_idxs), traj["observation"]) 30 | traj["task"] = tree_merge(traj["task"], goal) 31 | 32 | return traj 33 | -------------------------------------------------------------------------------- /prismatic/models/backbones/vision/clip_vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | clip_vit.py 3 | """ 4 | 5 | from prismatic.models.backbones.vision.base_vision import TimmViTBackbone 6 | 7 | # Registry =>> Supported CLIP Vision Backbones (from TIMM) 8 | CLIP_VISION_BACKBONES = { 9 | "clip-vit-b": "vit_base_patch16_clip_224.openai", 10 | "clip-vit-l": "vit_large_patch14_clip_224.openai", 11 | "clip-vit-l-336px": "vit_large_patch14_clip_336.openai", 12 | } 13 | 14 | 15 | # [IMPORTANT] By Default, TIMM initialized OpenAI CLIP models with the standard GELU activation from PyTorch. 16 | # HOWEVER =>> Original OpenAI models were trained with the quick_gelu *approximation* -- while it's 17 | # a decent approximation, the resulting features are *worse*; this was a super tricky bug 18 | # to identify, but luckily there's an easy fix (`override_act_layer`) 19 | class CLIPViTBackbone(TimmViTBackbone): 20 | def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: 21 | super().__init__( 22 | vision_backbone_id, 23 | CLIP_VISION_BACKBONES[vision_backbone_id], 24 | image_resize_strategy, 25 | default_image_size=default_image_size, 26 | override_act_layer="quick_gelu" if CLIP_VISION_BACKBONES[vision_backbone_id].endswith(".openai") else None, 27 | ) 28 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/data_config.yaml: -------------------------------------------------------------------------------- 1 | 2 | # global params for diffusion model 3 | # normalized min and max 4 | action_stats: 5 | min: [-2.5, -4] # [min_dx, min_dy] 6 | max: [5, 4] # [max_dx, max_dy] 7 | 8 | # data specific params 9 | recon: 10 | metric_waypoint_spacing: 0.25 # average spacing between waypoints (meters) 11 | 12 | # OPTIONAL (FOR VISUALIZATION ONLY) 13 | camera_metrics: # https://docs.opencv.org/4.x/dc/dbb/tutorial_py_calibration.html 14 | camera_height: 0.95 # meters 15 | camera_x_offset: 0.45 # distance between the center of the robot and the forward facing camera 16 | camera_matrix: 17 | fx: 272.547000 18 | fy: 266.358000 19 | cx: 320.000000 20 | cy: 220.000000 21 | dist_coeffs: 22 | k1: -0.038483 23 | k2: -0.010456 24 | p1: 0.003930 25 | p2: -0.001007 26 | k3: 0.0 27 | 28 | scand: 29 | metric_waypoint_spacing: 0.38 30 | 31 | tartan_drive: 32 | metric_waypoint_spacing: 0.72 33 | 34 | go_stanford: 35 | metric_waypoint_spacing: 0.12 36 | 37 | # private datasets: 38 | cory_hall: 39 | metric_waypoint_spacing: 0.06 40 | 41 | seattle: 42 | metric_waypoint_spacing: 0.35 43 | 44 | racer: 45 | metric_waypoint_spacing: 0.38 46 | 47 | carla_intvns: 48 | metric_waypoint_spacing: 1.39 49 | 50 | carla_cil: 51 | metric_waypoint_spacing: 1.27 52 | 53 | carla_intvns: 54 | metric_waypoint_spacing: 1.39 55 | 56 | carla: 57 | metric_waypoint_spacing: 1.59 58 | image_path_func: get_image_path 59 | 60 | sacson: 61 | metric_waypoint_spacing: 0.12 62 | 63 | go_stanford4: 64 | metric_waypoint_spacing: 0.12 65 | 66 | go_stanford2: 67 | metric_waypoint_spacing: 0.12 68 | 69 | humanw: 70 | metric_waypoint_spacing: 0.12 71 | 72 | youtube: 73 | metric_waypoint_spacing: 0.12 74 | 75 | bdd: 76 | metric_waypoint_spacing: 0.12 77 | 78 | # add your own dataset params here: 79 | -------------------------------------------------------------------------------- /prismatic/vla/constants.py: -------------------------------------------------------------------------------- 1 | """ 2 | Important constants for VLA training and evaluation. 3 | 4 | Attempts to automatically identify the correct constants to set based on the Python command used to launch 5 | training or evaluation. If it is unclear, defaults to using the LIBERO simulation benchmark constants. 6 | """ 7 | import sys 8 | from enum import Enum 9 | 10 | # Llama 2 token constants 11 | IGNORE_INDEX = -100 12 | ACTION_TOKEN_BEGIN_IDX = 31743 13 | STOP_INDEX = 2 # '' 14 | 15 | 16 | # Defines supported normalization schemes for action and proprioceptive state. 17 | class NormalizationType(str, Enum): 18 | # fmt: off 19 | NORMAL = "normal" # Normalize to Mean = 0, Stdev = 1 20 | BOUNDS = "bounds" # Normalize to Interval = [-1, 1] 21 | BOUNDS_Q99 = "bounds_q99" # Normalize [quantile_01, ..., quantile_99] --> [-1, ..., 1] 22 | # fmt: on 23 | 24 | 25 | # Define constants for each robot platform 26 | OMNIVLA_CONSTANTS = { 27 | "NUM_ACTIONS_CHUNK": 8, 28 | "ACTION_DIM": 4, 29 | "POSE_DIM": 4, 30 | "ACTION_PROPRIO_NORMALIZATION_TYPE": NormalizationType.BOUNDS_Q99, 31 | } 32 | 33 | constants = OMNIVLA_CONSTANTS 34 | 35 | # Assign constants to global variables 36 | NUM_ACTIONS_CHUNK = constants["NUM_ACTIONS_CHUNK"] 37 | ACTION_DIM = constants["ACTION_DIM"] 38 | POSE_DIM = constants["POSE_DIM"] 39 | ACTION_PROPRIO_NORMALIZATION_TYPE = constants["ACTION_PROPRIO_NORMALIZATION_TYPE"] 40 | 41 | # Print which robot platform constants are being used (for debugging) 42 | print(f"Using OmniVLA constants:") 43 | print(f" NUM_ACTIONS_CHUNK = {NUM_ACTIONS_CHUNK}") 44 | print(f" ACTION_DIM = {ACTION_DIM}") 45 | print(f" POSE_DIM = {POSE_DIM}") 46 | print(f" ACTION_PROPRIO_NORMALIZATION_TYPE = {ACTION_PROPRIO_NORMALIZATION_TYPE}") 47 | print("If needed, manually set the correct constants in `prismatic/vla/constants.py`!") 48 | -------------------------------------------------------------------------------- /prismatic/models/projectors.py: -------------------------------------------------------------------------------- 1 | """Implementation of additional projectors for additional inputs to the VLA models.""" 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class ProprioProjector(nn.Module): 7 | """ 8 | Projects proprio state inputs into the LLM's embedding space. 9 | """ 10 | def __init__(self, llm_dim: int, proprio_dim: int) -> None: 11 | super().__init__() 12 | self.llm_dim = llm_dim 13 | self.proprio_dim = proprio_dim 14 | 15 | self.fc1 = nn.Linear(self.proprio_dim, self.llm_dim, bias=True) 16 | self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True) 17 | self.act_fn1 = nn.GELU() 18 | 19 | def forward(self, proprio: torch.Tensor = None) -> torch.Tensor: 20 | # proprio: (bsz, proprio_dim) 21 | projected_features = self.fc1(proprio) 22 | projected_features = self.act_fn1(projected_features) 23 | projected_features = self.fc2(projected_features) 24 | return projected_features 25 | 26 | 27 | class NoisyActionProjector(nn.Module): 28 | """ 29 | [Diffusion] Projects noisy action inputs into the LLM's embedding space. 30 | 31 | Note that since each action is tokenized into 7 tokens in OpenVLA (rather 32 | than having 1 token per action), each noisy action token will have dimension 1 33 | instead of 7. 34 | """ 35 | def __init__(self, llm_dim: int) -> None: 36 | super().__init__() 37 | self.llm_dim = llm_dim 38 | self.action_token_dim = 1 39 | 40 | self.fc1 = nn.Linear(self.action_token_dim, self.llm_dim, bias=True) 41 | self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True) 42 | self.act_fn1 = nn.GELU() 43 | 44 | def forward(self, noisy_actions: torch.Tensor = None) -> torch.Tensor: 45 | # noisy_actions: (bsz, num_action_tokens=chunk_len*action_dim, 1) 46 | projected_features = self.fc1(noisy_actions) 47 | projected_features = self.act_fn1(projected_features) 48 | projected_features = self.fc2(projected_features) 49 | return projected_features 50 | -------------------------------------------------------------------------------- /prismatic/training/train_utils.py: -------------------------------------------------------------------------------- 1 | """Utils for training/fine-tuning scripts.""" 2 | 3 | import torch 4 | 5 | from prismatic.vla.constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX 6 | 7 | 8 | def get_current_action_mask(token_ids): 9 | # Create a tensor marking positions of IGNORE_INDEX 10 | newline_positions = token_ids != IGNORE_INDEX 11 | 12 | # Calculate cumulative sum to identify regions between newlines 13 | cumsum = torch.cumsum(newline_positions, dim=1) 14 | 15 | # Create the mask 16 | mask = (1 <= cumsum) & (cumsum <= ACTION_DIM) 17 | 18 | # Extract the action part only 19 | action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX 20 | mask = action_tokens_only_mask * mask 21 | 22 | return mask 23 | 24 | 25 | def get_next_actions_mask(token_ids): 26 | # Create a tensor marking positions of IGNORE_INDEX 27 | newline_positions = token_ids != IGNORE_INDEX 28 | 29 | # Calculate cumulative sum to identify regions between newlines 30 | cumsum = torch.cumsum(newline_positions, dim=1) 31 | 32 | # Create the mask 33 | mask = cumsum > ACTION_DIM 34 | 35 | # Extract the action part only 36 | action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX 37 | mask = action_tokens_only_mask * mask 38 | 39 | return mask 40 | 41 | 42 | def compute_token_accuracy(predicted_token_ids, ground_truth_token_ids, mask): 43 | correct_preds = (predicted_token_ids == ground_truth_token_ids) & mask 44 | accuracy = correct_preds.sum().float() / mask.sum().float() 45 | return accuracy 46 | 47 | 48 | def compute_actions_l1_loss(action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask): 49 | pred_continuous_actions = torch.tensor( 50 | action_tokenizer.decode_token_ids_to_actions(predicted_token_ids[mask].cpu().numpy()) 51 | ) 52 | true_continuous_actions = torch.tensor( 53 | action_tokenizer.decode_token_ids_to_actions(ground_truth_token_ids[mask].cpu().numpy()) 54 | ) 55 | l1_loss = torch.nn.functional.l1_loss(pred_continuous_actions, true_continuous_actions) 56 | return l1_loss 57 | -------------------------------------------------------------------------------- /prismatic/util/nn_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | nn_utils.py 3 | 4 | Utility functions and PyTorch submodule definitions. 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | # === Definitions for Various Projection Modules, with Signature :: [..., in_dim] --> [..., out_dim] === 12 | class LinearProjector(nn.Module): 13 | def __init__(self, vision_dim: int, llm_dim: int) -> None: 14 | super().__init__() 15 | self.projector = nn.Linear(vision_dim, llm_dim, bias=True) 16 | 17 | def forward(self, img_patches: torch.Tensor) -> torch.Tensor: 18 | return self.projector(img_patches) 19 | 20 | 21 | class MLPProjector(nn.Module): 22 | def __init__(self, vision_dim: int, llm_dim: int, mlp_type: str = "gelu-mlp") -> None: 23 | super().__init__() 24 | if mlp_type == "gelu-mlp": 25 | self.projector = nn.Sequential( 26 | nn.Linear(vision_dim, llm_dim, bias=True), 27 | nn.GELU(), 28 | nn.Linear(llm_dim, llm_dim, bias=True), 29 | ) 30 | else: 31 | raise ValueError(f"Projector with `{mlp_type = }` is not supported!") 32 | 33 | def forward(self, img_patches: torch.Tensor) -> torch.Tensor: 34 | return self.projector(img_patches) 35 | 36 | 37 | class FusedMLPProjector(nn.Module): 38 | def __init__(self, fused_vision_dim: int, llm_dim: int, mlp_type: str = "fused-gelu-mlp") -> None: 39 | super().__init__() 40 | self.initial_projection_dim = fused_vision_dim * 4 41 | if mlp_type == "fused-gelu-mlp": 42 | self.projector = nn.Sequential( 43 | nn.Linear(fused_vision_dim, self.initial_projection_dim, bias=True), 44 | nn.GELU(), 45 | nn.Linear(self.initial_projection_dim, llm_dim, bias=True), 46 | nn.GELU(), 47 | nn.Linear(llm_dim, llm_dim, bias=True), 48 | ) 49 | else: 50 | raise ValueError(f"Fused Projector with `{mlp_type = }` is not supported!") 51 | 52 | def forward(self, fused_img_patches: torch.Tensor) -> torch.Tensor: 53 | return self.projector(fused_img_patches) 54 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/rlds/utils/task_augmentation.py: -------------------------------------------------------------------------------- 1 | """ 2 | task_augmentation.py 3 | 4 | Contains basic logic for randomly zeroing out keys in the task specification. 5 | """ 6 | 7 | from typing import Dict 8 | 9 | import tensorflow as tf 10 | 11 | from prismatic.vla.datasets.rlds.utils.data_utils import to_padding 12 | 13 | 14 | def delete_task_conditioning(traj: Dict, keep_image_prob: float) -> Dict: 15 | """ 16 | Randomly drops out either the goal images or the language instruction. Only does something if both of 17 | these are present. 18 | 19 | Args: 20 | traj: A dictionary containing trajectory data. Should have a "task" key. 21 | keep_image_prob: The probability of keeping the goal images. The probability of keeping the language 22 | instruction is 1 - keep_image_prob. 23 | """ 24 | if "language_instruction" not in traj["task"]: 25 | return traj 26 | 27 | image_keys = {key for key in traj["task"].keys() if key.startswith("image_") or key.startswith("depth_")} 28 | if not image_keys: 29 | return traj 30 | 31 | traj_len = tf.shape(traj["action"])[0] 32 | should_keep_images = tf.random.uniform([traj_len]) < keep_image_prob 33 | should_keep_images |= ~traj["task"]["pad_mask_dict"]["language_instruction"] 34 | 35 | for key in image_keys | {"language_instruction"}: 36 | should_keep = should_keep_images if key in image_keys else ~should_keep_images 37 | # pad out the key 38 | traj["task"][key] = tf.where( 39 | should_keep, 40 | traj["task"][key], 41 | to_padding(traj["task"][key]), 42 | ) 43 | # zero out the pad mask dict for the key 44 | traj["task"]["pad_mask_dict"][key] = tf.where( 45 | should_keep, 46 | traj["task"]["pad_mask_dict"][key], 47 | tf.zeros_like(traj["task"]["pad_mask_dict"][key]), 48 | ) 49 | 50 | # when no goal images are present, the goal timestep becomes the final timestep 51 | traj["task"]["timestep"] = tf.where( 52 | should_keep_images, 53 | traj["task"]["timestep"], 54 | traj_len - 1, 55 | ) 56 | 57 | return traj 58 | -------------------------------------------------------------------------------- /prismatic/models/backbones/llm/phi.py: -------------------------------------------------------------------------------- 1 | """ 2 | phi.py 3 | 4 | Class definition for all LLMs derived from PhiForCausalLM. 5 | """ 6 | 7 | from typing import Optional, Type 8 | 9 | import torch 10 | from torch import nn as nn 11 | from transformers import PhiForCausalLM 12 | from transformers.models.phi.modeling_phi import PhiDecoderLayer 13 | 14 | from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone 15 | from prismatic.models.backbones.llm.prompting import PhiPromptBuilder, PromptBuilder 16 | 17 | # Registry ==> Support Phi Models (from HF Transformers) 18 | # fmt: off 19 | PHI_MODELS = { 20 | # === Phi-2 === 21 | "phi-2-3b": { 22 | "llm_family": "phi", "llm_cls": PhiForCausalLM, "hf_hub_path": "microsoft/phi-2" 23 | } 24 | } 25 | # fmt: on 26 | 27 | 28 | class PhiLLMBackbone(HFCausalLLMBackbone): 29 | def __init__( 30 | self, 31 | llm_backbone_id: str, 32 | llm_max_length: int = 2048, 33 | hf_token: Optional[str] = None, 34 | inference_mode: bool = False, 35 | use_flash_attention_2: bool = True, 36 | ) -> None: 37 | super().__init__( 38 | llm_backbone_id, 39 | llm_max_length=llm_max_length, 40 | hf_token=hf_token, 41 | inference_mode=inference_mode, 42 | use_flash_attention_2=use_flash_attention_2, 43 | **PHI_MODELS[llm_backbone_id], 44 | ) 45 | 46 | # [Special Case] Phi PAD Token Handling --> for clarity, we add an extra token (and resize) 47 | self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) 48 | self.llm.config.pad_token_id = self.tokenizer.pad_token_id 49 | self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64) 50 | 51 | @property 52 | def prompt_builder_fn(self) -> Type[PromptBuilder]: 53 | if self.identifier.startswith("phi-2"): 54 | return PhiPromptBuilder 55 | 56 | raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`") 57 | 58 | @property 59 | def transformer_layer_cls(self) -> Type[nn.Module]: 60 | return PhiDecoderLayer 61 | 62 | @property 63 | def half_precision_dtype(self) -> torch.dtype: 64 | return torch.bfloat16 65 | -------------------------------------------------------------------------------- /prismatic/vla/materialize.py: -------------------------------------------------------------------------------- 1 | """ 2 | materialize.py 3 | 4 | Factory class for initializing Open-X RLDS-backed datasets, given specified data mixture parameters; provides and 5 | exports individual functions for clear control flow. 6 | """ 7 | 8 | from pathlib import Path 9 | from typing import Tuple, Type 10 | 11 | from torch.utils.data import Dataset 12 | from transformers import PreTrainedTokenizerBase 13 | 14 | from prismatic.models.backbones.llm.prompting import PromptBuilder 15 | from prismatic.models.backbones.vision import ImageTransform 16 | from prismatic.util.data_utils import PaddedCollatorForActionPrediction 17 | from prismatic.vla.action_tokenizer import ActionTokenizer 18 | from prismatic.vla.datasets import EpisodicRLDSDataset, RLDSBatchTransform, RLDSDataset 19 | 20 | 21 | def get_vla_dataset_and_collator( 22 | data_root_dir: Path, 23 | data_mix: str, 24 | image_transform: ImageTransform, 25 | tokenizer: PreTrainedTokenizerBase, 26 | prompt_builder_fn: Type[PromptBuilder], 27 | default_image_resolution: Tuple[int, int, int], 28 | padding_side: str = "right", 29 | predict_stop_token: bool = True, 30 | shuffle_buffer_size: int = 100_000, 31 | train: bool = True, 32 | episodic: bool = False, 33 | image_aug: bool = False, 34 | ) -> Tuple[Dataset, ActionTokenizer, PaddedCollatorForActionPrediction]: 35 | """Initialize RLDS Dataset (wraps TFDS), ActionTokenizer, and initialize transform/collation functions.""" 36 | action_tokenizer = ActionTokenizer(tokenizer) 37 | batch_transform = RLDSBatchTransform( 38 | action_tokenizer, tokenizer, image_transform, prompt_builder_fn, predict_stop_token=predict_stop_token 39 | ) 40 | collator = PaddedCollatorForActionPrediction( 41 | tokenizer.model_max_length, tokenizer.pad_token_id, padding_side=padding_side 42 | ) 43 | 44 | # Build RLDS Iterable Dataset 45 | cls = RLDSDataset if not episodic else EpisodicRLDSDataset 46 | dataset = cls( 47 | data_root_dir, 48 | data_mix, 49 | batch_transform, 50 | resize_resolution=default_image_resolution[1:], 51 | shuffle_buffer_size=shuffle_buffer_size, 52 | train=train, 53 | image_aug=image_aug, 54 | ) 55 | 56 | return dataset, action_tokenizer, collator 57 | -------------------------------------------------------------------------------- /prismatic/models/backbones/llm/prompting/mistral_instruct_prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | mistral_instruct_prompter.py 3 | 4 | Defines a PromptBuilder for building Mistral Instruct Chat Prompts --> recommended pattern used by HF / online tutorial.s 5 | 6 | Reference: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format 7 | """ 8 | 9 | from typing import Optional 10 | 11 | from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder 12 | 13 | 14 | class MistralInstructPromptBuilder(PromptBuilder): 15 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 16 | super().__init__(model_family, system_prompt) 17 | 18 | # Note =>> Mistral Tokenizer is an instance of `LlamaTokenizer(Fast)` 19 | # =>> Mistral Instruct *does not* use a System Prompt 20 | self.bos, self.eos = "", "" 21 | 22 | # Get role-specific "wrap" functions 23 | self.wrap_human = lambda msg: f"[INST] {msg} [/INST] " 24 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" 25 | 26 | # === `self.prompt` gets built up over multiple turns === 27 | self.prompt, self.turn_count = "", 0 28 | 29 | def add_turn(self, role: str, message: str) -> str: 30 | assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") 31 | message = message.replace("", "").strip() 32 | 33 | if (self.turn_count % 2) == 0: 34 | human_message = self.wrap_human(message) 35 | wrapped_message = human_message 36 | else: 37 | gpt_message = self.wrap_gpt(message) 38 | wrapped_message = gpt_message 39 | 40 | # Update Prompt 41 | self.prompt += wrapped_message 42 | 43 | # Bump Turn Counter 44 | self.turn_count += 1 45 | 46 | # Return "wrapped_message" (effective string added to context) 47 | return wrapped_message 48 | 49 | def get_potential_prompt(self, message: str) -> None: 50 | # Assumes that it's always the user's (human's) turn! 51 | prompt_copy = str(self.prompt) 52 | 53 | human_message = self.wrap_human(message) 54 | prompt_copy += human_message 55 | 56 | return prompt_copy.removeprefix(self.bos).rstrip() 57 | 58 | def get_prompt(self) -> str: 59 | # Remove prefix because it gets auto-inserted by tokenizer! 60 | return self.prompt.removeprefix(self.bos).rstrip() 61 | -------------------------------------------------------------------------------- /prismatic/training/materialize.py: -------------------------------------------------------------------------------- 1 | """ 2 | materialize.py 3 | 4 | Factory class defining functions for instantiating various Training Strategies, supporting different VLMs, backbones, 5 | and strategy configurations. 6 | """ 7 | 8 | from typing import Callable, Optional 9 | 10 | import torch 11 | 12 | from prismatic.models.vlms import PrismaticVLM 13 | from prismatic.training.strategies import FSDPStrategy, TrainingStrategy 14 | 15 | # Registry =>> Maps ID --> {cls(), kwargs} :: supports FSDP for now, but DDP handler is also implemented! 16 | TRAIN_STRATEGIES = { 17 | "fsdp-shard-grad-op": {"cls": FSDPStrategy, "kwargs": {"sharding_strategy": "shard-grad-op"}}, 18 | "fsdp-full-shard": {"cls": FSDPStrategy, "kwargs": {"sharding_strategy": "full-shard"}}, 19 | } 20 | 21 | 22 | def get_train_strategy( 23 | train_strategy: str, 24 | vlm: PrismaticVLM, 25 | device_id: int, 26 | stage: str, 27 | epochs: int, 28 | max_steps: Optional[int], 29 | global_batch_size: int, 30 | per_device_batch_size: int, 31 | learning_rate: float, 32 | weight_decay: float, 33 | max_grad_norm: float, 34 | lr_scheduler_type: str, 35 | warmup_ratio: float, 36 | enable_gradient_checkpointing: bool = True, 37 | enable_mixed_precision_training: bool = True, 38 | reduce_in_full_precision: bool = False, 39 | mixed_precision_dtype: torch.dtype = torch.bfloat16, 40 | worker_init_fn: Optional[Callable[[int], None]] = None, 41 | ) -> TrainingStrategy: 42 | if train_strategy in TRAIN_STRATEGIES: 43 | strategy_cfg = TRAIN_STRATEGIES[train_strategy] 44 | strategy = strategy_cfg["cls"]( 45 | vlm=vlm, 46 | device_id=device_id, 47 | stage=stage, 48 | epochs=epochs, 49 | max_steps=max_steps, 50 | global_batch_size=global_batch_size, 51 | per_device_batch_size=per_device_batch_size, 52 | learning_rate=learning_rate, 53 | weight_decay=weight_decay, 54 | max_grad_norm=max_grad_norm, 55 | lr_scheduler_type=lr_scheduler_type, 56 | warmup_ratio=warmup_ratio, 57 | enable_gradient_checkpointing=enable_gradient_checkpointing, 58 | enable_mixed_precision_training=enable_mixed_precision_training, 59 | reduce_in_full_precision=reduce_in_full_precision, 60 | mixed_precision_dtype=mixed_precision_dtype, 61 | worker_init_fn=worker_init_fn, 62 | **strategy_cfg["kwargs"], 63 | ) 64 | return strategy 65 | else: 66 | raise ValueError(f"Train Strategy `{train_strategy}` is not supported!") 67 | -------------------------------------------------------------------------------- /prismatic/models/backbones/llm/mistral.py: -------------------------------------------------------------------------------- 1 | """ 2 | mistral.py 3 | 4 | Class definition for all LLMs derived from MistralForCausalLM. 5 | """ 6 | 7 | from typing import Optional, Type 8 | 9 | import torch 10 | from torch import nn as nn 11 | from transformers import MistralForCausalLM 12 | from transformers.models.mistral.modeling_mistral import MistralDecoderLayer 13 | 14 | from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone 15 | from prismatic.models.backbones.llm.prompting import MistralInstructPromptBuilder, PromptBuilder, PurePromptBuilder 16 | 17 | # Registry =>> Support Mistral Models (from HF Transformers) 18 | # fmt: off 19 | MISTRAL_MODELS = { 20 | # === Base Mistral v0.1 === 21 | "mistral-v0.1-7b-pure": { 22 | "llm_family": "mistral", "llm_cls": MistralForCausalLM, "hf_hub_path": "mistralai/Mistral-7B-v0.1" 23 | }, 24 | 25 | # === Mistral Instruct v0.1 === 26 | "mistral-v0.1-7b-instruct": { 27 | "llm_family": "mistral", "llm_cls": MistralForCausalLM, "hf_hub_path": "mistralai/Mistral-7B-Instruct-v0.1" 28 | } 29 | } 30 | # fmt: on 31 | 32 | 33 | class MistralLLMBackbone(HFCausalLLMBackbone): 34 | def __init__( 35 | self, 36 | llm_backbone_id: str, 37 | llm_max_length: int = 2048, 38 | hf_token: Optional[str] = None, 39 | inference_mode: bool = False, 40 | use_flash_attention_2: bool = True, 41 | ) -> None: 42 | super().__init__( 43 | llm_backbone_id, 44 | llm_max_length=llm_max_length, 45 | hf_token=hf_token, 46 | inference_mode=inference_mode, 47 | use_flash_attention_2=use_flash_attention_2, 48 | **MISTRAL_MODELS[llm_backbone_id], 49 | ) 50 | 51 | # [Special Case] Mistral PAD Token Handling --> for clarity, we add an extra token (and resize) 52 | self.tokenizer.add_special_tokens({"pad_token": ""}) 53 | self.llm.config.pad_token_id = self.tokenizer.pad_token_id 54 | self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64) 55 | 56 | @property 57 | def prompt_builder_fn(self) -> Type[PromptBuilder]: 58 | if self.identifier.endswith("-pure"): 59 | return PurePromptBuilder 60 | 61 | elif self.identifier.endswith("-instruct"): 62 | return MistralInstructPromptBuilder 63 | 64 | raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`") 65 | 66 | @property 67 | def transformer_layer_cls(self) -> Type[nn.Module]: 68 | return MistralDecoderLayer 69 | 70 | @property 71 | def half_precision_dtype(self) -> torch.dtype: 72 | return torch.bfloat16 73 | -------------------------------------------------------------------------------- /prismatic/models/backbones/llm/prompting/phi_prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | phi_prompter.py 3 | 4 | Defines a PromptBuilder for building Phi-2 Input/Output Prompts --> recommended pattern used by HF / Microsoft. 5 | Also handles Phi special case BOS token additions. 6 | 7 | Reference: https://huggingface.co/microsoft/phi-2#qa-format 8 | """ 9 | 10 | from typing import Optional 11 | 12 | from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder 13 | 14 | 15 | class PhiPromptBuilder(PromptBuilder): 16 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 17 | super().__init__(model_family, system_prompt) 18 | 19 | # Note =>> Phi Tokenizer is an instance of `CodeGenTokenizer(Fast)` 20 | # =>> By default, does *not* append / tokens --> we handle that here (IMPORTANT)! 21 | self.bos, self.eos = "<|endoftext|>", "<|endoftext|>" 22 | 23 | # Get role-specific "wrap" functions 24 | # =>> Note that placement of / were based on experiments generating from Phi-2 in Input/Output mode 25 | self.wrap_human = lambda msg: f"Input: {msg}\nOutput: " 26 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}\n{self.eos}" 27 | 28 | # === `self.prompt` gets built up over multiple turns === 29 | self.prompt, self.turn_count = "", 0 30 | 31 | def add_turn(self, role: str, message: str) -> str: 32 | assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") 33 | message = message.replace("", "").strip() 34 | 35 | # Special Handling for "first" input --> prepend a token (expected by Prismatic) 36 | if self.turn_count == 0: 37 | bos_human_message = f"{self.bos}{self.wrap_human(message)}" 38 | wrapped_message = bos_human_message 39 | elif (self.turn_count % 2) == 0: 40 | human_message = self.wrap_human(message) 41 | wrapped_message = human_message 42 | else: 43 | gpt_message = self.wrap_gpt(message) 44 | wrapped_message = gpt_message 45 | 46 | # Update Prompt 47 | self.prompt += wrapped_message 48 | 49 | # Bump Turn Counter 50 | self.turn_count += 1 51 | 52 | # Return "wrapped_message" (effective string added to context) 53 | return wrapped_message 54 | 55 | def get_potential_prompt(self, message: str) -> None: 56 | # Assumes that it's always the user's (human's) turn! 57 | prompt_copy = str(self.prompt) 58 | 59 | human_message = self.wrap_human(message) 60 | prompt_copy += human_message 61 | 62 | return prompt_copy.rstrip() 63 | 64 | def get_prompt(self) -> str: 65 | return self.prompt.rstrip() 66 | -------------------------------------------------------------------------------- /prismatic/preprocessing/materialize.py: -------------------------------------------------------------------------------- 1 | """ 2 | materialize.py 3 | 4 | Factory class for initializing pretraining datasets on a per-VLM basis; provides and exports individual functions for 5 | clear control flow. 6 | """ 7 | 8 | from typing import Tuple, Type 9 | 10 | from torch.utils.data import Dataset 11 | from transformers import PreTrainedTokenizerBase 12 | 13 | from prismatic.conf import DatasetConfig 14 | from prismatic.models.backbones.llm.prompting import PromptBuilder 15 | from prismatic.models.backbones.vision import ImageTransform 16 | from prismatic.preprocessing.datasets import AlignDataset, FinetuneDataset 17 | from prismatic.util.data_utils import PaddedCollatorForLanguageModeling 18 | 19 | # Dataset Initializers =>> Maps Stage --> cls() 20 | DATASET_INITIALIZER = {"align": AlignDataset, "finetune": FinetuneDataset, "full-finetune": FinetuneDataset} 21 | 22 | 23 | def get_dataset_and_collator( 24 | stage: str, 25 | dataset_cfg: DatasetConfig, 26 | image_transform: ImageTransform, 27 | tokenizer: PreTrainedTokenizerBase, 28 | prompt_builder_fn: Type[PromptBuilder], 29 | default_image_resolution: Tuple[int, int, int], 30 | padding_side: str = "right", 31 | ) -> Tuple[Dataset, PaddedCollatorForLanguageModeling]: 32 | dataset_cls = DATASET_INITIALIZER[stage] 33 | dataset_root_dir = dataset_cfg.dataset_root_dir 34 | collator = PaddedCollatorForLanguageModeling( 35 | tokenizer.model_max_length, tokenizer.pad_token_id, default_image_resolution, padding_side=padding_side 36 | ) 37 | 38 | # Switch on `stage` 39 | if stage == "align": 40 | annotation_json, image_dir = dataset_cfg.align_stage_components 41 | dataset = dataset_cls( 42 | dataset_root_dir / annotation_json, dataset_root_dir / image_dir, image_transform, tokenizer 43 | ) 44 | return dataset, collator 45 | 46 | elif stage == "finetune": 47 | annotation_json, image_dir = dataset_cfg.finetune_stage_components 48 | dataset = dataset_cls( 49 | dataset_root_dir / annotation_json, 50 | dataset_root_dir / image_dir, 51 | image_transform, 52 | tokenizer, 53 | prompt_builder_fn=prompt_builder_fn, 54 | ) 55 | return dataset, collator 56 | 57 | elif stage == "full-finetune": 58 | annotation_json, image_dir = dataset_cfg.finetune_stage_components 59 | dataset = dataset_cls( 60 | dataset_root_dir / annotation_json, 61 | dataset_root_dir / image_dir, 62 | image_transform, 63 | tokenizer, 64 | prompt_builder_fn=prompt_builder_fn, 65 | ) 66 | return dataset, collator 67 | 68 | else: 69 | raise ValueError(f"Stage `{stage}` is not supported!") 70 | -------------------------------------------------------------------------------- /prismatic/models/backbones/llm/prompting/base_prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | base_prompter.py 3 | 4 | Abstract class definition of a multi-turn prompt builder for ensuring consistent formatting for chat-based LLMs. 5 | """ 6 | 7 | from abc import ABC, abstractmethod 8 | from typing import Optional 9 | 10 | 11 | class PromptBuilder(ABC): 12 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 13 | self.model_family = model_family 14 | 15 | # Only some models define a system prompt => let subclasses handle this logic! 16 | self.system_prompt = system_prompt 17 | 18 | @abstractmethod 19 | def add_turn(self, role: str, message: str) -> str: ... 20 | 21 | @abstractmethod 22 | def get_potential_prompt(self, user_msg: str) -> None: ... 23 | 24 | @abstractmethod 25 | def get_prompt(self) -> str: ... 26 | 27 | 28 | class PurePromptBuilder(PromptBuilder): 29 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 30 | super().__init__(model_family, system_prompt) 31 | 32 | # TODO (siddk) =>> Can't always assume LlamaTokenizer --> FIX ME! 33 | self.bos, self.eos = "", "" 34 | 35 | # Get role-specific "wrap" functions 36 | self.wrap_human = lambda msg: f"In: {msg}\nOut: " 37 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" 38 | 39 | # === `self.prompt` gets built up over multiple turns === 40 | self.prompt, self.turn_count = "", 0 41 | 42 | def add_turn(self, role: str, message: str) -> str: 43 | assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") 44 | message = message.replace("", "").strip() 45 | 46 | if (self.turn_count % 2) == 0: 47 | human_message = self.wrap_human(message) 48 | wrapped_message = human_message 49 | else: 50 | gpt_message = self.wrap_gpt(message) 51 | wrapped_message = gpt_message 52 | 53 | # Update Prompt 54 | self.prompt += wrapped_message 55 | 56 | # Bump Turn Counter 57 | self.turn_count += 1 58 | 59 | # Return "wrapped_message" (effective string added to context) 60 | return wrapped_message 61 | 62 | def get_potential_prompt(self, message: str) -> None: 63 | # Assumes that it's always the user's (human's) turn! 64 | prompt_copy = str(self.prompt) 65 | 66 | human_message = self.wrap_human(message) 67 | prompt_copy += human_message 68 | 69 | return prompt_copy.removeprefix(self.bos).rstrip() 70 | 71 | def get_prompt(self) -> str: 72 | # Remove prefix (if exists) because it gets auto-inserted by tokenizer! 73 | return self.prompt.removeprefix(self.bos).rstrip() 74 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "OmniVLA" 7 | authors = [ 8 | {name = "Noriaki Hirose"}, 9 | {name = "Catherine Glossop"}, 10 | {name = "Dhruv Shah"}, 11 | {name = "Sergey Levine"}, 12 | ] 13 | description = "OmniVLA: An Omni-Modal Vision-Language-Action Model for Robot Navigation" 14 | version = "0.0.1" 15 | readme = "README.md" 16 | requires-python = ">=3.8" 17 | keywords = ["vision-language-action model", "vision-based navigation", "robot learning"] 18 | license = {file = "LICENSE"} 19 | classifiers = [ 20 | "Development Status :: 3 - Alpha", 21 | "Intended Audience :: Developers", 22 | "Intended Audience :: Education", 23 | "Intended Audience :: Science/Research", 24 | "License :: OSI Approved :: MIT License", 25 | "Operating System :: OS Independent", 26 | "Programming Language :: Python :: 3", 27 | "Programming Language :: Python :: 3.8", 28 | "Programming Language :: Python :: 3.9", 29 | "Programming Language :: Python :: 3.10", 30 | "Programming Language :: Python :: 3 :: Only", 31 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 32 | ] 33 | dependencies = [ 34 | "accelerate>=0.25.0", 35 | "draccus==0.8.0", 36 | "einops", 37 | # "flash_attn==2.5.5", # Here for documentation -- install *AFTER* editable install (follow README) 38 | "huggingface_hub==0.29.1", 39 | "json-numpy", 40 | "jsonlines", 41 | "matplotlib", 42 | "peft==0.11.1", 43 | "protobuf", 44 | "rich", 45 | "sentencepiece==0.1.99", 46 | "timm==0.9.10", 47 | "tokenizers==0.19.1", 48 | "torch==2.2.0", 49 | "torchvision==0.17.0", 50 | "torchaudio==2.2.0", 51 | "transformers @ git+https://github.com/moojink/transformers-openvla-oft.git", # IMPORTANT: Use this fork for bidirectional attn (for parallel decoding) 52 | "wandb", 53 | "tensorflow==2.15.0", 54 | "tensorflow_datasets==4.9.3", 55 | "tensorflow_graphics==2021.12.3", 56 | "dlimp @ git+https://github.com/moojink/dlimp_openvla", 57 | "diffusers==0.11.1", 58 | "imageio", 59 | "uvicorn", 60 | "fastapi", 61 | "utm", 62 | "lmdb", 63 | "zarr", 64 | "datasets", 65 | "efficientnet_pytorch", 66 | "av" 67 | ] 68 | 69 | [project.optional-dependencies] 70 | dev = [ 71 | "black>=24.2.0", 72 | "gpustat", 73 | "ipython", 74 | "pre-commit", 75 | "ruff>=0.2.2", 76 | ] 77 | sagemaker = [ 78 | "boto3", 79 | "sagemaker" 80 | ] 81 | 82 | [project.urls] 83 | homepage = "https://omnivla-nav.github.io" 84 | repository = "https://github.com/NHirose/OmniVLA" 85 | 86 | [tool.setuptools.packages.find] 87 | where = ["."] 88 | exclude = ["cache"] 89 | 90 | [tool.setuptools.package-data] 91 | "prismatic" = ["py.typed"] 92 | 93 | [tool.black] 94 | line-length = 121 95 | target-version = ["py38", "py39", "py310"] 96 | preview = true 97 | 98 | [tool.ruff] 99 | line-length = 121 100 | target-version = "py38" 101 | 102 | [tool.ruff.lint] 103 | select = ["A", "B", "E", "F", "I", "RUF", "W"] 104 | ignore = ["F722"] 105 | 106 | [tool.ruff.lint.per-file-ignores] 107 | "__init__.py" = ["E402", "F401"] 108 | -------------------------------------------------------------------------------- /prismatic/models/backbones/llm/prompting/vicuna_v15_prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | vicuna_v15_prompter.py 3 | 4 | Defines a PromptBuilder for building Vicuna-v1.5 Chat Prompts. 5 | 6 | Reference: https://huggingface.co/lmsys/vicuna-13b-v1.5 7 | """ 8 | 9 | from typing import Optional 10 | 11 | from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder 12 | 13 | # Default System Prompt for LLaVa Models 14 | SYS_PROMPTS = { 15 | "prismatic": ( 16 | "A chat between a curious user and an artificial intelligence assistant. " 17 | "The assistant gives helpful, detailed, and polite answers to the user's questions." 18 | ), 19 | "openvla": ( 20 | "A chat between a curious user and an artificial intelligence assistant. " 21 | "The assistant gives helpful, detailed, and polite answers to the user's questions." 22 | ), 23 | } 24 | 25 | 26 | class VicunaV15ChatPromptBuilder(PromptBuilder): 27 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 28 | super().__init__(model_family, system_prompt) 29 | self.system_prompt = (SYS_PROMPTS[self.model_family] if system_prompt is None else system_prompt).strip() + " " 30 | 31 | # LLaMa-2 Specific 32 | self.bos, self.eos = "", "" 33 | 34 | # Get role-specific "wrap" functions 35 | self.wrap_human = lambda msg: f"USER: {msg} ASSISTANT: " 36 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" 37 | 38 | # === `self.prompt` gets built up over multiple turns === 39 | self.prompt, self.turn_count = "", 0 40 | 41 | def add_turn(self, role: str, message: str) -> str: 42 | assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") 43 | message = message.replace("", "").strip() 44 | 45 | # Special Handling for "system" prompt (turn_count == 0) 46 | if self.turn_count == 0: 47 | sys_message = self.system_prompt + self.wrap_human(message) 48 | wrapped_message = sys_message 49 | elif (self.turn_count % 2) == 0: 50 | human_message = self.wrap_human(message) 51 | wrapped_message = human_message 52 | else: 53 | gpt_message = self.wrap_gpt(message) 54 | wrapped_message = gpt_message 55 | 56 | # Update Prompt 57 | self.prompt += wrapped_message 58 | 59 | # Bump Turn Counter 60 | self.turn_count += 1 61 | 62 | # Return "wrapped_message" (effective string added to context) 63 | return wrapped_message 64 | 65 | def get_potential_prompt(self, message: str) -> None: 66 | # Assumes that it's always the user's (human's) turn! 67 | prompt_copy = str(self.prompt) 68 | 69 | # Special Handling for "system" prompt (turn_count == 0) 70 | if self.turn_count == 0: 71 | sys_message = self.system_prompt + self.wrap_human(message) 72 | prompt_copy += sys_message 73 | 74 | else: 75 | human_message = self.wrap_human(message) 76 | prompt_copy += human_message 77 | 78 | return prompt_copy.removeprefix(self.bos).rstrip() 79 | 80 | def get_prompt(self) -> str: 81 | # Remove prefix (if exists) because it gets auto-inserted by tokenizer! 82 | return self.prompt.removeprefix(self.bos).rstrip() 83 | -------------------------------------------------------------------------------- /prismatic/models/backbones/llm/prompting/llama2_chat_prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | llama2_prompter.py 3 | 4 | Defines a PromptBuilder for building LLaMa-2 Chat Prompts --> not sure if this is "optimal", but this is the pattern 5 | that's used by HF and other online tutorials. 6 | 7 | Reference: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 8 | """ 9 | 10 | from typing import Optional 11 | 12 | from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder 13 | 14 | # Default System Prompt for Prismatic Models 15 | SYS_PROMPTS = { 16 | "prismatic": ( 17 | "You are a helpful language and vision assistant. " 18 | "You are able to understand the visual content that the user provides, " 19 | "and assist the user with a variety of tasks using natural language." 20 | ), 21 | "openvla": ( 22 | "You are a helpful language and vision assistant. " 23 | "You are able to understand the visual content that the user provides, " 24 | "and assist the user with a variety of tasks using natural language." 25 | ), 26 | } 27 | 28 | 29 | def format_system_prompt(system_prompt: str) -> str: 30 | return f"<\n{system_prompt.strip()}\n<>\n\n" 31 | 32 | 33 | class LLaMa2ChatPromptBuilder(PromptBuilder): 34 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 35 | super().__init__(model_family, system_prompt) 36 | self.system_prompt = format_system_prompt( 37 | SYS_PROMPTS[self.model_family] if system_prompt is None else system_prompt 38 | ) 39 | 40 | # LLaMa-2 Specific 41 | self.bos, self.eos = "", "" 42 | 43 | # Get role-specific "wrap" functions 44 | self.wrap_human = lambda msg: f"[INST] {msg} [/INST] " 45 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" 46 | 47 | # === `self.prompt` gets built up over multiple turns === 48 | self.prompt, self.turn_count = "", 0 49 | 50 | def add_turn(self, role: str, message: str) -> str: 51 | assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") 52 | message = message.replace("", "").strip() 53 | 54 | # Special Handling for "system" prompt (turn_count == 0) 55 | if self.turn_count == 0: 56 | sys_message = self.wrap_human(self.system_prompt + message) 57 | wrapped_message = sys_message 58 | elif (self.turn_count % 2) == 0: 59 | human_message = self.wrap_human(message) 60 | wrapped_message = human_message 61 | else: 62 | gpt_message = self.wrap_gpt(message) 63 | wrapped_message = gpt_message 64 | 65 | # Update Prompt 66 | self.prompt += wrapped_message 67 | 68 | # Bump Turn Counter 69 | self.turn_count += 1 70 | 71 | # Return "wrapped_message" (effective string added to context) 72 | return wrapped_message 73 | 74 | def get_potential_prompt(self, message: str) -> None: 75 | # Assumes that it's always the user's (human's) turn! 76 | prompt_copy = str(self.prompt) 77 | 78 | # Special Handling for "system" prompt (turn_count == 0) 79 | if self.turn_count == 0: 80 | sys_message = self.wrap_human(self.system_prompt + message) 81 | prompt_copy += sys_message 82 | 83 | else: 84 | human_message = self.wrap_human(message) 85 | prompt_copy += human_message 86 | 87 | return prompt_copy.removeprefix(self.bos).rstrip() 88 | 89 | def get_prompt(self) -> str: 90 | # Remove prefix because it gets auto-inserted by tokenizer! 91 | return self.prompt.removeprefix(self.bos).rstrip() 92 | -------------------------------------------------------------------------------- /prismatic/models/backbones/llm/llama2.py: -------------------------------------------------------------------------------- 1 | """ 2 | llama2.py 3 | 4 | Class definition for all LLMs derived from LlamaForCausalLM. 5 | """ 6 | 7 | from typing import Optional, Sequence, Type 8 | 9 | import torch 10 | from torch import nn as nn 11 | from transformers import LlamaForCausalLM 12 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer 13 | 14 | from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone 15 | from prismatic.models.backbones.llm.prompting import ( 16 | LLaMa2ChatPromptBuilder, 17 | PromptBuilder, 18 | PurePromptBuilder, 19 | VicunaV15ChatPromptBuilder, 20 | ) 21 | 22 | # Registry =>> Support LLaMa-2 Models (from HF Transformers) 23 | # fmt: off 24 | LLAMA2_MODELS = { 25 | # === Pure Meta LLaMa-2 (non-instruct/chat-tuned) Models === 26 | "llama2-7b-pure": { 27 | "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-7b-hf" 28 | }, 29 | 30 | "llama2-13b-pure": { 31 | "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-13b-hf" 32 | }, 33 | 34 | # === Meta LLaMa-2 Chat Models === 35 | "llama2-7b-chat": { 36 | "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-7b-chat-hf" 37 | }, 38 | 39 | "llama2-13b-chat": { 40 | "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-13b-chat-hf" 41 | }, 42 | 43 | # === Vicuna v1.5 Chat Models === 44 | "vicuna-v15-7b": { 45 | "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "lmsys/vicuna-7b-v1.5" 46 | }, 47 | 48 | "vicuna-v15-13b": { 49 | "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "lmsys/vicuna-13b-v1.5" 50 | }, 51 | } 52 | # fmt: on 53 | 54 | 55 | class LLaMa2LLMBackbone(HFCausalLLMBackbone): 56 | def __init__( 57 | self, 58 | llm_backbone_id: str, 59 | llm_max_length: int = 2048, 60 | hf_token: Optional[str] = None, 61 | inference_mode: bool = False, 62 | use_flash_attention_2: bool = True, 63 | ) -> None: 64 | super().__init__( 65 | llm_backbone_id, 66 | llm_max_length=llm_max_length, 67 | hf_token=hf_token, 68 | inference_mode=inference_mode, 69 | use_flash_attention_2=use_flash_attention_2, 70 | **LLAMA2_MODELS[llm_backbone_id], 71 | ) 72 | 73 | # [Special Case] LLaMa-2 PAD Token Handling --> for clarity, we add an extra token (and resize) 74 | self.tokenizer.add_special_tokens({"pad_token": ""}) 75 | self.llm.config.pad_token_id = self.tokenizer.pad_token_id 76 | self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64) 77 | 78 | @property 79 | def prompt_builder_fn(self) -> Type[PromptBuilder]: 80 | if self.identifier.startswith("llama2-") and self.identifier.endswith("-pure"): 81 | return PurePromptBuilder 82 | 83 | elif self.identifier.startswith("llama2-") and self.identifier.endswith("-chat"): 84 | return LLaMa2ChatPromptBuilder 85 | 86 | elif self.identifier.startswith("vicuna"): 87 | return VicunaV15ChatPromptBuilder 88 | 89 | raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`") 90 | 91 | @property 92 | def transformer_layer_cls(self) -> Type[nn.Module]: 93 | return LlamaDecoderLayer 94 | 95 | @property 96 | def half_precision_dtype(self) -> torch.dtype: 97 | """LLaMa-2 was trained in BF16; see https://huggingface.co/docs/transformers/main/model_doc/llama2.""" 98 | return torch.bfloat16 99 | 100 | @property 101 | def last_layer_finetune_modules(self) -> Sequence[nn.Module]: 102 | return (self.llm.model.embed_tokens, self.llm.model.layers[-1], self.llm.lm_head) 103 | -------------------------------------------------------------------------------- /prismatic/vla/action_tokenizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | action_tokenizer.py 3 | 4 | Extension class; wraps base LLM/VLM tokenizer with logic to discretize and tokenize continuous robot actions. 5 | """ 6 | 7 | from typing import List, Union 8 | 9 | import numpy as np 10 | from transformers import PreTrainedTokenizerBase 11 | 12 | 13 | class ActionTokenizer: 14 | def __init__( 15 | self, tokenizer: PreTrainedTokenizerBase, bins: int = 256, min_action: int = -1, max_action: int = 1 16 | ) -> None: 17 | """ 18 | Discretizes continuous robot actions into N bins per dimension and maps to the least used tokens. 19 | 20 | NOTE =>> by default, assumes a BPE-style tokenizer akin to the LlamaTokenizer, where *the least used tokens* 21 | appear at the end of the vocabulary! 22 | 23 | :param tokenizer: Base LLM/VLM tokenizer to extend. 24 | :param bins: Number of bins for each continuous value; we'll adopt a uniform binning strategy. 25 | :param min_action: Minimum action value (for clipping, setting lower bound on bin interval). 26 | :param max_action: Maximum action value (for clipping, setting upper bound on bin interval). 27 | """ 28 | self.tokenizer, self.n_bins, self.min_action, self.max_action = tokenizer, bins, min_action, max_action 29 | 30 | # Create Uniform Bins + Compute Bin Centers 31 | self.bins = np.linspace(min_action, max_action, self.n_bins) 32 | self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0 33 | 34 | # [Contract] Set "action_token_begin_idx" based on `self.tokenizer.vocab_size - (self.n_bins + 1)` 35 | # =>> Assumes we're always overwriting the final `n_bins` tokens of the vocabulary! 36 | self.action_token_begin_idx: int = int(self.tokenizer.vocab_size - (self.n_bins + 1)) 37 | 38 | def __call__(self, action: np.ndarray) -> Union[str, List[str]]: 39 | """Clip & bin actions to *the last `n_bins` tokens* of the vocabulary (e.g., tokenizer.vocab[-256:]).""" 40 | action = np.clip(action, a_min=float(self.min_action), a_max=float(self.max_action)) 41 | discretized_action = np.digitize(action, self.bins) 42 | 43 | # Handle single element vs. batch 44 | if len(discretized_action.shape) == 1: 45 | return self.tokenizer.decode(list(self.tokenizer.vocab_size - discretized_action)) 46 | else: 47 | return self.tokenizer.batch_decode((self.tokenizer.vocab_size - discretized_action).tolist()) 48 | 49 | def decode_token_ids_to_actions(self, action_token_ids: np.ndarray) -> np.ndarray: 50 | """ 51 | Returns continuous actions for discrete action token IDs. 52 | 53 | NOTE =>> Because of the way the actions are discretized w.r.t. the bins (and not the bin centers), the 54 | digitization returns bin indices between [1, # bins], inclusive, when there are actually only 55 | (# bins - 1) bin intervals. 56 | 57 | Therefore, if the digitization returns the last possible index, we map this to the last bin interval. 58 | 59 | EXAMPLE =>> Let's say self._bins has 256 values. Then self._bin_centers has 255 values. Digitization returns 60 | indices between [1, 256]. We subtract 1 from all indices so that they are between [0, 255]. There 61 | is still one index (i==255) that would cause an out-of-bounds error if used to index into 62 | self._bin_centers. Therefore, if i==255, we subtract 1 from it so that it just becomes the index of 63 | the last bin center. We implement this simply via clipping between [0, 255 - 1]. 64 | """ 65 | discretized_actions = self.tokenizer.vocab_size - action_token_ids 66 | discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1) 67 | 68 | return self.bin_centers[discretized_actions] 69 | 70 | @property 71 | def vocab_size(self) -> int: 72 | return self.n_bins 73 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/rlds/traj_transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | traj_transforms.py 3 | 4 | Contains trajectory transforms used in the orca data pipeline. Trajectory transforms operate on a dictionary 5 | that represents a single trajectory, meaning each tensor has the same leading dimension (the trajectory length). 6 | """ 7 | 8 | import logging 9 | from typing import Dict 10 | 11 | import tensorflow as tf 12 | 13 | 14 | def chunk_act_obs(traj: Dict, window_size: int, future_action_window_size: int = 0) -> Dict: 15 | """ 16 | Chunks actions and observations into the given window_size. 17 | 18 | "observation" keys are given a new axis (at index 1) of size `window_size` containing `window_size - 1` 19 | observations from the past and the current observation. "action" is given a new axis (at index 1) of size 20 | `window_size + future_action_window_size` containing `window_size - 1` actions from the past, the current 21 | action, and `future_action_window_size` actions from the future. "pad_mask" is added to "observation" and 22 | indicates whether an observation should be considered padding (i.e. if it had come from a timestep 23 | before the start of the trajectory). 24 | """ 25 | traj_len = tf.shape(traj["action"])[0] 26 | action_dim = traj["action"].shape[-1] 27 | effective_traj_len = traj_len - future_action_window_size 28 | chunk_indices = tf.broadcast_to(tf.range(-window_size + 1, 1), [effective_traj_len, window_size]) + tf.broadcast_to( 29 | tf.range(effective_traj_len)[:, None], [effective_traj_len, window_size] 30 | ) 31 | 32 | action_chunk_indices = tf.broadcast_to( 33 | tf.range(-window_size + 1, 1 + future_action_window_size), 34 | [effective_traj_len, window_size + future_action_window_size], 35 | ) + tf.broadcast_to( 36 | tf.range(effective_traj_len)[:, None], 37 | [effective_traj_len, window_size + future_action_window_size], 38 | ) 39 | 40 | floored_chunk_indices = tf.maximum(chunk_indices, 0) 41 | 42 | goal_timestep = tf.fill([effective_traj_len], traj_len - 1) 43 | 44 | floored_action_chunk_indices = tf.minimum(tf.maximum(action_chunk_indices, 0), goal_timestep[:, None]) 45 | 46 | traj["observation"] = tf.nest.map_structure(lambda x: tf.gather(x, floored_chunk_indices), traj["observation"]) 47 | traj["action"] = tf.gather(traj["action"], floored_action_chunk_indices) 48 | 49 | # indicates whether an entire observation is padding 50 | traj["observation"]["pad_mask"] = chunk_indices >= 0 51 | 52 | # Truncate other elements of the trajectory dict 53 | traj["task"] = tf.nest.map_structure(lambda x: tf.gather(x, tf.range(effective_traj_len)), traj["task"]) 54 | traj["dataset_name"] = tf.gather(traj["dataset_name"], tf.range(effective_traj_len)) 55 | traj["absolute_action_mask"] = tf.gather(traj["absolute_action_mask"], tf.range(effective_traj_len)) 56 | 57 | return traj 58 | 59 | 60 | def subsample(traj: Dict, subsample_length: int) -> Dict: 61 | """Subsamples trajectories to the given length.""" 62 | traj_len = tf.shape(traj["action"])[0] 63 | if traj_len > subsample_length: 64 | indices = tf.random.shuffle(tf.range(traj_len))[:subsample_length] 65 | traj = tf.nest.map_structure(lambda x: tf.gather(x, indices), traj) 66 | 67 | return traj 68 | 69 | 70 | def add_pad_mask_dict(traj: Dict) -> Dict: 71 | """ 72 | Adds a dictionary indicating which elements of the observation/task should be treated as padding. 73 | =>> traj["observation"|"task"]["pad_mask_dict"] = {k: traj["observation"|"task"][k] is not padding} 74 | """ 75 | traj_len = tf.shape(traj["action"])[0] 76 | 77 | for key in ["observation", "task"]: 78 | pad_mask_dict = {} 79 | for subkey in traj[key]: 80 | # Handles "language_instruction", "image_*", and "depth_*" 81 | if traj[key][subkey].dtype == tf.string: 82 | pad_mask_dict[subkey] = tf.strings.length(traj[key][subkey]) != 0 83 | 84 | # All other keys should not be treated as padding 85 | else: 86 | pad_mask_dict[subkey] = tf.ones([traj_len], dtype=tf.bool) 87 | 88 | traj[key]["pad_mask_dict"] = pad_mask_dict 89 | 90 | return traj 91 | -------------------------------------------------------------------------------- /prismatic/util/torch_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | torch_utils.py 3 | 4 | General utilities for randomness, mixed precision training, and miscellaneous checks in PyTorch. 5 | 6 | Random `set_global_seed` functionality is taken directly from PyTorch-Lighting: 7 | > Ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/seed.py 8 | 9 | This is pretty important to get right if we're every randomly generating our masks (or prefix dropout) inside our 10 | Dataset __getitem__() with multiple workers... if not handled properly, we will get repeated augmentations anytime 11 | we inject randomness from non-PyTorch sources (e.g., numpy, random)! 12 | > Ref: https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/ 13 | 14 | Terminology 15 | -> World Size :: Total number of processes distributed over (# nodes x # devices) -- assumed homogenous! 16 | -> Rank :: Integer index of current process in the total world size 17 | -> Local Rank :: Local index on given node in [0, Devices per Node] 18 | """ 19 | 20 | import os 21 | import random 22 | from typing import Callable, Optional 23 | 24 | import numpy as np 25 | import torch 26 | 27 | # === Randomness === 28 | 29 | 30 | def set_global_seed(seed: int, get_worker_init_fn: bool = False) -> Optional[Callable[[int], None]]: 31 | """Sets seed for all randomness libraries (mostly random, numpy, torch) and produces a `worker_init_fn`""" 32 | assert np.iinfo(np.uint32).min < seed < np.iinfo(np.uint32).max, "Seed outside the np.uint32 bounds!" 33 | 34 | # Set Seed as an Environment Variable 35 | os.environ["EXPERIMENT_GLOBAL_SEED"] = str(seed) 36 | random.seed(seed) 37 | np.random.seed(seed) 38 | torch.manual_seed(seed) 39 | 40 | return worker_init_function if get_worker_init_fn else None 41 | 42 | 43 | def worker_init_function(worker_id: int) -> None: 44 | """ 45 | Borrowed directly from PyTorch-Lightning; inspired by this issue comment in the PyTorch repo: 46 | > Ref: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562 47 | 48 | Intuition: You can think of the seed sequence spawn function as a "janky" torch.Generator() or jax.PRNGKey that 49 | you can run iterative splitting on to get new (predictable) randomness. 50 | 51 | :param worker_id: Identifier for the given worker [0, num_workers) for the Dataloader in question. 52 | """ 53 | # Get current `rank` (if running distributed) and `process_seed` 54 | global_rank, process_seed = int(os.environ["LOCAL_RANK"]), torch.initial_seed() 55 | 56 | # Back out the "base" (original) seed - the per-worker seed is set in PyTorch: 57 | # > https://pytorch.org/docs/stable/data.html#data-loading-randomness 58 | base_seed = process_seed - worker_id 59 | 60 | # "Magic" code --> basically creates a seed sequence that mixes different "sources" and seeds every library... 61 | seed_seq = np.random.SeedSequence([base_seed, worker_id, global_rank]) 62 | 63 | # Use 128 bits (4 x 32-bit words) to represent seed --> generate_state(k) produces a `k` element array! 64 | np.random.seed(seed_seq.generate_state(4)) 65 | 66 | # Spawn distinct child sequences for PyTorch (reseed) and stdlib random 67 | torch_seed_seq, random_seed_seq = seed_seq.spawn(2) 68 | 69 | # Torch Manual seed takes 64 bits (so just specify a dtype of uint64 70 | torch.manual_seed(torch_seed_seq.generate_state(1, dtype=np.uint64)[0]) 71 | 72 | # Use 128 Bits for `random`, but express as integer instead of as an array 73 | random_seed = (random_seed_seq.generate_state(2, dtype=np.uint64).astype(list) * [1 << 64, 1]).sum() 74 | random.seed(random_seed) 75 | 76 | 77 | # === BFloat16 Support === 78 | 79 | 80 | def check_bloat16_supported() -> bool: 81 | try: 82 | import packaging.version 83 | import torch.cuda.nccl as nccl 84 | import torch.distributed as dist 85 | 86 | return ( 87 | (torch.version.cuda is not None) 88 | and torch.cuda.is_bf16_supported() 89 | and (packaging.version.parse(torch.version.cuda).release >= (11, 0)) 90 | and dist.is_nccl_available() 91 | and (nccl.version() >= (2, 10)) 92 | ) 93 | 94 | except Exception: 95 | return False 96 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/rlds/obs_transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | obs_transforms.py 3 | 4 | Contains observation-level transforms used in the orca data pipeline. 5 | 6 | These transforms operate on the "observation" dictionary, and are applied at a per-frame level. 7 | """ 8 | 9 | from typing import Dict, Tuple, Union 10 | 11 | import dlimp as dl 12 | import tensorflow as tf 13 | from absl import logging 14 | 15 | 16 | # ruff: noqa: B023 17 | def augment(obs: Dict, seed: tf.Tensor, augment_kwargs: Union[Dict, Dict[str, Dict]]) -> Dict: 18 | """Augments images, skipping padding images.""" 19 | image_names = {key[6:] for key in obs if key.startswith("image_")} 20 | 21 | # "augment_order" is required in augment_kwargs, so if it's there, we can assume that the user has passed 22 | # in a single augmentation dict (otherwise, we assume that the user has passed in a mapping from image 23 | # name to augmentation dict) 24 | if "augment_order" in augment_kwargs: 25 | augment_kwargs = {name: augment_kwargs for name in image_names} 26 | 27 | for i, name in enumerate(image_names): 28 | if name not in augment_kwargs: 29 | continue 30 | kwargs = augment_kwargs[name] 31 | logging.debug(f"Augmenting image_{name} with kwargs {kwargs}") 32 | obs[f"image_{name}"] = tf.cond( 33 | obs["pad_mask_dict"][f"image_{name}"], 34 | lambda: dl.transforms.augment_image( 35 | obs[f"image_{name}"], 36 | **kwargs, 37 | seed=seed + i, # augment each image differently 38 | ), 39 | lambda: obs[f"image_{name}"], # skip padding images 40 | ) 41 | 42 | return obs 43 | 44 | 45 | def decode_and_resize( 46 | obs: Dict, 47 | resize_size: Union[Tuple[int, int], Dict[str, Tuple[int, int]]], 48 | depth_resize_size: Union[Tuple[int, int], Dict[str, Tuple[int, int]]], 49 | ) -> Dict: 50 | """Decodes images and depth images, and then optionally resizes them.""" 51 | image_names = {key[6:] for key in obs if key.startswith("image_")} 52 | depth_names = {key[6:] for key in obs if key.startswith("depth_")} 53 | 54 | if isinstance(resize_size, tuple): 55 | resize_size = {name: resize_size for name in image_names} 56 | if isinstance(depth_resize_size, tuple): 57 | depth_resize_size = {name: depth_resize_size for name in depth_names} 58 | 59 | for name in image_names: 60 | if name not in resize_size: 61 | logging.warning( 62 | f"No resize_size was provided for image_{name}. This will result in 1x1 " 63 | "padding images, which may cause errors if you mix padding and non-padding images." 64 | ) 65 | image = obs[f"image_{name}"] 66 | if image.dtype == tf.string: 67 | if tf.strings.length(image) == 0: 68 | # this is a padding image 69 | image = tf.zeros((*resize_size.get(name, (1, 1)), 3), dtype=tf.uint8) 70 | else: 71 | image = tf.io.decode_image(image, expand_animations=False, dtype=tf.uint8) 72 | elif image.dtype != tf.uint8: 73 | raise ValueError(f"Unsupported image dtype: found image_{name} with dtype {image.dtype}") 74 | if name in resize_size: 75 | image = dl.transforms.resize_image(image, size=resize_size[name]) 76 | obs[f"image_{name}"] = image 77 | 78 | for name in depth_names: 79 | if name not in depth_resize_size: 80 | logging.warning( 81 | f"No depth_resize_size was provided for depth_{name}. This will result in 1x1 " 82 | "padding depth images, which may cause errors if you mix padding and non-padding images." 83 | ) 84 | depth = obs[f"depth_{name}"] 85 | 86 | if depth.dtype == tf.string: 87 | if tf.strings.length(depth) == 0: 88 | depth = tf.zeros((*depth_resize_size.get(name, (1, 1)), 1), dtype=tf.float32) 89 | else: 90 | depth = tf.io.decode_image(depth, expand_animations=False, dtype=tf.float32)[..., 0] 91 | elif depth.dtype != tf.float32: 92 | raise ValueError(f"Unsupported depth dtype: found depth_{name} with dtype {depth.dtype}") 93 | 94 | if name in depth_resize_size: 95 | depth = dl.transforms.resize_depth_image(depth, size=depth_resize_size[name]) 96 | 97 | obs[f"depth_{name}"] = depth 98 | 99 | return obs 100 | -------------------------------------------------------------------------------- /prismatic/models/vlms/base_vlm.py: -------------------------------------------------------------------------------- 1 | """ 2 | base_vlm.py 3 | 4 | Abstract class definition of a Vision-Language Model (VLM), with full annotations of class methods, utility functions, 5 | and initialization logic. This is mostly to future-proof the codebase; while all our experiments instantiate 6 | from PrismaticVLM, theoretically, this base class should be general enough to cover almost all models (e.g., IDEFICS, 7 | PALI, Fuyu) in the future. 8 | 9 | We use Abstract base classes *sparingly* -- mostly as a way to encapsulate any redundant logic or nested inheritance 10 | (e.g., dependence on nn.Module, HF PretrainedModel, etc.). For other abstract objects (e.g., Tokenizers/Transforms), 11 | prefer Protocol definitions instead. 12 | """ 13 | 14 | from __future__ import annotations 15 | 16 | from abc import ABC, abstractmethod 17 | from pathlib import Path 18 | from typing import Callable, List, Optional 19 | 20 | import torch 21 | import torch.nn as nn 22 | from transformers import GenerationMixin, PretrainedConfig 23 | from transformers.modeling_outputs import CausalLMOutputWithPast 24 | 25 | from prismatic.models.backbones.llm import LLMBackbone 26 | from prismatic.models.backbones.llm.prompting import PromptBuilder 27 | from prismatic.models.backbones.vision import VisionBackbone 28 | 29 | 30 | # === Abstract Base Class for arbitrary Vision-Language Models === 31 | class VLM(nn.Module, GenerationMixin, ABC): 32 | def __init__( 33 | self, 34 | model_family: str, 35 | model_id: str, 36 | vision_backbone: VisionBackbone, 37 | llm_backbone: LLMBackbone, 38 | enable_mixed_precision_training: bool = True, 39 | ) -> None: 40 | super().__init__() 41 | self.model_family, self.model_id = model_family, model_id 42 | self.vision_backbone, self.llm_backbone = vision_backbone, llm_backbone 43 | self.enable_mixed_precision_training = enable_mixed_precision_training 44 | 45 | # Instance Attributes for a generic VLM 46 | self.all_module_keys, self.trainable_module_keys = None, None 47 | 48 | # === GenerationMixin Expected Attributes =>> *DO NOT MODIFY* === 49 | self.generation_config = self.llm_backbone.llm.generation_config 50 | self.main_input_name = "input_ids" 51 | 52 | @property 53 | def device(self) -> torch.device: 54 | """Borrowed from `transformers.modeling_utils.py` -- checks parameter device; assumes model on *ONE* device!""" 55 | return next(self.parameters()).device 56 | 57 | @classmethod 58 | @abstractmethod 59 | def from_pretrained( 60 | cls, 61 | pretrained_checkpoint: Path, 62 | model_family: str, 63 | model_id: str, 64 | vision_backbone: VisionBackbone, 65 | llm_backbone: LLMBackbone, 66 | **kwargs: str, 67 | ) -> VLM: ... 68 | 69 | @abstractmethod 70 | def get_prompt_builder(self, system_prompt: Optional[str] = None) -> PromptBuilder: ... 71 | 72 | @abstractmethod 73 | def freeze_backbones(self, stage: str) -> None: ... 74 | 75 | @abstractmethod 76 | def load_from_checkpoint(self, stage: str, run_dir: Path, pretrained_checkpoint: Optional[Path] = None) -> None: ... 77 | 78 | @abstractmethod 79 | def get_fsdp_wrapping_policy(self) -> Callable: ... 80 | 81 | @abstractmethod 82 | def forward( 83 | self, 84 | input_ids: Optional[torch.LongTensor] = None, 85 | attention_mask: Optional[torch.Tensor] = None, 86 | pixel_values: Optional[torch.FloatTensor] = None, 87 | labels: Optional[torch.LongTensor] = None, 88 | inputs_embeds: Optional[torch.FloatTensor] = None, 89 | past_key_values: Optional[List[torch.FloatTensor]] = None, 90 | use_cache: Optional[bool] = None, 91 | output_attentions: Optional[bool] = None, 92 | output_hidden_states: Optional[bool] = None, 93 | return_dict: Optional[bool] = None, 94 | multimodal_indices: Optional[torch.LongTensor] = None, 95 | ) -> CausalLMOutputWithPast: ... 96 | 97 | # === GenerationMixin Expected Properties & Methods (DO NOT MODIFY) === 98 | @staticmethod 99 | def can_generate() -> bool: 100 | return True 101 | 102 | @property 103 | def config(self) -> PretrainedConfig: 104 | return self.llm_backbone.llm.config 105 | 106 | # => Beam Search Utility 107 | def _reorder_cache(self, past_key_values, beam_idx): 108 | return self.llm_backbone.llm._reorder_cache(past_key_values, beam_idx) 109 | -------------------------------------------------------------------------------- /config_nav/mbra_and_dataset_config.yaml: -------------------------------------------------------------------------------- 1 | # parameters for MBRA 2 | context_size: 5 #history of image 3 | len_traj_pred: 8 #action length 4 | learn_angle: True 5 | obs_encoder: efficientnet-b0 6 | obs_encoding_size: 1024 7 | late_fusion: False 8 | mha_num_attention_heads: 4 9 | mha_num_attention_layers: 4 10 | mha_ff_dim_factor: 4 11 | 12 | # parameters for training dataset 13 | image_size: [96, 96] # width, height 14 | context_type: temporal 15 | normalize: True 16 | num_workers: 10 17 | 18 | #For dataset path and setting 19 | datasets_CAST: 20 | path: /your_cast_data_location/ 21 | 22 | datasets_frodobots: 23 | root: /your_frodobots_data_location 24 | horizon_short: 20 25 | horizon_long: 100 26 | frodobot: 27 | train: 11434305 28 | test: 11434305 29 | negative_mining: True 30 | 31 | datasets_bdd: 32 | train: /your_data_location/data_splits/bdd/train/ 33 | test: /your_data_location/data_splits/bdd/test/ 34 | image: /your_bdd_data_location/BDD_dataset 35 | pickle: /your_bdd_data_location/BDD_dataset_pickle 36 | backside: False 37 | aug_seq: False 38 | only_front: True 39 | waypoint_spacing: 0.12 40 | 41 | datasets_gnm: 42 | distance: 43 | min_dist_cat: 0 44 | max_dist_cat: 20 45 | action: 46 | min_dist_cat: 3 47 | max_dist_cat: 20 48 | 49 | recon: 50 | data_folder: /your_gnm_data_location/recon 51 | train: /your_data_location/data_splits/recon/train/ 52 | test: /your_data_location/data_splits/recon/test/ 53 | end_slack: 3 54 | goals_per_obs: 1 55 | negative_mining: True # negative mining from the ViNG paper (Shah et al.) 56 | go_stanford: 57 | data_folder: /your_gnm_data_location/go_stanford # datasets/stanford_go_new 58 | train: /your_data_location/data_splits/go_stanford/train/ 59 | test: /your_data_location/data_splits/go_stanford/test/ 60 | end_slack: 0 61 | goals_per_obs: 2 62 | negative_mining: True 63 | cory_hall: 64 | data_folder: /your_gnm_data_location/cory_hall 65 | train: /your_data_location/data_splits/cory_hall/train/ 66 | test: /your_data_location/data_splits/cory_hall/test/ 67 | end_slack: 3 68 | goals_per_obs: 1 69 | negative_mining: True 70 | tartan_drive: 71 | data_folder: /your_gnm_data_location/tartan_drive 72 | train: /your_data_location/data_splits/tartan_drive/train/ 73 | test: /your_data_location/data_splits/tartan_drive/test/ 74 | end_slack: 3 75 | goals_per_obs: 1 76 | negative_mining: True 77 | sacson: 78 | data_folder: /your_gnm_data_location/sacson 79 | train: /your_data_location/data_splits/sacson/train/ 80 | test: /your_data_location/data_splits/sacson/test/ 81 | end_slack: 3 82 | goals_per_obs: 1 83 | negative_mining: True 84 | seattle: 85 | data_folder: /your_gnm_data_location/seattle/ 86 | train: /your_data_location/data_splits/seattle/train/ 87 | test: /your_data_location/data_splits/seattle/test/ 88 | end_slack: 0 89 | goals_per_obs: 1 90 | negative_mining: True 91 | scand: 92 | data_folder: /your_gnm_data_location/scand/ 93 | train: /your_data_location/data_splits/scand/train/ 94 | test: /your_data_location/data_splits/scand/test/ 95 | end_slack: 0 96 | goals_per_obs: 1 97 | negative_mining: True 98 | 99 | datasets_lelan: 100 | go_stanford4: 101 | train: /your_data_location/data_splits/lelan_gs4/train/ 102 | test: /your_data_location/data_splits/lelan_gs4/test/ 103 | image: /your_lelan_data_location/dataset_LeLaN_gs4/image/ 104 | pickle: /your_lelan_data_location/dataset_LeLaN_gs4/pickle_nomad/ 105 | backside: False 106 | aug_seq: False 107 | only_front: False 108 | 109 | sacson: 110 | train: /your_data_location/data_splits/lelan_sacson/train/ 111 | test: /your_data_location/data_splits/lelan_sacson/test/ 112 | image: /your_lelan_data_location/dataset_LeLaN_sacson/ 113 | pickle: /your_lelan_data_location/dataset_LeLaN_sacson/ 114 | backside: False 115 | aug_seq: False 116 | only_front: False 117 | 118 | go_stanford2: 119 | train: /your_data_location/data_splits/lelan_gs2/train/ 120 | test: /your_data_location/data_splits/lelan_gs2/test/ 121 | image: /your_lelan_data_location/dataset_LeLaN_gs2/ 122 | pickle: /your_lelan_data_location/dataset_LeLaN_gs2/ 123 | backside: False 124 | aug_seq: False 125 | only_front: False 126 | 127 | humanw: 128 | train: /your_data_location/data_splits/lelan_humanw/train/ 129 | test: /your_data_location/data_splits/lelan_humanw/test/ 130 | image: /your_lelan_data_location/dataset_LeLaN_humanwalk/ 131 | pickle: /your_lelan_data_location/dataset_LeLaN_humanwalk/ 132 | backside: False 133 | aug_seq: False 134 | only_front: True 135 | 136 | youtube: 137 | train: /your_data_location/data_splits/lelan_youtube/train/ 138 | test: /your_data_location/data_splits/lelan_youtube/test/ 139 | image: /your_lelan_data_location/dataset_LeLaN_youtube/ 140 | pickle: /your_data_location/pickle_youtube/ 141 | backside: False 142 | aug_seq: False 143 | only_front: True 144 | -------------------------------------------------------------------------------- /prismatic/conf/datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | datasets.py 3 | 4 | Draccus Dataclass Definition for a DatasetConfig object, with various registered subclasses for each dataset variant 5 | and processing scheme. A given dataset variant (e.g., `llava-lightning`) configures the following attributes: 6 | - Dataset Variant (Identifier) --> e.g., "llava-v15" 7 | - Align Stage Dataset Components (annotations, images) 8 | - Finetune Stage Dataset Components (annotations, images) 9 | - Dataset Root Directory (Path) 10 | """ 11 | 12 | from dataclasses import dataclass 13 | from enum import Enum, unique 14 | from pathlib import Path 15 | from typing import Tuple 16 | 17 | from draccus import ChoiceRegistry 18 | 19 | 20 | @dataclass 21 | class DatasetConfig(ChoiceRegistry): 22 | # fmt: off 23 | dataset_id: str # Unique ID that fully specifies a dataset variant 24 | 25 | # Dataset Components for each Stage in < align | finetune > 26 | align_stage_components: Tuple[Path, Path] # Path to annotation file and images directory for `align` stage 27 | finetune_stage_components: Tuple[Path, Path] # Path to annotation file and images directory for `finetune` stage 28 | 29 | dataset_root_dir: Path # Path to dataset root directory; others paths are relative to root 30 | # fmt: on 31 | 32 | 33 | # [Reproduction] LLaVa-v15 (exact dataset used in all public LLaVa-v15 models) 34 | @dataclass 35 | class LLaVa_V15_Config(DatasetConfig): 36 | dataset_id: str = "llava-v15" 37 | 38 | align_stage_components: Tuple[Path, Path] = ( 39 | Path("download/llava-laion-cc-sbu-558k/chat.json"), 40 | Path("download/llava-laion-cc-sbu-558k/"), 41 | ) 42 | finetune_stage_components: Tuple[Path, Path] = ( 43 | Path("download/llava-v1.5-instruct/llava_v1_5_mix665k.json"), 44 | Path("download/llava-v1.5-instruct/"), 45 | ) 46 | dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms") 47 | 48 | 49 | # [Multimodal-Only] LLava-v15 WITHOUT the Language-Only ShareGPT Data (No Co-Training) 50 | @dataclass 51 | class LLaVa_Multimodal_Only_Config(DatasetConfig): 52 | dataset_id: str = "llava-multimodal" 53 | 54 | align_stage_components: Tuple[Path, Path] = ( 55 | Path("download/llava-laion-cc-sbu-558k/chat.json"), 56 | Path("download/llava-laion-cc-sbu-558k/"), 57 | ) 58 | finetune_stage_components: Tuple[Path, Path] = ( 59 | Path("download/llava-v1.5-instruct/llava_v1_5_stripped625k.json"), 60 | Path("download/llava-v1.5-instruct/"), 61 | ) 62 | dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms") 63 | 64 | 65 | # LLaVa-v15 + LVIS-Instruct-4V 66 | @dataclass 67 | class LLaVa_LVIS4V_Config(DatasetConfig): 68 | dataset_id: str = "llava-lvis4v" 69 | 70 | align_stage_components: Tuple[Path, Path] = ( 71 | Path("download/llava-laion-cc-sbu-558k/chat.json"), 72 | Path("download/llava-laion-cc-sbu-558k/"), 73 | ) 74 | finetune_stage_components: Tuple[Path, Path] = ( 75 | Path("download/llava-v1.5-instruct/llava_v1_5_lvis4v_mix888k.json"), 76 | Path("download/llava-v1.5-instruct/"), 77 | ) 78 | dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms") 79 | 80 | 81 | # LLaVa-v15 + LRV-Instruct 82 | @dataclass 83 | class LLaVa_LRV_Config(DatasetConfig): 84 | dataset_id: str = "llava-lrv" 85 | 86 | align_stage_components: Tuple[Path, Path] = ( 87 | Path("download/llava-laion-cc-sbu-558k/chat.json"), 88 | Path("download/llava-laion-cc-sbu-558k/"), 89 | ) 90 | finetune_stage_components: Tuple[Path, Path] = ( 91 | Path("download/llava-v1.5-instruct/llava_v1_5_lrv_mix1008k.json"), 92 | Path("download/llava-v1.5-instruct/"), 93 | ) 94 | dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms") 95 | 96 | 97 | # LLaVa-v15 + LVIS-Instruct-4V + LRV-Instruct 98 | @dataclass 99 | class LLaVa_LVIS4V_LRV_Config(DatasetConfig): 100 | dataset_id: str = "llava-lvis4v-lrv" 101 | 102 | align_stage_components: Tuple[Path, Path] = ( 103 | Path("download/llava-laion-cc-sbu-558k/chat.json"), 104 | Path("download/llava-laion-cc-sbu-558k/"), 105 | ) 106 | finetune_stage_components: Tuple[Path, Path] = ( 107 | Path("download/llava-v1.5-instruct/llava_v1_5_lvis4v_lrv_mix1231k.json"), 108 | Path("download/llava-v1.5-instruct/"), 109 | ) 110 | dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms") 111 | 112 | 113 | # === Define a Dataset Registry Enum for Reference & Validation =>> all *new* datasets must be added here! === 114 | @unique 115 | class DatasetRegistry(Enum): 116 | # === LLaVa v1.5 === 117 | LLAVA_V15 = LLaVa_V15_Config 118 | 119 | LLAVA_MULTIMODAL_ONLY = LLaVa_Multimodal_Only_Config 120 | 121 | LLAVA_LVIS4V = LLaVa_LVIS4V_Config 122 | LLAVA_LRV = LLaVa_LRV_Config 123 | 124 | LLAVA_LVIS4V_LRV = LLaVa_LVIS4V_LRV_Config 125 | 126 | @property 127 | def dataset_id(self) -> str: 128 | return self.value.dataset_id 129 | 130 | 131 | # Register Datasets in Choice Registry 132 | for dataset_variant in DatasetRegistry: 133 | DatasetConfig.register_subclass(dataset_variant.dataset_id, dataset_variant.value) 134 | -------------------------------------------------------------------------------- /prismatic/models/materialize.py: -------------------------------------------------------------------------------- 1 | """ 2 | materialize.py 3 | 4 | Factory class for initializing Vision Backbones, LLM Backbones, and VLMs from a set registry; provides and exports 5 | individual functions for clear control flow. 6 | """ 7 | 8 | from typing import Optional, Tuple 9 | 10 | from transformers import PreTrainedTokenizerBase 11 | 12 | from prismatic.models.backbones.llm import LLaMa2LLMBackbone, LLMBackbone, MistralLLMBackbone, PhiLLMBackbone 13 | from prismatic.models.backbones.vision import ( 14 | CLIPViTBackbone, 15 | DinoCLIPViTBackbone, 16 | DinoSigLIPViTBackbone, 17 | DinoV2ViTBackbone, 18 | ImageTransform, 19 | IN1KViTBackbone, 20 | SigLIPViTBackbone, 21 | VisionBackbone, 22 | ) 23 | from prismatic.models.vlms import PrismaticVLM 24 | 25 | # === Registries =>> Maps ID --> {cls(), kwargs} :: Different Registries for Vision Backbones, LLM Backbones, VLMs === 26 | # fmt: off 27 | 28 | # === Vision Backbone Registry === 29 | VISION_BACKBONES = { 30 | # === 224px Backbones === 31 | "clip-vit-l": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 224}}, 32 | "siglip-vit-so400m": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 224}}, 33 | "dinov2-vit-l": {"cls": DinoV2ViTBackbone, "kwargs": {"default_image_size": 224}}, 34 | "in1k-vit-l": {"cls": IN1KViTBackbone, "kwargs": {"default_image_size": 224}}, 35 | "dinosiglip-vit-so-224px": {"cls": DinoSigLIPViTBackbone, "kwargs": {"default_image_size": 224}}, 36 | 37 | # === Assorted CLIP Backbones === 38 | "clip-vit-b": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 224}}, 39 | "clip-vit-l-336px": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 336}}, 40 | 41 | # === Assorted SigLIP Backbones === 42 | "siglip-vit-b16-224px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 224}}, 43 | "siglip-vit-b16-256px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 256}}, 44 | "siglip-vit-b16-384px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 384}}, 45 | "siglip-vit-so400m-384px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 384}}, 46 | 47 | # === Fused Backbones === 48 | "dinoclip-vit-l-336px": {"cls": DinoCLIPViTBackbone, "kwargs": {"default_image_size": 336}}, 49 | "dinosiglip-vit-so-384px": {"cls": DinoSigLIPViTBackbone, "kwargs": {"default_image_size": 384}}, 50 | } 51 | 52 | 53 | # === Language Model Registry === 54 | LLM_BACKBONES = { 55 | # === LLaMa-2 Pure (Non-Chat) Backbones === 56 | "llama2-7b-pure": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, 57 | "llama2-13b-pure": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, 58 | 59 | # === LLaMa-2 Chat Backbones === 60 | "llama2-7b-chat": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, 61 | "llama2-13b-chat": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, 62 | 63 | # === Vicuna-v1.5 Backbones === 64 | "vicuna-v15-7b": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, 65 | "vicuna-v15-13b": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, 66 | 67 | # === Mistral v0.1 Backbones === 68 | "mistral-v0.1-7b-pure": {"cls": MistralLLMBackbone, "kwargs": {}}, 69 | "mistral-v0.1-7b-instruct": {"cls": MistralLLMBackbone, "kwargs": {}}, 70 | 71 | # === Phi-2 Backbone === 72 | "phi-2-3b": {"cls": PhiLLMBackbone, "kwargs": {}}, 73 | } 74 | 75 | # fmt: on 76 | 77 | 78 | def get_vision_backbone_and_transform( 79 | vision_backbone_id: str, image_resize_strategy: str 80 | ) -> Tuple[VisionBackbone, ImageTransform]: 81 | """Instantiate a Vision Backbone, returning both the nn.Module wrapper class and default Image Transform.""" 82 | if vision_backbone_id in VISION_BACKBONES: 83 | vision_cfg = VISION_BACKBONES[vision_backbone_id] 84 | vision_backbone: VisionBackbone = vision_cfg["cls"]( 85 | vision_backbone_id, image_resize_strategy, **vision_cfg["kwargs"] 86 | ) 87 | image_transform = vision_backbone.get_image_transform() 88 | return vision_backbone, image_transform 89 | 90 | else: 91 | raise ValueError(f"Vision Backbone `{vision_backbone_id}` is not supported!") 92 | 93 | 94 | def get_llm_backbone_and_tokenizer( 95 | llm_backbone_id: str, 96 | llm_max_length: int = 2048, 97 | hf_token: Optional[str] = None, 98 | inference_mode: bool = False, 99 | ) -> Tuple[LLMBackbone, PreTrainedTokenizerBase]: 100 | if llm_backbone_id in LLM_BACKBONES: 101 | llm_cfg = LLM_BACKBONES[llm_backbone_id] 102 | llm_backbone: LLMBackbone = llm_cfg["cls"]( 103 | llm_backbone_id, 104 | llm_max_length=llm_max_length, 105 | hf_token=hf_token, 106 | inference_mode=inference_mode, 107 | **llm_cfg["kwargs"], 108 | ) 109 | tokenizer = llm_backbone.get_tokenizer() 110 | return llm_backbone, tokenizer 111 | 112 | else: 113 | raise ValueError(f"LLM Backbone `{llm_backbone_id}` is not supported!") 114 | 115 | 116 | def get_vlm( 117 | model_id: str, 118 | arch_specifier: str, 119 | vision_backbone: VisionBackbone, 120 | llm_backbone: LLMBackbone, 121 | enable_mixed_precision_training: bool = True, 122 | ) -> PrismaticVLM: 123 | """Lightweight wrapper around initializing a VLM, mostly for future-proofing (if one wants to add a new VLM).""" 124 | return PrismaticVLM( 125 | model_id, 126 | vision_backbone, 127 | llm_backbone, 128 | enable_mixed_precision_training=enable_mixed_precision_training, 129 | arch_specifier=arch_specifier, 130 | ) 131 | -------------------------------------------------------------------------------- /prismatic/overwatch/overwatch.py: -------------------------------------------------------------------------------- 1 | """ 2 | overwatch.py 3 | 4 | Utility class for creating a centralized/standardized logger (built on Rich) and accelerate handler. 5 | """ 6 | 7 | import logging 8 | import logging.config 9 | import os 10 | from contextlib import nullcontext 11 | from logging import LoggerAdapter 12 | from typing import Any, Callable, ClassVar, Dict, MutableMapping, Tuple, Union 13 | 14 | # Overwatch Default Format String 15 | RICH_FORMATTER, DATEFMT = "| >> %(message)s", "%m/%d [%H:%M:%S]" 16 | 17 | # Set Logging Configuration 18 | LOG_CONFIG = { 19 | "version": 1, 20 | "disable_existing_loggers": True, 21 | "formatters": {"simple-console": {"format": RICH_FORMATTER, "datefmt": DATEFMT}}, 22 | "handlers": { 23 | "console": { 24 | "class": "rich.logging.RichHandler", 25 | "formatter": "simple-console", 26 | "markup": True, 27 | "rich_tracebacks": True, 28 | "show_level": True, 29 | "show_path": True, 30 | "show_time": True, 31 | } 32 | }, 33 | "root": {"level": "INFO", "handlers": ["console"]}, 34 | } 35 | logging.config.dictConfig(LOG_CONFIG) 36 | 37 | 38 | # === Custom Contextual Logging Logic === 39 | class ContextAdapter(LoggerAdapter): 40 | CTX_PREFIXES: ClassVar[Dict[int, str]] = {**{0: "[*] "}, **{idx: "|=> ".rjust(4 + (idx * 4)) for idx in [1, 2, 3]}} 41 | 42 | def process(self, msg: str, kwargs: MutableMapping[str, Any]) -> Tuple[str, MutableMapping[str, Any]]: 43 | ctx_level = kwargs.pop("ctx_level", 0) 44 | return f"{self.CTX_PREFIXES[ctx_level]}{msg}", kwargs 45 | 46 | 47 | class DistributedOverwatch: 48 | def __init__(self, name: str) -> None: 49 | """Initializer for an Overwatch object that wraps logging & `accelerate.PartialState`.""" 50 | from accelerate import PartialState 51 | 52 | # Note that PartialState is always safe to initialize regardless of `accelerate launch` or `torchrun` 53 | # =>> However, might be worth actually figuring out if we need the `accelerate` dependency at all! 54 | self.logger, self.distributed_state = ContextAdapter(logging.getLogger(name), extra={}), PartialState() 55 | 56 | # Logger Delegation (for convenience; would be nice to just compose & dynamic dispatch eventually) 57 | self.debug = self.logger.debug 58 | self.info = self.logger.info 59 | self.warning = self.logger.warning 60 | self.error = self.logger.error 61 | self.critical = self.logger.critical 62 | 63 | # Logging Defaults =>> only Log `INFO` on Main Process, `ERROR` on others! 64 | self.logger.setLevel(logging.INFO if self.distributed_state.is_main_process else logging.ERROR) 65 | 66 | @property 67 | def rank_zero_only(self) -> Callable[..., Any]: 68 | return self.distributed_state.on_main_process 69 | 70 | @property 71 | def local_zero_only(self) -> Callable[..., Any]: 72 | return self.distributed_state.on_local_main_process 73 | 74 | @property 75 | def rank_zero_first(self) -> Callable[..., Any]: 76 | return self.distributed_state.main_process_first 77 | 78 | @property 79 | def local_zero_first(self) -> Callable[..., Any]: 80 | return self.distributed_state.local_main_process_first 81 | 82 | def is_rank_zero(self) -> bool: 83 | return self.distributed_state.is_main_process 84 | 85 | def rank(self) -> int: 86 | return self.distributed_state.process_index 87 | 88 | def local_rank(self) -> int: 89 | return self.distributed_state.local_process_index 90 | 91 | def world_size(self) -> int: 92 | return self.distributed_state.num_processes 93 | 94 | 95 | class PureOverwatch: 96 | def __init__(self, name: str) -> None: 97 | """Initializer for an Overwatch object that just wraps logging.""" 98 | self.logger = ContextAdapter(logging.getLogger(name), extra={}) 99 | 100 | # Logger Delegation (for convenience; would be nice to just compose & dynamic dispatch eventually) 101 | self.debug = self.logger.debug 102 | self.info = self.logger.info 103 | self.warning = self.logger.warning 104 | self.error = self.logger.error 105 | self.critical = self.logger.critical 106 | 107 | # Logging Defaults =>> INFO 108 | self.logger.setLevel(logging.INFO) 109 | 110 | @staticmethod 111 | def get_identity_ctx() -> Callable[..., Any]: 112 | def identity(fn: Callable[..., Any]) -> Callable[..., Any]: 113 | return fn 114 | 115 | return identity 116 | 117 | @property 118 | def rank_zero_only(self) -> Callable[..., Any]: 119 | return self.get_identity_ctx() 120 | 121 | @property 122 | def local_zero_only(self) -> Callable[..., Any]: 123 | return self.get_identity_ctx() 124 | 125 | @property 126 | def rank_zero_first(self) -> Callable[..., Any]: 127 | return nullcontext 128 | 129 | @property 130 | def local_zero_first(self) -> Callable[..., Any]: 131 | return nullcontext 132 | 133 | @staticmethod 134 | def is_rank_zero() -> bool: 135 | return True 136 | 137 | @staticmethod 138 | def rank() -> int: 139 | return 0 140 | 141 | @staticmethod 142 | def world_size() -> int: 143 | return 1 144 | 145 | 146 | def initialize_overwatch(name: str) -> Union[DistributedOverwatch, PureOverwatch]: 147 | return DistributedOverwatch(name) if int(os.environ.get("WORLD_SIZE", -1)) != -1 else PureOverwatch(name) 148 | -------------------------------------------------------------------------------- /prismatic/models/vlas/openvla.py: -------------------------------------------------------------------------------- 1 | """ 2 | openvla.py 3 | 4 | PyTorch Module defining OpenVLA as a lightweight wrapper around a PrismaticVLM; defines custom logic around 5 | discretizing actions with the ActionTokenizer. 6 | """ 7 | 8 | from typing import Dict, List, Optional 9 | 10 | import numpy as np 11 | import torch 12 | from PIL import Image 13 | from transformers import LlamaTokenizerFast 14 | 15 | from prismatic.models.vlms.prismatic import PrismaticVLM 16 | from prismatic.overwatch import initialize_overwatch 17 | from prismatic.vla.action_tokenizer import ActionTokenizer 18 | 19 | # Initialize Overwatch =>> Wraps `logging.Logger` 20 | overwatch = initialize_overwatch(__name__) 21 | 22 | 23 | class OpenVLA(PrismaticVLM): 24 | def __init__( 25 | self, 26 | *args, 27 | norm_stats: Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]], 28 | action_tokenizer: ActionTokenizer, 29 | **kwargs, 30 | ) -> None: 31 | super().__init__(*args, **kwargs) 32 | self.norm_stats = norm_stats 33 | self.action_tokenizer = action_tokenizer 34 | 35 | @torch.inference_mode() 36 | def predict_action( 37 | self, image: Image, instruction: str, unnorm_key: Optional[str] = None, **kwargs: str 38 | ) -> np.ndarray: 39 | """ 40 | Core function for VLA inference; maps input image and task instruction to continuous action (de-tokenizes). 41 | 42 | @param image: PIL Image as [height, width, 3] 43 | @param instruction: Task instruction string 44 | @param unnorm_key: Optional dataset name for retrieving un-normalizing statistics; if None, checks that model 45 | was trained only on a single dataset, and retrieves those statistics. 46 | 47 | @return Unnormalized (continuous) action vector --> end-effector deltas. 48 | """ 49 | image_transform, tokenizer = self.vision_backbone.image_transform, self.llm_backbone.tokenizer 50 | 51 | # Build VLA Prompt 52 | prompt_builder = self.get_prompt_builder() 53 | prompt_builder.add_turn(role="human", message=f"What action should the robot take to {instruction.lower()}?") 54 | prompt_text = prompt_builder.get_prompt() 55 | 56 | # Prepare Inputs 57 | input_ids = tokenizer(prompt_text, truncation=True, return_tensors="pt").input_ids.to(self.device) 58 | if isinstance(tokenizer, LlamaTokenizerFast): 59 | # If the special empty token ('') does not already appear after the colon (':') token in the prompt 60 | # (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time 61 | if not torch.all(input_ids[:, -1] == 29871): 62 | input_ids = torch.cat( 63 | (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1 64 | ) 65 | else: 66 | raise ValueError(f"Unsupported `tokenizer` type = {type(tokenizer)}") 67 | 68 | # Preprocess Image 69 | pixel_values = image_transform(image) 70 | if isinstance(pixel_values, torch.Tensor): 71 | pixel_values = pixel_values[None, ...].to(self.device) 72 | elif isinstance(pixel_values, dict): 73 | pixel_values = {k: v[None, ...].to(self.device) for k, v in pixel_values.items()} 74 | else: 75 | raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}") 76 | 77 | # Invoke super().generate --> taps into `GenerationMixin` which (redirects) to `forward()` 78 | autocast_dtype = self.llm_backbone.half_precision_dtype 79 | with torch.autocast("cuda", dtype=autocast_dtype, enabled=self.enable_mixed_precision_training): 80 | # fmt: off 81 | generated_ids = super(PrismaticVLM, self).generate( 82 | input_ids=input_ids, # Shape: [1, seq] 83 | pixel_values=pixel_values, # Shape: [1, 3, res, res] or Dict[str, ...] 84 | max_new_tokens=self.get_action_dim(unnorm_key), 85 | **kwargs 86 | ) 87 | # fmt: on 88 | 89 | # Extract predicted action tokens and translate into (normalized) continuous actions 90 | predicted_action_token_ids = generated_ids[0, -self.get_action_dim(unnorm_key) :] 91 | normalized_actions = self.action_tokenizer.decode_token_ids_to_actions(predicted_action_token_ids.cpu().numpy()) 92 | 93 | # Un-normalize Actions 94 | action_norm_stats = self.get_action_stats(unnorm_key) 95 | mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool)) 96 | action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"]) 97 | actions = np.where( 98 | mask, 99 | 0.5 * (normalized_actions + 1) * (action_high - action_low) + action_low, 100 | normalized_actions, 101 | ) 102 | 103 | return actions 104 | 105 | @staticmethod 106 | def _check_unnorm_key(norm_stats: Dict, unnorm_key: str) -> str: 107 | if unnorm_key is None: 108 | assert len(norm_stats) == 1, ( 109 | f"Your model was trained on more than one dataset, please pass a `unnorm_key` from the following " 110 | f"options to choose the statistics used for un-normalizing actions: {norm_stats.keys()}" 111 | ) 112 | unnorm_key = next(iter(norm_stats.keys())) 113 | 114 | # Error Handling 115 | assert ( 116 | unnorm_key in norm_stats 117 | ), f"The `unnorm_key` you chose is not in the set of available statistics; choose from: {norm_stats.keys()}" 118 | 119 | return unnorm_key 120 | 121 | def get_action_dim(self, unnorm_key: Optional[str] = None) -> int: 122 | """Dimensionality of the policy's action space.""" 123 | unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) 124 | 125 | return len(self.norm_stats[unnorm_key]["action"]["q01"]) 126 | 127 | def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict: 128 | """Dimensionality of the policy's action space.""" 129 | unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) 130 | 131 | return self.norm_stats[unnorm_key]["action"] 132 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/rlds/oxe/materialize.py: -------------------------------------------------------------------------------- 1 | """ 2 | materialize.py 3 | 4 | Factory class for initializing Open-X Embodiment dataset kwargs and other parameters; provides and exports functions for 5 | clear control flow. 6 | """ 7 | 8 | from copy import deepcopy 9 | from pathlib import Path 10 | from typing import Any, Dict, List, Tuple 11 | 12 | from prismatic.overwatch import initialize_overwatch 13 | from prismatic.vla.constants import ACTION_DIM, ACTION_PROPRIO_NORMALIZATION_TYPE, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX, NUM_ACTIONS_CHUNK, STOP_INDEX 14 | from prismatic.vla.datasets.rlds.oxe.configs import OXE_DATASET_CONFIGS, ActionEncoding 15 | from prismatic.vla.datasets.rlds.oxe.transforms import OXE_STANDARDIZATION_TRANSFORMS 16 | 17 | # Initialize Overwatch =>> Wraps `logging.Logger` 18 | overwatch = initialize_overwatch(__name__) 19 | 20 | 21 | def make_oxe_dataset_kwargs( 22 | dataset_name: str, 23 | data_root_dir: Path, 24 | load_camera_views: Tuple[str] = ("primary",), 25 | load_depth: bool = False, 26 | load_proprio: bool = True, 27 | load_language: bool = True, 28 | action_proprio_normalization_type = ACTION_PROPRIO_NORMALIZATION_TYPE, 29 | ) -> Dict[str, Any]: 30 | """Generates config (kwargs) for given dataset from Open-X Embodiment.""" 31 | dataset_kwargs = deepcopy(OXE_DATASET_CONFIGS[dataset_name]) 32 | if dataset_kwargs["action_encoding"] not in [ActionEncoding.EEF_POS, ActionEncoding.EEF_R6, ActionEncoding.JOINT_POS_BIMANUAL]: 33 | raise ValueError(f"Cannot load `{dataset_name}`; only EEF_POS & EEF_R6 & JOINT_POS_BIMANUAL actions supported!") 34 | 35 | # [Contract] For EEF_POS & EEF_R6 actions, only the last action dimension (gripper) is absolute! 36 | # Normalize all action dimensions *except* the gripper 37 | if dataset_kwargs["action_encoding"] is ActionEncoding.EEF_POS: 38 | dataset_kwargs["absolute_action_mask"] = [False] * 6 + [True] 39 | dataset_kwargs["action_normalization_mask"] = [True] * 6 + [False] 40 | elif dataset_kwargs["action_encoding"] is ActionEncoding.EEF_R6: 41 | dataset_kwargs["absolute_action_mask"] = [False] * 9 + [True] 42 | dataset_kwargs["action_normalization_mask"] = [True] * 9 + [False] 43 | elif dataset_kwargs["action_encoding"] is ActionEncoding.JOINT_POS_BIMANUAL: 44 | dataset_kwargs["absolute_action_mask"] = [True] * 14 45 | dataset_kwargs["action_normalization_mask"] = [True] * 14 46 | dataset_kwargs["action_proprio_normalization_type"] = action_proprio_normalization_type 47 | 48 | # Adjust Loaded Camera Views 49 | if len(missing_keys := (set(load_camera_views) - set(dataset_kwargs["image_obs_keys"]))) > 0: 50 | raise ValueError(f"Cannot load `{dataset_name}`; missing camera views `{missing_keys}`") 51 | 52 | # Filter 53 | dataset_kwargs["image_obs_keys"] = { 54 | k: v for k, v in dataset_kwargs["image_obs_keys"].items() if k in load_camera_views 55 | } 56 | dataset_kwargs["depth_obs_keys"] = { 57 | k: v for k, v in dataset_kwargs["depth_obs_keys"].items() if k in load_camera_views 58 | } 59 | 60 | # Eliminate Unnecessary Keys 61 | dataset_kwargs.pop("state_encoding") 62 | dataset_kwargs.pop("action_encoding") 63 | if not load_depth: 64 | dataset_kwargs.pop("depth_obs_keys") 65 | if not load_proprio: 66 | dataset_kwargs.pop("state_obs_keys") 67 | 68 | # Load Language 69 | if load_language: 70 | dataset_kwargs["language_key"] = "language_instruction" 71 | 72 | # Specify Standardization Transform 73 | dataset_kwargs["standardize_fn"] = OXE_STANDARDIZATION_TRANSFORMS[dataset_name] 74 | 75 | # Add any aux arguments 76 | if "aux_kwargs" in dataset_kwargs: 77 | dataset_kwargs.update(dataset_kwargs.pop("aux_kwargs")) 78 | 79 | return {"name": dataset_name, "data_dir": str(data_root_dir), **dataset_kwargs} 80 | 81 | 82 | def get_oxe_dataset_kwargs_and_weights( 83 | data_root_dir: Path, 84 | mixture_spec: List[Tuple[str, float]], 85 | load_camera_views: Tuple[str] = ("primary",), 86 | load_depth: bool = False, 87 | load_proprio: bool = True, 88 | load_language: bool = True, 89 | action_proprio_normalization_type = ACTION_PROPRIO_NORMALIZATION_TYPE, 90 | ) -> Tuple[Dict[str, Any], List[float]]: 91 | """ 92 | Generates dataset kwargs for a given dataset mix from the Open X-Embodiment dataset. The returned kwargs 93 | (per-dataset configs) and weights can be passed directly to `make_interleaved_dataset`. 94 | 95 | :param data_root_dir: Base directory containing RLDS/TFDS-formatted datasets (from Open-X) 96 | :param mixture_spec: List of (dataset_name, sampling_weight) from `oxe.mixtures.OXE_NAMED_MIXTURES` 97 | :param load_camera_views: Camera views to load; see `oxe.dataset_configs.py` for available views. 98 | :param load_depth: Load depth information in addition to camera RGB. 99 | :param load_proprio: Load proprioceptive state. 100 | :param load_language: Load language instructions. 101 | :param action_proprio_normalization_type: Normalization scheme to use for proprioceptive actions. 102 | 103 | return: Tuple of (per_dataset_kwargs, sampling_weights) 104 | """ 105 | included_datasets, filtered_mixture_spec = set(), [] 106 | for d_name, d_weight in mixture_spec: 107 | if d_name in included_datasets: 108 | overwatch.warning(f"Skipping Duplicate Dataset: `{(d_name, d_weight)}`") 109 | continue 110 | 111 | included_datasets.add(d_name) 112 | filtered_mixture_spec.append((d_name, d_weight)) 113 | 114 | # Assemble Dataset Config (kwargs) and Weights 115 | per_dataset_kwargs, sampling_weights = [], [] 116 | for d_name, d_weight in filtered_mixture_spec: 117 | try: 118 | per_dataset_kwargs.append( 119 | make_oxe_dataset_kwargs( 120 | d_name, 121 | data_root_dir, 122 | load_camera_views, 123 | load_depth, 124 | load_proprio, 125 | load_language, 126 | action_proprio_normalization_type, 127 | ) 128 | ) 129 | sampling_weights.append(d_weight) 130 | 131 | except ValueError as e: 132 | overwatch.warning(f"Skipping `{d_name}` due to Error: {e}") 133 | 134 | return per_dataset_kwargs, sampling_weights 135 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/rlds/oxe/utils/droid_utils.py: -------------------------------------------------------------------------------- 1 | """Episode transforms for DROID dataset.""" 2 | 3 | from typing import Any, Dict 4 | 5 | import tensorflow as tf 6 | import tensorflow_graphics.geometry.transformation as tfg 7 | 8 | 9 | def rmat_to_euler(rot_mat): 10 | return tfg.euler.from_rotation_matrix(rot_mat) 11 | 12 | 13 | def euler_to_rmat(euler): 14 | return tfg.rotation_matrix_3d.from_euler(euler) 15 | 16 | 17 | def invert_rmat(rot_mat): 18 | return tfg.rotation_matrix_3d.inverse(rot_mat) 19 | 20 | 21 | def rotmat_to_rot6d(mat): 22 | """ 23 | Converts rotation matrix to R6 rotation representation (first two rows in rotation matrix). 24 | Args: 25 | mat: rotation matrix 26 | 27 | Returns: 6d vector (first two rows of rotation matrix) 28 | 29 | """ 30 | r6 = mat[..., :2, :] 31 | r6_0, r6_1 = r6[..., 0, :], r6[..., 1, :] 32 | r6_flat = tf.concat([r6_0, r6_1], axis=-1) 33 | return r6_flat 34 | 35 | 36 | def velocity_act_to_wrist_frame(velocity, wrist_in_robot_frame): 37 | """ 38 | Translates velocity actions (translation + rotation) from base frame of the robot to wrist frame. 39 | Args: 40 | velocity: 6d velocity action (3 x translation, 3 x rotation) 41 | wrist_in_robot_frame: 6d pose of the end-effector in robot base frame 42 | 43 | Returns: 9d velocity action in robot wrist frame (3 x translation, 6 x rotation as R6) 44 | 45 | """ 46 | R_frame = euler_to_rmat(wrist_in_robot_frame[:, 3:6]) 47 | R_frame_inv = invert_rmat(R_frame) 48 | 49 | # world to wrist: dT_pi = R^-1 dT_rbt 50 | vel_t = (R_frame_inv @ velocity[:, :3][..., None])[..., 0] 51 | 52 | # world to wrist: dR_pi = R^-1 dR_rbt R 53 | dR = euler_to_rmat(velocity[:, 3:6]) 54 | dR = R_frame_inv @ (dR @ R_frame) 55 | dR_r6 = rotmat_to_rot6d(dR) 56 | return tf.concat([vel_t, dR_r6], axis=-1) 57 | 58 | 59 | def rand_swap_exterior_images(img1, img2): 60 | """ 61 | Randomly swaps the two exterior images (for training with single exterior input). 62 | """ 63 | return tf.cond(tf.random.uniform(shape=[]) > 0.5, lambda: (img1, img2), lambda: (img2, img1)) 64 | 65 | 66 | def droid_baseact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: 67 | """ 68 | DROID dataset transformation for actions expressed in *base* frame of the robot. 69 | """ 70 | dt = trajectory["action_dict"]["cartesian_velocity"][:, :3] 71 | dR = trajectory["action_dict"]["cartesian_velocity"][:, 3:6] 72 | 73 | trajectory["action"] = tf.concat( 74 | ( 75 | dt, 76 | dR, 77 | 1 - trajectory["action_dict"]["gripper_position"], 78 | ), 79 | axis=-1, 80 | ) 81 | trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = ( 82 | rand_swap_exterior_images( 83 | trajectory["observation"]["exterior_image_1_left"], 84 | trajectory["observation"]["exterior_image_2_left"], 85 | ) 86 | ) 87 | trajectory["observation"]["proprio"] = tf.concat( 88 | ( 89 | trajectory["observation"]["cartesian_position"], 90 | trajectory["observation"]["gripper_position"], 91 | ), 92 | axis=-1, 93 | ) 94 | return trajectory 95 | 96 | 97 | def droid_wristact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: 98 | """ 99 | DROID dataset transformation for actions expressed in *wrist* frame of the robot. 100 | """ 101 | wrist_act = velocity_act_to_wrist_frame( 102 | trajectory["action_dict"]["cartesian_velocity"], trajectory["observation"]["cartesian_position"] 103 | ) 104 | trajectory["action"] = tf.concat( 105 | ( 106 | wrist_act, 107 | trajectory["action_dict"]["gripper_position"], 108 | ), 109 | axis=-1, 110 | ) 111 | trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = ( 112 | rand_swap_exterior_images( 113 | trajectory["observation"]["exterior_image_1_left"], 114 | trajectory["observation"]["exterior_image_2_left"], 115 | ) 116 | ) 117 | trajectory["observation"]["proprio"] = tf.concat( 118 | ( 119 | trajectory["observation"]["cartesian_position"], 120 | trajectory["observation"]["gripper_position"], 121 | ), 122 | axis=-1, 123 | ) 124 | return trajectory 125 | 126 | 127 | def droid_finetuning_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: 128 | """ 129 | DROID dataset transformation for actions expressed in *base* frame of the robot. 130 | """ 131 | dt = trajectory["action_dict"]["cartesian_velocity"][:, :3] 132 | dR = trajectory["action_dict"]["cartesian_velocity"][:, 3:6] 133 | trajectory["action"] = tf.concat( 134 | ( 135 | dt, 136 | dR, 137 | 1 - trajectory["action_dict"]["gripper_position"], 138 | ), 139 | axis=-1, 140 | ) 141 | trajectory["observation"]["proprio"] = tf.concat( 142 | ( 143 | trajectory["observation"]["cartesian_position"], 144 | trajectory["observation"]["gripper_position"], 145 | ), 146 | axis=-1, 147 | ) 148 | return trajectory 149 | 150 | 151 | def zero_action_filter(traj: Dict) -> bool: 152 | """ 153 | Filters transitions whose actions are all-0 (only relative actions, no gripper action). 154 | Note: this filter is applied *after* action normalization, so need to compare to "normalized 0". 155 | """ 156 | DROID_Q01 = tf.convert_to_tensor( 157 | [ 158 | -0.7776297926902771, 159 | -0.5803514122962952, 160 | -0.5795090794563293, 161 | -0.6464047729969025, 162 | -0.7041108310222626, 163 | -0.8895104378461838, 164 | ] 165 | ) 166 | DROID_Q99 = tf.convert_to_tensor( 167 | [ 168 | 0.7597932070493698, 169 | 0.5726242214441299, 170 | 0.7351000607013702, 171 | 0.6705610305070877, 172 | 0.6464948207139969, 173 | 0.8897542208433151, 174 | ] 175 | ) 176 | DROID_NORM_0_ACT = 2 * (tf.zeros_like(traj["action"][:, :6]) - DROID_Q01) / (DROID_Q99 - DROID_Q01 + 1e-8) - 1 177 | 178 | return tf.reduce_any(tf.math.abs(traj["action"][:, :6] - DROID_NORM_0_ACT) > 1e-5) 179 | -------------------------------------------------------------------------------- /prismatic/extern/hf/configuration_prismatic.py: -------------------------------------------------------------------------------- 1 | """ 2 | configuration_prismatic.py 3 | 4 | HuggingFace-style configuration definition for Prismatic VLMs, inheriting from `transformers.PretrainedConfig`. 5 | Default configuration specifies `siglip-224px+7b`. 6 | """ 7 | 8 | from typing import Any, Dict, List, Optional 9 | 10 | from transformers import PretrainedConfig 11 | from transformers.models.auto import CONFIG_MAPPING 12 | 13 | # === Utilities for Mapping Prismatic names to HF names === 14 | # fmt: off 15 | VISION_BACKBONE_TO_RESOLUTION: Dict[str, List[int]] = { 16 | "clip-vit-l": [224], "siglip-vit-so400m": [224], "dinov2-vit-l": [224], "in1k-vit-l": [224], 17 | 18 | "clip-vit-l-336px": [336], 19 | "siglip-vit-so400m-384px": [384], 20 | 21 | "dinoclip-vit-l-336px": [336, 336], 22 | "dinosiglip-vit-so-224px": [224, 224], 23 | "dinosiglip-vit-so-384px": [384, 384], 24 | } 25 | VISION_BACKBONE_TO_TIMM_ID: Dict[str, List[str]] = { 26 | "clip-vit-l": ["vit_large_patch14_clip_224.openai"], 27 | "clip-vit-l-336px": ["vit_large_patch14_clip_336.openai"], 28 | 29 | "dinov2-vit-l": ["vit_large_patch14_reg4_dinov2.lvd142m"], 30 | "in1k-vit-l": ["vit_large_patch16_224.augreg_in21k_ft_in1k"], 31 | 32 | "siglip-vit-so400m": ["vit_so400m_patch14_siglip_224"], 33 | "siglip-vit-so400m-384px": ["vit_so400m_patch14_siglip_384"], 34 | 35 | "dinoclip-vit-l-336px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_large_patch14_clip_336.openai"], 36 | "dinosiglip-vit-so-224px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_224"], 37 | "dinosiglip-vit-so-384px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_384"], 38 | } 39 | TIMM_OVERRIDE_ACT_LAYER: Dict[str, List[Optional[str]]] = { 40 | "clip-vit-l": ["quick_gelu"], "clip-vit-l-336px": ["quick_gelu"], 41 | "dinov2-vit-l": [None], "in1k-vit-l": [None], 42 | "siglip-vit-so400m": [None], "siglip-vit-so400m-384px": [None], 43 | "dinoclip-vit-l-336px": [None, "quick_gelu"], 44 | "dinosiglip-vit-so-224px": [None, None], "dinosiglip-vit-so-384px": [None, None] 45 | } 46 | 47 | LLM_BACKBONE_TO_HF_PATH = { 48 | "llama2-7b-pure": "meta-llama/Llama-2-7b-hf", "llama2-13b-pure": "meta-llama/Llama-2-13b-hf", 49 | "llama2-7b-chat": "meta-llama/Llama-2-7b-chat-hf", "llama2-13b-chat": "meta-llama/Llama-2-13b-chat-hf", 50 | 51 | "vicuna-v15-7b": "lmsys/vicuna-7b-v1.5", "vicuna-v15-13b": "lmsys/vicuna-13b-v1.5", 52 | 53 | "mistral-v0.1-7b-pure": "mistralai/Mistral-7B-v0.1", 54 | "mistral-v0.1-7b-instruct": "mistralai/Mistral-7B-Instruct-v0.1", 55 | 56 | "phi-2-3b": "microsoft/phi-2", 57 | } 58 | LLM_BACKBONE_TO_HF_METACLASS = { 59 | "llama2-7b-pure": "llama", "llama2-13b-pure": "llama", "llama2-7b-chat": "llama", "llama2-13b-chat": "llama", 60 | "vicuna-v15-7b": "llama", "vicuna-v15-13b": "llama", 61 | 62 | "mistral-v0.1-7b-pure": "mistral", "mistral-v0.1-7b-instruct": "mistral", 63 | 64 | "phi-2-3b": "phi", 65 | } 66 | 67 | VALID_VISION_BACKBONES = set(VISION_BACKBONE_TO_RESOLUTION.keys()) 68 | VALID_LLM_BACKBONES = set(LLM_BACKBONE_TO_HF_PATH) 69 | # fmt: on 70 | 71 | 72 | class PrismaticConfig(PretrainedConfig): 73 | model_type: str = "prismatic" 74 | is_composition: bool = False 75 | 76 | def __init__( 77 | self, 78 | vision_backbone_id: str = "siglip-vit-so400m", 79 | llm_backbone_id: str = "vicuna-v15-7b", 80 | arch_specifier: str = "no-align+gelu-mlp", 81 | use_fused_vision_backbone: Optional[bool] = None, 82 | image_resize_strategy: str = "letterbox", 83 | text_config: Optional[Dict[str, Any]] = None, 84 | llm_max_length: int = 2048, 85 | pad_token_id: int = 32000, 86 | pad_to_multiple_of: int = 64, 87 | output_projector_states: bool = False, 88 | **kwargs: str, 89 | ) -> None: 90 | if vision_backbone_id not in VALID_VISION_BACKBONES: 91 | raise ValueError(f"Vision backbone `{vision_backbone_id}` not in {VALID_VISION_BACKBONES = }") 92 | 93 | if llm_backbone_id not in VALID_LLM_BACKBONES: 94 | raise ValueError(f"LLM backbone `{llm_backbone_id}` not in {VALID_LLM_BACKBONES = }") 95 | 96 | # Set Prismatic Configuration Fields 97 | self.vision_backbone_id = vision_backbone_id 98 | self.llm_backbone_id = llm_backbone_id 99 | self.arch_specifier = arch_specifier 100 | self.output_projector_states = output_projector_states 101 | 102 | # [Contract] All vision backbone parameters are lists =>> supports fused backbones with different preprocessing 103 | self.use_fused_vision_backbone = ( 104 | use_fused_vision_backbone 105 | if use_fused_vision_backbone is not None 106 | else any(self.vision_backbone_id.startswith(v) for v in ["dinoclip", "dinosiglip"]) 107 | ) 108 | 109 | self.timm_model_ids = VISION_BACKBONE_TO_TIMM_ID[self.vision_backbone_id] 110 | self.timm_override_act_layers = TIMM_OVERRIDE_ACT_LAYER[self.vision_backbone_id] 111 | self.image_sizes = VISION_BACKBONE_TO_RESOLUTION[self.vision_backbone_id] 112 | self.image_resize_strategy = image_resize_strategy 113 | 114 | self.hf_llm_id = LLM_BACKBONE_TO_HF_PATH[self.llm_backbone_id] 115 | self.llm_max_length = llm_max_length 116 | self.pad_token_id, self.pad_to_multiple_of = pad_token_id, pad_to_multiple_of 117 | 118 | # [IMPORTANT] HF Utilities actually look for a `text_config` field... we need to use that specific naming! 119 | self.text_config = ( 120 | CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]](**text_config) 121 | if text_config is not None 122 | else CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]]() 123 | ) 124 | 125 | # Dispatch **kwargs to super() =>> note that `pad_token_id` collides, so we pass it in here as well... 126 | super().__init__(pad_token_id=pad_token_id, **kwargs) 127 | 128 | 129 | class OpenVLAConfig(PrismaticConfig): 130 | model_type: str = "openvla" 131 | 132 | def __init__( 133 | self, 134 | norm_stats: Optional[Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]]] = None, 135 | n_action_bins: int = 256, 136 | **kwargs: str, 137 | ) -> None: 138 | self.norm_stats, self.n_action_bins = norm_stats, n_action_bins 139 | 140 | super().__init__(**kwargs) 141 | -------------------------------------------------------------------------------- /prismatic/models/action_heads.py: -------------------------------------------------------------------------------- 1 | """Implementations of various action heads, which serve as alternatives to VLM sequential token prediction.""" 2 | 3 | import math 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | #from diffusers.schedulers.scheduling_ddim import DDIMScheduler 9 | from prismatic.vla.constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX, NUM_ACTIONS_CHUNK, STOP_INDEX 10 | 11 | 12 | class SinusoidalPositionalEncoding(nn.Module): 13 | """ 14 | Sine- and cosine-based positional encoding that produces embeddings of a batch of timesteps. 15 | 16 | For example, at train time, the input might be a batch of 32 randomly sampled diffusion timesteps -> shape (32,) 17 | Then the output would be a batch of 32 timestep embeddings -> shape (32, D) 18 | 19 | Adapted from: https://github.com/real-stanford/diffusion_policy/blob/main/diffusion_policy/model/diffusion/positional_embedding.py 20 | """ 21 | 22 | def __init__(self, dim): 23 | super().__init__() 24 | self.dim = dim # dimensionality of the positional encoding 25 | 26 | def forward(self, x): 27 | # x: (batch_size,) 28 | device = x.device 29 | assert self.dim % 2 == 0, f"# dimensions must be even but got {self.dim}" 30 | half_dim = self.dim // 2 31 | exponent = torch.arange(half_dim, device=device) * -math.log(10000) / (half_dim - 1) # shape: (D/2,) 32 | emb = torch.exp(exponent) # shape: (D/2,) 33 | emb = x[:, None] * emb[None, :] # shape: (batch_size, 1) * (1, D/2) -> (batch_size, D/2) 34 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) # shape: (batch_size, D) 35 | return emb 36 | 37 | 38 | class MLPResNetBlock(nn.Module): 39 | """One MLP ResNet block with a residual connection.""" 40 | def __init__(self, dim): 41 | super().__init__() 42 | self.dim = dim 43 | self.ffn = nn.Sequential( # feedforward network, similar to the ones in Transformers 44 | nn.LayerNorm(dim), 45 | nn.Linear(dim, dim), 46 | nn.ReLU(), 47 | ) 48 | 49 | def forward(self, x): 50 | # x: (batch_size, hidden_dim) 51 | # We follow the module ordering of "Pre-Layer Normalization" feedforward networks in Transformers as 52 | # described here: https://arxiv.org/pdf/2002.04745.pdf 53 | identity = x 54 | x = self.ffn(x) 55 | x = x + identity 56 | return x 57 | 58 | 59 | class MLPResNet(nn.Module): 60 | """MLP with residual connection blocks.""" 61 | def __init__(self, num_blocks, input_dim, hidden_dim, output_dim): 62 | super().__init__() 63 | self.layer_norm1 = nn.LayerNorm(input_dim) 64 | self.fc1 = nn.Linear(input_dim, hidden_dim) 65 | self.relu = nn.ReLU() 66 | self.mlp_resnet_blocks = nn.ModuleList() 67 | for _ in range(num_blocks): 68 | self.mlp_resnet_blocks.append(MLPResNetBlock(dim=hidden_dim)) 69 | self.layer_norm2 = nn.LayerNorm(hidden_dim) 70 | self.fc2 = nn.Linear(hidden_dim, output_dim) 71 | 72 | def forward(self, x): 73 | # x: (batch_size, input_dim) 74 | x = self.layer_norm1(x) # shape: (batch_size, input_dim) 75 | x = self.fc1(x) # shape: (batch_size, hidden_dim) 76 | x = self.relu(x) # shape: (batch_size, hidden_dim) 77 | for block in self.mlp_resnet_blocks: 78 | x = block(x) # shape: (batch_size, hidden_dim) 79 | x = self.layer_norm2(x) # shape: (batch_size, hidden_dim) 80 | x = self.fc2(x) # shape: (batch_size, output_dim) 81 | return x 82 | 83 | class MLPResNet_idcat(nn.Module): 84 | """MLP with residual connection blocks.""" 85 | def __init__(self, num_blocks, input_dim, hidden_dim, output_dim): 86 | super().__init__() 87 | self.layer_norm1 = nn.LayerNorm(input_dim) 88 | self.fc1 = nn.Linear(input_dim + 1, hidden_dim) 89 | self.relu = nn.ReLU() 90 | self.mlp_resnet_blocks = nn.ModuleList() 91 | for _ in range(num_blocks): 92 | self.mlp_resnet_blocks.append(MLPResNetBlock(dim=hidden_dim)) 93 | self.layer_norm2 = nn.LayerNorm(hidden_dim) 94 | self.fc2 = nn.Linear(hidden_dim, output_dim) 95 | 96 | def forward(self, x, taskid): 97 | # x: (batch_size, input_dim) 98 | x = self.layer_norm1(x) # shape: (batch_size, input_dim) 99 | x = torch.cat((x, taskid.unsqueeze(1).unsqueeze(2).repeat(1,8,1)), axis=2) 100 | x = self.fc1(x) # shape: (batch_size, hidden_dim) 101 | x = self.relu(x) # shape: (batch_size, hidden_dim) 102 | for block in self.mlp_resnet_blocks: 103 | x = block(x) # shape: (batch_size, hidden_dim) 104 | x = self.layer_norm2(x) # shape: (batch_size, hidden_dim) 105 | x = self.fc2(x) # shape: (batch_size, output_dim) 106 | return x 107 | 108 | class L1RegressionActionHead_idcat(nn.Module): 109 | """Simple MLP-based action head that generates continuous actions via L1 regression.""" 110 | def __init__( 111 | self, 112 | input_dim=4096, 113 | hidden_dim=4096, 114 | action_dim=7, 115 | ): 116 | super().__init__() 117 | self.action_dim = action_dim 118 | self.model = MLPResNet_idcat( 119 | num_blocks=2, input_dim=input_dim*ACTION_DIM, hidden_dim=hidden_dim, output_dim=action_dim 120 | ) 121 | 122 | def predict_action(self, actions_hidden_states, taskid): 123 | # actions_hidden_states: last hidden states of Transformer corresponding to action tokens in sequence 124 | # - shape: (batch_size, chunk_len * action_dim, hidden_dim) 125 | # ground_truth_actions: ground-truth actions 126 | # - shape: (batch_size, chunk_len, action_dim) 127 | batch_size = actions_hidden_states.shape[0] 128 | device = actions_hidden_states.device 129 | rearranged_actions_hidden_states = actions_hidden_states.reshape(batch_size, NUM_ACTIONS_CHUNK, -1) 130 | action = self.model(rearranged_actions_hidden_states, taskid) 131 | return action 132 | 133 | class L1RegressionDistHead(nn.Module): 134 | """Simple MLP-based action head that generates continuous actions via L1 regression.""" 135 | def __init__( 136 | self, 137 | input_dim=4096, 138 | hidden_dim=4096, 139 | action_dim=1, 140 | ): 141 | super().__init__() 142 | self.action_dim = action_dim 143 | self.model = MLPResNet( 144 | num_blocks=2, input_dim=input_dim*ACTION_DIM, hidden_dim=hidden_dim, output_dim=action_dim 145 | ) 146 | 147 | def predict_action(self, actions_hidden_states): 148 | # actions_hidden_states: last hidden states of Transformer corresponding to action tokens in sequence 149 | # - shape: (batch_size, chunk_len * action_dim, hidden_dim) 150 | # ground_truth_actions: ground-truth actions 151 | # - shape: (batch_size, chunk_len, action_dim) 152 | batch_size = actions_hidden_states.shape[0] 153 | device = actions_hidden_states.device 154 | rearranged_actions_hidden_states = actions_hidden_states.reshape(batch_size, NUM_ACTIONS_CHUNK, -1) 155 | dist = self.model(rearranged_actions_hidden_states).squeeze(2) 156 | dist_ave = dist.mean(dim=1, keepdim=False) 157 | return dist_ave 158 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OmniVLA: An Omni-Modal Vision-Language-Action Model for Robot Navigation 2 | [![Python](https://img.shields.io/badge/python-3.10-blue)](https://www.python.org) 3 | [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT) 4 | [![Static Badge](https://img.shields.io/badge/Project-Page-a)](https://omnivla-nav.github.io) 5 | 6 | 7 | [Noriaki Hirose](https://sites.google.com/view/noriaki-hirose/)1, 2, [Catherine Glossop](https://catglossop.github.io/)1, [Dhruv Shah](https://robodhruv.github.io/)3, [Sergey Levine](https://people.eecs.berkeley.edu/~svlevine/)1 8 | 9 | 1 UC Berkeley (_Berkeley AI Research_), 2 Toyota Motor North America, , 3 Princeton University 10 | 11 | ### Installation 12 | Please set up a conda environment (see instructions in [SETUP.md](SETUP.md)). 13 | 14 | ### Inference 15 | 1. Download our checkpoints and place them in our directory. "omnivla-original" is the trained checkpoints of the OmniVLA for paper submission. "omnivla-original-balance" contains the trained checkpoints of OmniVLA that account for the data balance in the LeLaN dataset. And "omnivla-finetuned-cast" is finetuned checkpoints with the [CAST](https://huggingface.co/datasets/catglossop/CAST-dataset) dataset. 16 | ``` 17 | git clone https://huggingface.co/NHirose/omnivla-original 18 | git clone https://huggingface.co/NHirose/omnivla-original-balance 19 | git clone https://huggingface.co/NHirose/omnivla-finetuned-cast 20 | ``` 21 | 2. Run OmniVLA using a sample current image, goal images, GPS pose, and language prompt. You can view the generated trajectory in the output figure 1_ex.jpg. 22 | ``` 23 | python inference/run_omnivla.py 24 | ``` 25 | 3. Change the goal modality: by default, our code generates actions based on the language prompt. To use a different modality, you can modify the settings around line 560. 26 | 27 | 4. Run OmniVLA to control the real robot. Modify "run_omnivla.py" to update the robot’s state (camera image, GPS signal) and adjust the goal information accordingly. Then, feed the generated velocity commands to your robot. 28 | 29 | 5. To try the finetuned checkpoints with the CAST dataset, update the path and step number in "InferenceConfig" within "run_omnivla.py". 30 | 31 | ### Training 32 | We provide the training code along with a sample dataloader to help you quickly understand the required data loading structure. Since preparing the full training dataset is resource-intensive, we include this simplified code base for convenience. 33 | 34 | 1. Downloading MBRA project code base: 35 | ``` 36 | cd .. 37 | git clone https://github.com/NHirose/Learning-to-Drive-Anywhere-with-MBRA.git 38 | ``` 39 | 2. Downloading MBRA model: 40 | ``` 41 | cd OmniVLA_internal 42 | git clone https://huggingface.co/NHirose/MBRA/ 43 | ``` 44 | 3. You can set the training or debugging mode at line 10 in vla-scripts/train_omnivla.py. Note that even in debugging mode, the code requires at least 20 GB of GPU memory (we use an NVIDIA RTX 4090). 45 | 46 | 4. You can configure visualization at line 11 in vla-scripts/train_omnivla.py. During training, it should be set to False. 47 | 48 | 5. Training our policy from OpenVLA checkpoints (Please fill X): 49 | ``` 50 | torchrun --standalone --nnodes 1 --nproc-per-node X vla-scripts/train_omnivla.py --vla_path openvla/openvla-7b --dataset_name omnivla --num_images_in_input 2 --batch_size X --wandb_entity "X" --wandb_project "omnivla" 51 | ``` 52 | 6. Finetuning our OmniVLA (Please fill X): 53 | ``` 54 | torchrun --standalone --nnodes 1 --nproc-per-node X vla-scripts/train_omnivla.py --vla_path ./omnivla-original --dataset_name omnivla --num_images_in_input 2 --batch_size X --wandb_entity "X" --wandb_project "omnivla" 55 | ```` 56 | 7. Memo finetuning our OmniVLA on our large navigation dataset: 57 | ``` 58 | conda activate omnivla_2 59 | cd /media/noriaki/Noriaki_Data/OmniVLA 60 | torchrun --standalone --nnodes 1 --nproc-per-node 1 vla-scripts/train_omnivla_dataset.py --vla_path ./omnivla-original --dataset_name omnivla --wandb_entity "noriaki-hirose" --wandb_project "omnivla" 61 | ``` 62 | 63 | ### Training with GNM, LeLaN, Frodobots, BDD and CAST datasets 64 | We provide training code that supports multiple public datasets. Before following the full training process, please first ensure that you can run the example training with the sample dataloader. 65 | 66 | 1. Downloading all datasets from the original website. ([GNM](https://github.com/robodhruv/visualnav-transformer), [LeLaN](https://github.com/NHirose/learning-language-navigation), [Frodobots](https://github.com/NHirose/Learning-to-Drive-Anywhere-with-MBRA), [CAST](https://openvla-oft.github.io/)) Please verify that the downloaded datasets work properly in their original codebase, except BDD dataset. Note that please download the LeLaN dataset from [this link](https://huggingface.co/datasets/NHirose/LeLaN_dataset_NoMaD_traj/tree/main) instead of [the original link](https://drive.google.com/file/d/1IazHcIyPGO7ENswz8_sGCIGBXF8_sZJK/view). The updated dataset already includes the NoMaD trajectories used for collision-avoidance supervision, you no longer need to compute the NoMaD policy during training. Please carefully follow the usage procedure described in the [LeLaN codebase](https://github.com/NHirose/learning-language-navigation) when working with the dataset. 67 | 68 | 2. Downloading the modified BDD dataset with MBRA annotations from [here](https://huggingface.co/datasets/NHirose/BDD_OmniVLA) and extract it. The image sequences in the modified dataset remain subject to the [original BDD license](http://bdd-data.berkeley.edu/download.html), while the additional MBRA annotations are released under the MIT license. 69 | 70 | 3. Downloading the lerobot code base for the Frodobots dataset dataloader: 71 | ``` 72 | git clone https://github.com/huggingface/lerobot.git 73 | ``` 74 | 4. Edit the data path in config_nav/mbra_and_dataset_config.yaml: 75 | 76 | 5. Training our policy from OpenVLA checkpoints (Please fill X): 77 | ``` 78 | torchrun --standalone --nnodes 1 --nproc-per-node X vla-scripts/train_omnivla_dataset.py --vla_path ./omnivla-original --dataset_name omnivla --wandb_entity "X" --wandb_project "omnivla" 79 | ``` 80 | 81 | In our training setup, we use 8 Nvidia H100 GPUs (80 GB each) across 8 nodes. The batch sizes are configured as [LeLaN, GNM, Frodobots, BDD] = [4, 1, 1, 1], with gradient accumulation set to 4 steps. When finetuning with CAST dataset, we set the batch size as [LeLaN, CAST, GNM, Frodobots, BDD] = [2, 2, 1, 1, 1]. To do so, you need to directly edit train_omnivla_dataset.py. 82 | 83 | ### Acknowledgement 84 | We implement our ideas and design choices on top of the pretrained checkpoints. Our work builds upon the [OpenVLA-OFT](https://openvla-oft.github.io/) codebase, with additional code added to create OmniVLA. As such, our implementation leverages many components of the OpenVLA-OFT codebase. We sincerely appreciate the effort and contributions of the OpenVLA-OFT team! 85 | 86 | ## Citing 87 | ``` 88 | @misc{hirose2025omnivla, 89 | title={OmniVLA: An Omni-Modal Vision-Language-Action Model for Robot Navigation}, 90 | author={Noriaki Hirose and Catherine Glossop and Dhruv Shah and Sergey Levine}, 91 | year={2025}, 92 | eprint={2509.19480}, 93 | archivePrefix={arXiv}, 94 | primaryClass={cs.RO}, 95 | url={https://arxiv.org/abs/2509.19480}, 96 | } 97 | -------------------------------------------------------------------------------- /prismatic/training/strategies/ddp.py: -------------------------------------------------------------------------------- 1 | """ 2 | ddp.py 3 | 4 | Core class definition for a strategy implementing Torch native Distributed Data Parallel Training; note that on most 5 | GPU hardware and LLM backbones >= 5-7B parameters, DDP training will OOM, which is why we opt for FSDP. 6 | """ 7 | 8 | import shutil 9 | from pathlib import Path 10 | from typing import Optional 11 | 12 | import torch 13 | from torch.nn.parallel import DistributedDataParallel as DDP 14 | from torch.optim import AdamW 15 | from transformers.optimization import get_constant_schedule, get_cosine_schedule_with_warmup 16 | 17 | from prismatic.overwatch import initialize_overwatch 18 | from prismatic.training.strategies.base_strategy import TrainingStrategy 19 | 20 | # Initialize Overwatch =>> Wraps `logging.Logger` 21 | overwatch = initialize_overwatch(__name__) 22 | 23 | 24 | class DDPStrategy(TrainingStrategy): 25 | @overwatch.rank_zero_only 26 | def save_checkpoint( 27 | self, 28 | run_dir: Path, 29 | global_step: int, 30 | epoch: int, 31 | train_loss: Optional[float] = None, 32 | only_trainable: bool = True, 33 | ) -> None: 34 | """Save a checkpoint to the `run_dir` only containing the state_dicts for trainable parameters by default.""" 35 | assert isinstance(self.vlm, DDP), "save_checkpoint assumes VLM is already wrapped in DDP!" 36 | 37 | # Splinter State Dictionary by Top-Level Submodules (or subset, if `only_trainable`) 38 | model_state_dicts = { 39 | mkey: getattr(self.vlm.module, mkey).state_dict() 40 | for mkey in (self.trainable_module_keys if only_trainable else self.all_module_keys) 41 | } 42 | optimizer_state_dict = self.optimizer.state_dict() 43 | 44 | # Set Checkpoint Path =>> Embed *minimal* training statistics! 45 | checkpoint_dir = run_dir / "checkpoints" 46 | if train_loss is None: 47 | checkpoint_path = checkpoint_dir / f"step-{global_step:06d}-epoch-{epoch:02d}-loss=inf.pt" 48 | else: 49 | checkpoint_path = checkpoint_dir / f"step-{global_step:06d}-epoch-{epoch:02d}-loss={train_loss:.4f}.pt" 50 | 51 | # Save Checkpoint & Copy Latest to `latest-checkpoint.pt` 52 | torch.save({"model": model_state_dicts, "optimizer": optimizer_state_dict}, checkpoint_path) 53 | shutil.copy(checkpoint_path, checkpoint_dir / "latest-checkpoint.pt") 54 | 55 | def run_setup(self, run_dir: Path, n_train_examples: int) -> None: 56 | # Gradient Checkpointing Setup 57 | if self.enable_gradient_checkpointing: 58 | # For Gradient Checkpointing --> we make the assumption that the "bulk" of activation memory is taken up 59 | # by the LLM; because we also make the explicit assumption that each LLM is derived from a HF 60 | # pretrained model, the only thing we *need* to do (technically) is call `gradient_checkpoint_enable` 61 | # on `self.llm_backbone`. 62 | # 63 | # What does it actually do? --> runs the *generic* custom_forward + torch.utils.checkpoint.checkpoint logic 64 | # => github.com/huggingface/transformers/.../models/llama/modeling_llama.py#L692-L706 65 | # 66 | # Additional Reference (to better understand gradient checkpointing in PyTorch writ large) 67 | # => github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb 68 | overwatch.info("Enabling Gradient Checkpointing on LLM Backbone", ctx_level=1) 69 | self.vlm.llm_backbone.gradient_checkpointing_enable() 70 | 71 | # Move to Device =>> Note parameters are in full precision (*mixed precision* will only autocast as appropriate) 72 | overwatch.info("Placing Entire VLM (Vision Backbone, LLM Backbone, Projector Weights) on GPU", ctx_level=1) 73 | self.vlm.to(self.device_id) 74 | 75 | # Wrap with Distributed Data Parallel 76 | # => Note: By default, wrapping naively with DDP(self.vlm) will initialize a *separate* buffer on GPU that 77 | # is the same size/dtype as the model parameters; this will *double* GPU memory! 78 | # - stackoverflow.com/questions/68949954/model-takes-twice-the-memory-footprint-with-distributed-data-parallel 79 | overwatch.info("Wrapping VLM with Distributed Data Parallel", ctx_level=1) 80 | self.vlm = DDP(self.vlm, device_ids=[self.device_id], gradient_as_bucket_view=True) 81 | 82 | # Create Optimizer and LR Scheduler =>> note that most of the LR Schedulers we use require `max_steps/epochs` 83 | # => Optimizer should only operate on parameters that are *unfrozen* / trainable! 84 | trainable_params = [param for param in self.vlm.parameters() if param.requires_grad] 85 | if self.max_steps is None: 86 | num_training_steps = (n_train_examples * self.epochs) // self.global_batch_size 87 | else: 88 | num_training_steps = self.max_steps 89 | 90 | if self.lr_scheduler_type == "linear-warmup+cosine-decay": 91 | # Set warmup steps (floor) based on `warmup_ratio` (should be 0.03 - 0.05) 92 | num_warmup_steps = int(num_training_steps * self.warmup_ratio) 93 | 94 | assert self.weight_decay == 0, "DDP training does not currently support `weight_decay` > 0!" 95 | self.optimizer = AdamW(trainable_params, lr=self.learning_rate, weight_decay=self.weight_decay) 96 | self.lr_scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps, num_training_steps) 97 | for param_group in self.optimizer.param_groups: 98 | param_group["lr"] = 0.0 99 | 100 | elif self.lr_scheduler_type == "constant": 101 | num_warmup_steps = 0 102 | 103 | assert self.weight_decay == 0, "DDP training does not currently support `weight_decay` > 0!" 104 | self.optimizer = AdamW(trainable_params, lr=self.learning_rate, weight_decay=self.weight_decay) 105 | self.lr_scheduler = get_constant_schedule(self.optimizer) 106 | 107 | else: 108 | raise ValueError(f"Learning Rate Schedule with type `{self.lr_scheduler_type}` is not supported!") 109 | 110 | # Finalize Setup =>> Log 111 | overwatch.info( 112 | "DDP Strategy =>> Finalized Training Setup:\n" 113 | f" |-> Global (Effective) Batch Size = {self.global_batch_size}\n" 114 | f" |-> Per-Device Batch Size = {self.per_device_batch_size}\n" 115 | f" |-> Distributed World Size = {overwatch.world_size()}\n" 116 | f" |-> Gradient Accumulation Steps = {self.grad_accumulation_steps}\n\n" 117 | f" |-> LLM Backbone Gradient Checkpointing = {self.enable_gradient_checkpointing}\n" 118 | f" |-> Use Native AMP = {self.enable_mixed_precision_training} ({self.mixed_precision_dtype})\n\n" 119 | f" |-> Default AdamW LR = {self.learning_rate}\n" 120 | f" |-> AdamW Weight Decay = {self.weight_decay}\n" 121 | f" |-> LR Scheduler Type = {self.lr_scheduler_type}\n" 122 | f" |-> LR Scheduler Warmup Steps (Ratio) = {num_warmup_steps} ({self.warmup_ratio})\n" 123 | f" |-> Dataset Size = {n_train_examples} Examples\n" 124 | f" |-> Max Steps = {num_training_steps}\n" 125 | ) 126 | 127 | def clip_grad_norm(self) -> None: 128 | torch.nn.utils.clip_grad_norm_(self.vlm.parameters(), max_norm=self.max_grad_norm) 129 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/cast_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import torchvision.transforms.functional as TF 4 | import numpy as np 5 | import random 6 | from PIL import Image 7 | from typing import Type 8 | from prismatic.vla.constants import IGNORE_INDEX 9 | from prismatic.models.backbones.llm.prompting import PromptBuilder 10 | from prismatic.vla.action_tokenizer import ActionTokenizer 11 | from transformers import PreTrainedTokenizerBase 12 | from prismatic.models.backbones.vision import ImageTransform 13 | from vint_train.data.data_utils import calculate_sin_cos, to_local_coords 14 | 15 | 16 | class CAST_Dataset(Dataset): 17 | def __init__( 18 | self, 19 | action_tokenizer: PreTrainedTokenizerBase, 20 | base_tokenizer: ActionTokenizer, 21 | image_transform: ImageTransform, 22 | prompt_builder_fn: Type[PromptBuilder], 23 | dataset_name, 24 | data_loc, 25 | data_size, 26 | features, 27 | predict_stop_token: bool = True, 28 | ): 29 | self.dataset_name = dataset_name 30 | self.data_loc = data_loc 31 | self.data_size = data_size 32 | self.features = features 33 | self.image_size = (96, 96) 34 | self.image_size_clip = (224, 224) 35 | 36 | self.action_tokenizer = action_tokenizer 37 | self.base_tokenizer = base_tokenizer 38 | self.prompt_builder = prompt_builder_fn 39 | self.predict_stop_token = predict_stop_token 40 | self.image_transform = image_transform 41 | 42 | def __len__(self): 43 | return self.data_size 44 | 45 | def _resize_norm(self, image, size): 46 | return TF.resize(image, size) 47 | 48 | def _compute_actions(self, action_yaw, goal_pose, metric_waypoint): 49 | positions = action_yaw[:, 0:2] 50 | yaw = action_yaw[:, 2] 51 | 52 | waypoints = to_local_coords(positions, positions[0], yaw[0]) 53 | goal_pos = to_local_coords(goal_pose[:, 0:2], positions[0], yaw[0]) 54 | 55 | yaw = yaw[1:] - yaw[0] 56 | actions = np.concatenate([waypoints[1:], yaw[:, None]], axis=-1) 57 | yawg = goal_pose[:, 2:3] - yaw[0] 58 | goal_pos = np.concatenate([goal_pos, yawg], axis=1) 59 | 60 | actions[:, :2] /= metric_waypoint 61 | goal_pos[:, :2] /= metric_waypoint 62 | 63 | return torch.from_numpy(actions), torch.from_numpy(goal_pos) 64 | 65 | def __getitem__(self, idx): 66 | folder_name = self.dataset_name.split("_convert") 67 | directory_location = self.data_loc + self.dataset_name + "/" + folder_name[0] + "/" 68 | 69 | len_action = 0 70 | while len_action < 10: 71 | traj = np.load(directory_location + f"traj_{idx:06d}.npz", allow_pickle=True) 72 | len_action = len(traj['action']) 73 | if len_action < 10: 74 | idx = random.randint(0, self.data_size - 1) 75 | 76 | num = random.randint(0, len(traj['action']) - 8 - 2) 77 | gid = max(len(traj['action']) - 1, num + 8) 78 | 79 | obs_dict = traj["observation"].item() 80 | cur_pilimg = obs_dict['image'][num] 81 | goal_pilimg = obs_dict['image'][gid] 82 | 83 | cur_obs = cur_pilimg.transpose(2, 0, 1) 84 | goal_obs = goal_pilimg.transpose(2, 0, 1) 85 | 86 | pil_img = Image.fromarray(cur_pilimg.astype(np.uint8)).resize(self.image_size_clip) 87 | pil_img_goal = Image.fromarray(goal_pilimg.astype(np.uint8)).resize(self.image_size_clip) 88 | 89 | pixel_values = self.image_transform(pil_img) 90 | pixel_values_g = self.image_transform(pil_img_goal) 91 | 92 | action_yaw = obs_dict['state'][num: num+8+1] 93 | goal_pose = obs_dict['state'][gid:gid+1] 94 | 95 | actions_norm, goal_pose_norm = self._compute_actions( 96 | action_yaw, goal_pose, traj["normalization_factor"] 97 | ) 98 | actions_torch = calculate_sin_cos(actions_norm) 99 | goal_pose_torch = calculate_sin_cos(goal_pose_norm) 100 | 101 | language_instruction = traj['language_instruction'][0] 102 | non_empty_prompts = [p for p in language_instruction if p] 103 | selected_prompt = random.choice(non_empty_prompts).decode('utf-8') 104 | 105 | actions = actions_torch 106 | current_action = actions[0] 107 | future_actions = actions[1:] 108 | future_actions_string = ''.join(self.action_tokenizer(future_actions)) 109 | current_action_string = self.action_tokenizer(current_action) 110 | action_chunk_string = current_action_string + future_actions_string 111 | action_chunk_len = len(action_chunk_string) 112 | 113 | lang = selected_prompt.lower() 114 | conversation = [ 115 | {"from": "human", "value": f"What action should the robot take to {lang}?"}, 116 | {"from": "gpt", "value": action_chunk_string}, 117 | ] 118 | 119 | prompt_builder = self.prompt_builder("openvla") 120 | for turn in conversation: 121 | prompt_builder.add_turn(turn["from"], turn["value"]) 122 | 123 | input_ids = self.base_tokenizer(prompt_builder.get_prompt(), add_special_tokens=True).input_ids 124 | labels = list(input_ids) 125 | 126 | max_token = 60 127 | if len(input_ids) > max_token: 128 | lang = "move toward XXXXX" 129 | conversation = [ 130 | {"from": "human", "value": f"What action should the robot take to {lang}?"}, 131 | {"from": "gpt", "value": action_chunk_string}, 132 | ] 133 | prompt_builder = self.prompt_builder("openvla") 134 | for turn in conversation: 135 | prompt_builder.add_turn(turn["from"], turn["value"]) 136 | 137 | input_ids = self.base_tokenizer(prompt_builder.get_prompt(), add_special_tokens=True).input_ids 138 | labels = list(input_ids) 139 | 140 | input_ids, labels = torch.tensor(input_ids), torch.tensor(labels) 141 | 142 | obj_pose_norm = goal_pose_torch[0, 0:2] 143 | goal_pose_cos_sin = goal_pose_torch.squeeze(0) 144 | 145 | labels[: -(action_chunk_len + 1)] = IGNORE_INDEX 146 | if not self.predict_stop_token: 147 | labels[-1] = IGNORE_INDEX 148 | 149 | goal_id = 0 150 | cur_image_r = self._resize_norm(torch.from_numpy(cur_obs), self.image_size).repeat(6,1,1)/255.0 151 | goal_image_full_8_r = self._resize_norm(torch.from_numpy(goal_obs), self.image_size)/255.0 152 | 153 | dataset_name = "cast" 154 | modality_id = 7 # language only 155 | action_select_mask = torch.tensor(1.0) 156 | 157 | return dict( 158 | pixel_values=pixel_values, 159 | pixel_values_goal=pixel_values_g, 160 | input_ids=input_ids, 161 | labels=labels, 162 | dataset_name=dataset_name, 163 | modality_id=modality_id, 164 | actions=torch.as_tensor(actions_torch), 165 | action_select_mask=action_select_mask, 166 | goal_pose=goal_pose_cos_sin, 167 | obj_pose_norm=obj_pose_norm, 168 | img_PIL=pil_img, 169 | gimg_PIL=pil_img_goal, 170 | cur_image=cur_image_r, 171 | goal_image_8=goal_image_full_8_r, 172 | temp_dist=goal_id, 173 | lan_prompt=lang 174 | ) 175 | 176 | -------------------------------------------------------------------------------- /prismatic/models/backbones/vision/dinoclip_vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | dinoclip_vit.py 3 | 4 | Vision backbone that returns concatenated features from both DINOv2 and CLIP. 5 | """ 6 | 7 | from dataclasses import dataclass 8 | from functools import partial 9 | from typing import Callable, Dict, Tuple 10 | 11 | import timm 12 | import torch 13 | from PIL import Image 14 | from timm.models.vision_transformer import Block, VisionTransformer 15 | from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy 16 | from torchvision.transforms import Compose, Resize 17 | 18 | from prismatic.models.backbones.vision.base_vision import ImageTransform, LetterboxPad, VisionBackbone, unpack_tuple 19 | 20 | # Registry =>> Supported DinoCLIP Pairs (as TIMM identifiers) 21 | DINOCLIP_VISION_BACKBONES = { 22 | "dinoclip-vit-l-336px": { 23 | "dino": "vit_large_patch14_reg4_dinov2.lvd142m", 24 | "clip": "vit_large_patch14_clip_336.openai", 25 | }, 26 | } 27 | 28 | 29 | @dataclass 30 | class DinoCLIPImageTransform: 31 | dino_image_transform: ImageTransform 32 | clip_image_transform: ImageTransform 33 | is_prismatic: bool = True 34 | 35 | def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]: 36 | return {"dino": self.dino_image_transform(img, **kwargs), "clip": self.clip_image_transform(img, **kwargs)} 37 | 38 | 39 | class DinoCLIPViTBackbone(VisionBackbone): 40 | def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: 41 | super().__init__(vision_backbone_id, image_resize_strategy, default_image_size=default_image_size) 42 | self.dino_timm_path_or_url = DINOCLIP_VISION_BACKBONES[vision_backbone_id]["dino"] 43 | self.clip_timm_path_or_url = DINOCLIP_VISION_BACKBONES[vision_backbone_id]["clip"] 44 | 45 | # Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary 46 | self.dino_featurizer: VisionTransformer = timm.create_model( 47 | self.dino_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size 48 | ) 49 | self.dino_featurizer.eval() 50 | 51 | self.clip_featurizer: VisionTransformer = timm.create_model( 52 | self.clip_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size 53 | ) 54 | self.clip_featurizer.eval() 55 | 56 | # Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility 57 | # => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches! 58 | # => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385 59 | self.dino_featurizer.forward = unpack_tuple( 60 | partial(self.dino_featurizer.get_intermediate_layers, n={len(self.dino_featurizer.blocks) - 2}) 61 | ) 62 | self.clip_featurizer.forward = unpack_tuple( 63 | partial(self.clip_featurizer.get_intermediate_layers, n={len(self.clip_featurizer.blocks) - 2}) 64 | ) 65 | 66 | # Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models 67 | self.dino_data_cfg = timm.data.resolve_model_data_config(self.dino_featurizer) 68 | self.dino_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size) 69 | 70 | self.clip_data_cfg = timm.data.resolve_model_data_config(self.clip_featurizer) 71 | self.clip_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size) 72 | 73 | # Initialize *both* Transforms 74 | default_dino_transform = timm.data.create_transform(**self.dino_data_cfg, is_training=False) 75 | default_clip_transform = timm.data.create_transform(**self.clip_data_cfg, is_training=False) 76 | if self.image_resize_strategy == "resize-naive": 77 | assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_image_transform`!" 78 | assert isinstance(default_clip_transform, Compose), "Unexpected `default_clip_image_transform`!" 79 | assert isinstance(default_dino_transform.transforms[0], Resize) 80 | assert isinstance(default_clip_transform.transforms[0], Resize) 81 | 82 | target_size = (self.default_image_size, self.default_image_size) 83 | dino_transform = Compose( 84 | [ 85 | Resize(target_size, interpolation=default_dino_transform.transforms[0].interpolation), 86 | *default_dino_transform.transforms[1:], 87 | ] 88 | ) 89 | clip_transform = Compose( 90 | [ 91 | Resize(target_size, interpolation=default_clip_transform.transforms[0].interpolation), 92 | *default_clip_transform.transforms[1:], 93 | ] 94 | ) 95 | 96 | self.image_transform = DinoCLIPImageTransform(dino_transform, clip_transform) 97 | 98 | elif self.image_resize_strategy == "resize-crop": 99 | self.image_transform = DinoCLIPImageTransform(default_dino_transform, default_clip_transform) 100 | 101 | elif self.image_resize_strategy == "letterbox": 102 | assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_transform`!" 103 | assert isinstance(default_clip_transform, Compose), "Unexpected `default_clip_transform`!" 104 | assert "mean" in self.dino_data_cfg and "mean" in self.clip_data_cfg, "DinoCLIP `data_cfg` missing `mean`!" 105 | 106 | # Compute Padding Fill Value(s) (rescaled normalization mean if applicable) 107 | dino_fill = tuple([int(x * 255) for x in self.dino_data_cfg["mean"]]) 108 | clip_fill = tuple([int(x * 255) for x in self.clip_data_cfg["mean"]]) 109 | 110 | # Build New Transform 111 | self.image_transform = DinoCLIPImageTransform( 112 | Compose([LetterboxPad(dino_fill), *default_dino_transform.transforms]), 113 | Compose([LetterboxPad(clip_fill), *default_clip_transform.transforms]), 114 | ) 115 | 116 | else: 117 | raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!") 118 | 119 | def get_fsdp_wrapping_policy(self) -> Callable: 120 | """Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers.""" 121 | vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer}) 122 | transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) 123 | return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy]) 124 | 125 | def forward(self, pixel_values: Dict[str, torch.Tensor]) -> torch.Tensor: 126 | """Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches.""" 127 | dino_patches = self.dino_featurizer(pixel_values["dino"]) 128 | clip_patches = self.clip_featurizer(pixel_values["clip"]) 129 | 130 | return torch.cat([dino_patches, clip_patches], dim=2) 131 | 132 | @property 133 | def default_image_resolution(self) -> Tuple[int, int, int]: 134 | return self.dino_data_cfg["input_size"] 135 | 136 | @property 137 | def embed_dim(self) -> int: 138 | return self.dino_featurizer.embed_dim + self.clip_featurizer.embed_dim 139 | 140 | @property 141 | def num_patches(self) -> int: 142 | assert self.dino_featurizer.patch_embed.num_patches == self.clip_featurizer.patch_embed.num_patches 143 | return self.dino_featurizer.patch_embed.num_patches 144 | 145 | @property 146 | def half_precision_dtype(self) -> torch.dtype: 147 | return torch.bfloat16 148 | -------------------------------------------------------------------------------- /prismatic/models/backbones/vision/dinosiglip_vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | dinosiglip_vit.py 3 | 4 | Vision backbone that returns concatenated features from both DINOv2 and SigLIP. 5 | """ 6 | 7 | from dataclasses import dataclass 8 | from functools import partial 9 | from typing import Callable, Dict, Tuple 10 | 11 | import timm 12 | import torch 13 | from PIL import Image 14 | from timm.models.vision_transformer import Block, VisionTransformer 15 | from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy 16 | from torchvision.transforms import Compose, Resize 17 | 18 | from prismatic.models.backbones.vision.base_vision import ImageTransform, LetterboxPad, VisionBackbone, unpack_tuple 19 | 20 | # Registry =>> Supported DinoSigLIP Pairs (as TIMM identifiers) 21 | DINOSigLIP_VISION_BACKBONES = { 22 | "dinosiglip-vit-so-224px": { 23 | "dino": "vit_large_patch14_reg4_dinov2.lvd142m", 24 | "siglip": "vit_so400m_patch14_siglip_224", 25 | }, 26 | "dinosiglip-vit-so-384px": { 27 | "dino": "vit_large_patch14_reg4_dinov2.lvd142m", 28 | "siglip": "vit_so400m_patch14_siglip_384", 29 | }, 30 | } 31 | 32 | 33 | @dataclass 34 | class DinoSigLIPImageTransform: 35 | dino_image_transform: ImageTransform 36 | siglip_image_transform: ImageTransform 37 | is_prismatic: bool = True 38 | 39 | def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]: 40 | return {"dino": self.dino_image_transform(img, **kwargs), "siglip": self.siglip_image_transform(img, **kwargs)} 41 | 42 | 43 | class DinoSigLIPViTBackbone(VisionBackbone): 44 | def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: 45 | super().__init__(vision_backbone_id, image_resize_strategy, default_image_size=default_image_size) 46 | self.dino_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[vision_backbone_id]["dino"] 47 | self.siglip_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[vision_backbone_id]["siglip"] 48 | 49 | # Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary 50 | self.dino_featurizer: VisionTransformer = timm.create_model( 51 | self.dino_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size 52 | ) 53 | self.dino_featurizer.eval() 54 | 55 | self.siglip_featurizer: VisionTransformer = timm.create_model( 56 | self.siglip_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size 57 | ) 58 | self.siglip_featurizer.eval() 59 | 60 | # Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility 61 | # => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches! 62 | # => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385 63 | self.dino_featurizer.forward = unpack_tuple( 64 | partial(self.dino_featurizer.get_intermediate_layers, n={len(self.dino_featurizer.blocks) - 2}) 65 | ) 66 | self.siglip_featurizer.forward = unpack_tuple( 67 | partial(self.siglip_featurizer.get_intermediate_layers, n={len(self.siglip_featurizer.blocks) - 2}) 68 | ) 69 | 70 | # Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models 71 | self.dino_data_cfg = timm.data.resolve_model_data_config(self.dino_featurizer) 72 | self.dino_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size) 73 | 74 | self.siglip_data_cfg = timm.data.resolve_model_data_config(self.siglip_featurizer) 75 | self.siglip_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size) 76 | 77 | # Initialize *both* Transforms 78 | default_dino_transform = timm.data.create_transform(**self.dino_data_cfg, is_training=False) 79 | default_siglip_transform = timm.data.create_transform(**self.siglip_data_cfg, is_training=False) 80 | 81 | # Fix =>> SigLIP default transform resizes to *larger* than `self.default_image_size` (crops image)!! 82 | assert isinstance(default_siglip_transform, Compose), "Unexpected `default_image_transform`!" 83 | assert isinstance(default_siglip_transform.transforms[0], Resize) 84 | default_siglip_transform = Compose( 85 | [ 86 | Resize(self.default_image_size, interpolation=default_siglip_transform.transforms[0].interpolation), 87 | *default_siglip_transform.transforms[1:], 88 | ] 89 | ) 90 | 91 | if self.image_resize_strategy == "resize-naive": 92 | assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_image_transform`!" 93 | assert isinstance(default_siglip_transform, Compose), "Unexpected `default_siglip_image_transform`!" 94 | assert isinstance(default_dino_transform.transforms[0], Resize) 95 | assert isinstance(default_siglip_transform.transforms[0], Resize) 96 | 97 | target_size = (self.default_image_size, self.default_image_size) 98 | dino_transform = Compose( 99 | [ 100 | Resize(target_size, interpolation=default_dino_transform.transforms[0].interpolation), 101 | *default_dino_transform.transforms[1:], 102 | ] 103 | ) 104 | siglip_transform = Compose( 105 | [ 106 | Resize(target_size, interpolation=default_siglip_transform.transforms[0].interpolation), 107 | *default_siglip_transform.transforms[1:], 108 | ] 109 | ) 110 | 111 | self.image_transform = DinoSigLIPImageTransform(dino_transform, siglip_transform) 112 | 113 | elif self.image_resize_strategy == "resize-crop": 114 | self.image_transform = DinoSigLIPImageTransform(default_dino_transform, default_siglip_transform) 115 | 116 | elif self.image_resize_strategy == "letterbox": 117 | assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_transform`!" 118 | assert isinstance(default_siglip_transform, Compose), "Unexpected `default_siglip_transform`!" 119 | assert ( 120 | "mean" in self.dino_data_cfg and "mean" in self.siglip_data_cfg 121 | ), "DinoSigLIP `data_cfg` missing `mean`!" 122 | 123 | # Compute Padding Fill Value(s) (rescaled normalization mean if applicable) 124 | dino_fill = tuple([int(x * 255) for x in self.dino_data_cfg["mean"]]) 125 | siglip_fill = tuple([int(x * 255) for x in self.siglip_data_cfg["mean"]]) 126 | 127 | # Build New Transform 128 | self.image_transform = DinoSigLIPImageTransform( 129 | Compose([LetterboxPad(dino_fill), *default_dino_transform.transforms]), 130 | Compose([LetterboxPad(siglip_fill), *default_siglip_transform.transforms]), 131 | ) 132 | 133 | else: 134 | raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!") 135 | 136 | def get_fsdp_wrapping_policy(self) -> Callable: 137 | """Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers.""" 138 | vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer}) 139 | transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) 140 | return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy]) 141 | 142 | def forward(self, pixel_values: Dict[str, torch.Tensor]) -> torch.Tensor: 143 | """Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches.""" 144 | dino_patches = self.dino_featurizer(pixel_values["dino"]) 145 | siglip_patches = self.siglip_featurizer(pixel_values["siglip"]) 146 | 147 | return torch.cat([dino_patches, siglip_patches], dim=2) 148 | 149 | @property 150 | def default_image_resolution(self) -> Tuple[int, int, int]: 151 | return self.dino_data_cfg["input_size"] 152 | 153 | @property 154 | def embed_dim(self) -> int: 155 | return self.dino_featurizer.embed_dim + self.siglip_featurizer.embed_dim 156 | 157 | @property 158 | def num_patches(self) -> int: 159 | assert self.dino_featurizer.patch_embed.num_patches == self.siglip_featurizer.patch_embed.num_patches 160 | return self.dino_featurizer.patch_embed.num_patches 161 | 162 | @property 163 | def half_precision_dtype(self) -> torch.dtype: 164 | return torch.bfloat16 165 | -------------------------------------------------------------------------------- /prismatic/preprocessing/download.py: -------------------------------------------------------------------------------- 1 | """ 2 | download.py 3 | 4 | Utility functions for downloading and extracting various datasets to (local) disk. 5 | """ 6 | 7 | import os 8 | import shutil 9 | from pathlib import Path 10 | from typing import Dict, List, TypedDict 11 | from zipfile import ZipFile 12 | 13 | import requests 14 | from PIL import Image 15 | from rich.progress import BarColumn, DownloadColumn, MofNCompleteColumn, Progress, TextColumn, TransferSpeedColumn 16 | from tqdm import tqdm 17 | 18 | from prismatic.overwatch import initialize_overwatch 19 | 20 | # Initialize Overwatch =>> Wraps `logging.Logger` 21 | overwatch = initialize_overwatch(__name__) 22 | 23 | 24 | # === Dataset Registry w/ Links === 25 | # fmt: off 26 | DatasetComponent = TypedDict( 27 | "DatasetComponent", 28 | {"name": str, "extract": bool, "extract_type": str, "url": str, "do_rename": bool}, 29 | total=False 30 | ) 31 | 32 | DATASET_REGISTRY: Dict[str, List[DatasetComponent]] = { 33 | # === LLaVa v1.5 Dataset(s) === 34 | 35 | # Note =>> This is the full suite of datasets included in the LLaVa 1.5 "finetuning" stage; all the LLaVa v1.5 36 | # models are finetuned on this split. We use this dataset for all experiments in our paper. 37 | "llava-laion-cc-sbu-558k": [ 38 | { 39 | "name": "chat.json", # Contains the "chat" traces :: {"human" => , "gpt" => } 40 | "extract": False, 41 | "url": "https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/resolve/main/blip_laion_cc_sbu_558k.json", 42 | "do_rename": True, 43 | }, 44 | { 45 | "name": "images", # Contains the LLaVa Processed Images (jpgs, 224x224 resolution) 46 | "extract": True, 47 | "extract_type": "directory", 48 | "url": "https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/resolve/main/images.zip", 49 | "do_rename": False, 50 | } 51 | ], 52 | 53 | "llava-v1.5-instruct": [ 54 | { 55 | "name": "llava_v1_5_mix665k.json", 56 | "extract": False, 57 | "url": ( 58 | "https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/resolve/main/llava_v1_5_mix665k.json" 59 | ), 60 | "do_rename": True, 61 | }, 62 | { 63 | "name": "coco/train2017", # Visual Instruct Tuning images are all sourced from COCO Train 2017 64 | "extract": True, 65 | "extract_type": "directory", 66 | "url": "http://images.cocodataset.org/zips/train2017.zip", 67 | "do_rename": True, 68 | }, 69 | { 70 | "name": "gqa/images", 71 | "extract": True, 72 | "extract_type": "directory", 73 | "url": "https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip", 74 | "do_rename": True, 75 | }, 76 | { 77 | "name": "ocr_vqa/images", 78 | "extract": True, 79 | "extract_type": "directory", 80 | "url": "https://huggingface.co/datasets/qnguyen3/ocr_vqa/resolve/main/ocr_vqa.zip", 81 | "do_rename": True, 82 | }, 83 | { 84 | "name": "textvqa/train_images", 85 | "extract": True, 86 | "extract_type": "directory", 87 | "url": "https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip", 88 | "do_rename": True, 89 | }, 90 | { 91 | "name": "vg/VG_100K", 92 | "extract": True, 93 | "extract_type": "directory", 94 | "url": "https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip", 95 | "do_rename": True, 96 | }, 97 | { 98 | "name": "vg/VG_100K_2", 99 | "extract": True, 100 | "extract_type": "directory", 101 | "url": "https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip", 102 | "do_rename": True, 103 | }, 104 | ] 105 | } 106 | # fmt: on 107 | 108 | 109 | def convert_to_jpg(image_dir: Path) -> None: 110 | """Handling for OCR-VQA Images specifically; iterates through directory, converts all GIFs/PNGs.""" 111 | overwatch.info(f"Converting all Images in `{image_dir}` to JPG") 112 | 113 | for image_fn in tqdm(list(image_dir.iterdir())): 114 | if image_fn.suffix in {".jpg", ".jpeg"} or (jpg_fn := image_dir / f"{image_fn.stem}.jpg").exists(): 115 | continue 116 | 117 | if image_fn.suffix == ".gif": 118 | gif = Image.open(image_fn) 119 | gif.seek(0) 120 | gif.convert("RGB").save(jpg_fn) 121 | elif image_fn.suffix == ".png": 122 | Image.open(image_fn).convert("RGB").save(jpg_fn) 123 | else: 124 | raise ValueError(f"Unexpected image format `{image_fn.suffix}`") 125 | 126 | 127 | def download_with_progress(url: str, download_dir: Path, chunk_size_bytes: int = 1024) -> Path: 128 | """Utility function for downloading files from the internet, with a handy Rich-based progress bar.""" 129 | overwatch.info(f"Downloading {(dest_path := download_dir / Path(url).name)} from `{url}`", ctx_level=1) 130 | if dest_path.exists(): 131 | return dest_path 132 | 133 | # Otherwise --> fire an HTTP Request, with `stream = True` 134 | response = requests.get(url, stream=True) 135 | 136 | # Download w/ Transfer-Aware Progress 137 | # => Reference: https://github.com/Textualize/rich/blob/master/examples/downloader.py 138 | with Progress( 139 | TextColumn("[bold]{task.description} - {task.fields[fname]}"), 140 | BarColumn(bar_width=None), 141 | "[progress.percentage]{task.percentage:>3.1f}%", 142 | "•", 143 | DownloadColumn(), 144 | "•", 145 | TransferSpeedColumn(), 146 | transient=True, 147 | ) as dl_progress: 148 | dl_tid = dl_progress.add_task( 149 | "Downloading", fname=dest_path.name, total=int(response.headers.get("content-length", "None")) 150 | ) 151 | with open(dest_path, "wb") as f: 152 | for data in response.iter_content(chunk_size=chunk_size_bytes): 153 | dl_progress.advance(dl_tid, f.write(data)) 154 | 155 | return dest_path 156 | 157 | 158 | def extract_with_progress(archive_path: Path, download_dir: Path, extract_type: str, cleanup: bool = False) -> Path: 159 | """Utility function for extracting compressed archives, with a handy Rich-based progress bar.""" 160 | assert archive_path.suffix == ".zip", "Only `.zip` compressed archives are supported for now!" 161 | overwatch.info(f"Extracting {archive_path.name} to `{download_dir}`", ctx_level=1) 162 | 163 | # Extract w/ Progress 164 | with Progress( 165 | TextColumn("[bold]{task.description} - {task.fields[aname]}"), 166 | BarColumn(bar_width=None), 167 | "[progress.percentage]{task.percentage:>3.1f}%", 168 | "•", 169 | MofNCompleteColumn(), 170 | transient=True, 171 | ) as ext_progress: 172 | with ZipFile(archive_path) as zf: 173 | ext_tid = ext_progress.add_task("Extracting", aname=archive_path.name, total=len(members := zf.infolist())) 174 | extract_path = Path(zf.extract(members[0], download_dir)) 175 | if extract_type == "file": 176 | assert len(members) == 1, f"Archive `{archive_path}` with extract type `{extract_type} has > 1 member!" 177 | elif extract_type == "directory": 178 | for member in members[1:]: 179 | zf.extract(member, download_dir) 180 | ext_progress.advance(ext_tid) 181 | else: 182 | raise ValueError(f"Extract type `{extract_type}` for archive `{archive_path}` is not defined!") 183 | 184 | # Cleanup (if specified) 185 | if cleanup: 186 | archive_path.unlink() 187 | 188 | return extract_path 189 | 190 | 191 | def download_extract(dataset_id: str, root_dir: Path) -> None: 192 | """Download all files for a given dataset (querying registry above), extracting archives if necessary.""" 193 | os.makedirs(download_dir := root_dir / "download" / dataset_id, exist_ok=True) 194 | 195 | # Download Files => Single-Threaded, with Progress Bar 196 | dl_tasks = [d for d in DATASET_REGISTRY[dataset_id] if not (download_dir / d["name"]).exists()] 197 | for dl_task in dl_tasks: 198 | dl_path = download_with_progress(dl_task["url"], download_dir) 199 | 200 | # Extract Files (if specified) --> Note (assumes ".zip" ONLY!) 201 | if dl_task["extract"]: 202 | dl_path = extract_with_progress(dl_path, download_dir, dl_task["extract_type"]) 203 | dl_path = dl_path.parent if dl_path.is_file() else dl_path 204 | 205 | # Rename Path --> dl_task["name"] 206 | if dl_task["do_rename"]: 207 | shutil.move(dl_path, download_dir / dl_task["name"]) 208 | -------------------------------------------------------------------------------- /prismatic/models/backbones/vision/base_vision.py: -------------------------------------------------------------------------------- 1 | """ 2 | base_vision.py 3 | 4 | Abstract class definition of a Vision Backbone (Visual Featurizer), with full annotations of class methods, utility 5 | functions, and initialization logic. 6 | 7 | We also define the generic TimmViTBackbone class here, providing a default interface for loading any TIMM Vision 8 | Transformer model for feature extraction. 9 | """ 10 | 11 | from abc import ABC, abstractmethod 12 | from dataclasses import dataclass 13 | from functools import partial 14 | from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union 15 | 16 | import timm 17 | import torch 18 | import torch.nn as nn 19 | import torchvision.transforms.functional as TVF 20 | from PIL.Image import Image 21 | from timm.models.vision_transformer import Block, VisionTransformer 22 | from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy 23 | from torchvision.transforms import Compose, Resize 24 | 25 | 26 | # === Utility Functions for Monkey-Patching === 27 | def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]: 28 | def wrapper(*args: Any, **kwargs: Any) -> Any: 29 | result = fn(*args, **kwargs) 30 | return result[0] if isinstance(result, tuple) else result 31 | 32 | return wrapper 33 | 34 | 35 | # === Interface for an Image Transform === 36 | class ImageTransform(Protocol): 37 | def __call__(self, img: Image, **kwargs: str) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: ... 38 | 39 | 40 | # === Custom Torchvision Image Transforms === 41 | @dataclass 42 | class LetterboxPad: 43 | padding_fill_value: Tuple[int, int, int] 44 | 45 | def __call__(self, image: Image) -> Image: 46 | """Given a PIL.Image, pad to square by adding a symmetric border around the height/width.""" 47 | (w, h), max_wh = image.size, max(image.size) 48 | horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2) 49 | padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad) 50 | return TVF.pad(image, padding, fill=self.padding_fill_value, padding_mode="constant") 51 | 52 | 53 | # === Abstract Base Class for arbitrary Vision Backbones === 54 | class VisionBackbone(nn.Module, ABC): 55 | def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: 56 | super().__init__() 57 | self.identifier: str = vision_backbone_id 58 | self.image_resize_strategy: str = image_resize_strategy 59 | self.default_image_size: int = default_image_size 60 | 61 | # Instance attributes for a Vision Backbone 62 | self.featurizer: nn.Module = None 63 | self.image_transform: ImageTransform = None 64 | 65 | def get_image_transform(self) -> ImageTransform: 66 | return self.image_transform 67 | 68 | @abstractmethod 69 | def get_fsdp_wrapping_policy(self) -> Callable: ... 70 | 71 | @abstractmethod 72 | def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: 73 | """Run a forward pass through the featurizer given a set of processed images, returning patch/grid features.""" 74 | raise NotImplementedError 75 | 76 | @property 77 | @abstractmethod 78 | def default_image_resolution(self) -> Tuple[int, int, int]: ... 79 | 80 | @property 81 | @abstractmethod 82 | def embed_dim(self) -> int: ... 83 | 84 | @property 85 | @abstractmethod 86 | def num_patches(self) -> int: ... 87 | 88 | @property 89 | @abstractmethod 90 | def half_precision_dtype(self) -> torch.dtype: ... 91 | 92 | 93 | # === Abstract Base Class for Arbitrary TIMM Vision Transformer Backbones === 94 | class TimmViTBackbone(VisionBackbone, ABC): 95 | def __init__( 96 | self, 97 | vision_backbone_id: str, 98 | timm_path_or_url: str, 99 | image_resize_strategy: str, 100 | default_image_size: int = 224, 101 | override_act_layer: Optional[str] = None, 102 | ) -> None: 103 | super().__init__(vision_backbone_id, image_resize_strategy, default_image_size=default_image_size) 104 | self.timm_path_or_url = timm_path_or_url 105 | self.override_act_layer = override_act_layer 106 | self.dtype = torch.bfloat16 107 | 108 | # Initialize Featurizer (ViT) by downloading from HF / TIMM Hub if necessary 109 | if self.override_act_layer is None: 110 | self.featurizer: VisionTransformer = timm.create_model( 111 | self.timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size 112 | ) 113 | else: 114 | self.featurizer: VisionTransformer = timm.create_model( 115 | self.timm_path_or_url, 116 | pretrained=True, 117 | num_classes=0, 118 | img_size=self.default_image_size, 119 | act_layer=self.override_act_layer, 120 | ) 121 | self.featurizer.eval() 122 | 123 | # Monkey-Patch the `forward()` function of the featurizer to ensure FSDP-compatibility 124 | # => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches! 125 | # => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385 126 | self.featurizer.forward = unpack_tuple( 127 | partial(self.featurizer.get_intermediate_layers, n={len(self.featurizer.blocks) - 2}) 128 | ) 129 | 130 | # Validation =>> for now, this class *only* supports TIMM Vision Transformers (but can be extended!) 131 | assert isinstance(self.featurizer, VisionTransformer), ( 132 | "Featurizer is not a TIMM VisionTransformer; if you would like to support a new visual representation, " 133 | "file an issue or implement the requisite logic (see `prismatic/models/backbones/vision/base_vision.py`)!" 134 | ) 135 | 136 | # Get Config =>> Note :: Override default image size to ensure correct image transform 137 | self.data_cfg = timm.data.resolve_model_data_config(self.featurizer) 138 | self.data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size) 139 | 140 | # Initialize Default Image Transform --> Modified by `self.image_resize_strategy` 141 | default_image_transform = timm.data.create_transform(**self.data_cfg, is_training=False) 142 | 143 | # Fix =>> SigLIP & IN1K default transforms resize to *larger* than `self.default_image_size` (crops image)! 144 | if "siglip" in self.timm_path_or_url or "in1k" in self.timm_path_or_url: 145 | assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!" 146 | assert isinstance(default_image_transform.transforms[0], Resize) 147 | default_image_transform = Compose( 148 | [ 149 | Resize(self.default_image_size, interpolation=default_image_transform.transforms[0].interpolation), 150 | *default_image_transform.transforms[1:], 151 | ] 152 | ) 153 | 154 | # Switch on `image_resize_strategy` 155 | if self.image_resize_strategy == "resize-naive": 156 | assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!" 157 | assert isinstance(default_image_transform.transforms[0], Resize) 158 | 159 | target_size = (self.default_image_size, self.default_image_size) 160 | self.image_transform = Compose( 161 | [ 162 | Resize(target_size, interpolation=default_image_transform.transforms[0].interpolation), 163 | *default_image_transform.transforms[1:], 164 | ] 165 | ) 166 | 167 | elif self.image_resize_strategy == "resize-crop": 168 | self.image_transform = default_image_transform 169 | 170 | elif self.image_resize_strategy == "letterbox": 171 | assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!" 172 | assert "mean" in self.data_cfg, "TIMM `data_cfg` missing image normalization mean!" 173 | 174 | # Compute Padding Fill Value (rescaled normalization mean if applicable) 175 | fill = tuple([int(x * 255) for x in self.data_cfg["mean"]]) 176 | 177 | # Build New Transform 178 | self.image_transform = Compose([LetterboxPad(fill), *default_image_transform.transforms]) 179 | 180 | else: 181 | raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!") 182 | 183 | def get_fsdp_wrapping_policy(self) -> Callable: 184 | """Return a simple FSDP policy that wraps each ViT block and then the _entire_ featurizer.""" 185 | vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer}) 186 | transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) 187 | return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy]) 188 | 189 | def forward(self, pixel_values: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor: 190 | """Runs transformed image/pixel tensor through vision backbone, returning _all_ patch features.""" 191 | return self.featurizer(pixel_values) 192 | 193 | @property 194 | def default_image_resolution(self) -> Tuple[int, int, int]: 195 | return self.data_cfg["input_size"] 196 | 197 | @property 198 | def embed_dim(self) -> int: 199 | return self.featurizer.embed_dim 200 | 201 | @property 202 | def num_patches(self) -> int: 203 | return self.featurizer.patch_embed.num_patches 204 | 205 | @property 206 | def half_precision_dtype(self) -> torch.dtype: 207 | return self.dtype 208 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/dummy_dataset.py: -------------------------------------------------------------------------------- 1 | #import sys 2 | #sys.path.append('/media/noriaki/Noriaki_Data/Learning-to-Drive-Anywhere-with-MBRA/train/') 3 | #sys.path.append('/home/noriaki/Learning-to-Drive-Anywhere-with-MBRA2/train/') 4 | 5 | import numpy as np 6 | import os 7 | import pickle 8 | import yaml 9 | from typing import Any, Dict, List, Optional, Tuple, Type 10 | import tqdm 11 | import io 12 | import lmdb 13 | import utm 14 | import math 15 | 16 | import torch 17 | from torch.utils.data import Dataset 18 | import torchvision.transforms.functional as TF 19 | 20 | import random 21 | #import cv2 22 | import matplotlib.pyplot as plt 23 | 24 | from prismatic.vla.constants import ACTION_DIM, ACTION_PROPRIO_NORMALIZATION_TYPE, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX, NUM_ACTIONS_CHUNK, STOP_INDEX 25 | from prismatic.models.backbones.llm.prompting import PromptBuilder 26 | from prismatic.vla.action_tokenizer import ActionTokenizer 27 | from transformers import PreTrainedTokenizerBase 28 | from prismatic.models.backbones.vision import ImageTransform 29 | from PIL import Image 30 | from typing import Union 31 | 32 | from vint_train.data.data_utils import ( 33 | img_path_to_data, 34 | calculate_sin_cos, 35 | get_data_path, 36 | to_local_coords, 37 | ) 38 | 39 | class Dummy_Dataset(Dataset): 40 | def __init__( 41 | self, 42 | context_size: int, 43 | action_tokenizer: PreTrainedTokenizerBase, 44 | base_tokenizer: ActionTokenizer, 45 | image_transform: ImageTransform, 46 | prompt_builder_fn: Type[PromptBuilder], 47 | predict_stop_token: bool = True, 48 | ): 49 | self.context_size = context_size 50 | self.action_tokenizer = action_tokenizer 51 | self.base_tokenizer = base_tokenizer 52 | self.prompt_builder = prompt_builder_fn 53 | self.predict_stop_token = predict_stop_token 54 | self.image_transform = image_transform 55 | 56 | def __len__(self) -> int: 57 | return 10000 #dummy length 58 | 59 | def calculate_relative_position(self, x_a, y_a, x_b, y_b): 60 | return x_b - x_a, y_b - y_a 61 | 62 | def rotate_to_local_frame(self, delta_x, delta_y, heading_a_rad): 63 | rel_x = delta_x * math.cos(heading_a_rad) + delta_y * math.sin(heading_a_rad) 64 | rel_y = -delta_x * math.sin(heading_a_rad) + delta_y * math.cos(heading_a_rad) 65 | return rel_x, rel_y 66 | 67 | def _resize_norm(self, image, size): 68 | return TF.resize(image, size) 69 | 70 | def __getitem__(self, i: int) -> Tuple[torch.Tensor]: 71 | thres_dist = 30.0 72 | metric_waypoint_spacing = 0.1 73 | predict_stop_token=True 74 | 75 | # set the available modality id for each dataset 76 | # 0:"satellite only", 1:"pose and satellite", 2:"satellite and image", 3:"all", 4:"pose only", 5:"pose and image", 6:"image only", 7:"language only", 8:"language and pose" 77 | modality_list = [4, 6, 7] #[4, 5, 6, 7, 8] # Our sample data is no consistency between modalities. So we can take a solo modality. 78 | modality_id = random.choice(modality_list) 79 | 80 | inst_obj = "move toward blue trash bin" 81 | actions = np.random.rand(8, 4) #dummy action 82 | 83 | # Dummy current and goal location 84 | current_lat, current_lon, current_compass = 37.87371258374039, -122.26729417226024, 270.0 85 | cur_utm = utm.from_latlon(current_lat, current_lon) 86 | cur_compass = -float(current_compass) / 180.0 * math.pi # inverted compass 87 | 88 | goal_lat, goal_lon, goal_compass = 37.8738930785863, -122.26746181032362, 0.0 89 | goal_utm = utm.from_latlon(goal_lat, goal_lon) 90 | goal_compass = -float(goal_compass) / 180.0 * math.pi 91 | 92 | # Local goal position 93 | delta_x, delta_y = self.calculate_relative_position( 94 | cur_utm[0], cur_utm[1], goal_utm[0], goal_utm[1] 95 | ) 96 | relative_x, relative_y = self.rotate_to_local_frame(delta_x, delta_y, cur_compass) 97 | radius = np.sqrt(relative_x**2 + relative_y**2) 98 | if radius > thres_dist: 99 | relative_x *= thres_dist / radius 100 | relative_y *= thres_dist / radius 101 | 102 | goal_pose_loc_norm = np.array([ 103 | relative_y / metric_waypoint_spacing, 104 | -relative_x / metric_waypoint_spacing, 105 | np.cos(goal_compass - cur_compass), 106 | np.sin(goal_compass - cur_compass) 107 | ]) 108 | 109 | goal_pose_cos_sin = goal_pose_loc_norm 110 | current_image_PIL = Image.open("./inference/current_img.jpg").convert("RGB") 111 | goal_image_PIL = Image.open("./inference/goal_img.jpg").convert("RGB") 112 | 113 | IGNORE_INDEX = -100 114 | current_action = actions[0] 115 | future_actions = actions[1:] 116 | future_actions_string = ''.join(self.action_tokenizer(future_actions)) 117 | current_action_string = self.action_tokenizer(current_action) 118 | action_chunk_string = current_action_string + future_actions_string 119 | action_chunk_len = len(action_chunk_string) 120 | 121 | if modality_id != 7 and modality_id != 8: # We give following language prompt when not selecting language modality instead of masking out the tokens. 122 | conversation = [ 123 | {"from": "human", "value": "No language instruction"}, 124 | {"from": "gpt", "value": action_chunk_string}, 125 | ] 126 | else: 127 | conversation = [ 128 | {"from": "human", "value": f"What action should the robot take to {inst_obj}?"}, 129 | {"from": "gpt", "value": action_chunk_string}, 130 | ] 131 | 132 | prompt_builder = self.prompt_builder("openvla") 133 | for turn in conversation: 134 | prompt_builder.add_turn(turn["from"], turn["value"]) 135 | 136 | # Tokenize 137 | input_ids = torch.tensor(self.base_tokenizer(prompt_builder.get_prompt(), add_special_tokens=True).input_ids) 138 | labels = input_ids.clone() 139 | labels[:-(action_chunk_len + 1)] = IGNORE_INDEX 140 | if not predict_stop_token: 141 | labels[-1] = IGNORE_INDEX 142 | 143 | # Images for MBRA model 144 | image_obs_list = [] 145 | for ih in range(self.context_size + 1): 146 | image_obs_list.append(self._resize_norm(TF.to_tensor(current_image_PIL), (96, 96))) #In our real code, image_obs_list is list of history image. In this dummy dataset code, we feed current images. The detail implementation is same as ViNT, NoMaD code base. 147 | image_obs = torch.cat(image_obs_list) 148 | image_goal = self._resize_norm(TF.to_tensor(goal_image_PIL), (96, 96)) 149 | 150 | # Data augmentation (random cropping) 151 | voffset = int(224.0*0.2*random.random()) 152 | hoffset = int(224.0*0.1*random.random()) 153 | PILbox = (hoffset, voffset, 224-hoffset, 224-voffset) 154 | current_image_PIL = current_image_PIL.crop(PILbox).resize((224,224)) 155 | goal_image_PIL = goal_image_PIL.crop(PILbox).resize((224,224)) 156 | 157 | # Data augmentation (horizontal flipping) 158 | if random.random() > 0.5: 159 | current_image_PIL = current_image_PIL.transpose(Image.FLIP_LEFT_RIGHT) 160 | goal_image_PIL = goal_image_PIL.transpose(Image.FLIP_LEFT_RIGHT) 161 | actions[:,1] = -actions[:,1] 162 | actions[:,3] = -actions[:,3] 163 | goal_pose_cos_sin[1] = -goal_pose_cos_sin[1] 164 | goal_pose_cos_sin[3] = -goal_pose_cos_sin[3] 165 | 166 | image_obs = torch.flip(image_obs, dims=[2]) 167 | image_goal = torch.flip(image_goal, dims=[2]) 168 | 169 | pixel_values_current = self.image_transform(current_image_PIL) 170 | pixel_values_goal = self.image_transform(goal_image_PIL) 171 | 172 | #action select 1.0: raw action, 0.0: MBRA synthetic action 173 | action_select_mask = torch.tensor(1.0) 174 | 175 | dataset_name = "dummy" 176 | return dict( 177 | pixel_values=pixel_values_current, #Current image for OmniVLA 178 | pixel_values_goal=pixel_values_goal, #Goal image for OmniVLA 179 | input_ids=input_ids, #language and action prompt, following OpenVLA-OFT 180 | labels=labels, #language and action prompt, following OpenVLA-OFT 181 | dataset_name=dataset_name, #dataset name 182 | modality_id=modality_id, #modality ID, 0:"satellite only", 1:"pose and satellite", 2:"satellite and image", 3:"all", 4:"pose only", 5:"pose and image", 6:"image only", 7:"language only", 8:"language and pose" 183 | actions=torch.as_tensor(actions), #action commands 184 | action_select_mask = action_select_mask,#action select mask, 1.0: raw action, 0.0: MBRA synthetic action 185 | goal_pose=goal_pose_cos_sin, #goal pose [X, Y, cos(yaw), sin(yaw)] 186 | obj_pose_norm=goal_pose_cos_sin[0:2], #obj pose [X, Y] (This is only for LeLaN dataset) : Dummy pose in this dummy dataset 187 | img_PIL=current_image_PIL, #for visualization 188 | gimg_PIL=goal_image_PIL, #for visualization 189 | cur_image=image_obs, #History of image for MBRA 190 | goal_image_8=image_goal, #Goal image (8 step future) for MBRA 191 | temp_dist=10.0, #Temporal distance (We are not using in our training) 192 | lan_prompt=inst_obj 193 | ) 194 | 195 | -------------------------------------------------------------------------------- /prismatic/conf/vla.py: -------------------------------------------------------------------------------- 1 | """ 2 | vla.py 3 | 4 | Draccus Dataclass Definition for a VLAConfig object, with various registered subclasses for each VLA experiment and 5 | model configuration thereof. A given VLA model (`policy`) configures the following attributes: 6 | - Data Mixture (e.g., Bridge, OXE_MAGIC_SOUP, etc.) 7 | - Base VLM from Prismatic Registry (e.g., `prism-dinosiglip+7b`) 8 | - VLA Model Architecture / Parameters (e.g., freeze vision encoder, last layer finetuning) 9 | - Training / Optimization Hyperparameters 10 | """ 11 | 12 | from dataclasses import dataclass 13 | from enum import Enum, unique 14 | from pathlib import Path 15 | from typing import Optional, Union 16 | 17 | from draccus import ChoiceRegistry 18 | 19 | 20 | @dataclass 21 | class VLAConfig(ChoiceRegistry): 22 | # fmt: off 23 | vla_id: str # Unique VLA Policy ID that fully specifies a configuration variant 24 | base_vlm: Union[str, Path] # Base VLM as ID/Path to Run Directory (e.g., `prism-dinosiglip+7b`) 25 | freeze_vision_backbone: bool # Freeze Vision Backbone Parameters (akin to pretraining) 26 | freeze_llm_backbone: bool # Freeze LLM Backbone parameters 27 | unfreeze_last_llm_layer: bool # Unfreeze final layer of LLM (only takes effect if LLM is frozen) 28 | 29 | # Data Mixture Parameters 30 | data_mix: str # Open-X Embodiment Dataset =>> Unique Mixture ID (e.g., `bridge`) 31 | shuffle_buffer_size: int # Size of Shuffle Buffer (100K for Bridge, 1M for OXE) 32 | 33 | # Optimization Parameters 34 | epochs: int # Epochs to Run (in case `max_steps` is not specified) 35 | max_steps: Optional[int] # [Optional] Max Gradient Steps to Run (overrides `epochs`) 36 | 37 | expected_world_size: int # Expected # of GPUs =>> allows us to gate training on hardware 38 | global_batch_size: int # Global Batch Size (divided across processes / world size) 39 | per_device_batch_size: int # Per-Device Batch Size (per-process / individual GPU) 40 | # =>> # of accumulation steps is auto-computed 41 | 42 | learning_rate: float # Peak Learning Rate (`lr_scheduler_type` sets warmup/decay) 43 | weight_decay: float # Weight Decay for AdamW Optimizer 44 | max_grad_norm: float # Max Grad Norm (for global gradient clipping) 45 | lr_scheduler_type: str # LR Scheduler (usually: "constant" | "linear-warmup+cosine-decay") 46 | warmup_ratio: float # Fraction of Steps to Warmup (for warmup LR schedulers) 47 | 48 | train_strategy: str # Train Strategy (default "fsdp-full-shard") 49 | 50 | # Enable Gradient/Activation Checkpointing (for the LLM Backbone) 51 | enable_gradient_checkpointing: bool = True # Enable Gradient/Activation Checkpointing during Training 52 | 53 | # Mixed Precision Training via Torch Native AMP (`autocast`) 54 | enable_mixed_precision_training: bool = True # Enable Traditional BF16 Mixed Precision 55 | reduce_in_full_precision: bool = True # Accumulate/Reduce All-Gather Gradients in FP32 Full Precision 56 | 57 | # fmt: on 58 | 59 | 60 | # === OpenVLA Training Configurations === 61 | 62 | 63 | # = [8 GPU] Fast Iteration =>> SigLIP 224px + Bridge = 64 | @dataclass 65 | class Exp_SigLIP_224px_Bridge(VLAConfig): 66 | vla_id: str = "siglip-224px+mx-bridge" 67 | base_vlm: Union[str, Path] = "siglip-224px+7b" 68 | 69 | freeze_vision_backbone: bool = False 70 | freeze_llm_backbone: bool = False 71 | unfreeze_last_llm_layer: bool = False 72 | 73 | # Data Mixture Parameters 74 | data_mix: str = "bridge" 75 | shuffle_buffer_size: int = 256_000 76 | 77 | # Optimization Parameters 78 | epochs: int = 1000 79 | max_steps: Optional[int] = None 80 | 81 | expected_world_size: int = 8 82 | global_batch_size: int = 256 83 | per_device_batch_size: int = 32 84 | 85 | learning_rate: float = 2e-5 86 | weight_decay: float = 0.0 87 | max_grad_norm: float = 1.0 88 | lr_scheduler_type: str = "constant" 89 | warmup_ratio: float = 0.0 90 | 91 | train_strategy: str = "fsdp-full-shard" 92 | 93 | 94 | # = [8 GPU] SigLIP 224px Frozen Vision Backbone + Bridge = 95 | @dataclass 96 | class Exp_FreezeVIT_SigLIP_224px_Bridge(Exp_SigLIP_224px_Bridge): 97 | vla_id: str = "siglip-224px-icy+mx-bridge" 98 | base_vlm: Union[str, Path] = "siglip-224px+7b" 99 | freeze_vision_backbone: bool = True 100 | 101 | 102 | # = [8 GPU] Fast Iteration =>> DINO-SigLIP 224px + Bridge = 103 | @dataclass 104 | class Exp_DinoSigLIP_224px_Bridge(Exp_SigLIP_224px_Bridge): 105 | vla_id: str = "prism-dinosiglip-224px+mx-bridge" 106 | base_vlm: Union[str, Path] = "prism-dinosiglip-224px+7b" 107 | 108 | data_mix: str = "bridge" 109 | 110 | 111 | # = [64 GPU] SigLIP 224px + OXE Magic Soup = 112 | @dataclass 113 | class Exp_SigLIP_224px_OXE_Magic_Soup(Exp_SigLIP_224px_Bridge): 114 | vla_id: str = "siglip-224px+mx-oxe-magic-soup" 115 | base_vlm: Union[str, Path] = "siglip-224px+7b" 116 | 117 | data_mix: str = "oxe_magic_soup" 118 | 119 | expected_world_size: int = 64 120 | global_batch_size: int = 2048 121 | per_device_batch_size: int = 32 122 | 123 | 124 | # = [64 GPU] DINO-SigLIP 224px + OXE Magic Soup++ = 125 | @dataclass 126 | class Exp_DinoSigLIP_224px_OXE_Magic_Soup_Plus(Exp_SigLIP_224px_Bridge): 127 | vla_id: str = "prism-dinosiglip-224px+mx-oxe-magic-soup-plus" 128 | base_vlm: Union[str, Path] = "prism-dinosiglip-224px+7b" 129 | 130 | # Note =>> We adopt two stages, training on a mixture including DROID for 70% of training, before resampling! 131 | # data_mix: str = "oxe_magic_soup_plus" 132 | data_mix: str = "oxe_magic_soup_plus_minus" 133 | 134 | expected_world_size: int = 64 135 | global_batch_size: int = 2048 136 | per_device_batch_size: int = 32 137 | 138 | 139 | # === OpenVLA Fine-tuning Configurations === 140 | 141 | 142 | # = [8 GPU] SigLIP 224px + T-DROID = 143 | @dataclass 144 | class Exp_SigLIP_224px_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge): 145 | vla_id: str = "siglip-224px+mx-tdroid_carrot_in_bowl" 146 | base_vlm: Union[str, Path] = "siglip-224px+7b" 147 | 148 | data_mix: str = "tdroid_carrot_in_bowl" 149 | 150 | 151 | @dataclass 152 | class Exp_SigLIP_224px_TDROID_PourCornInPot(Exp_SigLIP_224px_Bridge): 153 | vla_id: str = "siglip-224px+mx-tdroid_pour_corn_in_pot" 154 | base_vlm: Union[str, Path] = "siglip-224px+7b" 155 | 156 | data_mix: str = "tdroid_pour_corn_in_pot" 157 | 158 | 159 | # = [8 GPU] SigLIP 224px + T-DROID -- Partial Finetuning = 160 | @dataclass 161 | class Exp_SigLIP_224px_Icy_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge): 162 | vla_id: str = "siglip-224px-icy+mx-tdroid_carrot_in_bowl" 163 | base_vlm: Union[str, Path] = "siglip-224px+7b" 164 | freeze_vision_backbone: bool = True 165 | freeze_llm_backbone: bool = False 166 | 167 | data_mix: str = "tdroid_carrot_in_bowl" 168 | 169 | 170 | @dataclass 171 | class Exp_SigLIP_224px_LastLayer_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge): 172 | vla_id: str = "siglip-224px-last_layer+mx-tdroid_carrot_in_bowl" 173 | base_vlm: Union[str, Path] = "siglip-224px+7b" 174 | freeze_vision_backbone: bool = True 175 | freeze_llm_backbone: bool = True 176 | unfreeze_last_llm_layer: bool = True 177 | 178 | data_mix: str = "tdroid_carrot_in_bowl" 179 | 180 | 181 | @dataclass 182 | class Exp_SigLIP_224px_Sandwich_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge): 183 | vla_id: str = "siglip-224px-sandwich+mx-tdroid_carrot_in_bowl" 184 | base_vlm: Union[str, Path] = "siglip-224px+7b" 185 | freeze_vision_backbone: bool = False 186 | freeze_llm_backbone: bool = True 187 | unfreeze_last_llm_layer: bool = True 188 | 189 | data_mix: str = "tdroid_carrot_in_bowl" 190 | 191 | 192 | # === [8 GPU] SigLIP 224px + FrankaWipe === 193 | @dataclass 194 | class Exp_SigLIP_224px_Droid_Wipe(Exp_SigLIP_224px_Bridge): 195 | vla_id: str = "siglip-224px+mx-droid_wipe" 196 | base_vlm: Union[str, Path] = "siglip-224px+7b" 197 | 198 | data_mix: str = "droid_wipe" 199 | 200 | 201 | # === Define a VLA Registry Enum for Reference & Validation === 202 | @unique 203 | class VLARegistry(Enum): 204 | # Sanity Check Configurations =>> BridgeV2 205 | SIGLIP_224PX_MX_BRIDGE = Exp_SigLIP_224px_Bridge 206 | DINOSIGLIP_224PX_MX_BRIDGE = Exp_DinoSigLIP_224px_Bridge 207 | 208 | # SigLIP Frozen Backbone Experiment 209 | FREEZE_SIGLIP_224PX_MX_BRIDGE = Exp_FreezeVIT_SigLIP_224px_Bridge 210 | 211 | # [OpenVLA v0.1 7B] SigLIP 224px + OXE Magic Soup 212 | SIGLIP_224PX_MX_OXE_MAGIC_SOUP = Exp_SigLIP_224px_OXE_Magic_Soup 213 | 214 | # [OpenVLA 7B] DINO + SigLIP 224px + OXE Magic Soup++ 215 | DINOSIGLIP_224PX_MX_OXE_MAGIC_SOUP_PLUS = Exp_DinoSigLIP_224px_OXE_Magic_Soup_Plus 216 | 217 | # === TDROID Fine-tuning Configs === 218 | SIGLIP_224PX_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_TDROID_CarrotInBowl 219 | SIGLIP_224PX_MX_TDROID_POUR_CORN_IN_POT = Exp_SigLIP_224px_TDROID_PourCornInPot 220 | 221 | SIGLIP_224PX_ICY_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_Icy_TDROID_CarrotInBowl 222 | SIGLIP_224PX_LASTLAYER_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_LastLayer_TDROID_CarrotInBowl 223 | SIGLIP_224PX_SANDWICH_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_Sandwich_TDROID_CarrotInBowl 224 | 225 | # === DROID Fine-tuning Configs === 226 | SIGLIP_224PX_MX_DROID_WIPE = Exp_SigLIP_224px_Droid_Wipe 227 | 228 | @property 229 | def vla_id(self) -> str: 230 | return self.value.vla_id 231 | 232 | 233 | # Register VLAs in Choice Registry 234 | for vla_variant in VLARegistry: 235 | VLAConfig.register_subclass(vla_variant.vla_id, vla_variant.value) 236 | -------------------------------------------------------------------------------- /prismatic/preprocessing/datasets/datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | datasets.py 3 | 4 | PyTorch Dataset Definitions for Prismatic models; supports processing for both the `align` and `finetune` stages, with 5 | utilities for formatting conversations during the `finetune` stage subject to the given LLM backbone's expected 6 | formatting (e.g., SYS_PROMPT + USER: ... ASSISTANT: ... for Vicuña v1.5 Chat models). 7 | 8 | We currently only support Map-style Datasets; assumes that all files (annotations, images) are on local disk, and that 9 | random access image reading is relatively cheap/fast. 10 | """ 11 | 12 | import copy 13 | import json 14 | from pathlib import Path 15 | from typing import Dict, List, Tuple, Type 16 | 17 | import torch 18 | from PIL import Image 19 | from torch.utils.data import Dataset 20 | from transformers import CodeGenTokenizerFast, LlamaTokenizerFast, PreTrainedTokenizerBase 21 | 22 | from prismatic.models.backbones.llm.prompting import PromptBuilder 23 | from prismatic.models.backbones.vision import ImageTransform 24 | 25 | # HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels) 26 | IGNORE_INDEX = -100 27 | 28 | 29 | class AlignDataset(Dataset[Dict[str, torch.Tensor]]): 30 | def __init__( 31 | self, 32 | chat_json: Path, 33 | image_dir: Path, 34 | image_transform: ImageTransform, 35 | tokenizer: PreTrainedTokenizerBase, 36 | ) -> None: 37 | super().__init__() 38 | self.chat_json, self.image_dir = chat_json, image_dir 39 | self.image_transform, self.tokenizer = image_transform, tokenizer 40 | self.dataset_type = "align" 41 | 42 | # Create Prompt Template 43 | self.prompt_template = "{caption}" + self.tokenizer.eos_token 44 | 45 | # Load Chat JSON 46 | with open(self.chat_json, "r") as f: 47 | self.examples = json.load(f) 48 | 49 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 50 | """ 51 | Following the *actual* code executed from the LLaVa codebase, during the "align" phase, we actually discard 52 | the "prompt" from the human, and instead directly predict the caption from the image. 53 | 54 | As a concrete example given the "raw data" for the first example: 55 | example = self.examples[0]["conversations"]` = { 56 | [ 57 | {"from": "human", "value": "Render a clear and concise summary of the photo.\n"}, 58 | {"from": "gpt", "value": "select luxury furniture 3 - inch gel memory foam mattress topper"} 59 | ] 60 | } 61 | 62 | Return =>> self.tokenizer(" select luxury furniture 3 - inch gel memory foam mattress topper\n") 63 | 64 | :param idx: Index to retrieve from the dataset. 65 | 66 | :return: Dictionary of {"pixel_values": torch.Tensor, "input_ids": torch.Tensor, "labels": torch.Tensor} 67 | """ 68 | image_path, conversation = Path(self.examples[idx]["image"]), self.examples[idx]["conversations"] 69 | assert (len(conversation) == 2) and ("" not in conversation[-1]["value"]), "Unexpected text!" 70 | 71 | # Format Caption --> {caption}{eos_token} 72 | caption = self.prompt_template.format(caption=conversation[-1]["value"].strip()) 73 | 74 | # We treat image patches as "tokens = [p1 p2 p3, ...]"; we need to specify ordering of text/patch tokens. 75 | # => Critically, we find that inserting *after* the BOS token leads to the strongest performance! 76 | # - input_ids = " p1 p2 p3 ... \n" 77 | # - labels = "IGNORE IGNORE ..." (copy `input_ids` replacing and p{1...K} with IGNORE) 78 | # 79 | # IMPORTANT => IF WE'RE USING HF LLM.forward(... labels=labels), SHIFTING HAPPENS _INSIDE_ MODEL! 80 | input_ids = self.tokenizer(caption, truncation=True, return_tensors="pt").input_ids[0] 81 | labels = copy.deepcopy(input_ids) 82 | 83 | # Set the token's label to IGNORE_INDEX (since we're inserting the image patches right after) 84 | labels[0] = IGNORE_INDEX 85 | 86 | # Process Image --> get "pixel_values" (will either be a torch.Tensor OR a Dict[str,torch.Tensor]) 87 | pixel_values = self.image_transform(Image.open(self.image_dir / image_path).convert("RGB")) 88 | 89 | return dict(pixel_values=pixel_values, input_ids=input_ids, labels=labels) 90 | 91 | def get_modality_lengths(self, n_image_patches: int) -> List[Tuple[bool, int]]: 92 | """Get a list of modalities (unimodal / text-only vs. multimodal) and length of conversations per example.""" 93 | modality_lengths = [] 94 | for example in self.examples: 95 | is_multimodal = "image" in example 96 | n_words = sum([len(turn["value"].replace("", "").split()) for turn in example["conversations"]]) 97 | modality_lengths.append((is_multimodal, (n_image_patches + n_words) if is_multimodal else n_words)) 98 | return modality_lengths 99 | 100 | def __len__(self) -> int: 101 | return len(self.examples) 102 | 103 | 104 | class FinetuneDataset(Dataset[Dict[str, torch.Tensor]]): 105 | def __init__( 106 | self, 107 | instruct_json: Path, 108 | image_dir: Path, 109 | image_transform: ImageTransform, 110 | tokenizer: PreTrainedTokenizerBase, 111 | prompt_builder_fn: Type[PromptBuilder], 112 | ) -> None: 113 | super().__init__() 114 | self.instruct_json, self.image_dir = instruct_json, image_dir 115 | self.image_transform, self.tokenizer = image_transform, tokenizer 116 | self.prompt_builder_fn = prompt_builder_fn 117 | self.dataset_type = "finetune" 118 | 119 | # Load Instruct JSON 120 | with open(self.instruct_json, "r") as f: 121 | self.examples = json.load(f) 122 | 123 | # === Unimodal + Multimodal Handling === 124 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 125 | """ 126 | Unlike the *align* stage handling, for the *finetune* stage, we actually need to handle multiple "turns" of 127 | dialog grounded in a single image. 128 | 129 | To do this, we leverage the `prompt_builder_fn` which instantiates a PromptBuilder object. By calling the 130 | methods for adding turns and getting a prompt, we ensure proper formatting and consistency for each example. 131 | 132 | :param idx: Index to retrieve from the dataset. 133 | 134 | :return: Dictionary of {"pixel_values": torch.Tensor, "input_ids": torch.Tensor, "labels": torch.Tensor} 135 | """ 136 | conversation = self.examples[idx]["conversations"] 137 | 138 | # Create Prompt Builder --> add each message sequentially 139 | prompt_builder, input_ids, labels = self.prompt_builder_fn(model_family="prismatic"), [], [] 140 | for turn_idx, turn in enumerate(conversation): 141 | # Get "effective" string added to prompt --> handle whitespace for tokenizer type! 142 | msg = prompt_builder.add_turn(turn["from"], turn["value"]) 143 | 144 | # Llama Tokenizer (Fast) adds extra character if a string ends in whitespace --> strip if non-empty! 145 | if isinstance(self.tokenizer, LlamaTokenizerFast): 146 | msg = msg.rstrip() 147 | 148 | # Phi-2 Tokenizer == CodeGenTokenizer (Fast) -- no special handling! 149 | elif isinstance(self.tokenizer, CodeGenTokenizerFast): 150 | pass 151 | 152 | else: 153 | raise ValueError(f"Tokenizer of type `{type(self.tokenizer)}` is not explicitly handled!") 154 | 155 | # Tokenize Input IDs 156 | turn_input_ids = self.tokenizer(msg, add_special_tokens=turn_idx == 0).input_ids 157 | 158 | # [CRITICAL] We do not want to take the loss for the "USER: " prompts =>> just the responses! 159 | turn_labels = ( 160 | [IGNORE_INDEX for _ in range(len(turn_input_ids))] if (turn_idx % 2) == 0 else list(turn_input_ids) 161 | ) 162 | 163 | # Add to Trackers 164 | input_ids.extend(turn_input_ids) 165 | labels.extend(turn_labels) 166 | 167 | # Tensorize =>> Set the token's label to IGNORE_INDEX (since we're inserting the image patches after) 168 | # - IMPORTANT => IF WE'RE USING HF LLM.forward(... labels=labels), SHIFTING HAPPENS _INSIDE_ MODEL! 169 | input_ids, labels = torch.tensor(input_ids), torch.tensor(labels) 170 | 171 | # Handle Truncation (if necessary) 172 | input_ids, labels = input_ids[: self.tokenizer.model_max_length], labels[: self.tokenizer.model_max_length] 173 | 174 | # === Handle "unimodal" (language-only) vs. "multimodal" === 175 | if "image" in self.examples[idx]: 176 | image_path = Path(self.examples[idx]["image"]) 177 | 178 | # Set the token's label to IGNORE_INDEX (since we're inserting the image patches right after) 179 | labels[0] = IGNORE_INDEX 180 | 181 | # Process Image --> get "pixel_values" (will either be a torch.Tensor OR a Dict[str,torch.Tensor]) 182 | pixel_values = self.image_transform(Image.open(self.image_dir / image_path).convert("RGB")) 183 | 184 | return dict(pixel_values=pixel_values, input_ids=input_ids, labels=labels) 185 | 186 | else: 187 | # No image --> return `pixel_values` = None; Collator will do the smart batch handling for us! 188 | return dict(pixel_values=None, input_ids=input_ids, labels=labels) 189 | 190 | def get_modality_lengths(self) -> List[Tuple[bool, int]]: 191 | """Get a list of modalities (unimodal / text-only vs. multimodal) and length of conversations per example.""" 192 | modality_lengths = [] 193 | for example in self.examples: 194 | is_multimodal = "image" in example 195 | n_words = sum([len(turn["value"].split()) for turn in example["conversations"]]) 196 | modality_lengths.append((is_multimodal, n_words)) 197 | return modality_lengths 198 | 199 | def __len__(self) -> int: 200 | return len(self.examples) 201 | --------------------------------------------------------------------------------