├── prismatic
├── py.typed
├── extern
│ ├── __init__.py
│ └── hf
│ │ ├── __init__.py
│ │ └── configuration_prismatic.py
├── models
│ ├── backbones
│ │ ├── __init__.py
│ │ ├── llm
│ │ │ ├── __init__.py
│ │ │ ├── prompting
│ │ │ │ ├── __init__.py
│ │ │ │ ├── mistral_instruct_prompter.py
│ │ │ │ ├── phi_prompter.py
│ │ │ │ ├── base_prompter.py
│ │ │ │ ├── vicuna_v15_prompter.py
│ │ │ │ └── llama2_chat_prompter.py
│ │ │ ├── phi.py
│ │ │ ├── mistral.py
│ │ │ └── llama2.py
│ │ └── vision
│ │ │ ├── __init__.py
│ │ │ ├── dinov2_vit.py
│ │ │ ├── in1k_vit.py
│ │ │ ├── siglip_vit.py
│ │ │ ├── clip_vit.py
│ │ │ ├── dinoclip_vit.py
│ │ │ ├── dinosiglip_vit.py
│ │ │ └── base_vision.py
│ ├── vlas
│ │ ├── __init__.py
│ │ └── openvla.py
│ ├── vlms
│ │ ├── __init__.py
│ │ └── base_vlm.py
│ ├── __init__.py
│ ├── projectors.py
│ ├── materialize.py
│ └── action_heads.py
├── vla
│ ├── datasets
│ │ ├── rlds
│ │ │ ├── utils
│ │ │ │ ├── __init__.py
│ │ │ │ ├── goal_relabeling.py
│ │ │ │ └── task_augmentation.py
│ │ │ ├── __init__.py
│ │ │ ├── oxe
│ │ │ │ ├── __init__.py
│ │ │ │ ├── materialize.py
│ │ │ │ └── utils
│ │ │ │ │ └── droid_utils.py
│ │ │ ├── traj_transforms.py
│ │ │ └── obs_transforms.py
│ │ ├── __init__.py
│ │ ├── data_config.yaml
│ │ ├── cast_dataset.py
│ │ └── dummy_dataset.py
│ ├── __init__.py
│ ├── constants.py
│ ├── materialize.py
│ └── action_tokenizer.py
├── overwatch
│ ├── __init__.py
│ └── overwatch.py
├── util
│ ├── __init__.py
│ ├── nn_utils.py
│ └── torch_utils.py
├── preprocessing
│ ├── datasets
│ │ ├── __init__.py
│ │ └── datasets.py
│ ├── __init__.py
│ ├── materialize.py
│ └── download.py
├── __init__.py
├── training
│ ├── __init__.py
│ ├── strategies
│ │ ├── __init__.py
│ │ └── ddp.py
│ ├── train_utils.py
│ └── materialize.py
└── conf
│ ├── __init__.py
│ ├── datasets.py
│ └── vla.py
├── inference
├── 1_ex.jpg
├── goal_img.jpg
└── current_img.jpg
├── config_nav
├── mbra_config.yaml
└── mbra_and_dataset_config.yaml
├── .gitignore
├── SETUP.md
├── LICENSE
├── pyproject.toml
└── README.md
/prismatic/py.typed:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/prismatic/extern/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/prismatic/extern/hf/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/prismatic/models/backbones/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/prismatic/vla/datasets/rlds/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/prismatic/models/vlas/__init__.py:
--------------------------------------------------------------------------------
1 | from .openvla import OpenVLA
2 |
--------------------------------------------------------------------------------
/prismatic/models/vlms/__init__.py:
--------------------------------------------------------------------------------
1 | from .prismatic import PrismaticVLM
2 |
--------------------------------------------------------------------------------
/prismatic/overwatch/__init__.py:
--------------------------------------------------------------------------------
1 | from .overwatch import initialize_overwatch
2 |
--------------------------------------------------------------------------------
/prismatic/vla/__init__.py:
--------------------------------------------------------------------------------
1 | from .materialize import get_vla_dataset_and_collator
2 |
--------------------------------------------------------------------------------
/inference/1_ex.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NHirose/OmniVLA/HEAD/inference/1_ex.jpg
--------------------------------------------------------------------------------
/inference/goal_img.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NHirose/OmniVLA/HEAD/inference/goal_img.jpg
--------------------------------------------------------------------------------
/prismatic/util/__init__.py:
--------------------------------------------------------------------------------
1 | from .torch_utils import check_bloat16_supported, set_global_seed
2 |
--------------------------------------------------------------------------------
/inference/current_img.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NHirose/OmniVLA/HEAD/inference/current_img.jpg
--------------------------------------------------------------------------------
/prismatic/preprocessing/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .datasets import AlignDataset, FinetuneDataset
2 |
--------------------------------------------------------------------------------
/prismatic/vla/datasets/rlds/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataset import make_interleaved_dataset, make_single_dataset
2 |
--------------------------------------------------------------------------------
/prismatic/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import available_model_names, available_models, get_model_description, load
2 |
--------------------------------------------------------------------------------
/prismatic/training/__init__.py:
--------------------------------------------------------------------------------
1 | from .materialize import get_train_strategy
2 | from .metrics import Metrics, VLAMetrics
3 |
--------------------------------------------------------------------------------
/prismatic/preprocessing/__init__.py:
--------------------------------------------------------------------------------
1 | from .download import convert_to_jpg, download_extract
2 | from .materialize import get_dataset_and_collator
3 |
--------------------------------------------------------------------------------
/prismatic/vla/datasets/rlds/oxe/__init__.py:
--------------------------------------------------------------------------------
1 | from .materialize import get_oxe_dataset_kwargs_and_weights
2 | from .mixtures import OXE_NAMED_MIXTURES
3 |
--------------------------------------------------------------------------------
/prismatic/training/strategies/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_strategy import TrainingStrategy
2 | from .ddp import DDPStrategy
3 | from .fsdp import FSDPStrategy
4 |
--------------------------------------------------------------------------------
/prismatic/conf/__init__.py:
--------------------------------------------------------------------------------
1 | from .datasets import DatasetConfig, DatasetRegistry
2 | from .models import ModelConfig, ModelRegistry
3 | from .vla import VLAConfig, VLARegistry
4 |
--------------------------------------------------------------------------------
/prismatic/models/backbones/llm/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_llm import LLMBackbone
2 | from .llama2 import LLaMa2LLMBackbone
3 | from .mistral import MistralLLMBackbone
4 | from .phi import PhiLLMBackbone
5 |
--------------------------------------------------------------------------------
/prismatic/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .load import available_model_names, available_models, get_model_description, load, load_vla
2 | from .materialize import get_llm_backbone_and_tokenizer, get_vision_backbone_and_transform, get_vlm
3 |
--------------------------------------------------------------------------------
/config_nav/mbra_config.yaml:
--------------------------------------------------------------------------------
1 | # parameters for MBRA
2 | context_size: 5 #history of image
3 | len_traj_pred: 8 #action length
4 | learn_angle: True
5 | obs_encoder: efficientnet-b0
6 | obs_encoding_size: 1024
7 | late_fusion: False
8 | mha_num_attention_heads: 4
9 | mha_num_attention_layers: 4
10 | mha_ff_dim_factor: 4
11 |
--------------------------------------------------------------------------------
/prismatic/models/backbones/llm/prompting/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_prompter import PromptBuilder, PurePromptBuilder
2 | from .llama2_chat_prompter import LLaMa2ChatPromptBuilder
3 | from .mistral_instruct_prompter import MistralInstructPromptBuilder
4 | from .phi_prompter import PhiPromptBuilder
5 | from .vicuna_v15_prompter import VicunaV15ChatPromptBuilder
6 |
--------------------------------------------------------------------------------
/prismatic/models/backbones/vision/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_vision import ImageTransform, VisionBackbone
2 | from .clip_vit import CLIPViTBackbone
3 | from .dinoclip_vit import DinoCLIPViTBackbone
4 | from .dinosiglip_vit import DinoSigLIPViTBackbone
5 | from .dinov2_vit import DinoV2ViTBackbone
6 | from .in1k_vit import IN1KViTBackbone
7 | from .siglip_vit import SigLIPViTBackbone
8 |
--------------------------------------------------------------------------------
/prismatic/vla/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .datasets import DummyDataset, EpisodicRLDSDataset, RLDSBatchTransform, RLDSDataset
2 | #from .lelan_dataset import LeLaN_Dataset_multi, LeLaN_Dataset_openvla, LeLaN_Dataset_openvla_act, LeLaN_Dataset_openvla_act_MMN
3 | #from .vint_hf_dataset import ViNTLeRobotDataset_IL2_gps_map2_crop_shadow_MMN, EpisodeSampler_IL_MMN
4 | #from .vint_dataset import ViNT_Dataset_gps_MMN
5 | #from .bdd_dataset import BDD_Dataset_multi_MMN
6 | #from .cast_dataset import CAST_Dataset_MMN
7 |
--------------------------------------------------------------------------------
/prismatic/models/backbones/vision/dinov2_vit.py:
--------------------------------------------------------------------------------
1 | """
2 | dinov2_vit.py
3 | """
4 |
5 | from prismatic.models.backbones.vision.base_vision import TimmViTBackbone
6 |
7 | # Registry =>> Supported DINOv2 Vision Backbones (from TIMM) =>> Note:: Using DINOv2 w/ Registers!
8 | # => Reference: https://arxiv.org/abs/2309.16588
9 | DINOv2_VISION_BACKBONES = {"dinov2-vit-l": "vit_large_patch14_reg4_dinov2.lvd142m"}
10 |
11 |
12 | class DinoV2ViTBackbone(TimmViTBackbone):
13 | def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None:
14 | super().__init__(
15 | vision_backbone_id,
16 | DINOv2_VISION_BACKBONES[vision_backbone_id],
17 | image_resize_strategy,
18 | default_image_size=default_image_size,
19 | )
20 |
--------------------------------------------------------------------------------
/prismatic/models/backbones/vision/in1k_vit.py:
--------------------------------------------------------------------------------
1 | """
2 | in1k_vit.py
3 |
4 | Vision Transformers trained / finetuned on ImageNet (ImageNet-21K =>> ImageNet-1K)
5 | """
6 |
7 | from prismatic.models.backbones.vision.base_vision import TimmViTBackbone
8 |
9 | # Registry =>> Supported Vision Backbones (from TIMM)
10 | IN1K_VISION_BACKBONES = {
11 | "in1k-vit-l": "vit_large_patch16_224.augreg_in21k_ft_in1k",
12 | }
13 |
14 |
15 | class IN1KViTBackbone(TimmViTBackbone):
16 | def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None:
17 | super().__init__(
18 | vision_backbone_id,
19 | IN1K_VISION_BACKBONES[vision_backbone_id],
20 | image_resize_strategy,
21 | default_image_size=default_image_size,
22 | )
23 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | omnivla-finetuned-cast/
2 | omnivla-original/
3 | omnivla-original-balance/
4 | MBRA/
5 | MBRA_BDD/
6 | OmniVLA.egg-info/
7 | prismatic/vla/datasets/sampler/
8 | prismatic/vla/datasets/bdd_dataset_.py
9 | prismatic/vla/datasets/cast_dataset_.py
10 | prismatic/vla/datasets/lelan_dataset_.py
11 | prismatic/vla/datasets/gnm_dataset_.py
12 | prismatic/vla/datasets/frodobots_dataset_.py
13 | vla-scripts/train_omnivla_dataset_.py
14 | vla-scripts/train_omnivla_dataset__.py
15 | vla-scripts/finetune_nav_MMN_dataset_cast2.py
16 | config_nav/mbra_and_dataset_config_.yaml
17 | config_nav/mbra_and_dataset_config__.yaml
18 | config_nav/mbra_and_dataset_config_bellman.yaml
19 | config_nav/base_server_bellman.yaml
20 | visualization/
21 | runs/
22 | # Ignore all __pycache__ directories everywhere
23 | **/__pycache__/
24 | # Ignore all wandb directories everywhere
25 | **/wandb/
26 |
--------------------------------------------------------------------------------
/SETUP.md:
--------------------------------------------------------------------------------
1 | # Setup Instructions
2 |
3 | ## Set Up Conda Environment
4 |
5 | ```bash
6 | # Create and activate conda environment
7 | conda create -n omnivla python=3.10 -y
8 | conda activate omnivla
9 |
10 | # Install PyTorch
11 | # Use a command specific to your machine: https://pytorch.org/get-started/locally/
12 | pip3 install numpy==1.26.4 torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0
13 |
14 | # Clone openvla-oft repo and pip install to download dependencies
15 | git clone https://github.com/NHirose/OmniVLA.git
16 | cd OmniVLA
17 | pip install -e .
18 |
19 | # Install Flash Attention 2 for training (https://github.com/Dao-AILab/flash-attention)
20 | # =>> If you run into difficulty, try `pip cache remove flash_attn` first
21 | pip install packaging ninja
22 | ninja --version; echo $? # Verify Ninja --> should return exit code "0"
23 | pip install "flash-attn==2.5.5" --no-build-isolation
24 | ```
25 |
--------------------------------------------------------------------------------
/prismatic/models/backbones/vision/siglip_vit.py:
--------------------------------------------------------------------------------
1 | """
2 | siglip_vit.py
3 | """
4 |
5 | from prismatic.models.backbones.vision.base_vision import TimmViTBackbone
6 |
7 | # Registry =>> Supported SigLIP Vision Backbones (from TIMM) =>> Note:: Using SigLIP w/ Patch = 14 (but SO400M Arch)
8 | SIGLIP_VISION_BACKBONES = {
9 | "siglip-vit-b16-224px": "vit_base_patch16_siglip_224",
10 | "siglip-vit-b16-256px": "vit_base_patch16_siglip_256",
11 | "siglip-vit-b16-384px": "vit_base_patch16_siglip_384",
12 | "siglip-vit-so400m": "vit_so400m_patch14_siglip_224",
13 | "siglip-vit-so400m-384px": "vit_so400m_patch14_siglip_384",
14 | }
15 |
16 |
17 | class SigLIPViTBackbone(TimmViTBackbone):
18 | def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None:
19 | super().__init__(
20 | vision_backbone_id,
21 | SIGLIP_VISION_BACKBONES[vision_backbone_id],
22 | image_resize_strategy,
23 | default_image_size=default_image_size,
24 | )
25 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Original Copyright (c) 2025 Moo Jin Kim, Chelsea Finn, Percy Liang.
4 | Modifications Copyright (c) 2025 Noriaki Hirose, Catherine Glossop, Dhruv Shah, Sergey Levine
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 |
--------------------------------------------------------------------------------
/prismatic/vla/datasets/rlds/utils/goal_relabeling.py:
--------------------------------------------------------------------------------
1 | """
2 | goal_relabeling.py
3 |
4 | Contains simple goal relabeling logic for BC use-cases where rewards and next_observations are not required.
5 | Each function should add entries to the "task" dict.
6 | """
7 |
8 | from typing import Dict
9 |
10 | import tensorflow as tf
11 |
12 | from prismatic.vla.datasets.rlds.utils.data_utils import tree_merge
13 |
14 |
15 | def uniform(traj: Dict) -> Dict:
16 | """Relabels with a true uniform distribution over future states."""
17 | traj_len = tf.shape(tf.nest.flatten(traj["observation"])[0])[0]
18 |
19 | # Select a random future index for each transition i in the range [i + 1, traj_len)
20 | rand = tf.random.uniform([traj_len])
21 | low = tf.cast(tf.range(traj_len) + 1, tf.float32)
22 | high = tf.cast(traj_len, tf.float32)
23 | goal_idxs = tf.cast(rand * (high - low) + low, tf.int32)
24 |
25 | # Sometimes there are floating-point errors that cause an out-of-bounds
26 | goal_idxs = tf.minimum(goal_idxs, traj_len - 1)
27 |
28 | # Adds keys to "task" mirroring "observation" keys (`tree_merge` to combine "pad_mask_dict" properly)
29 | goal = tf.nest.map_structure(lambda x: tf.gather(x, goal_idxs), traj["observation"])
30 | traj["task"] = tree_merge(traj["task"], goal)
31 |
32 | return traj
33 |
--------------------------------------------------------------------------------
/prismatic/models/backbones/vision/clip_vit.py:
--------------------------------------------------------------------------------
1 | """
2 | clip_vit.py
3 | """
4 |
5 | from prismatic.models.backbones.vision.base_vision import TimmViTBackbone
6 |
7 | # Registry =>> Supported CLIP Vision Backbones (from TIMM)
8 | CLIP_VISION_BACKBONES = {
9 | "clip-vit-b": "vit_base_patch16_clip_224.openai",
10 | "clip-vit-l": "vit_large_patch14_clip_224.openai",
11 | "clip-vit-l-336px": "vit_large_patch14_clip_336.openai",
12 | }
13 |
14 |
15 | # [IMPORTANT] By Default, TIMM initialized OpenAI CLIP models with the standard GELU activation from PyTorch.
16 | # HOWEVER =>> Original OpenAI models were trained with the quick_gelu *approximation* -- while it's
17 | # a decent approximation, the resulting features are *worse*; this was a super tricky bug
18 | # to identify, but luckily there's an easy fix (`override_act_layer`)
19 | class CLIPViTBackbone(TimmViTBackbone):
20 | def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None:
21 | super().__init__(
22 | vision_backbone_id,
23 | CLIP_VISION_BACKBONES[vision_backbone_id],
24 | image_resize_strategy,
25 | default_image_size=default_image_size,
26 | override_act_layer="quick_gelu" if CLIP_VISION_BACKBONES[vision_backbone_id].endswith(".openai") else None,
27 | )
28 |
--------------------------------------------------------------------------------
/prismatic/vla/datasets/data_config.yaml:
--------------------------------------------------------------------------------
1 |
2 | # global params for diffusion model
3 | # normalized min and max
4 | action_stats:
5 | min: [-2.5, -4] # [min_dx, min_dy]
6 | max: [5, 4] # [max_dx, max_dy]
7 |
8 | # data specific params
9 | recon:
10 | metric_waypoint_spacing: 0.25 # average spacing between waypoints (meters)
11 |
12 | # OPTIONAL (FOR VISUALIZATION ONLY)
13 | camera_metrics: # https://docs.opencv.org/4.x/dc/dbb/tutorial_py_calibration.html
14 | camera_height: 0.95 # meters
15 | camera_x_offset: 0.45 # distance between the center of the robot and the forward facing camera
16 | camera_matrix:
17 | fx: 272.547000
18 | fy: 266.358000
19 | cx: 320.000000
20 | cy: 220.000000
21 | dist_coeffs:
22 | k1: -0.038483
23 | k2: -0.010456
24 | p1: 0.003930
25 | p2: -0.001007
26 | k3: 0.0
27 |
28 | scand:
29 | metric_waypoint_spacing: 0.38
30 |
31 | tartan_drive:
32 | metric_waypoint_spacing: 0.72
33 |
34 | go_stanford:
35 | metric_waypoint_spacing: 0.12
36 |
37 | # private datasets:
38 | cory_hall:
39 | metric_waypoint_spacing: 0.06
40 |
41 | seattle:
42 | metric_waypoint_spacing: 0.35
43 |
44 | racer:
45 | metric_waypoint_spacing: 0.38
46 |
47 | carla_intvns:
48 | metric_waypoint_spacing: 1.39
49 |
50 | carla_cil:
51 | metric_waypoint_spacing: 1.27
52 |
53 | carla_intvns:
54 | metric_waypoint_spacing: 1.39
55 |
56 | carla:
57 | metric_waypoint_spacing: 1.59
58 | image_path_func: get_image_path
59 |
60 | sacson:
61 | metric_waypoint_spacing: 0.12
62 |
63 | go_stanford4:
64 | metric_waypoint_spacing: 0.12
65 |
66 | go_stanford2:
67 | metric_waypoint_spacing: 0.12
68 |
69 | humanw:
70 | metric_waypoint_spacing: 0.12
71 |
72 | youtube:
73 | metric_waypoint_spacing: 0.12
74 |
75 | bdd:
76 | metric_waypoint_spacing: 0.12
77 |
78 | # add your own dataset params here:
79 |
--------------------------------------------------------------------------------
/prismatic/vla/constants.py:
--------------------------------------------------------------------------------
1 | """
2 | Important constants for VLA training and evaluation.
3 |
4 | Attempts to automatically identify the correct constants to set based on the Python command used to launch
5 | training or evaluation. If it is unclear, defaults to using the LIBERO simulation benchmark constants.
6 | """
7 | import sys
8 | from enum import Enum
9 |
10 | # Llama 2 token constants
11 | IGNORE_INDEX = -100
12 | ACTION_TOKEN_BEGIN_IDX = 31743
13 | STOP_INDEX = 2 # ''
14 |
15 |
16 | # Defines supported normalization schemes for action and proprioceptive state.
17 | class NormalizationType(str, Enum):
18 | # fmt: off
19 | NORMAL = "normal" # Normalize to Mean = 0, Stdev = 1
20 | BOUNDS = "bounds" # Normalize to Interval = [-1, 1]
21 | BOUNDS_Q99 = "bounds_q99" # Normalize [quantile_01, ..., quantile_99] --> [-1, ..., 1]
22 | # fmt: on
23 |
24 |
25 | # Define constants for each robot platform
26 | OMNIVLA_CONSTANTS = {
27 | "NUM_ACTIONS_CHUNK": 8,
28 | "ACTION_DIM": 4,
29 | "POSE_DIM": 4,
30 | "ACTION_PROPRIO_NORMALIZATION_TYPE": NormalizationType.BOUNDS_Q99,
31 | }
32 |
33 | constants = OMNIVLA_CONSTANTS
34 |
35 | # Assign constants to global variables
36 | NUM_ACTIONS_CHUNK = constants["NUM_ACTIONS_CHUNK"]
37 | ACTION_DIM = constants["ACTION_DIM"]
38 | POSE_DIM = constants["POSE_DIM"]
39 | ACTION_PROPRIO_NORMALIZATION_TYPE = constants["ACTION_PROPRIO_NORMALIZATION_TYPE"]
40 |
41 | # Print which robot platform constants are being used (for debugging)
42 | print(f"Using OmniVLA constants:")
43 | print(f" NUM_ACTIONS_CHUNK = {NUM_ACTIONS_CHUNK}")
44 | print(f" ACTION_DIM = {ACTION_DIM}")
45 | print(f" POSE_DIM = {POSE_DIM}")
46 | print(f" ACTION_PROPRIO_NORMALIZATION_TYPE = {ACTION_PROPRIO_NORMALIZATION_TYPE}")
47 | print("If needed, manually set the correct constants in `prismatic/vla/constants.py`!")
48 |
--------------------------------------------------------------------------------
/prismatic/models/projectors.py:
--------------------------------------------------------------------------------
1 | """Implementation of additional projectors for additional inputs to the VLA models."""
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class ProprioProjector(nn.Module):
7 | """
8 | Projects proprio state inputs into the LLM's embedding space.
9 | """
10 | def __init__(self, llm_dim: int, proprio_dim: int) -> None:
11 | super().__init__()
12 | self.llm_dim = llm_dim
13 | self.proprio_dim = proprio_dim
14 |
15 | self.fc1 = nn.Linear(self.proprio_dim, self.llm_dim, bias=True)
16 | self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
17 | self.act_fn1 = nn.GELU()
18 |
19 | def forward(self, proprio: torch.Tensor = None) -> torch.Tensor:
20 | # proprio: (bsz, proprio_dim)
21 | projected_features = self.fc1(proprio)
22 | projected_features = self.act_fn1(projected_features)
23 | projected_features = self.fc2(projected_features)
24 | return projected_features
25 |
26 |
27 | class NoisyActionProjector(nn.Module):
28 | """
29 | [Diffusion] Projects noisy action inputs into the LLM's embedding space.
30 |
31 | Note that since each action is tokenized into 7 tokens in OpenVLA (rather
32 | than having 1 token per action), each noisy action token will have dimension 1
33 | instead of 7.
34 | """
35 | def __init__(self, llm_dim: int) -> None:
36 | super().__init__()
37 | self.llm_dim = llm_dim
38 | self.action_token_dim = 1
39 |
40 | self.fc1 = nn.Linear(self.action_token_dim, self.llm_dim, bias=True)
41 | self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
42 | self.act_fn1 = nn.GELU()
43 |
44 | def forward(self, noisy_actions: torch.Tensor = None) -> torch.Tensor:
45 | # noisy_actions: (bsz, num_action_tokens=chunk_len*action_dim, 1)
46 | projected_features = self.fc1(noisy_actions)
47 | projected_features = self.act_fn1(projected_features)
48 | projected_features = self.fc2(projected_features)
49 | return projected_features
50 |
--------------------------------------------------------------------------------
/prismatic/training/train_utils.py:
--------------------------------------------------------------------------------
1 | """Utils for training/fine-tuning scripts."""
2 |
3 | import torch
4 |
5 | from prismatic.vla.constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX
6 |
7 |
8 | def get_current_action_mask(token_ids):
9 | # Create a tensor marking positions of IGNORE_INDEX
10 | newline_positions = token_ids != IGNORE_INDEX
11 |
12 | # Calculate cumulative sum to identify regions between newlines
13 | cumsum = torch.cumsum(newline_positions, dim=1)
14 |
15 | # Create the mask
16 | mask = (1 <= cumsum) & (cumsum <= ACTION_DIM)
17 |
18 | # Extract the action part only
19 | action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX
20 | mask = action_tokens_only_mask * mask
21 |
22 | return mask
23 |
24 |
25 | def get_next_actions_mask(token_ids):
26 | # Create a tensor marking positions of IGNORE_INDEX
27 | newline_positions = token_ids != IGNORE_INDEX
28 |
29 | # Calculate cumulative sum to identify regions between newlines
30 | cumsum = torch.cumsum(newline_positions, dim=1)
31 |
32 | # Create the mask
33 | mask = cumsum > ACTION_DIM
34 |
35 | # Extract the action part only
36 | action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX
37 | mask = action_tokens_only_mask * mask
38 |
39 | return mask
40 |
41 |
42 | def compute_token_accuracy(predicted_token_ids, ground_truth_token_ids, mask):
43 | correct_preds = (predicted_token_ids == ground_truth_token_ids) & mask
44 | accuracy = correct_preds.sum().float() / mask.sum().float()
45 | return accuracy
46 |
47 |
48 | def compute_actions_l1_loss(action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask):
49 | pred_continuous_actions = torch.tensor(
50 | action_tokenizer.decode_token_ids_to_actions(predicted_token_ids[mask].cpu().numpy())
51 | )
52 | true_continuous_actions = torch.tensor(
53 | action_tokenizer.decode_token_ids_to_actions(ground_truth_token_ids[mask].cpu().numpy())
54 | )
55 | l1_loss = torch.nn.functional.l1_loss(pred_continuous_actions, true_continuous_actions)
56 | return l1_loss
57 |
--------------------------------------------------------------------------------
/prismatic/util/nn_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | nn_utils.py
3 |
4 | Utility functions and PyTorch submodule definitions.
5 | """
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 |
11 | # === Definitions for Various Projection Modules, with Signature :: [..., in_dim] --> [..., out_dim] ===
12 | class LinearProjector(nn.Module):
13 | def __init__(self, vision_dim: int, llm_dim: int) -> None:
14 | super().__init__()
15 | self.projector = nn.Linear(vision_dim, llm_dim, bias=True)
16 |
17 | def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
18 | return self.projector(img_patches)
19 |
20 |
21 | class MLPProjector(nn.Module):
22 | def __init__(self, vision_dim: int, llm_dim: int, mlp_type: str = "gelu-mlp") -> None:
23 | super().__init__()
24 | if mlp_type == "gelu-mlp":
25 | self.projector = nn.Sequential(
26 | nn.Linear(vision_dim, llm_dim, bias=True),
27 | nn.GELU(),
28 | nn.Linear(llm_dim, llm_dim, bias=True),
29 | )
30 | else:
31 | raise ValueError(f"Projector with `{mlp_type = }` is not supported!")
32 |
33 | def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
34 | return self.projector(img_patches)
35 |
36 |
37 | class FusedMLPProjector(nn.Module):
38 | def __init__(self, fused_vision_dim: int, llm_dim: int, mlp_type: str = "fused-gelu-mlp") -> None:
39 | super().__init__()
40 | self.initial_projection_dim = fused_vision_dim * 4
41 | if mlp_type == "fused-gelu-mlp":
42 | self.projector = nn.Sequential(
43 | nn.Linear(fused_vision_dim, self.initial_projection_dim, bias=True),
44 | nn.GELU(),
45 | nn.Linear(self.initial_projection_dim, llm_dim, bias=True),
46 | nn.GELU(),
47 | nn.Linear(llm_dim, llm_dim, bias=True),
48 | )
49 | else:
50 | raise ValueError(f"Fused Projector with `{mlp_type = }` is not supported!")
51 |
52 | def forward(self, fused_img_patches: torch.Tensor) -> torch.Tensor:
53 | return self.projector(fused_img_patches)
54 |
--------------------------------------------------------------------------------
/prismatic/vla/datasets/rlds/utils/task_augmentation.py:
--------------------------------------------------------------------------------
1 | """
2 | task_augmentation.py
3 |
4 | Contains basic logic for randomly zeroing out keys in the task specification.
5 | """
6 |
7 | from typing import Dict
8 |
9 | import tensorflow as tf
10 |
11 | from prismatic.vla.datasets.rlds.utils.data_utils import to_padding
12 |
13 |
14 | def delete_task_conditioning(traj: Dict, keep_image_prob: float) -> Dict:
15 | """
16 | Randomly drops out either the goal images or the language instruction. Only does something if both of
17 | these are present.
18 |
19 | Args:
20 | traj: A dictionary containing trajectory data. Should have a "task" key.
21 | keep_image_prob: The probability of keeping the goal images. The probability of keeping the language
22 | instruction is 1 - keep_image_prob.
23 | """
24 | if "language_instruction" not in traj["task"]:
25 | return traj
26 |
27 | image_keys = {key for key in traj["task"].keys() if key.startswith("image_") or key.startswith("depth_")}
28 | if not image_keys:
29 | return traj
30 |
31 | traj_len = tf.shape(traj["action"])[0]
32 | should_keep_images = tf.random.uniform([traj_len]) < keep_image_prob
33 | should_keep_images |= ~traj["task"]["pad_mask_dict"]["language_instruction"]
34 |
35 | for key in image_keys | {"language_instruction"}:
36 | should_keep = should_keep_images if key in image_keys else ~should_keep_images
37 | # pad out the key
38 | traj["task"][key] = tf.where(
39 | should_keep,
40 | traj["task"][key],
41 | to_padding(traj["task"][key]),
42 | )
43 | # zero out the pad mask dict for the key
44 | traj["task"]["pad_mask_dict"][key] = tf.where(
45 | should_keep,
46 | traj["task"]["pad_mask_dict"][key],
47 | tf.zeros_like(traj["task"]["pad_mask_dict"][key]),
48 | )
49 |
50 | # when no goal images are present, the goal timestep becomes the final timestep
51 | traj["task"]["timestep"] = tf.where(
52 | should_keep_images,
53 | traj["task"]["timestep"],
54 | traj_len - 1,
55 | )
56 |
57 | return traj
58 |
--------------------------------------------------------------------------------
/prismatic/models/backbones/llm/phi.py:
--------------------------------------------------------------------------------
1 | """
2 | phi.py
3 |
4 | Class definition for all LLMs derived from PhiForCausalLM.
5 | """
6 |
7 | from typing import Optional, Type
8 |
9 | import torch
10 | from torch import nn as nn
11 | from transformers import PhiForCausalLM
12 | from transformers.models.phi.modeling_phi import PhiDecoderLayer
13 |
14 | from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone
15 | from prismatic.models.backbones.llm.prompting import PhiPromptBuilder, PromptBuilder
16 |
17 | # Registry ==> Support Phi Models (from HF Transformers)
18 | # fmt: off
19 | PHI_MODELS = {
20 | # === Phi-2 ===
21 | "phi-2-3b": {
22 | "llm_family": "phi", "llm_cls": PhiForCausalLM, "hf_hub_path": "microsoft/phi-2"
23 | }
24 | }
25 | # fmt: on
26 |
27 |
28 | class PhiLLMBackbone(HFCausalLLMBackbone):
29 | def __init__(
30 | self,
31 | llm_backbone_id: str,
32 | llm_max_length: int = 2048,
33 | hf_token: Optional[str] = None,
34 | inference_mode: bool = False,
35 | use_flash_attention_2: bool = True,
36 | ) -> None:
37 | super().__init__(
38 | llm_backbone_id,
39 | llm_max_length=llm_max_length,
40 | hf_token=hf_token,
41 | inference_mode=inference_mode,
42 | use_flash_attention_2=use_flash_attention_2,
43 | **PHI_MODELS[llm_backbone_id],
44 | )
45 |
46 | # [Special Case] Phi PAD Token Handling --> for clarity, we add an extra token (and resize)
47 | self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
48 | self.llm.config.pad_token_id = self.tokenizer.pad_token_id
49 | self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64)
50 |
51 | @property
52 | def prompt_builder_fn(self) -> Type[PromptBuilder]:
53 | if self.identifier.startswith("phi-2"):
54 | return PhiPromptBuilder
55 |
56 | raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`")
57 |
58 | @property
59 | def transformer_layer_cls(self) -> Type[nn.Module]:
60 | return PhiDecoderLayer
61 |
62 | @property
63 | def half_precision_dtype(self) -> torch.dtype:
64 | return torch.bfloat16
65 |
--------------------------------------------------------------------------------
/prismatic/vla/materialize.py:
--------------------------------------------------------------------------------
1 | """
2 | materialize.py
3 |
4 | Factory class for initializing Open-X RLDS-backed datasets, given specified data mixture parameters; provides and
5 | exports individual functions for clear control flow.
6 | """
7 |
8 | from pathlib import Path
9 | from typing import Tuple, Type
10 |
11 | from torch.utils.data import Dataset
12 | from transformers import PreTrainedTokenizerBase
13 |
14 | from prismatic.models.backbones.llm.prompting import PromptBuilder
15 | from prismatic.models.backbones.vision import ImageTransform
16 | from prismatic.util.data_utils import PaddedCollatorForActionPrediction
17 | from prismatic.vla.action_tokenizer import ActionTokenizer
18 | from prismatic.vla.datasets import EpisodicRLDSDataset, RLDSBatchTransform, RLDSDataset
19 |
20 |
21 | def get_vla_dataset_and_collator(
22 | data_root_dir: Path,
23 | data_mix: str,
24 | image_transform: ImageTransform,
25 | tokenizer: PreTrainedTokenizerBase,
26 | prompt_builder_fn: Type[PromptBuilder],
27 | default_image_resolution: Tuple[int, int, int],
28 | padding_side: str = "right",
29 | predict_stop_token: bool = True,
30 | shuffle_buffer_size: int = 100_000,
31 | train: bool = True,
32 | episodic: bool = False,
33 | image_aug: bool = False,
34 | ) -> Tuple[Dataset, ActionTokenizer, PaddedCollatorForActionPrediction]:
35 | """Initialize RLDS Dataset (wraps TFDS), ActionTokenizer, and initialize transform/collation functions."""
36 | action_tokenizer = ActionTokenizer(tokenizer)
37 | batch_transform = RLDSBatchTransform(
38 | action_tokenizer, tokenizer, image_transform, prompt_builder_fn, predict_stop_token=predict_stop_token
39 | )
40 | collator = PaddedCollatorForActionPrediction(
41 | tokenizer.model_max_length, tokenizer.pad_token_id, padding_side=padding_side
42 | )
43 |
44 | # Build RLDS Iterable Dataset
45 | cls = RLDSDataset if not episodic else EpisodicRLDSDataset
46 | dataset = cls(
47 | data_root_dir,
48 | data_mix,
49 | batch_transform,
50 | resize_resolution=default_image_resolution[1:],
51 | shuffle_buffer_size=shuffle_buffer_size,
52 | train=train,
53 | image_aug=image_aug,
54 | )
55 |
56 | return dataset, action_tokenizer, collator
57 |
--------------------------------------------------------------------------------
/prismatic/models/backbones/llm/prompting/mistral_instruct_prompter.py:
--------------------------------------------------------------------------------
1 | """
2 | mistral_instruct_prompter.py
3 |
4 | Defines a PromptBuilder for building Mistral Instruct Chat Prompts --> recommended pattern used by HF / online tutorial.s
5 |
6 | Reference: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format
7 | """
8 |
9 | from typing import Optional
10 |
11 | from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder
12 |
13 |
14 | class MistralInstructPromptBuilder(PromptBuilder):
15 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None:
16 | super().__init__(model_family, system_prompt)
17 |
18 | # Note =>> Mistral Tokenizer is an instance of `LlamaTokenizer(Fast)`
19 | # =>> Mistral Instruct *does not* use a System Prompt
20 | self.bos, self.eos = "", ""
21 |
22 | # Get role-specific "wrap" functions
23 | self.wrap_human = lambda msg: f"[INST] {msg} [/INST] "
24 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}"
25 |
26 | # === `self.prompt` gets built up over multiple turns ===
27 | self.prompt, self.turn_count = "", 0
28 |
29 | def add_turn(self, role: str, message: str) -> str:
30 | assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt")
31 | message = message.replace("", "").strip()
32 |
33 | if (self.turn_count % 2) == 0:
34 | human_message = self.wrap_human(message)
35 | wrapped_message = human_message
36 | else:
37 | gpt_message = self.wrap_gpt(message)
38 | wrapped_message = gpt_message
39 |
40 | # Update Prompt
41 | self.prompt += wrapped_message
42 |
43 | # Bump Turn Counter
44 | self.turn_count += 1
45 |
46 | # Return "wrapped_message" (effective string added to context)
47 | return wrapped_message
48 |
49 | def get_potential_prompt(self, message: str) -> None:
50 | # Assumes that it's always the user's (human's) turn!
51 | prompt_copy = str(self.prompt)
52 |
53 | human_message = self.wrap_human(message)
54 | prompt_copy += human_message
55 |
56 | return prompt_copy.removeprefix(self.bos).rstrip()
57 |
58 | def get_prompt(self) -> str:
59 | # Remove prefix because it gets auto-inserted by tokenizer!
60 | return self.prompt.removeprefix(self.bos).rstrip()
61 |
--------------------------------------------------------------------------------
/prismatic/training/materialize.py:
--------------------------------------------------------------------------------
1 | """
2 | materialize.py
3 |
4 | Factory class defining functions for instantiating various Training Strategies, supporting different VLMs, backbones,
5 | and strategy configurations.
6 | """
7 |
8 | from typing import Callable, Optional
9 |
10 | import torch
11 |
12 | from prismatic.models.vlms import PrismaticVLM
13 | from prismatic.training.strategies import FSDPStrategy, TrainingStrategy
14 |
15 | # Registry =>> Maps ID --> {cls(), kwargs} :: supports FSDP for now, but DDP handler is also implemented!
16 | TRAIN_STRATEGIES = {
17 | "fsdp-shard-grad-op": {"cls": FSDPStrategy, "kwargs": {"sharding_strategy": "shard-grad-op"}},
18 | "fsdp-full-shard": {"cls": FSDPStrategy, "kwargs": {"sharding_strategy": "full-shard"}},
19 | }
20 |
21 |
22 | def get_train_strategy(
23 | train_strategy: str,
24 | vlm: PrismaticVLM,
25 | device_id: int,
26 | stage: str,
27 | epochs: int,
28 | max_steps: Optional[int],
29 | global_batch_size: int,
30 | per_device_batch_size: int,
31 | learning_rate: float,
32 | weight_decay: float,
33 | max_grad_norm: float,
34 | lr_scheduler_type: str,
35 | warmup_ratio: float,
36 | enable_gradient_checkpointing: bool = True,
37 | enable_mixed_precision_training: bool = True,
38 | reduce_in_full_precision: bool = False,
39 | mixed_precision_dtype: torch.dtype = torch.bfloat16,
40 | worker_init_fn: Optional[Callable[[int], None]] = None,
41 | ) -> TrainingStrategy:
42 | if train_strategy in TRAIN_STRATEGIES:
43 | strategy_cfg = TRAIN_STRATEGIES[train_strategy]
44 | strategy = strategy_cfg["cls"](
45 | vlm=vlm,
46 | device_id=device_id,
47 | stage=stage,
48 | epochs=epochs,
49 | max_steps=max_steps,
50 | global_batch_size=global_batch_size,
51 | per_device_batch_size=per_device_batch_size,
52 | learning_rate=learning_rate,
53 | weight_decay=weight_decay,
54 | max_grad_norm=max_grad_norm,
55 | lr_scheduler_type=lr_scheduler_type,
56 | warmup_ratio=warmup_ratio,
57 | enable_gradient_checkpointing=enable_gradient_checkpointing,
58 | enable_mixed_precision_training=enable_mixed_precision_training,
59 | reduce_in_full_precision=reduce_in_full_precision,
60 | mixed_precision_dtype=mixed_precision_dtype,
61 | worker_init_fn=worker_init_fn,
62 | **strategy_cfg["kwargs"],
63 | )
64 | return strategy
65 | else:
66 | raise ValueError(f"Train Strategy `{train_strategy}` is not supported!")
67 |
--------------------------------------------------------------------------------
/prismatic/models/backbones/llm/mistral.py:
--------------------------------------------------------------------------------
1 | """
2 | mistral.py
3 |
4 | Class definition for all LLMs derived from MistralForCausalLM.
5 | """
6 |
7 | from typing import Optional, Type
8 |
9 | import torch
10 | from torch import nn as nn
11 | from transformers import MistralForCausalLM
12 | from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
13 |
14 | from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone
15 | from prismatic.models.backbones.llm.prompting import MistralInstructPromptBuilder, PromptBuilder, PurePromptBuilder
16 |
17 | # Registry =>> Support Mistral Models (from HF Transformers)
18 | # fmt: off
19 | MISTRAL_MODELS = {
20 | # === Base Mistral v0.1 ===
21 | "mistral-v0.1-7b-pure": {
22 | "llm_family": "mistral", "llm_cls": MistralForCausalLM, "hf_hub_path": "mistralai/Mistral-7B-v0.1"
23 | },
24 |
25 | # === Mistral Instruct v0.1 ===
26 | "mistral-v0.1-7b-instruct": {
27 | "llm_family": "mistral", "llm_cls": MistralForCausalLM, "hf_hub_path": "mistralai/Mistral-7B-Instruct-v0.1"
28 | }
29 | }
30 | # fmt: on
31 |
32 |
33 | class MistralLLMBackbone(HFCausalLLMBackbone):
34 | def __init__(
35 | self,
36 | llm_backbone_id: str,
37 | llm_max_length: int = 2048,
38 | hf_token: Optional[str] = None,
39 | inference_mode: bool = False,
40 | use_flash_attention_2: bool = True,
41 | ) -> None:
42 | super().__init__(
43 | llm_backbone_id,
44 | llm_max_length=llm_max_length,
45 | hf_token=hf_token,
46 | inference_mode=inference_mode,
47 | use_flash_attention_2=use_flash_attention_2,
48 | **MISTRAL_MODELS[llm_backbone_id],
49 | )
50 |
51 | # [Special Case] Mistral PAD Token Handling --> for clarity, we add an extra token (and resize)
52 | self.tokenizer.add_special_tokens({"pad_token": ""})
53 | self.llm.config.pad_token_id = self.tokenizer.pad_token_id
54 | self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64)
55 |
56 | @property
57 | def prompt_builder_fn(self) -> Type[PromptBuilder]:
58 | if self.identifier.endswith("-pure"):
59 | return PurePromptBuilder
60 |
61 | elif self.identifier.endswith("-instruct"):
62 | return MistralInstructPromptBuilder
63 |
64 | raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`")
65 |
66 | @property
67 | def transformer_layer_cls(self) -> Type[nn.Module]:
68 | return MistralDecoderLayer
69 |
70 | @property
71 | def half_precision_dtype(self) -> torch.dtype:
72 | return torch.bfloat16
73 |
--------------------------------------------------------------------------------
/prismatic/models/backbones/llm/prompting/phi_prompter.py:
--------------------------------------------------------------------------------
1 | """
2 | phi_prompter.py
3 |
4 | Defines a PromptBuilder for building Phi-2 Input/Output Prompts --> recommended pattern used by HF / Microsoft.
5 | Also handles Phi special case BOS token additions.
6 |
7 | Reference: https://huggingface.co/microsoft/phi-2#qa-format
8 | """
9 |
10 | from typing import Optional
11 |
12 | from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder
13 |
14 |
15 | class PhiPromptBuilder(PromptBuilder):
16 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None:
17 | super().__init__(model_family, system_prompt)
18 |
19 | # Note =>> Phi Tokenizer is an instance of `CodeGenTokenizer(Fast)`
20 | # =>> By default, does *not* append / tokens --> we handle that here (IMPORTANT)!
21 | self.bos, self.eos = "<|endoftext|>", "<|endoftext|>"
22 |
23 | # Get role-specific "wrap" functions
24 | # =>> Note that placement of / were based on experiments generating from Phi-2 in Input/Output mode
25 | self.wrap_human = lambda msg: f"Input: {msg}\nOutput: "
26 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}\n{self.eos}"
27 |
28 | # === `self.prompt` gets built up over multiple turns ===
29 | self.prompt, self.turn_count = "", 0
30 |
31 | def add_turn(self, role: str, message: str) -> str:
32 | assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt")
33 | message = message.replace("", "").strip()
34 |
35 | # Special Handling for "first" input --> prepend a token (expected by Prismatic)
36 | if self.turn_count == 0:
37 | bos_human_message = f"{self.bos}{self.wrap_human(message)}"
38 | wrapped_message = bos_human_message
39 | elif (self.turn_count % 2) == 0:
40 | human_message = self.wrap_human(message)
41 | wrapped_message = human_message
42 | else:
43 | gpt_message = self.wrap_gpt(message)
44 | wrapped_message = gpt_message
45 |
46 | # Update Prompt
47 | self.prompt += wrapped_message
48 |
49 | # Bump Turn Counter
50 | self.turn_count += 1
51 |
52 | # Return "wrapped_message" (effective string added to context)
53 | return wrapped_message
54 |
55 | def get_potential_prompt(self, message: str) -> None:
56 | # Assumes that it's always the user's (human's) turn!
57 | prompt_copy = str(self.prompt)
58 |
59 | human_message = self.wrap_human(message)
60 | prompt_copy += human_message
61 |
62 | return prompt_copy.rstrip()
63 |
64 | def get_prompt(self) -> str:
65 | return self.prompt.rstrip()
66 |
--------------------------------------------------------------------------------
/prismatic/preprocessing/materialize.py:
--------------------------------------------------------------------------------
1 | """
2 | materialize.py
3 |
4 | Factory class for initializing pretraining datasets on a per-VLM basis; provides and exports individual functions for
5 | clear control flow.
6 | """
7 |
8 | from typing import Tuple, Type
9 |
10 | from torch.utils.data import Dataset
11 | from transformers import PreTrainedTokenizerBase
12 |
13 | from prismatic.conf import DatasetConfig
14 | from prismatic.models.backbones.llm.prompting import PromptBuilder
15 | from prismatic.models.backbones.vision import ImageTransform
16 | from prismatic.preprocessing.datasets import AlignDataset, FinetuneDataset
17 | from prismatic.util.data_utils import PaddedCollatorForLanguageModeling
18 |
19 | # Dataset Initializers =>> Maps Stage --> cls()
20 | DATASET_INITIALIZER = {"align": AlignDataset, "finetune": FinetuneDataset, "full-finetune": FinetuneDataset}
21 |
22 |
23 | def get_dataset_and_collator(
24 | stage: str,
25 | dataset_cfg: DatasetConfig,
26 | image_transform: ImageTransform,
27 | tokenizer: PreTrainedTokenizerBase,
28 | prompt_builder_fn: Type[PromptBuilder],
29 | default_image_resolution: Tuple[int, int, int],
30 | padding_side: str = "right",
31 | ) -> Tuple[Dataset, PaddedCollatorForLanguageModeling]:
32 | dataset_cls = DATASET_INITIALIZER[stage]
33 | dataset_root_dir = dataset_cfg.dataset_root_dir
34 | collator = PaddedCollatorForLanguageModeling(
35 | tokenizer.model_max_length, tokenizer.pad_token_id, default_image_resolution, padding_side=padding_side
36 | )
37 |
38 | # Switch on `stage`
39 | if stage == "align":
40 | annotation_json, image_dir = dataset_cfg.align_stage_components
41 | dataset = dataset_cls(
42 | dataset_root_dir / annotation_json, dataset_root_dir / image_dir, image_transform, tokenizer
43 | )
44 | return dataset, collator
45 |
46 | elif stage == "finetune":
47 | annotation_json, image_dir = dataset_cfg.finetune_stage_components
48 | dataset = dataset_cls(
49 | dataset_root_dir / annotation_json,
50 | dataset_root_dir / image_dir,
51 | image_transform,
52 | tokenizer,
53 | prompt_builder_fn=prompt_builder_fn,
54 | )
55 | return dataset, collator
56 |
57 | elif stage == "full-finetune":
58 | annotation_json, image_dir = dataset_cfg.finetune_stage_components
59 | dataset = dataset_cls(
60 | dataset_root_dir / annotation_json,
61 | dataset_root_dir / image_dir,
62 | image_transform,
63 | tokenizer,
64 | prompt_builder_fn=prompt_builder_fn,
65 | )
66 | return dataset, collator
67 |
68 | else:
69 | raise ValueError(f"Stage `{stage}` is not supported!")
70 |
--------------------------------------------------------------------------------
/prismatic/models/backbones/llm/prompting/base_prompter.py:
--------------------------------------------------------------------------------
1 | """
2 | base_prompter.py
3 |
4 | Abstract class definition of a multi-turn prompt builder for ensuring consistent formatting for chat-based LLMs.
5 | """
6 |
7 | from abc import ABC, abstractmethod
8 | from typing import Optional
9 |
10 |
11 | class PromptBuilder(ABC):
12 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None:
13 | self.model_family = model_family
14 |
15 | # Only some models define a system prompt => let subclasses handle this logic!
16 | self.system_prompt = system_prompt
17 |
18 | @abstractmethod
19 | def add_turn(self, role: str, message: str) -> str: ...
20 |
21 | @abstractmethod
22 | def get_potential_prompt(self, user_msg: str) -> None: ...
23 |
24 | @abstractmethod
25 | def get_prompt(self) -> str: ...
26 |
27 |
28 | class PurePromptBuilder(PromptBuilder):
29 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None:
30 | super().__init__(model_family, system_prompt)
31 |
32 | # TODO (siddk) =>> Can't always assume LlamaTokenizer --> FIX ME!
33 | self.bos, self.eos = "", ""
34 |
35 | # Get role-specific "wrap" functions
36 | self.wrap_human = lambda msg: f"In: {msg}\nOut: "
37 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}"
38 |
39 | # === `self.prompt` gets built up over multiple turns ===
40 | self.prompt, self.turn_count = "", 0
41 |
42 | def add_turn(self, role: str, message: str) -> str:
43 | assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt")
44 | message = message.replace("", "").strip()
45 |
46 | if (self.turn_count % 2) == 0:
47 | human_message = self.wrap_human(message)
48 | wrapped_message = human_message
49 | else:
50 | gpt_message = self.wrap_gpt(message)
51 | wrapped_message = gpt_message
52 |
53 | # Update Prompt
54 | self.prompt += wrapped_message
55 |
56 | # Bump Turn Counter
57 | self.turn_count += 1
58 |
59 | # Return "wrapped_message" (effective string added to context)
60 | return wrapped_message
61 |
62 | def get_potential_prompt(self, message: str) -> None:
63 | # Assumes that it's always the user's (human's) turn!
64 | prompt_copy = str(self.prompt)
65 |
66 | human_message = self.wrap_human(message)
67 | prompt_copy += human_message
68 |
69 | return prompt_copy.removeprefix(self.bos).rstrip()
70 |
71 | def get_prompt(self) -> str:
72 | # Remove prefix (if exists) because it gets auto-inserted by tokenizer!
73 | return self.prompt.removeprefix(self.bos).rstrip()
74 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "OmniVLA"
7 | authors = [
8 | {name = "Noriaki Hirose"},
9 | {name = "Catherine Glossop"},
10 | {name = "Dhruv Shah"},
11 | {name = "Sergey Levine"},
12 | ]
13 | description = "OmniVLA: An Omni-Modal Vision-Language-Action Model for Robot Navigation"
14 | version = "0.0.1"
15 | readme = "README.md"
16 | requires-python = ">=3.8"
17 | keywords = ["vision-language-action model", "vision-based navigation", "robot learning"]
18 | license = {file = "LICENSE"}
19 | classifiers = [
20 | "Development Status :: 3 - Alpha",
21 | "Intended Audience :: Developers",
22 | "Intended Audience :: Education",
23 | "Intended Audience :: Science/Research",
24 | "License :: OSI Approved :: MIT License",
25 | "Operating System :: OS Independent",
26 | "Programming Language :: Python :: 3",
27 | "Programming Language :: Python :: 3.8",
28 | "Programming Language :: Python :: 3.9",
29 | "Programming Language :: Python :: 3.10",
30 | "Programming Language :: Python :: 3 :: Only",
31 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
32 | ]
33 | dependencies = [
34 | "accelerate>=0.25.0",
35 | "draccus==0.8.0",
36 | "einops",
37 | # "flash_attn==2.5.5", # Here for documentation -- install *AFTER* editable install (follow README)
38 | "huggingface_hub==0.29.1",
39 | "json-numpy",
40 | "jsonlines",
41 | "matplotlib",
42 | "peft==0.11.1",
43 | "protobuf",
44 | "rich",
45 | "sentencepiece==0.1.99",
46 | "timm==0.9.10",
47 | "tokenizers==0.19.1",
48 | "torch==2.2.0",
49 | "torchvision==0.17.0",
50 | "torchaudio==2.2.0",
51 | "transformers @ git+https://github.com/moojink/transformers-openvla-oft.git", # IMPORTANT: Use this fork for bidirectional attn (for parallel decoding)
52 | "wandb",
53 | "tensorflow==2.15.0",
54 | "tensorflow_datasets==4.9.3",
55 | "tensorflow_graphics==2021.12.3",
56 | "dlimp @ git+https://github.com/moojink/dlimp_openvla",
57 | "diffusers==0.11.1",
58 | "imageio",
59 | "uvicorn",
60 | "fastapi",
61 | "utm",
62 | "lmdb",
63 | "zarr",
64 | "datasets",
65 | "efficientnet_pytorch",
66 | "av"
67 | ]
68 |
69 | [project.optional-dependencies]
70 | dev = [
71 | "black>=24.2.0",
72 | "gpustat",
73 | "ipython",
74 | "pre-commit",
75 | "ruff>=0.2.2",
76 | ]
77 | sagemaker = [
78 | "boto3",
79 | "sagemaker"
80 | ]
81 |
82 | [project.urls]
83 | homepage = "https://omnivla-nav.github.io"
84 | repository = "https://github.com/NHirose/OmniVLA"
85 |
86 | [tool.setuptools.packages.find]
87 | where = ["."]
88 | exclude = ["cache"]
89 |
90 | [tool.setuptools.package-data]
91 | "prismatic" = ["py.typed"]
92 |
93 | [tool.black]
94 | line-length = 121
95 | target-version = ["py38", "py39", "py310"]
96 | preview = true
97 |
98 | [tool.ruff]
99 | line-length = 121
100 | target-version = "py38"
101 |
102 | [tool.ruff.lint]
103 | select = ["A", "B", "E", "F", "I", "RUF", "W"]
104 | ignore = ["F722"]
105 |
106 | [tool.ruff.lint.per-file-ignores]
107 | "__init__.py" = ["E402", "F401"]
108 |
--------------------------------------------------------------------------------
/prismatic/models/backbones/llm/prompting/vicuna_v15_prompter.py:
--------------------------------------------------------------------------------
1 | """
2 | vicuna_v15_prompter.py
3 |
4 | Defines a PromptBuilder for building Vicuna-v1.5 Chat Prompts.
5 |
6 | Reference: https://huggingface.co/lmsys/vicuna-13b-v1.5
7 | """
8 |
9 | from typing import Optional
10 |
11 | from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder
12 |
13 | # Default System Prompt for LLaVa Models
14 | SYS_PROMPTS = {
15 | "prismatic": (
16 | "A chat between a curious user and an artificial intelligence assistant. "
17 | "The assistant gives helpful, detailed, and polite answers to the user's questions."
18 | ),
19 | "openvla": (
20 | "A chat between a curious user and an artificial intelligence assistant. "
21 | "The assistant gives helpful, detailed, and polite answers to the user's questions."
22 | ),
23 | }
24 |
25 |
26 | class VicunaV15ChatPromptBuilder(PromptBuilder):
27 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None:
28 | super().__init__(model_family, system_prompt)
29 | self.system_prompt = (SYS_PROMPTS[self.model_family] if system_prompt is None else system_prompt).strip() + " "
30 |
31 | # LLaMa-2 Specific
32 | self.bos, self.eos = "", ""
33 |
34 | # Get role-specific "wrap" functions
35 | self.wrap_human = lambda msg: f"USER: {msg} ASSISTANT: "
36 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}"
37 |
38 | # === `self.prompt` gets built up over multiple turns ===
39 | self.prompt, self.turn_count = "", 0
40 |
41 | def add_turn(self, role: str, message: str) -> str:
42 | assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt")
43 | message = message.replace("", "").strip()
44 |
45 | # Special Handling for "system" prompt (turn_count == 0)
46 | if self.turn_count == 0:
47 | sys_message = self.system_prompt + self.wrap_human(message)
48 | wrapped_message = sys_message
49 | elif (self.turn_count % 2) == 0:
50 | human_message = self.wrap_human(message)
51 | wrapped_message = human_message
52 | else:
53 | gpt_message = self.wrap_gpt(message)
54 | wrapped_message = gpt_message
55 |
56 | # Update Prompt
57 | self.prompt += wrapped_message
58 |
59 | # Bump Turn Counter
60 | self.turn_count += 1
61 |
62 | # Return "wrapped_message" (effective string added to context)
63 | return wrapped_message
64 |
65 | def get_potential_prompt(self, message: str) -> None:
66 | # Assumes that it's always the user's (human's) turn!
67 | prompt_copy = str(self.prompt)
68 |
69 | # Special Handling for "system" prompt (turn_count == 0)
70 | if self.turn_count == 0:
71 | sys_message = self.system_prompt + self.wrap_human(message)
72 | prompt_copy += sys_message
73 |
74 | else:
75 | human_message = self.wrap_human(message)
76 | prompt_copy += human_message
77 |
78 | return prompt_copy.removeprefix(self.bos).rstrip()
79 |
80 | def get_prompt(self) -> str:
81 | # Remove prefix (if exists) because it gets auto-inserted by tokenizer!
82 | return self.prompt.removeprefix(self.bos).rstrip()
83 |
--------------------------------------------------------------------------------
/prismatic/models/backbones/llm/prompting/llama2_chat_prompter.py:
--------------------------------------------------------------------------------
1 | """
2 | llama2_prompter.py
3 |
4 | Defines a PromptBuilder for building LLaMa-2 Chat Prompts --> not sure if this is "optimal", but this is the pattern
5 | that's used by HF and other online tutorials.
6 |
7 | Reference: https://huggingface.co/blog/llama2#how-to-prompt-llama-2
8 | """
9 |
10 | from typing import Optional
11 |
12 | from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder
13 |
14 | # Default System Prompt for Prismatic Models
15 | SYS_PROMPTS = {
16 | "prismatic": (
17 | "You are a helpful language and vision assistant. "
18 | "You are able to understand the visual content that the user provides, "
19 | "and assist the user with a variety of tasks using natural language."
20 | ),
21 | "openvla": (
22 | "You are a helpful language and vision assistant. "
23 | "You are able to understand the visual content that the user provides, "
24 | "and assist the user with a variety of tasks using natural language."
25 | ),
26 | }
27 |
28 |
29 | def format_system_prompt(system_prompt: str) -> str:
30 | return f"<\n{system_prompt.strip()}\n<>\n\n"
31 |
32 |
33 | class LLaMa2ChatPromptBuilder(PromptBuilder):
34 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None:
35 | super().__init__(model_family, system_prompt)
36 | self.system_prompt = format_system_prompt(
37 | SYS_PROMPTS[self.model_family] if system_prompt is None else system_prompt
38 | )
39 |
40 | # LLaMa-2 Specific
41 | self.bos, self.eos = "", ""
42 |
43 | # Get role-specific "wrap" functions
44 | self.wrap_human = lambda msg: f"[INST] {msg} [/INST] "
45 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}"
46 |
47 | # === `self.prompt` gets built up over multiple turns ===
48 | self.prompt, self.turn_count = "", 0
49 |
50 | def add_turn(self, role: str, message: str) -> str:
51 | assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt")
52 | message = message.replace("", "").strip()
53 |
54 | # Special Handling for "system" prompt (turn_count == 0)
55 | if self.turn_count == 0:
56 | sys_message = self.wrap_human(self.system_prompt + message)
57 | wrapped_message = sys_message
58 | elif (self.turn_count % 2) == 0:
59 | human_message = self.wrap_human(message)
60 | wrapped_message = human_message
61 | else:
62 | gpt_message = self.wrap_gpt(message)
63 | wrapped_message = gpt_message
64 |
65 | # Update Prompt
66 | self.prompt += wrapped_message
67 |
68 | # Bump Turn Counter
69 | self.turn_count += 1
70 |
71 | # Return "wrapped_message" (effective string added to context)
72 | return wrapped_message
73 |
74 | def get_potential_prompt(self, message: str) -> None:
75 | # Assumes that it's always the user's (human's) turn!
76 | prompt_copy = str(self.prompt)
77 |
78 | # Special Handling for "system" prompt (turn_count == 0)
79 | if self.turn_count == 0:
80 | sys_message = self.wrap_human(self.system_prompt + message)
81 | prompt_copy += sys_message
82 |
83 | else:
84 | human_message = self.wrap_human(message)
85 | prompt_copy += human_message
86 |
87 | return prompt_copy.removeprefix(self.bos).rstrip()
88 |
89 | def get_prompt(self) -> str:
90 | # Remove prefix because it gets auto-inserted by tokenizer!
91 | return self.prompt.removeprefix(self.bos).rstrip()
92 |
--------------------------------------------------------------------------------
/prismatic/models/backbones/llm/llama2.py:
--------------------------------------------------------------------------------
1 | """
2 | llama2.py
3 |
4 | Class definition for all LLMs derived from LlamaForCausalLM.
5 | """
6 |
7 | from typing import Optional, Sequence, Type
8 |
9 | import torch
10 | from torch import nn as nn
11 | from transformers import LlamaForCausalLM
12 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer
13 |
14 | from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone
15 | from prismatic.models.backbones.llm.prompting import (
16 | LLaMa2ChatPromptBuilder,
17 | PromptBuilder,
18 | PurePromptBuilder,
19 | VicunaV15ChatPromptBuilder,
20 | )
21 |
22 | # Registry =>> Support LLaMa-2 Models (from HF Transformers)
23 | # fmt: off
24 | LLAMA2_MODELS = {
25 | # === Pure Meta LLaMa-2 (non-instruct/chat-tuned) Models ===
26 | "llama2-7b-pure": {
27 | "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-7b-hf"
28 | },
29 |
30 | "llama2-13b-pure": {
31 | "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-13b-hf"
32 | },
33 |
34 | # === Meta LLaMa-2 Chat Models ===
35 | "llama2-7b-chat": {
36 | "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-7b-chat-hf"
37 | },
38 |
39 | "llama2-13b-chat": {
40 | "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-13b-chat-hf"
41 | },
42 |
43 | # === Vicuna v1.5 Chat Models ===
44 | "vicuna-v15-7b": {
45 | "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "lmsys/vicuna-7b-v1.5"
46 | },
47 |
48 | "vicuna-v15-13b": {
49 | "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "lmsys/vicuna-13b-v1.5"
50 | },
51 | }
52 | # fmt: on
53 |
54 |
55 | class LLaMa2LLMBackbone(HFCausalLLMBackbone):
56 | def __init__(
57 | self,
58 | llm_backbone_id: str,
59 | llm_max_length: int = 2048,
60 | hf_token: Optional[str] = None,
61 | inference_mode: bool = False,
62 | use_flash_attention_2: bool = True,
63 | ) -> None:
64 | super().__init__(
65 | llm_backbone_id,
66 | llm_max_length=llm_max_length,
67 | hf_token=hf_token,
68 | inference_mode=inference_mode,
69 | use_flash_attention_2=use_flash_attention_2,
70 | **LLAMA2_MODELS[llm_backbone_id],
71 | )
72 |
73 | # [Special Case] LLaMa-2 PAD Token Handling --> for clarity, we add an extra token (and resize)
74 | self.tokenizer.add_special_tokens({"pad_token": ""})
75 | self.llm.config.pad_token_id = self.tokenizer.pad_token_id
76 | self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64)
77 |
78 | @property
79 | def prompt_builder_fn(self) -> Type[PromptBuilder]:
80 | if self.identifier.startswith("llama2-") and self.identifier.endswith("-pure"):
81 | return PurePromptBuilder
82 |
83 | elif self.identifier.startswith("llama2-") and self.identifier.endswith("-chat"):
84 | return LLaMa2ChatPromptBuilder
85 |
86 | elif self.identifier.startswith("vicuna"):
87 | return VicunaV15ChatPromptBuilder
88 |
89 | raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`")
90 |
91 | @property
92 | def transformer_layer_cls(self) -> Type[nn.Module]:
93 | return LlamaDecoderLayer
94 |
95 | @property
96 | def half_precision_dtype(self) -> torch.dtype:
97 | """LLaMa-2 was trained in BF16; see https://huggingface.co/docs/transformers/main/model_doc/llama2."""
98 | return torch.bfloat16
99 |
100 | @property
101 | def last_layer_finetune_modules(self) -> Sequence[nn.Module]:
102 | return (self.llm.model.embed_tokens, self.llm.model.layers[-1], self.llm.lm_head)
103 |
--------------------------------------------------------------------------------
/prismatic/vla/action_tokenizer.py:
--------------------------------------------------------------------------------
1 | """
2 | action_tokenizer.py
3 |
4 | Extension class; wraps base LLM/VLM tokenizer with logic to discretize and tokenize continuous robot actions.
5 | """
6 |
7 | from typing import List, Union
8 |
9 | import numpy as np
10 | from transformers import PreTrainedTokenizerBase
11 |
12 |
13 | class ActionTokenizer:
14 | def __init__(
15 | self, tokenizer: PreTrainedTokenizerBase, bins: int = 256, min_action: int = -1, max_action: int = 1
16 | ) -> None:
17 | """
18 | Discretizes continuous robot actions into N bins per dimension and maps to the least used tokens.
19 |
20 | NOTE =>> by default, assumes a BPE-style tokenizer akin to the LlamaTokenizer, where *the least used tokens*
21 | appear at the end of the vocabulary!
22 |
23 | :param tokenizer: Base LLM/VLM tokenizer to extend.
24 | :param bins: Number of bins for each continuous value; we'll adopt a uniform binning strategy.
25 | :param min_action: Minimum action value (for clipping, setting lower bound on bin interval).
26 | :param max_action: Maximum action value (for clipping, setting upper bound on bin interval).
27 | """
28 | self.tokenizer, self.n_bins, self.min_action, self.max_action = tokenizer, bins, min_action, max_action
29 |
30 | # Create Uniform Bins + Compute Bin Centers
31 | self.bins = np.linspace(min_action, max_action, self.n_bins)
32 | self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0
33 |
34 | # [Contract] Set "action_token_begin_idx" based on `self.tokenizer.vocab_size - (self.n_bins + 1)`
35 | # =>> Assumes we're always overwriting the final `n_bins` tokens of the vocabulary!
36 | self.action_token_begin_idx: int = int(self.tokenizer.vocab_size - (self.n_bins + 1))
37 |
38 | def __call__(self, action: np.ndarray) -> Union[str, List[str]]:
39 | """Clip & bin actions to *the last `n_bins` tokens* of the vocabulary (e.g., tokenizer.vocab[-256:])."""
40 | action = np.clip(action, a_min=float(self.min_action), a_max=float(self.max_action))
41 | discretized_action = np.digitize(action, self.bins)
42 |
43 | # Handle single element vs. batch
44 | if len(discretized_action.shape) == 1:
45 | return self.tokenizer.decode(list(self.tokenizer.vocab_size - discretized_action))
46 | else:
47 | return self.tokenizer.batch_decode((self.tokenizer.vocab_size - discretized_action).tolist())
48 |
49 | def decode_token_ids_to_actions(self, action_token_ids: np.ndarray) -> np.ndarray:
50 | """
51 | Returns continuous actions for discrete action token IDs.
52 |
53 | NOTE =>> Because of the way the actions are discretized w.r.t. the bins (and not the bin centers), the
54 | digitization returns bin indices between [1, # bins], inclusive, when there are actually only
55 | (# bins - 1) bin intervals.
56 |
57 | Therefore, if the digitization returns the last possible index, we map this to the last bin interval.
58 |
59 | EXAMPLE =>> Let's say self._bins has 256 values. Then self._bin_centers has 255 values. Digitization returns
60 | indices between [1, 256]. We subtract 1 from all indices so that they are between [0, 255]. There
61 | is still one index (i==255) that would cause an out-of-bounds error if used to index into
62 | self._bin_centers. Therefore, if i==255, we subtract 1 from it so that it just becomes the index of
63 | the last bin center. We implement this simply via clipping between [0, 255 - 1].
64 | """
65 | discretized_actions = self.tokenizer.vocab_size - action_token_ids
66 | discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)
67 |
68 | return self.bin_centers[discretized_actions]
69 |
70 | @property
71 | def vocab_size(self) -> int:
72 | return self.n_bins
73 |
--------------------------------------------------------------------------------
/prismatic/vla/datasets/rlds/traj_transforms.py:
--------------------------------------------------------------------------------
1 | """
2 | traj_transforms.py
3 |
4 | Contains trajectory transforms used in the orca data pipeline. Trajectory transforms operate on a dictionary
5 | that represents a single trajectory, meaning each tensor has the same leading dimension (the trajectory length).
6 | """
7 |
8 | import logging
9 | from typing import Dict
10 |
11 | import tensorflow as tf
12 |
13 |
14 | def chunk_act_obs(traj: Dict, window_size: int, future_action_window_size: int = 0) -> Dict:
15 | """
16 | Chunks actions and observations into the given window_size.
17 |
18 | "observation" keys are given a new axis (at index 1) of size `window_size` containing `window_size - 1`
19 | observations from the past and the current observation. "action" is given a new axis (at index 1) of size
20 | `window_size + future_action_window_size` containing `window_size - 1` actions from the past, the current
21 | action, and `future_action_window_size` actions from the future. "pad_mask" is added to "observation" and
22 | indicates whether an observation should be considered padding (i.e. if it had come from a timestep
23 | before the start of the trajectory).
24 | """
25 | traj_len = tf.shape(traj["action"])[0]
26 | action_dim = traj["action"].shape[-1]
27 | effective_traj_len = traj_len - future_action_window_size
28 | chunk_indices = tf.broadcast_to(tf.range(-window_size + 1, 1), [effective_traj_len, window_size]) + tf.broadcast_to(
29 | tf.range(effective_traj_len)[:, None], [effective_traj_len, window_size]
30 | )
31 |
32 | action_chunk_indices = tf.broadcast_to(
33 | tf.range(-window_size + 1, 1 + future_action_window_size),
34 | [effective_traj_len, window_size + future_action_window_size],
35 | ) + tf.broadcast_to(
36 | tf.range(effective_traj_len)[:, None],
37 | [effective_traj_len, window_size + future_action_window_size],
38 | )
39 |
40 | floored_chunk_indices = tf.maximum(chunk_indices, 0)
41 |
42 | goal_timestep = tf.fill([effective_traj_len], traj_len - 1)
43 |
44 | floored_action_chunk_indices = tf.minimum(tf.maximum(action_chunk_indices, 0), goal_timestep[:, None])
45 |
46 | traj["observation"] = tf.nest.map_structure(lambda x: tf.gather(x, floored_chunk_indices), traj["observation"])
47 | traj["action"] = tf.gather(traj["action"], floored_action_chunk_indices)
48 |
49 | # indicates whether an entire observation is padding
50 | traj["observation"]["pad_mask"] = chunk_indices >= 0
51 |
52 | # Truncate other elements of the trajectory dict
53 | traj["task"] = tf.nest.map_structure(lambda x: tf.gather(x, tf.range(effective_traj_len)), traj["task"])
54 | traj["dataset_name"] = tf.gather(traj["dataset_name"], tf.range(effective_traj_len))
55 | traj["absolute_action_mask"] = tf.gather(traj["absolute_action_mask"], tf.range(effective_traj_len))
56 |
57 | return traj
58 |
59 |
60 | def subsample(traj: Dict, subsample_length: int) -> Dict:
61 | """Subsamples trajectories to the given length."""
62 | traj_len = tf.shape(traj["action"])[0]
63 | if traj_len > subsample_length:
64 | indices = tf.random.shuffle(tf.range(traj_len))[:subsample_length]
65 | traj = tf.nest.map_structure(lambda x: tf.gather(x, indices), traj)
66 |
67 | return traj
68 |
69 |
70 | def add_pad_mask_dict(traj: Dict) -> Dict:
71 | """
72 | Adds a dictionary indicating which elements of the observation/task should be treated as padding.
73 | =>> traj["observation"|"task"]["pad_mask_dict"] = {k: traj["observation"|"task"][k] is not padding}
74 | """
75 | traj_len = tf.shape(traj["action"])[0]
76 |
77 | for key in ["observation", "task"]:
78 | pad_mask_dict = {}
79 | for subkey in traj[key]:
80 | # Handles "language_instruction", "image_*", and "depth_*"
81 | if traj[key][subkey].dtype == tf.string:
82 | pad_mask_dict[subkey] = tf.strings.length(traj[key][subkey]) != 0
83 |
84 | # All other keys should not be treated as padding
85 | else:
86 | pad_mask_dict[subkey] = tf.ones([traj_len], dtype=tf.bool)
87 |
88 | traj[key]["pad_mask_dict"] = pad_mask_dict
89 |
90 | return traj
91 |
--------------------------------------------------------------------------------
/prismatic/util/torch_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | torch_utils.py
3 |
4 | General utilities for randomness, mixed precision training, and miscellaneous checks in PyTorch.
5 |
6 | Random `set_global_seed` functionality is taken directly from PyTorch-Lighting:
7 | > Ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/seed.py
8 |
9 | This is pretty important to get right if we're every randomly generating our masks (or prefix dropout) inside our
10 | Dataset __getitem__() with multiple workers... if not handled properly, we will get repeated augmentations anytime
11 | we inject randomness from non-PyTorch sources (e.g., numpy, random)!
12 | > Ref: https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/
13 |
14 | Terminology
15 | -> World Size :: Total number of processes distributed over (# nodes x # devices) -- assumed homogenous!
16 | -> Rank :: Integer index of current process in the total world size
17 | -> Local Rank :: Local index on given node in [0, Devices per Node]
18 | """
19 |
20 | import os
21 | import random
22 | from typing import Callable, Optional
23 |
24 | import numpy as np
25 | import torch
26 |
27 | # === Randomness ===
28 |
29 |
30 | def set_global_seed(seed: int, get_worker_init_fn: bool = False) -> Optional[Callable[[int], None]]:
31 | """Sets seed for all randomness libraries (mostly random, numpy, torch) and produces a `worker_init_fn`"""
32 | assert np.iinfo(np.uint32).min < seed < np.iinfo(np.uint32).max, "Seed outside the np.uint32 bounds!"
33 |
34 | # Set Seed as an Environment Variable
35 | os.environ["EXPERIMENT_GLOBAL_SEED"] = str(seed)
36 | random.seed(seed)
37 | np.random.seed(seed)
38 | torch.manual_seed(seed)
39 |
40 | return worker_init_function if get_worker_init_fn else None
41 |
42 |
43 | def worker_init_function(worker_id: int) -> None:
44 | """
45 | Borrowed directly from PyTorch-Lightning; inspired by this issue comment in the PyTorch repo:
46 | > Ref: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562
47 |
48 | Intuition: You can think of the seed sequence spawn function as a "janky" torch.Generator() or jax.PRNGKey that
49 | you can run iterative splitting on to get new (predictable) randomness.
50 |
51 | :param worker_id: Identifier for the given worker [0, num_workers) for the Dataloader in question.
52 | """
53 | # Get current `rank` (if running distributed) and `process_seed`
54 | global_rank, process_seed = int(os.environ["LOCAL_RANK"]), torch.initial_seed()
55 |
56 | # Back out the "base" (original) seed - the per-worker seed is set in PyTorch:
57 | # > https://pytorch.org/docs/stable/data.html#data-loading-randomness
58 | base_seed = process_seed - worker_id
59 |
60 | # "Magic" code --> basically creates a seed sequence that mixes different "sources" and seeds every library...
61 | seed_seq = np.random.SeedSequence([base_seed, worker_id, global_rank])
62 |
63 | # Use 128 bits (4 x 32-bit words) to represent seed --> generate_state(k) produces a `k` element array!
64 | np.random.seed(seed_seq.generate_state(4))
65 |
66 | # Spawn distinct child sequences for PyTorch (reseed) and stdlib random
67 | torch_seed_seq, random_seed_seq = seed_seq.spawn(2)
68 |
69 | # Torch Manual seed takes 64 bits (so just specify a dtype of uint64
70 | torch.manual_seed(torch_seed_seq.generate_state(1, dtype=np.uint64)[0])
71 |
72 | # Use 128 Bits for `random`, but express as integer instead of as an array
73 | random_seed = (random_seed_seq.generate_state(2, dtype=np.uint64).astype(list) * [1 << 64, 1]).sum()
74 | random.seed(random_seed)
75 |
76 |
77 | # === BFloat16 Support ===
78 |
79 |
80 | def check_bloat16_supported() -> bool:
81 | try:
82 | import packaging.version
83 | import torch.cuda.nccl as nccl
84 | import torch.distributed as dist
85 |
86 | return (
87 | (torch.version.cuda is not None)
88 | and torch.cuda.is_bf16_supported()
89 | and (packaging.version.parse(torch.version.cuda).release >= (11, 0))
90 | and dist.is_nccl_available()
91 | and (nccl.version() >= (2, 10))
92 | )
93 |
94 | except Exception:
95 | return False
96 |
--------------------------------------------------------------------------------
/prismatic/vla/datasets/rlds/obs_transforms.py:
--------------------------------------------------------------------------------
1 | """
2 | obs_transforms.py
3 |
4 | Contains observation-level transforms used in the orca data pipeline.
5 |
6 | These transforms operate on the "observation" dictionary, and are applied at a per-frame level.
7 | """
8 |
9 | from typing import Dict, Tuple, Union
10 |
11 | import dlimp as dl
12 | import tensorflow as tf
13 | from absl import logging
14 |
15 |
16 | # ruff: noqa: B023
17 | def augment(obs: Dict, seed: tf.Tensor, augment_kwargs: Union[Dict, Dict[str, Dict]]) -> Dict:
18 | """Augments images, skipping padding images."""
19 | image_names = {key[6:] for key in obs if key.startswith("image_")}
20 |
21 | # "augment_order" is required in augment_kwargs, so if it's there, we can assume that the user has passed
22 | # in a single augmentation dict (otherwise, we assume that the user has passed in a mapping from image
23 | # name to augmentation dict)
24 | if "augment_order" in augment_kwargs:
25 | augment_kwargs = {name: augment_kwargs for name in image_names}
26 |
27 | for i, name in enumerate(image_names):
28 | if name not in augment_kwargs:
29 | continue
30 | kwargs = augment_kwargs[name]
31 | logging.debug(f"Augmenting image_{name} with kwargs {kwargs}")
32 | obs[f"image_{name}"] = tf.cond(
33 | obs["pad_mask_dict"][f"image_{name}"],
34 | lambda: dl.transforms.augment_image(
35 | obs[f"image_{name}"],
36 | **kwargs,
37 | seed=seed + i, # augment each image differently
38 | ),
39 | lambda: obs[f"image_{name}"], # skip padding images
40 | )
41 |
42 | return obs
43 |
44 |
45 | def decode_and_resize(
46 | obs: Dict,
47 | resize_size: Union[Tuple[int, int], Dict[str, Tuple[int, int]]],
48 | depth_resize_size: Union[Tuple[int, int], Dict[str, Tuple[int, int]]],
49 | ) -> Dict:
50 | """Decodes images and depth images, and then optionally resizes them."""
51 | image_names = {key[6:] for key in obs if key.startswith("image_")}
52 | depth_names = {key[6:] for key in obs if key.startswith("depth_")}
53 |
54 | if isinstance(resize_size, tuple):
55 | resize_size = {name: resize_size for name in image_names}
56 | if isinstance(depth_resize_size, tuple):
57 | depth_resize_size = {name: depth_resize_size for name in depth_names}
58 |
59 | for name in image_names:
60 | if name not in resize_size:
61 | logging.warning(
62 | f"No resize_size was provided for image_{name}. This will result in 1x1 "
63 | "padding images, which may cause errors if you mix padding and non-padding images."
64 | )
65 | image = obs[f"image_{name}"]
66 | if image.dtype == tf.string:
67 | if tf.strings.length(image) == 0:
68 | # this is a padding image
69 | image = tf.zeros((*resize_size.get(name, (1, 1)), 3), dtype=tf.uint8)
70 | else:
71 | image = tf.io.decode_image(image, expand_animations=False, dtype=tf.uint8)
72 | elif image.dtype != tf.uint8:
73 | raise ValueError(f"Unsupported image dtype: found image_{name} with dtype {image.dtype}")
74 | if name in resize_size:
75 | image = dl.transforms.resize_image(image, size=resize_size[name])
76 | obs[f"image_{name}"] = image
77 |
78 | for name in depth_names:
79 | if name not in depth_resize_size:
80 | logging.warning(
81 | f"No depth_resize_size was provided for depth_{name}. This will result in 1x1 "
82 | "padding depth images, which may cause errors if you mix padding and non-padding images."
83 | )
84 | depth = obs[f"depth_{name}"]
85 |
86 | if depth.dtype == tf.string:
87 | if tf.strings.length(depth) == 0:
88 | depth = tf.zeros((*depth_resize_size.get(name, (1, 1)), 1), dtype=tf.float32)
89 | else:
90 | depth = tf.io.decode_image(depth, expand_animations=False, dtype=tf.float32)[..., 0]
91 | elif depth.dtype != tf.float32:
92 | raise ValueError(f"Unsupported depth dtype: found depth_{name} with dtype {depth.dtype}")
93 |
94 | if name in depth_resize_size:
95 | depth = dl.transforms.resize_depth_image(depth, size=depth_resize_size[name])
96 |
97 | obs[f"depth_{name}"] = depth
98 |
99 | return obs
100 |
--------------------------------------------------------------------------------
/prismatic/models/vlms/base_vlm.py:
--------------------------------------------------------------------------------
1 | """
2 | base_vlm.py
3 |
4 | Abstract class definition of a Vision-Language Model (VLM), with full annotations of class methods, utility functions,
5 | and initialization logic. This is mostly to future-proof the codebase; while all our experiments instantiate
6 | from PrismaticVLM, theoretically, this base class should be general enough to cover almost all models (e.g., IDEFICS,
7 | PALI, Fuyu) in the future.
8 |
9 | We use Abstract base classes *sparingly* -- mostly as a way to encapsulate any redundant logic or nested inheritance
10 | (e.g., dependence on nn.Module, HF PretrainedModel, etc.). For other abstract objects (e.g., Tokenizers/Transforms),
11 | prefer Protocol definitions instead.
12 | """
13 |
14 | from __future__ import annotations
15 |
16 | from abc import ABC, abstractmethod
17 | from pathlib import Path
18 | from typing import Callable, List, Optional
19 |
20 | import torch
21 | import torch.nn as nn
22 | from transformers import GenerationMixin, PretrainedConfig
23 | from transformers.modeling_outputs import CausalLMOutputWithPast
24 |
25 | from prismatic.models.backbones.llm import LLMBackbone
26 | from prismatic.models.backbones.llm.prompting import PromptBuilder
27 | from prismatic.models.backbones.vision import VisionBackbone
28 |
29 |
30 | # === Abstract Base Class for arbitrary Vision-Language Models ===
31 | class VLM(nn.Module, GenerationMixin, ABC):
32 | def __init__(
33 | self,
34 | model_family: str,
35 | model_id: str,
36 | vision_backbone: VisionBackbone,
37 | llm_backbone: LLMBackbone,
38 | enable_mixed_precision_training: bool = True,
39 | ) -> None:
40 | super().__init__()
41 | self.model_family, self.model_id = model_family, model_id
42 | self.vision_backbone, self.llm_backbone = vision_backbone, llm_backbone
43 | self.enable_mixed_precision_training = enable_mixed_precision_training
44 |
45 | # Instance Attributes for a generic VLM
46 | self.all_module_keys, self.trainable_module_keys = None, None
47 |
48 | # === GenerationMixin Expected Attributes =>> *DO NOT MODIFY* ===
49 | self.generation_config = self.llm_backbone.llm.generation_config
50 | self.main_input_name = "input_ids"
51 |
52 | @property
53 | def device(self) -> torch.device:
54 | """Borrowed from `transformers.modeling_utils.py` -- checks parameter device; assumes model on *ONE* device!"""
55 | return next(self.parameters()).device
56 |
57 | @classmethod
58 | @abstractmethod
59 | def from_pretrained(
60 | cls,
61 | pretrained_checkpoint: Path,
62 | model_family: str,
63 | model_id: str,
64 | vision_backbone: VisionBackbone,
65 | llm_backbone: LLMBackbone,
66 | **kwargs: str,
67 | ) -> VLM: ...
68 |
69 | @abstractmethod
70 | def get_prompt_builder(self, system_prompt: Optional[str] = None) -> PromptBuilder: ...
71 |
72 | @abstractmethod
73 | def freeze_backbones(self, stage: str) -> None: ...
74 |
75 | @abstractmethod
76 | def load_from_checkpoint(self, stage: str, run_dir: Path, pretrained_checkpoint: Optional[Path] = None) -> None: ...
77 |
78 | @abstractmethod
79 | def get_fsdp_wrapping_policy(self) -> Callable: ...
80 |
81 | @abstractmethod
82 | def forward(
83 | self,
84 | input_ids: Optional[torch.LongTensor] = None,
85 | attention_mask: Optional[torch.Tensor] = None,
86 | pixel_values: Optional[torch.FloatTensor] = None,
87 | labels: Optional[torch.LongTensor] = None,
88 | inputs_embeds: Optional[torch.FloatTensor] = None,
89 | past_key_values: Optional[List[torch.FloatTensor]] = None,
90 | use_cache: Optional[bool] = None,
91 | output_attentions: Optional[bool] = None,
92 | output_hidden_states: Optional[bool] = None,
93 | return_dict: Optional[bool] = None,
94 | multimodal_indices: Optional[torch.LongTensor] = None,
95 | ) -> CausalLMOutputWithPast: ...
96 |
97 | # === GenerationMixin Expected Properties & Methods (DO NOT MODIFY) ===
98 | @staticmethod
99 | def can_generate() -> bool:
100 | return True
101 |
102 | @property
103 | def config(self) -> PretrainedConfig:
104 | return self.llm_backbone.llm.config
105 |
106 | # => Beam Search Utility
107 | def _reorder_cache(self, past_key_values, beam_idx):
108 | return self.llm_backbone.llm._reorder_cache(past_key_values, beam_idx)
109 |
--------------------------------------------------------------------------------
/config_nav/mbra_and_dataset_config.yaml:
--------------------------------------------------------------------------------
1 | # parameters for MBRA
2 | context_size: 5 #history of image
3 | len_traj_pred: 8 #action length
4 | learn_angle: True
5 | obs_encoder: efficientnet-b0
6 | obs_encoding_size: 1024
7 | late_fusion: False
8 | mha_num_attention_heads: 4
9 | mha_num_attention_layers: 4
10 | mha_ff_dim_factor: 4
11 |
12 | # parameters for training dataset
13 | image_size: [96, 96] # width, height
14 | context_type: temporal
15 | normalize: True
16 | num_workers: 10
17 |
18 | #For dataset path and setting
19 | datasets_CAST:
20 | path: /your_cast_data_location/
21 |
22 | datasets_frodobots:
23 | root: /your_frodobots_data_location
24 | horizon_short: 20
25 | horizon_long: 100
26 | frodobot:
27 | train: 11434305
28 | test: 11434305
29 | negative_mining: True
30 |
31 | datasets_bdd:
32 | train: /your_data_location/data_splits/bdd/train/
33 | test: /your_data_location/data_splits/bdd/test/
34 | image: /your_bdd_data_location/BDD_dataset
35 | pickle: /your_bdd_data_location/BDD_dataset_pickle
36 | backside: False
37 | aug_seq: False
38 | only_front: True
39 | waypoint_spacing: 0.12
40 |
41 | datasets_gnm:
42 | distance:
43 | min_dist_cat: 0
44 | max_dist_cat: 20
45 | action:
46 | min_dist_cat: 3
47 | max_dist_cat: 20
48 |
49 | recon:
50 | data_folder: /your_gnm_data_location/recon
51 | train: /your_data_location/data_splits/recon/train/
52 | test: /your_data_location/data_splits/recon/test/
53 | end_slack: 3
54 | goals_per_obs: 1
55 | negative_mining: True # negative mining from the ViNG paper (Shah et al.)
56 | go_stanford:
57 | data_folder: /your_gnm_data_location/go_stanford # datasets/stanford_go_new
58 | train: /your_data_location/data_splits/go_stanford/train/
59 | test: /your_data_location/data_splits/go_stanford/test/
60 | end_slack: 0
61 | goals_per_obs: 2
62 | negative_mining: True
63 | cory_hall:
64 | data_folder: /your_gnm_data_location/cory_hall
65 | train: /your_data_location/data_splits/cory_hall/train/
66 | test: /your_data_location/data_splits/cory_hall/test/
67 | end_slack: 3
68 | goals_per_obs: 1
69 | negative_mining: True
70 | tartan_drive:
71 | data_folder: /your_gnm_data_location/tartan_drive
72 | train: /your_data_location/data_splits/tartan_drive/train/
73 | test: /your_data_location/data_splits/tartan_drive/test/
74 | end_slack: 3
75 | goals_per_obs: 1
76 | negative_mining: True
77 | sacson:
78 | data_folder: /your_gnm_data_location/sacson
79 | train: /your_data_location/data_splits/sacson/train/
80 | test: /your_data_location/data_splits/sacson/test/
81 | end_slack: 3
82 | goals_per_obs: 1
83 | negative_mining: True
84 | seattle:
85 | data_folder: /your_gnm_data_location/seattle/
86 | train: /your_data_location/data_splits/seattle/train/
87 | test: /your_data_location/data_splits/seattle/test/
88 | end_slack: 0
89 | goals_per_obs: 1
90 | negative_mining: True
91 | scand:
92 | data_folder: /your_gnm_data_location/scand/
93 | train: /your_data_location/data_splits/scand/train/
94 | test: /your_data_location/data_splits/scand/test/
95 | end_slack: 0
96 | goals_per_obs: 1
97 | negative_mining: True
98 |
99 | datasets_lelan:
100 | go_stanford4:
101 | train: /your_data_location/data_splits/lelan_gs4/train/
102 | test: /your_data_location/data_splits/lelan_gs4/test/
103 | image: /your_lelan_data_location/dataset_LeLaN_gs4/image/
104 | pickle: /your_lelan_data_location/dataset_LeLaN_gs4/pickle_nomad/
105 | backside: False
106 | aug_seq: False
107 | only_front: False
108 |
109 | sacson:
110 | train: /your_data_location/data_splits/lelan_sacson/train/
111 | test: /your_data_location/data_splits/lelan_sacson/test/
112 | image: /your_lelan_data_location/dataset_LeLaN_sacson/
113 | pickle: /your_lelan_data_location/dataset_LeLaN_sacson/
114 | backside: False
115 | aug_seq: False
116 | only_front: False
117 |
118 | go_stanford2:
119 | train: /your_data_location/data_splits/lelan_gs2/train/
120 | test: /your_data_location/data_splits/lelan_gs2/test/
121 | image: /your_lelan_data_location/dataset_LeLaN_gs2/
122 | pickle: /your_lelan_data_location/dataset_LeLaN_gs2/
123 | backside: False
124 | aug_seq: False
125 | only_front: False
126 |
127 | humanw:
128 | train: /your_data_location/data_splits/lelan_humanw/train/
129 | test: /your_data_location/data_splits/lelan_humanw/test/
130 | image: /your_lelan_data_location/dataset_LeLaN_humanwalk/
131 | pickle: /your_lelan_data_location/dataset_LeLaN_humanwalk/
132 | backside: False
133 | aug_seq: False
134 | only_front: True
135 |
136 | youtube:
137 | train: /your_data_location/data_splits/lelan_youtube/train/
138 | test: /your_data_location/data_splits/lelan_youtube/test/
139 | image: /your_lelan_data_location/dataset_LeLaN_youtube/
140 | pickle: /your_data_location/pickle_youtube/
141 | backside: False
142 | aug_seq: False
143 | only_front: True
144 |
--------------------------------------------------------------------------------
/prismatic/conf/datasets.py:
--------------------------------------------------------------------------------
1 | """
2 | datasets.py
3 |
4 | Draccus Dataclass Definition for a DatasetConfig object, with various registered subclasses for each dataset variant
5 | and processing scheme. A given dataset variant (e.g., `llava-lightning`) configures the following attributes:
6 | - Dataset Variant (Identifier) --> e.g., "llava-v15"
7 | - Align Stage Dataset Components (annotations, images)
8 | - Finetune Stage Dataset Components (annotations, images)
9 | - Dataset Root Directory (Path)
10 | """
11 |
12 | from dataclasses import dataclass
13 | from enum import Enum, unique
14 | from pathlib import Path
15 | from typing import Tuple
16 |
17 | from draccus import ChoiceRegistry
18 |
19 |
20 | @dataclass
21 | class DatasetConfig(ChoiceRegistry):
22 | # fmt: off
23 | dataset_id: str # Unique ID that fully specifies a dataset variant
24 |
25 | # Dataset Components for each Stage in < align | finetune >
26 | align_stage_components: Tuple[Path, Path] # Path to annotation file and images directory for `align` stage
27 | finetune_stage_components: Tuple[Path, Path] # Path to annotation file and images directory for `finetune` stage
28 |
29 | dataset_root_dir: Path # Path to dataset root directory; others paths are relative to root
30 | # fmt: on
31 |
32 |
33 | # [Reproduction] LLaVa-v15 (exact dataset used in all public LLaVa-v15 models)
34 | @dataclass
35 | class LLaVa_V15_Config(DatasetConfig):
36 | dataset_id: str = "llava-v15"
37 |
38 | align_stage_components: Tuple[Path, Path] = (
39 | Path("download/llava-laion-cc-sbu-558k/chat.json"),
40 | Path("download/llava-laion-cc-sbu-558k/"),
41 | )
42 | finetune_stage_components: Tuple[Path, Path] = (
43 | Path("download/llava-v1.5-instruct/llava_v1_5_mix665k.json"),
44 | Path("download/llava-v1.5-instruct/"),
45 | )
46 | dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms")
47 |
48 |
49 | # [Multimodal-Only] LLava-v15 WITHOUT the Language-Only ShareGPT Data (No Co-Training)
50 | @dataclass
51 | class LLaVa_Multimodal_Only_Config(DatasetConfig):
52 | dataset_id: str = "llava-multimodal"
53 |
54 | align_stage_components: Tuple[Path, Path] = (
55 | Path("download/llava-laion-cc-sbu-558k/chat.json"),
56 | Path("download/llava-laion-cc-sbu-558k/"),
57 | )
58 | finetune_stage_components: Tuple[Path, Path] = (
59 | Path("download/llava-v1.5-instruct/llava_v1_5_stripped625k.json"),
60 | Path("download/llava-v1.5-instruct/"),
61 | )
62 | dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms")
63 |
64 |
65 | # LLaVa-v15 + LVIS-Instruct-4V
66 | @dataclass
67 | class LLaVa_LVIS4V_Config(DatasetConfig):
68 | dataset_id: str = "llava-lvis4v"
69 |
70 | align_stage_components: Tuple[Path, Path] = (
71 | Path("download/llava-laion-cc-sbu-558k/chat.json"),
72 | Path("download/llava-laion-cc-sbu-558k/"),
73 | )
74 | finetune_stage_components: Tuple[Path, Path] = (
75 | Path("download/llava-v1.5-instruct/llava_v1_5_lvis4v_mix888k.json"),
76 | Path("download/llava-v1.5-instruct/"),
77 | )
78 | dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms")
79 |
80 |
81 | # LLaVa-v15 + LRV-Instruct
82 | @dataclass
83 | class LLaVa_LRV_Config(DatasetConfig):
84 | dataset_id: str = "llava-lrv"
85 |
86 | align_stage_components: Tuple[Path, Path] = (
87 | Path("download/llava-laion-cc-sbu-558k/chat.json"),
88 | Path("download/llava-laion-cc-sbu-558k/"),
89 | )
90 | finetune_stage_components: Tuple[Path, Path] = (
91 | Path("download/llava-v1.5-instruct/llava_v1_5_lrv_mix1008k.json"),
92 | Path("download/llava-v1.5-instruct/"),
93 | )
94 | dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms")
95 |
96 |
97 | # LLaVa-v15 + LVIS-Instruct-4V + LRV-Instruct
98 | @dataclass
99 | class LLaVa_LVIS4V_LRV_Config(DatasetConfig):
100 | dataset_id: str = "llava-lvis4v-lrv"
101 |
102 | align_stage_components: Tuple[Path, Path] = (
103 | Path("download/llava-laion-cc-sbu-558k/chat.json"),
104 | Path("download/llava-laion-cc-sbu-558k/"),
105 | )
106 | finetune_stage_components: Tuple[Path, Path] = (
107 | Path("download/llava-v1.5-instruct/llava_v1_5_lvis4v_lrv_mix1231k.json"),
108 | Path("download/llava-v1.5-instruct/"),
109 | )
110 | dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms")
111 |
112 |
113 | # === Define a Dataset Registry Enum for Reference & Validation =>> all *new* datasets must be added here! ===
114 | @unique
115 | class DatasetRegistry(Enum):
116 | # === LLaVa v1.5 ===
117 | LLAVA_V15 = LLaVa_V15_Config
118 |
119 | LLAVA_MULTIMODAL_ONLY = LLaVa_Multimodal_Only_Config
120 |
121 | LLAVA_LVIS4V = LLaVa_LVIS4V_Config
122 | LLAVA_LRV = LLaVa_LRV_Config
123 |
124 | LLAVA_LVIS4V_LRV = LLaVa_LVIS4V_LRV_Config
125 |
126 | @property
127 | def dataset_id(self) -> str:
128 | return self.value.dataset_id
129 |
130 |
131 | # Register Datasets in Choice Registry
132 | for dataset_variant in DatasetRegistry:
133 | DatasetConfig.register_subclass(dataset_variant.dataset_id, dataset_variant.value)
134 |
--------------------------------------------------------------------------------
/prismatic/models/materialize.py:
--------------------------------------------------------------------------------
1 | """
2 | materialize.py
3 |
4 | Factory class for initializing Vision Backbones, LLM Backbones, and VLMs from a set registry; provides and exports
5 | individual functions for clear control flow.
6 | """
7 |
8 | from typing import Optional, Tuple
9 |
10 | from transformers import PreTrainedTokenizerBase
11 |
12 | from prismatic.models.backbones.llm import LLaMa2LLMBackbone, LLMBackbone, MistralLLMBackbone, PhiLLMBackbone
13 | from prismatic.models.backbones.vision import (
14 | CLIPViTBackbone,
15 | DinoCLIPViTBackbone,
16 | DinoSigLIPViTBackbone,
17 | DinoV2ViTBackbone,
18 | ImageTransform,
19 | IN1KViTBackbone,
20 | SigLIPViTBackbone,
21 | VisionBackbone,
22 | )
23 | from prismatic.models.vlms import PrismaticVLM
24 |
25 | # === Registries =>> Maps ID --> {cls(), kwargs} :: Different Registries for Vision Backbones, LLM Backbones, VLMs ===
26 | # fmt: off
27 |
28 | # === Vision Backbone Registry ===
29 | VISION_BACKBONES = {
30 | # === 224px Backbones ===
31 | "clip-vit-l": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 224}},
32 | "siglip-vit-so400m": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 224}},
33 | "dinov2-vit-l": {"cls": DinoV2ViTBackbone, "kwargs": {"default_image_size": 224}},
34 | "in1k-vit-l": {"cls": IN1KViTBackbone, "kwargs": {"default_image_size": 224}},
35 | "dinosiglip-vit-so-224px": {"cls": DinoSigLIPViTBackbone, "kwargs": {"default_image_size": 224}},
36 |
37 | # === Assorted CLIP Backbones ===
38 | "clip-vit-b": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 224}},
39 | "clip-vit-l-336px": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 336}},
40 |
41 | # === Assorted SigLIP Backbones ===
42 | "siglip-vit-b16-224px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 224}},
43 | "siglip-vit-b16-256px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 256}},
44 | "siglip-vit-b16-384px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 384}},
45 | "siglip-vit-so400m-384px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 384}},
46 |
47 | # === Fused Backbones ===
48 | "dinoclip-vit-l-336px": {"cls": DinoCLIPViTBackbone, "kwargs": {"default_image_size": 336}},
49 | "dinosiglip-vit-so-384px": {"cls": DinoSigLIPViTBackbone, "kwargs": {"default_image_size": 384}},
50 | }
51 |
52 |
53 | # === Language Model Registry ===
54 | LLM_BACKBONES = {
55 | # === LLaMa-2 Pure (Non-Chat) Backbones ===
56 | "llama2-7b-pure": {"cls": LLaMa2LLMBackbone, "kwargs": {}},
57 | "llama2-13b-pure": {"cls": LLaMa2LLMBackbone, "kwargs": {}},
58 |
59 | # === LLaMa-2 Chat Backbones ===
60 | "llama2-7b-chat": {"cls": LLaMa2LLMBackbone, "kwargs": {}},
61 | "llama2-13b-chat": {"cls": LLaMa2LLMBackbone, "kwargs": {}},
62 |
63 | # === Vicuna-v1.5 Backbones ===
64 | "vicuna-v15-7b": {"cls": LLaMa2LLMBackbone, "kwargs": {}},
65 | "vicuna-v15-13b": {"cls": LLaMa2LLMBackbone, "kwargs": {}},
66 |
67 | # === Mistral v0.1 Backbones ===
68 | "mistral-v0.1-7b-pure": {"cls": MistralLLMBackbone, "kwargs": {}},
69 | "mistral-v0.1-7b-instruct": {"cls": MistralLLMBackbone, "kwargs": {}},
70 |
71 | # === Phi-2 Backbone ===
72 | "phi-2-3b": {"cls": PhiLLMBackbone, "kwargs": {}},
73 | }
74 |
75 | # fmt: on
76 |
77 |
78 | def get_vision_backbone_and_transform(
79 | vision_backbone_id: str, image_resize_strategy: str
80 | ) -> Tuple[VisionBackbone, ImageTransform]:
81 | """Instantiate a Vision Backbone, returning both the nn.Module wrapper class and default Image Transform."""
82 | if vision_backbone_id in VISION_BACKBONES:
83 | vision_cfg = VISION_BACKBONES[vision_backbone_id]
84 | vision_backbone: VisionBackbone = vision_cfg["cls"](
85 | vision_backbone_id, image_resize_strategy, **vision_cfg["kwargs"]
86 | )
87 | image_transform = vision_backbone.get_image_transform()
88 | return vision_backbone, image_transform
89 |
90 | else:
91 | raise ValueError(f"Vision Backbone `{vision_backbone_id}` is not supported!")
92 |
93 |
94 | def get_llm_backbone_and_tokenizer(
95 | llm_backbone_id: str,
96 | llm_max_length: int = 2048,
97 | hf_token: Optional[str] = None,
98 | inference_mode: bool = False,
99 | ) -> Tuple[LLMBackbone, PreTrainedTokenizerBase]:
100 | if llm_backbone_id in LLM_BACKBONES:
101 | llm_cfg = LLM_BACKBONES[llm_backbone_id]
102 | llm_backbone: LLMBackbone = llm_cfg["cls"](
103 | llm_backbone_id,
104 | llm_max_length=llm_max_length,
105 | hf_token=hf_token,
106 | inference_mode=inference_mode,
107 | **llm_cfg["kwargs"],
108 | )
109 | tokenizer = llm_backbone.get_tokenizer()
110 | return llm_backbone, tokenizer
111 |
112 | else:
113 | raise ValueError(f"LLM Backbone `{llm_backbone_id}` is not supported!")
114 |
115 |
116 | def get_vlm(
117 | model_id: str,
118 | arch_specifier: str,
119 | vision_backbone: VisionBackbone,
120 | llm_backbone: LLMBackbone,
121 | enable_mixed_precision_training: bool = True,
122 | ) -> PrismaticVLM:
123 | """Lightweight wrapper around initializing a VLM, mostly for future-proofing (if one wants to add a new VLM)."""
124 | return PrismaticVLM(
125 | model_id,
126 | vision_backbone,
127 | llm_backbone,
128 | enable_mixed_precision_training=enable_mixed_precision_training,
129 | arch_specifier=arch_specifier,
130 | )
131 |
--------------------------------------------------------------------------------
/prismatic/overwatch/overwatch.py:
--------------------------------------------------------------------------------
1 | """
2 | overwatch.py
3 |
4 | Utility class for creating a centralized/standardized logger (built on Rich) and accelerate handler.
5 | """
6 |
7 | import logging
8 | import logging.config
9 | import os
10 | from contextlib import nullcontext
11 | from logging import LoggerAdapter
12 | from typing import Any, Callable, ClassVar, Dict, MutableMapping, Tuple, Union
13 |
14 | # Overwatch Default Format String
15 | RICH_FORMATTER, DATEFMT = "| >> %(message)s", "%m/%d [%H:%M:%S]"
16 |
17 | # Set Logging Configuration
18 | LOG_CONFIG = {
19 | "version": 1,
20 | "disable_existing_loggers": True,
21 | "formatters": {"simple-console": {"format": RICH_FORMATTER, "datefmt": DATEFMT}},
22 | "handlers": {
23 | "console": {
24 | "class": "rich.logging.RichHandler",
25 | "formatter": "simple-console",
26 | "markup": True,
27 | "rich_tracebacks": True,
28 | "show_level": True,
29 | "show_path": True,
30 | "show_time": True,
31 | }
32 | },
33 | "root": {"level": "INFO", "handlers": ["console"]},
34 | }
35 | logging.config.dictConfig(LOG_CONFIG)
36 |
37 |
38 | # === Custom Contextual Logging Logic ===
39 | class ContextAdapter(LoggerAdapter):
40 | CTX_PREFIXES: ClassVar[Dict[int, str]] = {**{0: "[*] "}, **{idx: "|=> ".rjust(4 + (idx * 4)) for idx in [1, 2, 3]}}
41 |
42 | def process(self, msg: str, kwargs: MutableMapping[str, Any]) -> Tuple[str, MutableMapping[str, Any]]:
43 | ctx_level = kwargs.pop("ctx_level", 0)
44 | return f"{self.CTX_PREFIXES[ctx_level]}{msg}", kwargs
45 |
46 |
47 | class DistributedOverwatch:
48 | def __init__(self, name: str) -> None:
49 | """Initializer for an Overwatch object that wraps logging & `accelerate.PartialState`."""
50 | from accelerate import PartialState
51 |
52 | # Note that PartialState is always safe to initialize regardless of `accelerate launch` or `torchrun`
53 | # =>> However, might be worth actually figuring out if we need the `accelerate` dependency at all!
54 | self.logger, self.distributed_state = ContextAdapter(logging.getLogger(name), extra={}), PartialState()
55 |
56 | # Logger Delegation (for convenience; would be nice to just compose & dynamic dispatch eventually)
57 | self.debug = self.logger.debug
58 | self.info = self.logger.info
59 | self.warning = self.logger.warning
60 | self.error = self.logger.error
61 | self.critical = self.logger.critical
62 |
63 | # Logging Defaults =>> only Log `INFO` on Main Process, `ERROR` on others!
64 | self.logger.setLevel(logging.INFO if self.distributed_state.is_main_process else logging.ERROR)
65 |
66 | @property
67 | def rank_zero_only(self) -> Callable[..., Any]:
68 | return self.distributed_state.on_main_process
69 |
70 | @property
71 | def local_zero_only(self) -> Callable[..., Any]:
72 | return self.distributed_state.on_local_main_process
73 |
74 | @property
75 | def rank_zero_first(self) -> Callable[..., Any]:
76 | return self.distributed_state.main_process_first
77 |
78 | @property
79 | def local_zero_first(self) -> Callable[..., Any]:
80 | return self.distributed_state.local_main_process_first
81 |
82 | def is_rank_zero(self) -> bool:
83 | return self.distributed_state.is_main_process
84 |
85 | def rank(self) -> int:
86 | return self.distributed_state.process_index
87 |
88 | def local_rank(self) -> int:
89 | return self.distributed_state.local_process_index
90 |
91 | def world_size(self) -> int:
92 | return self.distributed_state.num_processes
93 |
94 |
95 | class PureOverwatch:
96 | def __init__(self, name: str) -> None:
97 | """Initializer for an Overwatch object that just wraps logging."""
98 | self.logger = ContextAdapter(logging.getLogger(name), extra={})
99 |
100 | # Logger Delegation (for convenience; would be nice to just compose & dynamic dispatch eventually)
101 | self.debug = self.logger.debug
102 | self.info = self.logger.info
103 | self.warning = self.logger.warning
104 | self.error = self.logger.error
105 | self.critical = self.logger.critical
106 |
107 | # Logging Defaults =>> INFO
108 | self.logger.setLevel(logging.INFO)
109 |
110 | @staticmethod
111 | def get_identity_ctx() -> Callable[..., Any]:
112 | def identity(fn: Callable[..., Any]) -> Callable[..., Any]:
113 | return fn
114 |
115 | return identity
116 |
117 | @property
118 | def rank_zero_only(self) -> Callable[..., Any]:
119 | return self.get_identity_ctx()
120 |
121 | @property
122 | def local_zero_only(self) -> Callable[..., Any]:
123 | return self.get_identity_ctx()
124 |
125 | @property
126 | def rank_zero_first(self) -> Callable[..., Any]:
127 | return nullcontext
128 |
129 | @property
130 | def local_zero_first(self) -> Callable[..., Any]:
131 | return nullcontext
132 |
133 | @staticmethod
134 | def is_rank_zero() -> bool:
135 | return True
136 |
137 | @staticmethod
138 | def rank() -> int:
139 | return 0
140 |
141 | @staticmethod
142 | def world_size() -> int:
143 | return 1
144 |
145 |
146 | def initialize_overwatch(name: str) -> Union[DistributedOverwatch, PureOverwatch]:
147 | return DistributedOverwatch(name) if int(os.environ.get("WORLD_SIZE", -1)) != -1 else PureOverwatch(name)
148 |
--------------------------------------------------------------------------------
/prismatic/models/vlas/openvla.py:
--------------------------------------------------------------------------------
1 | """
2 | openvla.py
3 |
4 | PyTorch Module defining OpenVLA as a lightweight wrapper around a PrismaticVLM; defines custom logic around
5 | discretizing actions with the ActionTokenizer.
6 | """
7 |
8 | from typing import Dict, List, Optional
9 |
10 | import numpy as np
11 | import torch
12 | from PIL import Image
13 | from transformers import LlamaTokenizerFast
14 |
15 | from prismatic.models.vlms.prismatic import PrismaticVLM
16 | from prismatic.overwatch import initialize_overwatch
17 | from prismatic.vla.action_tokenizer import ActionTokenizer
18 |
19 | # Initialize Overwatch =>> Wraps `logging.Logger`
20 | overwatch = initialize_overwatch(__name__)
21 |
22 |
23 | class OpenVLA(PrismaticVLM):
24 | def __init__(
25 | self,
26 | *args,
27 | norm_stats: Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]],
28 | action_tokenizer: ActionTokenizer,
29 | **kwargs,
30 | ) -> None:
31 | super().__init__(*args, **kwargs)
32 | self.norm_stats = norm_stats
33 | self.action_tokenizer = action_tokenizer
34 |
35 | @torch.inference_mode()
36 | def predict_action(
37 | self, image: Image, instruction: str, unnorm_key: Optional[str] = None, **kwargs: str
38 | ) -> np.ndarray:
39 | """
40 | Core function for VLA inference; maps input image and task instruction to continuous action (de-tokenizes).
41 |
42 | @param image: PIL Image as [height, width, 3]
43 | @param instruction: Task instruction string
44 | @param unnorm_key: Optional dataset name for retrieving un-normalizing statistics; if None, checks that model
45 | was trained only on a single dataset, and retrieves those statistics.
46 |
47 | @return Unnormalized (continuous) action vector --> end-effector deltas.
48 | """
49 | image_transform, tokenizer = self.vision_backbone.image_transform, self.llm_backbone.tokenizer
50 |
51 | # Build VLA Prompt
52 | prompt_builder = self.get_prompt_builder()
53 | prompt_builder.add_turn(role="human", message=f"What action should the robot take to {instruction.lower()}?")
54 | prompt_text = prompt_builder.get_prompt()
55 |
56 | # Prepare Inputs
57 | input_ids = tokenizer(prompt_text, truncation=True, return_tensors="pt").input_ids.to(self.device)
58 | if isinstance(tokenizer, LlamaTokenizerFast):
59 | # If the special empty token ('') does not already appear after the colon (':') token in the prompt
60 | # (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time
61 | if not torch.all(input_ids[:, -1] == 29871):
62 | input_ids = torch.cat(
63 | (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
64 | )
65 | else:
66 | raise ValueError(f"Unsupported `tokenizer` type = {type(tokenizer)}")
67 |
68 | # Preprocess Image
69 | pixel_values = image_transform(image)
70 | if isinstance(pixel_values, torch.Tensor):
71 | pixel_values = pixel_values[None, ...].to(self.device)
72 | elif isinstance(pixel_values, dict):
73 | pixel_values = {k: v[None, ...].to(self.device) for k, v in pixel_values.items()}
74 | else:
75 | raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}")
76 |
77 | # Invoke super().generate --> taps into `GenerationMixin` which (redirects) to `forward()`
78 | autocast_dtype = self.llm_backbone.half_precision_dtype
79 | with torch.autocast("cuda", dtype=autocast_dtype, enabled=self.enable_mixed_precision_training):
80 | # fmt: off
81 | generated_ids = super(PrismaticVLM, self).generate(
82 | input_ids=input_ids, # Shape: [1, seq]
83 | pixel_values=pixel_values, # Shape: [1, 3, res, res] or Dict[str, ...]
84 | max_new_tokens=self.get_action_dim(unnorm_key),
85 | **kwargs
86 | )
87 | # fmt: on
88 |
89 | # Extract predicted action tokens and translate into (normalized) continuous actions
90 | predicted_action_token_ids = generated_ids[0, -self.get_action_dim(unnorm_key) :]
91 | normalized_actions = self.action_tokenizer.decode_token_ids_to_actions(predicted_action_token_ids.cpu().numpy())
92 |
93 | # Un-normalize Actions
94 | action_norm_stats = self.get_action_stats(unnorm_key)
95 | mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool))
96 | action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"])
97 | actions = np.where(
98 | mask,
99 | 0.5 * (normalized_actions + 1) * (action_high - action_low) + action_low,
100 | normalized_actions,
101 | )
102 |
103 | return actions
104 |
105 | @staticmethod
106 | def _check_unnorm_key(norm_stats: Dict, unnorm_key: str) -> str:
107 | if unnorm_key is None:
108 | assert len(norm_stats) == 1, (
109 | f"Your model was trained on more than one dataset, please pass a `unnorm_key` from the following "
110 | f"options to choose the statistics used for un-normalizing actions: {norm_stats.keys()}"
111 | )
112 | unnorm_key = next(iter(norm_stats.keys()))
113 |
114 | # Error Handling
115 | assert (
116 | unnorm_key in norm_stats
117 | ), f"The `unnorm_key` you chose is not in the set of available statistics; choose from: {norm_stats.keys()}"
118 |
119 | return unnorm_key
120 |
121 | def get_action_dim(self, unnorm_key: Optional[str] = None) -> int:
122 | """Dimensionality of the policy's action space."""
123 | unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
124 |
125 | return len(self.norm_stats[unnorm_key]["action"]["q01"])
126 |
127 | def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict:
128 | """Dimensionality of the policy's action space."""
129 | unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
130 |
131 | return self.norm_stats[unnorm_key]["action"]
132 |
--------------------------------------------------------------------------------
/prismatic/vla/datasets/rlds/oxe/materialize.py:
--------------------------------------------------------------------------------
1 | """
2 | materialize.py
3 |
4 | Factory class for initializing Open-X Embodiment dataset kwargs and other parameters; provides and exports functions for
5 | clear control flow.
6 | """
7 |
8 | from copy import deepcopy
9 | from pathlib import Path
10 | from typing import Any, Dict, List, Tuple
11 |
12 | from prismatic.overwatch import initialize_overwatch
13 | from prismatic.vla.constants import ACTION_DIM, ACTION_PROPRIO_NORMALIZATION_TYPE, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX, NUM_ACTIONS_CHUNK, STOP_INDEX
14 | from prismatic.vla.datasets.rlds.oxe.configs import OXE_DATASET_CONFIGS, ActionEncoding
15 | from prismatic.vla.datasets.rlds.oxe.transforms import OXE_STANDARDIZATION_TRANSFORMS
16 |
17 | # Initialize Overwatch =>> Wraps `logging.Logger`
18 | overwatch = initialize_overwatch(__name__)
19 |
20 |
21 | def make_oxe_dataset_kwargs(
22 | dataset_name: str,
23 | data_root_dir: Path,
24 | load_camera_views: Tuple[str] = ("primary",),
25 | load_depth: bool = False,
26 | load_proprio: bool = True,
27 | load_language: bool = True,
28 | action_proprio_normalization_type = ACTION_PROPRIO_NORMALIZATION_TYPE,
29 | ) -> Dict[str, Any]:
30 | """Generates config (kwargs) for given dataset from Open-X Embodiment."""
31 | dataset_kwargs = deepcopy(OXE_DATASET_CONFIGS[dataset_name])
32 | if dataset_kwargs["action_encoding"] not in [ActionEncoding.EEF_POS, ActionEncoding.EEF_R6, ActionEncoding.JOINT_POS_BIMANUAL]:
33 | raise ValueError(f"Cannot load `{dataset_name}`; only EEF_POS & EEF_R6 & JOINT_POS_BIMANUAL actions supported!")
34 |
35 | # [Contract] For EEF_POS & EEF_R6 actions, only the last action dimension (gripper) is absolute!
36 | # Normalize all action dimensions *except* the gripper
37 | if dataset_kwargs["action_encoding"] is ActionEncoding.EEF_POS:
38 | dataset_kwargs["absolute_action_mask"] = [False] * 6 + [True]
39 | dataset_kwargs["action_normalization_mask"] = [True] * 6 + [False]
40 | elif dataset_kwargs["action_encoding"] is ActionEncoding.EEF_R6:
41 | dataset_kwargs["absolute_action_mask"] = [False] * 9 + [True]
42 | dataset_kwargs["action_normalization_mask"] = [True] * 9 + [False]
43 | elif dataset_kwargs["action_encoding"] is ActionEncoding.JOINT_POS_BIMANUAL:
44 | dataset_kwargs["absolute_action_mask"] = [True] * 14
45 | dataset_kwargs["action_normalization_mask"] = [True] * 14
46 | dataset_kwargs["action_proprio_normalization_type"] = action_proprio_normalization_type
47 |
48 | # Adjust Loaded Camera Views
49 | if len(missing_keys := (set(load_camera_views) - set(dataset_kwargs["image_obs_keys"]))) > 0:
50 | raise ValueError(f"Cannot load `{dataset_name}`; missing camera views `{missing_keys}`")
51 |
52 | # Filter
53 | dataset_kwargs["image_obs_keys"] = {
54 | k: v for k, v in dataset_kwargs["image_obs_keys"].items() if k in load_camera_views
55 | }
56 | dataset_kwargs["depth_obs_keys"] = {
57 | k: v for k, v in dataset_kwargs["depth_obs_keys"].items() if k in load_camera_views
58 | }
59 |
60 | # Eliminate Unnecessary Keys
61 | dataset_kwargs.pop("state_encoding")
62 | dataset_kwargs.pop("action_encoding")
63 | if not load_depth:
64 | dataset_kwargs.pop("depth_obs_keys")
65 | if not load_proprio:
66 | dataset_kwargs.pop("state_obs_keys")
67 |
68 | # Load Language
69 | if load_language:
70 | dataset_kwargs["language_key"] = "language_instruction"
71 |
72 | # Specify Standardization Transform
73 | dataset_kwargs["standardize_fn"] = OXE_STANDARDIZATION_TRANSFORMS[dataset_name]
74 |
75 | # Add any aux arguments
76 | if "aux_kwargs" in dataset_kwargs:
77 | dataset_kwargs.update(dataset_kwargs.pop("aux_kwargs"))
78 |
79 | return {"name": dataset_name, "data_dir": str(data_root_dir), **dataset_kwargs}
80 |
81 |
82 | def get_oxe_dataset_kwargs_and_weights(
83 | data_root_dir: Path,
84 | mixture_spec: List[Tuple[str, float]],
85 | load_camera_views: Tuple[str] = ("primary",),
86 | load_depth: bool = False,
87 | load_proprio: bool = True,
88 | load_language: bool = True,
89 | action_proprio_normalization_type = ACTION_PROPRIO_NORMALIZATION_TYPE,
90 | ) -> Tuple[Dict[str, Any], List[float]]:
91 | """
92 | Generates dataset kwargs for a given dataset mix from the Open X-Embodiment dataset. The returned kwargs
93 | (per-dataset configs) and weights can be passed directly to `make_interleaved_dataset`.
94 |
95 | :param data_root_dir: Base directory containing RLDS/TFDS-formatted datasets (from Open-X)
96 | :param mixture_spec: List of (dataset_name, sampling_weight) from `oxe.mixtures.OXE_NAMED_MIXTURES`
97 | :param load_camera_views: Camera views to load; see `oxe.dataset_configs.py` for available views.
98 | :param load_depth: Load depth information in addition to camera RGB.
99 | :param load_proprio: Load proprioceptive state.
100 | :param load_language: Load language instructions.
101 | :param action_proprio_normalization_type: Normalization scheme to use for proprioceptive actions.
102 |
103 | return: Tuple of (per_dataset_kwargs, sampling_weights)
104 | """
105 | included_datasets, filtered_mixture_spec = set(), []
106 | for d_name, d_weight in mixture_spec:
107 | if d_name in included_datasets:
108 | overwatch.warning(f"Skipping Duplicate Dataset: `{(d_name, d_weight)}`")
109 | continue
110 |
111 | included_datasets.add(d_name)
112 | filtered_mixture_spec.append((d_name, d_weight))
113 |
114 | # Assemble Dataset Config (kwargs) and Weights
115 | per_dataset_kwargs, sampling_weights = [], []
116 | for d_name, d_weight in filtered_mixture_spec:
117 | try:
118 | per_dataset_kwargs.append(
119 | make_oxe_dataset_kwargs(
120 | d_name,
121 | data_root_dir,
122 | load_camera_views,
123 | load_depth,
124 | load_proprio,
125 | load_language,
126 | action_proprio_normalization_type,
127 | )
128 | )
129 | sampling_weights.append(d_weight)
130 |
131 | except ValueError as e:
132 | overwatch.warning(f"Skipping `{d_name}` due to Error: {e}")
133 |
134 | return per_dataset_kwargs, sampling_weights
135 |
--------------------------------------------------------------------------------
/prismatic/vla/datasets/rlds/oxe/utils/droid_utils.py:
--------------------------------------------------------------------------------
1 | """Episode transforms for DROID dataset."""
2 |
3 | from typing import Any, Dict
4 |
5 | import tensorflow as tf
6 | import tensorflow_graphics.geometry.transformation as tfg
7 |
8 |
9 | def rmat_to_euler(rot_mat):
10 | return tfg.euler.from_rotation_matrix(rot_mat)
11 |
12 |
13 | def euler_to_rmat(euler):
14 | return tfg.rotation_matrix_3d.from_euler(euler)
15 |
16 |
17 | def invert_rmat(rot_mat):
18 | return tfg.rotation_matrix_3d.inverse(rot_mat)
19 |
20 |
21 | def rotmat_to_rot6d(mat):
22 | """
23 | Converts rotation matrix to R6 rotation representation (first two rows in rotation matrix).
24 | Args:
25 | mat: rotation matrix
26 |
27 | Returns: 6d vector (first two rows of rotation matrix)
28 |
29 | """
30 | r6 = mat[..., :2, :]
31 | r6_0, r6_1 = r6[..., 0, :], r6[..., 1, :]
32 | r6_flat = tf.concat([r6_0, r6_1], axis=-1)
33 | return r6_flat
34 |
35 |
36 | def velocity_act_to_wrist_frame(velocity, wrist_in_robot_frame):
37 | """
38 | Translates velocity actions (translation + rotation) from base frame of the robot to wrist frame.
39 | Args:
40 | velocity: 6d velocity action (3 x translation, 3 x rotation)
41 | wrist_in_robot_frame: 6d pose of the end-effector in robot base frame
42 |
43 | Returns: 9d velocity action in robot wrist frame (3 x translation, 6 x rotation as R6)
44 |
45 | """
46 | R_frame = euler_to_rmat(wrist_in_robot_frame[:, 3:6])
47 | R_frame_inv = invert_rmat(R_frame)
48 |
49 | # world to wrist: dT_pi = R^-1 dT_rbt
50 | vel_t = (R_frame_inv @ velocity[:, :3][..., None])[..., 0]
51 |
52 | # world to wrist: dR_pi = R^-1 dR_rbt R
53 | dR = euler_to_rmat(velocity[:, 3:6])
54 | dR = R_frame_inv @ (dR @ R_frame)
55 | dR_r6 = rotmat_to_rot6d(dR)
56 | return tf.concat([vel_t, dR_r6], axis=-1)
57 |
58 |
59 | def rand_swap_exterior_images(img1, img2):
60 | """
61 | Randomly swaps the two exterior images (for training with single exterior input).
62 | """
63 | return tf.cond(tf.random.uniform(shape=[]) > 0.5, lambda: (img1, img2), lambda: (img2, img1))
64 |
65 |
66 | def droid_baseact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
67 | """
68 | DROID dataset transformation for actions expressed in *base* frame of the robot.
69 | """
70 | dt = trajectory["action_dict"]["cartesian_velocity"][:, :3]
71 | dR = trajectory["action_dict"]["cartesian_velocity"][:, 3:6]
72 |
73 | trajectory["action"] = tf.concat(
74 | (
75 | dt,
76 | dR,
77 | 1 - trajectory["action_dict"]["gripper_position"],
78 | ),
79 | axis=-1,
80 | )
81 | trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = (
82 | rand_swap_exterior_images(
83 | trajectory["observation"]["exterior_image_1_left"],
84 | trajectory["observation"]["exterior_image_2_left"],
85 | )
86 | )
87 | trajectory["observation"]["proprio"] = tf.concat(
88 | (
89 | trajectory["observation"]["cartesian_position"],
90 | trajectory["observation"]["gripper_position"],
91 | ),
92 | axis=-1,
93 | )
94 | return trajectory
95 |
96 |
97 | def droid_wristact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
98 | """
99 | DROID dataset transformation for actions expressed in *wrist* frame of the robot.
100 | """
101 | wrist_act = velocity_act_to_wrist_frame(
102 | trajectory["action_dict"]["cartesian_velocity"], trajectory["observation"]["cartesian_position"]
103 | )
104 | trajectory["action"] = tf.concat(
105 | (
106 | wrist_act,
107 | trajectory["action_dict"]["gripper_position"],
108 | ),
109 | axis=-1,
110 | )
111 | trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = (
112 | rand_swap_exterior_images(
113 | trajectory["observation"]["exterior_image_1_left"],
114 | trajectory["observation"]["exterior_image_2_left"],
115 | )
116 | )
117 | trajectory["observation"]["proprio"] = tf.concat(
118 | (
119 | trajectory["observation"]["cartesian_position"],
120 | trajectory["observation"]["gripper_position"],
121 | ),
122 | axis=-1,
123 | )
124 | return trajectory
125 |
126 |
127 | def droid_finetuning_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
128 | """
129 | DROID dataset transformation for actions expressed in *base* frame of the robot.
130 | """
131 | dt = trajectory["action_dict"]["cartesian_velocity"][:, :3]
132 | dR = trajectory["action_dict"]["cartesian_velocity"][:, 3:6]
133 | trajectory["action"] = tf.concat(
134 | (
135 | dt,
136 | dR,
137 | 1 - trajectory["action_dict"]["gripper_position"],
138 | ),
139 | axis=-1,
140 | )
141 | trajectory["observation"]["proprio"] = tf.concat(
142 | (
143 | trajectory["observation"]["cartesian_position"],
144 | trajectory["observation"]["gripper_position"],
145 | ),
146 | axis=-1,
147 | )
148 | return trajectory
149 |
150 |
151 | def zero_action_filter(traj: Dict) -> bool:
152 | """
153 | Filters transitions whose actions are all-0 (only relative actions, no gripper action).
154 | Note: this filter is applied *after* action normalization, so need to compare to "normalized 0".
155 | """
156 | DROID_Q01 = tf.convert_to_tensor(
157 | [
158 | -0.7776297926902771,
159 | -0.5803514122962952,
160 | -0.5795090794563293,
161 | -0.6464047729969025,
162 | -0.7041108310222626,
163 | -0.8895104378461838,
164 | ]
165 | )
166 | DROID_Q99 = tf.convert_to_tensor(
167 | [
168 | 0.7597932070493698,
169 | 0.5726242214441299,
170 | 0.7351000607013702,
171 | 0.6705610305070877,
172 | 0.6464948207139969,
173 | 0.8897542208433151,
174 | ]
175 | )
176 | DROID_NORM_0_ACT = 2 * (tf.zeros_like(traj["action"][:, :6]) - DROID_Q01) / (DROID_Q99 - DROID_Q01 + 1e-8) - 1
177 |
178 | return tf.reduce_any(tf.math.abs(traj["action"][:, :6] - DROID_NORM_0_ACT) > 1e-5)
179 |
--------------------------------------------------------------------------------
/prismatic/extern/hf/configuration_prismatic.py:
--------------------------------------------------------------------------------
1 | """
2 | configuration_prismatic.py
3 |
4 | HuggingFace-style configuration definition for Prismatic VLMs, inheriting from `transformers.PretrainedConfig`.
5 | Default configuration specifies `siglip-224px+7b`.
6 | """
7 |
8 | from typing import Any, Dict, List, Optional
9 |
10 | from transformers import PretrainedConfig
11 | from transformers.models.auto import CONFIG_MAPPING
12 |
13 | # === Utilities for Mapping Prismatic names to HF names ===
14 | # fmt: off
15 | VISION_BACKBONE_TO_RESOLUTION: Dict[str, List[int]] = {
16 | "clip-vit-l": [224], "siglip-vit-so400m": [224], "dinov2-vit-l": [224], "in1k-vit-l": [224],
17 |
18 | "clip-vit-l-336px": [336],
19 | "siglip-vit-so400m-384px": [384],
20 |
21 | "dinoclip-vit-l-336px": [336, 336],
22 | "dinosiglip-vit-so-224px": [224, 224],
23 | "dinosiglip-vit-so-384px": [384, 384],
24 | }
25 | VISION_BACKBONE_TO_TIMM_ID: Dict[str, List[str]] = {
26 | "clip-vit-l": ["vit_large_patch14_clip_224.openai"],
27 | "clip-vit-l-336px": ["vit_large_patch14_clip_336.openai"],
28 |
29 | "dinov2-vit-l": ["vit_large_patch14_reg4_dinov2.lvd142m"],
30 | "in1k-vit-l": ["vit_large_patch16_224.augreg_in21k_ft_in1k"],
31 |
32 | "siglip-vit-so400m": ["vit_so400m_patch14_siglip_224"],
33 | "siglip-vit-so400m-384px": ["vit_so400m_patch14_siglip_384"],
34 |
35 | "dinoclip-vit-l-336px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_large_patch14_clip_336.openai"],
36 | "dinosiglip-vit-so-224px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_224"],
37 | "dinosiglip-vit-so-384px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_384"],
38 | }
39 | TIMM_OVERRIDE_ACT_LAYER: Dict[str, List[Optional[str]]] = {
40 | "clip-vit-l": ["quick_gelu"], "clip-vit-l-336px": ["quick_gelu"],
41 | "dinov2-vit-l": [None], "in1k-vit-l": [None],
42 | "siglip-vit-so400m": [None], "siglip-vit-so400m-384px": [None],
43 | "dinoclip-vit-l-336px": [None, "quick_gelu"],
44 | "dinosiglip-vit-so-224px": [None, None], "dinosiglip-vit-so-384px": [None, None]
45 | }
46 |
47 | LLM_BACKBONE_TO_HF_PATH = {
48 | "llama2-7b-pure": "meta-llama/Llama-2-7b-hf", "llama2-13b-pure": "meta-llama/Llama-2-13b-hf",
49 | "llama2-7b-chat": "meta-llama/Llama-2-7b-chat-hf", "llama2-13b-chat": "meta-llama/Llama-2-13b-chat-hf",
50 |
51 | "vicuna-v15-7b": "lmsys/vicuna-7b-v1.5", "vicuna-v15-13b": "lmsys/vicuna-13b-v1.5",
52 |
53 | "mistral-v0.1-7b-pure": "mistralai/Mistral-7B-v0.1",
54 | "mistral-v0.1-7b-instruct": "mistralai/Mistral-7B-Instruct-v0.1",
55 |
56 | "phi-2-3b": "microsoft/phi-2",
57 | }
58 | LLM_BACKBONE_TO_HF_METACLASS = {
59 | "llama2-7b-pure": "llama", "llama2-13b-pure": "llama", "llama2-7b-chat": "llama", "llama2-13b-chat": "llama",
60 | "vicuna-v15-7b": "llama", "vicuna-v15-13b": "llama",
61 |
62 | "mistral-v0.1-7b-pure": "mistral", "mistral-v0.1-7b-instruct": "mistral",
63 |
64 | "phi-2-3b": "phi",
65 | }
66 |
67 | VALID_VISION_BACKBONES = set(VISION_BACKBONE_TO_RESOLUTION.keys())
68 | VALID_LLM_BACKBONES = set(LLM_BACKBONE_TO_HF_PATH)
69 | # fmt: on
70 |
71 |
72 | class PrismaticConfig(PretrainedConfig):
73 | model_type: str = "prismatic"
74 | is_composition: bool = False
75 |
76 | def __init__(
77 | self,
78 | vision_backbone_id: str = "siglip-vit-so400m",
79 | llm_backbone_id: str = "vicuna-v15-7b",
80 | arch_specifier: str = "no-align+gelu-mlp",
81 | use_fused_vision_backbone: Optional[bool] = None,
82 | image_resize_strategy: str = "letterbox",
83 | text_config: Optional[Dict[str, Any]] = None,
84 | llm_max_length: int = 2048,
85 | pad_token_id: int = 32000,
86 | pad_to_multiple_of: int = 64,
87 | output_projector_states: bool = False,
88 | **kwargs: str,
89 | ) -> None:
90 | if vision_backbone_id not in VALID_VISION_BACKBONES:
91 | raise ValueError(f"Vision backbone `{vision_backbone_id}` not in {VALID_VISION_BACKBONES = }")
92 |
93 | if llm_backbone_id not in VALID_LLM_BACKBONES:
94 | raise ValueError(f"LLM backbone `{llm_backbone_id}` not in {VALID_LLM_BACKBONES = }")
95 |
96 | # Set Prismatic Configuration Fields
97 | self.vision_backbone_id = vision_backbone_id
98 | self.llm_backbone_id = llm_backbone_id
99 | self.arch_specifier = arch_specifier
100 | self.output_projector_states = output_projector_states
101 |
102 | # [Contract] All vision backbone parameters are lists =>> supports fused backbones with different preprocessing
103 | self.use_fused_vision_backbone = (
104 | use_fused_vision_backbone
105 | if use_fused_vision_backbone is not None
106 | else any(self.vision_backbone_id.startswith(v) for v in ["dinoclip", "dinosiglip"])
107 | )
108 |
109 | self.timm_model_ids = VISION_BACKBONE_TO_TIMM_ID[self.vision_backbone_id]
110 | self.timm_override_act_layers = TIMM_OVERRIDE_ACT_LAYER[self.vision_backbone_id]
111 | self.image_sizes = VISION_BACKBONE_TO_RESOLUTION[self.vision_backbone_id]
112 | self.image_resize_strategy = image_resize_strategy
113 |
114 | self.hf_llm_id = LLM_BACKBONE_TO_HF_PATH[self.llm_backbone_id]
115 | self.llm_max_length = llm_max_length
116 | self.pad_token_id, self.pad_to_multiple_of = pad_token_id, pad_to_multiple_of
117 |
118 | # [IMPORTANT] HF Utilities actually look for a `text_config` field... we need to use that specific naming!
119 | self.text_config = (
120 | CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]](**text_config)
121 | if text_config is not None
122 | else CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]]()
123 | )
124 |
125 | # Dispatch **kwargs to super() =>> note that `pad_token_id` collides, so we pass it in here as well...
126 | super().__init__(pad_token_id=pad_token_id, **kwargs)
127 |
128 |
129 | class OpenVLAConfig(PrismaticConfig):
130 | model_type: str = "openvla"
131 |
132 | def __init__(
133 | self,
134 | norm_stats: Optional[Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]]] = None,
135 | n_action_bins: int = 256,
136 | **kwargs: str,
137 | ) -> None:
138 | self.norm_stats, self.n_action_bins = norm_stats, n_action_bins
139 |
140 | super().__init__(**kwargs)
141 |
--------------------------------------------------------------------------------
/prismatic/models/action_heads.py:
--------------------------------------------------------------------------------
1 | """Implementations of various action heads, which serve as alternatives to VLM sequential token prediction."""
2 |
3 | import math
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | #from diffusers.schedulers.scheduling_ddim import DDIMScheduler
9 | from prismatic.vla.constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX, NUM_ACTIONS_CHUNK, STOP_INDEX
10 |
11 |
12 | class SinusoidalPositionalEncoding(nn.Module):
13 | """
14 | Sine- and cosine-based positional encoding that produces embeddings of a batch of timesteps.
15 |
16 | For example, at train time, the input might be a batch of 32 randomly sampled diffusion timesteps -> shape (32,)
17 | Then the output would be a batch of 32 timestep embeddings -> shape (32, D)
18 |
19 | Adapted from: https://github.com/real-stanford/diffusion_policy/blob/main/diffusion_policy/model/diffusion/positional_embedding.py
20 | """
21 |
22 | def __init__(self, dim):
23 | super().__init__()
24 | self.dim = dim # dimensionality of the positional encoding
25 |
26 | def forward(self, x):
27 | # x: (batch_size,)
28 | device = x.device
29 | assert self.dim % 2 == 0, f"# dimensions must be even but got {self.dim}"
30 | half_dim = self.dim // 2
31 | exponent = torch.arange(half_dim, device=device) * -math.log(10000) / (half_dim - 1) # shape: (D/2,)
32 | emb = torch.exp(exponent) # shape: (D/2,)
33 | emb = x[:, None] * emb[None, :] # shape: (batch_size, 1) * (1, D/2) -> (batch_size, D/2)
34 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) # shape: (batch_size, D)
35 | return emb
36 |
37 |
38 | class MLPResNetBlock(nn.Module):
39 | """One MLP ResNet block with a residual connection."""
40 | def __init__(self, dim):
41 | super().__init__()
42 | self.dim = dim
43 | self.ffn = nn.Sequential( # feedforward network, similar to the ones in Transformers
44 | nn.LayerNorm(dim),
45 | nn.Linear(dim, dim),
46 | nn.ReLU(),
47 | )
48 |
49 | def forward(self, x):
50 | # x: (batch_size, hidden_dim)
51 | # We follow the module ordering of "Pre-Layer Normalization" feedforward networks in Transformers as
52 | # described here: https://arxiv.org/pdf/2002.04745.pdf
53 | identity = x
54 | x = self.ffn(x)
55 | x = x + identity
56 | return x
57 |
58 |
59 | class MLPResNet(nn.Module):
60 | """MLP with residual connection blocks."""
61 | def __init__(self, num_blocks, input_dim, hidden_dim, output_dim):
62 | super().__init__()
63 | self.layer_norm1 = nn.LayerNorm(input_dim)
64 | self.fc1 = nn.Linear(input_dim, hidden_dim)
65 | self.relu = nn.ReLU()
66 | self.mlp_resnet_blocks = nn.ModuleList()
67 | for _ in range(num_blocks):
68 | self.mlp_resnet_blocks.append(MLPResNetBlock(dim=hidden_dim))
69 | self.layer_norm2 = nn.LayerNorm(hidden_dim)
70 | self.fc2 = nn.Linear(hidden_dim, output_dim)
71 |
72 | def forward(self, x):
73 | # x: (batch_size, input_dim)
74 | x = self.layer_norm1(x) # shape: (batch_size, input_dim)
75 | x = self.fc1(x) # shape: (batch_size, hidden_dim)
76 | x = self.relu(x) # shape: (batch_size, hidden_dim)
77 | for block in self.mlp_resnet_blocks:
78 | x = block(x) # shape: (batch_size, hidden_dim)
79 | x = self.layer_norm2(x) # shape: (batch_size, hidden_dim)
80 | x = self.fc2(x) # shape: (batch_size, output_dim)
81 | return x
82 |
83 | class MLPResNet_idcat(nn.Module):
84 | """MLP with residual connection blocks."""
85 | def __init__(self, num_blocks, input_dim, hidden_dim, output_dim):
86 | super().__init__()
87 | self.layer_norm1 = nn.LayerNorm(input_dim)
88 | self.fc1 = nn.Linear(input_dim + 1, hidden_dim)
89 | self.relu = nn.ReLU()
90 | self.mlp_resnet_blocks = nn.ModuleList()
91 | for _ in range(num_blocks):
92 | self.mlp_resnet_blocks.append(MLPResNetBlock(dim=hidden_dim))
93 | self.layer_norm2 = nn.LayerNorm(hidden_dim)
94 | self.fc2 = nn.Linear(hidden_dim, output_dim)
95 |
96 | def forward(self, x, taskid):
97 | # x: (batch_size, input_dim)
98 | x = self.layer_norm1(x) # shape: (batch_size, input_dim)
99 | x = torch.cat((x, taskid.unsqueeze(1).unsqueeze(2).repeat(1,8,1)), axis=2)
100 | x = self.fc1(x) # shape: (batch_size, hidden_dim)
101 | x = self.relu(x) # shape: (batch_size, hidden_dim)
102 | for block in self.mlp_resnet_blocks:
103 | x = block(x) # shape: (batch_size, hidden_dim)
104 | x = self.layer_norm2(x) # shape: (batch_size, hidden_dim)
105 | x = self.fc2(x) # shape: (batch_size, output_dim)
106 | return x
107 |
108 | class L1RegressionActionHead_idcat(nn.Module):
109 | """Simple MLP-based action head that generates continuous actions via L1 regression."""
110 | def __init__(
111 | self,
112 | input_dim=4096,
113 | hidden_dim=4096,
114 | action_dim=7,
115 | ):
116 | super().__init__()
117 | self.action_dim = action_dim
118 | self.model = MLPResNet_idcat(
119 | num_blocks=2, input_dim=input_dim*ACTION_DIM, hidden_dim=hidden_dim, output_dim=action_dim
120 | )
121 |
122 | def predict_action(self, actions_hidden_states, taskid):
123 | # actions_hidden_states: last hidden states of Transformer corresponding to action tokens in sequence
124 | # - shape: (batch_size, chunk_len * action_dim, hidden_dim)
125 | # ground_truth_actions: ground-truth actions
126 | # - shape: (batch_size, chunk_len, action_dim)
127 | batch_size = actions_hidden_states.shape[0]
128 | device = actions_hidden_states.device
129 | rearranged_actions_hidden_states = actions_hidden_states.reshape(batch_size, NUM_ACTIONS_CHUNK, -1)
130 | action = self.model(rearranged_actions_hidden_states, taskid)
131 | return action
132 |
133 | class L1RegressionDistHead(nn.Module):
134 | """Simple MLP-based action head that generates continuous actions via L1 regression."""
135 | def __init__(
136 | self,
137 | input_dim=4096,
138 | hidden_dim=4096,
139 | action_dim=1,
140 | ):
141 | super().__init__()
142 | self.action_dim = action_dim
143 | self.model = MLPResNet(
144 | num_blocks=2, input_dim=input_dim*ACTION_DIM, hidden_dim=hidden_dim, output_dim=action_dim
145 | )
146 |
147 | def predict_action(self, actions_hidden_states):
148 | # actions_hidden_states: last hidden states of Transformer corresponding to action tokens in sequence
149 | # - shape: (batch_size, chunk_len * action_dim, hidden_dim)
150 | # ground_truth_actions: ground-truth actions
151 | # - shape: (batch_size, chunk_len, action_dim)
152 | batch_size = actions_hidden_states.shape[0]
153 | device = actions_hidden_states.device
154 | rearranged_actions_hidden_states = actions_hidden_states.reshape(batch_size, NUM_ACTIONS_CHUNK, -1)
155 | dist = self.model(rearranged_actions_hidden_states).squeeze(2)
156 | dist_ave = dist.mean(dim=1, keepdim=False)
157 | return dist_ave
158 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # OmniVLA: An Omni-Modal Vision-Language-Action Model for Robot Navigation
2 | [](https://www.python.org)
3 | [](https://opensource.org/licenses/MIT)
4 | [](https://omnivla-nav.github.io)
5 |
6 |
7 | [Noriaki Hirose](https://sites.google.com/view/noriaki-hirose/)1, 2, [Catherine Glossop](https://catglossop.github.io/)1, [Dhruv Shah](https://robodhruv.github.io/)3, [Sergey Levine](https://people.eecs.berkeley.edu/~svlevine/)1
8 |
9 | 1 UC Berkeley (_Berkeley AI Research_), 2 Toyota Motor North America, , 3 Princeton University
10 |
11 | ### Installation
12 | Please set up a conda environment (see instructions in [SETUP.md](SETUP.md)).
13 |
14 | ### Inference
15 | 1. Download our checkpoints and place them in our directory. "omnivla-original" is the trained checkpoints of the OmniVLA for paper submission. "omnivla-original-balance" contains the trained checkpoints of OmniVLA that account for the data balance in the LeLaN dataset. And "omnivla-finetuned-cast" is finetuned checkpoints with the [CAST](https://huggingface.co/datasets/catglossop/CAST-dataset) dataset.
16 | ```
17 | git clone https://huggingface.co/NHirose/omnivla-original
18 | git clone https://huggingface.co/NHirose/omnivla-original-balance
19 | git clone https://huggingface.co/NHirose/omnivla-finetuned-cast
20 | ```
21 | 2. Run OmniVLA using a sample current image, goal images, GPS pose, and language prompt. You can view the generated trajectory in the output figure 1_ex.jpg.
22 | ```
23 | python inference/run_omnivla.py
24 | ```
25 | 3. Change the goal modality: by default, our code generates actions based on the language prompt. To use a different modality, you can modify the settings around line 560.
26 |
27 | 4. Run OmniVLA to control the real robot. Modify "run_omnivla.py" to update the robot’s state (camera image, GPS signal) and adjust the goal information accordingly. Then, feed the generated velocity commands to your robot.
28 |
29 | 5. To try the finetuned checkpoints with the CAST dataset, update the path and step number in "InferenceConfig" within "run_omnivla.py".
30 |
31 | ### Training
32 | We provide the training code along with a sample dataloader to help you quickly understand the required data loading structure. Since preparing the full training dataset is resource-intensive, we include this simplified code base for convenience.
33 |
34 | 1. Downloading MBRA project code base:
35 | ```
36 | cd ..
37 | git clone https://github.com/NHirose/Learning-to-Drive-Anywhere-with-MBRA.git
38 | ```
39 | 2. Downloading MBRA model:
40 | ```
41 | cd OmniVLA_internal
42 | git clone https://huggingface.co/NHirose/MBRA/
43 | ```
44 | 3. You can set the training or debugging mode at line 10 in vla-scripts/train_omnivla.py. Note that even in debugging mode, the code requires at least 20 GB of GPU memory (we use an NVIDIA RTX 4090).
45 |
46 | 4. You can configure visualization at line 11 in vla-scripts/train_omnivla.py. During training, it should be set to False.
47 |
48 | 5. Training our policy from OpenVLA checkpoints (Please fill X):
49 | ```
50 | torchrun --standalone --nnodes 1 --nproc-per-node X vla-scripts/train_omnivla.py --vla_path openvla/openvla-7b --dataset_name omnivla --num_images_in_input 2 --batch_size X --wandb_entity "X" --wandb_project "omnivla"
51 | ```
52 | 6. Finetuning our OmniVLA (Please fill X):
53 | ```
54 | torchrun --standalone --nnodes 1 --nproc-per-node X vla-scripts/train_omnivla.py --vla_path ./omnivla-original --dataset_name omnivla --num_images_in_input 2 --batch_size X --wandb_entity "X" --wandb_project "omnivla"
55 | ````
56 | 7. Memo finetuning our OmniVLA on our large navigation dataset:
57 | ```
58 | conda activate omnivla_2
59 | cd /media/noriaki/Noriaki_Data/OmniVLA
60 | torchrun --standalone --nnodes 1 --nproc-per-node 1 vla-scripts/train_omnivla_dataset.py --vla_path ./omnivla-original --dataset_name omnivla --wandb_entity "noriaki-hirose" --wandb_project "omnivla"
61 | ```
62 |
63 | ### Training with GNM, LeLaN, Frodobots, BDD and CAST datasets
64 | We provide training code that supports multiple public datasets. Before following the full training process, please first ensure that you can run the example training with the sample dataloader.
65 |
66 | 1. Downloading all datasets from the original website. ([GNM](https://github.com/robodhruv/visualnav-transformer), [LeLaN](https://github.com/NHirose/learning-language-navigation), [Frodobots](https://github.com/NHirose/Learning-to-Drive-Anywhere-with-MBRA), [CAST](https://openvla-oft.github.io/)) Please verify that the downloaded datasets work properly in their original codebase, except BDD dataset. Note that please download the LeLaN dataset from [this link](https://huggingface.co/datasets/NHirose/LeLaN_dataset_NoMaD_traj/tree/main) instead of [the original link](https://drive.google.com/file/d/1IazHcIyPGO7ENswz8_sGCIGBXF8_sZJK/view). The updated dataset already includes the NoMaD trajectories used for collision-avoidance supervision, you no longer need to compute the NoMaD policy during training. Please carefully follow the usage procedure described in the [LeLaN codebase](https://github.com/NHirose/learning-language-navigation) when working with the dataset.
67 |
68 | 2. Downloading the modified BDD dataset with MBRA annotations from [here](https://huggingface.co/datasets/NHirose/BDD_OmniVLA) and extract it. The image sequences in the modified dataset remain subject to the [original BDD license](http://bdd-data.berkeley.edu/download.html), while the additional MBRA annotations are released under the MIT license.
69 |
70 | 3. Downloading the lerobot code base for the Frodobots dataset dataloader:
71 | ```
72 | git clone https://github.com/huggingface/lerobot.git
73 | ```
74 | 4. Edit the data path in config_nav/mbra_and_dataset_config.yaml:
75 |
76 | 5. Training our policy from OpenVLA checkpoints (Please fill X):
77 | ```
78 | torchrun --standalone --nnodes 1 --nproc-per-node X vla-scripts/train_omnivla_dataset.py --vla_path ./omnivla-original --dataset_name omnivla --wandb_entity "X" --wandb_project "omnivla"
79 | ```
80 |
81 | In our training setup, we use 8 Nvidia H100 GPUs (80 GB each) across 8 nodes. The batch sizes are configured as [LeLaN, GNM, Frodobots, BDD] = [4, 1, 1, 1], with gradient accumulation set to 4 steps. When finetuning with CAST dataset, we set the batch size as [LeLaN, CAST, GNM, Frodobots, BDD] = [2, 2, 1, 1, 1]. To do so, you need to directly edit train_omnivla_dataset.py.
82 |
83 | ### Acknowledgement
84 | We implement our ideas and design choices on top of the pretrained checkpoints. Our work builds upon the [OpenVLA-OFT](https://openvla-oft.github.io/) codebase, with additional code added to create OmniVLA. As such, our implementation leverages many components of the OpenVLA-OFT codebase. We sincerely appreciate the effort and contributions of the OpenVLA-OFT team!
85 |
86 | ## Citing
87 | ```
88 | @misc{hirose2025omnivla,
89 | title={OmniVLA: An Omni-Modal Vision-Language-Action Model for Robot Navigation},
90 | author={Noriaki Hirose and Catherine Glossop and Dhruv Shah and Sergey Levine},
91 | year={2025},
92 | eprint={2509.19480},
93 | archivePrefix={arXiv},
94 | primaryClass={cs.RO},
95 | url={https://arxiv.org/abs/2509.19480},
96 | }
97 |
--------------------------------------------------------------------------------
/prismatic/training/strategies/ddp.py:
--------------------------------------------------------------------------------
1 | """
2 | ddp.py
3 |
4 | Core class definition for a strategy implementing Torch native Distributed Data Parallel Training; note that on most
5 | GPU hardware and LLM backbones >= 5-7B parameters, DDP training will OOM, which is why we opt for FSDP.
6 | """
7 |
8 | import shutil
9 | from pathlib import Path
10 | from typing import Optional
11 |
12 | import torch
13 | from torch.nn.parallel import DistributedDataParallel as DDP
14 | from torch.optim import AdamW
15 | from transformers.optimization import get_constant_schedule, get_cosine_schedule_with_warmup
16 |
17 | from prismatic.overwatch import initialize_overwatch
18 | from prismatic.training.strategies.base_strategy import TrainingStrategy
19 |
20 | # Initialize Overwatch =>> Wraps `logging.Logger`
21 | overwatch = initialize_overwatch(__name__)
22 |
23 |
24 | class DDPStrategy(TrainingStrategy):
25 | @overwatch.rank_zero_only
26 | def save_checkpoint(
27 | self,
28 | run_dir: Path,
29 | global_step: int,
30 | epoch: int,
31 | train_loss: Optional[float] = None,
32 | only_trainable: bool = True,
33 | ) -> None:
34 | """Save a checkpoint to the `run_dir` only containing the state_dicts for trainable parameters by default."""
35 | assert isinstance(self.vlm, DDP), "save_checkpoint assumes VLM is already wrapped in DDP!"
36 |
37 | # Splinter State Dictionary by Top-Level Submodules (or subset, if `only_trainable`)
38 | model_state_dicts = {
39 | mkey: getattr(self.vlm.module, mkey).state_dict()
40 | for mkey in (self.trainable_module_keys if only_trainable else self.all_module_keys)
41 | }
42 | optimizer_state_dict = self.optimizer.state_dict()
43 |
44 | # Set Checkpoint Path =>> Embed *minimal* training statistics!
45 | checkpoint_dir = run_dir / "checkpoints"
46 | if train_loss is None:
47 | checkpoint_path = checkpoint_dir / f"step-{global_step:06d}-epoch-{epoch:02d}-loss=inf.pt"
48 | else:
49 | checkpoint_path = checkpoint_dir / f"step-{global_step:06d}-epoch-{epoch:02d}-loss={train_loss:.4f}.pt"
50 |
51 | # Save Checkpoint & Copy Latest to `latest-checkpoint.pt`
52 | torch.save({"model": model_state_dicts, "optimizer": optimizer_state_dict}, checkpoint_path)
53 | shutil.copy(checkpoint_path, checkpoint_dir / "latest-checkpoint.pt")
54 |
55 | def run_setup(self, run_dir: Path, n_train_examples: int) -> None:
56 | # Gradient Checkpointing Setup
57 | if self.enable_gradient_checkpointing:
58 | # For Gradient Checkpointing --> we make the assumption that the "bulk" of activation memory is taken up
59 | # by the LLM; because we also make the explicit assumption that each LLM is derived from a HF
60 | # pretrained model, the only thing we *need* to do (technically) is call `gradient_checkpoint_enable`
61 | # on `self.llm_backbone`.
62 | #
63 | # What does it actually do? --> runs the *generic* custom_forward + torch.utils.checkpoint.checkpoint logic
64 | # => github.com/huggingface/transformers/.../models/llama/modeling_llama.py#L692-L706
65 | #
66 | # Additional Reference (to better understand gradient checkpointing in PyTorch writ large)
67 | # => github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb
68 | overwatch.info("Enabling Gradient Checkpointing on LLM Backbone", ctx_level=1)
69 | self.vlm.llm_backbone.gradient_checkpointing_enable()
70 |
71 | # Move to Device =>> Note parameters are in full precision (*mixed precision* will only autocast as appropriate)
72 | overwatch.info("Placing Entire VLM (Vision Backbone, LLM Backbone, Projector Weights) on GPU", ctx_level=1)
73 | self.vlm.to(self.device_id)
74 |
75 | # Wrap with Distributed Data Parallel
76 | # => Note: By default, wrapping naively with DDP(self.vlm) will initialize a *separate* buffer on GPU that
77 | # is the same size/dtype as the model parameters; this will *double* GPU memory!
78 | # - stackoverflow.com/questions/68949954/model-takes-twice-the-memory-footprint-with-distributed-data-parallel
79 | overwatch.info("Wrapping VLM with Distributed Data Parallel", ctx_level=1)
80 | self.vlm = DDP(self.vlm, device_ids=[self.device_id], gradient_as_bucket_view=True)
81 |
82 | # Create Optimizer and LR Scheduler =>> note that most of the LR Schedulers we use require `max_steps/epochs`
83 | # => Optimizer should only operate on parameters that are *unfrozen* / trainable!
84 | trainable_params = [param for param in self.vlm.parameters() if param.requires_grad]
85 | if self.max_steps is None:
86 | num_training_steps = (n_train_examples * self.epochs) // self.global_batch_size
87 | else:
88 | num_training_steps = self.max_steps
89 |
90 | if self.lr_scheduler_type == "linear-warmup+cosine-decay":
91 | # Set warmup steps (floor) based on `warmup_ratio` (should be 0.03 - 0.05)
92 | num_warmup_steps = int(num_training_steps * self.warmup_ratio)
93 |
94 | assert self.weight_decay == 0, "DDP training does not currently support `weight_decay` > 0!"
95 | self.optimizer = AdamW(trainable_params, lr=self.learning_rate, weight_decay=self.weight_decay)
96 | self.lr_scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps, num_training_steps)
97 | for param_group in self.optimizer.param_groups:
98 | param_group["lr"] = 0.0
99 |
100 | elif self.lr_scheduler_type == "constant":
101 | num_warmup_steps = 0
102 |
103 | assert self.weight_decay == 0, "DDP training does not currently support `weight_decay` > 0!"
104 | self.optimizer = AdamW(trainable_params, lr=self.learning_rate, weight_decay=self.weight_decay)
105 | self.lr_scheduler = get_constant_schedule(self.optimizer)
106 |
107 | else:
108 | raise ValueError(f"Learning Rate Schedule with type `{self.lr_scheduler_type}` is not supported!")
109 |
110 | # Finalize Setup =>> Log
111 | overwatch.info(
112 | "DDP Strategy =>> Finalized Training Setup:\n"
113 | f" |-> Global (Effective) Batch Size = {self.global_batch_size}\n"
114 | f" |-> Per-Device Batch Size = {self.per_device_batch_size}\n"
115 | f" |-> Distributed World Size = {overwatch.world_size()}\n"
116 | f" |-> Gradient Accumulation Steps = {self.grad_accumulation_steps}\n\n"
117 | f" |-> LLM Backbone Gradient Checkpointing = {self.enable_gradient_checkpointing}\n"
118 | f" |-> Use Native AMP = {self.enable_mixed_precision_training} ({self.mixed_precision_dtype})\n\n"
119 | f" |-> Default AdamW LR = {self.learning_rate}\n"
120 | f" |-> AdamW Weight Decay = {self.weight_decay}\n"
121 | f" |-> LR Scheduler Type = {self.lr_scheduler_type}\n"
122 | f" |-> LR Scheduler Warmup Steps (Ratio) = {num_warmup_steps} ({self.warmup_ratio})\n"
123 | f" |-> Dataset Size = {n_train_examples} Examples\n"
124 | f" |-> Max Steps = {num_training_steps}\n"
125 | )
126 |
127 | def clip_grad_norm(self) -> None:
128 | torch.nn.utils.clip_grad_norm_(self.vlm.parameters(), max_norm=self.max_grad_norm)
129 |
--------------------------------------------------------------------------------
/prismatic/vla/datasets/cast_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | import torchvision.transforms.functional as TF
4 | import numpy as np
5 | import random
6 | from PIL import Image
7 | from typing import Type
8 | from prismatic.vla.constants import IGNORE_INDEX
9 | from prismatic.models.backbones.llm.prompting import PromptBuilder
10 | from prismatic.vla.action_tokenizer import ActionTokenizer
11 | from transformers import PreTrainedTokenizerBase
12 | from prismatic.models.backbones.vision import ImageTransform
13 | from vint_train.data.data_utils import calculate_sin_cos, to_local_coords
14 |
15 |
16 | class CAST_Dataset(Dataset):
17 | def __init__(
18 | self,
19 | action_tokenizer: PreTrainedTokenizerBase,
20 | base_tokenizer: ActionTokenizer,
21 | image_transform: ImageTransform,
22 | prompt_builder_fn: Type[PromptBuilder],
23 | dataset_name,
24 | data_loc,
25 | data_size,
26 | features,
27 | predict_stop_token: bool = True,
28 | ):
29 | self.dataset_name = dataset_name
30 | self.data_loc = data_loc
31 | self.data_size = data_size
32 | self.features = features
33 | self.image_size = (96, 96)
34 | self.image_size_clip = (224, 224)
35 |
36 | self.action_tokenizer = action_tokenizer
37 | self.base_tokenizer = base_tokenizer
38 | self.prompt_builder = prompt_builder_fn
39 | self.predict_stop_token = predict_stop_token
40 | self.image_transform = image_transform
41 |
42 | def __len__(self):
43 | return self.data_size
44 |
45 | def _resize_norm(self, image, size):
46 | return TF.resize(image, size)
47 |
48 | def _compute_actions(self, action_yaw, goal_pose, metric_waypoint):
49 | positions = action_yaw[:, 0:2]
50 | yaw = action_yaw[:, 2]
51 |
52 | waypoints = to_local_coords(positions, positions[0], yaw[0])
53 | goal_pos = to_local_coords(goal_pose[:, 0:2], positions[0], yaw[0])
54 |
55 | yaw = yaw[1:] - yaw[0]
56 | actions = np.concatenate([waypoints[1:], yaw[:, None]], axis=-1)
57 | yawg = goal_pose[:, 2:3] - yaw[0]
58 | goal_pos = np.concatenate([goal_pos, yawg], axis=1)
59 |
60 | actions[:, :2] /= metric_waypoint
61 | goal_pos[:, :2] /= metric_waypoint
62 |
63 | return torch.from_numpy(actions), torch.from_numpy(goal_pos)
64 |
65 | def __getitem__(self, idx):
66 | folder_name = self.dataset_name.split("_convert")
67 | directory_location = self.data_loc + self.dataset_name + "/" + folder_name[0] + "/"
68 |
69 | len_action = 0
70 | while len_action < 10:
71 | traj = np.load(directory_location + f"traj_{idx:06d}.npz", allow_pickle=True)
72 | len_action = len(traj['action'])
73 | if len_action < 10:
74 | idx = random.randint(0, self.data_size - 1)
75 |
76 | num = random.randint(0, len(traj['action']) - 8 - 2)
77 | gid = max(len(traj['action']) - 1, num + 8)
78 |
79 | obs_dict = traj["observation"].item()
80 | cur_pilimg = obs_dict['image'][num]
81 | goal_pilimg = obs_dict['image'][gid]
82 |
83 | cur_obs = cur_pilimg.transpose(2, 0, 1)
84 | goal_obs = goal_pilimg.transpose(2, 0, 1)
85 |
86 | pil_img = Image.fromarray(cur_pilimg.astype(np.uint8)).resize(self.image_size_clip)
87 | pil_img_goal = Image.fromarray(goal_pilimg.astype(np.uint8)).resize(self.image_size_clip)
88 |
89 | pixel_values = self.image_transform(pil_img)
90 | pixel_values_g = self.image_transform(pil_img_goal)
91 |
92 | action_yaw = obs_dict['state'][num: num+8+1]
93 | goal_pose = obs_dict['state'][gid:gid+1]
94 |
95 | actions_norm, goal_pose_norm = self._compute_actions(
96 | action_yaw, goal_pose, traj["normalization_factor"]
97 | )
98 | actions_torch = calculate_sin_cos(actions_norm)
99 | goal_pose_torch = calculate_sin_cos(goal_pose_norm)
100 |
101 | language_instruction = traj['language_instruction'][0]
102 | non_empty_prompts = [p for p in language_instruction if p]
103 | selected_prompt = random.choice(non_empty_prompts).decode('utf-8')
104 |
105 | actions = actions_torch
106 | current_action = actions[0]
107 | future_actions = actions[1:]
108 | future_actions_string = ''.join(self.action_tokenizer(future_actions))
109 | current_action_string = self.action_tokenizer(current_action)
110 | action_chunk_string = current_action_string + future_actions_string
111 | action_chunk_len = len(action_chunk_string)
112 |
113 | lang = selected_prompt.lower()
114 | conversation = [
115 | {"from": "human", "value": f"What action should the robot take to {lang}?"},
116 | {"from": "gpt", "value": action_chunk_string},
117 | ]
118 |
119 | prompt_builder = self.prompt_builder("openvla")
120 | for turn in conversation:
121 | prompt_builder.add_turn(turn["from"], turn["value"])
122 |
123 | input_ids = self.base_tokenizer(prompt_builder.get_prompt(), add_special_tokens=True).input_ids
124 | labels = list(input_ids)
125 |
126 | max_token = 60
127 | if len(input_ids) > max_token:
128 | lang = "move toward XXXXX"
129 | conversation = [
130 | {"from": "human", "value": f"What action should the robot take to {lang}?"},
131 | {"from": "gpt", "value": action_chunk_string},
132 | ]
133 | prompt_builder = self.prompt_builder("openvla")
134 | for turn in conversation:
135 | prompt_builder.add_turn(turn["from"], turn["value"])
136 |
137 | input_ids = self.base_tokenizer(prompt_builder.get_prompt(), add_special_tokens=True).input_ids
138 | labels = list(input_ids)
139 |
140 | input_ids, labels = torch.tensor(input_ids), torch.tensor(labels)
141 |
142 | obj_pose_norm = goal_pose_torch[0, 0:2]
143 | goal_pose_cos_sin = goal_pose_torch.squeeze(0)
144 |
145 | labels[: -(action_chunk_len + 1)] = IGNORE_INDEX
146 | if not self.predict_stop_token:
147 | labels[-1] = IGNORE_INDEX
148 |
149 | goal_id = 0
150 | cur_image_r = self._resize_norm(torch.from_numpy(cur_obs), self.image_size).repeat(6,1,1)/255.0
151 | goal_image_full_8_r = self._resize_norm(torch.from_numpy(goal_obs), self.image_size)/255.0
152 |
153 | dataset_name = "cast"
154 | modality_id = 7 # language only
155 | action_select_mask = torch.tensor(1.0)
156 |
157 | return dict(
158 | pixel_values=pixel_values,
159 | pixel_values_goal=pixel_values_g,
160 | input_ids=input_ids,
161 | labels=labels,
162 | dataset_name=dataset_name,
163 | modality_id=modality_id,
164 | actions=torch.as_tensor(actions_torch),
165 | action_select_mask=action_select_mask,
166 | goal_pose=goal_pose_cos_sin,
167 | obj_pose_norm=obj_pose_norm,
168 | img_PIL=pil_img,
169 | gimg_PIL=pil_img_goal,
170 | cur_image=cur_image_r,
171 | goal_image_8=goal_image_full_8_r,
172 | temp_dist=goal_id,
173 | lan_prompt=lang
174 | )
175 |
176 |
--------------------------------------------------------------------------------
/prismatic/models/backbones/vision/dinoclip_vit.py:
--------------------------------------------------------------------------------
1 | """
2 | dinoclip_vit.py
3 |
4 | Vision backbone that returns concatenated features from both DINOv2 and CLIP.
5 | """
6 |
7 | from dataclasses import dataclass
8 | from functools import partial
9 | from typing import Callable, Dict, Tuple
10 |
11 | import timm
12 | import torch
13 | from PIL import Image
14 | from timm.models.vision_transformer import Block, VisionTransformer
15 | from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy
16 | from torchvision.transforms import Compose, Resize
17 |
18 | from prismatic.models.backbones.vision.base_vision import ImageTransform, LetterboxPad, VisionBackbone, unpack_tuple
19 |
20 | # Registry =>> Supported DinoCLIP Pairs (as TIMM identifiers)
21 | DINOCLIP_VISION_BACKBONES = {
22 | "dinoclip-vit-l-336px": {
23 | "dino": "vit_large_patch14_reg4_dinov2.lvd142m",
24 | "clip": "vit_large_patch14_clip_336.openai",
25 | },
26 | }
27 |
28 |
29 | @dataclass
30 | class DinoCLIPImageTransform:
31 | dino_image_transform: ImageTransform
32 | clip_image_transform: ImageTransform
33 | is_prismatic: bool = True
34 |
35 | def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]:
36 | return {"dino": self.dino_image_transform(img, **kwargs), "clip": self.clip_image_transform(img, **kwargs)}
37 |
38 |
39 | class DinoCLIPViTBackbone(VisionBackbone):
40 | def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None:
41 | super().__init__(vision_backbone_id, image_resize_strategy, default_image_size=default_image_size)
42 | self.dino_timm_path_or_url = DINOCLIP_VISION_BACKBONES[vision_backbone_id]["dino"]
43 | self.clip_timm_path_or_url = DINOCLIP_VISION_BACKBONES[vision_backbone_id]["clip"]
44 |
45 | # Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary
46 | self.dino_featurizer: VisionTransformer = timm.create_model(
47 | self.dino_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size
48 | )
49 | self.dino_featurizer.eval()
50 |
51 | self.clip_featurizer: VisionTransformer = timm.create_model(
52 | self.clip_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size
53 | )
54 | self.clip_featurizer.eval()
55 |
56 | # Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility
57 | # => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches!
58 | # => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385
59 | self.dino_featurizer.forward = unpack_tuple(
60 | partial(self.dino_featurizer.get_intermediate_layers, n={len(self.dino_featurizer.blocks) - 2})
61 | )
62 | self.clip_featurizer.forward = unpack_tuple(
63 | partial(self.clip_featurizer.get_intermediate_layers, n={len(self.clip_featurizer.blocks) - 2})
64 | )
65 |
66 | # Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models
67 | self.dino_data_cfg = timm.data.resolve_model_data_config(self.dino_featurizer)
68 | self.dino_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size)
69 |
70 | self.clip_data_cfg = timm.data.resolve_model_data_config(self.clip_featurizer)
71 | self.clip_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size)
72 |
73 | # Initialize *both* Transforms
74 | default_dino_transform = timm.data.create_transform(**self.dino_data_cfg, is_training=False)
75 | default_clip_transform = timm.data.create_transform(**self.clip_data_cfg, is_training=False)
76 | if self.image_resize_strategy == "resize-naive":
77 | assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_image_transform`!"
78 | assert isinstance(default_clip_transform, Compose), "Unexpected `default_clip_image_transform`!"
79 | assert isinstance(default_dino_transform.transforms[0], Resize)
80 | assert isinstance(default_clip_transform.transforms[0], Resize)
81 |
82 | target_size = (self.default_image_size, self.default_image_size)
83 | dino_transform = Compose(
84 | [
85 | Resize(target_size, interpolation=default_dino_transform.transforms[0].interpolation),
86 | *default_dino_transform.transforms[1:],
87 | ]
88 | )
89 | clip_transform = Compose(
90 | [
91 | Resize(target_size, interpolation=default_clip_transform.transforms[0].interpolation),
92 | *default_clip_transform.transforms[1:],
93 | ]
94 | )
95 |
96 | self.image_transform = DinoCLIPImageTransform(dino_transform, clip_transform)
97 |
98 | elif self.image_resize_strategy == "resize-crop":
99 | self.image_transform = DinoCLIPImageTransform(default_dino_transform, default_clip_transform)
100 |
101 | elif self.image_resize_strategy == "letterbox":
102 | assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_transform`!"
103 | assert isinstance(default_clip_transform, Compose), "Unexpected `default_clip_transform`!"
104 | assert "mean" in self.dino_data_cfg and "mean" in self.clip_data_cfg, "DinoCLIP `data_cfg` missing `mean`!"
105 |
106 | # Compute Padding Fill Value(s) (rescaled normalization mean if applicable)
107 | dino_fill = tuple([int(x * 255) for x in self.dino_data_cfg["mean"]])
108 | clip_fill = tuple([int(x * 255) for x in self.clip_data_cfg["mean"]])
109 |
110 | # Build New Transform
111 | self.image_transform = DinoCLIPImageTransform(
112 | Compose([LetterboxPad(dino_fill), *default_dino_transform.transforms]),
113 | Compose([LetterboxPad(clip_fill), *default_clip_transform.transforms]),
114 | )
115 |
116 | else:
117 | raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!")
118 |
119 | def get_fsdp_wrapping_policy(self) -> Callable:
120 | """Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers."""
121 | vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer})
122 | transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
123 | return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy])
124 |
125 | def forward(self, pixel_values: Dict[str, torch.Tensor]) -> torch.Tensor:
126 | """Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches."""
127 | dino_patches = self.dino_featurizer(pixel_values["dino"])
128 | clip_patches = self.clip_featurizer(pixel_values["clip"])
129 |
130 | return torch.cat([dino_patches, clip_patches], dim=2)
131 |
132 | @property
133 | def default_image_resolution(self) -> Tuple[int, int, int]:
134 | return self.dino_data_cfg["input_size"]
135 |
136 | @property
137 | def embed_dim(self) -> int:
138 | return self.dino_featurizer.embed_dim + self.clip_featurizer.embed_dim
139 |
140 | @property
141 | def num_patches(self) -> int:
142 | assert self.dino_featurizer.patch_embed.num_patches == self.clip_featurizer.patch_embed.num_patches
143 | return self.dino_featurizer.patch_embed.num_patches
144 |
145 | @property
146 | def half_precision_dtype(self) -> torch.dtype:
147 | return torch.bfloat16
148 |
--------------------------------------------------------------------------------
/prismatic/models/backbones/vision/dinosiglip_vit.py:
--------------------------------------------------------------------------------
1 | """
2 | dinosiglip_vit.py
3 |
4 | Vision backbone that returns concatenated features from both DINOv2 and SigLIP.
5 | """
6 |
7 | from dataclasses import dataclass
8 | from functools import partial
9 | from typing import Callable, Dict, Tuple
10 |
11 | import timm
12 | import torch
13 | from PIL import Image
14 | from timm.models.vision_transformer import Block, VisionTransformer
15 | from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy
16 | from torchvision.transforms import Compose, Resize
17 |
18 | from prismatic.models.backbones.vision.base_vision import ImageTransform, LetterboxPad, VisionBackbone, unpack_tuple
19 |
20 | # Registry =>> Supported DinoSigLIP Pairs (as TIMM identifiers)
21 | DINOSigLIP_VISION_BACKBONES = {
22 | "dinosiglip-vit-so-224px": {
23 | "dino": "vit_large_patch14_reg4_dinov2.lvd142m",
24 | "siglip": "vit_so400m_patch14_siglip_224",
25 | },
26 | "dinosiglip-vit-so-384px": {
27 | "dino": "vit_large_patch14_reg4_dinov2.lvd142m",
28 | "siglip": "vit_so400m_patch14_siglip_384",
29 | },
30 | }
31 |
32 |
33 | @dataclass
34 | class DinoSigLIPImageTransform:
35 | dino_image_transform: ImageTransform
36 | siglip_image_transform: ImageTransform
37 | is_prismatic: bool = True
38 |
39 | def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]:
40 | return {"dino": self.dino_image_transform(img, **kwargs), "siglip": self.siglip_image_transform(img, **kwargs)}
41 |
42 |
43 | class DinoSigLIPViTBackbone(VisionBackbone):
44 | def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None:
45 | super().__init__(vision_backbone_id, image_resize_strategy, default_image_size=default_image_size)
46 | self.dino_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[vision_backbone_id]["dino"]
47 | self.siglip_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[vision_backbone_id]["siglip"]
48 |
49 | # Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary
50 | self.dino_featurizer: VisionTransformer = timm.create_model(
51 | self.dino_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size
52 | )
53 | self.dino_featurizer.eval()
54 |
55 | self.siglip_featurizer: VisionTransformer = timm.create_model(
56 | self.siglip_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size
57 | )
58 | self.siglip_featurizer.eval()
59 |
60 | # Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility
61 | # => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches!
62 | # => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385
63 | self.dino_featurizer.forward = unpack_tuple(
64 | partial(self.dino_featurizer.get_intermediate_layers, n={len(self.dino_featurizer.blocks) - 2})
65 | )
66 | self.siglip_featurizer.forward = unpack_tuple(
67 | partial(self.siglip_featurizer.get_intermediate_layers, n={len(self.siglip_featurizer.blocks) - 2})
68 | )
69 |
70 | # Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models
71 | self.dino_data_cfg = timm.data.resolve_model_data_config(self.dino_featurizer)
72 | self.dino_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size)
73 |
74 | self.siglip_data_cfg = timm.data.resolve_model_data_config(self.siglip_featurizer)
75 | self.siglip_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size)
76 |
77 | # Initialize *both* Transforms
78 | default_dino_transform = timm.data.create_transform(**self.dino_data_cfg, is_training=False)
79 | default_siglip_transform = timm.data.create_transform(**self.siglip_data_cfg, is_training=False)
80 |
81 | # Fix =>> SigLIP default transform resizes to *larger* than `self.default_image_size` (crops image)!!
82 | assert isinstance(default_siglip_transform, Compose), "Unexpected `default_image_transform`!"
83 | assert isinstance(default_siglip_transform.transforms[0], Resize)
84 | default_siglip_transform = Compose(
85 | [
86 | Resize(self.default_image_size, interpolation=default_siglip_transform.transforms[0].interpolation),
87 | *default_siglip_transform.transforms[1:],
88 | ]
89 | )
90 |
91 | if self.image_resize_strategy == "resize-naive":
92 | assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_image_transform`!"
93 | assert isinstance(default_siglip_transform, Compose), "Unexpected `default_siglip_image_transform`!"
94 | assert isinstance(default_dino_transform.transforms[0], Resize)
95 | assert isinstance(default_siglip_transform.transforms[0], Resize)
96 |
97 | target_size = (self.default_image_size, self.default_image_size)
98 | dino_transform = Compose(
99 | [
100 | Resize(target_size, interpolation=default_dino_transform.transforms[0].interpolation),
101 | *default_dino_transform.transforms[1:],
102 | ]
103 | )
104 | siglip_transform = Compose(
105 | [
106 | Resize(target_size, interpolation=default_siglip_transform.transforms[0].interpolation),
107 | *default_siglip_transform.transforms[1:],
108 | ]
109 | )
110 |
111 | self.image_transform = DinoSigLIPImageTransform(dino_transform, siglip_transform)
112 |
113 | elif self.image_resize_strategy == "resize-crop":
114 | self.image_transform = DinoSigLIPImageTransform(default_dino_transform, default_siglip_transform)
115 |
116 | elif self.image_resize_strategy == "letterbox":
117 | assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_transform`!"
118 | assert isinstance(default_siglip_transform, Compose), "Unexpected `default_siglip_transform`!"
119 | assert (
120 | "mean" in self.dino_data_cfg and "mean" in self.siglip_data_cfg
121 | ), "DinoSigLIP `data_cfg` missing `mean`!"
122 |
123 | # Compute Padding Fill Value(s) (rescaled normalization mean if applicable)
124 | dino_fill = tuple([int(x * 255) for x in self.dino_data_cfg["mean"]])
125 | siglip_fill = tuple([int(x * 255) for x in self.siglip_data_cfg["mean"]])
126 |
127 | # Build New Transform
128 | self.image_transform = DinoSigLIPImageTransform(
129 | Compose([LetterboxPad(dino_fill), *default_dino_transform.transforms]),
130 | Compose([LetterboxPad(siglip_fill), *default_siglip_transform.transforms]),
131 | )
132 |
133 | else:
134 | raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!")
135 |
136 | def get_fsdp_wrapping_policy(self) -> Callable:
137 | """Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers."""
138 | vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer})
139 | transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
140 | return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy])
141 |
142 | def forward(self, pixel_values: Dict[str, torch.Tensor]) -> torch.Tensor:
143 | """Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches."""
144 | dino_patches = self.dino_featurizer(pixel_values["dino"])
145 | siglip_patches = self.siglip_featurizer(pixel_values["siglip"])
146 |
147 | return torch.cat([dino_patches, siglip_patches], dim=2)
148 |
149 | @property
150 | def default_image_resolution(self) -> Tuple[int, int, int]:
151 | return self.dino_data_cfg["input_size"]
152 |
153 | @property
154 | def embed_dim(self) -> int:
155 | return self.dino_featurizer.embed_dim + self.siglip_featurizer.embed_dim
156 |
157 | @property
158 | def num_patches(self) -> int:
159 | assert self.dino_featurizer.patch_embed.num_patches == self.siglip_featurizer.patch_embed.num_patches
160 | return self.dino_featurizer.patch_embed.num_patches
161 |
162 | @property
163 | def half_precision_dtype(self) -> torch.dtype:
164 | return torch.bfloat16
165 |
--------------------------------------------------------------------------------
/prismatic/preprocessing/download.py:
--------------------------------------------------------------------------------
1 | """
2 | download.py
3 |
4 | Utility functions for downloading and extracting various datasets to (local) disk.
5 | """
6 |
7 | import os
8 | import shutil
9 | from pathlib import Path
10 | from typing import Dict, List, TypedDict
11 | from zipfile import ZipFile
12 |
13 | import requests
14 | from PIL import Image
15 | from rich.progress import BarColumn, DownloadColumn, MofNCompleteColumn, Progress, TextColumn, TransferSpeedColumn
16 | from tqdm import tqdm
17 |
18 | from prismatic.overwatch import initialize_overwatch
19 |
20 | # Initialize Overwatch =>> Wraps `logging.Logger`
21 | overwatch = initialize_overwatch(__name__)
22 |
23 |
24 | # === Dataset Registry w/ Links ===
25 | # fmt: off
26 | DatasetComponent = TypedDict(
27 | "DatasetComponent",
28 | {"name": str, "extract": bool, "extract_type": str, "url": str, "do_rename": bool},
29 | total=False
30 | )
31 |
32 | DATASET_REGISTRY: Dict[str, List[DatasetComponent]] = {
33 | # === LLaVa v1.5 Dataset(s) ===
34 |
35 | # Note =>> This is the full suite of datasets included in the LLaVa 1.5 "finetuning" stage; all the LLaVa v1.5
36 | # models are finetuned on this split. We use this dataset for all experiments in our paper.
37 | "llava-laion-cc-sbu-558k": [
38 | {
39 | "name": "chat.json", # Contains the "chat" traces :: {"human" => , "gpt" => }
40 | "extract": False,
41 | "url": "https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/resolve/main/blip_laion_cc_sbu_558k.json",
42 | "do_rename": True,
43 | },
44 | {
45 | "name": "images", # Contains the LLaVa Processed Images (jpgs, 224x224 resolution)
46 | "extract": True,
47 | "extract_type": "directory",
48 | "url": "https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/resolve/main/images.zip",
49 | "do_rename": False,
50 | }
51 | ],
52 |
53 | "llava-v1.5-instruct": [
54 | {
55 | "name": "llava_v1_5_mix665k.json",
56 | "extract": False,
57 | "url": (
58 | "https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/resolve/main/llava_v1_5_mix665k.json"
59 | ),
60 | "do_rename": True,
61 | },
62 | {
63 | "name": "coco/train2017", # Visual Instruct Tuning images are all sourced from COCO Train 2017
64 | "extract": True,
65 | "extract_type": "directory",
66 | "url": "http://images.cocodataset.org/zips/train2017.zip",
67 | "do_rename": True,
68 | },
69 | {
70 | "name": "gqa/images",
71 | "extract": True,
72 | "extract_type": "directory",
73 | "url": "https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip",
74 | "do_rename": True,
75 | },
76 | {
77 | "name": "ocr_vqa/images",
78 | "extract": True,
79 | "extract_type": "directory",
80 | "url": "https://huggingface.co/datasets/qnguyen3/ocr_vqa/resolve/main/ocr_vqa.zip",
81 | "do_rename": True,
82 | },
83 | {
84 | "name": "textvqa/train_images",
85 | "extract": True,
86 | "extract_type": "directory",
87 | "url": "https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip",
88 | "do_rename": True,
89 | },
90 | {
91 | "name": "vg/VG_100K",
92 | "extract": True,
93 | "extract_type": "directory",
94 | "url": "https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip",
95 | "do_rename": True,
96 | },
97 | {
98 | "name": "vg/VG_100K_2",
99 | "extract": True,
100 | "extract_type": "directory",
101 | "url": "https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip",
102 | "do_rename": True,
103 | },
104 | ]
105 | }
106 | # fmt: on
107 |
108 |
109 | def convert_to_jpg(image_dir: Path) -> None:
110 | """Handling for OCR-VQA Images specifically; iterates through directory, converts all GIFs/PNGs."""
111 | overwatch.info(f"Converting all Images in `{image_dir}` to JPG")
112 |
113 | for image_fn in tqdm(list(image_dir.iterdir())):
114 | if image_fn.suffix in {".jpg", ".jpeg"} or (jpg_fn := image_dir / f"{image_fn.stem}.jpg").exists():
115 | continue
116 |
117 | if image_fn.suffix == ".gif":
118 | gif = Image.open(image_fn)
119 | gif.seek(0)
120 | gif.convert("RGB").save(jpg_fn)
121 | elif image_fn.suffix == ".png":
122 | Image.open(image_fn).convert("RGB").save(jpg_fn)
123 | else:
124 | raise ValueError(f"Unexpected image format `{image_fn.suffix}`")
125 |
126 |
127 | def download_with_progress(url: str, download_dir: Path, chunk_size_bytes: int = 1024) -> Path:
128 | """Utility function for downloading files from the internet, with a handy Rich-based progress bar."""
129 | overwatch.info(f"Downloading {(dest_path := download_dir / Path(url).name)} from `{url}`", ctx_level=1)
130 | if dest_path.exists():
131 | return dest_path
132 |
133 | # Otherwise --> fire an HTTP Request, with `stream = True`
134 | response = requests.get(url, stream=True)
135 |
136 | # Download w/ Transfer-Aware Progress
137 | # => Reference: https://github.com/Textualize/rich/blob/master/examples/downloader.py
138 | with Progress(
139 | TextColumn("[bold]{task.description} - {task.fields[fname]}"),
140 | BarColumn(bar_width=None),
141 | "[progress.percentage]{task.percentage:>3.1f}%",
142 | "•",
143 | DownloadColumn(),
144 | "•",
145 | TransferSpeedColumn(),
146 | transient=True,
147 | ) as dl_progress:
148 | dl_tid = dl_progress.add_task(
149 | "Downloading", fname=dest_path.name, total=int(response.headers.get("content-length", "None"))
150 | )
151 | with open(dest_path, "wb") as f:
152 | for data in response.iter_content(chunk_size=chunk_size_bytes):
153 | dl_progress.advance(dl_tid, f.write(data))
154 |
155 | return dest_path
156 |
157 |
158 | def extract_with_progress(archive_path: Path, download_dir: Path, extract_type: str, cleanup: bool = False) -> Path:
159 | """Utility function for extracting compressed archives, with a handy Rich-based progress bar."""
160 | assert archive_path.suffix == ".zip", "Only `.zip` compressed archives are supported for now!"
161 | overwatch.info(f"Extracting {archive_path.name} to `{download_dir}`", ctx_level=1)
162 |
163 | # Extract w/ Progress
164 | with Progress(
165 | TextColumn("[bold]{task.description} - {task.fields[aname]}"),
166 | BarColumn(bar_width=None),
167 | "[progress.percentage]{task.percentage:>3.1f}%",
168 | "•",
169 | MofNCompleteColumn(),
170 | transient=True,
171 | ) as ext_progress:
172 | with ZipFile(archive_path) as zf:
173 | ext_tid = ext_progress.add_task("Extracting", aname=archive_path.name, total=len(members := zf.infolist()))
174 | extract_path = Path(zf.extract(members[0], download_dir))
175 | if extract_type == "file":
176 | assert len(members) == 1, f"Archive `{archive_path}` with extract type `{extract_type} has > 1 member!"
177 | elif extract_type == "directory":
178 | for member in members[1:]:
179 | zf.extract(member, download_dir)
180 | ext_progress.advance(ext_tid)
181 | else:
182 | raise ValueError(f"Extract type `{extract_type}` for archive `{archive_path}` is not defined!")
183 |
184 | # Cleanup (if specified)
185 | if cleanup:
186 | archive_path.unlink()
187 |
188 | return extract_path
189 |
190 |
191 | def download_extract(dataset_id: str, root_dir: Path) -> None:
192 | """Download all files for a given dataset (querying registry above), extracting archives if necessary."""
193 | os.makedirs(download_dir := root_dir / "download" / dataset_id, exist_ok=True)
194 |
195 | # Download Files => Single-Threaded, with Progress Bar
196 | dl_tasks = [d for d in DATASET_REGISTRY[dataset_id] if not (download_dir / d["name"]).exists()]
197 | for dl_task in dl_tasks:
198 | dl_path = download_with_progress(dl_task["url"], download_dir)
199 |
200 | # Extract Files (if specified) --> Note (assumes ".zip" ONLY!)
201 | if dl_task["extract"]:
202 | dl_path = extract_with_progress(dl_path, download_dir, dl_task["extract_type"])
203 | dl_path = dl_path.parent if dl_path.is_file() else dl_path
204 |
205 | # Rename Path --> dl_task["name"]
206 | if dl_task["do_rename"]:
207 | shutil.move(dl_path, download_dir / dl_task["name"])
208 |
--------------------------------------------------------------------------------
/prismatic/models/backbones/vision/base_vision.py:
--------------------------------------------------------------------------------
1 | """
2 | base_vision.py
3 |
4 | Abstract class definition of a Vision Backbone (Visual Featurizer), with full annotations of class methods, utility
5 | functions, and initialization logic.
6 |
7 | We also define the generic TimmViTBackbone class here, providing a default interface for loading any TIMM Vision
8 | Transformer model for feature extraction.
9 | """
10 |
11 | from abc import ABC, abstractmethod
12 | from dataclasses import dataclass
13 | from functools import partial
14 | from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union
15 |
16 | import timm
17 | import torch
18 | import torch.nn as nn
19 | import torchvision.transforms.functional as TVF
20 | from PIL.Image import Image
21 | from timm.models.vision_transformer import Block, VisionTransformer
22 | from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy
23 | from torchvision.transforms import Compose, Resize
24 |
25 |
26 | # === Utility Functions for Monkey-Patching ===
27 | def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
28 | def wrapper(*args: Any, **kwargs: Any) -> Any:
29 | result = fn(*args, **kwargs)
30 | return result[0] if isinstance(result, tuple) else result
31 |
32 | return wrapper
33 |
34 |
35 | # === Interface for an Image Transform ===
36 | class ImageTransform(Protocol):
37 | def __call__(self, img: Image, **kwargs: str) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: ...
38 |
39 |
40 | # === Custom Torchvision Image Transforms ===
41 | @dataclass
42 | class LetterboxPad:
43 | padding_fill_value: Tuple[int, int, int]
44 |
45 | def __call__(self, image: Image) -> Image:
46 | """Given a PIL.Image, pad to square by adding a symmetric border around the height/width."""
47 | (w, h), max_wh = image.size, max(image.size)
48 | horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2)
49 | padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
50 | return TVF.pad(image, padding, fill=self.padding_fill_value, padding_mode="constant")
51 |
52 |
53 | # === Abstract Base Class for arbitrary Vision Backbones ===
54 | class VisionBackbone(nn.Module, ABC):
55 | def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None:
56 | super().__init__()
57 | self.identifier: str = vision_backbone_id
58 | self.image_resize_strategy: str = image_resize_strategy
59 | self.default_image_size: int = default_image_size
60 |
61 | # Instance attributes for a Vision Backbone
62 | self.featurizer: nn.Module = None
63 | self.image_transform: ImageTransform = None
64 |
65 | def get_image_transform(self) -> ImageTransform:
66 | return self.image_transform
67 |
68 | @abstractmethod
69 | def get_fsdp_wrapping_policy(self) -> Callable: ...
70 |
71 | @abstractmethod
72 | def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
73 | """Run a forward pass through the featurizer given a set of processed images, returning patch/grid features."""
74 | raise NotImplementedError
75 |
76 | @property
77 | @abstractmethod
78 | def default_image_resolution(self) -> Tuple[int, int, int]: ...
79 |
80 | @property
81 | @abstractmethod
82 | def embed_dim(self) -> int: ...
83 |
84 | @property
85 | @abstractmethod
86 | def num_patches(self) -> int: ...
87 |
88 | @property
89 | @abstractmethod
90 | def half_precision_dtype(self) -> torch.dtype: ...
91 |
92 |
93 | # === Abstract Base Class for Arbitrary TIMM Vision Transformer Backbones ===
94 | class TimmViTBackbone(VisionBackbone, ABC):
95 | def __init__(
96 | self,
97 | vision_backbone_id: str,
98 | timm_path_or_url: str,
99 | image_resize_strategy: str,
100 | default_image_size: int = 224,
101 | override_act_layer: Optional[str] = None,
102 | ) -> None:
103 | super().__init__(vision_backbone_id, image_resize_strategy, default_image_size=default_image_size)
104 | self.timm_path_or_url = timm_path_or_url
105 | self.override_act_layer = override_act_layer
106 | self.dtype = torch.bfloat16
107 |
108 | # Initialize Featurizer (ViT) by downloading from HF / TIMM Hub if necessary
109 | if self.override_act_layer is None:
110 | self.featurizer: VisionTransformer = timm.create_model(
111 | self.timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size
112 | )
113 | else:
114 | self.featurizer: VisionTransformer = timm.create_model(
115 | self.timm_path_or_url,
116 | pretrained=True,
117 | num_classes=0,
118 | img_size=self.default_image_size,
119 | act_layer=self.override_act_layer,
120 | )
121 | self.featurizer.eval()
122 |
123 | # Monkey-Patch the `forward()` function of the featurizer to ensure FSDP-compatibility
124 | # => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches!
125 | # => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385
126 | self.featurizer.forward = unpack_tuple(
127 | partial(self.featurizer.get_intermediate_layers, n={len(self.featurizer.blocks) - 2})
128 | )
129 |
130 | # Validation =>> for now, this class *only* supports TIMM Vision Transformers (but can be extended!)
131 | assert isinstance(self.featurizer, VisionTransformer), (
132 | "Featurizer is not a TIMM VisionTransformer; if you would like to support a new visual representation, "
133 | "file an issue or implement the requisite logic (see `prismatic/models/backbones/vision/base_vision.py`)!"
134 | )
135 |
136 | # Get Config =>> Note :: Override default image size to ensure correct image transform
137 | self.data_cfg = timm.data.resolve_model_data_config(self.featurizer)
138 | self.data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size)
139 |
140 | # Initialize Default Image Transform --> Modified by `self.image_resize_strategy`
141 | default_image_transform = timm.data.create_transform(**self.data_cfg, is_training=False)
142 |
143 | # Fix =>> SigLIP & IN1K default transforms resize to *larger* than `self.default_image_size` (crops image)!
144 | if "siglip" in self.timm_path_or_url or "in1k" in self.timm_path_or_url:
145 | assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!"
146 | assert isinstance(default_image_transform.transforms[0], Resize)
147 | default_image_transform = Compose(
148 | [
149 | Resize(self.default_image_size, interpolation=default_image_transform.transforms[0].interpolation),
150 | *default_image_transform.transforms[1:],
151 | ]
152 | )
153 |
154 | # Switch on `image_resize_strategy`
155 | if self.image_resize_strategy == "resize-naive":
156 | assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!"
157 | assert isinstance(default_image_transform.transforms[0], Resize)
158 |
159 | target_size = (self.default_image_size, self.default_image_size)
160 | self.image_transform = Compose(
161 | [
162 | Resize(target_size, interpolation=default_image_transform.transforms[0].interpolation),
163 | *default_image_transform.transforms[1:],
164 | ]
165 | )
166 |
167 | elif self.image_resize_strategy == "resize-crop":
168 | self.image_transform = default_image_transform
169 |
170 | elif self.image_resize_strategy == "letterbox":
171 | assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!"
172 | assert "mean" in self.data_cfg, "TIMM `data_cfg` missing image normalization mean!"
173 |
174 | # Compute Padding Fill Value (rescaled normalization mean if applicable)
175 | fill = tuple([int(x * 255) for x in self.data_cfg["mean"]])
176 |
177 | # Build New Transform
178 | self.image_transform = Compose([LetterboxPad(fill), *default_image_transform.transforms])
179 |
180 | else:
181 | raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!")
182 |
183 | def get_fsdp_wrapping_policy(self) -> Callable:
184 | """Return a simple FSDP policy that wraps each ViT block and then the _entire_ featurizer."""
185 | vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer})
186 | transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
187 | return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy])
188 |
189 | def forward(self, pixel_values: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor:
190 | """Runs transformed image/pixel tensor through vision backbone, returning _all_ patch features."""
191 | return self.featurizer(pixel_values)
192 |
193 | @property
194 | def default_image_resolution(self) -> Tuple[int, int, int]:
195 | return self.data_cfg["input_size"]
196 |
197 | @property
198 | def embed_dim(self) -> int:
199 | return self.featurizer.embed_dim
200 |
201 | @property
202 | def num_patches(self) -> int:
203 | return self.featurizer.patch_embed.num_patches
204 |
205 | @property
206 | def half_precision_dtype(self) -> torch.dtype:
207 | return self.dtype
208 |
--------------------------------------------------------------------------------
/prismatic/vla/datasets/dummy_dataset.py:
--------------------------------------------------------------------------------
1 | #import sys
2 | #sys.path.append('/media/noriaki/Noriaki_Data/Learning-to-Drive-Anywhere-with-MBRA/train/')
3 | #sys.path.append('/home/noriaki/Learning-to-Drive-Anywhere-with-MBRA2/train/')
4 |
5 | import numpy as np
6 | import os
7 | import pickle
8 | import yaml
9 | from typing import Any, Dict, List, Optional, Tuple, Type
10 | import tqdm
11 | import io
12 | import lmdb
13 | import utm
14 | import math
15 |
16 | import torch
17 | from torch.utils.data import Dataset
18 | import torchvision.transforms.functional as TF
19 |
20 | import random
21 | #import cv2
22 | import matplotlib.pyplot as plt
23 |
24 | from prismatic.vla.constants import ACTION_DIM, ACTION_PROPRIO_NORMALIZATION_TYPE, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX, NUM_ACTIONS_CHUNK, STOP_INDEX
25 | from prismatic.models.backbones.llm.prompting import PromptBuilder
26 | from prismatic.vla.action_tokenizer import ActionTokenizer
27 | from transformers import PreTrainedTokenizerBase
28 | from prismatic.models.backbones.vision import ImageTransform
29 | from PIL import Image
30 | from typing import Union
31 |
32 | from vint_train.data.data_utils import (
33 | img_path_to_data,
34 | calculate_sin_cos,
35 | get_data_path,
36 | to_local_coords,
37 | )
38 |
39 | class Dummy_Dataset(Dataset):
40 | def __init__(
41 | self,
42 | context_size: int,
43 | action_tokenizer: PreTrainedTokenizerBase,
44 | base_tokenizer: ActionTokenizer,
45 | image_transform: ImageTransform,
46 | prompt_builder_fn: Type[PromptBuilder],
47 | predict_stop_token: bool = True,
48 | ):
49 | self.context_size = context_size
50 | self.action_tokenizer = action_tokenizer
51 | self.base_tokenizer = base_tokenizer
52 | self.prompt_builder = prompt_builder_fn
53 | self.predict_stop_token = predict_stop_token
54 | self.image_transform = image_transform
55 |
56 | def __len__(self) -> int:
57 | return 10000 #dummy length
58 |
59 | def calculate_relative_position(self, x_a, y_a, x_b, y_b):
60 | return x_b - x_a, y_b - y_a
61 |
62 | def rotate_to_local_frame(self, delta_x, delta_y, heading_a_rad):
63 | rel_x = delta_x * math.cos(heading_a_rad) + delta_y * math.sin(heading_a_rad)
64 | rel_y = -delta_x * math.sin(heading_a_rad) + delta_y * math.cos(heading_a_rad)
65 | return rel_x, rel_y
66 |
67 | def _resize_norm(self, image, size):
68 | return TF.resize(image, size)
69 |
70 | def __getitem__(self, i: int) -> Tuple[torch.Tensor]:
71 | thres_dist = 30.0
72 | metric_waypoint_spacing = 0.1
73 | predict_stop_token=True
74 |
75 | # set the available modality id for each dataset
76 | # 0:"satellite only", 1:"pose and satellite", 2:"satellite and image", 3:"all", 4:"pose only", 5:"pose and image", 6:"image only", 7:"language only", 8:"language and pose"
77 | modality_list = [4, 6, 7] #[4, 5, 6, 7, 8] # Our sample data is no consistency between modalities. So we can take a solo modality.
78 | modality_id = random.choice(modality_list)
79 |
80 | inst_obj = "move toward blue trash bin"
81 | actions = np.random.rand(8, 4) #dummy action
82 |
83 | # Dummy current and goal location
84 | current_lat, current_lon, current_compass = 37.87371258374039, -122.26729417226024, 270.0
85 | cur_utm = utm.from_latlon(current_lat, current_lon)
86 | cur_compass = -float(current_compass) / 180.0 * math.pi # inverted compass
87 |
88 | goal_lat, goal_lon, goal_compass = 37.8738930785863, -122.26746181032362, 0.0
89 | goal_utm = utm.from_latlon(goal_lat, goal_lon)
90 | goal_compass = -float(goal_compass) / 180.0 * math.pi
91 |
92 | # Local goal position
93 | delta_x, delta_y = self.calculate_relative_position(
94 | cur_utm[0], cur_utm[1], goal_utm[0], goal_utm[1]
95 | )
96 | relative_x, relative_y = self.rotate_to_local_frame(delta_x, delta_y, cur_compass)
97 | radius = np.sqrt(relative_x**2 + relative_y**2)
98 | if radius > thres_dist:
99 | relative_x *= thres_dist / radius
100 | relative_y *= thres_dist / radius
101 |
102 | goal_pose_loc_norm = np.array([
103 | relative_y / metric_waypoint_spacing,
104 | -relative_x / metric_waypoint_spacing,
105 | np.cos(goal_compass - cur_compass),
106 | np.sin(goal_compass - cur_compass)
107 | ])
108 |
109 | goal_pose_cos_sin = goal_pose_loc_norm
110 | current_image_PIL = Image.open("./inference/current_img.jpg").convert("RGB")
111 | goal_image_PIL = Image.open("./inference/goal_img.jpg").convert("RGB")
112 |
113 | IGNORE_INDEX = -100
114 | current_action = actions[0]
115 | future_actions = actions[1:]
116 | future_actions_string = ''.join(self.action_tokenizer(future_actions))
117 | current_action_string = self.action_tokenizer(current_action)
118 | action_chunk_string = current_action_string + future_actions_string
119 | action_chunk_len = len(action_chunk_string)
120 |
121 | if modality_id != 7 and modality_id != 8: # We give following language prompt when not selecting language modality instead of masking out the tokens.
122 | conversation = [
123 | {"from": "human", "value": "No language instruction"},
124 | {"from": "gpt", "value": action_chunk_string},
125 | ]
126 | else:
127 | conversation = [
128 | {"from": "human", "value": f"What action should the robot take to {inst_obj}?"},
129 | {"from": "gpt", "value": action_chunk_string},
130 | ]
131 |
132 | prompt_builder = self.prompt_builder("openvla")
133 | for turn in conversation:
134 | prompt_builder.add_turn(turn["from"], turn["value"])
135 |
136 | # Tokenize
137 | input_ids = torch.tensor(self.base_tokenizer(prompt_builder.get_prompt(), add_special_tokens=True).input_ids)
138 | labels = input_ids.clone()
139 | labels[:-(action_chunk_len + 1)] = IGNORE_INDEX
140 | if not predict_stop_token:
141 | labels[-1] = IGNORE_INDEX
142 |
143 | # Images for MBRA model
144 | image_obs_list = []
145 | for ih in range(self.context_size + 1):
146 | image_obs_list.append(self._resize_norm(TF.to_tensor(current_image_PIL), (96, 96))) #In our real code, image_obs_list is list of history image. In this dummy dataset code, we feed current images. The detail implementation is same as ViNT, NoMaD code base.
147 | image_obs = torch.cat(image_obs_list)
148 | image_goal = self._resize_norm(TF.to_tensor(goal_image_PIL), (96, 96))
149 |
150 | # Data augmentation (random cropping)
151 | voffset = int(224.0*0.2*random.random())
152 | hoffset = int(224.0*0.1*random.random())
153 | PILbox = (hoffset, voffset, 224-hoffset, 224-voffset)
154 | current_image_PIL = current_image_PIL.crop(PILbox).resize((224,224))
155 | goal_image_PIL = goal_image_PIL.crop(PILbox).resize((224,224))
156 |
157 | # Data augmentation (horizontal flipping)
158 | if random.random() > 0.5:
159 | current_image_PIL = current_image_PIL.transpose(Image.FLIP_LEFT_RIGHT)
160 | goal_image_PIL = goal_image_PIL.transpose(Image.FLIP_LEFT_RIGHT)
161 | actions[:,1] = -actions[:,1]
162 | actions[:,3] = -actions[:,3]
163 | goal_pose_cos_sin[1] = -goal_pose_cos_sin[1]
164 | goal_pose_cos_sin[3] = -goal_pose_cos_sin[3]
165 |
166 | image_obs = torch.flip(image_obs, dims=[2])
167 | image_goal = torch.flip(image_goal, dims=[2])
168 |
169 | pixel_values_current = self.image_transform(current_image_PIL)
170 | pixel_values_goal = self.image_transform(goal_image_PIL)
171 |
172 | #action select 1.0: raw action, 0.0: MBRA synthetic action
173 | action_select_mask = torch.tensor(1.0)
174 |
175 | dataset_name = "dummy"
176 | return dict(
177 | pixel_values=pixel_values_current, #Current image for OmniVLA
178 | pixel_values_goal=pixel_values_goal, #Goal image for OmniVLA
179 | input_ids=input_ids, #language and action prompt, following OpenVLA-OFT
180 | labels=labels, #language and action prompt, following OpenVLA-OFT
181 | dataset_name=dataset_name, #dataset name
182 | modality_id=modality_id, #modality ID, 0:"satellite only", 1:"pose and satellite", 2:"satellite and image", 3:"all", 4:"pose only", 5:"pose and image", 6:"image only", 7:"language only", 8:"language and pose"
183 | actions=torch.as_tensor(actions), #action commands
184 | action_select_mask = action_select_mask,#action select mask, 1.0: raw action, 0.0: MBRA synthetic action
185 | goal_pose=goal_pose_cos_sin, #goal pose [X, Y, cos(yaw), sin(yaw)]
186 | obj_pose_norm=goal_pose_cos_sin[0:2], #obj pose [X, Y] (This is only for LeLaN dataset) : Dummy pose in this dummy dataset
187 | img_PIL=current_image_PIL, #for visualization
188 | gimg_PIL=goal_image_PIL, #for visualization
189 | cur_image=image_obs, #History of image for MBRA
190 | goal_image_8=image_goal, #Goal image (8 step future) for MBRA
191 | temp_dist=10.0, #Temporal distance (We are not using in our training)
192 | lan_prompt=inst_obj
193 | )
194 |
195 |
--------------------------------------------------------------------------------
/prismatic/conf/vla.py:
--------------------------------------------------------------------------------
1 | """
2 | vla.py
3 |
4 | Draccus Dataclass Definition for a VLAConfig object, with various registered subclasses for each VLA experiment and
5 | model configuration thereof. A given VLA model (`policy`) configures the following attributes:
6 | - Data Mixture (e.g., Bridge, OXE_MAGIC_SOUP, etc.)
7 | - Base VLM from Prismatic Registry (e.g., `prism-dinosiglip+7b`)
8 | - VLA Model Architecture / Parameters (e.g., freeze vision encoder, last layer finetuning)
9 | - Training / Optimization Hyperparameters
10 | """
11 |
12 | from dataclasses import dataclass
13 | from enum import Enum, unique
14 | from pathlib import Path
15 | from typing import Optional, Union
16 |
17 | from draccus import ChoiceRegistry
18 |
19 |
20 | @dataclass
21 | class VLAConfig(ChoiceRegistry):
22 | # fmt: off
23 | vla_id: str # Unique VLA Policy ID that fully specifies a configuration variant
24 | base_vlm: Union[str, Path] # Base VLM as ID/Path to Run Directory (e.g., `prism-dinosiglip+7b`)
25 | freeze_vision_backbone: bool # Freeze Vision Backbone Parameters (akin to pretraining)
26 | freeze_llm_backbone: bool # Freeze LLM Backbone parameters
27 | unfreeze_last_llm_layer: bool # Unfreeze final layer of LLM (only takes effect if LLM is frozen)
28 |
29 | # Data Mixture Parameters
30 | data_mix: str # Open-X Embodiment Dataset =>> Unique Mixture ID (e.g., `bridge`)
31 | shuffle_buffer_size: int # Size of Shuffle Buffer (100K for Bridge, 1M for OXE)
32 |
33 | # Optimization Parameters
34 | epochs: int # Epochs to Run (in case `max_steps` is not specified)
35 | max_steps: Optional[int] # [Optional] Max Gradient Steps to Run (overrides `epochs`)
36 |
37 | expected_world_size: int # Expected # of GPUs =>> allows us to gate training on hardware
38 | global_batch_size: int # Global Batch Size (divided across processes / world size)
39 | per_device_batch_size: int # Per-Device Batch Size (per-process / individual GPU)
40 | # =>> # of accumulation steps is auto-computed
41 |
42 | learning_rate: float # Peak Learning Rate (`lr_scheduler_type` sets warmup/decay)
43 | weight_decay: float # Weight Decay for AdamW Optimizer
44 | max_grad_norm: float # Max Grad Norm (for global gradient clipping)
45 | lr_scheduler_type: str # LR Scheduler (usually: "constant" | "linear-warmup+cosine-decay")
46 | warmup_ratio: float # Fraction of Steps to Warmup (for warmup LR schedulers)
47 |
48 | train_strategy: str # Train Strategy (default "fsdp-full-shard")
49 |
50 | # Enable Gradient/Activation Checkpointing (for the LLM Backbone)
51 | enable_gradient_checkpointing: bool = True # Enable Gradient/Activation Checkpointing during Training
52 |
53 | # Mixed Precision Training via Torch Native AMP (`autocast`)
54 | enable_mixed_precision_training: bool = True # Enable Traditional BF16 Mixed Precision
55 | reduce_in_full_precision: bool = True # Accumulate/Reduce All-Gather Gradients in FP32 Full Precision
56 |
57 | # fmt: on
58 |
59 |
60 | # === OpenVLA Training Configurations ===
61 |
62 |
63 | # = [8 GPU] Fast Iteration =>> SigLIP 224px + Bridge =
64 | @dataclass
65 | class Exp_SigLIP_224px_Bridge(VLAConfig):
66 | vla_id: str = "siglip-224px+mx-bridge"
67 | base_vlm: Union[str, Path] = "siglip-224px+7b"
68 |
69 | freeze_vision_backbone: bool = False
70 | freeze_llm_backbone: bool = False
71 | unfreeze_last_llm_layer: bool = False
72 |
73 | # Data Mixture Parameters
74 | data_mix: str = "bridge"
75 | shuffle_buffer_size: int = 256_000
76 |
77 | # Optimization Parameters
78 | epochs: int = 1000
79 | max_steps: Optional[int] = None
80 |
81 | expected_world_size: int = 8
82 | global_batch_size: int = 256
83 | per_device_batch_size: int = 32
84 |
85 | learning_rate: float = 2e-5
86 | weight_decay: float = 0.0
87 | max_grad_norm: float = 1.0
88 | lr_scheduler_type: str = "constant"
89 | warmup_ratio: float = 0.0
90 |
91 | train_strategy: str = "fsdp-full-shard"
92 |
93 |
94 | # = [8 GPU] SigLIP 224px Frozen Vision Backbone + Bridge =
95 | @dataclass
96 | class Exp_FreezeVIT_SigLIP_224px_Bridge(Exp_SigLIP_224px_Bridge):
97 | vla_id: str = "siglip-224px-icy+mx-bridge"
98 | base_vlm: Union[str, Path] = "siglip-224px+7b"
99 | freeze_vision_backbone: bool = True
100 |
101 |
102 | # = [8 GPU] Fast Iteration =>> DINO-SigLIP 224px + Bridge =
103 | @dataclass
104 | class Exp_DinoSigLIP_224px_Bridge(Exp_SigLIP_224px_Bridge):
105 | vla_id: str = "prism-dinosiglip-224px+mx-bridge"
106 | base_vlm: Union[str, Path] = "prism-dinosiglip-224px+7b"
107 |
108 | data_mix: str = "bridge"
109 |
110 |
111 | # = [64 GPU] SigLIP 224px + OXE Magic Soup =
112 | @dataclass
113 | class Exp_SigLIP_224px_OXE_Magic_Soup(Exp_SigLIP_224px_Bridge):
114 | vla_id: str = "siglip-224px+mx-oxe-magic-soup"
115 | base_vlm: Union[str, Path] = "siglip-224px+7b"
116 |
117 | data_mix: str = "oxe_magic_soup"
118 |
119 | expected_world_size: int = 64
120 | global_batch_size: int = 2048
121 | per_device_batch_size: int = 32
122 |
123 |
124 | # = [64 GPU] DINO-SigLIP 224px + OXE Magic Soup++ =
125 | @dataclass
126 | class Exp_DinoSigLIP_224px_OXE_Magic_Soup_Plus(Exp_SigLIP_224px_Bridge):
127 | vla_id: str = "prism-dinosiglip-224px+mx-oxe-magic-soup-plus"
128 | base_vlm: Union[str, Path] = "prism-dinosiglip-224px+7b"
129 |
130 | # Note =>> We adopt two stages, training on a mixture including DROID for 70% of training, before resampling!
131 | # data_mix: str = "oxe_magic_soup_plus"
132 | data_mix: str = "oxe_magic_soup_plus_minus"
133 |
134 | expected_world_size: int = 64
135 | global_batch_size: int = 2048
136 | per_device_batch_size: int = 32
137 |
138 |
139 | # === OpenVLA Fine-tuning Configurations ===
140 |
141 |
142 | # = [8 GPU] SigLIP 224px + T-DROID =
143 | @dataclass
144 | class Exp_SigLIP_224px_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge):
145 | vla_id: str = "siglip-224px+mx-tdroid_carrot_in_bowl"
146 | base_vlm: Union[str, Path] = "siglip-224px+7b"
147 |
148 | data_mix: str = "tdroid_carrot_in_bowl"
149 |
150 |
151 | @dataclass
152 | class Exp_SigLIP_224px_TDROID_PourCornInPot(Exp_SigLIP_224px_Bridge):
153 | vla_id: str = "siglip-224px+mx-tdroid_pour_corn_in_pot"
154 | base_vlm: Union[str, Path] = "siglip-224px+7b"
155 |
156 | data_mix: str = "tdroid_pour_corn_in_pot"
157 |
158 |
159 | # = [8 GPU] SigLIP 224px + T-DROID -- Partial Finetuning =
160 | @dataclass
161 | class Exp_SigLIP_224px_Icy_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge):
162 | vla_id: str = "siglip-224px-icy+mx-tdroid_carrot_in_bowl"
163 | base_vlm: Union[str, Path] = "siglip-224px+7b"
164 | freeze_vision_backbone: bool = True
165 | freeze_llm_backbone: bool = False
166 |
167 | data_mix: str = "tdroid_carrot_in_bowl"
168 |
169 |
170 | @dataclass
171 | class Exp_SigLIP_224px_LastLayer_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge):
172 | vla_id: str = "siglip-224px-last_layer+mx-tdroid_carrot_in_bowl"
173 | base_vlm: Union[str, Path] = "siglip-224px+7b"
174 | freeze_vision_backbone: bool = True
175 | freeze_llm_backbone: bool = True
176 | unfreeze_last_llm_layer: bool = True
177 |
178 | data_mix: str = "tdroid_carrot_in_bowl"
179 |
180 |
181 | @dataclass
182 | class Exp_SigLIP_224px_Sandwich_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge):
183 | vla_id: str = "siglip-224px-sandwich+mx-tdroid_carrot_in_bowl"
184 | base_vlm: Union[str, Path] = "siglip-224px+7b"
185 | freeze_vision_backbone: bool = False
186 | freeze_llm_backbone: bool = True
187 | unfreeze_last_llm_layer: bool = True
188 |
189 | data_mix: str = "tdroid_carrot_in_bowl"
190 |
191 |
192 | # === [8 GPU] SigLIP 224px + FrankaWipe ===
193 | @dataclass
194 | class Exp_SigLIP_224px_Droid_Wipe(Exp_SigLIP_224px_Bridge):
195 | vla_id: str = "siglip-224px+mx-droid_wipe"
196 | base_vlm: Union[str, Path] = "siglip-224px+7b"
197 |
198 | data_mix: str = "droid_wipe"
199 |
200 |
201 | # === Define a VLA Registry Enum for Reference & Validation ===
202 | @unique
203 | class VLARegistry(Enum):
204 | # Sanity Check Configurations =>> BridgeV2
205 | SIGLIP_224PX_MX_BRIDGE = Exp_SigLIP_224px_Bridge
206 | DINOSIGLIP_224PX_MX_BRIDGE = Exp_DinoSigLIP_224px_Bridge
207 |
208 | # SigLIP Frozen Backbone Experiment
209 | FREEZE_SIGLIP_224PX_MX_BRIDGE = Exp_FreezeVIT_SigLIP_224px_Bridge
210 |
211 | # [OpenVLA v0.1 7B] SigLIP 224px + OXE Magic Soup
212 | SIGLIP_224PX_MX_OXE_MAGIC_SOUP = Exp_SigLIP_224px_OXE_Magic_Soup
213 |
214 | # [OpenVLA 7B] DINO + SigLIP 224px + OXE Magic Soup++
215 | DINOSIGLIP_224PX_MX_OXE_MAGIC_SOUP_PLUS = Exp_DinoSigLIP_224px_OXE_Magic_Soup_Plus
216 |
217 | # === TDROID Fine-tuning Configs ===
218 | SIGLIP_224PX_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_TDROID_CarrotInBowl
219 | SIGLIP_224PX_MX_TDROID_POUR_CORN_IN_POT = Exp_SigLIP_224px_TDROID_PourCornInPot
220 |
221 | SIGLIP_224PX_ICY_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_Icy_TDROID_CarrotInBowl
222 | SIGLIP_224PX_LASTLAYER_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_LastLayer_TDROID_CarrotInBowl
223 | SIGLIP_224PX_SANDWICH_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_Sandwich_TDROID_CarrotInBowl
224 |
225 | # === DROID Fine-tuning Configs ===
226 | SIGLIP_224PX_MX_DROID_WIPE = Exp_SigLIP_224px_Droid_Wipe
227 |
228 | @property
229 | def vla_id(self) -> str:
230 | return self.value.vla_id
231 |
232 |
233 | # Register VLAs in Choice Registry
234 | for vla_variant in VLARegistry:
235 | VLAConfig.register_subclass(vla_variant.vla_id, vla_variant.value)
236 |
--------------------------------------------------------------------------------
/prismatic/preprocessing/datasets/datasets.py:
--------------------------------------------------------------------------------
1 | """
2 | datasets.py
3 |
4 | PyTorch Dataset Definitions for Prismatic models; supports processing for both the `align` and `finetune` stages, with
5 | utilities for formatting conversations during the `finetune` stage subject to the given LLM backbone's expected
6 | formatting (e.g., SYS_PROMPT + USER: ... ASSISTANT: ... for Vicuña v1.5 Chat models).
7 |
8 | We currently only support Map-style Datasets; assumes that all files (annotations, images) are on local disk, and that
9 | random access image reading is relatively cheap/fast.
10 | """
11 |
12 | import copy
13 | import json
14 | from pathlib import Path
15 | from typing import Dict, List, Tuple, Type
16 |
17 | import torch
18 | from PIL import Image
19 | from torch.utils.data import Dataset
20 | from transformers import CodeGenTokenizerFast, LlamaTokenizerFast, PreTrainedTokenizerBase
21 |
22 | from prismatic.models.backbones.llm.prompting import PromptBuilder
23 | from prismatic.models.backbones.vision import ImageTransform
24 |
25 | # HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels)
26 | IGNORE_INDEX = -100
27 |
28 |
29 | class AlignDataset(Dataset[Dict[str, torch.Tensor]]):
30 | def __init__(
31 | self,
32 | chat_json: Path,
33 | image_dir: Path,
34 | image_transform: ImageTransform,
35 | tokenizer: PreTrainedTokenizerBase,
36 | ) -> None:
37 | super().__init__()
38 | self.chat_json, self.image_dir = chat_json, image_dir
39 | self.image_transform, self.tokenizer = image_transform, tokenizer
40 | self.dataset_type = "align"
41 |
42 | # Create Prompt Template
43 | self.prompt_template = "{caption}" + self.tokenizer.eos_token
44 |
45 | # Load Chat JSON
46 | with open(self.chat_json, "r") as f:
47 | self.examples = json.load(f)
48 |
49 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
50 | """
51 | Following the *actual* code executed from the LLaVa codebase, during the "align" phase, we actually discard
52 | the "prompt" from the human, and instead directly predict the caption from the image.
53 |
54 | As a concrete example given the "raw data" for the first example:
55 | example = self.examples[0]["conversations"]` = {
56 | [
57 | {"from": "human", "value": "Render a clear and concise summary of the photo.\n"},
58 | {"from": "gpt", "value": "select luxury furniture 3 - inch gel memory foam mattress topper"}
59 | ]
60 | }
61 |
62 | Return =>> self.tokenizer(" select luxury furniture 3 - inch gel memory foam mattress topper\n")
63 |
64 | :param idx: Index to retrieve from the dataset.
65 |
66 | :return: Dictionary of {"pixel_values": torch.Tensor, "input_ids": torch.Tensor, "labels": torch.Tensor}
67 | """
68 | image_path, conversation = Path(self.examples[idx]["image"]), self.examples[idx]["conversations"]
69 | assert (len(conversation) == 2) and ("" not in conversation[-1]["value"]), "Unexpected text!"
70 |
71 | # Format Caption --> {caption}{eos_token}
72 | caption = self.prompt_template.format(caption=conversation[-1]["value"].strip())
73 |
74 | # We treat image patches as "tokens = [p1 p2 p3, ...]"; we need to specify ordering of text/patch tokens.
75 | # => Critically, we find that inserting *after* the BOS token leads to the strongest performance!
76 | # - input_ids = " p1 p2 p3 ... \n"
77 | # - labels = "IGNORE IGNORE ..." (copy `input_ids` replacing and p{1...K} with IGNORE)
78 | #
79 | # IMPORTANT => IF WE'RE USING HF LLM.forward(... labels=labels), SHIFTING HAPPENS _INSIDE_ MODEL!
80 | input_ids = self.tokenizer(caption, truncation=True, return_tensors="pt").input_ids[0]
81 | labels = copy.deepcopy(input_ids)
82 |
83 | # Set the token's label to IGNORE_INDEX (since we're inserting the image patches right after)
84 | labels[0] = IGNORE_INDEX
85 |
86 | # Process Image --> get "pixel_values" (will either be a torch.Tensor OR a Dict[str,torch.Tensor])
87 | pixel_values = self.image_transform(Image.open(self.image_dir / image_path).convert("RGB"))
88 |
89 | return dict(pixel_values=pixel_values, input_ids=input_ids, labels=labels)
90 |
91 | def get_modality_lengths(self, n_image_patches: int) -> List[Tuple[bool, int]]:
92 | """Get a list of modalities (unimodal / text-only vs. multimodal) and length of conversations per example."""
93 | modality_lengths = []
94 | for example in self.examples:
95 | is_multimodal = "image" in example
96 | n_words = sum([len(turn["value"].replace("", "").split()) for turn in example["conversations"]])
97 | modality_lengths.append((is_multimodal, (n_image_patches + n_words) if is_multimodal else n_words))
98 | return modality_lengths
99 |
100 | def __len__(self) -> int:
101 | return len(self.examples)
102 |
103 |
104 | class FinetuneDataset(Dataset[Dict[str, torch.Tensor]]):
105 | def __init__(
106 | self,
107 | instruct_json: Path,
108 | image_dir: Path,
109 | image_transform: ImageTransform,
110 | tokenizer: PreTrainedTokenizerBase,
111 | prompt_builder_fn: Type[PromptBuilder],
112 | ) -> None:
113 | super().__init__()
114 | self.instruct_json, self.image_dir = instruct_json, image_dir
115 | self.image_transform, self.tokenizer = image_transform, tokenizer
116 | self.prompt_builder_fn = prompt_builder_fn
117 | self.dataset_type = "finetune"
118 |
119 | # Load Instruct JSON
120 | with open(self.instruct_json, "r") as f:
121 | self.examples = json.load(f)
122 |
123 | # === Unimodal + Multimodal Handling ===
124 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
125 | """
126 | Unlike the *align* stage handling, for the *finetune* stage, we actually need to handle multiple "turns" of
127 | dialog grounded in a single image.
128 |
129 | To do this, we leverage the `prompt_builder_fn` which instantiates a PromptBuilder object. By calling the
130 | methods for adding turns and getting a prompt, we ensure proper formatting and consistency for each example.
131 |
132 | :param idx: Index to retrieve from the dataset.
133 |
134 | :return: Dictionary of {"pixel_values": torch.Tensor, "input_ids": torch.Tensor, "labels": torch.Tensor}
135 | """
136 | conversation = self.examples[idx]["conversations"]
137 |
138 | # Create Prompt Builder --> add each message sequentially
139 | prompt_builder, input_ids, labels = self.prompt_builder_fn(model_family="prismatic"), [], []
140 | for turn_idx, turn in enumerate(conversation):
141 | # Get "effective" string added to prompt --> handle whitespace for tokenizer type!
142 | msg = prompt_builder.add_turn(turn["from"], turn["value"])
143 |
144 | # Llama Tokenizer (Fast) adds extra character if a string ends in whitespace --> strip if non-empty!
145 | if isinstance(self.tokenizer, LlamaTokenizerFast):
146 | msg = msg.rstrip()
147 |
148 | # Phi-2 Tokenizer == CodeGenTokenizer (Fast) -- no special handling!
149 | elif isinstance(self.tokenizer, CodeGenTokenizerFast):
150 | pass
151 |
152 | else:
153 | raise ValueError(f"Tokenizer of type `{type(self.tokenizer)}` is not explicitly handled!")
154 |
155 | # Tokenize Input IDs
156 | turn_input_ids = self.tokenizer(msg, add_special_tokens=turn_idx == 0).input_ids
157 |
158 | # [CRITICAL] We do not want to take the loss for the "USER: " prompts =>> just the responses!
159 | turn_labels = (
160 | [IGNORE_INDEX for _ in range(len(turn_input_ids))] if (turn_idx % 2) == 0 else list(turn_input_ids)
161 | )
162 |
163 | # Add to Trackers
164 | input_ids.extend(turn_input_ids)
165 | labels.extend(turn_labels)
166 |
167 | # Tensorize =>> Set the token's label to IGNORE_INDEX (since we're inserting the image patches after)
168 | # - IMPORTANT => IF WE'RE USING HF LLM.forward(... labels=labels), SHIFTING HAPPENS _INSIDE_ MODEL!
169 | input_ids, labels = torch.tensor(input_ids), torch.tensor(labels)
170 |
171 | # Handle Truncation (if necessary)
172 | input_ids, labels = input_ids[: self.tokenizer.model_max_length], labels[: self.tokenizer.model_max_length]
173 |
174 | # === Handle "unimodal" (language-only) vs. "multimodal" ===
175 | if "image" in self.examples[idx]:
176 | image_path = Path(self.examples[idx]["image"])
177 |
178 | # Set the token's label to IGNORE_INDEX (since we're inserting the image patches right after)
179 | labels[0] = IGNORE_INDEX
180 |
181 | # Process Image --> get "pixel_values" (will either be a torch.Tensor OR a Dict[str,torch.Tensor])
182 | pixel_values = self.image_transform(Image.open(self.image_dir / image_path).convert("RGB"))
183 |
184 | return dict(pixel_values=pixel_values, input_ids=input_ids, labels=labels)
185 |
186 | else:
187 | # No image --> return `pixel_values` = None; Collator will do the smart batch handling for us!
188 | return dict(pixel_values=None, input_ids=input_ids, labels=labels)
189 |
190 | def get_modality_lengths(self) -> List[Tuple[bool, int]]:
191 | """Get a list of modalities (unimodal / text-only vs. multimodal) and length of conversations per example."""
192 | modality_lengths = []
193 | for example in self.examples:
194 | is_multimodal = "image" in example
195 | n_words = sum([len(turn["value"].split()) for turn in example["conversations"]])
196 | modality_lengths.append((is_multimodal, n_words))
197 | return modality_lengths
198 |
199 | def __len__(self) -> int:
200 | return len(self.examples)
201 |
--------------------------------------------------------------------------------