', '').replace('\n', '').strip()
48 | questions = [first_instruction]
49 | answers = []
50 |
51 | for i, item in enumerate(ann["conversations"][1:]):
52 | if i % 2 == 0: # assistant
53 | assistant_answer = item["value"]
54 | answers.append(assistant_answer)
55 | else:
56 | human_instruction = item["value"] + " "
57 | questions.append(human_instruction)
58 |
59 | return {
60 | "image": image,
61 | "question": questions,
62 | "answer": answers,
63 | "image_id": self.img_ids[ann["id"]]
64 | }
65 |
66 |
67 |
68 | class LlavaMedVQADataCollator:
69 | def __init__(self, processor):
70 | self.processor = processor
71 |
72 | def __call__(self, examples):
73 | assert len(examples) == 1, 'Phi-3-V only supports batch_size == 1'
74 | example = examples[0]
75 |
76 | image = example['image']
77 | idx = random.randint(0, len(example['question']) - 1)
78 | question = example['question'][idx]
79 | answer = example['answer'][idx]
80 | prompt_message = {
81 | 'role': 'user',
82 | 'content': f'<|image_1|>\n{question}',
83 | }
84 |
85 | prompt = self.processor.tokenizer.apply_chat_template(
86 | [prompt_message], tokenize=False, add_generation_prompt=True
87 | )
88 | answer = f'{answer}<|end|>\n<|endoftext|>'
89 |
90 | # mask questions for labels
91 | batch = self.processor(prompt, [image], return_tensors='pt')
92 | prompt_input_ids = batch['input_ids']
93 | # Do not add bos token to answer
94 | answer_input_ids = self.processor.tokenizer(
95 | answer, add_special_tokens=False, return_tensors='pt'
96 | )['input_ids']
97 | input_ids = torch.cat([prompt_input_ids, answer_input_ids], dim=1)
98 | ignore_index = -100
99 | labels = torch.cat(
100 | [
101 | torch.tensor([ignore_index] * len(prompt_input_ids[0])).unsqueeze(0),
102 | answer_input_ids,
103 | ],
104 | dim=1,
105 | )
106 |
107 | batch['input_ids'] = input_ids
108 | del batch['attention_mask']
109 | batch['labels'] = labels
110 |
111 | return batch
112 |
113 | # Usage example
114 | if __name__ == "__main__":
115 | from torchvision import transforms
116 |
117 | from tqdm import tqdm
118 |
119 | dataset = LlavaMedAlignDataset(annotation_file='/workspace/LLaVA-Med/align_train.json', vis_root='/workspace/medical_conversation')
120 |
121 |
122 | for sample in tqdm(dataset):
123 | print(sample['question'], sample['answer'])
124 | len += 1
125 |
126 | print(len)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Uncertainty-Driven Expert Control: Enhancing the Reliability of Medical Vision-Language Models
4 |
5 |
6 |
7 | ## 💡Overview
8 |
9 | **Expert-Controlled Classifier-Free Guidance** is a training-free expert-in-the-loop framework designed to align MedVLM with clinical expertise. It integrates token-level uncertainty estimation, a BioMedCLIP-based medical multimodal Retrieval-Augmented Generation (RAG), and interactive expert revisions and highlight-based guidance.
10 |
11 | ## 🔨Setup
12 | ### 🔨Installation
13 | ```
14 | conda create -n expert_cfg python=3.10 -y
15 | conda activate expert_cfg
16 | pip install -r requirements.txt
17 | ```
18 |
19 | ### 🔨Pre-trained weights
20 |
21 | #### Baseline Model:
22 | Download them to the current directory separately and merge them with `Phi-3-vision-128k-instruct` and `Phi-3.5-vision-instruct` respectively.
23 | + Phi-3V: [Huggingface](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct)
24 | + Phi-3.5V: [Huggingface](https://huggingface.co/microsoft/Phi-3.5-vision-instruct)
25 |
26 |
27 | #### Medical LoRA:
28 | Our fine-tuning Phi3V-Med and Phi3.5V-Med LoRA links:
29 |
30 | + Phi-3V-Med: [Huggingface](https://huggingface.co/ecoxial2007/Phi-3V-Med)
31 | + Phi-3.5V-Med: [Huggingface](https://huggingface.co/ecoxial2007/Phi-3.5V-Med)
32 | Download them to the `./lora_weights` folder
33 |
34 | #### Demo
35 | ```
36 | torchrun --nproc_per_node=1 demo.py \
37 | --bf16 \
38 | --use_lora \
39 | --input_json 'examples/input_queries.json' \
40 | --img_root 'examples/images' \
41 | --save_path 'examples/results.json' \
42 | --output_dir './lora_weights/logs_phi35_pubmed_instruct'
43 | ```
44 |
45 | #### Medical Image & Test Encoder for RAG(optional):
46 |
47 | Download BiomedCLIP and place it in `./src/backbone/BiomedCLIP`.
48 |
49 | BiomedCLIP links:
50 | + [Huggingface](https://huggingface.co/microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224)
51 |
52 | **Note**: Directly downloading weights from Huggingface might encounter network issues. To facilitate modifications, we have converted the original `.bin` file to PyTorch's `.pth`. We recommend using the Baiduyun version.
53 |
54 |
55 | ### 📑Data Preparation
56 | Our data mainly comes from publicly available, free online Pathology Education Informational Resource ([PEIR](https://peir.path.uab.edu/library/index.php?/category/2)) Digital Library.
57 | We test our model on:
58 | + [VQA-RAD](https://osf.io/89kps/)
59 | + [SLAKE](https://www.med-vqa.com/slake/)
60 | + [PathVQA](https://github.com/UCSD-AI4H/PathVQA)
61 |
62 | Medical Alignment and Instruction Tuning:
63 | + [PubMedVision](https://huggingface.co/datasets/FreedomIntelligence/PubMedVision)
64 | + [Llave-Med](https://github.com/microsoft/LLaVA-Med)
65 |
66 |
67 | ### Prepare BiomedCLIP Pre-extracted Image Feature
68 | Note: We recommend using our pre-extracted BioMedCLIP features. The original images can also be found in the links below:
69 |
70 |
71 | | Dataset | Pre-extracted Features & Original Images |
72 | |----------|------------------------------------------|
73 | | PEIR | [Baiduyun, Rename zzz2zip](https://pan.baidu.com/s/1sJp_3UzjIIvOiuyMB417GQ?pwd=6666)|
74 | | PEIR BioMedCLIP features & keyword & GPT3.5 rewrite caption | [Baiduyun](https://pan.baidu.com/s/1pqHhrxLL-ZdgEat0wNwLmQ?pwd=6666)|
75 | | PathVQA | [Baiduyun](https://pan.baidu.com/s/1b1SuDSbsNM1rVGzbx8utvg?pwd=6666)|
76 | | Slake | [Baiduyun](https://pan.baidu.com/s/1mfAoi9_HZkrk7OuyQIn4-w?pwd=6666)|
77 | | RADVQA | [Baiduyun](https://pan.baidu.com/s/1gBjAjq2L-iIMf0j05QsJ-w?pwd=6666)|
78 |
79 |
80 |
81 |
82 |
83 |
84 | ## 📝Acknowledgements
85 | We also reference the excellent repos of [Phi-3CookBook](https://github.com/microsoft/Phi-3CookBook), [HuatuoVision](https://huggingface.co/datasets/FreedomIntelligence/PubMedVision), [BioMedCLIP](https://huggingface.co/microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224), in addition to other specific repos to the baseline and dataset we examined (see paper).
86 |
87 | ## 📝Citation
88 | If you find this paper useful, please consider staring 🌟 this repo and citing 📑 our paper:
89 | ```
90 | @misc{liang2025uncertaintydriven,
91 | title={Uncertainty-Driven Expert Control: Enhancing the Reliability of Medical Vision-Language Models},
92 | author={Xiao Liang and Di Wang and Zhicheng Jiao and Ronghan Li and Pengfei Yang and Quan Wang and Tat-Seng Chua},
93 | year={2025},
94 | eprint={2507.09209},
95 | archivePrefix={arXiv},
96 | primaryClass={cs.CV}
97 | }
98 | ```
99 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import torch
5 | from accelerate import Accelerator
6 | from accelerate import DataLoaderConfiguration
7 | from tqdm import tqdm
8 | from transformers import (
9 | AutoModelForCausalLM,
10 | AutoProcessor,
11 | BitsAndBytesConfig,
12 | Trainer,
13 | TrainingArguments,
14 | )
15 |
16 | from src.datasets.radvqa import RADVQADataset
17 | from src.highlighter_modules.guidance import ProbCFGLogitsProcessor
18 |
19 |
20 | # suggested deepspeed config
21 | DS_CONFIG_DICT = {
22 | 'zero_optimization': {
23 | 'stage': 2,
24 | 'allgather_partitions': True,
25 | 'allgather_bucket_size': 5e8,
26 | 'overlap_comm': True,
27 | 'reduce_scatter': True,
28 | 'reduce_bucket_size': 5e8,
29 | 'contiguous_gradients': True,
30 | 'round_robin_gradients': True,
31 | },
32 | 'fp16': {
33 | 'enabled': 'auto',
34 | 'loss_scale': 0,
35 | 'loss_scale_window': 1000,
36 | 'initial_scale_power': 16,
37 | 'hysteresis': 2,
38 | 'min_loss_scale': 1,
39 | },
40 | 'bf16': {'enabled': 'auto'},
41 | 'train_micro_batch_size_per_gpu': 'auto',
42 | 'train_batch_size': 'auto',
43 | 'gradient_accumulation_steps': 'auto',
44 | 'gradient_clipping': 'auto',
45 | }
46 |
47 |
48 | def create_dataset(args):
49 | output_file_test = args.input_json
50 | img_root = args.img_root
51 | eval_dataset = RADVQADataset(annotation_file=output_file_test, vis_root=img_root)
52 | return eval_dataset
53 |
54 |
55 |
56 | class NoGradHook:
57 | def __init__(self):
58 | self.prev_enabled = True
59 |
60 | def maybe_enable_grad_hook(self, *_):
61 | torch.set_grad_enabled(self.prev_enabled)
62 |
63 | def disable_grad_hook(self, *_):
64 | self.prev_enabled = torch.is_grad_enabled()
65 | torch.set_grad_enabled(False)
66 |
67 |
68 | def freeze_vision_model(model):
69 | vision_no_grad_hook = NoGradHook()
70 | vision_module = model.model.vision_embed_tokens
71 | vision_module.register_forward_pre_hook(vision_no_grad_hook.disable_grad_hook)
72 | vision_module.register_forward_hook(vision_no_grad_hook.maybe_enable_grad_hook)
73 | for p in vision_module.parameters():
74 | p.requires_grad_(False)
75 |
76 |
77 | def create_model(model_name_or_path, use_flash_attention=False, use_qlora=False):
78 | bnb_config = (
79 | BitsAndBytesConfig(
80 | load_in_4bit=True,
81 | bnb_4bit_quant_type='nf4',
82 | bnb_4bit_compute_dtype=torch.bfloat16 if use_flash_attention else torch.float16,
83 | )
84 | if use_qlora
85 | else None
86 | )
87 |
88 | model = AutoModelForCausalLM.from_pretrained(
89 | model_name_or_path,
90 | # Phi-3-V is originally trained in bf16 + flash attn
91 | # For fp16 mixed precision training, load in f32 to avoid hf accelerate error
92 | torch_dtype=torch.bfloat16 if use_flash_attention else torch.float32,
93 | trust_remote_code=True,
94 | _attn_implementation='flash_attention_2' if use_flash_attention else 'eager',
95 | quantization_config=bnb_config,
96 | )
97 |
98 | return model
99 |
100 |
101 | @torch.no_grad()
102 | def evaluate(model, processor, eval_dataset, args, save_path=None, disable_tqdm=False):
103 | rank = int(os.environ.get('RANK', 0))
104 | local_rank = int(os.environ.get('LOCAL_RANK', 0))
105 | world_size = int(os.environ.get('WORLD_SIZE', 1))
106 |
107 | model.eval()
108 | answers_unique = []
109 | generated_texts_unique = []
110 |
111 | eval_dataset_shard = eval_dataset
112 |
113 | for i in tqdm(range(len(eval_dataset)), disable=(rank != 0) or disable_tqdm):
114 | # Phi-3-V currently only supports batch_size == 1
115 | example = eval_dataset_shard[i]
116 | answers_unique.append(example['answer'])
117 | answers_unique.append(example['answer'])
118 | image = example['image']
119 | question = example['question']
120 | caption = example["top_k_captions"][0]
121 | highlights = example['highlights']
122 | prompt_message = {
123 | 'role': 'user',
124 | 'content': f'{caption} <|image_1|>\n{question}',
125 | }
126 | prompt = processor.tokenizer.apply_chat_template(
127 | [prompt_message], tokenize=False, add_generation_prompt=True
128 | )
129 |
130 | qs_highlighted_parts = highlights
131 |
132 | inputs = processor(prompt, [image], return_tensors='pt', qs_highlighted_parts=qs_highlighted_parts).to(f'cuda:{local_rank}')
133 | hl_mask_ = inputs['highlight_attention_mask']
134 | hl_mask_[hl_mask_ == 1] = args.perturb_weight
135 | hl_mask_[hl_mask_ == 0] = args.attn
136 | cfg_batched_input = inputs['input_ids'].repeat(2, 1)
137 | pixel_values = inputs['pixel_values'].repeat(2, 1, 1, 1, 1)
138 | image_sizes = inputs['image_sizes'].repeat(2, 1)
139 |
140 | del inputs['highlight_attention_mask']
141 |
142 | generated_outputs = model.generate(
143 | input_ids=cfg_batched_input,
144 | pixel_values=pixel_values,
145 | attention_mask=torch.cat([inputs['attention_mask'], hl_mask_], dim=0),
146 | image_sizes=image_sizes,
147 | eos_token_id=processor.tokenizer.eos_token_id,
148 | max_new_tokens=64,
149 | num_beams=args.num_beams,
150 | logits_processor=[ProbCFGLogitsProcessor(guidance_scale=args.cfg, use_log=True)],
151 | output_scores=True,
152 | return_dict_in_generate=True
153 | )
154 |
155 |
156 | batch_index = 1
157 | prediction = processor.batch_decode(
158 | generated_outputs.sequences[:, inputs['input_ids'].size(1):],
159 | skip_special_tokens=True,
160 | clean_up_tokenization_spaces=False,
161 | )
162 | prediction = prediction[0].strip().strip('.')
163 |
164 | print('Question:', example['question'], 'GT:',example['answer'])
165 | print('Prediction:', prediction)
166 | token_probs = []
167 | generated_texts = []
168 | for i, scores in enumerate(generated_outputs.scores):
169 | probs = torch.softmax(scores, dim=-1)
170 | generated_token_id = generated_outputs.sequences[batch_index, inputs['input_ids'].size(1) + len(token_probs)]
171 | token_prob = probs[batch_index, generated_token_id].item()
172 | token_probs.append(token_prob)
173 |
174 | # Print the decoded tokens and their probabilities
175 | print("Generated text and token probabilities:")
176 | for idx, prob in enumerate(token_probs):
177 | token = processor.decode(generated_outputs.sequences[batch_index, inputs['input_ids'].size(1) + idx])
178 | print(f"{token} - Probability: {prob}")
179 | generated_texts.append(token)
180 |
181 | update_information = {
182 | 'question': example['question'],
183 | 'answer': example['answer'],
184 | 'prediction': prediction,
185 | 'token_probs': token_probs,
186 | 'token_preds': generated_texts
187 | }
188 | generated_texts_unique.append(update_information)
189 |
190 |
191 |
192 | if rank == 0:
193 | if save_path:
194 | with open(save_path, 'w') as f:
195 | json.dump(generated_texts_unique, f, indent=4)
196 |
197 |
198 |
199 | def main():
200 | parser = argparse.ArgumentParser()
201 | parser.add_argument(
202 | '--model_name_or_path',
203 | type=str,
204 | # default='./Phi-3-vision-128k-instruct',
205 | default='./Phi-3.5-vision-instruct',
206 | help='Model name or path to load from',
207 | )
208 | parser.add_argument('--use_flash_attention', action='store_true', help='Use Flash Attention')
209 | parser.add_argument('--bf16', action='store_true', help='Use BF16')
210 | parser.add_argument('--use_lora', action='store_true', help='Use LoRA')
211 | parser.add_argument('--use_qlora', action='store_true', help='Use QLora')
212 | parser.add_argument('--output_dir', type=str, help='Output directory')
213 | parser.add_argument('--save_path', type=str, help='Save json path')
214 | parser.add_argument('--input_json', type=str, help='Question and Answer json path')
215 | parser.add_argument('--img_root', type=str, help='Image Folder')
216 | parser.add_argument('--batch_size', type=int, default=16, help='Batch size')
217 | parser.add_argument('--num_crops', type=int, default=16, help='Number of maximum image crops')
218 | parser.add_argument('--no-tqdm', dest='tqdm', action='store_false', help='Disable tqdm')
219 | parser.add_argument('--lora_rank', type=int, default=64, help='LoRA rank')
220 | parser.add_argument(
221 | '--lora_alpha_ratio', type=float, default=2, help='LoRA alpha to rank ratio'
222 | )
223 | parser.add_argument('--lora_dropout', type=float, default=0.0, help='LoRA dropout')
224 | parser.add_argument('--freeze_vision_model', action='store_true', help='Freeze vision model')
225 |
226 | parser.add_argument("--num_beams", type=int, default=1)
227 | parser.add_argument("--cfg", type=float, default=1.5)
228 | parser.add_argument("--attn", type=float, default=3.0)
229 | parser.add_argument("--perturb_weight", type=float, default=0.01)
230 |
231 | args = parser.parse_args()
232 | args.attention_weight = args.attn
233 |
234 | assert args.num_crops <= 16, 'num_crops must be less than or equal to 16'
235 | if args.use_qlora:
236 | args.use_lora = True
237 |
238 | dataloader_config = DataLoaderConfiguration(
239 | dispatch_batches=None,
240 | split_batches=False,
241 | even_batches=True,
242 | use_seedable_sampler=True
243 | )
244 | accelerator = Accelerator(dataloader_config)
245 |
246 | with accelerator.local_main_process_first():
247 | processor = AutoProcessor.from_pretrained(
248 | args.model_name_or_path, trust_remote_code=True, num_crops=args.num_crops
249 | )
250 | model = create_model(
251 | args.model_name_or_path,
252 | use_flash_attention=args.use_flash_attention,
253 | use_qlora=args.use_qlora,
254 | )
255 |
256 | eval_dataset = create_dataset(args)
257 |
258 | num_gpus = accelerator.num_processes
259 | print(f'training on {num_gpus} GPUs')
260 | assert args.batch_size % num_gpus == 0, 'Batch size must be divisible by the number of GPUs'
261 |
262 |
263 | # eval after fine-tuning
264 | if args.use_lora:
265 | # first try to clear GPU memory
266 | del model
267 | __import__('gc').collect()
268 | torch.cuda.empty_cache()
269 |
270 | # reload the model for inference
271 | # this part also serves as an example of how to load a trained model
272 | model = AutoModelForCausalLM.from_pretrained(
273 | args.model_name_or_path,
274 | # Phi-3-V is originally trained in bf16 + flash attn
275 | # For fp16 mixed precision training, load in f32 to avoid hf accelerate error
276 | torch_dtype=torch.bfloat16 if args.use_flash_attention else torch.float32,
277 | trust_remote_code=True,
278 | _attn_implementation='flash_attention_2' if args.use_flash_attention else 'eager',
279 | )
280 | model.load_adapter(args.output_dir)
281 |
282 | local_rank = int(os.environ.get('LOCAL_RANK', 0))
283 | model = model.to(f'cuda:{local_rank}')
284 | evaluate(
285 | model,
286 | processor,
287 | eval_dataset,
288 | args,
289 | save_path=args.save_path,
290 | disable_tqdm=not args.tqdm,
291 | )
292 |
293 |
294 |
295 | if __name__ == '__main__':
296 | main()
297 |
--------------------------------------------------------------------------------
/Phi-3-vision-128k-instruct/processing_phi3_v_cfg.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """
17 | Processor class for Phi3-V.
18 | """
19 | import re
20 | from typing import List, Optional, Union
21 |
22 | import torch
23 |
24 | import transformers
25 | from transformers.feature_extraction_utils import BatchFeature
26 | from transformers.image_utils import ImageInput
27 | from transformers.processing_utils import ProcessorMixin
28 | from transformers.tokenization_utils_base import PaddingStrategy, TextInput, TruncationStrategy
29 | from src.highlighter_modules.utils import txt_highlight_mask
30 | from transformers.utils import TensorType
31 | from .image_processing_phi3_v import Phi3VImageProcessor
32 |
33 | transformers.Phi3VImageProcessor = Phi3VImageProcessor
34 |
35 |
36 | class Phi3VProcessor(ProcessorMixin):
37 | r"""
38 | Constructs a Phi3-V processor which wraps a Phi3-V image processor and a LLaMa tokenizer into a single processor.
39 |
40 | [`Phi3VProcessor`] offers all the functionalities of [`Phi3VImageProcessor`] and [`LlamaTokenizerFast`]. See the
41 | [`~Phi3VProcessor.__call__`] and [`~Phi3VProcessor.decode`] for more information.
42 |
43 | Args:
44 | image_processor ([`Phi3VImageProcessor`], *optional*):
45 | The image processor is a required input.
46 | tokenizer ([`LlamaTokenizerFast`], *optional*):
47 | The tokenizer is a required input.
48 | """
49 |
50 | attributes = ["image_processor", "tokenizer"]
51 | image_processor_class = "Phi3VImageProcessor"
52 | tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
53 | special_image_token = "<|image|>"
54 |
55 | def __init__(self, image_processor, tokenizer):
56 | self.image_processor = image_processor
57 | self.tokenizer = tokenizer
58 | self.tokenizer.padding_side = 'left'
59 | self.num_img_tokens = image_processor.num_img_tokens
60 | self.img_tokens = [f"<|image_{i + 1}|>" for i in range(1000000)]
61 |
62 | def __call__(
63 | self,
64 | text: Union[TextInput, List[TextInput]],
65 | images: ImageInput = None,
66 | padding: Union[bool, str, PaddingStrategy] = False,
67 | truncation: Union[bool, str, TruncationStrategy] = None,
68 | max_length=None,
69 | return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
70 | qs_highlighted_parts: List[str] = None
71 | ) -> BatchFeature:
72 | """
73 | Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
74 | and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
75 | the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
76 | Phi3ImageProcessor's [`~Phi3ImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
77 | of the above two methods for more information.
78 |
79 | Args:
80 | text (`str`, `List[str]`, `List[List[str]]`):
81 | The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
82 | (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
83 | `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
84 | images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
85 | The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
86 | tensor. Both channels-first and channels-last formats are supported.
87 | padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
88 | Select a strategy to pad the returned sequences (according to the model's padding side and padding
89 | index) among:
90 | - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
91 | sequence if provided).
92 | - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
93 | acceptable input length for the model if that argument is not provided.
94 | - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
95 | lengths).
96 | max_length (`int`, *optional*):
97 | Maximum length of the returned list and optionally padding length (see above).
98 | truncation (`bool`, *optional*):
99 | Activates truncation to cut input sequences longer than `max_length` to `max_length`.
100 | return_tensors (`str` or [`~utils.TensorType`], *optional*):
101 | If set, will return tensors of a particular framework. Acceptable values are:
102 |
103 | - `'tf'`: Return TensorFlow `tf.constant` objects.
104 | - `'pt'`: Return PyTorch `torch.Tensor` objects.
105 | - `'np'`: Return NumPy `np.ndarray` objects.
106 | - `'jax'`: Return JAX `jnp.ndarray` objects.
107 |
108 | Returns:
109 | [`BatchFeature`]: A [`BatchFeature`] with the following fields:
110 |
111 | - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
112 | - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
113 | `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
114 | `None`).
115 | - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
116 | """
117 | if images is not None:
118 | image_inputs = self.image_processor(images, return_tensors=return_tensors)
119 | else:
120 | image_inputs = {}
121 | inputs = self._convert_images_texts_to_inputs(image_inputs, text, padding=padding, truncation=truncation,
122 | max_length=max_length, return_tensors=return_tensors,
123 | qs_highlighted_parts=qs_highlighted_parts)
124 | return inputs
125 |
126 | def calc_num_image_tokens(self, images: ImageInput):
127 | """ Calculate the number of image tokens for each image.
128 | Args:
129 | images (`ImageInput`):
130 | Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
131 | passing in images with pixel values between 0 and 1, set `do_rescale=False`.
132 | """
133 | return self.image_processor.calc_num_image_tokens(images)
134 |
135 | def calc_num_image_tokens_from_image_size(self, width, height):
136 | """ Calculate the number of image token for an image with given width and height.
137 | Args:
138 | width (`int`):
139 | Width of the image.
140 | height (`int`):
141 | Height of the image.
142 | """
143 | return self.image_processor.calc_num_image_tokens_from_image_size(width, height)
144 |
145 | @property
146 | def special_image_token_id(self):
147 | return self.tokenizer.convert_tokens_to_ids(self.special_image_token)
148 |
149 | def get_special_image_token_id(self):
150 | return self.tokenizer.convert_tokens_to_ids(self.special_image_token)
151 |
152 | def tokenize_and_create_masks(self, pattern, texts, highlighted_parts):
153 | # 分割文本并保留图像标签
154 | prompt_chunks = re.split(pattern, texts)
155 | image_tags = re.findall(pattern, texts)
156 |
157 | input_ids = []
158 | highlight_attention_mask = []
159 |
160 | # 处理每个chunk和对应的图像标签
161 | for i, chunk in enumerate(prompt_chunks):
162 | # Tokenize the chunk
163 | chunk_ids = self.tokenizer.encode(chunk, add_special_tokens=False)
164 |
165 | # 生成对应chunk的高亮掩码
166 | chunk_highlight_mask = [0] * len(chunk_ids)
167 | for part in highlighted_parts:
168 | start = 0
169 | while start < len(chunk):
170 | start_index = chunk.find(part, start)
171 | if start_index == -1:
172 | break
173 | end_index = start_index + len(part)
174 | # 将字符索引转换为token索引
175 | start_token_idx = len(self.tokenizer.encode(chunk[:start_index], add_special_tokens=False))
176 | end_token_idx = len(self.tokenizer.encode(chunk[:end_index], add_special_tokens=False))
177 | for idx in range(start_token_idx, end_token_idx):
178 | if idx < len(chunk_highlight_mask):
179 | chunk_highlight_mask[idx] = 1
180 | start = end_index
181 |
182 | # 添加chunk的token ids和高亮掩码到总列表
183 | input_ids.extend(chunk_ids)
184 | highlight_attention_mask.extend(chunk_highlight_mask)
185 |
186 | # 如果还有图像标签,处理图像标签
187 | if i < len(image_tags):
188 | image_tag_ids = self.tokenizer.encode(image_tags[i], add_special_tokens=False)
189 | input_ids.extend(image_tag_ids)
190 | # 图像标签不高亮
191 | highlight_attention_mask.extend([0] * len(image_tag_ids))
192 |
193 | return input_ids, highlight_attention_mask
194 | # 转换为torch tensor
195 | # input_ids_tensor = torch.tensor(input_ids).unsqueeze(0)
196 | # highlight_attention_mask_tensor = torch.tensor(highlight_attention_mask).unsqueeze(0)
197 | #
198 | #
199 | # return input_ids_tensor, highlight_attention_mask_tensor
200 |
201 | def _convert_images_texts_to_inputs(self, images, texts, padding=False, truncation=None, max_length=None,
202 | return_tensors=None, qs_highlighted_parts=[]):
203 |
204 | if not len(images):
205 | model_inputs = self.tokenizer(texts, return_tensors=return_tensors, padding=padding, truncation=truncation,
206 | max_length=max_length)
207 | return BatchFeature(data={**model_inputs})
208 |
209 | # print(texts, qs_highlighted_parts, '!!!!!this is in processing.py')
210 |
211 | pattern = r"<\|image_\d+\|>"
212 |
213 | if 'num_img_tokens' in images:
214 | num_img_tokens = images['num_img_tokens']
215 | else:
216 | assert 'num_crops' in images, 'num_crops must be provided in images if num_img_tokens is not provided'
217 | num_crops = images['num_crops']
218 | num_img_tokens = [_num_crops * self.num_img_tokens for _num_crops in num_crops]
219 |
220 | images, image_sizes = images['pixel_values'], images['image_sizes']
221 |
222 | # image_tags needs to start from 1 to n
223 | image_tags = re.findall(pattern, texts)
224 | # image_ids = [int(s.split("|")[1].split("_")[-1]) * -1 for s in image_tags]
225 | # image_ids_pad = [[iid]*num_img_tokens[i] for i, iid in enumerate(image_ids)]
226 | image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
227 | unique_image_ids = sorted(list(set(image_ids)))
228 | # image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be [1, 4, 5]
229 | # check the condition
230 | assert unique_image_ids == list(range(1, len(unique_image_ids) + 1)), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
231 | # total images must be the same as the number of image tags
232 | assert len(unique_image_ids) == len(
233 | images), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(images)} images"
234 |
235 | image_ids_pad = [[-iid] * num_img_tokens[iid - 1] for iid in image_ids]
236 |
237 | def insert_separator(X, sep_list):
238 | if len(X) > len(sep_list):
239 | sep_list.append([])
240 | return [ele for sublist in zip(X, sep_list) for ele in sublist]
241 |
242 | highlight_attention_mask = []
243 | prompt_chunks = []
244 |
245 |
246 | # Generate highlight masks for each chunk
247 | for chunk in re.split(pattern, texts):
248 | chunk_mask, _ = txt_highlight_mask(self.tokenizer, chunk, qs_highlighted_parts)
249 | highlight_attention_mask.append([0] + chunk_mask)
250 | a = self.tokenizer(chunk)
251 | prompt_chunks.append(a.input_ids)
252 |
253 | offset = 0
254 | input_ids = []
255 | combined_highlight_mask = []
256 | zero_mask_padding = [[0] * len(pad) for pad in
257 | image_ids_pad] # Create zero padding mask with the same length as image ids
258 |
259 | for tokens, mask in zip(insert_separator(prompt_chunks, image_ids_pad),
260 | insert_separator(highlight_attention_mask,
261 | zero_mask_padding)): # Use zero_mask_padding here
262 | input_ids.extend(tokens[offset:])
263 | combined_highlight_mask.extend(mask[offset:])
264 |
265 | input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0)
266 | attention_mask = (input_ids > -1000000).to(torch.long)
267 | combined_highlight_mask = torch.tensor(combined_highlight_mask, dtype=torch.long).unsqueeze(0)
268 |
269 | return BatchFeature(data={"input_ids": input_ids,
270 | "attention_mask": attention_mask,
271 | "highlight_attention_mask": combined_highlight_mask,
272 | "pixel_values": images,
273 | "image_sizes": image_sizes})
274 |
275 | # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
276 | def batch_decode(self, *args, **kwargs):
277 | """
278 | This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
279 | refer to the docstring of this method for more information.
280 | """
281 | return self.tokenizer.batch_decode(*args, **kwargs)
282 |
283 | # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
284 | def decode(self, *args, **kwargs):
285 | """
286 | This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
287 | the docstring of this method for more information.
288 | """
289 | return self.tokenizer.decode(*args, **kwargs)
290 |
291 | @property
292 | # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
293 | def model_input_names(self):
294 | tokenizer_input_names = self.tokenizer.model_input_names
295 | image_processor_input_names = self.image_processor.model_input_names
296 | return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
--------------------------------------------------------------------------------
/Phi-3.5-vision-instruct/processing_phi3_v_cfg.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """
17 | Processor class for Phi3-V.
18 | """
19 | import re
20 | from typing import List, Optional, Union
21 |
22 | import torch
23 |
24 | import transformers
25 | from transformers.feature_extraction_utils import BatchFeature
26 | from transformers.image_utils import ImageInput
27 | from transformers.processing_utils import ProcessorMixin
28 | from transformers.tokenization_utils_base import PaddingStrategy, TextInput, TruncationStrategy
29 | from transformers.utils import TensorType
30 | from src.highlighter_modules.utils import txt_highlight_mask
31 |
32 | """Image processor class for Phi3-V."""
33 |
34 | from typing import List, Optional, Union
35 |
36 | import numpy as np
37 |
38 | from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
39 | from transformers.image_transforms import (
40 | convert_to_rgb,
41 | )
42 | from transformers.image_utils import (
43 | OPENAI_CLIP_MEAN,
44 | OPENAI_CLIP_STD,
45 | ImageInput,
46 | make_list_of_images,
47 | valid_images,
48 | )
49 | from transformers.utils import TensorType, is_vision_available, logging
50 |
51 | from transformers import AutoImageProcessor
52 |
53 | logger = logging.get_logger(__name__)
54 |
55 |
56 | if is_vision_available():
57 | from PIL import Image
58 |
59 | import torch
60 | import torchvision
61 |
62 | def padding_336(b):
63 | width, height = b.size
64 | tar = int(np.ceil(height / 336) * 336)
65 | top_padding = int((tar - height)/2)
66 | bottom_padding = tar - height - top_padding
67 | left_padding = 0
68 | right_padding = 0
69 | b = torchvision.transforms.functional.pad(b, [left_padding, top_padding, right_padding, bottom_padding], fill=[255,255,255])
70 |
71 | return b
72 |
73 | def calc_padded_size(width, height, padding_unit=336):
74 | target_height = int(np.ceil(height / padding_unit) * padding_unit)
75 | top_padding = int((target_height - height) / 2)
76 | bottom_padding = target_height - height - top_padding
77 | left_padding = 0
78 | right_padding = 0
79 | padded_width = width + left_padding + right_padding
80 | padded_height = height + top_padding + bottom_padding
81 | return padded_width, padded_height
82 |
83 | def HD_transform(img, hd_num=16):
84 | width, height = img.size
85 | trans = False
86 | if width < height:
87 | img = img.transpose(Image.TRANSPOSE)
88 | trans = True
89 | width, height = img.size
90 | ratio = (width/ height)
91 | scale = 1
92 | while scale*np.ceil(scale/ratio) <= hd_num:
93 | scale += 1
94 | scale -= 1
95 | new_w = int(scale * 336)
96 | new_h = int(new_w / ratio)
97 |
98 | img = torchvision.transforms.functional.resize(img, [new_h, new_w],)
99 | img = padding_336(img)
100 | width, height = img.size
101 | if trans:
102 | img = img.transpose(Image.TRANSPOSE)
103 |
104 | return img
105 |
106 | def calc_hd_transform_size(width, height, hd_num=16):
107 | transposed = False
108 | if width < height:
109 | width, height = height, width
110 | transposed = True
111 |
112 | ratio = width / height
113 | scale = 1
114 | while scale * np.ceil(scale / ratio) <= hd_num:
115 | scale += 1
116 | scale -= 1
117 |
118 | new_width = int(scale * 336)
119 | new_height = int(new_width / ratio)
120 |
121 | padded_width, padded_height = calc_padded_size(new_width, new_height)
122 |
123 | if transposed:
124 | padded_width, padded_height = padded_height, padded_width
125 |
126 | return padded_width, padded_height
127 |
128 | def pad_to_max_num_crops_tensor(images, max_crops=5):
129 | """
130 | images: B x 3 x H x W, B<=max_crops
131 | """
132 | B, _, H, W = images.shape
133 | if B < max_crops:
134 | pad = torch.zeros(max_crops - B, 3, H, W, dtype=images.dtype, device=images.device)
135 | images = torch.cat([images, pad], dim=0)
136 | return images
137 |
138 |
139 | class Phi3VImageProcessor(BaseImageProcessor):
140 | r"""
141 | Constructs a Phi3 image processor. Based on [`CLIPImageProcessor`] with incorporation of additional techniques
142 | for processing high resolution images as explained in the [InternLM-XComposer2-4KHD](https://arxiv.org/pdf/2404.06512)
143 |
144 | Args:
145 | image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
146 | Mean to use if normalizing the image. This is a float or list of floats the length of the number of
147 | channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
148 | image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
149 | Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
150 | number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
151 | Can be overridden by the `image_std` parameter in the `preprocess` method.
152 | do_convert_rgb (`bool`, *optional*, defaults to `True`):
153 | Whether to convert the image to RGB.
154 | """
155 |
156 | model_input_names = ["pixel_values"]
157 |
158 | def __init__(
159 | self,
160 | num_crops: int = 1,
161 | image_mean: Optional[Union[float, List[float]]] = None,
162 | image_std: Optional[Union[float, List[float]]] = None,
163 | do_convert_rgb: bool = True,
164 | **kwargs,
165 | ) -> None:
166 | super().__init__(**kwargs)
167 | self.num_crops = num_crops
168 | self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
169 | self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
170 | self.do_convert_rgb = do_convert_rgb
171 |
172 | def calc_num_image_tokens(
173 | self,
174 | images: ImageInput
175 | ):
176 | """ Calculate the number of image tokens for each image.
177 | Args:
178 | images (`ImageInput`):
179 | Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
180 | passing in images with pixel values between 0 and 1, set `do_rescale=False`.
181 | """
182 | images = make_list_of_images(images)
183 |
184 | if not valid_images(images):
185 | raise ValueError(
186 | "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
187 | "torch.Tensor, tf.Tensor or jax.ndarray."
188 | )
189 |
190 | images = [image.convert('RGB') for image in images]
191 | # (H, W, C)
192 | elems = [HD_transform(im, hd_num = self.num_crops) for im in images]
193 | shapes = [[im.size[1], im.size[0]] for im in elems]
194 | num_img_tokens = [int((h//336*w//336+1)*144 + 1 + (h//336+1)*12) for h, w in shapes]
195 | return num_img_tokens
196 |
197 | def calc_num_image_tokens_from_image_size(self, width, height):
198 | """
199 | Calculate the number of image tokens for a given image size.
200 | Args:
201 | width (`int`): Width of the image.
202 | height (`int`): Height of the image.
203 | """
204 | new_width, new_height = calc_hd_transform_size(width, height, hd_num=self.num_crops)
205 | num_img_tokens = int((new_height // 336 * new_width // 336 + 1) * 144 + 1 + (new_height // 336 + 1) * 12)
206 | return num_img_tokens
207 |
208 | def preprocess(
209 | self,
210 | images: ImageInput,
211 | image_mean: Optional[Union[float, List[float]]] = None,
212 | image_std: Optional[Union[float, List[float]]] = None,
213 | do_convert_rgb: bool = None,
214 | return_tensors: Optional[Union[str, TensorType]] = None,
215 | ):
216 | """
217 | Args:
218 | images (`ImageInput`):
219 | Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
220 | passing in images with pixel values between 0 and 1, set `do_rescale=False`.
221 | image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
222 | Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
223 | image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
224 | Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
225 | `True`.
226 | do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
227 | Whether to convert the image to RGB.
228 | return_tensors (`str` or `TensorType`, *optional*):
229 | The type of tensors to return. Can be one of:
230 | - Unset: Return a list of `np.ndarray`.
231 | - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
232 | - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
233 | - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
234 | - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
235 | """
236 | image_mean = image_mean if image_mean is not None else self.image_mean
237 | image_std = image_std if image_std is not None else self.image_std
238 | do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
239 |
240 | images = make_list_of_images(images)
241 |
242 | if not valid_images(images):
243 | raise ValueError(
244 | "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
245 | "torch.Tensor, tf.Tensor or jax.ndarray."
246 | )
247 |
248 | if do_convert_rgb:
249 | images = [convert_to_rgb(image) for image in images]
250 |
251 | image_sizes = []
252 | img_processor = torchvision.transforms.Compose([
253 | torchvision.transforms.ToTensor(),
254 | torchvision.transforms.Normalize(image_mean, image_std)
255 | ])
256 |
257 | # PIL images
258 | # HD_transform pad images to size of multiiply of 336, 336
259 | # convert to RGB first
260 | images = [image.convert('RGB') for image in images]
261 | elems = [HD_transform(im, hd_num = self.num_crops) for im in images]
262 | # tensor transform and normalize
263 | hd_images = [img_processor(im) for im in elems]
264 | # create global image
265 | global_image = [torch.nn.functional.interpolate(im.unsqueeze(0).float(), size=(336, 336), mode='bicubic',).to(im.dtype) for im in hd_images]
266 |
267 | # [(3, h, w)], where h, w is multiple of 336
268 | shapes = [[im.size(1), im.size(2)] for im in hd_images]
269 | num_img_tokens = [int(((h//336)*(w//336)+1)*144 + 1 + (h//336+1)*12) for h, w in shapes]
270 | # reshape to channel dimension -> (num_images, num_crops, 3, 336, 336)
271 | # (1, 3, h//336, 336, w//336, 336) -> (1, h//336, w//336, 3, 336, 336) -> (h//336*w//336, 3, 336, 336)
272 | hd_images_reshape = [im.reshape(1, 3, h//336, 336, w//336, 336).permute(0,2,4,1,3,5).reshape(-1, 3, 336, 336).contiguous() for im, (h, w) in zip(hd_images, shapes)]
273 | # concat global image and local image
274 | hd_images_reshape = [torch.cat([_global_image] + [_im], dim=0) for _global_image, _im in zip(global_image, hd_images_reshape)]
275 |
276 | # pad to max_num_crops
277 | image_transformed = [pad_to_max_num_crops_tensor(im, self.num_crops+1) for im in hd_images_reshape]
278 | image_transformed = torch.stack(image_transformed, dim=0)
279 | image_sizes = [torch.LongTensor(_shapes) for _shapes in shapes]
280 | padded_images = image_transformed
281 | image_sizes = shapes
282 |
283 | data = {"pixel_values": padded_images,
284 | "image_sizes": image_sizes,
285 | "num_img_tokens": num_img_tokens
286 | }
287 |
288 | return BatchFeature(data=data, tensor_type=return_tensors)
289 |
290 | AutoImageProcessor.register("Phi3VImageProcessor", Phi3VImageProcessor)
291 |
292 | transformers.Phi3VImageProcessor = Phi3VImageProcessor
293 |
294 | class Phi3VProcessor(ProcessorMixin):
295 | r"""
296 | Constructs a Phi3-V processor which wraps a Phi3-V image processor and a LLaMa tokenizer into a single processor.
297 |
298 | [`Phi3VProcessor`] offers all the functionalities of [`Phi3VImageProcessor`] and [`LlamaTokenizerFast`]. See the
299 | [`~Phi3VProcessor.__call__`] and [`~Phi3VProcessor.decode`] for more information.
300 |
301 | Args:
302 | image_processor ([`Phi3VImageProcessor`], *optional*):
303 | The image processor is a required input.
304 | tokenizer ([`LlamaTokenizerFast`], *optional*):
305 | The tokenizer is a required input.
306 | """
307 |
308 | attributes = ["image_processor", "tokenizer"]
309 | image_processor_class = "Phi3VImageProcessor"
310 | tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
311 | special_image_token = "<|image|>"
312 |
313 | def __init__(self, image_processor, tokenizer):
314 | self.image_processor = image_processor
315 | self.tokenizer = tokenizer
316 | self.num_img_tokens = image_processor.num_img_tokens
317 | self.img_tokens = [f"<|image_{i+1}|>" for i in range(1000000)]
318 |
319 | def __call__(
320 | self,
321 | text: Union[TextInput, List[TextInput]],
322 | images: ImageInput = None,
323 | padding: Union[bool, str, PaddingStrategy] = False,
324 | truncation: Union[bool, str, TruncationStrategy] = None,
325 | max_length=None,
326 | return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
327 | qs_highlighted_parts: List[str] = None
328 | ) -> BatchFeature:
329 | """
330 | Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
331 | and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
332 | the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
333 | Phi3ImageProcessor's [`~Phi3ImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
334 | of the above two methods for more information.
335 |
336 | Args:
337 | text (`str`, `List[str]`, `List[List[str]]`):
338 | The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
339 | (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
340 | `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
341 | images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
342 | The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
343 | tensor. Both channels-first and channels-last formats are supported.
344 | padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
345 | Select a strategy to pad the returned sequences (according to the model's padding side and padding
346 | index) among:
347 | - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
348 | sequence if provided).
349 | - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
350 | acceptable input length for the model if that argument is not provided.
351 | - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
352 | lengths).
353 | max_length (`int`, *optional*):
354 | Maximum length of the returned list and optionally padding length (see above).
355 | truncation (`bool`, *optional*):
356 | Activates truncation to cut input sequences longer than `max_length` to `max_length`.
357 | return_tensors (`str` or [`~utils.TensorType`], *optional*):
358 | If set, will return tensors of a particular framework. Acceptable values are:
359 |
360 | - `'tf'`: Return TensorFlow `tf.constant` objects.
361 | - `'pt'`: Return PyTorch `torch.Tensor` objects.
362 | - `'np'`: Return NumPy `np.ndarray` objects.
363 | - `'jax'`: Return JAX `jnp.ndarray` objects.
364 |
365 | Returns:
366 | [`BatchFeature`]: A [`BatchFeature`] with the following fields:
367 |
368 | - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
369 | - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
370 | `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
371 | `None`).
372 | - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
373 | """
374 | if images is not None:
375 | image_inputs = self.image_processor(images, return_tensors=return_tensors)
376 | else:
377 | image_inputs = {}
378 | inputs = self._convert_images_texts_to_inputs(image_inputs, text, padding=padding, truncation=truncation,
379 | max_length=max_length, return_tensors=return_tensors,
380 | qs_highlighted_parts=qs_highlighted_parts)
381 |
382 | return inputs
383 |
384 | def calc_num_image_tokens(self, images: ImageInput):
385 | """ Calculate the number of image tokens for each image.
386 | Args:
387 | images (`ImageInput`):
388 | Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
389 | passing in images with pixel values between 0 and 1, set `do_rescale=False`.
390 | """
391 | return self.image_processor.calc_num_image_tokens(images)
392 |
393 | def calc_num_image_tokens_from_image_size(self, width, height):
394 | """ Calculate the number of image token for an image with given width and height.
395 | Args:
396 | width (`int`):
397 | Width of the image.
398 | height (`int`):
399 | Height of the image.
400 | """
401 | return self.image_processor.calc_num_image_tokens_from_image_size(width, height)
402 |
403 |
404 | @property
405 | def special_image_token_id(self):
406 | return self.tokenizer.convert_tokens_to_ids(self.special_image_token)
407 |
408 | def get_special_image_token_id(self):
409 | return self.tokenizer.convert_tokens_to_ids(self.special_image_token)
410 |
411 | def _convert_images_texts_to_inputs(self, images, texts, padding=False, truncation=None, max_length=None, return_tensors=None, qs_highlighted_parts=[]):
412 |
413 | if not len(images):
414 | model_inputs = self.tokenizer(texts, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length, padding_side='left')
415 | return BatchFeature(data={**model_inputs})
416 |
417 | pattern = r"<\|image_\d+\|>"
418 | # prompt_chunks = [self.tokenizer(chunk).input_ids for chunk in re.split(pattern, texts)]
419 |
420 | if 'num_img_tokens' in images:
421 | num_img_tokens = images['num_img_tokens']
422 | else:
423 | assert 'num_crops' in images, 'num_crops must be provided in images if num_img_tokens is not provided'
424 | num_crops = images['num_crops']
425 | num_img_tokens = [_num_crops * self.num_img_tokens for _num_crops in num_crops]
426 |
427 | images, image_sizes = images['pixel_values'], images['image_sizes']
428 |
429 | # image_tags needs to start from 1 to n
430 | image_tags = re.findall(pattern, texts)
431 | # image_ids = [int(s.split("|")[1].split("_")[-1]) * -1 for s in image_tags]
432 | # image_ids_pad = [[iid]*num_img_tokens[i] for i, iid in enumerate(image_ids)]
433 | image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
434 | unique_image_ids = sorted(list(set(image_ids)))
435 | # image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be [1, 4, 5]
436 | # check the condition
437 | assert unique_image_ids == list(range(1, len(unique_image_ids)+1)), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
438 | # total images must be the same as the number of image tags
439 | assert len(unique_image_ids) == len(images), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(images)} images"
440 |
441 | image_ids_pad = [[-iid]*num_img_tokens[iid-1] for iid in image_ids]
442 |
443 | def insert_separator(X, sep_list):
444 | if len(X) > len(sep_list):
445 | sep_list.append([])
446 | return [ele for sublist in zip(X, sep_list) for ele in sublist]
447 |
448 | # input_ids = []
449 | # offset = 0
450 | # for x in insert_separator(prompt_chunks, image_ids_pad):
451 | # input_ids.extend(x[offset:])
452 | #
453 | # input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0)
454 | # attention_mask = (input_ids > -1000000).to(torch.long)
455 | #
456 | # return BatchFeature(data={"input_ids": input_ids,
457 | # "attention_mask": attention_mask,
458 | # "pixel_values": images,
459 | # "image_sizes": image_sizes})
460 | highlight_attention_mask = []
461 | prompt_chunks = []
462 | # Generate highlight masks for each chunk
463 | for chunk in re.split(pattern, texts):
464 | chunk_mask, _ = txt_highlight_mask(self.tokenizer, chunk, qs_highlighted_parts)
465 | highlight_attention_mask.append([0] + chunk_mask)
466 | a = self.tokenizer(chunk)
467 | prompt_chunks.append(a.input_ids)
468 |
469 | offset = 0
470 | input_ids = []
471 | combined_highlight_mask = []
472 | zero_mask_padding = [[0] * len(pad) for pad in
473 | image_ids_pad] # Create zero padding mask with the same length as image ids
474 |
475 | for tokens, mask in zip(insert_separator(prompt_chunks, image_ids_pad),
476 | insert_separator(highlight_attention_mask,
477 | zero_mask_padding)): # Use zero_mask_padding here
478 | input_ids.extend(tokens[offset:])
479 | combined_highlight_mask.extend(mask[offset:])
480 |
481 | input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0)
482 | attention_mask = (input_ids > -1000000).to(torch.long)
483 | combined_highlight_mask = torch.tensor(combined_highlight_mask, dtype=torch.long).unsqueeze(0)
484 |
485 | return BatchFeature(data={"input_ids": input_ids,
486 | "attention_mask": attention_mask,
487 | "highlight_attention_mask": combined_highlight_mask,
488 | "pixel_values": images,
489 | "image_sizes": image_sizes})
490 |
491 | # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
492 | def batch_decode(self, *args, **kwargs):
493 | """
494 | This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
495 | refer to the docstring of this method for more information.
496 | """
497 | return self.tokenizer.batch_decode(*args, **kwargs)
498 |
499 | # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
500 | def decode(self, *args, **kwargs):
501 | """
502 | This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
503 | the docstring of this method for more information.
504 | """
505 | return self.tokenizer.decode(*args, **kwargs)
506 |
507 | @property
508 | # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
509 | def model_input_names(self):
510 | tokenizer_input_names = self.tokenizer.model_input_names
511 | image_processor_input_names = self.image_processor.model_input_names
512 | return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
--------------------------------------------------------------------------------