├── OdysseyAgent ├── config.json ├── configuration_qwen.py ├── generation_config.json ├── modeling_qwen.py ├── pytorch_model.bin.index.json ├── qwen.tiktoken ├── qwen_generation_utils.py ├── special_tokens_map.json ├── tokenization_qwen.py ├── tokenizer_config.json └── visual.py ├── Quickstart.md ├── README.md ├── assets ├── dataset_overview.jpg ├── pipeline.jpg └── pipeline.png ├── data ├── format_converter.py └── preprocessing.py ├── introduction.md └── src ├── SimSun.ttf ├── data_loader.py ├── eval_mm ├── GUIOdyssey_action_matching.py └── evaluate_GUIOdyssey.py ├── finetune.py ├── finetune ├── ds_config_zero1.json ├── ds_config_zero2.json ├── ds_config_zero3.json ├── finetune_ds.sh ├── finetune_lora_ds.sh ├── finetune_lora_single_gpu.sh ├── finetune_qlora_ds.sh └── finetune_qlora_single_gpu.sh ├── merge_weight.py ├── qwen_generation_utils.py ├── requirements.txt └── script ├── eval.sh └── train.sh /OdysseyAgent/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "OdysseyAgent", 3 | "architectures": [ 4 | "QWenLMHeadModel" 5 | ], 6 | "attn_dropout_prob": 0.0, 7 | "auto_map": { 8 | "AutoConfig": "configuration_qwen.QWenConfig", 9 | "AutoModelForCausalLM": "modeling_qwen.QWenLMHeadModel" 10 | }, 11 | "bf16": true, 12 | "emb_dropout_prob": 0.0, 13 | "fp16": false, 14 | "fp32": false, 15 | "hidden_size": 4096, 16 | "his_len": 4, 17 | "initializer_range": 0.02, 18 | "intermediate_size": 22016, 19 | "kv_channels": 128, 20 | "layer_norm_epsilon": 1e-06, 21 | "max_position_embeddings": 8192, 22 | "model_type": "qwen", 23 | "no_bias": true, 24 | "num_attention_heads": 32, 25 | "num_hidden_layers": 32, 26 | "onnx_safe": null, 27 | "rotary_emb_base": 10000, 28 | "rotary_pct": 1.0, 29 | "scale_attn_weights": true, 30 | "seq_length": 2048, 31 | "tie_word_embeddings": false, 32 | "tokenizer_type": "QWenTokenizer", 33 | "torch_dtype": "bfloat16", 34 | "transformers_version": "4.32.0", 35 | "use_cache": true, 36 | "use_dynamic_ntk": true, 37 | "use_flash_attn": false, 38 | "use_logn_attn": true, 39 | "visual": { 40 | "heads": 16, 41 | "image_size": 448, 42 | "image_start_id": 151857, 43 | "layers": 48, 44 | "mlp_ratio": 4.9231, 45 | "output_dim": 4096, 46 | "patch_size": 14, 47 | "width": 1664 48 | }, 49 | "vocab_size": 151936 50 | } 51 | -------------------------------------------------------------------------------- /OdysseyAgent/configuration_qwen.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba Cloud. 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from transformers import PretrainedConfig 7 | 8 | 9 | class QWenConfig(PretrainedConfig): 10 | model_type = "qwen" 11 | keys_to_ignore_at_inference = ["past_key_values"] 12 | 13 | def __init__( 14 | self, 15 | vocab_size=151936, 16 | hidden_size=4096, 17 | num_hidden_layers=32, 18 | num_attention_heads=32, 19 | emb_dropout_prob=0.0, 20 | attn_dropout_prob=0.0, 21 | layer_norm_epsilon=1e-6, 22 | initializer_range=0.02, 23 | max_position_embeddings=8192, 24 | scale_attn_weights=True, 25 | use_cache=True, 26 | bf16=False, 27 | fp16=False, 28 | fp32=False, 29 | kv_channels=128, 30 | rotary_pct=1.0, 31 | rotary_emb_base=10000, 32 | use_dynamic_ntk=True, 33 | use_logn_attn=True, 34 | use_flash_attn="auto", 35 | intermediate_size=22016, 36 | no_bias=True, 37 | tie_word_embeddings=False, 38 | **kwargs, 39 | ): 40 | self.vocab_size = vocab_size 41 | self.hidden_size = hidden_size 42 | self.intermediate_size = intermediate_size 43 | self.num_hidden_layers = num_hidden_layers 44 | self.num_attention_heads = num_attention_heads 45 | self.emb_dropout_prob = emb_dropout_prob 46 | self.attn_dropout_prob = attn_dropout_prob 47 | self.layer_norm_epsilon = layer_norm_epsilon 48 | self.initializer_range = initializer_range 49 | self.scale_attn_weights = scale_attn_weights 50 | self.use_cache = use_cache 51 | self.max_position_embeddings = max_position_embeddings 52 | self.bf16 = bf16 53 | self.fp16 = fp16 54 | self.fp32 = fp32 55 | self.kv_channels = kv_channels 56 | self.rotary_pct = rotary_pct 57 | self.rotary_emb_base = rotary_emb_base 58 | self.use_dynamic_ntk = use_dynamic_ntk 59 | self.use_logn_attn = use_logn_attn 60 | self.use_flash_attn = use_flash_attn 61 | self.no_bias = no_bias 62 | super().__init__( 63 | tie_word_embeddings=tie_word_embeddings, 64 | **kwargs 65 | ) 66 | -------------------------------------------------------------------------------- /OdysseyAgent/generation_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_from_model_config": true, 3 | "transformers_version": "4.32.0" 4 | } 5 | -------------------------------------------------------------------------------- /OdysseyAgent/modeling_qwen.py: -------------------------------------------------------------------------------- 1 | print('OdysseyAgent') 2 | 3 | import importlib 4 | import math 5 | from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator 6 | import os, json 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | import torch.utils.checkpoint 11 | from torch.cuda.amp import autocast 12 | 13 | from torch.nn import CrossEntropyLoss 14 | from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList 15 | from transformers.generation.logits_process import LogitsProcessorList 16 | 17 | if TYPE_CHECKING: 18 | from transformers.generation.streamers import BaseStreamer 19 | from transformers.generation.utils import GenerateOutput 20 | from transformers.modeling_outputs import ( 21 | BaseModelOutputWithPast, 22 | CausalLMOutputWithPast, 23 | ) 24 | from transformers.modeling_utils import PreTrainedModel 25 | from transformers.utils import logging 26 | 27 | try: 28 | from einops import rearrange 29 | except ImportError: 30 | rearrange = None 31 | from torch import nn 32 | 33 | SUPPORT_CUDA = torch.cuda.is_available() 34 | SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported() 35 | SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7 36 | 37 | from torch.nn.init import trunc_normal_ 38 | import sys 39 | sys.path.append('../OdysseyAgent') 40 | 41 | from configuration_qwen import QWenConfig 42 | from qwen_generation_utils import ( 43 | HistoryType, 44 | make_context, 45 | decode_tokens, 46 | get_stop_words_ids, 47 | StopWordsLogitsProcessor, 48 | ) 49 | from visual import VisionTransformer 50 | 51 | IMAGE_HISTORY = '../data/his_index.json' 52 | 53 | USE_RESAMPLER = True 54 | 55 | print(IMAGE_HISTORY) 56 | logger = logging.get_logger(__name__) 57 | 58 | _CHECKPOINT_FOR_DOC = "qwen" 59 | _CONFIG_FOR_DOC = "QWenConfig" 60 | 61 | QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"] 62 | 63 | _ERROR_BAD_CHAT_FORMAT = """\ 64 | We detect you are probably using the pretrained model (rather than chat model) for chatting, since the chat_format in generation_config is not "chatml". 65 | If you are directly using the model downloaded from Huggingface, please make sure you are using our "Qwen/Qwen-7B-Chat" Huggingface model (rather than "Qwen/Qwen-7B") when you call model.chat(). 66 | 我们检测到您可能在使用预训练模型(而非chat模型)进行多轮chat,因为您当前在generation_config指定的chat_format,并未设置为我们在对话中所支持的"chatml"格式。 67 | 如果您在直接使用我们从Huggingface提供的模型,请确保您在调用model.chat()时,使用的是"Qwen/Qwen-7B-Chat"模型(而非"Qwen/Qwen-7B"预训练模型)。 68 | """ 69 | 70 | _SENTINEL = object() 71 | _ERROR_STREAM_IN_CHAT = """\ 72 | Pass argument `stream` to model.chat() is buggy, deprecated, and marked for removal. Please use model.chat_stream(...) instead of model.chat(..., stream=True). 73 | 向model.chat()传入参数stream的用法可能存在Bug,该用法已被废弃,将在未来被移除。请使用model.chat_stream(...)代替model.chat(..., stream=True)。 74 | """ 75 | 76 | apply_rotary_emb_func = None 77 | rms_norm = None 78 | 79 | 80 | # Copied from transformers.models.bart.modeling_bart._make_causal_mask 81 | def _make_causal_mask( 82 | input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 83 | ): 84 | """ 85 | Make causal mask used for bi-directional self-attention. 86 | """ 87 | bsz, tgt_len = input_ids_shape 88 | mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) 89 | mask_cond = torch.arange(mask.size(-1), device=device) 90 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) 91 | mask = mask.to(dtype) 92 | 93 | if past_key_values_length > 0: 94 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) 95 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) 96 | 97 | 98 | # Copied from transformers.models.bart.modeling_bart._expand_mask 99 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): 100 | """ 101 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. 102 | """ 103 | bsz, src_len = mask.size() 104 | tgt_len = tgt_len if tgt_len is not None else src_len 105 | 106 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) 107 | 108 | inverted_mask = 1.0 - expanded_mask 109 | 110 | return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) 111 | 112 | def get_abs_pos(abs_pos, tgt_size): 113 | # abs_pos: L, C 114 | # tgt_size: M 115 | # return: M, C 116 | src_size = int(math.sqrt(abs_pos.size(0))) 117 | tgt_size = int(math.sqrt(tgt_size)) 118 | dtype = abs_pos.dtype 119 | 120 | if src_size != tgt_size: 121 | return F.interpolate( 122 | abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), 123 | size=(tgt_size, tgt_size), 124 | mode="bicubic", 125 | align_corners=False, 126 | ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype) 127 | else: 128 | return abs_pos 129 | 130 | 131 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 132 | """ 133 | grid_size: int of the grid height and width 134 | return: 135 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 136 | """ 137 | grid_h = np.arange(grid_size, dtype=np.float32) 138 | grid_w = np.arange(grid_size, dtype=np.float32) 139 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 140 | grid = np.stack(grid, axis=0) 141 | 142 | grid = grid.reshape([2, 1, grid_size, grid_size]) 143 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 144 | if cls_token: 145 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 146 | return pos_embed 147 | 148 | 149 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 150 | assert embed_dim % 2 == 0 151 | 152 | # use half of dimensions to encode grid_h 153 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 154 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 155 | 156 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 157 | return emb 158 | 159 | 160 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 161 | """ 162 | embed_dim: output dimension for each position 163 | pos: a list of positions to be encoded: size (M,) 164 | out: (M, D) 165 | """ 166 | assert embed_dim % 2 == 0 167 | omega = np.arange(embed_dim // 2, dtype=np.float32) 168 | omega /= embed_dim / 2. 169 | omega = 1. / 10000**omega # (D/2,) 170 | 171 | pos = pos.reshape(-1) # (M,) 172 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 173 | 174 | emb_sin = np.sin(out) # (M, D/2) 175 | emb_cos = np.cos(out) # (M, D/2) 176 | 177 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 178 | return emb 179 | 180 | 181 | 182 | class HisResampler(nn.Module): 183 | def __init__( 184 | self, 185 | embed_dim=4096, 186 | num_heads=32, 187 | grid_size=16, 188 | kv_dim=None, 189 | norm_layer=nn.LayerNorm 190 | ): 191 | super().__init__() 192 | self.num_queries = grid_size ** 2 193 | self.embed_dim = embed_dim 194 | self.num_heads = num_heads 195 | 196 | self.pos_embed = nn.Parameter( 197 | torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float() 198 | ).requires_grad_(False) 199 | 200 | self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) 201 | trunc_normal_(self.query, std=.02) 202 | 203 | if kv_dim is not None and kv_dim != embed_dim: 204 | self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False) 205 | else: 206 | self.kv_proj = nn.Identity() 207 | 208 | self.attn = nn.MultiheadAttention(embed_dim, num_heads) 209 | self.ln_q = norm_layer(embed_dim) 210 | self.ln_kv = norm_layer(embed_dim) 211 | 212 | self.ln_post = norm_layer(embed_dim) 213 | self.proj = nn.Parameter((embed_dim** -0.5) * torch.randn(embed_dim, embed_dim)) 214 | 215 | self.apply(self._init_weights) 216 | 217 | def _init_weights(self, m): 218 | if isinstance(m, nn.Linear): 219 | trunc_normal_(m.weight, std=.02) 220 | if isinstance(m, nn.Linear) and m.bias is not None: 221 | nn.init.constant_(m.bias, 0) 222 | elif isinstance(m, nn.LayerNorm): 223 | nn.init.constant_(m.bias, 0) 224 | nn.init.constant_(m.weight, 1.0) 225 | 226 | def forward(self, x, attn_mask=None): 227 | 228 | x = self.kv_proj(x) 229 | x = self.ln_kv(x).permute(1, 0, 2) 230 | 231 | N = x.shape[1] 232 | q = self.ln_q(self.query) 233 | out = self.attn( 234 | self._repeat(q, N), 235 | x, 236 | x, 237 | attn_mask=attn_mask)[0] 238 | out = out.permute(1, 0, 2) 239 | out = self.ln_post(out) 240 | out = out @ self.proj 241 | return out 242 | 243 | def _repeat(self, query, N: int): 244 | return query.unsqueeze(1).repeat(1, N, 1) 245 | 246 | 247 | 248 | class QWenAttention(nn.Module): 249 | def __init__(self, config): 250 | super().__init__() 251 | 252 | self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) 253 | self.seq_length = config.seq_length 254 | 255 | self.hidden_size = config.hidden_size 256 | self.split_size = config.hidden_size 257 | self.num_heads = config.num_attention_heads 258 | self.head_dim = self.hidden_size // self.num_heads 259 | 260 | self.scale_attn_weights = True 261 | 262 | self.projection_size = config.kv_channels * config.num_attention_heads 263 | 264 | assert self.projection_size % config.num_attention_heads == 0 265 | self.hidden_size_per_attention_head = ( 266 | self.projection_size // config.num_attention_heads 267 | ) 268 | 269 | self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size) 270 | 271 | self.c_proj = nn.Linear( 272 | config.hidden_size, self.projection_size, bias=not config.no_bias 273 | ) 274 | 275 | self.is_fp32 = not (config.bf16 or config.fp16) 276 | self.bf16 = config.bf16 277 | 278 | self.use_dynamic_ntk = config.use_dynamic_ntk 279 | self.use_logn_attn = config.use_logn_attn 280 | 281 | logn_list = [ 282 | math.log(i, self.seq_length) if i > self.seq_length else 1 283 | for i in range(1, 32768) 284 | ] 285 | self.logn_tensor = torch.tensor(logn_list)[None, :, None, None] 286 | 287 | self.attn_dropout = nn.Dropout(config.attn_dropout_prob) 288 | 289 | def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None): 290 | attn_weights = torch.matmul(query, key.transpose(-1, -2)) 291 | 292 | if self.scale_attn_weights: 293 | attn_weights = attn_weights / torch.full( 294 | [], 295 | value.size(-1) ** 0.5, 296 | dtype=attn_weights.dtype, 297 | device=attn_weights.device, 298 | ) 299 | 300 | query_length, key_length = query.size(-2), key.size(-2) 301 | # causal_mask = self.bias[ 302 | # :, :, key_length - query_length : key_length, :key_length 303 | # ] 304 | # mask_value = torch.finfo(attn_weights.dtype).min 305 | # mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to( 306 | # attn_weights.device 307 | # ) 308 | # attn_weights = torch.where( 309 | # causal_mask, attn_weights.to(attn_weights.dtype), mask_value 310 | # ) 311 | attn_weights = attn_weights + attention_mask 312 | 313 | attn_weights = nn.functional.softmax(attn_weights, dim=-1) 314 | 315 | attn_weights = attn_weights.type(value.dtype) 316 | attn_weights = self.attn_dropout(attn_weights) 317 | 318 | if head_mask is not None: 319 | attn_weights = attn_weights * head_mask 320 | 321 | attn_output = torch.matmul(attn_weights, value) 322 | attn_output = attn_output.transpose(1, 2) 323 | 324 | return attn_output, attn_weights 325 | 326 | def _upcast_and_reordered_attn( 327 | self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None 328 | ): 329 | bsz, num_heads, q_seq_len, dk = query.size() 330 | _, _, k_seq_len, _ = key.size() 331 | 332 | attn_weights = torch.empty( 333 | bsz * num_heads, 334 | q_seq_len, 335 | k_seq_len, 336 | dtype=torch.float32, 337 | device=query.device, 338 | ) 339 | 340 | scale_factor = 1.0 341 | if self.scale_attn_weights: 342 | scale_factor /= float(value.size(-1)) ** 0.5 343 | 344 | with autocast(enabled=False): 345 | q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape( 346 | -1, dk, k_seq_len 347 | ) 348 | attn_weights = torch.baddbmm( 349 | attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor 350 | ) 351 | attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) 352 | 353 | query_length, key_length = query.size(-2), key.size(-2) 354 | causal_mask = registered_causal_mask[ 355 | :, :, key_length - query_length : key_length, :key_length 356 | ] 357 | mask_value = torch.finfo(attn_weights.dtype).min 358 | mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to( 359 | attn_weights.device 360 | ) 361 | attn_weights = torch.where(causal_mask, attn_weights, mask_value) 362 | 363 | if attention_mask is not None: 364 | attn_weights = attn_weights + attention_mask 365 | 366 | attn_weights = nn.functional.softmax(attn_weights, dim=-1) 367 | 368 | if attn_weights.dtype != torch.float32: 369 | raise RuntimeError( 370 | "Error with upcasting, attn_weights does not have dtype torch.float32" 371 | ) 372 | attn_weights = attn_weights.type(value.dtype) 373 | attn_weights = self.attn_dropout(attn_weights) 374 | 375 | if head_mask is not None: 376 | attn_weights = attn_weights * head_mask 377 | 378 | attn_output = torch.matmul(attn_weights, value) 379 | 380 | return attn_output, attn_weights 381 | 382 | def _split_heads(self, tensor, num_heads, attn_head_size): 383 | new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) 384 | tensor = tensor.view(new_shape) 385 | return tensor 386 | 387 | def _merge_heads(self, tensor, num_heads, attn_head_size): 388 | tensor = tensor.contiguous() 389 | new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) 390 | return tensor.view(new_shape) 391 | 392 | def forward( 393 | self, 394 | hidden_states: Optional[Tuple[torch.FloatTensor]], 395 | rotary_pos_emb: Optional[List[torch.Tensor]] = None, 396 | registered_causal_mask: Optional[torch.Tensor] = None, 397 | layer_past: Optional[Tuple[torch.Tensor]] = None, 398 | attention_mask: Optional[torch.FloatTensor] = None, 399 | head_mask: Optional[torch.FloatTensor] = None, 400 | encoder_hidden_states: Optional[torch.Tensor] = None, 401 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 402 | output_attentions: Optional[bool] = False, 403 | use_cache: Optional[bool] = False, 404 | ): 405 | 406 | mixed_x_layer = self.c_attn(hidden_states) 407 | 408 | query, key, value = mixed_x_layer.split(self.split_size, dim=2) 409 | 410 | query = self._split_heads(query, self.num_heads, self.head_dim) 411 | key = self._split_heads(key, self.num_heads, self.head_dim) 412 | value = self._split_heads(value, self.num_heads, self.head_dim) 413 | 414 | if rotary_pos_emb is not None: 415 | cur_len = query.shape[1] 416 | rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] 417 | rotary_pos_emb = (rotary_pos_emb,) * 2 418 | q_pos_emb, k_pos_emb = rotary_pos_emb 419 | # Slice the pos emb for current inference 420 | query = apply_rotary_pos_emb(query, q_pos_emb) 421 | key = apply_rotary_pos_emb(key, k_pos_emb) 422 | 423 | if layer_past is not None: 424 | past_key, past_value = layer_past[0], layer_past[1] 425 | key = torch.cat((past_key, key), dim=1) 426 | value = torch.cat((past_value, value), dim=1) 427 | 428 | if use_cache: 429 | present = (key, value) 430 | else: 431 | present = None 432 | 433 | if self.use_logn_attn and not self.training: 434 | if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype: 435 | self.logn_tensor = self.logn_tensor.to(query.device).type_as(query) 436 | seq_start = key.size(1) - query.size(1) 437 | seq_end = key.size(1) 438 | logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :] 439 | query = query * logn_tensor.expand_as(query) 440 | 441 | query = query.permute(0, 2, 1, 3) 442 | key = key.permute(0, 2, 1, 3) 443 | value = value.permute(0, 2, 1, 3) 444 | attn_output, attn_weight = self._attn( 445 | query, key, value, registered_causal_mask, attention_mask, head_mask 446 | ) 447 | context_layer = self._merge_heads( 448 | attn_output, self.num_heads, self.head_dim 449 | ) 450 | 451 | attn_output = self.c_proj(context_layer) 452 | 453 | outputs = (attn_output, present) 454 | if output_attentions: 455 | outputs += (attn_weight,) 456 | 457 | return outputs 458 | 459 | 460 | class QWenMLP(nn.Module): 461 | def __init__(self, config): 462 | super().__init__() 463 | self.w1 = nn.Linear( 464 | config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias 465 | ) 466 | self.w2 = nn.Linear( 467 | config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias 468 | ) 469 | ff_dim_in = config.intermediate_size // 2 470 | self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias) 471 | 472 | def forward(self, hidden_states): 473 | a1 = self.w1(hidden_states) 474 | a2 = self.w2(hidden_states) 475 | intermediate_parallel = a1 * F.silu(a2) 476 | output = self.c_proj(intermediate_parallel) 477 | return output 478 | 479 | class QWenBlock(nn.Module): 480 | def __init__(self, config): 481 | super().__init__() 482 | hidden_size = config.hidden_size 483 | self.bf16 = config.bf16 484 | 485 | self.ln_1 = RMSNorm( 486 | hidden_size, 487 | eps=config.layer_norm_epsilon, 488 | ) 489 | self.attn = QWenAttention(config) 490 | self.ln_2 = RMSNorm( 491 | hidden_size, 492 | eps=config.layer_norm_epsilon, 493 | ) 494 | 495 | self.mlp = QWenMLP(config) 496 | 497 | def forward( 498 | self, 499 | hidden_states: Optional[Tuple[torch.FloatTensor]], 500 | rotary_pos_emb: Optional[List[torch.Tensor]] = None, 501 | registered_causal_mask: Optional[torch.Tensor] = None, 502 | layer_past: Optional[Tuple[torch.Tensor]] = None, 503 | attention_mask: Optional[torch.FloatTensor] = None, 504 | head_mask: Optional[torch.FloatTensor] = None, 505 | encoder_hidden_states: Optional[torch.Tensor] = None, 506 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 507 | use_cache: Optional[bool] = False, 508 | output_attentions: Optional[bool] = False, 509 | ): 510 | layernorm_output = self.ln_1(hidden_states) 511 | 512 | attn_outputs = self.attn( 513 | layernorm_output, 514 | rotary_pos_emb, 515 | registered_causal_mask=registered_causal_mask, 516 | layer_past=layer_past, 517 | attention_mask=attention_mask, 518 | head_mask=head_mask, 519 | use_cache=use_cache, 520 | output_attentions=output_attentions, 521 | ) 522 | attn_output = attn_outputs[0] 523 | 524 | outputs = attn_outputs[1:] 525 | 526 | residual = hidden_states 527 | layernorm_input = attn_output + residual 528 | 529 | layernorm_output = self.ln_2(layernorm_input) 530 | 531 | residual = layernorm_input 532 | mlp_output = self.mlp(layernorm_output) 533 | hidden_states = residual + mlp_output 534 | 535 | if use_cache: 536 | outputs = (hidden_states,) + outputs 537 | else: 538 | outputs = (hidden_states,) + outputs[1:] 539 | 540 | return outputs 541 | 542 | 543 | class QWenPreTrainedModel(PreTrainedModel): 544 | config_class = QWenConfig 545 | base_model_prefix = "transformer" 546 | is_parallelizable = False 547 | supports_gradient_checkpointing = True 548 | _no_split_modules = ["QWenBlock"] 549 | 550 | def __init__(self, *inputs, **kwargs): 551 | super().__init__(*inputs, **kwargs) 552 | 553 | def _init_weights(self, module): 554 | """Initialize the weights.""" 555 | if isinstance(module, nn.Linear): 556 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 557 | if module.bias is not None: 558 | module.bias.data.zero_() 559 | elif isinstance(module, nn.Embedding): 560 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 561 | if module.padding_idx is not None: 562 | module.weight.data[module.padding_idx].zero_() 563 | elif isinstance(module, RMSNorm): 564 | module.weight.data.fill_(1.0) 565 | 566 | for name, p in module.named_parameters(): 567 | if name == "c_proj.weight": 568 | p.data.normal_( 569 | mean=0.0, 570 | std=( 571 | self.config.initializer_range 572 | / math.sqrt(2 * self.config.num_hidden_layers) 573 | ), 574 | ) 575 | 576 | def _set_gradient_checkpointing(self, module, value=False): 577 | if isinstance(module, QWenModel): 578 | module.gradient_checkpointing = value 579 | 580 | 581 | class QWenModel(QWenPreTrainedModel): 582 | _keys_to_ignore_on_load_missing = ["attn.masked_bias"] 583 | 584 | def __init__(self, config): 585 | super().__init__(config) 586 | self.his_len = config.his_len 587 | self.vocab_size = config.vocab_size 588 | self.num_hidden_layers = config.num_hidden_layers 589 | self.embed_dim = config.hidden_size 590 | 591 | self.gradient_checkpointing = False 592 | self.use_dynamic_ntk = config.use_dynamic_ntk 593 | self.seq_length = config.seq_length 594 | 595 | self.wte = nn.Embedding(self.vocab_size, self.embed_dim) 596 | 597 | self.drop = nn.Dropout(config.emb_dropout_prob) 598 | 599 | if config.rotary_pct == 1.0: 600 | self.rotary_ndims = None 601 | else: 602 | assert config.rotary_pct < 1 603 | self.rotary_ndims = int( 604 | config.kv_channels * config.rotary_pct 605 | ) 606 | dim = ( 607 | self.rotary_ndims 608 | if self.rotary_ndims is not None 609 | else config.kv_channels 610 | ) 611 | self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base) 612 | 613 | self.use_flash_attn = config.use_flash_attn 614 | self.is_fp32 = not (config.bf16 or config.fp16) 615 | self.registered_causal_mask = None 616 | 617 | self.h = nn.ModuleList( 618 | [ 619 | QWenBlock( 620 | config 621 | ) 622 | for i in range(config.num_hidden_layers) 623 | ] 624 | ) 625 | self.ln_f = RMSNorm( 626 | self.embed_dim, 627 | eps=config.layer_norm_epsilon, 628 | ) 629 | 630 | self.visual = VisionTransformer(**config.visual) 631 | 632 | self.post_init() 633 | 634 | if USE_RESAMPLER: 635 | print('init RESAMPLER') 636 | self.his_resampler = HisResampler() 637 | 638 | self.imgtoken_dict = {} 639 | if os.path.isdir(IMAGE_HISTORY): 640 | for subdata in os.listdir(IMAGE_HISTORY): 641 | sub_img_dict = json.load(open(os.path.join(IMAGE_HISTORY, subdata))) 642 | self.imgtoken_dict.update(sub_img_dict) 643 | else: 644 | self.imgtoken_dict = json.load(open(IMAGE_HISTORY)) 645 | 646 | print('imgtoken_dict cache len:', len(self.imgtoken_dict)) 647 | 648 | def set_input_embeddings(self, new_embeddings): 649 | self.wte = new_embeddings 650 | 651 | # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask 652 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): 653 | # create causal mask 654 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 655 | combined_attention_mask = None 656 | if input_shape[-1] > 1: 657 | combined_attention_mask = _make_causal_mask( 658 | input_shape, 659 | inputs_embeds.dtype, 660 | device=inputs_embeds.device, 661 | past_key_values_length=past_key_values_length, 662 | ) 663 | 664 | if attention_mask is not None: 665 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 666 | expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( 667 | inputs_embeds.device 668 | ) 669 | combined_attention_mask = ( 670 | expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask 671 | ) 672 | 673 | return combined_attention_mask 674 | 675 | 676 | def forward( 677 | self, 678 | input_ids: Optional[torch.LongTensor] = None, 679 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 680 | attention_mask: Optional[torch.FloatTensor] = None, 681 | token_type_ids: Optional[torch.LongTensor] = None, 682 | position_ids: Optional[torch.LongTensor] = None, 683 | head_mask: Optional[torch.FloatTensor] = None, 684 | inputs_embeds: Optional[torch.FloatTensor] = None, 685 | encoder_hidden_states: Optional[torch.Tensor] = None, 686 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 687 | use_cache: Optional[bool] = None, 688 | output_attentions: Optional[bool] = None, 689 | output_hidden_states: Optional[bool] = None, 690 | return_dict: Optional[bool] = None, 691 | ): 692 | device = input_ids.device if input_ids is not None else inputs_embeds.device 693 | 694 | if past_key_values is None and torch.any(input_ids == self.config.visual['image_start_id']): 695 | bos_pos = torch.where(input_ids == self.config.visual['image_start_id']) 696 | eos_pos = torch.where(input_ids == self.config.visual['image_start_id'] + 1) 697 | assert (bos_pos[0] == eos_pos[0]).all() 698 | img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1) 699 | now_images = [] 700 | his_images = [] 701 | C_list = [] 702 | images = [] 703 | his_idx = [] 704 | his_image_temp = [] 705 | for idx, (i, a, b) in enumerate(img_pos): 706 | image = input_ids[i][a + 1 : b - 1].tolist() 707 | image = image[ : image.index(self.config.visual['image_start_id'] + 2)] 708 | image_path = bytes(image).decode('utf-8') 709 | 710 | if image_path.startswith('image-history: '): 711 | his_idx.append(idx) 712 | image_path = image_path.replace('image-history: ', '') 713 | his_list = self.imgtoken_dict[image_path][-self.his_len:] # t0 - tn-1 714 | assert len(his_list) > 0, his_list 715 | 716 | his_images.extend(his_list) 717 | his_image_temp.append(his_list) 718 | 719 | else: 720 | now_images.append(image_path) 721 | 722 | now_images = self.visual.encode(now_images) 723 | 724 | if len(his_images) > 0: 725 | his_images = self.visual.encode(his_images) 726 | his_tkn = None 727 | 728 | start_pos = 0 729 | for his_scr in his_image_temp: 730 | his_len = len(his_scr) 731 | his_img_feature = his_images[start_pos: start_pos + his_len] # [b, l, d] 732 | if USE_RESAMPLER: 733 | his_img_feature = his_img_feature.reshape(1, -1, his_img_feature.size(-1)) 734 | his_vis_tkn = self.his_resampler(his_img_feature) # [l, d] 735 | else: 736 | raise ValueError("You cannot run without History Redsampler!") 737 | his_tkn = his_vis_tkn if his_tkn is None else torch.concat((his_tkn, his_vis_tkn), dim=0) 738 | start_pos += his_len 739 | assert start_pos == len(his_images) 740 | his_images = his_tkn 741 | 742 | now_p, his_p = 0, 0 743 | for j in range(len(img_pos)): 744 | if j not in his_idx: 745 | images.append(now_images[now_p]) 746 | now_p += 1 747 | else: 748 | images.append(his_images[his_p]) 749 | his_p += 1 750 | images = torch.stack(images, dim=0) 751 | assert len(images) == len(img_pos) == len(now_images) + len(his_images) 752 | 753 | fake_images = None 754 | elif self.training: 755 | fake_images=torch.zeros(1,3,224,224).to( 756 | dtype=self.visual.conv1.weight.dtype, device=self.visual.conv1.weight.device) 757 | images = self.visual(fake_images) 758 | else: 759 | fake_images = None 760 | images = None 761 | 762 | output_attentions = ( 763 | output_attentions 764 | if output_attentions is not None 765 | else self.config.output_attentions 766 | ) 767 | output_hidden_states = ( 768 | output_hidden_states 769 | if output_hidden_states is not None 770 | else self.config.output_hidden_states 771 | ) 772 | use_cache = use_cache if use_cache is not None else self.config.use_cache 773 | return_dict = ( 774 | return_dict if return_dict is not None else self.config.use_return_dict 775 | ) 776 | 777 | if input_ids is not None and inputs_embeds is not None: 778 | raise ValueError( 779 | "You cannot specify both input_ids and inputs_embeds at the same time" 780 | ) 781 | elif input_ids is not None: 782 | input_shape = input_ids.size() 783 | input_ids = input_ids.view(-1, input_shape[-1]) 784 | batch_size = input_ids.shape[0] 785 | elif inputs_embeds is not None: 786 | input_shape = inputs_embeds.size()[:-1] 787 | batch_size = inputs_embeds.shape[0] 788 | else: 789 | raise ValueError("You have to specify either input_ids or inputs_embeds") 790 | 791 | 792 | if token_type_ids is not None: 793 | token_type_ids = token_type_ids.view(-1, input_shape[-1]) 794 | if position_ids is not None: 795 | position_ids = position_ids.view(-1, input_shape[-1]) 796 | 797 | if past_key_values is None: 798 | past_length = 0 799 | past_key_values = tuple([None] * len(self.h)) 800 | else: 801 | past_length = past_key_values[0][0].size(-2) 802 | 803 | if position_ids is None: 804 | position_ids = torch.arange( 805 | past_length, 806 | input_shape[-1] + past_length, 807 | dtype=torch.long, 808 | device=device, 809 | ) 810 | position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) 811 | 812 | encoder_attention_mask = None 813 | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) 814 | 815 | if inputs_embeds is None: 816 | inputs_embeds = self.wte(input_ids) 817 | 818 | if batch_size <= 0: 819 | raise ValueError("batch_size has to be defined and > 0") 820 | attention_mask = self._prepare_decoder_attention_mask( 821 | attention_mask, input_shape, inputs_embeds, past_length 822 | ) 823 | 824 | hidden_states = inputs_embeds 825 | 826 | kv_seq_len = hidden_states.size()[1] 827 | if past_key_values[0] is not None: 828 | # past key values[0][0] shape: bs * seq_len * head_num * dim 829 | kv_seq_len += past_key_values[0][0].shape[1] 830 | if ( 831 | self.use_dynamic_ntk 832 | and kv_seq_len == hidden_states.size()[1] 833 | and not self.training 834 | ): 835 | context_value = math.log(kv_seq_len / self.seq_length, 2) + 1 836 | ntk_alpha = 2 ** math.ceil(context_value) - 1 837 | ntk_alpha = max(ntk_alpha, 1) 838 | else: 839 | ntk_alpha = self.rotary_emb._ntk_alpha_cached 840 | 841 | rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) 842 | for idx in range(len(rotary_pos_emb)): 843 | rotary_pos_emb[idx] = rotary_pos_emb[idx].to(hidden_states.device) 844 | hidden_states = self.drop(hidden_states).clone() 845 | 846 | if fake_images is not None: 847 | hidden_states = hidden_states + images.mean()*0 848 | elif images is not None: 849 | for idx, (i, a, b) in enumerate(img_pos): 850 | hidden_states[i][a + 1 : b] = images[idx] 851 | output_shape = input_shape + (hidden_states.size(-1),) 852 | 853 | if self.gradient_checkpointing and self.training: 854 | if use_cache: 855 | logger.warning_once( 856 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 857 | ) 858 | use_cache = False 859 | 860 | presents = () if use_cache else None 861 | all_self_attentions = () if output_attentions else None 862 | all_hidden_states = () if output_hidden_states else None 863 | for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): 864 | 865 | if output_hidden_states: 866 | all_hidden_states = all_hidden_states + (hidden_states,) 867 | 868 | if self.gradient_checkpointing and self.training: 869 | 870 | def create_custom_forward(module): 871 | def custom_forward(*inputs): 872 | # None for past_key_value 873 | return module(*inputs, use_cache, output_attentions) 874 | 875 | return custom_forward 876 | 877 | outputs = torch.utils.checkpoint.checkpoint( 878 | create_custom_forward(block), 879 | hidden_states, 880 | rotary_pos_emb, 881 | self.registered_causal_mask, 882 | None, 883 | attention_mask, 884 | head_mask[i], 885 | encoder_hidden_states, 886 | encoder_attention_mask, 887 | ) 888 | else: 889 | outputs = block( 890 | hidden_states, 891 | layer_past=layer_past, 892 | rotary_pos_emb=rotary_pos_emb, 893 | registered_causal_mask=self.registered_causal_mask, 894 | attention_mask=attention_mask, 895 | head_mask=head_mask[i], 896 | encoder_hidden_states=encoder_hidden_states, 897 | encoder_attention_mask=encoder_attention_mask, 898 | use_cache=use_cache, 899 | output_attentions=output_attentions, 900 | ) 901 | 902 | hidden_states = outputs[0] 903 | if use_cache is True: 904 | presents = presents + (outputs[1],) 905 | 906 | if output_attentions: 907 | all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) 908 | 909 | hidden_states = self.ln_f(hidden_states) 910 | hidden_states = hidden_states.view(output_shape) 911 | # Add last hidden state 912 | if output_hidden_states: 913 | all_hidden_states = all_hidden_states + (hidden_states,) 914 | 915 | if not return_dict: 916 | return tuple( 917 | v for v in [hidden_states, presents, all_hidden_states] if v is not None 918 | ) 919 | 920 | return BaseModelOutputWithPast( 921 | last_hidden_state=hidden_states, 922 | past_key_values=presents, 923 | hidden_states=all_hidden_states, 924 | attentions=all_self_attentions, 925 | ) 926 | 927 | 928 | class QWenLMHeadModel(QWenPreTrainedModel): 929 | _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"] 930 | _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias"] 931 | 932 | def __init__(self, config): 933 | super().__init__(config) 934 | assert ( 935 | config.bf16 + config.fp16 + config.fp32 <= 1 936 | ), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true" 937 | 938 | autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0 939 | 940 | if autoset_precision: 941 | if SUPPORT_BF16: 942 | logger.warn( 943 | "The model is automatically converting to bf16 for faster inference. " 944 | "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"." 945 | ) 946 | config.bf16 = True 947 | elif SUPPORT_FP16: 948 | logger.warn( 949 | "The model is automatically converting to fp16 for faster inference. " 950 | "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"." 951 | ) 952 | config.fp16 = True 953 | else: 954 | config.fp32 = True 955 | 956 | if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16: 957 | logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".") 958 | if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16: 959 | logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster") 960 | if config.fp32: 961 | if SUPPORT_BF16: 962 | logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".") 963 | elif SUPPORT_FP16: 964 | logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".") 965 | 966 | self.transformer = QWenModel(config) 967 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 968 | 969 | if config.bf16: 970 | self.transformer.bfloat16() 971 | self.lm_head.bfloat16() 972 | if config.fp16: 973 | self.transformer.half() 974 | self.lm_head.half() 975 | self.post_init() 976 | 977 | def get_output_embeddings(self): 978 | return self.lm_head 979 | 980 | def set_output_embeddings(self, new_embeddings): 981 | self.lm_head = new_embeddings 982 | 983 | def prepare_inputs_for_generation( 984 | self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs 985 | ): 986 | token_type_ids = kwargs.get("token_type_ids", None) 987 | if past_key_values: 988 | input_ids = input_ids[:, -1].unsqueeze(-1) 989 | if token_type_ids is not None: 990 | token_type_ids = token_type_ids[:, -1].unsqueeze(-1) 991 | 992 | attention_mask = kwargs.get("attention_mask", None) 993 | position_ids = kwargs.get("position_ids", None) 994 | 995 | if attention_mask is not None and position_ids is None: 996 | position_ids = attention_mask.long().cumsum(-1) - 1 997 | position_ids.masked_fill_(attention_mask == 0, 1) 998 | if past_key_values: 999 | position_ids = position_ids[:, -1].unsqueeze(-1) 1000 | else: 1001 | position_ids = None 1002 | 1003 | if inputs_embeds is not None and past_key_values is None: 1004 | model_inputs = {"inputs_embeds": inputs_embeds} 1005 | else: 1006 | model_inputs = {"input_ids": input_ids} 1007 | 1008 | model_inputs.update( 1009 | { 1010 | "past_key_values": past_key_values, 1011 | "use_cache": kwargs.get("use_cache"), 1012 | "position_ids": position_ids, 1013 | "attention_mask": attention_mask, 1014 | "token_type_ids": token_type_ids, 1015 | } 1016 | ) 1017 | return model_inputs 1018 | 1019 | def forward( 1020 | self, 1021 | input_ids: Optional[torch.LongTensor] = None, 1022 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 1023 | attention_mask: Optional[torch.FloatTensor] = None, 1024 | token_type_ids: Optional[torch.LongTensor] = None, 1025 | position_ids: Optional[torch.LongTensor] = None, 1026 | head_mask: Optional[torch.FloatTensor] = None, 1027 | inputs_embeds: Optional[torch.FloatTensor] = None, 1028 | encoder_hidden_states: Optional[torch.Tensor] = None, 1029 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 1030 | labels: Optional[torch.LongTensor] = None, 1031 | use_cache: Optional[bool] = None, 1032 | output_attentions: Optional[bool] = None, 1033 | output_hidden_states: Optional[bool] = None, 1034 | return_dict: Optional[bool] = None, 1035 | ) -> Union[Tuple, CausalLMOutputWithPast]: 1036 | 1037 | return_dict = ( 1038 | return_dict if return_dict is not None else self.config.use_return_dict 1039 | ) 1040 | 1041 | transformer_outputs = self.transformer( 1042 | input_ids, 1043 | past_key_values=past_key_values, 1044 | attention_mask=attention_mask, 1045 | token_type_ids=token_type_ids, 1046 | position_ids=position_ids, 1047 | head_mask=head_mask, 1048 | inputs_embeds=inputs_embeds, 1049 | encoder_hidden_states=encoder_hidden_states, 1050 | encoder_attention_mask=encoder_attention_mask, 1051 | use_cache=use_cache, 1052 | output_attentions=output_attentions, 1053 | output_hidden_states=output_hidden_states, 1054 | return_dict=return_dict, 1055 | ) 1056 | hidden_states = transformer_outputs[0] 1057 | 1058 | lm_logits = self.lm_head(hidden_states) 1059 | 1060 | loss = None 1061 | if labels is not None: 1062 | labels = labels.to(lm_logits.device) 1063 | shift_logits = lm_logits[..., :-1, :].contiguous() 1064 | shift_labels = labels[..., 1:].contiguous() 1065 | loss_fct = CrossEntropyLoss() 1066 | loss = loss_fct( 1067 | shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) 1068 | ) 1069 | 1070 | if not return_dict: 1071 | output = (lm_logits,) + transformer_outputs[1:] 1072 | return ((loss,) + output) if loss is not None else output 1073 | 1074 | return CausalLMOutputWithPast( 1075 | loss=loss, 1076 | logits=lm_logits, 1077 | past_key_values=transformer_outputs.past_key_values, 1078 | hidden_states=transformer_outputs.hidden_states, 1079 | attentions=transformer_outputs.attentions, 1080 | ) 1081 | 1082 | @staticmethod 1083 | def _reorder_cache( 1084 | past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor 1085 | ) -> Tuple[Tuple[torch.Tensor]]: 1086 | 1087 | return tuple( 1088 | tuple( 1089 | past_state.index_select(0, beam_idx.to(past_state.device)) 1090 | for past_state in layer_past 1091 | ) 1092 | for layer_past in past_key_values 1093 | ) 1094 | 1095 | def chat( 1096 | self, 1097 | tokenizer: PreTrainedTokenizer, 1098 | query: str, 1099 | history: Optional[HistoryType], 1100 | system: str = "You are a helpful assistant.", 1101 | append_history: bool = True, 1102 | stream: Optional[bool] = _SENTINEL, 1103 | stop_words_ids: Optional[List[List[int]]] = None, 1104 | generation_config: Optional[GenerationConfig] = None, 1105 | **kwargs, 1106 | ) -> Tuple[str, HistoryType]: 1107 | generation_config = generation_config if generation_config is not None else self.generation_config 1108 | 1109 | assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT 1110 | assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT 1111 | if history is None: 1112 | history = [] 1113 | if stop_words_ids is None: 1114 | stop_words_ids = [] 1115 | 1116 | max_window_size = kwargs.get('max_window_size', None) 1117 | if max_window_size is None: 1118 | max_window_size = generation_config.max_window_size 1119 | raw_text, context_tokens = make_context( 1120 | tokenizer, 1121 | query, 1122 | history=history, 1123 | system=system, 1124 | max_window_size=max_window_size, 1125 | chat_format=generation_config.chat_format, 1126 | ) 1127 | 1128 | stop_words_ids.extend(get_stop_words_ids( 1129 | generation_config.chat_format, tokenizer 1130 | )) 1131 | input_ids = torch.tensor([context_tokens]).to(self.device) 1132 | outputs = self.generate( 1133 | input_ids, 1134 | stop_words_ids=stop_words_ids, 1135 | return_dict_in_generate=False, 1136 | generation_config=generation_config, 1137 | **kwargs, 1138 | ) 1139 | 1140 | response = decode_tokens( 1141 | outputs[0], 1142 | tokenizer, 1143 | raw_text_len=len(raw_text), 1144 | context_length=len(context_tokens), 1145 | chat_format=generation_config.chat_format, 1146 | verbose=False, 1147 | errors='replace' 1148 | ) 1149 | 1150 | if append_history: 1151 | history.append((query, response)) 1152 | 1153 | return response, history 1154 | 1155 | def chat_stream( 1156 | self, 1157 | tokenizer: PreTrainedTokenizer, 1158 | query: str, 1159 | history: Optional[HistoryType], 1160 | system: str = "You are a helpful assistant.", 1161 | stop_words_ids: Optional[List[List[int]]] = None, 1162 | logits_processor: Optional[LogitsProcessorList] = None, 1163 | generation_config: Optional[GenerationConfig] = None, 1164 | **kwargs, 1165 | ) -> Generator[str, Any, None]: 1166 | generation_config = generation_config if generation_config is not None else self.generation_config 1167 | assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT 1168 | if history is None: 1169 | history = [] 1170 | if stop_words_ids is None: 1171 | stop_words_ids = [] 1172 | 1173 | max_window_size = kwargs.get('max_window_size', None) 1174 | if max_window_size is None: 1175 | max_window_size = generation_config.max_window_size 1176 | raw_text, context_tokens = make_context( 1177 | tokenizer, 1178 | query, 1179 | history=history, 1180 | system=system, 1181 | max_window_size=max_window_size, 1182 | chat_format=generation_config.chat_format, 1183 | ) 1184 | 1185 | stop_words_ids.extend(get_stop_words_ids( 1186 | generation_config.chat_format, tokenizer 1187 | )) 1188 | if stop_words_ids is not None: 1189 | stop_words_logits_processor = StopWordsLogitsProcessor( 1190 | stop_words_ids=stop_words_ids, 1191 | eos_token_id=generation_config.eos_token_id, 1192 | ) 1193 | if logits_processor is None: 1194 | logits_processor = LogitsProcessorList([stop_words_logits_processor]) 1195 | else: 1196 | logits_processor.append(stop_words_logits_processor) 1197 | input_ids = torch.tensor([context_tokens]).to(self.device) 1198 | 1199 | from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig 1200 | self.__class__.generate_stream = NewGenerationMixin.generate 1201 | self.__class__.sample_stream = NewGenerationMixin.sample_stream 1202 | stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True) 1203 | 1204 | def stream_generator(): 1205 | outputs = [] 1206 | for token in self.generate_stream( 1207 | input_ids, 1208 | return_dict_in_generate=False, 1209 | generation_config=stream_config, 1210 | logits_processor=logits_processor, 1211 | seed=-1, 1212 | **kwargs): 1213 | outputs.append(token.item()) 1214 | yield tokenizer.decode(outputs, skip_special_tokens=True, errors='ignore', keep_image_special=True) 1215 | 1216 | return stream_generator() 1217 | 1218 | def generate( 1219 | self, 1220 | inputs: Optional[torch.Tensor] = None, 1221 | generation_config: Optional[GenerationConfig] = None, 1222 | logits_processor: Optional[LogitsProcessorList] = None, 1223 | stopping_criteria: Optional[StoppingCriteriaList] = None, 1224 | prefix_allowed_tokens_fn: Optional[ 1225 | Callable[[int, torch.Tensor], List[int]] 1226 | ] = None, 1227 | synced_gpus: Optional[bool] = None, 1228 | assistant_model: Optional["PreTrainedModel"] = None, 1229 | streamer: Optional["BaseStreamer"] = None, 1230 | **kwargs, 1231 | ) -> Union[GenerateOutput, torch.LongTensor]: 1232 | generation_config = generation_config if generation_config is not None else self.generation_config 1233 | 1234 | # Process stop_words_ids. 1235 | stop_words_ids = kwargs.pop("stop_words_ids", None) 1236 | if stop_words_ids is None and generation_config is not None: 1237 | stop_words_ids = getattr(generation_config, "stop_words_ids", None) 1238 | if stop_words_ids is None: 1239 | stop_words_ids = getattr(generation_config, "stop_words_ids", None) 1240 | 1241 | if stop_words_ids is not None: 1242 | stop_words_logits_processor = StopWordsLogitsProcessor( 1243 | stop_words_ids=stop_words_ids, 1244 | eos_token_id=generation_config.eos_token_id, 1245 | ) 1246 | if logits_processor is None: 1247 | logits_processor = LogitsProcessorList([stop_words_logits_processor]) 1248 | else: 1249 | logits_processor.append(stop_words_logits_processor) 1250 | 1251 | return super().generate( 1252 | inputs, 1253 | generation_config=generation_config, 1254 | logits_processor=logits_processor, 1255 | stopping_criteria=stopping_criteria, 1256 | prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, 1257 | synced_gpus=synced_gpus, 1258 | assistant_model=assistant_model, 1259 | streamer=streamer, 1260 | **kwargs, 1261 | ) 1262 | 1263 | 1264 | class RotaryEmbedding(torch.nn.Module): 1265 | def __init__(self, dim, base=10000): 1266 | super().__init__() 1267 | self.dim = dim 1268 | self.base = base 1269 | self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) 1270 | if importlib.util.find_spec("einops") is None: 1271 | raise RuntimeError("einops is required for Rotary Embedding") 1272 | 1273 | self._rotary_pos_emb_cache = None 1274 | self._seq_len_cached = 0 1275 | self._ntk_alpha_cached = 1.0 1276 | 1277 | def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0): 1278 | seqlen = max_seq_len + offset 1279 | if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached: 1280 | base = self.base * ntk_alpha ** (self.dim / (self.dim - 2)) 1281 | self.inv_freq = 1.0 / ( 1282 | base 1283 | ** ( 1284 | torch.arange(0, self.dim, 2, device=self.inv_freq.device).float() 1285 | / self.dim 1286 | ) 1287 | ) 1288 | self._seq_len_cached = max(2 * seqlen, 16) 1289 | self._ntk_alpha_cached = ntk_alpha 1290 | seq = torch.arange(self._seq_len_cached, device=self.inv_freq.device) 1291 | freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq) 1292 | 1293 | emb = torch.cat((freqs, freqs), dim=-1) 1294 | from einops import rearrange 1295 | 1296 | emb = rearrange(emb, "n d -> 1 n 1 d") 1297 | 1298 | cos, sin = emb.cos(), emb.sin() 1299 | self._rotary_pos_emb_cache = [cos, sin] 1300 | 1301 | def forward(self, max_seq_len, offset=0, ntk_alpha=1.0): 1302 | self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha) 1303 | cos, sin = self._rotary_pos_emb_cache 1304 | return [cos[:, offset : offset + max_seq_len], sin[:, offset : offset + max_seq_len]] 1305 | 1306 | 1307 | def _rotate_half(x): 1308 | from einops import rearrange 1309 | 1310 | x = rearrange(x, "... (j d) -> ... j d", j=2) 1311 | x1, x2 = x.unbind(dim=-2) 1312 | return torch.cat((-x2, x1), dim=-1) 1313 | 1314 | 1315 | def apply_rotary_pos_emb(t, freqs): 1316 | cos, sin = freqs 1317 | if apply_rotary_emb_func is not None and t.is_cuda: 1318 | t_ = t.float() 1319 | cos = cos.squeeze(0).squeeze(1)[:, : cos.shape[-1] // 2] 1320 | sin = sin.squeeze(0).squeeze(1)[:, : sin.shape[-1] // 2] 1321 | output = apply_rotary_emb_func(t_, cos, sin).type_as(t) 1322 | return output 1323 | else: 1324 | rot_dim = freqs[0].shape[-1] 1325 | cos, sin = freqs 1326 | t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:] 1327 | t_ = t_.float() 1328 | t_pass_ = t_pass_.float() 1329 | t_ = (t_ * cos) + (_rotate_half(t_) * sin) 1330 | return torch.cat((t_, t_pass_), dim=-1).type_as(t) 1331 | 1332 | 1333 | class RMSNorm(torch.nn.Module): 1334 | def __init__(self, dim: int, eps: float = 1e-6): 1335 | super().__init__() 1336 | self.eps = eps 1337 | self.weight = nn.Parameter(torch.ones(dim)) 1338 | 1339 | def _norm(self, x): 1340 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 1341 | 1342 | def forward(self, x): 1343 | if rms_norm is not None and x.is_cuda: 1344 | return rms_norm(x, self.weight, self.eps) 1345 | else: 1346 | output = self._norm(x.float()).type_as(x) 1347 | return output * self.weight 1348 | -------------------------------------------------------------------------------- /OdysseyAgent/qwen_generation_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba Cloud. 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """Generation support.""" 7 | 8 | from typing import Tuple, List, Union, Iterable 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn.functional as F 13 | from transformers import PreTrainedTokenizer 14 | from transformers import logging 15 | from transformers.generation import LogitsProcessor 16 | 17 | logger = logging.get_logger(__name__) 18 | 19 | # Types. 20 | HistoryType = List[Tuple[str, str]] 21 | TokensType = List[int] 22 | BatchTokensType = List[List[int]] 23 | 24 | 25 | def pad_batch(batch: BatchTokensType, pad_id: int, seq_length: int) -> BatchTokensType: 26 | for tokens in batch: 27 | context_length = len(tokens) 28 | if context_length < seq_length: 29 | tokens.extend([pad_id] * (seq_length - context_length)) 30 | return batch 31 | 32 | 33 | def get_ltor_masks_and_position_ids( 34 | data, 35 | eod_token, 36 | reset_position_ids, 37 | reset_attention_mask, 38 | eod_mask_loss, 39 | ): 40 | """Build masks and position id for left to right model.""" 41 | 42 | # Extract batch size and sequence length. 43 | micro_batch_size, seq_length = data.size() 44 | 45 | # Attention mask (lower triangular). 46 | if reset_attention_mask: 47 | att_mask_batch = micro_batch_size 48 | else: 49 | att_mask_batch = 1 50 | attention_mask = torch.tril( 51 | torch.ones((att_mask_batch, seq_length, seq_length), device=data.device) 52 | ).view(att_mask_batch, 1, seq_length, seq_length) 53 | 54 | # Loss mask. 55 | loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) 56 | if eod_mask_loss: 57 | loss_mask[data == eod_token] = 0.0 58 | 59 | # Position ids. 60 | position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) 61 | position_ids = position_ids.unsqueeze(0).expand_as(data) 62 | # We need to clone as the ids will be modifed based on batch index. 63 | if reset_position_ids: 64 | position_ids = position_ids.clone() 65 | 66 | if reset_position_ids or reset_attention_mask: 67 | # Loop through the batches: 68 | for b in range(micro_batch_size): 69 | 70 | # Find indecies where EOD token is. 71 | eod_index = position_ids[b, data[b] == eod_token] 72 | # Detach indecies from positions if going to modify positions. 73 | if reset_position_ids: 74 | eod_index = eod_index.clone() 75 | 76 | # Loop through EOD indecies: 77 | prev_index = 0 78 | for j in range(eod_index.size()[0]): 79 | i = eod_index[j] 80 | # Mask attention loss. 81 | if reset_attention_mask: 82 | attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0 83 | # Reset positions. 84 | if reset_position_ids: 85 | position_ids[b, (i + 1) :] -= i + 1 - prev_index 86 | prev_index = i + 1 87 | 88 | # Convert attention mask to binary: 89 | attention_mask = attention_mask < 0.5 90 | 91 | return attention_mask, loss_mask, position_ids 92 | 93 | 94 | def get_batch(context_tokens: torch.LongTensor, eod_id: int): 95 | """Generate batch from context tokens.""" 96 | # Move to GPU. 97 | tokens = context_tokens.contiguous().to(context_tokens.device) 98 | # Get the attention mask and postition ids. 99 | attention_mask, _, position_ids = get_ltor_masks_and_position_ids( 100 | tokens, 101 | eod_id, 102 | reset_position_ids=False, 103 | reset_attention_mask=False, 104 | eod_mask_loss=False, 105 | ) 106 | return tokens, attention_mask, position_ids 107 | 108 | 109 | def get_stop_words_ids(chat_format, tokenizer): 110 | if chat_format == "raw": 111 | stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]] 112 | elif chat_format == "chatml": 113 | stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]] 114 | else: 115 | raise NotImplementedError(f"Unknown chat format {chat_format!r}") 116 | return stop_words_ids 117 | 118 | 119 | def make_context( 120 | tokenizer: PreTrainedTokenizer, 121 | query: str, 122 | history: List[Tuple[str, str]] = None, 123 | system: str = "", 124 | max_window_size: int = 6144, 125 | chat_format: str = "chatml", 126 | ): 127 | if history is None: 128 | history = [] 129 | 130 | if chat_format == "chatml": 131 | im_start, im_end = "<|im_start|>", "<|im_end|>" 132 | im_start_tokens = [tokenizer.im_start_id] 133 | im_end_tokens = [tokenizer.im_end_id] 134 | nl_tokens = tokenizer.encode("\n") 135 | 136 | def _tokenize_str(role, content): 137 | return f"{role}\n{content}", tokenizer.encode( 138 | role, allowed_special=set(tokenizer.IMAGE_ST) 139 | ) + nl_tokens + tokenizer.encode(content, allowed_special=set(tokenizer.IMAGE_ST)) 140 | 141 | system_text, system_tokens_part = _tokenize_str("system", system) 142 | system_tokens = im_start_tokens + system_tokens_part + im_end_tokens 143 | 144 | raw_text = "" 145 | context_tokens = [] 146 | 147 | for turn_query, turn_response in reversed(history): 148 | query_text, query_tokens_part = _tokenize_str("user", turn_query) 149 | query_tokens = im_start_tokens + query_tokens_part + im_end_tokens 150 | if turn_response is not None: 151 | response_text, response_tokens_part = _tokenize_str( 152 | "assistant", turn_response 153 | ) 154 | response_tokens = im_start_tokens + response_tokens_part + im_end_tokens 155 | 156 | next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens 157 | prev_chat = ( 158 | f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}" 159 | ) 160 | else: 161 | next_context_tokens = nl_tokens + query_tokens + nl_tokens 162 | prev_chat = f"\n{im_start}{query_text}{im_end}\n" 163 | 164 | current_context_size = ( 165 | len(system_tokens) + len(next_context_tokens) + len(context_tokens) 166 | ) 167 | if current_context_size < max_window_size: 168 | context_tokens = next_context_tokens + context_tokens 169 | raw_text = prev_chat + raw_text 170 | else: 171 | break 172 | 173 | context_tokens = system_tokens + context_tokens 174 | raw_text = f"{im_start}{system_text}{im_end}" + raw_text 175 | context_tokens += ( 176 | nl_tokens 177 | + im_start_tokens 178 | + _tokenize_str("user", query)[1] 179 | + im_end_tokens 180 | + nl_tokens 181 | + im_start_tokens 182 | + tokenizer.encode("assistant") 183 | + nl_tokens 184 | ) 185 | raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n" 186 | 187 | elif chat_format == "raw": 188 | raw_text = query 189 | context_tokens = tokenizer.encode(raw_text) 190 | else: 191 | raise NotImplementedError(f"Unknown chat format {chat_format!r}") 192 | 193 | return raw_text, context_tokens 194 | 195 | 196 | def _decode_default( 197 | tokens: List[int], 198 | *, 199 | stop_words: List[str], 200 | eod_words: List[str], 201 | tokenizer: PreTrainedTokenizer, 202 | raw_text_len: int, 203 | verbose: bool = False, 204 | return_end_reason: bool = False, 205 | errors: str='replace', 206 | ): 207 | trim_decode_tokens = tokenizer.decode(tokens, errors=errors)[raw_text_len:] 208 | if verbose: 209 | print("\nRaw Generate: ", trim_decode_tokens) 210 | 211 | end_reason = f"Gen length {len(tokens)}" 212 | for stop_word in stop_words: 213 | trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip() 214 | for eod_word in eod_words: 215 | if eod_word in trim_decode_tokens: 216 | end_reason = f"Gen {eod_word!r}" 217 | trim_decode_tokens = trim_decode_tokens.split(eod_word)[0] 218 | trim_decode_tokens = trim_decode_tokens.strip() 219 | if verbose: 220 | print("\nEnd Reason:", end_reason) 221 | print("\nGenerate: ", trim_decode_tokens) 222 | 223 | if return_end_reason: 224 | return trim_decode_tokens, end_reason 225 | else: 226 | return trim_decode_tokens 227 | 228 | 229 | def _decode_chatml( 230 | tokens: List[int], 231 | *, 232 | stop_words: List[str], 233 | eod_token_ids: List[int], 234 | tokenizer: PreTrainedTokenizer, 235 | raw_text_len: int, 236 | context_length: int, 237 | verbose: bool = False, 238 | return_end_reason: bool = False, 239 | errors: str='replace' 240 | ): 241 | end_reason = f"Gen length {len(tokens)}" 242 | eod_token_idx = context_length 243 | for eod_token_idx in range(context_length, len(tokens)): 244 | if tokens[eod_token_idx] in eod_token_ids: 245 | end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}" 246 | break 247 | 248 | trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors)[raw_text_len:] 249 | if verbose: 250 | print("\nRaw Generate w/o EOD:", tokenizer.decode(tokens, errors=errors)[raw_text_len:]) 251 | print("\nRaw Generate:", trim_decode_tokens) 252 | print("\nEnd Reason:", end_reason) 253 | for stop_word in stop_words: 254 | trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip() 255 | trim_decode_tokens = trim_decode_tokens.strip() 256 | if verbose: 257 | print("\nGenerate:", trim_decode_tokens) 258 | 259 | if return_end_reason: 260 | return trim_decode_tokens, end_reason 261 | else: 262 | return trim_decode_tokens 263 | 264 | 265 | def decode_tokens( 266 | tokens: Union[torch.LongTensor, TokensType], 267 | tokenizer: PreTrainedTokenizer, 268 | raw_text_len: int, 269 | context_length: int, 270 | chat_format: str, 271 | verbose: bool = False, 272 | return_end_reason: bool = False, 273 | errors: str="replace", 274 | ) -> str: 275 | if torch.is_tensor(tokens): 276 | tokens = tokens.cpu().numpy().tolist() 277 | 278 | if chat_format == "chatml": 279 | return _decode_chatml( 280 | tokens, 281 | stop_words=[], 282 | eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id], 283 | tokenizer=tokenizer, 284 | raw_text_len=raw_text_len, 285 | context_length=context_length, 286 | verbose=verbose, 287 | return_end_reason=return_end_reason, 288 | errors=errors, 289 | ) 290 | elif chat_format == "raw": 291 | return _decode_default( 292 | tokens, 293 | stop_words=["<|endoftext|>"], 294 | eod_words=["<|endoftext|>"], 295 | tokenizer=tokenizer, 296 | raw_text_len=raw_text_len, 297 | verbose=verbose, 298 | return_end_reason=return_end_reason, 299 | errors=errors, 300 | ) 301 | else: 302 | raise NotImplementedError(f"Unknown chat format {chat_format!r}") 303 | 304 | 305 | class StopWordsLogitsProcessor(LogitsProcessor): 306 | """ 307 | :class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration. 308 | 309 | Args: 310 | stop_words_ids (:obj:`List[List[int]]`): 311 | List of list of token ids of stop ids. In order to get the tokens of the words 312 | that should not appear in the generated text, use :obj:`tokenizer(bad_word, 313 | add_prefix_space=True).input_ids`. 314 | eos_token_id (:obj:`int`): 315 | The id of the `end-of-sequence` token. 316 | """ 317 | 318 | def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int): 319 | 320 | if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0: 321 | raise ValueError( 322 | f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}." 323 | ) 324 | if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids): 325 | raise ValueError( 326 | f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}." 327 | ) 328 | if any( 329 | any( 330 | (not isinstance(token_id, (int, np.integer)) or token_id < 0) 331 | for token_id in stop_word_ids 332 | ) 333 | for stop_word_ids in stop_words_ids 334 | ): 335 | raise ValueError( 336 | f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}." 337 | ) 338 | 339 | self.stop_words_ids = list( 340 | filter( 341 | lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids 342 | ) 343 | ) 344 | self.eos_token_id = eos_token_id 345 | for stop_token_seq in self.stop_words_ids: 346 | assert ( 347 | len(stop_token_seq) > 0 348 | ), "Stop words token sequences {} cannot have an empty list".format( 349 | stop_words_ids 350 | ) 351 | 352 | def __call__( 353 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor 354 | ) -> torch.FloatTensor: 355 | stopped_samples = self._calc_stopped_samples(input_ids) 356 | for i, should_stop in enumerate(stopped_samples): 357 | if should_stop: 358 | scores[i, self.eos_token_id] = float(2**15) 359 | return scores 360 | 361 | def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool: 362 | if len(tokens) == 0: 363 | # if bad word tokens is just one token always ban it 364 | return True 365 | elif len(tokens) > len(prev_tokens): 366 | # if bad word tokens are longer then prev input_ids they can't be equal 367 | return False 368 | elif prev_tokens[-len(tokens) :].tolist() == tokens: 369 | # if tokens match 370 | return True 371 | else: 372 | return False 373 | 374 | def _calc_stopped_samples(self, prev_input_ids: Iterable[int]) -> Iterable[int]: 375 | stopped_samples = [] 376 | for prev_input_ids_slice in prev_input_ids: 377 | match = False 378 | for stop_token_seq in self.stop_words_ids: 379 | if self._tokens_match(prev_input_ids_slice, stop_token_seq): 380 | # if tokens do not match continue 381 | match = True 382 | break 383 | stopped_samples.append(match) 384 | 385 | return stopped_samples 386 | 387 | 388 | def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")): 389 | """This function has been mostly taken from huggingface conversational 390 | ai code at 391 | https://medium.com/huggingface/how-to-build-a-state-of-the-art- 392 | conversational-ai-with-transfer-learning-2d818ac26313""" 393 | 394 | if top_k > 0: 395 | # Remove all tokens with a probability less than the 396 | # last token of the top-k 397 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 398 | logits[indices_to_remove] = filter_value 399 | 400 | if top_p > 0.0: 401 | # Cconvert to 1D 402 | sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) 403 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 404 | 405 | # Remove tokens with cumulative probability above the threshold 406 | sorted_indices_to_remove = cumulative_probs > top_p 407 | # Shift the indices to the right to keep also the first token 408 | # above the threshold 409 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 410 | sorted_indices_to_remove[..., 0] = 0 411 | for i in range(sorted_indices.size(0)): 412 | indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] 413 | logits[i][indices_to_remove] = filter_value 414 | 415 | return logits 416 | 417 | 418 | def switch(val1, val2, boolean): 419 | boolean = boolean.type_as(val1) 420 | return (1 - boolean) * val1 + boolean * val2 421 | -------------------------------------------------------------------------------- /OdysseyAgent/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /OdysseyAgent/tokenization_qwen.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba Cloud. 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """Tokenization classes for QWen.""" 7 | 8 | import base64 9 | import logging 10 | import os 11 | import requests 12 | import unicodedata 13 | from typing import Collection, Dict, List, Set, Tuple, Union, Any, Callable, Optional 14 | 15 | import tiktoken 16 | import numpy as np 17 | from PIL import Image 18 | from PIL import ImageFont 19 | from PIL import ImageDraw 20 | from transformers import PreTrainedTokenizer, AddedToken 21 | from transformers.utils import try_to_load_from_cache 22 | 23 | import matplotlib.colors as mcolors 24 | from matplotlib.font_manager import FontProperties 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken", "ttf": "SimSun.ttf"} 30 | FONT_PATH = try_to_load_from_cache("Qwen/Qwen-VL-Chat", "SimSun.ttf") 31 | if FONT_PATH is None: 32 | if not os.path.exists("SimSun.ttf"): 33 | ttf = requests.get("https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/SimSun.ttf") 34 | open("SimSun.ttf", "wb").write(ttf.content) 35 | FONT_PATH = "SimSun.ttf" 36 | 37 | PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" 38 | ENDOFTEXT = "<|endoftext|>" 39 | IMSTART = "<|im_start|>" 40 | IMEND = "<|im_end|>" 41 | # as the default behavior is changed to allow special tokens in 42 | # regular texts, the surface forms of special tokens need to be 43 | # as different as possible to minimize the impact 44 | EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205))) 45 | SPECIAL_TOKENS = ( 46 | ENDOFTEXT, 47 | IMSTART, 48 | IMEND, 49 | ) + EXTRAS 50 | IMG_TOKEN_SPAN = 256 51 | 52 | 53 | def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]: 54 | with open(tiktoken_bpe_file, "rb") as f: 55 | contents = f.read() 56 | return { 57 | base64.b64decode(token): int(rank) 58 | for token, rank in (line.split() for line in contents.splitlines() if line) 59 | } 60 | 61 | def _list_find( 62 | input_list: List[Any], 63 | candidates: Tuple[Any], 64 | start: int = 0, 65 | ): 66 | for i in range(start, len(input_list)): 67 | if input_list[i] in candidates: 68 | return i 69 | return -1 70 | 71 | def _replace_closed_tag( 72 | input_tokens: List[Any], 73 | start_tags: Union[Any, Tuple[Any]], 74 | end_tags: Union[Any, Tuple[Any]], 75 | inclusive_replace_func: Callable, 76 | exclusive_replace_func: Callable = lambda x: x, 77 | ): 78 | if isinstance(start_tags, (str, int)): 79 | start_tags = (start_tags,) 80 | if isinstance(end_tags, (str, int)): 81 | end_tags = (end_tags,) 82 | assert len(start_tags) == len(end_tags) 83 | 84 | output_tokens = [] 85 | end = 0 86 | while True: 87 | start = _list_find(input_tokens, start_tags, end) 88 | if start == -1: 89 | break 90 | output_tokens.extend(exclusive_replace_func(input_tokens[end : start])) 91 | tag_idx = start_tags.index(input_tokens[start]) 92 | end = _list_find(input_tokens, (end_tags[tag_idx],), start) 93 | if end == -1: 94 | raise ValueError("Unclosed image token") 95 | output_tokens.extend(inclusive_replace_func(input_tokens[start : end + 1])) 96 | end += 1 97 | output_tokens.extend(exclusive_replace_func(input_tokens[end : ])) 98 | return output_tokens 99 | 100 | class QWenTokenizer(PreTrainedTokenizer): 101 | """QWen tokenizer.""" 102 | 103 | vocab_files_names = VOCAB_FILES_NAMES 104 | 105 | def __init__( 106 | self, 107 | vocab_file, 108 | errors="replace", 109 | image_start_tag='', 110 | image_end_tag='', 111 | image_pad_tag='', 112 | ref_start_tag='', 113 | ref_end_tag='', 114 | box_start_tag='', 115 | box_end_tag='', 116 | quad_start_tag='', 117 | quad_end_tag='', 118 | **kwargs, 119 | ): 120 | super().__init__(**kwargs) 121 | self.image_start_tag = image_start_tag 122 | self.image_end_tag = image_end_tag 123 | self.image_pad_tag = image_pad_tag 124 | self.ref_start_tag = ref_start_tag 125 | self.ref_end_tag = ref_end_tag 126 | self.box_start_tag = box_start_tag 127 | self.box_end_tag = box_end_tag 128 | self.quad_start_tag = quad_start_tag 129 | self.quad_end_tag = quad_end_tag 130 | self.IMAGE_ST = ( 131 | ref_start_tag, ref_end_tag, 132 | box_start_tag, box_end_tag, 133 | quad_start_tag, quad_end_tag, 134 | image_start_tag, image_end_tag, 135 | image_pad_tag 136 | ) 137 | 138 | self.errors = errors # how to handle errors in decoding 139 | 140 | self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: dict[bytes, int] 141 | self.special_tokens = { 142 | token: index 143 | for index, token in enumerate( 144 | SPECIAL_TOKENS + self.IMAGE_ST, start=len(self.mergeable_ranks) 145 | ) 146 | } 147 | self.img_start_id = self.special_tokens[self.image_start_tag] 148 | self.img_end_id = self.special_tokens[self.image_end_tag] 149 | self.img_pad_id = self.special_tokens[self.image_pad_tag] 150 | self.ref_start_id = self.special_tokens[self.ref_start_tag] 151 | self.ref_end_id = self.special_tokens[self.ref_end_tag] 152 | self.box_start_id = self.special_tokens[self.box_start_tag] 153 | self.box_end_id = self.special_tokens[self.box_end_tag] 154 | self.quad_start_id = self.special_tokens[self.quad_start_tag] 155 | self.quad_end_id = self.special_tokens[self.quad_end_tag] 156 | self.image_special_tokens = set([ 157 | self.ref_start_id, self.ref_end_id, self.box_start_id, self.box_end_id, 158 | self.quad_start_id, self.quad_end_id, 159 | ]) 160 | 161 | enc = tiktoken.Encoding( 162 | "Qwen", 163 | pat_str=PAT_STR, 164 | mergeable_ranks=self.mergeable_ranks, 165 | special_tokens=self.special_tokens, 166 | ) 167 | assert ( 168 | len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab 169 | ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding" 170 | 171 | self.decoder = { 172 | v: k for k, v in self.mergeable_ranks.items() 173 | } # type: dict[int, bytes|str] 174 | self.decoder.update({v: k for k, v in self.special_tokens.items()}) 175 | 176 | self.tokenizer = enc # type: tiktoken.Encoding 177 | 178 | self.eod_id = self.tokenizer.eot_token 179 | self.im_start_id = self.special_tokens[IMSTART] 180 | self.im_end_id = self.special_tokens[IMEND] 181 | 182 | def __getstate__(self): 183 | # for pickle lovers 184 | state = self.__dict__.copy() 185 | del state['tokenizer'] 186 | return state 187 | 188 | def __setstate__(self, state): 189 | # tokenizer is not python native; don't pass it; rebuild it 190 | self.__dict__.update(state) 191 | enc = tiktoken.Encoding( 192 | "Qwen", 193 | pat_str=PAT_STR, 194 | mergeable_ranks=self.mergeable_ranks, 195 | special_tokens=self.special_tokens, 196 | ) 197 | self.tokenizer = enc 198 | 199 | 200 | def __len__(self) -> int: 201 | return self.tokenizer.n_vocab 202 | 203 | def get_vocab(self) -> Dict[bytes, int]: 204 | return self.mergeable_ranks 205 | 206 | def convert_tokens_to_ids( 207 | self, tokens: Union[bytes, str, List[Union[bytes, str]]] 208 | ) -> List[int]: 209 | ids = [] 210 | if isinstance(tokens, (str, bytes)): 211 | if tokens in self.special_tokens: 212 | return self.special_tokens[tokens] 213 | else: 214 | return self.mergeable_ranks.get(tokens) 215 | for token in tokens: 216 | if token in self.special_tokens: 217 | ids.append(self.special_tokens[token]) 218 | else: 219 | ids.append(self.mergeable_ranks.get(token)) 220 | return ids 221 | 222 | def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int: 223 | if not special_tokens and new_tokens: 224 | raise ValueError('Adding regular tokens is not supported') 225 | for token in new_tokens: 226 | surface_form = token.content if isinstance(token, AddedToken) else token 227 | if surface_form not in SPECIAL_TOKENS + self.IMAGE_ST: 228 | raise ValueError('Adding unknown special tokens is not supported') 229 | return 0 230 | 231 | def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]: 232 | """ 233 | Save only the vocabulary of the tokenizer (vocabulary). 234 | 235 | Returns: 236 | `Tuple(str)`: Paths to the files saved. 237 | """ 238 | file_path = os.path.join(save_directory, "qwen.tiktoken") 239 | with open(file_path, "w", encoding="utf8") as w: 240 | for k, v in self.mergeable_ranks.items(): 241 | line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n" 242 | w.write(line) 243 | return (file_path,) 244 | 245 | def tokenize( 246 | self, 247 | text: str, 248 | allowed_special: Union[Set, str] = "all", 249 | disallowed_special: Union[Collection, str] = (), 250 | **kwargs, 251 | ) -> List[Union[bytes, str]]: 252 | """ 253 | Converts a string in a sequence of tokens. 254 | 255 | Args: 256 | text (`str`): 257 | The sequence to be encoded. 258 | allowed_special (`Literal["all"]` or `set`): 259 | The surface forms of the tokens to be encoded as special tokens in regular texts. 260 | Default to "all". 261 | disallowed_special (`Literal["all"]` or `Collection`): 262 | The surface forms of the tokens that should not be in regular texts and trigger errors. 263 | Default to an empty tuple. 264 | 265 | kwargs (additional keyword arguments, *optional*): 266 | Will be passed to the underlying model specific encode method. 267 | 268 | Returns: 269 | `List[bytes|str]`: The list of tokens. 270 | """ 271 | tokens = [] 272 | text = unicodedata.normalize("NFC", text) 273 | 274 | # this implementation takes a detour: text -> token id -> token surface forms 275 | for t in self.tokenizer.encode( 276 | text, allowed_special=allowed_special, disallowed_special=disallowed_special 277 | ): 278 | tokens.append(self.decoder[t]) 279 | 280 | def _encode_imgurl(img_tokens): 281 | assert img_tokens[0] == self.image_start_tag and img_tokens[-1] == self.image_end_tag 282 | img_tokens = img_tokens[1:-1] 283 | img_url = b''.join(img_tokens) 284 | out_img_tokens = list(map(self.decoder.get, img_url)) 285 | if len(out_img_tokens) > IMG_TOKEN_SPAN: 286 | raise ValueError("The content in {}..{} is too long".format( 287 | self.image_start_tag, self.image_end_tag)) 288 | out_img_tokens.extend([self.image_pad_tag] * (IMG_TOKEN_SPAN - len(out_img_tokens))) 289 | out_img_tokens = [self.image_start_tag] + out_img_tokens + [self.image_end_tag] 290 | return out_img_tokens 291 | 292 | return _replace_closed_tag(tokens, self.image_start_tag, self.image_end_tag, _encode_imgurl) 293 | 294 | def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str: 295 | """ 296 | Converts a sequence of tokens in a single string. 297 | """ 298 | text = "" 299 | temp = b"" 300 | for t in tokens: 301 | if isinstance(t, str): 302 | if temp: 303 | text += temp.decode("utf-8", errors=self.errors) 304 | temp = b"" 305 | text += t 306 | elif isinstance(t, bytes): 307 | temp += t 308 | else: 309 | raise TypeError("token should only be of type types or str") 310 | if temp: 311 | text += temp.decode("utf-8", errors=self.errors) 312 | return text 313 | 314 | @property 315 | def vocab_size(self): 316 | return self.tokenizer.n_vocab 317 | 318 | def _convert_id_to_token(self, index: int) -> Union[bytes, str]: 319 | """Converts an id to a token, special tokens included""" 320 | if index in self.decoder: 321 | return self.decoder[index] 322 | raise ValueError("unknown ids") 323 | 324 | def _convert_token_to_id(self, token: Union[bytes, str]) -> int: 325 | """Converts a token to an id using the vocab, special tokens included""" 326 | if token in self.special_tokens: 327 | return self.special_tokens[token] 328 | if token in self.mergeable_ranks: 329 | return self.mergeable_ranks[token] 330 | raise ValueError("unknown token") 331 | 332 | def _tokenize(self, text: str, **kwargs): 333 | """ 334 | Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based 335 | vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces). 336 | 337 | Do NOT take care of added tokens. 338 | """ 339 | raise NotImplementedError 340 | 341 | def _decode( 342 | self, 343 | token_ids: Union[int, List[int]], 344 | skip_special_tokens: bool = False, 345 | errors: str = None, 346 | **kwargs, 347 | ) -> str: 348 | if isinstance(token_ids, int): 349 | token_ids = [token_ids] 350 | 351 | def _decode_imgurl(img_token_ids): 352 | assert img_token_ids[0] == self.img_start_id and img_token_ids[-1] == self.img_end_id 353 | img_token_ids = img_token_ids[1:-1] 354 | img_token_ids = img_token_ids[ : img_token_ids.index(self.img_pad_id)] 355 | img_url = bytes(img_token_ids).decode('utf-8') 356 | return [self.img_start_id] + self.tokenizer.encode(img_url) + [self.img_end_id] 357 | 358 | token_ids = _replace_closed_tag(token_ids, self.img_start_id, self.img_end_id, _decode_imgurl) 359 | 360 | if skip_special_tokens: 361 | if kwargs.get('keep_image_special', False): 362 | token_ids = [i for i in token_ids if i < self.eod_id 363 | or i in self.image_special_tokens] 364 | else: 365 | token_ids = [i for i in token_ids if i < self.eod_id] 366 | return self.tokenizer.decode(token_ids, errors=errors or self.errors) 367 | 368 | def to_list_format(self, text: str): 369 | text = unicodedata.normalize("NFC", text) 370 | token_ids = self.tokenizer.encode( 371 | text, allowed_special=set(self.IMAGE_ST + (ENDOFTEXT,))) 372 | 373 | def _encode_vl_info(tokens): 374 | if len(tokens) == 0: 375 | return [] 376 | if tokens[0] == self.img_start_id and tokens[-1] == self.img_end_id: 377 | key = 'image' 378 | elif tokens[0] == self.ref_start_id and tokens[-1] == self.ref_end_id: 379 | key = 'ref' 380 | elif tokens[0] == self.box_start_id and tokens[-1] == self.box_end_id: 381 | key = 'box' 382 | elif tokens[0] == self.quad_start_id and tokens[-1] == self.quad_end_id: 383 | key = 'quad' 384 | else: 385 | _tobytes = lambda x: x.encode('utf-8') if isinstance(x, str) else x 386 | return [{'text': b''.join(map(_tobytes, map(self.decoder.get, tokens))).decode('utf-8')}] 387 | _tobytes = lambda x: x.encode('utf-8') if isinstance(x, str) else x 388 | val = b''.join(map(_tobytes, map(self.decoder.get, tokens[1:-1]))).decode('utf-8') 389 | return [{key: val}] 390 | 391 | return _replace_closed_tag( 392 | token_ids, 393 | (self.img_start_id, self.ref_start_id, self.box_start_id, self.quad_start_id), 394 | (self.img_end_id, self.ref_end_id, self.box_end_id, self.quad_end_id), 395 | _encode_vl_info, 396 | _encode_vl_info, 397 | ) 398 | 399 | def from_list_format(self, list_format: List[Dict]): 400 | text = '' 401 | num_images = 0 402 | for ele in list_format: 403 | if 'image' in ele: 404 | num_images += 1 405 | text += f'Picture {num_images}: ' 406 | text += self.image_start_tag + ele['image'] + self.image_end_tag 407 | text += '\n' 408 | elif 'text' in ele: 409 | text += ele['text'] 410 | elif 'box' in ele: 411 | if 'ref' in ele: 412 | text += self.ref_start_tag + ele['ref'] + self.ref_end_tag 413 | for box in ele['box']: 414 | text += self.box_start_tag + '(%d,%d),(%d,%d)' % (box[0], box[1], box[2], box[3]) + self.box_end_tag 415 | else: 416 | raise ValueError("Unsupport element: " + str(ele)) 417 | return text 418 | 419 | def _fetch_latest_picture(self, response, history): 420 | if history is None: 421 | history = [] 422 | _history = history + [(response, None)] 423 | for q, r in _history[::-1]: 424 | for ele in self.to_list_format(q)[::-1]: 425 | if 'image' in ele: 426 | return ele['image'] 427 | return None 428 | 429 | def _fetch_all_box_with_ref(self, text): 430 | list_format = self.to_list_format(text) 431 | output = [] 432 | for i, ele in enumerate(list_format): 433 | if 'box' in ele: 434 | bbox = tuple(map(int, ele['box'].replace('(', '').replace(')', '').split(','))) 435 | assert len(bbox) == 4 436 | output.append({'box': bbox}) 437 | if i > 0 and 'ref' in list_format[i-1]: 438 | output[-1]['ref'] = list_format[i-1]['ref'].strip() 439 | return output 440 | 441 | def draw_bbox_on_latest_picture( 442 | self, 443 | response, 444 | history=None, 445 | ) -> Optional[Image.Image]: 446 | image = self._fetch_latest_picture(response, history) 447 | if image is None: 448 | return None 449 | if image.startswith("http://") or image.startswith("https://"): 450 | image = Image.open(requests.get(image, stream=True).raw).convert("RGB") 451 | h, w = image.height, image.width 452 | else: 453 | image = np.asarray(Image.open(image).convert("RGB")) 454 | h, w = image.shape[0], image.shape[1] 455 | visualizer = Visualizer(image) 456 | 457 | boxes = self._fetch_all_box_with_ref(response) 458 | if not boxes: 459 | return None 460 | color = random.choice([_ for _ in mcolors.TABLEAU_COLORS.keys()]) # init color 461 | for box in boxes: 462 | if 'ref' in box: # random new color for new refexps 463 | color = random.choice([_ for _ in mcolors.TABLEAU_COLORS.keys()]) 464 | x1, y1, x2, y2 = box['box'] 465 | x1, y1, x2, y2 = (int(x1 / 1000 * w), int(y1 / 1000 * h), int(x2 / 1000 * w), int(y2 / 1000 * h)) 466 | visualizer.draw_box((x1, y1, x2, y2), alpha=1, edge_color=color) 467 | if 'ref' in box: 468 | visualizer.draw_text(box['ref'], (x1, y1), color=color, horizontal_alignment="left") 469 | return visualizer.output 470 | 471 | 472 | import colorsys 473 | import logging 474 | import math 475 | import numpy as np 476 | import matplotlib as mpl 477 | import matplotlib.colors as mplc 478 | import matplotlib.figure as mplfigure 479 | import torch 480 | from matplotlib.backends.backend_agg import FigureCanvasAgg 481 | from PIL import Image 482 | import random 483 | 484 | logger = logging.getLogger(__name__) 485 | 486 | 487 | class VisImage: 488 | def __init__(self, img, scale=1.0): 489 | self.img = img 490 | self.scale = scale 491 | self.width, self.height = img.shape[1], img.shape[0] 492 | self._setup_figure(img) 493 | 494 | def _setup_figure(self, img): 495 | fig = mplfigure.Figure(frameon=False) 496 | self.dpi = fig.get_dpi() 497 | # add a small 1e-2 to avoid precision lost due to matplotlib's truncation 498 | # (https://github.com/matplotlib/matplotlib/issues/15363) 499 | fig.set_size_inches( 500 | (self.width * self.scale + 1e-2) / self.dpi, 501 | (self.height * self.scale + 1e-2) / self.dpi, 502 | ) 503 | self.canvas = FigureCanvasAgg(fig) 504 | # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig) 505 | ax = fig.add_axes([0.0, 0.0, 1.0, 1.0]) 506 | ax.axis("off") 507 | self.fig = fig 508 | self.ax = ax 509 | self.reset_image(img) 510 | 511 | def reset_image(self, img): 512 | img = img.astype("uint8") 513 | self.ax.imshow(img, extent=(0, self.width, self.height, 0), interpolation="nearest") 514 | 515 | def save(self, filepath): 516 | self.fig.savefig(filepath) 517 | 518 | def get_image(self): 519 | canvas = self.canvas 520 | s, (width, height) = canvas.print_to_buffer() 521 | 522 | buffer = np.frombuffer(s, dtype="uint8") 523 | 524 | img_rgba = buffer.reshape(height, width, 4) 525 | rgb, alpha = np.split(img_rgba, [3], axis=2) 526 | return rgb.astype("uint8") 527 | 528 | 529 | class Visualizer: 530 | def __init__(self, img_rgb, metadata=None, scale=1.0): 531 | self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8) 532 | self.font_path = FONT_PATH 533 | self.output = VisImage(self.img, scale=scale) 534 | self.cpu_device = torch.device("cpu") 535 | 536 | # too small texts are useless, therefore clamp to 14 537 | self._default_font_size = max( 538 | np.sqrt(self.output.height * self.output.width) // 30, 15 // scale 539 | ) 540 | 541 | def draw_text( 542 | self, 543 | text, 544 | position, 545 | *, 546 | font_size=None, 547 | color="g", 548 | horizontal_alignment="center", 549 | rotation=0, 550 | ): 551 | if not font_size: 552 | font_size = self._default_font_size 553 | 554 | # since the text background is dark, we don't want the text to be dark 555 | color = np.maximum(list(mplc.to_rgb(color)), 0.2) 556 | color[np.argmax(color)] = max(0.8, np.max(color)) 557 | 558 | x, y = position 559 | self.output.ax.text( 560 | x, 561 | y, 562 | text, 563 | size=font_size * self.output.scale, 564 | fontproperties=FontProperties(fname=self.font_path), 565 | bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"}, 566 | verticalalignment="top", 567 | horizontalalignment=horizontal_alignment, 568 | color=color, 569 | zorder=10, 570 | rotation=rotation, 571 | ) 572 | return self.output 573 | 574 | def draw_box(self, box_coord, alpha=0.5, edge_color="g", line_style="-"): 575 | 576 | x0, y0, x1, y1 = box_coord 577 | width = x1 - x0 578 | height = y1 - y0 579 | 580 | linewidth = max(self._default_font_size / 4, 1) 581 | 582 | self.output.ax.add_patch( 583 | mpl.patches.Rectangle( 584 | (x0, y0), 585 | width, 586 | height, 587 | fill=False, 588 | edgecolor=edge_color, 589 | linewidth=linewidth * self.output.scale, 590 | alpha=alpha, 591 | linestyle=line_style, 592 | ) 593 | ) 594 | return self.output 595 | 596 | def get_output(self): 597 | 598 | return self.output 599 | -------------------------------------------------------------------------------- /OdysseyAgent/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "auto_map": { 3 | "AutoTokenizer": [ 4 | "Qwen/Qwen-VL-Chat--tokenization_qwen.QWenTokenizer", 5 | null 6 | ] 7 | }, 8 | "clean_up_tokenization_spaces": true, 9 | "model_max_length": 8192, 10 | "tokenizer_class": "QWenTokenizer" 11 | } 12 | -------------------------------------------------------------------------------- /OdysseyAgent/visual.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba Cloud. 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from collections import OrderedDict 7 | import math 8 | import requests 9 | from io import BytesIO 10 | from functools import partial 11 | from PIL import Image 12 | from typing import Callable, Optional, Sequence, Tuple, List 13 | import numpy as np 14 | 15 | import torch 16 | from torch import nn 17 | from torch.nn import functional as F 18 | from torch.nn.init import trunc_normal_ 19 | from torchvision import transforms 20 | from torchvision.transforms import InterpolationMode 21 | 22 | 23 | def get_abs_pos(abs_pos, tgt_size): 24 | # abs_pos: L, C 25 | # tgt_size: M 26 | # return: M, C 27 | src_size = int(math.sqrt(abs_pos.size(0))) 28 | tgt_size = int(math.sqrt(tgt_size)) 29 | dtype = abs_pos.dtype 30 | 31 | if src_size != tgt_size: 32 | return F.interpolate( 33 | abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), 34 | size=(tgt_size, tgt_size), 35 | mode="bicubic", 36 | align_corners=False, 37 | ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype) 38 | else: 39 | return abs_pos 40 | 41 | # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 42 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 43 | """ 44 | grid_size: int of the grid height and width 45 | return: 46 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 47 | """ 48 | grid_h = np.arange(grid_size, dtype=np.float32) 49 | grid_w = np.arange(grid_size, dtype=np.float32) 50 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 51 | grid = np.stack(grid, axis=0) 52 | 53 | grid = grid.reshape([2, 1, grid_size, grid_size]) 54 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 55 | if cls_token: 56 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 57 | return pos_embed 58 | 59 | 60 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 61 | assert embed_dim % 2 == 0 62 | 63 | # use half of dimensions to encode grid_h 64 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 65 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 66 | 67 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 68 | return emb 69 | 70 | 71 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 72 | """ 73 | embed_dim: output dimension for each position 74 | pos: a list of positions to be encoded: size (M,) 75 | out: (M, D) 76 | """ 77 | assert embed_dim % 2 == 0 78 | omega = np.arange(embed_dim // 2, dtype=np.float32) 79 | omega /= embed_dim / 2. 80 | omega = 1. / 10000**omega # (D/2,) 81 | 82 | pos = pos.reshape(-1) # (M,) 83 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 84 | 85 | emb_sin = np.sin(out) # (M, D/2) 86 | emb_cos = np.cos(out) # (M, D/2) 87 | 88 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 89 | return emb 90 | 91 | 92 | class Resampler(nn.Module): 93 | """ 94 | A 2D perceiver-resampler network with one cross attention layers by 95 | (grid_size**2) learnable queries and 2d sincos pos_emb 96 | Outputs: 97 | A tensor with the shape of (grid_size**2, embed_dim) 98 | """ 99 | def __init__( 100 | self, 101 | grid_size, 102 | embed_dim, 103 | num_heads, 104 | kv_dim=None, 105 | norm_layer=nn.LayerNorm 106 | ): 107 | super().__init__() 108 | self.num_queries = grid_size ** 2 109 | self.embed_dim = embed_dim 110 | self.num_heads = num_heads 111 | 112 | self.pos_embed = nn.Parameter( 113 | torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float() 114 | ).requires_grad_(False) 115 | 116 | self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) 117 | trunc_normal_(self.query, std=.02) 118 | 119 | if kv_dim is not None and kv_dim != embed_dim: 120 | self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False) 121 | else: 122 | self.kv_proj = nn.Identity() 123 | 124 | self.attn = nn.MultiheadAttention(embed_dim, num_heads) 125 | self.ln_q = norm_layer(embed_dim) 126 | self.ln_kv = norm_layer(embed_dim) 127 | 128 | # self.apply(self._init_weights) 129 | 130 | def _init_weights(self, m): 131 | if isinstance(m, nn.Linear): 132 | trunc_normal_(m.weight, std=.02) 133 | if isinstance(m, nn.Linear) and m.bias is not None: 134 | nn.init.constant_(m.bias, 0) 135 | elif isinstance(m, nn.LayerNorm): 136 | nn.init.constant_(m.bias, 0) 137 | nn.init.constant_(m.weight, 1.0) 138 | 139 | def forward(self, x, attn_mask=None): 140 | 141 | pos_embed = get_abs_pos(self.pos_embed, x.size(1)) 142 | 143 | x = self.kv_proj(x) 144 | x = self.ln_kv(x).permute(1, 0, 2) 145 | 146 | N = x.shape[1] 147 | q = self.ln_q(self.query) 148 | out = self.attn( 149 | self._repeat(q, N) + self.pos_embed.unsqueeze(1), 150 | x + pos_embed.unsqueeze(1), 151 | x, 152 | attn_mask=attn_mask)[0] 153 | return out.permute(1, 0, 2) 154 | 155 | def _repeat(self, query, N: int): 156 | return query.unsqueeze(1).repeat(1, N, 1) 157 | 158 | 159 | class VisualAttention(nn.Module): 160 | """self-attention layer class. 161 | 162 | Self-attention layer takes input with size [s, b, h] 163 | and returns output of the same size. 164 | """ 165 | 166 | def __init__(self, embed_dim, num_heads, 167 | bias=True, kdim=None, vdim=None): 168 | super(VisualAttention, self).__init__() 169 | self.embed_dim = embed_dim 170 | self.kdim = kdim if kdim is not None else embed_dim 171 | self.vdim = vdim if vdim is not None else embed_dim 172 | self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim 173 | 174 | self.num_heads = num_heads 175 | 176 | # Per attention head and per partition values. 177 | assert embed_dim % num_heads == 0 178 | self.hidden_size_per_attention_head = embed_dim // num_heads 179 | self.num_attention_heads_per_partition = num_heads 180 | self.hidden_size_per_partition = embed_dim 181 | 182 | # Strided linear layer. 183 | assert self._qkv_same_embed_dim, 'Only Support SelfAttention Currently' 184 | self.in_proj = nn.Linear(embed_dim, 3 * embed_dim) 185 | self.out_proj = nn.Linear(embed_dim, embed_dim) 186 | self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) 187 | 188 | def forward(self, query, key, value, attn_mask = None): 189 | # query/key/value: [sq, b, h] 190 | sq, b, _ = query.size() 191 | 192 | assert torch.allclose(query, key), 'Only Support Self-Attention Currently' 193 | sk = sq 194 | mixed_x_layer = self.in_proj(query) 195 | 196 | # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] 197 | new_tensor_shape = mixed_x_layer.size()[:-1] + \ 198 | (self.num_attention_heads_per_partition, 199 | 3 * self.hidden_size_per_attention_head) 200 | mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) 201 | 202 | # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] 203 | query_layer, key_layer, value_layer = mixed_x_layer.split( 204 | self.hidden_size_per_attention_head, dim=-1) 205 | 206 | # [sq, b, np, hn] -> [sq, b * np, hn] 207 | query_layer = query_layer.view(sq, 208 | b * self.num_attention_heads_per_partition, 209 | self.hidden_size_per_attention_head).transpose(0, 1) 210 | # [sk, b, np, hn] -> [sk, b * np, hn] 211 | key_layer = key_layer.view(sk, 212 | b * self.num_attention_heads_per_partition, 213 | self.hidden_size_per_attention_head).transpose(0, 1) 214 | 215 | q_scaled = query_layer / self.norm_factor 216 | if attn_mask is not None: 217 | attention_probs = torch.baddbmm(attn_mask, q_scaled, key_layer.transpose(-2, -1)) 218 | else: 219 | attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1)) 220 | attention_probs = attention_probs.softmax(dim=-1) 221 | 222 | value_layer = value_layer.view(sk, 223 | b * self.num_attention_heads_per_partition, 224 | self.hidden_size_per_attention_head).transpose(0, 1) 225 | 226 | # matmul: [b * np, sq, hn] 227 | context_layer = torch.bmm(attention_probs, value_layer) 228 | 229 | # change view [b, np, sq, hn] 230 | context_layer = context_layer.view(b, 231 | self.num_attention_heads_per_partition, 232 | sq, self.hidden_size_per_attention_head) 233 | 234 | # [b, np, sq, hn] --> [sq, b, np, hn] 235 | context_layer = context_layer.permute(2, 0, 1, 3).contiguous() 236 | 237 | # [sq, b, np, hn] --> [sq, b, hp] 238 | new_context_layer_shape = context_layer.size()[:-2] + \ 239 | (self.hidden_size_per_partition,) 240 | context_layer = context_layer.view(*new_context_layer_shape) 241 | 242 | output = self.out_proj(context_layer) 243 | 244 | return output 245 | 246 | 247 | class VisualAttentionBlock(nn.Module): 248 | def __init__( 249 | self, 250 | d_model: int, 251 | n_head: int, 252 | mlp_ratio: float = 4.0, 253 | act_layer: Callable = nn.GELU, 254 | norm_layer: Callable = nn.LayerNorm, 255 | is_cross_attention: bool = False, 256 | ): 257 | super().__init__() 258 | 259 | self.ln_1 = norm_layer(d_model) 260 | if is_cross_attention: 261 | self.ln_1_kv = norm_layer(d_model) 262 | 263 | self.ln_2 = norm_layer(d_model) 264 | mlp_width = int(d_model * mlp_ratio) 265 | self.attn = VisualAttention(d_model, n_head) 266 | self.mlp = nn.Sequential(OrderedDict([ 267 | ("c_fc", nn.Linear(d_model, mlp_width)), 268 | ("gelu", act_layer()), 269 | ("c_proj", nn.Linear(mlp_width, d_model)) 270 | ])) 271 | 272 | def attention( 273 | self, 274 | q_x: torch.Tensor, 275 | k_x: Optional[torch.Tensor] = None, 276 | v_x: Optional[torch.Tensor] = None, 277 | attn_mask: Optional[torch.Tensor] = None, 278 | ): 279 | k_x = k_x if k_x is not None else q_x 280 | v_x = v_x if v_x is not None else q_x 281 | 282 | attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None 283 | return self.attn(q_x, k_x, v_x, attn_mask=attn_mask) 284 | 285 | def forward( 286 | self, 287 | q_x: torch.Tensor, 288 | k_x: Optional[torch.Tensor] = None, 289 | v_x: Optional[torch.Tensor] = None, 290 | attn_mask: Optional[torch.Tensor] = None, 291 | ): 292 | k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None 293 | v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None 294 | 295 | x = q_x + self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask) 296 | x = x + self.mlp(self.ln_2(x)) 297 | return x 298 | 299 | 300 | class TransformerBlock(nn.Module): 301 | def __init__( 302 | self, 303 | width: int, 304 | layers: int, 305 | heads: int, 306 | mlp_ratio: float = 4.0, 307 | act_layer: Callable = nn.GELU, 308 | norm_layer: Callable = nn.LayerNorm, 309 | ): 310 | super().__init__() 311 | self.width = width 312 | self.layers = layers 313 | 314 | self.resblocks = nn.ModuleList([ 315 | VisualAttentionBlock( 316 | width, heads, mlp_ratio, act_layer=act_layer, norm_layer=norm_layer) 317 | for _ in range(layers) 318 | ]) 319 | 320 | def get_cast_dtype(self) -> torch.dtype: 321 | return self.resblocks[0].mlp.c_fc.weight.dtype 322 | 323 | def get_cast_device(self) -> torch.device: 324 | return self.resblocks[0].mlp.c_fc.weight.device 325 | 326 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): 327 | for r in self.resblocks: 328 | x = r(x, attn_mask=attn_mask) 329 | return x 330 | 331 | 332 | class VisionTransformer(nn.Module): 333 | 334 | def __init__( 335 | self, 336 | image_size: int, 337 | patch_size: int, 338 | width: int, 339 | layers: int, 340 | heads: int, 341 | mlp_ratio: float, 342 | n_queries: int = 256, 343 | output_dim: int = 512, 344 | **kwargs 345 | ): 346 | super().__init__() 347 | image_height, image_width = self.image_size = (image_size, image_size) 348 | patch_height, patch_width = self.patch_size = (patch_size, patch_size) 349 | self.grid_size = (image_height // patch_height, image_width // patch_width) 350 | self.output_dim = output_dim 351 | 352 | mean = (0.48145466, 0.4578275, 0.40821073) 353 | std = (0.26862954, 0.26130258, 0.27577711) 354 | self.image_transform = transforms.Compose([ 355 | transforms.Resize( 356 | (image_size, image_size), 357 | interpolation=InterpolationMode.BICUBIC 358 | ), 359 | transforms.ToTensor(), 360 | transforms.Normalize(mean=mean, std=std), 361 | ]) 362 | 363 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 364 | 365 | # class embeddings and positional embeddings 366 | scale = width ** -0.5 367 | self.positional_embedding = nn.Parameter(scale * torch.randn(256, width)) 368 | 369 | norm_layer = partial(nn.LayerNorm, eps=1e-6) 370 | act_layer = nn.GELU 371 | 372 | self.ln_pre = norm_layer(width) 373 | self.transformer = TransformerBlock( 374 | width, 375 | layers, 376 | heads, 377 | mlp_ratio, 378 | act_layer=act_layer, 379 | norm_layer=norm_layer, 380 | ) 381 | 382 | self.attn_pool = Resampler( 383 | grid_size=int(math.sqrt(n_queries)), 384 | embed_dim=output_dim, 385 | num_heads=output_dim // 128, 386 | kv_dim=width, 387 | norm_layer=norm_layer, 388 | ) 389 | self.ln_post = norm_layer(output_dim) 390 | self.proj = nn.Parameter((output_dim** -0.5) * torch.randn(output_dim, output_dim)) 391 | 392 | def forward(self, x: torch.Tensor): 393 | x = x.to( 394 | dtype=self.transformer.get_cast_dtype(), 395 | device=self.transformer.get_cast_device(), 396 | ) 397 | # to patches 398 | x = self.conv1(x) # shape = [*, width, grid, grid] 399 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 400 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 401 | 402 | x = x + get_abs_pos(self.positional_embedding, x.size(1)) 403 | 404 | x = self.ln_pre(x) 405 | 406 | x = x.permute(1, 0, 2) # NLD -> LND 407 | x = self.transformer(x) 408 | x = x.permute(1, 0, 2) # LND -> NLD 409 | 410 | x = self.attn_pool(x) 411 | x = self.ln_post(x) 412 | x = x @ self.proj 413 | 414 | return x 415 | 416 | def encode(self, image_paths: List[str]): 417 | images = [] 418 | for image_path in image_paths: 419 | if image_path.startswith("http://") or image_path.startswith("https://"): 420 | image = Image.open(requests.get(image_path, stream=True).raw) 421 | else: 422 | image = Image.open(image_path) 423 | image = image.convert("RGB") 424 | images.append(self.image_transform(image)) 425 | images = torch.stack(images, dim=0) 426 | return self(images) 427 | -------------------------------------------------------------------------------- /Quickstart.md: -------------------------------------------------------------------------------- 1 | # 🚀Quick Start 2 | 3 | ## Data preprocessing 4 | 5 | Please follow the **Dataset Access** section of the [README.md](README.md) to prepare the data, and run the `preprocessing.py` script as instructed. Ensure that the structure of the `./data` directory is as shown below: 6 | 7 | ``` 8 | GUI-Odyssey 9 | ├── data 10 | │ ├── annotations 11 | │ │ └── *.json 12 | │ ├── screenshots 13 | │ │ └── *.png 14 | │ ├── splits 15 | │ │ ├── app_split.json 16 | │ │ ├── device_split.json 17 | │ │ ├── random_split.json 18 | │ │ └── task_split.json 19 | │ ├── format_converter.py 20 | │ └── preprocessing.py 21 | └── ... 22 | ``` 23 | 24 | Next, run the following command to generate chat-format data for training and testing. The `his_len` parameter can be set to specify the length of historical information: 25 | 26 | ```shell 27 | cd data 28 | python format_converter.py --his_len 4 29 | ``` 30 | 31 | ## Build OdysseyAgent upon Qwen-VL-Chat 32 | 33 | The OdysseyAgent is bulit upon [Qwen-VL](https://github.com/QwenLM/Qwen-VL). 34 | 35 | Before running, set up the environment and install the required packages: 36 | 37 | ```shell 38 | cd src 39 | pip install -r requirements.txt 40 | ``` 41 | 42 | Next, initialize `OdysseyAgent` using the weights from `Qwen-VL-Chat`: 43 | 44 | ```shell 45 | python merge_weight.py 46 | ``` 47 | 48 | Further, we also provide four variants of OdysseyAgent: 49 | - [OdysseyAgent-Random](https://huggingface.co/hflqf88888/OdysseyAgent-random) 50 | - [OdysseyAgent-Task](https://huggingface.co/hflqf88888/OdysseyAgent-task) 51 | - [OdysseyAgent-Device](https://huggingface.co/hflqf88888/OdysseyAgent-device) 52 | - [OdysseyAgent-App](https://huggingface.co/hflqf88888/OdysseyAgent-app) 53 | 54 | Each fine-tuned on `Train-Random`, `Train-Task`, `Train-Device`, and `Train-App` respectively. 55 | 56 | ### Fine-tuning 57 | 58 | Specify the path to the `OdysseyAgent` and the chat-format training data generated in the `Data preprocessing` stage (one of the four splits) in the `script/train.sh` file. Then, run the following command: 59 | 60 | ```shell 61 | cd src 62 | bash script/train.sh 63 | ``` 64 | 65 | ### Evalutaion 66 | 67 | Specify the path to the checkpoint and dataset split (one of `app_split`, `device_split`, `random_split`, `task_split`) in the `script/eval.sh` file. Then, run the following command: 68 | 69 | ```shell 70 | cd src 71 | bash script/eval.sh 72 | ``` 73 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GUI Odyssey 2 | 3 | **This repository is the official implementation of GUI Odyssey.** 4 | 5 | > [GUI Odyssey: A Comprehensive Dataset for Cross-App GUI Navigation on Mobile Devices](https://arxiv.org/abs/2406.08451) 6 | > Quanfeng Lu, Wenqi Shao✉️⭐️, Zitao Liu, Fanqing Meng, Boxuan Li, Botong Chen, Siyuan Huang, Kaipeng Zhang, Yu Qiao, Ping Luo✉️ 7 | > ✉️ Wenqi Shao (shaowenqi@pjlab.org.cn) and Ping Luo (pluo@cs.hku.hk) are correponding authors. 8 | > ⭐️ Wenqi Shao is project leader. 9 | 10 | 11 | ## 💡 News 12 | 13 | - `2024/06/24`: The data of [GUI Odyssey](https://arxiv.org/pdf/2406.08451) is released! Please check out [OpenGVLab/GUI-Odyssey](https://huggingface.co/datasets/OpenGVLab/GUI-Odyssey)! 14 | - `2024/06/13`: The paper of [GUI Odyssey](https://arxiv.org/pdf/2406.08451) is released! 15 | 16 | 17 | 18 | ## 🔆 Introduction 19 | GUI Odyssey is a comprehensive dataset for training and evaluating **cross-app** navigation agents. GUI Odyssey consists of 7,735 episodes from 6 mobile devices, spanning 6 types of cross-app tasks, 201 apps, and 1.4K app combos. 20 | ![overview](assets/dataset_overview.jpg) 21 | 22 | 23 | ## 🛠️ Data collection pipeline 24 | GUI Odyssey comprises six categories of navigation tasks. For each category, we construct instruction templates with items and apps selected from a predefined pool, resulting in a vast array of unique instructions for annotating GUI episodes. Human demonstrations on an Android emulator capture the metadata of each episode in a comprehensive format. After rigorous quality checks, GUI Odyssey includes 7,735 validated cross-app GUI navigation episodes. 25 | ![pipeline](assets/pipeline.png) 26 | 27 | 28 | ## 📝 Statistics 29 | 30 |
31 | 32 | Splits | # Episodes | # Unique Prompts | # Avg. Steps | Data location | Model 33 | :---------: | :---------: | :-----------: | :--------------: | :-----------: | :-----------: 34 | **Total** | **7,735** | **7,735** | **15.4** | [GUI-Odyssey](https://huggingface.co/datasets/OpenGVLab/GUI-Odyssey) | [OdysseyAgent](https://huggingface.co/collections/hflqf88888/gui-odyssey-6683bac37ad6fe37b1215c18) 35 | Train-Random \& Test-Random | 5,802 / 1,933 | 5,802 / 1,933 | 15.4 / 15.2 | [random_split.json](https://huggingface.co/datasets/OpenGVLab/GUI-Odyssey/tree/main/splits) | [OdysseyAgent-Random](https://huggingface.co/hflqf88888/OdysseyAgent-random) 36 | Train-Task \& Test-Task | 6,719 / 1,016 | 6,719 / 1,016 | 15.0 / 17.6 | [task_split.json](https://huggingface.co/datasets/OpenGVLab/GUI-Odyssey/tree/main/splits) | [OdysseyAgent-Task](https://huggingface.co/hflqf88888/OdysseyAgent-task) 37 | Train-Device \& Test-Device | 6,473 / 1,262 | 6,473 / 1,262 | 15.4 / 15.0 | [device_split.json](https://huggingface.co/datasets/OpenGVLab/GUI-Odyssey/tree/main/splits) | [OdysseyAgent-Device](https://huggingface.co/hflqf88888/OdysseyAgent-device) 38 | Train-App \& Test-App | 6,596 / 1,139 | 6,596 / 1,139 | 15.4 / 15.3 | [app_split.json](https://huggingface.co/datasets/OpenGVLab/GUI-Odyssey/tree/main/splits) | [OdysseyAgent-App](https://huggingface.co/hflqf88888/OdysseyAgent-app) 39 | 40 |
41 | 42 | ## 💫 Dataset Access 43 | 44 | The whole GUI Odyssey is hosted on [Huggingface](https://huggingface.co/datasets/OpenGVLab/GUI-Odyssey). 45 | 46 | Clone the entire dataset from Huggingface: 47 | ```shell 48 | git clone https://huggingface.co/datasets/OpenGVLab/GUI-Odyssey 49 | ``` 50 | And then move the cloned dataset into `./data` directory. After that, the structure of `./data` should look like this: 51 | 52 | 53 | ``` 54 | GUI-Odyssey 55 | ├── data 56 | │ ├── annotations 57 | │ │ └── *.json 58 | │ ├── screenshots 59 | │ │ └── data_* 60 | │ │ └── *.png 61 | │ ├── splits 62 | │ │ ├── app_split.json 63 | │ │ ├── device_split.json 64 | │ │ ├── random_split.json 65 | │ │ └── task_split.json 66 | │ ├── format_converter.py 67 | │ └── preprocessing.py 68 | └── ... 69 | ``` 70 | 71 | Then organize the screenshots folder: 72 | 73 | ```shell 74 | cd data 75 | python preprocessing.py 76 | ``` 77 | 78 | Finally, the structure of `./data` should look like this: 79 | 80 | ``` 81 | GUI-Odyssey 82 | ├── data 83 | │ ├── annotations 84 | │ │ └── *.json 85 | │ ├── screenshots 86 | │ │ └── *.png 87 | │ ├── splits 88 | │ │ ├── app_split.json 89 | │ │ ├── device_split.json 90 | │ │ ├── random_split.json 91 | │ │ └── task_split.json 92 | │ ├── format_converter.py 93 | │ └── preprocessing.py 94 | └── ... 95 | ``` 96 | 97 | 98 | ## ⚙️ Detailed Data Information 99 | Please refer to [this](introduction.md). 100 | 101 | 102 | ## 🚀 Quick Start 103 | 104 | Please refer to [this](Quickstart.md) to quick start. 105 | 106 | ## 📖 Release Process 107 | 108 | - [x] Dataset 109 | - [x] Screenshots of GUI Odyssey 110 | - [x] annotations of GUI Odyssey 111 | - [x] split files of GUI Odyssey 112 | - [x] Code 113 | - [x] data preprocessing code 114 | - [x] inference code 115 | - [x] Models 116 | 117 | 118 | ## 🖊️ Citation 119 | If you feel GUI Odyssey useful in your project or research, please kindly use the following BibTeX entry to cite our paper. Thanks! 120 | ```bib 121 | @article{lu2024gui, 122 | title={GUI Odyssey: A Comprehensive Dataset for Cross-App GUI Navigation on Mobile Devices}, 123 | author={Lu, Quanfeng and Shao, Wenqi and Liu, Zitao and Meng, Fanqing and Li, Boxuan and Chen, Botong and Huang, Siyuan and Zhang, Kaipeng and Qiao, Yu and Luo, Ping}, 124 | journal={arXiv preprint arXiv:2406.08451}, 125 | year={2024} 126 | } 127 | ``` 128 | 129 | 132 | -------------------------------------------------------------------------------- /assets/dataset_overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/GUI-Odyssey/8bf1c756691f077e3d9439dc16103a2cf9493245/assets/dataset_overview.jpg -------------------------------------------------------------------------------- /assets/pipeline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/GUI-Odyssey/8bf1c756691f077e3d9439dc16103a2cf9493245/assets/pipeline.jpg -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/GUI-Odyssey/8bf1c756691f077e3d9439dc16103a2cf9493245/assets/pipeline.png -------------------------------------------------------------------------------- /data/format_converter.py: -------------------------------------------------------------------------------- 1 | import os, json 2 | import argparse 3 | import numpy as np 4 | 5 | ### 6 | current_path = os.path.abspath(__file__) 7 | DATA_DIR = os.path.dirname(current_path) 8 | pic_base = os.path.join(DATA_DIR, 'screenshots') 9 | anno_base = os.path.join(DATA_DIR, 'annotations') 10 | train_anno_base = os.path.join(DATA_DIR, 'train_anno') 11 | test_anno_base = os.path.join(DATA_DIR, 'test_anno') 12 | split_base = os.path.join(DATA_DIR, 'splits') 13 | 14 | PROMPT = "I'm looking for guidance on how to " 15 | 16 | ### 17 | 18 | def decode_action(action, info): 19 | if action == 'CLICK' or action == "LONG_PRESS": 20 | if info == 'KEY_HOME': 21 | gt = 'PRESS_HOME' 22 | elif info == 'KEY_BACK': 23 | gt = 'PRESS_BACK' 24 | elif info == 'KEY_APPSELECT': 25 | gt = 'PRESS_RECENT' 26 | elif type(info) == list: 27 | gt = f"{action}: {tuple(info[0])}" 28 | else: 29 | raise ValueError(f'Unknown click action {info}') 30 | 31 | elif action == 'SCROLL': 32 | start = np.array(info[0]) 33 | end = np.array(info[1]) 34 | delta = end - start 35 | delta_abs = np.abs(delta) 36 | lr = 'LEFT' if delta[0] < 0 else 'RIGHT' 37 | ud = 'UP' if delta[1] < 0 else 'DOWN' 38 | if delta_abs[0] > delta_abs[1]: 39 | gt = f"SCROLL: {lr}" 40 | else: 41 | gt = f"SCROLL: {ud}" 42 | 43 | elif action == 'TEXT': 44 | gt = f'TYPE: {info}' 45 | elif action == 'COMPLETE': 46 | gt = action 47 | elif action == 'INCOMPLETE': 48 | gt = 'IMPOSSIBLE' 49 | else: 50 | raise ValueError(f'Unknown action {action}') 51 | return gt 52 | 53 | def build_train_chat(split_fp='./splits/random_split.json', his_len=4): 54 | os.makedirs(train_anno_base, exist_ok=True) 55 | name = os.path.basename(split_fp).split('.')[0] 56 | train_split = json.load(open(split_fp))['train'] 57 | res = [] 58 | idx = 0 59 | for f in train_split: 60 | this_res = [] 61 | fp = os.path.join(anno_base, f) 62 | data = json.load(open(fp)) 63 | instruction = data['task_info']['instruction'] 64 | steps = data['steps'] 65 | 66 | history_screenshot, history_action = [], [] 67 | 68 | for step in steps: 69 | image = step['screenshot'] 70 | action = step['action'] 71 | info = step['info'] 72 | 73 | gt = decode_action(action, info) 74 | img_abs_path = os.path.join(pic_base, image) 75 | value = f"Picture 1: {img_abs_path}\n{PROMPT}{instruction}" 76 | 77 | his_str = '' 78 | for hidx, act in enumerate(history_action[-his_len:]): 79 | his_str += f'{hidx + 1}. {act}\n' 80 | 81 | if len(history_action) > 0 and his_len > 0: 82 | value += f'\nPrevious screenshots: image-history: {img_abs_path}' 83 | value += f'\nPrevious Actions: {his_str}' 84 | 85 | conversations = [{"from": "user", "value": value}, {"from": "assistant", "value": gt}] 86 | 87 | this_res.append({ 88 | 'id': f'GUIOdyssey_{name}_{idx}', 89 | 'image': img_abs_path, 90 | 'conversations': conversations, 91 | 'history': str(history_screenshot), 92 | }) 93 | idx += 1 94 | 95 | history_screenshot.append(img_abs_path) 96 | history_action.append(gt) 97 | 98 | res.extend(this_res) 99 | 100 | json.dump(res, open(os.path.join(train_anno_base, os.path.basename(split_fp)), 'w'), indent=4, ensure_ascii=False) 101 | 102 | 103 | def build_test(split_fp='./splits/random_split.json', his_len=4): 104 | os.makedirs(test_anno_base, exist_ok=True) 105 | name = os.path.basename(split_fp).split('.')[0] 106 | test_split = json.load(open(split_fp))['test'] 107 | res = [] 108 | idx = 0 109 | for f in test_split: 110 | this_res = [] 111 | fp = os.path.join(anno_base, f) 112 | data = json.load(open(fp)) 113 | instruction = data['task_info']['instruction'] 114 | steps = data['steps'] 115 | category = data['task_info']['category'] 116 | step_length = data['step_length'] 117 | 118 | history_screenshot = [] 119 | history_action = [] 120 | 121 | for step in steps: 122 | image = step['screenshot'] 123 | img_abs_path = os.path.join(pic_base, image) 124 | action = step['action'] 125 | info = step['info'] 126 | gt = decode_action(action, info) 127 | 128 | this_res.append({ 129 | 'id': f'GUIOdyssey_{name}_{idx}', 130 | 'image': img_abs_path, 131 | 'question': instruction, 132 | 'answer': gt, 133 | 'category': category, 134 | 'step_length': step_length, 135 | 'history_action': str(history_action), 136 | 'history_screenshot': str(history_screenshot), 137 | }) 138 | idx += 1 139 | 140 | history_screenshot.append(img_abs_path) 141 | history_action.append(gt) 142 | 143 | res.extend(this_res) 144 | 145 | json.dump(res, open(os.path.join(test_anno_base, os.path.basename(split_fp)), 'w'), indent=4, ensure_ascii=False) 146 | 147 | 148 | def make_his_idx(train_base=train_anno_base, test_base=test_anno_base): 149 | savep = './his_index.json' 150 | his_dict = {} 151 | for subsplit in os.listdir(train_base): 152 | subp = os.path.join(train_base, subsplit) 153 | 154 | data_all = json.load(open(subp)) 155 | for data in data_all: 156 | img = data['image'] 157 | history = eval(data['history']) 158 | if img in his_dict: 159 | assert his_dict[img] == history 160 | else: 161 | his_dict[img] = history 162 | 163 | for subsplit in os.listdir(test_base): 164 | subp = os.path.join(test_base, subsplit) 165 | data_all = json.load(open(subp)) 166 | for data in data_all: 167 | img = data['image'] 168 | history = eval(data['history_screenshot']) 169 | if img in his_dict: 170 | assert his_dict[img] == history 171 | else: 172 | his_dict[img] = history 173 | 174 | print(len(his_dict)) 175 | json.dump(his_dict, open(savep, 'w'), indent=4, ensure_ascii=False) 176 | 177 | 178 | def main(his_len): 179 | for f in os.listdir(split_base): 180 | fp = os.path.join(split_base, f) 181 | build_train_chat(fp, his_len) 182 | build_test(fp, his_len) 183 | 184 | make_his_idx() 185 | 186 | if __name__ == "__main__": 187 | parser = argparse.ArgumentParser() 188 | parser.add_argument('--his_len', type=int, default=4) 189 | args = parser.parse_args() 190 | main(args.his_len) -------------------------------------------------------------------------------- /data/preprocessing.py: -------------------------------------------------------------------------------- 1 | import os, shutil, pathlib 2 | 3 | screenshot_bp = './screenshots' 4 | anno_bp = './annotations' 5 | 6 | def reindex_screenshot(): 7 | parent_dir = pathlib.Path(screenshot_bp) 8 | for subdir in parent_dir.iterdir(): 9 | if subdir.is_dir(): 10 | for file_path in subdir.iterdir(): 11 | if file_path.is_file(): 12 | shutil.move(str(file_path), str(parent_dir / file_path.name)) 13 | print(f'{str(subdir)} ok.') 14 | subdir.rmdir() 15 | 16 | 17 | 18 | 19 | if __name__ == '__main__': 20 | reindex_screenshot() -------------------------------------------------------------------------------- /introduction.md: -------------------------------------------------------------------------------- 1 | # ⚙️ Data Structure 2 | 3 | 4 | ## Data Fields 5 | 6 | Each field of annotation is as follows: 7 | 8 | * `episode_id`(str): the unique identifier of this episode. 9 | * `device_info`(dict): the detailed information of the virtual device from which the episode was collected. 10 | * `product`(str): the product name of the emulator. 11 | * `release_version`(str): the Android API level of the emulator. 12 | * `sdk_version`(str): the version of the software development kit used for the emulator. 13 | * `h`(int): the height of the device screen. 14 | * `w`(int): the width of the device screen. 15 | * `device_name`(str): the name of the virtual device, one of **Pixel Fold**, **Pixel Tablet**, **Pixel 8 Pro**, **Pixel 7 Pro**, **Medium Phone**, **Small Phone** 16 | * `task_info`(dict): the detailed information of the task from which the episode was collected. 17 | * `category`(str): the category of this task, one of **Multi_Apps**, **Web_Shopping**, **General_Tool**, **Information_Management**, **Media_Entertainment**, **Social_Sharing** 18 | * `app`(list[str]): the Apps used for this task. 19 | * `meta_task`(str): the template for this task, e.g., "Search for the next {} and set a reminder." 20 | * `task`(str): the specific task created by filling in the meta-task, e.g., "Search for the next New York Fashion Week and set a reminder." 21 | * `instruction`(str): the detailed and rephrased version of the task, including specific tools or applications, e.g., "Utilize DuckDuckgo to find the dates for the next New York Fashion Week and then use TickTick to set a reminder for the event." 22 | * `step_length`(int): the total number of steps in this episode. 23 | * `steps`(list[dict]): each individual step of this episode. Including the following fields: 24 | * `step`(int): each step within the episode is identified by a zero-indexed step number, indicating its position in sequence within the episode. For example, if the *step* is 1, it corresponds to the second step of the episode. 25 | * `screenshot`(str): the current screenshot of this step 26 | * `action`(str): the corresponding action of this step, one of **CLICK**, **SCROLL**, **LONG_PRESS**, **TYPE**, **COMPLETE**, **IMPOSSIBLE**, **HOME**, **BACK** 27 | * `info`(Union[str, list[list]]): provides specific details required to perform the action specified in the *action* field. Note that all the coordinates are normalized to the range of [0, 1000]. 28 | * if action is *CLICK*, info contains the coordinates(x, y) to click on or one of the special keys *KEY_HOME*, *KEY_BACK*, *KEY_RECENT*. 29 | * if action is *LONG_PRESS*, info contains the coordinates(x, y) for the long press. 30 | * if action is *SCROLL*, info contains the starting(x1, y1) and ending(x2, y2) coordinates of the scroll action. 31 | * if action is any other value, info is empty (""). 32 | * `ps`(str): provides additional details or context depending on the value of the action field. 33 | * if action is *COMPLETE* or *IMPOSSIBLE*: may contain any additional information from the annotator about why the task is complete or why it was impossible to complete. 34 | * if action is *SCROLL*: contains the complete trajectory of the scroll action. 35 | 36 | 37 | 38 | ## Data Splits 39 | we can evaluate the in- and out-of-domain performance of Agent by splitting GUI Odyssey in two ways: 40 | 41 | * **random_split**: randomly splitting the dataset into the training and test set with the ratio of $3:1$, 42 | 43 | and organizing with the training set covering a portion of apps/tasks/devices and the test set covering the remaining apps/tasks/devices: 44 | 45 | 46 | * **task_split**: proportionally samples meta-tasks from six categories. The tasks in the test set differ significantly from those in the training set. This partitioning method allows for a robust assessment of an agent's generalization capabilities across diverse tasks. 47 | 48 | * **device_split**: selects episodes annotated on the *Fold Phone*, which differs significantly from other devices such as smartphones and tablets, as the test set. 49 | 50 | * **app_split**: splits based on the apps. The apps in the test set differ significantly from those in the training set. 51 | 52 | Each of the four classifications mentioned above has a corresponding JSON file, and the fields in each JSON file are as follows: 53 | * `train`(list[str]): the list of annotation filenames for the training set, which are equivalent to the *episode_id*. 54 | * `test`(list[str]): the list of annotation filenames for the test set, which are equivalent to the *episode_id*. -------------------------------------------------------------------------------- /src/SimSun.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/GUI-Odyssey/8bf1c756691f077e3d9439dc16103a2cf9493245/src/SimSun.ttf -------------------------------------------------------------------------------- /src/data_loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch 3 | import random 4 | import json, os 5 | 6 | 7 | class GUIOdysseyDataset(Dataset): 8 | def __init__(self, args): 9 | super().__init__() 10 | 11 | self.dataset = args.dataset 12 | 13 | self.data = self.load_GUIOdyssey() 14 | self.len = len(self.data) 15 | 16 | random.shuffle(self.data) 17 | print(self.len) 18 | 19 | def __len__(self): 20 | return self.len 21 | 22 | def __getitem__(self, idx): 23 | return self.data[idx] 24 | 25 | def load_GUIOdyssey(self): 26 | d = json.load(open(self.dataset)) 27 | return d -------------------------------------------------------------------------------- /src/eval_mm/GUIOdyssey_action_matching.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | TEXT_ANLS_THRESHOLD = 0.5 4 | CLICK_COORD_THRESHOLD = 0.14 5 | 6 | def levenshtein_distance(s1, s2): 7 | if len(s1) > len(s2): 8 | s1, s2 = s2, s1 9 | 10 | distances = range(len(s1) + 1) 11 | for i2, c2 in enumerate(s2): 12 | distances_ = [i2+1] 13 | for i1, c1 in enumerate(s1): 14 | if c1 == c2: 15 | distances_.append(distances[i1]) 16 | else: 17 | distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) 18 | distances = distances_ 19 | return distances[-1] 20 | 21 | 22 | def text_matching(gt, pred): 23 | gt = gt.strip() 24 | pred = pred.strip() 25 | if gt in pred or pred in gt: 26 | return True 27 | 28 | dist = levenshtein_distance(gt, pred) 29 | length = max(len(gt), len(pred)) 30 | value = 0.0 if length == 0 else float(dist) / float(length) 31 | value = 1 - value 32 | return value >= TEXT_ANLS_THRESHOLD 33 | 34 | 35 | def click_matching(gt_info, pred_info): 36 | if type(pred_info) == str: 37 | pred_info = eval(pred_info) 38 | if type(gt_info) == str: 39 | gt_info = eval(gt_info) 40 | 41 | pred = np.asarray(pred_info) / 1000 42 | gt = np.asarray(gt_info) / 1000 43 | 44 | return np.linalg.norm(pred - gt) <= CLICK_COORD_THRESHOLD 45 | 46 | 47 | 48 | def action_matching(pred_action, pred_info, gt_action, gt_info): 49 | pred_action = pred_action.strip() 50 | if type(pred_info) == str: 51 | pred_info = pred_info.strip() 52 | gt_action = gt_action.strip() 53 | if type(gt_info) == str: 54 | gt_info = gt_info.strip() 55 | 56 | if pred_action != gt_action: 57 | return {'is_correct': 'no', 'info': 'action_fail'} 58 | 59 | if gt_action not in ['SCROLL', 'CLICK', 'TYPE', 'LONG_PRESS']: 60 | return {'is_correct': 'yes', 'info': 'action_correct'} 61 | 62 | elif gt_action == 'TYPE': 63 | text_flag = text_matching(gt_info, pred_info) 64 | 65 | if text_flag: 66 | return {'is_correct': 'yes', 'info': 'type_correct'} 67 | else: 68 | return {'is_correct': 'no', 'info': 'type_fail'} 69 | 70 | elif gt_action == 'SCROLL': 71 | if gt_info.lower() == pred_info.lower(): 72 | return {'is_correct': 'yes', 'info': 'scroll_correct'} 73 | else: 74 | return {'is_correct': 'no', 'info': 'scroll_fail'} 75 | 76 | elif gt_action == 'CLICK' or gt_action == 'LONG_PRESS': 77 | click_flag = click_matching(gt_info, pred_info) 78 | 79 | if click_flag: 80 | return {'is_correct': 'yes', 'info': 'click_correct'} 81 | else: 82 | return {'is_correct': 'no', 'info': 'click_fail'} 83 | 84 | else: 85 | raise ValueError('Invalid action type') -------------------------------------------------------------------------------- /src/eval_mm/evaluate_GUIOdyssey.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import json 4 | import os, sys 5 | import torch 6 | import random 7 | import time 8 | from functools import partial 9 | from tqdm import tqdm 10 | import transformers 11 | from transformers import AutoModelForCausalLM, AutoTokenizer 12 | import warnings 13 | from GUIOdyssey_action_matching import action_matching 14 | import numpy as np 15 | 16 | warnings.filterwarnings("ignore") 17 | current_path = os.path.abspath(__file__) 18 | DATA_DIR = os.path.dirname(os.path.dirname(os.path.dirname(current_path))) 19 | 20 | sys.path.append(os.path.join(DATA_DIR, 'OdysseyAgent')) 21 | sys.path.append(os.path.join(DATA_DIR, 'src')) 22 | from qwen_generation_utils import make_context, decode_tokens 23 | 24 | 25 | IMAGE_HISTORY = True 26 | 27 | ds_collections = { 28 | 'app_split': { 29 | 'test': '../data/test_anno/app_split.json', 30 | 'metric': 'macro' 31 | }, 32 | 'device_split': { 33 | 'test': '../data/test_anno/device_split.json', 34 | 'metric': 'macro' 35 | }, 36 | 'random_split': { 37 | 'test': '../data/test_anno/random_split.json', 38 | 'metric': 'micro' 39 | }, 40 | 'task_split': { 41 | 'test': '../data/test_anno/task_split.json', 42 | 'metric': 'macro' 43 | } 44 | } 45 | 46 | 47 | def simple_decode(gt): 48 | gts = gt.split(':') 49 | gt_action = gts[0].strip() 50 | if len(gts) > 1: 51 | action = gt_action 52 | info = gts[1].strip() 53 | if action in ['CLICK', "LONG_PRESS"]: 54 | info = eval(info) 55 | else: 56 | action = gt_action 57 | info = "" 58 | return {"action": action, "info": info} 59 | 60 | def stat_result(eval_dict, metric): 61 | text_correct = sum([1 for _ in eval_dict if _['info'] == 'type_correct']) 62 | type_correct = sum([1 for _ in eval_dict if _['info'] != 'action_fail']) 63 | text_total = sum([1 for _ in eval_dict if _['info'].startswith('type_')]) 64 | 65 | if metric == 'macro': 66 | action_correct = sum([1 for _ in eval_dict if _['is_correct'] == 'yes']) 67 | AMS = round(action_correct / len(eval_dict) * 100, 2) 68 | SR_cnt, SR_tot, SR = check_SR(eval_dict) 69 | elif metric == 'micro': 70 | task_cate_dict = {} 71 | acc_list = [] 72 | SR_list = [] 73 | for sample in eval_dict: 74 | cat = sample['more_info']['category'] 75 | if cat not in task_cate_dict: 76 | task_cate_dict[cat] = [] 77 | task_cate_dict[cat].append(sample) 78 | assert len(task_cate_dict) == 6 79 | for k, v in task_cate_dict.items(): 80 | SR_cnt, SR_tot, SR = check_SR(v) 81 | SR_list.append((SR)) 82 | acc = round(sum([1 for x in v if x['is_correct'] == 'yes']) / len(v) * 100, 2) 83 | acc_list.append(acc) 84 | print(f'category: {k}, AMS: {acc}, SR: {SR}') 85 | 86 | AMS = np.round(np.mean(acc_list), 2) 87 | SR = np.round(np.mean(SR_list), 2) 88 | 89 | else: 90 | raise ValueError(f'No metric {metric} found.') 91 | 92 | info = { 93 | 'AMS': AMS, 94 | 'SR': SR, 95 | 'total': len(eval_dict), 96 | 'action_type': '{} / {} = {:.2f}'.format(type_correct, len(eval_dict), type_correct / len(eval_dict) * 100), 97 | 'text': '{} / {} = {:.2f}'.format(text_correct, text_total, text_correct / text_total * 100), 98 | } 99 | 100 | print(info) 101 | return info 102 | 103 | def action_matching_evaluation(pred_output, metric='macro'): 104 | eval_dict = [] 105 | for idx, sample in enumerate(pred_output): 106 | question, pred, gt, more_info = sample['question'], sample['pred'], sample['gt'], sample['more_info'] 107 | sample_eval_dict = {'question': question, 'pred': str(pred), 'gt': str(gt), 'more_info': more_info} 108 | 109 | gt_simple_info = simple_decode(gt) 110 | gt_action = gt_simple_info['action'] 111 | gt_info = gt_simple_info['info'] 112 | 113 | try: 114 | pred_simple_info = simple_decode(pred) 115 | pred_action = pred_simple_info['action'] 116 | pred_info = pred_simple_info['info'] 117 | except: 118 | print('eval err:', idx, pred) 119 | log_info = {'is_correct': 'no', 'info': 'invalid'} 120 | sample_eval_dict.update(log_info) 121 | eval_dict.append(sample_eval_dict) 122 | continue 123 | 124 | try: 125 | check_match = action_matching(pred_action, pred_info, gt_action, gt_info) 126 | except Exception as exc: 127 | print('eval err:', gt, pred, exc) 128 | check_match = {'is_correct': 'no', 'info': 'invalid'} 129 | 130 | sample_eval_dict.update(check_match) 131 | eval_dict.append(sample_eval_dict) 132 | 133 | 134 | info = stat_result(eval_dict, metric) 135 | metrics = {"info": info, "pred": eval_dict} 136 | return metrics 137 | 138 | 139 | 140 | def check_SR(eval_dict): 141 | episode_dict = {} 142 | steps_map = {} 143 | for data in eval_dict: 144 | if 'img' in data: img = data['img'] 145 | elif 'image' in data: img = data['image'] 146 | else: img = data['question'].split('')[0].split('')[1] 147 | img = os.path.basename(img) 148 | tail = img.split('_')[-1] 149 | episode = img.replace(f'_{tail}', '') 150 | if episode not in episode_dict: 151 | episode_dict[episode] = [] 152 | else: 153 | assert steps_map[episode] == data['more_info']['step_length'] 154 | 155 | info = data['is_correct'] 156 | episode_dict[episode].append(info) 157 | steps_map[episode] = data['more_info']['step_length'] 158 | 159 | cnt, tot = 0, 0 160 | for k, v in episode_dict.items(): 161 | if len(v) != steps_map[k]: 162 | print(f'step length of {k} does not match.') 163 | continue 164 | tot += 1 165 | v = list(set(v)) 166 | if len(v) == 1 and v[0] == 'yes': 167 | cnt += 1 168 | 169 | SR = round(cnt / tot * 100, 2) 170 | print(f'total episode: {tot}, successful episode: {cnt}, SR: {SR}') 171 | return cnt, tot, SR 172 | 173 | 174 | def rank0_print(*args): 175 | if torch.distributed.get_rank() == 0: 176 | print(*args) 177 | 178 | 179 | def collate_fn(batches, tokenizer): 180 | question = [_['question'] for _ in batches] 181 | raw_texts = [_['raw_text'] for _ in batches] 182 | gt = [_['gt'] for _ in batches] 183 | more_info = [_['more_info'] for _ in batches] 184 | input_ids = tokenizer(raw_texts, return_tensors='pt', padding='longest') 185 | 186 | return question, raw_texts, input_ids.input_ids, input_ids.attention_mask, gt, more_info 187 | 188 | 189 | class LazySupervisedDataset(torch.utils.data.Dataset): 190 | def __init__(self, datapath, tokenizer: transformers.PreTrainedTokenizer, his_len, max_window_size, chat_format): 191 | super(LazySupervisedDataset, self).__init__() 192 | self.aitw = json.load(open(datapath)) 193 | if len(self.aitw) > 50000: 194 | self.aitw = random.sample(self.aitw, len(self.aitw) // 10) 195 | self.tokenizer = tokenizer 196 | self.max_window_size = max_window_size 197 | self.chat_format = chat_format 198 | self.his_len = his_len 199 | 200 | def __len__(self): 201 | return len(self.aitw) 202 | 203 | def __getitem__(self, idx): 204 | data = self.aitw[idx] 205 | img = data['image'] 206 | question = f"Picture 1: {img}\nI'm looking for guidance on how to {data['question']}" 207 | answer = data['answer'] 208 | 209 | history_action = eval(data['history_action'])[-self.his_len:] 210 | if IMAGE_HISTORY: 211 | if len(history_action) > 0: 212 | his_img = f'\nPrevious screenshots: image-history: {img}' 213 | his_str = '\nPrevious Actions: ' 214 | for idx, hi in enumerate(history_action): 215 | his_str += f"{idx+1}. {hi}\n" 216 | 217 | question = f"{question}{his_img}{his_str}" 218 | else: 219 | if len(history_action) > 0: 220 | his_str = '\nPrevious Actions: ' 221 | for idx, hi in enumerate(history_action): 222 | his_str += f"{idx+1}. {hi}\n" 223 | 224 | question = f"{question}{his_str}" 225 | 226 | raw_text, _ = make_context(self.tokenizer, question, system="You are a helpful assistant.", max_window_size=self.max_window_size, chat_format=self.chat_format) 227 | more_info = {'category': data['category'], 'step_length': data['step_length']} 228 | return { 229 | 'raw_text': raw_text, 230 | 'question': question, 231 | 'gt': answer, 232 | 'more_info': more_info 233 | } 234 | 235 | 236 | class InferenceSampler(torch.utils.data.sampler.Sampler): 237 | def __init__(self, size): 238 | self._size = int(size) 239 | assert size > 0 240 | self._rank = torch.distributed.get_rank() 241 | self._world_size = torch.distributed.get_world_size() 242 | self._local_indices = self._get_local_indices(size, self._world_size, self._rank) 243 | 244 | @staticmethod 245 | def _get_local_indices(total_size, world_size, rank): 246 | shard_size = total_size // world_size 247 | left = total_size % world_size 248 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)] 249 | 250 | begin = sum(shard_sizes[:rank]) 251 | end = min(sum(shard_sizes[:rank + 1]), total_size) 252 | return range(begin, end) 253 | 254 | def __iter__(self): 255 | yield from self._local_indices 256 | 257 | def __len__(self): 258 | return len(self._local_indices) 259 | 260 | 261 | if __name__ == '__main__': 262 | parser = argparse.ArgumentParser() 263 | parser.add_argument('--checkpoint', type=str, default='/path/to/model') 264 | parser.add_argument('--dataset', type=str, default='random_split') 265 | parser.add_argument('--batch-size', type=int, default=4) 266 | parser.add_argument('--num-workers', type=int, default=12) 267 | parser.add_argument('--seed', type=int, default=2024) 268 | parser.add_argument('--output-path', type=str, default='output_res') 269 | parser.add_argument('--image-history', type=str, default='yes') 270 | parser.add_argument('--his_len', type=int, default=4) 271 | args = parser.parse_args() 272 | 273 | if args.image_history == 'no': 274 | IMAGE_HISTORY = False 275 | else: 276 | IMAGE_HISTORY = True 277 | 278 | torch.distributed.init_process_group(backend='nccl', world_size=int(os.getenv('WORLD_SIZE', '1')), rank=int(os.getenv('RANK', '0')),) 279 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0))) 280 | rank0_print(args) 281 | rank0_print('load model...') 282 | model = AutoModelForCausalLM.from_pretrained(args.checkpoint, device_map='cuda', trust_remote_code=True).eval() 283 | tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True) 284 | tokenizer.padding_side = 'left' 285 | tokenizer.pad_token_id = tokenizer.eod_id 286 | 287 | rank0_print('init test set...') 288 | random.seed(args.seed) 289 | datapath = ds_collections[args.dataset]['test'] 290 | dataset = LazySupervisedDataset(datapath=datapath, tokenizer=tokenizer, his_len=args.his_len, max_window_size=6144, chat_format='chatml') 291 | 292 | dataloader = torch.utils.data.DataLoader( 293 | dataset=dataset, 294 | sampler=InferenceSampler(len(dataset)), 295 | batch_size=args.batch_size, 296 | num_workers=args.num_workers, 297 | pin_memory=True, 298 | drop_last=False, 299 | collate_fn=partial(collate_fn, tokenizer=tokenizer), 300 | ) 301 | 302 | rank0_print(f'len of dataloader: {len(dataloader)}') 303 | outputs = [] 304 | for _, (question, raw_texts, input_ids, attention_mask, gt, more_info) in tqdm(enumerate(dataloader)): 305 | try: 306 | batch_input_ids = input_ids.to(model.device) 307 | batch_input_attention_mask = attention_mask.to(model.device) 308 | 309 | batch_out_ids = model.generate( 310 | input_ids=batch_input_ids, 311 | attention_mask=batch_input_attention_mask, 312 | do_sample=False, 313 | num_beams=1, 314 | length_penalty=1, 315 | num_return_sequences=1, 316 | use_cache=True, 317 | pad_token_id=tokenizer.eod_id, 318 | eos_token_id=tokenizer.eod_id, 319 | min_new_tokens=1, 320 | max_new_tokens=30, 321 | ) 322 | 323 | padding_lens = [batch_input_ids[i].eq(tokenizer.pad_token_id).sum().item() for i in range(batch_input_ids.size(0))] 324 | batch_response = [decode_tokens(batch_out_ids[i][padding_lens[i]:], tokenizer, raw_text_len=len(raw_texts[i]), context_length=(batch_input_ids[i].size(0)-padding_lens[i]), 325 | chat_format="chatml", verbose=False, errors='replace') for i in range(len(raw_texts))] 326 | 327 | for q, pred, _gt, info in zip(question, batch_response, gt, more_info): 328 | outputs.append({ 329 | 'question': q, 330 | 'pred': str(pred), 331 | 'gt': _gt, 332 | 'more_info': info, 333 | }) 334 | except Exception as e: 335 | print('error', e) 336 | print(_) 337 | continue 338 | 339 | print(f'rank {torch.distributed.get_rank()} finished inference.') 340 | torch.distributed.barrier() 341 | 342 | world_size = torch.distributed.get_world_size() 343 | merged_outputs = [None for _ in range(world_size)] 344 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs)) 345 | 346 | merged_outputs = [json.loads(_) for _ in merged_outputs] 347 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)] 348 | 349 | if torch.distributed.get_rank() == 0: 350 | print(f"Saving predict result ...") 351 | # time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime()) 352 | os.makedirs(args.output_path, exist_ok=True) 353 | model_name = str(args.checkpoint).replace('/', '_') 354 | savefile = os.path.join(args.output_path, f'{model_name}_{args.dataset}.json') 355 | json.dump(merged_outputs, open(savefile, 'w'), indent=4, ensure_ascii=False) 356 | 357 | print(f"Evaluating {args.dataset} ...") 358 | metrics = action_matching_evaluation(merged_outputs, metric=ds_collections[args.dataset]['metric']) 359 | 360 | output_data = {'dataset': args.dataset, 'model': model_name, 'metrics': metrics} 361 | json.dump(output_data, open(savefile, 'w'), indent=4, ensure_ascii=False) 362 | 363 | torch.distributed.barrier() 364 | -------------------------------------------------------------------------------- /src/finetune.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | import logging 3 | import os, sys 4 | from typing import Dict, Optional, List 5 | import torch 6 | from torch.utils.data import Dataset 7 | from deepspeed import zero 8 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 9 | import transformers 10 | from transformers import Trainer, GPTQConfig, deepspeed 11 | from transformers.trainer_pt_utils import LabelSmoother 12 | from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training 13 | from accelerate.utils import DistributedType 14 | from data_loader import GUIOdysseyDataset 15 | from PIL import Image 16 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index 17 | 18 | current_path = os.path.abspath(__file__) 19 | DATA_DIR = os.path.dirname(os.path.dirname(current_path)) 20 | OdysseyAgent_model_path = os.path.join(DATA_DIR, 'OdysseyAgent') 21 | print(OdysseyAgent_model_path) 22 | sys.path.append(OdysseyAgent_model_path) 23 | 24 | 25 | @dataclass 26 | class ModelArguments: 27 | model_name_or_path: Optional[str] = field(default="/path/to/model") 28 | 29 | 30 | @dataclass 31 | class DataArguments: 32 | dataset: str = field(default='/path/to/chat_format_annotations', metadata={"help": "the dataset"}) 33 | 34 | @dataclass 35 | class TrainingArguments(transformers.TrainingArguments): 36 | cache_dir: Optional[str] = field(default=None) 37 | optim: str = field(default="adamw_torch") 38 | model_max_length: int = field( 39 | default=8192, 40 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, 41 | ) 42 | use_lora: bool = False 43 | fix_vit: bool = True 44 | 45 | 46 | @dataclass 47 | class LoraArguments: 48 | lora_r: int = 64 49 | lora_alpha: int = 16 50 | lora_dropout: float = 0.05 51 | lora_target_modules: List[str] = field( 52 | default_factory=lambda: ["c_attn", "attn.c_proj", "w1", "w2"] ##["in_proj","out_proj","c_fc"] 53 | ) 54 | lora_weight_path: str = "" 55 | lora_bias: str = "none" 56 | q_lora: bool = False 57 | 58 | 59 | def maybe_zero_3(param): 60 | if hasattr(param, "ds_id"): 61 | assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE 62 | with zero.GatheredParameters([param]): 63 | param = param.data.detach().cpu().clone() 64 | else: 65 | param = param.detach().cpu().clone() 66 | return param 67 | 68 | 69 | # Borrowed from peft.utils.get_peft_model_state_dict 70 | def get_peft_state_maybe_zero_3(named_params, bias): 71 | if bias == "none": 72 | to_return = {k: t for k, t in named_params if "lora_" in k} 73 | elif bias == "all": 74 | to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} 75 | elif bias == "lora_only": 76 | to_return = {} 77 | maybe_lora_bias = {} 78 | lora_bias_names = set() 79 | for k, t in named_params: 80 | if "lora_" in k: 81 | to_return[k] = t 82 | bias_name = k.split("lora_")[0] + "bias" 83 | lora_bias_names.add(bias_name) 84 | elif "bias" in k: 85 | maybe_lora_bias[k] = t 86 | for k, t in maybe_lora_bias: 87 | if bias_name in lora_bias_names: 88 | to_return[bias_name] = t 89 | else: 90 | raise NotImplementedError 91 | to_return = {k: maybe_zero_3(v) for k, v in to_return.items()} 92 | return to_return 93 | 94 | local_rank = None 95 | 96 | def rank0_print(*args): 97 | if local_rank == 0: 98 | print(*args) 99 | 100 | 101 | def set_seed(seed: int): 102 | import numpy as np 103 | import random 104 | torch.manual_seed(seed) 105 | if torch.cuda.is_available(): 106 | torch.cuda.manual_seed_all(seed) 107 | np.random.seed(seed) 108 | random.seed(seed) 109 | os.environ["PYTHONHASHSEED"] = str(seed) 110 | 111 | 112 | def print_params(model): 113 | total_params = sum([p.numel() for p in model.parameters()]) 114 | trainable_params = sum([p.numel() for p in model.parameters() if p.requires_grad]) 115 | 116 | rank0_print(f"Total Parameters: {total_params}") 117 | rank0_print(f"Trainable Parameters: {trainable_params}") 118 | 119 | 120 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str, bias="none"): 121 | """Collects the state dict and dump to disk.""" 122 | # check if zero3 mode enabled 123 | if deepspeed.is_deepspeed_zero3_enabled(): 124 | state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict() 125 | else: 126 | if trainer.args.use_lora: 127 | state_dict = get_peft_state_maybe_zero_3( 128 | trainer.model.named_parameters(), bias 129 | ) 130 | else: 131 | state_dict = trainer.model.state_dict() 132 | if trainer.args.should_save and trainer.args.local_rank == 0: 133 | trainer._save(output_dir, state_dict=state_dict) 134 | 135 | 136 | def preprocess( 137 | sources, 138 | tokenizer: transformers.PreTrainedTokenizer, 139 | max_len: int, 140 | system_message: str = "You are a helpful assistant." 141 | ) -> Dict: 142 | roles = {"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"} 143 | 144 | im_start = tokenizer.im_start_id 145 | im_end = tokenizer.im_end_id 146 | nl_tokens = tokenizer('\n').input_ids 147 | _system = tokenizer('system').input_ids + nl_tokens 148 | _user = tokenizer('user').input_ids + nl_tokens 149 | _assistant = tokenizer('assistant').input_ids + nl_tokens 150 | 151 | # Apply prompt templates 152 | input_ids, targets = [], [] 153 | for i, source in enumerate(sources): 154 | if roles[source[0]["from"]] != roles["user"]: 155 | source = source[1:] 156 | 157 | input_id, target = [], [] 158 | system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens 159 | input_id += system 160 | target += [im_start] + [IGNORE_TOKEN_ID] * (len(system)-3) + [im_end] + nl_tokens 161 | assert len(input_id) == len(target) 162 | for j, sentence in enumerate(source): 163 | role = roles[sentence["from"]] 164 | _input_id = tokenizer(role).input_ids + nl_tokens + tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens 165 | input_id += _input_id 166 | if role == '<|im_start|>user': 167 | _target = [im_start] + [IGNORE_TOKEN_ID] * (len(_input_id)-3) + [im_end] + nl_tokens 168 | elif role == '<|im_start|>assistant': 169 | _target = [im_start] + [IGNORE_TOKEN_ID] * len(tokenizer(role).input_ids) + \ 170 | _input_id[len(tokenizer(role).input_ids)+1:-2] + [im_end] + nl_tokens 171 | else: 172 | raise NotImplementedError 173 | target += _target 174 | assert len(input_id) == len(target) 175 | input_id += [tokenizer.pad_token_id] * (max_len - len(input_id)) 176 | target += [IGNORE_TOKEN_ID] * (max_len - len(target)) 177 | input_ids.append(input_id[:max_len]) 178 | targets.append(target[:max_len]) 179 | input_ids = torch.tensor(input_ids, dtype=torch.int) 180 | targets = torch.tensor(targets, dtype=torch.int) 181 | 182 | return dict( 183 | input_ids=input_ids, 184 | labels=targets, 185 | attention_mask=input_ids.ne(tokenizer.pad_token_id), 186 | ) 187 | 188 | 189 | def load_img(url): 190 | img = Image.open(url) 191 | return img 192 | 193 | class LazySupervisedDataset(Dataset): 194 | """Dataset for supervised fine-tuning.""" 195 | 196 | def __init__(self, dataset, tokenizer: transformers.PreTrainedTokenizer, max_len: int, preprocess_strategy=None): 197 | super(LazySupervisedDataset, self).__init__() 198 | self.dataset = dataset 199 | self.tokenizer = tokenizer 200 | self.max_len = max_len 201 | self.preprocess_strategy = preprocess_strategy 202 | 203 | rank0_print("Formatting inputs...Skip in lazy mode") 204 | self.tokenizer = tokenizer 205 | self.cached_data_dict = {} 206 | 207 | def __len__(self): 208 | return len(self.dataset) 209 | 210 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 211 | try: 212 | if i in self.cached_data_dict: 213 | return self.cached_data_dict[i] 214 | ret = self.preprocess_strategy([self.dataset[i]["conversations"]], self.tokenizer, self.max_len) 215 | if "image" in self.dataset[i]: # test whether the image is valid 216 | img = load_img(self.dataset[i]["image"]) 217 | ret = dict( 218 | input_ids=ret["input_ids"][0], 219 | labels=ret["labels"][0], 220 | attention_mask=ret["attention_mask"][0], 221 | ) 222 | self.cached_data_dict[i] = ret 223 | 224 | return ret 225 | except Exception as e: 226 | print('get sample error:', str(e), self.dataset[i]) 227 | return self.__getitem__((i+1) % len(self.dataset)) 228 | 229 | 230 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args, max_len,) -> Dict: 231 | """Make dataset and collator for supervised fine-tuning.""" 232 | dataset_cls = LazySupervisedDataset 233 | rank0_print("Loading data...") 234 | 235 | train_ = GUIOdysseyDataset(data_args) 236 | 237 | preprocess_strategy = preprocess 238 | 239 | train_dataset = dataset_cls(train_, tokenizer=tokenizer, max_len=max_len, preprocess_strategy=preprocess_strategy) 240 | eval_dataset = None 241 | 242 | return dict(train_dataset=train_dataset, eval_dataset=eval_dataset) 243 | 244 | 245 | def train(): 246 | set_seed(201830168) 247 | global local_rank 248 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments, LoraArguments)) 249 | ( 250 | model_args, 251 | data_args, 252 | training_args, 253 | lora_args, 254 | ) = parser.parse_args_into_dataclasses() 255 | if getattr(training_args, 'deepspeed', None) and getattr(lora_args, 'q_lora', False): 256 | training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED 257 | 258 | 259 | local_rank = training_args.local_rank 260 | 261 | device_map = None 262 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 263 | ddp = world_size != 1 264 | if lora_args.q_lora: 265 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None 266 | if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled(): 267 | logging.warning("FSDP or ZeRO3 are not incompatible with QLoRA.") 268 | print(world_size) 269 | # Set RoPE scaling factor 270 | config = transformers.AutoConfig.from_pretrained( 271 | model_args.model_name_or_path, 272 | cache_dir=training_args.cache_dir, 273 | trust_remote_code=True, 274 | ) 275 | config.use_cache = False 276 | 277 | # Load model and tokenizer 278 | model = transformers.AutoModelForCausalLM.from_pretrained( 279 | model_args.model_name_or_path, 280 | config=config, 281 | cache_dir=training_args.cache_dir, 282 | device_map=device_map, 283 | trust_remote_code=True, 284 | quantization_config=GPTQConfig(bits=4, disable_exllama=True) if training_args.use_lora and lora_args.q_lora else None, 285 | ) 286 | if not training_args.use_lora: 287 | model.transformer.requires_grad_(True) 288 | if hasattr(model,'transformer') and hasattr(model.transformer,'visual'): 289 | model.transformer.visual.requires_grad_(False) 290 | if hasattr(model.transformer.visual,'attn_pool'): 291 | model.transformer.visual.attn_pool.requires_grad_(True) 292 | print_params(model) 293 | 294 | tokenizer = transformers.AutoTokenizer.from_pretrained( 295 | model_args.model_name_or_path, 296 | cache_dir=training_args.cache_dir, 297 | model_max_length=training_args.model_max_length, 298 | padding_side="right", 299 | use_fast=False, 300 | trust_remote_code=True, 301 | ) 302 | tokenizer.pad_token_id = tokenizer.eod_id 303 | 304 | if training_args.use_lora: 305 | if lora_args.q_lora or "chat" in model_args.model_name_or_path.lower(): 306 | modules_to_save = None 307 | else: 308 | modules_to_save = ["wte", "lm_head"] 309 | lora_config = LoraConfig( 310 | r=lora_args.lora_r, 311 | lora_alpha=lora_args.lora_alpha, 312 | target_modules=lora_args.lora_target_modules, 313 | lora_dropout=lora_args.lora_dropout, 314 | bias=lora_args.lora_bias, 315 | task_type="CAUSAL_LM", 316 | modules_to_save=modules_to_save # This argument serves for adding new tokens. 317 | ) 318 | if lora_args.q_lora: 319 | model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing) 320 | 321 | model = get_peft_model(model, lora_config) 322 | 323 | if training_args.gradient_checkpointing: 324 | model.enable_input_require_grads() 325 | 326 | rank0_print(training_args) 327 | # Load data 328 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args, max_len=training_args.model_max_length) 329 | train_module = {"train_dataset": data_module["train_dataset"], "eval_dataset": data_module["eval_dataset"]} 330 | rank0_print("training total: {}".format(len(data_module["train_dataset"]))) 331 | 332 | # Start trainner 333 | trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **train_module) 334 | resume_from_checkpoint = os.path.exists(training_args.output_dir) and len(os.listdir(training_args.output_dir)) != 0 335 | if resume_from_checkpoint: 336 | rank0_print('load from checkpoint.') 337 | trainer.train(resume_from_checkpoint=resume_from_checkpoint) 338 | trainer.save_state() 339 | 340 | safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir, bias=lora_args.lora_bias) 341 | 342 | 343 | if __name__ == "__main__": 344 | train() 345 | -------------------------------------------------------------------------------- /src/finetune/ds_config_zero1.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | 23 | "scheduler": { 24 | "type": "WarmupLR", 25 | "params": { 26 | "warmup_min_lr": "auto", 27 | "warmup_max_lr": "auto", 28 | "warmup_num_steps": "auto" 29 | } 30 | }, 31 | 32 | "zero_optimization": { 33 | "stage": 1, 34 | "offload_optimizer": { 35 | "device": "none", 36 | "pin_memory": true 37 | }, 38 | "allgather_partitions": true, 39 | "allgather_bucket_size": 2e8, 40 | "overlap_comm": true, 41 | "reduce_scatter": true, 42 | "reduce_bucket_size": 2e8, 43 | "contiguous_gradients": true 44 | }, 45 | 46 | "gradient_accumulation_steps": "auto", 47 | "gradient_clipping": "auto", 48 | "steps_per_print": 100, 49 | "train_batch_size": "auto", 50 | "train_micro_batch_size_per_gpu": "auto", 51 | "wall_clock_breakdown": false 52 | } -------------------------------------------------------------------------------- /src/finetune/ds_config_zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | 23 | "scheduler": { 24 | "type": "WarmupLR", 25 | "params": { 26 | "warmup_min_lr": "auto", 27 | "warmup_max_lr": "auto", 28 | "warmup_num_steps": "auto" 29 | } 30 | }, 31 | 32 | "zero_optimization": { 33 | "stage": 2, 34 | "offload_optimizer": { 35 | "device": "none", 36 | "pin_memory": true 37 | }, 38 | "allgather_partitions": true, 39 | "allgather_bucket_size": 2e8, 40 | "overlap_comm": true, 41 | "reduce_scatter": true, 42 | "reduce_bucket_size": 2e8, 43 | "contiguous_gradients": true 44 | }, 45 | 46 | "gradient_accumulation_steps": "auto", 47 | "gradient_clipping": "auto", 48 | "steps_per_print": 100, 49 | "train_batch_size": "auto", 50 | "train_micro_batch_size_per_gpu": "auto", 51 | "wall_clock_breakdown": false 52 | } -------------------------------------------------------------------------------- /src/finetune/ds_config_zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | 23 | "scheduler": { 24 | "type": "WarmupLR", 25 | "params": { 26 | "warmup_min_lr": "auto", 27 | "warmup_max_lr": "auto", 28 | "warmup_num_steps": "auto" 29 | } 30 | }, 31 | 32 | "zero_optimization": { 33 | "stage": 3, 34 | "offload_optimizer": { 35 | "device": "none", 36 | "pin_memory": true 37 | }, 38 | "offload_param": { 39 | "device": "none", 40 | "pin_memory": true 41 | }, 42 | "overlap_comm": true, 43 | "contiguous_gradients": true, 44 | "sub_group_size": 1e9, 45 | "reduce_bucket_size": "auto", 46 | "stage3_prefetch_bucket_size": "auto", 47 | "stage3_param_persistence_threshold": "auto", 48 | "stage3_max_live_parameters": 1e9, 49 | "stage3_max_reuse_distance": 1e9, 50 | "stage3_gather_16bit_weights_on_model_save": true 51 | }, 52 | 53 | "gradient_accumulation_steps": "auto", 54 | "gradient_clipping": "auto", 55 | "steps_per_print": 100, 56 | "train_batch_size": "auto", 57 | "train_micro_batch_size_per_gpu": "auto", 58 | "wall_clock_breakdown": false 59 | } 60 | -------------------------------------------------------------------------------- /src/finetune/finetune_ds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_DEVICE_MAX_CONNECTIONS=1 3 | DIR=`pwd` 4 | 5 | GPUS_PER_NODE=8 6 | NNODES=1 7 | NODE_RANK=0 8 | MASTER_ADDR=localhost 9 | MASTER_PORT=6001 10 | 11 | MODEL="Qwen/Qwen-VL-Chat" #"Qwen/Qwen-VL-Chat"/"Qwen/Qwen-VL" # Set the path if you do not want to load from huggingface directly 12 | # ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations. 13 | # See the section for finetuning in README for more information. 14 | DATA="path_to_data" 15 | 16 | DISTRIBUTED_ARGS=" 17 | --nproc_per_node $GPUS_PER_NODE \ 18 | --nnodes $NNODES \ 19 | --node_rank $NODE_RANK \ 20 | --master_addr $MASTER_ADDR \ 21 | --master_port $MASTER_PORT 22 | " 23 | 24 | torchrun $DISTRIBUTED_ARGS finetune.py \ 25 | --model_name_or_path $MODEL \ 26 | --data_path $DATA \ 27 | --bf16 True \ 28 | --fix_vit True \ 29 | --output_dir output_qwen \ 30 | --num_train_epochs 5 \ 31 | --per_device_train_batch_size 1 \ 32 | --per_device_eval_batch_size 1 \ 33 | --gradient_accumulation_steps 16 \ 34 | --evaluation_strategy "no" \ 35 | --save_strategy "steps" \ 36 | --save_steps 1000 \ 37 | --save_total_limit 10 \ 38 | --learning_rate 1e-5 \ 39 | --weight_decay 0.1 \ 40 | --adam_beta2 0.95 \ 41 | --warmup_ratio 0.01 \ 42 | --lr_scheduler_type "cosine" \ 43 | --logging_steps 1 \ 44 | --report_to "none" \ 45 | --model_max_length 2048 \ 46 | --gradient_checkpointing True \ 47 | --lazy_preprocess True \ 48 | --deepspeed finetune/ds_config_zero3.json 49 | -------------------------------------------------------------------------------- /src/finetune/finetune_lora_ds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_DEVICE_MAX_CONNECTIONS=1 3 | DIR=`pwd` 4 | 5 | GPUS_PER_NODE=8 6 | NNODES=1 7 | NODE_RANK=0 8 | MASTER_ADDR=localhost 9 | MASTER_PORT=6001 10 | 11 | MODEL="Qwen/Qwen-VL-Chat" #"Qwen/Qwen-VL-Chat"/"Qwen/Qwen-VL" Set the path if you do not want to load from huggingface directly 12 | # ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations. 13 | # See the section for finetuning in README for more information. 14 | DATA="path_to_data" 15 | 16 | DISTRIBUTED_ARGS=" 17 | --nproc_per_node $GPUS_PER_NODE \ 18 | --nnodes $NNODES \ 19 | --node_rank $NODE_RANK \ 20 | --master_addr $MASTER_ADDR \ 21 | --master_port $MASTER_PORT 22 | " 23 | 24 | torchrun $DISTRIBUTED_ARGS finetune.py \ 25 | --model_name_or_path $MODEL \ 26 | --data_path $DATA \ 27 | --bf16 True \ 28 | --fix_vit True \ 29 | --output_dir output_qwen \ 30 | --num_train_epochs 5 \ 31 | --per_device_train_batch_size 2 \ 32 | --per_device_eval_batch_size 1 \ 33 | --gradient_accumulation_steps 8 \ 34 | --evaluation_strategy "no" \ 35 | --save_strategy "steps" \ 36 | --save_steps 1000 \ 37 | --save_total_limit 10 \ 38 | --learning_rate 1e-5 \ 39 | --weight_decay 0.1 \ 40 | --adam_beta2 0.95 \ 41 | --warmup_ratio 0.01 \ 42 | --lr_scheduler_type "cosine" \ 43 | --logging_steps 1 \ 44 | --report_to "none" \ 45 | --model_max_length 2048 \ 46 | --lazy_preprocess True \ 47 | --use_lora \ 48 | --gradient_checkpointing \ 49 | --deepspeed finetune/ds_config_zero2.json -------------------------------------------------------------------------------- /src/finetune/finetune_lora_single_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_DEVICE_MAX_CONNECTIONS=1 3 | DIR=`pwd` 4 | 5 | 6 | MODEL="Qwen/Qwen-VL-Chat" #"Qwen/Qwen-VL-Chat"/"Qwen/Qwen-VL" # Set the path if you do not want to load from huggingface directly 7 | # ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations. 8 | # See the section for finetuning in README for more information. 9 | DATA="path_to_data" 10 | 11 | export CUDA_VISIBLE_DEVICES=0 12 | 13 | python finetune.py \ 14 | --model_name_or_path $MODEL \ 15 | --data_path $DATA \ 16 | --bf16 True \ 17 | --fix_vit True \ 18 | --output_dir output_qwen \ 19 | --num_train_epochs 5 \ 20 | --per_device_train_batch_size 1 \ 21 | --per_device_eval_batch_size 1 \ 22 | --gradient_accumulation_steps 8 \ 23 | --evaluation_strategy "no" \ 24 | --save_strategy "steps" \ 25 | --save_steps 1000 \ 26 | --save_total_limit 10 \ 27 | --learning_rate 1e-5 \ 28 | --weight_decay 0.1 \ 29 | --adam_beta2 0.95 \ 30 | --warmup_ratio 0.01 \ 31 | --lr_scheduler_type "cosine" \ 32 | --logging_steps 1 \ 33 | --report_to "none" \ 34 | --model_max_length 2048 \ 35 | --lazy_preprocess True \ 36 | --gradient_checkpointing \ 37 | --use_lora -------------------------------------------------------------------------------- /src/finetune/finetune_qlora_ds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_DEVICE_MAX_CONNECTIONS=1 3 | DIR=`pwd` 4 | 5 | GPUS_PER_NODE=8 6 | NNODES=1 7 | NODE_RANK=0 8 | MASTER_ADDR=localhost 9 | MASTER_PORT=6001 10 | 11 | MODEL="Qwen/Qwen-VL-Chat-Int4" # Qwen/Qwen-VL-Chat-Int4 Set the path if you do not want to load from huggingface directly 12 | # ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations. 13 | # See the section for finetuning in README for more information. 14 | DATA="path_to_data" 15 | 16 | 17 | DISTRIBUTED_ARGS=" 18 | --nproc_per_node $GPUS_PER_NODE \ 19 | --nnodes $NNODES \ 20 | --node_rank $NODE_RANK \ 21 | --master_addr $MASTER_ADDR \ 22 | --master_port $MASTER_PORT 23 | " 24 | 25 | # Remember to use --fp16 instead of --bf16 due to autogptq 26 | torchrun $DISTRIBUTED_ARGS finetune.py \ 27 | --model_name_or_path $MODEL \ 28 | --data_path $DATA \ 29 | --fp16 True \ 30 | --fix_vit True \ 31 | --output_dir output_qwen \ 32 | --num_train_epochs 5 \ 33 | --per_device_train_batch_size 2 \ 34 | --per_device_eval_batch_size 1 \ 35 | --gradient_accumulation_steps 8 \ 36 | --evaluation_strategy "no" \ 37 | --save_strategy "steps" \ 38 | --save_steps 1000 \ 39 | --save_total_limit 10 \ 40 | --learning_rate 1e-5 \ 41 | --weight_decay 0.1 \ 42 | --adam_beta2 0.95 \ 43 | --warmup_ratio 0.01 \ 44 | --lr_scheduler_type "cosine" \ 45 | --logging_steps 1 \ 46 | --report_to "none" \ 47 | --model_max_length 2048 \ 48 | --lazy_preprocess True \ 49 | --use_lora \ 50 | --q_lora \ 51 | --gradient_checkpointing \ 52 | --deepspeed finetune/ds_config_zero2.json -------------------------------------------------------------------------------- /src/finetune/finetune_qlora_single_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_DEVICE_MAX_CONNECTIONS=1 3 | DIR=`pwd` 4 | 5 | MODEL="Qwen/Qwen-VL-Chat-Int4" # Qwen/Qwen-VL-Chat-Int4 Set the path if you do not want to load from huggingface directly 6 | # ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations. 7 | # See the section for finetuning in README for more information. 8 | DATA="path_to_data" 9 | 10 | export CUDA_VISIBLE_DEVICES=0 11 | 12 | # Remember to use --fp16 instead of --bf16 due to autogptq 13 | python finetune.py \ 14 | --model_name_or_path $MODEL \ 15 | --data_path $DATA \ 16 | --fp16 True \ 17 | --fix_vit True \ 18 | --output_dir output_qwen \ 19 | --num_train_epochs 5 \ 20 | --per_device_train_batch_size 1 \ 21 | --per_device_eval_batch_size 1 \ 22 | --gradient_accumulation_steps 8 \ 23 | --evaluation_strategy "no" \ 24 | --save_strategy "steps" \ 25 | --save_steps 1000 \ 26 | --save_total_limit 10 \ 27 | --learning_rate 1e-5 \ 28 | --weight_decay 0.1 \ 29 | --adam_beta2 0.95 \ 30 | --warmup_ratio 0.01 \ 31 | --lr_scheduler_type "cosine" \ 32 | --logging_steps 1 \ 33 | --report_to "none" \ 34 | --model_max_length 2048 \ 35 | --lazy_preprocess True \ 36 | --gradient_checkpointing \ 37 | --use_lora \ 38 | --q_lora \ 39 | --deepspeed finetune/ds_config_zero2.json 40 | -------------------------------------------------------------------------------- /src/merge_weight.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | from transformers.generation import GenerationConfig 3 | from transformers import modeling_utils 4 | import torch 5 | 6 | QWENVL_PATH = 'Qwen/Qwen-VL-Chat' 7 | 8 | import sys 9 | sys.path.append('../OdysseyAgent') 10 | from configuration_qwen import QWenConfig 11 | from modeling_qwen import QWenLMHeadModel 12 | 13 | torch.manual_seed(1234) 14 | import json, random 15 | import time 16 | 17 | device = 'cpu' 18 | 19 | def load_qwen(model_name=QWENVL_PATH): 20 | model = AutoModelForCausalLM.from_pretrained(model_name, device_map=None, trust_remote_code=True) 21 | return model 22 | 23 | def load_qwen_tokenizer(model_name=QWENVL_PATH): 24 | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) 25 | return tokenizer 26 | 27 | 28 | def load_new_qwen(bp='../OdysseyAgent/config.json'): 29 | cfg = QWenConfig(**json.load(open(bp))) 30 | model = QWenLMHeadModel(cfg) 31 | return model 32 | 33 | def merge_weight(qwen, odysseyAgent): 34 | qwen_dict = qwen.state_dict() 35 | odysseyAgent_dict = odysseyAgent.state_dict() 36 | for k in qwen_dict.keys(): 37 | if k in odysseyAgent_dict: 38 | odysseyAgent_dict[k] = qwen_dict[k] 39 | odysseyAgent.load_state_dict(odysseyAgent_dict) 40 | return odysseyAgent 41 | 42 | 43 | def copy_QwenVL(bp='../OdysseyAgent'): 44 | tokenizer = load_qwen_tokenizer() 45 | tokenizer.save_pretrained(bp) 46 | qwen_model = load_qwen() 47 | new_qwen_model = load_new_qwen() 48 | print('start merging weight...') 49 | new_model = merge_weight(qwen_model, new_qwen_model) 50 | print('saving...') 51 | new_model.save_pretrained(bp) 52 | 53 | if __name__ == '__main__': 54 | copy_QwenVL() 55 | -------------------------------------------------------------------------------- /src/qwen_generation_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba Cloud. 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """Generation support.""" 7 | 8 | from typing import Tuple, List, Union, Iterable 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn.functional as F 13 | from transformers import PreTrainedTokenizer 14 | from transformers import logging 15 | from transformers.generation import LogitsProcessor 16 | 17 | logger = logging.get_logger(__name__) 18 | 19 | # Types. 20 | HistoryType = List[Tuple[str, str]] 21 | TokensType = List[int] 22 | BatchTokensType = List[List[int]] 23 | 24 | 25 | def pad_batch(batch: BatchTokensType, pad_id: int, seq_length: int) -> BatchTokensType: 26 | for tokens in batch: 27 | context_length = len(tokens) 28 | if context_length < seq_length: 29 | tokens.extend([pad_id] * (seq_length - context_length)) 30 | return batch 31 | 32 | 33 | def get_ltor_masks_and_position_ids( 34 | data, 35 | eod_token, 36 | reset_position_ids, 37 | reset_attention_mask, 38 | eod_mask_loss, 39 | ): 40 | """Build masks and position id for left to right model.""" 41 | 42 | # Extract batch size and sequence length. 43 | micro_batch_size, seq_length = data.size() 44 | 45 | # Attention mask (lower triangular). 46 | if reset_attention_mask: 47 | att_mask_batch = micro_batch_size 48 | else: 49 | att_mask_batch = 1 50 | attention_mask = torch.tril( 51 | torch.ones((att_mask_batch, seq_length, seq_length), device=data.device) 52 | ).view(att_mask_batch, 1, seq_length, seq_length) 53 | 54 | # Loss mask. 55 | loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) 56 | if eod_mask_loss: 57 | loss_mask[data == eod_token] = 0.0 58 | 59 | # Position ids. 60 | position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) 61 | position_ids = position_ids.unsqueeze(0).expand_as(data) 62 | # We need to clone as the ids will be modifed based on batch index. 63 | if reset_position_ids: 64 | position_ids = position_ids.clone() 65 | 66 | if reset_position_ids or reset_attention_mask: 67 | # Loop through the batches: 68 | for b in range(micro_batch_size): 69 | 70 | # Find indecies where EOD token is. 71 | eod_index = position_ids[b, data[b] == eod_token] 72 | # Detach indecies from positions if going to modify positions. 73 | if reset_position_ids: 74 | eod_index = eod_index.clone() 75 | 76 | # Loop through EOD indecies: 77 | prev_index = 0 78 | for j in range(eod_index.size()[0]): 79 | i = eod_index[j] 80 | # Mask attention loss. 81 | if reset_attention_mask: 82 | attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0 83 | # Reset positions. 84 | if reset_position_ids: 85 | position_ids[b, (i + 1) :] -= i + 1 - prev_index 86 | prev_index = i + 1 87 | 88 | # Convert attention mask to binary: 89 | attention_mask = attention_mask < 0.5 90 | 91 | return attention_mask, loss_mask, position_ids 92 | 93 | 94 | def get_batch(context_tokens: torch.LongTensor, eod_id: int): 95 | """Generate batch from context tokens.""" 96 | # Move to GPU. 97 | tokens = context_tokens.contiguous().to(context_tokens.device) 98 | # Get the attention mask and postition ids. 99 | attention_mask, _, position_ids = get_ltor_masks_and_position_ids( 100 | tokens, 101 | eod_id, 102 | reset_position_ids=False, 103 | reset_attention_mask=False, 104 | eod_mask_loss=False, 105 | ) 106 | return tokens, attention_mask, position_ids 107 | 108 | 109 | def get_stop_words_ids(chat_format, tokenizer): 110 | if chat_format == "raw": 111 | stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]] 112 | elif chat_format == "chatml": 113 | stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]] 114 | else: 115 | raise NotImplementedError(f"Unknown chat format {chat_format!r}") 116 | return stop_words_ids 117 | 118 | 119 | def make_context( 120 | tokenizer: PreTrainedTokenizer, 121 | query: str, 122 | history: List[Tuple[str, str]] = None, 123 | system: str = "", 124 | max_window_size: int = 6144, 125 | chat_format: str = "chatml", 126 | ): 127 | if history is None: 128 | history = [] 129 | 130 | if chat_format == "chatml": 131 | im_start, im_end = "<|im_start|>", "<|im_end|>" 132 | im_start_tokens = [tokenizer.im_start_id] 133 | im_end_tokens = [tokenizer.im_end_id] 134 | nl_tokens = tokenizer.encode("\n") 135 | 136 | def _tokenize_str(role, content): 137 | return f"{role}\n{content}", tokenizer.encode( 138 | role, allowed_special=set() 139 | ) + nl_tokens + tokenizer.encode(content, allowed_special=set()) 140 | 141 | system_text, system_tokens_part = _tokenize_str("system", system) 142 | system_tokens = im_start_tokens + system_tokens_part + im_end_tokens 143 | 144 | raw_text = "" 145 | context_tokens = [] 146 | 147 | for turn_query, turn_response in reversed(history): 148 | query_text, query_tokens_part = _tokenize_str("user", turn_query) 149 | query_tokens = im_start_tokens + query_tokens_part + im_end_tokens 150 | response_text, response_tokens_part = _tokenize_str( 151 | "assistant", turn_response 152 | ) 153 | response_tokens = im_start_tokens + response_tokens_part + im_end_tokens 154 | 155 | next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens 156 | prev_chat = ( 157 | f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}" 158 | ) 159 | 160 | current_context_size = ( 161 | len(system_tokens) + len(next_context_tokens) + len(context_tokens) 162 | ) 163 | if current_context_size < max_window_size: 164 | context_tokens = next_context_tokens + context_tokens 165 | raw_text = prev_chat + raw_text 166 | else: 167 | break 168 | 169 | context_tokens = system_tokens + context_tokens 170 | raw_text = f"{im_start}{system_text}{im_end}" + raw_text 171 | context_tokens += ( 172 | nl_tokens 173 | + im_start_tokens 174 | + _tokenize_str("user", query)[1] 175 | + im_end_tokens 176 | + nl_tokens 177 | + im_start_tokens 178 | + tokenizer.encode("assistant") 179 | + nl_tokens 180 | ) 181 | raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n" 182 | 183 | elif chat_format == "raw": 184 | raw_text = query 185 | context_tokens = tokenizer.encode(raw_text) 186 | else: 187 | raise NotImplementedError(f"Unknown chat format {chat_format!r}") 188 | 189 | return raw_text, context_tokens 190 | 191 | 192 | def _decode_default( 193 | tokens: List[int], 194 | *, 195 | stop_words: List[str], 196 | eod_words: List[str], 197 | tokenizer: PreTrainedTokenizer, 198 | raw_text_len: int, 199 | verbose: bool = False, 200 | return_end_reason: bool = False, 201 | errors: str='replace', 202 | ): 203 | trim_decode_tokens = tokenizer.decode(tokens, errors=errors)[raw_text_len:] 204 | if verbose: 205 | print("\nRaw Generate: ", trim_decode_tokens) 206 | 207 | end_reason = f"Gen length {len(tokens)}" 208 | for stop_word in stop_words: 209 | trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip() 210 | for eod_word in eod_words: 211 | if eod_word in trim_decode_tokens: 212 | end_reason = f"Gen {eod_word!r}" 213 | trim_decode_tokens = trim_decode_tokens.split(eod_word)[0] 214 | trim_decode_tokens = trim_decode_tokens.strip() 215 | if verbose: 216 | print("\nEnd Reason:", end_reason) 217 | print("\nGenerate: ", trim_decode_tokens) 218 | 219 | if return_end_reason: 220 | return trim_decode_tokens, end_reason 221 | else: 222 | return trim_decode_tokens 223 | 224 | 225 | def _decode_chatml( 226 | tokens: List[int], 227 | *, 228 | stop_words: List[str], 229 | eod_token_ids: List[int], 230 | tokenizer: PreTrainedTokenizer, 231 | raw_text_len: int, 232 | context_length: int, 233 | verbose: bool = False, 234 | return_end_reason: bool = False, 235 | errors: str='replace' 236 | ): 237 | end_reason = f"Gen length {len(tokens)}" 238 | eod_token_idx = context_length 239 | for eod_token_idx in range(context_length, len(tokens)): 240 | if tokens[eod_token_idx] in eod_token_ids: 241 | end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}" 242 | break 243 | 244 | trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors)[raw_text_len:] 245 | if verbose: 246 | print("\nRaw Generate w/o EOD:", tokenizer.decode(tokens, errors=errors)[raw_text_len:]) 247 | print("\nRaw Generate:", trim_decode_tokens) 248 | print("\nEnd Reason:", end_reason) 249 | for stop_word in stop_words: 250 | trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip() 251 | trim_decode_tokens = trim_decode_tokens.strip() 252 | if verbose: 253 | print("\nGenerate:", trim_decode_tokens) 254 | 255 | if return_end_reason: 256 | return trim_decode_tokens, end_reason 257 | else: 258 | return trim_decode_tokens 259 | 260 | 261 | def decode_tokens( 262 | tokens: Union[torch.LongTensor, TokensType], 263 | tokenizer: PreTrainedTokenizer, 264 | raw_text_len: int, 265 | context_length: int, 266 | chat_format: str, 267 | verbose: bool = False, 268 | return_end_reason: bool = False, 269 | errors: str="replace", 270 | ) -> str: 271 | if torch.is_tensor(tokens): 272 | tokens = tokens.cpu().numpy().tolist() 273 | 274 | if chat_format == "chatml": 275 | return _decode_chatml( 276 | tokens, 277 | stop_words=[], 278 | eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id], 279 | tokenizer=tokenizer, 280 | raw_text_len=raw_text_len, 281 | context_length=context_length, 282 | verbose=verbose, 283 | return_end_reason=return_end_reason, 284 | errors=errors, 285 | ) 286 | elif chat_format == "raw": 287 | return _decode_default( 288 | tokens, 289 | stop_words=["<|endoftext|>"], 290 | eod_words=["<|endoftext|>"], 291 | tokenizer=tokenizer, 292 | raw_text_len=raw_text_len, 293 | verbose=verbose, 294 | return_end_reason=return_end_reason, 295 | errors=errors, 296 | ) 297 | else: 298 | raise NotImplementedError(f"Unknown chat format {chat_format!r}") 299 | 300 | 301 | class StopWordsLogitsProcessor(LogitsProcessor): 302 | """ 303 | :class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration. 304 | Args: 305 | stop_words_ids (:obj:`List[List[int]]`): 306 | List of list of token ids of stop ids. In order to get the tokens of the words 307 | that should not appear in the generated text, use :obj:`tokenizer(bad_word, 308 | add_prefix_space=True).input_ids`. 309 | eos_token_id (:obj:`int`): 310 | The id of the `end-of-sequence` token. 311 | """ 312 | 313 | def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int): 314 | 315 | if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0: 316 | raise ValueError( 317 | f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}." 318 | ) 319 | if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids): 320 | raise ValueError( 321 | f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}." 322 | ) 323 | if any( 324 | any( 325 | (not isinstance(token_id, (int, np.integer)) or token_id < 0) 326 | for token_id in stop_word_ids 327 | ) 328 | for stop_word_ids in stop_words_ids 329 | ): 330 | raise ValueError( 331 | f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}." 332 | ) 333 | 334 | self.stop_words_ids = list( 335 | filter( 336 | lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids 337 | ) 338 | ) 339 | self.eos_token_id = eos_token_id 340 | for stop_token_seq in self.stop_words_ids: 341 | assert ( 342 | len(stop_token_seq) > 0 343 | ), "Stop words token sequences {} cannot have an empty list".format( 344 | stop_words_ids 345 | ) 346 | 347 | def __call__( 348 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor 349 | ) -> torch.FloatTensor: 350 | stopped_samples = self._calc_stopped_samples(input_ids) 351 | for i, should_stop in enumerate(stopped_samples): 352 | if should_stop: 353 | scores[i, self.eos_token_id] = float(2**15) 354 | return scores 355 | 356 | def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool: 357 | if len(tokens) == 0: 358 | # if bad word tokens is just one token always ban it 359 | return True 360 | elif len(tokens) > len(prev_tokens): 361 | # if bad word tokens are longer then prev input_ids they can't be equal 362 | return False 363 | elif prev_tokens[-len(tokens) :].tolist() == tokens: 364 | # if tokens match 365 | return True 366 | else: 367 | return False 368 | 369 | def _calc_stopped_samples(self, prev_input_ids: Iterable[int]) -> Iterable[int]: 370 | stopped_samples = [] 371 | for prev_input_ids_slice in prev_input_ids: 372 | match = False 373 | for stop_token_seq in self.stop_words_ids: 374 | if self._tokens_match(prev_input_ids_slice, stop_token_seq): 375 | # if tokens do not match continue 376 | match = True 377 | break 378 | stopped_samples.append(match) 379 | 380 | return stopped_samples 381 | 382 | 383 | def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")): 384 | """This function has been mostly taken from huggingface conversational 385 | ai code at 386 | https://medium.com/huggingface/how-to-build-a-state-of-the-art- 387 | conversational-ai-with-transfer-learning-2d818ac26313""" 388 | 389 | if top_k > 0: 390 | # Remove all tokens with a probability less than the 391 | # last token of the top-k 392 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 393 | logits[indices_to_remove] = filter_value 394 | 395 | if top_p > 0.0: 396 | # Cconvert to 1D 397 | sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) 398 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 399 | 400 | # Remove tokens with cumulative probability above the threshold 401 | sorted_indices_to_remove = cumulative_probs > top_p 402 | # Shift the indices to the right to keep also the first token 403 | # above the threshold 404 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 405 | sorted_indices_to_remove[..., 0] = 0 406 | for i in range(sorted_indices.size(0)): 407 | indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] 408 | logits[i][indices_to_remove] = filter_value 409 | 410 | return logits 411 | 412 | 413 | def switch(val1, val2, boolean): 414 | boolean = boolean.type_as(val1) 415 | return (1 - boolean) * val1 + boolean * val2 416 | -------------------------------------------------------------------------------- /src/requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu117 2 | torch==2.0.1 3 | torchvision==0.15.2 4 | torchaudio==2.0.2 5 | transformers==4.32.0 6 | accelerate 7 | tiktoken 8 | einops 9 | transformers_stream_generator==0.0.4 10 | scipy 11 | torchvision 12 | pillow 13 | tensorboard 14 | matplotlib 15 | peft 16 | deepspeed 17 | -------------------------------------------------------------------------------- /src/script/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | checkpoint=/path/to/checkpoint 3 | ds=app_split # one of "app_split", "device_split", "random_split", "task_split" 4 | DIR=`pwd` 5 | 6 | exp_name=OdysseyAgent_$ds 7 | mkdir -p output/"$exp_name" 8 | 9 | GPUS_PER_NODE=8 10 | NNODES=1 11 | NODE_RANK=0 12 | MASTER_ADDR=localhost 13 | MASTER_PORT=$((RANDOM % 30001 + 20000)) 14 | GPUS=$((GPUS_PER_NODE * NNODES)) 15 | 16 | DISTRIBUTED_ARGS=" 17 | --nproc_per_node $GPUS_PER_NODE \ 18 | --nnodes $NNODES \ 19 | --node_rank $NODE_RANK \ 20 | --master_addr $MASTER_ADDR \ 21 | --master_port $MASTER_PORT 22 | " 23 | 24 | echo $ds 25 | echo $checkpoint 26 | torchrun $DISTRIBUTED_ARGS eval_mm/evaluate_GUIOdyssey.py \ 27 | --checkpoint $checkpoint --dataset $ds --batch-size 16 --his_len 4 -------------------------------------------------------------------------------- /src/script/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | MODEL="/path/to/model" # "../OdysseyAgent" 3 | DATA_ROOT=/path/to/chat_format_annotation # ../data/train_anno/app_split.json ../data/train_anno/task_split.json 4 | 5 | exp_name=OdysseyAgent 6 | mkdir -p output/"$exp_name" 7 | OUTPUT_DIR=output_$exp_name 8 | export CUDA_DEVICE_MAX_CONNECTIONS=1 9 | DIR=`pwd` 10 | 11 | GPUS_PER_NODE=8 12 | NNODES=1 13 | NODE_RANK=0 14 | MASTER_ADDR=localhost 15 | MASTER_PORT=$((RANDOM % 10001 + 40000)) 16 | GPUS=$((GPUS_PER_NODE * NNODES)) 17 | 18 | DISTRIBUTED_ARGS=" 19 | --nproc_per_node $GPUS_PER_NODE \ 20 | --nnodes $NNODES \ 21 | --node_rank $NODE_RANK \ 22 | --master_addr $MASTER_ADDR \ 23 | --master_port $MASTER_PORT 24 | " 25 | torchrun $DISTRIBUTED_ARGS finetune.py \ 26 | --model_name_or_path $MODEL \ 27 | --dataset $DATA_ROOT \ 28 | --fp16 True \ 29 | --fix_vit True \ 30 | --output_dir $OUTPUT_DIR \ 31 | --num_train_epochs 1 \ 32 | --per_device_train_batch_size 2 \ 33 | --per_device_eval_batch_size 1 \ 34 | --gradient_accumulation_steps 8 \ 35 | --evaluation_strategy "no" \ 36 | --save_strategy "steps" \ 37 | --save_steps 5000 \ 38 | --save_total_limit 300 \ 39 | --learning_rate 2e-5 \ 40 | --weight_decay 0.1 \ 41 | --adam_beta2 0.95 \ 42 | --warmup_ratio 0.01 \ 43 | --lr_scheduler_type "cosine" \ 44 | --logging_steps 1 \ 45 | --report_to "none" \ 46 | --model_max_length 800 \ 47 | --gradient_checkpointing True \ 48 | --deepspeed finetune/ds_config_zero2.json 49 | 50 | --------------------------------------------------------------------------------