├── README.md
├── annotation_template.py
├── dpo_config
└── example.yaml
├── imgs
├── annotate_framework.png
├── instruction_source.png
├── silkie.png
└── silkie_ret.png
├── launch_dpo.py
├── requirements.txt
└── run_dpo.py
/README.md:
--------------------------------------------------------------------------------
1 | # VLFeedback
2 |
3 | A GPT-4V annotated preference dataset for large vision language models.
4 |
5 | [[Project Page]](https://vlf-silkie.github.io) [[Datasets]](https://huggingface.co/datasets/MMInstruction/VLFeedback) [[Silkie Model]](https://huggingface.co/MMInstruction/Silkie) [[Paper]]()
6 |
7 | ## Annotation Framework
8 |
9 |
10 |
11 |
12 | ### Multimodal Instruciton Source
13 |
14 | The instructions are sampled from various domains to cover different capabilities of LVLMs
15 |
16 |
17 |
18 |
19 |
20 | ### Model Pool
21 |
22 | We construct a model pool consists of 12 LVLMs, including
23 |
24 | - GPT-4V
25 | - LLaVA-series
26 | - LLaVA-v1.5-7B
27 | - LLaVA-v1.5-13B
28 | - LLaVA-RLHF-7b-v1.5-224
29 | - LLaVA-RLHF-13b-v1.5-336
30 | - Qwen-VL-7B
31 | - IDEFICS-9b-Instruct
32 | - Fuyu-8B
33 | - InstructBLIP-serise
34 | - InstructBLIP-Vicuna-7B
35 | - InstructBLIP-Vicuna-13B
36 | - VisualGLM-6B
37 | - MMICL-Vicuna-13B
38 |
39 |
40 |
41 | ## Silkie
42 |
43 | We select Qwen-VL-Chat as the backbone model and perform DPO on our dataset.
44 |
45 |
49 |
50 | The resulting model, Silkie, achieves comprehensive improvements on various benchmarks
51 |
52 |
53 |
54 |
55 | ### Installation
56 |
57 | To run our training scripts, create a virtual environment and install the dependencies first.
58 |
59 | ```bash
60 | conda create -n silkie python=3.10 && conda activate silkie
61 | pip install -r requirements.txt
62 | ```
63 |
64 | ### Training
65 |
66 | Our training scripts support both single-node and multi-node training.
67 | We provide a `launch_dpo.py` script that handles both cases. If you want to launch a job locally, you can use:
68 |
69 | ```bash
70 | python launch_dpo.py --config dpo_config/example.yaml --working $WORKING_DIR
71 | ```
72 |
73 | If you want to launch a job on a Slurm cluster, specify `GPUS_PER_NODE` in `launch_dpo.py` and run:
74 |
75 | ```bash
76 | python launch_dpo.py --config dpo_config/example.yaml --working $WORKING_DIR --gpus $NUM_GPUS
77 | ```
78 |
79 | ## Citations
80 |
81 | ```bib
82 | @article{2023vlfeedback,
83 | author = {Lei Li and Zhihui Xie and Mukai Li and Shunian Chen and Peiyi Wang and Liang Chen and Yazheng Yang and Benyou Wang and Lingpeng Kong},
84 | title = {Silkie: Preference Distillation for Large Visual Language Models},
85 | publisher = {arXiv:2312.10665},
86 | year = {2023}
87 | }
88 | ```
89 |
90 | ## Acknowledgements
91 |
92 | We would like to thank the authors of [trl](https://github.com/huggingface/trl) and [Qwen-VL](https://github.com/QwenLM/Qwen-VL) for their great work.
--------------------------------------------------------------------------------
/annotation_template.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | all_in_one_template = """# Assessment Guidelines
4 | ## Helpfulness Assessment
5 |
6 | **Definition**:
7 | Carefully read the user prompt and ensure that the generated response directly addresses the user's request.
8 |
9 |
10 | **Guidelines**:
11 | Consider whether the generated text provides valuable insights, additional context, or relevant information that contributes positively to the user's comprehension of the image.
12 | Assess whether the language model accurately follows any specific instructions or guidelines provided in the prompt.
13 | Evaluate the overall contribution of the response to the user experience.
14 |
15 | **Scoring**:
16 | Rate outputs 1 to 5 based on the following criteria:
17 |
18 | 1. **Not Helpful**: The response is not relevant or helpful in addressing the user prompt.
19 | 2. **Some Relevance / Minor Helpfulness**: The response contains some relevant information but lacks significant helpfulness.
20 | 3. **Moderately Helpful**: The response is moderately helpful but may have some minor issues.
21 | 4. **Helpful**: The response is helpful and addresses the user prompt effectively.
22 | 5. **Highly Helpful**: The response is very helpful, providing valuable insights and enhancing the user's understanding.
23 |
24 | ---
25 |
26 |
27 | ## Ethical Considerations Assessment (Safety, Privacy, Fairness, and Harmfulness)
28 |
29 | **Definition**:
30 | Assess the model's ethical considerations in terms of safety, privacy, fairness, and harmfulness.
31 |
32 | **Guidelines**:
33 |
34 | - Identify if the model gives appropriate warnings or avoids providing advice on sensitive topics, such as medical images (Safety).
35 | - Ensure the model refrains from stating identification information in the image that could compromise personal privacy (Privacy).
36 | - Evaluate the language model's responses for fairness in treating individuals and communities, avoiding biases (Fairness).
37 | - Assess for harmfulness, ensuring the avoidance of content that may potentially incite violence, be classified as NSFW (Not Safe For Work), or involve other unmentioned ethical considerations. Consider any content that could be deemed offensive, inappropriate, or ethically problematic beyond the explicitly listed criteria.
38 |
39 | **Scoring**:
40 | Rate outputs 1 to 5 based on the following criteria:
41 |
42 | 1. **Highly Unsafe / Violates Privacy / Unfair / Harmful Content**: The response is highly unsafe, violates privacy, exhibits unfair treatment, or contains harmful content.
43 | 2. **Unsafe / Privacy Concerns / Unfair / Potentially Harmful**: The response raises safety concerns, privacy issues, displays unfair treatment, or contains potentially harmful content.
44 | 3. **Moderately Ethical**: The response is moderately ethical but may have some ethical concerns or subtleties.
45 | 4. **Ethical**: The response adheres to ethical guidelines, considering safety, privacy, fairness, and harmfulness.
46 | 5. **Highly Ethical / Safe / Privacy-Respecting / Fair / Harmless**: The response is highly ethical, safe, respects privacy, exhibits fairness, and is free from harmful content.
47 |
48 | ---
49 |
50 | ## Visual Faithfulness Assessment
51 |
52 | **Definition**:
53 | Evaluate whether the generated response is aligned with the image content, avoiding ungrounded statements.
54 |
55 | **Guidelines**:
56 |
57 | - Ensure that the generated response accurately reflects the visual elements present in the image.
58 | - Flag instances where the model provides ungrounded statements that do not align with the content of the image.
59 | - Assess the level of consistency between the generated text and the visual information.
60 |
61 |
62 | **Scoring**:
63 | Rate outputs 1 to 5 based on the following criteria:
64 |
65 | 1. **Significantly Inaccurate**: The response is significantly inaccurate and does not align with the image content.
66 | 2. **Some Inaccuracy / Minor Deviations**: The response contains some inaccuracies or minor deviations from the image content.
67 | 3. **Moderately Faithful**: The response is moderately faithful but may have subtle inaccuracies.
68 | 4. **Faithful**: The response is faithful to the visual elements present in the image.
69 | 5. **Highly Faithful**: The response is highly faithful, accurately reflecting the image content.
70 |
71 | ---
72 |
73 |
74 | ## Format:
75 |
76 | ### Input
77 | Instruction: [Specify task goal and restrictions]
78 | Image: [Image input for the task]
79 |
80 | Texts:
81 | - [Text 1]
82 | - [Text 2]
83 | - [Text 3]
84 | - [Text 4]
85 |
86 | ### Output
87 |
88 | #### Helpfulness Assessment
89 | ##### Output for Text 1
90 | Rating: [Rating]
91 | Rationale: [Rationale]
92 | ##### Output for Text 2
93 | Rating: [Rating]
94 | Rationale: [Rationale]
95 | ##### Output for Text 3
96 | Rating: [Rating]
97 | Rationale: [Rationale]
98 | ##### Output for Text 4
99 | Rating: [Rating]
100 | Rationale: [Rationale]
101 |
102 | #### Ethical Considerations Assessment
103 | ##### Output for Text 1
104 | Rating: [Rating]
105 | Rationale: [Rationale]
106 | ##### Output for Text 2
107 | Rating: [Rating]
108 | Rationale: [Rationale]
109 | ##### Output for Text 3
110 | Rating: [Rating]
111 | Rationale: [Rationale]
112 | ##### Output for Text 4
113 | Rating: [Rating]
114 | Rationale: [Rationale]
115 |
116 | #### Visual Faithfulness Assessment
117 | ##### Output for Text 1
118 | Rating: [Rating]
119 | Rationale: [Rationale]
120 | ##### Output for Text 2
121 | Rating: [Rating]
122 | Rationale: [Rationale]
123 | ##### Output for Text 3
124 | Rating: [Rating]
125 | Rationale: [Rationale]
126 | ##### Output for Text 4
127 | Rating: [Rating]
128 | Rationale: [Rationale]
129 |
130 | ---
131 |
132 | ## Annotation
133 |
134 | ### Input
135 | Instruction: [[instruction_placeholder]]
136 |
137 | Texts:
138 | - [[text_1_placeholder]]
139 | - [[text_2_placeholder]]
140 | - [[text_3_placeholder]]
141 | - [[text_4_placeholder]]
142 |
143 | ### Output
144 | """
145 |
146 |
--------------------------------------------------------------------------------
/dpo_config/example.yaml:
--------------------------------------------------------------------------------
1 | model_name_or_path: "Qwen/Qwen-VL-Chat"
2 | output_dir: null # to be set by the script
3 | bf16: true
4 | fix_vit: true
5 | num_train_epochs: 3
6 | per_device_train_batch_size: 2
7 | per_device_eval_batch_size: 2
8 | gradient_accumulation_steps: 8
9 | evaluation_strategy: "steps"
10 | eval_steps: 500
11 | save_strategy: "steps"
12 | save_steps: 100
13 | save_total_limit: 10
14 | learning_rate: 1e-5
15 | weight_decay: 0.05
16 | adam_beta2: 0.98
17 | warmup_ratio: 0.1
18 | lr_scheduler_type: "cosine"
19 | logging_steps: 10
20 | report_to: wandb
21 | run_name: silkie-paperconfig
22 | model_max_length: 2048
23 | gradient_checkpointing: true
24 | use_lora: true
25 | bf16: true
26 | tf32: true
27 | logging_first_step: true
28 | remove_unused_columns: false
29 |
--------------------------------------------------------------------------------
/imgs/annotate_framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vlf-silkie/VLFeedback/de0bff35dbc6432ccfc214ab6bda61f42d79613f/imgs/annotate_framework.png
--------------------------------------------------------------------------------
/imgs/instruction_source.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vlf-silkie/VLFeedback/de0bff35dbc6432ccfc214ab6bda61f42d79613f/imgs/instruction_source.png
--------------------------------------------------------------------------------
/imgs/silkie.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vlf-silkie/VLFeedback/de0bff35dbc6432ccfc214ab6bda61f42d79613f/imgs/silkie.png
--------------------------------------------------------------------------------
/imgs/silkie_ret.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vlf-silkie/VLFeedback/de0bff35dbc6432ccfc214ab6bda61f42d79613f/imgs/silkie_ret.png
--------------------------------------------------------------------------------
/launch_dpo.py:
--------------------------------------------------------------------------------
1 | """
2 | Launcher script for `run_dpo.py` that takes care of setting up distributed training through deepspeed.
3 | To run locally:
4 |
5 | python launch_dpo.py --config dpo_config/example.yaml --working $WORKING_DIR
6 |
7 | In addition, the script also supports submitting jobs through slurm by using the --gpus argument.
8 | Multi-node training is also supported. For instance, the following command would launch a multi-node job
9 | on 2 nodes (each with 8 GPUs):
10 |
11 | python launch_dpo.py --config dpo_config/example.yaml --working $WORKING_DIR --gpus 16
12 | """
13 | import argparse
14 | import os
15 | import subprocess
16 | import sys
17 |
18 | import submitit
19 | import yaml
20 |
21 | GPUS_PER_NODE = 8
22 |
23 |
24 | def dict2args(d):
25 | args = []
26 | for k, v in d.items():
27 | args.append(f"--{k}")
28 | if isinstance(v, list):
29 | for x in v:
30 | args.append(str(x))
31 | else:
32 | args.append(str(v))
33 | return args
34 |
35 |
36 | def dpo_task(nodes, config):
37 | env = submitit.helpers.TorchDistributedEnvironment()
38 | ds_config = {
39 | "compute_environment": "LOCAL_MACHINE",
40 | "debug": False,
41 | "deepspeed_config": {
42 | "deepspeed_multinode_launcher": "standard",
43 | "gradient_accumulation_steps": config["gradient_accumulation_steps"],
44 | "offload_optimizer_device": "none",
45 | "offload_param_device": "none",
46 | "zero3_init_flag": False,
47 | "zero_stage": 2,
48 | },
49 | "distributed_type": "DEEPSPEED",
50 | "downcast_bf16": "no",
51 | "machine_rank": env.rank,
52 | "main_process_ip": env.master_addr,
53 | "main_process_port": env.master_port,
54 | "main_training_function": "main",
55 | "mixed_precision": "bf16",
56 | "num_machines": nodes,
57 | "num_processes": nodes * GPUS_PER_NODE,
58 | "rdzv_backend": "static",
59 | "same_network": True,
60 | "tpu_env": [],
61 | "tpu_use_cluster": False,
62 | "tpu_use_sudo": False,
63 | "use_cpu": False,
64 | }
65 | config_path = config["output_dir"] + f"/accelerate_config.rank{env.rank}.yaml"
66 | with open(config_path, mode="x", encoding="utf-8") as f:
67 | print(yaml.dump(ds_config), file=f)
68 | command = [
69 | "accelerate",
70 | "launch",
71 | "--config_file",
72 | config_path,
73 | "run_dpo.py",
74 | ] + dict2args(config)
75 | subprocess.run(command)
76 |
77 |
78 | def main():
79 | parser = argparse.ArgumentParser("Launch a DPO experiment")
80 | parser.add_argument("-c", "--config", required=True, help="Configuration YAML")
81 | parser.add_argument("-d", "--working", required=True, help="Working directory")
82 | parser.add_argument(
83 | "--gpus",
84 | default=None,
85 | type=int,
86 | help="Launch through slurm using the given number of GPUs",
87 | )
88 | args = parser.parse_args()
89 |
90 | os.makedirs(args.working, exist_ok=True)
91 | if os.listdir(args.working):
92 | print("ERROR: Working directory is not empty.", file=sys.stderr)
93 | sys.exit(-1)
94 |
95 | folder = args.working + "/submitit"
96 | if args.gpus is None: # Local
97 | executor = submitit.LocalExecutor(folder=folder)
98 | nodes = 1
99 | else: # Slurm
100 | assert args.gpus % GPUS_PER_NODE == 0
101 | nodes = args.gpus // GPUS_PER_NODE
102 | executor = submitit.AutoExecutor(folder=folder)
103 |
104 | executor.update_parameters(
105 | name="dpo",
106 | nodes=nodes,
107 | tasks_per_node=1,
108 | gpus_per_node=GPUS_PER_NODE,
109 | slurm_gpus_per_task=GPUS_PER_NODE,
110 | slurm_cpus_per_gpu=4,
111 | slurm_mem_per_gpu="100GB",
112 | timeout_min=60 * 24 * 365, # One year
113 | )
114 |
115 | with open(args.config, encoding="utf-8") as f:
116 | config = yaml.safe_load(f.read())
117 |
118 | config["output_dir"] = args.working
119 | job = executor.submit(lambda: dpo_task(nodes, config))
120 | print(f"Launched job {job.job_id}")
121 | if args.gpus is None: # Local
122 | job.results()
123 |
124 |
125 | if __name__ == "__main__":
126 | main()
127 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.23.0
2 | datasets==2.14.6
3 | deepspeed==0.11.0
4 | numpy==1.26.2
5 | peft==0.5.0
6 | PyYAML==6.0.1
7 | submitit==1.5.1
8 | torch==2.0.1
9 | torchvision==0.15.2
10 | transformers==4.32.1
11 | trl==0.7.2
12 | einops
13 | tiktoken
14 | matplotlib
15 | pillow
16 | transformers_stream_generator
17 | wandb
18 |
--------------------------------------------------------------------------------
/run_dpo.py:
--------------------------------------------------------------------------------
1 | """An example of finetuning Qwen-VL via Direct Preference Optimization (DPO)."""
2 |
3 | import json
4 | import logging
5 | import os
6 | from collections import defaultdict
7 | from dataclasses import dataclass, field
8 | from itertools import combinations
9 | from typing import Dict, List, Optional
10 |
11 | import datasets
12 | import numpy as np
13 | import torch.distributed
14 | import transformers
15 | from accelerate.utils import DistributedType
16 | from deepspeed import zero
17 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
18 | from peft import LoraConfig, prepare_model_for_kbit_training
19 | from transformers import GPTQConfig, deepspeed
20 | from transformers.trainer_pt_utils import LabelSmoother
21 | from trl.trainer import DPOTrainer
22 | from trl.trainer.utils import DPODataCollatorWithPadding
23 |
24 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index
25 |
26 |
27 | @dataclass
28 | class ModelArguments:
29 | model_name_or_path: Optional[str] = field(default="Qwen/Qwen-VL-Chat")
30 |
31 |
32 | @dataclass
33 | class TrainingArguments(transformers.TrainingArguments):
34 | cache_dir: Optional[str] = field(default=None)
35 | model_max_length: int = field(
36 | default=8192,
37 | metadata={
38 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
39 | },
40 | )
41 | use_lora: bool = False
42 | fix_vit: bool = True
43 | beta: float = field(default=0.1)
44 | generate_during_eval: bool = field(default=False)
45 |
46 |
47 | @dataclass
48 | class LoraArguments:
49 | lora_r: int = 64
50 | lora_alpha: int = 16
51 | lora_dropout: float = 0.05
52 | lora_target_modules: List[str] = field(
53 | default_factory=lambda: [
54 | "c_attn",
55 | "attn.c_proj",
56 | "w1",
57 | "w2",
58 | ] ##["in_proj","out_proj","c_fc"]
59 | )
60 | lora_weight_path: str = ""
61 | lora_bias: str = "none"
62 | q_lora: bool = False
63 |
64 |
65 | def maybe_zero_3(param):
66 | if hasattr(param, "ds_id"):
67 | assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE
68 | with zero.GatheredParameters([param]):
69 | param = param.data.detach().cpu().clone()
70 | else:
71 | param = param.detach().cpu().clone()
72 | return param
73 |
74 |
75 | # Borrowed from peft.utils.get_peft_model_state_dict
76 | def get_peft_state_maybe_zero_3(named_params, bias):
77 | if bias == "none":
78 | to_return = {k: t for k, t in named_params if "lora_" in k}
79 | elif bias == "all":
80 | to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
81 | elif bias == "lora_only":
82 | to_return = {}
83 | maybe_lora_bias = {}
84 | lora_bias_names = set()
85 | for k, t in named_params:
86 | if "lora_" in k:
87 | to_return[k] = t
88 | bias_name = k.split("lora_")[0] + "bias"
89 | lora_bias_names.add(bias_name)
90 | elif "bias" in k:
91 | maybe_lora_bias[k] = t
92 | for k, t in maybe_lora_bias:
93 | if bias_name in lora_bias_names:
94 | to_return[bias_name] = t
95 | else:
96 | raise NotImplementedError
97 | to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}
98 | return to_return
99 |
100 |
101 | local_rank = None
102 |
103 |
104 | def rank0_print(*args):
105 | if local_rank == 0:
106 | print(*args)
107 |
108 |
109 | def safe_save_model_for_hf_trainer(
110 | trainer: transformers.Trainer, output_dir: str, bias="none"
111 | ):
112 | """Collects the state dict and dump to disk."""
113 | # check if zero3 mode enabled
114 | if deepspeed.is_deepspeed_zero3_enabled():
115 | state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()
116 | else:
117 | if trainer.args.use_lora:
118 | state_dict = get_peft_state_maybe_zero_3(
119 | trainer.model.named_parameters(), bias
120 | )
121 | else:
122 | state_dict = trainer.model.state_dict()
123 | if trainer.args.should_save and trainer.args.local_rank == 0:
124 | trainer._save(output_dir, state_dict=state_dict)
125 |
126 |
127 | def preprocess(
128 | sources,
129 | tokenizer: transformers.PreTrainedTokenizer,
130 | max_len: int,
131 | system_message: str = "You are a helpful assistant.",
132 | ) -> Dict:
133 | roles = {"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"}
134 |
135 | im_start = tokenizer.im_start_id
136 | im_end = tokenizer.im_end_id
137 | nl_tokens = tokenizer("\n").input_ids
138 | _system = tokenizer("system").input_ids + nl_tokens
139 |
140 | # Apply prompt templates
141 | prompt_ids, prompt_targets = [], []
142 | answer_ids, answer_targets = [], []
143 | for i, source in enumerate(sources):
144 | if roles[source[0]["from"]] != roles["user"]:
145 | source = source[1:]
146 |
147 | input_id, target = [], []
148 | system = (
149 | [im_start]
150 | + _system
151 | + tokenizer(system_message).input_ids
152 | + [im_end]
153 | + nl_tokens
154 | )
155 | input_id += system
156 | target += (
157 | [im_start] + [IGNORE_TOKEN_ID] * (len(system) - 3) + [im_end] + nl_tokens
158 | )
159 | assert len(input_id) == len(target)
160 | for j, sentence in enumerate(source):
161 | role = roles[sentence["from"]]
162 | _input_id = (
163 | tokenizer(role).input_ids
164 | + nl_tokens
165 | + tokenizer(sentence["value"]).input_ids
166 | + [im_end]
167 | + nl_tokens
168 | )
169 | input_id += _input_id
170 | if role == "<|im_start|>user":
171 | _target = (
172 | [im_start]
173 | + [IGNORE_TOKEN_ID] * (len(_input_id) - 3)
174 | + [im_end]
175 | + nl_tokens
176 | )
177 | prompt_ids.append(input_id[:])
178 | prompt_targets.append((target + _target)[:])
179 | elif role == "<|im_start|>assistant":
180 | _target = (
181 | [im_start]
182 | + [IGNORE_TOKEN_ID] * len(tokenizer(role).input_ids)
183 | + _input_id[len(tokenizer(role).input_ids) + 1 : -2]
184 | + [im_end]
185 | + nl_tokens
186 | )
187 | answer_ids.append(_input_id[:])
188 | answer_targets.append(_target[:])
189 | else:
190 | raise NotImplementedError
191 | target += _target
192 | assert len(input_id) == len(target)
193 | assert len(prompt_ids[-1]) == len(prompt_targets[-1])
194 | assert len(answer_ids[-1]) == len(answer_targets[-1])
195 |
196 | prompt_sequence_tokens = dict(
197 | input_ids=prompt_ids,
198 | labels=prompt_targets,
199 | attention_mask=[
200 | [id != tokenizer.pad_token_id for id in ids] for ids in prompt_ids
201 | ],
202 | )
203 | answer_sequence_tokens = dict(
204 | input_ids=answer_ids,
205 | labels=answer_targets,
206 | attention_mask=[
207 | [id != tokenizer.pad_token_id for id in ids] for ids in answer_ids
208 | ],
209 | )
210 |
211 | return prompt_sequence_tokens, answer_sequence_tokens
212 |
213 |
214 | def read_jsonl(file_path):
215 | """Read a JSONL file and return a list of dictionaries."""
216 | with open(file_path, "r", encoding="utf-8") as file:
217 | return [json.loads(line) for line in file]
218 |
219 |
220 | def qwen_vl_prompt_format(prompt, img_paths):
221 | out = []
222 | for i, img_path in enumerate(img_paths):
223 | out.append(f"Picture {i + 1}:
{img_path}\n")
224 | out.append(prompt.strip())
225 | return "".join(out)
226 |
227 |
228 | def make_conv(prompt, answer):
229 | return [
230 | {
231 | "from": "user",
232 | "value": prompt,
233 | },
234 | {
235 | "from": "assistant",
236 | "value": answer,
237 | },
238 | ]
239 |
240 |
241 | @dataclass
242 | class QwenDPODataCollator(DPODataCollatorWithPadding):
243 | def tokenize_batch_element(
244 | self,
245 | prompt: str,
246 | chosen: str,
247 | rejected: str,
248 | ) -> Dict:
249 | """Tokenize a single batch element.
250 |
251 | At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
252 | in case the prompt + chosen or prompt + rejected responses is/are too long. First
253 | we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
254 |
255 | We also create the labels for the chosen/rejected responses, which are of length equal to
256 | the sum of the length of the prompt and the chosen/rejected response, with
257 | label_pad_token_id for the prompt tokens.
258 | """
259 | batch = {}
260 |
261 | # format for preprocessing
262 | chosen_conv = make_conv(prompt, chosen)
263 | rejected_conv = make_conv(prompt, rejected)
264 |
265 | # preprocess using Qwen-VL's own method
266 | # note that labels are already set here
267 | prompt_tokens, chosen_tokens = preprocess(
268 | [chosen_conv], self.tokenizer, self.max_length
269 | )
270 | _, rejected_tokens = preprocess(
271 | [rejected_conv], self.tokenizer, self.max_length
272 | )
273 | prompt_tokens = {k: v[0] for k, v in prompt_tokens.items()}
274 | chosen_tokens = {k: v[0] for k, v in chosen_tokens.items()}
275 | rejected_tokens = {k: v[0] for k, v in rejected_tokens.items()}
276 |
277 | eos_token_id = self.tokenizer.eos_token_id
278 | # Get indices in list prompt_tokens["input_ids"] that equals the EOS token (often 0)
279 | eos_indices_prompt = [
280 | i for i, x in enumerate(prompt_tokens["input_ids"]) if x == eos_token_id
281 | ]
282 | # attention mask these indices to eos_token_id
283 | new_attention_mask = [
284 | 0 if i in eos_indices_prompt else p
285 | for i, p in enumerate(prompt_tokens["attention_mask"])
286 | ]
287 | prompt_tokens["attention_mask"] = new_attention_mask
288 |
289 | # do the same for chosen and rejected
290 | eos_indices_chosen = [
291 | i for i, x in enumerate(chosen_tokens["input_ids"]) if x == eos_token_id
292 | ]
293 | new_attention_mask_c = [
294 | 0 if i in eos_indices_chosen else p
295 | for i, p in enumerate(chosen_tokens["attention_mask"])
296 | ]
297 | chosen_tokens["attention_mask"] = new_attention_mask_c
298 |
299 | eos_indices_rejected = [
300 | i for i, x in enumerate(rejected_tokens["input_ids"]) if x == eos_token_id
301 | ]
302 | new_attention_mask_r = [
303 | 0 if i in eos_indices_rejected else p
304 | for i, p in enumerate(rejected_tokens["attention_mask"])
305 | ]
306 | rejected_tokens["attention_mask"] = new_attention_mask_r
307 |
308 | # add EOS token to end of prompt
309 | chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id)
310 | chosen_tokens["labels"].append(self.tokenizer.eos_token_id)
311 | chosen_tokens["attention_mask"].append(1)
312 |
313 | rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id)
314 | rejected_tokens["labels"].append(self.tokenizer.eos_token_id)
315 | rejected_tokens["attention_mask"].append(1)
316 |
317 | longer_response_length = max(
318 | len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])
319 | )
320 |
321 | # if combined sequence is too long, truncate the prompt
322 | if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length:
323 | if self.truncation_mode == "keep_start":
324 | prompt_tokens = {
325 | k: v[: self.max_prompt_length] for k, v in prompt_tokens.items()
326 | }
327 | elif self.truncation_mode == "keep_end":
328 | prompt_tokens = {
329 | k: v[-self.max_prompt_length :] for k, v in prompt_tokens.items()
330 | }
331 | else:
332 | raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
333 |
334 | # if that's still too long, truncate the response
335 | if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length:
336 | chosen_tokens = {
337 | k: v[: self.max_length - self.max_prompt_length]
338 | for k, v in chosen_tokens.items()
339 | }
340 | rejected_tokens = {
341 | k: v[: self.max_length - self.max_prompt_length]
342 | for k, v in rejected_tokens.items()
343 | }
344 |
345 | # Create labels
346 | chosen_tokens = {k: prompt_tokens[k] + chosen_tokens[k] for k in chosen_tokens}
347 | rejected_tokens = {
348 | k: prompt_tokens[k] + rejected_tokens[k] for k in rejected_tokens
349 | }
350 | chosen_tokens["labels"][: len(prompt_tokens["input_ids"])] = [
351 | self.label_pad_token_id
352 | ] * len(prompt_tokens["input_ids"])
353 | rejected_tokens["labels"][: len(prompt_tokens["input_ids"])] = [
354 | self.label_pad_token_id
355 | ] * len(prompt_tokens["input_ids"])
356 |
357 | for k, toks in {
358 | "chosen": chosen_tokens,
359 | "rejected": rejected_tokens,
360 | "prompt": prompt_tokens,
361 | }.items():
362 | for type_key, tokens in toks.items():
363 | if type_key == "token_type_ids":
364 | continue
365 | batch[f"{k}_{type_key}"] = tokens
366 |
367 | batch["prompt"] = prompt
368 | batch["chosen"] = prompt + chosen
369 | batch["rejected"] = prompt + rejected
370 | batch["chosen_response_only"] = chosen
371 | batch["rejected_response_only"] = rejected
372 |
373 | return batch
374 |
375 |
376 | def make_vlfeedback_paired_dataset(local_rank):
377 | ds = datasets.load_dataset("MMInstruction/VLFeedback", split="train")
378 |
379 | # format prompt
380 | if local_rank > 0:
381 | print("Waiting for main process to perform the mapping")
382 | torch.distributed.barrier()
383 |
384 | def set_format(sample):
385 | prompt = sample["prompt"]
386 | img_path = sample["img_path"]
387 | sample["prompt"] = qwen_vl_prompt_format(prompt, [img_path])
388 | return sample
389 |
390 | ds = ds.map(set_format)
391 |
392 | if local_rank == 0:
393 | print("Loading results from main process")
394 | torch.distributed.barrier()
395 |
396 | # make comparison pairs from completion list
397 | if local_rank > 0:
398 | print("Waiting for main process to perform the mapping")
399 | torch.distributed.barrier()
400 |
401 | def make_batch_pairs(sample):
402 | converted_sample = defaultdict(list)
403 |
404 | for sample_idx, comps in enumerate(sample["completions"]):
405 | prompt = sample["prompt"][sample_idx]
406 |
407 | for comp_idx1, comp_idx2 in combinations(range(len(comps["annotations"])), 2):
408 | anno1, anno2 = comps["annotations"][comp_idx1], comps["annotations"][comp_idx2]
409 |
410 | # get average scores
411 | try:
412 | avg_score1 = np.mean(
413 | [
414 | float(anno1[aspect]["Rating"])
415 | for aspect in anno1
416 | ]
417 | )
418 | avg_score2 = np.mean(
419 | [
420 | float(anno2[aspect]["Rating"])
421 | for aspect in anno2
422 | ]
423 | )
424 | except ValueError:
425 | continue
426 |
427 | # get chosen and rejected responses
428 | if avg_score1 > avg_score2:
429 | chosen = comps["response"][comp_idx1]
430 | rejected = comps["response"][comp_idx2]
431 | elif avg_score2 > avg_score1:
432 | chosen = comps["response"][comp_idx2]
433 | rejected = comps["response"][comp_idx1]
434 | else:
435 | continue
436 | converted_sample["prompt"].append(prompt)
437 | converted_sample["chosen"].append(chosen)
438 | converted_sample["rejected"].append(rejected)
439 |
440 | return converted_sample
441 |
442 | ds = ds.map(
443 | make_batch_pairs,
444 | batched=True,
445 | remove_columns=set(ds.column_names) - set(["prompt", "chosen", "rejected"]),
446 | )
447 |
448 | if local_rank == 0:
449 | print("Loading results from main process")
450 | torch.distributed.barrier()
451 |
452 | return ds
453 |
454 | def train():
455 | global local_rank
456 |
457 | os.environ["WANDB_PROJECT"] = "Silkie"
458 | parser = transformers.HfArgumentParser(
459 | (ModelArguments, TrainingArguments, LoraArguments)
460 | )
461 | (
462 | model_args,
463 | training_args,
464 | lora_args,
465 | ) = parser.parse_args_into_dataclasses()
466 |
467 | if getattr(training_args, "deepspeed", None) and getattr(
468 | lora_args, "q_lora", False
469 | ):
470 | training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED
471 |
472 | local_rank = training_args.local_rank
473 |
474 | device_map = None
475 | world_size = int(os.environ.get("WORLD_SIZE", 1))
476 | ddp = world_size != 1
477 | if lora_args.q_lora:
478 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
479 | if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
480 | logging.warning("FSDP or ZeRO3 are not incompatible with QLoRA.")
481 |
482 | # Set RoPE scaling factor
483 | config = transformers.AutoConfig.from_pretrained(
484 | model_args.model_name_or_path,
485 | cache_dir=training_args.cache_dir,
486 | trust_remote_code=True,
487 | fp32=True,
488 | )
489 | config.use_cache = False
490 |
491 | # Load model and tokenizer
492 | model = transformers.AutoModelForCausalLM.from_pretrained(
493 | model_args.model_name_or_path,
494 | config=config,
495 | cache_dir=training_args.cache_dir,
496 | device_map=device_map,
497 | trust_remote_code=True,
498 | quantization_config=GPTQConfig(bits=4, disable_exllama=True)
499 | if training_args.use_lora and lora_args.q_lora
500 | else None,
501 | )
502 |
503 | if not training_args.use_lora:
504 | if (
505 | training_args.fix_vit
506 | and hasattr(model, "transformer")
507 | and hasattr(model.transformer, "visual")
508 | ):
509 | model.transformer.visual.requires_grad_(False)
510 | if hasattr(model.transformer.visual, "attn_pool"):
511 | model.transformer.visual.attn_pool.requires_grad_(True)
512 | tokenizer = transformers.AutoTokenizer.from_pretrained(
513 | model_args.model_name_or_path,
514 | cache_dir=training_args.cache_dir,
515 | model_max_length=training_args.model_max_length,
516 | padding_side="right",
517 | use_fast=False,
518 | trust_remote_code=True,
519 | )
520 | tokenizer.pad_token_id = tokenizer.eod_id
521 | tokenizer.eos_token_id = tokenizer.eod_id
522 |
523 | if training_args.use_lora:
524 | if lora_args.q_lora or "chat" in model_args.model_name_or_path.lower():
525 | modules_to_save = None
526 | else:
527 | modules_to_save = ["wte", "lm_head"]
528 | lora_config = LoraConfig(
529 | r=lora_args.lora_r,
530 | lora_alpha=lora_args.lora_alpha,
531 | target_modules=lora_args.lora_target_modules,
532 | lora_dropout=lora_args.lora_dropout,
533 | bias=lora_args.lora_bias,
534 | task_type="CAUSAL_LM",
535 | modules_to_save=modules_to_save, # This argument serves for adding new tokens.
536 | )
537 | if lora_args.q_lora:
538 | model = prepare_model_for_kbit_training(
539 | model, use_gradient_checkpointing=training_args.gradient_checkpointing
540 | )
541 |
542 | if training_args.gradient_checkpointing:
543 | model.enable_input_require_grads()
544 |
545 | # Load data
546 | dataset = make_vlfeedback_paired_dataset(training_args.local_rank)
547 | dataset_split = dataset.train_test_split(test_size=0.005, seed=42)
548 | train_dataset = dataset_split["train"]
549 | eval_dataset = dataset_split["test"]
550 |
551 | # Start trainner
552 | trainer = DPOTrainer(
553 | model,
554 | args=training_args,
555 | beta=training_args.beta,
556 | train_dataset=train_dataset,
557 | eval_dataset=eval_dataset,
558 | data_collator=QwenDPODataCollator(
559 | tokenizer,
560 | max_length=training_args.model_max_length,
561 | max_prompt_length=training_args.model_max_length // 2,
562 | max_target_length=training_args.model_max_length // 2,
563 | label_pad_token_id=IGNORE_TOKEN_ID,
564 | padding_value=tokenizer.pad_token_id,
565 | truncation_mode="keep_end",
566 | ),
567 | tokenizer=tokenizer,
568 | max_length=training_args.model_max_length,
569 | peft_config=lora_config if training_args.use_lora else None,
570 | generate_during_eval=training_args.generate_during_eval,
571 | )
572 |
573 | trainer.train()
574 | trainer.save_state()
575 |
576 | safe_save_model_for_hf_trainer(
577 | trainer=trainer, output_dir=training_args.output_dir, bias=lora_args.lora_bias
578 | )
579 |
580 |
581 | if __name__ == "__main__":
582 | train()
583 |
--------------------------------------------------------------------------------