├── .github └── FUNDING.yml ├── .gitignore ├── LICENSE ├── README.md ├── assets ├── real-world-exp_1.png └── teaser_univla.png ├── docs └── real-world-deployment.md ├── experiments └── robot │ ├── calvin │ ├── calvin_env_wrapper.py │ ├── calvin_model.py │ └── run_calvin_eval_ddp.py │ ├── libero │ ├── libero_requirements.txt │ ├── libero_utils.py │ ├── regenerate_libero_dataset.py │ ├── run_libero_eval.py │ └── run_libero_eval_blocking.py │ ├── openvla_utils.py │ └── robot_utils.py ├── latent_action_model ├── config │ ├── lam-stage-1.yaml │ └── lam-stage-2.yaml ├── genie │ ├── dataset.py │ ├── model.py │ └── modules │ │ ├── __init__.py │ │ ├── blocks.py │ │ └── lam.py ├── main.py └── train.sh ├── prismatic ├── __init__.py ├── conf │ ├── __init__.py │ ├── datasets.py │ ├── models.py │ └── vla.py ├── extern │ ├── __init__.py │ └── hf │ │ ├── __init__.py │ │ ├── configuration_prismatic.py │ │ ├── modeling_prismatic.py │ │ └── processing_prismatic.py ├── models │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ ├── llm │ │ │ ├── __init__.py │ │ │ ├── base_llm.py │ │ │ ├── llama2.py │ │ │ ├── mistral.py │ │ │ ├── phi.py │ │ │ └── prompting │ │ │ │ ├── __init__.py │ │ │ │ ├── base_prompter.py │ │ │ │ ├── llama2_chat_prompter.py │ │ │ │ ├── mistral_instruct_prompter.py │ │ │ │ ├── phi_prompter.py │ │ │ │ └── vicuna_v15_prompter.py │ │ └── vision │ │ │ ├── __init__.py │ │ │ ├── base_vision.py │ │ │ ├── clip_vit.py │ │ │ ├── dinoclip_vit.py │ │ │ ├── dinosiglip_vit.py │ │ │ ├── dinov2_vit.py │ │ │ ├── in1k_vit.py │ │ │ └── siglip_vit.py │ ├── load.py │ ├── materialize.py │ ├── policy │ │ └── transformer_utils.py │ ├── registry.py │ ├── vlas │ │ ├── __init__.py │ │ └── openvla.py │ └── vlms │ │ ├── __init__.py │ │ ├── base_vlm.py │ │ └── prismatic.py ├── overwatch │ ├── __init__.py │ └── overwatch.py ├── preprocessing │ ├── __init__.py │ ├── datasets │ │ ├── __init__.py │ │ └── datasets.py │ ├── download.py │ └── materialize.py ├── py.typed ├── training │ ├── __init__.py │ ├── materialize.py │ ├── metrics.py │ └── strategies │ │ ├── __init__.py │ │ ├── base_strategy.py │ │ ├── ddp.py │ │ └── fsdp.py ├── util │ ├── __init__.py │ ├── batching_utils.py │ ├── data_utils.py │ ├── nn_utils.py │ └── torch_utils.py └── vla │ ├── __init__.py │ ├── action_tokenizer.py │ ├── datasets │ ├── __init__.py │ ├── datasets.py │ ├── real_world_dataset.py │ └── rlds │ │ ├── __init__.py │ │ ├── dataset.py │ │ ├── obs_transforms.py │ │ ├── oxe │ │ ├── __init__.py │ │ ├── configs.py │ │ ├── materialize.py │ │ ├── mixtures.py │ │ ├── transforms.py │ │ └── utils │ │ │ └── droid_utils.py │ │ ├── traj_transforms.py │ │ └── utils │ │ ├── __init__.py │ │ ├── data_utils.py │ │ ├── goal_relabeling.py │ │ └── task_augmentation.py │ └── materialize.py ├── pyproject.toml ├── requirements.txt ├── setup.py └── vla-scripts ├── finetune_libero.py ├── finetune_realworld.py ├── real_world_deployment.py ├── train.py └── train.sh /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [OpenDriveLab] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 12 | polar: # Replace with a single Polar username 13 | buy_me_a_coffee: # Replace with a single Buy Me a Coffee username 14 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | **/__pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | **/libero_log/ 7 | **/eval_logs/ 8 | vla_log/ 9 | logs/ 10 | wandb/ 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | 136 | # Ruff 137 | .ruff_cache/ 138 | 139 | # Auth Tokens / Hidden Files 140 | .hf_token 141 | .wandb_api_key 142 | .*_token 143 | .*api_key 144 | 145 | # IDE Caches 146 | .idea/ 147 | .vscode/ 148 | 149 | # Mac OS 150 | .DS_Store 151 | 152 | # Caches and Datasets 153 | cache/ 154 | data/ -------------------------------------------------------------------------------- /assets/real-world-exp_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenDriveLab/UniVLA/c9c6788514a707daf536848197b9a5024fc85b6e/assets/real-world-exp_1.png -------------------------------------------------------------------------------- /assets/teaser_univla.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenDriveLab/UniVLA/c9c6788514a707daf536848197b9a5024fc85b6e/assets/teaser_univla.png -------------------------------------------------------------------------------- /docs/real-world-deployment.md: -------------------------------------------------------------------------------- 1 | ## Real-World Deployment 2 | 3 | ### Finetune 4 | 5 | #### Data Pre-processing 6 | 7 | 8 | > [!TIP] 9 | > If your datasets are in LeRobot format, you can convert them using the LeRobot2RLDS tool from [Any4Lerobot](https://github.com/Tavish9/any4lerobot). After conversion, you can load the data using the same pipeline we employed for LIBERO training. 10 | 11 | - Here we provide a script to process hdf5 files from an Agilex arm: ```/prismatic/vla/datasets/real_world_dataset.py``` 12 | - The data is structured as follows: 13 | 14 | ``` 15 | ├── action 16 | ├── observations 17 | │ ├── images 18 | │ │ ├── cam_high 19 | │ ├── qpos 20 | ``` 21 | 22 | The descriptions of each key above are as follows: 23 | 24 | | Keys | Description | Shape | 25 | | -------- | ----------------------------------------------------- | --------------------------- | 26 | | action | Collected real-world action data | [episode_len, 7] | 27 | | cam_high | Image frames captured by RGB cameras named 'cam_high' | [episode_len, 480, 640, 3] | 28 | | qpos | Proprioceptive data of robot arm (joint angle) | [episode_len, 7] | 29 | 30 | 31 | 32 | #### Training 33 | 34 | - ```./vla-scripts/finetune_real_world.py```: Finetune UniVLA on real-world data 35 | 36 | 37 | 38 | ```bash 39 | torchrun --standalone --nnodes 1 --nproc-per-node 8 finetune_real_world.py \ 40 | --batch_size 4 \ # Adjust based on your compute setup 41 | --grad_accumulation_steps 2 \ # A workaround for larger equivalent batch size 42 | --max_steps 10000 \ # Number of training steps, adjust based on your data volume 43 | --save_steps 2500 \ # Steps to save intermediate ckpts 44 | --window_size 10 \ # Frames interval for LAM, also the action chunk size, adjust based on your data frequency 45 | --run_root_dir "./real-world-log" 46 | ``` 47 | 48 | 49 | 50 | ### Inference 51 | 52 | Once finished fine-tuning of UniVLA and get the action decoder head tailored to your embodiment action space, let's deploy it and see how it works! 53 | 54 | Here we provide a deployment example ```./vla-scripts/real_world_employment.py``` which is mainly about `UniVLAInference` class. 55 | 56 | 57 | > [!NOTE] 58 | > Due to differences in deployment code among various embodiments, we present a general example below for reference. Action chunking is also implemented within the `UniVLAInference` class. 59 | 60 | ```python 61 | # Register UniVLA 62 | policy = UniVLAInference(saved_model_path=saved_model_path, pred_action_horizon=12, decoder_path=decoder_path) 63 | 64 | # curr_image is read from a camera 65 | resized_curr_image = torchvision.transforms.Resize((224, 224))(torch.flip(curr_image[0],(1,))) # Resize + BGR2RGB (not necessary) 66 | 67 | # create fake inputs 68 | task_instruction = 'Store the screwdriver' 69 | proprio = torch.zeros((1,7)) 70 | 71 | # sample actions 72 | all_actions = policy.step(resized_curr_image, task_instruction, proprio) 73 | ``` 74 | 75 | -------------------------------------------------------------------------------- /experiments/robot/calvin/calvin_env_wrapper.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import Any, Dict, Tuple, Union 4 | 5 | import gym 6 | import numpy as np 7 | import torch 8 | 9 | from calvin_env.envs.play_table_env import get_env 10 | from calvin_env.utils.utils import EglDeviceNotFoundError, get_egl_device_id 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class CalvinEnvWrapperRaw(gym.Wrapper): 16 | def __init__(self, abs_datasets_dir, observation_space, device, show_gui=False, **kwargs): 17 | """Environment wrapper which returns raw observations. 18 | 19 | Args: 20 | abs_datasets_dir: absolute datset directory 21 | observation_sapce: {'rgb_obs': ['rgb_static', 'rgb_gripper'], 'depth_obs': [], 'state_obs': ['robot_obs'], 'actions': ['rel_actions'], 'language': ['language']} 22 | """ 23 | self.set_egl_device(device) 24 | env = get_env( 25 | abs_datasets_dir, show_gui=show_gui, obs_space=observation_space, **kwargs 26 | ) 27 | super(CalvinEnvWrapperRaw, self).__init__(env) 28 | self.observation_space_keys = observation_space 29 | self.device = device 30 | self.relative_actions = "rel_actions" in self.observation_space_keys["actions"] 31 | logger.info(f"Initialized PlayTableEnv for device {self.device}") 32 | 33 | @staticmethod 34 | def set_egl_device(device): 35 | if "EGL_VISIBLE_DEVICES" in os.environ: 36 | logger.warning("Environment variable EGL_VISIBLE_DEVICES is already set. Is this intended?") 37 | # modified: cuda_id = device.index if device.type == "cuda" else 0 38 | cuda_id = torch.cuda.current_device() 39 | try: 40 | egl_id = get_egl_device_id(cuda_id) 41 | except EglDeviceNotFoundError: 42 | logger.warning( 43 | "Couldn't find correct EGL device. Setting EGL_VISIBLE_DEVICE=0. " 44 | "When using DDP with many GPUs this can lead to OOM errors. " 45 | "Did you install PyBullet correctly? Please refer to calvin env README" 46 | ) 47 | egl_id = 0 48 | os.environ["EGL_VISIBLE_DEVICES"] = str(egl_id) 49 | logger.info(f"EGL_DEVICE_ID {egl_id} <==> CUDA_DEVICE_ID {cuda_id}") 50 | 51 | def step( 52 | self, action_tensor: torch.Tensor 53 | ) -> Tuple[Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]], int, bool, Dict]: 54 | if self.relative_actions: 55 | action = action_tensor#.squeeze().cpu().detach().numpy() 56 | assert len(action) == 7 57 | else: 58 | if action_tensor.shape[-1] == 7: 59 | slice_ids = [3, 6] 60 | elif action_tensor.shape[-1] == 8: 61 | slice_ids = [3, 7] 62 | else: 63 | logger.error("actions are required to have length 8 (for euler angles) or 9 (for quaternions)") 64 | raise NotImplementedError 65 | action = np.split(action_tensor, slice_ids) 66 | o, r, d, i = self.env.step(action) 67 | 68 | obs = o # use raw observation 69 | return obs, r, d, i 70 | 71 | def reset( 72 | self, 73 | reset_info: Dict[str, Any] = None, 74 | batch_idx: int = 0, 75 | seq_idx: int = 0, 76 | scene_obs: Any = None, 77 | robot_obs: Any = None, 78 | ) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]: 79 | if reset_info is not None: 80 | obs = self.env.reset( 81 | robot_obs=reset_info["robot_obs"][batch_idx, seq_idx], 82 | scene_obs=reset_info["scene_obs"][batch_idx, seq_idx], 83 | ) 84 | elif scene_obs is not None or robot_obs is not None: 85 | obs = self.env.reset(scene_obs=scene_obs, robot_obs=robot_obs) 86 | else: 87 | obs = self.env.reset() 88 | 89 | return obs # use raw observation 90 | 91 | def get_info(self): 92 | return self.env.get_info() 93 | 94 | def get_obs(self): 95 | obs = self.env.get_obs() 96 | return obs # use raw observation -------------------------------------------------------------------------------- /experiments/robot/calvin/calvin_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torchvision.transforms as T 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from PIL import Image 7 | from einops import rearrange 8 | import time 9 | 10 | import torch.nn as nn 11 | 12 | from calvin_agent.models.calvin_base_model import CalvinBaseModel 13 | 14 | from experiments.robot.robot_utils import ( 15 | DATE_TIME, 16 | get_latent_action, 17 | get_image_resize_size, 18 | get_model, 19 | invert_gripper_action, 20 | normalize_gripper_action, 21 | set_seed_everywhere, 22 | ) 23 | from experiments.robot.openvla_utils import get_processor 24 | from prismatic.models.policy.transformer_utils import MAPBloc 25 | 26 | 27 | 28 | class ActionDecoder(torch.nn.Module): 29 | def __init__(self, window_size = 5, hidden_dim=512): 30 | super().__init__() 31 | self.latent_action_pool = MAPBlock(n_latents = 1, vis_dim = 4096, embed_dim = hidden_dim, n_heads = hidden_dim//64) 32 | self.visual_pool = MAPBlock(n_latents = 1, vis_dim = 4096, embed_dim = hidden_dim, n_heads = hidden_dim//64) 33 | 34 | 35 | self.proj = nn.Sequential( 36 | nn.Linear(hidden_dim, 7 * window_size), 37 | nn.Tanh(), 38 | ) 39 | 40 | def forward(self, latent_action_tokens, visual_embed): 41 | latent_action_tokens = latent_action_tokens[:, -4:] 42 | visual_embed = self.visual_pool(visual_embed) 43 | action = self.proj(self.latent_action_pool(latent_action_tokens , init_embed=visual_embed)) 44 | 45 | return action 46 | 47 | 48 | 49 | class ActionDecoderWrapper(nn.Module): 50 | def __init__(self, window_size=5): 51 | super().__init__() 52 | self.net = ActionDecoder(window_size) 53 | 54 | self.temporal_size = 12 55 | self.temporal_mask = torch.flip(torch.triu(torch.ones(self.temporal_size, self.temporal_size, dtype=torch.bool)), dims=[1]).numpy() 56 | 57 | self.action_buffer = np.zeros((self.temporal_mask.shape[0], self.temporal_mask.shape[0], 7)) 58 | self.action_buffer_mask = np.zeros((self.temporal_mask.shape[0], self.temporal_mask.shape[0]), dtype=np.bool_) 59 | 60 | # Action chunking with temporal aggregation 61 | balancing_factor = 0.01 62 | self.temporal_weights = np.array([np.exp(-1 * balancing_factor * i) for i in range(self.temporal_size)])[:, None] 63 | 64 | 65 | def reset(self): 66 | self.action_buffer = np.zeros((self.temporal_mask.shape[0], self.temporal_mask.shape[0], 7)) 67 | self.action_buffer_mask = np.zeros((self.temporal_mask.shape[0], self.temporal_mask.shape[0]), dtype=np.bool_) 68 | 69 | 70 | def forward(self, latent_actions, visual_embed): 71 | # Forward action decoder 72 | pred_action = self.net(latent_actions.to(torch.float), visual_embed.to(torch.float)).reshape(-1, 12, 7) 73 | pred_action = np.array(pred_action.tolist())[0] 74 | 75 | return pred_action[0] 76 | 77 | 78 | class WrappedModel(torch.nn.Module): 79 | def __init__(self, cfg): 80 | super().__init__() 81 | 82 | # Load action decoder 83 | self.action_decoder = ActionDecoderWrapper(cfg.window_size) 84 | self.action_decoder.net.load_state_dict(torch.load(cfg.action_decoder_path)) 85 | 86 | # Load VLA 87 | self.vla = get_model(cfg) 88 | 89 | 90 | 91 | class WrappedCalvinEvaluation(CalvinBaseModel): 92 | def __init__(self, cfg, wrapped_model): 93 | super().__init__() 94 | self.cfg = cfg 95 | 96 | self.model = wrapped_model 97 | # [OpenVLA] Get Hugging Face processor 98 | self.processor = get_processor(cfg) 99 | self.prev_hist_action = [''] 100 | 101 | 102 | 103 | def reset(self,): 104 | """ 105 | This is called 106 | """ 107 | self.model.module.action_decoder.reset() 108 | self.prev_hist_action = [''] 109 | 110 | 111 | def step(self, obs, instruction, step): 112 | """ 113 | Args: 114 | obs: environment observations 115 | goal: embedded language goal 116 | Returns: 117 | action: predicted action 118 | """ 119 | img = obs["rgb_obs"]['rgb_static'] 120 | 121 | observation = { 122 | "full_image": img, 123 | "state": [], 124 | } 125 | 126 | # Query model to get latent action 127 | latent_action, visual_embed, generated_ids = get_latent_action( 128 | self.cfg, 129 | self.model.module.vla, 130 | observation, 131 | instruction, 132 | processor=self.processor, 133 | ) 134 | 135 | # Get decoded action 136 | action = self.model.module.action_decoder(latent_action, visual_embed) 137 | 138 | return action -------------------------------------------------------------------------------- /experiments/robot/libero/libero_requirements.txt: -------------------------------------------------------------------------------- 1 | imageio[ffmpeg] 2 | robosuite==1.4.1 3 | bddl 4 | easydict 5 | cloudpickle 6 | gym 7 | -------------------------------------------------------------------------------- /experiments/robot/libero/libero_utils.py: -------------------------------------------------------------------------------- 1 | """Utils for evaluating policies in LIBERO simulation environments.""" 2 | 3 | import math 4 | import os 5 | 6 | import imageio 7 | import numpy as np 8 | import tensorflow as tf 9 | from libero.libero import get_libero_path 10 | from libero.libero.envs import OffScreenRenderEnv 11 | 12 | from experiments.robot.robot_utils import ( 13 | DATE, 14 | DATE_TIME, 15 | ) 16 | 17 | 18 | def get_libero_env(task, model_family, resolution=256): 19 | """Initializes and returns the LIBERO environment, along with the task description.""" 20 | task_description = task.language 21 | task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file) 22 | env_args = {"bddl_file_name": task_bddl_file, "camera_heights": resolution, "camera_widths": resolution} 23 | env = OffScreenRenderEnv(**env_args) 24 | env.seed(0) # IMPORTANT: seed seems to affect object positions even when using fixed initial state 25 | return env, task_description 26 | 27 | 28 | def get_libero_dummy_action(model_family: str): 29 | """Get dummy/no-op action, used to roll out the simulation while the robot does nothing.""" 30 | return [0, 0, 0, 0, 0, 0, -1] 31 | 32 | 33 | def resize_image(img, resize_size): 34 | """ 35 | Takes numpy array corresponding to a single image and returns resized image as numpy array. 36 | 37 | NOTE (Moo Jin): To make input images in distribution with respect to the inputs seen at training time, we follow 38 | the same resizing scheme used in the Octo dataloader, which OpenVLA uses for training. 39 | """ 40 | assert isinstance(resize_size, tuple) 41 | # Resize to image size expected by model 42 | img = tf.image.encode_jpeg(img) # Encode as JPEG, as done in RLDS dataset builder 43 | img = tf.io.decode_image(img, expand_animations=False, dtype=tf.uint8) # Immediately decode back 44 | img = tf.image.resize(img, resize_size, method="lanczos3", antialias=True) 45 | img = tf.cast(tf.clip_by_value(tf.round(img), 0, 255), tf.uint8) 46 | img = img.numpy() 47 | return img 48 | 49 | 50 | def get_libero_image(obs, resize_size): 51 | """Extracts image from observations and preprocesses it.""" 52 | assert isinstance(resize_size, int) or isinstance(resize_size, tuple) 53 | if isinstance(resize_size, int): 54 | resize_size = (resize_size, resize_size) 55 | img = obs["agentview_image"] 56 | img = img[::-1, ::-1] # IMPORTANT: rotate 180 degrees to match train preprocessing 57 | img = resize_image(img, resize_size) 58 | return img 59 | 60 | 61 | def save_rollout_video(rollout_images, idx, success, task_description, log_file=None): 62 | """Saves an MP4 replay of an episode.""" 63 | rollout_dir = f"./rollouts/{DATE}" 64 | os.makedirs(rollout_dir, exist_ok=True) 65 | processed_task_description = task_description.lower().replace(" ", "_").replace("\n", "_").replace(".", "_")[:50] 66 | mp4_path = f"{rollout_dir}/{DATE_TIME}--episode={idx}--success={success}--task={processed_task_description}.mp4" 67 | video_writer = imageio.get_writer(mp4_path, fps=30) 68 | for img in rollout_images: 69 | video_writer.append_data(img) 70 | video_writer.close() 71 | print(f"Saved rollout MP4 at path {mp4_path}") 72 | if log_file is not None: 73 | log_file.write(f"Saved rollout MP4 at path {mp4_path}\n") 74 | return mp4_path 75 | 76 | 77 | def quat2axisangle(quat): 78 | """ 79 | Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55 80 | 81 | Converts quaternion to axis-angle format. 82 | Returns a unit vector direction scaled by its angle in radians. 83 | 84 | Args: 85 | quat (np.array): (x,y,z,w) vec4 float angles 86 | 87 | Returns: 88 | np.array: (ax,ay,az) axis-angle exponential coordinates 89 | """ 90 | # clip quaternion 91 | if quat[3] > 1.0: 92 | quat[3] = 1.0 93 | elif quat[3] < -1.0: 94 | quat[3] = -1.0 95 | 96 | den = np.sqrt(1.0 - quat[3] * quat[3]) 97 | if math.isclose(den, 0.0): 98 | # This is (close to) a zero degree rotation, immediately return 99 | return np.zeros(3) 100 | 101 | return (quat[:3] * 2.0 * math.acos(quat[3])) / den 102 | -------------------------------------------------------------------------------- /experiments/robot/robot_utils.py: -------------------------------------------------------------------------------- 1 | """Utils for evaluating robot policies in various environments.""" 2 | 3 | import os 4 | import random 5 | import time 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from experiments.robot.openvla_utils import ( 11 | get_vla, 12 | get_vla_action, 13 | get_vla_latent_action, 14 | ) 15 | 16 | # Initialize important constants and pretty-printing mode in NumPy. 17 | ACTION_DIM = 7 18 | DATE = time.strftime("%Y_%m_%d") 19 | DATE_TIME = time.strftime("%Y_%m_%d-%H_%M_%S") 20 | DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 21 | np.set_printoptions(formatter={"float": lambda x: "{0:0.3f}".format(x)}) 22 | 23 | # Initialize system prompt for OpenVLA v0.1. 24 | OPENVLA_V01_SYSTEM_PROMPT = ( 25 | "A chat between a curious user and an artificial intelligence assistant. " 26 | "The assistant gives helpful, detailed, and polite answers to the user's questions." 27 | ) 28 | 29 | 30 | def set_seed_everywhere(seed: int): 31 | """Sets the random seed for Python, NumPy, and PyTorch functions.""" 32 | torch.manual_seed(seed) 33 | torch.cuda.manual_seed_all(seed) 34 | np.random.seed(seed) 35 | random.seed(seed) 36 | torch.backends.cudnn.deterministic = True 37 | torch.backends.cudnn.benchmark = False 38 | os.environ["PYTHONHASHSEED"] = str(seed) 39 | 40 | 41 | def get_model(cfg, wrap_diffusion_policy_for_droid=False): 42 | """Load model for evaluation.""" 43 | if cfg.model_family == "openvla": 44 | model = get_vla(cfg) 45 | else: 46 | raise ValueError("Unexpected `model_family` found in config.") 47 | print(f"Loaded model: {type(model)}") 48 | return model 49 | 50 | 51 | def get_image_resize_size(cfg): 52 | """ 53 | Gets image resize size for a model class. 54 | If `resize_size` is an int, then the resized image will be a square. 55 | Else, the image will be a rectangle. 56 | """ 57 | if cfg.model_family == "openvla": 58 | resize_size = 224 59 | else: 60 | raise ValueError("Unexpected `model_family` found in config.") 61 | return resize_size 62 | 63 | 64 | def get_action(cfg, model, obs, task_label, processor=None): 65 | """Queries the model to get an action.""" 66 | if cfg.model_family == "openvla": 67 | action = get_vla_action( 68 | model, processor, cfg.pretrained_checkpoint, obs, task_label, cfg.unnorm_key, center_crop=cfg.center_crop, 69 | ) 70 | assert action.shape == (ACTION_DIM,) 71 | else: 72 | raise ValueError("Unexpected `model_family` found in config.") 73 | return action 74 | 75 | 76 | def get_latent_action(cfg, model, obs, task_label, processor=None, hist_action=''): 77 | """Queries the model to get an action.""" 78 | latent_action = get_vla_latent_action( 79 | model, processor, cfg.pretrained_checkpoint, obs, task_label, cfg.unnorm_key, center_crop=cfg.center_crop, hist_action=hist_action, 80 | ) 81 | 82 | return latent_action 83 | 84 | 85 | def normalize_gripper_action(action, binarize=True): 86 | """ 87 | Changes gripper action (last dimension of action vector) from [0,1] to [-1,+1]. 88 | Necessary for some environments (not Bridge) because the dataset wrapper standardizes gripper actions to [0,1]. 89 | Note that unlike the other action dimensions, the gripper action is not normalized to [-1,+1] by default by 90 | the dataset wrapper. 91 | 92 | Normalization formula: y = 2 * (x - orig_low) / (orig_high - orig_low) - 1 93 | """ 94 | # Just normalize the last action to [-1,+1]. 95 | orig_low, orig_high = 0.0, 1.0 96 | action[..., -1] = 2 * (action[..., -1] - orig_low) / (orig_high - orig_low) - 1 97 | 98 | if binarize: 99 | # Binarize to -1 or +1. 100 | action[..., -1] = np.sign(action[..., -1]) 101 | 102 | return action 103 | 104 | 105 | def invert_gripper_action(action): 106 | """ 107 | Flips the sign of the gripper action (last dimension of action vector). 108 | This is necessary for some environments where -1 = open, +1 = close, since 109 | the RLDS dataloader aligns gripper actions such that 0 = close, 1 = open. 110 | """ 111 | action[..., -1] = action[..., -1] * -1.0 112 | return action 113 | -------------------------------------------------------------------------------- /latent_action_model/config/lam-stage-1.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | image_channels: 3 3 | 4 | lam_model_dim: 768 5 | lam_latent_dim: 128 6 | lam_num_latents: 16 7 | lam_patch_size: 14 8 | lam_enc_blocks: 12 9 | lam_dec_blocks: 12 10 | lam_num_heads: 12 11 | 12 | vq_beta: 0.25 13 | log_interval: 5000 14 | log_path: ./logs 15 | optimizer: 16 | class_path: torch.optim.AdamW 17 | init_args: 18 | lr: 1e-4 19 | weight_decay: 1e-2 20 | 21 | task_name: task_centric_lam_stage1 22 | make_data_pair: &make_data_pair false 23 | stage: stage-1 24 | 25 | data: 26 | data_root: /path/to/your/rlds_data_collection 27 | data_mix: omni_magic_soup_plus_plus # Manip. + Navi. + Human 28 | batch_size: 64 29 | resolution: 224 30 | num_frames: 16 # TODO 31 | episodic: false 32 | shuffle_buffer_size: 45000 # works fine for 1,600 GB memories, plz adjust based on yout setup 33 | image_aug: true 34 | 35 | trainer: 36 | max_epochs: 20 37 | accelerator: gpu 38 | num_nodes: 1 39 | devices: 8 40 | strategy: ddp_find_unused_parameters_false 41 | precision: 16-mixed 42 | log_every_n_steps: 1000 43 | gradient_clip_val: 0.1 44 | 45 | callbacks: 46 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 47 | init_args: 48 | dirpath: ./logs/task_centric_lam_stage1 49 | verbose: true 50 | save_last: true 51 | save_top_k: -1 52 | every_n_train_steps: 20000 53 | 54 | logger: 55 | - class_path: lightning.pytorch.loggers.TensorBoardLogger 56 | init_args: 57 | save_dir: ./logs 58 | name: task_centric_lam_stage1 -------------------------------------------------------------------------------- /latent_action_model/config/lam-stage-2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | image_channels: 3 3 | 4 | lam_model_dim: 768 5 | lam_latent_dim: 128 6 | lam_num_latents: 16 7 | lam_patch_size: 14 8 | lam_enc_blocks: 12 9 | lam_dec_blocks: 12 10 | lam_num_heads: 12 11 | 12 | vq_beta: 0.25 13 | log_interval: 5000 14 | log_path: ./logs 15 | optimizer: 16 | class_path: torch.optim.AdamW 17 | init_args: 18 | lr: 1e-4 19 | weight_decay: 1e-2 20 | 21 | task_name: task_centric_lam_stage2 22 | make_data_pair: &make_data_pair false 23 | stage: stage-2 24 | stage_one_ckpt: ./logs/task_centric_lam_stage1/epoch=0-step=100000.ckpt 25 | 26 | data: 27 | data_root: /path/to/your/rlds_data_collection 28 | data_mix: omni_magic_soup_plus_plus # Manip. + Navi. + Human 29 | batch_size: 64 30 | resolution: 224 31 | num_frames: 16 # TODO 32 | episodic: false 33 | shuffle_buffer_size: 45000 # works fine for 1,600 GB memories, plz adjust based on yout setup 34 | image_aug: true 35 | 36 | trainer: 37 | max_epochs: 20 38 | accelerator: gpu 39 | num_nodes: 1 40 | devices: 8 41 | strategy: ddp_find_unused_parameters_false 42 | precision: 16-mixed 43 | log_every_n_steps: 1000 44 | gradient_clip_val: 0.1 45 | 46 | callbacks: 47 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 48 | init_args: 49 | dirpath: ./logs/task_centric_lam_stage2 50 | verbose: true 51 | save_last: true 52 | save_top_k: -1 53 | every_n_train_steps: 20000 54 | 55 | logger: 56 | - class_path: lightning.pytorch.loggers.TensorBoardLogger 57 | init_args: 58 | save_dir: ./logs 59 | name: task_centric_lam_stage2 -------------------------------------------------------------------------------- /latent_action_model/genie/dataset.py: -------------------------------------------------------------------------------- 1 | import math 2 | from os import listdir, makedirs, path 3 | from random import choices, randint 4 | from typing import Any, Callable, Dict 5 | 6 | import cv2 as cv 7 | import torch 8 | import torch.nn.functional as F 9 | from einops import rearrange 10 | from lightning import LightningDataModule 11 | from torch import Tensor 12 | from torch.utils.data import DataLoader, Dataset, IterableDataset 13 | from torch.utils.data import get_worker_info 14 | import torchvision.transforms as transforms 15 | from dataclasses import dataclass 16 | 17 | from prismatic.util import set_global_seed 18 | from prismatic.util.data_utils import CollatorForLatentAction, CollatorForMultiViewVideo 19 | from prismatic.vla.datasets import RLDSDataset, EpisodicRLDSDataset, RLDSBatchTransformVideo 20 | 21 | 22 | 23 | def exists(var) -> bool: 24 | return var is not None 25 | 26 | 27 | def default(var, val) -> Any: 28 | return var if exists(var) else val 29 | 30 | 31 | def default_worker_init_fn(worker_id: int) -> None: 32 | torch.manual_seed(torch.initial_seed() + worker_id) 33 | worker_info = get_worker_info() 34 | 35 | if exists(worker_info): 36 | dataset = worker_info.dataset 37 | glob_start = dataset._start 38 | glob_end = dataset._end 39 | 40 | per_worker = int((glob_end - glob_start) / worker_info.num_workers) 41 | worker_id = worker_info.id 42 | 43 | dataset._start = glob_start + worker_id * per_worker 44 | dataset._end = min(dataset._start + per_worker, glob_end) 45 | 46 | 47 | class LightningDataset(LightningDataModule): 48 | """ 49 | Abstract LightningDataModule that represents a dataset we can train a Lightning module on. 50 | """ 51 | 52 | def __init__( 53 | self, 54 | *args, 55 | batch_size: int = 8, 56 | num_workers: int = 64, 57 | train_shuffle: bool = True, 58 | val_shuffle: bool = False, 59 | val_batch_size: int = None, 60 | worker_init_fn: Callable = None, 61 | collate_fn: Callable = None, 62 | train_sampler: Callable = None, 63 | test_sampler: Callable = None, 64 | val_sampler: Callable = None 65 | ) -> None: 66 | super(LightningDataset, self).__init__() 67 | self.train_dataset = None 68 | self.test_dataset = None 69 | self.val_dataset = None 70 | 71 | val_batch_size = default(val_batch_size, batch_size) 72 | 73 | self.num_workers = 0 # For RLDS parallelism 74 | self.batch_size = batch_size 75 | self.val_batch_size = val_batch_size 76 | 77 | # shuffle unspecified for iteratable datasets 78 | # self.train_shuffle = train_shuffle 79 | # self.val_shuffle = val_shuffle 80 | 81 | self.train_sampler = train_sampler 82 | self.test_sampler = test_sampler 83 | self.val_sampler = val_sampler 84 | self.collate_fn = collate_fn 85 | self.worker_init_fn = worker_init_fn 86 | 87 | def train_dataloader(self) -> DataLoader: 88 | if isinstance(self.train_dataset, IterableDataset): 89 | worker_init_fn = default(self.worker_init_fn, default_worker_init_fn) 90 | else: 91 | worker_init_fn = self.worker_init_fn 92 | return DataLoader( 93 | self.train_dataset, 94 | sampler=self.train_sampler, 95 | batch_size=self.batch_size, 96 | # shuffle=self.train_shuffle, 97 | collate_fn=self.collate_fn, 98 | num_workers=self.num_workers, 99 | worker_init_fn=worker_init_fn 100 | ) 101 | 102 | def val_dataloader(self) -> DataLoader: 103 | if isinstance(self.val_dataset, IterableDataset): 104 | worker_init_fn = default(self.worker_init_fn, default_worker_init_fn) 105 | else: 106 | worker_init_fn = self.worker_init_fn 107 | return DataLoader( 108 | self.val_dataset, 109 | sampler=self.val_sampler, 110 | batch_size=self.val_batch_size, 111 | # shuffle=self.val_shuffle, 112 | collate_fn=self.collate_fn, 113 | num_workers=self.num_workers, 114 | worker_init_fn=worker_init_fn 115 | ) 116 | 117 | def test_dataloader(self) -> DataLoader: 118 | if isinstance(self.test_dataset, IterableDataset): 119 | worker_init_fn = default(self.worker_init_fn, default_worker_init_fn) 120 | else: 121 | worker_init_fn = self.worker_init_fn 122 | return DataLoader( 123 | self.test_dataset, 124 | sampler=self.test_sampler, 125 | batch_size=self.val_batch_size, 126 | # shuffle=self.val_shuffle, 127 | collate_fn=self.collate_fn, 128 | num_workers=self.num_workers, 129 | worker_init_fn=worker_init_fn 130 | ) 131 | 132 | 133 | 134 | from PIL import Image 135 | import random 136 | 137 | @dataclass 138 | class random_crop_resize(): 139 | def __init__( 140 | self, 141 | target_size=224 142 | ): 143 | self.target_size = target_size 144 | self.to_tensor = transforms.ToTensor() 145 | 146 | def __call__(self, image): 147 | width, height = image.size 148 | 149 | if width < height: 150 | crop_size = width 151 | else: 152 | crop_size = height 153 | 154 | left = random.randint(0, width - crop_size) 155 | top = random.randint(0, height - crop_size) 156 | 157 | image_cropped = image.crop((left, top, left + crop_size, top + crop_size)) 158 | image_resized = image_cropped.resize((self.target_size, self.target_size), Image.BILINEAR) 159 | image_resized = self.to_tensor(image_resized) 160 | 161 | return image_resized 162 | 163 | 164 | 165 | class LightningOpenX(LightningDataset): 166 | """ 167 | This dataset samples video recorded using a random agent 168 | playing the gym environments defined in the Procgen Benchmark, 169 | see Cobbe et al. ICML (2020). 170 | """ 171 | 172 | def __init__( 173 | self, 174 | data_root: str, 175 | data_mix: str, 176 | batch_size:int = 16, 177 | resolution: int = 256, 178 | num_frames: int = 16, 179 | episodic: bool = False, 180 | shuffle_buffer_size: int = 100_000, 181 | image_aug:bool = False, 182 | **kwargs 183 | ) -> None: 184 | super(LightningOpenX, self).__init__(**kwargs) 185 | 186 | self.data_root_dir = data_root 187 | self.data_mix = data_mix 188 | 189 | self.batch_size = batch_size 190 | self.resolution = (resolution, resolution) 191 | self.num_frames = num_frames 192 | 193 | self.episodic = episodic 194 | self.shuffle_buffer_size = shuffle_buffer_size 195 | self.image_aug = image_aug 196 | 197 | self.num_workers = 0 # Important =>> Set to 0 if using RLDS; TFDS rolls its own parallelism! 198 | self.worker_init_fn = set_global_seed(42, get_worker_init_fn=True) 199 | 200 | self.batch_transform = RLDSBatchTransformVideo( 201 | image_transform=transforms.ToTensor() 202 | ) 203 | self.collate_fn = CollatorForLatentAction() 204 | 205 | self.save_hyperparameters() 206 | 207 | def setup(self, stage: str) -> None: 208 | cls = RLDSDataset if not self.episodic else EpisodicRLDSDataset 209 | if stage == "fit": 210 | self.train_dataset = cls( 211 | self.data_root_dir, 212 | self.data_mix, 213 | self.batch_transform, 214 | resize_resolution=self.resolution, 215 | shuffle_buffer_size=self.shuffle_buffer_size, 216 | train=True, 217 | image_aug=self.image_aug, 218 | training_phase='lam', 219 | ) 220 | self.val_dataset = cls( 221 | self.data_root_dir, 222 | self.data_mix, 223 | self.batch_transform, 224 | resize_resolution=self.resolution, 225 | shuffle_buffer_size=self.shuffle_buffer_size, 226 | train=False, 227 | image_aug=False, 228 | training_phase='lam', 229 | ) 230 | elif stage == "test": 231 | self.test_dataset = cls( 232 | self.data_root_dir, 233 | self.data_mix, 234 | self.batch_transform, 235 | resize_resolution=self.resolution, 236 | shuffle_buffer_size=self.shuffle_buffer_size, 237 | train=True, 238 | image_aug=False, 239 | training_phase='lam', 240 | ) 241 | else: 242 | raise ValueError(f"Invalid stage: {stage}") 243 | 244 | 245 | -------------------------------------------------------------------------------- /latent_action_model/genie/model.py: -------------------------------------------------------------------------------- 1 | from os import listdir, makedirs, path 2 | from typing import Callable, Dict, Iterable, Tuple 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import piq 7 | import torch 8 | import wandb 9 | from PIL import Image 10 | from einops import rearrange 11 | from lightning import LightningModule 12 | from torch import Tensor 13 | from torch.optim import AdamW, Optimizer 14 | from accelerate import PartialState 15 | 16 | OptimizerCallable = Callable[[Iterable], Optimizer] 17 | 18 | from genie.modules import UncontrolledDINOLatentActionModel, ControllableDINOLatentActionModel 19 | import logging 20 | logging.basicConfig(format='%(message)s', level=logging.INFO) 21 | 22 | 23 | 24 | class DINO_LAM(LightningModule): 25 | """ 26 | A latent action model operates at the DINO latent space 27 | """ 28 | 29 | def __init__( 30 | self, 31 | image_channels: int = 3, 32 | # Latent action model 33 | lam_model_dim: int = 512, 34 | lam_latent_dim: int = 32, 35 | lam_num_latents: int = 8, 36 | lam_patch_size: int = 16, 37 | lam_enc_blocks: int = 8, 38 | lam_dec_blocks: int = 8, 39 | lam_num_heads: int = 8, 40 | lam_dropout: float = 0.0, 41 | vq_beta: float = 0.25, 42 | log_interval: int = 1000, 43 | log_path: str = "log_imgs", 44 | task_name: str = 'lam_openx', 45 | stage: str = 'stage-1', 46 | optimizer: OptimizerCallable = AdamW, 47 | make_data_pair: bool = False, 48 | stage_one_ckpt: str = None, 49 | ) -> None: 50 | super(DINO_LAM, self).__init__() 51 | assert stage in ['stage-1', 'stage-2'] 52 | 53 | lam = UncontrolledDINOLatentActionModel if stage == 'stage-1' else ControllableDINOLatentActionModel 54 | 55 | self.lam = lam( 56 | in_dim=image_channels, 57 | model_dim=lam_model_dim, 58 | latent_dim=lam_latent_dim, 59 | num_latents=lam_num_latents, 60 | patch_size=lam_patch_size, 61 | enc_blocks=lam_enc_blocks, 62 | dec_blocks=lam_dec_blocks, 63 | num_heads=lam_num_heads, 64 | dropout=lam_dropout, 65 | ) 66 | 67 | if stage_one_ckpt and path.exists(stage_one_ckpt): 68 | lam_ckpt = torch.load(stage_one_ckpt)['state_dict'] 69 | stage1_ckpt = {} 70 | for key in lam_ckpt.keys(): 71 | if 'vq' in key or 'action_latent' in key: 72 | stage1_ckpt[key.replace("lam.", "")] = lam_ckpt[key] 73 | self.lam.load_state_dict(stage1_ckpt, strict=False) 74 | 75 | 76 | self.lam_num_latents = lam_num_latents 77 | self.vq_beta = vq_beta 78 | self.log_interval = log_interval 79 | self.log_path = log_path 80 | self.optimizer = optimizer 81 | self.make_data_pair = make_data_pair 82 | 83 | self.save_hyperparameters() 84 | 85 | self.task_name = task_name 86 | self.distributed_state = PartialState() 87 | if self.distributed_state.is_main_process: 88 | wandb.init(name=task_name, reinit=True) 89 | 90 | def shared_step(self, batch: Dict) -> Tuple: 91 | # batch: keys['videos', 'task_instruction', 'action', 'dataset_names'] 92 | 93 | outputs = self.lam(batch) 94 | gt_future_frames = outputs["target"] 95 | 96 | # Compute loss 97 | mse_loss = ((gt_future_frames - outputs["recon"]) ** 2).mean() 98 | q_loss = ((outputs["emb"].detach() - outputs["z"]) ** 2).mean() 99 | commit_loss = ((outputs["emb"] - outputs["z"].detach()) ** 2).mean() 100 | 101 | loss = mse_loss + q_loss + self.vq_beta * commit_loss 102 | 103 | # Optimize uncontrollable queries in stage-2 (the codebook is frozen though) 104 | if "z_q_uncontrol" in outputs.keys(): 105 | q_loss_uncontrol = ((outputs["emb_uncontrol"].detach() - outputs["z_uncontrol"]) ** 2).mean() 106 | commit_loss_uncontrol = ((outputs["emb_uncontrol"]- outputs["z_uncontrol"].detach()) ** 2).mean() 107 | loss = loss + q_loss_uncontrol + self.vq_beta * commit_loss_uncontrol 108 | 109 | # Compute code usage 110 | unique, counts = torch.unique(outputs["indices"], return_counts=True) 111 | index_counts = torch.zeros(self.lam_num_latents, dtype=torch.long).cuda() 112 | index_counts[unique] = counts 113 | code_usage = (index_counts != 0).float().mean() 114 | 115 | loss_logs = ( 116 | ("mse_loss", mse_loss), 117 | ("q_loss", q_loss), 118 | ("commit_loss", commit_loss), 119 | ("code_usage", code_usage), 120 | ) 121 | 122 | if "indices_uncontrol" in outputs.keys(): 123 | unique, counts = torch.unique(outputs["indices_uncontrol"], return_counts=True) 124 | index_counts = torch.zeros(32, dtype=torch.long).cuda() 125 | index_counts[unique] = counts 126 | uncontrol_code_usage = (index_counts != 0).float().mean() 127 | 128 | loss_logs = ( 129 | ("mse_loss", mse_loss), 130 | ("q_loss", q_loss), 131 | ("commit_loss", commit_loss), 132 | ("q_loss_uncontrol", q_loss_uncontrol), 133 | ("commit_loss_uncontrol", commit_loss_uncontrol), 134 | ("code_usage", code_usage), 135 | ("code_usage_uncontrol", uncontrol_code_usage), 136 | ) 137 | 138 | return outputs, loss, loss_logs 139 | 140 | 141 | 142 | def training_step(self, batch: Dict, batch_idx: int) -> Tensor: 143 | # Compute the training loss 144 | outputs, loss, aux_losses = self.shared_step(batch) 145 | 146 | 147 | # Log the training loss 148 | self.log_dict( 149 | {**{"train_loss": loss}, **{f"train/{k}": v for k, v in aux_losses}}, 150 | prog_bar=True, 151 | logger=True, 152 | on_step=True, 153 | on_epoch=True, 154 | sync_dist=True 155 | ) 156 | 157 | if self.distributed_state.is_main_process: 158 | wandb.log({**{"train_loss": loss}, **{f"train/{k}": v for k, v in aux_losses}}) 159 | 160 | return loss 161 | 162 | 163 | @torch.no_grad() 164 | def test_step(self, batch: Dict, batch_idx: int) -> Tensor: 165 | # Compute the test loss 166 | outputs, loss, aux_losses = self.shared_step(batch) 167 | 168 | # Log the test loss 169 | self.log_dict( 170 | {**{"test_loss": loss}, **{f"test/{k}": v for k, v in aux_losses}}, 171 | prog_bar=True, 172 | logger=True, 173 | on_step=True, 174 | on_epoch=True, 175 | sync_dist=True 176 | ) 177 | 178 | return loss 179 | 180 | def on_train_epoch_end(self): 181 | self.lam.vq.random_restart() 182 | self.lam.vq.reset_usage() 183 | 184 | def on_test_epoch_end(self): 185 | if self.make_data_pair: 186 | completed = len(listdir("output_pairs")) 187 | todo_name = listdir("../data/retro")[completed] 188 | makedirs(f"output_pairs/{todo_name}") 189 | top_indices = torch.topk(self.lam.vq.usage, 16, largest=True, sorted=True).indices 190 | top_latents = self.lam.vq.codebook(top_indices) 191 | torch.save(top_latents, f"output_pairs/{todo_name}/top_16.pt") 192 | with open(f"output_pairs/{todo_name}/top_16.txt", "w") as f: 193 | f.write(" ".join([str(i) for i in top_indices.tolist()])) 194 | 195 | self.plot_usage_distribution(self.lam.vq.usage, "unsorted_usage") 196 | self.plot_usage_distribution(self.lam.vq.usage.sort().values, "sorted_usage") 197 | 198 | def plot_usage_distribution(self, usage, filename): 199 | data = usage.cpu().numpy() 200 | n = 1 201 | for n in range(1, 10): 202 | if (2 ** n) ** 2 <= len(data) < (2 ** (n + 1)) ** 2: 203 | break 204 | data = data.reshape(2 ** n, -1) 205 | fig, ax = plt.subplots() 206 | cax = ax.matshow(data, interpolation="nearest") 207 | fig.colorbar(cax) 208 | plt.axis("off") 209 | plt.gca().set_axis_off() 210 | plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) 211 | plt.margins(0, 0) 212 | plt.gca().xaxis.set_major_locator(plt.NullLocator()) 213 | plt.gca().yaxis.set_major_locator(plt.NullLocator()) 214 | plt.savefig(f"{filename}.png", bbox_inches="tight", pad_inches=0.0) 215 | plt.close() 216 | 217 | def configure_optimizers(self) -> Optimizer: 218 | optim = self.optimizer(self.parameters()) 219 | return optim 220 | -------------------------------------------------------------------------------- /latent_action_model/genie/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from latent_action_model.genie.modules.lam import UncontrolledDINOLatentActionModel, ControllableDINOLatentActionModel 2 | -------------------------------------------------------------------------------- /latent_action_model/main.py: -------------------------------------------------------------------------------- 1 | from lightning.pytorch.cli import LightningCLI 2 | from genie.dataset import LightningOpenX 3 | from genie.model import DINO_LAM 4 | 5 | cli = LightningCLI( 6 | DINO_LAM, 7 | LightningOpenX, 8 | seed_everything_default=42, 9 | ) 10 | -------------------------------------------------------------------------------- /latent_action_model/train.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nnodes 1 --nproc-per-node 8 main.py fit \ 2 | --config config/lam-stage-1.yaml \ 3 | 2>&1 | tee lam-stage-1.log -------------------------------------------------------------------------------- /prismatic/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import available_model_names, available_models, get_model_description, load 2 | 3 | 4 | 5 | __version__ = "0.0.1" 6 | __project__ = "OmniEmbodiment" 7 | __author__ = "Qingwen Bu" 8 | __license__ = "Apache License 2.0" 9 | __email__ = "qwbu01@sjtu.edu.cn" -------------------------------------------------------------------------------- /prismatic/conf/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import DatasetConfig, DatasetRegistry 2 | from .models import ModelConfig, ModelRegistry 3 | from .vla import VLAConfig, VLARegistry 4 | -------------------------------------------------------------------------------- /prismatic/conf/datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | datasets.py 3 | 4 | Draccus Dataclass Definition for a DatasetConfig object, with various registered subclasses for each dataset variant 5 | and processing scheme. A given dataset variant (e.g., `llava-lightning`) configures the following attributes: 6 | - Dataset Variant (Identifier) --> e.g., "llava-v15" 7 | - Align Stage Dataset Components (annotations, images) 8 | - Finetune Stage Dataset Components (annotations, images) 9 | - Dataset Root Directory (Path) 10 | """ 11 | 12 | from dataclasses import dataclass 13 | from enum import Enum, unique 14 | from pathlib import Path 15 | from typing import Tuple 16 | 17 | from draccus import ChoiceRegistry 18 | 19 | 20 | @dataclass 21 | class DatasetConfig(ChoiceRegistry): 22 | # fmt: off 23 | dataset_id: str # Unique ID that fully specifies a dataset variant 24 | 25 | # Dataset Components for each Stage in < align | finetune > 26 | align_stage_components: Tuple[Path, Path] # Path to annotation file and images directory for `align` stage 27 | finetune_stage_components: Tuple[Path, Path] # Path to annotation file and images directory for `finetune` stage 28 | 29 | dataset_root_dir: Path # Path to dataset root directory; others paths are relative to root 30 | # fmt: on 31 | 32 | 33 | # [Reproduction] LLaVa-v15 (exact dataset used in all public LLaVa-v15 models) 34 | @dataclass 35 | class LLaVa_V15_Config(DatasetConfig): 36 | dataset_id: str = "llava-v15" 37 | 38 | align_stage_components: Tuple[Path, Path] = ( 39 | Path("download/llava-laion-cc-sbu-558k/chat.json"), 40 | Path("download/llava-laion-cc-sbu-558k/"), 41 | ) 42 | finetune_stage_components: Tuple[Path, Path] = ( 43 | Path("download/llava-v1.5-instruct/llava_v1_5_mix665k.json"), 44 | Path("download/llava-v1.5-instruct/"), 45 | ) 46 | dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms") 47 | 48 | 49 | # [Multimodal-Only] LLava-v15 WITHOUT the Language-Only ShareGPT Data (No Co-Training) 50 | @dataclass 51 | class LLaVa_Multimodal_Only_Config(DatasetConfig): 52 | dataset_id: str = "llava-multimodal" 53 | 54 | align_stage_components: Tuple[Path, Path] = ( 55 | Path("download/llava-laion-cc-sbu-558k/chat.json"), 56 | Path("download/llava-laion-cc-sbu-558k/"), 57 | ) 58 | finetune_stage_components: Tuple[Path, Path] = ( 59 | Path("download/llava-v1.5-instruct/llava_v1_5_stripped625k.json"), 60 | Path("download/llava-v1.5-instruct/"), 61 | ) 62 | dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms") 63 | 64 | 65 | # LLaVa-v15 + LVIS-Instruct-4V 66 | @dataclass 67 | class LLaVa_LVIS4V_Config(DatasetConfig): 68 | dataset_id: str = "llava-lvis4v" 69 | 70 | align_stage_components: Tuple[Path, Path] = ( 71 | Path("download/llava-laion-cc-sbu-558k/chat.json"), 72 | Path("download/llava-laion-cc-sbu-558k/"), 73 | ) 74 | finetune_stage_components: Tuple[Path, Path] = ( 75 | Path("download/llava-v1.5-instruct/llava_v1_5_lvis4v_mix888k.json"), 76 | Path("download/llava-v1.5-instruct/"), 77 | ) 78 | dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms") 79 | 80 | 81 | # LLaVa-v15 + LRV-Instruct 82 | @dataclass 83 | class LLaVa_LRV_Config(DatasetConfig): 84 | dataset_id: str = "llava-lrv" 85 | 86 | align_stage_components: Tuple[Path, Path] = ( 87 | Path("download/llava-laion-cc-sbu-558k/chat.json"), 88 | Path("download/llava-laion-cc-sbu-558k/"), 89 | ) 90 | finetune_stage_components: Tuple[Path, Path] = ( 91 | Path("download/llava-v1.5-instruct/llava_v1_5_lrv_mix1008k.json"), 92 | Path("download/llava-v1.5-instruct/"), 93 | ) 94 | dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms") 95 | 96 | 97 | # LLaVa-v15 + LVIS-Instruct-4V + LRV-Instruct 98 | @dataclass 99 | class LLaVa_LVIS4V_LRV_Config(DatasetConfig): 100 | dataset_id: str = "llava-lvis4v-lrv" 101 | 102 | align_stage_components: Tuple[Path, Path] = ( 103 | Path("download/llava-laion-cc-sbu-558k/chat.json"), 104 | Path("download/llava-laion-cc-sbu-558k/"), 105 | ) 106 | finetune_stage_components: Tuple[Path, Path] = ( 107 | Path("download/llava-v1.5-instruct/llava_v1_5_lvis4v_lrv_mix1231k.json"), 108 | Path("download/llava-v1.5-instruct/"), 109 | ) 110 | dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms") 111 | 112 | 113 | # === Define a Dataset Registry Enum for Reference & Validation =>> all *new* datasets must be added here! === 114 | @unique 115 | class DatasetRegistry(Enum): 116 | # === LLaVa v1.5 === 117 | LLAVA_V15 = LLaVa_V15_Config 118 | 119 | LLAVA_MULTIMODAL_ONLY = LLaVa_Multimodal_Only_Config 120 | 121 | LLAVA_LVIS4V = LLaVa_LVIS4V_Config 122 | LLAVA_LRV = LLaVa_LRV_Config 123 | 124 | LLAVA_LVIS4V_LRV = LLaVa_LVIS4V_LRV_Config 125 | 126 | @property 127 | def dataset_id(self) -> str: 128 | return self.value.dataset_id 129 | 130 | 131 | # Register Datasets in Choice Registry 132 | for dataset_variant in DatasetRegistry: 133 | DatasetConfig.register_subclass(dataset_variant.dataset_id, dataset_variant.value) 134 | -------------------------------------------------------------------------------- /prismatic/conf/vla.py: -------------------------------------------------------------------------------- 1 | """ 2 | vla.py 3 | 4 | Draccus Dataclass Definition for a VLAConfig object, with various registered subclasses for each VLA experiment and 5 | model configuration thereof. A given VLA model (`policy`) configures the following attributes: 6 | - Data Mixture (e.g., Bridge, OXE_MAGIC_SOUP, etc.) 7 | - Base VLM from Prismatic Registry (e.g., `prism-dinosiglip+7b`) 8 | - VLA Model Architecture / Parameters (e.g., freeze vision encoder, last layer finetuning) 9 | - Training / Optimization Hyperparameters 10 | """ 11 | 12 | from dataclasses import dataclass 13 | from enum import Enum, unique 14 | from pathlib import Path 15 | from typing import Optional, Union 16 | 17 | from draccus import ChoiceRegistry 18 | 19 | 20 | @dataclass 21 | class VLAConfig(ChoiceRegistry): 22 | # fmt: off 23 | vla_id: str # Unique VLA Policy ID that fully specifies a configuration variant 24 | base_vlm: Union[str, Path] # Base VLM as ID/Path to Run Directory (e.g., `prism-dinosiglip+7b`) 25 | freeze_vision_backbone: bool # Freeze Vision Backbone Parameters (akin to pretraining) 26 | freeze_llm_backbone: bool # Freeze LLM Backbone parameters 27 | unfreeze_last_llm_layer: bool # Unfreeze final layer of LLM (only takes effect if LLM is frozen) 28 | 29 | # Data Mixture Parameters 30 | data_mix: str # Open-X Embodiment Dataset =>> Unique Mixture ID (e.g., `bridge`) 31 | shuffle_buffer_size: int # Size of Shuffle Buffer (100K for Bridge, 1M for OXE) 32 | 33 | # Optimization Parameters 34 | epochs: int # Epochs to Run (in case `max_steps` is not specified) 35 | max_steps: Optional[int] # [Optional] Max Gradient Steps to Run (overrides `epochs`) 36 | 37 | expected_world_size: int # Expected # of GPUs =>> allows us to gate training on hardware 38 | global_batch_size: int # Global Batch Size (divided across processes / world size) 39 | per_device_batch_size: int # Per-Device Batch Size (per-process / individual GPU) 40 | # =>> # of accumulation steps is auto-computed 41 | 42 | learning_rate: float # Peak Learning Rate (`lr_scheduler_type` sets warmup/decay) 43 | weight_decay: float # Weight Decay for AdamW Optimizer 44 | max_grad_norm: float # Max Grad Norm (for global gradient clipping) 45 | lr_scheduler_type: str # LR Scheduler (usually: "constant" | "linear-warmup+cosine-decay") 46 | warmup_ratio: float # Fraction of Steps to Warmup (for warmup LR schedulers) 47 | 48 | train_strategy: str # Train Strategy (default "fsdp-full-shard") 49 | 50 | # Enable Gradient/Activation Checkpointing (for the LLM Backbone) 51 | enable_gradient_checkpointing: bool = True # Enable Gradient/Activation Checkpointing during Training 52 | 53 | # Mixed Precision Training via Torch Native AMP (`autocast`) 54 | enable_mixed_precision_training: bool = True # Enable Traditional BF16 Mixed Precision 55 | reduce_in_full_precision: bool = True # Accumulate/Reduce All-Gather Gradients in FP32 Full Precision 56 | 57 | # fmt: on 58 | 59 | 60 | # === OpenVLA Training Configurations === 61 | 62 | 63 | # = [8 GPU] Fast Iteration =>> SigLIP 224px + Bridge = 64 | @dataclass 65 | class Exp_SigLIP_224px_Bridge(VLAConfig): 66 | vla_id: str = "siglip-224px+mx-bridge" 67 | base_vlm: Union[str, Path] = "siglip-224px+7b" 68 | 69 | freeze_vision_backbone: bool = False 70 | freeze_llm_backbone: bool = False 71 | unfreeze_last_llm_layer: bool = True 72 | 73 | # Data Mixture Parameters 74 | data_mix: str = "oxe_magic_soup_plus" 75 | shuffle_buffer_size: int = 20_000 76 | 77 | # Optimization Parameters 78 | epochs: int = 10 79 | max_steps: Optional[int] = None 80 | 81 | expected_world_size: int = 8 82 | global_batch_size: int = 256 83 | per_device_batch_size: int = 32 84 | 85 | learning_rate: float = 2e-5 86 | weight_decay: float = 0.0 87 | max_grad_norm: float = 1.0 88 | lr_scheduler_type: str = "constant" 89 | warmup_ratio: float = 0.0 90 | 91 | train_strategy: str = "fsdp-full-shard" 92 | 93 | 94 | 95 | 96 | 97 | # = [8 GPU] Fast Iteration =>> DINO-SigLIP 224px + Bridge = 98 | @dataclass 99 | class Exp_DinoSigLIP_224px_Bridge(Exp_SigLIP_224px_Bridge): 100 | vla_id: str = "prism-dinosiglip-224px+mx-bridge" 101 | base_vlm: Union[str, Path] = "prism-dinosiglip-224px+7b" 102 | 103 | data_mix: str = "bridge" 104 | 105 | 106 | @dataclass 107 | class Exp_DinoSigLIP_224px_Human(Exp_SigLIP_224px_Bridge): 108 | vla_id: str = "prism-dinosiglip-224px+mx-human" 109 | base_vlm: Union[str, Path] = "prism-dinosiglip-224px+7b" 110 | 111 | data_mix: str = "Ego4D" 112 | 113 | 114 | # = [32 GPU Pre-training] DINO-SigLIP 224px + Magic Soup++ = 115 | @dataclass 116 | class Exp_DinoSigLIP_224px_OXE_Magic_Soup_Plus(Exp_SigLIP_224px_Bridge): 117 | vla_id: str = "prism-dinosiglip-224px+mx-oxe-magic-soup-plus" 118 | base_vlm: Union[str, Path] = "prism-dinosiglip-224px+7b" 119 | 120 | data_mix: str = "omni_magic_soup_plus" # OpenX (Manipulation + Navigation) 121 | # data_mix: str = "omni_magic_soup_plus_plus" # OpenX + Humam 122 | 123 | expected_world_size: int = 32 124 | global_batch_size: int = 1024 125 | per_device_batch_size: int = 32 126 | 127 | 128 | 129 | 130 | # === Define a VLA Registry Enum for Reference & Validation === 131 | @unique 132 | class VLARegistry(Enum): 133 | # Sanity Check Configurations =>> BridgeV2 134 | SIGLIP_224PX_MX_BRIDGE = Exp_SigLIP_224px_Bridge 135 | 136 | # Pre-training on Bridge-v2 data only 137 | DINOSIGLIP_224PX_MX_BRIDGE = Exp_DinoSigLIP_224px_Bridge 138 | 139 | # Pre-training on Human data only 140 | DINOSIGLIP_224PX_MX_HUMAN = Exp_DinoSigLIP_224px_Human 141 | 142 | # Pre-training on full dataset 143 | DINOSIGLIP_224PX_MX_OXE_MAGIC_SOUP_PLUS = Exp_DinoSigLIP_224px_OXE_Magic_Soup_Plus 144 | 145 | 146 | @property 147 | def vla_id(self) -> str: 148 | return self.value.vla_id 149 | 150 | 151 | # Register VLAs in Choice Registry 152 | for vla_variant in VLARegistry: 153 | VLAConfig.register_subclass(vla_variant.vla_id, vla_variant.value) 154 | -------------------------------------------------------------------------------- /prismatic/extern/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenDriveLab/UniVLA/c9c6788514a707daf536848197b9a5024fc85b6e/prismatic/extern/__init__.py -------------------------------------------------------------------------------- /prismatic/extern/hf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenDriveLab/UniVLA/c9c6788514a707daf536848197b9a5024fc85b6e/prismatic/extern/hf/__init__.py -------------------------------------------------------------------------------- /prismatic/extern/hf/configuration_prismatic.py: -------------------------------------------------------------------------------- 1 | """ 2 | configuration_prismatic.py 3 | 4 | HuggingFace-style configuration definition for Prismatic VLMs, inheriting from `transformers.PretrainedConfig`. 5 | Default configuration specifies `siglip-224px+7b`. 6 | """ 7 | 8 | from typing import Any, Dict, List, Optional 9 | 10 | from transformers import PretrainedConfig 11 | from transformers.models.auto import CONFIG_MAPPING 12 | 13 | # === Utilities for Mapping Prismatic names to HF names === 14 | # fmt: off 15 | VISION_BACKBONE_TO_RESOLUTION: Dict[str, List[int]] = { 16 | "clip-vit-l": [224], "siglip-vit-so400m": [224], "dinov2-vit-l": [224], "in1k-vit-l": [224], 17 | 18 | "clip-vit-l-336px": [336], 19 | "siglip-vit-so400m-384px": [384], 20 | 21 | "dinoclip-vit-l-336px": [336, 336], 22 | "dinosiglip-vit-so-224px": [224, 224], 23 | "dinosiglip-vit-so-384px": [384, 384], 24 | } 25 | VISION_BACKBONE_TO_TIMM_ID: Dict[str, List[str]] = { 26 | "clip-vit-l": ["vit_large_patch14_clip_224.openai"], 27 | "clip-vit-l-336px": ["vit_large_patch14_clip_336.openai"], 28 | 29 | "dinov2-vit-l": ["vit_large_patch14_reg4_dinov2.lvd142m"], 30 | "in1k-vit-l": ["vit_large_patch16_224.augreg_in21k_ft_in1k"], 31 | 32 | "siglip-vit-so400m": ["vit_so400m_patch14_siglip_224"], 33 | "siglip-vit-so400m-384px": ["vit_so400m_patch14_siglip_384"], 34 | 35 | "dinoclip-vit-l-336px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_large_patch14_clip_336.openai"], 36 | "dinosiglip-vit-so-224px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_224"], 37 | "dinosiglip-vit-so-384px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_384"], 38 | } 39 | TIMM_OVERRIDE_ACT_LAYER: Dict[str, List[Optional[str]]] = { 40 | "clip-vit-l": ["quick_gelu"], "clip-vit-l-336px": ["quick_gelu"], 41 | "dinov2-vit-l": [None], "in1k-vit-l": [None], 42 | "siglip-vit-so400m": [None], "siglip-vit-so400m-384px": [None], 43 | "dinoclip-vit-l-336px": [None, "quick_gelu"], 44 | "dinosiglip-vit-so-224px": [None, None], "dinosiglip-vit-so-384px": [None, None] 45 | } 46 | 47 | LLM_BACKBONE_TO_HF_PATH = { 48 | "llama2-7b-pure": "meta-llama/Llama-2-7b-hf", "llama2-13b-pure": "meta-llama/Llama-2-13b-hf", 49 | "llama2-7b-chat": "meta-llama/Llama-2-7b-chat-hf", "llama2-13b-chat": "meta-llama/Llama-2-13b-chat-hf", 50 | 51 | "vicuna-v15-7b": "lmsys/vicuna-7b-v1.5", "vicuna-v15-13b": "lmsys/vicuna-13b-v1.5", 52 | 53 | "mistral-v0.1-7b-pure": "mistralai/Mistral-7B-v0.1", 54 | "mistral-v0.1-7b-instruct": "mistralai/Mistral-7B-Instruct-v0.1", 55 | 56 | "phi-2-3b": "microsoft/phi-2", 57 | } 58 | LLM_BACKBONE_TO_HF_METACLASS = { 59 | "llama2-7b-pure": "llama", "llama2-13b-pure": "llama", "llama2-7b-chat": "llama", "llama2-13b-chat": "llama", 60 | "vicuna-v15-7b": "llama", "vicuna-v15-13b": "llama", 61 | 62 | "mistral-v0.1-7b-pure": "mistral", "mistral-v0.1-7b-instruct": "mistral", 63 | 64 | "phi-2-3b": "phi", 65 | } 66 | 67 | VALID_VISION_BACKBONES = set(VISION_BACKBONE_TO_RESOLUTION.keys()) 68 | VALID_LLM_BACKBONES = set(LLM_BACKBONE_TO_HF_PATH) 69 | # fmt: on 70 | 71 | 72 | class PrismaticConfig(PretrainedConfig): 73 | model_type: str = "prismatic" 74 | is_composition: bool = False 75 | 76 | def __init__( 77 | self, 78 | vision_backbone_id: str = "siglip-vit-so400m", 79 | llm_backbone_id: str = "vicuna-v15-7b", 80 | arch_specifier: str = "no-align+gelu-mlp", 81 | use_fused_vision_backbone: Optional[bool] = None, 82 | image_resize_strategy: str = "letterbox", 83 | text_config: Optional[Dict[str, Any]] = None, 84 | llm_max_length: int = 2048, 85 | pad_token_id: int = 32000, 86 | pad_to_multiple_of: int = 64, 87 | output_projector_states: bool = False, 88 | **kwargs: str, 89 | ) -> None: 90 | if vision_backbone_id not in VALID_VISION_BACKBONES: 91 | raise ValueError(f"Vision backbone `{vision_backbone_id}` not in {VALID_VISION_BACKBONES = }") 92 | 93 | if llm_backbone_id not in VALID_LLM_BACKBONES: 94 | raise ValueError(f"LLM backbone `{llm_backbone_id}` not in {VALID_LLM_BACKBONES = }") 95 | 96 | # Set Prismatic Configuration Fields 97 | self.vision_backbone_id = vision_backbone_id 98 | self.llm_backbone_id = llm_backbone_id 99 | self.arch_specifier = arch_specifier 100 | self.output_projector_states = output_projector_states 101 | 102 | # [Contract] All vision backbone parameters are lists =>> supports fused backbones with different preprocessing 103 | self.use_fused_vision_backbone = ( 104 | use_fused_vision_backbone 105 | if use_fused_vision_backbone is not None 106 | else any(self.vision_backbone_id.startswith(v) for v in ["dinoclip", "dinosiglip"]) 107 | ) 108 | 109 | self.timm_model_ids = VISION_BACKBONE_TO_TIMM_ID[self.vision_backbone_id] 110 | self.timm_override_act_layers = TIMM_OVERRIDE_ACT_LAYER[self.vision_backbone_id] 111 | self.image_sizes = VISION_BACKBONE_TO_RESOLUTION[self.vision_backbone_id] 112 | self.image_resize_strategy = image_resize_strategy 113 | 114 | self.hf_llm_id = LLM_BACKBONE_TO_HF_PATH[self.llm_backbone_id] 115 | self.llm_max_length = llm_max_length 116 | self.pad_token_id, self.pad_to_multiple_of = pad_token_id, pad_to_multiple_of 117 | 118 | # [IMPORTANT] HF Utilities actually look for a `text_config` field... we need to use that specific naming! 119 | self.text_config = ( 120 | CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]](**text_config) 121 | if text_config is not None 122 | else CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]]() 123 | ) 124 | 125 | # Dispatch **kwargs to super() =>> note that `pad_token_id` collides, so we pass it in here as well... 126 | super().__init__(pad_token_id=pad_token_id, **kwargs) 127 | 128 | 129 | class OpenVLAConfig(PrismaticConfig): 130 | model_type: str = "openvla" 131 | 132 | def __init__( 133 | self, 134 | norm_stats: Optional[Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]]] = None, 135 | n_action_bins: int = 256, 136 | **kwargs: str, 137 | ) -> None: 138 | self.norm_stats, self.n_action_bins = norm_stats, n_action_bins 139 | 140 | super().__init__(**kwargs) 141 | -------------------------------------------------------------------------------- /prismatic/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .load import available_model_names, available_models, get_model_description, load, load_vla 2 | from .materialize import get_llm_backbone_and_tokenizer, get_vision_backbone_and_transform, get_vlm 3 | -------------------------------------------------------------------------------- /prismatic/models/backbones/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenDriveLab/UniVLA/c9c6788514a707daf536848197b9a5024fc85b6e/prismatic/models/backbones/__init__.py -------------------------------------------------------------------------------- /prismatic/models/backbones/llm/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_llm import LLMBackbone 2 | from .llama2 import LLaMa2LLMBackbone 3 | from .mistral import MistralLLMBackbone 4 | from .phi import PhiLLMBackbone 5 | -------------------------------------------------------------------------------- /prismatic/models/backbones/llm/llama2.py: -------------------------------------------------------------------------------- 1 | """ 2 | llama2.py 3 | 4 | Class definition for all LLMs derived from LlamaForCausalLM. 5 | """ 6 | 7 | from typing import Optional, Sequence, Type 8 | 9 | import torch 10 | from torch import nn as nn 11 | from transformers import LlamaForCausalLM 12 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer 13 | 14 | from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone 15 | from prismatic.models.backbones.llm.prompting import ( 16 | LLaMa2ChatPromptBuilder, 17 | PromptBuilder, 18 | PurePromptBuilder, 19 | VicunaV15ChatPromptBuilder, 20 | ) 21 | 22 | # Registry =>> Support LLaMa-2 Models (from HF Transformers) 23 | # fmt: off 24 | LLAMA2_MODELS = { 25 | # === Pure Meta LLaMa-2 (non-instruct/chat-tuned) Models === 26 | "llama2-7b-pure": { 27 | "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/llama2-7b-hf" 28 | }, 29 | 30 | "llama2-13b-pure": { 31 | "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-13b-hf" 32 | }, 33 | 34 | # === Meta LLaMa-2 Chat Models === 35 | "llama2-7b-chat": { 36 | "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-7b-chat-hf" 37 | }, 38 | 39 | "llama2-13b-chat": { 40 | "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-13b-chat-hf" 41 | }, 42 | 43 | # === Vicuna v1.5 Chat Models === 44 | "vicuna-v15-7b": { 45 | "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "lmsys/vicuna-7b-v1.5" 46 | }, 47 | 48 | "vicuna-v15-13b": { 49 | "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "lmsys/vicuna-13b-v1.5" 50 | }, 51 | } 52 | # fmt: on 53 | 54 | 55 | class LLaMa2LLMBackbone(HFCausalLLMBackbone): 56 | def __init__( 57 | self, 58 | llm_backbone_id: str, 59 | llm_max_length: int = 2048, 60 | hf_token: Optional[str] = None, 61 | inference_mode: bool = False, 62 | use_flash_attention_2: bool = True, 63 | ) -> None: 64 | super().__init__( 65 | llm_backbone_id, 66 | llm_max_length=llm_max_length, 67 | hf_token=hf_token, 68 | inference_mode=inference_mode, 69 | use_flash_attention_2=use_flash_attention_2, 70 | **LLAMA2_MODELS[llm_backbone_id], 71 | ) 72 | 73 | # [Special Case] LLaMa-2 PAD Token Handling --> for clarity, we add an extra token (and resize) 74 | self.tokenizer.add_special_tokens({"pad_token": ""}) 75 | self.llm.config.pad_token_id = self.tokenizer.pad_token_id 76 | self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64) 77 | 78 | @property 79 | def prompt_builder_fn(self) -> Type[PromptBuilder]: 80 | if self.identifier.startswith("llama2-") and self.identifier.endswith("-pure"): 81 | return PurePromptBuilder 82 | 83 | elif self.identifier.startswith("llama2-") and self.identifier.endswith("-chat"): 84 | return LLaMa2ChatPromptBuilder 85 | 86 | elif self.identifier.startswith("vicuna"): 87 | return VicunaV15ChatPromptBuilder 88 | 89 | raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`") 90 | 91 | @property 92 | def transformer_layer_cls(self) -> Type[nn.Module]: 93 | return LlamaDecoderLayer 94 | 95 | @property 96 | def half_precision_dtype(self) -> torch.dtype: 97 | """LLaMa-2 was trained in BF16; see https://huggingface.co/docs/transformers/main/model_doc/llama2.""" 98 | return torch.bfloat16 99 | 100 | @property 101 | def last_layer_finetune_modules(self) -> Sequence[nn.Module]: 102 | return (self.llm.model.embed_tokens, self.llm.model.layers[-1], self.llm.lm_head) 103 | -------------------------------------------------------------------------------- /prismatic/models/backbones/llm/mistral.py: -------------------------------------------------------------------------------- 1 | """ 2 | mistral.py 3 | 4 | Class definition for all LLMs derived from MistralForCausalLM. 5 | """ 6 | 7 | from typing import Optional, Type 8 | 9 | import torch 10 | from torch import nn as nn 11 | from transformers import MistralForCausalLM 12 | from transformers.models.mistral.modeling_mistral import MistralDecoderLayer 13 | 14 | from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone 15 | from prismatic.models.backbones.llm.prompting import MistralInstructPromptBuilder, PromptBuilder, PurePromptBuilder 16 | 17 | # Registry =>> Support Mistral Models (from HF Transformers) 18 | # fmt: off 19 | MISTRAL_MODELS = { 20 | # === Base Mistral v0.1 === 21 | "mistral-v0.1-7b-pure": { 22 | "llm_family": "mistral", "llm_cls": MistralForCausalLM, "hf_hub_path": "mistralai/Mistral-7B-v0.1" 23 | }, 24 | 25 | # === Mistral Instruct v0.1 === 26 | "mistral-v0.1-7b-instruct": { 27 | "llm_family": "mistral", "llm_cls": MistralForCausalLM, "hf_hub_path": "mistralai/Mistral-7B-Instruct-v0.1" 28 | } 29 | } 30 | # fmt: on 31 | 32 | 33 | class MistralLLMBackbone(HFCausalLLMBackbone): 34 | def __init__( 35 | self, 36 | llm_backbone_id: str, 37 | llm_max_length: int = 2048, 38 | hf_token: Optional[str] = None, 39 | inference_mode: bool = False, 40 | use_flash_attention_2: bool = True, 41 | ) -> None: 42 | super().__init__( 43 | llm_backbone_id, 44 | llm_max_length=llm_max_length, 45 | hf_token=hf_token, 46 | inference_mode=inference_mode, 47 | use_flash_attention_2=use_flash_attention_2, 48 | **MISTRAL_MODELS[llm_backbone_id], 49 | ) 50 | 51 | # [Special Case] Mistral PAD Token Handling --> for clarity, we add an extra token (and resize) 52 | self.tokenizer.add_special_tokens({"pad_token": ""}) 53 | self.llm.config.pad_token_id = self.tokenizer.pad_token_id 54 | self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64) 55 | 56 | @property 57 | def prompt_builder_fn(self) -> Type[PromptBuilder]: 58 | if self.identifier.endswith("-pure"): 59 | return PurePromptBuilder 60 | 61 | elif self.identifier.endswith("-instruct"): 62 | return MistralInstructPromptBuilder 63 | 64 | raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`") 65 | 66 | @property 67 | def transformer_layer_cls(self) -> Type[nn.Module]: 68 | return MistralDecoderLayer 69 | 70 | @property 71 | def half_precision_dtype(self) -> torch.dtype: 72 | return torch.bfloat16 73 | -------------------------------------------------------------------------------- /prismatic/models/backbones/llm/phi.py: -------------------------------------------------------------------------------- 1 | """ 2 | phi.py 3 | 4 | Class definition for all LLMs derived from PhiForCausalLM. 5 | """ 6 | 7 | from typing import Optional, Type 8 | 9 | import torch 10 | from torch import nn as nn 11 | from transformers import PhiForCausalLM 12 | from transformers.models.phi.modeling_phi import PhiDecoderLayer 13 | 14 | from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone 15 | from prismatic.models.backbones.llm.prompting import PhiPromptBuilder, PromptBuilder 16 | 17 | # Registry ==> Support Phi Models (from HF Transformers) 18 | # fmt: off 19 | PHI_MODELS = { 20 | # === Phi-2 === 21 | "phi-2-3b": { 22 | "llm_family": "phi", "llm_cls": PhiForCausalLM, "hf_hub_path": "microsoft/phi-2" 23 | } 24 | } 25 | # fmt: on 26 | 27 | 28 | class PhiLLMBackbone(HFCausalLLMBackbone): 29 | def __init__( 30 | self, 31 | llm_backbone_id: str, 32 | llm_max_length: int = 2048, 33 | hf_token: Optional[str] = None, 34 | inference_mode: bool = False, 35 | use_flash_attention_2: bool = True, 36 | ) -> None: 37 | super().__init__( 38 | llm_backbone_id, 39 | llm_max_length=llm_max_length, 40 | hf_token=hf_token, 41 | inference_mode=inference_mode, 42 | use_flash_attention_2=use_flash_attention_2, 43 | **PHI_MODELS[llm_backbone_id], 44 | ) 45 | 46 | # [Special Case] Phi PAD Token Handling --> for clarity, we add an extra token (and resize) 47 | self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) 48 | self.llm.config.pad_token_id = self.tokenizer.pad_token_id 49 | self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64) 50 | 51 | @property 52 | def prompt_builder_fn(self) -> Type[PromptBuilder]: 53 | if self.identifier.startswith("phi-2"): 54 | return PhiPromptBuilder 55 | 56 | raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`") 57 | 58 | @property 59 | def transformer_layer_cls(self) -> Type[nn.Module]: 60 | return PhiDecoderLayer 61 | 62 | @property 63 | def half_precision_dtype(self) -> torch.dtype: 64 | return torch.bfloat16 65 | -------------------------------------------------------------------------------- /prismatic/models/backbones/llm/prompting/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_prompter import PromptBuilder, PurePromptBuilder 2 | from .llama2_chat_prompter import LLaMa2ChatPromptBuilder 3 | from .mistral_instruct_prompter import MistralInstructPromptBuilder 4 | from .phi_prompter import PhiPromptBuilder 5 | from .vicuna_v15_prompter import VicunaV15ChatPromptBuilder 6 | -------------------------------------------------------------------------------- /prismatic/models/backbones/llm/prompting/base_prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | base_prompter.py 3 | 4 | Abstract class definition of a multi-turn prompt builder for ensuring consistent formatting for chat-based LLMs. 5 | """ 6 | 7 | from abc import ABC, abstractmethod 8 | from typing import Optional 9 | 10 | 11 | class PromptBuilder(ABC): 12 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 13 | self.model_family = model_family 14 | 15 | # Only some models define a system prompt => let subclasses handle this logic! 16 | self.system_prompt = system_prompt 17 | 18 | @abstractmethod 19 | def add_turn(self, role: str, message: str) -> str: ... 20 | 21 | @abstractmethod 22 | def get_potential_prompt(self, user_msg: str) -> None: ... 23 | 24 | @abstractmethod 25 | def get_prompt(self) -> str: ... 26 | 27 | 28 | class PurePromptBuilder(PromptBuilder): 29 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 30 | super().__init__(model_family, system_prompt) 31 | 32 | # TODO (siddk) =>> Can't always assume LlamaTokenizer --> FIX ME! 33 | self.bos, self.eos = "", "" 34 | 35 | # Get role-specific "wrap" functions 36 | self.wrap_human = lambda msg: f"In: {msg}\nOut: " 37 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" 38 | 39 | # === `self.prompt` gets built up over multiple turns === 40 | self.prompt, self.turn_count = "", 0 41 | 42 | def add_turn(self, role: str, message: str) -> str: 43 | assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") 44 | message = message.replace("", "").strip() 45 | 46 | if (self.turn_count % 2) == 0: 47 | human_message = self.wrap_human(message) 48 | wrapped_message = human_message 49 | else: 50 | gpt_message = self.wrap_gpt(message) 51 | wrapped_message = gpt_message 52 | 53 | # Update Prompt 54 | self.prompt += wrapped_message 55 | 56 | # Bump Turn Counter 57 | self.turn_count += 1 58 | 59 | # Return "wrapped_message" (effective string added to context) 60 | return wrapped_message 61 | 62 | def get_potential_prompt(self, message: str) -> None: 63 | # Assumes that it's always the user's (human's) turn! 64 | prompt_copy = str(self.prompt) 65 | 66 | human_message = self.wrap_human(message) 67 | prompt_copy += human_message 68 | 69 | return prompt_copy.removeprefix(self.bos).rstrip() 70 | 71 | def get_prompt(self) -> str: 72 | # Remove prefix (if exists) because it gets auto-inserted by tokenizer! 73 | return self.prompt.removeprefix(self.bos).rstrip() 74 | -------------------------------------------------------------------------------- /prismatic/models/backbones/llm/prompting/llama2_chat_prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | llama2_prompter.py 3 | 4 | Defines a PromptBuilder for building LLaMa-2 Chat Prompts --> not sure if this is "optimal", but this is the pattern 5 | that's used by HF and other online tutorials. 6 | 7 | Reference: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 8 | """ 9 | 10 | from typing import Optional 11 | 12 | from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder 13 | 14 | # Default System Prompt for Prismatic Models 15 | SYS_PROMPTS = { 16 | "prismatic": ( 17 | "You are a helpful language and vision assistant. " 18 | "You are able to understand the visual content that the user provides, " 19 | "and assist the user with a variety of tasks using natural language." 20 | ), 21 | "openvla": ( 22 | "You are a helpful language and vision assistant. " 23 | "You are able to understand the visual content that the user provides, " 24 | "and assist the user with a variety of tasks using natural language." 25 | ), 26 | } 27 | 28 | 29 | def format_system_prompt(system_prompt: str) -> str: 30 | return f"<\n{system_prompt.strip()}\n<>\n\n" 31 | 32 | 33 | class LLaMa2ChatPromptBuilder(PromptBuilder): 34 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 35 | super().__init__(model_family, system_prompt) 36 | self.system_prompt = format_system_prompt( 37 | SYS_PROMPTS[self.model_family] if system_prompt is None else system_prompt 38 | ) 39 | 40 | # LLaMa-2 Specific 41 | self.bos, self.eos = "", "" 42 | 43 | # Get role-specific "wrap" functions 44 | self.wrap_human = lambda msg: f"[INST] {msg} [/INST] " 45 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" 46 | 47 | # === `self.prompt` gets built up over multiple turns === 48 | self.prompt, self.turn_count = "", 0 49 | 50 | def add_turn(self, role: str, message: str) -> str: 51 | assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") 52 | message = message.replace("", "").strip() 53 | 54 | # Special Handling for "system" prompt (turn_count == 0) 55 | if self.turn_count == 0: 56 | sys_message = self.wrap_human(self.system_prompt + message) 57 | wrapped_message = sys_message 58 | elif (self.turn_count % 2) == 0: 59 | human_message = self.wrap_human(message) 60 | wrapped_message = human_message 61 | else: 62 | gpt_message = self.wrap_gpt(message) 63 | wrapped_message = gpt_message 64 | 65 | # Update Prompt 66 | self.prompt += wrapped_message 67 | 68 | # Bump Turn Counter 69 | self.turn_count += 1 70 | 71 | # Return "wrapped_message" (effective string added to context) 72 | return wrapped_message 73 | 74 | def get_potential_prompt(self, message: str) -> None: 75 | # Assumes that it's always the user's (human's) turn! 76 | prompt_copy = str(self.prompt) 77 | 78 | # Special Handling for "system" prompt (turn_count == 0) 79 | if self.turn_count == 0: 80 | sys_message = self.wrap_human(self.system_prompt + message) 81 | prompt_copy += sys_message 82 | 83 | else: 84 | human_message = self.wrap_human(message) 85 | prompt_copy += human_message 86 | 87 | return prompt_copy.removeprefix(self.bos).rstrip() 88 | 89 | def get_prompt(self) -> str: 90 | # Remove prefix because it gets auto-inserted by tokenizer! 91 | return self.prompt.removeprefix(self.bos).rstrip() 92 | -------------------------------------------------------------------------------- /prismatic/models/backbones/llm/prompting/mistral_instruct_prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | mistral_instruct_prompter.py 3 | 4 | Defines a PromptBuilder for building Mistral Instruct Chat Prompts --> recommended pattern used by HF / online tutorial.s 5 | 6 | Reference: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format 7 | """ 8 | 9 | from typing import Optional 10 | 11 | from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder 12 | 13 | 14 | class MistralInstructPromptBuilder(PromptBuilder): 15 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 16 | super().__init__(model_family, system_prompt) 17 | 18 | # Note =>> Mistral Tokenizer is an instance of `LlamaTokenizer(Fast)` 19 | # =>> Mistral Instruct *does not* use a System Prompt 20 | self.bos, self.eos = "", "" 21 | 22 | # Get role-specific "wrap" functions 23 | self.wrap_human = lambda msg: f"[INST] {msg} [/INST] " 24 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" 25 | 26 | # === `self.prompt` gets built up over multiple turns === 27 | self.prompt, self.turn_count = "", 0 28 | 29 | def add_turn(self, role: str, message: str) -> str: 30 | assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") 31 | message = message.replace("", "").strip() 32 | 33 | if (self.turn_count % 2) == 0: 34 | human_message = self.wrap_human(message) 35 | wrapped_message = human_message 36 | else: 37 | gpt_message = self.wrap_gpt(message) 38 | wrapped_message = gpt_message 39 | 40 | # Update Prompt 41 | self.prompt += wrapped_message 42 | 43 | # Bump Turn Counter 44 | self.turn_count += 1 45 | 46 | # Return "wrapped_message" (effective string added to context) 47 | return wrapped_message 48 | 49 | def get_potential_prompt(self, message: str) -> None: 50 | # Assumes that it's always the user's (human's) turn! 51 | prompt_copy = str(self.prompt) 52 | 53 | human_message = self.wrap_human(message) 54 | prompt_copy += human_message 55 | 56 | return prompt_copy.removeprefix(self.bos).rstrip() 57 | 58 | def get_prompt(self) -> str: 59 | # Remove prefix because it gets auto-inserted by tokenizer! 60 | return self.prompt.removeprefix(self.bos).rstrip() 61 | -------------------------------------------------------------------------------- /prismatic/models/backbones/llm/prompting/phi_prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | phi_prompter.py 3 | 4 | Defines a PromptBuilder for building Phi-2 Input/Output Prompts --> recommended pattern used by HF / Microsoft. 5 | Also handles Phi special case BOS token additions. 6 | 7 | Reference: https://huggingface.co/microsoft/phi-2#qa-format 8 | """ 9 | 10 | from typing import Optional 11 | 12 | from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder 13 | 14 | 15 | class PhiPromptBuilder(PromptBuilder): 16 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 17 | super().__init__(model_family, system_prompt) 18 | 19 | # Note =>> Phi Tokenizer is an instance of `CodeGenTokenizer(Fast)` 20 | # =>> By default, does *not* append / tokens --> we handle that here (IMPORTANT)! 21 | self.bos, self.eos = "<|endoftext|>", "<|endoftext|>" 22 | 23 | # Get role-specific "wrap" functions 24 | # =>> Note that placement of / were based on experiments generating from Phi-2 in Input/Output mode 25 | self.wrap_human = lambda msg: f"Input: {msg}\nOutput: " 26 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}\n{self.eos}" 27 | 28 | # === `self.prompt` gets built up over multiple turns === 29 | self.prompt, self.turn_count = "", 0 30 | 31 | def add_turn(self, role: str, message: str) -> str: 32 | assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") 33 | message = message.replace("", "").strip() 34 | 35 | # Special Handling for "first" input --> prepend a token (expected by Prismatic) 36 | if self.turn_count == 0: 37 | bos_human_message = f"{self.bos}{self.wrap_human(message)}" 38 | wrapped_message = bos_human_message 39 | elif (self.turn_count % 2) == 0: 40 | human_message = self.wrap_human(message) 41 | wrapped_message = human_message 42 | else: 43 | gpt_message = self.wrap_gpt(message) 44 | wrapped_message = gpt_message 45 | 46 | # Update Prompt 47 | self.prompt += wrapped_message 48 | 49 | # Bump Turn Counter 50 | self.turn_count += 1 51 | 52 | # Return "wrapped_message" (effective string added to context) 53 | return wrapped_message 54 | 55 | def get_potential_prompt(self, message: str) -> None: 56 | # Assumes that it's always the user's (human's) turn! 57 | prompt_copy = str(self.prompt) 58 | 59 | human_message = self.wrap_human(message) 60 | prompt_copy += human_message 61 | 62 | return prompt_copy.rstrip() 63 | 64 | def get_prompt(self) -> str: 65 | return self.prompt.rstrip() 66 | -------------------------------------------------------------------------------- /prismatic/models/backbones/llm/prompting/vicuna_v15_prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | vicuna_v15_prompter.py 3 | 4 | Defines a PromptBuilder for building Vicuna-v1.5 Chat Prompts. 5 | 6 | Reference: https://huggingface.co/lmsys/vicuna-13b-v1.5 7 | """ 8 | 9 | from typing import Optional 10 | 11 | from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder 12 | 13 | # Default System Prompt for LLaVa Models 14 | SYS_PROMPTS = { 15 | "prismatic": ( 16 | "A chat between a curious user and an artificial intelligence assistant. " 17 | "The assistant gives helpful, detailed, and polite answers to the user's questions." 18 | ), 19 | "openvla": ( 20 | "A chat between a curious user and an artificial intelligence assistant. " 21 | "The assistant gives helpful, detailed, and polite answers to the user's questions." 22 | ), 23 | } 24 | 25 | 26 | class VicunaV15ChatPromptBuilder(PromptBuilder): 27 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 28 | super().__init__(model_family, system_prompt) 29 | self.system_prompt = (SYS_PROMPTS[self.model_family] if system_prompt is None else system_prompt).strip() + " " 30 | 31 | # LLaMa-2 Specific 32 | self.bos, self.eos = "", "" 33 | 34 | # Get role-specific "wrap" functions 35 | self.wrap_human = lambda msg: f"USER: {msg} ASSISTANT: " 36 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" 37 | 38 | # === `self.prompt` gets built up over multiple turns === 39 | self.prompt, self.turn_count = "", 0 40 | 41 | def add_turn(self, role: str, message: str) -> str: 42 | assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") 43 | message = message.replace("", "").strip() 44 | 45 | # Special Handling for "system" prompt (turn_count == 0) 46 | if self.turn_count == 0: 47 | sys_message = self.system_prompt + self.wrap_human(message) 48 | wrapped_message = sys_message 49 | elif (self.turn_count % 2) == 0: 50 | human_message = self.wrap_human(message) 51 | wrapped_message = human_message 52 | else: 53 | gpt_message = self.wrap_gpt(message) 54 | wrapped_message = gpt_message 55 | 56 | # Update Prompt 57 | self.prompt += wrapped_message 58 | 59 | # Bump Turn Counter 60 | self.turn_count += 1 61 | 62 | # Return "wrapped_message" (effective string added to context) 63 | return wrapped_message 64 | 65 | def get_potential_prompt(self, message: str) -> None: 66 | # Assumes that it's always the user's (human's) turn! 67 | prompt_copy = str(self.prompt) 68 | 69 | # Special Handling for "system" prompt (turn_count == 0) 70 | if self.turn_count == 0: 71 | sys_message = self.system_prompt + self.wrap_human(message) 72 | prompt_copy += sys_message 73 | 74 | else: 75 | human_message = self.wrap_human(message) 76 | prompt_copy += human_message 77 | 78 | return prompt_copy.removeprefix(self.bos).rstrip() 79 | 80 | def get_prompt(self) -> str: 81 | # Remove prefix (if exists) because it gets auto-inserted by tokenizer! 82 | return self.prompt.removeprefix(self.bos).rstrip() 83 | -------------------------------------------------------------------------------- /prismatic/models/backbones/vision/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_vision import ImageTransform, VisionBackbone 2 | from .clip_vit import CLIPViTBackbone 3 | from .dinoclip_vit import DinoCLIPViTBackbone 4 | from .dinosiglip_vit import DinoSigLIPViTBackbone 5 | from .dinov2_vit import DinoV2ViTBackbone 6 | from .in1k_vit import IN1KViTBackbone 7 | from .siglip_vit import SigLIPViTBackbone 8 | -------------------------------------------------------------------------------- /prismatic/models/backbones/vision/clip_vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | clip_vit.py 3 | """ 4 | 5 | from prismatic.models.backbones.vision.base_vision import TimmViTBackbone 6 | 7 | # Registry =>> Supported CLIP Vision Backbones (from TIMM) 8 | CLIP_VISION_BACKBONES = { 9 | "clip-vit-b": "vit_base_patch16_clip_224.openai", 10 | "clip-vit-l": "vit_large_patch14_clip_224.openai", 11 | "clip-vit-l-336px": "vit_large_patch14_clip_336.openai", 12 | } 13 | 14 | 15 | # [IMPORTANT] By Default, TIMM initialized OpenAI CLIP models with the standard GELU activation from PyTorch. 16 | # HOWEVER =>> Original OpenAI models were trained with the quick_gelu *approximation* -- while it's 17 | # a decent approximation, the resulting features are *worse*; this was a super tricky bug 18 | # to identify, but luckily there's an easy fix (`override_act_layer`) 19 | class CLIPViTBackbone(TimmViTBackbone): 20 | def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: 21 | super().__init__( 22 | vision_backbone_id, 23 | CLIP_VISION_BACKBONES[vision_backbone_id], 24 | image_resize_strategy, 25 | default_image_size=default_image_size, 26 | override_act_layer="quick_gelu" if CLIP_VISION_BACKBONES[vision_backbone_id].endswith(".openai") else None, 27 | ) 28 | -------------------------------------------------------------------------------- /prismatic/models/backbones/vision/dinoclip_vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | dinoclip_vit.py 3 | 4 | Vision backbone that returns concatenated features from both DINOv2 and CLIP. 5 | """ 6 | 7 | from dataclasses import dataclass 8 | from functools import partial 9 | from typing import Callable, Dict, Tuple 10 | 11 | import timm 12 | import torch 13 | from PIL import Image 14 | from timm.models.vision_transformer import Block, VisionTransformer 15 | from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy 16 | from torchvision.transforms import Compose, Resize 17 | 18 | from prismatic.models.backbones.vision.base_vision import ImageTransform, LetterboxPad, VisionBackbone, unpack_tuple 19 | 20 | # Registry =>> Supported DinoCLIP Pairs (as TIMM identifiers) 21 | DINOCLIP_VISION_BACKBONES = { 22 | "dinoclip-vit-l-336px": { 23 | "dino": "vit_large_patch14_reg4_dinov2.lvd142m", 24 | "clip": "vit_large_patch14_clip_336.openai", 25 | }, 26 | } 27 | 28 | 29 | @dataclass 30 | class DinoCLIPImageTransform: 31 | dino_image_transform: ImageTransform 32 | clip_image_transform: ImageTransform 33 | is_prismatic: bool = True 34 | 35 | def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]: 36 | return {"dino": self.dino_image_transform(img, **kwargs), "clip": self.clip_image_transform(img, **kwargs)} 37 | 38 | 39 | class DinoCLIPViTBackbone(VisionBackbone): 40 | def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: 41 | super().__init__(vision_backbone_id, image_resize_strategy, default_image_size=default_image_size) 42 | self.dino_timm_path_or_url = DINOCLIP_VISION_BACKBONES[vision_backbone_id]["dino"] 43 | self.clip_timm_path_or_url = DINOCLIP_VISION_BACKBONES[vision_backbone_id]["clip"] 44 | 45 | # Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary 46 | self.dino_featurizer: VisionTransformer = timm.create_model( 47 | self.dino_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size 48 | ) 49 | self.dino_featurizer.eval() 50 | 51 | self.clip_featurizer: VisionTransformer = timm.create_model( 52 | self.clip_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size 53 | ) 54 | self.clip_featurizer.eval() 55 | 56 | # Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility 57 | # => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches! 58 | # => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385 59 | self.dino_featurizer.forward = unpack_tuple( 60 | partial(self.dino_featurizer.get_intermediate_layers, n={len(self.dino_featurizer.blocks) - 2}) 61 | ) 62 | self.clip_featurizer.forward = unpack_tuple( 63 | partial(self.clip_featurizer.get_intermediate_layers, n={len(self.clip_featurizer.blocks) - 2}) 64 | ) 65 | 66 | # Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models 67 | self.dino_data_cfg = timm.data.resolve_model_data_config(self.dino_featurizer) 68 | self.dino_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size) 69 | 70 | self.clip_data_cfg = timm.data.resolve_model_data_config(self.clip_featurizer) 71 | self.clip_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size) 72 | 73 | # Initialize *both* Transforms 74 | default_dino_transform = timm.data.create_transform(**self.dino_data_cfg, is_training=False) 75 | default_clip_transform = timm.data.create_transform(**self.clip_data_cfg, is_training=False) 76 | if self.image_resize_strategy == "resize-naive": 77 | assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_image_transform`!" 78 | assert isinstance(default_clip_transform, Compose), "Unexpected `default_clip_image_transform`!" 79 | assert isinstance(default_dino_transform.transforms[0], Resize) 80 | assert isinstance(default_clip_transform.transforms[0], Resize) 81 | 82 | target_size = (self.default_image_size, self.default_image_size) 83 | dino_transform = Compose( 84 | [ 85 | Resize(target_size, interpolation=default_dino_transform.transforms[0].interpolation), 86 | *default_dino_transform.transforms[1:], 87 | ] 88 | ) 89 | clip_transform = Compose( 90 | [ 91 | Resize(target_size, interpolation=default_clip_transform.transforms[0].interpolation), 92 | *default_clip_transform.transforms[1:], 93 | ] 94 | ) 95 | 96 | self.image_transform = DinoCLIPImageTransform(dino_transform, clip_transform) 97 | 98 | elif self.image_resize_strategy == "resize-crop": 99 | self.image_transform = DinoCLIPImageTransform(default_dino_transform, default_clip_transform) 100 | 101 | elif self.image_resize_strategy == "letterbox": 102 | assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_transform`!" 103 | assert isinstance(default_clip_transform, Compose), "Unexpected `default_clip_transform`!" 104 | assert "mean" in self.dino_data_cfg and "mean" in self.clip_data_cfg, "DinoCLIP `data_cfg` missing `mean`!" 105 | 106 | # Compute Padding Fill Value(s) (rescaled normalization mean if applicable) 107 | dino_fill = tuple([int(x * 255) for x in self.dino_data_cfg["mean"]]) 108 | clip_fill = tuple([int(x * 255) for x in self.clip_data_cfg["mean"]]) 109 | 110 | # Build New Transform 111 | self.image_transform = DinoCLIPImageTransform( 112 | Compose([LetterboxPad(dino_fill), *default_dino_transform.transforms]), 113 | Compose([LetterboxPad(clip_fill), *default_clip_transform.transforms]), 114 | ) 115 | 116 | else: 117 | raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!") 118 | 119 | def get_fsdp_wrapping_policy(self) -> Callable: 120 | """Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers.""" 121 | vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer}) 122 | transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) 123 | return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy]) 124 | 125 | def forward(self, pixel_values: Dict[str, torch.Tensor]) -> torch.Tensor: 126 | """Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches.""" 127 | dino_patches = self.dino_featurizer(pixel_values["dino"]) 128 | clip_patches = self.clip_featurizer(pixel_values["clip"]) 129 | 130 | return torch.cat([dino_patches, clip_patches], dim=2) 131 | 132 | @property 133 | def default_image_resolution(self) -> Tuple[int, int, int]: 134 | return self.dino_data_cfg["input_size"] 135 | 136 | @property 137 | def embed_dim(self) -> int: 138 | return self.dino_featurizer.embed_dim + self.clip_featurizer.embed_dim 139 | 140 | @property 141 | def num_patches(self) -> int: 142 | assert self.dino_featurizer.patch_embed.num_patches == self.clip_featurizer.patch_embed.num_patches 143 | return self.dino_featurizer.patch_embed.num_patches 144 | 145 | @property 146 | def half_precision_dtype(self) -> torch.dtype: 147 | return torch.bfloat16 148 | -------------------------------------------------------------------------------- /prismatic/models/backbones/vision/dinosiglip_vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | dinosiglip_vit.py 3 | 4 | Vision backbone that returns concatenated features from both DINOv2 and SigLIP. 5 | """ 6 | 7 | from dataclasses import dataclass 8 | from functools import partial 9 | from typing import Callable, Dict, Tuple 10 | 11 | import timm 12 | import torch 13 | from PIL import Image 14 | from timm.models.vision_transformer import Block, VisionTransformer 15 | from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy 16 | from torchvision.transforms import Compose, Resize 17 | 18 | from prismatic.models.backbones.vision.base_vision import ImageTransform, LetterboxPad, VisionBackbone, unpack_tuple 19 | 20 | # Registry =>> Supported DinoSigLIP Pairs (as TIMM identifiers) 21 | DINOSigLIP_VISION_BACKBONES = { 22 | "dinosiglip-vit-so-224px": { 23 | "dino": "vit_large_patch14_reg4_dinov2.lvd142m", 24 | "siglip": "vit_so400m_patch14_siglip_224", 25 | }, 26 | "dinosiglip-vit-so-384px": { 27 | "dino": "vit_large_patch14_reg4_dinov2.lvd142m", 28 | "siglip": "vit_so400m_patch14_siglip_384", 29 | }, 30 | } 31 | 32 | 33 | @dataclass 34 | class DinoSigLIPImageTransform: 35 | dino_image_transform: ImageTransform 36 | siglip_image_transform: ImageTransform 37 | is_prismatic: bool = True 38 | 39 | def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]: 40 | return {"dino": self.dino_image_transform(img, **kwargs), "siglip": self.siglip_image_transform(img, **kwargs)} 41 | 42 | 43 | class DinoSigLIPViTBackbone(VisionBackbone): 44 | def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: 45 | super().__init__(vision_backbone_id, image_resize_strategy, default_image_size=default_image_size) 46 | self.dino_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[vision_backbone_id]["dino"] 47 | self.siglip_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[vision_backbone_id]["siglip"] 48 | 49 | # Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary 50 | self.dino_featurizer: VisionTransformer = timm.create_model( 51 | self.dino_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size 52 | ) 53 | self.dino_featurizer.eval() 54 | 55 | self.siglip_featurizer: VisionTransformer = timm.create_model( 56 | self.siglip_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size 57 | ) 58 | self.siglip_featurizer.eval() 59 | 60 | # Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility 61 | # => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches! 62 | # => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385 63 | self.dino_featurizer.forward = unpack_tuple( 64 | partial(self.dino_featurizer.get_intermediate_layers, n={len(self.dino_featurizer.blocks) - 2}) 65 | ) 66 | self.siglip_featurizer.forward = unpack_tuple( 67 | partial(self.siglip_featurizer.get_intermediate_layers, n={len(self.siglip_featurizer.blocks) - 2}) 68 | ) 69 | 70 | # Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models 71 | self.dino_data_cfg = timm.data.resolve_model_data_config(self.dino_featurizer) 72 | self.dino_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size) 73 | 74 | self.siglip_data_cfg = timm.data.resolve_model_data_config(self.siglip_featurizer) 75 | self.siglip_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size) 76 | 77 | # Initialize *both* Transforms 78 | default_dino_transform = timm.data.create_transform(**self.dino_data_cfg, is_training=False) 79 | default_siglip_transform = timm.data.create_transform(**self.siglip_data_cfg, is_training=False) 80 | 81 | # Fix =>> SigLIP default transform resizes to *larger* than `self.default_image_size` (crops image)!! 82 | assert isinstance(default_siglip_transform, Compose), "Unexpected `default_image_transform`!" 83 | assert isinstance(default_siglip_transform.transforms[0], Resize) 84 | default_siglip_transform = Compose( 85 | [ 86 | Resize(self.default_image_size, interpolation=default_siglip_transform.transforms[0].interpolation), 87 | *default_siglip_transform.transforms[1:], 88 | ] 89 | ) 90 | 91 | if self.image_resize_strategy == "resize-naive": 92 | assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_image_transform`!" 93 | assert isinstance(default_siglip_transform, Compose), "Unexpected `default_siglip_image_transform`!" 94 | assert isinstance(default_dino_transform.transforms[0], Resize) 95 | assert isinstance(default_siglip_transform.transforms[0], Resize) 96 | 97 | target_size = (self.default_image_size, self.default_image_size) 98 | dino_transform = Compose( 99 | [ 100 | Resize(target_size, interpolation=default_dino_transform.transforms[0].interpolation), 101 | *default_dino_transform.transforms[1:], 102 | ] 103 | ) 104 | siglip_transform = Compose( 105 | [ 106 | Resize(target_size, interpolation=default_siglip_transform.transforms[0].interpolation), 107 | *default_siglip_transform.transforms[1:], 108 | ] 109 | ) 110 | 111 | self.image_transform = DinoSigLIPImageTransform(dino_transform, siglip_transform) 112 | 113 | elif self.image_resize_strategy == "resize-crop": 114 | self.image_transform = DinoSigLIPImageTransform(default_dino_transform, default_siglip_transform) 115 | 116 | elif self.image_resize_strategy == "letterbox": 117 | assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_transform`!" 118 | assert isinstance(default_siglip_transform, Compose), "Unexpected `default_siglip_transform`!" 119 | assert ( 120 | "mean" in self.dino_data_cfg and "mean" in self.siglip_data_cfg 121 | ), "DinoSigLIP `data_cfg` missing `mean`!" 122 | 123 | # Compute Padding Fill Value(s) (rescaled normalization mean if applicable) 124 | dino_fill = tuple([int(x * 255) for x in self.dino_data_cfg["mean"]]) 125 | siglip_fill = tuple([int(x * 255) for x in self.siglip_data_cfg["mean"]]) 126 | 127 | # Build New Transform 128 | self.image_transform = DinoSigLIPImageTransform( 129 | Compose([LetterboxPad(dino_fill), *default_dino_transform.transforms]), 130 | Compose([LetterboxPad(siglip_fill), *default_siglip_transform.transforms]), 131 | ) 132 | 133 | else: 134 | raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!") 135 | 136 | def get_fsdp_wrapping_policy(self) -> Callable: 137 | """Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers.""" 138 | vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer}) 139 | transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) 140 | return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy]) 141 | 142 | def forward(self, pixel_values: Dict[str, torch.Tensor]) -> torch.Tensor: 143 | """Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches.""" 144 | # print(pixel_values.shape) 145 | if isinstance(pixel_values, dict): 146 | dino_patches = self.dino_featurizer(pixel_values["dino"]) 147 | siglip_patches = self.siglip_featurizer(pixel_values["siglip"]) 148 | else: 149 | dino_patches = self.dino_featurizer(pixel_values[:,:3]) 150 | siglip_patches = self.siglip_featurizer(pixel_values[:,3:]) 151 | 152 | return torch.cat([dino_patches, siglip_patches], dim=2) 153 | 154 | @property 155 | def default_image_resolution(self) -> Tuple[int, int, int]: 156 | return self.dino_data_cfg["input_size"] 157 | 158 | @property 159 | def embed_dim(self) -> int: 160 | return self.dino_featurizer.embed_dim + self.siglip_featurizer.embed_dim 161 | 162 | @property 163 | def num_patches(self) -> int: 164 | assert self.dino_featurizer.patch_embed.num_patches == self.siglip_featurizer.patch_embed.num_patches 165 | return self.dino_featurizer.patch_embed.num_patches 166 | 167 | @property 168 | def half_precision_dtype(self) -> torch.dtype: 169 | return torch.bfloat16 170 | -------------------------------------------------------------------------------- /prismatic/models/backbones/vision/dinov2_vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | dinov2_vit.py 3 | """ 4 | 5 | from prismatic.models.backbones.vision.base_vision import TimmViTBackbone 6 | 7 | # Registry =>> Supported DINOv2 Vision Backbones (from TIMM) =>> Note:: Using DINOv2 w/ Registers! 8 | # => Reference: https://arxiv.org/abs/2309.16588 9 | DINOv2_VISION_BACKBONES = {"dinov2-vit-l": "vit_large_patch14_reg4_dinov2.lvd142m"} 10 | 11 | 12 | class DinoV2ViTBackbone(TimmViTBackbone): 13 | def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: 14 | super().__init__( 15 | vision_backbone_id, 16 | DINOv2_VISION_BACKBONES[vision_backbone_id], 17 | image_resize_strategy, 18 | default_image_size=default_image_size, 19 | ) 20 | -------------------------------------------------------------------------------- /prismatic/models/backbones/vision/in1k_vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | in1k_vit.py 3 | 4 | Vision Transformers trained / finetuned on ImageNet (ImageNet-21K =>> ImageNet-1K) 5 | """ 6 | 7 | from prismatic.models.backbones.vision.base_vision import TimmViTBackbone 8 | 9 | # Registry =>> Supported Vision Backbones (from TIMM) 10 | IN1K_VISION_BACKBONES = { 11 | "in1k-vit-l": "vit_large_patch16_224.augreg_in21k_ft_in1k", 12 | } 13 | 14 | 15 | class IN1KViTBackbone(TimmViTBackbone): 16 | def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: 17 | super().__init__( 18 | vision_backbone_id, 19 | IN1K_VISION_BACKBONES[vision_backbone_id], 20 | image_resize_strategy, 21 | default_image_size=default_image_size, 22 | ) 23 | -------------------------------------------------------------------------------- /prismatic/models/backbones/vision/siglip_vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | siglip_vit.py 3 | """ 4 | 5 | from prismatic.models.backbones.vision.base_vision import TimmViTBackbone 6 | 7 | # Registry =>> Supported SigLIP Vision Backbones (from TIMM) =>> Note:: Using SigLIP w/ Patch = 14 (but SO400M Arch) 8 | SIGLIP_VISION_BACKBONES = { 9 | "siglip-vit-b16-224px": "vit_base_patch16_siglip_224", 10 | "siglip-vit-b16-256px": "vit_base_patch16_siglip_256", 11 | "siglip-vit-b16-384px": "vit_base_patch16_siglip_384", 12 | "siglip-vit-so400m": "vit_so400m_patch14_siglip_224", 13 | "siglip-vit-so400m-384px": "vit_so400m_patch14_siglip_384", 14 | } 15 | 16 | 17 | class SigLIPViTBackbone(TimmViTBackbone): 18 | def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: 19 | super().__init__( 20 | vision_backbone_id, 21 | SIGLIP_VISION_BACKBONES[vision_backbone_id], 22 | image_resize_strategy, 23 | default_image_size=default_image_size, 24 | ) 25 | -------------------------------------------------------------------------------- /prismatic/models/materialize.py: -------------------------------------------------------------------------------- 1 | """ 2 | materialize.py 3 | 4 | Factory class for initializing Vision Backbones, LLM Backbones, and VLMs from a set registry; provides and exports 5 | individual functions for clear control flow. 6 | """ 7 | 8 | from typing import Optional, Tuple 9 | 10 | from transformers import PreTrainedTokenizerBase 11 | 12 | from prismatic.models.backbones.llm import LLaMa2LLMBackbone, LLMBackbone, MistralLLMBackbone, PhiLLMBackbone 13 | from prismatic.models.backbones.vision import ( 14 | CLIPViTBackbone, 15 | DinoCLIPViTBackbone, 16 | DinoSigLIPViTBackbone, 17 | DinoV2ViTBackbone, 18 | ImageTransform, 19 | IN1KViTBackbone, 20 | SigLIPViTBackbone, 21 | VisionBackbone, 22 | ) 23 | from prismatic.models.vlms import PrismaticVLM 24 | 25 | # === Registries =>> Maps ID --> {cls(), kwargs} :: Different Registries for Vision Backbones, LLM Backbones, VLMs === 26 | # fmt: off 27 | 28 | # === Vision Backbone Registry === 29 | VISION_BACKBONES = { 30 | # === 224px Backbones === 31 | "clip-vit-l": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 224}}, 32 | "siglip-vit-so400m": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 224}}, 33 | "dinov2-vit-l": {"cls": DinoV2ViTBackbone, "kwargs": {"default_image_size": 224}}, 34 | "in1k-vit-l": {"cls": IN1KViTBackbone, "kwargs": {"default_image_size": 224}}, 35 | "dinosiglip-vit-so-224px": {"cls": DinoSigLIPViTBackbone, "kwargs": {"default_image_size": 224}}, 36 | 37 | # === Assorted CLIP Backbones === 38 | "clip-vit-b": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 224}}, 39 | "clip-vit-l-336px": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 336}}, 40 | 41 | # === Assorted SigLIP Backbones === 42 | "siglip-vit-b16-224px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 224}}, 43 | "siglip-vit-b16-256px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 256}}, 44 | "siglip-vit-b16-384px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 384}}, 45 | "siglip-vit-so400m-384px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 384}}, 46 | 47 | # === Fused Backbones === 48 | "dinoclip-vit-l-336px": {"cls": DinoCLIPViTBackbone, "kwargs": {"default_image_size": 336}}, 49 | "dinosiglip-vit-so-384px": {"cls": DinoSigLIPViTBackbone, "kwargs": {"default_image_size": 384}}, 50 | } 51 | 52 | 53 | # === Language Model Registry === 54 | LLM_BACKBONES = { 55 | # === LLaMa-2 Pure (Non-Chat) Backbones === 56 | "llama2-7b-pure": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, 57 | "llama2-13b-pure": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, 58 | 59 | # === LLaMa-2 Chat Backbones === 60 | "llama2-7b-chat": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, 61 | "llama2-13b-chat": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, 62 | 63 | # === Vicuna-v1.5 Backbones === 64 | "vicuna-v15-7b": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, 65 | "vicuna-v15-13b": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, 66 | 67 | # === Mistral v0.1 Backbones === 68 | "mistral-v0.1-7b-pure": {"cls": MistralLLMBackbone, "kwargs": {}}, 69 | "mistral-v0.1-7b-instruct": {"cls": MistralLLMBackbone, "kwargs": {}}, 70 | 71 | # === Phi-2 Backbone === 72 | "phi-2-3b": {"cls": PhiLLMBackbone, "kwargs": {}}, 73 | } 74 | 75 | # fmt: on 76 | 77 | 78 | def get_vision_backbone_and_transform( 79 | vision_backbone_id: str, image_resize_strategy: str 80 | ) -> Tuple[VisionBackbone, ImageTransform]: 81 | """Instantiate a Vision Backbone, returning both the nn.Module wrapper class and default Image Transform.""" 82 | if vision_backbone_id in VISION_BACKBONES: 83 | vision_cfg = VISION_BACKBONES[vision_backbone_id] 84 | vision_backbone: VisionBackbone = vision_cfg["cls"]( 85 | vision_backbone_id, image_resize_strategy, **vision_cfg["kwargs"] 86 | ) 87 | image_transform = vision_backbone.get_image_transform() 88 | return vision_backbone, image_transform 89 | 90 | else: 91 | raise ValueError(f"Vision Backbone `{vision_backbone_id}` is not supported!") 92 | 93 | 94 | def get_llm_backbone_and_tokenizer( 95 | llm_backbone_id: str, 96 | llm_max_length: int = 2048, 97 | hf_token: Optional[str] = None, 98 | inference_mode: bool = False, 99 | ) -> Tuple[LLMBackbone, PreTrainedTokenizerBase]: 100 | if llm_backbone_id in LLM_BACKBONES: 101 | llm_cfg = LLM_BACKBONES[llm_backbone_id] 102 | llm_backbone: LLMBackbone = llm_cfg["cls"]( 103 | llm_backbone_id, 104 | llm_max_length=llm_max_length, 105 | hf_token=hf_token, 106 | inference_mode=inference_mode, 107 | **llm_cfg["kwargs"], 108 | ) 109 | tokenizer = llm_backbone.get_tokenizer() 110 | return llm_backbone, tokenizer 111 | 112 | else: 113 | raise ValueError(f"LLM Backbone `{llm_backbone_id}` is not supported!") 114 | 115 | 116 | def get_vlm( 117 | model_id: str, 118 | arch_specifier: str, 119 | vision_backbone: VisionBackbone, 120 | llm_backbone: LLMBackbone, 121 | enable_mixed_precision_training: bool = True, 122 | ) -> PrismaticVLM: 123 | """Lightweight wrapper around initializing a VLM, mostly for future-proofing (if one wants to add a new VLM).""" 124 | return PrismaticVLM( 125 | model_id, 126 | vision_backbone, 127 | llm_backbone, 128 | enable_mixed_precision_training=enable_mixed_precision_training, 129 | arch_specifier=arch_specifier, 130 | ) 131 | -------------------------------------------------------------------------------- /prismatic/models/policy/transformer_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | from typing import Optional 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch import einsum 8 | import torch.nn.functional as F 9 | from torch.autograd import Function 10 | from torch.nn.init import constant_, xavier_uniform_ 11 | from einops import rearrange, repeat 12 | # from torch import einsum 13 | 14 | 15 | # helpers 16 | def _is_power_of_2(n): 17 | if (not isinstance(n, int)) or (n < 0): 18 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) 19 | return (n & (n - 1) == 0) and n != 0 20 | 21 | 22 | # RMSNorm -- Better, simpler alternative to LayerNorm 23 | class RMSNorm(nn.Module): 24 | def __init__(self, dim: int, eps: float = 1e-8) -> None: 25 | super().__init__() 26 | self.scale, self.eps = dim**-0.5, eps 27 | self.g = nn.Parameter(torch.ones(dim)) 28 | 29 | def forward(self, x: torch.Tensor) -> torch.Tensor: 30 | norm = torch.norm(x, dim=-1, keepdim=True) * self.scale 31 | return x / norm.clamp(min=self.eps) * self.g 32 | 33 | 34 | # SwishGLU -- A Gated Linear Unit (GLU) with the Swish activation; always better than GELU MLP! 35 | class SwishGLU(nn.Module): 36 | def __init__(self, in_dim: int, out_dim: int) -> None: 37 | super().__init__() 38 | self.act, self.project = nn.SiLU(), nn.Linear(in_dim, 2 * out_dim) 39 | 40 | def forward(self, x: torch.Tensor) -> torch.Tensor: 41 | projected, gate = self.project(x).tensor_split(2, dim=-1) 42 | return projected * self.act(gate) 43 | 44 | 45 | # As defined in Set Transformers () -- basically the above, additionally taking in 46 | # a set of $k$ learned "seed vectors" that are used to "pool" information. 47 | class MAPAttention(nn.Module): 48 | def __init__(self, embed_dim: int, n_heads: int) -> None: 49 | """Multi-Input Multi-Headed Attention Operation""" 50 | super().__init__() 51 | assert embed_dim % n_heads == 0, "`embed_dim` must be divisible by `n_heads`!" 52 | self.n_heads, self.scale = n_heads, (embed_dim // n_heads) ** -0.5 53 | 54 | # Projections (no bias) --> separate for Q (seed vector), and KV ("pool" inputs) 55 | self.q, self.kv = nn.Linear(embed_dim, embed_dim, bias=False), nn.Linear(embed_dim, 2 * embed_dim, bias=False) 56 | self.proj = nn.Linear(embed_dim, embed_dim) 57 | 58 | def forward(self, seed: torch.Tensor, x: torch.Tensor, attention_mask = None) -> torch.Tensor: 59 | (B_s, K, C_s), (B_x, N, C_x) = seed.shape, x.shape 60 | assert C_s == C_x, "Seed vectors and pool inputs must have the same embedding dimensionality!" 61 | 62 | # Project Seed Vectors to `queries` 63 | q = self.q(seed).reshape(B_s, K, self.n_heads, C_s // self.n_heads).permute(0, 2, 1, 3) 64 | kv = self.kv(x).reshape(B_x, N, 2, self.n_heads, C_x // self.n_heads).permute(2, 0, 3, 1, 4) 65 | k, v = kv.unbind(0) 66 | 67 | # Attention --> compute weighted sum over values! 68 | scores = q @ (k.transpose(-2, -1) * self.scale) 69 | # print(scores.shape) 70 | if attention_mask is not None: 71 | attention_mask = ( 72 | attention_mask[None, None, :, :].repeat(1, self.n_heads, 1, 1) #.flatten(0, 1) 73 | ) 74 | scores.masked_fill_(attention_mask == 0, float("-inf")) 75 | attn = scores.softmax(dim=-1) 76 | 77 | 78 | vals = (attn @ v).transpose(1, 2).reshape(B_s, K, C_s) 79 | 80 | # Project back to `embed_dim` 81 | return self.proj(vals) 82 | 83 | 84 | class MAPBlock(nn.Module): 85 | def __init__( 86 | self, 87 | n_latents: int, 88 | vis_dim: int, 89 | embed_dim: int, 90 | n_heads: int, 91 | mlp_ratio: float = 4.0, 92 | do_rms_norm: bool = True, 93 | do_swish_glu: bool = True, 94 | ) -> None: 95 | """Multiheaded Attention Pooling Block -- note that for MAP, we adopt earlier post-norm conventions.""" 96 | super().__init__() 97 | self.n_latents, self.embed_dim, self.n_heads = n_latents, embed_dim, n_heads 98 | 99 | # Projection Operator 100 | self.projection = nn.Linear(vis_dim, self.embed_dim) 101 | 102 | # Initialize Latents 103 | self.latents = nn.Parameter(torch.zeros(self.n_latents, self.embed_dim), requires_grad=True) 104 | nn.init.normal_(self.latents, std=0.02) 105 | 106 | # Custom MAP Attention (seed, encoder outputs) -> seed 107 | self.attn_norm = RMSNorm(self.embed_dim) if do_rms_norm else nn.LayerNorm(self.embed_dim, eps=1e-6) 108 | self.attn = MAPAttention(self.embed_dim, n_heads=self.n_heads) 109 | 110 | # Position-wise Feed-Forward Components 111 | self.mlp_norm = RMSNorm(self.embed_dim) if do_rms_norm else nn.LayerNorm(self.embed_dim, eps=1e-6) 112 | self.mlp = nn.Sequential( 113 | # Handle SwishGLU vs. GELU MLP... 114 | ( 115 | SwishGLU(self.embed_dim, int(mlp_ratio * self.embed_dim)) 116 | if do_swish_glu 117 | else nn.Sequential(nn.Linear(self.embed_dim, int(mlp_ratio * self.embed_dim)), nn.GELU()) 118 | ), 119 | nn.Linear(int(mlp_ratio * self.embed_dim), self.embed_dim), 120 | ) 121 | 122 | def forward(self, x: torch.Tensor, mask = None, init_embed = None) -> torch.Tensor: 123 | latents = repeat(self.latents, "n_latents d -> bsz n_latents d", bsz=x.shape[0]) 124 | latents = latents + init_embed.unsqueeze(1) if init_embed is not None else latents 125 | latents = self.attn_norm(latents + self.attn(latents, self.projection(x), mask)) 126 | latents = self.mlp_norm(latents + self.mlp(latents)) 127 | return latents.squeeze(dim=1) 128 | 129 | 130 | 131 | 132 | -------------------------------------------------------------------------------- /prismatic/models/vlas/__init__.py: -------------------------------------------------------------------------------- 1 | from .openvla import OpenVLA 2 | -------------------------------------------------------------------------------- /prismatic/models/vlas/openvla.py: -------------------------------------------------------------------------------- 1 | """ 2 | openvla.py 3 | 4 | PyTorch Module defining OpenVLA as a lightweight wrapper around a PrismaticVLM; defines custom logic around 5 | discretizing actions with the ActionTokenizer. 6 | """ 7 | 8 | from typing import Dict, List, Optional 9 | 10 | import numpy as np 11 | import torch 12 | from PIL import Image 13 | from transformers import LlamaTokenizerFast 14 | 15 | from prismatic.models.vlms.prismatic import PrismaticVLM 16 | from prismatic.overwatch import initialize_overwatch 17 | from prismatic.vla.action_tokenizer import ActionTokenizer 18 | 19 | # Initialize Overwatch =>> Wraps `logging.Logger` 20 | overwatch = initialize_overwatch(__name__) 21 | 22 | 23 | class OpenVLA(PrismaticVLM): 24 | def __init__( 25 | self, 26 | *args, 27 | norm_stats: Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]], 28 | action_tokenizer: ActionTokenizer, 29 | **kwargs, 30 | ) -> None: 31 | super().__init__(*args, **kwargs) 32 | self.norm_stats = norm_stats 33 | self.action_tokenizer = action_tokenizer 34 | 35 | @torch.inference_mode() 36 | def predict_action( 37 | self, image: Image, instruction: str, unnorm_key: Optional[str] = None, **kwargs: str 38 | ) -> np.ndarray: 39 | """ 40 | Core function for VLA inference; maps input image and task instruction to continuous action (de-tokenizes). 41 | 42 | @param image: PIL Image as [height, width, 3] 43 | @param instruction: Task instruction string 44 | @param unnorm_key: Optional dataset name for retrieving un-normalizing statistics; if None, checks that model 45 | was trained only on a single dataset, and retrieves those statistics. 46 | 47 | @return Unnormalized (continuous) action vector --> end-effector deltas. 48 | """ 49 | image_transform, tokenizer = self.vision_backbone.image_transform, self.llm_backbone.tokenizer 50 | 51 | # Build VLA Prompt 52 | prompt_builder = self.get_prompt_builder() 53 | prompt_builder.add_turn(role="human", message=f"What action should the robot take to {instruction.lower()}?") 54 | prompt_text = prompt_builder.get_prompt() 55 | 56 | # Prepare Inputs 57 | input_ids = tokenizer(prompt_text, truncation=True, return_tensors="pt").input_ids.to(self.device) 58 | if isinstance(tokenizer, LlamaTokenizerFast): 59 | # If the special empty token ('') does not already appear after the colon (':') token in the prompt 60 | # (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time 61 | if not torch.all(input_ids[:, -1] == 29871): 62 | input_ids = torch.cat( 63 | (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1 64 | ) 65 | else: 66 | raise ValueError(f"Unsupported `tokenizer` type = {type(tokenizer)}") 67 | 68 | # Preprocess Image 69 | pixel_values = image_transform(image) 70 | if isinstance(pixel_values, torch.Tensor): 71 | pixel_values = pixel_values[None, ...].to(self.device) 72 | elif isinstance(pixel_values, dict): 73 | pixel_values = {k: v[None, ...].to(self.device) for k, v in pixel_values.items()} 74 | else: 75 | raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}") 76 | 77 | # Invoke super().generate --> taps into `GenerationMixin` which (redirects) to `forward()` 78 | autocast_dtype = self.llm_backbone.half_precision_dtype 79 | with torch.autocast("cuda", dtype=autocast_dtype, enabled=self.enable_mixed_precision_training): 80 | # fmt: off 81 | generated_ids = super(PrismaticVLM, self).generate( 82 | input_ids=input_ids, # Shape: [1, seq] 83 | pixel_values=pixel_values, # Shape: [1, 3, res, res] or Dict[str, ...] 84 | max_new_tokens=self.get_action_dim(unnorm_key), 85 | **kwargs 86 | ) 87 | # fmt: on 88 | 89 | # Extract predicted action tokens and translate into (normalized) continuous actions 90 | predicted_action_token_ids = generated_ids[0, -self.get_action_dim(unnorm_key) :] 91 | normalized_actions = self.action_tokenizer.decode_token_ids_to_actions(predicted_action_token_ids.cpu().numpy()) 92 | 93 | # Un-normalize Actions 94 | action_norm_stats = self.get_action_stats(unnorm_key) 95 | mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool)) 96 | action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"]) 97 | actions = np.where( 98 | mask, 99 | 0.5 * (normalized_actions + 1) * (action_high - action_low) + action_low, 100 | normalized_actions, 101 | ) 102 | 103 | return actions 104 | 105 | @staticmethod 106 | def _check_unnorm_key(norm_stats: Dict, unnorm_key: str) -> str: 107 | if unnorm_key is None: 108 | assert len(norm_stats) == 1, ( 109 | f"Your model was trained on more than one dataset, please pass a `unnorm_key` from the following " 110 | f"options to choose the statistics used for un-normalizing actions: {norm_stats.keys()}" 111 | ) 112 | unnorm_key = next(iter(norm_stats.keys())) 113 | 114 | # Error Handling 115 | assert ( 116 | unnorm_key in norm_stats 117 | ), f"The `unnorm_key` you chose is not in the set of available statistics; choose from: {norm_stats.keys()}" 118 | 119 | return unnorm_key 120 | 121 | def get_action_dim(self, unnorm_key: Optional[str] = None) -> int: 122 | """Dimensionality of the policy's action space.""" 123 | unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) 124 | 125 | return len(self.norm_stats[unnorm_key]["action"]["q01"]) 126 | 127 | def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict: 128 | """Dimensionality of the policy's action space.""" 129 | unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) 130 | 131 | return self.norm_stats[unnorm_key]["action"] 132 | -------------------------------------------------------------------------------- /prismatic/models/vlms/__init__.py: -------------------------------------------------------------------------------- 1 | from .prismatic import PrismaticVLM 2 | -------------------------------------------------------------------------------- /prismatic/models/vlms/base_vlm.py: -------------------------------------------------------------------------------- 1 | """ 2 | base_vlm.py 3 | 4 | Abstract class definition of a Vision-Language Model (VLM), with full annotations of class methods, utility functions, 5 | and initialization logic. This is mostly to future-proof the codebase; while all our experiments instantiate 6 | from PrismaticVLM, theoretically, this base class should be general enough to cover almost all models (e.g., IDEFICS, 7 | PALI, Fuyu) in the future. 8 | 9 | We use Abstract base classes *sparingly* -- mostly as a way to encapsulate any redundant logic or nested inheritance 10 | (e.g., dependence on nn.Module, HF PretrainedModel, etc.). For other abstract objects (e.g., Tokenizers/Transforms), 11 | prefer Protocol definitions instead. 12 | """ 13 | 14 | from __future__ import annotations 15 | 16 | from abc import ABC, abstractmethod 17 | from pathlib import Path 18 | from typing import Callable, List, Optional 19 | 20 | import torch 21 | import torch.nn as nn 22 | from transformers import GenerationMixin, PretrainedConfig 23 | from transformers.modeling_outputs import CausalLMOutputWithPast 24 | 25 | from prismatic.models.backbones.llm import LLMBackbone 26 | from prismatic.models.backbones.llm.prompting import PromptBuilder 27 | from prismatic.models.backbones.vision import VisionBackbone 28 | 29 | 30 | # === Abstract Base Class for arbitrary Vision-Language Models === 31 | class VLM(nn.Module, GenerationMixin, ABC): 32 | def __init__( 33 | self, 34 | model_family: str, 35 | model_id: str, 36 | vision_backbone: VisionBackbone, 37 | llm_backbone: LLMBackbone, 38 | enable_mixed_precision_training: bool = True, 39 | ) -> None: 40 | super().__init__() 41 | self.model_family, self.model_id = model_family, model_id 42 | self.vision_backbone, self.llm_backbone = vision_backbone, llm_backbone 43 | self.enable_mixed_precision_training = enable_mixed_precision_training 44 | 45 | # Instance Attributes for a generic VLM 46 | self.all_module_keys, self.trainable_module_keys = None, None 47 | 48 | # === GenerationMixin Expected Attributes =>> *DO NOT MODIFY* === 49 | self.generation_config = self.llm_backbone.llm.generation_config 50 | self.main_input_name = "input_ids" 51 | 52 | @property 53 | def device(self) -> torch.device: 54 | """Borrowed from `transformers.modeling_utils.py` -- checks parameter device; assumes model on *ONE* device!""" 55 | return next(self.parameters()).device 56 | 57 | @classmethod 58 | @abstractmethod 59 | def from_pretrained( 60 | cls, 61 | pretrained_checkpoint: Path, 62 | model_family: str, 63 | model_id: str, 64 | vision_backbone: VisionBackbone, 65 | llm_backbone: LLMBackbone, 66 | **kwargs: str, 67 | ) -> VLM: ... 68 | 69 | @abstractmethod 70 | def get_prompt_builder(self, system_prompt: Optional[str] = None) -> PromptBuilder: ... 71 | 72 | @abstractmethod 73 | def freeze_backbones(self, stage: str) -> None: ... 74 | 75 | @abstractmethod 76 | def load_from_checkpoint(self, stage: str, run_dir: Path, pretrained_checkpoint: Optional[Path] = None) -> None: ... 77 | 78 | @abstractmethod 79 | def get_fsdp_wrapping_policy(self) -> Callable: ... 80 | 81 | @abstractmethod 82 | def forward( 83 | self, 84 | input_ids: Optional[torch.LongTensor] = None, 85 | attention_mask: Optional[torch.Tensor] = None, 86 | pixel_values: Optional[torch.FloatTensor] = None, 87 | labels: Optional[torch.LongTensor] = None, 88 | inputs_embeds: Optional[torch.FloatTensor] = None, 89 | past_key_values: Optional[List[torch.FloatTensor]] = None, 90 | use_cache: Optional[bool] = None, 91 | output_attentions: Optional[bool] = None, 92 | output_hidden_states: Optional[bool] = None, 93 | return_dict: Optional[bool] = None, 94 | multimodal_indices: Optional[torch.LongTensor] = None, 95 | ) -> CausalLMOutputWithPast: ... 96 | 97 | # === GenerationMixin Expected Properties & Methods (DO NOT MODIFY) === 98 | @staticmethod 99 | def can_generate() -> bool: 100 | return True 101 | 102 | @property 103 | def config(self) -> PretrainedConfig: 104 | return self.llm_backbone.llm.config 105 | 106 | # => Beam Search Utility 107 | def _reorder_cache(self, past_key_values, beam_idx): 108 | return self.llm_backbone.llm._reorder_cache(past_key_values, beam_idx) 109 | -------------------------------------------------------------------------------- /prismatic/overwatch/__init__.py: -------------------------------------------------------------------------------- 1 | from .overwatch import initialize_overwatch 2 | -------------------------------------------------------------------------------- /prismatic/overwatch/overwatch.py: -------------------------------------------------------------------------------- 1 | """ 2 | overwatch.py 3 | 4 | Utility class for creating a centralized/standardized logger (built on Rich) and accelerate handler. 5 | """ 6 | 7 | import logging 8 | import logging.config 9 | import os 10 | from contextlib import nullcontext 11 | from logging import LoggerAdapter 12 | from typing import Any, Callable, ClassVar, Dict, MutableMapping, Tuple, Union 13 | 14 | # Overwatch Default Format String 15 | RICH_FORMATTER, DATEFMT = "| >> %(message)s", "%m/%d [%H:%M:%S]" 16 | 17 | # Set Logging Configuration 18 | LOG_CONFIG = { 19 | "version": 1, 20 | "disable_existing_loggers": True, 21 | "formatters": {"simple-console": {"format": RICH_FORMATTER, "datefmt": DATEFMT}}, 22 | "handlers": { 23 | "console": { 24 | "class": "rich.logging.RichHandler", 25 | "formatter": "simple-console", 26 | "markup": True, 27 | "rich_tracebacks": True, 28 | "show_level": True, 29 | "show_path": True, 30 | "show_time": True, 31 | } 32 | }, 33 | "root": {"level": "INFO", "handlers": ["console"]}, 34 | } 35 | logging.config.dictConfig(LOG_CONFIG) 36 | 37 | 38 | # === Custom Contextual Logging Logic === 39 | class ContextAdapter(LoggerAdapter): 40 | CTX_PREFIXES: ClassVar[Dict[int, str]] = {**{0: "[*] "}, **{idx: "|=> ".rjust(4 + (idx * 4)) for idx in [1, 2, 3]}} 41 | 42 | def process(self, msg: str, kwargs: MutableMapping[str, Any]) -> Tuple[str, MutableMapping[str, Any]]: 43 | ctx_level = kwargs.pop("ctx_level", 0) 44 | return f"{self.CTX_PREFIXES[ctx_level]}{msg}", kwargs 45 | 46 | 47 | class DistributedOverwatch: 48 | def __init__(self, name: str) -> None: 49 | """Initializer for an Overwatch object that wraps logging & `accelerate.PartialState`.""" 50 | from accelerate import PartialState 51 | 52 | # Note that PartialState is always safe to initialize regardless of `accelerate launch` or `torchrun` 53 | # =>> However, might be worth actually figuring out if we need the `accelerate` dependency at all! 54 | self.logger, self.distributed_state = ContextAdapter(logging.getLogger(name), extra={}), PartialState() 55 | 56 | # Logger Delegation (for convenience; would be nice to just compose & dynamic dispatch eventually) 57 | self.debug = self.logger.debug 58 | self.info = self.logger.info 59 | self.warning = self.logger.warning 60 | self.error = self.logger.error 61 | self.critical = self.logger.critical 62 | 63 | # Logging Defaults =>> only Log `INFO` on Main Process, `ERROR` on others! 64 | self.logger.setLevel(logging.INFO if self.distributed_state.is_main_process else logging.ERROR) 65 | 66 | @property 67 | def rank_zero_only(self) -> Callable[..., Any]: 68 | return self.distributed_state.on_main_process 69 | 70 | @property 71 | def local_zero_only(self) -> Callable[..., Any]: 72 | return self.distributed_state.on_local_main_process 73 | 74 | @property 75 | def rank_zero_first(self) -> Callable[..., Any]: 76 | return self.distributed_state.main_process_first 77 | 78 | @property 79 | def local_zero_first(self) -> Callable[..., Any]: 80 | return self.distributed_state.local_main_process_first 81 | 82 | def is_rank_zero(self) -> bool: 83 | return self.distributed_state.is_main_process 84 | 85 | def rank(self) -> int: 86 | return self.distributed_state.process_index 87 | 88 | def local_rank(self) -> int: 89 | return self.distributed_state.local_process_index 90 | 91 | def world_size(self) -> int: 92 | return self.distributed_state.num_processes 93 | 94 | 95 | class PureOverwatch: 96 | def __init__(self, name: str) -> None: 97 | """Initializer for an Overwatch object that just wraps logging.""" 98 | self.logger = ContextAdapter(logging.getLogger(name), extra={}) 99 | 100 | # Logger Delegation (for convenience; would be nice to just compose & dynamic dispatch eventually) 101 | self.debug = self.logger.debug 102 | self.info = self.logger.info 103 | self.warning = self.logger.warning 104 | self.error = self.logger.error 105 | self.critical = self.logger.critical 106 | 107 | # Logging Defaults =>> INFO 108 | self.logger.setLevel(logging.INFO) 109 | 110 | @staticmethod 111 | def get_identity_ctx() -> Callable[..., Any]: 112 | def identity(fn: Callable[..., Any]) -> Callable[..., Any]: 113 | return fn 114 | 115 | return identity 116 | 117 | @property 118 | def rank_zero_only(self) -> Callable[..., Any]: 119 | return self.get_identity_ctx() 120 | 121 | @property 122 | def local_zero_only(self) -> Callable[..., Any]: 123 | return self.get_identity_ctx() 124 | 125 | @property 126 | def rank_zero_first(self) -> Callable[..., Any]: 127 | return nullcontext 128 | 129 | @property 130 | def local_zero_first(self) -> Callable[..., Any]: 131 | return nullcontext 132 | 133 | @staticmethod 134 | def is_rank_zero() -> bool: 135 | return True 136 | 137 | @staticmethod 138 | def rank() -> int: 139 | return 0 140 | 141 | @staticmethod 142 | def world_size() -> int: 143 | return 1 144 | 145 | 146 | def initialize_overwatch(name: str) -> Union[DistributedOverwatch, PureOverwatch]: 147 | return DistributedOverwatch(name) if int(os.environ.get("WORLD_SIZE", -1)) != -1 else PureOverwatch(name) 148 | -------------------------------------------------------------------------------- /prismatic/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | from .download import convert_to_jpg, download_extract 2 | from .materialize import get_dataset_and_collator 3 | -------------------------------------------------------------------------------- /prismatic/preprocessing/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import AlignDataset, FinetuneDataset 2 | -------------------------------------------------------------------------------- /prismatic/preprocessing/download.py: -------------------------------------------------------------------------------- 1 | """ 2 | download.py 3 | 4 | Utility functions for downloading and extracting various datasets to (local) disk. 5 | """ 6 | 7 | import os 8 | import shutil 9 | from pathlib import Path 10 | from typing import Dict, List, TypedDict 11 | from zipfile import ZipFile 12 | 13 | import requests 14 | from PIL import Image 15 | from rich.progress import BarColumn, DownloadColumn, MofNCompleteColumn, Progress, TextColumn, TransferSpeedColumn 16 | from tqdm import tqdm 17 | 18 | from prismatic.overwatch import initialize_overwatch 19 | 20 | # Initialize Overwatch =>> Wraps `logging.Logger` 21 | overwatch = initialize_overwatch(__name__) 22 | 23 | 24 | # === Dataset Registry w/ Links === 25 | # fmt: off 26 | DatasetComponent = TypedDict( 27 | "DatasetComponent", 28 | {"name": str, "extract": bool, "extract_type": str, "url": str, "do_rename": bool}, 29 | total=False 30 | ) 31 | 32 | DATASET_REGISTRY: Dict[str, List[DatasetComponent]] = { 33 | # === LLaVa v1.5 Dataset(s) === 34 | 35 | # Note =>> This is the full suite of datasets included in the LLaVa 1.5 "finetuning" stage; all the LLaVa v1.5 36 | # models are finetuned on this split. We use this dataset for all experiments in our paper. 37 | "llava-laion-cc-sbu-558k": [ 38 | { 39 | "name": "chat.json", # Contains the "chat" traces :: {"human" => , "gpt" => } 40 | "extract": False, 41 | "url": "https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/resolve/main/blip_laion_cc_sbu_558k.json", 42 | "do_rename": True, 43 | }, 44 | { 45 | "name": "images", # Contains the LLaVa Processed Images (jpgs, 224x224 resolution) 46 | "extract": True, 47 | "extract_type": "directory", 48 | "url": "https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/resolve/main/images.zip", 49 | "do_rename": False, 50 | } 51 | ], 52 | 53 | "llava-v1.5-instruct": [ 54 | { 55 | "name": "llava_v1_5_mix665k.json", 56 | "extract": False, 57 | "url": ( 58 | "https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/resolve/main/llava_v1_5_mix665k.json" 59 | ), 60 | "do_rename": True, 61 | }, 62 | { 63 | "name": "coco/train2017", # Visual Instruct Tuning images are all sourced from COCO Train 2017 64 | "extract": True, 65 | "extract_type": "directory", 66 | "url": "http://images.cocodataset.org/zips/train2017.zip", 67 | "do_rename": True, 68 | }, 69 | { 70 | "name": "gqa/images", 71 | "extract": True, 72 | "extract_type": "directory", 73 | "url": "https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip", 74 | "do_rename": True, 75 | }, 76 | { 77 | "name": "ocr_vqa/images", 78 | "extract": True, 79 | "extract_type": "directory", 80 | "url": "https://huggingface.co/datasets/qnguyen3/ocr_vqa/resolve/main/ocr_vqa.zip", 81 | "do_rename": True, 82 | }, 83 | { 84 | "name": "textvqa/train_images", 85 | "extract": True, 86 | "extract_type": "directory", 87 | "url": "https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip", 88 | "do_rename": True, 89 | }, 90 | { 91 | "name": "vg/VG_100K", 92 | "extract": True, 93 | "extract_type": "directory", 94 | "url": "https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip", 95 | "do_rename": True, 96 | }, 97 | { 98 | "name": "vg/VG_100K_2", 99 | "extract": True, 100 | "extract_type": "directory", 101 | "url": "https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip", 102 | "do_rename": True, 103 | }, 104 | ] 105 | } 106 | # fmt: on 107 | 108 | 109 | def convert_to_jpg(image_dir: Path) -> None: 110 | """Handling for OCR-VQA Images specifically; iterates through directory, converts all GIFs/PNGs.""" 111 | overwatch.info(f"Converting all Images in `{image_dir}` to JPG") 112 | 113 | for image_fn in tqdm(list(image_dir.iterdir())): 114 | if image_fn.suffix in {".jpg", ".jpeg"} or (jpg_fn := image_dir / f"{image_fn.stem}.jpg").exists(): 115 | continue 116 | 117 | if image_fn.suffix == ".gif": 118 | gif = Image.open(image_fn) 119 | gif.seek(0) 120 | gif.convert("RGB").save(jpg_fn) 121 | elif image_fn.suffix == ".png": 122 | Image.open(image_fn).convert("RGB").save(jpg_fn) 123 | else: 124 | raise ValueError(f"Unexpected image format `{image_fn.suffix}`") 125 | 126 | 127 | def download_with_progress(url: str, download_dir: Path, chunk_size_bytes: int = 1024) -> Path: 128 | """Utility function for downloading files from the internet, with a handy Rich-based progress bar.""" 129 | overwatch.info(f"Downloading {(dest_path := download_dir / Path(url).name)} from `{url}`", ctx_level=1) 130 | if dest_path.exists(): 131 | return dest_path 132 | 133 | # Otherwise --> fire an HTTP Request, with `stream = True` 134 | response = requests.get(url, stream=True) 135 | 136 | # Download w/ Transfer-Aware Progress 137 | # => Reference: https://github.com/Textualize/rich/blob/master/examples/downloader.py 138 | with Progress( 139 | TextColumn("[bold]{task.description} - {task.fields[fname]}"), 140 | BarColumn(bar_width=None), 141 | "[progress.percentage]{task.percentage:>3.1f}%", 142 | "•", 143 | DownloadColumn(), 144 | "•", 145 | TransferSpeedColumn(), 146 | transient=True, 147 | ) as dl_progress: 148 | dl_tid = dl_progress.add_task( 149 | "Downloading", fname=dest_path.name, total=int(response.headers.get("content-length", "None")) 150 | ) 151 | with open(dest_path, "wb") as f: 152 | for data in response.iter_content(chunk_size=chunk_size_bytes): 153 | dl_progress.advance(dl_tid, f.write(data)) 154 | 155 | return dest_path 156 | 157 | 158 | def extract_with_progress(archive_path: Path, download_dir: Path, extract_type: str, cleanup: bool = False) -> Path: 159 | """Utility function for extracting compressed archives, with a handy Rich-based progress bar.""" 160 | assert archive_path.suffix == ".zip", "Only `.zip` compressed archives are supported for now!" 161 | overwatch.info(f"Extracting {archive_path.name} to `{download_dir}`", ctx_level=1) 162 | 163 | # Extract w/ Progress 164 | with Progress( 165 | TextColumn("[bold]{task.description} - {task.fields[aname]}"), 166 | BarColumn(bar_width=None), 167 | "[progress.percentage]{task.percentage:>3.1f}%", 168 | "•", 169 | MofNCompleteColumn(), 170 | transient=True, 171 | ) as ext_progress: 172 | with ZipFile(archive_path) as zf: 173 | ext_tid = ext_progress.add_task("Extracting", aname=archive_path.name, total=len(members := zf.infolist())) 174 | extract_path = Path(zf.extract(members[0], download_dir)) 175 | if extract_type == "file": 176 | assert len(members) == 1, f"Archive `{archive_path}` with extract type `{extract_type} has > 1 member!" 177 | elif extract_type == "directory": 178 | for member in members[1:]: 179 | zf.extract(member, download_dir) 180 | ext_progress.advance(ext_tid) 181 | else: 182 | raise ValueError(f"Extract type `{extract_type}` for archive `{archive_path}` is not defined!") 183 | 184 | # Cleanup (if specified) 185 | if cleanup: 186 | archive_path.unlink() 187 | 188 | return extract_path 189 | 190 | 191 | def download_extract(dataset_id: str, root_dir: Path) -> None: 192 | """Download all files for a given dataset (querying registry above), extracting archives if necessary.""" 193 | os.makedirs(download_dir := root_dir / "download" / dataset_id, exist_ok=True) 194 | 195 | # Download Files => Single-Threaded, with Progress Bar 196 | dl_tasks = [d for d in DATASET_REGISTRY[dataset_id] if not (download_dir / d["name"]).exists()] 197 | for dl_task in dl_tasks: 198 | dl_path = download_with_progress(dl_task["url"], download_dir) 199 | 200 | # Extract Files (if specified) --> Note (assumes ".zip" ONLY!) 201 | if dl_task["extract"]: 202 | dl_path = extract_with_progress(dl_path, download_dir, dl_task["extract_type"]) 203 | dl_path = dl_path.parent if dl_path.is_file() else dl_path 204 | 205 | # Rename Path --> dl_task["name"] 206 | if dl_task["do_rename"]: 207 | shutil.move(dl_path, download_dir / dl_task["name"]) 208 | -------------------------------------------------------------------------------- /prismatic/preprocessing/materialize.py: -------------------------------------------------------------------------------- 1 | """ 2 | materialize.py 3 | 4 | Factory class for initializing pretraining datasets on a per-VLM basis; provides and exports individual functions for 5 | clear control flow. 6 | """ 7 | 8 | from typing import Tuple, Type 9 | 10 | from torch.utils.data import Dataset 11 | from transformers import PreTrainedTokenizerBase 12 | 13 | from prismatic.conf import DatasetConfig 14 | from prismatic.models.backbones.llm.prompting import PromptBuilder 15 | from prismatic.models.backbones.vision import ImageTransform 16 | from prismatic.preprocessing.datasets import AlignDataset, FinetuneDataset 17 | from prismatic.util.data_utils import PaddedCollatorForLanguageModeling 18 | 19 | # Dataset Initializers =>> Maps Stage --> cls() 20 | DATASET_INITIALIZER = {"align": AlignDataset, "finetune": FinetuneDataset, "full-finetune": FinetuneDataset} 21 | 22 | 23 | def get_dataset_and_collator( 24 | stage: str, 25 | dataset_cfg: DatasetConfig, 26 | image_transform: ImageTransform, 27 | tokenizer: PreTrainedTokenizerBase, 28 | prompt_builder_fn: Type[PromptBuilder], 29 | default_image_resolution: Tuple[int, int, int], 30 | padding_side: str = "right", 31 | ) -> Tuple[Dataset, PaddedCollatorForLanguageModeling]: 32 | dataset_cls = DATASET_INITIALIZER[stage] 33 | dataset_root_dir = dataset_cfg.dataset_root_dir 34 | collator = PaddedCollatorForLanguageModeling( 35 | tokenizer.model_max_length, tokenizer.pad_token_id, default_image_resolution, padding_side=padding_side 36 | ) 37 | 38 | # Switch on `stage` 39 | if stage == "align": 40 | annotation_json, image_dir = dataset_cfg.align_stage_components 41 | dataset = dataset_cls( 42 | dataset_root_dir / annotation_json, dataset_root_dir / image_dir, image_transform, tokenizer 43 | ) 44 | return dataset, collator 45 | 46 | elif stage == "finetune": 47 | annotation_json, image_dir = dataset_cfg.finetune_stage_components 48 | dataset = dataset_cls( 49 | dataset_root_dir / annotation_json, 50 | dataset_root_dir / image_dir, 51 | image_transform, 52 | tokenizer, 53 | prompt_builder_fn=prompt_builder_fn, 54 | ) 55 | return dataset, collator 56 | 57 | elif stage == "full-finetune": 58 | annotation_json, image_dir = dataset_cfg.finetune_stage_components 59 | dataset = dataset_cls( 60 | dataset_root_dir / annotation_json, 61 | dataset_root_dir / image_dir, 62 | image_transform, 63 | tokenizer, 64 | prompt_builder_fn=prompt_builder_fn, 65 | ) 66 | return dataset, collator 67 | 68 | else: 69 | raise ValueError(f"Stage `{stage}` is not supported!") 70 | -------------------------------------------------------------------------------- /prismatic/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenDriveLab/UniVLA/c9c6788514a707daf536848197b9a5024fc85b6e/prismatic/py.typed -------------------------------------------------------------------------------- /prismatic/training/__init__.py: -------------------------------------------------------------------------------- 1 | from .materialize import get_train_strategy 2 | from .metrics import Metrics, VLAMetrics 3 | -------------------------------------------------------------------------------- /prismatic/training/materialize.py: -------------------------------------------------------------------------------- 1 | """ 2 | materialize.py 3 | 4 | Factory class defining functions for instantiating various Training Strategies, supporting different VLMs, backbones, 5 | and strategy configurations. 6 | """ 7 | 8 | from typing import Callable, Optional 9 | 10 | import torch 11 | 12 | from prismatic.models.vlms import PrismaticVLM 13 | from prismatic.training.strategies import FSDPStrategy, TrainingStrategy 14 | 15 | # Registry =>> Maps ID --> {cls(), kwargs} :: supports FSDP for now, but DDP handler is also implemented! 16 | TRAIN_STRATEGIES = { 17 | "fsdp-shard-grad-op": {"cls": FSDPStrategy, "kwargs": {"sharding_strategy": "shard-grad-op"}}, 18 | "fsdp-full-shard": {"cls": FSDPStrategy, "kwargs": {"sharding_strategy": "full-shard"}}, 19 | } 20 | 21 | 22 | def get_train_strategy( 23 | train_strategy: str, 24 | vlm: PrismaticVLM, 25 | device_id: int, 26 | stage: str, 27 | epochs: int, 28 | max_steps: Optional[int], 29 | global_batch_size: int, 30 | per_device_batch_size: int, 31 | learning_rate: float, 32 | weight_decay: float, 33 | max_grad_norm: float, 34 | lr_scheduler_type: str, 35 | warmup_ratio: float, 36 | enable_gradient_checkpointing: bool = True, 37 | enable_mixed_precision_training: bool = True, 38 | reduce_in_full_precision: bool = False, 39 | mixed_precision_dtype: torch.dtype = torch.bfloat16, 40 | worker_init_fn: Optional[Callable[[int], None]] = None, 41 | ) -> TrainingStrategy: 42 | if train_strategy in TRAIN_STRATEGIES: 43 | strategy_cfg = TRAIN_STRATEGIES[train_strategy] 44 | strategy = strategy_cfg["cls"]( 45 | vlm=vlm, 46 | device_id=device_id, 47 | stage=stage, 48 | epochs=epochs, 49 | max_steps=max_steps, 50 | global_batch_size=global_batch_size, 51 | per_device_batch_size=per_device_batch_size, 52 | learning_rate=learning_rate, 53 | weight_decay=weight_decay, 54 | max_grad_norm=max_grad_norm, 55 | lr_scheduler_type=lr_scheduler_type, 56 | warmup_ratio=warmup_ratio, 57 | enable_gradient_checkpointing=enable_gradient_checkpointing, 58 | enable_mixed_precision_training=enable_mixed_precision_training, 59 | reduce_in_full_precision=reduce_in_full_precision, 60 | mixed_precision_dtype=mixed_precision_dtype, 61 | worker_init_fn=worker_init_fn, 62 | **strategy_cfg["kwargs"], 63 | ) 64 | return strategy 65 | else: 66 | raise ValueError(f"Train Strategy `{train_strategy}` is not supported!") 67 | -------------------------------------------------------------------------------- /prismatic/training/strategies/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_strategy import TrainingStrategy 2 | from .ddp import DDPStrategy 3 | from .fsdp import FSDPStrategy 4 | -------------------------------------------------------------------------------- /prismatic/training/strategies/ddp.py: -------------------------------------------------------------------------------- 1 | """ 2 | ddp.py 3 | 4 | Core class definition for a strategy implementing Torch native Distributed Data Parallel Training; note that on most 5 | GPU hardware and LLM backbones >= 5-7B parameters, DDP training will OOM, which is why we opt for FSDP. 6 | """ 7 | 8 | import shutil 9 | from pathlib import Path 10 | from typing import Optional 11 | 12 | import torch 13 | from torch.nn.parallel import DistributedDataParallel as DDP 14 | from torch.optim import AdamW 15 | from transformers.optimization import get_constant_schedule, get_cosine_schedule_with_warmup 16 | 17 | from prismatic.overwatch import initialize_overwatch 18 | from prismatic.training.strategies.base_strategy import TrainingStrategy 19 | 20 | # Initialize Overwatch =>> Wraps `logging.Logger` 21 | overwatch = initialize_overwatch(__name__) 22 | 23 | 24 | class DDPStrategy(TrainingStrategy): 25 | @overwatch.rank_zero_only 26 | def save_checkpoint( 27 | self, 28 | run_dir: Path, 29 | global_step: int, 30 | epoch: int, 31 | train_loss: Optional[float] = None, 32 | only_trainable: bool = True, 33 | ) -> None: 34 | """Save a checkpoint to the `run_dir` only containing the state_dicts for trainable parameters by default.""" 35 | assert isinstance(self.vlm, DDP), "save_checkpoint assumes VLM is already wrapped in DDP!" 36 | 37 | # Splinter State Dictionary by Top-Level Submodules (or subset, if `only_trainable`) 38 | model_state_dicts = { 39 | mkey: getattr(self.vlm.module, mkey).state_dict() 40 | for mkey in (self.trainable_module_keys if only_trainable else self.all_module_keys) 41 | } 42 | optimizer_state_dict = self.optimizer.state_dict() 43 | 44 | # Set Checkpoint Path =>> Embed *minimal* training statistics! 45 | checkpoint_dir = run_dir / "checkpoints" 46 | if train_loss is None: 47 | checkpoint_path = checkpoint_dir / f"step-{global_step:06d}-epoch-{epoch:02d}-loss=inf.pt" 48 | else: 49 | checkpoint_path = checkpoint_dir / f"step-{global_step:06d}-epoch-{epoch:02d}-loss={train_loss:.4f}.pt" 50 | 51 | # Save Checkpoint & Copy Latest to `latest-checkpoint.pt` 52 | torch.save({"model": model_state_dicts, "optimizer": optimizer_state_dict}, checkpoint_path) 53 | shutil.copy(checkpoint_path, checkpoint_dir / "latest-checkpoint.pt") 54 | 55 | def run_setup(self, run_dir: Path, n_train_examples: int) -> None: 56 | # Gradient Checkpointing Setup 57 | if self.enable_gradient_checkpointing: 58 | # For Gradient Checkpointing --> we make the assumption that the "bulk" of activation memory is taken up 59 | # by the LLM; because we also make the explicit assumption that each LLM is derived from a HF 60 | # pretrained model, the only thing we *need* to do (technically) is call `gradient_checkpoint_enable` 61 | # on `self.llm_backbone`. 62 | # 63 | # What does it actually do? --> runs the *generic* custom_forward + torch.utils.checkpoint.checkpoint logic 64 | # => github.com/huggingface/transformers/.../models/llama/modeling_llama.py#L692-L706 65 | # 66 | # Additional Reference (to better understand gradient checkpointing in PyTorch writ large) 67 | # => github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb 68 | overwatch.info("Enabling Gradient Checkpointing on LLM Backbone", ctx_level=1) 69 | self.vlm.llm_backbone.gradient_checkpointing_enable() 70 | 71 | # Move to Device =>> Note parameters are in full precision (*mixed precision* will only autocast as appropriate) 72 | overwatch.info("Placing Entire VLM (Vision Backbone, LLM Backbone, Projector Weights) on GPU", ctx_level=1) 73 | self.vlm.to(self.device_id) 74 | 75 | # Wrap with Distributed Data Parallel 76 | # => Note: By default, wrapping naively with DDP(self.vlm) will initialize a *separate* buffer on GPU that 77 | # is the same size/dtype as the model parameters; this will *double* GPU memory! 78 | # - stackoverflow.com/questions/68949954/model-takes-twice-the-memory-footprint-with-distributed-data-parallel 79 | overwatch.info("Wrapping VLM with Distributed Data Parallel", ctx_level=1) 80 | self.vlm = DDP(self.vlm, device_ids=[self.device_id], gradient_as_bucket_view=True) 81 | 82 | # Create Optimizer and LR Scheduler =>> note that most of the LR Schedulers we use require `max_steps/epochs` 83 | # => Optimizer should only operate on parameters that are *unfrozen* / trainable! 84 | trainable_params = [param for param in self.vlm.parameters() if param.requires_grad] 85 | if self.max_steps is None: 86 | num_training_steps = (n_train_examples * self.epochs) // self.global_batch_size 87 | else: 88 | num_training_steps = self.max_steps 89 | 90 | if self.lr_scheduler_type == "linear-warmup+cosine-decay": 91 | # Set warmup steps (floor) based on `warmup_ratio` (should be 0.03 - 0.05) 92 | num_warmup_steps = int(num_training_steps * self.warmup_ratio) 93 | 94 | assert self.weight_decay == 0, "DDP training does not currently support `weight_decay` > 0!" 95 | self.optimizer = AdamW(trainable_params, lr=self.learning_rate, weight_decay=self.weight_decay) 96 | self.lr_scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps, num_training_steps) 97 | for param_group in self.optimizer.param_groups: 98 | param_group["lr"] = 0.0 99 | 100 | elif self.lr_scheduler_type == "constant": 101 | num_warmup_steps = 0 102 | 103 | assert self.weight_decay == 0, "DDP training does not currently support `weight_decay` > 0!" 104 | self.optimizer = AdamW(trainable_params, lr=self.learning_rate, weight_decay=self.weight_decay) 105 | self.lr_scheduler = get_constant_schedule(self.optimizer) 106 | 107 | else: 108 | raise ValueError(f"Learning Rate Schedule with type `{self.lr_scheduler_type}` is not supported!") 109 | 110 | # Finalize Setup =>> Log 111 | overwatch.info( 112 | "DDP Strategy =>> Finalized Training Setup:\n" 113 | f" |-> Global (Effective) Batch Size = {self.global_batch_size}\n" 114 | f" |-> Per-Device Batch Size = {self.per_device_batch_size}\n" 115 | f" |-> Distributed World Size = {overwatch.world_size()}\n" 116 | f" |-> Gradient Accumulation Steps = {self.grad_accumulation_steps}\n\n" 117 | f" |-> LLM Backbone Gradient Checkpointing = {self.enable_gradient_checkpointing}\n" 118 | f" |-> Use Native AMP = {self.enable_mixed_precision_training} ({self.mixed_precision_dtype})\n\n" 119 | f" |-> Default AdamW LR = {self.learning_rate}\n" 120 | f" |-> AdamW Weight Decay = {self.weight_decay}\n" 121 | f" |-> LR Scheduler Type = {self.lr_scheduler_type}\n" 122 | f" |-> LR Scheduler Warmup Steps (Ratio) = {num_warmup_steps} ({self.warmup_ratio})\n" 123 | f" |-> Dataset Size = {n_train_examples} Examples\n" 124 | f" |-> Max Steps = {num_training_steps}\n" 125 | ) 126 | 127 | def clip_grad_norm(self) -> None: 128 | torch.nn.utils.clip_grad_norm_(self.vlm.parameters(), max_norm=self.max_grad_norm) 129 | -------------------------------------------------------------------------------- /prismatic/util/__init__.py: -------------------------------------------------------------------------------- 1 | from .torch_utils import check_bloat16_supported, set_global_seed 2 | -------------------------------------------------------------------------------- /prismatic/util/nn_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | nn_utils.py 3 | 4 | Utility functions and PyTorch submodule definitions. 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | # === Definitions for Various Projection Modules, with Signature :: [..., in_dim] --> [..., out_dim] === 12 | class LinearProjector(nn.Module): 13 | def __init__(self, vision_dim: int, llm_dim: int) -> None: 14 | super().__init__() 15 | self.projector = nn.Linear(vision_dim, llm_dim, bias=True) 16 | 17 | def forward(self, img_patches: torch.Tensor) -> torch.Tensor: 18 | return self.projector(img_patches) 19 | 20 | 21 | class MLPProjector(nn.Module): 22 | def __init__(self, vision_dim: int, llm_dim: int, mlp_type: str = "gelu-mlp") -> None: 23 | super().__init__() 24 | if mlp_type == "gelu-mlp": 25 | self.projector = nn.Sequential( 26 | nn.Linear(vision_dim, llm_dim, bias=True), 27 | nn.GELU(), 28 | nn.Linear(llm_dim, llm_dim, bias=True), 29 | ) 30 | else: 31 | raise ValueError(f"Projector with `{mlp_type = }` is not supported!") 32 | 33 | def forward(self, img_patches: torch.Tensor) -> torch.Tensor: 34 | return self.projector(img_patches) 35 | 36 | 37 | class FusedMLPProjector(nn.Module): 38 | def __init__(self, fused_vision_dim: int, llm_dim: int, mlp_type: str = "fused-gelu-mlp") -> None: 39 | super().__init__() 40 | self.initial_projection_dim = fused_vision_dim * 4 41 | if mlp_type == "fused-gelu-mlp": 42 | self.projector = nn.Sequential( 43 | nn.Linear(fused_vision_dim, self.initial_projection_dim, bias=True), 44 | nn.GELU(), 45 | nn.Linear(self.initial_projection_dim, llm_dim, bias=True), 46 | nn.GELU(), 47 | nn.Linear(llm_dim, llm_dim, bias=True), 48 | ) 49 | else: 50 | raise ValueError(f"Fused Projector with `{mlp_type = }` is not supported!") 51 | 52 | def forward(self, fused_img_patches: torch.Tensor) -> torch.Tensor: 53 | return self.projector(fused_img_patches) 54 | -------------------------------------------------------------------------------- /prismatic/util/torch_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | torch_utils.py 3 | 4 | General utilities for randomness, mixed precision training, and miscellaneous checks in PyTorch. 5 | 6 | Random `set_global_seed` functionality is taken directly from PyTorch-Lighting: 7 | > Ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/seed.py 8 | 9 | This is pretty important to get right if we're every randomly generating our masks (or prefix dropout) inside our 10 | Dataset __getitem__() with multiple workers... if not handled properly, we will get repeated augmentations anytime 11 | we inject randomness from non-PyTorch sources (e.g., numpy, random)! 12 | > Ref: https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/ 13 | 14 | Terminology 15 | -> World Size :: Total number of processes distributed over (# nodes x # devices) -- assumed homogenous! 16 | -> Rank :: Integer index of current process in the total world size 17 | -> Local Rank :: Local index on given node in [0, Devices per Node] 18 | """ 19 | 20 | import os 21 | import random 22 | from typing import Callable, Optional 23 | 24 | import numpy as np 25 | import torch 26 | 27 | # === Randomness === 28 | 29 | 30 | def set_global_seed(seed: int, get_worker_init_fn: bool = False) -> Optional[Callable[[int], None]]: 31 | """Sets seed for all randomness libraries (mostly random, numpy, torch) and produces a `worker_init_fn`""" 32 | assert np.iinfo(np.uint32).min < seed < np.iinfo(np.uint32).max, "Seed outside the np.uint32 bounds!" 33 | 34 | # Set Seed as an Environment Variable 35 | os.environ["EXPERIMENT_GLOBAL_SEED"] = str(seed) 36 | random.seed(seed) 37 | np.random.seed(seed) 38 | torch.manual_seed(seed) 39 | 40 | return worker_init_function if get_worker_init_fn else None 41 | 42 | 43 | def worker_init_function(worker_id: int) -> None: 44 | """ 45 | Borrowed directly from PyTorch-Lightning; inspired by this issue comment in the PyTorch repo: 46 | > Ref: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562 47 | 48 | Intuition: You can think of the seed sequence spawn function as a "janky" torch.Generator() or jax.PRNGKey that 49 | you can run iterative splitting on to get new (predictable) randomness. 50 | 51 | :param worker_id: Identifier for the given worker [0, num_workers) for the Dataloader in question. 52 | """ 53 | # Get current `rank` (if running distributed) and `process_seed` 54 | global_rank, process_seed = int(os.environ["LOCAL_RANK"]), torch.initial_seed() 55 | 56 | # Back out the "base" (original) seed - the per-worker seed is set in PyTorch: 57 | # > https://pytorch.org/docs/stable/data.html#data-loading-randomness 58 | base_seed = process_seed - worker_id 59 | 60 | # "Magic" code --> basically creates a seed sequence that mixes different "sources" and seeds every library... 61 | seed_seq = np.random.SeedSequence([base_seed, worker_id, global_rank]) 62 | 63 | # Use 128 bits (4 x 32-bit words) to represent seed --> generate_state(k) produces a `k` element array! 64 | np.random.seed(seed_seq.generate_state(4)) 65 | 66 | # Spawn distinct child sequences for PyTorch (reseed) and stdlib random 67 | torch_seed_seq, random_seed_seq = seed_seq.spawn(2) 68 | 69 | # Torch Manual seed takes 64 bits (so just specify a dtype of uint64 70 | torch.manual_seed(torch_seed_seq.generate_state(1, dtype=np.uint64)[0]) 71 | 72 | # Use 128 Bits for `random`, but express as integer instead of as an array 73 | random_seed = (random_seed_seq.generate_state(2, dtype=np.uint64).astype(list) * [1 << 64, 1]).sum() 74 | random.seed(random_seed) 75 | 76 | 77 | # === BFloat16 Support === 78 | 79 | 80 | def check_bloat16_supported() -> bool: 81 | try: 82 | import packaging.version 83 | import torch.cuda.nccl as nccl 84 | import torch.distributed as dist 85 | 86 | return ( 87 | (torch.version.cuda is not None) 88 | and torch.cuda.is_bf16_supported() 89 | and (packaging.version.parse(torch.version.cuda).release >= (11, 0)) 90 | and dist.is_nccl_available() 91 | and (nccl.version() >= (2, 10)) 92 | ) 93 | 94 | except Exception: 95 | return False 96 | -------------------------------------------------------------------------------- /prismatic/vla/__init__.py: -------------------------------------------------------------------------------- 1 | from .materialize import get_vla_dataset_and_collator, get_latent_vla_dataset_and_collator 2 | -------------------------------------------------------------------------------- /prismatic/vla/action_tokenizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | action_tokenizer.py 3 | 4 | Extension class; wraps base LLM/VLM tokenizer with logic to discretize and tokenize continuous robot actions. 5 | """ 6 | 7 | from typing import List, Union 8 | 9 | import numpy as np 10 | from transformers import PreTrainedTokenizerBase 11 | 12 | 13 | class ActionTokenizer: 14 | def __init__( 15 | self, tokenizer: PreTrainedTokenizerBase, bins: int = 256, min_action: int = -1, max_action: int = 1 16 | ) -> None: 17 | """ 18 | Discretizes continuous robot actions into N bins per dimension and maps to the least used tokens. 19 | 20 | NOTE =>> by default, assumes a BPE-style tokenizer akin to the LlamaTokenizer, where *the least used tokens* 21 | appear at the end of the vocabulary! 22 | 23 | :param tokenizer: Base LLM/VLM tokenizer to extend. 24 | :param bins: Number of bins for each continuous value; we'll adopt a uniform binning strategy. 25 | :param min_action: Minimum action value (for clipping, setting lower bound on bin interval). 26 | :param max_action: Maximum action value (for clipping, setting upper bound on bin interval). 27 | """ 28 | self.tokenizer, self.n_bins, self.min_action, self.max_action = tokenizer, bins, min_action, max_action 29 | 30 | # Create Uniform Bins + Compute Bin Centers 31 | self.bins = np.linspace(min_action, max_action, self.n_bins) 32 | self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0 33 | 34 | # [Contract] Set "action_token_begin_idx" based on `self.tokenizer.vocab_size - (self.n_bins + 1)` 35 | # =>> Assumes we're always overwriting the final `n_bins` tokens of the vocabulary! 36 | self.action_token_begin_idx: int = int(32000 - (self.n_bins + 1)) 37 | 38 | def __call__(self, action: np.ndarray) -> Union[str, List[str]]: 39 | """Clip & bin actions to *the last `n_bins` tokens* of the vocabulary (e.g., tokenizer.vocab[-256:]).""" 40 | action = np.clip(action, a_min=float(self.min_action), a_max=float(self.max_action)) 41 | discretized_action = np.digitize(action, self.bins) 42 | 43 | # Handle single element vs. batch 44 | if len(discretized_action.shape) == 1: 45 | return self.tokenizer.decode(list(32000 - discretized_action)) 46 | else: 47 | return self.tokenizer.batch_decode((32000 - discretized_action).tolist()) 48 | 49 | def decode_token_ids_to_actions(self, action_token_ids: np.ndarray) -> np.ndarray: 50 | """ 51 | Returns continuous actions for discrete action token IDs. 52 | 53 | NOTE =>> Because of the way the actions are discretized w.r.t. the bins (and not the bin centers), the 54 | digitization returns bin indices between [1, # bins], inclusive, when there are actually only 55 | (# bins - 1) bin intervals. 56 | 57 | Therefore, if the digitization returns the last possible index, we map this to the last bin interval. 58 | 59 | EXAMPLE =>> Let's say self._bins has 256 values. Then self._bin_centers has 255 values. Digitization returns 60 | indices between [1, 256]. We subtract 1 from all indices so that they are between [0, 255]. There 61 | is still one index (i==255) that would cause an out-of-bounds error if used to index into 62 | self._bin_centers. Therefore, if i==255, we subtract 1 from it so that it just becomes the index of 63 | the last bin center. We implement this simply via clipping between [0, 255 - 1]. 64 | """ 65 | discretized_actions = 32000 - action_token_ids 66 | discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1) 67 | 68 | return self.bin_centers[discretized_actions] 69 | 70 | @property 71 | def vocab_size(self) -> int: 72 | return self.n_bins 73 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import DummyDataset, EpisodicRLDSDataset, RLDSBatchTransform, \ 2 | RLDSDataset, RLDSBatchTransformVideo, RLDSBatchTransformLatentAction,\ 3 | RLDSBatchTransformLIBERO, RLDSBatchTransformLIBERO_withHis 4 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/rlds/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import make_interleaved_dataset, make_single_dataset 2 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/rlds/obs_transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | obs_transforms.py 3 | 4 | Contains observation-level transforms used in the orca data pipeline. 5 | 6 | These transforms operate on the "observation" dictionary, and are applied at a per-frame level. 7 | """ 8 | 9 | from typing import Dict, Tuple, Union 10 | 11 | import dlimp as dl 12 | import tensorflow as tf 13 | from absl import logging 14 | 15 | 16 | # ruff: noqa: B023 17 | def augment(obs: Dict, seed: tf.Tensor, augment_kwargs: Union[Dict, Dict[str, Dict]]) -> Dict: 18 | """Augments images, skipping padding images.""" 19 | image_names = {key[6:] for key in obs if key.startswith("image_")} 20 | 21 | # "augment_order" is required in augment_kwargs, so if it's there, we can assume that the user has passed 22 | # in a single augmentation dict (otherwise, we assume that the user has passed in a mapping from image 23 | # name to augmentation dict) 24 | if "augment_order" in augment_kwargs: 25 | augment_kwargs = {name: augment_kwargs for name in image_names} 26 | 27 | for i, name in enumerate(image_names): 28 | if name not in augment_kwargs: 29 | continue 30 | kwargs = augment_kwargs[name] 31 | logging.debug(f"Augmenting image_{name} with kwargs {kwargs}") 32 | obs[f"image_{name}"] = tf.cond( 33 | obs["pad_mask_dict"][f"image_{name}"], 34 | lambda: dl.transforms.augment_image( 35 | obs[f"image_{name}"], 36 | **kwargs, 37 | seed=seed + i, # augment each image differently 38 | ), 39 | lambda: obs[f"image_{name}"], # skip padding images 40 | ) 41 | 42 | return obs 43 | 44 | 45 | def decode_and_resize( 46 | obs: Dict, 47 | resize_size: Union[Tuple[int, int], Dict[str, Tuple[int, int]]], 48 | depth_resize_size: Union[Tuple[int, int], Dict[str, Tuple[int, int]]], 49 | ) -> Dict: 50 | """Decodes images and depth images, and then optionally resizes them.""" 51 | image_names = {key[6:] for key in obs if key.startswith("image_")} 52 | depth_names = {key[6:] for key in obs if key.startswith("depth_")} 53 | print('image_names', image_names) 54 | # print('depth_names', depth_names) 55 | if isinstance(resize_size, tuple): 56 | resize_size = {name: resize_size for name in image_names} 57 | if isinstance(depth_resize_size, tuple): 58 | depth_resize_size = {name: depth_resize_size for name in depth_names} 59 | 60 | print('keys', obs.keys()) 61 | for name in image_names: 62 | if name not in resize_size: 63 | logging.warning( 64 | f"No resize_size was provided for image_{name}. This will result in 1x1 " 65 | "padding images, which may cause errors if you mix padding and non-padding images." 66 | ) 67 | image = obs[f"image_{name}"] 68 | if image.dtype == tf.string: 69 | if tf.strings.length(image) == 0: 70 | # this is a padding image 71 | image = tf.zeros((*resize_size.get(name, (1, 1)), 3), dtype=tf.uint8) 72 | else: 73 | image = tf.io.decode_image(image, expand_animations=False, dtype=tf.uint8) 74 | elif image.dtype != tf.uint8: 75 | raise ValueError(f"Unsupported image dtype: found image_{name} with dtype {image.dtype}") 76 | if name in resize_size: 77 | image = dl.transforms.resize_image(image, size=resize_size[name]) 78 | obs[f"image_{name}"] = image 79 | 80 | for name in depth_names: 81 | if name not in depth_resize_size: 82 | logging.warning( 83 | f"No depth_resize_size was provided for depth_{name}. This will result in 1x1 " 84 | "padding depth images, which may cause errors if you mix padding and non-padding images." 85 | ) 86 | depth = obs[f"depth_{name}"] 87 | 88 | if depth.dtype == tf.string: 89 | if tf.strings.length(depth) == 0: 90 | depth = tf.zeros((*depth_resize_size.get(name, (1, 1)), 1), dtype=tf.float32) 91 | else: 92 | depth = tf.io.decode_image(depth, expand_animations=False, dtype=tf.float32)[..., 0] 93 | elif depth.dtype != tf.float32: 94 | raise ValueError(f"Unsupported depth dtype: found depth_{name} with dtype {depth.dtype}") 95 | 96 | if name in depth_resize_size: 97 | depth = dl.transforms.resize_depth_image(depth, size=depth_resize_size[name]) 98 | 99 | obs[f"depth_{name}"] = depth 100 | 101 | return obs 102 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/rlds/oxe/__init__.py: -------------------------------------------------------------------------------- 1 | from .materialize import get_oxe_dataset_kwargs_and_weights 2 | from .mixtures import OXE_NAMED_MIXTURES 3 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/rlds/oxe/materialize.py: -------------------------------------------------------------------------------- 1 | """ 2 | materialize.py 3 | 4 | Factory class for initializing Open-X Embodiment dataset kwargs and other parameters; provides and exports functions for 5 | clear control flow. 6 | """ 7 | 8 | from copy import deepcopy 9 | from pathlib import Path 10 | from typing import Any, Dict, List, Tuple 11 | 12 | from prismatic.overwatch import initialize_overwatch 13 | from prismatic.vla.datasets.rlds.oxe.configs import OXE_DATASET_CONFIGS, ActionEncoding 14 | from prismatic.vla.datasets.rlds.oxe.transforms import OXE_STANDARDIZATION_TRANSFORMS 15 | from prismatic.vla.datasets.rlds.utils.data_utils import NormalizationType 16 | 17 | # Initialize Overwatch =>> Wraps `logging.Logger` 18 | overwatch = initialize_overwatch(__name__) 19 | 20 | 21 | def make_oxe_dataset_kwargs( 22 | dataset_name: str, 23 | data_root_dir: Path, 24 | load_camera_views: Tuple[str] = ("primary",), 25 | load_depth: bool = False, 26 | load_proprio: bool = True, 27 | load_language: bool = True, 28 | action_proprio_normalization_type: NormalizationType = NormalizationType.NORMAL, 29 | ) -> Dict[str, Any]: 30 | """Generates config (kwargs) for given dataset from Open-X Embodiment.""" 31 | dataset_kwargs = deepcopy(OXE_DATASET_CONFIGS[dataset_name]) 32 | if dataset_kwargs["action_encoding"] not in [ActionEncoding.EEF_POS, ActionEncoding.EEF_R6]: 33 | raise ValueError(f"Cannot load `{dataset_name}`; only EEF_POS & EEF_R6 actions supported!") 34 | 35 | # [Contract] For EEF_POS & EEF_R6 actions, only the last action dimension (gripper) is absolute! 36 | # Normalize all action dimensions *except* the gripper 37 | if dataset_kwargs["action_encoding"] is ActionEncoding.EEF_POS: 38 | dataset_kwargs["absolute_action_mask"] = [False] * 6 + [True] 39 | dataset_kwargs["action_normalization_mask"] = [True] * 6 + [False] 40 | elif dataset_kwargs["action_encoding"] is ActionEncoding.EEF_R6: 41 | dataset_kwargs["absolute_action_mask"] = [False] * 9 + [True] 42 | dataset_kwargs["action_normalization_mask"] = [True] * 9 + [False] 43 | dataset_kwargs["action_proprio_normalization_type"] = action_proprio_normalization_type 44 | 45 | # Adjust Loaded Camera Views 46 | if len(missing_keys := (set(load_camera_views) - set(dataset_kwargs["image_obs_keys"]))) > 0: 47 | raise ValueError(f"Cannot load `{dataset_name}`; missing camera views `{missing_keys}`") 48 | 49 | # Filter 50 | dataset_kwargs["image_obs_keys"] = { 51 | k: v for k, v in dataset_kwargs["image_obs_keys"].items() if k in load_camera_views 52 | } 53 | dataset_kwargs["depth_obs_keys"] = { 54 | k: v for k, v in dataset_kwargs["depth_obs_keys"].items() if k in load_camera_views 55 | } 56 | 57 | # Eliminate Unnecessary Keys 58 | dataset_kwargs.pop("state_encoding") 59 | dataset_kwargs.pop("action_encoding") 60 | if not load_depth: 61 | dataset_kwargs.pop("depth_obs_keys") 62 | if not load_proprio: 63 | dataset_kwargs.pop("state_obs_keys") 64 | 65 | # Load Language 66 | if load_language: 67 | dataset_kwargs["language_key"] = "language_instruction" 68 | 69 | # Specify Standardization Transform 70 | dataset_kwargs["standardize_fn"] = OXE_STANDARDIZATION_TRANSFORMS[dataset_name] 71 | 72 | # Add any aux arguments 73 | if "aux_kwargs" in dataset_kwargs: 74 | dataset_kwargs.update(dataset_kwargs.pop("aux_kwargs")) 75 | 76 | return {"name": dataset_name, "data_dir": str(data_root_dir), **dataset_kwargs} 77 | 78 | 79 | def get_oxe_dataset_kwargs_and_weights( 80 | data_root_dir: Path, 81 | mixture_spec: List[Tuple[str, float]], 82 | load_camera_views: Tuple[str] = ("primary",), 83 | load_depth: bool = False, 84 | load_proprio: bool = True, 85 | load_language: bool = True, 86 | action_proprio_normalization_type: NormalizationType = NormalizationType.NORMAL, 87 | ) -> Tuple[Dict[str, Any], List[float]]: 88 | """ 89 | Generates dataset kwargs for a given dataset mix from the Open X-Embodiment dataset. The returned kwargs 90 | (per-dataset configs) and weights can be passed directly to `make_interleaved_dataset`. 91 | 92 | :param data_root_dir: Base directory containing RLDS/TFDS-formatted datasets (from Open-X) 93 | :param mixture_spec: List of (dataset_name, sampling_weight) from `oxe.mixtures.OXE_NAMED_MIXTURES` 94 | :param load_camera_views: Camera views to load; see `oxe.dataset_configs.py` for available views. 95 | :param load_depth: Load depth information in addition to camera RGB. 96 | :param load_proprio: Load proprioceptive state. 97 | :param load_language: Load language instructions. 98 | :param action_proprio_normalization_type: Normalization scheme to use for proprioceptive actions. 99 | 100 | return: Tuple of (per_dataset_kwargs, sampling_weights) 101 | """ 102 | included_datasets, filtered_mixture_spec = set(), [] 103 | for d_name, d_weight in mixture_spec: 104 | if d_name in included_datasets: 105 | overwatch.warning(f"Skipping Duplicate Dataset: `{(d_name, d_weight)}`") 106 | continue 107 | 108 | included_datasets.add(d_name) 109 | filtered_mixture_spec.append((d_name, d_weight)) 110 | 111 | # Assemble Dataset Config (kwargs) and Weights 112 | per_dataset_kwargs, sampling_weights = [], [] 113 | for d_name, d_weight in filtered_mixture_spec: 114 | try: 115 | per_dataset_kwargs.append( 116 | make_oxe_dataset_kwargs( 117 | d_name, 118 | data_root_dir, 119 | load_camera_views, 120 | load_depth, 121 | load_proprio, 122 | load_language, 123 | action_proprio_normalization_type, 124 | ) 125 | ) 126 | sampling_weights.append(d_weight) 127 | 128 | except ValueError as e: 129 | overwatch.warning(f"Skipping `{d_name}` due to Error: {e}") 130 | 131 | return per_dataset_kwargs, sampling_weights 132 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/rlds/oxe/mixtures.py: -------------------------------------------------------------------------------- 1 | """ 2 | mixtures.py 3 | 4 | Defines a registry of dataset mixtures and weights for the Open-X Embodiment Datasets. Each dataset is associated with 5 | a float "sampling weight" 6 | """ 7 | 8 | from typing import Dict, List, Tuple 9 | 10 | # fmt: off 11 | OXE_NAMED_MIXTURES: Dict[str, List[Tuple[str, float]]] = { 12 | # === Bridge V2 Dataset === 13 | "bridge": [ 14 | ("bridge_oxe", 1.0), # Version of Bridge V2 in Open-X GCP Bucket 15 | # ("bridge_orig", 1.0), # Original Version of Bridge V2 from Project Website 16 | ], 17 | 18 | "droid": [ 19 | ("droid", 1.0), 20 | ], 21 | 22 | # === Human-data Only === 23 | "Ego4D": [ 24 | ("ego4d_split_1", 1.0), 25 | ("ego4d_split_2", 1.0), 26 | ("ego4d_split_3", 1.0), 27 | ("ego4d_split_4", 1.0), 28 | ], 29 | 30 | 31 | "roboset": [ 32 | ("roboset", 1.0), 33 | ], 34 | 35 | "stanford_robocook_converted_externally_to_rlds": [ 36 | ("stanford_robocook_converted_externally_to_rlds", 1.0), 37 | ], 38 | 39 | # === [Moderate-Scale] Bridge++ Mixtures === 40 | "bridge_rt_1": [ 41 | # ("bridge_oxe", 1.0) # Version of Bridge V2 in Open-X GCP Bucket 42 | ("bridge_orig", 1.0), # Original Version of Bridge V2 from Project Website 43 | ("fractal20220817_data", 1.0), # Google RT-1 Robot Data (Large-Scale) 44 | ], 45 | 46 | "rt_1": [ 47 | ("fractal20220817_data", 1.0), 48 | ], 49 | 50 | # === UniVLA Magic Soup+ === 51 | "omni_magic_soup_plus": [ 52 | ("fractal20220817_data", 0.5), 53 | ("kuka", 0.1), 54 | ("bridge_oxe", 1.0), 55 | ("taco_play", 2.0), 56 | ("jaco_play", 1.0), 57 | ("berkeley_cable_routing", 1.0), 58 | ("roboturk", 2.0), 59 | ("viola", 2.0), 60 | ("berkeley_autolab_ur5", 2.0), 61 | ("toto", 1.0), 62 | ("language_table", 0.1), 63 | ("stanford_hydra_dataset_converted_externally_to_rlds", 2.0), 64 | ("austin_buds_dataset_converted_externally_to_rlds", 1.0), 65 | ("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0), 66 | ("furniture_bench_dataset_converted_externally_to_rlds", 0.1), 67 | ("ucsd_kitchen_dataset_converted_externally_to_rlds", 2.0), 68 | ("austin_sailor_dataset_converted_externally_to_rlds", 1.0), 69 | ("austin_sirius_dataset_converted_externally_to_rlds", 1.0), 70 | ("dlr_edan_shared_control_converted_externally_to_rlds", 1.0), 71 | ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0), 72 | ("utaustin_mutex", 1.0), 73 | ("berkeley_fanuc_manipulation", 2.0), 74 | ("cmu_stretch", 1.0), 75 | ("bc_z", 0.2), 76 | ("fmb", 1.0), 77 | ("dobbe", 0.2), 78 | ## Datasets for Navigation 79 | ("berkeley_gnm_recon", 1.0), 80 | ("berkeley_gnm_cory_hall", 1.0), 81 | ("berkeley_gnm_sac_son", 1.0), 82 | ], 83 | 84 | # === UniVLA Magic Soup++ === 85 | "omni_magic_soup_plus_plus": [ 86 | ("fractal20220817_data", 0.5), 87 | ("kuka", 0.1), 88 | ("bridge_oxe", 1.0), 89 | ("taco_play", 2.0), 90 | ("jaco_play", 1.0), 91 | ("berkeley_cable_routing", 1.0), 92 | ("roboturk", 2.0), 93 | ("viola", 2.0), 94 | ("berkeley_autolab_ur5", 2.0), 95 | ("toto", 1.0), 96 | ("language_table", 0.1), 97 | ("stanford_hydra_dataset_converted_externally_to_rlds", 2.0), 98 | ("austin_buds_dataset_converted_externally_to_rlds", 1.0), 99 | ("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0), 100 | ("furniture_bench_dataset_converted_externally_to_rlds", 0.1), 101 | ("ucsd_kitchen_dataset_converted_externally_to_rlds", 2.0), 102 | ("austin_sailor_dataset_converted_externally_to_rlds", 1.0), 103 | ("austin_sirius_dataset_converted_externally_to_rlds", 1.0), 104 | ("dlr_edan_shared_control_converted_externally_to_rlds", 1.0), 105 | ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0), 106 | ("utaustin_mutex", 1.0), 107 | ("berkeley_fanuc_manipulation", 2.0), 108 | ("cmu_stretch", 1.0), 109 | ("bc_z", 0.2), 110 | ("fmb", 1.0), 111 | ("dobbe", 0.2), 112 | ## Datasets for Navigation 113 | ("berkeley_gnm_recon", 2.0), 114 | ("berkeley_gnm_cory_hall", 2.0), 115 | ("berkeley_gnm_sac_son", 2.0), 116 | ## Human Datasets 117 | ("ego4d_split_1", 1.0), 118 | ("ego4d_split_2", 1.0), 119 | ("ego4d_split_3", 1.0), 120 | ("ego4d_split_4", 1.0), 121 | ], 122 | 123 | # === T-DROID Dataset === 124 | "tdroid_carrot_in_bowl": [ 125 | ("tdroid_carrot_in_bowl", 1.0), 126 | ], 127 | "tdroid_pour_corn_in_pot": [ 128 | ("tdroid_pour_corn_in_pot", 1.0), 129 | ], 130 | "tdroid_flip_pot_upright": [ 131 | ("tdroid_flip_pot_upright", 1.0), 132 | ], 133 | "tdroid_move_object_onto_plate": [ 134 | ("tdroid_move_object_onto_plate", 1.0), 135 | ], 136 | "tdroid_knock_object_over": [ 137 | ("tdroid_knock_object_over", 1.0), 138 | ], 139 | "tdroid_cover_object_with_towel": [ 140 | ("tdroid_cover_object_with_towel", 1.0), 141 | ], 142 | 143 | # === DROID Finetuning Datasets === 144 | "droid_wipe": [ 145 | ("droid_wipe", 1.0), 146 | ], 147 | 148 | # === LIBERO Datasets (Modified Versions) === 149 | "libero_spatial_no_noops": [ 150 | ("libero_spatial_no_noops", 1.0), 151 | ], 152 | "libero_object_no_noops": [ 153 | ("libero_object_no_noops", 1.0), 154 | ], 155 | "libero_goal_no_noops": [ 156 | ("libero_goal_no_noops", 1.0), 157 | ], 158 | "libero_10_no_noops": [ 159 | ("libero_10_no_noops", 1.0), 160 | ], 161 | "libero_10_no_noops_mini": [ 162 | ("libero_10_no_noops_mini", 1.0), 163 | ], 164 | "libero_goal_no_noops_mini": [ 165 | ("libero_goal_no_noops_mini", 1.0), 166 | ], 167 | "libero_goal_no_noops_half": [ 168 | ("libero_goal_no_noops_half", 1.0), 169 | ], 170 | "libero_10_no_noops_half": [ 171 | ("libero_10_no_noops_half", 1.0), 172 | ], 173 | "libero_goal_no_noops_quad": [ 174 | ("libero_goal_no_noops_quad", 1.0), 175 | ], 176 | "libero_10_no_noops_quad": [ 177 | ("libero_10_no_noops_quad", 1.0), 178 | ], 179 | "libero_combined": [ 180 | ("libero_combined", 1.0), 181 | ], 182 | } 183 | # fmt: on 184 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/rlds/oxe/utils/droid_utils.py: -------------------------------------------------------------------------------- 1 | """Episode transforms for DROID dataset.""" 2 | 3 | from typing import Any, Dict 4 | 5 | import tensorflow as tf 6 | import tensorflow_graphics.geometry.transformation as tfg 7 | 8 | 9 | def rmat_to_euler(rot_mat): 10 | return tfg.euler.from_rotation_matrix(rot_mat) 11 | 12 | 13 | def euler_to_rmat(euler): 14 | return tfg.rotation_matrix_3d.from_euler(euler) 15 | 16 | 17 | def invert_rmat(rot_mat): 18 | return tfg.rotation_matrix_3d.inverse(rot_mat) 19 | 20 | 21 | def rotmat_to_rot6d(mat): 22 | """ 23 | Converts rotation matrix to R6 rotation representation (first two rows in rotation matrix). 24 | Args: 25 | mat: rotation matrix 26 | 27 | Returns: 6d vector (first two rows of rotation matrix) 28 | 29 | """ 30 | r6 = mat[..., :2, :] 31 | r6_0, r6_1 = r6[..., 0, :], r6[..., 1, :] 32 | r6_flat = tf.concat([r6_0, r6_1], axis=-1) 33 | return r6_flat 34 | 35 | 36 | def velocity_act_to_wrist_frame(velocity, wrist_in_robot_frame): 37 | """ 38 | Translates velocity actions (translation + rotation) from base frame of the robot to wrist frame. 39 | Args: 40 | velocity: 6d velocity action (3 x translation, 3 x rotation) 41 | wrist_in_robot_frame: 6d pose of the end-effector in robot base frame 42 | 43 | Returns: 9d velocity action in robot wrist frame (3 x translation, 6 x rotation as R6) 44 | 45 | """ 46 | R_frame = euler_to_rmat(wrist_in_robot_frame[:, 3:6]) 47 | R_frame_inv = invert_rmat(R_frame) 48 | 49 | # world to wrist: dT_pi = R^-1 dT_rbt 50 | vel_t = (R_frame_inv @ velocity[:, :3][..., None])[..., 0] 51 | 52 | # world to wrist: dR_pi = R^-1 dR_rbt R 53 | dR = euler_to_rmat(velocity[:, 3:6]) 54 | dR = R_frame_inv @ (dR @ R_frame) 55 | dR_r6 = rotmat_to_rot6d(dR) 56 | return tf.concat([vel_t, dR_r6], axis=-1) 57 | 58 | 59 | def rand_swap_exterior_images(img1, img2): 60 | """ 61 | Randomly swaps the two exterior images (for training with single exterior input). 62 | """ 63 | return tf.cond(tf.random.uniform(shape=[]) > 0.5, lambda: (img1, img2), lambda: (img2, img1)) 64 | 65 | 66 | def droid_baseact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: 67 | """ 68 | DROID dataset transformation for actions expressed in *base* frame of the robot. 69 | """ 70 | dt = trajectory["action_dict"]["cartesian_velocity"][:, :3] 71 | dR = trajectory["action_dict"]["cartesian_velocity"][:, 3:6] 72 | 73 | trajectory["action"] = tf.concat( 74 | ( 75 | dt, 76 | dR, 77 | 1 - trajectory["action_dict"]["gripper_position"], 78 | ), 79 | axis=-1, 80 | ) 81 | trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = ( 82 | rand_swap_exterior_images( 83 | trajectory["observation"]["exterior_image_1_left"], 84 | trajectory["observation"]["exterior_image_2_left"], 85 | ) 86 | ) 87 | trajectory["observation"]["proprio"] = tf.concat( 88 | ( 89 | trajectory["observation"]["cartesian_position"], 90 | trajectory["observation"]["gripper_position"], 91 | ), 92 | axis=-1, 93 | ) 94 | print(trajectory['observation'].keys()) 95 | return trajectory 96 | 97 | 98 | def droid_wristact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: 99 | """ 100 | DROID dataset transformation for actions expressed in *wrist* frame of the robot. 101 | """ 102 | wrist_act = velocity_act_to_wrist_frame( 103 | trajectory["action_dict"]["cartesian_velocity"], trajectory["observation"]["cartesian_position"] 104 | ) 105 | trajectory["action"] = tf.concat( 106 | ( 107 | wrist_act, 108 | trajectory["action_dict"]["gripper_position"], 109 | ), 110 | axis=-1, 111 | ) 112 | trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = ( 113 | rand_swap_exterior_images( 114 | trajectory["observation"]["exterior_image_1_left"], 115 | trajectory["observation"]["exterior_image_2_left"], 116 | ) 117 | ) 118 | trajectory["observation"]["proprio"] = tf.concat( 119 | ( 120 | trajectory["observation"]["cartesian_position"], 121 | trajectory["observation"]["gripper_position"], 122 | ), 123 | axis=-1, 124 | ) 125 | return trajectory 126 | 127 | 128 | def droid_finetuning_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: 129 | """ 130 | DROID dataset transformation for actions expressed in *base* frame of the robot. 131 | """ 132 | dt = trajectory["action_dict"]["cartesian_velocity"][:, :3] 133 | dR = trajectory["action_dict"]["cartesian_velocity"][:, 3:6] 134 | trajectory["action"] = tf.concat( 135 | ( 136 | dt, 137 | dR, 138 | 1 - trajectory["action_dict"]["gripper_position"], 139 | ), 140 | axis=-1, 141 | ) 142 | trajectory["observation"]["proprio"] = tf.concat( 143 | ( 144 | trajectory["observation"]["cartesian_position"], 145 | trajectory["observation"]["gripper_position"], 146 | ), 147 | axis=-1, 148 | ) 149 | return trajectory 150 | 151 | 152 | def zero_action_filter(traj: Dict) -> bool: 153 | """ 154 | Filters transitions whose actions are all-0 (only relative actions, no gripper action). 155 | Note: this filter is applied *after* action normalization, so need to compare to "normalized 0". 156 | """ 157 | DROID_Q01 = tf.convert_to_tensor( 158 | [ 159 | -0.7776297926902771, 160 | -0.5803514122962952, 161 | -0.5795090794563293, 162 | -0.6464047729969025, 163 | -0.7041108310222626, 164 | -0.8895104378461838, 165 | ] 166 | ) 167 | DROID_Q99 = tf.convert_to_tensor( 168 | [ 169 | 0.7597932070493698, 170 | 0.5726242214441299, 171 | 0.7351000607013702, 172 | 0.6705610305070877, 173 | 0.6464948207139969, 174 | 0.8897542208433151, 175 | ] 176 | ) 177 | DROID_NORM_0_ACT = 2 * (tf.zeros_like(traj["action"][:, :6]) - DROID_Q01) / (DROID_Q99 - DROID_Q01 + 1e-8) - 1 178 | 179 | return tf.reduce_any(tf.math.abs(traj["action"][:, :6] - DROID_NORM_0_ACT) > 1e-5) 180 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/rlds/traj_transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | traj_transforms.py 3 | 4 | Contains trajectory transforms used in the orca data pipeline. Trajectory transforms operate on a dictionary 5 | that represents a single trajectory, meaning each tensor has the same leading dimension (the trajectory length). 6 | """ 7 | 8 | import logging 9 | from typing import Dict 10 | 11 | import tensorflow as tf 12 | 13 | def chunk_act_obs(traj, window_size, future_action_window_size): 14 | traj_len = tf.shape(traj["action"])[0] 15 | action_dim = traj["action"].shape[-1] 16 | 17 | # Create indices for the first and last elements within the window size 18 | first_indices = tf.range(traj_len)[:, None] # First index is the current timestep 19 | last_indices = tf.maximum(first_indices + (window_size - 1), 0) # Last index is the end of the window 20 | 21 | # Combine first and last indices into a single tensor 22 | chunk_indices = tf.concat([first_indices, last_indices], axis=1) # Shape: [traj_len, 2] 23 | 24 | # Create action_chunk_indices for the first and last elements 25 | action_first_indices = first_indices 26 | action_last_indices = tf.minimum(first_indices + (window_size + future_action_window_size - 1), traj_len - 1) 27 | action_chunk_indices = tf.concat([action_first_indices, action_last_indices], axis=1) # Shape: [traj_len, 2] 28 | 29 | # Ensure indices are bounded 30 | floored_chunk_indices = tf.maximum(tf.minimum(chunk_indices, traj_len - 1), 0) 31 | 32 | if "timestep" in traj["task"]: 33 | goal_timestep = traj["task"]["timestep"] 34 | else: 35 | goal_timestep = tf.fill([traj_len], traj_len - 1) 36 | 37 | 38 | floored_action_chunk_indices = tf.minimum(tf.maximum(action_chunk_indices, 0), goal_timestep[:, None]) 39 | 40 | traj["observation"] = tf.nest.map_structure(lambda x: tf.gather(x, floored_chunk_indices), traj["observation"]) 41 | traj["action"] = tf.gather(traj["action"], floored_action_chunk_indices) 42 | 43 | # indicates whether an entire observation is padding 44 | traj["observation"]["pad_mask"] = chunk_indices >= 0 45 | 46 | # If no absolute_action_mask was provided, assume all actions are relative 47 | if "absolute_action_mask" not in traj and future_action_window_size > 0: 48 | logging.warning( 49 | "future_action_window_size > 0 but no absolute_action_mask was provided. " 50 | "Assuming all actions are relative for the purpose of making neutral actions." 51 | ) 52 | absolute_action_mask = traj.get("absolute_action_mask", tf.zeros([traj_len, action_dim], dtype=tf.bool)) 53 | neutral_actions = tf.where( 54 | absolute_action_mask[:, None, :], 55 | traj["action"], # absolute actions are repeated (already done during chunking) 56 | tf.zeros_like(traj["action"]), # relative actions are zeroed 57 | ) 58 | 59 | # Actions past the goal timestep become neutral 60 | action_past_goal = action_chunk_indices > goal_timestep[:, None] 61 | traj["action"] = tf.where(action_past_goal[:, :, None], neutral_actions, traj["action"]) 62 | 63 | return traj 64 | 65 | 66 | def chunk_act_obs_libero(traj: Dict, window_size: int, future_action_window_size: int = 0) -> Dict: 67 | """ 68 | Chunks actions and observations into the given window_size. 69 | 70 | "observation" keys are given a new axis (at index 1) of size `window_size` containing `window_size - 1` 71 | observations from the past and the current observation. "action" is given a new axis (at index 1) of size 72 | `window_size + future_action_window_size` containing `window_size - 1` actions from the past, the current 73 | action, and `future_action_window_size` actions from the future. "pad_mask" is added to "observation" and 74 | indicates whether an observation should be considered padding (i.e. if it had come from a timestep 75 | before the start of the trajectory). 76 | """ 77 | traj_len = tf.shape(traj["action"])[0] 78 | action_dim = traj["action"].shape[-1] 79 | chunk_indices = tf.broadcast_to(tf.range(-window_size + 1, 1), [traj_len, window_size]) + tf.broadcast_to( 80 | tf.range(traj_len)[:, None], [traj_len, window_size] 81 | ) 82 | print('chunk_indices', chunk_indices) 83 | action_chunk_indices = tf.broadcast_to( 84 | tf.range(-window_size + 1, 1 + future_action_window_size), 85 | [traj_len, window_size + future_action_window_size], 86 | ) + tf.broadcast_to( 87 | tf.range(traj_len)[:, None], 88 | [traj_len, window_size + future_action_window_size], 89 | ) 90 | 91 | floored_chunk_indices = tf.maximum(chunk_indices, 0) 92 | 93 | if "timestep" in traj["task"]: 94 | goal_timestep = traj["task"]["timestep"] 95 | else: 96 | goal_timestep = tf.fill([traj_len], traj_len - 1) 97 | 98 | floored_action_chunk_indices = tf.minimum(tf.maximum(action_chunk_indices, 0), goal_timestep[:, None]) 99 | 100 | traj["observation"] = tf.nest.map_structure(lambda x: tf.gather(x, floored_chunk_indices), traj["observation"]) 101 | traj["action"] = tf.gather(traj["action"], floored_action_chunk_indices) 102 | 103 | # indicates whether an entire observation is padding 104 | traj["observation"]["pad_mask"] = chunk_indices >= 0 105 | 106 | # if no absolute_action_mask was provided, assume all actions are relative 107 | if "absolute_action_mask" not in traj and future_action_window_size > 0: 108 | logging.warning( 109 | "future_action_window_size > 0 but no absolute_action_mask was provided. " 110 | "Assuming all actions are relative for the purpose of making neutral actions." 111 | ) 112 | absolute_action_mask = traj.get("absolute_action_mask", tf.zeros([traj_len, action_dim], dtype=tf.bool)) 113 | neutral_actions = tf.where( 114 | absolute_action_mask[:, None, :], 115 | traj["action"], # absolute actions are repeated (already done during chunking) 116 | tf.zeros_like(traj["action"]), # relative actions are zeroed 117 | ) 118 | 119 | # actions past the goal timestep become neutral 120 | action_past_goal = action_chunk_indices > goal_timestep[:, None] 121 | traj["action"] = tf.where(action_past_goal[:, :, None], neutral_actions, traj["action"]) 122 | 123 | return traj 124 | 125 | 126 | def subsample(traj: Dict, subsample_length: int) -> Dict: 127 | """Subsamples trajectories to the given length.""" 128 | traj_len = tf.shape(traj["action"])[0] 129 | if traj_len > subsample_length: 130 | indices = tf.random.shuffle(tf.range(traj_len))[:subsample_length] 131 | traj = tf.nest.map_structure(lambda x: tf.gather(x, indices), traj) 132 | 133 | return traj 134 | 135 | 136 | def add_pad_mask_dict(traj: Dict) -> Dict: 137 | """ 138 | Adds a dictionary indicating which elements of the observation/task should be treated as padding. 139 | =>> traj["observation"|"task"]["pad_mask_dict"] = {k: traj["observation"|"task"][k] is not padding} 140 | """ 141 | traj_len = tf.shape(traj["action"])[0] 142 | 143 | for key in ["observation", "task"]: 144 | pad_mask_dict = {} 145 | for subkey in traj[key]: 146 | # Handles "language_instruction", "image_*", and "depth_*" 147 | if traj[key][subkey].dtype == tf.string: 148 | pad_mask_dict[subkey] = tf.strings.length(traj[key][subkey]) != 0 149 | 150 | # All other keys should not be treated as padding 151 | else: 152 | pad_mask_dict[subkey] = tf.ones([traj_len], dtype=tf.bool) 153 | 154 | traj[key]["pad_mask_dict"] = pad_mask_dict 155 | 156 | return traj 157 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/rlds/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenDriveLab/UniVLA/c9c6788514a707daf536848197b9a5024fc85b6e/prismatic/vla/datasets/rlds/utils/__init__.py -------------------------------------------------------------------------------- /prismatic/vla/datasets/rlds/utils/goal_relabeling.py: -------------------------------------------------------------------------------- 1 | """ 2 | goal_relabeling.py 3 | 4 | Contains simple goal relabeling logic for BC use-cases where rewards and next_observations are not required. 5 | Each function should add entries to the "task" dict. 6 | """ 7 | 8 | from typing import Dict 9 | 10 | import tensorflow as tf 11 | 12 | from prismatic.vla.datasets.rlds.utils.data_utils import tree_merge 13 | 14 | 15 | def uniform(traj: Dict) -> Dict: 16 | """Relabels with a true uniform distribution over future states.""" 17 | traj_len = tf.shape(tf.nest.flatten(traj["observation"])[0])[0] 18 | 19 | # Select a random future index for each transition i in the range [i + 1, traj_len) 20 | rand = tf.random.uniform([traj_len]) 21 | low = tf.cast(tf.range(traj_len) + 1, tf.float32) 22 | high = tf.cast(traj_len, tf.float32) 23 | goal_idxs = tf.cast(rand * (high - low) + low, tf.int32) 24 | 25 | # Sometimes there are floating-point errors that cause an out-of-bounds 26 | goal_idxs = tf.minimum(goal_idxs, traj_len - 1) 27 | 28 | # Adds keys to "task" mirroring "observation" keys (`tree_merge` to combine "pad_mask_dict" properly) 29 | goal = tf.nest.map_structure(lambda x: tf.gather(x, goal_idxs), traj["observation"]) 30 | traj["task"] = tree_merge(traj["task"], goal) 31 | 32 | return traj 33 | -------------------------------------------------------------------------------- /prismatic/vla/datasets/rlds/utils/task_augmentation.py: -------------------------------------------------------------------------------- 1 | """ 2 | task_augmentation.py 3 | 4 | Contains basic logic for randomly zeroing out keys in the task specification. 5 | """ 6 | 7 | from typing import Dict 8 | 9 | import tensorflow as tf 10 | 11 | from prismatic.vla.datasets.rlds.utils.data_utils import to_padding 12 | 13 | 14 | def delete_task_conditioning(traj: Dict, keep_image_prob: float) -> Dict: 15 | """ 16 | Randomly drops out either the goal images or the language instruction. Only does something if both of 17 | these are present. 18 | 19 | Args: 20 | traj: A dictionary containing trajectory data. Should have a "task" key. 21 | keep_image_prob: The probability of keeping the goal images. The probability of keeping the language 22 | instruction is 1 - keep_image_prob. 23 | """ 24 | if "language_instruction" not in traj["task"]: 25 | return traj 26 | 27 | image_keys = {key for key in traj["task"].keys() if key.startswith("image_") or key.startswith("depth_")} 28 | if not image_keys: 29 | return traj 30 | 31 | traj_len = tf.shape(traj["action"])[0] 32 | should_keep_images = tf.random.uniform([traj_len]) < keep_image_prob 33 | should_keep_images |= ~traj["task"]["pad_mask_dict"]["language_instruction"] 34 | 35 | for key in image_keys | {"language_instruction"}: 36 | should_keep = should_keep_images if key in image_keys else ~should_keep_images 37 | # pad out the key 38 | traj["task"][key] = tf.where( 39 | should_keep, 40 | traj["task"][key], 41 | to_padding(traj["task"][key]), 42 | ) 43 | # zero out the pad mask dict for the key 44 | traj["task"]["pad_mask_dict"][key] = tf.where( 45 | should_keep, 46 | traj["task"]["pad_mask_dict"][key], 47 | tf.zeros_like(traj["task"]["pad_mask_dict"][key]), 48 | ) 49 | 50 | # when no goal images are present, the goal timestep becomes the final timestep 51 | traj["task"]["timestep"] = tf.where( 52 | should_keep_images, 53 | traj["task"]["timestep"], 54 | traj_len - 1, 55 | ) 56 | 57 | return traj 58 | -------------------------------------------------------------------------------- /prismatic/vla/materialize.py: -------------------------------------------------------------------------------- 1 | """ 2 | materialize.py 3 | 4 | Factory class for initializing Open-X RLDS-backed datasets, given specified data mixture parameters; provides and 5 | exports individual functions for clear control flow. 6 | """ 7 | 8 | from pathlib import Path 9 | from typing import Tuple, Type 10 | 11 | from torch.utils.data import Dataset 12 | from transformers import PreTrainedTokenizerBase 13 | 14 | from prismatic.models.backbones.llm.prompting import PromptBuilder 15 | from prismatic.models.backbones.vision import ImageTransform 16 | from prismatic.util.data_utils import PaddedCollatorForActionPrediction 17 | from prismatic.vla.action_tokenizer import ActionTokenizer 18 | from prismatic.vla.datasets import EpisodicRLDSDataset, RLDSBatchTransform, RLDSBatchTransformLatentAction, RLDSDataset 19 | 20 | 21 | def get_vla_dataset_and_collator( 22 | data_root_dir: Path, 23 | data_mix: str, 24 | image_transform: ImageTransform, 25 | tokenizer: PreTrainedTokenizerBase, 26 | prompt_builder_fn: Type[PromptBuilder], 27 | default_image_resolution: Tuple[int, int, int], 28 | padding_side: str = "right", 29 | predict_stop_token: bool = True, 30 | shuffle_buffer_size: int = 100_000, 31 | train: bool = True, 32 | episodic: bool = False, 33 | image_aug: bool = False, 34 | ) -> Tuple[Dataset, ActionTokenizer, PaddedCollatorForActionPrediction]: 35 | """Initialize RLDS Dataset (wraps TFDS), ActionTokenizer, and initialize transform/collation functions.""" 36 | action_tokenizer = ActionTokenizer(tokenizer) 37 | batch_transform = RLDSBatchTransform( 38 | action_tokenizer, tokenizer, image_transform, prompt_builder_fn, predict_stop_token=predict_stop_token 39 | ) 40 | collator = PaddedCollatorForActionPrediction( 41 | tokenizer.model_max_length, tokenizer.pad_token_id, padding_side=padding_side 42 | ) 43 | 44 | # Build RLDS Iterable Dataset 45 | cls = RLDSDataset if not episodic else EpisodicRLDSDataset 46 | dataset = cls( 47 | data_root_dir, 48 | data_mix, 49 | batch_transform, 50 | resize_resolution=default_image_resolution[1:], 51 | shuffle_buffer_size=shuffle_buffer_size, 52 | train=train, 53 | image_aug=image_aug, 54 | ) 55 | 56 | return dataset, action_tokenizer, collator 57 | 58 | 59 | def get_latent_vla_dataset_and_collator( 60 | data_root_dir: Path, 61 | data_mix: str, 62 | image_transform: ImageTransform, 63 | image_transform_lam: ImageTransform, 64 | latent_action_tokenizer: PreTrainedTokenizerBase, 65 | tokenizer: PreTrainedTokenizerBase, 66 | prompt_builder_fn: Type[PromptBuilder], 67 | default_image_resolution: Tuple[int, int, int], 68 | padding_side: str = "right", 69 | predict_stop_token: bool = True, 70 | shuffle_buffer_size: int = 100_000, 71 | train: bool = True, 72 | episodic: bool = False, 73 | image_aug: bool = False, 74 | ) -> Tuple[Dataset, ActionTokenizer, PaddedCollatorForActionPrediction]: 75 | """Initialize RLDS Dataset (wraps TFDS), ActionTokenizer, and initialize transform/collation functions.""" 76 | # action_tokenizer = ActionTokenizer(tokenizer) 77 | 78 | batch_transform = RLDSBatchTransformLatentAction( 79 | action_tokenizer=latent_action_tokenizer, 80 | base_tokenizer=tokenizer, 81 | image_transform=image_transform, 82 | image_transform_lam=image_transform_lam, 83 | prompt_builder_fn=prompt_builder_fn 84 | ) 85 | 86 | collator = PaddedCollatorForActionPrediction( 87 | tokenizer.model_max_length, tokenizer.pad_token_id, padding_side=padding_side 88 | ) 89 | 90 | 91 | # Build RLDS Iterable Dataset 92 | cls = RLDSDataset if not episodic else EpisodicRLDSDataset 93 | dataset = cls( 94 | data_root_dir, 95 | data_mix, 96 | batch_transform, 97 | resize_resolution=default_image_resolution[1:], 98 | shuffle_buffer_size=shuffle_buffer_size, 99 | train=train, 100 | image_aug=image_aug, 101 | training_phase='pre-training', 102 | ) 103 | 104 | return dataset, tokenizer, collator -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "univla" 7 | authors = [ 8 | {name = "Qingwen Bu", email="qwbu01@sjtu.edu.cn"}, 9 | ] 10 | description = "UniVLA: Learning to Act Anywhere with Task-centric Latent Actions" 11 | version = "1.0.0" 12 | readme = "README.md" 13 | requires-python = ">=3.8" 14 | keywords = ["robotic manipulation", "vision-language-action models", "latent action"] 15 | license = {file = "LICENSE"} 16 | classifiers = [ 17 | "Development Status :: 3 - Alpha", 18 | "Intended Audience :: Developers", 19 | "Intended Audience :: Education", 20 | "Intended Audience :: Science/Research", 21 | "License :: OSI Approved :: MIT License", 22 | "Operating System :: OS Independent", 23 | "Programming Language :: Python :: 3", 24 | "Programming Language :: Python :: 3.8", 25 | "Programming Language :: Python :: 3.9", 26 | "Programming Language :: Python :: 3.10", 27 | "Programming Language :: Python :: 3 :: Only", 28 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 29 | ] 30 | dependencies = [ 31 | "absl_py==2.1.0", 32 | "accelerate==0.32.1", 33 | "dlimp @ git+https://github.com/moojink/dlimp_openvla", 34 | "draccus==0.8.0", 35 | "einops==0.8.1", 36 | "ema_pytorch==0.5.1", 37 | "gym==0.26.2", 38 | "h5py==3.11.0", 39 | "huggingface_hub==0.26.1", 40 | "hydra-core==1.3.2", 41 | "imageio==2.34.2", 42 | "jsonlines==4.0.0", 43 | "lightning==2.4.0", 44 | "matplotlib==3.10.1", 45 | "moviepy==1.0.3", 46 | "numpy==1.26.4", 47 | "omegaconf==2.3.0", 48 | "opencv_python==4.10.0.84", 49 | "packaging==24.1", 50 | "peft==0.11.1", 51 | "Pillow==11.2.1", 52 | "piq==0.8.0", 53 | "pyquaternion==0.9.9", 54 | "pytorch_lightning==1.8.6", 55 | "PyYAML==6.0.1", 56 | "Requests==2.32.3", 57 | "rich==14.0.0", 58 | "robosuite==1.4.1", 59 | "rotary_embedding_torch==0.8.4", 60 | "setuptools==57.5.0", 61 | "tensorflow==2.15.0", 62 | "tensorflow_datasets==4.9.3", 63 | "tensorflow_graphics==2021.12.3", 64 | "termcolor==3.0.1", 65 | "timm==0.9.10", 66 | "tokenizers==0.19.1", 67 | "tqdm==4.66.4", 68 | "transformers==4.40.1" 69 | ] 70 | 71 | [project.optional-dependencies] 72 | dev = [ 73 | "black>=24.2.0", 74 | "gpustat", 75 | "ipython", 76 | "pre-commit", 77 | "ruff>=0.2.2", 78 | ] 79 | sagemaker = [ 80 | "boto3", 81 | "sagemaker" 82 | ] 83 | 84 | [project.urls] 85 | homepage = "https://opendrivelab.com/UniVLA/" 86 | 87 | 88 | [tool.setuptools.packages.find] 89 | where = ["."] 90 | exclude = ["cache"] 91 | 92 | [tool.setuptools.package-data] 93 | "prismatic" = ["py.typed"] 94 | 95 | [tool.black] 96 | line-length = 121 97 | target-version = ["py38", "py39", "py310"] 98 | preview = true 99 | 100 | [tool.ruff] 101 | line-length = 121 102 | target-version = "py38" 103 | 104 | [tool.ruff.lint] 105 | select = ["A", "B", "E", "F", "I", "RUF", "W"] 106 | ignore = ["F722"] 107 | 108 | [tool.ruff.lint.per-file-ignores] 109 | "__init__.py" = ["E402", "F401"] 110 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl_py==2.1.0 2 | accelerate==0.32.1 3 | braceexpand=0.1.7 4 | Calvin.egg==info 5 | calvin_env==0.0.1 6 | git+https://github.com/moojink/dlimp_openvla 7 | draccus==0.8.0 8 | einops==0.8.1 9 | ema_pytorch==0.5.1 10 | gym==0.26.2 11 | h5py==3.11.0 12 | huggingface_hub==0.26.1 13 | hydra-core==1.3.2 14 | imageio==2.34.2 15 | jsonlines==4.0.0 16 | lightning==2.4.0 17 | matplotlib==3.10.1 18 | moviepy==1.0.3 19 | numpy==1.26.4 20 | omegaconf==2.3.0 21 | opencv_python==4.10.0.84 22 | openvla==0.0.3 23 | packaging==24.1 24 | peft==0.11.1 25 | Pillow==11.2.1 26 | piq==0.8.0 27 | pyquaternion==0.9.9 28 | pytorch_lightning==1.8.6 29 | PyYAML==6.0.1 30 | PyYAML==6.0.2 31 | Requests==2.32.3 32 | rich==14.0.0 33 | robosuite==1.4.1 34 | rotary_embedding_torch==0.8.4 35 | setuptools==57.5.0 36 | tensorflow==2.15.0 37 | tensorflow_datasets==4.9.3 38 | tensorflow_graphics==2021.12.3 39 | termcolor==3.0.1 40 | timm==0.9.10 41 | tokenizers==0.19.1 42 | torch==2.2.0 43 | torchvision==0.17.0 44 | tqdm==4.66.4 45 | transformers==4.40.1 46 | webdataset==0.2.111 47 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from os import path as op 3 | import re 4 | 5 | from setuptools import find_packages, setup 6 | 7 | 8 | def _read(f): 9 | return open(op.join(op.dirname(__file__), f)).read() if op.exists(f) else "" 10 | 11 | 12 | _meta = _read("prismatic/__init__.py") 13 | 14 | 15 | def find_meta(_meta, string): 16 | l_match = re.search(r"^" + string + r'\s*=\s*"(.*)"', _meta, re.M) 17 | if l_match: 18 | return l_match.group(1) 19 | raise RuntimeError(f"Unable to find {string} string.") 20 | 21 | 22 | # install_requires = [ 23 | # l for l in _read("requirements.txt").split("\n") if l and not l.startswith("#") and not l.startswith("-") 24 | # ] 25 | 26 | meta = dict( 27 | name=find_meta(_meta, "__project__"), 28 | version=find_meta(_meta, "__version__"), 29 | license=find_meta(_meta, "__license__"), 30 | description="UniVLA", 31 | platforms=("Any"), 32 | zip_safe=False, 33 | author=find_meta(_meta, "__author__"), 34 | author_email=find_meta(_meta, "__email__"), 35 | url="https://github.com/OpenDriveLab/UniVLA", 36 | packages=find_packages(exclude=["tests"]), 37 | # install_requires=install_requires, 38 | ) 39 | 40 | if __name__ == "__main__": 41 | print("find_package", find_packages(exclude=["tests"])) 42 | setup(**meta) -------------------------------------------------------------------------------- /vla-scripts/train.sh: -------------------------------------------------------------------------------- 1 | export LD_LIBRARY_PATH=/home/pai/envs/openvla/lib/python3.10/site-packages/nvidia/cudnn/lib:$LD_LIBRARY_PATH 2 | GPUS_PER_NODE=8 3 | NNODES=4 4 | MASTER_PORT=${MASTER_PORT:-28596} 5 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 6 | RANK=${RANK:-0} 7 | 8 | 9 | # Run your training script with torchrun 10 | torchrun --nproc_per_node ${GPUS_PER_NODE} --nnodes ${NNODES} --node_rank ${RANK} --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} train.py \ 11 | --vla.type prism-dinosiglip-224px+mx-oxe-magic-soup-plus \ 12 | --run_root_dir "vla_log" \ --------------------------------------------------------------------------------