├── LICENSE
├── README.md
├── Visual-ARFT
├── README.md
├── assets
│ ├── dataset.png
│ ├── framework.png
│ ├── tesear.png
│ └── test.txt
├── demo
│ ├── .env
│ ├── README.md
│ ├── cases
│ │ ├── 23.jpg
│ │ ├── annotation2.jpg
│ │ ├── coding_case.json
│ │ ├── crop_13_ori.png
│ │ ├── crop_8_ori.png
│ │ ├── dark_1_proc.png
│ │ ├── none_2_proc.png
│ │ ├── rotation180_1_proc.png
│ │ └── rotation90+dark_0_proc.png
│ ├── coding_demo.ipynb
│ ├── search_demo.ipynb
│ └── web_search.py
├── evaluation_coding
│ ├── README.md
│ ├── evaluation.ipynb
│ ├── evaluation_mat_coding_ngpu_7b_df.py
│ ├── evaluation_mat_coding_visual_arft.py
│ └── evaluation_results
│ │ ├── 3b_mat_coding_grpo_step1400.json
│ │ ├── 3b_mat_coding_grpo_step1400_v2.json
│ │ ├── 7b_mat_coding_grpo_step1400.json
│ │ └── 7b_mat_coding_grpo_step1400_v2.json
├── evaluation_search
│ ├── .env
│ ├── MAT-Search-Benchmark
│ │ ├── evaluation.ipynb
│ │ ├── evaluation_mat_search_ngpu_7b_df.py
│ │ └── evaluation_mat_search_ngpu_7b_visual_arft.py
│ ├── README.md
│ ├── evaluation_results
│ │ ├── 2wiki
│ │ │ ├── 3b_step200_2wiki_web_n4.json
│ │ │ └── 7b_step200_2wiki_web_n4.json
│ │ ├── MAT-Search
│ │ │ ├── 3b_step200_mat_search_web_n4.json
│ │ │ └── 7b_step200_mat_search_web_n4.json
│ │ ├── bamboogle
│ │ │ ├── 3b_step200_bamboogle_web_n4.json
│ │ │ └── 7b_step200_bamboogle_web_n4.json
│ │ ├── hotpotqa
│ │ │ ├── 3b_step200_hotpotqa_web_n4.json
│ │ │ └── 7b_step200_hotpotqa_web_n4.json
│ │ └── musique
│ │ │ ├── 3b_step200_musique_web_n4.json
│ │ │ └── 7b_step200_musique_web_n4.json
│ ├── tools
│ │ └── web_search.py
│ └── traditional_benchmark
│ │ ├── evaluation.ipynb
│ │ ├── evaluation_2wikimultihopqa_ngpu_7b_df.py
│ │ ├── evaluation_2wikimultihopqa_ngpu_7b_visual_arft.py
│ │ ├── evaluation_bamboogle_ngpu_7b_df.py
│ │ ├── evaluation_bamboogle_ngpu_7b_visual_arft.py
│ │ ├── evaluation_hotpotqa_ngpu_7b_df.py
│ │ ├── evaluation_hotpotqa_ngpu_7b_visual_arft.py
│ │ ├── evaluation_musique_ngpu_7b_df.py
│ │ └── evaluation_musique_ngpu_7b_visual_arft.py
├── setup.sh
└── src
│ ├── scripts
│ ├── run_grpo_agent_code_3b_1_2k_new2_gpu8.sh
│ ├── run_grpo_agent_code_7b_1_2k_new2_gpu8.sh
│ ├── run_grpo_agent_search_3b_gpu8.sh
│ └── run_grpo_agent_search_7b_gpu8.sh
│ └── visual_arft
│ ├── .gitignore
│ ├── LICENSE
│ ├── Makefile
│ ├── configs
│ ├── ddp.yaml
│ ├── qwen2vl_sft_config.yaml
│ ├── zero2.yaml
│ └── zero3.yaml
│ ├── local_scripts
│ ├── create_vision_cot_data.py
│ ├── lmms_eval_qwen2vl.sh
│ ├── prepare_hf_data.py
│ ├── train_aria_moe.sh
│ ├── train_qwen2_vl.sh
│ ├── zero1_no_optimizer.json
│ ├── zero2.json
│ ├── zero3.json
│ ├── zero3.yaml
│ └── zero3_offload.json
│ ├── run_grpo.sh
│ ├── setup.cfg
│ ├── setup.py
│ ├── src
│ └── open_r1
│ │ ├── __init__.py
│ │ ├── evaluate.py
│ │ ├── generate.py
│ │ ├── grpo.py
│ │ ├── grpo_agent_code.py
│ │ ├── grpo_agent_search.py
│ │ ├── sft.py
│ │ └── trainer
│ │ ├── __init__.py
│ │ ├── grpo_trainer.py
│ │ ├── grpo_trainer_aid.py
│ │ ├── grpo_trainer_mp.py
│ │ ├── vllm_grpo_trainer.py
│ │ └── vllm_grpo_trainer_modified.py
│ └── temp_image.png
├── assets
├── case_cls.png
├── case_lisa.png
├── framework.png
├── pokeymon.jpg
├── radar.png
└── teaser.png
├── classification
├── Qwen2_VL_classification_infere.py
└── val_data
│ ├── fgvc_aircraft.pth
│ ├── fgvc_aircraft.txt
│ ├── oxford_flowers.pth
│ ├── oxford_flowers.txt
│ ├── pets.pth
│ ├── pets.txt
│ ├── stanford_cars.pth
│ └── stanford_cars.txt
├── coco_evaluation
├── Qwen2_VL_coco_infere.py
├── coco_evaluation.py
├── evaluation.ipynb
├── exist_map_coco_Qwen2_vl_2B_baseline.json
└── exist_map_coco_Qwen2_vl_7B_baseline.json
├── dataset
├── README.md
└── build_dataset.ipynb
├── demo
├── README.md
└── lisa_demo.ipynb
├── lisa_evaluation
├── Qwen2_VL_lisa_infere.py
├── Qwen2_VL_lisa_infere.sh
├── README.md
├── box2mask.py
├── evaluation.ipynb
├── gen_box_ann.py
├── gen_sft.py
├── mask_iou.py
└── merge_eval.py
├── lvis_evaluation
├── Qwen2_VL_lvis_infere.py
├── exist_map_lvis_Qwen2_vl_2B_baseline.json
└── exist_map_lvis_Qwen2_vl_7B_baseline.json
├── setup.sh
└── src
├── scripts
├── 2B_aircraft_4_shot.sh
├── 2B_base65cate_6k.sh
├── 2B_lisa_grounding.sh
├── 7B_base65cate_6k.sh
└── example.sh
└── virft
├── .gitignore
├── LICENSE
├── Makefile
├── README.md
├── configs
├── ddp.yaml
├── zero2.yaml
└── zero3.yaml
├── local_scripts
├── create_vision_cot_data.py
├── lmms_eval_qwen2vl.sh
├── prepare_hf_data.py
├── train_aria_moe.sh
├── train_qwen2_vl.sh
├── zero2.json
├── zero3.json
├── zero3.yaml
└── zero3_offload.json
├── setup.cfg
├── setup.py
├── slurm
├── evaluate.slurm
├── generate.slurm
└── sft.slurm
└── src
└── open_r1
├── __init__.py
├── evaluate.py
├── generate.py
├── grpo.py
├── grpo_classification.py
├── grpo_lisa.py
├── sft.py
└── trainer
├── __init__.py
├── grpo_trainer.py
├── grpo_trainer_aid.py
├── grpo_trainer_mp.py
├── history
├── grpo_trainer_v1.py
└── vllm_grpo_trainer_v1.py
├── vllm_grpo_trainer.py
└── vllm_grpo_trainer_modified.py
/Visual-ARFT/assets/dataset.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/Visual-ARFT/assets/dataset.png
--------------------------------------------------------------------------------
/Visual-ARFT/assets/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/Visual-ARFT/assets/framework.png
--------------------------------------------------------------------------------
/Visual-ARFT/assets/tesear.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/Visual-ARFT/assets/tesear.png
--------------------------------------------------------------------------------
/Visual-ARFT/assets/test.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/Visual-ARFT/demo/.env:
--------------------------------------------------------------------------------
1 | SERPER_API = 'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'
--------------------------------------------------------------------------------
/Visual-ARFT/demo/README.md:
--------------------------------------------------------------------------------
1 | # Instruction for demos
2 |
3 | ## Prepare Datasets and Models
4 | You can download MAT Benchmakr on 🤗Datasets.
5 |
6 | You can download our model: 🤗Visual-ARFT-Search and 🤗Visual-ARFT-Coding
7 |
8 | ## Inference
9 | Replace model's and benchmark's path in our demo. Run `coding_demo.ipynb` or `search_demo.ipynb` step by step.
10 |
11 | > 🔔 If you want to try **Visual-ARFT-Search**, typing your SerperAPI in '.env' to support the web search function. (Registration includes 2,500 free usage credits.)
12 |
--------------------------------------------------------------------------------
/Visual-ARFT/demo/cases/23.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/Visual-ARFT/demo/cases/23.jpg
--------------------------------------------------------------------------------
/Visual-ARFT/demo/cases/annotation2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/Visual-ARFT/demo/cases/annotation2.jpg
--------------------------------------------------------------------------------
/Visual-ARFT/demo/cases/coding_case.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "id": 12,
4 | "question": "What is the serving size of the beverage?",
5 | "answers": [
6 | "1 Can (473ml)"
7 | ],
8 | "ori_image_path": "rotation180_1_ori.png",
9 | "processed_image_path": "rotation180_1_proc.png",
10 | "type": [
11 | "rotation180"
12 | ],
13 | "split": "simple"
14 | },
15 | {
16 | "id": 22,
17 | "question": "从图中提取: 发票号码, 发票代码",
18 | "answers": [
19 | "{'发票号码': '16734841', '发票代码': '112001970102'}"
20 | ],
21 | "ori_image_path": "dark_1_ori.png",
22 | "processed_image_path": "dark_1_proc.png",
23 | "type": [
24 | "dark"
25 | ],
26 | "split": "simple"
27 | },
28 | {
29 | "id": 63,
30 | "question": "What is the title of the movie featured in the image?",
31 | "answers": [
32 | "Uncommon Valor"
33 | ],
34 | "ori_image_path": "none_2_ori.png",
35 | "processed_image_path": "none_2_proc.png",
36 | "type": [
37 | "none"
38 | ],
39 | "split": "simple"
40 | },
41 | {
42 | "id": 71,
43 | "question": "从图中提取: 发票号码, 发票代码, 金额",
44 | "answers": [
45 | "{'发票号码': '00134721', '发票代码': '151002065029', '金额': '#'}"
46 | ],
47 | "ori_image_path": "rotation90+dark_0_ori.png",
48 | "processed_image_path": "rotation90+dark_0_proc.png",
49 | "type": [
50 | "rotation90",
51 | "dark"
52 | ],
53 | "split": "hard"
54 | },
55 | {
56 | "id": 119,
57 | "question": "Recognize the text within the [366, 723, 506, 851] of the image. The coordinates have been normalized ranging from 0 to 1000 by the image width and height. ",
58 | "answers": [
59 | "SINCE 1994"
60 | ],
61 | "ori_image_path": "crop_8_ori.png",
62 | "processed_image_path": "crop_8_proc.png",
63 | "type": [
64 | "crop"
65 | ],
66 | "split": "hard"
67 | },
68 | {
69 | "id": 124,
70 | "question": "Recognize the text within the [489, 155, 595, 259] of the image. The coordinates have been normalized ranging from 0 to 1000 by the image width and height. ",
71 | "answers": [
72 | "Spaghetti\nGOEMON"
73 | ],
74 | "ori_image_path": "crop_13_ori.png",
75 | "processed_image_path": "crop_13_proc.png",
76 | "type": [
77 | "crop"
78 | ],
79 | "split": "hard"
80 | }
81 | ]
--------------------------------------------------------------------------------
/Visual-ARFT/demo/cases/crop_13_ori.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/Visual-ARFT/demo/cases/crop_13_ori.png
--------------------------------------------------------------------------------
/Visual-ARFT/demo/cases/crop_8_ori.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/Visual-ARFT/demo/cases/crop_8_ori.png
--------------------------------------------------------------------------------
/Visual-ARFT/demo/cases/dark_1_proc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/Visual-ARFT/demo/cases/dark_1_proc.png
--------------------------------------------------------------------------------
/Visual-ARFT/demo/cases/none_2_proc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/Visual-ARFT/demo/cases/none_2_proc.png
--------------------------------------------------------------------------------
/Visual-ARFT/demo/cases/rotation180_1_proc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/Visual-ARFT/demo/cases/rotation180_1_proc.png
--------------------------------------------------------------------------------
/Visual-ARFT/demo/cases/rotation90+dark_0_proc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/Visual-ARFT/demo/cases/rotation90+dark_0_proc.png
--------------------------------------------------------------------------------
/Visual-ARFT/demo/web_search.py:
--------------------------------------------------------------------------------
1 | import os
2 | from duckduckgo_search import DDGS
3 | import json
4 | import requests
5 | from typing import List, Literal, Optional, Union
6 | from firecrawl import FirecrawlApp
7 | from dotenv import load_dotenv
8 | load_dotenv()
9 |
10 | # 使用免费的 duck duck go 进行网页检索,使用 firecrawl 将网页转化为markdown格式
11 | def web_search_DDG(query: Optional[str], search_num: int = 2, search_mode: str = 'fast') -> Optional[List[str]]:
12 | assert search_mode == 'fast' or search_mode == 'pro'
13 | if search_mode == 'fast':
14 | assert type(query)==str
15 | results = DDGS().text(query, max_results=search_num)
16 | return results
17 | elif search_mode == 'pro':
18 | assert type(query)==str
19 | firecrawl_app = FirecrawlApp(api_key=os.getenv("FIRECRAWL_API"))
20 | results = DDGS().text(query, max_results=search_num)
21 | for result in results:
22 | web_url = result['href']
23 | # firecrawl_app returns markdown and metadata
24 | web_content = firecrawl_app.scrape_url(web_url)
25 | web_content_markdown = web_content['markdown']
26 | web_content_metadata = web_content['metadata']
27 | result['web_content_markdown'] = web_content_markdown
28 | result['web_content_metadata'] = web_content_metadata
29 | return results
30 |
31 | def web_search_SERPER_API(query: Optional[str], search_num=2, search_mode: str = 'fast') -> Optional[List[str]]:
32 | assert search_mode == 'fast' or search_mode == 'pro'
33 | if search_mode == 'fast':
34 | assert type(query)==str
35 | url = "https://google.serper.dev/search"
36 | payload = json.dumps({"q": query, "num": search_num})
37 | headers = {'X-API-KEY': os.getenv('SERPER_API'), 'Content-Type': 'application/json'}
38 | response = requests.request("POST", url, headers=headers, data=payload)
39 | response = json.loads(response.text)
40 | results = []
41 | for item in response['organic']:
42 | results.append(
43 | {'title': item['title'], 'href':item['link'], 'body': item['snippet']}
44 | )
45 | return results
46 | elif search_mode == 'pro':
47 | assert type(query)==str
48 | firecrawl_app = FirecrawlApp(api_key=os.getenv("FIRECRAWL_API"))
49 | url = "https://google.serper.dev/search"
50 | payload = json.dumps({"q": query, "num": search_num})
51 | headers = {'X-API-KEY': os.getenv('SERPER_API'), 'Content-Type': 'application/json'}
52 | response = requests.request("POST", url, headers=headers, data=payload)
53 | response = json.loads(response.text)
54 | results = []
55 | for item in response['organic']:
56 | results.append(
57 | {'title': item['title'], 'href':item['link'], 'body': item['snippet']}
58 | )
59 | for result in results:
60 | web_url = result['href']
61 | # firecrawl_app returns markdown and metadata
62 | web_content = firecrawl_app.scrape_url(web_url)
63 | web_content_markdown = web_content['markdown']
64 | web_content_metadata = web_content['metadata']
65 | result['web_content_markdown'] = web_content_markdown
66 | result['web_content_metadata'] = web_content_metadata
67 | return results
--------------------------------------------------------------------------------
/Visual-ARFT/evaluation_coding/README.md:
--------------------------------------------------------------------------------
1 | # Evaluation on MAT-Coding
2 | In this folder, we provide two Python files for evaluation on **MAT-Coding**.
3 | ``evaluation_mat_coding_ngpu_7b_df.py`` is used to evaluate the results of direct inference, while `evaluation_mat_coding_visual_arft.py` is used to evaluate the inference results of the **Visual-Agentic-Coding** model.
4 |
5 | ## Download Dataset and Models
6 | You can download MAT Benchmakr on 🤗Datasets. Use `/MAT/MAT-Benchmark/MAT-Coding.json`.
7 |
8 | You can download our model: 🤗Visual-ARFT-Coding
9 |
10 | ## Inference on MAT-Coding
11 |
12 | To run `evaluation_mat_coding_visual_arft.py`, you need to replace the paths to the model and dataset:
13 |
14 | - Line 175: Replace **model_name** with the actual model path (**Visual-ARFT-Coding**).
15 | - Line 184: Replace **json_path** with the actual dataset path (MAT-Coding.json).
16 | - Line 201: Replace **input_image_path** with the actual image path.
17 | - Line 342: Set the results save path.
18 |
19 | > 🔔The intermediate image processing result will be saved as `cache.jpg` in the current directory, as specified in Line 206.
20 |
21 | To run `evaluation_mat_coding_ngpu_7b_df.py`, you need to replace the paths to the model and dataset:
22 |
23 | - Line 20: Replace **model_name** with the actual model path (**Original Qwen2.5-VL without Visual-ARFT**).
24 | - Line 30: Replace **json_path** with the actual dataset path (MAT-Coding.json).
25 | - Line 42: Replace **input_image_path** with the actual image path.
26 | - Line 103: Set the results save path.
27 |
28 | > ⏳ The inference time for **Visual-ARFT-Coding-7B** is approximately 1.5 hours, while the **3B model** takes around 50 minutes. The inference time for **Original Qwen2.5-VL without Visual-ARFT** is much faster.
29 |
30 |
31 | ## Evaluation
32 | After obtaining the inference results, run the `evaluation.ipynb` step by step. The `.ipynb` file will provide the final evaluation scores.
33 |
34 | We have saved the inference results of the `Visual-ARFT-Coding-7B` model in the `evaluation_results` folder. You can directly use the results inside to test the evaluation scores.
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/Visual-ARFT/evaluation_coding/evaluation_mat_coding_ngpu_7b_df.py:
--------------------------------------------------------------------------------
1 | import re
2 | import json
3 | import torch
4 | import string
5 | import numpy as np
6 | from tqdm import tqdm
7 | from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
8 | from qwen_vl_utils import process_vision_info
9 | import matplotlib.pyplot as plt
10 | import matplotlib.image as mpimg
11 | from PIL import Image
12 |
13 |
14 | # 定义颜色的ANSI代码
15 | RED = '\033[91m'
16 | GREEN = '\033[92m'
17 | YELLOW = '\033[93m'
18 | RESET = '\033[0m' # 重置颜色
19 |
20 | model_path = "/shared/mllm_ckpts/models--Qwen--Qwen2.5-VL-7B-Instruct"
21 |
22 | model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
23 | model_path,
24 | torch_dtype=torch.bfloat16,
25 | attn_implementation="flash_attention_2",
26 | device_map='cuda:0',
27 | )
28 | processor = AutoProcessor.from_pretrained(model_path)
29 |
30 | with open('/share_data/MAT/MAT-Benchmark/MAT-Coding.json', 'r') as file:
31 | wikimultihopqa = json.load(file)
32 | print(len(wikimultihopqa))
33 |
34 |
35 | combine_results = []
36 | for item in tqdm(wikimultihopqa[:]):
37 | print("########################################")
38 | if item['type'][0] == 'crop':
39 | input_image_path = item['ori_image_path']
40 | else:
41 | input_image_path = item['processed_image_path']
42 | input_image_path = '/share_data/MAT/MAT-Benchmark/MAT-Coding-image/' + input_image_path
43 | query = item['question']
44 | data_type = item['type']
45 | item_id = item['id']
46 | answer = item['answers']
47 |
48 | # ### Direct Inference
49 | input_text = query + '\n' + "Answer the question directly. The answer should be very brief."
50 |
51 | ### CoT
52 | # input_text = SYSTEM_PROMPT + '\n' + query + "\n" + "You must output your thinking processs in . The answer between should be very brief."
53 |
54 | print(RED+input_text+RESET)
55 | print(GREEN+str(answer)+RESET)
56 |
57 | try:
58 | ###### 进行一轮推理 ######
59 | messages = [
60 | { "role": "user", "content": [{"type": "image","image": input_image_path}, {"type": "text", "text": input_text}]}
61 | ]
62 | # Preparation for inference
63 | text = processor.apply_chat_template(
64 | messages, tokenize=False, add_generation_prompt=True
65 | )
66 | image_inputs, video_inputs = process_vision_info(messages)
67 | inputs = processor(
68 | text=[text],
69 | images=image_inputs,
70 | videos=video_inputs,
71 | padding=True,
72 | return_tensors="pt",
73 | )
74 | inputs = inputs.to(model.device)
75 |
76 | # Inference: Generation of the output
77 | generated_ids = model.generate(**inputs, max_new_tokens=2048)
78 | generated_ids_trimmed = [
79 | out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
80 | ]
81 | output_text = processor.batch_decode(
82 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
83 | )
84 | result = output_text[0]
85 |
86 | pred_answer = result
87 | print(YELLOW+result+RESET)
88 |
89 | # match = re.search(r"(.*?)", result, re.DOTALL)
90 | # if match:
91 | # pred_answer = match.group(1).strip()
92 | # print(YELLOW+result+RESET)
93 | # print(YELLOW+pred_answer+RESET)
94 |
95 | except Exception as e:
96 | print("ERROR OCCURES")
97 | print({e})
98 | combine_results.append(
99 | {'pred_answer': pred_answer, 'gt': answer, 'query': query}
100 | )
101 |
102 | print(len(combine_results))
103 | with open(f"./7B_df.json", "w", encoding="utf-8") as f:
104 | json.dump(combine_results, f, ensure_ascii=False, indent=4)
--------------------------------------------------------------------------------
/Visual-ARFT/evaluation_search/.env:
--------------------------------------------------------------------------------
1 | SERPER_API = 'xxxxxxxxxxxxxxxxxxxxxxxxx'
--------------------------------------------------------------------------------
/Visual-ARFT/evaluation_search/MAT-Search-Benchmark/evaluation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import re\n",
10 | "import string\n",
11 | "def normalize(s):\n",
12 | " def remove_articles(text):\n",
13 | " return re.sub(r\"\\b(a|an|the)\\b\", \" \", text)\n",
14 | "\n",
15 | " def white_space_fix(text):\n",
16 | " return \" \".join(text.split())\n",
17 | "\n",
18 | " def remove_punc(text):\n",
19 | " exclude = set(string.punctuation)\n",
20 | " return \"\".join(ch for ch in text if ch not in exclude)\n",
21 | "\n",
22 | " def lower(text):\n",
23 | " return text.lower()\n",
24 | "\n",
25 | " return white_space_fix(remove_articles(remove_punc(lower(s))))\n",
26 | "\n",
27 | "def compute_f1(prediction, ground_truth):\n",
28 | " if prediction is None:\n",
29 | " return 0.0\n",
30 | " prediction_tokens = normalize(prediction).split()\n",
31 | " ground_truth_tokens = normalize(ground_truth).split()\n",
32 | "\n",
33 | " common = set(prediction_tokens) & set(ground_truth_tokens)\n",
34 | " num_same = len(common)\n",
35 | "\n",
36 | " if num_same == 0:\n",
37 | " return 0.0\n",
38 | "\n",
39 | " precision = num_same / len(prediction_tokens)\n",
40 | " recall = num_same / len(ground_truth_tokens)\n",
41 | " f1 = 2 * precision * recall / (precision + recall)\n",
42 | " return f1\n",
43 | "\n",
44 | "def exact_match_score(prediction, ground_truth):\n",
45 | " if prediction is None:\n",
46 | " return 0.0\n",
47 | " return int(normalize(prediction) == normalize(ground_truth))\n",
48 | "\n",
49 | "def evaluate(predictions):\n",
50 | " total = len(predictions)\n",
51 | " f1_total = []\n",
52 | " em_total = []\n",
53 | "\n",
54 | " for item in predictions:\n",
55 | " # if item['pred_answer_ori'] == None:\n",
56 | " pred = item['pred_answer']\n",
57 | " # else:\n",
58 | " # pred = item['pred_answer_ori']\n",
59 | " gts = item['gt']\n",
60 | "\n",
61 | " # 若gt是str,统一转换为列表处理\n",
62 | " if isinstance(gts, str):\n",
63 | " gts = [gts]\n",
64 | "\n",
65 | " f1 = max([compute_f1(pred, gt) for gt in gts])\n",
66 | " em = max([exact_match_score(pred, gt) for gt in gts])\n",
67 | " if em == 1:\n",
68 | " f1 = 1\n",
69 | "\n",
70 | " f1_total.append(f1)\n",
71 | " em_total.append(em)\n",
72 | "\n",
73 | " return {\n",
74 | " \"avg_f1\": sum(f1_total) / total if total > 0 else 0,\n",
75 | " \"avg_em\": sum(em_total) / total if total > 0 else 0,\n",
76 | " \"simple_f1\": sum(f1_total[:75]) / 75 if total > 0 else 0,\n",
77 | " \"simple_em\": sum(em_total[:75]) / 75 if total > 0 else 0,\n",
78 | " \"hard_f1\": sum(f1_total[75:]) / 75 if total > 0 else 0,\n",
79 | " \"hard_em\": sum(em_total[75:]) / 75 if total > 0 else 0,\n",
80 | " }"
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "execution_count": null,
86 | "metadata": {},
87 | "outputs": [],
88 | "source": [
89 | "# 12576 125 7405 2417\n",
90 | "\n",
91 | "import json\n",
92 | "with open('7B_df.json', 'r') as f:\n",
93 | " combine_results = json.load(f)\n",
94 | "print(len(combine_results))\n",
95 | "\n",
96 | "count_none = 0\n",
97 | "for item in combine_results:\n",
98 | " if item['pred_answer'] == None:\n",
99 | " count_none += 1\n",
100 | "print(count_none)\n",
101 | "results = evaluate(combine_results)\n",
102 | "print(*[f\"{results[k]*100:.2f}\" for k in ['simple_f1', 'simple_em', 'hard_f1', 'hard_em', 'avg_f1', 'avg_em']])"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": null,
108 | "metadata": {},
109 | "outputs": [],
110 | "source": []
111 | }
112 | ],
113 | "metadata": {
114 | "kernelspec": {
115 | "display_name": "r1-v",
116 | "language": "python",
117 | "name": "python3"
118 | },
119 | "language_info": {
120 | "codemirror_mode": {
121 | "name": "ipython",
122 | "version": 3
123 | },
124 | "file_extension": ".py",
125 | "mimetype": "text/x-python",
126 | "name": "python",
127 | "nbconvert_exporter": "python",
128 | "pygments_lexer": "ipython3",
129 | "version": "3.11.11"
130 | },
131 | "orig_nbformat": 4
132 | },
133 | "nbformat": 4,
134 | "nbformat_minor": 2
135 | }
136 |
--------------------------------------------------------------------------------
/Visual-ARFT/evaluation_search/MAT-Search-Benchmark/evaluation_mat_search_ngpu_7b_df.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
4 | import json
5 | import torch
6 | from tqdm import tqdm
7 | from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
8 | from qwen_vl_utils import process_vision_info
9 |
10 | # 定义颜色的ANSI代码
11 | RED = '\033[91m'
12 | GREEN = '\033[92m'
13 | YELLOW = '\033[93m'
14 | RESET = '\033[0m' # 重置颜色
15 |
16 |
17 | model_path = "/shared/mllm_ckpts/models--Qwen--Qwen2.5-VL-7B-Instruct"
18 |
19 | model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
20 | model_path,
21 | torch_dtype=torch.bfloat16,
22 | attn_implementation="flash_attention_2",
23 | device_map='cuda:0',
24 | )
25 | processor = AutoProcessor.from_pretrained(model_path)
26 |
27 | with open('/share_data/MAT/MAT-Benchmark/MAT-Coding.json', 'r') as file:
28 | wikimultihopqa = json.load(file)
29 | print(len(wikimultihopqa))
30 |
31 | combine_results = []
32 | for i in tqdm(range(len(wikimultihopqa))):
33 | pred_answer = None
34 | query = wikimultihopqa[i]['question']
35 | answer = wikimultihopqa[i]['answer']
36 | image_path = wikimultihopqa[i]['image_path']
37 | input_image_path = '/MAT/MAT-Benchmark/MAT-Coding-image/' + image_path
38 | item_id = wikimultihopqa[i]['id']
39 |
40 | # ### Direct Inference
41 | input_text = query + '\n' + "Answer the question directly. The answer should be very brief."
42 |
43 | ### CoT
44 | # input_text = SYSTEM_PROMPT + '\n' + query + "\n" + "You must output your thinking processs in . The answer between should be very brief."
45 |
46 | print(RED+input_text+RESET)
47 | print(GREEN+str(answer)+RESET)
48 |
49 | try:
50 | ###### 进行一轮推理 ######
51 | messages = [
52 | { "role": "user", "content": [{"type": "image","image": input_image_path}, {"type": "text", "text": input_text}]}
53 | ]
54 | # Preparation for inference
55 | text = processor.apply_chat_template(
56 | messages, tokenize=False, add_generation_prompt=True
57 | )
58 | image_inputs, video_inputs = process_vision_info(messages)
59 | inputs = processor(
60 | text=[text],
61 | images=image_inputs,
62 | videos=video_inputs,
63 | padding=True,
64 | return_tensors="pt",
65 | )
66 | inputs = inputs.to(model.device)
67 |
68 | # Inference: Generation of the output
69 | generated_ids = model.generate(**inputs, max_new_tokens=2048)
70 | generated_ids_trimmed = [
71 | out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
72 | ]
73 | output_text = processor.batch_decode(
74 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
75 | )
76 | result = output_text[0]
77 |
78 | pred_answer = result
79 | print(YELLOW+result+RESET)
80 |
81 | # match = re.search(r"(.*?)", result, re.DOTALL)
82 | # if match:
83 | # pred_answer = match.group(1).strip()
84 | # print(YELLOW+result+RESET)
85 | # print(YELLOW+pred_answer+RESET)
86 |
87 | except Exception as e:
88 | print("ERROR OCCURES")
89 | print({e})
90 | combine_results.append(
91 | {'pred_answer': pred_answer, 'gt': answer, 'query': query}
92 | )
93 |
94 | print(len(combine_results))
95 | with open(f"./7B_df.json", "w", encoding="utf-8") as f:
96 | json.dump(combine_results, f, ensure_ascii=False, indent=4)
--------------------------------------------------------------------------------
/Visual-ARFT/evaluation_search/README.md:
--------------------------------------------------------------------------------
1 | # Evaluation on MAT-Search
2 | In this folder, we provide scripts for evaluation on **MAT-Search** and other four multihopQA Benchmark (2wikimultihopQA, HotpotQA, MuSiQue, Bamboogle).
3 | ``MAT-Search-Benchmark`` folder is used to evaluate the results on **MAT-Search**, while `traditional_benchmark` folder is used to evaluate the inference results of the other four multihopQA Benchmark.
4 |
5 | ## Download Dataset and Models
6 | You can download MAT Benchmakr on 🤗Datasets. Use `/MAT/MAT-Benchmark/MAT-Search.json`.
7 |
8 | You can download our model: 🤗Visual-ARFT-Search
9 |
10 | You can download other four multihopQA benchmark from 🤗Dataset.
11 |
12 | ## Web Search API
13 | We use SerperAPI for web search. You can start by registering an account to receive 2,500 free queries, and then add your API key to the `.env` file.
14 |
15 | ## Inference on MAT-Search
16 |
17 | Firstly, `cd MAT-Search-Benchmark`. In this folder, there are `evaluation_mat_search_ngpu_7b_df.py` and `evaluation_mat_search_ngpu_7b_visual_arft.py`. The first Python file is used to evaluate the results of direct inference, while the second Python file is used to evaluate the results of the **Visual-ARFT-Search** model.
18 |
19 | To run `evaluation_mat_coding_visual_arft.py`, you need to replace the paths to the model and dataset:
20 |
21 | - Line 59: Replace **model_name** with the actual model path (**Visual-ARFT-Search**).
22 | - Line 76: Replace **json_path** with the actual dataset path (MAT-Search.json).
23 | - Line 97: Replace **input_image_path** with the actual image path.
24 | - Line 192: Set the results save path.
25 |
26 | To run `evaluation_mat_search_ngpu_7b_df.py`, you need to replace the paths to the model and dataset:
27 |
28 | - Line 17: Replace **model_name** with the actual model path (**Original Qwen2.5-VL without Visual-ARFT**).
29 | - Line 27: Replace **json_path** with the actual dataset path (MAT-Search.json).
30 | - Line 37: Replace **input_image_path** with the actual image path.
31 | - Line 95: Set the results save path.
32 |
33 | > ⏳ The code `evaluation_mat_coding_visual_arft.py` supports multi-GPU execution. The inference time for **Visual-ARFT-Coding-3/7B** is around ten minutes with 8GPUs.
34 |
35 | ## Inference on other MultihopQA Benchmark
36 |
37 | Firstly, `cd traditional_benchmark`. Then, replace the paths as instructed above.
38 |
39 | **2Wiki** contains over 12k questions and **HotpotQA** has over 7k, so the inference process may take a relatively long time. However, all scripts in the `traditional_benchmark` folder support multi-GPU inference.
40 |
41 | ## Evaluation
42 | After obtaining the inference results, run the `evaluation.ipynb` in each folder step by step. The `.ipynb` file will provide the final evaluation scores.
43 |
44 | We have also saved the inference results of the `Visual-ARFT-Search-7B` model in the `evaluation_results` folder. You can directly use the results inside to test the evaluation scores.
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
--------------------------------------------------------------------------------
/Visual-ARFT/evaluation_search/tools/web_search.py:
--------------------------------------------------------------------------------
1 | import os
2 | from duckduckgo_search import DDGS
3 | import json
4 | import requests
5 | from typing import List, Literal, Optional, Union
6 | from firecrawl import FirecrawlApp
7 | from dotenv import load_dotenv
8 | load_dotenv()
9 |
10 | # 使用免费的 duck duck go 进行网页检索,使用 firecrawl 将网页转化为markdown格式
11 | def web_search_DDG(query: Optional[str], search_num: int = 2, search_mode: str = 'fast') -> Optional[List[str]]:
12 | assert search_mode == 'fast' or search_mode == 'pro'
13 | if search_mode == 'fast':
14 | assert type(query)==str
15 | results = DDGS().text(query, max_results=search_num)
16 | return results
17 | elif search_mode == 'pro':
18 | assert type(query)==str
19 | firecrawl_app = FirecrawlApp(api_key=os.getenv("FIRECRAWL_API"))
20 | results = DDGS().text(query, max_results=search_num)
21 | for result in results:
22 | web_url = result['href']
23 | # firecrawl_app returns markdown and metadata
24 | web_content = firecrawl_app.scrape_url(web_url)
25 | web_content_markdown = web_content['markdown']
26 | web_content_metadata = web_content['metadata']
27 | result['web_content_markdown'] = web_content_markdown
28 | result['web_content_metadata'] = web_content_metadata
29 | return results
30 |
31 | def web_search_SERPER_API(query: Optional[str], search_num=2, search_mode: str = 'fast') -> Optional[List[str]]:
32 | assert search_mode == 'fast' or search_mode == 'pro'
33 | if search_mode == 'fast':
34 | assert type(query)==str
35 | url = "https://google.serper.dev/search"
36 | payload = json.dumps({"q": query, "num": search_num})
37 | headers = {'X-API-KEY': os.getenv('SERPER_API'), 'Content-Type': 'application/json'}
38 | response = requests.request("POST", url, headers=headers, data=payload)
39 | response = json.loads(response.text)
40 | results = []
41 | for item in response['organic']:
42 | results.append(
43 | {'title': item['title'], 'href':item['link'], 'body': item['snippet']}
44 | )
45 | return results
46 | elif search_mode == 'pro':
47 | assert type(query)==str
48 | firecrawl_app = FirecrawlApp(api_key=os.getenv("FIRECRAWL_API"))
49 | url = "https://google.serper.dev/search"
50 | payload = json.dumps({"q": query, "num": search_num})
51 | headers = {'X-API-KEY': os.getenv('SERPER_API'), 'Content-Type': 'application/json'}
52 | response = requests.request("POST", url, headers=headers, data=payload)
53 | response = json.loads(response.text)
54 | results = []
55 | for item in response['organic']:
56 | results.append(
57 | {'title': item['title'], 'href':item['link'], 'body': item['snippet']}
58 | )
59 | for result in results:
60 | web_url = result['href']
61 | # firecrawl_app returns markdown and metadata
62 | web_content = firecrawl_app.scrape_url(web_url)
63 | web_content_markdown = web_content['markdown']
64 | web_content_metadata = web_content['metadata']
65 | result['web_content_markdown'] = web_content_markdown
66 | result['web_content_metadata'] = web_content_metadata
67 | return results
--------------------------------------------------------------------------------
/Visual-ARFT/evaluation_search/traditional_benchmark/evaluation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import re\n",
10 | "import string\n",
11 | "def normalize(s):\n",
12 | " def remove_articles(text):\n",
13 | " return re.sub(r\"\\b(a|an|the)\\b\", \" \", text)\n",
14 | "\n",
15 | " def white_space_fix(text):\n",
16 | " return \" \".join(text.split())\n",
17 | "\n",
18 | " def remove_punc(text):\n",
19 | " exclude = set(string.punctuation)\n",
20 | " return \"\".join(ch for ch in text if ch not in exclude)\n",
21 | "\n",
22 | " def lower(text):\n",
23 | " return text.lower()\n",
24 | "\n",
25 | " return white_space_fix(remove_articles(remove_punc(lower(s))))\n",
26 | "\n",
27 | "def compute_f1(prediction, ground_truth):\n",
28 | " if prediction is None:\n",
29 | " return 0.0\n",
30 | " prediction_tokens = normalize(prediction).split()\n",
31 | " ground_truth_tokens = normalize(ground_truth).split()\n",
32 | "\n",
33 | " common = set(prediction_tokens) & set(ground_truth_tokens)\n",
34 | " num_same = len(common)\n",
35 | "\n",
36 | " if num_same == 0:\n",
37 | " return 0.0\n",
38 | "\n",
39 | " precision = num_same / len(prediction_tokens)\n",
40 | " recall = num_same / len(ground_truth_tokens)\n",
41 | " f1 = 2 * precision * recall / (precision + recall)\n",
42 | " return f1\n",
43 | "\n",
44 | "def exact_match_score(prediction, ground_truth):\n",
45 | " if prediction is None:\n",
46 | " return 0.0\n",
47 | " return int(normalize(prediction) == normalize(ground_truth))\n",
48 | "\n",
49 | "def evaluate(predictions):\n",
50 | " total = len(predictions)\n",
51 | " f1_total = 0\n",
52 | " em_total = 0\n",
53 | "\n",
54 | " for item in predictions:\n",
55 | " # if item['pred_answer_ori'] == None:\n",
56 | " pred = item['pred_answer']\n",
57 | " # else:\n",
58 | " # pred = item['pred_answer_ori']\n",
59 | " gts = item['gt']\n",
60 | "\n",
61 | " # 若gt是str,统一转换为列表处理\n",
62 | " if isinstance(gts, str):\n",
63 | " gts = [gts]\n",
64 | "\n",
65 | " f1 = max([compute_f1(pred, gt) for gt in gts])\n",
66 | " em = max([exact_match_score(pred, gt) for gt in gts])\n",
67 | " if em == 1:\n",
68 | " f1 = 1\n",
69 | "\n",
70 | " f1_total += f1\n",
71 | " em_total += em\n",
72 | "\n",
73 | " return {\n",
74 | " \"avg_f1\": f1_total / total if total > 0 else 0,\n",
75 | " \"avg_em\": em_total / total if total > 0 else 0\n",
76 | " }"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": null,
82 | "metadata": {},
83 | "outputs": [],
84 | "source": [
85 | "import json\n",
86 | "\n",
87 | "# results list:\n",
88 | "json_file_paths = [\n",
89 | " \"7b_step200_musique_web_n4_0.json\",\n",
90 | " \"7b_step200_musique_web_n4_1.json\",\n",
91 | " \"7b_step200_musique_web_n4_2.json\",\n",
92 | " \"7b_step200_musique_web_n4_3.json\",\n",
93 | " \"7b_step200_musique_web_n4_4.json\",\n",
94 | " \"7b_step200_musique_web_n4_5.json\",\n",
95 | " \"7b_step200_musique_web_n4_6.json\",\n",
96 | " \"7b_step200_musique_web_n4_7.json\",\n",
97 | "]\n",
98 | "\n",
99 | "# combined to one file\n",
100 | "merged_data = []\n",
101 | "for path in json_file_paths:\n",
102 | " with open(path, \"r\", encoding=\"utf-8\") as f:\n",
103 | " data = json.load(f)\n",
104 | " merged_data.extend(data)\n",
105 | "\n",
106 | "print(len(merged_data))\n",
107 | "# save\n",
108 | "with open(\"7b_step200_musique_web_n4.json\", \"w\", encoding=\"utf-8\") as f:\n",
109 | " json.dump(merged_data, f, ensure_ascii=False, indent=2)\n"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": null,
115 | "metadata": {},
116 | "outputs": [],
117 | "source": [
118 | "import json\n",
119 | "with open('./evaluation_results/musique/7b_step200_musique_web_n4.json', 'r') as f:\n",
120 | " combine_results = json.load(f)\n",
121 | "print(len(combine_results))\n",
122 | "\n",
123 | "count_none = 0\n",
124 | "for item in combine_results:\n",
125 | " if item['pred_answer'] == None:\n",
126 | " count_none += 1\n",
127 | "print(count_none)\n",
128 | "results = evaluate(combine_results)\n",
129 | "print(\"Average F1:\", results['avg_f1'])\n",
130 | "print(\"Average EM:\", results['avg_em'])"
131 | ]
132 | },
133 | {
134 | "cell_type": "code",
135 | "execution_count": null,
136 | "metadata": {},
137 | "outputs": [],
138 | "source": []
139 | }
140 | ],
141 | "metadata": {
142 | "kernelspec": {
143 | "display_name": "r1-v",
144 | "language": "python",
145 | "name": "python3"
146 | },
147 | "language_info": {
148 | "codemirror_mode": {
149 | "name": "ipython",
150 | "version": 3
151 | },
152 | "file_extension": ".py",
153 | "mimetype": "text/x-python",
154 | "name": "python",
155 | "nbconvert_exporter": "python",
156 | "pygments_lexer": "ipython3",
157 | "version": "3.11.11"
158 | },
159 | "orig_nbformat": 4
160 | },
161 | "nbformat": 4,
162 | "nbformat_minor": 2
163 | }
164 |
--------------------------------------------------------------------------------
/Visual-ARFT/evaluation_search/traditional_benchmark/evaluation_2wikimultihopqa_ngpu_7b_df.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | from tqdm import tqdm
4 | from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
5 |
6 | # 定义颜色的ANSI代码
7 | RED = '\033[91m'
8 | GREEN = '\033[92m'
9 | YELLOW = '\033[93m'
10 | RESET = '\033[0m' # 重置颜色
11 |
12 | import logging
13 | logging.basicConfig()
14 | logger = logging.getLogger(__name__)
15 | logger.setLevel(logging.INFO)
16 |
17 | import functools
18 | import itertools
19 | import multiprocessing as mp
20 | from argparse import ArgumentParser
21 | from multiprocessing import Pool
22 |
23 | model_path = "/shared/mllm_ckpts/models--Qwen--Qwen2.5-VL-7B-Instruct"
24 |
25 | def run(rank, world_size):
26 | ### 多卡时候,device_map需要设置为cpu,再分配到不同GPU上,不能设置为 auto
27 | model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
28 | model_path,
29 | torch_dtype=torch.bfloat16,
30 | attn_implementation="flash_attention_2",
31 | device_map='cpu',
32 | )
33 | processor = AutoProcessor.from_pretrained(model_path)
34 |
35 | model = model.to(torch.device(rank))
36 | model = model.eval()
37 |
38 | # 加载数据集数据
39 | wikimultihopqa = []
40 | with open('/2wikimultihopqa/dev.jsonl', 'r', encoding='utf-8') as f:
41 | for line in f:
42 | if line.strip(): # 跳过空行
43 | wikimultihopqa.append(json.loads(line))
44 | print(len(wikimultihopqa))
45 |
46 | # wikimultihopqa = wikimultihopqa[:4]
47 |
48 | print(wikimultihopqa[0]['question'])
49 | print(wikimultihopqa[0]['golden_answers'])
50 | print("Rank:" + str(rank))
51 | print("World Size:" + str(world_size))
52 | import math
53 | split_length = math.ceil(len(wikimultihopqa)/world_size)
54 | print("Split Chunk Length:" + str(split_length))
55 | split_wikimultihopqa = wikimultihopqa[int(rank*split_length) : int((rank+1)*split_length)]
56 | print(len(split_wikimultihopqa))
57 | wikimultihopqa = split_wikimultihopqa
58 |
59 | combine_results = []
60 | for i in tqdm(range(len(wikimultihopqa))):
61 | pred_answer = None
62 | query = wikimultihopqa[i]['question']
63 | answer = wikimultihopqa[i]['golden_answers']
64 | item_id = wikimultihopqa[i]['id']
65 | # + 'Answer the question directly.'
66 | input_text = query + '\n' + 'Only Return the answer.'
67 | # print("################################################################")
68 | # print(query)
69 | # print(answer)
70 | try:
71 | messages = [
72 | { "role": "user",
73 | "content": [{"type": "text", "text": input_text}]}
74 | ]
75 | # Preparation for inference
76 | text = processor.apply_chat_template(
77 | messages, tokenize=False, add_generation_prompt=True
78 | )
79 | # image_inputs, video_inputs = process_vision_info(messages)
80 | inputs = processor(
81 | text=[text],
82 | # images=image_inputs,
83 | # videos=video_inputs,
84 | padding=True,
85 | return_tensors="pt",
86 | )
87 | inputs = inputs.to(model.device)
88 |
89 | # Inference: Generation of the output
90 | generated_ids = model.generate(**inputs, max_new_tokens=2048)
91 | generated_ids_trimmed = [
92 | out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
93 | ]
94 | output_text = processor.batch_decode(
95 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
96 | )
97 | result = output_text[0]
98 | # print(result)
99 | pred_answer = result
100 |
101 | except Exception as e:
102 | logger.info("ERROR OCCURES")
103 | logger.info({e})
104 | if pred_answer != None:
105 | combine_results.append(
106 | {'id': item_id, 'pred_answer': pred_answer, 'gt': answer, 'query': query}
107 | )
108 | else:
109 | combine_results.append(
110 | {'id': item_id, 'pred_answer': None, 'gt': answer, 'query': query}
111 | )
112 | with open(f"7b_ori_2wiki_direct_infere_{rank}.json", "w", encoding="utf-8") as f:
113 | json.dump(combine_results, f, ensure_ascii=False, indent=4)
114 | return combine_results
115 |
116 | def main():
117 | multiprocess = torch.cuda.device_count() >= 2
118 | mp.set_start_method('spawn')
119 | if multiprocess:
120 | logger.info('started generation')
121 | n_gpus = torch.cuda.device_count()
122 | world_size = n_gpus
123 | with Pool(world_size) as pool:
124 | func = functools.partial(run, world_size=world_size)
125 | result_lists = pool.map(func, range(world_size))
126 |
127 | global_results = []
128 | for i in range(world_size):
129 | global_results = global_results + result_lists[i]
130 |
131 | logger.info("Done")
132 | logger.info('finished running')
133 | else:
134 | logger.info("Not enough GPUs")
135 |
136 | if __name__ == "__main__":
137 | main()
--------------------------------------------------------------------------------
/Visual-ARFT/evaluation_search/traditional_benchmark/evaluation_bamboogle_ngpu_7b_df.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | from tqdm import tqdm
4 | from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
5 |
6 | # 定义颜色的ANSI代码
7 | RED = '\033[91m'
8 | GREEN = '\033[92m'
9 | YELLOW = '\033[93m'
10 | RESET = '\033[0m' # 重置颜色
11 |
12 | import logging
13 | logging.basicConfig()
14 | logger = logging.getLogger(__name__)
15 | logger.setLevel(logging.INFO)
16 |
17 | import functools
18 | import itertools
19 | import multiprocessing as mp
20 | from argparse import ArgumentParser
21 | from multiprocessing import Pool
22 |
23 | model_path = "/shared/mllm_ckpts/models--Qwen--Qwen2.5-VL-7B-Instruct"
24 |
25 | def run(rank, world_size):
26 | ### 多卡时候,device_map需要设置为cpu,再分配到不同GPU上,不能设置为 auto
27 | model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
28 | model_path,
29 | torch_dtype=torch.bfloat16,
30 | attn_implementation="flash_attention_2",
31 | device_map='cpu',
32 | )
33 | processor = AutoProcessor.from_pretrained(model_path)
34 |
35 | model = model.to(torch.device(rank))
36 | model = model.eval()
37 |
38 | # 加载数据集数据
39 | wikimultihopqa = []
40 | with open('/bamboogle/test.jsonl', 'r', encoding='utf-8') as f:
41 | for line in f:
42 | if line.strip(): # 跳过空行
43 | wikimultihopqa.append(json.loads(line))
44 | print(len(wikimultihopqa))
45 |
46 | # wikimultihopqa = wikimultihopqa[:4]
47 |
48 | print(wikimultihopqa[0]['question'])
49 | print(wikimultihopqa[0]['golden_answers'])
50 | print("Rank:" + str(rank))
51 | print("World Size:" + str(world_size))
52 | import math
53 | split_length = math.ceil(len(wikimultihopqa)/world_size)
54 | print("Split Chunk Length:" + str(split_length))
55 | split_wikimultihopqa = wikimultihopqa[int(rank*split_length) : int((rank+1)*split_length)]
56 | print(len(split_wikimultihopqa))
57 | wikimultihopqa = split_wikimultihopqa
58 |
59 | combine_results = []
60 | for i in tqdm(range(len(wikimultihopqa))):
61 | pred_answer = None
62 | query = wikimultihopqa[i]['question']
63 | answer = wikimultihopqa[i]['golden_answers']
64 | item_id = wikimultihopqa[i]['id']
65 | # + 'Answer the question directly.'
66 | input_text = query + '\n' + 'Only Return the answer.'
67 | # print("################################################################")
68 | # print(query)
69 | # print(answer)
70 | try:
71 | messages = [
72 | { "role": "user",
73 | "content": [{"type": "text", "text": input_text}]}
74 | ]
75 | # Preparation for inference
76 | text = processor.apply_chat_template(
77 | messages, tokenize=False, add_generation_prompt=True
78 | )
79 | # image_inputs, video_inputs = process_vision_info(messages)
80 | inputs = processor(
81 | text=[text],
82 | # images=image_inputs,
83 | # videos=video_inputs,
84 | padding=True,
85 | return_tensors="pt",
86 | )
87 | inputs = inputs.to(model.device)
88 |
89 | # Inference: Generation of the output
90 | generated_ids = model.generate(**inputs, max_new_tokens=2048)
91 | generated_ids_trimmed = [
92 | out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
93 | ]
94 | output_text = processor.batch_decode(
95 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
96 | )
97 | result = output_text[0]
98 | # print(result)
99 | pred_answer = result
100 |
101 | except Exception as e:
102 | logger.info("ERROR OCCURES")
103 | logger.info({e})
104 | if pred_answer != None:
105 | combine_results.append(
106 | {'id': item_id, 'pred_answer': pred_answer, 'gt': answer, 'query': query}
107 | )
108 | else:
109 | combine_results.append(
110 | {'id': item_id, 'pred_answer': None, 'gt': answer, 'query': query}
111 | )
112 | with open(f"7b_ori_bamboogle_direct_infere_{rank}.json", "w", encoding="utf-8") as f:
113 | json.dump(combine_results, f, ensure_ascii=False, indent=4)
114 | return combine_results
115 |
116 | def main():
117 | multiprocess = torch.cuda.device_count() >= 2
118 | mp.set_start_method('spawn')
119 | if multiprocess:
120 | logger.info('started generation')
121 | n_gpus = torch.cuda.device_count()
122 | world_size = n_gpus
123 | with Pool(world_size) as pool:
124 | func = functools.partial(run, world_size=world_size)
125 | result_lists = pool.map(func, range(world_size))
126 |
127 | global_results = []
128 | for i in range(world_size):
129 | global_results = global_results + result_lists[i]
130 |
131 | logger.info("Done")
132 | logger.info('finished running')
133 | else:
134 | logger.info("Not enough GPUs")
135 |
136 | if __name__ == "__main__":
137 | main()
--------------------------------------------------------------------------------
/Visual-ARFT/evaluation_search/traditional_benchmark/evaluation_hotpotqa_ngpu_7b_df.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | from tqdm import tqdm
4 | from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
5 |
6 | # 定义颜色的ANSI代码
7 | RED = '\033[91m'
8 | GREEN = '\033[92m'
9 | YELLOW = '\033[93m'
10 | RESET = '\033[0m' # 重置颜色
11 |
12 | import logging
13 | logging.basicConfig()
14 | logger = logging.getLogger(__name__)
15 | logger.setLevel(logging.INFO)
16 |
17 | import functools
18 | import itertools
19 | import multiprocessing as mp
20 | from argparse import ArgumentParser
21 | from multiprocessing import Pool
22 |
23 | model_path = "/shared/mllm_ckpts/models--Qwen--Qwen2.5-VL-7B-Instruct"
24 |
25 | def run(rank, world_size):
26 | ### 多卡时候,device_map需要设置为cpu,再分配到不同GPU上,不能设置为 auto
27 | model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
28 | model_path,
29 | torch_dtype=torch.bfloat16,
30 | attn_implementation="flash_attention_2",
31 | device_map='cpu',
32 | )
33 | processor = AutoProcessor.from_pretrained(model_path)
34 |
35 | model = model.to(torch.device(rank))
36 | model = model.eval()
37 |
38 | # 加载数据集数据
39 | wikimultihopqa = []
40 | with open('/hotpotqa/dev.jsonl', 'r', encoding='utf-8') as f:
41 | for line in f:
42 | if line.strip(): # 跳过空行
43 | wikimultihopqa.append(json.loads(line))
44 | print(len(wikimultihopqa))
45 |
46 | # wikimultihopqa = wikimultihopqa[:4]
47 |
48 | print(wikimultihopqa[0]['question'])
49 | print(wikimultihopqa[0]['golden_answers'])
50 | print("Rank:" + str(rank))
51 | print("World Size:" + str(world_size))
52 | import math
53 | split_length = math.ceil(len(wikimultihopqa)/world_size)
54 | print("Split Chunk Length:" + str(split_length))
55 | split_wikimultihopqa = wikimultihopqa[int(rank*split_length) : int((rank+1)*split_length)]
56 | print(len(split_wikimultihopqa))
57 | wikimultihopqa = split_wikimultihopqa
58 |
59 | combine_results = []
60 | for i in tqdm(range(len(wikimultihopqa))):
61 | pred_answer = None
62 | query = wikimultihopqa[i]['question']
63 | answer = wikimultihopqa[i]['golden_answers']
64 | item_id = wikimultihopqa[i]['id']
65 | # + 'Answer the question directly.'
66 | input_text = query + '\n' + 'Only Return the answer.'
67 | # print("################################################################")
68 | # print(query)
69 | # print(answer)
70 | try:
71 | messages = [
72 | { "role": "user",
73 | "content": [{"type": "text", "text": input_text}]}
74 | ]
75 | # Preparation for inference
76 | text = processor.apply_chat_template(
77 | messages, tokenize=False, add_generation_prompt=True
78 | )
79 | # image_inputs, video_inputs = process_vision_info(messages)
80 | inputs = processor(
81 | text=[text],
82 | # images=image_inputs,
83 | # videos=video_inputs,
84 | padding=True,
85 | return_tensors="pt",
86 | )
87 | inputs = inputs.to(model.device)
88 |
89 | # Inference: Generation of the output
90 | generated_ids = model.generate(**inputs, max_new_tokens=2048)
91 | generated_ids_trimmed = [
92 | out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
93 | ]
94 | output_text = processor.batch_decode(
95 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
96 | )
97 | result = output_text[0]
98 | # print(result)
99 | pred_answer = result
100 |
101 | except Exception as e:
102 | logger.info("ERROR OCCURES")
103 | logger.info({e})
104 | if pred_answer != None:
105 | combine_results.append(
106 | {'id': item_id, 'pred_answer': pred_answer, 'gt': answer, 'query': query}
107 | )
108 | else:
109 | combine_results.append(
110 | {'id': item_id, 'pred_answer': None, 'gt': answer, 'query': query}
111 | )
112 | with open(f"7b_ori_hotpotqa_direct_infere_{rank}.json", "w", encoding="utf-8") as f:
113 | json.dump(combine_results, f, ensure_ascii=False, indent=4)
114 | return combine_results
115 |
116 | def main():
117 | multiprocess = torch.cuda.device_count() >= 2
118 | mp.set_start_method('spawn')
119 | if multiprocess:
120 | logger.info('started generation')
121 | n_gpus = torch.cuda.device_count()
122 | world_size = n_gpus
123 | with Pool(world_size) as pool:
124 | func = functools.partial(run, world_size=world_size)
125 | result_lists = pool.map(func, range(world_size))
126 |
127 | global_results = []
128 | for i in range(world_size):
129 | global_results = global_results + result_lists[i]
130 |
131 | logger.info("Done")
132 | logger.info('finished running')
133 | else:
134 | logger.info("Not enough GPUs")
135 |
136 | if __name__ == "__main__":
137 | main()
--------------------------------------------------------------------------------
/Visual-ARFT/evaluation_search/traditional_benchmark/evaluation_musique_ngpu_7b_df.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | from tqdm import tqdm
4 | from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
5 |
6 | # 定义颜色的ANSI代码
7 | RED = '\033[91m'
8 | GREEN = '\033[92m'
9 | YELLOW = '\033[93m'
10 | RESET = '\033[0m' # 重置颜色
11 |
12 | import logging
13 | logging.basicConfig()
14 | logger = logging.getLogger(__name__)
15 | logger.setLevel(logging.INFO)
16 |
17 | import functools
18 | import itertools
19 | import multiprocessing as mp
20 | from argparse import ArgumentParser
21 | from multiprocessing import Pool
22 |
23 | model_path = "/shared/mllm_ckpts/models--Qwen--Qwen2.5-VL-7B-Instruct"
24 |
25 | def run(rank, world_size):
26 | ### 多卡时候,device_map需要设置为cpu,再分配到不同GPU上,不能设置为 auto
27 | model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
28 | model_path,
29 | torch_dtype=torch.bfloat16,
30 | attn_implementation="flash_attention_2",
31 | device_map='cpu',
32 | )
33 | processor = AutoProcessor.from_pretrained(model_path)
34 |
35 | model = model.to(torch.device(rank))
36 | model = model.eval()
37 |
38 | # 加载数据集数据
39 | wikimultihopqa = []
40 | with open('/musique/dev.jsonl', 'r', encoding='utf-8') as f:
41 | for line in f:
42 | if line.strip(): # 跳过空行
43 | wikimultihopqa.append(json.loads(line))
44 | print(len(wikimultihopqa))
45 |
46 | # wikimultihopqa = wikimultihopqa[:4]
47 |
48 | print(wikimultihopqa[0]['question'])
49 | print(wikimultihopqa[0]['golden_answers'])
50 | print("Rank:" + str(rank))
51 | print("World Size:" + str(world_size))
52 | import math
53 | split_length = math.ceil(len(wikimultihopqa)/world_size)
54 | print("Split Chunk Length:" + str(split_length))
55 | split_wikimultihopqa = wikimultihopqa[int(rank*split_length) : int((rank+1)*split_length)]
56 | print(len(split_wikimultihopqa))
57 | wikimultihopqa = split_wikimultihopqa
58 |
59 | combine_results = []
60 | for i in tqdm(range(len(wikimultihopqa))):
61 | pred_answer = None
62 | query = wikimultihopqa[i]['question']
63 | answer = wikimultihopqa[i]['golden_answers']
64 | item_id = wikimultihopqa[i]['id']
65 | # + 'Answer the question directly.'
66 | input_text = query + '\n' + 'Only Return the answer.'
67 | # print("################################################################")
68 | # print(query)
69 | # print(answer)
70 | try:
71 | messages = [
72 | { "role": "user",
73 | "content": [{"type": "text", "text": input_text}]}
74 | ]
75 | # Preparation for inference
76 | text = processor.apply_chat_template(
77 | messages, tokenize=False, add_generation_prompt=True
78 | )
79 | # image_inputs, video_inputs = process_vision_info(messages)
80 | inputs = processor(
81 | text=[text],
82 | # images=image_inputs,
83 | # videos=video_inputs,
84 | padding=True,
85 | return_tensors="pt",
86 | )
87 | inputs = inputs.to(model.device)
88 |
89 | # Inference: Generation of the output
90 | generated_ids = model.generate(**inputs, max_new_tokens=2048)
91 | generated_ids_trimmed = [
92 | out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
93 | ]
94 | output_text = processor.batch_decode(
95 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
96 | )
97 | result = output_text[0]
98 | # print(result)
99 | pred_answer = result
100 |
101 | except Exception as e:
102 | logger.info("ERROR OCCURES")
103 | logger.info({e})
104 | if pred_answer != None:
105 | combine_results.append(
106 | {'id': item_id, 'pred_answer': pred_answer, 'gt': answer, 'query': query}
107 | )
108 | else:
109 | combine_results.append(
110 | {'id': item_id, 'pred_answer': None, 'gt': answer, 'query': query}
111 | )
112 | with open(f"7b_ori_musique_direct_infere_{rank}.json", "w", encoding="utf-8") as f:
113 | json.dump(combine_results, f, ensure_ascii=False, indent=4)
114 | return combine_results
115 |
116 | def main():
117 | multiprocess = torch.cuda.device_count() >= 2
118 | mp.set_start_method('spawn')
119 | if multiprocess:
120 | logger.info('started generation')
121 | n_gpus = torch.cuda.device_count()
122 | world_size = n_gpus
123 | with Pool(world_size) as pool:
124 | func = functools.partial(run, world_size=world_size)
125 | result_lists = pool.map(func, range(world_size))
126 |
127 | global_results = []
128 | for i in range(world_size):
129 | global_results = global_results + result_lists[i]
130 |
131 | logger.info("Done")
132 | logger.info('finished running')
133 | else:
134 | logger.info("Not enough GPUs")
135 |
136 | if __name__ == "__main__":
137 | main()
--------------------------------------------------------------------------------
/Visual-ARFT/setup.sh:
--------------------------------------------------------------------------------
1 | cd src/visual_arft
2 | pip install -e ".[dev]"
3 |
4 | # Addtional modules
5 | pip install wandb==0.18.3
6 | pip install tensorboardx
7 | pip install qwen_vl_utils torchvision
8 | pip install flash-attn --no-build-isolation
9 |
10 | # vLLM support
11 | pip install vllm==0.7.2
12 |
13 | # fix transformers version
14 | pip install git+https://github.com/huggingface/transformers.git@336dc69d63d56f232a183a3e7f52790429b871ef
15 |
--------------------------------------------------------------------------------
/Visual-ARFT/src/scripts/run_grpo_agent_code_3b_1_2k_new2_gpu8.sh:
--------------------------------------------------------------------------------
1 | export DEBUG_MODE="true" # Enable Debug if you want to see the rollout of model during RL
2 | export LOG_PATH="./log_qwen25vl_3b_grpo_agent_code_1_2k_new2_gpu8.txt"
3 |
4 | torchrun --nproc_per_node="8" \
5 | --nnodes="1" \
6 | --node_rank="0" \
7 | --master_addr="127.0.0.1" \
8 | --master_port="12345" \
9 | /src/visual_arft/src/open_r1/grpo_agent_code.py \
10 | --output_dir /share_models/Qwen2.5-VL-3B-Instruct_GRPO_agent_code_1_2k_new2_gpu8 \
11 | --model_name_or_path /shared/mllm_ckpts/models--Qwen--Qwen2.5-VL-3B-Instruct \
12 | --dataset_name /train_data/rft_agent_code_1_2k.json \
13 | --deepspeed /src/visual_arft/local_scripts/zero3_offload.json \
14 | --max_prompt_length 2048 \
15 | --per_device_train_batch_size 1 \
16 | --gradient_accumulation_steps 2 \
17 | --logging_steps 1 \
18 | --bf16 true \
19 | --report_to wandb \
20 | --gradient_checkpointing true \
21 | --attn_implementation flash_attention_2 \
22 | --max_pixels 401408 \
23 | --num_train_epochs 10 \
24 | --run_name Qwen25-VL-3B-GRPO-Agent-code-1_2k-new2-gpu8 \
25 | --save_steps 100 \
26 | --save_only_model true \
27 | --num_generations 8
--------------------------------------------------------------------------------
/Visual-ARFT/src/scripts/run_grpo_agent_code_7b_1_2k_new2_gpu8.sh:
--------------------------------------------------------------------------------
1 | export DEBUG_MODE="true" # Enable Debug if you want to see the rollout of model during RL
2 | export LOG_PATH="./log_qwen25vl_7b_grpo_agent_code_1_2k_new2_gpu8.txt"
3 |
4 | torchrun --nproc_per_node="8" \
5 | --nnodes="1" \
6 | --node_rank="0" \
7 | --master_addr="127.0.0.1" \
8 | --master_port="12345" \
9 | /src/visual_arft/src/open_r1/grpo_agent_code.py \
10 | --output_dir /share_models/Qwen2.5-VL-7B-Instruct_GRPO_agent_code_1_2k_new2_gpu8 \
11 | --model_name_or_path /share_model/Qwen2.5-VL-7B-Instruct \
12 | --dataset_name /train_data/rft_agent_code_1_2k.json \
13 | --deepspeed /src/visual_arft/local_scripts/zero3_offload.json \
14 | --max_prompt_length 2048 \
15 | --per_device_train_batch_size 1 \
16 | --gradient_accumulation_steps 2 \
17 | --logging_steps 1 \
18 | --bf16 true \
19 | --report_to wandb \
20 | --gradient_checkpointing true \
21 | --attn_implementation flash_attention_2 \
22 | --max_pixels 401408 \
23 | --num_train_epochs 10 \
24 | --run_name Qwen25-VL-7B-GRPO-Agent-code-1_2k-new2-gpu8 \
25 | --save_steps 100 \
26 | --save_only_model true \
27 | --num_generations 8
--------------------------------------------------------------------------------
/Visual-ARFT/src/scripts/run_grpo_agent_search_3b_gpu8.sh:
--------------------------------------------------------------------------------
1 | export DEBUG_MODE="true" # Enable Debug if you want to see the rollout of model during RL
2 | export LOG_PATH="./log_qwen25vl_3b_grpo_agent_search_data20_63_gpu8.txt"
3 |
4 | torchrun --nproc_per_node="8" \
5 | --nnodes="1" \
6 | --node_rank="0" \
7 | --master_addr="127.0.0.1" \
8 | --master_port="12345" \
9 | /src/visual_arft/src/open_r1/grpo_agent_search.py \
10 | --output_dir /share_models/Qwen2.5-VL-3B-Instruct_GRPO_agent_search_data20_63_gpu8 \
11 | --model_name_or_path /shared/mllm_ckpts/models--Qwen--Qwen2.5-VL-3B-Instruct \
12 | --dataset_name /train_data/rft_agent_20.json \
13 | --deepspeed /src/visual_arft/local_scripts/zero3_offload.json \
14 | --max_prompt_length 2048 \
15 | --per_device_train_batch_size 1 \
16 | --gradient_accumulation_steps 2 \
17 | --logging_steps 1 \
18 | --bf16 true \
19 | --report_to wandb \
20 | --gradient_checkpointing true \
21 | --attn_implementation flash_attention_2 \
22 | --max_pixels 401408 \
23 | --num_train_epochs 400 \
24 | --run_name Qwen25-VL-3B-GRPO-Agent-Search-data20-63-gpu8 \
25 | --save_steps 100 \
26 | --save_only_model true \
27 | --num_generations 8
--------------------------------------------------------------------------------
/Visual-ARFT/src/scripts/run_grpo_agent_search_7b_gpu8.sh:
--------------------------------------------------------------------------------
1 | export DEBUG_MODE="true" # Enable Debug if you want to see the rollout of model during RL
2 | export LOG_PATH="./log_qwen25vl_7b_grpo_agent_search_data20_63_gpu8.txt"
3 |
4 | torchrun --nproc_per_node="8" \
5 | --nnodes="1" \
6 | --node_rank="0" \
7 | --master_addr="127.0.0.1" \
8 | --master_port="12345" \
9 | /src/visual_arft/src/open_r1/grpo_agent_search.py \
10 | --output_dir /share_models/Qwen2.5-VL-7B-Instruct_GRPO_agent_search_data20_63_gpu8 \
11 | --model_name_or_path /share_model/Qwen2.5-VL-7B-Instruct \
12 | --dataset_name /train_data/rft_agent_20.json \
13 | --deepspeed /src/visual_arft/local_scripts/zero3_offload.json \
14 | --max_prompt_length 2048 \
15 | --per_device_train_batch_size 1 \
16 | --gradient_accumulation_steps 2 \
17 | --logging_steps 1 \
18 | --bf16 true \
19 | --report_to wandb \
20 | --gradient_checkpointing true \
21 | --attn_implementation flash_attention_2 \
22 | --max_pixels 401408 \
23 | --num_train_epochs 400 \
24 | --run_name Qwen25-VL-7B-GRPO-Agent-Search-data20-63-gpu8 \
25 | --save_steps 100 \
26 | --save_only_model true \
27 | --num_generations 8
--------------------------------------------------------------------------------
/Visual-ARFT/src/visual_arft/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # PyPI configuration file
171 | .pypirc
172 |
173 | # Temp folders
174 | data/
175 | wandb/
176 | scripts/
177 | checkpoints/
178 | .vscode/
--------------------------------------------------------------------------------
/Visual-ARFT/src/visual_arft/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: style quality
2 |
3 | # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
4 | export PYTHONPATH = src
5 |
6 | check_dirs := src
7 |
8 | style:
9 | black --line-length 119 --target-version py310 $(check_dirs) setup.py
10 | isort $(check_dirs) setup.py
11 |
12 | quality:
13 | black --check --line-length 119 --target-version py310 $(check_dirs) setup.py
14 | isort --check-only $(check_dirs) setup.py
15 | flake8 --max-line-length 119 $(check_dirs) setup.py
16 |
17 |
18 | # Evaluation
19 |
20 | evaluate:
21 |
--------------------------------------------------------------------------------
/Visual-ARFT/src/visual_arft/configs/ddp.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | distributed_type: MULTI_GPU
4 | downcast_bf16: 'no'
5 | gpu_ids: all
6 | machine_rank: 0
7 | main_training_function: main
8 | mixed_precision: bf16
9 | num_machines: 1
10 | num_processes: 8
11 | rdzv_backend: static
12 | same_network: true
13 | tpu_env: []
14 | tpu_use_cluster: false
15 | tpu_use_sudo: false
16 | use_cpu: false
17 |
--------------------------------------------------------------------------------
/Visual-ARFT/src/visual_arft/configs/qwen2vl_sft_config.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | model_name_or_path: Qwen/Qwen2-VL-2B-Instruct
3 | model_revision: main
4 | torch_dtype: bfloat16
5 |
6 | # Data training arguments
7 | dataset_name: MMInstruction/Clevr_CoGenT_TrainA_R1
8 | dataset_configs:
9 | - all
10 | preprocessing_num_workers: 8
11 |
12 | # SFT trainer config
13 | bf16: true
14 | do_eval: true
15 | eval_strategy: "no"
16 | gradient_accumulation_steps: 4
17 | gradient_checkpointing: true
18 | gradient_checkpointing_kwargs:
19 | use_reentrant: false
20 | hub_model_id: Qwen2-VL-2B-Instruct-SFT
21 | hub_strategy: every_save
22 | learning_rate: 2.0e-05
23 | log_level: info
24 | logging_steps: 5
25 | logging_strategy: steps
26 | lr_scheduler_type: cosine
27 | packing: true
28 | max_seq_length: 4096
29 | max_steps: -1
30 | num_train_epochs: 1
31 | output_dir: data/Qwen2-VL-2B-Instruct-SFT
32 | overwrite_output_dir: true
33 | per_device_eval_batch_size: 4
34 | per_device_train_batch_size: 4
35 | push_to_hub: true
36 | report_to:
37 | - wandb
38 | save_strategy: "no"
39 | seed: 42
40 | warmup_ratio: 0.1
--------------------------------------------------------------------------------
/Visual-ARFT/src/visual_arft/configs/zero2.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_multinode_launcher: standard
5 | offload_optimizer_device: none
6 | offload_param_device: none
7 | zero3_init_flag: false
8 | zero_stage: 2
9 | distributed_type: DEEPSPEED
10 | downcast_bf16: 'no'
11 | machine_rank: 0
12 | main_training_function: main
13 | mixed_precision: bf16
14 | num_machines: 1
15 | num_processes: 8
16 | rdzv_backend: static
17 | same_network: true
18 | tpu_env: []
19 | tpu_use_cluster: false
20 | tpu_use_sudo: false
21 | use_cpu: false
--------------------------------------------------------------------------------
/Visual-ARFT/src/visual_arft/configs/zero3.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_multinode_launcher: standard
5 | offload_optimizer_device: none
6 | offload_param_device: none
7 | zero3_init_flag: true
8 | zero3_save_16bit_model: true
9 | zero_stage: 3
10 | distributed_type: DEEPSPEED
11 | downcast_bf16: 'no'
12 | machine_rank: 0
13 | main_training_function: main
14 | mixed_precision: bf16
15 | num_machines: 1
16 | num_processes: 8
17 | rdzv_backend: static
18 | same_network: true
19 | tpu_env: []
20 | tpu_use_cluster: false
21 | tpu_use_sudo: false
22 | use_cpu: false
23 |
--------------------------------------------------------------------------------
/Visual-ARFT/src/visual_arft/local_scripts/create_vision_cot_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import base64
3 | import concurrent.futures
4 | import io
5 | import json
6 | import os
7 | import random
8 | import re
9 | import time
10 | from concurrent.futures import ThreadPoolExecutor
11 | from functools import partial
12 | from io import BytesIO
13 | from typing import Dict, List
14 |
15 | import matplotlib.pyplot as plt
16 | import numpy as np
17 | import pandas as pd
18 | from datasets import Dataset, concatenate_datasets, load_dataset, load_from_disk
19 | from tqdm import tqdm
20 |
21 | import bytedtos
22 | import seaborn as sns
23 | import yaml
24 | from openai import AzureOpenAI
25 | from PIL import Image
26 | from pillow_avif import AvifImagePlugin
27 |
28 |
29 | PROMPT_FORMAT = """I will provide you with an image, an original question, and its answer related to the image. Your task is to rewrite the question in such a way that answering it requires step-by-step Chain-of-Thought (CoT) reasoning with numerical or mathematical expressions where applicable. The reasoning process can include expressions like "let me think," "oh, I see," or other natural language thought expressions.
30 |
31 | Please make sure your question is to ask for a certain answer with a certain value, do not ask for open-ended answer, and the answer is correct and easy to verify via simple protocol, like "2" or "A".
32 |
33 | Please strictly do not include "Answer:" in the question part to avoid confusion and leakage.
34 |
35 | Input Format:
36 | Original Question: {original_question}
37 | Original Answer: {original_answer}
38 |
39 | Output Format:
40 | Question: [rewrite the question if necessary]
41 | Answer: [answer with reasoning steps, including calculations where applicable]
42 | step-by-step reasoning process
43 | easy to verify answer
44 | """
45 |
46 |
47 | def get_image_data_url(image_input):
48 | if isinstance(image_input, str) and image_input.startswith("data:"):
49 | return image_input
50 |
51 | if isinstance(image_input, str) and image_input.startswith("http"):
52 | image_input = load_image(image_input)
53 |
54 | if isinstance(image_input, str):
55 | image_input = Image.open(image_input)
56 |
57 | if not isinstance(image_input, Image.Image):
58 | raise ValueError("Unsupported image input type")
59 |
60 | if image_input.mode != "RGB":
61 | image_input = image_input.convert("RGB")
62 |
63 | buffer = BytesIO()
64 | image_input.save(buffer, format="JPEG")
65 | img_bytes = buffer.getvalue()
66 | base64_data = base64.b64encode(img_bytes).decode("utf-8")
67 | return f"data:image/jpeg;base64,{base64_data}"
68 |
69 |
70 | def gpt4o_query(image, prompt, max_retries=5, initial_delay=3):
71 | if image is None:
72 | return None
73 |
74 | data_url_list = [get_image_data_url(image)]
75 | client = AzureOpenAI(
76 | azure_endpoint="YOUR_AZURE_ENDPOINT",
77 | api_version="2023-07-01-preview",
78 | api_key="YOUR_API_KEY",
79 | )
80 |
81 | for attempt in range(max_retries):
82 | try:
83 | messages = [
84 | {
85 | "role": "system",
86 | "content": "You are an expert to analyze the image and provide useful information for users.",
87 | },
88 | {
89 | "role": "user",
90 | "content": [
91 | {"type": "text", "text": prompt},
92 | ],
93 | },
94 | ]
95 |
96 | for data_url in data_url_list:
97 | messages[1]["content"].insert(
98 | 0, {"type": "image_url", "image_url": {"url": data_url}}
99 | )
100 |
101 | response = client.chat.completions.create(
102 | model="gpt-4o-2024-08-06",
103 | messages=messages,
104 | temperature=0.2,
105 | max_tokens=8192,
106 | )
107 | return response.choices[0].message.content
108 |
109 | except Exception as e:
110 | if attempt == max_retries - 1:
111 | raise Exception(
112 | f"Failed after {max_retries} attempts. Last error: {str(e)}"
113 | )
114 | delay = initial_delay * (2**attempt) + random.uniform(
115 | 0, 0.1 * initial_delay * (2**attempt)
116 | )
117 | time.sleep(delay)
118 |
119 |
120 | def process_single_item(example):
121 | try:
122 | image_path = example["image_path"]
123 | formatted_prompt = PROMPT_FORMAT.format(
124 | original_question=example["question"], original_answer=example["answer"]
125 | )
126 |
127 | response = gpt4o_query(image_path, formatted_prompt)
128 | example["gpt4o_response"] = response
129 | return example
130 | except Exception as e:
131 | print(f"Error processing item: {str(e)}")
132 | example["gpt4o_response"] = None
133 | return example
134 |
135 |
136 | def main():
137 | dataset_path = "path/to/your/dataset"
138 | full_dataset = load_from_disk(dataset_path)
139 |
140 | processed_dataset = full_dataset.map(
141 | function=partial(process_single_item),
142 | num_proc=256,
143 | desc="Processing dataset with GPT-4o",
144 | keep_in_memory=True,
145 | )
146 |
147 | output_path = f"{dataset_path}_processed"
148 | processed_dataset.save_to_disk(output_path)
149 | print(f"Processed dataset saved to: {output_path}")
150 |
151 |
152 | if __name__ == "__main__":
153 | main()
154 |
--------------------------------------------------------------------------------
/Visual-ARFT/src/visual_arft/local_scripts/lmms_eval_qwen2vl.sh:
--------------------------------------------------------------------------------
1 | export HF_HOME=""
2 | export HF_TOKEN=""
3 | export HF_HUB_ENABLE_HF_TRANSFER="1"
4 |
5 | export API_TYPE=""
6 | export AZURE_ENDPOINT=""
7 | export AZURE_API_KEY=""
8 | export API_VERSION=""
9 | export MODEL_VERSION=""
10 | export NAVIT_ATTENTION_IMPLEMENTATION="eager"
11 |
12 | # Prompt for installation with 3-second timeout
13 | read -t 3 -p "Do you want to install dependencies? (YES/no, timeout in 3s): " install_deps || true
14 | if [ "$install_deps" = "YES" ]; then
15 | # Prepare the environment
16 | pip3 install --upgrade pip
17 | pip3 install -U setuptools
18 |
19 | cd
20 | if [ ! -d "maas_engine" ]; then
21 | git clone
22 | else
23 | echo "maas_engine directory already exists, skipping clone"
24 | fi
25 | cd maas_engine
26 | git pull
27 | git checkout
28 | pip3 install --no-cache-dir --no-build-isolation -e ".[standalone]"
29 |
30 | current_version=$(pip3 show transformers | grep Version | cut -d' ' -f2)
31 | if [ "$current_version" != "4.46.2" ]; then
32 | echo "Installing transformers 4.46.2 (current version: $current_version)"
33 | pip3 install transformers==4.46.2
34 | else
35 | echo "transformers 4.46.2 is already installed"
36 | fi
37 |
38 | cd
39 | rm -rf
40 | pip3 install -e .
41 | pip3 install -U pydantic
42 | pip3 install Levenshtein
43 | pip3 install nltk
44 | python3 -c "import nltk; nltk.download('wordnet', quiet=True); nltk.download('punkt', quiet=True)"
45 | fi
46 |
47 | TASKS=mmmu_val,mathvista_testmini,mmmu_pro
48 | MODEL_BASENAME=qwen2_vl
49 |
50 | model_checkpoint=""
51 | echo "MODEL_BASENAME: ${MODEL_BASENAME}"
52 | cd
53 |
54 | python3 -m accelerate.commands.launch --num_processes=8 --main_process_port=12345 lmms_eval \
55 | --model qwen2_vl \
56 | --model_args=pretrained=${model_checkpoint},max_pixels=2359296 \
57 | --tasks ${TASKS} \
58 | --batch_size 1 \
59 | --log_samples \
60 | --log_samples_suffix ${MODEL_BASENAME} \
61 | --output_path ./logs
--------------------------------------------------------------------------------
/Visual-ARFT/src/visual_arft/local_scripts/prepare_hf_data.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import seaborn as sns
3 | import pandas as pd
4 | import random
5 | from typing import List, Dict
6 | import numpy as np
7 | from concurrent.futures import ThreadPoolExecutor
8 | from tqdm import tqdm
9 | import datasets
10 |
11 | import io
12 | from datasets import load_dataset, load_from_disk, concatenate_datasets
13 | from PIL import Image
14 | from tqdm import tqdm
15 | from functools import partial
16 | from pillow_avif import AvifImagePlugin
17 | from datasets import Dataset
18 | import json
19 | import yaml
20 | import os
21 | import re
22 | import time
23 | import random
24 | import base64
25 | from openai import AzureOpenAI
26 | import concurrent.futures
27 | from typing import List, Dict
28 | import argparse
29 | import time
30 |
31 |
32 | def extract_problem_solution(gpt4o_response):
33 | # Split the response into parts
34 | parts = gpt4o_response.split("")
35 |
36 | # Extract the problem (first part before any tags)
37 | problem = parts[0].strip()
38 | # Remove "Question:" prefix if it exists
39 | problem = re.sub(r"^Question:\s*", "", problem)
40 | # Remove "Answer:" at the end of the problem
41 | problem = re.sub(r"\s*Answer:\s*$", "", problem).strip()
42 |
43 | # Combine all the reasoning steps into a single block
44 | think_parts = [p.split("")[0].strip() for p in parts[1:] if "" in p]
45 | solution = f"{' '.join(think_parts)}"
46 |
47 | # Add the final answer if it exists, removing "Answer:" prefix
48 | if "" in gpt4o_response:
49 | final_answer = (
50 | gpt4o_response.split("")[-1].split("")[0].strip()
51 | )
52 | final_answer = re.sub(r"^Answer:\s*", "", final_answer)
53 | solution += f"\n\n{final_answer}"
54 |
55 | return problem, solution
56 |
57 |
58 | def load_image_from_path(image_path):
59 | try:
60 | img = Image.open(image_path)
61 | return img
62 | except Exception as e:
63 | print(f"Error loading image {image_path}: {str(e)}")
64 | return None
65 |
66 |
67 | def process_raw_data(raw_data):
68 | # Parse the raw data if it's a string
69 | if isinstance(raw_data, str):
70 | data = json.loads(raw_data)
71 | else:
72 | data = raw_data
73 |
74 | # Extract problem and solution
75 | try:
76 | problem, solution = extract_problem_solution(data["gpt4o_response"])
77 | image = load_image_from_path(data["image_path"])
78 |
79 | return {
80 | "image": image,
81 | "problem": problem,
82 | "solution": solution,
83 | "original_question": data["question"],
84 | "original_answer": data["answer"],
85 | }
86 | except Exception as e:
87 | print(f"Error processing data {data}: {str(e)}")
88 | return {
89 | "image": None,
90 | "problem": None,
91 | "solution": None,
92 | "original_question": None,
93 | "original_answer": None,
94 | }
95 |
96 |
97 | raw_data_list = [
98 | "/path/to/reasoning_data_with_response_90k_verified",
99 | ]
100 |
101 | raw_data = concatenate_datasets([load_from_disk(path) for path in raw_data_list])
102 |
103 | processed_data = raw_data.map(process_raw_data, num_proc=128).shuffle(seed=42)
104 |
105 | hf_dict = {
106 | "image": [],
107 | "problem": [],
108 | "solution": [],
109 | "original_question": [],
110 | "original_answer": [],
111 | }
112 |
113 | for item in tqdm(processed_data):
114 | hf_dict["image"].append(item["image"])
115 | hf_dict["problem"].append(item["problem"])
116 | hf_dict["solution"].append(item["solution"])
117 | hf_dict["original_question"].append(item["original_question"])
118 | hf_dict["original_answer"].append(item["original_answer"])
119 |
120 |
121 | features = datasets.Features(
122 | {
123 | "image": datasets.Image(),
124 | "problem": datasets.Value("string"),
125 | "solution": datasets.Value("string"),
126 | "original_question": datasets.Value("string"),
127 | "original_answer": datasets.Value("string"),
128 | }
129 | )
130 |
131 |
132 | def has_empty_tags(text):
133 | # Pattern to match empty tags like
134 | pattern = r"<[^>]+>[^>]+>"
135 | return bool(re.search(pattern, text))
136 |
137 |
138 | def has_answer_pattern(text):
139 | if "Answer:" in text:
140 | return True
141 | return False
142 |
143 |
144 | def has_valid_image_size(example): # for Qwen2-VL-2B's processor requirement
145 | # Assuming the image is in a format that can be checked for dimensions
146 | # You might need to adjust this depending on how the image is stored in your dataset
147 | try:
148 | image = example["image"] # or however your image is accessed
149 | if isinstance(image, dict) and "height" in image and "width" in image:
150 | return image["height"] >= 28 and image["width"] >= 28
151 | # If image is a PIL Image or similar
152 | return image.height >= 28 and image.width >= 28
153 | except:
154 | return False
155 |
156 |
157 | ds = datasets.Dataset.from_dict(hf_dict, features=features)
158 | ds = ds.filter(
159 | lambda x: not has_empty_tags(x["solution"])
160 | and not has_answer_pattern(x["problem"])
161 | and has_valid_image_size(x)
162 | and x["image"] is not None,
163 | num_proc=128,
164 | )
165 | # Push to Hugging Face Hub
166 | ds.push_to_hub("path/to/your/dataset")
167 |
--------------------------------------------------------------------------------
/Visual-ARFT/src/visual_arft/local_scripts/train_aria_moe.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export NCCL_BLOCKING_WAIT=0
4 | export TOKENIZERS_PARALLELISM=false
5 | export OMP_NUM_THREADS=8
6 | export NCCL_IB_DISABLE=0
7 | export NCCL_IB_GID_INDEX=3
8 | export NCCL_SOCKET_IFNAME=eth0
9 | export NCCL_DEBUG=INFO
10 |
11 | # CONFIG Huggingface
12 | # export HF_TOKEN=""
13 | export HF_TOKEN=""
14 | export HF_HOME="$HOME/.cache/huggingface"
15 | export HF_HUB_ENABLE_HF_TRANSFER="1"
16 |
17 | export NCCL_DEBUG=INFO
18 |
19 | GPUS="0,1,2,3,4,5,6,7"
20 |
21 | # 取 worker0 第一个 port
22 | ports=($(echo $METIS_WORKER_0_PORT | tr ',' ' '))
23 | port=${ports[0]}
24 | port_in_cmd="$(echo "${METIS_WORKER_0_PORT:-2000}" | awk -F',' '{print $1}')"
25 |
26 | echo "total workers: ${ARNOLD_WORKER_NUM}"
27 | echo "cur worker id: ${ARNOLD_ID}"
28 | echo "gpus per worker: ${ARNOLD_WORKER_GPU}"
29 | echo "master ip: ${METIS_WORKER_0_HOST}"
30 | echo "master port: ${port}"
31 | echo "master port in cmd: ${port_in_cmd}"
32 |
33 | # export WANDB_BASE_URL=https://api.wandb.ai
34 | # export WANDB_API_KEY=""
35 | # wandb login $WANDB_API_KEY
36 |
37 | export WANDB_BASE_URL=https://api.wandb.ai
38 | export WANDB_PROJECT=vision-reasoning
39 | export WANDB_API_KEY=""
40 | export WANDB_RUN_NAME=Qwen-VL-2B-GRPO-$(date +%Y-%m-%d-%H-%M-%S)
41 | wandb login $WANDB_API_KEY
42 |
43 | cd /home/tiger/multimodal-open-r1
44 | # pip3 install vllm==0.6.6.post1
45 | pip3 install -e ".[dev]"
46 | pip3 install wandb==0.18.3
47 |
48 | torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" \
49 | --nnodes="${ARNOLD_WORKER_NUM}" \
50 | --node_rank="${ARNOLD_ID}" \
51 | --master_addr="${METIS_WORKER_0_HOST}" \
52 | --master_port="${port_in_cmd}" \
53 | src/open_r1/grpo.py \
54 | --deepspeed scripts/zero3.json \
55 | --output_dir Aria-GRPO-mini_cot_80k \
56 | --model_name_or_path rhymes-ai/Aria \
57 | --dataset_name luodian/mini_cot_80k \
58 | --max_prompt_length 8192 \
59 | --per_device_train_batch_size 1 \
60 | --gradient_accumulation_steps 1 \
61 | --logging_steps 1 \
62 | --bf16 \
63 | --report_to wandb \
64 | --gradient_checkpointing true \
65 | --attn_implementation eager \
66 | --save_total_limit 8 \
67 | --num_train_epochs 1 \
68 | --run_name $WANDB_RUN_NAME
69 |
--------------------------------------------------------------------------------
/Visual-ARFT/src/visual_arft/local_scripts/train_qwen2_vl.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export NCCL_BLOCKING_WAIT=0
4 | export TOKENIZERS_PARALLELISM=false
5 | export OMP_NUM_THREADS=8
6 | export NCCL_IB_DISABLE=0
7 | export NCCL_IB_GID_INDEX=3
8 | export NCCL_SOCKET_IFNAME=eth0
9 | export NCCL_DEBUG=INFO
10 |
11 | GPUS="0,1,2,3,4,5,6,7"
12 |
13 | # 取 worker0 第一个 port
14 | ports=($(echo $METIS_WORKER_0_PORT | tr ',' ' '))
15 | port=${ports[0]}
16 | port_in_cmd="$(echo "${METIS_WORKER_0_PORT:-2000}" | awk -F',' '{print $1}')"
17 |
18 | echo "total workers: ${ARNOLD_WORKER_NUM}"
19 | echo "cur worker id: ${ARNOLD_ID}"
20 | echo "gpus per worker: ${ARNOLD_WORKER_GPU}"
21 | echo "master ip: ${METIS_WORKER_0_HOST}"
22 | echo "master port: ${port}"
23 | echo "master port in cmd: ${port_in_cmd}"
24 |
25 | # export WANDB_BASE_URL=https://api.wandb.ai
26 | # export WANDB_API_KEY=""
27 | # wandb login $WANDB_API_KEY
28 |
29 | export WANDB_BASE_URL=https://api.wandb.ai
30 | export WANDB_PROJECT=vision-reasoning
31 | export WANDB_API_KEY=""
32 | export WANDB_RUN_NAME=Qwen-VL-2B-GRPO-$(date +%Y-%m-%d-%H-%M-%S)
33 | wandb login $WANDB_API_KEY
34 |
35 | cd /home/tiger/multimodal-open-r1
36 | # pip3 install vllm==0.6.6.post1
37 | pip3 install -e ".[dev]"
38 | pip3 install wandb==0.18.3
39 |
40 | torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" \
41 | --nnodes="${ARNOLD_WORKER_NUM}" \
42 | --node_rank="${ARNOLD_ID}" \
43 | --master_addr="${METIS_WORKER_0_HOST}" \
44 | --master_port="${port_in_cmd}" \
45 | src/open_r1/grpo.py \
46 | --deepspeed scripts/zero3.json \
47 | --output_dir checkpoints/${WANDB_RUN_NAME} \
48 | --model_name_or_path Qwen/Qwen2-VL-2B-Instruct \
49 | --dataset_name luodian/${DATASET_NAME} \
50 | --max_prompt_length 8192 \
51 | --per_device_train_batch_size 1 \
52 | --gradient_accumulation_steps 1 \
53 | --logging_steps 1 \
54 | --bf16 \
55 | --report_to wandb \
56 | --gradient_checkpointing true \
57 | --attn_implementation flash_attention_2 \
58 | --max_pixels 2359296 \
59 | --save_total_limit 8 \
60 | --num_train_epochs 1 \
61 | --run_name $WANDB_RUN_NAME
62 |
--------------------------------------------------------------------------------
/Visual-ARFT/src/visual_arft/local_scripts/zero1_no_optimizer.json:
--------------------------------------------------------------------------------
1 | {
2 | "zero_optimization": {
3 | "stage": 1,
4 | "allgather_partitions": true,
5 | "allgather_bucket_size": 1e9,
6 | "overlap_comm": false,
7 | "reduce_scatter": true,
8 | "reduce_bucket_size": 1e9,
9 | "contiguous_gradients": true
10 | },
11 | "fp16": {
12 | "enabled": "auto",
13 | "auto_cast": true,
14 | "loss_scale": 0,
15 | "initial_scale_power": 32,
16 | "loss_scale_window": 1000,
17 | "hysteresis": 2,
18 | "min_loss_scale": 1
19 | },
20 | "bf16": {
21 | "enabled": "auto"
22 | },
23 | "gradient_accumulation_steps": "auto",
24 | "gradient_clipping": "auto",
25 | "steps_per_print": 1,
26 | "train_batch_size": "auto",
27 | "train_micro_batch_size_per_gpu": "auto",
28 | "wall_clock_breakdown": true
29 | }
--------------------------------------------------------------------------------
/Visual-ARFT/src/visual_arft/local_scripts/zero2.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 | "zero_optimization": {
23 | "stage": 2,
24 | "offload_optimizer": {
25 | "device": "none",
26 | "pin_memory": true
27 | },
28 | "allgather_partitions": true,
29 | "allgather_bucket_size": 2e8,
30 | "overlap_comm": false,
31 | "reduce_scatter": true,
32 | "reduce_bucket_size": 2e8,
33 | "contiguous_gradients": true
34 | },
35 | "gradient_accumulation_steps": "auto",
36 | "gradient_clipping": "auto",
37 | "steps_per_print": 100,
38 | "train_batch_size": "auto",
39 | "train_micro_batch_size_per_gpu": "auto",
40 | "wall_clock_breakdown": false
41 | }
--------------------------------------------------------------------------------
/Visual-ARFT/src/visual_arft/local_scripts/zero3.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 |
14 | "zero_optimization": {
15 | "stage": 3,
16 | "offload_optimizer": {
17 | "device": "none",
18 | "pin_memory": false
19 | },
20 | "offload_param": {
21 | "device": "none",
22 | "pin_memory": false
23 | },
24 | "overlap_comm": true,
25 | "contiguous_gradients": true,
26 | "sub_group_size": 1e9,
27 | "reduce_bucket_size": "auto",
28 | "stage3_prefetch_bucket_size": "auto",
29 | "stage3_param_persistence_threshold": "auto",
30 | "stage3_max_live_parameters": 1e9,
31 | "stage3_max_reuse_distance": 1e9,
32 | "stage3_gather_16bit_weights_on_model_save": true
33 | },
34 |
35 | "gradient_accumulation_steps": "auto",
36 | "gradient_clipping": "auto",
37 | "steps_per_print": 100,
38 | "train_batch_size": "auto",
39 | "train_micro_batch_size_per_gpu": "auto",
40 | "wall_clock_breakdown": false
41 | }
--------------------------------------------------------------------------------
/Visual-ARFT/src/visual_arft/local_scripts/zero3.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_multinode_launcher: standard
5 | offload_optimizer_device: none
6 | offload_param_device: none
7 | zero3_init_flag: true
8 | zero3_save_16bit_model: true
9 | zero_stage: 3
10 | distributed_type: DEEPSPEED
11 | downcast_bf16: 'no'
12 | machine_rank: 0
13 | main_training_function: main
14 | mixed_precision: bf16
15 | num_machines: 1
16 | num_processes: 8
17 | rdzv_backend: static
18 | same_network: true
19 | tpu_env: []
20 | tpu_use_cluster: false
21 | tpu_use_sudo: false
22 | use_cpu: false
23 |
--------------------------------------------------------------------------------
/Visual-ARFT/src/visual_arft/local_scripts/zero3_offload.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 | "zero_optimization": {
23 | "stage": 3,
24 | "offload_optimizer": {
25 | "device": "cpu",
26 | "pin_memory": true
27 | },
28 | "offload_param": {
29 | "device": "cpu",
30 | "pin_memory": true
31 | },
32 | "overlap_comm": true,
33 | "contiguous_gradients": true,
34 | "sub_group_size": 1e9,
35 | "reduce_bucket_size": "auto",
36 | "stage3_prefetch_bucket_size": "auto",
37 | "stage3_param_persistence_threshold": "auto",
38 | "stage3_max_live_parameters": 1e9,
39 | "stage3_max_reuse_distance": 1e9,
40 | "gather_16bit_weights_on_model_save": true
41 | },
42 | "gradient_accumulation_steps": "auto",
43 | "gradient_clipping": "auto",
44 | "train_batch_size": "auto",
45 | "train_micro_batch_size_per_gpu": "auto",
46 | "steps_per_print": 1e5,
47 | "wall_clock_breakdown": false
48 | }
--------------------------------------------------------------------------------
/Visual-ARFT/src/visual_arft/run_grpo.sh:
--------------------------------------------------------------------------------
1 | cd src/r1-v
2 |
3 | export DEBUG_MODE="true"
4 | export LOG_PATH="./debug_log_2b.txt"
5 |
6 |
7 |
8 | torchrun --nproc_per_node="8" \
9 | --nnodes="1" \
10 | --node_rank="0" \
11 | --master_addr="127.0.0.1" \
12 | --master_port="12345" \
13 | src/open_r1/grpo.py \
14 | --output_dir \
15 | --model_name_or_path \
16 | --dataset_name \
17 | --max_prompt_length 1024 \
18 | --per_device_train_batch_size 1 \
19 | --gradient_accumulation_steps 2 \
20 | --logging_steps 1 \
21 | --bf16 \
22 | --report_to wandb \
23 | --gradient_checkpointing false \
24 | --attn_implementation flash_attention_2 \
25 | --max_pixels 401408 \
26 | --num_train_epochs 2 \
27 | --run_name Qwen2-VL-2B-GRPO-CLEVR-70k \
28 | --save_steps 100 \
29 | --save_only_model true
--------------------------------------------------------------------------------
/Visual-ARFT/src/visual_arft/setup.cfg:
--------------------------------------------------------------------------------
1 | [isort]
2 | default_section = FIRSTPARTY
3 | ensure_newline_before_comments = True
4 | force_grid_wrap = 0
5 | include_trailing_comma = True
6 | known_first_party = open_r1
7 | known_third_party =
8 | transformers
9 | datasets
10 | fugashi
11 | git
12 | h5py
13 | matplotlib
14 | nltk
15 | numpy
16 | packaging
17 | pandas
18 | psutil
19 | pytest
20 | rouge_score
21 | sacrebleu
22 | seqeval
23 | sklearn
24 | streamlit
25 | torch
26 | tqdm
27 |
28 | line_length = 119
29 | lines_after_imports = 2
30 | multi_line_output = 3
31 | use_parentheses = True
32 |
33 | [flake8]
34 | ignore = E203, E501, E741, W503, W605
35 | max-line-length = 119
36 | per-file-ignores =
37 | # imported but unused
38 | __init__.py: F401
39 |
40 | [tool:pytest]
41 | doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS
--------------------------------------------------------------------------------
/Visual-ARFT/src/visual_arft/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # Adapted from huggingface/transformers: https://github.com/huggingface/transformers/blob/21a2d900eceeded7be9edc445b56877b95eda4ca/setup.py
16 |
17 |
18 | import re
19 | import shutil
20 | from pathlib import Path
21 |
22 | from setuptools import find_packages, setup
23 |
24 |
25 | # Remove stale open_r1.egg-info directory to avoid https://github.com/pypa/pip/issues/5466
26 | stale_egg_info = Path(__file__).parent / "open_r1.egg-info"
27 | if stale_egg_info.exists():
28 | print(
29 | (
30 | "Warning: {} exists.\n\n"
31 | "If you recently updated open_r1, this is expected,\n"
32 | "but it may prevent open_r1 from installing in editable mode.\n\n"
33 | "This directory is automatically generated by Python's packaging tools.\n"
34 | "I will remove it now.\n\n"
35 | "See https://github.com/pypa/pip/issues/5466 for details.\n"
36 | ).format(stale_egg_info)
37 | )
38 | shutil.rmtree(stale_egg_info)
39 |
40 |
41 | # IMPORTANT: all dependencies should be listed here with their version requirements, if any.
42 | # * If a dependency is fast-moving (e.g. transformers), pin to the exact version
43 | _deps = [
44 | "accelerate>=1.2.1",
45 | "bitsandbytes>=0.42.0",
46 | "black>=24.4.2",
47 | "datasets>=3.2.0",
48 | "deepspeed==0.15.4",
49 | "distilabel[vllm,ray,openai]>=1.5.2",
50 | "einops>=0.8.0",
51 | "flake8>=6.0.0",
52 | "hf_transfer>=0.1.4",
53 | "huggingface-hub[cli]>=0.19.2,<1.0",
54 | "isort>=5.12.0",
55 | "liger_kernel==0.5.2",
56 | "lighteval @ git+https://github.com/huggingface/lighteval.git@4f381b352c0e467b5870a97d41cb66b487a2c503#egg=lighteval[math]",
57 | "math-verify", # Used for math verification in grpo
58 | "packaging>=23.0",
59 | "parameterized>=0.9.0",
60 | "pytest",
61 | "safetensors>=0.3.3",
62 | "sentencepiece>=0.1.99",
63 | "torch>=2.5.1",
64 | # "transformers @ git+https://github.com/huggingface/transformers.git@336dc69d63d56f232a183a3e7f52790429b871ef",
65 | "trl==0.14.0",
66 | "vllm==0.6.6.post1",
67 | "wandb>=0.19.1",
68 | "pillow",
69 | ]
70 |
71 | # this is a lookup table with items like:
72 | #
73 | # tokenizers: "tokenizers==0.9.4"
74 | # packaging: "packaging"
75 | #
76 | # some of the values are versioned whereas others aren't.
77 | deps = {b: a for a, b in (re.findall(r"^(([^!=<>~ \[\]]+)(?:\[[^\]]+\])?(?:[!=<>~ ].*)?$)", x)[0] for x in _deps)}
78 |
79 |
80 | def deps_list(*pkgs):
81 | return [deps[pkg] for pkg in pkgs]
82 |
83 |
84 | extras = {}
85 | extras["tests"] = deps_list("pytest", "parameterized")
86 | extras["torch"] = deps_list("torch")
87 | extras["quality"] = deps_list("black", "isort", "flake8")
88 | extras["eval"] = deps_list("lighteval", "math-verify")
89 | extras["dev"] = extras["quality"] + extras["tests"] + extras["eval"]
90 |
91 | # core dependencies shared across the whole project - keep this to a bare minimum :)
92 | install_requires = [
93 | deps["accelerate"],
94 | deps["bitsandbytes"],
95 | deps["einops"],
96 | deps["datasets"],
97 | deps["deepspeed"],
98 | deps["hf_transfer"],
99 | deps["huggingface-hub"],
100 | deps["liger_kernel"],
101 | deps["packaging"], # utilities from PyPA to e.g., compare versions
102 | deps["safetensors"],
103 | deps["sentencepiece"],
104 | # deps["transformers"],
105 | deps["trl"],
106 | ]
107 |
108 | setup(
109 | name="r1-v",
110 | version="0.1.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
111 | author="The r1-v team and the Hugging Face team (past and future)",
112 | description="R1-V",
113 | license="Apache",
114 | url="https://github.com/Deep-Agent/R1-V",
115 | package_dir={"": "src"},
116 | packages=find_packages("src"),
117 | zip_safe=False,
118 | extras_require=extras,
119 | python_requires=">=3.10.9",
120 | install_requires=install_requires,
121 | classifiers=[
122 | "Development Status :: 3 - Alpha",
123 | "Intended Audience :: Developers",
124 | "Intended Audience :: Education",
125 | "Intended Audience :: Science/Research",
126 | "License :: OSI Approved :: Apache Software License",
127 | "Operating System :: OS Independent",
128 | "Programming Language :: Python :: 3",
129 | "Programming Language :: Python :: 3.10",
130 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
131 | ],
132 | )
133 |
--------------------------------------------------------------------------------
/Visual-ARFT/src/visual_arft/src/open_r1/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/Visual-ARFT/src/visual_arft/src/open_r1/__init__.py
--------------------------------------------------------------------------------
/Visual-ARFT/src/visual_arft/src/open_r1/evaluate.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Custom evaluation tasks for LightEval."""
16 |
17 | from lighteval.metrics.dynamic_metrics import (
18 | ExprExtractionConfig,
19 | LatexExtractionConfig,
20 | multilingual_extractive_match_metric,
21 | )
22 | from lighteval.tasks.lighteval_task import LightevalTaskConfig
23 | from lighteval.tasks.requests import Doc
24 | from lighteval.utils.language import Language
25 |
26 |
27 | metric = multilingual_extractive_match_metric(
28 | language=Language.ENGLISH,
29 | fallback_mode="first_match",
30 | precision=5,
31 | gold_extraction_target=(LatexExtractionConfig(),),
32 | pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
33 | aggregation_function=max,
34 | )
35 |
36 |
37 | def prompt_fn(line, task_name: str = None):
38 | """Assumes the model is either prompted to emit \\boxed{answer} or does so automatically"""
39 | return Doc(
40 | task_name=task_name,
41 | query=line["problem"],
42 | choices=[line["solution"]],
43 | gold_index=0,
44 | )
45 |
46 |
47 | # Define tasks
48 | aime24 = LightevalTaskConfig(
49 | name="aime24",
50 | suite=["custom"],
51 | prompt_function=prompt_fn,
52 | hf_repo="HuggingFaceH4/aime_2024",
53 | hf_subset="default",
54 | hf_avail_splits=["train"],
55 | evaluation_splits=["train"],
56 | few_shots_split=None,
57 | few_shots_select=None,
58 | generation_size=32768,
59 | metric=[metric],
60 | version=1,
61 | )
62 | math_500 = LightevalTaskConfig(
63 | name="math_500",
64 | suite=["custom"],
65 | prompt_function=prompt_fn,
66 | hf_repo="HuggingFaceH4/MATH-500",
67 | hf_subset="default",
68 | hf_avail_splits=["test"],
69 | evaluation_splits=["test"],
70 | few_shots_split=None,
71 | few_shots_select=None,
72 | generation_size=32768,
73 | metric=[metric],
74 | version=1,
75 | )
76 |
77 | # Add tasks to the table
78 | TASKS_TABLE = []
79 | TASKS_TABLE.append(aime24)
80 | TASKS_TABLE.append(math_500)
81 |
82 | # MODULE LOGIC
83 | if __name__ == "__main__":
84 | print([t["name"] for t in TASKS_TABLE])
85 | print(len(TASKS_TABLE))
86 |
--------------------------------------------------------------------------------
/Visual-ARFT/src/visual_arft/src/open_r1/generate.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Optional
16 |
17 | from distilabel.llms import OpenAILLM
18 | from distilabel.pipeline import Pipeline
19 | from distilabel.steps.tasks import TextGeneration
20 |
21 |
22 | def build_distilabel_pipeline(
23 | model: str,
24 | base_url: str = "http://localhost:8000/v1",
25 | prompt_column: Optional[str] = None,
26 | temperature: Optional[float] = None,
27 | top_p: Optional[float] = None,
28 | max_new_tokens: int = 8192,
29 | num_generations: int = 1,
30 | ) -> Pipeline:
31 | generation_kwargs = {"max_new_tokens": max_new_tokens}
32 |
33 | if temperature is not None:
34 | generation_kwargs["temperature"] = temperature
35 |
36 | if top_p is not None:
37 | generation_kwargs["top_p"] = top_p
38 |
39 | with Pipeline().ray() as pipeline:
40 | TextGeneration(
41 | llm=OpenAILLM(
42 | base_url=base_url,
43 | api_key="something",
44 | model=model,
45 | # thinking can take some time...
46 | timeout=10 * 60,
47 | generation_kwargs=generation_kwargs,
48 | ),
49 | input_mappings={"instruction": prompt_column} if prompt_column is not None else {},
50 | input_batch_size=64, # on 4 nodes bs ~60+ leads to preemption due to KV cache exhaustion
51 | num_generations=num_generations,
52 | )
53 |
54 | return pipeline
55 |
56 |
57 | if __name__ == "__main__":
58 | import argparse
59 |
60 | from datasets import load_dataset
61 |
62 | parser = argparse.ArgumentParser(description="Run distilabel pipeline for generating responses with DeepSeek R1")
63 | parser.add_argument(
64 | "--hf-dataset",
65 | type=str,
66 | required=True,
67 | help="HuggingFace dataset to load",
68 | )
69 | parser.add_argument(
70 | "--hf-dataset-config",
71 | type=str,
72 | required=False,
73 | help="Dataset config to use",
74 | )
75 | parser.add_argument(
76 | "--hf-dataset-split",
77 | type=str,
78 | default="train",
79 | help="Dataset split to use",
80 | )
81 | parser.add_argument("--prompt-column", type=str, default="prompt")
82 | parser.add_argument(
83 | "--model",
84 | type=str,
85 | required=True,
86 | help="Model name to use for generation",
87 | )
88 | parser.add_argument(
89 | "--vllm-server-url",
90 | type=str,
91 | default="http://localhost:8000/v1",
92 | help="URL of the vLLM server",
93 | )
94 | parser.add_argument(
95 | "--temperature",
96 | type=float,
97 | help="Temperature for generation",
98 | )
99 | parser.add_argument(
100 | "--top-p",
101 | type=float,
102 | help="Top-p value for generation",
103 | )
104 | parser.add_argument(
105 | "--max-new-tokens",
106 | type=int,
107 | default=8192,
108 | help="Maximum number of new tokens to generate",
109 | )
110 | parser.add_argument(
111 | "--num-generations",
112 | type=int,
113 | default=1,
114 | help="Number of generations per problem",
115 | )
116 | parser.add_argument(
117 | "--hf-output-dataset",
118 | type=str,
119 | required=False,
120 | help="HuggingFace repo to push results to",
121 | )
122 | parser.add_argument(
123 | "--private",
124 | action="store_true",
125 | help="Whether to make the output dataset private when pushing to HF Hub",
126 | )
127 |
128 | args = parser.parse_args()
129 |
130 | print("\nRunning with arguments:")
131 | for arg, value in vars(args).items():
132 | print(f" {arg}: {value}")
133 | print()
134 |
135 | print(f"Loading '{args.hf_dataset}' (config: {args.hf_dataset_config}, split: {args.hf_dataset_split}) dataset...")
136 | dataset = load_dataset(args.hf_dataset, split=args.hf_dataset_split)
137 | print("Dataset loaded!")
138 |
139 | pipeline = build_distilabel_pipeline(
140 | model=args.model,
141 | base_url=args.vllm_server_url,
142 | prompt_column=args.prompt_column,
143 | temperature=args.temperature,
144 | top_p=args.top_p,
145 | max_new_tokens=args.max_new_tokens,
146 | num_generations=args.num_generations,
147 | )
148 |
149 | print("Running generation pipeline...")
150 | distiset = pipeline.run(dataset=dataset, use_cache=False)
151 | print("Generation pipeline finished!")
152 |
153 | if args.hf_output_dataset:
154 | print(f"Pushing resulting dataset to '{args.hf_output_dataset}'...")
155 | distiset.push_to_hub(args.hf_output_dataset, private=args.private)
156 | print("Dataset pushed!")
157 |
--------------------------------------------------------------------------------
/Visual-ARFT/src/visual_arft/src/open_r1/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | from .grpo_trainer import Qwen2VLGRPOTrainer
2 | from .vllm_grpo_trainer import Qwen2VLGRPOVLLMTrainer
3 | from .vllm_grpo_trainer_modified import Qwen2VLGRPOVLLMTrainerModified
4 | from .grpo_trainer_aid import Qwen2VLGRPOTrainer_AID
5 | from .grpo_trainer_visual_rft import Qwen2VLGRPOTrainer_Visual_RFT
6 | from .grpo_trainer_mp import Qwen2VLGRPOTrainer_MP
7 |
8 | __all__ = [
9 | "Qwen2VLGRPOTrainer",
10 | "Qwen2VLGRPOVLLMTrainer",
11 | "Qwen2VLGRPOVLLMTrainerModified",
12 | "Qwen2VLGRPOTrainer_AID",
13 | "Qwen2VLGRPOTrainer_Visual_RFT",
14 | "Qwen2VLGRPOTrainer_MP",
15 | ]
16 |
--------------------------------------------------------------------------------
/Visual-ARFT/src/visual_arft/temp_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/Visual-ARFT/src/visual_arft/temp_image.png
--------------------------------------------------------------------------------
/assets/case_cls.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/assets/case_cls.png
--------------------------------------------------------------------------------
/assets/case_lisa.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/assets/case_lisa.png
--------------------------------------------------------------------------------
/assets/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/assets/framework.png
--------------------------------------------------------------------------------
/assets/pokeymon.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/assets/pokeymon.jpg
--------------------------------------------------------------------------------
/assets/radar.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/assets/radar.png
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/assets/teaser.png
--------------------------------------------------------------------------------
/classification/val_data/fgvc_aircraft.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/classification/val_data/fgvc_aircraft.pth
--------------------------------------------------------------------------------
/classification/val_data/fgvc_aircraft.txt:
--------------------------------------------------------------------------------
1 | 707-320
2 | 727-200
3 | 737-200
4 | 737-300
5 | 737-400
6 | 737-500
7 | 737-600
8 | 737-700
9 | 737-800
10 | 737-900
11 | 747-100
12 | 747-200
13 | 747-300
14 | 747-400
15 | 757-200
16 | 757-300
17 | 767-200
18 | 767-300
19 | 767-400
20 | 777-200
21 | 777-300
22 | A300B4
23 | A310
24 | A318
25 | A319
26 | A320
27 | A321
28 | A330-200
29 | A330-300
30 | A340-200
31 | A340-300
32 | A340-500
33 | A340-600
34 | A380
35 | ATR-42
36 | ATR-72
37 | An-12
38 | BAE 146-200
39 | BAE 146-300
40 | BAE-125
41 | Beechcraft 1900
42 | Boeing 717
43 | C-130
44 | C-47
45 | CRJ-200
46 | CRJ-700
47 | CRJ-900
48 | Cessna 172
49 | Cessna 208
50 | Cessna 525
51 | Cessna 560
52 | Challenger 600
53 | DC-10
54 | DC-3
55 | DC-6
56 | DC-8
57 | DC-9-30
58 | DH-82
59 | DHC-1
60 | DHC-6
61 | DHC-8-100
62 | DHC-8-300
63 | DR-400
64 | Dornier 328
65 | E-170
66 | E-190
67 | E-195
68 | EMB-120
69 | ERJ 135
70 | ERJ 145
71 | Embraer Legacy 600
72 | Eurofighter Typhoon
73 | F-16A/B
74 | F/A-18
75 | Falcon 2000
76 | Falcon 900
77 | Fokker 100
78 | Fokker 50
79 | Fokker 70
80 | Global Express
81 | Gulfstream IV
82 | Gulfstream V
83 | Hawk T1
84 | Il-76
85 | L-1011
86 | MD-11
87 | MD-80
88 | MD-87
89 | MD-90
90 | Metroliner
91 | Model B200
92 | PA-28
93 | SR-20
94 | Saab 2000
95 | Saab 340
96 | Spitfire
97 | Tornado
98 | Tu-134
99 | Tu-154
100 | Yak-42
--------------------------------------------------------------------------------
/classification/val_data/oxford_flowers.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/classification/val_data/oxford_flowers.pth
--------------------------------------------------------------------------------
/classification/val_data/oxford_flowers.txt:
--------------------------------------------------------------------------------
1 | pink primrose
2 | hard-leaved pocket orchid
3 | canterbury bells
4 | sweet pea
5 | english marigold
6 | tiger lily
7 | moon orchid
8 | bird of paradise
9 | monkshood
10 | globe thistle
11 | snapdragon
12 | colts foot
13 | king protea
14 | spear thistle
15 | yellow iris
16 | globe-flower
17 | purple coneflower
18 | peruvian lily
19 | balloon flower
20 | giant white arum lily
21 | fire lily
22 | pincushion flower
23 | fritillary
24 | red ginger
25 | grape hyacinth
26 | corn poppy
27 | prince of wales feathers
28 | stemless gentian
29 | artichoke
30 | sweet william
31 | carnation
32 | garden phlox
33 | love in the mist
34 | mexican aster
35 | alpine sea holly
36 | ruby-lipped cattleya
37 | cape flower
38 | great masterwort
39 | siam tulip
40 | lenten rose
41 | barbeton daisy
42 | daffodil
43 | sword lily
44 | poinsettia
45 | bolero deep blue
46 | wallflower
47 | marigold
48 | buttercup
49 | oxeye daisy
50 | common dandelion
51 | petunia
52 | wild pansy
53 | primula
54 | sunflower
55 | pelargonium
56 | bishop of llandaff
57 | gaura
58 | geranium
59 | orange dahlia
60 | pink-yellow dahlia
61 | cautleya spicata
62 | japanese anemone
63 | black-eyed susan
64 | silverbush
65 | californian poppy
66 | osteospermum
67 | spring crocus
68 | bearded iris
69 | windflower
70 | tree poppy
71 | gazania
72 | azalea
73 | water lily
74 | rose
75 | thorn apple
76 | morning glory
77 | passion flower
78 | lotus
79 | toad lily
80 | anthurium
81 | frangipani
82 | clematis
83 | hibiscus
84 | columbine
85 | desert-rose
86 | tree mallow
87 | magnolia
88 | cyclamen
89 | watercress
90 | canna lily
91 | hippeastrum
92 | bee balm
93 | ball moss
94 | foxglove
95 | bougainvillea
96 | camellia
97 | mallow
98 | mexican petunia
99 | bromelia
100 | blanket flower
101 | trumpet creeper
102 | blackberry lily
--------------------------------------------------------------------------------
/classification/val_data/pets.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/classification/val_data/pets.pth
--------------------------------------------------------------------------------
/classification/val_data/pets.txt:
--------------------------------------------------------------------------------
1 | abyssinian
2 | american_bulldog
3 | american_pit_bull_terrier
4 | basset_hound
5 | beagle
6 | bengal
7 | birman
8 | bombay
9 | boxer
10 | british_shorthair
11 | chihuahua
12 | egyptian_mau
13 | english_cocker_spaniel
14 | english_setter
15 | german_shorthaired
16 | great_pyrenees
17 | havanese
18 | japanese_chin
19 | keeshond
20 | leonberger
21 | maine_coon
22 | miniature_pinscher
23 | newfoundland
24 | persian
25 | pomeranian
26 | pug
27 | ragdoll
28 | russian_blue
29 | saint_bernard
30 | samoyed
31 | scottish_terrier
32 | shiba_inu
33 | siamese
34 | sphynx
35 | staffordshire_bull_terrier
36 | wheaten_terrier
37 | yorkshire_terrier
--------------------------------------------------------------------------------
/classification/val_data/stanford_cars.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/classification/val_data/stanford_cars.pth
--------------------------------------------------------------------------------
/classification/val_data/stanford_cars.txt:
--------------------------------------------------------------------------------
1 | 2000 AM General Hummer SUV
2 | 2012 Acura RL Sedan
3 | 2012 Acura TL Sedan
4 | 2008 Acura TL Type-S
5 | 2012 Acura TSX Sedan
6 | 2001 Acura Integra Type R
7 | 2012 Acura ZDX Hatchback
8 | 2012 Aston Martin V8 Vantage Convertible
9 | 2012 Aston Martin V8 Vantage Coupe
10 | 2012 Aston Martin Virage Convertible
11 | 2012 Aston Martin Virage Coupe
12 | 2008 Audi RS 4 Convertible
13 | 2012 Audi A5 Coupe
14 | 2012 Audi TTS Coupe
15 | 2012 Audi R8 Coupe
16 | 1994 Audi V8 Sedan
17 | 1994 Audi 100 Sedan
18 | 1994 Audi 100 Wagon
19 | 2011 Audi TT Hatchback
20 | 2011 Audi S6 Sedan
21 | 2012 Audi S5 Convertible
22 | 2012 Audi S5 Coupe
23 | 2012 Audi S4 Sedan
24 | 2007 Audi S4 Sedan
25 | 2012 Audi TT RS Coupe
26 | 2012 BMW ActiveHybrid 5 Sedan
27 | 2012 BMW 1 Series Convertible
28 | 2012 BMW 1 Series Coupe
29 | 2012 BMW 3 Series Sedan
30 | 2012 BMW 3 Series Wagon
31 | 2007 BMW 6 Series Convertible
32 | 2007 BMW X5 SUV
33 | 2012 BMW X6 SUV
34 | 2012 BMW M3 Coupe
35 | 2010 BMW M5 Sedan
36 | 2010 BMW M6 Convertible
37 | 2012 BMW X3 SUV
38 | 2012 BMW Z4 Convertible
39 | 2012 Bentley Continental Supersports Conv. Convertible
40 | 2009 Bentley Arnage Sedan
41 | 2011 Bentley Mulsanne Sedan
42 | 2012 Bentley Continental GT Coupe
43 | 2007 Bentley Continental GT Coupe
44 | 2007 Bentley Continental Flying Spur Sedan
45 | 2009 Bugatti Veyron 16.4 Convertible
46 | 2009 Bugatti Veyron 16.4 Coupe
47 | 2012 Buick Regal GS
48 | 2007 Buick Rainier SUV
49 | 2012 Buick Verano Sedan
50 | 2012 Buick Enclave SUV
51 | 2012 Cadillac CTS-V Sedan
52 | 2012 Cadillac SRX SUV
53 | 2007 Cadillac Escalade EXT Crew Cab
54 | 2012 Chevrolet Silverado 1500 Hybrid Crew Cab
55 | 2012 Chevrolet Corvette Convertible
56 | 2012 Chevrolet Corvette ZR1
57 | 2007 Chevrolet Corvette Ron Fellows Edition Z06
58 | 2012 Chevrolet Traverse SUV
59 | 2012 Chevrolet Camaro Convertible
60 | 2010 Chevrolet HHR SS
61 | 2007 Chevrolet Impala Sedan
62 | 2012 Chevrolet Tahoe Hybrid SUV
63 | 2012 Chevrolet Sonic Sedan
64 | 2007 Chevrolet Express Cargo Van
65 | 2012 Chevrolet Avalanche Crew Cab
66 | 2010 Chevrolet Cobalt SS
67 | 2010 Chevrolet Malibu Hybrid Sedan
68 | 2009 Chevrolet TrailBlazer SS
69 | 2012 Chevrolet Silverado 2500HD Regular Cab
70 | 2007 Chevrolet Silverado 1500 Classic Extended Cab
71 | 2007 Chevrolet Express Van
72 | 2007 Chevrolet Monte Carlo Coupe
73 | 2007 Chevrolet Malibu Sedan
74 | 2012 Chevrolet Silverado 1500 Extended Cab
75 | 2012 Chevrolet Silverado 1500 Regular Cab
76 | 2009 Chrysler Aspen SUV
77 | 2010 Chrysler Sebring Convertible
78 | 2012 Chrysler Town and Country Minivan
79 | 2010 Chrysler 300 SRT-8
80 | 2008 Chrysler Crossfire Convertible
81 | 2008 Chrysler PT Cruiser Convertible
82 | 2002 Daewoo Nubira Wagon
83 | 2012 Dodge Caliber Wagon
84 | 2007 Dodge Caliber Wagon
85 | 1997 Dodge Caravan Minivan
86 | 2010 Dodge Ram Pickup 3500 Crew Cab
87 | 2009 Dodge Ram Pickup 3500 Quad Cab
88 | 2009 Dodge Sprinter Cargo Van
89 | 2012 Dodge Journey SUV
90 | 2010 Dodge Dakota Crew Cab
91 | 2007 Dodge Dakota Club Cab
92 | 2008 Dodge Magnum Wagon
93 | 2011 Dodge Challenger SRT8
94 | 2012 Dodge Durango SUV
95 | 2007 Dodge Durango SUV
96 | 2012 Dodge Charger Sedan
97 | 2009 Dodge Charger SRT-8
98 | 1998 Eagle Talon Hatchback
99 | 2012 FIAT 500 Abarth
100 | 2012 FIAT 500 Convertible
101 | 2012 Ferrari FF Coupe
102 | 2012 Ferrari California Convertible
103 | 2012 Ferrari 458 Italia Convertible
104 | 2012 Ferrari 458 Italia Coupe
105 | 2012 Fisker Karma Sedan
106 | 2012 Ford F-450 Super Duty Crew Cab
107 | 2007 Ford Mustang Convertible
108 | 2007 Ford Freestar Minivan
109 | 2009 Ford Expedition EL SUV
110 | 2012 Ford Edge SUV
111 | 2011 Ford Ranger SuperCab
112 | 2006 Ford GT Coupe
113 | 2012 Ford F-150 Regular Cab
114 | 2007 Ford F-150 Regular Cab
115 | 2007 Ford Focus Sedan
116 | 2012 Ford E-Series Wagon Van
117 | 2012 Ford Fiesta Sedan
118 | 2012 GMC Terrain SUV
119 | 2012 GMC Savana Van
120 | 2012 GMC Yukon Hybrid SUV
121 | 2012 GMC Acadia SUV
122 | 2012 GMC Canyon Extended Cab
123 | 1993 Geo Metro Convertible
124 | 2010 HUMMER H3T Crew Cab
125 | 2009 HUMMER H2 SUT Crew Cab
126 | 2012 Honda Odyssey Minivan
127 | 2007 Honda Odyssey Minivan
128 | 2012 Honda Accord Coupe
129 | 2012 Honda Accord Sedan
130 | 2012 Hyundai Veloster Hatchback
131 | 2012 Hyundai Santa Fe SUV
132 | 2012 Hyundai Tucson SUV
133 | 2012 Hyundai Veracruz SUV
134 | 2012 Hyundai Sonata Hybrid Sedan
135 | 2007 Hyundai Elantra Sedan
136 | 2012 Hyundai Accent Sedan
137 | 2012 Hyundai Genesis Sedan
138 | 2012 Hyundai Sonata Sedan
139 | 2012 Hyundai Elantra Touring Hatchback
140 | 2012 Hyundai Azera Sedan
141 | 2012 Infiniti G Coupe IPL
142 | 2011 Infiniti QX56 SUV
143 | 2008 Isuzu Ascender SUV
144 | 2012 Jaguar XK XKR
145 | 2012 Jeep Patriot SUV
146 | 2012 Jeep Wrangler SUV
147 | 2012 Jeep Liberty SUV
148 | 2012 Jeep Grand Cherokee SUV
149 | 2012 Jeep Compass SUV
150 | 2008 Lamborghini Reventon Coupe
151 | 2012 Lamborghini Aventador Coupe
152 | 2012 Lamborghini Gallardo LP 570-4 Superleggera
153 | 2001 Lamborghini Diablo Coupe
154 | 2012 Land Rover Range Rover SUV
155 | 2012 Land Rover LR2 SUV
156 | 2011 Lincoln Town Car Sedan
157 | 2012 MINI Cooper Roadster Convertible
158 | 2012 Maybach Landaulet Convertible
159 | 2011 Mazda Tribute SUV
160 | 2012 McLaren MP4-12C Coupe
161 | 1993 Mercedes-Benz 300-Class Convertible
162 | 2012 Mercedes-Benz C-Class Sedan
163 | 2009 Mercedes-Benz SL-Class Coupe
164 | 2012 Mercedes-Benz E-Class Sedan
165 | 2012 Mercedes-Benz S-Class Sedan
166 | 2012 Mercedes-Benz Sprinter Van
167 | 2012 Mitsubishi Lancer Sedan
168 | 2012 Nissan Leaf Hatchback
169 | 2012 Nissan NV Passenger Van
170 | 2012 Nissan Juke Hatchback
171 | 1998 Nissan 240SX Coupe
172 | 1999 Plymouth Neon Coupe
173 | 2012 Porsche Panamera Sedan
174 | 2012 Ram C/V Cargo Van Minivan
175 | 2012 Rolls-Royce Phantom Drophead Coupe Convertible
176 | 2012 Rolls-Royce Ghost Sedan
177 | 2012 Rolls-Royce Phantom Sedan
178 | 2012 Scion xD Hatchback
179 | 2009 Spyker C8 Convertible
180 | 2009 Spyker C8 Coupe
181 | 2007 Suzuki Aerio Sedan
182 | 2012 Suzuki Kizashi Sedan
183 | 2012 Suzuki SX4 Hatchback
184 | 2012 Suzuki SX4 Sedan
185 | 2012 Tesla Model S Sedan
186 | 2012 Toyota Sequoia SUV
187 | 2012 Toyota Camry Sedan
188 | 2012 Toyota Corolla Sedan
189 | 2012 Toyota 4Runner SUV
190 | 2012 Volkswagen Golf Hatchback
191 | 1991 Volkswagen Golf Hatchback
192 | 2012 Volkswagen Beetle Hatchback
193 | 2012 Volvo C30 Hatchback
194 | 1993 Volvo 240 Sedan
195 | 2007 Volvo XC90 SUV
196 | 2012 smart fortwo Convertible
--------------------------------------------------------------------------------
/coco_evaluation/coco_evaluation.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import copy
3 | import io
4 | import itertools
5 | import json
6 | import logging
7 | import os
8 | import warnings
9 |
10 | import numpy as np
11 | from pycocotools.cocoeval import COCOeval
12 | from tabulate import tabulate
13 |
14 | import cv2
15 | import torch
16 | from pycocotools.coco import COCO
17 |
18 | logger = logging.getLogger("coco")
19 |
20 |
21 | def xyxy2xywh(bbox):
22 | """
23 | change bbox to coco format
24 | :param bbox: [x1, y1, x2, y2]
25 | :return: [x, y, w, h]
26 | """
27 | return [
28 | bbox[0],
29 | bbox[1],
30 | bbox[2] - bbox[0],
31 | bbox[3] - bbox[1],
32 | ]
33 |
34 |
35 | class CocoDetectionEvaluator:
36 | def __init__(self, ann_path):
37 |
38 | self.coco_api = COCO(ann_path)
39 | self.cat_ids = sorted(self.coco_api.getCatIds())
40 | self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
41 | self.cats = self.coco_api.loadCats(self.cat_ids)
42 | self.class_names = [cat["name"] for cat in self.cats]
43 | self.img_ids = sorted(self.coco_api.imgs.keys())
44 | img_info = self.coco_api.loadImgs(self.img_ids)
45 |
46 | # assert hasattr(dataset, "coco_api")
47 | # self.class_names = dataset.class_names
48 | # self.coco_api = dataset.coco_api
49 | # self.cat_ids = dataset.cat_ids
50 | self.metric_names = ["mAP", "AP_50", "AP_75", "AP_small", "AP_m", "AP_l"]
51 |
52 | def results2json(self, results):
53 | """
54 | results: {image_id: {label: [bboxes...] } }
55 | :return coco json format: {image_id:
56 | category_id:
57 | bbox:
58 | score: }
59 | """
60 | json_results = []
61 | for image_id, dets in results.items():
62 | for label, bboxes in dets.items():
63 | category_id = self.cat_ids[label]
64 | for bbox in bboxes:
65 | score = float(bbox[4])
66 | detection = dict(
67 | image_id=int(image_id),
68 | category_id=int(category_id),
69 | bbox=xyxy2xywh(bbox),
70 | score=score,
71 | )
72 | json_results.append(detection)
73 | return json_results
74 |
75 | def evaluate(self, results, save_dir, rank=-1):
76 | ### Original
77 | # results_json = self.results2json(results)
78 | ### lzy modified
79 | with open(results, 'r') as json_file:
80 | results_json = json.load(json_file)
81 | if len(results_json) == 0:
82 | warnings.warn(
83 | "Detection result is empty! Please check whether "
84 | "training set is too small (need to increase val_interval "
85 | "in config and train more epochs). Or check annotation "
86 | "correctness."
87 | )
88 | empty_eval_results = {}
89 | for key in self.metric_names:
90 | empty_eval_results[key] = 0
91 | return empty_eval_results
92 | json_path = os.path.join(save_dir, "results{}.json".format(rank))
93 | json.dump(results_json, open(json_path, "w"))
94 | coco_dets = self.coco_api.loadRes(json_path)
95 | coco_eval = COCOeval(
96 | copy.deepcopy(self.coco_api), copy.deepcopy(coco_dets), "bbox"
97 | )
98 | coco_eval.evaluate()
99 | coco_eval.accumulate()
100 |
101 | # use logger to log coco eval results
102 | redirect_string = io.StringIO()
103 | with contextlib.redirect_stdout(redirect_string):
104 | coco_eval.summarize()
105 | logger.info("\n" + redirect_string.getvalue())
106 |
107 | # print per class AP
108 | headers = ["class", "AP50", "mAP"]
109 | colums = 6
110 | per_class_ap50s = []
111 | per_class_maps = []
112 | precisions = coco_eval.eval["precision"]
113 | # dimension of precisions: [TxRxKxAxM]
114 | # precision has dims (iou, recall, cls, area range, max dets)
115 | assert len(self.class_names) == precisions.shape[2]
116 |
117 | ### lzy modified
118 | per_class_results = []
119 | for idx, name in enumerate(self.class_names):
120 | # area range index 0: all area ranges
121 | # max dets index -1: typically 100 per image
122 | precision_50 = precisions[0, :, idx, 0, -1]
123 | precision_50 = precision_50[precision_50 > -1]
124 | ap50 = np.mean(precision_50) if precision_50.size else float("nan")
125 | per_class_ap50s.append(float(ap50 * 100))
126 |
127 | precision = precisions[:, :, idx, 0, -1]
128 | precision = precision[precision > -1]
129 | ap = np.mean(precision) if precision.size else float("nan")
130 | per_class_maps.append(float(ap * 100))
131 | per_class_results.append({name:float(ap * 100)})
132 |
133 | num_cols = min(colums, len(self.class_names) * len(headers))
134 | flatten_results = []
135 | for name, ap50, mAP in zip(self.class_names, per_class_ap50s, per_class_maps):
136 | flatten_results += [name, ap50, mAP]
137 |
138 | row_pair = itertools.zip_longest(
139 | *[flatten_results[i::num_cols] for i in range(num_cols)]
140 | )
141 | table_headers = headers * (num_cols // len(headers))
142 | table = tabulate(
143 | row_pair,
144 | tablefmt="pipe",
145 | floatfmt=".1f",
146 | headers=table_headers,
147 | numalign="left",
148 | )
149 | logger.info("\n" + table)
150 |
151 | aps = coco_eval.stats[:6]
152 | eval_results = {}
153 | for k, v in zip(self.metric_names, aps):
154 | eval_results[k] = v
155 | return eval_results, per_class_results
156 |
--------------------------------------------------------------------------------
/coco_evaluation/evaluation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "d8fcd1a9-d7c5-4049-bfff-27dbc49b029d",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "from coco_evaluation import CocoDetectionEvaluator"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "id": "aa90e100-072b-4f55-999c-a8e03d56fe87",
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "evaluator = CocoDetectionEvaluator('./data/coco/annotations/instances_val2017.json')"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": null,
26 | "id": "585276a5-cfb2-4e2b-a9de-30a46cbf30c5",
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "results, per_class_results = evaluator.evaluate('./prediction_Qwen2_vl_2B_GRPO_coco.json', './results')"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": null,
36 | "id": "92d76ac4-44aa-43b5-8123-d871453e6750",
37 | "metadata": {
38 | "scrolled": true
39 | },
40 | "outputs": [],
41 | "source": [
42 | "### mAP and AP for all categories\n",
43 | "results, per_class_results"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": null,
49 | "id": "014151c7-5a90-4d46-9782-7195cb7147ed",
50 | "metadata": {},
51 | "outputs": [],
52 | "source": [
53 | "### mAP and AP for selected categories\n",
54 | "selected_cate = ['bus', 'train', 'fire hydrant', 'stop sign', 'cat', 'dog', 'bed', 'toilet']\n",
55 | "# selected_cate = ['mouse', 'fork', 'hot dog', 'cat', 'airplane', 'suitcase', 'parking meter', 'sandwich', 'train', 'hair drier', 'toilet', 'toaster', 'snowboard', 'frisbee', 'bear']\n",
56 | "results, per_class_results\n",
57 | "AP_sum = 0\n",
58 | "for item in per_class_results:\n",
59 | " for key, value in item.items():\n",
60 | " if key in selected_cate:\n",
61 | " print(f\"Key: {key}, Value: {value}\")\n",
62 | " AP_sum += value\n",
63 | "print(\"mAP for selected categories: \", (AP_sum)/(len(selected_cate)))"
64 | ]
65 | }
66 | ],
67 | "metadata": {
68 | "kernelspec": {
69 | "display_name": "Python 3 (ipykernel)",
70 | "language": "python",
71 | "name": "python3"
72 | },
73 | "language_info": {
74 | "codemirror_mode": {
75 | "name": "ipython",
76 | "version": 3
77 | },
78 | "file_extension": ".py",
79 | "mimetype": "text/x-python",
80 | "name": "python",
81 | "nbconvert_exporter": "python",
82 | "pygments_lexer": "ipython3",
83 | "version": "3.10.13"
84 | }
85 | },
86 | "nbformat": 4,
87 | "nbformat_minor": 5
88 | }
89 |
--------------------------------------------------------------------------------
/dataset/README.md:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/dataset/build_dataset.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "2f84d496-a970-452d-99a1-2f67718dff9b",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import time\n",
11 | "from datasets import DatasetDict, Dataset\n",
12 | "from PIL import Image\n",
13 | "import json"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": null,
19 | "id": "5a8ba1f2-d465-4b78-a48a-744f591a14ab",
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "\"\"\"\n",
24 | "turn your json to DatasetDict\n",
25 | "\"\"\"\n",
26 | "def json_to_dataset(json_file_path):\n",
27 | " # read json file\n",
28 | " with open(json_file_path, 'r') as f:\n",
29 | " data = json.load(f)\n",
30 | "\n",
31 | " image_paths = [item['image_path'] for item in data]\n",
32 | " problems = [item['problem'] for item in data]\n",
33 | " solutions = [item['solution'] for item in data]\n",
34 | "\n",
35 | " images = [Image.open(image_path).convert('RGBA') for image_path in image_paths]\n",
36 | "\n",
37 | " dataset_dict = {\n",
38 | " 'image': images,\n",
39 | " 'problem': problems,\n",
40 | " 'solution': solutions\n",
41 | " }\n",
42 | "\n",
43 | " dataset = Dataset.from_dict(dataset_dict)\n",
44 | " dataset_dict = DatasetDict({\n",
45 | " 'train': dataset\n",
46 | " })\n",
47 | " return dataset_dict\n",
48 | "\n",
49 | "\n",
50 | "time1 = time.asctime()\n",
51 | "print(time1)\n",
52 | "### Your dataset in JSON file format consists of three parts: image, problem and solution\n",
53 | "dataset_dict = json_to_dataset('your_dataset_json_file.json')\n",
54 | "time2 = time.asctime()\n",
55 | "print(time2)"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": null,
61 | "id": "0e2a20b4-131f-49fa-baee-4c2675479de3",
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "\"\"\"\n",
66 | "save to your local disk\n",
67 | "\"\"\"\n",
68 | "def save_dataset(dataset_dict, save_path):\n",
69 | " # save DatasetDict to your disk\n",
70 | " dataset_dict.save_to_disk(save_path)\n",
71 | "\n",
72 | "save_path = './share_data/your_local_dataset'\n",
73 | "save_dataset(dataset_dict, save_path)"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": null,
79 | "id": "22aa1938-3b68-4df3-896d-74c8b7c854c3",
80 | "metadata": {},
81 | "outputs": [],
82 | "source": [
83 | "\"\"\"\n",
84 | "read from your local disk\n",
85 | "\"\"\"\n",
86 | "def load_dataset(save_path):\n",
87 | " # load DatasetDict\n",
88 | " return DatasetDict.load_from_disk(save_path)\n",
89 | "\n",
90 | "test_dataset_dict = load_dataset('./share_data/your_local_dataset')"
91 | ]
92 | }
93 | ],
94 | "metadata": {
95 | "kernelspec": {
96 | "display_name": "Python 3 (ipykernel)",
97 | "language": "python",
98 | "name": "python3"
99 | },
100 | "language_info": {
101 | "codemirror_mode": {
102 | "name": "ipython",
103 | "version": 3
104 | },
105 | "file_extension": ".py",
106 | "mimetype": "text/x-python",
107 | "name": "python",
108 | "nbconvert_exporter": "python",
109 | "pygments_lexer": "ipython3",
110 | "version": "3.10.13"
111 | }
112 | },
113 | "nbformat": 4,
114 | "nbformat_minor": 5
115 | }
116 |
--------------------------------------------------------------------------------
/demo/README.md:
--------------------------------------------------------------------------------
1 | ## Inference on LISA Dataset
2 | We've uploaded the model trained with 239 samples from the LISA dataset(🤗Huggingface). You can use the following code for inference to test the model's **reasoning grounding** capability (or use `lisa_demo.ipynb`).
3 | ```python
4 | import torch
5 | from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
6 | from qwen_vl_utils import process_vision_info
7 | import json
8 | import os
9 | from PIL import Image
10 | import logging
11 | from tqdm import tqdm
12 | import re
13 | import math
14 |
15 | logging.basicConfig(level=logging.INFO)
16 | torch.manual_seed(1234)
17 | img2description = dict()
18 |
19 | SYSTEM_PROMPT = (
20 | "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
21 | "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
22 | "process and answer are enclosed within and tags, respectively, i.e., "
23 | " reasoning process here answer here "
24 | )
25 |
26 | model = Qwen2VLForConditionalGeneration.from_pretrained(
27 | "Zery/Qwen2-VL-7B_visual_rft_lisa_IoU_reward", device_map="auto"
28 | ).eval()
29 |
30 | processor = AutoProcessor.from_pretrained("Zery/Qwen2-VL-7B_visual_rft_lisa_IoU_reward")
31 |
32 | def prepare_inputs(img_path, instruction):
33 | messages = [
34 | {"role": "system", "content": SYSTEM_PROMPT},
35 | {
36 | "role": "user",
37 | "content": [
38 | {"type": "image", "image": img_path},
39 | {"type": "text", "text": f"Output the bounding box in the image corresponding to the instruction: {instruction}. Output the thinking process in and your grouding box. Following \" thinking process \n(x1,y1),(x2,y2))\" format."}
40 | ]
41 | }
42 | ]
43 | text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
44 | image_inputs, _ = process_vision_info(messages)
45 | inputs = processor(
46 | text=[text],
47 | images=image_inputs,
48 | padding=True,
49 | return_tensors="pt",
50 | )
51 | return inputs.to("cuda")
52 |
53 | image_path = "assets/pokeymon.jpg"
54 | inputs = prepare_inputs(image_path, "the pokeymon that can perform Thunderbolt. Output thinking process as detail as possibile")
55 |
56 | with torch.no_grad():
57 | generated_ids = model.generate(**inputs, max_new_tokens=128)
58 | response = processor.batch_decode(
59 | generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
60 | )[0]
61 | print(response)
62 | ```
63 |
--------------------------------------------------------------------------------
/lisa_evaluation/Qwen2_VL_lisa_infere.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
3 | from qwen_vl_utils import process_vision_info
4 | import json
5 | import os
6 | from PIL import Image
7 | import logging
8 | from tqdm import tqdm
9 | import re
10 | # from process_utils import pred_2_point, extract_bbox
11 | import math
12 |
13 | logging.basicConfig(level=logging.INFO)
14 | torch.manual_seed(1234)
15 | img2description = dict()
16 |
17 | SYSTEM_PROMPT = (
18 | "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
19 | "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
20 | "process and answer are enclosed within and tags, respectively, i.e., "
21 | " reasoning process here answer here "
22 | )
23 |
24 | # Load Qwen2-VL-2B model and processor
25 | model = Qwen2VLForConditionalGeneration.from_pretrained(
26 | "/path/to/your/checkpoint-498", torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2"
27 | ).eval()
28 |
29 | processor = AutoProcessor.from_pretrained("/path/to/your//checkpoint-498")
30 |
31 | logging.info("Model and processor loaded successfully")
32 |
33 | def process_image(image_path):
34 | if not os.path.exists(image_path):
35 | raise FileNotFoundError(f"Image not found: {image_path}")
36 | return Image.open(image_path)
37 |
38 | def prepare_inputs(img_path, instruction):
39 | messages = [
40 | {"role": "system", "content": SYSTEM_PROMPT},
41 | {
42 | "role": "user",
43 | "content": [
44 | {"type": "image", "image": img_path},
45 | {"type": "text", "text": f"Output the bounding box in the image corresponding to the instruction: {instruction}. Output the thinking process in and your grouding box. Following \" thinking process \n(x1,y1),(x2,y2))\" format."}
46 | ]
47 | }
48 | ]
49 | text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
50 | image_inputs, _ = process_vision_info(messages)
51 | inputs = processor(
52 | text=[text],
53 | images=image_inputs,
54 | padding=True,
55 | return_tensors="pt",
56 | )
57 | return inputs.to("cuda")
58 |
59 | def extract_bbox(response):
60 | try:
61 | match = re.search(r"\[(\d+),(\d+),(\d+),(\d+)\]", response)
62 | if match:
63 | return [int(match.group(i)) for i in range(1, 5)]
64 | else:
65 | raise ValueError("Invalid response format")
66 | except Exception as e:
67 | logging.error(f"Error extracting bbox: {e}")
68 | return None
69 |
70 | def compute_iou(boxA, boxB):
71 | """
72 | 计算 IoU (Intersection over Union)
73 | boxA, boxB 格式: [x1, y1, x2, y2]
74 | """
75 | xA = max(boxA[0], boxB[0])
76 | yA = max(boxA[1], boxB[1])
77 | xB = min(boxA[2], boxB[2])
78 | yB = min(boxA[3], boxB[3])
79 |
80 | interArea = max(0, xB - xA) * max(0, yB - yA)
81 | boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
82 | boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
83 |
84 | iou = interArea / float(boxAArea + boxBArea - interArea + 1e-6)
85 | return iou
86 |
87 | def evaluate_model(tasks):
88 | results = []
89 | box_res =[]
90 | for task in tasks:
91 | logging.info(f"Processing task: {task}")
92 | ious = []
93 | screenspot_data = json.load(open(f"path/to/lisa_{task}.json", 'r'))
94 | data_per_gpu = math.ceil(len(screenspot_data) / int(os.environ['SPLIT_NUM']))
95 | start_idx = int(os.environ['SPLIT']) * data_per_gpu
96 | end_idx = min(start_idx + data_per_gpu, len(screenspot_data))
97 | screenspot_data = screenspot_data[start_idx:end_idx]
98 |
99 | for item in tqdm(screenspot_data):
100 | img_path = item['image_path']
101 | try:
102 | image = process_image(img_path)
103 | w, h = image.size
104 | instruction = item["instruction"][0]
105 | bbox = item["boxes"][0]
106 | bbox = [
107 | bbox[0] / image.size[0],
108 | bbox[1] / image.size[1],
109 | bbox[2] / image.size[0],
110 | bbox[3] / image.size[1],
111 | ]
112 | inputs = prepare_inputs(img_path, instruction)
113 |
114 | with torch.no_grad():
115 | generated_ids = model.generate(**inputs, max_new_tokens=128)
116 | response = processor.batch_decode(
117 | generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
118 | )[0]
119 | print(response)
120 | pattern = r"\(\s*(\d+)\s*,\s*(\d+)\s*\)\s*,\s*\(\s*(\d+)\s*,\s*(\d+)\s*\)"
121 | matches = re.findall(pattern, response)
122 | x1, y1, x2, y2 = map(int, matches[0])
123 | pred_bbox = [int(x1) / 1000, int(y1) / 1000, int(x2) / 1000, int(y2) / 1000]
124 |
125 | iou = compute_iou(pred_bbox, bbox)
126 | box_res.append(
127 | {
128 | "image_pth": item['image_path'],
129 | "pred_bbox": pred_bbox,
130 | "thinking_process": response
131 | }
132 | )
133 | ious.append(iou)
134 | except Exception as e:
135 | ious.append(0)
136 | json.dump(box_res, open(f"tmp/resbox_{os.environ['SPLIT']}_r1_w_think_7b.json", 'w'), indent=4)
137 | json.dump(ious, open(f"tmp/res_{os.environ['SPLIT']}.json", 'w'), indent=4)
138 | return
139 |
140 | if __name__ == "__main__":
141 | import argparse
142 |
143 | parser = argparse.ArgumentParser()
144 | parser.add_argument('--task', type=str, required=True)
145 | args = parser.parse_args()
146 |
147 | if args.task == "all":
148 | tasks = ["val", "test"]
149 | else:
150 | tasks = [args.task]
151 |
152 | results = evaluate_model(tasks)
153 |
--------------------------------------------------------------------------------
/lisa_evaluation/Qwen2_VL_lisa_infere.sh:
--------------------------------------------------------------------------------
1 | TASKS=("test" "val")
2 | # Adjust to your gpu num
3 | GPU_IDS=(0 1 2 3 4 5 6 7)
4 | SPLIT_NUM=8
5 |
6 | for task in "${TASKS[@]}"; do
7 | echo "Starting inference for task: $task"
8 |
9 | # 遍历 GPU 和 SPLIT
10 | for i in "${!GPU_IDS[@]}"; do
11 | GPU_ID=${GPU_IDS[$i]}
12 | SPLIT=$i
13 | echo "Launching task=$task on GPU=$GPU_ID with SPLIT=$SPLIT"
14 | SPLIT=$SPLIT SPLIT_NUM=$SPLIT_NUM python Qwen2_VL_lisa_infere.py \
15 | --task $task &
16 | sleep 1
17 | done
18 | wait
19 | echo "Merging results for task: $task"
20 | SPLIT_NUM=$SPLIT_NUM python merge_eval.py >> res.txt
21 | done
22 |
23 | echo "All tasks completed!"
24 |
--------------------------------------------------------------------------------
/lisa_evaluation/README.md:
--------------------------------------------------------------------------------
1 | ## ViRFT for reasoning grounding
2 |
3 | ## training
4 | 1. Download [LISA dataset](https://github.com/dvlab-research/LISA)
5 | 2. use `gen_box_ann.py` to generate box from mask.
6 | 3. use `gen_sft.py` to generate SFT/Visual-RFT training annotations.
7 | 4. use `src/scripts/2B_lisa_grounding.sh` to train the model, with annotation path changed to step.3 generated annotations.
8 |
9 | After training model, replace model path in `Qwen2_VL_lisa_infere.py` with your own ckpt.
10 |
11 | ```python
12 | # Load Qwen2-VL-2B model and processor
13 | model = Qwen2VLForConditionalGeneration.from_pretrained(
14 | "/path/to/your/checkpoint-498", torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2"
15 | ).eval()
16 |
17 | processor = AutoProcessor.from_pretrained("/path/to/your/checkpoint-498")
18 | ```
19 |
20 | to compute gIoU, follow the process bellow.
21 | 1. Use `box2mask.py` to extract mask from [SAM](https://github.com/facebookresearch/segment-anything)
22 | 2. Use `mask_iou` to comput mask IoU.
23 |
24 | ```shell
25 | cd lisa_evaluation
26 | bash Qwen2_VL_lisa_infere.sh
27 | ```
28 |
--------------------------------------------------------------------------------
/lisa_evaluation/box2mask.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import torch
4 | import numpy as np
5 | import cv2
6 | from pycocotools import mask as maskUtils
7 | from segment_anything import sam_model_registry, SamPredictor
8 |
9 | def load_json(json_path):
10 | with open(json_path, 'r') as f:
11 | data = json.load(f)
12 | return data
13 |
14 | def load_image(image_path):
15 | image = cv2.imread(image_path)
16 | if image is None:
17 | raise FileNotFoundError(f"Image not found at {image_path}")
18 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
19 | return image
20 |
21 | def scale_bbox(bbox, width, height):
22 | x1, y1, x2, y2 = bbox
23 | return [int(x1 * width), int(y1 * height), int(x2 * width), int(y2 * height)]
24 |
25 | def mask_to_rle(mask):
26 | rle = maskUtils.encode(np.asfortranarray(mask.astype(np.uint8)))
27 | rle["counts"] = rle["counts"].decode("utf-8")
28 | return rle
29 |
30 | def save_rle_to_json(rle_data_list, save_path):
31 | with open(save_path, 'w') as f:
32 | json.dump(rle_data_list, f)
33 | print(f"Saved all RLE data to {save_path}")
34 |
35 | def main(json_path, sam_checkpoint, output_json_path, model_type="vit_h"):
36 | device = "cuda" if torch.cuda.is_available() else "cpu"
37 |
38 | sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
39 | sam.to(device=device)
40 | predictor = SamPredictor(sam)
41 |
42 | data = load_json(json_path)
43 |
44 | all_rle_data = []
45 |
46 | for idx, item in enumerate(data):
47 | image_pth = item["image_pth"]
48 | pred_bbox = item["pred_bbox"]
49 |
50 | image = load_image(image_pth)
51 | height, width, _ = image.shape
52 |
53 | abs_bbox = scale_bbox(pred_bbox, width, height)
54 |
55 | predictor.set_image(image)
56 |
57 | transformed_bbox = predictor.transform.apply_boxes_torch(
58 | torch.tensor([abs_bbox], device=device), image.shape[:2]
59 | )
60 | masks, scores, _ = predictor.predict_torch(
61 | point_coords=None,
62 | point_labels=None,
63 | boxes=transformed_bbox,
64 | multimask_output=False,
65 | )
66 |
67 | mask = masks[0][0].cpu().numpy()
68 |
69 | rle = mask_to_rle(mask)
70 |
71 | rle_data = {
72 | "image_pth": image_pth,
73 | "pred_bbox": pred_bbox,
74 | "rle_mask": rle
75 | }
76 | all_rle_data.append(rle_data)
77 |
78 | save_rle_to_json(all_rle_data, output_json_path)
79 |
80 | if __name__ == "__main__":
81 | json_path = f"/path/to/resbox"
82 | sam_checkpoint = "sam_vit_h_4b8939.pth"
83 | output_json_path = f"./all_masks_rle.json"
84 |
85 | main(json_path, sam_checkpoint, output_json_path)
86 |
--------------------------------------------------------------------------------
/lisa_evaluation/evaluation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "d8fcd1a9-d7c5-4049-bfff-27dbc49b029d",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "from coco_evaluation import CocoDetectionEvaluator"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "id": "aa90e100-072b-4f55-999c-a8e03d56fe87",
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "evaluator = CocoDetectionEvaluator('./data/coco/annotations/instances_val2017.json')"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": null,
26 | "id": "585276a5-cfb2-4e2b-a9de-30a46cbf30c5",
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "results, per_class_results = evaluator.evaluate('./prediction_Qwen2_vl_2B_GRPO_coco.json', './results')"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": null,
36 | "id": "92d76ac4-44aa-43b5-8123-d871453e6750",
37 | "metadata": {
38 | "scrolled": true
39 | },
40 | "outputs": [],
41 | "source": [
42 | "### mAP and AP for all categories\n",
43 | "results, per_class_results"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": null,
49 | "id": "014151c7-5a90-4d46-9782-7195cb7147ed",
50 | "metadata": {},
51 | "outputs": [],
52 | "source": [
53 | "### mAP and AP for selected categories\n",
54 | "selected_cate = ['bus', 'train', 'fire hydrant', 'stop sign', 'cat', 'dog', 'bed', 'toilet']\n",
55 | "# selected_cate = ['mouse', 'fork', 'hot dog', 'cat', 'airplane', 'suitcase', 'parking meter', 'sandwich', 'train', 'hair drier', 'toilet', 'toaster', 'snowboard', 'frisbee', 'bear']\n",
56 | "results, per_class_results\n",
57 | "AP_sum = 0\n",
58 | "for item in per_class_results:\n",
59 | " for key, value in item.items():\n",
60 | " if key in selected_cate:\n",
61 | " print(f\"Key: {key}, Value: {value}\")\n",
62 | " AP_sum += value\n",
63 | "print(\"mAP for selected categories: \", (AP_sum)/(len(selected_cate)))"
64 | ]
65 | }
66 | ],
67 | "metadata": {
68 | "kernelspec": {
69 | "display_name": "Python 3 (ipykernel)",
70 | "language": "python",
71 | "name": "python3"
72 | },
73 | "language_info": {
74 | "codemirror_mode": {
75 | "name": "ipython",
76 | "version": 3
77 | },
78 | "file_extension": ".py",
79 | "mimetype": "text/x-python",
80 | "name": "python",
81 | "nbconvert_exporter": "python",
82 | "pygments_lexer": "ipython3",
83 | "version": "3.10.13"
84 | }
85 | },
86 | "nbformat": 4,
87 | "nbformat_minor": 5
88 | }
89 |
--------------------------------------------------------------------------------
/lisa_evaluation/gen_box_ann.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from PIL import Image, ImageDraw
4 |
5 | res = []
6 | base_path = "/path/to/your/LISA-main/data/train"
7 |
8 | for pth in os.listdir(base_path):
9 | if pth.endswith(".json"):
10 | json_path = os.path.join(base_path, pth)
11 |
12 | with open(json_path, 'r') as f:
13 | item = json.load(f)
14 |
15 | instruct = item["text"]
16 | shapes = item["shapes"]
17 |
18 | boxes = []
19 | for shape in shapes[:1]:
20 | points = shape["points"]
21 | x_coords = [p[0] for p in points]
22 | y_coords = [p[1] for p in points]
23 |
24 | x_min, x_max = min(x_coords), max(x_coords)
25 | y_min, y_max = min(y_coords), max(y_coords)
26 | boxes.append((x_min, y_min, x_max, y_max))
27 |
28 | img_path = json_path.replace(".json", ".jpg")
29 | if os.path.exists(img_path):
30 | res.append({
31 | "image_path": img_path,
32 | "instruction": instruct,
33 | "boxes": boxes
34 | })
35 |
36 | json.dump(res, open("lisa_train.json", 'w'), indent=4)
37 |
--------------------------------------------------------------------------------
/lisa_evaluation/gen_sft.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from PIL import Image
4 | res = []
5 | index = 0
6 | for i, item in enumerate(json.load(open("path/to/your/lisa_train.json", 'r'))):
7 | for instruct in item['instruction']:
8 | w, h= Image.open(item['image_path']).size
9 | res.append({
10 | "id": f"lisa_{index}",
11 | "conversations": [
12 | {
13 | "from": "user",
14 | "value": f"
{item['image_path']}\n Output the bounding box in the image corresponding to the instruction: {instruct}"
15 | },
16 | {
17 | "from": "assistant",
18 | "value": f"({int(item['boxes'][0][0] / w * 1000)},{int(item['boxes'][0][1] / h * 1000)}),({int(item['boxes'][0][2] / w * 1000)},{int(item['boxes'][0][3] / h * 1000)})"
19 | }
20 | ]
21 | })
22 | index += 1
23 | json.dump(res, open("lisa_train_sft.json", 'w'), indent=4)
24 |
--------------------------------------------------------------------------------
/lisa_evaluation/mask_iou.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy as np
4 | from pycocotools import mask as maskUtils
5 | from PIL import Image, ImageDraw
6 | from tqdm import tqdm
7 |
8 | def polygon_to_mask(polygon, size):
9 | mask = Image.new('L', (size[1], size[0]), 0)
10 | ImageDraw.Draw(mask).polygon([tuple(point) for point in polygon], outline=1, fill=1)
11 | return np.array(mask, dtype=np.uint8)
12 |
13 | def rle_to_mask(rle):
14 | rle_decoded = maskUtils.decode(rle)
15 | return rle_decoded.astype(np.uint8)
16 |
17 | def compute_iou(pred_mask, gt_mask):
18 | intersection = np.logical_and(pred_mask, gt_mask).sum()
19 | union = np.logical_or(pred_mask, gt_mask).sum()
20 | iou = intersection / (union + 1e-6)
21 | return intersection, union, iou
22 |
23 | res = []
24 | base_path = "/path/to/LISA-main/data/test"
25 |
26 | test_res = {}
27 |
28 | with open(f"all_masks_rle.json", 'r') as f:
29 | for item in json.load(f):
30 | test_res[item['image_pth']] = item['rle_mask']
31 |
32 | total_intersection = 0
33 | total_union = 0
34 | ious = []
35 |
36 | for pth in tqdm(os.listdir(base_path)):
37 | if pth.endswith(".json"):
38 | json_path = os.path.join(base_path, pth)
39 |
40 | with open(json_path, 'r') as f:
41 | item = json.load(f)
42 |
43 | try:
44 | gt_polygon = item["shapes"][0]['points']
45 | except:
46 | print("no res")
47 | continue
48 |
49 | image_pth = f"/path/to/LISA-main/data/test/{pth.replace('.json', '.jpg')}"
50 | pred_rle = test_res.get(image_pth, None)
51 |
52 |
53 | with Image.open(image_pth) as img:
54 | img_width, img_height = img.size
55 | real_size = [img_height, img_width]
56 |
57 | if pred_rle is not None:
58 | size = pred_rle['size']
59 |
60 | if size != real_size:
61 | print("shape mismatch")
62 | pred_mask = np.zeros((size[0], size[1]), dtype=np.uint8)
63 | continue
64 | pred_mask = rle_to_mask(pred_rle)
65 | else:
66 | print(f"Warning: No prediction found for {image_pth}. Setting IoU to 0.")
67 | pred_mask = np.zeros((size[0], size[1]), dtype=np.uint8)
68 |
69 | gt_mask = polygon_to_mask(gt_polygon, real_size)
70 | intersection, union, iou = compute_iou(pred_mask, gt_mask)
71 |
72 | total_intersection += intersection
73 | total_union += union
74 |
75 | ious.append(iou)
76 |
77 | gIoU = np.mean(ious) if ious else 0
78 | cIoU = total_intersection / (total_union + 1e-6)
79 |
80 | print(f"gIoU (Global IoU): {gIoU:.4f}")
81 | print(f"cIoU (Cumulative IoU): {cIoU:.4f}")
82 |
--------------------------------------------------------------------------------
/lisa_evaluation/merge_eval.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | merged = []
4 | for i in range(int(os.environ['SPLIT_NUM'])):
5 | data = json.load(open(f"tmp/res_{i}.json", 'r'))
6 | merged += data
7 | print(f"mIoU: {sum(merged) / len(merged)}")
8 |
--------------------------------------------------------------------------------
/setup.sh:
--------------------------------------------------------------------------------
1 | cd src/virft
2 | pip install -e ".[dev]"
3 |
4 | # Addtional modules
5 | pip install wandb==0.18.3
6 | pip install tensorboardx
7 | pip install qwen_vl_utils torchvision
8 | pip install flash-attn --no-build-isolation
9 |
10 | # vLLM support
11 | pip install vllm==0.7.2
12 |
13 | # fix transformers version
14 | pip install git+https://github.com/huggingface/transformers.git@336dc69d63d56f232a183a3e7f52790429b871ef
15 |
--------------------------------------------------------------------------------
/src/scripts/2B_aircraft_4_shot.sh:
--------------------------------------------------------------------------------
1 | export DEBUG_MODE="true"
2 | export LOG_PATH="./debug_log_2b_GRPO_aircraft_4_shot.txt"
3 |
4 | export DATA_PATH=./share_data/ViRFT_CLS_fgvc_aircraft_4_shot
5 | export CKPT_PATH=./share_models/Qwen2-VL-2B-Instruct
6 | export SAVE_PATH=./share_models/Qwen2-VL-2B-Instruct_GRPO_aircraft_4_shot
7 |
8 |
9 | torchrun --nproc_per_node="8" \
10 | --nnodes="1" \
11 | --node_rank="0" \
12 | --master_addr="127.0.0.1" \
13 | --master_port="12345" \
14 | /src/open_r1/grpo_classification.py \
15 | --output_dir ${SAVE_PATH} \
16 | --model_name_or_path ${CKPT_PATH} \
17 | --dataset_name ${DATA_PATH} \
18 | --deepspeed /local_scripts/zero3.json \
19 | --max_prompt_length 1024 \
20 | --per_device_train_batch_size 1 \
21 | --gradient_accumulation_steps 2 \
22 | --logging_steps 1 \
23 | --bf16 \
24 | --report_to wandb \
25 | --gradient_checkpointing false \
26 | --attn_implementation flash_attention_2 \
27 | --max_pixels 401408 \
28 | --num_train_epochs 1 \
29 | --run_name Qwen2-VL-2B_GRPO_aircraft100_4shot \
30 | --save_steps 100 \
31 | --save_only_model true \
32 | --num_generations 8
33 |
--------------------------------------------------------------------------------
/src/scripts/2B_base65cate_6k.sh:
--------------------------------------------------------------------------------
1 | export DATA_PATH=./share_data/base65cate_6k_think
2 | export CKPT_PATH=./share_models/Qwen2-VL-2B-Instruct
3 | export SAVE_PATH=./share_models/Qwen2-VL-2B-Instruct_GRPO_coco_base65cate_6k
4 |
5 | export DEBUG_MODE="true" # Enable Debug if you want to see the rollout of model during RL
6 | export LOG_PATH="./debug_log_2b_GRPO_coco_base65cate_6k.txt"
7 |
8 | torchrun --nproc_per_node="8" \
9 | --nnodes="1" \
10 | --node_rank="0" \
11 | --master_addr="127.0.0.1" \
12 | --master_port="12345" \
13 | src/open_r1/grpo.py \
14 | --output_dir ${SAVE_PATH} \
15 | --model_name_or_path ${CKPT_PATH} \
16 | --dataset_name ${DATA_PATH} \
17 | --deepspeed ./local_scripts/zero3.json \
18 | --max_prompt_length 1024 \
19 | --per_device_train_batch_size 1 \
20 | --gradient_accumulation_steps 2 \
21 | --logging_steps 1 \
22 | --bf16 \
23 | --report_to wandb \
24 | --gradient_checkpointing false \
25 | --attn_implementation flash_attention_2 \
26 | --max_pixels 401408 \
27 | --num_train_epochs 2 \
28 | --run_name Qwen2-VL-2B_GRPO_coco_base65cate_6k \
29 | --save_steps 100 \
30 | --save_only_model true \
31 | --num_generations 8 '
32 |
--------------------------------------------------------------------------------
/src/scripts/2B_lisa_grounding.sh:
--------------------------------------------------------------------------------
1 | # cd src/open-r1-multimodal
2 |
3 | export DEBUG_MODE="true"
4 | export LOG_PATH="./debug_log_2b.txt"
5 | torchrun --nproc_per_node="8" \
6 | --nnodes="1" \
7 | --node_rank="0" \
8 | --master_addr="127.0.0.1" \
9 | --master_port="12346" \
10 | src/open_r1/grpo_gui_grounding_lisa.py \
11 | --output_dir "out/lisa_train_GIoU" \
12 | --model_name_or_path Qwen/Qwen2-VL-2B-Instruct \
13 | --dataset_name NOT_USED \
14 | --deepspeed local_scripts/zero3.json \
15 | --max_prompt_length 1024 \
16 | --per_device_train_batch_size 1 \
17 | --gradient_accumulation_steps 2 \
18 | --logging_steps 1 \
19 | --bf16 \
20 | --gradient_checkpointing true \
21 | --attn_implementation flash_attention_2 \
22 | --max_pixels 401408 \
23 | --num_train_epochs 6 \
24 | --run_name Qwen2-VL-2B-GRPO-groud_lisa_train \
25 | --save_steps 50 \
26 | --save_only_model true \
27 | --num_generations 8 # number of outputs G in grpo, reduce it would lead to faster training and smaller memory cost but higher variance
28 |
--------------------------------------------------------------------------------
/src/scripts/7B_base65cate_6k.sh:
--------------------------------------------------------------------------------
1 | export DATA_PATH=./share_data/base65cate_6k_think
2 | export CKPT_PATH=./share_models/Qwen2-VL-7B-Instruct
3 | export SAVE_PATH=./share_models/Qwen2-VL-7B-Instruct_GRPO_coco_base65cate_6k
4 |
5 | export DEBUG_MODE="true" # Enable Debug if you want to see the rollout of model during RL
6 | export LOG_PATH="./debug_log_7b_GRPO_coco_base65cate_6k.txt"
7 |
8 | torchrun --nproc_per_node="8" \
9 | --nnodes="1" \
10 | --node_rank="0" \
11 | --master_addr="127.0.0.1" \
12 | --master_port="12345" \
13 | src/open_r1/grpo.py \
14 | --output_dir ${SAVE_PATH} \
15 | --model_name_or_path ${CKPT_PATH} \
16 | --dataset_name ${DATA_PATH} \
17 | --deepspeed ./local_scripts/zero3.json \
18 | --max_prompt_length 1024 \
19 | --per_device_train_batch_size 1 \
20 | --gradient_accumulation_steps 2 \
21 | --logging_steps 1 \
22 | --bf16 \
23 | --report_to wandb \
24 | --gradient_checkpointing False \
25 | --attn_implementation flash_attention_2 \
26 | --max_pixels 401408 \
27 | --num_train_epochs 2 \
28 | --run_name Qwen2-VL-7B_GRPO_coco_base65cate_6k \
29 | --save_steps 100 \
30 | --save_only_model true \
31 | --num_generations 4'
32 |
--------------------------------------------------------------------------------
/src/scripts/example.sh:
--------------------------------------------------------------------------------
1 | export DEBUG_MODE="true"
2 | export LOG_PATH="./debug_log_2b_GRPO_coco_base65cate_6k.txt"
3 |
4 | export DATA_PATH=./share_data/base65cate_6k_think
5 | export CKPT_PATH=./share_models/Qwen2-VL-2B-Instruct
6 | export SAVE_PATH=./share_models/Qwen2-VL-2B-Instruct_GRPO_coco_base65cate_6k
7 |
8 |
9 | torchrun --nproc_per_node="8" \
10 | --nnodes="1" \
11 | --node_rank="0" \
12 | --master_addr="127.0.0.1" \
13 | --master_port="12345" \
14 | src/open_r1/grpo.py \
15 | --output_dir ${SAVE_PATH} \
16 | --model_name_or_path ${CKPT_PATH} \
17 | --dataset_name ${DATA_PATH} \
18 | --deepspeed local_scripts/zero3.json \
19 | --max_prompt_length 1024 \
20 | --per_device_train_batch_size 1 \
21 | --gradient_accumulation_steps 2 \
22 | --logging_steps 1 \
23 | --bf16 \
24 | --report_to wandb \
25 | --gradient_checkpointing false \
26 | --attn_implementation flash_attention_2 \
27 | --max_pixels 401408 \
28 | --num_train_epochs 2 \
29 | --run_name Qwen2-VL-2B_GRPO_coco_base65cate_6k \
30 | --save_steps 100 \
31 | --save_only_model true \
32 | --num_generations 8 '
33 |
--------------------------------------------------------------------------------
/src/virft/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # PyPI configuration file
171 | .pypirc
172 |
173 | # Temp folders
174 | data/
175 | wandb/
176 | scripts/
177 | checkpoints/
178 | .vscode/
--------------------------------------------------------------------------------
/src/virft/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: style quality
2 |
3 | # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
4 | export PYTHONPATH = src
5 |
6 | check_dirs := src
7 |
8 | style:
9 | black --line-length 119 --target-version py310 $(check_dirs) setup.py
10 | isort $(check_dirs) setup.py
11 |
12 | quality:
13 | black --check --line-length 119 --target-version py310 $(check_dirs) setup.py
14 | isort --check-only $(check_dirs) setup.py
15 | flake8 --max-line-length 119 $(check_dirs) setup.py
16 |
17 |
18 | # Evaluation
19 |
20 | evaluate:
21 |
--------------------------------------------------------------------------------
/src/virft/README.md:
--------------------------------------------------------------------------------
1 | # Visual-RFT
2 |
--------------------------------------------------------------------------------
/src/virft/configs/ddp.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | distributed_type: MULTI_GPU
4 | downcast_bf16: 'no'
5 | gpu_ids: all
6 | machine_rank: 0
7 | main_training_function: main
8 | mixed_precision: bf16
9 | num_machines: 1
10 | num_processes: 8
11 | rdzv_backend: static
12 | same_network: true
13 | tpu_env: []
14 | tpu_use_cluster: false
15 | tpu_use_sudo: false
16 | use_cpu: false
17 |
--------------------------------------------------------------------------------
/src/virft/configs/zero2.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_multinode_launcher: standard
5 | offload_optimizer_device: none
6 | offload_param_device: none
7 | zero3_init_flag: false
8 | zero_stage: 2
9 | distributed_type: DEEPSPEED
10 | downcast_bf16: 'no'
11 | machine_rank: 0
12 | main_training_function: main
13 | mixed_precision: bf16
14 | num_machines: 1
15 | num_processes: 8
16 | rdzv_backend: static
17 | same_network: true
18 | tpu_env: []
19 | tpu_use_cluster: false
20 | tpu_use_sudo: false
21 | use_cpu: false
--------------------------------------------------------------------------------
/src/virft/configs/zero3.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_multinode_launcher: standard
5 | offload_optimizer_device: none
6 | offload_param_device: none
7 | zero3_init_flag: true
8 | zero3_save_16bit_model: true
9 | zero_stage: 3
10 | distributed_type: DEEPSPEED
11 | downcast_bf16: 'no'
12 | machine_rank: 0
13 | main_training_function: main
14 | mixed_precision: bf16
15 | num_machines: 1
16 | num_processes: 8
17 | rdzv_backend: static
18 | same_network: true
19 | tpu_env: []
20 | tpu_use_cluster: false
21 | tpu_use_sudo: false
22 | use_cpu: false
23 |
--------------------------------------------------------------------------------
/src/virft/local_scripts/create_vision_cot_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import base64
3 | import concurrent.futures
4 | import io
5 | import json
6 | import os
7 | import random
8 | import re
9 | import time
10 | from concurrent.futures import ThreadPoolExecutor
11 | from functools import partial
12 | from io import BytesIO
13 | from typing import Dict, List
14 |
15 | import matplotlib.pyplot as plt
16 | import numpy as np
17 | import pandas as pd
18 | from datasets import Dataset, concatenate_datasets, load_dataset, load_from_disk
19 | from tqdm import tqdm
20 |
21 | import bytedtos
22 | import seaborn as sns
23 | import yaml
24 | from openai import AzureOpenAI
25 | from PIL import Image
26 | from pillow_avif import AvifImagePlugin
27 |
28 |
29 | PROMPT_FORMAT = """I will provide you with an image, an original question, and its answer related to the image. Your task is to rewrite the question in such a way that answering it requires step-by-step Chain-of-Thought (CoT) reasoning with numerical or mathematical expressions where applicable. The reasoning process can include expressions like "let me think," "oh, I see," or other natural language thought expressions.
30 |
31 | Please make sure your question is to ask for a certain answer with a certain value, do not ask for open-ended answer, and the answer is correct and easy to verify via simple protocol, like "2" or "A".
32 |
33 | Please strictly do not include "Answer:" in the question part to avoid confusion and leakage.
34 |
35 | Input Format:
36 | Original Question: {original_question}
37 | Original Answer: {original_answer}
38 |
39 | Output Format:
40 | Question: [rewrite the question if necessary]
41 | Answer: [answer with reasoning steps, including calculations where applicable]
42 | step-by-step reasoning process
43 | easy to verify answer
44 | """
45 |
46 |
47 | def get_image_data_url(image_input):
48 | if isinstance(image_input, str) and image_input.startswith("data:"):
49 | return image_input
50 |
51 | if isinstance(image_input, str) and image_input.startswith("http"):
52 | image_input = load_image(image_input)
53 |
54 | if isinstance(image_input, str):
55 | image_input = Image.open(image_input)
56 |
57 | if not isinstance(image_input, Image.Image):
58 | raise ValueError("Unsupported image input type")
59 |
60 | if image_input.mode != "RGB":
61 | image_input = image_input.convert("RGB")
62 |
63 | buffer = BytesIO()
64 | image_input.save(buffer, format="JPEG")
65 | img_bytes = buffer.getvalue()
66 | base64_data = base64.b64encode(img_bytes).decode("utf-8")
67 | return f"data:image/jpeg;base64,{base64_data}"
68 |
69 |
70 | def gpt4o_query(image, prompt, max_retries=5, initial_delay=3):
71 | if image is None:
72 | return None
73 |
74 | data_url_list = [get_image_data_url(image)]
75 | client = AzureOpenAI(
76 | azure_endpoint="YOUR_AZURE_ENDPOINT",
77 | api_version="2023-07-01-preview",
78 | api_key="YOUR_API_KEY",
79 | )
80 |
81 | for attempt in range(max_retries):
82 | try:
83 | messages = [
84 | {
85 | "role": "system",
86 | "content": "You are an expert to analyze the image and provide useful information for users.",
87 | },
88 | {
89 | "role": "user",
90 | "content": [
91 | {"type": "text", "text": prompt},
92 | ],
93 | },
94 | ]
95 |
96 | for data_url in data_url_list:
97 | messages[1]["content"].insert(
98 | 0, {"type": "image_url", "image_url": {"url": data_url}}
99 | )
100 |
101 | response = client.chat.completions.create(
102 | model="gpt-4o-2024-08-06",
103 | messages=messages,
104 | temperature=0.2,
105 | max_tokens=8192,
106 | )
107 | return response.choices[0].message.content
108 |
109 | except Exception as e:
110 | if attempt == max_retries - 1:
111 | raise Exception(
112 | f"Failed after {max_retries} attempts. Last error: {str(e)}"
113 | )
114 | delay = initial_delay * (2**attempt) + random.uniform(
115 | 0, 0.1 * initial_delay * (2**attempt)
116 | )
117 | time.sleep(delay)
118 |
119 |
120 | def process_single_item(example):
121 | try:
122 | image_path = example["image_path"]
123 | formatted_prompt = PROMPT_FORMAT.format(
124 | original_question=example["question"], original_answer=example["answer"]
125 | )
126 |
127 | response = gpt4o_query(image_path, formatted_prompt)
128 | example["gpt4o_response"] = response
129 | return example
130 | except Exception as e:
131 | print(f"Error processing item: {str(e)}")
132 | example["gpt4o_response"] = None
133 | return example
134 |
135 |
136 | def main():
137 | dataset_path = "path/to/your/dataset"
138 | full_dataset = load_from_disk(dataset_path)
139 |
140 | processed_dataset = full_dataset.map(
141 | function=partial(process_single_item),
142 | num_proc=256,
143 | desc="Processing dataset with GPT-4o",
144 | keep_in_memory=True,
145 | )
146 |
147 | output_path = f"{dataset_path}_processed"
148 | processed_dataset.save_to_disk(output_path)
149 | print(f"Processed dataset saved to: {output_path}")
150 |
151 |
152 | if __name__ == "__main__":
153 | main()
154 |
--------------------------------------------------------------------------------
/src/virft/local_scripts/lmms_eval_qwen2vl.sh:
--------------------------------------------------------------------------------
1 | export HF_HOME=""
2 | export HF_TOKEN=""
3 | export HF_HUB_ENABLE_HF_TRANSFER="1"
4 |
5 | export API_TYPE=""
6 | export AZURE_ENDPOINT=""
7 | export AZURE_API_KEY=""
8 | export API_VERSION=""
9 | export MODEL_VERSION=""
10 | export NAVIT_ATTENTION_IMPLEMENTATION="eager"
11 |
12 | # Prompt for installation with 3-second timeout
13 | read -t 3 -p "Do you want to install dependencies? (YES/no, timeout in 3s): " install_deps || true
14 | if [ "$install_deps" = "YES" ]; then
15 | # Prepare the environment
16 | pip3 install --upgrade pip
17 | pip3 install -U setuptools
18 |
19 | cd
20 | if [ ! -d "maas_engine" ]; then
21 | git clone
22 | else
23 | echo "maas_engine directory already exists, skipping clone"
24 | fi
25 | cd maas_engine
26 | git pull
27 | git checkout
28 | pip3 install --no-cache-dir --no-build-isolation -e ".[standalone]"
29 |
30 | current_version=$(pip3 show transformers | grep Version | cut -d' ' -f2)
31 | if [ "$current_version" != "4.46.2" ]; then
32 | echo "Installing transformers 4.46.2 (current version: $current_version)"
33 | pip3 install transformers==4.46.2
34 | else
35 | echo "transformers 4.46.2 is already installed"
36 | fi
37 |
38 | cd
39 | rm -rf
40 | pip3 install -e .
41 | pip3 install -U pydantic
42 | pip3 install Levenshtein
43 | pip3 install nltk
44 | python3 -c "import nltk; nltk.download('wordnet', quiet=True); nltk.download('punkt', quiet=True)"
45 | fi
46 |
47 | TASKS=mmmu_val,mathvista_testmini,mmmu_pro
48 | MODEL_BASENAME=qwen2_vl
49 |
50 | model_checkpoint=""
51 | echo "MODEL_BASENAME: ${MODEL_BASENAME}"
52 | cd
53 |
54 | python3 -m accelerate.commands.launch --num_processes=8 --main_process_port=12345 lmms_eval \
55 | --model qwen2_vl \
56 | --model_args=pretrained=${model_checkpoint},max_pixels=2359296 \
57 | --tasks ${TASKS} \
58 | --batch_size 1 \
59 | --log_samples \
60 | --log_samples_suffix ${MODEL_BASENAME} \
61 | --output_path ./logs
--------------------------------------------------------------------------------
/src/virft/local_scripts/prepare_hf_data.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import seaborn as sns
3 | import pandas as pd
4 | import random
5 | from typing import List, Dict
6 | import numpy as np
7 | from concurrent.futures import ThreadPoolExecutor
8 | from tqdm import tqdm
9 | import datasets
10 |
11 | import io
12 | from datasets import load_dataset, load_from_disk, concatenate_datasets
13 | from PIL import Image
14 | from tqdm import tqdm
15 | from functools import partial
16 | from pillow_avif import AvifImagePlugin
17 | from datasets import Dataset
18 | import json
19 | import yaml
20 | import os
21 | import re
22 | import time
23 | import random
24 | import base64
25 | from openai import AzureOpenAI
26 | import concurrent.futures
27 | from typing import List, Dict
28 | import argparse
29 | import time
30 |
31 |
32 | def extract_problem_solution(gpt4o_response):
33 | # Split the response into parts
34 | parts = gpt4o_response.split("")
35 |
36 | # Extract the problem (first part before any tags)
37 | problem = parts[0].strip()
38 | # Remove "Question:" prefix if it exists
39 | problem = re.sub(r"^Question:\s*", "", problem)
40 | # Remove "Answer:" at the end of the problem
41 | problem = re.sub(r"\s*Answer:\s*$", "", problem).strip()
42 |
43 | # Combine all the reasoning steps into a single block
44 | think_parts = [p.split("")[0].strip() for p in parts[1:] if "" in p]
45 | solution = f"{' '.join(think_parts)}"
46 |
47 | # Add the final answer if it exists, removing "Answer:" prefix
48 | if "" in gpt4o_response:
49 | final_answer = (
50 | gpt4o_response.split("")[-1].split("")[0].strip()
51 | )
52 | final_answer = re.sub(r"^Answer:\s*", "", final_answer)
53 | solution += f"\n\n{final_answer}"
54 |
55 | return problem, solution
56 |
57 |
58 | def load_image_from_path(image_path):
59 | try:
60 | img = Image.open(image_path)
61 | return img
62 | except Exception as e:
63 | print(f"Error loading image {image_path}: {str(e)}")
64 | return None
65 |
66 |
67 | def process_raw_data(raw_data):
68 | # Parse the raw data if it's a string
69 | if isinstance(raw_data, str):
70 | data = json.loads(raw_data)
71 | else:
72 | data = raw_data
73 |
74 | # Extract problem and solution
75 | try:
76 | problem, solution = extract_problem_solution(data["gpt4o_response"])
77 | image = load_image_from_path(data["image_path"])
78 |
79 | return {
80 | "image": image,
81 | "problem": problem,
82 | "solution": solution,
83 | "original_question": data["question"],
84 | "original_answer": data["answer"],
85 | }
86 | except Exception as e:
87 | print(f"Error processing data {data}: {str(e)}")
88 | return {
89 | "image": None,
90 | "problem": None,
91 | "solution": None,
92 | "original_question": None,
93 | "original_answer": None,
94 | }
95 |
96 |
97 | raw_data_list = [
98 | "/path/to/reasoning_data_with_response_90k_verified",
99 | ]
100 |
101 | raw_data = concatenate_datasets([load_from_disk(path) for path in raw_data_list])
102 |
103 | processed_data = raw_data.map(process_raw_data, num_proc=128).shuffle(seed=42)
104 |
105 | hf_dict = {
106 | "image": [],
107 | "problem": [],
108 | "solution": [],
109 | "original_question": [],
110 | "original_answer": [],
111 | }
112 |
113 | for item in tqdm(processed_data):
114 | hf_dict["image"].append(item["image"])
115 | hf_dict["problem"].append(item["problem"])
116 | hf_dict["solution"].append(item["solution"])
117 | hf_dict["original_question"].append(item["original_question"])
118 | hf_dict["original_answer"].append(item["original_answer"])
119 |
120 |
121 | features = datasets.Features(
122 | {
123 | "image": datasets.Image(),
124 | "problem": datasets.Value("string"),
125 | "solution": datasets.Value("string"),
126 | "original_question": datasets.Value("string"),
127 | "original_answer": datasets.Value("string"),
128 | }
129 | )
130 |
131 |
132 | def has_empty_tags(text):
133 | # Pattern to match empty tags like
134 | pattern = r"<[^>]+>[^>]+>"
135 | return bool(re.search(pattern, text))
136 |
137 |
138 | def has_answer_pattern(text):
139 | if "Answer:" in text:
140 | return True
141 | return False
142 |
143 |
144 | def has_valid_image_size(example): # for Qwen2-VL-2B's processor requirement
145 | # Assuming the image is in a format that can be checked for dimensions
146 | # You might need to adjust this depending on how the image is stored in your dataset
147 | try:
148 | image = example["image"] # or however your image is accessed
149 | if isinstance(image, dict) and "height" in image and "width" in image:
150 | return image["height"] >= 28 and image["width"] >= 28
151 | # If image is a PIL Image or similar
152 | return image.height >= 28 and image.width >= 28
153 | except:
154 | return False
155 |
156 |
157 | ds = datasets.Dataset.from_dict(hf_dict, features=features)
158 | ds = ds.filter(
159 | lambda x: not has_empty_tags(x["solution"])
160 | and not has_answer_pattern(x["problem"])
161 | and has_valid_image_size(x)
162 | and x["image"] is not None,
163 | num_proc=128,
164 | )
165 | # Push to Hugging Face Hub
166 | ds.push_to_hub("path/to/your/dataset")
167 |
--------------------------------------------------------------------------------
/src/virft/local_scripts/train_aria_moe.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export NCCL_BLOCKING_WAIT=0
4 | export TOKENIZERS_PARALLELISM=false
5 | export OMP_NUM_THREADS=8
6 | export NCCL_IB_DISABLE=0
7 | export NCCL_IB_GID_INDEX=3
8 | export NCCL_SOCKET_IFNAME=eth0
9 | export NCCL_DEBUG=INFO
10 |
11 | # CONFIG Huggingface
12 | # export HF_TOKEN=""
13 | export HF_TOKEN=""
14 | export HF_HOME="$HOME/.cache/huggingface"
15 | export HF_HUB_ENABLE_HF_TRANSFER="1"
16 |
17 | export NCCL_DEBUG=INFO
18 |
19 | GPUS="0,1,2,3,4,5,6,7"
20 |
21 | # 取 worker0 第一个 port
22 | ports=($(echo $METIS_WORKER_0_PORT | tr ',' ' '))
23 | port=${ports[0]}
24 | port_in_cmd="$(echo "${METIS_WORKER_0_PORT:-2000}" | awk -F',' '{print $1}')"
25 |
26 | echo "total workers: ${ARNOLD_WORKER_NUM}"
27 | echo "cur worker id: ${ARNOLD_ID}"
28 | echo "gpus per worker: ${ARNOLD_WORKER_GPU}"
29 | echo "master ip: ${METIS_WORKER_0_HOST}"
30 | echo "master port: ${port}"
31 | echo "master port in cmd: ${port_in_cmd}"
32 |
33 | # export WANDB_BASE_URL=https://api.wandb.ai
34 | # export WANDB_API_KEY=""
35 | # wandb login $WANDB_API_KEY
36 |
37 | export WANDB_BASE_URL=https://api.wandb.ai
38 | export WANDB_PROJECT=vision-reasoning
39 | export WANDB_API_KEY=""
40 | export WANDB_RUN_NAME=Qwen-VL-2B-GRPO-$(date +%Y-%m-%d-%H-%M-%S)
41 | wandb login $WANDB_API_KEY
42 |
43 | cd /home/tiger/multimodal-open-r1
44 | # pip3 install vllm==0.6.6.post1
45 | pip3 install -e ".[dev]"
46 | pip3 install wandb==0.18.3
47 |
48 | torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" \
49 | --nnodes="${ARNOLD_WORKER_NUM}" \
50 | --node_rank="${ARNOLD_ID}" \
51 | --master_addr="${METIS_WORKER_0_HOST}" \
52 | --master_port="${port_in_cmd}" \
53 | src/open_r1/grpo.py \
54 | --deepspeed scripts/zero3.json \
55 | --output_dir Aria-GRPO-mini_cot_80k \
56 | --model_name_or_path rhymes-ai/Aria \
57 | --dataset_name luodian/mini_cot_80k \
58 | --max_prompt_length 8192 \
59 | --per_device_train_batch_size 1 \
60 | --gradient_accumulation_steps 1 \
61 | --logging_steps 1 \
62 | --bf16 \
63 | --report_to wandb \
64 | --gradient_checkpointing true \
65 | --attn_implementation eager \
66 | --save_total_limit 8 \
67 | --num_train_epochs 1 \
68 | --run_name $WANDB_RUN_NAME
69 |
--------------------------------------------------------------------------------
/src/virft/local_scripts/train_qwen2_vl.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export NCCL_BLOCKING_WAIT=0
4 | export TOKENIZERS_PARALLELISM=false
5 | export OMP_NUM_THREADS=8
6 | export NCCL_IB_DISABLE=0
7 | export NCCL_IB_GID_INDEX=3
8 | export NCCL_SOCKET_IFNAME=eth0
9 | export NCCL_DEBUG=INFO
10 |
11 | GPUS="0,1,2,3,4,5,6,7"
12 |
13 | # 取 worker0 第一个 port
14 | ports=($(echo $METIS_WORKER_0_PORT | tr ',' ' '))
15 | port=${ports[0]}
16 | port_in_cmd="$(echo "${METIS_WORKER_0_PORT:-2000}" | awk -F',' '{print $1}')"
17 |
18 | echo "total workers: ${ARNOLD_WORKER_NUM}"
19 | echo "cur worker id: ${ARNOLD_ID}"
20 | echo "gpus per worker: ${ARNOLD_WORKER_GPU}"
21 | echo "master ip: ${METIS_WORKER_0_HOST}"
22 | echo "master port: ${port}"
23 | echo "master port in cmd: ${port_in_cmd}"
24 |
25 | # export WANDB_BASE_URL=https://api.wandb.ai
26 | # export WANDB_API_KEY=""
27 | # wandb login $WANDB_API_KEY
28 |
29 | export WANDB_BASE_URL=https://api.wandb.ai
30 | export WANDB_PROJECT=vision-reasoning
31 | export WANDB_API_KEY=""
32 | export WANDB_RUN_NAME=Qwen-VL-2B-GRPO-$(date +%Y-%m-%d-%H-%M-%S)
33 | wandb login $WANDB_API_KEY
34 |
35 | cd /home/tiger/multimodal-open-r1
36 | # pip3 install vllm==0.6.6.post1
37 | pip3 install -e ".[dev]"
38 | pip3 install wandb==0.18.3
39 |
40 | torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" \
41 | --nnodes="${ARNOLD_WORKER_NUM}" \
42 | --node_rank="${ARNOLD_ID}" \
43 | --master_addr="${METIS_WORKER_0_HOST}" \
44 | --master_port="${port_in_cmd}" \
45 | src/open_r1/grpo.py \
46 | --deepspeed scripts/zero3.json \
47 | --output_dir checkpoints/${WANDB_RUN_NAME} \
48 | --model_name_or_path Qwen/Qwen2-VL-2B-Instruct \
49 | --dataset_name luodian/${DATASET_NAME} \
50 | --max_prompt_length 8192 \
51 | --per_device_train_batch_size 1 \
52 | --gradient_accumulation_steps 1 \
53 | --logging_steps 1 \
54 | --bf16 \
55 | --report_to wandb \
56 | --gradient_checkpointing true \
57 | --attn_implementation flash_attention_2 \
58 | --max_pixels 2359296 \
59 | --save_total_limit 8 \
60 | --num_train_epochs 1 \
61 | --run_name $WANDB_RUN_NAME
62 |
--------------------------------------------------------------------------------
/src/virft/local_scripts/zero2.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 | "zero_optimization": {
23 | "stage": 2,
24 | "offload_optimizer": {
25 | "device": "none",
26 | "pin_memory": true
27 | },
28 | "allgather_partitions": true,
29 | "allgather_bucket_size": 2e8,
30 | "overlap_comm": false,
31 | "reduce_scatter": true,
32 | "reduce_bucket_size": 2e8,
33 | "contiguous_gradients": true
34 | },
35 | "gradient_accumulation_steps": "auto",
36 | "gradient_clipping": "auto",
37 | "steps_per_print": 100,
38 | "train_batch_size": "auto",
39 | "train_micro_batch_size_per_gpu": "auto",
40 | "wall_clock_breakdown": false
41 | }
--------------------------------------------------------------------------------
/src/virft/local_scripts/zero3.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 |
14 | "zero_optimization": {
15 | "stage": 3,
16 | "offload_optimizer": {
17 | "device": "none",
18 | "pin_memory": true
19 | },
20 | "offload_param": {
21 | "device": "none",
22 | "pin_memory": true
23 | },
24 | "overlap_comm": true,
25 | "contiguous_gradients": true,
26 | "sub_group_size": 1e9,
27 | "reduce_bucket_size": "auto",
28 | "stage3_prefetch_bucket_size": "auto",
29 | "stage3_param_persistence_threshold": "auto",
30 | "stage3_max_live_parameters": 1e9,
31 | "stage3_max_reuse_distance": 1e9,
32 | "stage3_gather_16bit_weights_on_model_save": true
33 | },
34 |
35 | "gradient_accumulation_steps": "auto",
36 | "gradient_clipping": "auto",
37 | "steps_per_print": 100,
38 | "train_batch_size": "auto",
39 | "train_micro_batch_size_per_gpu": "auto",
40 | "wall_clock_breakdown": false
41 | }
--------------------------------------------------------------------------------
/src/virft/local_scripts/zero3.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_multinode_launcher: standard
5 | offload_optimizer_device: none
6 | offload_param_device: none
7 | zero3_init_flag: true
8 | zero3_save_16bit_model: true
9 | zero_stage: 3
10 | distributed_type: DEEPSPEED
11 | downcast_bf16: 'no'
12 | machine_rank: 0
13 | main_training_function: main
14 | mixed_precision: bf16
15 | num_machines: 1
16 | num_processes: 8
17 | rdzv_backend: static
18 | same_network: true
19 | tpu_env: []
20 | tpu_use_cluster: false
21 | tpu_use_sudo: false
22 | use_cpu: false
23 |
--------------------------------------------------------------------------------
/src/virft/local_scripts/zero3_offload.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 | "zero_optimization": {
23 | "stage": 3,
24 | "offload_optimizer": {
25 | "device": "cpu",
26 | "pin_memory": true
27 | },
28 | "offload_param": {
29 | "device": "cpu",
30 | "pin_memory": true
31 | },
32 | "overlap_comm": true,
33 | "contiguous_gradients": true,
34 | "sub_group_size": 1e9,
35 | "reduce_bucket_size": "auto",
36 | "stage3_prefetch_bucket_size": "auto",
37 | "stage3_param_persistence_threshold": "auto",
38 | "stage3_max_live_parameters": 1e9,
39 | "stage3_max_reuse_distance": 1e9,
40 | "gather_16bit_weights_on_model_save": true
41 | },
42 | "gradient_accumulation_steps": "auto",
43 | "gradient_clipping": "auto",
44 | "train_batch_size": "auto",
45 | "train_micro_batch_size_per_gpu": "auto",
46 | "steps_per_print": 1e5,
47 | "wall_clock_breakdown": false
48 | }
--------------------------------------------------------------------------------
/src/virft/setup.cfg:
--------------------------------------------------------------------------------
1 | [isort]
2 | default_section = FIRSTPARTY
3 | ensure_newline_before_comments = True
4 | force_grid_wrap = 0
5 | include_trailing_comma = True
6 | known_first_party = open_r1
7 | known_third_party =
8 | transformers
9 | datasets
10 | fugashi
11 | git
12 | h5py
13 | matplotlib
14 | nltk
15 | numpy
16 | packaging
17 | pandas
18 | psutil
19 | pytest
20 | rouge_score
21 | sacrebleu
22 | seqeval
23 | sklearn
24 | streamlit
25 | torch
26 | tqdm
27 |
28 | line_length = 119
29 | lines_after_imports = 2
30 | multi_line_output = 3
31 | use_parentheses = True
32 |
33 | [flake8]
34 | ignore = E203, E501, E741, W503, W605
35 | max-line-length = 119
36 | per-file-ignores =
37 | # imported but unused
38 | __init__.py: F401
39 |
40 | [tool:pytest]
41 | doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS
--------------------------------------------------------------------------------
/src/virft/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # Adapted from huggingface/transformers: https://github.com/huggingface/transformers/blob/21a2d900eceeded7be9edc445b56877b95eda4ca/setup.py
16 |
17 |
18 | import re
19 | import shutil
20 | from pathlib import Path
21 |
22 | from setuptools import find_packages, setup
23 |
24 |
25 | # Remove stale open_r1.egg-info directory to avoid https://github.com/pypa/pip/issues/5466
26 | stale_egg_info = Path(__file__).parent / "open_r1.egg-info"
27 | if stale_egg_info.exists():
28 | print(
29 | (
30 | "Warning: {} exists.\n\n"
31 | "If you recently updated open_r1, this is expected,\n"
32 | "but it may prevent open_r1 from installing in editable mode.\n\n"
33 | "This directory is automatically generated by Python's packaging tools.\n"
34 | "I will remove it now.\n\n"
35 | "See https://github.com/pypa/pip/issues/5466 for details.\n"
36 | ).format(stale_egg_info)
37 | )
38 | shutil.rmtree(stale_egg_info)
39 |
40 |
41 | # IMPORTANT: all dependencies should be listed here with their version requirements, if any.
42 | # * If a dependency is fast-moving (e.g. transformers), pin to the exact version
43 | _deps = [
44 | "accelerate>=1.2.1",
45 | "bitsandbytes>=0.43.0",
46 | "black>=24.4.2",
47 | "datasets>=3.2.0",
48 | "deepspeed==0.15.4",
49 | "distilabel[vllm,ray,openai]>=1.5.2",
50 | "einops>=0.8.0",
51 | "flake8>=6.0.0",
52 | "hf_transfer>=0.1.4",
53 | "huggingface-hub[cli]>=0.19.2,<1.0",
54 | "isort>=5.12.0",
55 | "liger_kernel==0.5.2",
56 | "lighteval @ git+https://github.com/huggingface/lighteval.git@4f381b352c0e467b5870a97d41cb66b487a2c503#egg=lighteval[math]",
57 | "math-verify", # Used for math verification in grpo
58 | "packaging>=23.0",
59 | "parameterized>=0.9.0",
60 | "pytest",
61 | "safetensors>=0.3.3",
62 | "sentencepiece>=0.1.99",
63 | "torch>=2.5.1",
64 | "transformers @ git+https://github.com/huggingface/transformers.git@main",
65 | "trl @ git+https://github.com/huggingface/trl.git@main",
66 | "vllm==0.6.6.post1",
67 | "wandb>=0.19.1",
68 | "pillow",
69 | ]
70 |
71 | # this is a lookup table with items like:
72 | #
73 | # tokenizers: "tokenizers==0.9.4"
74 | # packaging: "packaging"
75 | #
76 | # some of the values are versioned whereas others aren't.
77 | deps = {b: a for a, b in (re.findall(r"^(([^!=<>~ \[\]]+)(?:\[[^\]]+\])?(?:[!=<>~ ].*)?$)", x)[0] for x in _deps)}
78 |
79 |
80 | def deps_list(*pkgs):
81 | return [deps[pkg] for pkg in pkgs]
82 |
83 |
84 | extras = {}
85 | extras["tests"] = deps_list("pytest", "parameterized")
86 | extras["torch"] = deps_list("torch")
87 | extras["quality"] = deps_list("black", "isort", "flake8")
88 | extras["eval"] = deps_list("lighteval", "math-verify")
89 | extras["dev"] = extras["quality"] + extras["tests"] + extras["eval"]
90 |
91 | # core dependencies shared across the whole project - keep this to a bare minimum :)
92 | install_requires = [
93 | deps["accelerate"],
94 | deps["bitsandbytes"],
95 | deps["einops"],
96 | deps["datasets"],
97 | deps["deepspeed"],
98 | deps["hf_transfer"],
99 | deps["huggingface-hub"],
100 | deps["liger_kernel"],
101 | deps["packaging"], # utilities from PyPA to e.g., compare versions
102 | deps["safetensors"],
103 | deps["sentencepiece"],
104 | deps["transformers"],
105 | deps["trl"],
106 | ]
107 |
108 | setup(
109 | name="open-r1",
110 | version="0.1.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
111 | author="The Hugging Face team (past and future)",
112 | author_email="lewis@huggingface.co",
113 | description="Open R1",
114 | long_description="This is the Visual-RFT project.",
115 | long_description_content_type="text/markdown",
116 | keywords="llm inference-time compute reasoning",
117 | license="Apache",
118 | url="https://github.com/huggingface/open-r1",
119 | package_dir={"": "src"},
120 | packages=find_packages("src"),
121 | zip_safe=False,
122 | extras_require=extras,
123 | python_requires=">=3.10.9",
124 | install_requires=install_requires,
125 | classifiers=[
126 | "Development Status :: 3 - Alpha",
127 | "Intended Audience :: Developers",
128 | "Intended Audience :: Education",
129 | "Intended Audience :: Science/Research",
130 | "License :: OSI Approved :: Apache Software License",
131 | "Operating System :: OS Independent",
132 | "Programming Language :: Python :: 3",
133 | "Programming Language :: Python :: 3.10",
134 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
135 | ],
136 | )
137 |
--------------------------------------------------------------------------------
/src/virft/slurm/evaluate.slurm:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=open-r1-evaluate
3 | #SBATCH --nodes=1
4 | #SBATCH --ntasks-per-node=1
5 | #SBATCH --exclusive
6 | #SBATCH --gres=gpu:8
7 | #SBATCH --partition=hopper-prod
8 | #SBATCH --time=01:59:00
9 | #SBATCH --output=./logs/evaluate/%x-%j.out
10 | #SBATCH --err=./logs/evaluate/%x-%j.err
11 |
12 | # Usage: sbatch slurm/evaluate.slurm deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B aime24
13 |
14 | set -x -e
15 |
16 | source ~/.bashrc
17 | conda activate openr1
18 | module load cuda/12.1
19 | echo "START TIME: $(date)"
20 | echo "PYTHON ENV: $(which python)"
21 |
22 |
23 | NUM_GPUS=8
24 | MODEL=$1
25 | TASK=$2
26 | MODEL_ARGS="pretrained=$MODEL,dtype=float16,data_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilisation=0.8"
27 | OUTPUT_DIR=data/evals/$MODEL
28 |
29 |
30 | # force crashing on nccl issues like hanging broadcast
31 | export NCCL_ASYNC_ERROR_HANDLING=1
32 | # export NCCL_DEBUG=INFO
33 | # export NCCL_DEBUG_SUBSYS=COLL
34 | # export NCCL_SOCKET_NTHREADS=1
35 | # export NCCL_NSOCKS_PERTHREAD=1
36 | # export CUDA_LAUNCH_BLOCKING=1
37 |
38 | # Specific configuration optimized for the Hugging Face Compute Cluster
39 | # Be ye warned this may not work on other clusters!
40 | module load cuda/12.1
41 |
42 | lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \
43 | --custom-tasks src/open_r1/evaluate.py \
44 | --use-chat-template \
45 | --system-prompt="Please reason step by step, and put your final answer within \boxed{}." \
46 | --output-dir $OUTPUT_DIR
47 |
48 |
49 | echo "END TIME: $(date)"
50 |
--------------------------------------------------------------------------------
/src/virft/slurm/sft.slurm:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=open-r1-sft
3 | #SBATCH --nodes=1
4 | #SBATCH --ntasks-per-node=1
5 | #SBATCH --exclusive
6 | #SBATCH --gres=gpu:8
7 | #SBATCH --partition=hopper-prod
8 | #SBATCH --output=./logs/%x-%j.out
9 | #SBATCH --err=./logs/%x-%j.err
10 |
11 | set -x -e
12 |
13 | source ~/.bashrc
14 | conda activate openr1
15 | module load cuda/12.1
16 | echo "START TIME: $(date)"
17 | echo "PYTHON ENV: $(which python)"
18 |
19 | MODEL_PATH=$1
20 | DATASET_PATH=$2
21 | ACCELERATOR=$3
22 |
23 | # Training setup
24 | NUM_NODES=$SLURM_NNODES
25 | GPUS_PER_NODE=8
26 | WORLD_SIZE=$(($NUM_NODES*$GPUS_PER_NODE))
27 |
28 | # so processes know who to talk to
29 | MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
30 | MASTER_PORT=6000
31 |
32 | export CMD=" \
33 | src/open_r1/sft.py \
34 | --model_name_or_path $MODEL_PATH \
35 | --dataset_name $DATASET_PATH \
36 | --use_liger_kernel true \
37 | --learning_rate 2.0e-5 \
38 | --num_train_epochs 1 \
39 | --packing \
40 | --max_seq_length 4096 \
41 | --per_device_train_batch_size 4 \
42 | --per_device_eval_batch_size 4 \
43 | --gradient_accumulation_steps 4 \
44 | --gradient_checkpointing \
45 | --bf16 \
46 | --logging_steps 5 \
47 | --eval_strategy steps \
48 | --eval_steps 100 \
49 | --output_dir data/Qwen2.5-1.5B-Open-R1-Distill
50 | "
51 |
52 | export LAUNCHER="HF_HUB_ENABLE_HF_TRANSFER=1 ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch \
53 | --config_file configs/$ACCELERATOR.yaml \
54 | --gradient_accumulation_steps 4 \
55 | --num_machines $NUM_NODES \
56 | --num_processes $WORLD_SIZE \
57 | --main_process_ip $MASTER_ADDR \
58 | --main_process_port $MASTER_PORT \
59 | --machine_rank \$SLURM_PROCID \
60 | --rdzv_conf "rdzv_backend=c10d,rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT" \
61 | --max_restarts 1 \
62 | --role \$(hostname -s): \
63 | --tee 3 \
64 | "
65 |
66 | # force crashing on nccl issues like hanging broadcast
67 | export NCCL_ASYNC_ERROR_HANDLING=1
68 | # export NCCL_DEBUG=INFO
69 | # export NCCL_DEBUG_SUBSYS=COLL
70 | # export NCCL_SOCKET_NTHREADS=1
71 | # export NCCL_NSOCKS_PERTHREAD=1
72 | # export CUDA_LAUNCH_BLOCKING=1
73 |
74 | # Specific configuration optimized for the Hugging Face Compute Cluster
75 | # Be ye warned this may not work on other clusters!
76 | module load cuda/12.1
77 |
78 | # srun error handling:
79 | # --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks
80 | # --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code
81 | SRUN_ARGS=" \
82 | --wait=60 \
83 | --kill-on-bad-exit=1 \
84 | "
85 |
86 | clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER --role \$SLURMD_NODENAME: $CMD" 2>&1
87 |
88 | echo "END TIME: $(date)"
89 |
--------------------------------------------------------------------------------
/src/virft/src/open_r1/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Liuziyu77/Visual-RFT/94866331ff38c516581697882a00f947c8bfda8a/src/virft/src/open_r1/__init__.py
--------------------------------------------------------------------------------
/src/virft/src/open_r1/evaluate.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Custom evaluation tasks for LightEval."""
16 |
17 | from lighteval.metrics.dynamic_metrics import (
18 | ExprExtractionConfig,
19 | LatexExtractionConfig,
20 | multilingual_extractive_match_metric,
21 | )
22 | from lighteval.tasks.lighteval_task import LightevalTaskConfig
23 | from lighteval.tasks.requests import Doc
24 | from lighteval.utils.language import Language
25 |
26 |
27 | metric = multilingual_extractive_match_metric(
28 | language=Language.ENGLISH,
29 | fallback_mode="first_match",
30 | precision=5,
31 | gold_extraction_target=(LatexExtractionConfig(),),
32 | pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
33 | aggregation_function=max,
34 | )
35 |
36 |
37 | def prompt_fn(line, task_name: str = None):
38 | """Assumes the model is either prompted to emit \\boxed{answer} or does so automatically"""
39 | return Doc(
40 | task_name=task_name,
41 | query=line["problem"],
42 | choices=[line["solution"]],
43 | gold_index=0,
44 | )
45 |
46 |
47 | # Define tasks
48 | aime24 = LightevalTaskConfig(
49 | name="aime24",
50 | suite=["custom"],
51 | prompt_function=prompt_fn,
52 | hf_repo="HuggingFaceH4/aime_2024",
53 | hf_subset="default",
54 | hf_avail_splits=["train"],
55 | evaluation_splits=["train"],
56 | few_shots_split=None,
57 | few_shots_select=None,
58 | generation_size=32768,
59 | metric=[metric],
60 | version=1,
61 | )
62 | math_500 = LightevalTaskConfig(
63 | name="math_500",
64 | suite=["custom"],
65 | prompt_function=prompt_fn,
66 | hf_repo="HuggingFaceH4/MATH-500",
67 | hf_subset="default",
68 | hf_avail_splits=["test"],
69 | evaluation_splits=["test"],
70 | few_shots_split=None,
71 | few_shots_select=None,
72 | generation_size=32768,
73 | metric=[metric],
74 | version=1,
75 | )
76 |
77 | # Add tasks to the table
78 | TASKS_TABLE = []
79 | TASKS_TABLE.append(aime24)
80 | TASKS_TABLE.append(math_500)
81 |
82 | # MODULE LOGIC
83 | if __name__ == "__main__":
84 | print([t["name"] for t in TASKS_TABLE])
85 | print(len(TASKS_TABLE))
86 |
--------------------------------------------------------------------------------
/src/virft/src/open_r1/generate.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Optional
16 |
17 | from distilabel.llms import OpenAILLM
18 | from distilabel.pipeline import Pipeline
19 | from distilabel.steps.tasks import TextGeneration
20 |
21 |
22 | def build_distilabel_pipeline(
23 | model: str,
24 | base_url: str = "http://localhost:8000/v1",
25 | prompt_column: Optional[str] = None,
26 | temperature: Optional[float] = None,
27 | top_p: Optional[float] = None,
28 | max_new_tokens: int = 8192,
29 | num_generations: int = 1,
30 | ) -> Pipeline:
31 | generation_kwargs = {"max_new_tokens": max_new_tokens}
32 |
33 | if temperature is not None:
34 | generation_kwargs["temperature"] = temperature
35 |
36 | if top_p is not None:
37 | generation_kwargs["top_p"] = top_p
38 |
39 | with Pipeline().ray() as pipeline:
40 | TextGeneration(
41 | llm=OpenAILLM(
42 | base_url=base_url,
43 | api_key="something",
44 | model=model,
45 | # thinking can take some time...
46 | timeout=10 * 60,
47 | generation_kwargs=generation_kwargs,
48 | ),
49 | input_mappings={"instruction": prompt_column} if prompt_column is not None else {},
50 | input_batch_size=64, # on 4 nodes bs ~60+ leads to preemption due to KV cache exhaustion
51 | num_generations=num_generations,
52 | )
53 |
54 | return pipeline
55 |
56 |
57 | if __name__ == "__main__":
58 | import argparse
59 |
60 | from datasets import load_dataset
61 |
62 | parser = argparse.ArgumentParser(description="Run distilabel pipeline for generating responses with DeepSeek R1")
63 | parser.add_argument(
64 | "--hf-dataset",
65 | type=str,
66 | required=True,
67 | help="HuggingFace dataset to load",
68 | )
69 | parser.add_argument(
70 | "--hf-dataset-config",
71 | type=str,
72 | required=False,
73 | help="Dataset config to use",
74 | )
75 | parser.add_argument(
76 | "--hf-dataset-split",
77 | type=str,
78 | default="train",
79 | help="Dataset split to use",
80 | )
81 | parser.add_argument("--prompt-column", type=str, default="prompt")
82 | parser.add_argument(
83 | "--model",
84 | type=str,
85 | required=True,
86 | help="Model name to use for generation",
87 | )
88 | parser.add_argument(
89 | "--vllm-server-url",
90 | type=str,
91 | default="http://localhost:8000/v1",
92 | help="URL of the vLLM server",
93 | )
94 | parser.add_argument(
95 | "--temperature",
96 | type=float,
97 | help="Temperature for generation",
98 | )
99 | parser.add_argument(
100 | "--top-p",
101 | type=float,
102 | help="Top-p value for generation",
103 | )
104 | parser.add_argument(
105 | "--max-new-tokens",
106 | type=int,
107 | default=8192,
108 | help="Maximum number of new tokens to generate",
109 | )
110 | parser.add_argument(
111 | "--num-generations",
112 | type=int,
113 | default=1,
114 | help="Number of generations per problem",
115 | )
116 | parser.add_argument(
117 | "--hf-output-dataset",
118 | type=str,
119 | required=False,
120 | help="HuggingFace repo to push results to",
121 | )
122 | parser.add_argument(
123 | "--private",
124 | action="store_true",
125 | help="Whether to make the output dataset private when pushing to HF Hub",
126 | )
127 |
128 | args = parser.parse_args()
129 |
130 | print("\nRunning with arguments:")
131 | for arg, value in vars(args).items():
132 | print(f" {arg}: {value}")
133 | print()
134 |
135 | print(f"Loading '{args.hf_dataset}' (config: {args.hf_dataset_config}, split: {args.hf_dataset_split}) dataset...")
136 | dataset = load_dataset(args.hf_dataset, split=args.hf_dataset_split)
137 | print("Dataset loaded!")
138 |
139 | pipeline = build_distilabel_pipeline(
140 | model=args.model,
141 | base_url=args.vllm_server_url,
142 | prompt_column=args.prompt_column,
143 | temperature=args.temperature,
144 | top_p=args.top_p,
145 | max_new_tokens=args.max_new_tokens,
146 | num_generations=args.num_generations,
147 | )
148 |
149 | print("Running generation pipeline...")
150 | distiset = pipeline.run(dataset=dataset, use_cache=False)
151 | print("Generation pipeline finished!")
152 |
153 | if args.hf_output_dataset:
154 | print(f"Pushing resulting dataset to '{args.hf_output_dataset}'...")
155 | distiset.push_to_hub(args.hf_output_dataset, private=args.private)
156 | print("Dataset pushed!")
157 |
--------------------------------------------------------------------------------
/src/virft/src/open_r1/sft.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | Supervised fine-tuning script for decoder language models.
17 |
18 | Usage:
19 |
20 | # One 1 node of 8 x H100s
21 | accelerate launch --config_file=configs/zero3.yaml src/open_r1/sft.py \
22 | --model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
23 | --dataset_name HuggingFaceH4/Bespoke-Stratos-17k \
24 | --learning_rate 2.0e-5 \
25 | --num_train_epochs 1 \
26 | --packing \
27 | --max_seq_length 4096 \
28 | --per_device_train_batch_size 4 \
29 | --gradient_accumulation_steps 4 \
30 | --gradient_checkpointing \
31 | --bf16 \
32 | --logging_steps 5 \
33 | --eval_strategy steps \
34 | --eval_steps 100 \
35 | --output_dir data/Qwen2.5-1.5B-Open-R1-Distill
36 | """
37 |
38 | from datasets import load_dataset
39 | from transformers import AutoTokenizer
40 |
41 | from trl import (
42 | ModelConfig,
43 | ScriptArguments,
44 | SFTConfig,
45 | SFTTrainer,
46 | TrlParser,
47 | get_kbit_device_map,
48 | get_peft_config,
49 | get_quantization_config,
50 | )
51 |
52 |
53 | def main(script_args, training_args, model_args):
54 | ################
55 | # Model init kwargs & Tokenizer
56 | ################
57 | quantization_config = get_quantization_config(model_args)
58 | model_kwargs = dict(
59 | revision=model_args.model_revision,
60 | trust_remote_code=model_args.trust_remote_code,
61 | attn_implementation=model_args.attn_implementation,
62 | torch_dtype=model_args.torch_dtype,
63 | use_cache=False if training_args.gradient_checkpointing else True,
64 | device_map=get_kbit_device_map() if quantization_config is not None else None,
65 | quantization_config=quantization_config,
66 | )
67 | training_args.model_init_kwargs = model_kwargs
68 | tokenizer = AutoTokenizer.from_pretrained(
69 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
70 | )
71 | tokenizer.pad_token = tokenizer.eos_token
72 |
73 | ################
74 | # Dataset
75 | ################
76 | dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
77 |
78 | ################
79 | # Training
80 | ################
81 | trainer = SFTTrainer(
82 | model=model_args.model_name_or_path,
83 | args=training_args,
84 | train_dataset=dataset[script_args.dataset_train_split],
85 | eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
86 | processing_class=tokenizer,
87 | peft_config=get_peft_config(model_args),
88 | )
89 |
90 | trainer.train()
91 |
92 | # Save and push to hub
93 | trainer.save_model(training_args.output_dir)
94 | if training_args.push_to_hub:
95 | trainer.push_to_hub(dataset_name=script_args.dataset_name)
96 |
97 |
98 | if __name__ == "__main__":
99 | parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
100 | script_args, training_args, model_args = parser.parse_args_and_config()
101 | main(script_args, training_args, model_args)
102 |
--------------------------------------------------------------------------------
/src/virft/src/open_r1/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | from .grpo_trainer import Qwen2VLGRPOTrainer
2 | from .vllm_grpo_trainer import Qwen2VLGRPOVLLMTrainer
3 | from .grpo_trainer_mp import Qwen2VLGRPOTrainer_MP
4 | from .grpo_trainer_aid import Qwen2VLGRPOTrainer_AID
5 |
6 | __all__ = ["Qwen2VLGRPOTrainer", "Qwen2VLGRPOVLLMTrainer", "Qwen2VLGRPOTrainer_MP", "Qwen2VLGRPOTrainer_AID"]
7 |
--------------------------------------------------------------------------------