├── .gitignore
├── README.md
├── arguments.py
├── collators
├── __init__.py
├── base.py
├── chat_template_monkey_patch.py
└── llava_1_5.py
├── datasets.py
├── ds_configs
├── zero2.json
└── zero3.json
├── imgs
├── ddvqa.jpg
├── fakeclue_loki_result.jpg
├── framework.jpg
├── logo.jpg
├── overview.png
└── result.jpg
├── loaders
├── __init__.py
├── base.py
└── llava_1_5.py
├── requirements.txt
├── scripts
├── eval.py
├── eval.sh
├── eval_vllm.py
└── train.sh
├── supported_models.py
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | *.pyc
3 | *.egg-info
4 | dist
5 |
6 | # Other
7 | .DS_Store
8 | wandb
9 | output
10 |
11 | checkpoints
12 | ckpts*
13 |
14 | .ipynb_checkpoints
15 | *.ipynb
16 |
17 | *.ttf
18 |
19 | local*
20 | not_useful_for_now
21 | *slurm*
22 |
23 | example_data/videos/ego4d/*.mp4
24 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
Spot the Fake: Large Multimodal Model-Based Synthetic Image Detection with Artifact Explanation
3 |
4 |
5 |
6 |
7 | [Siwei Wen](https://scholar.google.com/citations?user=kJRiUYwAAAAJ&hl=zh-CN)1,3*,
8 | [Junyan Ye](https://yejy53.github.io/)2,1*,
9 | [Peilin Feng](https://peilin-ff.github.io/)1,3,
10 | [Hengrui Kang](https://scholar.google.com/citations?user=kVbzWCAAAAAJ&hl=zh-CN)4,1,
11 | [Zichen Wen](https://scholar.google.com/citations?user=N-aPFvEAAAAJ&hl=zh-CN)4,1,
12 | [Yize Chen](https://openreview.net/profile?id=~Yize_Chen2)5,
13 | [Jiang Wu](https://scholar.google.com/citations?user=LHiiL7AAAAAJ&hl=zh-CN)1,
14 | [Wenjun Wu](https://openreview.net/profile?id=~wenjun_wu3)3,
15 | [Conghui He](https://conghui.github.io/)1,
16 | [Weijia Li](https://liweijia.github.io/)2,1†
17 |
18 | 1Shanghai Artificial Intelligence Laboratory, 2Sun Yat-sen University
19 | 3Beihang University, 4Shanghai Jiao Tong University, 5The Chinese University of Hong Kong, Shenzhen
20 |
21 |
22 |
23 |
24 |
25 | [](https://arxiv.org/pdf/2503.14905)
26 | [](https://hits.seeyoufarm.com)
27 | [](https://github.com/opendatalab/FakeVLM/issues)
28 | [](https://github.com/opendatalab/FakeVLM/stargazers)
29 | [](https://huggingface.co/datasets/lingcco/FakeClue)
30 | [](https://huggingface.co/lingcco/fakeVLM)
31 |
32 |
33 |
39 |
40 | ## 📰 News
41 | - **[2025.9.24]**: 🎉 FakeVLM was accepted to NeurIPS 2025!
42 | - **[2025.4.15]**: 🤗 We are excited to release the FakeClue dataset. Check out [here](https://huggingface.co/datasets/lingcco/FakeClue).
43 | - **[2025.3.20]**: 🔥 We have released **Spot the Fake: Large Multimodal Model-Based Synthetic Image Detection with Artifact Explanation**. Check out the [paper](https://arxiv.org/abs/2503.14905). We present FakeClue dataset and FakeVLM model.
44 |
45 | ##
FakeVLM Overview
46 |
47 | With the rapid advancement of Artificial Intelligence Generated Content (AIGC) technologies, synthetic images have become increasingly prevalent in everyday life, posing new challenges for authenticity assessment and detection. Despite the effectiveness of existing methods in evaluating image authenticity and locating forgeries, these approaches often lack human interpretability and do not fully address the growing complexity of synthetic data. To tackle these challenges, we introduce FakeVLM, a specialized large multimodal model designed for both general synthetic image and DeepFake detection tasks. FakeVLM not only excels in distinguishing real from fake images but also provides clear, natural language explanations for image artifacts, enhancing interpretability. Additionally, we present FakeClue, a comprehensive dataset containing over 100,000 images across seven categories, annotated with fine-grained artifact clues in natural language. FakeVLM demonstrates performance comparable to expert models while eliminating the need for additional classifiers, making it a robust solution for synthetic data detection. Extensive evaluations across multiple datasets confirm the superiority of FakeVLM in both authenticity classification and artifact explanation tasks, setting a new benchmark for synthetic image detection.
48 |
49 |
50 |

51 |
52 |
53 | ##
Contributions
54 |
55 | - We propose FakeVLM, a multimodal large model designed for both general synthetic and deepfake image detection tasks. It excels at distinguishing real from fake images while also providing excellent interpretability for artifact details in synthetic images.
56 | - We introduce the FakeClue dataset, which includes a rich variety of image categories and fine-grained artifact annotations in natural language.
57 | - Our method has been extensively evaluated on multiple datasets, achieving outstanding performance in both synthetic detection and abnormal artifact explanation tasks.
58 |
59 | ## 🛠️ Installation
60 | Please clone our repository and change to that folder
61 | ```bash
62 | git clone git@github.com:opendatalab/FakeVLM.git
63 | cd FakeVLM
64 | ```
65 |
66 | Our model is based on the [lmms-finetune](baidu.com) environment. Please follow the steps below to configure the environment.
67 | ```bash
68 | conda create -n fakevlm python=3.10 -y
69 | conda activate fakevlm
70 |
71 | python -m pip install -r requirements.txt
72 |
73 | python -m pip install --no-cache-dir --no-build-isolation flash-attn
74 | ```
75 |
76 | ## 📦 Dataset
77 | The directory containing the images should have the following structure:
78 |
79 | ```
80 | playground
81 | └──data
82 | └──train
83 | |--doc
84 | |--fake
85 | |--real
86 | .
87 | .
88 | |--satellite
89 | └──test
90 | .
91 | .
92 | .
93 | ```
94 |
95 |
96 | ## 📌 Usage
97 | ### 1. Data Preparation
98 | The training data can be downloaded from [here](https://huggingface.co/datasets/lingcco/FakeClue).
99 |
100 | Please download the dataset and unzip the images.
101 | ### 2. Train
102 |
103 | Replace data paths with yours in `scripts/train.sh` and the original [llava-1.5-7b-hf](https://huggingface.co/llava-hf/llava-1.5-7b-hf) model path with yours in `supported_models.py`.
104 |
105 | ```
106 | bash train.sh
107 | ```
108 |
109 | ### 3. Evaluation
110 |
111 | We prepared two scripts for you to evaluate the FakeVLM model. The trained FakeVLM model is available at [here](https://huggingface.co/lingcco/fakeVLM).
112 |
113 | #### 1. Usual evaluation
114 |
115 | ```
116 | bash scripts/eval.sh
117 | ```
118 |
119 | #### 2. Evaluation with vllm
120 |
121 | Considering the size of the model and the magnitude of the data, we recommend using vllm for evaluation. Please make sure that you have installed vllm.
122 |
123 |
124 | ```
125 | # change scripts/eval.py to scripts/eval_vllm.py in scripts/eval.sh
126 | bash scripts/eval.sh
127 | ```
128 | ## 📊 Results
129 | Performance of 7 leading LMMs and FakeVLM on DD-VQA, Fake Clues and Loki.
130 |
131 | - **FakeClue**
132 | Ours dataset.
133 | - **LOKI**
134 | A new benchmark for evaluating multimodal models in synthetic detection tasks. It includes **human-annotated fine-grained image artifacts**, enabling deeper analysis of artifact explanations. We used its image modality, covering categories like Animals, Humans, Scenery, and Documents.
135 |
136 |
137 |
138 | - **DD-VQA**
139 | A dataset for explaining facial artifacts, using **manual annotations** in a VQA format. Artifacts include blurred hairlines, mismatched eyebrows, rigid pupils, and unnatural shadows. It builds on FF++ data and emphasizes common-sense reasoning.
140 |
141 |
142 |

143 |
144 |
145 | To provide a comprehensive comparison of the model performance across the three datasets—FakeClue, LOKI, and DD-VQA—we present the following radar chart. This chart visually highlights the strengths and weaknesses of the 7 leading LMMs and FakeVLM, offering a clear depiction of their results in synthetic detection and artifact explanation tasks.
146 |
147 |
148 |

149 |
150 |
151 | ## 😄 Acknowledgement
152 |
153 | This repository is built upon the work of [LLaVA](https://github.com/haotian-liu/LLaVA/tree/main), and our codebase is built upon [lmms-finetune](https://github.com/zjysteven/lmms-finetune). We appreciate their contributions and insights that have provided a strong foundation for our research.
154 |
155 | ## 📨 Contact
156 |
157 | If you have any questions or suggestions, please feel free to contact us
158 | at [466439420gh@gmail.com](466439420gh@gmail.com).
159 |
160 | ## 📝 Citation
161 | If you find our work interesting and helpful, please consider giving our repo a star. Additionally, if you would like to cite our work, please use the following format:
162 | ```shell
163 | @article{wen2025spot,
164 | title={Spot the fake: Large multimodal model-based synthetic image detection with artifact explanation},
165 | author={Wen, Siwei and Ye, Junyan and Feng, Peilin and Kang, Hengrui and Wen, Zichen and Chen, Yize and Wu, Jiang and Wu, Wenjun and He, Conghui and Li, Weijia},
166 | journal={arXiv preprint arXiv:2503.14905},
167 | year={2025}
168 | }
169 | ```
170 |
--------------------------------------------------------------------------------
/arguments.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional, List
2 | from dataclasses import dataclass, field
3 |
4 | import transformers
5 |
6 | from supported_models import MODEL_HF_PATH, MODEL_FAMILIES
7 |
8 |
9 | @dataclass
10 | class ModelArguments:
11 | model_id: str = field(default="llava-1.5-7b")
12 | model_local_path: Optional[str] = field(default=None)
13 |
14 | def __post_init__(self):
15 | assert self.model_id in MODEL_HF_PATH, f"Unknown model_id: {self.model_id}"
16 | self.model_hf_path: str = MODEL_HF_PATH[self.model_id]
17 | assert self.model_id in MODEL_FAMILIES, f"Unknown model_id: {self.model_id}"
18 | self.model_family_id: str = MODEL_FAMILIES[self.model_id]
19 |
20 | if not self.model_local_path:
21 | self.model_local_path = self.model_hf_path
22 |
23 |
24 | @dataclass
25 | class DataArguments:
26 | data_path: str = field(
27 | default=None, metadata={"help": "Path to the training data json file."}
28 | )
29 | eval_data_path: Optional[str] = field(
30 | default=None, metadata={"help": "Path to the evaluation data json file."}
31 | )
32 | image_folder: Optional[str] = field(default=None)
33 | video_folder: Optional[str] = field(default=None)
34 | num_frames: Optional[int] = field(default=8)
35 | user_key: Optional[str] = field(default="human")
36 | assistant_key: Optional[str] = field(default="gpt")
37 |
38 |
39 | @dataclass
40 | class TrainingArguments(transformers.TrainingArguments):
41 | model_max_length: int = field(
42 | default=1024,
43 | metadata={
44 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
45 | },
46 | )
47 | use_flash_attn: bool = field(default=False)
48 | train_vision_encoder: bool = field(default=False)
49 | train_vision_projector: bool = field(default=False)
50 | mask_question_tokens: bool = field(default=True)
51 |
52 | def __post_init__(self):
53 | super().__post_init__()
54 | self.remove_unused_columns = False
55 |
56 |
57 | @dataclass
58 | class LoraArguments:
59 | use_lora: bool = field(default=True)
60 | use_vision_lora: bool = field(default=True)
61 | q_lora: bool = field(default=False)
62 | lora_r: int = field(default=8)
63 | lora_alpha: int = field(default=16)
64 | lora_dropout: float = field(default=0.05)
65 | lora_weight_path: str = ""
66 | lora_bias: str = "none"
--------------------------------------------------------------------------------
/collators/__init__.py:
--------------------------------------------------------------------------------
1 | COLLATORS = {}
2 |
3 | def register_collator(name):
4 | def register_collator_cls(cls):
5 | if name in COLLATORS:
6 | return COLLATORS[name]
7 | COLLATORS[name] = cls
8 | return cls
9 | return register_collator_cls
10 |
11 |
12 | from .llava_1_5 import LLaVA15DataCollator
--------------------------------------------------------------------------------
/collators/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Dict, Sequence, Optional
3 |
4 | import torch
5 | from transformers import PreTrainedTokenizer, AutoProcessor, AutoConfig
6 |
7 |
8 | class BaseDataCollator(ABC, object):
9 | """Collate examples for supervised fine-tuning."""
10 | def __init__(
11 | self,
12 | config: Optional[AutoConfig] = None,
13 | tokenizer: Optional[PreTrainedTokenizer] = None,
14 | processor: Optional[AutoProcessor] = None,
15 | mask_question_tokens: bool = True
16 | ) -> None:
17 | self.config = config
18 | self.tokenizer = tokenizer
19 | self.processor = processor
20 | self.mask_question_tokens = mask_question_tokens
21 |
22 | @property
23 | def IGNORE_TOKEN_ID(self) -> int:
24 | return -100
25 |
26 | @property
27 | def PAD_TOKEN_ID(self) -> int:
28 | return self.tokenizer.pad_token_id
29 |
30 | @abstractmethod
31 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: ...
--------------------------------------------------------------------------------
/collators/chat_template_monkey_patch.py:
--------------------------------------------------------------------------------
1 | # a monkey patch for https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1709
2 | # this sets add_special_tokens=True to the tokenizer call
3 |
4 | import re
5 | from inspect import isfunction
6 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
7 |
8 | from transformers.tokenization_utils_base import BatchEncoding
9 | from transformers.utils import TensorType, get_json_schema, logging
10 | from transformers.utils.chat_template_utils import _compile_jinja_template, _render_with_assistant_indices
11 |
12 |
13 | logger = logging.get_logger(__name__)
14 |
15 |
16 | def apply_chat_template(
17 | self,
18 | conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
19 | tools: Optional[List[Dict]] = None,
20 | documents: Optional[List[Dict[str, str]]] = None,
21 | chat_template: Optional[str] = None,
22 | add_generation_prompt: bool = False,
23 | continue_final_message: bool = False,
24 | tokenize: bool = True,
25 | padding: bool = False,
26 | truncation: bool = False,
27 | max_length: Optional[int] = None,
28 | return_tensors: Optional[Union[str, TensorType]] = None,
29 | return_dict: bool = False,
30 | return_assistant_tokens_mask: bool = False,
31 | tokenizer_kwargs: Optional[Dict[str, Any]] = None,
32 | **kwargs,
33 | ) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]:
34 | """
35 | Converts a list of dictionaries with `"role"` and `"content"` keys to a list of token
36 | ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to
37 | determine the format and control tokens to use when converting.
38 |
39 | Args:
40 | conversation (Union[List[Dict[str, str]], List[List[Dict[str, str]]]]): A list of dicts
41 | with "role" and "content" keys, representing the chat history so far.
42 | tools (`List[Dict]`, *optional*):
43 | A list of tools (callable functions) that will be accessible to the model. If the template does not
44 | support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema,
45 | giving the name, description and argument types for the tool. See our
46 | [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
47 | for more information.
48 | documents (`List[Dict[str, str]]`, *optional*):
49 | A list of dicts representing documents that will be accessible to the model if it is performing RAG
50 | (retrieval-augmented generation). If the template does not support RAG, this argument will have no
51 | effect. We recommend that each document should be a dict containing "title" and "text" keys. Please
52 | see the RAG section of the [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#arguments-for-RAG)
53 | for examples of passing documents with chat templates.
54 | chat_template (`str`, *optional*):
55 | A Jinja template to use for this conversion. It is usually not necessary to pass anything to this
56 | argument, as the model's template will be used by default.
57 | add_generation_prompt (bool, *optional*):
58 | If this is set, a prompt with the token(s) that indicate
59 | the start of an assistant message will be appended to the formatted output. This is useful when you want to generate a response from the model.
60 | Note that this argument will be passed to the chat template, and so it must be supported in the
61 | template for this argument to have any effect.
62 | continue_final_message (bool, *optional*):
63 | If this is set, the chat will be formatted so that the final
64 | message in the chat is open-ended, without any EOS tokens. The model will continue this message
65 | rather than starting a new one. This allows you to "prefill" part of
66 | the model's response for it. Cannot be used at the same time as `add_generation_prompt`.
67 | tokenize (`bool`, defaults to `True`):
68 | Whether to tokenize the output. If `False`, the output will be a string.
69 | padding (`bool`, defaults to `False`):
70 | Whether to pad sequences to the maximum length. Has no effect if tokenize is `False`.
71 | truncation (`bool`, defaults to `False`):
72 | Whether to truncate sequences at the maximum length. Has no effect if tokenize is `False`.
73 | max_length (`int`, *optional*):
74 | Maximum length (in tokens) to use for padding or truncation. Has no effect if tokenize is `False`. If
75 | not specified, the tokenizer's `max_length` attribute will be used as a default.
76 | return_tensors (`str` or [`~utils.TensorType`], *optional*):
77 | If set, will return tensors of a particular framework. Has no effect if tokenize is `False`. Acceptable
78 | values are:
79 | - `'tf'`: Return TensorFlow `tf.Tensor` objects.
80 | - `'pt'`: Return PyTorch `torch.Tensor` objects.
81 | - `'np'`: Return NumPy `np.ndarray` objects.
82 | - `'jax'`: Return JAX `jnp.ndarray` objects.
83 | return_dict (`bool`, defaults to `False`):
84 | Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
85 | tokenizer_kwargs (`Dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer.
86 | return_assistant_tokens_mask (`bool`, defaults to `False`):
87 | Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant,
88 | the mask will contain 1. For user and system tokens, the mask will contain 0.
89 | This functionality is only available for chat templates that support it via the `{% generation %}` keyword.
90 | **kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template.
91 |
92 | Returns:
93 | `Union[List[int], Dict]`: A list of token ids representing the tokenized chat so far, including control tokens. This
94 | output is ready to pass to the model, either directly or via methods like `generate()`. If `return_dict` is
95 | set, will return a dict of tokenizer outputs instead.
96 | """
97 |
98 | if return_dict and not tokenize:
99 | raise ValueError(
100 | "`return_dict=True` is incompatible with `tokenize=False`, because there is no dict "
101 | "of tokenizer outputs to return."
102 | )
103 |
104 | if return_assistant_tokens_mask and not return_dict:
105 | raise ValueError("`return_assistant_tokens_mask=True` is incompatible with `return_dict=False`")
106 |
107 | if tokenizer_kwargs is None:
108 | tokenizer_kwargs = {}
109 |
110 | chat_template = self.get_chat_template(chat_template, tools)
111 |
112 | if return_assistant_tokens_mask and not re.search(r"\{\%-?\s*generation\s*-?\%\}", chat_template):
113 | logger.warning_once(
114 | "return_assistant_tokens_mask==True but chat template does not contain `{% generation %}` keyword."
115 | )
116 |
117 | # Compilation function uses a cache to avoid recompiling the same template
118 | compiled_template = _compile_jinja_template(chat_template)
119 |
120 | if isinstance(conversation, (list, tuple)) and (
121 | isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages")
122 | ):
123 | conversations = conversation
124 | is_batched = True
125 | else:
126 | conversations = [conversation]
127 | is_batched = False
128 |
129 | if continue_final_message:
130 | if add_generation_prompt:
131 | raise ValueError(
132 | "continue_final_message and add_generation_prompt are not compatible. Use continue_final_message when you want the model to continue the final message, and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead."
133 | )
134 | if return_assistant_tokens_mask:
135 | raise ValueError("continue_final_message is not compatible with return_assistant_tokens_mask.")
136 |
137 | # We accept either JSON schemas or functions for tools. If we get functions, we convert them to schemas
138 | if tools is not None:
139 | tool_schemas = []
140 | for tool in tools:
141 | if isinstance(tool, dict):
142 | tool_schemas.append(tool)
143 | elif isfunction(tool):
144 | tool_schemas.append(get_json_schema(tool))
145 | else:
146 | raise ValueError(
147 | "Tools should either be a JSON schema, or a callable function with type hints "
148 | "and a docstring suitable for auto-conversion to a schema."
149 | )
150 | else:
151 | tool_schemas = None
152 |
153 | if documents is not None:
154 | for document in documents:
155 | if not isinstance(document, dict):
156 | raise TypeError("Documents should be a list of dicts with 'title' and 'text' keys!")
157 |
158 | rendered = []
159 | all_generation_indices = []
160 | template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present
161 | for chat in conversations:
162 | if hasattr(chat, "messages"):
163 | # Indicates it's a Conversation object
164 | chat = chat.messages
165 | if return_assistant_tokens_mask:
166 | rendered_chat, generation_indices = _render_with_assistant_indices(
167 | compiled_template=compiled_template,
168 | messages=chat,
169 | tools=tool_schemas,
170 | documents=documents,
171 | add_generation_prompt=add_generation_prompt,
172 | **template_kwargs,
173 | )
174 | all_generation_indices.append(generation_indices)
175 | else:
176 | rendered_chat = compiled_template.render(
177 | messages=chat,
178 | tools=tool_schemas,
179 | documents=documents,
180 | add_generation_prompt=add_generation_prompt,
181 | **template_kwargs,
182 | )
183 | if continue_final_message:
184 | final_message = chat[-1]["content"].strip()
185 | rendered_chat = rendered_chat[: rendered_chat.rindex(final_message) + len(final_message)].rstrip()
186 | rendered.append(rendered_chat)
187 |
188 | if not is_batched:
189 | rendered = rendered[0]
190 |
191 | if tokenize:
192 | out = self(
193 | rendered,
194 | padding=padding,
195 | truncation=truncation,
196 | max_length=max_length,
197 | add_special_tokens=True,
198 | return_tensors=return_tensors,
199 | **tokenizer_kwargs,
200 | )
201 | if return_dict:
202 | if return_assistant_tokens_mask:
203 | assistant_masks = []
204 | if is_batched or return_tensors:
205 | input_ids = out["input_ids"]
206 | else:
207 | input_ids = [out["input_ids"]]
208 | for i in range(len(input_ids)):
209 | current_mask = [0] * len(input_ids[i])
210 | for assistant_start_char, assistant_end_char in all_generation_indices[i]:
211 | start_token = out.char_to_token(i, assistant_start_char)
212 | end_token = out.char_to_token(i, assistant_end_char - 1)
213 | if start_token is None:
214 | # start_token is out of bounds maybe due to truncation.
215 | break
216 | for token_id in range(start_token, end_token + 1 if end_token else len(input_ids)):
217 | current_mask[token_id] = 1
218 | assistant_masks.append(current_mask)
219 | out["assistant_masks"] = assistant_masks if is_batched else assistant_masks[0]
220 | return out
221 | else:
222 | return out["input_ids"]
223 | else:
224 | return rendered
--------------------------------------------------------------------------------
/collators/llava_1_5.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import Dict, List, Sequence, Union
3 |
4 | import numpy as np
5 | import PIL
6 | import torch
7 | from transformers.image_utils import get_image_size, to_numpy_array
8 | from transformers.models.llava.processing_llava import LlavaProcessorKwargs
9 | from transformers.utils import logging
10 |
11 | from . import register_collator
12 | from .base import BaseDataCollator
13 | from .chat_template_monkey_patch import apply_chat_template
14 |
15 |
16 | logger = logging.get_logger(__name__)
17 |
18 |
19 | @register_collator("llava-1.5")
20 | class LLaVA15DataCollator(BaseDataCollator):
21 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
22 | # monkey patch to include bos tokens
23 | self.tokenizer.apply_chat_template = apply_chat_template.__get__(self.tokenizer)
24 |
25 | output_kwargs = self.processor._merge_kwargs(
26 | LlavaProcessorKwargs,
27 | tokenizer_init_kwargs=self.tokenizer.init_kwargs,
28 | )
29 |
30 | vision_inputs = dict()
31 | images: List[List[PIL.Image.Image]] = [x for instance in instances for x in instance["images"]]
32 | if len(images) > 0:
33 | vision_inputs.update(**self.processor.image_processor(images, return_tensors="pt", **output_kwargs["images_kwargs"]))
34 |
35 | # some parsing
36 | images: List[List[PIL.Image.Image]] = [instance["images"] for instance in instances]
37 | system_prompts: List[Union[str, None]] = [instance["system_prompt"] for instance in instances]
38 | conversations: List[List] = [instance["conversations"] for instance in instances]
39 |
40 | # constants
41 | max_len = self.tokenizer.model_max_length
42 | image_token_id = self.config.image_token_index
43 | patch_size = self.processor.patch_size
44 | vision_feature_select_strategy = self.processor.vision_feature_select_strategy
45 |
46 | input_ids = []
47 | labels = []
48 |
49 | for system_prompt, cur_images, cur_convs in zip(system_prompts, images, conversations):
50 | cur_num_images = 0
51 | cur_input_ids = []
52 | cur_labels = []
53 |
54 | cur_text = []
55 | if system_prompt is not None:
56 | cur_text.append({
57 | "role": "system",
58 | "content": [{"type": "text", "text": system_prompt}]
59 | })
60 |
61 | for i, text in enumerate(cur_convs):
62 | if i % 2 == 0:
63 | num_images = len([m.start() for m in re.finditer("", text)])
64 | cur_num_images += num_images
65 |
66 | # .strip(): whitespaces and newlines are handled by chat_template
67 | text = text.replace("", "").strip()
68 |
69 | cur_text.append({
70 | "role": "user",
71 | "content": [{"type": "text", "text": text}] + \
72 | [{"type": "image"}] * num_images
73 | })
74 | else:
75 | cur_text.append({
76 | "role": "assistant",
77 | "content": [
78 | {"type": "text", "text": text},
79 | ]
80 | })
81 |
82 | assert len(cur_images) == cur_num_images, "Number of image tokens does not match the number of images"
83 |
84 | temp = self.processor.apply_chat_template(
85 | cur_text,
86 | add_generation_prompt=False,
87 | tokenize=True,
88 | return_assistant_tokens_mask=True,
89 | return_dict=True,
90 | return_tensors="pt",
91 | truncation=False # the assistant tokens mask seems wrong when truncation is enabled
92 | )
93 | cur_input_ids = temp["input_ids"]
94 | cur_assistant_masks = torch.tensor(temp["assistant_masks"], dtype=torch.bool).unsqueeze(0)
95 |
96 | # expand image tokens
97 | temp_vision_inputs = self.processor.image_processor(cur_images, return_tensors="pt")
98 | if temp_vision_inputs.get("pixel_values") is not None:
99 | if patch_size is not None and vision_feature_select_strategy is not None:
100 | # Replace the image token with the expanded image token sequence
101 | pixel_values = temp_vision_inputs["pixel_values"]
102 | height, width = get_image_size(to_numpy_array(pixel_values[0]))
103 | num_image_tokens = (height // patch_size) * (width // patch_size) + 1
104 | if vision_feature_select_strategy == "default":
105 | num_image_tokens -= 1
106 |
107 | repeat = torch.where(cur_input_ids == image_token_id, num_image_tokens, 1).squeeze()
108 | cur_input_ids = cur_input_ids.repeat_interleave(repeat, dim=1)
109 | cur_assistant_masks = cur_assistant_masks.repeat_interleave(repeat, dim=1)
110 | else:
111 | logger.warning_once(
112 | "Expanding inputs for image tokens in LLaVa should be done in processing. "
113 | "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
114 | "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
115 | "Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
116 | )
117 |
118 | # a dirty hack to include eos token as part of the labels
119 | cur_assistant_masks[0, -1] = True
120 |
121 | # manual truncation
122 | if cur_input_ids.shape[1] > max_len:
123 | cur_input_ids = cur_input_ids[:, :max_len]
124 | cur_assistant_masks = cur_assistant_masks[:, :max_len]
125 | cur_labels = cur_input_ids.clone()
126 |
127 | # mask question tokens
128 | if self.mask_question_tokens:
129 | assert cur_labels.shape == cur_assistant_masks.shape, "Label and mask shapes do not match"
130 | cur_labels = torch.where(cur_assistant_masks, cur_labels, self.IGNORE_TOKEN_ID)
131 |
132 | assert cur_input_ids.shape == cur_labels.shape, "Input and label shapes do not match"
133 |
134 | # padding
135 | if cur_input_ids.shape[1] < max_len:
136 | cur_input_ids = torch.cat([
137 | cur_input_ids,
138 | torch.full(
139 | (cur_input_ids.shape[0], max_len - cur_input_ids.shape[1]),
140 | self.PAD_TOKEN_ID,
141 | dtype=cur_input_ids.dtype,
142 | device=cur_input_ids.device
143 | )
144 | ], dim=1)
145 | cur_labels = torch.cat([
146 | cur_labels,
147 | torch.full(
148 | (cur_labels.shape[0], max_len - cur_labels.shape[1]),
149 | self.IGNORE_TOKEN_ID,
150 | dtype=cur_labels.dtype,
151 | device=cur_labels.device
152 | )
153 | ], dim=1)
154 |
155 | input_ids.append(cur_input_ids)
156 | labels.append(cur_labels)
157 |
158 | input_ids = torch.cat(input_ids)
159 | labels = torch.cat(labels)
160 |
161 | return dict(
162 | **vision_inputs,
163 | input_ids=input_ids,
164 | labels=labels,
165 | attention_mask=input_ids.ne(self.PAD_TOKEN_ID),
166 | )
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import av
2 | import os
3 | import json
4 | from PIL import Image
5 | from typing import Dict, List, Optional
6 |
7 | import numpy as np
8 | from torch.utils.data import Dataset
9 |
10 |
11 | TO_LOAD_IMAGE: Dict[str, bool] = {
12 | "llava-1.5": True,
13 | }
14 |
15 |
16 | def read_video_pyav(container, indices):
17 | '''
18 | Decode the video with PyAV decoder.
19 | Args:
20 | container (`av.container.input.InputContainer`): PyAV container.
21 | indices (`List[int]`): List of frame indices to decode.
22 | Returns:
23 | result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
24 | '''
25 | frames = []
26 | container.seek(0)
27 | start_index = indices[0]
28 | end_index = indices[-1]
29 | for i, frame in enumerate(container.decode(video=0)):
30 | if i > end_index:
31 | break
32 | if i >= start_index and i in indices:
33 | frames.append(frame)
34 | return np.stack([x.to_ndarray(format="rgb24") for x in frames])
35 |
36 |
37 | class LazySupervisedDataset(Dataset):
38 | """Dataset for supervised fine-tuning
39 | which is generalized enough to handle both images and videos.
40 | """
41 |
42 | def __init__(
43 | self,
44 | data_path: str,
45 | model_family_id: str,
46 | image_folder: Optional[str] = None,
47 | video_folder: Optional[str] = None,
48 | num_frames: int = 8,
49 | user_key: str = "human",
50 | assistant_key: str = "gpt",
51 | ) -> None:
52 | super(LazySupervisedDataset, self).__init__()
53 | self.list_data_dict = json.load(open(data_path, "r"))
54 | self.image_folder = image_folder
55 | self.video_folder = video_folder
56 | self.num_frames = num_frames
57 | self.load_image = TO_LOAD_IMAGE[model_family_id]
58 | self.user_key = user_key
59 | self.assistant_key = assistant_key
60 |
61 | self.is_text_only = [
62 | "image" not in source and "video" not in source
63 | for source in self.list_data_dict
64 | ]
65 |
66 | def __len__(self) -> int:
67 | return len(self.list_data_dict)
68 |
69 | def __getitem__(self, i) -> Dict[str, List]:
70 | source = self.list_data_dict[i]
71 |
72 | images = []
73 | if "image" in source:
74 | # here we do not do any image preprocessing but rather
75 | # let the processor handle everything
76 | # in some cases this may cause slight differences
77 | # but should totally be fine (e.g., official llava-1.5 does padding,
78 | # but llava-1.5-hf (huggingface's implementation) does not)
79 | if isinstance(source["image"], list):
80 | image_sources = source["image"]
81 | elif isinstance(source["image"], str):
82 | image_sources = [source["image"]]
83 | else:
84 | raise ValueError(f"Invalid image source type: {type(source['image'])}")
85 |
86 | for image_path in image_sources:
87 | if self.image_folder is not None:
88 | image_path = os.path.join(self.image_folder, image_path)
89 | images.append(
90 | Image.open(image_path).convert("RGB")
91 | if self.load_image else image_path
92 | )
93 |
94 | videos = []
95 | if "video" in source:
96 | if isinstance(source["video"], list):
97 | video_sources = source["video"]
98 | elif isinstance(source["video"], str):
99 | video_sources = [source["video"]]
100 | else:
101 | raise ValueError(f"Invalid video source type: {type(source['video'])}")
102 |
103 | num_frames = [self.num_frames] * len(video_sources)
104 |
105 | for video_path, cur_num_frames in zip(video_sources, num_frames):
106 | if self.video_folder is not None:
107 | video_path = os.path.join(self.video_folder, video_path)
108 |
109 | container = av.open(video_path)
110 | total_frames = container.streams.video[0].frames
111 | indices = np.arange(0, total_frames, total_frames / cur_num_frames).astype(int)
112 | clip = read_video_pyav(container, indices)
113 |
114 | videos.append(clip)
115 |
116 | system_prompt = None
117 | if "system_prompt" in source:
118 | system_prompt = source["system_prompt"]
119 |
120 | convs = []
121 | assert len(source["conversations"]) > 0, "No conversations found"
122 | for i, conv in enumerate(source["conversations"]):
123 | assert conv["from"] == (self.user_key if i % 2 == 0 else self.assistant_key), "Invalid conversation"
124 | convs.append(conv["value"])
125 | assert len(convs) % 2 == 0, "Odd number of conversations"
126 |
127 | return dict(
128 | images=images,
129 | videos=videos,
130 | conversations=convs,
131 | system_prompt=system_prompt
132 | )
--------------------------------------------------------------------------------
/ds_configs/zero2.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "train_micro_batch_size_per_gpu": "auto",
14 | "train_batch_size": "auto",
15 | "gradient_accumulation_steps": "auto",
16 | "zero_optimization": {
17 | "stage": 2,
18 | "overlap_comm": false,
19 | "contiguous_gradients": true,
20 | "sub_group_size": 1e9,
21 | "reduce_bucket_size": "auto"
22 | }
23 | }
--------------------------------------------------------------------------------
/ds_configs/zero3.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "train_micro_batch_size_per_gpu": "auto",
14 | "train_batch_size": "auto",
15 | "gradient_accumulation_steps": "auto",
16 | "zero_optimization": {
17 | "stage": 3,
18 | "overlap_comm": true,
19 | "contiguous_gradients": true,
20 | "sub_group_size": 1e9,
21 | "reduce_bucket_size": "auto",
22 | "stage3_prefetch_bucket_size": "auto",
23 | "stage3_param_persistence_threshold": "auto",
24 | "stage3_max_live_parameters": 1e9,
25 | "stage3_max_reuse_distance": 1e9,
26 | "stage3_gather_16bit_weights_on_model_save": true
27 | }
28 | }
--------------------------------------------------------------------------------
/imgs/ddvqa.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendatalab/FakeVLM/fdcdcf111ecaff13ded65e32ce34d04cf876bc6f/imgs/ddvqa.jpg
--------------------------------------------------------------------------------
/imgs/fakeclue_loki_result.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendatalab/FakeVLM/fdcdcf111ecaff13ded65e32ce34d04cf876bc6f/imgs/fakeclue_loki_result.jpg
--------------------------------------------------------------------------------
/imgs/framework.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendatalab/FakeVLM/fdcdcf111ecaff13ded65e32ce34d04cf876bc6f/imgs/framework.jpg
--------------------------------------------------------------------------------
/imgs/logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendatalab/FakeVLM/fdcdcf111ecaff13ded65e32ce34d04cf876bc6f/imgs/logo.jpg
--------------------------------------------------------------------------------
/imgs/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendatalab/FakeVLM/fdcdcf111ecaff13ded65e32ce34d04cf876bc6f/imgs/overview.png
--------------------------------------------------------------------------------
/imgs/result.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendatalab/FakeVLM/fdcdcf111ecaff13ded65e32ce34d04cf876bc6f/imgs/result.jpg
--------------------------------------------------------------------------------
/loaders/__init__.py:
--------------------------------------------------------------------------------
1 | LOADERS = {}
2 |
3 | def register_loader(name):
4 | def register_loader_cls(cls):
5 | if name in LOADERS:
6 | return LOADERS[name]
7 | LOADERS[name] = cls
8 | return cls
9 | return register_loader_cls
10 |
11 |
12 | from .llava_1_5 import LLaVA15ModelLoader
--------------------------------------------------------------------------------
/loaders/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Dict, Tuple, Union, Optional
3 |
4 | import torch
5 | from transformers import PreTrainedModel, PreTrainedTokenizer, AutoProcessor, BitsAndBytesConfig
6 |
7 |
8 | class BaseModelLoader(ABC):
9 | def __init__(
10 | self,
11 | model_hf_path: str,
12 | model_local_path: str,
13 | compute_dtype: torch.dtype,
14 | bnb_config: Optional[BitsAndBytesConfig] = None,
15 | use_flash_attn: bool = False,
16 | device_map: Optional[Union[Dict, str]] = None,
17 | ) -> None:
18 | self.model_hf_path = model_hf_path
19 | self.model_local_path = model_local_path
20 | self.loading_kwargs = dict(
21 | torch_dtype=compute_dtype,
22 | quantization_config=bnb_config,
23 | device_map=device_map,
24 | )
25 | if use_flash_attn:
26 | self.loading_kwargs["attn_implementation"] = "flash_attention_2"
27 |
28 | @abstractmethod
29 | def load(self, load_model: bool = True) -> Tuple[
30 | PreTrainedModel, Union[None, PreTrainedTokenizer], Union[None, AutoProcessor]
31 | ]: ...
--------------------------------------------------------------------------------
/loaders/llava_1_5.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | from transformers import AutoProcessor, LlavaForConditionalGeneration, PreTrainedTokenizer, AutoConfig
4 |
5 | from . import register_loader
6 | from .base import BaseModelLoader
7 |
8 |
9 | @register_loader("llava-1.5")
10 | class LLaVA15ModelLoader(BaseModelLoader):
11 | def load(self, load_model: bool = True) -> Tuple[LlavaForConditionalGeneration, PreTrainedTokenizer, AutoProcessor, AutoConfig]:
12 | if load_model:
13 | model = LlavaForConditionalGeneration.from_pretrained(
14 | self.model_local_path,
15 | **self.loading_kwargs,
16 | )
17 | model.config.hidden_size = model.language_model.config.hidden_size # useful for deepspeed
18 | else:
19 | model = None
20 |
21 | processor = AutoProcessor.from_pretrained(self.model_hf_path, add_eos_token=True)
22 | tokenizer = processor.tokenizer
23 | config = AutoConfig.from_pretrained(self.model_local_path)
24 | return model, tokenizer, processor, config
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==1.10.1
2 | aiofiles==24.1.0
3 | annotated-types==0.7.0
4 | anyio==4.10.0
5 | av @ file:///home/conda/feedstock_root/build_artifacts/av_1756861937826/work
6 | bitsandbytes==0.47.0
7 | Brotli==1.1.0
8 | certifi==2025.8.3
9 | charset-normalizer==3.4.3
10 | click==8.3.0
11 | contourpy==1.3.2
12 | cycler==0.12.1
13 | deepspeed==0.14.4
14 | einops==0.8.1
15 | exceptiongroup==1.3.0
16 | fastapi==0.117.1
17 | ffmpy==0.6.1
18 | filelock==3.19.1
19 | fonttools==4.60.0
20 | fsspec==2025.9.0
21 | gitdb==4.0.12
22 | GitPython==3.1.45
23 | gradio==5.46.11
24 | gradio_client==1.13.1
25 | groovy==0.1.2
26 | h11==0.16.0
27 | hf-xet==1.1.10
28 | hjson==3.1.0
29 | httpcore==1.0.9
30 | httpx==0.28.1
31 | huggingface-hub==0.35.0
32 | idna==3.10
33 | Jinja2==3.1.6
34 | kiwisolver==1.4.9
35 | markdown-it-py==4.0.0
36 | MarkupSafe==3.0.2
37 | matplotlib==3.10.6
38 | mdurl==0.1.2
39 | mpmath==1.3.0
40 | networkx==3.4.2
41 | ninja==1.13.0
42 | numpy @ file:///home/conda/feedstock_root/build_artifacts/numpy_1747544640217/work/dist/numpy-2.2.6-cp310-cp310-linux_x86_64.whl#sha256=d6d964caeef85d00073d27cd62b46883d275b3d8162f723f0fcabbd0b3cc3f9d
43 | nvidia-cublas-cu12==12.8.4.1
44 | nvidia-cuda-cupti-cu12==12.8.90
45 | nvidia-cuda-nvrtc-cu12==12.8.93
46 | nvidia-cuda-runtime-cu12==12.8.90
47 | nvidia-cudnn-cu12==9.10.2.21
48 | nvidia-cufft-cu12==11.3.3.83
49 | nvidia-cufile-cu12==1.13.1.3
50 | nvidia-curand-cu12==10.3.9.90
51 | nvidia-cusolver-cu12==11.7.3.90
52 | nvidia-cusparse-cu12==12.5.8.93
53 | nvidia-cusparselt-cu12==0.7.1
54 | nvidia-ml-py==13.580.82
55 | nvidia-nccl-cu12==2.27.3
56 | nvidia-nvjitlink-cu12==12.8.93
57 | nvidia-nvtx-cu12==12.8.90
58 | orjson==3.11.3
59 | packaging==25.0
60 | pandas==2.3.2
61 | peft==0.17.1
62 | pillow @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_pillow_1758208668/work
63 | platformdirs==4.4.0
64 | protobuf==6.32.1
65 | psutil==7.1.0
66 | py-cpuinfo==9.0.0
67 | pydantic==2.11.9
68 | pydantic_core==2.33.2
69 | pydub==0.25.1
70 | Pygments==2.19.2
71 | pyparsing==3.2.5
72 | python-dateutil==2.9.0.post0
73 | python-multipart==0.0.20
74 | pytz==2025.2
75 | PyYAML==6.0.2
76 | regex==2025.9.18
77 | requests==2.32.5
78 | rich==14.1.0
79 | ruff==0.13.1
80 | safehttpx==0.1.6
81 | safetensors==0.6.2
82 | semantic-version==2.10.0
83 | sentencepiece==0.2.0
84 | sentry-sdk==2.38.0
85 | shellingham==1.5.4
86 | six==1.17.0
87 | smmap==5.0.2
88 | sniffio==1.3.1
89 | starlette==0.48.0
90 | sympy==1.14.0
91 | tiktoken==0.11.0
92 | tokenizers==0.20.3
93 | tomlkit==0.13.3
94 | torch==2.8.0
95 | torchvision==0.23.0
96 | tqdm==4.67.1
97 | transformers==4.45.2
98 | transformers-stream-generator==0.0.5
99 | triton==3.4.0
100 | typer==0.19.1
101 | typing-inspection==0.4.1
102 | typing_extensions==4.15.0
103 | tzdata==2025.2
104 | urllib3==2.5.0
105 | uvicorn==0.36.0
106 | wandb==0.22.0
107 | websockets==15.0.1
108 |
--------------------------------------------------------------------------------
/scripts/eval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import argparse
4 | from dataclasses import dataclass, field
5 | import transformers
6 | from torch.utils.data import Dataset, DataLoader
7 | from transformers import CLIPImageProcessor
8 | import pdb
9 | import json
10 | from transformers import AutoProcessor, LlavaForConditionalGeneration
11 | from tqdm import tqdm
12 | import random
13 | import numpy as np
14 | import torch
15 | import torchvision.transforms as T
16 | from PIL import Image
17 | from torchvision.transforms.functional import InterpolationMode
18 | import torch.nn as nn
19 |
20 |
21 | IMAGENET_MEAN = (0.485, 0.456, 0.406)
22 | IMAGENET_STD = (0.229, 0.224, 0.225)
23 |
24 | def parse_args():
25 | parser = argparse.ArgumentParser(description="Legion Model Training")
26 |
27 | # Model-specific settings
28 | parser.add_argument("--model_path", default="", type=str)
29 | parser.add_argument("--val_batch_size", default=1, type=int)
30 | parser.add_argument("--workers", default=1, type=int)
31 | parser.add_argument("--data_base_test", default="", type=str)
32 | parser.add_argument("--test_json_file", default="", type=str)
33 | parser.add_argument("--output_path", default="", type=str)
34 | return parser.parse_args()
35 |
36 | class legion_cls_dataset(Dataset):
37 | def __init__(self, args, train=True):
38 | super().__init__()
39 | self.args = args
40 | self.train = train
41 | if train == True:
42 | with open(args.train_json_file, 'r') as f:
43 | self.data = json.load(f)
44 | elif train == False:
45 | with open(args.test_json_file, 'r') as f:
46 | self.data = json.load(f)
47 | self.processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf", revision='a272c74')
48 |
49 | def __len__(self):
50 | return len(self.data)
51 |
52 | def __getitem__(self, idx):
53 | if self.train == True:
54 | img_path = os.path.join(self.args.data_base_train, self.data[idx]['image'])
55 | else:
56 | img_path = os.path.join(self.args.data_base_test, self.data[idx]['image'])
57 | label = self.data[idx]['label']
58 |
59 | image = Image.open(img_path)
60 |
61 | inputs = self.processor(
62 | text=self.data[idx]['conversations'][0]['value'],
63 | images=image,
64 | return_tensors="pt",
65 | padding="max_length",
66 | max_length=1024,
67 | truncation=True
68 | )
69 |
70 | cate = 'deepfake'
71 | # torch.Size([n, 3, 448, 448]), int, int, str, str
72 | return inputs, [label], [img_path], [cate]
73 |
74 |
75 |
76 |
77 | def load_model(args):
78 | print("Loading model...")
79 | model = LlavaForConditionalGeneration.from_pretrained(
80 | args.model_path,
81 | torch_dtype=torch.bfloat16,
82 | low_cpu_mem_usage=True,
83 | use_flash_attention_2=True,
84 | revision='a272c74',
85 | ).eval().cuda()
86 | print("Successfully loaded model from:", args.model_path)
87 | return model
88 |
89 | def calculate_results_acc(results):
90 | acc_results = {}
91 |
92 | for cate in results:
93 | data = results[cate]
94 |
95 | right_real = data['right']['right_real']
96 | right_fake = data['right']['right_fake']
97 | wrong_real = data['wrong']['wrong_real']
98 | wrong_fake = data['wrong']['wrong_fake']
99 |
100 | total_real = right_real + wrong_real
101 | total_fake = right_fake + wrong_fake
102 | total = total_real + total_fake
103 |
104 | acc_total = (right_real + right_fake) / total if total != 0 else 0
105 | acc_real = right_real / total_real if total_real != 0 else 0
106 | acc_fake = right_fake / total_fake if total_fake != 0 else 0
107 |
108 | acc_results[cate] = {
109 | 'total_samples': total,
110 | 'total_accuracy': round(acc_total, 4),
111 | 'real_accuracy': round(acc_real, 4),
112 | 'fake_accuracy': round(acc_fake, 4),
113 | 'confusion_matrix': {
114 | 'right_real': right_real,
115 | 'wrong_real': wrong_real,
116 | 'right_fake': right_fake,
117 | 'wrong_fake': wrong_fake,
118 | }
119 | }
120 |
121 | global_stats = {
122 | 'total_right': sum(r['right']['right_real'] + r['right']['right_fake'] for r in results.values()),
123 | 'total_wrong': sum(r['wrong']['wrong_real'] + r['wrong']['wrong_fake'] for r in results.values())
124 | }
125 | global_stats['global_accuracy'] = global_stats['total_right'] / (global_stats['total_right'] + global_stats['total_wrong'])
126 |
127 | return {
128 | 'category_acc': acc_results,
129 | 'global_stats': global_stats
130 | }
131 |
132 |
133 | def validate(args, model, cls_test_dataloader):
134 | processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf", revision='a272c74')
135 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
136 | results = {}
137 | outputs = []
138 | with torch.no_grad():
139 | for inputs, labels, paths, cates in tqdm(cls_test_dataloader):
140 |
141 | inputs["input_ids"] = inputs["input_ids"].squeeze().to(device)
142 | inputs["attention_mask"] = inputs["attention_mask"].squeeze().to(device)
143 | inputs["pixel_values"] = inputs["pixel_values"].squeeze().to(device)
144 | output = model.generate(**inputs, max_new_tokens=256)
145 | pred_cls = []
146 |
147 | for i in range(output.shape[0]):
148 | response = processor.decode(output[i], skip_special_tokens=True).split('?')[-1]
149 | print(response)
150 | outputs.append({"image_path": paths[0][i], "output": response})
151 | # pdb.set_trace()
152 | if 'real' in response.split('.')[0].lower():
153 | pred_cls.append(1)
154 | elif 'fake' in response.split('.')[0].lower():
155 | pred_cls.append(0)
156 | else:
157 | try:
158 | if 'real' in response.split('.')[1].lower():
159 | pred_cls.append(1)
160 | elif 'fake' in response.split('.')[1].lower():
161 | pred_cls.append(0)
162 | else:
163 | print(f"no fake or real in reponse:{response}")
164 | pred_cls.append(random.choice([0, 1]))
165 | except:
166 | print(f"no fake or real in reponse:{response}")
167 | pred_cls.append(random.choice([0, 1]))
168 | for label, pred, cate in zip(labels[0].tolist(), pred_cls, cates[0]):
169 | if cate not in results:
170 | results[cate] = {'right':{'right_fake':0, 'right_real':0}, 'wrong':{'wrong_fake':0, 'wrong_real':0}}
171 | if label == pred:
172 | if label == 1:
173 | results[cate]['right']['right_real'] += 1
174 | else:
175 | results[cate]['right']['right_fake'] += 1
176 | else:
177 | if label == 1:
178 | results[cate]['wrong']['wrong_real'] += 1
179 | else:
180 | results[cate]['wrong']['wrong_fake'] += 1
181 |
182 | os.makedirs('results', exist_ok=True)
183 | with open(args.output_path, "w") as file:
184 | json.dump(outputs, file, indent=2)
185 | acc = calculate_results_acc(results)
186 | print(acc)
187 |
188 |
189 |
190 | def main():
191 | args = parse_args()
192 | model = load_model(args)
193 | model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
194 | cls_test_dataset = legion_cls_dataset(args, train=False)
195 | cls_test_dataloader = DataLoader(
196 | cls_test_dataset,
197 | batch_size=args.val_batch_size,
198 | shuffle=False,
199 | num_workers=args.workers,
200 | pin_memory=True,
201 | )
202 | validate(args, model, cls_test_dataloader)
203 |
204 |
205 |
206 |
207 | if __name__ == "__main__":
208 | main()
209 |
210 |
--------------------------------------------------------------------------------
/scripts/eval.sh:
--------------------------------------------------------------------------------
1 | python scripts/eval.py \
2 | --model_path /path/to/your/checkpoint \
3 | --val_batch_size 16 \
4 | --workers 16 \
5 | --output_path results/fakevlm.json \
6 | --test_json_file "/path/to/your/test.json" \
7 | --data_base_test "path/to/your/test_images" \
--------------------------------------------------------------------------------
/scripts/eval_vllm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import argparse
4 | from dataclasses import dataclass, field
5 | import transformers
6 | from torch.utils.data import Dataset, DataLoader
7 | from transformers import CLIPImageProcessor
8 | import pdb
9 | import json
10 | from transformers import AutoProcessor, LlavaForConditionalGeneration
11 | from tqdm import tqdm
12 | import random
13 | import numpy as np
14 | import torch
15 | import torchvision.transforms as T
16 | from PIL import Image
17 | from torchvision.transforms.functional import InterpolationMode
18 | from vllm import LLM, SamplingParams
19 | import pdb
20 | IMAGENET_MEAN = (0.485, 0.456, 0.406)
21 | IMAGENET_STD = (0.229, 0.224, 0.225)
22 |
23 | def parse_args():
24 | parser = argparse.ArgumentParser(description="FakeVLM Model Testing")
25 |
26 | # Model-specific settings
27 | parser.add_argument("--model_path", default="", type=str)
28 | parser.add_argument("--val_batch_size", default=1, type=int)
29 | parser.add_argument("--workers", default=1, type=int)
30 | parser.add_argument("--data_base_test", default="", type=str)
31 | parser.add_argument("--test_json_file", default="", type=str)
32 | parser.add_argument("--output_path", default="", type=str)
33 | return parser.parse_args()
34 |
35 | class legion_cls_dataset(Dataset):
36 | def __init__(self, args, train=True):
37 | super().__init__()
38 | self.args = args
39 | self.train = train
40 | if train == True:
41 | with open(args.train_json_file, 'r') as f:
42 | self.data = json.load(f)
43 | elif train == False:
44 | with open(args.test_json_file, 'r') as f:
45 | self.data = json.load(f)
46 |
47 | def __len__(self):
48 | return len(self.data)
49 |
50 | def __getitem__(self, idx):
51 | if self.train == True:
52 | img_path = os.path.join(self.args.data_base_train, self.data[idx]['image'])
53 | else:
54 | img_path = os.path.join(self.args.data_base_test, self.data[idx]['image'])
55 |
56 |
57 | label = self.data[idx]['label']
58 |
59 | cate = self.data[idx]['cate']
60 |
61 | # torch.Size([n, 3, 448, 448]), int, int, str, str
62 | return [self.data[idx]['conversations'][0]['value']], [label], [img_path], [cate]
63 |
64 |
65 |
66 |
67 | # change this function to our own to support custom behaviors
68 | def load_model(args):
69 | print("Loading model...")
70 | llm = LLM(model=args.model_path,
71 | dtype="float16",
72 | max_model_len=800,
73 | tensor_parallel_size=torch.cuda.device_count()
74 | )
75 | print("Successfully loaded model from:", args.model_path)
76 | return llm
77 |
78 | def calculate_results_acc(results):
79 | acc_results = {}
80 |
81 | for cate in results:
82 | data = results[cate]
83 |
84 | right_real = data['right']['right_real']
85 | right_fake = data['right']['right_fake']
86 | wrong_real = data['wrong']['wrong_real']
87 | wrong_fake = data['wrong']['wrong_fake']
88 |
89 | total_real = right_real + wrong_real
90 | total_fake = right_fake + wrong_fake
91 | total = total_real + total_fake
92 |
93 | acc_total = (right_real + right_fake) / total if total != 0 else 0
94 | acc_real = right_real / total_real if total_real != 0 else 0
95 | acc_fake = right_fake / total_fake if total_fake != 0 else 0
96 |
97 | acc_results[cate] = {
98 | 'total_samples': total,
99 | 'total_accuracy': round(acc_total, 4),
100 | 'real_accuracy': round(acc_real, 4),
101 | 'fake_accuracy': round(acc_fake, 4),
102 | 'overall': {
103 | 'right_real': right_real,
104 | 'wrong_real': wrong_real,
105 | 'right_fake': right_fake,
106 | 'wrong_fake': wrong_fake,
107 | }
108 | }
109 |
110 | global_stats = {
111 | 'total_right': sum(r['right']['right_real'] + r['right']['right_fake'] for r in results.values()),
112 | 'total_wrong': sum(r['wrong']['wrong_real'] + r['wrong']['wrong_fake'] for r in results.values())
113 | }
114 | global_stats['global_accuracy'] = global_stats['total_right'] / (global_stats['total_right'] + global_stats['total_wrong'])
115 |
116 | return {
117 | 'category_acc': acc_results,
118 | 'global_stats': global_stats
119 | }
120 |
121 |
122 |
123 | def validate(args, model, cls_test_dataloader):
124 | results = {}
125 | output_result = []
126 | sampling_params = SamplingParams(
127 | max_tokens=4096,
128 | temperature=0,
129 | )
130 |
131 | with torch.no_grad():
132 | for questions, labels, imgs, cates in tqdm(cls_test_dataloader):
133 | inputs = []
134 |
135 | for question, img in zip(questions[0], imgs[0]):
136 | inputs.append({
137 | "prompt": question,
138 | "multi_modal_data": {"image": Image.open(img)},
139 | })
140 | # pdb.set_trace()
141 | outputs = model.generate(inputs, sampling_params=sampling_params)
142 |
143 | pred_cls = []
144 |
145 | for i, output in enumerate(outputs):
146 | # pdb.set_trace()
147 |
148 | # #save result
149 | output_result.append({'id':imgs[0][i], 'caption':output.outputs[0].text})
150 |
151 |
152 | response = output.outputs[0].text
153 | if 'real' in response.split('.')[0].lower():
154 | pred_cls.append(1)
155 | elif 'fake' in response.split('.')[0].lower():
156 | pred_cls.append(0)
157 | else:
158 | try:
159 | if 'real' in response.split('.')[1].lower():
160 | pred_cls.append(1)
161 | elif 'fake' in response.split('.')[1].lower():
162 | pred_cls.append(0)
163 | else:
164 | print(f"no fake or real in reponse:{response}")
165 | pred_cls.append(random.choice([0, 1]))
166 | except:
167 | print(f"no fake or real in reponse:{response}")
168 | pred_cls.append(random.choice([0, 1]))
169 |
170 | for label, pred, cate in zip(labels[0].tolist(), pred_cls, cates[0]):
171 | if cate not in results:
172 | results[cate] = {'right':{'right_fake':0, 'right_real':0}, 'wrong':{'wrong_fake':0, 'wrong_real':0}}
173 | if label == pred:
174 | if label == 1:
175 | results[cate]['right']['right_real'] += 1
176 | else:
177 | results[cate]['right']['right_fake'] += 1
178 | else:
179 | if label == 1:
180 | results[cate]['wrong']['wrong_real'] += 1
181 | else:
182 | results[cate]['wrong']['wrong_fake'] += 1
183 |
184 | # save result
185 | os.makedirs("results", exist_ok=True)
186 | with open(args.output_path, "w") as file:
187 | json.dump(output_result, file, indent=2)
188 |
189 | acc = calculate_results_acc(results)
190 | print(acc)
191 |
192 |
193 |
194 | def main():
195 | args = parse_args()
196 | model = load_model(args)
197 | cls_test_dataset = legion_cls_dataset(args, train=False)
198 | cls_test_dataloader = DataLoader(
199 | cls_test_dataset,
200 | batch_size=args.val_batch_size,
201 | shuffle=False,
202 | num_workers=args.workers,
203 | pin_memory=True,
204 | )
205 | validate(args, model, cls_test_dataloader)
206 |
207 |
208 |
209 |
210 | if __name__ == "__main__":
211 | main()
212 |
213 |
--------------------------------------------------------------------------------
/scripts/train.sh:
--------------------------------------------------------------------------------
1 | NUM_GPUS=8
2 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
3 | DISTRIBUTED_ARGS="
4 | --nnodes=1 \
5 | --nproc_per_node ${NUM_GPUS} \
6 | --rdzv_backend c10d \
7 | --rdzv_endpoint localhost:0
8 | "
9 |
10 | # according to your own case
11 | MODEL_ID=llava-1.5-7b # model id
12 | TRAIN_DATA_PATH="path/to/test.json" # path to the training data json file
13 | EVAL_DATA_PATH="path/to/test.json" # path to the evaluation data json file (optional)
14 | IMAGE_FOLDER="path/to/test_images" # path to the image root folder; if provided, the image paths in the json should be relative
15 | VIDEO_FOLDER="" # path to the video root folder; if provided, the video paths in the json should be relative
16 | NUM_FRAMES=8 # how many frames are sampled from each video
17 |
18 | TRAIN_VISION_ENCODER=True # whether train the vision encoder
19 | USE_VISION_LORA=False # whether use lora for vision encoder (only effective when `TRAIN_VISION_ENCODER` is True)
20 | TRAIN_VISION_PROJECTOR=True # whether train the vision projector (only full finetuning is supported)
21 |
22 | USE_LORA=False # whether use lora for llm
23 | Q_LORA=False # whether use q-lora for llm; only effective when `USE_LORA` is True
24 | LORA_R=8 # the lora rank (both llm and vision encoder)
25 | LORA_ALPHA=8 # the lora alpha (both llm and vision encoder)
26 |
27 | RUN_ID=${MODEL_ID}-fakevlm # a custom run id that determines the checkpoint folder and wandb run name
28 |
29 | DS_STAGE=zero2 # deepspeed stage; < zero2 | zero3 >
30 | PER_DEVICE_BATCH_SIZE=32 # batch size per GPU
31 | GRAD_ACCUM=1 # gradient accumulation steps
32 | NUM_EPOCHS=2 # number of training epochs
33 |
34 | LR=2e-5 # learning rate
35 | MODEL_MAX_LEN=1024 # maximum input length of the model
36 |
37 |
38 | torchrun $DISTRIBUTED_ARGS train.py \
39 | --model_id $MODEL_ID \
40 | --data_path $TRAIN_DATA_PATH \
41 | --eval_data_path $EVAL_DATA_PATH \
42 | --image_folder $IMAGE_FOLDER \
43 | --video_folder $VIDEO_FOLDER \
44 | --num_frames $NUM_FRAMES \
45 | --output_dir ./checkpoints/$RUN_ID \
46 | --report_to wandb \
47 | --run_name $RUN_ID \
48 | --deepspeed ./ds_configs/${DS_STAGE}.json \
49 | --bf16 True \
50 | --num_train_epochs $NUM_EPOCHS \
51 | --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \
52 | --per_device_eval_batch_size $PER_DEVICE_BATCH_SIZE \
53 | --gradient_accumulation_steps $GRAD_ACCUM \
54 | --eval_strategy "no" \
55 | --save_strategy "epoch" \
56 | --save_total_limit 1 \
57 | --learning_rate ${LR} \
58 | --weight_decay 0. \
59 | --warmup_ratio 0.03 \
60 | --lr_scheduler_type "cosine" \
61 | --logging_steps 1 \
62 | --tf32 True \
63 | --model_max_length $MODEL_MAX_LEN \
64 | --gradient_checkpointing True \
65 | --dataloader_num_workers 4 \
66 | --train_vision_encoder $TRAIN_VISION_ENCODER \
67 | --use_vision_lora $USE_VISION_LORA \
68 | --train_vision_projector $TRAIN_VISION_PROJECTOR \
69 | --use_lora $USE_LORA \
70 | --q_lora $Q_LORA \
71 | --lora_r $LORA_R \
72 | --lora_alpha $LORA_ALPHA
73 |
--------------------------------------------------------------------------------
/supported_models.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List
2 | from collections import OrderedDict
3 |
4 | from collators import COLLATORS
5 | from datasets import TO_LOAD_IMAGE
6 | from loaders import LOADERS
7 |
8 |
9 | MODULE_KEYWORDS: Dict[str, Dict[str, List]] = {
10 | "llava-1.5": {
11 | "vision_encoder": ["vision_tower"],
12 | "vision_projector": ["multi_modal_projector"],
13 | "llm": ["language_model"]
14 | },
15 | }
16 |
17 |
18 | MODEL_HF_PATH = OrderedDict()
19 |
20 | MODEL_FAMILIES = OrderedDict()
21 |
22 |
23 | def register_model(model_id: str, model_family_id: str, model_hf_path: str) -> None:
24 | if model_id in MODEL_HF_PATH or model_id in MODEL_FAMILIES:
25 | raise ValueError(f"Duplicate model_id: {model_id}")
26 | MODEL_HF_PATH[model_id] = model_hf_path
27 | MODEL_FAMILIES[model_id] = model_family_id
28 |
29 | register_model(
30 | model_id="llava-1.5-7b",
31 | model_family_id="llava-1.5",
32 | model_hf_path="llava-hf/llava-1.5-7b-hf"
33 | )
34 |
35 | # sanity check
36 | for model_family_id in MODEL_FAMILIES.values():
37 | assert model_family_id in COLLATORS, f"Collator not found for model family: {model_family_id}"
38 | assert model_family_id in LOADERS, f"Loader not found for model family: {model_family_id}"
39 | assert model_family_id in MODULE_KEYWORDS, f"Module keywords not found for model family: {model_family_id}"
40 | assert model_family_id in TO_LOAD_IMAGE, f"Image loading specification not found for model family: {model_family_id}"
41 |
42 |
43 | if __name__ == "__main__":
44 | temp = "Model ID"
45 | ljust = 30
46 | print("Supported models:")
47 | print(f" {temp.ljust(ljust)}: HuggingFace Path")
48 | print(" ------------------------------------------------")
49 | for model_id, model_hf_path in MODEL_HF_PATH.items():
50 | print(f" {model_id.ljust(ljust)}: {model_hf_path}")
51 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["WANDB_PROJECT"]= "lmms-ft"
3 | from dataclasses import asdict
4 | import math
5 | from pathlib import Path
6 | from typing import List, Optional
7 | import yaml
8 |
9 | from accelerate.utils import DistributedType
10 | from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
11 | import torch
12 | import transformers
13 | from transformers import Trainer, deepspeed
14 |
15 |
16 | from arguments import ModelArguments, DataArguments, TrainingArguments, LoraArguments
17 | from collators import COLLATORS
18 | from datasets import LazySupervisedDataset
19 | from loaders import LOADERS
20 | from supported_models import MODULE_KEYWORDS
21 | from utils import (
22 | rank0_print, find_all_linear_names, safe_save_model_for_hf_trainer,
23 | get_peft_state_maybe_zero_3, TrainerWithCustomSampler
24 | )
25 |
26 |
27 | def train():
28 | parser = transformers.HfArgumentParser(
29 | (ModelArguments, DataArguments, TrainingArguments, LoraArguments)
30 | )
31 | model_args, data_args, training_args, lora_args = parser.parse_args_into_dataclasses()
32 |
33 | # dumping arguments
34 | output_dir = getattr(training_args, 'output_dir', None)
35 | assert output_dir is not None, "output_dir is required"
36 | args_dir = Path(output_dir) / "arguments"
37 | args_dir.mkdir(parents=True, exist_ok=True)
38 | yaml.dump(asdict(model_args), open(args_dir / "model.yaml", "w"))
39 | yaml.dump(asdict(data_args), open(args_dir / "data.yaml", "w"))
40 | yaml.dump(asdict(training_args), open(args_dir / "training.yaml", "w"))
41 | yaml.dump(asdict(lora_args), open(args_dir / "lora.yaml", "w"))
42 |
43 | compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
44 | if getattr(training_args, 'deepspeed', None) and getattr(lora_args, 'q_lora', False):
45 | training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED
46 |
47 | device_map = None
48 | if lora_args.q_lora:
49 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if int(os.environ.get("WORLD_SIZE", 1)) != 1 else None
50 | if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
51 | raise ValueError("FSDP or ZeRO3 are not incompatible with QLoRA.")
52 |
53 | # llm quantization config (for q-lora)
54 | bnb_config = None
55 | if lora_args.use_lora and lora_args.q_lora:
56 | from transformers import BitsAndBytesConfig
57 | rank0_print("Quantization for LLM enabled...")
58 | bnb_config = BitsAndBytesConfig(
59 | load_in_4bit=True,
60 | bnb_4bit_compute_dtype=compute_dtype,
61 | bnb_4bit_quant_type="nf4",
62 | )
63 |
64 | # load model, tokenizer, processor
65 | rank0_print("Loading model, tokenizer, processor...")
66 | loader = LOADERS[model_args.model_family_id](
67 | model_hf_path=model_args.model_hf_path,
68 | model_local_path=model_args.model_local_path,
69 | compute_dtype=compute_dtype,
70 | bnb_config=bnb_config,
71 | use_flash_attn=training_args.use_flash_attn,
72 | device_map=device_map,
73 | )
74 | model, tokenizer, processor, config = loader.load()
75 | tokenizer.model_max_length = training_args.model_max_length
76 |
77 | if training_args.gradient_checkpointing:
78 | model.enable_input_require_grads()
79 |
80 | # freeze certain params
81 | vision_encoder_keys = MODULE_KEYWORDS[model_args.model_family_id]["vision_encoder"]
82 | if not training_args.train_vision_encoder:
83 | rank0_print(f"Vision encoder is freezed... including:")
84 | for module in vision_encoder_keys:
85 | rank0_print(f"\t{module}")
86 | eval(f"model.{module}").requires_grad_(False)
87 |
88 | vision_projector_keys = MODULE_KEYWORDS[model_args.model_family_id]["vision_projector"]
89 | if not training_args.train_vision_projector:
90 | rank0_print(f"Vision projector is freezed... including:")
91 | for module in vision_projector_keys:
92 | rank0_print(f"\t{module}")
93 | eval(f"model.{module}").requires_grad_(False)
94 |
95 | # other components preparation (e.g., image_newline, vision_resampler)
96 | # we will just freeze these
97 | if "others" in MODULE_KEYWORDS[model_args.model_family_id]:
98 | rank0_print(f"Other multimodal component is freezed... including:")
99 | for other_key in MODULE_KEYWORDS[model_args.model_family_id]["others"]:
100 | rank0_print(f"\t{other_key}")
101 | eval(f"model.{other_key}").requires_grad_(False)
102 |
103 | # lora preparation
104 | llm_keys = MODULE_KEYWORDS[model_args.model_family_id]["llm"]
105 | if not (lora_args.use_lora or (training_args.train_vision_encoder and lora_args.use_vision_lora)):
106 | rank0_print("No LoRA enabled...")
107 | else:
108 | named_modules = {n: m for n, m in model.named_modules()}
109 | lora_modules = []
110 | full_modules = []
111 |
112 | if training_args.train_vision_encoder and lora_args.use_vision_lora:
113 | rank0_print("LoRA for vision encoder enabled...")
114 | lora_modules.extend(find_all_linear_names(named_modules, vision_encoder_keys))
115 | elif training_args.train_vision_encoder:
116 | rank0_print("Vision encoder will be fully trained...")
117 | full_modules.extend(vision_encoder_keys)
118 |
119 | if lora_args.use_lora:
120 | rank0_print("LoRA for LLM enabled...")
121 | lora_modules.extend(find_all_linear_names(named_modules, llm_keys))
122 | else:
123 | rank0_print("LLM will be fully trained...")
124 | full_modules.extend(llm_keys)
125 |
126 | if training_args.train_vision_projector:
127 | rank0_print("Vision projector will be fully trained...")
128 | full_modules.extend(vision_projector_keys)
129 |
130 | lora_config = LoraConfig(
131 | r=lora_args.lora_r,
132 | lora_alpha=lora_args.lora_alpha,
133 | target_modules=lora_modules,
134 | modules_to_save=full_modules,
135 | lora_dropout=lora_args.lora_dropout,
136 | bias=lora_args.lora_bias,
137 | task_type="CAUSAL_LM",
138 | )
139 |
140 | if lora_args.q_lora:
141 | model = prepare_model_for_kbit_training(
142 | model, use_gradient_checkpointing=training_args.gradient_checkpointing
143 | )
144 |
145 | model = get_peft_model(model, lora_config)
146 |
147 | # for module in llm_keys:
148 | # rank0_print(f"\t{module}")
149 | # eval(f"model.{module}").requires_grad_(False)
150 |
151 | # print trainable parameters for inspection
152 | rank0_print("Trainable parameters:")
153 | for name, param in model.named_parameters():
154 | if param.requires_grad:
155 | rank0_print(f"\t{name}")
156 |
157 | # load data
158 | rank0_print("Loading data...")
159 | train_dataset = LazySupervisedDataset(
160 | data_path=data_args.data_path,
161 | image_folder=data_args.image_folder,
162 | video_folder=data_args.video_folder,
163 | num_frames=data_args.num_frames,
164 | model_family_id=model_args.model_family_id,
165 | user_key=data_args.user_key,
166 | assistant_key=data_args.assistant_key
167 | )
168 | if data_args.eval_data_path:
169 | eval_dataset = LazySupervisedDataset(
170 | data_path=data_args.eval_data_path,
171 | image_folder=data_args.image_folder,
172 | video_folder=data_args.video_folder,
173 | num_frames=data_args.num_frames,
174 | model_family_id=model_args.model_family_id,
175 | user_key=data_args.user_key,
176 | assistant_key=data_args.assistant_key
177 | )
178 | else:
179 | eval_dataset = None
180 | training_args.eval_strategy = "no"
181 |
182 | # data collator
183 | data_collator = COLLATORS[model_args.model_family_id](
184 | config=config,
185 | tokenizer=tokenizer,
186 | processor=processor,
187 | mask_question_tokens=training_args.mask_question_tokens
188 | )
189 |
190 | # trainer
191 | trainer = TrainerWithCustomSampler(
192 | model=model,
193 | args=training_args,
194 | data_collator=data_collator,
195 | train_dataset=train_dataset,
196 | eval_dataset=eval_dataset,
197 | )
198 | trainer.train()
199 | trainer.save_state()
200 |
201 | safe_save_model_for_hf_trainer(trainer=trainer, output_dir=output_dir)
202 |
203 |
204 | if __name__ == "__main__":
205 | train()
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import List, Dict, Optional
3 |
4 | from deepspeed import zero
5 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
6 |
7 | import torch
8 | import torch.distributed as dist
9 | from torch.utils.data import Sampler
10 | import transformers
11 | from transformers import Trainer
12 | from transformers.trainer import has_length
13 |
14 |
15 | class NoTextOnlyBatchSampler(Sampler):
16 | r"""
17 | Sampler that tries its best to sample batches such that no batch has only
18 | text (unimodal) data. This is necessary for training with deepspeed.
19 | """
20 |
21 | def __init__(
22 | self,
23 | batch_size: int,
24 | world_size: int,
25 | is_text_only: Optional[List[bool]] = None,
26 | generator=None,
27 | ):
28 | if is_text_only is None:
29 | raise ValueError("`is_text_only` must be provided.")
30 |
31 | self.batch_size = batch_size
32 | self.world_size = world_size
33 | self.is_text_only = is_text_only
34 | self.generator = generator
35 | self.mega_batch_size = batch_size * world_size
36 |
37 | def __len__(self):
38 | return len(self.is_text_only)
39 |
40 | def __iter__(self):
41 | # mm: multimodal, entry that has both text and image/video
42 | # uni: unimodal, entry that has only text
43 | mm_indices = [i for i, is_text_only in enumerate(self.is_text_only) if not is_text_only]
44 | uni_indices = [i for i, is_text_only in enumerate(self.is_text_only) if is_text_only]
45 |
46 | num_batches = math.ceil((len(mm_indices) + len(uni_indices)) / self.mega_batch_size)
47 | if len(mm_indices) < num_batches:
48 | raise ValueError(
49 | f"{len(mm_indices)} multimodal entries, {len(num_batches)} batches. "
50 | "Not enough multimodal data in the dataset, or the batch size is too small. "
51 | "There will be at least one batch that is text-only, which doesn't work with deepspeed. "
52 | "Try increasing the batch size first."
53 | )
54 |
55 | # shuffle indices
56 | mm_indices = [mm_indices[i] for i in torch.randperm(len(mm_indices), generator=None).tolist()]
57 | uni_indices = [uni_indices[i] for i in torch.randperm(len(uni_indices), generator=None).tolist()]
58 |
59 | # distribute indices into batches
60 | num_uni_indices_in_mega_batch = [len(uni_indices) // num_batches] * num_batches
61 | for i in range(len(uni_indices) % num_batches):
62 | num_uni_indices_in_mega_batch[i] += 1
63 |
64 | mega_batches = []
65 | cur_uni_index = 0
66 | cur_mm_index = 0
67 | for i, num_uni_indices in enumerate(num_uni_indices_in_mega_batch):
68 | mega_batch = []
69 | mega_batch.extend(uni_indices[cur_uni_index:cur_uni_index + num_uni_indices])
70 | cur_uni_index += num_uni_indices
71 | assert len(mega_batch) < self.mega_batch_size
72 |
73 | if i < num_batches - 1:
74 | increment = self.mega_batch_size - len(mega_batch)
75 | mega_batch.extend(
76 | mm_indices[cur_mm_index:cur_mm_index + increment]
77 | )
78 | cur_mm_index += increment
79 | else: # last batch
80 | mega_batch.extend(mm_indices[cur_mm_index:])
81 | assert len(mega_batch) <= self.mega_batch_size, "Last batch is too big."
82 |
83 | mega_batches.append(mega_batch)
84 |
85 | mega_batch_indices = torch.randperm(len(mega_batches), generator=self.generator)
86 | mega_batches = [mega_batches[i] for i in mega_batch_indices]
87 | indices = [i for mega_batch in mega_batches for i in mega_batch]
88 | return iter(indices)
89 |
90 |
91 | class TrainerWithCustomSampler(Trainer):
92 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
93 | if self.train_dataset is None or not has_length(self.train_dataset):
94 | return None
95 |
96 | is_text_only = self.train_dataset.is_text_only
97 | return NoTextOnlyBatchSampler(
98 | self.args.train_batch_size,
99 | world_size=self.args.world_size * self.args.gradient_accumulation_steps,
100 | is_text_only=is_text_only,
101 | )
102 |
103 | def _get_eval_sampler(self, eval_dataset: torch.utils.data.Dataset) -> Optional[torch.utils.data.Sampler]:
104 | is_text_only = eval_dataset.is_text_only
105 | return NoTextOnlyBatchSampler(
106 | self.args.eval_batch_size,
107 | world_size=self.args.world_size,
108 | is_text_only=is_text_only,
109 | )
110 |
111 |
112 | def find_all_linear_names(named_modules: Dict, target_modules: List[str]):
113 | cls = torch.nn.Linear
114 | lora_module_names = set()
115 | for name, module in named_modules.items():
116 | if not any([module_name in name for module_name in target_modules]):
117 | continue
118 |
119 | if isinstance(module, cls):
120 | lora_module_names.add(name)
121 |
122 | for name in list(lora_module_names):
123 | if 'lm_head' in name: # needed for 16-bit
124 | lora_module_names.remove(name)
125 |
126 | return list(lora_module_names)
127 |
128 |
129 | def rank0_print(*args):
130 | if dist.is_initialized():
131 | if dist.get_rank() == 0:
132 | print(*args)
133 |
134 |
135 | def maybe_zero_3(param):
136 | if hasattr(param, "ds_id"):
137 | with zero.GatheredParameters([param]):
138 | param = param.data.detach().cpu().clone()
139 | else:
140 | param = param.detach().cpu().clone()
141 | return param
142 |
143 |
144 | # Borrowed from peft.utils.get_peft_model_state_dict
145 | def get_peft_state_maybe_zero_3(named_params, bias):
146 | if bias == "none":
147 | to_return = {k: t for k, t in named_params if "lora_" in k}
148 | elif bias == "all":
149 | to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
150 | elif bias == "lora_only":
151 | to_return = {}
152 | maybe_lora_bias = {}
153 | lora_bias_names = set()
154 | for k, t in named_params:
155 | if "lora_" in k:
156 | to_return[k] = t
157 | bias_name = k.split("lora_")[0] + "bias"
158 | lora_bias_names.add(bias_name)
159 | elif "bias" in k:
160 | maybe_lora_bias[k] = t
161 | for k, t in maybe_lora_bias:
162 | if bias_name in lora_bias_names:
163 | to_return[bias_name] = t
164 | else:
165 | raise NotImplementedError
166 | to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}
167 | return to_return
168 |
169 |
170 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
171 | """Collects the state dict and dump to disk."""
172 | if trainer.deepspeed:
173 | torch.cuda.synchronize()
174 | trainer.save_model(output_dir)
175 | return
176 |
177 | state_dict = trainer.model.state_dict()
178 | if trainer.args.should_save:
179 | cpu_state_dict = {
180 | key: value.cpu()
181 | for key, value in state_dict.items()
182 | }
183 | del state_dict
184 | trainer._save(output_dir, state_dict=cpu_state_dict)
--------------------------------------------------------------------------------