├── multimodal_ropes ├── __init__.py ├── freq_allocation │ ├── __init__.py │ ├── videorope.py │ ├── mrope_i.py │ ├── mhrope.py │ ├── hope.py │ ├── vanilla.py │ └── mrope.py ├── pos_design │ ├── __init__.py │ ├── vanilla.py │ ├── mrope.py │ ├── videorope.py │ ├── hope.py │ ├── mrope_i.py │ └── circlerope.py ├── configs │ ├── __init__.py │ ├── vanilla.py │ ├── videorope.py │ ├── mrope.py │ ├── mrope_i.py │ ├── hope.py │ ├── mhrope.py │ └── circlerope.py └── entry.py ├── pyproject.toml ├── LICENSE ├── test.py └── README.md /multimodal_ropes/__init__.py: -------------------------------------------------------------------------------- 1 | from .entry import get_multimodal_rope, get_multimodal_rope_config 2 | -------------------------------------------------------------------------------- /multimodal_ropes/freq_allocation/__init__.py: -------------------------------------------------------------------------------- 1 | from .vanilla import RopeEmbedding 2 | from .mrope import MRopeEmbedding 3 | from .mrope_i import MRopeInterleaveEmbedding 4 | from .mhrope import MHRopeEmbedding 5 | from .videorope import VideoRopeEmbedding 6 | from .hope import HopeEmbedding 7 | -------------------------------------------------------------------------------- /multimodal_ropes/pos_design/__init__.py: -------------------------------------------------------------------------------- 1 | from .vanilla import get_vanilla_rope_index 2 | from .mrope import get_mrope_index 3 | from .mrope_i import get_mrope_interleave_index 4 | from .videorope import get_videorope_index 5 | from .hope import get_hope_index 6 | from .circlerope import get_circlerope_index 7 | -------------------------------------------------------------------------------- /multimodal_ropes/configs/__init__.py: -------------------------------------------------------------------------------- 1 | from .vanilla import VanillaRopeConfig 2 | from .mrope import MRopeConfig 3 | from .mrope_i import MRopeInterleaveConfig 4 | from .mhrope import MHRopeConfig 5 | from .videorope import VideoRopeConfig 6 | from .hope import HopeConfig 7 | from .circlerope import CircleRopeConfig 8 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "multimodal-ropes" 7 | version = "0.1.0" 8 | description = "Multimodal RoPE implementations" 9 | dependencies = [ 10 | "torch", 11 | "transformers==4.57.1" 12 | ] 13 | 14 | [tool.setuptools.packages.find] 15 | where = ["."] 16 | include = ["multimodal_ropes*"] -------------------------------------------------------------------------------- /multimodal_ropes/configs/vanilla.py: -------------------------------------------------------------------------------- 1 | class VanillaRopeConfig: 2 | """ 3 | Configuration class for RoPE. 4 | """ 5 | 6 | def __init__(self, dim: int, base: int = 10000, **kwargs): 7 | self.name = "vanilla-rope" 8 | self.dim = dim 9 | self.base = base 10 | 11 | def __repr__(self): 12 | attrs = ", ".join(f"{k}={v!r}" for k, v in self.__dict__.items()) 13 | return f"{self.__class__.__name__}({attrs})" 14 | -------------------------------------------------------------------------------- /multimodal_ropes/configs/videorope.py: -------------------------------------------------------------------------------- 1 | from .mrope import MRopeConfig 2 | 3 | 4 | class VideoRopeConfig(MRopeConfig): 5 | """ 6 | Configuration class for VideoRoPE. 7 | """ 8 | 9 | def __init__( 10 | self, 11 | dim: int, 12 | base: int = 10000, 13 | mrope_section: list[int] = [16, 24, 24], 14 | temporal_stride: float = 2.0, 15 | **kwargs, 16 | ): 17 | super().__init__(dim, base, mrope_section, temporal_stride, **kwargs) 18 | self.name = "videorope" 19 | -------------------------------------------------------------------------------- /multimodal_ropes/configs/mrope.py: -------------------------------------------------------------------------------- 1 | from .vanilla import VanillaRopeConfig 2 | 3 | 4 | class MRopeConfig(VanillaRopeConfig): 5 | """ 6 | Configuration class for MRoPE/MRoPE-I. 7 | """ 8 | 9 | def __init__( 10 | self, 11 | dim: int, 12 | base: int = 10000, 13 | mrope_section: list[int] = [16, 24, 24], 14 | temporal_stride: float = 2.0, 15 | **kwargs, 16 | ): 17 | super().__init__(dim, base, **kwargs) 18 | self.name = "mrope" 19 | self.mrope_section = mrope_section 20 | self.temporal_stride = temporal_stride 21 | -------------------------------------------------------------------------------- /multimodal_ropes/configs/mrope_i.py: -------------------------------------------------------------------------------- 1 | from .mrope import MRopeConfig 2 | 3 | 4 | class MRopeInterleaveConfig(MRopeConfig): 5 | """ 6 | Configuration class for MRoPE/MRoPE-I. 7 | """ 8 | 9 | def __init__( 10 | self, 11 | dim: int, 12 | base: int = 10000, 13 | mrope_section: list[int] = [24, 20, 20], 14 | temporal_stride: float = 1.0, 15 | spatial_reset: bool = False, 16 | **kwargs, 17 | ): 18 | super().__init__(dim, base, mrope_section, temporal_stride, **kwargs) 19 | self.name = "mrope-interleave" 20 | self.spatial_reset = spatial_reset 21 | -------------------------------------------------------------------------------- /multimodal_ropes/configs/hope.py: -------------------------------------------------------------------------------- 1 | from .mrope import MRopeConfig 2 | 3 | 4 | class HopeConfig(MRopeConfig): 5 | """ 6 | Configuration class for HoPE. 7 | """ 8 | 9 | def __init__( 10 | self, 11 | dim: int, 12 | base: int = 10000, 13 | mrope_section: list[int] = [16, 24, 24], 14 | temporal_stride: float = 2.0, 15 | temporal_stride_lst: list[float] = [0.5, 0.75, 1.0, 1.25, 1.5], 16 | **kwargs, 17 | ): 18 | super().__init__(dim, base, mrope_section, temporal_stride, **kwargs) 19 | self.name = "hope" 20 | self.temporal_stride_lst = temporal_stride_lst 21 | -------------------------------------------------------------------------------- /multimodal_ropes/configs/mhrope.py: -------------------------------------------------------------------------------- 1 | from .mrope_i import MRopeInterleaveConfig 2 | 3 | 4 | class MHRopeConfig(MRopeInterleaveConfig): 5 | """ 6 | Configuration class for MRoPE/MRoPE-I. 7 | """ 8 | 9 | def __init__( 10 | self, 11 | dim: int, 12 | base: int = 10000, 13 | mrope_section: list[int] = [2, 3, 3], 14 | temporal_stride: float = 1.0, 15 | num_key_value_heads: int = 8, 16 | **kwargs, 17 | ): 18 | """ 19 | Configuration class for MHRoPE. 20 | mrope_section means the number of heads for each dimension, like [2, 3, 3] for T, H, W. 21 | """ 22 | super().__init__(dim, base, mrope_section, temporal_stride, **kwargs) 23 | self.name = "mhrope" 24 | self.num_key_value_heads = num_key_value_heads 25 | -------------------------------------------------------------------------------- /multimodal_ropes/configs/circlerope.py: -------------------------------------------------------------------------------- 1 | from .mrope import MRopeConfig 2 | 3 | 4 | class CircleRopeConfig(MRopeConfig): 5 | """ 6 | Configuration class for CircleRoPE. 7 | """ 8 | 9 | def __init__( 10 | self, 11 | dim: int, 12 | base: int = 10000, 13 | mrope_section: list[int] = [16, 24, 24], 14 | temporal_stride: float = 1.0, 15 | move_to_origin: bool = False, 16 | move_to_positive: bool = False, 17 | dff_rate: bool = False, 18 | method: str = "circle", 19 | radius: float = -1, 20 | alpha: float = -1, 21 | **kwargs, 22 | ): 23 | super().__init__(dim, base, mrope_section, temporal_stride, **kwargs) 24 | self.name = "circlerope" 25 | self.move_to_origin = move_to_origin 26 | self.move_to_positive = move_to_positive 27 | self.dff_rate = dff_rate 28 | self.method = method 29 | self.radius = radius 30 | self.alpha = alpha 31 | -------------------------------------------------------------------------------- /multimodal_ropes/freq_allocation/videorope.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..configs.videorope import VideoRopeConfig 4 | from .mrope import MRopeEmbedding 5 | 6 | 7 | class VideoRopeEmbedding(MRopeEmbedding): 8 | def __init__(self, config, device=None, extra_config: VideoRopeConfig = None): 9 | super().__init__(config, device, extra_config) 10 | 11 | def apply_transformation(self, freqs, mrope_section): 12 | """Follow the order of hwhwhwhwtttt... to reorganize the frequency layout. 13 | args: 14 | x: (3, bs, seq_len, head_dim // 2) 15 | mrope_section: (3,) 16 | returns: 17 | x_t: (bs, seq_len, head_dim // 2) 18 | """ 19 | freqs_t = freqs[0] # just overwrite the first dimension T 20 | for dim, offset in enumerate((1, 2), start=1): # H, W 21 | length = mrope_section[dim] * 2 22 | idx = slice(offset, length, 2) 23 | freqs_t[..., idx] = freqs[dim, ..., idx] 24 | return freqs_t 25 | -------------------------------------------------------------------------------- /multimodal_ropes/freq_allocation/mrope_i.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..configs.mrope_i import MRopeInterleaveConfig 4 | from .mrope import MRopeEmbedding 5 | 6 | 7 | class MRopeInterleaveEmbedding(MRopeEmbedding): 8 | def __init__(self, config, device=None, extra_config: MRopeInterleaveConfig = None): 9 | super().__init__(config, device, extra_config) 10 | 11 | def apply_transformation(self, freqs, mrope_section): 12 | """Apply interleaved MRoPE to 3D rotary embeddings. 13 | Reorganizes frequency layout from chunked [TTT...HHH...WWW] to 14 | interleaved [THWTHWHTHW...TT], preserving frequency continuity. 15 | args: 16 | x: (3, bs, seq_len, head_dim // 2) 17 | mrope_section: (3,) 18 | returns: 19 | x_t: (bs, seq_len, head_dim // 2) 20 | """ 21 | freqs_t = freqs[0] # just overwrite the first dimension T 22 | for dim, offset in enumerate((1, 2), start=1): # H, W 23 | length = mrope_section[dim] * 3 24 | idx = slice(offset, length, 3) 25 | freqs_t[..., idx] = freqs[dim, ..., idx] 26 | return freqs_t 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Qwen Team, Alibaba Group. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /multimodal_ropes/freq_allocation/mhrope.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..configs.mhrope import MHRopeConfig 4 | from .mrope import MRopeEmbedding 5 | 6 | 7 | class MHRopeEmbedding(MRopeEmbedding): 8 | def __init__(self, config, device=None, extra_config: MHRopeConfig = None): 9 | super().__init__(config, device, extra_config) 10 | 11 | def apply_transformation(self, freqs, mrope_section): 12 | """Apply Multi-Head RoPE to 3D rotary embeddings. 13 | args: 14 | x: (3, bs, seq_len, head_dim // 2) 15 | mrope_section: (3,) 16 | returns: 17 | x_t: (bs, seq_len, head_dim // 2) 18 | """ 19 | 20 | batch_size, seq_length, dim = freqs.shape[1:] 21 | freqs = torch.cat( 22 | [ 23 | freqs[m, :, None, :, :].repeat(1, num, 1, 1) 24 | for m, num in enumerate(mrope_section) 25 | ], 26 | dim=1, 27 | ) 28 | if ( 29 | sum(mrope_section) < self.extra_config.num_key_value_heads 30 | ): # padding unused heads with zeros, e.g. q * cos(embed) + rotate_half(q) * sin(embed) = q 31 | freqs = torch.cat( 32 | [ 33 | freqs, 34 | torch.zeros( 35 | batch_size, 36 | self.extra_config.num_key_value_heads - sum(mrope_section), 37 | seq_length, 38 | dim, 39 | device=freqs.device, 40 | dtype=freqs.dtype, 41 | ), 42 | ], 43 | dim=1, 44 | ) 45 | 46 | return freqs 47 | -------------------------------------------------------------------------------- /multimodal_ropes/freq_allocation/hope.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers.modeling_rope_utils import dynamic_rope_update 3 | 4 | from ..configs.hope import HopeConfig 5 | from .videorope import VideoRopeEmbedding 6 | 7 | 8 | class HopeEmbedding(VideoRopeEmbedding): 9 | def __init__(self, config, device=None, extra_config: HopeConfig = None): 10 | super().__init__(config, device, extra_config) 11 | 12 | # HoPE follow the same frequency allocation as VideoRoPE, but apply NoPE to t dimension. 13 | # DIfferent from the official implementation, We achieve NoPE by setting the t pos to 0. 14 | @torch.no_grad() 15 | @dynamic_rope_update 16 | def forward(self, x, position_ids): 17 | if position_ids.ndim == 2: 18 | position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) 19 | 20 | # HoPE: set t pos to 0 21 | position_ids[0].masked_fill_(position_ids[0] > 0, 0) 22 | 23 | inv_freq_expanded = ( 24 | self.inv_freq[None, None, :, None] 25 | .float() 26 | .expand(3, position_ids.shape[1], -1, 1) 27 | ) 28 | position_ids_expanded = position_ids[ 29 | :, :, None, : 30 | ].float() # shape (3, bs, 1, positions) 31 | 32 | device_type = ( 33 | x.device.type 34 | if isinstance(x.device.type, str) and x.device.type != "mps" 35 | else "cpu" 36 | ) 37 | with torch.autocast(device_type=device_type, enabled=False): # Force float32 38 | freqs = ( 39 | inv_freq_expanded.float() @ position_ids_expanded.float() 40 | ).transpose(2, 3) 41 | freqs = self.apply_transformation( 42 | freqs, self.extra_config.mrope_section 43 | ) # use extra_config to avoid confusion with self.mrope_section 44 | emb = torch.cat((freqs, freqs), dim=-1) 45 | cos = emb.cos() * self.attention_scaling 46 | sin = emb.sin() * self.attention_scaling 47 | 48 | return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) 49 | -------------------------------------------------------------------------------- /multimodal_ropes/freq_allocation/vanilla.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from collections.abc import Callable 5 | from typing import Optional 6 | from transformers.modeling_rope_utils import dynamic_rope_update, ROPE_INIT_FUNCTIONS 7 | 8 | from ..configs.vanilla import VanillaRopeConfig 9 | 10 | 11 | class RopeEmbedding(nn.Module): 12 | inv_freq: torch.Tensor 13 | 14 | def __init__(self, config, device=None, extra_config: VanillaRopeConfig = None): 15 | super().__init__() 16 | self.max_seq_len_cached = config.max_position_embeddings 17 | self.original_max_seq_len = config.max_position_embeddings 18 | 19 | self.config = config 20 | self.extra_config = extra_config 21 | 22 | self.rope_type = self.config.rope_scaling["rope_type"] 23 | self.rope_init_fn: Callable = ROPE_INIT_FUNCTIONS[self.rope_type] 24 | inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) 25 | 26 | self.register_buffer("inv_freq", inv_freq, persistent=False) 27 | self.original_inv_freq = inv_freq 28 | 29 | @torch.no_grad() 30 | @dynamic_rope_update 31 | def forward(self, x, position_ids): 32 | if position_ids.ndim == 3: 33 | position_ids = position_ids[ 34 | 0 35 | ] # vanilla rope does not have multiple position ids for one token 36 | inv_freq_expanded = ( 37 | self.inv_freq[None, :, None] 38 | .float() 39 | .expand(position_ids.shape[0], -1, 1) 40 | .to(x.device) 41 | ) 42 | position_ids_expanded = position_ids[:, None, :].float() 43 | 44 | device_type = ( 45 | x.device.type 46 | if isinstance(x.device.type, str) and x.device.type != "mps" 47 | else "cpu" 48 | ) 49 | with torch.autocast(device_type=device_type, enabled=False): # Force float32 50 | freqs = ( 51 | inv_freq_expanded.float() @ position_ids_expanded.float() 52 | ).transpose(1, 2) 53 | emb = torch.cat((freqs, freqs), dim=-1) 54 | cos = emb.cos() * self.attention_scaling 55 | sin = emb.sin() * self.attention_scaling 56 | 57 | return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) 58 | -------------------------------------------------------------------------------- /multimodal_ropes/entry.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from .configs import * 4 | from .pos_design import * 5 | from .freq_allocation import * 6 | 7 | import logging 8 | 9 | logging.basicConfig(level=logging.INFO) 10 | 11 | SUPPORT_MM_ROPES = [ 12 | "vanilla-rope", 13 | "mrope", 14 | "mrope-interleave", 15 | "mhrope", 16 | "videorope", 17 | "hope", 18 | "circlerope", 19 | ] 20 | 21 | 22 | MAPPINGS_NAME_TO_CONFIG = { 23 | "vanilla-rope": VanillaRopeConfig, 24 | "mrope": MRopeConfig, 25 | "mrope-interleave": MRopeInterleaveConfig, 26 | "mhrope": MHRopeConfig, 27 | "videorope": VideoRopeConfig, 28 | "hope": HopeConfig, 29 | "circlerope": CircleRopeConfig, 30 | } 31 | 32 | 33 | MAPPINGS_NAME_TO_POS_DESIGN = { 34 | "vanilla-rope": get_vanilla_rope_index, 35 | "mrope": get_mrope_index, 36 | "mrope-interleave": get_mrope_interleave_index, 37 | "mhrope": get_mrope_interleave_index, 38 | "videorope": get_videorope_index, 39 | "hope": get_hope_index, 40 | "circlerope": get_circlerope_index, 41 | } 42 | 43 | 44 | MAPPINGS_NAME_TO_FREQ_ALLOCATION = { 45 | "vanilla-rope": RopeEmbedding, 46 | "mrope": MRopeEmbedding, 47 | "mrope-interleave": MRopeInterleaveEmbedding, 48 | "mhrope": MHRopeEmbedding, 49 | "videorope": VideoRopeEmbedding, 50 | "hope": HopeEmbedding, 51 | "circlerope": MRopeEmbedding, 52 | } 53 | 54 | 55 | def get_multimodal_rope_config(rope_name: str, **kwargs) -> VanillaRopeConfig: 56 | assert rope_name in SUPPORT_MM_ROPES, f"RoPE type {rope_name} not supported." 57 | rope_config_class = MAPPINGS_NAME_TO_CONFIG[rope_name] 58 | return rope_config_class(**kwargs) 59 | 60 | 61 | def get_multimodal_rope(rope_name: str, *args, **kwargs): 62 | assert rope_name in SUPPORT_MM_ROPES, f"RoPE type {rope_name} not supported." 63 | 64 | rope_config_class = MAPPINGS_NAME_TO_CONFIG[rope_name] 65 | config = rope_config_class(*args, **kwargs) 66 | 67 | logging.info(f"Config: {config}") 68 | 69 | pos_design_func = MAPPINGS_NAME_TO_POS_DESIGN[config.name] 70 | freq_allocation_class = MAPPINGS_NAME_TO_FREQ_ALLOCATION[config.name] 71 | 72 | def patched_pos_design_func(*args, **kwargs): 73 | return pos_design_func(extra_config=config, *args, **kwargs) 74 | 75 | rope_embed_factory = partial(freq_allocation_class, extra_config=config) 76 | 77 | return patched_pos_design_func, rope_embed_factory 78 | -------------------------------------------------------------------------------- /multimodal_ropes/freq_allocation/mrope.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from transformers.modeling_rope_utils import dynamic_rope_update 4 | 5 | from ..configs.mrope import MRopeConfig 6 | from .vanilla import RopeEmbedding 7 | 8 | 9 | class MRopeEmbedding(RopeEmbedding): 10 | def __init__(self, config, device=None, extra_config: MRopeConfig = None): 11 | super().__init__(config, device, extra_config) 12 | 13 | self.mrope_section = config.rope_scaling.get("mrope_section", [24, 20, 20]) 14 | 15 | @torch.no_grad() 16 | @dynamic_rope_update 17 | def forward(self, x, position_ids): 18 | # In contrast to other models, Qwen3VL has different position ids for the grids 19 | # So we expand the inv_freq to shape (3, ...) 20 | if position_ids.ndim == 2: 21 | position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) 22 | inv_freq_expanded = ( 23 | self.inv_freq[None, None, :, None] 24 | .float() 25 | .expand(3, position_ids.shape[1], -1, 1) 26 | ) 27 | position_ids_expanded = position_ids[ 28 | :, :, None, : 29 | ].float() # shape (3, bs, 1, positions) 30 | 31 | device_type = ( 32 | x.device.type 33 | if isinstance(x.device.type, str) and x.device.type != "mps" 34 | else "cpu" 35 | ) 36 | with torch.autocast(device_type=device_type, enabled=False): # Force float32 37 | freqs = ( 38 | inv_freq_expanded.float() @ position_ids_expanded.float() 39 | ).transpose(2, 3) 40 | freqs = self.apply_transformation( 41 | freqs, self.extra_config.mrope_section 42 | ) # use extra_config to avoid confusion with self.mrope_section 43 | emb = torch.cat((freqs, freqs), dim=-1) 44 | cos = emb.cos() * self.attention_scaling 45 | sin = emb.sin() * self.attention_scaling 46 | 47 | return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) 48 | 49 | def apply_transformation(self, freqs, mrope_section): 50 | """Apply MRoPE to 3D rotary embeddings. 51 | args: 52 | x: (3, bs, seq_len, head_dim // 2) 53 | mrope_section: (3,) 54 | returns: 55 | x_t: (bs, seq_len, head_dim // 2) 56 | """ 57 | freqs = torch.cat( 58 | [m[i % 3] for i, m in enumerate(freqs.split(mrope_section, dim=-1))], dim=-1 59 | ) 60 | return freqs 61 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import transformers 3 | from transformers import AutoProcessor, Qwen3VLModel, Qwen3VLForConditionalGeneration 4 | from transformers.models.qwen3_vl.modeling_qwen3_vl import ( 5 | apply_rotary_pos_emb as original_apply_rotary_pos_emb, 6 | ) 7 | from multimodal_ropes.entry import get_multimodal_rope 8 | 9 | import logging 10 | 11 | logging.basicConfig(level=logging.INFO) 12 | 13 | 14 | def rotate_half(x): 15 | """Rotates half the hidden dims of the input.""" 16 | x1 = x[..., : x.shape[-1] // 2] 17 | x2 = x[..., x.shape[-1] // 2 :] 18 | return torch.cat((-x2, x1), dim=-1) 19 | 20 | 21 | def apply_multihead_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): 22 | # repeat interleave for GQA, (num_heads / num_key_value_heads) querys -> 1 key, repeat freqs of key 23 | # q [bs, num_heads, seq_len, head_dim] 24 | # k [bs, num_key_value_heads, seq_len, head_dim] 25 | # cos, sin [bs, num_key_value_heads, seq_len, head_dim] -> [bs, num_heads, seq_len, head_dim] 26 | n_repeat = q.shape[1] // cos.shape[1] 27 | q_embed = (q * cos.repeat_interleave(n_repeat, dim=1)) + ( 28 | rotate_half(q) * sin.repeat_interleave(n_repeat, dim=1) 29 | ) 30 | k_embed = (k * cos) + (rotate_half(k) * sin) 31 | return q_embed, k_embed 32 | 33 | 34 | def monkey_patch_qwen3vl(rope_name, **kwargs): 35 | rope_index, rope_embed = get_multimodal_rope(rope_name, **kwargs) 36 | 37 | logging.info(f"Begin to patch Qwen3VLModel with {rope_name}.") 38 | Qwen3VLModel.get_rope_index = rope_index 39 | transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRotaryEmbedding = ( 40 | rope_embed 41 | ) 42 | 43 | if rope_name == "mhrope": 44 | # patch apply_rotary_embed 45 | transformers.models.qwen3_vl.modeling_qwen3_vl.apply_rotary_pos_emb = ( 46 | apply_multihead_rotary_pos_emb 47 | ) 48 | logging.info( 49 | "MHRoPE: Patched apply_rotary_pos_emb with apply_multihead_rotary_pos_emb." 50 | ) 51 | 52 | logging.info(f"Patched Qwen3VLModel with {rope_name}.") 53 | 54 | test_forward() 55 | 56 | if rope_name == "mhrope": 57 | transformers.models.qwen3_vl.modeling_qwen3_vl.apply_rotary_pos_emb = ( 58 | original_apply_rotary_pos_emb 59 | ) 60 | logging.info("MHRoPE: Restored original apply_rotary_pos_emb.") 61 | 62 | 63 | def test_forward(): 64 | ckpt = "Qwen/Qwen3-VL-2B-Instruct" 65 | model = Qwen3VLForConditionalGeneration.from_pretrained( 66 | ckpt, 67 | attn_implementation="flash_attention_2", 68 | device_map="auto", 69 | dtype=torch.bfloat16, 70 | ) 71 | processor = AutoProcessor.from_pretrained(ckpt) 72 | 73 | messages = [ 74 | { 75 | "role": "user", 76 | "content": [ 77 | { 78 | "type": "image", 79 | "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", 80 | }, 81 | {"type": "text", "text": "Describe this image."}, 82 | ], 83 | } 84 | ] 85 | 86 | # Preparation for inference 87 | inputs = processor.apply_chat_template( 88 | messages, 89 | tokenize=True, 90 | add_generation_prompt=True, 91 | return_dict=True, 92 | return_tensors="pt", 93 | ) 94 | inputs = inputs.to(model.device) 95 | 96 | # Inference: Generation of the output 97 | generated_ids = model.generate(**inputs, max_new_tokens=128) 98 | generated_ids_trimmed = [ 99 | out_ids[len(in_ids) :] 100 | for in_ids, out_ids in zip(inputs.input_ids, generated_ids) 101 | ] 102 | output_text = processor.batch_decode( 103 | generated_ids_trimmed, 104 | skip_special_tokens=True, 105 | clean_up_tokenization_spaces=False, 106 | ) 107 | print(output_text) 108 | 109 | 110 | if __name__ == "__main__": 111 | common_kwargs = dict( 112 | dim=128, 113 | base=5000000, 114 | ) 115 | monkey_patch_qwen3vl("vanilla-rope", **common_kwargs) 116 | 117 | # MRoPE, MRoPE-I, MHRoPE 118 | monkey_patch_qwen3vl( 119 | "mrope", mrope_section=[16, 24, 24], temporal_stride=1, **common_kwargs 120 | ) 121 | monkey_patch_qwen3vl( 122 | "mrope-interleave", 123 | mrope_section=[24, 20, 20], 124 | temporal_stride=1, 125 | spatial_reset=True, 126 | **common_kwargs, 127 | ) 128 | monkey_patch_qwen3vl( 129 | "mhrope", 130 | num_key_value_heads=8, 131 | mrope_section=[2, 3, 3], 132 | temporal_stride=1, 133 | spatial_reset=True, 134 | **common_kwargs, 135 | ) 136 | 137 | # VideoRoPE and HoPE, temporal_stride is a float 138 | monkey_patch_qwen3vl( 139 | "videorope", mrope_section=[16, 24, 24], temporal_stride=2.0, **common_kwargs 140 | ) 141 | monkey_patch_qwen3vl( 142 | "hope", 143 | mrope_section=[16, 24, 24], 144 | temporal_stride=2.0, 145 | temporal_stride_lst=[0.5, 0.75, 1.0, 1.25, 1.5], 146 | **common_kwargs, 147 | ) 148 | monkey_patch_qwen3vl( 149 | "circlerope", 150 | mrope_section=[16, 24, 24], 151 | temporal_stride=1, 152 | move_to_origin=True, 153 | dff_rate=True, 154 | method="circle", 155 | radius=10, 156 | alpha=0.5, 157 | **common_kwargs, 158 | ) 159 | -------------------------------------------------------------------------------- /multimodal_ropes/pos_design/vanilla.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional 3 | 4 | from ..configs.vanilla import VanillaRopeConfig 5 | 6 | 7 | def get_vanilla_rope_index( 8 | self, 9 | input_ids: Optional[torch.LongTensor] = None, 10 | image_grid_thw: Optional[torch.LongTensor] = None, 11 | video_grid_thw: Optional[torch.LongTensor] = None, 12 | attention_mask: Optional[torch.Tensor] = None, 13 | extra_config: VanillaRopeConfig = None, 14 | ) -> tuple[torch.Tensor, torch.Tensor]: 15 | # qwen3vl use timestamps to seperate videos, like , the video_grid_thw should also be split 16 | # if you are using qwen2/2.5vl, please remove them 17 | if video_grid_thw is not None: 18 | video_grid_thw = torch.repeat_interleave( 19 | video_grid_thw, video_grid_thw[:, 0], dim=0 20 | ) 21 | video_grid_thw[:, 0] = 1 22 | 23 | spatial_merge_size = self.config.vision_config.spatial_merge_size 24 | image_token_id = self.config.image_token_id 25 | video_token_id = self.config.video_token_id 26 | vision_start_token_id = self.config.vision_start_token_id 27 | mrope_position_deltas = [] 28 | if input_ids is not None and ( 29 | image_grid_thw is not None or video_grid_thw is not None 30 | ): 31 | total_input_ids = input_ids 32 | if attention_mask is None: 33 | attention_mask = torch.ones_like(total_input_ids) 34 | position_ids = torch.ones( 35 | 3, 36 | input_ids.shape[0], 37 | input_ids.shape[1], 38 | dtype=input_ids.dtype, 39 | device=input_ids.device, 40 | ) 41 | image_index, video_index = 0, 0 42 | attention_mask = attention_mask.to(total_input_ids.device) 43 | for i, input_ids in enumerate(total_input_ids): 44 | input_ids = input_ids[attention_mask[i] == 1] 45 | image_nums, video_nums = 0, 0 46 | vision_start_indices = torch.argwhere( 47 | input_ids == vision_start_token_id 48 | ).squeeze(1) 49 | vision_tokens = input_ids[vision_start_indices + 1] 50 | image_nums = (vision_tokens == image_token_id).sum() 51 | video_nums = (vision_tokens == video_token_id).sum() 52 | input_tokens = input_ids.tolist() 53 | llm_pos_ids_list: list = [] 54 | st = 0 55 | remain_images, remain_videos = image_nums, video_nums 56 | for _ in range(image_nums + video_nums): 57 | if image_token_id in input_tokens and remain_images > 0: 58 | ed_image = input_tokens.index(image_token_id, st) 59 | else: 60 | ed_image = len(input_tokens) + 1 61 | if video_token_id in input_tokens and remain_videos > 0: 62 | ed_video = input_tokens.index(video_token_id, st) 63 | else: 64 | ed_video = len(input_tokens) + 1 65 | if ed_image < ed_video: 66 | t, h, w = ( 67 | image_grid_thw[image_index][0], 68 | image_grid_thw[image_index][1], 69 | image_grid_thw[image_index][2], 70 | ) 71 | image_index += 1 72 | remain_images -= 1 73 | ed = ed_image 74 | 75 | else: 76 | t, h, w = ( 77 | video_grid_thw[video_index][0], 78 | video_grid_thw[video_index][1], 79 | video_grid_thw[video_index][2], 80 | ) 81 | video_index += 1 82 | remain_videos -= 1 83 | ed = ed_video 84 | llm_grid_t, llm_grid_h, llm_grid_w = ( 85 | t.item(), 86 | h.item() // spatial_merge_size, 87 | w.item() // spatial_merge_size, 88 | ) 89 | text_len = ed - st 90 | 91 | st_idx = ( 92 | llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 93 | ) 94 | llm_pos_ids_list.append( 95 | torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx 96 | ) 97 | 98 | # 1. vanilla rope just increase the position ids linearly 99 | token_num = llm_grid_t * llm_grid_h * llm_grid_w 100 | t_index = torch.arange(token_num) 101 | 102 | # 2. repeat 3 times for compatibility with mrope dimensions 103 | llm_pos_ids_list.append( 104 | t_index.view(1, -1).expand(3, -1) + text_len + st_idx 105 | ) 106 | st = ed + llm_grid_t * llm_grid_h * llm_grid_w 107 | 108 | if st < len(input_tokens): 109 | st_idx = ( 110 | llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 111 | ) 112 | text_len = len(input_tokens) - st 113 | llm_pos_ids_list.append( 114 | torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx 115 | ) 116 | 117 | llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) 118 | position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( 119 | position_ids.device 120 | ) 121 | mrope_position_deltas.append( 122 | llm_positions.max() + 1 - len(total_input_ids[i]) 123 | ) 124 | mrope_position_deltas = torch.tensor( 125 | mrope_position_deltas, device=input_ids.device 126 | ).unsqueeze(1) 127 | return position_ids, mrope_position_deltas 128 | else: 129 | if attention_mask is not None: 130 | position_ids = attention_mask.long().cumsum(-1) - 1 131 | position_ids.masked_fill_(attention_mask == 0, 1) 132 | position_ids = ( 133 | position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) 134 | ) 135 | max_position_ids = position_ids.max(0, keepdim=False)[0].max( 136 | -1, keepdim=True 137 | )[0] 138 | mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] 139 | else: 140 | position_ids = ( 141 | torch.arange(input_ids.shape[1], device=input_ids.device) 142 | .view(1, 1, -1) 143 | .expand(3, input_ids.shape[0], -1) 144 | ) 145 | mrope_position_deltas = torch.zeros( 146 | [input_ids.shape[0], 1], 147 | device=input_ids.device, 148 | dtype=input_ids.dtype, 149 | ) 150 | 151 | return position_ids, mrope_position_deltas 152 | -------------------------------------------------------------------------------- /multimodal_ropes/pos_design/mrope.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional 3 | 4 | from ..configs.mrope import MRopeConfig 5 | 6 | 7 | def get_mrope_index( 8 | self, 9 | input_ids: Optional[torch.LongTensor] = None, 10 | image_grid_thw: Optional[torch.LongTensor] = None, 11 | video_grid_thw: Optional[torch.LongTensor] = None, 12 | attention_mask: Optional[torch.Tensor] = None, 13 | extra_config: MRopeConfig = None, 14 | ) -> tuple[torch.Tensor, torch.Tensor]: 15 | # qwen3vl use timestamps to seperate videos, like , the video_grid_thw should also be split 16 | # if you are using qwen2/2.5vl, please remove them 17 | if video_grid_thw is not None: 18 | video_grid_thw = torch.repeat_interleave( 19 | video_grid_thw, video_grid_thw[:, 0], dim=0 20 | ) 21 | video_grid_thw[:, 0] = 1 22 | 23 | spatial_merge_size = self.config.vision_config.spatial_merge_size 24 | image_token_id = self.config.image_token_id 25 | video_token_id = self.config.video_token_id 26 | vision_start_token_id = self.config.vision_start_token_id 27 | mrope_position_deltas = [] 28 | if input_ids is not None and ( 29 | image_grid_thw is not None or video_grid_thw is not None 30 | ): 31 | total_input_ids = input_ids 32 | if attention_mask is None: 33 | attention_mask = torch.ones_like(total_input_ids) 34 | position_ids = torch.ones( 35 | 3, 36 | input_ids.shape[0], 37 | input_ids.shape[1], 38 | dtype=input_ids.dtype, 39 | device=input_ids.device, 40 | ) 41 | image_index, video_index = 0, 0 42 | attention_mask = attention_mask.to(total_input_ids.device) 43 | for i, input_ids in enumerate(total_input_ids): 44 | input_ids = input_ids[attention_mask[i] == 1] 45 | image_nums, video_nums = 0, 0 46 | vision_start_indices = torch.argwhere( 47 | input_ids == vision_start_token_id 48 | ).squeeze(1) 49 | vision_tokens = input_ids[vision_start_indices + 1] 50 | image_nums = (vision_tokens == image_token_id).sum() 51 | video_nums = (vision_tokens == video_token_id).sum() 52 | input_tokens = input_ids.tolist() 53 | llm_pos_ids_list: list = [] 54 | st = 0 55 | remain_images, remain_videos = image_nums, video_nums 56 | for _ in range(image_nums + video_nums): 57 | if image_token_id in input_tokens and remain_images > 0: 58 | ed_image = input_tokens.index(image_token_id, st) 59 | else: 60 | ed_image = len(input_tokens) + 1 61 | if video_token_id in input_tokens and remain_videos > 0: 62 | ed_video = input_tokens.index(video_token_id, st) 63 | else: 64 | ed_video = len(input_tokens) + 1 65 | if ed_image < ed_video: 66 | t, h, w = ( 67 | image_grid_thw[image_index][0], 68 | image_grid_thw[image_index][1], 69 | image_grid_thw[image_index][2], 70 | ) 71 | image_index += 1 72 | remain_images -= 1 73 | ed = ed_image 74 | 75 | else: 76 | t, h, w = ( 77 | video_grid_thw[video_index][0], 78 | video_grid_thw[video_index][1], 79 | video_grid_thw[video_index][2], 80 | ) 81 | video_index += 1 82 | remain_videos -= 1 83 | ed = ed_video 84 | llm_grid_t, llm_grid_h, llm_grid_w = ( 85 | t.item(), 86 | h.item() // spatial_merge_size, 87 | w.item() // spatial_merge_size, 88 | ) 89 | text_len = ed - st 90 | 91 | st_idx = ( 92 | llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 93 | ) 94 | llm_pos_ids_list.append( 95 | torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx 96 | ) 97 | 98 | # default stride = 1, follow qwen2-vl 99 | t_index = ( 100 | torch.arange(llm_grid_t) 101 | .view(-1, 1) 102 | .expand(-1, llm_grid_h * llm_grid_w) 103 | .flatten() 104 | ) * extra_config.temporal_stride 105 | h_index = ( 106 | torch.arange(llm_grid_h) 107 | .view(1, -1, 1) 108 | .expand(llm_grid_t, -1, llm_grid_w) 109 | .flatten() 110 | ) 111 | w_index = ( 112 | torch.arange(llm_grid_w) 113 | .view(1, 1, -1) 114 | .expand(llm_grid_t, llm_grid_h, -1) 115 | .flatten() 116 | ) 117 | llm_pos_ids_list.append( 118 | torch.stack([t_index, h_index, w_index]) + text_len + st_idx 119 | ) 120 | st = ed + llm_grid_t * llm_grid_h * llm_grid_w 121 | 122 | if st < len(input_tokens): 123 | st_idx = ( 124 | llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 125 | ) 126 | text_len = len(input_tokens) - st 127 | llm_pos_ids_list.append( 128 | torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx 129 | ) 130 | 131 | llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) 132 | position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( 133 | position_ids.device 134 | ) 135 | mrope_position_deltas.append( 136 | llm_positions.max() + 1 - len(total_input_ids[i]) 137 | ) 138 | mrope_position_deltas = torch.tensor( 139 | mrope_position_deltas, device=input_ids.device 140 | ).unsqueeze(1) 141 | return position_ids, mrope_position_deltas 142 | else: 143 | if attention_mask is not None: 144 | position_ids = attention_mask.long().cumsum(-1) - 1 145 | position_ids.masked_fill_(attention_mask == 0, 1) 146 | position_ids = ( 147 | position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) 148 | ) 149 | max_position_ids = position_ids.max(0, keepdim=False)[0].max( 150 | -1, keepdim=True 151 | )[0] 152 | mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] 153 | else: 154 | position_ids = ( 155 | torch.arange(input_ids.shape[1], device=input_ids.device) 156 | .view(1, 1, -1) 157 | .expand(3, input_ids.shape[0], -1) 158 | ) 159 | mrope_position_deltas = torch.zeros( 160 | [input_ids.shape[0], 1], 161 | device=input_ids.device, 162 | dtype=input_ids.dtype, 163 | ) 164 | 165 | return position_ids, mrope_position_deltas 166 | -------------------------------------------------------------------------------- /multimodal_ropes/pos_design/videorope.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional 3 | 4 | from ..configs.videorope import VideoRopeConfig 5 | 6 | 7 | def get_videorope_index( 8 | self, 9 | input_ids: Optional[torch.LongTensor] = None, 10 | image_grid_thw: Optional[torch.LongTensor] = None, 11 | video_grid_thw: Optional[torch.LongTensor] = None, 12 | attention_mask: Optional[torch.Tensor] = None, 13 | extra_config: VideoRopeConfig = None, 14 | ) -> tuple[torch.Tensor, torch.Tensor]: 15 | # qwen3vl use timestamps to seperate videos, like , the video_grid_thw should also be split 16 | # if you are using qwen2/2.5vl, please remove them 17 | if video_grid_thw is not None: 18 | video_grid_thw = torch.repeat_interleave( 19 | video_grid_thw, video_grid_thw[:, 0], dim=0 20 | ) 21 | video_grid_thw[:, 0] = 1 22 | 23 | spatial_merge_size = self.config.vision_config.spatial_merge_size 24 | image_token_id = self.config.image_token_id 25 | video_token_id = self.config.video_token_id 26 | vision_start_token_id = self.config.vision_start_token_id 27 | mrope_position_deltas = [] 28 | if input_ids is not None and ( 29 | image_grid_thw is not None or video_grid_thw is not None 30 | ): 31 | total_input_ids = input_ids 32 | if attention_mask is None: 33 | attention_mask = torch.ones_like(total_input_ids) 34 | position_ids = torch.ones( 35 | 3, 36 | input_ids.shape[0], 37 | input_ids.shape[1], 38 | dtype=torch.float, 39 | device=input_ids.device, 40 | ) 41 | image_index, video_index = 0, 0 42 | attention_mask = attention_mask.to(total_input_ids.device) 43 | for i, input_ids in enumerate(total_input_ids): 44 | input_ids = input_ids[attention_mask[i] == 1] 45 | image_nums, video_nums = 0, 0 46 | vision_start_indices = torch.argwhere( 47 | input_ids == vision_start_token_id 48 | ).squeeze(1) 49 | vision_tokens = input_ids[vision_start_indices + 1] 50 | image_nums = (vision_tokens == image_token_id).sum() 51 | video_nums = (vision_tokens == video_token_id).sum() 52 | input_tokens = input_ids.tolist() 53 | llm_pos_ids_list: list = [] 54 | st = 0 55 | remain_images, remain_videos = image_nums, video_nums 56 | for _ in range(image_nums + video_nums): 57 | if image_token_id in input_tokens and remain_images > 0: 58 | ed_image = input_tokens.index(image_token_id, st) 59 | else: 60 | ed_image = len(input_tokens) + 1 61 | if video_token_id in input_tokens and remain_videos > 0: 62 | ed_video = input_tokens.index(video_token_id, st) 63 | else: 64 | ed_video = len(input_tokens) + 1 65 | if ed_image < ed_video: 66 | t, h, w = ( 67 | image_grid_thw[image_index][0], 68 | image_grid_thw[image_index][1], 69 | image_grid_thw[image_index][2], 70 | ) 71 | image_index += 1 72 | remain_images -= 1 73 | ed = ed_image 74 | 75 | else: 76 | t, h, w = ( 77 | video_grid_thw[video_index][0], 78 | video_grid_thw[video_index][1], 79 | video_grid_thw[video_index][2], 80 | ) 81 | video_index += 1 82 | remain_videos -= 1 83 | ed = ed_video 84 | llm_grid_t, llm_grid_h, llm_grid_w = ( 85 | t.item(), 86 | h.item() // spatial_merge_size, 87 | w.item() // spatial_merge_size, 88 | ) 89 | text_len = ed - st 90 | 91 | st_idx = ( 92 | llm_pos_ids_list[-1][0].max() + 1 93 | if len(llm_pos_ids_list) > 0 94 | else 0 95 | ) 96 | llm_pos_ids_list.append( 97 | torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx 98 | ) 99 | 100 | # body-diagonal symmetry 101 | t_index = ( 102 | torch.arange(llm_grid_t) 103 | .view(-1, 1) 104 | .expand(-1, llm_grid_h * llm_grid_w) 105 | .flatten() 106 | ) 107 | h_index = ( 108 | torch.arange(llm_grid_h) 109 | .view(1, -1, 1) 110 | .expand(llm_grid_t, -1, llm_grid_w) 111 | .flatten() 112 | - (llm_grid_h - 1) // 2 113 | ) 114 | 115 | w_index = ( 116 | torch.arange(llm_grid_w) 117 | .view(1, 1, -1) 118 | .expand(llm_grid_t, llm_grid_h, -1) 119 | .flatten() 120 | - (llm_grid_w - 1) // 2 121 | ) 122 | 123 | # time dim adjust step size 124 | t_index = t_index * extra_config.temporal_stride 125 | 126 | t_index = t_index + text_len + st_idx 127 | h_index = h_index + t_index 128 | w_index = w_index + t_index 129 | llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index])) 130 | st = ed + llm_grid_t * llm_grid_h * llm_grid_w 131 | 132 | if st < len(input_tokens): 133 | st_idx = ( 134 | llm_pos_ids_list[-1][0].max() + 1 135 | if len(llm_pos_ids_list) > 0 136 | else 0 137 | ) 138 | text_len = len(input_tokens) - st 139 | llm_pos_ids_list.append( 140 | torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx 141 | ) 142 | 143 | llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) 144 | position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( 145 | position_ids.device 146 | ) 147 | mrope_position_deltas.append( 148 | llm_positions[0].max() + 1 - len(total_input_ids[i]) 149 | ) 150 | mrope_position_deltas = torch.tensor( 151 | mrope_position_deltas, device=input_ids.device 152 | ).unsqueeze(1) 153 | return position_ids, mrope_position_deltas 154 | else: 155 | if attention_mask is not None: 156 | position_ids = attention_mask.long().cumsum(-1) - 1 157 | position_ids.masked_fill_(attention_mask == 0, 1) 158 | position_ids = ( 159 | position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) 160 | ) 161 | max_position_ids = position_ids.max(0, keepdim=False)[0].max( 162 | -1, keepdim=True 163 | )[0] 164 | mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] 165 | else: 166 | position_ids = ( 167 | torch.arange(input_ids.shape[1], device=input_ids.device) 168 | .view(1, 1, -1) 169 | .expand(3, input_ids.shape[0], -1) 170 | ) 171 | mrope_position_deltas = torch.zeros( 172 | [input_ids.shape[0], 1], 173 | device=input_ids.device, 174 | dtype=input_ids.dtype, 175 | ) 176 | 177 | return position_ids, mrope_position_deltas 178 | -------------------------------------------------------------------------------- /multimodal_ropes/pos_design/hope.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | from typing import Optional 4 | 5 | from ..configs.hope import HopeConfig 6 | 7 | 8 | def get_hope_index( 9 | self, 10 | input_ids: Optional[torch.LongTensor] = None, 11 | image_grid_thw: Optional[torch.LongTensor] = None, 12 | video_grid_thw: Optional[torch.LongTensor] = None, 13 | attention_mask: Optional[torch.Tensor] = None, 14 | extra_config: HopeConfig = None, 15 | ) -> tuple[torch.Tensor, torch.Tensor]: 16 | # qwen3vl use timestamps to seperate videos, like , the video_grid_thw should also be split 17 | # if you are using qwen2/2.5vl, please remove them 18 | if video_grid_thw is not None: 19 | video_grid_thw = torch.repeat_interleave( 20 | video_grid_thw, video_grid_thw[:, 0], dim=0 21 | ) 22 | video_grid_thw[:, 0] = 1 23 | 24 | spatial_merge_size = self.config.vision_config.spatial_merge_size 25 | image_token_id = self.config.image_token_id 26 | video_token_id = self.config.video_token_id 27 | vision_start_token_id = self.config.vision_start_token_id 28 | mrope_position_deltas = [] 29 | if input_ids is not None and ( 30 | image_grid_thw is not None or video_grid_thw is not None 31 | ): 32 | total_input_ids = input_ids 33 | if attention_mask is None: 34 | attention_mask = torch.ones_like(total_input_ids) 35 | position_ids = torch.ones( 36 | 3, 37 | input_ids.shape[0], 38 | input_ids.shape[1], 39 | dtype=torch.float, 40 | device=input_ids.device, 41 | ) 42 | image_index, video_index = 0, 0 43 | attention_mask = attention_mask.to(total_input_ids.device) 44 | for i, input_ids in enumerate(total_input_ids): 45 | input_ids = input_ids[attention_mask[i] == 1] 46 | image_nums, video_nums = 0, 0 47 | vision_start_indices = torch.argwhere( 48 | input_ids == vision_start_token_id 49 | ).squeeze(1) 50 | vision_tokens = input_ids[vision_start_indices + 1] 51 | image_nums = (vision_tokens == image_token_id).sum() 52 | video_nums = (vision_tokens == video_token_id).sum() 53 | input_tokens = input_ids.tolist() 54 | llm_pos_ids_list: list = [] 55 | st = 0 56 | remain_images, remain_videos = image_nums, video_nums 57 | for _ in range(image_nums + video_nums): 58 | if image_token_id in input_tokens and remain_images > 0: 59 | ed_image = input_tokens.index(image_token_id, st) 60 | else: 61 | ed_image = len(input_tokens) + 1 62 | if video_token_id in input_tokens and remain_videos > 0: 63 | ed_video = input_tokens.index(video_token_id, st) 64 | else: 65 | ed_video = len(input_tokens) + 1 66 | if ed_image < ed_video: 67 | t, h, w = ( 68 | image_grid_thw[image_index][0], 69 | image_grid_thw[image_index][1], 70 | image_grid_thw[image_index][2], 71 | ) 72 | image_index += 1 73 | remain_images -= 1 74 | ed = ed_image 75 | 76 | else: 77 | t, h, w = ( 78 | video_grid_thw[video_index][0], 79 | video_grid_thw[video_index][1], 80 | video_grid_thw[video_index][2], 81 | ) 82 | video_index += 1 83 | remain_videos -= 1 84 | ed = ed_video 85 | llm_grid_t, llm_grid_h, llm_grid_w = ( 86 | t.item(), 87 | h.item() // spatial_merge_size, 88 | w.item() // spatial_merge_size, 89 | ) 90 | text_len = ed - st 91 | 92 | st_idx = ( 93 | llm_pos_ids_list[-1][0].max() + 1 94 | if len(llm_pos_ids_list) > 0 95 | else 0 96 | ) 97 | llm_pos_ids_list.append( 98 | torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx 99 | ) 100 | 101 | # body-diagonal symmetry 102 | t_index = ( 103 | torch.arange(llm_grid_t) 104 | .view(-1, 1) 105 | .expand(-1, llm_grid_h * llm_grid_w) 106 | .flatten() 107 | ) 108 | h_index = ( 109 | torch.arange(llm_grid_h) 110 | .view(1, -1, 1) 111 | .expand(llm_grid_t, -1, llm_grid_w) 112 | .flatten() 113 | - (llm_grid_h - 1) // 2 114 | ) 115 | 116 | w_index = ( 117 | torch.arange(llm_grid_w) 118 | .view(1, 1, -1) 119 | .expand(llm_grid_t, llm_grid_h, -1) 120 | .flatten() 121 | - (llm_grid_w - 1) // 2 122 | ) 123 | 124 | # dynamic scaling 125 | scalor = random.choice(extra_config.temporal_stride_lst) 126 | t_index = t_index * scalor 127 | 128 | t_index = t_index + text_len + st_idx 129 | h_index = h_index + t_index 130 | w_index = w_index + t_index 131 | llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index])) 132 | st = ed + llm_grid_t * llm_grid_h * llm_grid_w 133 | 134 | if st < len(input_tokens): 135 | st_idx = ( 136 | llm_pos_ids_list[-1][0].max() + 1 137 | if len(llm_pos_ids_list) > 0 138 | else 0 139 | ) 140 | text_len = len(input_tokens) - st 141 | llm_pos_ids_list.append( 142 | torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx 143 | ) 144 | 145 | llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) 146 | position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( 147 | position_ids.device 148 | ) 149 | mrope_position_deltas.append( 150 | llm_positions[0].max() + 1 - len(total_input_ids[i]) 151 | ) 152 | mrope_position_deltas = torch.tensor( 153 | mrope_position_deltas, device=input_ids.device 154 | ).unsqueeze(1) 155 | return position_ids, mrope_position_deltas 156 | else: 157 | if attention_mask is not None: 158 | position_ids = attention_mask.long().cumsum(-1) - 1 159 | position_ids.masked_fill_(attention_mask == 0, 1) 160 | position_ids = ( 161 | position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) 162 | ) 163 | max_position_ids = position_ids.max(0, keepdim=False)[0].max( 164 | -1, keepdim=True 165 | )[0] 166 | mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] 167 | else: 168 | position_ids = ( 169 | torch.arange(input_ids.shape[1], device=input_ids.device) 170 | .view(1, 1, -1) 171 | .expand(3, input_ids.shape[0], -1) 172 | ) 173 | mrope_position_deltas = torch.zeros( 174 | [input_ids.shape[0], 1], 175 | device=input_ids.device, 176 | dtype=input_ids.dtype, 177 | ) 178 | 179 | return position_ids, mrope_position_deltas 180 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Revisiting Multimodal Positional Encoding in Vision–Language Models 2 | 3 | This repository is the official implementation of [Revisiting Multimodal Positional Encoding in Vision–Language Models](https://arxiv.org/abs/2510.23095). 4 | 5 | Multimodal position encoding is essential for vision-language models, yet there has been little systematic investigation into multimodal position encoding. We conduct a comprehensive analysis of *multimodal Rotary Positional Embedding (RoPE)* by examining its two core components: *position design* and *frequency allocation*. Through extensive experiments, we identify three key guidelines: *positional coherence, full frequency utilization, and preservation of textual priors*—ensuring unambiguous layout, rich representation, and faithful transfer from the pre-trained LLM. Based on these insights, we propose **Multi-Head RoPE (MHRoPE)** and **MRoPE-Interleave (MRoPE-I)**, two simple and plug-and-play variants that require no architectural changes. Our methods consistently outperform existing approaches across diverse benchmarks, with significant improvements in both general and fine-grained multimodal understanding. 6 | 7 | 8 | position-design 9 | frequency-allocation 10 | 11 | ## News 12 | 13 | - 2025.10 All variants of [Qwen3-VL](https://github.com/QwenLM/Qwen3-VL) now adopt MRoPE-Interleave w/o *spatial-reset*! 14 | 15 | ## Todo List: Implementations of Multimodal RoPE Variants 16 | 17 | To enhance usability and consistency, we are refactoring various multimodal RoPE implementations into a unified interface, similar to [`Qwen3VLTextRotaryEmbedding`](https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py#L278). This effort is expected to be completed within one week (target date: November 9, 2025). 18 | 19 | - [x] [Vanilla RoPE](https://arxiv.org/abs/2104.09864): Apply vanilla RoPE directly to multimodal sequences, discarding the spatio-temporal structure of visual content, yet it remains a strong baseline. 20 | - [x] [MRoPE](https://arxiv.org/abs/2409.12191): Introduced in Qwen2VL, upcasts 1D positions to three axes (t, h, w) and splits the feature dimension across different positional axes. 21 | - [x] Our MRoPE-I: Applies an interleaved frequency allocation strategy, preserving a more complete frequency band for each positional axis; spatial reset is incorporated into the positional design to enhance visual attention. 22 | - [x] Our MHRoPE: Employs head-wise frequency allocation to maximize utility across attention heads; spatial reset is also used. 23 | - [x] [VideoRoPE](https://arxiv.org/pdf/2502.05173): Optimizes MRoPE’s frequency allocation by assigning the temporal dimension to low-frequency bands and adopting a diagonal positional design. Different from the official implementation, we vectorize the computation for faster execution. 24 | - [x] [HoPE](https://arxiv.org/abs/2505.20444): Built on VideoRoPE, applies positional scaling in the design and uses NoPE (no positional encoding) on the temporal dimension during frequency allocation. Different from the official implementation, we reset the temporal positions and vectorize the computation for faster execution. 25 | - [x] [CircleRoPE](https://arxiv.org/abs/2505.16416): A novel positional design that maps image tokens onto a circular trajectory orthogonal to text token indices, effectively mitigating cross-modal positional bias during generation. While the original CircleRoPE was designed for static images, we extend it to support video inputs by stacking circular rings along the temporal dimension, see [here](https://github.com/JJJYmmm/Multimodal-RoPEs/blob/64f8a141326c0ec079f8f05da42483a600028662/multimodal_ropes/pos_design/circlerope.py#L122-L130). Note that this repo doesn't support AGE mode for simplicity. 26 | - [ ] [V2PE](https://arxiv.org/abs/2412.09616) 27 | - [ ] [ILRoPE](https://arxiv.org/abs/2505.05472v1) / [OmniRoPE](https://arxiv.org/abs/2506.18871) 28 | - [ ] [MMRoPE](https://arxiv.org/abs/2507.08801) 29 | - [ ] [GRAPE](https://openreview.net/forum?id=itoNJ3gJl2) 30 | - [ ] More variants... Feel free to open an issue or pull request, and I will add them here. 🤗 31 | 32 | ## Usage 33 | 34 | We organize various multimodal RoPE implementations under the `transformers` by decoupling them into two components: **position design** and **frequency allocation**. 35 | 36 | ### Installation 37 | 38 | You can install the `multimodal-ropes` package directly from the repository: 39 | 40 | ```bash 41 | git clone https://github.com/JJJYmmm/Multimodal-RoPEs.git 42 | pip install -e . 43 | # Successfully installed multimodal-ropes-0.1.0 44 | ``` 45 | 46 | ### Integration with Vision–Language Models (e.g., Qwen3-VL) 47 | 48 | The package provides a simple interface to plug in different multimodal RoPE variants. Below is an example of how to patch `Qwen3-VL` with your preferred RoPE configuration: 49 | > Note that adapting to the new positional encodings typically requires additional training or fine-tuning to ensure optimal performance. 50 | 51 | ```python 52 | from multimodal_ropes.entry import get_multimodal_rope 53 | 54 | def rotate_half(x): 55 | """Rotates half the hidden dimensions of the input.""" 56 | x1 = x[..., : x.shape[-1] // 2] 57 | x2 = x[..., x.shape[-1] // 2 :] 58 | return torch.cat((-x2, x1), dim=-1) 59 | 60 | def apply_multihead_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): 61 | """ 62 | Applies multi-head rotary positional embeddings for MHA/GQA. 63 | 64 | Args: 65 | q: [bs, num_heads, seq_len, head_dim] 66 | k: [bs, num_key_value_heads, seq_len, head_dim] 67 | cos, sin: [bs, num_key_value_heads, seq_len, head_dim] → broadcast to num_heads 68 | """ 69 | n_repeat = q.shape[1] // cos.shape[1] 70 | q_embed = (q * cos.repeat_interleave(n_repeat, dim=1)) + (rotate_half(q) * sin.repeat_interleave(n_repeat, dim=1)) 71 | k_embed = (k * cos) + (rotate_half(k) * sin) 72 | return q_embed, k_embed 73 | 74 | def monkey_patch_qwen3vl(rope_name, **kwargs): 75 | rope_index, rope_embed = get_multimodal_rope(rope_name, **kwargs) 76 | 77 | logging.info(f"Begin patching Qwen3VLModel with {rope_name}.") 78 | Qwen3VLModel.get_rope_index = rope_index 79 | transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRotaryEmbedding = rope_embed 80 | 81 | if rope_name == "mhrope": 82 | # Special handling for MHRoPE: replace the rotary embedding application function 83 | transformers.models.qwen3_vl.modeling_qwen3_vl.apply_rotary_pos_emb = apply_multihead_rotary_pos_emb 84 | logging.info("MHRoPE: Replaced apply_rotary_pos_emb with multi-head version.") 85 | 86 | logging.info(f"Successfully patched Qwen3VLModel with {rope_name}.") 87 | test_forward() # Optional: run a minimal forward pass to verify 88 | 89 | # Common RoPE configuration 90 | common_kwargs = dict( 91 | dim=128, 92 | base=5_000_000, 93 | ) 94 | 95 | # Examples of patching with different RoPE variants 96 | monkey_patch_qwen3vl("vanilla-rope", **common_kwargs) 97 | monkey_patch_qwen3vl("mrope-interleave", mrope_section=[24, 20, 20], temporal_stride=1, spatial_reset=True, **common_kwargs) 98 | monkey_patch_qwen3vl("mhrope", num_key_value_heads=8, mrope_section=[2, 3, 3], temporal_stride=1, spatial_reset=True, **common_kwargs) 99 | ``` 100 | 101 | For more comprehensive examples and testing utilities, please refer to [`test.py`](test.py). 102 | 103 | ## Citation 104 | 105 | If you find this repository is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry. 106 | 107 | ```bibtex 108 | @misc{huang2025revisitingmultimodalpositionalencoding, 109 | title={Revisiting Multimodal Positional Encoding in Vision-Language Models}, 110 | author={Jie Huang and Xuejing Liu and Sibo Song and Ruibing Hou and Hong Chang and Junyang Lin and Shuai Bai}, 111 | journal={arXiv preprint arXiv:2510.23095}, 112 | year={2025} 113 | } 114 | ``` 115 | 116 | ## License 117 | 118 | The content of this project itself is licensed under [LICENSE](LICENSE). 119 | -------------------------------------------------------------------------------- /multimodal_ropes/pos_design/mrope_i.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional 3 | 4 | from ..configs.mrope_i import MRopeInterleaveConfig 5 | 6 | 7 | def get_mrope_interleave_index( 8 | self, 9 | input_ids: Optional[torch.LongTensor] = None, 10 | image_grid_thw: Optional[torch.LongTensor] = None, 11 | video_grid_thw: Optional[torch.LongTensor] = None, 12 | attention_mask: Optional[torch.Tensor] = None, 13 | extra_config: MRopeInterleaveConfig = None, 14 | ) -> tuple[torch.Tensor, torch.Tensor]: 15 | # qwen3vl use timestamps to seperate videos, like , the video_grid_thw should also be split 16 | # if you are using qwen2/2.5vl, please remove them 17 | if video_grid_thw is not None: 18 | video_grid_thw = torch.repeat_interleave( 19 | video_grid_thw, video_grid_thw[:, 0], dim=0 20 | ) 21 | video_grid_thw[:, 0] = 1 22 | 23 | spatial_merge_size = self.config.vision_config.spatial_merge_size 24 | image_token_id = self.config.image_token_id 25 | video_token_id = self.config.video_token_id 26 | vision_start_token_id = self.config.vision_start_token_id 27 | mrope_position_deltas = [] 28 | if input_ids is not None and ( 29 | image_grid_thw is not None or video_grid_thw is not None 30 | ): 31 | total_input_ids = input_ids 32 | if attention_mask is None: 33 | attention_mask = torch.ones_like(total_input_ids) 34 | position_ids = torch.ones( 35 | 3, 36 | input_ids.shape[0], 37 | input_ids.shape[1], 38 | dtype=input_ids.dtype, 39 | device=input_ids.device, 40 | ) 41 | image_index, video_index = 0, 0 42 | attention_mask = attention_mask.to(total_input_ids.device) 43 | for i, input_ids in enumerate(total_input_ids): 44 | input_ids = input_ids[attention_mask[i] == 1] 45 | image_nums, video_nums = 0, 0 46 | vision_start_indices = torch.argwhere( 47 | input_ids == vision_start_token_id 48 | ).squeeze(1) 49 | vision_tokens = input_ids[vision_start_indices + 1] 50 | image_nums = (vision_tokens == image_token_id).sum() 51 | video_nums = (vision_tokens == video_token_id).sum() 52 | input_tokens = input_ids.tolist() 53 | llm_pos_ids_list: list = [] 54 | st = 0 55 | remain_images, remain_videos = image_nums, video_nums 56 | for _ in range(image_nums + video_nums): 57 | if image_token_id in input_tokens and remain_images > 0: 58 | ed_image = input_tokens.index(image_token_id, st) 59 | else: 60 | ed_image = len(input_tokens) + 1 61 | if video_token_id in input_tokens and remain_videos > 0: 62 | ed_video = input_tokens.index(video_token_id, st) 63 | else: 64 | ed_video = len(input_tokens) + 1 65 | if ed_image < ed_video: 66 | t, h, w = ( 67 | image_grid_thw[image_index][0], 68 | image_grid_thw[image_index][1], 69 | image_grid_thw[image_index][2], 70 | ) 71 | image_index += 1 72 | remain_images -= 1 73 | ed = ed_image 74 | 75 | else: 76 | t, h, w = ( 77 | video_grid_thw[video_index][0], 78 | video_grid_thw[video_index][1], 79 | video_grid_thw[video_index][2], 80 | ) 81 | video_index += 1 82 | remain_videos -= 1 83 | ed = ed_video 84 | llm_grid_t, llm_grid_h, llm_grid_w = ( 85 | t.item(), 86 | h.item() // spatial_merge_size, 87 | w.item() // spatial_merge_size, 88 | ) 89 | text_len = ed - st 90 | 91 | st_idx = ( 92 | llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 93 | ) 94 | llm_pos_ids_list.append( 95 | torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx 96 | ) 97 | 98 | # t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos) 99 | t_index = ( 100 | torch.arange(llm_grid_t) 101 | .view(-1, 1) 102 | .expand(-1, llm_grid_h * llm_grid_w) 103 | .flatten() 104 | ) * extra_config.temporal_stride 105 | h_index = ( 106 | torch.arange(llm_grid_h) 107 | .view(1, -1, 1) 108 | .expand(llm_grid_t, -1, llm_grid_w) 109 | .flatten() 110 | ) 111 | w_index = ( 112 | torch.arange(llm_grid_w) 113 | .view(1, 1, -1) 114 | .expand(llm_grid_t, llm_grid_h, -1) 115 | .flatten() 116 | ) 117 | if extra_config.spatial_reset: 118 | mm_pos_ids = torch.stack([t_index, h_index, w_index]) 119 | # calculate the token id to avoid too narrow stride caused by .max() line 120 120 | vision_end_token_id = torch.full( 121 | (3, 1), torch.max(mm_pos_ids).item() + 1 + text_len + st_idx 122 | ) 123 | mm_pos_ids[0] += text_len + st_idx 124 | llm_pos_ids_list.append( 125 | torch.cat([mm_pos_ids, vision_end_token_id], dim=1) 126 | ) 127 | st = ed + llm_grid_t * llm_grid_h * llm_grid_w + 1 128 | else: 129 | llm_pos_ids_list.append( 130 | torch.stack([t_index, h_index, w_index]) + text_len + st_idx 131 | ) 132 | st = ed + llm_grid_t * llm_grid_h * llm_grid_w 133 | 134 | if st < len(input_tokens): 135 | st_idx = ( 136 | llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 137 | ) 138 | text_len = len(input_tokens) - st 139 | llm_pos_ids_list.append( 140 | torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx 141 | ) 142 | 143 | llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) 144 | position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( 145 | position_ids.device 146 | ) 147 | mrope_position_deltas.append( 148 | llm_positions.max() + 1 - len(total_input_ids[i]) 149 | ) 150 | mrope_position_deltas = torch.tensor( 151 | mrope_position_deltas, device=input_ids.device 152 | ).unsqueeze(1) 153 | return position_ids, mrope_position_deltas 154 | else: 155 | if attention_mask is not None: 156 | position_ids = attention_mask.long().cumsum(-1) - 1 157 | position_ids.masked_fill_(attention_mask == 0, 1) 158 | position_ids = ( 159 | position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) 160 | ) 161 | max_position_ids = position_ids.max(0, keepdim=False)[0].max( 162 | -1, keepdim=True 163 | )[0] 164 | mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] 165 | else: 166 | position_ids = ( 167 | torch.arange(input_ids.shape[1], device=input_ids.device) 168 | .view(1, 1, -1) 169 | .expand(3, input_ids.shape[0], -1) 170 | ) 171 | mrope_position_deltas = torch.zeros( 172 | [input_ids.shape[0], 1], 173 | device=input_ids.device, 174 | dtype=input_ids.dtype, 175 | ) 176 | 177 | return position_ids, mrope_position_deltas 178 | -------------------------------------------------------------------------------- /multimodal_ropes/pos_design/circlerope.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from typing import Optional 4 | 5 | from ..configs.circlerope import CircleRopeConfig 6 | 7 | 8 | def get_circlerope_index( 9 | self, 10 | input_ids: Optional[torch.LongTensor] = None, 11 | image_grid_thw: Optional[torch.LongTensor] = None, 12 | video_grid_thw: Optional[torch.LongTensor] = None, 13 | attention_mask: Optional[torch.Tensor] = None, 14 | extra_config: CircleRopeConfig = None, 15 | ) -> tuple[torch.Tensor, torch.Tensor]: 16 | # qwen3vl use timestamps to seperate videos, like , the video_grid_thw should also be split 17 | # if you are using qwen2/2.5vl, please remove them 18 | if video_grid_thw is not None: 19 | video_grid_thw = torch.repeat_interleave( 20 | video_grid_thw, video_grid_thw[:, 0], dim=0 21 | ) 22 | video_grid_thw[:, 0] = 1 23 | 24 | spatial_merge_size = self.config.vision_config.spatial_merge_size 25 | image_token_id = self.config.image_token_id 26 | video_token_id = self.config.video_token_id 27 | vision_start_token_id = self.config.vision_start_token_id 28 | mrope_position_deltas = [] 29 | if input_ids is not None and ( 30 | image_grid_thw is not None or video_grid_thw is not None 31 | ): 32 | total_input_ids = input_ids 33 | if attention_mask is None: 34 | attention_mask = torch.ones_like(total_input_ids) 35 | position_ids = torch.ones( 36 | 3, 37 | input_ids.shape[0], 38 | input_ids.shape[1], 39 | # dtype=input_ids.dtype, 40 | dtype=torch.float, 41 | device=input_ids.device, 42 | ) 43 | image_index, video_index = 0, 0 44 | attention_mask = attention_mask.to(total_input_ids.device) 45 | for i, input_ids in enumerate(total_input_ids): 46 | input_ids = input_ids[attention_mask[i] == 1] 47 | image_nums, video_nums = 0, 0 48 | vision_start_indices = torch.argwhere( 49 | input_ids == vision_start_token_id 50 | ).squeeze(1) 51 | vision_tokens = input_ids[vision_start_indices + 1] 52 | image_nums = (vision_tokens == image_token_id).sum() 53 | video_nums = (vision_tokens == video_token_id).sum() 54 | input_tokens = input_ids.tolist() 55 | llm_pos_ids_list: list = [] 56 | st = 0 57 | remain_images, remain_videos = image_nums, video_nums 58 | for _ in range(image_nums + video_nums): 59 | if image_token_id in input_tokens and remain_images > 0: 60 | ed_image = input_tokens.index(image_token_id, st) 61 | else: 62 | ed_image = len(input_tokens) + 1 63 | if video_token_id in input_tokens and remain_videos > 0: 64 | ed_video = input_tokens.index(video_token_id, st) 65 | else: 66 | ed_video = len(input_tokens) + 1 67 | if ed_image < ed_video: 68 | t, h, w = ( 69 | image_grid_thw[image_index][0], 70 | image_grid_thw[image_index][1], 71 | image_grid_thw[image_index][2], 72 | ) 73 | image_index += 1 74 | remain_images -= 1 75 | ed = ed_image 76 | 77 | else: 78 | t, h, w = ( 79 | video_grid_thw[video_index][0], 80 | video_grid_thw[video_index][1], 81 | video_grid_thw[video_index][2], 82 | ) 83 | video_index += 1 84 | remain_videos -= 1 85 | ed = ed_video 86 | llm_grid_t, llm_grid_h, llm_grid_w = ( 87 | t.item(), 88 | h.item() // spatial_merge_size, 89 | w.item() // spatial_merge_size, 90 | ) 91 | text_len = ed - st 92 | 93 | st_idx = ( 94 | llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 95 | ) 96 | llm_pos_ids_list.append( 97 | torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx 98 | ) 99 | 100 | # t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos) 101 | t_index = ( 102 | torch.arange(llm_grid_t) 103 | .view(-1, 1) 104 | .expand(-1, llm_grid_h * llm_grid_w) 105 | .view(-1, llm_grid_h, llm_grid_w) 106 | ) * extra_config.temporal_stride 107 | h_index = ( 108 | torch.arange(llm_grid_h) 109 | .view(1, -1, 1) 110 | .expand(llm_grid_t, -1, llm_grid_w) 111 | ) 112 | w_index = ( 113 | torch.arange(llm_grid_w) 114 | .view(1, 1, -1) 115 | .expand(llm_grid_t, llm_grid_h, -1) 116 | ) 117 | 118 | llm_pos_ids = _circle_projection( 119 | w_index, h_index, t_index, extra_config 120 | ) 121 | 122 | # original circle rope only supports image input, we extend it to support video input 123 | # by increasing the time dimension linearly 124 | llm_pos_ids = ( 125 | llm_pos_ids.repeat(1, llm_grid_t) 126 | + torch.arange(llm_grid_t) 127 | .view(1, -1) 128 | .repeat_interleave(llm_grid_h * llm_grid_w, dim=1) 129 | * extra_config.temporal_stride 130 | ) 131 | 132 | llm_pos_ids_list.append(llm_pos_ids + text_len + st_idx) 133 | st = ed + llm_grid_t * llm_grid_h * llm_grid_w 134 | 135 | if st < len(input_tokens): 136 | st_idx = ( 137 | llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 138 | ) 139 | text_len = len(input_tokens) - st 140 | llm_pos_ids_list.append( 141 | torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx 142 | ) 143 | 144 | llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) 145 | position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( 146 | position_ids.device 147 | ) 148 | mrope_position_deltas.append( 149 | llm_positions.max() + 1 - len(total_input_ids[i]) 150 | ) 151 | mrope_position_deltas = torch.tensor( 152 | mrope_position_deltas, device=input_ids.device 153 | ).unsqueeze(1) 154 | return position_ids, mrope_position_deltas 155 | else: 156 | if attention_mask is not None: 157 | position_ids = attention_mask.long().cumsum(-1) - 1 158 | position_ids.masked_fill_(attention_mask == 0, 1) 159 | position_ids = ( 160 | position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) 161 | ) 162 | max_position_ids = position_ids.max(0, keepdim=False)[0].max( 163 | -1, keepdim=True 164 | )[0] 165 | mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] 166 | else: 167 | position_ids = ( 168 | torch.arange(input_ids.shape[1], device=input_ids.device) 169 | .view(1, 1, -1) 170 | .expand(3, input_ids.shape[0], -1) 171 | ) 172 | mrope_position_deltas = torch.zeros( 173 | [input_ids.shape[0], 1], 174 | device=input_ids.device, 175 | dtype=input_ids.dtype, 176 | ) 177 | 178 | return position_ids, mrope_position_deltas 179 | 180 | 181 | # modified from https://github.com/lose4578/CircleRoPE 182 | def _circle_projection( 183 | w_index, h_index, t_index, extra_config: CircleRopeConfig = None 184 | ): 185 | # Load circle rope configurations 186 | move_to_origin = extra_config.move_to_origin 187 | move_to_positive = extra_config.move_to_positive 188 | dff_rate = extra_config.dff_rate 189 | method = extra_config.method 190 | radius = extra_config.radius 191 | alpha = extra_config.alpha 192 | 193 | # Stack original coordinates 194 | ori_coords = torch.stack((w_index, h_index, t_index), dim=0) 195 | if move_to_origin: 196 | # Move coordinates to origin if specified 197 | ori_coords = move_to_origin_coords(ori_coords) 198 | 199 | # Determine radius: auto or fixed value 200 | if "auto" in str(radius): 201 | if radius == "auto": 202 | radius_scale = 1 203 | else: 204 | _, radius_scale = radius.split("-") 205 | # Calculate radius based on the maximum absolute coordinate value 206 | radius = ori_coords.max().abs() * float(radius_scale) 207 | else: 208 | radius = float(radius) 209 | 210 | # Perform circle projection 211 | convert_coords = circle_projection( 212 | ori_coords, text_vector=[1, 1, 1], radius=radius, alpha=alpha, method=method 213 | ) 214 | 215 | # Apply differential rate if specified 216 | if dff_rate: 217 | no_circle_convert_coords = circle_projection( 218 | ori_coords, text_vector=[1, 1, 1], radius=-1, alpha=-1, method="no_circle" 219 | ) 220 | # Linearly interpolate between circle projection and original coordinates 221 | convert_coords = ( 222 | 1 - dff_rate 223 | ) * convert_coords + dff_rate * no_circle_convert_coords 224 | 225 | # Move coordinates to positive axis if specified 226 | if move_to_positive: 227 | if move_to_positive == "auto": 228 | offset = 0 229 | else: 230 | offset = float(move_to_positive) 231 | convert_coords = move_to_positive_axis(convert_coords, offset=offset) 232 | 233 | # Flatten coordinate dimensions 234 | w_index = convert_coords[0].flatten() 235 | h_index = convert_coords[1].flatten() 236 | t_index = convert_coords[2].flatten() 237 | 238 | # Stack coordinates for language model position IDs 239 | llm_pos_ids = torch.stack([t_index, h_index, w_index]) 240 | 241 | return llm_pos_ids 242 | 243 | 244 | def move_to_origin_coords(coords): 245 | """ 246 | Moves the center of the cube to the origin (stacked coordinates version). 247 | Parameters: 248 | coords: Tensor of shape (3, depth, height, width) 249 | Channel order corresponds to [x, y, z] axis coordinates. 250 | Returns: 251 | new_coords: Center-aligned coordinate tensor, maintaining the same shape. 252 | """ 253 | # Calculate the center point for each axis [x_center, y_center, z_center] 254 | max_vals = torch.amax( 255 | coords, dim=(1, 2, 3) 256 | ) # Get maximum value along spatial dimensions 257 | min_vals = torch.amin( 258 | coords, dim=(1, 2, 3) 259 | ) # Get minimum value along spatial dimensions 260 | centers = (max_vals + min_vals) / 2.0 261 | # Adjust dimensions for broadcast subtraction (3, 1, 1, 1) 262 | centers = centers.view(-1, 1, 1, 1) 263 | # Perform translation 264 | new_coords = coords - centers 265 | 266 | return new_coords 267 | 268 | 269 | def move_to_positive_axis(coords, offset=0): 270 | # Find the absolute minimum value across all coordinates 271 | min_vals = torch.abs(torch.min(coords)) 272 | # Create a tensor of these minimum values for shifting 273 | centers = torch.tensor([min_vals, min_vals, min_vals]).view(-1, 1, 1, 1) 274 | 275 | # Shift coordinates to be positive and add an optional offset 276 | new_coords = coords + centers + offset 277 | 278 | return new_coords 279 | 280 | 281 | def circle_projection( 282 | coords, text_vector=[1, 1, 1], radius=1.0, alpha=0.5, method="circle", rotate=True 283 | ): 284 | """ 285 | Maps a point cloud to the circumference of a circle on a plane perpendicular to the given text_vector. 286 | Parameters: 287 | coords: [3, N] or [3, D, H, W] point cloud or stacked coordinates. 288 | text_vector: [3] Normal vector of the target plane. 289 | radius: Target circle radius. 290 | alpha: Nonlinear coefficient (0-1, controls distribution density). 291 | method: 'circle' for mapping to circle, 'no_circle' for no mapping. 292 | rotate: Boolean, whether to rotate the plane. 293 | """ 294 | 295 | # Original non-linear circular mapping 296 | if method == "circle": 297 | coord_circle = map_to_circle(coords, radius, alpha) 298 | elif method == "no_circle": 299 | # Pass through coordinates if no circle mapping is specified 300 | coord_circle = coords 301 | else: 302 | raise ValueError(f"Invalid circle projection method: {method}") 303 | 304 | if rotate: 305 | # Rotate the plane to be perpendicular to the text_vector 306 | coord_circle = rotate_plane_perpendicular_to_vector(coord_circle, text_vector) 307 | 308 | return coord_circle 309 | 310 | 311 | def rotate_plane_perpendicular_to_vector(coord_circle, text_vector): 312 | data_device = coord_circle.device 313 | data_dtype = coord_circle.dtype 314 | 315 | # Construct the target plane coordinate system 316 | text = torch.tensor(text_vector, dtype=data_dtype, device=data_device).float() 317 | text_norm = torch.norm(text) 318 | if text_norm < 1e-6: 319 | raise ValueError("text_vector cannot be zero vector") 320 | text_unit = text / text_norm # Normalize the text vector 321 | 322 | # Construct an orthogonal basis 323 | if torch.abs(text_unit[0]) < 1e-6 and torch.abs(text_unit[1]) < 1e-6: 324 | # Handle the case where the vector is along the z-axis 325 | u = torch.tensor([1.0, 0.0, 0.0], device=data_device, dtype=data_dtype) 326 | v = torch.tensor([0.0, 1.0, 0.0], device=data_device, dtype=data_dtype) 327 | else: 328 | # Construct the first orthogonal vector u 329 | u = torch.stack( 330 | [-text_unit[1], text_unit[0], torch.tensor(0.0, device=data_device)] 331 | ) 332 | u = u / torch.norm(u) # Normalize u 333 | # Construct the second orthogonal vector v using cross product 334 | v = torch.cross(text_unit, u, dim=0) 335 | v = v / torch.norm(v) # Normalize v 336 | 337 | # Project the circle points onto the new coordinate system 338 | x_components = ( 339 | coord_circle[0] * u[0] + coord_circle[1] * v[0] 340 | ) # Contribution to new X from original X and Y 341 | y_components = ( 342 | coord_circle[0] * u[1] + coord_circle[1] * v[1] 343 | ) # Contribution to new Y from original X and Y 344 | z_components = ( 345 | coord_circle[0] * u[2] + coord_circle[1] * v[2] 346 | ) # Contribution to new Z from original X and Y 347 | 348 | coord_componets = torch.stack([x_components, y_components, z_components]) 349 | 350 | return coord_componets 351 | 352 | 353 | def map_to_circle(tensor, radius=1.0, alpha=0.5): 354 | """ 355 | Maps points on a plane (z coordinate is 0) to the edge of a circle centered at (0, 0, 0) with the given radius. 356 | 357 | Parameters: 358 | tensor: A tensor of shape (3, 1, H, W), where the three channels are x, y, z coordinates (here z coordinates are all 0). 359 | radius: The radius of the mapped circle, default 1.0. 360 | alpha: Value range [0,1], represents the weight of the normalized original angle; default 0.5. 361 | 362 | Returns: 363 | A tensor of the same shape as the input tensor, where each point on the plane is mapped to the edge of the circle, and the z coordinate remains unchanged. 364 | """ 365 | # Extract x, y, z components; here x and y are tensors of shape (H, W) 366 | x = tensor[0, 0] 367 | y = tensor[1, 0] 368 | z = tensor[2, 0] # z coordinates are preserved 369 | H, W = x.shape 370 | 371 | def get_norm_theta(): 372 | # Method 1: Calculate angle using original coordinates, then linearly normalize to [0, 2π] 373 | theta_orig = torch.atan2(y, x) 374 | theta_min = theta_orig.min() 375 | theta_max = theta_orig.max() 376 | theta_range = theta_max - theta_min 377 | if theta_range > 0: 378 | theta_uniform = (theta_orig - theta_min) / theta_range * (2 * math.pi) 379 | else: 380 | # Handle cases with a single point or all points collinear through origin 381 | theta_uniform = theta_orig 382 | return theta_uniform 383 | 384 | def get_index_theta(): 385 | # Method 2: Generate uniformly distributed angles based on grid indices 386 | indices = torch.arange( 387 | H * W, dtype=torch.float32, device=tensor.device 388 | ).reshape(H, W) 389 | theta_uniform = indices / (H * W) * (2 * math.pi) 390 | return theta_uniform 391 | 392 | # The larger alpha is, the closer it is to the normalized original angle. 393 | # When alpha=0, the grid index method is fully used. 394 | # When alpha=1, the original coordinate calculation angle is fully used. 395 | theta_norm = get_norm_theta() 396 | theta_index = get_index_theta() 397 | # Combine the two methods for calculating theta based on alpha 398 | theta_uniform = alpha * theta_norm + (1 - alpha) * theta_index 399 | 400 | # Generate mapped x, y coordinates based on the calculated uniform angle 401 | new_x = radius * torch.cos(theta_uniform) 402 | new_y = radius * torch.sin(theta_uniform) 403 | 404 | # Combine the three channels and maintain the shape as (3, 1, H, W) 405 | new_tensor = torch.stack([new_x, new_y, z], dim=0).unsqueeze(1) 406 | 407 | return new_tensor 408 | --------------------------------------------------------------------------------