├── data ├── .gitkeep ├── LLaVA-Med │ └── .gitkeep ├── PathVQA │ └── .gitkeep ├── RadVQA │ └── .gitkeep ├── Slake │ └── .gitkeep └── PubMedVision │ └── .gitkeep ├── src ├── __init__.py ├── datasets │ ├── __init__.py │ ├── mscxr.py │ ├── radvqa.py │ ├── slakevqa.py │ ├── pubmed.py │ ├── pathvqa.py │ └── llavamed.py └── highlighter_modules │ ├── guidance.py │ └── utils.py ├── lora_weights └── .gitkeep ├── examples ├── results │ └── .gitkeep ├── images │ ├── chest_xray.jpg │ └── kidney_tissue.jpg └── input_queries.json ├── requirements.txt ├── .idea ├── vcs.xml ├── .gitignore ├── inspectionProfiles │ ├── profiles_settings.xml │ └── Project_Default.xml ├── modules.xml ├── Expert-CFG.iml └── misc.xml ├── Phi-3.5-vision-instruct ├── processor_config.json └── processing_phi3_v_cfg.py ├── Phi-3-vision-128k-instruct ├── preprocessor_config.json └── processing_phi3_v_cfg.py ├── .github └── workflows │ └── static.yml ├── README.md └── demo.py /data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/LLaVA-Med/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/PathVQA/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/RadVQA/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/Slake/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lora_weights/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/PubMedVision/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/results/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/images/chest_xray.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ecoxial2007/Expert-CFG/HEAD/examples/images/chest_xray.jpg -------------------------------------------------------------------------------- /examples/images/kidney_tissue.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ecoxial2007/Expert-CFG/HEAD/examples/images/kidney_tissue.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | flash_attn==2.5.8 2 | numpy==1.24.4 3 | Pillow==10.3.0 4 | Requests==2.31.0 5 | torch==2.3.0 6 | torchvision==0.18.0 7 | transformers==4.43.0 8 | accelerate==0.30.0 -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /Phi-3.5-vision-instruct/processor_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "auto_map": { 3 | "AutoProcessor": "processing_phi3_v_cfg.Phi3VProcessor" 4 | }, 5 | "processor_class": "Phi3VProcessor" 6 | } 7 | 8 | 9 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/Expert-CFG.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /Phi-3-vision-128k-instruct/preprocessor_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "auto_map": { 3 | "AutoProcessor": "processing_phi3_v_cfg.Phi3VProcessor", 4 | "AutoImageProcessor": "image_processing_phi3_v.Phi3VImageProcessor" 5 | }, 6 | "num_crops": 16, 7 | "image_mean": [ 8 | 0.48145466, 9 | 0.4578275, 10 | 0.40821073 11 | ], 12 | "image_processor_type": "Phi3VImageProcessor", 13 | "image_std": [ 14 | 0.26862954, 15 | 0.26130258, 16 | 0.27577711 17 | ], 18 | "processor_class": "Phi3VProcessor", 19 | "num_img_tokens": 144 20 | } -------------------------------------------------------------------------------- /examples/input_queries.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "img_id": "chest_xray.jpg", 4 | "question": "Is there evidence of an aortic aneurysm?", 5 | "highlights": ["aortic aneurysm"], 6 | "answer_type": "CLOSED", 7 | "answer": "yes", 8 | "top_k_captions": ["The abnormality in this image is related to the cardiovascular system. This is a chest X-ray image. There is evidence of an aortic aneurysm."] 9 | }, 10 | { 11 | "img_id": "kidney_tissue.jpg", 12 | "question": "What does this image show?", 13 | "highlights": ["giant cell", "fibrosis"], 14 | "question_type": "what", 15 | "answer": "giant cells and fibrosis", 16 | "answer_id": 1742, 17 | "answer_type": "OPEN", 18 | "top_k_captions": ["In the image, there appears to be microscopic view of a kidney tissue sample under low magnification. Features such as giant cells and fibrosis can be observed, possibly indicating a pathological condition such as sarcoidosis affecting the urinary system."] 19 | } 20 | ] -------------------------------------------------------------------------------- /src/highlighter_modules/guidance.py: -------------------------------------------------------------------------------- 1 | # file for highlight guidance. 2 | # REF: CFG-LLM: Stay on topic with Classifier-Free Guidance 3 | # https://arxiv.org/abs/2306.17806 4 | from transformers.generation.logits_process import LogitsProcessor 5 | import torch 6 | 7 | 8 | class ProbCFGLogitsProcessor(LogitsProcessor): 9 | def __init__( 10 | self, 11 | guidance_scale: float, 12 | use_log: bool = False, # whether to use log softmax. 13 | ): 14 | self.guidance_scale = guidance_scale 15 | self.use_log = use_log 16 | 17 | def __call__(self, input_ids, scores): 18 | if self.use_log: 19 | scores = torch.nn.functional.log_softmax(scores, dim=-1) 20 | else: 21 | scores = torch.nn.functional.softmax(scores, dim=-1) 22 | 23 | bs = input_ids.shape[0] // 2 24 | cond_logits, uncond_logits = scores[:bs], scores[bs:] 25 | cond_logits = ( 26 | self.guidance_scale * (cond_logits - uncond_logits) + uncond_logits 27 | ) 28 | 29 | # directly copy two. 30 | logits = torch.cat([cond_logits, cond_logits], dim=0) 31 | return logits 32 | -------------------------------------------------------------------------------- /.github/workflows/static.yml: -------------------------------------------------------------------------------- 1 | # Simple workflow for deploying static content to GitHub Pages 2 | name: Deploy static content to Pages 3 | 4 | on: 5 | # Runs on pushes targeting the default branch 6 | push: 7 | branches: ["main"] 8 | 9 | # Allows you to run this workflow manually from the Actions tab 10 | workflow_dispatch: 11 | 12 | # Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages 13 | permissions: 14 | contents: read 15 | pages: write 16 | id-token: write 17 | 18 | # Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued. 19 | # However, do NOT cancel in-progress runs as we want to allow these production deployments to complete. 20 | concurrency: 21 | group: "pages" 22 | cancel-in-progress: false 23 | 24 | jobs: 25 | # Single deploy job since we're just deploying 26 | deploy: 27 | environment: 28 | name: github-pages 29 | url: ${{ steps.deployment.outputs.page_url }} 30 | runs-on: ubuntu-latest 31 | steps: 32 | - name: Checkout 33 | uses: actions/checkout@v4 34 | - name: Setup Pages 35 | uses: actions/configure-pages@v5 36 | - name: Upload artifact 37 | uses: actions/upload-pages-artifact@v3 38 | with: 39 | # Upload entire repository 40 | path: '.' 41 | - name: Deploy to GitHub Pages 42 | id: deployment 43 | uses: actions/deploy-pages@v4 44 | -------------------------------------------------------------------------------- /src/highlighter_modules/utils.py: -------------------------------------------------------------------------------- 1 | # utilization functions for the highlighter. 2 | 3 | # return the mask and tokens for the highlighted text prompt. 4 | def txt_highlight_mask(tokenizer, txt_prompt, highlighted_idx_range): 5 | # Convert text to tokens 6 | tokens = tokenizer.tokenize(txt_prompt) 7 | 8 | # Initialize the mask 9 | mask = [0] * len(tokens) 10 | 11 | # Convert highlighted_idx_range to integer ranges 12 | ranges = [] 13 | for idx_range in highlighted_idx_range: 14 | if isinstance(idx_range, str): 15 | # Add a space before the string to avoid partial matches 16 | if idx_range[0] != " ": 17 | idx_range = " " + idx_range 18 | start_idx = txt_prompt.find(idx_range) 19 | if start_idx == -1: 20 | start_idx = txt_prompt.find( 21 | idx_range[1:] 22 | ) # remove the space and try again 23 | if start_idx == -1: 24 | continue # Skip if the string is not found 25 | end_idx = start_idx + len(idx_range) 26 | ranges.append((start_idx, end_idx)) 27 | elif isinstance(idx_range, list) and len(idx_range) == 2: 28 | ranges.append((idx_range[0], idx_range[1])) 29 | 30 | # Mark the highlighted ranges in the mask 31 | for start_idx, end_idx in ranges: 32 | start_token_idx = len(tokenizer.tokenize(txt_prompt[:start_idx])) 33 | end_token_idx = len(tokenizer.tokenize(txt_prompt[:end_idx])) 34 | # TODO: Include the start and end tokens that partially overlap with the highlighted area 35 | for i in range(start_token_idx, end_token_idx): 36 | mask[i] = 1 37 | 38 | return mask, tokens 39 | -------------------------------------------------------------------------------- /src/datasets/mscxr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | 6 | class MSCXDataset(Dataset): 7 | def __init__(self, json_path, img_root, transform=None): 8 | """ 9 | 初始化数据集类。 10 | 11 | 参数: 12 | json_path (str): JSON文件的路径,用于读取标注信息。 13 | img_root (str): 图像根目录的路径。 14 | transform (callable, optional): 图像转换方法,应用在每个PIL图像上。 15 | """ 16 | # 1. 读取并解析JSON文件 17 | with open(json_path, 'r') as file: 18 | data = json.load(file) 19 | 20 | # 2. 提取图像和注释信息 21 | self.images = data["images"] 22 | self.annotations = data["annotations"] 23 | self.img_root = img_root 24 | self.transform = transform 25 | 26 | # 3. 创建图像ID和注释的映射关系 27 | self.img_id_to_annotations = {} 28 | for ann in self.annotations: 29 | img_id = ann["image_id"] 30 | if img_id not in self.img_id_to_annotations: 31 | self.img_id_to_annotations[img_id] = [] 32 | self.img_id_to_annotations[img_id].append(ann) 33 | 34 | def __len__(self): 35 | """ 36 | 返回数据集中样本的总数。 37 | """ 38 | return len(self.images) 39 | 40 | def __getitem__(self, index): 41 | """ 42 | 根据索引返回一个样本,包括图像和相关的注释信息。 43 | """ 44 | # 4. 获取图像信息 45 | img_info = self.images[index] 46 | img_id = img_info["id"] 47 | img_path = os.path.join(self.img_root, img_info["path"]) 48 | 49 | # 5. 加载图像 50 | image = Image.open(img_path).convert('RGB') 51 | if self.transform: 52 | image = self.transform(image) 53 | 54 | # 6. 获取该图像的所有注释 55 | annotations = self.img_id_to_annotations.get(img_id, []) 56 | bboxes = [] 57 | labels = [] 58 | 59 | # 7. 处理每个注释 60 | for ann in annotations: 61 | x1, y1 = ann["bbox"][:2] 62 | x2, y2 = ann["bbox"][2:] 63 | bboxes.append([x1, y1, x2, y2]) 64 | labels.append(ann["label_text"]) 65 | 66 | return { 67 | "image": image, 68 | "bboxes": bboxes, 69 | "labels": labels, 70 | "image_id": img_id 71 | } 72 | -------------------------------------------------------------------------------- /src/datasets/radvqa.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import json 4 | import torch 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | 8 | class RADVQADataset(Dataset): 9 | def __init__(self, annotation_file='', vis_root='', transform=None): 10 | """ 11 | Initialize the dataset. 12 | 13 | Parameters: 14 | annotation_file (str): Path to the annotation file containing image IDs and captions. 15 | vis_root (str): Root directory where images are stored. 16 | transform (callable, optional): Optional transform to be applied on a PIL image. 17 | """ 18 | with open(annotation_file, 'r') as file: 19 | self.annotation = json.load(file) 20 | 21 | self.vis_root = vis_root 22 | self.img_ids = {ann['img_id']: idx for idx, ann in enumerate(self.annotation)} 23 | self.transform = transform 24 | 25 | def __len__(self): 26 | """ 27 | Return the total number of samples in the dataset. 28 | """ 29 | return len(self.annotation) 30 | 31 | def __getitem__(self, index): 32 | """ 33 | Retrieve a sample from the dataset at the specified index. 34 | """ 35 | ann = self.annotation[index] 36 | img_file = ann["img_id"] 37 | image_path = os.path.join(self.vis_root, img_file) 38 | image = Image.open(image_path).convert('RGB') 39 | 40 | if self.transform: 41 | image = self.transform(image) 42 | 43 | question = ann['question'] 44 | answer = ann['answer'] 45 | try: 46 | return { 47 | "image": image, 48 | "question": question, 49 | "answer": answer, 50 | "image_id": self.img_ids[ann["img_id"]], 51 | "top_k_captions": ann["top_k_captions"], 52 | "highlights": ann['highlights'] 53 | } 54 | except: 55 | return { 56 | "image": image, 57 | "question": question, 58 | "answer": answer, 59 | "image_id": self.img_ids[ann["img_id"]], 60 | } 61 | 62 | 63 | class RADVQADataCollator: 64 | def __init__(self, processor): 65 | self.processor = processor 66 | 67 | def __call__(self, examples): 68 | assert len(examples) == 1, 'Phi-3-V only supports batch_size == 1' 69 | example = examples[0] 70 | 71 | image = example['image'] 72 | question = example['question'] 73 | answer = example['answer'] 74 | prompt_message = { 75 | 'role': 'user', 76 | 'content': f'<|image_1|>\n{question}', 77 | } 78 | 79 | prompt = self.processor.tokenizer.apply_chat_template( 80 | [prompt_message], tokenize=False, add_generation_prompt=True 81 | ) 82 | answer = f'{answer}<|end|>\n<|endoftext|>' 83 | 84 | # mask questions for labels 85 | batch = self.processor(prompt, [image], return_tensors='pt') 86 | prompt_input_ids = batch['input_ids'] 87 | # Do not add bos token to answer 88 | answer_input_ids = self.processor.tokenizer( 89 | answer, add_special_tokens=False, return_tensors='pt' 90 | )['input_ids'] 91 | input_ids = torch.cat([prompt_input_ids, answer_input_ids], dim=1) 92 | ignore_index = -100 93 | labels = torch.cat( 94 | [ 95 | torch.tensor([ignore_index] * len(prompt_input_ids[0])).unsqueeze(0), 96 | answer_input_ids, 97 | ], 98 | dim=1, 99 | ) 100 | 101 | batch['input_ids'] = input_ids 102 | del batch['attention_mask'] 103 | batch['labels'] = labels 104 | 105 | return batch 106 | -------------------------------------------------------------------------------- /src/datasets/slakevqa.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import json 4 | import torch 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | 8 | class SLAKEVQADataset(Dataset): 9 | def __init__(self, annotation_file='', vis_root='', transform=None): 10 | """ 11 | Initialize the dataset. 12 | 13 | Parameters: 14 | annotation_file (str): Path to the annotation file containing image IDs and captions. 15 | vis_root (str): Root directory where images are stored. 16 | transform (callable, optional): Optional transform to be applied on a PIL image. 17 | """ 18 | with open(annotation_file, 'r') as file: 19 | self.annotation = json.load(file) 20 | 21 | self.vis_root = vis_root 22 | self.img_ids = {ann['img_id']: idx for idx, ann in enumerate(self.annotation)} 23 | self.transform = transform 24 | 25 | def __len__(self): 26 | """ 27 | Return the total number of samples in the dataset. 28 | """ 29 | return len(self.annotation) 30 | 31 | def __getitem__(self, index): 32 | """ 33 | Retrieve a sample from the dataset at the specified index. 34 | """ 35 | ann = self.annotation[index] 36 | img_file = ann["img_id"] 37 | 38 | image_path = os.path.join(self.vis_root, img_file) 39 | image = Image.open(image_path).convert('RGB') 40 | 41 | if self.transform: 42 | image = self.transform(image) 43 | 44 | question = ann['question'] 45 | answer = ann['answer'] 46 | try: 47 | return { 48 | "image": image, 49 | "question": question, 50 | "answer": answer, 51 | "image_id": self.img_ids[ann["img_id"]], 52 | "top_k_captions": ann["top_k_captions"], 53 | "highlights": ann['highlights'] 54 | } 55 | except: 56 | return { 57 | "image": image, 58 | "question": question, 59 | "answer": answer, 60 | "image_id": self.img_ids[ann["img_id"]], 61 | } 62 | 63 | 64 | class SLAKEVQADataCollator: 65 | def __init__(self, processor): 66 | self.processor = processor 67 | 68 | def __call__(self, examples): 69 | assert len(examples) == 1, 'Phi-3-V only supports batch_size == 1' 70 | example = examples[0] 71 | 72 | image = example['image'] 73 | question = example['question'] 74 | answer = example['answer'] 75 | prompt_message = { 76 | 'role': 'user', 77 | 'content': f'<|image_1|>\n{question}', 78 | } 79 | 80 | prompt = self.processor.tokenizer.apply_chat_template( 81 | [prompt_message], tokenize=False, add_generation_prompt=True 82 | ) 83 | answer = f'{answer}<|end|>\n<|endoftext|>' 84 | 85 | # mask questions for labels 86 | batch = self.processor(prompt, [image], return_tensors='pt') 87 | prompt_input_ids = batch['input_ids'] 88 | # Do not add bos token to answer 89 | answer_input_ids = self.processor.tokenizer( 90 | answer, add_special_tokens=False, return_tensors='pt' 91 | )['input_ids'] 92 | input_ids = torch.cat([prompt_input_ids, answer_input_ids], dim=1) 93 | ignore_index = -100 94 | labels = torch.cat( 95 | [ 96 | torch.tensor([ignore_index] * len(prompt_input_ids[0])).unsqueeze(0), 97 | answer_input_ids, 98 | ], 99 | dim=1, 100 | ) 101 | 102 | batch['input_ids'] = input_ids 103 | del batch['attention_mask'] 104 | batch['labels'] = labels 105 | 106 | return batch 107 | -------------------------------------------------------------------------------- /src/datasets/pubmed.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import json 4 | import torch 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | 8 | class PubMedAlignDataset(Dataset): 9 | def __init__(self, annotation_file='', vis_root='', transform=None): 10 | """ 11 | Initialize the dataset. 12 | 13 | Parameters: 14 | annotation_file (str): Path to the annotation file containing image IDs and captions. 15 | vis_root (str): Root directory where images are stored. 16 | transform (callable, optional): Optional transform to be applied on a PIL image. 17 | """ 18 | with open(annotation_file, 'r') as file: 19 | self.annotation = json.load(file)['annotations'] 20 | 21 | self.vis_root = vis_root 22 | self.img_ids = {} 23 | for idx, ann in enumerate(self.annotation): 24 | ann['id'] = idx 25 | self.img_ids[ann['id']] = idx 26 | self.transform = transform 27 | 28 | def __len__(self): 29 | """ 30 | Return the total number of samples in the dataset. 31 | """ 32 | return len(self.annotation) 33 | 34 | def __getitem__(self, index): 35 | """ 36 | Retrieve a sample from the dataset at the specified index. 37 | """ 38 | ann = self.annotation[index] 39 | img_files = ann["image_id"][:2] 40 | images = [] 41 | for img_file in img_files: 42 | image_path = os.path.join(self.vis_root, img_file) 43 | image = Image.open(image_path).convert('RGB') 44 | if self.transform: 45 | image = self.transform(image) 46 | images.append(image) 47 | 48 | 49 | 50 | question = ann['conversations'][0]['value'].replace('', '').replace('\n', '').strip() 51 | answer = ann['conversations'][1]['value'].replace('\n', '').strip() 52 | 53 | return { 54 | "images": images, 55 | "question": question, 56 | "answer": answer, 57 | "image_id": self.img_ids[ann["id"]] 58 | } 59 | 60 | 61 | class PubMedVQADataCollator: 62 | def __init__(self, processor): 63 | self.processor = processor 64 | 65 | def __call__(self, examples): 66 | assert len(examples) == 1, 'Phi-3-V only supports batch_size == 1' 67 | example = examples[0] 68 | 69 | images = example['images'] 70 | image_references = ''.join([f"<|image_{i + 1}|>\n" for i in range(len(images))]) 71 | 72 | question = example['question'] 73 | answer = example['answer'] 74 | prompt_message = { 75 | 'role': 'user', 76 | 'content': f"{image_references}{question}", 77 | } 78 | 79 | prompt = self.processor.tokenizer.apply_chat_template( 80 | [prompt_message], tokenize=False, add_generation_prompt=True 81 | ) 82 | answer = f'{answer}<|end|>\n<|endoftext|>' 83 | 84 | # mask questions for labels 85 | batch = self.processor(prompt, images, return_tensors='pt') 86 | prompt_input_ids = batch['input_ids'] 87 | # Do not add bos token to answer 88 | answer_input_ids = self.processor.tokenizer( 89 | answer, add_special_tokens=False, return_tensors='pt' 90 | )['input_ids'] 91 | input_ids = torch.cat([prompt_input_ids, answer_input_ids], dim=1) 92 | ignore_index = -100 93 | labels = torch.cat( 94 | [ 95 | torch.tensor([ignore_index] * len(prompt_input_ids[0])).unsqueeze(0), 96 | answer_input_ids, 97 | ], 98 | dim=1, 99 | ) 100 | 101 | batch['input_ids'] = input_ids 102 | del batch['attention_mask'] 103 | batch['labels'] = labels 104 | 105 | return batch 106 | -------------------------------------------------------------------------------- /src/datasets/pathvqa.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import json 4 | import torch 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | 8 | class PathVQADataset(Dataset): 9 | def __init__(self, annotation_file='', vis_root='', transform=None): 10 | """ 11 | Initialize the dataset. 12 | 13 | Parameters: 14 | annotation_file (str): Path to the annotation file containing image IDs and captions. 15 | vis_root (str): Root directory where images are stored. 16 | transform (callable, optional): Optional transform to be applied on a PIL image. 17 | """ 18 | with open(annotation_file, 'r') as file: 19 | self.annotation = json.load(file) 20 | 21 | self.vis_root = vis_root 22 | self.img_ids = {ann['img_id']: idx for idx, ann in enumerate(self.annotation)} 23 | self.transform = transform 24 | 25 | def __len__(self): 26 | """ 27 | Return the total number of samples in the dataset. 28 | """ 29 | return len(self.annotation) 30 | 31 | def __getitem__(self, index): 32 | """ 33 | Retrieve a sample from the dataset at the specified index. 34 | """ 35 | ann = self.annotation[index] 36 | img_file = ann["img_id"]+'.jpg' 37 | split = img_file.split("_")[0] 38 | image_path = os.path.join(self.vis_root, split, img_file) 39 | image = Image.open(image_path).convert('RGB') 40 | 41 | if self.transform: 42 | image = self.transform(image) 43 | 44 | question = ann['question'] 45 | answer = ann['answer'] 46 | 47 | try: 48 | return { 49 | "image": image, 50 | "question": question, 51 | "answer": answer, 52 | "image_id": self.img_ids[ann["img_id"]], 53 | "top_k_captions": ann["top_k_captions"], 54 | "highlights": ann['highlights'] 55 | } 56 | except: 57 | return { 58 | "image": image, 59 | "question": question, 60 | "answer": answer, 61 | "image_id": self.img_ids[ann["img_id"]], 62 | } 63 | 64 | 65 | class PathVQADataCollator: 66 | def __init__(self, processor): 67 | self.processor = processor 68 | 69 | def __call__(self, examples): 70 | assert len(examples) == 1, 'Phi-3-V only supports batch_size == 1' 71 | example = examples[0] 72 | 73 | image = example['image'] 74 | question = example['question'] 75 | answer = example['answer'] 76 | prompt_message = { 77 | 'role': 'user', 78 | 'content': f'<|image_1|>\n{question}', 79 | } 80 | 81 | prompt = self.processor.tokenizer.apply_chat_template( 82 | [prompt_message], tokenize=False, add_generation_prompt=True 83 | ) 84 | answer = f'{answer}<|end|>\n<|endoftext|>' 85 | 86 | # mask questions for labels 87 | batch = self.processor(prompt, [image], return_tensors='pt') 88 | prompt_input_ids = batch['input_ids'] 89 | # Do not add bos token to answer 90 | answer_input_ids = self.processor.tokenizer( 91 | answer, add_special_tokens=False, return_tensors='pt' 92 | )['input_ids'] 93 | input_ids = torch.cat([prompt_input_ids, answer_input_ids], dim=1) 94 | ignore_index = -100 95 | labels = torch.cat( 96 | [ 97 | torch.tensor([ignore_index] * len(prompt_input_ids[0])).unsqueeze(0), 98 | answer_input_ids, 99 | ], 100 | dim=1, 101 | ) 102 | 103 | batch['input_ids'] = input_ids 104 | del batch['attention_mask'] 105 | batch['labels'] = labels 106 | 107 | return batch 108 | -------------------------------------------------------------------------------- /src/datasets/llavamed.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import json 4 | import random 5 | import torch 6 | from torch.utils.data import Dataset 7 | from PIL import Image 8 | 9 | class LlavaMedAlignDataset(Dataset): 10 | def __init__(self, annotation_file='', vis_root='', transform=None): 11 | """ 12 | Initialize the dataset. 13 | 14 | Parameters: 15 | annotation_file (str): Path to the annotation file containing image IDs and captions. 16 | vis_root (str): Root directory where images are stored. 17 | transform (callable, optional): Optional transform to be applied on a PIL image. 18 | """ 19 | with open(annotation_file, 'r') as file: 20 | self.annotation = json.load(file) 21 | 22 | self.vis_root = vis_root 23 | self.img_ids = {ann['id']: idx for idx, ann in enumerate(self.annotation)} 24 | self.transform = transform 25 | 26 | def __len__(self): 27 | """ 28 | Return the total number of samples in the dataset. 29 | """ 30 | return len(self.annotation) 31 | 32 | def __getitem__(self, index): 33 | """ 34 | Retrieve a sample from the dataset at the specified index. 35 | """ 36 | ann = self.annotation[index] 37 | 38 | img_file = ann['image'] 39 | image_path = os.path.join(self.vis_root, img_file) 40 | image = Image.open(image_path).convert('RGB') 41 | 42 | if self.transform: 43 | image = self.transform(image) 44 | 45 | 46 | 47 | first_instruction = ann['conversations'][0]['value'].replace('', '').replace('\n', '').strip() 48 | questions = [first_instruction] 49 | answers = [] 50 | 51 | for i, item in enumerate(ann["conversations"][1:]): 52 | if i % 2 == 0: # assistant 53 | assistant_answer = item["value"] 54 | answers.append(assistant_answer) 55 | else: 56 | human_instruction = item["value"] + " " 57 | questions.append(human_instruction) 58 | 59 | return { 60 | "image": image, 61 | "question": questions, 62 | "answer": answers, 63 | "image_id": self.img_ids[ann["id"]] 64 | } 65 | 66 | 67 | 68 | class LlavaMedVQADataCollator: 69 | def __init__(self, processor): 70 | self.processor = processor 71 | 72 | def __call__(self, examples): 73 | assert len(examples) == 1, 'Phi-3-V only supports batch_size == 1' 74 | example = examples[0] 75 | 76 | image = example['image'] 77 | idx = random.randint(0, len(example['question']) - 1) 78 | question = example['question'][idx] 79 | answer = example['answer'][idx] 80 | prompt_message = { 81 | 'role': 'user', 82 | 'content': f'<|image_1|>\n{question}', 83 | } 84 | 85 | prompt = self.processor.tokenizer.apply_chat_template( 86 | [prompt_message], tokenize=False, add_generation_prompt=True 87 | ) 88 | answer = f'{answer}<|end|>\n<|endoftext|>' 89 | 90 | # mask questions for labels 91 | batch = self.processor(prompt, [image], return_tensors='pt') 92 | prompt_input_ids = batch['input_ids'] 93 | # Do not add bos token to answer 94 | answer_input_ids = self.processor.tokenizer( 95 | answer, add_special_tokens=False, return_tensors='pt' 96 | )['input_ids'] 97 | input_ids = torch.cat([prompt_input_ids, answer_input_ids], dim=1) 98 | ignore_index = -100 99 | labels = torch.cat( 100 | [ 101 | torch.tensor([ignore_index] * len(prompt_input_ids[0])).unsqueeze(0), 102 | answer_input_ids, 103 | ], 104 | dim=1, 105 | ) 106 | 107 | batch['input_ids'] = input_ids 108 | del batch['attention_mask'] 109 | batch['labels'] = labels 110 | 111 | return batch 112 | 113 | # Usage example 114 | if __name__ == "__main__": 115 | from torchvision import transforms 116 | 117 | from tqdm import tqdm 118 | 119 | dataset = LlavaMedAlignDataset(annotation_file='/workspace/LLaVA-Med/align_train.json', vis_root='/workspace/medical_conversation') 120 | 121 | 122 | for sample in tqdm(dataset): 123 | print(sample['question'], sample['answer']) 124 | len += 1 125 | 126 | print(len) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Uncertainty-Driven Expert Control: Enhancing the Reliability of Medical Vision-Language Models 4 | 5 |
6 | 7 | ## 💡Overview 8 | 9 | **Expert-Controlled Classifier-Free Guidance** is a training-free expert-in-the-loop framework designed to align MedVLM with clinical expertise. It integrates token-level uncertainty estimation, a BioMedCLIP-based medical multimodal Retrieval-Augmented Generation (RAG), and interactive expert revisions and highlight-based guidance. 10 | 11 | ## 🔨Setup 12 | ### 🔨Installation 13 | ``` 14 | conda create -n expert_cfg python=3.10 -y 15 | conda activate expert_cfg 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | ### 🔨Pre-trained weights 20 | 21 | #### Baseline Model: 22 | Download them to the current directory separately and merge them with `Phi-3-vision-128k-instruct` and `Phi-3.5-vision-instruct` respectively. 23 | + Phi-3V: [Huggingface](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct) 24 | + Phi-3.5V: [Huggingface](https://huggingface.co/microsoft/Phi-3.5-vision-instruct) 25 | 26 | 27 | #### Medical LoRA: 28 | Our fine-tuning Phi3V-Med and Phi3.5V-Med LoRA links: 29 | 30 | + Phi-3V-Med: [Huggingface](https://huggingface.co/ecoxial2007/Phi-3V-Med) 31 | + Phi-3.5V-Med: [Huggingface](https://huggingface.co/ecoxial2007/Phi-3.5V-Med) 32 | Download them to the `./lora_weights` folder 33 | 34 | #### Demo 35 | ``` 36 | torchrun --nproc_per_node=1 demo.py \ 37 | --bf16 \ 38 | --use_lora \ 39 | --input_json 'examples/input_queries.json' \ 40 | --img_root 'examples/images' \ 41 | --save_path 'examples/results.json' \ 42 | --output_dir './lora_weights/logs_phi35_pubmed_instruct' 43 | ``` 44 | 45 | #### Medical Image & Test Encoder for RAG(optional): 46 | 47 | Download BiomedCLIP and place it in `./src/backbone/BiomedCLIP`. 48 | 49 | BiomedCLIP links: 50 | + [Huggingface](https://huggingface.co/microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224) 51 | 52 | **Note**: Directly downloading weights from Huggingface might encounter network issues. To facilitate modifications, we have converted the original `.bin` file to PyTorch's `.pth`. We recommend using the Baiduyun version. 53 | 54 | 55 | ### 📑Data Preparation 56 | Our data mainly comes from publicly available, free online Pathology Education Informational Resource ([PEIR](https://peir.path.uab.edu/library/index.php?/category/2)) Digital Library. 57 | We test our model on: 58 | + [VQA-RAD](https://osf.io/89kps/) 59 | + [SLAKE](https://www.med-vqa.com/slake/) 60 | + [PathVQA](https://github.com/UCSD-AI4H/PathVQA) 61 | 62 | Medical Alignment and Instruction Tuning: 63 | + [PubMedVision](https://huggingface.co/datasets/FreedomIntelligence/PubMedVision) 64 | + [Llave-Med](https://github.com/microsoft/LLaVA-Med) 65 | 66 | 67 | ### Prepare BiomedCLIP Pre-extracted Image Feature 68 | Note: We recommend using our pre-extracted BioMedCLIP features. The original images can also be found in the links below: 69 | 70 | 71 | | Dataset | Pre-extracted Features & Original Images | 72 | |----------|------------------------------------------| 73 | | PEIR | [Baiduyun, Rename zzz2zip](https://pan.baidu.com/s/1sJp_3UzjIIvOiuyMB417GQ?pwd=6666)| 74 | | PEIR BioMedCLIP features & keyword & GPT3.5 rewrite caption | [Baiduyun](https://pan.baidu.com/s/1pqHhrxLL-ZdgEat0wNwLmQ?pwd=6666)| 75 | | PathVQA | [Baiduyun](https://pan.baidu.com/s/1b1SuDSbsNM1rVGzbx8utvg?pwd=6666)| 76 | | Slake | [Baiduyun](https://pan.baidu.com/s/1mfAoi9_HZkrk7OuyQIn4-w?pwd=6666)| 77 | | RADVQA | [Baiduyun](https://pan.baidu.com/s/1gBjAjq2L-iIMf0j05QsJ-w?pwd=6666)| 78 | 79 | 80 | 81 | 82 | 83 | 84 | ## 📝Acknowledgements 85 | We also reference the excellent repos of [Phi-3CookBook](https://github.com/microsoft/Phi-3CookBook), [HuatuoVision](https://huggingface.co/datasets/FreedomIntelligence/PubMedVision), [BioMedCLIP](https://huggingface.co/microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224), in addition to other specific repos to the baseline and dataset we examined (see paper). 86 | 87 | ## 📝Citation 88 | If you find this paper useful, please consider staring 🌟 this repo and citing 📑 our paper: 89 | ``` 90 | @misc{liang2025uncertaintydriven, 91 | title={Uncertainty-Driven Expert Control: Enhancing the Reliability of Medical Vision-Language Models}, 92 | author={Xiao Liang and Di Wang and Zhicheng Jiao and Ronghan Li and Pengfei Yang and Quan Wang and Tat-Seng Chua}, 93 | year={2025}, 94 | eprint={2507.09209}, 95 | archivePrefix={arXiv}, 96 | primaryClass={cs.CV} 97 | } 98 | ``` 99 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import torch 5 | from accelerate import Accelerator 6 | from accelerate import DataLoaderConfiguration 7 | from tqdm import tqdm 8 | from transformers import ( 9 | AutoModelForCausalLM, 10 | AutoProcessor, 11 | BitsAndBytesConfig, 12 | Trainer, 13 | TrainingArguments, 14 | ) 15 | 16 | from src.datasets.radvqa import RADVQADataset 17 | from src.highlighter_modules.guidance import ProbCFGLogitsProcessor 18 | 19 | 20 | # suggested deepspeed config 21 | DS_CONFIG_DICT = { 22 | 'zero_optimization': { 23 | 'stage': 2, 24 | 'allgather_partitions': True, 25 | 'allgather_bucket_size': 5e8, 26 | 'overlap_comm': True, 27 | 'reduce_scatter': True, 28 | 'reduce_bucket_size': 5e8, 29 | 'contiguous_gradients': True, 30 | 'round_robin_gradients': True, 31 | }, 32 | 'fp16': { 33 | 'enabled': 'auto', 34 | 'loss_scale': 0, 35 | 'loss_scale_window': 1000, 36 | 'initial_scale_power': 16, 37 | 'hysteresis': 2, 38 | 'min_loss_scale': 1, 39 | }, 40 | 'bf16': {'enabled': 'auto'}, 41 | 'train_micro_batch_size_per_gpu': 'auto', 42 | 'train_batch_size': 'auto', 43 | 'gradient_accumulation_steps': 'auto', 44 | 'gradient_clipping': 'auto', 45 | } 46 | 47 | 48 | def create_dataset(args): 49 | output_file_test = args.input_json 50 | img_root = args.img_root 51 | eval_dataset = RADVQADataset(annotation_file=output_file_test, vis_root=img_root) 52 | return eval_dataset 53 | 54 | 55 | 56 | class NoGradHook: 57 | def __init__(self): 58 | self.prev_enabled = True 59 | 60 | def maybe_enable_grad_hook(self, *_): 61 | torch.set_grad_enabled(self.prev_enabled) 62 | 63 | def disable_grad_hook(self, *_): 64 | self.prev_enabled = torch.is_grad_enabled() 65 | torch.set_grad_enabled(False) 66 | 67 | 68 | def freeze_vision_model(model): 69 | vision_no_grad_hook = NoGradHook() 70 | vision_module = model.model.vision_embed_tokens 71 | vision_module.register_forward_pre_hook(vision_no_grad_hook.disable_grad_hook) 72 | vision_module.register_forward_hook(vision_no_grad_hook.maybe_enable_grad_hook) 73 | for p in vision_module.parameters(): 74 | p.requires_grad_(False) 75 | 76 | 77 | def create_model(model_name_or_path, use_flash_attention=False, use_qlora=False): 78 | bnb_config = ( 79 | BitsAndBytesConfig( 80 | load_in_4bit=True, 81 | bnb_4bit_quant_type='nf4', 82 | bnb_4bit_compute_dtype=torch.bfloat16 if use_flash_attention else torch.float16, 83 | ) 84 | if use_qlora 85 | else None 86 | ) 87 | 88 | model = AutoModelForCausalLM.from_pretrained( 89 | model_name_or_path, 90 | # Phi-3-V is originally trained in bf16 + flash attn 91 | # For fp16 mixed precision training, load in f32 to avoid hf accelerate error 92 | torch_dtype=torch.bfloat16 if use_flash_attention else torch.float32, 93 | trust_remote_code=True, 94 | _attn_implementation='flash_attention_2' if use_flash_attention else 'eager', 95 | quantization_config=bnb_config, 96 | ) 97 | 98 | return model 99 | 100 | 101 | @torch.no_grad() 102 | def evaluate(model, processor, eval_dataset, args, save_path=None, disable_tqdm=False): 103 | rank = int(os.environ.get('RANK', 0)) 104 | local_rank = int(os.environ.get('LOCAL_RANK', 0)) 105 | world_size = int(os.environ.get('WORLD_SIZE', 1)) 106 | 107 | model.eval() 108 | answers_unique = [] 109 | generated_texts_unique = [] 110 | 111 | eval_dataset_shard = eval_dataset 112 | 113 | for i in tqdm(range(len(eval_dataset)), disable=(rank != 0) or disable_tqdm): 114 | # Phi-3-V currently only supports batch_size == 1 115 | example = eval_dataset_shard[i] 116 | answers_unique.append(example['answer']) 117 | answers_unique.append(example['answer']) 118 | image = example['image'] 119 | question = example['question'] 120 | caption = example["top_k_captions"][0] 121 | highlights = example['highlights'] 122 | prompt_message = { 123 | 'role': 'user', 124 | 'content': f'{caption} <|image_1|>\n{question}', 125 | } 126 | prompt = processor.tokenizer.apply_chat_template( 127 | [prompt_message], tokenize=False, add_generation_prompt=True 128 | ) 129 | 130 | qs_highlighted_parts = highlights 131 | 132 | inputs = processor(prompt, [image], return_tensors='pt', qs_highlighted_parts=qs_highlighted_parts).to(f'cuda:{local_rank}') 133 | hl_mask_ = inputs['highlight_attention_mask'] 134 | hl_mask_[hl_mask_ == 1] = args.perturb_weight 135 | hl_mask_[hl_mask_ == 0] = args.attn 136 | cfg_batched_input = inputs['input_ids'].repeat(2, 1) 137 | pixel_values = inputs['pixel_values'].repeat(2, 1, 1, 1, 1) 138 | image_sizes = inputs['image_sizes'].repeat(2, 1) 139 | 140 | del inputs['highlight_attention_mask'] 141 | 142 | generated_outputs = model.generate( 143 | input_ids=cfg_batched_input, 144 | pixel_values=pixel_values, 145 | attention_mask=torch.cat([inputs['attention_mask'], hl_mask_], dim=0), 146 | image_sizes=image_sizes, 147 | eos_token_id=processor.tokenizer.eos_token_id, 148 | max_new_tokens=64, 149 | num_beams=args.num_beams, 150 | logits_processor=[ProbCFGLogitsProcessor(guidance_scale=args.cfg, use_log=True)], 151 | output_scores=True, 152 | return_dict_in_generate=True 153 | ) 154 | 155 | 156 | batch_index = 1 157 | prediction = processor.batch_decode( 158 | generated_outputs.sequences[:, inputs['input_ids'].size(1):], 159 | skip_special_tokens=True, 160 | clean_up_tokenization_spaces=False, 161 | ) 162 | prediction = prediction[0].strip().strip('.') 163 | 164 | print('Question:', example['question'], 'GT:',example['answer']) 165 | print('Prediction:', prediction) 166 | token_probs = [] 167 | generated_texts = [] 168 | for i, scores in enumerate(generated_outputs.scores): 169 | probs = torch.softmax(scores, dim=-1) 170 | generated_token_id = generated_outputs.sequences[batch_index, inputs['input_ids'].size(1) + len(token_probs)] 171 | token_prob = probs[batch_index, generated_token_id].item() 172 | token_probs.append(token_prob) 173 | 174 | # Print the decoded tokens and their probabilities 175 | print("Generated text and token probabilities:") 176 | for idx, prob in enumerate(token_probs): 177 | token = processor.decode(generated_outputs.sequences[batch_index, inputs['input_ids'].size(1) + idx]) 178 | print(f"{token} - Probability: {prob}") 179 | generated_texts.append(token) 180 | 181 | update_information = { 182 | 'question': example['question'], 183 | 'answer': example['answer'], 184 | 'prediction': prediction, 185 | 'token_probs': token_probs, 186 | 'token_preds': generated_texts 187 | } 188 | generated_texts_unique.append(update_information) 189 | 190 | 191 | 192 | if rank == 0: 193 | if save_path: 194 | with open(save_path, 'w') as f: 195 | json.dump(generated_texts_unique, f, indent=4) 196 | 197 | 198 | 199 | def main(): 200 | parser = argparse.ArgumentParser() 201 | parser.add_argument( 202 | '--model_name_or_path', 203 | type=str, 204 | # default='./Phi-3-vision-128k-instruct', 205 | default='./Phi-3.5-vision-instruct', 206 | help='Model name or path to load from', 207 | ) 208 | parser.add_argument('--use_flash_attention', action='store_true', help='Use Flash Attention') 209 | parser.add_argument('--bf16', action='store_true', help='Use BF16') 210 | parser.add_argument('--use_lora', action='store_true', help='Use LoRA') 211 | parser.add_argument('--use_qlora', action='store_true', help='Use QLora') 212 | parser.add_argument('--output_dir', type=str, help='Output directory') 213 | parser.add_argument('--save_path', type=str, help='Save json path') 214 | parser.add_argument('--input_json', type=str, help='Question and Answer json path') 215 | parser.add_argument('--img_root', type=str, help='Image Folder') 216 | parser.add_argument('--batch_size', type=int, default=16, help='Batch size') 217 | parser.add_argument('--num_crops', type=int, default=16, help='Number of maximum image crops') 218 | parser.add_argument('--no-tqdm', dest='tqdm', action='store_false', help='Disable tqdm') 219 | parser.add_argument('--lora_rank', type=int, default=64, help='LoRA rank') 220 | parser.add_argument( 221 | '--lora_alpha_ratio', type=float, default=2, help='LoRA alpha to rank ratio' 222 | ) 223 | parser.add_argument('--lora_dropout', type=float, default=0.0, help='LoRA dropout') 224 | parser.add_argument('--freeze_vision_model', action='store_true', help='Freeze vision model') 225 | 226 | parser.add_argument("--num_beams", type=int, default=1) 227 | parser.add_argument("--cfg", type=float, default=1.5) 228 | parser.add_argument("--attn", type=float, default=3.0) 229 | parser.add_argument("--perturb_weight", type=float, default=0.01) 230 | 231 | args = parser.parse_args() 232 | args.attention_weight = args.attn 233 | 234 | assert args.num_crops <= 16, 'num_crops must be less than or equal to 16' 235 | if args.use_qlora: 236 | args.use_lora = True 237 | 238 | dataloader_config = DataLoaderConfiguration( 239 | dispatch_batches=None, 240 | split_batches=False, 241 | even_batches=True, 242 | use_seedable_sampler=True 243 | ) 244 | accelerator = Accelerator(dataloader_config) 245 | 246 | with accelerator.local_main_process_first(): 247 | processor = AutoProcessor.from_pretrained( 248 | args.model_name_or_path, trust_remote_code=True, num_crops=args.num_crops 249 | ) 250 | model = create_model( 251 | args.model_name_or_path, 252 | use_flash_attention=args.use_flash_attention, 253 | use_qlora=args.use_qlora, 254 | ) 255 | 256 | eval_dataset = create_dataset(args) 257 | 258 | num_gpus = accelerator.num_processes 259 | print(f'training on {num_gpus} GPUs') 260 | assert args.batch_size % num_gpus == 0, 'Batch size must be divisible by the number of GPUs' 261 | 262 | 263 | # eval after fine-tuning 264 | if args.use_lora: 265 | # first try to clear GPU memory 266 | del model 267 | __import__('gc').collect() 268 | torch.cuda.empty_cache() 269 | 270 | # reload the model for inference 271 | # this part also serves as an example of how to load a trained model 272 | model = AutoModelForCausalLM.from_pretrained( 273 | args.model_name_or_path, 274 | # Phi-3-V is originally trained in bf16 + flash attn 275 | # For fp16 mixed precision training, load in f32 to avoid hf accelerate error 276 | torch_dtype=torch.bfloat16 if args.use_flash_attention else torch.float32, 277 | trust_remote_code=True, 278 | _attn_implementation='flash_attention_2' if args.use_flash_attention else 'eager', 279 | ) 280 | model.load_adapter(args.output_dir) 281 | 282 | local_rank = int(os.environ.get('LOCAL_RANK', 0)) 283 | model = model.to(f'cuda:{local_rank}') 284 | evaluate( 285 | model, 286 | processor, 287 | eval_dataset, 288 | args, 289 | save_path=args.save_path, 290 | disable_tqdm=not args.tqdm, 291 | ) 292 | 293 | 294 | 295 | if __name__ == '__main__': 296 | main() 297 | -------------------------------------------------------------------------------- /Phi-3-vision-128k-instruct/processing_phi3_v_cfg.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ 17 | Processor class for Phi3-V. 18 | """ 19 | import re 20 | from typing import List, Optional, Union 21 | 22 | import torch 23 | 24 | import transformers 25 | from transformers.feature_extraction_utils import BatchFeature 26 | from transformers.image_utils import ImageInput 27 | from transformers.processing_utils import ProcessorMixin 28 | from transformers.tokenization_utils_base import PaddingStrategy, TextInput, TruncationStrategy 29 | from src.highlighter_modules.utils import txt_highlight_mask 30 | from transformers.utils import TensorType 31 | from .image_processing_phi3_v import Phi3VImageProcessor 32 | 33 | transformers.Phi3VImageProcessor = Phi3VImageProcessor 34 | 35 | 36 | class Phi3VProcessor(ProcessorMixin): 37 | r""" 38 | Constructs a Phi3-V processor which wraps a Phi3-V image processor and a LLaMa tokenizer into a single processor. 39 | 40 | [`Phi3VProcessor`] offers all the functionalities of [`Phi3VImageProcessor`] and [`LlamaTokenizerFast`]. See the 41 | [`~Phi3VProcessor.__call__`] and [`~Phi3VProcessor.decode`] for more information. 42 | 43 | Args: 44 | image_processor ([`Phi3VImageProcessor`], *optional*): 45 | The image processor is a required input. 46 | tokenizer ([`LlamaTokenizerFast`], *optional*): 47 | The tokenizer is a required input. 48 | """ 49 | 50 | attributes = ["image_processor", "tokenizer"] 51 | image_processor_class = "Phi3VImageProcessor" 52 | tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") 53 | special_image_token = "<|image|>" 54 | 55 | def __init__(self, image_processor, tokenizer): 56 | self.image_processor = image_processor 57 | self.tokenizer = tokenizer 58 | self.tokenizer.padding_side = 'left' 59 | self.num_img_tokens = image_processor.num_img_tokens 60 | self.img_tokens = [f"<|image_{i + 1}|>" for i in range(1000000)] 61 | 62 | def __call__( 63 | self, 64 | text: Union[TextInput, List[TextInput]], 65 | images: ImageInput = None, 66 | padding: Union[bool, str, PaddingStrategy] = False, 67 | truncation: Union[bool, str, TruncationStrategy] = None, 68 | max_length=None, 69 | return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, 70 | qs_highlighted_parts: List[str] = None 71 | ) -> BatchFeature: 72 | """ 73 | Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` 74 | and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode 75 | the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to 76 | Phi3ImageProcessor's [`~Phi3ImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring 77 | of the above two methods for more information. 78 | 79 | Args: 80 | text (`str`, `List[str]`, `List[List[str]]`): 81 | The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings 82 | (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set 83 | `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). 84 | images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): 85 | The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch 86 | tensor. Both channels-first and channels-last formats are supported. 87 | padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): 88 | Select a strategy to pad the returned sequences (according to the model's padding side and padding 89 | index) among: 90 | - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single 91 | sequence if provided). 92 | - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum 93 | acceptable input length for the model if that argument is not provided. 94 | - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different 95 | lengths). 96 | max_length (`int`, *optional*): 97 | Maximum length of the returned list and optionally padding length (see above). 98 | truncation (`bool`, *optional*): 99 | Activates truncation to cut input sequences longer than `max_length` to `max_length`. 100 | return_tensors (`str` or [`~utils.TensorType`], *optional*): 101 | If set, will return tensors of a particular framework. Acceptable values are: 102 | 103 | - `'tf'`: Return TensorFlow `tf.constant` objects. 104 | - `'pt'`: Return PyTorch `torch.Tensor` objects. 105 | - `'np'`: Return NumPy `np.ndarray` objects. 106 | - `'jax'`: Return JAX `jnp.ndarray` objects. 107 | 108 | Returns: 109 | [`BatchFeature`]: A [`BatchFeature`] with the following fields: 110 | 111 | - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. 112 | - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when 113 | `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not 114 | `None`). 115 | - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. 116 | """ 117 | if images is not None: 118 | image_inputs = self.image_processor(images, return_tensors=return_tensors) 119 | else: 120 | image_inputs = {} 121 | inputs = self._convert_images_texts_to_inputs(image_inputs, text, padding=padding, truncation=truncation, 122 | max_length=max_length, return_tensors=return_tensors, 123 | qs_highlighted_parts=qs_highlighted_parts) 124 | return inputs 125 | 126 | def calc_num_image_tokens(self, images: ImageInput): 127 | """ Calculate the number of image tokens for each image. 128 | Args: 129 | images (`ImageInput`): 130 | Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If 131 | passing in images with pixel values between 0 and 1, set `do_rescale=False`. 132 | """ 133 | return self.image_processor.calc_num_image_tokens(images) 134 | 135 | def calc_num_image_tokens_from_image_size(self, width, height): 136 | """ Calculate the number of image token for an image with given width and height. 137 | Args: 138 | width (`int`): 139 | Width of the image. 140 | height (`int`): 141 | Height of the image. 142 | """ 143 | return self.image_processor.calc_num_image_tokens_from_image_size(width, height) 144 | 145 | @property 146 | def special_image_token_id(self): 147 | return self.tokenizer.convert_tokens_to_ids(self.special_image_token) 148 | 149 | def get_special_image_token_id(self): 150 | return self.tokenizer.convert_tokens_to_ids(self.special_image_token) 151 | 152 | def tokenize_and_create_masks(self, pattern, texts, highlighted_parts): 153 | # 分割文本并保留图像标签 154 | prompt_chunks = re.split(pattern, texts) 155 | image_tags = re.findall(pattern, texts) 156 | 157 | input_ids = [] 158 | highlight_attention_mask = [] 159 | 160 | # 处理每个chunk和对应的图像标签 161 | for i, chunk in enumerate(prompt_chunks): 162 | # Tokenize the chunk 163 | chunk_ids = self.tokenizer.encode(chunk, add_special_tokens=False) 164 | 165 | # 生成对应chunk的高亮掩码 166 | chunk_highlight_mask = [0] * len(chunk_ids) 167 | for part in highlighted_parts: 168 | start = 0 169 | while start < len(chunk): 170 | start_index = chunk.find(part, start) 171 | if start_index == -1: 172 | break 173 | end_index = start_index + len(part) 174 | # 将字符索引转换为token索引 175 | start_token_idx = len(self.tokenizer.encode(chunk[:start_index], add_special_tokens=False)) 176 | end_token_idx = len(self.tokenizer.encode(chunk[:end_index], add_special_tokens=False)) 177 | for idx in range(start_token_idx, end_token_idx): 178 | if idx < len(chunk_highlight_mask): 179 | chunk_highlight_mask[idx] = 1 180 | start = end_index 181 | 182 | # 添加chunk的token ids和高亮掩码到总列表 183 | input_ids.extend(chunk_ids) 184 | highlight_attention_mask.extend(chunk_highlight_mask) 185 | 186 | # 如果还有图像标签,处理图像标签 187 | if i < len(image_tags): 188 | image_tag_ids = self.tokenizer.encode(image_tags[i], add_special_tokens=False) 189 | input_ids.extend(image_tag_ids) 190 | # 图像标签不高亮 191 | highlight_attention_mask.extend([0] * len(image_tag_ids)) 192 | 193 | return input_ids, highlight_attention_mask 194 | # 转换为torch tensor 195 | # input_ids_tensor = torch.tensor(input_ids).unsqueeze(0) 196 | # highlight_attention_mask_tensor = torch.tensor(highlight_attention_mask).unsqueeze(0) 197 | # 198 | # 199 | # return input_ids_tensor, highlight_attention_mask_tensor 200 | 201 | def _convert_images_texts_to_inputs(self, images, texts, padding=False, truncation=None, max_length=None, 202 | return_tensors=None, qs_highlighted_parts=[]): 203 | 204 | if not len(images): 205 | model_inputs = self.tokenizer(texts, return_tensors=return_tensors, padding=padding, truncation=truncation, 206 | max_length=max_length) 207 | return BatchFeature(data={**model_inputs}) 208 | 209 | # print(texts, qs_highlighted_parts, '!!!!!this is in processing.py') 210 | 211 | pattern = r"<\|image_\d+\|>" 212 | 213 | if 'num_img_tokens' in images: 214 | num_img_tokens = images['num_img_tokens'] 215 | else: 216 | assert 'num_crops' in images, 'num_crops must be provided in images if num_img_tokens is not provided' 217 | num_crops = images['num_crops'] 218 | num_img_tokens = [_num_crops * self.num_img_tokens for _num_crops in num_crops] 219 | 220 | images, image_sizes = images['pixel_values'], images['image_sizes'] 221 | 222 | # image_tags needs to start from 1 to n 223 | image_tags = re.findall(pattern, texts) 224 | # image_ids = [int(s.split("|")[1].split("_")[-1]) * -1 for s in image_tags] 225 | # image_ids_pad = [[iid]*num_img_tokens[i] for i, iid in enumerate(image_ids)] 226 | image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags] 227 | unique_image_ids = sorted(list(set(image_ids))) 228 | # image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be [1, 4, 5] 229 | # check the condition 230 | assert unique_image_ids == list(range(1, len(unique_image_ids) + 1)), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}" 231 | # total images must be the same as the number of image tags 232 | assert len(unique_image_ids) == len( 233 | images), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(images)} images" 234 | 235 | image_ids_pad = [[-iid] * num_img_tokens[iid - 1] for iid in image_ids] 236 | 237 | def insert_separator(X, sep_list): 238 | if len(X) > len(sep_list): 239 | sep_list.append([]) 240 | return [ele for sublist in zip(X, sep_list) for ele in sublist] 241 | 242 | highlight_attention_mask = [] 243 | prompt_chunks = [] 244 | 245 | 246 | # Generate highlight masks for each chunk 247 | for chunk in re.split(pattern, texts): 248 | chunk_mask, _ = txt_highlight_mask(self.tokenizer, chunk, qs_highlighted_parts) 249 | highlight_attention_mask.append([0] + chunk_mask) 250 | a = self.tokenizer(chunk) 251 | prompt_chunks.append(a.input_ids) 252 | 253 | offset = 0 254 | input_ids = [] 255 | combined_highlight_mask = [] 256 | zero_mask_padding = [[0] * len(pad) for pad in 257 | image_ids_pad] # Create zero padding mask with the same length as image ids 258 | 259 | for tokens, mask in zip(insert_separator(prompt_chunks, image_ids_pad), 260 | insert_separator(highlight_attention_mask, 261 | zero_mask_padding)): # Use zero_mask_padding here 262 | input_ids.extend(tokens[offset:]) 263 | combined_highlight_mask.extend(mask[offset:]) 264 | 265 | input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0) 266 | attention_mask = (input_ids > -1000000).to(torch.long) 267 | combined_highlight_mask = torch.tensor(combined_highlight_mask, dtype=torch.long).unsqueeze(0) 268 | 269 | return BatchFeature(data={"input_ids": input_ids, 270 | "attention_mask": attention_mask, 271 | "highlight_attention_mask": combined_highlight_mask, 272 | "pixel_values": images, 273 | "image_sizes": image_sizes}) 274 | 275 | # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama 276 | def batch_decode(self, *args, **kwargs): 277 | """ 278 | This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please 279 | refer to the docstring of this method for more information. 280 | """ 281 | return self.tokenizer.batch_decode(*args, **kwargs) 282 | 283 | # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama 284 | def decode(self, *args, **kwargs): 285 | """ 286 | This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to 287 | the docstring of this method for more information. 288 | """ 289 | return self.tokenizer.decode(*args, **kwargs) 290 | 291 | @property 292 | # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names 293 | def model_input_names(self): 294 | tokenizer_input_names = self.tokenizer.model_input_names 295 | image_processor_input_names = self.image_processor.model_input_names 296 | return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) -------------------------------------------------------------------------------- /Phi-3.5-vision-instruct/processing_phi3_v_cfg.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ 17 | Processor class for Phi3-V. 18 | """ 19 | import re 20 | from typing import List, Optional, Union 21 | 22 | import torch 23 | 24 | import transformers 25 | from transformers.feature_extraction_utils import BatchFeature 26 | from transformers.image_utils import ImageInput 27 | from transformers.processing_utils import ProcessorMixin 28 | from transformers.tokenization_utils_base import PaddingStrategy, TextInput, TruncationStrategy 29 | from transformers.utils import TensorType 30 | from src.highlighter_modules.utils import txt_highlight_mask 31 | 32 | """Image processor class for Phi3-V.""" 33 | 34 | from typing import List, Optional, Union 35 | 36 | import numpy as np 37 | 38 | from transformers.image_processing_utils import BaseImageProcessor, BatchFeature 39 | from transformers.image_transforms import ( 40 | convert_to_rgb, 41 | ) 42 | from transformers.image_utils import ( 43 | OPENAI_CLIP_MEAN, 44 | OPENAI_CLIP_STD, 45 | ImageInput, 46 | make_list_of_images, 47 | valid_images, 48 | ) 49 | from transformers.utils import TensorType, is_vision_available, logging 50 | 51 | from transformers import AutoImageProcessor 52 | 53 | logger = logging.get_logger(__name__) 54 | 55 | 56 | if is_vision_available(): 57 | from PIL import Image 58 | 59 | import torch 60 | import torchvision 61 | 62 | def padding_336(b): 63 | width, height = b.size 64 | tar = int(np.ceil(height / 336) * 336) 65 | top_padding = int((tar - height)/2) 66 | bottom_padding = tar - height - top_padding 67 | left_padding = 0 68 | right_padding = 0 69 | b = torchvision.transforms.functional.pad(b, [left_padding, top_padding, right_padding, bottom_padding], fill=[255,255,255]) 70 | 71 | return b 72 | 73 | def calc_padded_size(width, height, padding_unit=336): 74 | target_height = int(np.ceil(height / padding_unit) * padding_unit) 75 | top_padding = int((target_height - height) / 2) 76 | bottom_padding = target_height - height - top_padding 77 | left_padding = 0 78 | right_padding = 0 79 | padded_width = width + left_padding + right_padding 80 | padded_height = height + top_padding + bottom_padding 81 | return padded_width, padded_height 82 | 83 | def HD_transform(img, hd_num=16): 84 | width, height = img.size 85 | trans = False 86 | if width < height: 87 | img = img.transpose(Image.TRANSPOSE) 88 | trans = True 89 | width, height = img.size 90 | ratio = (width/ height) 91 | scale = 1 92 | while scale*np.ceil(scale/ratio) <= hd_num: 93 | scale += 1 94 | scale -= 1 95 | new_w = int(scale * 336) 96 | new_h = int(new_w / ratio) 97 | 98 | img = torchvision.transforms.functional.resize(img, [new_h, new_w],) 99 | img = padding_336(img) 100 | width, height = img.size 101 | if trans: 102 | img = img.transpose(Image.TRANSPOSE) 103 | 104 | return img 105 | 106 | def calc_hd_transform_size(width, height, hd_num=16): 107 | transposed = False 108 | if width < height: 109 | width, height = height, width 110 | transposed = True 111 | 112 | ratio = width / height 113 | scale = 1 114 | while scale * np.ceil(scale / ratio) <= hd_num: 115 | scale += 1 116 | scale -= 1 117 | 118 | new_width = int(scale * 336) 119 | new_height = int(new_width / ratio) 120 | 121 | padded_width, padded_height = calc_padded_size(new_width, new_height) 122 | 123 | if transposed: 124 | padded_width, padded_height = padded_height, padded_width 125 | 126 | return padded_width, padded_height 127 | 128 | def pad_to_max_num_crops_tensor(images, max_crops=5): 129 | """ 130 | images: B x 3 x H x W, B<=max_crops 131 | """ 132 | B, _, H, W = images.shape 133 | if B < max_crops: 134 | pad = torch.zeros(max_crops - B, 3, H, W, dtype=images.dtype, device=images.device) 135 | images = torch.cat([images, pad], dim=0) 136 | return images 137 | 138 | 139 | class Phi3VImageProcessor(BaseImageProcessor): 140 | r""" 141 | Constructs a Phi3 image processor. Based on [`CLIPImageProcessor`] with incorporation of additional techniques 142 | for processing high resolution images as explained in the [InternLM-XComposer2-4KHD](https://arxiv.org/pdf/2404.06512) 143 | 144 | Args: 145 | image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): 146 | Mean to use if normalizing the image. This is a float or list of floats the length of the number of 147 | channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. 148 | image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): 149 | Standard deviation to use if normalizing the image. This is a float or list of floats the length of the 150 | number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. 151 | Can be overridden by the `image_std` parameter in the `preprocess` method. 152 | do_convert_rgb (`bool`, *optional*, defaults to `True`): 153 | Whether to convert the image to RGB. 154 | """ 155 | 156 | model_input_names = ["pixel_values"] 157 | 158 | def __init__( 159 | self, 160 | num_crops: int = 1, 161 | image_mean: Optional[Union[float, List[float]]] = None, 162 | image_std: Optional[Union[float, List[float]]] = None, 163 | do_convert_rgb: bool = True, 164 | **kwargs, 165 | ) -> None: 166 | super().__init__(**kwargs) 167 | self.num_crops = num_crops 168 | self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN 169 | self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD 170 | self.do_convert_rgb = do_convert_rgb 171 | 172 | def calc_num_image_tokens( 173 | self, 174 | images: ImageInput 175 | ): 176 | """ Calculate the number of image tokens for each image. 177 | Args: 178 | images (`ImageInput`): 179 | Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If 180 | passing in images with pixel values between 0 and 1, set `do_rescale=False`. 181 | """ 182 | images = make_list_of_images(images) 183 | 184 | if not valid_images(images): 185 | raise ValueError( 186 | "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " 187 | "torch.Tensor, tf.Tensor or jax.ndarray." 188 | ) 189 | 190 | images = [image.convert('RGB') for image in images] 191 | # (H, W, C) 192 | elems = [HD_transform(im, hd_num = self.num_crops) for im in images] 193 | shapes = [[im.size[1], im.size[0]] for im in elems] 194 | num_img_tokens = [int((h//336*w//336+1)*144 + 1 + (h//336+1)*12) for h, w in shapes] 195 | return num_img_tokens 196 | 197 | def calc_num_image_tokens_from_image_size(self, width, height): 198 | """ 199 | Calculate the number of image tokens for a given image size. 200 | Args: 201 | width (`int`): Width of the image. 202 | height (`int`): Height of the image. 203 | """ 204 | new_width, new_height = calc_hd_transform_size(width, height, hd_num=self.num_crops) 205 | num_img_tokens = int((new_height // 336 * new_width // 336 + 1) * 144 + 1 + (new_height // 336 + 1) * 12) 206 | return num_img_tokens 207 | 208 | def preprocess( 209 | self, 210 | images: ImageInput, 211 | image_mean: Optional[Union[float, List[float]]] = None, 212 | image_std: Optional[Union[float, List[float]]] = None, 213 | do_convert_rgb: bool = None, 214 | return_tensors: Optional[Union[str, TensorType]] = None, 215 | ): 216 | """ 217 | Args: 218 | images (`ImageInput`): 219 | Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If 220 | passing in images with pixel values between 0 and 1, set `do_rescale=False`. 221 | image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): 222 | Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. 223 | image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): 224 | Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to 225 | `True`. 226 | do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): 227 | Whether to convert the image to RGB. 228 | return_tensors (`str` or `TensorType`, *optional*): 229 | The type of tensors to return. Can be one of: 230 | - Unset: Return a list of `np.ndarray`. 231 | - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. 232 | - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. 233 | - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. 234 | - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. 235 | """ 236 | image_mean = image_mean if image_mean is not None else self.image_mean 237 | image_std = image_std if image_std is not None else self.image_std 238 | do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb 239 | 240 | images = make_list_of_images(images) 241 | 242 | if not valid_images(images): 243 | raise ValueError( 244 | "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " 245 | "torch.Tensor, tf.Tensor or jax.ndarray." 246 | ) 247 | 248 | if do_convert_rgb: 249 | images = [convert_to_rgb(image) for image in images] 250 | 251 | image_sizes = [] 252 | img_processor = torchvision.transforms.Compose([ 253 | torchvision.transforms.ToTensor(), 254 | torchvision.transforms.Normalize(image_mean, image_std) 255 | ]) 256 | 257 | # PIL images 258 | # HD_transform pad images to size of multiiply of 336, 336 259 | # convert to RGB first 260 | images = [image.convert('RGB') for image in images] 261 | elems = [HD_transform(im, hd_num = self.num_crops) for im in images] 262 | # tensor transform and normalize 263 | hd_images = [img_processor(im) for im in elems] 264 | # create global image 265 | global_image = [torch.nn.functional.interpolate(im.unsqueeze(0).float(), size=(336, 336), mode='bicubic',).to(im.dtype) for im in hd_images] 266 | 267 | # [(3, h, w)], where h, w is multiple of 336 268 | shapes = [[im.size(1), im.size(2)] for im in hd_images] 269 | num_img_tokens = [int(((h//336)*(w//336)+1)*144 + 1 + (h//336+1)*12) for h, w in shapes] 270 | # reshape to channel dimension -> (num_images, num_crops, 3, 336, 336) 271 | # (1, 3, h//336, 336, w//336, 336) -> (1, h//336, w//336, 3, 336, 336) -> (h//336*w//336, 3, 336, 336) 272 | hd_images_reshape = [im.reshape(1, 3, h//336, 336, w//336, 336).permute(0,2,4,1,3,5).reshape(-1, 3, 336, 336).contiguous() for im, (h, w) in zip(hd_images, shapes)] 273 | # concat global image and local image 274 | hd_images_reshape = [torch.cat([_global_image] + [_im], dim=0) for _global_image, _im in zip(global_image, hd_images_reshape)] 275 | 276 | # pad to max_num_crops 277 | image_transformed = [pad_to_max_num_crops_tensor(im, self.num_crops+1) for im in hd_images_reshape] 278 | image_transformed = torch.stack(image_transformed, dim=0) 279 | image_sizes = [torch.LongTensor(_shapes) for _shapes in shapes] 280 | padded_images = image_transformed 281 | image_sizes = shapes 282 | 283 | data = {"pixel_values": padded_images, 284 | "image_sizes": image_sizes, 285 | "num_img_tokens": num_img_tokens 286 | } 287 | 288 | return BatchFeature(data=data, tensor_type=return_tensors) 289 | 290 | AutoImageProcessor.register("Phi3VImageProcessor", Phi3VImageProcessor) 291 | 292 | transformers.Phi3VImageProcessor = Phi3VImageProcessor 293 | 294 | class Phi3VProcessor(ProcessorMixin): 295 | r""" 296 | Constructs a Phi3-V processor which wraps a Phi3-V image processor and a LLaMa tokenizer into a single processor. 297 | 298 | [`Phi3VProcessor`] offers all the functionalities of [`Phi3VImageProcessor`] and [`LlamaTokenizerFast`]. See the 299 | [`~Phi3VProcessor.__call__`] and [`~Phi3VProcessor.decode`] for more information. 300 | 301 | Args: 302 | image_processor ([`Phi3VImageProcessor`], *optional*): 303 | The image processor is a required input. 304 | tokenizer ([`LlamaTokenizerFast`], *optional*): 305 | The tokenizer is a required input. 306 | """ 307 | 308 | attributes = ["image_processor", "tokenizer"] 309 | image_processor_class = "Phi3VImageProcessor" 310 | tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") 311 | special_image_token = "<|image|>" 312 | 313 | def __init__(self, image_processor, tokenizer): 314 | self.image_processor = image_processor 315 | self.tokenizer = tokenizer 316 | self.num_img_tokens = image_processor.num_img_tokens 317 | self.img_tokens = [f"<|image_{i+1}|>" for i in range(1000000)] 318 | 319 | def __call__( 320 | self, 321 | text: Union[TextInput, List[TextInput]], 322 | images: ImageInput = None, 323 | padding: Union[bool, str, PaddingStrategy] = False, 324 | truncation: Union[bool, str, TruncationStrategy] = None, 325 | max_length=None, 326 | return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, 327 | qs_highlighted_parts: List[str] = None 328 | ) -> BatchFeature: 329 | """ 330 | Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` 331 | and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode 332 | the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to 333 | Phi3ImageProcessor's [`~Phi3ImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring 334 | of the above two methods for more information. 335 | 336 | Args: 337 | text (`str`, `List[str]`, `List[List[str]]`): 338 | The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings 339 | (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set 340 | `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). 341 | images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): 342 | The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch 343 | tensor. Both channels-first and channels-last formats are supported. 344 | padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): 345 | Select a strategy to pad the returned sequences (according to the model's padding side and padding 346 | index) among: 347 | - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single 348 | sequence if provided). 349 | - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum 350 | acceptable input length for the model if that argument is not provided. 351 | - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different 352 | lengths). 353 | max_length (`int`, *optional*): 354 | Maximum length of the returned list and optionally padding length (see above). 355 | truncation (`bool`, *optional*): 356 | Activates truncation to cut input sequences longer than `max_length` to `max_length`. 357 | return_tensors (`str` or [`~utils.TensorType`], *optional*): 358 | If set, will return tensors of a particular framework. Acceptable values are: 359 | 360 | - `'tf'`: Return TensorFlow `tf.constant` objects. 361 | - `'pt'`: Return PyTorch `torch.Tensor` objects. 362 | - `'np'`: Return NumPy `np.ndarray` objects. 363 | - `'jax'`: Return JAX `jnp.ndarray` objects. 364 | 365 | Returns: 366 | [`BatchFeature`]: A [`BatchFeature`] with the following fields: 367 | 368 | - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. 369 | - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when 370 | `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not 371 | `None`). 372 | - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. 373 | """ 374 | if images is not None: 375 | image_inputs = self.image_processor(images, return_tensors=return_tensors) 376 | else: 377 | image_inputs = {} 378 | inputs = self._convert_images_texts_to_inputs(image_inputs, text, padding=padding, truncation=truncation, 379 | max_length=max_length, return_tensors=return_tensors, 380 | qs_highlighted_parts=qs_highlighted_parts) 381 | 382 | return inputs 383 | 384 | def calc_num_image_tokens(self, images: ImageInput): 385 | """ Calculate the number of image tokens for each image. 386 | Args: 387 | images (`ImageInput`): 388 | Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If 389 | passing in images with pixel values between 0 and 1, set `do_rescale=False`. 390 | """ 391 | return self.image_processor.calc_num_image_tokens(images) 392 | 393 | def calc_num_image_tokens_from_image_size(self, width, height): 394 | """ Calculate the number of image token for an image with given width and height. 395 | Args: 396 | width (`int`): 397 | Width of the image. 398 | height (`int`): 399 | Height of the image. 400 | """ 401 | return self.image_processor.calc_num_image_tokens_from_image_size(width, height) 402 | 403 | 404 | @property 405 | def special_image_token_id(self): 406 | return self.tokenizer.convert_tokens_to_ids(self.special_image_token) 407 | 408 | def get_special_image_token_id(self): 409 | return self.tokenizer.convert_tokens_to_ids(self.special_image_token) 410 | 411 | def _convert_images_texts_to_inputs(self, images, texts, padding=False, truncation=None, max_length=None, return_tensors=None, qs_highlighted_parts=[]): 412 | 413 | if not len(images): 414 | model_inputs = self.tokenizer(texts, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length, padding_side='left') 415 | return BatchFeature(data={**model_inputs}) 416 | 417 | pattern = r"<\|image_\d+\|>" 418 | # prompt_chunks = [self.tokenizer(chunk).input_ids for chunk in re.split(pattern, texts)] 419 | 420 | if 'num_img_tokens' in images: 421 | num_img_tokens = images['num_img_tokens'] 422 | else: 423 | assert 'num_crops' in images, 'num_crops must be provided in images if num_img_tokens is not provided' 424 | num_crops = images['num_crops'] 425 | num_img_tokens = [_num_crops * self.num_img_tokens for _num_crops in num_crops] 426 | 427 | images, image_sizes = images['pixel_values'], images['image_sizes'] 428 | 429 | # image_tags needs to start from 1 to n 430 | image_tags = re.findall(pattern, texts) 431 | # image_ids = [int(s.split("|")[1].split("_")[-1]) * -1 for s in image_tags] 432 | # image_ids_pad = [[iid]*num_img_tokens[i] for i, iid in enumerate(image_ids)] 433 | image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags] 434 | unique_image_ids = sorted(list(set(image_ids))) 435 | # image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be [1, 4, 5] 436 | # check the condition 437 | assert unique_image_ids == list(range(1, len(unique_image_ids)+1)), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}" 438 | # total images must be the same as the number of image tags 439 | assert len(unique_image_ids) == len(images), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(images)} images" 440 | 441 | image_ids_pad = [[-iid]*num_img_tokens[iid-1] for iid in image_ids] 442 | 443 | def insert_separator(X, sep_list): 444 | if len(X) > len(sep_list): 445 | sep_list.append([]) 446 | return [ele for sublist in zip(X, sep_list) for ele in sublist] 447 | 448 | # input_ids = [] 449 | # offset = 0 450 | # for x in insert_separator(prompt_chunks, image_ids_pad): 451 | # input_ids.extend(x[offset:]) 452 | # 453 | # input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0) 454 | # attention_mask = (input_ids > -1000000).to(torch.long) 455 | # 456 | # return BatchFeature(data={"input_ids": input_ids, 457 | # "attention_mask": attention_mask, 458 | # "pixel_values": images, 459 | # "image_sizes": image_sizes}) 460 | highlight_attention_mask = [] 461 | prompt_chunks = [] 462 | # Generate highlight masks for each chunk 463 | for chunk in re.split(pattern, texts): 464 | chunk_mask, _ = txt_highlight_mask(self.tokenizer, chunk, qs_highlighted_parts) 465 | highlight_attention_mask.append([0] + chunk_mask) 466 | a = self.tokenizer(chunk) 467 | prompt_chunks.append(a.input_ids) 468 | 469 | offset = 0 470 | input_ids = [] 471 | combined_highlight_mask = [] 472 | zero_mask_padding = [[0] * len(pad) for pad in 473 | image_ids_pad] # Create zero padding mask with the same length as image ids 474 | 475 | for tokens, mask in zip(insert_separator(prompt_chunks, image_ids_pad), 476 | insert_separator(highlight_attention_mask, 477 | zero_mask_padding)): # Use zero_mask_padding here 478 | input_ids.extend(tokens[offset:]) 479 | combined_highlight_mask.extend(mask[offset:]) 480 | 481 | input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0) 482 | attention_mask = (input_ids > -1000000).to(torch.long) 483 | combined_highlight_mask = torch.tensor(combined_highlight_mask, dtype=torch.long).unsqueeze(0) 484 | 485 | return BatchFeature(data={"input_ids": input_ids, 486 | "attention_mask": attention_mask, 487 | "highlight_attention_mask": combined_highlight_mask, 488 | "pixel_values": images, 489 | "image_sizes": image_sizes}) 490 | 491 | # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama 492 | def batch_decode(self, *args, **kwargs): 493 | """ 494 | This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please 495 | refer to the docstring of this method for more information. 496 | """ 497 | return self.tokenizer.batch_decode(*args, **kwargs) 498 | 499 | # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama 500 | def decode(self, *args, **kwargs): 501 | """ 502 | This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to 503 | the docstring of this method for more information. 504 | """ 505 | return self.tokenizer.decode(*args, **kwargs) 506 | 507 | @property 508 | # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names 509 | def model_input_names(self): 510 | tokenizer_input_names = self.tokenizer.model_input_names 511 | image_processor_input_names = self.image_processor.model_input_names 512 | return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) --------------------------------------------------------------------------------