├── vla_network ├── utils │ └── constant.py ├── type │ ├── __init__.py │ └── data_type.py ├── data_preprocessing │ ├── prompt.py │ ├── vla_data_collator.py │ ├── preprocess.py │ ├── tokenizer.py │ └── token_pattern.py ├── config │ ├── __init__.py │ └── define.py ├── model │ ├── vla │ │ ├── projector.py │ │ ├── flow_matching.py │ │ └── __init__.py │ ├── backbone_2d │ │ ├── __init__.py │ │ └── dinosiglip_vit.py │ └── backbone_llm │ │ ├── internlm │ │ ├── tokenization_internlm2_fast.py │ │ ├── configuration_internlm2.py │ │ └── modeling_internlm2.py │ │ └── __init__.py └── scripts │ ├── offline_test.py │ └── serve.py ├── figs ├── teaser.jpg ├── playground.gif └── real-world.png ├── visualization ├── trial-20250507120350_data.npy └── trial-20250507120350_visualization.png ├── .gitignore ├── requirements.txt ├── setup.py └── README.md /vla_network/utils/constant.py: -------------------------------------------------------------------------------- 1 | IGNORE_INDEX=-100 -------------------------------------------------------------------------------- /figs/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-EPIC/GraspVLA/HEAD/figs/teaser.jpg -------------------------------------------------------------------------------- /figs/playground.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-EPIC/GraspVLA/HEAD/figs/playground.gif -------------------------------------------------------------------------------- /figs/real-world.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-EPIC/GraspVLA/HEAD/figs/real-world.png -------------------------------------------------------------------------------- /vla_network/type/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_type import RawVLAData, BatchVLAData # type: ignore 2 | -------------------------------------------------------------------------------- /vla_network/data_preprocessing/prompt.py: -------------------------------------------------------------------------------- 1 | COT_PROMPT = lambda prompt: f"In: What action should the robot take to {prompt}?\nOut: " 2 | -------------------------------------------------------------------------------- /visualization/trial-20250507120350_data.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-EPIC/GraspVLA/HEAD/visualization/trial-20250507120350_data.npy -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | 3 | **/__pycache__ 4 | **.egg-info 5 | 6 | data 7 | tmp 8 | exps 9 | ckpt 10 | wandb 11 | venv 12 | .idea 13 | vla_data 14 | -------------------------------------------------------------------------------- /visualization/trial-20250507120350_visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-EPIC/GraspVLA/HEAD/visualization/trial-20250507120350_visualization.png -------------------------------------------------------------------------------- /vla_network/config/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .define import Backbone2DConfig, VLADataConfig, BasicModelConfig, LLMConfig, VLAModelConfig, BasicConfig, VLAConfig, ActionExpertConfig, FlowMatchingConfig, ImageTransform # type: ignore -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.7.1 2 | torchvision==0.22.1 3 | transformers==4.53.1 4 | numpy==1.26.4 5 | Pillow>=8.0.0 6 | pydantic>=2.0.0 7 | timm>=0.9.0 8 | tqdm>=4.60.0 9 | pyzmq>=25.0.0 10 | transforms3d>=0.4.0 11 | safetensors>=0.3.0 12 | typing-extensions>=4.0.0 13 | opencv-python 14 | matplotlib 15 | termcolor -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="vla_network", 5 | version="0.0.1", 6 | author="galbot_vla_team", 7 | description="model code for vla", 8 | long_description=open("README.md").read(), 9 | packages=find_packages(), 10 | python_requires=">=3.8", 11 | ) 12 | -------------------------------------------------------------------------------- /vla_network/model/vla/projector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class FusedMLPProjector(nn.Module): 6 | def __init__(self, fused_vision_dim: int, llm_dim: int) -> None: 7 | super().__init__() 8 | self.initial_projection_dim = fused_vision_dim * 4 9 | self.projector = nn.Sequential( 10 | nn.Linear(fused_vision_dim, self.initial_projection_dim, bias=True), 11 | nn.GELU(), 12 | nn.Linear(self.initial_projection_dim, llm_dim, bias=True), 13 | nn.GELU(), 14 | nn.Linear(llm_dim, llm_dim, bias=True), 15 | ) 16 | 17 | def forward(self, fused_img_patches: torch.Tensor) -> torch.Tensor: 18 | return self.projector(fused_img_patches) 19 | -------------------------------------------------------------------------------- /vla_network/model/backbone_2d/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import torch 3 | from torch import nn 4 | 5 | from vla_network.config import Backbone2DConfig, ImageTransform 6 | 7 | 8 | class Backbone2D(nn.Module, ABC): 9 | config: Backbone2DConfig 10 | image_transform: ImageTransform 11 | 12 | def __init__(self, config: Backbone2DConfig) -> None: 13 | super().__init__() 14 | self.config = config 15 | 16 | @property 17 | @abstractmethod 18 | def feature_dim(self) -> int: ... 19 | 20 | @abstractmethod 21 | def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: 22 | raise NotImplementedError 23 | 24 | @staticmethod 25 | def init(config: Backbone2DConfig) -> "Backbone2D": 26 | if config.name == "dinosiglip": 27 | from .dinosiglip_vit import DinoSigLIPViTBackbone 28 | 29 | return DinoSigLIPViTBackbone(config) 30 | else: 31 | raise NotImplementedError 32 | -------------------------------------------------------------------------------- /vla_network/data_preprocessing/vla_data_collator.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import torch 3 | from torch.nn.utils.rnn import pad_sequence 4 | 5 | from vla_network.type import BatchVLAData 6 | 7 | from vla_network.config import VLADataConfig 8 | 9 | 10 | def vla_collator(config: VLADataConfig, datas: List[BatchVLAData]) -> dict: 11 | kwargs = dict() 12 | 13 | pad_idx = config.tokenizer.pad_token_id 14 | max_len = config.tokenizer.model_max_length 15 | input_ids = pad_sequence( 16 | [data.input_ids[0] for data in datas], 17 | batch_first=True, 18 | padding_value=pad_idx, 19 | ) 20 | robot_input_ids = pad_sequence( 21 | [data.robot_input_ids[0] for data in datas], 22 | batch_first=True, 23 | padding_value=pad_idx, 24 | ) 25 | kwargs["input_ids"] = input_ids[:, :max_len] 26 | kwargs["robot_input_ids"] = robot_input_ids 27 | kwargs["attention_mask"] = kwargs["input_ids"] != pad_idx 28 | kwargs["robot_attention_mask"] = robot_input_ids != pad_idx 29 | 30 | for k in ['images', 'action', 'proprio', 'goal', 'is_action']: 31 | if getattr(datas[0], k, None) is not None: 32 | kwargs[k] = torch.cat([getattr(data, k) for data in datas], dim=0) 33 | else: 34 | kwargs[k] = None 35 | 36 | return kwargs -------------------------------------------------------------------------------- /vla_network/type/data_type.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional, List, Any, Dict 3 | from pydantic import BaseModel, ConfigDict, BeforeValidator, PlainSerializer 4 | from typing_extensions import Annotated 5 | import numpy as np 6 | import torch 7 | 8 | def nd_array_custom_before_validator(x): 9 | # custom before validation logic for np.ndarray 10 | return np.array(x) 11 | 12 | 13 | def nd_array_custom_serializer(x): 14 | # custom serialization logic for np.ndarray 15 | return x.tolist() 16 | 17 | NdArray = Annotated[ 18 | np.ndarray, 19 | BeforeValidator(nd_array_custom_before_validator), 20 | PlainSerializer(nd_array_custom_serializer, return_type=list, when_used="json"), 21 | ] # A wrapper of np.ndarray in order to make it usable in pydantic 22 | 23 | class RawVLAData(BaseModel): 24 | model_config = ConfigDict(arbitrary_types_allowed=True) 25 | 26 | # instruction 27 | instruction: Optional[str] = None 28 | can_be_anything: bool = False 29 | 30 | # Observation 31 | images: Optional[Dict[str, NdArray]] = None 32 | bboxs: Optional[Dict[str, NdArray]] = None 33 | pcs: Optional[Dict[str, NdArray]] = None 34 | proprio: NdArray = None 35 | proprio_flag: NdArray = None 36 | for_rel_proprio: NdArray = None 37 | for_rel_proprio_flag: NdArray = None 38 | 39 | # Action 40 | action: Optional[NdArray] = None 41 | action_flag: Optional[NdArray] = None 42 | goal: Optional[NdArray] = None 43 | goal_trans: Optional[NdArray] = None 44 | goal_rot: Optional[NdArray] = None 45 | 46 | 47 | @dataclass 48 | class BatchVLAData: 49 | debug: List[Any] # TODO: Conflict with huggingface now 50 | # With the following things: 51 | # dataset_name: List[[str]] 52 | # data_id: List[[str]] 53 | # orig_instruction: List[[str]] 54 | # instruction: List[[str]] 55 | 56 | # tokens 57 | input_ids: torch.Tensor # (B, N_token) 58 | robot_input_ids: torch.Tensor # (B, N_robot_token) 59 | labels: Optional[torch.Tensor] # (B, N_token) 60 | robot_labels: Optional[torch.Tensor] # (B, N_robot_token) 61 | attention_mask: torch.Tensor # (B, N_token) 62 | robot_attention_mask: torch.Tensor # (B, N_robot_token) 63 | 64 | # robot 65 | action: torch.Tensor # (B, T_action, D_action) 66 | proprio: torch.Tensor # (B, T_proprio, D_proprio) 67 | goal: Optional[torch.Tensor] # (B, D_goal) 68 | 69 | # Images 70 | images: torch.Tensor # (B, T_image, N_backbone, C, H, W) 71 | 72 | # type 73 | is_action: torch.Tensor # (B,) 74 | 75 | # inference 76 | inference_kwargs: Optional[list] = None -------------------------------------------------------------------------------- /vla_network/model/backbone_2d/dinosiglip_vit.py: -------------------------------------------------------------------------------- 1 | import timm 2 | import torch 3 | from torch import nn 4 | from dataclasses import dataclass 5 | from typing import List 6 | from timm.models.vision_transformer import VisionTransformer 7 | from PIL import Image 8 | from torchvision.transforms import Compose, Resize 9 | 10 | from . import Backbone2D, Backbone2DConfig 11 | from vla_network.config import ImageTransform 12 | 13 | # Registry =>> Supported DinoSigLIP Pairs (as TIMM identifiers) 14 | DINOSigLIP_NAMES = { 15 | 224: { 16 | "dino": "vit_large_patch14_reg4_dinov2.lvd142m", 17 | "siglip": "vit_so400m_patch14_siglip_224", 18 | }, 19 | 384: { 20 | "dino": "vit_large_patch14_reg4_dinov2.lvd142m", 21 | "siglip": "vit_so400m_patch14_siglip_384", 22 | }, 23 | } 24 | 25 | 26 | @dataclass 27 | class CombineImageTransform: 28 | transforms: List[ImageTransform] 29 | 30 | def __call__(self, img: Image, **kwargs: str) -> torch.Tensor: 31 | return torch.stack([t(img, **kwargs) for t in self.transforms], dim=0) 32 | 33 | 34 | class ViT(nn.Module): 35 | model: VisionTransformer 36 | 37 | def __init__(self, model: VisionTransformer) -> None: 38 | super().__init__() 39 | self.model = model 40 | self.n = len(self.model.blocks) - 2 41 | 42 | @property 43 | def embed_dim(self) -> int: 44 | return self.model.embed_dim 45 | 46 | def forward(self, x: torch.Tensor) -> torch.Tensor: 47 | return self.model.get_intermediate_layers(x, n={self.n})[0] 48 | 49 | 50 | class DinoSigLIPViTBackbone(Backbone2D): 51 | # from parent class 52 | config: Backbone2DConfig 53 | image_transform: CombineImageTransform 54 | 55 | models: List[str] 56 | dino: ViT 57 | siglip: ViT 58 | 59 | def __init__(self, config: Backbone2DConfig) -> None: 60 | super().__init__(config) 61 | self.models = ["dino", "siglip"] 62 | 63 | transforms = [] 64 | for model_type in self.models: 65 | name = DINOSigLIP_NAMES[config.image_size][model_type] 66 | model: ViT = ViT( 67 | timm.create_model( 68 | name, pretrained=True, num_classes=0, img_size=config.image_size 69 | ) 70 | ) 71 | model.eval() 72 | 73 | model_cfg = timm.data.resolve_model_data_config(model.model) 74 | model_cfg["input_size"] = (3, config.image_size, config.image_size) 75 | transform = timm.data.create_transform(**model_cfg, is_training=False) 76 | 77 | # Replace the resize transform with the target size 78 | target_size = (config.image_size, config.image_size) 79 | resize_transform = Compose( 80 | [ 81 | Resize( 82 | target_size, interpolation=transform.transforms[0].interpolation 83 | ), 84 | *transform.transforms[1:], 85 | ] 86 | ) 87 | 88 | setattr(self, model_type, model) 89 | transforms.append(resize_transform) 90 | self.image_transform = CombineImageTransform(transforms) 91 | 92 | @property 93 | def feature_dim(self) -> int: 94 | return self.dino.embed_dim + self.siglip.embed_dim 95 | 96 | def forward(self, images: torch.Tensor) -> torch.Tensor: 97 | b, n, _, *chw = images.shape 98 | feats = [] 99 | for i, k in enumerate(self.models): 100 | feat = getattr(self, k)(images[:, :, i].reshape(b * n, *chw)) 101 | feats.append(feat.reshape(b, -1, feat.shape[-1])) 102 | return torch.cat(feats, dim=-1) 103 | 104 | if __name__ == '__main__': 105 | for name in DINOSigLIP_NAMES[224].values(): 106 | model: ViT = ViT( 107 | timm.create_model( 108 | name, pretrained=True, num_classes=0, img_size=224 109 | ) 110 | ) 111 | -------------------------------------------------------------------------------- /vla_network/model/backbone_llm/internlm/tokenization_internlm2_fast.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on transformers/src/transformers/models/llama/tokenization_llama_fast.py 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | """Tokenization Fast class for InternLM.""" 19 | import os 20 | from shutil import copyfile 21 | from typing import Any, Dict, Optional 22 | 23 | from tokenizers import processors 24 | 25 | from transformers.tokenization_utils_fast import PreTrainedTokenizerFast 26 | from transformers.utils import logging 27 | 28 | logger = logging.get_logger(__name__) 29 | 30 | VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"} 31 | 32 | 33 | # Modified from transformers.model.llama.tokenization_llama_fast.LlamaTokenizerFast -> InternLM2TokenizerFast 34 | class InternLM2TokenizerFast(PreTrainedTokenizerFast): 35 | vocab_files_names = VOCAB_FILES_NAMES 36 | padding_side = "left" 37 | model_input_names = ["input_ids", "attention_mask"] 38 | _auto_class = "AutoTokenizer" 39 | 40 | def __init__( 41 | self, 42 | vocab_file, 43 | unk_token="", 44 | bos_token="", 45 | eos_token="", 46 | pad_token="", 47 | sp_model_kwargs: Optional[Dict[str, Any]] = None, 48 | add_bos_token=True, 49 | add_eos_token=False, 50 | decode_with_prefix_space=False, 51 | clean_up_tokenization_spaces=False, 52 | **kwargs, 53 | ): 54 | super().__init__( 55 | vocab_file=vocab_file, 56 | unk_token=unk_token, 57 | bos_token=bos_token, 58 | eos_token=eos_token, 59 | pad_token=pad_token, 60 | sp_model_kwargs=sp_model_kwargs, 61 | add_bos_token=add_bos_token, 62 | add_eos_token=add_eos_token, 63 | decode_with_prefix_space=decode_with_prefix_space, 64 | clean_up_tokenization_spaces=clean_up_tokenization_spaces, 65 | **kwargs, 66 | ) 67 | self._add_bos_token = add_bos_token 68 | self._add_eos_token = add_eos_token 69 | self.update_post_processor() 70 | self.vocab_file = vocab_file 71 | 72 | @property 73 | def can_save_slow_tokenizer(self) -> bool: 74 | return os.path.isfile(self.vocab_file) if self.vocab_file else False 75 | 76 | def update_post_processor(self): 77 | """ 78 | Updates the underlying post processor with the current `bos_token` and `eos_token`. 79 | """ 80 | bos = self.bos_token 81 | bos_token_id = self.bos_token_id 82 | if bos is None and self.add_bos_token: 83 | raise ValueError("add_bos_token = True but bos_token = None") 84 | 85 | eos = self.eos_token 86 | eos_token_id = self.eos_token_id 87 | if eos is None and self.add_eos_token: 88 | raise ValueError("add_eos_token = True but eos_token = None") 89 | 90 | single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}" 91 | pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}" 92 | 93 | special_tokens = [] 94 | if self.add_bos_token: 95 | special_tokens.append((bos, bos_token_id)) 96 | if self.add_eos_token: 97 | special_tokens.append((eos, eos_token_id)) 98 | self._tokenizer.post_processor = processors.TemplateProcessing( 99 | single=single, pair=pair, special_tokens=special_tokens 100 | ) 101 | 102 | @property 103 | def add_eos_token(self): 104 | return self._add_eos_token 105 | 106 | @property 107 | def add_bos_token(self): 108 | return self._add_bos_token 109 | 110 | @add_eos_token.setter 111 | def add_eos_token(self, value): 112 | self._add_eos_token = value 113 | self.update_post_processor() 114 | 115 | @add_bos_token.setter 116 | def add_bos_token(self, value): 117 | self._add_bos_token = value 118 | self.update_post_processor() -------------------------------------------------------------------------------- /vla_network/data_preprocessing/preprocess.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple, Optional 2 | import random 3 | import numpy as np 4 | from PIL import Image 5 | import torch 6 | from transformers import PreTrainedTokenizerBase 7 | from transforms3d.euler import mat2euler, euler2mat 8 | 9 | from vla_network.type import BatchVLAData, RawVLAData 10 | from vla_network.config import VLADataConfig, ImageTransform 11 | 12 | from .tokenizer import RobotTokenizer 13 | from .token_pattern import get_token_pattern 14 | 15 | 16 | def resize_with_bbox( 17 | image: Image.Image, 18 | bbox: Optional[np.ndarray], 19 | target_size: Tuple[int, int], 20 | random_padding: bool = True, 21 | ) -> Tuple[Image.Image, Optional[np.ndarray]]: 22 | """ 23 | Resize the image to target size. Pad if necessary. 24 | Also computes the bbox on the resized & padded image. 25 | """ 26 | original_size = image.size 27 | ratio = min(target_size[0] / original_size[0], target_size[1] / original_size[1]) 28 | new_size = (int(original_size[0] * ratio), int(original_size[1] * ratio)) 29 | image = image.resize(new_size, Image.LANCZOS) 30 | 31 | new_image = Image.new("RGB", target_size) 32 | if random_padding: 33 | paste_x = random.randint(0, target_size[0] - new_size[0]) 34 | paste_y = random.randint(0, target_size[1] - new_size[1]) 35 | else: 36 | paste_x = (target_size[0] - new_size[0]) // 2 37 | paste_y = (target_size[1] - new_size[1]) // 2 38 | new_image.paste(image, (paste_x, paste_y)) 39 | 40 | if bbox is not None: 41 | new_bbox = bbox * ratio 42 | new_bbox[0] += paste_x 43 | new_bbox[1] += paste_y 44 | new_bbox[2] += paste_x 45 | new_bbox[3] += paste_y 46 | new_bbox = np.array([int(t) for t in new_bbox]) 47 | else: 48 | new_bbox = None 49 | 50 | return new_image, new_bbox 51 | 52 | class DataPreprocessor: 53 | config: VLADataConfig 54 | robot_tokenizer: RobotTokenizer 55 | tokenizer: PreTrainedTokenizerBase 56 | image_transform: ImageTransform 57 | 58 | def __init__(self, config: VLADataConfig): 59 | self.config = config 60 | self.tokenizer = config.tokenizer 61 | config.tokenizer = None 62 | self.robot_tokenizer = RobotTokenizer.init(config, self.tokenizer.vocab_size) 63 | config.tokenizer = self.tokenizer 64 | self.image_transform = config.image_transform 65 | if config.pred == 'cot_flow_matching': 66 | self.pattern = get_token_pattern(config, 'cot_action') 67 | 68 | def load(self, data: dict): 69 | self.robot_tokenizer.load(data) 70 | 71 | def transform_img_bbox(self, raw_images: Dict[str, np.ndarray], raw_bboxs: Optional[Dict[str, np.ndarray]]) -> Tuple[torch.Tensor, Optional[np.ndarray]]: 72 | pixel_values: List[Dict[str, torch.Tensor]] = [] 73 | bboxs: List[np.ndarray] = [] 74 | 75 | img_key = self.config.img_key 76 | assert all(len(raw_images[k]) == self.config.img_steps for k in img_key) 77 | for i in range(self.config.img_steps): 78 | for img_k in img_key: 79 | img, bbox = resize_with_bbox( 80 | Image.fromarray(raw_images[img_k][i]), 81 | raw_bboxs[img_k][i] if raw_bboxs is not None else None, 82 | (self.config.image_size, self.config.image_size), 83 | ) 84 | pixel_value = self.image_transform(img) 85 | if bbox is not None: 86 | bbox = bbox / self.config.image_size * 2 - 1 87 | pixel_values.append(pixel_value) 88 | bboxs.append(bbox) 89 | pixel_values = torch.stack(pixel_values)[None] 90 | bboxs = np.stack(bboxs) if bboxs[0] is not None else None 91 | return pixel_values, bboxs 92 | 93 | 94 | def transform(self, raw_data: RawVLAData, inference: bool = False) -> BatchVLAData: 95 | 96 | pixel_values, bboxs = self.transform_img_bbox(raw_data.images, raw_data.bboxs) 97 | 98 | trans_dic = dict(proprio=raw_data.proprio, action=raw_data.action, goal=None) 99 | assert len(trans_dic["proprio"]) == self.config.proprio_len 100 | 101 | text_ids = self.tokenizer(raw_data.instruction, add_special_tokens=True).input_ids 102 | 103 | debug_dict = None 104 | inference_kwargs = [dict( 105 | text_ids=text_ids, 106 | hist_proprio=self.robot_tokenizer.proprio(trans_dic['proprio'][:-1]), 107 | cur_proprio=self.robot_tokenizer.proprio(trans_dic['proprio'][-1]), 108 | )] 109 | token_result = self.pattern.update_tokens( 110 | output=[], 111 | **inference_kwargs[0] 112 | ) 113 | input_ids = token_result.input_ids 114 | robot_input_ids = token_result.robot_input_ids 115 | 116 | return BatchVLAData( 117 | debug=[debug_dict], 118 | input_ids=torch.tensor(input_ids)[None], 119 | labels=None, 120 | attention_mask=torch.ones(len(input_ids))[None].bool(), 121 | robot_input_ids=torch.tensor(robot_input_ids)[None], 122 | robot_attention_mask=torch.ones(len(robot_input_ids))[None].bool(), 123 | robot_labels=None, 124 | images=pixel_values, 125 | action=None, 126 | proprio=torch.from_numpy(self.robot_tokenizer.norm_proprio(trans_dic['proprio'])).float()[None], 127 | goal=None, 128 | is_action=torch.ones(1).bool(), 129 | inference_kwargs=inference_kwargs, 130 | ) -------------------------------------------------------------------------------- /vla_network/scripts/offline_test.py: -------------------------------------------------------------------------------- 1 | from urllib import request 2 | import zmq 3 | from PIL import Image 4 | import io 5 | import cv2 6 | from matplotlib import pyplot as plt 7 | import numpy as np 8 | from termcolor import colored 9 | import os 10 | 11 | import argparse 12 | arg_parser = argparse.ArgumentParser() 13 | arg_parser.add_argument("--port", type=str, default="6666") 14 | 15 | def validate_server(host: str = "127.0.0.1", port: int = 6666, timeout: int = 5) -> bool: 16 | """ 17 | Validate that the server is running and returns a valid dict. 18 | 19 | Args: 20 | host: Server hostname 21 | port: Server port 22 | timeout: Timeout in seconds 23 | 24 | Returns: 25 | True if server returns valid dict, False otherwise 26 | """ 27 | context = zmq.Context() 28 | socket = context.socket(zmq.REQ) 29 | socket.setsockopt(zmq.RCVTIMEO, timeout * 1000) 30 | 31 | try: 32 | socket.connect(f"tcp://{host}:{port}") 33 | 34 | # Create test data matching agent.py format 35 | mock_image = np.random.randint(0, 255, (256, 256, 3), dtype=np.uint8) 36 | mock_proprio = [np.random.randn(7) for _ in range(4)] 37 | 38 | test_data = { 39 | 'front_view_image': [mock_image], 40 | 'side_view_image': [mock_image], 41 | 'proprio_array': mock_proprio, 42 | 'text': 'Validation test instruction', 43 | } 44 | 45 | socket.send_pyobj(test_data) 46 | response = socket.recv_pyobj() 47 | 48 | # Check if response is a valid dict 49 | if not isinstance(response, dict): 50 | print(f"✗ Server returned {type(response)}, expected dict") 51 | return False 52 | 53 | print(colored(f"✓ Server at {host}:{port} returned valid dict", 'green')) 54 | return True 55 | 56 | except zmq.Again: 57 | print(colored(f"✗ Server at {host}:{port} timeout after {timeout}s", 'red')) 58 | return False 59 | except Exception as e: 60 | print(colored(f"✗ Error connecting to server at {host}:{port}: {e}", 'red')) 61 | return False 62 | finally: 63 | socket.close() 64 | context.term() 65 | 66 | 67 | def rename_request_keys(request): 68 | if 'image_array' in request: 69 | request['front_view_image'] = request.pop('image_array') 70 | if 'image_wrist_array' in request: 71 | request['side_view_image'] = request.pop('image_wrist_array') 72 | 73 | 74 | def visualize_response(request, response, vis=False): 75 | bbox = response['debug']['bbox'] 76 | 77 | if request['compressed']: 78 | front_image = Image.open(io.BytesIO(request['front_view_image'][0])) 79 | side_image = Image.open(io.BytesIO(request['side_view_image'][0])) 80 | else: 81 | front_image = request['front_view_image'][0] 82 | side_image = request['side_view_image'][0] 83 | front_image = np.array(front_image) 84 | front_bbox = bbox[0] 85 | resized_bbox = (front_bbox / 224 * 256).astype(int) # hack: the order of image and bbox is different 86 | cv2.rectangle(front_image, (resized_bbox[0], resized_bbox[1]), (resized_bbox[2], resized_bbox[3]), (0, 255, 0), 2) 87 | 88 | side_image = np.array(side_image) 89 | side_bbox = bbox[1] 90 | resized_bbox = (side_bbox / 224 * 256).astype(int) # hack: the order of image and bbox is different 91 | cv2.rectangle(side_image, (resized_bbox[0], resized_bbox[1]), (resized_bbox[2], resized_bbox[3]), (0, 255, 0), 2) 92 | 93 | merged_image = np.concatenate((front_image, side_image), axis=1) 94 | if vis: 95 | plt.imshow(merged_image) 96 | return merged_image 97 | 98 | 99 | def main(): 100 | 101 | args = arg_parser.parse_args() 102 | 103 | if (validate_server(port=args.port) == True): 104 | context = zmq.Context() 105 | socket = context.socket(zmq.REQ) 106 | socket.connect(f"tcp://127.0.0.1:{args.port}") 107 | 108 | trial_caption = 'trial-20250507120350' 109 | data = np.load("visualization/trial-20250507120350_data.npy", allow_pickle=True).item() 110 | print(f"Task: {data['request']['text']}") 111 | 112 | fig = plt.figure(figsize=(6, 2)) 113 | plt.suptitle(f"Task: {data['request']['text']}", fontsize=16) 114 | 115 | request = data['request'] 116 | response = data['response'] 117 | 118 | rename_request_keys(request) 119 | 120 | # show the result of the original model 121 | left_image = visualize_response(request, response) 122 | 123 | try: 124 | socket.send_pyobj(request) 125 | new_response = socket.recv_pyobj() 126 | except Exception as e: 127 | print(f"Socket communication failed: {e}") 128 | 129 | # show the result of the current model 130 | right_image = visualize_response(request, new_response) 131 | 132 | # our model result 133 | plt.subplot(1, 2, 1) 134 | plt.imshow(left_image) 135 | plt.title(f"Our Result") 136 | plt.axis('off') 137 | 138 | # current model result 139 | plt.subplot(1, 2, 2) 140 | plt.imshow(right_image) 141 | plt.title(f"Your Result") 142 | plt.axis('off') 143 | 144 | plt.tight_layout(rect=[0, 0, 1, 0.97]) 145 | os.makedirs("visualization", exist_ok=True) 146 | plt.savefig(f"visualization/{trial_caption}_visualization.png", dpi=200) 147 | print(f"Saved figure as \"visualization/{trial_caption}_visualization.png\".") 148 | socket.close() 149 | context.term() 150 | 151 | else: 152 | print(f"Please make sure the model server is running at tcp://127.0.0.1:{args.port}.") 153 | return 154 | 155 | if __name__ == "__main__": 156 | main() -------------------------------------------------------------------------------- /vla_network/config/define.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Type, Union, Dict 2 | from pydantic import BaseModel, Field, ConfigDict 3 | from PIL import Image 4 | import torch 5 | import importlib 6 | from transformers import ( 7 | PreTrainedModel, 8 | PreTrainedTokenizerBase, 9 | PreTrainedTokenizerFast, 10 | ) 11 | 12 | class ImageTransform: 13 | def __call__( 14 | self, img: Image, **kwargs: str 15 | ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: ... 16 | 17 | # TODO: which keys should be in the basic config, shared in VA and VLA? 18 | # Don't add keys without asking other people 19 | class BasicDataConfig(BaseModel): 20 | exp_name: Optional[str] = Field(default=None) 21 | robot: str 22 | proprio_len: int 23 | action_len: int 24 | action_dim: int = Field(default=None) 25 | goal_dim: Optional[int] = Field(default=None) 26 | action_rel_len: int 27 | dt_steps: int 28 | 29 | def setup(self): 30 | pass 31 | 32 | 33 | class VLADataConfig(BasicDataConfig): 34 | # TODO: sort them in a better way 35 | model_config = ConfigDict(arbitrary_types_allowed=True) 36 | 37 | tokenizer: Optional[PreTrainedTokenizerBase] = Field(init=False, default=None) 38 | image_transform: Optional[ImageTransform] = Field(init=False, default=None) 39 | action_token_num: int 40 | img_steps: int 41 | img_key: Optional[List[str]] 42 | image_size: Optional[int] 43 | anything_prob: float 44 | robot_rep: str 45 | goal_rep: Optional[str] 46 | tokenizer_type: str 47 | tokenizer_ratio_limit: float 48 | count_num: int 49 | trans_noise: float 50 | rot_noise: float 51 | brightness_img: str 52 | brightness_threshold: float 53 | crop_mode: Dict[str, str] 54 | proprio_dim: Optional[int] = Field(default=None) 55 | use_bbox: int 56 | pred: Optional[str] = Field(init=False, default=None) 57 | 58 | def setup(self): 59 | super().setup() 60 | 61 | if self.action_dim is None: 62 | if self.robot_rep in ['xyz_rpy', 'xyz_rpy_rot']: 63 | self.action_dim = 7 # xyz rpy gripper 64 | if self.goal_dim is None: 65 | if self.goal_rep == 'xyz_rpy': 66 | self.goal_dim = 6 # xyz rpy 67 | elif self.goal_rep == 'xyz_rot': 68 | self.goal_dim = 12 # xyz rotmat 69 | if self.proprio_dim is None: 70 | if self.robot_rep == 'xyz_rpy': 71 | self.proprio_dim = 7 # xyz rpy gripper 72 | elif self.robot_rep == 'xyz_rpy_rot': 73 | self.proprio_dim = 13 # xyz rotmat gripper 74 | 75 | 76 | @property 77 | def img_num(self) -> int: 78 | return len(self.img_key) * self.img_steps 79 | 80 | 81 | LLM_CONFIG = { 82 | "meta-llama/Llama-2-7b-hf": { 83 | "family": "llama2", 84 | "model_cls": ("transformers", "LlamaForCausalLM"), 85 | "token_cls": ("transformers", "AutoTokenizer"), 86 | }, 87 | "internlm/internlm2-1_8b": { 88 | "family": "internlm", 89 | "model_cls": ( 90 | "vla_network.model.backbone_llm.internlm.modeling_internlm2", 91 | "InternLM2ForCausalLM", 92 | ), 93 | "token_cls": ( 94 | "vla_network.model.backbone_llm.internlm.tokenization_internlm2_fast", 95 | "InternLM2TokenizerFast", 96 | ), 97 | }, 98 | } 99 | 100 | 101 | # TODO: which keys should be in the basic config, shared in VA and VLA? 102 | # Don't add keys without asking other people 103 | class BasicModelConfig(BaseModel): 104 | pass 105 | 106 | class LLMConfig(BaseModel): 107 | name: str 108 | max_len: int = Field(default=2048) 109 | special_tokens: List[str] = Field(default_factory=lambda: []) 110 | pad_multiple_of: int = Field(default=64) 111 | attn_implementation: str 112 | 113 | @property 114 | def family(self) -> str: 115 | return LLM_CONFIG[self.name]["family"] 116 | 117 | @staticmethod 118 | def get_cls(package: str, name: str): 119 | module = importlib.import_module(package) 120 | return getattr(module, name) 121 | 122 | @property 123 | def model_cls(self) -> Type[PreTrainedModel]: 124 | cls_package, cls_name = LLM_CONFIG[self.name]["model_cls"] 125 | return self.get_cls(cls_package, cls_name) 126 | 127 | @property 128 | def token_cls(self) -> Type[PreTrainedTokenizerFast]: 129 | cls_package, cls_name = LLM_CONFIG[self.name]["token_cls"] 130 | return self.get_cls(cls_package, cls_name) 131 | 132 | 133 | class Backbone2DConfig(BaseModel): 134 | name: str 135 | image_size: int 136 | 137 | class ActionExpertConfig(BaseModel): 138 | hidden_size_scale: Optional[int] = Field(default=None) 139 | intermediate_size_scale: Optional[int] = Field(default=None) 140 | hidden_size: Optional[int] = Field(init=False, default=None) 141 | intermediate_size: Optional[int] = Field(init=False, default=None) 142 | hidden_act: Optional[str] = Field(init=False, default=None) 143 | 144 | class FlowMatchingConfig(BaseModel): 145 | beta_alpha: float 146 | beta_beta: float 147 | time_min: float 148 | time_max: float 149 | 150 | class VLAModelConfig(BasicModelConfig): 151 | backbone_2d: Backbone2DConfig 152 | llm: LLMConfig 153 | ckpt: str 154 | pred: str # flow_matching or token_pred 155 | action_len: int = Field(init=False, default=None) 156 | action_dim: int = Field(init=False, default=None) 157 | proprio_dim: int = Field(init=False, default=None) 158 | action_expert: int 159 | action_expert_cfg: Optional[ActionExpertConfig] = None 160 | flow_matching_cfg: Optional[FlowMatchingConfig] = None 161 | 162 | def to_dict(self): 163 | return self.model_dump() 164 | 165 | 166 | class BasicConfig(BaseModel): 167 | data: BasicDataConfig 168 | model: BasicModelConfig 169 | dummy: str = None # For unneeded arguments 170 | 171 | 172 | class VLAConfig(BasicConfig): 173 | data: VLADataConfig 174 | model: VLAModelConfig 175 | dummy: str = None # For unneeded arguments -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GraspVLA: a Grasping Foundation Model Pre-trained on Billion-scale Synthetic Action Data 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2505.03233-df2a2a.svg)](https://arxiv.org/pdf/2505.03233) 4 | [![Static Badge](https://img.shields.io/badge/Project-Page-a)](https://pku-epic.github.io/GraspVLA-web/) 5 | 6 | [**Model Server**](#model-server) | [**Simulation Playground**](#simulation-playground) | [**Real World Control Interface**](#real-world-control-interface) 7 | 8 | We present a cost-effective pretraining paradigm for VLA models using only synthetic data, achieving direct sim-to-real transfer and strong zero-shot generalizability for robotic grasping. Key contributions include: 9 | 10 | - **SynGrasp-1B**: a billion-frame synthetic grasping dataset, spanning 240 object categories and 10,000+ objects. 11 | 12 | - **GraspVLA**: a VLA model pretrained on SynGrasp-1B that achieves zero-shot generalization to real-world grasping without fine-tuning. 13 | 14 | - **Unified CoT Framework**: GraspVLA integrates autoregressive perception and flow-matching-based action generation into a single reasoning process, enabling joint training on synthetic action data and internet-scale semantic data for open-vocabulary grasping. 15 | 16 | ![teaser](./figs/teaser.jpg) 17 | 18 | ## Latest Updates 19 | - [2025-07-25] Release the GraspVLA model, [simulation playground](https://github.com/MiYanDoris/GraspVLA-playground) and [real world control interface](https://github.com/MiYanDoris/GraspVLA-real-world-controller). 20 | - [2025-07-19] Release the [supplementary material](https://arxiv.org/pdf/2505.03233). 21 | 22 | ## Model Server 23 | 24 | Please follow the steps below to start the model server. We provide the checkpoint of GraspVLA on both [huggingface](https://huggingface.co/shengliangd/GraspVLA) and [Baidu cloud](https://pan.baidu.com/s/1DOJbKrKzdBcEIrFQ_NcuLw?pwd=6666). GraspVLA achieves 200ms inference latency using ~9GB of GPU memory when running on a single NVIDIA RTX L40s GPU. 25 | 26 | ### Step 1: Clone the Repository 27 | ```bash 28 | git clone https://github.com/PKU-EPIC/GraspVLA 29 | cd GraspVLA 30 | ``` 31 | 32 | ### Step 2: Set Up Python Environment 33 | Create and activate a conda environment with the required dependencies, for example: 34 | ```bash 35 | conda create -n GraspVLA python=3.9.19 36 | conda activate GraspVLA 37 | pip install -r requirements.txt 38 | ``` 39 | 40 | ### Step 3: Download Model Weights 41 | 42 | If you want to download from huggingface: 43 | ``` 44 | pip install -U "huggingface_hub" 45 | # set HF_ENDPOINT if you encounter connection issues: 46 | # export HF_ENDPOINT=https://hf-mirror.com 47 | hf download shengliangd/GraspVLA 48 | ``` 49 | The model weight will be placed at `~/.cache/huggingface/hub/models--shengliangd--GraspVLA/snapshots/f291eac1d3494c5c13c3d420af4e5bc987f23c3e/checkpoint/model.safetensors`. 50 | 51 | ### Step 4: Launch the Model Server 52 | Run the model server with your desired configuration, for example: 53 | ```bash 54 | python3 -u -m vla_network.scripts.serve --path you-path-to-model.safetensors --port 6666 55 | ``` 56 | 57 | #### Required arguments: 58 | 59 | - --path — Path to the model.safetensors file. 60 | 61 | - --port — Port number on which the server will listen for incoming requests. 62 | 63 | #### Optional arguments: 64 | 65 | - --compile: Enable model compilation (default: False). Speeds up inference (500ms → 200ms) but adds ~3 minutes to startup time. Recommended for large-scale evaluations (e.g., LIBERO benchmark). 66 | 67 | Success: The message `Started server on port ` indicates the server is ready. 68 | 69 | ### Offline Test And Visualization 70 | 71 | To quickly test GraspVLA without setting up a simulation or real-world environment, use the `offline_test` script. It runs offline inference on pre-recorded requests and compares results. The script saves a comparison image in `visualization`, with our reference output (left) and your model’s output (from the specified port) on the right. 72 | 73 | ```bash 74 | python3 -u -m vla_network.scripts.offline_test --port 75 | ``` 76 | 77 | ### Repository Structure 78 | 79 | High-level overview of `vla_network` file-tree: 80 | 81 | + `config/` - Contains basic configurations for our model. 82 | + `data_preprocessing/` - Includes tools for preprocessing raw data into model-ready formats. 83 | + `model/` - Code for defining and loading the main model structure. 84 | + `scripts/` - Contains the `serve.py` file that starts a model server and the `offline_test.py` file that performs offline visualization. 85 | + `type/` - Data type definitions used in our model. 86 | + `utils/` - Contains some constants used in our model. 87 | 88 | ## Simulation Playground 89 | We provide a simulation playground for GraspVLA here: [GraspVLA-playground](https://github.com/MiYanDoris/GraspVLA-playground). This repository includes both the evaluation code used for GraspVLA in the [LIBERO](https://github.com/Lifelong-Robot-Learning/LIBERO) benchmark and an enhanced playground environment built on top of it. The playground provides an easy-to-use interface to evaluate GraspVLA across diverse objects, layouts, and environments. 90 | ![playground](figs/playground.gif) 91 | 92 | ## Real World Control Interface 93 | We provide a [real-world control interface](https://github.com/MiYanDoris/GraspVLA-real-world-controller) for deploying GraspVLA in physical environments. This interface enables: 94 | 95 | - Zero-shot evaluation on real-world objects. 96 | 97 | - Both blocking and non-blocking control modes. 98 | 99 | - Real-time visualization of intermediate COT results, including 2D bounding boxes and 3D grasp poses. 100 | 101 | camera_setup 102 | 103 | ## Citation 104 | 105 | If you find this work useful, please cite: 106 | 107 | ```bibtex 108 | @article{deng2025graspvla, 109 | title={GraspVLA: a Grasping Foundation Model Pre-trained on Billion-scale Synthetic Action Data}, 110 | author={Shengliang Deng and Mi Yan and Songlin Wei and Haixin Ma and Yuxin Yang and Jiayi Chen and Zhiqi Zhang and Taoyu Yang and Xuheng Zhang and Wenhao Zhang and Heming Cui and Zhizheng Zhang and He Wang}, 111 | year={2025}, 112 | eprint={2505.03233}, 113 | archivePrefix={arXiv}, 114 | primaryClass={cs.RO}, 115 | url={https://arxiv.org/abs/2505.03233} 116 | } 117 | ``` 118 | 119 | [![License](https://licensebuttons.net/l/by-nc/4.0/88x31.png)](LICENSE) 120 | -------------------------------------------------------------------------------- /vla_network/model/vla/flow_matching.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union, Callable 2 | import torch 3 | from torch import nn 4 | 5 | from vla_network.config import FlowMatchingConfig 6 | 7 | def posemb_sincos(pos: torch.Tensor, embedding_dim: int, min_period: float = 4e-3, max_period: float = 4.0) -> torch.Tensor: 8 | if embedding_dim % 2 != 0: 9 | raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by 2") 10 | fraction = torch.linspace(0.0, 1.0, embedding_dim // 2, device=pos.device, dtype=pos.dtype) 11 | period = min_period * (max_period / min_period) ** fraction 12 | sinusoid_input = torch.einsum("i,j->ij", pos, 1.0 / period * 2 * torch.pi) 13 | return torch.cat([torch.sin(sinusoid_input), torch.cos(sinusoid_input)], dim=-1) 14 | 15 | class BaseFlowMatchingModule(nn.Module): 16 | 17 | def __init__(self, config: FlowMatchingConfig, embed_dim: int): 18 | super().__init__() 19 | self.config = config 20 | self.embed_dim = embed_dim 21 | 22 | def get_time_embedding(self, t: torch.Tensor) -> torch.Tensor: 23 | t_embed = posemb_sincos(t, self.embed_dim) 24 | return t_embed 25 | 26 | def sample_time(self, batch_shape: tuple, device: str, dtype: str) -> torch.FloatTensor: 27 | time = torch.distributions.Beta( 28 | torch.tensor(self.config.beta_alpha, device=device, dtype=torch.float32), 29 | torch.tensor(self.config.beta_beta, device=device, dtype=torch.float32) 30 | ).sample(batch_shape).to(dtype) 31 | time = time * (self.config.time_max - self.config.time_min) + self.config.time_min 32 | return time 33 | 34 | def sample_noise(self, shape: tuple, device: str, dtype: str) -> torch.FloatTensor: 35 | return torch.randn(shape, device=device, dtype=dtype) 36 | 37 | def diffuse(self, x_1: torch.Tensor, t: torch.Tensor, noise: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: 38 | if noise is None: 39 | noise = torch.randn_like(x_1) 40 | x_t = t * noise + (1 - t) * x_1 41 | u_t = noise - x_1 42 | return x_t, u_t 43 | 44 | 45 | def update(self, x_t: torch.FloatTensor, v_t: torch.FloatTensor, dt: Union[float, torch.FloatTensor], timestep: Union[float, torch.FloatTensor]) -> torch.FloatTensor: 46 | return x_t + dt * v_t 47 | 48 | def denoise(self, compute_v_t: Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor], x_t: torch.FloatTensor, iter_num: int) -> torch.FloatTensor: 49 | device, dtype = x_t.device, x_t.dtype 50 | time_vec = torch.ones((len(x_t),), device=device, dtype=dtype) 51 | dt = 1.0 / iter_num 52 | time_steps = torch.linspace(1.0, dt, iter_num, device=device, dtype=dtype) 53 | for t in time_steps: 54 | time_vec[:] = t 55 | v_t = compute_v_t(x_t, time_vec) 56 | x_t = self.update(x_t, v_t, -dt, time_vec) 57 | return x_t 58 | 59 | class VLAFlowMatchingModule(BaseFlowMatchingModule): 60 | 61 | def __init__(self, config: FlowMatchingConfig, action_dim: int, llm_dim: int, action_len: int, proprio_dim: int): 62 | super().__init__(config=config, embed_dim=llm_dim) 63 | self.action_len = action_len 64 | self.action_dim = action_dim 65 | 66 | self.proprior_proj = nn.Linear(proprio_dim, llm_dim) 67 | self.action_in_proj = nn.Linear(action_dim, llm_dim) 68 | self.action_time_mlp = nn.Sequential( 69 | nn.Linear(llm_dim * 2, llm_dim), 70 | nn.SiLU(), 71 | nn.Linear(llm_dim, llm_dim) 72 | ) 73 | self.action_out_proj = nn.Linear(llm_dim, action_dim) 74 | 75 | def sample_noise_and_time(self, action: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: 76 | batch_shape = action.shape[:-2] 77 | device, dtype = action.device, action.dtype 78 | 79 | time = self.sample_time(batch_shape, device=device, dtype=dtype) 80 | time_expanded = time[..., None, None] 81 | noise = super().sample_noise(action.shape, device=device, dtype=dtype) 82 | x_t, u_t = self.diffuse(action, time_expanded, noise) 83 | 84 | return x_t, u_t, time 85 | 86 | def sample_noise(self, batch_size: int, device: str, dtype: str) -> torch.FloatTensor: 87 | return super().sample_noise((batch_size, self.action_len, self.action_dim), device=device, dtype=dtype) 88 | 89 | def embed_suffix_flow_matching( 90 | self, 91 | proprio: torch.FloatTensor, 92 | noisy_actions: torch.FloatTensor, 93 | timestep: torch.FloatTensor 94 | ) -> Tuple[torch.FloatTensor, torch.BoolTensor, torch.BoolTensor]: 95 | #device = proprio.device 96 | dtype = self.proprior_proj.weight.dtype 97 | #batch_size = proprio.shape[0] 98 | 99 | proprio_embed = self.proprior_proj(proprio.to(dtype)) 100 | embeds = self.embed_suffix_flow_matching_embeds(proprio_embed, noisy_actions, timestep) 101 | input_mask, block_mask = self.get_suffix_masks(proprio_embed) 102 | 103 | return embeds, input_mask, block_mask 104 | 105 | def get_suffix_masks(self, proprio_embed): 106 | batch_size = proprio_embed.shape[0] 107 | device = proprio_embed.device 108 | total_len = proprio_embed.shape[1] + self.action_len 109 | input_mask = torch.ones((batch_size, total_len), dtype=torch.bool, device=device) 110 | block_mask = torch.zeros(total_len, dtype=torch.bool, device=device) 111 | block_mask[0], block_mask[proprio_embed.shape[1]] = True, True 112 | return input_mask, block_mask 113 | 114 | def embed_suffix_flow_matching_embeds( 115 | self, 116 | proprio_embed: torch.FloatTensor, 117 | noisy_actions: torch.FloatTensor, 118 | timestep: torch.FloatTensor 119 | ) -> Tuple[torch.FloatTensor, torch.BoolTensor, torch.BoolTensor]: 120 | action_embeds = self.action_in_proj(noisy_actions) 121 | time_embeds = self.get_time_embedding(timestep) 122 | time_embeds = time_embeds[:, None, :].expand(-1, self.action_len, -1) 123 | action_time_embeds = self.action_time_mlp( 124 | torch.cat([action_embeds, time_embeds], dim=-1) 125 | ) 126 | embeds = torch.cat([proprio_embed, action_time_embeds], dim=1) 127 | return embeds 128 | 129 | def get_v_t(self, hidden_states: torch.FloatTensor): 130 | return self.action_out_proj(hidden_states.to(self.action_out_proj.weight.dtype)) 131 | 132 | 133 | -------------------------------------------------------------------------------- /vla_network/scripts/serve.py: -------------------------------------------------------------------------------- 1 | # FIXME: import urchin before others, otherwise segfault, unknown reason 2 | 3 | import os 4 | if 'DEBUG_PORT' in os.environ: 5 | import debugpy 6 | debugpy.listen(int(os.environ['DEBUG_PORT'])) 7 | print(f'waiting for debugger to attach...') 8 | debugpy.wait_for_client() 9 | 10 | import argparse 11 | arg_parser = argparse.ArgumentParser() 12 | arg_parser.add_argument("--port", type=str, required=True) 13 | arg_parser.add_argument("--path", type=str, required=True) 14 | arg_parser.add_argument("--compile", action="store_true") 15 | 16 | 17 | import PIL 18 | import io 19 | import os 20 | from typing import List 21 | import zmq 22 | import pickle 23 | import time 24 | import numpy as np 25 | from tqdm import tqdm 26 | from vla_network.model.vla import VLAAgent 27 | from vla_network.data_preprocessing.prompt import COT_PROMPT 28 | import torch 29 | torch.autograd.set_grad_enabled(False) 30 | 31 | 32 | def interpolate_delta_actions(delta_actions, n): 33 | """ 34 | Interpolate m delta_actions to m*n delta_actions. 35 | 36 | actions: list of actions, each action is (delta x, delta y, delta z, delta roll, delta pitch, delta yaw, gripper open/close). 37 | """ 38 | import transforms3d as t3d 39 | ret = [] 40 | for delta_action in delta_actions: 41 | xyzs = 1 / n * np.array([delta_action[:3]]*n) 42 | axangle_ax, axangle_angle = t3d.euler.euler2axangle(*delta_action[3:6]) 43 | eulers = [t3d.euler.axangle2euler(axangle_ax, axangle_angle / n)]*n 44 | grippers = np.array([[0.]] * (n-1) + [[delta_action[-1]]]) # 0 for no change of gripper state 45 | ret.extend(np.concatenate([xyzs, eulers, grippers], axis=-1)) 46 | return ret 47 | 48 | 49 | def infer_single_sample(vla_model: VLAAgent, sample: dict): 50 | input_data = [] 51 | if sample.get('compressed', False): 52 | for key in ['front_view_image', 'side_view_image']: 53 | decompressed_image_array = [] 54 | for compressed_image in sample[key]: 55 | decompressed_image_array.append(np.array(PIL.Image.open(io.BytesIO(compressed_image)))) 56 | sample[key] = decompressed_image_array 57 | sample['compressed'] = False 58 | proprio_array = np.array([sample['proprio_array'][-4], sample['proprio_array'][-1]]) 59 | proprio_array[:, -1] = (proprio_array[:, -1] + 1) / 2 60 | input_data.append({ 61 | 'env_id': 0, 62 | 'text': COT_PROMPT(sample['text']), 63 | 'proprio_array': proprio_array, 64 | 'front_view_image': [sample['front_view_image'][-1]], 65 | 'side_view_image': [sample['side_view_image'][-1]], 66 | }) 67 | results = vla_model(input_data) # the model recieve a list of input samples 68 | result = results[0] # only one sample 69 | action = result['action'] 70 | # Quantize last dimension of action to 0, 1 using 0.4, 0.6 as bin boundaries 71 | last_dim = action[:, -1] 72 | last_dim = np.where(last_dim < 0.4, -1, np.where(last_dim > 0.6, 1, 0)) 73 | action = np.concatenate([action[:, :-1], last_dim[:, None]], axis=-1) 74 | action = interpolate_delta_actions(action, 2) 75 | debug = {} 76 | if 'goal' in result: 77 | debug['pose'] = result['goal'] 78 | if 'bbox' in result: 79 | debug['bbox'] = result['bbox'] 80 | return { 81 | 'result': action, 82 | 'env_id': 0, 83 | 'debug': debug, 84 | } 85 | 86 | 87 | def warmup(vla_model: VLAAgent): 88 | SAMPLES = [ 89 | { 90 | 'text': 'pick up elephant', 91 | 'front_view_image': [np.zeros((256, 256, 3), dtype=np.uint8)], 92 | 'side_view_image': [np.zeros((256, 256, 3), dtype=np.uint8)], 93 | 'proprio_array': [np.concatenate([np.zeros((6,), dtype=np.float32), np.ones((1,), dtype=np.float32)])]*4, 94 | }, 95 | { 96 | 'text': 'pick up toy large elephant', 97 | 'front_view_image': [np.zeros((256, 256, 3), dtype=np.uint8)], 98 | 'side_view_image': [np.zeros((256, 256, 3), dtype=np.uint8)], 99 | 'proprio_array': [np.concatenate([np.zeros((6,), dtype=np.float32), np.ones((1,), dtype=np.float32)])]*4, 100 | }, 101 | { 102 | 'text': 'pick up toy car', 103 | 'front_view_image': [np.zeros((256, 256, 3), dtype=np.uint8)], 104 | 'side_view_image': [np.zeros((256, 256, 3), dtype=np.uint8)], 105 | 'proprio_array': [np.concatenate([np.zeros((6,), dtype=np.float32), np.ones((1,), dtype=np.float32)])]*4, 106 | }, 107 | ] 108 | NUM_TESTS = 5 109 | print('warming up...') 110 | for i in tqdm(range(NUM_TESTS)): 111 | ret = infer_single_sample(vla_model, SAMPLES[i%len(SAMPLES)]) 112 | print('check the latency after warm up:') 113 | for i in tqdm(range(NUM_TESTS)): 114 | ret = infer_single_sample(vla_model, SAMPLES[i%len(SAMPLES)]) 115 | 116 | 117 | def main(): 118 | args = arg_parser.parse_args() 119 | vla_model = VLAAgent(args.path, compile=args.compile) 120 | 121 | vla_model.preprocessor.config.robot_rep = "identity" 122 | 123 | assert vla_model.data_cfg.action_rel_len == 0 124 | 125 | warmup(vla_model) 126 | 127 | context = zmq.Context() 128 | socket = context.socket(zmq.ROUTER) 129 | socket.bind(f"tcp://*:{args.port}") 130 | 131 | print(f"Started server on port {args.port}") 132 | 133 | requests = [] 134 | 135 | while True: 136 | # run inference if data is ready 137 | if (len(requests) > 0): 138 | client_id, data_received = requests[0] 139 | 140 | tbegin = time.time() 141 | print(f'start processing a request') 142 | result = infer_single_sample(vla_model, data_received) 143 | tend = time.time() 144 | print(f'finished a request in {tend-tbegin:.3f}s') 145 | 146 | socket.send_multipart([ 147 | client_id, 148 | b'', 149 | pickle.dumps({ 150 | 'info': 'success', 151 | 'env_id': result['env_id'], 152 | 'result': result['result'], 153 | 'debug': result['debug'], 154 | }) 155 | ]) 156 | 157 | requests = requests[1:] 158 | 159 | # try getting new sample 160 | try: 161 | client_id, empty, data = socket.recv_multipart(zmq.DONTWAIT) 162 | data = pickle.loads(data) 163 | requests.append((client_id, data)) 164 | except zmq.Again: 165 | pass 166 | 167 | 168 | if __name__ == "__main__": 169 | main() 170 | -------------------------------------------------------------------------------- /vla_network/data_preprocessing/tokenizer.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Callable, Dict, List 3 | from tqdm import trange 4 | import numpy as np 5 | 6 | from vla_network.config import VLADataConfig 7 | 8 | robot_tokenizer = None 9 | 10 | class RobotTokenizer: 11 | config: VLADataConfig 12 | 13 | def __init__(self, config: VLADataConfig, vocab_size: int): 14 | self.config = config 15 | self.vocab_size = vocab_size 16 | 17 | @staticmethod 18 | def init(config: VLADataConfig, vocab_size: int): 19 | global robot_tokenizer 20 | if robot_tokenizer is None: 21 | config = deepcopy(config) 22 | if config.tokenizer_type == "uniform": 23 | robot_tokenizer = UniformRobotTokenizer(config, vocab_size) 24 | elif config.tokenizer_type == "ratio_min_max_uniform": 25 | robot_tokenizer = RatioMinMaxUniformRobotTokenizer(config, vocab_size) 26 | return robot_tokenizer 27 | 28 | def bbox(self, bbox: np.ndarray) -> np.ndarray: 29 | raise NotImplementedError 30 | 31 | def proprio(self, proprio: np.ndarray) -> np.ndarray: 32 | raise NotImplementedError 33 | 34 | def action(self, action: np.ndarray) -> np.ndarray: 35 | raise NotImplementedError 36 | 37 | def inv_action(self, action: np.ndarray) -> np.ndarray: 38 | raise NotImplementedError 39 | 40 | def goal(self, goal: np.ndarray) -> np.ndarray: 41 | raise NotImplementedError 42 | 43 | def inv_goal(self, goal: np.ndarray) -> np.ndarray: 44 | raise NotImplementedError 45 | 46 | def save(self) -> dict: 47 | return {} 48 | 49 | 50 | class UniformRobotTokenizer(RobotTokenizer): 51 | config: VLADataConfig 52 | bins: np.ndarray 53 | 54 | def __init__(self, config: VLADataConfig, vocab_size: int): 55 | super().__init__(config, vocab_size) 56 | self.bins = np.linspace(-1.0, 1.0, config.action_token_num) 57 | 58 | def uniform_tokenize(self, x: np.ndarray) -> np.ndarray: 59 | x = x.flatten() 60 | discretized_action = np.clip(np.digitize(x, self.bins), a_min=1, a_max=self.config.action_token_num) 61 | return self.vocab_size - discretized_action 62 | 63 | def uniform_detokenize(self, x: np.ndarray) -> np.ndarray: 64 | y = self.vocab_size - x 65 | return ( 66 | self.bins[np.clip(y - 1, a_min=0, a_max=self.config.action_token_num - 1)] 67 | + self.bins[np.clip(y, a_min=0, a_max=self.config.action_token_num - 1)] 68 | ) / 2 69 | 70 | def bbox(self, bbox: np.ndarray) -> np.ndarray: 71 | return self.uniform_tokenize(bbox) 72 | 73 | def proprio(self, proprio: np.ndarray) -> np.ndarray: 74 | return self.uniform_tokenize(proprio) 75 | 76 | def action(self, action: np.ndarray) -> np.ndarray: 77 | return self.uniform_tokenize(action) 78 | 79 | def goal(self, goal: np.ndarray) -> np.ndarray: 80 | return self.uniform_tokenize(goal) 81 | 82 | def inv_action(self, action: np.ndarray) -> np.ndarray: 83 | return self.uniform_detokenize(action) 84 | 85 | def inv_goal(self, goal: np.ndarray) -> np.ndarray: 86 | return self.uniform_detokenize(goal) 87 | 88 | 89 | class RatioMinMaxUniformRobotTokenizer(RobotTokenizer): 90 | config: VLADataConfig 91 | 92 | def __init__(self, config: VLADataConfig, vocab_size: int): 93 | super().__init__(config, vocab_size) 94 | self.uniform_tokenizer = UniformRobotTokenizer(config, vocab_size) 95 | 96 | def bbox(self, bbox: np.ndarray) -> np.ndarray: 97 | return self.uniform_tokenizer.bbox(bbox) 98 | 99 | def proprio(self, proprio: np.ndarray) -> np.ndarray: 100 | proprio = self.norm_proprio(proprio) 101 | return self.uniform_tokenizer.proprio(proprio) 102 | 103 | def action(self, action: np.ndarray) -> np.ndarray: 104 | action = self.norm_action(action) 105 | return self.uniform_tokenizer.action(action) 106 | 107 | def goal(self, goal: np.ndarray) -> np.ndarray: 108 | goal = self.norm_goal(goal) 109 | return self.uniform_tokenizer.goal(goal) 110 | 111 | def inv_action(self, action: np.ndarray) -> np.ndarray: 112 | action = self.uniform_tokenizer.inv_action(action) 113 | return self.inv_norm_action(action) 114 | 115 | def inv_goal(self, goal: np.ndarray) -> np.ndarray: 116 | goal = self.uniform_tokenizer.inv_goal(goal) 117 | return self.inv_norm_goal(goal) 118 | 119 | def norm(self, x: np.ndarray, min_v: np.ndarray, max_v: np.ndarray): 120 | return (x - min_v) / (max_v - min_v) * 2 - 1 121 | 122 | def inv_norm(self, x: np.ndarray, min_v: np.ndarray, max_v: np.ndarray): 123 | return (x + 1) / 2 * (max_v - min_v) + min_v 124 | 125 | def norm_proprio(self, proprio: np.ndarray): 126 | return self.norm(proprio, self.min_proprio, self.max_proprio) 127 | 128 | def norm_action(self, action: np.ndarray): 129 | return self.norm(action, self.min_action, self.max_action) 130 | 131 | def norm_goal(self, goal: np.ndarray): 132 | return self.norm(goal, self.min_proprio[:-1], self.max_proprio[:-1]) 133 | 134 | def inv_norm_action(self, action: np.ndarray): 135 | return self.inv_norm(action, self.min_action, self.max_action) 136 | 137 | def inv_norm_goal(self, goal: np.ndarray): 138 | return self.inv_norm(goal, self.min_proprio[:-1], self.max_proprio[:-1]) 139 | 140 | def setup(self, get_func: Callable[[], Dict[str, np.ndarray]]): 141 | keys = list(get_func().keys()) 142 | results = [[] for _ in keys] 143 | for _ in trange(self.config.count_num, desc="setup proprio action"): 144 | dic = get_func() 145 | for i in range(len(keys)): 146 | results[i].append(dic[keys[i]]) 147 | for i in range(len(keys)): 148 | results[i] = np.stack(results[i]) 149 | 150 | def set_min_max(data: np.ndarray, eps: float = 1e-7): 151 | data = data.reshape(-1, data.shape[-1]) 152 | return (np.percentile(data, self.config.tokenizer_ratio_limit * 100, axis=0) - eps, 153 | np.percentile(data, (1 - self.config.tokenizer_ratio_limit) * 100, axis=0) + eps 154 | ) 155 | 156 | self.min_proprio, self.max_proprio = set_min_max(results[keys.index("proprio")]) 157 | self.min_action, self.max_action = set_min_max(results[keys.index("action")]) 158 | 159 | def store_names(self) -> List[str]: 160 | ret = [] 161 | for x in ["min", "max"]: 162 | for y in ["proprio", "action"]: 163 | ret.append(f"{x}_{y}") 164 | return ret 165 | 166 | def save(self) -> dict: 167 | ret = dict() 168 | for n in self.store_names(): 169 | if getattr(self, n) is not None: 170 | ret[n] = getattr(self, n) 171 | return ret 172 | 173 | def load(self, data: dict): 174 | for n in self.store_names(): 175 | if n in data: 176 | setattr(self, n, data[n]) 177 | else: 178 | setattr(self, n, None) 179 | 180 | -------------------------------------------------------------------------------- /vla_network/data_preprocessing/token_pattern.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from dataclasses import dataclass, field 3 | from typing import Callable, Dict, List, Optional, Tuple, Union 4 | from pydantic import BaseModel, Field 5 | import torch 6 | import numpy as np 7 | from vla_network.utils.constant import IGNORE_INDEX 8 | from vla_network.config import VLADataConfig 9 | 10 | UNFINISHED = "unfinished_" 11 | 12 | def to_flatten_list(x: Union["torch.Tensor", np.ndarray, list]) -> list: # type: ignore 13 | """ 14 | Convert a tensor, numpy array, or list to a flattened list. 15 | 16 | Parameters 17 | ---------- 18 | x : Union["torch.Tensor", np.ndarray, list] 19 | The input data to convert. 20 | 21 | Returns 22 | ------- 23 | list 24 | The flattened list. 25 | """ 26 | if torch is not None and isinstance(x, torch.Tensor): 27 | return x.reshape(-1).tolist() 28 | elif isinstance(x, np.ndarray): 29 | return x.reshape(-1).tolist() 30 | elif isinstance(x, list): 31 | return np.array(x).reshape(-1).tolist() 32 | else: 33 | raise ValueError(f"Unsupported type {type(x)}") 34 | 35 | class TokenInfo(BaseModel): 36 | key: str # the key in the input dict 37 | length: Optional[int] # the length of the token, None means no limit 38 | est: bool # whether loss is calculated for those tokens 39 | as_input: bool # whether those tokens are input tokens 40 | # terminate function to determine if the token sequence is complete 41 | # the default is to check if the length of the tokens is equal to the length of the token info 42 | terminate: Callable[["TokenInfo", List[int]], bool] = Field(default=lambda tinfo, tokens: len(tokens) == tinfo.length) 43 | 44 | def model_post_init(self, __context): 45 | assert self.key not in ['terminate', 'input_ids', 'robot_input_ids'], f"key {self.key} is not allowed in TokenInfo" 46 | assert not self.key.startswith(UNFINISHED), f"key {self.key} is not allowed in TokenInfo" 47 | 48 | @dataclass 49 | class TokenResult: 50 | # whether the token sequence is complete 51 | terminate: bool = field(default_factory=lambda: False) 52 | # the new input ids of the tokens 53 | input_ids: List[int] = field(default_factory=lambda: []) 54 | # the new robot input ids of the tokens 55 | robot_input_ids: List[int] = field(default_factory=lambda: []) 56 | # allow arbitrary key(str): value(List[int]) 57 | # note that there are two requirements of the keys (see TokenInfo.model_post_init): 58 | # 1. they should not conflict with the above keys 59 | # 2. they should not start with "unfinished_" since we use this prefix to indicate that the token is not finished 60 | 61 | # things not listed above is also shown in this func 62 | def __str__(self) -> str: 63 | ret = 'TokenResult:\n' 64 | for k, v in self.__dict__.items(): 65 | ret += f'\t{k}={v}\n' 66 | return ret 67 | 68 | def __repr__(self) -> str: 69 | return self.__str__() 70 | 71 | class TokenPattern(BaseModel): 72 | # the token info for the input tokens, N 73 | infos: List[Optional[TokenInfo]] 74 | # the token info for the robot input tokens 75 | robot_infos: List[Optional[TokenInfo]] 76 | 77 | # get the input ids and labels for the input tokens 78 | def get_input_id_label(self, **kwargs: Dict[str, List[int]])-> Tuple[List[int], List[int]]: 79 | return self.get_id_label_inner(self.infos, **kwargs) 80 | 81 | # get the input ids and labels for the robot input tokens 82 | def get_robot_input_id_label(self, **kwargs: Dict[str, List[int]]) -> Tuple[List[int], List[int]]: 83 | return self.get_id_label_inner(self.robot_infos, **kwargs) 84 | 85 | @staticmethod 86 | def get_id_label_inner(infos: List[TokenInfo], **kwargs: Dict[str, List[int]]) -> Tuple[List[int], List[int]]: 87 | input_ids, labels = [], [] 88 | for info in infos: 89 | if info is None: 90 | continue 91 | value = to_flatten_list(kwargs.get(info.key, [])) 92 | assert info.length is None or len(value) == info.length, f"key {info.key} length {len(value)} != {info.length}" 93 | # add to input ids 94 | input_ids.extend(to_flatten_list(value)) 95 | if info.est: 96 | # add to labels 97 | labels.extend(to_flatten_list(value)) 98 | else: 99 | # ignore those tokens's loss 100 | labels.extend([IGNORE_INDEX] * len(value)) 101 | return input_ids, labels 102 | 103 | # suppose we output some of the tokens (maybe unfinished), we need to update the input ids and labels 104 | # can be used in generation 105 | def update_tokens(self, output: List[int], **kwargs: Dict[str, List[int]]) -> TokenResult: 106 | output = deepcopy(to_flatten_list(output)) 107 | ret = TokenResult(terminate=False) 108 | # shallow copy the input ids so that we can add tokens to it 109 | for ids, infos in [(ret.input_ids, self.infos), (ret.robot_input_ids, self.robot_infos)]: 110 | for info in infos: 111 | if info is None: 112 | continue 113 | 114 | if info.as_input: 115 | # if the token is as_input, then we should find them in input and add them to ids 116 | if info.key in kwargs: 117 | value = to_flatten_list(kwargs[info.key]) 118 | ids.extend(value) 119 | else: 120 | value = None 121 | setattr(ret, info.key, value) 122 | else: 123 | # the token is not as_input, then we should find the tokens in the output 124 | cur = [] 125 | while True: 126 | if info.terminate(info, cur): 127 | # if this part finished, then we should save to ret and break 128 | setattr(ret, info.key, cur) 129 | break 130 | elif len(output) == 0: 131 | # if the output is empty, then the rest of tokens haven't been predicted 132 | # save the current tokens and return 133 | # terminate is False 134 | setattr(ret, UNFINISHED+info.key, cur) 135 | return ret 136 | else: 137 | # the next token is one of the output tokens 138 | token_id = output.pop(0) 139 | ids.append(token_id) 140 | cur.append(token_id) 141 | 142 | # all the tokens are predicted 143 | assert len(output) == 0, f"output is not empty, {output}" 144 | ret.terminate = True 145 | return ret 146 | # fmt: off 147 | 148 | def get_cot_action_pattern(config: VLADataConfig) -> TokenPattern: 149 | return TokenPattern( 150 | infos=[ 151 | TokenInfo(key='text_ids', length=None, est=False, as_input=True), 152 | TokenInfo(key='bbox', length=config.img_num * 4, est=True, as_input=False) if config.use_bbox else None, 153 | TokenInfo(key='hist_proprio', length=(config.proprio_len-1) * config.proprio_dim, est=False, as_input=True), 154 | TokenInfo(key='cur_proprio', length=config.proprio_dim, est=False, as_input=True), 155 | TokenInfo(key='goal', length=config.goal_dim, est=True, as_input=False) if config.goal_dim is not None else None, 156 | TokenInfo(key='eos', length=1, est=True, as_input=False), 157 | ], 158 | robot_infos=[ 159 | ], 160 | ) 161 | 162 | 163 | def get_token_pattern(config: VLADataConfig, name: str) -> TokenPattern: 164 | return dict( 165 | cot_action=get_cot_action_pattern, 166 | )[name](config) -------------------------------------------------------------------------------- /vla_network/model/backbone_llm/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | from torch import nn 5 | import copy 6 | from transformers import PreTrainedModel, PreTrainedTokenizerBase 7 | from transformers.modeling_outputs import CausalLMOutputWithPast 8 | from vla_network.config import LLMConfig 9 | 10 | 11 | PAD_TOKEN = "" 12 | 13 | 14 | class LLMBackbone(nn.Module): 15 | config: LLMConfig 16 | llm: PreTrainedModel 17 | tokenizer: PreTrainedTokenizerBase 18 | 19 | def __init__(self, config: LLMConfig, train: bool = False) -> None: 20 | super().__init__() 21 | 22 | config = copy.deepcopy(config) 23 | 24 | self.config = config 25 | 26 | self.llm = config.model_cls.from_pretrained( 27 | "internlm/internlm2-1_8b", 28 | attn_implementation=config.attn_implementation, 29 | do_sample=False, 30 | temperature=1.0, 31 | top_p=1.0, 32 | trust_remote_code=True, 33 | torch_dtype=torch.bfloat16, 34 | ) 35 | 36 | self.llm.config.use_cache = True 37 | 38 | self.tokenizer = config.token_cls.from_pretrained( 39 | "internlm/internlm2-1_8b", 40 | model_max_length=self.config.max_len, 41 | padding_side="right", 42 | trust_remote_code=True, 43 | ) 44 | self.tokenizer.add_special_tokens( 45 | {"additional_special_tokens": self.config.special_tokens} 46 | ) 47 | self.tokenizer.add_special_tokens({"pad_token": PAD_TOKEN}) 48 | self.llm.config.pad_token_id = self.tokenizer.pad_token_id 49 | self.llm.resize_token_embeddings( 50 | len(self.tokenizer), pad_to_multiple_of=config.pad_multiple_of 51 | ) 52 | 53 | @property 54 | def input_dim(self) -> int: 55 | return self.input_embedding.embedding_dim 56 | 57 | def forward( 58 | self, 59 | input_ids: Optional[torch.LongTensor] = None, 60 | attention_mask: Optional[torch.Tensor] = None, 61 | position_ids: Optional[torch.LongTensor] = None, 62 | past_key_values: Optional[List[torch.FloatTensor]] = None, 63 | inputs_embeds: Optional[torch.FloatTensor] = None, 64 | labels: Optional[torch.LongTensor] = None, 65 | use_cache: Optional[bool] = None, 66 | output_attentions: Optional[bool] = None, 67 | output_hidden_states: Optional[bool] = None, 68 | return_dict: Optional[bool] = None, 69 | ) -> CausalLMOutputWithPast: 70 | """Run a forward pass through the LLM given targets (labels), returning the scalar Cross-Entropy Loss""" 71 | return self.llm( 72 | input_ids=input_ids, 73 | attention_mask=attention_mask, 74 | position_ids=position_ids, 75 | past_key_values=past_key_values, 76 | inputs_embeds=inputs_embeds, 77 | labels=labels, 78 | use_cache=use_cache, 79 | output_attentions=output_attentions, 80 | output_hidden_states=output_hidden_states, 81 | return_dict=return_dict, 82 | ) 83 | 84 | def generate( 85 | self, 86 | max_token_num: int, 87 | attention_mask: Optional[torch.Tensor] = None, 88 | position_ids: Optional[torch.LongTensor] = None, 89 | cache: Optional[dict] = None, 90 | inputs_embeds: Optional[torch.FloatTensor] = None, 91 | requires_past_key_values: bool = False, 92 | ): 93 | """Contains optimization for generating a given number of tokens. 94 | NOTE: the returned cache should not contain things computed from the last generated token. 95 | """ 96 | assert inputs_embeds.shape[0] == 1, "only single sample for now" 97 | return self.generate_normal( 98 | max_token_num, 99 | attention_mask=attention_mask, 100 | position_ids=position_ids, 101 | cache=cache, 102 | inputs_embeds=inputs_embeds, 103 | ) 104 | 105 | def generate_normal( 106 | self, 107 | max_token_num: int, 108 | attention_mask: Optional[torch.Tensor] = None, 109 | position_ids: Optional[torch.LongTensor] = None, 110 | cache: Optional[dict] = None, 111 | inputs_embeds: Optional[torch.FloatTensor] = None, 112 | ): 113 | """ 114 | TODO: currently does not check termination, generates max_token_num for all sequences. 115 | """ 116 | assert attention_mask is None 117 | assert position_ids is None 118 | if cache is None: 119 | cache = {} 120 | device = inputs_embeds.device if inputs_embeds is not None else next(self.llm.parameters()).device 121 | batch_size = inputs_embeds.shape[0] if inputs_embeds is not None else attention_mask.shape[0] 122 | generated_tokens = torch.zeros((batch_size, max_token_num), dtype=torch.long, device=device) 123 | past_key_values = cache.get("past_key_values") 124 | 125 | # construct attention mask and position ids, precompute for input_len + max_token_num so that we can simply slice the tensors during generation 126 | full_length = inputs_embeds.shape[1] + max_token_num 127 | full_length += past_key_values[0][0].shape[2] if past_key_values is not None else 0 128 | position_ids = torch.arange(full_length, device=device).unsqueeze(0) 129 | if self.llm.config.attn_implementation == "flex_attention": 130 | attention_mask = torch.tril(torch.ones((full_length, full_length), device=device)).unsqueeze(0).unsqueeze(0) 131 | else: 132 | attention_mask = torch.ones((batch_size, full_length), device=device) 133 | 134 | # pad to multiples of PAD_TO to avoid torch recompile with varying seq len 135 | PAD_TO = 16 136 | num_padding = cache["num_padding"] if "num_padding" in cache else ((PAD_TO - (full_length % PAD_TO)) % PAD_TO) 137 | if num_padding > 0: 138 | # pad inputs_embeds only in prefill stage 139 | if past_key_values is None: 140 | pad_embeds = torch.zeros((inputs_embeds.shape[0], num_padding, inputs_embeds.shape[2]), dtype=inputs_embeds.dtype, device=inputs_embeds.device) 141 | inputs_embeds = torch.cat([pad_embeds, inputs_embeds], dim=1) 142 | 143 | # pad position_ids 144 | pad_pos = torch.zeros((position_ids.shape[0], num_padding), dtype=position_ids.dtype, device=position_ids.device) 145 | position_ids = torch.cat([pad_pos, position_ids], dim=1) 146 | 147 | # pad attention_mask 148 | if self.llm.config.attn_implementation == "flex_attention": 149 | pad_mask = torch.zeros((*attention_mask.shape[:2], num_padding, full_length), dtype=attention_mask.dtype, device=attention_mask.device) 150 | attention_mask = torch.cat([pad_mask, attention_mask], dim=2) 151 | pad_mask2 = torch.zeros((*attention_mask.shape[:2], full_length+num_padding, num_padding), dtype=attention_mask.dtype, device=attention_mask.device) 152 | attention_mask = torch.cat([pad_mask2, attention_mask], dim=3) 153 | else: 154 | # attention_mask: (batch_size, full_length) 155 | pad_mask = torch.zeros((attention_mask.shape[0], num_padding), dtype=attention_mask.dtype, device=attention_mask.device) 156 | attention_mask = torch.cat([pad_mask, attention_mask], dim=1) 157 | 158 | for i in range(max_token_num): 159 | if past_key_values is not None: 160 | past_length = past_key_values[0][0].shape[2] 161 | else: 162 | past_length = 0 163 | outputs = self.llm( 164 | inputs_embeds=inputs_embeds, 165 | attention_mask=attention_mask[:, :, past_length:past_length+inputs_embeds.shape[1], :past_length+inputs_embeds.shape[1]] if attention_mask is not None else None, 166 | position_ids=position_ids[:, past_length:past_length+inputs_embeds.shape[1]] if attention_mask is not None else None, 167 | past_key_values=past_key_values, 168 | use_cache=True, 169 | return_dict=True, 170 | ) 171 | next_token_logits = outputs.logits[:, -1, :] 172 | next_tokens = next_token_logits.argmax(dim=-1) 173 | generated_tokens[:, i] = next_tokens 174 | 175 | past_key_values = outputs.past_key_values 176 | inputs_embeds = self.llm.get_input_embeddings()(next_tokens.unsqueeze(-1)) 177 | 178 | return generated_tokens.tolist(), {**cache, "past_key_values": past_key_values, "num_padding": num_padding} 179 | 180 | def gradient_checkpointing_enable( 181 | self, gradient_checkpointing_kwargs: Optional[dict] = None 182 | ): 183 | self.llm.gradient_checkpointing_enable(gradient_checkpointing_kwargs) 184 | 185 | @property 186 | def input_embedding(self) -> nn.Embedding: 187 | return self.llm.get_input_embeddings() 188 | -------------------------------------------------------------------------------- /vla_network/model/backbone_llm/internlm/configuration_internlm2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on transformers/src/transformers/models/llama/configuration_llama.py 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """ InternLM2 model configuration""" 18 | 19 | from transformers.configuration_utils import PretrainedConfig 20 | from transformers.utils import logging 21 | 22 | logger = logging.get_logger(__name__) 23 | 24 | INTERNLM2_PRETRAINED_CONFIG_ARCHIVE_MAP = {} 25 | 26 | 27 | # Modified from transformers.model.llama.configuration_llama.LlamaConfig 28 | class InternLM2Config(PretrainedConfig): 29 | r""" 30 | This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate 31 | an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a 32 | configuration with the defaults will yield a similar configuration to that of the InternLM2-7B. 33 | 34 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 35 | documentation from [`PretrainedConfig`] for more information. 36 | 37 | 38 | Args: 39 | vocab_size (`int`, *optional*, defaults to 32000): 40 | Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the 41 | `inputs_ids` passed when calling [`InternLM2Model`] 42 | hidden_size (`int`, *optional*, defaults to 4096): 43 | Dimension of the hidden representations. 44 | intermediate_size (`int`, *optional*, defaults to 11008): 45 | Dimension of the MLP representations. 46 | num_hidden_layers (`int`, *optional*, defaults to 32): 47 | Number of hidden layers in the Transformer decoder. 48 | num_attention_heads (`int`, *optional*, defaults to 32): 49 | Number of attention heads for each attention layer in the Transformer decoder. 50 | num_key_value_heads (`int`, *optional*): 51 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 52 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 53 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When 54 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 55 | by meanpooling all the original heads within that group. For more details checkout [this 56 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to 57 | `num_attention_heads`. 58 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 59 | The non-linear activation function (function or string) in the decoder. 60 | max_position_embeddings (`int`, *optional*, defaults to 2048): 61 | The maximum sequence length that this model might ever be used with. InternLM2 supports up to 32768 tokens. 62 | initializer_range (`float`, *optional*, defaults to 0.02): 63 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 64 | rms_norm_eps (`float`, *optional*, defaults to 1e-06): 65 | The epsilon used by the rms normalization layers. 66 | use_cache (`bool`, *optional*, defaults to `True`): 67 | Whether or not the model should return the last key/values attentions (not used by all models). Only 68 | relevant if `config.is_decoder=True`. 69 | pad_token_id (`int`, *optional*): 70 | Padding token id. 71 | bos_token_id (`int`, *optional*, defaults to 1): 72 | Beginning of stream token id. 73 | eos_token_id (`int`, *optional*, defaults to 2): 74 | End of stream token id. 75 | pretraining_tp (`int`, *optional*, defaults to 1): 76 | Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this 77 | document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) 78 | to understand more about it. This value is necessary to ensure exact reproducibility 79 | of the pretraining results. Please refer to [this 80 | issue](https://github.com/pytorch/pytorch/issues/76232). 81 | tie_word_embeddings (`bool`, *optional*, defaults to `False`): 82 | Whether to tie weight embeddings 83 | rope_theta (`float`, *optional*, defaults to 10000.0): 84 | The base period of the RoPE embeddings. 85 | rope_scaling (`Dict`, *optional*): 86 | Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling 87 | strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is 88 | `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update 89 | `max_position_embeddings` to the expected new maximum. See the following thread for more information on how 90 | these scaling strategies behave: 91 | https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an 92 | experimental feature, subject to breaking API changes in future versions. 93 | """ 94 | _auto_class = "AutoConfig" 95 | model_type = "internlm2" 96 | keys_to_ignore_at_inference = ["past_key_values"] 97 | 98 | def __init__( # pylint: disable=W0102 99 | self, 100 | vocab_size=103168, 101 | hidden_size=4096, 102 | intermediate_size=11008, 103 | num_hidden_layers=32, 104 | num_attention_heads=32, 105 | num_key_value_heads=None, 106 | hidden_act="silu", 107 | max_position_embeddings=2048, 108 | initializer_range=0.02, 109 | rms_norm_eps=1e-6, 110 | use_cache=True, 111 | pad_token_id=0, 112 | bos_token_id=1, 113 | eos_token_id=2, 114 | pretraining_tp=1, 115 | tie_word_embeddings=False, 116 | bias=True, 117 | rope_theta=10000, 118 | rope_scaling=None, 119 | attn_implementation=None, 120 | # add this to construct shape-compatible backbones for flow matching 121 | head_dim=None, 122 | **kwargs, 123 | ): 124 | self.vocab_size = vocab_size 125 | self.max_position_embeddings = max_position_embeddings 126 | self.hidden_size = hidden_size 127 | self.intermediate_size = intermediate_size 128 | self.num_hidden_layers = num_hidden_layers 129 | self.num_attention_heads = num_attention_heads 130 | self.bias = bias 131 | self.head_dim = head_dim 132 | 133 | if num_key_value_heads is None: 134 | num_key_value_heads = num_attention_heads 135 | self.num_key_value_heads = num_key_value_heads 136 | 137 | self.hidden_act = hidden_act 138 | self.initializer_range = initializer_range 139 | self.rms_norm_eps = rms_norm_eps 140 | self.pretraining_tp = pretraining_tp 141 | self.use_cache = use_cache 142 | self.rope_theta = rope_theta 143 | self.rope_scaling = rope_scaling 144 | self._rope_scaling_validation() 145 | self.attn_implementation = attn_implementation 146 | if self.attn_implementation is None: 147 | self.attn_implementation = "eager" 148 | 149 | super().__init__( 150 | pad_token_id=pad_token_id, 151 | bos_token_id=bos_token_id, 152 | eos_token_id=eos_token_id, 153 | tie_word_embeddings=tie_word_embeddings, 154 | **kwargs, 155 | ) 156 | 157 | def _rope_scaling_validation(self): 158 | """ 159 | Validate the `rope_scaling` configuration. 160 | """ 161 | if self.rope_scaling is None: 162 | return 163 | 164 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: 165 | raise ValueError( 166 | "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " 167 | f"got {self.rope_scaling}" 168 | ) 169 | rope_scaling_type = self.rope_scaling.get("type", None) 170 | rope_scaling_factor = self.rope_scaling.get("factor", None) 171 | if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: 172 | raise ValueError( 173 | f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" 174 | ) 175 | if ( 176 | rope_scaling_factor is None 177 | or not isinstance(rope_scaling_factor, (float, int)) 178 | or rope_scaling_factor < 1.0 179 | ): 180 | raise ValueError( 181 | f"`rope_scaling`'s factor field must be a number >= 1, got {rope_scaling_factor} " 182 | f"of type {type(rope_scaling_factor)}" 183 | ) -------------------------------------------------------------------------------- /vla_network/model/vla/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Union, Tuple, Dict, Any 2 | import numpy as np 3 | import copy 4 | import torch 5 | from torch import nn 6 | from transformers import GenerationMixin, PreTrainedTokenizerBase, PreTrainedModel 7 | from transformers.modeling_outputs import CausalLMOutputWithPast 8 | from transformers.modeling_outputs import ModelOutput 9 | import json 10 | import re 11 | from safetensors.torch import load_file 12 | import os 13 | 14 | from vla_network.type import RawVLAData 15 | from vla_network.model.backbone_2d import Backbone2D 16 | from vla_network.model.backbone_llm import LLMBackbone 17 | from vla_network.config import VLAModelConfig, ImageTransform, VLADataConfig, ActionExpertConfig 18 | from vla_network.data_preprocessing.preprocess import DataPreprocessor 19 | from vla_network.data_preprocessing.vla_data_collator import vla_collator 20 | from vla_network.data_preprocessing.token_pattern import TokenPattern, TokenResult 21 | from vla_network.utils.constant import IGNORE_INDEX 22 | from .projector import FusedMLPProjector 23 | from .flow_matching import VLAFlowMatchingModule 24 | 25 | 26 | def update_state_dict(state_dict: dict) -> dict: 27 | # update if load from prism vlm 28 | if "llm_backbone" in state_dict: 29 | state_dict["llm"] = state_dict.pop("llm_backbone") 30 | if "vision_backbone" in state_dict: 31 | state_dict["backbone_2d"] = dict() 32 | for k, v in state_dict.pop("vision_backbone").items(): 33 | state_dict["backbone_2d"][k.replace("_featurizer", ".model")] = v 34 | return state_dict 35 | 36 | 37 | def make_block_attn_mask(input_mask, block_mask): 38 | cumsum = torch.cumsum(block_mask, dim=0) 39 | causal_num = (cumsum == 0).sum() 40 | causal_mask = torch.tril(torch.ones((input_mask.shape[1], input_mask.shape[1]), dtype=torch.bool, device=input_mask.device)) 41 | if causal_num != len(block_mask): 42 | block_attn_mask = cumsum[None, causal_num:] <= cumsum[causal_num:, None] 43 | causal_mask[causal_num:, causal_num:] = block_attn_mask 44 | valid_mask = input_mask[:, None, :] * input_mask[:, :, None] 45 | return torch.logical_and(causal_mask, valid_mask)[:, None] 46 | 47 | 48 | def load_safetensors(path: str) -> dict: 49 | return load_file(path) 50 | 51 | 52 | def load_model(model: nn.Module, ckpt_path: str) -> nn.Module: 53 | ckpt = load_safetensors(ckpt_path) 54 | model.load_state_dict(ckpt) 55 | return model 56 | 57 | class VLA(nn.Module, GenerationMixin): 58 | config: VLAModelConfig 59 | backbone_2d: Backbone2D 60 | llm: LLMBackbone 61 | projector: nn.Module 62 | train_modules: List[str] 63 | tokenizer: PreTrainedTokenizerBase 64 | image_transform: ImageTransform 65 | is_train: bool 66 | 67 | def __init__(self, config: VLAModelConfig): 68 | super().__init__() 69 | self.config = config 70 | 71 | # TODO: Check whether add this to __init__ or not 72 | def init(self, train: bool = False): 73 | self.backbone_2d = Backbone2D.init(self.config.backbone_2d) 74 | self.backbone_2d_dim = self.backbone_2d.feature_dim 75 | self.image_transform = self.backbone_2d.image_transform 76 | self.is_train = train 77 | 78 | self.llm = LLMBackbone(self.config.llm, train=train) 79 | self.llm_dim = self.llm.input_dim 80 | self.tokenizer = self.llm.tokenizer 81 | 82 | if self.config.action_expert: 83 | self.action_expert = self.create_action_expert_from_llm(self.llm.llm, self.config.action_expert_cfg) 84 | 85 | # Set Weight Initialization Seed for Projector Consistency 86 | torch.manual_seed(self.backbone_2d_dim) 87 | 88 | self.projector = FusedMLPProjector(self.backbone_2d_dim, self.llm_dim) 89 | 90 | if self.config.pred == "cot_flow_matching": 91 | self.flow_module = VLAFlowMatchingModule( 92 | config=self.config.flow_matching_cfg, 93 | action_dim=self.config.action_dim, 94 | llm_dim=self.action_expert.config.hidden_size, 95 | action_len=self.config.action_len, 96 | proprio_dim=self.config.proprio_dim, 97 | ) 98 | 99 | @staticmethod 100 | def create_action_expert_from_llm(llm: PreTrainedModel, action_expert_config: ActionExpertConfig): 101 | config = copy.deepcopy(llm.config) 102 | if config.attn_implementation != "flex_attention": 103 | config.attn_implementation = "flex_attention" 104 | config.hidden_size = config.hidden_size // action_expert_config.hidden_size_scale 105 | config.intermediate_size = config.intermediate_size // action_expert_config.intermediate_size_scale 106 | config.hidden_act = config.hidden_act 107 | config.head_dim = llm.model.layers[0].attention.head_dim 108 | model_cls = type(llm) 109 | return model_cls._from_config(config) 110 | 111 | def from_pretrained(self, path: Optional[str] = None) -> "VLA": 112 | if path is None: 113 | path = self.config.ckpt 114 | state_dict = torch.load(path, map_location="cpu", weights_only=True)["model"] 115 | state_dict = update_state_dict(state_dict) 116 | if "backbone_2d" in state_dict: 117 | self.backbone_2d.load_state_dict(state_dict["backbone_2d"]) 118 | self.projector.load_state_dict(state_dict["projector"]) 119 | return self 120 | 121 | @staticmethod 122 | def insert_img_info(orig: torch.Tensor, img_info: torch.Tensor) -> torch.Tensor: 123 | return torch.cat([orig[:, :1], img_info, orig[:, 1:]], dim=1) # fmt: skip 124 | 125 | @staticmethod 126 | def insert_img_info_single(orig: torch.Tensor, img_info: torch.Tensor) -> torch.Tensor: 127 | return torch.cat([orig[:1], img_info, orig[1:]], dim=0) # fmt: skip 128 | 129 | def get_proj_feat_2d(self, images: torch.FloatTensor) -> torch.FloatTensor: 130 | with torch.set_grad_enabled(False): 131 | feat_2d = self.backbone_2d(images) 132 | proj_feat_2d = self.projector(feat_2d) 133 | return proj_feat_2d 134 | 135 | def embed_prefix( 136 | self, 137 | input_ids: torch.LongTensor = None, 138 | attention_mask: torch.Tensor = None, 139 | images: Optional[torch.FloatTensor] = None, 140 | labels: Optional[torch.LongTensor] = None, 141 | proj_feat_2d: Optional[torch.FloatTensor] = None, 142 | ) -> tuple[torch.FloatTensor, torch.BoolTensor, torch.BoolTensor, torch.LongTensor]: 143 | 144 | b = len(input_ids) 145 | if proj_feat_2d is None: 146 | proj_feat_2d = self.get_proj_feat_2d(images) 147 | n_img_token = proj_feat_2d.shape[1] 148 | 149 | input_embed = self.llm.input_embedding(input_ids) 150 | mm_input_embed = self.insert_img_info(input_embed, proj_feat_2d).to( 151 | input_embed.dtype 152 | ) 153 | 154 | img_attn_mask = torch.ones( 155 | (b, n_img_token), dtype=torch.bool, device=attention_mask.device 156 | ) 157 | mm_attn_mask = self.insert_img_info(attention_mask, img_attn_mask) 158 | 159 | n_mm_token = mm_attn_mask.shape[1] 160 | mm_block_mask = torch.zeros( 161 | (n_mm_token, ), dtype=torch.bool, 162 | device=attention_mask.device 163 | ) 164 | 165 | if labels is None: 166 | mm_labels = None 167 | else: 168 | img_labels = torch.full( 169 | (b, n_img_token), IGNORE_INDEX, dtype=labels.dtype, device=labels.device 170 | ) 171 | mm_labels = self.insert_img_info(labels, img_labels) 172 | 173 | return mm_input_embed, mm_attn_mask, mm_block_mask, mm_labels 174 | 175 | def gradient_checkpointing_enable( 176 | self, gradient_checkpointing_kwargs: Optional[dict] = None 177 | ): 178 | self.llm.gradient_checkpointing_enable(gradient_checkpointing_kwargs) 179 | 180 | # TODO: remove unused inputs 181 | # TODO: what should be the output type of this function? 182 | def generate( 183 | self, 184 | input_ids: torch.LongTensor = None, 185 | robot_input_ids: torch.LongTensor = None, 186 | attention_mask: torch.Tensor = None, 187 | robot_attention_mask: torch.Tensor = None, 188 | images: Optional[torch.FloatTensor] = None, 189 | proprio: Optional[torch.FloatTensor] = None, 190 | # TODO: maybe requires runtime config 191 | max_token_num: int = int(1e10), 192 | flow_matching_iter: int = 10, 193 | inference_kwargs: List[dict] = None, 194 | token_pattern: Optional[TokenPattern] = None, 195 | ) -> Tuple[TokenResult, Any]: 196 | # TODO: This is a temporary solution 197 | # Latter we will change to C++ implementation 198 | # So don't care about the performance 199 | 200 | proj_feat_2d = self.get_proj_feat_2d(images) 201 | prefix_embeds, prefix_mask, prefix_block_mask, _ = self.embed_prefix( 202 | input_ids=input_ids, 203 | attention_mask=attention_mask, 204 | proj_feat_2d=proj_feat_2d, 205 | labels=None 206 | ) 207 | 208 | if self.config.pred == "cot_flow_matching": 209 | # generate bbox and goal tokens autoregressively 210 | cot_parse, kv_cache = self.generate_autoregressive( 211 | input_ids=input_ids, 212 | robot_input_ids=robot_input_ids, 213 | proj_feat_2d=proj_feat_2d, 214 | attention_mask=attention_mask, 215 | robot_attention_mask=robot_attention_mask, 216 | max_token_num=max_token_num, 217 | token_pattern=token_pattern, 218 | inference_kwargs=inference_kwargs, 219 | require_kv_cache=True, 220 | ) 221 | 222 | input_ids = torch.tensor(cot_parse.input_ids, device=input_ids.device)[None] 223 | _, prefix_mask, prefix_block_mask, _ = self.embed_prefix( 224 | input_ids=input_ids, 225 | attention_mask=torch.ones_like(input_ids).bool(), 226 | proj_feat_2d=proj_feat_2d, 227 | labels=None 228 | ) 229 | 230 | padded_prefix_length = kv_cache[0][0].shape[2] 231 | num_paddings = padded_prefix_length - prefix_mask.shape[1] 232 | if num_paddings > 0: 233 | pad_mask = torch.zeros((prefix_mask.shape[0], num_paddings), dtype=prefix_mask.dtype, device=prefix_mask.device) 234 | prefix_mask = torch.cat([pad_mask, prefix_mask], dim=1) 235 | pad_block_mask = torch.zeros((num_paddings,), dtype=prefix_block_mask.dtype, device=prefix_block_mask.device) 236 | prefix_block_mask = torch.cat([pad_block_mask, prefix_block_mask], dim=0) 237 | 238 | # generate actions using flow matching 239 | action = self.generate_flow_matching( 240 | prefix_kv_cache=kv_cache, 241 | prefix_mask=prefix_mask, 242 | prefix_block_mask=prefix_block_mask, 243 | proprio=proprio, 244 | flow_matching_iter=flow_matching_iter, 245 | ) 246 | ret = cot_parse, action 247 | else: 248 | raise NotImplementedError(f"Prediction type {self.config.pred} is not implemented.") 249 | return ret 250 | 251 | def generate_flow_matching(self, prefix_kv_cache, prefix_mask, prefix_block_mask, proprio, flow_matching_iter): 252 | device, dtype = prefix_kv_cache[0][0].device, prefix_kv_cache[0][0].dtype 253 | assert self.config.action_expert 254 | assert self.config.llm.attn_implementation == "flex_attention" 255 | proprio = proprio.to(dtype) 256 | # TODO: should move to flow matching module instead of here 257 | noise = self.flow_module.sample_noise( 258 | batch_size=len(proprio), 259 | device=device, 260 | dtype=dtype 261 | ) 262 | proprio_embeds = self.flow_module.proprior_proj(proprio) 263 | suffix_mask, suffix_block_mask = self.flow_module.get_suffix_masks(proprio_embeds) 264 | 265 | full_input_mask = torch.cat((prefix_mask, suffix_mask), dim=1) 266 | full_block_mask = torch.cat((prefix_block_mask, suffix_block_mask), axis=0) 267 | full_attn_mask = make_block_attn_mask(full_input_mask, full_block_mask).to(dtype) 268 | full_position_ids = torch.cumsum(full_input_mask, dim=1) - 1 269 | suffix_attn_mask = full_attn_mask[:, :, -suffix_mask.shape[1]:, ...] 270 | suffix_position_ids = full_position_ids[:, -suffix_mask.shape[1]:] 271 | 272 | prefix_kv_cache = tuple(prefix_kv_cache) 273 | 274 | def compute_v_t(x_t: torch.Tensor, time_vec: torch.Tensor): 275 | suffix_embeds = self.flow_module.embed_suffix_flow_matching_embeds(proprio_embeds, x_t, time_vec) 276 | action_expert_output = self.action_expert( 277 | attention_mask=suffix_attn_mask, 278 | position_ids=suffix_position_ids, 279 | inputs_embeds=suffix_embeds, 280 | past_key_values=prefix_kv_cache, use_cache=True, output_hidden_states=True, 281 | ) 282 | 283 | action_hidden_states = action_expert_output.hidden_states[-1][:, -self.config.action_len:] 284 | v_t = self.flow_module.get_v_t(action_hidden_states) 285 | return v_t 286 | 287 | x_0 = self.flow_module.denoise(compute_v_t, noise, flow_matching_iter) 288 | return x_0 289 | 290 | def generate_autoregressive(self, input_ids, robot_input_ids, proj_feat_2d, attention_mask, robot_attention_mask, max_token_num, token_pattern, inference_kwargs, require_kv_cache=False) -> Tuple[TokenPattern, Optional[Any]]: 291 | """Returns token pattern and kv cache. 292 | Requires batch size == 1 and no padding and no block attention. 293 | require_key_values enforces returning all_key_values in the cache. 294 | Note that this all_key_values includes things computed with the last token for flow matching, take care! 295 | """ 296 | assert input_ids.shape[0] == 1, "only support single sample for now" 297 | cache = None 298 | current_input_embeddings = [] 299 | current_input_mask = [] 300 | current_block_mask = [] 301 | pending = 0 302 | total_length = 0 303 | output = [] 304 | for idx, token_info in enumerate([*token_pattern.infos, *token_pattern.robot_infos]): 305 | if token_info is None: 306 | continue 307 | if token_info.as_input: 308 | embeddings = self.llm.input_embedding(torch.tensor(inference_kwargs[0][token_info.key], device=input_ids.device)) 309 | if idx == 0: 310 | # insert the proj_feat_2d after the first embedding 311 | embeddings = self.insert_img_info_single(embeddings, proj_feat_2d[0]) 312 | current_input_embeddings.append(embeddings) 313 | current_block_mask.extend([0] * embeddings.shape[0]) 314 | current_input_mask.extend([1] * embeddings.shape[0]) 315 | pending += embeddings.shape[0] 316 | total_length += embeddings.shape[0] 317 | continue 318 | 319 | # let the network generate, then clear pending, and update kv cache 320 | 321 | generated_tokens, cache = self.llm.generate( 322 | max_token_num=token_info.length, 323 | inputs_embeds=torch.concat(current_input_embeddings, dim=0).unsqueeze(0), 324 | cache=cache, 325 | ) 326 | total_length += len(generated_tokens[0]) 327 | output.extend(generated_tokens[0]) 328 | 329 | # reset pending tokens, it should be the embedding of the last generated token 330 | # assumes the kv cache does not contain the last token 331 | current_input_embeddings = [self.llm.input_embedding(torch.tensor(generated_tokens[0][-1:], dtype=torch.long, device=input_ids.device))] 332 | current_input_mask = [1] 333 | current_block_mask = [0] 334 | pending = 1 335 | 336 | # check completion 337 | parse_ret = token_pattern.update_tokens(output, **inference_kwargs[0]) 338 | if parse_ret.terminate or len(output) >= max_token_num: 339 | break 340 | kv_cache = None 341 | if require_kv_cache and len(current_input_embeddings) != 0: 342 | _, cache_with_past_key_values = self.llm.generate( 343 | max_token_num=1, 344 | inputs_embeds=torch.concat(current_input_embeddings, dim=0).unsqueeze(0), 345 | cache=cache, 346 | ) 347 | kv_cache = cache_with_past_key_values['past_key_values'] 348 | return parse_ret, kv_cache 349 | 350 | 351 | class VLAAgent(): 352 | def __init__(self, path: Optional[str] = None, exp_name: Optional[str]=None, iter: Optional[int] = None, device: str = 'cuda:0', compile=False): 353 | self.path, self.exp_name, self.device, self.iter = path, exp_name, device, iter 354 | self.model_cfg, self.data_cfg, self.model, self.preprocessor = self.load_vla(path, exp_name, iter, device, compile) 355 | self.token_pattern = self.preprocessor.pattern 356 | 357 | def load_vla( 358 | self, path: Optional[str]=None, exp_name: Optional[str]=None, iter: Optional[int] = None, device: str = "cuda:0", compile=False, 359 | ) -> Tuple[VLAModelConfig, VLADataConfig, VLA, DataPreprocessor]: 360 | # TODO: return cfg as a VLAConfig 361 | cfg_path = os.path.join(os.path.dirname(path), '..', 'config.json') 362 | with open(cfg_path, "r", encoding="utf-8") as f: 363 | cfg = json.load(f) 364 | data_cfg = VLADataConfig.model_validate(cfg["data"]) 365 | model_cfg = VLAModelConfig.model_validate(cfg["model"]) 366 | model: VLA = VLA(model_cfg) 367 | model.init(train=False) 368 | model = load_model(model, path) 369 | model = model.to(device).eval() 370 | if compile: 371 | model.llm.llm = torch.compile(model.llm.llm, dynamic=True) 372 | model.backbone_2d = torch.compile(model.backbone_2d) 373 | if hasattr(model, 'action_expert'): 374 | model.action_expert = torch.compile(model.action_expert, dynamic=True) 375 | data_cfg.tokenizer = model.tokenizer 376 | data_cfg.image_size = model.config.backbone_2d.image_size 377 | data_cfg.image_transform = model.image_transform 378 | data_cfg.pred = model_cfg.pred 379 | preprocessor = DataPreprocessor(data_cfg) 380 | preprocessor_path = os.path.join(os.path.dirname(path), '..', 'preprocessor.npz') 381 | preprocessor.load(np.load(preprocessor_path)) 382 | return model_cfg, data_cfg, model, preprocessor 383 | 384 | def sample_action(self, raw: RawVLAData): 385 | with torch.no_grad(): 386 | with torch.autocast(device_type='cuda', dtype=torch.bfloat16): 387 | x = self.preprocessor.transform(raw, inference=True) 388 | model_input = {k:v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in vla_collator(self.data_cfg, [x]).items()} 389 | token_result, action_result = self.model.generate( 390 | input_ids=model_input['input_ids'].to(self.device), 391 | robot_input_ids=model_input['robot_input_ids'].to(self.device), 392 | attention_mask=model_input['attention_mask'].to(self.device), 393 | robot_attention_mask=model_input['robot_attention_mask'].to(self.device), 394 | images=model_input['images'].to(self.device), 395 | proprio=model_input['proprio'].to(self.device), 396 | inference_kwargs=x.inference_kwargs, 397 | token_pattern=self.token_pattern, 398 | max_token_num=100, 399 | ) 400 | ret = {} 401 | if self.model_cfg.pred == "cot_flow_matching": 402 | ret['action'] = self.preprocessor.robot_tokenizer.inv_norm_action(action_result.float().cpu().numpy()[0]) 403 | if hasattr(token_result, 'goal'): 404 | goal = self.preprocessor.robot_tokenizer.inv_goal(np.array(token_result.goal)) 405 | ret['goal'] = (goal[:3], goal[3:6]) 406 | if hasattr(token_result, 'bbox'): 407 | ret['bbox'] = (self.preprocessor.robot_tokenizer.uniform_tokenizer.uniform_detokenize(np.array(token_result.bbox).reshape(-1, 4)) + 1)/2*224 408 | else: 409 | raise NotImplementedError() 410 | return ret 411 | 412 | 413 | def __call__(self, samples: List) -> List[Dict[str, Any]]: 414 | rets = [] 415 | for sample in samples: 416 | raw = RawVLAData( 417 | dataset_name="dummy", 418 | data_id=str(sample['env_id']), 419 | frame=0, 420 | instruction=sample['text'], 421 | images=dict( 422 | front=np.stack(sample['front_view_image']), 423 | side=np.stack(sample['side_view_image']), 424 | ), 425 | proprio=np.stack(sample['proprio_array']), 426 | ) 427 | rets.append(self.sample_action(raw)) 428 | return rets 429 | -------------------------------------------------------------------------------- /vla_network/model/backbone_llm/internlm/modeling_internlm2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # This code is based on transformers/src/transformers/models/llama/modeling_llama.py 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch InternLM2 model.""" 17 | import math 18 | from typing import List, Optional, Tuple, Union 19 | import warnings 20 | from dataclasses import dataclass 21 | 22 | import torch 23 | import torch.nn.functional as F 24 | import torch.utils.checkpoint # type: ignore 25 | from einops import rearrange 26 | from torch import nn 27 | from torch.nn import CrossEntropyLoss 28 | from transformers.activations import ACT2FN 29 | from transformers.cache_utils import Cache, DynamicCache, StaticCache 30 | from transformers.modeling_outputs import ( 31 | BaseModelOutputWithPast, 32 | CausalLMOutputWithPast, 33 | ) 34 | from transformers.modeling_utils import PreTrainedModel 35 | from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS 36 | from transformers.utils import ( 37 | add_start_docstrings, 38 | add_start_docstrings_to_model_forward, 39 | logging, 40 | replace_return_docstrings, 41 | ) 42 | 43 | try: 44 | from transformers.generation.streamers import BaseStreamer 45 | except Exception: 46 | BaseStreamer = None 47 | 48 | from .configuration_internlm2 import InternLM2Config 49 | 50 | 51 | try: 52 | from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore 53 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # type: ignore 54 | except: 55 | pass 56 | 57 | try: 58 | support_bf16_triu = torch.__version__ >= "2.1.0" 59 | except Exception: 60 | support_bf16_triu = False 61 | 62 | logger = logging.get_logger(__name__) 63 | 64 | _CONFIG_FOR_DOC = "InternLM2Config" 65 | 66 | 67 | def _get_unpad_data(attention_mask): 68 | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) 69 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() 70 | max_seqlen_in_batch = seqlens_in_batch.max().item() 71 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) # pylint: disable=E1102 72 | return ( 73 | indices, 74 | cu_seqlens, 75 | max_seqlen_in_batch, 76 | ) 77 | 78 | 79 | class InternLM2RMSNorm(nn.Module): 80 | """InternLM2RMSNorm is equivalent to T5LayerNorm.""" 81 | 82 | def __init__(self, hidden_size, eps=1e-6): 83 | super().__init__() 84 | self.weight = nn.Parameter(torch.ones(hidden_size)) 85 | self.variance_epsilon = eps 86 | 87 | def forward(self, hidden_states): 88 | input_dtype = hidden_states.dtype 89 | hidden_states = hidden_states.to(torch.float32) 90 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 91 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 92 | return self.weight * hidden_states.to(input_dtype) 93 | 94 | 95 | ALL_LAYERNORM_LAYERS.append(InternLM2RMSNorm) 96 | 97 | 98 | class InternLM2RotaryEmbedding(nn.Module): 99 | """Rotary Position Embedding for the InternLM2 model. Credits to the Reddit user /u/lucidrains.""" 100 | 101 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): 102 | super().__init__() 103 | self.scaling_factor = scaling_factor 104 | self.dim = dim 105 | self.max_position_embeddings = max_position_embeddings 106 | self.base = base 107 | inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) 108 | self.register_buffer("inv_freq", inv_freq, persistent=False) 109 | # For BC we register cos and sin cached 110 | self.max_seq_len_cached = max_position_embeddings 111 | 112 | @torch.no_grad() 113 | def forward(self, x, position_ids): 114 | # x: [bs, num_attention_heads, seq_len, head_size] 115 | inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) 116 | position_ids_expanded = position_ids[:, None, :].float() 117 | # Force float32 since bfloat16 loses precision on long contexts 118 | # See https://github.com/huggingface/transformers/pull/29285 119 | device_type = x.device.type 120 | device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" 121 | with torch.autocast(device_type=device_type, enabled=False): 122 | freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) 123 | emb = torch.cat((freqs, freqs), dim=-1) 124 | cos = emb.cos() 125 | sin = emb.sin() 126 | return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) 127 | 128 | 129 | class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding): 130 | """InternLM2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" 131 | 132 | def forward(self, x, position_ids): 133 | # difference to the original RoPE: a scaling factor is aplied to the position ids 134 | position_ids = position_ids.float() / self.scaling_factor 135 | cos, sin = super().forward(x, position_ids) 136 | return cos, sin 137 | 138 | 139 | class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding): 140 | """InternLM2RotaryEmbedding extended with Dynamic NTK scaling. 141 | Credits to the Reddit users /u/bloc97 and /u/emozilla""" 142 | 143 | def forward(self, x, position_ids): 144 | # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length 145 | seq_len = torch.max(position_ids) + 1 146 | if seq_len > self.max_position_embeddings: 147 | base = self.base * ( 148 | (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) 149 | ) ** (self.dim / (self.dim - 2)) 150 | inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)) 151 | self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation 152 | 153 | cos, sin = super().forward(x, position_ids) 154 | return cos, sin 155 | 156 | 157 | def rotate_half(x): 158 | """Rotates half the hidden dims of the input.""" 159 | x1 = x[..., : x.shape[-1] // 2] 160 | x2 = x[..., x.shape[-1] // 2 :] 161 | return torch.cat((-x2, x1), dim=-1) 162 | 163 | 164 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): # pylint: disable=unused-argument 165 | """Applies Rotary Position Embedding to the query and key tensors. 166 | 167 | Args: 168 | q (`torch.Tensor`): The query tensor. 169 | k (`torch.Tensor`): The key tensor. 170 | cos (`torch.Tensor`): The cosine part of the rotary embedding. 171 | sin (`torch.Tensor`): The sine part of the rotary embedding. 172 | position_ids (`torch.Tensor`, *optional*): 173 | Deprecated and unused. 174 | unsqueeze_dim (`int`, *optional*, defaults to 1): 175 | The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and 176 | sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note 177 | that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and 178 | k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes 179 | cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have 180 | the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. 181 | Returns: 182 | `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. 183 | """ 184 | cos = cos.unsqueeze(unsqueeze_dim) 185 | sin = sin.unsqueeze(unsqueeze_dim) 186 | q_embed = (q * cos) + (rotate_half(q) * sin) 187 | k_embed = (k * cos) + (rotate_half(k) * sin) 188 | return q_embed, k_embed 189 | 190 | 191 | class InternLM2MLP(nn.Module): 192 | """MLP for InternLM2 model.""" 193 | 194 | def __init__(self, config): 195 | super().__init__() 196 | self.config = config 197 | self.hidden_size = config.hidden_size 198 | self.intermediate_size = config.intermediate_size 199 | self.w1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 200 | self.w3 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 201 | self.w2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) 202 | self.act_fn = ACT2FN[config.hidden_act] 203 | 204 | def forward(self, x): 205 | down_proj = self.w2(self.act_fn(self.w1(x)) * self.w3(x)) 206 | 207 | return down_proj 208 | 209 | 210 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 211 | """ 212 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, 213 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) 214 | """ 215 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape 216 | if n_rep == 1: 217 | return hidden_states 218 | hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) 219 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) 220 | 221 | 222 | class InternLM2Attention(nn.Module): 223 | """Multi-headed attention from 'Attention Is All You Need' paper""" 224 | 225 | def __init__(self, config: InternLM2Config, layer_idx: Optional[int] = None): 226 | super().__init__() 227 | self.config = config 228 | self.layer_idx = layer_idx 229 | if layer_idx is None: 230 | logger.warning_once( 231 | f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " 232 | "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " 233 | "when creating this class." 234 | ) 235 | 236 | self.hidden_size = config.hidden_size 237 | self.num_heads = config.num_attention_heads 238 | self.head_dim = self.hidden_size // self.num_heads if config.head_dim is None else config.head_dim 239 | self.num_key_value_heads = config.num_key_value_heads 240 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads 241 | self.max_position_embeddings = config.max_position_embeddings 242 | self.rope_theta = config.rope_theta 243 | self.is_causal = True 244 | 245 | # if (self.head_dim * self.num_heads) != self.hidden_size: 246 | # raise ValueError( 247 | # f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" 248 | # f" and `num_heads`: {self.num_heads})." 249 | # ) 250 | 251 | self.wqkv = nn.Linear( 252 | self.hidden_size, 253 | (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, 254 | bias=config.bias, 255 | ) 256 | self.wo = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias) 257 | 258 | self._init_rope() 259 | 260 | def _init_rope(self): 261 | if self.config.rope_scaling is None: 262 | self.rotary_emb = InternLM2RotaryEmbedding( 263 | self.head_dim, 264 | max_position_embeddings=self.max_position_embeddings, 265 | base=self.rope_theta, 266 | ) 267 | else: 268 | scaling_type = self.config.rope_scaling["type"] 269 | scaling_factor = self.config.rope_scaling["factor"] 270 | if scaling_type == "linear": 271 | self.rotary_emb = InternLM2LinearScalingRotaryEmbedding( 272 | self.head_dim, 273 | max_position_embeddings=self.max_position_embeddings, 274 | scaling_factor=scaling_factor, 275 | base=self.rope_theta, 276 | ) 277 | elif scaling_type == "dynamic": 278 | self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding( 279 | self.head_dim, 280 | max_position_embeddings=self.max_position_embeddings, 281 | scaling_factor=scaling_factor, 282 | base=self.rope_theta, 283 | ) 284 | else: 285 | raise ValueError(f"Unknown RoPE scaling type {scaling_type}") 286 | 287 | def forward( 288 | self, 289 | hidden_states: torch.Tensor, 290 | attention_mask: Optional[torch.Tensor] = None, 291 | position_ids: Optional[torch.LongTensor] = None, 292 | past_key_value: Optional[Cache] = None, 293 | output_attentions: bool = False, 294 | use_cache: bool = False, # pylint: disable=unused-argument 295 | cache_position: Optional[torch.LongTensor] = None, 296 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 297 | bsz, q_len, _ = hidden_states.size() 298 | 299 | if self.config.pretraining_tp > 1: 300 | # split qkv_states by tp size 301 | key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp 302 | qkv_slices = self.wqkv.weight.split(key_value_slicing, dim=0) 303 | qkv_states = torch.cat( 304 | [F.linear(hidden_states, qkv_slice) for qkv_slice in qkv_slices], dim=-1 # pylint: disable=E1102 305 | ) 306 | else: 307 | qkv_states = self.wqkv(hidden_states) 308 | 309 | qkv_states = rearrange( 310 | qkv_states, 311 | "b q (h gs d) -> b q h gs d", 312 | gs=2 + self.num_key_value_groups, 313 | d=self.head_dim, 314 | ) 315 | 316 | query_states = qkv_states[..., : self.num_key_value_groups, :] 317 | query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d").transpose(1, 2) 318 | key_states = qkv_states[..., -2, :].transpose(1, 2) 319 | value_states = qkv_states[..., -1, :].transpose(1, 2) 320 | 321 | cos, sin = self.rotary_emb(value_states, position_ids) 322 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 323 | 324 | if past_key_value is not None: 325 | # sin and cos are specific to RoPE models; cache_position needed for the static cache 326 | cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} 327 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 328 | 329 | key_states = repeat_kv(key_states, self.num_key_value_groups) 330 | value_states = repeat_kv(value_states, self.num_key_value_groups) 331 | 332 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) 333 | 334 | if attention_mask is not None: # no matter the length, we just slice it 335 | causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] 336 | attn_weights = attn_weights + causal_mask 337 | 338 | # upcast attention to fp32 339 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 340 | attn_output = torch.matmul(attn_weights, value_states) 341 | 342 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 343 | raise ValueError( 344 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 345 | f" {attn_output.size()}" 346 | ) 347 | 348 | attn_output = attn_output.transpose(1, 2).contiguous() 349 | 350 | attn_output = attn_output.reshape(bsz, q_len, -1) 351 | 352 | if self.config.pretraining_tp > 1: 353 | attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) 354 | o_proj_slices = self.wo.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) 355 | attn_output = sum( 356 | [ 357 | F.linear(attn_output[i], o_proj_slices[i]) # pylint: disable=E1102 358 | for i in range(self.config.pretraining_tp) 359 | ] 360 | ) 361 | else: 362 | attn_output = self.wo(attn_output) 363 | 364 | if not output_attentions: 365 | attn_weights = None 366 | 367 | return attn_output, attn_weights, past_key_value 368 | 369 | 370 | class InternLM2FlexAttention(InternLM2Attention): 371 | """ 372 | InternLM2 block attention module. This module enables each token to access its own block and all the preceding blocks. 373 | """ 374 | def __init__(self, config: InternLM2Config, layer_idx: Optional[int] = None, action_expert_cfg=None): 375 | super().__init__(config, layer_idx) 376 | warnings.warn('Flex Attention not implemented, use pytorch version instead') 377 | self.is_causal = False 378 | 379 | 380 | INTERNLM2_ATTENTION_CLASSES = { 381 | "eager": InternLM2Attention, 382 | "flex_attention": InternLM2FlexAttention, 383 | } 384 | 385 | 386 | # Modified from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->InternLM2 387 | class InternLM2DecoderLayer(nn.Module): 388 | """InternLM2 Decoder Layer. This module is a single layer of the InternLM2 model.""" 389 | 390 | def __init__(self, config: InternLM2Config, layer_idx: int): 391 | super().__init__() 392 | self.hidden_size = config.hidden_size 393 | self.layer_idx = layer_idx 394 | 395 | self.attention = INTERNLM2_ATTENTION_CLASSES[config.attn_implementation](config=config, layer_idx=layer_idx) 396 | 397 | self.feed_forward = InternLM2MLP(config) 398 | self.attention_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 399 | self.ffn_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 400 | 401 | def forward( 402 | self, 403 | hidden_states: torch.Tensor, 404 | attention_mask: Optional[torch.Tensor] = None, 405 | position_ids: Optional[torch.LongTensor] = None, 406 | past_key_value: Optional[Cache] = None, 407 | output_attentions: Optional[bool] = False, 408 | use_cache: Optional[bool] = False, 409 | cache_position: Optional[torch.LongTensor] = None, 410 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 411 | """ 412 | Args: 413 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` 414 | attention_mask (`torch.FloatTensor`, *optional*): 415 | attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, 416 | query_sequence_length, key_sequence_length)` if default attention is used. 417 | output_attentions (`bool`, *optional*): 418 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 419 | returned tensors for more detail. 420 | use_cache (`bool`, *optional*): 421 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 422 | (see `past_key_values`). 423 | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states 424 | """ 425 | residual = hidden_states 426 | 427 | hidden_states = self.attention_norm(hidden_states) 428 | 429 | # Self Attention 430 | hidden_states, self_attn_weights, present_key_value = self.attention( 431 | hidden_states=hidden_states, 432 | attention_mask=attention_mask, 433 | position_ids=position_ids, 434 | past_key_value=past_key_value, 435 | output_attentions=output_attentions, 436 | use_cache=use_cache, 437 | cache_position=cache_position, 438 | ) 439 | hidden_states = residual + hidden_states 440 | # Fully Connected 441 | residual = hidden_states 442 | hidden_states = self.ffn_norm(hidden_states) 443 | hidden_states = self.feed_forward(hidden_states) 444 | hidden_states = residual + hidden_states 445 | 446 | outputs = (hidden_states,) 447 | 448 | if output_attentions: 449 | outputs += (self_attn_weights,) 450 | 451 | if use_cache: 452 | outputs += (present_key_value,) 453 | 454 | return outputs 455 | 456 | 457 | InternLM2_START_DOCSTRING = r""" 458 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 459 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 460 | etc.) 461 | 462 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 463 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 464 | and behavior. 465 | 466 | Parameters: 467 | config ([`InternLM2Config`]): 468 | Model configuration class with all the parameters of the model. Initializing with a config file does not 469 | load the weights associated with the model, only the configuration. Check out the 470 | [`~PreTrainedModel.from_pretrained`] method to load the model weights. 471 | """ 472 | 473 | 474 | # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->InternLM2 475 | @add_start_docstrings( 476 | "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.", 477 | InternLM2_START_DOCSTRING, 478 | ) 479 | class InternLM2PreTrainedModel(PreTrainedModel): 480 | """ 481 | InternLM2 pretraiend model's base class. 482 | """ 483 | 484 | config_class = InternLM2Config 485 | base_model_prefix = "model" 486 | supports_gradient_checkpointing = True 487 | _no_split_modules = ["InternLM2DecoderLayer"] 488 | _skip_keys_device_placement = ["past_key_values"] 489 | _supports_flash_attn_2 = True 490 | _supports_sdpa = True 491 | _supports_cache_class = True 492 | _supports_quantized_cache = True 493 | _supports_static_cache = True 494 | 495 | def _init_weights(self, module): 496 | std = self.config.initializer_range 497 | if isinstance(module, nn.Linear): 498 | module.weight.data.normal_(mean=0.0, std=std) 499 | if module.bias is not None: 500 | module.bias.data.zero_() 501 | elif isinstance(module, nn.Embedding): 502 | module.weight.data.normal_(mean=0.0, std=std) 503 | if module.padding_idx is not None: 504 | module.weight.data[module.padding_idx].zero_() 505 | 506 | 507 | InternLM2_INPUTS_DOCSTRING = r""" 508 | Args: 509 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 510 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide 511 | it. 512 | 513 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 514 | [`PreTrainedTokenizer.__call__`] for details. 515 | 516 | [What are input IDs?](../glossary#input-ids) 517 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 518 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 519 | 520 | - 1 for tokens that are **not masked**, 521 | - 0 for tokens that are **masked**. 522 | 523 | [What are attention masks?](../glossary#attention-mask) 524 | 525 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 526 | [`PreTrainedTokenizer.__call__`] for details. 527 | 528 | If `past_key_values` is used, optionally only the last `input_ids` have to be input (see 529 | `past_key_values`). 530 | 531 | If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] 532 | and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more 533 | information on the default strategy. 534 | 535 | - 1 indicates the head is **not masked**, 536 | - 0 indicates the head is **masked**. 537 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 538 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 539 | config.n_positions - 1]`. 540 | 541 | [What are position IDs?](../glossary#position-ids) 542 | past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): 543 | Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention 544 | blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` 545 | returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. 546 | 547 | Two formats are allowed: 548 | - a [`~cache_utils.Cache`] instance; 549 | - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of 550 | shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy 551 | cache format. 552 | 553 | The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the 554 | legacy cache format will be returned. 555 | 556 | If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't 557 | have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` 558 | of shape `(batch_size, sequence_length)`. 559 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 560 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 561 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 562 | model's internal embedding lookup matrix. 563 | use_cache (`bool`, *optional*): 564 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 565 | `past_key_values`). 566 | output_attentions (`bool`, *optional*): 567 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 568 | tensors for more detail. 569 | output_hidden_states (`bool`, *optional*): 570 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 571 | more detail. 572 | return_dict (`bool`, *optional*): 573 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 574 | cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): 575 | Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, 576 | this tensor is not affected by padding. It is used to update the cache in the correct position and to infer 577 | the complete sequence length. 578 | """ 579 | 580 | 581 | # Modified from transformers.models.llama.modeling_llama.LlamaModel with Llama->InternLM2 582 | @add_start_docstrings( 583 | "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.", 584 | InternLM2_START_DOCSTRING, 585 | ) 586 | class InternLM2Model(InternLM2PreTrainedModel): 587 | """ 588 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLM2DecoderLayer`] 589 | 590 | Args: 591 | config: InternLM2Config 592 | """ 593 | 594 | _auto_class = "AutoModel" 595 | 596 | def __init__(self, config: InternLM2Config): 597 | super().__init__(config) 598 | self.padding_idx = config.pad_token_id 599 | self.vocab_size = config.vocab_size 600 | self.config = config 601 | 602 | self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) 603 | 604 | self.layers = nn.ModuleList( 605 | [InternLM2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] 606 | ) 607 | self.norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 608 | 609 | self.gradient_checkpointing = False 610 | # Initialize weights and apply final processing 611 | self.post_init() 612 | 613 | def get_input_embeddings(self): 614 | return self.tok_embeddings 615 | 616 | def set_input_embeddings(self, value): 617 | self.tok_embeddings = value 618 | 619 | @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING) 620 | def forward( 621 | self, 622 | input_ids: torch.LongTensor = None, 623 | attention_mask: Optional[torch.Tensor] = None, 624 | position_ids: Optional[torch.LongTensor] = None, 625 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, 626 | inputs_embeds: Optional[torch.FloatTensor] = None, 627 | use_cache: Optional[bool] = None, 628 | output_attentions: Optional[bool] = None, 629 | output_hidden_states: Optional[bool] = None, 630 | return_dict: Optional[bool] = None, 631 | cache_position: Optional[torch.LongTensor] = None, 632 | ) -> Union[Tuple, BaseModelOutputWithPast]: 633 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 634 | output_hidden_states = ( 635 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 636 | ) 637 | use_cache = use_cache if use_cache is not None else self.config.use_cache 638 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 639 | 640 | if (input_ids is None) ^ (inputs_embeds is not None): 641 | raise ValueError( 642 | "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" 643 | ) 644 | 645 | if self.gradient_checkpointing and self.training and use_cache: 646 | logger.warning_once( 647 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." 648 | ) 649 | use_cache = False 650 | 651 | if inputs_embeds is None: 652 | inputs_embeds = self.tok_embeddings(input_ids) 653 | 654 | return_legacy_cache = False 655 | if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) 656 | return_legacy_cache = True 657 | past_key_values = DynamicCache.from_legacy_cache(past_key_values) 658 | 659 | if cache_position is None: 660 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 661 | cache_position = torch.arange( 662 | past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device 663 | ) 664 | if position_ids is None: 665 | position_ids = cache_position.unsqueeze(0) 666 | 667 | causal_mask = self._update_causal_mask( 668 | attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions 669 | ) 670 | 671 | # embed positions 672 | hidden_states = inputs_embeds 673 | 674 | # decoder layers 675 | all_hidden_states = () if output_hidden_states else None 676 | all_self_attns = () if output_attentions else None 677 | next_decoder_cache = None 678 | 679 | for decoder_layer in self.layers: 680 | if output_hidden_states: 681 | all_hidden_states += (hidden_states,) 682 | 683 | if self.gradient_checkpointing and self.training: 684 | layer_outputs = self._gradient_checkpointing_func( 685 | decoder_layer.__call__, 686 | hidden_states, 687 | causal_mask, 688 | position_ids, 689 | past_key_values, 690 | output_attentions, 691 | use_cache, 692 | cache_position, 693 | ) 694 | else: 695 | layer_outputs = decoder_layer( 696 | hidden_states, 697 | attention_mask=causal_mask, 698 | position_ids=position_ids, 699 | past_key_value=past_key_values, 700 | output_attentions=output_attentions, 701 | use_cache=use_cache, 702 | cache_position=cache_position, 703 | ) 704 | 705 | hidden_states = layer_outputs[0] 706 | 707 | if use_cache: 708 | next_decoder_cache = layer_outputs[2 if output_attentions else 1] 709 | 710 | if output_attentions: 711 | all_self_attns += (layer_outputs[1],) 712 | 713 | hidden_states = self.norm(hidden_states) 714 | 715 | # add hidden states from the last decoder layer 716 | if output_hidden_states: 717 | all_hidden_states += (hidden_states,) 718 | 719 | next_cache = next_decoder_cache if use_cache else None 720 | if return_legacy_cache: 721 | next_cache = next_cache.to_legacy_cache() 722 | 723 | if not return_dict: 724 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 725 | return BaseModelOutputWithPast( 726 | last_hidden_state=hidden_states, 727 | past_key_values=next_cache, 728 | hidden_states=all_hidden_states, 729 | attentions=all_self_attns, 730 | ) 731 | def _update_causal_mask( 732 | self, 733 | attention_mask: torch.Tensor, 734 | input_tensor: torch.Tensor, 735 | cache_position: torch.Tensor, 736 | past_key_values: Cache, 737 | output_attentions: bool, 738 | ): 739 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 740 | using_static_cache = isinstance(past_key_values, StaticCache) 741 | 742 | dtype, device = input_tensor.dtype, input_tensor.device 743 | min_dtype = torch.finfo(dtype).min 744 | sequence_length = input_tensor.shape[1] 745 | if using_static_cache: 746 | target_length = past_key_values.get_max_length() 747 | else: 748 | target_length = ( 749 | attention_mask.shape[-1] 750 | if isinstance(attention_mask, torch.Tensor) 751 | else past_seen_tokens + sequence_length + 1 752 | ) 753 | 754 | if self.config.attn_implementation == "flex_attention": 755 | assert attention_mask is not None and attention_mask.dim() == 4 756 | attention_mask = attention_mask.to(dtype).masked_fill(attention_mask == 0, min_dtype) 757 | return attention_mask 758 | 759 | if attention_mask is not None and attention_mask.dim() == 4: 760 | # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing 761 | if attention_mask.max() != 0: 762 | raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") 763 | causal_mask = attention_mask 764 | else: 765 | causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) 766 | if sequence_length != 1: 767 | if support_bf16_triu or dtype == torch.float32: 768 | causal_mask = torch.triu(causal_mask, diagonal=1) 769 | else: 770 | triu_mask = torch.triu(torch.ones(causal_mask.size(), device=device), diagonal=1).bool() 771 | causal_mask.masked_fill_(~triu_mask, 0) 772 | causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) 773 | causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) 774 | if attention_mask is not None: 775 | causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit 776 | mask_length = attention_mask.shape[-1] 777 | padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] 778 | padding_mask = padding_mask == 0 779 | causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( 780 | padding_mask, min_dtype 781 | ) 782 | 783 | return causal_mask 784 | 785 | 786 | # Modified from transformers.models.llama.modeling_llama.LlamaForCausalLM 787 | class InternLM2ForCausalLM(InternLM2PreTrainedModel): 788 | """Causal language model (CLM) for InternLM2.""" 789 | 790 | _auto_class = "AutoModelForCausalLM" 791 | _tied_weights_keys = ["output.weight"] 792 | 793 | def __init__(self, config): 794 | super().__init__(config) 795 | self.model = InternLM2Model(config) 796 | self.vocab_size = config.vocab_size 797 | self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 798 | 799 | # Initialize weights and apply final processing 800 | self.post_init() 801 | 802 | def get_input_embeddings(self): 803 | return self.model.tok_embeddings 804 | 805 | def set_input_embeddings(self, value): 806 | self.model.tok_embeddings = value 807 | 808 | def get_output_embeddings(self): 809 | return self.output 810 | 811 | def set_output_embeddings(self, new_embeddings): 812 | self.output = new_embeddings 813 | 814 | def set_decoder(self, decoder): 815 | self.model = decoder 816 | 817 | def get_decoder(self): 818 | return self.model 819 | 820 | @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING) 821 | @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) 822 | def forward( 823 | self, 824 | input_ids: Optional[torch.LongTensor] = None, 825 | attention_mask: Optional[torch.Tensor] = None, 826 | position_ids: Optional[torch.LongTensor] = None, 827 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, 828 | inputs_embeds: Optional[torch.FloatTensor] = None, 829 | labels: Optional[torch.LongTensor] = None, 830 | use_cache: Optional[bool] = None, 831 | output_attentions: Optional[bool] = None, 832 | output_hidden_states: Optional[bool] = None, 833 | return_dict: Optional[bool] = None, 834 | cache_position: Optional[torch.LongTensor] = None, 835 | ) -> Union[Tuple, CausalLMOutputWithPast]: 836 | r""" 837 | Args: 838 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 839 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 840 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 841 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 842 | 843 | Returns: 844 | 845 | Example: 846 | 847 | ```python 848 | >>> from transformers import AutoTokenizer, InternLM2ForCausalLM 849 | 850 | >>> model = InternLM2ForCausalLM.from_pretrained("meta-InternLM2/InternLM2-2-7b-hf") 851 | >>> tokenizer = AutoTokenizer.from_pretrained("meta-InternLM2/InternLM2-2-7b-hf") 852 | 853 | >>> prompt = "Hey, are you conscious? Can you talk to me?" 854 | >>> inputs = tokenizer(prompt, return_tensors="pt") 855 | 856 | >>> # Generate 857 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 858 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 859 | "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." 860 | ```""" 861 | 862 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 863 | output_hidden_states = ( 864 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 865 | ) 866 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 867 | 868 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 869 | outputs = self.model( 870 | input_ids=input_ids, 871 | attention_mask=attention_mask, 872 | position_ids=position_ids, 873 | past_key_values=past_key_values, 874 | inputs_embeds=inputs_embeds, 875 | use_cache=use_cache, 876 | output_attentions=output_attentions, 877 | output_hidden_states=output_hidden_states, 878 | return_dict=return_dict, 879 | cache_position=cache_position, 880 | ) 881 | 882 | hidden_states = outputs[0] 883 | if self.config.pretraining_tp > 1: 884 | output_slices = self.output.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) 885 | logits = [ 886 | F.linear(hidden_states, output_slices[i]) # pylint: disable=not-callable 887 | for i in range(self.config.pretraining_tp) 888 | ] 889 | logits = torch.cat(logits, dim=-1) 890 | else: 891 | logits = self.output(hidden_states) 892 | logits = logits.float() 893 | 894 | loss = None 895 | if labels is not None: 896 | # Shift so that tokens < n predict n 897 | shift_logits = logits[..., :-1, :].contiguous() 898 | shift_labels = labels[..., 1:].contiguous() 899 | # Flatten the tokens 900 | loss_fct = CrossEntropyLoss() 901 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 902 | shift_labels = shift_labels.view(-1) 903 | # Enable model parallelism 904 | shift_labels = shift_labels.to(shift_logits.device) 905 | loss = loss_fct(shift_logits, shift_labels) 906 | 907 | if not return_dict: 908 | output = (logits,) + outputs[1:] 909 | return (loss,) + output if loss is not None else output 910 | 911 | return CausalLMOutputWithPast( 912 | loss=loss, 913 | logits=logits, 914 | past_key_values=outputs.past_key_values, 915 | hidden_states=outputs.hidden_states, 916 | attentions=outputs.attentions, 917 | ) --------------------------------------------------------------------------------