├── OdysseyAgent
├── config.json
├── configuration_qwen.py
├── generation_config.json
├── modeling_qwen.py
├── pytorch_model.bin.index.json
├── qwen.tiktoken
├── qwen_generation_utils.py
├── special_tokens_map.json
├── tokenization_qwen.py
├── tokenizer_config.json
└── visual.py
├── Quickstart.md
├── README.md
├── assets
├── dataset_overview.jpg
├── pipeline.jpg
└── pipeline.png
├── data
├── format_converter.py
└── preprocessing.py
├── introduction.md
└── src
├── SimSun.ttf
├── data_loader.py
├── eval_mm
├── GUIOdyssey_action_matching.py
└── evaluate_GUIOdyssey.py
├── finetune.py
├── finetune
├── ds_config_zero1.json
├── ds_config_zero2.json
├── ds_config_zero3.json
├── finetune_ds.sh
├── finetune_lora_ds.sh
├── finetune_lora_single_gpu.sh
├── finetune_qlora_ds.sh
└── finetune_qlora_single_gpu.sh
├── merge_weight.py
├── qwen_generation_utils.py
├── requirements.txt
└── script
├── eval.sh
└── train.sh
/OdysseyAgent/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "OdysseyAgent",
3 | "architectures": [
4 | "QWenLMHeadModel"
5 | ],
6 | "attn_dropout_prob": 0.0,
7 | "auto_map": {
8 | "AutoConfig": "configuration_qwen.QWenConfig",
9 | "AutoModelForCausalLM": "modeling_qwen.QWenLMHeadModel"
10 | },
11 | "bf16": true,
12 | "emb_dropout_prob": 0.0,
13 | "fp16": false,
14 | "fp32": false,
15 | "hidden_size": 4096,
16 | "his_len": 4,
17 | "initializer_range": 0.02,
18 | "intermediate_size": 22016,
19 | "kv_channels": 128,
20 | "layer_norm_epsilon": 1e-06,
21 | "max_position_embeddings": 8192,
22 | "model_type": "qwen",
23 | "no_bias": true,
24 | "num_attention_heads": 32,
25 | "num_hidden_layers": 32,
26 | "onnx_safe": null,
27 | "rotary_emb_base": 10000,
28 | "rotary_pct": 1.0,
29 | "scale_attn_weights": true,
30 | "seq_length": 2048,
31 | "tie_word_embeddings": false,
32 | "tokenizer_type": "QWenTokenizer",
33 | "torch_dtype": "bfloat16",
34 | "transformers_version": "4.32.0",
35 | "use_cache": true,
36 | "use_dynamic_ntk": true,
37 | "use_flash_attn": false,
38 | "use_logn_attn": true,
39 | "visual": {
40 | "heads": 16,
41 | "image_size": 448,
42 | "image_start_id": 151857,
43 | "layers": 48,
44 | "mlp_ratio": 4.9231,
45 | "output_dim": 4096,
46 | "patch_size": 14,
47 | "width": 1664
48 | },
49 | "vocab_size": 151936
50 | }
51 |
--------------------------------------------------------------------------------
/OdysseyAgent/configuration_qwen.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Alibaba Cloud.
2 | #
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from transformers import PretrainedConfig
7 |
8 |
9 | class QWenConfig(PretrainedConfig):
10 | model_type = "qwen"
11 | keys_to_ignore_at_inference = ["past_key_values"]
12 |
13 | def __init__(
14 | self,
15 | vocab_size=151936,
16 | hidden_size=4096,
17 | num_hidden_layers=32,
18 | num_attention_heads=32,
19 | emb_dropout_prob=0.0,
20 | attn_dropout_prob=0.0,
21 | layer_norm_epsilon=1e-6,
22 | initializer_range=0.02,
23 | max_position_embeddings=8192,
24 | scale_attn_weights=True,
25 | use_cache=True,
26 | bf16=False,
27 | fp16=False,
28 | fp32=False,
29 | kv_channels=128,
30 | rotary_pct=1.0,
31 | rotary_emb_base=10000,
32 | use_dynamic_ntk=True,
33 | use_logn_attn=True,
34 | use_flash_attn="auto",
35 | intermediate_size=22016,
36 | no_bias=True,
37 | tie_word_embeddings=False,
38 | **kwargs,
39 | ):
40 | self.vocab_size = vocab_size
41 | self.hidden_size = hidden_size
42 | self.intermediate_size = intermediate_size
43 | self.num_hidden_layers = num_hidden_layers
44 | self.num_attention_heads = num_attention_heads
45 | self.emb_dropout_prob = emb_dropout_prob
46 | self.attn_dropout_prob = attn_dropout_prob
47 | self.layer_norm_epsilon = layer_norm_epsilon
48 | self.initializer_range = initializer_range
49 | self.scale_attn_weights = scale_attn_weights
50 | self.use_cache = use_cache
51 | self.max_position_embeddings = max_position_embeddings
52 | self.bf16 = bf16
53 | self.fp16 = fp16
54 | self.fp32 = fp32
55 | self.kv_channels = kv_channels
56 | self.rotary_pct = rotary_pct
57 | self.rotary_emb_base = rotary_emb_base
58 | self.use_dynamic_ntk = use_dynamic_ntk
59 | self.use_logn_attn = use_logn_attn
60 | self.use_flash_attn = use_flash_attn
61 | self.no_bias = no_bias
62 | super().__init__(
63 | tie_word_embeddings=tie_word_embeddings,
64 | **kwargs
65 | )
66 |
--------------------------------------------------------------------------------
/OdysseyAgent/generation_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_from_model_config": true,
3 | "transformers_version": "4.32.0"
4 | }
5 |
--------------------------------------------------------------------------------
/OdysseyAgent/modeling_qwen.py:
--------------------------------------------------------------------------------
1 | print('OdysseyAgent')
2 |
3 | import importlib
4 | import math
5 | from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
6 | import os, json
7 | import numpy as np
8 | import torch
9 | import torch.nn.functional as F
10 | import torch.utils.checkpoint
11 | from torch.cuda.amp import autocast
12 |
13 | from torch.nn import CrossEntropyLoss
14 | from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
15 | from transformers.generation.logits_process import LogitsProcessorList
16 |
17 | if TYPE_CHECKING:
18 | from transformers.generation.streamers import BaseStreamer
19 | from transformers.generation.utils import GenerateOutput
20 | from transformers.modeling_outputs import (
21 | BaseModelOutputWithPast,
22 | CausalLMOutputWithPast,
23 | )
24 | from transformers.modeling_utils import PreTrainedModel
25 | from transformers.utils import logging
26 |
27 | try:
28 | from einops import rearrange
29 | except ImportError:
30 | rearrange = None
31 | from torch import nn
32 |
33 | SUPPORT_CUDA = torch.cuda.is_available()
34 | SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
35 | SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
36 |
37 | from torch.nn.init import trunc_normal_
38 | import sys
39 | sys.path.append('../OdysseyAgent')
40 |
41 | from configuration_qwen import QWenConfig
42 | from qwen_generation_utils import (
43 | HistoryType,
44 | make_context,
45 | decode_tokens,
46 | get_stop_words_ids,
47 | StopWordsLogitsProcessor,
48 | )
49 | from visual import VisionTransformer
50 |
51 | IMAGE_HISTORY = '../data/his_index.json'
52 |
53 | USE_RESAMPLER = True
54 |
55 | print(IMAGE_HISTORY)
56 | logger = logging.get_logger(__name__)
57 |
58 | _CHECKPOINT_FOR_DOC = "qwen"
59 | _CONFIG_FOR_DOC = "QWenConfig"
60 |
61 | QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]
62 |
63 | _ERROR_BAD_CHAT_FORMAT = """\
64 | We detect you are probably using the pretrained model (rather than chat model) for chatting, since the chat_format in generation_config is not "chatml".
65 | If you are directly using the model downloaded from Huggingface, please make sure you are using our "Qwen/Qwen-7B-Chat" Huggingface model (rather than "Qwen/Qwen-7B") when you call model.chat().
66 | 我们检测到您可能在使用预训练模型(而非chat模型)进行多轮chat,因为您当前在generation_config指定的chat_format,并未设置为我们在对话中所支持的"chatml"格式。
67 | 如果您在直接使用我们从Huggingface提供的模型,请确保您在调用model.chat()时,使用的是"Qwen/Qwen-7B-Chat"模型(而非"Qwen/Qwen-7B"预训练模型)。
68 | """
69 |
70 | _SENTINEL = object()
71 | _ERROR_STREAM_IN_CHAT = """\
72 | Pass argument `stream` to model.chat() is buggy, deprecated, and marked for removal. Please use model.chat_stream(...) instead of model.chat(..., stream=True).
73 | 向model.chat()传入参数stream的用法可能存在Bug,该用法已被废弃,将在未来被移除。请使用model.chat_stream(...)代替model.chat(..., stream=True)。
74 | """
75 |
76 | apply_rotary_emb_func = None
77 | rms_norm = None
78 |
79 |
80 | # Copied from transformers.models.bart.modeling_bart._make_causal_mask
81 | def _make_causal_mask(
82 | input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
83 | ):
84 | """
85 | Make causal mask used for bi-directional self-attention.
86 | """
87 | bsz, tgt_len = input_ids_shape
88 | mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
89 | mask_cond = torch.arange(mask.size(-1), device=device)
90 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
91 | mask = mask.to(dtype)
92 |
93 | if past_key_values_length > 0:
94 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
95 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
96 |
97 |
98 | # Copied from transformers.models.bart.modeling_bart._expand_mask
99 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
100 | """
101 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
102 | """
103 | bsz, src_len = mask.size()
104 | tgt_len = tgt_len if tgt_len is not None else src_len
105 |
106 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
107 |
108 | inverted_mask = 1.0 - expanded_mask
109 |
110 | return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
111 |
112 | def get_abs_pos(abs_pos, tgt_size):
113 | # abs_pos: L, C
114 | # tgt_size: M
115 | # return: M, C
116 | src_size = int(math.sqrt(abs_pos.size(0)))
117 | tgt_size = int(math.sqrt(tgt_size))
118 | dtype = abs_pos.dtype
119 |
120 | if src_size != tgt_size:
121 | return F.interpolate(
122 | abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
123 | size=(tgt_size, tgt_size),
124 | mode="bicubic",
125 | align_corners=False,
126 | ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
127 | else:
128 | return abs_pos
129 |
130 |
131 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
132 | """
133 | grid_size: int of the grid height and width
134 | return:
135 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
136 | """
137 | grid_h = np.arange(grid_size, dtype=np.float32)
138 | grid_w = np.arange(grid_size, dtype=np.float32)
139 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
140 | grid = np.stack(grid, axis=0)
141 |
142 | grid = grid.reshape([2, 1, grid_size, grid_size])
143 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
144 | if cls_token:
145 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
146 | return pos_embed
147 |
148 |
149 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
150 | assert embed_dim % 2 == 0
151 |
152 | # use half of dimensions to encode grid_h
153 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
154 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
155 |
156 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
157 | return emb
158 |
159 |
160 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
161 | """
162 | embed_dim: output dimension for each position
163 | pos: a list of positions to be encoded: size (M,)
164 | out: (M, D)
165 | """
166 | assert embed_dim % 2 == 0
167 | omega = np.arange(embed_dim // 2, dtype=np.float32)
168 | omega /= embed_dim / 2.
169 | omega = 1. / 10000**omega # (D/2,)
170 |
171 | pos = pos.reshape(-1) # (M,)
172 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
173 |
174 | emb_sin = np.sin(out) # (M, D/2)
175 | emb_cos = np.cos(out) # (M, D/2)
176 |
177 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
178 | return emb
179 |
180 |
181 |
182 | class HisResampler(nn.Module):
183 | def __init__(
184 | self,
185 | embed_dim=4096,
186 | num_heads=32,
187 | grid_size=16,
188 | kv_dim=None,
189 | norm_layer=nn.LayerNorm
190 | ):
191 | super().__init__()
192 | self.num_queries = grid_size ** 2
193 | self.embed_dim = embed_dim
194 | self.num_heads = num_heads
195 |
196 | self.pos_embed = nn.Parameter(
197 | torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float()
198 | ).requires_grad_(False)
199 |
200 | self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
201 | trunc_normal_(self.query, std=.02)
202 |
203 | if kv_dim is not None and kv_dim != embed_dim:
204 | self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
205 | else:
206 | self.kv_proj = nn.Identity()
207 |
208 | self.attn = nn.MultiheadAttention(embed_dim, num_heads)
209 | self.ln_q = norm_layer(embed_dim)
210 | self.ln_kv = norm_layer(embed_dim)
211 |
212 | self.ln_post = norm_layer(embed_dim)
213 | self.proj = nn.Parameter((embed_dim** -0.5) * torch.randn(embed_dim, embed_dim))
214 |
215 | self.apply(self._init_weights)
216 |
217 | def _init_weights(self, m):
218 | if isinstance(m, nn.Linear):
219 | trunc_normal_(m.weight, std=.02)
220 | if isinstance(m, nn.Linear) and m.bias is not None:
221 | nn.init.constant_(m.bias, 0)
222 | elif isinstance(m, nn.LayerNorm):
223 | nn.init.constant_(m.bias, 0)
224 | nn.init.constant_(m.weight, 1.0)
225 |
226 | def forward(self, x, attn_mask=None):
227 |
228 | x = self.kv_proj(x)
229 | x = self.ln_kv(x).permute(1, 0, 2)
230 |
231 | N = x.shape[1]
232 | q = self.ln_q(self.query)
233 | out = self.attn(
234 | self._repeat(q, N),
235 | x,
236 | x,
237 | attn_mask=attn_mask)[0]
238 | out = out.permute(1, 0, 2)
239 | out = self.ln_post(out)
240 | out = out @ self.proj
241 | return out
242 |
243 | def _repeat(self, query, N: int):
244 | return query.unsqueeze(1).repeat(1, N, 1)
245 |
246 |
247 |
248 | class QWenAttention(nn.Module):
249 | def __init__(self, config):
250 | super().__init__()
251 |
252 | self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
253 | self.seq_length = config.seq_length
254 |
255 | self.hidden_size = config.hidden_size
256 | self.split_size = config.hidden_size
257 | self.num_heads = config.num_attention_heads
258 | self.head_dim = self.hidden_size // self.num_heads
259 |
260 | self.scale_attn_weights = True
261 |
262 | self.projection_size = config.kv_channels * config.num_attention_heads
263 |
264 | assert self.projection_size % config.num_attention_heads == 0
265 | self.hidden_size_per_attention_head = (
266 | self.projection_size // config.num_attention_heads
267 | )
268 |
269 | self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size)
270 |
271 | self.c_proj = nn.Linear(
272 | config.hidden_size, self.projection_size, bias=not config.no_bias
273 | )
274 |
275 | self.is_fp32 = not (config.bf16 or config.fp16)
276 | self.bf16 = config.bf16
277 |
278 | self.use_dynamic_ntk = config.use_dynamic_ntk
279 | self.use_logn_attn = config.use_logn_attn
280 |
281 | logn_list = [
282 | math.log(i, self.seq_length) if i > self.seq_length else 1
283 | for i in range(1, 32768)
284 | ]
285 | self.logn_tensor = torch.tensor(logn_list)[None, :, None, None]
286 |
287 | self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
288 |
289 | def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None):
290 | attn_weights = torch.matmul(query, key.transpose(-1, -2))
291 |
292 | if self.scale_attn_weights:
293 | attn_weights = attn_weights / torch.full(
294 | [],
295 | value.size(-1) ** 0.5,
296 | dtype=attn_weights.dtype,
297 | device=attn_weights.device,
298 | )
299 |
300 | query_length, key_length = query.size(-2), key.size(-2)
301 | # causal_mask = self.bias[
302 | # :, :, key_length - query_length : key_length, :key_length
303 | # ]
304 | # mask_value = torch.finfo(attn_weights.dtype).min
305 | # mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(
306 | # attn_weights.device
307 | # )
308 | # attn_weights = torch.where(
309 | # causal_mask, attn_weights.to(attn_weights.dtype), mask_value
310 | # )
311 | attn_weights = attn_weights + attention_mask
312 |
313 | attn_weights = nn.functional.softmax(attn_weights, dim=-1)
314 |
315 | attn_weights = attn_weights.type(value.dtype)
316 | attn_weights = self.attn_dropout(attn_weights)
317 |
318 | if head_mask is not None:
319 | attn_weights = attn_weights * head_mask
320 |
321 | attn_output = torch.matmul(attn_weights, value)
322 | attn_output = attn_output.transpose(1, 2)
323 |
324 | return attn_output, attn_weights
325 |
326 | def _upcast_and_reordered_attn(
327 | self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None
328 | ):
329 | bsz, num_heads, q_seq_len, dk = query.size()
330 | _, _, k_seq_len, _ = key.size()
331 |
332 | attn_weights = torch.empty(
333 | bsz * num_heads,
334 | q_seq_len,
335 | k_seq_len,
336 | dtype=torch.float32,
337 | device=query.device,
338 | )
339 |
340 | scale_factor = 1.0
341 | if self.scale_attn_weights:
342 | scale_factor /= float(value.size(-1)) ** 0.5
343 |
344 | with autocast(enabled=False):
345 | q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(
346 | -1, dk, k_seq_len
347 | )
348 | attn_weights = torch.baddbmm(
349 | attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor
350 | )
351 | attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
352 |
353 | query_length, key_length = query.size(-2), key.size(-2)
354 | causal_mask = registered_causal_mask[
355 | :, :, key_length - query_length : key_length, :key_length
356 | ]
357 | mask_value = torch.finfo(attn_weights.dtype).min
358 | mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(
359 | attn_weights.device
360 | )
361 | attn_weights = torch.where(causal_mask, attn_weights, mask_value)
362 |
363 | if attention_mask is not None:
364 | attn_weights = attn_weights + attention_mask
365 |
366 | attn_weights = nn.functional.softmax(attn_weights, dim=-1)
367 |
368 | if attn_weights.dtype != torch.float32:
369 | raise RuntimeError(
370 | "Error with upcasting, attn_weights does not have dtype torch.float32"
371 | )
372 | attn_weights = attn_weights.type(value.dtype)
373 | attn_weights = self.attn_dropout(attn_weights)
374 |
375 | if head_mask is not None:
376 | attn_weights = attn_weights * head_mask
377 |
378 | attn_output = torch.matmul(attn_weights, value)
379 |
380 | return attn_output, attn_weights
381 |
382 | def _split_heads(self, tensor, num_heads, attn_head_size):
383 | new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
384 | tensor = tensor.view(new_shape)
385 | return tensor
386 |
387 | def _merge_heads(self, tensor, num_heads, attn_head_size):
388 | tensor = tensor.contiguous()
389 | new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
390 | return tensor.view(new_shape)
391 |
392 | def forward(
393 | self,
394 | hidden_states: Optional[Tuple[torch.FloatTensor]],
395 | rotary_pos_emb: Optional[List[torch.Tensor]] = None,
396 | registered_causal_mask: Optional[torch.Tensor] = None,
397 | layer_past: Optional[Tuple[torch.Tensor]] = None,
398 | attention_mask: Optional[torch.FloatTensor] = None,
399 | head_mask: Optional[torch.FloatTensor] = None,
400 | encoder_hidden_states: Optional[torch.Tensor] = None,
401 | encoder_attention_mask: Optional[torch.FloatTensor] = None,
402 | output_attentions: Optional[bool] = False,
403 | use_cache: Optional[bool] = False,
404 | ):
405 |
406 | mixed_x_layer = self.c_attn(hidden_states)
407 |
408 | query, key, value = mixed_x_layer.split(self.split_size, dim=2)
409 |
410 | query = self._split_heads(query, self.num_heads, self.head_dim)
411 | key = self._split_heads(key, self.num_heads, self.head_dim)
412 | value = self._split_heads(value, self.num_heads, self.head_dim)
413 |
414 | if rotary_pos_emb is not None:
415 | cur_len = query.shape[1]
416 | rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
417 | rotary_pos_emb = (rotary_pos_emb,) * 2
418 | q_pos_emb, k_pos_emb = rotary_pos_emb
419 | # Slice the pos emb for current inference
420 | query = apply_rotary_pos_emb(query, q_pos_emb)
421 | key = apply_rotary_pos_emb(key, k_pos_emb)
422 |
423 | if layer_past is not None:
424 | past_key, past_value = layer_past[0], layer_past[1]
425 | key = torch.cat((past_key, key), dim=1)
426 | value = torch.cat((past_value, value), dim=1)
427 |
428 | if use_cache:
429 | present = (key, value)
430 | else:
431 | present = None
432 |
433 | if self.use_logn_attn and not self.training:
434 | if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype:
435 | self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)
436 | seq_start = key.size(1) - query.size(1)
437 | seq_end = key.size(1)
438 | logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
439 | query = query * logn_tensor.expand_as(query)
440 |
441 | query = query.permute(0, 2, 1, 3)
442 | key = key.permute(0, 2, 1, 3)
443 | value = value.permute(0, 2, 1, 3)
444 | attn_output, attn_weight = self._attn(
445 | query, key, value, registered_causal_mask, attention_mask, head_mask
446 | )
447 | context_layer = self._merge_heads(
448 | attn_output, self.num_heads, self.head_dim
449 | )
450 |
451 | attn_output = self.c_proj(context_layer)
452 |
453 | outputs = (attn_output, present)
454 | if output_attentions:
455 | outputs += (attn_weight,)
456 |
457 | return outputs
458 |
459 |
460 | class QWenMLP(nn.Module):
461 | def __init__(self, config):
462 | super().__init__()
463 | self.w1 = nn.Linear(
464 | config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias
465 | )
466 | self.w2 = nn.Linear(
467 | config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias
468 | )
469 | ff_dim_in = config.intermediate_size // 2
470 | self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias)
471 |
472 | def forward(self, hidden_states):
473 | a1 = self.w1(hidden_states)
474 | a2 = self.w2(hidden_states)
475 | intermediate_parallel = a1 * F.silu(a2)
476 | output = self.c_proj(intermediate_parallel)
477 | return output
478 |
479 | class QWenBlock(nn.Module):
480 | def __init__(self, config):
481 | super().__init__()
482 | hidden_size = config.hidden_size
483 | self.bf16 = config.bf16
484 |
485 | self.ln_1 = RMSNorm(
486 | hidden_size,
487 | eps=config.layer_norm_epsilon,
488 | )
489 | self.attn = QWenAttention(config)
490 | self.ln_2 = RMSNorm(
491 | hidden_size,
492 | eps=config.layer_norm_epsilon,
493 | )
494 |
495 | self.mlp = QWenMLP(config)
496 |
497 | def forward(
498 | self,
499 | hidden_states: Optional[Tuple[torch.FloatTensor]],
500 | rotary_pos_emb: Optional[List[torch.Tensor]] = None,
501 | registered_causal_mask: Optional[torch.Tensor] = None,
502 | layer_past: Optional[Tuple[torch.Tensor]] = None,
503 | attention_mask: Optional[torch.FloatTensor] = None,
504 | head_mask: Optional[torch.FloatTensor] = None,
505 | encoder_hidden_states: Optional[torch.Tensor] = None,
506 | encoder_attention_mask: Optional[torch.FloatTensor] = None,
507 | use_cache: Optional[bool] = False,
508 | output_attentions: Optional[bool] = False,
509 | ):
510 | layernorm_output = self.ln_1(hidden_states)
511 |
512 | attn_outputs = self.attn(
513 | layernorm_output,
514 | rotary_pos_emb,
515 | registered_causal_mask=registered_causal_mask,
516 | layer_past=layer_past,
517 | attention_mask=attention_mask,
518 | head_mask=head_mask,
519 | use_cache=use_cache,
520 | output_attentions=output_attentions,
521 | )
522 | attn_output = attn_outputs[0]
523 |
524 | outputs = attn_outputs[1:]
525 |
526 | residual = hidden_states
527 | layernorm_input = attn_output + residual
528 |
529 | layernorm_output = self.ln_2(layernorm_input)
530 |
531 | residual = layernorm_input
532 | mlp_output = self.mlp(layernorm_output)
533 | hidden_states = residual + mlp_output
534 |
535 | if use_cache:
536 | outputs = (hidden_states,) + outputs
537 | else:
538 | outputs = (hidden_states,) + outputs[1:]
539 |
540 | return outputs
541 |
542 |
543 | class QWenPreTrainedModel(PreTrainedModel):
544 | config_class = QWenConfig
545 | base_model_prefix = "transformer"
546 | is_parallelizable = False
547 | supports_gradient_checkpointing = True
548 | _no_split_modules = ["QWenBlock"]
549 |
550 | def __init__(self, *inputs, **kwargs):
551 | super().__init__(*inputs, **kwargs)
552 |
553 | def _init_weights(self, module):
554 | """Initialize the weights."""
555 | if isinstance(module, nn.Linear):
556 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
557 | if module.bias is not None:
558 | module.bias.data.zero_()
559 | elif isinstance(module, nn.Embedding):
560 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
561 | if module.padding_idx is not None:
562 | module.weight.data[module.padding_idx].zero_()
563 | elif isinstance(module, RMSNorm):
564 | module.weight.data.fill_(1.0)
565 |
566 | for name, p in module.named_parameters():
567 | if name == "c_proj.weight":
568 | p.data.normal_(
569 | mean=0.0,
570 | std=(
571 | self.config.initializer_range
572 | / math.sqrt(2 * self.config.num_hidden_layers)
573 | ),
574 | )
575 |
576 | def _set_gradient_checkpointing(self, module, value=False):
577 | if isinstance(module, QWenModel):
578 | module.gradient_checkpointing = value
579 |
580 |
581 | class QWenModel(QWenPreTrainedModel):
582 | _keys_to_ignore_on_load_missing = ["attn.masked_bias"]
583 |
584 | def __init__(self, config):
585 | super().__init__(config)
586 | self.his_len = config.his_len
587 | self.vocab_size = config.vocab_size
588 | self.num_hidden_layers = config.num_hidden_layers
589 | self.embed_dim = config.hidden_size
590 |
591 | self.gradient_checkpointing = False
592 | self.use_dynamic_ntk = config.use_dynamic_ntk
593 | self.seq_length = config.seq_length
594 |
595 | self.wte = nn.Embedding(self.vocab_size, self.embed_dim)
596 |
597 | self.drop = nn.Dropout(config.emb_dropout_prob)
598 |
599 | if config.rotary_pct == 1.0:
600 | self.rotary_ndims = None
601 | else:
602 | assert config.rotary_pct < 1
603 | self.rotary_ndims = int(
604 | config.kv_channels * config.rotary_pct
605 | )
606 | dim = (
607 | self.rotary_ndims
608 | if self.rotary_ndims is not None
609 | else config.kv_channels
610 | )
611 | self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
612 |
613 | self.use_flash_attn = config.use_flash_attn
614 | self.is_fp32 = not (config.bf16 or config.fp16)
615 | self.registered_causal_mask = None
616 |
617 | self.h = nn.ModuleList(
618 | [
619 | QWenBlock(
620 | config
621 | )
622 | for i in range(config.num_hidden_layers)
623 | ]
624 | )
625 | self.ln_f = RMSNorm(
626 | self.embed_dim,
627 | eps=config.layer_norm_epsilon,
628 | )
629 |
630 | self.visual = VisionTransformer(**config.visual)
631 |
632 | self.post_init()
633 |
634 | if USE_RESAMPLER:
635 | print('init RESAMPLER')
636 | self.his_resampler = HisResampler()
637 |
638 | self.imgtoken_dict = {}
639 | if os.path.isdir(IMAGE_HISTORY):
640 | for subdata in os.listdir(IMAGE_HISTORY):
641 | sub_img_dict = json.load(open(os.path.join(IMAGE_HISTORY, subdata)))
642 | self.imgtoken_dict.update(sub_img_dict)
643 | else:
644 | self.imgtoken_dict = json.load(open(IMAGE_HISTORY))
645 |
646 | print('imgtoken_dict cache len:', len(self.imgtoken_dict))
647 |
648 | def set_input_embeddings(self, new_embeddings):
649 | self.wte = new_embeddings
650 |
651 | # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
652 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
653 | # create causal mask
654 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
655 | combined_attention_mask = None
656 | if input_shape[-1] > 1:
657 | combined_attention_mask = _make_causal_mask(
658 | input_shape,
659 | inputs_embeds.dtype,
660 | device=inputs_embeds.device,
661 | past_key_values_length=past_key_values_length,
662 | )
663 |
664 | if attention_mask is not None:
665 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
666 | expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
667 | inputs_embeds.device
668 | )
669 | combined_attention_mask = (
670 | expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
671 | )
672 |
673 | return combined_attention_mask
674 |
675 |
676 | def forward(
677 | self,
678 | input_ids: Optional[torch.LongTensor] = None,
679 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
680 | attention_mask: Optional[torch.FloatTensor] = None,
681 | token_type_ids: Optional[torch.LongTensor] = None,
682 | position_ids: Optional[torch.LongTensor] = None,
683 | head_mask: Optional[torch.FloatTensor] = None,
684 | inputs_embeds: Optional[torch.FloatTensor] = None,
685 | encoder_hidden_states: Optional[torch.Tensor] = None,
686 | encoder_attention_mask: Optional[torch.FloatTensor] = None,
687 | use_cache: Optional[bool] = None,
688 | output_attentions: Optional[bool] = None,
689 | output_hidden_states: Optional[bool] = None,
690 | return_dict: Optional[bool] = None,
691 | ):
692 | device = input_ids.device if input_ids is not None else inputs_embeds.device
693 |
694 | if past_key_values is None and torch.any(input_ids == self.config.visual['image_start_id']):
695 | bos_pos = torch.where(input_ids == self.config.visual['image_start_id'])
696 | eos_pos = torch.where(input_ids == self.config.visual['image_start_id'] + 1)
697 | assert (bos_pos[0] == eos_pos[0]).all()
698 | img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1)
699 | now_images = []
700 | his_images = []
701 | C_list = []
702 | images = []
703 | his_idx = []
704 | his_image_temp = []
705 | for idx, (i, a, b) in enumerate(img_pos):
706 | image = input_ids[i][a + 1 : b - 1].tolist()
707 | image = image[ : image.index(self.config.visual['image_start_id'] + 2)]
708 | image_path = bytes(image).decode('utf-8')
709 |
710 | if image_path.startswith('image-history: '):
711 | his_idx.append(idx)
712 | image_path = image_path.replace('image-history: ', '')
713 | his_list = self.imgtoken_dict[image_path][-self.his_len:] # t0 - tn-1
714 | assert len(his_list) > 0, his_list
715 |
716 | his_images.extend(his_list)
717 | his_image_temp.append(his_list)
718 |
719 | else:
720 | now_images.append(image_path)
721 |
722 | now_images = self.visual.encode(now_images)
723 |
724 | if len(his_images) > 0:
725 | his_images = self.visual.encode(his_images)
726 | his_tkn = None
727 |
728 | start_pos = 0
729 | for his_scr in his_image_temp:
730 | his_len = len(his_scr)
731 | his_img_feature = his_images[start_pos: start_pos + his_len] # [b, l, d]
732 | if USE_RESAMPLER:
733 | his_img_feature = his_img_feature.reshape(1, -1, his_img_feature.size(-1))
734 | his_vis_tkn = self.his_resampler(his_img_feature) # [l, d]
735 | else:
736 | raise ValueError("You cannot run without History Redsampler!")
737 | his_tkn = his_vis_tkn if his_tkn is None else torch.concat((his_tkn, his_vis_tkn), dim=0)
738 | start_pos += his_len
739 | assert start_pos == len(his_images)
740 | his_images = his_tkn
741 |
742 | now_p, his_p = 0, 0
743 | for j in range(len(img_pos)):
744 | if j not in his_idx:
745 | images.append(now_images[now_p])
746 | now_p += 1
747 | else:
748 | images.append(his_images[his_p])
749 | his_p += 1
750 | images = torch.stack(images, dim=0)
751 | assert len(images) == len(img_pos) == len(now_images) + len(his_images)
752 |
753 | fake_images = None
754 | elif self.training:
755 | fake_images=torch.zeros(1,3,224,224).to(
756 | dtype=self.visual.conv1.weight.dtype, device=self.visual.conv1.weight.device)
757 | images = self.visual(fake_images)
758 | else:
759 | fake_images = None
760 | images = None
761 |
762 | output_attentions = (
763 | output_attentions
764 | if output_attentions is not None
765 | else self.config.output_attentions
766 | )
767 | output_hidden_states = (
768 | output_hidden_states
769 | if output_hidden_states is not None
770 | else self.config.output_hidden_states
771 | )
772 | use_cache = use_cache if use_cache is not None else self.config.use_cache
773 | return_dict = (
774 | return_dict if return_dict is not None else self.config.use_return_dict
775 | )
776 |
777 | if input_ids is not None and inputs_embeds is not None:
778 | raise ValueError(
779 | "You cannot specify both input_ids and inputs_embeds at the same time"
780 | )
781 | elif input_ids is not None:
782 | input_shape = input_ids.size()
783 | input_ids = input_ids.view(-1, input_shape[-1])
784 | batch_size = input_ids.shape[0]
785 | elif inputs_embeds is not None:
786 | input_shape = inputs_embeds.size()[:-1]
787 | batch_size = inputs_embeds.shape[0]
788 | else:
789 | raise ValueError("You have to specify either input_ids or inputs_embeds")
790 |
791 |
792 | if token_type_ids is not None:
793 | token_type_ids = token_type_ids.view(-1, input_shape[-1])
794 | if position_ids is not None:
795 | position_ids = position_ids.view(-1, input_shape[-1])
796 |
797 | if past_key_values is None:
798 | past_length = 0
799 | past_key_values = tuple([None] * len(self.h))
800 | else:
801 | past_length = past_key_values[0][0].size(-2)
802 |
803 | if position_ids is None:
804 | position_ids = torch.arange(
805 | past_length,
806 | input_shape[-1] + past_length,
807 | dtype=torch.long,
808 | device=device,
809 | )
810 | position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
811 |
812 | encoder_attention_mask = None
813 | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
814 |
815 | if inputs_embeds is None:
816 | inputs_embeds = self.wte(input_ids)
817 |
818 | if batch_size <= 0:
819 | raise ValueError("batch_size has to be defined and > 0")
820 | attention_mask = self._prepare_decoder_attention_mask(
821 | attention_mask, input_shape, inputs_embeds, past_length
822 | )
823 |
824 | hidden_states = inputs_embeds
825 |
826 | kv_seq_len = hidden_states.size()[1]
827 | if past_key_values[0] is not None:
828 | # past key values[0][0] shape: bs * seq_len * head_num * dim
829 | kv_seq_len += past_key_values[0][0].shape[1]
830 | if (
831 | self.use_dynamic_ntk
832 | and kv_seq_len == hidden_states.size()[1]
833 | and not self.training
834 | ):
835 | context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
836 | ntk_alpha = 2 ** math.ceil(context_value) - 1
837 | ntk_alpha = max(ntk_alpha, 1)
838 | else:
839 | ntk_alpha = self.rotary_emb._ntk_alpha_cached
840 |
841 | rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha)
842 | for idx in range(len(rotary_pos_emb)):
843 | rotary_pos_emb[idx] = rotary_pos_emb[idx].to(hidden_states.device)
844 | hidden_states = self.drop(hidden_states).clone()
845 |
846 | if fake_images is not None:
847 | hidden_states = hidden_states + images.mean()*0
848 | elif images is not None:
849 | for idx, (i, a, b) in enumerate(img_pos):
850 | hidden_states[i][a + 1 : b] = images[idx]
851 | output_shape = input_shape + (hidden_states.size(-1),)
852 |
853 | if self.gradient_checkpointing and self.training:
854 | if use_cache:
855 | logger.warning_once(
856 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
857 | )
858 | use_cache = False
859 |
860 | presents = () if use_cache else None
861 | all_self_attentions = () if output_attentions else None
862 | all_hidden_states = () if output_hidden_states else None
863 | for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
864 |
865 | if output_hidden_states:
866 | all_hidden_states = all_hidden_states + (hidden_states,)
867 |
868 | if self.gradient_checkpointing and self.training:
869 |
870 | def create_custom_forward(module):
871 | def custom_forward(*inputs):
872 | # None for past_key_value
873 | return module(*inputs, use_cache, output_attentions)
874 |
875 | return custom_forward
876 |
877 | outputs = torch.utils.checkpoint.checkpoint(
878 | create_custom_forward(block),
879 | hidden_states,
880 | rotary_pos_emb,
881 | self.registered_causal_mask,
882 | None,
883 | attention_mask,
884 | head_mask[i],
885 | encoder_hidden_states,
886 | encoder_attention_mask,
887 | )
888 | else:
889 | outputs = block(
890 | hidden_states,
891 | layer_past=layer_past,
892 | rotary_pos_emb=rotary_pos_emb,
893 | registered_causal_mask=self.registered_causal_mask,
894 | attention_mask=attention_mask,
895 | head_mask=head_mask[i],
896 | encoder_hidden_states=encoder_hidden_states,
897 | encoder_attention_mask=encoder_attention_mask,
898 | use_cache=use_cache,
899 | output_attentions=output_attentions,
900 | )
901 |
902 | hidden_states = outputs[0]
903 | if use_cache is True:
904 | presents = presents + (outputs[1],)
905 |
906 | if output_attentions:
907 | all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
908 |
909 | hidden_states = self.ln_f(hidden_states)
910 | hidden_states = hidden_states.view(output_shape)
911 | # Add last hidden state
912 | if output_hidden_states:
913 | all_hidden_states = all_hidden_states + (hidden_states,)
914 |
915 | if not return_dict:
916 | return tuple(
917 | v for v in [hidden_states, presents, all_hidden_states] if v is not None
918 | )
919 |
920 | return BaseModelOutputWithPast(
921 | last_hidden_state=hidden_states,
922 | past_key_values=presents,
923 | hidden_states=all_hidden_states,
924 | attentions=all_self_attentions,
925 | )
926 |
927 |
928 | class QWenLMHeadModel(QWenPreTrainedModel):
929 | _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"]
930 | _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias"]
931 |
932 | def __init__(self, config):
933 | super().__init__(config)
934 | assert (
935 | config.bf16 + config.fp16 + config.fp32 <= 1
936 | ), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true"
937 |
938 | autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0
939 |
940 | if autoset_precision:
941 | if SUPPORT_BF16:
942 | logger.warn(
943 | "The model is automatically converting to bf16 for faster inference. "
944 | "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
945 | )
946 | config.bf16 = True
947 | elif SUPPORT_FP16:
948 | logger.warn(
949 | "The model is automatically converting to fp16 for faster inference. "
950 | "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
951 | )
952 | config.fp16 = True
953 | else:
954 | config.fp32 = True
955 |
956 | if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16:
957 | logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".")
958 | if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16:
959 | logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster")
960 | if config.fp32:
961 | if SUPPORT_BF16:
962 | logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".")
963 | elif SUPPORT_FP16:
964 | logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".")
965 |
966 | self.transformer = QWenModel(config)
967 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
968 |
969 | if config.bf16:
970 | self.transformer.bfloat16()
971 | self.lm_head.bfloat16()
972 | if config.fp16:
973 | self.transformer.half()
974 | self.lm_head.half()
975 | self.post_init()
976 |
977 | def get_output_embeddings(self):
978 | return self.lm_head
979 |
980 | def set_output_embeddings(self, new_embeddings):
981 | self.lm_head = new_embeddings
982 |
983 | def prepare_inputs_for_generation(
984 | self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
985 | ):
986 | token_type_ids = kwargs.get("token_type_ids", None)
987 | if past_key_values:
988 | input_ids = input_ids[:, -1].unsqueeze(-1)
989 | if token_type_ids is not None:
990 | token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
991 |
992 | attention_mask = kwargs.get("attention_mask", None)
993 | position_ids = kwargs.get("position_ids", None)
994 |
995 | if attention_mask is not None and position_ids is None:
996 | position_ids = attention_mask.long().cumsum(-1) - 1
997 | position_ids.masked_fill_(attention_mask == 0, 1)
998 | if past_key_values:
999 | position_ids = position_ids[:, -1].unsqueeze(-1)
1000 | else:
1001 | position_ids = None
1002 |
1003 | if inputs_embeds is not None and past_key_values is None:
1004 | model_inputs = {"inputs_embeds": inputs_embeds}
1005 | else:
1006 | model_inputs = {"input_ids": input_ids}
1007 |
1008 | model_inputs.update(
1009 | {
1010 | "past_key_values": past_key_values,
1011 | "use_cache": kwargs.get("use_cache"),
1012 | "position_ids": position_ids,
1013 | "attention_mask": attention_mask,
1014 | "token_type_ids": token_type_ids,
1015 | }
1016 | )
1017 | return model_inputs
1018 |
1019 | def forward(
1020 | self,
1021 | input_ids: Optional[torch.LongTensor] = None,
1022 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1023 | attention_mask: Optional[torch.FloatTensor] = None,
1024 | token_type_ids: Optional[torch.LongTensor] = None,
1025 | position_ids: Optional[torch.LongTensor] = None,
1026 | head_mask: Optional[torch.FloatTensor] = None,
1027 | inputs_embeds: Optional[torch.FloatTensor] = None,
1028 | encoder_hidden_states: Optional[torch.Tensor] = None,
1029 | encoder_attention_mask: Optional[torch.FloatTensor] = None,
1030 | labels: Optional[torch.LongTensor] = None,
1031 | use_cache: Optional[bool] = None,
1032 | output_attentions: Optional[bool] = None,
1033 | output_hidden_states: Optional[bool] = None,
1034 | return_dict: Optional[bool] = None,
1035 | ) -> Union[Tuple, CausalLMOutputWithPast]:
1036 |
1037 | return_dict = (
1038 | return_dict if return_dict is not None else self.config.use_return_dict
1039 | )
1040 |
1041 | transformer_outputs = self.transformer(
1042 | input_ids,
1043 | past_key_values=past_key_values,
1044 | attention_mask=attention_mask,
1045 | token_type_ids=token_type_ids,
1046 | position_ids=position_ids,
1047 | head_mask=head_mask,
1048 | inputs_embeds=inputs_embeds,
1049 | encoder_hidden_states=encoder_hidden_states,
1050 | encoder_attention_mask=encoder_attention_mask,
1051 | use_cache=use_cache,
1052 | output_attentions=output_attentions,
1053 | output_hidden_states=output_hidden_states,
1054 | return_dict=return_dict,
1055 | )
1056 | hidden_states = transformer_outputs[0]
1057 |
1058 | lm_logits = self.lm_head(hidden_states)
1059 |
1060 | loss = None
1061 | if labels is not None:
1062 | labels = labels.to(lm_logits.device)
1063 | shift_logits = lm_logits[..., :-1, :].contiguous()
1064 | shift_labels = labels[..., 1:].contiguous()
1065 | loss_fct = CrossEntropyLoss()
1066 | loss = loss_fct(
1067 | shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
1068 | )
1069 |
1070 | if not return_dict:
1071 | output = (lm_logits,) + transformer_outputs[1:]
1072 | return ((loss,) + output) if loss is not None else output
1073 |
1074 | return CausalLMOutputWithPast(
1075 | loss=loss,
1076 | logits=lm_logits,
1077 | past_key_values=transformer_outputs.past_key_values,
1078 | hidden_states=transformer_outputs.hidden_states,
1079 | attentions=transformer_outputs.attentions,
1080 | )
1081 |
1082 | @staticmethod
1083 | def _reorder_cache(
1084 | past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1085 | ) -> Tuple[Tuple[torch.Tensor]]:
1086 |
1087 | return tuple(
1088 | tuple(
1089 | past_state.index_select(0, beam_idx.to(past_state.device))
1090 | for past_state in layer_past
1091 | )
1092 | for layer_past in past_key_values
1093 | )
1094 |
1095 | def chat(
1096 | self,
1097 | tokenizer: PreTrainedTokenizer,
1098 | query: str,
1099 | history: Optional[HistoryType],
1100 | system: str = "You are a helpful assistant.",
1101 | append_history: bool = True,
1102 | stream: Optional[bool] = _SENTINEL,
1103 | stop_words_ids: Optional[List[List[int]]] = None,
1104 | generation_config: Optional[GenerationConfig] = None,
1105 | **kwargs,
1106 | ) -> Tuple[str, HistoryType]:
1107 | generation_config = generation_config if generation_config is not None else self.generation_config
1108 |
1109 | assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT
1110 | assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
1111 | if history is None:
1112 | history = []
1113 | if stop_words_ids is None:
1114 | stop_words_ids = []
1115 |
1116 | max_window_size = kwargs.get('max_window_size', None)
1117 | if max_window_size is None:
1118 | max_window_size = generation_config.max_window_size
1119 | raw_text, context_tokens = make_context(
1120 | tokenizer,
1121 | query,
1122 | history=history,
1123 | system=system,
1124 | max_window_size=max_window_size,
1125 | chat_format=generation_config.chat_format,
1126 | )
1127 |
1128 | stop_words_ids.extend(get_stop_words_ids(
1129 | generation_config.chat_format, tokenizer
1130 | ))
1131 | input_ids = torch.tensor([context_tokens]).to(self.device)
1132 | outputs = self.generate(
1133 | input_ids,
1134 | stop_words_ids=stop_words_ids,
1135 | return_dict_in_generate=False,
1136 | generation_config=generation_config,
1137 | **kwargs,
1138 | )
1139 |
1140 | response = decode_tokens(
1141 | outputs[0],
1142 | tokenizer,
1143 | raw_text_len=len(raw_text),
1144 | context_length=len(context_tokens),
1145 | chat_format=generation_config.chat_format,
1146 | verbose=False,
1147 | errors='replace'
1148 | )
1149 |
1150 | if append_history:
1151 | history.append((query, response))
1152 |
1153 | return response, history
1154 |
1155 | def chat_stream(
1156 | self,
1157 | tokenizer: PreTrainedTokenizer,
1158 | query: str,
1159 | history: Optional[HistoryType],
1160 | system: str = "You are a helpful assistant.",
1161 | stop_words_ids: Optional[List[List[int]]] = None,
1162 | logits_processor: Optional[LogitsProcessorList] = None,
1163 | generation_config: Optional[GenerationConfig] = None,
1164 | **kwargs,
1165 | ) -> Generator[str, Any, None]:
1166 | generation_config = generation_config if generation_config is not None else self.generation_config
1167 | assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
1168 | if history is None:
1169 | history = []
1170 | if stop_words_ids is None:
1171 | stop_words_ids = []
1172 |
1173 | max_window_size = kwargs.get('max_window_size', None)
1174 | if max_window_size is None:
1175 | max_window_size = generation_config.max_window_size
1176 | raw_text, context_tokens = make_context(
1177 | tokenizer,
1178 | query,
1179 | history=history,
1180 | system=system,
1181 | max_window_size=max_window_size,
1182 | chat_format=generation_config.chat_format,
1183 | )
1184 |
1185 | stop_words_ids.extend(get_stop_words_ids(
1186 | generation_config.chat_format, tokenizer
1187 | ))
1188 | if stop_words_ids is not None:
1189 | stop_words_logits_processor = StopWordsLogitsProcessor(
1190 | stop_words_ids=stop_words_ids,
1191 | eos_token_id=generation_config.eos_token_id,
1192 | )
1193 | if logits_processor is None:
1194 | logits_processor = LogitsProcessorList([stop_words_logits_processor])
1195 | else:
1196 | logits_processor.append(stop_words_logits_processor)
1197 | input_ids = torch.tensor([context_tokens]).to(self.device)
1198 |
1199 | from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
1200 | self.__class__.generate_stream = NewGenerationMixin.generate
1201 | self.__class__.sample_stream = NewGenerationMixin.sample_stream
1202 | stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)
1203 |
1204 | def stream_generator():
1205 | outputs = []
1206 | for token in self.generate_stream(
1207 | input_ids,
1208 | return_dict_in_generate=False,
1209 | generation_config=stream_config,
1210 | logits_processor=logits_processor,
1211 | seed=-1,
1212 | **kwargs):
1213 | outputs.append(token.item())
1214 | yield tokenizer.decode(outputs, skip_special_tokens=True, errors='ignore', keep_image_special=True)
1215 |
1216 | return stream_generator()
1217 |
1218 | def generate(
1219 | self,
1220 | inputs: Optional[torch.Tensor] = None,
1221 | generation_config: Optional[GenerationConfig] = None,
1222 | logits_processor: Optional[LogitsProcessorList] = None,
1223 | stopping_criteria: Optional[StoppingCriteriaList] = None,
1224 | prefix_allowed_tokens_fn: Optional[
1225 | Callable[[int, torch.Tensor], List[int]]
1226 | ] = None,
1227 | synced_gpus: Optional[bool] = None,
1228 | assistant_model: Optional["PreTrainedModel"] = None,
1229 | streamer: Optional["BaseStreamer"] = None,
1230 | **kwargs,
1231 | ) -> Union[GenerateOutput, torch.LongTensor]:
1232 | generation_config = generation_config if generation_config is not None else self.generation_config
1233 |
1234 | # Process stop_words_ids.
1235 | stop_words_ids = kwargs.pop("stop_words_ids", None)
1236 | if stop_words_ids is None and generation_config is not None:
1237 | stop_words_ids = getattr(generation_config, "stop_words_ids", None)
1238 | if stop_words_ids is None:
1239 | stop_words_ids = getattr(generation_config, "stop_words_ids", None)
1240 |
1241 | if stop_words_ids is not None:
1242 | stop_words_logits_processor = StopWordsLogitsProcessor(
1243 | stop_words_ids=stop_words_ids,
1244 | eos_token_id=generation_config.eos_token_id,
1245 | )
1246 | if logits_processor is None:
1247 | logits_processor = LogitsProcessorList([stop_words_logits_processor])
1248 | else:
1249 | logits_processor.append(stop_words_logits_processor)
1250 |
1251 | return super().generate(
1252 | inputs,
1253 | generation_config=generation_config,
1254 | logits_processor=logits_processor,
1255 | stopping_criteria=stopping_criteria,
1256 | prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
1257 | synced_gpus=synced_gpus,
1258 | assistant_model=assistant_model,
1259 | streamer=streamer,
1260 | **kwargs,
1261 | )
1262 |
1263 |
1264 | class RotaryEmbedding(torch.nn.Module):
1265 | def __init__(self, dim, base=10000):
1266 | super().__init__()
1267 | self.dim = dim
1268 | self.base = base
1269 | self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
1270 | if importlib.util.find_spec("einops") is None:
1271 | raise RuntimeError("einops is required for Rotary Embedding")
1272 |
1273 | self._rotary_pos_emb_cache = None
1274 | self._seq_len_cached = 0
1275 | self._ntk_alpha_cached = 1.0
1276 |
1277 | def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0):
1278 | seqlen = max_seq_len + offset
1279 | if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:
1280 | base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
1281 | self.inv_freq = 1.0 / (
1282 | base
1283 | ** (
1284 | torch.arange(0, self.dim, 2, device=self.inv_freq.device).float()
1285 | / self.dim
1286 | )
1287 | )
1288 | self._seq_len_cached = max(2 * seqlen, 16)
1289 | self._ntk_alpha_cached = ntk_alpha
1290 | seq = torch.arange(self._seq_len_cached, device=self.inv_freq.device)
1291 | freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
1292 |
1293 | emb = torch.cat((freqs, freqs), dim=-1)
1294 | from einops import rearrange
1295 |
1296 | emb = rearrange(emb, "n d -> 1 n 1 d")
1297 |
1298 | cos, sin = emb.cos(), emb.sin()
1299 | self._rotary_pos_emb_cache = [cos, sin]
1300 |
1301 | def forward(self, max_seq_len, offset=0, ntk_alpha=1.0):
1302 | self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha)
1303 | cos, sin = self._rotary_pos_emb_cache
1304 | return [cos[:, offset : offset + max_seq_len], sin[:, offset : offset + max_seq_len]]
1305 |
1306 |
1307 | def _rotate_half(x):
1308 | from einops import rearrange
1309 |
1310 | x = rearrange(x, "... (j d) -> ... j d", j=2)
1311 | x1, x2 = x.unbind(dim=-2)
1312 | return torch.cat((-x2, x1), dim=-1)
1313 |
1314 |
1315 | def apply_rotary_pos_emb(t, freqs):
1316 | cos, sin = freqs
1317 | if apply_rotary_emb_func is not None and t.is_cuda:
1318 | t_ = t.float()
1319 | cos = cos.squeeze(0).squeeze(1)[:, : cos.shape[-1] // 2]
1320 | sin = sin.squeeze(0).squeeze(1)[:, : sin.shape[-1] // 2]
1321 | output = apply_rotary_emb_func(t_, cos, sin).type_as(t)
1322 | return output
1323 | else:
1324 | rot_dim = freqs[0].shape[-1]
1325 | cos, sin = freqs
1326 | t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]
1327 | t_ = t_.float()
1328 | t_pass_ = t_pass_.float()
1329 | t_ = (t_ * cos) + (_rotate_half(t_) * sin)
1330 | return torch.cat((t_, t_pass_), dim=-1).type_as(t)
1331 |
1332 |
1333 | class RMSNorm(torch.nn.Module):
1334 | def __init__(self, dim: int, eps: float = 1e-6):
1335 | super().__init__()
1336 | self.eps = eps
1337 | self.weight = nn.Parameter(torch.ones(dim))
1338 |
1339 | def _norm(self, x):
1340 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
1341 |
1342 | def forward(self, x):
1343 | if rms_norm is not None and x.is_cuda:
1344 | return rms_norm(x, self.weight, self.eps)
1345 | else:
1346 | output = self._norm(x.float()).type_as(x)
1347 | return output * self.weight
1348 |
--------------------------------------------------------------------------------
/OdysseyAgent/qwen_generation_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Alibaba Cloud.
2 | #
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | """Generation support."""
7 |
8 | from typing import Tuple, List, Union, Iterable
9 |
10 | import numpy as np
11 | import torch
12 | import torch.nn.functional as F
13 | from transformers import PreTrainedTokenizer
14 | from transformers import logging
15 | from transformers.generation import LogitsProcessor
16 |
17 | logger = logging.get_logger(__name__)
18 |
19 | # Types.
20 | HistoryType = List[Tuple[str, str]]
21 | TokensType = List[int]
22 | BatchTokensType = List[List[int]]
23 |
24 |
25 | def pad_batch(batch: BatchTokensType, pad_id: int, seq_length: int) -> BatchTokensType:
26 | for tokens in batch:
27 | context_length = len(tokens)
28 | if context_length < seq_length:
29 | tokens.extend([pad_id] * (seq_length - context_length))
30 | return batch
31 |
32 |
33 | def get_ltor_masks_and_position_ids(
34 | data,
35 | eod_token,
36 | reset_position_ids,
37 | reset_attention_mask,
38 | eod_mask_loss,
39 | ):
40 | """Build masks and position id for left to right model."""
41 |
42 | # Extract batch size and sequence length.
43 | micro_batch_size, seq_length = data.size()
44 |
45 | # Attention mask (lower triangular).
46 | if reset_attention_mask:
47 | att_mask_batch = micro_batch_size
48 | else:
49 | att_mask_batch = 1
50 | attention_mask = torch.tril(
51 | torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
52 | ).view(att_mask_batch, 1, seq_length, seq_length)
53 |
54 | # Loss mask.
55 | loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
56 | if eod_mask_loss:
57 | loss_mask[data == eod_token] = 0.0
58 |
59 | # Position ids.
60 | position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
61 | position_ids = position_ids.unsqueeze(0).expand_as(data)
62 | # We need to clone as the ids will be modifed based on batch index.
63 | if reset_position_ids:
64 | position_ids = position_ids.clone()
65 |
66 | if reset_position_ids or reset_attention_mask:
67 | # Loop through the batches:
68 | for b in range(micro_batch_size):
69 |
70 | # Find indecies where EOD token is.
71 | eod_index = position_ids[b, data[b] == eod_token]
72 | # Detach indecies from positions if going to modify positions.
73 | if reset_position_ids:
74 | eod_index = eod_index.clone()
75 |
76 | # Loop through EOD indecies:
77 | prev_index = 0
78 | for j in range(eod_index.size()[0]):
79 | i = eod_index[j]
80 | # Mask attention loss.
81 | if reset_attention_mask:
82 | attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0
83 | # Reset positions.
84 | if reset_position_ids:
85 | position_ids[b, (i + 1) :] -= i + 1 - prev_index
86 | prev_index = i + 1
87 |
88 | # Convert attention mask to binary:
89 | attention_mask = attention_mask < 0.5
90 |
91 | return attention_mask, loss_mask, position_ids
92 |
93 |
94 | def get_batch(context_tokens: torch.LongTensor, eod_id: int):
95 | """Generate batch from context tokens."""
96 | # Move to GPU.
97 | tokens = context_tokens.contiguous().to(context_tokens.device)
98 | # Get the attention mask and postition ids.
99 | attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
100 | tokens,
101 | eod_id,
102 | reset_position_ids=False,
103 | reset_attention_mask=False,
104 | eod_mask_loss=False,
105 | )
106 | return tokens, attention_mask, position_ids
107 |
108 |
109 | def get_stop_words_ids(chat_format, tokenizer):
110 | if chat_format == "raw":
111 | stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]]
112 | elif chat_format == "chatml":
113 | stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
114 | else:
115 | raise NotImplementedError(f"Unknown chat format {chat_format!r}")
116 | return stop_words_ids
117 |
118 |
119 | def make_context(
120 | tokenizer: PreTrainedTokenizer,
121 | query: str,
122 | history: List[Tuple[str, str]] = None,
123 | system: str = "",
124 | max_window_size: int = 6144,
125 | chat_format: str = "chatml",
126 | ):
127 | if history is None:
128 | history = []
129 |
130 | if chat_format == "chatml":
131 | im_start, im_end = "<|im_start|>", "<|im_end|>"
132 | im_start_tokens = [tokenizer.im_start_id]
133 | im_end_tokens = [tokenizer.im_end_id]
134 | nl_tokens = tokenizer.encode("\n")
135 |
136 | def _tokenize_str(role, content):
137 | return f"{role}\n{content}", tokenizer.encode(
138 | role, allowed_special=set(tokenizer.IMAGE_ST)
139 | ) + nl_tokens + tokenizer.encode(content, allowed_special=set(tokenizer.IMAGE_ST))
140 |
141 | system_text, system_tokens_part = _tokenize_str("system", system)
142 | system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
143 |
144 | raw_text = ""
145 | context_tokens = []
146 |
147 | for turn_query, turn_response in reversed(history):
148 | query_text, query_tokens_part = _tokenize_str("user", turn_query)
149 | query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
150 | if turn_response is not None:
151 | response_text, response_tokens_part = _tokenize_str(
152 | "assistant", turn_response
153 | )
154 | response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
155 |
156 | next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
157 | prev_chat = (
158 | f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
159 | )
160 | else:
161 | next_context_tokens = nl_tokens + query_tokens + nl_tokens
162 | prev_chat = f"\n{im_start}{query_text}{im_end}\n"
163 |
164 | current_context_size = (
165 | len(system_tokens) + len(next_context_tokens) + len(context_tokens)
166 | )
167 | if current_context_size < max_window_size:
168 | context_tokens = next_context_tokens + context_tokens
169 | raw_text = prev_chat + raw_text
170 | else:
171 | break
172 |
173 | context_tokens = system_tokens + context_tokens
174 | raw_text = f"{im_start}{system_text}{im_end}" + raw_text
175 | context_tokens += (
176 | nl_tokens
177 | + im_start_tokens
178 | + _tokenize_str("user", query)[1]
179 | + im_end_tokens
180 | + nl_tokens
181 | + im_start_tokens
182 | + tokenizer.encode("assistant")
183 | + nl_tokens
184 | )
185 | raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
186 |
187 | elif chat_format == "raw":
188 | raw_text = query
189 | context_tokens = tokenizer.encode(raw_text)
190 | else:
191 | raise NotImplementedError(f"Unknown chat format {chat_format!r}")
192 |
193 | return raw_text, context_tokens
194 |
195 |
196 | def _decode_default(
197 | tokens: List[int],
198 | *,
199 | stop_words: List[str],
200 | eod_words: List[str],
201 | tokenizer: PreTrainedTokenizer,
202 | raw_text_len: int,
203 | verbose: bool = False,
204 | return_end_reason: bool = False,
205 | errors: str='replace',
206 | ):
207 | trim_decode_tokens = tokenizer.decode(tokens, errors=errors)[raw_text_len:]
208 | if verbose:
209 | print("\nRaw Generate: ", trim_decode_tokens)
210 |
211 | end_reason = f"Gen length {len(tokens)}"
212 | for stop_word in stop_words:
213 | trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
214 | for eod_word in eod_words:
215 | if eod_word in trim_decode_tokens:
216 | end_reason = f"Gen {eod_word!r}"
217 | trim_decode_tokens = trim_decode_tokens.split(eod_word)[0]
218 | trim_decode_tokens = trim_decode_tokens.strip()
219 | if verbose:
220 | print("\nEnd Reason:", end_reason)
221 | print("\nGenerate: ", trim_decode_tokens)
222 |
223 | if return_end_reason:
224 | return trim_decode_tokens, end_reason
225 | else:
226 | return trim_decode_tokens
227 |
228 |
229 | def _decode_chatml(
230 | tokens: List[int],
231 | *,
232 | stop_words: List[str],
233 | eod_token_ids: List[int],
234 | tokenizer: PreTrainedTokenizer,
235 | raw_text_len: int,
236 | context_length: int,
237 | verbose: bool = False,
238 | return_end_reason: bool = False,
239 | errors: str='replace'
240 | ):
241 | end_reason = f"Gen length {len(tokens)}"
242 | eod_token_idx = context_length
243 | for eod_token_idx in range(context_length, len(tokens)):
244 | if tokens[eod_token_idx] in eod_token_ids:
245 | end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}"
246 | break
247 |
248 | trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors)[raw_text_len:]
249 | if verbose:
250 | print("\nRaw Generate w/o EOD:", tokenizer.decode(tokens, errors=errors)[raw_text_len:])
251 | print("\nRaw Generate:", trim_decode_tokens)
252 | print("\nEnd Reason:", end_reason)
253 | for stop_word in stop_words:
254 | trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
255 | trim_decode_tokens = trim_decode_tokens.strip()
256 | if verbose:
257 | print("\nGenerate:", trim_decode_tokens)
258 |
259 | if return_end_reason:
260 | return trim_decode_tokens, end_reason
261 | else:
262 | return trim_decode_tokens
263 |
264 |
265 | def decode_tokens(
266 | tokens: Union[torch.LongTensor, TokensType],
267 | tokenizer: PreTrainedTokenizer,
268 | raw_text_len: int,
269 | context_length: int,
270 | chat_format: str,
271 | verbose: bool = False,
272 | return_end_reason: bool = False,
273 | errors: str="replace",
274 | ) -> str:
275 | if torch.is_tensor(tokens):
276 | tokens = tokens.cpu().numpy().tolist()
277 |
278 | if chat_format == "chatml":
279 | return _decode_chatml(
280 | tokens,
281 | stop_words=[],
282 | eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id],
283 | tokenizer=tokenizer,
284 | raw_text_len=raw_text_len,
285 | context_length=context_length,
286 | verbose=verbose,
287 | return_end_reason=return_end_reason,
288 | errors=errors,
289 | )
290 | elif chat_format == "raw":
291 | return _decode_default(
292 | tokens,
293 | stop_words=["<|endoftext|>"],
294 | eod_words=["<|endoftext|>"],
295 | tokenizer=tokenizer,
296 | raw_text_len=raw_text_len,
297 | verbose=verbose,
298 | return_end_reason=return_end_reason,
299 | errors=errors,
300 | )
301 | else:
302 | raise NotImplementedError(f"Unknown chat format {chat_format!r}")
303 |
304 |
305 | class StopWordsLogitsProcessor(LogitsProcessor):
306 | """
307 | :class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration.
308 |
309 | Args:
310 | stop_words_ids (:obj:`List[List[int]]`):
311 | List of list of token ids of stop ids. In order to get the tokens of the words
312 | that should not appear in the generated text, use :obj:`tokenizer(bad_word,
313 | add_prefix_space=True).input_ids`.
314 | eos_token_id (:obj:`int`):
315 | The id of the `end-of-sequence` token.
316 | """
317 |
318 | def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int):
319 |
320 | if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0:
321 | raise ValueError(
322 | f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}."
323 | )
324 | if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids):
325 | raise ValueError(
326 | f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}."
327 | )
328 | if any(
329 | any(
330 | (not isinstance(token_id, (int, np.integer)) or token_id < 0)
331 | for token_id in stop_word_ids
332 | )
333 | for stop_word_ids in stop_words_ids
334 | ):
335 | raise ValueError(
336 | f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}."
337 | )
338 |
339 | self.stop_words_ids = list(
340 | filter(
341 | lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids
342 | )
343 | )
344 | self.eos_token_id = eos_token_id
345 | for stop_token_seq in self.stop_words_ids:
346 | assert (
347 | len(stop_token_seq) > 0
348 | ), "Stop words token sequences {} cannot have an empty list".format(
349 | stop_words_ids
350 | )
351 |
352 | def __call__(
353 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor
354 | ) -> torch.FloatTensor:
355 | stopped_samples = self._calc_stopped_samples(input_ids)
356 | for i, should_stop in enumerate(stopped_samples):
357 | if should_stop:
358 | scores[i, self.eos_token_id] = float(2**15)
359 | return scores
360 |
361 | def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool:
362 | if len(tokens) == 0:
363 | # if bad word tokens is just one token always ban it
364 | return True
365 | elif len(tokens) > len(prev_tokens):
366 | # if bad word tokens are longer then prev input_ids they can't be equal
367 | return False
368 | elif prev_tokens[-len(tokens) :].tolist() == tokens:
369 | # if tokens match
370 | return True
371 | else:
372 | return False
373 |
374 | def _calc_stopped_samples(self, prev_input_ids: Iterable[int]) -> Iterable[int]:
375 | stopped_samples = []
376 | for prev_input_ids_slice in prev_input_ids:
377 | match = False
378 | for stop_token_seq in self.stop_words_ids:
379 | if self._tokens_match(prev_input_ids_slice, stop_token_seq):
380 | # if tokens do not match continue
381 | match = True
382 | break
383 | stopped_samples.append(match)
384 |
385 | return stopped_samples
386 |
387 |
388 | def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
389 | """This function has been mostly taken from huggingface conversational
390 | ai code at
391 | https://medium.com/huggingface/how-to-build-a-state-of-the-art-
392 | conversational-ai-with-transfer-learning-2d818ac26313"""
393 |
394 | if top_k > 0:
395 | # Remove all tokens with a probability less than the
396 | # last token of the top-k
397 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
398 | logits[indices_to_remove] = filter_value
399 |
400 | if top_p > 0.0:
401 | # Cconvert to 1D
402 | sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
403 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
404 |
405 | # Remove tokens with cumulative probability above the threshold
406 | sorted_indices_to_remove = cumulative_probs > top_p
407 | # Shift the indices to the right to keep also the first token
408 | # above the threshold
409 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
410 | sorted_indices_to_remove[..., 0] = 0
411 | for i in range(sorted_indices.size(0)):
412 | indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
413 | logits[i][indices_to_remove] = filter_value
414 |
415 | return logits
416 |
417 |
418 | def switch(val1, val2, boolean):
419 | boolean = boolean.type_as(val1)
420 | return (1 - boolean) * val1 + boolean * val2
421 |
--------------------------------------------------------------------------------
/OdysseyAgent/special_tokens_map.json:
--------------------------------------------------------------------------------
1 | {}
2 |
--------------------------------------------------------------------------------
/OdysseyAgent/tokenization_qwen.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Alibaba Cloud.
2 | #
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | """Tokenization classes for QWen."""
7 |
8 | import base64
9 | import logging
10 | import os
11 | import requests
12 | import unicodedata
13 | from typing import Collection, Dict, List, Set, Tuple, Union, Any, Callable, Optional
14 |
15 | import tiktoken
16 | import numpy as np
17 | from PIL import Image
18 | from PIL import ImageFont
19 | from PIL import ImageDraw
20 | from transformers import PreTrainedTokenizer, AddedToken
21 | from transformers.utils import try_to_load_from_cache
22 |
23 | import matplotlib.colors as mcolors
24 | from matplotlib.font_manager import FontProperties
25 |
26 | logger = logging.getLogger(__name__)
27 |
28 |
29 | VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken", "ttf": "SimSun.ttf"}
30 | FONT_PATH = try_to_load_from_cache("Qwen/Qwen-VL-Chat", "SimSun.ttf")
31 | if FONT_PATH is None:
32 | if not os.path.exists("SimSun.ttf"):
33 | ttf = requests.get("https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/SimSun.ttf")
34 | open("SimSun.ttf", "wb").write(ttf.content)
35 | FONT_PATH = "SimSun.ttf"
36 |
37 | PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
38 | ENDOFTEXT = "<|endoftext|>"
39 | IMSTART = "<|im_start|>"
40 | IMEND = "<|im_end|>"
41 | # as the default behavior is changed to allow special tokens in
42 | # regular texts, the surface forms of special tokens need to be
43 | # as different as possible to minimize the impact
44 | EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
45 | SPECIAL_TOKENS = (
46 | ENDOFTEXT,
47 | IMSTART,
48 | IMEND,
49 | ) + EXTRAS
50 | IMG_TOKEN_SPAN = 256
51 |
52 |
53 | def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
54 | with open(tiktoken_bpe_file, "rb") as f:
55 | contents = f.read()
56 | return {
57 | base64.b64decode(token): int(rank)
58 | for token, rank in (line.split() for line in contents.splitlines() if line)
59 | }
60 |
61 | def _list_find(
62 | input_list: List[Any],
63 | candidates: Tuple[Any],
64 | start: int = 0,
65 | ):
66 | for i in range(start, len(input_list)):
67 | if input_list[i] in candidates:
68 | return i
69 | return -1
70 |
71 | def _replace_closed_tag(
72 | input_tokens: List[Any],
73 | start_tags: Union[Any, Tuple[Any]],
74 | end_tags: Union[Any, Tuple[Any]],
75 | inclusive_replace_func: Callable,
76 | exclusive_replace_func: Callable = lambda x: x,
77 | ):
78 | if isinstance(start_tags, (str, int)):
79 | start_tags = (start_tags,)
80 | if isinstance(end_tags, (str, int)):
81 | end_tags = (end_tags,)
82 | assert len(start_tags) == len(end_tags)
83 |
84 | output_tokens = []
85 | end = 0
86 | while True:
87 | start = _list_find(input_tokens, start_tags, end)
88 | if start == -1:
89 | break
90 | output_tokens.extend(exclusive_replace_func(input_tokens[end : start]))
91 | tag_idx = start_tags.index(input_tokens[start])
92 | end = _list_find(input_tokens, (end_tags[tag_idx],), start)
93 | if end == -1:
94 | raise ValueError("Unclosed image token")
95 | output_tokens.extend(inclusive_replace_func(input_tokens[start : end + 1]))
96 | end += 1
97 | output_tokens.extend(exclusive_replace_func(input_tokens[end : ]))
98 | return output_tokens
99 |
100 | class QWenTokenizer(PreTrainedTokenizer):
101 | """QWen tokenizer."""
102 |
103 | vocab_files_names = VOCAB_FILES_NAMES
104 |
105 | def __init__(
106 | self,
107 | vocab_file,
108 | errors="replace",
109 | image_start_tag='
',
110 | image_end_tag='',
111 | image_pad_tag='',
112 | ref_start_tag='[',
113 | ref_end_tag=']',
114 | box_start_tag='',
115 | box_end_tag='',
116 | quad_start_tag='',
117 | quad_end_tag='',
118 | **kwargs,
119 | ):
120 | super().__init__(**kwargs)
121 | self.image_start_tag = image_start_tag
122 | self.image_end_tag = image_end_tag
123 | self.image_pad_tag = image_pad_tag
124 | self.ref_start_tag = ref_start_tag
125 | self.ref_end_tag = ref_end_tag
126 | self.box_start_tag = box_start_tag
127 | self.box_end_tag = box_end_tag
128 | self.quad_start_tag = quad_start_tag
129 | self.quad_end_tag = quad_end_tag
130 | self.IMAGE_ST = (
131 | ref_start_tag, ref_end_tag,
132 | box_start_tag, box_end_tag,
133 | quad_start_tag, quad_end_tag,
134 | image_start_tag, image_end_tag,
135 | image_pad_tag
136 | )
137 |
138 | self.errors = errors # how to handle errors in decoding
139 |
140 | self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: dict[bytes, int]
141 | self.special_tokens = {
142 | token: index
143 | for index, token in enumerate(
144 | SPECIAL_TOKENS + self.IMAGE_ST, start=len(self.mergeable_ranks)
145 | )
146 | }
147 | self.img_start_id = self.special_tokens[self.image_start_tag]
148 | self.img_end_id = self.special_tokens[self.image_end_tag]
149 | self.img_pad_id = self.special_tokens[self.image_pad_tag]
150 | self.ref_start_id = self.special_tokens[self.ref_start_tag]
151 | self.ref_end_id = self.special_tokens[self.ref_end_tag]
152 | self.box_start_id = self.special_tokens[self.box_start_tag]
153 | self.box_end_id = self.special_tokens[self.box_end_tag]
154 | self.quad_start_id = self.special_tokens[self.quad_start_tag]
155 | self.quad_end_id = self.special_tokens[self.quad_end_tag]
156 | self.image_special_tokens = set([
157 | self.ref_start_id, self.ref_end_id, self.box_start_id, self.box_end_id,
158 | self.quad_start_id, self.quad_end_id,
159 | ])
160 |
161 | enc = tiktoken.Encoding(
162 | "Qwen",
163 | pat_str=PAT_STR,
164 | mergeable_ranks=self.mergeable_ranks,
165 | special_tokens=self.special_tokens,
166 | )
167 | assert (
168 | len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
169 | ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
170 |
171 | self.decoder = {
172 | v: k for k, v in self.mergeable_ranks.items()
173 | } # type: dict[int, bytes|str]
174 | self.decoder.update({v: k for k, v in self.special_tokens.items()})
175 |
176 | self.tokenizer = enc # type: tiktoken.Encoding
177 |
178 | self.eod_id = self.tokenizer.eot_token
179 | self.im_start_id = self.special_tokens[IMSTART]
180 | self.im_end_id = self.special_tokens[IMEND]
181 |
182 | def __getstate__(self):
183 | # for pickle lovers
184 | state = self.__dict__.copy()
185 | del state['tokenizer']
186 | return state
187 |
188 | def __setstate__(self, state):
189 | # tokenizer is not python native; don't pass it; rebuild it
190 | self.__dict__.update(state)
191 | enc = tiktoken.Encoding(
192 | "Qwen",
193 | pat_str=PAT_STR,
194 | mergeable_ranks=self.mergeable_ranks,
195 | special_tokens=self.special_tokens,
196 | )
197 | self.tokenizer = enc
198 |
199 |
200 | def __len__(self) -> int:
201 | return self.tokenizer.n_vocab
202 |
203 | def get_vocab(self) -> Dict[bytes, int]:
204 | return self.mergeable_ranks
205 |
206 | def convert_tokens_to_ids(
207 | self, tokens: Union[bytes, str, List[Union[bytes, str]]]
208 | ) -> List[int]:
209 | ids = []
210 | if isinstance(tokens, (str, bytes)):
211 | if tokens in self.special_tokens:
212 | return self.special_tokens[tokens]
213 | else:
214 | return self.mergeable_ranks.get(tokens)
215 | for token in tokens:
216 | if token in self.special_tokens:
217 | ids.append(self.special_tokens[token])
218 | else:
219 | ids.append(self.mergeable_ranks.get(token))
220 | return ids
221 |
222 | def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
223 | if not special_tokens and new_tokens:
224 | raise ValueError('Adding regular tokens is not supported')
225 | for token in new_tokens:
226 | surface_form = token.content if isinstance(token, AddedToken) else token
227 | if surface_form not in SPECIAL_TOKENS + self.IMAGE_ST:
228 | raise ValueError('Adding unknown special tokens is not supported')
229 | return 0
230 |
231 | def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
232 | """
233 | Save only the vocabulary of the tokenizer (vocabulary).
234 |
235 | Returns:
236 | `Tuple(str)`: Paths to the files saved.
237 | """
238 | file_path = os.path.join(save_directory, "qwen.tiktoken")
239 | with open(file_path, "w", encoding="utf8") as w:
240 | for k, v in self.mergeable_ranks.items():
241 | line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n"
242 | w.write(line)
243 | return (file_path,)
244 |
245 | def tokenize(
246 | self,
247 | text: str,
248 | allowed_special: Union[Set, str] = "all",
249 | disallowed_special: Union[Collection, str] = (),
250 | **kwargs,
251 | ) -> List[Union[bytes, str]]:
252 | """
253 | Converts a string in a sequence of tokens.
254 |
255 | Args:
256 | text (`str`):
257 | The sequence to be encoded.
258 | allowed_special (`Literal["all"]` or `set`):
259 | The surface forms of the tokens to be encoded as special tokens in regular texts.
260 | Default to "all".
261 | disallowed_special (`Literal["all"]` or `Collection`):
262 | The surface forms of the tokens that should not be in regular texts and trigger errors.
263 | Default to an empty tuple.
264 |
265 | kwargs (additional keyword arguments, *optional*):
266 | Will be passed to the underlying model specific encode method.
267 |
268 | Returns:
269 | `List[bytes|str]`: The list of tokens.
270 | """
271 | tokens = []
272 | text = unicodedata.normalize("NFC", text)
273 |
274 | # this implementation takes a detour: text -> token id -> token surface forms
275 | for t in self.tokenizer.encode(
276 | text, allowed_special=allowed_special, disallowed_special=disallowed_special
277 | ):
278 | tokens.append(self.decoder[t])
279 |
280 | def _encode_imgurl(img_tokens):
281 | assert img_tokens[0] == self.image_start_tag and img_tokens[-1] == self.image_end_tag
282 | img_tokens = img_tokens[1:-1]
283 | img_url = b''.join(img_tokens)
284 | out_img_tokens = list(map(self.decoder.get, img_url))
285 | if len(out_img_tokens) > IMG_TOKEN_SPAN:
286 | raise ValueError("The content in {}..{} is too long".format(
287 | self.image_start_tag, self.image_end_tag))
288 | out_img_tokens.extend([self.image_pad_tag] * (IMG_TOKEN_SPAN - len(out_img_tokens)))
289 | out_img_tokens = [self.image_start_tag] + out_img_tokens + [self.image_end_tag]
290 | return out_img_tokens
291 |
292 | return _replace_closed_tag(tokens, self.image_start_tag, self.image_end_tag, _encode_imgurl)
293 |
294 | def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
295 | """
296 | Converts a sequence of tokens in a single string.
297 | """
298 | text = ""
299 | temp = b""
300 | for t in tokens:
301 | if isinstance(t, str):
302 | if temp:
303 | text += temp.decode("utf-8", errors=self.errors)
304 | temp = b""
305 | text += t
306 | elif isinstance(t, bytes):
307 | temp += t
308 | else:
309 | raise TypeError("token should only be of type types or str")
310 | if temp:
311 | text += temp.decode("utf-8", errors=self.errors)
312 | return text
313 |
314 | @property
315 | def vocab_size(self):
316 | return self.tokenizer.n_vocab
317 |
318 | def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
319 | """Converts an id to a token, special tokens included"""
320 | if index in self.decoder:
321 | return self.decoder[index]
322 | raise ValueError("unknown ids")
323 |
324 | def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
325 | """Converts a token to an id using the vocab, special tokens included"""
326 | if token in self.special_tokens:
327 | return self.special_tokens[token]
328 | if token in self.mergeable_ranks:
329 | return self.mergeable_ranks[token]
330 | raise ValueError("unknown token")
331 |
332 | def _tokenize(self, text: str, **kwargs):
333 | """
334 | Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
335 | vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
336 |
337 | Do NOT take care of added tokens.
338 | """
339 | raise NotImplementedError
340 |
341 | def _decode(
342 | self,
343 | token_ids: Union[int, List[int]],
344 | skip_special_tokens: bool = False,
345 | errors: str = None,
346 | **kwargs,
347 | ) -> str:
348 | if isinstance(token_ids, int):
349 | token_ids = [token_ids]
350 |
351 | def _decode_imgurl(img_token_ids):
352 | assert img_token_ids[0] == self.img_start_id and img_token_ids[-1] == self.img_end_id
353 | img_token_ids = img_token_ids[1:-1]
354 | img_token_ids = img_token_ids[ : img_token_ids.index(self.img_pad_id)]
355 | img_url = bytes(img_token_ids).decode('utf-8')
356 | return [self.img_start_id] + self.tokenizer.encode(img_url) + [self.img_end_id]
357 |
358 | token_ids = _replace_closed_tag(token_ids, self.img_start_id, self.img_end_id, _decode_imgurl)
359 |
360 | if skip_special_tokens:
361 | if kwargs.get('keep_image_special', False):
362 | token_ids = [i for i in token_ids if i < self.eod_id
363 | or i in self.image_special_tokens]
364 | else:
365 | token_ids = [i for i in token_ids if i < self.eod_id]
366 | return self.tokenizer.decode(token_ids, errors=errors or self.errors)
367 |
368 | def to_list_format(self, text: str):
369 | text = unicodedata.normalize("NFC", text)
370 | token_ids = self.tokenizer.encode(
371 | text, allowed_special=set(self.IMAGE_ST + (ENDOFTEXT,)))
372 |
373 | def _encode_vl_info(tokens):
374 | if len(tokens) == 0:
375 | return []
376 | if tokens[0] == self.img_start_id and tokens[-1] == self.img_end_id:
377 | key = 'image'
378 | elif tokens[0] == self.ref_start_id and tokens[-1] == self.ref_end_id:
379 | key = 'ref'
380 | elif tokens[0] == self.box_start_id and tokens[-1] == self.box_end_id:
381 | key = 'box'
382 | elif tokens[0] == self.quad_start_id and tokens[-1] == self.quad_end_id:
383 | key = 'quad'
384 | else:
385 | _tobytes = lambda x: x.encode('utf-8') if isinstance(x, str) else x
386 | return [{'text': b''.join(map(_tobytes, map(self.decoder.get, tokens))).decode('utf-8')}]
387 | _tobytes = lambda x: x.encode('utf-8') if isinstance(x, str) else x
388 | val = b''.join(map(_tobytes, map(self.decoder.get, tokens[1:-1]))).decode('utf-8')
389 | return [{key: val}]
390 |
391 | return _replace_closed_tag(
392 | token_ids,
393 | (self.img_start_id, self.ref_start_id, self.box_start_id, self.quad_start_id),
394 | (self.img_end_id, self.ref_end_id, self.box_end_id, self.quad_end_id),
395 | _encode_vl_info,
396 | _encode_vl_info,
397 | )
398 |
399 | def from_list_format(self, list_format: List[Dict]):
400 | text = ''
401 | num_images = 0
402 | for ele in list_format:
403 | if 'image' in ele:
404 | num_images += 1
405 | text += f'Picture {num_images}: '
406 | text += self.image_start_tag + ele['image'] + self.image_end_tag
407 | text += '\n'
408 | elif 'text' in ele:
409 | text += ele['text']
410 | elif 'box' in ele:
411 | if 'ref' in ele:
412 | text += self.ref_start_tag + ele['ref'] + self.ref_end_tag
413 | for box in ele['box']:
414 | text += self.box_start_tag + '(%d,%d),(%d,%d)' % (box[0], box[1], box[2], box[3]) + self.box_end_tag
415 | else:
416 | raise ValueError("Unsupport element: " + str(ele))
417 | return text
418 |
419 | def _fetch_latest_picture(self, response, history):
420 | if history is None:
421 | history = []
422 | _history = history + [(response, None)]
423 | for q, r in _history[::-1]:
424 | for ele in self.to_list_format(q)[::-1]:
425 | if 'image' in ele:
426 | return ele['image']
427 | return None
428 |
429 | def _fetch_all_box_with_ref(self, text):
430 | list_format = self.to_list_format(text)
431 | output = []
432 | for i, ele in enumerate(list_format):
433 | if 'box' in ele:
434 | bbox = tuple(map(int, ele['box'].replace('(', '').replace(')', '').split(',')))
435 | assert len(bbox) == 4
436 | output.append({'box': bbox})
437 | if i > 0 and 'ref' in list_format[i-1]:
438 | output[-1]['ref'] = list_format[i-1]['ref'].strip()
439 | return output
440 |
441 | def draw_bbox_on_latest_picture(
442 | self,
443 | response,
444 | history=None,
445 | ) -> Optional[Image.Image]:
446 | image = self._fetch_latest_picture(response, history)
447 | if image is None:
448 | return None
449 | if image.startswith("http://") or image.startswith("https://"):
450 | image = Image.open(requests.get(image, stream=True).raw).convert("RGB")
451 | h, w = image.height, image.width
452 | else:
453 | image = np.asarray(Image.open(image).convert("RGB"))
454 | h, w = image.shape[0], image.shape[1]
455 | visualizer = Visualizer(image)
456 |
457 | boxes = self._fetch_all_box_with_ref(response)
458 | if not boxes:
459 | return None
460 | color = random.choice([_ for _ in mcolors.TABLEAU_COLORS.keys()]) # init color
461 | for box in boxes:
462 | if 'ref' in box: # random new color for new refexps
463 | color = random.choice([_ for _ in mcolors.TABLEAU_COLORS.keys()])
464 | x1, y1, x2, y2 = box['box']
465 | x1, y1, x2, y2 = (int(x1 / 1000 * w), int(y1 / 1000 * h), int(x2 / 1000 * w), int(y2 / 1000 * h))
466 | visualizer.draw_box((x1, y1, x2, y2), alpha=1, edge_color=color)
467 | if 'ref' in box:
468 | visualizer.draw_text(box['ref'], (x1, y1), color=color, horizontal_alignment="left")
469 | return visualizer.output
470 |
471 |
472 | import colorsys
473 | import logging
474 | import math
475 | import numpy as np
476 | import matplotlib as mpl
477 | import matplotlib.colors as mplc
478 | import matplotlib.figure as mplfigure
479 | import torch
480 | from matplotlib.backends.backend_agg import FigureCanvasAgg
481 | from PIL import Image
482 | import random
483 |
484 | logger = logging.getLogger(__name__)
485 |
486 |
487 | class VisImage:
488 | def __init__(self, img, scale=1.0):
489 | self.img = img
490 | self.scale = scale
491 | self.width, self.height = img.shape[1], img.shape[0]
492 | self._setup_figure(img)
493 |
494 | def _setup_figure(self, img):
495 | fig = mplfigure.Figure(frameon=False)
496 | self.dpi = fig.get_dpi()
497 | # add a small 1e-2 to avoid precision lost due to matplotlib's truncation
498 | # (https://github.com/matplotlib/matplotlib/issues/15363)
499 | fig.set_size_inches(
500 | (self.width * self.scale + 1e-2) / self.dpi,
501 | (self.height * self.scale + 1e-2) / self.dpi,
502 | )
503 | self.canvas = FigureCanvasAgg(fig)
504 | # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
505 | ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
506 | ax.axis("off")
507 | self.fig = fig
508 | self.ax = ax
509 | self.reset_image(img)
510 |
511 | def reset_image(self, img):
512 | img = img.astype("uint8")
513 | self.ax.imshow(img, extent=(0, self.width, self.height, 0), interpolation="nearest")
514 |
515 | def save(self, filepath):
516 | self.fig.savefig(filepath)
517 |
518 | def get_image(self):
519 | canvas = self.canvas
520 | s, (width, height) = canvas.print_to_buffer()
521 |
522 | buffer = np.frombuffer(s, dtype="uint8")
523 |
524 | img_rgba = buffer.reshape(height, width, 4)
525 | rgb, alpha = np.split(img_rgba, [3], axis=2)
526 | return rgb.astype("uint8")
527 |
528 |
529 | class Visualizer:
530 | def __init__(self, img_rgb, metadata=None, scale=1.0):
531 | self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)
532 | self.font_path = FONT_PATH
533 | self.output = VisImage(self.img, scale=scale)
534 | self.cpu_device = torch.device("cpu")
535 |
536 | # too small texts are useless, therefore clamp to 14
537 | self._default_font_size = max(
538 | np.sqrt(self.output.height * self.output.width) // 30, 15 // scale
539 | )
540 |
541 | def draw_text(
542 | self,
543 | text,
544 | position,
545 | *,
546 | font_size=None,
547 | color="g",
548 | horizontal_alignment="center",
549 | rotation=0,
550 | ):
551 | if not font_size:
552 | font_size = self._default_font_size
553 |
554 | # since the text background is dark, we don't want the text to be dark
555 | color = np.maximum(list(mplc.to_rgb(color)), 0.2)
556 | color[np.argmax(color)] = max(0.8, np.max(color))
557 |
558 | x, y = position
559 | self.output.ax.text(
560 | x,
561 | y,
562 | text,
563 | size=font_size * self.output.scale,
564 | fontproperties=FontProperties(fname=self.font_path),
565 | bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"},
566 | verticalalignment="top",
567 | horizontalalignment=horizontal_alignment,
568 | color=color,
569 | zorder=10,
570 | rotation=rotation,
571 | )
572 | return self.output
573 |
574 | def draw_box(self, box_coord, alpha=0.5, edge_color="g", line_style="-"):
575 |
576 | x0, y0, x1, y1 = box_coord
577 | width = x1 - x0
578 | height = y1 - y0
579 |
580 | linewidth = max(self._default_font_size / 4, 1)
581 |
582 | self.output.ax.add_patch(
583 | mpl.patches.Rectangle(
584 | (x0, y0),
585 | width,
586 | height,
587 | fill=False,
588 | edgecolor=edge_color,
589 | linewidth=linewidth * self.output.scale,
590 | alpha=alpha,
591 | linestyle=line_style,
592 | )
593 | )
594 | return self.output
595 |
596 | def get_output(self):
597 |
598 | return self.output
599 |
--------------------------------------------------------------------------------
/OdysseyAgent/tokenizer_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "auto_map": {
3 | "AutoTokenizer": [
4 | "Qwen/Qwen-VL-Chat--tokenization_qwen.QWenTokenizer",
5 | null
6 | ]
7 | },
8 | "clean_up_tokenization_spaces": true,
9 | "model_max_length": 8192,
10 | "tokenizer_class": "QWenTokenizer"
11 | }
12 |
--------------------------------------------------------------------------------
/OdysseyAgent/visual.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Alibaba Cloud.
2 | #
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from collections import OrderedDict
7 | import math
8 | import requests
9 | from io import BytesIO
10 | from functools import partial
11 | from PIL import Image
12 | from typing import Callable, Optional, Sequence, Tuple, List
13 | import numpy as np
14 |
15 | import torch
16 | from torch import nn
17 | from torch.nn import functional as F
18 | from torch.nn.init import trunc_normal_
19 | from torchvision import transforms
20 | from torchvision.transforms import InterpolationMode
21 |
22 |
23 | def get_abs_pos(abs_pos, tgt_size):
24 | # abs_pos: L, C
25 | # tgt_size: M
26 | # return: M, C
27 | src_size = int(math.sqrt(abs_pos.size(0)))
28 | tgt_size = int(math.sqrt(tgt_size))
29 | dtype = abs_pos.dtype
30 |
31 | if src_size != tgt_size:
32 | return F.interpolate(
33 | abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
34 | size=(tgt_size, tgt_size),
35 | mode="bicubic",
36 | align_corners=False,
37 | ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
38 | else:
39 | return abs_pos
40 |
41 | # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
42 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
43 | """
44 | grid_size: int of the grid height and width
45 | return:
46 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
47 | """
48 | grid_h = np.arange(grid_size, dtype=np.float32)
49 | grid_w = np.arange(grid_size, dtype=np.float32)
50 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
51 | grid = np.stack(grid, axis=0)
52 |
53 | grid = grid.reshape([2, 1, grid_size, grid_size])
54 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
55 | if cls_token:
56 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
57 | return pos_embed
58 |
59 |
60 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
61 | assert embed_dim % 2 == 0
62 |
63 | # use half of dimensions to encode grid_h
64 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
65 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
66 |
67 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
68 | return emb
69 |
70 |
71 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
72 | """
73 | embed_dim: output dimension for each position
74 | pos: a list of positions to be encoded: size (M,)
75 | out: (M, D)
76 | """
77 | assert embed_dim % 2 == 0
78 | omega = np.arange(embed_dim // 2, dtype=np.float32)
79 | omega /= embed_dim / 2.
80 | omega = 1. / 10000**omega # (D/2,)
81 |
82 | pos = pos.reshape(-1) # (M,)
83 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
84 |
85 | emb_sin = np.sin(out) # (M, D/2)
86 | emb_cos = np.cos(out) # (M, D/2)
87 |
88 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
89 | return emb
90 |
91 |
92 | class Resampler(nn.Module):
93 | """
94 | A 2D perceiver-resampler network with one cross attention layers by
95 | (grid_size**2) learnable queries and 2d sincos pos_emb
96 | Outputs:
97 | A tensor with the shape of (grid_size**2, embed_dim)
98 | """
99 | def __init__(
100 | self,
101 | grid_size,
102 | embed_dim,
103 | num_heads,
104 | kv_dim=None,
105 | norm_layer=nn.LayerNorm
106 | ):
107 | super().__init__()
108 | self.num_queries = grid_size ** 2
109 | self.embed_dim = embed_dim
110 | self.num_heads = num_heads
111 |
112 | self.pos_embed = nn.Parameter(
113 | torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float()
114 | ).requires_grad_(False)
115 |
116 | self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
117 | trunc_normal_(self.query, std=.02)
118 |
119 | if kv_dim is not None and kv_dim != embed_dim:
120 | self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
121 | else:
122 | self.kv_proj = nn.Identity()
123 |
124 | self.attn = nn.MultiheadAttention(embed_dim, num_heads)
125 | self.ln_q = norm_layer(embed_dim)
126 | self.ln_kv = norm_layer(embed_dim)
127 |
128 | # self.apply(self._init_weights)
129 |
130 | def _init_weights(self, m):
131 | if isinstance(m, nn.Linear):
132 | trunc_normal_(m.weight, std=.02)
133 | if isinstance(m, nn.Linear) and m.bias is not None:
134 | nn.init.constant_(m.bias, 0)
135 | elif isinstance(m, nn.LayerNorm):
136 | nn.init.constant_(m.bias, 0)
137 | nn.init.constant_(m.weight, 1.0)
138 |
139 | def forward(self, x, attn_mask=None):
140 |
141 | pos_embed = get_abs_pos(self.pos_embed, x.size(1))
142 |
143 | x = self.kv_proj(x)
144 | x = self.ln_kv(x).permute(1, 0, 2)
145 |
146 | N = x.shape[1]
147 | q = self.ln_q(self.query)
148 | out = self.attn(
149 | self._repeat(q, N) + self.pos_embed.unsqueeze(1),
150 | x + pos_embed.unsqueeze(1),
151 | x,
152 | attn_mask=attn_mask)[0]
153 | return out.permute(1, 0, 2)
154 |
155 | def _repeat(self, query, N: int):
156 | return query.unsqueeze(1).repeat(1, N, 1)
157 |
158 |
159 | class VisualAttention(nn.Module):
160 | """self-attention layer class.
161 |
162 | Self-attention layer takes input with size [s, b, h]
163 | and returns output of the same size.
164 | """
165 |
166 | def __init__(self, embed_dim, num_heads,
167 | bias=True, kdim=None, vdim=None):
168 | super(VisualAttention, self).__init__()
169 | self.embed_dim = embed_dim
170 | self.kdim = kdim if kdim is not None else embed_dim
171 | self.vdim = vdim if vdim is not None else embed_dim
172 | self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
173 |
174 | self.num_heads = num_heads
175 |
176 | # Per attention head and per partition values.
177 | assert embed_dim % num_heads == 0
178 | self.hidden_size_per_attention_head = embed_dim // num_heads
179 | self.num_attention_heads_per_partition = num_heads
180 | self.hidden_size_per_partition = embed_dim
181 |
182 | # Strided linear layer.
183 | assert self._qkv_same_embed_dim, 'Only Support SelfAttention Currently'
184 | self.in_proj = nn.Linear(embed_dim, 3 * embed_dim)
185 | self.out_proj = nn.Linear(embed_dim, embed_dim)
186 | self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
187 |
188 | def forward(self, query, key, value, attn_mask = None):
189 | # query/key/value: [sq, b, h]
190 | sq, b, _ = query.size()
191 |
192 | assert torch.allclose(query, key), 'Only Support Self-Attention Currently'
193 | sk = sq
194 | mixed_x_layer = self.in_proj(query)
195 |
196 | # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
197 | new_tensor_shape = mixed_x_layer.size()[:-1] + \
198 | (self.num_attention_heads_per_partition,
199 | 3 * self.hidden_size_per_attention_head)
200 | mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
201 |
202 | # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
203 | query_layer, key_layer, value_layer = mixed_x_layer.split(
204 | self.hidden_size_per_attention_head, dim=-1)
205 |
206 | # [sq, b, np, hn] -> [sq, b * np, hn]
207 | query_layer = query_layer.view(sq,
208 | b * self.num_attention_heads_per_partition,
209 | self.hidden_size_per_attention_head).transpose(0, 1)
210 | # [sk, b, np, hn] -> [sk, b * np, hn]
211 | key_layer = key_layer.view(sk,
212 | b * self.num_attention_heads_per_partition,
213 | self.hidden_size_per_attention_head).transpose(0, 1)
214 |
215 | q_scaled = query_layer / self.norm_factor
216 | if attn_mask is not None:
217 | attention_probs = torch.baddbmm(attn_mask, q_scaled, key_layer.transpose(-2, -1))
218 | else:
219 | attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1))
220 | attention_probs = attention_probs.softmax(dim=-1)
221 |
222 | value_layer = value_layer.view(sk,
223 | b * self.num_attention_heads_per_partition,
224 | self.hidden_size_per_attention_head).transpose(0, 1)
225 |
226 | # matmul: [b * np, sq, hn]
227 | context_layer = torch.bmm(attention_probs, value_layer)
228 |
229 | # change view [b, np, sq, hn]
230 | context_layer = context_layer.view(b,
231 | self.num_attention_heads_per_partition,
232 | sq, self.hidden_size_per_attention_head)
233 |
234 | # [b, np, sq, hn] --> [sq, b, np, hn]
235 | context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
236 |
237 | # [sq, b, np, hn] --> [sq, b, hp]
238 | new_context_layer_shape = context_layer.size()[:-2] + \
239 | (self.hidden_size_per_partition,)
240 | context_layer = context_layer.view(*new_context_layer_shape)
241 |
242 | output = self.out_proj(context_layer)
243 |
244 | return output
245 |
246 |
247 | class VisualAttentionBlock(nn.Module):
248 | def __init__(
249 | self,
250 | d_model: int,
251 | n_head: int,
252 | mlp_ratio: float = 4.0,
253 | act_layer: Callable = nn.GELU,
254 | norm_layer: Callable = nn.LayerNorm,
255 | is_cross_attention: bool = False,
256 | ):
257 | super().__init__()
258 |
259 | self.ln_1 = norm_layer(d_model)
260 | if is_cross_attention:
261 | self.ln_1_kv = norm_layer(d_model)
262 |
263 | self.ln_2 = norm_layer(d_model)
264 | mlp_width = int(d_model * mlp_ratio)
265 | self.attn = VisualAttention(d_model, n_head)
266 | self.mlp = nn.Sequential(OrderedDict([
267 | ("c_fc", nn.Linear(d_model, mlp_width)),
268 | ("gelu", act_layer()),
269 | ("c_proj", nn.Linear(mlp_width, d_model))
270 | ]))
271 |
272 | def attention(
273 | self,
274 | q_x: torch.Tensor,
275 | k_x: Optional[torch.Tensor] = None,
276 | v_x: Optional[torch.Tensor] = None,
277 | attn_mask: Optional[torch.Tensor] = None,
278 | ):
279 | k_x = k_x if k_x is not None else q_x
280 | v_x = v_x if v_x is not None else q_x
281 |
282 | attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
283 | return self.attn(q_x, k_x, v_x, attn_mask=attn_mask)
284 |
285 | def forward(
286 | self,
287 | q_x: torch.Tensor,
288 | k_x: Optional[torch.Tensor] = None,
289 | v_x: Optional[torch.Tensor] = None,
290 | attn_mask: Optional[torch.Tensor] = None,
291 | ):
292 | k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
293 | v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
294 |
295 | x = q_x + self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)
296 | x = x + self.mlp(self.ln_2(x))
297 | return x
298 |
299 |
300 | class TransformerBlock(nn.Module):
301 | def __init__(
302 | self,
303 | width: int,
304 | layers: int,
305 | heads: int,
306 | mlp_ratio: float = 4.0,
307 | act_layer: Callable = nn.GELU,
308 | norm_layer: Callable = nn.LayerNorm,
309 | ):
310 | super().__init__()
311 | self.width = width
312 | self.layers = layers
313 |
314 | self.resblocks = nn.ModuleList([
315 | VisualAttentionBlock(
316 | width, heads, mlp_ratio, act_layer=act_layer, norm_layer=norm_layer)
317 | for _ in range(layers)
318 | ])
319 |
320 | def get_cast_dtype(self) -> torch.dtype:
321 | return self.resblocks[0].mlp.c_fc.weight.dtype
322 |
323 | def get_cast_device(self) -> torch.device:
324 | return self.resblocks[0].mlp.c_fc.weight.device
325 |
326 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
327 | for r in self.resblocks:
328 | x = r(x, attn_mask=attn_mask)
329 | return x
330 |
331 |
332 | class VisionTransformer(nn.Module):
333 |
334 | def __init__(
335 | self,
336 | image_size: int,
337 | patch_size: int,
338 | width: int,
339 | layers: int,
340 | heads: int,
341 | mlp_ratio: float,
342 | n_queries: int = 256,
343 | output_dim: int = 512,
344 | **kwargs
345 | ):
346 | super().__init__()
347 | image_height, image_width = self.image_size = (image_size, image_size)
348 | patch_height, patch_width = self.patch_size = (patch_size, patch_size)
349 | self.grid_size = (image_height // patch_height, image_width // patch_width)
350 | self.output_dim = output_dim
351 |
352 | mean = (0.48145466, 0.4578275, 0.40821073)
353 | std = (0.26862954, 0.26130258, 0.27577711)
354 | self.image_transform = transforms.Compose([
355 | transforms.Resize(
356 | (image_size, image_size),
357 | interpolation=InterpolationMode.BICUBIC
358 | ),
359 | transforms.ToTensor(),
360 | transforms.Normalize(mean=mean, std=std),
361 | ])
362 |
363 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
364 |
365 | # class embeddings and positional embeddings
366 | scale = width ** -0.5
367 | self.positional_embedding = nn.Parameter(scale * torch.randn(256, width))
368 |
369 | norm_layer = partial(nn.LayerNorm, eps=1e-6)
370 | act_layer = nn.GELU
371 |
372 | self.ln_pre = norm_layer(width)
373 | self.transformer = TransformerBlock(
374 | width,
375 | layers,
376 | heads,
377 | mlp_ratio,
378 | act_layer=act_layer,
379 | norm_layer=norm_layer,
380 | )
381 |
382 | self.attn_pool = Resampler(
383 | grid_size=int(math.sqrt(n_queries)),
384 | embed_dim=output_dim,
385 | num_heads=output_dim // 128,
386 | kv_dim=width,
387 | norm_layer=norm_layer,
388 | )
389 | self.ln_post = norm_layer(output_dim)
390 | self.proj = nn.Parameter((output_dim** -0.5) * torch.randn(output_dim, output_dim))
391 |
392 | def forward(self, x: torch.Tensor):
393 | x = x.to(
394 | dtype=self.transformer.get_cast_dtype(),
395 | device=self.transformer.get_cast_device(),
396 | )
397 | # to patches
398 | x = self.conv1(x) # shape = [*, width, grid, grid]
399 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
400 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
401 |
402 | x = x + get_abs_pos(self.positional_embedding, x.size(1))
403 |
404 | x = self.ln_pre(x)
405 |
406 | x = x.permute(1, 0, 2) # NLD -> LND
407 | x = self.transformer(x)
408 | x = x.permute(1, 0, 2) # LND -> NLD
409 |
410 | x = self.attn_pool(x)
411 | x = self.ln_post(x)
412 | x = x @ self.proj
413 |
414 | return x
415 |
416 | def encode(self, image_paths: List[str]):
417 | images = []
418 | for image_path in image_paths:
419 | if image_path.startswith("http://") or image_path.startswith("https://"):
420 | image = Image.open(requests.get(image_path, stream=True).raw)
421 | else:
422 | image = Image.open(image_path)
423 | image = image.convert("RGB")
424 | images.append(self.image_transform(image))
425 | images = torch.stack(images, dim=0)
426 | return self(images)
427 |
--------------------------------------------------------------------------------
/Quickstart.md:
--------------------------------------------------------------------------------
1 | # 🚀Quick Start
2 |
3 | ## Data preprocessing
4 |
5 | Please follow the **Dataset Access** section of the [README.md](README.md) to prepare the data, and run the `preprocessing.py` script as instructed. Ensure that the structure of the `./data` directory is as shown below:
6 |
7 | ```
8 | GUI-Odyssey
9 | ├── data
10 | │ ├── annotations
11 | │ │ └── *.json
12 | │ ├── screenshots
13 | │ │ └── *.png
14 | │ ├── splits
15 | │ │ ├── app_split.json
16 | │ │ ├── device_split.json
17 | │ │ ├── random_split.json
18 | │ │ └── task_split.json
19 | │ ├── format_converter.py
20 | │ └── preprocessing.py
21 | └── ...
22 | ```
23 |
24 | Next, run the following command to generate chat-format data for training and testing. The `his_len` parameter can be set to specify the length of historical information:
25 |
26 | ```shell
27 | cd data
28 | python format_converter.py --his_len 4
29 | ```
30 |
31 | ## Build OdysseyAgent upon Qwen-VL-Chat
32 |
33 | The OdysseyAgent is bulit upon [Qwen-VL](https://github.com/QwenLM/Qwen-VL).
34 |
35 | Before running, set up the environment and install the required packages:
36 |
37 | ```shell
38 | cd src
39 | pip install -r requirements.txt
40 | ```
41 |
42 | Next, initialize `OdysseyAgent` using the weights from `Qwen-VL-Chat`:
43 |
44 | ```shell
45 | python merge_weight.py
46 | ```
47 |
48 | Further, we also provide four variants of OdysseyAgent:
49 | - [OdysseyAgent-Random](https://huggingface.co/hflqf88888/OdysseyAgent-random)
50 | - [OdysseyAgent-Task](https://huggingface.co/hflqf88888/OdysseyAgent-task)
51 | - [OdysseyAgent-Device](https://huggingface.co/hflqf88888/OdysseyAgent-device)
52 | - [OdysseyAgent-App](https://huggingface.co/hflqf88888/OdysseyAgent-app)
53 |
54 | Each fine-tuned on `Train-Random`, `Train-Task`, `Train-Device`, and `Train-App` respectively.
55 |
56 | ### Fine-tuning
57 |
58 | Specify the path to the `OdysseyAgent` and the chat-format training data generated in the `Data preprocessing` stage (one of the four splits) in the `script/train.sh` file. Then, run the following command:
59 |
60 | ```shell
61 | cd src
62 | bash script/train.sh
63 | ```
64 |
65 | ### Evalutaion
66 |
67 | Specify the path to the checkpoint and dataset split (one of `app_split`, `device_split`, `random_split`, `task_split`) in the `script/eval.sh` file. Then, run the following command:
68 |
69 | ```shell
70 | cd src
71 | bash script/eval.sh
72 | ```
73 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GUI Odyssey
2 |
3 | **This repository is the official implementation of GUI Odyssey.**
4 |
5 | > [GUI Odyssey: A Comprehensive Dataset for Cross-App GUI Navigation on Mobile Devices](https://arxiv.org/abs/2406.08451)
6 | > Quanfeng Lu, Wenqi Shao✉️⭐️, Zitao Liu, Fanqing Meng, Boxuan Li, Botong Chen, Siyuan Huang, Kaipeng Zhang, Yu Qiao, Ping Luo✉️
7 | > ✉️ Wenqi Shao (shaowenqi@pjlab.org.cn) and Ping Luo (pluo@cs.hku.hk) are correponding authors.
8 | > ⭐️ Wenqi Shao is project leader.
9 |
10 |
11 | ## 💡 News
12 |
13 | - `2024/06/24`: The data of [GUI Odyssey](https://arxiv.org/pdf/2406.08451) is released! Please check out [OpenGVLab/GUI-Odyssey](https://huggingface.co/datasets/OpenGVLab/GUI-Odyssey)!
14 | - `2024/06/13`: The paper of [GUI Odyssey](https://arxiv.org/pdf/2406.08451) is released!
15 |
16 |
17 |
18 | ## 🔆 Introduction
19 | GUI Odyssey is a comprehensive dataset for training and evaluating **cross-app** navigation agents. GUI Odyssey consists of 7,735 episodes from 6 mobile devices, spanning 6 types of cross-app tasks, 201 apps, and 1.4K app combos.
20 | 
21 |
22 |
23 | ## 🛠️ Data collection pipeline
24 | GUI Odyssey comprises six categories of navigation tasks. For each category, we construct instruction templates with items and apps selected from a predefined pool, resulting in a vast array of unique instructions for annotating GUI episodes. Human demonstrations on an Android emulator capture the metadata of each episode in a comprehensive format. After rigorous quality checks, GUI Odyssey includes 7,735 validated cross-app GUI navigation episodes.
25 | 
26 |
27 |
28 | ## 📝 Statistics
29 |
30 |
31 |
32 | Splits | # Episodes | # Unique Prompts | # Avg. Steps | Data location | Model
33 | :---------: | :---------: | :-----------: | :--------------: | :-----------: | :-----------:
34 | **Total** | **7,735** | **7,735** | **15.4** | [GUI-Odyssey](https://huggingface.co/datasets/OpenGVLab/GUI-Odyssey) | [OdysseyAgent](https://huggingface.co/collections/hflqf88888/gui-odyssey-6683bac37ad6fe37b1215c18)
35 | Train-Random \& Test-Random | 5,802 / 1,933 | 5,802 / 1,933 | 15.4 / 15.2 | [random_split.json](https://huggingface.co/datasets/OpenGVLab/GUI-Odyssey/tree/main/splits) | [OdysseyAgent-Random](https://huggingface.co/hflqf88888/OdysseyAgent-random)
36 | Train-Task \& Test-Task | 6,719 / 1,016 | 6,719 / 1,016 | 15.0 / 17.6 | [task_split.json](https://huggingface.co/datasets/OpenGVLab/GUI-Odyssey/tree/main/splits) | [OdysseyAgent-Task](https://huggingface.co/hflqf88888/OdysseyAgent-task)
37 | Train-Device \& Test-Device | 6,473 / 1,262 | 6,473 / 1,262 | 15.4 / 15.0 | [device_split.json](https://huggingface.co/datasets/OpenGVLab/GUI-Odyssey/tree/main/splits) | [OdysseyAgent-Device](https://huggingface.co/hflqf88888/OdysseyAgent-device)
38 | Train-App \& Test-App | 6,596 / 1,139 | 6,596 / 1,139 | 15.4 / 15.3 | [app_split.json](https://huggingface.co/datasets/OpenGVLab/GUI-Odyssey/tree/main/splits) | [OdysseyAgent-App](https://huggingface.co/hflqf88888/OdysseyAgent-app)
39 |
40 |
41 |
42 | ## 💫 Dataset Access
43 |
44 | The whole GUI Odyssey is hosted on [Huggingface](https://huggingface.co/datasets/OpenGVLab/GUI-Odyssey).
45 |
46 | Clone the entire dataset from Huggingface:
47 | ```shell
48 | git clone https://huggingface.co/datasets/OpenGVLab/GUI-Odyssey
49 | ```
50 | And then move the cloned dataset into `./data` directory. After that, the structure of `./data` should look like this:
51 |
52 |
53 | ```
54 | GUI-Odyssey
55 | ├── data
56 | │ ├── annotations
57 | │ │ └── *.json
58 | │ ├── screenshots
59 | │ │ └── data_*
60 | │ │ └── *.png
61 | │ ├── splits
62 | │ │ ├── app_split.json
63 | │ │ ├── device_split.json
64 | │ │ ├── random_split.json
65 | │ │ └── task_split.json
66 | │ ├── format_converter.py
67 | │ └── preprocessing.py
68 | └── ...
69 | ```
70 |
71 | Then organize the screenshots folder:
72 |
73 | ```shell
74 | cd data
75 | python preprocessing.py
76 | ```
77 |
78 | Finally, the structure of `./data` should look like this:
79 |
80 | ```
81 | GUI-Odyssey
82 | ├── data
83 | │ ├── annotations
84 | │ │ └── *.json
85 | │ ├── screenshots
86 | │ │ └── *.png
87 | │ ├── splits
88 | │ │ ├── app_split.json
89 | │ │ ├── device_split.json
90 | │ │ ├── random_split.json
91 | │ │ └── task_split.json
92 | │ ├── format_converter.py
93 | │ └── preprocessing.py
94 | └── ...
95 | ```
96 |
97 |
98 | ## ⚙️ Detailed Data Information
99 | Please refer to [this](introduction.md).
100 |
101 |
102 | ## 🚀 Quick Start
103 |
104 | Please refer to [this](Quickstart.md) to quick start.
105 |
106 | ## 📖 Release Process
107 |
108 | - [x] Dataset
109 | - [x] Screenshots of GUI Odyssey
110 | - [x] annotations of GUI Odyssey
111 | - [x] split files of GUI Odyssey
112 | - [x] Code
113 | - [x] data preprocessing code
114 | - [x] inference code
115 | - [x] Models
116 |
117 |
118 | ## 🖊️ Citation
119 | If you feel GUI Odyssey useful in your project or research, please kindly use the following BibTeX entry to cite our paper. Thanks!
120 | ```bib
121 | @article{lu2024gui,
122 | title={GUI Odyssey: A Comprehensive Dataset for Cross-App GUI Navigation on Mobile Devices},
123 | author={Lu, Quanfeng and Shao, Wenqi and Liu, Zitao and Meng, Fanqing and Li, Boxuan and Chen, Botong and Huang, Siyuan and Zhang, Kaipeng and Qiao, Yu and Luo, Ping},
124 | journal={arXiv preprint arXiv:2406.08451},
125 | year={2024}
126 | }
127 | ```
128 |
129 |
132 |
--------------------------------------------------------------------------------
/assets/dataset_overview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/GUI-Odyssey/8bf1c756691f077e3d9439dc16103a2cf9493245/assets/dataset_overview.jpg
--------------------------------------------------------------------------------
/assets/pipeline.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/GUI-Odyssey/8bf1c756691f077e3d9439dc16103a2cf9493245/assets/pipeline.jpg
--------------------------------------------------------------------------------
/assets/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/GUI-Odyssey/8bf1c756691f077e3d9439dc16103a2cf9493245/assets/pipeline.png
--------------------------------------------------------------------------------
/data/format_converter.py:
--------------------------------------------------------------------------------
1 | import os, json
2 | import argparse
3 | import numpy as np
4 |
5 | ###
6 | current_path = os.path.abspath(__file__)
7 | DATA_DIR = os.path.dirname(current_path)
8 | pic_base = os.path.join(DATA_DIR, 'screenshots')
9 | anno_base = os.path.join(DATA_DIR, 'annotations')
10 | train_anno_base = os.path.join(DATA_DIR, 'train_anno')
11 | test_anno_base = os.path.join(DATA_DIR, 'test_anno')
12 | split_base = os.path.join(DATA_DIR, 'splits')
13 |
14 | PROMPT = "I'm looking for guidance on how to "
15 |
16 | ###
17 |
18 | def decode_action(action, info):
19 | if action == 'CLICK' or action == "LONG_PRESS":
20 | if info == 'KEY_HOME':
21 | gt = 'PRESS_HOME'
22 | elif info == 'KEY_BACK':
23 | gt = 'PRESS_BACK'
24 | elif info == 'KEY_APPSELECT':
25 | gt = 'PRESS_RECENT'
26 | elif type(info) == list:
27 | gt = f"{action}: {tuple(info[0])}"
28 | else:
29 | raise ValueError(f'Unknown click action {info}')
30 |
31 | elif action == 'SCROLL':
32 | start = np.array(info[0])
33 | end = np.array(info[1])
34 | delta = end - start
35 | delta_abs = np.abs(delta)
36 | lr = 'LEFT' if delta[0] < 0 else 'RIGHT'
37 | ud = 'UP' if delta[1] < 0 else 'DOWN'
38 | if delta_abs[0] > delta_abs[1]:
39 | gt = f"SCROLL: {lr}"
40 | else:
41 | gt = f"SCROLL: {ud}"
42 |
43 | elif action == 'TEXT':
44 | gt = f'TYPE: {info}'
45 | elif action == 'COMPLETE':
46 | gt = action
47 | elif action == 'INCOMPLETE':
48 | gt = 'IMPOSSIBLE'
49 | else:
50 | raise ValueError(f'Unknown action {action}')
51 | return gt
52 |
53 | def build_train_chat(split_fp='./splits/random_split.json', his_len=4):
54 | os.makedirs(train_anno_base, exist_ok=True)
55 | name = os.path.basename(split_fp).split('.')[0]
56 | train_split = json.load(open(split_fp))['train']
57 | res = []
58 | idx = 0
59 | for f in train_split:
60 | this_res = []
61 | fp = os.path.join(anno_base, f)
62 | data = json.load(open(fp))
63 | instruction = data['task_info']['instruction']
64 | steps = data['steps']
65 |
66 | history_screenshot, history_action = [], []
67 |
68 | for step in steps:
69 | image = step['screenshot']
70 | action = step['action']
71 | info = step['info']
72 |
73 | gt = decode_action(action, info)
74 | img_abs_path = os.path.join(pic_base, image)
75 | value = f"Picture 1:
{img_abs_path}\n{PROMPT}{instruction}"
76 |
77 | his_str = ''
78 | for hidx, act in enumerate(history_action[-his_len:]):
79 | his_str += f'{hidx + 1}. {act}\n'
80 |
81 | if len(history_action) > 0 and his_len > 0:
82 | value += f'\nPrevious screenshots:
image-history: {img_abs_path}'
83 | value += f'\nPrevious Actions: {his_str}'
84 |
85 | conversations = [{"from": "user", "value": value}, {"from": "assistant", "value": gt}]
86 |
87 | this_res.append({
88 | 'id': f'GUIOdyssey_{name}_{idx}',
89 | 'image': img_abs_path,
90 | 'conversations': conversations,
91 | 'history': str(history_screenshot),
92 | })
93 | idx += 1
94 |
95 | history_screenshot.append(img_abs_path)
96 | history_action.append(gt)
97 |
98 | res.extend(this_res)
99 |
100 | json.dump(res, open(os.path.join(train_anno_base, os.path.basename(split_fp)), 'w'), indent=4, ensure_ascii=False)
101 |
102 |
103 | def build_test(split_fp='./splits/random_split.json', his_len=4):
104 | os.makedirs(test_anno_base, exist_ok=True)
105 | name = os.path.basename(split_fp).split('.')[0]
106 | test_split = json.load(open(split_fp))['test']
107 | res = []
108 | idx = 0
109 | for f in test_split:
110 | this_res = []
111 | fp = os.path.join(anno_base, f)
112 | data = json.load(open(fp))
113 | instruction = data['task_info']['instruction']
114 | steps = data['steps']
115 | category = data['task_info']['category']
116 | step_length = data['step_length']
117 |
118 | history_screenshot = []
119 | history_action = []
120 |
121 | for step in steps:
122 | image = step['screenshot']
123 | img_abs_path = os.path.join(pic_base, image)
124 | action = step['action']
125 | info = step['info']
126 | gt = decode_action(action, info)
127 |
128 | this_res.append({
129 | 'id': f'GUIOdyssey_{name}_{idx}',
130 | 'image': img_abs_path,
131 | 'question': instruction,
132 | 'answer': gt,
133 | 'category': category,
134 | 'step_length': step_length,
135 | 'history_action': str(history_action),
136 | 'history_screenshot': str(history_screenshot),
137 | })
138 | idx += 1
139 |
140 | history_screenshot.append(img_abs_path)
141 | history_action.append(gt)
142 |
143 | res.extend(this_res)
144 |
145 | json.dump(res, open(os.path.join(test_anno_base, os.path.basename(split_fp)), 'w'), indent=4, ensure_ascii=False)
146 |
147 |
148 | def make_his_idx(train_base=train_anno_base, test_base=test_anno_base):
149 | savep = './his_index.json'
150 | his_dict = {}
151 | for subsplit in os.listdir(train_base):
152 | subp = os.path.join(train_base, subsplit)
153 |
154 | data_all = json.load(open(subp))
155 | for data in data_all:
156 | img = data['image']
157 | history = eval(data['history'])
158 | if img in his_dict:
159 | assert his_dict[img] == history
160 | else:
161 | his_dict[img] = history
162 |
163 | for subsplit in os.listdir(test_base):
164 | subp = os.path.join(test_base, subsplit)
165 | data_all = json.load(open(subp))
166 | for data in data_all:
167 | img = data['image']
168 | history = eval(data['history_screenshot'])
169 | if img in his_dict:
170 | assert his_dict[img] == history
171 | else:
172 | his_dict[img] = history
173 |
174 | print(len(his_dict))
175 | json.dump(his_dict, open(savep, 'w'), indent=4, ensure_ascii=False)
176 |
177 |
178 | def main(his_len):
179 | for f in os.listdir(split_base):
180 | fp = os.path.join(split_base, f)
181 | build_train_chat(fp, his_len)
182 | build_test(fp, his_len)
183 |
184 | make_his_idx()
185 |
186 | if __name__ == "__main__":
187 | parser = argparse.ArgumentParser()
188 | parser.add_argument('--his_len', type=int, default=4)
189 | args = parser.parse_args()
190 | main(args.his_len)
--------------------------------------------------------------------------------
/data/preprocessing.py:
--------------------------------------------------------------------------------
1 | import os, shutil, pathlib
2 |
3 | screenshot_bp = './screenshots'
4 | anno_bp = './annotations'
5 |
6 | def reindex_screenshot():
7 | parent_dir = pathlib.Path(screenshot_bp)
8 | for subdir in parent_dir.iterdir():
9 | if subdir.is_dir():
10 | for file_path in subdir.iterdir():
11 | if file_path.is_file():
12 | shutil.move(str(file_path), str(parent_dir / file_path.name))
13 | print(f'{str(subdir)} ok.')
14 | subdir.rmdir()
15 |
16 |
17 |
18 |
19 | if __name__ == '__main__':
20 | reindex_screenshot()
--------------------------------------------------------------------------------
/introduction.md:
--------------------------------------------------------------------------------
1 | # ⚙️ Data Structure
2 |
3 |
4 | ## Data Fields
5 |
6 | Each field of annotation is as follows:
7 |
8 | * `episode_id`(str): the unique identifier of this episode.
9 | * `device_info`(dict): the detailed information of the virtual device from which the episode was collected.
10 | * `product`(str): the product name of the emulator.
11 | * `release_version`(str): the Android API level of the emulator.
12 | * `sdk_version`(str): the version of the software development kit used for the emulator.
13 | * `h`(int): the height of the device screen.
14 | * `w`(int): the width of the device screen.
15 | * `device_name`(str): the name of the virtual device, one of **Pixel Fold**, **Pixel Tablet**, **Pixel 8 Pro**, **Pixel 7 Pro**, **Medium Phone**, **Small Phone**
16 | * `task_info`(dict): the detailed information of the task from which the episode was collected.
17 | * `category`(str): the category of this task, one of **Multi_Apps**, **Web_Shopping**, **General_Tool**, **Information_Management**, **Media_Entertainment**, **Social_Sharing**
18 | * `app`(list[str]): the Apps used for this task.
19 | * `meta_task`(str): the template for this task, e.g., "Search for the next {} and set a reminder."
20 | * `task`(str): the specific task created by filling in the meta-task, e.g., "Search for the next New York Fashion Week and set a reminder."
21 | * `instruction`(str): the detailed and rephrased version of the task, including specific tools or applications, e.g., "Utilize DuckDuckgo to find the dates for the next New York Fashion Week and then use TickTick to set a reminder for the event."
22 | * `step_length`(int): the total number of steps in this episode.
23 | * `steps`(list[dict]): each individual step of this episode. Including the following fields:
24 | * `step`(int): each step within the episode is identified by a zero-indexed step number, indicating its position in sequence within the episode. For example, if the *step* is 1, it corresponds to the second step of the episode.
25 | * `screenshot`(str): the current screenshot of this step
26 | * `action`(str): the corresponding action of this step, one of **CLICK**, **SCROLL**, **LONG_PRESS**, **TYPE**, **COMPLETE**, **IMPOSSIBLE**, **HOME**, **BACK**
27 | * `info`(Union[str, list[list]]): provides specific details required to perform the action specified in the *action* field. Note that all the coordinates are normalized to the range of [0, 1000].
28 | * if action is *CLICK*, info contains the coordinates(x, y) to click on or one of the special keys *KEY_HOME*, *KEY_BACK*, *KEY_RECENT*.
29 | * if action is *LONG_PRESS*, info contains the coordinates(x, y) for the long press.
30 | * if action is *SCROLL*, info contains the starting(x1, y1) and ending(x2, y2) coordinates of the scroll action.
31 | * if action is any other value, info is empty ("").
32 | * `ps`(str): provides additional details or context depending on the value of the action field.
33 | * if action is *COMPLETE* or *IMPOSSIBLE*: may contain any additional information from the annotator about why the task is complete or why it was impossible to complete.
34 | * if action is *SCROLL*: contains the complete trajectory of the scroll action.
35 |
36 |
37 |
38 | ## Data Splits
39 | we can evaluate the in- and out-of-domain performance of Agent by splitting GUI Odyssey in two ways:
40 |
41 | * **random_split**: randomly splitting the dataset into the training and test set with the ratio of $3:1$,
42 |
43 | and organizing with the training set covering a portion of apps/tasks/devices and the test set covering the remaining apps/tasks/devices:
44 |
45 |
46 | * **task_split**: proportionally samples meta-tasks from six categories. The tasks in the test set differ significantly from those in the training set. This partitioning method allows for a robust assessment of an agent's generalization capabilities across diverse tasks.
47 |
48 | * **device_split**: selects episodes annotated on the *Fold Phone*, which differs significantly from other devices such as smartphones and tablets, as the test set.
49 |
50 | * **app_split**: splits based on the apps. The apps in the test set differ significantly from those in the training set.
51 |
52 | Each of the four classifications mentioned above has a corresponding JSON file, and the fields in each JSON file are as follows:
53 | * `train`(list[str]): the list of annotation filenames for the training set, which are equivalent to the *episode_id*.
54 | * `test`(list[str]): the list of annotation filenames for the test set, which are equivalent to the *episode_id*.
--------------------------------------------------------------------------------
/src/SimSun.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/GUI-Odyssey/8bf1c756691f077e3d9439dc16103a2cf9493245/src/SimSun.ttf
--------------------------------------------------------------------------------
/src/data_loader.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import torch
3 | import random
4 | import json, os
5 |
6 |
7 | class GUIOdysseyDataset(Dataset):
8 | def __init__(self, args):
9 | super().__init__()
10 |
11 | self.dataset = args.dataset
12 |
13 | self.data = self.load_GUIOdyssey()
14 | self.len = len(self.data)
15 |
16 | random.shuffle(self.data)
17 | print(self.len)
18 |
19 | def __len__(self):
20 | return self.len
21 |
22 | def __getitem__(self, idx):
23 | return self.data[idx]
24 |
25 | def load_GUIOdyssey(self):
26 | d = json.load(open(self.dataset))
27 | return d
--------------------------------------------------------------------------------
/src/eval_mm/GUIOdyssey_action_matching.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | TEXT_ANLS_THRESHOLD = 0.5
4 | CLICK_COORD_THRESHOLD = 0.14
5 |
6 | def levenshtein_distance(s1, s2):
7 | if len(s1) > len(s2):
8 | s1, s2 = s2, s1
9 |
10 | distances = range(len(s1) + 1)
11 | for i2, c2 in enumerate(s2):
12 | distances_ = [i2+1]
13 | for i1, c1 in enumerate(s1):
14 | if c1 == c2:
15 | distances_.append(distances[i1])
16 | else:
17 | distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
18 | distances = distances_
19 | return distances[-1]
20 |
21 |
22 | def text_matching(gt, pred):
23 | gt = gt.strip()
24 | pred = pred.strip()
25 | if gt in pred or pred in gt:
26 | return True
27 |
28 | dist = levenshtein_distance(gt, pred)
29 | length = max(len(gt), len(pred))
30 | value = 0.0 if length == 0 else float(dist) / float(length)
31 | value = 1 - value
32 | return value >= TEXT_ANLS_THRESHOLD
33 |
34 |
35 | def click_matching(gt_info, pred_info):
36 | if type(pred_info) == str:
37 | pred_info = eval(pred_info)
38 | if type(gt_info) == str:
39 | gt_info = eval(gt_info)
40 |
41 | pred = np.asarray(pred_info) / 1000
42 | gt = np.asarray(gt_info) / 1000
43 |
44 | return np.linalg.norm(pred - gt) <= CLICK_COORD_THRESHOLD
45 |
46 |
47 |
48 | def action_matching(pred_action, pred_info, gt_action, gt_info):
49 | pred_action = pred_action.strip()
50 | if type(pred_info) == str:
51 | pred_info = pred_info.strip()
52 | gt_action = gt_action.strip()
53 | if type(gt_info) == str:
54 | gt_info = gt_info.strip()
55 |
56 | if pred_action != gt_action:
57 | return {'is_correct': 'no', 'info': 'action_fail'}
58 |
59 | if gt_action not in ['SCROLL', 'CLICK', 'TYPE', 'LONG_PRESS']:
60 | return {'is_correct': 'yes', 'info': 'action_correct'}
61 |
62 | elif gt_action == 'TYPE':
63 | text_flag = text_matching(gt_info, pred_info)
64 |
65 | if text_flag:
66 | return {'is_correct': 'yes', 'info': 'type_correct'}
67 | else:
68 | return {'is_correct': 'no', 'info': 'type_fail'}
69 |
70 | elif gt_action == 'SCROLL':
71 | if gt_info.lower() == pred_info.lower():
72 | return {'is_correct': 'yes', 'info': 'scroll_correct'}
73 | else:
74 | return {'is_correct': 'no', 'info': 'scroll_fail'}
75 |
76 | elif gt_action == 'CLICK' or gt_action == 'LONG_PRESS':
77 | click_flag = click_matching(gt_info, pred_info)
78 |
79 | if click_flag:
80 | return {'is_correct': 'yes', 'info': 'click_correct'}
81 | else:
82 | return {'is_correct': 'no', 'info': 'click_fail'}
83 |
84 | else:
85 | raise ValueError('Invalid action type')
--------------------------------------------------------------------------------
/src/eval_mm/evaluate_GUIOdyssey.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import itertools
3 | import json
4 | import os, sys
5 | import torch
6 | import random
7 | import time
8 | from functools import partial
9 | from tqdm import tqdm
10 | import transformers
11 | from transformers import AutoModelForCausalLM, AutoTokenizer
12 | import warnings
13 | from GUIOdyssey_action_matching import action_matching
14 | import numpy as np
15 |
16 | warnings.filterwarnings("ignore")
17 | current_path = os.path.abspath(__file__)
18 | DATA_DIR = os.path.dirname(os.path.dirname(os.path.dirname(current_path)))
19 |
20 | sys.path.append(os.path.join(DATA_DIR, 'OdysseyAgent'))
21 | sys.path.append(os.path.join(DATA_DIR, 'src'))
22 | from qwen_generation_utils import make_context, decode_tokens
23 |
24 |
25 | IMAGE_HISTORY = True
26 |
27 | ds_collections = {
28 | 'app_split': {
29 | 'test': '../data/test_anno/app_split.json',
30 | 'metric': 'macro'
31 | },
32 | 'device_split': {
33 | 'test': '../data/test_anno/device_split.json',
34 | 'metric': 'macro'
35 | },
36 | 'random_split': {
37 | 'test': '../data/test_anno/random_split.json',
38 | 'metric': 'micro'
39 | },
40 | 'task_split': {
41 | 'test': '../data/test_anno/task_split.json',
42 | 'metric': 'macro'
43 | }
44 | }
45 |
46 |
47 | def simple_decode(gt):
48 | gts = gt.split(':')
49 | gt_action = gts[0].strip()
50 | if len(gts) > 1:
51 | action = gt_action
52 | info = gts[1].strip()
53 | if action in ['CLICK', "LONG_PRESS"]:
54 | info = eval(info)
55 | else:
56 | action = gt_action
57 | info = ""
58 | return {"action": action, "info": info}
59 |
60 | def stat_result(eval_dict, metric):
61 | text_correct = sum([1 for _ in eval_dict if _['info'] == 'type_correct'])
62 | type_correct = sum([1 for _ in eval_dict if _['info'] != 'action_fail'])
63 | text_total = sum([1 for _ in eval_dict if _['info'].startswith('type_')])
64 |
65 | if metric == 'macro':
66 | action_correct = sum([1 for _ in eval_dict if _['is_correct'] == 'yes'])
67 | AMS = round(action_correct / len(eval_dict) * 100, 2)
68 | SR_cnt, SR_tot, SR = check_SR(eval_dict)
69 | elif metric == 'micro':
70 | task_cate_dict = {}
71 | acc_list = []
72 | SR_list = []
73 | for sample in eval_dict:
74 | cat = sample['more_info']['category']
75 | if cat not in task_cate_dict:
76 | task_cate_dict[cat] = []
77 | task_cate_dict[cat].append(sample)
78 | assert len(task_cate_dict) == 6
79 | for k, v in task_cate_dict.items():
80 | SR_cnt, SR_tot, SR = check_SR(v)
81 | SR_list.append((SR))
82 | acc = round(sum([1 for x in v if x['is_correct'] == 'yes']) / len(v) * 100, 2)
83 | acc_list.append(acc)
84 | print(f'category: {k}, AMS: {acc}, SR: {SR}')
85 |
86 | AMS = np.round(np.mean(acc_list), 2)
87 | SR = np.round(np.mean(SR_list), 2)
88 |
89 | else:
90 | raise ValueError(f'No metric {metric} found.')
91 |
92 | info = {
93 | 'AMS': AMS,
94 | 'SR': SR,
95 | 'total': len(eval_dict),
96 | 'action_type': '{} / {} = {:.2f}'.format(type_correct, len(eval_dict), type_correct / len(eval_dict) * 100),
97 | 'text': '{} / {} = {:.2f}'.format(text_correct, text_total, text_correct / text_total * 100),
98 | }
99 |
100 | print(info)
101 | return info
102 |
103 | def action_matching_evaluation(pred_output, metric='macro'):
104 | eval_dict = []
105 | for idx, sample in enumerate(pred_output):
106 | question, pred, gt, more_info = sample['question'], sample['pred'], sample['gt'], sample['more_info']
107 | sample_eval_dict = {'question': question, 'pred': str(pred), 'gt': str(gt), 'more_info': more_info}
108 |
109 | gt_simple_info = simple_decode(gt)
110 | gt_action = gt_simple_info['action']
111 | gt_info = gt_simple_info['info']
112 |
113 | try:
114 | pred_simple_info = simple_decode(pred)
115 | pred_action = pred_simple_info['action']
116 | pred_info = pred_simple_info['info']
117 | except:
118 | print('eval err:', idx, pred)
119 | log_info = {'is_correct': 'no', 'info': 'invalid'}
120 | sample_eval_dict.update(log_info)
121 | eval_dict.append(sample_eval_dict)
122 | continue
123 |
124 | try:
125 | check_match = action_matching(pred_action, pred_info, gt_action, gt_info)
126 | except Exception as exc:
127 | print('eval err:', gt, pred, exc)
128 | check_match = {'is_correct': 'no', 'info': 'invalid'}
129 |
130 | sample_eval_dict.update(check_match)
131 | eval_dict.append(sample_eval_dict)
132 |
133 |
134 | info = stat_result(eval_dict, metric)
135 | metrics = {"info": info, "pred": eval_dict}
136 | return metrics
137 |
138 |
139 |
140 | def check_SR(eval_dict):
141 | episode_dict = {}
142 | steps_map = {}
143 | for data in eval_dict:
144 | if 'img' in data: img = data['img']
145 | elif 'image' in data: img = data['image']
146 | else: img = data['question'].split('')[0].split('
')[1]
147 | img = os.path.basename(img)
148 | tail = img.split('_')[-1]
149 | episode = img.replace(f'_{tail}', '')
150 | if episode not in episode_dict:
151 | episode_dict[episode] = []
152 | else:
153 | assert steps_map[episode] == data['more_info']['step_length']
154 |
155 | info = data['is_correct']
156 | episode_dict[episode].append(info)
157 | steps_map[episode] = data['more_info']['step_length']
158 |
159 | cnt, tot = 0, 0
160 | for k, v in episode_dict.items():
161 | if len(v) != steps_map[k]:
162 | print(f'step length of {k} does not match.')
163 | continue
164 | tot += 1
165 | v = list(set(v))
166 | if len(v) == 1 and v[0] == 'yes':
167 | cnt += 1
168 |
169 | SR = round(cnt / tot * 100, 2)
170 | print(f'total episode: {tot}, successful episode: {cnt}, SR: {SR}')
171 | return cnt, tot, SR
172 |
173 |
174 | def rank0_print(*args):
175 | if torch.distributed.get_rank() == 0:
176 | print(*args)
177 |
178 |
179 | def collate_fn(batches, tokenizer):
180 | question = [_['question'] for _ in batches]
181 | raw_texts = [_['raw_text'] for _ in batches]
182 | gt = [_['gt'] for _ in batches]
183 | more_info = [_['more_info'] for _ in batches]
184 | input_ids = tokenizer(raw_texts, return_tensors='pt', padding='longest')
185 |
186 | return question, raw_texts, input_ids.input_ids, input_ids.attention_mask, gt, more_info
187 |
188 |
189 | class LazySupervisedDataset(torch.utils.data.Dataset):
190 | def __init__(self, datapath, tokenizer: transformers.PreTrainedTokenizer, his_len, max_window_size, chat_format):
191 | super(LazySupervisedDataset, self).__init__()
192 | self.aitw = json.load(open(datapath))
193 | if len(self.aitw) > 50000:
194 | self.aitw = random.sample(self.aitw, len(self.aitw) // 10)
195 | self.tokenizer = tokenizer
196 | self.max_window_size = max_window_size
197 | self.chat_format = chat_format
198 | self.his_len = his_len
199 |
200 | def __len__(self):
201 | return len(self.aitw)
202 |
203 | def __getitem__(self, idx):
204 | data = self.aitw[idx]
205 | img = data['image']
206 | question = f"Picture 1:
{img}\nI'm looking for guidance on how to {data['question']}"
207 | answer = data['answer']
208 |
209 | history_action = eval(data['history_action'])[-self.his_len:]
210 | if IMAGE_HISTORY:
211 | if len(history_action) > 0:
212 | his_img = f'\nPrevious screenshots:
image-history: {img}'
213 | his_str = '\nPrevious Actions: '
214 | for idx, hi in enumerate(history_action):
215 | his_str += f"{idx+1}. {hi}\n"
216 |
217 | question = f"{question}{his_img}{his_str}"
218 | else:
219 | if len(history_action) > 0:
220 | his_str = '\nPrevious Actions: '
221 | for idx, hi in enumerate(history_action):
222 | his_str += f"{idx+1}. {hi}\n"
223 |
224 | question = f"{question}{his_str}"
225 |
226 | raw_text, _ = make_context(self.tokenizer, question, system="You are a helpful assistant.", max_window_size=self.max_window_size, chat_format=self.chat_format)
227 | more_info = {'category': data['category'], 'step_length': data['step_length']}
228 | return {
229 | 'raw_text': raw_text,
230 | 'question': question,
231 | 'gt': answer,
232 | 'more_info': more_info
233 | }
234 |
235 |
236 | class InferenceSampler(torch.utils.data.sampler.Sampler):
237 | def __init__(self, size):
238 | self._size = int(size)
239 | assert size > 0
240 | self._rank = torch.distributed.get_rank()
241 | self._world_size = torch.distributed.get_world_size()
242 | self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
243 |
244 | @staticmethod
245 | def _get_local_indices(total_size, world_size, rank):
246 | shard_size = total_size // world_size
247 | left = total_size % world_size
248 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
249 |
250 | begin = sum(shard_sizes[:rank])
251 | end = min(sum(shard_sizes[:rank + 1]), total_size)
252 | return range(begin, end)
253 |
254 | def __iter__(self):
255 | yield from self._local_indices
256 |
257 | def __len__(self):
258 | return len(self._local_indices)
259 |
260 |
261 | if __name__ == '__main__':
262 | parser = argparse.ArgumentParser()
263 | parser.add_argument('--checkpoint', type=str, default='/path/to/model')
264 | parser.add_argument('--dataset', type=str, default='random_split')
265 | parser.add_argument('--batch-size', type=int, default=4)
266 | parser.add_argument('--num-workers', type=int, default=12)
267 | parser.add_argument('--seed', type=int, default=2024)
268 | parser.add_argument('--output-path', type=str, default='output_res')
269 | parser.add_argument('--image-history', type=str, default='yes')
270 | parser.add_argument('--his_len', type=int, default=4)
271 | args = parser.parse_args()
272 |
273 | if args.image_history == 'no':
274 | IMAGE_HISTORY = False
275 | else:
276 | IMAGE_HISTORY = True
277 |
278 | torch.distributed.init_process_group(backend='nccl', world_size=int(os.getenv('WORLD_SIZE', '1')), rank=int(os.getenv('RANK', '0')),)
279 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
280 | rank0_print(args)
281 | rank0_print('load model...')
282 | model = AutoModelForCausalLM.from_pretrained(args.checkpoint, device_map='cuda', trust_remote_code=True).eval()
283 | tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True)
284 | tokenizer.padding_side = 'left'
285 | tokenizer.pad_token_id = tokenizer.eod_id
286 |
287 | rank0_print('init test set...')
288 | random.seed(args.seed)
289 | datapath = ds_collections[args.dataset]['test']
290 | dataset = LazySupervisedDataset(datapath=datapath, tokenizer=tokenizer, his_len=args.his_len, max_window_size=6144, chat_format='chatml')
291 |
292 | dataloader = torch.utils.data.DataLoader(
293 | dataset=dataset,
294 | sampler=InferenceSampler(len(dataset)),
295 | batch_size=args.batch_size,
296 | num_workers=args.num_workers,
297 | pin_memory=True,
298 | drop_last=False,
299 | collate_fn=partial(collate_fn, tokenizer=tokenizer),
300 | )
301 |
302 | rank0_print(f'len of dataloader: {len(dataloader)}')
303 | outputs = []
304 | for _, (question, raw_texts, input_ids, attention_mask, gt, more_info) in tqdm(enumerate(dataloader)):
305 | try:
306 | batch_input_ids = input_ids.to(model.device)
307 | batch_input_attention_mask = attention_mask.to(model.device)
308 |
309 | batch_out_ids = model.generate(
310 | input_ids=batch_input_ids,
311 | attention_mask=batch_input_attention_mask,
312 | do_sample=False,
313 | num_beams=1,
314 | length_penalty=1,
315 | num_return_sequences=1,
316 | use_cache=True,
317 | pad_token_id=tokenizer.eod_id,
318 | eos_token_id=tokenizer.eod_id,
319 | min_new_tokens=1,
320 | max_new_tokens=30,
321 | )
322 |
323 | padding_lens = [batch_input_ids[i].eq(tokenizer.pad_token_id).sum().item() for i in range(batch_input_ids.size(0))]
324 | batch_response = [decode_tokens(batch_out_ids[i][padding_lens[i]:], tokenizer, raw_text_len=len(raw_texts[i]), context_length=(batch_input_ids[i].size(0)-padding_lens[i]),
325 | chat_format="chatml", verbose=False, errors='replace') for i in range(len(raw_texts))]
326 |
327 | for q, pred, _gt, info in zip(question, batch_response, gt, more_info):
328 | outputs.append({
329 | 'question': q,
330 | 'pred': str(pred),
331 | 'gt': _gt,
332 | 'more_info': info,
333 | })
334 | except Exception as e:
335 | print('error', e)
336 | print(_)
337 | continue
338 |
339 | print(f'rank {torch.distributed.get_rank()} finished inference.')
340 | torch.distributed.barrier()
341 |
342 | world_size = torch.distributed.get_world_size()
343 | merged_outputs = [None for _ in range(world_size)]
344 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))
345 |
346 | merged_outputs = [json.loads(_) for _ in merged_outputs]
347 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
348 |
349 | if torch.distributed.get_rank() == 0:
350 | print(f"Saving predict result ...")
351 | # time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
352 | os.makedirs(args.output_path, exist_ok=True)
353 | model_name = str(args.checkpoint).replace('/', '_')
354 | savefile = os.path.join(args.output_path, f'{model_name}_{args.dataset}.json')
355 | json.dump(merged_outputs, open(savefile, 'w'), indent=4, ensure_ascii=False)
356 |
357 | print(f"Evaluating {args.dataset} ...")
358 | metrics = action_matching_evaluation(merged_outputs, metric=ds_collections[args.dataset]['metric'])
359 |
360 | output_data = {'dataset': args.dataset, 'model': model_name, 'metrics': metrics}
361 | json.dump(output_data, open(savefile, 'w'), indent=4, ensure_ascii=False)
362 |
363 | torch.distributed.barrier()
364 |
--------------------------------------------------------------------------------
/src/finetune.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | import logging
3 | import os, sys
4 | from typing import Dict, Optional, List
5 | import torch
6 | from torch.utils.data import Dataset
7 | from deepspeed import zero
8 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
9 | import transformers
10 | from transformers import Trainer, GPTQConfig, deepspeed
11 | from transformers.trainer_pt_utils import LabelSmoother
12 | from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
13 | from accelerate.utils import DistributedType
14 | from data_loader import GUIOdysseyDataset
15 | from PIL import Image
16 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index
17 |
18 | current_path = os.path.abspath(__file__)
19 | DATA_DIR = os.path.dirname(os.path.dirname(current_path))
20 | OdysseyAgent_model_path = os.path.join(DATA_DIR, 'OdysseyAgent')
21 | print(OdysseyAgent_model_path)
22 | sys.path.append(OdysseyAgent_model_path)
23 |
24 |
25 | @dataclass
26 | class ModelArguments:
27 | model_name_or_path: Optional[str] = field(default="/path/to/model")
28 |
29 |
30 | @dataclass
31 | class DataArguments:
32 | dataset: str = field(default='/path/to/chat_format_annotations', metadata={"help": "the dataset"})
33 |
34 | @dataclass
35 | class TrainingArguments(transformers.TrainingArguments):
36 | cache_dir: Optional[str] = field(default=None)
37 | optim: str = field(default="adamw_torch")
38 | model_max_length: int = field(
39 | default=8192,
40 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
41 | )
42 | use_lora: bool = False
43 | fix_vit: bool = True
44 |
45 |
46 | @dataclass
47 | class LoraArguments:
48 | lora_r: int = 64
49 | lora_alpha: int = 16
50 | lora_dropout: float = 0.05
51 | lora_target_modules: List[str] = field(
52 | default_factory=lambda: ["c_attn", "attn.c_proj", "w1", "w2"] ##["in_proj","out_proj","c_fc"]
53 | )
54 | lora_weight_path: str = ""
55 | lora_bias: str = "none"
56 | q_lora: bool = False
57 |
58 |
59 | def maybe_zero_3(param):
60 | if hasattr(param, "ds_id"):
61 | assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE
62 | with zero.GatheredParameters([param]):
63 | param = param.data.detach().cpu().clone()
64 | else:
65 | param = param.detach().cpu().clone()
66 | return param
67 |
68 |
69 | # Borrowed from peft.utils.get_peft_model_state_dict
70 | def get_peft_state_maybe_zero_3(named_params, bias):
71 | if bias == "none":
72 | to_return = {k: t for k, t in named_params if "lora_" in k}
73 | elif bias == "all":
74 | to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
75 | elif bias == "lora_only":
76 | to_return = {}
77 | maybe_lora_bias = {}
78 | lora_bias_names = set()
79 | for k, t in named_params:
80 | if "lora_" in k:
81 | to_return[k] = t
82 | bias_name = k.split("lora_")[0] + "bias"
83 | lora_bias_names.add(bias_name)
84 | elif "bias" in k:
85 | maybe_lora_bias[k] = t
86 | for k, t in maybe_lora_bias:
87 | if bias_name in lora_bias_names:
88 | to_return[bias_name] = t
89 | else:
90 | raise NotImplementedError
91 | to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}
92 | return to_return
93 |
94 | local_rank = None
95 |
96 | def rank0_print(*args):
97 | if local_rank == 0:
98 | print(*args)
99 |
100 |
101 | def set_seed(seed: int):
102 | import numpy as np
103 | import random
104 | torch.manual_seed(seed)
105 | if torch.cuda.is_available():
106 | torch.cuda.manual_seed_all(seed)
107 | np.random.seed(seed)
108 | random.seed(seed)
109 | os.environ["PYTHONHASHSEED"] = str(seed)
110 |
111 |
112 | def print_params(model):
113 | total_params = sum([p.numel() for p in model.parameters()])
114 | trainable_params = sum([p.numel() for p in model.parameters() if p.requires_grad])
115 |
116 | rank0_print(f"Total Parameters: {total_params}")
117 | rank0_print(f"Trainable Parameters: {trainable_params}")
118 |
119 |
120 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str, bias="none"):
121 | """Collects the state dict and dump to disk."""
122 | # check if zero3 mode enabled
123 | if deepspeed.is_deepspeed_zero3_enabled():
124 | state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()
125 | else:
126 | if trainer.args.use_lora:
127 | state_dict = get_peft_state_maybe_zero_3(
128 | trainer.model.named_parameters(), bias
129 | )
130 | else:
131 | state_dict = trainer.model.state_dict()
132 | if trainer.args.should_save and trainer.args.local_rank == 0:
133 | trainer._save(output_dir, state_dict=state_dict)
134 |
135 |
136 | def preprocess(
137 | sources,
138 | tokenizer: transformers.PreTrainedTokenizer,
139 | max_len: int,
140 | system_message: str = "You are a helpful assistant."
141 | ) -> Dict:
142 | roles = {"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"}
143 |
144 | im_start = tokenizer.im_start_id
145 | im_end = tokenizer.im_end_id
146 | nl_tokens = tokenizer('\n').input_ids
147 | _system = tokenizer('system').input_ids + nl_tokens
148 | _user = tokenizer('user').input_ids + nl_tokens
149 | _assistant = tokenizer('assistant').input_ids + nl_tokens
150 |
151 | # Apply prompt templates
152 | input_ids, targets = [], []
153 | for i, source in enumerate(sources):
154 | if roles[source[0]["from"]] != roles["user"]:
155 | source = source[1:]
156 |
157 | input_id, target = [], []
158 | system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens
159 | input_id += system
160 | target += [im_start] + [IGNORE_TOKEN_ID] * (len(system)-3) + [im_end] + nl_tokens
161 | assert len(input_id) == len(target)
162 | for j, sentence in enumerate(source):
163 | role = roles[sentence["from"]]
164 | _input_id = tokenizer(role).input_ids + nl_tokens + tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens
165 | input_id += _input_id
166 | if role == '<|im_start|>user':
167 | _target = [im_start] + [IGNORE_TOKEN_ID] * (len(_input_id)-3) + [im_end] + nl_tokens
168 | elif role == '<|im_start|>assistant':
169 | _target = [im_start] + [IGNORE_TOKEN_ID] * len(tokenizer(role).input_ids) + \
170 | _input_id[len(tokenizer(role).input_ids)+1:-2] + [im_end] + nl_tokens
171 | else:
172 | raise NotImplementedError
173 | target += _target
174 | assert len(input_id) == len(target)
175 | input_id += [tokenizer.pad_token_id] * (max_len - len(input_id))
176 | target += [IGNORE_TOKEN_ID] * (max_len - len(target))
177 | input_ids.append(input_id[:max_len])
178 | targets.append(target[:max_len])
179 | input_ids = torch.tensor(input_ids, dtype=torch.int)
180 | targets = torch.tensor(targets, dtype=torch.int)
181 |
182 | return dict(
183 | input_ids=input_ids,
184 | labels=targets,
185 | attention_mask=input_ids.ne(tokenizer.pad_token_id),
186 | )
187 |
188 |
189 | def load_img(url):
190 | img = Image.open(url)
191 | return img
192 |
193 | class LazySupervisedDataset(Dataset):
194 | """Dataset for supervised fine-tuning."""
195 |
196 | def __init__(self, dataset, tokenizer: transformers.PreTrainedTokenizer, max_len: int, preprocess_strategy=None):
197 | super(LazySupervisedDataset, self).__init__()
198 | self.dataset = dataset
199 | self.tokenizer = tokenizer
200 | self.max_len = max_len
201 | self.preprocess_strategy = preprocess_strategy
202 |
203 | rank0_print("Formatting inputs...Skip in lazy mode")
204 | self.tokenizer = tokenizer
205 | self.cached_data_dict = {}
206 |
207 | def __len__(self):
208 | return len(self.dataset)
209 |
210 | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
211 | try:
212 | if i in self.cached_data_dict:
213 | return self.cached_data_dict[i]
214 | ret = self.preprocess_strategy([self.dataset[i]["conversations"]], self.tokenizer, self.max_len)
215 | if "image" in self.dataset[i]: # test whether the image is valid
216 | img = load_img(self.dataset[i]["image"])
217 | ret = dict(
218 | input_ids=ret["input_ids"][0],
219 | labels=ret["labels"][0],
220 | attention_mask=ret["attention_mask"][0],
221 | )
222 | self.cached_data_dict[i] = ret
223 |
224 | return ret
225 | except Exception as e:
226 | print('get sample error:', str(e), self.dataset[i])
227 | return self.__getitem__((i+1) % len(self.dataset))
228 |
229 |
230 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args, max_len,) -> Dict:
231 | """Make dataset and collator for supervised fine-tuning."""
232 | dataset_cls = LazySupervisedDataset
233 | rank0_print("Loading data...")
234 |
235 | train_ = GUIOdysseyDataset(data_args)
236 |
237 | preprocess_strategy = preprocess
238 |
239 | train_dataset = dataset_cls(train_, tokenizer=tokenizer, max_len=max_len, preprocess_strategy=preprocess_strategy)
240 | eval_dataset = None
241 |
242 | return dict(train_dataset=train_dataset, eval_dataset=eval_dataset)
243 |
244 |
245 | def train():
246 | set_seed(201830168)
247 | global local_rank
248 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments, LoraArguments))
249 | (
250 | model_args,
251 | data_args,
252 | training_args,
253 | lora_args,
254 | ) = parser.parse_args_into_dataclasses()
255 | if getattr(training_args, 'deepspeed', None) and getattr(lora_args, 'q_lora', False):
256 | training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED
257 |
258 |
259 | local_rank = training_args.local_rank
260 |
261 | device_map = None
262 | world_size = int(os.environ.get("WORLD_SIZE", 1))
263 | ddp = world_size != 1
264 | if lora_args.q_lora:
265 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
266 | if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
267 | logging.warning("FSDP or ZeRO3 are not incompatible with QLoRA.")
268 | print(world_size)
269 | # Set RoPE scaling factor
270 | config = transformers.AutoConfig.from_pretrained(
271 | model_args.model_name_or_path,
272 | cache_dir=training_args.cache_dir,
273 | trust_remote_code=True,
274 | )
275 | config.use_cache = False
276 |
277 | # Load model and tokenizer
278 | model = transformers.AutoModelForCausalLM.from_pretrained(
279 | model_args.model_name_or_path,
280 | config=config,
281 | cache_dir=training_args.cache_dir,
282 | device_map=device_map,
283 | trust_remote_code=True,
284 | quantization_config=GPTQConfig(bits=4, disable_exllama=True) if training_args.use_lora and lora_args.q_lora else None,
285 | )
286 | if not training_args.use_lora:
287 | model.transformer.requires_grad_(True)
288 | if hasattr(model,'transformer') and hasattr(model.transformer,'visual'):
289 | model.transformer.visual.requires_grad_(False)
290 | if hasattr(model.transformer.visual,'attn_pool'):
291 | model.transformer.visual.attn_pool.requires_grad_(True)
292 | print_params(model)
293 |
294 | tokenizer = transformers.AutoTokenizer.from_pretrained(
295 | model_args.model_name_or_path,
296 | cache_dir=training_args.cache_dir,
297 | model_max_length=training_args.model_max_length,
298 | padding_side="right",
299 | use_fast=False,
300 | trust_remote_code=True,
301 | )
302 | tokenizer.pad_token_id = tokenizer.eod_id
303 |
304 | if training_args.use_lora:
305 | if lora_args.q_lora or "chat" in model_args.model_name_or_path.lower():
306 | modules_to_save = None
307 | else:
308 | modules_to_save = ["wte", "lm_head"]
309 | lora_config = LoraConfig(
310 | r=lora_args.lora_r,
311 | lora_alpha=lora_args.lora_alpha,
312 | target_modules=lora_args.lora_target_modules,
313 | lora_dropout=lora_args.lora_dropout,
314 | bias=lora_args.lora_bias,
315 | task_type="CAUSAL_LM",
316 | modules_to_save=modules_to_save # This argument serves for adding new tokens.
317 | )
318 | if lora_args.q_lora:
319 | model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
320 |
321 | model = get_peft_model(model, lora_config)
322 |
323 | if training_args.gradient_checkpointing:
324 | model.enable_input_require_grads()
325 |
326 | rank0_print(training_args)
327 | # Load data
328 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args, max_len=training_args.model_max_length)
329 | train_module = {"train_dataset": data_module["train_dataset"], "eval_dataset": data_module["eval_dataset"]}
330 | rank0_print("training total: {}".format(len(data_module["train_dataset"])))
331 |
332 | # Start trainner
333 | trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **train_module)
334 | resume_from_checkpoint = os.path.exists(training_args.output_dir) and len(os.listdir(training_args.output_dir)) != 0
335 | if resume_from_checkpoint:
336 | rank0_print('load from checkpoint.')
337 | trainer.train(resume_from_checkpoint=resume_from_checkpoint)
338 | trainer.save_state()
339 |
340 | safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir, bias=lora_args.lora_bias)
341 |
342 |
343 | if __name__ == "__main__":
344 | train()
345 |
--------------------------------------------------------------------------------
/src/finetune/ds_config_zero1.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 |
23 | "scheduler": {
24 | "type": "WarmupLR",
25 | "params": {
26 | "warmup_min_lr": "auto",
27 | "warmup_max_lr": "auto",
28 | "warmup_num_steps": "auto"
29 | }
30 | },
31 |
32 | "zero_optimization": {
33 | "stage": 1,
34 | "offload_optimizer": {
35 | "device": "none",
36 | "pin_memory": true
37 | },
38 | "allgather_partitions": true,
39 | "allgather_bucket_size": 2e8,
40 | "overlap_comm": true,
41 | "reduce_scatter": true,
42 | "reduce_bucket_size": 2e8,
43 | "contiguous_gradients": true
44 | },
45 |
46 | "gradient_accumulation_steps": "auto",
47 | "gradient_clipping": "auto",
48 | "steps_per_print": 100,
49 | "train_batch_size": "auto",
50 | "train_micro_batch_size_per_gpu": "auto",
51 | "wall_clock_breakdown": false
52 | }
--------------------------------------------------------------------------------
/src/finetune/ds_config_zero2.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 |
23 | "scheduler": {
24 | "type": "WarmupLR",
25 | "params": {
26 | "warmup_min_lr": "auto",
27 | "warmup_max_lr": "auto",
28 | "warmup_num_steps": "auto"
29 | }
30 | },
31 |
32 | "zero_optimization": {
33 | "stage": 2,
34 | "offload_optimizer": {
35 | "device": "none",
36 | "pin_memory": true
37 | },
38 | "allgather_partitions": true,
39 | "allgather_bucket_size": 2e8,
40 | "overlap_comm": true,
41 | "reduce_scatter": true,
42 | "reduce_bucket_size": 2e8,
43 | "contiguous_gradients": true
44 | },
45 |
46 | "gradient_accumulation_steps": "auto",
47 | "gradient_clipping": "auto",
48 | "steps_per_print": 100,
49 | "train_batch_size": "auto",
50 | "train_micro_batch_size_per_gpu": "auto",
51 | "wall_clock_breakdown": false
52 | }
--------------------------------------------------------------------------------
/src/finetune/ds_config_zero3.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 |
23 | "scheduler": {
24 | "type": "WarmupLR",
25 | "params": {
26 | "warmup_min_lr": "auto",
27 | "warmup_max_lr": "auto",
28 | "warmup_num_steps": "auto"
29 | }
30 | },
31 |
32 | "zero_optimization": {
33 | "stage": 3,
34 | "offload_optimizer": {
35 | "device": "none",
36 | "pin_memory": true
37 | },
38 | "offload_param": {
39 | "device": "none",
40 | "pin_memory": true
41 | },
42 | "overlap_comm": true,
43 | "contiguous_gradients": true,
44 | "sub_group_size": 1e9,
45 | "reduce_bucket_size": "auto",
46 | "stage3_prefetch_bucket_size": "auto",
47 | "stage3_param_persistence_threshold": "auto",
48 | "stage3_max_live_parameters": 1e9,
49 | "stage3_max_reuse_distance": 1e9,
50 | "stage3_gather_16bit_weights_on_model_save": true
51 | },
52 |
53 | "gradient_accumulation_steps": "auto",
54 | "gradient_clipping": "auto",
55 | "steps_per_print": 100,
56 | "train_batch_size": "auto",
57 | "train_micro_batch_size_per_gpu": "auto",
58 | "wall_clock_breakdown": false
59 | }
60 |
--------------------------------------------------------------------------------
/src/finetune/finetune_ds.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export CUDA_DEVICE_MAX_CONNECTIONS=1
3 | DIR=`pwd`
4 |
5 | GPUS_PER_NODE=8
6 | NNODES=1
7 | NODE_RANK=0
8 | MASTER_ADDR=localhost
9 | MASTER_PORT=6001
10 |
11 | MODEL="Qwen/Qwen-VL-Chat" #"Qwen/Qwen-VL-Chat"/"Qwen/Qwen-VL" # Set the path if you do not want to load from huggingface directly
12 | # ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
13 | # See the section for finetuning in README for more information.
14 | DATA="path_to_data"
15 |
16 | DISTRIBUTED_ARGS="
17 | --nproc_per_node $GPUS_PER_NODE \
18 | --nnodes $NNODES \
19 | --node_rank $NODE_RANK \
20 | --master_addr $MASTER_ADDR \
21 | --master_port $MASTER_PORT
22 | "
23 |
24 | torchrun $DISTRIBUTED_ARGS finetune.py \
25 | --model_name_or_path $MODEL \
26 | --data_path $DATA \
27 | --bf16 True \
28 | --fix_vit True \
29 | --output_dir output_qwen \
30 | --num_train_epochs 5 \
31 | --per_device_train_batch_size 1 \
32 | --per_device_eval_batch_size 1 \
33 | --gradient_accumulation_steps 16 \
34 | --evaluation_strategy "no" \
35 | --save_strategy "steps" \
36 | --save_steps 1000 \
37 | --save_total_limit 10 \
38 | --learning_rate 1e-5 \
39 | --weight_decay 0.1 \
40 | --adam_beta2 0.95 \
41 | --warmup_ratio 0.01 \
42 | --lr_scheduler_type "cosine" \
43 | --logging_steps 1 \
44 | --report_to "none" \
45 | --model_max_length 2048 \
46 | --gradient_checkpointing True \
47 | --lazy_preprocess True \
48 | --deepspeed finetune/ds_config_zero3.json
49 |
--------------------------------------------------------------------------------
/src/finetune/finetune_lora_ds.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export CUDA_DEVICE_MAX_CONNECTIONS=1
3 | DIR=`pwd`
4 |
5 | GPUS_PER_NODE=8
6 | NNODES=1
7 | NODE_RANK=0
8 | MASTER_ADDR=localhost
9 | MASTER_PORT=6001
10 |
11 | MODEL="Qwen/Qwen-VL-Chat" #"Qwen/Qwen-VL-Chat"/"Qwen/Qwen-VL" Set the path if you do not want to load from huggingface directly
12 | # ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
13 | # See the section for finetuning in README for more information.
14 | DATA="path_to_data"
15 |
16 | DISTRIBUTED_ARGS="
17 | --nproc_per_node $GPUS_PER_NODE \
18 | --nnodes $NNODES \
19 | --node_rank $NODE_RANK \
20 | --master_addr $MASTER_ADDR \
21 | --master_port $MASTER_PORT
22 | "
23 |
24 | torchrun $DISTRIBUTED_ARGS finetune.py \
25 | --model_name_or_path $MODEL \
26 | --data_path $DATA \
27 | --bf16 True \
28 | --fix_vit True \
29 | --output_dir output_qwen \
30 | --num_train_epochs 5 \
31 | --per_device_train_batch_size 2 \
32 | --per_device_eval_batch_size 1 \
33 | --gradient_accumulation_steps 8 \
34 | --evaluation_strategy "no" \
35 | --save_strategy "steps" \
36 | --save_steps 1000 \
37 | --save_total_limit 10 \
38 | --learning_rate 1e-5 \
39 | --weight_decay 0.1 \
40 | --adam_beta2 0.95 \
41 | --warmup_ratio 0.01 \
42 | --lr_scheduler_type "cosine" \
43 | --logging_steps 1 \
44 | --report_to "none" \
45 | --model_max_length 2048 \
46 | --lazy_preprocess True \
47 | --use_lora \
48 | --gradient_checkpointing \
49 | --deepspeed finetune/ds_config_zero2.json
--------------------------------------------------------------------------------
/src/finetune/finetune_lora_single_gpu.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export CUDA_DEVICE_MAX_CONNECTIONS=1
3 | DIR=`pwd`
4 |
5 |
6 | MODEL="Qwen/Qwen-VL-Chat" #"Qwen/Qwen-VL-Chat"/"Qwen/Qwen-VL" # Set the path if you do not want to load from huggingface directly
7 | # ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
8 | # See the section for finetuning in README for more information.
9 | DATA="path_to_data"
10 |
11 | export CUDA_VISIBLE_DEVICES=0
12 |
13 | python finetune.py \
14 | --model_name_or_path $MODEL \
15 | --data_path $DATA \
16 | --bf16 True \
17 | --fix_vit True \
18 | --output_dir output_qwen \
19 | --num_train_epochs 5 \
20 | --per_device_train_batch_size 1 \
21 | --per_device_eval_batch_size 1 \
22 | --gradient_accumulation_steps 8 \
23 | --evaluation_strategy "no" \
24 | --save_strategy "steps" \
25 | --save_steps 1000 \
26 | --save_total_limit 10 \
27 | --learning_rate 1e-5 \
28 | --weight_decay 0.1 \
29 | --adam_beta2 0.95 \
30 | --warmup_ratio 0.01 \
31 | --lr_scheduler_type "cosine" \
32 | --logging_steps 1 \
33 | --report_to "none" \
34 | --model_max_length 2048 \
35 | --lazy_preprocess True \
36 | --gradient_checkpointing \
37 | --use_lora
--------------------------------------------------------------------------------
/src/finetune/finetune_qlora_ds.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export CUDA_DEVICE_MAX_CONNECTIONS=1
3 | DIR=`pwd`
4 |
5 | GPUS_PER_NODE=8
6 | NNODES=1
7 | NODE_RANK=0
8 | MASTER_ADDR=localhost
9 | MASTER_PORT=6001
10 |
11 | MODEL="Qwen/Qwen-VL-Chat-Int4" # Qwen/Qwen-VL-Chat-Int4 Set the path if you do not want to load from huggingface directly
12 | # ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
13 | # See the section for finetuning in README for more information.
14 | DATA="path_to_data"
15 |
16 |
17 | DISTRIBUTED_ARGS="
18 | --nproc_per_node $GPUS_PER_NODE \
19 | --nnodes $NNODES \
20 | --node_rank $NODE_RANK \
21 | --master_addr $MASTER_ADDR \
22 | --master_port $MASTER_PORT
23 | "
24 |
25 | # Remember to use --fp16 instead of --bf16 due to autogptq
26 | torchrun $DISTRIBUTED_ARGS finetune.py \
27 | --model_name_or_path $MODEL \
28 | --data_path $DATA \
29 | --fp16 True \
30 | --fix_vit True \
31 | --output_dir output_qwen \
32 | --num_train_epochs 5 \
33 | --per_device_train_batch_size 2 \
34 | --per_device_eval_batch_size 1 \
35 | --gradient_accumulation_steps 8 \
36 | --evaluation_strategy "no" \
37 | --save_strategy "steps" \
38 | --save_steps 1000 \
39 | --save_total_limit 10 \
40 | --learning_rate 1e-5 \
41 | --weight_decay 0.1 \
42 | --adam_beta2 0.95 \
43 | --warmup_ratio 0.01 \
44 | --lr_scheduler_type "cosine" \
45 | --logging_steps 1 \
46 | --report_to "none" \
47 | --model_max_length 2048 \
48 | --lazy_preprocess True \
49 | --use_lora \
50 | --q_lora \
51 | --gradient_checkpointing \
52 | --deepspeed finetune/ds_config_zero2.json
--------------------------------------------------------------------------------
/src/finetune/finetune_qlora_single_gpu.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export CUDA_DEVICE_MAX_CONNECTIONS=1
3 | DIR=`pwd`
4 |
5 | MODEL="Qwen/Qwen-VL-Chat-Int4" # Qwen/Qwen-VL-Chat-Int4 Set the path if you do not want to load from huggingface directly
6 | # ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
7 | # See the section for finetuning in README for more information.
8 | DATA="path_to_data"
9 |
10 | export CUDA_VISIBLE_DEVICES=0
11 |
12 | # Remember to use --fp16 instead of --bf16 due to autogptq
13 | python finetune.py \
14 | --model_name_or_path $MODEL \
15 | --data_path $DATA \
16 | --fp16 True \
17 | --fix_vit True \
18 | --output_dir output_qwen \
19 | --num_train_epochs 5 \
20 | --per_device_train_batch_size 1 \
21 | --per_device_eval_batch_size 1 \
22 | --gradient_accumulation_steps 8 \
23 | --evaluation_strategy "no" \
24 | --save_strategy "steps" \
25 | --save_steps 1000 \
26 | --save_total_limit 10 \
27 | --learning_rate 1e-5 \
28 | --weight_decay 0.1 \
29 | --adam_beta2 0.95 \
30 | --warmup_ratio 0.01 \
31 | --lr_scheduler_type "cosine" \
32 | --logging_steps 1 \
33 | --report_to "none" \
34 | --model_max_length 2048 \
35 | --lazy_preprocess True \
36 | --gradient_checkpointing \
37 | --use_lora \
38 | --q_lora \
39 | --deepspeed finetune/ds_config_zero2.json
40 |
--------------------------------------------------------------------------------
/src/merge_weight.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModelForCausalLM, AutoTokenizer
2 | from transformers.generation import GenerationConfig
3 | from transformers import modeling_utils
4 | import torch
5 |
6 | QWENVL_PATH = 'Qwen/Qwen-VL-Chat'
7 |
8 | import sys
9 | sys.path.append('../OdysseyAgent')
10 | from configuration_qwen import QWenConfig
11 | from modeling_qwen import QWenLMHeadModel
12 |
13 | torch.manual_seed(1234)
14 | import json, random
15 | import time
16 |
17 | device = 'cpu'
18 |
19 | def load_qwen(model_name=QWENVL_PATH):
20 | model = AutoModelForCausalLM.from_pretrained(model_name, device_map=None, trust_remote_code=True)
21 | return model
22 |
23 | def load_qwen_tokenizer(model_name=QWENVL_PATH):
24 | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
25 | return tokenizer
26 |
27 |
28 | def load_new_qwen(bp='../OdysseyAgent/config.json'):
29 | cfg = QWenConfig(**json.load(open(bp)))
30 | model = QWenLMHeadModel(cfg)
31 | return model
32 |
33 | def merge_weight(qwen, odysseyAgent):
34 | qwen_dict = qwen.state_dict()
35 | odysseyAgent_dict = odysseyAgent.state_dict()
36 | for k in qwen_dict.keys():
37 | if k in odysseyAgent_dict:
38 | odysseyAgent_dict[k] = qwen_dict[k]
39 | odysseyAgent.load_state_dict(odysseyAgent_dict)
40 | return odysseyAgent
41 |
42 |
43 | def copy_QwenVL(bp='../OdysseyAgent'):
44 | tokenizer = load_qwen_tokenizer()
45 | tokenizer.save_pretrained(bp)
46 | qwen_model = load_qwen()
47 | new_qwen_model = load_new_qwen()
48 | print('start merging weight...')
49 | new_model = merge_weight(qwen_model, new_qwen_model)
50 | print('saving...')
51 | new_model.save_pretrained(bp)
52 |
53 | if __name__ == '__main__':
54 | copy_QwenVL()
55 |
--------------------------------------------------------------------------------
/src/qwen_generation_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Alibaba Cloud.
2 | #
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | """Generation support."""
7 |
8 | from typing import Tuple, List, Union, Iterable
9 |
10 | import numpy as np
11 | import torch
12 | import torch.nn.functional as F
13 | from transformers import PreTrainedTokenizer
14 | from transformers import logging
15 | from transformers.generation import LogitsProcessor
16 |
17 | logger = logging.get_logger(__name__)
18 |
19 | # Types.
20 | HistoryType = List[Tuple[str, str]]
21 | TokensType = List[int]
22 | BatchTokensType = List[List[int]]
23 |
24 |
25 | def pad_batch(batch: BatchTokensType, pad_id: int, seq_length: int) -> BatchTokensType:
26 | for tokens in batch:
27 | context_length = len(tokens)
28 | if context_length < seq_length:
29 | tokens.extend([pad_id] * (seq_length - context_length))
30 | return batch
31 |
32 |
33 | def get_ltor_masks_and_position_ids(
34 | data,
35 | eod_token,
36 | reset_position_ids,
37 | reset_attention_mask,
38 | eod_mask_loss,
39 | ):
40 | """Build masks and position id for left to right model."""
41 |
42 | # Extract batch size and sequence length.
43 | micro_batch_size, seq_length = data.size()
44 |
45 | # Attention mask (lower triangular).
46 | if reset_attention_mask:
47 | att_mask_batch = micro_batch_size
48 | else:
49 | att_mask_batch = 1
50 | attention_mask = torch.tril(
51 | torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
52 | ).view(att_mask_batch, 1, seq_length, seq_length)
53 |
54 | # Loss mask.
55 | loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
56 | if eod_mask_loss:
57 | loss_mask[data == eod_token] = 0.0
58 |
59 | # Position ids.
60 | position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
61 | position_ids = position_ids.unsqueeze(0).expand_as(data)
62 | # We need to clone as the ids will be modifed based on batch index.
63 | if reset_position_ids:
64 | position_ids = position_ids.clone()
65 |
66 | if reset_position_ids or reset_attention_mask:
67 | # Loop through the batches:
68 | for b in range(micro_batch_size):
69 |
70 | # Find indecies where EOD token is.
71 | eod_index = position_ids[b, data[b] == eod_token]
72 | # Detach indecies from positions if going to modify positions.
73 | if reset_position_ids:
74 | eod_index = eod_index.clone()
75 |
76 | # Loop through EOD indecies:
77 | prev_index = 0
78 | for j in range(eod_index.size()[0]):
79 | i = eod_index[j]
80 | # Mask attention loss.
81 | if reset_attention_mask:
82 | attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0
83 | # Reset positions.
84 | if reset_position_ids:
85 | position_ids[b, (i + 1) :] -= i + 1 - prev_index
86 | prev_index = i + 1
87 |
88 | # Convert attention mask to binary:
89 | attention_mask = attention_mask < 0.5
90 |
91 | return attention_mask, loss_mask, position_ids
92 |
93 |
94 | def get_batch(context_tokens: torch.LongTensor, eod_id: int):
95 | """Generate batch from context tokens."""
96 | # Move to GPU.
97 | tokens = context_tokens.contiguous().to(context_tokens.device)
98 | # Get the attention mask and postition ids.
99 | attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
100 | tokens,
101 | eod_id,
102 | reset_position_ids=False,
103 | reset_attention_mask=False,
104 | eod_mask_loss=False,
105 | )
106 | return tokens, attention_mask, position_ids
107 |
108 |
109 | def get_stop_words_ids(chat_format, tokenizer):
110 | if chat_format == "raw":
111 | stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]]
112 | elif chat_format == "chatml":
113 | stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
114 | else:
115 | raise NotImplementedError(f"Unknown chat format {chat_format!r}")
116 | return stop_words_ids
117 |
118 |
119 | def make_context(
120 | tokenizer: PreTrainedTokenizer,
121 | query: str,
122 | history: List[Tuple[str, str]] = None,
123 | system: str = "",
124 | max_window_size: int = 6144,
125 | chat_format: str = "chatml",
126 | ):
127 | if history is None:
128 | history = []
129 |
130 | if chat_format == "chatml":
131 | im_start, im_end = "<|im_start|>", "<|im_end|>"
132 | im_start_tokens = [tokenizer.im_start_id]
133 | im_end_tokens = [tokenizer.im_end_id]
134 | nl_tokens = tokenizer.encode("\n")
135 |
136 | def _tokenize_str(role, content):
137 | return f"{role}\n{content}", tokenizer.encode(
138 | role, allowed_special=set()
139 | ) + nl_tokens + tokenizer.encode(content, allowed_special=set())
140 |
141 | system_text, system_tokens_part = _tokenize_str("system", system)
142 | system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
143 |
144 | raw_text = ""
145 | context_tokens = []
146 |
147 | for turn_query, turn_response in reversed(history):
148 | query_text, query_tokens_part = _tokenize_str("user", turn_query)
149 | query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
150 | response_text, response_tokens_part = _tokenize_str(
151 | "assistant", turn_response
152 | )
153 | response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
154 |
155 | next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
156 | prev_chat = (
157 | f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
158 | )
159 |
160 | current_context_size = (
161 | len(system_tokens) + len(next_context_tokens) + len(context_tokens)
162 | )
163 | if current_context_size < max_window_size:
164 | context_tokens = next_context_tokens + context_tokens
165 | raw_text = prev_chat + raw_text
166 | else:
167 | break
168 |
169 | context_tokens = system_tokens + context_tokens
170 | raw_text = f"{im_start}{system_text}{im_end}" + raw_text
171 | context_tokens += (
172 | nl_tokens
173 | + im_start_tokens
174 | + _tokenize_str("user", query)[1]
175 | + im_end_tokens
176 | + nl_tokens
177 | + im_start_tokens
178 | + tokenizer.encode("assistant")
179 | + nl_tokens
180 | )
181 | raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
182 |
183 | elif chat_format == "raw":
184 | raw_text = query
185 | context_tokens = tokenizer.encode(raw_text)
186 | else:
187 | raise NotImplementedError(f"Unknown chat format {chat_format!r}")
188 |
189 | return raw_text, context_tokens
190 |
191 |
192 | def _decode_default(
193 | tokens: List[int],
194 | *,
195 | stop_words: List[str],
196 | eod_words: List[str],
197 | tokenizer: PreTrainedTokenizer,
198 | raw_text_len: int,
199 | verbose: bool = False,
200 | return_end_reason: bool = False,
201 | errors: str='replace',
202 | ):
203 | trim_decode_tokens = tokenizer.decode(tokens, errors=errors)[raw_text_len:]
204 | if verbose:
205 | print("\nRaw Generate: ", trim_decode_tokens)
206 |
207 | end_reason = f"Gen length {len(tokens)}"
208 | for stop_word in stop_words:
209 | trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
210 | for eod_word in eod_words:
211 | if eod_word in trim_decode_tokens:
212 | end_reason = f"Gen {eod_word!r}"
213 | trim_decode_tokens = trim_decode_tokens.split(eod_word)[0]
214 | trim_decode_tokens = trim_decode_tokens.strip()
215 | if verbose:
216 | print("\nEnd Reason:", end_reason)
217 | print("\nGenerate: ", trim_decode_tokens)
218 |
219 | if return_end_reason:
220 | return trim_decode_tokens, end_reason
221 | else:
222 | return trim_decode_tokens
223 |
224 |
225 | def _decode_chatml(
226 | tokens: List[int],
227 | *,
228 | stop_words: List[str],
229 | eod_token_ids: List[int],
230 | tokenizer: PreTrainedTokenizer,
231 | raw_text_len: int,
232 | context_length: int,
233 | verbose: bool = False,
234 | return_end_reason: bool = False,
235 | errors: str='replace'
236 | ):
237 | end_reason = f"Gen length {len(tokens)}"
238 | eod_token_idx = context_length
239 | for eod_token_idx in range(context_length, len(tokens)):
240 | if tokens[eod_token_idx] in eod_token_ids:
241 | end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}"
242 | break
243 |
244 | trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors)[raw_text_len:]
245 | if verbose:
246 | print("\nRaw Generate w/o EOD:", tokenizer.decode(tokens, errors=errors)[raw_text_len:])
247 | print("\nRaw Generate:", trim_decode_tokens)
248 | print("\nEnd Reason:", end_reason)
249 | for stop_word in stop_words:
250 | trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
251 | trim_decode_tokens = trim_decode_tokens.strip()
252 | if verbose:
253 | print("\nGenerate:", trim_decode_tokens)
254 |
255 | if return_end_reason:
256 | return trim_decode_tokens, end_reason
257 | else:
258 | return trim_decode_tokens
259 |
260 |
261 | def decode_tokens(
262 | tokens: Union[torch.LongTensor, TokensType],
263 | tokenizer: PreTrainedTokenizer,
264 | raw_text_len: int,
265 | context_length: int,
266 | chat_format: str,
267 | verbose: bool = False,
268 | return_end_reason: bool = False,
269 | errors: str="replace",
270 | ) -> str:
271 | if torch.is_tensor(tokens):
272 | tokens = tokens.cpu().numpy().tolist()
273 |
274 | if chat_format == "chatml":
275 | return _decode_chatml(
276 | tokens,
277 | stop_words=[],
278 | eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id],
279 | tokenizer=tokenizer,
280 | raw_text_len=raw_text_len,
281 | context_length=context_length,
282 | verbose=verbose,
283 | return_end_reason=return_end_reason,
284 | errors=errors,
285 | )
286 | elif chat_format == "raw":
287 | return _decode_default(
288 | tokens,
289 | stop_words=["<|endoftext|>"],
290 | eod_words=["<|endoftext|>"],
291 | tokenizer=tokenizer,
292 | raw_text_len=raw_text_len,
293 | verbose=verbose,
294 | return_end_reason=return_end_reason,
295 | errors=errors,
296 | )
297 | else:
298 | raise NotImplementedError(f"Unknown chat format {chat_format!r}")
299 |
300 |
301 | class StopWordsLogitsProcessor(LogitsProcessor):
302 | """
303 | :class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration.
304 | Args:
305 | stop_words_ids (:obj:`List[List[int]]`):
306 | List of list of token ids of stop ids. In order to get the tokens of the words
307 | that should not appear in the generated text, use :obj:`tokenizer(bad_word,
308 | add_prefix_space=True).input_ids`.
309 | eos_token_id (:obj:`int`):
310 | The id of the `end-of-sequence` token.
311 | """
312 |
313 | def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int):
314 |
315 | if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0:
316 | raise ValueError(
317 | f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}."
318 | )
319 | if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids):
320 | raise ValueError(
321 | f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}."
322 | )
323 | if any(
324 | any(
325 | (not isinstance(token_id, (int, np.integer)) or token_id < 0)
326 | for token_id in stop_word_ids
327 | )
328 | for stop_word_ids in stop_words_ids
329 | ):
330 | raise ValueError(
331 | f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}."
332 | )
333 |
334 | self.stop_words_ids = list(
335 | filter(
336 | lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids
337 | )
338 | )
339 | self.eos_token_id = eos_token_id
340 | for stop_token_seq in self.stop_words_ids:
341 | assert (
342 | len(stop_token_seq) > 0
343 | ), "Stop words token sequences {} cannot have an empty list".format(
344 | stop_words_ids
345 | )
346 |
347 | def __call__(
348 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor
349 | ) -> torch.FloatTensor:
350 | stopped_samples = self._calc_stopped_samples(input_ids)
351 | for i, should_stop in enumerate(stopped_samples):
352 | if should_stop:
353 | scores[i, self.eos_token_id] = float(2**15)
354 | return scores
355 |
356 | def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool:
357 | if len(tokens) == 0:
358 | # if bad word tokens is just one token always ban it
359 | return True
360 | elif len(tokens) > len(prev_tokens):
361 | # if bad word tokens are longer then prev input_ids they can't be equal
362 | return False
363 | elif prev_tokens[-len(tokens) :].tolist() == tokens:
364 | # if tokens match
365 | return True
366 | else:
367 | return False
368 |
369 | def _calc_stopped_samples(self, prev_input_ids: Iterable[int]) -> Iterable[int]:
370 | stopped_samples = []
371 | for prev_input_ids_slice in prev_input_ids:
372 | match = False
373 | for stop_token_seq in self.stop_words_ids:
374 | if self._tokens_match(prev_input_ids_slice, stop_token_seq):
375 | # if tokens do not match continue
376 | match = True
377 | break
378 | stopped_samples.append(match)
379 |
380 | return stopped_samples
381 |
382 |
383 | def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
384 | """This function has been mostly taken from huggingface conversational
385 | ai code at
386 | https://medium.com/huggingface/how-to-build-a-state-of-the-art-
387 | conversational-ai-with-transfer-learning-2d818ac26313"""
388 |
389 | if top_k > 0:
390 | # Remove all tokens with a probability less than the
391 | # last token of the top-k
392 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
393 | logits[indices_to_remove] = filter_value
394 |
395 | if top_p > 0.0:
396 | # Cconvert to 1D
397 | sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
398 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
399 |
400 | # Remove tokens with cumulative probability above the threshold
401 | sorted_indices_to_remove = cumulative_probs > top_p
402 | # Shift the indices to the right to keep also the first token
403 | # above the threshold
404 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
405 | sorted_indices_to_remove[..., 0] = 0
406 | for i in range(sorted_indices.size(0)):
407 | indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
408 | logits[i][indices_to_remove] = filter_value
409 |
410 | return logits
411 |
412 |
413 | def switch(val1, val2, boolean):
414 | boolean = boolean.type_as(val1)
415 | return (1 - boolean) * val1 + boolean * val2
416 |
--------------------------------------------------------------------------------
/src/requirements.txt:
--------------------------------------------------------------------------------
1 | --extra-index-url https://download.pytorch.org/whl/cu117
2 | torch==2.0.1
3 | torchvision==0.15.2
4 | torchaudio==2.0.2
5 | transformers==4.32.0
6 | accelerate
7 | tiktoken
8 | einops
9 | transformers_stream_generator==0.0.4
10 | scipy
11 | torchvision
12 | pillow
13 | tensorboard
14 | matplotlib
15 | peft
16 | deepspeed
17 |
--------------------------------------------------------------------------------
/src/script/eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | checkpoint=/path/to/checkpoint
3 | ds=app_split # one of "app_split", "device_split", "random_split", "task_split"
4 | DIR=`pwd`
5 |
6 | exp_name=OdysseyAgent_$ds
7 | mkdir -p output/"$exp_name"
8 |
9 | GPUS_PER_NODE=8
10 | NNODES=1
11 | NODE_RANK=0
12 | MASTER_ADDR=localhost
13 | MASTER_PORT=$((RANDOM % 30001 + 20000))
14 | GPUS=$((GPUS_PER_NODE * NNODES))
15 |
16 | DISTRIBUTED_ARGS="
17 | --nproc_per_node $GPUS_PER_NODE \
18 | --nnodes $NNODES \
19 | --node_rank $NODE_RANK \
20 | --master_addr $MASTER_ADDR \
21 | --master_port $MASTER_PORT
22 | "
23 |
24 | echo $ds
25 | echo $checkpoint
26 | torchrun $DISTRIBUTED_ARGS eval_mm/evaluate_GUIOdyssey.py \
27 | --checkpoint $checkpoint --dataset $ds --batch-size 16 --his_len 4
--------------------------------------------------------------------------------
/src/script/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | MODEL="/path/to/model" # "../OdysseyAgent"
3 | DATA_ROOT=/path/to/chat_format_annotation # ../data/train_anno/app_split.json ../data/train_anno/task_split.json
4 |
5 | exp_name=OdysseyAgent
6 | mkdir -p output/"$exp_name"
7 | OUTPUT_DIR=output_$exp_name
8 | export CUDA_DEVICE_MAX_CONNECTIONS=1
9 | DIR=`pwd`
10 |
11 | GPUS_PER_NODE=8
12 | NNODES=1
13 | NODE_RANK=0
14 | MASTER_ADDR=localhost
15 | MASTER_PORT=$((RANDOM % 10001 + 40000))
16 | GPUS=$((GPUS_PER_NODE * NNODES))
17 |
18 | DISTRIBUTED_ARGS="
19 | --nproc_per_node $GPUS_PER_NODE \
20 | --nnodes $NNODES \
21 | --node_rank $NODE_RANK \
22 | --master_addr $MASTER_ADDR \
23 | --master_port $MASTER_PORT
24 | "
25 | torchrun $DISTRIBUTED_ARGS finetune.py \
26 | --model_name_or_path $MODEL \
27 | --dataset $DATA_ROOT \
28 | --fp16 True \
29 | --fix_vit True \
30 | --output_dir $OUTPUT_DIR \
31 | --num_train_epochs 1 \
32 | --per_device_train_batch_size 2 \
33 | --per_device_eval_batch_size 1 \
34 | --gradient_accumulation_steps 8 \
35 | --evaluation_strategy "no" \
36 | --save_strategy "steps" \
37 | --save_steps 5000 \
38 | --save_total_limit 300 \
39 | --learning_rate 2e-5 \
40 | --weight_decay 0.1 \
41 | --adam_beta2 0.95 \
42 | --warmup_ratio 0.01 \
43 | --lr_scheduler_type "cosine" \
44 | --logging_steps 1 \
45 | --report_to "none" \
46 | --model_max_length 800 \
47 | --gradient_checkpointing True \
48 | --deepspeed finetune/ds_config_zero2.json
49 |
50 |
--------------------------------------------------------------------------------