├── Models ├── RadFM │ ├── __init__.py │ ├── Language_files │ │ ├── special_tokens_map.json │ │ ├── tokenizer.model │ │ ├── tokenizer_config.json │ │ └── config.json │ ├── utils.py │ ├── multimodality_model.py │ ├── vit_3d.py │ ├── position_encoding.py │ ├── transformer_decoder.py │ ├── helpers.py │ ├── my_embedding_layer.py │ └── blocks.py ├── Med_Flamingo │ └── src │ │ ├── __init__.py │ │ └── utils.py ├── __init__.py ├── gemini.py ├── claude.py ├── gpt.py ├── blip2.py ├── instructblip.py ├── molmo.py ├── llava.py ├── llama.py ├── llava_med.py ├── med_flamingo.py ├── deepseek.py └── radfm.py ├── assets ├── banner.png ├── logo.png ├── example.png └── samples.png ├── data ├── stats │ ├── plots │ │ ├── DINO_hist.jpg │ │ └── BiomedCLIP_hist.jpg │ └── statistics.json └── few_shot_prompts │ ├── PMC1064097_F2.jpg │ ├── PMC1065025_F1.jpg │ └── PMC1087855_F3.jpg ├── configs ├── Models │ ├── gemini │ │ └── vanilla.json │ ├── claude │ │ └── vanilla.json │ ├── gpt │ │ └── vanilla.json │ ├── blip2 │ │ └── vanilla.json │ ├── molmo │ │ └── vanilla.json │ ├── radfm │ │ └── vanilla.json │ ├── deepseek │ │ └── vanilla.json │ ├── instructblip │ │ └── vanilla.json │ ├── llama │ │ └── vanilla.json │ ├── med_flamingo │ │ └── vanilla.json │ ├── llava │ │ └── vanilla.json │ └── llava_med │ │ └── vanilla.json └── prompts │ └── answering.json ├── requirements.txt ├── scripts ├── printing.py ├── download.py └── answering.py ├── LICENSE ├── utils ├── io_tools.py └── answering.py └── README.md /Models/RadFM/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Models/Med_Flamingo/src/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Models/RadFM/Language_files/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /assets/banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MShahabSepehri/MediConfusion/main/assets/banner.png -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MShahabSepehri/MediConfusion/main/assets/logo.png -------------------------------------------------------------------------------- /assets/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MShahabSepehri/MediConfusion/main/assets/example.png -------------------------------------------------------------------------------- /assets/samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MShahabSepehri/MediConfusion/main/assets/samples.png -------------------------------------------------------------------------------- /data/stats/plots/DINO_hist.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MShahabSepehri/MediConfusion/main/data/stats/plots/DINO_hist.jpg -------------------------------------------------------------------------------- /data/stats/plots/BiomedCLIP_hist.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MShahabSepehri/MediConfusion/main/data/stats/plots/BiomedCLIP_hist.jpg -------------------------------------------------------------------------------- /data/few_shot_prompts/PMC1064097_F2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MShahabSepehri/MediConfusion/main/data/few_shot_prompts/PMC1064097_F2.jpg -------------------------------------------------------------------------------- /data/few_shot_prompts/PMC1065025_F1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MShahabSepehri/MediConfusion/main/data/few_shot_prompts/PMC1065025_F1.jpg -------------------------------------------------------------------------------- /data/few_shot_prompts/PMC1087855_F3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MShahabSepehri/MediConfusion/main/data/few_shot_prompts/PMC1087855_F3.jpg -------------------------------------------------------------------------------- /configs/Models/gemini/vanilla.json: -------------------------------------------------------------------------------- 1 | { 2 | "temperature": 0.2, 3 | "deployment_name": "gemini-2.0-flash", 4 | "init_prompt_id": 1 5 | } -------------------------------------------------------------------------------- /Models/RadFM/Language_files/tokenizer.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MShahabSepehri/MediConfusion/main/Models/RadFM/Language_files/tokenizer.model -------------------------------------------------------------------------------- /configs/Models/claude/vanilla.json: -------------------------------------------------------------------------------- 1 | { 2 | "temperature": 0.2, 3 | "deployment_name": "claude-3-opus-20240229", 4 | "init_prompt_id": 1 5 | } -------------------------------------------------------------------------------- /configs/Models/gpt/vanilla.json: -------------------------------------------------------------------------------- 1 | { 2 | "temperature": 0.7, 3 | "deployment_name": "o1", 4 | "api_version": "2025-01-01-preview", 5 | "init_prompt_id": 1 6 | } -------------------------------------------------------------------------------- /Models/RadFM/Language_files/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"bos_token": "", "eos_token": "", "model_max_length": 1000000000000000019884624838656, "tokenizer_class": "LlamaTokenizer", "unk_token": ""} -------------------------------------------------------------------------------- /configs/Models/blip2/vanilla.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_beams": 5, 3 | "max_new_tokens": 30, 4 | "top_p": 0.9, 5 | "temperature": 1, 6 | "model_id": "Salesforce/blip2-opt-2.7b" 7 | } -------------------------------------------------------------------------------- /configs/Models/molmo/vanilla.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_beams": 1, 3 | "max_new_tokens": 30, 4 | "top_p": null, 5 | "temperature": 0.7, 6 | "model_id": "allenai/Molmo-7B-D-0924" 7 | } -------------------------------------------------------------------------------- /configs/Models/radfm/vanilla.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_new_tokens": 64, 3 | "language_files_path": "/data/models/radfm/Language_files", 4 | "model_path": "/data/models/radfm/pytorch_model.bin" 5 | } -------------------------------------------------------------------------------- /configs/Models/deepseek/vanilla.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_path": "deepseek-ai/deepseek-vl2", 3 | "temperature": 0.7, 4 | "init_prompt_id": 1, 5 | "max_new_tokens": 512, 6 | "do_sample": false 7 | } -------------------------------------------------------------------------------- /configs/Models/instructblip/vanilla.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_beams": 5, 3 | "max_new_tokens": 128, 4 | "top_p": 0.9, 5 | "temperature": 1, 6 | "model_id": "Salesforce/instructblip-vicuna-7b" 7 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pyaml 2 | openai 3 | google.generativeai 4 | open_clip_torch 5 | torch 6 | pandas 7 | oauth2client 8 | python-dotenv 9 | open_flamingo 10 | anthropic 11 | ipython 12 | peft 13 | transformers 14 | tqdm==4.66.2 15 | fairscale 16 | fire 17 | blobfile 18 | -------------------------------------------------------------------------------- /data/stats/statistics.json: -------------------------------------------------------------------------------- 1 | { 2 | "Cerebral": 79, 3 | "Vascular": 73, 4 | "Head and Neck": 67, 5 | "Spinal": 51, 6 | "Musculoskeletal": 42, 7 | "Cardiac": 52, 8 | "Gastrointestinal": 43, 9 | "Pulmonary": 20, 10 | "Nuclear Medicine": 14 11 | } -------------------------------------------------------------------------------- /configs/Models/llama/vanilla.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_base": null, 3 | "top_p": null, 4 | "num_beams": 1, 5 | "temperature": 0.2, 6 | "init_prompt_id": 1, 7 | "max_new_tokens": 100, 8 | "model_id": "meta-llama/Llama-3.2-90B-Vision-Instruct", 9 | "token": "" 10 | } -------------------------------------------------------------------------------- /configs/Models/med_flamingo/vanilla.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_beams": 5, 3 | "max_new_tokens": 10, 4 | "top_p": 0.9, 5 | "temperature": 1, 6 | "LLaMa_PATH": "/data/models/llama", 7 | "CHECKPOINT_PATH": "/data/models/med_flamingo/model.pt", 8 | "IMAGE_PATH": "./data/few_shot_prompts" 9 | } -------------------------------------------------------------------------------- /configs/Models/llava/vanilla.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_base": null, 3 | "top_p": null, 4 | "num_beams": 1, 5 | "conv_mode": "vicuna_v1", 6 | "temperature": 0.2, 7 | "init_prompt_id": 1, 8 | "max_new_tokens": 100, 9 | "use_im_start_end": false, 10 | "model_id": "llava-hf/llava-v1.6-mistral-7b-hf" 11 | } -------------------------------------------------------------------------------- /configs/Models/llava_med/vanilla.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_path": "/data/models/llava-med-v1.5-mistral-7b", 3 | "model_base": null, 4 | "top_p": null, 5 | "num_beams": 1, 6 | "conv_mode": "mistral_instruct", 7 | "temperature": 0.2, 8 | "init_prompt_id": 1, 9 | "use_im_start_end": false, 10 | "max_new_tokens": 100 11 | } -------------------------------------------------------------------------------- /Models/__init__.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | def get_image(image): 4 | if type(image) is str: 5 | try: 6 | return Image.open(image).convert("RGB") 7 | except Exception as e: 8 | print(f"Fail to read image: {image}") 9 | exit(-1) 10 | elif type(image) is Image.Image: 11 | return image 12 | else: 13 | raise NotImplementedError(f"Invalid type of Image: {type(image)}") -------------------------------------------------------------------------------- /Models/RadFM/Language_files/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "/home/cs/leijiayu/wuchaoyi/Finetune_LLAMA/LLAMA_Model/llama-13b-hf", 3 | "architectures": [ 4 | "LlamaForCausalLM" 5 | ], 6 | "bos_token_id": 0, 7 | "eos_token_id": 1, 8 | "hidden_act": "silu", 9 | "hidden_size": 5120, 10 | "initializer_range": 0.02, 11 | "intermediate_size": 13824, 12 | "max_sequence_length": 2048, 13 | "model_type": "llama", 14 | "num_attention_heads": 40, 15 | "num_hidden_layers": 40, 16 | "pad_token_id": -1, 17 | "rms_norm_eps": 1e-06, 18 | "tie_word_embeddings": false, 19 | "torch_dtype": "float32", 20 | "transformers_version": "4.28.0.dev0", 21 | "use_cache": true, 22 | "vocab_size": 32000 23 | } 24 | -------------------------------------------------------------------------------- /scripts/printing.py: -------------------------------------------------------------------------------- 1 | import os, sys, pathlib 2 | sys.path.insert(0, os.path.dirname(pathlib.Path(__file__).parent.absolute())) 3 | 4 | import argparse 5 | from utils import io_tools 6 | from utils.answering import DEFAULT_MODEL_CONFIGS, BaseAnsweringModel 7 | 8 | def get_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--mllm_name", type=str, required=True, choices=set(DEFAULT_MODEL_CONFIGS.keys())) 11 | parser.add_argument("--mode", type=str, required=True, choices={'gpt4', 'mc', 'greedy', 'prefix'}) 12 | args = parser.parse_args() 13 | 14 | return args 15 | 16 | if __name__ == "__main__": 17 | args = get_args() 18 | ROOT = io_tools.get_root(__file__, 2) 19 | 20 | load_path = f'{ROOT}/Results/{args.mllm_name}/{args.mllm_name}_{args.mode}_score.json' 21 | 22 | score = io_tools.load_json(load_path) 23 | BaseAnsweringModel.print_score(score) 24 | -------------------------------------------------------------------------------- /Models/gemini.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dotenv import load_dotenv 3 | from IPython.display import Image 4 | import google.generativeai as genai 5 | 6 | def configure_client(): 7 | # Load environment variables from .env file 8 | load_dotenv() 9 | 10 | # Access the API key 11 | genai.configure(api_key=os.getenv('GEMINI_API_KEY')) 12 | 13 | 14 | def load_model(init_prompt, temperature, deployment_name): 15 | configure_client() 16 | model = genai.GenerativeModel(deployment_name, 17 | generation_config=genai.GenerationConfig(temperature=temperature), 18 | system_instruction=init_prompt, 19 | ) 20 | return model 21 | 22 | # Based on https://github.com/google-gemini/cookbook/blob/main/quickstarts/Prompting.ipynb 23 | def ask_question(model, image_path, question): 24 | img = Image(image_path) 25 | response = model.generate_content([question, img]) 26 | return response.text -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 USC Center on AI Foundations for Science (AIF4S) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Models/claude.py: -------------------------------------------------------------------------------- 1 | import os 2 | import base64 3 | from dotenv import load_dotenv 4 | from anthropic import Anthropic 5 | 6 | def get_base64_encoded_image(image_path): 7 | with open(image_path, "rb") as image_file: 8 | binary_data = image_file.read() 9 | base_64_encoded_data = base64.b64encode(binary_data) 10 | base64_string = base_64_encoded_data.decode('utf-8') 11 | return base64_string 12 | 13 | 14 | def get_client(): 15 | load_dotenv() 16 | client = Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) 17 | return client 18 | 19 | def ask_question(client, image_path, question, init_prompt, temperature, deployment_name): 20 | message_list = [ 21 | { 22 | "role": 'user', 23 | "content": [ 24 | {"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": get_base64_encoded_image(image_path)}}, 25 | {"type": "text", "text": question} 26 | ] 27 | } 28 | ] 29 | response = client.messages.create( 30 | model=deployment_name, 31 | max_tokens=2048, 32 | messages=message_list, 33 | temperature=temperature, 34 | system=init_prompt, 35 | ) 36 | return response.content[0].text -------------------------------------------------------------------------------- /Models/Med_Flamingo/src/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from abc import ABC, abstractmethod 3 | 4 | 5 | class AbstractProcessor(ABC): 6 | """ 7 | Abstract class for processors to show what methods they need to implement. 8 | Processors handle text encoding and image preprocessing. 9 | """ 10 | @abstractmethod 11 | def encode_text(self, prompt): 12 | pass 13 | 14 | @abstractmethod 15 | def preprocess_images(self, images: list): 16 | pass 17 | 18 | 19 | class FlamingoProcessor(AbstractProcessor): 20 | """ 21 | Processor class for Flamingo. 22 | """ 23 | def __init__(self, tokenizer, vision_processor): 24 | """ 25 | OF does not use same vision processor, image_processor only transforms single image 26 | """ 27 | self.tokenizer = tokenizer 28 | self.vision_processor = vision_processor 29 | 30 | def encode_text(self, prompt): 31 | self.tokenizer.padding_side = "left" 32 | # For generation padding tokens should be on the left 33 | return self.tokenizer([prompt], 34 | return_tensors="pt", 35 | ) 36 | 37 | def preprocess_images(self, images: list): 38 | vision_x = [self.vision_processor(im).unsqueeze(0) for im in images] 39 | vision_x = torch.cat(vision_x, dim=0) 40 | return vision_x 41 | 42 | 43 | -------------------------------------------------------------------------------- /scripts/download.py: -------------------------------------------------------------------------------- 1 | import os, sys, pathlib 2 | sys.path.insert(0, os.path.dirname(pathlib.Path(__file__).parent.absolute())) 3 | 4 | import shutil 5 | import argparse 6 | import subprocess 7 | from tqdm import tqdm 8 | from utils import io_tools 9 | 10 | def get_args(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--save_path", type=str, default='data/images') 13 | parser.add_argument("--image_dict_path", type=str, default='data/image_dict.json') 14 | args = parser.parse_args() 15 | return args 16 | 17 | if __name__ == "__main__": 18 | args = get_args() 19 | image_dict = io_tools.load_json(args.image_dict_path) 20 | 21 | if os.path.isdir(args.save_path): 22 | shutil.rmtree(args.save_path) 23 | os.makedirs(args.save_path) 24 | 25 | for key in tqdm(image_dict.keys()): 26 | sample = image_dict.get(key) 27 | local = sample.get('local') 28 | roco_id = key.split('/')[-1].replace('.jpg', '') 29 | link = sample.get('dlink') 30 | file_name = sample.get('file_name') 31 | pmc_name = link.split('/')[-1].replace('.tar.gz', '') 32 | subprocess.call(['wget', '-q', link, '-P', f'{args.save_path}/']) 33 | subprocess.call(['tar', '-xzf', f'{args.save_path}/{pmc_name}.tar.gz', '-C', args.save_path]) 34 | shutil.copy(f'{args.save_path}/{pmc_name}/{file_name}', f'{args.save_path}/{local}.jpg') 35 | shutil.rmtree(f'{args.save_path}/{pmc_name}') 36 | os.remove(f'{args.save_path}/{pmc_name}.tar.gz') -------------------------------------------------------------------------------- /scripts/answering.py: -------------------------------------------------------------------------------- 1 | import os, sys, pathlib 2 | sys.path.insert(0, os.path.dirname(pathlib.Path(__file__).parent.absolute())) 3 | 4 | import argparse 5 | from utils import io_tools 6 | from utils.answering import ANSWERING_CLASS_DICT, DEFAULT_MODEL_CONFIGS 7 | 8 | def get_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--tr", type=int, default=3) 11 | parser.add_argument("--mllm_name", type=str, required=True, choices=set(DEFAULT_MODEL_CONFIGS.keys())) 12 | parser.add_argument("--resume_path", type=str, default=None) 13 | parser.add_argument("--model_args_path", type=str, default=None) 14 | parser.add_argument("--local_image_address", type=bool, default=True) 15 | parser.add_argument("--data_path", type=str, default='./data/images') 16 | parser.add_argument("--mode", type=str, required=True, choices={'gpt4', 'mc', 'greedy', 'prefix'}) 17 | parser.add_argument("--device", type=str, default='cuda') 18 | args = parser.parse_args() 19 | 20 | if args.model_args_path is None: 21 | args.model_args_path = DEFAULT_MODEL_CONFIGS.get(args.mllm_name) 22 | 23 | return args 24 | 25 | if __name__ == "__main__": 26 | args = get_args() 27 | ROOT = io_tools.get_root(__file__, 2) 28 | 29 | save_path = f'{ROOT}/Results/' 30 | 31 | answering_class = ANSWERING_CLASS_DICT.get(args.mllm_name) 32 | 33 | ans_obj = answering_class(args.model_args_path, args.mode, args.data_path, args.local_image_address, args.tr, args.device) 34 | ans_obj.evaluate(args.resume_path, save_path) 35 | -------------------------------------------------------------------------------- /utils/io_tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import yaml 4 | import torch 5 | import pickle 6 | import pathlib 7 | import importlib 8 | 9 | 10 | def get_device(): 11 | if torch.cuda.is_available(): 12 | return 'cuda' 13 | return 'cpu' 14 | 15 | def get_obj_from_str(string, reload=False): 16 | module, cls = string.rsplit(".", 1) 17 | if reload: 18 | module_imp = importlib.import_module(module) 19 | importlib.reload(module_imp) 20 | return getattr(importlib.import_module(module, package=None), cls) 21 | 22 | 23 | def instantiate_from_config(config): 24 | if not "target" in config: 25 | if config == '__is_first_stage__': 26 | return None 27 | elif config == "__is_unconditional__": 28 | return None 29 | raise KeyError("Expected key `target` to instantiate.") 30 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 31 | 32 | 33 | def load_config_from_yaml(path): 34 | config_file = pathlib.Path(path) 35 | if config_file.exists(): 36 | with config_file.open('r') as f: 37 | d = yaml.safe_load(f) 38 | return d 39 | else: 40 | raise ValueError('Config file does not exist.') 41 | 42 | 43 | def str2int(s): 44 | return int.from_bytes(s.encode(), 'little') % (2 ** 32 - 1) 45 | 46 | 47 | def save_pickle(data, path): 48 | with open(path, 'wb') as file: 49 | pickle.dump(data, file) 50 | 51 | 52 | def load_pickle(path): 53 | with open(path, 'rb') as file: 54 | tmp = pickle.load(file) 55 | return tmp 56 | 57 | 58 | def check_and_create_dir(path): 59 | if not os.path.exists(path): 60 | os.makedirs(path) 61 | 62 | 63 | def get_root(file, num_returns=1): 64 | tmp = pathlib.Path(file) 65 | for _ in range(num_returns): 66 | tmp = tmp.parent.resolve() 67 | return tmp 68 | 69 | 70 | def load_json(path): 71 | with open(path, 'r') as f: 72 | return json.load(f) 73 | 74 | def save_json(data, path): 75 | with open(path, 'w') as f: 76 | json.dump(data, f, indent=4) 77 | 78 | def modify_json(data, path): 79 | old_data = load_json(path) 80 | new_keys = 0 81 | for key in data.keys(): 82 | if key in old_data: 83 | continue 84 | old_data[key] = data.get(key) 85 | new_keys += 1 86 | save_json(old_data, path) 87 | return new_keys 88 | 89 | def load_resume_dict(path): 90 | if path is None: 91 | return {} 92 | return load_json(path) -------------------------------------------------------------------------------- /Models/RadFM/utils.py: -------------------------------------------------------------------------------- 1 | from .blocks import ModifiedResNet,PMC_CLIP_cfg 2 | import torch 3 | from torchvision import transforms 4 | from PIL import Image 5 | import torch.nn as nn 6 | def extend_instance(obj, mixin): 7 | """Apply mixins to a class instance after creation""" 8 | base_cls = obj.__class__ 9 | base_cls_name = obj.__class__.__name__ 10 | obj.__class__ = type( 11 | base_cls_name, (mixin, base_cls), {} 12 | ) # mixin needs to go first for our forward() logic to work 13 | 14 | 15 | def getattr_recursive(obj, att): 16 | """ 17 | Return nested attribute of obj 18 | Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c 19 | """ 20 | if att == "": 21 | return obj 22 | i = att.find(".") 23 | if i < 0: 24 | return getattr(obj, att) 25 | else: 26 | return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :]) 27 | 28 | 29 | def setattr_recursive(obj, att, val): 30 | """ 31 | Set nested attribute of obj 32 | Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val 33 | """ 34 | if "." in att: 35 | obj = getattr_recursive(obj, ".".join(att.split(".")[:-1])) 36 | setattr(obj, att.split(".")[-1], val) 37 | 38 | 39 | 40 | def get_visual_encoder(model_str): 41 | """ 42 | Args: 43 | str (_type_): str_to_model_path 44 | Return: 45 | vision_model, visual_dim, img_preprocessor 46 | """ 47 | normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 48 | img_preprocessor = transforms.Compose([ 49 | transforms.Resize((512,512), interpolation=Image.BICUBIC), 50 | transforms.ToTensor(), 51 | normalize, 52 | ]) 53 | if 'PMC-CLIP' in model_str: 54 | #vision_cfg = json.load(open(model_args.visual_model_config,'r'))['vision_cfg'] 55 | vision_cfg = PMC_CLIP_cfg() 56 | vision_heads = vision_cfg.width * 32 // vision_cfg.head_width 57 | vision_model = ModifiedResNet( 58 | layers=vision_cfg.layers, 59 | heads=vision_heads, 60 | output_dim = 768, 61 | image_size=vision_cfg.image_size, 62 | width=vision_cfg.width 63 | ) 64 | vision_model = vision_load_pretrain(vision_model,model_str) 65 | vision_model = nn.Sequential(*list(vision_model.children())[:-2]) 66 | visual_dim = 1024 67 | return vision_model,visual_dim,img_preprocessor 68 | 69 | def vision_load_pretrain(resnet,model_path): 70 | checkpoint = torch.load(model_path, map_location='cpu') 71 | state_dict = checkpoint['state_dict'] 72 | state_dict = {k.replace('module.visual.',''): v for k, v in state_dict.items() if '.visual' in k} 73 | resnet.load_state_dict(state_dict) 74 | return resnet 75 | -------------------------------------------------------------------------------- /Models/gpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import base64 3 | from openai import AzureOpenAI 4 | from dotenv import load_dotenv 5 | 6 | def get_client(max_retries=2, timeout=20, api_version="2025-01-01-preview", override=True): 7 | load_dotenv(override=override) 8 | client = AzureOpenAI(api_key=os.getenv("AZURE_OPENAI_API_KEY"), 9 | api_version=api_version, 10 | azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT"), 11 | max_retries=max_retries, 12 | timeout=timeout, 13 | ) 14 | 15 | return client 16 | 17 | def get_response(client, deployment_name, init_prompt, prompt, temperature, max_retry=3, print_error=True): 18 | counter = max_retry 19 | response = None 20 | while counter > 0: 21 | try: 22 | if deployment_name == 'o1': 23 | response = client.chat.completions.create(model=deployment_name, 24 | messages=[ 25 | {"role": "user", "content": prompt}, 26 | ], 27 | ) 28 | else: 29 | response = client.chat.completions.create(model=deployment_name, 30 | messages=[ 31 | {"role": "system", "content": init_prompt}, 32 | {"role": "user", "content": prompt}, 33 | ], 34 | temperature=temperature, 35 | ) 36 | response = response.choices[0].message.content 37 | break 38 | except Exception as e: 39 | if print_error: 40 | print(e) 41 | counter -= 1 42 | return response 43 | 44 | def encode_image(image_path): 45 | with open(image_path, "rb") as image_file: 46 | return base64.b64encode(image_file.read()).decode('utf-8') 47 | 48 | 49 | def ask_question(client, image_path, question, init_prompt, deployment_name, temperature): 50 | base64_image = encode_image(image_path) 51 | content = [ 52 | {"type": "text", 53 | "text": question 54 | }, 55 | {"type": "image_url", 56 | "image_url": { 57 | "url": f"data:image/jpeg;base64,{base64_image}" 58 | } 59 | } 60 | ] 61 | response = get_response(client=client, 62 | deployment_name=deployment_name, 63 | init_prompt=init_prompt, 64 | prompt=content, 65 | temperature=temperature) 66 | 67 | return response -------------------------------------------------------------------------------- /Models/RadFM/multimodality_model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from transformers.models.llama import LlamaForCausalLM 3 | from transformers import AutoConfig 4 | from .my_embedding_layer import MyEmbedding 5 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 6 | import tqdm.auto as tqdm 7 | import torch.nn as nn 8 | import torch 9 | from torch.utils.checkpoint import checkpoint 10 | from torch.autograd import Variable 11 | import numpy as np 12 | class MultiLLaMAForCausalLM(nn.Module): 13 | def __init__(self, lang_model_path): 14 | super(MultiLLaMAForCausalLM, self).__init__() 15 | try: 16 | self.lang_model = LlamaForCausalLM.from_pretrained(lang_model_path) 17 | except: 18 | config = AutoConfig.from_pretrained(lang_model_path) 19 | self.lang_model = LlamaForCausalLM(config) 20 | self.lang_model.gradient_checkpointing_enable() 21 | self.lang_model.enable_input_require_grads() 22 | # self.lang_model.requires_grad_(False) 23 | self.embedding_layer = MyEmbedding() 24 | self.embedding_layer.weight = self.lang_model.get_input_embeddings().weight 25 | self.hidden_dim = 5120 26 | self.voc_size = 32000 27 | 28 | def forward(self, lang_x, vision_x, attention_mask, labels, loss_reweight, key_words_query): 29 | if labels.shape == lang_x.shape: 30 | self.embedding_layer.flag = 'Text' 31 | # lang_x = lang_x.to(vision_x.dtype) 32 | # lang_x = lang_x + torch.zeros(1, dtype=lang_x.dtype, device=lang_x.device, requires_grad=True) 33 | # vision_x = vision_x + torch.zeros(1, dtype=vision_x.dtype, device=vision_x.device, requires_grad=True) 34 | # input_embedding = checkpoint(self.embedding_layer, lang_x, vision_x) 35 | input_embedding,loss_match= self.embedding_layer(lang_x, vision_x, key_words_query) # ,loss_matching 36 | output = self.lang_model(inputs_embeds=input_embedding, attention_mask=attention_mask, labels=labels) 37 | logits = output['logits'] 38 | 39 | loss_reg = None 40 | if labels is not None: 41 | # Shift so that tokens < n predict n 42 | shift_logits = logits[..., :-1, :].contiguous() 43 | shift_labels = labels[..., 1:].contiguous() 44 | shift_loss_reweight = loss_reweight[...,1:].contiguous() 45 | # Flatten the tokens 46 | loss_fct = CrossEntropyLoss(reduction = 'none') 47 | shift_logits = shift_logits.view(-1, self.voc_size) 48 | shift_labels = shift_labels.view(-1) 49 | shift_loss_reweight = shift_loss_reweight.view(-1) 50 | # Enable model parallelism 51 | shift_labels = shift_labels.to(shift_logits.device) 52 | shift_loss_reweight = shift_loss_reweight.to(shift_logits.device) 53 | loss_reg = loss_fct(shift_logits, shift_labels) 54 | loss_reg = torch.sum(shift_loss_reweight*loss_reg)/torch.sum(shift_loss_reweight) 55 | loss = loss_reg 56 | if loss_match!= None: 57 | loss = 0.8*loss + 0.2*loss_match 58 | 59 | logits = output['logits'][..., :-1, :].contiguous().detach() 60 | total = len(labels) 61 | predictions = torch.argmax(logits, dim=-1) 62 | labels = labels[..., 1:].contiguous() 63 | Acc = torch.sum(torch.all(torch.logical_or(predictions == labels, labels == -100),dim = -1)) 64 | Accuracy = Acc /total 65 | 66 | return dict( 67 | # loss_reg = loss_reg, 68 | # loss_matching = loss_matching, 69 | logits = Accuracy, 70 | loss = output['loss'], 71 | ) 72 | ### useless for now ignore the folowing codes ### 73 | # if labels.shape == vision_x.shape: 74 | # self.embedding_layer.flag = 'Seg' 75 | # input_embedding = self.embedding_layer(lang_x, vision_x) 76 | 77 | def generate(self, lang_x,vision_x): 78 | self.embedding_layer.flag = 'Text' 79 | with torch.no_grad(): 80 | input_embedding,_ = self.embedding_layer(lang_x, vision_x) 81 | generation = self.lang_model.generate(inputs_embeds=input_embedding, max_new_tokens=200, top_k=50) 82 | return generation 83 | -------------------------------------------------------------------------------- /Models/blip2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from transformers import Blip2Processor, Blip2ForConditionalGeneration 4 | 5 | 6 | def load_model(model_id, device='cuda'): 7 | processor = Blip2Processor.from_pretrained(model_id) 8 | model = Blip2ForConditionalGeneration.from_pretrained(model_id) 9 | model.to(device) 10 | return model, processor 11 | 12 | def ask_question(model, question, image_path, processor, num_beams, max_length, top_p, temperature, mode): 13 | image = Image.open(image_path).convert("RGB") 14 | if mode == 'prefix': 15 | return do_prefix_forward(model, question, image, processor) 16 | if mode == 'greedy': 17 | return do_forward(model, processor, image, question) 18 | elif mode in ['mc', 'gpt4']: 19 | return do_generation(model, 20 | processor, 21 | image, 22 | question, 23 | num_beams=num_beams, 24 | top_p=top_p, 25 | temperature=temperature, 26 | max_new_tokens=max_length) 27 | 28 | @torch.no_grad() 29 | def do_generation(model, 30 | processor, 31 | image, 32 | question, 33 | num_beams, 34 | top_p, 35 | temperature, 36 | max_new_tokens): 37 | device = model.device 38 | inputs = processor(images=image, text=question, return_tensors="pt").to(device=device) 39 | outputs = model.generate(**inputs, 40 | do_sample=(temperature > 0), 41 | num_beams=num_beams, 42 | top_p=top_p, 43 | temperature=temperature, 44 | max_new_tokens=max_new_tokens) 45 | generated_text = processor.decode(outputs[0], skip_special_tokens=True) 46 | return generated_text 47 | 48 | @torch.no_grad() 49 | def do_forward(model, processor, image, question): 50 | VALID_ANSWERS = ['A', 'B'] 51 | device = model.device 52 | TOKEN_IDs = [processor.tokenizer(x, return_tensors="pt", add_special_tokens=False).get('input_ids') for x in VALID_ANSWERS] 53 | inputs = processor(images=image, text=question, return_tensors="pt").to(device=device) 54 | logits = model.forward(**inputs).logits 55 | logits = logits[0, -1, :] 56 | logits = logits.reshape(-1, 1) 57 | soft_max = torch.nn.Softmax(dim=0) 58 | probs = soft_max(torch.cat([logits[x] for x in TOKEN_IDs][:len(VALID_ANSWERS)])) 59 | outputs = VALID_ANSWERS[probs.argmax().item()] 60 | return outputs 61 | 62 | @torch.no_grad() 63 | def do_prefix_forward(model, problem, image, processor): 64 | # PREFIX_PROMPT_TEMPLATE = "Question: {} Answer: {}" 65 | device = model.device 66 | PREFIX_PROMPT_TEMPLATE = problem.get('format') 67 | scores = [] 68 | questions = [] 69 | qs = problem["question"] 70 | 71 | for option in [problem["option_A"], problem["option_B"]]: 72 | prompt = PREFIX_PROMPT_TEMPLATE.format(qs, option) 73 | questions.append(prompt) 74 | inputs = processor(images=image, text=prompt, return_tensors="pt").to(device=device) 75 | answer_tokens = processor.tokenizer.encode(' ' + option, add_special_tokens=False) 76 | num_answer_tokens = len(answer_tokens) 77 | input_ids = inputs["input_ids"] 78 | # try to find the answer tokens in input ids 79 | start_indices = [] 80 | for i in range(input_ids.size(1) - num_answer_tokens + 1): 81 | if torch.equal(input_ids[0, i:i+num_answer_tokens], torch.tensor(answer_tokens).to(device=device)): 82 | start_indices.append(i) 83 | 84 | if len(start_indices) == 0: 85 | raise ValueError("Answer tokens not found in input_ids") 86 | answer_start = start_indices[-1] 87 | answer_start_from_back = answer_start - input_ids.size(1) 88 | with torch.inference_mode(): 89 | out = model(**inputs 90 | ) 91 | # shift by 1 compared to input 92 | logits = out.logits[0, answer_start_from_back-1:answer_start_from_back-1+num_answer_tokens] 93 | probs = torch.nn.functional.softmax(logits, dim=-1) 94 | 95 | # Pick the probabilities corresponding to each of the answer tokens 96 | probs = torch.gather(probs, 1, torch.tensor(answer_tokens).to(device=device).unsqueeze(0)) 97 | prefix_score = torch.prod(probs.pow(1/num_answer_tokens)) 98 | scores.append(prefix_score.item()) 99 | 100 | outputs = "A" if scores[0] > scores[1] else "B" 101 | return outputs -------------------------------------------------------------------------------- /Models/instructblip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration 4 | 5 | def load_model(model_id, device='cpu'): 6 | processor = InstructBlipProcessor.from_pretrained(model_id, device=device) 7 | model = InstructBlipForConditionalGeneration.from_pretrained(model_id) 8 | model.to(device=device) 9 | return model, processor 10 | 11 | 12 | def ask_question(model, question, image_path, processor, num_beams, max_length, top_p, temperature, mode): 13 | image = Image.open(image_path).convert("RGB") 14 | device = model.device 15 | if mode == 'prefix': 16 | return do_prefix_forward(model, question, image, processor) 17 | inputs = processor(images=image, text=question, return_tensors="pt").to(device=device, dtype=torch.float16) 18 | if mode == 'greedy': 19 | return do_forward(model, processor, inputs) 20 | elif mode in ['mc', 'gpt4']: 21 | return do_generation(model, 22 | processor, 23 | inputs, 24 | num_beams=num_beams, 25 | top_p=top_p, 26 | repetition_penalty=1.5, 27 | length_penalty=1, 28 | temperature=temperature, 29 | max_new_tokens=max_length) 30 | 31 | @torch.no_grad() 32 | def do_generation(model, 33 | processor, 34 | inputs, 35 | num_beams, 36 | top_p, 37 | repetition_penalty, 38 | length_penalty, 39 | temperature, 40 | max_new_tokens): 41 | 42 | outputs = model.generate( 43 | **inputs, 44 | num_beams=num_beams, 45 | do_sample=(temperature > 0), 46 | max_new_tokens=max_new_tokens, 47 | min_length=1, 48 | top_p=top_p, 49 | repetition_penalty=repetition_penalty, 50 | length_penalty=length_penalty, 51 | temperature=temperature, 52 | ) 53 | generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip() 54 | return generated_text 55 | 56 | @torch.no_grad() 57 | def do_forward(model, processor, inputs): 58 | VALID_ANSWERS = ['A', 'B'] 59 | TOKEN_IDs = [processor.tokenizer(x, return_tensors="pt", add_special_tokens=False).get('input_ids') for x in VALID_ANSWERS] 60 | logits = model.forward(**inputs).logits[0, -1, :] 61 | logits = logits.reshape(-1, 1) 62 | soft_max = torch.nn.Softmax(dim=0) 63 | probs = soft_max(torch.cat([logits[x] for x in TOKEN_IDs])) 64 | outputs = VALID_ANSWERS[probs.argmax().item()] 65 | return outputs 66 | 67 | @torch.no_grad() 68 | def do_prefix_forward(model, problem, image, processor): 69 | # PREFIX_PROMPT_TEMPLATE = "Question: {} Answer: {}" 70 | device = model.device 71 | PREFIX_PROMPT_TEMPLATE = problem.get('format') 72 | scores = [] 73 | questions = [] 74 | qs = problem["question"] 75 | 76 | for option in [problem["option_A"], problem["option_B"]]: 77 | prompt = PREFIX_PROMPT_TEMPLATE.format(qs, option) 78 | questions.append(prompt) 79 | inputs = processor(images=image, text=prompt, return_tensors="pt").to(device=device, dtype=torch.float16) 80 | answer_tokens = processor.tokenizer.encode(" " + option, add_special_tokens=False)[1:] 81 | num_answer_tokens = len(answer_tokens) 82 | input_ids = inputs["input_ids"] 83 | 84 | # try to find the answer tokens in input ids 85 | start_indices = [] 86 | for i in range(input_ids.size(1) - num_answer_tokens + 1): 87 | if torch.equal(input_ids[0, i:i+num_answer_tokens], torch.tensor(answer_tokens).to(device=device)): 88 | start_indices.append(i) 89 | 90 | if len(start_indices) == 0: 91 | raise ValueError("Answer tokens not found in input_ids") 92 | answer_start = start_indices[-1] 93 | answer_start_from_back = answer_start - input_ids.size(1) 94 | with torch.inference_mode(): 95 | out = model(**inputs 96 | ) 97 | # shift by 1 compared to input 98 | logits = out.logits[0, answer_start_from_back-1:answer_start_from_back-1+num_answer_tokens] 99 | probs = torch.nn.functional.softmax(logits, dim=-1) 100 | 101 | # Pick the probabilities corresponding to each of the answer tokens 102 | probs = torch.gather(probs, 1, torch.tensor(answer_tokens).to(device=device).unsqueeze(0)) 103 | prefix_score = torch.prod(probs.pow(1/num_answer_tokens)) 104 | scores.append(prefix_score.item()) 105 | outputs = "A" if scores[0] > scores[1] else "B" 106 | return outputs -------------------------------------------------------------------------------- /Models/molmo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig 4 | 5 | 6 | def load_model(model_id, device='cuda'): 7 | processor = AutoProcessor.from_pretrained( 8 | 'allenai/Molmo-7B-D-0924', 9 | trust_remote_code=True, 10 | torch_dtype='auto', 11 | device_map=device 12 | ) 13 | model = AutoModelForCausalLM.from_pretrained( 14 | model_id, 15 | trust_remote_code=True, 16 | torch_dtype='auto', 17 | device_map=device 18 | ) 19 | return model, processor 20 | 21 | def ask_question(model, question, image_path, processor, num_beams, max_length, top_p, temperature, mode): 22 | image = [Image.open(image_path).convert("RGB")] 23 | if mode == 'prefix': 24 | return do_prefix_forward(model, question, image, processor) 25 | inputs = processor.process(images=image, text=question) 26 | inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()} 27 | if mode == 'greedy': 28 | return do_forward(model, processor, inputs) 29 | elif mode in ['mc', 'gpt4']: 30 | return do_generation(model, 31 | processor, 32 | inputs, 33 | num_beams=num_beams, 34 | top_p=top_p, 35 | temperature=temperature, 36 | max_new_tokens=max_length) 37 | 38 | @torch.no_grad() 39 | def do_generation(model, 40 | processor, 41 | inputs, 42 | num_beams, 43 | top_p, 44 | temperature, 45 | max_new_tokens): 46 | gc = GenerationConfig( 47 | do_sample=(temperature > 0), 48 | num_beams=num_beams, 49 | top_p=top_p, 50 | temperature=temperature, 51 | max_new_tokens=max_new_tokens, 52 | stop_strings="<|endoftext|>", 53 | ) 54 | output = model.generate_from_batch(inputs, gc, tokenizer=processor.tokenizer) 55 | generated_tokens = output[0, inputs['input_ids'].size(1):] 56 | generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True) 57 | return generated_text 58 | 59 | @torch.no_grad() 60 | def do_forward(model, processor, inputs): 61 | VALID_ANSWERS = ['A', 'B'] 62 | TOKEN_IDs = [processor.tokenizer(x, return_tensors="pt", add_special_tokens=False).get('input_ids') for x in VALID_ANSWERS] 63 | logits = model.forward(**inputs).logits 64 | logits = logits[0, -1, :] 65 | logits = logits.reshape(-1, 1) 66 | soft_max = torch.nn.Softmax(dim=0) 67 | probs = soft_max(torch.cat([logits[x] for x in TOKEN_IDs][:len(VALID_ANSWERS)])) 68 | outputs = VALID_ANSWERS[probs.argmax().item()] 69 | return outputs 70 | 71 | @torch.no_grad() 72 | def do_prefix_forward(model, problem, image, processor): 73 | # PREFIX_PROMPT_TEMPLATE = "Question: {} Answer: {}" 74 | device = model.device 75 | PREFIX_PROMPT_TEMPLATE = problem.get('format') 76 | scores = [] 77 | 78 | qs = problem["question"] 79 | 80 | for option in [problem["option_A"], problem["option_B"]]: 81 | prompt = PREFIX_PROMPT_TEMPLATE.format(qs, option, return_tensors="pt") 82 | inputs = processor.process(images=image, text=prompt) 83 | inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()} 84 | answer_tokens = processor.tokenizer.encode(' ' + option, add_special_tokens=False) 85 | num_answer_tokens = len(answer_tokens) 86 | input_ids = inputs["input_ids"] 87 | # try to find the answer tokens in input ids 88 | start_indices = [] 89 | for i in range(input_ids.size(1) - num_answer_tokens + 1): 90 | if torch.equal(input_ids[0, i:i+num_answer_tokens], torch.tensor(answer_tokens).to(device=device)): 91 | start_indices.append(i) 92 | 93 | if len(start_indices) == 0: 94 | raise ValueError("Answer tokens not found in input_ids") 95 | answer_start = start_indices[-1] 96 | answer_start_from_back = answer_start - input_ids.size(1) 97 | with torch.inference_mode(): 98 | out = model(**inputs 99 | ) 100 | # shift by 1 compared to input 101 | logits = out.logits[0, answer_start_from_back-1:answer_start_from_back-1+num_answer_tokens] 102 | probs = torch.nn.functional.softmax(logits, dim=-1) 103 | 104 | # Pick the probabilities corresponding to each of the answer tokens 105 | probs = torch.gather(probs, 1, torch.tensor(answer_tokens).to(device=device).unsqueeze(0)) 106 | prefix_score = torch.prod(probs.pow(1/num_answer_tokens)) 107 | scores.append(prefix_score.item()) 108 | 109 | outputs = "A" if scores[0] > scores[1] else "B" 110 | return outputs -------------------------------------------------------------------------------- /Models/RadFM/vit_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from einops import rearrange, repeat 5 | from einops.layers.torch import Rearrange 6 | from .position_encoding import PositionEmbeddingLearned3d 7 | 8 | # helpers 9 | 10 | def pair(t): 11 | return t if isinstance(t, tuple) else (t, t) 12 | 13 | # classes 14 | 15 | class PreNorm(nn.Module): 16 | def __init__(self, dim, fn): 17 | super().__init__() 18 | self.norm = nn.LayerNorm(dim) 19 | self.fn = fn 20 | def forward(self, x, **kwargs): 21 | return self.fn(self.norm(x), **kwargs) 22 | 23 | class FeedForward(nn.Module): 24 | def __init__(self, dim, hidden_dim, dropout = 0.): 25 | super().__init__() 26 | self.net = nn.Sequential( 27 | nn.Linear(dim, hidden_dim), 28 | nn.GELU(), 29 | nn.Dropout(dropout), 30 | nn.Linear(hidden_dim, dim), 31 | nn.Dropout(dropout) 32 | ) 33 | def forward(self, x): 34 | return self.net(x) 35 | 36 | class Attention(nn.Module): 37 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 38 | super().__init__() 39 | inner_dim = dim_head * heads 40 | project_out = not (heads == 1 and dim_head == dim) 41 | 42 | self.heads = heads 43 | self.scale = dim_head ** -0.5 44 | 45 | self.attend = nn.Softmax(dim = -1) 46 | self.dropout = nn.Dropout(dropout) 47 | 48 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 49 | 50 | self.to_out = nn.Sequential( 51 | nn.Linear(inner_dim, dim), 52 | nn.Dropout(dropout) 53 | ) if project_out else nn.Identity() 54 | 55 | def forward(self, x): 56 | qkv = self.to_qkv(x).chunk(3, dim = -1) 57 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 58 | 59 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 60 | 61 | attn = self.attend(dots) 62 | attn = self.dropout(attn) 63 | 64 | out = torch.matmul(attn, v) 65 | out = rearrange(out, 'b h n d -> b n (h d)') 66 | return self.to_out(out) 67 | 68 | class Transformer(nn.Module): 69 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 70 | super().__init__() 71 | self.layers = nn.ModuleList([]) 72 | for _ in range(depth): 73 | self.layers.append(nn.ModuleList([ 74 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), 75 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) 76 | ])) 77 | def forward(self, x): 78 | for attn, ff in self.layers: 79 | x = attn(x) + x 80 | x = ff(x) + x 81 | return x 82 | 83 | class ViT(nn.Module): 84 | def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): 85 | super().__init__() 86 | image_height, image_width = pair(image_size) 87 | patch_height, patch_width = pair(image_patch_size) 88 | 89 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 90 | assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size' 91 | 92 | self.patch_height = patch_height 93 | self.patch_width = patch_width 94 | self.frame_patch_size = frame_patch_size 95 | 96 | num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size) 97 | patch_dim = channels * patch_height * patch_width * frame_patch_size 98 | 99 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 100 | 101 | self.to_patch_embedding = nn.Sequential( 102 | Rearrange('b c (h p1) (w p2) (f pf) -> b (h w f) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size), 103 | nn.LayerNorm(patch_dim), 104 | nn.Linear(patch_dim, dim), 105 | nn.LayerNorm(dim), 106 | ) 107 | 108 | self.pos_embedding = PositionEmbeddingLearned3d(dim // 3,(image_height // patch_height), (image_width // patch_width), (frames // frame_patch_size)) 109 | self.dropout = nn.Dropout(emb_dropout) 110 | 111 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 112 | 113 | def forward(self, video): 114 | B, C, H, W, D = video.shape 115 | x = self.to_patch_embedding(video) 116 | b, n, _ = x.shape 117 | 118 | pos = self.pos_embedding(B, H // self.patch_height, W // self.patch_width, D // self.frame_patch_size,x) 119 | x += pos 120 | x = self.dropout(x) 121 | 122 | x = self.transformer(x) 123 | return x,pos 124 | -------------------------------------------------------------------------------- /Models/llava.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration 3 | from llava.conversation import conv_templates 4 | from llava.mm_utils import tokenizer_image_token 5 | from llava.constants import IMAGE_TOKEN_INDEX 6 | 7 | 8 | def load_model(model_id, device='cuda'): 9 | processor = LlavaNextProcessor.from_pretrained(model_id, device_map=device) 10 | model = LlavaNextForConditionalGeneration.from_pretrained(model_id) 11 | model.to(device) 12 | return model, processor 13 | 14 | def ask_question(model, processor, question, image, mode, temperature=0.2, top_p=None, num_beams=1, max_new_tokens=100): 15 | if mode == 'prefix': 16 | return do_prefix_forward(model, question, image, processor) 17 | 18 | conversation = [ 19 | { 20 | "role": "user", 21 | "content": [ 22 | {"type": "text", "text": question}, 23 | {"type": "image"}, 24 | ], 25 | }, 26 | ] 27 | prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) 28 | inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device) 29 | 30 | if mode == 'greedy': 31 | outputs = do_forward(model, inputs, processor) 32 | elif mode in ['mc', 'gpt4']: 33 | outputs = do_generation(model, inputs, processor, temperature, top_p, num_beams, max_new_tokens) 34 | return outputs 35 | 36 | 37 | @torch.no_grad() 38 | def do_forward(model, inputs, processor): 39 | VALID_ANSWERS = ['A', 'B'] 40 | TOKEN_IDs = [processor.tokenizer.encode(x, return_tensors="pt", add_special_tokens=False) for x in VALID_ANSWERS] 41 | 42 | with torch.inference_mode(): 43 | out = model.forward(**inputs) 44 | 45 | logits = out.logits[0, -1, :] 46 | soft_max = torch.nn.Softmax(dim=0) 47 | probs = soft_max(torch.cat([logits[x] for x in TOKEN_IDs])) 48 | outputs = VALID_ANSWERS[probs.argmax().item()] 49 | return outputs 50 | 51 | @torch.no_grad() 52 | def do_generation(model, inputs, processor, temperature, top_p, num_beams, max_new_tokens): 53 | output = model.generate(**inputs, 54 | temperature=temperature, 55 | do_sample=(temperature > 0), 56 | top_p=top_p, 57 | num_beams=num_beams, 58 | max_new_tokens=max_new_tokens) 59 | tmp = processor.decode(inputs.get('input_ids')[0], skip_special_tokens=True) 60 | txt = processor.decode(output[0], skip_special_tokens=True).replace(tmp, '') 61 | return txt 62 | 63 | @torch.no_grad() 64 | def do_prefix_forward(model, problem, image, processor): 65 | # python scripts/answering.py --vlm_name llava --mode prefix 66 | device = model.device 67 | scores = [] 68 | questions = [] 69 | qs = problem["question"] 70 | for option in [problem["option_A"], problem["option_B"]]: 71 | conv_template = [ 72 | { 73 | "role": "user", 74 | "content": [ 75 | {"type": "image"}, 76 | {"type": "text", "text": f"{qs}"}, 77 | ], 78 | }, 79 | { 80 | "role": "assistant", 81 | "content": [ 82 | {"type": "text", "text": f"{option}"}, 83 | ], 84 | } 85 | ] 86 | prompt = processor.apply_chat_template(conv_template, add_generation_prompt=True) 87 | questions.append(prompt) 88 | inputs = processor(images=image, text=prompt, return_tensors="pt").to(device) 89 | answer_tokens = processor.tokenizer.encode(" " + option, add_special_tokens=False)[1:] 90 | num_answer_tokens = len(answer_tokens) 91 | input_ids = inputs["input_ids"] 92 | # try to find the answer tokens in input ids 93 | start_indices = [] 94 | for i in range(input_ids.size(1) - num_answer_tokens + 1): 95 | if torch.equal(input_ids[0, i:i+num_answer_tokens], torch.tensor(answer_tokens).to(device=device)): 96 | start_indices.append(i) 97 | 98 | if len(start_indices) == 0: 99 | raise ValueError("Answer tokens not found in input_ids") 100 | answer_start = start_indices[-1] 101 | answer_start_from_back = answer_start - input_ids.size(1) 102 | 103 | with torch.inference_mode(): 104 | out = model(**inputs) 105 | # shift by 1 compared to input 106 | 107 | logits = out.logits[0, answer_start_from_back-1:answer_start_from_back-1+num_answer_tokens] 108 | probs = torch.nn.functional.softmax(logits, dim=-1) 109 | # Pick the probabilities corresponding to each of the answer tokens 110 | probs = torch.gather(probs, 1, torch.tensor(answer_tokens).to(device=device).unsqueeze(0)) 111 | prefix_score = torch.prod(probs.pow(1/num_answer_tokens)) 112 | scores.append(prefix_score.item()) 113 | outputs = "A" if scores[0] > scores[1] else "B" 114 | return outputs -------------------------------------------------------------------------------- /Models/llama.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import requests 3 | from PIL import Image 4 | from huggingface_hub import login 5 | # from accelerate import Accelerator 6 | from transformers import MllamaForConditionalGeneration, AutoProcessor 7 | 8 | def load_model(token, model_id, device='cuda'): 9 | if token is not None: 10 | login(token) 11 | else: 12 | login() 13 | model = MllamaForConditionalGeneration.from_pretrained( 14 | model_id, 15 | torch_dtype=torch.bfloat16, 16 | device_map=device, 17 | ) 18 | processor = AutoProcessor.from_pretrained(model_id) 19 | # model.to(device) 20 | return model, processor 21 | 22 | 23 | def ask_question(model, processor, question, image, mode, temperature=0.2, top_p=None, num_beams=1, max_new_tokens=100): 24 | if mode == 'prefix': 25 | return do_prefix_forward(model, question, image, processor) 26 | messages = [ 27 | {"role": "user", "content": [ 28 | {"type": "image"}, 29 | {"type": "text", "text": question} 30 | ]} 31 | ] 32 | input_text = processor.apply_chat_template(messages, add_generation_prompt=True) 33 | inputs = processor(image, input_text, add_special_tokens=False, return_tensors="pt").to(model.device) 34 | 35 | if mode == 'greedy': 36 | outputs = do_forward(model, inputs, processor) 37 | elif mode in ['mc', 'gpt4']: 38 | outputs = do_generation(model, inputs, processor, temperature, top_p, num_beams, max_new_tokens) 39 | return outputs 40 | 41 | 42 | @torch.no_grad() 43 | def do_generation(model, inputs, processor, temperature, top_p, num_beams, max_new_tokens): 44 | output = model.generate(**inputs, 45 | temperature=temperature, 46 | do_sample=(temperature > 0), 47 | top_p=top_p, 48 | num_beams=num_beams, 49 | max_new_tokens=max_new_tokens) 50 | # raise ValueError(inputs.get('input_ids'), output[0]) 51 | tmp = processor.decode(inputs.get('input_ids')[0], skip_special_tokens=True) 52 | txt = processor.decode(output[0], skip_special_tokens=True).replace(tmp, '') 53 | return txt 54 | 55 | def do_forward(model, inputs, processor): 56 | VALID_ANSWERS = ['A', 'B'] 57 | TOKEN_IDs = [processor.tokenizer.encode(x, return_tensors="pt", add_special_tokens=False) for x in VALID_ANSWERS] 58 | 59 | with torch.inference_mode(): 60 | out = model.forward(**inputs) 61 | 62 | logits = out.logits[0, -1, :] 63 | soft_max = torch.nn.Softmax(dim=0) 64 | probs = soft_max(torch.cat([logits[x] for x in TOKEN_IDs])) 65 | outputs = VALID_ANSWERS[probs.argmax().item()] 66 | return outputs 67 | 68 | 69 | @torch.no_grad() 70 | def do_prefix_forward(model, problem, image, processor): 71 | device = model.device 72 | scores = [] 73 | questions = [] 74 | qs = problem["question"] 75 | for option in [problem["option_A"], problem["option_B"]]: 76 | conv_template = [ 77 | { 78 | "role": "user", 79 | "content": [ 80 | {"type": "image"}, 81 | {"type": "text", "text": f"Question: {qs}"}, 82 | ], 83 | }, 84 | { 85 | "role": "assistant", 86 | "content": [ 87 | {"type": "text", "text": f"{option}"}, 88 | ], 89 | } 90 | ] 91 | prompt = processor.apply_chat_template(conv_template, add_generation_prompt=True) 92 | questions.append(prompt) 93 | inputs = processor(images=image, text=prompt, return_tensors="pt").to(device) 94 | answer_tokens = processor.tokenizer.encode(option, add_special_tokens=False) 95 | num_answer_tokens = len(answer_tokens) 96 | 97 | input_ids = inputs["input_ids"] 98 | # try to find the answer tokens in input ids 99 | start_indices = [] 100 | 101 | for i in range(input_ids.size(1) - num_answer_tokens + 1): 102 | if torch.equal(input_ids[0, i:i+num_answer_tokens], torch.tensor(answer_tokens).to(device=device)): 103 | start_indices.append(i) 104 | 105 | if len(start_indices) == 0: 106 | raise ValueError("Answer tokens not found in input_ids") 107 | answer_start = start_indices[-1] 108 | answer_start_from_back = answer_start - input_ids.size(1) 109 | 110 | with torch.inference_mode(): 111 | out = model(**inputs) 112 | # shift by 1 compared to input 113 | logits = out.logits[0, answer_start_from_back-1:answer_start_from_back-1+num_answer_tokens] 114 | probs = torch.nn.functional.softmax(logits, dim=-1) 115 | 116 | # Pick the probabilities corresponding to each of the answer tokens 117 | probs = torch.gather(probs, 1, torch.tensor(answer_tokens).to(device=device).unsqueeze(0)) 118 | prefix_score = torch.prod(probs.pow(1/num_answer_tokens)) 119 | scores.append(prefix_score.item()) 120 | outputs = "A" if scores[0] > scores[1] else "B" 121 | return outputs -------------------------------------------------------------------------------- /Models/RadFM/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Various positional encodings for the transformer. 4 | """ 5 | import math 6 | import torch 7 | from torch import nn 8 | from einops.layers.torch import Rearrange 9 | from einops import rearrange, repeat 10 | 11 | class PositionEmbeddingSine(nn.Module): 12 | """ 13 | This is a more standard version of the position embedding, very similar to the one 14 | used by the Attention is all you need paper, generalized to work on images. 15 | """ 16 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 17 | super().__init__() 18 | self.num_pos_feats = num_pos_feats 19 | self.temperature = temperature 20 | self.normalize = normalize 21 | if scale is not None and normalize is False: 22 | raise ValueError("normalize should be True if scale is passed") 23 | if scale is None: 24 | scale = 2 * math.pi 25 | self.scale = scale 26 | 27 | def forward(self, tensor_list): 28 | x = tensor_list.tensors 29 | mask = tensor_list.mask 30 | assert mask is not None 31 | not_mask = ~mask 32 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 33 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 34 | if self.normalize: 35 | eps = 1e-6 36 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 37 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 38 | 39 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 40 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 41 | 42 | pos_x = x_embed[:, :, :, None] / dim_t 43 | pos_y = y_embed[:, :, :, None] / dim_t 44 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 45 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 46 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 47 | return pos 48 | 49 | 50 | class PositionEmbeddingLearned(nn.Module): 51 | """ 52 | Absolute pos embedding, learned. 53 | """ 54 | def __init__(self, num_pos_feats=256): 55 | super().__init__() 56 | self.row_embed = nn.Embedding(50, num_pos_feats) 57 | self.col_embed = nn.Embedding(50, num_pos_feats) 58 | self.reset_parameters() 59 | 60 | def reset_parameters(self): 61 | nn.init.uniform_(self.row_embed.weight) 62 | nn.init.uniform_(self.col_embed.weight) 63 | 64 | def forward(self, tensor_list): 65 | x = tensor_list.tensors 66 | h, w = x.shape[-2:] 67 | i = torch.arange(w, device=x.device) 68 | j = torch.arange(h, device=x.device) 69 | x_emb = self.col_embed(i) 70 | y_emb = self.row_embed(j) 71 | pos = torch.cat([ 72 | x_emb.unsqueeze(0).repeat(h, 1, 1), 73 | y_emb.unsqueeze(1).repeat(1, w, 1), 74 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 75 | return pos 76 | 77 | class PositionEmbeddingLearned3d(nn.Module): 78 | """ 79 | Absolute pos embedding, learned. 80 | """ 81 | def __init__(self, num_pos_feats=256,h_patch_num = 16, w_patch_num = 16,d_patch_num = 64): 82 | super().__init__() 83 | self.h_patch_num = h_patch_num 84 | self.w_patch_num = w_patch_num 85 | self.d_patch_num = d_patch_num 86 | self.row_embed = nn.Embedding(h_patch_num, num_pos_feats) 87 | self.col_embed = nn.Embedding(w_patch_num, num_pos_feats) 88 | self.dep_embed = nn.Embedding(d_patch_num, num_pos_feats) 89 | self.reset_parameters() 90 | 91 | def reset_parameters(self): 92 | nn.init.uniform_(self.row_embed.weight) 93 | nn.init.uniform_(self.col_embed.weight) 94 | nn.init.uniform_(self.dep_embed.weight) 95 | 96 | def forward(self, B, h, w, d,x): 97 | i = (torch.arange(h, device=x.device) + 1)* (self.h_patch_num // h) -1 98 | j = (torch.arange(w, device=x.device) + 1)* (self.w_patch_num // w) -1 99 | k = (torch.arange(d, device=x.device) + 1)* (self.d_patch_num // d) -1 100 | x_emb = self.row_embed(i).unsqueeze(1).unsqueeze(2).repeat(1,w,d,1) 101 | y_emb = self.col_embed(j).unsqueeze(0).unsqueeze(2).repeat(h,1,d,1) 102 | z_emb = self.dep_embed(k).unsqueeze(0).unsqueeze(1).repeat(h,w,1,1) 103 | pos = torch.cat([x_emb,y_emb,z_emb,], dim=-1).unsqueeze(0).repeat(B, 1, 1, 1, 1) 104 | pos = rearrange(pos,'b h w d c -> b (h w d) c') 105 | return pos 106 | 107 | def build_position_encoding(args): 108 | N_steps = args.hidden_dim // 2 109 | if args.position_embedding in ('v2', 'sine'): 110 | # TODO find a better way of exposing other arguments 111 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 112 | elif args.position_embedding in ('v3', 'learned'): 113 | position_embedding = PositionEmbeddingLearned(N_steps) 114 | else: 115 | raise ValueError(f"not supported {args.position_embedding}") 116 | 117 | return position_embedding 118 | 119 | # Pos = PositionEmbeddingLearned3d() 120 | # x = torch.randn((8,3,32,32,1)) 121 | # print(Pos(8,16,16,1,x)) -------------------------------------------------------------------------------- /configs/prompts/answering.json: -------------------------------------------------------------------------------- 1 | { 2 | "prompts": { 3 | "gpt4": { 4 | "default": "Based on the image, answer following question.\n{}", 5 | "blip2": "Question: {} Answer:", 6 | "med_flamingo": "You are a helpful medical assistant. You are being provided with images, a two choice question about each image and an answer. Follow the examples and answer the last question. Question: **Q1**<|endofchunk|>Question: **Q2**|endofchunk|>Question: **Q3**<|endofchunk|>Question: {} Answer:" 7 | }, 8 | "greedy": { 9 | "default": "Based on the image, choose the correct option for the following question.\nQuestion: {}\nA: {}\nB: {}\nAnswer with the option's letter from the given choices directly.\nAnswer: ", 10 | "molmo": "Based on the image, choose the correct option for the following question.\nQuestion: {}\n(\"A\": {})\n(\"B\": {})\nAnswer with the option's letter from the given choices directly. Your answer should be just one letter.\n#answer: (\"", 11 | "med_flamingo": "You are a helpful medical assistant. You are being provided with images, a two choice question about each image and an answer. Follow the examples and answer the last question. Question: **Q1**<|endofchunk|>Question: **Q2**|endofchunk|>Question: **Q3**<|endofchunk|>Question: {}\nA: {}\nB: {}\nAnswer:", 12 | "blip2": "Question: Based on the image, choose the correct option for the following question. {}\nA: {}\nB: {}\nAnswer with the option's letter from the given choices directly. Answer:" 13 | }, 14 | "mc": { 15 | "default": "Based on the image, choose the correct option for the following question.\nQuestion: {}\n(\"A\": {})\n(\"B\": {})\nAnswer with the option's letter from the given choices directly. Your answer should be just one letter.\n#answer: (\"", 16 | "llava_med": "Based on the image, choose the correct option for the following question.\nQuestion: {}\nA: {}\nB: {}\nAnswer with the option's letter from the given choices directly. Your answer should be just one letter.\nAnswer: ", 17 | "gpt": "Based on the image, choose the correct option for the following question.\nQuestion: {}\nA: {}\nB: {}\nAnswer with the option's letter from the given choices directly. Your answer should be just one letter.\nAnswer: ", 18 | "deepseek": "Based on the image, choose the correct option for the following question.\nQuestion: {}\nA: {}\nB: {}\nAnswer with the option's letter from the given choices directly. Your answer should be just one letter.\nAnswer: ", 19 | "gemini": "Based on the image, choose the correct option for the following question.\nQuestion: {}\nA: {}\nB: {}\nAnswer with the option's letter from the given choices directly. Your answer should be just one letter.\nAnswer: ", 20 | "claude": "Based on the image, choose the correct option for the following question.\nQuestion: {}\nA: {}\nB: {}\nAnswer with the option's letter from the given choices directly. Your answer should be just one letter.\nAnswer: ", 21 | "medvint": "Question: {} The choices are: A: {} B: {} The Answer is: ", 22 | "blip2": "Question: Based on the image, choose the correct option for the following question. {}\nA: {}\nB: {}\nAnswer with the option's letter from the given choices directly. Answer:", 23 | "med_flamingo": "You are a helpful medical assistant. You are being provided with images, a two choice question about each image and an answer. Follow the examples and answer the last question. Question: **Q1**<|endofchunk|>Question: **Q2**|endofchunk|>Question: **Q3**<|endofchunk|>Question: {}\nA: {}\nB: {}\nAnswer:" 24 | }, 25 | "prefix": { 26 | "default": "Question: {} Answer: {}", 27 | "radfm": "{} {}", 28 | "med_flamingo": "You are a helpful medical assistant. You are being provided with images, a question about each image and an answer. Follow the examples and answer the last question. Question: **Q1**<|endofchunk|>Question: **Q2**<|endofchunk|>Question: **Q3**<|endofchunk|>Question: {} Answer: {}" 29 | } 30 | }, 31 | 32 | "init_prompts": { 33 | "default": "You are a helpful assitant expert in medical domain.", 34 | "llava_med": { 35 | "1": "You are a helpful and precise assistant for checking the quality of the answer." 36 | } 37 | }, 38 | 39 | "conversion": { 40 | "instruct_prompt": "We would like to request your feedback on the performance of an AI assistant in response to the user question displayed above. The user asks the question on observing an image. We have provided two possible answers, [Answer A] and [Answer B] to the question. Your job is to evaluate how close the AI assistant's answer is to each of the answers. You don't have to decide whether the answers are correct or not. Each answer should receive an overall score on a scale of 1 to 10, where a higher score indicates the AI assistant's answer is closer to the specific answer. After providing the scores, concisely provide your explanation for the given scores. Remember, you don't need to comment on the correctness of the answers. Please provide your answer in the following format:\nA: \nB: \nYour explanation: ", 41 | "role": "AI Assistant", 42 | "init_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", 43 | "gpt_deployment_name": "gpt4o", 44 | "temperature": 0.7 45 | } 46 | } -------------------------------------------------------------------------------- /Models/llava_med.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from utils import io_tools 4 | from llava.utils import disable_torch_init 5 | from llava.conversation import conv_templates 6 | from llava.model.builder import load_pretrained_model 7 | from llava.mm_utils import get_model_name_from_path, tokenizer_image_token, process_images 8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 9 | from torch.nn import CrossEntropyLoss 10 | from torch.nn.functional import softmax 11 | 12 | def load_model(model_path, model_base): 13 | disable_torch_init() 14 | model_path = os.path.expanduser(model_path) 15 | model_name = get_model_name_from_path(model_path) 16 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name) 17 | return tokenizer, model, image_processor, context_len 18 | 19 | def get_input_id(tokenizer, question, conv_mode): 20 | # qs = convert_question(question, mm_use_im_start_end, use_options) 21 | conv = conv_templates[conv_mode].copy() 22 | conv.append_message(conv.roles[0], question) 23 | conv.append_message(conv.roles[1], None) 24 | prompt = conv.get_prompt() 25 | 26 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 27 | return input_ids 28 | 29 | @torch.no_grad() 30 | def do_forward(model, input_ids, image_tensor, image_size, tokenizer): 31 | VALID_ANSWERS = ['A', 'B'] 32 | TOKEN_IDs = [tokenizer.encode(x, return_tensors="pt", add_special_tokens=False) for x in VALID_ANSWERS] 33 | 34 | with torch.inference_mode(): 35 | out = model(input_ids, 36 | images=image_tensor.unsqueeze(0).half().cuda(), 37 | image_sizes=[image_size], 38 | ) 39 | 40 | logits = out.logits[0, -1, :] 41 | soft_max = torch.nn.Softmax(dim=0) 42 | probs = soft_max(torch.cat([logits[x] for x in TOKEN_IDs])) 43 | outputs = VALID_ANSWERS[probs.argmax().item()] 44 | return outputs 45 | 46 | @torch.no_grad() 47 | def do_generation(model, input_ids, image_tensor, tokenizer, temperature, top_p, num_beams, max_new_tokens): 48 | with torch.inference_mode(): 49 | output_ids = model.generate(input_ids, 50 | images=image_tensor.unsqueeze(0).half().cuda(), 51 | do_sample=True if temperature > 0 else False, 52 | temperature=temperature, 53 | top_p=top_p, 54 | num_beams=num_beams, 55 | max_new_tokens=max_new_tokens, 56 | use_cache=True) 57 | 58 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 59 | return outputs 60 | 61 | def ask_question(model, question, image, image_processor, tokenizer, mode, conv_mode, temperature=0.2, top_p=None, num_beams=1, max_new_tokens=100): 62 | image_tensor = process_images([image], image_processor, model.config)[0] 63 | 64 | if mode == 'greedy': 65 | input_ids = get_input_id(tokenizer, question, conv_mode) 66 | outputs = do_forward(model, input_ids, image_tensor, image.size, tokenizer) 67 | elif mode in ['mc', 'gpt4']: 68 | input_ids = get_input_id(tokenizer, question, conv_mode) 69 | outputs = do_generation(model, input_ids, image_tensor, tokenizer, temperature, top_p, num_beams, max_new_tokens) 70 | elif mode == 'prefix': 71 | outputs = do_prefix_forward(model, question, image_tensor, image.size, tokenizer, conv_mode) 72 | return outputs 73 | 74 | @torch.no_grad() 75 | def do_prefix_forward(model, problem, image_tensor, image_size, tokenizer, conv_mode): 76 | scores = [] 77 | questions = [] 78 | qs = problem["question"] 79 | for option in [problem["option_A"], problem["option_B"]]: 80 | conv = conv_templates[conv_mode].copy() 81 | conv.append_message(conv.roles[0], qs) 82 | conv.append_message(conv.roles[1], option) 83 | prompt = conv.get_prompt() 84 | questions.append(prompt) 85 | answer_tokens = tokenizer.encode(" " + option, add_special_tokens=False)[1:] 86 | num_answer_tokens = len(answer_tokens) 87 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 88 | # try to find the answer tokens in input ids 89 | start_indices = [] 90 | for i in range(input_ids.size(1) - num_answer_tokens + 1): 91 | if torch.equal(input_ids[0, i:i+num_answer_tokens], torch.tensor(answer_tokens).cuda()): 92 | start_indices.append(i) 93 | 94 | if len(start_indices) == 0: 95 | raise ValueError("Answer tokens not found in input_ids") 96 | answer_start = start_indices[-1] 97 | answer_start_from_back = answer_start - input_ids.size(1) 98 | 99 | with torch.inference_mode(): 100 | out = model( 101 | input_ids, 102 | images=image_tensor.unsqueeze(0).half().cuda(), 103 | use_cache=True 104 | ) 105 | logits = out.logits[0, answer_start_from_back-1:answer_start_from_back-1+num_answer_tokens] 106 | probs = torch.nn.functional.softmax(logits, dim=-1) 107 | 108 | # Pick the probabilities corresponding to each of the answer tokens 109 | probs = torch.gather(probs, 1, torch.tensor(answer_tokens).cuda().unsqueeze(0)) 110 | prefix_score = torch.prod(probs.pow(1/num_answer_tokens)) 111 | scores.append(prefix_score.item()) 112 | outputs = "A" if scores[0] > scores[1] else "B" 113 | return outputs -------------------------------------------------------------------------------- /Models/med_flamingo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from einops import repeat 4 | from .Med_Flamingo.src.utils import FlamingoProcessor 5 | from open_flamingo import create_model_and_transforms 6 | 7 | FEW_SHOT_IMAGES = [ 8 | 'PMC1064097_F2.jpg', 9 | 'PMC1065025_F1.jpg', 10 | 'PMC1087855_F3.jpg', 11 | ] 12 | FEW_SHOT_QUESTIONS = [ 13 | 'What radiological technique was used to confirm the diagnosis?', 14 | 'What did the CT scan show?', 15 | 'What is the purpose of the asterisk shown in the figure?', 16 | ] 17 | FEW_SHOW_ANSWERS = [ 18 | [1, 'Mammography'], 19 | [0, 'Cerebral edema'], 20 | [1, 'To indicate the normal lentoid shape of hypocotyl nuclei.'] 21 | ] 22 | FEW_SHOT_OPTIONS = [ 23 | ['A: CT Scan', 'B: Mammography'], 24 | ['A: Cerebral edema', 'B: Intracranial hemorrhage'], 25 | ['A: To indicate the formation of lobes around the contracting nucleus.', 'B: To indicate the normal lentoid shape of hypocotyl nuclei.'] 26 | ] 27 | 28 | def load_model(LLaMa_PATH, CHECKPOINT_PATH, device='cuda'): 29 | model, image_processor, tokenizer = create_model_and_transforms( 30 | clip_vision_encoder_path="ViT-L-14", 31 | clip_vision_encoder_pretrained="openai", 32 | lang_encoder_path=LLaMa_PATH, 33 | tokenizer_path=LLaMa_PATH, 34 | cross_attn_every_n_layers=4 35 | ) 36 | model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location='cpu'), strict=False) 37 | model.to(device=device) 38 | model.eval() 39 | processor = FlamingoProcessor(tokenizer, image_processor) 40 | return model, processor 41 | 42 | def get_few_shot_sample(num, use_option): 43 | question = FEW_SHOT_QUESTIONS[num] 44 | answer = FEW_SHOW_ANSWERS[num] 45 | options = FEW_SHOT_OPTIONS[num] 46 | if use_option: 47 | return f'{question}\n{options[0]}\n{options[1]}\nAnswer: {options[answer[0]]}' 48 | return f'{question} Answer: {answer[1]}' 49 | 50 | def process_prompt(prompt, use_option): 51 | for q in range(len(FEW_SHOT_QUESTIONS)): 52 | prompt = prompt.replace(f'**Q{q+1}**', get_few_shot_sample(q, use_option)) 53 | return prompt 54 | 55 | def ask_question(model, processor, image_path, question, max_new_tokens, mode, IMAGE_DIR): 56 | tmp = [(f'{IMAGE_DIR}/{IM}') for IM in FEW_SHOT_IMAGES] 57 | tmp.append(image_path) 58 | images = [Image.open(image_path) for image_path in tmp] 59 | pixels = processor.preprocess_images(images) 60 | pixels = repeat(pixels, 'N c h w -> b N T c h w', b=1, T=1) 61 | 62 | if mode == 'prefix': 63 | return do_prefix_forward(model, question, pixels, processor) 64 | 65 | question = process_prompt(question, use_option=(mode in ['mc', 'greedy'])) 66 | tokenized_data = processor.encode_text(question) 67 | if mode == 'greedy': 68 | return do_forward(model, processor, pixels, tokenized_data) 69 | elif mode in ['mc', 'gpt4']: 70 | return do_generation(model, processor, pixels, tokenized_data, max_new_tokens) 71 | 72 | 73 | @torch.no_grad() 74 | def do_forward(model, processor, pixels, tokenized_data): 75 | device = model.lang_encoder.device 76 | VALID_ANSWERS = ['A', 'B'] 77 | TOKEN_IDs = [processor.tokenizer(x, return_tensors="pt", add_special_tokens=False).get('input_ids') for x in VALID_ANSWERS] 78 | outputs = model.forward(vision_x=pixels.to(device), 79 | lang_x=tokenized_data["input_ids"].to(device), 80 | attention_mask=tokenized_data["attention_mask"].to(device)) 81 | logits = outputs.logits[0, -1, :].reshape(-1, 1) 82 | soft_max = torch.nn.Softmax(dim=0) 83 | probs = soft_max(torch.cat([logits[x] for x in TOKEN_IDs])) 84 | outputs = VALID_ANSWERS[probs.argmax().item()] 85 | return outputs 86 | 87 | @torch.no_grad() 88 | def do_generation(model, processor, pixels, tokenized_data, max_new_tokens): 89 | device = model.lang_encoder.device 90 | generated_text = model.generate( 91 | vision_x=pixels.to(device), 92 | lang_x=tokenized_data["input_ids"].to(device), 93 | attention_mask=tokenized_data["attention_mask"].to(device), 94 | max_new_tokens=max_new_tokens, 95 | ) 96 | response = processor.tokenizer.decode(generated_text[0]).replace(' ', '').strip() 97 | tmp = processor.tokenizer.decode(tokenized_data.get('input_ids')[0]) 98 | response = response.replace(f'{tmp} ', '') 99 | while response[0] == ' ': 100 | response = response[1: ] 101 | return response 102 | 103 | @torch.no_grad() 104 | def do_prefix_forward(model, problem, pixels, processor): 105 | PREFIX_PROMPT_TEMPLATE = process_prompt(problem.get('format'), use_option=False) 106 | scores = [] 107 | questions = [] 108 | qs = problem["question"] 109 | device = model.lang_encoder.device 110 | for option in [problem["option_A"], problem["option_B"]]: 111 | prompt = PREFIX_PROMPT_TEMPLATE.format(qs, option) 112 | prompt = process_prompt(prompt, use_option=False) 113 | tokenized_data = processor.encode_text(prompt) 114 | questions.append(prompt) 115 | answer_tokens = processor.tokenizer.encode(" " + option, add_special_tokens=False)[1:] 116 | num_answer_tokens = len(answer_tokens) 117 | input_ids = tokenized_data['input_ids'] 118 | # try to find the answer tokens in input ids 119 | start_indices = [] 120 | for i in range(input_ids.size(1) - num_answer_tokens + 1): 121 | if torch.equal(input_ids[0, i:i+num_answer_tokens], torch.tensor(answer_tokens)): 122 | start_indices.append(i) 123 | 124 | if len(start_indices) == 0: 125 | raise ValueError("Answer tokens not found in input_ids") 126 | answer_start = start_indices[-1] 127 | answer_start_from_back = answer_start - input_ids.size(1) 128 | with torch.inference_mode(): 129 | outputs = model.forward(vision_x=pixels.to(device), 130 | lang_x=tokenized_data["input_ids"].to(device), 131 | attention_mask=tokenized_data["attention_mask"].to(device)) 132 | # shift by 1 compared to input 133 | logits = outputs.logits[0, answer_start_from_back-1:answer_start_from_back-1+num_answer_tokens] 134 | probs = torch.nn.functional.softmax(logits, dim=-1) 135 | 136 | # Pick the probabilities corresponding to each of the answer tokens 137 | probs = torch.gather(probs, 1, torch.tensor(answer_tokens).to(device=device).unsqueeze(0)) 138 | prefix_score = torch.prod(probs.pow(1/num_answer_tokens)) 139 | scores.append(prefix_score.item()) 140 | outputs = "A" if scores[0] > scores[1] else "B" 141 | return outputs -------------------------------------------------------------------------------- /Models/RadFM/transformer_decoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code modified from DETR tranformer: 3 | https://github.com/facebookresearch/detr 4 | Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 5 | """ 6 | 7 | import copy 8 | from typing import Optional, List 9 | import pickle as cp 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch import nn, Tensor 14 | 15 | 16 | class TransformerDecoder(nn.Module): 17 | 18 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): 19 | super().__init__() 20 | self.layers = _get_clones(decoder_layer, num_layers) 21 | self.num_layers = num_layers 22 | self.norm = norm 23 | self.return_intermediate = return_intermediate 24 | 25 | def forward(self, tgt, memory, 26 | tgt_mask: Optional[Tensor] = None, 27 | memory_mask: Optional[Tensor] = None, 28 | tgt_key_padding_mask: Optional[Tensor] = None, 29 | memory_key_padding_mask: Optional[Tensor] = None, 30 | pos: Optional[Tensor] = None, 31 | query_pos: Optional[Tensor] = None): 32 | output = tgt 33 | T,B,C = memory.shape 34 | intermediate = [] 35 | atten_layers = [] 36 | for n,layer in enumerate(self.layers): 37 | 38 | residual=True 39 | output,ws = layer(output, memory, tgt_mask=tgt_mask, 40 | memory_mask=memory_mask, 41 | tgt_key_padding_mask=tgt_key_padding_mask, 42 | memory_key_padding_mask=memory_key_padding_mask, 43 | pos=pos, query_pos=query_pos,residual=residual) 44 | atten_layers.append(ws) 45 | if self.return_intermediate: 46 | intermediate.append(self.norm(output)) 47 | if self.norm is not None: 48 | output = self.norm(output) 49 | if self.return_intermediate: 50 | intermediate.pop() 51 | intermediate.append(output) 52 | 53 | if self.return_intermediate: 54 | return torch.stack(intermediate) 55 | return output,atten_layers 56 | 57 | 58 | 59 | class TransformerDecoderLayer(nn.Module): 60 | 61 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 62 | activation="relu", normalize_before=False): 63 | super().__init__() 64 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 65 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 66 | # Implementation of Feedforward model 67 | self.linear1 = nn.Linear(d_model, dim_feedforward) 68 | self.dropout = nn.Dropout(dropout) 69 | self.linear2 = nn.Linear(dim_feedforward, d_model) 70 | 71 | self.norm1 = nn.LayerNorm(d_model) 72 | self.norm2 = nn.LayerNorm(d_model) 73 | self.norm3 = nn.LayerNorm(d_model) 74 | self.dropout1 = nn.Dropout(dropout) 75 | self.dropout2 = nn.Dropout(dropout) 76 | self.dropout3 = nn.Dropout(dropout) 77 | 78 | self.activation = _get_activation_fn(activation) 79 | self.normalize_before = normalize_before 80 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 81 | return tensor if pos is None else tensor + pos 82 | 83 | def forward_post(self, tgt, memory, 84 | tgt_mask: Optional[Tensor] = None, 85 | memory_mask: Optional[Tensor] = None, 86 | tgt_key_padding_mask: Optional[Tensor] = None, 87 | memory_key_padding_mask: Optional[Tensor] = None, 88 | pos: Optional[Tensor] = None, 89 | query_pos: Optional[Tensor] = None, 90 | residual=True): 91 | q = k = self.with_pos_embed(tgt, query_pos) 92 | tgt2,ws = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, 93 | key_padding_mask=tgt_key_padding_mask) 94 | tgt = self.norm1(tgt) 95 | tgt2,ws = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 96 | key=self.with_pos_embed(memory, pos), 97 | value=memory, attn_mask=memory_mask, 98 | key_padding_mask=memory_key_padding_mask) 99 | 100 | 101 | # attn_weights [B,NUM_Q,T] 102 | tgt = tgt + self.dropout2(tgt2) 103 | tgt = self.norm2(tgt) 104 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 105 | tgt = tgt + self.dropout3(tgt2) 106 | tgt = self.norm3(tgt) 107 | return tgt,ws 108 | 109 | def forward_pre(self, tgt, memory, 110 | tgt_mask: Optional[Tensor] = None, 111 | memory_mask: Optional[Tensor] = None, 112 | tgt_key_padding_mask: Optional[Tensor] = None, 113 | memory_key_padding_mask: Optional[Tensor] = None, 114 | pos: Optional[Tensor] = None, 115 | query_pos: Optional[Tensor] = None): 116 | tgt2 = self.norm1(tgt) 117 | q = k = self.with_pos_embed(tgt2, query_pos) 118 | tgt2,ws = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, 119 | key_padding_mask=tgt_key_padding_mask) 120 | tgt = tgt + self.dropout1(tgt2) 121 | tgt2 = self.norm2(tgt) 122 | tgt2,attn_weights = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 123 | key=self.with_pos_embed(memory, pos), 124 | value=memory, attn_mask=memory_mask, 125 | key_padding_mask=memory_key_padding_mask) 126 | tgt = tgt + self.dropout2(tgt2) 127 | tgt2 = self.norm3(tgt) 128 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 129 | tgt = tgt + self.dropout3(tgt2) 130 | return tgt,attn_weights 131 | 132 | def forward(self, tgt, memory, 133 | tgt_mask: Optional[Tensor] = None, 134 | memory_mask: Optional[Tensor] = None, 135 | tgt_key_padding_mask: Optional[Tensor] = None, 136 | memory_key_padding_mask: Optional[Tensor] = None, 137 | pos: Optional[Tensor] = None, 138 | query_pos: Optional[Tensor] = None, 139 | residual=True): 140 | if self.normalize_before: 141 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, 142 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 143 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, 144 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos,residual) 145 | 146 | 147 | def _get_clones(module, N): 148 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 149 | 150 | 151 | 152 | def _get_activation_fn(activation): 153 | """Return an activation function given a string""" 154 | if activation == "relu": 155 | return F.relu 156 | if activation == "gelu": 157 | return F.gelu 158 | if activation == "glu": 159 | return F.glu 160 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 161 | -------------------------------------------------------------------------------- /Models/deepseek.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForCausalLM 3 | 4 | def split_model(model_name): 5 | device_map = {} 6 | model_splits = { 7 | 'deepseek-ai/deepseek-vl2-small': [13, 14], # 2 GPU for 16b 8 | 'deepseek-ai/deepseek-vl2': [10, 10, 10], # 3 GPU for 27b 9 | } 10 | num_layers_per_gpu = model_splits[model_name] 11 | num_layers = sum(num_layers_per_gpu) 12 | layer_cnt = 0 13 | for i, num_layer in enumerate(num_layers_per_gpu): 14 | for j in range(num_layer): 15 | device_map[f'language.model.layers.{layer_cnt}'] = i 16 | layer_cnt += 1 17 | device_map['vision'] = 0 18 | device_map['projector'] = 0 19 | device_map['image_newline'] = 0 20 | device_map['view_seperator'] = 0 21 | device_map['language.model.embed_tokens'] = 0 22 | device_map['language.model.norm'] = 0 23 | device_map['language.lm_head'] = 0 24 | device_map[f'language.model.layers.{num_layers - 1}'] = 0 25 | return device_map 26 | 27 | def load_model(model_path): 28 | if 'janus' in model_path: 29 | from janus.models import MultiModalityCausalLM, VLChatProcessor 30 | global load_pil_images 31 | from janus.utils.io import load_pil_images 32 | vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path) 33 | vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) 34 | vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval() 35 | janus = True 36 | 37 | elif 'deepseek-vl2' in model_path: 38 | from deepseek_vl2.models import DeepseekVLV2Processor, DeepseekVLV2ForCausalLM 39 | global load_pil_images 40 | from deepseek_vl2.utils.io import load_pil_images 41 | device_map = split_model(model_path) 42 | vl_chat_processor: DeepseekVLV2Processor = DeepseekVLV2Processor.from_pretrained(model_path) 43 | vl_gpt: DeepseekVLV2ForCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, device_map=device_map) 44 | vl_gpt = vl_gpt.to(torch.bfloat16).eval() 45 | janus = False 46 | 47 | tokenizer = vl_chat_processor.tokenizer 48 | 49 | # vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval() 50 | return vl_gpt, vl_chat_processor, tokenizer, janus 51 | 52 | def ask_question(model, processor, tokenizer, image_path, question, temperature, mode, max_new_tokens=512, do_sample=False, janus=False): 53 | if janus: 54 | gen_model = model.language_model 55 | content_tmp = "\n{}" 56 | else: 57 | gen_model = model.language 58 | content_tmp = "\n{}" 59 | if mode == 'prefix': 60 | return do_prefix_forward(model, gen_model, question, image_path, processor, tokenizer, content_tmp) 61 | conversation = [ 62 | { 63 | "role": "<|User|>", 64 | "content": content_tmp.format(question), 65 | "images": [image_path], 66 | }, 67 | {"role": "<|Assistant|>", "content": ""}, 68 | ] 69 | 70 | pil_images = load_pil_images(conversation) 71 | prepare_inputs = processor( 72 | conversations=conversation, images=pil_images, force_batchify=True 73 | ).to(model.device) 74 | 75 | inputs_embeds = model.prepare_inputs_embeds(**prepare_inputs) 76 | 77 | if mode == 'greedy': 78 | return do_forward(gen_model, inputs_embeds, prepare_inputs.attention_mask, tokenizer) 79 | elif mode in ['mc', 'gpt4']: 80 | return do_generation(gen_model, inputs_embeds, prepare_inputs, tokenizer, max_new_tokens, do_sample, temperature) 81 | 82 | 83 | def do_generation(model, inputs_embeds, prepare_inputs, tokenizer, max_new_tokens, do_sample, temperature): 84 | if temperature > 0: 85 | do_sample = True 86 | outputs = model.generate( 87 | inputs_embeds=inputs_embeds, 88 | attention_mask=prepare_inputs.attention_mask, 89 | pad_token_id=tokenizer.eos_token_id, 90 | bos_token_id=tokenizer.bos_token_id, 91 | eos_token_id=tokenizer.eos_token_id, 92 | max_new_tokens=max_new_tokens, 93 | do_sample=do_sample, 94 | use_cache=True, 95 | temperature=temperature, 96 | ) 97 | response = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True) 98 | return response 99 | 100 | 101 | def do_forward(model, inputs_embeds, attention_mask, tokenizer): 102 | VALID_ANSWERS = ['A', 'B'] 103 | TOKEN_IDs = [tokenizer.encode(x, return_tensors="pt", add_special_tokens=False) for x in VALID_ANSWERS] 104 | 105 | with torch.inference_mode(): 106 | out = model.forward(inputs_embeds=inputs_embeds, 107 | attention_mask=attention_mask,) 108 | 109 | logits = out.logits[0, -1, :] 110 | soft_max = torch.nn.Softmax(dim=0) 111 | probs = soft_max(torch.cat([logits[x] for x in TOKEN_IDs])) 112 | outputs = VALID_ANSWERS[probs.argmax().item()] 113 | return outputs 114 | 115 | 116 | @torch.no_grad() 117 | def do_prefix_forward(model, gen_model, problem, image, processor, tokenizer, content_tmp): 118 | # PREFIX_PROMPT_TEMPLATE = "Question: {} Answer: {}" 119 | device = model.device 120 | PREFIX_PROMPT_TEMPLATE = problem.get('format') 121 | scores = [] 122 | 123 | qs = problem["question"] 124 | 125 | for option in [problem["option_A"], problem["option_B"]]: 126 | # prompt = PREFIX_PROMPT_TEMPLATE.format(qs, option) 127 | 128 | conversation = [ 129 | { 130 | "role": "<|User|>", 131 | "content": content_tmp.format(qs), 132 | "images": [image], 133 | }, 134 | {"role": "<|Assistant|>", "content": f"{option}"}, 135 | ] 136 | pil_images = load_pil_images(conversation) 137 | inputs = processor(conversations=conversation, images=pil_images, force_batchify=True).to(model.device) 138 | inputs_embeds = model.prepare_inputs_embeds(**inputs) 139 | 140 | answer_tokens = tokenizer.encode(option, add_special_tokens=False) 141 | num_answer_tokens = len(answer_tokens) 142 | input_ids = inputs["input_ids"] 143 | # try to find the answer tokens in input ids 144 | start_indices = [] 145 | for i in range(input_ids.size(1) - num_answer_tokens + 1): 146 | if torch.equal(input_ids[0, i:i+num_answer_tokens], torch.tensor(answer_tokens).to(device=device)): 147 | start_indices.append(i) 148 | 149 | if len(start_indices) == 0: 150 | raise ValueError("Answer tokens not found in input_ids") 151 | answer_start = start_indices[-1] 152 | answer_start_from_back = answer_start - input_ids.size(1) 153 | with torch.inference_mode(): 154 | # out = model(**inputs) 155 | out = gen_model.forward(inputs_embeds=inputs_embeds, attention_mask=inputs.attention_mask) 156 | # shift by 1 compared to input 157 | logits = out.logits[0, answer_start_from_back-1:answer_start_from_back-1+num_answer_tokens] 158 | probs = torch.nn.functional.softmax(logits, dim=-1) 159 | 160 | # Pick the probabilities corresponding to each of the answer tokens 161 | probs = torch.gather(probs, 1, torch.tensor(answer_tokens).to(device=device).unsqueeze(0)) 162 | prefix_score = torch.prod(probs.pow(1/num_answer_tokens)) 163 | scores.append(prefix_score.item()) 164 | 165 | outputs = "A" if scores[0] > scores[1] else "B" 166 | return outputs -------------------------------------------------------------------------------- /Models/radfm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from utils import io_tools 4 | from torchvision import transforms 5 | from transformers import LlamaTokenizer 6 | from .RadFM.multimodality_model import MultiLLaMAForCausalLM 7 | 8 | ROOT = io_tools.get_root(__file__, 2) 9 | 10 | 11 | def get_tokenizer(tokenizer_path, max_img_size=100, image_num=32): 12 | ''' 13 | Initialize the image special tokens 14 | max_img_size denotes the max image put length and image_num denotes how many patch embeddings the image will be encoded to 15 | ''' 16 | if isinstance(tokenizer_path, str): 17 | image_padding_tokens = [] 18 | text_tokenizer = LlamaTokenizer.from_pretrained( 19 | tokenizer_path, 20 | ) 21 | special_token = {"additional_special_tokens": ["",""]} 22 | for i in range(max_img_size): 23 | image_padding_token = "" 24 | 25 | for j in range(image_num): 26 | image_token = "" 27 | image_padding_token = image_padding_token + image_token 28 | special_token["additional_special_tokens"].append("") 29 | image_padding_tokens.append(image_padding_token) 30 | text_tokenizer.add_special_tokens( 31 | special_token 32 | ) 33 | ## make sure the bos eos pad tokens are correct for LLaMA-like models 34 | text_tokenizer.pad_token_id = 0 35 | text_tokenizer.bos_token_id = 1 36 | text_tokenizer.eos_token_id = 2 37 | 38 | return text_tokenizer,image_padding_tokens 39 | 40 | def combine_and_preprocess(question,image_list,image_padding_tokens): 41 | 42 | transform = transforms.Compose([ 43 | transforms.RandomResizedCrop([512,512], scale=(0.8, 1.0), interpolation=transforms.InterpolationMode.BICUBIC), 44 | transforms.ToTensor(), 45 | ]) 46 | images = [] 47 | new_qestions = [_ for _ in question] 48 | padding_index = 0 49 | for img in image_list: 50 | img_path = img['img_path'] 51 | position = img['position'] 52 | 53 | image = Image.open(img_path).convert('RGB') 54 | image = transform(image) 55 | image = image.unsqueeze(0).unsqueeze(-1) # c,w,h,d 56 | 57 | ## pre-process the img first 58 | target_H = 512 59 | target_W = 512 60 | target_D = 4 61 | # This can be different for 3D and 2D images. For demonstration we here set this as the default sizes for 2D images. 62 | images.append(torch.nn.functional.interpolate(image, size=(target_H,target_W,target_D))) 63 | 64 | ## add img placeholder to text 65 | new_qestions[position] = "" + image_padding_tokens[padding_index] + "" + new_qestions[position] 66 | padding_index += 1 67 | 68 | vision_x = torch.cat(images,dim = 1).unsqueeze(0) #cat tensors and expand the batch_size dim 69 | text = ''.join(new_qestions) 70 | return text, vision_x, 71 | 72 | def load_model(model_path, device='cuda'): 73 | language_files_path = f'{ROOT}/Models/RadFM/Language_files' 74 | text_tokenizer, image_padding_tokens = get_tokenizer(language_files_path) 75 | model = MultiLLaMAForCausalLM(lang_model_path=language_files_path) 76 | ckpt = torch.load(model_path, map_location='cpu') 77 | model.load_state_dict(ckpt) 78 | model = model.to(device) 79 | model.eval() 80 | return model, text_tokenizer, image_padding_tokens 81 | 82 | def ask_question(model, question, image_path, text_tokenizer, image_padding_tokens, mode, device): 83 | image =[ 84 | { 85 | 'img_path': image_path, 86 | 'position': 0, #indicate where to put the images in the text string, range from [0,len(question)-1] 87 | }, # can add abitrary number of imgs 88 | ] 89 | if mode == 'prefix': 90 | return do_prefix_forward(model, question, text_tokenizer, image_padding_tokens, image, device) 91 | 92 | text, vision_x = combine_and_preprocess(question, image, image_padding_tokens) 93 | with torch.no_grad(): 94 | lang_x = text_tokenizer(text, max_length=2048, truncation=True, return_tensors="pt")['input_ids'].to(device) 95 | vision_x = vision_x.to(device) 96 | if mode == 'greedy': 97 | return do_forward(model, text_tokenizer, lang_x, vision_x) 98 | elif mode in ['mc', 'gpt4']: 99 | return do_generation(model, text_tokenizer, lang_x, vision_x) 100 | 101 | @torch.no_grad() 102 | def do_generation(model, text_tokenizer, lang_x, vision_x): 103 | generation = model.generate(lang_x, vision_x) 104 | generated_texts = text_tokenizer.batch_decode(generation, skip_special_tokens=True) 105 | return generated_texts[0] 106 | 107 | @torch.no_grad() 108 | def do_forward(model, text_tokenizer, lang_x, vision_x): 109 | VALID_ANSWERS = ['A', 'B'] 110 | TOKEN_IDs = [text_tokenizer.encode(x, return_tensors="pt", add_special_tokens=False) for x in VALID_ANSWERS] 111 | input_embedding, _= model.embedding_layer(lang_x, vision_x, key_words_query=None) 112 | out = model.lang_model(inputs_embeds=input_embedding, attention_mask=None, labels=None) 113 | logits = out['logits'][0, -1, :] 114 | soft_max = torch.nn.Softmax(dim=0) 115 | probs = soft_max(torch.cat([logits[x] for x in TOKEN_IDs])) 116 | outputs = VALID_ANSWERS[probs.argmax().item()] 117 | return outputs 118 | 119 | @torch.no_grad() 120 | def do_prefix_forward(model, problem, text_tokenizer, image_padding_tokens, image, device): 121 | # PREFIX_PROMPT_TEMPLATE = "{} {}" 122 | PREFIX_PROMPT_TEMPLATE = problem.get('format') 123 | scores = [] 124 | questions = [] 125 | qs = problem["question"] 126 | 127 | for option in [problem["option_A"], problem["option_B"]]: 128 | prompt = PREFIX_PROMPT_TEMPLATE.format(qs, option) 129 | questions.append(prompt) 130 | text, vision_x = combine_and_preprocess(prompt, image, image_padding_tokens) 131 | with torch.no_grad(): 132 | lang_x = text_tokenizer(text, max_length=2048, truncation=True, return_tensors="pt")['input_ids'].to(device) 133 | vision_x = vision_x.to(device=device) 134 | answer_tokens = text_tokenizer.encode(" " + option, add_special_tokens=False)[1:] 135 | num_answer_tokens = len(answer_tokens) 136 | 137 | # try to find the answer tokens in input ids 138 | start_indices = [] 139 | for i in range(lang_x.size(1) - num_answer_tokens + 1): 140 | if torch.equal(lang_x[0, i:i+num_answer_tokens], torch.tensor(answer_tokens).to(device=device)): 141 | start_indices.append(i) 142 | 143 | if len(start_indices) == 0: 144 | raise ValueError("Answer tokens not found in input_ids") 145 | answer_start = start_indices[-1] 146 | answer_start_from_back = answer_start - lang_x.size(1) 147 | with torch.inference_mode(): 148 | input_embedding, _= model.embedding_layer(lang_x, vision_x, key_words_query=None) 149 | output = model.lang_model(inputs_embeds=input_embedding, attention_mask=None, labels=None) 150 | # shift by 1 compared to input 151 | logits = output['logits'][0, answer_start_from_back-1:answer_start_from_back-1+num_answer_tokens] 152 | probs = torch.nn.functional.softmax(logits, dim=-1) 153 | 154 | # Pick the probabilities corresponding to each of the answer tokens 155 | probs = torch.gather(probs, 1, torch.tensor(answer_tokens).to(device=device).unsqueeze(0)) 156 | prefix_score = torch.prod(probs.pow(1/num_answer_tokens)) 157 | scores.append(prefix_score.item()) 158 | outputs = "A" if scores[0] > scores[1] else "B" 159 | return outputs -------------------------------------------------------------------------------- /Models/RadFM/helpers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Taken from https://github.com/lucidrains/flamingo-pytorch 3 | """ 4 | 5 | import torch 6 | from einops import rearrange, repeat 7 | from einops_exts import rearrange_many 8 | from torch import einsum, nn 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def FeedForward(dim, mult=4): 16 | inner_dim = int(dim * mult) 17 | return nn.Sequential( 18 | nn.LayerNorm(dim), 19 | nn.Linear(dim, inner_dim, bias=False), 20 | nn.GELU(), 21 | nn.Linear(inner_dim, dim, bias=False), 22 | ) 23 | 24 | 25 | class PerceiverAttention(nn.Module): 26 | def __init__(self, *, dim, dim_head=64, heads=8): 27 | super().__init__() 28 | self.scale = dim_head**-0.5 29 | self.heads = heads 30 | inner_dim = dim_head * heads 31 | 32 | self.norm_media = nn.LayerNorm(dim) 33 | self.norm_latents = nn.LayerNorm(dim) 34 | 35 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 36 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 37 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 38 | 39 | def forward(self, x, latents): 40 | """ 41 | Args: 42 | x (torch.Tensor): image features 43 | shape (b, T, n1, D) 44 | latent (torch.Tensor): latent features 45 | shape (b, T, n2, D) 46 | """ 47 | x = self.norm_media(x) 48 | latents = self.norm_latents(latents) 49 | 50 | h = self.heads 51 | 52 | q = self.to_q(latents) 53 | kv_input = torch.cat((x, latents), dim=-2) 54 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 55 | q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h) 56 | q = q * self.scale 57 | 58 | # attention 59 | sim = einsum("... i d, ... j d -> ... i j", q, k) 60 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 61 | attn = sim.softmax(dim=-1) 62 | 63 | out = einsum("... i j, ... j d -> ... i d", attn, v) 64 | out = rearrange(out, "b h t n d -> b t n (h d)", h=h) 65 | return self.to_out(out) 66 | 67 | 68 | class PerceiverResampler(nn.Module): 69 | def __init__( 70 | self, 71 | *, 72 | dim, 73 | depth=6, 74 | dim_head=64, 75 | heads=8, 76 | num_latents=64, 77 | max_num_media=None, 78 | max_num_frames=None, 79 | ff_mult=4, 80 | ): 81 | super().__init__() 82 | self.latents = nn.Parameter(torch.randn(num_latents, dim)) 83 | self.frame_embs = ( 84 | nn.Parameter(torch.randn(max_num_frames, dim)) 85 | if exists(max_num_frames) 86 | else None 87 | ) 88 | self.media_time_embs = ( 89 | nn.Parameter(torch.randn(max_num_media, 1, dim)) 90 | if exists(max_num_media) 91 | else None 92 | ) 93 | 94 | self.layers = nn.ModuleList([]) 95 | for _ in range(depth): 96 | self.layers.append( 97 | nn.ModuleList( 98 | [ 99 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 100 | FeedForward(dim=dim, mult=ff_mult), 101 | ] 102 | ) 103 | ) 104 | 105 | self.norm = nn.LayerNorm(dim) 106 | 107 | def forward(self, x): 108 | """ 109 | Args: 110 | x (torch.Tensor): image features 111 | shape (b, T, F, v, D) 112 | Returns: 113 | shape (b, T, n, D) where n is self.num_latents 114 | """ 115 | b, T, F, v = x.shape[:4] 116 | 117 | # frame and media time embeddings 118 | if exists(self.frame_embs): 119 | frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v) 120 | x = x + frame_embs 121 | x = rearrange( 122 | x, "b T F v d -> b T (F v) d" 123 | ) # flatten the frame and spatial dimensions 124 | if exists(self.media_time_embs): 125 | x = x + self.media_time_embs[:T] 126 | 127 | # blocks 128 | latents = repeat(self.latents, "n d -> b T n d", b=b, T=T) 129 | for attn, ff in self.layers: 130 | latents = attn(x, latents) + latents 131 | latents = ff(latents) + latents 132 | return self.norm(latents) 133 | 134 | 135 | # gated cross attention 136 | 137 | 138 | class MaskedCrossAttention(nn.Module): 139 | def __init__( 140 | self, 141 | *, 142 | dim, 143 | dim_visual, 144 | dim_head=64, 145 | heads=8, 146 | only_attend_immediate_media=True, 147 | ): 148 | super().__init__() 149 | self.scale = dim_head**-0.5 150 | self.heads = heads 151 | inner_dim = dim_head * heads 152 | 153 | self.norm = nn.LayerNorm(dim) 154 | 155 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 156 | self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False) 157 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 158 | 159 | # whether for text to only attend to immediate preceding image, or all previous images 160 | self.only_attend_immediate_media = only_attend_immediate_media 161 | 162 | def forward(self, x, media, media_locations=None, attend_previous=True): 163 | """ 164 | Args: 165 | x (torch.Tensor): text features 166 | shape (B, T_txt, D_txt) 167 | media (torch.Tensor): image features 168 | shape (B, T_img, n, D_img) where n is the dim of the latents 169 | media_locations: boolean mask identifying the media tokens in x 170 | shape (B, T_txt) 171 | attend_previous: bool 172 | If false, ignores immediately preceding image and starts attending when following image 173 | """ 174 | _, T_img, n = media.shape[:3] 175 | h = self.heads 176 | 177 | x = self.norm(x) 178 | 179 | q = self.to_q(x) 180 | media = rearrange(media, "b t n d -> b (t n) d") 181 | 182 | k, v = self.to_kv(media).chunk(2, dim=-1) 183 | q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h) 184 | 185 | q = q * self.scale 186 | 187 | sim = einsum("... i d, ... j d -> ... i j", q, k) 188 | 189 | if exists(media_locations): 190 | # at each boolean of True, increment the time counter (relative to media time) 191 | text_time = media_locations.cumsum(dim=-1) 192 | media_time = torch.arange(T_img, device=x.device) + 1 193 | 194 | if not attend_previous: 195 | text_time[~media_locations] += 1 196 | # make sure max is still the number of images in the sequence 197 | text_time[ 198 | text_time 199 | > repeat( 200 | torch.count_nonzero(media_locations, dim=1), 201 | "b -> b i", 202 | i=text_time.shape[1], 203 | ) 204 | ] = 0 205 | 206 | # text time must equal media time if only attending to most immediate image 207 | # otherwise, as long as text time is greater than media time (if attending to all previous images / media) 208 | mask_op = torch.eq if self.only_attend_immediate_media else torch.ge 209 | 210 | text_to_media_mask = mask_op( 211 | rearrange(text_time, "b i -> b 1 i 1"), 212 | repeat(media_time, "j -> 1 1 1 (j n)", n=n), 213 | ) 214 | sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max) 215 | 216 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 217 | attn = sim.softmax(dim=-1) 218 | 219 | if exists(media_locations) and self.only_attend_immediate_media: 220 | # any text without a preceding media needs to have attention zeroed out 221 | text_without_media_mask = text_time == 0 222 | text_without_media_mask = rearrange( 223 | text_without_media_mask, "b i -> b 1 i 1" 224 | ) 225 | attn = attn.masked_fill(text_without_media_mask, 0.0) 226 | 227 | out = einsum("... i j, ... j d -> ... i d", attn, v) 228 | out = rearrange(out, "b h n d -> b n (h d)") 229 | return self.to_out(out) 230 | 231 | 232 | class GatedCrossAttentionBlock(nn.Module): 233 | def __init__( 234 | self, 235 | *, 236 | dim, 237 | dim_visual, 238 | dim_head=64, 239 | heads=8, 240 | ff_mult=4, 241 | only_attend_immediate_media=True, 242 | ): 243 | super().__init__() 244 | self.attn = MaskedCrossAttention( 245 | dim=dim, 246 | dim_visual=dim_visual, 247 | dim_head=dim_head, 248 | heads=heads, 249 | only_attend_immediate_media=only_attend_immediate_media, 250 | ) 251 | self.attn_gate = nn.Parameter(torch.tensor([0.0])) 252 | 253 | self.ff = FeedForward(dim, mult=ff_mult) 254 | self.ff_gate = nn.Parameter(torch.tensor([0.0])) 255 | 256 | def forward( 257 | self, 258 | x, 259 | media, 260 | media_locations=None, 261 | attend_previous=True, 262 | ): 263 | x = ( 264 | self.attn( 265 | x, 266 | media, 267 | media_locations=media_locations, 268 | attend_previous=attend_previous, 269 | ) 270 | * self.attn_gate.tanh() 271 | + x 272 | ) 273 | x = self.ff(x) * self.ff_gate.tanh() + x 274 | 275 | return x 276 | -------------------------------------------------------------------------------- /Models/RadFM/my_embedding_layer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | from .helpers import PerceiverResampler 5 | from .utils import get_visual_encoder 6 | from einops import rearrange, repeat 7 | from einops_exts import rearrange_many 8 | import torchvision 9 | from .vit_3d import ViT 10 | from einops.layers.torch import Rearrange 11 | from .transformer_decoder import TransformerDecoder,TransformerDecoderLayer 12 | from torch.utils.checkpoint import checkpoint 13 | from torch.autograd import Variable 14 | import random 15 | from transformers import AutoTokenizer, AutoModel, AutoConfig, BertModel 16 | 17 | class MyEmbedding(nn.Module): 18 | def __init__(self, num_embeddings=32000, embedding_dim=5120, perceiver_num=32,vis_dim = 768, patch_size=32, frame_patch_size = 4 ,seg_channel = 256): 19 | super().__init__() 20 | self.num_embeddings = num_embeddings 21 | self.embedding_dim = embedding_dim 22 | self.weight = nn.Parameter(torch.torch.randn((num_embeddings, embedding_dim))) 23 | self.figure_token_weight = nn.Parameter(torch.randn((2, embedding_dim))) 24 | self.flag = 'Text' 25 | self.patch_size = patch_size 26 | self.frame_patch_size = frame_patch_size 27 | self.seg_channel = seg_channel 28 | 29 | ## The bert model is useless for generation. Load it just for keeping model the same with the pre-train checkpoint. 30 | self.bert_tokenizer = AutoTokenizer.from_pretrained("/data/models/radfm/MedKEBERT") 31 | self.bert_model = BertModel(AutoConfig.from_pretrained("/data/models/radfm/MedKEBERT/")) 32 | self.bert_projection_fc = nn.Linear(768,vis_dim) 33 | 34 | ## the MedKEBERT can be downloaded from https://huggingface.co/xmcmic/Med-KEBERT/tree/main ## 35 | # self.bert_tokenizer = AutoTokenizer.from_pretrained("xmcmic/Med-KEBERT") 36 | # self.bert_model = AutoModel.from_pretrained("xmcmic/Med-KEBERT") 37 | # self.bert_projection_fc = nn.Linear(768,vis_dim) 38 | 39 | self.vision_encoder = ViT( 40 | image_size = 512, # image size 41 | frames = 512, # max number of frames 42 | image_patch_size = patch_size, # image patch size 43 | frame_patch_size = frame_patch_size, # frame patch size 44 | dim = vis_dim, 45 | depth = 12, 46 | heads = 8, 47 | mlp_dim = 2048, 48 | dropout = 0.1, 49 | emb_dropout = 0.1 50 | ) 51 | 52 | self.output_upscaling = nn.Sequential( 53 | nn.ConvTranspose3d(vis_dim, vis_dim // 4, kernel_size=2, stride=2), 54 | nn.BatchNorm3d(vis_dim // 4), 55 | nn.GELU(), 56 | nn.ConvTranspose3d(vis_dim // 4, vis_dim // 8, kernel_size=2, stride=2), 57 | nn.GELU(), 58 | ) 59 | 60 | decoder_layer = TransformerDecoderLayer(d_model = vis_dim, nhead = 8, normalize_before=True) 61 | decoder_norm = nn.LayerNorm(vis_dim) 62 | self.transformer_decoder = TransformerDecoder(decoder_layer = decoder_layer, num_layers = 4, norm=decoder_norm) 63 | self.transformer_decoder_mlp = nn.Sequential( 64 | nn.Linear(vis_dim,vis_dim // 4), 65 | nn.GELU(), 66 | nn.Linear(vis_dim // 4,vis_dim // 8), 67 | nn.GELU(), 68 | ) 69 | self.vis_dim = vis_dim 70 | 71 | self.perceiver = PerceiverResampler(dim=self.vis_dim, num_latents = perceiver_num) 72 | self.fc = nn.Linear(self.vis_dim,self.embedding_dim) 73 | self.cls_head = nn.Linear(self.vis_dim // 8, 1) 74 | 75 | 76 | def forward(self, text_input, vision_x, key_words_query = None): 77 | if self.flag == 'Text': 78 | B,S,C,H,W,D = vision_x.shape 79 | vision_x = rearrange(vision_x, "b S c h w d-> (b S) c h w d") 80 | 81 | 82 | vision_x, pos_embedding = self.vision_encoder(vision_x) 83 | # vision_x = Variable(vision_x,requires_grad=True) 84 | # vision_x, _ = checkpoint(self.vision_encoder,vision_x) 85 | 86 | vision_x = rearrange(vision_x, "(b s F) v d -> b s F v d", b=B, s=S,F=1) 87 | 88 | loss_matching = None 89 | # if key_words_query != None: 90 | # # key_words_query list[list[str]] B, words, each word matches corresponding vision_x embedding 91 | # query_words = [item for sublist in key_words_query for item in sublist] 92 | # query_words = list(set(query_words)) 93 | # if len(query_words)>16: 94 | # random.shuffle(query_words) 95 | # query_words = query_words[0:16] 96 | # if query_words != []: 97 | # contrastive_labels = torch.zeros(B,len(query_words)) #B Q 98 | # for i,sublist in enumerate(key_words_query): 99 | # for j,item in enumerate(query_words): 100 | # if item in sublist: 101 | # contrastive_labels[i,j] = 1 102 | # contrastive_labels = contrastive_labels.to(vision_x.dtype).to(vision_x.device) 103 | 104 | # with torch.no_grad(): 105 | # query_words_embedding = self.bert_tokenizer(query_words, padding='max_length', truncation=True, max_length=256,return_tensors="pt") 106 | # query_words_embedding = self.bert_model(input_ids = query_words_embedding['input_ids'].to(vision_x.device),attention_mask = query_words_embedding['attention_mask'].to(vision_x.device))['last_hidden_state'][:,0,:].to(vision_x.dtype).to(vision_x.device) # Q,D 107 | # query_words_embedding = self.bert_projection_fc(query_words_embedding) 108 | # query_words_embedding = query_words_embedding.unsqueeze(0).repeat(B,1,1) # B,Q,D 109 | # _,N,_ = query_words_embedding.shape 110 | 111 | # image_embedding = vision_x.mean(dim=1) # B V D average pooling 去除掉多模态。 112 | # image_embedding = rearrange(image_embedding, "b F v d -> b (F v) d") 113 | # pos_embedding = rearrange(pos_embedding, "(b s) v d -> b s v d", b=B, s=S)[:,0,:,:] 114 | 115 | # image_embedding = image_embedding.transpose(0,1) # (H/P W/P D/P) B D 116 | # pos_embedding = pos_embedding.transpose(0,1) # (H/P W/P D/P) B D 117 | # query_words_embedding = query_words_embedding.transpose(0,1) # N B D 118 | 119 | # oo_embedding,_ = self.transformer_decoder(query_words_embedding, image_embedding, pos = pos_embedding) 120 | # oo_embedding = oo_embedding.transpose(0,1) # B Q D 121 | # oo_embedding = rearrange(oo_embedding, 'b n d -> (b n) d') 122 | # oo_embedding = self.transformer_decoder_mlp(oo_embedding) 123 | # oo_embedding = self.cls_head(oo_embedding).mean(dim = -1) 124 | # oo_embedding = rearrange(oo_embedding, '(b n) -> b n', b=B, n=N) # B Q 125 | # # oo_embedding = rearrange(oo_embedding, 'b n d -> b (n d)') # B Q 126 | # loss_matching = F.binary_cross_entropy_with_logits(oo_embedding, contrastive_labels) 127 | 128 | vision_x = self.perceiver(vision_x) # reshapes to (b, S, n, d) 129 | #vision_x = checkpoint(self.perceiver,vision_x) 130 | 131 | n = vision_x.shape[2] 132 | 133 | vision_x = rearrange(vision_x, "b s n d -> (b s n) d") 134 | vision_x = self.fc(vision_x) 135 | vision_x = rearrange(vision_x, "(b T) d -> b T d", b=B, T=n*S) 136 | 137 | embedding_weight = torch.cat([self.weight, self.figure_token_weight],dim = 0) 138 | embedding_weight = embedding_weight.unsqueeze(0).repeat(B, 1, 1) 139 | embedding_weight = torch.cat([embedding_weight,vision_x],dim = 1) 140 | text_input = F.one_hot(text_input,embedding_weight.shape[1]).to(vision_x.dtype).to(vision_x.device) 141 | out_put = torch.matmul(text_input, embedding_weight) 142 | 143 | ## useless for now. ignore the folowing code## 144 | # if self.flag == 'Seg': 145 | # B,C,H,W,D = vision_x.shape 146 | # _,N,_ = text_input.shape 147 | # latent_embedding, pos_embedding = self.vision_encoder(vision_x) # B (H/P W/P D/P) D 148 | 149 | # image_embedding = latent_embedding.transpose(0,1) # (H/P W/P D/P) B D 150 | # pos_embedding = pos_embedding.transpose(0,1) # (H/P W/P D/P) B D 151 | # text_input = text_input.transpose(0,1) # N B D 152 | 153 | # mask_embedding,_ = self.transformer_decoder(text_input, image_embedding, pos = pos_embedding) 154 | # mask_embedding = mask_embedding.transpose(0,1) # B N D 155 | # mask_embedding = rearrange(mask_embedding, 'b n d -> (b n) d') 156 | # mask_embedding = self.transformer_decoder_mlp(mask_embedding) 157 | # mask_embedding = rearrange(mask_embedding, '(b n) d -> b n d', b=B, n=N,d = self.vis_dim // 8) 158 | 159 | # vision_x = rearrange(latent_embedding,'b (h w d) c -> b c h w d', h = (H // self.patch_size), w = (W // self.patch_size), d = (D // self.frame_patch_size), c=self.vis_dim) 160 | # vision_x = self.output_upscaling(vision_x) #B C H/4 W/4 D/4 161 | # out_put = torch.einsum('bchwd,bnc->bnhwd', vision_x, mask_embedding) 162 | 163 | return out_put,loss_matching 164 | 165 | # model = MyEmbedding(vision_encoder_path = '') 166 | # text_input = torch.randint(low=0, high=3210, size=(4,2048)) 167 | # image_input = torch.randn((4,3,3,512,512,4)) 168 | # key_words_query = [[],[],[],['consoliation']] 169 | # print(model(text_input, image_input, key_words_query)) 170 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | drawing 3 |

