├── .gitignore ├── README.md ├── arguments.py ├── collators ├── __init__.py ├── base.py ├── chat_template_monkey_patch.py └── llava_1_5.py ├── datasets.py ├── ds_configs ├── zero2.json └── zero3.json ├── imgs ├── ddvqa.jpg ├── fakeclue_loki_result.jpg ├── framework.jpg ├── logo.jpg ├── overview.png └── result.jpg ├── loaders ├── __init__.py ├── base.py └── llava_1_5.py ├── requirements.txt ├── scripts ├── eval.py ├── eval.sh ├── eval_vllm.py └── train.sh ├── supported_models.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pyc 3 | *.egg-info 4 | dist 5 | 6 | # Other 7 | .DS_Store 8 | wandb 9 | output 10 | 11 | checkpoints 12 | ckpts* 13 | 14 | .ipynb_checkpoints 15 | *.ipynb 16 | 17 | *.ttf 18 | 19 | local* 20 | not_useful_for_now 21 | *slurm* 22 | 23 | example_data/videos/ego4d/*.mp4 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

Image Alt Text Spot the Fake: Large Multimodal Model-Based Synthetic Image Detection with Artifact Explanation 3 |

4 |
5 |
6 | 7 | [Siwei Wen](https://scholar.google.com/citations?user=kJRiUYwAAAAJ&hl=zh-CN)1,3*, 8 | [Junyan Ye](https://yejy53.github.io/)2,1*, 9 | [Peilin Feng](https://peilin-ff.github.io/)1,3, 10 | [Hengrui Kang](https://scholar.google.com/citations?user=kVbzWCAAAAAJ&hl=zh-CN)4,1,
11 | [Zichen Wen](https://scholar.google.com/citations?user=N-aPFvEAAAAJ&hl=zh-CN)4,1, 12 | [Yize Chen](https://openreview.net/profile?id=~Yize_Chen2)5, 13 | [Jiang Wu](https://scholar.google.com/citations?user=LHiiL7AAAAAJ&hl=zh-CN)1, 14 | [Wenjun Wu](https://openreview.net/profile?id=~wenjun_wu3)3, 15 | [Conghui He](https://conghui.github.io/)1, 16 | [Weijia Li](https://liweijia.github.io/)2,1† 17 | 18 | 1Shanghai Artificial Intelligence Laboratory, 2Sun Yat-sen University
19 | 3Beihang University, 4Shanghai Jiao Tong University, 5The Chinese University of Hong Kong, Shenzhen 20 | 21 |
22 | 23 |
24 | 25 | [![arXiv](https://img.shields.io/badge/Arxiv-2503.14905-AD1C18.svg?logo=arXiv)](https://arxiv.org/pdf/2503.14905) 26 | [![](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2Fopendatalab%2FFakeVLM&count_bg=%23C25AE6&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=Visitor&edge_flat=false)](https://hits.seeyoufarm.com) 27 | [![GitHub issues](https://img.shields.io/github/issues/opendatalab/FakeVLM?color=critical&label=Issues)](https://github.com/opendatalab/FakeVLM/issues) 28 | [![GitHub Stars](https://img.shields.io/github/stars/opendatalab/FakeVLM?style=social)](https://github.com/opendatalab/FakeVLM/stargazers) 29 | [![Dataset](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Dataset-yellow)](https://huggingface.co/datasets/lingcco/FakeClue) 30 | [![Model](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-yellow)](https://huggingface.co/lingcco/fakeVLM) 31 |
32 | 33 | 39 | 40 | ## 📰 News 41 | - **[2025.9.24]**: 🎉 FakeVLM was accepted to NeurIPS 2025! 42 | - **[2025.4.15]**: 🤗 We are excited to release the FakeClue dataset. Check out [here](https://huggingface.co/datasets/lingcco/FakeClue). 43 | - **[2025.3.20]**: 🔥 We have released **Spot the Fake: Large Multimodal Model-Based Synthetic Image Detection with Artifact Explanation**. Check out the [paper](https://arxiv.org/abs/2503.14905). We present FakeClue dataset and FakeVLM model. 44 | 45 | ## FakeVLM Overview 46 | 47 | With the rapid advancement of Artificial Intelligence Generated Content (AIGC) technologies, synthetic images have become increasingly prevalent in everyday life, posing new challenges for authenticity assessment and detection. Despite the effectiveness of existing methods in evaluating image authenticity and locating forgeries, these approaches often lack human interpretability and do not fully address the growing complexity of synthetic data. To tackle these challenges, we introduce FakeVLM, a specialized large multimodal model designed for both general synthetic image and DeepFake detection tasks. FakeVLM not only excels in distinguishing real from fake images but also provides clear, natural language explanations for image artifacts, enhancing interpretability. Additionally, we present FakeClue, a comprehensive dataset containing over 100,000 images across seven categories, annotated with fine-grained artifact clues in natural language. FakeVLM demonstrates performance comparable to expert models while eliminating the need for additional classifiers, making it a robust solution for synthetic data detection. Extensive evaluations across multiple datasets confirm the superiority of FakeVLM in both authenticity classification and artifact explanation tasks, setting a new benchmark for synthetic image detection. 48 | 49 |
50 | framework 51 |
52 | 53 | ## Contributions 54 | 55 | - We propose FakeVLM, a multimodal large model designed for both general synthetic and deepfake image detection tasks. It excels at distinguishing real from fake images while also providing excellent interpretability for artifact details in synthetic images. 56 | - We introduce the FakeClue dataset, which includes a rich variety of image categories and fine-grained artifact annotations in natural language. 57 | - Our method has been extensively evaluated on multiple datasets, achieving outstanding performance in both synthetic detection and abnormal artifact explanation tasks. 58 | 59 | ## 🛠️ Installation 60 | Please clone our repository and change to that folder 61 | ```bash 62 | git clone git@github.com:opendatalab/FakeVLM.git 63 | cd FakeVLM 64 | ``` 65 | 66 | Our model is based on the [lmms-finetune](baidu.com) environment. Please follow the steps below to configure the environment. 67 | ```bash 68 | conda create -n fakevlm python=3.10 -y 69 | conda activate fakevlm 70 | 71 | python -m pip install -r requirements.txt 72 | 73 | python -m pip install --no-cache-dir --no-build-isolation flash-attn 74 | ``` 75 | 76 | ## 📦 Dataset 77 | The directory containing the images should have the following structure: 78 | 79 | ``` 80 | playground 81 | └──data 82 | └──train 83 | |--doc 84 | |--fake 85 | |--real 86 | . 87 | . 88 | |--satellite 89 | └──test 90 | . 91 | . 92 | . 93 | ``` 94 | 95 | 96 | ## 📌 Usage 97 | ### 1. Data Preparation 98 | The training data can be downloaded from [here](https://huggingface.co/datasets/lingcco/FakeClue). 99 | 100 | Please download the dataset and unzip the images. 101 | ### 2. Train 102 | 103 | Replace data paths with yours in `scripts/train.sh` and the original [llava-1.5-7b-hf](https://huggingface.co/llava-hf/llava-1.5-7b-hf) model path with yours in `supported_models.py`. 104 | 105 | ``` 106 | bash train.sh 107 | ``` 108 | 109 | ### 3. Evaluation 110 | 111 | We prepared two scripts for you to evaluate the FakeVLM model. The trained FakeVLM model is available at [here](https://huggingface.co/lingcco/fakeVLM). 112 | 113 | #### 1. Usual evaluation 114 | 115 | ``` 116 | bash scripts/eval.sh 117 | ``` 118 | 119 | #### 2. Evaluation with vllm 120 | 121 | Considering the size of the model and the magnitude of the data, we recommend using vllm for evaluation. Please make sure that you have installed vllm. 122 | 123 | 124 | ``` 125 | # change scripts/eval.py to scripts/eval_vllm.py in scripts/eval.sh 126 | bash scripts/eval.sh 127 | ``` 128 | ## 📊 Results 129 | Performance of 7 leading LMMs and FakeVLM on DD-VQA, Fake Clues and Loki. 130 | 131 | - **FakeClue** 132 | Ours dataset. 133 | - **LOKI** 134 | A new benchmark for evaluating multimodal models in synthetic detection tasks. It includes **human-annotated fine-grained image artifacts**, enabling deeper analysis of artifact explanations. We used its image modality, covering categories like Animals, Humans, Scenery, and Documents. 135 | 136 | framework 137 | 138 | - **DD-VQA** 139 | A dataset for explaining facial artifacts, using **manual annotations** in a VQA format. Artifacts include blurred hairlines, mismatched eyebrows, rigid pupils, and unnatural shadows. It builds on FF++ data and emphasizes common-sense reasoning. 140 | 141 |
142 | framework 143 |
144 | 145 | To provide a comprehensive comparison of the model performance across the three datasets—FakeClue, LOKI, and DD-VQA—we present the following radar chart. This chart visually highlights the strengths and weaknesses of the 7 leading LMMs and FakeVLM, offering a clear depiction of their results in synthetic detection and artifact explanation tasks. 146 | 147 |
148 | result 149 |
150 | 151 | ## 😄 Acknowledgement 152 | 153 | This repository is built upon the work of [LLaVA](https://github.com/haotian-liu/LLaVA/tree/main), and our codebase is built upon [lmms-finetune](https://github.com/zjysteven/lmms-finetune). We appreciate their contributions and insights that have provided a strong foundation for our research. 154 | 155 | ## 📨 Contact 156 | 157 | If you have any questions or suggestions, please feel free to contact us 158 | at [466439420gh@gmail.com](466439420gh@gmail.com). 159 | 160 | ## 📝 Citation 161 | If you find our work interesting and helpful, please consider giving our repo a star. Additionally, if you would like to cite our work, please use the following format: 162 | ```shell 163 | @article{wen2025spot, 164 | title={Spot the fake: Large multimodal model-based synthetic image detection with artifact explanation}, 165 | author={Wen, Siwei and Ye, Junyan and Feng, Peilin and Kang, Hengrui and Wen, Zichen and Chen, Yize and Wu, Jiang and Wu, Wenjun and He, Conghui and Li, Weijia}, 166 | journal={arXiv preprint arXiv:2503.14905}, 167 | year={2025} 168 | } 169 | ``` 170 | -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, List 2 | from dataclasses import dataclass, field 3 | 4 | import transformers 5 | 6 | from supported_models import MODEL_HF_PATH, MODEL_FAMILIES 7 | 8 | 9 | @dataclass 10 | class ModelArguments: 11 | model_id: str = field(default="llava-1.5-7b") 12 | model_local_path: Optional[str] = field(default=None) 13 | 14 | def __post_init__(self): 15 | assert self.model_id in MODEL_HF_PATH, f"Unknown model_id: {self.model_id}" 16 | self.model_hf_path: str = MODEL_HF_PATH[self.model_id] 17 | assert self.model_id in MODEL_FAMILIES, f"Unknown model_id: {self.model_id}" 18 | self.model_family_id: str = MODEL_FAMILIES[self.model_id] 19 | 20 | if not self.model_local_path: 21 | self.model_local_path = self.model_hf_path 22 | 23 | 24 | @dataclass 25 | class DataArguments: 26 | data_path: str = field( 27 | default=None, metadata={"help": "Path to the training data json file."} 28 | ) 29 | eval_data_path: Optional[str] = field( 30 | default=None, metadata={"help": "Path to the evaluation data json file."} 31 | ) 32 | image_folder: Optional[str] = field(default=None) 33 | video_folder: Optional[str] = field(default=None) 34 | num_frames: Optional[int] = field(default=8) 35 | user_key: Optional[str] = field(default="human") 36 | assistant_key: Optional[str] = field(default="gpt") 37 | 38 | 39 | @dataclass 40 | class TrainingArguments(transformers.TrainingArguments): 41 | model_max_length: int = field( 42 | default=1024, 43 | metadata={ 44 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." 45 | }, 46 | ) 47 | use_flash_attn: bool = field(default=False) 48 | train_vision_encoder: bool = field(default=False) 49 | train_vision_projector: bool = field(default=False) 50 | mask_question_tokens: bool = field(default=True) 51 | 52 | def __post_init__(self): 53 | super().__post_init__() 54 | self.remove_unused_columns = False 55 | 56 | 57 | @dataclass 58 | class LoraArguments: 59 | use_lora: bool = field(default=True) 60 | use_vision_lora: bool = field(default=True) 61 | q_lora: bool = field(default=False) 62 | lora_r: int = field(default=8) 63 | lora_alpha: int = field(default=16) 64 | lora_dropout: float = field(default=0.05) 65 | lora_weight_path: str = "" 66 | lora_bias: str = "none" -------------------------------------------------------------------------------- /collators/__init__.py: -------------------------------------------------------------------------------- 1 | COLLATORS = {} 2 | 3 | def register_collator(name): 4 | def register_collator_cls(cls): 5 | if name in COLLATORS: 6 | return COLLATORS[name] 7 | COLLATORS[name] = cls 8 | return cls 9 | return register_collator_cls 10 | 11 | 12 | from .llava_1_5 import LLaVA15DataCollator -------------------------------------------------------------------------------- /collators/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, Sequence, Optional 3 | 4 | import torch 5 | from transformers import PreTrainedTokenizer, AutoProcessor, AutoConfig 6 | 7 | 8 | class BaseDataCollator(ABC, object): 9 | """Collate examples for supervised fine-tuning.""" 10 | def __init__( 11 | self, 12 | config: Optional[AutoConfig] = None, 13 | tokenizer: Optional[PreTrainedTokenizer] = None, 14 | processor: Optional[AutoProcessor] = None, 15 | mask_question_tokens: bool = True 16 | ) -> None: 17 | self.config = config 18 | self.tokenizer = tokenizer 19 | self.processor = processor 20 | self.mask_question_tokens = mask_question_tokens 21 | 22 | @property 23 | def IGNORE_TOKEN_ID(self) -> int: 24 | return -100 25 | 26 | @property 27 | def PAD_TOKEN_ID(self) -> int: 28 | return self.tokenizer.pad_token_id 29 | 30 | @abstractmethod 31 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: ... -------------------------------------------------------------------------------- /collators/chat_template_monkey_patch.py: -------------------------------------------------------------------------------- 1 | # a monkey patch for https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1709 2 | # this sets add_special_tokens=True to the tokenizer call 3 | 4 | import re 5 | from inspect import isfunction 6 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union 7 | 8 | from transformers.tokenization_utils_base import BatchEncoding 9 | from transformers.utils import TensorType, get_json_schema, logging 10 | from transformers.utils.chat_template_utils import _compile_jinja_template, _render_with_assistant_indices 11 | 12 | 13 | logger = logging.get_logger(__name__) 14 | 15 | 16 | def apply_chat_template( 17 | self, 18 | conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]], 19 | tools: Optional[List[Dict]] = None, 20 | documents: Optional[List[Dict[str, str]]] = None, 21 | chat_template: Optional[str] = None, 22 | add_generation_prompt: bool = False, 23 | continue_final_message: bool = False, 24 | tokenize: bool = True, 25 | padding: bool = False, 26 | truncation: bool = False, 27 | max_length: Optional[int] = None, 28 | return_tensors: Optional[Union[str, TensorType]] = None, 29 | return_dict: bool = False, 30 | return_assistant_tokens_mask: bool = False, 31 | tokenizer_kwargs: Optional[Dict[str, Any]] = None, 32 | **kwargs, 33 | ) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]: 34 | """ 35 | Converts a list of dictionaries with `"role"` and `"content"` keys to a list of token 36 | ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to 37 | determine the format and control tokens to use when converting. 38 | 39 | Args: 40 | conversation (Union[List[Dict[str, str]], List[List[Dict[str, str]]]]): A list of dicts 41 | with "role" and "content" keys, representing the chat history so far. 42 | tools (`List[Dict]`, *optional*): 43 | A list of tools (callable functions) that will be accessible to the model. If the template does not 44 | support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema, 45 | giving the name, description and argument types for the tool. See our 46 | [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use) 47 | for more information. 48 | documents (`List[Dict[str, str]]`, *optional*): 49 | A list of dicts representing documents that will be accessible to the model if it is performing RAG 50 | (retrieval-augmented generation). If the template does not support RAG, this argument will have no 51 | effect. We recommend that each document should be a dict containing "title" and "text" keys. Please 52 | see the RAG section of the [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#arguments-for-RAG) 53 | for examples of passing documents with chat templates. 54 | chat_template (`str`, *optional*): 55 | A Jinja template to use for this conversion. It is usually not necessary to pass anything to this 56 | argument, as the model's template will be used by default. 57 | add_generation_prompt (bool, *optional*): 58 | If this is set, a prompt with the token(s) that indicate 59 | the start of an assistant message will be appended to the formatted output. This is useful when you want to generate a response from the model. 60 | Note that this argument will be passed to the chat template, and so it must be supported in the 61 | template for this argument to have any effect. 62 | continue_final_message (bool, *optional*): 63 | If this is set, the chat will be formatted so that the final 64 | message in the chat is open-ended, without any EOS tokens. The model will continue this message 65 | rather than starting a new one. This allows you to "prefill" part of 66 | the model's response for it. Cannot be used at the same time as `add_generation_prompt`. 67 | tokenize (`bool`, defaults to `True`): 68 | Whether to tokenize the output. If `False`, the output will be a string. 69 | padding (`bool`, defaults to `False`): 70 | Whether to pad sequences to the maximum length. Has no effect if tokenize is `False`. 71 | truncation (`bool`, defaults to `False`): 72 | Whether to truncate sequences at the maximum length. Has no effect if tokenize is `False`. 73 | max_length (`int`, *optional*): 74 | Maximum length (in tokens) to use for padding or truncation. Has no effect if tokenize is `False`. If 75 | not specified, the tokenizer's `max_length` attribute will be used as a default. 76 | return_tensors (`str` or [`~utils.TensorType`], *optional*): 77 | If set, will return tensors of a particular framework. Has no effect if tokenize is `False`. Acceptable 78 | values are: 79 | - `'tf'`: Return TensorFlow `tf.Tensor` objects. 80 | - `'pt'`: Return PyTorch `torch.Tensor` objects. 81 | - `'np'`: Return NumPy `np.ndarray` objects. 82 | - `'jax'`: Return JAX `jnp.ndarray` objects. 83 | return_dict (`bool`, defaults to `False`): 84 | Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`. 85 | tokenizer_kwargs (`Dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer. 86 | return_assistant_tokens_mask (`bool`, defaults to `False`): 87 | Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant, 88 | the mask will contain 1. For user and system tokens, the mask will contain 0. 89 | This functionality is only available for chat templates that support it via the `{% generation %}` keyword. 90 | **kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template. 91 | 92 | Returns: 93 | `Union[List[int], Dict]`: A list of token ids representing the tokenized chat so far, including control tokens. This 94 | output is ready to pass to the model, either directly or via methods like `generate()`. If `return_dict` is 95 | set, will return a dict of tokenizer outputs instead. 96 | """ 97 | 98 | if return_dict and not tokenize: 99 | raise ValueError( 100 | "`return_dict=True` is incompatible with `tokenize=False`, because there is no dict " 101 | "of tokenizer outputs to return." 102 | ) 103 | 104 | if return_assistant_tokens_mask and not return_dict: 105 | raise ValueError("`return_assistant_tokens_mask=True` is incompatible with `return_dict=False`") 106 | 107 | if tokenizer_kwargs is None: 108 | tokenizer_kwargs = {} 109 | 110 | chat_template = self.get_chat_template(chat_template, tools) 111 | 112 | if return_assistant_tokens_mask and not re.search(r"\{\%-?\s*generation\s*-?\%\}", chat_template): 113 | logger.warning_once( 114 | "return_assistant_tokens_mask==True but chat template does not contain `{% generation %}` keyword." 115 | ) 116 | 117 | # Compilation function uses a cache to avoid recompiling the same template 118 | compiled_template = _compile_jinja_template(chat_template) 119 | 120 | if isinstance(conversation, (list, tuple)) and ( 121 | isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages") 122 | ): 123 | conversations = conversation 124 | is_batched = True 125 | else: 126 | conversations = [conversation] 127 | is_batched = False 128 | 129 | if continue_final_message: 130 | if add_generation_prompt: 131 | raise ValueError( 132 | "continue_final_message and add_generation_prompt are not compatible. Use continue_final_message when you want the model to continue the final message, and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead." 133 | ) 134 | if return_assistant_tokens_mask: 135 | raise ValueError("continue_final_message is not compatible with return_assistant_tokens_mask.") 136 | 137 | # We accept either JSON schemas or functions for tools. If we get functions, we convert them to schemas 138 | if tools is not None: 139 | tool_schemas = [] 140 | for tool in tools: 141 | if isinstance(tool, dict): 142 | tool_schemas.append(tool) 143 | elif isfunction(tool): 144 | tool_schemas.append(get_json_schema(tool)) 145 | else: 146 | raise ValueError( 147 | "Tools should either be a JSON schema, or a callable function with type hints " 148 | "and a docstring suitable for auto-conversion to a schema." 149 | ) 150 | else: 151 | tool_schemas = None 152 | 153 | if documents is not None: 154 | for document in documents: 155 | if not isinstance(document, dict): 156 | raise TypeError("Documents should be a list of dicts with 'title' and 'text' keys!") 157 | 158 | rendered = [] 159 | all_generation_indices = [] 160 | template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present 161 | for chat in conversations: 162 | if hasattr(chat, "messages"): 163 | # Indicates it's a Conversation object 164 | chat = chat.messages 165 | if return_assistant_tokens_mask: 166 | rendered_chat, generation_indices = _render_with_assistant_indices( 167 | compiled_template=compiled_template, 168 | messages=chat, 169 | tools=tool_schemas, 170 | documents=documents, 171 | add_generation_prompt=add_generation_prompt, 172 | **template_kwargs, 173 | ) 174 | all_generation_indices.append(generation_indices) 175 | else: 176 | rendered_chat = compiled_template.render( 177 | messages=chat, 178 | tools=tool_schemas, 179 | documents=documents, 180 | add_generation_prompt=add_generation_prompt, 181 | **template_kwargs, 182 | ) 183 | if continue_final_message: 184 | final_message = chat[-1]["content"].strip() 185 | rendered_chat = rendered_chat[: rendered_chat.rindex(final_message) + len(final_message)].rstrip() 186 | rendered.append(rendered_chat) 187 | 188 | if not is_batched: 189 | rendered = rendered[0] 190 | 191 | if tokenize: 192 | out = self( 193 | rendered, 194 | padding=padding, 195 | truncation=truncation, 196 | max_length=max_length, 197 | add_special_tokens=True, 198 | return_tensors=return_tensors, 199 | **tokenizer_kwargs, 200 | ) 201 | if return_dict: 202 | if return_assistant_tokens_mask: 203 | assistant_masks = [] 204 | if is_batched or return_tensors: 205 | input_ids = out["input_ids"] 206 | else: 207 | input_ids = [out["input_ids"]] 208 | for i in range(len(input_ids)): 209 | current_mask = [0] * len(input_ids[i]) 210 | for assistant_start_char, assistant_end_char in all_generation_indices[i]: 211 | start_token = out.char_to_token(i, assistant_start_char) 212 | end_token = out.char_to_token(i, assistant_end_char - 1) 213 | if start_token is None: 214 | # start_token is out of bounds maybe due to truncation. 215 | break 216 | for token_id in range(start_token, end_token + 1 if end_token else len(input_ids)): 217 | current_mask[token_id] = 1 218 | assistant_masks.append(current_mask) 219 | out["assistant_masks"] = assistant_masks if is_batched else assistant_masks[0] 220 | return out 221 | else: 222 | return out["input_ids"] 223 | else: 224 | return rendered -------------------------------------------------------------------------------- /collators/llava_1_5.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Dict, List, Sequence, Union 3 | 4 | import numpy as np 5 | import PIL 6 | import torch 7 | from transformers.image_utils import get_image_size, to_numpy_array 8 | from transformers.models.llava.processing_llava import LlavaProcessorKwargs 9 | from transformers.utils import logging 10 | 11 | from . import register_collator 12 | from .base import BaseDataCollator 13 | from .chat_template_monkey_patch import apply_chat_template 14 | 15 | 16 | logger = logging.get_logger(__name__) 17 | 18 | 19 | @register_collator("llava-1.5") 20 | class LLaVA15DataCollator(BaseDataCollator): 21 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 22 | # monkey patch to include bos tokens 23 | self.tokenizer.apply_chat_template = apply_chat_template.__get__(self.tokenizer) 24 | 25 | output_kwargs = self.processor._merge_kwargs( 26 | LlavaProcessorKwargs, 27 | tokenizer_init_kwargs=self.tokenizer.init_kwargs, 28 | ) 29 | 30 | vision_inputs = dict() 31 | images: List[List[PIL.Image.Image]] = [x for instance in instances for x in instance["images"]] 32 | if len(images) > 0: 33 | vision_inputs.update(**self.processor.image_processor(images, return_tensors="pt", **output_kwargs["images_kwargs"])) 34 | 35 | # some parsing 36 | images: List[List[PIL.Image.Image]] = [instance["images"] for instance in instances] 37 | system_prompts: List[Union[str, None]] = [instance["system_prompt"] for instance in instances] 38 | conversations: List[List] = [instance["conversations"] for instance in instances] 39 | 40 | # constants 41 | max_len = self.tokenizer.model_max_length 42 | image_token_id = self.config.image_token_index 43 | patch_size = self.processor.patch_size 44 | vision_feature_select_strategy = self.processor.vision_feature_select_strategy 45 | 46 | input_ids = [] 47 | labels = [] 48 | 49 | for system_prompt, cur_images, cur_convs in zip(system_prompts, images, conversations): 50 | cur_num_images = 0 51 | cur_input_ids = [] 52 | cur_labels = [] 53 | 54 | cur_text = [] 55 | if system_prompt is not None: 56 | cur_text.append({ 57 | "role": "system", 58 | "content": [{"type": "text", "text": system_prompt}] 59 | }) 60 | 61 | for i, text in enumerate(cur_convs): 62 | if i % 2 == 0: 63 | num_images = len([m.start() for m in re.finditer("", text)]) 64 | cur_num_images += num_images 65 | 66 | # .strip(): whitespaces and newlines are handled by chat_template 67 | text = text.replace("", "").strip() 68 | 69 | cur_text.append({ 70 | "role": "user", 71 | "content": [{"type": "text", "text": text}] + \ 72 | [{"type": "image"}] * num_images 73 | }) 74 | else: 75 | cur_text.append({ 76 | "role": "assistant", 77 | "content": [ 78 | {"type": "text", "text": text}, 79 | ] 80 | }) 81 | 82 | assert len(cur_images) == cur_num_images, "Number of image tokens does not match the number of images" 83 | 84 | temp = self.processor.apply_chat_template( 85 | cur_text, 86 | add_generation_prompt=False, 87 | tokenize=True, 88 | return_assistant_tokens_mask=True, 89 | return_dict=True, 90 | return_tensors="pt", 91 | truncation=False # the assistant tokens mask seems wrong when truncation is enabled 92 | ) 93 | cur_input_ids = temp["input_ids"] 94 | cur_assistant_masks = torch.tensor(temp["assistant_masks"], dtype=torch.bool).unsqueeze(0) 95 | 96 | # expand image tokens 97 | temp_vision_inputs = self.processor.image_processor(cur_images, return_tensors="pt") 98 | if temp_vision_inputs.get("pixel_values") is not None: 99 | if patch_size is not None and vision_feature_select_strategy is not None: 100 | # Replace the image token with the expanded image token sequence 101 | pixel_values = temp_vision_inputs["pixel_values"] 102 | height, width = get_image_size(to_numpy_array(pixel_values[0])) 103 | num_image_tokens = (height // patch_size) * (width // patch_size) + 1 104 | if vision_feature_select_strategy == "default": 105 | num_image_tokens -= 1 106 | 107 | repeat = torch.where(cur_input_ids == image_token_id, num_image_tokens, 1).squeeze() 108 | cur_input_ids = cur_input_ids.repeat_interleave(repeat, dim=1) 109 | cur_assistant_masks = cur_assistant_masks.repeat_interleave(repeat, dim=1) 110 | else: 111 | logger.warning_once( 112 | "Expanding inputs for image tokens in LLaVa should be done in processing. " 113 | "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " 114 | "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " 115 | "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." 116 | ) 117 | 118 | # a dirty hack to include eos token as part of the labels 119 | cur_assistant_masks[0, -1] = True 120 | 121 | # manual truncation 122 | if cur_input_ids.shape[1] > max_len: 123 | cur_input_ids = cur_input_ids[:, :max_len] 124 | cur_assistant_masks = cur_assistant_masks[:, :max_len] 125 | cur_labels = cur_input_ids.clone() 126 | 127 | # mask question tokens 128 | if self.mask_question_tokens: 129 | assert cur_labels.shape == cur_assistant_masks.shape, "Label and mask shapes do not match" 130 | cur_labels = torch.where(cur_assistant_masks, cur_labels, self.IGNORE_TOKEN_ID) 131 | 132 | assert cur_input_ids.shape == cur_labels.shape, "Input and label shapes do not match" 133 | 134 | # padding 135 | if cur_input_ids.shape[1] < max_len: 136 | cur_input_ids = torch.cat([ 137 | cur_input_ids, 138 | torch.full( 139 | (cur_input_ids.shape[0], max_len - cur_input_ids.shape[1]), 140 | self.PAD_TOKEN_ID, 141 | dtype=cur_input_ids.dtype, 142 | device=cur_input_ids.device 143 | ) 144 | ], dim=1) 145 | cur_labels = torch.cat([ 146 | cur_labels, 147 | torch.full( 148 | (cur_labels.shape[0], max_len - cur_labels.shape[1]), 149 | self.IGNORE_TOKEN_ID, 150 | dtype=cur_labels.dtype, 151 | device=cur_labels.device 152 | ) 153 | ], dim=1) 154 | 155 | input_ids.append(cur_input_ids) 156 | labels.append(cur_labels) 157 | 158 | input_ids = torch.cat(input_ids) 159 | labels = torch.cat(labels) 160 | 161 | return dict( 162 | **vision_inputs, 163 | input_ids=input_ids, 164 | labels=labels, 165 | attention_mask=input_ids.ne(self.PAD_TOKEN_ID), 166 | ) -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import av 2 | import os 3 | import json 4 | from PIL import Image 5 | from typing import Dict, List, Optional 6 | 7 | import numpy as np 8 | from torch.utils.data import Dataset 9 | 10 | 11 | TO_LOAD_IMAGE: Dict[str, bool] = { 12 | "llava-1.5": True, 13 | } 14 | 15 | 16 | def read_video_pyav(container, indices): 17 | ''' 18 | Decode the video with PyAV decoder. 19 | Args: 20 | container (`av.container.input.InputContainer`): PyAV container. 21 | indices (`List[int]`): List of frame indices to decode. 22 | Returns: 23 | result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3). 24 | ''' 25 | frames = [] 26 | container.seek(0) 27 | start_index = indices[0] 28 | end_index = indices[-1] 29 | for i, frame in enumerate(container.decode(video=0)): 30 | if i > end_index: 31 | break 32 | if i >= start_index and i in indices: 33 | frames.append(frame) 34 | return np.stack([x.to_ndarray(format="rgb24") for x in frames]) 35 | 36 | 37 | class LazySupervisedDataset(Dataset): 38 | """Dataset for supervised fine-tuning 39 | which is generalized enough to handle both images and videos. 40 | """ 41 | 42 | def __init__( 43 | self, 44 | data_path: str, 45 | model_family_id: str, 46 | image_folder: Optional[str] = None, 47 | video_folder: Optional[str] = None, 48 | num_frames: int = 8, 49 | user_key: str = "human", 50 | assistant_key: str = "gpt", 51 | ) -> None: 52 | super(LazySupervisedDataset, self).__init__() 53 | self.list_data_dict = json.load(open(data_path, "r")) 54 | self.image_folder = image_folder 55 | self.video_folder = video_folder 56 | self.num_frames = num_frames 57 | self.load_image = TO_LOAD_IMAGE[model_family_id] 58 | self.user_key = user_key 59 | self.assistant_key = assistant_key 60 | 61 | self.is_text_only = [ 62 | "image" not in source and "video" not in source 63 | for source in self.list_data_dict 64 | ] 65 | 66 | def __len__(self) -> int: 67 | return len(self.list_data_dict) 68 | 69 | def __getitem__(self, i) -> Dict[str, List]: 70 | source = self.list_data_dict[i] 71 | 72 | images = [] 73 | if "image" in source: 74 | # here we do not do any image preprocessing but rather 75 | # let the processor handle everything 76 | # in some cases this may cause slight differences 77 | # but should totally be fine (e.g., official llava-1.5 does padding, 78 | # but llava-1.5-hf (huggingface's implementation) does not) 79 | if isinstance(source["image"], list): 80 | image_sources = source["image"] 81 | elif isinstance(source["image"], str): 82 | image_sources = [source["image"]] 83 | else: 84 | raise ValueError(f"Invalid image source type: {type(source['image'])}") 85 | 86 | for image_path in image_sources: 87 | if self.image_folder is not None: 88 | image_path = os.path.join(self.image_folder, image_path) 89 | images.append( 90 | Image.open(image_path).convert("RGB") 91 | if self.load_image else image_path 92 | ) 93 | 94 | videos = [] 95 | if "video" in source: 96 | if isinstance(source["video"], list): 97 | video_sources = source["video"] 98 | elif isinstance(source["video"], str): 99 | video_sources = [source["video"]] 100 | else: 101 | raise ValueError(f"Invalid video source type: {type(source['video'])}") 102 | 103 | num_frames = [self.num_frames] * len(video_sources) 104 | 105 | for video_path, cur_num_frames in zip(video_sources, num_frames): 106 | if self.video_folder is not None: 107 | video_path = os.path.join(self.video_folder, video_path) 108 | 109 | container = av.open(video_path) 110 | total_frames = container.streams.video[0].frames 111 | indices = np.arange(0, total_frames, total_frames / cur_num_frames).astype(int) 112 | clip = read_video_pyav(container, indices) 113 | 114 | videos.append(clip) 115 | 116 | system_prompt = None 117 | if "system_prompt" in source: 118 | system_prompt = source["system_prompt"] 119 | 120 | convs = [] 121 | assert len(source["conversations"]) > 0, "No conversations found" 122 | for i, conv in enumerate(source["conversations"]): 123 | assert conv["from"] == (self.user_key if i % 2 == 0 else self.assistant_key), "Invalid conversation" 124 | convs.append(conv["value"]) 125 | assert len(convs) % 2 == 0, "Odd number of conversations" 126 | 127 | return dict( 128 | images=images, 129 | videos=videos, 130 | conversations=convs, 131 | system_prompt=system_prompt 132 | ) -------------------------------------------------------------------------------- /ds_configs/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 2, 18 | "overlap_comm": false, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto" 22 | } 23 | } -------------------------------------------------------------------------------- /ds_configs/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 3, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto", 22 | "stage3_prefetch_bucket_size": "auto", 23 | "stage3_param_persistence_threshold": "auto", 24 | "stage3_max_live_parameters": 1e9, 25 | "stage3_max_reuse_distance": 1e9, 26 | "stage3_gather_16bit_weights_on_model_save": true 27 | } 28 | } -------------------------------------------------------------------------------- /imgs/ddvqa.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendatalab/FakeVLM/fdcdcf111ecaff13ded65e32ce34d04cf876bc6f/imgs/ddvqa.jpg -------------------------------------------------------------------------------- /imgs/fakeclue_loki_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendatalab/FakeVLM/fdcdcf111ecaff13ded65e32ce34d04cf876bc6f/imgs/fakeclue_loki_result.jpg -------------------------------------------------------------------------------- /imgs/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendatalab/FakeVLM/fdcdcf111ecaff13ded65e32ce34d04cf876bc6f/imgs/framework.jpg -------------------------------------------------------------------------------- /imgs/logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendatalab/FakeVLM/fdcdcf111ecaff13ded65e32ce34d04cf876bc6f/imgs/logo.jpg -------------------------------------------------------------------------------- /imgs/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendatalab/FakeVLM/fdcdcf111ecaff13ded65e32ce34d04cf876bc6f/imgs/overview.png -------------------------------------------------------------------------------- /imgs/result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendatalab/FakeVLM/fdcdcf111ecaff13ded65e32ce34d04cf876bc6f/imgs/result.jpg -------------------------------------------------------------------------------- /loaders/__init__.py: -------------------------------------------------------------------------------- 1 | LOADERS = {} 2 | 3 | def register_loader(name): 4 | def register_loader_cls(cls): 5 | if name in LOADERS: 6 | return LOADERS[name] 7 | LOADERS[name] = cls 8 | return cls 9 | return register_loader_cls 10 | 11 | 12 | from .llava_1_5 import LLaVA15ModelLoader -------------------------------------------------------------------------------- /loaders/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, Tuple, Union, Optional 3 | 4 | import torch 5 | from transformers import PreTrainedModel, PreTrainedTokenizer, AutoProcessor, BitsAndBytesConfig 6 | 7 | 8 | class BaseModelLoader(ABC): 9 | def __init__( 10 | self, 11 | model_hf_path: str, 12 | model_local_path: str, 13 | compute_dtype: torch.dtype, 14 | bnb_config: Optional[BitsAndBytesConfig] = None, 15 | use_flash_attn: bool = False, 16 | device_map: Optional[Union[Dict, str]] = None, 17 | ) -> None: 18 | self.model_hf_path = model_hf_path 19 | self.model_local_path = model_local_path 20 | self.loading_kwargs = dict( 21 | torch_dtype=compute_dtype, 22 | quantization_config=bnb_config, 23 | device_map=device_map, 24 | ) 25 | if use_flash_attn: 26 | self.loading_kwargs["attn_implementation"] = "flash_attention_2" 27 | 28 | @abstractmethod 29 | def load(self, load_model: bool = True) -> Tuple[ 30 | PreTrainedModel, Union[None, PreTrainedTokenizer], Union[None, AutoProcessor] 31 | ]: ... -------------------------------------------------------------------------------- /loaders/llava_1_5.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | from transformers import AutoProcessor, LlavaForConditionalGeneration, PreTrainedTokenizer, AutoConfig 4 | 5 | from . import register_loader 6 | from .base import BaseModelLoader 7 | 8 | 9 | @register_loader("llava-1.5") 10 | class LLaVA15ModelLoader(BaseModelLoader): 11 | def load(self, load_model: bool = True) -> Tuple[LlavaForConditionalGeneration, PreTrainedTokenizer, AutoProcessor, AutoConfig]: 12 | if load_model: 13 | model = LlavaForConditionalGeneration.from_pretrained( 14 | self.model_local_path, 15 | **self.loading_kwargs, 16 | ) 17 | model.config.hidden_size = model.language_model.config.hidden_size # useful for deepspeed 18 | else: 19 | model = None 20 | 21 | processor = AutoProcessor.from_pretrained(self.model_hf_path, add_eos_token=True) 22 | tokenizer = processor.tokenizer 23 | config = AutoConfig.from_pretrained(self.model_local_path) 24 | return model, tokenizer, processor, config -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==1.10.1 2 | aiofiles==24.1.0 3 | annotated-types==0.7.0 4 | anyio==4.10.0 5 | av @ file:///home/conda/feedstock_root/build_artifacts/av_1756861937826/work 6 | bitsandbytes==0.47.0 7 | Brotli==1.1.0 8 | certifi==2025.8.3 9 | charset-normalizer==3.4.3 10 | click==8.3.0 11 | contourpy==1.3.2 12 | cycler==0.12.1 13 | deepspeed==0.14.4 14 | einops==0.8.1 15 | exceptiongroup==1.3.0 16 | fastapi==0.117.1 17 | ffmpy==0.6.1 18 | filelock==3.19.1 19 | fonttools==4.60.0 20 | fsspec==2025.9.0 21 | gitdb==4.0.12 22 | GitPython==3.1.45 23 | gradio==5.46.11 24 | gradio_client==1.13.1 25 | groovy==0.1.2 26 | h11==0.16.0 27 | hf-xet==1.1.10 28 | hjson==3.1.0 29 | httpcore==1.0.9 30 | httpx==0.28.1 31 | huggingface-hub==0.35.0 32 | idna==3.10 33 | Jinja2==3.1.6 34 | kiwisolver==1.4.9 35 | markdown-it-py==4.0.0 36 | MarkupSafe==3.0.2 37 | matplotlib==3.10.6 38 | mdurl==0.1.2 39 | mpmath==1.3.0 40 | networkx==3.4.2 41 | ninja==1.13.0 42 | numpy @ file:///home/conda/feedstock_root/build_artifacts/numpy_1747544640217/work/dist/numpy-2.2.6-cp310-cp310-linux_x86_64.whl#sha256=d6d964caeef85d00073d27cd62b46883d275b3d8162f723f0fcabbd0b3cc3f9d 43 | nvidia-cublas-cu12==12.8.4.1 44 | nvidia-cuda-cupti-cu12==12.8.90 45 | nvidia-cuda-nvrtc-cu12==12.8.93 46 | nvidia-cuda-runtime-cu12==12.8.90 47 | nvidia-cudnn-cu12==9.10.2.21 48 | nvidia-cufft-cu12==11.3.3.83 49 | nvidia-cufile-cu12==1.13.1.3 50 | nvidia-curand-cu12==10.3.9.90 51 | nvidia-cusolver-cu12==11.7.3.90 52 | nvidia-cusparse-cu12==12.5.8.93 53 | nvidia-cusparselt-cu12==0.7.1 54 | nvidia-ml-py==13.580.82 55 | nvidia-nccl-cu12==2.27.3 56 | nvidia-nvjitlink-cu12==12.8.93 57 | nvidia-nvtx-cu12==12.8.90 58 | orjson==3.11.3 59 | packaging==25.0 60 | pandas==2.3.2 61 | peft==0.17.1 62 | pillow @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_pillow_1758208668/work 63 | platformdirs==4.4.0 64 | protobuf==6.32.1 65 | psutil==7.1.0 66 | py-cpuinfo==9.0.0 67 | pydantic==2.11.9 68 | pydantic_core==2.33.2 69 | pydub==0.25.1 70 | Pygments==2.19.2 71 | pyparsing==3.2.5 72 | python-dateutil==2.9.0.post0 73 | python-multipart==0.0.20 74 | pytz==2025.2 75 | PyYAML==6.0.2 76 | regex==2025.9.18 77 | requests==2.32.5 78 | rich==14.1.0 79 | ruff==0.13.1 80 | safehttpx==0.1.6 81 | safetensors==0.6.2 82 | semantic-version==2.10.0 83 | sentencepiece==0.2.0 84 | sentry-sdk==2.38.0 85 | shellingham==1.5.4 86 | six==1.17.0 87 | smmap==5.0.2 88 | sniffio==1.3.1 89 | starlette==0.48.0 90 | sympy==1.14.0 91 | tiktoken==0.11.0 92 | tokenizers==0.20.3 93 | tomlkit==0.13.3 94 | torch==2.8.0 95 | torchvision==0.23.0 96 | tqdm==4.67.1 97 | transformers==4.45.2 98 | transformers-stream-generator==0.0.5 99 | triton==3.4.0 100 | typer==0.19.1 101 | typing-inspection==0.4.1 102 | typing_extensions==4.15.0 103 | tzdata==2025.2 104 | urllib3==2.5.0 105 | uvicorn==0.36.0 106 | wandb==0.22.0 107 | websockets==15.0.1 108 | -------------------------------------------------------------------------------- /scripts/eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import argparse 4 | from dataclasses import dataclass, field 5 | import transformers 6 | from torch.utils.data import Dataset, DataLoader 7 | from transformers import CLIPImageProcessor 8 | import pdb 9 | import json 10 | from transformers import AutoProcessor, LlavaForConditionalGeneration 11 | from tqdm import tqdm 12 | import random 13 | import numpy as np 14 | import torch 15 | import torchvision.transforms as T 16 | from PIL import Image 17 | from torchvision.transforms.functional import InterpolationMode 18 | import torch.nn as nn 19 | 20 | 21 | IMAGENET_MEAN = (0.485, 0.456, 0.406) 22 | IMAGENET_STD = (0.229, 0.224, 0.225) 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser(description="Legion Model Training") 26 | 27 | # Model-specific settings 28 | parser.add_argument("--model_path", default="", type=str) 29 | parser.add_argument("--val_batch_size", default=1, type=int) 30 | parser.add_argument("--workers", default=1, type=int) 31 | parser.add_argument("--data_base_test", default="", type=str) 32 | parser.add_argument("--test_json_file", default="", type=str) 33 | parser.add_argument("--output_path", default="", type=str) 34 | return parser.parse_args() 35 | 36 | class legion_cls_dataset(Dataset): 37 | def __init__(self, args, train=True): 38 | super().__init__() 39 | self.args = args 40 | self.train = train 41 | if train == True: 42 | with open(args.train_json_file, 'r') as f: 43 | self.data = json.load(f) 44 | elif train == False: 45 | with open(args.test_json_file, 'r') as f: 46 | self.data = json.load(f) 47 | self.processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf", revision='a272c74') 48 | 49 | def __len__(self): 50 | return len(self.data) 51 | 52 | def __getitem__(self, idx): 53 | if self.train == True: 54 | img_path = os.path.join(self.args.data_base_train, self.data[idx]['image']) 55 | else: 56 | img_path = os.path.join(self.args.data_base_test, self.data[idx]['image']) 57 | label = self.data[idx]['label'] 58 | 59 | image = Image.open(img_path) 60 | 61 | inputs = self.processor( 62 | text=self.data[idx]['conversations'][0]['value'], 63 | images=image, 64 | return_tensors="pt", 65 | padding="max_length", 66 | max_length=1024, 67 | truncation=True 68 | ) 69 | 70 | cate = 'deepfake' 71 | # torch.Size([n, 3, 448, 448]), int, int, str, str 72 | return inputs, [label], [img_path], [cate] 73 | 74 | 75 | 76 | 77 | def load_model(args): 78 | print("Loading model...") 79 | model = LlavaForConditionalGeneration.from_pretrained( 80 | args.model_path, 81 | torch_dtype=torch.bfloat16, 82 | low_cpu_mem_usage=True, 83 | use_flash_attention_2=True, 84 | revision='a272c74', 85 | ).eval().cuda() 86 | print("Successfully loaded model from:", args.model_path) 87 | return model 88 | 89 | def calculate_results_acc(results): 90 | acc_results = {} 91 | 92 | for cate in results: 93 | data = results[cate] 94 | 95 | right_real = data['right']['right_real'] 96 | right_fake = data['right']['right_fake'] 97 | wrong_real = data['wrong']['wrong_real'] 98 | wrong_fake = data['wrong']['wrong_fake'] 99 | 100 | total_real = right_real + wrong_real 101 | total_fake = right_fake + wrong_fake 102 | total = total_real + total_fake 103 | 104 | acc_total = (right_real + right_fake) / total if total != 0 else 0 105 | acc_real = right_real / total_real if total_real != 0 else 0 106 | acc_fake = right_fake / total_fake if total_fake != 0 else 0 107 | 108 | acc_results[cate] = { 109 | 'total_samples': total, 110 | 'total_accuracy': round(acc_total, 4), 111 | 'real_accuracy': round(acc_real, 4), 112 | 'fake_accuracy': round(acc_fake, 4), 113 | 'confusion_matrix': { 114 | 'right_real': right_real, 115 | 'wrong_real': wrong_real, 116 | 'right_fake': right_fake, 117 | 'wrong_fake': wrong_fake, 118 | } 119 | } 120 | 121 | global_stats = { 122 | 'total_right': sum(r['right']['right_real'] + r['right']['right_fake'] for r in results.values()), 123 | 'total_wrong': sum(r['wrong']['wrong_real'] + r['wrong']['wrong_fake'] for r in results.values()) 124 | } 125 | global_stats['global_accuracy'] = global_stats['total_right'] / (global_stats['total_right'] + global_stats['total_wrong']) 126 | 127 | return { 128 | 'category_acc': acc_results, 129 | 'global_stats': global_stats 130 | } 131 | 132 | 133 | def validate(args, model, cls_test_dataloader): 134 | processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf", revision='a272c74') 135 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 136 | results = {} 137 | outputs = [] 138 | with torch.no_grad(): 139 | for inputs, labels, paths, cates in tqdm(cls_test_dataloader): 140 | 141 | inputs["input_ids"] = inputs["input_ids"].squeeze().to(device) 142 | inputs["attention_mask"] = inputs["attention_mask"].squeeze().to(device) 143 | inputs["pixel_values"] = inputs["pixel_values"].squeeze().to(device) 144 | output = model.generate(**inputs, max_new_tokens=256) 145 | pred_cls = [] 146 | 147 | for i in range(output.shape[0]): 148 | response = processor.decode(output[i], skip_special_tokens=True).split('?')[-1] 149 | print(response) 150 | outputs.append({"image_path": paths[0][i], "output": response}) 151 | # pdb.set_trace() 152 | if 'real' in response.split('.')[0].lower(): 153 | pred_cls.append(1) 154 | elif 'fake' in response.split('.')[0].lower(): 155 | pred_cls.append(0) 156 | else: 157 | try: 158 | if 'real' in response.split('.')[1].lower(): 159 | pred_cls.append(1) 160 | elif 'fake' in response.split('.')[1].lower(): 161 | pred_cls.append(0) 162 | else: 163 | print(f"no fake or real in reponse:{response}") 164 | pred_cls.append(random.choice([0, 1])) 165 | except: 166 | print(f"no fake or real in reponse:{response}") 167 | pred_cls.append(random.choice([0, 1])) 168 | for label, pred, cate in zip(labels[0].tolist(), pred_cls, cates[0]): 169 | if cate not in results: 170 | results[cate] = {'right':{'right_fake':0, 'right_real':0}, 'wrong':{'wrong_fake':0, 'wrong_real':0}} 171 | if label == pred: 172 | if label == 1: 173 | results[cate]['right']['right_real'] += 1 174 | else: 175 | results[cate]['right']['right_fake'] += 1 176 | else: 177 | if label == 1: 178 | results[cate]['wrong']['wrong_real'] += 1 179 | else: 180 | results[cate]['wrong']['wrong_fake'] += 1 181 | 182 | os.makedirs('results', exist_ok=True) 183 | with open(args.output_path, "w") as file: 184 | json.dump(outputs, file, indent=2) 185 | acc = calculate_results_acc(results) 186 | print(acc) 187 | 188 | 189 | 190 | def main(): 191 | args = parse_args() 192 | model = load_model(args) 193 | model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) 194 | cls_test_dataset = legion_cls_dataset(args, train=False) 195 | cls_test_dataloader = DataLoader( 196 | cls_test_dataset, 197 | batch_size=args.val_batch_size, 198 | shuffle=False, 199 | num_workers=args.workers, 200 | pin_memory=True, 201 | ) 202 | validate(args, model, cls_test_dataloader) 203 | 204 | 205 | 206 | 207 | if __name__ == "__main__": 208 | main() 209 | 210 | -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | python scripts/eval.py \ 2 | --model_path /path/to/your/checkpoint \ 3 | --val_batch_size 16 \ 4 | --workers 16 \ 5 | --output_path results/fakevlm.json \ 6 | --test_json_file "/path/to/your/test.json" \ 7 | --data_base_test "path/to/your/test_images" \ -------------------------------------------------------------------------------- /scripts/eval_vllm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import argparse 4 | from dataclasses import dataclass, field 5 | import transformers 6 | from torch.utils.data import Dataset, DataLoader 7 | from transformers import CLIPImageProcessor 8 | import pdb 9 | import json 10 | from transformers import AutoProcessor, LlavaForConditionalGeneration 11 | from tqdm import tqdm 12 | import random 13 | import numpy as np 14 | import torch 15 | import torchvision.transforms as T 16 | from PIL import Image 17 | from torchvision.transforms.functional import InterpolationMode 18 | from vllm import LLM, SamplingParams 19 | import pdb 20 | IMAGENET_MEAN = (0.485, 0.456, 0.406) 21 | IMAGENET_STD = (0.229, 0.224, 0.225) 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser(description="FakeVLM Model Testing") 25 | 26 | # Model-specific settings 27 | parser.add_argument("--model_path", default="", type=str) 28 | parser.add_argument("--val_batch_size", default=1, type=int) 29 | parser.add_argument("--workers", default=1, type=int) 30 | parser.add_argument("--data_base_test", default="", type=str) 31 | parser.add_argument("--test_json_file", default="", type=str) 32 | parser.add_argument("--output_path", default="", type=str) 33 | return parser.parse_args() 34 | 35 | class legion_cls_dataset(Dataset): 36 | def __init__(self, args, train=True): 37 | super().__init__() 38 | self.args = args 39 | self.train = train 40 | if train == True: 41 | with open(args.train_json_file, 'r') as f: 42 | self.data = json.load(f) 43 | elif train == False: 44 | with open(args.test_json_file, 'r') as f: 45 | self.data = json.load(f) 46 | 47 | def __len__(self): 48 | return len(self.data) 49 | 50 | def __getitem__(self, idx): 51 | if self.train == True: 52 | img_path = os.path.join(self.args.data_base_train, self.data[idx]['image']) 53 | else: 54 | img_path = os.path.join(self.args.data_base_test, self.data[idx]['image']) 55 | 56 | 57 | label = self.data[idx]['label'] 58 | 59 | cate = self.data[idx]['cate'] 60 | 61 | # torch.Size([n, 3, 448, 448]), int, int, str, str 62 | return [self.data[idx]['conversations'][0]['value']], [label], [img_path], [cate] 63 | 64 | 65 | 66 | 67 | # change this function to our own to support custom behaviors 68 | def load_model(args): 69 | print("Loading model...") 70 | llm = LLM(model=args.model_path, 71 | dtype="float16", 72 | max_model_len=800, 73 | tensor_parallel_size=torch.cuda.device_count() 74 | ) 75 | print("Successfully loaded model from:", args.model_path) 76 | return llm 77 | 78 | def calculate_results_acc(results): 79 | acc_results = {} 80 | 81 | for cate in results: 82 | data = results[cate] 83 | 84 | right_real = data['right']['right_real'] 85 | right_fake = data['right']['right_fake'] 86 | wrong_real = data['wrong']['wrong_real'] 87 | wrong_fake = data['wrong']['wrong_fake'] 88 | 89 | total_real = right_real + wrong_real 90 | total_fake = right_fake + wrong_fake 91 | total = total_real + total_fake 92 | 93 | acc_total = (right_real + right_fake) / total if total != 0 else 0 94 | acc_real = right_real / total_real if total_real != 0 else 0 95 | acc_fake = right_fake / total_fake if total_fake != 0 else 0 96 | 97 | acc_results[cate] = { 98 | 'total_samples': total, 99 | 'total_accuracy': round(acc_total, 4), 100 | 'real_accuracy': round(acc_real, 4), 101 | 'fake_accuracy': round(acc_fake, 4), 102 | 'overall': { 103 | 'right_real': right_real, 104 | 'wrong_real': wrong_real, 105 | 'right_fake': right_fake, 106 | 'wrong_fake': wrong_fake, 107 | } 108 | } 109 | 110 | global_stats = { 111 | 'total_right': sum(r['right']['right_real'] + r['right']['right_fake'] for r in results.values()), 112 | 'total_wrong': sum(r['wrong']['wrong_real'] + r['wrong']['wrong_fake'] for r in results.values()) 113 | } 114 | global_stats['global_accuracy'] = global_stats['total_right'] / (global_stats['total_right'] + global_stats['total_wrong']) 115 | 116 | return { 117 | 'category_acc': acc_results, 118 | 'global_stats': global_stats 119 | } 120 | 121 | 122 | 123 | def validate(args, model, cls_test_dataloader): 124 | results = {} 125 | output_result = [] 126 | sampling_params = SamplingParams( 127 | max_tokens=4096, 128 | temperature=0, 129 | ) 130 | 131 | with torch.no_grad(): 132 | for questions, labels, imgs, cates in tqdm(cls_test_dataloader): 133 | inputs = [] 134 | 135 | for question, img in zip(questions[0], imgs[0]): 136 | inputs.append({ 137 | "prompt": question, 138 | "multi_modal_data": {"image": Image.open(img)}, 139 | }) 140 | # pdb.set_trace() 141 | outputs = model.generate(inputs, sampling_params=sampling_params) 142 | 143 | pred_cls = [] 144 | 145 | for i, output in enumerate(outputs): 146 | # pdb.set_trace() 147 | 148 | # #save result 149 | output_result.append({'id':imgs[0][i], 'caption':output.outputs[0].text}) 150 | 151 | 152 | response = output.outputs[0].text 153 | if 'real' in response.split('.')[0].lower(): 154 | pred_cls.append(1) 155 | elif 'fake' in response.split('.')[0].lower(): 156 | pred_cls.append(0) 157 | else: 158 | try: 159 | if 'real' in response.split('.')[1].lower(): 160 | pred_cls.append(1) 161 | elif 'fake' in response.split('.')[1].lower(): 162 | pred_cls.append(0) 163 | else: 164 | print(f"no fake or real in reponse:{response}") 165 | pred_cls.append(random.choice([0, 1])) 166 | except: 167 | print(f"no fake or real in reponse:{response}") 168 | pred_cls.append(random.choice([0, 1])) 169 | 170 | for label, pred, cate in zip(labels[0].tolist(), pred_cls, cates[0]): 171 | if cate not in results: 172 | results[cate] = {'right':{'right_fake':0, 'right_real':0}, 'wrong':{'wrong_fake':0, 'wrong_real':0}} 173 | if label == pred: 174 | if label == 1: 175 | results[cate]['right']['right_real'] += 1 176 | else: 177 | results[cate]['right']['right_fake'] += 1 178 | else: 179 | if label == 1: 180 | results[cate]['wrong']['wrong_real'] += 1 181 | else: 182 | results[cate]['wrong']['wrong_fake'] += 1 183 | 184 | # save result 185 | os.makedirs("results", exist_ok=True) 186 | with open(args.output_path, "w") as file: 187 | json.dump(output_result, file, indent=2) 188 | 189 | acc = calculate_results_acc(results) 190 | print(acc) 191 | 192 | 193 | 194 | def main(): 195 | args = parse_args() 196 | model = load_model(args) 197 | cls_test_dataset = legion_cls_dataset(args, train=False) 198 | cls_test_dataloader = DataLoader( 199 | cls_test_dataset, 200 | batch_size=args.val_batch_size, 201 | shuffle=False, 202 | num_workers=args.workers, 203 | pin_memory=True, 204 | ) 205 | validate(args, model, cls_test_dataloader) 206 | 207 | 208 | 209 | 210 | if __name__ == "__main__": 211 | main() 212 | 213 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | NUM_GPUS=8 2 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 3 | DISTRIBUTED_ARGS=" 4 | --nnodes=1 \ 5 | --nproc_per_node ${NUM_GPUS} \ 6 | --rdzv_backend c10d \ 7 | --rdzv_endpoint localhost:0 8 | " 9 | 10 | # according to your own case 11 | MODEL_ID=llava-1.5-7b # model id 12 | TRAIN_DATA_PATH="path/to/test.json" # path to the training data json file 13 | EVAL_DATA_PATH="path/to/test.json" # path to the evaluation data json file (optional) 14 | IMAGE_FOLDER="path/to/test_images" # path to the image root folder; if provided, the image paths in the json should be relative 15 | VIDEO_FOLDER="" # path to the video root folder; if provided, the video paths in the json should be relative 16 | NUM_FRAMES=8 # how many frames are sampled from each video 17 | 18 | TRAIN_VISION_ENCODER=True # whether train the vision encoder 19 | USE_VISION_LORA=False # whether use lora for vision encoder (only effective when `TRAIN_VISION_ENCODER` is True) 20 | TRAIN_VISION_PROJECTOR=True # whether train the vision projector (only full finetuning is supported) 21 | 22 | USE_LORA=False # whether use lora for llm 23 | Q_LORA=False # whether use q-lora for llm; only effective when `USE_LORA` is True 24 | LORA_R=8 # the lora rank (both llm and vision encoder) 25 | LORA_ALPHA=8 # the lora alpha (both llm and vision encoder) 26 | 27 | RUN_ID=${MODEL_ID}-fakevlm # a custom run id that determines the checkpoint folder and wandb run name 28 | 29 | DS_STAGE=zero2 # deepspeed stage; < zero2 | zero3 > 30 | PER_DEVICE_BATCH_SIZE=32 # batch size per GPU 31 | GRAD_ACCUM=1 # gradient accumulation steps 32 | NUM_EPOCHS=2 # number of training epochs 33 | 34 | LR=2e-5 # learning rate 35 | MODEL_MAX_LEN=1024 # maximum input length of the model 36 | 37 | 38 | torchrun $DISTRIBUTED_ARGS train.py \ 39 | --model_id $MODEL_ID \ 40 | --data_path $TRAIN_DATA_PATH \ 41 | --eval_data_path $EVAL_DATA_PATH \ 42 | --image_folder $IMAGE_FOLDER \ 43 | --video_folder $VIDEO_FOLDER \ 44 | --num_frames $NUM_FRAMES \ 45 | --output_dir ./checkpoints/$RUN_ID \ 46 | --report_to wandb \ 47 | --run_name $RUN_ID \ 48 | --deepspeed ./ds_configs/${DS_STAGE}.json \ 49 | --bf16 True \ 50 | --num_train_epochs $NUM_EPOCHS \ 51 | --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \ 52 | --per_device_eval_batch_size $PER_DEVICE_BATCH_SIZE \ 53 | --gradient_accumulation_steps $GRAD_ACCUM \ 54 | --eval_strategy "no" \ 55 | --save_strategy "epoch" \ 56 | --save_total_limit 1 \ 57 | --learning_rate ${LR} \ 58 | --weight_decay 0. \ 59 | --warmup_ratio 0.03 \ 60 | --lr_scheduler_type "cosine" \ 61 | --logging_steps 1 \ 62 | --tf32 True \ 63 | --model_max_length $MODEL_MAX_LEN \ 64 | --gradient_checkpointing True \ 65 | --dataloader_num_workers 4 \ 66 | --train_vision_encoder $TRAIN_VISION_ENCODER \ 67 | --use_vision_lora $USE_VISION_LORA \ 68 | --train_vision_projector $TRAIN_VISION_PROJECTOR \ 69 | --use_lora $USE_LORA \ 70 | --q_lora $Q_LORA \ 71 | --lora_r $LORA_R \ 72 | --lora_alpha $LORA_ALPHA 73 | -------------------------------------------------------------------------------- /supported_models.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | from collections import OrderedDict 3 | 4 | from collators import COLLATORS 5 | from datasets import TO_LOAD_IMAGE 6 | from loaders import LOADERS 7 | 8 | 9 | MODULE_KEYWORDS: Dict[str, Dict[str, List]] = { 10 | "llava-1.5": { 11 | "vision_encoder": ["vision_tower"], 12 | "vision_projector": ["multi_modal_projector"], 13 | "llm": ["language_model"] 14 | }, 15 | } 16 | 17 | 18 | MODEL_HF_PATH = OrderedDict() 19 | 20 | MODEL_FAMILIES = OrderedDict() 21 | 22 | 23 | def register_model(model_id: str, model_family_id: str, model_hf_path: str) -> None: 24 | if model_id in MODEL_HF_PATH or model_id in MODEL_FAMILIES: 25 | raise ValueError(f"Duplicate model_id: {model_id}") 26 | MODEL_HF_PATH[model_id] = model_hf_path 27 | MODEL_FAMILIES[model_id] = model_family_id 28 | 29 | register_model( 30 | model_id="llava-1.5-7b", 31 | model_family_id="llava-1.5", 32 | model_hf_path="llava-hf/llava-1.5-7b-hf" 33 | ) 34 | 35 | # sanity check 36 | for model_family_id in MODEL_FAMILIES.values(): 37 | assert model_family_id in COLLATORS, f"Collator not found for model family: {model_family_id}" 38 | assert model_family_id in LOADERS, f"Loader not found for model family: {model_family_id}" 39 | assert model_family_id in MODULE_KEYWORDS, f"Module keywords not found for model family: {model_family_id}" 40 | assert model_family_id in TO_LOAD_IMAGE, f"Image loading specification not found for model family: {model_family_id}" 41 | 42 | 43 | if __name__ == "__main__": 44 | temp = "Model ID" 45 | ljust = 30 46 | print("Supported models:") 47 | print(f" {temp.ljust(ljust)}: HuggingFace Path") 48 | print(" ------------------------------------------------") 49 | for model_id, model_hf_path in MODEL_HF_PATH.items(): 50 | print(f" {model_id.ljust(ljust)}: {model_hf_path}") 51 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["WANDB_PROJECT"]= "lmms-ft" 3 | from dataclasses import asdict 4 | import math 5 | from pathlib import Path 6 | from typing import List, Optional 7 | import yaml 8 | 9 | from accelerate.utils import DistributedType 10 | from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training 11 | import torch 12 | import transformers 13 | from transformers import Trainer, deepspeed 14 | 15 | 16 | from arguments import ModelArguments, DataArguments, TrainingArguments, LoraArguments 17 | from collators import COLLATORS 18 | from datasets import LazySupervisedDataset 19 | from loaders import LOADERS 20 | from supported_models import MODULE_KEYWORDS 21 | from utils import ( 22 | rank0_print, find_all_linear_names, safe_save_model_for_hf_trainer, 23 | get_peft_state_maybe_zero_3, TrainerWithCustomSampler 24 | ) 25 | 26 | 27 | def train(): 28 | parser = transformers.HfArgumentParser( 29 | (ModelArguments, DataArguments, TrainingArguments, LoraArguments) 30 | ) 31 | model_args, data_args, training_args, lora_args = parser.parse_args_into_dataclasses() 32 | 33 | # dumping arguments 34 | output_dir = getattr(training_args, 'output_dir', None) 35 | assert output_dir is not None, "output_dir is required" 36 | args_dir = Path(output_dir) / "arguments" 37 | args_dir.mkdir(parents=True, exist_ok=True) 38 | yaml.dump(asdict(model_args), open(args_dir / "model.yaml", "w")) 39 | yaml.dump(asdict(data_args), open(args_dir / "data.yaml", "w")) 40 | yaml.dump(asdict(training_args), open(args_dir / "training.yaml", "w")) 41 | yaml.dump(asdict(lora_args), open(args_dir / "lora.yaml", "w")) 42 | 43 | compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) 44 | if getattr(training_args, 'deepspeed', None) and getattr(lora_args, 'q_lora', False): 45 | training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED 46 | 47 | device_map = None 48 | if lora_args.q_lora: 49 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if int(os.environ.get("WORLD_SIZE", 1)) != 1 else None 50 | if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled(): 51 | raise ValueError("FSDP or ZeRO3 are not incompatible with QLoRA.") 52 | 53 | # llm quantization config (for q-lora) 54 | bnb_config = None 55 | if lora_args.use_lora and lora_args.q_lora: 56 | from transformers import BitsAndBytesConfig 57 | rank0_print("Quantization for LLM enabled...") 58 | bnb_config = BitsAndBytesConfig( 59 | load_in_4bit=True, 60 | bnb_4bit_compute_dtype=compute_dtype, 61 | bnb_4bit_quant_type="nf4", 62 | ) 63 | 64 | # load model, tokenizer, processor 65 | rank0_print("Loading model, tokenizer, processor...") 66 | loader = LOADERS[model_args.model_family_id]( 67 | model_hf_path=model_args.model_hf_path, 68 | model_local_path=model_args.model_local_path, 69 | compute_dtype=compute_dtype, 70 | bnb_config=bnb_config, 71 | use_flash_attn=training_args.use_flash_attn, 72 | device_map=device_map, 73 | ) 74 | model, tokenizer, processor, config = loader.load() 75 | tokenizer.model_max_length = training_args.model_max_length 76 | 77 | if training_args.gradient_checkpointing: 78 | model.enable_input_require_grads() 79 | 80 | # freeze certain params 81 | vision_encoder_keys = MODULE_KEYWORDS[model_args.model_family_id]["vision_encoder"] 82 | if not training_args.train_vision_encoder: 83 | rank0_print(f"Vision encoder is freezed... including:") 84 | for module in vision_encoder_keys: 85 | rank0_print(f"\t{module}") 86 | eval(f"model.{module}").requires_grad_(False) 87 | 88 | vision_projector_keys = MODULE_KEYWORDS[model_args.model_family_id]["vision_projector"] 89 | if not training_args.train_vision_projector: 90 | rank0_print(f"Vision projector is freezed... including:") 91 | for module in vision_projector_keys: 92 | rank0_print(f"\t{module}") 93 | eval(f"model.{module}").requires_grad_(False) 94 | 95 | # other components preparation (e.g., image_newline, vision_resampler) 96 | # we will just freeze these 97 | if "others" in MODULE_KEYWORDS[model_args.model_family_id]: 98 | rank0_print(f"Other multimodal component is freezed... including:") 99 | for other_key in MODULE_KEYWORDS[model_args.model_family_id]["others"]: 100 | rank0_print(f"\t{other_key}") 101 | eval(f"model.{other_key}").requires_grad_(False) 102 | 103 | # lora preparation 104 | llm_keys = MODULE_KEYWORDS[model_args.model_family_id]["llm"] 105 | if not (lora_args.use_lora or (training_args.train_vision_encoder and lora_args.use_vision_lora)): 106 | rank0_print("No LoRA enabled...") 107 | else: 108 | named_modules = {n: m for n, m in model.named_modules()} 109 | lora_modules = [] 110 | full_modules = [] 111 | 112 | if training_args.train_vision_encoder and lora_args.use_vision_lora: 113 | rank0_print("LoRA for vision encoder enabled...") 114 | lora_modules.extend(find_all_linear_names(named_modules, vision_encoder_keys)) 115 | elif training_args.train_vision_encoder: 116 | rank0_print("Vision encoder will be fully trained...") 117 | full_modules.extend(vision_encoder_keys) 118 | 119 | if lora_args.use_lora: 120 | rank0_print("LoRA for LLM enabled...") 121 | lora_modules.extend(find_all_linear_names(named_modules, llm_keys)) 122 | else: 123 | rank0_print("LLM will be fully trained...") 124 | full_modules.extend(llm_keys) 125 | 126 | if training_args.train_vision_projector: 127 | rank0_print("Vision projector will be fully trained...") 128 | full_modules.extend(vision_projector_keys) 129 | 130 | lora_config = LoraConfig( 131 | r=lora_args.lora_r, 132 | lora_alpha=lora_args.lora_alpha, 133 | target_modules=lora_modules, 134 | modules_to_save=full_modules, 135 | lora_dropout=lora_args.lora_dropout, 136 | bias=lora_args.lora_bias, 137 | task_type="CAUSAL_LM", 138 | ) 139 | 140 | if lora_args.q_lora: 141 | model = prepare_model_for_kbit_training( 142 | model, use_gradient_checkpointing=training_args.gradient_checkpointing 143 | ) 144 | 145 | model = get_peft_model(model, lora_config) 146 | 147 | # for module in llm_keys: 148 | # rank0_print(f"\t{module}") 149 | # eval(f"model.{module}").requires_grad_(False) 150 | 151 | # print trainable parameters for inspection 152 | rank0_print("Trainable parameters:") 153 | for name, param in model.named_parameters(): 154 | if param.requires_grad: 155 | rank0_print(f"\t{name}") 156 | 157 | # load data 158 | rank0_print("Loading data...") 159 | train_dataset = LazySupervisedDataset( 160 | data_path=data_args.data_path, 161 | image_folder=data_args.image_folder, 162 | video_folder=data_args.video_folder, 163 | num_frames=data_args.num_frames, 164 | model_family_id=model_args.model_family_id, 165 | user_key=data_args.user_key, 166 | assistant_key=data_args.assistant_key 167 | ) 168 | if data_args.eval_data_path: 169 | eval_dataset = LazySupervisedDataset( 170 | data_path=data_args.eval_data_path, 171 | image_folder=data_args.image_folder, 172 | video_folder=data_args.video_folder, 173 | num_frames=data_args.num_frames, 174 | model_family_id=model_args.model_family_id, 175 | user_key=data_args.user_key, 176 | assistant_key=data_args.assistant_key 177 | ) 178 | else: 179 | eval_dataset = None 180 | training_args.eval_strategy = "no" 181 | 182 | # data collator 183 | data_collator = COLLATORS[model_args.model_family_id]( 184 | config=config, 185 | tokenizer=tokenizer, 186 | processor=processor, 187 | mask_question_tokens=training_args.mask_question_tokens 188 | ) 189 | 190 | # trainer 191 | trainer = TrainerWithCustomSampler( 192 | model=model, 193 | args=training_args, 194 | data_collator=data_collator, 195 | train_dataset=train_dataset, 196 | eval_dataset=eval_dataset, 197 | ) 198 | trainer.train() 199 | trainer.save_state() 200 | 201 | safe_save_model_for_hf_trainer(trainer=trainer, output_dir=output_dir) 202 | 203 | 204 | if __name__ == "__main__": 205 | train() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Dict, Optional 3 | 4 | from deepspeed import zero 5 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 6 | 7 | import torch 8 | import torch.distributed as dist 9 | from torch.utils.data import Sampler 10 | import transformers 11 | from transformers import Trainer 12 | from transformers.trainer import has_length 13 | 14 | 15 | class NoTextOnlyBatchSampler(Sampler): 16 | r""" 17 | Sampler that tries its best to sample batches such that no batch has only 18 | text (unimodal) data. This is necessary for training with deepspeed. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | batch_size: int, 24 | world_size: int, 25 | is_text_only: Optional[List[bool]] = None, 26 | generator=None, 27 | ): 28 | if is_text_only is None: 29 | raise ValueError("`is_text_only` must be provided.") 30 | 31 | self.batch_size = batch_size 32 | self.world_size = world_size 33 | self.is_text_only = is_text_only 34 | self.generator = generator 35 | self.mega_batch_size = batch_size * world_size 36 | 37 | def __len__(self): 38 | return len(self.is_text_only) 39 | 40 | def __iter__(self): 41 | # mm: multimodal, entry that has both text and image/video 42 | # uni: unimodal, entry that has only text 43 | mm_indices = [i for i, is_text_only in enumerate(self.is_text_only) if not is_text_only] 44 | uni_indices = [i for i, is_text_only in enumerate(self.is_text_only) if is_text_only] 45 | 46 | num_batches = math.ceil((len(mm_indices) + len(uni_indices)) / self.mega_batch_size) 47 | if len(mm_indices) < num_batches: 48 | raise ValueError( 49 | f"{len(mm_indices)} multimodal entries, {len(num_batches)} batches. " 50 | "Not enough multimodal data in the dataset, or the batch size is too small. " 51 | "There will be at least one batch that is text-only, which doesn't work with deepspeed. " 52 | "Try increasing the batch size first." 53 | ) 54 | 55 | # shuffle indices 56 | mm_indices = [mm_indices[i] for i in torch.randperm(len(mm_indices), generator=None).tolist()] 57 | uni_indices = [uni_indices[i] for i in torch.randperm(len(uni_indices), generator=None).tolist()] 58 | 59 | # distribute indices into batches 60 | num_uni_indices_in_mega_batch = [len(uni_indices) // num_batches] * num_batches 61 | for i in range(len(uni_indices) % num_batches): 62 | num_uni_indices_in_mega_batch[i] += 1 63 | 64 | mega_batches = [] 65 | cur_uni_index = 0 66 | cur_mm_index = 0 67 | for i, num_uni_indices in enumerate(num_uni_indices_in_mega_batch): 68 | mega_batch = [] 69 | mega_batch.extend(uni_indices[cur_uni_index:cur_uni_index + num_uni_indices]) 70 | cur_uni_index += num_uni_indices 71 | assert len(mega_batch) < self.mega_batch_size 72 | 73 | if i < num_batches - 1: 74 | increment = self.mega_batch_size - len(mega_batch) 75 | mega_batch.extend( 76 | mm_indices[cur_mm_index:cur_mm_index + increment] 77 | ) 78 | cur_mm_index += increment 79 | else: # last batch 80 | mega_batch.extend(mm_indices[cur_mm_index:]) 81 | assert len(mega_batch) <= self.mega_batch_size, "Last batch is too big." 82 | 83 | mega_batches.append(mega_batch) 84 | 85 | mega_batch_indices = torch.randperm(len(mega_batches), generator=self.generator) 86 | mega_batches = [mega_batches[i] for i in mega_batch_indices] 87 | indices = [i for mega_batch in mega_batches for i in mega_batch] 88 | return iter(indices) 89 | 90 | 91 | class TrainerWithCustomSampler(Trainer): 92 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: 93 | if self.train_dataset is None or not has_length(self.train_dataset): 94 | return None 95 | 96 | is_text_only = self.train_dataset.is_text_only 97 | return NoTextOnlyBatchSampler( 98 | self.args.train_batch_size, 99 | world_size=self.args.world_size * self.args.gradient_accumulation_steps, 100 | is_text_only=is_text_only, 101 | ) 102 | 103 | def _get_eval_sampler(self, eval_dataset: torch.utils.data.Dataset) -> Optional[torch.utils.data.Sampler]: 104 | is_text_only = eval_dataset.is_text_only 105 | return NoTextOnlyBatchSampler( 106 | self.args.eval_batch_size, 107 | world_size=self.args.world_size, 108 | is_text_only=is_text_only, 109 | ) 110 | 111 | 112 | def find_all_linear_names(named_modules: Dict, target_modules: List[str]): 113 | cls = torch.nn.Linear 114 | lora_module_names = set() 115 | for name, module in named_modules.items(): 116 | if not any([module_name in name for module_name in target_modules]): 117 | continue 118 | 119 | if isinstance(module, cls): 120 | lora_module_names.add(name) 121 | 122 | for name in list(lora_module_names): 123 | if 'lm_head' in name: # needed for 16-bit 124 | lora_module_names.remove(name) 125 | 126 | return list(lora_module_names) 127 | 128 | 129 | def rank0_print(*args): 130 | if dist.is_initialized(): 131 | if dist.get_rank() == 0: 132 | print(*args) 133 | 134 | 135 | def maybe_zero_3(param): 136 | if hasattr(param, "ds_id"): 137 | with zero.GatheredParameters([param]): 138 | param = param.data.detach().cpu().clone() 139 | else: 140 | param = param.detach().cpu().clone() 141 | return param 142 | 143 | 144 | # Borrowed from peft.utils.get_peft_model_state_dict 145 | def get_peft_state_maybe_zero_3(named_params, bias): 146 | if bias == "none": 147 | to_return = {k: t for k, t in named_params if "lora_" in k} 148 | elif bias == "all": 149 | to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} 150 | elif bias == "lora_only": 151 | to_return = {} 152 | maybe_lora_bias = {} 153 | lora_bias_names = set() 154 | for k, t in named_params: 155 | if "lora_" in k: 156 | to_return[k] = t 157 | bias_name = k.split("lora_")[0] + "bias" 158 | lora_bias_names.add(bias_name) 159 | elif "bias" in k: 160 | maybe_lora_bias[k] = t 161 | for k, t in maybe_lora_bias: 162 | if bias_name in lora_bias_names: 163 | to_return[bias_name] = t 164 | else: 165 | raise NotImplementedError 166 | to_return = {k: maybe_zero_3(v) for k, v in to_return.items()} 167 | return to_return 168 | 169 | 170 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): 171 | """Collects the state dict and dump to disk.""" 172 | if trainer.deepspeed: 173 | torch.cuda.synchronize() 174 | trainer.save_model(output_dir) 175 | return 176 | 177 | state_dict = trainer.model.state_dict() 178 | if trainer.args.should_save: 179 | cpu_state_dict = { 180 | key: value.cpu() 181 | for key, value in state_dict.items() 182 | } 183 | del state_dict 184 | trainer._save(output_dir, state_dict=cpu_state_dict) --------------------------------------------------------------------------------