4 | 5 |

Can you trust your AI radiologist?
Probing the reliability of
multimodal medical foundation models

6 |

Mohammad Shahab Sepehri, Zalan Fabian, Maryam Soltanolkotabi, Mahdi Soltanolkotabi

7 | 15 |

16 | | 🤗 Hugging Face | 📄 17 | Paper | 🌐 18 | Blog | 19 |

20 | 21 | [![License](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/license/MIT) 22 | 23 |

24 | MediConfusion is a challenging medical Visual Question Answering (VQA) benchmark dataset, that probes the failure modes of medical Multimodal Large Language Models (MLLMs) from a vision perspective. We reveal that state-of-the-art models are easily confused by image pairs that are otherwise visually dissimilar and clearly distinct for medical experts. These are some examples of confusing image pairs from the ROCO radiology dataset: 25 |

26 |

27 | drawing 28 |

29 |

30 | Our benchmark consists of 176 confusing pairs. A confusing pair is a set of two images that share the same question and corresponding answer options, but the correct answer is different for the images.
31 |

32 |

33 | drawing 34 |

35 |

36 | We evaluate models based on their ability to answer both questions correctly within a confusing pair, which we call set accuracy. This metric indicates how well models can tell the two images apart, as a model that selects the same answer option for both images for all pairs will receive 0% set accuracy. We also report confusion, a metric that describes the proportion of confusing pairs where the model has chosen the same answer option for both images. 37 |

38 |

39 | Strikingly, all available models (open-source or proprietary) achieve performance below random guessing on MediConfusion, raising serious concerns about the reliability of existing medical MLLMs for healthcare deployment. 40 |

41 | 42 | ## 📊 Leaderboard 43 |
44 | 45 | 46 | | Rank | Model | Version | Set acc. (%) | Confusion (%) | 47 | | :--: | :--: | :--: | :--: | :--: | 48 | | 🏅️ | **[Gemini](https://deepmind.google/technologies/gemini/pro/)** | 2.0 Pro | **29.55** | 61.93 | 49 | | 🥈 | **Random Guessing** | - | **25.00** | 50.00 | 50 | | 🥉 | **[GPT](https://openai.com/index/learning-to-reason-with-llms/)** | o1 (release 20241217) | 21.69 | 72.99 | 51 | | 4 | **[Gemini](https://deepmind.google/technologies/gemini/pro/)** | 1.5 Pro | 19.89 | 58.52 | 52 | | 5 | **[GPT](https://openai.com/index/hello-gpt-4o/)** | 4o (release 20240513) | 18.75 | 75.00 | 53 | | 6 | [Deepseek VL2](https://github.com/deepseek-ai/DeepSeek-VL2) | - | 15.91 | 77.19 | 54 | | 7 | [Llama 3.2](https://www.llama.com/) | 90B-Vision-Instruct | 15.34 | 78.41 | 55 | | 8 | [InstructBLIP](https://github.com/salesforce/LAVIS/tree/main/projects/instructblip) | Vicuna 7B | 12.50 | 80.35 | 56 | | 9 | [Molmo](https://molmo.allenai.org/) | 7B-D-0924 | 9.66 | 86.21 | 57 | | 10 | [LLaVA](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf) | v1.6-Mistral 7B | 9.09 | 85.80 | 58 | | 11 | [Claude](https://claude.ai/new) | 3 Opus | 8.52 | 84.09 | 59 | | 12 | [BLIP-2](https://github.com/salesforce/LAVIS/tree/main/projects/blip2) | Opt 2.7B | 6.82 | 86.93 | 60 | | 13 | [Molmo](https://molmo.allenai.org/) | 72B-0924 | 6.82 | 85.80 | 61 | | 14 | [RadFM](https://github.com/chaoyi-wu/RadFM) | - | 5.68 | 85.80 | 62 | | 15 | [Janus](https://github.com/deepseek-ai/Janus) | Pro 7B | 4.55 | 92.05 | 63 | | 16 | [Med-Flamingo](https://github.com/snap-stanford/med-flamingo) | - | 4.55 | 98.30 | 64 | | 17 | [LLaVA-Med](https://github.com/microsoft/LLaVA-Med) | v1.5-Mistral 7B | 1.14 | 97.16 | 65 | 66 |
67 | 68 | ## Updates 69 | 70 | - [2025/02/18] _DeepSeek_ family added to the supported models. 71 | - [2025/01/22] **MediConfusion** is accepted by ICLR 2025. 72 | - [2024/09/11] _Molmo_ family added to the supported models. 73 | - [2024/03/11] _Llama 3.2_ family added to the supported models. 74 | 75 | ## 📖 Table of Contents 76 | 77 | * [Requirements](#-requirements) 78 | * [Data Download](#data-download) 79 | * [Open-source Models](#open-source-models) 80 | * [Proprietary Models](#proprietary-models) 81 | * [Package Versions](#package-versions) 82 | * [Usage](#-usage) 83 | * [Evaluation](#evaluation) 84 | * [Arguments](#arguments) 85 | 86 | ## 🔧 Requirements 87 | 88 | Create and activate a `conda` environment with the following command: 89 | ``` 90 | conda create -n "mediconfusion" python=3.10 91 | conda activate mediconfusion 92 | ``` 93 | 94 | Use the following code to install requirements: 95 | 96 | ``` 97 | pip install -r requirements.txt 98 | ``` 99 | 100 | 101 | 102 | If you have any problem using the models, please follow the instructions below. 103 | 104 | ### Data Download 105 |

106 | The images in MediConfusion have to be downloaded directly from the source due to their license. 107 | To download all images (26 MB), use the following command: 108 | 109 |

110 | 111 | ``` 112 | python scripts/download.py 113 | ``` 114 | 115 | The images can also be downloaded directly from [ROCO](https://github.com/razorx89/roco-dataset) (set `local_image_address` to `False`). In this case, set `data_path` to the download folder when running the evaluation script (more details in [Usage](#-usage)).
116 | 117 | ### Open-source Models 118 | * `LLaVA-Med`: Follow the instructions [here](https://github.com/microsoft/LLaVA-Med/tree/main) and install `LLaVA-Med`. Download the model from [here](https://huggingface.co/microsoft/llava-med-7b-delta) and set `model_path` in the [config](./configs/Models/llava_med/vanilla.json) to its folder. 119 | * `LLaMA 3.2`: To download this model you should get access by requesting in [here](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision). Then, add your token to the [config](./configs/Models/llama/vanilla.json). If you encountered CUDA memory error, set `device` to `auto`. 120 | * `Molmo`: If you encountered CUDA memory error, set `device` to `auto`. 121 | * `LLaMA`: Download the model from [here](https://huggingface.co/yahma/llama-7b-hf) and set `LLaMa_PATH` in the `MedFlamingo` [config](./configs/Models/med_flamingo/vanilla.json) to its folder. 122 | * `MedFlamingo`: Download the model from [here](https://huggingface.co/med-flamingo/med-flamingo) and set `CHECKPOINT_PATH` in the [config](./configs/Models/llava_med/vanilla.json) to its folder. 123 | * `RadFM`: Download the model from [here](https://huggingface.co/chaoyi-wu/RadFM) and set `model_path` in the [config](./configs/Models/radfm/vanilla.json) to its folder. 124 | * `DeepSeek VL2`: Follow the instalation from [here](https://github.com/deepseek-ai/DeepSeek-VL2). 125 | * `DeepSeek Janus`: Follow instalation from [here](https://github.com/deepseek-ai/Janus). 126 | 127 | ### Proprietary Models 128 | To use proprietary models, save your API keys in the root directory of the repo in a file named `.env,` including the keys as in the example below. 129 | ``` 130 | GEMINI_API_KEY=YOUR_KEY 131 | AZURE_OPENAI_API_KEY=YOUR_KEY 132 | AZURE_OPENAI_ENDPOINT=YOUR_KEY 133 | ANTHROPIC_API_KEY=YOUR_KEY 134 | ``` 135 | 136 | ### Package Versions 137 | Different MLLMs need different versions of the `transformers` package. Please use the following versions for each MLLM.
138 | * `LLaVA-Med`: Use `transformers==4.36.2` 139 | * `RadFM`: Use `transformers==4.28.1` 140 | * `MedFlamingo`: Use `transformers==4.44.2` and install `open-flamingo` package 141 | * `Gemini`: You need `python>=3.9` 142 | * `Molmo`: Use `transformers==4.45.2` 143 | * `DeepSeek`: Use `transformers==4.38.2` 144 | * `Other MLLMs`: Use `transformers==4.44.2` and `python>=3.8` 145 | 146 | ## 🔰 Usage 147 | 148 | ### Evaluation 149 | Before using the code, make sure to follow the instructions in [Requirements](#-requirements).
150 | You can create/change model configurations in `configs/MODEL_NAME/`.
151 | To use the evaluation code, use the following command: 152 | ``` 153 | python scripts/answering.py --mllm_name MODEL_NAME --mode MODE 154 | ``` 155 | The results will be saved in `Results/MODEL_NAME/`. You will see two files: one containing the final scores and one containing the answers provided by the model.
156 | After runing `answering.py` you can print the results again with the command below: 157 | ``` 158 | python scripts/printing.py --mllm_name MODEL_NAME --mode MODE 159 | ``` 160 | ### Arguments 161 | * `mode`: This sets the evaluation method. Available options are `gpt4` (FF), `mc` (MC), `greedy` (GD), and `prefix` (PS). For proprietary models, you can only use the first two methods. 162 | * `mllm_name`: This is the name of your desired MLLM. Available options are `gpt` (GPT-4o), `gemini` (Gemini 1.5 Pro), `claude` (Claude 3 Opus), `llava` (LLaVA), `blip2` (BLIP-2), `intructblip` (InstructBLIP), `llava_med` (LLaVA-Med), `radfm` (RadFM), and `med_flamingo` (Med-Flamingo). 163 | * `model_args_path` (default: `configs/MLLM_NAME/vanilla.json`): Path to the model's configuration file. 164 | * `tr` (default: 3): Threshold used for FF evaluation to select an option. If the difference between assigned scores is at least `tr`, we select the option with the higher score. 165 | * `resume_path` (default: `None`): If your run is interrupted and you want to resume evaluation, you should set this argument to the path to the answers of the previous run. 166 | * `local_image_address` (default: `True`): If `Flase`, the code looks for the images based on their ROCO IDs. Otherwise, it looks for the images based on their local IDs. 167 | * `data_path` (default: `./data/images`): Path to the images. If you download the images using our script, this is `./data/images`. If you are not using local addressing, this is the path to the [ROCO](https://github.com/razorx89/roco-dataset). 168 | * `device` (default: `cuda`): You can use `cuda` or `cpu`. For `LLaVA-Med`, our code does not support `cpu`. 169 | 170 | ## 📌 Citation 171 | 172 | If you use this code or our dataset, please cite our [paper](https://arxiv.org/abs/2409.15477). 173 | 174 | ```bibtex 175 | @inproceedings{sepehri2025mediconfusion, 176 | title={MediConfusion: Can you trust your {AI} radiologist? Probing the reliability of multimodal medical foundation models}, 177 | author={Mohammad Shahab Sepehri and Zalan Fabian and Maryam Soltanolkotabi and Mahdi Soltanolkotabi}, 178 | booktitle={The Thirteenth International Conference on Learning Representations}, 179 | year={2025}, 180 | url={https://openreview.net/forum?id=H9UnNgdq0g} 181 | } 182 | ``` 183 | -------------------------------------------------------------------------------- /Models/RadFM/blocks.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union, Callable, Optional 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | from torch.utils.checkpoint import checkpoint 8 | 9 | class PMC_CLIP_cfg: 10 | backbone: str = 'ModifiedRN50' # ['RN50', 'ModifiedRN50', 'MAE'] 11 | layers: Union[Tuple[int, int, int, int], int] = [3,4,6,3] 12 | width: int = 64 13 | head_width: int = 64 14 | mlp_ratio: float = 4.0 15 | patch_size: int = 16 16 | image_size: Union[Tuple[int, int], int] = 224 17 | timm_model_name: str = None # a valid model name overrides layers, width, patch_size 18 | timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model 19 | timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') 20 | timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') 21 | patch_dropout: float = 0.0 # patch dropout rate, no dropout by default 22 | drop_attention_rate: float = 0. # Transformer Dropout 23 | patch_size: None 24 | 25 | class Bottleneck(nn.Module): 26 | expansion = 4 27 | 28 | def __init__(self, inplanes, planes, stride=1): 29 | super().__init__() 30 | 31 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 32 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 33 | self.bn1 = nn.BatchNorm2d(planes) 34 | self.relu1 = nn.ReLU(inplace=True) 35 | 36 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 37 | self.bn2 = nn.BatchNorm2d(planes) 38 | self.relu2 = nn.ReLU(inplace=True) 39 | 40 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 41 | 42 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 43 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 44 | self.relu3 = nn.ReLU(inplace=True) 45 | 46 | self.downsample = None 47 | self.stride = stride 48 | 49 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 50 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 51 | self.downsample = nn.Sequential(OrderedDict([ 52 | ("-1", nn.AvgPool2d(stride)), 53 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 54 | ("1", nn.BatchNorm2d(planes * self.expansion)) 55 | ])) 56 | 57 | def forward(self, x: torch.Tensor): 58 | identity = x 59 | 60 | out = self.relu1(self.bn1(self.conv1(x))) 61 | out = self.relu2(self.bn2(self.conv2(out))) 62 | out = self.avgpool(out) 63 | out = self.bn3(self.conv3(out)) 64 | 65 | if self.downsample is not None: 66 | identity = self.downsample(x) 67 | 68 | out += identity 69 | out = self.relu3(out) 70 | return out 71 | 72 | 73 | class AttentionPool2d(nn.Module): 74 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 75 | super().__init__() 76 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 77 | self.k_proj = nn.Linear(embed_dim, embed_dim) 78 | self.q_proj = nn.Linear(embed_dim, embed_dim) 79 | self.v_proj = nn.Linear(embed_dim, embed_dim) 80 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 81 | self.num_heads = num_heads 82 | 83 | def forward(self, x): 84 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 85 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 86 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 87 | x, _ = F.multi_head_attention_forward( 88 | query=x, key=x, value=x, 89 | embed_dim_to_check=x.shape[-1], 90 | num_heads=self.num_heads, 91 | q_proj_weight=self.q_proj.weight, 92 | k_proj_weight=self.k_proj.weight, 93 | v_proj_weight=self.v_proj.weight, 94 | in_proj_weight=None, 95 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 96 | bias_k=None, 97 | bias_v=None, 98 | add_zero_attn=False, 99 | dropout_p=0, 100 | out_proj_weight=self.c_proj.weight, 101 | out_proj_bias=self.c_proj.bias, 102 | use_separate_proj_weight=True, 103 | training=self.training, 104 | need_weights=False 105 | ) 106 | 107 | return x[0] 108 | 109 | 110 | class ResNet(nn.Module): 111 | """ 112 | RN50 113 | """ 114 | 115 | def __init__( 116 | self, layers, output_dim, heads, image_size=224, width=64, 117 | block=Bottleneck, 118 | ): 119 | super().__init__() 120 | self.output_dim = output_dim 121 | self.image_size = image_size 122 | 123 | # the 1-layer stem 124 | self.conv1 = nn.Conv2d(3, width, kernel_size=3, stride=2, padding=1, bias=False) 125 | self.bn1 = nn.BatchNorm2d(width) 126 | self.relu1 = nn.ReLU(inplace=True) 127 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 128 | 129 | # residual layers 130 | self._inplanes = width # this is a *mutable* variable used during construction 131 | self.layer1 = self._make_layer(width, layers[0]) 132 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 133 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 134 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 135 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 136 | # self.head = nn.Linear(512 * 6, output_dim) 137 | self.head = nn.Linear(512 * block.expansion, output_dim) 138 | 139 | # embed_dim = width * 32 # the ResNet feature dimension 140 | # self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) 141 | 142 | self.init_parameters() 143 | 144 | def _make_layer( 145 | self, 146 | planes, blocks, stride=1, 147 | block=Bottleneck, 148 | ): 149 | layers = [block(self._inplanes, planes, stride)] 150 | 151 | self._inplanes = planes * block.expansion 152 | for _ in range(1, blocks): 153 | layers.append(block(self._inplanes, planes)) 154 | 155 | return nn.Sequential(*layers) 156 | 157 | def init_parameters(self): 158 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: 159 | for name, param in resnet_block.named_parameters(): 160 | if name.endswith("bn3.weight"): 161 | nn.init.zeros_(param) 162 | 163 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 164 | assert unlocked_groups == 0, 'partial locking not currently supported for this model' 165 | for param in self.parameters(): 166 | param.requires_grad = False 167 | if freeze_bn_stats: 168 | freeze_batch_norm_2d(self) 169 | 170 | @torch.jit.ignore 171 | def set_grad_checkpointing(self, enable=True): 172 | # FIXME support for non-transformer 173 | pass 174 | 175 | def stem(self, x): 176 | x = self.relu1(self.bn1(self.conv1(x))) 177 | x = self.maxpool(x) 178 | return x 179 | 180 | def forward(self, x): 181 | # x[0]: [batch_size, 3, 224, 224] 182 | # x[1]: [batch_size, 1] 183 | x = self.stem(x) # [batch_size, 64, 56, 56] 184 | x = self.layer1(x) 185 | x = self.layer2(x) 186 | x = self.layer3(x) 187 | x = self.layer4(x) # [batch_size, 2048, 7, 7] 188 | x = self.avgpool(x) # [batch_size, 2048, 1, 1] 189 | x = torch.flatten(x, 1) # [batch_size, 2048*1*1] 190 | x = self.head(x) # [batch_size, 1024] 191 | 192 | visual_output = dict.fromkeys(["image_features", "mim_loss"], None) 193 | visual_output.update({ 194 | 'image_features': x, 195 | }) 196 | 197 | return visual_output 198 | 199 | 200 | class ModifiedResNet(nn.Module): 201 | """ 202 | A ResNet class that is similar to torchvision's but contains the following changes: 203 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 204 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 205 | - The final pooling layer is a QKV attention instead of an average pool 206 | """ 207 | 208 | def __init__(self, layers, output_dim, heads, image_size=224, width=64): 209 | super().__init__() 210 | self.output_dim = output_dim 211 | self.image_size = image_size 212 | 213 | # the 3-layer stem 214 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 215 | self.bn1 = nn.BatchNorm2d(width // 2) 216 | self.relu1 = nn.ReLU(inplace=True) 217 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 218 | self.bn2 = nn.BatchNorm2d(width // 2) 219 | self.relu2 = nn.ReLU(inplace=True) 220 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 221 | self.bn3 = nn.BatchNorm2d(width) 222 | self.relu3 = nn.ReLU(inplace=True) 223 | self.avgpool = nn.AvgPool2d(2) 224 | 225 | # residual layers 226 | self._inplanes = width # this is a *mutable* variable used during construction 227 | self.layer1 = self._make_layer(width, layers[0]) 228 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 229 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 230 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 231 | 232 | embed_dim = width * 32 # the ResNet feature dimension 233 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) 234 | 235 | self.init_parameters() 236 | 237 | def _make_layer(self, planes, blocks, stride=1): 238 | layers = [Bottleneck(self._inplanes, planes, stride)] 239 | 240 | self._inplanes = planes * Bottleneck.expansion 241 | for _ in range(1, blocks): 242 | layers.append(Bottleneck(self._inplanes, planes)) 243 | 244 | return nn.Sequential(*layers) 245 | 246 | def init_parameters(self): 247 | if self.attnpool is not None: 248 | std = self.attnpool.c_proj.in_features ** -0.5 249 | nn.init.normal_(self.attnpool.q_proj.weight, std=std) 250 | nn.init.normal_(self.attnpool.k_proj.weight, std=std) 251 | nn.init.normal_(self.attnpool.v_proj.weight, std=std) 252 | nn.init.normal_(self.attnpool.c_proj.weight, std=std) 253 | 254 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: 255 | for name, param in resnet_block.named_parameters(): 256 | if name.endswith("bn3.weight"): 257 | nn.init.zeros_(param) 258 | 259 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 260 | assert unlocked_groups == 0, 'partial locking not currently supported for this model' 261 | for param in self.parameters(): 262 | param.requires_grad = False 263 | if freeze_bn_stats: 264 | freeze_batch_norm_2d(self) 265 | 266 | @torch.jit.ignore 267 | def set_grad_checkpointing(self, enable=True): 268 | # FIXME support for non-transformer 269 | pass 270 | 271 | def stem(self, x): 272 | x = self.relu1(self.bn1(self.conv1(x))) 273 | x = self.relu2(self.bn2(self.conv2(x))) 274 | x = self.relu3(self.bn3(self.conv3(x))) 275 | x = self.avgpool(x) 276 | return x 277 | 278 | def forward(self, x): 279 | x = self.stem(x) 280 | x = self.layer1(x) 281 | x = self.layer2(x) 282 | x = self.layer3(x) 283 | x = self.layer4(x) 284 | x = self.attnpool(x) 285 | 286 | visual_output = dict.fromkeys(["image_features", "mim_loss"], None) 287 | visual_output.update({ 288 | 'image_features': x, 289 | }) 290 | 291 | return visual_output 292 | 293 | 294 | class LayerNorm(nn.LayerNorm): 295 | """Subclass torch's LayerNorm to handle fp16.""" 296 | 297 | def forward(self, x: torch.Tensor): 298 | orig_type = x.dtype 299 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 300 | return x.to(orig_type) 301 | 302 | 303 | class QuickGELU(nn.Module): 304 | # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory 305 | def forward(self, x: torch.Tensor): 306 | return x * torch.sigmoid(1.702 * x) 307 | 308 | 309 | class ResidualAttentionBlock(nn.Module): 310 | def __init__( 311 | self, d_model: int, n_head: int, mlp_ratio: float = 4.0, act_layer: Callable = nn.GELU, 312 | drop_attention_rate: float = 0., 313 | ): 314 | super().__init__() 315 | 316 | self.attn = nn.MultiheadAttention( 317 | embed_dim=d_model, 318 | num_heads=n_head, 319 | dropout=drop_attention_rate, 320 | ) 321 | self.ln_1 = LayerNorm(d_model) 322 | mlp_width = int(d_model * mlp_ratio) 323 | self.mlp = nn.Sequential(OrderedDict([ 324 | ("c_fc", nn.Linear(d_model, mlp_width)), 325 | ("gelu", act_layer()), 326 | ("c_proj", nn.Linear(mlp_width, d_model)) 327 | ])) 328 | self.ln_2 = LayerNorm(d_model) 329 | 330 | def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): 331 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] 332 | 333 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): 334 | x = x + self.attention(self.ln_1(x), attn_mask=attn_mask) 335 | x = x + self.mlp(self.ln_2(x)) 336 | return x 337 | 338 | 339 | class PatchDropout(nn.Module): 340 | """ 341 | https://arxiv.org/abs/2212.00794 342 | """ 343 | 344 | def __init__(self, prob, exclude_first_token=True): 345 | super().__init__() 346 | assert 0 <= prob < 1. 347 | self.prob = prob 348 | self.exclude_first_token = exclude_first_token # exclude CLS token 349 | 350 | def forward(self, x): 351 | if not self.training or self.prob == 0.: 352 | return x 353 | 354 | if self.exclude_first_token: 355 | cls_tokens, x = x[:, :1], x[:, 1:] 356 | else: 357 | cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) 358 | 359 | batch = x.size()[0] 360 | num_tokens = x.size()[1] 361 | 362 | batch_indices = torch.arange(batch) 363 | batch_indices = batch_indices[..., None] 364 | 365 | keep_prob = 1 - self.prob 366 | num_patches_keep = max(1, int(num_tokens * keep_prob)) 367 | 368 | rand = torch.randn(batch, num_tokens) 369 | patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices 370 | 371 | x = x[batch_indices, patch_indices_keep] 372 | 373 | if self.exclude_first_token: 374 | x = torch.cat((cls_tokens, x), dim=1) 375 | 376 | return x 377 | 378 | 379 | class Transformer(nn.Module): 380 | def __init__( 381 | self, width: int, layers: int, heads: int, mlp_ratio: float = 4.0, act_layer: Callable = nn.GELU, 382 | drop_attention_rate: float = 0., 383 | ): 384 | super().__init__() 385 | self.width = width 386 | self.layers = layers 387 | self.grad_checkpointing = False 388 | 389 | self.resblocks = nn.ModuleList([ 390 | ResidualAttentionBlock(width, heads, mlp_ratio, act_layer=act_layer, drop_attention_rate=drop_attention_rate) 391 | for _ in range(layers) 392 | ]) 393 | 394 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): 395 | for r in self.resblocks: 396 | if self.grad_checkpointing and not torch.jit.is_scripting(): 397 | x = checkpoint(r, x, attn_mask) 398 | else: 399 | x = r(x, attn_mask=attn_mask) 400 | return x -------------------------------------------------------------------------------- /utils/answering.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import warnings 4 | from tqdm import tqdm 5 | from PIL import Image 6 | from utils import io_tools 7 | from transformers import set_seed, logging 8 | 9 | 10 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 11 | logging.set_verbosity_error() 12 | 13 | ROOT = io_tools.get_root(__file__, 2) 14 | PROMPTS_LOC = f'{ROOT}/configs/prompts/answering.json' 15 | DATA_PATH = f'{ROOT}/data/dataset.json' 16 | STATS_PATH = f'{ROOT}/data/stats/statistics.json' 17 | DATA = io_tools.load_json(DATA_PATH) 18 | STATS = io_tools.load_json(STATS_PATH) 19 | PROMPTS = io_tools.load_json(PROMPTS_LOC) 20 | 21 | 22 | class BaseAnsweringModel(): 23 | def __init__(self, model_args_path, mode, data_path, local_image_address=True, tr=3, device='cuda'): 24 | self.key = None 25 | self.model_args_path = model_args_path 26 | self.conversion = io_tools.load_json(PROMPTS_LOC).get('conversion') 27 | self.mode = mode 28 | self.tr = tr 29 | self.data_path = data_path 30 | self.prompt_key = 'prompts' 31 | self.local_image_address = local_image_address 32 | self.device = device 33 | self.set_model_params() 34 | 35 | def set_model_params(self): 36 | args = io_tools.load_json(self.model_args_path) 37 | self.set_init_prompt(args.get('init_prompt_id')) 38 | self.temperature = args.get("temperature") 39 | self.num_beams = args.get('num_beams') 40 | self.max_new_tokens = args.get('max_new_tokens') 41 | self.top_p = args.get('top_p') 42 | if self.mode == 'mc': 43 | self.temperature = 0 44 | self.num_beams = 1 45 | self.top_p = None 46 | self.max_new_tokens = 32 47 | if self.mode == 'gpt4': 48 | self.clean_up = self.clean_up_gpt 49 | global gpt 50 | from Models import gpt 51 | else: 52 | self.clean_up = self.clean_up_manual 53 | return args 54 | 55 | def ask_question(self, question, options, image_list): 56 | return self.convert_question(question, options) 57 | 58 | def set_init_prompt(self, init_prompt_id): 59 | self.init_prompt = None 60 | if init_prompt_id is None: 61 | return 62 | tmp = PROMPTS.get('init_prompts').get(self.key) 63 | if tmp is not None: 64 | self.init_prompt = tmp.get(init_prompt_id) 65 | else: 66 | self.init_prompt = PROMPTS.get('init_prompts').get('default') 67 | 68 | def evaluate(self, resume_path, save_dir): 69 | results = io_tools.load_resume_dict(resume_path) 70 | score = self.create_score_table([], [], -1, -1, -1, -1, -1) 71 | save_path = self.check_folder(save_dir) 72 | for id in tqdm(DATA.keys()): 73 | if id in results.keys(): 74 | sample_score = results.get(id).get('score') 75 | else: 76 | sample = DATA.get(id) 77 | ans_dict, sample_score = self.sample_eval(sample) 78 | results[id] = {'answer': ans_dict, 'score': sample_score} 79 | 80 | self.update_score_table(score, sample_score) 81 | if save_path is not None: 82 | io_tools.save_json(results, f'{save_path}/{self.key}_{self.mode}.json') 83 | self.print_score(score) 84 | if save_path is not None: 85 | io_tools.save_json(score, f'{save_path}/{self.key}_{self.mode}_score.json') 86 | return results, score 87 | 88 | def sample_eval(self, sample): 89 | if self.local_image_address: 90 | image_list = [f"{self.data_path}/{sample.get(x)}.jpg" for x in ['im_1_local', 'im_2_local']] 91 | image_list = [f"{self.data_path}/{sample.get('im_1_local')}.jpg", f"{self.data_path}/{sample.get('im_2_local')}.jpg"] 92 | else: 93 | image_list = [f"{self.data_path}/roco-dataset/data/{sample.get(x)}" for x in ['im_1', 'im_2']] 94 | question = sample.get('question') 95 | options = [sample.get('option_A'), sample.get('option_B')] 96 | im1_ans = sample.get('im_1_correct') 97 | im2_ans = sample.get('im_2_correct') 98 | responses = self.ask_question(question, options, image_list) 99 | ans_dict = {'im1': self.clean_up(question, options, responses[0]), 100 | 'im2': self.clean_up(question, options, responses[1])} 101 | im1_correct, im1_invalid, im2_correct, im2_invalid, confused = self.get_score(ans_dict, im1_ans, im2_ans) 102 | scores = self.create_score_table(sample.get('category_1'), 103 | sample.get('category_2'), 104 | im1_correct, 105 | im2_correct, 106 | im1_invalid, 107 | im2_invalid, 108 | confused, 109 | ) 110 | 111 | 112 | return ans_dict, scores 113 | 114 | def get_clean_up_prompt(self, question, options, response): 115 | role = self.conversion.get('role') 116 | return (f'[Question]\n{question}\n\n' 117 | f'[Answer A]\n{options[0]}\n\n' 118 | f'[Answer B]\n{options[1]}\n\n' 119 | f'[{role}]\n{response}\n\n[End of {role}]\n\n' 120 | f'[System]\n{self.conversion.get("instruct_prompt")}\n\n') 121 | 122 | def get_score(self, ans_dict, im1_ans, im2_ans): 123 | im1_correct, c1 = self.check_answer(im1_ans, 124 | ans_dict.get('im1').get('A'), 125 | ans_dict.get('im1').get('B'), 126 | self.tr) 127 | invalid1 = (c1 == '-') 128 | im2_correct, c2 = self.check_answer(im2_ans, 129 | ans_dict.get('im2').get('A'), 130 | ans_dict.get('im2').get('B'), 131 | self.tr) 132 | invalid2 = (c2 == '-') 133 | confused = 1 * ((c1 == c2) and (c1 != '-')) 134 | ans_dict 135 | return im1_correct, invalid1, im2_correct, invalid2, confused 136 | 137 | def clean_up_gpt(self, question, options, answer): 138 | client = gpt.get_client() 139 | prompt = self.get_clean_up_prompt(question, options, answer) 140 | response = gpt.get_response(client=client, 141 | deployment_name=self.conversion.get('gpt_deployment_name'), 142 | init_prompt=self.conversion.get('init_prompt'), 143 | prompt=prompt, 144 | temperature=float(self.conversion.get('temperature')), 145 | ) 146 | ans = self.process_gpt_response(response) 147 | ans['full_answer'] = answer 148 | return ans 149 | 150 | def clean_up_manual(self, question, options, answer): 151 | labels = ['A', 'B'] 152 | scores = {'full_answer': answer} 153 | for key in labels: 154 | scores[key] = 0 155 | if answer is not None: 156 | answer = answer.replace('\n', ' ') 157 | tmp = answer.split(' ') 158 | for la in labels: 159 | valid_list = [f'{la}', f'{la}:', f'.{la}', f'.{la}:', f'{la}.', 160 | f'{la}\")', f'{la}\n', f'\n{la}', f'(\"{la}\":', 161 | f'(\"{la}\")', f'(\"{la}\").'] 162 | correct = any([x in tmp for x in valid_list]) 163 | if correct: 164 | scores[la] = 10 165 | tmp = [1 for x in scores.values() if x==10] 166 | if sum(tmp) > 1: 167 | for key in labels: 168 | scores[key] = 0 169 | return scores 170 | 171 | def convert_question(self, question, options): 172 | prompt_dict = PROMPTS.get(self.prompt_key).get(self.mode) 173 | if self.key in prompt_dict.keys(): 174 | key = self.key 175 | else: 176 | key = 'default' 177 | 178 | tmp = prompt_dict.get(key) 179 | if self.mode == 'gpt4': 180 | output = tmp.format(question) 181 | elif self.mode == 'greedy': 182 | output = tmp.format(question, options[0], options[1]) 183 | elif self.mode == 'mc': 184 | output = tmp.format(question, options[0], options[1]) 185 | elif self.mode == 'prefix': 186 | output = { 187 | "question": question, 188 | "option_A": options[0], 189 | "option_B": options[1], 190 | "format": tmp 191 | } 192 | return output 193 | 194 | def check_folder(self, save_dir): 195 | if save_dir is None: 196 | return None 197 | save_path = f'{save_dir}/{self.key}' 198 | if not os.path.isdir(save_path): 199 | os.makedirs(save_path) 200 | return save_path 201 | 202 | @staticmethod 203 | def update_score_table(score, sample_score): 204 | for key in score: 205 | tmp = score.get(key) 206 | for cat in tmp.keys(): 207 | tmp[cat] += sample_score.get(key).get(cat) 208 | 209 | @staticmethod 210 | def create_score_table(cat_1, cat_2, im1_correct, im2_correct, im1_invalid, im2_invalid, confused): 211 | scores = {'set_score': {}, 'individual_score': {}, 'confused': {}, 'invalid': {}, 'valids': {}, 'valid_pairs': {}} 212 | for v in scores.values(): 213 | for key in STATS.keys(): 214 | v[key] = 0 215 | v['total'] = 0 216 | if confused == 1: 217 | scores.get('confused')['total'] += 1 218 | for c in (cat_1 + cat_2): 219 | scores.get('confused')[c] += 1 220 | 221 | if im1_correct == 1: 222 | scores.get('individual_score')['total'] += 1 223 | for c in cat_1: 224 | scores.get('individual_score')[c] += 1 225 | if im2_correct == 1: 226 | scores.get('individual_score')['total'] += 1 227 | for c in cat_2: 228 | scores.get('individual_score')[c] += 1 229 | 230 | if im1_invalid == 1: 231 | scores.get('invalid')['total'] += 1 232 | for c in cat_1: 233 | scores.get('invalid')[c] += 1 234 | if im2_invalid == 1: 235 | scores.get('invalid')['total'] += 1 236 | for c in cat_2: 237 | scores.get('invalid')[c] += 1 238 | 239 | if (im1_invalid == 0) and (im2_invalid == 0): 240 | scores.get('valids')['total'] += 1 241 | for c in (cat_1 + cat_2): 242 | scores.get('valids')[c] += 1 243 | if confused == 0: 244 | scores.get('valid_pairs')['total'] += 1 245 | for c in (cat_1 + cat_2): 246 | scores.get('valid_pairs')[c] += 1 247 | 248 | if (im1_correct == 1) and (im2_correct == 1): 249 | scores.get('set_score')['total'] += 1 250 | for c in (cat_1 + cat_2): 251 | scores.get('set_score')[c] += 1 252 | return scores 253 | 254 | @staticmethod 255 | def print_score(score, precision=2): 256 | print('\n') 257 | # print_format = "{:<17} {:<10} {:<10} {:<10} {:<12} {:<17} {:<10}" 258 | print_format = "{:<17} {:<10} {:<10} {:<17} {:<15} {:<15} {:<15} {:<15} {:<15}" 259 | print(print_format.format('Category', 260 | 'Total', 261 | 'Set acc.', 262 | 'Individual acc.', 263 | 'Confused acc.', 264 | 'Valid pairs', 265 | 'Invalid acc.', 266 | 'Precision', 267 | 'Precision total', 268 | )) 269 | key_list = list(STATS.keys()) + ['total'] 270 | for cat in key_list: 271 | if cat == 'total': 272 | total = len(DATA) 273 | num = total / 100 274 | individual_acc = round(score.get('individual_score').get(cat) / num / 2, precision) 275 | invalid = round(score.get('invalid').get(cat) / num / 2, precision) 276 | txt = 'All' 277 | else: 278 | total = STATS.get(cat) 279 | num = total / 100 280 | individual_acc = round(score.get('individual_score').get(cat) / num, precision) 281 | invalid = round(score.get('invalid').get(cat) / num, precision) 282 | txt = cat 283 | 284 | set_acc = round(score.get('set_score').get(cat) / num, precision) 285 | valid_pairs = score.get('valids').get(cat) 286 | precision_total = score.get('valid_pairs').get(cat) 287 | 288 | confused = 0 289 | if score.get('valids').get(cat) > 0: 290 | confused = round(score.get('confused').get(cat) / score.get('valids').get(cat) * 100, precision) 291 | 292 | pr = 0 293 | if score.get('valid_pairs').get(cat) > 0: 294 | pr = round(score.get('set_score').get(cat) / score.get('valid_pairs').get(cat) * 100, precision) 295 | 296 | print(print_format.format(txt, total, set_acc, individual_acc, confused, valid_pairs, invalid, pr, precision_total)) 297 | 298 | @staticmethod 299 | def process_gpt_response(response): 300 | if response is None: 301 | return { 302 | 'A': 0, 303 | 'B': 0, 304 | 'gpt_reason': '', 305 | } 306 | tmp = response.replace('\n\n', '\n').split('\n') 307 | ans = { 308 | 'A': int(tmp[0].replace('A: ', '')), 309 | 'B': int(tmp[1].replace('B: ', '')), 310 | 'gpt_reason': tmp[2].replace('Your explanation: ', ''), 311 | } 312 | return ans 313 | 314 | @staticmethod 315 | def check_answer(answer, a_score, b_score, tr): 316 | chosen = '-' 317 | if a_score >= b_score + tr: 318 | chosen = 'A' 319 | elif b_score >= a_score + tr: 320 | chosen = 'B' 321 | if chosen == answer: 322 | return 1, chosen 323 | return 0, chosen 324 | 325 | 326 | class GPTAnswering(BaseAnsweringModel): 327 | 328 | def set_model_params(self): 329 | global gpt 330 | from Models import gpt 331 | self.key = 'gpt' 332 | args = super().set_model_params() 333 | self.deployment_name = args.get("deployment_name") 334 | self.api_version = args.get("api_version") 335 | self.client = gpt.get_client() 336 | if self.mode in ['greedy', 'prefix']: 337 | raise ValueError(f'Cannot use forward for GPT!') 338 | 339 | def ask_question(self, question, options, image_list): 340 | qs = super().ask_question(question, options, image_list) 341 | response_list = [] 342 | for image in image_list: 343 | response = gpt.ask_question(self.client, image, qs, self.init_prompt, self.deployment_name, self.temperature) 344 | response_list.append(response) 345 | return response_list 346 | 347 | class DeepSeekAnswering(BaseAnsweringModel): 348 | 349 | def set_model_params(self): 350 | global deepseek 351 | from Models import deepseek 352 | self.key = 'deepseek' 353 | args = super().set_model_params() 354 | self.model_path = args.get("model_path") 355 | self.max_new_tokens = args.get("max_new_tokens") 356 | self.do_sample = args.get("do_sample") 357 | self.model, self.processor, self.tokenizer, self.janus = deepseek.load_model(self.model_path) 358 | # if self.mode in ['greedy', 'prefix']: 359 | # raise ValueError(f'Not implemented!') 360 | 361 | def ask_question(self, question, options, image_list): 362 | qs = super().ask_question(question, options, image_list) 363 | response_list = [] 364 | for image in image_list: 365 | response = deepseek.ask_question(self.model, self.processor, self.tokenizer, 366 | image, qs, self.temperature, self.mode, 367 | self.max_new_tokens, self.do_sample, self.janus) 368 | response_list.append(response) 369 | return response_list 370 | 371 | class ClaudeAnswering(BaseAnsweringModel): 372 | 373 | def set_model_params(self): 374 | global claude 375 | from Models import claude 376 | self.key = 'claude' 377 | args = super().set_model_params() 378 | self.deployment_name = args.get("deployment_name") 379 | self.client = claude.get_client() 380 | if self.mode in ['greedy', 'prefix']: 381 | raise ValueError(f'Cannot use forward for Claude!') 382 | 383 | def ask_question(self, question, options, image_list): 384 | qs = super().ask_question(question, options, image_list) 385 | response_list = [] 386 | for image in image_list: 387 | response = claude.ask_question(self.client, image, qs, self.init_prompt, self.temperature, self.deployment_name) 388 | response_list.append(response) 389 | return response_list 390 | 391 | 392 | class GeminiAnswering(BaseAnsweringModel): 393 | def set_model_params(self): 394 | global gemini 395 | from Models import gemini 396 | self.key = 'gemini' 397 | args = super().set_model_params() 398 | self.deployment_name = args.get("deployment_name") 399 | self.model = gemini.load_model(self.init_prompt, self.temperature, self.deployment_name) 400 | if self.mode in ['greedy', 'prefix']: 401 | raise ValueError(f'Cannot use forward for Gemini!') 402 | 403 | def ask_question(self, question, options, image_list): 404 | qs = super().ask_question(question, options, image_list) 405 | response_list = [] 406 | for image in image_list: 407 | flag = True 408 | counter = 0 409 | while flag: 410 | try: 411 | response = gemini.ask_question(self.model, image, qs) 412 | flag = False 413 | except Exception as e: 414 | counter += 1 415 | print(counter, e) 416 | time.sleep(10) 417 | response_list.append(response) 418 | return response_list 419 | 420 | 421 | class LLaVAMedAnswering(BaseAnsweringModel): 422 | 423 | def set_model_params(self): 424 | global llava_med 425 | from Models import llava_med 426 | self.key = 'llava_med' 427 | args = super().set_model_params() 428 | 429 | set_seed(0) 430 | tokenizer, model, image_processor, context_len = \ 431 | llava_med.load_model(args.get("model_path"), args.get("model_base")) 432 | 433 | self.model = model 434 | self.tokenizer = tokenizer 435 | self.image_processor = image_processor 436 | 437 | if self.device == 'cpu': 438 | warnings.warn('LLaVA-Med implepenation does not support CPU! Switching to CUDA.') 439 | self.device = 'cuda' 440 | 441 | self.conv_mode = args.get("conv_mode") 442 | self.use_im_start_end = args.get('use_im_start_end') 443 | 444 | def convert_question(self, question, options): 445 | tmp = super().convert_question(question, options) 446 | if self.mode == 'prefix': 447 | to_process = tmp["question"] 448 | else: 449 | to_process = tmp 450 | 451 | to_process = '\n' + to_process 452 | qs = to_process.replace(llava_med.DEFAULT_IMAGE_TOKEN, '').strip() 453 | if self.use_im_start_end: 454 | qs = llava_med.DEFAULT_IM_START_TOKEN + llava_med.DEFAULT_IMAGE_TOKEN + llava_med.DEFAULT_IM_END_TOKEN + '\n' + qs 455 | else: 456 | qs = llava_med.DEFAULT_IMAGE_TOKEN + '\n' + qs 457 | 458 | if self.mode == 'prefix': 459 | tmp["question"] = qs 460 | return tmp 461 | else: 462 | return qs 463 | 464 | def ask_question(self, question, options, image_list): 465 | question = super().ask_question(question, options, image_list) 466 | response_list = [] 467 | image_list = [Image.open(x) for x in image_list] 468 | for image in image_list: 469 | outputs = llava_med.ask_question(self.model, 470 | question, 471 | image, 472 | self.image_processor, 473 | self.tokenizer, 474 | self.mode, 475 | conv_mode=self.conv_mode, 476 | temperature=self.temperature, 477 | top_p=self.top_p, 478 | num_beams=self.num_beams, 479 | max_new_tokens=self.max_new_tokens) 480 | response_list.append(outputs) 481 | 482 | return response_list 483 | 484 | 485 | class LLaVAAnswering(BaseAnsweringModel): 486 | 487 | def set_model_params(self): 488 | global llava 489 | from Models import llava 490 | self.key = 'llava' 491 | args = super().set_model_params() 492 | 493 | set_seed(0) 494 | model_id = args.get('model_id') 495 | model, processor = llava.load_model(model_id, self.device) 496 | 497 | self.model = model 498 | self.processor = processor 499 | self.conv_mode = args.get("conv_mode") 500 | 501 | def ask_question(self, question, options, image_list): 502 | question = super().ask_question(question, options, image_list) 503 | response_list = [] 504 | image_list = [Image.open(x) for x in image_list] 505 | for image in image_list: 506 | outputs = llava.ask_question(self.model, 507 | self.processor, 508 | question, 509 | image, 510 | self.mode, 511 | temperature=self.temperature, 512 | top_p=self.top_p, 513 | num_beams=self.num_beams) 514 | response_list.append(outputs) 515 | 516 | return response_list 517 | 518 | 519 | class RadFMAnswering(BaseAnsweringModel): 520 | 521 | def set_model_params(self): 522 | global radfm 523 | from Models import radfm 524 | self.key = 'radfm' 525 | args = super().set_model_params() 526 | self.model_path = args.get("model_path") 527 | model, text_tokenizer, image_padding_tokens = radfm.load_model(self.model_path, self.device) 528 | self.model = model 529 | self.text_tokenizer = text_tokenizer 530 | self.image_padding_tokens = image_padding_tokens 531 | 532 | def ask_question(self, question, options, image_list): 533 | question = super().ask_question(question, options, image_list) 534 | response_list = [] 535 | for image_path in image_list: 536 | outputs = radfm.ask_question(self.model, 537 | question, 538 | image_path, 539 | self.text_tokenizer, 540 | self.image_padding_tokens, 541 | self.mode, 542 | self.device) 543 | response_list.append(outputs) 544 | return response_list 545 | 546 | class BLIP2Answering(BaseAnsweringModel): 547 | 548 | def set_model_params(self): 549 | global blip2 550 | from Models import blip2 551 | self.key = 'blip2' 552 | args = super().set_model_params() 553 | model_id = args.get('model_id') 554 | model, processor = blip2.load_model(model_id, self.device) 555 | self.model = model 556 | self.processor = processor 557 | 558 | def ask_question(self, question, options, image_list): 559 | question = super().ask_question(question, options, image_list) 560 | response_list = [] 561 | for image_path in image_list: 562 | outputs = blip2.ask_question(self.model, 563 | question, 564 | image_path, 565 | self.processor, 566 | self.num_beams, 567 | self.max_new_tokens, 568 | self.top_p, 569 | self.temperature, 570 | self.mode) 571 | response_list.append(outputs) 572 | return response_list 573 | 574 | class InstructBLIPAnswering(BaseAnsweringModel): 575 | 576 | def set_model_params(self): 577 | global instructblip 578 | from Models import instructblip 579 | self.key = 'instructblip' 580 | args = super().set_model_params() 581 | model_id = args.get('model_id') 582 | model, processor = instructblip.load_model(model_id, self.device) 583 | self.model = model 584 | self.processor = processor 585 | if self.temperature == 0: 586 | self.temperature = 0.1 587 | 588 | def ask_question(self, question, options, image_list): 589 | question = super().ask_question(question, options, image_list) 590 | response_list = [] 591 | for image_path in image_list: 592 | outputs = instructblip.ask_question(self.model, 593 | question, 594 | image_path, 595 | self.processor, 596 | self.num_beams, 597 | self.max_new_tokens, 598 | self.top_p, 599 | self.temperature, 600 | self.mode) 601 | response_list.append(outputs) 602 | return response_list 603 | 604 | 605 | class MolmoAnswering(BaseAnsweringModel): 606 | 607 | def set_model_params(self): 608 | global molmo 609 | from Models import molmo 610 | self.key = 'molmo' 611 | args = super().set_model_params() 612 | model_id = args.get('model_id') 613 | model, processor = molmo.load_model(model_id, self.device) 614 | self.model = model 615 | self.processor = processor 616 | 617 | def ask_question(self, question, options, image_list): 618 | question = super().ask_question(question, options, image_list) 619 | response_list = [] 620 | for image_path in image_list: 621 | outputs = molmo.ask_question(self.model, 622 | question, 623 | image_path, 624 | self.processor, 625 | self.num_beams, 626 | self.max_new_tokens, 627 | self.top_p, 628 | self.temperature, 629 | self.mode) 630 | response_list.append(outputs) 631 | return response_list 632 | 633 | 634 | class MedFlamingoAnswering(BaseAnsweringModel): 635 | 636 | def set_model_params(self): 637 | global med_flamingo 638 | from Models import med_flamingo 639 | self.key = 'med_flamingo' 640 | args = super().set_model_params() 641 | self.LLaMa_PATH = args.get('LLaMa_PATH') 642 | self.CHECKPOINT_PATH = args.get('CHECKPOINT_PATH') 643 | self.IMAGE_PATH = args.get('IMAGE_PATH') 644 | model, processor = med_flamingo.load_model(self.LLaMa_PATH, self.CHECKPOINT_PATH, self.device) 645 | self.model = model 646 | self.processor = processor 647 | 648 | def ask_question(self, question, options, image_list): 649 | question = super().ask_question(question, options, image_list) 650 | response_list = [] 651 | for image_path in image_list: 652 | outputs = med_flamingo.ask_question(self.model, 653 | self.processor, 654 | image_path, 655 | question, 656 | self.max_new_tokens, 657 | self.mode, 658 | self.IMAGE_PATH, 659 | ) 660 | response_list.append(outputs) 661 | return response_list 662 | 663 | 664 | class LlamaAnswering(BaseAnsweringModel): 665 | 666 | def set_model_params(self): 667 | global llama 668 | from Models import llama 669 | self.key = 'llama' 670 | args = super().set_model_params() 671 | 672 | token = args.get('token') 673 | model_id = args.get('model_id') 674 | model, processor = llama.load_model(token, model_id, self.device) 675 | 676 | self.model = model 677 | self.processor = processor 678 | 679 | def ask_question(self, question, options, image_list): 680 | question = super().ask_question(question, options, image_list) 681 | response_list = [] 682 | image_list = [Image.open(x) for x in image_list] 683 | for image in image_list: 684 | outputs = llama.ask_question(self.model, 685 | self.processor, 686 | question, 687 | image, 688 | self.mode, 689 | temperature=self.temperature, 690 | top_p=self.top_p, 691 | num_beams=self.num_beams) 692 | response_list.append(outputs) 693 | 694 | return response_list 695 | 696 | 697 | class MedVInTAnswering(BaseAnsweringModel): 698 | 699 | def set_model_params(self): 700 | global med_flamingo 701 | from Models import medvint 702 | self.key = 'medvint' 703 | args = super().set_model_params() 704 | self.model_args = medvint.ModelArguments() 705 | self.model_args.embed_dim = args.get("EMBED_DIM") 706 | self.model_args.pretrained_tokenizer = args.get("PRETRAINED_TOKENIZER") 707 | self.model_args.pretrained_model = args.get("PRETRAINED_MODEL") 708 | self.model_args.image_encoder = args.get("IMAGE_ENCODER") 709 | self.model_args.pmcclip_pretrained = args.get("PMCCLIP_PRETRAINED") 710 | self.model_args.clip_pretrained = args.get("CLIP_PRETRAINED") 711 | self.model_args.ckp = args.get("CKP") 712 | model, image_transform, tokenizer = medvint.load_model(self.model_args) 713 | self.model = model 714 | self.image_transform = image_transform 715 | self.tokenizer = tokenizer 716 | 717 | def ask_question(self, question, options, image_list): 718 | question = super().ask_question(question, options, image_list) 719 | image_list = [Image.open(x).convert('RGB') for x in image_list] 720 | response_list = [] 721 | for image in image_list: 722 | image = self.image_transform(image) 723 | outputs = med_flamingo.ask_question(self.model, 724 | self.tokenizer, 725 | question, 726 | image, 727 | ) 728 | response_list.append(outputs) 729 | return response_list 730 | ANSWERING_CLASS_DICT = { 731 | 'gpt': GPTAnswering, 732 | 'deepseek': DeepSeekAnswering, 733 | 'claude': ClaudeAnswering, 734 | 'gemini': GeminiAnswering, 735 | 'llava_med': LLaVAMedAnswering, 736 | 'llava': LLaVAAnswering, 737 | 'radfm': RadFMAnswering, 738 | 'blip2': BLIP2Answering, 739 | 'instructblip': InstructBLIPAnswering, 740 | 'med_flamingo': MedFlamingoAnswering, 741 | 'medvint': MedVInTAnswering, 742 | 'llama': LlamaAnswering, 743 | 'molmo': MolmoAnswering, 744 | } 745 | 746 | DEFAULT_MODEL_CONFIGS = { 747 | 'gpt': f'{ROOT}/configs/Models/gpt/vanilla.json', 748 | 'deepseek': f'{ROOT}/configs/Models/deepseek/vanilla.json', 749 | 'claude': f'{ROOT}/configs/Models/claude/vanilla.json', 750 | 'gemini': f'{ROOT}/configs/Models/gemini/vanilla.json', 751 | 'llava_med': f'{ROOT}/configs/Models/llava_med/vanilla.json', 752 | 'llava': f'{ROOT}/configs/Models/llava/vanilla.json', 753 | 'radfm': f'{ROOT}/configs/Models/radfm/vanilla.json', 754 | 'blip2': f'{ROOT}/configs/Models/blip2/vanilla.json', 755 | 'instructblip': f'{ROOT}/configs/Models/instructblip/vanilla.json', 756 | 'med_flamingo': f'{ROOT}/configs/Models/med_flamingo/vanilla.json', 757 | 'medvint': f'{ROOT}/configs/Models/medvint/vanilla.json', 758 | 'llama': f'{ROOT}/configs/Models/llama/vanilla.json', 759 | 'molmo': f'{ROOT}/configs/Models/molmo/vanilla.json', 760 | } 761 | --------------------------------------------------------------------------------